|
32 | 32 | from google.cloud.aiplatform import jobs
|
33 | 33 | from google.cloud.aiplatform.compat.types import (
|
34 | 34 | custom_job as gca_custom_job_compat,
|
| 35 | + tensorboard_run as gca_tensorboard_run, |
35 | 36 | io,
|
36 | 37 | )
|
37 | 38 |
|
|
55 | 56 | _TEST_PARENT = test_constants.ProjectConstants._TEST_PARENT
|
56 | 57 |
|
57 | 58 | _TEST_CUSTOM_JOB_NAME = f"{_TEST_PARENT}/customJobs/{_TEST_ID}"
|
58 |
| -_TEST_TENSORBOARD_NAME = f"{_TEST_PARENT}/tensorboards/{_TEST_ID}" |
| 59 | +_TEST_TENSORBOARD_ID = "987654321" |
| 60 | +_TEST_TENSORBOARD_NAME = f"{_TEST_PARENT}/tensorboards/{_TEST_TENSORBOARD_ID}" |
59 | 61 | _TEST_ENABLE_WEB_ACCESS = test_constants.TrainingJobConstants._TEST_ENABLE_WEB_ACCESS
|
60 | 62 | _TEST_WEB_ACCESS_URIS = test_constants.TrainingJobConstants._TEST_WEB_ACCESS_URIS
|
61 | 63 | _TEST_TRAINING_CONTAINER_IMAGE = (
|
|
162 | 164 | _TEST_EXPERIMENT_RUN_CONTEXT_NAME = (
|
163 | 165 | f"{_TEST_PARENT_METADATA}/contexts/{_TEST_EXECUTION_ID}"
|
164 | 166 | )
|
| 167 | +_TEST_TENSORBOARD_RUN_NAME = f"{_TEST_PARENT}/tensorboards/{_TEST_TENSORBOARD_ID}/experiments/{_TEST_ID}/runs/{_TEST_RUN}" |
| 168 | +_TEST_TENSORBOARD_RUN_CONTEXT_NAME = f"{_TEST_ID}-{_TEST_RUN}" |
165 | 169 |
|
166 | 170 | _EXPERIMENT_MOCK = GapicContext(
|
167 | 171 | name=_TEST_CONTEXT_NAME,
|
@@ -207,6 +211,16 @@ def _get_custom_job_proto_with_experiments(state=None, name=None, error=None):
|
207 | 211 | return custom_job_proto
|
208 | 212 |
|
209 | 213 |
|
| 214 | +def _get_custom_job_proto_with_tensorboard(state=None, name=None, error=None): |
| 215 | + custom_job_proto = copy.deepcopy(_TEST_BASE_CUSTOM_JOB_PROTO) |
| 216 | + custom_job_proto.job_spec.worker_pool_specs = _TEST_WORKER_POOL_SPEC |
| 217 | + custom_job_proto.name = name |
| 218 | + custom_job_proto.state = state |
| 219 | + custom_job_proto.error = error |
| 220 | + custom_job_proto.job_spec.tensorboard = _TEST_TENSORBOARD_NAME |
| 221 | + return custom_job_proto |
| 222 | + |
| 223 | + |
210 | 224 | def _get_custom_job_proto_with_enable_web_access(state=None, name=None, error=None):
|
211 | 225 | custom_job_proto = _get_custom_job_proto(state=state, name=name, error=error)
|
212 | 226 | custom_job_proto.job_spec.enable_web_access = _TEST_ENABLE_WEB_ACCESS
|
@@ -284,6 +298,28 @@ def get_custom_job_with_experiments_mock():
|
284 | 298 | yield get_custom_job_mock
|
285 | 299 |
|
286 | 300 |
|
| 301 | +@pytest.fixture |
| 302 | +def get_custom_job_with_tensorboard_mock(): |
| 303 | + with patch.object( |
| 304 | + job_service_client.JobServiceClient, "get_custom_job" |
| 305 | + ) as get_custom_job_mock: |
| 306 | + get_custom_job_mock.side_effect = [ |
| 307 | + _get_custom_job_proto( |
| 308 | + name=_TEST_CUSTOM_JOB_NAME, |
| 309 | + state=gca_job_state_compat.JobState.JOB_STATE_PENDING, |
| 310 | + ), |
| 311 | + _get_custom_job_proto( |
| 312 | + name=_TEST_CUSTOM_JOB_NAME, |
| 313 | + state=gca_job_state_compat.JobState.JOB_STATE_RUNNING, |
| 314 | + ), |
| 315 | + _get_custom_job_proto_with_tensorboard( |
| 316 | + name=_TEST_CUSTOM_JOB_NAME, |
| 317 | + state=gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED, |
| 318 | + ), |
| 319 | + ] |
| 320 | + yield get_custom_job_mock |
| 321 | + |
| 322 | + |
287 | 323 | @pytest.fixture
|
288 | 324 | def get_custom_tpu_v5e_job_mock():
|
289 | 325 | with patch.object(
|
@@ -822,6 +858,98 @@ def test_run_custom_job_with_experiment_run_warning(self, caplog):
|
822 | 858 | in caplog.text
|
823 | 859 | )
|
824 | 860 |
|
| 861 | + @pytest.mark.usefixtures( |
| 862 | + "get_experiment_run_not_found_mock", |
| 863 | + "get_tensorboard_run_artifact_not_found_mock", |
| 864 | + ) |
| 865 | + def test_run_custom_job_with_tensorboard_cannot_list_experiment_runs( |
| 866 | + self, |
| 867 | + create_custom_job_mock_with_tensorboard, |
| 868 | + get_custom_job_with_tensorboard_mock, |
| 869 | + caplog, |
| 870 | + ): |
| 871 | + |
| 872 | + aiplatform.init( |
| 873 | + project=_TEST_PROJECT, |
| 874 | + location=_TEST_LOCATION, |
| 875 | + staging_bucket=_TEST_STAGING_BUCKET, |
| 876 | + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, |
| 877 | + ) |
| 878 | + |
| 879 | + job = aiplatform.CustomJob( |
| 880 | + display_name=_TEST_DISPLAY_NAME, |
| 881 | + worker_pool_specs=_TEST_WORKER_POOL_SPEC, |
| 882 | + base_output_dir=_TEST_BASE_OUTPUT_DIR, |
| 883 | + labels=_TEST_LABELS, |
| 884 | + ) |
| 885 | + |
| 886 | + job.run( |
| 887 | + service_account=_TEST_SERVICE_ACCOUNT, |
| 888 | + tensorboard=_TEST_TENSORBOARD_NAME, |
| 889 | + network=_TEST_NETWORK, |
| 890 | + timeout=_TEST_TIMEOUT, |
| 891 | + restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, |
| 892 | + create_request_timeout=None, |
| 893 | + disable_retries=_TEST_DISABLE_RETRIES, |
| 894 | + max_wait_duration=_TEST_MAX_WAIT_DURATION, |
| 895 | + ) |
| 896 | + |
| 897 | + job.wait() |
| 898 | + |
| 899 | + assert "Failed to list experiment runs for tensorboard" in caplog.text |
| 900 | + |
| 901 | + @pytest.mark.usefixtures( |
| 902 | + "get_experiment_run_not_found_mock", |
| 903 | + "get_tensorboard_run_artifact_not_found_mock", |
| 904 | + ) |
| 905 | + def test_run_custom_job_with_tensorboard_cannot_end_experiment_run( |
| 906 | + self, |
| 907 | + create_custom_job_mock_with_tensorboard, |
| 908 | + get_custom_job_with_tensorboard_mock, |
| 909 | + caplog, |
| 910 | + ): |
| 911 | + |
| 912 | + aiplatform.init( |
| 913 | + project=_TEST_PROJECT, |
| 914 | + location=_TEST_LOCATION, |
| 915 | + staging_bucket=_TEST_STAGING_BUCKET, |
| 916 | + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, |
| 917 | + ) |
| 918 | + |
| 919 | + job = aiplatform.CustomJob( |
| 920 | + display_name=_TEST_DISPLAY_NAME, |
| 921 | + worker_pool_specs=_TEST_WORKER_POOL_SPEC, |
| 922 | + base_output_dir=_TEST_BASE_OUTPUT_DIR, |
| 923 | + labels=_TEST_LABELS, |
| 924 | + ) |
| 925 | + |
| 926 | + with mock.patch.object( |
| 927 | + aiplatform.TensorboardRun, "list" |
| 928 | + ) as list_tensorboard_runs_mock: |
| 929 | + tb_run = gca_tensorboard_run.TensorboardRun( |
| 930 | + name=_TEST_TENSORBOARD_RUN_NAME, |
| 931 | + display_name=_TEST_DISPLAY_NAME, |
| 932 | + ) |
| 933 | + list_tensorboard_runs_mock.return_value = [tb_run] |
| 934 | + |
| 935 | + job.run( |
| 936 | + service_account=_TEST_SERVICE_ACCOUNT, |
| 937 | + tensorboard=_TEST_TENSORBOARD_NAME, |
| 938 | + network=_TEST_NETWORK, |
| 939 | + timeout=_TEST_TIMEOUT, |
| 940 | + restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, |
| 941 | + create_request_timeout=None, |
| 942 | + disable_retries=_TEST_DISABLE_RETRIES, |
| 943 | + max_wait_duration=_TEST_MAX_WAIT_DURATION, |
| 944 | + ) |
| 945 | + |
| 946 | + job.wait() |
| 947 | + |
| 948 | + assert ( |
| 949 | + f"Failed to end experiment run {_TEST_TENSORBOARD_RUN_CONTEXT_NAME} due to:" |
| 950 | + in caplog.text |
| 951 | + ) |
| 952 | + |
825 | 953 | @pytest.mark.parametrize("sync", [True, False])
|
826 | 954 | def test_run_custom_job_with_fail_raises(
|
827 | 955 | self, create_custom_job_mock, get_custom_job_mock_with_fail, sync
|
|
0 commit comments