Skip to content

Commit 84b36c4

Browse files
committed
Implement slice on LazyXComSequence
I decided to split index and slice access to their separate endpoints, instead of reusing the GetXCom endpoint. This duplicates code a bit, but the input parameters are now a lot easier to reason with. It's unfortunate FastAPI does not natively allow unions on Query(), or this could be implemented a lot nicer.
1 parent dc271d0 commit 84b36c4

File tree

11 files changed

+516
-72
lines changed

11 files changed

+516
-72
lines changed

airflow-core/src/airflow/api_fastapi/execution_api/datamodels/xcom.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import sys
2121
from typing import Any
2222

23-
from pydantic import JsonValue
23+
from pydantic import JsonValue, RootModel
2424

2525
from airflow.api_fastapi.core_api.base import BaseModel
2626

@@ -36,3 +36,15 @@ class XComResponse(BaseModel):
3636
key: str
3737
value: JsonValue
3838
"""The returned XCom value in a JSON-compatible format."""
39+
40+
41+
class XComSequenceIndexResponse(RootModel):
42+
"""XCom schema with minimal structure for index-based access."""
43+
44+
root: JsonValue
45+
46+
47+
class XComSequenceSliceResponse(RootModel):
48+
"""XCom schema with minimal structure for slice-based access."""
49+
50+
root: list[JsonValue]

airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@
2727
from sqlalchemy.sql.selectable import Select
2828

2929
from airflow.api_fastapi.common.db.common import SessionDep
30-
from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse
30+
from airflow.api_fastapi.execution_api.datamodels.xcom import (
31+
XComResponse,
32+
XComSequenceIndexResponse,
33+
XComSequenceSliceResponse,
34+
)
3135
from airflow.api_fastapi.execution_api.deps import JWTBearerDep
3236
from airflow.models.taskmap import TaskMap
3337
from airflow.models.xcom import XComModel
@@ -184,6 +188,132 @@ def get_xcom(
184188
return XComResponse(key=key, value=result.value)
185189

186190

191+
@router.get(
192+
"/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}",
193+
description="Get a single XCom value from a mapped task by sequence index",
194+
)
195+
def get_mapped_xcom_by_index(
196+
dag_id: str,
197+
run_id: str,
198+
task_id: str,
199+
key: str,
200+
offset: int,
201+
session: SessionDep,
202+
) -> XComSequenceIndexResponse:
203+
xcom_query = XComModel.get_many(
204+
run_id=run_id,
205+
key=key,
206+
task_ids=task_id,
207+
dag_ids=dag_id,
208+
session=session,
209+
)
210+
xcom_query = xcom_query.order_by(None)
211+
if offset >= 0:
212+
xcom_query = xcom_query.order_by(XComModel.map_index.asc()).offset(offset)
213+
else:
214+
xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1 - offset)
215+
216+
if (result := xcom_query.limit(1).first()) is None:
217+
message = (
218+
f"XCom with {key=} {offset=} not found for task {task_id!r} in DAG run {run_id!r} of {dag_id!r}"
219+
)
220+
raise HTTPException(
221+
status_code=status.HTTP_404_NOT_FOUND,
222+
detail={"reason": "not_found", "message": message},
223+
)
224+
return XComSequenceIndexResponse(result.value)
225+
226+
227+
class GetXComSliceFilterParams(BaseModel):
228+
"""Class to house slice params."""
229+
230+
start: int | None = None
231+
stop: int | None = None
232+
step: int | None = None
233+
234+
235+
@router.get(
236+
"/{dag_id}/{run_id}/{task_id}/{key}/slice",
237+
description="Get XCom values from a mapped task by sequence slice",
238+
)
239+
def get_mapped_xcom_by_slice(
240+
dag_id: str,
241+
run_id: str,
242+
task_id: str,
243+
key: str,
244+
params: Annotated[GetXComSliceFilterParams, Query()],
245+
session: SessionDep,
246+
) -> XComSequenceSliceResponse:
247+
query = XComModel.get_many(
248+
run_id=run_id,
249+
key=key,
250+
task_ids=task_id,
251+
dag_ids=dag_id,
252+
session=session,
253+
)
254+
query = query.order_by(None)
255+
256+
step = params.step or 1
257+
258+
# We want to optimize negative slicing (e.g. seq[-10:]) by not doing an
259+
# additional COUNT query if possible. This is possible unless both start and
260+
# stop are explicitly given and have different signs.
261+
if (start := params.start) is None:
262+
if (stop := params.stop) is None:
263+
if step >= 0:
264+
query = query.order_by(XComModel.map_index.asc())
265+
else:
266+
query = query.order_by(XComModel.map_index.desc())
267+
step = -step
268+
elif stop >= 0:
269+
query = query.order_by(XComModel.map_index.asc())
270+
if step >= 0:
271+
query = query.limit(stop)
272+
else:
273+
query = query.offset(stop + 1)
274+
else:
275+
query = query.order_by(XComModel.map_index.desc())
276+
step = -step
277+
if step > 0:
278+
query = query.limit(-stop - 1)
279+
else:
280+
query = query.offset(-stop)
281+
elif start >= 0:
282+
query = query.order_by(XComModel.map_index.asc())
283+
if (stop := params.stop) is None:
284+
if step >= 0:
285+
query = query.offset(start)
286+
else:
287+
query = query.limit(start + 1)
288+
else:
289+
if stop < 0:
290+
stop += get_query_count(query, session=session)
291+
if step >= 0:
292+
query = query.slice(start, stop)
293+
else:
294+
query = query.slice(stop + 1, start + 1)
295+
else:
296+
query = query.order_by(XComModel.map_index.desc())
297+
step = -step
298+
if (stop := params.stop) is None:
299+
if step > 0:
300+
query = query.offset(-start - 1)
301+
else:
302+
query = query.limit(-start)
303+
else:
304+
if stop >= 0:
305+
stop -= get_query_count(query, session=session)
306+
if step > 0:
307+
query = query.slice(-1 - start, -1 - stop)
308+
else:
309+
query = query.slice(-stop, -start)
310+
311+
values = [row.value for row in query.with_entities(XComModel.value)]
312+
if step != 1:
313+
values = values[::step]
314+
return XComSequenceSliceResponse(values)
315+
316+
187317
if sys.version_info < (3, 12):
188318
# zmievsa/cadwyn#262
189319
# Setting this to "Any" doesn't have any impact on the API as it has to be parsed as valid JSON regardless

airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Licensed to the Apache Software Foundation (ASF) under one
2-
# Licensed to the Apache Software Foundation (ASF) under one
32
# or more contributor license agreements. See the NOTICE file
43
# distributed with this work for additional information
54
# regarding copyright ownership. The ASF licenses this file
@@ -20,6 +19,7 @@
2019

2120
import contextlib
2221
import logging
22+
import urllib.parse
2323

2424
import httpx
2525
import pytest
@@ -148,12 +148,12 @@ def test_xcom_access_denied(self, client, caplog):
148148
},
149149
id="-4",
150150
),
151-
pytest.param(-3, 200, {"key": "xcom_1", "value": "f"}, id="-3"),
152-
pytest.param(-2, 200, {"key": "xcom_1", "value": "o"}, id="-2"),
153-
pytest.param(-1, 200, {"key": "xcom_1", "value": "b"}, id="-1"),
154-
pytest.param(0, 200, {"key": "xcom_1", "value": "f"}, id="0"),
155-
pytest.param(1, 200, {"key": "xcom_1", "value": "o"}, id="1"),
156-
pytest.param(2, 200, {"key": "xcom_1", "value": "b"}, id="2"),
151+
pytest.param(-3, 200, "f", id="-3"),
152+
pytest.param(-2, 200, "o", id="-2"),
153+
pytest.param(-1, 200, "b", id="-1"),
154+
pytest.param(0, 200, "f", id="0"),
155+
pytest.param(1, 200, "o", id="1"),
156+
pytest.param(2, 200, "b", id="2"),
157157
pytest.param(
158158
3,
159159
404,
@@ -207,10 +207,72 @@ def __init__(self, *, x, **kwargs):
207207
session.add(x)
208208
session.commit()
209209

210-
response = client.get(f"/execution/xcoms/dag/runid/task/xcom_1?offset={offset}")
210+
response = client.get(f"/execution/xcoms/dag/runid/task/xcom_1/item/{offset}")
211211
assert response.status_code == expected_status
212212
assert response.json() == expected_json
213213

214+
@pytest.mark.parametrize(
215+
"key",
216+
[
217+
pytest.param(slice(None, None, None), id=":"),
218+
pytest.param(slice(None, None, -2), id="::-2"),
219+
pytest.param(slice(None, 2, None), id=":2"),
220+
pytest.param(slice(None, 2, -1), id=":2:-1"),
221+
pytest.param(slice(None, -2, None), id=":-2"),
222+
pytest.param(slice(None, -2, -1), id=":-2:-1"),
223+
pytest.param(slice(1, None, None), id="1:"),
224+
pytest.param(slice(2, None, -1), id="2::-1"),
225+
pytest.param(slice(1, 2, None), id="1:2"),
226+
pytest.param(slice(2, 1, -1), id="2:1:-1"),
227+
pytest.param(slice(1, -1, None), id="1:-1"),
228+
pytest.param(slice(2, -2, -1), id="2:-2:-1"),
229+
pytest.param(slice(-2, None, None), id="-2:"),
230+
pytest.param(slice(-1, None, -1), id="-1::-1"),
231+
pytest.param(slice(-2, -1, None), id="-2:-1"),
232+
pytest.param(slice(-1, -3, -1), id="-1:-3:-1"),
233+
],
234+
)
235+
def test_xcom_get_with_slice(self, client, dag_maker, session, key):
236+
xcom_values = ["f", None, "o", "b"]
237+
238+
class MyOperator(EmptyOperator):
239+
def __init__(self, *, x, **kwargs):
240+
super().__init__(**kwargs)
241+
self.x = x
242+
243+
with dag_maker(dag_id="dag"):
244+
MyOperator.partial(task_id="task").expand(x=xcom_values)
245+
dag_run = dag_maker.create_dagrun(run_id="runid")
246+
tis = {ti.map_index: ti for ti in dag_run.task_instances}
247+
248+
for map_index, db_value in enumerate(xcom_values):
249+
if db_value is None: # We don't put None to XCom.
250+
continue
251+
ti = tis[map_index]
252+
x = XComModel(
253+
key="xcom_1",
254+
value=db_value,
255+
dag_run_id=ti.dag_run.id,
256+
run_id=ti.run_id,
257+
task_id=ti.task_id,
258+
dag_id=ti.dag_id,
259+
map_index=map_index,
260+
)
261+
session.add(x)
262+
session.commit()
263+
264+
qs = {}
265+
if key.start is not None:
266+
qs["start"] = key.start
267+
if key.stop is not None:
268+
qs["stop"] = key.stop
269+
if key.step is not None:
270+
qs["step"] = key.step
271+
272+
response = client.get(f"/execution/xcoms/dag/runid/task/xcom_1/slice?{urllib.parse.urlencode(qs)}")
273+
assert response.status_code == 200
274+
assert response.json() == ["f", "o", "b"][key]
275+
214276

215277
class TestXComsSetEndpoint:
216278
@pytest.mark.parametrize(
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
import pytest
21+
22+
from airflow.models.xcom import XComModel
23+
from airflow.providers.standard.operators.empty import EmptyOperator
24+
25+
pytestmark = pytest.mark.db_test
26+
27+
28+
class TestXComsGetEndpoint:
29+
@pytest.mark.parametrize(
30+
"offset, expected_status, expected_json",
31+
[
32+
pytest.param(
33+
-4,
34+
404,
35+
{
36+
"detail": {
37+
"reason": "not_found",
38+
"message": (
39+
"XCom with key='xcom_1' offset=-4 not found "
40+
"for task 'task' in DAG run 'runid' of 'dag'"
41+
),
42+
},
43+
},
44+
id="-4",
45+
),
46+
pytest.param(-3, 200, {"key": "xcom_1", "value": "f"}, id="-3"),
47+
pytest.param(-2, 200, {"key": "xcom_1", "value": "o"}, id="-2"),
48+
pytest.param(-1, 200, {"key": "xcom_1", "value": "b"}, id="-1"),
49+
pytest.param(0, 200, {"key": "xcom_1", "value": "f"}, id="0"),
50+
pytest.param(1, 200, {"key": "xcom_1", "value": "o"}, id="1"),
51+
pytest.param(2, 200, {"key": "xcom_1", "value": "b"}, id="2"),
52+
pytest.param(
53+
3,
54+
404,
55+
{
56+
"detail": {
57+
"reason": "not_found",
58+
"message": (
59+
"XCom with key='xcom_1' offset=3 not found "
60+
"for task 'task' in DAG run 'runid' of 'dag'"
61+
),
62+
},
63+
},
64+
id="3",
65+
),
66+
],
67+
)
68+
def test_xcom_get_with_offset(
69+
self,
70+
client,
71+
dag_maker,
72+
session,
73+
offset,
74+
expected_status,
75+
expected_json,
76+
):
77+
xcom_values = ["f", None, "o", "b"]
78+
79+
class MyOperator(EmptyOperator):
80+
def __init__(self, *, x, **kwargs):
81+
super().__init__(**kwargs)
82+
self.x = x
83+
84+
with dag_maker(dag_id="dag"):
85+
MyOperator.partial(task_id="task").expand(x=xcom_values)
86+
87+
dag_run = dag_maker.create_dagrun(run_id="runid")
88+
tis = {ti.map_index: ti for ti in dag_run.task_instances}
89+
for map_index, db_value in enumerate(xcom_values):
90+
if db_value is None: # We don't put None to XCom.
91+
continue
92+
ti = tis[map_index]
93+
x = XComModel(
94+
key="xcom_1",
95+
value=db_value,
96+
dag_run_id=ti.dag_run.id,
97+
run_id=ti.run_id,
98+
task_id=ti.task_id,
99+
dag_id=ti.dag_id,
100+
map_index=map_index,
101+
)
102+
session.add(x)
103+
session.commit()
104+
105+
response = client.get(f"/execution/xcoms/dag/runid/task/xcom_1?offset={offset}")
106+
assert response.status_code == expected_status
107+
assert response.json() == expected_json

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy