Skip to content

Commit 77a741e

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Also support unhashable objects to be serialized with extra args
PiperOrigin-RevId: 577998940
1 parent 1e4a4ec commit 77a741e

File tree

7 files changed

+180
-21
lines changed

7 files changed

+180
-21
lines changed

tests/unit/vertexai/test_remote_training.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,9 @@ def test_remote_training_sklearn_with_remote_configs(
972972
_TEST_TRAINING_CONFIG_CONTAINER_URI
973973
)
974974
model.fit.vertex.remote_config.machine_type = _TEST_TRAINING_CONFIG_MACHINE_TYPE
975-
model.fit.vertex.remote_config.serializer_args = {model: {"extra_params": 1}}
975+
model.fit.vertex.remote_config.serializer_args[model] = {"extra_params": 1}
976+
# X_TRAIN is a numpy array that is not hashable.
977+
model.fit.vertex.remote_config.serializer_args[_X_TRAIN] = {"extra_params": 2}
976978

977979
model.fit(_X_TRAIN, _Y_TRAIN)
978980

@@ -991,7 +993,7 @@ def test_remote_training_sklearn_with_remote_configs(
991993
mock_any_serializer_sklearn.return_value.serialize.assert_any_call(
992994
to_serialize=_X_TRAIN,
993995
gcs_path=os.path.join(remote_job_base_path, "input/X"),
994-
**{},
996+
**{"extra_params": 2},
995997
)
996998
mock_any_serializer_sklearn.return_value.serialize.assert_any_call(
997999
to_serialize=_Y_TRAIN,
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2023 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
from vertexai.preview._workflow.serialization_engine import (
18+
serializers_base,
19+
)
20+
21+
22+
class TestSerializerArgs:
23+
def test_object_id_is_saved(self):
24+
class TestClass:
25+
pass
26+
27+
test_obj = TestClass()
28+
serializer_args = serializers_base.SerializerArgs({test_obj: {"a": 1, "b": 2}})
29+
assert id(test_obj) in serializer_args
30+
assert test_obj not in serializer_args
31+
32+
def test_getitem_support_original_object(self):
33+
class TestClass:
34+
pass
35+
36+
test_obj = TestClass()
37+
serializer_args = serializers_base.SerializerArgs({test_obj: {"a": 1, "b": 2}})
38+
assert serializer_args[test_obj] == {"a": 1, "b": 2}
39+
40+
def test_get_support_original_object(self):
41+
class TestClass:
42+
pass
43+
44+
test_obj = TestClass()
45+
serializer_args = serializers_base.SerializerArgs({test_obj: {"a": 1, "b": 2}})
46+
assert serializer_args.get(test_obj) == {"a": 1, "b": 2}
47+
48+
def test_unhashable_obj_saved_successfully(self):
49+
unhashable = [1, 2, 3]
50+
serializer_args = serializers_base.SerializerArgs()
51+
serializer_args[unhashable] = {"a": 1, "b": 2}
52+
assert id(unhashable) in serializer_args
53+
54+
def test_getitem_support_original_unhashable(self):
55+
unhashable = [1, 2, 3]
56+
serializer_args = serializers_base.SerializerArgs()
57+
serializer_args[unhashable] = {"a": 1, "b": 2}
58+
assert serializer_args[unhashable] == {"a": 1, "b": 2}
59+
60+
def test_get_support_original_unhashable(self):
61+
unhashable = [1, 2, 3]
62+
serializers_args = serializers_base.SerializerArgs()
63+
serializers_args[unhashable] = {"a": 1, "b": 2}
64+
assert serializers_args.get(unhashable) == {"a": 1, "b": 2}

vertexai/preview/_workflow/executor/training.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import re
2222
import sys
2323
import time
24-
from typing import Any, Dict, List, Optional, Set, Tuple, Union, Hashable
24+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
2525
import warnings
2626

2727
from google.api_core import exceptions as api_exceptions
@@ -495,6 +495,8 @@ def remote_training(invokable: shared._Invokable, rewrapper: Any):
495495
bound_args = invokable.bound_arguments
496496
config = invokable.vertex_config.remote_config
497497
serializer_args = invokable.vertex_config.remote_config.serializer_args
498+
if not isinstance(serializer_args, serializers_base.SerializerArgs):
499+
raise ValueError("serializer_args must be an instance of SerializerArgs.")
498500

499501
autolog = vertexai.preview.global_config.autolog
500502
service_account = _get_service_account(config, autolog=autolog)
@@ -609,17 +611,13 @@ def remote_training(invokable: shared._Invokable, rewrapper: Any):
609611
to_serialize=arg_value,
610612
gcs_path=os.path.join(remote_job_input_path, f"{arg_name}"),
611613
framework=detected_framework,
612-
**serializer_args.get(arg_value, {})
613-
if isinstance(arg_value, Hashable)
614-
else {},
614+
**serializer_args.get(arg_value, {}),
615615
)
616616
else:
617617
serialization_metadata = serializer.serialize(
618618
to_serialize=arg_value,
619619
gcs_path=os.path.join(remote_job_input_path, f"{arg_name}"),
620-
**serializer_args.get(arg_value, {})
621-
if isinstance(arg_value, Hashable)
622-
else {},
620+
**serializer_args.get(arg_value, {}),
623621
)
624622
# serializer.get_dependencies() must be run after serializer.serialize()
625623
requirements += serialization_metadata[

vertexai/preview/_workflow/serialization_engine/serializers_base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
2626

2727
from google.cloud.aiplatform.utils import gcs_utils
28-
28+
from vertexai.preview._workflow.shared import data_structures
2929

3030
T = TypeVar("T")
3131
SERIALIZATION_METADATA_FILENAME = "serialization_metadata"
@@ -34,6 +34,9 @@
3434
SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY = "custom_commands"
3535

3636

37+
SerializerArgs = data_structures.IdAsKeyDict
38+
39+
3740
@dataclasses.dataclass
3841
class SerializationMetadata:
3942
"""Metadata of Serializer classes.

vertexai/preview/_workflow/shared/configs.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
# limitations under the License.
1616
#
1717
import dataclasses
18-
from typing import List, Optional, Dict, Any
18+
from typing import List, Optional
19+
from vertexai.preview._workflow.serialization_engine import (
20+
serializers_base,
21+
)
1922

2023

2124
@dataclasses.dataclass
@@ -72,16 +75,33 @@ class RemoteConfig(_BaseConfig):
7275
]
7376
7477
# Specify the extra parameters needed for serializing objects.
75-
model.train.vertex.remote_config.serializer_args = {
76-
model: {
77-
"extra_serializer_param1_for_model": param1_value,
78-
"extra_serializer_param2_for_model": param2_value,
78+
from vertexai.preview.developer import SerializerArgs
79+
80+
# You can put all the hashable objects with their arguments in the
81+
# SerializerArgs all at once in a dict. Here we assume "model" is
82+
# hashable.
83+
model.train.vertex.remote_config.serializer_args = SerializerArgs({
84+
model: {
85+
"extra_serializer_param1_for_model": param1_value,
86+
"extra_serializer_param2_for_model": param2_value,
87+
},
88+
hashable_obj2: {
89+
"extra_serializer_param1_for_hashable2": param1_value,
90+
"extra_serializer_param2_for_hashable2": param2_value,
91+
},
92+
})
93+
# Or if the object to be serialized is unhashable, put them into the
94+
# serializer_args one by one. If this is the only use case, there is
95+
# no need to import `SerializerArgs`. Here we assume "X_train" and
96+
# "y_train" is not hashable.
97+
model.train.vertex.remote_config.serializer_args[X_train] = {
98+
"extra_serializer_param1_for_X_train": param1_value,
99+
"extra_serializer_param2_for_X_train": param2_value,
79100
},
80-
X_train: {
81-
"extra_serializer_param1": param1_value,
82-
"extra_serializer_param2": param2_value,
101+
model.train.vertex.remote_config.serializer_args[y_train] = {
102+
"extra_serializer_param1_for_y_train": param1_value,
103+
"extra_serializer_param2_for_y_train": param2_value,
83104
}
84-
}
85105
86106
# Train the model as usual
87107
model.train(X_train, y_train)
@@ -132,7 +152,7 @@ class RemoteConfig(_BaseConfig):
132152
custom_commands (List[str]):
133153
List of custom commands to be run in the remote job environment.
134154
These commands will be run before the requirements are installed.
135-
serializer_args (Dict[Any, Dict[str, Any]]):
155+
serializer_args: serializers_base.SerializerArgs:
136156
Map from object to extra arguments when serializing the object. The extra
137157
arguments is a dictionary from the argument names to the argument values.
138158
"""
@@ -143,7 +163,9 @@ class RemoteConfig(_BaseConfig):
143163
service_account: Optional[str] = None
144164
requirements: List[str] = dataclasses.field(default_factory=list)
145165
custom_commands: List[str] = dataclasses.field(default_factory=list)
146-
serializer_args: Dict[Any, Dict[str, Any]] = dataclasses.field(default_factory=dict)
166+
serializer_args: serializers_base.SerializerArgs = dataclasses.field(
167+
default_factory=serializers_base.SerializerArgs
168+
)
147169

148170

149171
@dataclasses.dataclass
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2023 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
19+
class IdAsKeyDict(dict):
20+
"""Customized dict that maps each key to its id before storing the data.
21+
22+
This subclass of dict still allows one to use the original key during
23+
subscription ([] operator) or via `get()` method. But under the hood, the
24+
keys are the ids of the original keys.
25+
26+
Example:
27+
# add some hashable objects (key1 and key2) to the dict
28+
id_as_key_dict = IdAsKeyDict({key1: value1, key2: value2})
29+
# add a unhashable object (key3) to the dict
30+
id_as_key_dict[key3] = value3
31+
32+
# can access the value via subscription using the original key
33+
assert id_as_key_dict[key1] == value1
34+
assert id_as_key_dict[key2] == value2
35+
assert id_as_key_dict[key3] == value3
36+
# can access the value via get method using the original key
37+
assert id_as_key_dict.get(key1) == value1
38+
assert id_as_key_dict.get(key2) == value2
39+
assert id_as_key_dict.get(key3) == value3
40+
# but the original keys are not in the dict - the ids are
41+
assert id(key1) in id_as_key_dict
42+
assert id(key2) in id_as_key_dict
43+
assert id(key3) in id_as_key_dict
44+
assert key1 not in id_as_key_dict
45+
assert key2 not in id_as_key_dict
46+
assert key3 not in id_as_key_dict
47+
"""
48+
49+
def __init__(self, *args, **kwargs):
50+
internal_dict = {}
51+
for arg in args:
52+
for k, v in arg.items():
53+
internal_dict[id(k)] = v
54+
for k, v in kwargs.items():
55+
internal_dict[id(k)] = v
56+
super().__init__(internal_dict)
57+
58+
def __getitem__(self, _key):
59+
internal_key = id(_key)
60+
return super().__getitem__(internal_key)
61+
62+
def __setitem__(self, _key, _value):
63+
internal_key = id(_key)
64+
return super().__setitem__(internal_key, _value)
65+
66+
def get(self, key, default=None):
67+
internal_key = id(key)
68+
return super().get(internal_key, default)

vertexai/preview/developer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
PersistentResourceConfig = configs.PersistentResourceConfig
3030
Serializer = serializers_base.Serializer
3131
SerializationMetadata = serializers_base.SerializationMetadata
32+
SerializerArgs = serializers_base.SerializerArgs
3233
RemoteConfig = configs.RemoteConfig
3334
WorkerPoolSpec = remote_specs.WorkerPoolSpec
3435
WorkerPoolSepcs = remote_specs.WorkerPoolSpecs
@@ -41,6 +42,7 @@
4142
"PersistentResourceConfig",
4243
"register_serializer",
4344
"Serializer",
45+
"SerializerArgs",
4446
"SerializationMetadata",
4547
"RemoteConfig",
4648
"WorkerPoolSpec",

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy