17
17
from __future__ import annotations
18
18
19
19
from unittest import mock
20
+ from unittest .mock import MagicMock
20
21
21
22
from google .api_core .gapic_v1 .method import DEFAULT
22
23
from google .api_core .retry import Retry
@@ -783,7 +784,12 @@ class TestVertexAICreateAutoMLTabularTrainingJobOperator:
783
784
@mock .patch ("google.cloud.aiplatform.datasets.TabularDataset" )
784
785
@mock .patch (VERTEX_AI_PATH .format ("auto_ml.AutoMLHook" ))
785
786
def test_execute (self , mock_hook , mock_dataset ):
786
- mock_hook .return_value .create_auto_ml_tabular_training_job .return_value = (None , "training_id" )
787
+ mock_hook .return_value = MagicMock (
788
+ ** {
789
+ "create_auto_ml_tabular_training_job.return_value" : (None , "training_id" ),
790
+ "get_credentials_and_project_id.return_value" : ("creds" , "project_id" ),
791
+ }
792
+ )
787
793
op = CreateAutoMLTabularTrainingJobOperator (
788
794
task_id = TASK_ID ,
789
795
gcp_conn_id = GCP_CONN_ID ,
@@ -798,7 +804,9 @@ def test_execute(self, mock_hook, mock_dataset):
798
804
)
799
805
op .execute (context = {"ti" : mock .MagicMock ()})
800
806
mock_hook .assert_called_once_with (gcp_conn_id = GCP_CONN_ID , impersonation_chain = IMPERSONATION_CHAIN )
801
- mock_dataset .assert_called_once_with (dataset_name = TEST_DATASET_ID )
807
+ mock_dataset .assert_called_once_with (
808
+ dataset_name = TEST_DATASET_ID , project = GCP_PROJECT , credentials = "creds"
809
+ )
802
810
mock_hook .return_value .create_auto_ml_tabular_training_job .assert_called_once_with (
803
811
project_id = GCP_PROJECT ,
804
812
region = GCP_LOCATION ,
0 commit comments