Skip to content

Commit 98459aa

Browse files
fthoelecopybara-github
authored andcommitted
feat: Add validation of the BigQuery location when creating a MultimodalDataset
PiperOrigin-RevId: 741515869
1 parent 184cca5 commit 98459aa

File tree

2 files changed

+133
-26
lines changed

2 files changed

+133
-26
lines changed

google/cloud/aiplatform/preview/datasets.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,46 @@ def _get_metadata_for_bq(
8888
return json_format.ParseDict(input_config, struct_pb2.Value())
8989

9090

91-
def _normalize_table_id(*, table_id: str, project: str):
92-
if table_id.count(".") == 1:
93-
# table_id has the "dataset.table" format, prepend the project
94-
return f"{project}.{table_id}"
95-
elif table_id.count(".") != 2:
96-
raise ValueError(f"invalid table id: {table_id}")
97-
return table_id
91+
def _normalize_and_validate_table_id(
92+
*,
93+
table_id: str,
94+
project: Optional[str] = None,
95+
vertex_location: Optional[str] = None,
96+
credentials: Optional[auth_credentials.Credentials] = None,
97+
):
98+
from google.cloud import bigquery # pylint: disable=g-import-not-at-top
99+
100+
if not project:
101+
project = initializer.global_config.project
102+
if not vertex_location:
103+
vertex_location = initializer.global_config.location
104+
if not credentials:
105+
credentials = initializer.global_config.credentials
106+
107+
table_ref = bigquery.TableReference.from_string(table_id, default_project=project)
108+
if table_ref.project != project:
109+
raise ValueError(
110+
f"The BigQuery table "
111+
f"`{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}`"
112+
" must be in the same project as the multimodal dataset."
113+
f" The multimodal dataset is in `{project}`, but the BigQuery table"
114+
f" is in `{table_ref.project}`."
115+
)
116+
117+
dataset_ref = bigquery.DatasetReference(
118+
project=table_ref.project, dataset_id=table_ref.dataset_id
119+
)
120+
client = bigquery.Client(project=project, credentials=credentials)
121+
bq_dataset = client.get_dataset(dataset_ref=dataset_ref)
122+
if bq_dataset.location != vertex_location:
123+
raise ValueError(
124+
f"The BigQuery dataset"
125+
f" `{dataset_ref.project}.{dataset_ref.dataset_id}` must be in the"
126+
" same location as the multimodal dataset. The multimodal dataset"
127+
f" is in `{vertex_location}`, but the BigQuery dataset is in"
128+
f" `{bq_dataset.location}`."
129+
)
130+
return f"{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}"
98131

99132

100133
class GeminiExample:
@@ -577,7 +610,8 @@ def from_pandas(
577610
table id can be in the format of "dataset.table" or
578611
"project.dataset.table". If a table already exists with the
579612
given table id, it will be overwritten. Note that the BigQuery
580-
dataset must already exist.
613+
dataset must already exist and be in the same location as the
614+
multimodal dataset.
581615
display_name (str):
582616
Optional. The user-defined name of the dataset. The name can be
583617
up to 128 characters long and can consist of any UTF-8
@@ -614,12 +648,15 @@ def from_pandas(
614648
The created multimodal dataset.
615649
"""
616650
bigframes = _try_import_bigframes()
617-
if not project:
618-
project = initializer.global_config.project
619651
# TODO(b/400355374): `table_id` should be optional, and if not provided,
620652
# we generate a random table id. Also, check if we can use a default
621653
# dataset that's created from the SDK.
622-
target_table_id = _normalize_table_id(table_id=target_table_id, project=project)
654+
target_table_id = _normalize_and_validate_table_id(
655+
table_id=target_table_id,
656+
project=project,
657+
vertex_location=location,
658+
credentials=credentials,
659+
)
623660

624661
temp_bigframes_df = bigframes.pandas.read_pandas(dataframe)
625662
temp_bigframes_df.to_gbq(
@@ -662,7 +699,8 @@ def from_bigframes(
662699
table id can be in the format of "dataset.table" or
663700
"project.dataset.table". If a table already exists with the
664701
given table id, it will be overwritten. Note that the BigQuery
665-
dataset must already exist.
702+
dataset must already exist and be in the same location as the
703+
multimodal dataset.
666704
display_name (str):
667705
Optional. The user-defined name of the dataset. The name can be
668706
up to 128 characters long and can consist of any UTF-8
@@ -697,12 +735,14 @@ def from_bigframes(
697735
Returns:
698736
The created multimodal dataset.
699737
"""
700-
project_id = project or initializer.global_config.project
701738
# TODO(b/400355374): `table_id` should be optional, and if not provided,
702739
# we generate a random table id. Also, check if we can use a default
703740
# dataset that's created from the SDK.
704-
target_table_id = _normalize_table_id(
705-
table_id=target_table_id, project=project_id
741+
target_table_id = _normalize_and_validate_table_id(
742+
table_id=target_table_id,
743+
project=project,
744+
vertex_location=location,
745+
credentials=credentials,
706746
)
707747
dataframe.to_gbq(
708748
destination_table=target_table_id,

tests/unit/aiplatform/test_multimodal_datasets.py

Lines changed: 78 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from google import auth
2222
from google.api_core import operation
2323
from google.auth import credentials as auth_credentials
24+
from google.cloud import bigquery
2425
from google.cloud import aiplatform
2526
from google.cloud.aiplatform import base
2627
from google.cloud.aiplatform import initializer
@@ -42,6 +43,7 @@
4243

4344
_TEST_PROJECT = "test-project"
4445
_TEST_LOCATION = "us-central1"
46+
_TEST_ALTERNATE_LOCATION = "europe-west6"
4547
_TEST_ID = "1028944691210842416"
4648
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
4749
_TEST_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/datasets/{_TEST_ID}"
@@ -53,6 +55,8 @@
5355
)
5456

5557
_TEST_SOURCE_URI_BQ = "bq://my-project.my-dataset.table"
58+
_TEST_TARGET_BQ_DATASET = f"{_TEST_PROJECT}.target-dataset"
59+
_TEST_TARGET_BQ_TABLE = f"{_TEST_TARGET_BQ_DATASET}.target-table"
5660
_TEST_DISPLAY_NAME = "my_dataset_1234"
5761
_TEST_METADATA_SCHEMA_URI_MULTIMODAL = (
5862
"gs://google-cloud-aiplatform/schema/dataset/metadata/multimodal_1.0.0.yaml"
@@ -168,6 +172,24 @@ def bigframes_import_mock():
168172
del sys.modules["bigframes.pandas"]
169173

170174

175+
@pytest.fixture
176+
def get_bq_dataset_mock():
177+
with mock.patch.object(bigquery.Client, "get_dataset") as get_bq_dataset_mock:
178+
bq_dataset = mock.Mock()
179+
bq_dataset.location = _TEST_LOCATION
180+
get_bq_dataset_mock.return_value = bq_dataset
181+
yield get_bq_dataset_mock
182+
183+
184+
@pytest.fixture
185+
def get_bq_dataset_alternate_location_mock():
186+
with mock.patch.object(bigquery.Client, "get_dataset") as get_bq_dataset_mock:
187+
bq_dataset = mock.Mock()
188+
bq_dataset.location = _TEST_ALTERNATE_LOCATION
189+
get_bq_dataset_mock.return_value = bq_dataset
190+
yield get_bq_dataset_mock
191+
192+
171193
@pytest.fixture
172194
def update_dataset_with_template_config_mock():
173195
with mock.patch.object(
@@ -259,7 +281,7 @@ def test_create_dataset_from_bigquery(self, create_dataset_mock, sync):
259281
)
260282

261283
@pytest.mark.skip(reason="flaky with other tests mocking bigframes")
262-
@pytest.mark.usefixtures("get_dataset_mock")
284+
@pytest.mark.usefixtures("get_dataset_mock", "get_bq_dataset_mock")
263285
def test_create_dataset_from_pandas(
264286
self, create_dataset_mock, bigframes_import_mock
265287
):
@@ -273,55 +295,100 @@ def test_create_dataset_from_pandas(
273295
"answer": ["answer"],
274296
}
275297
)
276-
bq_table = "my-project.my-dataset.my-table"
277298
ummd.MultimodalDataset.from_pandas(
278299
dataframe=dataframe,
279-
target_table_id=bq_table,
300+
target_table_id=_TEST_TARGET_BQ_TABLE,
280301
display_name=_TEST_DISPLAY_NAME,
281302
)
282303
expected_dataset = gca_dataset.Dataset(
283304
display_name=_TEST_DISPLAY_NAME,
284305
metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_MULTIMODAL,
285-
metadata={"inputConfig": {"bigquerySource": {"uri": f"bq://{bq_table}"}}},
306+
metadata={
307+
"inputConfig": {
308+
"bigquerySource": {"uri": f"bq://{_TEST_TARGET_BQ_TABLE}"}
309+
}
310+
},
286311
)
287312
create_dataset_mock.assert_called_once_with(
288313
dataset=expected_dataset,
289314
parent=_TEST_PARENT,
290315
timeout=None,
291316
)
292317
bigframes_mock.to_gbq.assert_called_once_with(
293-
destination_table=bq_table,
318+
destination_table=_TEST_TARGET_BQ_TABLE,
294319
if_exists="replace",
295320
)
296321

297322
@pytest.mark.skip(reason="flaky with other tests mocking bigframes")
298-
@pytest.mark.usefixtures("bigframes_import_mock")
299-
@pytest.mark.usefixtures("get_dataset_mock")
323+
@pytest.mark.usefixtures(
324+
"bigframes_import_mock", "get_dataset_mock", "get_bq_dataset_mock"
325+
)
300326
def test_create_dataset_from_bigframes(self, create_dataset_mock):
301327
aiplatform.init(project=_TEST_PROJECT)
302328
bigframes_df = mock.Mock()
303-
bq_table = "my-project.my-dataset.my-table"
304329
ummd.MultimodalDataset.from_bigframes(
305330
dataframe=bigframes_df,
306-
target_table_id=bq_table,
331+
target_table_id=_TEST_TARGET_BQ_TABLE,
307332
display_name=_TEST_DISPLAY_NAME,
308333
)
309334

310335
bigframes_df.to_gbq.assert_called_once_with(
311-
destination_table=bq_table,
336+
destination_table=_TEST_TARGET_BQ_TABLE,
312337
if_exists="replace",
313338
)
314339
expected_dataset = gca_dataset.Dataset(
315340
display_name=_TEST_DISPLAY_NAME,
316341
metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_MULTIMODAL,
317-
metadata={"inputConfig": {"bigquerySource": {"uri": f"bq://{bq_table}"}}},
342+
metadata={
343+
"inputConfig": {
344+
"bigquerySource": {"uri": f"bq://{_TEST_TARGET_BQ_TABLE}"}
345+
}
346+
},
318347
)
319348
create_dataset_mock.assert_called_once_with(
320349
dataset=expected_dataset,
321350
parent=_TEST_PARENT,
322351
timeout=None,
323352
)
324353

354+
@pytest.mark.skip(reason="flaky with other tests mocking bigframes")
355+
@pytest.mark.usefixtures("bigframes_import_mock")
356+
def test_create_dataset_from_bigframes_different_project_throws_error(self):
357+
aiplatform.init(project=_TEST_PROJECT)
358+
bigframes_df = mock.Mock()
359+
with pytest.raises(ValueError):
360+
ummd.MultimodalDataset.from_bigframes(
361+
dataframe=bigframes_df,
362+
target_table_id="another_project.dataset.table",
363+
display_name=_TEST_DISPLAY_NAME,
364+
)
365+
366+
@pytest.mark.skip(reason="flaky with other tests mocking bigframes")
367+
@pytest.mark.usefixtures(
368+
"bigframes_import_mock", "get_bq_dataset_alternate_location_mock"
369+
)
370+
def test_create_dataset_from_bigframes_different_location_throws_error(self):
371+
aiplatform.init(project=_TEST_PROJECT)
372+
bigframes_df = mock.Mock()
373+
with pytest.raises(ValueError):
374+
ummd.MultimodalDataset.from_bigframes(
375+
dataframe=bigframes_df,
376+
target_table_id=_TEST_TARGET_BQ_TABLE,
377+
display_name=_TEST_DISPLAY_NAME,
378+
)
379+
380+
@pytest.mark.skip(reason="flaky with other tests mocking bigframes")
381+
@pytest.mark.usefixtures("bigframes_import_mock")
382+
def test_create_dataset_from_bigframes_invalid_target_table_id_throws_error(self):
383+
aiplatform.init(project=_TEST_PROJECT)
384+
bigframes_df = mock.Mock()
385+
with pytest.raises(ValueError):
386+
ummd.MultimodalDataset.from_bigframes(
387+
dataframe=bigframes_df,
388+
target_table_id="invalid-table",
389+
display_name=_TEST_DISPLAY_NAME,
390+
)
391+
325392
@pytest.mark.usefixtures("get_dataset_mock")
326393
def test_update_dataset(self, update_dataset_mock):
327394
aiplatform.init(project=_TEST_PROJECT)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy