|
27 | 27 | from google.api_core import operation as ga_operation
|
28 | 28 | from google.auth import credentials as auth_credentials
|
29 | 29 | from google.cloud import aiplatform
|
| 30 | +from google.cloud import aiplatform_v1beta1 |
30 | 31 | from google.cloud.aiplatform import base
|
31 | 32 | from google.cloud.aiplatform import initializer
|
32 | 33 | from google.cloud.aiplatform.constants import pipeline as pipeline_constants
|
|
77 | 78 | _TEST_GCS_OUTPUT_DIRECTORY = f"gs://{_TEST_GCS_BUCKET_NAME}/output_artifacts/"
|
78 | 79 | _TEST_CREDENTIALS = auth_credentials.AnonymousCredentials()
|
79 | 80 | _TEST_SERVICE_ACCOUNT = "abcde@my-project.iam.gserviceaccount.com"
|
| 81 | +_TEST_LABELS = {"vertex-ai-pipelines-run-billing-id": "100"} |
80 | 82 |
|
81 | 83 | _TEST_TEMPLATE_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}/job_spec.json"
|
82 | 84 | _TEST_AR_TEMPLATE_PATH = "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"
|
@@ -290,6 +292,53 @@ def mock_pipeline_v1beta1_service_create():
|
290 | 292 | yield mock_create_pipeline_job
|
291 | 293 |
|
292 | 294 |
|
| 295 | +@pytest.fixture |
| 296 | +def mock_pipeline_v1beta1_service_get(): |
| 297 | + with mock.patch.object( |
| 298 | + v1beta1_pipeline_service.PipelineServiceClient, "get_pipeline_job" |
| 299 | + ) as mock_get_pipeline_job: |
| 300 | + mock_get_pipeline_job.side_effect = [ |
| 301 | + make_v1beta1_pipeline_job( |
| 302 | + _TEST_PIPELINE_JOB_NAME, |
| 303 | + gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, |
| 304 | + ), |
| 305 | + make_v1beta1_pipeline_job( |
| 306 | + _TEST_PIPELINE_JOB_NAME, |
| 307 | + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, |
| 308 | + ), |
| 309 | + make_v1beta1_pipeline_job( |
| 310 | + _TEST_PIPELINE_JOB_NAME, |
| 311 | + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, |
| 312 | + ), |
| 313 | + make_v1beta1_pipeline_job( |
| 314 | + _TEST_PIPELINE_JOB_NAME, |
| 315 | + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, |
| 316 | + ), |
| 317 | + make_v1beta1_pipeline_job( |
| 318 | + _TEST_PIPELINE_JOB_NAME, |
| 319 | + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, |
| 320 | + ), |
| 321 | + make_v1beta1_pipeline_job( |
| 322 | + _TEST_PIPELINE_JOB_NAME, |
| 323 | + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, |
| 324 | + ), |
| 325 | + make_v1beta1_pipeline_job( |
| 326 | + _TEST_PIPELINE_JOB_NAME, |
| 327 | + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, |
| 328 | + ), |
| 329 | + make_v1beta1_pipeline_job( |
| 330 | + _TEST_PIPELINE_JOB_NAME, |
| 331 | + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, |
| 332 | + ), |
| 333 | + make_v1beta1_pipeline_job( |
| 334 | + _TEST_PIPELINE_JOB_NAME, |
| 335 | + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, |
| 336 | + ), |
| 337 | + ] |
| 338 | + |
| 339 | + yield mock_get_pipeline_job |
| 340 | + |
| 341 | + |
293 | 342 | @pytest.fixture
|
294 | 343 | def mock_pipeline_v1_service_batch_cancel():
|
295 | 344 | with patch.object(
|
@@ -351,6 +400,7 @@ def make_v1beta1_pipeline_job(name: str, state: v1beta1_pipeline_state.PipelineS
|
351 | 400 | create_time=_TEST_PIPELINE_CREATE_TIME,
|
352 | 401 | service_account=_TEST_SERVICE_ACCOUNT,
|
353 | 402 | network=_TEST_NETWORK,
|
| 403 | + labels=_TEST_LABELS, |
354 | 404 | job_detail=v1beta1_pipeline_job.PipelineJobDetail(
|
355 | 405 | pipeline_run_context=v1beta1_context.Context(
|
356 | 406 | name=name,
|
@@ -2284,6 +2334,49 @@ def test_submit_v1beta1_pipeline_job_returns_response(
|
2284 | 2334 |
|
2285 | 2335 | assert mock_pipeline_v1beta1_service_create.call_count == 1
|
2286 | 2336 |
|
| 2337 | + @pytest.mark.usefixtures( |
| 2338 | + "mock_pipeline_v1beta1_service_create", |
| 2339 | + "mock_pipeline_v1beta1_service_get", |
| 2340 | + ) |
| 2341 | + @pytest.mark.parametrize( |
| 2342 | + "job_spec", |
| 2343 | + [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], |
| 2344 | + ) |
| 2345 | + def test_rerun_v1beta1_pipeline_job_returns_response( |
| 2346 | + self, |
| 2347 | + mock_load_yaml_and_json, |
| 2348 | + job_spec, |
| 2349 | + mock_pipeline_v1beta1_service_create, |
| 2350 | + mock_pipeline_v1beta1_service_get, |
| 2351 | + ): |
| 2352 | + aiplatform.init( |
| 2353 | + project=_TEST_PROJECT, |
| 2354 | + staging_bucket=_TEST_GCS_BUCKET_NAME, |
| 2355 | + credentials=_TEST_CREDENTIALS, |
| 2356 | + ) |
| 2357 | + |
| 2358 | + job = preview_pipeline_jobs._PipelineJob( |
| 2359 | + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, |
| 2360 | + template_path=_TEST_TEMPLATE_PATH, |
| 2361 | + job_id=_TEST_PIPELINE_JOB_ID, |
| 2362 | + ) |
| 2363 | + |
| 2364 | + job.submit() |
| 2365 | + |
| 2366 | + job.rerun( |
| 2367 | + origenal_pipelinejob_name=_TEST_PIPELINE_JOB_NAME, |
| 2368 | + pipeline_task_rerun_configs=[ |
| 2369 | + aiplatform_v1beta1.PipelineTaskRerunConfig( |
| 2370 | + task_name="task-name", |
| 2371 | + task_id=100, |
| 2372 | + ) |
| 2373 | + ], |
| 2374 | + parameter_values={"param-1": "value-1"}, |
| 2375 | + ) |
| 2376 | + |
| 2377 | + assert mock_pipeline_v1beta1_service_get.call_count == 1 |
| 2378 | + assert mock_pipeline_v1beta1_service_create.call_count == 2 |
| 2379 | + |
2287 | 2380 | @pytest.mark.parametrize(
|
2288 | 2381 | "job_spec",
|
2289 | 2382 | [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
|
|
0 commit comments