|
20 | 20 | from google.api_core import operation as ga_operation
|
21 | 21 | from google.auth import credentials as auth_credentials
|
22 | 22 | from google.cloud import aiplatform
|
| 23 | +from google.cloud.aiplatform.compat.services import job_service_client |
| 24 | +from google.cloud.aiplatform.compat.types import ( |
| 25 | + batch_prediction_job as gca_batch_prediction_job_compat, |
| 26 | +) |
| 27 | +from google.cloud.aiplatform.compat.types import io as gca_io_compat |
| 28 | +from google.cloud.aiplatform.compat.types import ( |
| 29 | + job_state as gca_job_state_compat, |
| 30 | +) |
| 31 | +from google.cloud.aiplatform_v1.types import machine_resources |
| 32 | +from google.cloud.aiplatform_v1.types import manual_batch_tuning_parameters |
23 | 33 | from google.cloud.aiplatform_v1beta1 import types
|
24 | 34 | from google.cloud.aiplatform_v1beta1.services import model_garden_service
|
| 35 | +from vertexai import batch_prediction |
25 | 36 | from vertexai.preview import model_garden
|
26 | 37 | import pytest
|
27 | 38 |
|
28 | 39 | from google.protobuf import duration_pb2
|
29 | 40 |
|
| 41 | + |
30 | 42 | _TEST_PROJECT = "test-project"
|
31 | 43 | _TEST_LOCATION = "us-central1"
|
32 | 44 |
|
|
73 | 85 | timeout_seconds=10,
|
74 | 86 | ),
|
75 | 87 | )
|
| 88 | +_TEST_BATCH_PREDICTION_JOB_ID = "123456789" |
| 89 | +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" |
| 90 | +_TEST_BATCH_PREDICTION_JOB_NAME = ( |
| 91 | + f"{_TEST_PARENT}/batchPredictionJobs/{_TEST_BATCH_PREDICTION_JOB_ID}" |
| 92 | +) |
| 93 | +_TEST_BATCH_PREDICTION_MODEL_FULL_RESOURCE_NAME = ( |
| 94 | + "publishers/google/models/gemma@gemma-2b-it" |
| 95 | +) |
| 96 | +_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME = "test-batch-prediction-job" |
| 97 | +_TEST_JOB_STATE_RUNNING = gca_job_state_compat.JobState(3) |
| 98 | +_TEST_GAPIC_BATCH_PREDICTION_JOB = gca_batch_prediction_job_compat.BatchPredictionJob( |
| 99 | + name=_TEST_BATCH_PREDICTION_JOB_NAME, |
| 100 | + display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, |
| 101 | + model=_TEST_BATCH_PREDICTION_MODEL_FULL_RESOURCE_NAME, |
| 102 | + state=_TEST_JOB_STATE_RUNNING, |
| 103 | +) |
| 104 | +_TEST_BQ_INPUT_URI = "bq://test-project.test-dataset.test-input" |
| 105 | +_TEST_BQ_OUTPUT_PREFIX = "bq://test-project.test-dataset.test-output" |
76 | 106 |
|
77 | 107 |
|
78 | 108 | @pytest.fixture(scope="module")
|
@@ -117,6 +147,25 @@ def deploy_mock():
|
117 | 147 | yield deploy
|
118 | 148 |
|
119 | 149 |
|
| 150 | +@pytest.fixture |
| 151 | +def batch_prediction_mock(): |
| 152 | + """Mocks the create_batch_prediction_job method.""" |
| 153 | + with mock.patch.object( |
| 154 | + job_service_client.JobServiceClient, "create_batch_prediction_job" |
| 155 | + ) as create_batch_prediction_job_mock: |
| 156 | + create_batch_prediction_job_mock.return_value = _TEST_GAPIC_BATCH_PREDICTION_JOB |
| 157 | + yield create_batch_prediction_job_mock |
| 158 | + |
| 159 | + |
| 160 | +@pytest.fixture |
| 161 | +def complete_bq_uri_mock(): |
| 162 | + with mock.patch.object( |
| 163 | + batch_prediction.BatchPredictionJob, "_complete_bq_uri" |
| 164 | + ) as complete_bq_uri_mock: |
| 165 | + complete_bq_uri_mock.return_value = _TEST_BQ_OUTPUT_PREFIX |
| 166 | + yield complete_bq_uri_mock |
| 167 | + |
| 168 | + |
120 | 169 | @pytest.fixture
|
121 | 170 | def get_publisher_model_mock():
|
122 | 171 | with mock.patch.object(
|
@@ -355,6 +404,8 @@ def list_publisher_models_mock():
|
355 | 404 | "get_publisher_model_mock",
|
356 | 405 | "list_publisher_models_mock",
|
357 | 406 | "export_publisher_model_mock",
|
| 407 | + "batch_prediction_mock", |
| 408 | + "complete_bq_uri_mock", |
358 | 409 | )
|
359 | 410 | class TestModelGarden:
|
360 | 411 | """Test cases for ModelGarden class."""
|
@@ -897,3 +948,54 @@ def test_list_deployable_models(self, list_publisher_models_mock):
|
897 | 948 | "google/gemma-2-2b",
|
898 | 949 | "google/gemma-2-2b",
|
899 | 950 | ]
|
| 951 | + |
| 952 | + def test_batch_prediction_success(self, batch_prediction_mock): |
| 953 | + aiplatform.init( |
| 954 | + project=_TEST_PROJECT, |
| 955 | + location=_TEST_LOCATION, |
| 956 | + ) |
| 957 | + model = model_garden.OpenModel( |
| 958 | + model_name=_TEST_BATCH_PREDICTION_MODEL_FULL_RESOURCE_NAME |
| 959 | + ) |
| 960 | + job = model.batch_predict( |
| 961 | + input_dataset=_TEST_BQ_INPUT_URI, |
| 962 | + job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, |
| 963 | + machine_type="g2-standard-12", |
| 964 | + accelerator_type="NVIDIA_L4", |
| 965 | + accelerator_count=1, |
| 966 | + starting_replica_count=1, |
| 967 | + ) |
| 968 | + |
| 969 | + assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB |
| 970 | + |
| 971 | + expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( |
| 972 | + display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, |
| 973 | + model=_TEST_BATCH_PREDICTION_MODEL_FULL_RESOURCE_NAME, |
| 974 | + input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( |
| 975 | + instances_format="bigquery", |
| 976 | + bigquery_source=gca_io_compat.BigQuerySource( |
| 977 | + input_uri=_TEST_BQ_INPUT_URI |
| 978 | + ), |
| 979 | + ), |
| 980 | + output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( |
| 981 | + bigquery_destination=gca_io_compat.BigQueryDestination( |
| 982 | + output_uri=_TEST_BQ_OUTPUT_PREFIX |
| 983 | + ), |
| 984 | + predictions_format="bigquery", |
| 985 | + ), |
| 986 | + dedicated_resources=machine_resources.BatchDedicatedResources( |
| 987 | + machine_spec=machine_resources.MachineSpec( |
| 988 | + machine_type="g2-standard-12", |
| 989 | + accelerator_type="NVIDIA_L4", |
| 990 | + accelerator_count=1, |
| 991 | + ), |
| 992 | + starting_replica_count=1, |
| 993 | + ), |
| 994 | + manual_batch_tuning_parameters=manual_batch_tuning_parameters.ManualBatchTuningParameters(), |
| 995 | + ) |
| 996 | + |
| 997 | + batch_prediction_mock.assert_called_once_with( |
| 998 | + parent=_TEST_PARENT, |
| 999 | + batch_prediction_job=expected_gapic_batch_prediction_job, |
| 1000 | + timeout=None, |
| 1001 | + ) |
0 commit comments