Content-Length: 417979 | pFad | https://github.com/googleapis/python-aiplatform/commit/627992484ec16fbf7fdfc9c963046e10e3d7c6bf

D4 feat: Add mappings to pipeline templates for text-embedding models. · googleapis/python-aiplatform@6279924 · GitHub
Skip to content

Commit 6279924

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add mappings to pipeline templates for text-embedding models.
PiperOrigin-RevId: 625816165
1 parent db10338 commit 6279924

File tree

2 files changed

+42
-7
lines changed

2 files changed

+42
-7
lines changed

tests/unit/aiplatform/test_model_garden_models.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,23 +47,40 @@
4747
},
4848
}
4949

50+
_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT = {
51+
"name": "publishers/google/models/textembedding-gecko",
52+
"version_id": "003",
53+
"open_source_category": "PROPRIETARY",
54+
"launch_stage": gca_publisher_model.PublisherModel.LaunchStage.GA,
55+
"publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/textembedding-gecko@003",
56+
"predict_schemata": {
57+
"instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/text_embedding_1.0.0.yaml",
58+
"parameters_schema_uri": "gs://google-cloud-aiplatfrom/schema/predict/params/text_embedding_1.0.0.yaml",
59+
"prediction_schema_uri": "gs://google-cloud-aiplatform/schema/predict/prediction/text_embedding_1.0.0.yaml",
60+
},
61+
}
62+
5063

5164
@pytest.mark.usefixtures("google_auth_mock")
5265
class TestModelGardenModels:
5366
"""Unit tests for the _ModelGardenModel base class."""
5467

55-
class FakeModelGardenModel(_model_garden_models._ModelGardenModel):
68+
class FakeModelGardenBisonModel(_model_garden_models._ModelGardenModel):
5669

5770
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml"
5871

72+
class FakeModelGardenGeckoModel(_model_garden_models._ModelGardenModel):
73+
74+
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/text_embedding_1.0.0.yaml"
75+
5976
def setup_method(self):
6077
reload(initializer)
6178
reload(aiplatform)
6279

6380
def teardown_method(self):
6481
initializer.global_pool.shutdown(wait=True)
6582

66-
def test_init_model_garden_model_with_from_pretrained(self):
83+
def test_init_model_garden_bison_model_with_from_pretrained(self):
6784
"""Tests the text generation model."""
6885
aiplatform.init(
6986
project=test_constants.ProjectConstants._TEST_PROJECT,
@@ -76,9 +93,29 @@ def test_init_model_garden_model_with_from_pretrained(self):
7693
_TEXT_BISON_PUBLISHER_MODEL_DICT
7794
),
7895
) as mock_get_publisher_model:
79-
self.FakeModelGardenModel.from_pretrained("text-bison@001")
96+
self.FakeModelGardenBisonModel.from_pretrained("text-bison@001")
8097

8198
mock_get_publisher_model.assert_called_once_with(
8299
name="publishers/google/models/text-bison@001",
83100
retry=base._DEFAULT_RETRY,
84101
)
102+
103+
def test_init_model_garden_gecko_model_with_from_pretrained(self):
104+
"""Tests the text generation model."""
105+
aiplatform.init(
106+
project=test_constants.ProjectConstants._TEST_PROJECT,
107+
location=test_constants.ProjectConstants._TEST_LOCATION,
108+
)
109+
with mock.patch.object(
110+
target=model_garden_service_client_v1.ModelGardenServiceClient,
111+
attribute="get_publisher_model",
112+
return_value=gca_publisher_model.PublisherModel(
113+
_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT
114+
),
115+
) as mock_get_publisher_model:
116+
self.FakeModelGardenGeckoModel.from_pretrained("textembedding-gecko@003")
117+
118+
mock_get_publisher_model.assert_called_once_with(
119+
name="publishers/google/models/textembedding-gecko@003",
120+
retry=base._DEFAULT_RETRY,
121+
)

vertexai/_model_garden/_model_garden_models.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,6 @@
2828
# this is needed for class registration to _SUBCLASSES
2929
import vertexai # pylint:disable=unused-import
3030

31-
from google.cloud.aiplatform.compat.types import (
32-
publisher_model as gca_publisher_model,
33-
)
34-
3531
_SUPPORTED_PUBLISHERS = ["google"]
3632

3733
_SHORT_MODEL_ID_TO_TUNING_PIPELINE_MAP = {
@@ -43,6 +39,8 @@
4339
"chat-bison-32k": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0",
4440
"codechat-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0",
4541
"codechat-bison-32k": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0",
42+
"textembedding-gecko": "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.2",
43+
"textembedding-gecko-multilingual": "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.2",
4644
}
4745

4846
_LOGGER = base.Logger(__name__)

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: https://github.com/googleapis/python-aiplatform/commit/627992484ec16fbf7fdfc9c963046e10e3d7c6bf

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy