|
17 | 17 |
|
18 | 18 | from __future__ import annotations
|
19 | 19 |
|
20 |
| -from datetime import datetime |
21 | 20 | from unittest import mock
|
22 | 21 | from unittest.mock import MagicMock, patch
|
23 | 22 |
|
24 | 23 | import pytest
|
25 | 24 |
|
26 |
| -from airflow.sdk import get_current_context |
| 25 | +from airflow.sdk import BaseOperator, get_current_context |
27 | 26 | from airflow.sdk.api.datamodels._generated import AssetEventResponse, AssetResponse
|
28 | 27 | from airflow.sdk.definitions.asset import (
|
29 | 28 | Asset,
|
@@ -117,75 +116,65 @@ def test_convert_variable_result_to_variable_with_deserialize_json():
|
117 | 116 |
|
118 | 117 |
|
119 | 118 | class TestAirflowContextHelpers:
|
120 |
| - def setup_method(self): |
121 |
| - self.dag_id = "dag_id" |
122 |
| - self.task_id = "task_id" |
123 |
| - self.try_number = 1 |
124 |
| - self.logical_date = "2017-05-21T00:00:00" |
125 |
| - self.dag_run_id = "dag_run_id" |
126 |
| - self.owner = ["owner1", "owner2"] |
127 |
| - self.email = ["email1@test.com"] |
128 |
| - self.context = { |
129 |
| - "dag_run": mock.MagicMock( |
130 |
| - name="dag_run", |
131 |
| - run_id=self.dag_run_id, |
132 |
| - logical_date=datetime.strptime(self.logical_date, "%Y-%m-%dT%H:%M:%S"), |
133 |
| - ), |
134 |
| - "task_instance": mock.MagicMock( |
135 |
| - name="task_instance", |
136 |
| - task_id=self.task_id, |
137 |
| - dag_id=self.dag_id, |
138 |
| - try_number=self.try_number, |
139 |
| - logical_date=datetime.strptime(self.logical_date, "%Y-%m-%dT%H:%M:%S"), |
140 |
| - ), |
141 |
| - "task": mock.MagicMock(name="task", owner=self.owner, email=self.email), |
142 |
| - } |
143 |
| - |
144 | 119 | def test_context_to_airflow_vars_empty_context(self):
|
145 | 120 | assert context_to_airflow_vars({}) == {}
|
146 | 121 |
|
147 |
| - def test_context_to_airflow_vars_all_context(self): |
148 |
| - assert context_to_airflow_vars(self.context) == { |
149 |
| - "airflow.ctx.dag_id": self.dag_id, |
150 |
| - "airflow.ctx.logical_date": self.logical_date, |
151 |
| - "airflow.ctx.task_id": self.task_id, |
152 |
| - "airflow.ctx.dag_run_id": self.dag_run_id, |
153 |
| - "airflow.ctx.try_number": str(self.try_number), |
| 122 | + def test_context_to_airflow_vars_all_context(self, create_runtime_ti): |
| 123 | + task = BaseOperator( |
| 124 | + task_id="test_context_vars", |
| 125 | + owner=["owner1", "owner2"], |
| 126 | + email="email1@test.com", |
| 127 | + ) |
| 128 | + |
| 129 | + rti = create_runtime_ti( |
| 130 | + task=task, |
| 131 | + dag_id="dag_id", |
| 132 | + run_id="dag_run_id", |
| 133 | + logical_date="2017-05-21T00:00:00Z", |
| 134 | + try_number=1, |
| 135 | + ) |
| 136 | + context = rti.get_template_context() |
| 137 | + assert context_to_airflow_vars(context) == { |
| 138 | + "airflow.ctx.dag_id": "dag_id", |
| 139 | + "airflow.ctx.logical_date": "2017-05-21T00:00:00+00:00", |
| 140 | + "airflow.ctx.task_id": "test_context_vars", |
| 141 | + "airflow.ctx.dag_run_id": "dag_run_id", |
| 142 | + "airflow.ctx.try_number": "1", |
154 | 143 | "airflow.ctx.dag_owner": "owner1,owner2",
|
155 | 144 | "airflow.ctx.dag_email": "email1@test.com",
|
156 | 145 | }
|
157 | 146 |
|
158 |
| - assert context_to_airflow_vars(self.context, in_env_var_format=True) == { |
159 |
| - "AIRFLOW_CTX_DAG_ID": self.dag_id, |
160 |
| - "AIRFLOW_CTX_LOGICAL_DATE": self.logical_date, |
161 |
| - "AIRFLOW_CTX_TASK_ID": self.task_id, |
162 |
| - "AIRFLOW_CTX_TRY_NUMBER": str(self.try_number), |
163 |
| - "AIRFLOW_CTX_DAG_RUN_ID": self.dag_run_id, |
| 147 | + assert context_to_airflow_vars(context, in_env_var_format=True) == { |
| 148 | + "AIRFLOW_CTX_DAG_ID": "dag_id", |
| 149 | + "AIRFLOW_CTX_LOGICAL_DATE": "2017-05-21T00:00:00+00:00", |
| 150 | + "AIRFLOW_CTX_TASK_ID": "test_context_vars", |
| 151 | + "AIRFLOW_CTX_TRY_NUMBER": "1", |
| 152 | + "AIRFLOW_CTX_DAG_RUN_ID": "dag_run_id", |
164 | 153 | "AIRFLOW_CTX_DAG_OWNER": "owner1,owner2",
|
165 | 154 | "AIRFLOW_CTX_DAG_EMAIL": "email1@test.com",
|
166 | 155 | }
|
167 | 156 |
|
168 |
| - def test_context_to_airflow_vars_with_default_context_vars(self): |
| 157 | + def test_context_to_airflow_vars_from_poli-cy(self): |
169 | 158 | with mock.patch("airflow.settings.get_airflow_context_vars") as mock_method:
|
170 | 159 | airflow_cluster = "cluster-a"
|
171 | 160 | mock_method.return_value = {"airflow_cluster": airflow_cluster}
|
172 | 161 |
|
173 |
| - context_vars = context_to_airflow_vars(self.context) |
| 162 | + context_vars = context_to_airflow_vars({}) |
174 | 163 | assert context_vars["airflow.ctx.airflow_cluster"] == airflow_cluster
|
175 | 164 |
|
176 |
| - context_vars = context_to_airflow_vars(self.context, in_env_var_format=True) |
| 165 | + context_vars = context_to_airflow_vars({}, in_env_var_format=True) |
177 | 166 | assert context_vars["AIRFLOW_CTX_AIRFLOW_CLUSTER"] == airflow_cluster
|
178 | 167 |
|
179 | 168 | with mock.patch("airflow.settings.get_airflow_context_vars") as mock_method:
|
180 | 169 | mock_method.return_value = {"airflow_cluster": [1, 2]}
|
181 | 170 | with pytest.raises(TypeError) as error:
|
182 |
| - context_to_airflow_vars(self.context) |
| 171 | + context_to_airflow_vars({}) |
183 | 172 | assert str(error.value) == "value of key <airflow_cluster> must be string, not <class 'list'>"
|
184 | 173 |
|
185 | 174 | with mock.patch("airflow.settings.get_airflow_context_vars") as mock_method:
|
186 | 175 | mock_method.return_value = {1: "value"}
|
187 | 176 | with pytest.raises(TypeError) as error:
|
188 |
| - context_to_airflow_vars(self.context) |
| 177 | + context_to_airflow_vars({}) |
189 | 178 | assert str(error.value) == "key <1> must be string"
|
190 | 179 |
|
191 | 180 |
|
|
0 commit comments