@@ -119,6 +119,91 @@ def register_operations(self, **kwargs):
119
119
"""Register the user provided operations (modes and methods)."""
120
120
121
121
122
+ def _wrap_agent_operation (agent : Any , operation : str ):
123
+ def _method (self , ** kwargs ):
124
+ if not self ._tmpl_attrs .get ("agent" ):
125
+ self .set_up ()
126
+ return getattr (self ._tmpl_attrs ["agent" ], operation )(** kwargs )
127
+
128
+ _method .__name__ = operation
129
+ _method .__doc__ = getattr (agent , operation ).__doc__
130
+ return _method
131
+
132
+
133
+ class ModuleAgent (Cloneable , OperationRegistrable ):
134
+ """Agent that is defined by a module and an agent name.
135
+
136
+ This agent is instantiated by importing a module and instantiating an agent
137
+ from that module. It also allows to register operations that are defined in
138
+ the agent.
139
+ """
140
+
141
+ def __init__ (
142
+ self ,
143
+ * ,
144
+ module_name : str ,
145
+ agent_name : str ,
146
+ register_operations : Dict [str , Sequence [str ]],
147
+ ):
148
+ """Initializes a module-based agent.
149
+
150
+ Args:
151
+ module_name (str):
152
+ Required. The name of the module to import.
153
+ agent_name (str):
154
+ Required. The name of the agent in the module to instantiate.
155
+ register_operations (Dict[str, Sequence[str]]):
156
+ Required. A dictionary of API modes to a list of method names.
157
+ """
158
+ self ._tmpl_attrs = {
159
+ "module_name" : module_name ,
160
+ "agent_name" : agent_name ,
161
+ "register_operations" : register_operations ,
162
+ }
163
+
164
+ def clone (self ):
165
+ """Return a clone of the agent."""
166
+ return ModuleAgent (
167
+ module_name = self ._tmpl_attrs .get ("module_name" ),
168
+ agent_name = self ._tmpl_attrs .get ("agent_name" ),
169
+ register_operations = self ._tmpl_attrs .get ("register_operations" ),
170
+ )
171
+
172
+ def register_operations (self ) -> Dict [str , Sequence [str ]]:
173
+ return self ._tmpl_attrs .get ("register_operations" )
174
+
175
+ def set_up (self ) -> None :
176
+ """Sets up the agent for execution of queries at runtime.
177
+
178
+ It runs the code to import the agent from the module, and registers the
179
+ operations of the agent.
180
+ """
181
+ import importlib
182
+
183
+ module = importlib .import_module (self ._tmpl_attrs .get ("module_name" ))
184
+ try :
185
+ importlib .reload (module )
186
+ except Exception as e :
187
+ _LOGGER .warning (
188
+ f"Failed to reload module { self ._tmpl_attrs .get ('module_name' )} : { e } "
189
+ )
190
+ agent_name = self ._tmpl_attrs .get ("agent_name" )
191
+ try :
192
+ agent = getattr (module , agent_name )
193
+ except AttributeError as e :
194
+ raise AttributeError (
195
+ f"Agent { agent_name } not found in module "
196
+ f"{ self ._tmpl_attrs .get ('module_name' )} "
197
+ ) from e
198
+ self ._tmpl_attrs ["agent" ] = agent
199
+ if hasattr (agent , "set_up" ):
200
+ agent .set_up ()
201
+ for operations in self .register_operations ().values ():
202
+ for operation in operations :
203
+ op = _wrap_agent_operation (agent , operation )
204
+ setattr (self , operation , types .MethodType (op , self ))
205
+
206
+
122
207
class AgentEngine (base .VertexAiResourceNounWithFutureManager ):
123
208
"""Represents a Vertex AI Agent Engine resource."""
124
209
@@ -1160,6 +1245,16 @@ def _generate_class_methods_spec_or_raise(
1160
1245
ValueError: If a method defined in `register_operations` is not found on
1161
1246
the AgentEngine.
1162
1247
"""
1248
+ if isinstance (agent_engine , ModuleAgent ):
1249
+ # We do a dry-run of setting up the agent engine to have the operations
1250
+ # needed for registration.
1251
+ agent_engine = agent_engine .clone ()
1252
+ try :
1253
+ agent_engine .set_up ()
1254
+ except Exception as e :
1255
+ raise ValueError (
1256
+ f"Failed to set up agent engine { agent_engine } : { e } "
1257
+ ) from e
1163
1258
class_methods_spec = []
1164
1259
for mode , method_names in operations .items ():
1165
1260
for method_name in method_names :
0 commit comments