Skip to content

Commit 0008735

Browse files
Christine Bettscopybara-github
Christine Betts
authored andcommitted
feat: add support for query method in Vertex AI Extension SDK
PiperOrigin-RevId: 662504522
1 parent 659ba3f commit 0008735

File tree

2 files changed

+132
-4
lines changed

2 files changed

+132
-4
lines changed

tests/unit/vertexai/test_extensions.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,13 @@
2323
from google.cloud.aiplatform import initializer
2424
from google.cloud.aiplatform import utils as aip_utils
2525
from google.cloud.aiplatform_v1beta1 import types
26-
from google.cloud.aiplatform_v1beta1.services import extension_execution_service
27-
from google.cloud.aiplatform_v1beta1.services import extension_registry_service
26+
from google.cloud.aiplatform_v1beta1.services import (
27+
extension_execution_service,
28+
)
29+
from google.cloud.aiplatform_v1beta1.services import (
30+
extension_registry_service,
31+
)
32+
from vertexai.generative_models import _generative_models
2833
from vertexai.preview import extensions
2934
from vertexai.reasoning_engines import _utils
3035
import pytest
@@ -180,6 +185,33 @@ def execute_extension_mock():
180185
yield execute_extension_mock
181186

182187

188+
@pytest.fixture
189+
def query_extension_mock():
190+
with mock.patch.object(
191+
extension_execution_service.ExtensionExecutionServiceClient, "query_extension"
192+
) as query_extension_mock:
193+
query_extension_mock.return_value.steps = [
194+
types.Content(
195+
role="user",
196+
parts=[
197+
types.Part(
198+
text=_TEST_QUERY_PROMPT,
199+
)
200+
],
201+
),
202+
types.Content(
203+
role="extension",
204+
parts=[
205+
types.Part(
206+
text=_TEST_RESPONSE_CONTENT,
207+
)
208+
],
209+
),
210+
]
211+
query_extension_mock.return_value.failure_message = ""
212+
yield query_extension_mock
213+
214+
183215
@pytest.fixture
184216
def delete_extension_mock():
185217
with mock.patch.object(
@@ -325,6 +357,49 @@ def test_execute_extension(
325357
),
326358
)
327359

360+
def test_query_extension(
361+
self,
362+
get_extension_mock,
363+
query_extension_mock,
364+
load_yaml_mock,
365+
):
366+
test_extension = extensions.Extension(_TEST_RESOURCE_ID)
367+
get_extension_mock.assert_called_once_with(
368+
name=_TEST_EXTENSION_RESOURCE_NAME,
369+
retry=aiplatform.base._DEFAULT_RETRY,
370+
)
371+
# Manually set _gca_resource here to prevent the mocks from propagating.
372+
test_extension._gca_resource = _TEST_EXTENSION_OBJ
373+
response = test_extension.query(
374+
contents=[
375+
_generative_models.Content(
376+
parts=[
377+
_generative_models.Part.from_text(
378+
_TEST_QUERY_PROMPT,
379+
)
380+
],
381+
role="user",
382+
)
383+
],
384+
)
385+
assert response.steps[-1].parts[0].text == _TEST_RESPONSE_CONTENT
386+
387+
query_extension_mock.assert_called_once_with(
388+
types.QueryExtensionRequest(
389+
name=_TEST_EXTENSION_RESOURCE_NAME,
390+
contents=[
391+
types.Content(
392+
role="user",
393+
parts=[
394+
types.Part(
395+
text=_TEST_QUERY_PROMPT,
396+
)
397+
],
398+
)
399+
],
400+
),
401+
)
402+
328403
def test_api_spec_from_yaml(self, get_extension_mock, load_yaml_mock):
329404
test_extension = extensions.Extension(_TEST_RESOURCE_ID)
330405
get_extension_mock.assert_called_once_with(

vertexai/extensions/_extensions.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
# limitations under the License.
1515
#
1616
import json
17-
from typing import Optional, Sequence, Union
17+
from typing import List, Optional, Sequence, Union
1818

1919
from google.cloud.aiplatform import base
2020
from google.cloud.aiplatform import initializer
2121
from google.cloud.aiplatform import utils as aip_utils
2222
from google.cloud.aiplatform_v1beta1 import types
23+
from vertexai.generative_models import _generative_models
2324
from vertexai.reasoning_engines import _utils
24-
2525
from google.protobuf import struct_pb2
2626

2727
_LOGGER = base.Logger(__name__)
@@ -248,6 +248,36 @@ def execute(
248248
response = self.execution_api_client.execute_extension(request)
249249
return _try_parse_execution_response(response)
250250

251+
def query(
252+
self,
253+
contents: _generative_models.ContentsType,
254+
) -> "QueryExtensionResponse":
255+
"""Queries an extension with the specified contents.
256+
257+
Args:
258+
contents (ContentsType):
259+
Required. The content of the current
260+
conversation with the model.
261+
For single-turn queries, this is a single
262+
instance. For multi-turn queries, this is a
263+
repeated field that contains conversation
264+
history + latest request.
265+
266+
Returns:
267+
The result of querying the extension.
268+
269+
Raises:
270+
RuntimeError: If the response contains an error.
271+
"""
272+
request = types.QueryExtensionRequest(
273+
name=self.resource_name,
274+
contents=_generative_models._content_types_to_gapic_contents(contents),
275+
)
276+
response = self.execution_api_client.query_extension(request)
277+
if response.failure_message:
278+
raise RuntimeError(response.failure_message)
279+
return QueryExtensionResponse._from_gapic(response)
280+
251281
@classmethod
252282
def from_hub(
253283
cls,
@@ -317,6 +347,29 @@ def from_hub(
317347
)
318348

319349

350+
class QueryExtensionResponse:
351+
"""A class representing the response from querying an extension."""
352+
353+
def __init__(self, steps: List[_generative_models.Content]):
354+
"""Initializes the QueryExtensionResponse with the given steps."""
355+
self.steps = steps
356+
357+
@classmethod
358+
def _from_gapic(
359+
cls, response: types.QueryExtensionResponse
360+
) -> "QueryExtensionResponse":
361+
"""Creates a QueryExtensionResponse from a gapic response."""
362+
return cls(
363+
steps=[
364+
_generative_models.Content(
365+
parts=[_generative_models.Part._from_gapic(p) for p in c.parts],
366+
role=c.role,
367+
)
368+
for c in response.steps
369+
]
370+
)
371+
372+
320373
def _try_parse_execution_response(
321374
response: types.ExecuteExtensionResponse,
322375
) -> Union[_utils.JsonDict, str]:

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy