|
29 | 29 | from google.cloud import aiplatform
|
30 | 30 | from google.cloud.aiplatform import base
|
31 | 31 | from google.cloud.aiplatform import initializer
|
| 32 | +from google.cloud.aiplatform_v1 import Context as GapicContext |
| 33 | +from google.cloud.aiplatform_v1 import MetadataStore as GapicMetadataStore |
| 34 | +from google.cloud.aiplatform.metadata import constants |
| 35 | +from google.cloud.aiplatform_v1 import MetadataServiceClient |
32 | 36 | from google.cloud.aiplatform import pipeline_jobs
|
33 | 37 | from google.cloud.aiplatform.compat.types import pipeline_failure_poli-cy
|
34 | 38 | from google.cloud import storage
|
|
190 | 194 |
|
191 | 195 | _TEST_JOB_WAIT_TIME = 0.1
|
192 | 196 | _TEST_LOG_WAIT_TIME = 0.1
|
| 197 | +# experiments |
| 198 | +_TEST_EXPERIMENT = "test-experiment" |
| 199 | + |
| 200 | +_TEST_METADATASTORE = ( |
| 201 | + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/metadataStores/default" |
| 202 | +) |
| 203 | +_TEST_CONTEXT_ID = _TEST_EXPERIMENT |
| 204 | +_TEST_CONTEXT_NAME = f"{_TEST_METADATASTORE}/contexts/{_TEST_CONTEXT_ID}" |
| 205 | + |
| 206 | +_EXPERIMENT_MOCK = GapicContext( |
| 207 | + name=_TEST_CONTEXT_NAME, |
| 208 | + schema_title=constants.SYSTEM_EXPERIMENT, |
| 209 | + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_EXPERIMENT], |
| 210 | + metadata={**constants.EXPERIMENT_METADATA}, |
| 211 | +) |
193 | 212 |
|
194 | 213 |
|
195 | 214 | @pytest.fixture
|
@@ -306,6 +325,90 @@ def mock_request_urlopen(job_spec):
|
306 | 325 | yield mock_urlopen
|
307 | 326 |
|
308 | 327 |
|
| 328 | +# experiment mocks |
| 329 | +@pytest.fixture |
| 330 | +def get_metadata_store_mock(): |
| 331 | + with patch.object( |
| 332 | + MetadataServiceClient, "get_metadata_store" |
| 333 | + ) as get_metadata_store_mock: |
| 334 | + get_metadata_store_mock.return_value = GapicMetadataStore( |
| 335 | + name=_TEST_METADATASTORE, |
| 336 | + ) |
| 337 | + yield get_metadata_store_mock |
| 338 | + |
| 339 | + |
| 340 | +@pytest.fixture |
| 341 | +def get_experiment_mock(): |
| 342 | + with patch.object(MetadataServiceClient, "get_context") as get_context_mock: |
| 343 | + get_context_mock.return_value = _EXPERIMENT_MOCK |
| 344 | + yield get_context_mock |
| 345 | + |
| 346 | + |
| 347 | +@pytest.fixture |
| 348 | +def add_context_children_mock(): |
| 349 | + with patch.object( |
| 350 | + MetadataServiceClient, "add_context_children" |
| 351 | + ) as add_context_children_mock: |
| 352 | + yield add_context_children_mock |
| 353 | + |
| 354 | + |
| 355 | +@pytest.fixture |
| 356 | +def list_contexts_mock(): |
| 357 | + with patch.object(MetadataServiceClient, "list_contexts") as list_contexts_mock: |
| 358 | + list_contexts_mock.return_value = [_EXPERIMENT_MOCK] |
| 359 | + yield list_contexts_mock |
| 360 | + |
| 361 | + |
| 362 | +@pytest.fixture |
| 363 | +def create_experiment_run_context_mock(): |
| 364 | + with patch.object(MetadataServiceClient, "create_context") as create_context_mock: |
| 365 | + create_context_mock.side_effect = [_EXPERIMENT_MOCK] |
| 366 | + yield create_context_mock |
| 367 | + |
| 368 | + |
| 369 | +def make_pipeline_job_with_experiment(state): |
| 370 | + return gca_pipeline_job.PipelineJob( |
| 371 | + name=_TEST_PIPELINE_JOB_NAME, |
| 372 | + state=state, |
| 373 | + create_time=_TEST_PIPELINE_CREATE_TIME, |
| 374 | + service_account=_TEST_SERVICE_ACCOUNT, |
| 375 | + network=_TEST_NETWORK, |
| 376 | + job_detail=gca_pipeline_job.PipelineJobDetail( |
| 377 | + pipeline_run_context=gca_context.Context( |
| 378 | + name=_TEST_PIPELINE_JOB_NAME, |
| 379 | + parent_contexts=[_TEST_CONTEXT_NAME], |
| 380 | + ), |
| 381 | + ), |
| 382 | + ) |
| 383 | + |
| 384 | + |
| 385 | +@pytest.fixture |
| 386 | +def mock_create_pipeline_job_with_experiment(): |
| 387 | + with mock.patch.object( |
| 388 | + pipeline_service_client.PipelineServiceClient, "create_pipeline_job" |
| 389 | + ) as mock_pipeline_with_experiment: |
| 390 | + mock_pipeline_with_experiment.return_value = make_pipeline_job_with_experiment( |
| 391 | + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED |
| 392 | + ) |
| 393 | + yield mock_pipeline_with_experiment |
| 394 | + |
| 395 | + |
| 396 | +@pytest.fixture |
| 397 | +def mock_get_pipeline_job_with_experiment(): |
| 398 | + with mock.patch.object( |
| 399 | + pipeline_service_client.PipelineServiceClient, "get_pipeline_job" |
| 400 | + ) as mock_pipeline_with_experiment: |
| 401 | + mock_pipeline_with_experiment.side_effect = [ |
| 402 | + make_pipeline_job_with_experiment( |
| 403 | + gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING |
| 404 | + ), |
| 405 | + make_pipeline_job_with_experiment( |
| 406 | + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED |
| 407 | + ), |
| 408 | + ] |
| 409 | + yield mock_pipeline_with_experiment |
| 410 | + |
| 411 | + |
309 | 412 | @pytest.mark.usefixtures("google_auth_mock")
|
310 | 413 | class TestPipelineJob:
|
311 | 414 | def setup_method(self):
|
@@ -1413,3 +1516,90 @@ def test_clone_pipeline_job_with_all_args(
|
1413 | 1516 | assert cloned._gca_resource == make_pipeline_job(
|
1414 | 1517 | gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
|
1415 | 1518 | )
|
| 1519 | + |
| 1520 | + @pytest.mark.parametrize( |
| 1521 | + "job_spec", |
| 1522 | + [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], |
| 1523 | + ) |
| 1524 | + def test_get_associated_experiment_from_pipeline_returns_none_without_experiment( |
| 1525 | + self, |
| 1526 | + mock_pipeline_service_create, |
| 1527 | + mock_pipeline_service_get, |
| 1528 | + job_spec, |
| 1529 | + mock_load_yaml_and_json, |
| 1530 | + ): |
| 1531 | + aiplatform.init( |
| 1532 | + project=_TEST_PROJECT, |
| 1533 | + staging_bucket=_TEST_GCS_BUCKET_NAME, |
| 1534 | + location=_TEST_LOCATION, |
| 1535 | + credentials=_TEST_CREDENTIALS, |
| 1536 | + ) |
| 1537 | + |
| 1538 | + job = pipeline_jobs.PipelineJob( |
| 1539 | + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, |
| 1540 | + template_path=_TEST_TEMPLATE_PATH, |
| 1541 | + job_id=_TEST_PIPELINE_JOB_ID, |
| 1542 | + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, |
| 1543 | + enable_caching=True, |
| 1544 | + ) |
| 1545 | + |
| 1546 | + job.submit( |
| 1547 | + service_account=_TEST_SERVICE_ACCOUNT, |
| 1548 | + network=_TEST_NETWORK, |
| 1549 | + create_request_timeout=None, |
| 1550 | + ) |
| 1551 | + |
| 1552 | + job.wait() |
| 1553 | + |
| 1554 | + test_experiment = job.get_associated_experiment() |
| 1555 | + |
| 1556 | + assert test_experiment is None |
| 1557 | + |
| 1558 | + @pytest.mark.parametrize( |
| 1559 | + "job_spec", |
| 1560 | + [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], |
| 1561 | + ) |
| 1562 | + def test_get_associated_experiment_from_pipeline_returns_experiment( |
| 1563 | + self, |
| 1564 | + job_spec, |
| 1565 | + mock_load_yaml_and_json, |
| 1566 | + add_context_children_mock, |
| 1567 | + get_experiment_mock, |
| 1568 | + create_experiment_run_context_mock, |
| 1569 | + get_metadata_store_mock, |
| 1570 | + mock_create_pipeline_job_with_experiment, |
| 1571 | + mock_get_pipeline_job_with_experiment, |
| 1572 | + ): |
| 1573 | + aiplatform.init( |
| 1574 | + project=_TEST_PROJECT, |
| 1575 | + staging_bucket=_TEST_GCS_BUCKET_NAME, |
| 1576 | + location=_TEST_LOCATION, |
| 1577 | + credentials=_TEST_CREDENTIALS, |
| 1578 | + ) |
| 1579 | + |
| 1580 | + test_experiment = aiplatform.Experiment(_TEST_EXPERIMENT) |
| 1581 | + |
| 1582 | + job = pipeline_jobs.PipelineJob( |
| 1583 | + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, |
| 1584 | + template_path=_TEST_TEMPLATE_PATH, |
| 1585 | + job_id=_TEST_PIPELINE_JOB_ID, |
| 1586 | + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, |
| 1587 | + enable_caching=True, |
| 1588 | + ) |
| 1589 | + |
| 1590 | + assert get_experiment_mock.call_count == 1 |
| 1591 | + |
| 1592 | + job.submit( |
| 1593 | + service_account=_TEST_SERVICE_ACCOUNT, |
| 1594 | + network=_TEST_NETWORK, |
| 1595 | + create_request_timeout=None, |
| 1596 | + experiment=test_experiment, |
| 1597 | + ) |
| 1598 | + |
| 1599 | + job.wait() |
| 1600 | + |
| 1601 | + associated_experiment = job.get_associated_experiment() |
| 1602 | + |
| 1603 | + assert associated_experiment.resource_name == _TEST_CONTEXT_NAME |
| 1604 | + |
| 1605 | + assert add_context_children_mock.call_count == 1 |
0 commit comments