From 84b36c4a38daa9fc7e9ced598abdb973cdbf525d Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Fri, 2 May 2025 16:54:23 +0800 Subject: [PATCH] Implement slice on LazyXComSequence I decided to split index and slice access to their separate endpoints, instead of reusing the GetXCom endpoint. This duplicates code a bit, but the input parameters are now a lot easier to reason with. It's unfortunate FastAPI does not natively allow unions on Query(), or this could be implemented a lot nicer. --- .../execution_api/datamodels/xcom.py | 14 +- .../api_fastapi/execution_api/routes/xcoms.py | 132 +++++++++++++++++- .../execution_api/versions/head/test_xcoms.py | 78 +++++++++-- .../versions/v2025_04_28/test_xcom.py | 107 ++++++++++++++ task-sdk/src/airflow/sdk/api/client.py | 29 +++- .../airflow/sdk/api/datamodels/_generated.py | 26 +++- .../src/airflow/sdk/execution_time/comms.py | 36 ++++- .../sdk/execution_time/lazy_sequence.py | 94 +++++++------ .../airflow/sdk/execution_time/supervisor.py | 14 +- .../execution_time/test_lazy_sequence.py | 33 ++++- .../execution_time/test_supervisor.py | 25 +++- 11 files changed, 516 insertions(+), 72 deletions(-) create mode 100644 airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_04_28/test_xcom.py diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/xcom.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/xcom.py index ae7ddd26761cd..4df3e3f74f059 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/xcom.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/xcom.py @@ -20,7 +20,7 @@ import sys from typing import Any -from pydantic import JsonValue +from pydantic import JsonValue, RootModel from airflow.api_fastapi.core_api.base import BaseModel @@ -36,3 +36,15 @@ class XComResponse(BaseModel): key: str value: JsonValue """The returned XCom value in a JSON-compatible format.""" + + +class XComSequenceIndexResponse(RootModel): + """XCom schema with minimal structure for index-based access.""" + + root: JsonValue + + +class XComSequenceSliceResponse(RootModel): + """XCom schema with minimal structure for slice-based access.""" + + root: list[JsonValue] diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py index d4c8c5160cd97..a9ed4a5b48d20 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -27,7 +27,11 @@ from sqlalchemy.sql.selectable import Select from airflow.api_fastapi.common.db.common import SessionDep -from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse +from airflow.api_fastapi.execution_api.datamodels.xcom import ( + XComResponse, + XComSequenceIndexResponse, + XComSequenceSliceResponse, +) from airflow.api_fastapi.execution_api.deps import JWTBearerDep from airflow.models.taskmap import TaskMap from airflow.models.xcom import XComModel @@ -184,6 +188,132 @@ def get_xcom( return XComResponse(key=key, value=result.value) +@router.get( + "/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}", + description="Get a single XCom value from a mapped task by sequence index", +) +def get_mapped_xcom_by_index( + dag_id: str, + run_id: str, + task_id: str, + key: str, + offset: int, + session: SessionDep, +) -> XComSequenceIndexResponse: + xcom_query = XComModel.get_many( + run_id=run_id, + key=key, + task_ids=task_id, + dag_ids=dag_id, + session=session, + ) + xcom_query = xcom_query.order_by(None) + if offset >= 0: + xcom_query = xcom_query.order_by(XComModel.map_index.asc()).offset(offset) + else: + xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1 - offset) + + if (result := xcom_query.limit(1).first()) is None: + message = ( + f"XCom with {key=} {offset=} not found for task {task_id!r} in DAG run {run_id!r} of {dag_id!r}" + ) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"reason": "not_found", "message": message}, + ) + return XComSequenceIndexResponse(result.value) + + +class GetXComSliceFilterParams(BaseModel): + """Class to house slice params.""" + + start: int | None = None + stop: int | None = None + step: int | None = None + + +@router.get( + "/{dag_id}/{run_id}/{task_id}/{key}/slice", + description="Get XCom values from a mapped task by sequence slice", +) +def get_mapped_xcom_by_slice( + dag_id: str, + run_id: str, + task_id: str, + key: str, + params: Annotated[GetXComSliceFilterParams, Query()], + session: SessionDep, +) -> XComSequenceSliceResponse: + query = XComModel.get_many( + run_id=run_id, + key=key, + task_ids=task_id, + dag_ids=dag_id, + session=session, + ) + query = query.order_by(None) + + step = params.step or 1 + + # We want to optimize negative slicing (e.g. seq[-10:]) by not doing an + # additional COUNT query if possible. This is possible unless both start and + # stop are explicitly given and have different signs. + if (start := params.start) is None: + if (stop := params.stop) is None: + if step >= 0: + query = query.order_by(XComModel.map_index.asc()) + else: + query = query.order_by(XComModel.map_index.desc()) + step = -step + elif stop >= 0: + query = query.order_by(XComModel.map_index.asc()) + if step >= 0: + query = query.limit(stop) + else: + query = query.offset(stop + 1) + else: + query = query.order_by(XComModel.map_index.desc()) + step = -step + if step > 0: + query = query.limit(-stop - 1) + else: + query = query.offset(-stop) + elif start >= 0: + query = query.order_by(XComModel.map_index.asc()) + if (stop := params.stop) is None: + if step >= 0: + query = query.offset(start) + else: + query = query.limit(start + 1) + else: + if stop < 0: + stop += get_query_count(query, session=session) + if step >= 0: + query = query.slice(start, stop) + else: + query = query.slice(stop + 1, start + 1) + else: + query = query.order_by(XComModel.map_index.desc()) + step = -step + if (stop := params.stop) is None: + if step > 0: + query = query.offset(-start - 1) + else: + query = query.limit(-start) + else: + if stop >= 0: + stop -= get_query_count(query, session=session) + if step > 0: + query = query.slice(-1 - start, -1 - stop) + else: + query = query.slice(-stop, -start) + + values = [row.value for row in query.with_entities(XComModel.value)] + if step != 1: + values = values[::step] + return XComSequenceSliceResponse(values) + + if sys.version_info < (3, 12): # zmievsa/cadwyn#262 # Setting this to "Any" doesn't have any impact on the API as it has to be parsed as valid JSON regardless diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py index 951fbf5cbae0a..1b10e81cd2338 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py @@ -1,5 +1,4 @@ # Licensed to the Apache Software Foundation (ASF) under one -# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file @@ -20,6 +19,7 @@ import contextlib import logging +import urllib.parse import httpx import pytest @@ -148,12 +148,12 @@ def test_xcom_access_denied(self, client, caplog): }, id="-4", ), - pytest.param(-3, 200, {"key": "xcom_1", "value": "f"}, id="-3"), - pytest.param(-2, 200, {"key": "xcom_1", "value": "o"}, id="-2"), - pytest.param(-1, 200, {"key": "xcom_1", "value": "b"}, id="-1"), - pytest.param(0, 200, {"key": "xcom_1", "value": "f"}, id="0"), - pytest.param(1, 200, {"key": "xcom_1", "value": "o"}, id="1"), - pytest.param(2, 200, {"key": "xcom_1", "value": "b"}, id="2"), + pytest.param(-3, 200, "f", id="-3"), + pytest.param(-2, 200, "o", id="-2"), + pytest.param(-1, 200, "b", id="-1"), + pytest.param(0, 200, "f", id="0"), + pytest.param(1, 200, "o", id="1"), + pytest.param(2, 200, "b", id="2"), pytest.param( 3, 404, @@ -207,10 +207,72 @@ def __init__(self, *, x, **kwargs): session.add(x) session.commit() - response = client.get(f"/execution/xcoms/dag/runid/task/xcom_1?offset={offset}") + response = client.get(f"/execution/xcoms/dag/runid/task/xcom_1/item/{offset}") assert response.status_code == expected_status assert response.json() == expected_json + @pytest.mark.parametrize( + "key", + [ + pytest.param(slice(None, None, None), id=":"), + pytest.param(slice(None, None, -2), id="::-2"), + pytest.param(slice(None, 2, None), id=":2"), + pytest.param(slice(None, 2, -1), id=":2:-1"), + pytest.param(slice(None, -2, None), id=":-2"), + pytest.param(slice(None, -2, -1), id=":-2:-1"), + pytest.param(slice(1, None, None), id="1:"), + pytest.param(slice(2, None, -1), id="2::-1"), + pytest.param(slice(1, 2, None), id="1:2"), + pytest.param(slice(2, 1, -1), id="2:1:-1"), + pytest.param(slice(1, -1, None), id="1:-1"), + pytest.param(slice(2, -2, -1), id="2:-2:-1"), + pytest.param(slice(-2, None, None), id="-2:"), + pytest.param(slice(-1, None, -1), id="-1::-1"), + pytest.param(slice(-2, -1, None), id="-2:-1"), + pytest.param(slice(-1, -3, -1), id="-1:-3:-1"), + ], + ) + def test_xcom_get_with_slice(self, client, dag_maker, session, key): + xcom_values = ["f", None, "o", "b"] + + class MyOperator(EmptyOperator): + def __init__(self, *, x, **kwargs): + super().__init__(**kwargs) + self.x = x + + with dag_maker(dag_id="dag"): + MyOperator.partial(task_id="task").expand(x=xcom_values) + dag_run = dag_maker.create_dagrun(run_id="runid") + tis = {ti.map_index: ti for ti in dag_run.task_instances} + + for map_index, db_value in enumerate(xcom_values): + if db_value is None: # We don't put None to XCom. + continue + ti = tis[map_index] + x = XComModel( + key="xcom_1", + value=db_value, + dag_run_id=ti.dag_run.id, + run_id=ti.run_id, + task_id=ti.task_id, + dag_id=ti.dag_id, + map_index=map_index, + ) + session.add(x) + session.commit() + + qs = {} + if key.start is not None: + qs["start"] = key.start + if key.stop is not None: + qs["stop"] = key.stop + if key.step is not None: + qs["step"] = key.step + + response = client.get(f"/execution/xcoms/dag/runid/task/xcom_1/slice?{urllib.parse.urlencode(qs)}") + assert response.status_code == 200 + assert response.json() == ["f", "o", "b"][key] + class TestXComsSetEndpoint: @pytest.mark.parametrize( diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_04_28/test_xcom.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_04_28/test_xcom.py new file mode 100644 index 0000000000000..1de65d493d1ac --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_04_28/test_xcom.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pytest + +from airflow.models.xcom import XComModel +from airflow.providers.standard.operators.empty import EmptyOperator + +pytestmark = pytest.mark.db_test + + +class TestXComsGetEndpoint: + @pytest.mark.parametrize( + "offset, expected_status, expected_json", + [ + pytest.param( + -4, + 404, + { + "detail": { + "reason": "not_found", + "message": ( + "XCom with key='xcom_1' offset=-4 not found " + "for task 'task' in DAG run 'runid' of 'dag'" + ), + }, + }, + id="-4", + ), + pytest.param(-3, 200, {"key": "xcom_1", "value": "f"}, id="-3"), + pytest.param(-2, 200, {"key": "xcom_1", "value": "o"}, id="-2"), + pytest.param(-1, 200, {"key": "xcom_1", "value": "b"}, id="-1"), + pytest.param(0, 200, {"key": "xcom_1", "value": "f"}, id="0"), + pytest.param(1, 200, {"key": "xcom_1", "value": "o"}, id="1"), + pytest.param(2, 200, {"key": "xcom_1", "value": "b"}, id="2"), + pytest.param( + 3, + 404, + { + "detail": { + "reason": "not_found", + "message": ( + "XCom with key='xcom_1' offset=3 not found " + "for task 'task' in DAG run 'runid' of 'dag'" + ), + }, + }, + id="3", + ), + ], + ) + def test_xcom_get_with_offset( + self, + client, + dag_maker, + session, + offset, + expected_status, + expected_json, + ): + xcom_values = ["f", None, "o", "b"] + + class MyOperator(EmptyOperator): + def __init__(self, *, x, **kwargs): + super().__init__(**kwargs) + self.x = x + + with dag_maker(dag_id="dag"): + MyOperator.partial(task_id="task").expand(x=xcom_values) + + dag_run = dag_maker.create_dagrun(run_id="runid") + tis = {ti.map_index: ti for ti in dag_run.task_instances} + for map_index, db_value in enumerate(xcom_values): + if db_value is None: # We don't put None to XCom. + continue + ti = tis[map_index] + x = XComModel( + key="xcom_1", + value=db_value, + dag_run_id=ti.dag_run.id, + run_id=ti.run_id, + task_id=ti.task_id, + dag_id=ti.dag_id, + map_index=map_index, + ) + session.add(x) + session.commit() + + response = client.get(f"/execution/xcoms/dag/runid/task/xcom_1?offset={offset}") + assert response.status_code == expected_status + assert response.json() == expected_json diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 1fd548319e71f..b9d0a4511ea10 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -58,6 +58,8 @@ VariablePostBody, VariableResponse, XComResponse, + XComSequenceIndexResponse, + XComSequenceSliceResponse, ) from airflow.sdk.exceptions import ErrorType from airflow.sdk.execution_time.comms import ( @@ -442,10 +444,9 @@ def get_sequence_item( task_id: str, key: str, offset: int, - ) -> XComResponse | ErrorResponse: - params = {"offset": offset} + ) -> XComSequenceIndexResponse | ErrorResponse: try: - resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params) + resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}") except ServerResponseError as e: if e.response.status_code == HTTPStatus.NOT_FOUND: log.error( @@ -469,7 +470,27 @@ def get_sequence_item( }, ) raise - return XComResponse.model_validate_json(resp.read()) + return XComSequenceIndexResponse.model_validate_json(resp.read()) + + def get_sequence_slice( + self, + dag_id: str, + run_id: str, + task_id: str, + key: str, + start: int | None, + stop: int | None, + step: int | None, + ) -> XComSequenceSliceResponse: + params = {} + if start is not None: + params["start"] = start + if stop is not None: + params["stop"] = stop + if step is not None: + params["step"] = step + resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}/slice", params=params) + return XComSequenceSliceResponse.model_validate_json(resp.read()) class AssetOperations: diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index 3efae80e5b671..f6b1c907ef529 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -25,7 +25,7 @@ from typing import Annotated, Any, Final, Literal from uuid import UUID -from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, RootModel API_VERSION: Final[str] = "2025-05-20" @@ -356,6 +356,30 @@ class XComResponse(BaseModel): value: JsonValue +class XComSequenceIndexResponse(RootModel[JsonValue]): + root: Annotated[ + JsonValue, + Field( + description="XCom schema with minimal structure for index-based access.", + title="XComSequenceIndexResponse", + ), + ] + + +class XComSequenceSliceResponse(RootModel[list[JsonValue]]): + """ + XCom schema with minimal structure for slice-based access. + """ + + root: Annotated[ + list[JsonValue], + Field( + description="XCom schema with minimal structure for slice-based access.", + title="XComSequenceSliceResponse", + ), + ] + + class TaskInstance(BaseModel): """ Schema for TaskInstance model with minimal required fields needed for Runtime. diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index a25ba5745827d..ecc34852252e0 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -74,6 +74,8 @@ TriggerDAGRunPayload, VariableResponse, XComResponse, + XComSequenceIndexResponse, + XComSequenceSliceResponse, ) from airflow.sdk.exceptions import ErrorType @@ -227,6 +229,24 @@ class XComCountResponse(BaseModel): type: Literal["XComLengthResponse"] = "XComLengthResponse" +class XComSequenceIndexResult(BaseModel): + root: JsonValue + type: Literal["XComSequenceIndexResult"] = "XComSequenceIndexResult" + + @classmethod + def from_response(cls, response: XComSequenceIndexResponse) -> XComSequenceIndexResult: + return cls(root=response.root, type="XComSequenceIndexResult") + + +class XComSequenceSliceResult(BaseModel): + root: list[JsonValue] + type: Literal["XComSequenceSliceResult"] = "XComSequenceSliceResult" + + @classmethod + def from_response(cls, response: XComSequenceSliceResponse) -> XComSequenceSliceResult: + return cls(root=response.root, type="XComSequenceSliceResult") + + class ConnectionResult(ConnectionResponse): type: Literal["ConnectionResult"] = "ConnectionResult" @@ -352,8 +372,10 @@ class OKResponse(BaseModel): TICount, TaskStatesResult, VariableResult, - XComResult, XComCountResponse, + XComResult, + XComSequenceIndexResult, + XComSequenceSliceResult, OKResponse, ], Field(discriminator="type"), @@ -451,6 +473,17 @@ class GetXComSequenceItem(BaseModel): type: Literal["GetXComSequenceItem"] = "GetXComSequenceItem" +class GetXComSequenceSlice(BaseModel): + key: str + dag_id: str + run_id: str + task_id: str + start: int | None + stop: int | None + step: int | None + type: Literal["GetXComSequenceSlice"] = "GetXComSequenceSlice" + + class SetXCom(BaseModel): key: str value: Annotated[ @@ -616,6 +649,7 @@ class GetDRCount(BaseModel): GetXCom, GetXComCount, GetXComSequenceItem, + GetXComSequenceSlice, PutVariable, RescheduleTask, RetryTask, diff --git a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py index 0fbfcf39498e4..9cf9acfac81bb 100644 --- a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py +++ b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import collections import itertools from collections.abc import Iterator, Sequence from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload @@ -30,6 +31,11 @@ T = TypeVar("T") +# This is used to wrap values from the API so the structure is compatible with +# ``XCom.deserialize_value``. We don't want to wrap the API values in a nested +# {"value": value} dict since it wastes bandwidth. +_XComWrapper = collections.namedtuple("_XComWrapper", "value") + log = structlog.get_logger(logger_name=__name__) @@ -98,7 +104,7 @@ def __len__(self) -> int: if isinstance(msg, ErrorResponse): raise RuntimeError(msg) if not isinstance(msg, XComCountResponse): - raise TypeError(f"Got unexpected response to GetXComCount: {msg}") + raise TypeError(f"Got unexpected response to GetXComCount: {msg!r}") self._len = msg.len return self._len @@ -109,41 +115,42 @@ def __getitem__(self, key: int) -> T: ... def __getitem__(self, key: slice) -> Sequence[T]: ... def __getitem__(self, key: int | slice) -> T | Sequence[T]: - if not isinstance(key, (int, slice)): - raise TypeError(f"Sequence indices must be integers or slices, not {type(key).__name__}") - - if isinstance(key, slice): - raise TypeError("slice is not implemented yet") - # TODO... - # This implements the slicing syntax. We want to optimize negative slicing (e.g. seq[-10:]) by not - # doing an additional COUNT query (via HEAD http request) if possible. We can do this unless the - # start and stop have different signs (i.e. one is positive and another negative). - # start, stop, reverse = _coerce_slice(key) - # if start >= 0: - # if stop is None: - # stmt = self._select_asc.offset(start) - # elif stop >= 0: - # stmt = self._select_asc.slice(start, stop) - # else: - # stmt = self._select_asc.slice(start, len(self) + stop) - # rows = [self._process_row(row) for row in self._session.execute(stmt)] - # if reverse: - # rows.reverse() - # else: - # if stop is None: - # stmt = self._select_desc.limit(-start) - # elif stop < 0: - # stmt = self._select_desc.slice(-stop, -start) - # else: - # stmt = self._select_desc.slice(len(self) - stop, -start) - # rows = [self._process_row(row) for row in self._session.execute(stmt)] - # if not reverse: - # rows.reverse() - # return rows - from airflow.sdk.execution_time.comms import GetXComSequenceItem, XComResult + from airflow.sdk.execution_time.comms import ( + ErrorResponse, + GetXComSequenceItem, + GetXComSequenceSlice, + XComSequenceIndexResult, + XComSequenceSliceResult, + ) from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS from airflow.sdk.execution_time.xcom import XCom + if isinstance(key, slice): + start, stop, step = _coerce_slice(key) + with SUPERVISOR_COMMS.lock: + source = (xcom_arg := self._xcom_arg).operator + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetXComSequenceSlice( + key=xcom_arg.key, + dag_id=source.dag_id, + task_id=source.task_id, + run_id=self._ti.run_id, + start=start, + stop=stop, + step=step, + ), + ) + msg = SUPERVISOR_COMMS.get_message() + if not isinstance(msg, XComSequenceSliceResult): + raise TypeError(f"Got unexpected response to GetXComSequenceSlice: {msg!r}") + return [XCom.deserialize_value(_XComWrapper(value)) for value in msg.root] + + if not isinstance(key, int): + if (index := getattr(key, "__index__", None)) is not None: + key = index() + raise TypeError(f"Sequence indices must be integers or slices not {type(key).__name__}") + with SUPERVISOR_COMMS.lock: source = (xcom_arg := self._xcom_arg).operator SUPERVISOR_COMMS.send_request( @@ -157,13 +164,14 @@ def __getitem__(self, key: int | slice) -> T | Sequence[T]: ), ) msg = SUPERVISOR_COMMS.get_message() - - if not isinstance(msg, XComResult): + if isinstance(msg, ErrorResponse): raise IndexError(key) - return XCom.deserialize_value(msg) + if not isinstance(msg, XComSequenceIndexResult): + raise TypeError(f"Got unexpected response to GetXComSequenceItem: {msg!r}") + return XCom.deserialize_value(_XComWrapper(msg.root)) -def _coerce_index(value: Any) -> int | None: +def _coerce_slice_index(value: Any) -> int | None: """ Check slice attribute's type and convert it to int. @@ -177,17 +185,13 @@ def _coerce_index(value: Any) -> int | None: raise TypeError("slice indices must be integers or None or have an __index__ method") -def _coerce_slice(key: slice) -> tuple[int, int | None, bool]: +def _coerce_slice(key: slice) -> tuple[int | None, int | None, int | None]: """ Check slice content and convert it for SQL. See CPython documentation on this: https://docs.python.org/3/reference/datamodel.html#slice-objects """ - if key.step is None or key.step == 1: - reverse = False - elif key.step == -1: - reverse = True - else: - raise ValueError("non-trivial slice step not supported") - return _coerce_index(key.start) or 0, _coerce_index(key.stop), reverse + if (step := _coerce_slice_index(key.step)) == 0: + raise ValueError("slice step cannot be zero") + return _coerce_slice_index(key.start), _coerce_slice_index(key.stop), step diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 1006b861378d0..65d05cc023d51 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -60,7 +60,7 @@ TaskInstanceState, TaskStatesResponse, VariableResponse, - XComResponse, + XComSequenceIndexResponse, ) from airflow.sdk.exceptions import ErrorType from airflow.sdk.execution_time.comms import ( @@ -87,6 +87,7 @@ GetXCom, GetXComCount, GetXComSequenceItem, + GetXComSequenceSlice, PrevSuccessfulDagRunResult, PutVariable, RescheduleTask, @@ -103,6 +104,8 @@ VariableResult, XComCountResponse, XComResult, + XComSequenceIndexResult, + XComSequenceSliceResult, ) from airflow.sdk.execution_time.secrets_masker import mask_secret @@ -1108,10 +1111,15 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): xcom = self.client.xcoms.get_sequence_item( msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.offset ) - if isinstance(xcom, XComResponse): - resp = XComResult.from_xcom_response(xcom) + if isinstance(xcom, XComSequenceIndexResponse): + resp = XComSequenceIndexResult.from_response(xcom) else: resp = xcom + elif isinstance(msg, GetXComSequenceSlice): + xcoms = self.client.xcoms.get_sequence_slice( + msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.start, msg.stop, msg.step + ) + resp = XComSequenceSliceResult.from_response(xcoms) elif isinstance(msg, DeferTask): self._terminal_state = TaskInstanceState.DEFERRED self._rendered_map_index = msg.rendered_map_index diff --git a/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py b/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py index 2430f85f35e5e..e4943196a09da 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py +++ b/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py @@ -28,8 +28,10 @@ ErrorResponse, GetXComCount, GetXComSequenceItem, + GetXComSequenceSlice, XComCountResponse, - XComResult, + XComSequenceIndexResult, + XComSequenceSliceResult, ) from airflow.sdk.execution_time.lazy_sequence import LazyXComSequence from airflow.sdk.execution_time.xcom import resolve_xcom_backend @@ -75,7 +77,7 @@ def test_iter(mock_supervisor_comms, lazy_sequence): it = iter(lazy_sequence) mock_supervisor_comms.get_message.side_effect = [ - XComResult(key="return_value", value="f"), + XComSequenceIndexResult(root="f"), ErrorResponse(error=ErrorType.XCOM_NOT_FOUND, detail={"oops": "sorry!"}), ] assert list(it) == ["f"] @@ -104,7 +106,7 @@ def test_iter(mock_supervisor_comms, lazy_sequence): def test_getitem_index(mock_supervisor_comms, lazy_sequence): - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value="f") + mock_supervisor_comms.get_message.return_value = XComSequenceIndexResult(root="f") assert lazy_sequence[4] == "f" assert mock_supervisor_comms.send_request.mock_calls == [ call( @@ -121,12 +123,12 @@ def test_getitem_index(mock_supervisor_comms, lazy_sequence): @conf_vars({("core", "xcom_backend"): "task_sdk.execution_time.test_lazy_sequence.CustomXCom"}) -def test_getitem_calls_correct_deserialise(mock_supervisor_comms, lazy_sequence): - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value="some-value") +def test_getitem_calls_correct_deserialise(monkeypatch, mock_supervisor_comms, lazy_sequence): + mock_supervisor_comms.get_message.return_value = XComSequenceIndexResult(root="some-value") xcom = resolve_xcom_backend() assert xcom.__name__ == "CustomXCom" - airflow.sdk.execution_time.xcom.XCom = xcom + monkeypatch.setattr(airflow.sdk.execution_time.xcom, "XCom", xcom) assert lazy_sequence[4] == "Made with CustomXCom: some-value" assert mock_supervisor_comms.send_request.mock_calls == [ @@ -163,3 +165,22 @@ def test_getitem_indexerror(mock_supervisor_comms, lazy_sequence): ), ), ] + + +def test_getitem_slice(mock_supervisor_comms, lazy_sequence): + mock_supervisor_comms.get_message.return_value = XComSequenceSliceResult(root=[6, 4, 1]) + assert lazy_sequence[:5] == [6, 4, 1] + assert mock_supervisor_comms.send_request.mock_calls == [ + call( + log=ANY, + msg=GetXComSequenceSlice( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + start=None, + stop=5, + step=None, + ), + ), + ] diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 4696908bae72f..86a2e747e0f0f 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -76,6 +76,7 @@ GetVariable, GetXCom, GetXComSequenceItem, + GetXComSequenceSlice, OKResponse, PrevSuccessfulDagRunResult, PutVariable, @@ -91,6 +92,8 @@ TriggerDagRun, VariableResult, XComResult, + XComSequenceIndexResult, + XComSequenceSliceResult, ) from airflow.sdk.execution_time.supervisor import ( BUFFER_SIZE, @@ -1618,11 +1621,11 @@ def watched_subprocess(self, mocker): task_id="test_task", offset=0, ), - b'{"key":"test_key","value":"test_value","type":"XComResult"}\n', + b'{"root":"test_value","type":"XComSequenceIndexResult"}\n', "xcoms.get_sequence_item", ("test_dag", "test_run", "test_task", "test_key", 0), {}, - XComResult(key="test_key", value="test_value"), + XComSequenceIndexResult(root="test_value"), None, id="get_xcom_seq_item", ), @@ -1642,6 +1645,24 @@ def watched_subprocess(self, mocker): None, id="get_xcom_seq_item_not_found", ), + pytest.param( + GetXComSequenceSlice( + key="test_key", + dag_id="test_dag", + run_id="test_run", + task_id="test_task", + start=None, + stop=None, + step=None, + ), + b'{"root":["foo","bar"],"type":"XComSequenceSliceResult"}\n', + "xcoms.get_sequence_slice", + ("test_dag", "test_run", "test_task", "test_key", None, None, None), + {}, + XComSequenceSliceResult(root=["foo", "bar"]), + None, + id="get_xcom_seq_slice", + ), ], ) def test_handle_requests(