Skip to content

Commit 0359f1d

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Support streaming prediction for code chat models
PiperOrigin-RevId: 558364254
1 parent 3a8348b commit 0359f1d

File tree

3 files changed

+85
-8
lines changed

3 files changed

+85
-8
lines changed

tests/system/aiplatform/test_language_models.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,19 @@ def test_code_generation_streaming(self):
260260

261261
for response in model.predict_streaming(
262262
prefix="def reverse_string(s):",
263-
suffix=" return s",
263+
# code-bison does not support suffix
264+
# suffix=" return s",
264265
max_output_tokens=128,
265266
temperature=0,
266267
):
267268
assert response.text
269+
270+
def test_code_chat_model_send_message_streaming(self):
271+
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
272+
273+
chat_model = language_models.ChatModel.from_pretrained("codeodechat-bison@001")
274+
chat = chat_model.start_chat()
275+
276+
message1 = "Please help write a function to calculate the max of two numbers"
277+
for response in chat.send_message_streaming(message1):
278+
assert response.text

tests/unit/aiplatform/test_language_models.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1938,6 +1938,51 @@ def test_code_chat(self):
19381938
assert prediction_parameters["temperature"] == message_temperature
19391939
assert prediction_parameters["maxDecodeSteps"] == message_max_output_tokens
19401940

1941+
def test_code_chat_model_send_message_streaming(self):
1942+
"""Tests the chat generation model."""
1943+
aiplatform.init(
1944+
project=_TEST_PROJECT,
1945+
location=_TEST_LOCATION,
1946+
)
1947+
with mock.patch.object(
1948+
target=model_garden_service_client.ModelGardenServiceClient,
1949+
attribute="get_publisher_model",
1950+
return_value=gca_publisher_model.PublisherModel(
1951+
_CODECHAT_BISON_PUBLISHER_MODEL_DICT
1952+
),
1953+
):
1954+
model = language_models.CodeChatModel.from_pretrained("codechat-bison@001")
1955+
1956+
chat = model.start_chat(temperature=0.0)
1957+
1958+
# Using list instead of a generator so that it can be reused.
1959+
response_generator = [
1960+
gca_prediction_service.StreamingPredictResponse(
1961+
outputs=[_streaming_prediction.value_to_tensor(response_dict)]
1962+
)
1963+
for response_dict in _TEST_CHAT_PREDICTION_STREAMING
1964+
]
1965+
1966+
with mock.patch.object(
1967+
target=prediction_service_client.PredictionServiceClient,
1968+
attribute="server_streaming_predict",
1969+
return_value=response_generator,
1970+
):
1971+
message_text1 = (
1972+
"Please help write a function to calculate the max of two numbers"
1973+
)
1974+
# New messages are not added until the response is fully read
1975+
assert not chat.message_history
1976+
for response in chat.send_message_streaming(message_text1):
1977+
assert len(response.text) > 10
1978+
# New messages are only added after the response is fully read
1979+
assert chat.message_history
1980+
1981+
assert len(chat.message_history) == 2
1982+
assert chat.message_history[0].author == chat.USER_AUTHOR
1983+
assert chat.message_history[0].content == message_text1
1984+
assert chat.message_history[1].author == chat.MODEL_AUTHOR
1985+
19411986
def test_code_generation(self):
19421987
"""Tests code generation with the code generation model."""
19431988
aiplatform.init(

vertexai/language_models/_language_models.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,6 @@ def _get_model_id_from_tuning_model_id(tuning_model_id: str) -> str:
5959
return f"publishers/google/models/{model_name}@{version}"
6060

6161

62-
@dataclasses.dataclass
63-
class _PredictionRequest:
64-
"""A single-instance prediction request."""
65-
instance: Dict[str, Any]
66-
parameters: Optional[Dict[str, Any]] = None
67-
68-
6962
class _LanguageModel(_model_garden_models._ModelGardenModel):
7063
"""_LanguageModel is a base class for all language models."""
7164

@@ -1234,6 +1227,34 @@ def send_message(
12341227
temperature=temperature,
12351228
)
12361229

1230+
def send_message_streaming(
1231+
self,
1232+
message: str,
1233+
*,
1234+
max_output_tokens: Optional[int] = None,
1235+
temperature: Optional[float] = None,
1236+
) -> Iterator[TextGenerationResponse]:
1237+
"""Sends message to the language model and gets a streamed response.
1238+
1239+
The response is only added to the history once it's fully read.
1240+
1241+
Args:
1242+
message: Message to send to the model
1243+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
1244+
Uses the value specified when calling `ChatModel.start_chat` by default.
1245+
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
1246+
Uses the value specified when calling `ChatModel.start_chat` by default.
1247+
1248+
Returns:
1249+
A stream of `TextGenerationResponse` objects that contain partial
1250+
responses produced by the model.
1251+
"""
1252+
return super().send_message_streaming(
1253+
message=message,
1254+
max_output_tokens=max_output_tokens,
1255+
temperature=temperature,
1256+
)
1257+
12371258

12381259
class CodeGenerationModel(_LanguageModel):
12391260
"""A language model that generates code.

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy