From a32358af3f1e67ebf30625c35fde1e8309fef056 Mon Sep 17 00:00:00 2001 From: blag Date: Thu, 31 Mar 2022 18:41:26 -0700 Subject: [PATCH 01/34] Prep for rewriting the Airflow CLI with Click --- airflow/cli/__init__.py | 101 +++++++++++++++++++++++++++++++++++ airflow/cli/__main__.py | 23 ++++++++ airflow/utils/cli.py | 29 ++++------ setup.cfg | 3 +- tests/utils/test_cli_util.py | 4 +- 5 files changed, 138 insertions(+), 22 deletions(-) create mode 100644 airflow/cli/__main__.py diff --git a/airflow/cli/__init__.py b/airflow/cli/__init__.py index 217e5db960782..451c34f4c5df7 100644 --- a/airflow/cli/__init__.py +++ b/airflow/cli/__init__.py @@ -15,3 +15,104 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import os + +import rich_click as click + +from airflow import settings +from airflow.utils.cli import ColorMode +from airflow.utils.timezone import parse as parsedate + +BUILD_DOCS = "BUILDING_AIRFLOW_DOCS" in os.environ + +click_color = click.option( + '--color', + choices=click.Choice({ColorMode.ON, ColorMode.OFF, ColorMode.AUTO}), + default=ColorMode.AUTO, + help="Do emit colored output (default: auto)", +) +click_conf = click.option( + '-c', '--conf', help="JSON string that gets pickled into the DagRun's conf attribute" +) +click_daemon = click.option( + "-D", "--daemon", 'daemon_', is_flag=True, help="Daemonize instead of running in the foreground" +) +click_dag_id = click.argument("dag_id", help="The id of the dag") +click_dag_id_opt = click.option("-d", "--dag-id", help="The id of the dag") +click_debug = click.option( + "-d", "--debug", is_flag=True, help="Use the server that ships with Flask in debug mode" +) +click_dry_run = click.option( + '-n', + '--dry-run', + default=False, + help="Perform a dry run for each task. Only renders Template Fields for each task, nothing else", +) +click_end_date = click.option( + "-e", + "--end-date", + type=parsedate, + help="Override end_date YYYY-MM-DD", +) +click_execution_date = click.argument("execution_date", help="The execution date of the DAG", type=parsedate) +click_execution_date_or_run_id = click.argument( + "execution_date_or_run_id", help="The execution_date of the DAG or run_id of the DAGRun" +) +click_log_file = click.option( + "-l", + "--log-file", + type=click.Path(exists=True, dir_okay=False, writable=True), + help="Location of the log file", +) +click_output = click.option( + "-o", + "--output", + choices=click.Choice(["table", "json", "yaml", "plain"]), + default="table", + help="Output format.", +) +click_pid = click.option("--pid", type=click.Path(exists=True), help="PID file location") +click_start_date = click.option( + "-s", + "--start-date", + type=parsedate, + help="Override start_date YYYY-MM-DD", +) +click_stderr = click.option( + "--stderr", + type=click.Path(exists=True, dir_okay=False, writable=True), + help="Redirect stderr to this file", +) +click_stdout = click.option( + "--stdout", + type=click.Path(exists=True, dir_okay=False, writable=True), + help="Redirect stdout to this file", +) +click_subdir = click.option( + "-S", + "--subdir", + default='[AIRFLOW_HOME]/dags' if BUILD_DOCS else settings.DAGS_FOLDER, + type=click.Path(), + help=( + "File location or directory from which to look for the dag. " + "Defaults to '[AIRFLOW_HOME]/dags' where [AIRFLOW_HOME] is the " + "value you set for 'AIRFLOW_HOME' config you set in 'airflow.cfg' " + ), +) +click_task_id = click.argument("task_id", help="The id of the task") +click_task_regex = click.option( + "-t", "--task-regex", help="The regex to filter specific task_ids to backfill (optional)" +) +click_verbose = click.option( + '-v', '--verbose', is_flag=True, default=False, help="Make logging output more verbose" +) +click_yes = click.option( + '-y', '--yes', is_flag=True, default=False, help="Do not prompt to confirm. Use with care!" +) + + +# https://click.palletsprojects.com/en/8.1.x/documentation/#help-parameter-customization +@click.group(context_settings={'help_option_names': ['-h', '--help']}) +@click.pass_context +def airflow_cmd(ctx): + pass diff --git a/airflow/cli/__main__.py b/airflow/cli/__main__.py new file mode 100644 index 0000000000000..2ae6c91cb39f1 --- /dev/null +++ b/airflow/cli/__main__.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.cli import airflow_cmd + +if __name__ == '__main__': + airflow_cmd(obj={}) diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py index a15e1b7f0265e..78b8701a98a50 100644 --- a/airflow/utils/cli.py +++ b/airflow/utils/cli.py @@ -47,10 +47,6 @@ def _check_cli_args(args): if not args: raise ValueError("Args should be set") - if not isinstance(args[0], Namespace): - raise ValueError( - "1st positional argument should be argparse.Namespace instance," f"but is {type(args[0])}" - ) def action_cli(func=None, check_db=True): @@ -79,15 +75,13 @@ def action_logging(f: T) -> T: @functools.wraps(f) def wrapper(*args, **kwargs): """ - An wrapper for cli functions. It assumes to have Namespace instance - at 1st positional argument + An wrapper for cli functions. - :param args: Positional argument. It assumes to have Namespace instance - at 1st positional argument + :param args: Positional argument. :param kwargs: A passthrough keyword argument """ _check_cli_args(args) - metrics = _build_metrics(f.__name__, args[0]) + metrics = _build_metrics(f.__name__, args, kwargs) cli_action_loggers.on_pre_execution(**metrics) try: # Check and run migrations if necessary @@ -111,15 +105,16 @@ def wrapper(*args, **kwargs): return action_logging -def _build_metrics(func_name, namespace): +def _build_metrics(func_name, args, kwargs): """ Builds metrics dict from function args - It assumes that function arguments is from airflow.bin.cli module's function - and has Namespace instance where it optionally contains "dag_id", "task_id", - and "execution_date". + If the first item in args is a Namespace instance, it assumes that it + optionally contains "dag_id", "task_id", and "execution_date". :param func_name: name of function - :param namespace: Namespace instance from argparse + :param args: Arguments from wrapped function, possibly including the Namespace instance from + argparse as the first argument + :param kwargs: Keyword arguments from wrapped function :return: dict with metrics """ from airflow.models import Log @@ -146,11 +141,7 @@ def _build_metrics(func_name, namespace): 'user': getuser(), } - if not isinstance(namespace, Namespace): - raise ValueError( - "namespace argument should be argparse.Namespace instance," f"but is {type(namespace)}" - ) - tmp_dic = vars(namespace) + tmp_dic = vars(args[0]) if isinstance(args[0], Namespace) else kwargs metrics['dag_id'] = tmp_dic.get('dag_id') metrics['task_id'] = tmp_dic.get('task_id') metrics['execution_date'] = tmp_dic.get('execution_date') diff --git a/setup.cfg b/setup.cfg index 11507cb4c65cf..7e304018d335f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,8 +45,8 @@ license_files = licenses/LICENSE-moment-strftime.txt licenses/LICENSE-moment.txt licenses/LICENSE-normalize.txt -# End of licences generated automatically licenses/LICENSES-ui.txt +# End of licences generated automatically classifiers = Development Status :: 5 - Production/Stable Environment :: Console @@ -192,6 +192,7 @@ airflow.utils= [options.entry_points] console_scripts= airflow=airflow.__main__:main + airflow-ng=airflow.cli.__main__:airflow_cmd [bdist_wheel] python-tag=py3 diff --git a/tests/utils/test_cli_util.py b/tests/utils/test_cli_util.py index d69e651a205f8..da0e47a495e30 100644 --- a/tests/utils/test_cli_util.py +++ b/tests/utils/test_cli_util.py @@ -38,7 +38,7 @@ def test_metrics_build(self): func_name = 'test' exec_date = datetime.utcnow() namespace = Namespace(dag_id='foo', task_id='bar', subcommand='test', execution_date=exec_date) - metrics = cli._build_metrics(func_name, namespace) + metrics = cli._build_metrics(func_name, [namespace], {}) expected = { 'user': os.environ.get('USER'), @@ -132,7 +132,7 @@ def test_cli_create_user_supplied_password_is_masked(self, given_command, expect exec_date = datetime.utcnow() namespace = Namespace(dag_id='foo', task_id='bar', subcommand='test', execution_date=exec_date) with mock.patch.object(sys, "argv", args): - metrics = cli._build_metrics(args[1], namespace) + metrics = cli._build_metrics(args[1], [namespace], {}) assert metrics.get('start_datetime') <= datetime.utcnow() From 5d27b031caefa710d301c835e2f8b170b2b31395 Mon Sep 17 00:00:00 2001 From: blag Date: Thu, 31 Mar 2022 18:48:03 -0700 Subject: [PATCH 02/34] Convert db subcommand to use Click --- airflow/cli/__main__.py | 1 + airflow/cli/commands/db.py | 331 +++++++++++++++++++++++++++++++++++++ 2 files changed, 332 insertions(+) create mode 100644 airflow/cli/commands/db.py diff --git a/airflow/cli/__main__.py b/airflow/cli/__main__.py index 2ae6c91cb39f1..e4c1e56d27768 100644 --- a/airflow/cli/__main__.py +++ b/airflow/cli/__main__.py @@ -18,6 +18,7 @@ # under the License. from airflow.cli import airflow_cmd +from airflow.cli.commands import db # noqa: F401 if __name__ == '__main__': airflow_cmd(obj={}) diff --git a/airflow/cli/commands/db.py b/airflow/cli/commands/db.py new file mode 100644 index 0000000000000..080df6f0b8e20 --- /dev/null +++ b/airflow/cli/commands/db.py @@ -0,0 +1,331 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Database sub-commands""" +import os +import textwrap +from tempfile import NamedTemporaryFile + +import rich_click as click +import wrapt +from packaging.version import parse as parse_version +from rich.console import Console + +from airflow import settings +from airflow.cli import airflow_cmd, click_dry_run, click_verbose, click_yes +from airflow.exceptions import AirflowException +from airflow.utils import cli as cli_utils, db as db_utils +from airflow.utils.db import REVISION_HEADS_MAP +from airflow.utils.db_cleanup import config_dict, run_cleanup +from airflow.utils.process_utils import execute_interactive + +click_revision = click.option( + '-r', + '--revision', + default=None, + help="(Optional) If provided, only run migrations up to and including this revision.", +) +click_version = click.option( + '-n', + '--version', + default=None, + help=( + "(Optional) The airflow version to upgrade to. Note: must provide either " + "`--revision` or `--version`." + ), +) +click_from_revision = click.option( + '--from-revision', default=None, help="(Optional) If generating sql, may supply a *from* revision" +) +click_from_version = click.option('--from-version', default=None, help="From version help text") +click_show_sql_only = click.option( + '-s', + '--show-sql-only', + is_flag=True, + help=( + "Don't actually run migrations; just print out sql scripts for offline migration. " + "Required if using either `--from-version` or `--from-version`." + ), +) + + +@airflow_cmd.group() +@click.pass_context +def db(ctx): + """Commands for the metadata database""" + + +@db.command('init') +@click.pass_context +def db_init(ctx): + """Initializes the metadata database""" + console = Console() + console.print(f"DB: {settings.engine.url}") + db_utils.initdb() + console.print("Initialization done") + + +@db.command('check-migrations') +@click.pass_context +@click.option( + '-t', + '--migration-wait-timeout', + default=60, + help="This command will wait for up to this time, specified in seconds", +) +def check_migrations(ctx, migration_wait_timeout): + """Function to wait for all airflow migrations to complete. Used for launching airflow in k8s""" + console = Console() + console.print(f"Waiting for {migration_wait_timeout}s") + db_utils.check_migrations(timeout=migration_wait_timeout) + + +@db.command('reset') +@click.pass_context +@click_yes +def db_reset(ctx, yes=False): + """Burn down and rebuild the metadata database""" + console = Console() + console.print(f"DB: {settings.engine.url}") + if yes or click.confirm("This will drop existing tables if they exist. Proceed? (y/n)"): + db_utils.resetdb() + else: + console.print("Cancelled") + + +@wrapt.decorator +def check_revision_and_version_options(wrapped, instance, args, kwargs): + """A decorator that defines upgrade/downgrade option checks in a single place""" + + def wrapper(ctx, revision, version, from_revision, from_version, *_args, show_sql_only=False, **_kwargs): + if revision is not None and version is not None: + raise SystemExit("Cannot supply both `--revision` and `--version`.") + if from_revision is not None and from_version is not None: + raise SystemExit("Cannot supply both `--from-revision` and `--from-version`") + if (from_revision is not None or from_version is not None) and not show_sql_only: + raise SystemExit( + "Args `--from-revision` and `--from-version` may only be used with `--show-sql-only`" + ) + if version is None and revision is None: + raise SystemExit("Must provide either --revision or --version.") + + if from_version is not None: + if parse_version(from_version) < parse_version('2.0.0'): + raise SystemExit("--from-version must be greater than or equal to 2.0.0") + from_revision = REVISION_HEADS_MAP.get(from_version) + if not from_revision: + raise SystemExit(f"Unknown version {from_version!r} supplied as `--from-version`.") + + if version is not None: + revision = REVISION_HEADS_MAP.get(version) + if not revision: + raise SystemExit(f"Upgrading to version {version} is not supported.") + + return wrapped( + ctx, revision, version, from_revision, from_version, *_args, show_sql_only=False, **_kwargs + ) + + return wrapper(*args, **kwargs) + + +@db.command('upgrade') +@click.pass_context +@click_revision +@click_version +@click_from_revision +@click_from_version +@click_show_sql_only +@click_yes +@check_revision_and_version_options +def upgrade(ctx, revision, version, from_revision, from_version, show_sql_only=False, yes=False): + """ + Upgrade the metadata database to latest version + + Upgrade the schema of the metadata database. + To print but not execute commands, use option ``--show-sql-only``. + If using options ``--from-revision`` or ``--from-version``, you must also use + ``--show-sql-only``, because if actually *running* migrations, we should only + migrate from the *current* revision. + """ + console = Console() + console.print(f"Using DB (engine: {settings.engine.url})") + + if not show_sql_only: + console.print(f"Performing upgrade with database {settings.engine.url}") + else: + console.print("Generating SQL for upgrade -- upgrade commands will *not* be submitted.") + + if show_sql_only or ( + yes + or click.confirm( + "\nWarning: About to run schema migrations for the airflow metastore. " + "Please ensure you have backed up your database before any migration " + "operation. Proceed? (y/n)\n" + ) + ): + db_utils.upgradedb(to_revision=revision, from_revision=from_revision, show_sql_only=show_sql_only) + if not show_sql_only: + console.print("Upgrades done") + else: + SystemExit("Cancelled") + + +@db.command('downgrade') +@click.pass_context +@click_revision +@click_version +@click_from_revision +@click_from_version +@click_show_sql_only +@click_yes +@check_revision_and_version_options +def downgrade(ctx, revision, version, from_revision, from_version, show_sql_only=False, yes=False): + """ + Downgrade the schema of the metadata database + + Downgrade the schema of the metadata database. + You must provide either `--revision` or `--version`. + To print but not execute commands, use option `--show-sql-only`. + If using options `--from-revision` or `--from-version`, you must also use `--show-sql-only`, + because if actually *running* migrations, we should only migrate from the *current* revision. + """ + console = Console() + console.print(f"Using DB (engine: {settings.engine.url})") + + if not show_sql_only: + console.print(f"Performing downgrade with database {settings.engine.url}") + else: + console.print("Generating SQL for downgrade -- downgrade commands will *not* be submitted.") + + if show_sql_only or ( + yes + or click.confirm( + "\nWarning: About to reverse schema migrations for the airflow metastore. " + "Please ensure you have backed up your database before any migration " + "operation. Proceed? (y/n)\n" + ) + ): + db_utils.downgrade(to_revision=revision, from_revision=from_revision, show_sql_only=show_sql_only) + if not show_sql_only: + console.print("Downgrades done") + else: + SystemExit("Cancelled") + + +@db.command('shell') +@click.pass_context +@cli_utils.action_cli(check_db=False) +def shell(ctx): + """Runs a shell to access the database""" + url = settings.engine.url + console = Console() + console.print(f"DB: {url}") + + if url.get_backend_name() == 'mysql': + with NamedTemporaryFile(suffix="my.cnf") as f: + content = textwrap.dedent( + f""" + [client] + host = {url.host} + user = {url.username} + password = {url.password or ""} + port = {url.port or "3306"} + database = {url.database} + """ + ).strip() + f.write(content.encode()) + f.flush() + execute_interactive(["mysql", f"--defaults-extra-file={f.name}"]) + elif url.get_backend_name() == 'sqlite': + execute_interactive(["sqlite3", url.database]) + elif url.get_backend_name() == 'postgresql': + env = os.environ.copy() + env['PGHOST'] = url.host or "" + env['PGPORT'] = str(url.port or "5432") + env['PGUSER'] = url.username or "" + # PostgreSQL does not allow the use of PGPASSFILE if the current user is root. + env["PGPASSWORD"] = url.password or "" + env['PGDATABASE'] = url.database + execute_interactive(["psql"], env=env) + elif url.get_backend_name() == 'mssql': + env = os.environ.copy() + env['MSSQL_CLI_SERVER'] = url.host + env['MSSQL_CLI_DATABASE'] = url.database + env['MSSQL_CLI_USER'] = url.username + env['MSSQL_CLI_PASSWORD'] = url.password + execute_interactive(["mssql-cli"], env=env) + else: + raise AirflowException(f"Unknown driver: {url.drivername}") + + +@db.command('check') +@click.pass_context +@click.option( + '-t', + '--migration-wait-timeout', + type=int, + default=60, + help="Tmeout to wait for the database to migrate", +) +@cli_utils.action_cli(check_db=False) +def check(ctx, migration_wait_timeout): + """Runs a check command that checks if db is reachable""" + console = Console() + console.print(f"Waiting for {migration_wait_timeout}s") + db_utils.check_migrations(timeout=migration_wait_timeout) + + +# lazily imported by CLI parser for `help` command +all_tables = sorted(config_dict) + + +@db.command('cleanup') +@click.pass_context +@click.option( + '-t', + '--tables', + multiple=True, + default=all_tables, + show_default=True, + help=( + "Table names to perform maintenance on (use comma-separated list).\n" + "Can be specified multiple times, all tables names will be used.\n" + ), +) +@click.option( + '--clean-before-timestamp', + type=str, + default=None, + help="The date or timestamp before which data should be purged.\n" + "If no timezone info is supplied then dates are assumed to be in airflow default timezone.\n" + "Example: '2022-01-01 00:00:00+01:00'", +) +@click_dry_run +@click_verbose +@click_yes +@cli_utils.action_cli(check_db=False) +def cleanup_tables(ctx, tables, clean_before_timestamp, dry_run, verbose, yes): + """Purge old records in metastore tables""" + split_tables = [] + for table in tables: + split_tables.extend(table.split(',')) + run_cleanup( + table_names=split_tables, + dry_run=dry_run, + clean_before_timestamp=clean_before_timestamp, + verbose=verbose, + confirm=not yes, + ) From ead0fcb27f1793ef21210c7105859837d51fca1e Mon Sep 17 00:00:00 2001 From: blag Date: Thu, 31 Mar 2022 15:58:24 -0700 Subject: [PATCH 03/34] Convert webserver commands to use click --- airflow/cli/__main__.py | 1 + airflow/cli/commands/webserver.py | 577 ++++++++++++++++++++++++++++++ 2 files changed, 578 insertions(+) create mode 100644 airflow/cli/commands/webserver.py diff --git a/airflow/cli/__main__.py b/airflow/cli/__main__.py index e4c1e56d27768..1d8dd73a58046 100644 --- a/airflow/cli/__main__.py +++ b/airflow/cli/__main__.py @@ -19,6 +19,7 @@ from airflow.cli import airflow_cmd from airflow.cli.commands import db # noqa: F401 +from airflow.cli.commands import webserver # noqa: F401 if __name__ == '__main__': airflow_cmd(obj={}) diff --git a/airflow/cli/commands/webserver.py b/airflow/cli/commands/webserver.py new file mode 100644 index 0000000000000..a675008bae0a9 --- /dev/null +++ b/airflow/cli/commands/webserver.py @@ -0,0 +1,577 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Webserver command""" +import hashlib +import logging +import os +import signal +import subprocess +import sys +import textwrap +import time +from contextlib import suppress +from time import sleep +from typing import Dict, List, NoReturn + +import daemon +import psutil +import rich_click as click +from daemon.pidfile import TimeoutPIDLockFile +from lockfile.pidlockfile import read_pid_from_pidfile +from rich.console import Console + +from airflow import settings +from airflow.cli import ( + airflow_cmd, + click_daemon, + click_debug, + click_log_file, + click_pid, + click_stderr, + click_stdout, +) +from airflow.configuration import conf +from airflow.exceptions import AirflowException, AirflowWebServerTimeout +from airflow.utils import cli as cli_utils +from airflow.utils.cli import setup_locations, setup_logging +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.process_utils import check_if_pidfile_process_is_running +from airflow.www.app import create_app + +log = logging.getLogger(__name__) + + +class GunicornMonitor(LoggingMixin): + """ + Runs forever, monitoring the child processes of @gunicorn_master_proc and + restarting workers occasionally or when files in the plug-in directory + has been modified. + + Each iteration of the loop traverses one edge of this state transition + diagram, where each state (node) represents + [ num_ready_workers_running / num_workers_running ]. We expect most time to + be spent in [n / n]. `bs` is the setting webserver.worker_refresh_batch_size. + The horizontal transition at ? happens after the new worker parses all the + dags (so it could take a while!) + V ────────────────────────────────────────────────────────────────────────┐ + [n / n] ──TTIN──> [ [n, n+bs) / n + bs ] ────?───> [n + bs / n + bs] ──TTOU─┘ + ^ ^───────────────┘ + │ + │ ┌────────────────v + └──────┴────── [ [0, n) / n ] <─── start + We change the number of workers by sending TTIN and TTOU to the gunicorn + master process, which increases and decreases the number of child workers + respectively. Gunicorn guarantees that on TTOU workers are terminated + gracefully and that the oldest worker is terminated. + + :param gunicorn_master_pid: PID for the main Gunicorn process + :param num_workers_expected: Number of workers to run the Gunicorn web server + :param master_timeout: Number of seconds the webserver waits before killing gunicorn master that + doesn't respond + :param worker_refresh_interval: Number of seconds to wait before refreshing a batch of workers. + :param worker_refresh_batch_size: Number of workers to refresh at a time. When set to 0, worker + refresh is disabled. When nonzero, airflow periodically refreshes webserver workers by + bringing up new ones and killing old ones. + :param reload_on_plugin_change: If set to True, Airflow will track files in plugins_folder directory. + When it detects changes, then reload the gunicorn. + """ + + def __init__( + self, + gunicorn_master_pid: int, + num_workers_expected: int, + master_timeout: int, + worker_refresh_interval: int, + worker_refresh_batch_size: int, + reload_on_plugin_change: bool, + ): + super().__init__() + self.gunicorn_master_proc = psutil.Process(gunicorn_master_pid) + self.num_workers_expected = num_workers_expected + self.master_timeout = master_timeout + self.worker_refresh_interval = worker_refresh_interval + self.worker_refresh_batch_size = worker_refresh_batch_size + self.reload_on_plugin_change = reload_on_plugin_change + + self._num_workers_running = 0 + self._num_ready_workers_running = 0 + self._last_refresh_time = time.monotonic() if worker_refresh_interval > 0 else None + self._last_plugin_state = self._generate_plugin_state() if reload_on_plugin_change else None + self._restart_on_next_plugin_check = False + + def _generate_plugin_state(self) -> Dict[str, float]: + """ + Generate dict of filenames and last modification time of all files in settings.PLUGINS_FOLDER + directory. + """ + if not settings.PLUGINS_FOLDER: + return {} + + all_filenames: List[str] = [] + for (root, _, filenames) in os.walk(settings.PLUGINS_FOLDER): + all_filenames.extend(os.path.join(root, f) for f in filenames) + plugin_state = {f: self._get_file_hash(f) for f in sorted(all_filenames)} + return plugin_state + + @staticmethod + def _get_file_hash(fname: str): + """Calculate MD5 hash for file""" + hash_md5 = hashlib.md5() + with open(fname, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + def _get_num_ready_workers_running(self) -> int: + """Returns number of ready Gunicorn workers by looking for READY_PREFIX in process name""" + workers = psutil.Process(self.gunicorn_master_proc.pid).children() + + def ready_prefix_on_cmdline(proc): + try: + cmdline = proc.cmdline() + if len(cmdline) > 0: + return settings.GUNICORN_WORKER_READY_PREFIX in cmdline[0] + except psutil.NoSuchProcess: + pass + return False + + ready_workers = [proc for proc in workers if ready_prefix_on_cmdline(proc)] + return len(ready_workers) + + def _get_num_workers_running(self) -> int: + """Returns number of running Gunicorn workers processes""" + workers = psutil.Process(self.gunicorn_master_proc.pid).children() + return len(workers) + + def _wait_until_true(self, fn, timeout: int = 0) -> None: + """Sleeps until fn is true""" + start_time = time.monotonic() + while not fn(): + if 0 < timeout <= time.monotonic() - start_time: + raise AirflowWebServerTimeout(f"No response from gunicorn master within {timeout} seconds") + sleep(0.1) + + def _spawn_new_workers(self, count: int) -> None: + """ + Send signal to kill the worker. + + :param count: The number of workers to spawn + """ + excess = 0 + for _ in range(count): + # TTIN: Increment the number of processes by one + self.gunicorn_master_proc.send_signal(signal.SIGTTIN) + excess += 1 + self._wait_until_true( + lambda: self.num_workers_expected + excess == self._get_num_workers_running(), + timeout=self.master_timeout, + ) + + def _kill_old_workers(self, count: int) -> None: + """ + Send signal to kill the worker. + + :param count: The number of workers to kill + """ + for _ in range(count): + count -= 1 + # TTOU: Decrement the number of processes by one + self.gunicorn_master_proc.send_signal(signal.SIGTTOU) + self._wait_until_true( + lambda: self.num_workers_expected + count == self._get_num_workers_running(), + timeout=self.master_timeout, + ) + + def _reload_gunicorn(self) -> None: + """ + Send signal to reload the gunicorn configuration. When gunicorn receive signals, it reload the + configuration, start the new worker processes with a new configuration and gracefully + shutdown older workers. + """ + # HUP: Reload the configuration. + self.gunicorn_master_proc.send_signal(signal.SIGHUP) + sleep(1) + self._wait_until_true( + lambda: self.num_workers_expected == self._get_num_workers_running(), timeout=self.master_timeout + ) + + def start(self) -> NoReturn: + """Starts monitoring the webserver.""" + try: + self._wait_until_true( + lambda: self.num_workers_expected == self._get_num_workers_running(), + timeout=self.master_timeout, + ) + while True: + if not self.gunicorn_master_proc.is_running(): + sys.exit(1) + self._check_workers() + # Throttle loop + sleep(1) + + except (AirflowWebServerTimeout, OSError) as err: + self.log.error(err) + self.log.error("Shutting down webserver") + try: + self.gunicorn_master_proc.terminate() + self.gunicorn_master_proc.wait() + finally: + sys.exit(1) + + def _check_workers(self) -> None: + num_workers_running = self._get_num_workers_running() + num_ready_workers_running = self._get_num_ready_workers_running() + + # Whenever some workers are not ready, wait until all workers are ready + if num_ready_workers_running < num_workers_running: + self.log.debug( + '[%d / %d] Some workers are starting up, waiting...', + num_ready_workers_running, + num_workers_running, + ) + sleep(1) + return + + # If there are too many workers, then kill a worker gracefully by asking gunicorn to reduce + # number of workers + if num_workers_running > self.num_workers_expected: + excess = min(num_workers_running - self.num_workers_expected, self.worker_refresh_batch_size) + self.log.debug( + '[%d / %d] Killing %s workers', num_ready_workers_running, num_workers_running, excess + ) + self._kill_old_workers(excess) + return + + # If there are too few workers, start a new worker by asking gunicorn + # to increase number of workers + if num_workers_running < self.num_workers_expected: + self.log.error( + "[%d / %d] Some workers seem to have died and gunicorn did not restart them as expected", + num_ready_workers_running, + num_workers_running, + ) + sleep(10) + num_workers_running = self._get_num_workers_running() + if num_workers_running < self.num_workers_expected: + new_worker_count = min( + self.num_workers_expected - num_workers_running, self.worker_refresh_batch_size + ) + # log at info since we are trying fix an error logged just above + self.log.info( + '[%d / %d] Spawning %d workers', + num_ready_workers_running, + num_workers_running, + new_worker_count, + ) + self._spawn_new_workers(new_worker_count) + return + + # Now the number of running and expected worker should be equal + + # If workers should be restarted periodically. + if self.worker_refresh_interval > 0 and self._last_refresh_time: + # and we refreshed the workers a long time ago, refresh the workers + last_refresh_diff = time.monotonic() - self._last_refresh_time + if self.worker_refresh_interval < last_refresh_diff: + num_new_workers = self.worker_refresh_batch_size + self.log.debug( + '[%d / %d] Starting doing a refresh. Starting %d workers.', + num_ready_workers_running, + num_workers_running, + num_new_workers, + ) + self._spawn_new_workers(num_new_workers) + self._last_refresh_time = time.monotonic() + return + + # if we should check the directory with the plugin, + if self.reload_on_plugin_change: + # compare the previous and current contents of the directory + new_state = self._generate_plugin_state() + # If changed, wait until its content is fully saved. + if new_state != self._last_plugin_state: + self.log.debug( + '[%d / %d] Plugins folder changed. The gunicorn will be restarted the next time the ' + 'plugin directory is checked, if there is no change in it.', + num_ready_workers_running, + num_workers_running, + ) + self._restart_on_next_plugin_check = True + self._last_plugin_state = new_state + elif self._restart_on_next_plugin_check: + self.log.debug( + '[%d / %d] Starts reloading the gunicorn configuration.', + num_ready_workers_running, + num_workers_running, + ) + self._restart_on_next_plugin_check = False + self._last_refresh_time = time.monotonic() + self._reload_gunicorn() + + +@airflow_cmd.command('webserver') +@click.option( + '-p', + '--port', + default=conf.get('webserver', 'WEB_SERVER_PORT'), + type=int, + help="The port on which to run the server", +) +@click.option( + "-w", + "--workers", + default=conf.get('webserver', 'WORKERS'), + type=int, + help="Number of workers to run the webserver on", +) +@click.option( + "-k", + "--workerclass", + default=conf.get('webserver', 'WORKER_CLASS'), + type=click.Choice(['sync', 'eventlet', 'gevent', 'tornado']), + help="The worker class to use for Gunicorn", +) +@click.option( + "-t", + "--worker-timeout", + default=conf.get('webserver', 'WEB_SERVER_WORKER_TIMEOUT'), + type=int, + help="The timeout for waiting on webserver workers", +) +@click.option( + "-H", + "--hostname", + default=conf.get('webserver', 'WEB_SERVER_HOST'), + help="Set the hostname on which to run the web server", +) +@click_pid +@click_daemon +@click_stdout +@click_stderr +@click.option( + "-A", + "--access-logfile", + type=click.Path(exists=True, dir_okay=False, writable=True, allow_dash=True), + default=conf.get('webserver', 'ACCESS_LOGFILE'), + help="The logfile to store the webserver access log. Use '-' to print to stderr", +) +@click.option( + "-E", + "--error-logfile", + type=click.Path(exists=True, dir_okay=False, writable=True, allow_dash=True), + default=conf.get('webserver', 'ERROR_LOGFILE'), + help="The logfile to store the webserver error log. Use '-' to print to stderr", +) +@click.option( + "-L", + "--access-logformat", + default=conf.get('webserver', 'ACCESS_LOGFORMAT'), + help="The access log format for gunicorn logs", +) +@click_log_file +@click.option( + "--ssl-cert", + type=click.Path(exists=True, dir_okay=False, writable=True), + default=conf.get('webserver', 'WEB_SERVER_SSL_CERT'), + help="Path to the SSL certificate for the webserver", +) +@click.option( + "--ssl-key", + type=click.Path(exists=True, dir_okay=False, writable=True), + default=conf.get('webserver', 'WEB_SERVER_SSL_KEY'), + help="Path to the key to use with the SSL certificate", +) +@click_debug +@cli_utils.action_cli +def webserver( + ctx, + port, + workers, + workerclass, + worker_timeout, + hostname, + pid, + daemon_, + stdout, + stderr, + access_logfile, + error_logfile, + access_logformat, + log_file, + ssl_cert, + ssl_key, + debug, +): + """Starts Airflow Webserver""" + console = Console() + console.print(settings.HEADER) + + # Check for old/insecure config, and fail safe (i.e. don't launch) if the config is wildly insecure. + if conf.get('webserver', 'secret_key') == 'temporary_key': + from rich import print as rich_print + + rich_print( + "[red][bold]ERROR:[/bold] The `secret_key` setting under the webserver config has an insecure " + "value - Airflow has failed safe and refuses to start. Please change this value to a new, " + "per-environment, randomly generated string, for example using this command `[cyan]openssl rand " + "-hex 30[/cyan]`", + file=sys.stderr, + ) + sys.exit(1) + + access_logfile = access_logfile or conf.get('webserver', 'access_logfile') + error_logfile = error_logfile or conf.get('webserver', 'error_logfile') + access_logformat = access_logformat or conf.get('webserver', 'access_logformat') + num_workers = workers or conf.get('webserver', 'workers') + worker_timeout = worker_timeout or conf.get('webserver', 'web_server_worker_timeout') + ssl_cert = ssl_cert or conf.get('webserver', 'web_server_ssl_cert') + ssl_key = ssl_key or conf.get('webserver', 'web_server_ssl_key') + if not ssl_cert and ssl_key: + raise AirflowException('An SSL certificate must also be provided for use with ' + ssl_key) + if ssl_cert and not ssl_key: + raise AirflowException('An SSL key must also be provided for use with ' + ssl_cert) + + if debug: + console.print(f"Starting the web server on port {port} and host {hostname}.") + app = create_app(testing=conf.getboolean('core', 'unit_test_mode')) + app.run( + debug=True, + use_reloader=not app.config['TESTING'], + port=port, + host=hostname, + ssl_context=(ssl_cert, ssl_key) if ssl_cert and ssl_key else None, + ) + else: + + pid_file, stdout, stderr, log_file = setup_locations("webserver", pid, stdout, stderr, log_file) + + # Check if webserver is already running if not, remove old pidfile + check_if_pidfile_process_is_running(pid_file=pid_file, process_name="webserver") + + console.print( + textwrap.dedent( + f'''\ + Running the Gunicorn Server with: + Workers: {num_workers} {workerclass} + Host: {hostname}:{port} + Timeout: {worker_timeout} + Logfiles: {access_logfile} {error_logfile} + Access Logformat: {access_logformat} + =================================================================''' + ) + ) + + run_args = [ + sys.executable, + '-m', + 'gunicorn', + '--workers', + str(num_workers), + '--worker-class', + str(workerclass), + '--timeout', + str(worker_timeout), + '--bind', + hostname + ':' + str(port), + '--name', + 'airflow-webserver', + '--pid', + pid_file, + '--config', + 'python:airflow.www.gunicorn_config', + ] + + if access_logfile: + run_args += ['--access-logfile', str(access_logfile)] + + if error_logfile: + run_args += ['--error-logfile', str(error_logfile)] + + if access_logformat and access_logformat.strip(): + run_args += ['--access-logformat', str(access_logformat)] + + if daemon_: + run_args += ['--daemon'] + + if ssl_cert: + run_args += ['--certfile', ssl_cert, '--keyfile', ssl_key] + + run_args += ["airflow.www.app:cached_app()"] + + gunicorn_master_proc = None + + def kill_proc(signum, _): + log.info("Received signal: %s. Closing gunicorn.", signum) + gunicorn_master_proc.terminate() + with suppress(TimeoutError): + gunicorn_master_proc.wait(timeout=30) + if gunicorn_master_proc.poll() is not None: + gunicorn_master_proc.kill() + sys.exit(0) + + def monitor_gunicorn(gunicorn_master_pid: int): + # Register signal handlers + signal.signal(signal.SIGINT, kill_proc) + signal.signal(signal.SIGTERM, kill_proc) + + # These run forever until SIG{INT, TERM, KILL, ...} signal is sent + GunicornMonitor( + gunicorn_master_pid=gunicorn_master_pid, + num_workers_expected=num_workers, + master_timeout=conf.getint('webserver', 'web_server_master_timeout'), + worker_refresh_interval=conf.getint('webserver', 'worker_refresh_interval', fallback=30), + worker_refresh_batch_size=conf.getint('webserver', 'worker_refresh_batch_size', fallback=1), + reload_on_plugin_change=conf.getboolean( + 'webserver', 'reload_on_plugin_change', fallback=False + ), + ).start() + + if daemon_: + # This makes possible errors get reported before daemonization + os.environ['SKIP_DAGS_PARSING'] = 'True' + app = create_app(None) + os.environ.pop('SKIP_DAGS_PARSING') + + handle = setup_logging(log_file) + + base, ext = os.path.splitext(pid_file) + with open(stdout, 'w+') as stdout, open(stderr, 'w+') as stderr: + ctx = daemon.DaemonContext( + pidfile=TimeoutPIDLockFile(f"{base}-monitor{ext}", -1), + files_preserve=[handle], + stdout=stdout, + stderr=stderr, + ) + with ctx: + subprocess.Popen(run_args, close_fds=True) + + # Reading pid of gunicorn master as it will be different that + # the one of process spawned above. + while True: + sleep(0.1) + gunicorn_master_proc_pid = read_pid_from_pidfile(pid_file) + if gunicorn_master_proc_pid: + break + + # Run Gunicorn monitor + gunicorn_master_proc = psutil.Process(gunicorn_master_proc_pid) + monitor_gunicorn(gunicorn_master_proc.pid) + + else: + with subprocess.Popen(run_args, close_fds=True) as gunicorn_master_proc: + monitor_gunicorn(gunicorn_master_proc.pid) From 3119bff26f5dd564b1927b1d24f77ede95496b6b Mon Sep 17 00:00:00 2001 From: blag Date: Fri, 1 Apr 2022 16:30:32 -0700 Subject: [PATCH 04/34] Convert standalone commands to use click --- airflow/cli/__main__.py | 1 + airflow/cli/commands/standalone.py | 299 +++++++++++++++++++++++++++++ 2 files changed, 300 insertions(+) create mode 100644 airflow/cli/commands/standalone.py diff --git a/airflow/cli/__main__.py b/airflow/cli/__main__.py index 1d8dd73a58046..9fe7b30328f68 100644 --- a/airflow/cli/__main__.py +++ b/airflow/cli/__main__.py @@ -19,6 +19,7 @@ from airflow.cli import airflow_cmd from airflow.cli.commands import db # noqa: F401 +from airflow.cli.commands import standalone # noqa: F401 from airflow.cli.commands import webserver # noqa: F401 if __name__ == '__main__': diff --git a/airflow/cli/commands/standalone.py b/airflow/cli/commands/standalone.py new file mode 100644 index 0000000000000..0046cd45b8ba6 --- /dev/null +++ b/airflow/cli/commands/standalone.py @@ -0,0 +1,299 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +import os +import random +import socket +import subprocess +import threading +import time +from collections import deque +from typing import Dict, List + +import rich_click as click +from rich.console import Console +from termcolor import colored + +from airflow.cli import airflow_cmd +from airflow.configuration import AIRFLOW_HOME, conf +from airflow.executors import executor_constants +from airflow.jobs.scheduler_job import SchedulerJob +from airflow.jobs.triggerer_job import TriggererJob +from airflow.utils import db +from airflow.www.app import cached_app + + +@airflow_cmd.command('standalone') +@click.pass_context +def standalone(ctx): + """Run an all-in-one copy of Airflow""" + StandaloneCommand.entrypoint() + + +class StandaloneCommand: + """ + Runs all components of Airflow under a single parent process. + + Useful for local development. + """ + + @classmethod + def entrypoint(cls, args): + """CLI entrypoint, called by the main CLI system.""" + StandaloneCommand().run() + + def __init__(self): + self.subcommands = {} + self.output_queue = deque() + self.user_info = {} + self.ready_time = None + self.ready_delay = 3 + self.console = Console() + + def run(self): + """Main run loop""" + self.print_output("standalone", "Starting Airflow Standalone") + # Silence built-in logging at INFO + logging.getLogger("").setLevel(logging.WARNING) + # Startup checks and prep + env = self.calculate_env() + self.initialize_database() + # Set up commands to run + self.subcommands["scheduler"] = SubCommand( + self, + name="scheduler", + command=["scheduler"], + env=env, + ) + self.subcommands["webserver"] = SubCommand( + self, + name="webserver", + command=["webserver"], + env=env, + ) + self.subcommands["triggerer"] = SubCommand( + self, + name="triggerer", + command=["triggerer"], + env=env, + ) + + self.web_server_port = conf.getint('webserver', 'WEB_SERVER_PORT', fallback=8080) + # Run subcommand threads + for command in self.subcommands.values(): + command.start() + # Run output loop + shown_ready = False + while True: + try: + # Print all the current lines onto the screen + self.update_output() + # Print info banner when all components are ready and the + # delay has passed + if not self.ready_time and self.is_ready(): + self.ready_time = time.monotonic() + if ( + not shown_ready + and self.ready_time + and time.monotonic() - self.ready_time > self.ready_delay + ): + self.print_ready() + shown_ready = True + # Ensure we idle-sleep rather than fast-looping + time.sleep(0.1) + except KeyboardInterrupt: + break + # Stop subcommand threads + self.print_output("standalone", "Shutting down components") + for command in self.subcommands.values(): + command.stop() + for command in self.subcommands.values(): + command.join() + self.print_output("standalone", "Complete") + + def update_output(self): + """Drains the output queue and prints its contents to the screen""" + while self.output_queue: + # Extract info + name, line = self.output_queue.popleft() + # Make line printable + line_str = line.decode("utf8").strip() + self.print_output(name, line_str) + + def print_output(self, name: str, output): + """ + Prints an output line with name and colouring. You can pass multiple + lines to output if you wish; it will be split for you. + """ + color = { + "webserver": "green", + "scheduler": "blue", + "triggerer": "cyan", + "standalone": "white", + }.get(name, "white") + colorised_name = colored("%10s" % name, color) + for line in output.split("\n"): + self.console.print(f"{colorised_name} | {line.strip()}") + + def print_error(self, name: str, output): + """ + Prints an error message to the console (this is the same as + print_output but with the text red) + """ + self.print_output(name, colored(output, "red")) + + def calculate_env(self): + """ + Works out the environment variables needed to run subprocesses. + We override some settings as part of being standalone. + """ + env = dict(os.environ) + # Make sure we're using a local executor flavour + if conf.get("core", "executor") not in [ + executor_constants.LOCAL_EXECUTOR, + executor_constants.SEQUENTIAL_EXECUTOR, + ]: + if "sqlite" in conf.get("core", "sql_alchemy_conn"): + self.print_output("standalone", "Forcing executor to SequentialExecutor") + env["AIRFLOW__CORE__EXECUTOR"] = executor_constants.SEQUENTIAL_EXECUTOR + else: + self.print_output("standalone", "Forcing executor to LocalExecutor") + env["AIRFLOW__CORE__EXECUTOR"] = executor_constants.LOCAL_EXECUTOR + return env + + def initialize_database(self): + """Makes sure all the tables are created.""" + # Set up DB tables + self.print_output("standalone", "Checking database is initialized") + db.initdb() + self.print_output("standalone", "Database ready") + # See if a user needs creating + # We want a streamlined first-run experience, but we do not want to + # use a preset password as people will inevitably run this on a public + # server. Thus, we make a random password and store it in AIRFLOW_HOME, + # with the reasoning that if you can read that directory, you can see + # the database credentials anyway. + appbuilder = cached_app().appbuilder + user_exists = appbuilder.sm.find_user("admin") + password_path = os.path.join(AIRFLOW_HOME, "standalone_admin_password.txt") + we_know_password = os.path.isfile(password_path) + # If the user does not exist, make a random password and make it + if not user_exists: + self.print_output("standalone", "Creating admin user") + role = appbuilder.sm.find_role("Admin") + assert role is not None + password = "".join( + random.choice("abcdefghkmnpqrstuvwxyzABCDEFGHKMNPQRSTUVWXYZ23456789") for i in range(16) + ) + with open(password_path, "w") as file: + file.write(password) + appbuilder.sm.add_user("admin", "Admin", "User", "admin@example.com", role, password) + self.print_output("standalone", "Created admin user") + # If the user does exist and we know its password, read the password + elif user_exists and we_know_password: + with open(password_path) as file: + password = file.read().strip() + # Otherwise we don't know the password + else: + password = None + # Store what we know about the user for printing later in startup + self.user_info = {"username": "admin", "password": password} + + def is_ready(self): + """ + Detects when all Airflow components are ready to serve. + For now, it's simply time-based. + """ + return ( + self.port_open(self.web_server_port) + and self.job_running(SchedulerJob) + and self.job_running(TriggererJob) + ) + + def port_open(self, port): + """ + Checks if the given port is listening on the local machine. + (used to tell if webserver is alive) + """ + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(1) + sock.connect(("127.0.0.1", port)) + sock.close() + except (OSError, ValueError): + # Any exception means the socket is not available + return False + return True + + def job_running(self, job): + """ + Checks if the given job name is running and heartbeating correctly + (used to tell if scheduler is alive) + """ + recent = job.most_recent_job() + if not recent: + return False + return recent.is_alive() + + def print_ready(self): + """ + Prints the banner shown when Airflow is ready to go, with login + details. + """ + self.print_output("standalone", "") + self.print_output("standalone", "Airflow is ready") + if self.user_info["password"]: + self.print_output( + "standalone", + f"Login with username: {self.user_info['username']} password: {self.user_info['password']}", + ) + self.print_output( + "standalone", + "Airflow Standalone is for development purposes only. Do not use this in production!", + ) + self.print_output("standalone", "") + + +class SubCommand(threading.Thread): + """ + Thread that launches a process and then streams its output back to the main + command. We use threads to avoid using select() and raw filehandles, and the + complex logic that brings doing line buffering. + """ + + def __init__(self, parent, name: str, command: List[str], env: Dict[str, str]): + super().__init__() + self.parent = parent + self.name = name + self.command = command + self.env = env + + def run(self): + """Runs the actual process and captures it output to a queue""" + self.process = subprocess.Popen( + ["airflow"] + self.command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=self.env, + ) + for line in self.process.stdout: + self.parent.output_queue.append((self.name, line)) + + def stop(self): + """Call to stop this process (and thus this thread)""" + self.process.terminate() From f320d0169eedf6dacb57d2430073360e27a5d5d5 Mon Sep 17 00:00:00 2001 From: blag Date: Fri, 1 Apr 2022 17:46:36 -0700 Subject: [PATCH 05/34] Convert scheduler commands to use click --- airflow/cli/__main__.py | 1 + airflow/cli/commands/scheduler.py | 129 ++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+) create mode 100644 airflow/cli/commands/scheduler.py diff --git a/airflow/cli/__main__.py b/airflow/cli/__main__.py index 9fe7b30328f68..e2458cb517925 100644 --- a/airflow/cli/__main__.py +++ b/airflow/cli/__main__.py @@ -19,6 +19,7 @@ from airflow.cli import airflow_cmd from airflow.cli.commands import db # noqa: F401 +from airflow.cli.commands import scheduler # noqa: F401 from airflow.cli.commands import standalone # noqa: F401 from airflow.cli.commands import webserver # noqa: F401 diff --git a/airflow/cli/commands/scheduler.py b/airflow/cli/commands/scheduler.py new file mode 100644 index 0000000000000..e9524f3fe62f7 --- /dev/null +++ b/airflow/cli/commands/scheduler.py @@ -0,0 +1,129 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Scheduler command""" +import signal +from multiprocessing import Process +from typing import Optional + +import daemon +import rich_click as click +from daemon.pidfile import TimeoutPIDLockFile +from rich.console import Console + +from airflow import settings +from airflow.cli import ( + airflow_cmd, + click_daemon, + click_log_file, + click_pid, + click_stderr, + click_stdout, + click_subdir, +) +from airflow.configuration import conf +from airflow.jobs.scheduler_job import SchedulerJob +from airflow.utils import cli as cli_utils +from airflow.utils.cli import process_subdir, setup_locations, setup_logging, sigint_handler, sigquit_handler + + +def _create_scheduler_job(subdir, num_runs, do_pickle): + job = SchedulerJob( + subdir=process_subdir(subdir), + num_runs=num_runs, + do_pickle=do_pickle, + ) + return job + + +def _run_scheduler_job(subdir, num_runs, do_pickle, skip_serve_logs): + job = _create_scheduler_job(subdir, num_runs, do_pickle) + sub_proc = _serve_logs(skip_serve_logs) + try: + job.run() + finally: + if sub_proc: + sub_proc.terminate() + + +@airflow_cmd.command('scheduler') +@click_daemon +@click.option( + "-p", + "--do-pickle", + is_flag=True, + default=False, + help=( + "Attempt to pickle the DAG object to send over to the workers, instead of letting workers " + "run their version of the code" + ), +) +@click_log_file +@click.option( + "-n", + "--num-runs", + type=int, + default=conf.get('scheduler', 'num_runs'), + help="Set the number of runs to execute before exiting", +) +@click_pid +@click.option( + "-s", + "--skip-serve-logs", + is_flag=True, + default=False, + help="Don't start the serve logs process along with the workers", +) +@click_stderr +@click_stdout +@click_subdir +@cli_utils.action_cli +def scheduler(ctx, daemon_, do_pickle, log_file, num_runs, pid, skip_serve_logs, stderr, stdout, subdir): + """Starts Airflow Scheduler""" + console = Console() + console.print(settings.HEADER) + + if daemon_: + pid, stdout, stderr, log_file = setup_locations("scheduler", pid, stdout, stderr, log_file) + handle = setup_logging(log_file) + with open(stdout, 'w+') as stdout_handle, open(stderr, 'w+') as stderr_handle: + ctx = daemon.DaemonContext( + pidfile=TimeoutPIDLockFile(pid, -1), + files_preserve=[handle], + stdout=stdout_handle, + stderr=stderr_handle, + ) + with ctx: + _run_scheduler_job(subdir, num_runs, do_pickle, skip_serve_logs) + else: + signal.signal(signal.SIGINT, sigint_handler) + signal.signal(signal.SIGTERM, sigint_handler) + signal.signal(signal.SIGQUIT, sigquit_handler) + _run_scheduler_job(subdir, num_runs, do_pickle, skip_serve_logs) + + +def _serve_logs(skip_serve_logs: bool = False) -> Optional[Process]: + """Starts serve_logs sub-process""" + from airflow.configuration import conf + from airflow.utils.serve_logs import serve_logs + + if conf.get("core", "executor") in ["LocalExecutor", "SequentialExecutor"]: + if skip_serve_logs is False: + sub_proc = Process(target=serve_logs) + sub_proc.start() + return sub_proc + return None From a8a5be701024fe63c47c391a7a9587b7cd635e30 Mon Sep 17 00:00:00 2001 From: blag Date: Fri, 1 Apr 2022 21:12:03 -0700 Subject: [PATCH 06/34] Convert triggerer commands to use click --- airflow/cli/__main__.py | 1 + airflow/cli/commands/triggerer.py | 69 +++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 airflow/cli/commands/triggerer.py diff --git a/airflow/cli/__main__.py b/airflow/cli/__main__.py index e2458cb517925..2dc38e2e69684 100644 --- a/airflow/cli/__main__.py +++ b/airflow/cli/__main__.py @@ -21,6 +21,7 @@ from airflow.cli.commands import db # noqa: F401 from airflow.cli.commands import scheduler # noqa: F401 from airflow.cli.commands import standalone # noqa: F401 +from airflow.cli.commands import triggerer # noqa: F401 from airflow.cli.commands import webserver # noqa: F401 if __name__ == '__main__': diff --git a/airflow/cli/commands/triggerer.py b/airflow/cli/commands/triggerer.py new file mode 100644 index 0000000000000..38d93d8f11399 --- /dev/null +++ b/airflow/cli/commands/triggerer.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Triggerer command""" +import signal + +import daemon +import rich_click as click +from daemon.pidfile import TimeoutPIDLockFile +from rich.console import Console + +from airflow import settings +from airflow.cli import airflow_cmd, click_daemon, click_log_file, click_pid, click_stderr, click_stdout +from airflow.jobs.triggerer_job import TriggererJob +from airflow.utils import cli as cli_utils +from airflow.utils.cli import setup_locations, setup_logging, sigint_handler, sigquit_handler + + +@airflow_cmd.command('triggerer') +@click.option( + "--capacity", + type=click.IntRange(min=1), + help="The maximum number of triggers that a Triggerer will run at one time", +) +@click_daemon +@click_log_file +@click_pid +@click_stderr +@click_stdout +@cli_utils.action_cli +def triggerer(ctx, capacity, daemon_, log_file, pid, stderr, stdout): + """Starts Airflow Triggerer""" + console = Console() + settings.MASK_SECRETS_IN_LOGS = True + console.print(settings.HEADER) + job = TriggererJob(capacity=capacity) + + if daemon_: + pid, stdout, stderr, log_file = setup_locations("triggerer", pid, stdout, stderr, log_file) + handle = setup_logging(log_file) + with open(stdout, 'w+') as stdout_handle, open(stderr, 'w+') as stderr_handle: + ctx = daemon.DaemonContext( + pidfile=TimeoutPIDLockFile(pid, -1), + files_preserve=[handle], + stdout=stdout_handle, + stderr=stderr_handle, + ) + with ctx: + job.run() + + else: + signal.signal(signal.SIGINT, sigint_handler) + signal.signal(signal.SIGTERM, sigint_handler) + signal.signal(signal.SIGQUIT, sigquit_handler) + job.run() From 7994c45637cf79e9ea4d923dcef7042928a12498 Mon Sep 17 00:00:00 2001 From: Karthikeyan Singaravelan Date: Sun, 3 Apr 2022 11:51:11 +0530 Subject: [PATCH 07/34] Convert sync-perm command to click and make tests compatible. --- airflow/cli/__main__.py | 1 + airflow/cli/commands/sync_perm.py | 36 +++++++++++++ tests/cli/commands/test_sync_perm_command.py | 54 ++++++++++++++------ 3 files changed, 75 insertions(+), 16 deletions(-) create mode 100644 airflow/cli/commands/sync_perm.py diff --git a/airflow/cli/__main__.py b/airflow/cli/__main__.py index 2dc38e2e69684..21a564595978f 100644 --- a/airflow/cli/__main__.py +++ b/airflow/cli/__main__.py @@ -21,6 +21,7 @@ from airflow.cli.commands import db # noqa: F401 from airflow.cli.commands import scheduler # noqa: F401 from airflow.cli.commands import standalone # noqa: F401 +from airflow.cli.commands import sync_perm # noqa: F401 from airflow.cli.commands import triggerer # noqa: F401 from airflow.cli.commands import webserver # noqa: F401 diff --git a/airflow/cli/commands/sync_perm.py b/airflow/cli/commands/sync_perm.py new file mode 100644 index 0000000000000..8272f53417bde --- /dev/null +++ b/airflow/cli/commands/sync_perm.py @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Sync permission command""" +import click + +from airflow.cli import airflow_cmd +from airflow.www.app import cached_app + + +@airflow_cmd.command("sync-perm") +@click.option("--include-dags", is_flag=True, help="If passed, DAG specific permissions will also be synced.") +def sync_perm(include_dags): + """Updates permissions for existing roles and DAGs""" + appbuilder = cached_app().appbuilder + print('Updating actions and resources for all existing roles') + # Add missing permissions for all the Base Views _before_ syncing/creating roles + appbuilder.add_permissions(update_perms=True) + appbuilder.sm.sync_roles() + if include_dags: + print('Updating permission on all DAG views') + appbuilder.sm.create_dag_specific_permissions() diff --git a/tests/cli/commands/test_sync_perm_command.py b/tests/cli/commands/test_sync_perm_command.py index 9fff6d686fa19..a1df6793d247f 100644 --- a/tests/cli/commands/test_sync_perm_command.py +++ b/tests/cli/commands/test_sync_perm_command.py @@ -16,37 +16,59 @@ # specific language governing permissions and limitations # under the License. # -import unittest from unittest import mock +import pytest +from click.testing import CliRunner + from airflow.cli import cli_parser +from airflow.cli.__main__ import airflow_cmd from airflow.cli.commands import sync_perm_command -class TestCliSyncPerm(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.parser = cli_parser.get_parser() +@pytest.fixture +def setup_test(request): + if request.param == "click": + module = "airflow.cli.commands.sync_perm.cached_app" + else: + module = "airflow.cli.commands.sync_perm_command.cached_app" - @mock.patch("airflow.cli.commands.sync_perm_command.cached_app") - def test_cli_sync_perm(self, mock_cached_app): - appbuilder = mock_cached_app.return_value.appbuilder + with mock.patch(module) as mocked_cached_app: + appbuilder = mocked_cached_app.return_value.appbuilder appbuilder.sm = mock.Mock() + yield request.param, appbuilder + - args = self.parser.parse_args(['sync-perm']) - sync_perm_command.sync_perm(args) +class TestCliSyncPerm: + @pytest.mark.parametrize("setup_test", ["click", "argparse"], indirect=True) + def test_cli_sync_perm_1(self, setup_test): + parser, appbuilder = setup_test + cli_args = ['sync-perm'] + + if parser == "click": + runner = CliRunner() + runner.invoke(airflow_cmd, cli_args) + else: + parser = cli_parser.get_parser() + args = parser.parse_args(cli_args) + sync_perm_command.sync_perm(args) appbuilder.add_permissions.assert_called_once_with(update_perms=True) appbuilder.sm.sync_roles.assert_called_once_with() appbuilder.sm.create_dag_specific_permissions.assert_not_called() - @mock.patch("airflow.cli.commands.sync_perm_command.cached_app") - def test_cli_sync_perm_include_dags(self, mock_cached_app): - appbuilder = mock_cached_app.return_value.appbuilder - appbuilder.sm = mock.Mock() + @pytest.mark.parametrize("setup_test", ["click", "argparse"], indirect=True) + def test_cli_sync_perm_include_dags(self, setup_test): + parser, appbuilder = setup_test + cli_args = ['sync-perm', '--include-dags'] - args = self.parser.parse_args(['sync-perm', '--include-dags']) - sync_perm_command.sync_perm(args) + if parser == "click": + runner = CliRunner() + runner.invoke(airflow_cmd, cli_args) + else: + parser = cli_parser.get_parser() + args = parser.parse_args(cli_args) + sync_perm_command.sync_perm(args) appbuilder.add_permissions.assert_called_once_with(update_perms=True) appbuilder.sm.sync_roles.assert_called_once_with() From 3dabb718b1ee5032d21a14f082d2df0fe12ef52c Mon Sep 17 00:00:00 2001 From: Karthikeyan Singaravelan Date: Sun, 3 Apr 2022 12:23:22 +0530 Subject: [PATCH 08/34] Match doc with original command. --- airflow/cli/commands/sync_perm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/cli/commands/sync_perm.py b/airflow/cli/commands/sync_perm.py index 8272f53417bde..436a122c62c74 100644 --- a/airflow/cli/commands/sync_perm.py +++ b/airflow/cli/commands/sync_perm.py @@ -25,7 +25,7 @@ @airflow_cmd.command("sync-perm") @click.option("--include-dags", is_flag=True, help="If passed, DAG specific permissions will also be synced.") def sync_perm(include_dags): - """Updates permissions for existing roles and DAGs""" + """Update permissions for existing roles and optionally DAGs""" appbuilder = cached_app().appbuilder print('Updating actions and resources for all existing roles') # Add missing permissions for all the Base Views _before_ syncing/creating roles From 0392107def2d3c9181a79ee65384a7e4c4a0c3ea Mon Sep 17 00:00:00 2001 From: Karthikeyan Singaravelan Date: Sun, 3 Apr 2022 14:09:06 +0530 Subject: [PATCH 09/34] Use local imports instead of toplevel imports to optimise startup time. --- airflow/cli/commands/db.py | 34 +++++++++++++++++++++++++----- airflow/cli/commands/scheduler.py | 3 ++- airflow/cli/commands/standalone.py | 11 ++++++---- airflow/cli/commands/triggerer.py | 3 ++- airflow/cli/commands/webserver.py | 3 ++- 5 files changed, 42 insertions(+), 12 deletions(-) diff --git a/airflow/cli/commands/db.py b/airflow/cli/commands/db.py index 080df6f0b8e20..bf7357d332d84 100644 --- a/airflow/cli/commands/db.py +++ b/airflow/cli/commands/db.py @@ -27,9 +27,7 @@ from airflow import settings from airflow.cli import airflow_cmd, click_dry_run, click_verbose, click_yes from airflow.exceptions import AirflowException -from airflow.utils import cli as cli_utils, db as db_utils -from airflow.utils.db import REVISION_HEADS_MAP -from airflow.utils.db_cleanup import config_dict, run_cleanup +from airflow.utils import cli as cli_utils from airflow.utils.process_utils import execute_interactive click_revision = click.option( @@ -72,6 +70,8 @@ def db(ctx): @click.pass_context def db_init(ctx): """Initializes the metadata database""" + from airflow.utils import db as db_utils + console = Console() console.print(f"DB: {settings.engine.url}") db_utils.initdb() @@ -88,6 +88,8 @@ def db_init(ctx): ) def check_migrations(ctx, migration_wait_timeout): """Function to wait for all airflow migrations to complete. Used for launching airflow in k8s""" + from airflow.utils import db as db_utils + console = Console() console.print(f"Waiting for {migration_wait_timeout}s") db_utils.check_migrations(timeout=migration_wait_timeout) @@ -98,9 +100,12 @@ def check_migrations(ctx, migration_wait_timeout): @click_yes def db_reset(ctx, yes=False): """Burn down and rebuild the metadata database""" + console = Console() console.print(f"DB: {settings.engine.url}") if yes or click.confirm("This will drop existing tables if they exist. Proceed? (y/n)"): + from airflow.utils import db as db_utils + db_utils.resetdb() else: console.print("Cancelled") @@ -125,6 +130,8 @@ def wrapper(ctx, revision, version, from_revision, from_version, *_args, show_sq if from_version is not None: if parse_version(from_version) < parse_version('2.0.0'): raise SystemExit("--from-version must be greater than or equal to 2.0.0") + from airflow.utils.db import REVISION_HEADS_MAP + from_revision = REVISION_HEADS_MAP.get(from_version) if not from_revision: raise SystemExit(f"Unknown version {from_version!r} supplied as `--from-version`.") @@ -176,6 +183,8 @@ def upgrade(ctx, revision, version, from_revision, from_version, show_sql_only=F "operation. Proceed? (y/n)\n" ) ): + from airflow.utils import db as db_utils + db_utils.upgradedb(to_revision=revision, from_revision=from_revision, show_sql_only=show_sql_only) if not show_sql_only: console.print("Upgrades done") @@ -218,6 +227,8 @@ def downgrade(ctx, revision, version, from_revision, from_version, show_sql_only "operation. Proceed? (y/n)\n" ) ): + from airflow.utils import db as db_utils + db_utils.downgrade(to_revision=revision, from_revision=from_revision, show_sql_only=show_sql_only) if not show_sql_only: console.print("Downgrades done") @@ -283,13 +294,24 @@ def shell(ctx): @cli_utils.action_cli(check_db=False) def check(ctx, migration_wait_timeout): """Runs a check command that checks if db is reachable""" + from airflow.utils import db as db_utils + console = Console() console.print(f"Waiting for {migration_wait_timeout}s") db_utils.check_migrations(timeout=migration_wait_timeout) # lazily imported by CLI parser for `help` command -all_tables = sorted(config_dict) +# Create a custom class that emulates a callable since click validates +# non-callable and make __str__ to return output for lazy processing. +class _CleanTableDefault: + def __call__(self): + pass + + def __str__(self): + from airflow.utils.db_cleanup import config_dict + + return str(sorted(config_dict)) @db.command('cleanup') @@ -298,7 +320,7 @@ def check(ctx, migration_wait_timeout): '-t', '--tables', multiple=True, - default=all_tables, + default=_CleanTableDefault(), show_default=True, help=( "Table names to perform maintenance on (use comma-separated list).\n" @@ -319,6 +341,8 @@ def check(ctx, migration_wait_timeout): @cli_utils.action_cli(check_db=False) def cleanup_tables(ctx, tables, clean_before_timestamp, dry_run, verbose, yes): """Purge old records in metastore tables""" + from airflow.utils.db_cleanup import run_cleanup + split_tables = [] for table in tables: split_tables.extend(table.split(',')) diff --git a/airflow/cli/commands/scheduler.py b/airflow/cli/commands/scheduler.py index e9524f3fe62f7..42a46ead05bde 100644 --- a/airflow/cli/commands/scheduler.py +++ b/airflow/cli/commands/scheduler.py @@ -36,12 +36,13 @@ click_subdir, ) from airflow.configuration import conf -from airflow.jobs.scheduler_job import SchedulerJob from airflow.utils import cli as cli_utils from airflow.utils.cli import process_subdir, setup_locations, setup_logging, sigint_handler, sigquit_handler def _create_scheduler_job(subdir, num_runs, do_pickle): + from airflow.jobs.scheduler_job import SchedulerJob + job = SchedulerJob( subdir=process_subdir(subdir), num_runs=num_runs, diff --git a/airflow/cli/commands/standalone.py b/airflow/cli/commands/standalone.py index 0046cd45b8ba6..dcebdfc2f3dd1 100644 --- a/airflow/cli/commands/standalone.py +++ b/airflow/cli/commands/standalone.py @@ -32,10 +32,6 @@ from airflow.cli import airflow_cmd from airflow.configuration import AIRFLOW_HOME, conf from airflow.executors import executor_constants -from airflow.jobs.scheduler_job import SchedulerJob -from airflow.jobs.triggerer_job import TriggererJob -from airflow.utils import db -from airflow.www.app import cached_app @airflow_cmd.command('standalone') @@ -179,6 +175,8 @@ def calculate_env(self): def initialize_database(self): """Makes sure all the tables are created.""" # Set up DB tables + from airflow.utils import db + self.print_output("standalone", "Checking database is initialized") db.initdb() self.print_output("standalone", "Database ready") @@ -188,6 +186,8 @@ def initialize_database(self): # server. Thus, we make a random password and store it in AIRFLOW_HOME, # with the reasoning that if you can read that directory, you can see # the database credentials anyway. + from airflow.www.app import cached_app + appbuilder = cached_app().appbuilder user_exists = appbuilder.sm.find_user("admin") password_path = os.path.join(AIRFLOW_HOME, "standalone_admin_password.txt") @@ -219,6 +219,9 @@ def is_ready(self): Detects when all Airflow components are ready to serve. For now, it's simply time-based. """ + from airflow.jobs.scheduler_job import SchedulerJob + from airflow.jobs.triggerer_job import TriggererJob + return ( self.port_open(self.web_server_port) and self.job_running(SchedulerJob) diff --git a/airflow/cli/commands/triggerer.py b/airflow/cli/commands/triggerer.py index 38d93d8f11399..c2365e2656104 100644 --- a/airflow/cli/commands/triggerer.py +++ b/airflow/cli/commands/triggerer.py @@ -25,7 +25,6 @@ from airflow import settings from airflow.cli import airflow_cmd, click_daemon, click_log_file, click_pid, click_stderr, click_stdout -from airflow.jobs.triggerer_job import TriggererJob from airflow.utils import cli as cli_utils from airflow.utils.cli import setup_locations, setup_logging, sigint_handler, sigquit_handler @@ -44,6 +43,8 @@ @cli_utils.action_cli def triggerer(ctx, capacity, daemon_, log_file, pid, stderr, stdout): """Starts Airflow Triggerer""" + from airflow.jobs.triggerer_job import TriggererJob + console = Console() settings.MASK_SECRETS_IN_LOGS = True console.print(settings.HEADER) diff --git a/airflow/cli/commands/webserver.py b/airflow/cli/commands/webserver.py index a675008bae0a9..18b782726bbfd 100644 --- a/airflow/cli/commands/webserver.py +++ b/airflow/cli/commands/webserver.py @@ -51,7 +51,6 @@ from airflow.utils.cli import setup_locations, setup_logging from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.process_utils import check_if_pidfile_process_is_running -from airflow.www.app import create_app log = logging.getLogger(__name__) @@ -421,6 +420,8 @@ def webserver( console = Console() console.print(settings.HEADER) + from airflow.www.app import create_app + # Check for old/insecure config, and fail safe (i.e. don't launch) if the config is wildly insecure. if conf.get('webserver', 'secret_key') == 'temporary_key': from rich import print as rich_print From 129c46efd96a8946ad0264de55ec45a75db3fa71 Mon Sep 17 00:00:00 2001 From: Karthikeyan Singaravelan Date: Sun, 3 Apr 2022 14:34:18 +0530 Subject: [PATCH 10/34] Use local import instead of toplevel imports to optimise startup time. --- airflow/cli/commands/sync_perm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/airflow/cli/commands/sync_perm.py b/airflow/cli/commands/sync_perm.py index 436a122c62c74..ea856e07d21de 100644 --- a/airflow/cli/commands/sync_perm.py +++ b/airflow/cli/commands/sync_perm.py @@ -19,13 +19,14 @@ import click from airflow.cli import airflow_cmd -from airflow.www.app import cached_app @airflow_cmd.command("sync-perm") @click.option("--include-dags", is_flag=True, help="If passed, DAG specific permissions will also be synced.") def sync_perm(include_dags): """Update permissions for existing roles and optionally DAGs""" + from airflow.www.app import cached_app + appbuilder = cached_app().appbuilder print('Updating actions and resources for all existing roles') # Add missing permissions for all the Base Views _before_ syncing/creating roles From 71356ea141f8f8b161cef27e4b4d822fd36032ed Mon Sep 17 00:00:00 2001 From: Karthikeyan Singaravelan Date: Thu, 7 Apr 2022 11:29:26 +0530 Subject: [PATCH 11/34] Use rich.print. --- airflow/cli/commands/sync_perm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow/cli/commands/sync_perm.py b/airflow/cli/commands/sync_perm.py index ea856e07d21de..46819c2902e7f 100644 --- a/airflow/cli/commands/sync_perm.py +++ b/airflow/cli/commands/sync_perm.py @@ -17,6 +17,7 @@ # under the License. """Sync permission command""" import click +from rich import print from airflow.cli import airflow_cmd From 39cd62e078d170b1150fd1886fdba9e6916574ea Mon Sep 17 00:00:00 2001 From: hankehly Date: Wed, 13 Apr 2022 08:21:59 +0900 Subject: [PATCH 12/34] Add cheat_sheet module --- airflow/cli/__main__.py | 1 + airflow/cli/commands/cheat_sheet.py | 65 +++++++++++++++++++++++++++++ airflow/cli/commands/db.py | 12 +++--- airflow/utils/cli.py | 26 ++++++++++++ 4 files changed, 98 insertions(+), 6 deletions(-) create mode 100644 airflow/cli/commands/cheat_sheet.py diff --git a/airflow/cli/__main__.py b/airflow/cli/__main__.py index 2dc38e2e69684..32d5187b018ea 100644 --- a/airflow/cli/__main__.py +++ b/airflow/cli/__main__.py @@ -18,6 +18,7 @@ # under the License. from airflow.cli import airflow_cmd +from airflow.cli.commands import cheat_sheet # noqa: F401 from airflow.cli.commands import db # noqa: F401 from airflow.cli.commands import scheduler # noqa: F401 from airflow.cli.commands import standalone # noqa: F401 diff --git a/airflow/cli/commands/cheat_sheet.py b/airflow/cli/commands/cheat_sheet.py new file mode 100644 index 0000000000000..7dd479816f93c --- /dev/null +++ b/airflow/cli/commands/cheat_sheet.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List, Optional + +import rich_click as click + +from airflow.cli import airflow_cmd, click_verbose +from airflow.cli.simple_table import AirflowConsole, SimpleTable +from airflow.utils.cli import suppress_logs_and_warning_click_compatible + + +@airflow_cmd.command("cheat-sheet") +@click_verbose +@suppress_logs_and_warning_click_compatible +def cheat_sheet(verbose): + """Display cheat-sheet""" + display_commands_index() + + +def display_commands_index(): + def display_recursive( + prefix: List[str], + command_group: click.Group, + help_msg: Optional[str] = None, + help_msg_length: int = 88, + ): + actions: List[click.Command] = [] + groups: List[click.Group] = [] + for command in command_group.commands.values(): + if isinstance(command, click.Group): + groups.append(command) + else: + actions.append(command) + + console = AirflowConsole() + if actions: + table = SimpleTable(title=help_msg or "Miscellaneous commands") + table.add_column(width=40) + table.add_column() + for action_command in sorted(actions, key=lambda d: d.name): + help_str = action_command.get_short_help_str(limit=help_msg_length) + table.add_row(" ".join([*prefix, action_command.name]), help_str) + console.print(table) + + if groups: + for group_command in sorted(groups, key=lambda d: d.name): + group_prefix = [*prefix, group_command.name] + help_str = group_command.get_short_help_str(limit=help_msg_length) + display_recursive(group_prefix, group_command, help_str) + + display_recursive(["airflow"], airflow_cmd) diff --git a/airflow/cli/commands/db.py b/airflow/cli/commands/db.py index bf7357d332d84..8e52850dd4d37 100644 --- a/airflow/cli/commands/db.py +++ b/airflow/cli/commands/db.py @@ -63,13 +63,13 @@ @airflow_cmd.group() @click.pass_context def db(ctx): - """Commands for the metadata database""" + """Database operations""" @db.command('init') @click.pass_context def db_init(ctx): - """Initializes the metadata database""" + """Initialize the metadata database""" from airflow.utils import db as db_utils console = Console() @@ -87,7 +87,7 @@ def db_init(ctx): help="This command will wait for up to this time, specified in seconds", ) def check_migrations(ctx, migration_wait_timeout): - """Function to wait for all airflow migrations to complete. Used for launching airflow in k8s""" + """Wait for all airflow migrations to complete (used for launching airflow in k8s)""" from airflow.utils import db as db_utils console = Console() @@ -289,11 +289,11 @@ def shell(ctx): '--migration-wait-timeout', type=int, default=60, - help="Tmeout to wait for the database to migrate", + help="Timeout to wait for the database to migrate", ) @cli_utils.action_cli(check_db=False) def check(ctx, migration_wait_timeout): - """Runs a check command that checks if db is reachable""" + """Check if the database can be reached""" from airflow.utils import db as db_utils console = Console() @@ -314,7 +314,7 @@ def __str__(self): return str(sorted(config_dict)) -@db.command('cleanup') +@db.command('clean') @click.pass_context @click.option( '-t', diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py index 887dfcf1f827c..c3779becc216a 100644 --- a/airflow/utils/cli.py +++ b/airflow/utils/cli.py @@ -335,6 +335,32 @@ def _wrapper(*args, **kwargs): return cast(T, _wrapper) +def suppress_logs_and_warning_click_compatible(f: T) -> T: + """ + Click compatible version of suppress_logs_and_warning. + Place after click_verbose decorator. + + Decorator to suppress logging and warning messages + in cli functions. + """ + + @functools.wraps(f) + def _wrapper(*args, **kwargs): + if kwargs.get("verbose"): + f(*args, **kwargs) + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + logging.disable(logging.CRITICAL) + try: + f(*args, **kwargs) + finally: + # logging output again depends on the effective + # levels of individual loggers + logging.disable(logging.NOTSET) + + return cast(T, _wrapper) + def get_config_with_source(include_default: bool = False) -> str: """Return configuration along with source for each option.""" config_dict = conf.as_dict(display_source=True) From b29cfb99a0f52b57a391965651a5b408634514b8 Mon Sep 17 00:00:00 2001 From: hankehly Date: Thu, 14 Apr 2022 14:06:32 +0900 Subject: [PATCH 13/34] De-nest display_commands_index function --- airflow/cli/commands/cheat_sheet.py | 63 ++++++++++++++--------------- 1 file changed, 30 insertions(+), 33 deletions(-) diff --git a/airflow/cli/commands/cheat_sheet.py b/airflow/cli/commands/cheat_sheet.py index 7dd479816f93c..4e4a049310968 100644 --- a/airflow/cli/commands/cheat_sheet.py +++ b/airflow/cli/commands/cheat_sheet.py @@ -28,38 +28,35 @@ @suppress_logs_and_warning_click_compatible def cheat_sheet(verbose): """Display cheat-sheet""" - display_commands_index() - - -def display_commands_index(): - def display_recursive( - prefix: List[str], - command_group: click.Group, - help_msg: Optional[str] = None, - help_msg_length: int = 88, - ): - actions: List[click.Command] = [] - groups: List[click.Group] = [] - for command in command_group.commands.values(): - if isinstance(command, click.Group): - groups.append(command) - else: - actions.append(command) - - console = AirflowConsole() - if actions: - table = SimpleTable(title=help_msg or "Miscellaneous commands") - table.add_column(width=40) - table.add_column() - for action_command in sorted(actions, key=lambda d: d.name): - help_str = action_command.get_short_help_str(limit=help_msg_length) - table.add_row(" ".join([*prefix, action_command.name]), help_str) - console.print(table) + display_recursive(["airflow"], airflow_cmd) - if groups: - for group_command in sorted(groups, key=lambda d: d.name): - group_prefix = [*prefix, group_command.name] - help_str = group_command.get_short_help_str(limit=help_msg_length) - display_recursive(group_prefix, group_command, help_str) - display_recursive(["airflow"], airflow_cmd) +def display_recursive( + prefix: List[str], + command_group: click.Group, + help_msg: Optional[str] = None, + help_msg_length: int = 88, +): + actions: List[click.Command] = [] + groups: List[click.Group] = [] + for command in command_group.commands.values(): + if isinstance(command, click.Group): + groups.append(command) + else: + actions.append(command) + + console = AirflowConsole() + if actions: + table = SimpleTable(title=help_msg or "Miscellaneous commands") + table.add_column(width=40) + table.add_column() + for action_command in sorted(actions, key=lambda d: d.name): + help_str = action_command.get_short_help_str(limit=help_msg_length) + table.add_row(" ".join([*prefix, action_command.name]), help_str) + console.print(table) + + if groups: + for group_command in sorted(groups, key=lambda d: d.name): + group_prefix = [*prefix, group_command.name] + help_str = group_command.get_short_help_str(limit=help_msg_length) + display_recursive(group_prefix, group_command, help_str) From 4bcaa30ec7b169b9ffc5e464f513c45391e9ee26 Mon Sep 17 00:00:00 2001 From: hankehly Date: Fri, 15 Apr 2022 09:14:59 +0900 Subject: [PATCH 14/34] Add click-compatible celery command --- airflow/cli/__main__.py | 1 + airflow/cli/commands/celery.py | 323 +++++++++++++++++++++++++++++++++ 2 files changed, 324 insertions(+) create mode 100644 airflow/cli/commands/celery.py diff --git a/airflow/cli/__main__.py b/airflow/cli/__main__.py index 2dc38e2e69684..0bd2c1aa5d2fe 100644 --- a/airflow/cli/__main__.py +++ b/airflow/cli/__main__.py @@ -18,6 +18,7 @@ # under the License. from airflow.cli import airflow_cmd +from airflow.cli.commands import celery # noqa: F401 from airflow.cli.commands import db # noqa: F401 from airflow.cli.commands import scheduler # noqa: F401 from airflow.cli.commands import standalone # noqa: F401 diff --git a/airflow/cli/commands/celery.py b/airflow/cli/commands/celery.py new file mode 100644 index 0000000000000..aa1b543028465 --- /dev/null +++ b/airflow/cli/commands/celery.py @@ -0,0 +1,323 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from multiprocessing import Process +from typing import Optional + +import daemon +import psutil +import rich_click as click +import sqlalchemy.exc +from celery import maybe_patch_concurrency +from daemon.pidfile import TimeoutPIDLockFile +from lockfile.pidlockfile import read_pid_from_pidfile, remove_existing_pidfile + +from airflow import settings +from airflow.cli import airflow_cmd, click_daemon, click_log_file, click_pid, click_stderr, click_stdout +from airflow.configuration import conf +from airflow.executors.celery_executor import app as celery_app +from airflow.utils.cli import setup_locations, setup_logging +from airflow.utils.serve_logs import serve_logs + +WORKER_PROCESS_NAME = "worker" + + +click_flower_host = click.option( + "-H", + "--hostname", + default=conf.get("celery", "FLOWER_HOST"), + help="Set the hostname on which to run the server", +) +click_flower_port = click.option( + "-p", + "--port", + default=conf.get("celery", "FLOWER_PORT"), + type=int, + help="The port on which to run the server", +) +click_flower_broker_api = click.option("-a", "--broker-api", help="Broker API") +click_flower_url_prefix = click.option( + "-u", "--url-prefix", default=conf.get("celery", "FLOWER_URL_PREFIX"), help="URL prefix for Flower" +) +click_flower_basic_auth = click.option( + "-A", + "--basic-auth", + default=conf.get("celery", "FLOWER_BASIC_AUTH"), + help=( + "Securing Flower with Basic Authentication. " + "Accepts user:password pairs separated by a comma. " + "Example: flower_basic_auth = user1:password1,user2:password2" + ), +) +click_flower_conf = click.option("-c", "--flower-conf", help="Configuration file for flower") +click_worker_autoscale = click.option( + "-a", "--autoscale", help="Minimum and Maximum number of worker to autoscale" +) +click_worker_skip_serve_logs = click.option( + "-s", + "--skip-serve-logs", + is_flag=True, + default=False, + help="Don't start the serve logs process along with the workers", +) +click_worker_queues = click.option( + "-q", + "--queues", + default=conf.get("operators", "DEFAULT_QUEUE"), + help="Comma delimited list of queues to serve", +) +click_worker_concurrency = click.option( + "-c", + "--concurrency", + type=int, + default=conf.get("celery", "worker_concurrency"), + help="The number of worker processes", +) +click_worker_hostname = click.option( + "-H", + "--celery-hostname", + help="Set the hostname of celery worker if you have multiple workers on a single machine", +) +click_worker_umask = click.option( + "-u", + "--umask", + default=conf.get("celery", "worker_umask"), + help="Set the umask of celery worker in daemon mode", +) +click_worker_without_mingle = click.option( + "--without-mingle", is_flag=True, default=False, help="Don't synchronize with other workers at start-up" +) +click_worker_without_gossip = click.option( + "--without-gossip", is_flag=True, default=False, help="Don't subscribe to other workers events" +) + + +@airflow_cmd.group() +def celery(): + """Celery components""" + pass + + +@celery.command() +@click_flower_host +@click_flower_port +@click_flower_broker_api +@click_flower_url_prefix +@click_flower_basic_auth +@click_flower_conf +@click_stdout +@click_stderr +@click_pid +@click_daemon +@click_log_file +def flower( + hostname, port, broker_api, url_prefix, basic_auth, flower_conf, stdout, stderr, pid, daemon_, log_file +): + """Starts Flower, Celery monitoring tool""" + options = [ + "flower", + conf.get("celery", "BROKER_URL"), + f"--address={hostname}", + f"--port={port}", + ] + + if broker_api: + options.append(f"--broker-api={broker_api}") + + if url_prefix: + options.append(f"--url-prefix={url_prefix}") + + if basic_auth: + options.append(f"--basic-auth={basic_auth}") + + if flower_conf: + options.append(f"--conf={flower_conf}") + + if daemon_: + pidfile, stdout, stderr, _ = setup_locations( + process="flower", + pid=pid, + stdout=stdout, + stderr=stderr, + log=log_file, + ) + with open(stdout, "w+") as stdout, open(stderr, "w+") as stderr: + ctx = daemon.DaemonContext( + pidfile=TimeoutPIDLockFile(pidfile, -1), + stdout=stdout, + stderr=stderr, + ) + with ctx: + celery_app.start(options) + else: + celery_app.start(options) + + +def _serve_logs(skip_serve_logs: bool = False) -> Optional[Process]: + """Starts serve_logs sub-process""" + if skip_serve_logs is False: + sub_proc = Process(target=serve_logs) + sub_proc.start() + return sub_proc + return None + + +def _run_worker(options, skip_serve_logs): + sub_proc = _serve_logs(skip_serve_logs) + try: + celery_app.worker_main(options) + finally: + if sub_proc: + sub_proc.terminate() + + +@celery.command() +@click_pid +@click_daemon +@click_stdout +@click_stderr +@click_log_file +@click_worker_autoscale +@click_worker_skip_serve_logs +@click_worker_queues +@click_worker_concurrency +@click_worker_hostname +@click_worker_umask +@click_worker_without_mingle +@click_worker_without_gossip +def worker( + pid, + daemon_, + stdout, + stderr, + log_file, + autoscale, + skip_serve_logs, + queues, + concurrency, + celery_hostname, + umask, + without_mingle, + without_gossip, +): + """Starts Airflow Celery worker""" + # Disable connection pool so that celery worker does not hold an unnecessary db connection + settings.reconfigure_orm(disable_connection_pool=True) + if not settings.validate_session(): + raise SystemExit("Worker exiting, database connection precheck failed.") + + if autoscale is None and conf.has_option("celery", "worker_autoscale"): + autoscale = conf.get("celery", "worker_autoscale") + + # Setup locations + pid_file_path, stdout, stderr, log_file = setup_locations( + process=WORKER_PROCESS_NAME, + pid=pid, + stdout=stdout, + stderr=stderr, + log=log_file, + ) + + if hasattr(celery_app.backend, 'ResultSession'): + # Pre-create the database tables now, otherwise SQLA via Celery has a + # race condition where one of the subprocesses can die with "Table + # already exists" error, because SQLA checks for which tables exist, + # then issues a CREATE TABLE, rather than doing CREATE TABLE IF NOT + # EXISTS + try: + session = celery_app.backend.ResultSession() + session.close() + except sqlalchemy.exc.IntegrityError: + # At least on postgres, trying to create a table that already exist + # gives a unique constraint violation or the + # "pg_type_typname_nsp_index" table. If this happens we can ignore + # it, we raced to create the tables and lost. + pass + + # backwards-compatible: https://github.com/apache/airflow/pull/21506#pullrequestreview-879893763 + celery_log_level = conf.get('logging', 'CELERY_LOGGING_LEVEL') + if not celery_log_level: + celery_log_level = conf.get('logging', 'LOGGING_LEVEL') + # Setup Celery worker + options = [ + 'worker', + '-O', + 'fair', + '--queues', + queues, + '--concurrency', + concurrency, + '--hostname', + celery_hostname, + '--loglevel', + celery_log_level, + '--pidfile', + pid_file_path, + ] + if autoscale: + options.extend(['--autoscale', autoscale]) + if without_mingle: + options.append('--without-mingle') + if without_gossip: + options.append('--without-gossip') + + if conf.has_option("celery", "pool"): + pool = conf.get("celery", "pool") + options.extend(["--pool", pool]) + # Celery pools of type eventlet and gevent use greenlets, which + # requires monkey patching the app: + # https://eventlet.net/doc/patching.html#monkey-patch + # Otherwise task instances hang on the workers and are never + # executed. + maybe_patch_concurrency(['-P', pool]) + + if daemon_: + # Run Celery worker as daemon + handle = setup_logging(log_file) + + with open(stdout, 'w+') as stdout_handle, open(stderr, 'w+') as stderr_handle: + ctx = daemon.DaemonContext( + files_preserve=[handle], + umask=int(umask, 8), + stdout=stdout_handle, + stderr=stderr_handle, + ) + with ctx: + _run_worker(options=options, skip_serve_logs=skip_serve_logs) + else: + # Run Celery worker in the same process + _run_worker(options=options, skip_serve_logs=skip_serve_logs) + + +@celery.command("stop") +@click_pid +def stop_worker(pid): + """Stop the Celery worker gracefully by sending SIGTERM to worker""" + # Read PID from file + if pid: + pid_file_path = pid + else: + pid_file_path, _, _, _ = setup_locations(process=WORKER_PROCESS_NAME) + pid = read_pid_from_pidfile(pid_file_path) + + # Send SIGTERM + if pid: + worker_process = psutil.Process(pid) + worker_process.terminate() + + # Remove pid file + remove_existing_pidfile(pid_file_path) From 908a05ab12349dcbf9b7c70420c3e0c6051fab34 Mon Sep 17 00:00:00 2001 From: hankehly Date: Thu, 14 Apr 2022 14:07:11 +0900 Subject: [PATCH 15/34] Add click-compatible cheat-sheet command unit test (amended 2022-04-16 13:37 JST) --- tests/cli/commands/test_cheat_sheet.py | 101 +++++++++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 tests/cli/commands/test_cheat_sheet.py diff --git a/tests/cli/commands/test_cheat_sheet.py b/tests/cli/commands/test_cheat_sheet.py new file mode 100644 index 0000000000000..88610b394aadd --- /dev/null +++ b/tests/cli/commands/test_cheat_sheet.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import contextlib +import io +import unittest +from typing import List + +import rich_click as click + +from airflow.cli.commands.cheat_sheet import display_recursive + + +@click.group() +def mock_cli(): + """Mock cli""" + pass + + +@mock_cli.group() +def cmd_a(): + """Help text A""" + pass + + +@cmd_a.command("cmd_b") +def cmd_b(): + """Help text B""" + pass + + +@cmd_a.command("cmd_c") +def cmd_c(): + """Help text C""" + pass + + +@mock_cli.group() +def cmd_e(): + """Help text E""" + pass + + +@cmd_e.command("cmd_f") +def cmd_f(): + """Help text F""" + pass + + +@cmd_e.command("cmd_g") +def cmd_g(): + """Help text G""" + pass + + +@mock_cli.command() +def cmd_h(): + """Help text H""" + pass + + +SECTION_MISC = """\ +Miscellaneous commands +airflow cmd-h | Help text H +""" + +SECTION_A = """\ +Help text A +airflow cmd-a cmd_b | Help text B +airflow cmd-a cmd_c | Help text C +""" + +SECTION_E = """\ +Help text E +airflow cmd-e cmd_f | Help text F +airflow cmd-e cmd_g | Help text G +""" + + +class TestCheatSheet(unittest.TestCase): + def test_display_recursive_commands(self): + with contextlib.redirect_stdout(io.StringIO()) as stdout: + display_recursive(["airflow"], mock_cli) + output = stdout.getvalue() + assert SECTION_MISC in output + assert SECTION_A in output + assert SECTION_E in output From dc319ab55f8b045e745302824b71868d28aa11db Mon Sep 17 00:00:00 2001 From: hankehly Date: Mon, 18 Apr 2022 08:01:39 +0900 Subject: [PATCH 16/34] Mark cli path existence as optional --- airflow/cli/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/airflow/cli/__init__.py b/airflow/cli/__init__.py index 451c34f4c5df7..f34e144110595 100644 --- a/airflow/cli/__init__.py +++ b/airflow/cli/__init__.py @@ -61,7 +61,7 @@ click_log_file = click.option( "-l", "--log-file", - type=click.Path(exists=True, dir_okay=False, writable=True), + type=click.Path(exists=False, dir_okay=False, writable=True), help="Location of the log file", ) click_output = click.option( @@ -71,7 +71,7 @@ default="table", help="Output format.", ) -click_pid = click.option("--pid", type=click.Path(exists=True), help="PID file location") +click_pid = click.option("--pid", type=click.Path(exists=False), help="PID file location") click_start_date = click.option( "-s", "--start-date", @@ -80,12 +80,12 @@ ) click_stderr = click.option( "--stderr", - type=click.Path(exists=True, dir_okay=False, writable=True), + type=click.Path(exists=False, dir_okay=False, writable=True), help="Redirect stderr to this file", ) click_stdout = click.option( "--stdout", - type=click.Path(exists=True, dir_okay=False, writable=True), + type=click.Path(exists=False, dir_okay=False, writable=True), help="Redirect stdout to this file", ) click_subdir = click.option( From d9b4c864556bf9b614b83f334de0245ddc07b09c Mon Sep 17 00:00:00 2001 From: hankehly Date: Mon, 18 Apr 2022 08:03:39 +0900 Subject: [PATCH 17/34] Add click compatible celery command unit test module --- tests/cli/commands/test_celery.py | 355 ++++++++++++++++++++++++++++++ 1 file changed, 355 insertions(+) create mode 100644 tests/cli/commands/test_celery.py diff --git a/tests/cli/commands/test_celery.py b/tests/cli/commands/test_celery.py new file mode 100644 index 0000000000000..373b3bee4fef3 --- /dev/null +++ b/tests/cli/commands/test_celery.py @@ -0,0 +1,355 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import unittest +from tempfile import NamedTemporaryFile +from unittest import mock + +import pytest +import sqlalchemy +from click.testing import CliRunner + +import airflow +from airflow.cli.commands import celery as celery_command +from airflow.configuration import conf +from tests.test_utils.config import conf_vars + + +class TestWorkerPrecheck(unittest.TestCase): + @mock.patch('airflow.settings.validate_session') + def test_error(self, mock_validate_session): + """ + Test to verify the exit mechanism of airflow-worker cli + by mocking validate_session method + """ + mock_validate_session.return_value = False + runner = CliRunner() + result = runner.invoke(celery_command.worker, ["--queues", 1, "--concurrency", "1"]) + assert result.exit_code == 1 + assert result.output.strip() == "Worker exiting, database connection precheck failed." + + @conf_vars({('celery', 'worker_precheck'): 'False'}) + def test_worker_precheck_exception(self): + """ + Test to check the behaviour of validate_session method + when worker_precheck is absent in airflow configuration + """ + assert airflow.settings.validate_session() + + @mock.patch('sqlalchemy.orm.session.Session.execute') + @conf_vars({('celery', 'worker_precheck'): 'True'}) + def test_validate_session_dbapi_exception(self, mock_session): + """ + Test to validate connection failure scenario on SELECT 1 query + """ + mock_session.side_effect = sqlalchemy.exc.OperationalError("m1", "m2", "m3", "m4") + assert airflow.settings.validate_session() is False + + +@pytest.mark.integration("redis") +@pytest.mark.integration("rabbitmq") +@pytest.mark.backend("mysql", "postgres") +class TestWorkerServeLogs(unittest.TestCase): + @mock.patch('celery.platforms.check_privileges', return_value=0) + @mock.patch('airflow.cli.commands.celery.Process') + @mock.patch('airflow.cli.commands.celery.celery_app') + @conf_vars({("core", "executor"): "CeleryExecutor"}) + def test_serve_logs_on_worker_start(self, mock_celery_app, mock_process, mock_privil): + runner = CliRunner() + runner.invoke(celery_command.worker, ["--concurrency", "1"]) + mock_process.assert_called() + + @mock.patch('celery.platforms.check_privileges', return_value=0) + @mock.patch('airflow.cli.commands.celery.Process') + @mock.patch('airflow.cli.commands.celery.celery_app') + @conf_vars({("core", "executor"): "CeleryExecutor"}) + def test_skip_serve_logs_on_worker_start(self, mock_celery_app, mock_process, mock_privil): + runner = CliRunner() + runner.invoke(celery_command.worker, ["--concurrency", "1", "--skip-serve-logs"]) + mock_process.assert_not_called() + + +@pytest.mark.backend("mysql", "postgres") +class TestCeleryStopCommand(unittest.TestCase): + @mock.patch("airflow.cli.commands.celery.setup_locations") + @mock.patch("airflow.cli.commands.celery.psutil.Process") + @conf_vars({("core", "executor"): "CeleryExecutor"}) + def test_if_right_pid_is_read(self, mock_process, mock_setup_locations): + pid = "123" + # Calling stop_worker should delete the temporary pid file + with pytest.raises(FileNotFoundError), NamedTemporaryFile("w+") as f: + # Create pid file + f.write(pid) + f.flush() + # Setup mock + mock_setup_locations.return_value = (f.name, None, None, None) + # Check if works as expected + runner = CliRunner() + result = runner.invoke(celery_command.stop_worker) + assert result.exit_code == 0 + mock_process.assert_called_once_with(int(pid)) + mock_process.return_value.terminate.assert_called_once_with() + + @mock.patch("airflow.cli.commands.celery.read_pid_from_pidfile") + @mock.patch('airflow.cli.commands.celery.celery_app') + @mock.patch("airflow.cli.commands.celery.setup_locations") + @conf_vars({("core", "executor"): "CeleryExecutor"}) + def test_same_pid_file_is_used_in_start_and_stop( + self, mock_setup_locations, mock_celery_app, mock_read_pid_from_pidfile + ): + pid_file = "test_pid_file" + mock_setup_locations.return_value = (pid_file, None, None, None) + mock_read_pid_from_pidfile.return_value = None + runner = CliRunner() + + # Call worker + result_worker = runner.invoke(celery_command.worker, ["--skip-serve-logs"]) + assert result_worker.exit_code == 0 + assert mock_celery_app.worker_main.call_args + args, _ = mock_celery_app.worker_main.call_args + args_str = ' '.join(map(str, args[0])) + assert f'--pidfile {pid_file}' in args_str + + # Call stop + result_stop_worker = runner.invoke(celery_command.stop_worker) + assert result_stop_worker.exit_code == 0 + mock_read_pid_from_pidfile.assert_called_once_with(pid_file) + + @mock.patch("airflow.cli.commands.celery.remove_existing_pidfile") + @mock.patch("airflow.cli.commands.celery.read_pid_from_pidfile") + @mock.patch('airflow.cli.commands.celery.celery_app') + @mock.patch("airflow.cli.commands.celery.psutil.Process") + @mock.patch("airflow.cli.commands.celery.setup_locations") + @conf_vars({("core", "executor"): "CeleryExecutor"}) + def test_custom_pid_file_is_used_in_start_and_stop( + self, + mock_setup_locations, + mock_process, + mock_celery_app, + mock_read_pid_from_pidfile, + mock_remove_existing_pidfile, + ): + pid = "123" + runner = CliRunner() + with NamedTemporaryFile("w+") as pid_file: + # Create pid file + pid_file.write(pid) + pid_file.flush() + mock_setup_locations.return_value = (pid_file.name, None, None, None) + + # Call worker + result_worker = runner.invoke( + celery_command.worker, ["--skip-serve-logs", "--pid", pid_file.name] + ) + assert result_worker.exit_code == 0 + assert mock_celery_app.worker_main.call_args + args, _ = mock_celery_app.worker_main.call_args + args_str = ' '.join(map(str, args[0])) + assert f'--pidfile {pid_file.name}' in args_str + + # Call stop + result_stop_worker = runner.invoke(celery_command.stop_worker) + assert result_stop_worker.exit_code == 0 + + mock_read_pid_from_pidfile.assert_called_once_with(pid_file.name) + mock_process.return_value.terminate.assert_called() + mock_remove_existing_pidfile.assert_called_once_with(pid_file.name) + + +@pytest.mark.backend("mysql", "postgres") +class TestWorkerStart(unittest.TestCase): + @mock.patch("airflow.cli.commands.celery.setup_locations") + @mock.patch('airflow.cli.commands.celery.Process') + @mock.patch('airflow.cli.commands.celery.celery_app') + @conf_vars({("core", "executor"): "CeleryExecutor"}) + def test_worker_started_with_required_arguments(self, mock_celery_app, mock_popen, mock_locations): + pid_file = "pid_file" + mock_locations.return_value = (pid_file, None, None, None) + concurrency = '1' + celery_hostname = "celery_hostname" + queues = "queue" + autoscale = "2,5" + runner = CliRunner() + runner.invoke( + celery_command.worker, + [ + '--autoscale', + autoscale, + '--concurrency', + concurrency, + '--celery-hostname', + celery_hostname, + '--queues', + queues, + '--without-mingle', + '--without-gossip', + ], + ) + mock_celery_app.worker_main.assert_called_once_with( + [ + 'worker', + '-O', + 'fair', + '--queues', + queues, + '--concurrency', + int(concurrency), + '--hostname', + celery_hostname, + '--loglevel', + conf.get('logging', 'CELERY_LOGGING_LEVEL'), + '--pidfile', + pid_file, + '--autoscale', + autoscale, + '--without-mingle', + '--without-gossip', + '--pool', + 'prefork', + ] + ) + + +@pytest.mark.backend("mysql", "postgres") +class TestWorkerFailure(unittest.TestCase): + @mock.patch('airflow.cli.commands.celery.Process') + @mock.patch('airflow.cli.commands.celery.celery_app') + @conf_vars({("core", "executor"): "CeleryExecutor"}) + def test_worker_failure_gracefull_shutdown(self, mock_celery_app, mock_popen): + mock_celery_app.run.side_effect = Exception('Mock exception to trigger runtime error') + runner = CliRunner() + runner.invoke(celery_command.worker) + mock_popen().terminate.assert_called() + + +@pytest.mark.backend("mysql", "postgres") +class TestFlowerCommand(unittest.TestCase): + @mock.patch('airflow.cli.commands.celery.celery_app') + @conf_vars({("core", "executor"): "CeleryExecutor"}) + def test_run_command(self, mock_celery_app): + runner = CliRunner() + runner.invoke( + celery_command.flower, + [ + '--basic-auth', + 'admin:admin', + '--broker-api', + 'http://username:password@rabbitmq-server-name:15672/api/', + '--flower-conf', + 'flower_config', + '--hostname', + 'my-hostname', + '--port', + '3333', + '--url-prefix', + 'flower-monitoring', + ], + ) + mock_celery_app.start.assert_called_once_with( + [ + 'flower', + 'amqp://guest:guest@rabbitmq:5672/', + '--address=my-hostname', + '--port=3333', + '--broker-api=http://username:password@rabbitmq-server-name:15672/api/', + '--url-prefix=flower-monitoring', + '--basic-auth=admin:admin', + '--conf=flower_config', + ] + ) + + @mock.patch('airflow.cli.commands.celery.TimeoutPIDLockFile') + @mock.patch('airflow.cli.commands.celery.setup_locations') + @mock.patch('airflow.cli.commands.celery.daemon') + @mock.patch('airflow.cli.commands.celery.celery_app') + @conf_vars({("core", "executor"): "CeleryExecutor"}) + def test_run_command_daemon(self, mock_celery_app, mock_daemon, mock_setup_locations, mock_pid_file): + mock_setup_locations.return_value = ( + mock.MagicMock(name='pidfile'), + mock.MagicMock(name='stdout'), + mock.MagicMock(name='stderr'), + mock.MagicMock(name="INVALID"), + ) + args = [ + '--basic-auth', + 'admin:admin', + '--broker-api', + 'http://username:password@rabbitmq-server-name:15672/api/', + '--flower-conf', + 'flower_config', + '--hostname', + 'my-hostname', + '--log-file', + '/tmp/flower.log', + '--pid', + '/tmp/flower.pid', + '--port', + '3333', + '--stderr', + '/tmp/flower-stderr.log', + '--stdout', + '/tmp/flower-stdout.log', + '--url-prefix', + 'flower-monitoring', + '--daemon', + ] + runner = CliRunner() + mock_open = mock.mock_open() + with mock.patch('airflow.cli.commands.celery.open', mock_open): + result = runner.invoke(celery_command.flower, args) + assert result.exit_code == 0 + + mock_celery_app.start.assert_called_once_with( + [ + 'flower', + 'amqp://guest:guest@rabbitmq:5672/', + '--address=my-hostname', + '--port=3333', + '--broker-api=http://username:password@rabbitmq-server-name:15672/api/', + '--url-prefix=flower-monitoring', + '--basic-auth=admin:admin', + '--conf=flower_config', + ] + ) + assert mock_daemon.mock_calls == [ + mock.call.DaemonContext( + pidfile=mock_pid_file.return_value, + stderr=mock_open.return_value, + stdout=mock_open.return_value, + ), + mock.call.DaemonContext().__enter__(), + mock.call.DaemonContext().__exit__(None, None, None), + ] + + assert mock_setup_locations.mock_calls == [ + mock.call( + log='/tmp/flower.log', + pid='/tmp/flower.pid', + process='flower', + stderr='/tmp/flower-stderr.log', + stdout='/tmp/flower-stdout.log', + ) + ] + mock_pid_file.assert_has_calls([mock.call(mock_setup_locations.return_value[0], -1)]) + assert mock_open.mock_calls == [ + mock.call(mock_setup_locations.return_value[1], 'w+'), + mock.call().__enter__(), + mock.call(mock_setup_locations.return_value[2], 'w+'), + mock.call().__enter__(), + mock.call().__exit__(None, None, None), + mock.call().__exit__(None, None, None), + ] From 232bcb333e742ae4db80c4bbf1f8d634b981bd51 Mon Sep 17 00:00:00 2001 From: Hank Ehly Date: Wed, 27 Apr 2022 08:55:40 +0900 Subject: [PATCH 18/34] Apply suggestions from code review Co-authored-by: blag --- airflow/cli/__init__.py | 4 ++- airflow/cli/commands/celery.py | 46 +++++++++++++++++++++++++++------- 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/airflow/cli/__init__.py b/airflow/cli/__init__.py index f34e144110595..38decfd3239c8 100644 --- a/airflow/cli/__init__.py +++ b/airflow/cli/__init__.py @@ -71,7 +71,7 @@ default="table", help="Output format.", ) -click_pid = click.option("--pid", type=click.Path(exists=False), help="PID file location") +click_pid = click.option("--pid", metavar="PID", type=click.Path(exists=False), help="PID file location") click_start_date = click.option( "-s", "--start-date", @@ -80,11 +80,13 @@ ) click_stderr = click.option( "--stderr", + metavar="STDERR", type=click.Path(exists=False, dir_okay=False, writable=True), help="Redirect stderr to this file", ) click_stdout = click.option( "--stdout", + metavar="STDOUT", type=click.Path(exists=False, dir_okay=False, writable=True), help="Redirect stdout to this file", ) diff --git a/airflow/cli/commands/celery.py b/airflow/cli/commands/celery.py index aa1b543028465..d2cf13669ab3f 100644 --- a/airflow/cli/commands/celery.py +++ b/airflow/cli/commands/celery.py @@ -36,36 +36,60 @@ WORKER_PROCESS_NAME = "worker" -click_flower_host = click.option( +click_flower_hostname = click.option( "-H", "--hostname", + metavar="HOSTNAME", default=conf.get("celery", "FLOWER_HOST"), help="Set the hostname on which to run the server", ) click_flower_port = click.option( "-p", "--port", + metavar="PORT", default=conf.get("celery", "FLOWER_PORT"), type=int, help="The port on which to run the server", ) -click_flower_broker_api = click.option("-a", "--broker-api", help="Broker API") +click_flower_broker_api = click.option( + "-a", + "--broker-api", + metavar="BROKER_API", + help=""" + Broker API URL + + Examples: + + postgresql://user:secret@host1:5432,host2:5433/otherdb?connect_timeout=10&application_name=myapp + + redis://localhost:6379/0 + + amqp://myuser:mypassword@localhost:5672/myvhost + + sqs://ABCDEFGHIJKLMNOPQRST:ZYXK7NiynGlTogH8Nj+P9nlE73sq3@ + """ +) click_flower_url_prefix = click.option( "-u", "--url-prefix", default=conf.get("celery", "FLOWER_URL_PREFIX"), help="URL prefix for Flower" ) click_flower_basic_auth = click.option( "-A", "--basic-auth", + metavar="BASIC_AUTH", default=conf.get("celery", "FLOWER_BASIC_AUTH"), - help=( - "Securing Flower with Basic Authentication. " - "Accepts user:password pairs separated by a comma. " - "Example: flower_basic_auth = user1:password1,user2:password2" + help=""" + Securing Flower with Basic Authentication. + Accepts user:password pairs separated by a comma. + + Example: + + --basic-auth user1:password1,user2:password2 + """ ), ) -click_flower_conf = click.option("-c", "--flower-conf", help="Configuration file for flower") +click_flower_conf = click.option("-c", "--flower-conf", metavar="FLOWER_CONF", help="Configuration file for flower") click_worker_autoscale = click.option( - "-a", "--autoscale", help="Minimum and Maximum number of worker to autoscale" + "-a", "--autoscale", metavar="AUTOSCALE", help="Minimum and Maximum number of worker to autoscale" ) click_worker_skip_serve_logs = click.option( "-s", @@ -77,12 +101,14 @@ click_worker_queues = click.option( "-q", "--queues", + metavar="QUEUES", default=conf.get("operators", "DEFAULT_QUEUE"), help="Comma delimited list of queues to serve", ) click_worker_concurrency = click.option( "-c", "--concurrency", + metavar="CONCURRENCY", type=int, default=conf.get("celery", "worker_concurrency"), help="The number of worker processes", @@ -90,11 +116,13 @@ click_worker_hostname = click.option( "-H", "--celery-hostname", + metavar="CELERY_HOSTNAME", help="Set the hostname of celery worker if you have multiple workers on a single machine", ) click_worker_umask = click.option( "-u", "--umask", + metavar="UMASK", default=conf.get("celery", "worker_umask"), help="Set the umask of celery worker in daemon mode", ) @@ -113,7 +141,7 @@ def celery(): @celery.command() -@click_flower_host +@click_flower_hostname @click_flower_port @click_flower_broker_api @click_flower_url_prefix From 487fe3a3657e578021da4768a7b18a8a643a729e Mon Sep 17 00:00:00 2001 From: hankehly Date: Wed, 27 Apr 2022 09:02:40 +0900 Subject: [PATCH 19/34] Inline celery cli command click option definitions --- airflow/cli/commands/celery.py | 141 ++++++++++++++------------------- 1 file changed, 61 insertions(+), 80 deletions(-) diff --git a/airflow/cli/commands/celery.py b/airflow/cli/commands/celery.py index d2cf13669ab3f..871f56c6645fb 100644 --- a/airflow/cli/commands/celery.py +++ b/airflow/cli/commands/celery.py @@ -36,14 +36,21 @@ WORKER_PROCESS_NAME = "worker" -click_flower_hostname = click.option( +@airflow_cmd.group() +def celery(): + """Celery components""" + pass + + +@celery.command() +@click.option( "-H", "--hostname", metavar="HOSTNAME", default=conf.get("celery", "FLOWER_HOST"), help="Set the hostname on which to run the server", ) -click_flower_port = click.option( +@click.option( "-p", "--port", metavar="PORT", @@ -51,7 +58,7 @@ type=int, help="The port on which to run the server", ) -click_flower_broker_api = click.option( +@click.option( "-a", "--broker-api", metavar="BROKER_API", @@ -60,19 +67,19 @@ Examples: - postgresql://user:secret@host1:5432,host2:5433/otherdb?connect_timeout=10&application_name=myapp + postgresql://user:secret@host1:5432,host2:5433/otherdb?connect_timeout=10&application_name=myapp redis://localhost:6379/0 - + amqp://myuser:mypassword@localhost:5672/myvhost - - sqs://ABCDEFGHIJKLMNOPQRST:ZYXK7NiynGlTogH8Nj+P9nlE73sq3@ - """ + + sqs://sqs.us-east-1.amazonaws.com:80 + """, ) -click_flower_url_prefix = click.option( +@click.option( "-u", "--url-prefix", default=conf.get("celery", "FLOWER_URL_PREFIX"), help="URL prefix for Flower" ) -click_flower_basic_auth = click.option( +@click.option( "-A", "--basic-auth", metavar="BASIC_AUTH", @@ -84,69 +91,9 @@ Example: --basic-auth user1:password1,user2:password2 - """ - ), -) -click_flower_conf = click.option("-c", "--flower-conf", metavar="FLOWER_CONF", help="Configuration file for flower") -click_worker_autoscale = click.option( - "-a", "--autoscale", metavar="AUTOSCALE", help="Minimum and Maximum number of worker to autoscale" -) -click_worker_skip_serve_logs = click.option( - "-s", - "--skip-serve-logs", - is_flag=True, - default=False, - help="Don't start the serve logs process along with the workers", -) -click_worker_queues = click.option( - "-q", - "--queues", - metavar="QUEUES", - default=conf.get("operators", "DEFAULT_QUEUE"), - help="Comma delimited list of queues to serve", -) -click_worker_concurrency = click.option( - "-c", - "--concurrency", - metavar="CONCURRENCY", - type=int, - default=conf.get("celery", "worker_concurrency"), - help="The number of worker processes", -) -click_worker_hostname = click.option( - "-H", - "--celery-hostname", - metavar="CELERY_HOSTNAME", - help="Set the hostname of celery worker if you have multiple workers on a single machine", -) -click_worker_umask = click.option( - "-u", - "--umask", - metavar="UMASK", - default=conf.get("celery", "worker_umask"), - help="Set the umask of celery worker in daemon mode", -) -click_worker_without_mingle = click.option( - "--without-mingle", is_flag=True, default=False, help="Don't synchronize with other workers at start-up" + """, ) -click_worker_without_gossip = click.option( - "--without-gossip", is_flag=True, default=False, help="Don't subscribe to other workers events" -) - - -@airflow_cmd.group() -def celery(): - """Celery components""" - pass - - -@celery.command() -@click_flower_hostname -@click_flower_port -@click_flower_broker_api -@click_flower_url_prefix -@click_flower_basic_auth -@click_flower_conf +@click.option("-c", "--flower-conf", metavar="FLOWER_CONF", help="Configuration file for flower") @click_stdout @click_stderr @click_pid @@ -219,14 +166,48 @@ def _run_worker(options, skip_serve_logs): @click_stdout @click_stderr @click_log_file -@click_worker_autoscale -@click_worker_skip_serve_logs -@click_worker_queues -@click_worker_concurrency -@click_worker_hostname -@click_worker_umask -@click_worker_without_mingle -@click_worker_without_gossip +@click.option( + "-a", "--autoscale", metavar="AUTOSCALE", help="Minimum and Maximum number of worker to autoscale" +) +@click.option( + "-s", + "--skip-serve-logs", + is_flag=True, + default=False, + help="Don't start the serve logs process along with the workers", +) +@click.option( + "-q", + "--queues", + metavar="QUEUES", + default=conf.get("operators", "DEFAULT_QUEUE"), + help="Comma delimited list of queues to serve", +) +@click.option( + "-c", + "--concurrency", + metavar="CONCURRENCY", + type=int, + default=conf.get("celery", "worker_concurrency"), + help="The number of worker processes", +) +@click.option( + "-H", + "--celery-hostname", + metavar="CELERY_HOSTNAME", + help="Set the hostname of celery worker if you have multiple workers on a single machine", +) +@click.option( + "-u", + "--umask", + metavar="UMASK", + default=conf.get("celery", "worker_umask"), + help="Set the umask of celery worker in daemon mode", +) +@click.option( + "--without-mingle", is_flag=True, default=False, help="Don't synchronize with other workers at start-up" +) +@click.option("--without-gossip", is_flag=True, default=False, help="Don't subscribe to other workers events") def worker( pid, daemon_, From 789772376690d0cddc713e5f0a170873e14998c6 Mon Sep 17 00:00:00 2001 From: hankehly Date: Fri, 29 Apr 2022 14:03:17 +0900 Subject: [PATCH 20/34] Set click option metavars --- airflow/cli/__init__.py | 1 + airflow/cli/commands/celery.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/airflow/cli/__init__.py b/airflow/cli/__init__.py index 38decfd3239c8..2b63b6ca922be 100644 --- a/airflow/cli/__init__.py +++ b/airflow/cli/__init__.py @@ -61,6 +61,7 @@ click_log_file = click.option( "-l", "--log-file", + metavar="LOG_FILE", type=click.Path(exists=False, dir_okay=False, writable=True), help="Location of the log file", ) diff --git a/airflow/cli/commands/celery.py b/airflow/cli/commands/celery.py index 871f56c6645fb..7a2c05608f6e8 100644 --- a/airflow/cli/commands/celery.py +++ b/airflow/cli/commands/celery.py @@ -77,7 +77,11 @@ def celery(): """, ) @click.option( - "-u", "--url-prefix", default=conf.get("celery", "FLOWER_URL_PREFIX"), help="URL prefix for Flower" + "-u", + "--url-prefix", + default=conf.get("celery", "FLOWER_URL_PREFIX"), + metavar="URL_PREFIX", + help="URL prefix for Flower", ) @click.option( "-A", From 13749d0d4cbcc78d848d5c29d31f5dd4deca8eef Mon Sep 17 00:00:00 2001 From: hankehly Date: Mon, 2 May 2022 12:27:44 +0900 Subject: [PATCH 21/34] Port info command and unit tests to click --- airflow/cli/__init__.py | 4 +- airflow/cli/__main__.py | 1 + airflow/cli/commands/info.py | 401 ++++++++++++++++++++++++++++++++ tests/cli/commands/test_info.py | 177 ++++++++++++++ 4 files changed, 581 insertions(+), 2 deletions(-) create mode 100644 airflow/cli/commands/info.py create mode 100644 tests/cli/commands/test_info.py diff --git a/airflow/cli/__init__.py b/airflow/cli/__init__.py index 451c34f4c5df7..7a27c7a6ff60f 100644 --- a/airflow/cli/__init__.py +++ b/airflow/cli/__init__.py @@ -27,7 +27,7 @@ click_color = click.option( '--color', - choices=click.Choice({ColorMode.ON, ColorMode.OFF, ColorMode.AUTO}), + type=click.Choice({ColorMode.ON, ColorMode.OFF, ColorMode.AUTO}), default=ColorMode.AUTO, help="Do emit colored output (default: auto)", ) @@ -67,7 +67,7 @@ click_output = click.option( "-o", "--output", - choices=click.Choice(["table", "json", "yaml", "plain"]), + type=click.Choice(["table", "json", "yaml", "plain"]), default="table", help="Output format.", ) diff --git a/airflow/cli/__main__.py b/airflow/cli/__main__.py index 32d5187b018ea..9d9b10367cdfe 100644 --- a/airflow/cli/__main__.py +++ b/airflow/cli/__main__.py @@ -20,6 +20,7 @@ from airflow.cli import airflow_cmd from airflow.cli.commands import cheat_sheet # noqa: F401 from airflow.cli.commands import db # noqa: F401 +from airflow.cli.commands import info # noqa: F401 from airflow.cli.commands import scheduler # noqa: F401 from airflow.cli.commands import standalone # noqa: F401 from airflow.cli.commands import triggerer # noqa: F401 diff --git a/airflow/cli/commands/info.py b/airflow/cli/commands/info.py new file mode 100644 index 0000000000000..b2cd6df2c8000 --- /dev/null +++ b/airflow/cli/commands/info.py @@ -0,0 +1,401 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Config sub-commands""" +import locale +import logging +import os +import platform +import subprocess +import sys +from typing import List, Optional +from urllib.parse import urlsplit, urlunsplit + +import httpx +import rich_click as click +import tenacity + +from airflow import configuration +from airflow.cli import airflow_cmd, click_output, click_verbose +from airflow.cli.simple_table import AirflowConsole +from airflow.providers_manager import ProvidersManager +from airflow.typing_compat import Protocol +from airflow.utils.cli import suppress_logs_and_warning_click_compatible +from airflow.utils.platform import getuser +from airflow.version import version as airflow_version + +log = logging.getLogger(__name__) + + +class Anonymizer(Protocol): + """Anonymizer protocol.""" + + def process_path(self, value) -> str: + """Remove pii from paths""" + + def process_username(self, value) -> str: + """Remove pii from username""" + + def process_url(self, value) -> str: + """Remove pii from URL""" + + +class NullAnonymizer(Anonymizer): + """Do nothing.""" + + def _identity(self, value) -> str: + return value + + process_path = process_username = process_url = _identity + + del _identity + + +class PiiAnonymizer(Anonymizer): + """Remove personally identifiable info from path.""" + + def __init__(self): + home_path = os.path.expanduser("~") + username = getuser() + self._path_replacements = {home_path: "${HOME}", username: "${USER}"} + + def process_path(self, value) -> str: + if not value: + return value + for src, target in self._path_replacements.items(): + value = value.replace(src, target) + return value + + def process_username(self, value) -> str: + if not value: + return value + return value[0] + "..." + value[-1] + + def process_url(self, value) -> str: + if not value: + return value + + url_parts = urlsplit(value) + netloc = None + if url_parts.netloc: + # unpack + userinfo = None + username = None + password = None + + if "@" in url_parts.netloc: + userinfo, _, host = url_parts.netloc.partition("@") + else: + host = url_parts.netloc + if userinfo: + if ":" in userinfo: + username, _, password = userinfo.partition(":") + else: + username = userinfo + + # anonymize + username = self.process_username(username) if username else None + password = "PASSWORD" if password else None + + # pack + if username and password and host: + netloc = username + ":" + password + "@" + host + elif username and host: + netloc = username + "@" + host + elif password and host: + netloc = ":" + password + "@" + host + elif host: + netloc = host + else: + netloc = "" + + return urlunsplit((url_parts.scheme, netloc, url_parts.path, url_parts.query, url_parts.fragment)) + + +class OperatingSystem: + """Operating system""" + + WINDOWS = "Windows" + LINUX = "Linux" + MACOSX = "Mac OS" + CYGWIN = "Cygwin" + + @staticmethod + def get_current() -> Optional[str]: + """Get current operating system""" + if os.name == "nt": + return OperatingSystem.WINDOWS + elif "linux" in sys.platform: + return OperatingSystem.LINUX + elif "darwin" in sys.platform: + return OperatingSystem.MACOSX + elif "cygwin" in sys.platform: + return OperatingSystem.CYGWIN + return None + + +class Architecture: + """Compute architecture""" + + X86_64 = "x86_64" + X86 = "x86" + PPC = "ppc" + ARM = "arm" + + @staticmethod + def get_current(): + """Get architecture""" + return _MACHINE_TO_ARCHITECTURE.get(platform.machine().lower()) + + +_MACHINE_TO_ARCHITECTURE = { + "amd64": Architecture.X86_64, + "x86_64": Architecture.X86_64, + "i686-64": Architecture.X86_64, + "i386": Architecture.X86, + "i686": Architecture.X86, + "x86": Architecture.X86, + "ia64": Architecture.X86, # Itanium is different x64 arch, treat it as the common x86. + "powerpc": Architecture.PPC, + "power macintosh": Architecture.PPC, + "ppc64": Architecture.PPC, + "armv6": Architecture.ARM, + "armv6l": Architecture.ARM, + "arm64": Architecture.ARM, + "armv7": Architecture.ARM, + "armv7l": Architecture.ARM, +} + + +class AirflowInfo: + """Renders information about Airflow instance""" + + def __init__(self, anonymizer): + self.anonymizer = anonymizer + + @staticmethod + def _get_version(cmd: List[str], grep: Optional[bytes] = None): + """Return tools version.""" + try: + with subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as proc: + stdoutdata, _ = proc.communicate() + data = [f for f in stdoutdata.split(b"\n") if f] + if grep: + data = [line for line in data if grep in line] + if len(data) != 1: + return "NOT AVAILABLE" + else: + return data[0].decode() + except OSError: + return "NOT AVAILABLE" + + @staticmethod + def _task_logging_handler(): + """Returns task logging handler.""" + + def get_fullname(o): + module = o.__class__.__module__ + if module is None or module == str.__class__.__module__: + return o.__class__.__name__ # Avoid reporting __builtin__ + else: + return module + '.' + o.__class__.__name__ + + try: + handler_names = [get_fullname(handler) for handler in logging.getLogger('airflow.task').handlers] + return ", ".join(handler_names) + except Exception: + return "NOT AVAILABLE" + + @property + def _airflow_info(self): + executor = configuration.conf.get("core", "executor") + sql_alchemy_conn = self.anonymizer.process_url( + configuration.conf.get("core", "SQL_ALCHEMY_CONN", fallback="NOT AVAILABLE") + ) + dags_folder = self.anonymizer.process_path( + configuration.conf.get("core", "dags_folder", fallback="NOT AVAILABLE") + ) + plugins_folder = self.anonymizer.process_path( + configuration.conf.get("core", "plugins_folder", fallback="NOT AVAILABLE") + ) + base_log_folder = self.anonymizer.process_path( + configuration.conf.get("logging", "base_log_folder", fallback="NOT AVAILABLE") + ) + remote_base_log_folder = self.anonymizer.process_path( + configuration.conf.get("logging", "remote_base_log_folder", fallback="NOT AVAILABLE") + ) + + return [ + ("version", airflow_version), + ("executor", executor), + ("task_logging_handler", self._task_logging_handler()), + ("sql_alchemy_conn", sql_alchemy_conn), + ("dags_folder", dags_folder), + ("plugins_folder", plugins_folder), + ("base_log_folder", base_log_folder), + ("remote_base_log_folder", remote_base_log_folder), + ] + + @property + def _system_info(self): + operating_system = OperatingSystem.get_current() + arch = Architecture.get_current() + uname = platform.uname() + _locale = locale.getdefaultlocale() + python_location = self.anonymizer.process_path(sys.executable) + python_version = sys.version.replace("\n", " ") + + return [ + ("OS", operating_system or "NOT AVAILABLE"), + ("architecture", arch or "NOT AVAILABLE"), + ("uname", str(uname)), + ("locale", str(_locale)), + ("python_version", python_version), + ("python_location", python_location), + ] + + @property + def _tools_info(self): + git_version = self._get_version(["git", "--version"]) + ssh_version = self._get_version(["ssh", "-V"]) + kubectl_version = self._get_version(["kubectl", "version", "--short=True", "--client=True"]) + gcloud_version = self._get_version(["gcloud", "version"], grep=b"Google Cloud SDK") + cloud_sql_proxy_version = self._get_version(["cloud_sql_proxy", "--version"]) + mysql_version = self._get_version(["mysql", "--version"]) + sqlite3_version = self._get_version(["sqlite3", "--version"]) + psql_version = self._get_version(["psql", "--version"]) + + return [ + ("git", git_version), + ("ssh", ssh_version), + ("kubectl", kubectl_version), + ("gcloud", gcloud_version), + ("cloud_sql_proxy", cloud_sql_proxy_version), + ("mysql", mysql_version), + ("sqlite3", sqlite3_version), + ("psql", psql_version), + ] + + @property + def _paths_info(self): + system_path = os.environ.get("PATH", "").split(os.pathsep) + airflow_home = self.anonymizer.process_path(configuration.get_airflow_home()) + system_path = [self.anonymizer.process_path(p) for p in system_path] + python_path = [self.anonymizer.process_path(p) for p in sys.path] + airflow_on_path = any(os.path.exists(os.path.join(path_elem, "airflow")) for path_elem in system_path) + + return [ + ("airflow_home", airflow_home), + ("system_path", os.pathsep.join(system_path)), + ("python_path", os.pathsep.join(python_path)), + ("airflow_on_path", str(airflow_on_path)), + ] + + @property + def _providers_info(self): + return [(p.data['package-name'], p.version) for p in ProvidersManager().providers.values()] + + def show(self, output: str, console: Optional[AirflowConsole] = None) -> None: + """Shows information about Airflow instance""" + all_info = { + "Apache Airflow": self._airflow_info, + "System info": self._system_info, + "Tools info": self._tools_info, + "Paths info": self._paths_info, + "Providers info": self._providers_info, + } + + console = console or AirflowConsole(show_header=False) + if output in ("table", "plain"): + # Show each info as table with key, value column + for key, info in all_info.items(): + console.print(f"\n[bold][green]{key}[/bold][/green]", highlight=False) + console.print_as(data=[{"key": k, "value": v} for k, v in info], output=output) + else: + # Render info in given format, change keys to snake_case + console.print_as( + data=[{k.lower().replace(" ", "_"): dict(v)} for k, v in all_info.items()], output=output + ) + + def render_text(self, output: str) -> str: + """Exports the info to string""" + console = AirflowConsole(record=True) + with console.capture(): + self.show(output=output, console=console) + return console.export_text() + + +class FileIoException(Exception): + """Raises when error happens in FileIo.io integration""" + + +@tenacity.retry( + stop=tenacity.stop_after_attempt(5), + wait=tenacity.wait_exponential(multiplier=1, max=10), + retry=tenacity.retry_if_exception_type(FileIoException), + before=tenacity.before_log(log, logging.DEBUG), + after=tenacity.after_log(log, logging.DEBUG), +) +def _upload_text_to_fileio(content): + """Upload text file to File.io service and return lnk""" + resp = httpx.post("https://file.io", content=content) + if resp.status_code not in [200, 201]: + print(resp.json()) + raise FileIoException("Failed to send report to file.io service.") + try: + return resp.json()["link"] + except ValueError as e: + log.debug(e) + raise FileIoException("Failed to send report to file.io service.") + + +def _send_report_to_fileio(info): + print("Uploading report to file.io service.") + try: + link = _upload_text_to_fileio(str(info)) + print("Report uploaded.") + print(link) + print() + except FileIoException as ex: + print(str(ex)) + + +@airflow_cmd.command("info") +@click.option( + "--anonymize", + is_flag=True, + default=False, + help="Minimize any personal identifiable information. Use it when sharing output with others.", +) +@click.option( + "--file-io", + is_flag=True, + default=False, + help="Send output to file.io service and returns link.", +) +@click_output +@click_verbose +@suppress_logs_and_warning_click_compatible +def show_info(anonymize, file_io, output, verbose): + """Show information related to Airflow, system and other.""" + # Enforce anonymization, when file_io upload is tuned on. + anonymizer = PiiAnonymizer() if anonymize or file_io else NullAnonymizer() + info = AirflowInfo(anonymizer) + if file_io: + content = info.render_text(output) + _send_report_to_fileio(content) + else: + info.show(output) diff --git a/tests/cli/commands/test_info.py b/tests/cli/commands/test_info.py new file mode 100644 index 0000000000000..0a2ef8b6392c8 --- /dev/null +++ b/tests/cli/commands/test_info.py @@ -0,0 +1,177 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import importlib +import logging +import os +import unittest + +import pytest +from click.testing import CliRunner +from parameterized import parameterized +from rich.console import Console + +from airflow.cli.commands import info +from airflow.config_templates import airflow_local_settings +from airflow.logging_config import configure_logging +from airflow.version import version as airflow_version +from tests.test_utils.config import conf_vars + + +class TestPiiAnonymizer(unittest.TestCase): + def setUp(self) -> None: + self.instance = info.PiiAnonymizer() + + def test_should_remove_pii_from_path(self): + home_path = os.path.expanduser("~/airflow/config") + assert "${HOME}/airflow/config" == self.instance.process_path(home_path) + + @parameterized.expand( + [ + ( + "postgresql+psycopg2://postgres:airflow@postgres/airflow", + "postgresql+psycopg2://p...s:PASSWORD@postgres/airflow", + ), + ( + "postgresql+psycopg2://postgres@postgres/airflow", + "postgresql+psycopg2://p...s@postgres/airflow", + ), + ( + "postgresql+psycopg2://:airflow@postgres/airflow", + "postgresql+psycopg2://:PASSWORD@postgres/airflow", + ), + ( + "postgresql+psycopg2://postgres/airflow", + "postgresql+psycopg2://postgres/airflow", + ), + ] + ) + def test_should_remove_pii_from_url(self, before, after): + assert after == self.instance.process_url(before) + + +class TestAirflowInfo: + @classmethod + def teardown_class(cls) -> None: + for handler_ref in logging._handlerList[:]: # type: ignore + logging._removeHandlerRef(handler_ref) # type: ignore + importlib.reload(airflow_local_settings) + configure_logging() + + @staticmethod + def unique_items(items): + return {i[0] for i in items} + + @conf_vars( + { + ("core", "executor"): "TEST_EXECUTOR", + ("core", "dags_folder"): "TEST_DAGS_FOLDER", + ("core", "plugins_folder"): "TEST_PLUGINS_FOLDER", + ("logging", "base_log_folder"): "TEST_LOG_FOLDER", + ('core', 'sql_alchemy_conn'): 'postgresql+psycopg2://postgres:airflow@postgres/airflow', + ('logging', 'remote_logging'): 'True', + ('logging', 'remote_base_log_folder'): 's3://logs-name', + } + ) + def test_airflow_info(self): + importlib.reload(airflow_local_settings) + configure_logging() + instance = info.AirflowInfo(info.NullAnonymizer()) + expected = { + 'executor', + 'version', + 'task_logging_handler', + 'plugins_folder', + 'base_log_folder', + 'remote_base_log_folder', + 'dags_folder', + 'sql_alchemy_conn', + } + assert self.unique_items(instance._airflow_info) == expected + + def test_system_info(self): + instance = info.AirflowInfo(info.NullAnonymizer()) + expected = {'uname', 'architecture', 'OS', 'python_location', 'locale', 'python_version'} + assert self.unique_items(instance._system_info) == expected + + def test_paths_info(self): + instance = info.AirflowInfo(info.NullAnonymizer()) + expected = {'airflow_on_path', 'airflow_home', 'system_path', 'python_path'} + assert self.unique_items(instance._paths_info) == expected + + def test_tools_info(self): + instance = info.AirflowInfo(info.NullAnonymizer()) + expected = { + 'cloud_sql_proxy', + 'gcloud', + 'git', + 'kubectl', + 'mysql', + 'psql', + 'sqlite3', + 'ssh', + } + assert self.unique_items(instance._tools_info) == expected + + @conf_vars( + { + ('core', 'sql_alchemy_conn'): 'postgresql+psycopg2://postgres:airflow@postgres/airflow', + } + ) + def test_show_info(self): + runner = CliRunner() + result = runner.invoke(info.show_info) + output = result.output.strip() + assert result.exit_code == 0 + assert airflow_version in output + assert "postgresql+psycopg2://postgres:airflow@postgres/airflow" in output + + @conf_vars( + { + ('core', 'sql_alchemy_conn'): 'postgresql+psycopg2://postgres:airflow@postgres/airflow', + } + ) + def test_show_info_anonymize(self): + runner = CliRunner() + result = runner.invoke(info.show_info, ["--anonymize"]) + output = result.output.strip() + assert airflow_version in output + assert "postgresql+psycopg2://p...s:PASSWORD@postgres/airflow" in output + + +class TestInfoCommandMockHttpx: + @conf_vars( + { + ('core', 'sql_alchemy_conn'): 'postgresql+psycopg2://postgres:airflow@postgres/airflow', + } + ) + def test_show_info_anonymize_fileio(self, httpx_mock): + httpx_mock.add_response( + url="https://file.io", + method="post", + json={ + "success": True, + "key": "f9U3zs3I", + "link": "https://file.io/TEST", + "expiry": "14 days", + }, + status_code=200, + ) + runner = CliRunner() + result = runner.invoke(info.show_info, ["--file-io"]) + output = result.output.strip() + assert "https://file.io/TEST" in output From 3fdbfce1d98c85a6af8b9d7e52816d568e421e54 Mon Sep 17 00:00:00 2001 From: blag Date: Mon, 6 Jun 2022 23:18:53 -0600 Subject: [PATCH 22/34] Fix a few more things --- airflow/cli/__init__.py | 1 + airflow/cli/commands/db.py | 125 +++++++++++++++++++++---------------- 2 files changed, 72 insertions(+), 54 deletions(-) diff --git a/airflow/cli/__init__.py b/airflow/cli/__init__.py index cde2d63c68211..33c3016aedf5f 100644 --- a/airflow/cli/__init__.py +++ b/airflow/cli/__init__.py @@ -45,6 +45,7 @@ click_dry_run = click.option( '-n', '--dry-run', + is_flag=True, default=False, help="Perform a dry run for each task. Only renders Template Fields for each task, nothing else", ) diff --git a/airflow/cli/commands/db.py b/airflow/cli/commands/db.py index 8e52850dd4d37..a08ba199386fb 100644 --- a/airflow/cli/commands/db.py +++ b/airflow/cli/commands/db.py @@ -29,20 +29,24 @@ from airflow.exceptions import AirflowException from airflow.utils import cli as cli_utils from airflow.utils.process_utils import execute_interactive +from airflow.utils.timezone import parse as parsedate -click_revision = click.option( +click_to_revision = click.option( '-r', - '--revision', + '--to-revision', default=None, - help="(Optional) If provided, only run migrations up to and including this revision.", + help=( + "(Optional) If provided, only run migrations up to and including this revision. Note: must " + "provide either `--to-revision` or `--to-version`." + ), ) -click_version = click.option( +click_to_version = click.option( '-n', - '--version', + '--to-version', default=None, help=( "(Optional) The airflow version to upgrade to. Note: must provide either " - "`--revision` or `--version`." + "`--to-revision` or `--to-version`." ), ) click_from_revision = click.option( @@ -53,6 +57,7 @@ '-s', '--show-sql-only', is_flag=True, + default=False, help=( "Don't actually run migrations; just print out sql scripts for offline migration. " "Required if using either `--from-version` or `--from-version`." @@ -97,8 +102,15 @@ def check_migrations(ctx, migration_wait_timeout): @db.command('reset') @click.pass_context +@click.option( + '-s', + '--skip-init', + help="Only remove tables; do not perform db init.", + is_flag=True, + default=False, +) @click_yes -def db_reset(ctx, yes=False): +def db_reset(ctx, skip_init, yes): """Burn down and rebuild the metadata database""" console = Console() @@ -106,26 +118,27 @@ def db_reset(ctx, yes=False): if yes or click.confirm("This will drop existing tables if they exist. Proceed? (y/n)"): from airflow.utils import db as db_utils - db_utils.resetdb() + db_utils.resetdb(skip_init=skip_init) else: console.print("Cancelled") @wrapt.decorator def check_revision_and_version_options(wrapped, instance, args, kwargs): - """A decorator that defines upgrade/downgrade option checks in a single place""" - - def wrapper(ctx, revision, version, from_revision, from_version, *_args, show_sql_only=False, **_kwargs): - if revision is not None and version is not None: - raise SystemExit("Cannot supply both `--revision` and `--version`.") + # Get the progressive aspect of the name of the wrapped function + # upgrade -> upgrading + # downgrade -> downgrading + verb = f'{wrapped.__name__[:-1]}ing' + + def wrapper(ctx, to_revision, to_version, from_revision, from_version, show_sql_only, *_args, **_kwargs): + if to_revision is not None and to_version is not None: + raise SystemExit("Cannot supply both `--to-revision` and `--to-version`.") if from_revision is not None and from_version is not None: raise SystemExit("Cannot supply both `--from-revision` and `--from-version`") if (from_revision is not None or from_version is not None) and not show_sql_only: raise SystemExit( "Args `--from-revision` and `--from-version` may only be used with `--show-sql-only`" ) - if version is None and revision is None: - raise SystemExit("Must provide either --revision or --version.") if from_version is not None: if parse_version(from_version) < parse_version('2.0.0'): @@ -136,13 +149,15 @@ def wrapper(ctx, revision, version, from_revision, from_version, *_args, show_sq if not from_revision: raise SystemExit(f"Unknown version {from_version!r} supplied as `--from-version`.") - if version is not None: - revision = REVISION_HEADS_MAP.get(version) - if not revision: - raise SystemExit(f"Upgrading to version {version} is not supported.") + if to_version is not None: + from airflow.utils.db import REVISION_HEADS_MAP + + to_revision = REVISION_HEADS_MAP.get(to_version) + if not to_revision: + raise SystemExit(f"{verb.capitalize()} to version {to_version} is not supported.") return wrapped( - ctx, revision, version, from_revision, from_version, *_args, show_sql_only=False, **_kwargs + ctx, to_revision, to_version, from_revision, from_version, show_sql_only, *_args, **_kwargs ) return wrapper(*args, **kwargs) @@ -150,14 +165,14 @@ def wrapper(ctx, revision, version, from_revision, from_version, *_args, show_sq @db.command('upgrade') @click.pass_context -@click_revision -@click_version +@click_to_revision +@click_to_version @click_from_revision @click_from_version @click_show_sql_only @click_yes @check_revision_and_version_options -def upgrade(ctx, revision, version, from_revision, from_version, show_sql_only=False, yes=False): +def upgrade(ctx, to_revision, to_version, from_revision, from_version, show_sql_only, yes): """ Upgrade the metadata database to latest version @@ -180,33 +195,32 @@ def upgrade(ctx, revision, version, from_revision, from_version, show_sql_only=F or click.confirm( "\nWarning: About to run schema migrations for the airflow metastore. " "Please ensure you have backed up your database before any migration " - "operation. Proceed? (y/n)\n" + "operation. Proceed?\n" ) ): from airflow.utils import db as db_utils - db_utils.upgradedb(to_revision=revision, from_revision=from_revision, show_sql_only=show_sql_only) + db_utils.upgradedb(to_revision=to_revision, from_revision=from_revision, show_sql_only=show_sql_only) if not show_sql_only: console.print("Upgrades done") else: - SystemExit("Cancelled") + raise SystemExit("Cancelled") @db.command('downgrade') @click.pass_context -@click_revision -@click_version +@click_to_revision +@click_to_version @click_from_revision @click_from_version @click_show_sql_only @click_yes @check_revision_and_version_options -def downgrade(ctx, revision, version, from_revision, from_version, show_sql_only=False, yes=False): +def downgrade(ctx, to_revision, to_version, from_revision, from_version, show_sql_only, yes): """ Downgrade the schema of the metadata database - Downgrade the schema of the metadata database. - You must provide either `--revision` or `--version`. + You must provide either `--to-revision` or `--to-version`. To print but not execute commands, use option `--show-sql-only`. If using options `--from-revision` or `--from-version`, you must also use `--show-sql-only`, because if actually *running* migrations, we should only migrate from the *current* revision. @@ -214,26 +228,27 @@ def downgrade(ctx, revision, version, from_revision, from_version, show_sql_only console = Console() console.print(f"Using DB (engine: {settings.engine.url})") + if not (to_version or to_revision): + raise SystemExit("Must provide either --to-revision or --to-version.") + if not show_sql_only: console.print(f"Performing downgrade with database {settings.engine.url}") else: console.print("Generating SQL for downgrade -- downgrade commands will *not* be submitted.") - if show_sql_only or ( - yes - or click.confirm( + if not (show_sql_only or yes): + click.confirm( "\nWarning: About to reverse schema migrations for the airflow metastore. " "Please ensure you have backed up your database before any migration " - "operation. Proceed? (y/n)\n" + "operation. Proceed?\n", + abort=True, ) - ): - from airflow.utils import db as db_utils - db_utils.downgrade(to_revision=revision, from_revision=from_revision, show_sql_only=show_sql_only) - if not show_sql_only: - console.print("Downgrades done") - else: - SystemExit("Cancelled") + from airflow.utils import db as db_utils + + db_utils.downgrade(to_revision=to_revision, from_revision=from_revision, show_sql_only=show_sql_only) + if not show_sql_only: + console.print("Downgrades done") @db.command('shell') @@ -311,7 +326,7 @@ def __call__(self): def __str__(self): from airflow.utils.db_cleanup import config_dict - return str(sorted(config_dict)) + return ','.join(sorted(config_dict)) @db.command('clean') @@ -323,17 +338,22 @@ def __str__(self): default=_CleanTableDefault(), show_default=True, help=( - "Table names to perform maintenance on (use comma-separated list).\n" - "Can be specified multiple times, all tables names will be used.\n" + "Table names to perform maintenance on.\n" + "Can be specified multiple times or use a comma-separated list.\n" + "If not specified, all tables names will be cleaned.\n" ), ) @click.option( '--clean-before-timestamp', - type=str, - default=None, - help="The date or timestamp before which data should be purged.\n" - "If no timezone info is supplied then dates are assumed to be in airflow default timezone.\n" - "Example: '2022-01-01 00:00:00+01:00'", + required=True, + metavar='TIMESTAMP', + type=parsedate, + help=( + "The date or timestamp before which data should be purged.\n" + "If no timezone info is supplied then dates are assumed to be in airflow default timezone.\n" + "\n" + "Example: '2022-01-01 00:00:00+01:00'\n" + ), ) @click_dry_run @click_verbose @@ -343,11 +363,8 @@ def cleanup_tables(ctx, tables, clean_before_timestamp, dry_run, verbose, yes): """Purge old records in metastore tables""" from airflow.utils.db_cleanup import run_cleanup - split_tables = [] - for table in tables: - split_tables.extend(table.split(',')) run_cleanup( - table_names=split_tables, + table_names=[t.strip() for table in tables for t in table.split(',')] or None, dry_run=dry_run, clean_before_timestamp=clean_before_timestamp, verbose=verbose, From ba9078d0e30a0cf67d36a4ce9fdd312af1c48765 Mon Sep 17 00:00:00 2001 From: blag Date: Mon, 6 Jun 2022 23:32:22 -0600 Subject: [PATCH 23/34] Add tests for new db subcommand --- tests/cli/commands/test_db.py | 499 ++++++++++++++++++++++++++++++++++ 1 file changed, 499 insertions(+) create mode 100644 tests/cli/commands/test_db.py diff --git a/tests/cli/commands/test_db.py b/tests/cli/commands/test_db.py new file mode 100644 index 0000000000000..3de15b3b38183 --- /dev/null +++ b/tests/cli/commands/test_db.py @@ -0,0 +1,499 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import re +from unittest import mock + +import pendulum +import pytest +from click.testing import CliRunner +from pytest import param +from sqlalchemy.engine.url import make_url + +from airflow.cli.commands import db +from airflow.exceptions import AirflowException + + +class TestCliDb: + @classmethod + def setup_class(cls): + cls.runner = CliRunner() + + @mock.patch("airflow.utils.db.initdb") + def test_cli_initdb(self, mock_initdb): + response = self.runner.invoke(db.db_init) + + assert response.exit_code == 0 + mock_initdb.assert_called_once_with() + + @mock.patch("airflow.utils.db.resetdb") + def test_cli_resetdb(self, mock_resetdb): + response = self.runner.invoke(db.db_reset, ['--yes']) + + assert response.exit_code == 0 + mock_resetdb.assert_called_once_with(skip_init=False) + + @mock.patch("airflow.utils.db.resetdb") + def test_cli_resetdb_skip_init(self, mock_resetdb): + response = self.runner.invoke(db.db_reset, ['--yes', '--skip-init']) + + assert response.exit_code == 0 + mock_resetdb.assert_called_once_with(skip_init=True) + + @mock.patch("airflow.utils.db.check_migrations") + def test_cli_check_migrations(self, mock_wait_for_migrations): + response = self.runner.invoke(db.check_migrations) + + assert response.exit_code == 0 + mock_wait_for_migrations.assert_called_once_with(timeout=60) + + @pytest.mark.parametrize( + 'args, called_with', + [ + param( + [], + dict(to_revision=None, from_revision=None, show_sql_only=False), + id="No flags", + ), + param( + ['--show-sql-only'], + dict(to_revision=None, from_revision=None, show_sql_only=True), + id="Just show SQL", + ), + param( + ['--to-revision', 'abc'], + dict(to_revision='abc', from_revision=None, show_sql_only=False), + id="Just --to-revision", + ), + param( + ['--to-revision', 'abc', '--show-sql-only'], + dict(to_revision='abc', from_revision=None, show_sql_only=True), + id="Show SQL with --to-revison", + ), + param( + ['--to-version', '2.2.2'], + dict(to_revision='7b2661a43ba3', from_revision=None, show_sql_only=False), + id="Just --to-version", + ), + param( + ['--to-version', '2.2.2', '--show-sql-only'], + dict(to_revision='7b2661a43ba3', from_revision=None, show_sql_only=True), + id="Show SQL with --to-version", + ), + param( + ['--to-revision', 'abc', '--from-revision', 'abc123', '--show-sql-only'], + dict(to_revision='abc', from_revision='abc123', show_sql_only=True), + id="Show SQL with from revision and to revision", + ), + param( + ['--to-revision', 'abc', '--from-version', '2.2.2', '--show-sql-only'], + dict(to_revision='abc', from_revision='7b2661a43ba3', show_sql_only=True), + id="Show SQL with from version and to revision", + ), + param( + ['--to-version', '2.2.4', '--from-revision', 'abc123', '--show-sql-only'], + dict(to_revision='587bdf053233', from_revision='abc123', show_sql_only=True), + id="Show SQL with from revision and to version", + ), + param( + ['--to-version', '2.2.4', '--from-version', '2.2.2', '--show-sql-only'], + dict(to_revision='587bdf053233', from_revision='7b2661a43ba3', show_sql_only=True), + id="Show SQL with from version and to version", + ), + ], + ) + @mock.patch("airflow.utils.db.upgradedb") + def test_cli_upgrade_success(self, mock_upgradedb, args, called_with): + response = self.runner.invoke(db.upgrade, args, input='y') + + assert response.exit_code == 0 + mock_upgradedb.assert_called_once_with(**called_with) + + @pytest.mark.parametrize( + 'args, pattern', + [ + param(['--to-version', '2.1.25'], 'not supported', id='bad version'), + param( + ['--to-revision', 'abc', '--from-revision', 'abc123'], + 'used with `--show-sql-only`', + id='requires offline', + ), + param( + ['--to-revision', 'abc', '--from-version', '2.0.2'], + 'used with `--show-sql-only`', + id='requires offline', + ), + param( + ['--to-revision', 'abc', '--from-version', '2.1.25', '--show-sql-only'], + 'Unknown version', + id='bad version', + ), + ], + ) + @mock.patch("airflow.utils.db.upgradedb") + def test_cli_upgrade_failure(self, mock_upgradedb, args, pattern): + response = self.runner.invoke(db.upgrade, args, input='y') + + assert response.exit_code != 0 + assert pattern in str(response.exception) + + @mock.patch("airflow.cli.commands.db.execute_interactive") + @mock.patch("airflow.cli.commands.db.NamedTemporaryFile") + @mock.patch("airflow.cli.commands.db.settings.engine.url", make_url("mysql://root@mysql:3306/airflow")) + def test_cli_shell_mysql(self, mock_tmp_file, mock_execute_interactive): + mock_tmp_file.return_value.__enter__.return_value.name = "/tmp/name" + + response = self.runner.invoke(db.shell) + + assert response.exit_code == 0 + mock_execute_interactive.assert_called_once_with(['mysql', '--defaults-extra-file=/tmp/name']) + mock_tmp_file.return_value.__enter__.return_value.write.assert_called_once_with( + b'[client]\nhost = mysql\nuser = root\npassword = \nport = 3306' + b'\ndatabase = airflow' + ) + + @mock.patch("airflow.cli.commands.db.execute_interactive") + @mock.patch("airflow.cli.commands.db.NamedTemporaryFile") + @mock.patch("airflow.cli.commands.db.settings.engine.url", make_url("mysql://root@mysql/airflow")) + def test_cli_shell_mysql_without_port(self, mock_tmp_file, mock_execute_interactive): + mock_tmp_file.return_value.__enter__.return_value.name = "/tmp/name" + + response = self.runner.invoke(db.shell) + + assert response.exit_code == 0 + mock_execute_interactive.assert_called_once_with(['mysql', '--defaults-extra-file=/tmp/name']) + mock_tmp_file.return_value.__enter__.return_value.write.assert_called_once_with( + b'[client]\nhost = mysql\nuser = root\npassword = \nport = 3306' + b'\ndatabase = airflow' + ) + + @mock.patch("airflow.cli.commands.db.execute_interactive") + @mock.patch("airflow.cli.commands.db.settings.engine.url", make_url("sqlite:////root/airflow/airflow.db")) + def test_cli_shell_sqlite(self, mock_execute_interactive): + response = self.runner.invoke(db.shell) + + assert response.exit_code == 0 + mock_execute_interactive.assert_called_once_with(['sqlite3', '/root/airflow/airflow.db']) + + @mock.patch("airflow.cli.commands.db.execute_interactive") + @mock.patch( + "airflow.cli.commands.db.settings.engine.url", + make_url("postgresql+psycopg2://postgres:airflow@postgres:5432/airflow"), + ) + def test_cli_shell_postgres(self, mock_execute_interactive): + response = self.runner.invoke(db.shell) + + assert response.exit_code == 0 + mock_execute_interactive.assert_called_once_with(['psql'], env=mock.ANY) + _, kwargs = mock_execute_interactive.call_args + env = kwargs['env'] + postgres_env = {k: v for k, v in env.items() if k.startswith('PG')} + assert { + 'PGDATABASE': 'airflow', + 'PGHOST': 'postgres', + 'PGPASSWORD': 'airflow', + 'PGPORT': '5432', + 'PGUSER': 'postgres', + } == postgres_env + + @mock.patch("airflow.cli.commands.db.execute_interactive") + @mock.patch( + "airflow.cli.commands.db.settings.engine.url", + make_url("postgresql+psycopg2://postgres:airflow@postgres/airflow"), + ) + def test_cli_shell_postgres_without_port(self, mock_execute_interactive): + response = self.runner.invoke(db.shell) + + assert response.exit_code == 0 + mock_execute_interactive.assert_called_once_with(['psql'], env=mock.ANY) + _, kwargs = mock_execute_interactive.call_args + env = kwargs['env'] + postgres_env = {k: v for k, v in env.items() if k.startswith('PG')} + assert { + 'PGDATABASE': 'airflow', + 'PGHOST': 'postgres', + 'PGPASSWORD': 'airflow', + 'PGPORT': '5432', + 'PGUSER': 'postgres', + } == postgres_env + + @mock.patch( + "airflow.cli.commands.db.settings.engine.url", + make_url("invalid+psycopg2://postgres:airflow@postgres/airflow"), + ) + def test_cli_shell_invalid(self): + response = self.runner.invoke(db.shell) + + assert response.exit_code != 0 + assert isinstance(response.exception, AirflowException) + assert "Unknown driver: invalid+psycopg2" in str(response.exception) + + @pytest.mark.parametrize( + 'args, pattern', + [ + param( + ['-y', '--to-revision', 'abc', '--to-version', '2.2.0'], + r'Cannot supply both .*', + id="Both --to-revision and --to-version", + ), + param( + ['-y', '--to-revision', 'abc1', '--from-revision', 'abc2'], + r'.* may only be used with `--show-sql-only`', + id="Only with --show-sql-only: --to-revision and --from-revision", + ), + param( + ['-y', '--to-revision', 'abc1', '--from-version', '2.2.2'], + r'.* may only be used with `--show-sql-only`', + id="Only with --show-sql-only: --to-revision and --from-version", + ), + param( + ['-y', '--to-version', '2.2.2', '--from-version', '2.2.2'], + r'.* only be used with `--show-sql-only`', + id="Only with --show-sql-only: --to-version and --from-version", + ), + param( + ['-y', '--to-revision', 'abc', '--from-version', '2.2.0', '--from-revision', 'abc'], + r'Cannot supply both .*', + id="Both --from-revision and --from-version", + ), + param( + ['-y', '--to-version', 'abc'], + r'Downgrading to .* is not supported\.', + id="Downgrading to version not supported", + ), + param(['-y'], 'Must provide either', id="Must provide either --to-revision or --to-version"), + ], + ) + @mock.patch("airflow.utils.db.downgrade") + def test_cli_downgrade_invalid(self, mock_dg, args, pattern): + """We test some options that should produce an error""" + + response = self.runner.invoke(db.downgrade, args) + + assert response.exit_code != 0 + assert re.match(pattern, str(response.exception)) + + @pytest.mark.parametrize( + 'args, expected', + [ + param( + ['-y', '--to-revision', 'abc1'], + dict(to_revision='abc1'), + id="Just --to-revision", + ), + param( + ['-y', '--to-revision', 'abc1', '--from-revision', 'abc2', '-s'], + dict(to_revision='abc1', from_revision='abc2', show_sql_only=True), + id="", + ), + param( + ['-y', '--to-revision', 'abc1', '--from-version', '2.2.2', '-s'], + dict(to_revision='abc1', from_revision='7b2661a43ba3', show_sql_only=True), + id="", + ), + param( + ['-y', '--to-version', '2.2.2', '--from-version', '2.2.2', '-s'], + dict(to_revision='7b2661a43ba3', from_revision='7b2661a43ba3', show_sql_only=True), + id="", + ), + param( + ['-y', '--to-version', '2.2.2'], + dict(to_revision='7b2661a43ba3'), + id="", + ), + ], + ) + @mock.patch("airflow.utils.db.downgrade") + def test_cli_downgrade_good(self, mock_dg, args, expected): + defaults = dict(from_revision=None, show_sql_only=False) + + response = self.runner.invoke(db.downgrade, args) + + assert response.exit_code == 0 + mock_dg.assert_called_with(**{**defaults, **expected}) + + @pytest.mark.parametrize( + 'resp, raise_', + [ + ('y', False), + ('Y', False), + ('n', True), + ('a', True), # any other value + ], + ) + @mock.patch("airflow.utils.db.downgrade") + def test_cli_downgrade_confirm(self, mock_dg, resp, raise_): + if raise_: + response = self.runner.invoke(db.downgrade, ['--to-revision', 'abc'], input=resp) + + assert response.exit_code != 0 + assert mock_dg.not_called + + else: + response = self.runner.invoke(db.downgrade, ['--to-revision', 'abc'], input=resp) + + assert response.exit_code == 0 + mock_dg.assert_called_with(to_revision='abc', from_revision=None, show_sql_only=False) + + +class TestCLIDBClean: + @classmethod + def setup_class(cls): + cls.runner = CliRunner() + + @pytest.mark.parametrize('timezone', ['UTC', 'Europe/Berlin', 'America/Los_Angeles']) + @mock.patch('airflow.utils.db_cleanup.run_cleanup') + def test_date_timezone_omitted(self, run_cleanup_mock, timezone): + """ + When timezone omitted we should always expect that the timestamp is + coerced to tz-aware with default timezone + """ + timestamp = '2021-01-01 00:00:00' + with mock.patch('airflow.utils.timezone.TIMEZONE', pendulum.timezone(timezone)): + response = self.runner.invoke( + db.cleanup_tables, ['--clean-before-timestamp', f"{timestamp}", '-y'] + ) + + assert response.exit_code == 0 + + run_cleanup_mock.assert_called_once_with( + table_names=None, + dry_run=False, + clean_before_timestamp=pendulum.parse(timestamp, tz=timezone), + verbose=False, + confirm=False, + ) + + @pytest.mark.parametrize('timezone', ['UTC', 'Europe/Berlin', 'America/Los_Angeles']) + @mock.patch('airflow.utils.db_cleanup.run_cleanup') + def test_date_timezone_supplied(self, run_cleanup_mock, timezone): + """ + When tz included in the string then default timezone should not be used. + """ + timestamp = '2021-01-01 00:00:00+03:00' + with mock.patch('airflow.utils.timezone.TIMEZONE', pendulum.timezone(timezone)): + response = self.runner.invoke( + db.cleanup_tables, ['--clean-before-timestamp', f"{timestamp}", '-y'] + ) + + assert response.exit_code == 0 + + run_cleanup_mock.assert_called_once_with( + table_names=None, + dry_run=False, + clean_before_timestamp=pendulum.parse(timestamp), + verbose=False, + confirm=False, + ) + + @pytest.mark.parametrize('confirm_arg, expected', [(['-y'], False), ([], True)]) + @mock.patch('airflow.utils.db_cleanup.run_cleanup') + def test_confirm(self, run_cleanup_mock, confirm_arg, expected): + """ + When tz included in the string then default timezone should not be used. + """ + response = self.runner.invoke( + db.cleanup_tables, + [ + '--clean-before-timestamp', + '2021-01-01', + *confirm_arg, + ], + ) + + assert response.exit_code == 0 + run_cleanup_mock.assert_called_once_with( + table_names=None, + dry_run=False, + clean_before_timestamp=pendulum.parse('2021-01-01 00:00:00Z'), + verbose=False, + confirm=expected, + ) + + @pytest.mark.parametrize('dry_run_arg, expected', [(['--dry-run'], True), ([], False)]) + @mock.patch('airflow.utils.db_cleanup.run_cleanup') + def test_dry_run(self, run_cleanup_mock, dry_run_arg, expected): + """ + When tz included in the string then default timezone should not be used. + """ + response = self.runner.invoke( + db.cleanup_tables, + [ + '--clean-before-timestamp', + '2021-01-01', + *dry_run_arg, + ], + ) + + assert response.exit_code == 0 + run_cleanup_mock.assert_called_once_with( + table_names=None, + dry_run=expected, + clean_before_timestamp=pendulum.parse('2021-01-01 00:00:00Z'), + verbose=False, + confirm=True, + ) + + @pytest.mark.parametrize( + 'extra_args, expected', [(['--tables', 'hello, goodbye'], ['hello', 'goodbye']), ([], None)] + ) + @mock.patch('airflow.utils.db_cleanup.run_cleanup') + def test_tables(self, run_cleanup_mock, extra_args, expected): + """ + When tz included in the string then default timezone should not be used. + """ + response = self.runner.invoke( + db.cleanup_tables, + [ + '--clean-before-timestamp', + '2021-01-01', + *extra_args, + ], + ) + + assert response.exit_code == 0 + run_cleanup_mock.assert_called_once_with( + table_names=expected, + dry_run=False, + clean_before_timestamp=pendulum.parse('2021-01-01 00:00:00Z'), + verbose=False, + confirm=True, + ) + + @pytest.mark.parametrize('extra_args, expected', [(['--verbose'], True), ([], False)]) + @mock.patch('airflow.utils.db_cleanup.run_cleanup') + def test_verbose(self, run_cleanup_mock, extra_args, expected): + """ + When tz included in the string then default timezone should not be used. + """ + response = self.runner.invoke( + db.cleanup_tables, + [ + '--clean-before-timestamp', + '2021-01-01', + *extra_args, + ], + ) + + assert response.exit_code == 0 + run_cleanup_mock.assert_called_once_with( + table_names=None, + dry_run=False, + clean_before_timestamp=pendulum.parse('2021-01-01 00:00:00Z'), + verbose=expected, + confirm=True, + ) From f927b9414cdeda47852ac3d89e6ad0921908788b Mon Sep 17 00:00:00 2001 From: blag Date: Mon, 6 Jun 2022 23:42:25 -0600 Subject: [PATCH 24/34] Convert version command to use click --- airflow/cli/__main__.py | 1 + airflow/cli/commands/version.py | 28 ++++++++++++++++++++++++++ tests/cli/commands/test_version.py | 32 ++++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+) create mode 100644 airflow/cli/commands/version.py create mode 100644 tests/cli/commands/test_version.py diff --git a/airflow/cli/__main__.py b/airflow/cli/__main__.py index 548f9e9b51a2a..ac28d0fdd25ff 100644 --- a/airflow/cli/__main__.py +++ b/airflow/cli/__main__.py @@ -26,6 +26,7 @@ from airflow.cli.commands import standalone # noqa: F401 from airflow.cli.commands import sync_perm # noqa: F401 from airflow.cli.commands import triggerer # noqa: F401 +from airflow.cli.commands import version # noqa: F401 from airflow.cli.commands import webserver # noqa: F401 if __name__ == '__main__': diff --git a/airflow/cli/commands/version.py b/airflow/cli/commands/version.py new file mode 100644 index 0000000000000..be4a0a9b4d17b --- /dev/null +++ b/airflow/cli/commands/version.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Version command""" +from rich.console import Console + +import airflow +from airflow.cli import airflow_cmd + + +@airflow_cmd.command('version') +def version(): + """Displays Airflow version at the command line""" + console = Console() + console.print(airflow.__version__) diff --git a/tests/cli/commands/test_version.py b/tests/cli/commands/test_version.py new file mode 100644 index 0000000000000..fa767345d5335 --- /dev/null +++ b/tests/cli/commands/test_version.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +from click.testing import CliRunner + +import airflow +from airflow.cli.commands.version import version + + +class TestCliVersion(unittest.TestCase): + def test_cli_version(self): + runner = CliRunner() + response = runner.invoke(version) + + assert response.exit_code == 0 + assert airflow.__version__ in response.output From 8ea28a473b44da89bd0a191ef8ff475f618bdba3 Mon Sep 17 00:00:00 2001 From: blag Date: Tue, 7 Jun 2022 01:32:02 -0600 Subject: [PATCH 25/34] A few more fixups --- airflow/cli/commands/scheduler.py | 7 ++++--- airflow/utils/cli.py | 9 +++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/airflow/cli/commands/scheduler.py b/airflow/cli/commands/scheduler.py index 42a46ead05bde..40ba644d45f06 100644 --- a/airflow/cli/commands/scheduler.py +++ b/airflow/cli/commands/scheduler.py @@ -62,6 +62,7 @@ def _run_scheduler_job(subdir, num_runs, do_pickle, skip_serve_logs): @airflow_cmd.command('scheduler') +@click.pass_context @click_daemon @click.option( "-p", @@ -92,7 +93,7 @@ def _run_scheduler_job(subdir, num_runs, do_pickle, skip_serve_logs): @click_stderr @click_stdout @click_subdir -@cli_utils.action_cli +@cli_utils.action_cli(check_cli_args=False) def scheduler(ctx, daemon_, do_pickle, log_file, num_runs, pid, skip_serve_logs, stderr, stdout, subdir): """Starts Airflow Scheduler""" console = Console() @@ -102,13 +103,13 @@ def scheduler(ctx, daemon_, do_pickle, log_file, num_runs, pid, skip_serve_logs, pid, stdout, stderr, log_file = setup_locations("scheduler", pid, stdout, stderr, log_file) handle = setup_logging(log_file) with open(stdout, 'w+') as stdout_handle, open(stderr, 'w+') as stderr_handle: - ctx = daemon.DaemonContext( + daemon_ctx = daemon.DaemonContext( pidfile=TimeoutPIDLockFile(pid, -1), files_preserve=[handle], stdout=stdout_handle, stderr=stderr_handle, ) - with ctx: + with daemon_ctx: _run_scheduler_job(subdir, num_runs, do_pickle, skip_serve_logs) else: signal.signal(signal.SIGINT, sigint_handler) diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py index 6f4538b8e152a..807e01f6dd2a4 100644 --- a/airflow/utils/cli.py +++ b/airflow/utils/cli.py @@ -46,10 +46,10 @@ def _check_cli_args(args): if not args: - raise ValueError("Args should be set") + raise ValueError(f"Args should be set: {args} [{type(args)}]") -def action_cli(func=None, check_db=True): +def action_cli(func=None, check_db=True, check_cli_args=True): def action_logging(f: T) -> T: """ Decorates function to execute function at the same time submitting action_logging @@ -80,7 +80,8 @@ def wrapper(*args, **kwargs): :param args: Positional argument. :param kwargs: A passthrough keyword argument """ - _check_cli_args(args) + if check_cli_args: + _check_cli_args(args) metrics = _build_metrics(f.__name__, args, kwargs) cli_action_loggers.on_pre_execution(**metrics) try: @@ -141,7 +142,7 @@ def _build_metrics(func_name, args, kwargs): 'user': getuser(), } - tmp_dic = vars(args[0]) if isinstance(args[0], Namespace) else kwargs + tmp_dic = vars(args[0]) if (args and isinstance(args[0], Namespace)) else kwargs metrics['dag_id'] = tmp_dic.get('dag_id') metrics['task_id'] = tmp_dic.get('task_id') metrics['execution_date'] = tmp_dic.get('execution_date') From 62a84a39d785fbe68a52389117787d73adb3096c Mon Sep 17 00:00:00 2001 From: blag Date: Tue, 7 Jun 2022 01:56:14 -0600 Subject: [PATCH 26/34] Add tests for scheduler subcommand --- tests/cli/commands/test_scheduler.py | 95 ++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 tests/cli/commands/test_scheduler.py diff --git a/tests/cli/commands/test_scheduler.py b/tests/cli/commands/test_scheduler.py new file mode 100644 index 0000000000000..d1ac1472b63ca --- /dev/null +++ b/tests/cli/commands/test_scheduler.py @@ -0,0 +1,95 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from unittest import mock + +import pytest +from click.testing import CliRunner + +from airflow.cli.commands import scheduler +from airflow.utils.serve_logs import serve_logs +from tests.test_utils.config import conf_vars + + +class TestSchedulerCommand: + @classmethod + def setup_class(cls): + cls.runner = CliRunner() + + @pytest.mark.parametrize( + 'executor, expect_serve_logs', + [ + ("CeleryExecutor", False), + ("LocalExecutor", True), + ("SequentialExecutor", True), + ("KubernetesExecutor", False), + ], + ) + @mock.patch("airflow.jobs.scheduler_job.SchedulerJob") + @mock.patch("airflow.cli.commands.scheduler.Process") + def test_serve_logs_on_scheduler( + self, + mock_process, + mock_scheduler_job, + executor, + expect_serve_logs, + ): + with conf_vars({("core", "executor"): executor}): + response = self.runner.invoke(scheduler.scheduler) + + assert response.exit_code == 0 + + if expect_serve_logs: + mock_process.assert_called_once_with(target=serve_logs) + else: + mock_process.assert_not_called() + + @pytest.mark.parametrize( + 'executor', + [ + "LocalExecutor", + "SequentialExecutor", + ], + ) + @mock.patch("airflow.jobs.scheduler_job.SchedulerJob") + @mock.patch("airflow.cli.commands.scheduler.Process") + def test_skip_serve_logs(self, mock_process, mock_scheduler_job, executor): + with conf_vars({("core", "executor"): executor}): + response = self.runner.invoke(scheduler.scheduler, ['--skip-serve-logs']) + + assert response.exit_code == 0 + + mock_process.assert_not_called() + + @pytest.mark.parametrize( + 'executor', + [ + "LocalExecutor", + "SequentialExecutor", + ], + ) + @mock.patch("airflow.jobs.scheduler_job.SchedulerJob") + @mock.patch("airflow.cli.commands.scheduler.Process") + def test_graceful_shutdown(self, mock_process, mock_scheduler_job, executor): + with conf_vars({("core", "executor"): executor}): + mock_scheduler_job.run.side_effect = Exception('Mock exception to trigger runtime error') + try: + response = self.runner.invoke(scheduler.scheduler) + finally: + assert response.exit_code == 0 + + mock_process().terminate.assert_called() From bfe86e06ec9091778976e193fa59aec4fc7d84fd Mon Sep 17 00:00:00 2001 From: blag Date: Tue, 7 Jun 2022 02:02:39 -0600 Subject: [PATCH 27/34] Add tests for triggerer subcommand --- airflow/cli/commands/triggerer.py | 3 +- tests/cli/commands/test_triggerer.py | 43 ++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 tests/cli/commands/test_triggerer.py diff --git a/airflow/cli/commands/triggerer.py b/airflow/cli/commands/triggerer.py index c2365e2656104..42c5229047724 100644 --- a/airflow/cli/commands/triggerer.py +++ b/airflow/cli/commands/triggerer.py @@ -30,6 +30,7 @@ @airflow_cmd.command('triggerer') +@click.pass_context @click.option( "--capacity", type=click.IntRange(min=1), @@ -40,7 +41,7 @@ @click_pid @click_stderr @click_stdout -@cli_utils.action_cli +@cli_utils.action_cli(check_cli_args=False) def triggerer(ctx, capacity, daemon_, log_file, pid, stderr, stdout): """Starts Airflow Triggerer""" from airflow.jobs.triggerer_job import TriggererJob diff --git a/tests/cli/commands/test_triggerer.py b/tests/cli/commands/test_triggerer.py new file mode 100644 index 0000000000000..129efbd6e86eb --- /dev/null +++ b/tests/cli/commands/test_triggerer.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from unittest import mock + +from click.testing import CliRunner + +from airflow.cli.commands import triggerer + + +class TestTriggererCommand: + """ + Tests the CLI interface and that it correctly calls the TriggererJob + """ + + @classmethod + def setup_class(cls): + cls.runner = CliRunner() + + @mock.patch("airflow.jobs.triggerer_job.TriggererJob") + def test_capacity_argument( + self, + mock_scheduler_job, + ): + """Ensure that the capacity argument is passed correctly""" + response = self.runner.invoke(triggerer.triggerer, ['--capacity', '42']) + + assert response.exit_code == 0 + mock_scheduler_job.assert_called_once_with(capacity=42) From cd057190a1a698a96d81b4dca43ff9680d53339e Mon Sep 17 00:00:00 2001 From: blag Date: Wed, 8 Jun 2022 01:42:18 -0600 Subject: [PATCH 28/34] Fixup the webserver command --- airflow/cli/commands/webserver.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/airflow/cli/commands/webserver.py b/airflow/cli/commands/webserver.py index 18b782726bbfd..3ca802b573fa9 100644 --- a/airflow/cli/commands/webserver.py +++ b/airflow/cli/commands/webserver.py @@ -80,8 +80,8 @@ class GunicornMonitor(LoggingMixin): :param gunicorn_master_pid: PID for the main Gunicorn process :param num_workers_expected: Number of workers to run the Gunicorn web server - :param master_timeout: Number of seconds the webserver waits before killing gunicorn master that - doesn't respond + :param master_timeout: Number of seconds the webserver waits before terminating gunicorn master + that doesn't respond :param worker_refresh_interval: Number of seconds to wait before refreshing a batch of workers. :param worker_refresh_batch_size: Number of workers to refresh at a time. When set to 0, worker refresh is disabled. When nonzero, airflow periodically refreshes webserver workers by @@ -324,6 +324,7 @@ def _check_workers(self) -> None: @airflow_cmd.command('webserver') +@click.pass_context @click.option( '-p', '--port', @@ -385,18 +386,18 @@ def _check_workers(self) -> None: @click_log_file @click.option( "--ssl-cert", - type=click.Path(exists=True, dir_okay=False, writable=True), + type=click.Path(exists=False, dir_okay=False, writable=True), default=conf.get('webserver', 'WEB_SERVER_SSL_CERT'), help="Path to the SSL certificate for the webserver", ) @click.option( "--ssl-key", - type=click.Path(exists=True, dir_okay=False, writable=True), + type=click.Path(exists=False, dir_okay=False, writable=True), default=conf.get('webserver', 'WEB_SERVER_SSL_KEY'), help="Path to the key to use with the SSL certificate", ) @click_debug -@cli_utils.action_cli +@cli_utils.action_cli(check_cli_args=False) def webserver( ctx, port, From c64674837e53575e93a878d6ef7ecf5255783f39 Mon Sep 17 00:00:00 2001 From: blag Date: Wed, 8 Jun 2022 01:45:25 -0600 Subject: [PATCH 29/34] Add tests for webserver subcommand --- tests/cli/commands/test_webserver.py | 406 +++++++++++++++++++++++++++ 1 file changed, 406 insertions(+) create mode 100644 tests/cli/commands/test_webserver.py diff --git a/tests/cli/commands/test_webserver.py b/tests/cli/commands/test_webserver.py new file mode 100644 index 0000000000000..b921cb1e69f31 --- /dev/null +++ b/tests/cli/commands/test_webserver.py @@ -0,0 +1,406 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import subprocess +import sys +import tempfile +import time +from unittest import mock + +import psutil +import pytest +from click.testing import CliRunner + +from airflow import settings +from airflow.cli.commands import webserver +from airflow.cli.commands.webserver import GunicornMonitor +from airflow.utils.cli import setup_locations +from airflow.www import app as www_app +from tests.test_utils.config import conf_vars + + +class TestGunicornMonitor: + @classmethod + def setup_class(cls): + cls.monitor = GunicornMonitor( + gunicorn_master_pid=1, + num_workers_expected=4, + master_timeout=60, + worker_refresh_interval=60, + worker_refresh_batch_size=2, + reload_on_plugin_change=True, + ) + cls.monitor_patched_methods = [ + mock.patch.object(cls.monitor, '_generate_plugin_state', return_value={}), + mock.patch.object(cls.monitor, '_get_num_ready_workers_running', return_value=4), + mock.patch.object(cls.monitor, '_get_num_workers_running', return_value=4), + mock.patch.object(cls.monitor, '_spawn_new_workers', return_value=None), + mock.patch.object(cls.monitor, '_kill_old_workers', return_value=None), + mock.patch.object(cls.monitor, '_reload_gunicorn', return_value=None), + ] + + def setup_method(self): + for mock_ in self.monitor_patched_methods: + mock_.start() + + def teardown_class(self): + for mock_ in self.monitor_patched_methods: + mock_.stop() + + @mock.patch('airflow.cli.commands.webserver.sleep') + def test_should_wait_for_workers_to_start(self, mock_sleep): + self.monitor._get_num_ready_workers_running.return_value = 0 + self.monitor._get_num_workers_running.return_value = 4 + self.monitor._check_workers() + self.monitor._spawn_new_workers.assert_not_called() + self.monitor._kill_old_workers.assert_not_called() + self.monitor._reload_gunicorn.assert_not_called() + + @mock.patch('airflow.cli.commands.webserver.sleep') + def test_should_kill_excess_workers(self, mock_sleep): + self.monitor._get_num_ready_workers_running.return_value = 10 + self.monitor._get_num_workers_running.return_value = 10 + self.monitor._check_workers() + self.monitor._spawn_new_workers.assert_not_called() + self.monitor._kill_old_workers.assert_called_once_with(2) + self.monitor._reload_gunicorn.assert_not_called() + + @mock.patch('airflow.cli.commands.webserver.sleep') + def test_should_start_new_workers_when_missing(self, mock_sleep): + self.monitor._get_num_ready_workers_running.return_value = 3 + self.monitor._get_num_workers_running.return_value = 3 + self.monitor._check_workers() + # missing one worker, starting just 1 + self.monitor._spawn_new_workers.assert_called_once_with(1) + self.monitor._kill_old_workers.assert_not_called() + self.monitor._reload_gunicorn.assert_not_called() + + @mock.patch('airflow.cli.commands.webserver.sleep') + def test_should_start_new_batch_when_missing_many_workers(self, mock_sleep): + self.monitor._get_num_ready_workers_running.return_value = 1 + self.monitor._get_num_workers_running.return_value = 1 + self.monitor._check_workers() + # missing 3 workers, but starting single batch (2) + self.monitor._spawn_new_workers.assert_called_once_with(2) + self.monitor._kill_old_workers.assert_not_called() + self.monitor._reload_gunicorn.assert_not_called() + + @mock.patch('airflow.cli.commands.webserver.sleep') + def test_should_start_new_workers_when_refresh_interval_has_passed(self, mock_sleep): + self.monitor._last_refresh_time -= 200 + self.monitor._check_workers() + self.monitor._spawn_new_workers.assert_called_once_with(2) + self.monitor._kill_old_workers.assert_not_called() + self.monitor._reload_gunicorn.assert_not_called() + assert abs(self.monitor._last_refresh_time - time.monotonic()) < 5 + + @mock.patch('airflow.cli.commands.webserver.sleep') + def test_should_reload_when_plugin_has_been_changed(self, mock_sleep): + self.monitor._generate_plugin_state.return_value = {'AA': 12} + + self.monitor._check_workers() + + self.monitor._spawn_new_workers.assert_not_called() + self.monitor._kill_old_workers.assert_not_called() + self.monitor._reload_gunicorn.assert_not_called() + + self.monitor._generate_plugin_state.return_value = {'AA': 32} + + self.monitor._check_workers() + + self.monitor._spawn_new_workers.assert_not_called() + self.monitor._kill_old_workers.assert_not_called() + self.monitor._reload_gunicorn.assert_not_called() + + self.monitor._generate_plugin_state.return_value = {'AA': 32} + + self.monitor._check_workers() + + self.monitor._spawn_new_workers.assert_not_called() + self.monitor._kill_old_workers.assert_not_called() + self.monitor._reload_gunicorn.assert_called_once_with() + assert abs(self.monitor._last_refresh_time - time.monotonic()) < 5 + + +class TestGunicornMonitorGeneratePluginState: + @staticmethod + def _prepare_test_file(filepath: str, size: int): + os.makedirs(os.path.dirname(filepath), exist_ok=True) + with open(filepath, "w") as file: + file.write("A" * size) + file.flush() + + def test_should_detect_changes_in_directory(self): + with tempfile.TemporaryDirectory() as tempdir, mock.patch( + "airflow.cli.commands.webserver.settings.PLUGINS_FOLDER", tempdir + ): + self._prepare_test_file(f"{tempdir}/file1.txt", 100) + self._prepare_test_file(f"{tempdir}/nested/nested/nested/nested/file2.txt", 200) + self._prepare_test_file(f"{tempdir}/file3.txt", 300) + + monitor = GunicornMonitor( + gunicorn_master_pid=1, + num_workers_expected=4, + master_timeout=60, + worker_refresh_interval=60, + worker_refresh_batch_size=2, + reload_on_plugin_change=True, + ) + + # When the files have not changed, the result should be constant + state_a = monitor._generate_plugin_state() + state_b = monitor._generate_plugin_state() + + assert state_a == state_b + assert 3 == len(state_a) + + # Should detect new file + self._prepare_test_file(f"{tempdir}/file4.txt", 400) + + state_c = monitor._generate_plugin_state() + + assert state_b != state_c + assert 4 == len(state_c) + + # Should detect changes in files + self._prepare_test_file(f"{tempdir}/file4.txt", 450) + + state_d = monitor._generate_plugin_state() + + assert state_c != state_d + assert 4 == len(state_d) + + # Should support large files + self._prepare_test_file(f"{tempdir}/file4.txt", 4000000) + + state_d = monitor._generate_plugin_state() + + assert state_c != state_d + assert 4 == len(state_d) + + +class TestCLIGetNumReadyWorkersRunning: + @classmethod + def setup_class(cls): + cls.runner = CliRunner() + + cls.children = mock.MagicMock() + cls.child = mock.MagicMock() + cls.process = mock.MagicMock() + cls.monitor = GunicornMonitor( + gunicorn_master_pid=1, + num_workers_expected=4, + master_timeout=60, + worker_refresh_interval=60, + worker_refresh_batch_size=2, + reload_on_plugin_change=True, + ) + + def test_ready_prefix_on_cmdline(self): + self.child.cmdline.return_value = [settings.GUNICORN_WORKER_READY_PREFIX] + self.process.children.return_value = [self.child] + + with mock.patch('psutil.Process', return_value=self.process): + assert self.monitor._get_num_ready_workers_running() == 1 + + def test_ready_prefix_on_cmdline_no_children(self): + self.process.children.return_value = [] + + with mock.patch('psutil.Process', return_value=self.process): + assert self.monitor._get_num_ready_workers_running() == 0 + + def test_ready_prefix_on_cmdline_zombie(self): + self.child.cmdline.return_value = [] + self.process.children.return_value = [self.child] + + with mock.patch('psutil.Process', return_value=self.process): + assert self.monitor._get_num_ready_workers_running() == 0 + + def test_ready_prefix_on_cmdline_dead_process(self): + self.child.cmdline.side_effect = psutil.NoSuchProcess(11347) + self.process.children.return_value = [self.child] + + with mock.patch('psutil.Process', return_value=self.process): + assert self.monitor._get_num_ready_workers_running() == 0 + + +class TestCliWebServer: + @classmethod + def setup_class(cls): + cls.runner = CliRunner() + + @pytest.fixture(autouse=True) + def _cleanup(self): + self._check_processes() + self._clean_pidfiles() + + yield + + self._check_processes(ignore_running=True) + self._clean_pidfiles() + + def _check_processes(self, ignore_running=False): + # Confirm that webserver hasn't been launched. + # pgrep returns exit status 1 if no process matched. + # Use more specific regexps (^) to avoid matching pytest run when running specific method. + # For instance, we want to be able to do: pytest -k 'gunicorn' + exit_code_pgrep_webserver = subprocess.Popen(["pgrep", "-c", "-f", "airflow webserver"]).wait() + exit_code_pgrep_gunicorn = subprocess.Popen(["pgrep", "-c", "-f", "^gunicorn"]).wait() + if exit_code_pgrep_webserver != 1 or exit_code_pgrep_gunicorn != 1: + subprocess.Popen(["ps", "-ax"]).wait() + if exit_code_pgrep_webserver != 1: + subprocess.Popen(["pkill", "-9", "-f", "airflow webserver"]).wait() + if exit_code_pgrep_gunicorn != 1: + subprocess.Popen(["pkill", "-9", "-f", "^gunicorn"]).wait() + if not ignore_running: + raise AssertionError( + "Background processes are running that prevent the test from passing successfully." + ) + + def _clean_pidfiles(self): + pidfile_webserver = setup_locations("webserver")[0] + pidfile_monitor = setup_locations("webserver-monitor")[0] + if os.path.exists(pidfile_webserver): + os.remove(pidfile_webserver) + if os.path.exists(pidfile_monitor): + os.remove(pidfile_monitor) + + def _wait_pidfile(self, pidfile): + start_time = time.monotonic() + while True: + try: + with open(pidfile) as file: + return int(file.read()) + except Exception: + if start_time - time.monotonic() > 60: + raise + time.sleep(1) + + @pytest.mark.quarantined + def test_cli_webserver_background(self): + with tempfile.TemporaryDirectory(prefix="gunicorn") as tmpdir, mock.patch.dict( + "os.environ", + AIRFLOW__CORE__DAGS_FOLDER="/dev/null", + AIRFLOW__CORE__LOAD_EXAMPLES="False", + AIRFLOW__WEBSERVER__WORKERS="1", + ): + pidfile_webserver = f"{tmpdir}/pidflow-webserver.pid" + pidfile_monitor = f"{tmpdir}/pidflow-webserver-monitor.pid" + stdout = f"{tmpdir}/airflow-webserver.out" + stderr = f"{tmpdir}/airflow-webserver.err" + logfile = f"{tmpdir}/airflow-webserver.log" + try: + # Run webserver as daemon in background. Note that the wait method is not called. + + proc = subprocess.Popen( + [ + "airflow", + "webserver", + "--daemon", + "--pid", + pidfile_webserver, + "--stdout", + stdout, + "--stderr", + stderr, + "--log-file", + logfile, + ] + ) + assert proc.poll() is None + + pid_monitor = self._wait_pidfile(pidfile_monitor) + self._wait_pidfile(pidfile_webserver) + + # Assert that gunicorn and its monitor are launched. + assert 0 == subprocess.Popen(["pgrep", "-f", "-c", "airflow webserver --daemon"]).wait() + assert 0 == subprocess.Popen(["pgrep", "-c", "-f", "gunicorn: master"]).wait() + + # Terminate monitor process. + proc = psutil.Process(pid_monitor) + proc.terminate() + assert proc.wait(120) in (0, None) + + self._check_processes() + except Exception: + # List all logs + subprocess.Popen(["ls", "-lah", tmpdir]).wait() + # Dump all logs + subprocess.Popen(["bash", "-c", f"ls {tmpdir}/* | xargs -n 1 -t cat"]).wait() + raise + + # Patch for causing webserver timeout + @mock.patch("airflow.cli.commands.webserver.GunicornMonitor._get_num_workers_running", return_value=0) + def test_cli_webserver_shutdown_when_gunicorn_master_is_killed(self, _): + # Shorten timeout so that this test doesn't take too long time + with conf_vars({('webserver', 'web_server_master_timeout'): '10'}): + response = self.runner.invoke(webserver.webserver) + + assert response.exit_code == 1 + + def test_cli_webserver_debug(self, app): + with mock.patch.object(www_app, 'create_app') as create_app, mock.patch.object(app, 'run') as app_run: + create_app.return_value = app + + self.runner.invoke(webserver.webserver, ['--debug']) + + app_run.assert_called_with( + debug=True, + use_reloader=False, + port=8080, + host='0.0.0.0', + ssl_context=None, + ) + + def test_cli_webserver_args(self): + with mock.patch("subprocess.Popen") as Popen, mock.patch.object(webserver, 'GunicornMonitor'): + response = self.runner.invoke( + webserver.webserver, ['--access-logformat', 'custom_log_format', '--pid', '/tmp/x.pid'] + ) + + assert response.exit_code == 0 + + Popen.assert_called_with( + [ + sys.executable, + '-m', + 'gunicorn', + '--workers', + '4', + '--worker-class', + 'sync', + '--timeout', + '120', + '--bind', + '0.0.0.0:8080', + '--name', + 'airflow-webserver', + '--pid', + '/tmp/x.pid', + '--config', + 'python:airflow.www.gunicorn_config', + '--access-logfile', + '-', + '--error-logfile', + '-', + '--access-logformat', + 'custom_log_format', + 'airflow.www.app:cached_app()', + ], + close_fds=True, + ) From dc98b37d778b4aaeac28ad3e854c31b8a1906dfb Mon Sep 17 00:00:00 2001 From: blag Date: Wed, 8 Jun 2022 15:25:04 -0600 Subject: [PATCH 30/34] Convert users command to use click --- airflow/cli/__main__.py | 1 + airflow/cli/commands/users.py | 421 +++++++++++++++++++++++++ tests/cli/commands/test_users.py | 515 +++++++++++++++++++++++++++++++ 3 files changed, 937 insertions(+) create mode 100644 airflow/cli/commands/users.py create mode 100644 tests/cli/commands/test_users.py diff --git a/airflow/cli/__main__.py b/airflow/cli/__main__.py index ac28d0fdd25ff..7b48af63864b4 100644 --- a/airflow/cli/__main__.py +++ b/airflow/cli/__main__.py @@ -26,6 +26,7 @@ from airflow.cli.commands import standalone # noqa: F401 from airflow.cli.commands import sync_perm # noqa: F401 from airflow.cli.commands import triggerer # noqa: F401 +from airflow.cli.commands import users # noqa: F401 from airflow.cli.commands import version # noqa: F401 from airflow.cli.commands import webserver # noqa: F401 diff --git a/airflow/cli/commands/users.py b/airflow/cli/commands/users.py new file mode 100644 index 0000000000000..0dd5da9efd9cb --- /dev/null +++ b/airflow/cli/commands/users.py @@ -0,0 +1,421 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""User sub-commands""" +import getpass +import json +import random +import re +import string +from typing import Any, Dict, List + +import rich_click as click +from marshmallow import Schema, fields, validate +from marshmallow.exceptions import ValidationError +from rich.console import Console + +from airflow.cli import airflow_cmd, click_output, click_verbose +from airflow.cli.simple_table import AirflowConsole +from airflow.utils import cli as cli_utils +from airflow.utils.cli import suppress_logs_and_warning_click_compatible +from airflow.www.app import cached_app + + +class UserSchema(Schema): + """user collection item schema""" + + id = fields.Int() + firstname = fields.Str(required=True) + lastname = fields.Str(required=True) + username = fields.Str(required=True) + email = fields.Email(required=True) + roles = fields.List(fields.Str, required=True, validate=validate.Length(min=1)) + + +click_email = click.option( + '-e', + '--email', + metavar="EMAIL", + help="Email of the user", +) +click_username = click.option( + '-u', + '--username', + metavar="USERNAME", + help="Username of the user", +) +click_role = click.option( + '-r', + '--role', + metavar="ROLE", + help=""" + Role of the user. + + Existing roles include: Admin, User, Op, Viewer, and Public. + """, +) + + +@airflow_cmd.group('users') +def users(): + """Commands for managing users""" + + +@users.command('list') +@click.pass_context +@click_output +@click_verbose +@suppress_logs_and_warning_click_compatible +def list_(ctx, output, verbose): + """Lists users at the command line""" + appbuilder = cached_app().appbuilder + users = appbuilder.sm.get_all_users() + fields = ['id', 'username', 'email', 'first_name', 'last_name', 'roles'] + + AirflowConsole().print_as( + data=users, output=output, mapper=lambda x: {f: x.__getattribute__(f) for f in fields} + ) + + +@users.command('create') +@click.pass_context +@click.option( + '-f', '--firstname', metavar='FIRSTNAME', required=True, type=str, help="First name of the user" +) +@click.option('-l', '--lastname', metavar='LASTNAME', required=True, type=str, help="Last name of the user") +@click.option('-e', '--email', metavar='EMAIL', required=True, type=str, help="Email of the user") +@click.option('-u', '--username', metavar='USERNAME', required=True, type=str, help="Username of the user") +@click.option( + '-p', + '--password', + metavar='PASSWORD', + is_flag=False, + flag_value=False, + help="Password of the user, required to create a user without --use-random-password", +) +@click.option( + '--use-random-password', + is_flag=True, + default=False, + help=( + "Do not prompt for password. Use random string instead. Required to create a user without --password" + ), +) +@click.option( + '-r', + '--role', + metavar='ROLE', + required=True, + help="Role of the user, pre-existing roles include Admin, User, Op, Viewer, and Public", +) +@cli_utils.action_cli(check_db=True) +def create(ctx, firstname, lastname, email, username, password, use_random_password, role): # noqa: D301 + """ + Creates new user in the DB + + \b + Example + To create a user with "Admin" role and username "admin": + + \b + airflow users create \\ + --username admin \\ + --firstname FIRST_NAME \\ + --lastname LAST_NAME \\ + --role Admin \\ + --email admin@example.org + """ + console = Console() + appbuilder = cached_app().appbuilder + role_ = appbuilder.sm.find_role(role) + if not role_: + valid_roles = appbuilder.sm.get_all_roles() + raise SystemExit(f'{role} is not a valid role. Valid roles are: {valid_roles}') + + if password and use_random_password: + raise SystemExit('You cannot specify both --password and --use-random-password') + + # Click's password_option isn't aware of the --use-random-password option, so we have to handle + # setting passwords manually + if use_random_password: + password_ = ''.join(random.choice(string.printable) for _ in range(16)) + elif password: + password_ = password + else: + password_ = getpass.getpass('Password:') + password_confirmation = getpass.getpass('Repeat for confirmation:') + if password_ != password_confirmation: + raise SystemExit('Passwords did not match') + + if appbuilder.sm.find_user(username): + console.print(f'{username} already exist in the db') + return + user = appbuilder.sm.add_user(username, firstname, lastname, email, role_, password_) + if user: + console.print(f'User "{username}" created with role "{role}"') + else: + raise SystemExit('Failed to create user') + + +def _find_user(username=None, email=None): + if not username and not email: + raise SystemExit('Missing args: must supply one of --username or --email') + + if username and email: + raise SystemExit('Conflicting args: must supply either --username or --email, but not both') + + appbuilder = cached_app().appbuilder + + user = appbuilder.sm.find_user(username=username, email=email) + if not user: + raise SystemExit(f'User "{username or email}" does not exist') + return user + + +@users.command('delete') +@click.pass_context +@click.option( + '-e', + '--email', + metavar='EMAIL', + help="Email of the user", +) +@click.option( + '-u', + '--username', + metavar='USERNAME', + help="Username of the user", +) +@cli_utils.action_cli +def delete(ctx, email, username): + """Deletes user from DB""" + user = _find_user(username=username, email=email) + + appbuilder = cached_app().appbuilder + + if appbuilder.sm.del_register_user(user): + print(f'User "{user.username}" deleted') + else: + raise SystemExit('Failed to delete user') + + +@users.command('add-role') +@click.pass_context +@click_email +@click_username +@click_role +@cli_utils.action_cli +def add_role(ctx, email, username, role): + """ + Grant a role to a user + + Exactly one of --email or --username must be specified. + """ + return users_manage_role(email, username, role, remove=False) + + +@users.command('remove-role') +@click.pass_context +@click_email +@click_username +@click_role +@cli_utils.action_cli +def remove_role(ctx, email, username, role): + """ + Revoke a role from a user + + Exactly one of --email or --username must be specified. + """ + return users_manage_role(email, username, role, remove=True) + + +def users_manage_role(email, username, role, remove=False): + """Deletes or appends user roles""" + console = Console() + + user = _find_user(username=username, email=email) + + appbuilder = cached_app().appbuilder + + found_role = appbuilder.sm.find_role(role) + if not found_role: + valid_roles = appbuilder.sm.get_all_roles() + raise SystemExit(f'"{role}" is not a valid role. Valid roles are: {valid_roles}') + + if remove: + if found_role not in user.roles: + raise SystemExit(f'User "{user.username}" is not a member of role "{found_role}"') + + user.roles = [r for r in user.roles if r != found_role] + appbuilder.sm.update_user(user) + console.print(f'User "{user.username}" removed from role "{found_role}"') + else: + if found_role in user.roles: + raise SystemExit(f'User "{user.username}" is already a member of role "{found_role}"') + + user.roles.append(found_role) + appbuilder.sm.update_user(user) + console.print(f'User "{user.username}" added to role "{found_role}"') + + +@users.command('export') +@click.pass_context +@click.argument('FILEPATH', type=click.Path(exists=True)) +@cli_utils.action_cli +def export(ctx, filepath): # noqa: D301, D412 + """ + Exports all users to the json file + + Arguments: + + FILEPATH Export users from this JSON file. + + \b + Example format: + [ + { + "email": "foo@bar.org", + "firstname": "Jon", + "lastname": "Doe", + "roles": ["Public"], + "username": "jdoe" + } + ] + """ + console = Console() + + appbuilder = cached_app().appbuilder + all_users = appbuilder.sm.get_all_users() + fields = ['id', 'username', 'email', 'first_name', 'last_name', 'roles'] + + # In the User model the first and last name fields have underscores, + # but the corresponding parameters in the CLI don't + def remove_underscores(s): + return re.sub("_", "", s) + + users_ = [ + { + remove_underscores(field): user.__getattribute__(field) + if field != 'roles' + else [r.name for r in user.roles] + for field in fields + } + for user in all_users + ] + + with open(filepath, 'w') as file: + file.write(json.dumps(users_, sort_keys=True, indent=4)) + console.print(f"{len(users)} users successfully exported to {file.name}") + + +@users.command('import') +@click.pass_context +@click.argument('FILEPATH', type=click.Path(exists=True)) +@cli_utils.action_cli +def import_(ctx, filepath): # noqa: D301, D412 + """ + Imports users from a JSON file + + Arguments: + + FILEPATH Import users from this JSON file. + + \b + Example format: + [ + { + "email": "foo@bar.org", + "firstname": "Jon", + "lastname": "Doe", + "roles": ["Public"], + "username": "jdoe" + } + ] + """ + console = Console() + + users_list = None + try: + with open(filepath) as file: + users_list = json.loads(file.read()) + except ValueError as e: + raise SystemExit(f"File '{filepath}' is not valid JSON. Error: {e}") + + users_created, users_updated = _import_users(users_list, console=console) + if users_created: + console.print("Created the following users:") + for user in users_created: + console.print(f"\t{user}") + + if users_updated: + console.print("Updated the following users:") + for user in users_updated: + console.print(f"\t{user}") + + +def _import_users(users_list: List[Dict[str, Any]], console=None): + appbuilder = cached_app().appbuilder + users_created = [] + users_updated = [] + + try: + UserSchema(many=True).load(users_list) + except ValidationError as e: + msg = ["Error: Input file didn't pass validation. See below:"] + for row_num, failure in e.normalized_messages().items(): + msg.append(f'[Item {row_num}]') + for key, value in failure.items(): + msg.append(f'\t{key}: {value}') + raise SystemExit('\n'.join(msg)) + + for user in users_list: + + roles = [] + for rolename in user['roles']: + role = appbuilder.sm.find_role(rolename) + if not role: + valid_roles = appbuilder.sm.get_all_roles() + raise SystemExit(f'Error: "{rolename}" is not a valid role. Valid roles are: {valid_roles}') + + roles.append(role) + + existing_user = appbuilder.sm.find_user(email=user['email']) + if existing_user: + console.print(f"Found existing user with email '{user['email']}'") + if existing_user.username != user['username']: + raise SystemExit( + f"Error: Changing the username is not allowed - please delete and recreate the user with" + f" email {user['email']!r}" + ) + + existing_user.roles = roles + existing_user.first_name = user['firstname'] + existing_user.last_name = user['lastname'] + appbuilder.sm.update_user(existing_user) + users_updated.append(user['email']) + else: + console.print(f"Creating new user with email '{user['email']}'") + appbuilder.sm.add_user( + username=user['username'], + first_name=user['firstname'], + last_name=user['lastname'], + email=user['email'], + role=roles, + ) + + users_created.append(user['email']) + + return users_created, users_updated diff --git a/tests/cli/commands/test_users.py b/tests/cli/commands/test_users.py new file mode 100644 index 0000000000000..e5fd042ab00f0 --- /dev/null +++ b/tests/cli/commands/test_users.py @@ -0,0 +1,515 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +import os +import tempfile + +import pytest +from click.testing import CliRunner + +from airflow.cli.commands import users +from tests.test_utils.api_connexion_utils import delete_users + +TEST_USER1_EMAIL = 'test-user1@example.com' +TEST_USER2_EMAIL = 'test-user2@example.com' +TEST_USER3_EMAIL = 'test-user3@example.com' + + +def _does_user_belong_to_role(appbuilder, email, rolename): + user = appbuilder.sm.find_user(email=email) + role = appbuilder.sm.find_role(rolename) + if user and role: + return role in user.roles + + return False + + +class TestCliUsers: + @classmethod + def setup_class(cls): + cls.runner = CliRunner() + + @pytest.fixture(autouse=True) + def _set_attrs(self, app, dagbag, parser): + self.app = app + self.dagbag = dagbag + self.parser = parser + self.appbuilder = self.app.appbuilder + delete_users(app) + yield + delete_users(app) + + def test_cli_create_user_random_password(self): + response = self.runner.invoke( + users.create, + [ + '--username', + 'test1', + '--lastname', + 'doe', + '--firstname', + 'jon', + '--email', + 'jdoe@foo.com', + '--role', + 'Viewer', + '--use-random-password', + ], + ) + + assert response.exit_code == 0 + + def test_cli_create_user_supplied_password(self): + response = self.runner.invoke( + users.create, + [ + '--username', + 'test2', + '--lastname', + 'doe', + '--firstname', + 'jon', + '--email', + 'jdoe@apache.org', + '--role', + 'Viewer', + '--password', + 'test', + ], + ) + + assert response.exit_code == 0 + + def test_cli_create_user_typed_password(self): + response = self.runner.invoke( + users.create, + [ + '--username', + 'test2', + '--lastname', + 'doe', + '--firstname', + 'jon', + '--email', + 'jdoe@apache.org', + '--role', + 'Viewer', + '--password', + ], + input="testpw\ntestpw\n", + ) + + assert response.exit_code == 0 + + def test_cli_create_user_mistyped_confirm_password(self): + response = self.runner.invoke( + users.create, + [ + '--username', + 'test2', + '--lastname', + 'doe', + '--firstname', + 'jon', + '--email', + 'jdoe@apache.org', + '--role', + 'Viewer', + '--password', + ], + input="testpw\ntestp\n", + ) + + assert response.exit_code == 0 + + def test_cli_create_user_random_and_supplied_password(self): + response = self.runner.invoke( + users.create, + [ + '--username', + 'thisusershouldntexist', + '--lastname', + 'shouldntexist', + '--firstname', + 'thisuser', + '--email', + 'thisusershouldntexist@example.com', + '--role', + 'Viewer', + '--password', + 'test', + '--use-random-password', + ], + ) + + assert response.exit_code != 0 + assert "cannot specify both" in response.stdout + + def test_cli_delete_user(self): + response = self.runner.invoke( + users.create, + [ + '--username', + 'test3', + '--lastname', + 'doe', + '--firstname', + 'jon', + '--email', + 'jdoe@example.com', + '--role', + 'Viewer', + '--use-random-password', + ], + ) + + response = self.runner.invoke( + users.delete, + [ + '--username', + 'test3', + ], + ) + + assert response.exit_code == 0 + assert 'User "test3" deleted' in response.stdout + + def test_cli_delete_user_by_email(self): + self.runner.invoke( + users.create, + [ + '--username', + 'test4', + '--lastname', + 'doe', + '--firstname', + 'jon', + '--email', + 'jdoe2@example.com', + '--role', + 'Viewer', + '--use-random-password', + ], + ) + + response = self.runner.invoke( + users.delete, + [ + '--email', + 'jdoe2@example.com', + ], + ) + + assert response.exit_code == 0 + assert 'User "test4" deleted' in response.stdout + + @pytest.mark.parametrize( + 'args, raise_match', + [ + ( + [], + 'Missing args: must supply one of --username or --email', + ), + ( + [ + 'test_user_name99', + 'jdoe2@example.com', + ], + 'Conflicting args: must supply either --username or --email, but not both', + ), + ( + [ + 'test_user_name99', + ], + 'User "test_user_name99" does not exist', + ), + ( + [ + 'jode2@example.com', + ], + 'User "jode2@example.com" does not exist', + ), + ], + ) + def test_find_user_exceptions(self, args, raise_match): + with pytest.raises( + SystemExit, + match=raise_match, + ): + users._find_user(*args) + + def test_cli_list_users(self): + for i in range(0, 3): + self.runner.invoke( + users.create, + [ + '--username', + f'user{i}', + '--lastname', + 'doe', + '--firstname', + 'jon', + '--email', + f'jdoe+{i}@gmail.com', + '--role', + 'Viewer', + '--use-random-password', + ], + ) + + response = self.runner.invoke(users.list_) + + assert response.exit_code == 0 + + for i in range(0, 3): + assert f'user{i}' in response.stdout + + def test_cli_list_users_with_args(self): + response = self.runner.invoke(users.list_, ['--output', 'json']) + + assert response.exit_code == 0 + + def test_cli_import_users(self): + def assert_user_in_roles(email, roles): + for role in roles: + assert _does_user_belong_to_role(self.appbuilder, email, role) + + def assert_user_not_in_roles(email, roles): + for role in roles: + assert not _does_user_belong_to_role(self.appbuilder, email, role) + + assert_user_not_in_roles(TEST_USER1_EMAIL, ['Admin', 'Op']) + assert_user_not_in_roles(TEST_USER2_EMAIL, ['Public']) + users_ = [ + { + "username": "imported_user1", + "lastname": "doe1", + "firstname": "jon", + "email": TEST_USER1_EMAIL, + "roles": ["Admin", "Op"], + }, + { + "username": "imported_user2", + "lastname": "doe2", + "firstname": "jon", + "email": TEST_USER2_EMAIL, + "roles": ["Public"], + }, + ] + self._import_users_from_file(users_) + + assert_user_in_roles(TEST_USER1_EMAIL, ['Admin', 'Op']) + assert_user_in_roles(TEST_USER2_EMAIL, ['Public']) + + users_ = [ + { + "username": "imported_user1", + "lastname": "doe1", + "firstname": "jon", + "email": TEST_USER1_EMAIL, + "roles": ["Public"], + }, + { + "username": "imported_user2", + "lastname": "doe2", + "firstname": "jon", + "email": TEST_USER2_EMAIL, + "roles": ["Admin"], + }, + ] + self._import_users_from_file(users_) + + assert_user_not_in_roles(TEST_USER1_EMAIL, ['Admin', 'Op']) + assert_user_in_roles(TEST_USER1_EMAIL, ['Public']) + assert_user_not_in_roles(TEST_USER2_EMAIL, ['Public']) + assert_user_in_roles(TEST_USER2_EMAIL, ['Admin']) + + def test_cli_export_users(self): + user1 = { + "username": "imported_user1", + "lastname": "doe1", + "firstname": "jon", + "email": TEST_USER1_EMAIL, + "roles": ["Public"], + } + user2 = { + "username": "imported_user2", + "lastname": "doe2", + "firstname": "jon", + "email": TEST_USER2_EMAIL, + "roles": ["Admin"], + } + self._import_users_from_file([user1, user2]) + + users_filename = self._export_users_to_file() + with open(users_filename) as file: + retrieved_users = json.loads(file.read()) + os.remove(users_filename) + + # ensure that an export can be imported + self._import_users_from_file(retrieved_users) + + def find_by_username(username): + matches = [u for u in retrieved_users if u['username'] == username] + assert matches, ( + f"Couldn't find user with username {username} in: " + f"[{', '.join([u['username'] for u in retrieved_users])}]" + ) + matches[0].pop('id') # this key not required for import + return matches[0] + + assert find_by_username('imported_user1') == user1 + assert find_by_username('imported_user2') == user2 + + def _import_users_from_file(self, user_list): + json_file_content = json.dumps(user_list) + with tempfile.NamedTemporaryFile(delete=False) as f: + try: + f.write(json_file_content.encode()) + f.flush() + + response = self.runner.invoke(users.import_, [f.name]) + finally: + os.remove(f.name) + return response + + def _export_users_to_file(self): + with tempfile.NamedTemporaryFile(delete=False) as f: + self.runner.invoke(users.export, [f.name]) + return f.name + + @pytest.fixture() + def create_user_test4(self): + self.runner.invoke( + users.create, + [ + '--username', + 'test4', + '--lastname', + 'doe', + '--firstname', + 'jon', + '--email', + TEST_USER1_EMAIL, + '--role', + 'Viewer', + '--use-random-password', + ], + ) + + def test_cli_add_user_role(self, create_user_test4): + assert not _does_user_belong_to_role( + appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Op' + ), "User should not yet be a member of role 'Op'" + + response = self.runner.invoke(users.add_role, ['--username', 'test4', '--role', 'Op']) + + assert response.exit_code == 0 + assert 'User "test4" added to role "Op"' in response.stdout + + assert _does_user_belong_to_role( + appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Op' + ), "User should have been added to role 'Op'" + + def test_cli_remove_user_role(self, create_user_test4): + assert _does_user_belong_to_role( + appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Viewer' + ), "User should have been created with role 'Viewer'" + + response = self.runner.invoke(users.remove_role, ['--username', 'test4', '--role', 'Viewer']) + + assert response.exit_code == 0 + assert 'User "test4" removed from role "Viewer"' in response.stdout + + assert not _does_user_belong_to_role( + appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Viewer' + ), "User should have been removed from role 'Viewer'" + + @pytest.mark.parametrize( + "action, role, message", + [ + ["add", "Viewer", 'User "test4" is already a member of role "Viewer"'], + ["add", "Foo", '"Foo" is not a valid role. Valid roles are'], + ["remove", "Admin", 'User "test4" is not a member of role "Admin"'], + ["remove", "Foo", '"Foo" is not a valid role. Valid roles are'], + ], + ) + def test_cli_manage_roles_exceptions(self, create_user_test4, action, role, message): + args = ['--username', 'test4', '--role', role] + if action == 'add': + response = self.runner.invoke(users.add_role, args) + else: + response = self.runner.invoke(users.remove_role, args) + + assert response.exit_code != 0 + assert message in response.stdout + + @pytest.mark.parametrize( + "user, message", + [ + [ + { + "username": "imported_user1", + "lastname": "doe1", + "firstname": "john", + "email": TEST_USER1_EMAIL, + "roles": "This is not a list", + }, + "Error: Input file didn't pass validation. See below:\n" + "[Item 0]\n" + "\troles: ['Not a valid list.']", + ], + [ + { + "username": "imported_user2", + "lastname": "doe2", + "firstname": "jon", + "email": TEST_USER2_EMAIL, + "roles": [], + }, + "Error: Input file didn't pass validation. See below:\n" + "[Item 0]\n" + "\troles: ['Shorter than minimum length 1.']", + ], + [ + { + "username1": "imported_user3", + "lastname": "doe3", + "firstname": "jon", + "email": TEST_USER3_EMAIL, + "roles": ["Test"], + }, + "Error: Input file didn't pass validation. See below:\n" + "[Item 0]\n" + "\tusername: ['Missing data for required field.']\n" + "\tusername1: ['Unknown field.']", + ], + [ + "Wrong input", + "Error: Input file didn't pass validation. See below:\n" + "[Item 0]\n" + "\t_schema: ['Invalid input type.']", + ], + ], + ids=["Incorrect roles", "Empty roles", "Required field is missing", "Wrong input"], + ) + def test_cli_import_users_exceptions(self, user, message): + response = self._import_users_from_file([user]) + + assert response.exit_code != 0 + assert message in response.stdout From 771262b32c5f216aecc39a8ffb7dbee00e62fd2b Mon Sep 17 00:00:00 2001 From: blag Date: Wed, 8 Jun 2022 15:25:51 -0600 Subject: [PATCH 31/34] Add rich-click to dependencies --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index ca825be210f8d..1d6a4302f9e2e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -156,6 +156,7 @@ install_requires = python-nvd3>=0.15.0 python-slugify>=5.0 rich>=12.4.4 + rich-click>=1.3.1 setproctitle>=1.1.8 # SQL Alchemy 1.4.10 introduces a bug where for PyODBC driver UTCDateTime fields get wrongly converted # as string and fail to be converted back to datetime. It was supposed to be fixed in From 62e7b11e1d54e9cc0c761b412827c812c5752dc3 Mon Sep 17 00:00:00 2001 From: hankehly Date: Mon, 13 Jun 2022 09:55:17 +0900 Subject: [PATCH 32/34] Port jobs command and unit tests to click --- airflow/cli/__main__.py | 1 + airflow/cli/commands/jobs.py | 82 ++++++++++++++++++++ tests/cli/commands/test_jobs.py | 129 ++++++++++++++++++++++++++++++++ 3 files changed, 212 insertions(+) create mode 100644 airflow/cli/commands/jobs.py create mode 100644 tests/cli/commands/test_jobs.py diff --git a/airflow/cli/__main__.py b/airflow/cli/__main__.py index 7b48af63864b4..a559530b7af3d 100644 --- a/airflow/cli/__main__.py +++ b/airflow/cli/__main__.py @@ -22,6 +22,7 @@ from airflow.cli.commands import cheat_sheet # noqa: F401 from airflow.cli.commands import db # noqa: F401 from airflow.cli.commands import info # noqa: F401 +from airflow.cli.commands import jobs # noqa: F401 from airflow.cli.commands import scheduler # noqa: F401 from airflow.cli.commands import standalone # noqa: F401 from airflow.cli.commands import sync_perm # noqa: F401 diff --git a/airflow/cli/commands/jobs.py b/airflow/cli/commands/jobs.py new file mode 100644 index 0000000000000..0e178f92354b5 --- /dev/null +++ b/airflow/cli/commands/jobs.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List + +import rich_click as click + +from airflow.cli import airflow_cmd +from airflow.jobs.base_job import BaseJob +from airflow.utils.session import provide_session +from airflow.utils.state import State + + +@airflow_cmd.group("jobs") +def jobs(): + """Manage jobs""" + + +@jobs.command("check") +@click.option( + "--job-type", + type=click.Choice({"BackfillJob", "LocalTaskJob", "SchedulerJob", "TriggererJob"}), + help="The type of job(s) that will be checked. By default all job types are checked.", +) +@click.option( + "--hostname", metavar="HOSTNAME", default=None, help="The hostname of job(s) that will be checked." +) +@click.option( + "--limit", + default=1, + type=click.IntRange(min=0, max=None), + help="The number of recent jobs that will be checked. To disable limit, set 0.", +) +@click.option( + "--allow-multiple", + is_flag=True, + default=False, + help="If passed, this command will be successful even if multiple matching alive jobs are found.", +) +@provide_session +def check(job_type: str, hostname: str, limit: int, allow_multiple: bool, session=None): + """Checks if job(s) are still alive""" + if allow_multiple and not limit > 1: + raise SystemExit("To use option --allow-multiple, you must set the limit to a value greater than 1.") + query = ( + session.query(BaseJob) + .filter(BaseJob.state == State.RUNNING) + .order_by(BaseJob.latest_heartbeat.desc()) + ) + if job_type: + query = query.filter(BaseJob.job_type == job_type) + if hostname: + query = query.filter(BaseJob.hostname == hostname) + if limit > 0: + query = query.limit(limit) + + jobs: List[BaseJob] = query.all() + alive_jobs = [job for job in jobs if job.is_alive()] + + count_alive_jobs = len(alive_jobs) + if count_alive_jobs == 0: + raise SystemExit("No alive jobs found.") + if count_alive_jobs > 1 and not allow_multiple: + raise SystemExit(f"Found {count_alive_jobs} alive jobs. Expected only one.") + if count_alive_jobs == 1: + print("Found one alive job.") + else: + print(f"Found {count_alive_jobs} alive jobs.") diff --git a/tests/cli/commands/test_jobs.py b/tests/cli/commands/test_jobs.py new file mode 100644 index 0000000000000..eb66581ab7d92 --- /dev/null +++ b/tests/cli/commands/test_jobs.py @@ -0,0 +1,129 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import unittest + +import pytest +from click.testing import CliRunner + +from airflow.cli.commands import jobs +from airflow.jobs.scheduler_job import SchedulerJob +from airflow.utils.session import create_session +from airflow.utils.state import State +from tests.test_utils.db import clear_db_jobs + + +class TestCliConfigList(unittest.TestCase): + def setUp(self) -> None: + clear_db_jobs() + self.scheduler_job = None + + def tearDown(self) -> None: + if self.scheduler_job and self.scheduler_job.processor_agent: + self.scheduler_job.processor_agent.end() + clear_db_jobs() + + def test_should_report_success_for_one_working_scheduler(self): + with create_session() as session: + self.scheduler_job = SchedulerJob() + self.scheduler_job.state = State.RUNNING + session.add(self.scheduler_job) + session.commit() + self.scheduler_job.heartbeat() + + runner = CliRunner() + result = runner.invoke(jobs.check, ["--job-type", "SchedulerJob"]) + self.assertIn("Found one alive job.", result.output) + + def test_should_report_success_for_one_working_scheduler_with_hostname(self): + with create_session() as session: + self.scheduler_job = SchedulerJob() + self.scheduler_job.state = State.RUNNING + self.scheduler_job.hostname = 'HOSTNAME' + session.add(self.scheduler_job) + session.commit() + self.scheduler_job.heartbeat() + + runner = CliRunner() + result = runner.invoke(jobs.check, ["--job-type", "SchedulerJob", "--hostname", "HOSTNAME"]) + self.assertIn("Found one alive job.", result.output) + + def test_should_report_success_for_ha_schedulers(self): + scheduler_jobs = [] + with create_session() as session: + for _ in range(3): + scheduler_job = SchedulerJob() + scheduler_job.state = State.RUNNING + session.add(scheduler_job) + scheduler_jobs.append(scheduler_job) + session.commit() + scheduler_job.heartbeat() + + runner = CliRunner() + result = runner.invoke( + jobs.check, ["--job-type", "SchedulerJob", "--limit", "100", "--allow-multiple"] + ) + self.assertIn("Found 3 alive jobs.", result.output) + for scheduler_job in scheduler_jobs: + if scheduler_job.processor_agent: + scheduler_job.processor_agent.end() + + def test_should_ignore_not_running_jobs(self): + scheduler_jobs = [] + with create_session() as session: + for _ in range(3): + scheduler_job = SchedulerJob() + scheduler_job.state = State.SHUTDOWN + session.add(scheduler_job) + scheduler_jobs.append(scheduler_job) + session.commit() + # No alive jobs found. + runner = CliRunner() + result = runner.invoke(jobs.check) + assert isinstance(result.exception, SystemExit) + assert "No alive jobs found." in result.output + for scheduler_job in scheduler_jobs: + if scheduler_job.processor_agent: + scheduler_job.processor_agent.end() + + def test_should_raise_exception_for_multiple_scheduler_on_one_host(self): + scheduler_jobs = [] + with create_session() as session: + for _ in range(3): + scheduler_job = SchedulerJob() + scheduler_job.state = State.RUNNING + scheduler_job.hostname = 'HOSTNAME' + session.add(scheduler_job) + session.commit() + scheduler_job.heartbeat() + + runner = CliRunner() + result = runner.invoke(jobs.check, ["--job-type", "SchedulerJob", "--limit", "100"]) + assert isinstance(result.exception, SystemExit) + assert "Found 3 alive jobs. Expected only one." in result.output + + for scheduler_job in scheduler_jobs: + if scheduler_job.processor_agent: + scheduler_job.processor_agent.end() + + def test_should_raise_exception_for_allow_multiple_and_limit_1(self): + runner = CliRunner() + result = runner.invoke(jobs.check, ["--allow-multiple"]) + assert isinstance(result.exception, SystemExit) + assert ( + "To use option --allow-multiple, you must set the limit to a value greater than 1." + in result.output + ) From 572ffecb94aedc0c4be8ae2890451dd27a8baf24 Mon Sep 17 00:00:00 2001 From: hankehly Date: Wed, 15 Jun 2022 08:52:00 +0900 Subject: [PATCH 33/34] Add jobs group command epilog to help text --- airflow/cli/commands/jobs.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/airflow/cli/commands/jobs.py b/airflow/cli/commands/jobs.py index 0e178f92354b5..069785aa9b364 100644 --- a/airflow/cli/commands/jobs.py +++ b/airflow/cli/commands/jobs.py @@ -27,7 +27,18 @@ @airflow_cmd.group("jobs") def jobs(): - """Manage jobs""" + """Manage jobs + + \b + Examples: + To check if the local scheduler is still working properly, run: + \b + $ airflow jobs check --job-type SchedulerJob --hostname "$(hostname)" + \b + To check if any scheduler is running when you are using high availability, run: + \b + $ airflow jobs check --job-type SchedulerJob --allow-multiple --limit 100 + """ @jobs.command("check") From 4d2502b4cd63de2b5f245aeee202bc6bc6cb9054 Mon Sep 17 00:00:00 2001 From: hankehly Date: Wed, 15 Jun 2022 08:56:47 +0900 Subject: [PATCH 34/34] Move jobs-check epilog text to appropriate function --- airflow/cli/commands/jobs.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/airflow/cli/commands/jobs.py b/airflow/cli/commands/jobs.py index 069785aa9b364..88a79e5a1bfd1 100644 --- a/airflow/cli/commands/jobs.py +++ b/airflow/cli/commands/jobs.py @@ -27,18 +27,7 @@ @airflow_cmd.group("jobs") def jobs(): - """Manage jobs - - \b - Examples: - To check if the local scheduler is still working properly, run: - \b - $ airflow jobs check --job-type SchedulerJob --hostname "$(hostname)" - \b - To check if any scheduler is running when you are using high availability, run: - \b - $ airflow jobs check --job-type SchedulerJob --allow-multiple --limit 100 - """ + """Manage jobs""" @jobs.command("check") @@ -64,7 +53,18 @@ def jobs(): ) @provide_session def check(job_type: str, hostname: str, limit: int, allow_multiple: bool, session=None): - """Checks if job(s) are still alive""" + """Checks if job(s) are still alive + + \b + examples: + To check if the local scheduler is still working properly, run: + \b + $ airflow jobs check --job-type SchedulerJob --hostname "$(hostname)" + \b + To check if any scheduler is running when you are using high availability, run: + \b + $ airflow jobs check --job-type SchedulerJob --allow-multiple --limit 100 + """ if allow_multiple and not limit > 1: raise SystemExit("To use option --allow-multiple, you must set the limit to a value greater than 1.") query = (