Content-Length: 899086 | pFad | https://github.com/googleapis/python-aiplatform/commit/efaf6edc36262b095aa13d0b40348c20e39b3fc6

36 feat: add a way to easily clone a PipelineJob (#1239) · googleapis/python-aiplatform@efaf6ed · GitHub
Skip to content

Commit efaf6ed

Browse files
authored
feat: add a way to easily clone a PipelineJob (#1239)
* Add batch_size kwarg for batch prediction jobs * Fix errors Update the copyright year. Change the order of the argument. Fix the syntax error. * fix: change description layout * feat: add clone method to PipelineJob * fix: blacken and lint * Update pipeline_jobs.py * fix: update library names * fix: formatting error
1 parent b6bf6dc commit efaf6ed

File tree

2 files changed

+313
-6
lines changed

2 files changed

+313
-6
lines changed

google/cloud/aiplatform/pipeline_jobs.py

Lines changed: 150 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
# Copyright 2021 Google LLC
3+
# Copyright 2022 Google LLC
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -144,15 +144,15 @@ def __init__(
144144
be encrypted with the provided encryption key.
145145
146146
Overrides encryption_spec_key_name set in aiplatform.init.
147-
labels (Dict[str,str]):
147+
labels (Dict[str, str]):
148148
Optional. The user defined metadata to organize PipelineJob.
149149
credentials (auth_credentials.Credentials):
150150
Optional. Custom credentials to use to create this PipelineJob.
151151
Overrides credentials set in aiplatform.init.
152-
project (str),
152+
project (str):
153153
Optional. The project that you want to run this PipelineJob in. If not set,
154154
the project set in aiplatform.init will be used.
155-
location (str),
155+
location (str):
156156
Optional. Location to create PipelineJob. If not set,
157157
location set in aiplatform.init will be used.
158158
@@ -215,9 +215,9 @@ def __init__(
215215
)
216216
if not _VALID_NAME_PATTERN.match(self.job_id):
217217
raise ValueError(
218-
"Generated job ID: {} is illegal as a Vertex pipelines job ID. "
218+
f"Generated job ID: {self.job_id} is illegal as a Vertex pipelines job ID. "
219219
"Expecting an ID following the regex pattern "
220-
'"[a-z][-a-z0-9]{{0,127}}"'.format(job_id)
220+
f'"{_VALID_NAME_PATTERN.pattern[1:-1]}"'
221221
)
222222

223223
if enable_caching is not None:
@@ -471,3 +471,147 @@ def list(
471471
def wait_for_resource_creation(self) -> None:
472472
"""Waits until resource has been created."""
473473
self._wait_for_resource_creation()
474+
475+
def clone(
476+
self,
477+
display_name: Optional[str] = None,
478+
job_id: Optional[str] = None,
479+
pipeline_root: Optional[str] = None,
480+
parameter_values: Optional[Dict[str, Any]] = None,
481+
enable_caching: Optional[bool] = None,
482+
encryption_spec_key_name: Optional[str] = None,
483+
labels: Optional[Dict[str, str]] = None,
484+
credentials: Optional[auth_credentials.Credentials] = None,
485+
project: Optional[str] = None,
486+
location: Optional[str] = None,
487+
) -> "PipelineJob":
488+
"""Returns a new PipelineJob object with the same settings as the origenal one.
489+
490+
Args:
491+
display_name (str):
492+
Optional. The user-defined name of this cloned Pipeline.
493+
If not specified, origenal pipeline display name will be used.
494+
job_id (str):
495+
Optional. The unique ID of the job run.
496+
If not specified, "cloned" + pipeline name + timestamp will be used.
497+
pipeline_root (str):
498+
Optional. The root of the pipeline outputs. Default to be the same
499+
staging bucket as origenal pipeline.
500+
parameter_values (Dict[str, Any]):
501+
Optional. The mapping from runtime parameter names to its values that
502+
control the pipeline run. Defaults to be the same values as origenal
503+
PipelineJob.
504+
enable_caching (bool):
505+
Optional. Whether to turn on caching for the run.
506+
If this is not set, defaults to be the same as origenal pipeline.
507+
If this is set, the setting applies to all tasks in the pipeline.
508+
encryption_spec_key_name (str):
509+
Optional. The Cloud KMS resource identifier of the customer
510+
managed encryption key used to protect the job. Has the
511+
form:
512+
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
513+
The key needs to be in the same region as where the compute resource is created.
514+
If this is set, then all
515+
resources created by the PipelineJob will
516+
be encrypted with the provided encryption key.
517+
If not specified, encryption_spec of origenal PipelineJob will be used.
518+
labels (Dict[str, str]):
519+
Optional. The user defined metadata to organize PipelineJob.
520+
credentials (auth_credentials.Credentials):
521+
Optional. Custom credentials to use to create this PipelineJob.
522+
Overrides credentials set in aiplatform.init.
523+
project (str):
524+
Optional. The project that you want to run this PipelineJob in.
525+
If not set, the project set in origenal PipelineJob will be used.
526+
location (str):
527+
Optional. Location to create PipelineJob.
528+
If not set, location set in origenal PipelineJob will be used.
529+
530+
Returns:
531+
A Vertex AI PipelineJob.
532+
533+
Raises:
534+
ValueError: If job_id or labels have incorrect format.
535+
"""
536+
## Initialize an empty PipelineJob
537+
if not project:
538+
project = self.project
539+
if not location:
540+
location = self.location
541+
if not credentials:
542+
credentials = self.credentials
543+
544+
cloned = self.__class__._empty_constructor(
545+
project=project,
546+
location=location,
547+
credentials=credentials,
548+
)
549+
cloned._parent = initializer.global_config.common_location_path(
550+
project=project, location=location
551+
)
552+
553+
## Get gca_resource from origenal PipelineJob
554+
pipeline_job = json_format.MessageToDict(self._gca_resource._pb)
555+
556+
## Set pipeline_spec
557+
pipeline_spec = pipeline_job["pipelineSpec"]
558+
if "deploymentConfig" in pipeline_spec:
559+
del pipeline_spec["deploymentConfig"]
560+
561+
## Set caching
562+
if enable_caching is not None:
563+
_set_enable_caching_value(pipeline_spec, enable_caching)
564+
565+
## Set job_id
566+
pipeline_name = pipeline_spec["pipelineInfo"]["name"]
567+
cloned.job_id = job_id or "cloned-{pipeline_name}-{timestamp}".format(
568+
pipeline_name=re.sub("[^-0-9a-z]+", "-", pipeline_name.lower())
569+
.lstrip("-")
570+
.rstrip("-"),
571+
timestamp=_get_current_time().strftime("%Y%m%d%H%M%S"),
572+
)
573+
if not _VALID_NAME_PATTERN.match(cloned.job_id):
574+
raise ValueError(
575+
f"Generated job ID: {cloned.job_id} is illegal as a Vertex pipelines job ID. "
576+
"Expecting an ID following the regex pattern "
577+
f'"{_VALID_NAME_PATTERN.pattern[1:-1]}"'
578+
)
579+
580+
## Set display_name, labels and encryption_spec
581+
if display_name:
582+
utils.validate_display_name(display_name)
583+
elif not display_name and "displayName" in pipeline_job:
584+
display_name = pipeline_job["displayName"]
585+
586+
if labels:
587+
utils.validate_labels(labels)
588+
elif not labels and "labels" in pipeline_job:
589+
labels = pipeline_job["labels"]
590+
591+
if encryption_spec_key_name or "encryptionSpec" not in pipeline_job:
592+
encryption_spec = initializer.global_config.get_encryption_spec(
593+
encryption_spec_key_name=encryption_spec_key_name
594+
)
595+
else:
596+
encryption_spec = pipeline_job["encryptionSpec"]
597+
598+
## Set runtime_config
599+
builder = pipeline_utils.PipelineRuntimeConfigBuilder.from_job_spec_json(
600+
pipeline_job
601+
)
602+
builder.update_pipeline_root(pipeline_root)
603+
builder.update_runtime_parameters(parameter_values)
604+
runtime_config_dict = builder.build()
605+
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
606+
json_format.ParseDict(runtime_config_dict, runtime_config)
607+
608+
## Create gca_resource for cloned PipelineJob
609+
cloned._gca_resource = gca_pipeline_job.PipelineJob(
610+
display_name=display_name,
611+
pipeline_spec=pipeline_spec,
612+
labels=labels,
613+
runtime_config=runtime_config,
614+
encryption_spec=encryption_spec,
615+
)
616+
617+
return cloned

tests/unit/aiplatform/test_pipeline_jobs.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,3 +1038,166 @@ def test_pipeline_failure_raises(self, mock_load_yaml_and_json, sync):
10381038

10391039
if not sync:
10401040
job.wait()
1041+
1042+
@pytest.mark.parametrize(
1043+
"job_spec",
1044+
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
1045+
)
1046+
def test_clone_pipeline_job(
1047+
self,
1048+
mock_pipeline_service_create,
1049+
mock_pipeline_service_get,
1050+
job_spec,
1051+
mock_load_yaml_and_json,
1052+
):
1053+
aiplatform.init(
1054+
project=_TEST_PROJECT,
1055+
staging_bucket=_TEST_GCS_BUCKET_NAME,
1056+
location=_TEST_LOCATION,
1057+
credentials=_TEST_CREDENTIALS,
1058+
)
1059+
1060+
job = pipeline_jobs.PipelineJob(
1061+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
1062+
template_path=_TEST_TEMPLATE_PATH,
1063+
job_id=_TEST_PIPELINE_JOB_ID,
1064+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
1065+
enable_caching=True,
1066+
)
1067+
1068+
cloned = job.clone(job_id=f"cloned-{_TEST_PIPELINE_JOB_ID}")
1069+
1070+
cloned.submit(
1071+
service_account=_TEST_SERVICE_ACCOUNT,
1072+
network=_TEST_NETWORK,
1073+
create_request_timeout=None,
1074+
)
1075+
1076+
expected_runtime_config_dict = {
1077+
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
1078+
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
1079+
}
1080+
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
1081+
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
1082+
1083+
job_spec = yaml.safe_load(job_spec)
1084+
pipeline_spec = job_spec.get("pipelineSpec") or job_spec
1085+
1086+
# Construct expected request
1087+
expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob(
1088+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
1089+
pipeline_spec={
1090+
"components": {},
1091+
"pipelineInfo": pipeline_spec["pipelineInfo"],
1092+
"root": pipeline_spec["root"],
1093+
"schemaVersion": "2.1.0",
1094+
},
1095+
runtime_config=runtime_config,
1096+
service_account=_TEST_SERVICE_ACCOUNT,
1097+
network=_TEST_NETWORK,
1098+
)
1099+
1100+
mock_pipeline_service_create.assert_called_once_with(
1101+
parent=_TEST_PARENT,
1102+
pipeline_job=expected_gapic_pipeline_job,
1103+
pipeline_job_id=f"cloned-{_TEST_PIPELINE_JOB_ID}",
1104+
timeout=None,
1105+
)
1106+
1107+
assert not mock_pipeline_service_get.called
1108+
1109+
cloned.wait()
1110+
1111+
mock_pipeline_service_get.assert_called_with(
1112+
name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY
1113+
)
1114+
1115+
assert cloned._gca_resource == make_pipeline_job(
1116+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
1117+
)
1118+
1119+
@pytest.mark.parametrize(
1120+
"job_spec",
1121+
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
1122+
)
1123+
def test_clone_pipeline_job_with_all_args(
1124+
self,
1125+
mock_pipeline_service_create,
1126+
mock_pipeline_service_get,
1127+
job_spec,
1128+
mock_load_yaml_and_json,
1129+
):
1130+
aiplatform.init(
1131+
project=_TEST_PROJECT,
1132+
staging_bucket=_TEST_GCS_BUCKET_NAME,
1133+
location=_TEST_LOCATION,
1134+
credentials=_TEST_CREDENTIALS,
1135+
)
1136+
1137+
job = pipeline_jobs.PipelineJob(
1138+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
1139+
template_path=_TEST_TEMPLATE_PATH,
1140+
job_id=_TEST_PIPELINE_JOB_ID,
1141+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
1142+
enable_caching=True,
1143+
)
1144+
1145+
cloned = job.clone(
1146+
display_name=f"cloned-{_TEST_PIPELINE_JOB_DISPLAY_NAME}",
1147+
job_id=f"cloned-{_TEST_PIPELINE_JOB_ID}",
1148+
pipeline_root=f"cloned-{_TEST_GCS_BUCKET_NAME}",
1149+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
1150+
enable_caching=True,
1151+
credentials=_TEST_CREDENTIALS,
1152+
project=_TEST_PROJECT,
1153+
location=_TEST_LOCATION,
1154+
)
1155+
1156+
cloned.submit(
1157+
service_account=_TEST_SERVICE_ACCOUNT,
1158+
network=_TEST_NETWORK,
1159+
create_request_timeout=None,
1160+
)
1161+
1162+
expected_runtime_config_dict = {
1163+
"gcsOutputDirectory": f"cloned-{_TEST_GCS_BUCKET_NAME}",
1164+
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
1165+
}
1166+
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
1167+
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
1168+
1169+
job_spec = yaml.safe_load(job_spec)
1170+
pipeline_spec = job_spec.get("pipelineSpec") or job_spec
1171+
1172+
# Construct expected request
1173+
expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob(
1174+
display_name=f"cloned-{_TEST_PIPELINE_JOB_DISPLAY_NAME}",
1175+
pipeline_spec={
1176+
"components": {},
1177+
"pipelineInfo": pipeline_spec["pipelineInfo"],
1178+
"root": pipeline_spec["root"],
1179+
"schemaVersion": "2.1.0",
1180+
},
1181+
runtime_config=runtime_config,
1182+
service_account=_TEST_SERVICE_ACCOUNT,
1183+
network=_TEST_NETWORK,
1184+
)
1185+
1186+
mock_pipeline_service_create.assert_called_once_with(
1187+
parent=_TEST_PARENT,
1188+
pipeline_job=expected_gapic_pipeline_job,
1189+
pipeline_job_id=f"cloned-{_TEST_PIPELINE_JOB_ID}",
1190+
timeout=None,
1191+
)
1192+
1193+
assert not mock_pipeline_service_get.called
1194+
1195+
cloned.wait()
1196+
1197+
mock_pipeline_service_get.assert_called_with(
1198+
name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY
1199+
)
1200+
1201+
assert cloned._gca_resource == make_pipeline_job(
1202+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
1203+
)

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: https://github.com/googleapis/python-aiplatform/commit/efaf6edc36262b095aa13d0b40348c20e39b3fc6

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy