diff --git a/airflow-core/docs/templates-ref.rst b/airflow-core/docs/templates-ref.rst index 1d226b1ec12fb..401253961bb33 100644 --- a/airflow-core/docs/templates-ref.rst +++ b/airflow-core/docs/templates-ref.rst @@ -83,6 +83,11 @@ Variable Type Description list[AssetEvent]] | (there may be more than one, if there are multiple Assets with different frequencies). | Read more here :doc:`Assets `. | Added in version 2.4. +``{{ mark_success_url }}`` str | None |URL to mark the DAG run as successful in the Airflow UI. +``{{ log_url }}`` str | None |URL to the log for the current DAG run or task instance. +``{{ dag_run_url }}`` str | None |URL to the DAG run details page in the Airflow UI. +``{{ end_date }}`` DateTime | None |The end date/time of the DAG run. +``{{ max_tries }}`` int | None |The maximum number of tries for the task instance. =========================================== ===================== =================================================================== The following are only available when the DagRun has a ``logical_date`` diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 8dcd00fef0cbc..0867dc3843246 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -1352,21 +1352,48 @@ def notify_dagrun_state_changed(self, msg: str = ""): def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason: str = "success"): """Only needed for `dag.test` where `execute_callbacks=True` is passed to `update_state`.""" + task_instances = self.get_task_instances() + + # Identify the most relevant task instance + last_relevant_ti = None + if not success: + failed_tis = [ti for ti in task_instances if ti.state in State.failed_states and ti.end_date] + failed_tis.sort(key=lambda x: x.end_date, reverse=True) + last_relevant_ti = failed_tis[0] if failed_tis else None + else: + success_tis = [ti for ti in task_instances if ti.state in State.success_states and ti.end_date] + success_tis.sort(key=lambda x: x.end_date, reverse=True) + last_relevant_ti = success_tis[0] if success_tis else None + + # Enrich DAG-level callback context context: Context = { # type: ignore[assignment] "dag": dag, "run_id": str(self.run_id), + "start_date": self.start_date, + "end_date": self.end_date, + "data_interval_start": self.data_interval_start, + "data_interval_end": self.data_interval_end, "reason": reason, } + # Add task-level metadata if available + if last_relevant_ti: + context.update( + { + "task_instance": last_relevant_ti, + "ti": last_relevant_ti, + } + ) + callbacks = dag.on_success_callback if success else dag.on_failure_callback if not callbacks: - self.log.warning("Callback requested, but dag didn't have any for DAG: %s.", dag.dag_id) + self.log.warning("Callback requested, but DAG didn't have any for DAG: %s.", dag.dag_id) return callbacks = callbacks if isinstance(callbacks, list) else [callbacks] for callback in callbacks: self.log.info( - "Executing on_%s dag callback: %s", + "Executing on_%s DAG callback: %s", "success" if success else "failure", callback.__name__ if hasattr(callback, "__name__") else repr(callback), ) @@ -2042,3 +2069,5 @@ def __repr__(self): if self.map_index != -1: prefix += f" map_index={self.map_index}" return prefix + ">" + + return prefix + ">" diff --git a/airflow-core/src/airflow/utils/context.py b/airflow-core/src/airflow/utils/context.py index c27032c7c3c20..806ced8293be1 100644 --- a/airflow-core/src/airflow/utils/context.py +++ b/airflow-core/src/airflow/utils/context.py @@ -86,6 +86,11 @@ "ts_nodash_with_tz", "try_number", "var", + "dag_run_url", + "end_date", + "log_url", + "mark_success_url", + "max_tries", } diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index 0a95ffcddc7a1..00d3f6ba4c092 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -907,6 +907,172 @@ def test_already_added_task_instances_can_be_ignored(self, dag_maker, session): first_ti.refresh_from_db() assert first_ti.state is None + def test_dag_callback_context_with_task_metadata(self, dag_maker, session): + """Test that DAG-level on_success_callback receives the correct context including task metadata.""" + + callback_invoked = {"flag": False} # Use mutable object to modify inside nested function + + def on_success_callable(context): + callback_invoked["flag"] = True + + # Base context assertions + assert context["dag_run"].dag_id == "test_dag_callback_context_with_task_metadata" + assert context["reason"] == "success" + assert "dag" in context + assert "run_id" in context + assert "execution_date" in context + assert "data_interval_start" in context + assert "data_interval_end" in context + assert "dag_run_url" in context + + # Task-level metadata + assert "task_instance" in context + assert "try_number" in context + assert "max_tries" in context + assert "log_url" in context + assert "mark_success_url" in context + + # Verify task instance content + ti = context["task_instance"] + assert ti.task_id == "task3" + assert ti.state == TaskInstanceState.SUCCESS + assert context["try_number"] == ti.try_number + assert context["max_tries"] == ti.max_tries + assert context["log_url"] == ti.log_url + assert context["mark_success_url"] == ti.mark_success_url + + # Define DAG with success callback + with dag_maker( + dag_id="test_dag_callback_context_with_task_metadata", + on_success_callback=on_success_callable, + ) as dag: + task1 = EmptyOperator(task_id="task1") + task2 = EmptyOperator(task_id="task2") + task3 = EmptyOperator(task_id="task3") + task1 >> task2 >> task3 + + initial_task_states = { + "task1": TaskInstanceState.RUNNING, + "task2": TaskInstanceState.RUNNING, + "task3": TaskInstanceState.RUNNING, + } + + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + session.commit() + + # Create DAG run + dagrun = self.create_dag_run( + dag=dag, + task_states=initial_task_states, + state=DagRunState.RUNNING, + session=session, + ) + + # Simulate task completions + ti1 = dagrun.get_task_instance("task1", session) + ti2 = dagrun.get_task_instance("task2", session) + ti3 = dagrun.get_task_instance("task3", session) + + now = timezone.utcnow() + ti1.set_state(TaskInstanceState.SUCCESS, session=session) + ti1.end_date = now + + ti2.set_state(TaskInstanceState.SUCCESS, session=session) + ti2.end_date = now + datetime.timedelta(minutes=1) + + ti3.set_state(TaskInstanceState.SUCCESS, session=session) + ti3.end_date = now + datetime.timedelta(minutes=2) + + # Reattach callback if it was stripped (safeguard against test infra / DagBag side effects) + dag.on_success_callback = on_success_callable + + # Finalize DAG run and trigger callback + dagrun.update_state(session=session, execute_callbacks=True) + + # Ensure callback was actually invoked + assert callback_invoked["flag"], "DAG on_success_callback was not triggered" + + def test_dag_callback_context_with_task_metadata_failure(self, dag_maker, session): + """Test that DAG-level on_failure_callback receives the last failed task instance in the context.""" + callback_invoked = {"flag": False} + + def on_failure_callable(context): + callback_invoked["flag"] = True + + # Base context assertions + assert context["dag_run"].dag_id == "test_dag_callback_context_failure" + assert context["reason"] == "failure" + assert "dag" in context + assert "run_id" in context + assert "execution_date" in context + assert "data_interval_start" in context + assert "data_interval_end" in context + assert "dag_run_url" in context + + # Task-level metadata + assert "task_instance" in context + assert "try_number" in context + assert "max_tries" in context + assert "log_url" in context + assert "mark_success_url" in context + + # Verify task instance content + ti = context["task_instance"] + assert ti.task_id == "task3" + assert ti.state == TaskInstanceState.FAILED + assert context["try_number"] == ti.try_number + assert context["max_tries"] == ti.max_tries + assert context["log_url"] == ti.log_url + assert context["mark_success_url"] == ti.mark_success_url + + # Define DAG with failure callback + with dag_maker( + dag_id="test_dag_callback_context_failure", + schedule=datetime.timedelta(days=1), + start_date=datetime.datetime(2017, 1, 1), + on_failure_callback=on_failure_callable, + ) as dag: + task1 = EmptyOperator(task_id="task1") + task2 = EmptyOperator(task_id="task2") + task3 = EmptyOperator(task_id="task3") + task1 >> task2 >> task3 + + initial_task_states = { + "task1": TaskInstanceState.RUNNING, + "task2": TaskInstanceState.RUNNING, + "task3": TaskInstanceState.RUNNING, + } + + session.commit() + + dagrun = self.create_dag_run( + dag=dag, + task_states=initial_task_states, + state=DagRunState.RUNNING, + session=session, + ) + + now = timezone.utcnow() + ti1 = dagrun.get_task_instance("task1", session) + ti2 = dagrun.get_task_instance("task2", session) + ti3 = dagrun.get_task_instance("task3", session) + + ti1.set_state(TaskInstanceState.SUCCESS, session=session) + ti1.end_date = now + + ti2.set_state(TaskInstanceState.SUCCESS, session=session) + ti2.end_date = now + datetime.timedelta(minutes=1) + + ti3.set_state(TaskInstanceState.FAILED, session=session) + ti3.end_date = now + datetime.timedelta(minutes=2) + + # Reattach callback (required due to potential DAG context loss) + dag.on_failure_callback = on_failure_callable + + dagrun.update_state(session=session, execute_callbacks=True) + assert dagrun.state == DagRunState.FAILED + assert callback_invoked["flag"], "DAG on_failure_callback was not triggered" + @pytest.mark.parametrize("state", State.task_states) @mock.patch.object(settings, "task_instance_mutation_hook", autospec=True) def test_task_instance_mutation_hook(self, mock_hook, dag_maker, session, state): diff --git a/task-sdk/src/airflow/sdk/definitions/context.py b/task-sdk/src/airflow/sdk/definitions/context.py index 082ad36202ec2..34dc76b03ac7f 100644 --- a/task-sdk/src/airflow/sdk/definitions/context.py +++ b/task-sdk/src/airflow/sdk/definitions/context.py @@ -79,6 +79,13 @@ class Context(TypedDict, total=False): ts_nodash_with_tz: str var: Any + # --- Added for enriched DAG-level callback context --- + end_date: DateTime | None + dag_run_url: str | None + max_tries: int | None + log_url: str | None + mark_success_url: str | None + def get_current_context() -> Context: """ diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 6c6e597f65e5c..fc727ddec56ae 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -193,6 +193,11 @@ def get_template_context(self) -> Context: "value": VariableAccessor(deserialize_json=False), }, "conn": ConnectionAccessor(), + "dag_run_url": None, # Will be populated in callbacks + "end_date": None, # Will be populated in callbacks + "log_url": None, # Will be populated in callbacks + "mark_success_url": None, # Will be populated in callbacks + "max_tries": None, # Will be populated in callbacks } if from_server: dag_run = from_server.dag_run