Skip to content

Commit 653ba88

Browse files
shawn-yang-googlecopybara-github
authored andcommitted
chore: Dynamic set query method
PiperOrigin-RevId: 698923437
1 parent 58ba55e commit 653ba88

File tree

2 files changed

+106
-46
lines changed

2 files changed

+106
-46
lines changed

tests/unit/vertex_langchain/test_reasoning_engines.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def register_operations(self) -> Dict[str, List[str]]:
146146
_TEST_STANDARD_API_MODE = _reasoning_engines._STANDARD_API_MODE
147147
_TEST_MODE_KEY_IN_SCHEMA = _reasoning_engines._MODE_KEY_IN_SCHEMA
148148
_TEST_DEFAULT_METHOD_NAME = _reasoning_engines._DEFAULT_METHOD_NAME
149+
_TEST_DEFAULT_METHOD_DOCSTRING = _reasoning_engines._DEFAULT_METHOD_DOCSTRING
149150
_TEST_CUSTOM_METHOD_NAME = "custom_method"
150151
_TEST_QUERY_PROMPT = "Find the first fibonacci number greater than 999"
151152
_TEST_REASONING_ENGINE_GCS_URI = "{}/{}/{}".format(
@@ -413,13 +414,6 @@ def query_reasoning_engine_mock():
413414
yield query_reasoning_engine_mock
414415

415416

416-
@pytest.fixture(scope="module")
417-
def to_dict_mock():
418-
with mock.patch.object(_utils, "to_dict") as to_dict_mock:
419-
to_dict_mock.return_value = {}
420-
yield to_dict_mock
421-
422-
423417
# Function scope is required for the pytest parameterized tests.
424418
@pytest.fixture(scope="function")
425419
def types_reasoning_engine_mock():
@@ -853,23 +847,49 @@ def test_delete_after_get_reasoning_engine(
853847
name=test_reasoning_engine.resource_name,
854848
)
855849

850+
def test_query_after_create_reasoning_engine(
851+
self,
852+
get_reasoning_engine_mock,
853+
query_reasoning_engine_mock,
854+
get_gca_resource_mock,
855+
):
856+
test_reasoning_engine = reasoning_engines.ReasoningEngine.create(
857+
self.test_app,
858+
display_name=_TEST_REASONING_ENGINE_DISPLAY_NAME,
859+
requirements=_TEST_REASONING_ENGINE_REQUIREMENTS,
860+
extra_packages=[_TEST_REASONING_ENGINE_EXTRA_PACKAGE_PATH],
861+
)
862+
get_reasoning_engine_mock.assert_called_with(
863+
name=_TEST_REASONING_ENGINE_RESOURCE_NAME,
864+
retry=_TEST_RETRY,
865+
)
866+
with mock.patch.object(_utils, "to_dict") as to_dict_mock:
867+
to_dict_mock.return_value = {}
868+
test_reasoning_engine.query(query=_TEST_QUERY_PROMPT)
869+
assert test_reasoning_engine.query.__doc__ == _TEST_DEFAULT_METHOD_DOCSTRING
870+
query_reasoning_engine_mock.assert_called_with(
871+
request=_TEST_REASONING_ENGINE_QUERY_REQUEST
872+
)
873+
to_dict_mock.assert_called_once()
874+
856875
def test_query_reasoning_engine(
857876
self,
858877
get_reasoning_engine_mock,
859878
query_reasoning_engine_mock,
860-
to_dict_mock,
861879
get_gca_resource_mock,
862880
):
863881
test_reasoning_engine = reasoning_engines.ReasoningEngine(_TEST_RESOURCE_ID)
864882
get_reasoning_engine_mock.assert_called_with(
865883
name=_TEST_REASONING_ENGINE_RESOURCE_NAME,
866884
retry=_TEST_RETRY,
867885
)
868-
test_reasoning_engine.query(query=_TEST_QUERY_PROMPT)
869-
query_reasoning_engine_mock.assert_called_with(
870-
request=_TEST_REASONING_ENGINE_QUERY_REQUEST
871-
)
872-
to_dict_mock.assert_called_once()
886+
with mock.patch.object(_utils, "to_dict") as to_dict_mock:
887+
to_dict_mock.return_value = {}
888+
test_reasoning_engine.query(query=_TEST_QUERY_PROMPT)
889+
query_reasoning_engine_mock.assert_called_with(
890+
request=_TEST_REASONING_ENGINE_QUERY_REQUEST
891+
)
892+
to_dict_mock.assert_called_once()
873893

874894
def test_operation_schemas(
875895
self,

vertexai/reasoning_engines/_reasoning_engines.py

Lines changed: 73 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
import os
2020
import sys
2121
import tarfile
22+
import types
2223
import typing
23-
from typing import Any, Dict, List, Optional, Protocol, Sequence, Union
24+
from typing import Any, Callable, Dict, List, Optional, Protocol, Sequence, Union
2425

2526
import proto
2627

@@ -29,7 +30,7 @@
2930
from google.cloud.aiplatform import base
3031
from google.cloud.aiplatform import initializer
3132
from google.cloud.aiplatform import utils as aip_utils
32-
from google.cloud.aiplatform_v1beta1 import types
33+
from google.cloud.aiplatform_v1beta1 import types as aip_types
3334
from google.cloud.aiplatform_v1beta1.types import reasoning_engine_service
3435
from vertexai.reasoning_engines import _utils
3536
from google.protobuf import field_mask_pb2
@@ -44,6 +45,19 @@
4445
_STANDARD_API_MODE = ""
4546
_MODE_KEY_IN_SCHEMA = "api_mode"
4647
_DEFAULT_METHOD_NAME = "query"
48+
_DEFAULT_METHOD_DOCSTRING = """
49+
Runs the Reasoning Engine to serve the user query.
50+
51+
This will be based on the `.query(...)` method of the python object that
52+
was passed in when creating the Reasoning Engine.
53+
54+
Args:
55+
**kwargs:
56+
Optional. The arguments of the `.query(...)` method.
57+
58+
Returns:
59+
dict[str, Any]: The response from serving the user query.
60+
"""
4761

4862

4963
@typing.runtime_checkable
@@ -73,7 +87,7 @@ def register_operations(self, **kwargs):
7387
"""Register the user provided operations (modes and methods)."""
7488

7589

76-
class ReasoningEngine(base.VertexAiResourceNounWithFutureManager, Queryable):
90+
class ReasoningEngine(base.VertexAiResourceNounWithFutureManager):
7791
"""Represents a Vertex AI Reasoning Engine resource."""
7892

7993
client_class = aip_utils.ReasoningEngineClientWithOverride
@@ -98,6 +112,7 @@ def __init__(self, reasoning_engine_name: str):
98112
client_class=aip_utils.ReasoningEngineExecutionClientWithOverride,
99113
)
100114
self._gca_resource = self._get_gca_resource(resource_name=reasoning_engine_name)
115+
_register_api_method(self)
101116
self._operation_schemas = None
102117

103118
@property
@@ -233,7 +248,7 @@ def create(
233248
extra_packages=extra_packages,
234249
)
235250
# Update the package spec.
236-
package_spec = types.ReasoningEngineSpec.PackageSpec(
251+
package_spec = aip_types.ReasoningEngineSpec.PackageSpec(
237252
python_version=sys_version,
238253
pickle_object_gcs_uri="{}/{}/{}".format(
239254
staging_bucket,
@@ -253,7 +268,7 @@ def create(
253268
gcs_dir_name,
254269
_REQUIREMENTS_FILE,
255270
)
256-
reasoning_engine_spec = types.ReasoningEngineSpec(
271+
reasoning_engine_spec = aip_types.ReasoningEngineSpec(
257272
package_spec=package_spec,
258273
)
259274
class_methods_spec = _generate_class_methods_spec_or_raise(
@@ -264,7 +279,7 @@ def create(
264279
parent=initializer.global_config.common_location_path(
265280
project=sdk_resource.project, location=sdk_resource.location
266281
),
267-
reasoning_engine=types.ReasoningEngine(
282+
reasoning_engine=aip_types.ReasoningEngine(
268283
name=reasoning_engine_name,
269284
display_name=display_name,
270285
description=description,
@@ -289,6 +304,7 @@ def create(
289304
credentials=sdk_resource.credentials,
290305
location_override=sdk_resource.location,
291306
)
307+
_register_api_method(sdk_resource)
292308
sdk_resource._operation_schemas = None
293309
return sdk_resource
294310

@@ -431,30 +447,6 @@ def operation_schemas(self) -> Sequence[_utils.JsonDict]:
431447
self._operation_schemas = spec.get("class_methods", [])
432448
return self._operation_schemas
433449

434-
def query(self, **kwargs) -> _utils.JsonDict:
435-
"""Runs the Reasoning Engine to serve the user query.
436-
437-
This will be based on the `.query(...)` method of the python object that
438-
was passed in when creating the Reasoning Engine.
439-
440-
Args:
441-
**kwargs:
442-
Optional. The arguments of the `.query(...)` method.
443-
444-
Returns:
445-
dict[str, Any]: The response from serving the user query.
446-
"""
447-
response = self.execution_api_client.query_reasoning_engine(
448-
request=types.QueryReasoningEngineRequest(
449-
name=self.resource_name,
450-
input=kwargs,
451-
),
452-
)
453-
output = _utils.to_dict(response)
454-
if "output" in output:
455-
return output.get("output")
456-
return output
457-
458450

459451
def _validate_sys_version_or_raise(sys_version: str) -> None:
460452
"""Tries to validate the python system version."""
@@ -630,8 +622,8 @@ def _generate_update_request_or_raise(
630622
"""Tries to generates the update request for the reasoning engine."""
631623
is_spec_update = False
632624
update_masks: List[str] = []
633-
reasoning_engine_spec = types.ReasoningEngineSpec()
634-
package_spec = types.ReasoningEngineSpec.PackageSpec()
625+
reasoning_engine_spec = aip_types.ReasoningEngineSpec()
626+
package_spec = aip_types.ReasoningEngineSpec.PackageSpec()
635627
if requirements is not None:
636628
is_spec_update = True
637629
update_masks.append("spec.package_spec.requirements_gcs_uri")
@@ -662,7 +654,7 @@ def _generate_update_request_or_raise(
662654
reasoning_engine_spec.class_methods.extend(class_methods_spec)
663655
update_masks.append("spec.class_methods")
664656

665-
reasoning_engine_message = types.ReasoningEngine(name=resource_name)
657+
reasoning_engine_message = aip_types.ReasoningEngine(name=resource_name)
666658
if is_spec_update:
667659
reasoning_engine_spec.package_spec = package_spec
668660
reasoning_engine_message.spec = reasoning_engine_spec
@@ -684,6 +676,54 @@ def _generate_update_request_or_raise(
684676
)
685677

686678

679+
def _wrap_query_operation(method_name: str, doc: str) -> Callable[..., _utils.JsonDict]:
680+
"""Wraps a Reasoning Engine method, creating a callable for `query` API.
681+
682+
This function creates a callable object that executes the specified
683+
Reasoning Engine method using the `query` API. It handles the creation of
684+
the API request and the processing of the API response.
685+
686+
Args:
687+
method_name: The name of the Reasoning Engine method to call.
688+
doc: Documentation string for the method.
689+
690+
Returns:
691+
A callable object that executes the method on the Reasoning Engine via
692+
the `query` API.
693+
"""
694+
695+
def _method(self, **kwargs) -> _utils.JsonDict:
696+
response = self.execution_api_client.query_reasoning_engine(
697+
request=aip_types.QueryReasoningEngineRequest(
698+
name=self.resource_name,
699+
input=kwargs,
700+
),
701+
)
702+
output = _utils.to_dict(response)
703+
return output.get("output", output)
704+
705+
_method.__name__ = method_name
706+
_method.__doc__ = doc
707+
708+
return _method
709+
710+
711+
def _register_api_method(obj: "ReasoningEngine"):
712+
"""Registers Reasoning Engine API methods based on operation schemas.
713+
714+
This function registers `query` method on the ReasoningEngine object
715+
to handle API calls based on the specified API mode.
716+
717+
Args:
718+
obj: The ReasoningEngine object to augment with API methods.
719+
"""
720+
query_method = _wrap_query_operation(
721+
method_name=_DEFAULT_METHOD_NAME, doc=_DEFAULT_METHOD_DOCSTRING
722+
)
723+
# Binds the method to the object.
724+
setattr(obj, _DEFAULT_METHOD_NAME, types.MethodType(query_method, obj))
725+
726+
687727
def _get_registered_operations(reasoning_engine: Any) -> Dict[str, List[str]]:
688728
"""Retrieves registered operations for a ReasoningEngine."""
689729
if isinstance(reasoning_engine, OperationRegistrable):

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy