Skip to content

Commit d8109dd

Browse files
committed
Fix short circuit in mapped tasks
1 parent 4b27c3f commit d8109dd

File tree

7 files changed

+187
-45
lines changed

7 files changed

+187
-45
lines changed

airflow/models/mappedoperator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,8 @@ def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) ->
830830
op.is_setup = is_setup
831831
op.is_teardown = is_teardown
832832
op.on_failure_fail_dagrun = on_failure_fail_dagrun
833+
op.downstream_task_ids = self.downstream_task_ids
834+
op.upstream_task_ids = self.upstream_task_ids
833835
return op
834836

835837
# After a mapped operator is serialized, there's no real way to actually

airflow/models/skipmixin.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,11 @@ def _skip(
161161
raise ValueError("dag_run is required")
162162

163163
task_ids_list = [d.task_id for d in task_list]
164-
SkipMixin._set_state_to_skipped(dag_run, task_ids_list, session)
165-
session.commit()
164+
165+
# The following could be applied only for non-mapped tasks
166+
if map_index == -1:
167+
SkipMixin._set_state_to_skipped(dag_run, task_ids_list, session)
168+
session.commit()
166169

167170
if task_id is not None:
168171
from airflow.models.xcom import XCom
@@ -177,8 +180,8 @@ def _skip(
177180
session=session,
178181
)
179182

183+
@staticmethod
180184
def skip_all_except(
181-
self,
182185
ti: TaskInstance | TaskInstancePydantic,
183186
branch_task_ids: None | str | Iterable[str],
184187
):

airflow/ti_deps/deps/not_previously_skipped_dep.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from airflow.models.taskinstance import PAST_DEPENDS_MET
2121
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
22+
from airflow.utils.db import LazySelectSequence
2223

2324

2425
class NotPreviouslySkippedDep(BaseTIDep):
@@ -38,7 +39,6 @@ def _get_dep_statuses(self, ti, session, dep_context):
3839
XCOM_SKIPMIXIN_FOLLOWED,
3940
XCOM_SKIPMIXIN_KEY,
4041
XCOM_SKIPMIXIN_SKIPPED,
41-
SkipMixin,
4242
)
4343
from airflow.utils.state import TaskInstanceState
4444

@@ -49,46 +49,47 @@ def _get_dep_statuses(self, ti, session, dep_context):
4949
finished_task_ids = {t.task_id for t in finished_tis}
5050

5151
for parent in upstream:
52-
if isinstance(parent, SkipMixin):
53-
if parent.task_id not in finished_task_ids:
54-
# This can happen if the parent task has not yet run.
55-
continue
52+
if parent.task_id not in finished_task_ids:
53+
# This can happen if the parent task has not yet run.
54+
continue
5655

57-
prev_result = ti.xcom_pull(task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY, session=session)
56+
prev_result = ti.xcom_pull(
57+
task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY, session=session, map_indexes=ti.map_index
58+
)
5859

59-
if prev_result is None:
60-
# This can happen if the parent task has not yet run.
61-
continue
60+
if isinstance(prev_result, LazySelectSequence):
61+
prev_result = next(iter(prev_result))
6262

63-
should_skip = False
64-
if (
65-
XCOM_SKIPMIXIN_FOLLOWED in prev_result
66-
and ti.task_id not in prev_result[XCOM_SKIPMIXIN_FOLLOWED]
67-
):
68-
# Skip any tasks that are not in "followed"
69-
should_skip = True
70-
elif (
71-
XCOM_SKIPMIXIN_SKIPPED in prev_result
72-
and ti.task_id in prev_result[XCOM_SKIPMIXIN_SKIPPED]
73-
):
74-
# Skip any tasks that are in "skipped"
75-
should_skip = True
63+
if prev_result is None:
64+
# This can happen if the parent task has not yet run.
65+
continue
7666

77-
if should_skip:
78-
# If the parent SkipMixin has run, and the XCom result stored indicates this
79-
# ti should be skipped, set ti.state to SKIPPED and fail the rule so that the
80-
# ti does not execute.
81-
if dep_context.wait_for_past_depends_before_skipping:
82-
past_depends_met = ti.xcom_pull(
83-
task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False
84-
)
85-
if not past_depends_met:
86-
yield self._failing_status(
87-
reason=("Task should be skipped but the past depends are not met")
88-
)
89-
return
90-
ti.set_state(TaskInstanceState.SKIPPED, session)
91-
yield self._failing_status(
92-
reason=f"Skipping because of previous XCom result from parent task {parent.task_id}"
67+
should_skip = False
68+
if (
69+
XCOM_SKIPMIXIN_FOLLOWED in prev_result
70+
and ti.task_id not in prev_result[XCOM_SKIPMIXIN_FOLLOWED]
71+
):
72+
# Skip any tasks that are not in "followed"
73+
should_skip = True
74+
elif XCOM_SKIPMIXIN_SKIPPED in prev_result and ti.task_id in prev_result[XCOM_SKIPMIXIN_SKIPPED]:
75+
# Skip any tasks that are in "skipped"
76+
should_skip = True
77+
78+
if should_skip:
79+
# If the parent SkipMixin has run, and the XCom result stored indicates this
80+
# ti should be skipped, set ti.state to SKIPPED and fail the rule so that the
81+
# ti does not execute.
82+
if dep_context.wait_for_past_depends_before_skipping:
83+
past_depends_met = ti.xcom_pull(
84+
task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False
9385
)
94-
return
86+
if not past_depends_met:
87+
yield self._failing_status(
88+
reason="Task should be skipped but the past depends are not met"
89+
)
90+
return
91+
ti.set_state(TaskInstanceState.SKIPPED, session)
92+
yield self._failing_status(
93+
reason=f"Skipping because of previous XCom result from parent task {parent.task_id}"
94+
)
95+
return

newsfragments/44912.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix short circuit operator in mapped tasks. The operator did not work until now due to a bug in ``NotPreviouslySkippedDep``. Please note that at time of merging, this fix has been applied only for Airflow version > 2.10.4 and < 3, and should be ported to v3 after merging PR #44925.

tests/models/test_mappedoperator.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from datetime import timedelta
2222
from typing import TYPE_CHECKING
2323
from unittest import mock
24-
from unittest.mock import patch
24+
from unittest.mock import MagicMock, patch
2525

2626
import pendulum
2727
import pytest
@@ -1868,3 +1868,58 @@ def _one_scheduling_decision_iteration() -> dict[tuple[str, int], TaskInstance]:
18681868
assert tis["group.last", 0].state == State.SUCCESS
18691869
assert dr.get_task_instance("group.last", map_index=1, session=session).state == State.SKIPPED
18701870
assert tis["group.last", 2].state == State.SUCCESS
1871+
1872+
1873+
class TestMappedOperator:
1874+
@pytest.fixture
1875+
def mock_operator_class(self):
1876+
return MagicMock(spec=type(BaseOperator))
1877+
1878+
@pytest.fixture
1879+
@patch("airflow.serialization.serialized_objects.SerializedBaseOperator")
1880+
def mapped_operator(self, _, mock_operator_class):
1881+
return MappedOperator(
1882+
operator_class=mock_operator_class,
1883+
expand_input=MagicMock(),
1884+
partial_kwargs={"task_id": "test_task"},
1885+
task_id="test_task",
1886+
params={},
1887+
deps=frozenset(),
1888+
operator_extra_links=[],
1889+
template_ext=[],
1890+
template_fields=[],
1891+
template_fields_renderers={},
1892+
ui_color="",
1893+
ui_fgcolor="",
1894+
start_trigger_args=None,
1895+
start_from_trigger=False,
1896+
dag=None,
1897+
task_group=None,
1898+
start_date=None,
1899+
end_date=None,
1900+
is_empty=False,
1901+
task_module=MagicMock(),
1902+
task_type="taske_type",
1903+
operator_name="operator_name",
1904+
disallow_kwargs_override=False,
1905+
expand_input_attr="expand_input",
1906+
)
1907+
1908+
def test_unmap_with_resolved_kwargs(self, mapped_operator, mock_operator_class):
1909+
mapped_operator.upstream_task_ids = ["a"]
1910+
mapped_operator.downstream_task_ids = ["b"]
1911+
resolve = {"param1": "value1"}
1912+
result = mapped_operator.unmap(resolve)
1913+
assert result == mock_operator_class.return_value
1914+
assert result.task_id == "test_task"
1915+
assert result.is_setup is False
1916+
assert result.is_teardown is False
1917+
assert result.on_failure_fail_dagrun is False
1918+
assert result.upstream_task_ids == ["a"]
1919+
assert result.downstream_task_ids == ["b"]
1920+
1921+
def test_unmap_runtime_error(self, mapped_operator):
1922+
mapped_operator.upstream_task_ids = ["a"]
1923+
mapped_operator.downstream_task_ids = ["b"]
1924+
with pytest.raises(RuntimeError):
1925+
mapped_operator.unmap(None)

tests/models/test_skipmixin.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818
from __future__ import annotations
1919

2020
import datetime
21-
from unittest.mock import Mock, patch
21+
from unittest.mock import MagicMock, Mock, patch
2222

2323
import pendulum
2424
import pytest
2525

2626
from airflow import settings
2727
from airflow.decorators import task, task_group
2828
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
29+
from airflow.models import DagRun, MappedOperator
2930
from airflow.models.skipmixin import SkipMixin
3031
from airflow.models.taskinstance import TaskInstance as TI
3132
from airflow.operators.empty import EmptyOperator
@@ -53,6 +54,10 @@ def setup_method(self):
5354
def teardown_method(self):
5455
self.clean_db()
5556

57+
@pytest.fixture
58+
def mock_session(self):
59+
return Mock(spec=settings.Session)
60+
5661
@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode
5762
@patch("airflow.utils.timezone.utcnow")
5863
def test_skip(self, mock_now, dag_maker):
@@ -104,10 +109,40 @@ def test_skip_none_dagrun(self, mock_now, dag_maker):
104109

105110
def test_skip_none_tasks(self):
106111
session = Mock()
107-
SkipMixin().skip(dag_run=None, execution_date=None, tasks=[])
112+
assert (
113+
SkipMixin()._skip(dag_run=None, task_id=None, execution_date=None, tasks=[], session=session)
114+
is None
115+
)
108116
assert not session.query.called
109117
assert not session.commit.called
110118

119+
def test_skip_mapped_task(self, mock_session):
120+
SkipMixin()._skip(
121+
dag_run=MagicMock(spec=DagRun),
122+
task_id=None,
123+
execution_date=None,
124+
tasks=[MagicMock(spec=MappedOperator)],
125+
session=mock_session,
126+
map_index=2,
127+
)
128+
mock_session.execute.assert_not_called()
129+
mock_session.commit.assert_not_called()
130+
131+
@patch("airflow.models.skipmixin.update")
132+
def test_skip_none_mapped_task(self, mock_update, mock_session):
133+
SkipMixin()._skip(
134+
dag_run=MagicMock(spec=DagRun),
135+
task_id=None,
136+
execution_date=None,
137+
tasks=[MagicMock(spec=MappedOperator)],
138+
session=mock_session,
139+
map_index=-1,
140+
)
141+
mock_session.execute.assert_called_once_with(
142+
mock_update.return_value.where.return_value.values.return_value.execution_options.return_value
143+
)
144+
mock_session.commit.assert_called_once()
145+
111146
@pytest.mark.parametrize(
112147
"branch_task_ids, expected_states",
113148
[

tests/ti_deps/deps/test_not_previously_skipped_dep.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pendulum
2121
import pytest
2222

23+
from airflow.decorators import task
2324
from airflow.models import DagRun, TaskInstance
2425
from airflow.operators.empty import EmptyOperator
2526
from airflow.operators.python import BranchPythonOperator
@@ -84,6 +85,50 @@ def test_no_skipmixin_parent(session, dag_maker):
8485
assert ti2.state != State.SKIPPED
8586

8687

88+
@pytest.mark.parametrize("condition, final_state", [(True, State.SUCCESS), (False, State.SKIPPED)])
89+
def test_parent_is_mapped_short_circuit(session, dag_maker, condition, final_state):
90+
with dag_maker(session=session):
91+
92+
@task
93+
def op1():
94+
return [1]
95+
96+
@task.short_circuit
97+
def op2(i: int):
98+
return condition
99+
100+
@task
101+
def op3(res: bool):
102+
pass
103+
104+
op3.expand(res=op2.expand(i=op1()))
105+
106+
dr = dag_maker.create_dagrun()
107+
108+
def _one_scheduling_decision_iteration() -> dict[tuple[str, int], TaskInstance]:
109+
decision = dr.task_instance_scheduling_decisions(session=session)
110+
return {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis}
111+
112+
tis = _one_scheduling_decision_iteration()
113+
114+
tis["op1", -1].run()
115+
assert tis["op1", -1].state == State.SUCCESS
116+
117+
tis = _one_scheduling_decision_iteration()
118+
tis["op2", 0].run()
119+
120+
assert tis["op2", 0].state == State.SUCCESS
121+
tis = _one_scheduling_decision_iteration()
122+
123+
if condition:
124+
ti3 = tis["op3", 0]
125+
ti3.run()
126+
else:
127+
ti3 = dr.get_task_instance("op3", map_index=0, session=session)
128+
129+
assert ti3.state == final_state
130+
131+
87132
def test_parent_follow_branch(session, dag_maker):
88133
"""
89134
A simple DAG with a BranchPythonOperator that follows op2. NotPreviouslySkippedDep is met.

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy