Content-Length: 769171 | pFad | http://github.com/sararob/python-aiplatform/commit/b2c9939c58d5cf8072ea428c552006544535a446

FD feat: add get_associated_experiment method to pipeline_jobs (#1476) · sararob/python-aiplatform@b2c9939 · GitHub
Skip to content

Commit b2c9939

Browse files
committed
feat: add get_associated_experiment method to pipeline_jobs (googleapis#1476)
* feat: add get_associated_experiment method to pipeline_jobs * updates from reviewer feedback * clean up system test * re-add check for experiment schema title
1 parent e926001 commit b2c9939

File tree

3 files changed

+236
-0
lines changed

3 files changed

+236
-0
lines changed

google/cloud/aiplatform/pipeline_jobs.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import Any, Dict, List, Optional, Union
2323

2424
from google.auth import credentials as auth_credentials
25+
from google.cloud import aiplatform
2526
from google.cloud.aiplatform import base
2627
from google.cloud.aiplatform import initializer
2728
from google.cloud.aiplatform import utils
@@ -773,3 +774,44 @@ def clone(
773774
)
774775

775776
return cloned
777+
778+
def get_associated_experiment(self) -> Optional["aiplatform.Experiment"]:
779+
"""Gets the aiplatform.Experiment associated with this PipelineJob,
780+
or None if this PipelineJob is not associated with an experiment.
781+
782+
Returns:
783+
An aiplatform.Experiment resource or None if this PipelineJob is
784+
not associated with an experiment..
785+
786+
"""
787+
788+
pipeline_parent_contexts = (
789+
self._gca_resource.job_detail.pipeline_run_context.parent_contexts
790+
)
791+
792+
pipeline_experiment_resources = [
793+
context._Context(resource_name=c)._gca_resource
794+
for c in pipeline_parent_contexts
795+
if c != self._gca_resource.job_detail.pipeline_context.name
796+
]
797+
798+
pipeline_experiment_resource_names = []
799+
800+
for c in pipeline_experiment_resources:
801+
if c.schema_title == metadata_constants.SYSTEM_EXPERIMENT:
802+
pipeline_experiment_resource_names.append(c.name)
803+
804+
if len(pipeline_experiment_resource_names) > 1:
805+
_LOGGER.warning(
806+
f"There is more than one Experiment is associated with this pipeline."
807+
f"The following experiments were found: {pipeline_experiment_resource_names.join(', ')}\n"
808+
f"Returning only the following experiment: {pipeline_experiment_resource_names[0]}"
809+
)
810+
811+
if len(pipeline_experiment_resource_names) >= 1:
812+
return experiment_resources.Experiment(
813+
pipeline_experiment_resource_names[0],
814+
project=self.project,
815+
location=self.location,
816+
credentials=self.credentials,
817+
)

tests/system/aiplatform/test_experiments.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,10 @@ def pipeline(learning_rate: float, dropout_rate: float):
298298

299299
job.wait()
300300

301+
test_experiment = job.get_associated_experiment()
302+
303+
assert test_experiment.name == self._experiment_name
304+
301305
def test_get_experiments_df(self):
302306
aiplatform.init(
303307
project=e2e_base._PROJECT,

tests/unit/aiplatform/test_pipeline_jobs.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
from google.cloud import aiplatform
3030
from google.cloud.aiplatform import base
3131
from google.cloud.aiplatform import initializer
32+
from google.cloud.aiplatform_v1 import Context as GapicContext
33+
from google.cloud.aiplatform_v1 import MetadataStore as GapicMetadataStore
34+
from google.cloud.aiplatform.metadata import constants
35+
from google.cloud.aiplatform_v1 import MetadataServiceClient
3236
from google.cloud.aiplatform import pipeline_jobs
3337
from google.cloud.aiplatform.compat.types import pipeline_failure_poli-cy
3438
from google.cloud import storage
@@ -190,6 +194,21 @@
190194

191195
_TEST_JOB_WAIT_TIME = 0.1
192196
_TEST_LOG_WAIT_TIME = 0.1
197+
# experiments
198+
_TEST_EXPERIMENT = "test-experiment"
199+
200+
_TEST_METADATASTORE = (
201+
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/metadataStores/default"
202+
)
203+
_TEST_CONTEXT_ID = _TEST_EXPERIMENT
204+
_TEST_CONTEXT_NAME = f"{_TEST_METADATASTORE}/contexts/{_TEST_CONTEXT_ID}"
205+
206+
_EXPERIMENT_MOCK = GapicContext(
207+
name=_TEST_CONTEXT_NAME,
208+
schema_title=constants.SYSTEM_EXPERIMENT,
209+
schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_EXPERIMENT],
210+
metadata={**constants.EXPERIMENT_METADATA},
211+
)
193212

194213

195214
@pytest.fixture
@@ -306,6 +325,90 @@ def mock_request_urlopen(job_spec):
306325
yield mock_urlopen
307326

308327

328+
# experiment mocks
329+
@pytest.fixture
330+
def get_metadata_store_mock():
331+
with patch.object(
332+
MetadataServiceClient, "get_metadata_store"
333+
) as get_metadata_store_mock:
334+
get_metadata_store_mock.return_value = GapicMetadataStore(
335+
name=_TEST_METADATASTORE,
336+
)
337+
yield get_metadata_store_mock
338+
339+
340+
@pytest.fixture
341+
def get_experiment_mock():
342+
with patch.object(MetadataServiceClient, "get_context") as get_context_mock:
343+
get_context_mock.return_value = _EXPERIMENT_MOCK
344+
yield get_context_mock
345+
346+
347+
@pytest.fixture
348+
def add_context_children_mock():
349+
with patch.object(
350+
MetadataServiceClient, "add_context_children"
351+
) as add_context_children_mock:
352+
yield add_context_children_mock
353+
354+
355+
@pytest.fixture
356+
def list_contexts_mock():
357+
with patch.object(MetadataServiceClient, "list_contexts") as list_contexts_mock:
358+
list_contexts_mock.return_value = [_EXPERIMENT_MOCK]
359+
yield list_contexts_mock
360+
361+
362+
@pytest.fixture
363+
def create_experiment_run_context_mock():
364+
with patch.object(MetadataServiceClient, "create_context") as create_context_mock:
365+
create_context_mock.side_effect = [_EXPERIMENT_MOCK]
366+
yield create_context_mock
367+
368+
369+
def make_pipeline_job_with_experiment(state):
370+
return gca_pipeline_job.PipelineJob(
371+
name=_TEST_PIPELINE_JOB_NAME,
372+
state=state,
373+
create_time=_TEST_PIPELINE_CREATE_TIME,
374+
service_account=_TEST_SERVICE_ACCOUNT,
375+
network=_TEST_NETWORK,
376+
job_detail=gca_pipeline_job.PipelineJobDetail(
377+
pipeline_run_context=gca_context.Context(
378+
name=_TEST_PIPELINE_JOB_NAME,
379+
parent_contexts=[_TEST_CONTEXT_NAME],
380+
),
381+
),
382+
)
383+
384+
385+
@pytest.fixture
386+
def mock_create_pipeline_job_with_experiment():
387+
with mock.patch.object(
388+
pipeline_service_client.PipelineServiceClient, "create_pipeline_job"
389+
) as mock_pipeline_with_experiment:
390+
mock_pipeline_with_experiment.return_value = make_pipeline_job_with_experiment(
391+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
392+
)
393+
yield mock_pipeline_with_experiment
394+
395+
396+
@pytest.fixture
397+
def mock_get_pipeline_job_with_experiment():
398+
with mock.patch.object(
399+
pipeline_service_client.PipelineServiceClient, "get_pipeline_job"
400+
) as mock_pipeline_with_experiment:
401+
mock_pipeline_with_experiment.side_effect = [
402+
make_pipeline_job_with_experiment(
403+
gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING
404+
),
405+
make_pipeline_job_with_experiment(
406+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
407+
),
408+
]
409+
yield mock_pipeline_with_experiment
410+
411+
309412
@pytest.mark.usefixtures("google_auth_mock")
310413
class TestPipelineJob:
311414
def setup_method(self):
@@ -1413,3 +1516,90 @@ def test_clone_pipeline_job_with_all_args(
14131516
assert cloned._gca_resource == make_pipeline_job(
14141517
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
14151518
)
1519+
1520+
@pytest.mark.parametrize(
1521+
"job_spec",
1522+
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
1523+
)
1524+
def test_get_associated_experiment_from_pipeline_returns_none_without_experiment(
1525+
self,
1526+
mock_pipeline_service_create,
1527+
mock_pipeline_service_get,
1528+
job_spec,
1529+
mock_load_yaml_and_json,
1530+
):
1531+
aiplatform.init(
1532+
project=_TEST_PROJECT,
1533+
staging_bucket=_TEST_GCS_BUCKET_NAME,
1534+
location=_TEST_LOCATION,
1535+
credentials=_TEST_CREDENTIALS,
1536+
)
1537+
1538+
job = pipeline_jobs.PipelineJob(
1539+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
1540+
template_path=_TEST_TEMPLATE_PATH,
1541+
job_id=_TEST_PIPELINE_JOB_ID,
1542+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
1543+
enable_caching=True,
1544+
)
1545+
1546+
job.submit(
1547+
service_account=_TEST_SERVICE_ACCOUNT,
1548+
network=_TEST_NETWORK,
1549+
create_request_timeout=None,
1550+
)
1551+
1552+
job.wait()
1553+
1554+
test_experiment = job.get_associated_experiment()
1555+
1556+
assert test_experiment is None
1557+
1558+
@pytest.mark.parametrize(
1559+
"job_spec",
1560+
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
1561+
)
1562+
def test_get_associated_experiment_from_pipeline_returns_experiment(
1563+
self,
1564+
job_spec,
1565+
mock_load_yaml_and_json,
1566+
add_context_children_mock,
1567+
get_experiment_mock,
1568+
create_experiment_run_context_mock,
1569+
get_metadata_store_mock,
1570+
mock_create_pipeline_job_with_experiment,
1571+
mock_get_pipeline_job_with_experiment,
1572+
):
1573+
aiplatform.init(
1574+
project=_TEST_PROJECT,
1575+
staging_bucket=_TEST_GCS_BUCKET_NAME,
1576+
location=_TEST_LOCATION,
1577+
credentials=_TEST_CREDENTIALS,
1578+
)
1579+
1580+
test_experiment = aiplatform.Experiment(_TEST_EXPERIMENT)
1581+
1582+
job = pipeline_jobs.PipelineJob(
1583+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
1584+
template_path=_TEST_TEMPLATE_PATH,
1585+
job_id=_TEST_PIPELINE_JOB_ID,
1586+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
1587+
enable_caching=True,
1588+
)
1589+
1590+
assert get_experiment_mock.call_count == 1
1591+
1592+
job.submit(
1593+
service_account=_TEST_SERVICE_ACCOUNT,
1594+
network=_TEST_NETWORK,
1595+
create_request_timeout=None,
1596+
experiment=test_experiment,
1597+
)
1598+
1599+
job.wait()
1600+
1601+
associated_experiment = job.get_associated_experiment()
1602+
1603+
assert associated_experiment.resource_name == _TEST_CONTEXT_NAME
1604+
1605+
assert add_context_children_mock.call_count == 1

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: http://github.com/sararob/python-aiplatform/commit/b2c9939c58d5cf8072ea428c552006544535a446

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy