Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions airflow-core/docs/templates-ref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ Variable Type Description
list[AssetEvent]] | (there may be more than one, if there are multiple Assets with different frequencies).
| Read more here :doc:`Assets <authoring-and-scheduling/asset-scheduling>`.
| Added in version 2.4.
``{{ mark_success_url }}`` str | None |URL to mark the DAG run as successful in the Airflow UI.
``{{ log_url }}`` str | None |URL to the log for the current DAG run or task instance.
``{{ dag_run_url }}`` str | None |URL to the DAG run details page in the Airflow UI.
``{{ end_date }}`` DateTime | None |The end date/time of the DAG run.
``{{ max_tries }}`` int | None |The maximum number of tries for the task instance.
Comment on lines +86 to +90
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be removed now they are not being added

=========================================== ===================== ===================================================================

The following are only available when the DagRun has a ``logical_date``
Expand Down
33 changes: 31 additions & 2 deletions airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,21 +1352,48 @@ def notify_dagrun_state_changed(self, msg: str = ""):

def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason: str = "success"):
"""Only needed for `dag.test` where `execute_callbacks=True` is passed to `update_state`."""
task_instances = self.get_task_instances()

# Identify the most relevant task instance
last_relevant_ti = None
if not success:
failed_tis = [ti for ti in task_instances if ti.state in State.failed_states and ti.end_date]
failed_tis.sort(key=lambda x: x.end_date, reverse=True)
last_relevant_ti = failed_tis[0] if failed_tis else None
else:
success_tis = [ti for ti in task_instances if ti.state in State.success_states and ti.end_date]
success_tis.sort(key=lambda x: x.end_date, reverse=True)
last_relevant_ti = success_tis[0] if success_tis else None
Comment on lines +1365 to +1366
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be something like

last_relevant_ti = max(success_tis, ...) if success_tis else None

instead.

I kind of wonder if we can even avoid building the list at all.

Also is simply sorting by end_date correct? Especially with trigger_rule, the last success/failed ti might not necessarily be the ti that causes the dag run to be marked as success/failed. Can you check the logic in 2.x to see how the ti is selected?


# Enrich DAG-level callback context
context: Context = { # type: ignore[assignment]
"dag": dag,
"run_id": str(self.run_id),
"start_date": self.start_date,
"end_date": self.end_date,
"data_interval_start": self.data_interval_start,
"data_interval_end": self.data_interval_end,
"reason": reason,
}

# Add task-level metadata if available
if last_relevant_ti:
context.update(
{
"task_instance": last_relevant_ti,
"ti": last_relevant_ti,
}
)

callbacks = dag.on_success_callback if success else dag.on_failure_callback
if not callbacks:
self.log.warning("Callback requested, but dag didn't have any for DAG: %s.", dag.dag_id)
self.log.warning("Callback requested, but DAG didn't have any for DAG: %s.", dag.dag_id)
return
callbacks = callbacks if isinstance(callbacks, list) else [callbacks]

for callback in callbacks:
self.log.info(
"Executing on_%s dag callback: %s",
"Executing on_%s DAG callback: %s",
"success" if success else "failure",
callback.__name__ if hasattr(callback, "__name__") else repr(callback),
)
Expand Down Expand Up @@ -2042,3 +2069,5 @@ def __repr__(self):
if self.map_index != -1:
prefix += f" map_index={self.map_index}"
return prefix + ">"

return prefix + ">"
Comment on lines +2072 to +2073
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accidental?

5 changes: 5 additions & 0 deletions airflow-core/src/airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@
"ts_nodash_with_tz",
"try_number",
"var",
"dag_run_url",
"end_date",
"log_url",
"mark_success_url",
"max_tries",
Comment on lines +89 to +93
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, should be removed

}


Expand Down
166 changes: 166 additions & 0 deletions airflow-core/tests/unit/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,172 @@ def test_already_added_task_instances_can_be_ignored(self, dag_maker, session):
first_ti.refresh_from_db()
assert first_ti.state is None

def test_dag_callback_context_with_task_metadata(self, dag_maker, session):
"""Test that DAG-level on_success_callback receives the correct context including task metadata."""

callback_invoked = {"flag": False} # Use mutable object to modify inside nested function

def on_success_callable(context):
callback_invoked["flag"] = True

# Base context assertions
assert context["dag_run"].dag_id == "test_dag_callback_context_with_task_metadata"
assert context["reason"] == "success"
assert "dag" in context
assert "run_id" in context
assert "execution_date" in context
assert "data_interval_start" in context
assert "data_interval_end" in context
assert "dag_run_url" in context

# Task-level metadata
assert "task_instance" in context
assert "try_number" in context
assert "max_tries" in context
assert "log_url" in context
assert "mark_success_url" in context

# Verify task instance content
ti = context["task_instance"]
assert ti.task_id == "task3"
assert ti.state == TaskInstanceState.SUCCESS
assert context["try_number"] == ti.try_number
assert context["max_tries"] == ti.max_tries
assert context["log_url"] == ti.log_url
assert context["mark_success_url"] == ti.mark_success_url

# Define DAG with success callback
with dag_maker(
dag_id="test_dag_callback_context_with_task_metadata",
on_success_callback=on_success_callable,
) as dag:
task1 = EmptyOperator(task_id="task1")
task2 = EmptyOperator(task_id="task2")
task3 = EmptyOperator(task_id="task3")
task1 >> task2 >> task3

initial_task_states = {
"task1": TaskInstanceState.RUNNING,
"task2": TaskInstanceState.RUNNING,
"task3": TaskInstanceState.RUNNING,
}

dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
session.commit()

# Create DAG run
dagrun = self.create_dag_run(
dag=dag,
task_states=initial_task_states,
state=DagRunState.RUNNING,
session=session,
)

# Simulate task completions
ti1 = dagrun.get_task_instance("task1", session)
ti2 = dagrun.get_task_instance("task2", session)
ti3 = dagrun.get_task_instance("task3", session)

now = timezone.utcnow()
ti1.set_state(TaskInstanceState.SUCCESS, session=session)
ti1.end_date = now

ti2.set_state(TaskInstanceState.SUCCESS, session=session)
ti2.end_date = now + datetime.timedelta(minutes=1)

ti3.set_state(TaskInstanceState.SUCCESS, session=session)
ti3.end_date = now + datetime.timedelta(minutes=2)

# Reattach callback if it was stripped (safeguard against test infra / DagBag side effects)
dag.on_success_callback = on_success_callable

# Finalize DAG run and trigger callback
dagrun.update_state(session=session, execute_callbacks=True)

# Ensure callback was actually invoked
assert callback_invoked["flag"], "DAG on_success_callback was not triggered"

def test_dag_callback_context_with_task_metadata_failure(self, dag_maker, session):
"""Test that DAG-level on_failure_callback receives the last failed task instance in the context."""
callback_invoked = {"flag": False}

def on_failure_callable(context):
callback_invoked["flag"] = True

# Base context assertions
assert context["dag_run"].dag_id == "test_dag_callback_context_failure"
assert context["reason"] == "failure"
assert "dag" in context
assert "run_id" in context
assert "execution_date" in context
assert "data_interval_start" in context
assert "data_interval_end" in context
assert "dag_run_url" in context

# Task-level metadata
assert "task_instance" in context
assert "try_number" in context
assert "max_tries" in context
assert "log_url" in context
assert "mark_success_url" in context

# Verify task instance content
ti = context["task_instance"]
assert ti.task_id == "task3"
assert ti.state == TaskInstanceState.FAILED
assert context["try_number"] == ti.try_number
assert context["max_tries"] == ti.max_tries
assert context["log_url"] == ti.log_url
assert context["mark_success_url"] == ti.mark_success_url

# Define DAG with failure callback
with dag_maker(
dag_id="test_dag_callback_context_failure",
schedule=datetime.timedelta(days=1),
start_date=datetime.datetime(2017, 1, 1),
on_failure_callback=on_failure_callable,
) as dag:
task1 = EmptyOperator(task_id="task1")
task2 = EmptyOperator(task_id="task2")
task3 = EmptyOperator(task_id="task3")
task1 >> task2 >> task3

initial_task_states = {
"task1": TaskInstanceState.RUNNING,
"task2": TaskInstanceState.RUNNING,
"task3": TaskInstanceState.RUNNING,
}

session.commit()

dagrun = self.create_dag_run(
dag=dag,
task_states=initial_task_states,
state=DagRunState.RUNNING,
session=session,
)

now = timezone.utcnow()
ti1 = dagrun.get_task_instance("task1", session)
ti2 = dagrun.get_task_instance("task2", session)
ti3 = dagrun.get_task_instance("task3", session)

ti1.set_state(TaskInstanceState.SUCCESS, session=session)
ti1.end_date = now

ti2.set_state(TaskInstanceState.SUCCESS, session=session)
ti2.end_date = now + datetime.timedelta(minutes=1)

ti3.set_state(TaskInstanceState.FAILED, session=session)
ti3.end_date = now + datetime.timedelta(minutes=2)

# Reattach callback (required due to potential DAG context loss)
dag.on_failure_callback = on_failure_callable

dagrun.update_state(session=session, execute_callbacks=True)
assert dagrun.state == DagRunState.FAILED
assert callback_invoked["flag"], "DAG on_failure_callback was not triggered"

@pytest.mark.parametrize("state", State.task_states)
@mock.patch.object(settings, "task_instance_mutation_hook", autospec=True)
def test_task_instance_mutation_hook(self, mock_hook, dag_maker, session, state):
Expand Down
7 changes: 7 additions & 0 deletions task-sdk/src/airflow/sdk/definitions/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ class Context(TypedDict, total=False):
ts_nodash_with_tz: str
var: Any

# --- Added for enriched DAG-level callback context ---
end_date: DateTime | None
dag_run_url: str | None
max_tries: int | None
log_url: str | None
mark_success_url: str | None


def get_current_context() -> Context:
"""
Expand Down
5 changes: 5 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ def get_template_context(self) -> Context:
"value": VariableAccessor(deserialize_json=False),
},
"conn": ConnectionAccessor(),
"dag_run_url": None, # Will be populated in callbacks
"end_date": None, # Will be populated in callbacks
"log_url": None, # Will be populated in callbacks
"mark_success_url": None, # Will be populated in callbacks
"max_tries": None, # Will be populated in callbacks
}
if from_server:
dag_run = from_server.dag_run
Expand Down
Loading