From 054c93a667c5e0e07042b868c413d57d57c2cd6c Mon Sep 17 00:00:00 2001 From: Karthikeyan Singaravelan Date: Sat, 26 Apr 2025 00:27:31 +0530 Subject: [PATCH 1/4] Implement offset to get the xcom for a given task by offset. --- .../api_fastapi/execution_api/routes/xcoms.py | 9 +++- task-sdk/src/airflow/sdk/api/client.py | 3 ++ task-sdk/src/airflow/sdk/bases/xcom.py | 9 ++-- .../src/airflow/sdk/execution_time/comms.py | 1 + .../sdk/execution_time/lazy_sequence.py | 4 +- .../airflow/sdk/execution_time/supervisor.py | 8 ++- .../airflow/sdk/execution_time/task_runner.py | 51 ++++++++++++------- 7 files changed, 57 insertions(+), 28 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py index 53a0582015aab..cf119d789494a 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -126,6 +126,7 @@ class GetXcomFilterParams(BaseModel): map_index: int = -1 include_prior_dates: bool = False + offset: int | None = None @router.get( @@ -148,11 +149,15 @@ def get_xcom( key=key, task_ids=task_id, dag_ids=dag_id, - map_indexes=params.map_index, + map_indexes=None if params.offset is not None else params.map_index, include_prior_dates=params.include_prior_dates, session=session, ) - xcom_query = xcom_query.filter(XComModel.map_index == params.map_index) + + if params.offset is not None: + xcom_query = xcom_query.offset(params.offset) + else: + xcom_query = xcom_query.filter(XComModel.map_index == params.map_index) # We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend. # This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead # retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one` diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index c64f721b9ae61..ad3ad879aafe9 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -357,6 +357,7 @@ def get( key: str, map_index: int | None = None, include_prior_dates: bool = False, + offset: int | None = None, ) -> XComResponse: """Get a XCom value from the API server.""" # TODO: check if we need to use map_index as params in the uri @@ -364,6 +365,8 @@ def get( params = {} if map_index is not None and map_index >= 0: params.update({"map_index": map_index}) + if offset is not None: + params.update({"offset": offset}) if include_prior_dates: params.update({"include_prior_dates": include_prior_dates}) try: diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py index b8b301e1511d9..5154bcaac0753 100644 --- a/task-sdk/src/airflow/sdk/bases/xcom.py +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -189,11 +189,7 @@ def _get_xcom_db_ref( SUPERVISOR_COMMS.send_request( log=log, msg=GetXCom( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, + key=key, dag_id=dag_id, task_id=task_id, run_id=run_id, map_index=map_index, offset=None ), ) @@ -212,6 +208,7 @@ def get_one( task_id: str, run_id: str, map_index: int | None = None, + offset: int | None = None, include_prior_dates: bool = False, ) -> Any | None: """ @@ -245,6 +242,7 @@ def get_one( # we need to make sure that we "atomically" send a request and get the response to that # back so that two triggers don't end up interleaving requests and create a possible # race condition where the wrong trigger reads the response. + print(f"{offset = }") with SUPERVISOR_COMMS.lock: SUPERVISOR_COMMS.send_request( log=log, @@ -255,6 +253,7 @@ def get_one( run_id=run_id, map_index=map_index, include_prior_dates=include_prior_dates, + offset=offset, ), ) msg = SUPERVISOR_COMMS.get_message() diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index b4d68086b0c4d..42fd39d789a8b 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -427,6 +427,7 @@ class GetXCom(BaseModel): run_id: str task_id: str map_index: int | None = None + offset: int | None = None include_prior_dates: bool = False type: Literal["GetXCom"] = "GetXCom" diff --git a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py index 095ab051fb1c4..911458b495f5b 100644 --- a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py +++ b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py @@ -151,9 +151,7 @@ def __getitem__(self, key: int | slice) -> T | Sequence[T]: def _get_item(self, index: int) -> T: # TODO: maybe we need to call SUPERVISOR_COMMS manually so we can handle not found here? return self._ti.xcom_pull( - task_ids=self._xcom_arg.operator.task_id, - key=self._xcom_arg.key, - map_indexes=index, + task_ids=self._xcom_arg.operator.task_id, key=self._xcom_arg.key, offset=index ) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index c90d6ea5a0241..9689cb37e5b14 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1027,7 +1027,13 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): resp = var elif isinstance(msg, GetXCom): xcom = self.client.xcoms.get( - msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index, msg.include_prior_dates + msg.dag_id, + msg.run_id, + msg.task_id, + msg.key, + msg.map_index, + msg.include_prior_dates, + msg.offset, ) xcom_result = XComResult.from_xcom_response(xcom) resp = xcom_result diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 00c09528c19a1..9e530bd1604ca 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -277,6 +277,7 @@ def xcom_pull( include_prior_dates: bool = False, # TODO: Add support for this *, map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET, + offset: int | None = None, default: Any = None, run_id: str | None = None, ) -> Any: @@ -348,23 +349,39 @@ def xcom_pull( ) xcoms = [] - # TODO: AIP 72 Execution API only allows working with a single map_index at a time - # this is inefficient and leads to task_id * map_index requests to the API. - # And we can't achieve the original behavior of XCom pull with multiple tasks - # directly now. - for t_id, m_idx in product(task_ids, map_indexes_iterable): - value = XCom.get_one( - run_id=run_id, - key=key, - task_id=t_id, - dag_id=dag_id, - map_index=m_idx, - include_prior_dates=include_prior_dates, - ) - if value is None: - xcoms.append(default) - else: - xcoms.append(value) + + if offset is not None: + for t_id, offset_idx in product(task_ids, [offset]): + value = XCom.get_one( + run_id=run_id, + key=key, + task_id=t_id, + dag_id=dag_id, + offset=offset_idx, + include_prior_dates=include_prior_dates, + ) + if value is None: + xcoms.append(default) + else: + xcoms.append(value) + else: + # TODO: AIP 72 Execution API only allows working with a single map_index at a time + # this is inefficient and leads to task_id * map_index requests to the API. + # And we can't achieve the original behavior of XCom pull with multiple tasks + # directly now. + for t_id, m_idx in product(task_ids, map_indexes_iterable): + value = XCom.get_one( + run_id=run_id, + key=key, + task_id=t_id, + dag_id=dag_id, + map_index=m_idx, + include_prior_dates=include_prior_dates, + ) + if value is None: + xcoms.append(default) + else: + xcoms.append(value) if single_task_requested and single_map_index_requested: return xcoms[0] From 5df6a79d50d0c4f0ad1d33345f2fa775fe4fb2dd Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 30 Apr 2025 16:58:55 +0800 Subject: [PATCH 2/4] Call API directly in lazy sequence getitem This allows us to better handle errors, and also avoids overloading XCom interfaces with different use cases. --- .../api_fastapi/execution_api/routes/xcoms.py | 5 +- airflow-core/src/airflow/models/xcom.py | 10 +- task-sdk/src/airflow/sdk/api/client.py | 39 ++++++- task-sdk/src/airflow/sdk/bases/xcom.py | 9 +- .../src/airflow/sdk/execution_time/comms.py | 11 +- .../sdk/execution_time/lazy_sequence.py | 101 ++++++++++-------- .../airflow/sdk/execution_time/supervisor.py | 18 ++-- .../airflow/sdk/execution_time/task_runner.py | 51 +++------ 8 files changed, 145 insertions(+), 99 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py index cf119d789494a..cabf3c6114706 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -151,13 +151,14 @@ def get_xcom( dag_ids=dag_id, map_indexes=None if params.offset is not None else params.map_index, include_prior_dates=params.include_prior_dates, + latest_first=(params.offset is None or params.offset >= 0), session=session, ) - if params.offset is not None: - xcom_query = xcom_query.offset(params.offset) + xcom_query = xcom_query.offset(abs(params.offset)) else: xcom_query = xcom_query.filter(XComModel.map_index == params.map_index) + # We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend. # This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead # retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one` diff --git a/airflow-core/src/airflow/models/xcom.py b/airflow-core/src/airflow/models/xcom.py index f828d9b9343c5..dbee6332a9991 100644 --- a/airflow-core/src/airflow/models/xcom.py +++ b/airflow-core/src/airflow/models/xcom.py @@ -252,6 +252,7 @@ def get_many( map_indexes: int | Iterable[int] | None = None, include_prior_dates: bool = False, limit: int | None = None, + latest_first: bool = True, session: Session = NEW_SESSION, ) -> Query: """ @@ -274,7 +275,9 @@ def get_many( returned regardless of the run it belongs to. :param session: Database session. If not given, a new session will be created for this function. - :param limit: Limiting returning XComs + :param limit: Limiting returning XComs. + :param latest_first: If *True* (default), returning XComs are ordered + latest-first. Otherwise earlier XComs are returned first. """ from airflow.models.dagrun import DagRun @@ -318,7 +321,10 @@ def get_many( else: query = query.filter(cls.run_id == run_id) - query = query.order_by(DagRun.logical_date.desc(), cls.timestamp.desc()) + if latest_first: + query = query.order_by(DagRun.logical_date.desc(), cls.timestamp.desc()) + else: + query = query.order_by(DagRun.logical_date.asc(), cls.timestamp.asc()) if limit: return query.limit(limit) return query diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index ad3ad879aafe9..399954ddfcbad 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -357,7 +357,6 @@ def get( key: str, map_index: int | None = None, include_prior_dates: bool = False, - offset: int | None = None, ) -> XComResponse: """Get a XCom value from the API server.""" # TODO: check if we need to use map_index as params in the uri @@ -365,8 +364,6 @@ def get( params = {} if map_index is not None and map_index >= 0: params.update({"map_index": map_index}) - if offset is not None: - params.update({"offset": offset}) if include_prior_dates: params.update({"include_prior_dates": include_prior_dates}) try: @@ -432,6 +429,42 @@ def delete( # decouple from the server response string return OKResponse(ok=True) + def get_sequence_item( + self, + dag_id: str, + run_id: str, + task_id: str, + key: str, + offset: int, + ) -> XComResponse | ErrorResponse: + params = {"offset": offset} + try: + resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params) + except ServerResponseError as e: + if e.response.status_code == HTTPStatus.NOT_FOUND: + log.error( + "XCom not found", + dag_id=dag_id, + run_id=run_id, + task_id=task_id, + key=key, + offset=offset, + detail=e.detail, + status_code=e.response.status_code, + ) + return ErrorResponse( + error=ErrorType.XCOM_NOT_FOUND, + detail={ + "dag_id": dag_id, + "run_id": run_id, + "task_id": task_id, + "key": key, + "offset": offset, + }, + ) + raise + return XComResponse.model_validate_json(resp.read()) + class AssetOperations: __slots__ = ("client",) diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py index 5154bcaac0753..b8b301e1511d9 100644 --- a/task-sdk/src/airflow/sdk/bases/xcom.py +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -189,7 +189,11 @@ def _get_xcom_db_ref( SUPERVISOR_COMMS.send_request( log=log, msg=GetXCom( - key=key, dag_id=dag_id, task_id=task_id, run_id=run_id, map_index=map_index, offset=None + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, ), ) @@ -208,7 +212,6 @@ def get_one( task_id: str, run_id: str, map_index: int | None = None, - offset: int | None = None, include_prior_dates: bool = False, ) -> Any | None: """ @@ -242,7 +245,6 @@ def get_one( # we need to make sure that we "atomically" send a request and get the response to that # back so that two triggers don't end up interleaving requests and create a possible # race condition where the wrong trigger reads the response. - print(f"{offset = }") with SUPERVISOR_COMMS.lock: SUPERVISOR_COMMS.send_request( log=log, @@ -253,7 +255,6 @@ def get_one( run_id=run_id, map_index=map_index, include_prior_dates=include_prior_dates, - offset=offset, ), ) msg = SUPERVISOR_COMMS.get_message() diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 42fd39d789a8b..039e2a5409ca3 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -427,7 +427,6 @@ class GetXCom(BaseModel): run_id: str task_id: str map_index: int | None = None - offset: int | None = None include_prior_dates: bool = False type: Literal["GetXCom"] = "GetXCom" @@ -442,6 +441,15 @@ class GetXComCount(BaseModel): type: Literal["GetNumberXComs"] = "GetNumberXComs" +class GetXComSequenceItem(BaseModel): + key: str + dag_id: str + run_id: str + task_id: str + offset: int + type: Literal["GetXComSequenceItem"] = "GetXComSequenceItem" + + class SetXCom(BaseModel): key: str value: Annotated[ @@ -606,6 +614,7 @@ class GetDRCount(BaseModel): GetVariable, GetXCom, GetXComCount, + GetXComSequenceItem, PutVariable, RescheduleTask, RetryTask, diff --git a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py index 911458b495f5b..79822787f3881 100644 --- a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py +++ b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py @@ -43,10 +43,10 @@ def __next__(self) -> T: if self.index < 0: # When iterating backwards, avoid extra HTTP request raise StopIteration() - val = self.seq._get_item(self.index) - if val is None: - # None isn't the best signal (it's bad in fact) but it's the best we can do until https://github.com/apache/airflow/issues/46426 - raise StopIteration() + try: + val = self.seq[self.index] + except IndexError: + raise StopIteration from None self.index += self.dir return val @@ -109,50 +109,59 @@ def __getitem__(self, key: int) -> T: ... def __getitem__(self, key: slice) -> Sequence[T]: ... def __getitem__(self, key: int | slice) -> T | Sequence[T]: - if isinstance(key, int): - if key >= 0: - return self._get_item(key) - # val[-1] etc. - return self._get_item(len(self) + key) + if not isinstance(key, (int, slice)): + raise TypeError(f"Sequence indices must be integers or slices, not {type(key).__name__}") if isinstance(key, slice): - # This implements the slicing syntax. We want to optimize negative slicing (e.g. seq[-10:]) by not - # doing an additional COUNT query (via HEAD http request) if possible. We can do this unless the - # start and stop have different signs (i.e. one is positive and another negative). - ... - """ - Todo? - elif isinstance(key, slice): - start, stop, reverse = _coerce_slice(key) - if start >= 0: - if stop is None: - stmt = self._select_asc.offset(start) - elif stop >= 0: - stmt = self._select_asc.slice(start, stop) - else: - stmt = self._select_asc.slice(start, len(self) + stop) - rows = [self._process_row(row) for row in self._session.execute(stmt)] - if reverse: - rows.reverse() - else: - if stop is None: - stmt = self._select_desc.limit(-start) - elif stop < 0: - stmt = self._select_desc.slice(-stop, -start) - else: - stmt = self._select_desc.slice(len(self) - stop, -start) - rows = [self._process_row(row) for row in self._session.execute(stmt)] - if not reverse: - rows.reverse() - return rows - """ - raise TypeError(f"Sequence indices must be integers or slices, not {type(key).__name__}") - - def _get_item(self, index: int) -> T: - # TODO: maybe we need to call SUPERVISOR_COMMS manually so we can handle not found here? - return self._ti.xcom_pull( - task_ids=self._xcom_arg.operator.task_id, key=self._xcom_arg.key, offset=index - ) + raise TypeError("slice is not implemented yet") + # TODO... + # This implements the slicing syntax. We want to optimize negative slicing (e.g. seq[-10:]) by not + # doing an additional COUNT query (via HEAD http request) if possible. We can do this unless the + # start and stop have different signs (i.e. one is positive and another negative). + # start, stop, reverse = _coerce_slice(key) + # if start >= 0: + # if stop is None: + # stmt = self._select_asc.offset(start) + # elif stop >= 0: + # stmt = self._select_asc.slice(start, stop) + # else: + # stmt = self._select_asc.slice(start, len(self) + stop) + # rows = [self._process_row(row) for row in self._session.execute(stmt)] + # if reverse: + # rows.reverse() + # else: + # if stop is None: + # stmt = self._select_desc.limit(-start) + # elif stop < 0: + # stmt = self._select_desc.slice(-stop, -start) + # else: + # stmt = self._select_desc.slice(len(self) - stop, -start) + # rows = [self._process_row(row) for row in self._session.execute(stmt)] + # if not reverse: + # rows.reverse() + # return rows + + from airflow.sdk.bases.xcom import BaseXCom + from airflow.sdk.execution_time.comms import GetXComSequenceItem, XComResult + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + with SUPERVISOR_COMMS.lock: + source = (xcom_arg := self._xcom_arg).operator + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetXComSequenceItem( + key=xcom_arg.key, + dag_id=source.dag_id, + task_id=source.task_id, + run_id=self._ti.run_id, + offset=key, + ), + ) + msg = SUPERVISOR_COMMS.get_message() + + if not isinstance(msg, XComResult): + raise IndexError(key) + return BaseXCom.deserialize_value(msg) def _coerce_index(value: Any) -> int | None: diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 9689cb37e5b14..ff804f3f76962 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -59,6 +59,7 @@ TaskStatesResponse, TerminalTIState, VariableResponse, + XComResponse, ) from airflow.sdk.exceptions import ErrorType from airflow.sdk.execution_time.comms import ( @@ -84,6 +85,7 @@ GetVariable, GetXCom, GetXComCount, + GetXComSequenceItem, PrevSuccessfulDagRunResult, PutVariable, RescheduleTask, @@ -1027,19 +1029,21 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): resp = var elif isinstance(msg, GetXCom): xcom = self.client.xcoms.get( - msg.dag_id, - msg.run_id, - msg.task_id, - msg.key, - msg.map_index, - msg.include_prior_dates, - msg.offset, + msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index, msg.include_prior_dates ) xcom_result = XComResult.from_xcom_response(xcom) resp = xcom_result elif isinstance(msg, GetXComCount): len = self.client.xcoms.head(msg.dag_id, msg.run_id, msg.task_id, msg.key) resp = XComCountResponse(len=len) + elif isinstance(msg, GetXComSequenceItem): + xcom = self.client.xcoms.get_sequence_item( + msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.offset + ) + if isinstance(xcom, XComResponse): + resp = XComResult.from_xcom_response(xcom) + else: + resp = xcom elif isinstance(msg, DeferTask): self._terminal_state = IntermediateTIState.DEFERRED self.client.task_instances.defer(self.id, msg) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 9e530bd1604ca..00c09528c19a1 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -277,7 +277,6 @@ def xcom_pull( include_prior_dates: bool = False, # TODO: Add support for this *, map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET, - offset: int | None = None, default: Any = None, run_id: str | None = None, ) -> Any: @@ -349,39 +348,23 @@ def xcom_pull( ) xcoms = [] - - if offset is not None: - for t_id, offset_idx in product(task_ids, [offset]): - value = XCom.get_one( - run_id=run_id, - key=key, - task_id=t_id, - dag_id=dag_id, - offset=offset_idx, - include_prior_dates=include_prior_dates, - ) - if value is None: - xcoms.append(default) - else: - xcoms.append(value) - else: - # TODO: AIP 72 Execution API only allows working with a single map_index at a time - # this is inefficient and leads to task_id * map_index requests to the API. - # And we can't achieve the original behavior of XCom pull with multiple tasks - # directly now. - for t_id, m_idx in product(task_ids, map_indexes_iterable): - value = XCom.get_one( - run_id=run_id, - key=key, - task_id=t_id, - dag_id=dag_id, - map_index=m_idx, - include_prior_dates=include_prior_dates, - ) - if value is None: - xcoms.append(default) - else: - xcoms.append(value) + # TODO: AIP 72 Execution API only allows working with a single map_index at a time + # this is inefficient and leads to task_id * map_index requests to the API. + # And we can't achieve the original behavior of XCom pull with multiple tasks + # directly now. + for t_id, m_idx in product(task_ids, map_indexes_iterable): + value = XCom.get_one( + run_id=run_id, + key=key, + task_id=t_id, + dag_id=dag_id, + map_index=m_idx, + include_prior_dates=include_prior_dates, + ) + if value is None: + xcoms.append(default) + else: + xcoms.append(value) if single_task_requested and single_map_index_requested: return xcoms[0] From b2a9f240504db0a0b23d2950dd19e5176d003e34 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 30 Apr 2025 18:07:38 +0800 Subject: [PATCH 3/4] Fix test mock using relevant interface --- .../api_fastapi/execution_api/routes/xcoms.py | 26 ++-- airflow-core/src/airflow/models/xcom.py | 10 +- .../execution_api/versions/head/test_xcoms.py | 81 +++++++++++ .../tests/unit/models/test_mappedoperator.py | 12 +- .../src/airflow/sdk/definitions/xcom_arg.py | 2 +- .../execution_time/test_lazy_sequence.py | 131 ++++++++++++++++++ 6 files changed, 241 insertions(+), 21 deletions(-) create mode 100644 task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py index cabf3c6114706..9f5e7d686a3cb 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -142,20 +142,20 @@ def get_xcom( params: Annotated[GetXcomFilterParams, Query()], ) -> XComResponse: """Get an Airflow XCom from database - not other XCom Backends.""" - # The xcom_query allows no map_index to be passed. This endpoint should always return just a single item, - # so we override that query value xcom_query = XComModel.get_many( run_id=run_id, key=key, task_ids=task_id, dag_ids=dag_id, - map_indexes=None if params.offset is not None else params.map_index, include_prior_dates=params.include_prior_dates, - latest_first=(params.offset is None or params.offset >= 0), session=session, ) if params.offset is not None: - xcom_query = xcom_query.offset(abs(params.offset)) + xcom_query = xcom_query.filter(XComModel.value.is_not(None)).order_by(None) + if params.offset >= 0: + xcom_query = xcom_query.order_by(XComModel.map_index.asc()).offset(params.offset) + else: + xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1 - params.offset) else: xcom_query = xcom_query.filter(XComModel.map_index == params.map_index) @@ -166,13 +166,19 @@ def get_xcom( # performance hits from retrieving large data files into the API server. result = xcom_query.limit(1).first() if result is None: - map_index = params.map_index + if params.offset is None: + message = ( + f"XCom with {key=} map_index={params.map_index} not found for " + f"task {task_id!r} in DAG run {run_id!r} of {dag_id!r}" + ) + else: + message = ( + f"XCom with {key=} offset={params.offset} not found for " + f"task {task_id!r} in DAG run {run_id!r} of {dag_id!r}" + ) raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail={ - "reason": "not_found", - "message": f"XCom with {key=} {map_index=} not found for task {task_id!r} in DAG run {run_id!r} of {dag_id!r}", - }, + detail={"reason": "not_found", "message": message}, ) return XComResponse(key=key, value=result.value) diff --git a/airflow-core/src/airflow/models/xcom.py b/airflow-core/src/airflow/models/xcom.py index dbee6332a9991..f828d9b9343c5 100644 --- a/airflow-core/src/airflow/models/xcom.py +++ b/airflow-core/src/airflow/models/xcom.py @@ -252,7 +252,6 @@ def get_many( map_indexes: int | Iterable[int] | None = None, include_prior_dates: bool = False, limit: int | None = None, - latest_first: bool = True, session: Session = NEW_SESSION, ) -> Query: """ @@ -275,9 +274,7 @@ def get_many( returned regardless of the run it belongs to. :param session: Database session. If not given, a new session will be created for this function. - :param limit: Limiting returning XComs. - :param latest_first: If *True* (default), returning XComs are ordered - latest-first. Otherwise earlier XComs are returned first. + :param limit: Limiting returning XComs """ from airflow.models.dagrun import DagRun @@ -321,10 +318,7 @@ def get_many( else: query = query.filter(cls.run_id == run_id) - if latest_first: - query = query.order_by(DagRun.logical_date.desc(), cls.timestamp.desc()) - else: - query = query.order_by(DagRun.logical_date.asc(), cls.timestamp.asc()) + query = query.order_by(DagRun.logical_date.desc(), cls.timestamp.desc()) if limit: return query.limit(limit) return query diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py index c2b49841b3ab8..951fbf5cbae0a 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py @@ -29,6 +29,7 @@ from airflow.models.dagrun import DagRun from airflow.models.taskmap import TaskMap from airflow.models.xcom import XComModel +from airflow.providers.standard.operators.empty import EmptyOperator from airflow.serialization.serde import deserialize, serialize from airflow.utils.session import create_session @@ -130,6 +131,86 @@ def test_xcom_access_denied(self, client, caplog): } assert any(msg.startswith("Checking read XCom access") for msg in caplog.messages) + @pytest.mark.parametrize( + "offset, expected_status, expected_json", + [ + pytest.param( + -4, + 404, + { + "detail": { + "reason": "not_found", + "message": ( + "XCom with key='xcom_1' offset=-4 not found " + "for task 'task' in DAG run 'runid' of 'dag'" + ), + }, + }, + id="-4", + ), + pytest.param(-3, 200, {"key": "xcom_1", "value": "f"}, id="-3"), + pytest.param(-2, 200, {"key": "xcom_1", "value": "o"}, id="-2"), + pytest.param(-1, 200, {"key": "xcom_1", "value": "b"}, id="-1"), + pytest.param(0, 200, {"key": "xcom_1", "value": "f"}, id="0"), + pytest.param(1, 200, {"key": "xcom_1", "value": "o"}, id="1"), + pytest.param(2, 200, {"key": "xcom_1", "value": "b"}, id="2"), + pytest.param( + 3, + 404, + { + "detail": { + "reason": "not_found", + "message": ( + "XCom with key='xcom_1' offset=3 not found " + "for task 'task' in DAG run 'runid' of 'dag'" + ), + }, + }, + id="3", + ), + ], + ) + def test_xcom_get_with_offset( + self, + client, + dag_maker, + session, + offset, + expected_status, + expected_json, + ): + xcom_values = ["f", None, "o", "b"] + + class MyOperator(EmptyOperator): + def __init__(self, *, x, **kwargs): + super().__init__(**kwargs) + self.x = x + + with dag_maker(dag_id="dag"): + MyOperator.partial(task_id="task").expand(x=xcom_values) + dag_run = dag_maker.create_dagrun(run_id="runid") + tis = {ti.map_index: ti for ti in dag_run.task_instances} + + for map_index, db_value in enumerate(xcom_values): + if db_value is None: # We don't put None to XCom. + continue + ti = tis[map_index] + x = XComModel( + key="xcom_1", + value=db_value, + dag_run_id=ti.dag_run.id, + run_id=ti.run_id, + task_id=ti.task_id, + dag_id=ti.dag_id, + map_index=map_index, + ) + session.add(x) + session.commit() + + response = client.get(f"/execution/xcoms/dag/runid/task/xcom_1?offset={offset}") + assert response.status_code == expected_status + assert response.json() == expected_json + class TestXComsSetEndpoint: @pytest.mark.parametrize( diff --git a/airflow-core/tests/unit/models/test_mappedoperator.py b/airflow-core/tests/unit/models/test_mappedoperator.py index 7f2f3770bb4d6..dd35f9461dc9c 100644 --- a/airflow-core/tests/unit/models/test_mappedoperator.py +++ b/airflow-core/tests/unit/models/test_mappedoperator.py @@ -32,7 +32,7 @@ from airflow.models.taskmap import TaskMap from airflow.providers.standard.operators.python import PythonOperator from airflow.sdk import setup, task, task_group, teardown -from airflow.sdk.execution_time.comms import XComCountResponse +from airflow.sdk.execution_time.comms import XComCountResponse, XComResult from airflow.utils.state import TaskInstanceState from airflow.utils.task_group import TaskGroup from airflow.utils.trigger_rule import TriggerRule @@ -1270,8 +1270,16 @@ def my_teardown(val): ) as supervisor_comms: # TODO: TaskSDK: this is a bit of a hack that we need to stub this at all. `dag.test()` should # really work without this! - supervisor_comms.get_message.return_value = XComCountResponse(len=3) + supervisor_comms.get_message.side_effect = [ + XComCountResponse(len=3), + XComResult(key="return_value", value=1), + XComCountResponse(len=3), + XComResult(key="return_value", value=2), + XComCountResponse(len=3), + XComResult(key="return_value", value=3), + ] dr = dag.test() + assert supervisor_comms.get_message.call_count == 6 states = self.get_states(dr) expected = { "tg_1.my_pre_setup": "success", diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index ab3a5a82b22c1..1adcb7efaa7a6 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -337,7 +337,7 @@ def resolve(self, context: Mapping[str, Any]) -> Any: task_id = self.operator.task_id if self.operator.is_mapped: - return LazyXComSequence[Any](xcom_arg=self, ti=ti) + return LazyXComSequence(xcom_arg=self, ti=ti) tg = self.operator.get_closest_mapped_task_group() result = None if tg is None: diff --git a/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py b/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py new file mode 100644 index 0000000000000..a42572e5df1fa --- /dev/null +++ b/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from unittest.mock import ANY, Mock, call + +import pytest + +from airflow.sdk.exceptions import ErrorType +from airflow.sdk.execution_time.comms import ( + ErrorResponse, + GetXComCount, + GetXComSequenceItem, + XComCountResponse, + XComResult, +) +from airflow.sdk.execution_time.lazy_sequence import LazyXComSequence + + +@pytest.fixture +def mock_operator(): + return Mock(spec=["dag_id", "task_id"], dag_id="dag", task_id="task") + + +@pytest.fixture +def mock_xcom_arg(mock_operator): + return Mock(spec=["operator", "key"], operator=mock_operator, key="return_value") + + +@pytest.fixture +def mock_ti(): + return Mock(spec=["run_id"], run_id="run") + + +@pytest.fixture +def lazy_sequence(mock_xcom_arg, mock_ti): + return LazyXComSequence(mock_xcom_arg, mock_ti) + + +def test_len(mock_supervisor_comms, lazy_sequence): + mock_supervisor_comms.get_message.return_value = XComCountResponse(len=3) + assert len(lazy_sequence) == 3 + assert mock_supervisor_comms.send_request.mock_calls == [ + call(log=ANY, msg=GetXComCount(key="return_value", dag_id="dag", task_id="task", run_id="run")), + ] + + +def test_iter(mock_supervisor_comms, lazy_sequence): + it = iter(lazy_sequence) + + mock_supervisor_comms.get_message.side_effect = [ + XComResult(key="return_value", value="f"), + ErrorResponse(error=ErrorType.XCOM_NOT_FOUND, detail={"oops": "sorry!"}), + ] + assert list(it) == ["f"] + assert mock_supervisor_comms.send_request.mock_calls == [ + call( + log=ANY, + msg=GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=0, + ), + ), + call( + log=ANY, + msg=GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=1, + ), + ), + ] + + +def test_getitem_index(mock_supervisor_comms, lazy_sequence): + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value="f") + assert lazy_sequence[4] == "f" + assert mock_supervisor_comms.send_request.mock_calls == [ + call( + log=ANY, + msg=GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=4, + ), + ), + ] + + +def test_getitem_indexerror(mock_supervisor_comms, lazy_sequence): + mock_supervisor_comms.get_message.return_value = ErrorResponse( + error=ErrorType.XCOM_NOT_FOUND, + detail={"oops": "sorry!"}, + ) + with pytest.raises(IndexError) as ctx: + lazy_sequence[4] + assert ctx.value.args == (4,) + assert mock_supervisor_comms.send_request.mock_calls == [ + call( + log=ANY, + msg=GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=4, + ), + ), + ] From ffd51c885dd5c5e40fb15ac8e7038c287e6f1981 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 30 Apr 2025 21:05:55 +0800 Subject: [PATCH 4/4] Add supervisor handling test --- .../execution_time/test_supervisor.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 4aff50a5fd486..a3eef617a4ee0 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -74,6 +74,7 @@ GetTICount, GetVariable, GetXCom, + GetXComSequenceItem, OKResponse, PrevSuccessfulDagRunResult, PutVariable, @@ -1436,6 +1437,36 @@ def watched_subprocess(self, mocker): TaskStatesResult(task_states={"run_id": {"task1": "success", "task2": "failed"}}), id="get_task_states", ), + pytest.param( + GetXComSequenceItem( + key="test_key", + dag_id="test_dag", + run_id="test_run", + task_id="test_task", + offset=0, + ), + b'{"key":"test_key","value":"test_value","type":"XComResult"}\n', + "xcoms.get_sequence_item", + ("test_dag", "test_run", "test_task", "test_key", 0), + {}, + XComResult(key="test_key", value="test_value"), + id="get_xcom_seq_item", + ), + pytest.param( + GetXComSequenceItem( + key="test_key", + dag_id="test_dag", + run_id="test_run", + task_id="test_task", + offset=2, + ), + b'{"error":"XCOM_NOT_FOUND","detail":null,"type":"ErrorResponse"}\n', + "xcoms.get_sequence_item", + ("test_dag", "test_run", "test_task", "test_key", 2), + {}, + ErrorResponse(error=ErrorType.XCOM_NOT_FOUND), + id="get_xcom_seq_item_not_found", + ), ], ) def test_handle_requests(