diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py index 8e4d819098c5a..ee1e60f1b210a 100644 --- a/airflow/cli/cli_parser.py +++ b/airflow/cli/cli_parser.py @@ -24,6 +24,7 @@ from __future__ import annotations import argparse +import collections import logging from argparse import Action from functools import lru_cache @@ -41,11 +42,12 @@ GroupCommand, core_commands, ) +from airflow.cli.utils import CliConflictError from airflow.exceptions import AirflowException from airflow.executors.executor_loader import ExecutorLoader from airflow.utils.helpers import partition -airflow_commands = core_commands +airflow_commands = core_commands.copy() # make a copy to prevent bad interactions in tests log = logging.getLogger(__name__) try: @@ -59,13 +61,23 @@ "a 3.3.0+ version of the Celery provider. If using a Kubernetes executor, install a " "7.4.0+ version of the CNCF provider" ) - # Do no re-raise the exception since we want the CLI to still function for + # Do not re-raise the exception since we want the CLI to still function for # other commands. ALL_COMMANDS_DICT: dict[str, CLICommand] = {sp.name: sp for sp in airflow_commands} +# Check if sub-commands are defined twice, which could be an issue. +if len(ALL_COMMANDS_DICT) < len(airflow_commands): + dup = {k for k, v in collections.Counter([c.name for c in airflow_commands]).items() if v > 1} + raise CliConflictError( + f"The following CLI {len(dup)} command(s) are defined more than once: {sorted(dup)}\n" + f"This can be due to the executor '{ExecutorLoader.get_default_executor_name()}' " + f"redefining core airflow CLI commands." + ) + + class AirflowHelpFormatter(RichHelpFormatter): """ Custom help formatter to display help message. diff --git a/airflow/cli/utils.py b/airflow/cli/utils.py index 718d34a6eb75c..a300798ec716c 100644 --- a/airflow/cli/utils.py +++ b/airflow/cli/utils.py @@ -21,6 +21,12 @@ import sys +class CliConflictError(Exception): + """Error for when CLI commands are defined twice by different sources.""" + + pass + + def is_stdout(fileio: io.IOBase) -> bool: """Check whether a file IO is stdout. diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 81c441b521f83..5a9c5f30fae68 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -488,6 +488,7 @@ def get_cli_commands() -> list[GroupCommand]: Override this method to expose commands via Airflow CLI to manage this executor. This can be commands to setup/teardown the executor, inspect state, etc. + Make sure to choose unique names for those commands, to avoid collisions. """ return [] diff --git a/tests/cli/commands/test_celery_command.py b/tests/cli/commands/test_celery_command.py index f3e3d391180e6..b97278b17a61c 100644 --- a/tests/cli/commands/test_celery_command.py +++ b/tests/cli/commands/test_celery_command.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import importlib from argparse import Namespace from tempfile import NamedTemporaryFile from unittest import mock @@ -65,11 +66,12 @@ def test_validate_session_dbapi_exception(self, mock_session): class TestCeleryStopCommand: @classmethod def setup_class(cls): - cls.parser = cli_parser.get_parser() + with conf_vars({("core", "executor"): "CeleryExecutor"}): + importlib.reload(cli_parser) + cls.parser = cli_parser.get_parser() @mock.patch("airflow.cli.commands.celery_command.setup_locations") @mock.patch("airflow.cli.commands.celery_command.psutil.Process") - @conf_vars({("core", "executor"): "CeleryExecutor"}) def test_if_right_pid_is_read(self, mock_process, mock_setup_locations): args = self.parser.parse_args(["celery", "stop"]) pid = "123" @@ -90,7 +92,6 @@ def test_if_right_pid_is_read(self, mock_process, mock_setup_locations): @mock.patch("airflow.cli.commands.celery_command.read_pid_from_pidfile") @mock.patch("airflow.providers.celery.executors.celery_executor.app") @mock.patch("airflow.cli.commands.celery_command.setup_locations") - @conf_vars({("core", "executor"): "CeleryExecutor"}) def test_same_pid_file_is_used_in_start_and_stop( self, mock_setup_locations, mock_celery_app, mock_read_pid_from_pidfile ): @@ -116,7 +117,6 @@ def test_same_pid_file_is_used_in_start_and_stop( @mock.patch("airflow.providers.celery.executors.celery_executor.app") @mock.patch("airflow.cli.commands.celery_command.psutil.Process") @mock.patch("airflow.cli.commands.celery_command.setup_locations") - @conf_vars({("core", "executor"): "CeleryExecutor"}) def test_custom_pid_file_is_used_in_start_and_stop( self, mock_setup_locations, @@ -147,12 +147,13 @@ def test_custom_pid_file_is_used_in_start_and_stop( class TestWorkerStart: @classmethod def setup_class(cls): - cls.parser = cli_parser.get_parser() + with conf_vars({("core", "executor"): "CeleryExecutor"}): + importlib.reload(cli_parser) + cls.parser = cli_parser.get_parser() @mock.patch("airflow.cli.commands.celery_command.setup_locations") @mock.patch("airflow.cli.commands.celery_command.Process") @mock.patch("airflow.providers.celery.executors.celery_executor.app") - @conf_vars({("core", "executor"): "CeleryExecutor"}) def test_worker_started_with_required_arguments(self, mock_celery_app, mock_popen, mock_locations): pid_file = "pid_file" mock_locations.return_value = (pid_file, None, None, None) @@ -208,11 +209,12 @@ def test_worker_started_with_required_arguments(self, mock_celery_app, mock_pope class TestWorkerFailure: @classmethod def setup_class(cls): - cls.parser = cli_parser.get_parser() + with conf_vars({("core", "executor"): "CeleryExecutor"}): + importlib.reload(cli_parser) + cls.parser = cli_parser.get_parser() @mock.patch("airflow.cli.commands.celery_command.Process") @mock.patch("airflow.providers.celery.executors.celery_executor.app") - @conf_vars({("core", "executor"): "CeleryExecutor"}) def test_worker_failure_gracefull_shutdown(self, mock_celery_app, mock_popen): args = self.parser.parse_args(["celery", "worker"]) mock_celery_app.run.side_effect = Exception("Mock exception to trigger runtime error") @@ -226,10 +228,11 @@ def test_worker_failure_gracefull_shutdown(self, mock_celery_app, mock_popen): class TestFlowerCommand: @classmethod def setup_class(cls): - cls.parser = cli_parser.get_parser() + with conf_vars({("core", "executor"): "CeleryExecutor"}): + importlib.reload(cli_parser) + cls.parser = cli_parser.get_parser() @mock.patch("airflow.providers.celery.executors.celery_executor.app") - @conf_vars({("core", "executor"): "CeleryExecutor"}) def test_run_command(self, mock_celery_app): args = self.parser.parse_args( [ @@ -268,7 +271,6 @@ def test_run_command(self, mock_celery_app): @mock.patch("airflow.cli.commands.celery_command.setup_locations") @mock.patch("airflow.cli.commands.celery_command.daemon") @mock.patch("airflow.providers.celery.executors.celery_executor.app") - @conf_vars({("core", "executor"): "CeleryExecutor"}) def test_run_command_daemon(self, mock_celery_app, mock_daemon, mock_setup_locations, mock_pid_file): mock_setup_locations.return_value = ( mock.MagicMock(name="pidfile"), diff --git a/tests/cli/commands/test_kubernetes_command.py b/tests/cli/commands/test_kubernetes_command.py index 1f76220f52e6b..3957790fc9b7d 100644 --- a/tests/cli/commands/test_kubernetes_command.py +++ b/tests/cli/commands/test_kubernetes_command.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import importlib import os import tempfile from unittest import mock @@ -26,12 +27,15 @@ from airflow.cli import cli_parser from airflow.cli.commands import kubernetes_command +from tests.test_utils.config import conf_vars class TestGenerateDagYamlCommand: @classmethod def setup_class(cls): - cls.parser = cli_parser.get_parser() + with conf_vars({("core", "executor"): "KubernetesExecutor"}): + importlib.reload(cli_parser) + cls.parser = cli_parser.get_parser() def test_generate_dag_yaml(self): with tempfile.TemporaryDirectory("airflow_dry_run_test/") as directory: @@ -61,7 +65,9 @@ class TestCleanUpPodsCommand: @classmethod def setup_class(cls): - cls.parser = cli_parser.get_parser() + with conf_vars({("core", "executor"): "KubernetesExecutor"}): + importlib.reload(cli_parser) + cls.parser = cli_parser.get_parser() @mock.patch("kubernetes.client.CoreV1Api.delete_namespaced_pod") @mock.patch("airflow.providers.cncf.kubernetes.kube_client.config.load_incluster_config") diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py index dcb8f5a9a7ad8..9987afb6833ab 100644 --- a/tests/cli/conftest.py +++ b/tests/cli/conftest.py @@ -23,7 +23,9 @@ from airflow import models from airflow.cli import cli_parser +from airflow.executors import local_executor from airflow.providers.celery.executors import celery_executor, celery_kubernetes_executor +from airflow.providers.cncf.kubernetes.executors import kubernetes_executor, local_kubernetes_executor from tests.test_utils.config import conf_vars # Create custom executors here because conftest is imported first @@ -34,17 +36,14 @@ custom_executor_module.CustomCeleryKubernetesExecutor = type( # type: ignore "CustomCeleryKubernetesExecutor", (celery_kubernetes_executor.CeleryKubernetesExecutor,), {} ) -custom_executor_module.CustomCeleryExecutor = type( # type: ignore - "CustomLocalExecutor", (celery_executor.CeleryExecutor,), {} -) -custom_executor_module.CustomCeleryKubernetesExecutor = type( # type: ignore - "CustomLocalKubernetesExecutor", (celery_kubernetes_executor.CeleryKubernetesExecutor,), {} +custom_executor_module.CustomLocalExecutor = type( # type: ignore + "CustomLocalExecutor", (local_executor.LocalExecutor,), {} ) -custom_executor_module.CustomCeleryExecutor = type( # type: ignore - "CustomKubernetesExecutor", (celery_executor.CeleryExecutor,), {} +custom_executor_module.CustomLocalKubernetesExecutor = type( # type: ignore + "CustomLocalKubernetesExecutor", (local_kubernetes_executor.LocalKubernetesExecutor,), {} ) -custom_executor_module.CustomCeleryKubernetesExecutor = type( # type: ignore - "CustomCeleryKubernetesExecutor", (celery_kubernetes_executor.CeleryKubernetesExecutor,), {} +custom_executor_module.CustomKubernetesExecutor = type( # type: ignore + "CustomKubernetesExecutor", (kubernetes_executor.KubernetesExecutor,), {} ) sys.modules["custom_executor"] = custom_executor_module diff --git a/tests/cli/test_cli_parser.py b/tests/cli/test_cli_parser.py index 28e0eef8c3ff2..6235bc63a58f0 100644 --- a/tests/cli/test_cli_parser.py +++ b/tests/cli/test_cli_parser.py @@ -29,13 +29,15 @@ from collections import Counter from importlib import reload from pathlib import Path -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from airflow.cli import cli_config, cli_parser -from airflow.cli.cli_config import ActionCommand, lazy_load_command +from airflow.cli.cli_config import ActionCommand, core_commands, lazy_load_command +from airflow.cli.utils import CliConflictError from airflow.configuration import AIRFLOW_HOME +from airflow.executors.local_executor import LocalExecutor from tests.test_utils.config import conf_vars # Can not be `--snake_case` or contain uppercase letter @@ -133,6 +135,28 @@ def test_subcommand_arg_flag_conflict(self): f"short option flags {conflict_short_option}" ) + @patch.object(LocalExecutor, "get_cli_commands") + def test_dynamic_conflict_detection(self, cli_commands_mock: MagicMock): + core_commands.append( + ActionCommand( + name="test_command", + help="does nothing", + func=lambda: None, + args=[], + ) + ) + cli_commands_mock.return_value = [ + ActionCommand( + name="test_command", + help="just a command that'll conflict with one defined in core", + func=lambda: None, + args=[], + ) + ] + with pytest.raises(CliConflictError, match="test_command"): + # force re-evaluation of cli commands (done in top level code) + reload(cli_parser) + def test_falsy_default_value(self): arg = cli_parser.Arg(("--test",), default=0, type=int) parser = argparse.ArgumentParser() @@ -205,57 +229,38 @@ def test_positive_int(self): cli_config.positive_int(allow_zero=True)("-1") @pytest.mark.parametrize( - "executor", + "command", [ "celery", "kubernetes", ], ) - def test_dag_parser_require_celery_executor(self, executor): + def test_executor_specific_commands_not_accessible(self, command): with conf_vars({("core", "executor"): "SequentialExecutor"}), contextlib.redirect_stderr( io.StringIO() ) as stderr: reload(cli_parser) parser = cli_parser.get_parser() with pytest.raises(SystemExit): - parser.parse_args([executor]) - stderr = stderr.getvalue() - assert (f"airflow command error: argument GROUP_OR_COMMAND: invalid choice: '{executor}'") in stderr - - @pytest.mark.parametrize( - "executor", - [ - "CeleryExecutor", - "CeleryKubernetesExecutor", - "custom_executor.CustomCeleryExecutor", - "custom_executor.CustomCeleryKubernetesExecutor", - ], - ) - def test_dag_parser_celery_command_accept_celery_executor(self, executor): - with conf_vars({("core", "executor"): executor}), contextlib.redirect_stderr(io.StringIO()) as stderr: - reload(cli_parser) - parser = cli_parser.get_parser() - with pytest.raises(SystemExit): - parser.parse_args(["celery"]) + parser.parse_args([command]) stderr = stderr.getvalue() - assert ( - "airflow celery command error: the following arguments are required: COMMAND, see help above." - ) in stderr + assert (f"airflow command error: argument GROUP_OR_COMMAND: invalid choice: '{command}'") in stderr @pytest.mark.parametrize( "executor,expected_args", [ ("CeleryExecutor", ["celery"]), ("CeleryKubernetesExecutor", ["celery", "kubernetes"]), - ("custom_executor.CustomCeleryExecutor", ["celery"]), - ("custom_executor.CustomCeleryKubernetesExecutor", ["celery", "kubernetes"]), ("KubernetesExecutor", ["kubernetes"]), - ("custom_executor.KubernetesExecutor", ["kubernetes"]), ("LocalExecutor", []), ("LocalKubernetesExecutor", ["kubernetes"]), - ("custom_executor.LocalExecutor", []), - ("custom_executor.LocalKubernetesExecutor", ["kubernetes"]), ("SequentialExecutor", []), + # custom executors are mapped to the regular ones in `conftest.py` + ("custom_executor.CustomLocalExecutor", []), + ("custom_executor.CustomLocalKubernetesExecutor", ["kubernetes"]), + ("custom_executor.CustomCeleryExecutor", ["celery"]), + ("custom_executor.CustomCeleryKubernetesExecutor", ["celery", "kubernetes"]), + ("custom_executor.CustomKubernetesExecutor", ["kubernetes"]), ], ) def test_cli_parser_executors(self, executor, expected_args): @@ -266,20 +271,12 @@ def test_cli_parser_executors(self, executor, expected_args): ) as stderr: reload(cli_parser) parser = cli_parser.get_parser() - with pytest.raises(SystemExit) as e: + with pytest.raises(SystemExit) as e: # running the help command exits, so we prevent that parser.parse_args([expected_arg, "--help"]) - assert e.value.code == 0 + assert e.value.code == 0, stderr.getvalue() # return code 0 == no problem stderr = stderr.getvalue() assert "airflow command error" not in stderr - def test_dag_parser_config_command_dont_required_celery_executor(self): - with conf_vars({("core", "executor"): "CeleryExecutor"}), contextlib.redirect_stderr( - io.StringIO() - ) as stdout: - parser = cli_parser.get_parser() - parser.parse_args(["config", "get-value", "celery", "broker-url"]) - assert stdout is not None - def test_non_existing_directory_raises_when_metavar_is_dir_for_db_export_cleaned(self): """Test that the error message is correct when the directory does not exist.""" with contextlib.redirect_stderr(io.StringIO()) as stderr: