19
19
import os
20
20
import sys
21
21
import tarfile
22
+ import types
22
23
import typing
23
- from typing import Any , Dict , List , Optional , Protocol , Sequence , Union
24
+ from typing import Any , Callable , Dict , List , Optional , Protocol , Sequence , Union
24
25
25
26
import proto
26
27
29
30
from google .cloud .aiplatform import base
30
31
from google .cloud .aiplatform import initializer
31
32
from google .cloud .aiplatform import utils as aip_utils
32
- from google .cloud .aiplatform_v1beta1 import types
33
+ from google .cloud .aiplatform_v1beta1 import types as aip_types
33
34
from google .cloud .aiplatform_v1beta1 .types import reasoning_engine_service
34
35
from vertexai .reasoning_engines import _utils
35
36
from google .protobuf import field_mask_pb2
44
45
_STANDARD_API_MODE = ""
45
46
_MODE_KEY_IN_SCHEMA = "api_mode"
46
47
_DEFAULT_METHOD_NAME = "query"
48
+ _DEFAULT_METHOD_DOCSTRING = """
49
+ Runs the Reasoning Engine to serve the user query.
50
+
51
+ This will be based on the `.query(...)` method of the python object that
52
+ was passed in when creating the Reasoning Engine.
53
+
54
+ Args:
55
+ **kwargs:
56
+ Optional. The arguments of the `.query(...)` method.
57
+
58
+ Returns:
59
+ dict[str, Any]: The response from serving the user query.
60
+ """
47
61
48
62
49
63
@typing .runtime_checkable
@@ -73,7 +87,7 @@ def register_operations(self, **kwargs):
73
87
"""Register the user provided operations (modes and methods)."""
74
88
75
89
76
- class ReasoningEngine (base .VertexAiResourceNounWithFutureManager , Queryable ):
90
+ class ReasoningEngine (base .VertexAiResourceNounWithFutureManager ):
77
91
"""Represents a Vertex AI Reasoning Engine resource."""
78
92
79
93
client_class = aip_utils .ReasoningEngineClientWithOverride
@@ -98,6 +112,7 @@ def __init__(self, reasoning_engine_name: str):
98
112
client_class = aip_utils .ReasoningEngineExecutionClientWithOverride ,
99
113
)
100
114
self ._gca_resource = self ._get_gca_resource (resource_name = reasoning_engine_name )
115
+ _register_api_method (self )
101
116
self ._operation_schemas = None
102
117
103
118
@property
@@ -233,7 +248,7 @@ def create(
233
248
extra_packages = extra_packages ,
234
249
)
235
250
# Update the package spec.
236
- package_spec = types .ReasoningEngineSpec .PackageSpec (
251
+ package_spec = aip_types .ReasoningEngineSpec .PackageSpec (
237
252
python_version = sys_version ,
238
253
pickle_object_gcs_uri = "{}/{}/{}" .format (
239
254
staging_bucket ,
@@ -253,7 +268,7 @@ def create(
253
268
gcs_dir_name ,
254
269
_REQUIREMENTS_FILE ,
255
270
)
256
- reasoning_engine_spec = types .ReasoningEngineSpec (
271
+ reasoning_engine_spec = aip_types .ReasoningEngineSpec (
257
272
package_spec = package_spec ,
258
273
)
259
274
class_methods_spec = _generate_class_methods_spec_or_raise (
@@ -264,7 +279,7 @@ def create(
264
279
parent = initializer .global_config .common_location_path (
265
280
project = sdk_resource .project , location = sdk_resource .location
266
281
),
267
- reasoning_engine = types .ReasoningEngine (
282
+ reasoning_engine = aip_types .ReasoningEngine (
268
283
name = reasoning_engine_name ,
269
284
display_name = display_name ,
270
285
description = description ,
@@ -289,6 +304,7 @@ def create(
289
304
credentials = sdk_resource .credentials ,
290
305
location_override = sdk_resource .location ,
291
306
)
307
+ _register_api_method (sdk_resource )
292
308
sdk_resource ._operation_schemas = None
293
309
return sdk_resource
294
310
@@ -431,30 +447,6 @@ def operation_schemas(self) -> Sequence[_utils.JsonDict]:
431
447
self ._operation_schemas = spec .get ("class_methods" , [])
432
448
return self ._operation_schemas
433
449
434
- def query (self , ** kwargs ) -> _utils .JsonDict :
435
- """Runs the Reasoning Engine to serve the user query.
436
-
437
- This will be based on the `.query(...)` method of the python object that
438
- was passed in when creating the Reasoning Engine.
439
-
440
- Args:
441
- **kwargs:
442
- Optional. The arguments of the `.query(...)` method.
443
-
444
- Returns:
445
- dict[str, Any]: The response from serving the user query.
446
- """
447
- response = self .execution_api_client .query_reasoning_engine (
448
- request = types .QueryReasoningEngineRequest (
449
- name = self .resource_name ,
450
- input = kwargs ,
451
- ),
452
- )
453
- output = _utils .to_dict (response )
454
- if "output" in output :
455
- return output .get ("output" )
456
- return output
457
-
458
450
459
451
def _validate_sys_version_or_raise (sys_version : str ) -> None :
460
452
"""Tries to validate the python system version."""
@@ -630,8 +622,8 @@ def _generate_update_request_or_raise(
630
622
"""Tries to generates the update request for the reasoning engine."""
631
623
is_spec_update = False
632
624
update_masks : List [str ] = []
633
- reasoning_engine_spec = types .ReasoningEngineSpec ()
634
- package_spec = types .ReasoningEngineSpec .PackageSpec ()
625
+ reasoning_engine_spec = aip_types .ReasoningEngineSpec ()
626
+ package_spec = aip_types .ReasoningEngineSpec .PackageSpec ()
635
627
if requirements is not None :
636
628
is_spec_update = True
637
629
update_masks .append ("spec.package_spec.requirements_gcs_uri" )
@@ -662,7 +654,7 @@ def _generate_update_request_or_raise(
662
654
reasoning_engine_spec .class_methods .extend (class_methods_spec )
663
655
update_masks .append ("spec.class_methods" )
664
656
665
- reasoning_engine_message = types .ReasoningEngine (name = resource_name )
657
+ reasoning_engine_message = aip_types .ReasoningEngine (name = resource_name )
666
658
if is_spec_update :
667
659
reasoning_engine_spec .package_spec = package_spec
668
660
reasoning_engine_message .spec = reasoning_engine_spec
@@ -684,6 +676,54 @@ def _generate_update_request_or_raise(
684
676
)
685
677
686
678
679
+ def _wrap_query_operation (method_name : str , doc : str ) -> Callable [..., _utils .JsonDict ]:
680
+ """Wraps a Reasoning Engine method, creating a callable for `query` API.
681
+
682
+ This function creates a callable object that executes the specified
683
+ Reasoning Engine method using the `query` API. It handles the creation of
684
+ the API request and the processing of the API response.
685
+
686
+ Args:
687
+ method_name: The name of the Reasoning Engine method to call.
688
+ doc: Documentation string for the method.
689
+
690
+ Returns:
691
+ A callable object that executes the method on the Reasoning Engine via
692
+ the `query` API.
693
+ """
694
+
695
+ def _method (self , ** kwargs ) -> _utils .JsonDict :
696
+ response = self .execution_api_client .query_reasoning_engine (
697
+ request = aip_types .QueryReasoningEngineRequest (
698
+ name = self .resource_name ,
699
+ input = kwargs ,
700
+ ),
701
+ )
702
+ output = _utils .to_dict (response )
703
+ return output .get ("output" , output )
704
+
705
+ _method .__name__ = method_name
706
+ _method .__doc__ = doc
707
+
708
+ return _method
709
+
710
+
711
+ def _register_api_method (obj : "ReasoningEngine" ):
712
+ """Registers Reasoning Engine API methods based on operation schemas.
713
+
714
+ This function registers `query` method on the ReasoningEngine object
715
+ to handle API calls based on the specified API mode.
716
+
717
+ Args:
718
+ obj: The ReasoningEngine object to augment with API methods.
719
+ """
720
+ query_method = _wrap_query_operation (
721
+ method_name = _DEFAULT_METHOD_NAME , doc = _DEFAULT_METHOD_DOCSTRING
722
+ )
723
+ # Binds the method to the object.
724
+ setattr (obj , _DEFAULT_METHOD_NAME , types .MethodType (query_method , obj ))
725
+
726
+
687
727
def _get_registered_operations (reasoning_engine : Any ) -> Dict [str , List [str ]]:
688
728
"""Retrieves registered operations for a ReasoningEngine."""
689
729
if isinstance (reasoning_engine , OperationRegistrable ):
0 commit comments