|
92 | 92 | "schemaVersion": "2.1.0",
|
93 | 93 | "components": {},
|
94 | 94 | }
|
| 95 | +_TEST_TFX_PIPELINE_SPEC = { |
| 96 | + "pipelineInfo": {"name": "my-pipeline"}, |
| 97 | + "root": { |
| 98 | + "dag": {"tasks": {}}, |
| 99 | + "inputDefinitions": {"parameters": {"string_param": {"type": "STRING"}}}, |
| 100 | + }, |
| 101 | + "schemaVersion": "2.0.0", |
| 102 | + "sdkVersion": "tfx-1.4.0", |
| 103 | + "components": {}, |
| 104 | +} |
95 | 105 |
|
96 | 106 | _TEST_PIPELINE_JOB_LEGACY = {
|
97 | 107 | "runtimeConfig": {},
|
|
101 | 111 | "runtimeConfig": {"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES},
|
102 | 112 | "pipelineSpec": _TEST_PIPELINE_SPEC,
|
103 | 113 | }
|
| 114 | +_TEST_PIPELINE_JOB_TFX = { |
| 115 | + "runtimeConfig": {}, |
| 116 | + "pipelineSpec": _TEST_TFX_PIPELINE_SPEC, |
| 117 | +} |
104 | 118 |
|
105 | 119 | _TEST_PIPELINE_GET_METHOD_NAME = "get_fake_pipeline_job"
|
106 | 120 | _TEST_PIPELINE_LIST_METHOD_NAME = "list_fake_pipeline_jobs"
|
@@ -378,6 +392,78 @@ def test_run_call_pipeline_service_create_legacy(
|
378 | 392 | gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED
|
379 | 393 | )
|
380 | 394 |
|
| 395 | + @pytest.mark.parametrize( |
| 396 | + "job_spec_json", [_TEST_TFX_PIPELINE_SPEC, _TEST_PIPELINE_JOB_TFX], |
| 397 | + ) |
| 398 | + @pytest.mark.parametrize("sync", [True, False]) |
| 399 | + def test_run_call_pipeline_service_create_tfx( |
| 400 | + self, |
| 401 | + mock_pipeline_service_create, |
| 402 | + mock_pipeline_service_get, |
| 403 | + job_spec_json, |
| 404 | + mock_load_json, |
| 405 | + sync, |
| 406 | + ): |
| 407 | + aiplatform.init( |
| 408 | + project=_TEST_PROJECT, |
| 409 | + staging_bucket=_TEST_GCS_BUCKET_NAME, |
| 410 | + location=_TEST_LOCATION, |
| 411 | + credentials=_TEST_CREDENTIALS, |
| 412 | + ) |
| 413 | + |
| 414 | + job = pipeline_jobs.PipelineJob( |
| 415 | + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, |
| 416 | + template_path=_TEST_TEMPLATE_PATH, |
| 417 | + job_id=_TEST_PIPELINE_JOB_ID, |
| 418 | + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES_LEGACY, |
| 419 | + enable_caching=True, |
| 420 | + ) |
| 421 | + |
| 422 | + job.run( |
| 423 | + service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK, sync=sync, |
| 424 | + ) |
| 425 | + |
| 426 | + if not sync: |
| 427 | + job.wait() |
| 428 | + |
| 429 | + expected_runtime_config_dict = { |
| 430 | + "gcsOutputDirectory": _TEST_GCS_BUCKET_NAME, |
| 431 | + "parameters": {"string_param": {"stringValue": "hello"}}, |
| 432 | + } |
| 433 | + runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb |
| 434 | + json_format.ParseDict(expected_runtime_config_dict, runtime_config) |
| 435 | + |
| 436 | + pipeline_spec = job_spec_json.get("pipelineSpec") or job_spec_json |
| 437 | + |
| 438 | + # Construct expected request |
| 439 | + expected_gapic_pipeline_job = gca_pipeline_job_v1.PipelineJob( |
| 440 | + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, |
| 441 | + pipeline_spec={ |
| 442 | + "components": {}, |
| 443 | + "pipelineInfo": pipeline_spec["pipelineInfo"], |
| 444 | + "root": pipeline_spec["root"], |
| 445 | + "schemaVersion": "2.0.0", |
| 446 | + "sdkVersion": "tfx-1.4.0", |
| 447 | + }, |
| 448 | + runtime_config=runtime_config, |
| 449 | + service_account=_TEST_SERVICE_ACCOUNT, |
| 450 | + network=_TEST_NETWORK, |
| 451 | + ) |
| 452 | + |
| 453 | + mock_pipeline_service_create.assert_called_once_with( |
| 454 | + parent=_TEST_PARENT, |
| 455 | + pipeline_job=expected_gapic_pipeline_job, |
| 456 | + pipeline_job_id=_TEST_PIPELINE_JOB_ID, |
| 457 | + ) |
| 458 | + |
| 459 | + mock_pipeline_service_get.assert_called_with( |
| 460 | + name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY |
| 461 | + ) |
| 462 | + |
| 463 | + assert job._gca_resource == make_pipeline_job( |
| 464 | + gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED |
| 465 | + ) |
| 466 | + |
381 | 467 | @pytest.mark.parametrize(
|
382 | 468 | "job_spec_json", [_TEST_PIPELINE_SPEC, _TEST_PIPELINE_JOB],
|
383 | 469 | )
|
|
0 commit comments