Skip to content

Commit b1cab3f

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add support for output_dimensionality parameter through get_embeddings.
PiperOrigin-RevId: 617251035
1 parent be4922a commit b1cab3f

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

tests/unit/aiplatform/test_language_models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4164,6 +4164,7 @@ def test_text_embedding(self):
41644164
),
41654165
],
41664166
auto_truncate=False,
4167+
output_dimensionality=3,
41674168
)
41684169
prediction_instances = mock_predict.call_args[1]["instances"]
41694170
assert prediction_instances == [
@@ -4180,6 +4181,7 @@ def test_text_embedding(self):
41804181
]
41814182
prediction_parameters = mock_predict.call_args[1]["parameters"]
41824183
assert not prediction_parameters["autoTruncate"]
4184+
assert prediction_parameters["outputDimensionality"] == 3
41834185
assert embeddings
41844186
for embedding in embeddings:
41854187
vector = embedding.values

vertexai/language_models/_language_models.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,12 +2003,14 @@ def _prepare_text_embedding_request(
20032003
texts: List[Union[str, TextEmbeddingInput]],
20042004
*,
20052005
auto_truncate: bool = True,
2006+
output_dimensionality: Optional[int] = None,
20062007
) -> _MultiInstancePredictionRequest:
20072008
"""Asynchronously calculates embeddings for the given texts.
20082009
20092010
Args:
20102011
texts(str): A list of texts or `TextEmbeddingInput` objects to embed.
20112012
auto_truncate(bool): Whether to automatically truncate long texts. Default: True.
2013+
output_dimensionality: Optional dimensions of embeddings. Range: [1, 768]. Default: None.
20122014
20132015
Returns:
20142016
A `_MultiInstancePredictionRequest` object.
@@ -2029,6 +2031,8 @@ def _prepare_text_embedding_request(
20292031
raise TypeError(f"Unsupported text embedding input type: {text}.")
20302032
instances.append(instance)
20312033
parameters = {"autoTruncate": auto_truncate}
2034+
if output_dimensionality is not None:
2035+
parameters["outputDimensionality"] = output_dimensionality
20322036
return _MultiInstancePredictionRequest(
20332037
instances=instances,
20342038
parameters=parameters,
@@ -2057,19 +2061,22 @@ def get_embeddings(
20572061
texts: List[Union[str, TextEmbeddingInput]],
20582062
*,
20592063
auto_truncate: bool = True,
2064+
output_dimensionality: Optional[int] = None
20602065
) -> List["TextEmbedding"]:
20612066
"""Calculates embeddings for the given texts.
20622067
20632068
Args:
2064-
texts(str): A list of texts or `TextEmbeddingInput` objects to embed.
2065-
auto_truncate(bool): Whether to automatically truncate long texts. Default: True.
2069+
texts: A list of texts or `TextEmbeddingInput` objects to embed.
2070+
auto_truncate: Whether to automatically truncate long texts. Default: True.
2071+
output_dimensionality: Optional dimensions of embeddings. Range: [1, 768]. Default: None.
20662072
20672073
Returns:
20682074
A list of `TextEmbedding` objects.
20692075
"""
20702076
prediction_request = self._prepare_text_embedding_request(
20712077
texts=texts,
20722078
auto_truncate=auto_truncate,
2079+
output_dimensionality=output_dimensionality,
20732080
)
20742081

20752082
prediction_response = self._endpoint.predict(
@@ -2092,19 +2099,22 @@ async def get_embeddings_async(
20922099
texts: List[Union[str, TextEmbeddingInput]],
20932100
*,
20942101
auto_truncate: bool = True,
2102+
output_dimensionality: Optional[int] = None,
20952103
) -> List["TextEmbedding"]:
20962104
"""Asynchronously calculates embeddings for the given texts.
20972105
20982106
Args:
2099-
texts(str): A list of texts or `TextEmbeddingInput` objects to embed.
2100-
auto_truncate(bool): Whether to automatically truncate long texts. Default: True.
2107+
texts: A list of texts or `TextEmbeddingInput` objects to embed.
2108+
auto_truncate: Whether to automatically truncate long texts. Default: True.
2109+
output_dimensionality: Optional dimensions of embeddings. Range: [1, 768]. Default: None.
21012110
21022111
Returns:
21032112
A list of `TextEmbedding` objects.
21042113
"""
21052114
prediction_request = self._prepare_text_embedding_request(
21062115
texts=texts,
21072116
auto_truncate=auto_truncate,
2117+
output_dimensionality=output_dimensionality
21082118
)
21092119

21102120
prediction_response = await self._endpoint.predict_async(

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy