diff --git a/airflow/cli/cli_config.py b/airflow/cli/cli_config.py index 8e4c984450338..02908ee8f0a34 100644 --- a/airflow/cli/cli_config.py +++ b/airflow/cli/cli_config.py @@ -24,7 +24,6 @@ import json import os import textwrap -from argparse import ArgumentError from typing import Callable, Iterable, NamedTuple, Union import lazy_object_proxy @@ -32,8 +31,6 @@ from airflow import settings from airflow.cli.commands.legacy_commands import check_legacy_command from airflow.configuration import conf -from airflow.executors.executor_constants import CELERY_EXECUTOR, CELERY_KUBERNETES_EXECUTOR -from airflow.executors.executor_loader import ExecutorLoader from airflow.settings import _ENABLE_AIP_44 from airflow.utils.cli import ColorMode from airflow.utils.module_loading import import_string @@ -61,46 +58,6 @@ class DefaultHelpParser(argparse.ArgumentParser): def _check_value(self, action, value): """Override _check_value and check conditionally added command.""" - if action.dest == "subcommand" and value == "celery": - executor = conf.get("core", "EXECUTOR") - if executor not in (CELERY_EXECUTOR, CELERY_KUBERNETES_EXECUTOR): - executor_cls, _ = ExecutorLoader.import_executor_cls(executor) - classes = () - try: - from airflow.providers.celery.executors.celery_executor import CeleryExecutor - - classes += (CeleryExecutor,) - except ImportError: - message = ( - "The celery subcommand requires that you pip install the celery module. " - "To do it, run: pip install 'apache-airflow[celery]'" - ) - raise ArgumentError(action, message) - try: - from airflow.providers.celery.executors.celery_kubernetes_executor import ( - CeleryKubernetesExecutor, - ) - - classes += (CeleryKubernetesExecutor,) - except ImportError: - pass - if not issubclass(executor_cls, classes): - message = ( - f"celery subcommand works only with CeleryExecutor, CeleryKubernetesExecutor and " - f"executors derived from them, your current executor: {executor}, subclassed from: " - f'{", ".join([base_cls.__qualname__ for base_cls in executor_cls.__bases__])}' - ) - raise ArgumentError(action, message) - if action.dest == "subcommand" and value == "kubernetes": - try: - import kubernetes.client # noqa: F401 - except ImportError: - message = ( - "The kubernetes subcommand requires that you pip install the kubernetes python client. " - "To do it, run: pip install 'apache-airflow[cncf.kubernetes]'" - ) - raise ArgumentError(action, message) - if action.choices is not None and value not in action.choices: check_legacy_command(action, value) @@ -823,25 +780,6 @@ def string_lower_type(val): action="store_true", ) -ARG_QUEUES = Arg( - ("-q", "--queues"), - help="Comma delimited list of queues to serve", - default=conf.get("operators", "DEFAULT_QUEUE"), -) -ARG_CONCURRENCY = Arg( - ("-c", "--concurrency"), - type=int, - help="The number of worker processes", - default=conf.getint("celery", "worker_concurrency"), -) -ARG_CELERY_HOSTNAME = Arg( - ("-H", "--celery-hostname"), - help="Set the hostname of celery worker if you have multiple workers on a single machine", -) -ARG_UMASK = Arg( - ("-u", "--umask"), - help="Set the umask of celery worker in daemon mode", -) ARG_WITHOUT_MINGLE = Arg( ("--without-mingle",), default=False, @@ -855,34 +793,6 @@ def string_lower_type(val): action="store_true", ) -# flower -ARG_BROKER_API = Arg(("-a", "--broker-api"), help="Broker API") -ARG_FLOWER_HOSTNAME = Arg( - ("-H", "--hostname"), - default=conf.get("celery", "FLOWER_HOST"), - help="Set the hostname on which to run the server", -) -ARG_FLOWER_PORT = Arg( - ("-p", "--port"), - default=conf.getint("celery", "FLOWER_PORT"), - type=int, - help="The port on which to run the server", -) -ARG_FLOWER_CONF = Arg(("-c", "--flower-conf"), help="Configuration file for flower") -ARG_FLOWER_URL_PREFIX = Arg( - ("-u", "--url-prefix"), - default=conf.get("celery", "FLOWER_URL_PREFIX"), - help="URL prefix for Flower", -) -ARG_FLOWER_BASIC_AUTH = Arg( - ("-A", "--basic-auth"), - default=conf.get("celery", "FLOWER_BASIC_AUTH"), - help=( - "Securing Flower with Basic Authentication. " - "Accepts user:password pairs separated by a comma. " - "Example: flower_basic_auth = user1:password1,user2:password2" - ), -) ARG_TASK_PARAMS = Arg(("-t", "--task-params"), help="Sends a JSON params dict to the task") ARG_POST_MORTEM = Arg( ("-m", "--post-mortem"), action="store_true", help="Open debugger on uncaught exception" @@ -1978,55 +1888,6 @@ class GroupCommand(NamedTuple): ), ) -CELERY_COMMANDS = ( - ActionCommand( - name="worker", - help="Start a Celery worker node", - func=lazy_load_command("airflow.cli.commands.celery_command.worker"), - args=( - ARG_QUEUES, - ARG_CONCURRENCY, - ARG_CELERY_HOSTNAME, - ARG_PID, - ARG_DAEMON, - ARG_UMASK, - ARG_STDOUT, - ARG_STDERR, - ARG_LOG_FILE, - ARG_AUTOSCALE, - ARG_SKIP_SERVE_LOGS, - ARG_WITHOUT_MINGLE, - ARG_WITHOUT_GOSSIP, - ARG_VERBOSE, - ), - ), - ActionCommand( - name="flower", - help="Start a Celery Flower", - func=lazy_load_command("airflow.cli.commands.celery_command.flower"), - args=( - ARG_FLOWER_HOSTNAME, - ARG_FLOWER_PORT, - ARG_FLOWER_CONF, - ARG_FLOWER_URL_PREFIX, - ARG_FLOWER_BASIC_AUTH, - ARG_BROKER_API, - ARG_PID, - ARG_DAEMON, - ARG_STDOUT, - ARG_STDERR, - ARG_LOG_FILE, - ARG_VERBOSE, - ), - ), - ActionCommand( - name="stop", - help="Stop the Celery worker gracefully", - func=lazy_load_command("airflow.cli.commands.celery_command.stop_worker"), - args=(ARG_PID, ARG_VERBOSE), - ), -) - CONFIG_COMMANDS = ( ActionCommand( name="get-value", @@ -2109,9 +1970,6 @@ class GroupCommand(NamedTuple): help="Manage DAGs", subcommands=DAGS_COMMANDS, ), - GroupCommand( - name="kubernetes", help="Tools to help run the KubernetesExecutor", subcommands=KUBERNETES_COMMANDS - ), GroupCommand( name="tasks", help="Manage tasks", @@ -2298,15 +2156,6 @@ class GroupCommand(NamedTuple): func=lazy_load_command("airflow.cli.commands.plugins_command.dump_plugins"), args=(ARG_OUTPUT, ARG_VERBOSE), ), - GroupCommand( - name="celery", - help="Celery components", - description=( - "Start celery components. Works only when using CeleryExecutor. For more information, see " - "https://airflow.apache.org/docs/apache-airflow/stable/executor/celery.html" - ), - subcommands=CELERY_COMMANDS, - ), ActionCommand( name="standalone", help="Run an all-in-one copy of Airflow", diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py index 9ccccd0f30676..1a8ec1b84402d 100644 --- a/airflow/cli/cli_parser.py +++ b/airflow/cli/cli_parser.py @@ -24,6 +24,7 @@ from __future__ import annotations import argparse +import logging from argparse import Action from functools import lru_cache from typing import Iterable @@ -41,10 +42,27 @@ core_commands, ) from airflow.exceptions import AirflowException +from airflow.executors.executor_loader import ExecutorLoader from airflow.utils.helpers import partition airflow_commands = core_commands +log = logging.getLogger(__name__) +try: + executor, _ = ExecutorLoader.import_default_executor_cls(validate=False) + airflow_commands.extend(executor.get_cli_commands()) +except Exception: + executor_name = ExecutorLoader.get_default_executor_name() + log.exception("Failed to load CLI commands from executor: %s", executor_name) + log.error( + "Ensure all dependencies are met and try again. If using a Celery based executor install " + "a 3.3.0+ version of the Celery provider. If using a Kubernetes executor, install a " + "7.4.0+ version of the CNCF provider" + ) + # Do no re-raise the exception since we want the CLI to still function for + # other commands. + + ALL_COMMANDS_DICT: dict[str, CLICommand] = {sp.name: sp for sp in airflow_commands} diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 1a44540b077ea..999125afe7ec2 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -27,6 +27,7 @@ import pendulum +from airflow.cli.cli_config import GroupCommand from airflow.configuration import conf from airflow.exceptions import RemovedInAirflow3Warning from airflow.stats import Stats @@ -479,3 +480,12 @@ def send_callback(self, request: CallbackRequest) -> None: if not self.callback_sink: raise ValueError("Callback sink is not ready.") self.callback_sink.send(request) + + @staticmethod + def get_cli_commands() -> list[GroupCommand]: + """Vends CLI commands to be included in Airflow CLI. + + Override this method to expose commands via Airflow CLI to manage this executor. This can + be commands to setup/teardown the executor, inspect state, etc. + """ + return [] diff --git a/airflow/executors/executor_loader.py b/airflow/executors/executor_loader.py index ca21bdf05ac8a..4a4bda85831a3 100644 --- a/airflow/executors/executor_loader.py +++ b/airflow/executors/executor_loader.py @@ -119,18 +119,24 @@ def load_executor(cls, executor_name: str) -> BaseExecutor: return executor_cls() @classmethod - def import_executor_cls(cls, executor_name: str) -> tuple[type[BaseExecutor], ConnectorSource]: + def import_executor_cls( + cls, executor_name: str, validate: bool = True + ) -> tuple[type[BaseExecutor], ConnectorSource]: """ Imports the executor class. Supports the same formats as ExecutorLoader.load_executor. + :param executor_name: Name of core executor or module path to provider provided as a plugin. + :param validate: Whether or not to validate the executor before returning + :return: executor class via executor_name and executor import source """ def _import_and_validate(path: str) -> type[BaseExecutor]: executor = import_string(path) - cls.validate_database_executor_compatibility(executor) + if validate: + cls.validate_database_executor_compatibility(executor) return executor if executor_name in cls.executors: @@ -151,14 +157,16 @@ def _import_and_validate(path: str) -> type[BaseExecutor]: return _import_and_validate(executor_name), ConnectorSource.CUSTOM_PATH @classmethod - def import_default_executor_cls(cls) -> tuple[type[BaseExecutor], ConnectorSource]: + def import_default_executor_cls(cls, validate: bool = True) -> tuple[type[BaseExecutor], ConnectorSource]: """ Imports the default executor class. + :param validate: Whether or not to validate the executor before returning + :return: executor class and executor import source """ executor_name = cls.get_default_executor_name() - executor, source = cls.import_executor_cls(executor_name) + executor, source = cls.import_executor_cls(executor_name, validate=validate) return executor, source @classmethod diff --git a/airflow/providers/celery/executors/celery_executor.py b/airflow/providers/celery/executors/celery_executor.py index d587be44db038..10513605bafed 100644 --- a/airflow/providers/celery/executors/celery_executor.py +++ b/airflow/providers/celery/executors/celery_executor.py @@ -34,6 +34,40 @@ from celery import states as celery_states +try: + from airflow.cli.cli_config import ( + ARG_AUTOSCALE, + ARG_DAEMON, + ARG_LOG_FILE, + ARG_PID, + ARG_SKIP_SERVE_LOGS, + ARG_STDERR, + ARG_STDOUT, + ARG_VERBOSE, + ActionCommand, + Arg, + GroupCommand, + lazy_load_command, + ) +except ImportError: + try: + from airflow import __version__ as airflow_version + except ImportError: + from airflow.version import version as airflow_version + + import packaging.version + + from airflow.exceptions import AirflowOptionalProviderFeatureException + + base_version = packaging.version.parse(airflow_version).base_version + + if packaging.version.parse(base_version) < packaging.version.parse("2.7.0"): + raise AirflowOptionalProviderFeatureException( + "Celery Executor from Celery Provider should only be used with Airflow 2.7.0+.\n" + f"This is Airflow {airflow_version} and Celery and CeleryKubernetesExecutor are " + f"available in the 'airflow.executors' package. You should not use " + f"the provider's executors in this version of Airflow." + ) from airflow.configuration import conf from airflow.exceptions import AirflowTaskTimeout from airflow.executors.base_executor import BaseExecutor @@ -76,6 +110,119 @@ def __getattr__(name): """ +# flower cli args +ARG_BROKER_API = Arg(("-a", "--broker-api"), help="Broker API") +ARG_FLOWER_HOSTNAME = Arg( + ("-H", "--hostname"), + default=conf.get("celery", "FLOWER_HOST"), + help="Set the hostname on which to run the server", +) +ARG_FLOWER_PORT = Arg( + ("-p", "--port"), + default=conf.getint("celery", "FLOWER_PORT"), + type=int, + help="The port on which to run the server", +) +ARG_FLOWER_CONF = Arg(("-c", "--flower-conf"), help="Configuration file for flower") +ARG_FLOWER_URL_PREFIX = Arg( + ("-u", "--url-prefix"), + default=conf.get("celery", "FLOWER_URL_PREFIX"), + help="URL prefix for Flower", +) +ARG_FLOWER_BASIC_AUTH = Arg( + ("-A", "--basic-auth"), + default=conf.get("celery", "FLOWER_BASIC_AUTH"), + help=( + "Securing Flower with Basic Authentication. " + "Accepts user:password pairs separated by a comma. " + "Example: flower_basic_auth = user1:password1,user2:password2" + ), +) + +# worker cli args +ARG_QUEUES = Arg( + ("-q", "--queues"), + help="Comma delimited list of queues to serve", + default=conf.get("operators", "DEFAULT_QUEUE"), +) +ARG_CONCURRENCY = Arg( + ("-c", "--concurrency"), + type=int, + help="The number of worker processes", + default=conf.getint("celery", "worker_concurrency"), +) +ARG_CELERY_HOSTNAME = Arg( + ("-H", "--celery-hostname"), + help="Set the hostname of celery worker if you have multiple workers on a single machine", +) +ARG_UMASK = Arg( + ("-u", "--umask"), + help="Set the umask of celery worker in daemon mode", +) + +ARG_WITHOUT_MINGLE = Arg( + ("--without-mingle",), + default=False, + help="Don't synchronize with other workers at start-up", + action="store_true", +) +ARG_WITHOUT_GOSSIP = Arg( + ("--without-gossip",), + default=False, + help="Don't subscribe to other workers events", + action="store_true", +) + +CELERY_COMMANDS = ( + ActionCommand( + name="worker", + help="Start a Celery worker node", + func=lazy_load_command("airflow.cli.commands.celery_command.worker"), + args=( + ARG_QUEUES, + ARG_CONCURRENCY, + ARG_CELERY_HOSTNAME, + ARG_PID, + ARG_DAEMON, + ARG_UMASK, + ARG_STDOUT, + ARG_STDERR, + ARG_LOG_FILE, + ARG_AUTOSCALE, + ARG_SKIP_SERVE_LOGS, + ARG_WITHOUT_MINGLE, + ARG_WITHOUT_GOSSIP, + ARG_VERBOSE, + ), + ), + ActionCommand( + name="flower", + help="Start a Celery Flower", + func=lazy_load_command("airflow.cli.commands.celery_command.flower"), + args=( + ARG_FLOWER_HOSTNAME, + ARG_FLOWER_PORT, + ARG_FLOWER_CONF, + ARG_FLOWER_URL_PREFIX, + ARG_FLOWER_BASIC_AUTH, + ARG_BROKER_API, + ARG_PID, + ARG_DAEMON, + ARG_STDOUT, + ARG_STDERR, + ARG_LOG_FILE, + ARG_VERBOSE, + ), + ), + ActionCommand( + name="stop", + help="Stop the Celery worker gracefully", + func=lazy_load_command("airflow.cli.commands.celery_command.stop_worker"), + args=(ARG_PID, ARG_VERBOSE), + ), +) + + class CeleryExecutor(BaseExecutor): """ CeleryExecutor is recommended for production use of Airflow. @@ -317,3 +464,17 @@ def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]: except Exception as ex: self.log.error("Error revoking task instance %s from celery: %s", task_instance_key, ex) return readable_tis + + @staticmethod + def get_cli_commands() -> list[GroupCommand]: + return [ + GroupCommand( + name="celery", + help="Celery components", + description=( + "Start celery components. Works only when using CeleryExecutor. For more information, " + "see https://airflow.apache.org/docs/apache-airflow/stable/executor/celery.html" + ), + subcommands=CELERY_COMMANDS, + ), + ] diff --git a/airflow/providers/celery/executors/celery_kubernetes_executor.py b/airflow/providers/celery/executors/celery_kubernetes_executor.py index c5bbaac081394..725e7fe5c4ce4 100644 --- a/airflow/providers/celery/executors/celery_kubernetes_executor.py +++ b/airflow/providers/celery/executors/celery_kubernetes_executor.py @@ -257,3 +257,7 @@ def send_callback(self, request: CallbackRequest) -> None: if not self.callback_sink: raise ValueError("Callback sink is not ready.") self.callback_sink.send(request) + + @staticmethod + def get_cli_commands() -> list: + return CeleryExecutor.get_cli_commands() + KubernetesExecutor.get_cli_commands() diff --git a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py index e57ca205fe3b6..1c4a0e91fb4da 100644 --- a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py +++ b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py @@ -36,6 +36,39 @@ from sqlalchemy.orm import Session from airflow import AirflowException + +try: + from airflow.cli.cli_config import ( + ARG_DAG_ID, + ARG_EXECUTION_DATE, + ARG_OUTPUT_PATH, + ARG_SUBDIR, + ARG_VERBOSE, + ActionCommand, + Arg, + GroupCommand, + lazy_load_command, + positive_int, + ) +except ImportError: + try: + from airflow import __version__ as airflow_version + except ImportError: + from airflow.version import version as airflow_version + + import packaging.version + + from airflow.exceptions import AirflowOptionalProviderFeatureException + + base_version = packaging.version.parse(airflow_version).base_version + + if packaging.version.parse(base_version) < packaging.version.parse("2.7.0"): + raise AirflowOptionalProviderFeatureException( + "Kubernetes Executor from CNCF Provider should only be used with Airflow 2.7.0+.\n" + f"This is Airflow {airflow_version} and Kubernetes and CeleryKubernetesExecutor are " + f"available in the 'airflow.executors' package. You should not use " + f"the provider's executors in this version of Airflow." + ) from airflow.configuration import conf from airflow.executors.base_executor import BaseExecutor from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_types import POD_EXECUTOR_DONE_KEY @@ -70,6 +103,45 @@ class PodReconciliationError(AirflowException): """Raised when an error is encountered while trying to merge pod configs.""" +# CLI Args +ARG_NAMESPACE = Arg( + ("--namespace",), + default=conf.get("kubernetes_executor", "namespace"), + help="Kubernetes Namespace. Default value is `[kubernetes] namespace` in configuration.", +) + +ARG_MIN_PENDING_MINUTES = Arg( + ("--min-pending-minutes",), + default=30, + type=positive_int(allow_zero=False), + help=( + "Pending pods created before the time interval are to be cleaned up, " + "measured in minutes. Default value is 30(m). The minimum value is 5(m)." + ), +) + +# CLI Commands +KUBERNETES_COMMANDS = ( + ActionCommand( + name="cleanup-pods", + help=( + "Clean up Kubernetes pods " + "(created by KubernetesExecutor/KubernetesPodOperator) " + "in evicted/failed/succeeded/pending states" + ), + func=lazy_load_command("airflow.cli.commands.kubernetes_command.cleanup_pods"), + args=(ARG_NAMESPACE, ARG_MIN_PENDING_MINUTES, ARG_VERBOSE), + ), + ActionCommand( + name="generate-dag-yaml", + help="Generate YAML files for all tasks in DAG. Useful for debugging tasks without " + "launching into a cluster", + func=lazy_load_command("airflow.cli.commands.kubernetes_command.generate_pod_yaml"), + args=(ARG_DAG_ID, ARG_EXECUTION_DATE, ARG_SUBDIR, ARG_OUTPUT_PATH, ARG_VERBOSE), + ), +) + + class KubernetesExecutor(BaseExecutor): """Executor for Kubernetes.""" @@ -644,3 +716,13 @@ def end(self) -> None: def terminate(self): """Terminate the executor is not doing anything.""" + + @staticmethod + def get_cli_commands() -> list[GroupCommand]: + return [ + GroupCommand( + name="kubernetes", + help="Tools to help run the KubernetesExecutor", + subcommands=KUBERNETES_COMMANDS, + ) + ] diff --git a/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py b/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py index ee82252096800..49977a35ef846 100644 --- a/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py +++ b/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py @@ -240,3 +240,7 @@ def send_callback(self, request: CallbackRequest) -> None: if not self.callback_sink: raise ValueError("Callback sink is not ready.") self.callback_sink.send(request) + + @staticmethod + def get_cli_commands() -> list: + return KubernetesExecutor.get_cli_commands() diff --git a/airflow/stats.py b/airflow/stats.py index 2f45b338dcee5..5b84def6db76a 100644 --- a/airflow/stats.py +++ b/airflow/stats.py @@ -22,7 +22,6 @@ from typing import TYPE_CHECKING, Callable from airflow.configuration import conf -from airflow.metrics import datadog_logger, otel_logger, statsd_logger from airflow.metrics.base_stats_logger import NoStatsLogger, StatsLogger log = logging.getLogger(__name__) @@ -46,10 +45,16 @@ def __init__(cls, *args, **kwargs) -> None: if not hasattr(cls.__class__, "factory"): is_datadog_enabled_defined = conf.has_option("metrics", "statsd_datadog_enabled") if is_datadog_enabled_defined and conf.getboolean("metrics", "statsd_datadog_enabled"): + from airflow.metrics import datadog_logger + cls.__class__.factory = datadog_logger.get_dogstatsd_logger elif conf.getboolean("metrics", "statsd_on"): + from airflow.metrics import statsd_logger + cls.__class__.factory = statsd_logger.get_statsd_logger elif conf.getboolean("metrics", "otel_on"): + from airflow.metrics import otel_logger + cls.__class__.factory = otel_logger.get_otel_logger else: cls.__class__.factory = NoStatsLogger diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py index 3069726a059cf..dcb8f5a9a7ad8 100644 --- a/tests/cli/conftest.py +++ b/tests/cli/conftest.py @@ -34,6 +34,18 @@ custom_executor_module.CustomCeleryKubernetesExecutor = type( # type: ignore "CustomCeleryKubernetesExecutor", (celery_kubernetes_executor.CeleryKubernetesExecutor,), {} ) +custom_executor_module.CustomCeleryExecutor = type( # type: ignore + "CustomLocalExecutor", (celery_executor.CeleryExecutor,), {} +) +custom_executor_module.CustomCeleryKubernetesExecutor = type( # type: ignore + "CustomLocalKubernetesExecutor", (celery_kubernetes_executor.CeleryKubernetesExecutor,), {} +) +custom_executor_module.CustomCeleryExecutor = type( # type: ignore + "CustomKubernetesExecutor", (celery_executor.CeleryExecutor,), {} +) +custom_executor_module.CustomCeleryKubernetesExecutor = type( # type: ignore + "CustomCeleryKubernetesExecutor", (celery_kubernetes_executor.CeleryKubernetesExecutor,), {} +) sys.modules["custom_executor"] = custom_executor_module diff --git a/tests/cli/test_cli_parser.py b/tests/cli/test_cli_parser.py index 6089961de94be..eaec31db99dbb 100644 --- a/tests/cli/test_cli_parser.py +++ b/tests/cli/test_cli_parser.py @@ -27,6 +27,7 @@ import sys import timeit from collections import Counter +from importlib import reload from pathlib import Path from unittest.mock import patch @@ -203,31 +204,22 @@ def test_positive_int(self): cli_config.positive_int(allow_zero=True)("-1") @pytest.mark.parametrize( - "command", + "executor", [ - ["celery"], - ["celery", "--help"], - ["celery", "worker", "--help"], - ["celery", "worker"], - ["celery", "flower", "--help"], - ["celery", "flower"], - ["celery", "stop_worker", "--help"], - ["celery", "stop_worker"], + "celery", + "kubernetes", ], ) - def test_dag_parser_require_celery_executor(self, command): + def test_dag_parser_require_celery_executor(self, executor): with conf_vars({("core", "executor"): "SequentialExecutor"}), contextlib.redirect_stderr( io.StringIO() ) as stderr: + reload(cli_parser) parser = cli_parser.get_parser() with pytest.raises(SystemExit): - parser.parse_args(command) + parser.parse_args([executor]) stderr = stderr.getvalue() - assert ( - "airflow command error: argument GROUP_OR_COMMAND: celery subcommand " - "works only with CeleryExecutor, CeleryKubernetesExecutor and executors derived from them, " - "your current executor: SequentialExecutor, subclassed from: BaseExecutor, see help above." - ) in stderr + assert (f"airflow command error: argument GROUP_OR_COMMAND: invalid choice: '{executor}'") in stderr @pytest.mark.parametrize( "executor", @@ -240,6 +232,7 @@ def test_dag_parser_require_celery_executor(self, command): ) def test_dag_parser_celery_command_accept_celery_executor(self, executor): with conf_vars({("core", "executor"): executor}), contextlib.redirect_stderr(io.StringIO()) as stderr: + reload(cli_parser) parser = cli_parser.get_parser() with pytest.raises(SystemExit): parser.parse_args(["celery"]) @@ -248,6 +241,36 @@ def test_dag_parser_celery_command_accept_celery_executor(self, executor): "airflow celery command error: the following arguments are required: COMMAND, see help above." ) in stderr + @pytest.mark.parametrize( + "executor,expected_args", + [ + ("CeleryExecutor", ["celery"]), + ("CeleryKubernetesExecutor", ["celery", "kubernetes"]), + ("custom_executor.CustomCeleryExecutor", ["celery"]), + ("custom_executor.CustomCeleryKubernetesExecutor", ["celery", "kubernetes"]), + ("KubernetesExecutor", ["kubernetes"]), + ("custom_executor.KubernetesExecutor", ["kubernetes"]), + ("LocalExecutor", []), + ("LocalKubernetesExecutor", ["kubernetes"]), + ("custom_executor.LocalExecutor", []), + ("custom_executor.LocalKubernetesExecutor", ["kubernetes"]), + ("SequentialExecutor", []), + ], + ) + def test_cli_parser_executors(self, executor, expected_args): + """Test that CLI commands for the configured executor are present""" + for expected_arg in expected_args: + with conf_vars({("core", "executor"): executor}), contextlib.redirect_stderr( + io.StringIO() + ) as stderr: + reload(cli_parser) + parser = cli_parser.get_parser() + with pytest.raises(SystemExit) as e: + parser.parse_args([expected_arg, "--help"]) + assert e.value.code == 0 + stderr = stderr.getvalue() + assert "airflow command error" not in stderr + def test_dag_parser_config_command_dont_required_celery_executor(self): with conf_vars({("core", "executor"): "CeleryExecutor"}), contextlib.redirect_stderr( io.StringIO() diff --git a/tests/executors/test_base_executor.py b/tests/executors/test_base_executor.py index b2059ae996518..7d7f21ac0da6a 100644 --- a/tests/executors/test_base_executor.py +++ b/tests/executors/test_base_executor.py @@ -62,6 +62,10 @@ def test_serve_logs_default_value(): assert not BaseExecutor.serve_logs +def test_no_cli_commands_vended(): + assert not BaseExecutor.get_cli_commands() + + def test_get_event_buffer(): executor = BaseExecutor() diff --git a/tests/integration/cli/commands/test_celery_command.py b/tests/integration/cli/commands/test_celery_command.py index b421c949f0e1a..1b779272cc144 100644 --- a/tests/integration/cli/commands/test_celery_command.py +++ b/tests/integration/cli/commands/test_celery_command.py @@ -17,6 +17,7 @@ from __future__ import annotations +from importlib import reload from unittest import mock import pytest @@ -31,7 +32,11 @@ class TestWorkerServeLogs: @classmethod def setup_class(cls): - cls.parser = cli_parser.get_parser() + with conf_vars({("core", "executor"): "CeleryExecutor"}): + # The cli_parser module is loaded during test collection. Reload it here with the + # executor overridden so that we get the expected commands loaded. + reload(cli_parser) + cls.parser = cli_parser.get_parser() @conf_vars({("core", "executor"): "CeleryExecutor"}) def test_serve_logs_on_worker_start(self): diff --git a/tests/providers/celery/executors/test_celery_executor.py b/tests/providers/celery/executors/test_celery_executor.py index c8ddee8c4ee87..7390855f2e369 100644 --- a/tests/providers/celery/executors/test_celery_executor.py +++ b/tests/providers/celery/executors/test_celery_executor.py @@ -111,6 +111,9 @@ def test_supports_pickling(self): def test_supports_sentry(self): assert CeleryExecutor.supports_sentry + def test_cli_commands_vended(self): + assert CeleryExecutor.get_cli_commands() + @pytest.mark.backend("mysql", "postgres") def test_exception_propagation(self, caplog): caplog.set_level( diff --git a/tests/providers/celery/executors/test_celery_kubernetes_executor.py b/tests/providers/celery/executors/test_celery_kubernetes_executor.py index 1950b93690b27..6c3857912b5ce 100644 --- a/tests/providers/celery/executors/test_celery_kubernetes_executor.py +++ b/tests/providers/celery/executors/test_celery_kubernetes_executor.py @@ -49,6 +49,9 @@ def test_serve_logs_default_value(self): def test_is_single_threaded_default_value(self): assert not CeleryKubernetesExecutor.is_single_threaded + def test_cli_commands_vended(self): + assert CeleryKubernetesExecutor.get_cli_commands() + def test_queued_tasks(self): celery_executor_mock = mock.MagicMock() k8s_executor_mock = mock.MagicMock() diff --git a/tests/providers/cncf/kubernetes/executors/test_kubernetes_executor.py b/tests/providers/cncf/kubernetes/executors/test_kubernetes_executor.py index 0b9721f117a9b..f0d77d5c18c70 100644 --- a/tests/providers/cncf/kubernetes/executors/test_kubernetes_executor.py +++ b/tests/providers/cncf/kubernetes/executors/test_kubernetes_executor.py @@ -1192,6 +1192,9 @@ def test_supports_pickling(self): def test_supports_sentry(self): assert not KubernetesExecutor.supports_sentry + def test_cli_commands_vended(self): + assert KubernetesExecutor.get_cli_commands() + def test_annotations_for_logging_task_metadata(self): annotations_test = { "dag_id": "dag", diff --git a/tests/providers/cncf/kubernetes/executors/test_local_kubernetes_executor.py b/tests/providers/cncf/kubernetes/executors/test_local_kubernetes_executor.py index ca40c95e45a4c..2b1817ac89853 100644 --- a/tests/providers/cncf/kubernetes/executors/test_local_kubernetes_executor.py +++ b/tests/providers/cncf/kubernetes/executors/test_local_kubernetes_executor.py @@ -46,6 +46,9 @@ def test_serve_logs_default_value(self): def test_is_single_threaded_default_value(self): assert not LocalKubernetesExecutor.is_single_threaded + def test_cli_commands_vended(self): + assert LocalKubernetesExecutor.get_cli_commands() + def test_queued_tasks(self): local_executor_mock = mock.MagicMock() k8s_executor_mock = mock.MagicMock()