47
47
},
48
48
}
49
49
50
+ _EMBEDDING_GECKO_PUBLISHER_MODEL_DICT = {
51
+ "name" : "publishers/google/models/textembedding-gecko" ,
52
+ "version_id" : "003" ,
53
+ "open_source_category" : "PROPRIETARY" ,
54
+ "launch_stage" : gca_publisher_model .PublisherModel .LaunchStage .GA ,
55
+ "publisher_model_template" : "projects/{user-project}/locations/{location}/publishers/google/models/textembedding-gecko@003" ,
56
+ "predict_schemata" : {
57
+ "instance_schema_uri" : "gs://google-cloud-aiplatform/schema/predict/instance/text_embedding_1.0.0.yaml" ,
58
+ "parameters_schema_uri" : "gs://google-cloud-aiplatfrom/schema/predict/params/text_embedding_1.0.0.yaml" ,
59
+ "prediction_schema_uri" : "gs://google-cloud-aiplatform/schema/predict/prediction/text_embedding_1.0.0.yaml" ,
60
+ },
61
+ }
62
+
50
63
51
64
@pytest .mark .usefixtures ("google_auth_mock" )
52
65
class TestModelGardenModels :
53
66
"""Unit tests for the _ModelGardenModel base class."""
54
67
55
- class FakeModelGardenModel (_model_garden_models ._ModelGardenModel ):
68
+ class FakeModelGardenBisonModel (_model_garden_models ._ModelGardenModel ):
56
69
57
70
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml"
58
71
72
+ class FakeModelGardenGeckoModel (_model_garden_models ._ModelGardenModel ):
73
+
74
+ _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/text_embedding_1.0.0.yaml"
75
+
59
76
def setup_method (self ):
60
77
reload (initializer )
61
78
reload (aiplatform )
62
79
63
80
def teardown_method (self ):
64
81
initializer .global_pool .shutdown (wait = True )
65
82
66
- def test_init_model_garden_model_with_from_pretrained (self ):
83
+ def test_init_model_garden_bison_model_with_from_pretrained (self ):
67
84
"""Tests the text generation model."""
68
85
aiplatform .init (
69
86
project = test_constants .ProjectConstants ._TEST_PROJECT ,
@@ -76,9 +93,29 @@ def test_init_model_garden_model_with_from_pretrained(self):
76
93
_TEXT_BISON_PUBLISHER_MODEL_DICT
77
94
),
78
95
) as mock_get_publisher_model :
79
- self .FakeModelGardenModel .from_pretrained ("text-bison@001" )
96
+ self .FakeModelGardenBisonModel .from_pretrained ("text-bison@001" )
80
97
81
98
mock_get_publisher_model .assert_called_once_with (
82
99
name = "publishers/google/models/text-bison@001" ,
83
100
retry = base ._DEFAULT_RETRY ,
84
101
)
102
+
103
+ def test_init_model_garden_gecko_model_with_from_pretrained (self ):
104
+ """Tests the text generation model."""
105
+ aiplatform .init (
106
+ project = test_constants .ProjectConstants ._TEST_PROJECT ,
107
+ location = test_constants .ProjectConstants ._TEST_LOCATION ,
108
+ )
109
+ with mock .patch .object (
110
+ target = model_garden_service_client_v1 .ModelGardenServiceClient ,
111
+ attribute = "get_publisher_model" ,
112
+ return_value = gca_publisher_model .PublisherModel (
113
+ _EMBEDDING_GECKO_PUBLISHER_MODEL_DICT
114
+ ),
115
+ ) as mock_get_publisher_model :
116
+ self .FakeModelGardenGeckoModel .from_pretrained ("textembedding-gecko@003" )
117
+
118
+ mock_get_publisher_model .assert_called_once_with (
119
+ name = "publishers/google/models/textembedding-gecko@003" ,
120
+ retry = base ._DEFAULT_RETRY ,
121
+ )
0 commit comments