Skip to content

Commit e3eb82f

Browse files
authored
feat: support dataset update (#1416)
* feat: add update() method and system test * fix: fix and add unit test * remove superfluous line in system test
1 parent b91db66 commit e3eb82f

File tree

3 files changed

+137
-4
lines changed

3 files changed

+137
-4
lines changed

google/cloud/aiplatform/datasets/dataset.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
# Copyright 2020 Google LLC
3+
# Copyright 2022 Google LLC
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -31,6 +31,7 @@
3131
io as gca_io,
3232
)
3333
from google.cloud.aiplatform.datasets import _datasources
34+
from google.protobuf import field_mask_pb2
3435

3536
_LOGGER = base.Logger(__name__)
3637

@@ -597,8 +598,69 @@ def export_data(self, output_dir: str) -> Sequence[str]:
597598

598599
return export_data_response.exported_files
599600

600-
def update(self):
601-
raise NotImplementedError("Update dataset has not been implemented yet")
601+
def update(
602+
self,
603+
*,
604+
display_name: Optional[str] = None,
605+
labels: Optional[Dict[str, str]] = None,
606+
description: Optional[str] = None,
607+
update_request_timeout: Optional[float] = None,
608+
) -> "_Dataset":
609+
"""Update the dataset.
610+
Updatable fields:
611+
- ``display_name``
612+
- ``description``
613+
- ``labels``
614+
615+
Args:
616+
display_name (str):
617+
Optional. The user-defined name of the Dataset.
618+
The name can be up to 128 characters long and can be consist
619+
of any UTF-8 characters.
620+
labels (Dict[str, str]):
621+
Optional. Labels with user-defined metadata to organize your Tensorboards.
622+
Label keys and values can be no longer than 64 characters
623+
(Unicode codepoints), can only contain lowercase letters, numeric
624+
characters, underscores and dashes. International characters are allowed.
625+
No more than 64 user labels can be associated with one Tensorboard
626+
(System labels are excluded).
627+
See https://goo.gl/xmQnxf for more information and examples of labels.
628+
System reserved label keys are prefixed with "aiplatform.googleapis.com/"
629+
and are immutable.
630+
description (str):
631+
Optional. The description of the Dataset.
632+
update_request_timeout (float):
633+
Optional. The timeout for the update request in seconds.
634+
635+
Returns:
636+
dataset (Dataset):
637+
Updated dataset.
638+
"""
639+
640+
update_mask = field_mask_pb2.FieldMask()
641+
if display_name:
642+
update_mask.paths.append("display_name")
643+
644+
if labels:
645+
update_mask.paths.append("labels")
646+
647+
if description:
648+
update_mask.paths.append("description")
649+
650+
update_dataset = gca_dataset.Dataset(
651+
name=self.resource_name,
652+
display_name=display_name,
653+
description=description,
654+
labels=labels,
655+
)
656+
657+
self._gca_resource = self.api_client.update_dataset(
658+
dataset=update_dataset,
659+
update_mask=update_mask,
660+
timeout=update_request_timeout,
661+
)
662+
663+
return self
602664

603665
@classmethod
604666
def list(

tests/system/aiplatform/test_dataset.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
# Copyright 2020 Google LLC
3+
# Copyright 2022 Google LLC
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -50,6 +50,8 @@
5050
"6203215905493614592" # permanent_text_entity_extraction_dataset
5151
)
5252
_TEST_DATASET_DISPLAY_NAME = "permanent_50_flowers_dataset"
53+
_TEST_DATASET_LABELS = {"test": "labels"}
54+
_TEST_DATASET_DESCRIPTION = "test description"
5355
_TEST_TABULAR_CLASSIFICATION_GCS_SOURCE = "gs://ucaip-sample-resources/iris_1000.csv"
5456
_TEST_FORECASTING_BQ_SOURCE = (
5557
"bq://ucaip-sample-tests:ucaip_test_us_central1.2020_sales_train"
@@ -350,3 +352,26 @@ def test_export_data(self, storage_client, staging_bucket):
350352
blob = bucket.get_blob(prefix)
351353

352354
assert blob # Verify the returned GCS export path exists
355+
356+
def test_update_dataset(self):
357+
"""Create a new dataset and use update() method to change its display_name, labels, and description.
358+
Then confirm these fields of the dataset was successfully modifed."""
359+
360+
try:
361+
dataset = aiplatform.ImageDataset.create()
362+
labels = dataset.labels
363+
364+
dataset = dataset.update(
365+
display_name=_TEST_DATASET_DISPLAY_NAME,
366+
labels=_TEST_DATASET_LABELS,
367+
description=_TEST_DATASET_DESCRIPTION,
368+
update_request_timeout=None,
369+
)
370+
labels.update(_TEST_DATASET_LABELS)
371+
372+
assert dataset.display_name == _TEST_DATASET_DISPLAY_NAME
373+
assert dataset.labels == labels
374+
assert dataset.gca_resource.description == _TEST_DATASET_DESCRIPTION
375+
376+
finally:
377+
dataset.delete()

tests/unit/aiplatform/test_datasets.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from google.cloud.aiplatform import schema
3737
from google.cloud import bigquery
3838
from google.cloud import storage
39+
from google.protobuf import field_mask_pb2
3940

4041
from google.cloud.aiplatform.compat.services import dataset_service_client
4142

@@ -59,6 +60,7 @@
5960
_TEST_ID = "1028944691210842416"
6061
_TEST_DISPLAY_NAME = "my_dataset_1234"
6162
_TEST_DATA_LABEL_ITEMS = None
63+
_TEST_DESCRIPTION = "test description"
6264

6365
_TEST_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/datasets/{_TEST_ID}"
6466
_TEST_ALT_NAME = (
@@ -425,6 +427,20 @@ def export_data_mock():
425427
yield export_data_mock
426428

427429

430+
@pytest.fixture
431+
def update_dataset_mock():
432+
with patch.object(
433+
dataset_service_client.DatasetServiceClient, "update_dataset"
434+
) as update_dataset_mock:
435+
update_dataset_mock.return_value = gca_dataset.Dataset(
436+
name=_TEST_NAME,
437+
display_name=f"update_{_TEST_DISPLAY_NAME}",
438+
labels=_TEST_LABELS,
439+
description=_TEST_DESCRIPTION,
440+
)
441+
yield update_dataset_mock
442+
443+
428444
@pytest.fixture
429445
def list_datasets_mock():
430446
with patch.object(
@@ -996,6 +1012,36 @@ def test_delete_dataset(self, delete_dataset_mock, sync):
9961012

9971013
delete_dataset_mock.assert_called_once_with(name=my_dataset.resource_name)
9981014

1015+
@pytest.mark.usefixtures("get_dataset_mock")
1016+
def test_update_dataset(self, update_dataset_mock):
1017+
aiplatform.init(project=_TEST_PROJECT)
1018+
1019+
my_dataset = datasets._Dataset(dataset_name=_TEST_NAME)
1020+
1021+
my_dataset = my_dataset.update(
1022+
display_name=f"update_{_TEST_DISPLAY_NAME}",
1023+
labels=_TEST_LABELS,
1024+
description=_TEST_DESCRIPTION,
1025+
update_request_timeout=None,
1026+
)
1027+
1028+
expected_dataset = gca_dataset.Dataset(
1029+
name=_TEST_NAME,
1030+
display_name=f"update_{_TEST_DISPLAY_NAME}",
1031+
labels=_TEST_LABELS,
1032+
description=_TEST_DESCRIPTION,
1033+
)
1034+
1035+
expected_mask = field_mask_pb2.FieldMask(
1036+
paths=["display_name", "labels", "description"]
1037+
)
1038+
1039+
update_dataset_mock.assert_called_once_with(
1040+
dataset=expected_dataset,
1041+
update_mask=expected_mask,
1042+
timeout=None,
1043+
)
1044+
9991045

10001046
@pytest.mark.usefixtures("google_auth_mock")
10011047
class TestImageDataset:

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy