diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 92134fa4bf1cc..503084e13a9db 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -2199,16 +2199,6 @@ scheduler: type: integer default: "20" see_also: ":ref:`scheduler:ha:tunables`" - parsing_pre_import_modules: - description: | - The scheduler reads dag files to extract the airflow modules that are going to be used, - and imports them ahead of time to avoid having to re-do it for each parsing process. - This flag can be set to ``False`` to disable this behavior in case an airflow module needs - to be freshly imported each time (at the cost of increased DAG parsing time). - version_added: 2.6.0 - type: boolean - example: ~ - default: "True" dag_stale_not_seen_duration: description: | Time in seconds after which dags, which were not updated by Dag Processor are deactivated. @@ -2563,3 +2553,13 @@ dag_processor: type: integer example: ~ default: "10" + parsing_pre_import_modules: + description: | + The dag_processor reads dag files to extract the airflow modules that are going to be used, + and imports them ahead of time to avoid having to re-do it for each parsing process. + This flag can be set to ``False`` to disable this behavior in case an airflow module needs + to be freshly imported each time (at the cost of increased DAG parsing time). + version_added: 2.6.0 + type: boolean + example: ~ + default: "True" diff --git a/airflow-core/src/airflow/configuration.py b/airflow-core/src/airflow/configuration.py index 2bf3df9816387..5fd67f5334341 100644 --- a/airflow-core/src/airflow/configuration.py +++ b/airflow-core/src/airflow/configuration.py @@ -364,6 +364,7 @@ def sensitive_config_values(self) -> set[tuple[str, str]]: ("fab", "navbar_text_hover_color"): ("webserver", "navbar_text_hover_color", "3.0.2"), ("api", "secret_key"): ("webserver", "secret_key", "3.0.2"), ("api", "enable_swagger_ui"): ("webserver", "enable_swagger_ui", "3.0.2"), + ("dag_processor", "parsing_pre_import_modules"): ("scheduler", "parsing_pre_import_modules", "3.0.3"), } # A mapping of new section -> (old section, since_version). diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index 011393f22c886..a54c56ddbe27e 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import importlib import os import sys import traceback @@ -45,6 +46,7 @@ from airflow.sdk.execution_time.supervisor import WatchedSubprocess from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG from airflow.stats import Stats +from airflow.utils.file import iter_airflow_imports if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger @@ -98,6 +100,27 @@ class DagFileParsingResult(BaseModel): ] +def _pre_import_airflow_modules(file_path: str, log: FilteringBoundLogger) -> None: + """ + Pre-import Airflow modules found in the given file. + + This prevents modules from being re-imported in each processing process, + saving CPU time and memory. + (The default value of "parsing_pre_import_modules" is set to True) + + :param file_path: Path to the file to scan for imports + :param log: Logger instance to use for warnings + """ + if not conf.getboolean("dag_processor", "parsing_pre_import_modules", fallback=True): + return + + for module in iter_airflow_imports(file_path): + try: + importlib.import_module(module) + except ModuleNotFoundError as e: + log.warning("Error when trying to pre-import module '%s' found in %s: %s", module, file_path, e) + + def _parse_file_entrypoint(): import structlog @@ -127,6 +150,7 @@ def _parse_file_entrypoint(): def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileParsingResult | None: # TODO: Set known_pool names on DagBag! + bag = DagBag( dag_folder=msg.file, bundle_path=msg.bundle_path, @@ -250,6 +274,10 @@ def start( # type: ignore[override] client: Client, **kwargs, ) -> Self: + logger = kwargs["logger"] + + _pre_import_airflow_modules(os.fspath(path), logger) + proc: Self = super().start(target=target, client=client, **kwargs) proc._on_child_started(callbacks, path, bundle_path) return proc diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 8d77da61cafb9..e8ab15a02227c 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -37,6 +37,7 @@ DagFileParsingResult, DagFileProcessorProcess, _parse_file, + _pre_import_airflow_modules, ) from airflow.models import DagBag, TaskInstance from airflow.models.baseoperator import BaseOperator @@ -140,8 +141,13 @@ def fake_collect_dags(dagbag: DagBag, *args, **kwargs): assert "a.py" in resp.import_errors def test_top_level_variable_access( - self, spy_agency: SpyAgency, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch, inprocess_client + self, + spy_agency: SpyAgency, + tmp_path: pathlib.Path, + monkeypatch: pytest.MonkeyPatch, + inprocess_client, ): + logger = MagicMock() logger_filehandle = MagicMock() def dag_in_a_fn(): @@ -158,6 +164,7 @@ def dag_in_a_fn(): path=path, bundle_path=tmp_path, callbacks=[], + logger=logger, logger_filehandle=logger_filehandle, client=inprocess_client, ) @@ -171,8 +178,13 @@ def dag_in_a_fn(): assert result.serialized_dags[0].dag_id == "test_abc" def test_top_level_variable_access_not_found( - self, spy_agency: SpyAgency, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch, inprocess_client + self, + spy_agency: SpyAgency, + tmp_path: pathlib.Path, + monkeypatch: pytest.MonkeyPatch, + inprocess_client, ): + logger = MagicMock() logger_filehandle = MagicMock() def dag_in_a_fn(): @@ -187,6 +199,7 @@ def dag_in_a_fn(): path=path, bundle_path=tmp_path, callbacks=[], + logger=logger, logger_filehandle=logger_filehandle, client=inprocess_client, ) @@ -203,6 +216,7 @@ def dag_in_a_fn(): def test_top_level_variable_set(self, tmp_path: pathlib.Path, inprocess_client): from airflow.models.variable import Variable as VariableORM + logger = MagicMock() logger_filehandle = MagicMock() def dag_in_a_fn(): @@ -218,6 +232,7 @@ def dag_in_a_fn(): path=path, bundle_path=tmp_path, callbacks=[], + logger=logger, logger_filehandle=logger_filehandle, client=inprocess_client, ) @@ -238,6 +253,7 @@ def dag_in_a_fn(): def test_top_level_variable_delete(self, tmp_path: pathlib.Path, inprocess_client): from airflow.models.variable import Variable as VariableORM + logger = MagicMock() logger_filehandle = MagicMock() def dag_in_a_fn(): @@ -259,6 +275,7 @@ def dag_in_a_fn(): path=path, bundle_path=tmp_path, callbacks=[], + logger=logger, logger_filehandle=logger_filehandle, client=inprocess_client, ) @@ -278,6 +295,7 @@ def dag_in_a_fn(): def test_top_level_connection_access( self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch, inprocess_client ): + logger = MagicMock() logger_filehandle = MagicMock() def dag_in_a_fn(): @@ -295,6 +313,7 @@ def dag_in_a_fn(): path=path, bundle_path=tmp_path, callbacks=[], + logger=logger, logger_filehandle=logger_filehandle, client=inprocess_client, ) @@ -308,6 +327,7 @@ def dag_in_a_fn(): assert result.serialized_dags[0].dag_id == "test_my_conn" def test_top_level_connection_access_not_found(self, tmp_path: pathlib.Path, inprocess_client): + logger = MagicMock() logger_filehandle = MagicMock() def dag_in_a_fn(): @@ -323,6 +343,7 @@ def dag_in_a_fn(): path=path, bundle_path=tmp_path, callbacks=[], + logger=logger, logger_filehandle=logger_filehandle, client=inprocess_client, ) @@ -355,6 +376,7 @@ def test_import_module_in_bundle_root(self, tmp_path: pathlib.Path, inprocess_cl path=dag1_path, bundle_path=tmp_path, callbacks=[], + logger=MagicMock(), logger_filehandle=MagicMock(), client=inprocess_client, ) @@ -366,6 +388,65 @@ def test_import_module_in_bundle_root(self, tmp_path: pathlib.Path, inprocess_cl assert result.import_errors == {} assert result.serialized_dags[0].dag_id == "dag_name" + def test__pre_import_airflow_modules_when_disabled(self): + logger = MagicMock() + with ( + env_vars({"AIRFLOW__DAG_PROCESSOR__PARSING_PRE_IMPORT_MODULES": "false"}), + patch("airflow.dag_processing.processor.iter_airflow_imports") as mock_iter, + ): + _pre_import_airflow_modules("test.py", logger) + + mock_iter.assert_not_called() + logger.warning.assert_not_called() + + def test__pre_import_airflow_modules_when_enabled(self): + logger = MagicMock() + with ( + env_vars({"AIRFLOW__DAG_PROCESSOR__PARSING_PRE_IMPORT_MODULES": "true"}), + patch("airflow.dag_processing.processor.iter_airflow_imports", return_value=["airflow.models"]), + patch("airflow.dag_processing.processor.importlib.import_module") as mock_import, + ): + _pre_import_airflow_modules("test.py", logger) + + mock_import.assert_called_once_with("airflow.models") + logger.warning.assert_not_called() + + def test__pre_import_airflow_modules_warns_on_missing_module(self): + logger = MagicMock() + with ( + env_vars({"AIRFLOW__DAG_PROCESSOR__PARSING_PRE_IMPORT_MODULES": "true"}), + patch( + "airflow.dag_processing.processor.iter_airflow_imports", return_value=["non_existent_module"] + ), + patch( + "airflow.dag_processing.processor.importlib.import_module", side_effect=ModuleNotFoundError() + ), + ): + _pre_import_airflow_modules("test.py", logger) + + logger.warning.assert_called_once() + warning_args = logger.warning.call_args[0] + assert "Error when trying to pre-import module" in warning_args[0] + assert "non_existent_module" in warning_args[1] + assert "test.py" in warning_args[2] + + def test__pre_import_airflow_modules_partial_success_and_warning(self): + logger = MagicMock() + with ( + env_vars({"AIRFLOW__DAG_PROCESSOR__PARSING_PRE_IMPORT_MODULES": "true"}), + patch( + "airflow.dag_processing.processor.iter_airflow_imports", + return_value=["airflow.models", "non_existent_module"], + ), + patch( + "airflow.dag_processing.processor.importlib.import_module", + side_effect=[None, ModuleNotFoundError()], + ), + ): + _pre_import_airflow_modules("test.py", logger) + + assert logger.warning.call_count == 1 + def write_dag_in_a_fn_to_file(fn: Callable[[], None], folder: pathlib.Path) -> pathlib.Path: # Create the dag in a fn, and use inspect.getsource to write it to a file so that