Skip to content

Commit cc59e60

Browse files
authored
feat: Add timeout arguments to Endpoint.predict and Endpoint.explain (#1094)
Fixes # [b/224990641](b/224990641) 🦕
1 parent 25b546a commit cc59e60

File tree

4 files changed

+56
-4
lines changed

4 files changed

+56
-4
lines changed

google/cloud/aiplatform/models.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,7 +1167,12 @@ def _instantiate_prediction_client(
11671167
prediction_client=True,
11681168
)
11691169

1170-
def predict(self, instances: List, parameters: Optional[Dict] = None) -> Prediction:
1170+
def predict(
1171+
self,
1172+
instances: List,
1173+
parameters: Optional[Dict] = None,
1174+
timeout: Optional[float] = None,
1175+
) -> Prediction:
11711176
"""Make a prediction against this Endpoint.
11721177
11731178
Args:
@@ -1190,13 +1195,17 @@ def predict(self, instances: List, parameters: Optional[Dict] = None) -> Predict
11901195
][google.cloud.aiplatform.v1beta1.DeployedModel.model]
11911196
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
11921197
``parameters_schema_uri``.
1198+
timeout (float): Optional. The timeout for this request in seconds.
11931199
Returns:
11941200
prediction: Prediction with returned predictions and Model Id.
11951201
"""
11961202
self.wait()
11971203

11981204
prediction_response = self._prediction_client.predict(
1199-
endpoint=self._gca_resource.name, instances=instances, parameters=parameters
1205+
endpoint=self._gca_resource.name,
1206+
instances=instances,
1207+
parameters=parameters,
1208+
timeout=timeout,
12001209
)
12011210

12021211
return Prediction(
@@ -1212,6 +1221,7 @@ def explain(
12121221
instances: List[Dict],
12131222
parameters: Optional[Dict] = None,
12141223
deployed_model_id: Optional[str] = None,
1224+
timeout: Optional[float] = None,
12151225
) -> Prediction:
12161226
"""Make a prediction with explanations against this Endpoint.
12171227
@@ -1242,6 +1252,7 @@ def explain(
12421252
deployed_model_id (str):
12431253
Optional. If specified, this ExplainRequest will be served by the
12441254
chosen DeployedModel, overriding this Endpoint's traffic split.
1255+
timeout (float): Optional. The timeout for this request in seconds.
12451256
Returns:
12461257
prediction: Prediction with returned predictions, explanations and Model Id.
12471258
"""
@@ -1252,6 +1263,7 @@ def explain(
12521263
instances=instances,
12531264
parameters=parameters,
12541265
deployed_model_id=deployed_model_id,
1266+
timeout=timeout,
12551267
)
12561268

12571269
return Prediction(

tests/system/aiplatform/test_e2e_tabular.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,13 +164,14 @@ def test_end_to_end_tabular(self, shared_state):
164164
is True
165165
)
166166

167-
custom_prediction = custom_endpoint.predict([_INSTANCE])
167+
custom_prediction = custom_endpoint.predict([_INSTANCE], timeout=180.0)
168168

169169
custom_batch_prediction_job.wait()
170170

171171
automl_endpoint.wait()
172172
automl_prediction = automl_endpoint.predict(
173-
[{k: str(v) for k, v in _INSTANCE.items()}] # Cast int values to strings
173+
[{k: str(v) for k, v in _INSTANCE.items()}], # Cast int values to strings
174+
timeout=180.0,
174175
)
175176

176177
# Test lazy loading of Endpoint, check getter was never called after predict()

tests/unit/aiplatform/test_end_to_end.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def test_dataset_create_to_model_predict(
174174
endpoint=test_endpoints._TEST_ENDPOINT_NAME,
175175
instances=[[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]],
176176
parameters={"param": 3.0},
177+
timeout=None,
177178
)
178179

179180
expected_dataset = gca_dataset.Dataset(

tests/unit/aiplatform/test_endpoints.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,6 +1162,7 @@ def test_predict(self, get_endpoint_mock, predict_client_predict_mock):
11621162
endpoint=_TEST_ENDPOINT_NAME,
11631163
instances=_TEST_INSTANCES,
11641164
parameters={"param": 3.0},
1165+
timeout=None,
11651166
)
11661167

11671168
def test_explain(self, get_endpoint_mock, predict_client_explain_mock):
@@ -1187,6 +1188,43 @@ def test_explain(self, get_endpoint_mock, predict_client_explain_mock):
11871188
instances=_TEST_INSTANCES,
11881189
parameters={"param": 3.0},
11891190
deployed_model_id=_TEST_MODEL_ID,
1191+
timeout=None,
1192+
)
1193+
1194+
@pytest.mark.usefixtures("get_endpoint_mock")
1195+
def test_predict_with_timeout(self, predict_client_predict_mock):
1196+
1197+
test_endpoint = models.Endpoint(_TEST_ID)
1198+
1199+
test_endpoint.predict(
1200+
instances=_TEST_INSTANCES, parameters={"param": 3.0}, timeout=10.0
1201+
)
1202+
1203+
predict_client_predict_mock.assert_called_once_with(
1204+
endpoint=_TEST_ENDPOINT_NAME,
1205+
instances=_TEST_INSTANCES,
1206+
parameters={"param": 3.0},
1207+
timeout=10.0,
1208+
)
1209+
1210+
@pytest.mark.usefixtures("get_endpoint_mock")
1211+
def test_explain_with_timeout(self, predict_client_explain_mock):
1212+
1213+
test_endpoint = models.Endpoint(_TEST_ID)
1214+
1215+
test_endpoint.explain(
1216+
instances=_TEST_INSTANCES,
1217+
parameters={"param": 3.0},
1218+
deployed_model_id=_TEST_MODEL_ID,
1219+
timeout=10.0,
1220+
)
1221+
1222+
predict_client_explain_mock.assert_called_once_with(
1223+
endpoint=_TEST_ENDPOINT_NAME,
1224+
instances=_TEST_INSTANCES,
1225+
parameters={"param": 3.0},
1226+
deployed_model_id=_TEST_MODEL_ID,
1227+
timeout=10.0,
11901228
)
11911229

11921230
def test_list_models(self, get_endpoint_with_models_mock):

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy