Skip to content

Commit 6a2f2aa

Browse files
sararobcopybara-github
authored andcommitted
feat: LLM - Added the count_tokens method to the preview TextGenerationModel and TextEmbeddingModel classes
PiperOrigin-RevId: 570108703
1 parent 69a67f2 commit 6a2f2aa

File tree

4 files changed

+192
-6
lines changed

4 files changed

+192
-6
lines changed

tests/system/aiplatform/test_language_models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,18 @@ def test_text_generation(self):
6060
stop_sequences=["# %%"],
6161
).text
6262

63+
def test_text_generation_preview_count_tokens(self):
64+
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
65+
66+
model = preview_language_models.TextGenerationModel.from_pretrained(
67+
"google/text-bison@001"
68+
)
69+
70+
response = model.count_tokens(["How are you doing?"])
71+
72+
assert response.total_tokens
73+
assert response.total_billable_characters
74+
6375
@pytest.mark.asyncio
6476
async def test_text_generation_model_predict_async(self):
6577
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

tests/unit/aiplatform/test_language_models.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@
5858
model as gca_model,
5959
)
6060

61+
from google.cloud.aiplatform_v1beta1.services.prediction_service import (
62+
client as prediction_service_client_v1beta1,
63+
)
64+
from google.cloud.aiplatform_v1beta1.types import (
65+
prediction_service as gca_prediction_service_v1beta1,
66+
)
67+
6168
import vertexai
6269
from vertexai.preview import (
6370
language_models as preview_language_models,
@@ -306,6 +313,11 @@ def reverse_string_2(s):""",
306313
}
307314
}
308315

316+
_TEST_COUNT_TOKENS_RESPONSE = {
317+
"total_tokens": 5,
318+
"total_billable_characters": 25,
319+
}
320+
309321

310322
_TEST_TEXT_BISON_TRAINING_DF = pd.DataFrame(
311323
{
@@ -1206,6 +1218,43 @@ def test_text_generation(self):
12061218
== _TEST_TEXT_GENERATION_PREDICTION["safetyAttributes"]["scores"][0]
12071219
)
12081220

1221+
def test_text_generation_preview_count_tokens(self):
1222+
"""Tests the text generation model."""
1223+
aiplatform.init(
1224+
project=_TEST_PROJECT,
1225+
location=_TEST_LOCATION,
1226+
)
1227+
with mock.patch.object(
1228+
target=model_garden_service_client.ModelGardenServiceClient,
1229+
attribute="get_publisher_model",
1230+
return_value=gca_publisher_model.PublisherModel(
1231+
_TEXT_BISON_PUBLISHER_MODEL_DICT
1232+
),
1233+
):
1234+
model = preview_language_models.TextGenerationModel.from_pretrained(
1235+
"text-bison@001"
1236+
)
1237+
1238+
gca_count_tokens_response = gca_prediction_service_v1beta1.CountTokensResponse(
1239+
total_tokens=_TEST_COUNT_TOKENS_RESPONSE["total_tokens"],
1240+
total_billable_characters=_TEST_COUNT_TOKENS_RESPONSE[
1241+
"total_billable_characters"
1242+
],
1243+
)
1244+
1245+
with mock.patch.object(
1246+
target=prediction_service_client_v1beta1.PredictionServiceClient,
1247+
attribute="count_tokens",
1248+
return_value=gca_count_tokens_response,
1249+
):
1250+
response = model.count_tokens(["What is the best recipe for banana bread?"])
1251+
1252+
assert response.total_tokens == _TEST_COUNT_TOKENS_RESPONSE["total_tokens"]
1253+
assert (
1254+
response.total_billable_characters
1255+
== _TEST_COUNT_TOKENS_RESPONSE["total_billable_characters"]
1256+
)
1257+
12091258
def test_text_generation_ga(self):
12101259
"""Tests the text generation model."""
12111260
aiplatform.init(
@@ -2469,6 +2518,47 @@ def test_text_embedding(self):
24692518
== expected_embedding["statistics"]["truncated"]
24702519
)
24712520

2521+
def test_text_embedding_preview_count_tokens(self):
2522+
"""Tests the text embedding model."""
2523+
aiplatform.init(
2524+
project=_TEST_PROJECT,
2525+
location=_TEST_LOCATION,
2526+
)
2527+
with mock.patch.object(
2528+
target=model_garden_service_client.ModelGardenServiceClient,
2529+
attribute="get_publisher_model",
2530+
return_value=gca_publisher_model.PublisherModel(
2531+
_TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT
2532+
),
2533+
):
2534+
model = preview_language_models.TextEmbeddingModel.from_pretrained(
2535+
"textembedding-gecko@001"
2536+
)
2537+
2538+
gca_count_tokens_response = (
2539+
gca_prediction_service_v1beta1.CountTokensResponse(
2540+
total_tokens=_TEST_COUNT_TOKENS_RESPONSE["total_tokens"],
2541+
total_billable_characters=_TEST_COUNT_TOKENS_RESPONSE[
2542+
"total_billable_characters"
2543+
],
2544+
)
2545+
)
2546+
2547+
with mock.patch.object(
2548+
target=prediction_service_client_v1beta1.PredictionServiceClient,
2549+
attribute="count_tokens",
2550+
return_value=gca_count_tokens_response,
2551+
):
2552+
response = model.count_tokens(["What is life?"])
2553+
2554+
assert (
2555+
response.total_tokens == _TEST_COUNT_TOKENS_RESPONSE["total_tokens"]
2556+
)
2557+
assert (
2558+
response.total_billable_characters
2559+
== _TEST_COUNT_TOKENS_RESPONSE["total_billable_characters"]
2560+
)
2561+
24722562
def test_text_embedding_ga(self):
24732563
"""Tests the text embedding model."""
24742564
aiplatform.init(

vertexai/language_models/_language_models.py

Lines changed: 88 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,15 @@ def _model_resource_name(self) -> str:
9292
@dataclasses.dataclass
9393
class _PredictionRequest:
9494
"""A single-instance prediction request."""
95+
9596
instance: Dict[str, Any]
9697
parameters: Optional[Dict[str, Any]] = None
9798

9899

99100
@dataclasses.dataclass
100101
class _MultiInstancePredictionRequest:
101102
"""A multi-instance prediction request."""
103+
102104
instances: List[Dict[str, Any]]
103105
parameters: Optional[Dict[str, Any]] = None
104106

@@ -573,6 +575,62 @@ def tune_model(
573575
return job
574576

575577

578+
@dataclasses.dataclass
579+
class CountTokensResponse:
580+
"""The response from a count_tokens request.
581+
Attributes:
582+
total_tokens (int):
583+
The total number of tokens counted across all
584+
instances passed to the request.
585+
total_billable_characters (int):
586+
The total number of billable characters
587+
counted across all instances from the request.
588+
"""
589+
590+
total_tokens: int
591+
total_billable_characters: int
592+
_count_tokens_response: Any
593+
594+
595+
class _CountTokensMixin(_LanguageModel):
596+
"""Mixin for models that support the CountTokens API"""
597+
598+
def count_tokens(
599+
self,
600+
prompts: List[str],
601+
) -> CountTokensResponse:
602+
"""Counts the tokens and billable characters for a given prompt.
603+
604+
Note: this does not make a request to the model, it only counts the tokens
605+
in the request.
606+
607+
Args:
608+
prompts (List[str]):
609+
Required. A list of prompts to ask the model. For example: ["What should I do today?", "How's it going?"]
610+
611+
Returns:
612+
A `CountTokensResponse` object that contains the number of tokens
613+
in the text and the number of billable characters.
614+
"""
615+
instances = []
616+
617+
for prompt in prompts:
618+
instances.append({"content": prompt})
619+
620+
count_tokens_response = self._endpoint._prediction_client.select_version(
621+
"v1beta1"
622+
).count_tokens(
623+
endpoint=self._endpoint_name,
624+
instances=instances,
625+
)
626+
627+
return CountTokensResponse(
628+
total_tokens=count_tokens_response.total_tokens,
629+
total_billable_characters=count_tokens_response.total_billable_characters,
630+
_count_tokens_response=count_tokens_response,
631+
)
632+
633+
576634
@dataclasses.dataclass
577635
class TuningEvaluationSpec:
578636
"""Specification for model evaluation to perform during tuning.
@@ -587,6 +645,7 @@ class TuningEvaluationSpec:
587645
tensorboard: Vertex Tensorboard where to write the evaluation metrics.
588646
The Tensorboard must be in the same location as the tuning job.
589647
"""
648+
590649
__module__ = "vertexai.language_models"
591650

592651
evaluation_data: str
@@ -605,6 +664,7 @@ class TextGenerationResponse:
605664
Learn more about the safety attributes here:
606665
https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_descriptions
607666
"""
667+
608668
__module__ = "vertexai.language_models"
609669

610670
text: str
@@ -761,7 +821,9 @@ def predict_streaming(
761821
)
762822

763823
prediction_service_client = self._endpoint._prediction_client
764-
for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
824+
for (
825+
prediction_dict
826+
) in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
765827
prediction_service_client=prediction_service_client,
766828
endpoint_name=self._endpoint_name,
767829
instance=prediction_request.instance,
@@ -955,6 +1017,7 @@ class _PreviewTextGenerationModel(
9551017
_PreviewTunableTextModelMixin,
9561018
_PreviewModelWithBatchPredict,
9571019
_evaluatable_language_models._EvaluatableLanguageModel,
1020+
_CountTokensMixin,
9581021
):
9591022
# Do not add docstring so that it's inherited from the base class.
9601023
__name__ = "TextGenerationModel"
@@ -1094,6 +1157,7 @@ class TextEmbeddingInput:
10941157
Specifies that the embeddings will be used for clustering.
10951158
title: Optional identifier of the text content.
10961159
"""
1160+
10971161
__module__ = "vertexai.language_models"
10981162

10991163
text: str
@@ -1113,6 +1177,7 @@ class TextEmbeddingModel(_LanguageModel):
11131177
vector = embedding.values
11141178
print(len(vector))
11151179
"""
1180+
11161181
__module__ = "vertexai.language_models"
11171182

11181183
_LAUNCH_STAGE = _model_garden_models._SDK_GA_LAUNCH_STAGE
@@ -1173,7 +1238,8 @@ def _parse_text_embedding_response(
11731238
_prediction_response=prediction_response,
11741239
)
11751240

1176-
def get_embeddings(self,
1241+
def get_embeddings(
1242+
self,
11771243
texts: List[Union[str, TextEmbeddingInput]],
11781244
*,
11791245
auto_truncate: bool = True,
@@ -1207,7 +1273,8 @@ def get_embeddings(self,
12071273

12081274
return results
12091275

1210-
async def get_embeddings_async(self,
1276+
async def get_embeddings_async(
1277+
self,
12111278
texts: List[Union[str, TextEmbeddingInput]],
12121279
*,
12131280
auto_truncate: bool = True,
@@ -1242,7 +1309,9 @@ async def get_embeddings_async(self,
12421309
return results
12431310

12441311

1245-
class _PreviewTextEmbeddingModel(TextEmbeddingModel, _ModelWithBatchPredict):
1312+
class _PreviewTextEmbeddingModel(
1313+
TextEmbeddingModel, _ModelWithBatchPredict, _CountTokensMixin
1314+
):
12461315
__name__ = "TextEmbeddingModel"
12471316
__module__ = "vertexai.preview.language_models"
12481317

@@ -1252,6 +1321,7 @@ class _PreviewTextEmbeddingModel(TextEmbeddingModel, _ModelWithBatchPredict):
12521321
@dataclasses.dataclass
12531322
class TextEmbeddingStatistics:
12541323
"""Text embedding statistics."""
1324+
12551325
__module__ = "vertexai.language_models"
12561326

12571327
token_count: int
@@ -1261,6 +1331,7 @@ class TextEmbeddingStatistics:
12611331
@dataclasses.dataclass
12621332
class TextEmbedding:
12631333
"""Text embedding vector and statistics."""
1334+
12641335
__module__ = "vertexai.language_models"
12651336

12661337
values: List[float]
@@ -1271,6 +1342,7 @@ class TextEmbedding:
12711342
@dataclasses.dataclass
12721343
class InputOutputTextPair:
12731344
"""InputOutputTextPair represents a pair of input and output texts."""
1345+
12741346
__module__ = "vertexai.language_models"
12751347

12761348
input_text: str
@@ -1285,6 +1357,7 @@ class ChatMessage:
12851357
content: Content of the message.
12861358
author: Author of the message.
12871359
"""
1360+
12881361
__module__ = "vertexai.language_models"
12891362

12901363
content: str
@@ -1362,6 +1435,7 @@ class ChatModel(_ChatModelBase, _TunableChatModelMixin):
13621435
13631436
chat.send_message("Do you know any cool events this weekend?")
13641437
"""
1438+
13651439
__module__ = "vertexai.language_models"
13661440

13671441
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml"
@@ -1388,6 +1462,7 @@ class CodeChatModel(_ChatModelBase):
13881462
13891463
code_chat.send_message("Please help write a function to calculate the min of two numbers")
13901464
"""
1465+
13911466
__module__ = "vertexai.language_models"
13921467

13931468
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/codechat_generation_1.0.0.yaml"
@@ -1739,7 +1814,9 @@ def send_message_streaming(
17391814

17401815
full_response_text = ""
17411816

1742-
for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
1817+
for (
1818+
prediction_dict
1819+
) in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
17431820
prediction_service_client=prediction_service_client,
17441821
endpoint_name=self._model._endpoint_name,
17451822
instance=prediction_request.instance,
@@ -1770,6 +1847,7 @@ class ChatSession(_ChatSessionBase):
17701847
17711848
Within a chat session, the model keeps context and remembers the previous conversation.
17721849
"""
1850+
17731851
__module__ = "vertexai.language_models"
17741852

17751853
def __init__(
@@ -1802,6 +1880,7 @@ class CodeChatSession(_ChatSessionBase):
18021880
18031881
Within a code chat session, the model keeps context and remembers the previous converstion.
18041882
"""
1883+
18051884
__module__ = "vertexai.language_models"
18061885

18071886
def __init__(
@@ -1924,6 +2003,7 @@ class CodeGenerationModel(_LanguageModel):
19242003
prefix="def reverse_string(s):",
19252004
))
19262005
"""
2006+
19272007
__module__ = "vertexai.language_models"
19282008

19292009
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/code_generation_1.0.0.yaml"
@@ -2074,7 +2154,9 @@ def predict_streaming(
20742154
)
20752155

20762156
prediction_service_client = self._endpoint._prediction_client
2077-
for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
2157+
for (
2158+
prediction_dict
2159+
) in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
20782160
prediction_service_client=prediction_service_client,
20792161
endpoint_name=self._endpoint_name,
20802162
instance=prediction_request.instance,

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy