|
28 | 28 |
|
29 | 29 | from google.auth import credentials as auth_credentials
|
30 | 30 | from google.protobuf import duration_pb2 # type: ignore
|
| 31 | +from google.protobuf import field_mask_pb2 # type: ignore |
31 | 32 | from google.rpc import status_pb2
|
32 | 33 |
|
33 | 34 | from google.cloud import aiplatform
|
@@ -1952,7 +1953,9 @@ def __init__(
|
1952 | 1953 | location=location,
|
1953 | 1954 | credentials=credentials,
|
1954 | 1955 | )
|
1955 |
| - self._gca_resource = self._get_gca_resource(resource_name=self.job_name) |
| 1956 | + self._gca_resource = self._get_gca_resource( |
| 1957 | + resource_name=model_deployment_monitoring_job_name |
| 1958 | + ) |
1956 | 1959 | self._endpoint_resource_name = ""
|
1957 | 1960 |
|
1958 | 1961 | @classmethod
|
@@ -1985,6 +1988,10 @@ def _parse_configs(
|
1985 | 1988 | all_configs = []
|
1986 | 1989 | all_models = []
|
1987 | 1990 | default_endpoint = "aiplatform.googleapis.com"
|
| 1991 | + if aiplatform.initializer.global_config._location is None: |
| 1992 | + raise ValueError( |
| 1993 | + "Error parsing model monitoring objective configs: project location is not set" |
| 1994 | + ) |
1988 | 1995 | client_options = dict(
|
1989 | 1996 | api_endpoint=f"{aiplatform.initializer.global_config._location}-{default_endpoint}"
|
1990 | 1997 | )
|
@@ -2309,7 +2316,7 @@ def update(
|
2309 | 2316 | ) -> "ModelDeploymentMonitoringJob":
|
2310 | 2317 | """"""
|
2311 | 2318 | current_job = self.api_client.get_model_deployment_monitoring_job(
|
2312 |
| - name=self.model_deployment_monitoring_job_name |
| 2319 | + name=self._gca_resource.name |
2313 | 2320 | )
|
2314 | 2321 | update_mask: List[str] = []
|
2315 | 2322 | if display_name:
|
@@ -2339,27 +2346,40 @@ def update(
|
2339 | 2346 | update_mask.append("model_deployment_monitoring_objective_configs")
|
2340 | 2347 | current_job.model_deployment_monitoring_objective_configs = (
|
2341 | 2348 | ModelDeploymentMonitoringJob._parse_configs(
|
2342 |
| - objective_configs, self._endpoint_resource_name |
| 2349 | + objective_configs, current_job.endpoint, deployed_model_ids |
2343 | 2350 | )
|
2344 | 2351 | )
|
2345 | 2352 | self.api_client.update_model_deployment_monitoring_job(
|
2346 |
| - model_deployment_monitoring_job=current_job, update_mask=update_mask |
| 2353 | + model_deployment_monitoring_job=current_job, |
| 2354 | + update_mask=field_mask_pb2.FieldMask(paths=update_mask), |
2347 | 2355 | )
|
2348 | 2356 |
|
2349 | 2357 | def pause(self) -> "ModelDeploymentMonitoringJob":
|
2350 | 2358 | """"""
|
2351 |
| - self.api_client.pause_model_deployment_monitoring_job( |
2352 |
| - self.model_deployment_monitoring_job_name |
2353 |
| - ) |
| 2359 | + if self.state == gca_job_state.JobState.JOB_STATE_RUNNING: |
| 2360 | + self.api_client.pause_model_deployment_monitoring_job( |
| 2361 | + name=self._gca_resource.name |
| 2362 | + ) |
| 2363 | + else: |
| 2364 | + raise RuntimeError( |
| 2365 | + "The monitoring job can only be paused under running / pending state, the current state is: %s" |
| 2366 | + % self.state |
| 2367 | + ) |
2354 | 2368 |
|
2355 | 2369 | def resume(self) -> "ModelDeploymentMonitoringJob":
|
2356 | 2370 | """"""
|
2357 |
| - self.api_client.resume_model_deployment_monitoring_job( |
2358 |
| - self.model_deployment_monitoring_job_name |
2359 |
| - ) |
| 2371 | + if self.state == gca_job_state.JobState.JOB_STATE_PAUSED: |
| 2372 | + self.api_client.resume_model_deployment_monitoring_job( |
| 2373 | + name=self._gca_resource.name |
| 2374 | + ) |
| 2375 | + else: |
| 2376 | + raise RuntimeError( |
| 2377 | + "The monitoring job can only be resumed under paused state" |
| 2378 | + ) |
2360 | 2379 |
|
2361 | 2380 | def delete(self) -> "ModelDeploymentMonitoringJob":
|
2362 | 2381 | """"""
|
| 2382 | + self.pause() |
2363 | 2383 | self.api_client.delete_model_deployment_monitoring_job(
|
2364 |
| - self.model_deployment_monitoring_job_name |
| 2384 | + name=self._gca_resource.name |
2365 | 2385 | )
|
0 commit comments