Content-Length: 721524 | pFad | https://github.com/googleapis/python-aiplatform/commit/8bf30b74828c976e315879d9a7b61cb718e1bcfe

B3 feat: GenAI - Support batch prediction in Model Garden OpenModel. · googleapis/python-aiplatform@8bf30b7 · GitHub
Skip to content

Commit 8bf30b7

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI - Support batch prediction in Model Garden OpenModel.
PiperOrigin-RevId: 752953157
1 parent e0a54df commit 8bf30b7

File tree

2 files changed

+162
-1
lines changed

2 files changed

+162
-1
lines changed

tests/unit/vertexai/model_garden/test_model_garden.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,25 @@
2020
from google.api_core import operation as ga_operation
2121
from google.auth import credentials as auth_credentials
2222
from google.cloud import aiplatform
23+
from google.cloud.aiplatform.compat.services import job_service_client
24+
from google.cloud.aiplatform.compat.types import (
25+
batch_prediction_job as gca_batch_prediction_job_compat,
26+
)
27+
from google.cloud.aiplatform.compat.types import io as gca_io_compat
28+
from google.cloud.aiplatform.compat.types import (
29+
job_state as gca_job_state_compat,
30+
)
31+
from google.cloud.aiplatform_v1.types import machine_resources
32+
from google.cloud.aiplatform_v1.types import manual_batch_tuning_parameters
2333
from google.cloud.aiplatform_v1beta1 import types
2434
from google.cloud.aiplatform_v1beta1.services import model_garden_service
35+
from vertexai import batch_prediction
2536
from vertexai.preview import model_garden
2637
import pytest
2738

2839
from google.protobuf import duration_pb2
2940

41+
3042
_TEST_PROJECT = "test-project"
3143
_TEST_LOCATION = "us-central1"
3244

@@ -73,6 +85,24 @@
7385
timeout_seconds=10,
7486
),
7587
)
88+
_TEST_BATCH_PREDICTION_JOB_ID = "123456789"
89+
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
90+
_TEST_BATCH_PREDICTION_JOB_NAME = (
91+
f"{_TEST_PARENT}/batchPredictionJobs/{_TEST_BATCH_PREDICTION_JOB_ID}"
92+
)
93+
_TEST_BATCH_PREDICTION_MODEL_FULL_RESOURCE_NAME = (
94+
"publishers/google/models/gemma@gemma-2b-it"
95+
)
96+
_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME = "test-batch-prediction-job"
97+
_TEST_JOB_STATE_RUNNING = gca_job_state_compat.JobState(3)
98+
_TEST_GAPIC_BATCH_PREDICTION_JOB = gca_batch_prediction_job_compat.BatchPredictionJob(
99+
name=_TEST_BATCH_PREDICTION_JOB_NAME,
100+
display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
101+
model=_TEST_BATCH_PREDICTION_MODEL_FULL_RESOURCE_NAME,
102+
state=_TEST_JOB_STATE_RUNNING,
103+
)
104+
_TEST_BQ_INPUT_URI = "bq://test-project.test-dataset.test-input"
105+
_TEST_BQ_OUTPUT_PREFIX = "bq://test-project.test-dataset.test-output"
76106

77107

78108
@pytest.fixture(scope="module")
@@ -117,6 +147,25 @@ def deploy_mock():
117147
yield deploy
118148

119149

150+
@pytest.fixture
151+
def batch_prediction_mock():
152+
"""Mocks the create_batch_prediction_job method."""
153+
with mock.patch.object(
154+
job_service_client.JobServiceClient, "create_batch_prediction_job"
155+
) as create_batch_prediction_job_mock:
156+
create_batch_prediction_job_mock.return_value = _TEST_GAPIC_BATCH_PREDICTION_JOB
157+
yield create_batch_prediction_job_mock
158+
159+
160+
@pytest.fixture
161+
def complete_bq_uri_mock():
162+
with mock.patch.object(
163+
batch_prediction.BatchPredictionJob, "_complete_bq_uri"
164+
) as complete_bq_uri_mock:
165+
complete_bq_uri_mock.return_value = _TEST_BQ_OUTPUT_PREFIX
166+
yield complete_bq_uri_mock
167+
168+
120169
@pytest.fixture
121170
def get_publisher_model_mock():
122171
with mock.patch.object(
@@ -355,6 +404,8 @@ def list_publisher_models_mock():
355404
"get_publisher_model_mock",
356405
"list_publisher_models_mock",
357406
"export_publisher_model_mock",
407+
"batch_prediction_mock",
408+
"complete_bq_uri_mock",
358409
)
359410
class TestModelGarden:
360411
"""Test cases for ModelGarden class."""
@@ -897,3 +948,54 @@ def test_list_deployable_models(self, list_publisher_models_mock):
897948
"google/gemma-2-2b",
898949
"google/gemma-2-2b",
899950
]
951+
952+
def test_batch_prediction_success(self, batch_prediction_mock):
953+
aiplatform.init(
954+
project=_TEST_PROJECT,
955+
location=_TEST_LOCATION,
956+
)
957+
model = model_garden.OpenModel(
958+
model_name=_TEST_BATCH_PREDICTION_MODEL_FULL_RESOURCE_NAME
959+
)
960+
job = model.batch_predict(
961+
input_dataset=_TEST_BQ_INPUT_URI,
962+
job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
963+
machine_type="g2-standard-12",
964+
accelerator_type="NVIDIA_L4",
965+
accelerator_count=1,
966+
starting_replica_count=1,
967+
)
968+
969+
assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
970+
971+
expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
972+
display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
973+
model=_TEST_BATCH_PREDICTION_MODEL_FULL_RESOURCE_NAME,
974+
input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig(
975+
instances_format="bigquery",
976+
bigquery_source=gca_io_compat.BigQuerySource(
977+
input_uri=_TEST_BQ_INPUT_URI
978+
),
979+
),
980+
output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig(
981+
bigquery_destination=gca_io_compat.BigQueryDestination(
982+
output_uri=_TEST_BQ_OUTPUT_PREFIX
983+
),
984+
predictions_format="bigquery",
985+
),
986+
dedicated_resources=machine_resources.BatchDedicatedResources(
987+
machine_spec=machine_resources.MachineSpec(
988+
machine_type="g2-standard-12",
989+
accelerator_type="NVIDIA_L4",
990+
accelerator_count=1,
991+
),
992+
starting_replica_count=1,
993+
),
994+
manual_batch_tuning_parameters=manual_batch_tuning_parameters.ManualBatchTuningParameters(),
995+
)
996+
997+
batch_prediction_mock.assert_called_once_with(
998+
parent=_TEST_PARENT,
999+
batch_prediction_job=expected_gapic_batch_prediction_job,
1000+
timeout=None,
1001+
)

vertexai/model_garden/_model_garden.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import datetime
2020
import functools
2121
import re
22-
from typing import Dict, List, Optional, Sequence
22+
from typing import Dict, List, Optional, Sequence, Union
2323

2424
from google.cloud import aiplatform
2525
from google.cloud.aiplatform import base
@@ -29,6 +29,8 @@
2929
from google.cloud.aiplatform import utils
3030
from google.cloud.aiplatform_v1beta1 import types
3131
from google.cloud.aiplatform_v1beta1.services import model_garden_service
32+
from vertexai import batch_prediction
33+
3234

3335
from google.protobuf import duration_pb2
3436

@@ -656,3 +658,60 @@ def list_deploy_options(
656658
" to find out which ones currently support deployment."
657659
)
658660
return multi_deploy
661+
662+
def batch_predict(
663+
self,
664+
input_dataset: Union[str, List[str]],
665+
*,
666+
output_uri_prefix: Optional[str] = None,
667+
job_display_name: Optional[str] = None,
668+
machine_type: Optional[str] = None,
669+
accelerator_type: Optional[str] = None,
670+
accelerator_count: Optional[int] = None,
671+
starting_replica_count: Optional[int] = None,
672+
max_replica_count: Optional[int] = None,
673+
) -> batch_prediction.BatchPredictionJob:
674+
"""Perform batch prediction on the model.
675+
676+
Args:
677+
input_dataset (Union[str, List[str]]):
678+
GCS URI(-s) or BigQuery URI to your input data to run batch
679+
prediction on. Example: "gs://path/to/input/data.jsonl" or
680+
"bq://projectId.bqDatasetId.bqTableId"
681+
output_uri_prefix (Optional[str]):
682+
GCS or BigQuery URI prefix for the output predictions. Example:
683+
"gs://path/to/output/data" or "bq://projectId.bqDatasetId"
684+
If not specified, f"{STAGING_BUCKET}/gen-ai-batch-prediction" will
685+
be used for GCS source and
686+
f"bq://projectId.gen_ai_batch_prediction.predictions_{TIMESTAMP}"
687+
will be used for BigQuery source.
688+
job_display_name (Optional[str]):
689+
The user-defined name of the BatchPredictionJob.
690+
The name can be up to 128 characters long and can be consist
691+
of any UTF-8 characters.
692+
machine_type (Optional[str]):
693+
The machine type for the batch prediction job.
694+
accelerator_type (Optional[str]):
695+
The accelerator type for the batch prediction job.
696+
accelerator_count (Optional[int]):
697+
The accelerator count for the batch prediction job.
698+
starting_replica_count (Optional[int]):
699+
The starting replica count for the batch prediction job.
700+
max_replica_count (Optional[int]):
701+
The maximum replica count for the batch prediction job.
702+
703+
Returns:
704+
batch_prediction.BatchPredictionJob:
705+
The batch prediction job.
706+
"""
707+
return batch_prediction.BatchPredictionJob.submit(
708+
source_model=self._publisher_model_name,
709+
input_dataset=input_dataset,
710+
output_uri_prefix=output_uri_prefix,
711+
job_display_name=job_display_name,
712+
machine_type=machine_type,
713+
accelerator_type=accelerator_type,
714+
accelerator_count=accelerator_count,
715+
starting_replica_count=starting_replica_count,
716+
max_replica_count=max_replica_count,
717+
)

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: https://github.com/googleapis/python-aiplatform/commit/8bf30b74828c976e315879d9a7b61cb718e1bcfe

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy