diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index 680816a654555..d43a2375ef996 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -2507,6 +2507,23 @@ def serialize_dag(cls, dag: DAG) -> dict: serialized_dag["has_on_success_callback"] = True if dag.has_on_failure_callback: serialized_dag["has_on_failure_callback"] = True + + # TODO: Move this logic to a better place -- ideally before serializing contents of default_args. + # There is some duplication with this and SerializedBaseOperator.partial_kwargs serialization. + # Ideally default_args goes through same logic as fields of SerializedBaseOperator. + if serialized_dag.get("default_args", {}): + default_args_dict = serialized_dag["default_args"][Encoding.VAR] + callbacks_to_remove = [] + for k, v in list(default_args_dict.items()): + if k in [ + f"on_{x}_callback" for x in ("execute", "failure", "success", "retry", "skipped") + ]: + if bool(v): + default_args_dict[f"has_{k}"] = True + callbacks_to_remove.append(k) + for k in callbacks_to_remove: + del default_args_dict[k] + return serialized_dag except SerializationError: raise diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index 40864f18ee795..c06e3a47ed2a7 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -4292,3 +4292,65 @@ def test_partial_kwargs_end_to_end_deserialization(self): assert "owner" in deserialized_task.partial_kwargs assert deserialized_task.partial_kwargs["retry_delay"] == timedelta(seconds=600) assert deserialized_task.partial_kwargs["owner"] == "custom_owner" + + +@pytest.mark.parametrize( + ["callbacks", "expected_has_flags", "absent_keys"], + [ + pytest.param( + { + "on_failure_callback": lambda ctx: None, + "on_success_callback": lambda ctx: None, + "on_retry_callback": lambda ctx: None, + }, + ["has_on_failure_callback", "has_on_success_callback", "has_on_retry_callback"], + ["on_failure_callback", "on_success_callback", "on_retry_callback"], + id="multiple_callbacks", + ), + pytest.param( + {"on_failure_callback": lambda ctx: None}, + ["has_on_failure_callback"], + ["on_failure_callback", "has_on_success_callback", "on_success_callback"], + id="single_callback", + ), + pytest.param( + {"on_failure_callback": lambda ctx: None, "on_execute_callback": None}, + ["has_on_failure_callback"], + ["on_failure_callback", "has_on_execute_callback", "on_execute_callback"], + id="callback_with_none", + ), + pytest.param( + {}, + [], + [ + "has_on_execute_callback", + "has_on_failure_callback", + "has_on_success_callback", + "has_on_retry_callback", + "has_on_skipped_callback", + ], + id="no_callbacks", + ), + ], +) +def test_dag_default_args_callbacks_serialization(callbacks, expected_has_flags, absent_keys): + """Test callbacks in DAG default_args are serialized as boolean flags.""" + default_args = {"owner": "test_owner", "retries": 2, **callbacks} + + with DAG(dag_id="test_default_args_callbacks", default_args=default_args) as dag: + BashOperator(task_id="task1", bash_command="echo 1", dag=dag) + + serialized_dag_dict = SerializedDAG.serialize_dag(dag) + default_args_dict = serialized_dag_dict["default_args"][Encoding.VAR] + + for flag in expected_has_flags: + assert default_args_dict.get(flag) is True + + for key in absent_keys: + assert key not in default_args_dict + + assert default_args_dict["owner"] == "test_owner" + assert default_args_dict["retries"] == 2 + + deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag_dict) + assert deserialized_dag.dag_id == "test_default_args_callbacks"