diff --git a/airflow-core/src/airflow/utils/db_cleanup.py b/airflow-core/src/airflow/utils/db_cleanup.py index 02b0307523251..96e803775f803 100644 --- a/airflow-core/src/airflow/utils/db_cleanup.py +++ b/airflow-core/src/airflow/utils/db_cleanup.py @@ -73,6 +73,7 @@ class _TableConfig: in the table. to ignore certain records even if they are the latest in the table, you can supply additional filters here (e.g. externally triggered dag runs) :param keep_last_group_by: if keeping the last record, can keep the last record for each group + :param dependent_tables: list of tables which have FK relationship with this table """ table_name: str @@ -81,6 +82,10 @@ class _TableConfig: keep_last: bool = False keep_last_filters: Any | None = None keep_last_group_by: Any | None = None + # We explicitly list these tables instead of detecting foreign keys automatically, + # because the relationships are unlikely to change and the number of tables is small. + # Relying on automation here would increase complexity and reduce maintainability. + dependent_tables: list[str] | None = None def __post_init__(self): self.recency_column = column(self.recency_column_name) @@ -104,7 +109,11 @@ def readable_config(self): config_list: list[_TableConfig] = [ _TableConfig(table_name="job", recency_column_name="latest_heartbeat"), - _TableConfig(table_name="dag", recency_column_name="last_parsed_time"), + _TableConfig( + table_name="dag", + recency_column_name="last_parsed_time", + dependent_tables=["dag_version", "deadline"], + ), _TableConfig( table_name="dag_run", recency_column_name="start_date", @@ -112,12 +121,17 @@ def readable_config(self): keep_last=True, keep_last_filters=[column("run_type") != DagRunType.MANUAL], keep_last_group_by=["dag_id"], + dependent_tables=["task_instance", "deadline"], ), _TableConfig(table_name="asset_event", recency_column_name="timestamp"), _TableConfig(table_name="import_error", recency_column_name="timestamp"), _TableConfig(table_name="log", recency_column_name="dttm"), _TableConfig(table_name="sla_miss", recency_column_name="timestamp"), - _TableConfig(table_name="task_instance", recency_column_name="start_date"), + _TableConfig( + table_name="task_instance", + recency_column_name="start_date", + dependent_tables=["task_instance_history", "xcom"], + ), _TableConfig(table_name="task_instance_history", recency_column_name="start_date"), _TableConfig(table_name="task_reschedule", recency_column_name="start_date"), _TableConfig(table_name="xcom", recency_column_name="timestamp"), @@ -125,8 +139,16 @@ def readable_config(self): _TableConfig(table_name="callback_request", recency_column_name="created_at"), _TableConfig(table_name="celery_taskmeta", recency_column_name="date_done"), _TableConfig(table_name="celery_tasksetmeta", recency_column_name="date_done"), - _TableConfig(table_name="trigger", recency_column_name="created_date"), - _TableConfig(table_name="dag_version", recency_column_name="created_at"), + _TableConfig( + table_name="trigger", + recency_column_name="created_date", + dependent_tables=["task_instance"], + ), + _TableConfig( + table_name="dag_version", + recency_column_name="created_at", + dependent_tables=["task_instance", "dag_run"], + ), _TableConfig(table_name="deadline", recency_column_name="deadline_time"), ] @@ -234,6 +256,7 @@ def _do_delete( logger.debug("delete statement:\n%s", delete.compile()) session.execute(delete) session.commit() + except BaseException as e: raise e finally: @@ -414,17 +437,37 @@ def _suppress_with_logging(table: str, session: Session): session.rollback() -def _effective_table_names(*, table_names: list[str] | None) -> tuple[set[str], dict[str, _TableConfig]]: +def _effective_table_names(*, table_names: list[str] | None) -> tuple[list[str], dict[str, _TableConfig]]: desired_table_names = set(table_names or config_dict) - effective_config_dict = {k: v for k, v in config_dict.items() if k in desired_table_names} - effective_table_names = set(effective_config_dict) - if desired_table_names != effective_table_names: - outliers = desired_table_names - effective_table_names + + outliers = desired_table_names - set(config_dict.keys()) + if outliers: logger.warning( - "The following table(s) are not valid choices and will be skipped: %s", sorted(outliers) + "The following table(s) are not valid choices and will be skipped: %s", + sorted(outliers), ) - if not effective_table_names: + desired_table_names = desired_table_names - outliers + + visited: set[str] = set() + effective_table_names: list[str] = [] + + def collect_deps(table: str): + if table in visited: + return + visited.add(table) + config = config_dict[table] + for dep in config.dependent_tables or []: + collect_deps(dep) + effective_table_names.append(table) + + for table_name in desired_table_names: + collect_deps(table_name) + + effective_config_dict = {n: config_dict[n] for n in effective_table_names} + + if not effective_config_dict: raise SystemExit("No tables selected for db cleanup. Please choose valid table names.") + return effective_table_names, effective_config_dict @@ -480,6 +523,8 @@ def run_cleanup( :param session: Session representing connection to the metadata database. """ clean_before_timestamp = timezone.coerce_datetime(clean_before_timestamp) + + # Get all tables to clean (root + dependents) effective_table_names, effective_config_dict = _effective_table_names(table_names=table_names) if dry_run: print("Performing dry run for db cleanup.") @@ -491,6 +536,7 @@ def run_cleanup( if not dry_run and confirm: _confirm_delete(date=clean_before_timestamp, tables=sorted(effective_table_names)) existing_tables = reflect_tables(tables=None, session=session).tables + for table_name, table_config in effective_config_dict.items(): if table_name in existing_tables: with _suppress_with_logging(table_name, session): diff --git a/airflow-core/tests/unit/utils/test_db_cleanup.py b/airflow-core/tests/unit/utils/test_db_cleanup.py index 00c1cf4efd442..89cfe892f3817 100644 --- a/airflow-core/tests/unit/utils/test_db_cleanup.py +++ b/airflow-core/tests/unit/utils/test_db_cleanup.py @@ -26,7 +26,7 @@ import pendulum import pytest -from sqlalchemy import text +from sqlalchemy import inspect, text from sqlalchemy.exc import OperationalError, SQLAlchemyError from sqlalchemy.ext.declarative import DeclarativeMeta @@ -303,6 +303,51 @@ def test__cleanup_table(self, table_name, date_add_kwargs, expected_to_delete, r else: raise Exception("unexpected") + @pytest.mark.parametrize( + "table_name, expected_archived", + [ + ( + "dag_run", + {"dag_run", "task_instance"}, # Only these are populated + ), + ], + ) + def test_run_cleanup_archival_integration(self, table_name, expected_archived): + """ + Integration test that verifies: + 1. Recursive FK-dependent tables are resolved via _effective_table_names(). + 2. run_cleanup() archives only tables with data. + 3. Archive tables are not created for empty dependent tables. + """ + base_date = pendulum.datetime(2022, 1, 1, tz="UTC") + num_tis = 5 + + # Create test data for DAG Run and TIs + if table_name in {"dag_run", "task_instance"}: + create_tis(base_date=base_date, num_tis=num_tis, run_type=DagRunType.MANUAL) + + clean_before_date = base_date.add(days=10) + + with create_session() as session: + run_cleanup( + clean_before_timestamp=clean_before_date, + table_names=[table_name], + dry_run=False, + confirm=False, + session=session, + ) + + # Inspect archive tables created + inspector = inspect(session.bind) + archive_tables = { + name for name in inspector.get_table_names() if name.startswith(ARCHIVE_TABLE_PREFIX) + } + actual_archived = {t.split("__", 1)[-1].split("__")[0] for t in archive_tables} + + assert expected_archived <= actual_archived, ( + f"Expected archive tables not found: {expected_archived - actual_archived}" + ) + @pytest.mark.parametrize( "skip_archive, expected_archives", [pytest.param(True, 1, id="skip_archive"), pytest.param(False, 2, id="do_archive")],