@@ -36,11 +36,11 @@ def _create_list_response(messages, token):
36
36
37
37
class TestStackdriverLoggingHandlerStandalone (unittest .TestCase ):
38
38
39
+ @mock .patch ('airflow.utils.log.stackdriver_task_handler.get_credentials_and_project_id' )
39
40
@mock .patch ('airflow.utils.log.stackdriver_task_handler.gcp_logging.Client' )
40
- def test_should_pass_message_to_client (self , mock_client ):
41
+ def test_should_pass_message_to_client (self , mock_client , mock_get_creds_and_project_id ):
41
42
transport_type = mock .MagicMock ()
42
43
stackdriver_task_handler = StackdriverTaskHandler (
43
- gcp_conn_id = None ,
44
44
transport = transport_type ,
45
45
labels = {"key" : 'value' }
46
46
)
@@ -54,7 +54,10 @@ def test_should_pass_message_to_client(self, mock_client):
54
54
transport_type .return_value .send .assert_called_once_with (
55
55
mock .ANY , 'test-message' , labels = {"key" : 'value' }, resource = Resource (type = 'global' , labels = {})
56
56
)
57
- mock_client .assert_called_once ()
57
+ mock_client .assert_called_once_with (
58
+ credentials = mock_get_creds_and_project_id .return_value ,
59
+ client_info = mock .ANY
60
+ )
58
61
59
62
60
63
class TestStackdriverLoggingHandlerTask (unittest .TestCase ):
@@ -73,8 +76,9 @@ def setUp(self) -> None:
73
76
self .ti .state = State .RUNNING
74
77
self .addCleanup (self .dag .clear )
75
78
79
+ @mock .patch ('airflow.utils.log.stackdriver_task_handler.get_credentials_and_project_id' )
76
80
@mock .patch ('airflow.utils.log.stackdriver_task_handler.gcp_logging.Client' )
77
- def test_should_set_labels (self , mock_client ):
81
+ def test_should_set_labels (self , mock_client , mock_get_creds_and_project_id ):
78
82
self .stackdriver_task_handler .set_context (self .ti )
79
83
self .logger .addHandler (self .stackdriver_task_handler )
80
84
@@ -92,8 +96,9 @@ def test_should_set_labels(self, mock_client):
92
96
mock .ANY , 'test-message' , labels = labels , resource = resource
93
97
)
94
98
99
+ @mock .patch ('airflow.utils.log.stackdriver_task_handler.get_credentials_and_project_id' )
95
100
@mock .patch ('airflow.utils.log.stackdriver_task_handler.gcp_logging.Client' )
96
- def test_should_append_labels (self , mock_client ):
101
+ def test_should_append_labels (self , mock_client , mock_get_creds_and_project_id ):
97
102
self .stackdriver_task_handler = StackdriverTaskHandler (
98
103
transport = self .transport_mock ,
99
104
labels = {"product.googleapis.com/task_id" : "test-value" }
@@ -116,11 +121,12 @@ def test_should_append_labels(self, mock_client):
116
121
mock .ANY , 'test-message' , labels = labels , resource = resource
117
122
)
118
123
124
+ @mock .patch ('airflow.utils.log.stackdriver_task_handler.get_credentials_and_project_id' )
119
125
@mock .patch (
120
126
'airflow.utils.log.stackdriver_task_handler.gcp_logging.Client' ,
121
127
** {'return_value.project' : 'asf-project' } # type: ignore
122
128
)
123
- def test_should_read_logs_for_all_try (self , mock_client ):
129
+ def test_should_read_logs_for_all_try (self , mock_client , mock_get_creds_and_project_id ):
124
130
mock_client .return_value .list_entries .return_value = _create_list_response (["MSG1" , "MSG2" ], None )
125
131
126
132
logs , metadata = self .stackdriver_task_handler .read (self .ti )
@@ -135,11 +141,12 @@ def test_should_read_logs_for_all_try(self, mock_client):
135
141
self .assertEqual (['MSG1\n MSG2' ], logs )
136
142
self .assertEqual ([{'end_of_log' : True }], metadata )
137
143
144
+ @mock .patch ('airflow.utils.log.stackdriver_task_handler.get_credentials_and_project_id' )
138
145
@mock .patch ( # type: ignore
139
146
'airflow.utils.log.stackdriver_task_handler.gcp_logging.Client' ,
140
147
** {'return_value.project' : 'asf-project' } # type: ignore
141
148
)
142
- def test_should_read_logs_for_task_with_quote (self , mock_client ):
149
+ def test_should_read_logs_for_task_with_quote (self , mock_client , mock_get_creds_and_project_id ):
143
150
mock_client .return_value .list_entries .return_value = _create_list_response (["MSG1" , "MSG2" ], None )
144
151
self .ti .task_id = "K\" OT"
145
152
logs , metadata = self .stackdriver_task_handler .read (self .ti )
@@ -154,11 +161,12 @@ def test_should_read_logs_for_task_with_quote(self, mock_client):
154
161
self .assertEqual (['MSG1\n MSG2' ], logs )
155
162
self .assertEqual ([{'end_of_log' : True }], metadata )
156
163
164
+ @mock .patch ('airflow.utils.log.stackdriver_task_handler.get_credentials_and_project_id' )
157
165
@mock .patch (
158
166
'airflow.utils.log.stackdriver_task_handler.gcp_logging.Client' ,
159
167
** {'return_value.project' : 'asf-project' } # type: ignore
160
168
)
161
- def test_should_read_logs_for_single_try (self , mock_client ):
169
+ def test_should_read_logs_for_single_try (self , mock_client , mock_get_creds_and_project_id ):
162
170
mock_client .return_value .list_entries .return_value = _create_list_response (["MSG1" , "MSG2" ], None )
163
171
164
172
logs , metadata = self .stackdriver_task_handler .read (self .ti , 3 )
@@ -174,8 +182,9 @@ def test_should_read_logs_for_single_try(self, mock_client):
174
182
self .assertEqual (['MSG1\n MSG2' ], logs )
175
183
self .assertEqual ([{'end_of_log' : True }], metadata )
176
184
185
+ @mock .patch ('airflow.utils.log.stackdriver_task_handler.get_credentials_and_project_id' )
177
186
@mock .patch ('airflow.utils.log.stackdriver_task_handler.gcp_logging.Client' )
178
- def test_should_read_logs_with_pagination (self , mock_client ):
187
+ def test_should_read_logs_with_pagination (self , mock_client , mock_get_creds_and_project_id ):
179
188
mock_client .return_value .list_entries .side_effect = [
180
189
_create_list_response (["MSG1" , "MSG2" ], "TOKEN1" ),
181
190
_create_list_response (["MSG3" , "MSG4" ], None ),
@@ -195,8 +204,9 @@ def test_should_read_logs_with_pagination(self, mock_client):
195
204
self .assertEqual (['MSG3\n MSG4' ], logs )
196
205
self .assertEqual ([{'end_of_log' : True }], metadata2 )
197
206
207
+ @mock .patch ('airflow.utils.log.stackdriver_task_handler.get_credentials_and_project_id' )
198
208
@mock .patch ('airflow.utils.log.stackdriver_task_handler.gcp_logging.Client' )
199
- def test_should_read_logs_with_download (self , mock_client ):
209
+ def test_should_read_logs_with_download (self , mock_client , mock_get_creds_and_project_id ):
200
210
mock_client .return_value .list_entries .side_effect = [
201
211
_create_list_response (["MSG1" , "MSG2" ], "TOKEN1" ),
202
212
_create_list_response (["MSG3" , "MSG4" ], None ),
@@ -207,11 +217,12 @@ def test_should_read_logs_with_download(self, mock_client):
207
217
self .assertEqual (['MSG1\n MSG2\n MSG3\n MSG4' ], logs )
208
218
self .assertEqual ([{'end_of_log' : True }], metadata1 )
209
219
220
+ @mock .patch ('airflow.utils.log.stackdriver_task_handler.get_credentials_and_project_id' )
210
221
@mock .patch (
211
222
'airflow.utils.log.stackdriver_task_handler.gcp_logging.Client' ,
212
223
** {'return_value.project' : 'asf-project' } # type: ignore
213
224
)
214
- def test_should_read_logs_with_custom_resources (self , mock_client ):
225
+ def test_should_read_logs_with_custom_resources (self , mock_client , mock_get_creds_and_project_id ):
215
226
resource = Resource (
216
227
type = "cloud_composer_environment" ,
217
228
labels = {
@@ -245,31 +256,26 @@ def test_should_read_logs_with_custom_resources(self, mock_client):
245
256
self .assertEqual (['TEXT\n TEXT' ], logs )
246
257
self .assertEqual ([{'end_of_log' : True }], metadata )
247
258
248
-
249
- class TestStackdriverTaskHandlerAuthorization (unittest .TestCase ):
250
-
259
+ @mock .patch ('airflow.utils.log.stackdriver_task_handler.get_credentials_and_project_id' )
251
260
@mock .patch ('airflow.utils.log.stackdriver_task_handler.gcp_logging.Client' )
252
- def test_should_fallback_to_adc (self , mock_client ):
261
+ def test_should_use_credentials (self , mock_client , mock_get_creds_and_project_id ):
253
262
stackdriver_task_handler = StackdriverTaskHandler (
254
- gcp_conn_id = None
263
+ gcp_key_path = "KEY_PATH" ,
255
264
)
256
265
257
266
client = stackdriver_task_handler ._client
258
267
259
- mock_client .assert_called_once_with (credentials = None , client_info = mock . ANY )
260
- self . assertEqual ( mock_client . return_value , client )
261
-
262
- @ mock . patch ( "airflow.providers.google.common.hooks.base_google.GoogleBaseHook" )
263
- @ mock . patch ( 'airflow.utils.log.stackdriver_task_handler.gcp_logging.Client' )
264
- def test_should_support_gcp_conn_id ( self , mock_client , mock_hook ):
265
- stackdriver_task_handler = StackdriverTaskHandler (
266
- gcp_conn_id = "test-gcp-conn"
268
+ mock_get_creds_and_project_id .assert_called_once_with (
269
+ key_path = 'KEY_PATH' ,
270
+ scopes = frozenset (
271
+ {
272
+ 'https://www.googleapis.com/auth/logging.write' ,
273
+ 'https://www.googleapis.com/auth/logging.read'
274
+ }
275
+ )
267
276
)
268
-
269
- client = stackdriver_task_handler ._client
270
-
271
277
mock_client .assert_called_once_with (
272
- credentials = mock_hook . return_value . _get_credentials .return_value ,
278
+ credentials = mock_get_creds_and_project_id .return_value ,
273
279
client_info = mock .ANY
274
280
)
275
281
self .assertEqual (mock_client .return_value , client )
0 commit comments