diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py index 53a0582015aab..9f5e7d686a3cb 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -126,6 +126,7 @@ class GetXcomFilterParams(BaseModel): map_index: int = -1 include_prior_dates: bool = False + offset: int | None = None @router.get( @@ -141,18 +142,23 @@ def get_xcom( params: Annotated[GetXcomFilterParams, Query()], ) -> XComResponse: """Get an Airflow XCom from database - not other XCom Backends.""" - # The xcom_query allows no map_index to be passed. This endpoint should always return just a single item, - # so we override that query value xcom_query = XComModel.get_many( run_id=run_id, key=key, task_ids=task_id, dag_ids=dag_id, - map_indexes=params.map_index, include_prior_dates=params.include_prior_dates, session=session, ) - xcom_query = xcom_query.filter(XComModel.map_index == params.map_index) + if params.offset is not None: + xcom_query = xcom_query.filter(XComModel.value.is_not(None)).order_by(None) + if params.offset >= 0: + xcom_query = xcom_query.order_by(XComModel.map_index.asc()).offset(params.offset) + else: + xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1 - params.offset) + else: + xcom_query = xcom_query.filter(XComModel.map_index == params.map_index) + # We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend. # This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead # retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one` @@ -160,13 +166,19 @@ def get_xcom( # performance hits from retrieving large data files into the API server. result = xcom_query.limit(1).first() if result is None: - map_index = params.map_index + if params.offset is None: + message = ( + f"XCom with {key=} map_index={params.map_index} not found for " + f"task {task_id!r} in DAG run {run_id!r} of {dag_id!r}" + ) + else: + message = ( + f"XCom with {key=} offset={params.offset} not found for " + f"task {task_id!r} in DAG run {run_id!r} of {dag_id!r}" + ) raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail={ - "reason": "not_found", - "message": f"XCom with {key=} {map_index=} not found for task {task_id!r} in DAG run {run_id!r} of {dag_id!r}", - }, + detail={"reason": "not_found", "message": message}, ) return XComResponse(key=key, value=result.value) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py index c2b49841b3ab8..951fbf5cbae0a 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py @@ -29,6 +29,7 @@ from airflow.models.dagrun import DagRun from airflow.models.taskmap import TaskMap from airflow.models.xcom import XComModel +from airflow.providers.standard.operators.empty import EmptyOperator from airflow.serialization.serde import deserialize, serialize from airflow.utils.session import create_session @@ -130,6 +131,86 @@ def test_xcom_access_denied(self, client, caplog): } assert any(msg.startswith("Checking read XCom access") for msg in caplog.messages) + @pytest.mark.parametrize( + "offset, expected_status, expected_json", + [ + pytest.param( + -4, + 404, + { + "detail": { + "reason": "not_found", + "message": ( + "XCom with key='xcom_1' offset=-4 not found " + "for task 'task' in DAG run 'runid' of 'dag'" + ), + }, + }, + id="-4", + ), + pytest.param(-3, 200, {"key": "xcom_1", "value": "f"}, id="-3"), + pytest.param(-2, 200, {"key": "xcom_1", "value": "o"}, id="-2"), + pytest.param(-1, 200, {"key": "xcom_1", "value": "b"}, id="-1"), + pytest.param(0, 200, {"key": "xcom_1", "value": "f"}, id="0"), + pytest.param(1, 200, {"key": "xcom_1", "value": "o"}, id="1"), + pytest.param(2, 200, {"key": "xcom_1", "value": "b"}, id="2"), + pytest.param( + 3, + 404, + { + "detail": { + "reason": "not_found", + "message": ( + "XCom with key='xcom_1' offset=3 not found " + "for task 'task' in DAG run 'runid' of 'dag'" + ), + }, + }, + id="3", + ), + ], + ) + def test_xcom_get_with_offset( + self, + client, + dag_maker, + session, + offset, + expected_status, + expected_json, + ): + xcom_values = ["f", None, "o", "b"] + + class MyOperator(EmptyOperator): + def __init__(self, *, x, **kwargs): + super().__init__(**kwargs) + self.x = x + + with dag_maker(dag_id="dag"): + MyOperator.partial(task_id="task").expand(x=xcom_values) + dag_run = dag_maker.create_dagrun(run_id="runid") + tis = {ti.map_index: ti for ti in dag_run.task_instances} + + for map_index, db_value in enumerate(xcom_values): + if db_value is None: # We don't put None to XCom. + continue + ti = tis[map_index] + x = XComModel( + key="xcom_1", + value=db_value, + dag_run_id=ti.dag_run.id, + run_id=ti.run_id, + task_id=ti.task_id, + dag_id=ti.dag_id, + map_index=map_index, + ) + session.add(x) + session.commit() + + response = client.get(f"/execution/xcoms/dag/runid/task/xcom_1?offset={offset}") + assert response.status_code == expected_status + assert response.json() == expected_json + class TestXComsSetEndpoint: @pytest.mark.parametrize( diff --git a/airflow-core/tests/unit/models/test_mappedoperator.py b/airflow-core/tests/unit/models/test_mappedoperator.py index 7f2f3770bb4d6..dd35f9461dc9c 100644 --- a/airflow-core/tests/unit/models/test_mappedoperator.py +++ b/airflow-core/tests/unit/models/test_mappedoperator.py @@ -32,7 +32,7 @@ from airflow.models.taskmap import TaskMap from airflow.providers.standard.operators.python import PythonOperator from airflow.sdk import setup, task, task_group, teardown -from airflow.sdk.execution_time.comms import XComCountResponse +from airflow.sdk.execution_time.comms import XComCountResponse, XComResult from airflow.utils.state import TaskInstanceState from airflow.utils.task_group import TaskGroup from airflow.utils.trigger_rule import TriggerRule @@ -1270,8 +1270,16 @@ def my_teardown(val): ) as supervisor_comms: # TODO: TaskSDK: this is a bit of a hack that we need to stub this at all. `dag.test()` should # really work without this! - supervisor_comms.get_message.return_value = XComCountResponse(len=3) + supervisor_comms.get_message.side_effect = [ + XComCountResponse(len=3), + XComResult(key="return_value", value=1), + XComCountResponse(len=3), + XComResult(key="return_value", value=2), + XComCountResponse(len=3), + XComResult(key="return_value", value=3), + ] dr = dag.test() + assert supervisor_comms.get_message.call_count == 6 states = self.get_states(dr) expected = { "tg_1.my_pre_setup": "success", diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index c64f721b9ae61..399954ddfcbad 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -429,6 +429,42 @@ def delete( # decouple from the server response string return OKResponse(ok=True) + def get_sequence_item( + self, + dag_id: str, + run_id: str, + task_id: str, + key: str, + offset: int, + ) -> XComResponse | ErrorResponse: + params = {"offset": offset} + try: + resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params) + except ServerResponseError as e: + if e.response.status_code == HTTPStatus.NOT_FOUND: + log.error( + "XCom not found", + dag_id=dag_id, + run_id=run_id, + task_id=task_id, + key=key, + offset=offset, + detail=e.detail, + status_code=e.response.status_code, + ) + return ErrorResponse( + error=ErrorType.XCOM_NOT_FOUND, + detail={ + "dag_id": dag_id, + "run_id": run_id, + "task_id": task_id, + "key": key, + "offset": offset, + }, + ) + raise + return XComResponse.model_validate_json(resp.read()) + class AssetOperations: __slots__ = ("client",) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index ab3a5a82b22c1..1adcb7efaa7a6 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -337,7 +337,7 @@ def resolve(self, context: Mapping[str, Any]) -> Any: task_id = self.operator.task_id if self.operator.is_mapped: - return LazyXComSequence[Any](xcom_arg=self, ti=ti) + return LazyXComSequence(xcom_arg=self, ti=ti) tg = self.operator.get_closest_mapped_task_group() result = None if tg is None: diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index b4d68086b0c4d..039e2a5409ca3 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -441,6 +441,15 @@ class GetXComCount(BaseModel): type: Literal["GetNumberXComs"] = "GetNumberXComs" +class GetXComSequenceItem(BaseModel): + key: str + dag_id: str + run_id: str + task_id: str + offset: int + type: Literal["GetXComSequenceItem"] = "GetXComSequenceItem" + + class SetXCom(BaseModel): key: str value: Annotated[ @@ -605,6 +614,7 @@ class GetDRCount(BaseModel): GetVariable, GetXCom, GetXComCount, + GetXComSequenceItem, PutVariable, RescheduleTask, RetryTask, diff --git a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py index 095ab051fb1c4..79822787f3881 100644 --- a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py +++ b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py @@ -43,10 +43,10 @@ def __next__(self) -> T: if self.index < 0: # When iterating backwards, avoid extra HTTP request raise StopIteration() - val = self.seq._get_item(self.index) - if val is None: - # None isn't the best signal (it's bad in fact) but it's the best we can do until https://github.com/apache/airflow/issues/46426 - raise StopIteration() + try: + val = self.seq[self.index] + except IndexError: + raise StopIteration from None self.index += self.dir return val @@ -109,52 +109,59 @@ def __getitem__(self, key: int) -> T: ... def __getitem__(self, key: slice) -> Sequence[T]: ... def __getitem__(self, key: int | slice) -> T | Sequence[T]: - if isinstance(key, int): - if key >= 0: - return self._get_item(key) - # val[-1] etc. - return self._get_item(len(self) + key) + if not isinstance(key, (int, slice)): + raise TypeError(f"Sequence indices must be integers or slices, not {type(key).__name__}") if isinstance(key, slice): - # This implements the slicing syntax. We want to optimize negative slicing (e.g. seq[-10:]) by not - # doing an additional COUNT query (via HEAD http request) if possible. We can do this unless the - # start and stop have different signs (i.e. one is positive and another negative). - ... - """ - Todo? - elif isinstance(key, slice): - start, stop, reverse = _coerce_slice(key) - if start >= 0: - if stop is None: - stmt = self._select_asc.offset(start) - elif stop >= 0: - stmt = self._select_asc.slice(start, stop) - else: - stmt = self._select_asc.slice(start, len(self) + stop) - rows = [self._process_row(row) for row in self._session.execute(stmt)] - if reverse: - rows.reverse() - else: - if stop is None: - stmt = self._select_desc.limit(-start) - elif stop < 0: - stmt = self._select_desc.slice(-stop, -start) - else: - stmt = self._select_desc.slice(len(self) - stop, -start) - rows = [self._process_row(row) for row in self._session.execute(stmt)] - if not reverse: - rows.reverse() - return rows - """ - raise TypeError(f"Sequence indices must be integers or slices, not {type(key).__name__}") - - def _get_item(self, index: int) -> T: - # TODO: maybe we need to call SUPERVISOR_COMMS manually so we can handle not found here? - return self._ti.xcom_pull( - task_ids=self._xcom_arg.operator.task_id, - key=self._xcom_arg.key, - map_indexes=index, - ) + raise TypeError("slice is not implemented yet") + # TODO... + # This implements the slicing syntax. We want to optimize negative slicing (e.g. seq[-10:]) by not + # doing an additional COUNT query (via HEAD http request) if possible. We can do this unless the + # start and stop have different signs (i.e. one is positive and another negative). + # start, stop, reverse = _coerce_slice(key) + # if start >= 0: + # if stop is None: + # stmt = self._select_asc.offset(start) + # elif stop >= 0: + # stmt = self._select_asc.slice(start, stop) + # else: + # stmt = self._select_asc.slice(start, len(self) + stop) + # rows = [self._process_row(row) for row in self._session.execute(stmt)] + # if reverse: + # rows.reverse() + # else: + # if stop is None: + # stmt = self._select_desc.limit(-start) + # elif stop < 0: + # stmt = self._select_desc.slice(-stop, -start) + # else: + # stmt = self._select_desc.slice(len(self) - stop, -start) + # rows = [self._process_row(row) for row in self._session.execute(stmt)] + # if not reverse: + # rows.reverse() + # return rows + + from airflow.sdk.bases.xcom import BaseXCom + from airflow.sdk.execution_time.comms import GetXComSequenceItem, XComResult + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + with SUPERVISOR_COMMS.lock: + source = (xcom_arg := self._xcom_arg).operator + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetXComSequenceItem( + key=xcom_arg.key, + dag_id=source.dag_id, + task_id=source.task_id, + run_id=self._ti.run_id, + offset=key, + ), + ) + msg = SUPERVISOR_COMMS.get_message() + + if not isinstance(msg, XComResult): + raise IndexError(key) + return BaseXCom.deserialize_value(msg) def _coerce_index(value: Any) -> int | None: diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index c90d6ea5a0241..ff804f3f76962 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -59,6 +59,7 @@ TaskStatesResponse, TerminalTIState, VariableResponse, + XComResponse, ) from airflow.sdk.exceptions import ErrorType from airflow.sdk.execution_time.comms import ( @@ -84,6 +85,7 @@ GetVariable, GetXCom, GetXComCount, + GetXComSequenceItem, PrevSuccessfulDagRunResult, PutVariable, RescheduleTask, @@ -1034,6 +1036,14 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): elif isinstance(msg, GetXComCount): len = self.client.xcoms.head(msg.dag_id, msg.run_id, msg.task_id, msg.key) resp = XComCountResponse(len=len) + elif isinstance(msg, GetXComSequenceItem): + xcom = self.client.xcoms.get_sequence_item( + msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.offset + ) + if isinstance(xcom, XComResponse): + resp = XComResult.from_xcom_response(xcom) + else: + resp = xcom elif isinstance(msg, DeferTask): self._terminal_state = IntermediateTIState.DEFERRED self.client.task_instances.defer(self.id, msg) diff --git a/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py b/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py new file mode 100644 index 0000000000000..a42572e5df1fa --- /dev/null +++ b/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from unittest.mock import ANY, Mock, call + +import pytest + +from airflow.sdk.exceptions import ErrorType +from airflow.sdk.execution_time.comms import ( + ErrorResponse, + GetXComCount, + GetXComSequenceItem, + XComCountResponse, + XComResult, +) +from airflow.sdk.execution_time.lazy_sequence import LazyXComSequence + + +@pytest.fixture +def mock_operator(): + return Mock(spec=["dag_id", "task_id"], dag_id="dag", task_id="task") + + +@pytest.fixture +def mock_xcom_arg(mock_operator): + return Mock(spec=["operator", "key"], operator=mock_operator, key="return_value") + + +@pytest.fixture +def mock_ti(): + return Mock(spec=["run_id"], run_id="run") + + +@pytest.fixture +def lazy_sequence(mock_xcom_arg, mock_ti): + return LazyXComSequence(mock_xcom_arg, mock_ti) + + +def test_len(mock_supervisor_comms, lazy_sequence): + mock_supervisor_comms.get_message.return_value = XComCountResponse(len=3) + assert len(lazy_sequence) == 3 + assert mock_supervisor_comms.send_request.mock_calls == [ + call(log=ANY, msg=GetXComCount(key="return_value", dag_id="dag", task_id="task", run_id="run")), + ] + + +def test_iter(mock_supervisor_comms, lazy_sequence): + it = iter(lazy_sequence) + + mock_supervisor_comms.get_message.side_effect = [ + XComResult(key="return_value", value="f"), + ErrorResponse(error=ErrorType.XCOM_NOT_FOUND, detail={"oops": "sorry!"}), + ] + assert list(it) == ["f"] + assert mock_supervisor_comms.send_request.mock_calls == [ + call( + log=ANY, + msg=GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=0, + ), + ), + call( + log=ANY, + msg=GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=1, + ), + ), + ] + + +def test_getitem_index(mock_supervisor_comms, lazy_sequence): + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value="f") + assert lazy_sequence[4] == "f" + assert mock_supervisor_comms.send_request.mock_calls == [ + call( + log=ANY, + msg=GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=4, + ), + ), + ] + + +def test_getitem_indexerror(mock_supervisor_comms, lazy_sequence): + mock_supervisor_comms.get_message.return_value = ErrorResponse( + error=ErrorType.XCOM_NOT_FOUND, + detail={"oops": "sorry!"}, + ) + with pytest.raises(IndexError) as ctx: + lazy_sequence[4] + assert ctx.value.args == (4,) + assert mock_supervisor_comms.send_request.mock_calls == [ + call( + log=ANY, + msg=GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=4, + ), + ), + ] diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 4aff50a5fd486..a3eef617a4ee0 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -74,6 +74,7 @@ GetTICount, GetVariable, GetXCom, + GetXComSequenceItem, OKResponse, PrevSuccessfulDagRunResult, PutVariable, @@ -1436,6 +1437,36 @@ def watched_subprocess(self, mocker): TaskStatesResult(task_states={"run_id": {"task1": "success", "task2": "failed"}}), id="get_task_states", ), + pytest.param( + GetXComSequenceItem( + key="test_key", + dag_id="test_dag", + run_id="test_run", + task_id="test_task", + offset=0, + ), + b'{"key":"test_key","value":"test_value","type":"XComResult"}\n', + "xcoms.get_sequence_item", + ("test_dag", "test_run", "test_task", "test_key", 0), + {}, + XComResult(key="test_key", value="test_value"), + id="get_xcom_seq_item", + ), + pytest.param( + GetXComSequenceItem( + key="test_key", + dag_id="test_dag", + run_id="test_run", + task_id="test_task", + offset=2, + ), + b'{"error":"XCOM_NOT_FOUND","detail":null,"type":"ErrorResponse"}\n', + "xcoms.get_sequence_item", + ("test_dag", "test_run", "test_task", "test_key", 2), + {}, + ErrorResponse(error=ErrorType.XCOM_NOT_FOUND), + id="get_xcom_seq_item_not_found", + ), ], ) def test_handle_requests(