diff --git a/airflow-core/src/airflow/dag_processing/manager.py b/airflow-core/src/airflow/dag_processing/manager.py index b9b2cfcdee8fa..4926f26acf831 100644 --- a/airflow-core/src/airflow/dag_processing/manager.py +++ b/airflow-core/src/airflow/dag_processing/manager.py @@ -48,6 +48,7 @@ from airflow._shared.timezones import timezone from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI from airflow.configuration import conf +from airflow.dag_processing.bundles.base import BundleUsageTrackingManager from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.dag_processing.collection import update_dag_parsing_results_in_db from airflow.dag_processing.processor import DagFileParsingResult, DagFileProcessorProcess @@ -175,6 +176,9 @@ class DagFileProcessorManager(LoggingMixin): parsing_cleanup_interval: float = attrs.field( factory=_config_int_factory("scheduler", "parsing_cleanup_interval") ) + stale_bundle_cleanup_interval: float = attrs.field( + factory=_config_int_factory("dag_processor", "stale_bundle_cleanup_interval") + ) _file_process_interval: float = attrs.field( factory=_config_int_factory("dag_processor", "min_file_process_interval") ) @@ -183,6 +187,7 @@ class DagFileProcessorManager(LoggingMixin): ) _last_deactivate_stale_dags_time: float = attrs.field(default=0, init=False) + _last_stale_bundle_cleanup_time: float = attrs.field(default=0, init=False) print_stats_interval: float = attrs.field( factory=_config_int_factory("dag_processor", "print_stats_interval") ) @@ -299,6 +304,20 @@ def _scan_stale_dags(self): self.deactivate_stale_dags(last_parsed=last_parsed) self._last_deactivate_stale_dags_time = time.monotonic() + def _cleanup_stale_bundle_versions(self): + if self.stale_bundle_cleanup_interval <= 0: + return + now = time.monotonic() + elapsed_time_since_cleanup = now - self._last_stale_bundle_cleanup_time + if elapsed_time_since_cleanup < self.stale_bundle_cleanup_interval: + return + try: + BundleUsageTrackingManager().remove_stale_bundle_versions() + except Exception: + self.log.exception("Error removing stale bundle versions") + finally: + self._last_stale_bundle_cleanup_time = now + @provide_session def deactivate_stale_dags( self, @@ -374,6 +393,7 @@ def _run_parsing_loop(self): for callback in self._fetch_callbacks(): self._add_callback_to_queue(callback) self._scan_stale_dags() + self._cleanup_stale_bundle_versions() DagWarning.purge_inactive_dag_warnings() # Update number of loop iteration. @@ -491,6 +511,16 @@ def _add_callback_to_queue(self, request: CallbackRequest): # Bundle no longer configured self.log.error("Bundle %s no longer configured, skipping callback", request.bundle_name) return None + if bundle.supports_versioning and request.bundle_version: + try: + bundle.initialize() + except Exception: + self.log.exception( + "Error initializing bundle %s version %s for callback, skipping", + request.bundle_name, + request.bundle_version, + ) + return None file_info = DagFileInfo( rel_path=Path(request.filepath), diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index 6492e4bab777b..b3b31765ee61d 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -35,6 +35,7 @@ TaskCallbackRequest, ) from airflow.configuration import conf +from airflow.dag_processing.bundles.base import BundleVersionLock from airflow.dag_processing.dagbag import BundleDagBag, DagBag from airflow.sdk.exceptions import TaskNotFound from airflow.sdk.execution_time.comms import ( @@ -287,12 +288,16 @@ def _execute_callbacks( ) -> None: for request in callback_requests: log.debug("Processing Callback Request", request=request.to_json()) - if isinstance(request, TaskCallbackRequest): - _execute_task_callbacks(dagbag, request, log) - elif isinstance(request, DagCallbackRequest): - _execute_dag_callbacks(dagbag, request, log) - elif isinstance(request, EmailRequest): - _execute_email_callbacks(dagbag, request, log) + with BundleVersionLock( + bundle_name=request.bundle_name, + bundle_version=request.bundle_version, + ): + if isinstance(request, TaskCallbackRequest): + _execute_task_callbacks(dagbag, request, log) + elif isinstance(request, DagCallbackRequest): + _execute_dag_callbacks(dagbag, request, log) + elif isinstance(request, EmailRequest): + _execute_email_callbacks(dagbag, request, log) def _execute_dag_callbacks(dagbag: DagBag, request: DagCallbackRequest, log: FilteringBoundLogger) -> None: diff --git a/airflow-core/tests/unit/dag_processing/test_manager.py b/airflow-core/tests/unit/dag_processing/test_manager.py index 83d10d79e7191..27f92e423996a 100644 --- a/airflow-core/tests/unit/dag_processing/test_manager.py +++ b/airflow-core/tests/unit/dag_processing/test_manager.py @@ -593,6 +593,20 @@ def test_scan_stale_dags(self, session): # SerializedDagModel gives history about Dags assert serialized_dag_count == 1 + @mock.patch("airflow.dag_processing.manager.BundleUsageTrackingManager") + def test_cleanup_stale_bundle_versions_interval(self, mock_bundle_manager): + manager = DagFileProcessorManager(max_runs=1) + manager.stale_bundle_cleanup_interval = 10 + + manager._last_stale_bundle_cleanup_time = time.monotonic() - 20 + manager._cleanup_stale_bundle_versions() + mock_bundle_manager.return_value.remove_stale_bundle_versions.assert_called_once() + + mock_bundle_manager.return_value.remove_stale_bundle_versions.reset_mock() + manager._last_stale_bundle_cleanup_time = time.monotonic() + manager._cleanup_stale_bundle_versions() + mock_bundle_manager.return_value.remove_stale_bundle_versions.assert_not_called() + def test_kill_timed_out_processors_kill(self): manager = DagFileProcessorManager(max_runs=1, processor_timeout=5) # Set start_time to ensure timeout occurs: start_time = current_time - (timeout + 1) = always (timeout + 1) seconds @@ -1138,6 +1152,51 @@ def test_callback_queue(self, mock_get_logger, configure_testing_dag_bundle): assert dag1_path not in manager._callback_to_execute assert dag2_path not in manager._callback_to_execute + @mock.patch("airflow.dag_processing.manager.DagBundlesManager") + def test_add_callback_initializes_versioned_bundle(self, mock_bundle_manager): + manager = DagFileProcessorManager(max_runs=1) + bundle = MagicMock() + bundle.supports_versioning = True + bundle.path = Path("/tmp/bundle") + mock_bundle_manager.return_value.get_bundle.return_value = bundle + + request = DagCallbackRequest( + filepath="file1.py", + dag_id="dag1", + run_id="run1", + is_failure_callback=False, + bundle_name="testing", + bundle_version="some_commit_hash", + msg=None, + ) + + manager._add_callback_to_queue(request) + + bundle.initialize.assert_called_once() + + @mock.patch("airflow.dag_processing.manager.DagBundlesManager") + def test_add_callback_skips_when_bundle_init_fails(self, mock_bundle_manager): + manager = DagFileProcessorManager(max_runs=1) + bundle = MagicMock() + bundle.supports_versioning = True + bundle.initialize.side_effect = Exception("clone failed") + mock_bundle_manager.return_value.get_bundle.return_value = bundle + + request = DagCallbackRequest( + filepath="file1.py", + dag_id="dag1", + run_id="run1", + is_failure_callback=False, + bundle_name="testing", + bundle_version="some_commit_hash", + msg=None, + ) + + manager._add_callback_to_queue(request) + + bundle.initialize.assert_called_once() + assert len(manager._callback_to_execute) == 0 + def test_dag_with_assets(self, session, configure_testing_dag_bundle): """'Integration' test to ensure that the assets get parsed and stored correctly for parsed dags.""" test_dag_path = str(TEST_DAG_FOLDER / "test_assets.py") diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 960b8440902de..f41b5cbeba42a 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -56,6 +56,7 @@ DagFileProcessorProcess, ToDagProcessor, ToManager, + _execute_callbacks, _execute_dag_callbacks, _execute_email_callbacks, _execute_task_callbacks, @@ -666,6 +667,34 @@ def test_import_error_updates_timestamps(session): assert stat.import_errors == 1 +class TestExecuteCallbacks: + def test_execute_callbacks_locks_bundle_version(self): + callbacks = [ + DagCallbackRequest( + filepath="test.py", + dag_id="test_dag", + run_id="test_run", + bundle_name="testing", + bundle_version="some_commit_hash", + is_failure_callback=False, + msg=None, + ) + ] + log = MagicMock(spec=FilteringBoundLogger) + dagbag = MagicMock(spec=DagBag) + + with ( + patch("airflow.dag_processing.processor.BundleVersionLock") as mock_lock, + patch("airflow.dag_processing.processor._execute_dag_callbacks") as mock_execute, + ): + _execute_callbacks(dagbag, callbacks, log) + + mock_lock.assert_called_once_with(bundle_name="testing", bundle_version="some_commit_hash") + mock_lock.return_value.__enter__.assert_called_once() + mock_lock.return_value.__exit__.assert_called_once() + mock_execute.assert_called_once_with(dagbag, callbacks[0], log) + + class TestExecuteDagCallbacks: """Test the _execute_dag_callbacks function with context_from_server"""