@@ -274,6 +274,8 @@ def register_operations(self) -> Dict[str, List[str]]:
274
274
"""
275
275
_TEST_METHOD_TO_BE_UNREGISTERED_NAME = "method_to_be_unregistered"
276
276
_TEST_QUERY_PROMPT = "Find the first fibonacci number greater than 999"
277
+ _TEST_AGENT_ENGINE_ENV_KEY = "GOOGLE_CLOUD_AGENT_ENGINE_ENV"
278
+ _TEST_AGENT_ENGINE_ENV_VALUE = "test_env_value"
277
279
_TEST_AGENT_ENGINE_GCS_URI = "{}/{}/{}" .format (
278
280
_TEST_STAGING_BUCKET ,
279
281
_TEST_GCS_DIR_NAME ,
@@ -673,6 +675,8 @@ class TestAgentEngine:
673
675
def setup_method (self ):
674
676
importlib .reload (initializer )
675
677
importlib .reload (aiplatform )
678
+ importlib .reload (os )
679
+ os .environ [_TEST_AGENT_ENGINE_ENV_KEY ] = _TEST_AGENT_ENGINE_ENV_VALUE
676
680
aiplatform .init (
677
681
project = _TEST_PROJECT ,
678
682
location = _TEST_LOCATION ,
@@ -801,6 +805,119 @@ def test_create_agent_engine_requirements_from_file(
801
805
retry = _TEST_RETRY ,
802
806
)
803
807
808
+ def test_create_agent_engine_with_env_vars_dict (
809
+ self ,
810
+ create_agent_engine_mock ,
811
+ cloud_storage_create_bucket_mock ,
812
+ tarfile_open_mock ,
813
+ cloudpickle_dump_mock ,
814
+ cloudpickle_load_mock ,
815
+ importlib_metadata_version_mock ,
816
+ get_agent_engine_mock ,
817
+ get_gca_resource_mock ,
818
+ ):
819
+ agent_engines .create (
820
+ self .test_agent ,
821
+ display_name = _TEST_AGENT_ENGINE_DISPLAY_NAME ,
822
+ requirements = _TEST_AGENT_ENGINE_REQUIREMENTS ,
823
+ extra_packages = [_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH ],
824
+ env_vars = {
825
+ "TEST_ENV_VAR" : "TEST_ENV_VAR_VALUE" ,
826
+ "TEST_ENV_VAR_2" : "TEST_ENV_VAR_VALUE_2" ,
827
+ "TEST_SECRET_ENV_VAR" : {
828
+ "secret" : "TEST_SECRET_NAME_1" ,
829
+ "version" : "TEST_SECRET_VERSION_1" ,
830
+ },
831
+ "TEST_SECRET_ENV_VAR_2" : types .SecretRef (
832
+ secret = "TEST_SECRET_NAME_2" ,
833
+ version = "TEST_SECRET_VERSION_2" ,
834
+ ),
835
+ },
836
+ )
837
+ test_spec = types .ReasoningEngineSpec (
838
+ package_spec = _TEST_AGENT_ENGINE_PACKAGE_SPEC ,
839
+ deployment_spec = types .ReasoningEngineSpec .DeploymentSpec (
840
+ env = [
841
+ types .EnvVar (name = "TEST_ENV_VAR" , value = "TEST_ENV_VAR_VALUE" ),
842
+ types .EnvVar (name = "TEST_ENV_VAR_2" , value = "TEST_ENV_VAR_VALUE_2" ),
843
+ ],
844
+ secret_env = [
845
+ types .SecretEnvVar (
846
+ name = "TEST_SECRET_ENV_VAR" ,
847
+ secret_ref = {
848
+ "secret" : "TEST_SECRET_NAME_1" ,
849
+ "version" : "TEST_SECRET_VERSION_1" ,
850
+ },
851
+ ),
852
+ types .SecretEnvVar (
853
+ name = "TEST_SECRET_ENV_VAR_2" ,
854
+ secret_ref = types .SecretRef (
855
+ secret = "TEST_SECRET_NAME_2" ,
856
+ version = "TEST_SECRET_VERSION_2" ,
857
+ ),
858
+ ),
859
+ ],
860
+ ),
861
+ )
862
+ test_spec .class_methods .append (_TEST_AGENT_ENGINE_QUERY_SCHEMA )
863
+ create_agent_engine_mock .assert_called_with (
864
+ parent = _TEST_PARENT ,
865
+ reasoning_engine = types .ReasoningEngine (
866
+ display_name = _TEST_AGENT_ENGINE_DISPLAY_NAME ,
867
+ spec = test_spec ,
868
+ ),
869
+ )
870
+ get_agent_engine_mock .assert_called_with (
871
+ name = _TEST_AGENT_ENGINE_RESOURCE_NAME ,
872
+ retry = _TEST_RETRY ,
873
+ )
874
+
875
+ def test_create_agent_engine_with_env_vars_list (
876
+ self ,
877
+ create_agent_engine_mock ,
878
+ cloud_storage_create_bucket_mock ,
879
+ tarfile_open_mock ,
880
+ cloudpickle_dump_mock ,
881
+ cloudpickle_load_mock ,
882
+ importlib_metadata_version_mock ,
883
+ get_agent_engine_mock ,
884
+ get_gca_resource_mock ,
885
+ ):
886
+ agent_engines .create (
887
+ self .test_agent ,
888
+ display_name = _TEST_AGENT_ENGINE_DISPLAY_NAME ,
889
+ requirements = _TEST_AGENT_ENGINE_REQUIREMENTS ,
890
+ extra_packages = [_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH ],
891
+ env_vars = [_TEST_AGENT_ENGINE_ENV_KEY , _TEST_AGENT_ENGINE_ENV_KEY ],
892
+ )
893
+ test_spec = types .ReasoningEngineSpec (
894
+ package_spec = _TEST_AGENT_ENGINE_PACKAGE_SPEC ,
895
+ deployment_spec = types .ReasoningEngineSpec .DeploymentSpec (
896
+ env = [
897
+ types .EnvVar (
898
+ name = _TEST_AGENT_ENGINE_ENV_KEY ,
899
+ value = _TEST_AGENT_ENGINE_ENV_VALUE ,
900
+ ),
901
+ types .EnvVar (
902
+ name = _TEST_AGENT_ENGINE_ENV_KEY ,
903
+ value = _TEST_AGENT_ENGINE_ENV_VALUE ,
904
+ ),
905
+ ],
906
+ ),
907
+ )
908
+ test_spec .class_methods .append (_TEST_AGENT_ENGINE_QUERY_SCHEMA )
909
+ create_agent_engine_mock .assert_called_with (
910
+ parent = _TEST_PARENT ,
911
+ reasoning_engine = types .ReasoningEngine (
912
+ display_name = _TEST_AGENT_ENGINE_DISPLAY_NAME ,
913
+ spec = test_spec ,
914
+ ),
915
+ )
916
+ get_agent_engine_mock .assert_called_with (
917
+ name = _TEST_AGENT_ENGINE_RESOURCE_NAME ,
918
+ retry = _TEST_RETRY ,
919
+ )
920
+
804
921
# pytest does not allow absl.testing.parameterized.named_parameters.
805
922
@pytest .mark .parametrize (
806
923
"test_case_name, test_kwargs, want_request" ,
@@ -923,6 +1040,48 @@ def test_create_agent_engine_requirements_from_file(
923
1040
update_mask = field_mask_pb2 .FieldMask (paths = ["description" ]),
924
1041
),
925
1042
),
1043
+ (
1044
+ "Update the environment variables" ,
1045
+ {
1046
+ "env_vars" : {
1047
+ _TEST_AGENT_ENGINE_ENV_KEY : _TEST_AGENT_ENGINE_ENV_VALUE ,
1048
+ "TEST_SECRET_ENV_VAR" : {
1049
+ "secret" : "TEST_SECRET_NAME" ,
1050
+ "version" : "TEST_SECRET_VERSION" ,
1051
+ },
1052
+ },
1053
+ },
1054
+ types .reasoning_engine_service .UpdateReasoningEngineRequest (
1055
+ reasoning_engine = types .ReasoningEngine (
1056
+ name = _TEST_AGENT_ENGINE_RESOURCE_NAME ,
1057
+ spec = types .ReasoningEngineSpec (
1058
+ deployment_spec = types .ReasoningEngineSpec .DeploymentSpec (
1059
+ env = [
1060
+ types .EnvVar (
1061
+ name = _TEST_AGENT_ENGINE_ENV_KEY ,
1062
+ value = _TEST_AGENT_ENGINE_ENV_VALUE ,
1063
+ ),
1064
+ ],
1065
+ secret_env = [
1066
+ types .SecretEnvVar (
1067
+ name = "TEST_SECRET_ENV_VAR" ,
1068
+ secret_ref = types .SecretRef (
1069
+ secret = "TEST_SECRET_NAME" ,
1070
+ version = "TEST_SECRET_VERSION" ,
1071
+ ),
1072
+ ),
1073
+ ],
1074
+ ),
1075
+ ),
1076
+ ),
1077
+ update_mask = field_mask_pb2 .FieldMask (
1078
+ paths = [
1079
+ "spec.deployment_spec.env" ,
1080
+ "spec.deployment_spec.secret_env" ,
1081
+ ],
1082
+ ),
1083
+ ),
1084
+ ),
926
1085
],
927
1086
)
928
1087
def test_update_agent_engine (
@@ -1832,6 +1991,78 @@ def test_create_agent_engine_with_invalid_register_operations_method(
1832
1991
requirements = _TEST_AGENT_ENGINE_REQUIREMENTS ,
1833
1992
)
1834
1993
1994
+ def test_create_agent_engine_with_invalid_secret_ref_env_var (
1995
+ self ,
1996
+ create_agent_engine_mock ,
1997
+ cloud_storage_create_bucket_mock ,
1998
+ tarfile_open_mock ,
1999
+ cloudpickle_dump_mock ,
2000
+ cloudpickle_load_mock ,
2001
+ importlib_metadata_version_mock ,
2002
+ get_agent_engine_mock ,
2003
+ ):
2004
+ with pytest .raises (ValueError , match = "Failed to convert to secret ref" ):
2005
+ agent_engines .create (
2006
+ self .test_agent ,
2007
+ display_name = _TEST_AGENT_ENGINE_DISPLAY_NAME ,
2008
+ requirements = _TEST_AGENT_ENGINE_REQUIREMENTS ,
2009
+ env_vars = {
2010
+ "TEST_ENV_VAR" : {
2011
+ "name" : "TEST_SECRET_NAME" , # "name" should be "secret"
2012
+ "version" : "TEST_SECRET_VERSION" ,
2013
+ },
2014
+ },
2015
+ )
2016
+
2017
+ def test_create_agent_engine_with_unknown_env_var (
2018
+ self ,
2019
+ create_agent_engine_mock ,
2020
+ cloud_storage_create_bucket_mock ,
2021
+ tarfile_open_mock ,
2022
+ cloudpickle_dump_mock ,
2023
+ cloudpickle_load_mock ,
2024
+ importlib_metadata_version_mock ,
2025
+ get_agent_engine_mock ,
2026
+ ):
2027
+ with pytest .raises (ValueError , match = "Env var not found in os.environ" ):
2028
+ agent_engines .create (
2029
+ self .test_agent ,
2030
+ display_name = _TEST_AGENT_ENGINE_DISPLAY_NAME ,
2031
+ requirements = _TEST_AGENT_ENGINE_REQUIREMENTS ,
2032
+ # Assumption: "UNKNOWN_TEST_ENV_VAR" not in os.environ
2033
+ env_vars = ["UNKNOWN_TEST_ENV_VAR" ],
2034
+ )
2035
+
2036
+ def test_create_agent_engine_with_invalid_type_env_var (
2037
+ self ,
2038
+ create_agent_engine_mock ,
2039
+ cloud_storage_create_bucket_mock ,
2040
+ tarfile_open_mock ,
2041
+ cloudpickle_dump_mock ,
2042
+ cloudpickle_load_mock ,
2043
+ importlib_metadata_version_mock ,
2044
+ get_agent_engine_mock ,
2045
+ ):
2046
+ with pytest .raises (TypeError , match = "Unknown value type in env_vars" ):
2047
+ agent_engines .create (
2048
+ self .test_agent ,
2049
+ display_name = _TEST_AGENT_ENGINE_DISPLAY_NAME ,
2050
+ requirements = _TEST_AGENT_ENGINE_REQUIREMENTS ,
2051
+ env_vars = {
2052
+ "TEST_ENV_VAR" : 0.01 , # should be a string or dict or SecretRef
2053
+ },
2054
+ )
2055
+ with pytest .raises (TypeError , match = "env_vars must be a list or a dict" ):
2056
+ agent_engines .create (
2057
+ self .test_agent ,
2058
+ display_name = _TEST_AGENT_ENGINE_DISPLAY_NAME ,
2059
+ requirements = _TEST_AGENT_ENGINE_REQUIREMENTS ,
2060
+ env_vars = types .SecretRef ( # should be a list or dict
2061
+ secret = "TEST_SECRET_NAME" ,
2062
+ version = "TEST_SECRET_VERSION" ,
2063
+ ),
2064
+ )
2065
+
1835
2066
def test_update_agent_engine_unspecified_staging_bucket (
1836
2067
self ,
1837
2068
update_agent_engine_mock ,
@@ -1963,7 +2194,7 @@ def test_update_agent_engine_with_no_updates(
1963
2194
ValueError ,
1964
2195
match = (
1965
2196
"At least one of `agent_engine`, `requirements`, "
1966
- "`extra_packages`, `display_name`, or `description ` "
2197
+ "`extra_packages`, `display_name`, `description`, or `env_vars ` "
1967
2198
"must be specified."
1968
2199
),
1969
2200
):
0 commit comments