From 4af8d8cae419eeb706ad31946a8d586feb952dfd Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 22 May 2025 23:38:54 +0530 Subject: [PATCH] Port ``ti.run`` to Task SDK execution path (#50141) This is the last thing to remove the parallel execution path. For simplicity, `ti.run` and `ti._run_raw_task` have been retained, but they now use the Task SDK execution path. They have been retained so we don't have to make big bang changes in tests and PR remains relatively review-able. There are opportunities for evaluating & cleanup after this PR is merged: - `get_template_context`, - `handle_failure`, - `check_and_change_state_before_execution` - and probably a lot more. The following bugs/missing features were identified and implemented: - Running `on_kill` on the Task SDK execution path - Resolving lazy_object_proxies in the Context dict when running with a VirtualEnvOperator - https://github.com/apache/airflow/pull/50898 (cherry picked from commit f1ca1d1b90da1b573df90183be444a0a71f3bff9) --- .../src/airflow/cli/commands/task_command.py | 2 +- .../src/airflow/models/taskinstance.py | 719 +------ .../api_fastapi/execution_api/conftest.py | 31 +- .../versions/head/test_task_instances.py | 17 +- .../tests/unit/listeners/test_listeners.py | 20 +- .../tests/unit/models/test_taskinstance.py | 1815 +---------------- .../deps/test_not_previously_skipped_dep.py | 6 +- .../tests/unit/utils/test_log_handlers.py | 2 + .../tests_common/test_utils/version_compat.py | 1 + .../log_handlers/test_log_handlers.py | 4 +- .../unit/common/sql/operators/test_sql.py | 12 +- .../microsoft/azure/operators/test_adx.py | 25 +- .../unit/oracle/operators/test_oracle.py | 22 +- .../snowflake/decorators/test_snowpark.py | 29 +- .../providers/standard/operators/python.py | 14 +- .../decorators/test_branch_external_python.py | 11 +- .../standard/decorators/test_branch_python.py | 10 +- .../decorators/test_branch_virtualenv.py | 11 +- .../unit/standard/decorators/test_python.py | 43 +- .../standard/decorators/test_short_circuit.py | 13 +- .../operators/test_branch_operator.py | 80 +- .../unit/standard/operators/test_datetime.py | 16 +- .../operators/test_latest_only_operator.py | 115 +- .../unit/standard/operators/test_python.py | 229 ++- .../unit/standard/operators/test_weekday.py | 21 +- pyproject.toml | 2 +- task-sdk/src/airflow/sdk/definitions/dag.py | 12 +- .../airflow/sdk/execution_time/task_runner.py | 2 +- 28 files changed, 514 insertions(+), 2770 deletions(-) diff --git a/airflow-core/src/airflow/cli/commands/task_command.py b/airflow-core/src/airflow/cli/commands/task_command.py index cf074cd32c244..69859af3f5610 100644 --- a/airflow-core/src/airflow/cli/commands/task_command.py +++ b/airflow-core/src/airflow/cli/commands/task_command.py @@ -381,7 +381,7 @@ def task_test(args, dag: DAG | None = None) -> None: ) try: with redirect_stdout(RedactedIO()): - _run_task(ti=ti) + _run_task(ti=ti, run_triggerer=True) if ti.state == State.FAILED and args.post_mortem: debugger = _guess_debugger() debugger.set_trace() diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 5a44596d4ca9d..27b0305f8d478 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -24,15 +24,12 @@ import math import operator import os -import signal -import traceback from collections import defaultdict -from collections.abc import Collection, Generator, Iterable, Mapping, Sequence +from collections.abc import Collection, Generator, Iterable, Sequence from datetime import timedelta -from enum import Enum from functools import cache from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any from urllib.parse import quote import attrs @@ -40,7 +37,6 @@ import jinja2 import lazy_object_proxy import uuid6 -from jinja2 import TemplateAssertionError, UndefinedError from sqlalchemy import ( Column, Float, @@ -77,32 +73,19 @@ from airflow.assets.manager import asset_manager from airflow.configuration import conf from airflow.exceptions import ( - AirflowException, - AirflowFailException, AirflowInactiveAssetInInletOrOutletException, - AirflowRescheduleException, - AirflowSensorTimeout, - AirflowSkipException, - AirflowTaskTerminated, - AirflowTaskTimeout, TaskDeferralError, TaskDeferred, - UnmappableXComLengthPushed, - UnmappableXComTypePushed, - XComForMappingNotPushed, ) from airflow.listeners.listener import get_listener_manager -from airflow.models.asset import AssetActive, AssetEvent, AssetModel +from airflow.models.asset import AssetEvent, AssetModel from airflow.models.base import Base, StringID, TaskInstanceDependencies from airflow.models.log import Log -from airflow.models.renderedtifields import get_serialized_template_fields from airflow.models.taskinstancekey import TaskInstanceKey from airflow.models.taskmap import TaskMap from airflow.models.taskreschedule import TaskReschedule from airflow.models.xcom import LazyXComSelectSequence, XComModel from airflow.plugins_manager import integrate_macros_plugins -from airflow.sdk.execution_time.context import context_to_airflow_vars -from airflow.sentry import Sentry from airflow.settings import task_instance_mutation_hook from airflow.stats import Stats from airflow.ti_deps.dep_context import DepContext @@ -118,8 +101,6 @@ from airflow.utils.span_status import SpanStatus from airflow.utils.sqlalchemy import ExecutorConfigType, ExtendedJSON, UtcDateTime from airflow.utils.state import DagRunState, State, TaskInstanceState -from airflow.utils.task_instance_session import set_current_task_instance_session -from airflow.utils.timeout import timeout from airflow.utils.xcom import XCOM_RETURN_KEY TR = TaskReschedule @@ -130,7 +111,6 @@ if TYPE_CHECKING: from datetime import datetime from pathlib import PurePath - from types import TracebackType import pendulum from sqlalchemy.engine import Connection as SAConnection, Engine @@ -139,12 +119,11 @@ from sqlalchemy.sql.elements import BooleanClauseList from sqlalchemy.sql.expression import ColumnOperators - from airflow.models.abstractoperator import TaskStateChangeCallback from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG as SchedulerDAG, DagModel from airflow.models.dagrun import DagRun from airflow.sdk.api.datamodels._generated import AssetProfile - from airflow.sdk.definitions._internal.abstractoperator import Operator + from airflow.sdk.definitions._internal.abstractoperator import Operator, TaskStateChangeCallback from airflow.sdk.definitions.asset import AssetNameRef, AssetUniqueKey, AssetUriRef from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.taskgroup import MappedTaskGroup @@ -157,17 +136,6 @@ PAST_DEPENDS_MET = "past_depends_met" -class TaskReturnCode(Enum): - """ - Enum to signal manner of exit for task run command. - - :meta private: - """ - - DEFERRED = 100 - """When task exits with deferral to trigger.""" - - @provide_session def _add_log( event, @@ -352,46 +320,6 @@ def _creator_note(val): return TaskInstanceNote(*val) -@provide_session -def _record_task_map_for_downstreams( - *, - task_instance: TaskInstance, - task: Operator, - value: Any, - session: Session, -) -> None: - """ - Record the task map for downstream tasks. - - :param task_instance: the task instance - :param task: The task object - :param dag: the dag associated with the task - :param value: The value - :param session: SQLAlchemy ORM Session - - :meta private: - """ - from airflow.sdk.definitions.mappedoperator import MappedOperator, is_mappable_value - - if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate. - return - # TODO: We don't push TaskMap for mapped task instances because it's not - # currently possible for a downstream to depend on one individual mapped - # task instance. This will change when we implement task mapping inside - # a mapped task group, and we'll need to further analyze the case. - if isinstance(task, MappedOperator): - return - if value is None: - raise XComForMappingNotPushed() - if not is_mappable_value(value): - raise UnmappableXComTypePushed(value) - task_map = TaskMap.from_task_instance_xcom(task_instance, value) - max_map_length = conf.getint("core", "max_map_length", fallback=1024) - if task_map.length > max_map_length: - raise UnmappableXComLengthPushed(value, max_map_length) - session.merge(task_map) - - def _get_email_subject_content( *, task_instance: TaskInstance | RuntimeTaskInstanceProtocol, @@ -1128,36 +1056,6 @@ def refresh_from_task(self, task: Operator, pool_override: str | None = None) -> # Re-apply cluster policy here so that task default do not overload previous data task_instance_mutation_hook(self) - @staticmethod - @provide_session - def _clear_xcom_data(ti: TaskInstance, session: Session = NEW_SESSION) -> None: - """ - Clear all XCom data from the database for the task instance. - - If the task is unmapped, all XComs matching this task ID in the same DAG - run are removed. If the task is mapped, only the one with matching map - index is removed. - - :param ti: The TI for which we need to clear xcoms. - :param session: SQLAlchemy ORM Session - """ - ti.log.debug("Clearing XCom data") - if ti.map_index < 0: - map_index: int | None = None - else: - map_index = ti.map_index - XComModel.clear( - dag_id=ti.dag_id, - task_id=ti.task_id, - run_id=ti.run_id, - map_index=map_index, - session=session, - ) - - @provide_session - def clear_xcom_data(self, session: Session = NEW_SESSION): - self._clear_xcom_data(ti=self, session=session) - @property def key(self) -> TaskInstanceKey: """Returns a tuple that identifies the task instance uniquely.""" @@ -1670,184 +1568,25 @@ def clear_next_method_args(self) -> None: self.next_kwargs = None @provide_session - @Sentry.enrich_errors def _run_raw_task( self, mark_success: bool = False, - test_mode: bool = False, - pool: str | None = None, - raise_on_defer: bool = False, session: Session = NEW_SESSION, - ) -> TaskReturnCode | None: - """ - Run a task, update the state upon completion, and run any appropriate callbacks. - - Immediately runs the task (without checking or changing db state - before execution) and then sets the appropriate final state after - completion and runs any post-execute callbacks. Meant to be called - only after another function changes the state to running. - - :param mark_success: Don't run the task, mark its state as success - :param test_mode: Doesn't record success or failure in the DB - :param pool: specifies the pool to use to run the task instance - :param session: SQLAlchemy ORM Session - """ - if TYPE_CHECKING: - assert self.task - - if TYPE_CHECKING: - assert isinstance(self.task, BaseOperator) - - self.test_mode = test_mode - self.refresh_from_task(self.task, pool_override=pool) - self.refresh_from_db(session=session) - self.hostname = get_hostname() - self.pid = os.getpid() - if not test_mode: - TaskInstance.save_to_db(ti=self, session=session) - actual_start_date = timezone.utcnow() - Stats.incr(f"ti.start.{self.task.dag_id}.{self.task.task_id}", tags=self.stats_tags) - # Same metric with tagging - Stats.incr("ti.start", tags=self.stats_tags) - # Initialize final state counters at zero - for state in State.task_states: - Stats.incr( - f"ti.finish.{self.task.dag_id}.{self.task.task_id}.{state}", - count=0, - tags=self.stats_tags, - ) - # Same metric with tagging - Stats.incr( - "ti.finish", - count=0, - tags={**self.stats_tags, "state": str(state)}, - ) - with set_current_task_instance_session(session=session): - self.task = self.task.prepare_for_execution() - context = self.get_template_context(ignore_param_exceptions=False, session=session) - - try: - if self.task: - from airflow.sdk.definitions.asset import Asset - - inlets = [asset.asprofile() for asset in self.task.inlets if isinstance(asset, Asset)] - outlets = [asset.asprofile() for asset in self.task.outlets if isinstance(asset, Asset)] - TaskInstance.validate_inlet_outlet_assets_activeness(inlets, outlets, session=session) - if not mark_success: - TaskInstance._execute_task_with_callbacks( - self=self, # type: ignore[arg-type] - context=context, - test_mode=test_mode, - session=session, - ) - if not test_mode: - self.refresh_from_db(lock_for_update=True, session=session, keep_local_changes=True) - self.state = TaskInstanceState.SUCCESS - except TaskDeferred as defer: - # The task has signalled it wants to defer execution based on - # a trigger. - if raise_on_defer: - raise - self.defer_task(exception=defer, session=session) - self.log.info( - "Pausing task as DEFERRED. dag_id=%s, task_id=%s, run_id=%s, logical_date=%s, start_date=%s", - self.dag_id, - self.task_id, - self.run_id, - _date_or_empty(task_instance=self, attr="logical_date"), - _date_or_empty(task_instance=self, attr="start_date"), - ) - return TaskReturnCode.DEFERRED - except AirflowSkipException as e: - # Recording SKIP - # log only if exception has any arguments to prevent log flooding - if e.args: - self.log.info(e) - if not test_mode: - self.refresh_from_db(lock_for_update=True, session=session, keep_local_changes=True) - self.state = TaskInstanceState.SKIPPED - _run_finished_callback(callbacks=self.task.on_skipped_callback, context=context) - TaskInstance.save_to_db(ti=self, session=session) - except AirflowRescheduleException as reschedule_exception: - self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session) - self.log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE") - return None - except (AirflowFailException, AirflowSensorTimeout) as e: - # If AirflowFailException is raised, task should not retry. - # If a sensor in reschedule mode reaches timeout, task should not retry. - self.handle_failure( - e, test_mode, context, force_fail=True, session=session - ) # already saves to db - raise - except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated) as e: - if not test_mode: - self.refresh_from_db(lock_for_update=True, session=session) - # for case when task is marked as success/failed externally - # or dagrun timed out and task is marked as skipped - # current behavior doesn't hit the callbacks - if self.state in State.finished: - self.clear_next_method_args() - TaskInstance.save_to_db(ti=self, session=session) - return None - self.handle_failure(e, test_mode, context, session=session) - raise - except SystemExit as e: - # We have already handled SystemExit with success codes (0 and None) in the `_execute_task`. - # Therefore, here we must handle only error codes. - msg = f"Task failed due to SystemExit({e.code})" - self.handle_failure(msg, test_mode, context, session=session) - raise AirflowException(msg) - except BaseException as e: - self.handle_failure(e, test_mode, context, session=session) - raise - finally: - # Print a marker post execution for internals of post task processing - log.info("::group::Post task execution logs") - - Stats.incr( - f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}", - tags=self.stats_tags, - ) - # Same metric with tagging - Stats.incr("ti.finish", tags={**self.stats_tags, "state": str(self.state)}) - - # Recording SKIPPED or SUCCESS - self.clear_next_method_args() - self.end_date = timezone.utcnow() - _log_state(task_instance=self) - self.set_duration() - - # run on_success_callback before db committing - # otherwise, the LocalTaskJob sees the state is changed to `success`, - # but the task_runner is still running, LocalTaskJob then treats the state is set externally! - if self.state == TaskInstanceState.SUCCESS: - _run_finished_callback(callbacks=self.task.on_success_callback, context=context) - - if not test_mode: - _add_log(event=self.state, task_instance=self, session=session) - if self.state == TaskInstanceState.SUCCESS: - from airflow.sdk.execution_time.task_runner import ( - _build_asset_profiles, - _serialize_outlet_events, - ) - - TaskInstance.register_asset_changes_in_db( - self, - list(_build_asset_profiles(self.task.outlets)), - list(_serialize_outlet_events(context["outlet_events"])), - session=session, - ) + **kwargs: Any, + ) -> None: + """Only kept for tests.""" + from airflow.sdk.definitions.dag import _run_task - TaskInstance.save_to_db(ti=self, session=session) - if self.state == TaskInstanceState.SUCCESS: - try: - get_listener_manager().hook.on_task_instance_success( - previous_state=TaskInstanceState.RUNNING, task_instance=self - ) - except Exception: - log.exception("error calling listener") + if mark_success: + self.set_state(TaskInstanceState.SUCCESS) + log.info("[DAG TEST] Marking success for %s ", self.task_id) return None + taskrun_result = _run_task(ti=self) + if taskrun_result is not None and taskrun_result.error: + raise taskrun_result.error + return None + @staticmethod @provide_session def register_asset_changes_in_db( @@ -1988,252 +1727,6 @@ def update_rtif(self, rendered_fields, session: Session = NEW_SESSION): session.flush() RenderedTaskInstanceFields.delete_old_records(self.task_id, self.dag_id, session=session) - def _execute_task_with_callbacks(self, context: Context, test_mode: bool = False, *, session: Session): - """Prepare Task for Execution.""" - from airflow.sdk.execution_time.callback_runner import create_executable_runner - from airflow.sdk.execution_time.context import context_get_outlet_events - - if TYPE_CHECKING: - assert self.task - - parent_pid = os.getpid() - - def signal_handler(signum, frame): - pid = os.getpid() - - # If a task forks during execution (from DAG code) for whatever - # reason, we want to make sure that we react to the signal only in - # the process that we've spawned ourselves (referred to here as the - # parent process). - if pid != parent_pid: - os._exit(1) - return - self.log.error("Received SIGTERM. Terminating subprocesses.") - self.log.error("Stacktrace: \n%s", "".join(traceback.format_stack())) - self.task.on_kill() - raise AirflowTaskTerminated( - f"Task received SIGTERM signal {self.task_id=} {self.dag_id=} {self.run_id=} {self.map_index=}" - ) - - signal.signal(signal.SIGTERM, signal_handler) - - # Don't clear Xcom until the task is certain to execute, and check if we are resuming from deferral. - if not self.next_method: - self.clear_xcom_data() - - with ( - Stats.timer(f"dag.{self.task.dag_id}.{self.task.task_id}.duration"), - Stats.timer("task.duration", tags=self.stats_tags), - ): - # Set the validated/merged params on the task object. - self.task.params = context["params"] - - with set_current_context(context): - dag = self.task.get_dag() - if dag is not None: - jinja_env = dag.get_template_env() - else: - jinja_env = None - task_orig = self.render_templates(context=context, jinja_env=jinja_env) - - # The task is never MappedOperator at this point. - if TYPE_CHECKING: - assert isinstance(self.task, BaseOperator) - - if not test_mode: - rendered_fields = get_serialized_template_fields(task=self.task) - self.update_rtif(rendered_fields=rendered_fields) - # Export context to make it available for operators to use. - airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) - os.environ.update(airflow_context_vars) - - # Log context only for the default execution method, the assumption - # being that otherwise we're resuming a deferred task (in which - # case there's no need to log these again). - if not self.next_method: - self.log.info( - "Exporting env vars: %s", - " ".join(f"{k}={v!r}" for k, v in airflow_context_vars.items()), - ) - - # Run pre_execute callback - if self.task._pre_execute_hook: - create_executable_runner( - self.task._pre_execute_hook, - context_get_outlet_events(context), - logger=self.log, - ).run(context) - create_executable_runner( - self.task.pre_execute, - context_get_outlet_events(context), - logger=self.log, - ).run(context) - - # Run on_execute callback - self._run_execute_callback(context, self.task) - - # Run on_task_instance_running event - try: - get_listener_manager().hook.on_task_instance_running( - previous_state=TaskInstanceState.QUEUED, task_instance=self - ) - except Exception: - log.exception("error calling listener") - - def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) -> str | None: - """Render named map index if the DAG author defined map_index_template at the task level.""" - if jinja_env is None or (template := context.get("map_index_template")) is None: - return None - rendered_map_index = jinja_env.from_string(template).render(context) - log.debug("Map index rendered as %s", rendered_map_index) - return rendered_map_index - - # Execute the task. - with set_current_context(context): - try: - result = self._execute_task(context, task_orig) - except Exception: - # If the task failed, swallow rendering error so it doesn't mask the main error. - with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError): - self._rendered_map_index = _render_map_index(context, jinja_env=jinja_env) - raise - else: # If the task succeeded, render normally to let rendering error bubble up. - self._rendered_map_index = _render_map_index(context, jinja_env=jinja_env) - - # Run post_execute callback - if self.task._post_execute_hook: - create_executable_runner( - self.task._post_execute_hook, - context_get_outlet_events(context), - logger=self.log, - ).run(context, result) - create_executable_runner( - self.task.post_execute, - context_get_outlet_events(context), - logger=self.log, - ).run(context, result) - - Stats.incr(f"operator_successes_{self.task.task_type}", tags=self.stats_tags) - # Same metric with tagging - Stats.incr("operator_successes", tags={**self.stats_tags, "task_type": self.task.task_type}) - Stats.incr("ti_successes", tags=self.stats_tags) - - def _execute_task(self, context: Context, task_orig: Operator): - """ - Execute Task (optionally with a Timeout) and push Xcom results. - - :param context: Jinja2 context - :param task_orig: origin task - """ - from airflow.sdk.bases.operator import ExecutorSafeguard - from airflow.sdk.definitions.mappedoperator import MappedOperator - - task_to_execute = self.task - - if TYPE_CHECKING: - # TODO: TaskSDK this function will need 100% re-writing - # This only works with a "rich" BaseOperator, not the SDK version - assert isinstance(task_to_execute, BaseOperator) - - if isinstance(task_to_execute, MappedOperator): - raise AirflowException("MappedOperator cannot be executed.") - - # If the task has been deferred and is being executed due to a trigger, - # then we need to pick the right method to come back to, otherwise - # we go for the default execute - execute_callable_kwargs: dict[str, Any] = {} - execute_callable: Callable - if self.next_method: - execute_callable = task_to_execute.resume_execution - execute_callable_kwargs["next_method"] = self.next_method - # We don't want modifictions we make here to be tracked by SQLA - execute_callable_kwargs["next_kwargs"] = {**(self.next_kwargs or {})} - if self.next_method == "execute": - execute_callable_kwargs["next_kwargs"][f"{task_to_execute.__class__.__name__}__sentinel"] = ( - ExecutorSafeguard.sentinel_value - ) - else: - execute_callable = task_to_execute.execute - if execute_callable.__name__ == "execute": - execute_callable_kwargs[f"{task_to_execute.__class__.__name__}__sentinel"] = ( - ExecutorSafeguard.sentinel_value - ) - - def _execute_callable(context: Context, **execute_callable_kwargs): - from airflow.sdk.execution_time.callback_runner import create_executable_runner - from airflow.sdk.execution_time.context import context_get_outlet_events - - try: - # Print a marker for log grouping of details before task execution - log.info("::endgroup::") - - return create_executable_runner( - execute_callable, - context_get_outlet_events(context), - logger=log, - ).run(context=context, **execute_callable_kwargs) - except SystemExit as e: - # Handle only successful cases here. Failure cases will be handled upper - # in the exception chain. - if e.code is not None and e.code != 0: - raise - return None - - # If a timeout is specified for the task, make it fail - # if it goes beyond - if task_to_execute.execution_timeout: - # If we are coming in with a next_method (i.e. from a deferral), - # calculate the timeout from our start_date. - if self.next_method and self.start_date: - timeout_seconds = ( - task_to_execute.execution_timeout - (timezone.utcnow() - self.start_date) - ).total_seconds() - else: - timeout_seconds = task_to_execute.execution_timeout.total_seconds() - try: - # It's possible we're already timed out, so fast-fail if true - if timeout_seconds <= 0: - raise AirflowTaskTimeout() - # Run task in timeout wrapper - with timeout(timeout_seconds): - result = _execute_callable(context=context, **execute_callable_kwargs) - except AirflowTaskTimeout: - task_to_execute.on_kill() - raise - else: - result = _execute_callable(context=context, **execute_callable_kwargs) - cm = create_session() - with cm as session_or_null: - if task_to_execute.do_xcom_push: - xcom_value = result - else: - xcom_value = None - if xcom_value is not None: # If the task returns a result, push an XCom containing it. - if task_to_execute.multiple_outputs: - if not isinstance(xcom_value, Mapping): - raise AirflowException( - f"Returned output was type {type(xcom_value)} " - "expected dictionary for multiple_outputs" - ) - for key in xcom_value.keys(): - if not isinstance(key, str): - raise AirflowException( - "Returned dictionary keys must be strings when using " - f"multiple_outputs, found {key} ({type(key)}) instead" - ) - for key, value in xcom_value.items(): - self.xcom_push(key=key, value=value, session=session_or_null) - self.xcom_push(key=XCOM_RETURN_KEY, value=xcom_value, session=session_or_null) - if TYPE_CHECKING: - assert task_orig.dag - _record_task_map_for_downstreams( - task_instance=self, - task=task_orig, - value=xcom_value, - session=session_or_null, - ) - return result - def update_heartbeat(self): with create_session() as session: session.execute( @@ -2318,16 +1811,6 @@ def defer_task(self, exception: TaskDeferred | None, session: Session = NEW_SESS session.merge(self) session.commit() - def _run_execute_callback(self, context: Context, task: BaseOperator) -> None: - """Functions that need to be run before a Task is executed.""" - if not (callbacks := task.on_execute_callback): - return - for callback in callbacks if isinstance(callbacks, list) else [callbacks]: - try: - callback(context) - except Exception: - self.log.exception("Failed when executing execute callback") - @provide_session def run( self, @@ -2343,7 +1826,7 @@ def run( session: Session = NEW_SESSION, raise_on_defer: bool = False, ) -> None: - """Run TaskInstance.""" + """Run TaskInstance (only kept for tests).""" res = self.check_and_change_state_before_execution( verbose=verbose, ignore_all_deps=ignore_all_deps, @@ -2359,13 +1842,7 @@ def run( if not res: return - self._run_raw_task( - mark_success=mark_success, - test_mode=test_mode, - pool=pool, - session=session, - raise_on_defer=raise_on_defer, - ) + self._run_raw_task(mark_success=mark_success) def dry_run(self) -> None: """Only Renders Templates for the TI.""" @@ -2378,65 +1855,6 @@ def dry_run(self) -> None: assert isinstance(self.task, BaseOperator) self.task.dry_run() - @provide_session - def _handle_reschedule( - self, - actual_start_date: datetime, - reschedule_exception: AirflowRescheduleException, - test_mode: bool = False, - session: Session = NEW_SESSION, - ): - # Don't record reschedule request in test mode - if test_mode: - return - - self.refresh_from_db(session) - - if TYPE_CHECKING: - assert self.task - - self.end_date = timezone.utcnow() - self.set_duration() - - # set state - self.state = TaskInstanceState.UP_FOR_RESCHEDULE - - self.clear_next_method_args() - - session.merge(self) - session.commit() - - # we add this in separate commit to reduce likelihood of deadlock - # see https://github.com/apache/airflow/pull/21362 for more info - session.add( - TaskReschedule( - self.id, - actual_start_date, - self.end_date, - reschedule_exception.reschedule_date, - ) - ) - session.commit() - return self - - @staticmethod - def get_truncated_error_traceback(error: BaseException, truncate_to: Callable) -> TracebackType | None: - """ - Truncate the traceback of an exception to the first frame called from within a given function. - - :param error: exception to get traceback from - :param truncate_to: Function to truncate TB to. Must have a ``__code__`` attribute - - :meta private: - """ - tb = error.__traceback__ - code = truncate_to.__func__.__code__ # type: ignore[attr-defined] - while tb is not None: - if tb.tb_frame.f_code is code: - return tb.tb_next - tb = tb.tb_next - return tb or error.__traceback__ - @classmethod def fetch_handle_failure_context( cls, @@ -2461,11 +1879,7 @@ def fetch_handle_failure_context( :param fail_fast: if True, fail all downstream tasks """ if error: - if isinstance(error, BaseException): - tb = TaskInstance.get_truncated_error_traceback(error, truncate_to=ti._execute_task) - cls.logger().error("Task failed with exception", exc_info=(type(error), error, tb)) - else: - cls.logger().error("%s", error) + cls.logger().error("%s", error) if not test_mode: ti.refresh_from_db(session) @@ -2778,47 +2192,6 @@ def get_triggering_events() -> dict[str, list[AssetEvent]]: return context - @provide_session - def get_rendered_template_fields(self, session: Session = NEW_SESSION) -> None: - """ - Update task with rendered template fields for presentation in UI. - - If task has already run, will fetch from DB; otherwise will render. - """ - from airflow.models.renderedtifields import RenderedTaskInstanceFields - - if TYPE_CHECKING: - assert isinstance(self.task, BaseOperator) - - rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self, session=session) - if rendered_task_instance_fields: - self.task = self.task.unmap(None) - for field_name, rendered_value in rendered_task_instance_fields.items(): - setattr(self.task, field_name, rendered_value) - return - - try: - # If we get here, either the task hasn't run or the RTIF record was purged. - from airflow.sdk.execution_time.secrets_masker import redact - - self.render_templates() - for field_name in self.task.template_fields: - rendered_value = getattr(self.task, field_name) - setattr(self.task, field_name, redact(rendered_value, field_name)) - except (TemplateAssertionError, UndefinedError) as e: - raise AirflowException( - "Webserver does not have access to User-defined Macros or Filters " - "when Dag Serialization is enabled. Hence for the task that have not yet " - "started running, please use 'airflow tasks render' for debugging the " - "rendering of template_fields." - ) from e - - def overwrite_params_with_dag_run_conf(self, params: dict, dag_run: DagRun): - """Overwrite Task Params with DagRun.conf.""" - if dag_run and dag_run.conf: - self.log.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf) - params.update(dag_run.conf) - def render_templates( self, context: Context | None = None, jinja_env: jinja2.Environment | None = None ) -> Operator: @@ -3263,60 +2636,6 @@ def duration_expression_update( } ) - @staticmethod - def validate_inlet_outlet_assets_activeness( - inlets: list[AssetProfile], outlets: list[AssetProfile], session: Session - ) -> None: - from airflow.sdk.definitions.asset import AssetUniqueKey - - if not (inlets or outlets): - return - - all_asset_unique_keys = { - AssetUniqueKey.from_asset(inlet_or_outlet) # type: ignore - for inlet_or_outlet in itertools.chain(inlets, outlets) - } - inactive_asset_unique_keys = TaskInstance._get_inactive_asset_unique_keys( - all_asset_unique_keys, session - ) - if inactive_asset_unique_keys: - raise AirflowInactiveAssetInInletOrOutletException(inactive_asset_unique_keys) - - @staticmethod - def _get_inactive_asset_unique_keys( - asset_unique_keys: set[AssetUniqueKey], session: Session - ) -> set[AssetUniqueKey]: - from airflow.sdk.definitions.asset import AssetUniqueKey - - active_asset_unique_keys = { - AssetUniqueKey(name, uri) - for name, uri in session.execute( - select(AssetActive.name, AssetActive.uri).where( - tuple_(AssetActive.name, AssetActive.uri).in_( - attrs.astuple(key) for key in asset_unique_keys - ) - ) - ) - } - return asset_unique_keys - active_asset_unique_keys - - def get_first_reschedule_date(self, context: Context) -> datetime | None: - """Get the first reschedule date for the task instance.""" - if TYPE_CHECKING: - assert isinstance(self.task, BaseOperator) - - with create_session() as session: - start_date = session.scalar( - select(TaskReschedule) - .where( - TaskReschedule.ti_id == str(self.id), - ) - .order_by(TaskReschedule.id.asc()) - .with_only_columns(TaskReschedule.start_date) - .limit(1) - ) - return start_date - def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> MappedTaskGroup | None: """Given two operators, find their innermost common mapped task group.""" diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py index 9c7a6e39a0e96..9e26937b63c06 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py @@ -32,8 +32,35 @@ def client(request: pytest.FixtureRequest): with TestClient(app, headers={"Authorization": "Bearer fake"}) as client: auth = AsyncMock(spec=JWTValidator) - auth.avalidated_claims.return_value = {"sub": "edb09971-4e0e-4221-ad3f-800852d38085"} - # Inject our fake JWTValidator object. Can be over-ridden by tests if they want + # Create a side_effect function that dynamically extracts the task instance ID from validators + def smart_validated_claims(cred, validators=None): + # Extract task instance ID from validators if present + # This handles the JWTBearerTIPathDep case where the validator contains the task ID from the path + if ( + validators + and "sub" in validators + and isinstance(validators["sub"], dict) + and "value" in validators["sub"] + ): + return { + "sub": validators["sub"]["value"], + "exp": 9999999999, # Far future expiration + "iat": 1000000000, # Past issuance time + "aud": "test-audience", + } + + # For other cases (like JWTBearerDep) where no specific validators are provided + # Return a default UUID with all required claims + return { + "sub": "00000000-0000-0000-0000-000000000000", + "exp": 9999999999, # Far future expiration + "iat": 1000000000, # Past issuance time + "aud": "test-audience", + } + + # Set the side_effect for avalidated_claims + auth.avalidated_claims.side_effect = smart_validated_claims lifespan.registry.register_value(JWTValidator, auth) + yield client diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index cc2b1baa64ac6..f1b8982eb04f1 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -17,7 +17,6 @@ from __future__ import annotations -import operator from datetime import datetime from unittest import mock from uuid import uuid4 @@ -962,22 +961,18 @@ def test_ti_skip_downstream(self, client, session, create_task_instance, dag_mak t1 = EmptyOperator(task_id="t1") t0 >> t1 dr = dag_maker.create_dagrun(run_id="run") - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in sorted(decision.schedulable_tis, key=operator.attrgetter("task_id")): - # TODO: TaskSDK #45549 - ti.task = dag_maker.dag.get_task(ti.task_id) - ti.run(session=session) - t0 = dr.get_task_instance("t0") + ti0 = dr.get_task_instance("t0") + ti0.set_state(State.SUCCESS) + response = client.patch( - f"/execution/task-instances/{t0.id}/skip-downstream", + f"/execution/task-instances/{ti0.id}/skip-downstream", json=_json, ) - t1 = dr.get_task_instance("t1") + ti1 = dr.get_task_instance("t1") assert response.status_code == 204 - assert decision.schedulable_tis[0].state == State.SUCCESS - assert t1.state == State.SKIPPED + assert ti1.state == State.SKIPPED class TestTIHealthEndpoint: diff --git a/airflow-core/tests/unit/listeners/test_listeners.py b/airflow-core/tests/unit/listeners/test_listeners.py index 68337c905b24c..3fceaaf0843cc 100644 --- a/airflow-core/tests/unit/listeners/test_listeners.py +++ b/airflow-core/tests/unit/listeners/test_listeners.py @@ -120,14 +120,13 @@ def test_listener_gets_only_subscribed_calls(create_task_instance, session=None) @provide_session -def test_listener_suppresses_exceptions(create_task_instance, session, caplog): +def test_listener_suppresses_exceptions(create_task_instance, session, cap_structlog): lm = get_listener_manager() lm.add_listener(throwing_listener) ti = create_task_instance(session=session, state=TaskInstanceState.QUEUED) - with caplog.at_level(logging.ERROR): - ti._run_raw_task() - assert "error calling listener" in caplog.messages + ti.run() + assert "error calling listener" in cap_structlog @provide_session @@ -139,7 +138,7 @@ def test_listener_captures_failed_taskinstances(create_task_instance_of_operator BashOperator, dag_id=DAG_ID, logical_date=LOGICAL_DATE, task_id=TASK_ID, bash_command="exit 1" ) with pytest.raises(AirflowException): - ti._run_raw_task() + ti.run() assert full_listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.FAILED] assert len(full_listener.state) == 2 @@ -153,7 +152,7 @@ def test_listener_captures_longrunning_taskinstances(create_task_instance_of_ope ti = create_task_instance_of_operator( BashOperator, dag_id=DAG_ID, logical_date=LOGICAL_DATE, task_id=TASK_ID, bash_command="sleep 5" ) - ti._run_raw_task() + ti.run() assert full_listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS] assert len(full_listener.state) == 2 @@ -166,13 +165,9 @@ def test_class_based_listener(create_task_instance, session=None): lm.add_listener(listener) ti = create_task_instance(session=session, state=TaskInstanceState.QUEUED) - # Using ti.run() instead of ti._run_raw_task() to capture state change to RUNNING - # that only happens on `check_and_change_state_before_execution()` that is called before - # `run()` calls `_run_raw_task()` ti.run() - assert len(listener.state) == 2 - assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS] + assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS, DagRunState.SUCCESS] def test_listener_logs_call(caplog, create_task_instance, session): @@ -181,10 +176,9 @@ def test_listener_logs_call(caplog, create_task_instance, session): lm.add_listener(full_listener) ti = create_task_instance(session=session, state=TaskInstanceState.QUEUED) - ti._run_raw_task() + ti.run() listener_logs = [r for r in caplog.record_tuples if r[0] == "airflow.listeners.listener"] - assert len(listener_logs) == 6 assert all(r[:-1] == ("airflow.listeners.listener", logging.DEBUG) for r in listener_logs) assert listener_logs[0][-1].startswith("Calling 'on_task_instance_running' with {'") assert listener_logs[1][-1].startswith("Hook impls: [" - assert body.startswith("Try 0") # try number only incremented by the scheduler - assert "test_email_alert" in body - def test_set_duration(self): task = EmptyOperator(task_id="op", email="test@test.test") ti = TI(task=task) @@ -2117,173 +1558,6 @@ def test_set_duration_empty_dates(self): ti.set_duration() assert ti.duration is None - def test_success_callback_no_race_condition(self, create_task_instance): - callback_wrapper = CallbackWrapper() - ti = create_task_instance( - on_success_callback=callback_wrapper.success_handler, - end_date=timezone.utcnow() + datetime.timedelta(days=10), - logical_date=timezone.utcnow(), - state=State.RUNNING, - ) - - session = settings.Session() - session.merge(ti) - session.commit() - - callback_wrapper.wrap_task_instance(ti) - ti._run_raw_task() - assert callback_wrapper.callback_ran - assert callback_wrapper.task_state_in_callback == State.SUCCESS - ti.refresh_from_db() - assert ti.state == State.SUCCESS - - def test_outlet_assets(self, create_task_instance, testing_dag_bundle): - """ - Verify that when we have an outlet asset on a task, and the task - completes successfully, an AssetDagRunQueue is logged. - """ - from airflow.example_dags import example_assets - from airflow.example_dags.example_assets import dag1 - - session = settings.Session() - dagbag = DagBag(dag_folder=example_assets.__file__) - dagbag.collect_dags(only_if_updated=False, safe_mode=False) - dagbag.sync_to_db("testing", None, session=session) - - asset_models = session.scalars(select(AssetModel)).all() - SchedulerJobRunner._activate_referenced_assets(asset_models, session=session) - session.flush() - - run_id = str(uuid4()) - dr = DagRun( - dag1.dag_id, - run_id=run_id, - run_type="manual", - state=DagRunState.RUNNING, - logical_date=timezone.utcnow(), - ) - session.merge(dr) - task = dag1.get_task("producing_task_1") - task.bash_command = "echo 1" # make it go faster - ti = TaskInstance(task, run_id=run_id) - session.merge(ti) - session.commit() - ti._run_raw_task() - ti.refresh_from_db() - assert ti.state == TaskInstanceState.SUCCESS - - # check that no other asset events recorded - event = ( - session.query(AssetEvent) - .join(AssetEvent.asset) - .filter(AssetEvent.source_task_instance == ti) - .one() - ) - assert event - assert event.asset - - # check that one queue record created for each dag that depends on asset 1 - assert session.query(AssetDagRunQueue.target_dag_id).filter_by(asset_id=event.asset.id).order_by( - AssetDagRunQueue.target_dag_id - ).all() == [ - ("asset_consumes_1",), - ("asset_consumes_1_and_2",), - ("asset_consumes_1_never_scheduled",), - ("conditional_asset_and_time_based_timetable",), - ("consume_1_and_2_with_asset_expressions",), - ("consume_1_or_2_with_asset_expressions",), - ("consume_1_or_both_2_and_3_with_asset_expressions",), - ] - - # check that one event record created for asset1 and this TI - assert session.query(AssetModel.uri).join(AssetEvent.asset).filter( - AssetEvent.source_task_instance == ti - ).one() == ("s3://dag1/output_1.txt",) - - # check that the asset event has an earlier timestamp than the ADRQ's - adrq_timestamps = session.query(AssetDagRunQueue.created_at).filter_by(asset_id=event.asset.id).all() - assert all(event.timestamp < adrq_timestamp for (adrq_timestamp,) in adrq_timestamps), ( - f"Some items in {[str(t) for t in adrq_timestamps]} are earlier than {event.timestamp}" - ) - - def test_outlet_assets_failed(self, create_task_instance, testing_dag_bundle): - """ - Verify that when we have an outlet asset on a task, and the task - failed, an AssetDagRunQueue is not logged, and an AssetEvent is - not generated - """ - from unit.dags import test_assets - from unit.dags.test_assets import dag_with_fail_task - - session = settings.Session() - dagbag = DagBag(dag_folder=test_assets.__file__) - dagbag.collect_dags(only_if_updated=False, safe_mode=False) - dagbag.sync_to_db("testing", None, session=session) - run_id = str(uuid4()) - dr = DagRun( - dag_with_fail_task.dag_id, - run_id=run_id, - run_type="manual", - state=DagRunState.RUNNING, - logical_date=timezone.utcnow(), - ) - session.merge(dr) - task = dag_with_fail_task.get_task("fail_task") - ti = TaskInstance(task, run_id=run_id) - session.merge(ti) - session.commit() - with pytest.raises(AirflowFailException): - ti._run_raw_task() - ti.refresh_from_db() - assert ti.state == TaskInstanceState.FAILED - - # check that no dagruns were queued - assert session.query(AssetDagRunQueue).count() == 0 - - # check that no asset events were generated - assert session.query(AssetEvent).count() == 0 - - def test_outlet_assets_skipped(self, testing_dag_bundle): - """ - Verify that when we have an outlet asset on a task, and the task - is skipped, an AssetDagRunQueue is not logged, and an AssetEvent is - not generated - """ - from unit.dags import test_assets - from unit.dags.test_assets import dag_with_skip_task - - session = settings.Session() - dagbag = DagBag(dag_folder=test_assets.__file__) - dagbag.collect_dags(only_if_updated=False, safe_mode=False) - dagbag.sync_to_db("testing", None, session=session) - - asset_models = session.scalars(select(AssetModel)).all() - SchedulerJobRunner._activate_referenced_assets(asset_models, session=session) - session.flush() - - run_id = str(uuid4()) - dr = DagRun( - dag_with_skip_task.dag_id, - run_id=run_id, - run_type="manual", - state=DagRunState.RUNNING, - logical_date=timezone.utcnow(), - ) - session.merge(dr) - task = dag_with_skip_task.get_task("skip_task") - ti = TaskInstance(task, run_id=run_id) - session.merge(ti) - session.commit() - ti._run_raw_task() - ti.refresh_from_db() - assert ti.state == TaskInstanceState.SKIPPED - - # check that no dagruns were queued - assert session.query(AssetDagRunQueue).count() == 0 - - # check that no asset events were generated - assert session.query(AssetEvent).count() == 0 - @pytest.mark.want_activate_assets(True) def test_outlet_asset_extra(self, dag_maker, session): from airflow.sdk.definitions.asset import Asset @@ -2347,77 +1621,9 @@ def write(*, outlet_events): assert event.source_task_id == "write" assert event.extra == {"one": 1} - @pytest.mark.want_activate_assets(True) - def test_outlet_asset_extra_yield(self, dag_maker, session): - from airflow.sdk.definitions.asset import Asset - from airflow.sdk.definitions.asset.metadata import Metadata - - with dag_maker(schedule=None, serialized=True, session=session): - - @task(outlets=Asset("test_outlet_asset_extra_1")) - def write1(): - result = "write_1 result" - yield Metadata(Asset(name="test_outlet_asset_extra_1"), {"foo": "bar"}) - return result - - write1() - - def _write2_post_execute(context, result): - yield Metadata(Asset(name="test_outlet_asset_extra_2", uri="test://asset-2"), extra={"x": 1}) - - BashOperator( - task_id="write2", - bash_command=":", - outlets=Asset(name="test_outlet_asset_extra_2", uri="test://asset-2"), - post_execute=_write2_post_execute, - ) - - @task(outlets=Asset("test_outlet_asset_extra_3")) - def write3(): - result = "write_3 result" - yield Metadata(Asset(name="test_outlet_asset_extra_3")) - return result - - write3() - - dr: DagRun = dag_maker.create_dagrun() - for ti in dr.get_task_instances(session=session): - ti.run(session=session) - - xcom = session.scalars( - select(XComModel).filter_by( - dag_id=dr.dag_id, run_id=dr.run_id, task_id="write1", key="return_value" - ) - ).one() - assert xcom.value == json.dumps("write_1 result") - - events = dict(iter(session.execute(select(AssetEvent.source_task_id, AssetEvent)))) - assert set(events) == {"write1", "write2", "write3"} - - assert events["write1"].source_dag_id == dr.dag_id - assert events["write1"].source_run_id == dr.run_id - assert events["write1"].source_task_id == "write1" - assert events["write1"].asset.uri == "test_outlet_asset_extra_1" - assert events["write1"].asset.name == "test_outlet_asset_extra_1" - assert events["write1"].extra == {"foo": "bar"} - - assert events["write2"].source_dag_id == dr.dag_id - assert events["write2"].source_run_id == dr.run_id - assert events["write2"].source_task_id == "write2" - assert events["write2"].asset.uri == "test://asset-2/" - assert events["write2"].asset.name == "test_outlet_asset_extra_2" - assert events["write2"].extra == {"x": 1} - - assert events["write3"].source_dag_id == dr.dag_id - assert events["write3"].source_run_id == dr.run_id - assert events["write3"].source_task_id == "write3" - assert events["write3"].asset.uri == "test_outlet_asset_extra_3" - assert events["write3"].asset.name == "test_outlet_asset_extra_3" - assert events["write3"].extra == {} - @pytest.mark.want_activate_assets(True) def test_outlet_asset_alias(self, dag_maker, session): - from airflow.sdk.definitions.asset import Asset, AssetAlias + from airflow.sdk.definitions.asset import Asset asset_uri = "test_outlet_asset_alias_test_case_ds" alias_name_1 = "test_outlet_asset_alias_test_case_asset_alias_1" @@ -2465,7 +1671,7 @@ def producer(*, outlet_events): @pytest.mark.want_activate_assets(True) def test_outlet_multiple_asset_alias(self, dag_maker, session): - from airflow.sdk.definitions.asset import Asset, AssetAlias + from airflow.sdk.definitions.asset import Asset asset_uri = "test_outlet_maa_ds" asset_alias_name_1 = "test_outlet_maa_asset_alias_1" @@ -2538,7 +1744,6 @@ def producer(*, outlet_events): @pytest.mark.want_activate_assets(True) def test_outlet_asset_alias_through_metadata(self, dag_maker, session): - from airflow.sdk.definitions.asset import AssetAlias from airflow.sdk.definitions.asset.metadata import Metadata asset_uri = "test_outlet_asset_alias_through_metadata_ds" @@ -2582,7 +1787,7 @@ def producer(*, outlet_events): @pytest.mark.want_activate_assets(True) def test_outlet_asset_alias_asset_not_exists(self, dag_maker, session): - from airflow.sdk.definitions.asset import Asset, AssetAlias + from airflow.sdk.definitions.asset import Asset asset_alias_name = "test_outlet_asset_alias_asset_not_exists_asset_alias" asset_uri = "does_not_exist" @@ -2611,7 +1816,7 @@ def producer(*, outlet_events): assert session.scalars(asset_event_check_stmt).one().uri == asset_uri def test_outlet_asset_alias_asset_inactive(self, dag_maker, session): - from airflow.sdk.definitions.asset import Asset, AssetAlias + from airflow.sdk.definitions.asset import Asset asset1 = Asset("asset1") asset2 = Asset("asset2") @@ -2751,89 +1956,6 @@ def read(*, inlet_events): assert not dr.task_instance_scheduling_decisions(session=session).schedulable_tis assert read_task_evaluated - @pytest.mark.want_activate_assets(True) - @pytest.mark.need_serialized_dag - def test_inlet_asset_alias_extra(self, dag_maker, session, mock_supervisor_comms): - from airflow.sdk.definitions.asset import Asset, AssetAlias - - mock_supervisor_comms.get_message.return_value = AssetEventsResult( - asset_events=[ - AssetEventResponse( - id=1, - created_dagruns=[], - timestamp=timezone.utcnow(), - extra={"from": f"write{i}"}, - asset=AssetResponse( - name="test_inlet_asset_extra_ds", uri="test_inlet_asset_extra_ds", group="asset" - ), - ) - for i in (1, 2, 3) - ] - ) - - asset_uri = "test_inlet_asset_extra_ds" - asset_alias_name = "test_inlet_asset_extra_asset_alias" - - asset_model = AssetModel(id=1, uri=asset_uri, group="asset") - asset_alias_model = AssetAliasModel(name=asset_alias_name) - asset_alias_model.assets.append(asset_model) - session.add_all([asset_model, asset_alias_model, AssetActive.for_asset(Asset(asset_uri))]) - session.commit() - - read_task_evaluated = False - - with dag_maker(schedule=None, serialized=True, session=session): - - @task(outlets=AssetAlias(asset_alias_name)) - def write(*, ti, outlet_events): - outlet_events[AssetAlias(asset_alias_name)].add(Asset(asset_uri), extra={"from": ti.task_id}) - - @task(inlets=AssetAlias(asset_alias_name)) - def read(*, inlet_events): - second_event = inlet_events[AssetAlias(asset_alias_name)][1] - assert second_event.asset.uri == asset_uri - assert second_event.extra == {"from": "write2"} - - last_event = inlet_events[AssetAlias(asset_alias_name)][-1] - assert last_event.asset.uri == asset_uri - assert last_event.extra == {"from": "write3"} - - with pytest.raises(KeyError): - inlet_events[Asset("does_not_exist")] - with pytest.raises(KeyError): - inlet_events[AssetAlias("does_not_exist")] - with pytest.raises(IndexError): - inlet_events[AssetAlias(asset_alias_name)][5] - - nonlocal read_task_evaluated - read_task_evaluated = True - - [ - write.override(task_id="write1")(), - write.override(task_id="write2")(), - write.override(task_id="write3")(), - ] >> read() - - dr: DagRun = dag_maker.create_dagrun() - - # Run "write1", "write2", and "write3" (in this order). - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in sorted(decision.schedulable_tis, key=operator.attrgetter("task_id")): - # TODO: TaskSDK #45549 - ti.task = dag_maker.dag.get_task(ti.task_id) - ti.run(session=session) - - # Run "read". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - # TODO: TaskSDK #45549 - ti.task = dag_maker.dag.get_task(ti.task_id) - ti.run(session=session) - - # Should be done. - assert not dr.task_instance_scheduling_decisions(session=session).schedulable_tis - assert read_task_evaluated - @pytest.mark.need_serialized_dag def test_inlet_unresolved_asset_alias(self, dag_maker, session, mock_supervisor_comms): asset_alias_name = "test_inlet_asset_extra_asset_alias" @@ -2843,8 +1965,6 @@ def test_inlet_unresolved_asset_alias(self, dag_maker, session, mock_supervisor_ session.add(asset_alias_model) session.commit() - from airflow.sdk.definitions.asset import AssetAlias - with dag_maker(schedule=None, session=session): @task(inlets=AssetAlias(asset_alias_name)) @@ -2863,146 +1983,6 @@ def read(*, inlet_events): # Should be done. assert not dr.task_instance_scheduling_decisions(session=session).schedulable_tis - @pytest.mark.want_activate_assets(True) - @pytest.mark.parametrize( - "slicer, expected", - [ - (lambda x: x[-2:], [{"from": 8}, {"from": 9}]), - (lambda x: x[-5:-3], [{"from": 5}, {"from": 6}]), - (lambda x: x[:-8], [{"from": 0}, {"from": 1}]), - (lambda x: x[1:-7], [{"from": 1}, {"from": 2}]), - (lambda x: x[-8:4], [{"from": 2}, {"from": 3}]), - (lambda x: x[-5:5], []), - ], - ) - def test_inlet_asset_extra_slice(self, dag_maker, session, slicer, expected, mock_supervisor_comms): - from airflow.sdk.definitions.asset import Asset - - asset_uri = "test_inlet_asset_extra_slice" - mock_supervisor_comms.get_message.return_value = AssetEventsResult( - asset_events=[ - AssetEventResponse( - id=1, - created_dagruns=[], - timestamp=timezone.utcnow(), - extra={"from": i}, - asset=AssetResponse(name=asset_uri, uri=asset_uri, group="asset"), - ) - for i in range(0, 10) - ] - ) - - with dag_maker(dag_id="write", serialized=True, schedule="@daily", params={"i": -1}, session=session): - - @task(outlets=Asset(asset_uri)) - def write(*, params, outlet_events): - outlet_events[Asset(asset_uri)].extra = {"from": params["i"]} - - write() - - # Run the write DAG 10 times. - dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, conf={"i": 0}) - for ti in dr.get_task_instances(session=session): - ti.run(session=session) - for i in range(1, 10): - dr = dag_maker.create_dagrun_after(dr, run_type=DagRunType.SCHEDULED, conf={"i": i}) - for ti in dr.get_task_instances(session=session): - ti.run(session=session) - - result = "the task does not run" - - with dag_maker(dag_id="read", schedule=None, session=session): - - @task(inlets=Asset(asset_uri)) - def read(*, inlet_events): - nonlocal result - events = inlet_events[Asset(asset_uri)] - result = [e.extra for e in slicer(events)] - - read() - - # Run the read DAG. - dr = dag_maker.create_dagrun() - for ti in dr.get_task_instances(session=session): - ti.run(session=session) - - # Should be done. - assert not dr.task_instance_scheduling_decisions(session=session).schedulable_tis - assert result == expected - - @pytest.mark.parametrize( - "slicer, expected", - [ - (lambda x: x[-2:], [{"from": 8}, {"from": 9}]), - (lambda x: x[-5:-3], [{"from": 5}, {"from": 6}]), - (lambda x: x[:-8], [{"from": 0}, {"from": 1}]), - (lambda x: x[1:-7], [{"from": 1}, {"from": 2}]), - (lambda x: x[-8:4], [{"from": 2}, {"from": 3}]), - (lambda x: x[-5:5], []), - ], - ) - @pytest.mark.want_activate_assets(True) - def test_inlet_asset_alias_extra_slice(self, dag_maker, session, slicer, expected, mock_supervisor_comms): - from airflow.sdk.definitions.asset import Asset - - asset_uri = "test_inlet_asset_alias_extra_slice_ds" - mock_supervisor_comms.get_message.return_value = AssetEventsResult( - asset_events=[ - AssetEventResponse( - id=1, - created_dagruns=[], - timestamp=timezone.utcnow(), - extra={"from": i}, - asset=AssetResponse(name=asset_uri, uri=asset_uri, group="asset"), - ) - for i in range(0, 10) - ] - ) - asset_alias_name = "test_inlet_asset_alias_extra_slice_asset_alias" - - asset_model = AssetModel(id=1, uri=asset_uri) - asset_alias_model = AssetAliasModel(name=asset_alias_name) - asset_alias_model.assets.append(asset_model) - session.add_all([asset_model, asset_alias_model, AssetActive.for_asset(Asset(asset_uri))]) - session.commit() - - with dag_maker(dag_id="write", schedule="@daily", params={"i": -1}, serialized=True, session=session): - - @task(outlets=AssetAlias(asset_alias_name)) - def write(*, params, outlet_events): - outlet_events[AssetAlias(asset_alias_name)].add(Asset(asset_uri), {"from": params["i"]}) - - write() - - # Run the write DAG 10 times. - dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, conf={"i": 0}) - for ti in dr.get_task_instances(session=session): - ti.run(session=session) - for i in range(1, 10): - dr = dag_maker.create_dagrun_after(dr, run_type=DagRunType.SCHEDULED, conf={"i": i}) - for ti in dr.get_task_instances(session=session): - ti.run(session=session) - - result = "the task does not run" - - with dag_maker(dag_id="read", schedule=None, serialized=True, session=session): - - @task(inlets=AssetAlias(asset_alias_name)) - def read(*, inlet_events): - nonlocal result - result = [e.extra for e in slicer(inlet_events[AssetAlias(asset_alias_name)])] - - read() - - # Run the read DAG. - dr = dag_maker.create_dagrun() - for ti in dr.get_task_instances(session=session): - ti.run(session=session) - - # Should be done. - assert not dr.task_instance_scheduling_decisions(session=session).schedulable_tis - assert result == expected - def test_changing_of_asset_when_adrq_is_already_populated(self, dag_maker): """ Test that when a task that produces asset has ran, that changing the consumer @@ -3299,76 +2279,6 @@ def test_template_with_json_variable_missing(self, create_task_instance, session with pytest.raises(KeyError): ti.task.render_template('{{ var.json.get("missing_variable") }}', context) - def test_execute_callback(self, create_task_instance): - called = False - - def on_execute_callable(context): - nonlocal called - called = True - assert context["dag_run"].dag_id == "test_dagrun_execute_callback" - - for i, callback_input in enumerate([[on_execute_callable], on_execute_callable]): - ti = create_task_instance( - dag_id=f"test_execute_callback_{i}", - on_execute_callback=callback_input, - state=State.RUNNING, - ) - - session = settings.Session() - - session.merge(ti) - session.commit() - - ti._run_raw_task() - assert called - ti.refresh_from_db() - assert ti.state == State.SUCCESS - - def test_finished_callbacks_callable_handle_and_log_exception(self, caplog): - called = completed = False - - def on_finish_callable(context): - nonlocal called, completed - called = True - raise KeyError - completed = True - - for callback_input in [[on_finish_callable], on_finish_callable]: - called = completed = False - caplog.clear() - _run_finished_callback(callbacks=callback_input, context={}) - - assert called - assert not completed - callback_name = callback_input[0] if isinstance(callback_input, list) else callback_input - callback_name = qualname(callback_name).split(".")[-1] - assert "Executing callback at index 0: on_finish_callable" in caplog.text - assert "Error in callback at index 0: on_finish_callable" in caplog.text - - def test_finished_callbacks_notifier_handle_and_log_exception(self, caplog): - class OnFinishNotifier(BaseNotifier): - """ - error captured by BaseNotifier - """ - - def __init__(self, error: bool): - super().__init__() - self.raise_error = error - - def notify(self, context): - self.execute() - - def execute(self) -> None: - if self.raise_error: - raise KeyError - - caplog.clear() - callbacks = [OnFinishNotifier(error=False), OnFinishNotifier(error=True)] - _run_finished_callback(callbacks=callbacks, context={}) - assert "Executing callback at index 0: OnFinishNotifier" in caplog.text - assert "Executing callback at index 1: OnFinishNotifier" in caplog.text - assert "KeyError" in caplog.text - @provide_session def test_handle_failure(self, create_dummy_dag, session=None): start_date = timezone.datetime(2016, 6, 1) @@ -3630,30 +2540,6 @@ def fail(): ti.run() assert ti.state == State.UP_FOR_RETRY - @patch.object(TaskInstance, "logger") - def test_stacktrace_on_failure_starts_with_task_execute_method(self, mock_get_log, dag_maker): - def fail(): - raise AirflowException("maybe this will pass?") - - with dag_maker(dag_id="test_retries_on_other_exceptions"): - task = PythonOperator( - task_id="test_raise_other_exception", - python_callable=fail, - retries=1, - ) - ti = dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances[0] - ti.task = task - mock_log = mock.Mock() - mock_get_log.return_value = mock_log - with pytest.raises(AirflowException): - ti.run() - mock_log.error.assert_called_once() - assert mock_log.error.call_args.args == ("Task failed with exception",) - exc_info = mock_log.error.call_args.kwargs["exc_info"] - filename = exc_info[2].tb_frame.f_code.co_filename - formatted_exc = format_exception(*exc_info) - assert sys.modules[TaskInstance.__module__].__file__ == filename, "".join(formatted_exc) - def _env_var_check_callback(self): assert os.environ["AIRFLOW_CTX_DAG_ID"] == "test_echo_env_variables" assert os.environ["AIRFLOW_CTX_TASK_ID"] == "hive_in_python_op" @@ -3665,198 +2551,6 @@ def _env_var_check_callback(self): == os.environ["AIRFLOW_CTX_DAG_RUN_ID"] ) - def test_echo_env_variables(self, dag_maker): - with dag_maker( - "test_echo_env_variables", - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=10), - ): - op = PythonOperator(task_id="hive_in_python_op", python_callable=self._env_var_check_callback) - dr = dag_maker.create_dagrun( - run_type=DagRunType.MANUAL, - ) - ti = dr.get_task_instance(op.task_id) - ti.state = State.RUNNING - session = settings.Session() - session.merge(ti) - session.commit() - ti._run_raw_task() - ti.refresh_from_db() - assert ti.state == State.SUCCESS - - @pytest.mark.parametrize( - "code, expected_state", - [ - pytest.param(1, State.FAILED, id="code-positive-number"), - pytest.param(-1, State.FAILED, id="code-negative-number"), - pytest.param("error", State.FAILED, id="code-text"), - pytest.param(0, State.SUCCESS, id="code-zero"), - pytest.param(None, State.SUCCESS, id="code-none"), - ], - ) - def test_handle_system_exit_failed(self, dag_maker, code, expected_state): - with dag_maker(): - - def f(*args, **kwargs): - exit(code) - - task = PythonOperator(task_id="mytask", python_callable=f) - - dr = dag_maker.create_dagrun() - ti = dr.get_task_instance(task.task_id) - ti.state = State.RUNNING - session = settings.Session() - session.merge(ti) - session.commit() - - if expected_state == State.SUCCESS: - ctx = contextlib.nullcontext() - else: - ctx = pytest.raises(AirflowException, match=rf"Task failed due to SystemExit\({code}\)") - - with ctx: - ti._run_raw_task() - ti.refresh_from_db() - assert ti.state == expected_state - - def test_get_current_context_works_in_template(self, dag_maker): - def user_defined_macro(): - from airflow.providers.standard.operators.python import get_current_context - - get_current_context() - - with dag_maker( - "test_context_inside_template", - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=10), - user_defined_macros={"user_defined_macro": user_defined_macro}, - ): - - def foo(arg): - print(arg) - - PythonOperator( - task_id="context_inside_template", - python_callable=foo, - op_kwargs={"arg": "{{ user_defined_macro() }}"}, - ) - dagrun = dag_maker.create_dagrun() - tis = dagrun.get_task_instances() - ti: TaskInstance = next(x for x in tis if x.task_id == "context_inside_template") - ti._run_raw_task() - assert ti.state == State.SUCCESS - - @patch.object(Stats, "incr") - def test_task_stats(self, stats_mock, create_task_instance): - ti = create_task_instance( - dag_id="test_task_start_end_stats", - end_date=timezone.utcnow() + datetime.timedelta(days=10), - state=State.RUNNING, - ) - stats_mock.reset_mock() - - session = settings.Session() - session.merge(ti) - session.commit() - ti._run_raw_task() - ti.refresh_from_db() - stats_mock.assert_any_call( - f"ti.finish.{ti.dag_id}.{ti.task_id}.{ti.state}", - tags={"dag_id": ti.dag_id, "task_id": ti.task_id}, - ) - stats_mock.assert_any_call( - "ti.finish", - tags={"dag_id": ti.dag_id, "task_id": ti.task_id, "state": ti.state}, - ) - for state in State.task_states: - assert ( - call( - f"ti.finish.{ti.dag_id}.{ti.task_id}.{state}", - count=0, - tags={"dag_id": ti.dag_id, "task_id": ti.task_id}, - ) - in stats_mock.mock_calls - ) - assert ( - call( - "ti.finish", - count=0, - tags={"dag_id": ti.dag_id, "task_id": ti.task_id, "state": str(state)}, - ) - in stats_mock.mock_calls - ) - assert ( - call(f"ti.start.{ti.dag_id}.{ti.task_id}", tags={"dag_id": ti.dag_id, "task_id": ti.task_id}) - in stats_mock.mock_calls - ) - assert call("ti.start", tags={"dag_id": ti.dag_id, "task_id": ti.task_id}) in stats_mock.mock_calls - assert stats_mock.call_count == (2 * len(State.task_states)) + 7 - - def test_command_as_list(self, dag_maker): - with dag_maker(): - PythonOperator(python_callable=print, task_id="hi") - dr = dag_maker.create_dagrun() - ti = dr.task_instances[0] - assert ti.command_as_list() == [ - "airflow", - "tasks", - "run", - ti.dag_id, - ti.task_id, - ti.run_id, - "--subdir", - "DAGS_FOLDER/test_taskinstance.py", - ] - - def test_generate_command_default_param(self): - dag_id = "test_generate_command_default_param" - task_id = "task" - assert_command = ["airflow", "tasks", "run", dag_id, task_id, "run_1"] - generate_command = TI.generate_command(dag_id=dag_id, task_id=task_id, run_id="run_1") - assert assert_command == generate_command - - def test_generate_command_specific_param(self): - dag_id = "test_generate_command_specific_param" - task_id = "task" - assert_command = [ - "airflow", - "tasks", - "run", - dag_id, - task_id, - "run_1", - "--mark-success", - "--map-index", - "0", - ] - generate_command = TI.generate_command( - dag_id=dag_id, task_id=task_id, run_id="run_1", mark_success=True, map_index=0 - ) - assert assert_command == generate_command - - @provide_session - def test_get_rendered_template_fields(self, dag_maker, session=None): - with dag_maker("test-dag", session=session) as dag: - task = BashOperator(task_id="op1", bash_command="{{ task.task_id }}") - dag.fileloc = TEST_DAGS_FOLDER / "test_get_k8s_pod_yaml.py" - ti = dag_maker.create_dagrun().task_instances[0] - ti.task = task - - session.add(RenderedTaskInstanceFields(ti)) - session.flush() - - # Create new TI for the same Task - new_task = BashOperator(task_id="op12", bash_command="{{ task.task_id }}", dag=dag) - - new_ti = TI(task=new_task, run_id=ti.run_id) - new_ti.get_rendered_template_fields(session=session) - - assert ti.task.bash_command == "op1" - - # CleanUp - with create_session() as session: - session.query(RenderedTaskInstanceFields).delete() - def test_set_state_up_for_retry(self, create_task_instance): ti = create_task_instance(state=State.RUNNING) @@ -4072,123 +2766,6 @@ def duplicate_asset_task_in_outlet(*, outlet_events): assert "Asset(name='asset_second', uri='asset_second')" in str(exc.value) assert "Asset(name='asset_first', uri='test://asset/')" in str(exc.value) - @pytest.mark.want_activate_assets(True) - def test_run_with_inactive_assets_in_outlets_within_the_same_dag(self, dag_maker, session): - from airflow.sdk.definitions.asset import Asset - - with dag_maker(schedule=None, serialized=True, session=session): - - @task(outlets=Asset("asset_first")) - def first_asset_task(*, outlet_events): - outlet_events[Asset("asset_first")].extra = {"foo": "bar"} - - @task(outlets=Asset(name="asset_first", uri="test://asset")) - def duplicate_asset_task(*, outlet_events): - outlet_events[Asset(name="asset_first", uri="test://asset")].extra = {"foo": "bar"} - - first_asset_task() >> duplicate_asset_task() - - tis = {ti.task_id: ti for ti in dag_maker.create_dagrun().task_instances} - tis["first_asset_task"].run(session=session) - with pytest.raises(AirflowInactiveAssetInInletOrOutletException) as exc: - tis["duplicate_asset_task"].run(session=session) - - assert str(exc.value) == ( - "Task has the following inactive assets in its inlets or outlets: " - "Asset(name='asset_first', uri='test://asset/')" - ) - - @pytest.mark.skip( - reason="This test has some issues that were surfaced when dag_maker started allowing multiple serdag versions. Issue #48539 will track fixing this." - ) - @pytest.mark.want_activate_assets(True) - def test_run_with_inactive_assets_in_outlets_in_different_dag(self, dag_maker, session): - from airflow.sdk.definitions.asset import Asset - - with dag_maker(schedule=None, serialized=True, session=session): - - @task(outlets=Asset("asset_first")) - def first_asset_task(*, outlet_events): - outlet_events[Asset("asset_first")].extra = {"foo": "bar"} - - first_asset_task() - - with dag_maker(schedule=None, serialized=True, session=session): - - @task(outlets=Asset(name="asset_first", uri="test://asset")) - def duplicate_asset_task(*, outlet_events): - outlet_events[Asset(name="asset_first", uri="test://asset")].extra = {"foo": "bar"} - - duplicate_asset_task() - - tis = {ti.task_id: ti for ti in dag_maker.create_dagrun().task_instances} - with pytest.raises(AirflowInactiveAssetInInletOrOutletException) as exc: - tis["duplicate_asset_task"].run(session=session) - - assert str(exc.value) == ( - "Task has the following inactive assets in its inlets or outlets: " - "Asset(name='asset_first', uri='test://asset/')" - ) - - @pytest.mark.want_activate_assets(False) - def test_run_with_inactive_assets_in_inlets_within_the_same_dag(self, dag_maker, session): - valid_asset = Asset("asset_first") - conflict_asset = Asset(name="asset_first", uri="test://asset/") - - with dag_maker(schedule=None, serialized=True, session=session): - - @task(inlets=valid_asset) - def first_asset_task(): - pass - - @task(inlets=conflict_asset) - def conflict_asset_task(): - pass - - first_asset_task() >> conflict_asset_task() - - session.execute(delete(AssetActive)) - session.add(AssetActive.for_asset(valid_asset)) - - tis = {ti.task_id: ti for ti in dag_maker.create_dagrun().task_instances} - tis["first_asset_task"].run(session=session) - with pytest.raises(AirflowInactiveAssetInInletOrOutletException) as exc: - tis["conflict_asset_task"].run(session=session) - - assert str(exc.value) == ( - "Task has the following inactive assets in its inlets or outlets: " - "Asset(name='asset_first', uri='test://asset/')" - ) - - @pytest.mark.want_activate_assets(True) - def test_run_with_inactive_assets_in_inlets_in_different_dag(self, dag_maker, session): - from airflow.sdk.definitions.asset import Asset - - with dag_maker(schedule=None, serialized=True, session=session): - - @task(inlets=Asset("asset_first")) - def first_asset_task(*, outlet_events): - pass - - first_asset_task() - - with dag_maker(schedule=None, serialized=True, session=session): - - @task(inlets=Asset(name="asset_first", uri="test://asset")) - def duplicate_asset_task(*, outlet_events): - pass - - duplicate_asset_task() - - tis = {ti.task_id: ti for ti in dag_maker.create_dagrun().task_instances} - with pytest.raises(AirflowInactiveAssetInInletOrOutletException) as exc: - tis["duplicate_asset_task"].run(session=session) - - assert str(exc.value) == ( - "Task has the following inactive assets in its inlets or outlets: " - "Asset(name='asset_first', uri='test://asset/')" - ) - @pytest.mark.parametrize("pool_override", [None, "test_pool2"]) @pytest.mark.parametrize("queue_by_policy", [None, "forced_queue"]) @@ -4350,343 +2927,6 @@ def tg(arg): tis["push_4"].run() assert dag_maker.session.query(TaskMap).count() == 2 - @pytest.mark.parametrize( - "return_value, exception_type, error_message", - [ - ("abc", UnmappableXComTypePushed, "unmappable return type 'str'"), - (None, XComForMappingNotPushed, "did not push XCom for task mapping"), - ], - ) - def test_expand_error_if_unmappable_type(self, dag_maker, return_value, exception_type, error_message): - """If an unmappable return value is used for expand(), fail the task that pushed the XCom.""" - with dag_maker(dag_id="test_expand_error_if_unmappable_type") as dag: - - @dag.task() - def push_something(): - return return_value - - @dag.task() - def pull_something(value): - print(value) - - pull_something.expand(value=push_something()) - - ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push_something") - with pytest.raises(exception_type) as ctx: - ti.run() - - assert dag_maker.session.query(TaskMap).count() == 0 - assert ti.state == TaskInstanceState.FAILED - assert str(ctx.value) == error_message - - @pytest.mark.parametrize( - "return_value, exception_type, error_message", - [ - (123, UnmappableXComTypePushed, "unmappable return type 'int'"), - (None, XComForMappingNotPushed, "did not push XCom for task mapping"), - ], - ) - def test_expand_kwargs_error_if_unmappable_type( - self, - dag_maker, - return_value, - exception_type, - error_message, - ): - """If an unmappable return value is used for expand_kwargs(), fail the task that pushed the XCom.""" - with dag_maker(dag_id="test_expand_kwargs_error_if_unmappable_type") as dag: - - @dag.task() - def push(): - return return_value - - MockOperator.partial(task_id="pull").expand_kwargs(push()) - - ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push") - with pytest.raises(exception_type) as ctx: - ti.run() - - assert dag_maker.session.query(TaskMap).count() == 0 - assert ti.state == TaskInstanceState.FAILED - assert str(ctx.value) == error_message - - @pytest.mark.parametrize( - "return_value, exception_type, error_message", - [ - (123, UnmappableXComTypePushed, "unmappable return type 'int'"), - (None, XComForMappingNotPushed, "did not push XCom for task mapping"), - ], - ) - def test_task_group_expand_error_if_unmappable_type( - self, - dag_maker, - return_value, - exception_type, - error_message, - ): - """If an unmappable return value is used , fail the task that pushed the XCom.""" - with dag_maker(dag_id="test_task_group_expand_error_if_unmappable_type") as dag: - - @dag.task() - def push(): - return return_value - - @task_group - def tg(arg): - MockOperator(task_id="pull", arg1=arg) - - tg.expand(arg=push()) - - ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push") - with pytest.raises(exception_type) as ctx: - ti.run() - - assert dag_maker.session.query(TaskMap).count() == 0 - assert ti.state == TaskInstanceState.FAILED - assert str(ctx.value) == error_message - - @pytest.mark.parametrize( - "return_value, exception_type, error_message", - [ - (123, UnmappableXComTypePushed, "unmappable return type 'int'"), - (None, XComForMappingNotPushed, "did not push XCom for task mapping"), - ], - ) - def test_task_group_expand_kwargs_error_if_unmappable_type( - self, - dag_maker, - return_value, - exception_type, - error_message, - ): - """If an unmappable return value is used, fail the task that pushed the XCom.""" - with dag_maker(dag_id="test_task_group_expand_kwargs_error_if_unmappable_type") as dag: - - @dag.task() - def push(): - return return_value - - @task_group - def tg(arg): - MockOperator(task_id="pull", arg1=arg) - - tg.expand_kwargs(push()) - - ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push") - with pytest.raises(exception_type) as ctx: - ti.run() - - assert dag_maker.session.query(TaskMap).count() == 0 - assert ti.state == TaskInstanceState.FAILED - assert str(ctx.value) == error_message - - @pytest.mark.parametrize( - "create_upstream", - [ - # The task returns an invalid expand_kwargs() input (a list[int] instead of list[dict]). - pytest.param(lambda: task(task_id="push")(lambda: [0])(), id="normal"), - # This task returns a list[dict] (correct), but we use map() to transform it to list[int] (wrong). - pytest.param(lambda: task(task_id="push")(lambda: [{"v": ""}])().map(lambda _: 0), id="mapped"), - ], - ) - def test_expand_kwargs_error_if_received_invalid(self, dag_maker, session, create_upstream): - with dag_maker(dag_id="test_expand_kwargs_error_if_received_invalid", session=session): - push_task = create_upstream() - - @task() - def pull(v): - print(v) - - pull.expand_kwargs(push_task) - - dr = dag_maker.create_dagrun() - - # Run "push". - decision = dr.task_instance_scheduling_decisions(session=session) - assert decision.schedulable_tis - for ti in decision.schedulable_tis: - ti.run() - - # Run "pull". - decision = dr.task_instance_scheduling_decisions(session=session) - assert decision.schedulable_tis - for ti in decision.schedulable_tis: - with pytest.raises(ValueError) as ctx: - ti.run() - assert str(ctx.value) == "expand_kwargs() expects a list[dict], not list[int]" - - @pytest.mark.parametrize( - "downstream, error_message", - [ - ("taskflow", "mapping already partial argument: arg2"), - ("classic", "unmappable or already specified argument: arg2"), - ], - ids=["taskflow", "classic"], - ) - @pytest.mark.parametrize("strict", [True, False], ids=["strict", "override"]) - def test_expand_kwargs_override_partial(self, dag_maker, session, downstream, error_message, strict): - class ClassicOperator(MockOperator): - def execute(self, context): - return (self.arg1, self.arg2) - - with dag_maker(dag_id="test_expand_kwargs_override_partial", session=session) as dag: - - @dag.task() - def push(): - return [{"arg1": "a"}, {"arg1": "b", "arg2": "c"}] - - push_task = push() - - ClassicOperator.partial(task_id="classic", arg2="d").expand_kwargs(push_task, strict=strict) - - @dag.task(task_id="taskflow") - def pull(arg1, arg2): - return (arg1, arg2) - - pull.partial(arg2="d").expand_kwargs(push_task, strict=strict) - - dr = dag_maker.create_dagrun() - next(ti for ti in dr.task_instances if ti.task_id == "push").run() - - decision = dr.task_instance_scheduling_decisions(session=session) - tis = {(ti.task_id, ti.map_index, ti.state): ti for ti in decision.schedulable_tis} - assert sorted(tis) == [ - ("classic", 0, None), - ("classic", 1, None), - ("taskflow", 0, None), - ("taskflow", 1, None), - ] - - ti = tis[(downstream, 0, None)] - ti.run() - ti.xcom_pull(task_ids=downstream, map_indexes=0, session=session) == ["a", "d"] - - ti = tis[(downstream, 1, None)] - if strict: - with pytest.raises(TypeError) as ctx: - ti.run() - assert str(ctx.value) == error_message - else: - ti.run() - ti.xcom_pull(task_ids=downstream, map_indexes=1, session=session) == ["b", "c"] - - def test_error_if_upstream_does_not_push(self, dag_maker): - """Fail the upstream task if it fails to push the XCom used for task mapping.""" - with dag_maker(dag_id="test_not_recorded_for_unused") as dag: - - @dag.task(do_xcom_push=False) - def push_something(): - return [1, 2] - - @dag.task() - def pull_something(value): - print(value) - - pull_something.expand(value=push_something()) - - ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push_something") - with pytest.raises(XComForMappingNotPushed) as ctx: - ti.run() - - assert dag_maker.session.query(TaskMap).count() == 0 - assert ti.state == TaskInstanceState.FAILED - assert str(ctx.value) == "did not push XCom for task mapping" - - @conf_vars({("core", "max_map_length"): "1"}) - def test_error_if_unmappable_length(self, dag_maker): - """If an unmappable return value is used to map, fail the task that pushed the XCom.""" - with dag_maker(dag_id="test_not_recorded_for_unused") as dag: - - @dag.task() - def push_something(): - return [1, 2] - - @dag.task() - def pull_something(value): - print(value) - - pull_something.expand(value=push_something()) - - ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push_something") - with pytest.raises(UnmappableXComLengthPushed) as ctx: - ti.run() - - assert dag_maker.session.query(TaskMap).count() == 0 - assert ti.state == TaskInstanceState.FAILED - assert str(ctx.value) == "unmappable return value length: 2 > 1" - - @pytest.mark.parametrize( - "xcom_value, expected_length, expected_keys", - [ - ([1, 2, 3], 3, None), - ({"a": 1, "b": 2}, 2, ["a", "b"]), - ], - ) - def test_written_task_map(self, dag_maker, xcom_value, expected_length, expected_keys): - """Return value should be recorded in TaskMap if it's used by a downstream to map.""" - with dag_maker(dag_id="test_written_task_map") as dag: - - @dag.task() - def push_something(): - return xcom_value - - @dag.task() - def pull_something(value): - print(value) - - pull_something.expand(value=push_something()) - - dag_run = dag_maker.create_dagrun() - ti = next(ti for ti in dag_run.task_instances if ti.task_id == "push_something") - ti.run() - - task_map = dag_maker.session.query(TaskMap).one() - assert task_map.dag_id == "test_written_task_map" - assert task_map.task_id == "push_something" - assert task_map.run_id == dag_run.run_id - assert task_map.map_index == -1 - assert task_map.length == expected_length - assert task_map.keys == expected_keys - - @pytest.mark.xfail( - reason="not clear what this is really testing; " - "there's no API for removing a task; " - "and when a serialized dag is there, this fails; " - "and we need a serialized dag for dag.clear to work now" - ) - def test_no_error_on_changing_from_non_mapped_to_mapped(self, dag_maker, session): - """If a task changes from non-mapped to mapped, don't fail on integrity error.""" - with dag_maker(dag_id="test_no_error_on_changing_from_non_mapped_to_mapped") as dag: - - @dag.task() - def add_one(x): - return [x + 1] - - @dag.task() - def add_two(x): - return x + 2 - - task1 = add_one(2) - add_two.expand(x=task1) - - dr = dag_maker.create_dagrun() - ti = dr.get_task_instance(task_id="add_one") - ti.run() - assert ti.state == TaskInstanceState.SUCCESS - dag._remove_task("add_one") - with dag: - task1 = add_one.expand(x=[1, 2, 3]).operator - serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - dr.dag = serialized_dag - dr.verify_integrity(session=session) - ti = dr.get_task_instance(task_id="add_one") - assert ti.state == TaskInstanceState.REMOVED - dag.clear() - ti.refresh_from_task(task1) - # This should not raise an integrity error - dr.task_instance_scheduling_decisions() - class TestMappedTaskInstanceReceiveValue: @pytest.mark.parametrize( @@ -4722,29 +2962,6 @@ def show(value): ti.run() assert outputs == expected_outputs - def test_map_has_dag_version(self, dag_maker, session): - from airflow.models.dag_version import DagVersion - - known_versions = [] - - with dag_maker(dag_id="test_89eug7u6f7y", session=session) as dag: - - @dag.task - def show(value, *, ti): - # let's record the dag version ids we observe on the tis - known_versions.append(ti.dag_version_id) - - show.expand(value=[1, 2, 3]) - # get the dag version for the dag - dag_version = session.scalar(select(DagVersion).where(DagVersion.dag_id == dag.dag_id)) - dag_maker.create_dagrun(session=session) - task = dag.get_task("show") - for ti in session.scalars(select(TI)): - ti.refresh_from_task(task) - ti.run(session=session) - # verify that we only saw the dag version we created - assert known_versions == [dag_version.id] * 3 - @pytest.mark.parametrize( "upstream_return, expected_outputs", [ diff --git a/airflow-core/tests/unit/ti_deps/deps/test_not_previously_skipped_dep.py b/airflow-core/tests/unit/ti_deps/deps/test_not_previously_skipped_dep.py index 11d913e1ec568..d22dc6f1a5825 100644 --- a/airflow-core/tests/unit/ti_deps/deps/test_not_previously_skipped_dep.py +++ b/airflow-core/tests/unit/ti_deps/deps/test_not_previously_skipped_dep.py @@ -20,7 +20,6 @@ import pendulum import pytest -from airflow.exceptions import DownstreamTasksSkipped from airflow.models import DagRun, TaskInstance from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import BranchPythonOperator @@ -130,10 +129,7 @@ def test_parent_skip_branch(session, dag_maker): ti.task_id: ti for ti in dag_maker.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING).task_instances } - with pytest.raises(DownstreamTasksSkipped) as exc_info: - tis["op1"].run() - - assert exc_info.value.tasks == [("op2", -1)] + tis["op1"].run() dep = NotPreviouslySkippedDep() assert len(list(dep.get_dep_statuses(tis["op2"], session, DepContext()))) == 1 diff --git a/airflow-core/tests/unit/utils/test_log_handlers.py b/airflow-core/tests/unit/utils/test_log_handlers.py index e9a0a1823cf19..6d06137573973 100644 --- a/airflow-core/tests/unit/utils/test_log_handlers.py +++ b/airflow-core/tests/unit/utils/test_log_handlers.py @@ -105,6 +105,7 @@ def test_default_task_logging_setup(self): handler = handlers[0] assert handler.name == FILE_TASK_HANDLER + @pytest.mark.xfail(reason="TODO: Needs to be ported over to the new structlog based logging") def test_file_task_handler_when_ti_value_is_invalid(self, dag_maker): def task_callable(ti): ti.log.info("test") @@ -149,6 +150,7 @@ def task_callable(ti): # Remove the generated tmp log file. os.remove(log_filename) + @pytest.mark.xfail(reason="TODO: Needs to be ported over to the new structlog based logging") def test_file_task_handler(self, dag_maker, session): def task_callable(ti): ti.log.info("test") diff --git a/devel-common/src/tests_common/test_utils/version_compat.py b/devel-common/src/tests_common/test_utils/version_compat.py index 7227de2d85962..4877ced6c8eca 100644 --- a/devel-common/src/tests_common/test_utils/version_compat.py +++ b/devel-common/src/tests_common/test_utils/version_compat.py @@ -33,5 +33,6 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_2_10_PLUS = get_base_airflow_version_tuple() >= (2, 10, 0) +AIRFLOW_V_3_0_1 = get_base_airflow_version_tuple() == (3, 0, 1) AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) [].sort() diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/log_handlers/test_log_handlers.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/log_handlers/test_log_handlers.py index cc63b8b15b310..e467bcda87633 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/log_handlers/test_log_handlers.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/log_handlers/test_log_handlers.py @@ -117,7 +117,7 @@ def test_read_from_k8s_under_multi_namespace_mode( mock_list_pod = mock_kube_client.return_value.list_namespaced_pod def task_callable(ti): - ti.log.info("test") + ti.task.log.info("test") with DAG("dag_for_testing_file_task_handler", schedule=None, start_date=DEFAULT_DATE) as dag: task = PythonOperator( @@ -146,7 +146,7 @@ def task_callable(ti): ti.executor = "KubernetesExecutor" logger = ti.log - ti.log.disabled = False + ti.task.log.disabled = False file_handler = next((h for h in logger.handlers if h.name == FILE_TASK_HANDLER), None) set_context(logger, ti) diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py index c27e94773673f..f7cba7e6a00b1 100644 --- a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py +++ b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py @@ -48,7 +48,7 @@ from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker from tests_common.test_utils.providers import get_provider_min_airflow_version -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: from airflow.models.xcom import XComModel as XCom @@ -1202,7 +1202,7 @@ def test_branch_single_value_with_dag_run(self, mock_get_db_hook, branch_op): mock_get_records.return_value = 1 - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -1250,7 +1250,7 @@ def test_branch_true_with_dag_run(self, mock_get_db_hook, true_value, branch_op) mock_get_records = mock_get_db_hook.return_value.get_first mock_get_records.return_value = true_value - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -1298,7 +1298,7 @@ def test_branch_false_with_dag_run(self, mock_get_db_hook, false_value, branch_o mock_get_records = mock_get_db_hook.return_value.get_first mock_get_records.return_value = false_value - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -1355,7 +1355,7 @@ def test_branch_list_with_dag_run(self, mock_get_db_hook): mock_get_records = mock_get_db_hook.return_value.get_first mock_get_records.return_value = [["1"]] - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -1473,7 +1473,7 @@ def test_with_skip_in_branch_downstream_dependencies2(self, mock_get_db_hook, fa mock_get_records = mock_get_db_hook.return_value.get_first mock_get_records.return_value = [false_value] - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_adx.py b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_adx.py index ec63bba5b2a30..ec83ed1863f6c 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_adx.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_adx.py @@ -25,6 +25,7 @@ from airflow.models import DAG from airflow.providers.microsoft.azure.hooks.adx import AzureDataExplorerHook from airflow.providers.microsoft.azure.operators.adx import AzureDataExplorerQueryOperator +from airflow.providers.microsoft.azure.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils.timezone import datetime TEST_DAG_ID = "unit_tests" @@ -88,12 +89,20 @@ def test_azure_data_explorer_query_operator_xcom_push_and_pull( mock_conn, mock_run_query, create_task_instance_of_operator, + request, ): - ti = create_task_instance_of_operator( - AzureDataExplorerQueryOperator, - dag_id="test_azure_data_explorer_query_operator_xcom_push_and_pull", - **MOCK_DATA, - ) - ti.run() - - assert ti.xcom_pull(task_ids=MOCK_DATA["task_id"]) == str(MOCK_RESULT) + if AIRFLOW_V_3_0_PLUS: + run_task = request.getfixturevalue("run_task") + task = AzureDataExplorerQueryOperator(**MOCK_DATA) + run_task(task=task) + + assert run_task.xcom.get(key="return_value", task_id=task.task_id) == str(MOCK_RESULT) + else: + ti = create_task_instance_of_operator( + AzureDataExplorerQueryOperator, + dag_id="test_azure_data_explorer_query_operator_xcom_push_and_pull", + **MOCK_DATA, + ) + ti.run() + + assert ti.xcom_pull(task_ids=MOCK_DATA["task_id"]) == str(MOCK_RESULT) diff --git a/providers/oracle/tests/unit/oracle/operators/test_oracle.py b/providers/oracle/tests/unit/oracle/operators/test_oracle.py index 2f06e1513472a..02bc3f391de11 100644 --- a/providers/oracle/tests/unit/oracle/operators/test_oracle.py +++ b/providers/oracle/tests/unit/oracle/operators/test_oracle.py @@ -27,6 +27,8 @@ from airflow.providers.oracle.hooks.oracle import OracleHook from airflow.providers.oracle.operators.oracle import OracleStoredProcedureOperator +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + class TestOracleStoredProcedureOperator: @mock.patch.object(OracleHook, "run", autospec=OracleHook.run) @@ -65,12 +67,20 @@ def test_push_oracle_exit_to_xcom(self, mock_callproc, request, dag_maker): error = f"ORA-{ora_exit_code}: This is a five-digit ORA error code" mock_callproc.side_effect = oracledb.DatabaseError(error) - with dag_maker(dag_id=f"dag_{request.node.name}"): + if AIRFLOW_V_3_0_PLUS: + run_task = request.getfixturevalue("run_task") task = OracleStoredProcedureOperator( procedure=procedure, oracle_conn_id=oracle_conn_id, parameters=parameters, task_id=task_id ) - dr = dag_maker.create_dagrun(run_id=task_id) - ti = TaskInstance(task=task, run_id=dr.run_id) - with pytest.raises(oracledb.DatabaseError, match=re.escape(error)): - ti.run() - assert ti.xcom_pull(task_ids=task.task_id, key="ORA") == ora_exit_code + run_task(task=task) + assert run_task.xcom.get(task_id=task.task_id, key="ORA") == ora_exit_code + else: + with dag_maker(dag_id=f"dag_{request.node.name}"): + task = OracleStoredProcedureOperator( + procedure=procedure, oracle_conn_id=oracle_conn_id, parameters=parameters, task_id=task_id + ) + dr = dag_maker.create_dagrun(run_id=task_id) + ti = TaskInstance(task=task, run_id=dr.run_id) + with pytest.raises(oracledb.DatabaseError, match=re.escape(error)): + ti.run() + assert ti.xcom_pull(task_ids=task.task_id, key="ORA") == ora_exit_code diff --git a/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py b/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py index b14b6bd5c0df1..f51b5d6a6acca 100644 --- a/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py +++ b/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py @@ -24,6 +24,7 @@ import pytest from airflow.decorators import task +from airflow.providers.snowflake.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils import timezone if TYPE_CHECKING: @@ -156,7 +157,7 @@ def func(session: Session): mock_snowflake_hook.return_value.get_snowpark_session.assert_called_once() @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook") - def test_snowpark_decorator_multiple_output(self, mock_snowflake_hook, dag_maker): + def test_snowpark_decorator_multiple_output(self, mock_snowflake_hook, dag_maker, request): @task.snowpark( task_id=TASK_ID, snowflake_conn_id=CONN_ID, @@ -171,15 +172,23 @@ def func(session: Session): assert session == mock_snowflake_hook.return_value.get_snowpark_session.return_value return {"a": 1, "b": "2"} - with dag_maker(dag_id=TEST_DAG_ID): - ret = func() - - dr = dag_maker.create_dagrun() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - ti = dr.get_task_instances()[0] - assert ti.xcom_pull(key="a") == 1 - assert ti.xcom_pull(key="b") == "2" - assert ti.xcom_pull() == {"a": 1, "b": "2"} + if AIRFLOW_V_3_0_PLUS: + run_task = request.getfixturevalue("run_task") + op = func().operator + run_task(task=op) + assert run_task.xcom.get(key="a") == 1 + assert run_task.xcom.get(key="b") == "2" + assert run_task.xcom.get(key="return_value") == {"a": 1, "b": "2"} + else: + with dag_maker(dag_id=TEST_DAG_ID): + ret = func() + + dr = dag_maker.create_dagrun() + ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + ti = dr.get_task_instances()[0] + assert ti.xcom_pull(key="a") == 1 + assert ti.xcom_pull(key="b") == "2" + assert ti.xcom_pull() == {"a": 1, "b": "2"} mock_snowflake_hook.assert_called_once() mock_snowflake_hook.return_value.get_snowpark_session.assert_called_once() diff --git a/providers/standard/src/airflow/providers/standard/operators/python.py b/providers/standard/src/airflow/providers/standard/operators/python.py index 1f6759ef8b89f..f65b7e0ea6c95 100644 --- a/providers/standard/src/airflow/providers/standard/operators/python.py +++ b/providers/standard/src/airflow/providers/standard/operators/python.py @@ -492,9 +492,21 @@ def get_python_source(self): return textwrap.dedent(inspect.getsource(self.python_callable)) def _write_args(self, file: Path): + def resolve_proxies(obj): + """Recursively replaces lazy_object_proxy.Proxy instances with their resolved values.""" + if isinstance(obj, lazy_object_proxy.Proxy): + return obj.__wrapped__ # force evaluation + if isinstance(obj, dict): + return {k: resolve_proxies(v) for k, v in obj.items()} + if isinstance(obj, list): + return [resolve_proxies(v) for v in obj] + return obj + if self.op_args or self.op_kwargs: self.log.info("Use %r as serializer.", self.serializer) - file.write_bytes(self.pickling_library.dumps({"args": self.op_args, "kwargs": self.op_kwargs})) + file.write_bytes( + self.pickling_library.dumps({"args": self.op_args, "kwargs": resolve_proxies(self.op_kwargs)}) + ) def _write_string_args(self, file: Path): file.write_text("\n".join(map(str, self.string_args))) diff --git a/providers/standard/tests/unit/standard/decorators/test_branch_external_python.py b/providers/standard/tests/unit/standard/decorators/test_branch_external_python.py index 43e9ae1d91a37..f0283c0307493 100644 --- a/providers/standard/tests/unit/standard/decorators/test_branch_external_python.py +++ b/providers/standard/tests/unit/standard/decorators/test_branch_external_python.py @@ -22,12 +22,13 @@ import pytest from airflow.decorators import task -from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.utils.state import State -if AIRFLOW_V_3_0_PLUS: +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1 + +if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped -else: - from airflow.utils.state import State + pytestmark = pytest.mark.db_test @@ -79,7 +80,7 @@ def branch_operator(): dr = dag_maker.create_dagrun() df.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as exc_info: branchoperator.operator.run( start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True diff --git a/providers/standard/tests/unit/standard/decorators/test_branch_python.py b/providers/standard/tests/unit/standard/decorators/test_branch_python.py index a78050b6a3cfa..3d8a46d7a37cd 100644 --- a/providers/standard/tests/unit/standard/decorators/test_branch_python.py +++ b/providers/standard/tests/unit/standard/decorators/test_branch_python.py @@ -20,12 +20,12 @@ import pytest from airflow.decorators import task -from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.utils.state import State -if AIRFLOW_V_3_0_PLUS: +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1 + +if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped -else: - from airflow.utils.state import State pytestmark = pytest.mark.db_test @@ -67,7 +67,7 @@ def branch_operator(): dr = dag_maker.create_dagrun() df.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as exc_info: branchoperator.operator.run( start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True diff --git a/providers/standard/tests/unit/standard/decorators/test_branch_virtualenv.py b/providers/standard/tests/unit/standard/decorators/test_branch_virtualenv.py index ab616b37435cd..170916c21a31b 100644 --- a/providers/standard/tests/unit/standard/decorators/test_branch_virtualenv.py +++ b/providers/standard/tests/unit/standard/decorators/test_branch_virtualenv.py @@ -22,12 +22,13 @@ import pytest from airflow.decorators import task -from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.utils.state import State -if AIRFLOW_V_3_0_PLUS: +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS + +if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped -else: - from airflow.utils.state import State + pytestmark = pytest.mark.db_test @@ -95,7 +96,7 @@ def branch_operator(): dr = dag_maker.create_dagrun() df.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as exc_info: branchoperator.operator.run( start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True diff --git a/providers/standard/tests/unit/standard/decorators/test_python.py b/providers/standard/tests/unit/standard/decorators/test_python.py index 8dfc7b4d8e40f..ed8d5ef0cdb97 100644 --- a/providers/standard/tests/unit/standard/decorators/test_python.py +++ b/providers/standard/tests/unit/standard/decorators/test_python.py @@ -33,7 +33,7 @@ from airflow.utils.types import DagRunType from airflow.utils.xcom import XCOM_RETURN_KEY -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS from unit.standard.operators.test_python import BasePythonTest if AIRFLOW_V_3_0_PLUS: @@ -215,7 +215,22 @@ def identity_notyping_with_decorator_call(x: int): assert identity_notyping_with_decorator_call(5).operator.multiple_outputs is False - def test_manual_multiple_outputs_false_with_typings(self): + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Different test for AF 2") + def test_manual_multiple_outputs_false_with_typings(self, run_task): + @task_decorator(multiple_outputs=False) + def identity2(x: int, y: int) -> tuple[int, int]: + return x, y + + res = identity2(8, 4) + run_task(task=res.operator) + + assert not res.operator.multiple_outputs + assert run_task.xcom.get(key=res.key) == (8, 4) + assert run_task.xcom.get(key="return_value_0") is None + assert run_task.xcom.get(key="return_value_1") is None + + @pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Different test for AF 3") + def test_manual_multiple_outputs_false_with_typings_af2(self): @task_decorator(multiple_outputs=False) def identity2(x: int, y: int) -> tuple[int, int]: return x, y @@ -233,7 +248,22 @@ def identity2(x: int, y: int) -> tuple[int, int]: assert ti.xcom_pull(key="return_value_0") is None assert ti.xcom_pull(key="return_value_1") is None - def test_multiple_outputs_ignore_typing(self): + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Different test for AF 2") + def test_multiple_outputs_ignore_typing(self, run_task): + @task_decorator + def identity_tuple(x: int, y: int) -> tuple[int, int]: + return x, y + + ident = identity_tuple(35, 36) + run_task(task=ident.operator) + + assert not ident.operator.multiple_outputs + assert run_task.xcom.get(key=ident.key) == (35, 36) + assert run_task.xcom.get(key="return_value_0") is None + assert run_task.xcom.get(key="return_value_1") is None + + @pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Different test for AF 3") + def test_multiple_outputs_ignore_typing_af2(self): @task_decorator def identity_tuple(x: int, y: int) -> tuple[int, int]: return x, y @@ -296,7 +326,9 @@ def add_number(num: int): ret = add_number(2) self.create_dag_run() - with pytest.raises(AirflowException): + + error_expected = AirflowException if (not AIRFLOW_V_3_0_PLUS or AIRFLOW_V_3_0_1) else TypeError + with pytest.raises(error_expected): ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) def test_fail_multiple_outputs_no_dict(self): @@ -308,7 +340,8 @@ def add_number(num: int): ret = add_number(2) self.create_dag_run() - with pytest.raises(AirflowException): + error_expected = AirflowException if (not AIRFLOW_V_3_0_PLUS or AIRFLOW_V_3_0_1) else TypeError + with pytest.raises(error_expected): ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) def test_multiple_outputs_empty_dict(self): diff --git a/providers/standard/tests/unit/standard/decorators/test_short_circuit.py b/providers/standard/tests/unit/standard/decorators/test_short_circuit.py index 3992870b52f09..3ead1c252bbb3 100644 --- a/providers/standard/tests/unit/standard/decorators/test_short_circuit.py +++ b/providers/standard/tests/unit/standard/decorators/test_short_circuit.py @@ -21,10 +21,11 @@ from pendulum import datetime from airflow.decorators import task -from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils.state import State from airflow.utils.trigger_rule import TriggerRule +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS + if AIRFLOW_V_3_0_PLUS: from airflow.exceptions import DownstreamTasksSkipped @@ -34,8 +35,8 @@ DEFAULT_DATE = datetime(2022, 8, 17) -@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Test doesn't run on AF3. Companion test below.") -def test_short_circuit_decorator_af2(dag_maker): +@pytest.mark.skipif(AIRFLOW_V_3_0_1, reason="Test doesn't run on AF 3.0.1. Companion test below.") +def test_short_circuit_decorator(dag_maker): with dag_maker(serialized=True): @task @@ -82,9 +83,9 @@ def short_circuit(condition): assert ti.state == task_state_mapping[ti.task_id] -@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test only runs on AF3") +@pytest.mark.skipif(not AIRFLOW_V_3_0_1, reason="Test only runs on AF3.0.1") @pytest.mark.parametrize(["condition", "should_be_skipped"], [(True, False), (False, True)]) -def test_short_circuit_decorator_af3(dag_maker, session, condition, should_be_skipped): +def test_short_circuit_decorator_af301(dag_maker, session, condition, should_be_skipped): with dag_maker(serialized=True, session=session): @task.short_circuit() @@ -112,7 +113,7 @@ def empty(): ... ti_sc.run() -@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test only runs on AF3") +@pytest.mark.skipif(not AIRFLOW_V_3_0_1, reason="Test only runs on AF3.0.1") @pytest.mark.parametrize( ["ignore_downstream_trigger_rules", "expected"], [(True, State.SKIPPED), (False, State.SUCCESS)] ) diff --git a/providers/standard/tests/unit/standard/operators/test_branch_operator.py b/providers/standard/tests/unit/standard/operators/test_branch_operator.py index c38a1074bcaa9..821e7cfb9c675 100644 --- a/providers/standard/tests/unit/standard/operators/test_branch_operator.py +++ b/providers/standard/tests/unit/standard/operators/test_branch_operator.py @@ -24,13 +24,14 @@ from airflow.models.taskinstance import TaskInstance as TI from airflow.providers.standard.operators.branch import BaseBranchOperator from airflow.providers.standard.operators.empty import EmptyOperator +from airflow.providers.standard.utils.skipmixin import XCOM_SKIPMIXIN_FOLLOWED, XCOM_SKIPMIXIN_KEY from airflow.timetables.base import DataInterval from airflow.utils import timezone from airflow.utils.state import State from airflow.utils.task_group import TaskGroup from airflow.utils.types import DagRunType -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: from airflow.exceptions import DownstreamTasksSkipped @@ -75,7 +76,7 @@ def test_without_dag_run(self, dag_maker): branch_2.set_upstream(branch_op) dag_maker.create_dagrun(**triggered_by_kwargs) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as exc_info: branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -83,9 +84,9 @@ def test_without_dag_run(self, dag_maker): else: branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - for ti in dag_maker.session.query(TI).filter( - TI.dag_id == dag_id, TI.execution_date == DEFAULT_DATE - ): + ti_date = TI.logical_date if AIRFLOW_V_3_0_PLUS else TI.execution_date + + for ti in dag_maker.session.query(TI).filter(TI.dag_id == dag_id, ti_date == DEFAULT_DATE): if ti.task_id == "make_choice": assert ti.state == State.SUCCESS elif ti.task_id == "branch_1": @@ -115,7 +116,7 @@ def test_branch_list_without_dag_run(self, dag_maker): branch_3.set_upstream(branch_op) dag_maker.create_dagrun(**triggered_by_kwargs) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as exc_info: branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -130,9 +131,9 @@ def test_branch_list_without_dag_run(self, dag_maker): "branch_3": State.SKIPPED, } - for ti in dag_maker.session.query(TI).filter( - TI.dag_id == dag_id, TI.execution_date == DEFAULT_DATE - ): + ti_date = TI.logical_date if AIRFLOW_V_3_0_PLUS else TI.execution_date + + for ti in dag_maker.session.query(TI).filter(TI.dag_id == dag_id, ti_date == DEFAULT_DATE): if ti.task_id in expected: assert ti.state == expected[ti.task_id] else: @@ -152,7 +153,7 @@ def test_with_dag_run(self, dag_maker): branch_op = ChooseBranchOne(task_id="make_choice") branch_1.set_upstream(branch_op) branch_2.set_upstream(branch_op) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: dag_maker.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), @@ -184,9 +185,9 @@ def test_with_dag_run(self, dag_maker): "branch_2": State.SKIPPED, } - for ti in dag_maker.session.query(TI).filter( - TI.dag_id == dag_id, TI.execution_date == DEFAULT_DATE - ): + ti_date = TI.logical_date if AIRFLOW_V_3_0_PLUS else TI.execution_date + + for ti in dag_maker.session.query(TI).filter(TI.dag_id == dag_id, ti_date == DEFAULT_DATE): if ti.task_id in expected: assert ti.state == expected[ti.task_id] else: @@ -244,7 +245,15 @@ def test_with_skip_in_branch_downstream_dependencies(self, dag_maker): def test_xcom_push(self, dag_maker): dag_id = "branch_operator_test" - triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} + + triggered_by_kwargs = ( + { + "triggered_by": DagRunTriggeredByType.TEST, + "logical_date": DEFAULT_DATE, + } + if AIRFLOW_V_3_0_PLUS + else {"execution_date": DEFAULT_DATE} + ) with dag_maker( dag_id, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, @@ -257,36 +266,25 @@ def test_xcom_push(self, dag_maker): branch_1.set_upstream(branch_op) branch_2.set_upstream(branch_op) - if AIRFLOW_V_3_0_PLUS: - dag_maker.create_dagrun( - run_type=DagRunType.MANUAL, - start_date=timezone.utcnow(), - logical_date=DEFAULT_DATE, - state=State.RUNNING, - data_interval=DataInterval(DEFAULT_DATE, DEFAULT_DATE), - **triggered_by_kwargs, - ) + dr = dag_maker.create_dagrun( + run_type=DagRunType.MANUAL, + start_date=timezone.utcnow(), + state=State.RUNNING, + data_interval=DataInterval(DEFAULT_DATE, DEFAULT_DATE), + **triggered_by_kwargs, + ) + if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as exc_info: branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) assert exc_info.value.tasks == [("branch_2", -1)] else: - dag_maker.create_dagrun( - run_type=DagRunType.MANUAL, - start_date=timezone.utcnow(), - execution_date=DEFAULT_DATE, - state=State.RUNNING, - data_interval=DataInterval(DEFAULT_DATE, DEFAULT_DATE), - **triggered_by_kwargs, - ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - ti_date = TI.logical_date if AIRFLOW_V_3_0_PLUS else TI.execution_date - - for ti in dag_maker.session.query(TI).filter(TI.dag_id == dag_id, ti_date == DEFAULT_DATE): - if ti.task_id == "make_choice": - assert ti.xcom_pull(task_ids="make_choice") == "branch_1" + branch_op_ti = dr.get_task_instance(branch_op.task_id) + assert branch_op_ti.xcom_pull(task_ids="make_choice", key=XCOM_SKIPMIXIN_KEY) == { + XCOM_SKIPMIXIN_FOLLOWED: ["branch_1"] + } def test_with_dag_run_task_groups(self, dag_maker): dag_id = "branch_operator_test" @@ -307,7 +305,7 @@ def test_with_dag_run_task_groups(self, dag_maker): branch_2.set_upstream(branch_op) branch_3.set_upstream(branch_op) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: dag_maker.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), @@ -332,9 +330,9 @@ def test_with_dag_run_task_groups(self, dag_maker): ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - for ti in dag_maker.session.query(TI).filter( - TI.dag_id == dag_id, TI.execution_date == DEFAULT_DATE - ): + ti_date = TI.logical_date if AIRFLOW_V_3_0_PLUS else TI.execution_date + + for ti in dag_maker.session.query(TI).filter(TI.dag_id == dag_id, ti_date == DEFAULT_DATE): if ti.task_id == "make_choice": assert ti.state == State.SUCCESS elif ti.task_id == "branch_1": diff --git a/providers/standard/tests/unit/standard/operators/test_datetime.py b/providers/standard/tests/unit/standard/operators/test_datetime.py index eab06756610d9..0c6e40381793f 100644 --- a/providers/standard/tests/unit/standard/operators/test_datetime.py +++ b/providers/standard/tests/unit/standard/operators/test_datetime.py @@ -32,7 +32,7 @@ from airflow.utils.session import create_session from airflow.utils.state import State -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS pytestmark = pytest.mark.db_test @@ -124,7 +124,7 @@ def test_branch_datetime_operator_falls_within_range(self, target_lower, target_ """Check BranchDateTimeOperator branch operation""" self.branch_op.target_lower = target_lower self.branch_op.target_upper = target_upper - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -157,7 +157,7 @@ def test_branch_datetime_operator_falls_outside_range(self, date, target_lower, self.branch_op.target_lower = target_lower self.branch_op.target_upper = target_upper - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info, time_machine.travel(date): @@ -183,7 +183,7 @@ def test_branch_datetime_operator_upper_comparison_within_range(self, target_upp self.branch_op.target_upper = target_upper self.branch_op.target_lower = None - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -208,7 +208,7 @@ def test_branch_datetime_operator_lower_comparison_within_range(self, target_low self.branch_op.target_lower = target_lower self.branch_op.target_upper = None - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -233,7 +233,7 @@ def test_branch_datetime_operator_upper_comparison_outside_range(self, target_up self.branch_op.target_upper = target_upper self.branch_op.target_lower = None - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -258,7 +258,7 @@ def test_branch_datetime_operator_lower_comparison_outside_range(self, target_lo self.branch_op.target_lower = target_lower self.branch_op.target_upper = None - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -295,7 +295,7 @@ def test_branch_datetime_operator_use_task_logical_date(self, dag_maker, target_ self.branch_op.target_lower = target_lower self.branch_op.target_upper = target_upper - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: diff --git a/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py b/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py index b976f41fa9a2e..fce99a64b8278 100644 --- a/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py +++ b/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py @@ -18,6 +18,7 @@ from __future__ import annotations import datetime +import operator import pytest import time_machine @@ -33,7 +34,7 @@ from airflow.utils.types import DagRunType from tests_common.test_utils.db import clear_db_runs, clear_db_xcom -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: from airflow.sdk import DAG @@ -115,9 +116,7 @@ def test_skipping_non_latest(self, dag_maker): start_date=timezone.utcnow(), logical_date=timezone.datetime(2016, 1, 1, 12), state=State.RUNNING, - data_interval=DataInterval( - timezone.datetime(2016, 1, 1, 12), timezone.datetime(2016, 1, 1, 12) + INTERVAL - ), + data_interval=DataInterval(timezone.datetime(2016, 1, 1, 12), timezone.datetime(2016, 1, 1, 12)), **triggered_by_kwargs, ) @@ -126,11 +125,11 @@ def test_skipping_non_latest(self, dag_maker): start_date=timezone.utcnow(), logical_date=END_DATE, state=State.RUNNING, - data_interval=DataInterval(END_DATE, END_DATE + INTERVAL), + data_interval=DataInterval(END_DATE + INTERVAL, END_DATE + INTERVAL), **triggered_by_kwargs, ) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped # AIP-72 @@ -145,6 +144,7 @@ def test_skipping_non_latest(self, dag_maker): latest_ti0.run() assert exc_info.value.tasks == [("downstream", -1)] + # TODO: Set state is needed until #45549 is completed. latest_ti0.set_state(State.SUCCESS) dr0.get_task_instance(task_id="downstream").set_state(State.SKIPPED) @@ -156,6 +156,7 @@ def test_skipping_non_latest(self, dag_maker): latest_ti1.run() assert exc_info.value.tasks == [("downstream", -1)] + # TODO: Set state is needed until #45549 is completed. latest_ti1.set_state(State.SUCCESS) dr1.get_task_instance(task_id="downstream").set_state(State.SKIPPED) @@ -164,78 +165,52 @@ def test_skipping_non_latest(self, dag_maker): latest_ti2 = dr2.get_task_instance(task_id="latest") latest_ti2.task = latest_task latest_ti2.run() - - latest_ti2.set_state(State.SUCCESS) - - # Verify the state of the other downstream tasks - downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) - downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE) - downstream_task3.run(start_date=DEFAULT_DATE, end_date=END_DATE) - - downstream_instances = get_task_instances("downstream") - exec_date_to_downstream_state = {ti.logical_date: ti.state for ti in downstream_instances} - assert exec_date_to_downstream_state == { - timezone.datetime(2016, 1, 1): "skipped", - timezone.datetime(2016, 1, 1, 12): "skipped", - timezone.datetime(2016, 1, 2): "success", - } - - downstream_instances = get_task_instances("downstream_2") - exec_date_to_downstream_state = {ti.logical_date: ti.state for ti in downstream_instances} - assert exec_date_to_downstream_state == { - timezone.datetime(2016, 1, 1): None, - timezone.datetime(2016, 1, 1, 12): None, - timezone.datetime(2016, 1, 2): "success", - } - - downstream_instances = get_task_instances("downstream_3") - exec_date_to_downstream_state = {ti.logical_date: ti.state for ti in downstream_instances} - assert exec_date_to_downstream_state == { - timezone.datetime(2016, 1, 1): "success", - timezone.datetime(2016, 1, 1, 12): "success", - timezone.datetime(2016, 1, 2): "success", - } - else: latest_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) - downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) - downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE) - downstream_task3.run(start_date=DEFAULT_DATE, end_date=END_DATE) + if AIRFLOW_V_3_0_PLUS: + date_getter = operator.attrgetter("logical_date") + else: + date_getter = operator.attrgetter("execution_date") - latest_instances = get_task_instances("latest") - exec_date_to_latest_state = {ti.execution_date: ti.state for ti in latest_instances} - assert exec_date_to_latest_state == { - timezone.datetime(2016, 1, 1): "success", - timezone.datetime(2016, 1, 1, 12): "success", - timezone.datetime(2016, 1, 2): "success", - } + latest_instances = get_task_instances("latest") + exec_date_to_latest_state = {date_getter(ti): ti.state for ti in latest_instances} + assert exec_date_to_latest_state == { + timezone.datetime(2016, 1, 1): "success", + timezone.datetime(2016, 1, 1, 12): "success", + timezone.datetime(2016, 1, 2): "success", + } - downstream_instances = get_task_instances("downstream") - exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances} - assert exec_date_to_downstream_state == { - timezone.datetime(2016, 1, 1): "skipped", - timezone.datetime(2016, 1, 1, 12): "skipped", - timezone.datetime(2016, 1, 2): "success", - } + # Verify the state of the other downstream tasks + downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) + downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE) + downstream_task3.run(start_date=DEFAULT_DATE, end_date=END_DATE) - downstream_instances = get_task_instances("downstream_2") - exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances} - assert exec_date_to_downstream_state == { - timezone.datetime(2016, 1, 1): None, - timezone.datetime(2016, 1, 1, 12): None, - timezone.datetime(2016, 1, 2): "success", - } + downstream_instances = get_task_instances("downstream") + exec_date_to_downstream_state = {date_getter(ti): ti.state for ti in downstream_instances} + assert exec_date_to_downstream_state == { + timezone.datetime(2016, 1, 1): "skipped", + timezone.datetime(2016, 1, 1, 12): "skipped", + timezone.datetime(2016, 1, 2): "success", + } - downstream_instances = get_task_instances("downstream_3") - exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances} - assert exec_date_to_downstream_state == { - timezone.datetime(2016, 1, 1): "success", - timezone.datetime(2016, 1, 1, 12): "success", - timezone.datetime(2016, 1, 2): "success", - } + downstream_instances = get_task_instances("downstream_2") + exec_date_to_downstream_state = {date_getter(ti): ti.state for ti in downstream_instances} + assert exec_date_to_downstream_state == { + timezone.datetime(2016, 1, 1): None, + timezone.datetime(2016, 1, 1, 12): None, + timezone.datetime(2016, 1, 2): "success", + } + + downstream_instances = get_task_instances("downstream_3") + exec_date_to_downstream_state = {date_getter(ti): ti.state for ti in downstream_instances} + assert exec_date_to_downstream_state == { + timezone.datetime(2016, 1, 1): "success", + timezone.datetime(2016, 1, 1, 12): "success", + timezone.datetime(2016, 1, 2): "success", + } - def test_not_skipping_external(self, dag_maker): + def test_not_skipping_manual(self, dag_maker): with dag_maker( default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, schedule=INTERVAL, serialized=True ): diff --git a/providers/standard/tests/unit/standard/operators/test_python.py b/providers/standard/tests/unit/standard/operators/test_python.py index 1da002de755a0..4474795afbd50 100644 --- a/providers/standard/tests/unit/standard/operators/test_python.py +++ b/providers/standard/tests/unit/standard/operators/test_python.py @@ -44,6 +44,7 @@ from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG from airflow.exceptions import ( AirflowException, + AirflowProviderDeprecationWarning, DeserializingResultError, ) from airflow.models.baseoperator import BaseOperator @@ -70,7 +71,7 @@ from airflow.utils.types import NOTSET, DagRunType from tests_common.test_utils.db import clear_db_runs -from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: from airflow.models.dagrun import DagRun @@ -90,7 +91,7 @@ CLOUDPICKLE_INSTALLED = find_spec("cloudpickle") is not None CLOUDPICKLE_MARKER = pytest.mark.skipif(not CLOUDPICKLE_INSTALLED, reason="`cloudpickle` is not installed") -if AIRFLOW_V_3_0_PLUS: +if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped @@ -405,7 +406,7 @@ def f(): branch_op >> [self.branch_1, self.branch_2] dr = self.create_dag_run() - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as dts: branch_op.run(start_date=self.default_date, end_date=self.default_date) assert dts.value.tasks == [("branch_2", -1)] @@ -444,7 +445,7 @@ def f(): branch_op >> self.branch_2 dr = self.create_dag_run() - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as dts: branch_op.run(start_date=self.default_date, end_date=self.default_date) assert dts.value.tasks == [("branch_1", -1)] @@ -469,7 +470,7 @@ def f(): branch_op >> branches dr = dag_maker.create_dagrun() - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with create_session() as session: @@ -502,7 +503,10 @@ def f(): tis = dr.get_task_instances() children_tis = [ti for ti in tis if ti.task_id in branch_op.get_direct_relative_ids()] with create_session() as session: - clear_task_instances(children_tis, session=session, dag=branch_op.dag) + if AIRFLOW_V_3_0_PLUS: + clear_task_instances(children_tis, session=session) + else: + clear_task_instances(children_tis, session=session, dag=branch_op.dag) # Run the cleared tasks again. for task in branches: @@ -562,7 +566,7 @@ def f(): for task_id in task_ids: # Mimic the specific order the scheduling would run the tests. task_instance = tis[task_id] task_instance.refresh_from_task(self.dag_non_serialized.get_task(task_id)) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped try: @@ -720,7 +724,7 @@ def test_short_circuiting( self.op2.trigger_rule = test_trigger_rule dr = self.create_dag_run() - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped if expected_skipped_tasks: @@ -751,7 +755,7 @@ def test_clear_skipped_downstream_task(self): short_circuit >> self.op1 >> self.op2 dr = self.create_dag_run() - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with create_session() as session: @@ -786,9 +790,12 @@ def test_clear_skipped_downstream_task(self): # Clear downstream task "op1" that was previously executed. tis = dr.get_task_instances() with create_session() as session: - clear_task_instances( - [ti for ti in tis if ti.task_id == "op1"], session=session, dag=short_circuit.dag - ) + if AIRFLOW_V_3_0_PLUS: + clear_task_instances([ti for ti in tis if ti.task_id == "op1"], session=session) + else: + clear_task_instances( + [ti for ti in tis if ti.task_id == "op1"], session=session, dag=short_circuit.dag + ) self.op1.run(start_date=self.default_date, end_date=self.default_date) self.assert_expected_task_states(dr, expected_states) @@ -818,7 +825,7 @@ def test_xcom_push_skipped_tasks(self): empty_task = EmptyOperator(task_id="empty_task") short_op_push_xcom >> empty_task dr = self.create_dag_run() - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped): @@ -936,11 +943,10 @@ def test_virtualenv_serializable_context_fields(self, create_task_instance): "conn", # Accessor for Connection. "map_index_template", } - if AIRFLOW_V_2_10_PLUS: - intentionally_excluded_context_keys |= { - "inlet_events", - "outlet_events", - } + intentionally_excluded_context_keys |= { + "inlet_events", + "outlet_events", + } ti = create_task_instance(dag_id=self.dag_id, task_id=self.task_id, schedule=None) context = ti.get_template_context() @@ -1155,7 +1161,9 @@ def f(): return True raise RuntimeError - self.run_as_task(f, system_site_packages=False, requirements=extra_requirements) + self.run_as_task( + f, system_site_packages=False, requirements=extra_requirements, serializer=serializer + ) def test_system_site_packages(self): def f(): @@ -1201,7 +1209,12 @@ def test_unpinned_requirements(self, serializer, extra_requirements): def f(): import funcsigs # noqa: F401 - self.run_as_task(f, requirements=["funcsigs", *extra_requirements], system_site_packages=False) + self.run_as_task( + f, + requirements=["funcsigs", *extra_requirements], + system_site_packages=False, + serializer=serializer, + ) @pytest.mark.parametrize( "serializer, extra_requirements", @@ -1216,7 +1229,12 @@ def test_range_requirements(self, serializer, extra_requirements): def f(): import funcsigs # noqa: F401 - self.run_as_task(f, requirements=["funcsigs>1.0", *extra_requirements], system_site_packages=False) + self.run_as_task( + f, + requirements=["funcsigs>1.0", *extra_requirements], + system_site_packages=False, + serializer=serializer, + ) def test_requirements_file(self): def f(): @@ -1360,7 +1378,6 @@ def f( params, run_id, task_instance_key_str, - test_mode, ts, ts_nodash, ts_nodash_with_tz, @@ -1396,7 +1413,6 @@ def f( outlets, run_id, task_instance_key_str, - test_mode, ts, ts_nodash, ts_nodash_with_tz, @@ -1428,7 +1444,6 @@ def f( outlets, run_id, task_instance_key_str, - test_mode, ts, ts_nodash, ts_nodash_with_tz, @@ -1439,6 +1454,103 @@ def f( self.run_as_task(f, serializer=serializer, system_site_packages=False, requirements=None) + @pytest.mark.parametrize( + "requirements, system_site, want_airflow, want_pendulum", + [ + # nothing → just base keys + ([], False, False, False), + # site-packages → base keys + pendulum keys + ([], True, True, True), + # apache-airflow / no version constraint + (["apache-airflow"], False, True, True), + # specific version + (["apache-airflow==2.10.2"], False, True, True), + # minimum version + (["apache-airflow>=2.10"], False, True, True), + # pendulum / no version constraint + (["pendulum"], False, False, True), + # compatible release + (["pendulum~=2.1.0"], False, False, True), + # other package + (["foo==1.0.0"], False, False, False), + # with other package + (["apache-airflow", "foo"], False, True, True), + # full-line comment only + (["# comment"], False, False, False), + # inline comment after requirement + (["apache-airflow==2.10.2 # comment"], False, True, True), + # blank line + requirement + (["", "pendulum"], False, False, True), + # indented comment + requirement + ([" # comment", "pendulum~=2.1.0"], False, False, True), + # requirements passed as multi-line strings + ("funcsigs==0.4\nattrs==23.1.0", False, False, False), + (["funcsigs==0.4\nattrs==23.1.0"], False, False, False), + ("pendulum==2.1.2 # pinned version\nattrs==23.1.0 # optional", False, False, True), + ], + ) + def test_iter_serializable_context_keys(self, requirements, system_site, want_airflow, want_pendulum): + def func(): + return "test_return_value" + + op = PythonVirtualenvOperator( + task_id="task", + python_callable=func, + requirements=requirements, + system_site_packages=system_site, + ) + keys = set(op._iter_serializable_context_keys()) + + base_keys = set(op.BASE_SERIALIZABLE_CONTEXT_KEYS) + airflow_keys = set(op.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS) + pendulum_keys = set(op.PENDULUM_SERIALIZABLE_CONTEXT_KEYS) + + # BASE keys always present + assert base_keys <= keys + + # AIRFLOW keys only when expected + if want_airflow: + assert airflow_keys <= keys, f"expected AIRFLOW keys for requirements: {requirements}" + else: + assert not (airflow_keys & keys), f"unexpected AIRFLOW keys for requirements: {requirements}" + + # PENDULUM keys only when expected + if want_pendulum: + assert pendulum_keys <= keys, f"expected PENDULUM keys for requirements: {requirements}" + else: + assert not (pendulum_keys & keys), f"unexpected PENDULUM keys for requirements: {requirements}" + + @pytest.mark.parametrize( + "invalid_requirement", + [ + # invalid version format + "pendulum==3..0", + # invalid operator (=< instead of <=) + "apache-airflow=<2.0", + # same invalid operator on pendulum + "pendulum=<3.0", + # totally malformed + "invalid requirement", + ], + ) + def test_iter_serializable_context_keys_invalid_requirement(self, invalid_requirement): + def func(): + return "test_return_value" + + op = PythonVirtualenvOperator( + task_id="task", + python_callable=func, + requirements=[invalid_requirement], + system_site_packages=False, + ) + + with pytest.raises(ValueError) as exc_info: + # Consume the generator to trigger parsing + list(op._iter_serializable_context_keys()) + + msg = str(exc_info.value) + assert f"Invalid requirement '{invalid_requirement}'" in msg + # when venv tests are run in parallel to other test they create new processes and this might take # quite some time in shared docker environment and get some contention even between different containers @@ -1628,7 +1740,7 @@ def f(): branch_op >> [self.branch_1, self.branch_2] dr = self.create_dag_run() - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as dts: branch_op.run(start_date=self.default_date, end_date=self.default_date) @@ -1669,7 +1781,7 @@ def f(): dr = self.create_dag_run() - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as dts: branch_op.run(start_date=self.default_date, end_date=self.default_date) @@ -1696,7 +1808,7 @@ def f(): dr = self.create_dag_run() - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with create_session() as session: @@ -1731,7 +1843,10 @@ def f(): tis = dr.get_task_instances() children_tis = [ti for ti in tis if ti.task_id in branch_op.get_direct_relative_ids()] with create_session() as session: - clear_task_instances(children_tis, session=session, dag=branch_op.dag) + if AIRFLOW_V_3_0_PLUS: + clear_task_instances(children_tis, session=session) + else: + clear_task_instances(children_tis, session=session, dag=branch_op.dag) # Run the cleared tasks again. for task in branches: @@ -1799,22 +1914,35 @@ def default_kwargs(*, python_version=DEFAULT_PYTHON_VERSION, **kwargs): class TestCurrentContext: def test_current_context_no_context_raise(self): - with pytest.raises(RuntimeError): - get_current_context() + if AIRFLOW_V_3_0_PLUS: + with pytest.warns(AirflowProviderDeprecationWarning): + with pytest.raises(RuntimeError): + get_current_context() + else: + with pytest.raises(RuntimeError): + get_current_context() def test_current_context_roundtrip(self): example_context = {"Hello": "World"} - with set_current_context(example_context): - assert get_current_context() == example_context + if AIRFLOW_V_3_0_PLUS: + with pytest.warns(AirflowProviderDeprecationWarning): + assert get_current_context() == example_context + else: + assert get_current_context() == example_context def test_context_removed_after_exit(self): example_context = {"Hello": "World"} with set_current_context(example_context): pass - with pytest.raises(RuntimeError): - get_current_context() + if AIRFLOW_V_3_0_PLUS: + with pytest.warns(AirflowProviderDeprecationWarning): + with pytest.raises(RuntimeError): + get_current_context() + else: + with pytest.raises(RuntimeError): + get_current_context() def test_nested_context(self): """ @@ -1831,12 +1959,21 @@ def test_nested_context(self): ctx_obj = set_current_context(new_context) ctx_obj.__enter__() ctx_list.append(ctx_obj) - for i in reversed(range(max_stack_depth)): - # Iterate over contexts in reverse order - stack is LIFO - ctx = get_current_context() - assert ctx["ContextId"] == i - # End of with statement - ctx_list[i].__exit__(None, None, None) + if AIRFLOW_V_3_0_PLUS: + with pytest.warns(AirflowProviderDeprecationWarning): + for i in reversed(range(max_stack_depth)): + # Iterate over contexts in reverse order - stack is LIFO + ctx = get_current_context() + assert ctx["ContextId"] == i + # End of with statement + ctx_list[i].__exit__(None, None, None) + else: + for i in reversed(range(max_stack_depth)): + # Iterate over contexts in reverse order - stack is LIFO + ctx = get_current_context() + assert ctx["ContextId"] == i + # End of with statement + ctx_list[i].__exit__(None, None, None) class MyContextAssertOperator(BaseOperator): @@ -1878,12 +2015,20 @@ class TestCurrentContextRuntime: def test_context_in_task(self): with DAG(dag_id="assert_context_dag", default_args=DEFAULT_ARGS, schedule="@once"): op = MyContextAssertOperator(task_id="assert_context") - op.run(ignore_first_depends_on_past=True, ignore_ti_state=True) + if AIRFLOW_V_3_0_1: + with pytest.warns(AirflowProviderDeprecationWarning): + op.run(ignore_first_depends_on_past=True, ignore_ti_state=True) + else: + op.run(ignore_first_depends_on_past=True, ignore_ti_state=True) def test_get_context_in_old_style_context_task(self): with DAG(dag_id="edge_case_context_dag", default_args=DEFAULT_ARGS, schedule="@once"): op = PythonOperator(python_callable=get_all_the_context, task_id="get_all_the_context") - op.run(ignore_first_depends_on_past=True, ignore_ti_state=True) + if AIRFLOW_V_3_0_1: + with pytest.warns(AirflowProviderDeprecationWarning): + op.run(ignore_first_depends_on_past=True, ignore_ti_state=True) + else: + op.run(ignore_first_depends_on_past=True, ignore_ti_state=True) @pytest.mark.need_serialized_dag(False) diff --git a/providers/standard/tests/unit/standard/operators/test_weekday.py b/providers/standard/tests/unit/standard/operators/test_weekday.py index 583f20fd663be..0372669c96179 100644 --- a/providers/standard/tests/unit/standard/operators/test_weekday.py +++ b/providers/standard/tests/unit/standard/operators/test_weekday.py @@ -29,12 +29,13 @@ from airflow.providers.standard.operators.weekday import BranchDayOfWeekOperator from airflow.providers.standard.utils.skipmixin import XCOM_SKIPMIXIN_FOLLOWED, XCOM_SKIPMIXIN_KEY from airflow.providers.standard.utils.weekday import WeekDay -from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS from airflow.timetables.base import DataInterval from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import State +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS + if AIRFLOW_V_3_0_PLUS: from airflow.models.xcom import XComModel as XCom else: @@ -115,7 +116,7 @@ def test_branch_follow_true(self, weekday, dag_maker): data_interval=DataInterval(DEFAULT_DATE, DEFAULT_DATE), ) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -161,7 +162,7 @@ def test_branch_follow_true_with_logical_date(self, dag_maker): data_interval=DataInterval(DEFAULT_DATE, DEFAULT_DATE), ) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -230,7 +231,7 @@ def test_branch_follow_false(self, dag_maker): data_interval=DataInterval(DEFAULT_DATE, DEFAULT_DATE), ) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -336,20 +337,16 @@ def test_branch_xcom_push_true_branch(self, dag_maker): ) branch_op_ti = dr.get_task_instance(branch_op.task_id) - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: branch_op_ti.run() assert exc_info.value.tasks == [("branch_2", -1)] - assert branch_op_ti.xcom_pull(task_ids="make_choice", key=XCOM_SKIPMIXIN_KEY) == { - XCOM_SKIPMIXIN_FOLLOWED: ["branch_1"] - } else: branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - tis = dr.get_task_instances() - for ti in tis: - if ti.task_id == "make_choice": - assert ti.xcom_pull(task_ids="make_choice") == "branch_1" + assert branch_op_ti.xcom_pull(task_ids="make_choice", key=XCOM_SKIPMIXIN_KEY) == { + XCOM_SKIPMIXIN_FOLLOWED: ["branch_1"] + } diff --git a/pyproject.toml b/pyproject.toml index faf90779316a9..632c0a7ca01ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -841,7 +841,7 @@ fixture-parentheses = false ## pytest settings ## [tool.pytest.ini_options] addopts = [ - "--tb=short", + "--tb=no", "-rasl", "--verbosity=2", # Disable `flaky` plugin for pytest. This plugin conflicts with `rerunfailures` because provide the same marker. diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index 36f79bb173372..79616f43ffd3c 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -1198,7 +1198,7 @@ def add_logger_if_needed(ti: TaskInstance): ti.set_state(State.SUCCESS) log.info("[DAG TEST] Marking success for %s on %s", ti.task, ti.logical_date) else: - _run_task(ti=ti) + _run_task(ti=ti, run_triggerer=True) except Exception: log.exception("Task failed; ti=%s", ti) if use_executor: @@ -1213,7 +1213,7 @@ def add_logger_if_needed(ti: TaskInstance): return dr -def _run_task(*, ti): +def _run_task(*, ti, run_triggerer=False): """ Run a single task instance, and push result to Xcom for downstream tasks. @@ -1250,8 +1250,9 @@ def _run_task(*, ti): msg = taskrun_result.msg ti.set_state(taskrun_result.ti.state) + ti.task = taskrun_result.ti.task - if ti.state == State.DEFERRED and isinstance(msg, DeferTask): + if ti.state == State.DEFERRED and isinstance(msg, DeferTask) and run_triggerer: # API Server expects the task instance to be in QUEUED state before # resuming from deferral. ti.set_state(State.QUEUED) @@ -1260,11 +1261,12 @@ def _run_task(*, ti): trigger = import_string(msg.classpath)(**msg.trigger_kwargs) event = _run_inline_trigger(trigger) ti.next_method = msg.next_method - ti.next_kwargs = {"event": event.payload} if event else msg.kwargs + ti.next_kwargs = {"event": event.payload} if event else msg.next_kwargs log.info("[DAG TEST] Trigger completed") ti.set_state(State.SUCCESS) - break + + return taskrun_result except Exception: log.exception("[DAG TEST] Error running task %s", ti) if ti.state not in State.finished: diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 630b5cb8a849a..2036ee22b4717 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -1127,7 +1127,7 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): with timeout(timeout_seconds): result = ctx.run(execute, context=context) except AirflowTaskTimeout: - # TODO: handle on kill callback here + task.on_kill() raise else: result = ctx.run(execute, context=context)