From 589a4e6a0509ef29c9d470d82d4d75b3e0b1b3c2 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 18 Sep 2025 18:30:11 +0100 Subject: [PATCH 1/2] Optimize DAG serialization by excluding schema default values This change reduces serialized DAG size by automatically excluding fields that match their schema default values, similar to how operator serialization works. Fields like `catchup=False`, `max_active_runs=16`, and `fail_fast=False` are no longer stored when they have default values. Follow-up of https://github.com/apache/airflow/pull/54569 --- .../src/airflow/serialization/schema.json | 18 +-- .../serialization/serialized_objects.py | 77 +++++++---- .../serialization/test_dag_serialization.py | 109 ++++++++++++--- .../in_container/run_schema_defaults_check.py | 125 +++++++++++++++--- 4 files changed, 261 insertions(+), 68 deletions(-) diff --git a/airflow-core/src/airflow/serialization/schema.json b/airflow-core/src/airflow/serialization/schema.json index c4740f346f3f8..7707494fce546 100644 --- a/airflow-core/src/airflow/serialization/schema.json +++ b/airflow-core/src/airflow/serialization/schema.json @@ -175,8 +175,8 @@ "value": { "$ref": "#/definitions/dict" } } }, - "catchup": { "type": "boolean" }, - "fail_fast": { "type": "boolean" }, + "catchup": { "type": "boolean", "default": false }, + "fail_fast": { "type": "boolean", "default": false }, "fileloc": { "type" : "string"}, "relative_fileloc": { "type" : "string"}, "_processor_dags_folder": { @@ -198,9 +198,9 @@ ] }, "_concurrency": { "type" : "number"}, - "max_active_tasks": { "type" : "number"}, - "max_active_runs": { "type" : "number"}, - "max_consecutive_failed_dag_runs": { "type" : "number"}, + "max_active_tasks": { "type" : "number", "default": 16}, + "max_active_runs": { "type" : "number", "default": 16}, + "max_consecutive_failed_dag_runs": { "type" : "number", "default": 0}, "default_args": { "$ref": "#/definitions/dict" }, "start_date": { "$ref": "#/definitions/datetime" }, "end_date": { "$ref": "#/definitions/datetime" }, @@ -208,9 +208,9 @@ "doc_md": { "type" : "string"}, "access_control": {"$ref": "#/definitions/dict" }, "is_paused_upon_creation": { "type": "boolean" }, - "has_on_success_callback": { "type": "boolean" }, - "has_on_failure_callback": { "type": "boolean" }, - "render_template_as_native_obj": { "type": "boolean" }, + "has_on_success_callback": { "type": "boolean", "default": false }, + "has_on_failure_callback": { "type": "boolean", "default": false }, + "render_template_as_native_obj": { "type": "boolean", "default": false }, "tags": { "type": "array" }, "task_group": {"anyOf": [ { "type": "null" }, @@ -218,7 +218,7 @@ ]}, "edge_info": { "$ref": "#/definitions/edge_info" }, "dag_dependencies": { "$ref": "#/definitions/dag_dependencies" }, - "disable_bundle_versioning": {"type": "boolean"} + "disable_bundle_versioning": {"type": "boolean", "default": false } }, "required": [ "dag_id", diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index ad85dcc380642..8eecae3170285 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -1093,21 +1093,6 @@ def _deserialize_params_dict(cls, encoded_params: list[tuple[str, dict]]) -> Par return ParamsDict(op_params) - @classmethod - def get_operator_optional_fields_from_schema(cls) -> set[str]: - schema_loader = cls._json_schema - - if schema_loader is None: - return set() - - schema_data = schema_loader.schema - operator_def = schema_data.get("definitions", {}).get("operator", {}) - operator_fields = set(operator_def.get("properties", {}).keys()) - required_fields = set(operator_def.get("required", [])) - - optional_fields = operator_fields - required_fields - return optional_fields - @classmethod def get_schema_defaults(cls, object_type: str) -> dict[str, Any]: """ @@ -1713,6 +1698,21 @@ def set_task_dag_references(task: SerializedOperator | MappedOperator, dag: Seri # Bypass set_upstream etc here - it does more than we want dag.task_dict[task_id].upstream_task_ids.add(task.task_id) + @classmethod + def get_operator_optional_fields_from_schema(cls) -> set[str]: + schema_loader = cls._json_schema + + if schema_loader is None: + return set() + + schema_data = schema_loader.schema + operator_def = schema_data.get("definitions", {}).get("operator", {}) + operator_fields = set(operator_def.get("properties", {}).keys()) + required_fields = set(operator_def.get("required", [])) + + optional_fields = operator_fields - required_fields + return optional_fields + @classmethod def deserialize_operator( cls, @@ -1814,7 +1814,7 @@ def detect_dependencies(cls, op: SdkOperator) -> set[DagDependency]: return deps @classmethod - def _matches_client_defaults(cls, var: Any, attrname: str, op: DAGNode) -> bool: + def _matches_client_defaults(cls, var: Any, attrname: str) -> bool: """ Check if a field value matches client_defaults and should be excluded. @@ -1823,7 +1823,6 @@ def _matches_client_defaults(cls, var: Any, attrname: str, op: DAGNode) -> bool: :param var: The value to check :param attrname: The attribute name - :param op: The operator instance :return: True if value matches client_defaults and should be excluded """ try: @@ -1851,7 +1850,7 @@ def _is_excluded(cls, var: Any, attrname: str, op: DAGNode): :return: True if a variable is excluded, False otherwise. """ # Check if value matches client_defaults (hierarchical defaults optimization) - if cls._matches_client_defaults(var, attrname, op): + if cls._matches_client_defaults(var, attrname): return True schema_defaults = cls.get_schema_defaults("operator") @@ -2384,13 +2383,13 @@ class SerializedDAG(BaseSerialization): _processor_dags_folder: str def __init__(self, *, dag_id: str) -> None: - self.catchup = airflow_conf.getboolean("scheduler", "catchup_by_default") + self.catchup = False # Schema default self.dag_id = self.dag_display_name = dag_id self.dagrun_timeout = None self.deadline = None self.default_args = {} self.description = None - self.disable_bundle_versioning = airflow_conf.getboolean("dag_processor", "disable_bundle_versioning") + self.disable_bundle_versioning = False self.doc_md = None self.edge_info = {} self.end_date = None @@ -2398,11 +2397,9 @@ def __init__(self, *, dag_id: str) -> None: self.has_on_failure_callback = False self.has_on_success_callback = False self.is_paused_upon_creation = None - self.max_active_runs = airflow_conf.getint("core", "max_active_runs_per_dag") - self.max_active_tasks = airflow_conf.getint("core", "max_active_tasks_per_dag") - self.max_consecutive_failed_dag_runs = airflow_conf.getint( - "core", "max_consecutive_failed_dag_runs_per_dag" - ) + self.max_active_runs = 16 # Schema default + self.max_active_tasks = 16 # Schema default + self.max_consecutive_failed_dag_runs = 0 # Schema default self.owner_links = {} self.params = ParamsDict() self.partial = False @@ -2624,8 +2621,37 @@ def _is_excluded(cls, var: Any, attrname: str, op: DAGNode): return False if attrname == "dag_display_name" and var == op.dag_id: return True + + # DAG schema defaults exclusion (same pattern as SerializedBaseOperator) + dag_schema_defaults = cls.get_schema_defaults("dag") + if attrname in dag_schema_defaults: + if dag_schema_defaults[attrname] == var: + return True + + optional_fields = cls.get_dag_optional_fields_from_schema() + if var is None: + return True + if attrname in optional_fields: + if var in [[], (), set(), {}]: + return True + return super()._is_excluded(var, attrname, op) + @classmethod + def get_dag_optional_fields_from_schema(cls) -> set[str]: + schema_loader = cls._json_schema + + if schema_loader is None: + return set() + + schema_data = schema_loader.schema + operator_def = schema_data.get("definitions", {}).get("dag", {}) + operator_fields = set(operator_def.get("properties", {}).keys()) + required_fields = set(operator_def.get("required", [])) + + optional_fields = operator_fields - required_fields + return optional_fields + @classmethod def to_dict(cls, var: Any) -> dict: """Stringifies DAGs and operators contained by var and returns a dict of var.""" @@ -3798,6 +3824,7 @@ class LazyDeserializedDAG(pydantic.BaseModel): "dag_display_name", "has_on_success_callback", "has_on_failure_callback", + "tags", # Attr properties that are nullable, or have a default that loads from config "description", "start_date", diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index 4ed84c32fbe4c..156e1feb9a4d9 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -152,13 +152,8 @@ "downstream_task_ids": [], }, "is_paused_upon_creation": False, - "max_active_runs": 16, - "max_active_tasks": 16, - "max_consecutive_failed_dag_runs": 0, "dag_id": "simple_dag", "deadline": None, - "catchup": False, - "disable_bundle_versioning": False, "doc_md": "### DAG Tutorial Documentation", "fileloc": None, "_processor_dags_folder": ( @@ -269,7 +264,6 @@ }, ], "params": [], - "tags": [], }, } @@ -2163,7 +2157,7 @@ def test_dag_on_failure_callback_roundtrip(self, passed_failure_callback, expect (True, "False", True), (False, "True", False), (False, "False", False), - (None, "True", True), + (None, "True", False), (None, "False", False), ], ) @@ -3299,17 +3293,34 @@ def test_handle_v1_serdag(): SerializedDAG.conversion_v1_to_v2(v1) SerializedDAG.conversion_v2_to_v3(v1) - # Update a few subtle differences - v1["dag"]["tags"] = [] - v1["dag"]["catchup"] = False - v1["dag"]["disable_bundle_versioning"] = False + dag = SerializedDAG.from_dict(v1) - expected = copy.deepcopy(serialized_simple_dag_ground_truth) - expected["dag"]["dag_dependencies"] = expected_dag_dependencies - del expected["dag"]["tasks"][1]["__var"]["_operator_extra_links"] + expected_sdag = copy.deepcopy(serialized_simple_dag_ground_truth) + expected = SerializedDAG.from_dict(expected_sdag) + + fields_to_verify = set(vars(expected).keys()) - { + "task_group", # Tested separately + "dag_dependencies", # Tested separately + "last_loaded", # Dynamically set to utcnow + } + + for f in fields_to_verify: + dag_value = getattr(dag, f) + expected_value = getattr(expected, f) + + assert dag_value == expected_value, ( + f"V2 DAG field '{f}' differs from V3: V2={dag_value!r} != V3={expected_value!r}" + ) + + for f in set(vars(expected.task_group).keys()) - {"dag"}: + dag_tg_value = getattr(dag.task_group, f) + expected_tg_value = getattr(expected.task_group, f) + + assert dag_tg_value == expected_tg_value, ( + f"V2 task_group field '{f}' differs: V2={dag_tg_value!r} != V3={expected_tg_value!r}" + ) - del expected["client_defaults"] - assert v1 == expected + assert getattr(dag, "dag_dependencies") == expected_dag_dependencies def test_handle_v2_serdag(): @@ -3514,6 +3525,72 @@ def test_handle_v2_serdag(): ) +def test_dag_schema_defaults_optimization(): + """Test that DAG fields matching schema defaults are excluded from serialization.""" + + # Create DAG with all schema default values + dag_with_defaults = DAG( + dag_id="test_defaults_dag", + start_date=datetime(2023, 1, 1), + # These should match schema defaults and be excluded + catchup=False, + fail_fast=False, + max_active_runs=16, + max_active_tasks=16, + max_consecutive_failed_dag_runs=0, + render_template_as_native_obj=False, + disable_bundle_versioning=False, + # These should be excluded as None + description=None, + doc_md=None, + ) + + # Serialize and check exclusions + serialized = SerializedDAG.to_dict(dag_with_defaults) + dag_data = serialized["dag"] + + # Schema default fields should be excluded + for field in SerializedDAG.get_schema_defaults("dag").keys(): + assert field not in dag_data, f"Schema default field '{field}' should be excluded" + + # None fields should also be excluded + none_fields = ["description", "doc_md"] + for field in none_fields: + assert field not in dag_data, f"None field '{field}' should be excluded" + + # Test deserialization restores defaults correctly + deserialized_dag = SerializedDAG.from_dict(serialized) + + # Verify schema defaults are restored + assert deserialized_dag.catchup is False + assert deserialized_dag.fail_fast is False + assert deserialized_dag.max_active_runs == 16 + assert deserialized_dag.max_active_tasks == 16 + assert deserialized_dag.max_consecutive_failed_dag_runs == 0 + assert deserialized_dag.render_template_as_native_obj is False + assert deserialized_dag.disable_bundle_versioning is False + + # Test with non-default values (should be included) + dag_non_defaults = DAG( + dag_id="test_non_defaults_dag", + start_date=datetime(2023, 1, 1), + catchup=True, # Non-default + max_active_runs=32, # Non-default + description="Test description", # Non-None + ) + + serialized_non_defaults = SerializedDAG.to_dict(dag_non_defaults) + dag_non_defaults_data = serialized_non_defaults["dag"] + + # Non-default values should be included + assert "catchup" in dag_non_defaults_data + assert dag_non_defaults_data["catchup"] is True + assert "max_active_runs" in dag_non_defaults_data + assert dag_non_defaults_data["max_active_runs"] == 32 + assert "description" in dag_non_defaults_data + assert dag_non_defaults_data["description"] == "Test description" + + def test_email_optimization_removes_email_attrs_when_email_empty(): """Test that email_on_failure and email_on_retry are removed when email is empty.""" with DAG(dag_id="test_email_optimization") as dag: diff --git a/scripts/in_container/run_schema_defaults_check.py b/scripts/in_container/run_schema_defaults_check.py index ef672c297ca0c..bc7c1e2844cb3 100755 --- a/scripts/in_container/run_schema_defaults_check.py +++ b/scripts/in_container/run_schema_defaults_check.py @@ -32,7 +32,7 @@ from typing import Any -def load_schema_defaults() -> dict[str, Any]: +def load_schema_defaults(object_type: str = "operator") -> dict[str, Any]: """Load default values from the JSON schema.""" schema_path = Path("airflow-core/src/airflow/serialization/schema.json") @@ -43,9 +43,9 @@ def load_schema_defaults() -> dict[str, Any]: with open(schema_path) as f: schema = json.load(f) - # Extract defaults from the operator definition - operator_def = schema.get("definitions", {}).get("operator", {}) - properties = operator_def.get("properties", {}) + # Extract defaults from the specified object type definition + object_def = schema.get("definitions", {}).get(object_type, {}) + properties = object_def.get("properties", {}) defaults = {} for field_name, field_def in properties.items(): @@ -55,7 +55,7 @@ def load_schema_defaults() -> dict[str, Any]: return defaults -def get_server_side_defaults() -> dict[str, Any]: +def get_server_side_operator_defaults() -> dict[str, Any]: """Get default values from server-side SerializedBaseOperator class.""" try: from airflow.serialization.serialized_objects import SerializedBaseOperator @@ -92,14 +92,46 @@ def get_server_side_defaults() -> dict[str, Any]: sys.exit(1) -def compare_defaults() -> list[str]: - """Compare schema defaults with server-side defaults and return discrepancies.""" - schema_defaults = load_schema_defaults() - server_defaults = get_server_side_defaults() +def get_server_side_dag_defaults() -> dict[str, Any]: + """Get default values from server-side SerializedDAG class.""" + try: + from airflow.serialization.serialized_objects import SerializedDAG + + # DAG defaults are set in __init__, so we create a temporary instance + temp_dag = SerializedDAG(dag_id="temp") + + # Get all serializable DAG fields from the server-side class + serialized_fields = SerializedDAG.get_serialized_fields() + + server_defaults = {} + for field_name in serialized_fields: + if hasattr(temp_dag, field_name): + default_value = getattr(temp_dag, field_name) + # Only include actual default values that are not None, callables, or descriptors + if not callable(default_value) and not isinstance(default_value, (property, type)): + if isinstance(default_value, (set, tuple)): + # Convert to list since schema.json is pure JSON + default_value = list(default_value) + server_defaults[field_name] = default_value + + return server_defaults + + except ImportError as e: + print(f"Error importing SerializedDAG: {e}") + sys.exit(1) + except Exception as e: + print(f"Error getting server-side DAG defaults: {e}") + sys.exit(1) + + +def compare_operator_defaults() -> list[str]: + """Compare operator schema defaults with server-side defaults and return discrepancies.""" + schema_defaults = load_schema_defaults("operator") + server_defaults = get_server_side_operator_defaults() errors = [] - print(f"Found {len(schema_defaults)} schema defaults") - print(f"Found {len(server_defaults)} server-side defaults") + print(f"Found {len(schema_defaults)} operator schema defaults") + print(f"Found {len(server_defaults)} operator server-side defaults") # Check each server default against schema for field_name, server_value in server_defaults.items(): @@ -141,25 +173,82 @@ def compare_defaults() -> list[str]: return errors +def compare_dag_defaults() -> list[str]: + """Compare DAG schema defaults with server-side defaults and return discrepancies.""" + schema_defaults = load_schema_defaults("dag") + server_defaults = get_server_side_dag_defaults() + errors = [] + + print(f"Found {len(schema_defaults)} DAG schema defaults") + print(f"Found {len(server_defaults)} DAG server-side defaults") + + # Check each server default against schema + for field_name, server_value in server_defaults.items(): + schema_value = schema_defaults.get(field_name) + + # Check if field exists in schema + if field_name not in schema_defaults: + # Some server fields don't need defaults in schema (like None values, empty collections, or computed fields) + if ( + server_value is not None + and server_value not in [[], {}, (), set()] + and field_name not in ["dag_id", "dag_display_name"] + ): + errors.append( + f"DAG server field '{field_name}' has default {server_value!r} but no schema default" + ) + continue + + # Direct comparison + if schema_value != server_value: + errors.append( + f"DAG field '{field_name}': schema default is {schema_value!r}, " + f"server default is {server_value!r}" + ) + + # Check for schema defaults that don't have corresponding server defaults + for field_name, schema_value in schema_defaults.items(): + if field_name not in server_defaults: + # Some schema fields are computed properties (like has_on_*_callback) + computed_properties = { + "has_on_success_callback", + "has_on_failure_callback", + } + if field_name not in computed_properties: + errors.append( + f"DAG schema has default for '{field_name}' = {schema_value!r} but no corresponding server default" + ) + + return errors + + def main(): """Main function to run the schema defaults check.""" - print("Checking schema defaults against server-side SerializedBaseOperator...") + print("Checking schema defaults against server-side serialization classes...") + + # Check Operator defaults + print("\n1. Checking Operator defaults...") + operator_errors = compare_operator_defaults() + + # Check Dag defaults + print("\n2. Checking Dag defaults...") + dag_errors = compare_dag_defaults() - errors = compare_defaults() + all_errors = operator_errors + dag_errors - if errors: - print("❌ Found discrepancies between schema and server defaults:") - for error in errors: + if all_errors: + print("\n❌ Found discrepancies between schema and server defaults:") + for error in all_errors: print(f" • {error}") print() print("To fix these issues:") print("1. Update airflow-core/src/airflow/serialization/schema.json to match server defaults, OR") print( - "2. Update airflow-core/src/airflow/serialization/serialized_objects.py class defaults to match schema" + "2. Update airflow-core/src/airflow/serialization/serialized_objects.py class/init defaults to match schema" ) sys.exit(1) else: - print("✅ All schema defaults match server-side defaults!") + print("\n✅ All schema defaults match server-side defaults!") if __name__ == "__main__": From a8fd7ea9ca9a513e2999444f1dea01faec101c8e Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 18 Sep 2025 18:54:09 +0100 Subject: [PATCH 2/2] fixup! Optimize DAG serialization by excluding schema default values --- airflow-core/src/airflow/serialization/serialized_objects.py | 3 +++ .../tests/unit/serialization/test_dag_serialization.py | 5 +++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index 8eecae3170285..d7a4007e5888e 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -1094,6 +1094,7 @@ def _deserialize_params_dict(cls, encoded_params: list[tuple[str, dict]]) -> Par return ParamsDict(op_params) @classmethod + @lru_cache(maxsize=4) # Cache for "operator", "dag", and a few others def get_schema_defaults(cls, object_type: str) -> dict[str, Any]: """ Extract default values from JSON schema for any object type. @@ -1699,6 +1700,7 @@ def set_task_dag_references(task: SerializedOperator | MappedOperator, dag: Seri dag.task_dict[task_id].upstream_task_ids.add(task.task_id) @classmethod + @lru_cache(maxsize=1) # Only one type: "operator" def get_operator_optional_fields_from_schema(cls) -> set[str]: schema_loader = cls._json_schema @@ -2638,6 +2640,7 @@ def _is_excluded(cls, var: Any, attrname: str, op: DAGNode): return super()._is_excluded(var, attrname, op) @classmethod + @lru_cache(maxsize=1) # Only one type: "dag" def get_dag_optional_fields_from_schema(cls) -> set[str]: schema_loader = cls._json_schema diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index 156e1feb9a4d9..102b092770f17 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -2157,7 +2157,7 @@ def test_dag_on_failure_callback_roundtrip(self, passed_failure_callback, expect (True, "False", True), (False, "True", False), (False, "False", False), - (None, "True", False), + (None, "True", True), (None, "False", False), ], ) @@ -2171,7 +2171,8 @@ def test_dag_disable_bundle_versioning_roundtrip(self, dag_arg, conf_arg, expect """ with conf_vars({("dag_processor", "disable_bundle_versioning"): conf_arg}): kwargs = {} - kwargs["disable_bundle_versioning"] = dag_arg + if dag_arg is not None: + kwargs["disable_bundle_versioning"] = dag_arg dag = DAG( dag_id="test_dag_disable_bundle_versioning_roundtrip", schedule=None,