Skip to content

Commit 9e77c61

Browse files
fix: added proto message conversion to MDMJob.update fields (#1718)
* fix: added proto message conversion to MDMJob.update fields * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * addressed PR comment * formatting * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * replaced string literal with constant * adding _gca_resource re-assignmnet to mdm job class * Added side effects in get_mdm_job pytest mock * fixing side effects * formatting * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * minor edits to variable names * Addressed PR feedback * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * addressed more PR commentes * addressed PR comments * fix linter errors Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent 3747ce3 commit 9e77c61

File tree

2 files changed

+126
-49
lines changed

2 files changed

+126
-49
lines changed

google/cloud/aiplatform/jobs.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2427,7 +2427,8 @@ def update(
24272427
are allowed. See https://goo.gl/xmQnxf for more information
24282428
and examples of labels.
24292429
bigquery_tables_log_ttl (int):
2430-
Optional. The TTL(time to live) of BigQuery tables in user projects
2430+
Optional. The number of days for which the logs are stored.
2431+
The TTL(time to live) of BigQuery tables in user projects
24312432
which stores logs. A day is the basic unit of
24322433
the TTL and we take the ceil of TTL/86400(a
24332434
day). e.g. { second: 3600} indicates ttl = 1
@@ -2453,28 +2454,30 @@ def update(
24532454
will be applied to all deployed models.
24542455
"""
24552456
self._sync_gca_resource()
2456-
current_job = self.api_client.get_model_deployment_monitoring_job(
2457-
name=self._gca_resource.name
2458-
)
2457+
current_job = copy.deepcopy(self._gca_resource)
24592458
update_mask: List[str] = []
24602459
if display_name is not None:
24612460
update_mask.append("display_name")
24622461
current_job.display_name = display_name
24632462
if schedule_config is not None:
24642463
update_mask.append("model_deployment_monitoring_schedule_config")
2465-
current_job.model_deployment_monitoring_schedule_config = schedule_config
2464+
current_job.model_deployment_monitoring_schedule_config = (
2465+
schedule_config.as_proto()
2466+
)
24662467
if alert_config is not None:
24672468
update_mask.append("model_monitoring_alert_config")
2468-
current_job.model_monitoring_alert_config = alert_config
2469+
current_job.model_monitoring_alert_config = alert_config.as_proto()
24692470
if logging_sampling_strategy is not None:
24702471
update_mask.append("logging_sampling_strategy")
2471-
current_job.logging_sampling_strategy = logging_sampling_strategy
2472+
current_job.logging_sampling_strategy = logging_sampling_strategy.as_proto()
24722473
if labels is not None:
24732474
update_mask.append("labels")
2474-
current_job.lables = labels
2475+
current_job.labels = labels
24752476
if bigquery_tables_log_ttl is not None:
24762477
update_mask.append("log_ttl")
2477-
current_job.log_ttl = bigquery_tables_log_ttl
2478+
current_job.log_ttl = duration_pb2.Duration(
2479+
seconds=bigquery_tables_log_ttl * 86400
2480+
)
24782481
if enable_monitoring_pipeline_logs is not None:
24792482
update_mask.append("enable_monitoring_pipeline_logs")
24802483
current_job.enable_monitoring_pipeline_logs = (
@@ -2491,10 +2494,12 @@ def update(
24912494
deployed_model_ids=deployed_model_ids,
24922495
)
24932496
)
2494-
self.api_client.update_model_deployment_monitoring_job(
2497+
# TODO: b/254285776 add optional_sync support to model monitoring job
2498+
lro = self.api_client.update_model_deployment_monitoring_job(
24952499
model_deployment_monitoring_job=current_job,
24962500
update_mask=field_mask_pb2.FieldMask(paths=update_mask),
24972501
)
2502+
self._gca_resource = lro.result()
24982503
return self
24992504

25002505
def pause(self) -> "ModelDeploymentMonitoringJob":

tests/unit/aiplatform/test_jobs.py

Lines changed: 111 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#
1717

1818
import pytest
19+
import copy
1920

2021
from unittest import mock
2122
from importlib import reload
@@ -24,6 +25,7 @@
2425
from google.cloud import storage
2526
from google.cloud import bigquery
2627

28+
from google.api_core import operation
2729
from google.auth import credentials as auth_credentials
2830

2931
from google.cloud import aiplatform
@@ -46,7 +48,9 @@
4648
job_service_client,
4749
)
4850
from google.protobuf import field_mask_pb2 # type: ignore
51+
from google.protobuf import duration_pb2 # type: ignore
4952

53+
import test_endpoints # noqa: F401
5054
from test_endpoints import get_endpoint_with_models_mock # noqa: F401
5155

5256
_TEST_API_CLIENT = job_service_client.JobServiceClient
@@ -175,6 +179,58 @@
175179
_TEST_JOB_RESOURCE_NAME = f"{_TEST_PARENT}/customJobs/{_TEST_ID}"
176180

177181
_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG = {"TEST_KEY": 0.01}
182+
_TEST_MDM_USER_EMAIL = "TEST_EMAIL"
183+
_TEST_MDM_SAMPLE_RATE = 0.5
184+
_TEST_MDM_LABEL = {"TEST KEY": "TEST VAL"}
185+
_TEST_LOG_TTL_IN_DAYS = 1
186+
_TEST_MDM_NEW_NAME = "NEW_NAME"
187+
188+
_TEST_MDM_OLD_JOB = (
189+
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
190+
name=_TEST_MDM_JOB_NAME,
191+
display_name=_TEST_DISPLAY_NAME,
192+
endpoint=_TEST_ENDPOINT,
193+
state=_TEST_JOB_STATE_RUNNING,
194+
)
195+
)
196+
197+
_TEST_MDM_EXPECTED_NEW_JOB = gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
198+
name=_TEST_MDM_JOB_NAME,
199+
display_name=_TEST_MDM_NEW_NAME,
200+
endpoint=_TEST_ENDPOINT,
201+
state=_TEST_JOB_STATE_RUNNING,
202+
model_deployment_monitoring_objective_configs=[
203+
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig(
204+
deployed_model_id=model_id,
205+
objective_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig(
206+
prediction_drift_detection_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig(
207+
drift_thresholds={
208+
"TEST_KEY": gca_model_monitoring_compat.ThresholdConfig(
209+
value=0.01
210+
)
211+
}
212+
)
213+
),
214+
)
215+
for model_id in [model.id for model in test_endpoints._TEST_DEPLOYED_MODELS]
216+
],
217+
logging_sampling_strategy=gca_model_monitoring_compat.SamplingStrategy(
218+
random_sample_config=gca_model_monitoring_compat.SamplingStrategy.RandomSampleConfig(
219+
sample_rate=_TEST_MDM_SAMPLE_RATE
220+
)
221+
),
222+
labels=_TEST_MDM_LABEL,
223+
model_monitoring_alert_config=gca_model_monitoring_compat.ModelMonitoringAlertConfig(
224+
email_alert_config=gca_model_monitoring_compat.ModelMonitoringAlertConfig.EmailAlertConfig(
225+
user_emails=[_TEST_MDM_USER_EMAIL]
226+
)
227+
),
228+
model_deployment_monitoring_schedule_config=gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringScheduleConfig(
229+
monitor_interval=duration_pb2.Duration(seconds=3600)
230+
),
231+
log_ttl=duration_pb2.Duration(seconds=_TEST_LOG_TTL_IN_DAYS * 86400),
232+
enable_monitoring_pipeline_logs=True,
233+
)
178234

179235
# TODO(b/171333554): Move reusable test fixtures to conftest.py file
180236

@@ -988,48 +1044,23 @@ def get_mdm_job_mock():
9881044
with mock.patch.object(
9891045
_TEST_API_CLIENT, "get_model_deployment_monitoring_job"
9901046
) as get_mdm_job_mock:
991-
get_mdm_job_mock.return_value = (
992-
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
993-
name=_TEST_MDM_JOB_NAME,
994-
display_name=_TEST_DISPLAY_NAME,
995-
state=_TEST_JOB_STATE_RUNNING,
996-
endpoint=_TEST_ENDPOINT,
997-
)
998-
)
1047+
get_mdm_job_mock.side_effect = [
1048+
_TEST_MDM_OLD_JOB,
1049+
_TEST_MDM_OLD_JOB,
1050+
_TEST_MDM_OLD_JOB,
1051+
_TEST_MDM_EXPECTED_NEW_JOB,
1052+
]
9991053
yield get_mdm_job_mock
10001054

10011055

10021056
@pytest.fixture
1003-
@pytest.mark.usefixtures("get_mdm_job_mock")
10041057
def update_mdm_job_mock(get_endpoint_with_models_mock): # noqa: F811
10051058
with mock.patch.object(
10061059
_TEST_API_CLIENT, "update_model_deployment_monitoring_job"
10071060
) as update_mdm_job_mock:
1008-
expected_objective_config = gca_model_monitoring_compat.ModelMonitoringObjectiveConfig(
1009-
prediction_drift_detection_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig(
1010-
drift_thresholds={
1011-
"TEST_KEY": gca_model_monitoring_compat.ThresholdConfig(value=0.01)
1012-
}
1013-
)
1014-
)
1015-
all_configs = []
1016-
for model in get_endpoint_with_models_mock.return_value.deployed_models:
1017-
all_configs.append(
1018-
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig(
1019-
deployed_model_id=model.id,
1020-
objective_config=expected_objective_config,
1021-
)
1022-
)
1023-
1024-
update_mdm_job_mock.return_vaue.result_type = (
1025-
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
1026-
name=_TEST_MDM_JOB_NAME,
1027-
display_name=_TEST_DISPLAY_NAME,
1028-
state=_TEST_JOB_STATE_RUNNING,
1029-
endpoint=_TEST_ENDPOINT,
1030-
model_deployment_monitoring_objective_configs=all_configs,
1031-
)
1032-
)
1061+
update_mdm_job_lro_mock = mock.Mock(operation.Operation)
1062+
update_mdm_job_lro_mock.result.return_value = _TEST_MDM_EXPECTED_NEW_JOB
1063+
update_mdm_job_mock.return_value = update_mdm_job_lro_mock
10331064
yield update_mdm_job_mock
10341065

10351066

@@ -1046,25 +1077,66 @@ def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock):
10461077
job = jobs.ModelDeploymentMonitoringJob(
10471078
model_deployment_monitoring_job_name=_TEST_MDM_JOB_NAME
10481079
)
1080+
old_job = copy.deepcopy(job._gca_resource)
10491081
drift_detection_config = aiplatform.model_monitoring.DriftDetectionConfig(
10501082
drift_thresholds=_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG
10511083
)
1084+
schedule_config = aiplatform.model_monitoring.ScheduleConfig(monitor_interval=1)
1085+
alert_config = aiplatform.model_monitoring.EmailAlertConfig(
1086+
user_emails=[_TEST_MDM_USER_EMAIL]
1087+
)
1088+
sampling_strategy = aiplatform.model_monitoring.RandomSampleConfig(
1089+
sample_rate=_TEST_MDM_SAMPLE_RATE
1090+
)
1091+
labels = _TEST_MDM_LABEL
1092+
log_ttl = _TEST_LOG_TTL_IN_DAYS
1093+
display_name = _TEST_MDM_NEW_NAME
10521094
new_config = aiplatform.model_monitoring.ObjectiveConfig(
10531095
drift_detection_config=drift_detection_config
10541096
)
1055-
job.update(objective_configs=new_config)
1097+
job.update(
1098+
display_name=display_name,
1099+
schedule_config=schedule_config,
1100+
alert_config=alert_config,
1101+
logging_sampling_strategy=sampling_strategy,
1102+
labels=labels,
1103+
bigquery_tables_log_ttl=log_ttl,
1104+
enable_monitoring_pipeline_logs=True,
1105+
objective_configs=new_config,
1106+
)
1107+
new_job = job._gca_resource
1108+
assert old_job != new_job
1109+
assert new_job.display_name == display_name
1110+
assert new_job.logging_sampling_strategy == sampling_strategy.as_proto()
1111+
assert (
1112+
new_job.model_deployment_monitoring_schedule_config
1113+
== schedule_config.as_proto()
1114+
)
1115+
assert new_job.labels == labels
1116+
assert new_job.model_monitoring_alert_config == alert_config.as_proto()
1117+
assert new_job.log_ttl.days == _TEST_LOG_TTL_IN_DAYS
1118+
assert new_job.enable_monitoring_pipeline_logs
10561119
assert (
1057-
job._gca_resource.model_deployment_monitoring_objective_configs[
1120+
new_job.model_deployment_monitoring_objective_configs[
10581121
0
10591122
].objective_config.prediction_drift_detection_config
10601123
== drift_detection_config.as_proto()
10611124
)
10621125
get_mdm_job_mock.assert_called_with(
1063-
name=_TEST_MDM_JOB_NAME,
1126+
name=_TEST_MDM_JOB_NAME, retry=base._DEFAULT_RETRY
10641127
)
10651128
update_mdm_job_mock.assert_called_once_with(
1066-
model_deployment_monitoring_job=get_mdm_job_mock.return_value,
1129+
model_deployment_monitoring_job=new_job,
10671130
update_mask=field_mask_pb2.FieldMask(
1068-
paths=["model_deployment_monitoring_objective_configs"]
1131+
paths=[
1132+
"display_name",
1133+
"model_deployment_monitoring_schedule_config",
1134+
"model_monitoring_alert_config",
1135+
"logging_sampling_strategy",
1136+
"labels",
1137+
"log_ttl",
1138+
"enable_monitoring_pipeline_logs",
1139+
"model_deployment_monitoring_objective_configs",
1140+
]
10691141
),
10701142
)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy