Skip to content

Commit 2a08535

Browse files
Ark-kuncopybara-github
authored andcommitted
fix: LLM - Fixed batch prediction on tuned models
PiperOrigin-RevId: 560910428
1 parent 2e3090b commit 2a08535

File tree

2 files changed

+34
-9
lines changed

2 files changed

+34
-9
lines changed

tests/unit/aiplatform/test_language_models.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
artifact as gca_artifact,
4646
prediction_service as gca_prediction_service,
4747
context as gca_context,
48-
endpoint as gca_endpoint,
48+
endpoint_v1 as gca_endpoint,
4949
pipeline_job as gca_pipeline_job,
5050
pipeline_state as gca_pipeline_state,
5151
deployed_model_ref_v1,
@@ -1030,6 +1030,11 @@ def get_endpoint_mock():
10301030
get_endpoint_mock.return_value = gca_endpoint.Endpoint(
10311031
display_name="test-display-name",
10321032
name=test_constants.EndpointConstants._TEST_ENDPOINT_NAME,
1033+
deployed_models=[
1034+
gca_endpoint.DeployedModel(
1035+
model=test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME
1036+
),
1037+
],
10331038
)
10341039
yield get_endpoint_mock
10351040

@@ -2420,7 +2425,10 @@ def test_text_embedding_ga(self):
24202425
assert len(vector) == _TEXT_EMBEDDING_VECTOR_LENGTH
24212426
assert vector == _TEST_TEXT_EMBEDDING_PREDICTION["embeddings"]["values"]
24222427

2423-
def test_batch_prediction(self):
2428+
def test_batch_prediction(
2429+
self,
2430+
get_endpoint_mock,
2431+
):
24242432
"""Tests batch prediction."""
24252433
aiplatform.init(
24262434
project=_TEST_PROJECT,
@@ -2447,7 +2455,29 @@ def test_batch_prediction(self):
24472455
model_parameters={"temperature": 0.1},
24482456
)
24492457
mock_create.assert_called_once_with(
2450-
model_name="publishers/google/models/text-bison@001",
2458+
model_name=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/text-bison@001",
2459+
job_display_name=None,
2460+
gcs_source="gs://test-bucket/test_table.jsonl",
2461+
gcs_destination_prefix="gs://test-bucket/results/",
2462+
model_parameters={"temperature": 0.1},
2463+
)
2464+
2465+
# Testing tuned model batch prediction
2466+
tuned_model = language_models.TextGenerationModel(
2467+
model_id=model._model_id,
2468+
endpoint_name=test_constants.EndpointConstants._TEST_ENDPOINT_NAME,
2469+
)
2470+
with mock.patch.object(
2471+
target=aiplatform.BatchPredictionJob,
2472+
attribute="create",
2473+
) as mock_create:
2474+
tuned_model.batch_predict(
2475+
dataset="gs://test-bucket/test_table.jsonl",
2476+
destination_uri_prefix="gs://test-bucket/results/",
2477+
model_parameters={"temperature": 0.1},
2478+
)
2479+
mock_create.assert_called_once_with(
2480+
model_name=test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME,
24512481
job_display_name=None,
24522482
gcs_source="gs://test-bucket/test_table.jsonl",
24532483
gcs_destination_prefix="gs://test-bucket/results/",
@@ -2481,7 +2511,7 @@ def test_batch_prediction_for_text_embedding(self):
24812511
model_parameters={},
24822512
)
24832513
mock_create.assert_called_once_with(
2484-
model_name="publishers/google/models/textembedding-gecko@001",
2514+
model_name=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/textembedding-gecko@001",
24852515
job_display_name=None,
24862516
gcs_source="gs://test-bucket/test_table.jsonl",
24872517
gcs_destination_prefix="gs://test-bucket/results/",

vertexai/language_models/_language_models.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -839,11 +839,6 @@ def batch_predict(
839839
raise ValueError(f"Unsupported destination_uri: {destination_uri_prefix}")
840840

841841
model_name = self._model_resource_name
842-
# TODO(b/284512065): Batch prediction service does not support
843-
# fully qualified publisher model names yet
844-
publishers_index = model_name.index("/publishers/")
845-
if publishers_index > 0:
846-
model_name = model_name[publishers_index + 1 :]
847842

848843
job = aiplatform.BatchPredictionJob.create(
849844
model_name=model_name,

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy