32
32
Literal ,
33
33
Optional ,
34
34
Sequence ,
35
+ Type ,
36
+ TypeVar ,
35
37
Union ,
36
38
overload ,
37
39
TYPE_CHECKING ,
66
68
except ImportError :
67
69
PIL_Image = None
68
70
71
+
72
+ T = TypeVar ("T" )
73
+
74
+
69
75
# Re-exporting some GAPIC types
70
76
71
77
# GAPIC types used in request
@@ -1627,8 +1633,9 @@ def __init__(
1627
1633
if response_schema is None :
1628
1634
raw_schema = None
1629
1635
else :
1630
- gapic_schema_dict = _convert_schema_dict_to_gapic (response_schema )
1631
- raw_schema = aiplatform_types .Schema (gapic_schema_dict )
1636
+ raw_schema = FunctionDeclaration (
1637
+ name = "tmp" , parameters = response_schema
1638
+ )._raw_function_declaration .parameters
1632
1639
self ._raw_generation_config = gapic_content_types .GenerationConfig (
1633
1640
temperature = temperature ,
1634
1641
top_p = top_p ,
@@ -1660,8 +1667,12 @@ def _from_gapic(
1660
1667
1661
1668
@classmethod
1662
1669
def from_dict (cls , generation_config_dict : Dict [str , Any ]) -> "GenerationConfig" :
1663
- raw_generation_config = gapic_content_types .GenerationConfig (
1664
- generation_config_dict
1670
+ generation_config_dict = copy .deepcopy (generation_config_dict )
1671
+ response_schema = generation_config_dict .get ("response_schema" )
1672
+ if response_schema :
1673
+ _fix_schema_dict_for_gapic_in_place (response_schema )
1674
+ raw_generation_config = _dict_to_proto (
1675
+ gapic_content_types .GenerationConfig , generation_config_dict
1665
1676
)
1666
1677
return cls ._from_gapic (raw_generation_config = raw_generation_config )
1667
1678
@@ -1872,12 +1883,11 @@ def _from_gapic(
1872
1883
@classmethod
1873
1884
def from_dict (cls , tool_dict : Dict [str , Any ]) -> "Tool" :
1874
1885
tool_dict = copy .deepcopy (tool_dict )
1875
- function_declarations = tool_dict ["function_declarations" ]
1876
- for function_declaration in function_declarations :
1877
- function_declaration ["parameters" ] = _convert_schema_dict_to_gapic (
1878
- function_declaration ["parameters" ]
1879
- )
1880
- raw_tool = gapic_tool_types .Tool (tool_dict )
1886
+ for function_declaration in tool_dict .get ("function_declarations" ) or []:
1887
+ parameters = function_declaration .get ("parameters" )
1888
+ if parameters :
1889
+ _fix_schema_dict_for_gapic_in_place (parameters )
1890
+ raw_tool = _dict_to_proto (aiplatform_types .Tool , tool_dict )
1881
1891
return cls ._from_gapic (raw_tool = raw_tool )
1882
1892
1883
1893
def to_dict (self ) -> Dict [str , Any ]:
@@ -2035,8 +2045,9 @@ def __init__(
2035
2045
description: Description and purpose of the function.
2036
2046
Model uses it to decide how and whether to call the function.
2037
2047
"""
2038
- gapic_schema_dict = _convert_schema_dict_to_gapic (parameters )
2039
- raw_schema = aiplatform_types .Schema (gapic_schema_dict )
2048
+ parameters = copy .deepcopy (parameters )
2049
+ _fix_schema_dict_for_gapic_in_place (parameters )
2050
+ raw_schema = _dict_to_proto (aiplatform_types .Schema , parameters )
2040
2051
self ._raw_function_declaration = gapic_tool_types .FunctionDeclaration (
2041
2052
name = name , description = description , parameters = raw_schema
2042
2053
)
@@ -2052,6 +2063,7 @@ def __repr__(self) -> str:
2052
2063
return self ._raw_function_declaration .__repr__ ()
2053
2064
2054
2065
2066
+ # TODO: Remove this function once Reasoning Engines moves away from it.
2055
2067
def _convert_schema_dict_to_gapic (schema_dict : Dict [str , Any ]) -> Dict [str , Any ]:
2056
2068
"""Converts a JsonSchema to a dict that the GAPIC Schema class accepts."""
2057
2069
gapic_schema_dict = copy .copy (schema_dict )
@@ -2070,6 +2082,20 @@ def _convert_schema_dict_to_gapic(schema_dict: Dict[str, Any]) -> Dict[str, Any]
2070
2082
return gapic_schema_dict
2071
2083
2072
2084
2085
+ def _fix_schema_dict_for_gapic_in_place (schema_dict : Dict [str , Any ]) -> None :
2086
+ """Converts a JsonSchema to a dict that the Schema proto class accepts."""
2087
+ schema_dict ["type" ] = schema_dict ["type" ].upper ()
2088
+
2089
+ items_schema = schema_dict .get ("items" )
2090
+ if items_schema :
2091
+ _fix_schema_dict_for_gapic_in_place (items_schema )
2092
+
2093
+ properties = schema_dict .get ("properties" )
2094
+ if properties :
2095
+ for property_schema in properties .values ():
2096
+ _fix_schema_dict_for_gapic_in_place (property_schema )
2097
+
2098
+
2073
2099
class CallableFunctionDeclaration (FunctionDeclaration ):
2074
2100
"""A function declaration plus a function."""
2075
2101
@@ -2139,8 +2165,9 @@ def _from_gapic(
2139
2165
2140
2166
@classmethod
2141
2167
def from_dict (cls , response_dict : Dict [str , Any ]) -> "GenerationResponse" :
2142
- raw_response = gapic_prediction_service_types .GenerateContentResponse ()
2143
- json_format .ParseDict (response_dict , raw_response ._pb )
2168
+ raw_response = _dict_to_proto (
2169
+ gapic_prediction_service_types .GenerateContentResponse , response_dict
2170
+ )
2144
2171
return cls ._from_gapic (raw_response = raw_response )
2145
2172
2146
2173
def to_dict (self ) -> Dict [str , Any ]:
@@ -2209,8 +2236,7 @@ def _from_gapic(cls, raw_candidate: gapic_content_types.Candidate) -> "Candidate
2209
2236
2210
2237
@classmethod
2211
2238
def from_dict (cls , candidate_dict : Dict [str , Any ]) -> "Candidate" :
2212
- raw_candidate = gapic_content_types .Candidate ()
2213
- json_format .ParseDict (candidate_dict , raw_candidate ._pb )
2239
+ raw_candidate = _dict_to_proto (gapic_content_types .Candidate , candidate_dict )
2214
2240
return cls ._from_gapic (raw_candidate = raw_candidate )
2215
2241
2216
2242
def to_dict (self ) -> Dict [str , Any ]:
@@ -2310,8 +2336,7 @@ def _from_gapic(cls, raw_content: gapic_content_types.Content) -> "Content":
2310
2336
2311
2337
@classmethod
2312
2338
def from_dict (cls , content_dict : Dict [str , Any ]) -> "Content" :
2313
- raw_content = gapic_content_types .Content ()
2314
- json_format .ParseDict (content_dict , raw_content ._pb )
2339
+ raw_content = _dict_to_proto (gapic_content_types .Content , content_dict )
2315
2340
return cls ._from_gapic (raw_content = raw_content )
2316
2341
2317
2342
def to_dict (self ) -> Dict [str , Any ]:
@@ -2381,8 +2406,7 @@ def _from_gapic(cls, raw_part: gapic_content_types.Part) -> "Part":
2381
2406
2382
2407
@classmethod
2383
2408
def from_dict (cls , part_dict : Dict [str , Any ]) -> "Part" :
2384
- raw_part = gapic_content_types .Part ()
2385
- json_format .ParseDict (part_dict , raw_part ._pb )
2409
+ raw_part = _dict_to_proto (gapic_content_types .Part , part_dict )
2386
2410
return cls ._from_gapic (raw_part = raw_part )
2387
2411
2388
2412
def __repr__ (self ) -> str :
@@ -2510,7 +2534,9 @@ def _from_gapic(
2510
2534
2511
2535
@classmethod
2512
2536
def from_dict (cls , safety_setting_dict : Dict [str , Any ]) -> "SafetySetting" :
2513
- raw_safety_setting = gapic_content_types .SafetySetting (safety_setting_dict )
2537
+ raw_safety_setting = _dict_to_proto (
2538
+ aiplatform_types .SafetySetting , safety_setting_dict
2539
+ )
2514
2540
return cls ._from_gapic (raw_safety_setting = raw_safety_setting )
2515
2541
2516
2542
def to_dict (self ) -> Dict [str , Any ]:
@@ -2760,6 +2786,15 @@ def _proto_to_dict(message) -> Dict[str, Any]:
2760
2786
)
2761
2787
2762
2788
2789
+ def _dict_to_proto (message_type : Type [T ], message_dict : Dict [str , Any ]) -> T :
2790
+ """Converts a dictionary to a proto-plus protobuf message."""
2791
+ # We cannot just use `message = message_type(message_dict)` because
2792
+ # it fails for classes where GAPIC has renamed proto fields.
2793
+ message = message_type ()
2794
+ json_format .ParseDict (message_dict , message ._pb )
2795
+ return message
2796
+
2797
+
2763
2798
def _dict_to_pretty_string (d : dict ) -> str :
2764
2799
"""Format dict as a pretty-printed JSON string."""
2765
2800
return json .dumps (d , indent = 2 )
0 commit comments