From 768e77ee881894613b561cd8263db7a2a5a1865f Mon Sep 17 00:00:00 2001 From: kiran2706 Date: Fri, 20 Jun 2025 12:37:12 +0530 Subject: [PATCH 1/7] Feature(call-back): Enhance DAG-level callback context with enriched metadata and test coverage --- airflow-core/src/airflow/models/dagrun.py | 44 ++++- airflow-core/tests/unit/models/test_dagrun.py | 166 ++++++++++++++++++ 2 files changed, 207 insertions(+), 3 deletions(-) diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 8dcd00fef0cbc..dbf215f4edd18 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -1351,22 +1351,58 @@ def notify_dagrun_state_changed(self, msg: str = ""): # or LocalTaskJob, so we don't want to "falsely advertise" we notify about that def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason: str = "success"): - """Only needed for `dag.test` where `execute_callbacks=True` is passed to `update_state`.""" + """Handle DAG-level callbacks (on_success_callback, on_failure_callback) with enriched context.""" + + task_instances = self.get_task_instances() + + # Identify the most relevant task instance + last_relevant_ti = None + if not success: + failed_tis = [ti for ti in task_instances if ti.state in State.failed_states and ti.end_date] + failed_tis.sort(key=lambda x: x.end_date, reverse=True) + last_relevant_ti = failed_tis[0] if failed_tis else None + else: + success_tis = [ti for ti in task_instances if ti.state in State.success_states and ti.end_date] + success_tis.sort(key=lambda x: x.end_date, reverse=True) + last_relevant_ti = success_tis[0] if success_tis else None + + # Enrich DAG-level callback context context: Context = { # type: ignore[assignment] "dag": dag, "run_id": str(self.run_id), + "execution_date": self.logical_date, + "start_date": self.start_date, + "end_date": self.end_date, + "data_interval_start": self.data_interval_start, + "data_interval_end": self.data_interval_end, "reason": reason, + "run_duration": ( + (self.end_date - self.start_date).total_seconds() + if self.start_date and self.end_date + else None + ), } + # Add task-level metadata if available + if last_relevant_ti: + context.update({ + "task_instance": last_relevant_ti, + "ti": last_relevant_ti, + "try_number": last_relevant_ti.try_number, + "max_tries": last_relevant_ti.max_tries, + "log_url": last_relevant_ti.log_url, + "mark_success_url": last_relevant_ti.mark_success_url, + }) + callbacks = dag.on_success_callback if success else dag.on_failure_callback if not callbacks: - self.log.warning("Callback requested, but dag didn't have any for DAG: %s.", dag.dag_id) + self.log.warning("Callback requested, but DAG didn't have any for DAG: %s.", dag.dag_id) return callbacks = callbacks if isinstance(callbacks, list) else [callbacks] for callback in callbacks: self.log.info( - "Executing on_%s dag callback: %s", + "Executing on_%s DAG callback: %s", "success" if success else "failure", callback.__name__ if hasattr(callback, "__name__") else repr(callback), ) @@ -2042,3 +2078,5 @@ def __repr__(self): if self.map_index != -1: prefix += f" map_index={self.map_index}" return prefix + ">" + + return prefix + ">" diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index 0a95ffcddc7a1..00d3f6ba4c092 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -907,6 +907,172 @@ def test_already_added_task_instances_can_be_ignored(self, dag_maker, session): first_ti.refresh_from_db() assert first_ti.state is None + def test_dag_callback_context_with_task_metadata(self, dag_maker, session): + """Test that DAG-level on_success_callback receives the correct context including task metadata.""" + + callback_invoked = {"flag": False} # Use mutable object to modify inside nested function + + def on_success_callable(context): + callback_invoked["flag"] = True + + # Base context assertions + assert context["dag_run"].dag_id == "test_dag_callback_context_with_task_metadata" + assert context["reason"] == "success" + assert "dag" in context + assert "run_id" in context + assert "execution_date" in context + assert "data_interval_start" in context + assert "data_interval_end" in context + assert "dag_run_url" in context + + # Task-level metadata + assert "task_instance" in context + assert "try_number" in context + assert "max_tries" in context + assert "log_url" in context + assert "mark_success_url" in context + + # Verify task instance content + ti = context["task_instance"] + assert ti.task_id == "task3" + assert ti.state == TaskInstanceState.SUCCESS + assert context["try_number"] == ti.try_number + assert context["max_tries"] == ti.max_tries + assert context["log_url"] == ti.log_url + assert context["mark_success_url"] == ti.mark_success_url + + # Define DAG with success callback + with dag_maker( + dag_id="test_dag_callback_context_with_task_metadata", + on_success_callback=on_success_callable, + ) as dag: + task1 = EmptyOperator(task_id="task1") + task2 = EmptyOperator(task_id="task2") + task3 = EmptyOperator(task_id="task3") + task1 >> task2 >> task3 + + initial_task_states = { + "task1": TaskInstanceState.RUNNING, + "task2": TaskInstanceState.RUNNING, + "task3": TaskInstanceState.RUNNING, + } + + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + session.commit() + + # Create DAG run + dagrun = self.create_dag_run( + dag=dag, + task_states=initial_task_states, + state=DagRunState.RUNNING, + session=session, + ) + + # Simulate task completions + ti1 = dagrun.get_task_instance("task1", session) + ti2 = dagrun.get_task_instance("task2", session) + ti3 = dagrun.get_task_instance("task3", session) + + now = timezone.utcnow() + ti1.set_state(TaskInstanceState.SUCCESS, session=session) + ti1.end_date = now + + ti2.set_state(TaskInstanceState.SUCCESS, session=session) + ti2.end_date = now + datetime.timedelta(minutes=1) + + ti3.set_state(TaskInstanceState.SUCCESS, session=session) + ti3.end_date = now + datetime.timedelta(minutes=2) + + # Reattach callback if it was stripped (safeguard against test infra / DagBag side effects) + dag.on_success_callback = on_success_callable + + # Finalize DAG run and trigger callback + dagrun.update_state(session=session, execute_callbacks=True) + + # Ensure callback was actually invoked + assert callback_invoked["flag"], "DAG on_success_callback was not triggered" + + def test_dag_callback_context_with_task_metadata_failure(self, dag_maker, session): + """Test that DAG-level on_failure_callback receives the last failed task instance in the context.""" + callback_invoked = {"flag": False} + + def on_failure_callable(context): + callback_invoked["flag"] = True + + # Base context assertions + assert context["dag_run"].dag_id == "test_dag_callback_context_failure" + assert context["reason"] == "failure" + assert "dag" in context + assert "run_id" in context + assert "execution_date" in context + assert "data_interval_start" in context + assert "data_interval_end" in context + assert "dag_run_url" in context + + # Task-level metadata + assert "task_instance" in context + assert "try_number" in context + assert "max_tries" in context + assert "log_url" in context + assert "mark_success_url" in context + + # Verify task instance content + ti = context["task_instance"] + assert ti.task_id == "task3" + assert ti.state == TaskInstanceState.FAILED + assert context["try_number"] == ti.try_number + assert context["max_tries"] == ti.max_tries + assert context["log_url"] == ti.log_url + assert context["mark_success_url"] == ti.mark_success_url + + # Define DAG with failure callback + with dag_maker( + dag_id="test_dag_callback_context_failure", + schedule=datetime.timedelta(days=1), + start_date=datetime.datetime(2017, 1, 1), + on_failure_callback=on_failure_callable, + ) as dag: + task1 = EmptyOperator(task_id="task1") + task2 = EmptyOperator(task_id="task2") + task3 = EmptyOperator(task_id="task3") + task1 >> task2 >> task3 + + initial_task_states = { + "task1": TaskInstanceState.RUNNING, + "task2": TaskInstanceState.RUNNING, + "task3": TaskInstanceState.RUNNING, + } + + session.commit() + + dagrun = self.create_dag_run( + dag=dag, + task_states=initial_task_states, + state=DagRunState.RUNNING, + session=session, + ) + + now = timezone.utcnow() + ti1 = dagrun.get_task_instance("task1", session) + ti2 = dagrun.get_task_instance("task2", session) + ti3 = dagrun.get_task_instance("task3", session) + + ti1.set_state(TaskInstanceState.SUCCESS, session=session) + ti1.end_date = now + + ti2.set_state(TaskInstanceState.SUCCESS, session=session) + ti2.end_date = now + datetime.timedelta(minutes=1) + + ti3.set_state(TaskInstanceState.FAILED, session=session) + ti3.end_date = now + datetime.timedelta(minutes=2) + + # Reattach callback (required due to potential DAG context loss) + dag.on_failure_callback = on_failure_callable + + dagrun.update_state(session=session, execute_callbacks=True) + assert dagrun.state == DagRunState.FAILED + assert callback_invoked["flag"], "DAG on_failure_callback was not triggered" + @pytest.mark.parametrize("state", State.task_states) @mock.patch.object(settings, "task_instance_mutation_hook", autospec=True) def test_task_instance_mutation_hook(self, mock_hook, dag_maker, session, state): From 6f8c5295e407a1546450c3f6754a12170854f69e Mon Sep 17 00:00:00 2001 From: kiran2706 Date: Fri, 20 Jun 2025 12:45:47 +0530 Subject: [PATCH 2/7] Docstring --- airflow-core/src/airflow/models/dagrun.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index dbf215f4edd18..e185876045d00 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -1351,7 +1351,7 @@ def notify_dagrun_state_changed(self, msg: str = ""): # or LocalTaskJob, so we don't want to "falsely advertise" we notify about that def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason: str = "success"): - """Handle DAG-level callbacks (on_success_callback, on_failure_callback) with enriched context.""" + """Only needed for `dag.test` where `execute_callbacks=True` is passed to `update_state`.""" task_instances = self.get_task_instances() From 6bbfc3466c7686e011893d021261344166df005a Mon Sep 17 00:00:00 2001 From: kiran2706 Date: Fri, 20 Jun 2025 14:58:59 +0530 Subject: [PATCH 3/7] Updated the Context model --- task-sdk/src/airflow/sdk/definitions/context.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/task-sdk/src/airflow/sdk/definitions/context.py b/task-sdk/src/airflow/sdk/definitions/context.py index 082ad36202ec2..028cf79871b3d 100644 --- a/task-sdk/src/airflow/sdk/definitions/context.py +++ b/task-sdk/src/airflow/sdk/definitions/context.py @@ -79,6 +79,16 @@ class Context(TypedDict, total=False): ts_nodash_with_tz: str var: Any + # --- Added for enriched DAG-level callback context --- + execution_date: DateTime + end_date: DateTime + run_duration: float | None + dag_run_url: str + max_tries: int | None + log_url: str + mark_success_url: str + + def get_current_context() -> Context: """ From 6830c6f2d03b9d208d2872385546896f2e6b1872 Mon Sep 17 00:00:00 2001 From: kiran2706 Date: Fri, 20 Jun 2025 16:01:02 +0530 Subject: [PATCH 4/7] Formatted --- airflow-core/src/airflow/models/dagrun.py | 19 ++++++++++--------- .../src/airflow/sdk/definitions/context.py | 1 - 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index e185876045d00..450677cbb51dd 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -1352,7 +1352,6 @@ def notify_dagrun_state_changed(self, msg: str = ""): def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason: str = "success"): """Only needed for `dag.test` where `execute_callbacks=True` is passed to `update_state`.""" - task_instances = self.get_task_instances() # Identify the most relevant task instance @@ -1385,14 +1384,16 @@ def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason: str = " # Add task-level metadata if available if last_relevant_ti: - context.update({ - "task_instance": last_relevant_ti, - "ti": last_relevant_ti, - "try_number": last_relevant_ti.try_number, - "max_tries": last_relevant_ti.max_tries, - "log_url": last_relevant_ti.log_url, - "mark_success_url": last_relevant_ti.mark_success_url, - }) + context.update( + { + "task_instance": last_relevant_ti, + "ti": last_relevant_ti, + "try_number": last_relevant_ti.try_number, + "max_tries": last_relevant_ti.max_tries, + "log_url": last_relevant_ti.log_url, + "mark_success_url": last_relevant_ti.mark_success_url, + } + ) callbacks = dag.on_success_callback if success else dag.on_failure_callback if not callbacks: diff --git a/task-sdk/src/airflow/sdk/definitions/context.py b/task-sdk/src/airflow/sdk/definitions/context.py index 028cf79871b3d..08bc77b527647 100644 --- a/task-sdk/src/airflow/sdk/definitions/context.py +++ b/task-sdk/src/airflow/sdk/definitions/context.py @@ -89,7 +89,6 @@ class Context(TypedDict, total=False): mark_success_url: str - def get_current_context() -> Context: """ Retrieve the execution context dictionary without altering user method's signature. From 0e3eca3be38928d2dcc805cae920136cd321e506 Mon Sep 17 00:00:00 2001 From: Kiran R <71453237+kiran2706@users.noreply.github.com> Date: Mon, 23 Jun 2025 09:59:03 +0530 Subject: [PATCH 5/7] Update airflow-core/src/airflow/models/dagrun.py Co-authored-by: Tzu-ping Chung --- airflow-core/src/airflow/models/dagrun.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 450677cbb51dd..ac935e724bd56 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -1385,14 +1385,12 @@ def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason: str = " # Add task-level metadata if available if last_relevant_ti: context.update( - { - "task_instance": last_relevant_ti, - "ti": last_relevant_ti, - "try_number": last_relevant_ti.try_number, - "max_tries": last_relevant_ti.max_tries, - "log_url": last_relevant_ti.log_url, - "mark_success_url": last_relevant_ti.mark_success_url, - } + task_instance=last_relevant_ti, + ti=last_relevant_ti, + try_number=last_relevant_ti.try_number, + max_tries=last_relevant_ti.max_tries, + log_url=last_relevant_ti.log_url, + mark_success_url=last_relevant_ti.mark_success_url, ) callbacks = dag.on_success_callback if success else dag.on_failure_callback From dbadcf4d0879582629ad6f49ff0c4b761f020e40 Mon Sep 17 00:00:00 2001 From: kiran2706 Date: Mon, 23 Jun 2025 10:23:11 +0530 Subject: [PATCH 6/7] removed unnecessary keys --- airflow-core/src/airflow/models/dagrun.py | 10 ---------- task-sdk/src/airflow/sdk/definitions/context.py | 2 -- 2 files changed, 12 deletions(-) diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index ac935e724bd56..2f1058a931e14 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -1369,17 +1369,11 @@ def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason: str = " context: Context = { # type: ignore[assignment] "dag": dag, "run_id": str(self.run_id), - "execution_date": self.logical_date, "start_date": self.start_date, "end_date": self.end_date, "data_interval_start": self.data_interval_start, "data_interval_end": self.data_interval_end, "reason": reason, - "run_duration": ( - (self.end_date - self.start_date).total_seconds() - if self.start_date and self.end_date - else None - ), } # Add task-level metadata if available @@ -1387,10 +1381,6 @@ def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason: str = " context.update( task_instance=last_relevant_ti, ti=last_relevant_ti, - try_number=last_relevant_ti.try_number, - max_tries=last_relevant_ti.max_tries, - log_url=last_relevant_ti.log_url, - mark_success_url=last_relevant_ti.mark_success_url, ) callbacks = dag.on_success_callback if success else dag.on_failure_callback diff --git a/task-sdk/src/airflow/sdk/definitions/context.py b/task-sdk/src/airflow/sdk/definitions/context.py index 08bc77b527647..b1b9efe6123bc 100644 --- a/task-sdk/src/airflow/sdk/definitions/context.py +++ b/task-sdk/src/airflow/sdk/definitions/context.py @@ -80,9 +80,7 @@ class Context(TypedDict, total=False): var: Any # --- Added for enriched DAG-level callback context --- - execution_date: DateTime end_date: DateTime - run_duration: float | None dag_run_url: str max_tries: int | None log_url: str From 9493126251eaf16654fa9e32983c12ef9df6c7f2 Mon Sep 17 00:00:00 2001 From: kiran2706 Date: Mon, 23 Jun 2025 22:32:43 +0530 Subject: [PATCH 7/7] Fixed test failures --- airflow-core/docs/templates-ref.rst | 5 +++++ airflow-core/src/airflow/models/dagrun.py | 6 ++++-- airflow-core/src/airflow/utils/context.py | 5 +++++ task-sdk/src/airflow/sdk/definitions/context.py | 8 ++++---- task-sdk/src/airflow/sdk/execution_time/task_runner.py | 5 +++++ 5 files changed, 23 insertions(+), 6 deletions(-) diff --git a/airflow-core/docs/templates-ref.rst b/airflow-core/docs/templates-ref.rst index 1d226b1ec12fb..401253961bb33 100644 --- a/airflow-core/docs/templates-ref.rst +++ b/airflow-core/docs/templates-ref.rst @@ -83,6 +83,11 @@ Variable Type Description list[AssetEvent]] | (there may be more than one, if there are multiple Assets with different frequencies). | Read more here :doc:`Assets `. | Added in version 2.4. +``{{ mark_success_url }}`` str | None |URL to mark the DAG run as successful in the Airflow UI. +``{{ log_url }}`` str | None |URL to the log for the current DAG run or task instance. +``{{ dag_run_url }}`` str | None |URL to the DAG run details page in the Airflow UI. +``{{ end_date }}`` DateTime | None |The end date/time of the DAG run. +``{{ max_tries }}`` int | None |The maximum number of tries for the task instance. =========================================== ===================== =================================================================== The following are only available when the DagRun has a ``logical_date`` diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 2f1058a931e14..0867dc3843246 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -1379,8 +1379,10 @@ def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason: str = " # Add task-level metadata if available if last_relevant_ti: context.update( - task_instance=last_relevant_ti, - ti=last_relevant_ti, + { + "task_instance": last_relevant_ti, + "ti": last_relevant_ti, + } ) callbacks = dag.on_success_callback if success else dag.on_failure_callback diff --git a/airflow-core/src/airflow/utils/context.py b/airflow-core/src/airflow/utils/context.py index c27032c7c3c20..806ced8293be1 100644 --- a/airflow-core/src/airflow/utils/context.py +++ b/airflow-core/src/airflow/utils/context.py @@ -86,6 +86,11 @@ "ts_nodash_with_tz", "try_number", "var", + "dag_run_url", + "end_date", + "log_url", + "mark_success_url", + "max_tries", } diff --git a/task-sdk/src/airflow/sdk/definitions/context.py b/task-sdk/src/airflow/sdk/definitions/context.py index b1b9efe6123bc..34dc76b03ac7f 100644 --- a/task-sdk/src/airflow/sdk/definitions/context.py +++ b/task-sdk/src/airflow/sdk/definitions/context.py @@ -80,11 +80,11 @@ class Context(TypedDict, total=False): var: Any # --- Added for enriched DAG-level callback context --- - end_date: DateTime - dag_run_url: str + end_date: DateTime | None + dag_run_url: str | None max_tries: int | None - log_url: str - mark_success_url: str + log_url: str | None + mark_success_url: str | None def get_current_context() -> Context: diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 6c6e597f65e5c..fc727ddec56ae 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -193,6 +193,11 @@ def get_template_context(self) -> Context: "value": VariableAccessor(deserialize_json=False), }, "conn": ConnectionAccessor(), + "dag_run_url": None, # Will be populated in callbacks + "end_date": None, # Will be populated in callbacks + "log_url": None, # Will be populated in callbacks + "mark_success_url": None, # Will be populated in callbacks + "max_tries": None, # Will be populated in callbacks } if from_server: dag_run = from_server.dag_run