Content-Length: 844877 | pFad | http://github.com/apache/airflow/commit/5458e7e7be86c6de034d7a589bd26db85c532308

C9 fix(task_instances): handle upstream_mapped_index when xcom access is… · apache/airflow@5458e7e · GitHub
Skip to content

Commit 5458e7e

Browse files
authored
fix(task_instances): handle upstream_mapped_index when xcom access is needed (#50641)
* fix(task_instances): handle upstream_mapped_index when xcom access is needed * style(expand_input): fix expand_input and SchedulerExpandInput types * test(task_instances): add test_dynamic_task_mapping_with_parse_time_value * test(task_instance): add test_dynamic_task_mapping_with_xcom * style: import typing * style: move the SchedulerExpandInput into type checking block * Revert "style: move the SchedulerExpandInput into type checking block" This reverts commit c2c87ca.
1 parent e033afa commit 5458e7e

File tree

6 files changed

+153
-14
lines changed

6 files changed

+153
-14
lines changed

airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from airflow.models.taskreschedule import TaskReschedule
5858
from airflow.models.trigger import Trigger
5959
from airflow.models.xcom import XComModel
60+
from airflow.sdk.definitions._internal.expandinput import NotFullyPopulated
6061
from airflow.sdk.definitions.taskgroup import MappedTaskGroup
6162
from airflow.utils import timezone
6263
from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState
@@ -244,7 +245,9 @@ def ti_run(
244245
)
245246

246247
if dag := dag_bag.get_dag(ti.dag_id):
247-
upstream_map_indexes = dict(_get_upstream_map_indexes(dag.get_task(ti.task_id), ti.map_index))
248+
upstream_map_indexes = dict(
249+
_get_upstream_map_indexes(dag.get_task(ti.task_id), ti.map_index, ti.run_id, session)
250+
)
248251
else:
249252
upstream_map_indexes = None
250253

@@ -274,7 +277,7 @@ def ti_run(
274277

275278

276279
def _get_upstream_map_indexes(
277-
task: Operator, ti_map_index: int
280+
task: Operator, ti_map_index: int, run_id: str, session: SessionDep
278281
) -> Iterator[tuple[str, int | list[int] | None]]:
279282
for upstream_task in task.upstream_list:
280283
map_indexes: int | list[int] | None
@@ -287,8 +290,17 @@ def _get_upstream_map_indexes(
287290
map_indexes = ti_map_index
288291
else:
289292
# tasks not in the same mapped task group
290-
# the upstream mapped task group should combine the xcom as a list and return it
291-
mapped_ti_count: int = upstream_task.task_group.get_parse_time_mapped_ti_count()
293+
# the upstream mapped task group should combine the return xcom as a list and return it
294+
mapped_ti_count: int
295+
upstream_mapped_group = upstream_task.task_group
296+
try:
297+
# for cases that does not need to resolve xcom
298+
mapped_ti_count = upstream_mapped_group.get_parse_time_mapped_ti_count()
299+
except NotFullyPopulated:
300+
# for cases that needs to resolve xcom to get the correct count
301+
mapped_ti_count = upstream_mapped_group._expand_input.get_total_map_length(
302+
run_id, session=session
303+
)
292304
map_indexes = list(range(mapped_ti_count)) if mapped_ti_count is not None else None
293305

294306
yield upstream_task.task_id, map_indexes

airflow-core/src/airflow/models/expandinput.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import functools
2121
import operator
2222
from collections.abc import Iterable, Sized
23-
from typing import TYPE_CHECKING, Any
23+
from typing import TYPE_CHECKING, Any, ClassVar, Union
2424

2525
import attrs
2626

@@ -32,7 +32,6 @@
3232

3333
from airflow.sdk.definitions._internal.expandinput import (
3434
DictOfListsExpandInput,
35-
ExpandInput,
3635
ListOfDictsExpandInput,
3736
MappedArgument,
3837
NotFullyPopulated,
@@ -62,6 +61,8 @@ def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArg
6261
class SchedulerDictOfListsExpandInput:
6362
value: dict
6463

64+
EXPAND_INPUT_TYPE: ClassVar[str] = "dict-of-lists"
65+
6566
def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
6667
"""Generate kwargs with values available on parse-time."""
6768
return ((k, v) for k, v in self.value.items() if not _needs_run_time_resolution(v))
@@ -114,6 +115,8 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int:
114115
class SchedulerListOfDictsExpandInput:
115116
value: list
116117

118+
EXPAND_INPUT_TYPE: ClassVar[str] = "list-of-dicts"
119+
117120
def get_parse_time_mapped_ti_count(self) -> int:
118121
if isinstance(self.value, Sized):
119122
return len(self.value)
@@ -130,11 +133,13 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int:
130133
return length
131134

132135

133-
_EXPAND_INPUT_TYPES = {
136+
_EXPAND_INPUT_TYPES: dict[str, type[SchedulerExpandInput]] = {
134137
"dict-of-lists": SchedulerDictOfListsExpandInput,
135138
"list-of-dicts": SchedulerListOfDictsExpandInput,
136139
}
137140

141+
SchedulerExpandInput = Union[SchedulerDictOfListsExpandInput, SchedulerListOfDictsExpandInput]
142+
138143

139-
def create_expand_input(kind: str, value: Any) -> ExpandInput:
144+
def create_expand_input(kind: str, value: Any) -> SchedulerExpandInput:
140145
return _EXPAND_INPUT_TYPES[kind](value)

airflow-core/src/airflow/serialization/serialized_objects.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@
100100
from inspect import Parameter
101101

102102
from airflow.models import DagRun
103-
from airflow.models.expandinput import ExpandInput
103+
from airflow.models.expandinput import SchedulerExpandInput
104104
from airflow.sdk import BaseOperatorLink
105105
from airflow.sdk.definitions._internal.node import DAGNode
106106
from airflow.sdk.types import Operator
@@ -557,7 +557,7 @@ def validate_expand_input_value(cls, value: _ExpandInputOriginalValue) -> None:
557557
possible ExpandInput cases.
558558
"""
559559

560-
def deref(self, dag: DAG) -> ExpandInput:
560+
def deref(self, dag: DAG) -> SchedulerExpandInput:
561561
"""
562562
De-reference into a concrete ExpandInput object.
563563

airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from airflow.models.taskinstance import TaskInstance
3535
from airflow.models.taskinstancehistory import TaskInstanceHistory
3636
from airflow.providers.standard.operators.empty import EmptyOperator
37-
from airflow.sdk import TaskGroup
37+
from airflow.sdk import TaskGroup, task, task_group
3838
from airflow.utils import timezone
3939
from airflow.utils.state import State, TaskInstanceState, TerminalTIState
4040

@@ -237,6 +237,128 @@ def test_ti_run_state_to_running(
237237
)
238238
assert response.status_code == 409
239239

240+
def test_dynamic_task_mapping_with_parse_time_value(self, client, dag_maker):
241+
"""
242+
Test that the Task Instance upstream_map_indexes is correctly fetched when to running the Task Instances
243+
"""
244+
245+
with dag_maker("test_dynamic_task_mapping_with_parse_time_value", serialized=True):
246+
247+
@task_group
248+
def task_group_1(arg1):
249+
@task
250+
def group1_task_1(arg1):
251+
return {"a": arg1}
252+
253+
@task
254+
def group1_task_2(arg2):
255+
return arg2
256+
257+
group1_task_2(group1_task_1(arg1))
258+
259+
@task
260+
def task2():
261+
return None
262+
263+
task_group_1.expand(arg1=[0, 1]) >> task2()
264+
265+
dr = dag_maker.create_dagrun()
266+
for ti in dr.get_task_instances():
267+
ti.set_state(State.QUEUED)
268+
dag_maker.session.flush()
269+
270+
# key: (task_id, map_index)
271+
# value: result upstream_map_indexes ({task_id: map_indexes})
272+
expected_upstream_map_indexes = {
273+
# no upstream task for task_group_1.group_task_1
274+
("task_group_1.group1_task_1", 0): {},
275+
("task_group_1.group1_task_1", 1): {},
276+
# the upstream task for task_group_1.group_task_2 is task_group_1.group_task_2
277+
# since they are in the same task group, the upstream map index should be the same as the task
278+
("task_group_1.group1_task_2", 0): {"task_group_1.group1_task_1": 0},
279+
("task_group_1.group1_task_2", 1): {"task_group_1.group1_task_1": 1},
280+
# the upstream task for task2 is the last tasks of task_group_1, which is
281+
# task_group_1.group_task_2
282+
# since they are not in the same task group, the upstream map index should include all the
283+
# expanded tasks
284+
("task2", -1): {"task_group_1.group1_task_2": [0, 1]},
285+
}
286+
287+
for ti in dr.get_task_instances():
288+
response = client.patch(
289+
f"/execution/task-instances/{ti.id}/run",
290+
json={
291+
"state": "running",
292+
"hostname": "random-hostname",
293+
"unixname": "random-unixname",
294+
"pid": 100,
295+
"start_date": "2024-09-30T12:00:00Z",
296+
},
297+
)
298+
299+
assert response.status_code == 200
300+
upstream_map_indexes = response.json()["upstream_map_indexes"]
301+
assert upstream_map_indexes == expected_upstream_map_indexes[(ti.task_id, ti.map_index)]
302+
303+
def test_dynamic_task_mapping_with_xcom(self, client, dag_maker, create_task_instance, session, run_task):
304+
"""
305+
Test that the Task Instance upstream_map_indexes is correctly fetched when to running the Task Instances with xcom
306+
"""
307+
from airflow.models.taskmap import TaskMap
308+
309+
with dag_maker(session=session):
310+
311+
@task
312+
def task_1():
313+
return [0, 1]
314+
315+
@task_group
316+
def tg(x, y):
317+
@task
318+
def task_2():
319+
pass
320+
321+
task_2()
322+
323+
@task
324+
def task_3():
325+
pass
326+
327+
tg.expand(x=task_1(), y=[1, 2, 3]) >> task_3()
328+
329+
dr = dag_maker.create_dagrun()
330+
331+
decision = dr.task_instance_scheduling_decisions(session=session)
332+
333+
# Simulate task_1 execution to produce TaskMap.
334+
(ti_1,) = decision.schedulable_tis
335+
# ti_1 = dr.get_task_instance(task_id="task_1")
336+
ti_1.state = TaskInstanceState.SUCCESS
337+
session.add(TaskMap.from_task_instance_xcom(ti_1, [0, 1]))
338+
session.flush()
339+
340+
# Now task_2 in mapped tagk group is expanded.
341+
decision = dr.task_instance_scheduling_decisions(session=session)
342+
for ti in decision.schedulable_tis:
343+
ti.state = TaskInstanceState.SUCCESS
344+
session.flush()
345+
346+
decision = dr.task_instance_scheduling_decisions(session=session)
347+
(task_3_ti,) = decision.schedulable_tis
348+
task_3_ti.set_state(State.QUEUED)
349+
350+
response = client.patch(
351+
f"/execution/task-instances/{task_3_ti.id}/run",
352+
json={
353+
"state": "running",
354+
"hostname": "random-hostname",
355+
"unixname": "random-unixname",
356+
"pid": 100,
357+
"start_date": "2024-09-30T12:00:00Z",
358+
},
359+
)
360+
assert response.json()["upstream_map_indexes"] == {"tg.task_2": [0, 1, 2, 3, 4, 5]}
361+
240362
def test_next_kwargs_still_encoded(self, client, session, create_task_instance, time_machine):
241363
instant_str = "2024-09-30T12:00:00Z"
242364
instant = timezone.parse(instant_str)

task-sdk/src/airflow/sdk/definitions/mappedoperator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,12 @@
6464
TaskStateChangeCallback,
6565
)
6666
from airflow.models.expandinput import (
67-
ExpandInput,
6867
OperatorExpandArgument,
6968
OperatorExpandKwargsArgument,
7069
)
7170
from airflow.sdk.bases.operator import BaseOperator
7271
from airflow.sdk.bases.operatorlink import BaseOperatorLink
72+
from airflow.sdk.definitions._internal.expandinput import ExpandInput
7373
from airflow.sdk.definitions.dag import DAG
7474
from airflow.sdk.definitions.param import ParamsDict
7575
from airflow.sdk.definitions.xcom_arg import XComArg

task-sdk/src/airflow/sdk/definitions/taskgroup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from airflow.utils.trigger_rule import TriggerRule
4141

4242
if TYPE_CHECKING:
43-
from airflow.models.expandinput import ExpandInput
43+
from airflow.models.expandinput import SchedulerExpandInput
4444
from airflow.sdk.bases.operator import BaseOperator
4545
from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
4646
from airflow.sdk.definitions._internal.mixins import DependencyMixin
@@ -613,7 +613,7 @@ class MappedTaskGroup(TaskGroup):
613613
a ``@task_group`` function instead.
614614
"""
615615

616-
def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None:
616+
def __init__(self, *, expand_input: SchedulerExpandInput, **kwargs: Any) -> None:
617617
super().__init__(**kwargs)
618618
self._expand_input = expand_input
619619

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: http://github.com/apache/airflow/commit/5458e7e7be86c6de034d7a589bd26db85c532308

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy