Skip to content

Commit 4241738

Browse files
authored
feat: MBSDK Tabular samples (#338)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: Tracking Bug: [MB SDK Samples - Milestone 1](https://buganizer.corp.google.com/issues/180729765) - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-aiplatform/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent c057083 commit 4241738

23 files changed

+693
-31
lines changed

samples/model-builder/conftest.py

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -138,17 +138,58 @@ def mock_import_text_dataset(mock_text_dataset):
138138

139139

140140
@pytest.fixture
141-
def mock_init_automl_image_training_job():
142-
with patch.object(
143-
aiplatform.training_jobs.AutoMLImageTrainingJob, "__init__"
144-
) as mock:
145-
mock.return_value = None
141+
def mock_custom_training_job():
142+
mock = MagicMock(aiplatform.training_jobs.CustomTrainingJob)
143+
yield mock
144+
145+
146+
@pytest.fixture
147+
def mock_image_training_job():
148+
mock = MagicMock(aiplatform.training_jobs.AutoMLImageTrainingJob)
149+
yield mock
150+
151+
152+
@pytest.fixture
153+
def mock_tabular_training_job():
154+
mock = MagicMock(aiplatform.training_jobs.AutoMLTabularTrainingJob)
155+
yield mock
156+
157+
158+
@pytest.fixture
159+
def mock_text_training_job():
160+
mock = MagicMock(aiplatform.training_jobs.AutoMLTextTrainingJob)
161+
yield mock
162+
163+
164+
@pytest.fixture
165+
def mock_video_training_job():
166+
mock = MagicMock(aiplatform.training_jobs.AutoMLVideoTrainingJob)
167+
yield mock
168+
169+
170+
@pytest.fixture
171+
def mock_get_automl_tabular_training_job(mock_tabular_training_job):
172+
with patch.object(aiplatform, "AutoMLTabularTrainingJob") as mock:
173+
mock.return_value = mock_tabular_training_job
174+
yield mock
175+
176+
177+
@pytest.fixture
178+
def mock_run_automl_tabular_training_job(mock_tabular_training_job):
179+
with patch.object(mock_tabular_training_job, "run") as mock:
180+
yield mock
181+
182+
183+
@pytest.fixture
184+
def mock_get_automl_image_training_job(mock_image_training_job):
185+
with patch.object(aiplatform, "AutoMLImageTrainingJob") as mock:
186+
mock.return_value = mock_image_training_job
146187
yield mock
147188

148189

149190
@pytest.fixture
150-
def mock_run_automl_image_training_job():
151-
with patch.object(aiplatform.training_jobs.AutoMLImageTrainingJob, "run") as mock:
191+
def mock_run_automl_image_training_job(mock_image_training_job):
192+
with patch.object(mock_image_training_job, "run") as mock:
152193
yield mock
153194

154195

@@ -173,15 +214,21 @@ def mock_run_custom_training_job():
173214

174215

175216
@pytest.fixture
176-
def mock_init_model():
177-
with patch.object(aiplatform.models.Model, "__init__") as mock:
178-
mock.return_value = None
217+
def mock_model():
218+
mock = MagicMock(aiplatform.models.Model)
219+
yield mock
220+
221+
222+
@pytest.fixture
223+
def mock_get_model(mock_model):
224+
with patch.object(aiplatform, "Model") as mock:
225+
mock.return_value = mock_model
179226
yield mock
180227

181228

182229
@pytest.fixture
183-
def mock_batch_predict_model():
184-
with patch.object(aiplatform.models.Model, "batch_predict") as mock:
230+
def mock_batch_predict_model(mock_model):
231+
with patch.object(mock_model, "batch_predict") as mock:
185232
yield mock
186233

187234

@@ -211,6 +258,12 @@ def mock_endpoint():
211258
yield mock
212259

213260

261+
@pytest.fixture
262+
def mock_create_endpoint():
263+
with patch.object(aiplatform.Endpoint, "create") as mock:
264+
yield mock
265+
266+
214267
@pytest.fixture
215268
def mock_get_endpoint(mock_endpoint):
216269
with patch.object(aiplatform, "Endpoint") as mock_get_endpoint:
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from google.cloud import aiplatform
17+
18+
19+
# [START aiplatform_sdk_create_and_import_dataset_tabular_bigquery_sample]
20+
def create_and_import_dataset_tabular_bigquery_sample(
21+
display_name: str, project: str, location: str, bigquery_source: str,
22+
):
23+
24+
aiplatform.init(project=project, location=location)
25+
26+
dataset = aiplatform.TabularDataset.create(
27+
display_name=display_name, bigquery_source=bigquery_source,
28+
)
29+
30+
dataset.wait()
31+
32+
print(f'\tDataset: "{dataset.display_name}"')
33+
print(f'\tname: "{dataset.resource_name}"')
34+
35+
36+
# [END aiplatform_sdk_create_and_import_dataset_tabular_bigquery_sample]
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import create_and_import_dataset_tabular_bigquery_sample
17+
import test_constants as constants
18+
19+
20+
def test_create_and_import_dataset_tabular_bigquery_sample(
21+
mock_sdk_init, mock_create_tabular_dataset
22+
):
23+
24+
create_and_import_dataset_tabular_bigquery_sample.create_and_import_dataset_tabular_bigquery_sample(
25+
project=constants.PROJECT,
26+
location=constants.LOCATION,
27+
bigquery_source=constants.BIGQUERY_SOURCE,
28+
display_name=constants.DISPLAY_NAME,
29+
)
30+
31+
mock_sdk_init.assert_called_once_with(
32+
project=constants.PROJECT, location=constants.LOCATION
33+
)
34+
mock_create_tabular_dataset.assert_called_once_with(
35+
display_name=constants.DISPLAY_NAME, bigquery_source=constants.BIGQUERY_SOURCE,
36+
)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List, Union
16+
17+
from google.cloud import aiplatform
18+
19+
20+
# [START aiplatform_sdk_create_and_import_dataset_tabular_gcs_sample]
21+
def create_and_import_dataset_tabular_gcs_sample(
22+
display_name: str, project: str, location: str, gcs_source: Union[str, List[str]],
23+
):
24+
25+
aiplatform.init(project=project, location=location)
26+
27+
dataset = aiplatform.TabularDataset.create(
28+
display_name=display_name, gcs_source=gcs_source,
29+
)
30+
31+
dataset.wait()
32+
33+
print(f'\tDataset: "{dataset.display_name}"')
34+
print(f'\tname: "{dataset.resource_name}"')
35+
36+
37+
# [END aiplatform_sdk_create_and_import_dataset_tabular_gcs_sample]
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import create_and_import_dataset_tabular_gcs_sample
17+
import test_constants as constants
18+
19+
20+
def test_create_and_import_dataset_tabular_gcs_sample(
21+
mock_sdk_init, mock_create_tabular_dataset
22+
):
23+
24+
create_and_import_dataset_tabular_gcs_sample.create_and_import_dataset_tabular_gcs_sample(
25+
project=constants.PROJECT,
26+
location=constants.LOCATION,
27+
gcs_source=constants.GCS_SOURCES,
28+
display_name=constants.DISPLAY_NAME,
29+
)
30+
31+
mock_sdk_init.assert_called_once_with(
32+
project=constants.PROJECT, location=constants.LOCATION
33+
)
34+
mock_create_tabular_dataset.assert_called_once_with(
35+
display_name=constants.DISPLAY_NAME, gcs_source=constants.GCS_SOURCES,
36+
)

samples/model-builder/create_batch_prediction_job_sample_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
def test_create_batch_prediction_job_sample(
21-
mock_sdk_init, mock_init_model, mock_batch_predict_model
21+
mock_sdk_init, mock_get_model, mock_batch_predict_model
2222
):
2323

2424
create_batch_prediction_job_sample.create_batch_prediction_job_sample(
@@ -33,7 +33,7 @@ def test_create_batch_prediction_job_sample(
3333
mock_sdk_init.assert_called_once_with(
3434
project=constants.PROJECT, location=constants.LOCATION
3535
)
36-
mock_init_model.assert_called_once_with(constants.MODEL_NAME)
36+
mock_get_model.assert_called_once_with(constants.MODEL_NAME)
3737
mock_batch_predict_model.assert_called_once_with(
3838
job_display_name=constants.DISPLAY_NAME,
3939
gcs_source=constants.GCS_SOURCES,
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.cloud import aiplatform
16+
17+
18+
# [START aiplatform_sdk_create_endpoint_sample]
19+
def create_endpoint_sample(
20+
project: str, display_name: str, location: str, sync: bool = True,
21+
):
22+
aiplatform.init(project=project, location=location)
23+
24+
endpoint = aiplatform.Endpoint.create(
25+
display_name=display_name, project=project, location=location,
26+
)
27+
28+
print(endpoint.display_name)
29+
print(endpoint.resource_name)
30+
print(endpoint.uri)
31+
return endpoint
32+
33+
34+
# [END aiplatform_sdk_create_endpoint_sample]
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import create_endpoint_sample
17+
import test_constants as constants
18+
19+
20+
def test_create_endpoint_sample(mock_sdk_init, mock_create_endpoint):
21+
22+
create_endpoint_sample.create_endpoint_sample(
23+
project=constants.PROJECT,
24+
display_name=constants.DISPLAY_NAME,
25+
location=constants.LOCATION,
26+
)
27+
28+
mock_sdk_init.assert_called_once_with(
29+
project=constants.PROJECT, location=constants.LOCATION
30+
)
31+
32+
mock_create_endpoint.assert_called_once_with(
33+
display_name=constants.DISPLAY_NAME,
34+
project=constants.PROJECT,
35+
location=constants.LOCATION,
36+
)

samples/model-builder/create_training_pipeline_image_classification_sample_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
def test_create_training_pipeline_image_classification_sample(
2121
mock_sdk_init,
2222
mock_image_dataset,
23-
mock_init_automl_image_training_job,
23+
mock_get_automl_image_training_job,
2424
mock_run_automl_image_training_job,
2525
mock_get_image_dataset,
2626
):
@@ -43,7 +43,7 @@ def test_create_training_pipeline_image_classification_sample(
4343
mock_sdk_init.assert_called_once_with(
4444
project=constants.PROJECT, location=constants.LOCATION
4545
)
46-
mock_init_automl_image_training_job.assert_called_once_with(
46+
mock_get_automl_image_training_job.assert_called_once_with(
4747
display_name=constants.DISPLAY_NAME
4848
)
4949
mock_run_automl_image_training_job.assert_called_once_with(

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy