Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import sys
from typing import Any

from pydantic import JsonValue
from pydantic import JsonValue, RootModel

from airflow.api_fastapi.core_api.base import BaseModel

Expand All @@ -36,3 +36,15 @@ class XComResponse(BaseModel):
key: str
value: JsonValue
"""The returned XCom value in a JSON-compatible format."""


class XComSequenceIndexResponse(RootModel):
"""XCom schema with minimal structure for index-based access."""

root: JsonValue


class XComSequenceSliceResponse(RootModel):
"""XCom schema with minimal structure for slice-based access."""

root: list[JsonValue]
132 changes: 131 additions & 1 deletion airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
from sqlalchemy.sql.selectable import Select

from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse
from airflow.api_fastapi.execution_api.datamodels.xcom import (
XComResponse,
XComSequenceIndexResponse,
XComSequenceSliceResponse,
)
from airflow.api_fastapi.execution_api.deps import JWTBearerDep
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XComModel
Expand Down Expand Up @@ -184,6 +188,132 @@ def get_xcom(
return XComResponse(key=key, value=result.value)


@router.get(
"/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}",
description="Get a single XCom value from a mapped task by sequence index",
)
def get_mapped_xcom_by_index(
dag_id: str,
run_id: str,
task_id: str,
key: str,
offset: int,
session: SessionDep,
) -> XComSequenceIndexResponse:
xcom_query = XComModel.get_many(
run_id=run_id,
key=key,
task_ids=task_id,
dag_ids=dag_id,
session=session,
)
xcom_query = xcom_query.order_by(None)
if offset >= 0:
xcom_query = xcom_query.order_by(XComModel.map_index.asc()).offset(offset)
else:
xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1 - offset)

if (result := xcom_query.limit(1).first()) is None:
message = (
f"XCom with {key=} {offset=} not found for task {task_id!r} in DAG run {run_id!r} of {dag_id!r}"
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={"reason": "not_found", "message": message},
)
return XComSequenceIndexResponse(result.value)


class GetXComSliceFilterParams(BaseModel):
"""Class to house slice params."""

start: int | None = None
stop: int | None = None
step: int | None = None


@router.get(
"/{dag_id}/{run_id}/{task_id}/{key}/slice",
description="Get XCom values from a mapped task by sequence slice",
)
def get_mapped_xcom_by_slice(
dag_id: str,
run_id: str,
task_id: str,
key: str,
params: Annotated[GetXComSliceFilterParams, Query()],
session: SessionDep,
) -> XComSequenceSliceResponse:
query = XComModel.get_many(
run_id=run_id,
key=key,
task_ids=task_id,
dag_ids=dag_id,
session=session,
)
query = query.order_by(None)

step = params.step or 1

# We want to optimize negative slicing (e.g. seq[-10:]) by not doing an
# additional COUNT query if possible. This is possible unless both start and
# stop are explicitly given and have different signs.
if (start := params.start) is None:
if (stop := params.stop) is None:
if step >= 0:
query = query.order_by(XComModel.map_index.asc())
else:
query = query.order_by(XComModel.map_index.desc())
step = -step
elif stop >= 0:
query = query.order_by(XComModel.map_index.asc())
if step >= 0:
query = query.limit(stop)
else:
query = query.offset(stop + 1)
else:
query = query.order_by(XComModel.map_index.desc())
step = -step
if step > 0:
query = query.limit(-stop - 1)
else:
query = query.offset(-stop)
elif start >= 0:
query = query.order_by(XComModel.map_index.asc())
if (stop := params.stop) is None:
if step >= 0:
query = query.offset(start)
else:
query = query.limit(start + 1)
else:
if stop < 0:
stop += get_query_count(query, session=session)
if step >= 0:
query = query.slice(start, stop)
else:
query = query.slice(stop + 1, start + 1)
else:
query = query.order_by(XComModel.map_index.desc())
step = -step
if (stop := params.stop) is None:
if step > 0:
query = query.offset(-start - 1)
else:
query = query.limit(-start)
else:
if stop >= 0:
stop -= get_query_count(query, session=session)
if step > 0:
query = query.slice(-1 - start, -1 - stop)
else:
query = query.slice(-stop, -start)

values = [row.value for row in query.with_entities(XComModel.value)]
if step != 1:
values = values[::step]
return XComSequenceSliceResponse(values)


if sys.version_info < (3, 12):
# zmievsa/cadwyn#262
# Setting this to "Any" doesn't have any impact on the API as it has to be parsed as valid JSON regardless
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Licensed to the Apache Software Foundation (ASF) under one
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
Expand All @@ -20,6 +19,7 @@

import contextlib
import logging
import urllib.parse

import httpx
import pytest
Expand Down Expand Up @@ -148,12 +148,12 @@ def test_xcom_access_denied(self, client, caplog):
},
id="-4",
),
pytest.param(-3, 200, {"key": "xcom_1", "value": "f"}, id="-3"),
pytest.param(-2, 200, {"key": "xcom_1", "value": "o"}, id="-2"),
pytest.param(-1, 200, {"key": "xcom_1", "value": "b"}, id="-1"),
pytest.param(0, 200, {"key": "xcom_1", "value": "f"}, id="0"),
pytest.param(1, 200, {"key": "xcom_1", "value": "o"}, id="1"),
pytest.param(2, 200, {"key": "xcom_1", "value": "b"}, id="2"),
pytest.param(-3, 200, "f", id="-3"),
pytest.param(-2, 200, "o", id="-2"),
pytest.param(-1, 200, "b", id="-1"),
pytest.param(0, 200, "f", id="0"),
pytest.param(1, 200, "o", id="1"),
pytest.param(2, 200, "b", id="2"),
pytest.param(
3,
404,
Expand Down Expand Up @@ -207,10 +207,72 @@ def __init__(self, *, x, **kwargs):
session.add(x)
session.commit()

response = client.get(f"/execution/xcoms/dag/runid/task/xcom_1?offset={offset}")
response = client.get(f"/execution/xcoms/dag/runid/task/xcom_1/item/{offset}")
assert response.status_code == expected_status
assert response.json() == expected_json

@pytest.mark.parametrize(
"key",
[
pytest.param(slice(None, None, None), id=":"),
pytest.param(slice(None, None, -2), id="::-2"),
pytest.param(slice(None, 2, None), id=":2"),
pytest.param(slice(None, 2, -1), id=":2:-1"),
pytest.param(slice(None, -2, None), id=":-2"),
pytest.param(slice(None, -2, -1), id=":-2:-1"),
pytest.param(slice(1, None, None), id="1:"),
pytest.param(slice(2, None, -1), id="2::-1"),
pytest.param(slice(1, 2, None), id="1:2"),
pytest.param(slice(2, 1, -1), id="2:1:-1"),
pytest.param(slice(1, -1, None), id="1:-1"),
pytest.param(slice(2, -2, -1), id="2:-2:-1"),
pytest.param(slice(-2, None, None), id="-2:"),
pytest.param(slice(-1, None, -1), id="-1::-1"),
pytest.param(slice(-2, -1, None), id="-2:-1"),
pytest.param(slice(-1, -3, -1), id="-1:-3:-1"),
],
)
def test_xcom_get_with_slice(self, client, dag_maker, session, key):
xcom_values = ["f", None, "o", "b"]

class MyOperator(EmptyOperator):
def __init__(self, *, x, **kwargs):
super().__init__(**kwargs)
self.x = x

with dag_maker(dag_id="dag"):
MyOperator.partial(task_id="task").expand(x=xcom_values)
dag_run = dag_maker.create_dagrun(run_id="runid")
tis = {ti.map_index: ti for ti in dag_run.task_instances}

for map_index, db_value in enumerate(xcom_values):
if db_value is None: # We don't put None to XCom.
continue
ti = tis[map_index]
x = XComModel(
key="xcom_1",
value=db_value,
dag_run_id=ti.dag_run.id,
run_id=ti.run_id,
task_id=ti.task_id,
dag_id=ti.dag_id,
map_index=map_index,
)
session.add(x)
session.commit()

qs = {}
if key.start is not None:
qs["start"] = key.start
if key.stop is not None:
qs["stop"] = key.stop
if key.step is not None:
qs["step"] = key.step

response = client.get(f"/execution/xcoms/dag/runid/task/xcom_1/slice?{urllib.parse.urlencode(qs)}")
assert response.status_code == 200
assert response.json() == ["f", "o", "b"][key]


class TestXComsSetEndpoint:
@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import pytest

from airflow.models.xcom import XComModel
from airflow.providers.standard.operators.empty import EmptyOperator

pytestmark = pytest.mark.db_test


class TestXComsGetEndpoint:
@pytest.mark.parametrize(
"offset, expected_status, expected_json",
[
pytest.param(
-4,
404,
{
"detail": {
"reason": "not_found",
"message": (
"XCom with key='xcom_1' offset=-4 not found "
"for task 'task' in DAG run 'runid' of 'dag'"
),
},
},
id="-4",
),
pytest.param(-3, 200, {"key": "xcom_1", "value": "f"}, id="-3"),
pytest.param(-2, 200, {"key": "xcom_1", "value": "o"}, id="-2"),
pytest.param(-1, 200, {"key": "xcom_1", "value": "b"}, id="-1"),
pytest.param(0, 200, {"key": "xcom_1", "value": "f"}, id="0"),
pytest.param(1, 200, {"key": "xcom_1", "value": "o"}, id="1"),
pytest.param(2, 200, {"key": "xcom_1", "value": "b"}, id="2"),
pytest.param(
3,
404,
{
"detail": {
"reason": "not_found",
"message": (
"XCom with key='xcom_1' offset=3 not found "
"for task 'task' in DAG run 'runid' of 'dag'"
),
},
},
id="3",
),
],
)
def test_xcom_get_with_offset(
self,
client,
dag_maker,
session,
offset,
expected_status,
expected_json,
):
xcom_values = ["f", None, "o", "b"]

class MyOperator(EmptyOperator):
def __init__(self, *, x, **kwargs):
super().__init__(**kwargs)
self.x = x

with dag_maker(dag_id="dag"):
MyOperator.partial(task_id="task").expand(x=xcom_values)

dag_run = dag_maker.create_dagrun(run_id="runid")
tis = {ti.map_index: ti for ti in dag_run.task_instances}
for map_index, db_value in enumerate(xcom_values):
if db_value is None: # We don't put None to XCom.
continue
ti = tis[map_index]
x = XComModel(
key="xcom_1",
value=db_value,
dag_run_id=ti.dag_run.id,
run_id=ti.run_id,
task_id=ti.task_id,
dag_id=ti.dag_id,
map_index=map_index,
)
session.add(x)
session.commit()

response = client.get(f"/execution/xcoms/dag/runid/task/xcom_1?offset={offset}")
assert response.status_code == expected_status
assert response.json() == expected_json
Loading
Loading