Skip to content

Commit 7dc8771

Browse files
matthew29tangcopybara-github
authored andcommitted
feat: Support custom batch size for Bigframes Tensorflow
PiperOrigin-RevId: 589190954
1 parent 0cb1a7b commit 7dc8771

File tree

3 files changed

+16
-8
lines changed

3 files changed

+16
-8
lines changed

tests/system/vertexai/test_bigframes_tensorflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262
"prepare_staging_bucket", "delete_staging_bucket", "tear_down_resources"
6363
)
6464
class TestRemoteExecutionBigframesTensorflow(e2e_base.TestEndToEnd):
65-
6665
_temp_prefix = "temp-vertexai-remote-execution"
6766

6867
def test_remote_execution_keras(self, shared_state):
@@ -97,6 +96,7 @@ def test_remote_execution_keras(self, shared_state):
9796
enable_cuda=True,
9897
display_name=self._make_display_name("bigframes-keras-training"),
9998
)
99+
model.fit.vertex.remote_config.serializer_args[train] = {"batch_size": 10}
100100

101101
# Train model on Vertex
102102
model.fit(train, epochs=10)

tests/unit/vertexai/test_any_serializer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,7 @@ def test_any_serializer_deserialize_bigframe_tensorflow(
11051105
mock_bigframe_deserialize_tensorflow.assert_called_once_with(
11061106
any_serializer_instance._instances[serializers.BigframeSerializer],
11071107
serialized_gcs_path=fake_gcs_path,
1108+
batch_size=None,
11081109
)
11091110

11101111
def test_any_serializer_deserialize_tf_dataset(

vertexai/preview/_workflow/serialization_engine/serializers.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
"2.12": "0.32.0",
9191
"2.13": "0.34.0", # TODO(b/295580335): Support TF 2.13
9292
}
93+
DEFAULT_TENSORFLOW_BATCHSIZE = 32
9394

9495

9596
def get_uri_prefix(gcs_uri: str) -> str:
@@ -1174,7 +1175,9 @@ def serialize(
11741175
# Convert bigframes.dataframe.DataFrame to Parquet (GCS)
11751176
parquet_gcs_path = gcs_path + "/*" # path is required to contain '*'
11761177
to_serialize.to_parquet(parquet_gcs_path, index=True)
1177-
return parquet_gcs_path
1178+
1179+
# Return original gcs_path to retrieve the metadata for later
1180+
return gcs_path
11781181

11791182
def _get_tfio_verison(self):
11801183
major, minor, _ = version.Version(tf.__version__).release
@@ -1190,15 +1193,15 @@ def _get_tfio_verison(self):
11901193
def deserialize(
11911194
self, serialized_gcs_path: str, **kwargs
11921195
) -> Union["pandas.DataFrame", "bigframes.dataframe.DataFrame"]: # noqa: F821
1193-
del kwargs
1194-
11951196
detected_framework = BigframeSerializer._metadata.framework
11961197
if detected_framework == "sklearn":
11971198
return self._deserialize_sklearn(serialized_gcs_path)
11981199
elif detected_framework == "torch":
11991200
return self._deserialize_torch(serialized_gcs_path)
12001201
elif detected_framework == "tensorflow":
1201-
return self._deserialize_tensorflow(serialized_gcs_path)
1202+
return self._deserialize_tensorflow(
1203+
serialized_gcs_path, kwargs.get("batch_size")
1204+
)
12021205
else:
12031206
raise ValueError(f"Unsupported framework: {detected_framework}")
12041207

@@ -1269,11 +1272,16 @@ def reduce_tensors(a, b):
12691272

12701273
return functools.reduce(reduce_tensors, list(parquet_df_dp))
12711274

1272-
def _deserialize_tensorflow(self, serialized_gcs_path: str) -> TFDataset:
1275+
def _deserialize_tensorflow(
1276+
self, serialized_gcs_path: str, batch_size: Optional[int] = None
1277+
) -> TFDataset:
12731278
"""Tensorflow deserializes parquet (GCS) --> tf.data.Dataset
12741279
12751280
serialized_gcs_path is a folder containing one or more parquet files.
12761281
"""
1282+
# Set default batch_size
1283+
batch_size = batch_size or DEFAULT_TENSORFLOW_BATCHSIZE
1284+
12771285
# Deserialization at remote environment
12781286
try:
12791287
import tensorflow_io as tfio
@@ -1307,8 +1315,7 @@ def reduce_fn(a, b):
13071315

13081316
return functools.reduce(reduce_fn, row.values()), target
13091317

1310-
# TODO(b/295535730): Remove hardcoded batch_size of 32
1311-
return ds.map(map_fn).batch(32)
1318+
return ds.map(map_fn).batch(batch_size)
13121319

13131320

13141321
class CloudPickleSerializer(serializers_base.Serializer):

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy