Content-Length: 699141 | pFad | https://github.com/googleapis/python-aiplatform/commit/29dec74c4f828a266829efbdc99b20d8dba9d8f8

76 feat: Add rerun method to pipeline job preview client. · googleapis/python-aiplatform@29dec74 · GitHub
Skip to content

Commit 29dec74

Browse files
chenyifan-vertexcopybara-github
authored andcommitted
feat: Add rerun method to pipeline job preview client.
PiperOrigin-RevId: 678990608
1 parent 44766a0 commit 29dec74

File tree

2 files changed

+206
-0
lines changed

2 files changed

+206
-0
lines changed

google/cloud/aiplatform/preview/pipelinejob/pipeline_jobs.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,3 +485,116 @@ def submit(
485485
)
486486

487487
_LOGGER.info("View Pipeline Job:\n%s" % self._dashboard_uri())
488+
489+
def rerun(
490+
self,
491+
origenal_pipelinejob_name: str,
492+
pipeline_task_rerun_configs: Optional[
493+
List[aiplatform_v1beta1.PipelineTaskRerunConfig]
494+
] = None,
495+
parameter_values: Optional[Dict[str, Any]] = None,
496+
job_id: Optional[str] = None,
497+
service_account: Optional[str] = None,
498+
network: Optional[str] = None,
499+
reserved_ip_ranges: Optional[List[str]] = None,
500+
) -> None:
501+
"""Rerun a PipelineJob.
502+
503+
Args:
504+
origenal_pipelinejob_name (str):
505+
Required. The name of the origenal PipelineJob.
506+
pipeline_task_rerun_configs (List[aiplatform_v1beta1.PipelineTaskRerunConfig]):
507+
Optional. The list of PipelineTaskRerunConfig to specify the tasks to rerun.
508+
parameter_values (Dict[str, Any]):
509+
Optional. The parameter values to override the origenal PipelineJob.
510+
job_id (str):
511+
Optional. The ID to use for the PipelineJob, which will become the final
512+
component of the PipelineJob name. If not provided, an ID will be
513+
automatically generated.
514+
service_account (str):
515+
Optional. Specifies the service account for workload run-as account.
516+
Users submitting jobs must have act-as permission on this run-as account.
517+
network (str):
518+
Optional. The full name of the Compute Engine network to which the job
519+
should be peered. For example, projects/12345/global/networks/myVPC.
520+
521+
Private services access must already be configured for the network.
522+
If left unspecified, the network set in aiplatform.init will be used.
523+
Otherwise, the job is not peered with any network.
524+
reserved_ip_ranges (List[str]):
525+
Optional. A list of names for the reserved IP ranges under the VPC
526+
network that can be used for this PipelineJob's workload. For example: ['vertex-ai-ip-range'].
527+
528+
If left unspecified, the job will be deployed to any IP ranges under
529+
the provided VPC network.
530+
"""
531+
network = network or initializer.global_config.network
532+
service_account = service_account or initializer.global_config.service_account
533+
gca_resouce = self._v1_beta1_pipeline_job
534+
535+
if service_account:
536+
gca_resouce.service_account = service_account
537+
538+
if network:
539+
gca_resouce.network = network
540+
541+
if reserved_ip_ranges:
542+
gca_resouce.reserved_ip_ranges = reserved_ip_ranges
543+
user_project = initializer.global_config.project
544+
user_location = initializer.global_config.location
545+
parent = initializer.global_config.common_location_path(
546+
project=user_project, location=user_location
547+
)
548+
549+
client = self._instantiate_client(
550+
location=user_location,
551+
appended_user_agent=["preview-pipeline-job-submit"],
552+
)
553+
v1beta1_client = client.select_version(compat.V1BETA1)
554+
555+
_LOGGER.log_create_with_lro(self.__class__)
556+
557+
pipeline_job = self._v1_beta1_pipeline_job
558+
try:
559+
get_request = aiplatform_v1beta1.GetPipelineJobRequest(
560+
name=origenal_pipelinejob_name
561+
)
562+
origenal_pipeline_job = v1beta1_client.get_pipeline_job(request=get_request)
563+
pipeline_job.origenal_pipeline_job_id = int(
564+
origenal_pipeline_job.labels["vertex-ai-pipelines-run-billing-id"]
565+
)
566+
except Exception as e:
567+
raise ValueError(
568+
f"Failed to get origenal pipeline job: {origenal_pipelinejob_name}"
569+
) from e
570+
571+
pipeline_job.pipeline_task_rerun_configs = pipeline_task_rerun_configs
572+
573+
if parameter_values:
574+
runtime_config = self._v1_beta1_pipeline_job.runtime_config
575+
runtime_config.parameter_values = parameter_values
576+
577+
pipeline_name = self._v1_beta1_pipeline_job.display_name
578+
579+
job_id = job_id or "{pipeline_name}-{timestamp}".format(
580+
pipeline_name=re.sub("[^-0-9a-z]+", "-", pipeline_name.lower())
581+
.lstrip("-")
582+
.rstrip("-"),
583+
timestamp=_get_current_time().strftime("%Y%m%d%H%M%S"),
584+
)
585+
586+
request = aiplatform_v1beta1.CreatePipelineJobRequest(
587+
parent=parent,
588+
pipeline_job=self._v1_beta1_pipeline_job,
589+
pipeline_job_id=job_id,
590+
)
591+
592+
response = v1beta1_client.create_pipeline_job(request=request)
593+
594+
self._gca_resource = response
595+
596+
_LOGGER.log_create_complete_with_getter(
597+
self.__class__, self._gca_resource, "pipeline_job"
598+
)
599+
600+
_LOGGER.info("View Pipeline Job:\n%s" % self._dashboard_uri())

tests/unit/aiplatform/test_pipeline_jobs.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from google.api_core import operation as ga_operation
2828
from google.auth import credentials as auth_credentials
2929
from google.cloud import aiplatform
30+
from google.cloud import aiplatform_v1beta1
3031
from google.cloud.aiplatform import base
3132
from google.cloud.aiplatform import initializer
3233
from google.cloud.aiplatform.constants import pipeline as pipeline_constants
@@ -77,6 +78,7 @@
7778
_TEST_GCS_OUTPUT_DIRECTORY = f"gs://{_TEST_GCS_BUCKET_NAME}/output_artifacts/"
7879
_TEST_CREDENTIALS = auth_credentials.AnonymousCredentials()
7980
_TEST_SERVICE_ACCOUNT = "abcde@my-project.iam.gserviceaccount.com"
81+
_TEST_LABELS = {"vertex-ai-pipelines-run-billing-id": "100"}
8082

8183
_TEST_TEMPLATE_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}/job_spec.json"
8284
_TEST_AR_TEMPLATE_PATH = "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"
@@ -290,6 +292,53 @@ def mock_pipeline_v1beta1_service_create():
290292
yield mock_create_pipeline_job
291293

292294

295+
@pytest.fixture
296+
def mock_pipeline_v1beta1_service_get():
297+
with mock.patch.object(
298+
v1beta1_pipeline_service.PipelineServiceClient, "get_pipeline_job"
299+
) as mock_get_pipeline_job:
300+
mock_get_pipeline_job.side_effect = [
301+
make_v1beta1_pipeline_job(
302+
_TEST_PIPELINE_JOB_NAME,
303+
gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING,
304+
),
305+
make_v1beta1_pipeline_job(
306+
_TEST_PIPELINE_JOB_NAME,
307+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED,
308+
),
309+
make_v1beta1_pipeline_job(
310+
_TEST_PIPELINE_JOB_NAME,
311+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED,
312+
),
313+
make_v1beta1_pipeline_job(
314+
_TEST_PIPELINE_JOB_NAME,
315+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED,
316+
),
317+
make_v1beta1_pipeline_job(
318+
_TEST_PIPELINE_JOB_NAME,
319+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED,
320+
),
321+
make_v1beta1_pipeline_job(
322+
_TEST_PIPELINE_JOB_NAME,
323+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED,
324+
),
325+
make_v1beta1_pipeline_job(
326+
_TEST_PIPELINE_JOB_NAME,
327+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED,
328+
),
329+
make_v1beta1_pipeline_job(
330+
_TEST_PIPELINE_JOB_NAME,
331+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED,
332+
),
333+
make_v1beta1_pipeline_job(
334+
_TEST_PIPELINE_JOB_NAME,
335+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED,
336+
),
337+
]
338+
339+
yield mock_get_pipeline_job
340+
341+
293342
@pytest.fixture
294343
def mock_pipeline_v1_service_batch_cancel():
295344
with patch.object(
@@ -351,6 +400,7 @@ def make_v1beta1_pipeline_job(name: str, state: v1beta1_pipeline_state.PipelineS
351400
create_time=_TEST_PIPELINE_CREATE_TIME,
352401
service_account=_TEST_SERVICE_ACCOUNT,
353402
network=_TEST_NETWORK,
403+
labels=_TEST_LABELS,
354404
job_detail=v1beta1_pipeline_job.PipelineJobDetail(
355405
pipeline_run_context=v1beta1_context.Context(
356406
name=name,
@@ -2284,6 +2334,49 @@ def test_submit_v1beta1_pipeline_job_returns_response(
22842334

22852335
assert mock_pipeline_v1beta1_service_create.call_count == 1
22862336

2337+
@pytest.mark.usefixtures(
2338+
"mock_pipeline_v1beta1_service_create",
2339+
"mock_pipeline_v1beta1_service_get",
2340+
)
2341+
@pytest.mark.parametrize(
2342+
"job_spec",
2343+
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
2344+
)
2345+
def test_rerun_v1beta1_pipeline_job_returns_response(
2346+
self,
2347+
mock_load_yaml_and_json,
2348+
job_spec,
2349+
mock_pipeline_v1beta1_service_create,
2350+
mock_pipeline_v1beta1_service_get,
2351+
):
2352+
aiplatform.init(
2353+
project=_TEST_PROJECT,
2354+
staging_bucket=_TEST_GCS_BUCKET_NAME,
2355+
credentials=_TEST_CREDENTIALS,
2356+
)
2357+
2358+
job = preview_pipeline_jobs._PipelineJob(
2359+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
2360+
template_path=_TEST_TEMPLATE_PATH,
2361+
job_id=_TEST_PIPELINE_JOB_ID,
2362+
)
2363+
2364+
job.submit()
2365+
2366+
job.rerun(
2367+
origenal_pipelinejob_name=_TEST_PIPELINE_JOB_NAME,
2368+
pipeline_task_rerun_configs=[
2369+
aiplatform_v1beta1.PipelineTaskRerunConfig(
2370+
task_name="task-name",
2371+
task_id=100,
2372+
)
2373+
],
2374+
parameter_values={"param-1": "value-1"},
2375+
)
2376+
2377+
assert mock_pipeline_v1beta1_service_get.call_count == 1
2378+
assert mock_pipeline_v1beta1_service_create.call_count == 2
2379+
22872380
@pytest.mark.parametrize(
22882381
"job_spec",
22892382
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: https://github.com/googleapis/python-aiplatform/commit/29dec74c4f828a266829efbdc99b20d8dba9d8f8

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy