Content-Length: 1057249 | pFad | http://github.com/apache/airflow/commit/84b36c4a38daa9fc7e9ced598abdb973cdbf525d

E6 Implement slice on LazyXComSequence · apache/airflow@84b36c4 · GitHub
Skip to content

Commit 84b36c4

Browse files
committed
Implement slice on LazyXComSequence
I decided to split index and slice access to their separate endpoints, instead of reusing the GetXCom endpoint. This duplicates code a bit, but the input parameters are now a lot easier to reason with. It's unfortunate FastAPI does not natively allow unions on Query(), or this could be implemented a lot nicer.
1 parent dc271d0 commit 84b36c4

File tree

11 files changed

+516
-72
lines changed

11 files changed

+516
-72
lines changed

airflow-core/src/airflow/api_fastapi/execution_api/datamodels/xcom.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import sys
2121
from typing import Any
2222

23-
from pydantic import JsonValue
23+
from pydantic import JsonValue, RootModel
2424

2525
from airflow.api_fastapi.core_api.base import BaseModel
2626

@@ -36,3 +36,15 @@ class XComResponse(BaseModel):
3636
key: str
3737
value: JsonValue
3838
"""The returned XCom value in a JSON-compatible format."""
39+
40+
41+
class XComSequenceIndexResponse(RootModel):
42+
"""XCom schema with minimal structure for index-based access."""
43+
44+
root: JsonValue
45+
46+
47+
class XComSequenceSliceResponse(RootModel):
48+
"""XCom schema with minimal structure for slice-based access."""
49+
50+
root: list[JsonValue]

airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@
2727
from sqlalchemy.sql.selectable import Select
2828

2929
from airflow.api_fastapi.common.db.common import SessionDep
30-
from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse
30+
from airflow.api_fastapi.execution_api.datamodels.xcom import (
31+
XComResponse,
32+
XComSequenceIndexResponse,
33+
XComSequenceSliceResponse,
34+
)
3135
from airflow.api_fastapi.execution_api.deps import JWTBearerDep
3236
from airflow.models.taskmap import TaskMap
3337
from airflow.models.xcom import XComModel
@@ -184,6 +188,132 @@ def get_xcom(
184188
return XComResponse(key=key, value=result.value)
185189

186190

191+
@router.get(
192+
"/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}",
193+
description="Get a single XCom value from a mapped task by sequence index",
194+
)
195+
def get_mapped_xcom_by_index(
196+
dag_id: str,
197+
run_id: str,
198+
task_id: str,
199+
key: str,
200+
offset: int,
201+
session: SessionDep,
202+
) -> XComSequenceIndexResponse:
203+
xcom_query = XComModel.get_many(
204+
run_id=run_id,
205+
key=key,
206+
task_ids=task_id,
207+
dag_ids=dag_id,
208+
session=session,
209+
)
210+
xcom_query = xcom_query.order_by(None)
211+
if offset >= 0:
212+
xcom_query = xcom_query.order_by(XComModel.map_index.asc()).offset(offset)
213+
else:
214+
xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1 - offset)
215+
216+
if (result := xcom_query.limit(1).first()) is None:
217+
message = (
218+
f"XCom with {key=} {offset=} not found for task {task_id!r} in DAG run {run_id!r} of {dag_id!r}"
219+
)
220+
raise HTTPException(
221+
status_code=status.HTTP_404_NOT_FOUND,
222+
detail={"reason": "not_found", "message": message},
223+
)
224+
return XComSequenceIndexResponse(result.value)
225+
226+
227+
class GetXComSliceFilterParams(BaseModel):
228+
"""Class to house slice params."""
229+
230+
start: int | None = None
231+
stop: int | None = None
232+
step: int | None = None
233+
234+
235+
@router.get(
236+
"/{dag_id}/{run_id}/{task_id}/{key}/slice",
237+
description="Get XCom values from a mapped task by sequence slice",
238+
)
239+
def get_mapped_xcom_by_slice(
240+
dag_id: str,
241+
run_id: str,
242+
task_id: str,
243+
key: str,
244+
params: Annotated[GetXComSliceFilterParams, Query()],
245+
session: SessionDep,
246+
) -> XComSequenceSliceResponse:
247+
query = XComModel.get_many(
248+
run_id=run_id,
249+
key=key,
250+
task_ids=task_id,
251+
dag_ids=dag_id,
252+
session=session,
253+
)
254+
query = query.order_by(None)
255+
256+
step = params.step or 1
257+
258+
# We want to optimize negative slicing (e.g. seq[-10:]) by not doing an
259+
# additional COUNT query if possible. This is possible unless both start and
260+
# stop are explicitly given and have different signs.
261+
if (start := params.start) is None:
262+
if (stop := params.stop) is None:
263+
if step >= 0:
264+
query = query.order_by(XComModel.map_index.asc())
265+
else:
266+
query = query.order_by(XComModel.map_index.desc())
267+
step = -step
268+
elif stop >= 0:
269+
query = query.order_by(XComModel.map_index.asc())
270+
if step >= 0:
271+
query = query.limit(stop)
272+
else:
273+
query = query.offset(stop + 1)
274+
else:
275+
query = query.order_by(XComModel.map_index.desc())
276+
step = -step
277+
if step > 0:
278+
query = query.limit(-stop - 1)
279+
else:
280+
query = query.offset(-stop)
281+
elif start >= 0:
282+
query = query.order_by(XComModel.map_index.asc())
283+
if (stop := params.stop) is None:
284+
if step >= 0:
285+
query = query.offset(start)
286+
else:
287+
query = query.limit(start + 1)
288+
else:
289+
if stop < 0:
290+
stop += get_query_count(query, session=session)
291+
if step >= 0:
292+
query = query.slice(start, stop)
293+
else:
294+
query = query.slice(stop + 1, start + 1)
295+
else:
296+
query = query.order_by(XComModel.map_index.desc())
297+
step = -step
298+
if (stop := params.stop) is None:
299+
if step > 0:
300+
query = query.offset(-start - 1)
301+
else:
302+
query = query.limit(-start)
303+
else:
304+
if stop >= 0:
305+
stop -= get_query_count(query, session=session)
306+
if step > 0:
307+
query = query.slice(-1 - start, -1 - stop)
308+
else:
309+
query = query.slice(-stop, -start)
310+
311+
values = [row.value for row in query.with_entities(XComModel.value)]
312+
if step != 1:
313+
values = values[::step]
314+
return XComSequenceSliceResponse(values)
315+
316+
187317
if sys.version_info < (3, 12):
188318
# zmievsa/cadwyn#262
189319
# Setting this to "Any" doesn't have any impact on the API as it has to be parsed as valid JSON regardless

airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Licensed to the Apache Software Foundation (ASF) under one
2-
# Licensed to the Apache Software Foundation (ASF) under one
32
# or more contributor license agreements. See the NOTICE file
43
# distributed with this work for additional information
54
# regarding copyright ownership. The ASF licenses this file
@@ -20,6 +19,7 @@
2019

2120
import contextlib
2221
import logging
22+
import urllib.parse
2323

2424
import httpx
2525
import pytest
@@ -148,12 +148,12 @@ def test_xcom_access_denied(self, client, caplog):
148148
},
149149
id="-4",
150150
),
151-
pytest.param(-3, 200, {"key": "xcom_1", "value": "f"}, id="-3"),
152-
pytest.param(-2, 200, {"key": "xcom_1", "value": "o"}, id="-2"),
153-
pytest.param(-1, 200, {"key": "xcom_1", "value": "b"}, id="-1"),
154-
pytest.param(0, 200, {"key": "xcom_1", "value": "f"}, id="0"),
155-
pytest.param(1, 200, {"key": "xcom_1", "value": "o"}, id="1"),
156-
pytest.param(2, 200, {"key": "xcom_1", "value": "b"}, id="2"),
151+
pytest.param(-3, 200, "f", id="-3"),
152+
pytest.param(-2, 200, "o", id="-2"),
153+
pytest.param(-1, 200, "b", id="-1"),
154+
pytest.param(0, 200, "f", id="0"),
155+
pytest.param(1, 200, "o", id="1"),
156+
pytest.param(2, 200, "b", id="2"),
157157
pytest.param(
158158
3,
159159
404,
@@ -207,10 +207,72 @@ def __init__(self, *, x, **kwargs):
207207
session.add(x)
208208
session.commit()
209209

210-
response = client.get(f"/execution/xcoms/dag/runid/task/xcom_1?offset={offset}")
210+
response = client.get(f"/execution/xcoms/dag/runid/task/xcom_1/item/{offset}")
211211
assert response.status_code == expected_status
212212
assert response.json() == expected_json
213213

214+
@pytest.mark.parametrize(
215+
"key",
216+
[
217+
pytest.param(slice(None, None, None), id=":"),
218+
pytest.param(slice(None, None, -2), id="::-2"),
219+
pytest.param(slice(None, 2, None), id=":2"),
220+
pytest.param(slice(None, 2, -1), id=":2:-1"),
221+
pytest.param(slice(None, -2, None), id=":-2"),
222+
pytest.param(slice(None, -2, -1), id=":-2:-1"),
223+
pytest.param(slice(1, None, None), id="1:"),
224+
pytest.param(slice(2, None, -1), id="2::-1"),
225+
pytest.param(slice(1, 2, None), id="1:2"),
226+
pytest.param(slice(2, 1, -1), id="2:1:-1"),
227+
pytest.param(slice(1, -1, None), id="1:-1"),
228+
pytest.param(slice(2, -2, -1), id="2:-2:-1"),
229+
pytest.param(slice(-2, None, None), id="-2:"),
230+
pytest.param(slice(-1, None, -1), id="-1::-1"),
231+
pytest.param(slice(-2, -1, None), id="-2:-1"),
232+
pytest.param(slice(-1, -3, -1), id="-1:-3:-1"),
233+
],
234+
)
235+
def test_xcom_get_with_slice(self, client, dag_maker, session, key):
236+
xcom_values = ["f", None, "o", "b"]
237+
238+
class MyOperator(EmptyOperator):
239+
def __init__(self, *, x, **kwargs):
240+
super().__init__(**kwargs)
241+
self.x = x
242+
243+
with dag_maker(dag_id="dag"):
244+
MyOperator.partial(task_id="task").expand(x=xcom_values)
245+
dag_run = dag_maker.create_dagrun(run_id="runid")
246+
tis = {ti.map_index: ti for ti in dag_run.task_instances}
247+
248+
for map_index, db_value in enumerate(xcom_values):
249+
if db_value is None: # We don't put None to XCom.
250+
continue
251+
ti = tis[map_index]
252+
x = XComModel(
253+
key="xcom_1",
254+
value=db_value,
255+
dag_run_id=ti.dag_run.id,
256+
run_id=ti.run_id,
257+
task_id=ti.task_id,
258+
dag_id=ti.dag_id,
259+
map_index=map_index,
260+
)
261+
session.add(x)
262+
session.commit()
263+
264+
qs = {}
265+
if key.start is not None:
266+
qs["start"] = key.start
267+
if key.stop is not None:
268+
qs["stop"] = key.stop
269+
if key.step is not None:
270+
qs["step"] = key.step
271+
272+
response = client.get(f"/execution/xcoms/dag/runid/task/xcom_1/slice?{urllib.parse.urlencode(qs)}")
273+
assert response.status_code == 200
274+
assert response.json() == ["f", "o", "b"][key]
275+
214276

215277
class TestXComsSetEndpoint:
216278
@pytest.mark.parametrize(
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
import pytest
21+
22+
from airflow.models.xcom import XComModel
23+
from airflow.providers.standard.operators.empty import EmptyOperator
24+
25+
pytestmark = pytest.mark.db_test
26+
27+
28+
class TestXComsGetEndpoint:
29+
@pytest.mark.parametrize(
30+
"offset, expected_status, expected_json",
31+
[
32+
pytest.param(
33+
-4,
34+
404,
35+
{
36+
"detail": {
37+
"reason": "not_found",
38+
"message": (
39+
"XCom with key='xcom_1' offset=-4 not found "
40+
"for task 'task' in DAG run 'runid' of 'dag'"
41+
),
42+
},
43+
},
44+
id="-4",
45+
),
46+
pytest.param(-3, 200, {"key": "xcom_1", "value": "f"}, id="-3"),
47+
pytest.param(-2, 200, {"key": "xcom_1", "value": "o"}, id="-2"),
48+
pytest.param(-1, 200, {"key": "xcom_1", "value": "b"}, id="-1"),
49+
pytest.param(0, 200, {"key": "xcom_1", "value": "f"}, id="0"),
50+
pytest.param(1, 200, {"key": "xcom_1", "value": "o"}, id="1"),
51+
pytest.param(2, 200, {"key": "xcom_1", "value": "b"}, id="2"),
52+
pytest.param(
53+
3,
54+
404,
55+
{
56+
"detail": {
57+
"reason": "not_found",
58+
"message": (
59+
"XCom with key='xcom_1' offset=3 not found "
60+
"for task 'task' in DAG run 'runid' of 'dag'"
61+
),
62+
},
63+
},
64+
id="3",
65+
),
66+
],
67+
)
68+
def test_xcom_get_with_offset(
69+
self,
70+
client,
71+
dag_maker,
72+
session,
73+
offset,
74+
expected_status,
75+
expected_json,
76+
):
77+
xcom_values = ["f", None, "o", "b"]
78+
79+
class MyOperator(EmptyOperator):
80+
def __init__(self, *, x, **kwargs):
81+
super().__init__(**kwargs)
82+
self.x = x
83+
84+
with dag_maker(dag_id="dag"):
85+
MyOperator.partial(task_id="task").expand(x=xcom_values)
86+
87+
dag_run = dag_maker.create_dagrun(run_id="runid")
88+
tis = {ti.map_index: ti for ti in dag_run.task_instances}
89+
for map_index, db_value in enumerate(xcom_values):
90+
if db_value is None: # We don't put None to XCom.
91+
continue
92+
ti = tis[map_index]
93+
x = XComModel(
94+
key="xcom_1",
95+
value=db_value,
96+
dag_run_id=ti.dag_run.id,
97+
run_id=ti.run_id,
98+
task_id=ti.task_id,
99+
dag_id=ti.dag_id,
100+
map_index=map_index,
101+
)
102+
session.add(x)
103+
session.commit()
104+
105+
response = client.get(f"/execution/xcoms/dag/runid/task/xcom_1?offset={offset}")
106+
assert response.status_code == expected_status
107+
assert response.json() == expected_json

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: http://github.com/apache/airflow/commit/84b36c4a38daa9fc7e9ced598abdb973cdbf525d

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy