Skip to content

Commit bbec998

Browse files
authored
feat: support model monitoring for batch prediction in Vertex SDK (#1570)
* feat: support model monitoring for batch prediction in Vertex SDK * fixed broken tests * fixing syntax error * addressed comments * updated test variable name
1 parent 3d3e0aa commit bbec998

File tree

6 files changed

+224
-58
lines changed

6 files changed

+224
-58
lines changed

google/cloud/aiplatform/jobs.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,13 @@ def create(
385385
sync: bool = True,
386386
create_request_timeout: Optional[float] = None,
387387
batch_size: Optional[int] = None,
388+
model_monitoring_objective_config: Optional[
389+
"aiplatform.model_monitoring.ObjectiveConfig"
390+
] = None,
391+
model_monitoring_alert_config: Optional[
392+
"aiplatform.model_monitoring.AlertConfig"
393+
] = None,
394+
analysis_instance_schema_uri: Optional[str] = None,
388395
) -> "BatchPredictionJob":
389396
"""Create a batch prediction job.
390397
@@ -551,6 +558,23 @@ def create(
551558
but too high value will result in a whole batch not fitting in a machine's memory,
552559
and the whole operation will fail.
553560
The default value is 64.
561+
model_monitoring_objective_config (aiplatform.model_monitoring.ObjectiveConfig):
562+
Optional. The objective config for model monitoring. Passing this parameter enables
563+
monitoring on the model associated with this batch prediction job.
564+
model_monitoring_alert_config (aiplatform.model_monitoring.EmailAlertConfig):
565+
Optional. Configures how model monitoring alerts are sent to the user. Right now
566+
only email alert is supported.
567+
analysis_instance_schema_uri (str):
568+
Optional. Only applicable if model_monitoring_objective_config is also passed.
569+
This parameter specifies the YAML schema file uri describing the format of a single
570+
instance that you want Tensorflow Data Validation (TFDV) to
571+
analyze. If this field is empty, all the feature data types are
572+
inferred from predict_instance_schema_uri, meaning that TFDV
573+
will use the data in the exact format as prediction request/response.
574+
If there are any data type differences between predict instance
575+
and TFDV instance, this field can be used to override the schema.
576+
For models trained with Vertex AI, this field must be set as all the
577+
fields in predict instance formatted as string.
554578
Returns:
555579
(jobs.BatchPredictionJob):
556580
Instantiated representation of the created batch prediction job.
@@ -601,7 +625,18 @@ def create(
601625
f"{predictions_format} is not an accepted prediction format "
602626
f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}"
603627
)
604-
628+
# TODO: remove temporary import statements once model monitoring for batch prediction is GA
629+
if model_monitoring_objective_config:
630+
from google.cloud.aiplatform.compat.types import (
631+
io_v1beta1 as gca_io_compat,
632+
batch_prediction_job_v1beta1 as gca_bp_job_compat,
633+
model_monitoring_v1beta1 as gca_model_monitoring_compat,
634+
)
635+
else:
636+
from google.cloud.aiplatform.compat.types import (
637+
io as gca_io_compat,
638+
batch_prediction_job as gca_bp_job_compat,
639+
)
605640
gapic_batch_prediction_job = gca_bp_job_compat.BatchPredictionJob()
606641

607642
# Required Fields
@@ -688,6 +723,28 @@ def create(
688723
)
689724
)
690725

726+
# Model Monitoring
727+
if model_monitoring_objective_config:
728+
if model_monitoring_objective_config.drift_detection_config:
729+
_LOGGER.info(
730+
"Drift detection config is currently not supported for monitoring models associated with batch prediction jobs."
731+
)
732+
if model_monitoring_objective_config.explanation_config:
733+
_LOGGER.info(
734+
"XAI config is currently not supported for monitoring models associated with batch prediction jobs."
735+
)
736+
gapic_batch_prediction_job.model_monitoring_config = (
737+
gca_model_monitoring_compat.ModelMonitoringConfig(
738+
objective_configs=[
739+
model_monitoring_objective_config.as_proto(config_for_bp=True)
740+
],
741+
alert_config=model_monitoring_alert_config.as_proto(
742+
config_for_bp=True
743+
),
744+
analysis_instance_schema_uri=analysis_instance_schema_uri,
745+
)
746+
)
747+
691748
empty_batch_prediction_job = cls._empty_constructor(
692749
project=project,
693750
location=location,
@@ -702,6 +759,11 @@ def create(
702759
sync=sync,
703760
create_request_timeout=create_request_timeout,
704761
)
762+
# TODO: b/242108750
763+
from google.cloud.aiplatform.compat.types import (
764+
io as gca_io_compat,
765+
batch_prediction_job as gca_bp_job_compat,
766+
)
705767

706768
@classmethod
707769
@base.optional_sync(return_input_arg="empty_batch_prediction_job")

google/cloud/aiplatform/model_monitoring/alert.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,16 @@
1717

1818
from typing import Optional, List
1919
from google.cloud.aiplatform_v1.types import (
20-
model_monitoring as gca_model_monitoring,
20+
model_monitoring as gca_model_monitoring_v1,
2121
)
2222

23+
# TODO: remove imports from v1beta1 once model monitoring for batch prediction is GA
24+
from google.cloud.aiplatform_v1beta1.types import (
25+
model_monitoring as gca_model_monitoring_v1beta1,
26+
)
27+
28+
gca_model_monitoring = gca_model_monitoring_v1
29+
2330

2431
class EmailAlertConfig:
2532
def __init__(
@@ -40,8 +47,19 @@ def __init__(
4047
self.enable_logging = enable_logging
4148
self.user_emails = user_emails
4249

43-
def as_proto(self):
44-
"""Returns EmailAlertConfig as a proto message."""
50+
# TODO: remove config_for_bp parameter when model monitoring for batch prediction is GA
51+
def as_proto(self, config_for_bp: bool = False):
52+
"""Returns EmailAlertConfig as a proto message.
53+
54+
Args:
55+
config_for_bp (bool):
56+
Optional. Set this parameter to True if the config object
57+
is used for model monitoring on a batch prediction job.
58+
"""
59+
if config_for_bp:
60+
gca_model_monitoring = gca_model_monitoring_v1beta1
61+
else:
62+
gca_model_monitoring = gca_model_monitoring_v1
4563
user_email_alert_config = (
4664
gca_model_monitoring.ModelMonitoringAlertConfig.EmailAlertConfig(
4765
user_emails=self.user_emails

google/cloud/aiplatform/model_monitoring/objective.py

Lines changed: 63 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,19 @@
1818
from typing import Optional, Dict
1919

2020
from google.cloud.aiplatform_v1.types import (
21-
io as gca_io,
22-
ThresholdConfig as gca_threshold_config,
23-
model_monitoring as gca_model_monitoring,
21+
io as gca_io_v1,
22+
model_monitoring as gca_model_monitoring_v1,
2423
)
2524

25+
# TODO: b/242108750
26+
from google.cloud.aiplatform_v1beta1.types import (
27+
io as gca_io_v1beta1,
28+
model_monitoring as gca_model_monitoring_v1beta1,
29+
)
30+
31+
gca_model_monitoring = gca_model_monitoring_v1
32+
gca_io = gca_io_v1
33+
2634
TF_RECORD = "tf-record"
2735
CSV = "csv"
2836
JSONL = "jsonl"
@@ -80,19 +88,20 @@ def __init__(
8088
self.attribute_skew_thresholds = attribute_skew_thresholds
8189
self.data_format = data_format
8290
self.target_field = target_field
83-
self.training_dataset = None
8491

8592
def as_proto(self):
8693
"""Returns _SkewDetectionConfig as a proto message."""
8794
skew_thresholds_mapping = {}
8895
attribution_score_skew_thresholds_mapping = {}
8996
if self.skew_thresholds is not None:
9097
for key in self.skew_thresholds.keys():
91-
skew_threshold = gca_threshold_config(value=self.skew_thresholds[key])
98+
skew_threshold = gca_model_monitoring.ThresholdConfig(
99+
value=self.skew_thresholds[key]
100+
)
92101
skew_thresholds_mapping[key] = skew_threshold
93102
if self.attribute_skew_thresholds is not None:
94103
for key in self.attribute_skew_thresholds.keys():
95-
attribution_score_skew_threshold = gca_threshold_config(
104+
attribution_score_skew_threshold = gca_model_monitoring.ThresholdConfig(
96105
value=self.attribute_skew_thresholds[key]
97106
)
98107
attribution_score_skew_thresholds_mapping[
@@ -134,12 +143,16 @@ def as_proto(self):
134143
attribution_score_drift_thresholds_mapping = {}
135144
if self.drift_thresholds is not None:
136145
for key in self.drift_thresholds.keys():
137-
drift_threshold = gca_threshold_config(value=self.drift_thresholds[key])
146+
drift_threshold = gca_model_monitoring.ThresholdConfig(
147+
value=self.drift_thresholds[key]
148+
)
138149
drift_thresholds_mapping[key] = drift_threshold
139150
if self.attribute_drift_thresholds is not None:
140151
for key in self.attribute_drift_thresholds.keys():
141-
attribution_score_drift_threshold = gca_threshold_config(
142-
value=self.attribute_drift_thresholds[key]
152+
attribution_score_drift_threshold = (
153+
gca_model_monitoring.ThresholdConfig(
154+
value=self.attribute_drift_thresholds[key]
155+
)
143156
)
144157
attribution_score_drift_thresholds_mapping[
145158
key
@@ -186,11 +199,49 @@ def __init__(
186199
self.drift_detection_config = drift_detection_config
187200
self.explanation_config = explanation_config
188201

189-
def as_proto(self):
190-
"""Returns _ObjectiveConfig as a proto message."""
202+
# TODO: b/242108750
203+
def as_proto(self, config_for_bp: bool = False):
204+
"""Returns _SkewDetectionConfig as a proto message.
205+
206+
Args:
207+
config_for_bp (bool):
208+
Optional. Set this parameter to True if the config object
209+
is used for model monitoring on a batch prediction job.
210+
"""
211+
if config_for_bp:
212+
gca_io = gca_io_v1beta1
213+
gca_model_monitoring = gca_model_monitoring_v1beta1
214+
else:
215+
gca_io = gca_io_v1
216+
gca_model_monitoring = gca_model_monitoring_v1
191217
training_dataset = None
192218
if self.skew_detection_config is not None:
193-
training_dataset = self.skew_detection_config.training_dataset
219+
training_dataset = (
220+
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
221+
target_field=self.skew_detection_config.target_field
222+
)
223+
)
224+
if self.skew_detection_config.data_source.startswith("bq:/"):
225+
training_dataset.bigquery_source = gca_io.BigQuerySource(
226+
input_uri=self.skew_detection_config.data_source
227+
)
228+
elif self.skew_detection_config.data_source.startswith("gs:/"):
229+
training_dataset.gcs_source = gca_io.GcsSource(
230+
uris=[self.skew_detection_config.data_source]
231+
)
232+
if (
233+
self.skew_detection_config.data_format is not None
234+
and self.skew_detection_config.data_format
235+
not in [TF_RECORD, CSV, JSONL]
236+
):
237+
raise ValueError(
238+
"Unsupported value in skew detection config. `data_format` must be one of %s, %s, or %s"
239+
% (TF_RECORD, CSV, JSONL)
240+
)
241+
training_dataset.data_format = self.skew_detection_config.data_format
242+
else:
243+
training_dataset.dataset = self.skew_detection_config.data_source
244+
194245
return gca_model_monitoring.ModelMonitoringObjectiveConfig(
195246
training_dataset=training_dataset,
196247
training_prediction_skew_detection_config=self.skew_detection_config.as_proto()
@@ -271,27 +322,6 @@ def __init__(
271322
data_format,
272323
)
273324

274-
training_dataset = (
275-
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
276-
target_field=target_field
277-
)
278-
)
279-
if data_source.startswith("bq:/"):
280-
training_dataset.bigquery_source = gca_io.BigQuerySource(
281-
input_uri=data_source
282-
)
283-
elif data_source.startswith("gs:/"):
284-
training_dataset.gcs_source = gca_io.GcsSource(uris=[data_source])
285-
if data_format is not None and data_format not in [TF_RECORD, CSV, JSONL]:
286-
raise ValueError(
287-
"Unsupported value. `data_format` must be one of %s, %s, or %s"
288-
% (TF_RECORD, CSV, JSONL)
289-
)
290-
training_dataset.data_format = data_format
291-
else:
292-
training_dataset.dataset = data_source
293-
self.training_dataset = training_dataset
294-
295325

296326
class DriftDetectionConfig(_DriftDetectionConfig):
297327
"""A class that configures prediction drift detection for models deployed to an endpoint.

tests/system/aiplatform/test_model_monitoring.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,15 @@
2424
from google.api_core import exceptions as core_exceptions
2525
from tests.system.aiplatform import e2e_base
2626

27+
from google.cloud.aiplatform_v1.types import (
28+
io as gca_io,
29+
model_monitoring as gca_model_monitoring,
30+
)
31+
2732
# constants used for testing
2833
USER_EMAIL = ""
29-
MODEL_NAME = "churn"
30-
MODEL_NAME2 = "churn2"
34+
MODEL_DISPLAYNAME_KEY = "churn"
35+
MODEL_DISPLAYNAME_KEY2 = "churn2"
3136
IMAGE = "us-docker.pkg.dev/cloud-aiplatform/prediction/tf2-cpu.2-5:latest"
3237
ENDPOINT = "us-central1-aiplatform.googleapis.com"
3338
CHURN_MODEL_PATH = "gs://mco-mm/churn"
@@ -139,7 +144,7 @@ def temp_endpoint(self, shared_state):
139144
)
140145

141146
model = aiplatform.Model.upload(
142-
display_name=self._make_display_name(key=MODEL_NAME),
147+
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY),
143148
artifact_uri=CHURN_MODEL_PATH,
144149
serving_container_image_uri=IMAGE,
145150
)
@@ -157,19 +162,19 @@ def temp_endpoint_with_two_models(self, shared_state):
157162
)
158163

159164
model1 = aiplatform.Model.upload(
160-
display_name=self._make_display_name(key=MODEL_NAME),
165+
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY),
161166
artifact_uri=CHURN_MODEL_PATH,
162167
serving_container_image_uri=IMAGE,
163168
)
164169

165170
model2 = aiplatform.Model.upload(
166-
display_name=self._make_display_name(key=MODEL_NAME),
171+
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY2),
167172
artifact_uri=CHURN_MODEL_PATH,
168173
serving_container_image_uri=IMAGE,
169174
)
170175
shared_state["resources"] = [model1, model2]
171176
endpoint = aiplatform.Endpoint.create(
172-
display_name=self._make_display_name(key=MODEL_NAME)
177+
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY)
173178
)
174179
endpoint.deploy(
175180
model=model1, machine_type="n1-standard-2", traffic_percentage=100
@@ -224,7 +229,14 @@ def test_mdm_one_model_one_valid_config(self, shared_state):
224229
gca_obj_config = gapic_job.model_deployment_monitoring_objective_configs[
225230
0
226231
].objective_config
227-
assert gca_obj_config.training_dataset == skew_config.training_dataset
232+
233+
expected_training_dataset = (
234+
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
235+
bigquery_source=gca_io.BigQuerySource(input_uri=DATASET_BQ_URI),
236+
target_field=TARGET,
237+
)
238+
)
239+
assert gca_obj_config.training_dataset == expected_training_dataset
228240
assert (
229241
gca_obj_config.training_prediction_skew_detection_config
230242
== skew_config.as_proto()
@@ -297,12 +309,18 @@ def test_mdm_two_models_two_valid_configs(self, shared_state):
297309
)
298310
assert gapic_job.model_monitoring_alert_config.enable_logging
299311

312+
expected_training_dataset = (
313+
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
314+
bigquery_source=gca_io.BigQuerySource(input_uri=DATASET_BQ_URI),
315+
target_field=TARGET,
316+
)
317+
)
318+
300319
for config in gapic_job.model_deployment_monitoring_objective_configs:
301320
gca_obj_config = config.objective_config
302321
deployed_model_id = config.deployed_model_id
303322
assert (
304-
gca_obj_config.training_dataset
305-
== all_configs[deployed_model_id].skew_detection_config.training_dataset
323+
gca_obj_config.as_proto().training_dataset == expected_training_dataset
306324
)
307325
assert (
308326
gca_obj_config.training_prediction_skew_detection_config

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy