Content-Length: 566982 | pFad | http://www.github.com/googleapis/python-aiplatform/commit/8b9376e9c961a751799f5b80d1b19917c8c353f8

D95 feat: default to custom job display name if experiment name looks lik… · googleapis/python-aiplatform@8b9376e · GitHub
Skip to content

Commit 8b9376e

Browse files
yfang1Yicheng Fang
and
Yicheng Fang
authored
feat: default to custom job display name if experiment name looks like a custom job ID (#833)
Co-authored-by: Yicheng Fang <yichengfang@google.com>
1 parent e0fc3d9 commit 8b9376e

File tree

2 files changed

+146
-2
lines changed

2 files changed

+146
-2
lines changed

google/cloud/aiplatform/tensorboard/uploader_main.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@
2828
from tensorboard.plugins.image import metadata as images_metadata
2929
from tensorboard.plugins.graph import metadata as graphs_metadata
3030

31+
from google.api_core import exceptions
3132
from google.cloud import storage
3233
from google.cloud import aiplatform
34+
from google.cloud.aiplatform import jobs
3335
from google.cloud.aiplatform.tensorboard import uploader
3436
from google.cloud.aiplatform.utils import TensorboardClientWithOverride
3537

@@ -123,9 +125,14 @@ def main(argv):
123125
exitcode=0,
124126
)
125127

128+
experiment_name = FLAGS.experiment_name
129+
experiment_display_name = get_experiment_display_name_with_override(
130+
experiment_name, FLAGS.experiment_display_name, project_id, region
131+
)
132+
126133
tb_uploader = uploader.TensorBoardUploader(
127-
experiment_name=FLAGS.experiment_name,
128-
experiment_display_name=FLAGS.experiment_display_name,
134+
experiment_name=experiment_name,
135+
experiment_display_name=experiment_display_name,
129136
tensorboard_resource_name=tensorboard.name,
130137
blob_storage_bucket=blob_storage_bucket,
131138
blob_storage_folder=blob_storage_folder,
@@ -149,6 +156,19 @@ def main(argv):
149156
tb_uploader.start_uploading()
150157

151158

159+
def get_experiment_display_name_with_override(
160+
experiment_name, experiment_display_name, project_id, region
161+
):
162+
if experiment_name.isdecimal() and not experiment_display_name:
163+
try:
164+
return jobs.CustomJob.get(
165+
resource_name=experiment_name, project=project_id, location=region,
166+
).display_name
167+
except exceptions.NotFound:
168+
return experiment_display_name
169+
return experiment_display_name
170+
171+
152172
def flags_parser(args):
153173
# Plumbs the flags defined in this file to the main module, mostly for the
154174
# console script wrapper tb-gcp-uploader.
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2021 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
import pytest
18+
19+
from importlib import reload
20+
from unittest.mock import patch
21+
22+
from google.api_core import exceptions
23+
from google.cloud import aiplatform
24+
from google.cloud.aiplatform import initializer
25+
from google.cloud.aiplatform.tensorboard import uploader_main
26+
from google.cloud.aiplatform.compat.types import job_state as gca_job_state_compat
27+
from google.cloud.aiplatform.compat.types import custom_job as gca_custom_job_compat
28+
from google.cloud.aiplatform_v1.services.job_service import client as job_service_client
29+
30+
_TEST_PROJECT = "test-project"
31+
_TEST_LOCATION = "us-central1"
32+
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
33+
_TEST_CUSTOM_JOB_ID = "445768"
34+
_TEST_CUSTOM_JOB_NAME = f"{_TEST_PARENT}/customJobs/{_TEST_CUSTOM_JOB_ID}"
35+
_TEST_CUSTOM_JOBS_DISPLAY_NAME = "a custom job display name"
36+
_TEST_PASSED_IN_EXPERIMENT_DISPLAY_NAME = "someDisplayName"
37+
38+
39+
def _get_custom_job_proto(state=None, name=None):
40+
custom_job_proto = gca_custom_job_compat.CustomJob()
41+
custom_job_proto.name = name
42+
custom_job_proto.state = state
43+
custom_job_proto.display_name = _TEST_CUSTOM_JOBS_DISPLAY_NAME
44+
return custom_job_proto
45+
46+
47+
@pytest.fixture
48+
def get_custom_job_mock_not_found():
49+
with patch.object(
50+
job_service_client.JobServiceClient, "get_custom_job"
51+
) as get_custom_job_mock:
52+
get_custom_job_mock.side_effect = exceptions.NotFound("not found")
53+
yield get_custom_job_mock
54+
55+
56+
@pytest.fixture
57+
def get_custom_job_mock():
58+
with patch.object(
59+
job_service_client.JobServiceClient, "get_custom_job"
60+
) as get_custom_job_mock:
61+
get_custom_job_mock.side_effect = [
62+
_get_custom_job_proto(
63+
name=_TEST_CUSTOM_JOB_NAME,
64+
state=gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED,
65+
),
66+
]
67+
yield get_custom_job_mock
68+
69+
70+
class TestUploaderMain:
71+
def setup_method(self):
72+
reload(initializer)
73+
reload(aiplatform)
74+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
75+
76+
def teardown_method(self):
77+
initializer.global_pool.shutdown(wait=True)
78+
79+
def test_get_default_custom_job_display_name(self, get_custom_job_mock):
80+
aiplatform.init(project=_TEST_PROJECT)
81+
assert (
82+
uploader_main.get_experiment_display_name_with_override(
83+
_TEST_CUSTOM_JOB_ID, None, _TEST_PROJECT, _TEST_LOCATION
84+
)
85+
== _TEST_CUSTOM_JOBS_DISPLAY_NAME
86+
)
87+
88+
def test_non_decimal_experiment_name(self, get_custom_job_mock):
89+
aiplatform.init(project=_TEST_PROJECT)
90+
assert (
91+
uploader_main.get_experiment_display_name_with_override(
92+
"someExperimentName",
93+
_TEST_PASSED_IN_EXPERIMENT_DISPLAY_NAME,
94+
_TEST_PROJECT,
95+
_TEST_LOCATION,
96+
)
97+
== _TEST_PASSED_IN_EXPERIMENT_DISPLAY_NAME
98+
)
99+
get_custom_job_mock.assert_not_called()
100+
101+
def test_display_name_already_specified(self, get_custom_job_mock):
102+
aiplatform.init(project=_TEST_PROJECT)
103+
assert (
104+
uploader_main.get_experiment_display_name_with_override(
105+
_TEST_CUSTOM_JOB_ID,
106+
_TEST_PASSED_IN_EXPERIMENT_DISPLAY_NAME,
107+
_TEST_PROJECT,
108+
_TEST_LOCATION,
109+
)
110+
== _TEST_PASSED_IN_EXPERIMENT_DISPLAY_NAME
111+
)
112+
get_custom_job_mock.assert_not_called()
113+
114+
def test_custom_job_not_found(self, get_custom_job_mock_not_found):
115+
aiplatform.init(project=_TEST_PROJECT)
116+
assert (
117+
uploader_main.get_experiment_display_name_with_override(
118+
_TEST_CUSTOM_JOB_ID,
119+
_TEST_PASSED_IN_EXPERIMENT_DISPLAY_NAME,
120+
_TEST_PROJECT,
121+
_TEST_LOCATION,
122+
)
123+
== _TEST_PASSED_IN_EXPERIMENT_DISPLAY_NAME
124+
)

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: http://www.github.com/googleapis/python-aiplatform/commit/8b9376e9c961a751799f5b80d1b19917c8c353f8

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy