Content-Length: 932629 | pFad | https://github.com/apache/airflow/commit/5b680e27e8118861ef484c00a4b87c6885b0a518

8B Don't use connection to store task handler credentials (#9381) · apache/airflow@5b680e2 · GitHub
Skip to content

Commit 5b680e2

Browse files
authored
Don't use connection to store task handler credentials (#9381)
1 parent 583f213 commit 5b680e2

File tree

7 files changed

+84
-46
lines changed

7 files changed

+84
-46
lines changed

airflow/config_templates/airflow_local_settings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,15 +218,15 @@
218218

219219
DEFAULT_LOGGING_CONFIG['handlers'].update(WASB_REMOTE_HANDLERS)
220220
elif REMOTE_BASE_LOG_FOLDER.startswith('stackdriver://'):
221-
gcp_conn_id = conf.get('core', 'REMOTE_LOG_CONN_ID', fallback=None)
221+
key_path = conf.get('logging', 'STACKDRIVER_KEY_PATH', fallback=None)
222222
# stackdriver://github.com/airflow-tasks => airflow-tasks
223223
log_name = urlparse(REMOTE_BASE_LOG_FOLDER).path[1:]
224224
STACKDRIVER_REMOTE_HANDLERS = {
225225
'task': {
226226
'class': 'airflow.utils.log.stackdriver_task_handler.StackdriverTaskHandler',
227227
'formatter': 'airflow',
228228
'name': log_name,
229-
'gcp_conn_id': gcp_conn_id
229+
'gcp_key_path': key_path
230230
}
231231
}
232232

airflow/config_templates/config.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,16 @@
394394
type: string
395395
example: ~
396396
default: ""
397+
- name: stackdriver_key_path
398+
description: |
399+
Path to GCP Credential JSON file. If ommited, authorization based on `the Application Default
400+
Credentials
401+
<https://cloud.google.com/docs/authentication/production#finding_credentials_automatically>`__ will
402+
be used.
403+
version_added: ~
404+
type: string
405+
example: ~
406+
default: ""
397407
- name: remote_base_log_folder
398408
description: |
399409
Storage bucket URL for remote logging

airflow/config_templates/default_airflow.cfg

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,12 @@ remote_logging = False
225225
# location.
226226
remote_log_conn_id =
227227

228+
# Path to GCP Credential JSON file. If ommited, authorization based on `the Application Default
229+
# Credentials
230+
# <https://cloud.google.com/docs/authentication/production#finding_credentials_automatically>`__ will
231+
# be used.
232+
stackdriver_key_path =
233+
228234
# Storage bucket URL for remote logging
229235
# S3 buckets should start with "s3://"
230236
# Cloudwatch log groups should start with "cloudwatch://"

airflow/providers/google/cloud/utils/credentials_provider.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import logging
2424
import tempfile
2525
from contextlib import ExitStack, contextmanager
26-
from typing import Dict, Optional, Sequence, Tuple
26+
from typing import Collection, Dict, Optional, Sequence, Tuple
2727
from urllib.parse import urlencode
2828

2929
import google.auth
@@ -179,7 +179,8 @@ def provide_gcp_conn_and_credentials(
179179
def get_credentials_and_project_id(
180180
key_path: Optional[str] = None,
181181
keyfile_dict: Optional[Dict[str, str]] = None,
182-
scopes: Optional[Sequence[str]] = None,
182+
# See: https://github.com/PyCQA/pylint/issues/2377
183+
scopes: Optional[Collection[str]] = None, # pylint: disable=unsubscriptable-object
183184
delegate_to: Optional[str] = None
184185
) -> Tuple[google.auth.credentials.Credentials, str]:
185186
"""

airflow/utils/log/stackdriver_task_handler.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
Handler that integrates with Stackdriver
1919
"""
2020
import logging
21-
from typing import Dict, List, Optional, Tuple, Type
21+
from typing import Collection, Dict, List, Optional, Tuple, Type
2222

2323
from cached_property import cached_property
2424
from google.api_core.gapic_v1.client_info import ClientInfo
@@ -28,10 +28,16 @@
2828

2929
from airflow import version
3030
from airflow.models import TaskInstance
31+
from airflow.providers.google.cloud.utils.credentials_provider import get_credentials_and_project_id
3132

3233
DEFAULT_LOGGER_NAME = "airflow"
3334
_GLOBAL_RESOURCE = Resource(type="global", labels={})
3435

36+
_DEFAULT_SCOPESS = frozenset([
37+
"https://www.googleapis.com/auth/logging.read",
38+
"https://www.googleapis.com/auth/logging.write"
39+
])
40+
3541

3642
class StackdriverTaskHandler(logging.Handler):
3743
"""Handler that directly makes Stackdriver logging API calls.
@@ -45,11 +51,14 @@ class StackdriverTaskHandler(logging.Handler):
4551
4652
This handler supports both an asynchronous and synchronous transport.
4753
48-
:param gcp_conn_id: Connection ID that will be used for authorization to the Google Cloud Platform.
49-
If omitted, authorization based on `the Application Default Credentials
54+
55+
:param gcp_key_path: Path to GCP Credential JSON file.
56+
If ommited, authorization based on `the Application Default Credentials
5057
<https://cloud.google.com/docs/authentication/production#finding_credentials_automatically>`__ will
5158
be used.
52-
:type gcp_conn_id: str
59+
:type gcp_key_path: str
60+
:param scopes: OAuth scopes for the credentials,
61+
:type scopes: Sequence[str]
5362
:param name: the name of the custom log in Stackdriver Logging. Defaults
5463
to 'airflow'. The name of the Python logger will be represented
5564
in the ``python_logger`` field.
@@ -74,14 +83,18 @@ class StackdriverTaskHandler(logging.Handler):
7483

7584
def __init__(
7685
self,
77-
gcp_conn_id: Optional[str] = None,
86+
gcp_key_path: Optional[str] = None,
87+
# See: https://github.com/PyCQA/pylint/issues/2377
88+
scopes: Optional[Collection[str]] = _DEFAULT_SCOPESS, # pylint: disable=unsubscriptable-object
7889
name: str = DEFAULT_LOGGER_NAME,
7990
transport: Type[Transport] = BackgroundThreadTransport,
8091
resource: Resource = _GLOBAL_RESOURCE,
8192
labels: Optional[Dict[str, str]] = None,
8293
):
8394
super().__init__()
84-
self.gcp_conn_id = gcp_conn_id
95+
self.gcp_key_path: Optional[str] = gcp_key_path
96+
# See: https://github.com/PyCQA/pylint/issues/2377
97+
self.scopes: Optional[Collection[str]] = scopes # pylint: disable=unsubscriptable-object
8598
self.name: str = name
8699
self.transport_type: Type[Transport] = transport
87100
self.resource: Resource = resource
@@ -91,14 +104,10 @@ def __init__(
91104
@cached_property
92105
def _client(self) -> gcp_logging.Client:
93106
"""Google Cloud Library API client"""
94-
if self.gcp_conn_id:
95-
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
96-
97-
hook = GoogleBaseHook(gcp_conn_id=self.gcp_conn_id)
98-
credentials = hook._get_credentials() # pylint: disable=protected-access
99-
else:
100-
# Use Application Default Credentials
101-
credentials = None
107+
credentials = get_credentials_and_project_id(
108+
key_path=self.gcp_key_path,
109+
scopes=self.scopes,
110+
)
102111
client = gcp_logging.Client(
103112
credentials=credentials,
104113
client_info=ClientInfo(client_library_version='airflow_v' + version.version)

docs/howto/write-logs.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,12 @@ For integration with Stackdriver, this option should start with ``stackdriver://
309309
The path section of the URL specifies the name of the log e.g. ``stackdriver://airflow-tasks`` writes
310310
logs under the name ``airflow-tasks``.
311311

312+
You can set ``stackdriver_key_path`` option in the ``[logging]`` section to specify the path to `the service
313+
account key file <https://cloud.google.com/iam/docs/service-accounts>`__.
314+
If ommited, authorization based on `the Application Default Credentials
315+
<https://cloud.google.com/docs/authentication/production#finding_credentials_automatically>`__ will
316+
be used.
317+
312318
By using the ``logging_config_class`` option you can get :ref:`advanced features <write-logs-advanced>` of
313319
this handler. Details are available in the handler's documentation -
314320
:class:`~airflow.utils.log.stackdriver_task_handler.StackdriverTaskHandler`.

tests/utils/log/test_stackdriver_task_handler.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ def _create_list_response(messages, token):
3636

3737
class TestStackdriverLoggingHandlerStandalone(unittest.TestCase):
3838

39+
@mock.patch('airflow.utils.log.stackdriver_task_handler.get_credentials_and_project_id')
3940
@mock.patch('airflow.utils.log.stackdriver_task_handler.gcp_logging.Client')
40-
def test_should_pass_message_to_client(self, mock_client):
41+
def test_should_pass_message_to_client(self, mock_client, mock_get_creds_and_project_id):
4142
transport_type = mock.MagicMock()
4243
stackdriver_task_handler = StackdriverTaskHandler(
43-
gcp_conn_id=None,
4444
transport=transport_type,
4545
labels={"key": 'value'}
4646
)
@@ -54,7 +54,10 @@ def test_should_pass_message_to_client(self, mock_client):
5454
transport_type.return_value.send.assert_called_once_with(
5555
mock.ANY, 'test-message', labels={"key": 'value'}, resource=Resource(type='global', labels={})
5656
)
57-
mock_client.assert_called_once()
57+
mock_client.assert_called_once_with(
58+
credentials=mock_get_creds_and_project_id.return_value,
59+
client_info=mock.ANY
60+
)
5861

5962

6063
class TestStackdriverLoggingHandlerTask(unittest.TestCase):
@@ -73,8 +76,9 @@ def setUp(self) -> None:
7376
self.ti.state = State.RUNNING
7477
self.addCleanup(self.dag.clear)
7578

79+
@mock.patch('airflow.utils.log.stackdriver_task_handler.get_credentials_and_project_id')
7680
@mock.patch('airflow.utils.log.stackdriver_task_handler.gcp_logging.Client')
77-
def test_should_set_labels(self, mock_client):
81+
def test_should_set_labels(self, mock_client, mock_get_creds_and_project_id):
7882
self.stackdriver_task_handler.set_context(self.ti)
7983
self.logger.addHandler(self.stackdriver_task_handler)
8084

@@ -92,8 +96,9 @@ def test_should_set_labels(self, mock_client):
9296
mock.ANY, 'test-message', labels=labels, resource=resource
9397
)
9498

99+
@mock.patch('airflow.utils.log.stackdriver_task_handler.get_credentials_and_project_id')
95100
@mock.patch('airflow.utils.log.stackdriver_task_handler.gcp_logging.Client')
96-
def test_should_append_labels(self, mock_client):
101+
def test_should_append_labels(self, mock_client, mock_get_creds_and_project_id):
97102
self.stackdriver_task_handler = StackdriverTaskHandler(
98103
transport=self.transport_mock,
99104
labels={"product.googleapis.com/task_id": "test-value"}
@@ -116,11 +121,12 @@ def test_should_append_labels(self, mock_client):
116121
mock.ANY, 'test-message', labels=labels, resource=resource
117122
)
118123

124+
@mock.patch('airflow.utils.log.stackdriver_task_handler.get_credentials_and_project_id')
119125
@mock.patch(
120126
'airflow.utils.log.stackdriver_task_handler.gcp_logging.Client',
121127
**{'return_value.project': 'asf-project'} # type: ignore
122128
)
123-
def test_should_read_logs_for_all_try(self, mock_client):
129+
def test_should_read_logs_for_all_try(self, mock_client, mock_get_creds_and_project_id):
124130
mock_client.return_value.list_entries.return_value = _create_list_response(["MSG1", "MSG2"], None)
125131

126132
logs, metadata = self.stackdriver_task_handler.read(self.ti)
@@ -135,11 +141,12 @@ def test_should_read_logs_for_all_try(self, mock_client):
135141
self.assertEqual(['MSG1\nMSG2'], logs)
136142
self.assertEqual([{'end_of_log': True}], metadata)
137143

144+
@mock.patch('airflow.utils.log.stackdriver_task_handler.get_credentials_and_project_id')
138145
@mock.patch( # type: ignore
139146
'airflow.utils.log.stackdriver_task_handler.gcp_logging.Client',
140147
**{'return_value.project': 'asf-project'} # type: ignore
141148
)
142-
def test_should_read_logs_for_task_with_quote(self, mock_client):
149+
def test_should_read_logs_for_task_with_quote(self, mock_client, mock_get_creds_and_project_id):
143150
mock_client.return_value.list_entries.return_value = _create_list_response(["MSG1", "MSG2"], None)
144151
self.ti.task_id = "K\"OT"
145152
logs, metadata = self.stackdriver_task_handler.read(self.ti)
@@ -154,11 +161,12 @@ def test_should_read_logs_for_task_with_quote(self, mock_client):
154161
self.assertEqual(['MSG1\nMSG2'], logs)
155162
self.assertEqual([{'end_of_log': True}], metadata)
156163

164+
@mock.patch('airflow.utils.log.stackdriver_task_handler.get_credentials_and_project_id')
157165
@mock.patch(
158166
'airflow.utils.log.stackdriver_task_handler.gcp_logging.Client',
159167
**{'return_value.project': 'asf-project'} # type: ignore
160168
)
161-
def test_should_read_logs_for_single_try(self, mock_client):
169+
def test_should_read_logs_for_single_try(self, mock_client, mock_get_creds_and_project_id):
162170
mock_client.return_value.list_entries.return_value = _create_list_response(["MSG1", "MSG2"], None)
163171

164172
logs, metadata = self.stackdriver_task_handler.read(self.ti, 3)
@@ -174,8 +182,9 @@ def test_should_read_logs_for_single_try(self, mock_client):
174182
self.assertEqual(['MSG1\nMSG2'], logs)
175183
self.assertEqual([{'end_of_log': True}], metadata)
176184

185+
@mock.patch('airflow.utils.log.stackdriver_task_handler.get_credentials_and_project_id')
177186
@mock.patch('airflow.utils.log.stackdriver_task_handler.gcp_logging.Client')
178-
def test_should_read_logs_with_pagination(self, mock_client):
187+
def test_should_read_logs_with_pagination(self, mock_client, mock_get_creds_and_project_id):
179188
mock_client.return_value.list_entries.side_effect = [
180189
_create_list_response(["MSG1", "MSG2"], "TOKEN1"),
181190
_create_list_response(["MSG3", "MSG4"], None),
@@ -195,8 +204,9 @@ def test_should_read_logs_with_pagination(self, mock_client):
195204
self.assertEqual(['MSG3\nMSG4'], logs)
196205
self.assertEqual([{'end_of_log': True}], metadata2)
197206

207+
@mock.patch('airflow.utils.log.stackdriver_task_handler.get_credentials_and_project_id')
198208
@mock.patch('airflow.utils.log.stackdriver_task_handler.gcp_logging.Client')
199-
def test_should_read_logs_with_download(self, mock_client):
209+
def test_should_read_logs_with_download(self, mock_client, mock_get_creds_and_project_id):
200210
mock_client.return_value.list_entries.side_effect = [
201211
_create_list_response(["MSG1", "MSG2"], "TOKEN1"),
202212
_create_list_response(["MSG3", "MSG4"], None),
@@ -207,11 +217,12 @@ def test_should_read_logs_with_download(self, mock_client):
207217
self.assertEqual(['MSG1\nMSG2\nMSG3\nMSG4'], logs)
208218
self.assertEqual([{'end_of_log': True}], metadata1)
209219

220+
@mock.patch('airflow.utils.log.stackdriver_task_handler.get_credentials_and_project_id')
210221
@mock.patch(
211222
'airflow.utils.log.stackdriver_task_handler.gcp_logging.Client',
212223
**{'return_value.project': 'asf-project'} # type: ignore
213224
)
214-
def test_should_read_logs_with_custom_resources(self, mock_client):
225+
def test_should_read_logs_with_custom_resources(self, mock_client, mock_get_creds_and_project_id):
215226
resource = Resource(
216227
type="cloud_composer_environment",
217228
labels={
@@ -245,31 +256,26 @@ def test_should_read_logs_with_custom_resources(self, mock_client):
245256
self.assertEqual(['TEXT\nTEXT'], logs)
246257
self.assertEqual([{'end_of_log': True}], metadata)
247258

248-
249-
class TestStackdriverTaskHandlerAuthorization(unittest.TestCase):
250-
259+
@mock.patch('airflow.utils.log.stackdriver_task_handler.get_credentials_and_project_id')
251260
@mock.patch('airflow.utils.log.stackdriver_task_handler.gcp_logging.Client')
252-
def test_should_fallback_to_adc(self, mock_client):
261+
def test_should_use_credentials(self, mock_client, mock_get_creds_and_project_id):
253262
stackdriver_task_handler = StackdriverTaskHandler(
254-
gcp_conn_id=None
263+
gcp_key_path="KEY_PATH",
255264
)
256265

257266
client = stackdriver_task_handler._client
258267

259-
mock_client.assert_called_once_with(credentials=None, client_info=mock.ANY)
260-
self.assertEqual(mock_client.return_value, client)
261-
262-
@mock.patch("airflow.providers.google.common.hooks.base_google.GoogleBaseHook")
263-
@mock.patch('airflow.utils.log.stackdriver_task_handler.gcp_logging.Client')
264-
def test_should_support_gcp_conn_id(self, mock_client, mock_hook):
265-
stackdriver_task_handler = StackdriverTaskHandler(
266-
gcp_conn_id="test-gcp-conn"
268+
mock_get_creds_and_project_id.assert_called_once_with(
269+
key_path='KEY_PATH',
270+
scopes=frozenset(
271+
{
272+
'https://www.googleapis.com/auth/logging.write',
273+
'https://www.googleapis.com/auth/logging.read'
274+
}
275+
)
267276
)
268-
269-
client = stackdriver_task_handler._client
270-
271277
mock_client.assert_called_once_with(
272-
credentials=mock_hook.return_value._get_credentials.return_value,
278+
credentials=mock_get_creds_and_project_id.return_value,
273279
client_info=mock.ANY
274280
)
275281
self.assertEqual(mock_client.return_value, client)

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: https://github.com/apache/airflow/commit/5b680e27e8118861ef484c00a4b87c6885b0a518

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy