Skip to content

Commit 30cf221

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Automatically end Experiment runs when Tensorboard CustomJob is complete
PiperOrigin-RevId: 683653996
1 parent 2b8ae76 commit 30cf221

File tree

2 files changed

+152
-4
lines changed

2 files changed

+152
-4
lines changed

google/cloud/aiplatform/jobs.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1633,14 +1633,34 @@ def _block_until_complete(self):
16331633

16341634
if isinstance(self, CustomJob):
16351635
# End the experiment run associated with the custom job, if exists.
1636-
experiment_run = self._gca_resource.job_spec.experiment_run
1637-
if experiment_run:
1636+
experiment_runs = []
1637+
if self._gca_resource.job_spec.experiment_run:
1638+
experiment_runs = [self._gca_resource.job_spec.experiment_run]
1639+
elif self._gca_resource.job_spec.tensorboard:
1640+
tensorboard_id = self._gca_resource.job_spec.tensorboard.split("/")[-1]
1641+
try:
1642+
tb_runs = aiplatform.TensorboardRun.list(
1643+
tensorboard_experiment_name=self.name,
1644+
tensorboard_id=tensorboard_id,
1645+
)
1646+
experiment_runs = [
1647+
f"{self.name}-{tb_run.name.split('/')[-1]}"
1648+
for tb_run in tb_runs
1649+
]
1650+
except (ValueError, api_exceptions.GoogleAPIError) as e:
1651+
_LOGGER.warning(
1652+
f"Failed to list experiment runs for tensorboard "
1653+
f"{tensorboard_id} due to: {e}"
1654+
)
1655+
for experiment_run in experiment_runs:
16381656
try:
16391657
# sync resource before end run
16401658
experiment_run_context = aiplatform.Context(experiment_run)
16411659
experiment_run_context.update(
16421660
metadata={
1643-
metadata_constants._STATE_KEY: gca_execution_compat.Execution.State.COMPLETE.name
1661+
metadata_constants._STATE_KEY: (
1662+
gca_execution_compat.Execution.State.COMPLETE.name
1663+
)
16441664
}
16451665
)
16461666
except (ValueError, api_exceptions.GoogleAPIError) as e:

tests/unit/aiplatform/test_custom_job.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from google.cloud.aiplatform import jobs
3333
from google.cloud.aiplatform.compat.types import (
3434
custom_job as gca_custom_job_compat,
35+
tensorboard_run as gca_tensorboard_run,
3536
io,
3637
)
3738

@@ -55,7 +56,8 @@
5556
_TEST_PARENT = test_constants.ProjectConstants._TEST_PARENT
5657

5758
_TEST_CUSTOM_JOB_NAME = f"{_TEST_PARENT}/customJobs/{_TEST_ID}"
58-
_TEST_TENSORBOARD_NAME = f"{_TEST_PARENT}/tensorboards/{_TEST_ID}"
59+
_TEST_TENSORBOARD_ID = "987654321"
60+
_TEST_TENSORBOARD_NAME = f"{_TEST_PARENT}/tensorboards/{_TEST_TENSORBOARD_ID}"
5961
_TEST_ENABLE_WEB_ACCESS = test_constants.TrainingJobConstants._TEST_ENABLE_WEB_ACCESS
6062
_TEST_WEB_ACCESS_URIS = test_constants.TrainingJobConstants._TEST_WEB_ACCESS_URIS
6163
_TEST_TRAINING_CONTAINER_IMAGE = (
@@ -162,6 +164,8 @@
162164
_TEST_EXPERIMENT_RUN_CONTEXT_NAME = (
163165
f"{_TEST_PARENT_METADATA}/contexts/{_TEST_EXECUTION_ID}"
164166
)
167+
_TEST_TENSORBOARD_RUN_NAME = f"{_TEST_PARENT}/tensorboards/{_TEST_TENSORBOARD_ID}/experiments/{_TEST_ID}/runs/{_TEST_RUN}"
168+
_TEST_TENSORBOARD_RUN_CONTEXT_NAME = f"{_TEST_ID}-{_TEST_RUN}"
165169

166170
_EXPERIMENT_MOCK = GapicContext(
167171
name=_TEST_CONTEXT_NAME,
@@ -207,6 +211,16 @@ def _get_custom_job_proto_with_experiments(state=None, name=None, error=None):
207211
return custom_job_proto
208212

209213

214+
def _get_custom_job_proto_with_tensorboard(state=None, name=None, error=None):
215+
custom_job_proto = copy.deepcopy(_TEST_BASE_CUSTOM_JOB_PROTO)
216+
custom_job_proto.job_spec.worker_pool_specs = _TEST_WORKER_POOL_SPEC
217+
custom_job_proto.name = name
218+
custom_job_proto.state = state
219+
custom_job_proto.error = error
220+
custom_job_proto.job_spec.tensorboard = _TEST_TENSORBOARD_NAME
221+
return custom_job_proto
222+
223+
210224
def _get_custom_job_proto_with_enable_web_access(state=None, name=None, error=None):
211225
custom_job_proto = _get_custom_job_proto(state=state, name=name, error=error)
212226
custom_job_proto.job_spec.enable_web_access = _TEST_ENABLE_WEB_ACCESS
@@ -284,6 +298,28 @@ def get_custom_job_with_experiments_mock():
284298
yield get_custom_job_mock
285299

286300

301+
@pytest.fixture
302+
def get_custom_job_with_tensorboard_mock():
303+
with patch.object(
304+
job_service_client.JobServiceClient, "get_custom_job"
305+
) as get_custom_job_mock:
306+
get_custom_job_mock.side_effect = [
307+
_get_custom_job_proto(
308+
name=_TEST_CUSTOM_JOB_NAME,
309+
state=gca_job_state_compat.JobState.JOB_STATE_PENDING,
310+
),
311+
_get_custom_job_proto(
312+
name=_TEST_CUSTOM_JOB_NAME,
313+
state=gca_job_state_compat.JobState.JOB_STATE_RUNNING,
314+
),
315+
_get_custom_job_proto_with_tensorboard(
316+
name=_TEST_CUSTOM_JOB_NAME,
317+
state=gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED,
318+
),
319+
]
320+
yield get_custom_job_mock
321+
322+
287323
@pytest.fixture
288324
def get_custom_tpu_v5e_job_mock():
289325
with patch.object(
@@ -822,6 +858,98 @@ def test_run_custom_job_with_experiment_run_warning(self, caplog):
822858
in caplog.text
823859
)
824860

861+
@pytest.mark.usefixtures(
862+
"get_experiment_run_not_found_mock",
863+
"get_tensorboard_run_artifact_not_found_mock",
864+
)
865+
def test_run_custom_job_with_tensorboard_cannot_list_experiment_runs(
866+
self,
867+
create_custom_job_mock_with_tensorboard,
868+
get_custom_job_with_tensorboard_mock,
869+
caplog,
870+
):
871+
872+
aiplatform.init(
873+
project=_TEST_PROJECT,
874+
location=_TEST_LOCATION,
875+
staging_bucket=_TEST_STAGING_BUCKET,
876+
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
877+
)
878+
879+
job = aiplatform.CustomJob(
880+
display_name=_TEST_DISPLAY_NAME,
881+
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
882+
base_output_dir=_TEST_BASE_OUTPUT_DIR,
883+
labels=_TEST_LABELS,
884+
)
885+
886+
job.run(
887+
service_account=_TEST_SERVICE_ACCOUNT,
888+
tensorboard=_TEST_TENSORBOARD_NAME,
889+
network=_TEST_NETWORK,
890+
timeout=_TEST_TIMEOUT,
891+
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
892+
create_request_timeout=None,
893+
disable_retries=_TEST_DISABLE_RETRIES,
894+
max_wait_duration=_TEST_MAX_WAIT_DURATION,
895+
)
896+
897+
job.wait()
898+
899+
assert "Failed to list experiment runs for tensorboard" in caplog.text
900+
901+
@pytest.mark.usefixtures(
902+
"get_experiment_run_not_found_mock",
903+
"get_tensorboard_run_artifact_not_found_mock",
904+
)
905+
def test_run_custom_job_with_tensorboard_cannot_end_experiment_run(
906+
self,
907+
create_custom_job_mock_with_tensorboard,
908+
get_custom_job_with_tensorboard_mock,
909+
caplog,
910+
):
911+
912+
aiplatform.init(
913+
project=_TEST_PROJECT,
914+
location=_TEST_LOCATION,
915+
staging_bucket=_TEST_STAGING_BUCKET,
916+
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
917+
)
918+
919+
job = aiplatform.CustomJob(
920+
display_name=_TEST_DISPLAY_NAME,
921+
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
922+
base_output_dir=_TEST_BASE_OUTPUT_DIR,
923+
labels=_TEST_LABELS,
924+
)
925+
926+
with mock.patch.object(
927+
aiplatform.TensorboardRun, "list"
928+
) as list_tensorboard_runs_mock:
929+
tb_run = gca_tensorboard_run.TensorboardRun(
930+
name=_TEST_TENSORBOARD_RUN_NAME,
931+
display_name=_TEST_DISPLAY_NAME,
932+
)
933+
list_tensorboard_runs_mock.return_value = [tb_run]
934+
935+
job.run(
936+
service_account=_TEST_SERVICE_ACCOUNT,
937+
tensorboard=_TEST_TENSORBOARD_NAME,
938+
network=_TEST_NETWORK,
939+
timeout=_TEST_TIMEOUT,
940+
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
941+
create_request_timeout=None,
942+
disable_retries=_TEST_DISABLE_RETRIES,
943+
max_wait_duration=_TEST_MAX_WAIT_DURATION,
944+
)
945+
946+
job.wait()
947+
948+
assert (
949+
f"Failed to end experiment run {_TEST_TENSORBOARD_RUN_CONTEXT_NAME} due to:"
950+
in caplog.text
951+
)
952+
825953
@pytest.mark.parametrize("sync", [True, False])
826954
def test_run_custom_job_with_fail_raises(
827955
self, create_custom_job_mock, get_custom_job_mock_with_fail, sync

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy