Content-Length: 636343 | pFad | https://github.com/googleapis/python-aiplatform/commit/e3eb82f59d3f28dfedd71b9e69a0e967a01eada5

EE feat: support dataset update (#1416) · googleapis/python-aiplatform@e3eb82f · GitHub
Skip to content

Commit e3eb82f

Browse files
authored
feat: support dataset update (#1416)
* feat: add update() method and system test * fix: fix and add unit test * remove superfluous line in system test
1 parent b91db66 commit e3eb82f

File tree

3 files changed

+137
-4
lines changed

3 files changed

+137
-4
lines changed

google/cloud/aiplatform/datasets/dataset.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
# Copyright 2020 Google LLC
3+
# Copyright 2022 Google LLC
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -31,6 +31,7 @@
3131
io as gca_io,
3232
)
3333
from google.cloud.aiplatform.datasets import _datasources
34+
from google.protobuf import field_mask_pb2
3435

3536
_LOGGER = base.Logger(__name__)
3637

@@ -597,8 +598,69 @@ def export_data(self, output_dir: str) -> Sequence[str]:
597598

598599
return export_data_response.exported_files
599600

600-
def update(self):
601-
raise NotImplementedError("Update dataset has not been implemented yet")
601+
def update(
602+
self,
603+
*,
604+
display_name: Optional[str] = None,
605+
labels: Optional[Dict[str, str]] = None,
606+
description: Optional[str] = None,
607+
update_request_timeout: Optional[float] = None,
608+
) -> "_Dataset":
609+
"""Update the dataset.
610+
Updatable fields:
611+
- ``display_name``
612+
- ``description``
613+
- ``labels``
614+
615+
Args:
616+
display_name (str):
617+
Optional. The user-defined name of the Dataset.
618+
The name can be up to 128 characters long and can be consist
619+
of any UTF-8 characters.
620+
labels (Dict[str, str]):
621+
Optional. Labels with user-defined metadata to organize your Tensorboards.
622+
Label keys and values can be no longer than 64 characters
623+
(Unicode codepoints), can only contain lowercase letters, numeric
624+
characters, underscores and dashes. International characters are allowed.
625+
No more than 64 user labels can be associated with one Tensorboard
626+
(System labels are excluded).
627+
See https://goo.gl/xmQnxf for more information and examples of labels.
628+
System reserved label keys are prefixed with "aiplatform.googleapis.com/"
629+
and are immutable.
630+
description (str):
631+
Optional. The description of the Dataset.
632+
update_request_timeout (float):
633+
Optional. The timeout for the update request in seconds.
634+
635+
Returns:
636+
dataset (Dataset):
637+
Updated dataset.
638+
"""
639+
640+
update_mask = field_mask_pb2.FieldMask()
641+
if display_name:
642+
update_mask.paths.append("display_name")
643+
644+
if labels:
645+
update_mask.paths.append("labels")
646+
647+
if description:
648+
update_mask.paths.append("description")
649+
650+
update_dataset = gca_dataset.Dataset(
651+
name=self.resource_name,
652+
display_name=display_name,
653+
description=description,
654+
labels=labels,
655+
)
656+
657+
self._gca_resource = self.api_client.update_dataset(
658+
dataset=update_dataset,
659+
update_mask=update_mask,
660+
timeout=update_request_timeout,
661+
)
662+
663+
return self
602664

603665
@classmethod
604666
def list(

tests/system/aiplatform/test_dataset.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
# Copyright 2020 Google LLC
3+
# Copyright 2022 Google LLC
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -50,6 +50,8 @@
5050
"6203215905493614592" # permanent_text_entity_extraction_dataset
5151
)
5252
_TEST_DATASET_DISPLAY_NAME = "permanent_50_flowers_dataset"
53+
_TEST_DATASET_LABELS = {"test": "labels"}
54+
_TEST_DATASET_DESCRIPTION = "test description"
5355
_TEST_TABULAR_CLASSIFICATION_GCS_SOURCE = "gs://ucaip-sample-resources/iris_1000.csv"
5456
_TEST_FORECASTING_BQ_SOURCE = (
5557
"bq://ucaip-sample-tests:ucaip_test_us_central1.2020_sales_train"
@@ -350,3 +352,26 @@ def test_export_data(self, storage_client, staging_bucket):
350352
blob = bucket.get_blob(prefix)
351353

352354
assert blob # Verify the returned GCS export path exists
355+
356+
def test_update_dataset(self):
357+
"""Create a new dataset and use update() method to change its display_name, labels, and description.
358+
Then confirm these fields of the dataset was successfully modifed."""
359+
360+
try:
361+
dataset = aiplatform.ImageDataset.create()
362+
labels = dataset.labels
363+
364+
dataset = dataset.update(
365+
display_name=_TEST_DATASET_DISPLAY_NAME,
366+
labels=_TEST_DATASET_LABELS,
367+
description=_TEST_DATASET_DESCRIPTION,
368+
update_request_timeout=None,
369+
)
370+
labels.update(_TEST_DATASET_LABELS)
371+
372+
assert dataset.display_name == _TEST_DATASET_DISPLAY_NAME
373+
assert dataset.labels == labels
374+
assert dataset.gca_resource.description == _TEST_DATASET_DESCRIPTION
375+
376+
finally:
377+
dataset.delete()

tests/unit/aiplatform/test_datasets.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from google.cloud.aiplatform import schema
3737
from google.cloud import bigquery
3838
from google.cloud import storage
39+
from google.protobuf import field_mask_pb2
3940

4041
from google.cloud.aiplatform.compat.services import dataset_service_client
4142

@@ -59,6 +60,7 @@
5960
_TEST_ID = "1028944691210842416"
6061
_TEST_DISPLAY_NAME = "my_dataset_1234"
6162
_TEST_DATA_LABEL_ITEMS = None
63+
_TEST_DESCRIPTION = "test description"
6264

6365
_TEST_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/datasets/{_TEST_ID}"
6466
_TEST_ALT_NAME = (
@@ -425,6 +427,20 @@ def export_data_mock():
425427
yield export_data_mock
426428

427429

430+
@pytest.fixture
431+
def update_dataset_mock():
432+
with patch.object(
433+
dataset_service_client.DatasetServiceClient, "update_dataset"
434+
) as update_dataset_mock:
435+
update_dataset_mock.return_value = gca_dataset.Dataset(
436+
name=_TEST_NAME,
437+
display_name=f"update_{_TEST_DISPLAY_NAME}",
438+
labels=_TEST_LABELS,
439+
description=_TEST_DESCRIPTION,
440+
)
441+
yield update_dataset_mock
442+
443+
428444
@pytest.fixture
429445
def list_datasets_mock():
430446
with patch.object(
@@ -996,6 +1012,36 @@ def test_delete_dataset(self, delete_dataset_mock, sync):
9961012

9971013
delete_dataset_mock.assert_called_once_with(name=my_dataset.resource_name)
9981014

1015+
@pytest.mark.usefixtures("get_dataset_mock")
1016+
def test_update_dataset(self, update_dataset_mock):
1017+
aiplatform.init(project=_TEST_PROJECT)
1018+
1019+
my_dataset = datasets._Dataset(dataset_name=_TEST_NAME)
1020+
1021+
my_dataset = my_dataset.update(
1022+
display_name=f"update_{_TEST_DISPLAY_NAME}",
1023+
labels=_TEST_LABELS,
1024+
description=_TEST_DESCRIPTION,
1025+
update_request_timeout=None,
1026+
)
1027+
1028+
expected_dataset = gca_dataset.Dataset(
1029+
name=_TEST_NAME,
1030+
display_name=f"update_{_TEST_DISPLAY_NAME}",
1031+
labels=_TEST_LABELS,
1032+
description=_TEST_DESCRIPTION,
1033+
)
1034+
1035+
expected_mask = field_mask_pb2.FieldMask(
1036+
paths=["display_name", "labels", "description"]
1037+
)
1038+
1039+
update_dataset_mock.assert_called_once_with(
1040+
dataset=expected_dataset,
1041+
update_mask=expected_mask,
1042+
timeout=None,
1043+
)
1044+
9991045

10001046
@pytest.mark.usefixtures("google_auth_mock")
10011047
class TestImageDataset:

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: https://github.com/googleapis/python-aiplatform/commit/e3eb82f59d3f28dfedd71b9e69a0e967a01eada5

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy