Skip to content

Commit 01989b1

Browse files
sararobcopybara-github
authored andcommitted
feat: LLM - Added count_tokens support to ChatModel (preview)
PiperOrigin-RevId: 575006811
1 parent eb6071f commit 01989b1

File tree

4 files changed

+264
-6
lines changed

4 files changed

+264
-6
lines changed

tests/system/aiplatform/test_language_models.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,29 @@ def test_chat_on_chat_model(self):
159159
assert chat.message_history[2].content == message2
160160
assert chat.message_history[3].author == chat.MODEL_AUTHOR
161161

162+
def test_chat_model_preview_count_tokens(self):
163+
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
164+
165+
chat_model = ChatModel.from_pretrained("google/chat-bison@001")
166+
167+
chat = chat_model.start_chat()
168+
169+
chat.send_message("What should I do today?")
170+
171+
response_with_history = chat.count_tokens("Any ideas?")
172+
173+
response_without_history = chat_model.start_chat().count_tokens(
174+
"What should I do today?"
175+
)
176+
177+
assert (
178+
response_with_history.total_tokens > response_without_history.total_tokens
179+
)
180+
assert (
181+
response_with_history.total_billable_characters
182+
> response_without_history.total_billable_characters
183+
)
184+
162185
@pytest.mark.asyncio
163186
async def test_chat_model_async(self):
164187
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

tests/unit/aiplatform/test_language_models.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2377,6 +2377,44 @@ def test_chat_model_send_message_streaming(self):
23772377
assert chat.message_history[2].content == message_text1
23782378
assert chat.message_history[3].author == chat.MODEL_AUTHOR
23792379

2380+
def test_chat_model_preview_count_tokens(self):
2381+
"""Tests the text generation model."""
2382+
aiplatform.init(
2383+
project=_TEST_PROJECT,
2384+
location=_TEST_LOCATION,
2385+
)
2386+
with mock.patch.object(
2387+
target=model_garden_service_client.ModelGardenServiceClient,
2388+
attribute="get_publisher_model",
2389+
return_value=gca_publisher_model.PublisherModel(
2390+
_CHAT_BISON_PUBLISHER_MODEL_DICT
2391+
),
2392+
):
2393+
model = preview_language_models.ChatModel.from_pretrained("chat-bison@001")
2394+
2395+
chat = model.start_chat()
2396+
assert isinstance(chat, preview_language_models.ChatSession)
2397+
2398+
gca_count_tokens_response = gca_prediction_service_v1beta1.CountTokensResponse(
2399+
total_tokens=_TEST_COUNT_TOKENS_RESPONSE["total_tokens"],
2400+
total_billable_characters=_TEST_COUNT_TOKENS_RESPONSE[
2401+
"total_billable_characters"
2402+
],
2403+
)
2404+
2405+
with mock.patch.object(
2406+
target=prediction_service_client_v1beta1.PredictionServiceClient,
2407+
attribute="count_tokens",
2408+
return_value=gca_count_tokens_response,
2409+
):
2410+
response = chat.count_tokens("What is the best recipe for banana bread?")
2411+
2412+
assert response.total_tokens == _TEST_COUNT_TOKENS_RESPONSE["total_tokens"]
2413+
assert (
2414+
response.total_billable_characters
2415+
== _TEST_COUNT_TOKENS_RESPONSE["total_billable_characters"]
2416+
)
2417+
23802418
def test_code_chat(self):
23812419
"""Tests the code chat model."""
23822420
aiplatform.init(
@@ -2577,6 +2615,46 @@ def test_code_chat_model_send_message_streaming(self):
25772615
assert chat.message_history[0].content == message_text1
25782616
assert chat.message_history[1].author == chat.MODEL_AUTHOR
25792617

2618+
def test_code_chat_model_preview_count_tokens(self):
2619+
"""Tests the text generation model."""
2620+
aiplatform.init(
2621+
project=_TEST_PROJECT,
2622+
location=_TEST_LOCATION,
2623+
)
2624+
with mock.patch.object(
2625+
target=model_garden_service_client.ModelGardenServiceClient,
2626+
attribute="get_publisher_model",
2627+
return_value=gca_publisher_model.PublisherModel(
2628+
_CODECHAT_BISON_PUBLISHER_MODEL_DICT
2629+
),
2630+
):
2631+
model = preview_language_models.CodeChatModel.from_pretrained(
2632+
"codechat-bison@001"
2633+
)
2634+
2635+
chat = model.start_chat()
2636+
assert isinstance(chat, preview_language_models.CodeChatSession)
2637+
2638+
gca_count_tokens_response = gca_prediction_service_v1beta1.CountTokensResponse(
2639+
total_tokens=_TEST_COUNT_TOKENS_RESPONSE["total_tokens"],
2640+
total_billable_characters=_TEST_COUNT_TOKENS_RESPONSE[
2641+
"total_billable_characters"
2642+
],
2643+
)
2644+
2645+
with mock.patch.object(
2646+
target=prediction_service_client_v1beta1.PredictionServiceClient,
2647+
attribute="count_tokens",
2648+
return_value=gca_count_tokens_response,
2649+
):
2650+
response = chat.count_tokens("What is the best recipe for banana bread?")
2651+
2652+
assert response.total_tokens == _TEST_COUNT_TOKENS_RESPONSE["total_tokens"]
2653+
assert (
2654+
response.total_billable_characters
2655+
== _TEST_COUNT_TOKENS_RESPONSE["total_billable_characters"]
2656+
)
2657+
25802658
def test_code_generation(self):
25812659
"""Tests code generation with the code generation model."""
25822660
aiplatform.init(

vertexai/language_models/_language_models.py

Lines changed: 159 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,9 @@ def tune_model(
222222
if eval_spec.evaluation_data:
223223
if isinstance(eval_spec.evaluation_data, str):
224224
if eval_spec.evaluation_data.startswith("gs://"):
225-
tuning_parameters["evaluation_data_uri"] = eval_spec.evaluation_data
225+
tuning_parameters[
226+
"evaluation_data_uri"
227+
] = eval_spec.evaluation_data
226228
else:
227229
raise ValueError("evaluation_data should be a GCS URI")
228230
else:
@@ -627,7 +629,7 @@ def count_tokens(
627629
) -> CountTokensResponse:
628630
"""Counts the tokens and billable characters for a given prompt.
629631
630-
Note: this does not make a request to the model, it only counts the tokens
632+
Note: this does not make a prediction request to the model, it only counts the tokens
631633
in the request.
632634
633635
Args:
@@ -802,7 +804,9 @@ def predict(
802804
parameters=prediction_request.parameters,
803805
)
804806

805-
return _parse_text_generation_model_multi_candidate_response(prediction_response)
807+
return _parse_text_generation_model_multi_candidate_response(
808+
prediction_response
809+
)
806810

807811
async def predict_async(
808812
self,
@@ -844,7 +848,9 @@ async def predict_async(
844848
parameters=prediction_request.parameters,
845849
)
846850

847-
return _parse_text_generation_model_multi_candidate_response(prediction_response)
851+
return _parse_text_generation_model_multi_candidate_response(
852+
prediction_response
853+
)
848854

849855
def predict_streaming(
850856
self,
@@ -1587,6 +1593,47 @@ class _PreviewChatModel(ChatModel, _PreviewTunableChatModelMixin):
15871593

15881594
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
15891595

1596+
def start_chat(
1597+
self,
1598+
*,
1599+
context: Optional[str] = None,
1600+
examples: Optional[List[InputOutputTextPair]] = None,
1601+
max_output_tokens: Optional[int] = None,
1602+
temperature: Optional[float] = None,
1603+
top_k: Optional[int] = None,
1604+
top_p: Optional[float] = None,
1605+
message_history: Optional[List[ChatMessage]] = None,
1606+
stop_sequences: Optional[List[str]] = None,
1607+
) -> "_PreviewChatSession":
1608+
"""Starts a chat session with the model.
1609+
1610+
Args:
1611+
context: Context shapes how the model responds throughout the conversation.
1612+
For example, you can use context to specify words the model can or cannot use, topics to focus on or avoid, or the response format or style
1613+
examples: List of structured messages to the model to learn how to respond to the conversation.
1614+
A list of `InputOutputTextPair` objects.
1615+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
1616+
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
1617+
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
1618+
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
1619+
message_history: A list of previously sent and received messages.
1620+
stop_sequences: Customized stop sequences to stop the decoding process.
1621+
1622+
Returns:
1623+
A `ChatSession` object.
1624+
"""
1625+
return _PreviewChatSession(
1626+
model=self,
1627+
context=context,
1628+
examples=examples,
1629+
max_output_tokens=max_output_tokens,
1630+
temperature=temperature,
1631+
top_k=top_k,
1632+
top_p=top_p,
1633+
message_history=message_history,
1634+
stop_sequences=stop_sequences,
1635+
)
1636+
15901637

15911638
class CodeChatModel(_ChatModelBase):
15921639
"""CodeChatModel represents a model that is capable of completing code.
@@ -1646,6 +1693,47 @@ class _PreviewCodeChatModel(CodeChatModel, _TunableChatModelMixin):
16461693

16471694
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
16481695

1696+
def start_chat(
1697+
self,
1698+
*,
1699+
context: Optional[str] = None,
1700+
examples: Optional[List[InputOutputTextPair]] = None,
1701+
max_output_tokens: Optional[int] = None,
1702+
temperature: Optional[float] = None,
1703+
top_k: Optional[int] = None,
1704+
top_p: Optional[float] = None,
1705+
message_history: Optional[List[ChatMessage]] = None,
1706+
stop_sequences: Optional[List[str]] = None,
1707+
) -> "_PreviewCodeChatSession":
1708+
"""Starts a chat session with the model.
1709+
1710+
Args:
1711+
context: Context shapes how the model responds throughout the conversation.
1712+
For example, you can use context to specify words the model can or cannot use, topics to focus on or avoid, or the response format or style
1713+
examples: List of structured messages to the model to learn how to respond to the conversation.
1714+
A list of `InputOutputTextPair` objects.
1715+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
1716+
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
1717+
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
1718+
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
1719+
message_history: A list of previously sent and received messages.
1720+
stop_sequences: Customized stop sequences to stop the decoding process.
1721+
1722+
Returns:
1723+
A `ChatSession` object.
1724+
"""
1725+
return _PreviewCodeChatSession(
1726+
model=self,
1727+
context=context,
1728+
examples=examples,
1729+
max_output_tokens=max_output_tokens,
1730+
temperature=temperature,
1731+
top_k=top_k,
1732+
top_p=top_p,
1733+
message_history=message_history,
1734+
stop_sequences=stop_sequences,
1735+
)
1736+
16491737

16501738
class _ChatSessionBase:
16511739
"""_ChatSessionBase is a base class for all chat sessions."""
@@ -2071,6 +2159,67 @@ async def send_message_streaming_async(
20712159
)
20722160

20732161

2162+
class _ChatSessionBaseWithCountTokensMixin(_ChatSessionBase):
2163+
"""A mixin class for adding count_tokens to ChatSession."""
2164+
2165+
def count_tokens(
2166+
self,
2167+
message: str,
2168+
) -> CountTokensResponse:
2169+
"""Counts the tokens and billable characters for the provided chat message and any message history,
2170+
context, or examples set on the chat session.
2171+
2172+
If you've called `send_message()` in the current chat session before calling `count_tokens()`, the
2173+
response will include the total tokens and characters for the previously sent message and the one in the
2174+
`count_tokens()` request. To count the tokens for a single message, call `count_tokens()` right after
2175+
calling `start_chat()` before calling `send_message()`.
2176+
2177+
Note: this does not make a prediction request to the model, it only counts the tokens
2178+
in the request.
2179+
2180+
Examples::
2181+
2182+
model = ChatModel.from_pretrained("chat-bison@001")
2183+
chat_session = model.start_chat()
2184+
count_tokens_response = chat_session.count_tokens("How's it going?")
2185+
2186+
count_tokens_response.total_tokens
2187+
count_tokens_response.total_billable_characters
2188+
2189+
Args:
2190+
message (str):
2191+
Required. A chat message to count tokens or. For example: "How's it going?"
2192+
Returns:
2193+
A `CountTokensResponse` object that contains the number of tokens
2194+
in the text and the number of billable characters.
2195+
"""
2196+
2197+
count_tokens_request = self._prepare_request(message=message)
2198+
2199+
count_tokens_response = self._model._endpoint._prediction_client.select_version(
2200+
"v1beta1"
2201+
).count_tokens(
2202+
endpoint=self._model._endpoint_name,
2203+
instances=[count_tokens_request.instance],
2204+
)
2205+
2206+
return CountTokensResponse(
2207+
total_tokens=count_tokens_response.total_tokens,
2208+
total_billable_characters=count_tokens_response.total_billable_characters,
2209+
_count_tokens_response=count_tokens_response,
2210+
)
2211+
2212+
2213+
class _PreviewChatSession(_ChatSessionBaseWithCountTokensMixin):
2214+
2215+
__module__ = "vertexai.preview.language_models"
2216+
2217+
2218+
class _PreviewCodeChatSession(_ChatSessionBaseWithCountTokensMixin):
2219+
2220+
__module__ = "vertexai.preview.language_models"
2221+
2222+
20742223
class ChatSession(_ChatSessionBase):
20752224
"""ChatSession represents a chat session with a language model.
20762225
@@ -2361,7 +2510,9 @@ def predict(
23612510
instances=[prediction_request.instance],
23622511
parameters=prediction_request.parameters,
23632512
)
2364-
return _parse_text_generation_model_multi_candidate_response(prediction_response)
2513+
return _parse_text_generation_model_multi_candidate_response(
2514+
prediction_response
2515+
)
23652516

23662517
async def predict_async(
23672518
self,
@@ -2400,7 +2551,9 @@ async def predict_async(
24002551
instances=[prediction_request.instance],
24012552
parameters=prediction_request.parameters,
24022553
)
2403-
return _parse_text_generation_model_multi_candidate_response(prediction_response)
2554+
return _parse_text_generation_model_multi_candidate_response(
2555+
prediction_response
2556+
)
24042557

24052558
def predict_streaming(
24062559
self,

vertexai/preview/language_models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
from vertexai.language_models._language_models import (
1818
_PreviewChatModel,
19+
_PreviewChatSession,
1920
_PreviewCodeChatModel,
21+
_PreviewCodeChatSession,
2022
_PreviewCodeGenerationModel,
2123
_PreviewTextEmbeddingModel,
2224
_PreviewTextGenerationModel,
@@ -43,7 +45,9 @@
4345

4446

4547
ChatModel = _PreviewChatModel
48+
ChatSession = _PreviewChatSession
4649
CodeChatModel = _PreviewCodeChatModel
50+
CodeChatSession = _PreviewCodeChatSession
4751
CodeGenerationModel = _PreviewCodeGenerationModel
4852
TextGenerationModel = _PreviewTextGenerationModel
4953
TextEmbeddingModel = _PreviewTextEmbeddingModel

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy