Skip to content

Commit e138cfd

Browse files
feat: add support for accepting an Artifact Registry URL in pipeline_job (#1405)
* Add support for Artifact Registry in template_path * fix typo * update tests * fix AR path * remove unused project * add code for refreshing credentials * add import for google.auth.transport * fix AR path * fix AR path * fix runtime_config * test removing v1beta1 * try using v1 directly instead * update to use v1beta1 * use select_version * add back template_uri * try adding back v1beta1 * use select_version * differentiate when to use select_version * test removing v1beta1 for pipeline_complete_states * add tests for creating pipelines using v1beta1 * fix merge * fix typo * fix lint using blacken * fix regex * update to use v1 instead of v1beta1 * add test for invalid url * update error type * implement failure_policy * use urllib.request instead of requests * Revert "implement failure_policy" This reverts commit 72cdd9e. * fix lint Co-authored-by: Anthonios Partheniou <partheniou@google.com>
1 parent 82f678e commit e138cfd

File tree

4 files changed

+177
-11
lines changed

4 files changed

+177
-11
lines changed

google/cloud/aiplatform/pipeline_jobs.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@
5656
# Pattern for valid names used as a Vertex resource name.
5757
_VALID_NAME_PATTERN = re.compile("^[a-z][-a-z0-9]{0,127}$")
5858

59+
# Pattern for an Artifact Registry URL.
60+
_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*")
61+
5962

6063
def _get_current_time() -> datetime.datetime:
6164
"""Gets the current timestamp."""
@@ -125,8 +128,9 @@ def __init__(
125128
Required. The user-defined name of this Pipeline.
126129
template_path (str):
127130
Required. The path of PipelineJob or PipelineSpec JSON or YAML file. It
128-
can be a local path or a Google Cloud Storage URI.
129-
Example: "gs://project.name"
131+
can be a local path, a Google Cloud Storage URI (e.g. "gs://project.name"),
132+
or an Artifact Registry URI (e.g.
133+
"https://us-central1-kfp.pkg.dev/proj/repo/pack/latest").
130134
job_id (str):
131135
Optional. The unique ID of the job run.
132136
If not specified, pipeline name + timestamp will be used.
@@ -237,15 +241,20 @@ def __init__(
237241
if enable_caching is not None:
238242
_set_enable_caching_value(pipeline_job["pipelineSpec"], enable_caching)
239243

240-
self._gca_resource = gca_pipeline_job.PipelineJob(
241-
display_name=display_name,
242-
pipeline_spec=pipeline_job["pipelineSpec"],
243-
labels=labels,
244-
runtime_config=runtime_config,
245-
encryption_spec=initializer.global_config.get_encryption_spec(
244+
pipeline_job_args = {
245+
"display_name": display_name,
246+
"pipeline_spec": pipeline_job["pipelineSpec"],
247+
"labels": labels,
248+
"runtime_config": runtime_config,
249+
"encryption_spec": initializer.global_config.get_encryption_spec(
246250
encryption_spec_key_name=encryption_spec_key_name
247251
),
248-
)
252+
}
253+
254+
if _VALID_AR_URL.match(template_path):
255+
pipeline_job_args["template_uri"] = template_path
256+
257+
self._gca_resource = gca_pipeline_job.PipelineJob(**pipeline_job_args)
249258

250259
@base.optional_sync()
251260
def run(

google/cloud/aiplatform/utils/yaml_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,17 @@
1515
# limitations under the License.
1616
#
1717

18+
import re
1819
from typing import Any, Dict, Optional
20+
from urllib import request
1921

2022
from google.auth import credentials as auth_credentials
23+
from google.auth import transport
2124
from google.cloud import storage
2225

26+
# Pattern for an Artifact Registry URL.
27+
_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*")
28+
2329

2430
def load_yaml(
2531
path: str,
@@ -42,6 +48,8 @@ def load_yaml(
4248
"""
4349
if path.startswith("gs://"):
4450
return _load_yaml_from_gs_uri(path, project, credentials)
51+
elif _VALID_AR_URL.match(path):
52+
return _load_yaml_from_ar_uri(path, credentials)
4553
else:
4654
return _load_yaml_from_local_file(path)
4755

@@ -95,3 +103,37 @@ def _load_yaml_from_local_file(file_path: str) -> Dict[str, Any]:
95103
)
96104
with open(file_path) as f:
97105
return yaml.safe_load(f)
106+
107+
108+
def _load_yaml_from_ar_uri(
109+
uri: str,
110+
credentials: Optional[auth_credentials.Credentials] = None,
111+
) -> Dict[str, Any]:
112+
"""Loads data from a YAML document referenced by a Artifact Registry URI.
113+
114+
Args:
115+
path (str):
116+
Required. Artifact Registry URI for YAML document.
117+
credentials (auth_credentials.Credentials):
118+
Optional. Credentials to use with Artifact Registry.
119+
120+
Returns:
121+
A Dict object representing the YAML document.
122+
"""
123+
try:
124+
import yaml
125+
except ImportError:
126+
raise ImportError(
127+
"pyyaml is not installed and is required to parse PipelineJob or PipelineSpec files. "
128+
'Please install the SDK using "pip install google-cloud-aiplatform[pipelines]"'
129+
)
130+
req = request.Request(uri)
131+
132+
if credentials:
133+
if not credentials.valid:
134+
credentials.refresh(transport.requests.Request())
135+
if credentials.token:
136+
req.add_header("Authorization", "Bearer " + credentials.token)
137+
response = request.urlopen(req)
138+
139+
return yaml.safe_load(response.read().decode("utf-8"))

tests/unit/aiplatform/test_pipeline_jobs.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from unittest import mock
2323
from importlib import reload
2424
from unittest.mock import patch
25+
from urllib import request
2526
from datetime import datetime
2627

2728
from google.auth import credentials as auth_credentials
@@ -50,6 +51,7 @@
5051
_TEST_SERVICE_ACCOUNT = "abcde@my-project.iam.gserviceaccount.com"
5152

5253
_TEST_TEMPLATE_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}/job_spec.json"
54+
_TEST_AR_TEMPLATE_PATH = "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"
5355
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
5456
_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_PIPELINE_JOB_ID}"
5557

@@ -289,6 +291,17 @@ def mock_load_yaml_and_json(job_spec):
289291
yield mock_load_yaml_and_json
290292

291293

294+
@pytest.fixture
295+
def mock_request_urlopen(job_spec):
296+
with patch.object(request, "urlopen") as mock_urlopen:
297+
mock_read_response = mock.MagicMock()
298+
mock_decode_response = mock.MagicMock()
299+
mock_decode_response.return_value = job_spec.encode()
300+
mock_read_response.return_value.decode = mock_decode_response
301+
mock_urlopen.return_value.read = mock_read_response
302+
yield mock_urlopen
303+
304+
292305
@pytest.mark.usefixtures("google_auth_mock")
293306
class TestPipelineJob:
294307
def setup_method(self):
@@ -376,6 +389,85 @@ def test_run_call_pipeline_service_create(
376389
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
377390
)
378391

392+
@pytest.mark.parametrize(
393+
"job_spec",
394+
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
395+
)
396+
@pytest.mark.parametrize("sync", [True, False])
397+
def test_run_call_pipeline_service_create_artifact_registry(
398+
self,
399+
mock_pipeline_service_create,
400+
mock_pipeline_service_get,
401+
mock_request_urlopen,
402+
job_spec,
403+
mock_load_yaml_and_json,
404+
sync,
405+
):
406+
aiplatform.init(
407+
project=_TEST_PROJECT,
408+
staging_bucket=_TEST_GCS_BUCKET_NAME,
409+
location=_TEST_LOCATION,
410+
credentials=_TEST_CREDENTIALS,
411+
)
412+
413+
job = pipeline_jobs.PipelineJob(
414+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
415+
template_path=_TEST_AR_TEMPLATE_PATH,
416+
job_id=_TEST_PIPELINE_JOB_ID,
417+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
418+
enable_caching=True,
419+
)
420+
421+
job.run(
422+
service_account=_TEST_SERVICE_ACCOUNT,
423+
network=_TEST_NETWORK,
424+
sync=sync,
425+
create_request_timeout=None,
426+
)
427+
428+
if not sync:
429+
job.wait()
430+
431+
expected_runtime_config_dict = {
432+
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
433+
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
434+
}
435+
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
436+
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
437+
438+
job_spec = yaml.safe_load(job_spec)
439+
pipeline_spec = job_spec.get("pipelineSpec") or job_spec
440+
441+
# Construct expected request
442+
expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob(
443+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
444+
pipeline_spec={
445+
"components": {},
446+
"pipelineInfo": pipeline_spec["pipelineInfo"],
447+
"root": pipeline_spec["root"],
448+
"schemaVersion": "2.1.0",
449+
},
450+
runtime_config=runtime_config,
451+
service_account=_TEST_SERVICE_ACCOUNT,
452+
network=_TEST_NETWORK,
453+
template_uri=_TEST_AR_TEMPLATE_PATH,
454+
)
455+
456+
mock_pipeline_service_create.assert_called_once_with(
457+
parent=_TEST_PARENT,
458+
pipeline_job=expected_gapic_pipeline_job,
459+
pipeline_job_id=_TEST_PIPELINE_JOB_ID,
460+
timeout=None,
461+
)
462+
463+
mock_pipeline_service_get.assert_called_with(
464+
name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY
465+
)
466+
467+
assert job._gca_resource == make_pipeline_job(
468+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
469+
)
470+
379471
@pytest.mark.parametrize(
380472
"job_spec",
381473
[

tests/unit/aiplatform/test_utils.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import json
2121
import os
2222
from typing import Callable, Dict, Optional
23+
from unittest import mock
24+
from urllib import request
2325

2426
import pytest
2527
import yaml
@@ -564,13 +566,34 @@ def json_file(tmp_path):
564566
yield json_file_path
565567

566568

569+
@pytest.fixture(scope="function")
570+
def mock_request_urlopen():
571+
data = {"key": "val", "list": ["1", 2, 3.0]}
572+
with mock.patch.object(request, "urlopen") as mock_urlopen:
573+
mock_read_response = mock.MagicMock()
574+
mock_decode_response = mock.MagicMock()
575+
mock_decode_response.return_value = json.dumps(data)
576+
mock_read_response.return_value.decode = mock_decode_response
577+
mock_urlopen.return_value.read = mock_read_response
578+
yield "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"
579+
580+
567581
class TestYamlUtils:
568-
def test_load_yaml_from_local_file__with_json(self, yaml_file):
582+
def test_load_yaml_from_local_file__with_yaml(self, yaml_file):
569583
actual = yaml_utils.load_yaml(yaml_file)
570584
expected = {"key": "val", "list": ["1", 2, 3.0]}
571585
assert actual == expected
572586

573-
def test_load_yaml_from_local_file__with_yaml(self, json_file):
587+
def test_load_yaml_from_local_file__with_json(self, json_file):
574588
actual = yaml_utils.load_yaml(json_file)
575589
expected = {"key": "val", "list": ["1", 2, 3.0]}
576590
assert actual == expected
591+
592+
def test_load_yaml_from_ar_uri(self, mock_request_urlopen):
593+
actual = yaml_utils.load_yaml(mock_request_urlopen)
594+
expected = {"key": "val", "list": ["1", 2, 3.0]}
595+
assert actual == expected
596+
597+
def test_load_yaml_from_invalid_uri(self):
598+
with pytest.raises(FileNotFoundError):
599+
yaml_utils.load_yaml("https://us-docker.pkg.dev/v2/proj/repo/img/tags/list")

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy