From dc964046068e8032becb48b20185c84655a9e590 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 27 Mar 2025 17:58:42 +0530 Subject: [PATCH 1/2] Add a fixture to easily replace `ti.run` usage As we are replacing BaseOperator usage from Core to Task SDK, we are running into several issues, one of the common one being over-usage of `task.run()`. While some cases can be easily replaced by `task.execute()` others needs execution of the tasks, sharing of XCom's in between, checking task state, correct exception etc. To make this easier I have added `run_task` fixture which I have been using in https://github.com/apache/airflow/pull/48244 and it has worked out well. Example: --- .../src/tests_common/pytest_plugin.py | 425 ++++++++++++++++++ 1 file changed, 425 insertions(+) diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index ad885d176c76b..1475c20510a01 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -35,6 +35,8 @@ import time_machine if TYPE_CHECKING: + from uuid import UUID + from itsdangerous import URLSafeSerializer from sqlalchemy.orm import Session @@ -43,6 +45,10 @@ from airflow.models.dagrun import DagRun, DagRunType from airflow.models.taskinstance import TaskInstance from airflow.providers.standard.operators.empty import EmptyOperator + from airflow.sdk.api.datamodels._generated import IntermediateTIState, TerminalTIState + from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator + from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance from airflow.timetables.base import DataInterval from airflow.typing_compat import Self from airflow.utils.state import DagRunState, TaskInstanceState @@ -1872,3 +1878,422 @@ def mock_supervisor_comms(): "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True ) as supervisor_comms: yield supervisor_comms + + +@pytest.fixture +def mocked_parse(spy_agency): + """ + Fixture to set up an inline DAG and use it in a stubbed `parse` function. + + Use this fixture if you want to isolate and test `parse` or `run` logic without having to define a DAG file. + + This fixture returns a helper function `set_dag` that: + 1. Creates an in line DAG with the given `dag_id` and `task` (limited to one task) + 2. Constructs a `RuntimeTaskInstance` based on the provided `StartupDetails` and task. + 3. Stubs the `parse` function using `spy_agency`, to return the mocked `RuntimeTaskInstance`. + + After adding the fixture in your test function signature, you can use it like this :: + + mocked_parse( + StartupDetails( + ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1), + file="", + requests_fd=0, + ), + "example_dag_id", + CustomOperator(task_id="hello"), + ) + """ + + def set_dag(what: StartupDetails, dag_id: str, task: TaskSDKBaseOperator) -> RuntimeTaskInstance: + from airflow.sdk.definitions.dag import DAG + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, parse + from airflow.utils import timezone + + if not task.has_dag(): + dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 12, 3)) + task.dag = dag # type: ignore[assignment] + task = dag.task_dict[task.task_id] + else: + dag = task.dag + if what.ti_context.dag_run.conf: + dag.params = what.ti_context.dag_run.conf # type: ignore[assignment] + ti = RuntimeTaskInstance.model_construct( + **what.ti.model_dump(exclude_unset=True), + task=task, + _ti_context_from_server=what.ti_context, + max_tries=what.ti_context.max_tries, + start_date=what.start_date, + ) + if hasattr(parse, "spy"): + spy_agency.unspy(parse) + spy_agency.spy_on(parse, call_fake=lambda _: ti) + return ti + + return set_dag + + +class _XComHelperProtocol(Protocol): + def get( + self, + key: str, + task_id: str | None = None, + dag_id: str | None = None, + run_id: str | None = None, + map_index: int | None = None, + ) -> Any: ... + + def assert_pushed( + self, + key: str, + value: Any, + task_id: str | None = None, + dag_id: str | None = None, + run_id: str | None = None, + map_index: int | None = None, + **kwargs, + ) -> None: ... + + def clear(self): ... + + +class RunTaskCallable(Protocol): + """Protocol for better type hints for the fixture `run_task`.""" + + @property + def state(self) -> IntermediateTIState | TerminalTIState: ... + + @property + def msg(self) -> ToSupervisor | None: ... + + @property + def error(self) -> BaseException | None: ... + + xcom: _XComHelperProtocol + + def __call__( + self, + task: BaseOperator, + dag_id: str = ..., + run_id: str = ..., + logical_date: datetime | None = None, + start_date: datetime | None = None, + run_type: str = ..., + try_number: int = ..., + map_index: int | None = ..., + ti_id: UUID | None = None, + max_tries: int | None = None, + context_update: dict[str, Any] | None = None, + ) -> tuple[IntermediateTIState | TerminalTIState, ToSupervisor | None, BaseException | None]: ... + + +@pytest.fixture +def create_runtime_ti(mocked_parse): + """ + Fixture to create a Runtime TaskInstance for testing purposes without defining a dag file. + + It mimics the behavior of the `parse` function by creating a `RuntimeTaskInstance` based on the provided + `StartupDetails` (formed from arguments) and task. This allows you to test the logic of a task without + having to define a DAG file, parse it, get context from the server, etc. + + Example usage: :: + + def test_custom_task_instance(create_runtime_ti): + class MyTaskOperator(BaseOperator): + def execute(self, context): + assert context["dag_run"].run_id == "test_run" + + task = MyTaskOperator(task_id="test_task") + ti = create_runtime_ti(task) + # Further test logic... + """ + from uuid6 import uuid7 + + from airflow.sdk.api.datamodels._generated import TaskInstance + from airflow.sdk.definitions.dag import DAG + from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails + from airflow.utils import timezone + + def _create_task_instance( + task: BaseOperator, + dag_id: str = "test_dag", + run_id: str = "test_run", + logical_date: str | datetime = "2024-12-01T01:00:00Z", + start_date: str | datetime = "2024-12-01T01:00:00Z", + run_type: str = "manual", + try_number: int = 1, + map_index: int | None = -1, + upstream_map_indexes: dict[str, int] | None = None, + task_reschedule_count: int = 0, + ti_id: UUID | None = None, + conf: dict[str, Any] | None = None, + should_retry: bool | None = None, + max_tries: int | None = None, + ) -> RuntimeTaskInstance: + from airflow.sdk.api.datamodels._generated import DagRun, TIRunContext + + if not ti_id: + ti_id = uuid7() + + if not task.has_dag(): + dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 12, 3)) + task.dag = dag # type: ignore[assignment] + task = dag.task_dict[task.task_id] + + if task.dag.timetable: + data_interval_start, data_interval_end = task.dag.timetable.infer_manual_data_interval( + run_after=logical_date # type: ignore + ) + else: + data_interval_start = None + data_interval_end = None + + dag_id = task.dag.dag_id + task_retries = task.retries or 0 + run_after = data_interval_end or logical_date or timezone.utcnow() + + ti_context = TIRunContext( + dag_run=DagRun( + dag_id=dag_id, + run_id=run_id, + logical_date=logical_date, # type: ignore + data_interval_start=data_interval_start, + data_interval_end=data_interval_end, + start_date=start_date, # type: ignore + run_type=run_type, # type: ignore + run_after=run_after, # type: ignore + conf=conf, + ), + task_reschedule_count=task_reschedule_count, + max_tries=task_retries if max_tries is None else max_tries, + should_retry=should_retry if should_retry is not None else try_number <= task_retries, + ) + + if upstream_map_indexes is not None: + ti_context.upstream_map_indexes = upstream_map_indexes + + startup_details = StartupDetails( + ti=TaskInstance( + id=ti_id, + task_id=task.task_id, + dag_id=dag_id, + run_id=run_id, + try_number=try_number, + map_index=map_index, + ), + dag_rel_path="", + bundle_info=BundleInfo(name="anything", version="any"), + requests_fd=0, + ti_context=ti_context, + start_date=start_date, # type: ignore + ) + + ti = mocked_parse(startup_details, dag_id, task) + return ti + + return _create_task_instance + + +@pytest.fixture +def run_task(create_runtime_ti, mock_supervisor_comms, spy_agency) -> RunTaskCallable: + """ + Fixture to run a task without defining a dag file. + + This fixture builds on top of create_runtime_ti to provide a convenient way to execute tasks and get their results. + + The fixture provides: + - run_task.state - Get the task state + - run_task.msg - Get the task message + - run_task.error - Get the task error + - run_task.xcom.get(key) - Get an XCom value + - run_task.xcom.assert_pushed(key, value, ...) - Assert an XCom was pushed + + Example usage: :: + + def test_custom_task(run_task): + class MyTaskOperator(BaseOperator): + def execute(self, context): + return "hello" + + task = MyTaskOperator(task_id="test_task") + run_task(task) + assert run_task.state == TerminalTIState.SUCCESS + assert run_task.error is None + """ + import structlog + + from airflow.sdk.execution_time.task_runner import run + from airflow.sdk.execution_time.xcom import XCom + from airflow.utils import timezone + + # Set up spies once at fixture level + if hasattr(XCom.set, "spy"): + spy_agency.unspy(XCom.set) + if hasattr(XCom.get_one, "spy"): + spy_agency.unspy(XCom.get_one) + spy_agency.spy_on(XCom.set, call_original=True) + spy_agency.spy_on( + XCom.get_one, call_fake=lambda cls, *args, **kwargs: _get_one_from_set_calls(*args, **kwargs) + ) + + def _get_one_from_set_calls(*args, **kwargs) -> Any | None: + """Get the most recent value from XCom.set calls that matches the criteria.""" + key = kwargs.get("key") + task_id = kwargs.get("task_id") + dag_id = kwargs.get("dag_id") + run_id = kwargs.get("run_id") + map_index = kwargs.get("map_index") or -1 + + for call in reversed(XCom.set.calls): + if ( + call.kwargs.get("task_id") == task_id + and call.kwargs.get("dag_id") == dag_id + and call.kwargs.get("run_id") == run_id + and call.kwargs.get("map_index") == map_index + ): + if call.args and len(call.args) >= 2: + call_key, value = call.args + if call_key == key: + return value + return None + + class XComHelper: + def __init__(self): + self._ti = None + + def get( + self, + key: str, + task_id: str | None = None, + dag_id: str | None = None, + run_id: str | None = None, + map_index: int | None = None, + ) -> Any: + # Use task instance values as defaults + task_id = task_id or self._ti.task_id + dag_id = dag_id or self._ti.dag_id + run_id = run_id or self._ti.run_id + map_index = map_index if map_index is not None else self._ti.map_index + + return XCom.get_one( + key=key, + task_id=task_id, + dag_id=dag_id, + run_id=run_id, + map_index=map_index, + ) + + def assert_pushed( + self, + key: str, + value: Any, + task_id: str | None = None, + dag_id: str | None = None, + run_id: str | None = None, + map_index: int | None = None, + **kwargs, + ): + """Assert that an XCom was pushed with the given key and value.""" + task_id = task_id or self._ti.task_id + dag_id = dag_id or self._ti.dag_id + run_id = run_id or self._ti.run_id + map_index = map_index if map_index is not None else self._ti.map_index + + spy_agency.assert_spy_called_with( + XCom.set, + key, + value, + task_id=task_id, + dag_id=dag_id, + run_id=run_id, + map_index=map_index, + **kwargs, + ) + + def clear(self): + """Clear all XCom calls.""" + if hasattr(XCom.set, "spy"): + spy_agency.unspy(XCom.set) + if hasattr(XCom.get_one, "spy"): + spy_agency.unspy(XCom.get_one) + + class RunTaskWithXCom: + def __init__(self, create_runtime_ti): + self.create_runtime_ti = create_runtime_ti + self.xcom = XComHelper() + self._state = None + self._msg = None + self._error = None + + @property + def state(self) -> IntermediateTIState | TerminalTIState: + """Get the task state.""" + return self._state + + @property + def msg(self) -> ToSupervisor | None: + """Get the task message to send to supervisor.""" + return self._msg + + @property + def error(self) -> BaseException | None: + """Get the error message if there was any.""" + return self._error + + def __call__( + self, + task: BaseOperator, + dag_id: str = "test_dag", + run_id: str = "test_run", + logical_date: datetime | None = None, + start_date: datetime | None = None, + run_type: str = "manual", + try_number: int = 1, + map_index: int | None = -1, + ti_id: UUID | None = None, + max_tries: int | None = None, + context_update: dict[str, Any] | None = None, + ) -> tuple[IntermediateTIState | TerminalTIState, ToSupervisor | None, BaseException | None]: + now = timezone.utcnow() + if logical_date is None: + logical_date = now + + if start_date is None: + start_date = now + + ti = self.create_runtime_ti( + task=task, + dag_id=dag_id, + run_id=run_id, + logical_date=logical_date, + start_date=start_date, + run_type=run_type, + try_number=try_number, + map_index=map_index, + ti_id=ti_id, + max_tries=max_tries, + ) + + context = ti.get_template_context() + if context_update: + context.update(context_update) + log = structlog.get_logger(logger_name="task") + + # Store the task instance for XCom operations + self.xcom._ti = ti + + # Run the task + state, msg, error = run(ti, context, log) + self._state = state + self._msg = msg + self._error = error + + return state, msg, error + + return RunTaskWithXCom(create_runtime_ti) + + +@pytest.fixture +def mock_xcom_backend(): + with mock.patch("airflow.sdk.execution_time.task_runner.XCom", create=True) as xcom_backend: + yield xcom_backend From 446ca47384e2390c8eb245ec5dd75ea0ec656c42 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 27 Mar 2025 18:51:54 +0530 Subject: [PATCH 2/2] fixup! Add a fixture to easily replace `ti.run` usage --- devel-common/src/tests_common/pytest_plugin.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 1475c20510a01..7d04f4fe22189 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -1886,6 +1886,8 @@ def mocked_parse(spy_agency): Fixture to set up an inline DAG and use it in a stubbed `parse` function. Use this fixture if you want to isolate and test `parse` or `run` logic without having to define a DAG file. + In most cases, you should use `create_runtime_ti` fixture instead where you can directly pass an operator + compared to lower level AIP-72 constructs like `StartupDetails`. This fixture returns a helper function `set_dag` that: 1. Creates an in line DAG with the given `dag_id` and `task` (limited to one task)