From b60a73fa748099a67d4e5edb3deae998ceeef83b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Fri, 15 Nov 2019 18:11:27 +0100 Subject: [PATCH 01/23] [AIRFLOW-YYY] Lazy load API Client --- airflow/api/client/__init__.py | 20 ++++++++++++++++++++ airflow/bin/cli.py | 16 +++++++++++----- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/airflow/api/client/__init__.py b/airflow/api/client/__init__.py index 114d189da14ab..f3ecb812aa9cb 100644 --- a/airflow/api/client/__init__.py +++ b/airflow/api/client/__init__.py @@ -16,3 +16,23 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +""" +API Client that allows interact with Airflow API +""" +from importlib import import_module +from typing import Any + +from airflow import api, conf +from airflow.api.client.api_client import Client + + +def get_current_api_client() -> Client: + """ + Return current API Client depends on current Airflow configuration + """ + api_module = import_module(conf.get('cli', 'api_client')) # type: Any + api_client = api_module.Client( + api_base_url=conf.get('cli', 'endpoint_url'), + auth=api.API_AUTH.api_auth.CLIENT_AUTH + ) + return api_client diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index 49202beedac5f..64bf3a8bce438 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -39,8 +39,6 @@ import time import traceback from argparse import RawTextHelpFormatter -from importlib import import_module -from typing import Any from urllib.parse import urlunparse import daemon @@ -51,6 +49,7 @@ import airflow from airflow import api, jobs, settings +from airflow.api.client import get_current_api_client from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowWebServerTimeout from airflow.executors import get_default_executor @@ -64,9 +63,6 @@ from airflow.www.app import cached_app, cached_appbuilder, create_app api.load_auth() -api_module = import_module(conf.get('cli', 'api_client')) # type: Any -api_client = api_module.Client(api_base_url=conf.get('cli', 'endpoint_url'), - auth=api.API_AUTH.api_auth.CLIENT_AUTH) LOG = LoggingMixin().log @@ -231,6 +227,7 @@ def trigger_dag(args): :param args: :return: """ + api_client = get_current_api_client() log = LoggingMixin().log try: message = api_client.trigger_dag(dag_id=args.dag_id, @@ -251,6 +248,7 @@ def delete_dag(args): :param args: :return: """ + api_client = get_current_api_client() log = LoggingMixin().log if args.yes or input( "This will drop all existing records related to the specified DAG. " @@ -272,6 +270,7 @@ def _tabulate_pools(pools, tablefmt="fancy_grid"): def pool_list(args): """Displays info of all the pools""" + api_client = get_current_api_client() log = LoggingMixin().log pools = api_client.get_pools() log.info(_tabulate_pools(pools=pools, tablefmt=args.output)) @@ -279,6 +278,7 @@ def pool_list(args): def pool_get(args): """Displays pool info by a given name""" + api_client = get_current_api_client() log = LoggingMixin().log pools = [api_client.get_pool(name=args.pool)] log.info(_tabulate_pools(pools=pools, tablefmt=args.output)) @@ -287,6 +287,7 @@ def pool_get(args): @cli_utils.action_logging def pool_set(args): """Creates new pool with a given name and slots""" + api_client = get_current_api_client() log = LoggingMixin().log pools = [api_client.create_pool(name=args.pool, slots=args.slots, @@ -297,6 +298,7 @@ def pool_set(args): @cli_utils.action_logging def pool_delete(args): """Deletes pool by a given name""" + api_client = get_current_api_client() log = LoggingMixin().log pools = [api_client.delete_pool(name=args.pool)] log.info(_tabulate_pools(pools=pools, tablefmt=args.output)) @@ -305,6 +307,7 @@ def pool_delete(args): @cli_utils.action_logging def pool_import(args): """Imports pools from the file""" + api_client = get_current_api_client() log = LoggingMixin().log if os.path.exists(args.file): pools = pool_import_helper(args.file) @@ -323,6 +326,8 @@ def pool_export(args): def pool_import_helper(filepath): """Helps import pools from the json file""" + api_client = get_current_api_client() + with open(filepath, 'r') as poolfile: data = poolfile.read() try: # pylint: disable=too-many-nested-blocks @@ -350,6 +355,7 @@ def pool_import_helper(filepath): def pool_export_helper(filepath): """Helps export all of the pools to the json file""" + api_client = get_current_api_client() pool_dict = {} pools = api_client.get_pools() for pool in pools: From 2a7c7d6286b4190dec17b3cc2bc2f085e291d684 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Fri, 15 Nov 2019 18:25:15 +0100 Subject: [PATCH 02/23] [AIRFLOW-YYY] Introduce order in CLI's function names --- airflow/bin/cli.py | 60 +++++++++++++-------------- tests/cli/test_cli.py | 72 ++++++++++++++++----------------- tests/jobs/test_backfill_job.py | 6 +-- 3 files changed, 69 insertions(+), 69 deletions(-) diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index 64bf3a8bce438..41032837a34df 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -158,7 +158,7 @@ def get_dags(args): @cli_utils.action_logging -def backfill(args, dag=None): +def dag_backfill(args, dag=None): """Creates backfill job or dry run for a DAG""" logging.basicConfig( level=settings.LOGGING_LEVEL, @@ -220,7 +220,7 @@ def backfill(args, dag=None): @cli_utils.action_logging -def trigger_dag(args): +def dag_trigger(args): """ Creates a dag run for the specified dag @@ -241,7 +241,7 @@ def trigger_dag(args): @cli_utils.action_logging -def delete_dag(args): +def dag_delete(args): """ Deletes all DB records related to the specified dag @@ -454,13 +454,13 @@ def variable_export_helper(filepath): @cli_utils.action_logging -def pause(args): +def dag_pause(args): """Pauses a DAG""" set_is_paused(True, args) @cli_utils.action_logging -def unpause(args): +def dag_unpause(args): """Unpauses a DAG""" set_is_paused(False, args) @@ -474,7 +474,7 @@ def set_is_paused(is_paused, args): print("Dag: {}, paused: {}".format(args.dag_id, str(is_paused))) -def show_dag(args): +def dag_show(args): """Displays DAG or saves it's graphic representation to the file""" dag = get_dag(args) dot = render_dag(dag) @@ -554,7 +554,7 @@ def _run(args, dag, ti): @cli_utils.action_logging -def run(args, dag=None): +def task_run(args, dag=None): """Runs a single task instance""" if dag: args.dag_id = dag.dag_id @@ -658,7 +658,7 @@ def dag_state(args): @cli_utils.action_logging -def next_execution(args): +def dag_next_execution(args): """ Returns the next execution datetime of a DAG at the command line. >>> airflow dags next_execution tutorial @@ -694,7 +694,7 @@ def rotate_fernet_key(args): @cli_utils.action_logging -def list_dags(args): +def dag_list_dags(args): """Displays dags with or without stats at the command line""" dagbag = DagBag(process_subdir(args.subdir)) list_template = textwrap.dedent("""\n @@ -710,7 +710,7 @@ def list_dags(args): @cli_utils.action_logging -def list_tasks(args, dag=None): +def task_list(args, dag=None): """Lists the tasks within a DAG at the command line""" dag = dag or get_dag(args) if args.tree: @@ -721,7 +721,7 @@ def list_tasks(args, dag=None): @cli_utils.action_logging -def list_jobs(args, dag=None): +def dag_list_jobs(args, dag=None): """Lists latest n jobs""" queries = [] if dag: @@ -753,7 +753,7 @@ def list_jobs(args, dag=None): @cli_utils.action_logging -def test(args, dag=None): +def task_test(args, dag=None): """Tests task for a given dag_id""" # We want log outout from operators etc to show up here. Normally # airflow.task would redirect to a file, but here we want it to propagate @@ -786,7 +786,7 @@ def test(args, dag=None): @cli_utils.action_logging -def render(args): +def task_render(args): """Renders and displays templated fields for a given task""" dag = get_dag(args) task = dag.get_task(task_id=args.task_id) @@ -802,7 +802,7 @@ def render(args): @cli_utils.action_logging -def clear(args): +def task_clear(args): """Clears all task instances or only those matched by regex for a DAG(s)""" logging.basicConfig( level=settings.LOGGING_LEVEL, @@ -1659,7 +1659,7 @@ def roles_create(args): @cli_utils.action_logging -def list_dag_runs(args, dag=None): +def dag_list_dag_runs(args, dag=None): """Lists dag runs for a given DAG""" if dag: args.dag_id = dag.dag_id @@ -2255,13 +2255,13 @@ class CLIFactory: 'name': 'dags', 'subcommands': ( { - 'func': list_dags, + 'func': dag_list_dags, 'name': 'list', 'help': "List all the DAGs", 'args': ('subdir', 'report'), }, { - 'func': list_dag_runs, + 'func': dag_list_dag_runs, 'name': 'list_runs', 'help': "List dag runs given a DAG id. If state option is given, it will only " "search for all the dagruns with the given state. " @@ -2270,7 +2270,7 @@ class CLIFactory: 'args': ('dag_id', 'no_backfill', 'state'), }, { - 'func': list_jobs, + 'func': dag_list_jobs, 'name': 'list_jobs', 'help': "List the jobs", 'args': ('dag_id_opt', 'state', 'limit', 'output',), @@ -2282,43 +2282,43 @@ class CLIFactory: 'args': ('dag_id', 'execution_date', 'subdir'), }, { - 'func': next_execution, + 'func': dag_next_execution, 'name': 'next_execution', 'help': "Get the next execution datetime of a DAG.", 'args': ('dag_id', 'subdir'), }, { - 'func': pause, + 'func': dag_pause, 'name': 'pause', 'help': 'Pause a DAG', 'args': ('dag_id', 'subdir'), }, { - 'func': unpause, + 'func': dag_unpause, 'name': 'unpause', 'help': 'Resume a paused DAG', 'args': ('dag_id', 'subdir'), }, { - 'func': trigger_dag, + 'func': dag_trigger, 'name': 'trigger', 'help': 'Trigger a DAG run', 'args': ('dag_id', 'subdir', 'run_id', 'conf', 'exec_date'), }, { - 'func': delete_dag, + 'func': dag_delete, 'name': 'delete', 'help': "Delete all DB records related to the specified DAG", 'args': ('dag_id', 'yes'), }, { - 'func': show_dag, + 'func': dag_show, 'name': 'show', 'help': "Displays DAG's tasks with their dependencies", 'args': ('dag_id', 'subdir', 'save', 'imgcat',), }, { - 'func': backfill, + 'func': dag_backfill, 'name': 'backfill', 'help': "Run subsections of a DAG for a specified date range. " "If reset_dag_run option is used," @@ -2342,13 +2342,13 @@ class CLIFactory: 'name': 'tasks', 'subcommands': ( { - 'func': list_tasks, + 'func': task_list, 'name': 'list', 'help': "List the tasks within a DAG", 'args': ('dag_id', 'tree', 'subdir'), }, { - 'func': clear, + 'func': task_clear, 'name': 'clear', 'help': "Clear a set of task instance, as if they never ran", 'args': ( @@ -2373,13 +2373,13 @@ class CLIFactory: 'args': ('dag_id', 'task_id', 'execution_date', 'subdir'), }, { - 'func': render, + 'func': task_render, 'name': 'render', 'help': "Render a task instance's template(s)", 'args': ('dag_id', 'task_id', 'execution_date', 'subdir'), }, { - 'func': run, + 'func': task_run, 'name': 'run', 'help': "Run a single task instance", 'args': ( @@ -2389,7 +2389,7 @@ class CLIFactory: 'ignore_depends_on_past', 'ship_dag', 'pickle', 'job_id', 'interactive',), }, { - 'func': test, + 'func': task_test, 'name': 'test', 'help': ( "Test a task instance. This will run a task without checking for " diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 22d2f93dd29bd..2f36e7a5b3bba 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -35,7 +35,7 @@ import airflow.bin.cli as cli from airflow import DAG, AirflowException, models, settings -from airflow.bin.cli import get_dag, get_num_ready_workers_running, run +from airflow.bin.cli import get_dag, get_num_ready_workers_running, task_run from airflow.models import Connection, DagModel, Pool, TaskInstance, Variable from airflow.settings import Session from airflow.utils import db, timezone @@ -190,7 +190,7 @@ def test_local_run(self): reset(args.dag_id) with patch('argparse.Namespace', args) as mock_args: - run(mock_args) + task_run(mock_args) dag = get_dag(mock_args) task = dag.get_task(task_id=args.task_id) ti = TaskInstance(task, args.execution_date) @@ -211,7 +211,7 @@ def setUpClass(cls): @mock.patch("airflow.bin.cli.DAG.run") def test_backfill(self, mock_run): - cli.backfill(self.parser.parse_args([ + cli.dag_backfill(self.parser.parse_args([ 'dags', 'backfill', 'example_bash_operator', '-s', DEFAULT_DATE.isoformat()])) @@ -234,7 +234,7 @@ def test_backfill(self, mock_run): dag = self.dagbag.get_dag('example_bash_operator') with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: - cli.backfill(self.parser.parse_args([ + cli.dag_backfill(self.parser.parse_args([ 'dags', 'backfill', 'example_bash_operator', '-t', 'runme_0', '--dry_run', '-s', DEFAULT_DATE.isoformat()]), dag=dag) @@ -246,13 +246,13 @@ def test_backfill(self, mock_run): mock_run.assert_not_called() # Dry run shouldn't run the backfill - cli.backfill(self.parser.parse_args([ + cli.dag_backfill(self.parser.parse_args([ 'dags', 'backfill', 'example_bash_operator', '--dry_run', '-s', DEFAULT_DATE.isoformat()]), dag=dag) mock_run.assert_not_called() # Dry run shouldn't run the backfill - cli.backfill(self.parser.parse_args([ + cli.dag_backfill(self.parser.parse_args([ 'dags', 'backfill', 'example_bash_operator', '-l', '-s', DEFAULT_DATE.isoformat()]), dag=dag) @@ -276,7 +276,7 @@ def test_backfill(self, mock_run): def test_show_dag_print(self): temp_stdout = io.StringIO() with contextlib.redirect_stdout(temp_stdout): - cli.show_dag(self.parser.parse_args([ + cli.dag_show(self.parser.parse_args([ 'dags', 'show', 'example_bash_operator'])) out = temp_stdout.getvalue() self.assertIn("label=example_bash_operator", out) @@ -287,7 +287,7 @@ def test_show_dag_print(self): def test_show_dag_dave(self, mock_render_dag): temp_stdout = io.StringIO() with contextlib.redirect_stdout(temp_stdout): - cli.show_dag(self.parser.parse_args([ + cli.dag_show(self.parser.parse_args([ 'dags', 'show', 'example_bash_operator', '--save', 'awesome.png'] )) out = temp_stdout.getvalue() @@ -303,7 +303,7 @@ def test_show_dag_imgcat(self, mock_render_dag, mock_popen): mock_popen.return_value.communicate.return_value = (b"OUT", b"ERR") temp_stdout = io.StringIO() with contextlib.redirect_stdout(temp_stdout): - cli.show_dag(self.parser.parse_args([ + cli.dag_show(self.parser.parse_args([ 'dags', 'show', 'example_bash_operator', '--imgcat'] )) out = temp_stdout.getvalue() @@ -333,7 +333,7 @@ def test_cli_backfill_depends_on_past(self, mock_run): ] dag = self.dagbag.get_dag(dag_id) - cli.backfill(self.parser.parse_args(args), dag=dag) + cli.dag_backfill(self.parser.parse_args(args), dag=dag) mock_run.assert_called_once_with( start_date=run_date, @@ -373,7 +373,7 @@ def test_cli_backfill_depends_on_past_backwards(self, mock_run): ] dag = self.dagbag.get_dag(dag_id) - cli.backfill(self.parser.parse_args(args), dag=dag) + cli.dag_backfill(self.parser.parse_args(args), dag=dag) mock_run.assert_called_once_with( start_date=start_date, end_date=end_date, @@ -458,15 +458,15 @@ def reset_dr_db(dag_id): def test_cli_list_dags(self): args = self.parser.parse_args(['dags', 'list', '--report']) - cli.list_dags(args) + cli.dag_list_dags(args) def test_cli_list_dag_runs(self): - cli.trigger_dag(self.parser.parse_args([ + cli.dag_trigger(self.parser.parse_args([ 'dags', 'trigger', 'example_bash_operator', ])) args = self.parser.parse_args(['dags', 'list_runs', 'example_bash_operator', '--no_backfill']) - cli.list_dag_runs(args) + cli.dag_list_dag_runs(args) def test_cli_list_jobs_with_args(self): args = self.parser.parse_args(['dags', 'list_jobs', '--dag_id', @@ -474,26 +474,26 @@ def test_cli_list_jobs_with_args(self): '--state', 'success', '--limit', '100', '--output', 'tsv']) - cli.list_jobs(args) + cli.dag_list_jobs(args) def test_pause(self): args = self.parser.parse_args([ 'dags', 'pause', 'example_bash_operator']) - cli.pause(args) + cli.dag_pause(args) self.assertIn(self.dagbag.dags['example_bash_operator'].is_paused, [True, 1]) args = self.parser.parse_args([ 'dags', 'unpause', 'example_bash_operator']) - cli.unpause(args) + cli.dag_unpause(args) self.assertIn(self.dagbag.dags['example_bash_operator'].is_paused, [False, 0]) def test_trigger_dag(self): - cli.trigger_dag(self.parser.parse_args([ + cli.dag_trigger(self.parser.parse_args([ 'dags', 'trigger', 'example_bash_operator', '-c', '{"foo": "bar"}'])) self.assertRaises( ValueError, - cli.trigger_dag, + cli.dag_trigger, self.parser.parse_args([ 'dags', 'trigger', 'example_bash_operator', '--run_id', 'trigger_dag_xxx', @@ -506,12 +506,12 @@ def test_delete_dag(self): session = settings.Session() session.add(DM(dag_id=key)) session.commit() - cli.delete_dag(self.parser.parse_args([ + cli.dag_delete(self.parser.parse_args([ 'dags', 'delete', key, '--yes'])) self.assertEqual(session.query(DM).filter_by(dag_id=key).count(), 0) self.assertRaises( AirflowException, - cli.delete_dag, + cli.dag_delete, self.parser.parse_args([ 'dags', 'delete', 'does_not_exist_dag', @@ -527,13 +527,13 @@ def test_delete_dag_existing_file(self): with tempfile.NamedTemporaryFile() as f: session.add(DM(dag_id=key, fileloc=f.name)) session.commit() - cli.delete_dag(self.parser.parse_args([ + cli.dag_delete(self.parser.parse_args([ 'dags', 'delete', key, '--yes'])) self.assertEqual(session.query(DM).filter_by(dag_id=key).count(), 0) def test_cli_list_jobs(self): args = self.parser.parse_args(['dags', 'list_jobs']) - cli.list_jobs(args) + cli.dag_list_jobs(args) def test_dag_state(self): self.assertEqual(None, cli.dag_state(self.parser.parse_args([ @@ -882,11 +882,11 @@ def setUpClass(cls): def test_cli_list_tasks(self): for dag_id in self.dagbag.dags: args = self.parser.parse_args(['tasks', 'list', dag_id]) - cli.list_tasks(args) + cli.task_list(args) args = self.parser.parse_args([ 'tasks', 'list', 'example_bash_operator', '--tree']) - cli.list_tasks(args) + cli.task_list(args) def test_test(self): """Test the `airflow test` command""" @@ -900,7 +900,7 @@ def test_test(self): saved_stdout = sys.stdout try: sys.stdout = out = io.StringIO() - cli.test(args) + cli.task_test(args) output = out.getvalue() # Check that prints, and log messages, are shown @@ -927,7 +927,7 @@ def test_run_naive_taskinstance(self, mock_local_job): task0_id, naive_date.isoformat()] - cli.run(self.parser.parse_args(args0), dag=dag) + cli.task_run(self.parser.parse_args(args0), dag=dag) mock_local_job.assert_called_once_with( task_instance=mock.ANY, mark_success=False, @@ -940,23 +940,23 @@ def test_run_naive_taskinstance(self, mock_local_job): ) def test_cli_test(self): - cli.test(self.parser.parse_args([ + cli.task_test(self.parser.parse_args([ 'tasks', 'test', 'example_bash_operator', 'runme_0', DEFAULT_DATE.isoformat()])) - cli.test(self.parser.parse_args([ + cli.task_test(self.parser.parse_args([ 'tasks', 'test', 'example_bash_operator', 'runme_0', '--dry_run', DEFAULT_DATE.isoformat()])) def test_cli_test_with_params(self): - cli.test(self.parser.parse_args([ + cli.task_test(self.parser.parse_args([ 'tasks', 'test', 'example_passing_params_via_test_command', 'run_this', '-tp', '{"foo":"bar"}', DEFAULT_DATE.isoformat()])) - cli.test(self.parser.parse_args([ + cli.task_test(self.parser.parse_args([ 'tasks', 'test', 'example_passing_params_via_test_command', 'also_run_this', '-tp', '{"foo":"bar"}', DEFAULT_DATE.isoformat()])) def test_cli_run(self): - cli.run(self.parser.parse_args([ + cli.task_run(self.parser.parse_args([ 'tasks', 'run', 'example_bash_operator', 'runme_0', '-l', DEFAULT_DATE.isoformat()])) @@ -968,19 +968,19 @@ def test_task_state(self): def test_subdag_clear(self): args = self.parser.parse_args([ 'tasks', 'clear', 'example_subdag_operator', '--yes']) - cli.clear(args) + cli.task_clear(args) args = self.parser.parse_args([ 'tasks', 'clear', 'example_subdag_operator', '--yes', '--exclude_subdags']) - cli.clear(args) + cli.task_clear(args) def test_parentdag_downstream_clear(self): args = self.parser.parse_args([ 'tasks', 'clear', 'example_subdag_operator.section-1', '--yes']) - cli.clear(args) + cli.task_clear(args) args = self.parser.parse_args([ 'tasks', 'clear', 'example_subdag_operator.section-1', '--yes', '--exclude_parentdag']) - cli.clear(args) + cli.task_clear(args) def test_get_dags(self): dags = cli.get_dags(self.parser.parse_args(['tasks', 'clear', 'example_subdag_operator', diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py index 671473532237c..36baeb38bc54e 100644 --- a/tests/jobs/test_backfill_job.py +++ b/tests/jobs/test_backfill_job.py @@ -827,7 +827,7 @@ def test_run_ignores_all_dependencies(self): dag_id, task0_id, DEFAULT_DATE.isoformat()] - cli.run(self.parser.parse_args(args0)) + cli.task_run(self.parser.parse_args(args0)) ti_dependent0 = TI( task=dag.get_task(task0_id), execution_date=DEFAULT_DATE) @@ -842,7 +842,7 @@ def test_run_ignores_all_dependencies(self): dag_id, task1_id, (DEFAULT_DATE + datetime.timedelta(days=1)).isoformat()] - cli.run(self.parser.parse_args(args1)) + cli.task_run(self.parser.parse_args(args1)) ti_dependency = TI( task=dag.get_task(task1_id), @@ -857,7 +857,7 @@ def test_run_ignores_all_dependencies(self): dag_id, task2_id, (DEFAULT_DATE + datetime.timedelta(days=1)).isoformat()] - cli.run(self.parser.parse_args(args2)) + cli.task_run(self.parser.parse_args(args2)) ti_dependent = TI( task=dag.get_task(task2_id), From c1d6b26b2a8b36bc39bcf5609ebb0cc66fdae2e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Fri, 15 Nov 2019 19:34:39 +0100 Subject: [PATCH 03/23] [AIRFLOW-YYY] Create cli package --- airflow/cli/__init__.py | 18 ++++++++++++++++++ airflow/cli/commands/__init__.py | 18 ++++++++++++++++++ docs/conf.py | 2 ++ 3 files changed, 38 insertions(+) create mode 100644 airflow/cli/__init__.py create mode 100644 airflow/cli/commands/__init__.py diff --git a/airflow/cli/__init__.py b/airflow/cli/__init__.py new file mode 100644 index 0000000000000..114d189da14ab --- /dev/null +++ b/airflow/cli/__init__.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/cli/commands/__init__.py b/airflow/cli/commands/__init__.py new file mode 100644 index 0000000000000..114d189da14ab --- /dev/null +++ b/airflow/cli/commands/__init__.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/docs/conf.py b/docs/conf.py index ef4a74154dce8..1bf37423f8e04 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -174,6 +174,8 @@ '_api/airflow/_vendor', '_api/airflow/api', '_api/airflow/bin', + '_api/airflow/cli', + '_api/airflow/cli/command', '_api/airflow/config_templates', '_api/airflow/configuration', '_api/airflow/contrib/auth', From 4083a5c4359c18eb1ae6a86ecd1797b9dc970e15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Fri, 15 Nov 2019 19:57:03 +0100 Subject: [PATCH 04/23] [AIRLFOW-YYY] Move user and roles command to seperate files --- airflow/bin/cli.py | 255 +-------------------- airflow/cli/commands/role_command.py | 44 ++++ airflow/cli/commands/user_command.py | 245 ++++++++++++++++++++ tests/cli/commands/__init__.py | 19 ++ tests/cli/commands/test_role_command.py | 93 ++++++++ tests/cli/commands/test_user_command.py | 254 +++++++++++++++++++++ tests/cli/test_cli.py | 290 ------------------------ 7 files changed, 665 insertions(+), 535 deletions(-) create mode 100644 airflow/cli/commands/role_command.py create mode 100644 airflow/cli/commands/user_command.py create mode 100644 tests/cli/commands/__init__.py create mode 100644 tests/cli/commands/test_role_command.py create mode 100644 tests/cli/commands/test_user_command.py diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index 41032837a34df..ffc06d65cbe3e 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -21,17 +21,13 @@ import argparse import errno -import functools -import getpass import importlib import json import logging import os -import random import re import reprlib import signal -import string import subprocess import sys import textwrap @@ -50,6 +46,7 @@ import airflow from airflow import api, jobs, settings from airflow.api.client import get_current_api_client +from airflow.cli.commands import role_command, user_command from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowWebServerTimeout from airflow.executors import get_default_executor @@ -1426,238 +1423,6 @@ def kerberos(args): airflow.security.kerberos.run(principal=args.principal, keytab=args.keytab) -def users_list(args): - """Lists users at the command line""" - appbuilder = cached_appbuilder() - users = appbuilder.sm.get_all_users() - fields = ['id', 'username', 'email', 'first_name', 'last_name', 'roles'] - users = [[user.__getattribute__(field) for field in fields] for user in users] - msg = tabulate(users, [field.capitalize().replace('_', ' ') for field in fields], - tablefmt=args.output) - print(msg) - - -@cli_utils.action_logging -def users_create(args): - """Creates new user in the DB""" - appbuilder = cached_appbuilder() - role = appbuilder.sm.find_role(args.role) - if not role: - valid_roles = appbuilder.sm.get_all_roles() - raise SystemExit('{} is not a valid role. Valid roles are: {}'.format(args.role, valid_roles)) - - if args.use_random_password: - password = ''.join(random.choice(string.printable) for _ in range(16)) - elif args.password: - password = args.password - else: - password = getpass.getpass('Password:') - password_confirmation = getpass.getpass('Repeat for confirmation:') - if password != password_confirmation: - raise SystemExit('Passwords did not match!') - - if appbuilder.sm.find_user(args.username): - print('{} already exist in the db'.format(args.username)) - return - user = appbuilder.sm.add_user(args.username, args.firstname, args.lastname, - args.email, role, password) - if user: - print('{} user {} created.'.format(args.role, args.username)) - else: - raise SystemExit('Failed to create user.') - - -@cli_utils.action_logging -def users_delete(args): - """Deletes user from DB""" - appbuilder = cached_appbuilder() - - try: - user = next(u for u in appbuilder.sm.get_all_users() - if u.username == args.username) - except StopIteration: - raise SystemExit('{} is not a valid user.'.format(args.username)) - - if appbuilder.sm.del_register_user(user): - print('User {} deleted.'.format(args.username)) - else: - raise SystemExit('Failed to delete user.') - - -@cli_utils.action_logging -def users_manage_role(args, remove=False): - """Deletes or appends user roles""" - if not args.username and not args.email: - raise SystemExit('Missing args: must supply one of --username or --email') - - if args.username and args.email: - raise SystemExit('Conflicting args: must supply either --username' - ' or --email, but not both') - - appbuilder = cached_appbuilder() - user = (appbuilder.sm.find_user(username=args.username) or - appbuilder.sm.find_user(email=args.email)) - if not user: - raise SystemExit('User "{}" does not exist'.format( - args.username or args.email)) - - role = appbuilder.sm.find_role(args.role) - if not role: - valid_roles = appbuilder.sm.get_all_roles() - raise SystemExit('{} is not a valid role. Valid roles are: {}'.format(args.role, valid_roles)) - - if remove: - if role in user.roles: - user.roles = [r for r in user.roles if r != role] - appbuilder.sm.update_user(user) - print('User "{}" removed from role "{}".'.format( - user, - args.role)) - else: - raise SystemExit('User "{}" is not a member of role "{}".'.format( - user, - args.role)) - else: - if role in user.roles: - raise SystemExit('User "{}" is already a member of role "{}".'.format( - user, - args.role)) - else: - user.roles.append(role) - appbuilder.sm.update_user(user) - print('User "{}" added to role "{}".'.format( - user, - args.role)) - - -def users_export(args): - """Exports all users to the json file""" - appbuilder = cached_appbuilder() - users = appbuilder.sm.get_all_users() - fields = ['id', 'username', 'email', 'first_name', 'last_name', 'roles'] - - # In the User model the first and last name fields have underscores, - # but the corresponding parameters in the CLI don't - def remove_underscores(s): - return re.sub("_", "", s) - - users = [ - {remove_underscores(field): user.__getattribute__(field) - if field != 'roles' else [r.name for r in user.roles] - for field in fields} - for user in users - ] - - with open(args.export, 'w') as file: - file.write(json.dumps(users, sort_keys=True, indent=4)) - print("{} users successfully exported to {}".format(len(users), file.name)) - - -@cli_utils.action_logging -def users_import(args): - """Imports users from the json file""" - json_file = getattr(args, 'import') - if not os.path.exists(json_file): - print("File '{}' does not exist") - exit(1) - - users_list = None # pylint: disable=redefined-outer-name - try: - with open(json_file, 'r') as file: - users_list = json.loads(file.read()) - except ValueError as e: - print("File '{}' is not valid JSON. Error: {}".format(json_file, e)) - exit(1) - - users_created, users_updated = _import_users(users_list) - if users_created: - print("Created the following users:\n\t{}".format( - "\n\t".join(users_created))) - - if users_updated: - print("Updated the following users:\n\t{}".format( - "\n\t".join(users_updated))) - - -def _import_users(users_list): # pylint: disable=redefined-outer-name - appbuilder = cached_appbuilder() - users_created = [] - users_updated = [] - - for user in users_list: - roles = [] - for rolename in user['roles']: - role = appbuilder.sm.find_role(rolename) - if not role: - valid_roles = appbuilder.sm.get_all_roles() - print("Error: '{}' is not a valid role. Valid roles are: {}".format(rolename, valid_roles)) - exit(1) - else: - roles.append(role) - - required_fields = ['username', 'firstname', 'lastname', - 'email', 'roles'] - for field in required_fields: - if not user.get(field): - print("Error: '{}' is a required field, but was not " - "specified".format(field)) - exit(1) - - existing_user = appbuilder.sm.find_user(email=user['email']) - if existing_user: - print("Found existing user with email '{}'".format(user['email'])) - existing_user.roles = roles - existing_user.first_name = user['firstname'] - existing_user.last_name = user['lastname'] - - if existing_user.username != user['username']: - print("Error: Changing the username is not allowed - " - "please delete and recreate the user with " - "email '{}'".format(user['email'])) - exit(1) - - appbuilder.sm.update_user(existing_user) - users_updated.append(user['email']) - else: - print("Creating new user with email '{}'".format(user['email'])) - appbuilder.sm.add_user( - username=user['username'], - first_name=user['firstname'], - last_name=user['lastname'], - email=user['email'], - role=roles[0], # add_user() requires exactly 1 role - ) - - if len(roles) > 1: - new_user = appbuilder.sm.find_user(email=user['email']) - new_user.roles = roles - appbuilder.sm.update_user(new_user) - - users_created.append(user['email']) - - return users_created, users_updated - - -def roles_list(args): - """Lists all existing roles""" - appbuilder = cached_appbuilder() - roles = appbuilder.sm.get_all_roles() - print("Existing roles:\n") - role_names = sorted([[r.name] for r in roles]) - msg = tabulate(role_names, - headers=['Role'], - tablefmt=args.output) - print(msg) - - -@cli_utils.action_logging -def roles_create(args): - """Creates new empty role in DB""" - appbuilder = cached_appbuilder() - for role_name in args.role: - appbuilder.sm.add_role(role_name) - - @cli_utils.action_logging def dag_list_dag_runs(args, dag=None): """Lists dag runs for a given DAG""" @@ -2569,44 +2334,44 @@ class CLIFactory: 'name': 'users', 'subcommands': ( { - 'func': users_list, + 'func': user_command.users_list, 'name': 'list', 'help': 'List users', 'args': ('output',), }, { - 'func': users_create, + 'func': user_command.users_create, 'name': 'create', 'help': 'Create a user', 'args': ('role', 'username', 'email', 'firstname', 'lastname', 'password', 'use_random_password') }, { - 'func': users_delete, + 'func': user_command.users_delete, 'name': 'delete', 'help': 'Delete a user', 'args': ('username',), }, { - 'func': functools.partial(users_manage_role, remove=False), + 'func': user_command.add_role, 'name': 'add_role', 'help': 'Add role to a user', 'args': ('username_optional', 'email_optional', 'role'), }, { - 'func': functools.partial(users_manage_role, remove=True), + 'func': user_command.remove_role, 'name': 'remove_role', 'help': 'Remove role from a user', 'args': ('username_optional', 'email_optional', 'role'), }, { - 'func': users_import, + 'func': user_command.users_import, 'name': 'import', 'help': 'Import a user', 'args': ('user_import',), }, { - 'func': users_export, + 'func': user_command.users_export, 'name': 'export', 'help': 'Export a user', 'args': ('user_export',), @@ -2617,13 +2382,13 @@ class CLIFactory: 'name': 'roles', 'subcommands': ( { - 'func': roles_list, + 'func': role_command.roles_list, 'name': 'list', 'help': 'List roles', 'args': ('output',), }, { - 'func': roles_create, + 'func': role_command.roles_create, 'name': 'create', 'help': 'Create role', 'args': ('roles',), diff --git a/airflow/cli/commands/role_command.py b/airflow/cli/commands/role_command.py new file mode 100644 index 0000000000000..33ff313a68791 --- /dev/null +++ b/airflow/cli/commands/role_command.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""Roles sub-commands""" +from tabulate import tabulate + +from airflow.utils import cli as cli_utils +from airflow.www.app import cached_appbuilder + + +def roles_list(args): + """Lists all existing roles""" + appbuilder = cached_appbuilder() + roles = appbuilder.sm.get_all_roles() + print("Existing roles:\n") + role_names = sorted([[r.name] for r in roles]) + msg = tabulate(role_names, + headers=['Role'], + tablefmt=args.output) + print(msg) + + +@cli_utils.action_logging +def roles_create(args): + """Creates new empty role in DB""" + appbuilder = cached_appbuilder() + for role_name in args.role: + appbuilder.sm.add_role(role_name) diff --git a/airflow/cli/commands/user_command.py b/airflow/cli/commands/user_command.py new file mode 100644 index 0000000000000..b155d0ab34f22 --- /dev/null +++ b/airflow/cli/commands/user_command.py @@ -0,0 +1,245 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""User sub-commands""" +import functools +import getpass +import json +import os +import random +import re +import string + +from tabulate import tabulate + +from airflow.utils import cli as cli_utils +from airflow.www.app import cached_appbuilder + + +def users_list(args): + """Lists users at the command line""" + appbuilder = cached_appbuilder() + users = appbuilder.sm.get_all_users() + fields = ['id', 'username', 'email', 'first_name', 'last_name', 'roles'] + users = [[user.__getattribute__(field) for field in fields] for user in users] + msg = tabulate(users, [field.capitalize().replace('_', ' ') for field in fields], + tablefmt=args.output) + print(msg) + + +@cli_utils.action_logging +def users_create(args): + """Creates new user in the DB""" + appbuilder = cached_appbuilder() + role = appbuilder.sm.find_role(args.role) + if not role: + valid_roles = appbuilder.sm.get_all_roles() + raise SystemExit('{} is not a valid role. Valid roles are: {}'.format(args.role, valid_roles)) + + if args.use_random_password: + password = ''.join(random.choice(string.printable) for _ in range(16)) + elif args.password: + password = args.password + else: + password = getpass.getpass('Password:') + password_confirmation = getpass.getpass('Repeat for confirmation:') + if password != password_confirmation: + raise SystemExit('Passwords did not match!') + + if appbuilder.sm.find_user(args.username): + print('{} already exist in the db'.format(args.username)) + return + user = appbuilder.sm.add_user(args.username, args.firstname, args.lastname, + args.email, role, password) + if user: + print('{} user {} created.'.format(args.role, args.username)) + else: + raise SystemExit('Failed to create user.') + + +@cli_utils.action_logging +def users_delete(args): + """Deletes user from DB""" + appbuilder = cached_appbuilder() + + try: + user = next(u for u in appbuilder.sm.get_all_users() + if u.username == args.username) + except StopIteration: + raise SystemExit('{} is not a valid user.'.format(args.username)) + + if appbuilder.sm.del_register_user(user): + print('User {} deleted.'.format(args.username)) + else: + raise SystemExit('Failed to delete user.') + + +@cli_utils.action_logging +def users_manage_role(args, remove=False): + """Deletes or appends user roles""" + if not args.username and not args.email: + raise SystemExit('Missing args: must supply one of --username or --email') + + if args.username and args.email: + raise SystemExit('Conflicting args: must supply either --username' + ' or --email, but not both') + + appbuilder = cached_appbuilder() + user = (appbuilder.sm.find_user(username=args.username) or + appbuilder.sm.find_user(email=args.email)) + if not user: + raise SystemExit('User "{}" does not exist'.format( + args.username or args.email)) + + role = appbuilder.sm.find_role(args.role) + if not role: + valid_roles = appbuilder.sm.get_all_roles() + raise SystemExit('{} is not a valid role. Valid roles are: {}'.format(args.role, valid_roles)) + + if remove: + if role in user.roles: + user.roles = [r for r in user.roles if r != role] + appbuilder.sm.update_user(user) + print('User "{}" removed from role "{}".'.format( + user, + args.role)) + else: + raise SystemExit('User "{}" is not a member of role "{}".'.format( + user, + args.role)) + else: + if role in user.roles: + raise SystemExit('User "{}" is already a member of role "{}".'.format( + user, + args.role)) + else: + user.roles.append(role) + appbuilder.sm.update_user(user) + print('User "{}" added to role "{}".'.format( + user, + args.role)) + + +def users_export(args): + """Exports all users to the json file""" + appbuilder = cached_appbuilder() + users = appbuilder.sm.get_all_users() + fields = ['id', 'username', 'email', 'first_name', 'last_name', 'roles'] + + # In the User model the first and last name fields have underscores, + # but the corresponding parameters in the CLI don't + def remove_underscores(s): + return re.sub("_", "", s) + + users = [ + {remove_underscores(field): user.__getattribute__(field) + if field != 'roles' else [r.name for r in user.roles] + for field in fields} + for user in users + ] + + with open(args.export, 'w') as file: + file.write(json.dumps(users, sort_keys=True, indent=4)) + print("{} users successfully exported to {}".format(len(users), file.name)) + + +@cli_utils.action_logging +def users_import(args): + """Imports users from the json file""" + json_file = getattr(args, 'import') + if not os.path.exists(json_file): + print("File '{}' does not exist") + exit(1) + + users_list = None # pylint: disable=redefined-outer-name + try: + with open(json_file, 'r') as file: + users_list = json.loads(file.read()) + except ValueError as e: + print("File '{}' is not valid JSON. Error: {}".format(json_file, e)) + exit(1) + + users_created, users_updated = _import_users(users_list) + if users_created: + print("Created the following users:\n\t{}".format( + "\n\t".join(users_created))) + + if users_updated: + print("Updated the following users:\n\t{}".format( + "\n\t".join(users_updated))) + + +def _import_users(users_list): # pylint: disable=redefined-outer-name + appbuilder = cached_appbuilder() + users_created = [] + users_updated = [] + + for user in users_list: + roles = [] + for rolename in user['roles']: + role = appbuilder.sm.find_role(rolename) + if not role: + valid_roles = appbuilder.sm.get_all_roles() + print("Error: '{}' is not a valid role. Valid roles are: {}".format(rolename, valid_roles)) + exit(1) + else: + roles.append(role) + + required_fields = ['username', 'firstname', 'lastname', + 'email', 'roles'] + for field in required_fields: + if not user.get(field): + print("Error: '{}' is a required field, but was not " + "specified".format(field)) + exit(1) + + existing_user = appbuilder.sm.find_user(email=user['email']) + if existing_user: + print("Found existing user with email '{}'".format(user['email'])) + existing_user.roles = roles + existing_user.first_name = user['firstname'] + existing_user.last_name = user['lastname'] + + if existing_user.username != user['username']: + print("Error: Changing the username is not allowed - " + "please delete and recreate the user with " + "email '{}'".format(user['email'])) + exit(1) + + appbuilder.sm.update_user(existing_user) + users_updated.append(user['email']) + else: + print("Creating new user with email '{}'".format(user['email'])) + appbuilder.sm.add_user( + username=user['username'], + first_name=user['firstname'], + last_name=user['lastname'], + email=user['email'], + role=roles[0], # add_user() requires exactly 1 role + ) + + if len(roles) > 1: + new_user = appbuilder.sm.find_user(email=user['email']) + new_user.roles = roles + appbuilder.sm.update_user(new_user) + + users_created.append(user['email']) + + return users_created, users_updated + + +add_role = functools.partial(users_manage_role, remove=False) +remove_role = functools.partial(users_manage_role, remove=True) diff --git a/tests/cli/commands/__init__.py b/tests/cli/commands/__init__.py new file mode 100644 index 0000000000000..b7f8352944d3f --- /dev/null +++ b/tests/cli/commands/__init__.py @@ -0,0 +1,19 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# diff --git a/tests/cli/commands/test_role_command.py b/tests/cli/commands/test_role_command.py new file mode 100644 index 0000000000000..41e0b79a08d3e --- /dev/null +++ b/tests/cli/commands/test_role_command.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import io +import unittest +from unittest import mock + +from airflow import models +from airflow.bin import cli +from airflow.cli.commands import role_command +from airflow.settings import Session + +TEST_USER1_EMAIL = 'test-user1@example.com' +TEST_USER2_EMAIL = 'test-user2@example.com' + + +class TestCliRoles(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.dagbag = models.DagBag(include_examples=True) + cls.parser = cli.CLIFactory.get_parser() + + def setUp(self): + from airflow.www import app as application + self.app, self.appbuilder = application.create_app(session=Session, testing=True) + self.clear_roles_and_roles() + + def tearDown(self): + self.clear_roles_and_roles() + + def clear_roles_and_roles(self): + for email in [TEST_USER1_EMAIL, TEST_USER2_EMAIL]: + test_user = self.appbuilder.sm.find_user(email=email) + if test_user: + self.appbuilder.sm.del_register_user(test_user) + for role_name in ['FakeTeamA', 'FakeTeamB']: + if self.appbuilder.sm.find_role(role_name): + self.appbuilder.sm.delete_role(role_name) + + def test_cli_create_roles(self): + self.assertIsNone(self.appbuilder.sm.find_role('FakeTeamA')) + self.assertIsNone(self.appbuilder.sm.find_role('FakeTeamB')) + + args = self.parser.parse_args([ + 'roles', 'create', 'FakeTeamA', 'FakeTeamB' + ]) + role_command.roles_create(args) + + self.assertIsNotNone(self.appbuilder.sm.find_role('FakeTeamA')) + self.assertIsNotNone(self.appbuilder.sm.find_role('FakeTeamB')) + + def test_cli_create_roles_is_reentrant(self): + self.assertIsNone(self.appbuilder.sm.find_role('FakeTeamA')) + self.assertIsNone(self.appbuilder.sm.find_role('FakeTeamB')) + + args = self.parser.parse_args([ + 'roles', 'create', 'FakeTeamA', 'FakeTeamB' + ]) + + role_command.roles_create(args) + + self.assertIsNotNone(self.appbuilder.sm.find_role('FakeTeamA')) + self.assertIsNotNone(self.appbuilder.sm.find_role('FakeTeamB')) + + def test_cli_list_roles(self): + self.appbuilder.sm.add_role('FakeTeamA') + self.appbuilder.sm.add_role('FakeTeamB') + + with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: + role_command.roles_list(self.parser.parse_args(['roles', 'list'])) + stdout = mock_stdout.getvalue() + + self.assertIn('FakeTeamA', stdout) + self.assertIn('FakeTeamB', stdout) + + def test_cli_list_roles_with_args(self): + role_command.roles_list(self.parser.parse_args(['roles', 'list', '--output', 'tsv'])) diff --git a/tests/cli/commands/test_user_command.py b/tests/cli/commands/test_user_command.py new file mode 100644 index 0000000000000..76e5f7c11aea3 --- /dev/null +++ b/tests/cli/commands/test_user_command.py @@ -0,0 +1,254 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import io +import json +import os +import tempfile +import unittest +from unittest import mock + +from airflow import models +from airflow.bin import cli +from airflow.cli.commands import user_command +from airflow.settings import Session + +TEST_USER1_EMAIL = 'test-user1@example.com' +TEST_USER2_EMAIL = 'test-user2@example.com' + + +def _does_user_belong_to_role(appbuilder, email, rolename): + user = appbuilder.sm.find_user(email=email) + role = appbuilder.sm.find_role(rolename) + if user and role: + return role in user.roles + + return False + + +class TestCliUsers(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.dagbag = models.DagBag(include_examples=True) + cls.parser = cli.CLIFactory.get_parser() + + def setUp(self): + from airflow.www import app as application + self.app, self.appbuilder = application.create_app(session=Session, testing=True) + self.clear_roles_and_roles() + + def tearDown(self): + self.clear_roles_and_roles() + + def clear_roles_and_roles(self): + for email in [TEST_USER1_EMAIL, TEST_USER2_EMAIL]: + test_user = self.appbuilder.sm.find_user(email=email) + if test_user: + self.appbuilder.sm.del_register_user(test_user) + for role_name in ['FakeTeamA', 'FakeTeamB']: + if self.appbuilder.sm.find_role(role_name): + self.appbuilder.sm.delete_role(role_name) + + def test_cli_create_user_random_password(self): + args = self.parser.parse_args([ + 'users', 'create', '--username', 'test1', '--lastname', 'doe', + '--firstname', 'jon', + '--email', 'jdoe@foo.com', '--role', 'Viewer', '--use_random_password' + ]) + user_command.users_create(args) + + def test_cli_create_user_supplied_password(self): + args = self.parser.parse_args([ + 'users', 'create', '--username', 'test2', '--lastname', 'doe', + '--firstname', 'jon', + '--email', 'jdoe@apache.org', '--role', 'Viewer', '--password', 'test' + ]) + user_command.users_create(args) + + def test_cli_delete_user(self): + args = self.parser.parse_args([ + 'users', 'create', '--username', 'test3', '--lastname', 'doe', + '--firstname', 'jon', + '--email', 'jdoe@example.com', '--role', 'Viewer', '--use_random_password' + ]) + user_command.users_create(args) + args = self.parser.parse_args([ + 'users', 'delete', '--username', 'test3', + ]) + user_command.users_delete(args) + + def test_cli_list_users(self): + for i in range(0, 3): + args = self.parser.parse_args([ + 'users', 'create', '--username', 'user{}'.format(i), '--lastname', + 'doe', '--firstname', 'jon', + '--email', 'jdoe+{}@gmail.com'.format(i), '--role', 'Viewer', + '--use_random_password' + ]) + user_command.users_create(args) + with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: + user_command.users_list(self.parser.parse_args(['users', 'list'])) + stdout = mock_stdout.getvalue() + for i in range(0, 3): + self.assertIn('user{}'.format(i), stdout) + + def test_cli_list_users_with_args(self): + user_command.users_list(self.parser.parse_args(['users', 'list', '--output', 'tsv'])) + + def test_cli_import_users(self): + def assert_user_in_roles(email, roles): + for role in roles: + self.assertTrue(_does_user_belong_to_role(self.appbuilder, email, role)) + + def assert_user_not_in_roles(email, roles): + for role in roles: + self.assertFalse(_does_user_belong_to_role(self.appbuilder, email, role)) + + assert_user_not_in_roles(TEST_USER1_EMAIL, ['Admin', 'Op']) + assert_user_not_in_roles(TEST_USER2_EMAIL, ['Public']) + users = [ + { + "username": "imported_user1", "lastname": "doe1", + "firstname": "jon", "email": TEST_USER1_EMAIL, + "roles": ["Admin", "Op"] + }, + { + "username": "imported_user2", "lastname": "doe2", + "firstname": "jon", "email": TEST_USER2_EMAIL, + "roles": ["Public"] + } + ] + self._import_users_from_file(users) + + assert_user_in_roles(TEST_USER1_EMAIL, ['Admin', 'Op']) + assert_user_in_roles(TEST_USER2_EMAIL, ['Public']) + + users = [ + { + "username": "imported_user1", "lastname": "doe1", + "firstname": "jon", "email": TEST_USER1_EMAIL, + "roles": ["Public"] + }, + { + "username": "imported_user2", "lastname": "doe2", + "firstname": "jon", "email": TEST_USER2_EMAIL, + "roles": ["Admin"] + } + ] + self._import_users_from_file(users) + + assert_user_not_in_roles(TEST_USER1_EMAIL, ['Admin', 'Op']) + assert_user_in_roles(TEST_USER1_EMAIL, ['Public']) + assert_user_not_in_roles(TEST_USER2_EMAIL, ['Public']) + assert_user_in_roles(TEST_USER2_EMAIL, ['Admin']) + + def test_cli_export_users(self): + user1 = {"username": "imported_user1", "lastname": "doe1", + "firstname": "jon", "email": TEST_USER1_EMAIL, + "roles": ["Public"]} + user2 = {"username": "imported_user2", "lastname": "doe2", + "firstname": "jon", "email": TEST_USER2_EMAIL, + "roles": ["Admin"]} + self._import_users_from_file([user1, user2]) + + users_filename = self._export_users_to_file() + with open(users_filename, mode='r') as file: + retrieved_users = json.loads(file.read()) + os.remove(users_filename) + + # ensure that an export can be imported + self._import_users_from_file(retrieved_users) + + def find_by_username(username): + matches = [u for u in retrieved_users if u['username'] == username] + if not matches: + self.fail("Couldn't find user with username {}".format(username)) + return None + else: + matches[0].pop('id') # this key not required for import + return matches[0] + + self.assertEqual(find_by_username('imported_user1'), user1) + self.assertEqual(find_by_username('imported_user2'), user2) + + def _import_users_from_file(self, user_list): + json_file_content = json.dumps(user_list) + f = tempfile.NamedTemporaryFile(delete=False) + try: + f.write(json_file_content.encode()) + f.flush() + + args = self.parser.parse_args([ + 'users', 'import', f.name + ]) + user_command.users_import(args) + finally: + os.remove(f.name) + + def _export_users_to_file(self): + f = tempfile.NamedTemporaryFile(delete=False) + args = self.parser.parse_args([ + 'users', 'export', f.name + ]) + user_command.users_export(args) + return f.name + + def test_cli_add_user_role(self): + args = self.parser.parse_args([ + 'users', 'create', '--username', 'test4', '--lastname', 'doe', + '--firstname', 'jon', + '--email', TEST_USER1_EMAIL, '--role', 'Viewer', '--use_random_password' + ]) + user_command.users_create(args) + + self.assertFalse( + _does_user_belong_to_role(appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Op'), + "User should not yet be a member of role 'Op'" + ) + + args = self.parser.parse_args([ + 'users', 'add_role', '--username', 'test4', '--role', 'Op' + ]) + user_command.users_manage_role(args, remove=False) + + self.assertTrue( + _does_user_belong_to_role(appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Op'), + "User should have been added to role 'Op'" + ) + + def test_cli_remove_user_role(self): + args = self.parser.parse_args([ + 'users', 'create', '--username', 'test4', '--lastname', 'doe', + '--firstname', 'jon', + '--email', TEST_USER1_EMAIL, '--role', 'Viewer', '--use_random_password' + ]) + user_command.users_create(args) + + self.assertTrue( + _does_user_belong_to_role(appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Viewer'), + "User should have been created with role 'Viewer'" + ) + + args = self.parser.parse_args([ + 'users', 'remove_role', '--username', 'test4', '--role', 'Viewer' + ]) + user_command.users_manage_role(args, remove=True) + + self.assertFalse( + _does_user_belong_to_role(appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Viewer'), + "User should have been removed from role 'Viewer'" + ) diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 2f36e7a5b3bba..c401827b6bea8 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -52,8 +52,6 @@ TEST_DAG_FOLDER = os.path.join( os.path.dirname(dag_folder_path), 'dags') TEST_DAG_ID = 'unit_tests' -TEST_USER1_EMAIL = 'test-user1@example.com' -TEST_USER2_EMAIL = 'test-user2@example.com' def reset(dag_id): @@ -540,230 +538,6 @@ def test_dag_state(self): 'dags', 'state', 'example_bash_operator', DEFAULT_DATE.isoformat()]))) -def _does_user_belong_to_role(appbuilder, email, rolename): - user = appbuilder.sm.find_user(email=email) - role = appbuilder.sm.find_role(rolename) - if user and role: - return role in user.roles - - return False - - -class TestCliUsers(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.dagbag = models.DagBag(include_examples=True) - cls.parser = cli.CLIFactory.get_parser() - - def setUp(self): - from airflow.www import app as application - self.app, self.appbuilder = application.create_app(session=Session, testing=True) - self.clear_roles_and_roles() - - def tearDown(self): - self.clear_roles_and_roles() - - def clear_roles_and_roles(self): - for email in [TEST_USER1_EMAIL, TEST_USER2_EMAIL]: - test_user = self.appbuilder.sm.find_user(email=email) - if test_user: - self.appbuilder.sm.del_register_user(test_user) - for role_name in ['FakeTeamA', 'FakeTeamB']: - if self.appbuilder.sm.find_role(role_name): - self.appbuilder.sm.delete_role(role_name) - - def test_cli_create_user_random_password(self): - args = self.parser.parse_args([ - 'users', 'create', '--username', 'test1', '--lastname', 'doe', - '--firstname', 'jon', - '--email', 'jdoe@foo.com', '--role', 'Viewer', '--use_random_password' - ]) - cli.users_create(args) - - def test_cli_create_user_supplied_password(self): - args = self.parser.parse_args([ - 'users', 'create', '--username', 'test2', '--lastname', 'doe', - '--firstname', 'jon', - '--email', 'jdoe@apache.org', '--role', 'Viewer', '--password', 'test' - ]) - cli.users_create(args) - - def test_cli_delete_user(self): - args = self.parser.parse_args([ - 'users', 'create', '--username', 'test3', '--lastname', 'doe', - '--firstname', 'jon', - '--email', 'jdoe@example.com', '--role', 'Viewer', '--use_random_password' - ]) - cli.users_create(args) - args = self.parser.parse_args([ - 'users', 'delete', '--username', 'test3', - ]) - cli.users_delete(args) - - def test_cli_list_users(self): - for i in range(0, 3): - args = self.parser.parse_args([ - 'users', 'create', '--username', 'user{}'.format(i), '--lastname', - 'doe', '--firstname', 'jon', - '--email', 'jdoe+{}@gmail.com'.format(i), '--role', 'Viewer', - '--use_random_password' - ]) - cli.users_create(args) - with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: - cli.users_list(self.parser.parse_args(['users', 'list'])) - stdout = mock_stdout.getvalue() - for i in range(0, 3): - self.assertIn('user{}'.format(i), stdout) - - def test_cli_list_users_with_args(self): - cli.users_list(self.parser.parse_args(['users', 'list', - '--output', 'tsv'])) - - def test_cli_import_users(self): - def assert_user_in_roles(email, roles): - for role in roles: - self.assertTrue(_does_user_belong_to_role(self.appbuilder, email, role)) - - def assert_user_not_in_roles(email, roles): - for role in roles: - self.assertFalse(_does_user_belong_to_role(self.appbuilder, email, role)) - - assert_user_not_in_roles(TEST_USER1_EMAIL, ['Admin', 'Op']) - assert_user_not_in_roles(TEST_USER2_EMAIL, ['Public']) - users = [ - { - "username": "imported_user1", "lastname": "doe1", - "firstname": "jon", "email": TEST_USER1_EMAIL, - "roles": ["Admin", "Op"] - }, - { - "username": "imported_user2", "lastname": "doe2", - "firstname": "jon", "email": TEST_USER2_EMAIL, - "roles": ["Public"] - } - ] - self._import_users_from_file(users) - - assert_user_in_roles(TEST_USER1_EMAIL, ['Admin', 'Op']) - assert_user_in_roles(TEST_USER2_EMAIL, ['Public']) - - users = [ - { - "username": "imported_user1", "lastname": "doe1", - "firstname": "jon", "email": TEST_USER1_EMAIL, - "roles": ["Public"] - }, - { - "username": "imported_user2", "lastname": "doe2", - "firstname": "jon", "email": TEST_USER2_EMAIL, - "roles": ["Admin"] - } - ] - self._import_users_from_file(users) - - assert_user_not_in_roles(TEST_USER1_EMAIL, ['Admin', 'Op']) - assert_user_in_roles(TEST_USER1_EMAIL, ['Public']) - assert_user_not_in_roles(TEST_USER2_EMAIL, ['Public']) - assert_user_in_roles(TEST_USER2_EMAIL, ['Admin']) - - def test_cli_export_users(self): - user1 = {"username": "imported_user1", "lastname": "doe1", - "firstname": "jon", "email": TEST_USER1_EMAIL, - "roles": ["Public"]} - user2 = {"username": "imported_user2", "lastname": "doe2", - "firstname": "jon", "email": TEST_USER2_EMAIL, - "roles": ["Admin"]} - self._import_users_from_file([user1, user2]) - - users_filename = self._export_users_to_file() - with open(users_filename, mode='r') as file: - retrieved_users = json.loads(file.read()) - os.remove(users_filename) - - # ensure that an export can be imported - self._import_users_from_file(retrieved_users) - - def find_by_username(username): - matches = [u for u in retrieved_users if u['username'] == username] - if not matches: - self.fail("Couldn't find user with username {}".format(username)) - return None - else: - matches[0].pop('id') # this key not required for import - return matches[0] - - self.assertEqual(find_by_username('imported_user1'), user1) - self.assertEqual(find_by_username('imported_user2'), user2) - - def _import_users_from_file(self, user_list): - json_file_content = json.dumps(user_list) - f = tempfile.NamedTemporaryFile(delete=False) - try: - f.write(json_file_content.encode()) - f.flush() - - args = self.parser.parse_args([ - 'users', 'import', f.name - ]) - cli.users_import(args) - finally: - os.remove(f.name) - - def _export_users_to_file(self): - f = tempfile.NamedTemporaryFile(delete=False) - args = self.parser.parse_args([ - 'users', 'export', f.name - ]) - cli.users_export(args) - return f.name - - def test_cli_add_user_role(self): - args = self.parser.parse_args([ - 'users', 'create', '--username', 'test4', '--lastname', 'doe', - '--firstname', 'jon', - '--email', TEST_USER1_EMAIL, '--role', 'Viewer', '--use_random_password' - ]) - cli.users_create(args) - - self.assertFalse( - _does_user_belong_to_role(appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Op'), - "User should not yet be a member of role 'Op'" - ) - - args = self.parser.parse_args([ - 'users', 'add_role', '--username', 'test4', '--role', 'Op' - ]) - cli.users_manage_role(args, remove=False) - - self.assertTrue( - _does_user_belong_to_role(appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Op'), - "User should have been added to role 'Op'" - ) - - def test_cli_remove_user_role(self): - args = self.parser.parse_args([ - 'users', 'create', '--username', 'test4', '--lastname', 'doe', - '--firstname', 'jon', - '--email', TEST_USER1_EMAIL, '--role', 'Viewer', '--use_random_password' - ]) - cli.users_create(args) - - self.assertTrue( - _does_user_belong_to_role(appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Viewer'), - "User should have been created with role 'Viewer'" - ) - - args = self.parser.parse_args([ - 'users', 'remove_role', '--username', 'test4', '--role', 'Viewer' - ]) - cli.users_manage_role(args, remove=True) - - self.assertFalse( - _does_user_belong_to_role(appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Viewer'), - "User should have been removed from role 'Viewer'" - ) - - class TestCliSyncPerms(unittest.TestCase): @classmethod def setUpClass(cls): @@ -809,70 +583,6 @@ def expect_dagbag_contains(self, dags, dagbag_mock): dagbag_mock.return_value = dagbag -class TestCliRoles(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.dagbag = models.DagBag(include_examples=True) - cls.parser = cli.CLIFactory.get_parser() - - def setUp(self): - from airflow.www import app as application - self.app, self.appbuilder = application.create_app(session=Session, testing=True) - self.clear_roles_and_roles() - - def tearDown(self): - self.clear_roles_and_roles() - - def clear_roles_and_roles(self): - for email in [TEST_USER1_EMAIL, TEST_USER2_EMAIL]: - test_user = self.appbuilder.sm.find_user(email=email) - if test_user: - self.appbuilder.sm.del_register_user(test_user) - for role_name in ['FakeTeamA', 'FakeTeamB']: - if self.appbuilder.sm.find_role(role_name): - self.appbuilder.sm.delete_role(role_name) - - def test_cli_create_roles(self): - self.assertIsNone(self.appbuilder.sm.find_role('FakeTeamA')) - self.assertIsNone(self.appbuilder.sm.find_role('FakeTeamB')) - - args = self.parser.parse_args([ - 'roles', 'create', 'FakeTeamA', 'FakeTeamB' - ]) - cli.roles_create(args) - - self.assertIsNotNone(self.appbuilder.sm.find_role('FakeTeamA')) - self.assertIsNotNone(self.appbuilder.sm.find_role('FakeTeamB')) - - def test_cli_create_roles_is_reentrant(self): - self.assertIsNone(self.appbuilder.sm.find_role('FakeTeamA')) - self.assertIsNone(self.appbuilder.sm.find_role('FakeTeamB')) - - args = self.parser.parse_args([ - 'roles', 'create', 'FakeTeamA', 'FakeTeamB' - ]) - - cli.roles_create(args) - - self.assertIsNotNone(self.appbuilder.sm.find_role('FakeTeamA')) - self.assertIsNotNone(self.appbuilder.sm.find_role('FakeTeamB')) - - def test_cli_list_roles(self): - self.appbuilder.sm.add_role('FakeTeamA') - self.appbuilder.sm.add_role('FakeTeamB') - - with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: - cli.roles_list(self.parser.parse_args(['roles', 'list'])) - stdout = mock_stdout.getvalue() - - self.assertIn('FakeTeamA', stdout) - self.assertIn('FakeTeamB', stdout) - - def test_cli_list_roles_with_args(self): - cli.roles_list(self.parser.parse_args(['roles', 'list', - '--output', 'tsv'])) - - class TestCliTasks(unittest.TestCase): @classmethod def setUpClass(cls): From 0c13f4eb9af2ca99aeeb97ea6f230682818f1309 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Fri, 15 Nov 2019 20:17:19 +0100 Subject: [PATCH 05/23] [AIRLFOW-YYY] Move sync_perm command to seperate file --- airflow/bin/cli.py | 33 ++------- .../cli/commands/rotate_fernet_key_command.py | 30 ++++++++ airflow/cli/commands/sync_perm_command.py | 36 ++++++++++ tests/cli/commands/test_sync_perm_command.py | 71 +++++++++++++++++++ tests/cli/test_cli.py | 47 +----------- 5 files changed, 142 insertions(+), 75 deletions(-) create mode 100644 airflow/cli/commands/rotate_fernet_key_command.py create mode 100644 airflow/cli/commands/sync_perm_command.py create mode 100644 tests/cli/commands/test_sync_perm_command.py diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index ffc06d65cbe3e..dfcd36aaab2fc 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -46,7 +46,7 @@ import airflow from airflow import api, jobs, settings from airflow.api.client import get_current_api_client -from airflow.cli.commands import role_command, user_command +from airflow.cli.commands import role_command, rotate_fernet_key_command, sync_perm_command, user_command from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowWebServerTimeout from airflow.executors import get_default_executor @@ -57,7 +57,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin, redirect_stderr, redirect_stdout from airflow.utils.net import get_hostname from airflow.utils.timezone import parse as parsedate -from airflow.www.app import cached_app, cached_appbuilder, create_app +from airflow.www.app import cached_app, create_app api.load_auth() @@ -679,17 +679,6 @@ def dag_next_execution(args): print(None) -@cli_utils.action_logging -def rotate_fernet_key(args): - """Rotates all encrypted connection credentials and variables""" - with db.create_session() as session: - for conn in session.query(Connection).filter( - Connection.is_encrypted | Connection.is_extra_encrypted): - conn.rotate_fernet_key() - for var in session.query(Variable).filter(Variable.is_encrypted): - var.rotate_fernet_key() - - @cli_utils.action_logging def dag_list_dags(args): """Displays dags with or without stats at the command line""" @@ -1476,20 +1465,6 @@ def dag_list_dag_runs(args, dag=None): print(record) -@cli_utils.action_logging -def sync_perm(args): - """Updates permissions for existing roles and DAGs""" - appbuilder = cached_appbuilder() - print('Updating permission, view-menu for all existing roles') - appbuilder.sm.sync_roles() - print('Updating permission on all DAG views') - dags = DagBag().dags.values() - for dag in dags: - appbuilder.sm.sync_perm_for_dag( - dag.dag_id, - dag.access_control) - - class Arg: """Class to keep information about command line argument""" # pylint: disable=redefined-builtin @@ -2395,12 +2370,12 @@ class CLIFactory: }, ), }, { - 'func': sync_perm, + 'func': sync_perm_command.sync_perm, 'help': "Update permissions for existing roles and DAGs.", 'args': tuple(), }, { - 'func': rotate_fernet_key, + 'func': rotate_fernet_key_command.rotate_fernet_key, 'help': 'Rotate all encrypted connection credentials and variables; see ' 'https://airflow.readthedocs.io/en/stable/howto/secure-connections.html' '#rotating-encryption-keys.', diff --git a/airflow/cli/commands/rotate_fernet_key_command.py b/airflow/cli/commands/rotate_fernet_key_command.py new file mode 100644 index 0000000000000..303812adf942c --- /dev/null +++ b/airflow/cli/commands/rotate_fernet_key_command.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Rotate Fernet key command""" +from airflow.models import Connection, Variable +from airflow.utils import cli as cli_utils, db + + +@cli_utils.action_logging +def rotate_fernet_key(args): + """Rotates all encrypted connection credentials and variables""" + with db.create_session() as session: + for conn in session.query(Connection).filter( + Connection.is_encrypted | Connection.is_extra_encrypted): + conn.rotate_fernet_key() + for var in session.query(Variable).filter(Variable.is_encrypted): + var.rotate_fernet_key() diff --git a/airflow/cli/commands/sync_perm_command.py b/airflow/cli/commands/sync_perm_command.py new file mode 100644 index 0000000000000..aea591cd50f71 --- /dev/null +++ b/airflow/cli/commands/sync_perm_command.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Sync permission command""" +from airflow.models import DagBag +from airflow.utils import cli as cli_utils +from airflow.www.app import cached_appbuilder + + +@cli_utils.action_logging +def sync_perm(args): + """Updates permissions for existing roles and DAGs""" + appbuilder = cached_appbuilder() + print('Updating permission, view-menu for all existing roles') + appbuilder.sm.sync_roles() + print('Updating permission on all DAG views') + dags = DagBag().dags.values() + for dag in dags: + appbuilder.sm.sync_perm_for_dag( + dag.dag_id, + dag.access_control) diff --git a/tests/cli/commands/test_sync_perm_command.py b/tests/cli/commands/test_sync_perm_command.py new file mode 100644 index 0000000000000..86a64d1ad032c --- /dev/null +++ b/tests/cli/commands/test_sync_perm_command.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import unittest +from unittest import mock + +from airflow import DAG, models +from airflow.bin import cli +from airflow.cli.commands import sync_perm_command +from airflow.settings import Session + + +class TestCliSyncPerm(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.dagbag = models.DagBag(include_examples=True) + cls.parser = cli.CLIFactory.get_parser() + + def setUp(self): + from airflow.www import app as application + self.app, self.appbuilder = application.create_app(session=Session, testing=True) + + @mock.patch("airflow.cli.commands.sync_perm_command.DagBag") + def test_cli_sync_perm(self, dagbag_mock): + self.expect_dagbag_contains([ + DAG('has_access_control', + access_control={ + 'Public': {'can_dag_read'} + }), + DAG('no_access_control') + ], dagbag_mock) + self.appbuilder.sm = mock.Mock() + + args = self.parser.parse_args([ + 'sync_perm' + ]) + sync_perm_command.sync_perm(args) + + assert self.appbuilder.sm.sync_roles.call_count == 1 + + self.assertEqual(2, + len(self.appbuilder.sm.sync_perm_for_dag.mock_calls)) + self.appbuilder.sm.sync_perm_for_dag.assert_any_call( + 'has_access_control', + {'Public': {'can_dag_read'}} + ) + self.appbuilder.sm.sync_perm_for_dag.assert_any_call( + 'no_access_control', + None, + ) + + def expect_dagbag_contains(self, dags, dagbag_mock): + dagbag = mock.Mock() + dagbag.dags = {dag.dag_id: dag for dag in dags} + dagbag_mock.return_value = dagbag diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index c401827b6bea8..e025e39a9b66a 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -34,7 +34,7 @@ import pytz import airflow.bin.cli as cli -from airflow import DAG, AirflowException, models, settings +from airflow import AirflowException, models, settings from airflow.bin.cli import get_dag, get_num_ready_workers_running, task_run from airflow.models import Connection, DagModel, Pool, TaskInstance, Variable from airflow.settings import Session @@ -538,51 +538,6 @@ def test_dag_state(self): 'dags', 'state', 'example_bash_operator', DEFAULT_DATE.isoformat()]))) -class TestCliSyncPerms(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.dagbag = models.DagBag(include_examples=True) - cls.parser = cli.CLIFactory.get_parser() - - def setUp(self): - from airflow.www import app as application - self.app, self.appbuilder = application.create_app(session=Session, testing=True) - - @mock.patch("airflow.bin.cli.DagBag") - def test_cli_sync_perm(self, dagbag_mock): - self.expect_dagbag_contains([ - DAG('has_access_control', - access_control={ - 'Public': {'can_dag_read'} - }), - DAG('no_access_control') - ], dagbag_mock) - self.appbuilder.sm = mock.Mock() - - args = self.parser.parse_args([ - 'sync_perm' - ]) - cli.sync_perm(args) - - assert self.appbuilder.sm.sync_roles.call_count == 1 - - self.assertEqual(2, - len(self.appbuilder.sm.sync_perm_for_dag.mock_calls)) - self.appbuilder.sm.sync_perm_for_dag.assert_any_call( - 'has_access_control', - {'Public': {'can_dag_read'}} - ) - self.appbuilder.sm.sync_perm_for_dag.assert_any_call( - 'no_access_control', - None, - ) - - def expect_dagbag_contains(self, dags, dagbag_mock): - dagbag = mock.Mock() - dagbag.dags = {dag.dag_id: dag for dag in dags} - dagbag_mock.return_value = dagbag - - class TestCliTasks(unittest.TestCase): @classmethod def setUpClass(cls): From 709216775b9e61d779875491f70b4dcd174ba7ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Fri, 15 Nov 2019 21:03:07 +0100 Subject: [PATCH 06/23] [AIRLFOW-YYY] Move task commands to separate file --- airflow/bin/cli.py | 290 ++--------------------- airflow/cli/commands/task_command.py | 263 +++++++++++++++++++++ airflow/utils/cli.py | 37 ++- tests/cli/commands/test_task_command.py | 293 ++++++++++++++++++++++++ tests/cli/test_cli.py | 157 +------------ tests/jobs/test_backfill_job.py | 54 ----- tests/utils/test_cli_util.py | 16 ++ 7 files changed, 623 insertions(+), 487 deletions(-) create mode 100644 airflow/cli/commands/task_command.py create mode 100644 tests/cli/commands/test_task_command.py diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index dfcd36aaab2fc..c790390113a63 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -21,11 +21,9 @@ import argparse import errno -import importlib import json import logging import os -import re import reprlib import signal import subprocess @@ -46,16 +44,16 @@ import airflow from airflow import api, jobs, settings from airflow.api.client import get_current_api_client -from airflow.cli.commands import role_command, rotate_fernet_key_command, sync_perm_command, user_command +from airflow.cli.commands import ( + role_command, rotate_fernet_key_command, sync_perm_command, task_command, user_command, +) from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowWebServerTimeout -from airflow.executors import get_default_executor -from airflow.models import DAG, Connection, DagBag, DagModel, DagPickle, DagRun, TaskInstance, Variable -from airflow.ti_deps.dep_context import SCHEDULER_QUEUED_DEPS, DepContext +from airflow.models import DAG, Connection, DagBag, DagModel, DagRun, TaskInstance, Variable from airflow.utils import cli as cli_utils, db +from airflow.utils.cli import get_dag, process_subdir from airflow.utils.dot_renderer import render_dag -from airflow.utils.log.logging_mixin import LoggingMixin, redirect_stderr, redirect_stdout -from airflow.utils.net import get_hostname +from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.timezone import parse as parsedate from airflow.www.app import cached_app, create_app @@ -122,38 +120,6 @@ def setup_locations(process, pid=None, stdout=None, stderr=None, log=None): return pid, stdout, stderr, log -def process_subdir(subdir): - """Expands path to absolute by replacing 'DAGS_FOLDER', '~', '.', etc.""" - if subdir: - subdir = subdir.replace('DAGS_FOLDER', DAGS_FOLDER) - subdir = os.path.abspath(os.path.expanduser(subdir)) - return subdir - - -def get_dag(args): - """Returns DAG of a given dag_id""" - dagbag = DagBag(process_subdir(args.subdir)) - if args.dag_id not in dagbag.dags: - raise AirflowException( - 'dag_id could not be found: {}. Either the dag did not exist or it failed to ' - 'parse.'.format(args.dag_id)) - return dagbag.dags[args.dag_id] - - -def get_dags(args): - """Returns DAG(s) matching a given regex or dag_id""" - if not args.dag_regex: - return [get_dag(args)] - dagbag = DagBag(process_subdir(args.subdir)) - matched_dags = [dag for dag in dagbag.dags.values() if re.search( - args.dag_id, dag.dag_id)] - if not matched_dags: - raise AirflowException( - 'dag_id could not be found with regex: {}. Either the dag did not exist ' - 'or it failed to parse.'.format(args.dag_id)) - return matched_dags - - @cli_utils.action_logging def dag_backfill(args, dag=None): """Creates backfill job or dry run for a DAG""" @@ -499,149 +465,6 @@ def dag_show(args): print(dot.source) -def _run(args, dag, ti): - if args.local: - run_job = jobs.LocalTaskJob( - task_instance=ti, - mark_success=args.mark_success, - pickle_id=args.pickle, - ignore_all_deps=args.ignore_all_dependencies, - ignore_depends_on_past=args.ignore_depends_on_past, - ignore_task_deps=args.ignore_dependencies, - ignore_ti_state=args.force, - pool=args.pool) - run_job.run() - elif args.raw: - ti._run_raw_task( # pylint: disable=protected-access - mark_success=args.mark_success, - job_id=args.job_id, - pool=args.pool, - ) - else: - pickle_id = None - if args.ship_dag: - try: - # Running remotely, so pickling the DAG - with db.create_session() as session: - pickle = DagPickle(dag) - session.add(pickle) - pickle_id = pickle.id - # TODO: This should be written to a log - print('Pickled dag {dag} as pickle_id: {pickle_id}'.format( - dag=dag, pickle_id=pickle_id)) - except Exception as e: - print('Could not pickle the DAG') - print(e) - raise e - - executor = get_default_executor() - executor.start() - print("Sending to executor.") - executor.queue_task_instance( - ti, - mark_success=args.mark_success, - pickle_id=pickle_id, - ignore_all_deps=args.ignore_all_dependencies, - ignore_depends_on_past=args.ignore_depends_on_past, - ignore_task_deps=args.ignore_dependencies, - ignore_ti_state=args.force, - pool=args.pool) - executor.heartbeat() - executor.end() - - -@cli_utils.action_logging -def task_run(args, dag=None): - """Runs a single task instance""" - if dag: - args.dag_id = dag.dag_id - - log = LoggingMixin().log - - # Load custom airflow config - if args.cfg_path: - with open(args.cfg_path, 'r') as conf_file: - conf_dict = json.load(conf_file) - - if os.path.exists(args.cfg_path): - os.remove(args.cfg_path) - - conf.read_dict(conf_dict, source=args.cfg_path) - settings.configure_vars() - - # IMPORTANT, have to use the NullPool, otherwise, each "run" command may leave - # behind multiple open sleeping connections while heartbeating, which could - # easily exceed the database connection limit when - # processing hundreds of simultaneous tasks. - settings.configure_orm(disable_connection_pool=True) - - if not args.pickle and not dag: - dag = get_dag(args) - elif not dag: - with db.create_session() as session: - log.info('Loading pickle id %s', args.pickle) - dag_pickle = session.query(DagPickle).filter(DagPickle.id == args.pickle).first() - if not dag_pickle: - raise AirflowException("Who hid the pickle!? [missing pickle]") - dag = dag_pickle.pickle - - task = dag.get_task(task_id=args.task_id) - ti = TaskInstance(task, args.execution_date) - ti.refresh_from_db() - - ti.init_run_context(raw=args.raw) - - hostname = get_hostname() - log.info("Running %s on host %s", ti, hostname) - - if args.interactive: - _run(args, dag, ti) - else: - with redirect_stdout(ti.log, logging.INFO), redirect_stderr(ti.log, logging.WARN): - _run(args, dag, ti) - logging.shutdown() - - -@cli_utils.action_logging -def task_failed_deps(args): - """ - Returns the unmet dependencies for a task instance from the perspective of the - scheduler (i.e. why a task instance doesn't get scheduled and then queued by the - scheduler, and then run by an executor). - >>> airflow tasks failed_deps tutorial sleep 2015-01-01 - Task instance dependencies not met: - Dagrun Running: Task instance's dagrun did not exist: Unknown reason - Trigger Rule: Task's trigger rule 'all_success' requires all upstream tasks - to have succeeded, but found 1 non-success(es). - """ - dag = get_dag(args) - task = dag.get_task(task_id=args.task_id) - ti = TaskInstance(task, args.execution_date) - - dep_context = DepContext(deps=SCHEDULER_QUEUED_DEPS) - failed_deps = list(ti.get_failed_dep_statuses(dep_context=dep_context)) - # TODO, Do we want to print or log this - if failed_deps: - print("Task instance dependencies not met:") - for dep in failed_deps: - print("{}: {}".format(dep.dep_name, dep.reason)) - else: - print("Task instance dependencies are all met.") - - -@cli_utils.action_logging -def task_state(args): - """ - Returns the state of a TaskInstance at the command line. - >>> airflow tasks state tutorial sleep 2015-01-01 - success - """ - dag = get_dag(args) - task = dag.get_task(task_id=args.task_id) - ti = TaskInstance(task, args.execution_date) - print(ti.current_state()) - - @cli_utils.action_logging def dag_state(args): """ @@ -695,17 +518,6 @@ def dag_list_dags(args): print(dagbag.dagbag_report()) -@cli_utils.action_logging -def task_list(args, dag=None): - """Lists the tasks within a DAG at the command line""" - dag = dag or get_dag(args) - if args.tree: - dag.tree_view() - else: - tasks = sorted([t.task_id for t in dag.tasks]) - print("\n".join(sorted(tasks))) - - @cli_utils.action_logging def dag_list_jobs(args, dag=None): """Lists latest n jobs""" @@ -738,82 +550,6 @@ def dag_list_jobs(args, dag=None): print(msg) -@cli_utils.action_logging -def task_test(args, dag=None): - """Tests task for a given dag_id""" - # We want log outout from operators etc to show up here. Normally - # airflow.task would redirect to a file, but here we want it to propagate - # up to the normal airflow handler. - logging.getLogger('airflow.task').propagate = True - - dag = dag or get_dag(args) - - task = dag.get_task(task_id=args.task_id) - # Add CLI provided task_params to task.params - if args.task_params: - passed_in_params = json.loads(args.task_params) - task.params.update(passed_in_params) - ti = TaskInstance(task, args.execution_date) - - try: - if args.dry_run: - ti.dry_run() - else: - ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True) - except Exception: # pylint: disable=broad-except - if args.post_mortem: - try: - debugger = importlib.import_module("ipdb") - except ImportError: - debugger = importlib.import_module("pdb") - debugger.post_mortem() - else: - raise - - -@cli_utils.action_logging -def task_render(args): - """Renders and displays templated fields for a given task""" - dag = get_dag(args) - task = dag.get_task(task_id=args.task_id) - ti = TaskInstance(task, args.execution_date) - ti.render_templates() - for attr in task.__class__.template_fields: - print(textwrap.dedent("""\ - # ---------------------------------------------------------- - # property: {} - # ---------------------------------------------------------- - {} - """.format(attr, getattr(task, attr)))) - - -@cli_utils.action_logging -def task_clear(args): - """Clears all task instances or only those matched by regex for a DAG(s)""" - logging.basicConfig( - level=settings.LOGGING_LEVEL, - format=settings.SIMPLE_LOG_FORMAT) - dags = get_dags(args) - - if args.task_regex: - for idx, dag in enumerate(dags): - dags[idx] = dag.sub_dag( - task_regex=args.task_regex, - include_downstream=args.downstream, - include_upstream=args.upstream) - - DAG.clear_dags( - dags, - start_date=args.start_date, - end_date=args.end_date, - only_failed=args.only_failed, - only_running=args.only_running, - confirm_prompt=not args.yes, - include_subdags=not args.exclude_subdags, - include_parentdag=not args.exclude_parentdag, - ) - - def get_num_ready_workers_running(gunicorn_master_proc): """Returns number of ready Gunicorn workers by looking for READY_PREFIX in process name""" workers = psutil.Process(gunicorn_master_proc.pid).children() @@ -2082,13 +1818,13 @@ class CLIFactory: 'name': 'tasks', 'subcommands': ( { - 'func': task_list, + 'func': task_command.task_list, 'name': 'list', 'help': "List the tasks within a DAG", 'args': ('dag_id', 'tree', 'subdir'), }, { - 'func': task_clear, + 'func': task_command.task_clear, 'name': 'clear', 'help': "Clear a set of task instance, as if they never ran", 'args': ( @@ -2097,13 +1833,13 @@ class CLIFactory: 'only_running', 'exclude_subdags', 'exclude_parentdag', 'dag_regex'), }, { - 'func': task_state, + 'func': task_command.task_state, 'name': 'state', 'help': "Get the status of a task instance", 'args': ('dag_id', 'task_id', 'execution_date', 'subdir'), }, { - 'func': task_failed_deps, + 'func': task_command.task_failed_deps, 'name': 'failed_deps', 'help': ( "Returns the unmet dependencies for a task instance from the perspective " @@ -2113,13 +1849,13 @@ class CLIFactory: 'args': ('dag_id', 'task_id', 'execution_date', 'subdir'), }, { - 'func': task_render, + 'func': task_command.task_render, 'name': 'render', 'help': "Render a task instance's template(s)", 'args': ('dag_id', 'task_id', 'execution_date', 'subdir'), }, { - 'func': task_run, + 'func': task_command.task_run, 'name': 'run', 'help': "Run a single task instance", 'args': ( @@ -2129,7 +1865,7 @@ class CLIFactory: 'ignore_depends_on_past', 'ship_dag', 'pickle', 'job_id', 'interactive',), }, { - 'func': task_test, + 'func': task_command.task_test, 'name': 'test', 'help': ( "Test a task instance. This will run a task without checking for " diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py new file mode 100644 index 0000000000000..15c81a41a8af5 --- /dev/null +++ b/airflow/cli/commands/task_command.py @@ -0,0 +1,263 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Task sub-commands""" +import importlib +import json +import logging +import os +import textwrap + +from airflow import DAG, AirflowException, LoggingMixin, conf, jobs, settings +from airflow.executors import get_default_executor +from airflow.models import DagPickle, TaskInstance +from airflow.ti_deps.dep_context import SCHEDULER_QUEUED_DEPS, DepContext +from airflow.utils import cli as cli_utils, db +from airflow.utils.cli import get_dag, get_dags +from airflow.utils.log.logging_mixin import redirect_stderr, redirect_stdout +from airflow.utils.net import get_hostname + + +def _run(args, dag, ti): + if args.local: + run_job = jobs.LocalTaskJob( + task_instance=ti, + mark_success=args.mark_success, + pickle_id=args.pickle, + ignore_all_deps=args.ignore_all_dependencies, + ignore_depends_on_past=args.ignore_depends_on_past, + ignore_task_deps=args.ignore_dependencies, + ignore_ti_state=args.force, + pool=args.pool) + run_job.run() + elif args.raw: + ti._run_raw_task( # pylint: disable=protected-access + mark_success=args.mark_success, + job_id=args.job_id, + pool=args.pool, + ) + else: + pickle_id = None + if args.ship_dag: + try: + # Running remotely, so pickling the DAG + with db.create_session() as session: + pickle = DagPickle(dag) + session.add(pickle) + pickle_id = pickle.id + # TODO: This should be written to a log + print('Pickled dag {dag} as pickle_id: {pickle_id}'.format( + dag=dag, pickle_id=pickle_id)) + except Exception as e: + print('Could not pickle the DAG') + print(e) + raise e + + executor = get_default_executor() + executor.start() + print("Sending to executor.") + executor.queue_task_instance( + ti, + mark_success=args.mark_success, + pickle_id=pickle_id, + ignore_all_deps=args.ignore_all_dependencies, + ignore_depends_on_past=args.ignore_depends_on_past, + ignore_task_deps=args.ignore_dependencies, + ignore_ti_state=args.force, + pool=args.pool) + executor.heartbeat() + executor.end() + + +@cli_utils.action_logging +def task_run(args, dag=None): + """Runs a single task instance""" + if dag: + args.dag_id = dag.dag_id + + log = LoggingMixin().log + + # Load custom airflow config + if args.cfg_path: + with open(args.cfg_path, 'r') as conf_file: + conf_dict = json.load(conf_file) + + if os.path.exists(args.cfg_path): + os.remove(args.cfg_path) + + conf.read_dict(conf_dict, source=args.cfg_path) + settings.configure_vars() + + # IMPORTANT, have to use the NullPool, otherwise, each "run" command may leave + # behind multiple open sleeping connections while heartbeating, which could + # easily exceed the database connection limit when + # processing hundreds of simultaneous tasks. + settings.configure_orm(disable_connection_pool=True) + + if not args.pickle and not dag: + dag = get_dag(args) + elif not dag: + with db.create_session() as session: + log.info('Loading pickle id %s', args.pickle) + dag_pickle = session.query(DagPickle).filter(DagPickle.id == args.pickle).first() + if not dag_pickle: + raise AirflowException("Who hid the pickle!? [missing pickle]") + dag = dag_pickle.pickle + + task = dag.get_task(task_id=args.task_id) + ti = TaskInstance(task, args.execution_date) + ti.refresh_from_db() + + ti.init_run_context(raw=args.raw) + + hostname = get_hostname() + log.info("Running %s on host %s", ti, hostname) + + if args.interactive: + _run(args, dag, ti) + else: + with redirect_stdout(ti.log, logging.INFO), redirect_stderr(ti.log, logging.WARN): + _run(args, dag, ti) + logging.shutdown() + + +@cli_utils.action_logging +def task_failed_deps(args): + """ + Returns the unmet dependencies for a task instance from the perspective of the + scheduler (i.e. why a task instance doesn't get scheduled and then queued by the + scheduler, and then run by an executor). + >>> airflow tasks failed_deps tutorial sleep 2015-01-01 + Task instance dependencies not met: + Dagrun Running: Task instance's dagrun did not exist: Unknown reason + Trigger Rule: Task's trigger rule 'all_success' requires all upstream tasks + to have succeeded, but found 1 non-success(es). + """ + dag = get_dag(args) + task = dag.get_task(task_id=args.task_id) + ti = TaskInstance(task, args.execution_date) + + dep_context = DepContext(deps=SCHEDULER_QUEUED_DEPS) + failed_deps = list(ti.get_failed_dep_statuses(dep_context=dep_context)) + # TODO, Do we want to print or log this + if failed_deps: + print("Task instance dependencies not met:") + for dep in failed_deps: + print("{}: {}".format(dep.dep_name, dep.reason)) + else: + print("Task instance dependencies are all met.") + + +@cli_utils.action_logging +def task_state(args): + """ + Returns the state of a TaskInstance at the command line. + >>> airflow tasks state tutorial sleep 2015-01-01 + success + """ + dag = get_dag(args) + task = dag.get_task(task_id=args.task_id) + ti = TaskInstance(task, args.execution_date) + print(ti.current_state()) + + +@cli_utils.action_logging +def task_list(args, dag=None): + """Lists the tasks within a DAG at the command line""" + dag = dag or get_dag(args) + if args.tree: + dag.tree_view() + else: + tasks = sorted([t.task_id for t in dag.tasks]) + print("\n".join(sorted(tasks))) + + +@cli_utils.action_logging +def task_test(args, dag=None): + """Tests task for a given dag_id""" + # We want log outout from operators etc to show up here. Normally + # airflow.task would redirect to a file, but here we want it to propagate + # up to the normal airflow handler. + logging.getLogger('airflow.task').propagate = True + + dag = dag or get_dag(args) + + task = dag.get_task(task_id=args.task_id) + # Add CLI provided task_params to task.params + if args.task_params: + passed_in_params = json.loads(args.task_params) + task.params.update(passed_in_params) + ti = TaskInstance(task, args.execution_date) + + try: + if args.dry_run: + ti.dry_run() + else: + ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True) + except Exception: # pylint: disable=broad-except + if args.post_mortem: + try: + debugger = importlib.import_module("ipdb") + except ImportError: + debugger = importlib.import_module("pdb") + debugger.post_mortem() + else: + raise + + +@cli_utils.action_logging +def task_render(args): + """Renders and displays templated fields for a given task""" + dag = get_dag(args) + task = dag.get_task(task_id=args.task_id) + ti = TaskInstance(task, args.execution_date) + ti.render_templates() + for attr in task.__class__.template_fields: + print(textwrap.dedent("""\ + # ---------------------------------------------------------- + # property: {} + # ---------------------------------------------------------- + {} + """.format(attr, getattr(task, attr)))) + + +@cli_utils.action_logging +def task_clear(args): + """Clears all task instances or only those matched by regex for a DAG(s)""" + logging.basicConfig( + level=settings.LOGGING_LEVEL, + format=settings.SIMPLE_LOG_FORMAT) + dags = get_dags(args) + + if args.task_regex: + for idx, dag in enumerate(dags): + dags[idx] = dag.sub_dag( + task_regex=args.task_regex, + include_downstream=args.downstream, + include_upstream=args.upstream) + + DAG.clear_dags( + dags, + start_date=args.start_date, + end_date=args.end_date, + only_failed=args.only_failed, + only_running=args.only_running, + confirm_prompt=not args.yes, + include_subdags=not args.exclude_subdags, + include_parentdag=not args.exclude_parentdag, + ) diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py index d33ca806d6e92..39d9647d2abcc 100644 --- a/airflow/utils/cli.py +++ b/airflow/utils/cli.py @@ -24,12 +24,15 @@ import functools import getpass import json +import os +import re import socket import sys from argparse import Namespace from datetime import datetime -from airflow.models import Log +from airflow import AirflowException, settings +from airflow.models import DagBag, Log from airflow.utils import cli_action_loggers @@ -115,3 +118,35 @@ def _build_metrics(func_name, namespace): execution_date=metrics.get('execution_date')) metrics['log'] = log return metrics + + +def process_subdir(subdir): + """Expands path to absolute by replacing 'DAGS_FOLDER', '~', '.', etc.""" + if subdir: + subdir = subdir.replace('DAGS_FOLDER', settings.DAGS_FOLDER) + subdir = os.path.abspath(os.path.expanduser(subdir)) + return subdir + + +def get_dag(args): + """Returns DAG of a given dag_id""" + dagbag = DagBag(process_subdir(args.subdir)) + if args.dag_id not in dagbag.dags: + raise AirflowException( + 'dag_id could not be found: {}. Either the dag did not exist or it failed to ' + 'parse.'.format(args.dag_id)) + return dagbag.dags[args.dag_id] + + +def get_dags(args): + """Returns DAG(s) matching a given regex or dag_id""" + if not args.dag_regex: + return [get_dag(args)] + dagbag = DagBag(process_subdir(args.subdir)) + matched_dags = [dag for dag in dagbag.dags.values() if re.search( + args.dag_id, dag.dag_id)] + if not matched_dags: + raise AirflowException( + 'dag_id could not be found with regex: {}. Either the dag did not exist ' + 'or it failed to parse.'.format(args.dag_id)) + return matched_dags diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py new file mode 100644 index 0000000000000..4a37a05dc2f9a --- /dev/null +++ b/tests/cli/commands/test_task_command.py @@ -0,0 +1,293 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import io +import sys +import unittest +from argparse import Namespace +from datetime import datetime, timedelta +from unittest import mock +from unittest.mock import MagicMock + +from airflow import models +from airflow.bin import cli +from airflow.cli.commands import task_command +from airflow.models import DagBag, TaskInstance +from airflow.settings import Session +from airflow.utils import timezone +from airflow.utils.cli import get_dag +from airflow.utils.state import State +from tests.test_utils.db import clear_db_pools, clear_db_runs + +DEFAULT_DATE = timezone.make_aware(datetime(2016, 1, 1)) + + +def reset(dag_id): + session = Session() + tis = session.query(models.TaskInstance).filter_by(dag_id=dag_id) + tis.delete() + session.commit() + session.close() + + +def create_mock_args( # pylint: disable=too-many-arguments + task_id, + dag_id, + subdir, + execution_date, + task_params=None, + dry_run=False, + queue=None, + pool=None, + priority_weight_total=None, + retries=0, + local=True, + mark_success=False, + ignore_all_dependencies=False, + ignore_depends_on_past=False, + ignore_dependencies=False, + force=False, + run_as_user=None, + executor_config=None, + cfg_path=None, + pickle=None, + raw=None, + interactive=None, +): + if executor_config is None: + executor_config = {} + args = MagicMock(spec=Namespace) + args.task_id = task_id + args.dag_id = dag_id + args.subdir = subdir + args.task_params = task_params + args.execution_date = execution_date + args.dry_run = dry_run + args.queue = queue + args.pool = pool + args.priority_weight_total = priority_weight_total + args.retries = retries + args.local = local + args.run_as_user = run_as_user + args.executor_config = executor_config + args.cfg_path = cfg_path + args.pickle = pickle + args.raw = raw + args.mark_success = mark_success + args.ignore_all_dependencies = ignore_all_dependencies + args.ignore_depends_on_past = ignore_depends_on_past + args.ignore_dependencies = ignore_dependencies + args.force = force + args.interactive = interactive + return args + + +class TestCliTasks(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.dagbag = models.DagBag(include_examples=True) + cls.parser = cli.CLIFactory.get_parser() + + def test_cli_list_tasks(self): + for dag_id in self.dagbag.dags: + args = self.parser.parse_args(['tasks', 'list', dag_id]) + task_command.task_list(args) + + args = self.parser.parse_args([ + 'tasks', 'list', 'example_bash_operator', '--tree']) + task_command.task_list(args) + + def test_test(self): + """Test the `airflow test` command""" + args = create_mock_args( + task_id='print_the_context', + dag_id='example_python_operator', + subdir=None, + execution_date=timezone.parse('2018-01-01') + ) + + saved_stdout = sys.stdout + try: + sys.stdout = out = io.StringIO() + task_command.task_test(args) + + output = out.getvalue() + # Check that prints, and log messages, are shown + self.assertIn("'example_python_operator__print_the_context__20180101'", output) + finally: + sys.stdout = saved_stdout + + @mock.patch("airflow.cli.commands.task_command.jobs.LocalTaskJob") + def test_run_naive_taskinstance(self, mock_local_job): + """ + Test that we can run naive (non-localized) task instances + """ + naive_date = datetime(2016, 1, 1) + dag_id = 'test_run_ignores_all_dependencies' + + dag = self.dagbag.get_dag('test_run_ignores_all_dependencies') + + task0_id = 'test_run_dependent_task' + args0 = ['tasks', + 'run', + '-A', + '--local', + dag_id, + task0_id, + naive_date.isoformat()] + + task_command.task_run(self.parser.parse_args(args0), dag=dag) + mock_local_job.assert_called_once_with( + task_instance=mock.ANY, + mark_success=False, + ignore_all_deps=True, + ignore_depends_on_past=False, + ignore_task_deps=False, + ignore_ti_state=False, + pickle_id=None, + pool=None, + ) + + def test_cli_test(self): + task_command.task_test(self.parser.parse_args([ + 'tasks', 'test', 'example_bash_operator', 'runme_0', + DEFAULT_DATE.isoformat()])) + task_command.task_test(self.parser.parse_args([ + 'tasks', 'test', 'example_bash_operator', 'runme_0', '--dry_run', + DEFAULT_DATE.isoformat()])) + + def test_cli_test_with_params(self): + task_command.task_test(self.parser.parse_args([ + 'tasks', 'test', 'example_passing_params_via_test_command', 'run_this', + '-tp', '{"foo":"bar"}', DEFAULT_DATE.isoformat()])) + task_command.task_test(self.parser.parse_args([ + 'tasks', 'test', 'example_passing_params_via_test_command', 'also_run_this', + '-tp', '{"foo":"bar"}', DEFAULT_DATE.isoformat()])) + + def test_cli_run(self): + task_command.task_run(self.parser.parse_args([ + 'tasks', 'run', 'example_bash_operator', 'runme_0', '-l', + DEFAULT_DATE.isoformat()])) + + def test_task_state(self): + task_command.task_state(self.parser.parse_args([ + 'tasks', 'state', 'example_bash_operator', 'runme_0', + DEFAULT_DATE.isoformat()])) + + def test_subdag_clear(self): + args = self.parser.parse_args([ + 'tasks', 'clear', 'example_subdag_operator', '--yes']) + task_command.task_clear(args) + args = self.parser.parse_args([ + 'tasks', 'clear', 'example_subdag_operator', '--yes', '--exclude_subdags']) + task_command.task_clear(args) + + def test_parentdag_downstream_clear(self): + args = self.parser.parse_args([ + 'tasks', 'clear', 'example_subdag_operator.section-1', '--yes']) + task_command.task_clear(args) + args = self.parser.parse_args([ + 'tasks', 'clear', 'example_subdag_operator.section-1', '--yes', + '--exclude_parentdag']) + task_command.task_clear(args) + + def test_local_run(self): + args = create_mock_args( + task_id='print_the_context', + dag_id='example_python_operator', + subdir='/root/dags/example_python_operator.py', + interactive=True, + execution_date=timezone.parse('2018-04-27T08:39:51.298439+00:00') + ) + dag = get_dag(args) + reset(dag.dag_id) + + with mock.patch('argparse.Namespace', args) as mock_args: + task_command.task_run(mock_args) + task = dag.get_task(task_id=args.task_id) + ti = TaskInstance(task, args.execution_date) + ti.refresh_from_db() + state = ti.current_state() + self.assertEqual(state, State.SUCCESS) + + +class TestCliTaskBackfill(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.dagbag = DagBag(include_examples=True) + + def setUp(self): + clear_db_runs() + clear_db_pools() + + self.parser = cli.CLIFactory.get_parser() + + def test_run_ignores_all_dependencies(self): + """ + Test that run respects ignore_all_dependencies + """ + dag_id = 'test_run_ignores_all_dependencies' + + dag = self.dagbag.get_dag('test_run_ignores_all_dependencies') + dag.clear() + + task0_id = 'test_run_dependent_task' + args0 = ['tasks', + 'run', + '-A', + dag_id, + task0_id, + DEFAULT_DATE.isoformat()] + task_command.task_run(self.parser.parse_args(args0)) + ti_dependent0 = TaskInstance( + task=dag.get_task(task0_id), + execution_date=DEFAULT_DATE) + + ti_dependent0.refresh_from_db() + self.assertEqual(ti_dependent0.state, State.FAILED) + + task1_id = 'test_run_dependency_task' + args1 = ['tasks', + 'run', + '-A', + dag_id, + task1_id, + (DEFAULT_DATE + timedelta(days=1)).isoformat()] + task_command.task_run(self.parser.parse_args(args1)) + + ti_dependency = TaskInstance( + task=dag.get_task(task1_id), + execution_date=DEFAULT_DATE + timedelta(days=1)) + ti_dependency.refresh_from_db() + self.assertEqual(ti_dependency.state, State.FAILED) + + task2_id = 'test_run_dependent_task' + args2 = ['tasks', + 'run', + '-A', + dag_id, + task2_id, + (DEFAULT_DATE + timedelta(days=1)).isoformat()] + task_command.task_run(self.parser.parse_args(args2)) + + ti_dependent = TaskInstance( + task=dag.get_task(task2_id), + execution_date=DEFAULT_DATE + timedelta(days=1)) + ti_dependent.refresh_from_db() + self.assertEqual(ti_dependent.state, State.SUCCESS) diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index e025e39a9b66a..17589c8858850 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -22,7 +22,6 @@ import os import re import subprocess -import sys import tempfile import unittest from argparse import Namespace @@ -35,8 +34,8 @@ import airflow.bin.cli as cli from airflow import AirflowException, models, settings -from airflow.bin.cli import get_dag, get_num_ready_workers_running, task_run -from airflow.models import Connection, DagModel, Pool, TaskInstance, Variable +from airflow.bin.cli import get_num_ready_workers_running +from airflow.models import Connection, DagModel, Pool, Variable from airflow.settings import Session from airflow.utils import db, timezone from airflow.utils.db import add_default_pool_if_not_exists @@ -54,14 +53,6 @@ TEST_DAG_ID = 'unit_tests' -def reset(dag_id): - session = Session() - tis = session.query(models.TaskInstance).filter_by(dag_id=dag_id) - tis.delete() - session.commit() - session.close() - - def create_mock_args( # pylint: disable=too-many-arguments task_id, dag_id, @@ -176,29 +167,6 @@ def test_cli_webserver_debug(self): proc.terminate() proc.wait() - def test_local_run(self): - args = create_mock_args( - task_id='print_the_context', - dag_id='example_python_operator', - subdir='/root/dags/example_python_operator.py', - interactive=True, - execution_date=timezone.parse('2018-04-27T08:39:51.298439+00:00') - ) - - reset(args.dag_id) - - with patch('argparse.Namespace', args) as mock_args: - task_run(mock_args) - dag = get_dag(mock_args) - task = dag.get_task(task_id=args.task_id) - ti = TaskInstance(task, args.execution_date) - ti.refresh_from_db() - state = ti.current_state() - self.assertEqual(state, State.SUCCESS) - - def test_process_subdir_path_with_placeholder(self): - self.assertEqual(os.path.join(settings.DAGS_FOLDER, 'abc'), cli.process_subdir('DAGS_FOLDER/abc')) - class TestCliDags(unittest.TestCase): @@ -538,127 +506,6 @@ def test_dag_state(self): 'dags', 'state', 'example_bash_operator', DEFAULT_DATE.isoformat()]))) -class TestCliTasks(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.dagbag = models.DagBag(include_examples=True) - cls.parser = cli.CLIFactory.get_parser() - - def test_cli_list_tasks(self): - for dag_id in self.dagbag.dags: - args = self.parser.parse_args(['tasks', 'list', dag_id]) - cli.task_list(args) - - args = self.parser.parse_args([ - 'tasks', 'list', 'example_bash_operator', '--tree']) - cli.task_list(args) - - def test_test(self): - """Test the `airflow test` command""" - args = create_mock_args( - task_id='print_the_context', - dag_id='example_python_operator', - subdir=None, - execution_date=timezone.parse('2018-01-01') - ) - - saved_stdout = sys.stdout - try: - sys.stdout = out = io.StringIO() - cli.task_test(args) - - output = out.getvalue() - # Check that prints, and log messages, are shown - self.assertIn("'example_python_operator__print_the_context__20180101'", output) - finally: - sys.stdout = saved_stdout - - @mock.patch("airflow.bin.cli.jobs.LocalTaskJob") - def test_run_naive_taskinstance(self, mock_local_job): - """ - Test that we can run naive (non-localized) task instances - """ - naive_date = datetime(2016, 1, 1) - dag_id = 'test_run_ignores_all_dependencies' - - dag = self.dagbag.get_dag('test_run_ignores_all_dependencies') - - task0_id = 'test_run_dependent_task' - args0 = ['tasks', - 'run', - '-A', - '--local', - dag_id, - task0_id, - naive_date.isoformat()] - - cli.task_run(self.parser.parse_args(args0), dag=dag) - mock_local_job.assert_called_once_with( - task_instance=mock.ANY, - mark_success=False, - ignore_all_deps=True, - ignore_depends_on_past=False, - ignore_task_deps=False, - ignore_ti_state=False, - pickle_id=None, - pool=None, - ) - - def test_cli_test(self): - cli.task_test(self.parser.parse_args([ - 'tasks', 'test', 'example_bash_operator', 'runme_0', - DEFAULT_DATE.isoformat()])) - cli.task_test(self.parser.parse_args([ - 'tasks', 'test', 'example_bash_operator', 'runme_0', '--dry_run', - DEFAULT_DATE.isoformat()])) - - def test_cli_test_with_params(self): - cli.task_test(self.parser.parse_args([ - 'tasks', 'test', 'example_passing_params_via_test_command', 'run_this', - '-tp', '{"foo":"bar"}', DEFAULT_DATE.isoformat()])) - cli.task_test(self.parser.parse_args([ - 'tasks', 'test', 'example_passing_params_via_test_command', 'also_run_this', - '-tp', '{"foo":"bar"}', DEFAULT_DATE.isoformat()])) - - def test_cli_run(self): - cli.task_run(self.parser.parse_args([ - 'tasks', 'run', 'example_bash_operator', 'runme_0', '-l', - DEFAULT_DATE.isoformat()])) - - def test_task_state(self): - cli.task_state(self.parser.parse_args([ - 'tasks', 'state', 'example_bash_operator', 'runme_0', - DEFAULT_DATE.isoformat()])) - - def test_subdag_clear(self): - args = self.parser.parse_args([ - 'tasks', 'clear', 'example_subdag_operator', '--yes']) - cli.task_clear(args) - args = self.parser.parse_args([ - 'tasks', 'clear', 'example_subdag_operator', '--yes', '--exclude_subdags']) - cli.task_clear(args) - - def test_parentdag_downstream_clear(self): - args = self.parser.parse_args([ - 'tasks', 'clear', 'example_subdag_operator.section-1', '--yes']) - cli.task_clear(args) - args = self.parser.parse_args([ - 'tasks', 'clear', 'example_subdag_operator.section-1', '--yes', - '--exclude_parentdag']) - cli.task_clear(args) - - def test_get_dags(self): - dags = cli.get_dags(self.parser.parse_args(['tasks', 'clear', 'example_subdag_operator', - '--yes'])) - self.assertEqual(len(dags), 1) - - dags = cli.get_dags(self.parser.parse_args(['tasks', 'clear', 'subdag', '-dx', '--yes'])) - self.assertGreater(len(dags), 1) - - with self.assertRaises(AirflowException): - cli.get_dags(self.parser.parse_args(['tasks', 'clear', 'foobar', '-dx', '--yes'])) - - class TestCliPools(unittest.TestCase): @classmethod def setUpClass(cls): diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py index 36baeb38bc54e..739f4c78c24d7 100644 --- a/tests/jobs/test_backfill_job.py +++ b/tests/jobs/test_backfill_job.py @@ -811,60 +811,6 @@ def test_backfill_depends_on_past(self): ti.refresh_from_db() self.assertEqual(ti.state, State.SUCCESS) - def test_run_ignores_all_dependencies(self): - """ - Test that run respects ignore_all_dependencies - """ - dag_id = 'test_run_ignores_all_dependencies' - - dag = self.dagbag.get_dag('test_run_ignores_all_dependencies') - dag.clear() - - task0_id = 'test_run_dependent_task' - args0 = ['tasks', - 'run', - '-A', - dag_id, - task0_id, - DEFAULT_DATE.isoformat()] - cli.task_run(self.parser.parse_args(args0)) - ti_dependent0 = TI( - task=dag.get_task(task0_id), - execution_date=DEFAULT_DATE) - - ti_dependent0.refresh_from_db() - self.assertEqual(ti_dependent0.state, State.FAILED) - - task1_id = 'test_run_dependency_task' - args1 = ['tasks', - 'run', - '-A', - dag_id, - task1_id, - (DEFAULT_DATE + datetime.timedelta(days=1)).isoformat()] - cli.task_run(self.parser.parse_args(args1)) - - ti_dependency = TI( - task=dag.get_task(task1_id), - execution_date=DEFAULT_DATE + datetime.timedelta(days=1)) - ti_dependency.refresh_from_db() - self.assertEqual(ti_dependency.state, State.FAILED) - - task2_id = 'test_run_dependent_task' - args2 = ['tasks', - 'run', - '-A', - dag_id, - task2_id, - (DEFAULT_DATE + datetime.timedelta(days=1)).isoformat()] - cli.task_run(self.parser.parse_args(args2)) - - ti_dependent = TI( - task=dag.get_task(task2_id), - execution_date=DEFAULT_DATE + datetime.timedelta(days=1)) - ti_dependent.refresh_from_db() - self.assertEqual(ti_dependent.state, State.SUCCESS) - def test_backfill_depends_on_past_backwards(self): """ Test that CLI respects -B argument and raises on interaction with depends_on_past diff --git a/tests/utils/test_cli_util.py b/tests/utils/test_cli_util.py index 867827c0c40e6..f52072ed85284 100644 --- a/tests/utils/test_cli_util.py +++ b/tests/utils/test_cli_util.py @@ -24,6 +24,8 @@ from contextlib import contextmanager from datetime import datetime +from airflow import AirflowException, settings +from airflow.bin.cli import CLIFactory from airflow.utils import cli, cli_action_loggers @@ -71,6 +73,20 @@ def test_success_function(self): with fail_action_logger_callback(): success_func(Namespace()) + def test_process_subdir_path_with_placeholder(self): + self.assertEqual(os.path.join(settings.DAGS_FOLDER, 'abc'), cli.process_subdir('DAGS_FOLDER/abc')) + + def test_get_dags(self): + parser = CLIFactory.get_parser() + dags = cli.get_dags(parser.parse_args(['tasks', 'clear', 'example_subdag_operator', '--yes'])) + self.assertEqual(len(dags), 1) + + dags = cli.get_dags(parser.parse_args(['tasks', 'clear', 'subdag', '-dx', '--yes'])) + self.assertGreater(len(dags), 1) + + with self.assertRaises(AirflowException): + cli.get_dags(parser.parse_args(['tasks', 'clear', 'foobar', '-dx', '--yes'])) + @contextmanager def fail_action_logger_callback(): From 5e41d286e89f6ad3071b797863bbc29ec547c8aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Sun, 17 Nov 2019 06:08:00 +0100 Subject: [PATCH 07/23] [AIRLFOW-YYY] Move pool commands to separate file --- airflow/bin/cli.py | 117 ++------------------- airflow/cli/commands/pool_command.py | 132 ++++++++++++++++++++++++ tests/cli/commands/test_pool_command.py | 113 ++++++++++++++++++++ tests/cli/test_cli.py | 89 +--------------- 4 files changed, 253 insertions(+), 198 deletions(-) create mode 100644 airflow/cli/commands/pool_command.py create mode 100644 tests/cli/commands/test_pool_command.py diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index c790390113a63..f1e1d6d9405e8 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -45,7 +45,7 @@ from airflow import api, jobs, settings from airflow.api.client import get_current_api_client from airflow.cli.commands import ( - role_command, rotate_fernet_key_command, sync_perm_command, task_command, user_command, + pool_command, role_command, rotate_fernet_key_command, sync_perm_command, task_command, user_command, ) from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowWebServerTimeout @@ -226,109 +226,6 @@ def dag_delete(args): print("Bail.") -def _tabulate_pools(pools, tablefmt="fancy_grid"): - return "\n%s" % tabulate(pools, ['Pool', 'Slots', 'Description'], - tablefmt=tablefmt) - - -def pool_list(args): - """Displays info of all the pools""" - api_client = get_current_api_client() - log = LoggingMixin().log - pools = api_client.get_pools() - log.info(_tabulate_pools(pools=pools, tablefmt=args.output)) - - -def pool_get(args): - """Displays pool info by a given name""" - api_client = get_current_api_client() - log = LoggingMixin().log - pools = [api_client.get_pool(name=args.pool)] - log.info(_tabulate_pools(pools=pools, tablefmt=args.output)) - - -@cli_utils.action_logging -def pool_set(args): - """Creates new pool with a given name and slots""" - api_client = get_current_api_client() - log = LoggingMixin().log - pools = [api_client.create_pool(name=args.pool, - slots=args.slots, - description=args.description)] - log.info(_tabulate_pools(pools=pools, tablefmt=args.output)) - - -@cli_utils.action_logging -def pool_delete(args): - """Deletes pool by a given name""" - api_client = get_current_api_client() - log = LoggingMixin().log - pools = [api_client.delete_pool(name=args.pool)] - log.info(_tabulate_pools(pools=pools, tablefmt=args.output)) - - -@cli_utils.action_logging -def pool_import(args): - """Imports pools from the file""" - api_client = get_current_api_client() - log = LoggingMixin().log - if os.path.exists(args.file): - pools = pool_import_helper(args.file) - else: - print("Missing pools file.") - pools = api_client.get_pools() - log.info(_tabulate_pools(pools=pools, tablefmt=args.output)) - - -def pool_export(args): - """Exports all of the pools to the file""" - log = LoggingMixin().log - pools = pool_export_helper(args.file) - log.info(_tabulate_pools(pools=pools, tablefmt=args.output)) - - -def pool_import_helper(filepath): - """Helps import pools from the json file""" - api_client = get_current_api_client() - - with open(filepath, 'r') as poolfile: - data = poolfile.read() - try: # pylint: disable=too-many-nested-blocks - pools_json = json.loads(data) - except Exception as e: # pylint: disable=broad-except - print("Please check the validity of the json file: " + str(e)) - else: - try: - pools = [] - counter = 0 - for k, v in pools_json.items(): - if isinstance(v, dict) and len(v) == 2: - pools.append(api_client.create_pool(name=k, - slots=v["slots"], - description=v["description"])) - counter += 1 - else: - pass - except Exception: # pylint: disable=broad-except - pass - finally: - print("{} of {} pool(s) successfully updated.".format(counter, len(pools_json))) - return pools # pylint: disable=lost-exception - - -def pool_export_helper(filepath): - """Helps export all of the pools to the json file""" - api_client = get_current_api_client() - pool_dict = {} - pools = api_client.get_pools() - for pool in pools: - pool_dict[pool[0]] = {"slots": pool[1], "description": pool[2]} - with open(filepath, 'w') as poolfile: - poolfile.write(json.dumps(pool_dict, sort_keys=True, indent=4)) - print("{} pools successfully exported to {}".format(len(pool_dict), filepath)) - return pools - - def variables_list(args): """Displays all of the variables""" with db.create_session() as session: @@ -1880,37 +1777,37 @@ class CLIFactory: 'name': 'pools', 'subcommands': ( { - 'func': pool_list, + 'func': pool_command.pool_list, 'name': 'list', 'help': 'List pools', 'args': ('output',), }, { - 'func': pool_get, + 'func': pool_command.pool_get, 'name': 'get', 'help': 'Get pool size', 'args': ('pool_name', 'output',), }, { - 'func': pool_set, + 'func': pool_command.pool_set, 'name': 'set', 'help': 'Configure pool', 'args': ('pool_name', 'pool_slots', 'pool_description', 'output',), }, { - 'func': pool_delete, + 'func': pool_command.pool_delete, 'name': 'delete', 'help': 'Delete pool', 'args': ('pool_name', 'output',), }, { - 'func': pool_import, + 'func': pool_command.pool_import, 'name': 'import', 'help': 'Import pool', 'args': ('pool_import', 'output',), }, { - 'func': pool_export, + 'func': pool_command.pool_export, 'name': 'export', 'help': 'Export pool', 'args': ('pool_export', 'output',), diff --git a/airflow/cli/commands/pool_command.py b/airflow/cli/commands/pool_command.py new file mode 100644 index 0000000000000..235e2534d959f --- /dev/null +++ b/airflow/cli/commands/pool_command.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Pools sub-commands +""" +import json +import os + +from tabulate import tabulate + +from airflow import LoggingMixin +from airflow.api.client import get_current_api_client +from airflow.utils import cli as cli_utils + + +def _tabulate_pools(pools, tablefmt="fancy_grid"): + return "\n%s" % tabulate(pools, ['Pool', 'Slots', 'Description'], + tablefmt=tablefmt) + + +def pool_list(args): + """Displays info of all the pools""" + api_client = get_current_api_client() + log = LoggingMixin().log + pools = api_client.get_pools() + log.info(_tabulate_pools(pools=pools, tablefmt=args.output)) + + +def pool_get(args): + """Displays pool info by a given name""" + api_client = get_current_api_client() + log = LoggingMixin().log + pools = [api_client.get_pool(name=args.pool)] + log.info(_tabulate_pools(pools=pools, tablefmt=args.output)) + + +@cli_utils.action_logging +def pool_set(args): + """Creates new pool with a given name and slots""" + api_client = get_current_api_client() + log = LoggingMixin().log + pools = [api_client.create_pool(name=args.pool, + slots=args.slots, + description=args.description)] + log.info(_tabulate_pools(pools=pools, tablefmt=args.output)) + + +@cli_utils.action_logging +def pool_delete(args): + """Deletes pool by a given name""" + api_client = get_current_api_client() + log = LoggingMixin().log + pools = [api_client.delete_pool(name=args.pool)] + log.info(_tabulate_pools(pools=pools, tablefmt=args.output)) + + +@cli_utils.action_logging +def pool_import(args): + """Imports pools from the file""" + api_client = get_current_api_client() + log = LoggingMixin().log + if os.path.exists(args.file): + pools = pool_import_helper(args.file) + else: + print("Missing pools file.") + pools = api_client.get_pools() + log.info(_tabulate_pools(pools=pools, tablefmt=args.output)) + + +def pool_export(args): + """Exports all of the pools to the file""" + log = LoggingMixin().log + pools = pool_export_helper(args.file) + log.info(_tabulate_pools(pools=pools, tablefmt=args.output)) + + +def pool_import_helper(filepath): + """Helps import pools from the json file""" + api_client = get_current_api_client() + + with open(filepath, 'r') as poolfile: + data = poolfile.read() + try: # pylint: disable=too-many-nested-blocks + pools_json = json.loads(data) + except Exception as e: # pylint: disable=broad-except + print("Please check the validity of the json file: " + str(e)) + else: + try: + pools = [] + counter = 0 + for k, v in pools_json.items(): + if isinstance(v, dict) and len(v) == 2: + pools.append(api_client.create_pool(name=k, + slots=v["slots"], + description=v["description"])) + counter += 1 + else: + pass + except Exception: # pylint: disable=broad-except + pass + finally: + print("{} of {} pool(s) successfully updated.".format(counter, len(pools_json))) + return pools # pylint: disable=lost-exception + + +def pool_export_helper(filepath): + """Helps export all of the pools to the json file""" + api_client = get_current_api_client() + pool_dict = {} + pools = api_client.get_pools() + for pool in pools: + pool_dict[pool[0]] = {"slots": pool[1], "description": pool[2]} + with open(filepath, 'w') as poolfile: + poolfile.write(json.dumps(pool_dict, sort_keys=True, indent=4)) + print("{} pools successfully exported to {}".format(len(pool_dict), filepath)) + return pools diff --git a/tests/cli/commands/test_pool_command.py b/tests/cli/commands/test_pool_command.py new file mode 100644 index 0000000000000..4a1f0ed64e7d1 --- /dev/null +++ b/tests/cli/commands/test_pool_command.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import json +import os +import unittest + +from airflow import models, settings +from airflow.bin import cli +from airflow.cli.commands import pool_command +from airflow.models import Pool +from airflow.settings import Session +from airflow.utils.db import add_default_pool_if_not_exists + + +class TestCliPools(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.dagbag = models.DagBag(include_examples=True) + cls.parser = cli.CLIFactory.get_parser() + + def setUp(self): + super().setUp() + settings.configure_orm() + self.session = Session + self._cleanup() + + def tearDown(self): + self._cleanup() + + @staticmethod + def _cleanup(session=None): + if session is None: + session = Session() + session.query(Pool).filter(Pool.pool != Pool.DEFAULT_POOL_NAME).delete() + session.commit() + add_default_pool_if_not_exists() + session.close() + + def test_pool_list(self): + pool_command.pool_set(self.parser.parse_args(['pools', 'set', 'foo', '1', 'test'])) + with self.assertLogs(level='INFO') as cm: + pool_command.pool_list(self.parser.parse_args(['pools', 'list'])) + + stdout = cm.output + + self.assertIn('foo', stdout[0]) + + def test_pool_list_with_args(self): + pool_command.pool_list(self.parser.parse_args(['pools', 'list', '--output', 'tsv'])) + + def test_pool_create(self): + pool_command.pool_set(self.parser.parse_args(['pools', 'set', 'foo', '1', 'test'])) + self.assertEqual(self.session.query(Pool).count(), 2) + + def test_pool_get(self): + pool_command.pool_set(self.parser.parse_args(['pools', 'set', 'foo', '1', 'test'])) + pool_command.pool_get(self.parser.parse_args(['pools', 'get', 'foo'])) + + def test_pool_delete(self): + pool_command.pool_set(self.parser.parse_args(['pools', 'set', 'foo', '1', 'test'])) + pool_command.pool_delete(self.parser.parse_args(['pools', 'delete', 'foo'])) + self.assertEqual(self.session.query(Pool).count(), 1) + + def test_pool_import_export(self): + # Create two pools first + pool_config_input = { + "foo": { + "description": "foo_test", + "slots": 1 + }, + 'default_pool': { + 'description': 'Default pool', + 'slots': 128 + }, + "baz": { + "description": "baz_test", + "slots": 2 + } + } + with open('pools_import.json', mode='w') as file: + json.dump(pool_config_input, file) + + # Import json + pool_command.pool_import(self.parser.parse_args(['pools', 'import', 'pools_import.json'])) + + # Export json + pool_command.pool_export(self.parser.parse_args(['pools', 'export', 'pools_export.json'])) + + with open('pools_export.json', mode='r') as file: + pool_config_output = json.load(file) + self.assertEqual( + pool_config_input, + pool_config_output, + "Input and output pool files are not same") + os.remove('pools_import.json') + os.remove('pools_export.json') diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 17589c8858850..833793b1bd10c 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -18,7 +18,6 @@ # under the License. import contextlib import io -import json import os import re import subprocess @@ -35,10 +34,9 @@ import airflow.bin.cli as cli from airflow import AirflowException, models, settings from airflow.bin.cli import get_num_ready_workers_running -from airflow.models import Connection, DagModel, Pool, Variable +from airflow.models import Connection, DagModel, Variable from airflow.settings import Session from airflow.utils import db, timezone -from airflow.utils.db import add_default_pool_if_not_exists from airflow.utils.state import State from airflow.version import version from tests import conf_vars @@ -506,91 +504,6 @@ def test_dag_state(self): 'dags', 'state', 'example_bash_operator', DEFAULT_DATE.isoformat()]))) -class TestCliPools(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.dagbag = models.DagBag(include_examples=True) - cls.parser = cli.CLIFactory.get_parser() - - def setUp(self): - super().setUp() - settings.configure_orm() - self.session = Session - self._cleanup() - - def tearDown(self): - self._cleanup() - - @staticmethod - def _cleanup(session=None): - if session is None: - session = Session() - session.query(Pool).filter(Pool.pool != Pool.DEFAULT_POOL_NAME).delete() - session.commit() - add_default_pool_if_not_exists() - session.close() - - def test_pool_list(self): - cli.pool_set(self.parser.parse_args(['pools', 'set', 'foo', '1', 'test'])) - with self.assertLogs(level='INFO') as cm: - cli.pool_list(self.parser.parse_args(['pools', 'list'])) - - stdout = cm.output - - self.assertIn('foo', stdout[0]) - - def test_pool_list_with_args(self): - cli.pool_list(self.parser.parse_args(['pools', 'list', - '--output', 'tsv'])) - - def test_pool_create(self): - cli.pool_set(self.parser.parse_args(['pools', 'set', 'foo', '1', 'test'])) - self.assertEqual(self.session.query(Pool).count(), 2) - - def test_pool_get(self): - cli.pool_set(self.parser.parse_args(['pools', 'set', 'foo', '1', 'test'])) - cli.pool_get(self.parser.parse_args(['pools', 'get', 'foo'])) - - def test_pool_delete(self): - cli.pool_set(self.parser.parse_args(['pools', 'set', 'foo', '1', 'test'])) - cli.pool_delete(self.parser.parse_args(['pools', 'delete', 'foo'])) - self.assertEqual(self.session.query(Pool).count(), 1) - - def test_pool_import_export(self): - # Create two pools first - pool_config_input = { - "foo": { - "description": "foo_test", - "slots": 1 - }, - 'default_pool': { - 'description': 'Default pool', - 'slots': 128 - }, - "baz": { - "description": "baz_test", - "slots": 2 - } - } - with open('pools_import.json', mode='w') as file: - json.dump(pool_config_input, file) - - # Import json - cli.pool_import(self.parser.parse_args(['pools', 'import', 'pools_import.json'])) - - # Export json - cli.pool_export(self.parser.parse_args(['pools', 'export', 'pools_export.json'])) - - with open('pools_export.json', mode='r') as file: - pool_config_output = json.load(file) - self.assertEqual( - pool_config_input, - pool_config_output, - "Input and output pool files are not same") - os.remove('pools_import.json') - os.remove('pools_export.json') - - class TestCliVariables(unittest.TestCase): @classmethod def setUpClass(cls): From a6712c0f6f7bca8375e1111b9cd924ff43743a77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Sun, 17 Nov 2019 06:16:38 +0100 Subject: [PATCH 08/23] [AIRLFOW-YYY] Move variable commands to separate file --- airflow/bin/cli.py | 102 ++------------- airflow/cli/commands/variable_command.py | 111 ++++++++++++++++ tests/cli/commands/test_variable_command.py | 132 ++++++++++++++++++++ tests/cli/test_cli.py | 107 +--------------- 4 files changed, 252 insertions(+), 200 deletions(-) create mode 100644 airflow/cli/commands/variable_command.py create mode 100644 tests/cli/commands/test_variable_command.py diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index f1e1d6d9405e8..9761c7e3c8413 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -46,10 +46,11 @@ from airflow.api.client import get_current_api_client from airflow.cli.commands import ( pool_command, role_command, rotate_fernet_key_command, sync_perm_command, task_command, user_command, + variable_command, ) from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowWebServerTimeout -from airflow.models import DAG, Connection, DagBag, DagModel, DagRun, TaskInstance, Variable +from airflow.models import DAG, Connection, DagBag, DagModel, DagRun, TaskInstance from airflow.utils import cli as cli_utils, db from airflow.utils.cli import get_dag, process_subdir from airflow.utils.dot_renderer import render_dag @@ -226,93 +227,6 @@ def dag_delete(args): print("Bail.") -def variables_list(args): - """Displays all of the variables""" - with db.create_session() as session: - variables = session.query(Variable) - print("\n".join(var.key for var in variables)) - - -def variables_get(args): - """Displays variable by a given name""" - try: - var = Variable.get(args.key, - deserialize_json=args.json, - default_var=args.default) - print(var) - except ValueError as e: - print(e) - - -@cli_utils.action_logging -def variables_set(args): - """Creates new variable with a given name and value""" - Variable.set(args.key, args.value, serialize_json=args.json) - - -@cli_utils.action_logging -def variables_delete(args): - """Deletes variable by a given name""" - Variable.delete(args.key) - - -@cli_utils.action_logging -def variables_import(args): - """Imports variables from a given file""" - if os.path.exists(args.file): - import_helper(args.file) - else: - print("Missing variables file.") - - -def variables_export(args): - """Exports all of the variables to the file""" - variable_export_helper(args.file) - - -def import_helper(filepath): - """Helps import variables from the file""" - with open(filepath, 'r') as varfile: - data = varfile.read() - - try: - var_json = json.loads(data) - except Exception: # pylint: disable=broad-except - print("Invalid variables file.") - else: - suc_count = fail_count = 0 - for k, v in var_json.items(): - try: - Variable.set(k, v, serialize_json=not isinstance(v, str)) - except Exception as e: # pylint: disable=broad-except - print('Variable import failed: {}'.format(repr(e))) - fail_count += 1 - else: - suc_count += 1 - print("{} of {} variables successfully updated.".format(suc_count, len(var_json))) - if fail_count: - print("{} variable(s) failed to be updated.".format(fail_count)) - - -def variable_export_helper(filepath): - """Helps export all of the variables to the file""" - var_dict = {} - with db.create_session() as session: - qry = session.query(Variable).all() - - data = json.JSONDecoder() - for var in qry: - try: - val = data.decode(var.val) - except Exception: # pylint: disable=broad-except - val = var.val - var_dict[var.key] = val - - with open(filepath, 'w') as varfile: - varfile.write(json.dumps(var_dict, sort_keys=True, indent=4)) - print("{} variables successfully exported to {}".format(len(var_dict), filepath)) - - @cli_utils.action_logging def dag_pause(args): """Pauses a DAG""" @@ -1818,37 +1732,37 @@ class CLIFactory: 'name': 'variables', 'subcommands': ( { - 'func': variables_list, + 'func': variable_command.variables_list, 'name': 'list', 'help': 'List variables', 'args': (), }, { - 'func': variables_get, + 'func': variable_command.variables_get, 'name': 'get', 'help': 'Get variable', 'args': ('var', 'json', 'default'), }, { - 'func': variables_set, + 'func': variable_command.variables_set, 'name': 'set', 'help': 'Set variable', 'args': ('var', 'var_value', 'json'), }, { - 'func': variables_delete, + 'func': variable_command.variables_delete, 'name': 'delete', 'help': 'Delete variable', 'args': ('var',), }, { - 'func': variables_import, + 'func': variable_command.variables_import, 'name': 'import', 'help': 'Import variables', 'args': ('var_import',), }, { - 'func': variables_export, + 'func': variable_command.variables_export, 'name': 'export', 'help': 'Export variables', 'args': ('var_export',), diff --git a/airflow/cli/commands/variable_command.py b/airflow/cli/commands/variable_command.py new file mode 100644 index 0000000000000..f7b0aa9a9be40 --- /dev/null +++ b/airflow/cli/commands/variable_command.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Variable subcommands""" +import json +import os + +from airflow.models import Variable +from airflow.utils import cli as cli_utils, db + + +def variables_list(args): + """Displays all of the variables""" + with db.create_session() as session: + variables = session.query(Variable) + print("\n".join(var.key for var in variables)) + + +def variables_get(args): + """Displays variable by a given name""" + try: + var = Variable.get(args.key, + deserialize_json=args.json, + default_var=args.default) + print(var) + except ValueError as e: + print(e) + + +@cli_utils.action_logging +def variables_set(args): + """Creates new variable with a given name and value""" + Variable.set(args.key, args.value, serialize_json=args.json) + + +@cli_utils.action_logging +def variables_delete(args): + """Deletes variable by a given name""" + Variable.delete(args.key) + + +@cli_utils.action_logging +def variables_import(args): + """Imports variables from a given file""" + if os.path.exists(args.file): + _import_helper(args.file) + else: + print("Missing variables file.") + + +def variables_export(args): + """Exports all of the variables to the file""" + _variable_export_helper(args.file) + + +def _import_helper(filepath): + """Helps import variables from the file""" + with open(filepath, 'r') as varfile: + data = varfile.read() + + try: + var_json = json.loads(data) + except Exception: # pylint: disable=broad-except + print("Invalid variables file.") + else: + suc_count = fail_count = 0 + for k, v in var_json.items(): + try: + Variable.set(k, v, serialize_json=not isinstance(v, str)) + except Exception as e: # pylint: disable=broad-except + print('Variable import failed: {}'.format(repr(e))) + fail_count += 1 + else: + suc_count += 1 + print("{} of {} variables successfully updated.".format(suc_count, len(var_json))) + if fail_count: + print("{} variable(s) failed to be updated.".format(fail_count)) + + +def _variable_export_helper(filepath): + """Helps export all of the variables to the file""" + var_dict = {} + with db.create_session() as session: + qry = session.query(Variable).all() + + data = json.JSONDecoder() + for var in qry: + try: + val = data.decode(var.val) + except Exception: # pylint: disable=broad-except + val = var.val + var_dict[var.key] = val + + with open(filepath, 'w') as varfile: + varfile.write(json.dumps(var_dict, sort_keys=True, indent=4)) + print("{} variables successfully exported to {}".format(len(var_dict), filepath)) diff --git a/tests/cli/commands/test_variable_command.py b/tests/cli/commands/test_variable_command.py new file mode 100644 index 0000000000000..76733abfc1d51 --- /dev/null +++ b/tests/cli/commands/test_variable_command.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import os +import unittest + +from airflow import models +from airflow.bin import cli +from airflow.cli.commands import variable_command +from airflow.models import Variable + +DEV_NULL = "/dev/null" + + +class TestCliVariables(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.dagbag = models.DagBag(include_examples=True) + cls.parser = cli.CLIFactory.get_parser() + + def test_variables(self): + # Checks if all subcommands are properly received + variable_command.variables_set(self.parser.parse_args([ + 'variables', 'set', 'foo', '{"foo":"bar"}'])) + variable_command.variables_get(self.parser.parse_args([ + 'variables', 'get', 'foo'])) + variable_command.variables_get(self.parser.parse_args([ + 'variables', 'get', 'baz', '-d', 'bar'])) + variable_command.variables_list(self.parser.parse_args([ + 'variables', 'list'])) + variable_command.variables_delete(self.parser.parse_args([ + 'variables', 'delete', 'bar'])) + variable_command.variables_import(self.parser.parse_args([ + 'variables', 'import', DEV_NULL])) + variable_command.variables_export(self.parser.parse_args([ + 'variables', 'export', DEV_NULL])) + + variable_command.variables_set(self.parser.parse_args([ + 'variables', 'set', 'bar', 'original'])) + # First export + variable_command.variables_export(self.parser.parse_args([ + 'variables', 'export', 'variables1.json'])) + + first_exp = open('variables1.json', 'r') + + variable_command.variables_set(self.parser.parse_args([ + 'variables', 'set', 'bar', 'updated'])) + variable_command.variables_set(self.parser.parse_args([ + 'variables', 'set', 'foo', '{"foo":"oops"}'])) + variable_command.variables_delete(self.parser.parse_args([ + 'variables', 'delete', 'foo'])) + # First import + variable_command.variables_import(self.parser.parse_args([ + 'variables', 'import', 'variables1.json'])) + + self.assertEqual('original', Variable.get('bar')) + self.assertEqual('{\n "foo": "bar"\n}', Variable.get('foo')) + # Second export + variable_command.variables_export(self.parser.parse_args([ + 'variables', 'export', 'variables2.json'])) + + second_exp = open('variables2.json', 'r') + self.assertEqual(first_exp.read(), second_exp.read()) + second_exp.close() + first_exp.close() + # Second import + variable_command.variables_import(self.parser.parse_args([ + 'variables', 'import', 'variables2.json'])) + + self.assertEqual('original', Variable.get('bar')) + self.assertEqual('{\n "foo": "bar"\n}', Variable.get('foo')) + + # Set a dict + variable_command.variables_set(self.parser.parse_args([ + 'variables', 'set', 'dict', '{"foo": "oops"}'])) + # Set a list + variable_command.variables_set(self.parser.parse_args([ + 'variables', 'set', 'list', '["oops"]'])) + # Set str + variable_command.variables_set(self.parser.parse_args([ + 'variables', 'set', 'str', 'hello string'])) + # Set int + variable_command.variables_set(self.parser.parse_args([ + 'variables', 'set', 'int', '42'])) + # Set float + variable_command.variables_set(self.parser.parse_args([ + 'variables', 'set', 'float', '42.0'])) + # Set true + variable_command.variables_set(self.parser.parse_args([ + 'variables', 'set', 'true', 'true'])) + # Set false + variable_command.variables_set(self.parser.parse_args([ + 'variables', 'set', 'false', 'false'])) + # Set none + variable_command.variables_set(self.parser.parse_args([ + 'variables', 'set', 'null', 'null'])) + + # Export and then import + variable_command.variables_export(self.parser.parse_args([ + 'variables', 'export', 'variables3.json'])) + variable_command.variables_import(self.parser.parse_args([ + 'variables', 'import', 'variables3.json'])) + + # Assert value + self.assertEqual({'foo': 'oops'}, Variable.get('dict', deserialize_json=True)) + self.assertEqual(['oops'], Variable.get('list', deserialize_json=True)) + self.assertEqual('hello string', Variable.get('str')) # cannot json.loads(str) + self.assertEqual(42, Variable.get('int', deserialize_json=True)) + self.assertEqual(42.0, Variable.get('float', deserialize_json=True)) + self.assertEqual(True, Variable.get('true', deserialize_json=True)) + self.assertEqual(False, Variable.get('false', deserialize_json=True)) + self.assertEqual(None, Variable.get('null', deserialize_json=True)) + + os.remove('variables1.json') + os.remove('variables2.json') + os.remove('variables3.json') diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 833793b1bd10c..65349ab0dd566 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -34,7 +34,7 @@ import airflow.bin.cli as cli from airflow import AirflowException, models, settings from airflow.bin.cli import get_num_ready_workers_running -from airflow.models import Connection, DagModel, Variable +from airflow.models import Connection, DagModel from airflow.settings import Session from airflow.utils import db, timezone from airflow.utils.state import State @@ -44,7 +44,6 @@ dag_folder_path = '/'.join(os.path.realpath(__file__).split('/')[:-1]) -DEV_NULL = "/dev/null" DEFAULT_DATE = timezone.make_aware(datetime(2015, 1, 1)) TEST_DAG_FOLDER = os.path.join( os.path.dirname(dag_folder_path), 'dags') @@ -504,110 +503,6 @@ def test_dag_state(self): 'dags', 'state', 'example_bash_operator', DEFAULT_DATE.isoformat()]))) -class TestCliVariables(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.dagbag = models.DagBag(include_examples=True) - cls.parser = cli.CLIFactory.get_parser() - - def test_variables(self): - # Checks if all subcommands are properly received - cli.variables_set(self.parser.parse_args([ - 'variables', 'set', 'foo', '{"foo":"bar"}'])) - cli.variables_get(self.parser.parse_args([ - 'variables', 'get', 'foo'])) - cli.variables_get(self.parser.parse_args([ - 'variables', 'get', 'baz', '-d', 'bar'])) - cli.variables_list(self.parser.parse_args([ - 'variables', 'list'])) - cli.variables_delete(self.parser.parse_args([ - 'variables', 'delete', 'bar'])) - cli.variables_import(self.parser.parse_args([ - 'variables', 'import', DEV_NULL])) - cli.variables_export(self.parser.parse_args([ - 'variables', 'export', DEV_NULL])) - - cli.variables_set(self.parser.parse_args([ - 'variables', 'set', 'bar', 'original'])) - # First export - cli.variables_export(self.parser.parse_args([ - 'variables', 'export', 'variables1.json'])) - - first_exp = open('variables1.json', 'r') - - cli.variables_set(self.parser.parse_args([ - 'variables', 'set', 'bar', 'updated'])) - cli.variables_set(self.parser.parse_args([ - 'variables', 'set', 'foo', '{"foo":"oops"}'])) - cli.variables_delete(self.parser.parse_args([ - 'variables', 'delete', 'foo'])) - # First import - cli.variables_import(self.parser.parse_args([ - 'variables', 'import', 'variables1.json'])) - - self.assertEqual('original', Variable.get('bar')) - self.assertEqual('{\n "foo": "bar"\n}', Variable.get('foo')) - # Second export - cli.variables_export(self.parser.parse_args([ - 'variables', 'export', 'variables2.json'])) - - second_exp = open('variables2.json', 'r') - self.assertEqual(first_exp.read(), second_exp.read()) - second_exp.close() - first_exp.close() - # Second import - cli.variables_import(self.parser.parse_args([ - 'variables', 'import', 'variables2.json'])) - - self.assertEqual('original', Variable.get('bar')) - self.assertEqual('{\n "foo": "bar"\n}', Variable.get('foo')) - - # Set a dict - cli.variables_set(self.parser.parse_args([ - 'variables', 'set', 'dict', '{"foo": "oops"}'])) - # Set a list - cli.variables_set(self.parser.parse_args([ - 'variables', 'set', 'list', '["oops"]'])) - # Set str - cli.variables_set(self.parser.parse_args([ - 'variables', 'set', 'str', 'hello string'])) - # Set int - cli.variables_set(self.parser.parse_args([ - 'variables', 'set', 'int', '42'])) - # Set float - cli.variables_set(self.parser.parse_args([ - 'variables', 'set', 'float', '42.0'])) - # Set true - cli.variables_set(self.parser.parse_args([ - 'variables', 'set', 'true', 'true'])) - # Set false - cli.variables_set(self.parser.parse_args([ - 'variables', 'set', 'false', 'false'])) - # Set none - cli.variables_set(self.parser.parse_args([ - 'variables', 'set', 'null', 'null'])) - - # Export and then import - cli.variables_export(self.parser.parse_args([ - 'variables', 'export', 'variables3.json'])) - cli.variables_import(self.parser.parse_args([ - 'variables', 'import', 'variables3.json'])) - - # Assert value - self.assertEqual({'foo': 'oops'}, Variable.get('dict', deserialize_json=True)) - self.assertEqual(['oops'], Variable.get('list', deserialize_json=True)) - self.assertEqual('hello string', Variable.get('str')) # cannot json.loads(str) - self.assertEqual(42, Variable.get('int', deserialize_json=True)) - self.assertEqual(42.0, Variable.get('float', deserialize_json=True)) - self.assertEqual(True, Variable.get('true', deserialize_json=True)) - self.assertEqual(False, Variable.get('false', deserialize_json=True)) - self.assertEqual(None, Variable.get('null', deserialize_json=True)) - - os.remove('variables1.json') - os.remove('variables2.json') - os.remove('variables3.json') - - class TestCliWebServer(unittest.TestCase): @classmethod def setUpClass(cls): From f8f5d148e34a22826153e8152d788edd8b973412 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Mon, 18 Nov 2019 13:29:49 +0100 Subject: [PATCH 09/23] fixup! [AIRLFOW-YYY] Move variable commands to separate file --- tests/cli/commands/test_variable_command.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/cli/commands/test_variable_command.py b/tests/cli/commands/test_variable_command.py index 76733abfc1d51..b65ba6b5cdc4e 100644 --- a/tests/cli/commands/test_variable_command.py +++ b/tests/cli/commands/test_variable_command.py @@ -25,8 +25,6 @@ from airflow.cli.commands import variable_command from airflow.models import Variable -DEV_NULL = "/dev/null" - class TestCliVariables(unittest.TestCase): @classmethod @@ -47,9 +45,9 @@ def test_variables(self): variable_command.variables_delete(self.parser.parse_args([ 'variables', 'delete', 'bar'])) variable_command.variables_import(self.parser.parse_args([ - 'variables', 'import', DEV_NULL])) + 'variables', 'import', os.devnull])) variable_command.variables_export(self.parser.parse_args([ - 'variables', 'export', DEV_NULL])) + 'variables', 'export', os.devnull])) variable_command.variables_set(self.parser.parse_args([ 'variables', 'set', 'bar', 'original'])) From 52564a0436957d5a862a8d8ca589dcdbd1fb6ca8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Sun, 17 Nov 2019 06:22:58 +0100 Subject: [PATCH 10/23] [AIRLFOW-YYY] Move db commands to separate file --- airflow/bin/cli.py | 35 +++------------------ airflow/cli/commands/db_command.py | 44 +++++++++++++++++++++++++++ tests/cli/commands/test_db_command.py | 40 ++++++++++++++++++++++++ tests/cli/test_cli.py | 18 ----------- 4 files changed, 89 insertions(+), 48 deletions(-) create mode 100644 airflow/cli/commands/db_command.py create mode 100644 tests/cli/commands/test_db_command.py diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index 9761c7e3c8413..d5710a7c576da 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -45,8 +45,8 @@ from airflow import api, jobs, settings from airflow.api.client import get_current_api_client from airflow.cli.commands import ( - pool_command, role_command, rotate_fernet_key_command, sync_perm_command, task_command, user_command, - variable_command, + db_command, pool_command, role_command, rotate_fernet_key_command, sync_perm_command, task_command, + user_command, variable_command, ) from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowWebServerTimeout @@ -750,31 +750,6 @@ def worker(args): sub_proc.kill() -def initdb(args): - """Initializes the metadata database""" - print("DB: " + repr(settings.engine.url)) - db.initdb() - print("Done.") - - -def resetdb(args): - """Resets the metadata database""" - print("DB: " + repr(settings.engine.url)) - if args.yes or input("This will drop existing tables " - "if they exist. Proceed? " - "(y/n)").upper() == "Y": - db.resetdb() - else: - print("Bail.") - - -@cli_utils.action_logging -def upgradedb(args): - """Upgrades the metadata database""" - print("DB: " + repr(settings.engine.url)) - db.upgradedb() - - @cli_utils.action_logging def version(args): """Displays Airflow version at the command line""" @@ -1775,19 +1750,19 @@ class CLIFactory: 'name': 'db', 'subcommands': ( { - 'func': initdb, + 'func': db_command.initdb, 'name': 'init', 'help': "Initialize the metadata database", 'args': (), }, { - 'func': resetdb, + 'func': db_command.resetdb, 'name': 'reset', 'help': "Burn down and rebuild the metadata database", 'args': ('yes',), }, { - 'func': upgradedb, + 'func': db_command.upgradedb, 'name': 'upgrade', 'help': "Upgrade the metadata database to latest version", 'args': tuple(), diff --git a/airflow/cli/commands/db_command.py b/airflow/cli/commands/db_command.py new file mode 100644 index 0000000000000..a307502fafd2c --- /dev/null +++ b/airflow/cli/commands/db_command.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Database sub-commands""" +from airflow import settings +from airflow.utils import cli as cli_utils, db + + +def initdb(args): + """Initializes the metadata database""" + print("DB: " + repr(settings.engine.url)) + db.initdb() + print("Done.") + + +def resetdb(args): + """Resets the metadata database""" + print("DB: " + repr(settings.engine.url)) + if args.yes or input("This will drop existing tables " + "if they exist. Proceed? " + "(y/n)").upper() == "Y": + db.resetdb() + else: + print("Bail.") + + +@cli_utils.action_logging +def upgradedb(args): + """Upgrades the metadata database""" + print("DB: " + repr(settings.engine.url)) + db.upgradedb() diff --git a/tests/cli/commands/test_db_command.py b/tests/cli/commands/test_db_command.py new file mode 100644 index 0000000000000..a0d6bc8400db7 --- /dev/null +++ b/tests/cli/commands/test_db_command.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +from unittest import mock + +from airflow.bin import cli +from airflow.cli.commands import db_command + + +class TestCliDb(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.parser = cli.CLIFactory.get_parser() + + @mock.patch("airflow.cli.commands.db_command.db.initdb") + def test_cli_initdb(self, initdb_mock): + db_command.initdb(self.parser.parse_args(['db', 'init'])) + + initdb_mock.assert_called_once_with() + + @mock.patch("airflow.cli.commands.db_command.db.resetdb") + def test_cli_resetdb(self, resetdb_mock): + db_command.resetdb(self.parser.parse_args(['db', 'reset', '--yes'])) + + resetdb_mock.assert_called_once_with() diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 65349ab0dd566..fbe36a783876a 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -593,24 +593,6 @@ def test_cli_webserver_shutdown_when_gunicorn_master_is_killed(self, _): self.assertEqual(e.exception.code, 1) -class TestCliDb(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.parser = cli.CLIFactory.get_parser() - - @mock.patch("airflow.bin.cli.db.initdb") - def test_cli_initdb(self, initdb_mock): - cli.initdb(self.parser.parse_args(['db', 'init'])) - - initdb_mock.assert_called_once_with() - - @mock.patch("airflow.bin.cli.db.resetdb") - def test_cli_resetdb(self, resetdb_mock): - cli.resetdb(self.parser.parse_args(['db', 'reset', '--yes'])) - - resetdb_mock.assert_called_once_with() - - class TestCliConnections(unittest.TestCase): @classmethod def setUpClass(cls): From 151bdbb19446ace1ccb2375d0a136c18d90b5ea0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Sun, 17 Nov 2019 06:32:03 +0100 Subject: [PATCH 11/23] [AIRLFOW-YYY] Move connection commands to separate file --- airflow/bin/cli.py | 120 +--------- airflow/cli/commands/connection_command.py | 125 +++++++++++ airflow/utils/cli.py | 4 + tests/cli/commands/test_connection_command.py | 211 ++++++++++++++++++ tests/cli/test_cli.py | 187 +--------------- 5 files changed, 349 insertions(+), 298 deletions(-) create mode 100644 airflow/cli/commands/connection_command.py create mode 100644 tests/cli/commands/test_connection_command.py diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index d5710a7c576da..550d222d10965 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -24,7 +24,6 @@ import json import logging import os -import reprlib import signal import subprocess import sys @@ -33,26 +32,24 @@ import time import traceback from argparse import RawTextHelpFormatter -from urllib.parse import urlunparse import daemon import psutil from daemon.pidfile import TimeoutPIDLockFile -from sqlalchemy.orm import exc from tabulate import tabulate, tabulate_formats import airflow from airflow import api, jobs, settings from airflow.api.client import get_current_api_client from airflow.cli.commands import ( - db_command, pool_command, role_command, rotate_fernet_key_command, sync_perm_command, task_command, - user_command, variable_command, + connection_command, db_command, pool_command, role_command, rotate_fernet_key_command, sync_perm_command, + task_command, user_command, variable_command, ) from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowWebServerTimeout -from airflow.models import DAG, Connection, DagBag, DagModel, DagRun, TaskInstance +from airflow.models import DAG, DagBag, DagModel, DagRun, TaskInstance from airflow.utils import cli as cli_utils, db -from airflow.utils.cli import get_dag, process_subdir +from airflow.utils.cli import alternative_conn_specs, get_dag, process_subdir from airflow.utils.dot_renderer import render_dag from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.timezone import parse as parsedate @@ -756,109 +753,6 @@ def version(args): print(settings.HEADER + " v" + airflow.__version__) -alternative_conn_specs = ['conn_type', 'conn_host', - 'conn_login', 'conn_password', 'conn_schema', 'conn_port'] - - -def connections_list(args): - """Lists all connections at the command line""" - with db.create_session() as session: - conns = session.query(Connection.conn_id, Connection.conn_type, - Connection.host, Connection.port, - Connection.is_encrypted, - Connection.is_extra_encrypted, - Connection.extra).all() - conns = [map(reprlib.repr, conn) for conn in conns] - msg = tabulate(conns, ['Conn Id', 'Conn Type', 'Host', 'Port', - 'Is Encrypted', 'Is Extra Encrypted', 'Extra'], - tablefmt=args.output) - print(msg) - - -@cli_utils.action_logging -def connections_add(args): - """Adds new connection""" - # Check that the conn_id and conn_uri args were passed to the command: - missing_args = list() - invalid_args = list() - if args.conn_uri: - for arg in alternative_conn_specs: - if getattr(args, arg) is not None: - invalid_args.append(arg) - elif not args.conn_type: - missing_args.append('conn_uri or conn_type') - if missing_args: - msg = ('The following args are required to add a connection:' + - ' {missing!r}'.format(missing=missing_args)) - raise SystemExit(msg) - if invalid_args: - msg = ('The following args are not compatible with the ' + - '--add flag and --conn_uri flag: {invalid!r}') - msg = msg.format(invalid=invalid_args) - raise SystemExit(msg) - - if args.conn_uri: - new_conn = Connection(conn_id=args.conn_id, uri=args.conn_uri) - else: - new_conn = Connection(conn_id=args.conn_id, - conn_type=args.conn_type, - host=args.conn_host, - login=args.conn_login, - password=args.conn_password, - schema=args.conn_schema, - port=args.conn_port) - if args.conn_extra is not None: - new_conn.set_extra(args.conn_extra) - - with db.create_session() as session: - if not (session.query(Connection) - .filter(Connection.conn_id == new_conn.conn_id).first()): - session.add(new_conn) - msg = '\n\tSuccessfully added `conn_id`={conn_id} : {uri}\n' - msg = msg.format(conn_id=new_conn.conn_id, - uri=args.conn_uri or - urlunparse((args.conn_type, - '{login}:{password}@{host}:{port}' - .format(login=args.conn_login or '', - password=args.conn_password or '', - host=args.conn_host or '', - port=args.conn_port or ''), - args.conn_schema or '', '', '', ''))) - print(msg) - else: - msg = '\n\tA connection with `conn_id`={conn_id} already exists\n' - msg = msg.format(conn_id=new_conn.conn_id) - print(msg) - - -@cli_utils.action_logging -def connections_delete(args): - """Deletes connection from DB""" - with db.create_session() as session: - try: - to_delete = (session - .query(Connection) - .filter(Connection.conn_id == args.conn_id) - .one()) - except exc.NoResultFound: - msg = '\n\tDid not find a connection with `conn_id`={conn_id}\n' - msg = msg.format(conn_id=args.conn_id) - print(msg) - return - except exc.MultipleResultsFound: - msg = ('\n\tFound more than one connection with ' + - '`conn_id`={conn_id}\n') - msg = msg.format(conn_id=args.conn_id) - print(msg) - return - else: - deleted_conn_id = to_delete.conn_id - session.delete(to_delete) - msg = '\n\tSuccessfully deleted `conn_id`={conn_id}\n' - msg = msg.format(conn_id=deleted_conn_id) - print(msg) - - @cli_utils.action_logging def flower(args): """Starts Flower, Celery monitoring tool""" @@ -1808,19 +1702,19 @@ class CLIFactory: 'name': 'connections', 'subcommands': ( { - 'func': connections_list, + 'func': connection_command.connections_list, 'name': 'list', 'help': 'List connections', 'args': ('output',), }, { - 'func': connections_add, + 'func': connection_command.connections_add, 'name': 'add', 'help': 'Add a connection', 'args': ('conn_id', 'conn_uri', 'conn_extra') + tuple(alternative_conn_specs), }, { - 'func': connections_delete, + 'func': connection_command.connections_delete, 'name': 'delete', 'help': 'Delete a connection', 'args': ('conn_id',), diff --git a/airflow/cli/commands/connection_command.py b/airflow/cli/commands/connection_command.py new file mode 100644 index 0000000000000..d6097ac4d576c --- /dev/null +++ b/airflow/cli/commands/connection_command.py @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Connection sbu-commands""" +import reprlib +from urllib.parse import urlunparse + +from sqlalchemy.orm import exc +from tabulate import tabulate + +from airflow.models import Connection +from airflow.utils import cli as cli_utils, db +from airflow.utils.cli import alternative_conn_specs + + +def connections_list(args): + """Lists all connections at the command line""" + with db.create_session() as session: + conns = session.query(Connection.conn_id, Connection.conn_type, + Connection.host, Connection.port, + Connection.is_encrypted, + Connection.is_extra_encrypted, + Connection.extra).all() + conns = [map(reprlib.repr, conn) for conn in conns] + msg = tabulate(conns, ['Conn Id', 'Conn Type', 'Host', 'Port', + 'Is Encrypted', 'Is Extra Encrypted', 'Extra'], + tablefmt=args.output) + print(msg) + + +@cli_utils.action_logging +def connections_add(args): + """Adds new connection""" + # Check that the conn_id and conn_uri args were passed to the command: + missing_args = list() + invalid_args = list() + if args.conn_uri: + for arg in alternative_conn_specs: + if getattr(args, arg) is not None: + invalid_args.append(arg) + elif not args.conn_type: + missing_args.append('conn_uri or conn_type') + if missing_args: + msg = ('The following args are required to add a connection:' + + ' {missing!r}'.format(missing=missing_args)) + raise SystemExit(msg) + if invalid_args: + msg = ('The following args are not compatible with the ' + + '--add flag and --conn_uri flag: {invalid!r}') + msg = msg.format(invalid=invalid_args) + raise SystemExit(msg) + + if args.conn_uri: + new_conn = Connection(conn_id=args.conn_id, uri=args.conn_uri) + else: + new_conn = Connection(conn_id=args.conn_id, + conn_type=args.conn_type, + host=args.conn_host, + login=args.conn_login, + password=args.conn_password, + schema=args.conn_schema, + port=args.conn_port) + if args.conn_extra is not None: + new_conn.set_extra(args.conn_extra) + + with db.create_session() as session: + if not (session.query(Connection) + .filter(Connection.conn_id == new_conn.conn_id).first()): + session.add(new_conn) + msg = '\n\tSuccessfully added `conn_id`={conn_id} : {uri}\n' + msg = msg.format(conn_id=new_conn.conn_id, + uri=args.conn_uri or + urlunparse((args.conn_type, + '{login}:{password}@{host}:{port}' + .format(login=args.conn_login or '', + password=args.conn_password or '', + host=args.conn_host or '', + port=args.conn_port or ''), + args.conn_schema or '', '', '', ''))) + print(msg) + else: + msg = '\n\tA connection with `conn_id`={conn_id} already exists\n' + msg = msg.format(conn_id=new_conn.conn_id) + print(msg) + + +@cli_utils.action_logging +def connections_delete(args): + """Deletes connection from DB""" + with db.create_session() as session: + try: + to_delete = (session + .query(Connection) + .filter(Connection.conn_id == args.conn_id) + .one()) + except exc.NoResultFound: + msg = '\n\tDid not find a connection with `conn_id`={conn_id}\n' + msg = msg.format(conn_id=args.conn_id) + print(msg) + return + except exc.MultipleResultsFound: + msg = ('\n\tFound more than one connection with ' + + '`conn_id`={conn_id}\n') + msg = msg.format(conn_id=args.conn_id) + print(msg) + return + else: + deleted_conn_id = to_delete.conn_id + session.delete(to_delete) + msg = '\n\tSuccessfully deleted `conn_id`={conn_id}\n' + msg = msg.format(conn_id=deleted_conn_id) + print(msg) diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py index 39d9647d2abcc..f6a34ec4b3bb8 100644 --- a/airflow/utils/cli.py +++ b/airflow/utils/cli.py @@ -150,3 +150,7 @@ def get_dags(args): 'dag_id could not be found with regex: {}. Either the dag did not exist ' 'or it failed to parse.'.format(args.dag_id)) return matched_dags + + +alternative_conn_specs = ['conn_type', 'conn_host', + 'conn_login', 'conn_password', 'conn_schema', 'conn_port'] diff --git a/tests/cli/commands/test_connection_command.py b/tests/cli/commands/test_connection_command.py new file mode 100644 index 0000000000000..75a68a04fdd87 --- /dev/null +++ b/tests/cli/commands/test_connection_command.py @@ -0,0 +1,211 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import io +import re +import subprocess +import tempfile +import unittest +from unittest import mock + +from airflow import settings +from airflow.bin import cli +from airflow.cli.commands import connection_command +from airflow.models import Connection +from airflow.utils import db + + +class TestCliConnections(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.parser = cli.CLIFactory.get_parser() + + def test_cli_connections_list(self): + with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: + connection_command.connections_list(self.parser.parse_args(['connections', 'list'])) + stdout = mock_stdout.getvalue() + conns = [[x.strip("'") for x in re.findall(r"'\w+'", line)[:2]] + for ii, line in enumerate(stdout.split('\n')) + if ii % 2 == 1] + conns = [conn for conn in conns if len(conn) > 0] + + # Assert that some of the connections are present in the output as + # expected: + self.assertIn(['aws_default', 'aws'], conns) + self.assertIn(['hive_cli_default', 'hive_cli'], conns) + self.assertIn(['emr_default', 'emr'], conns) + self.assertIn(['mssql_default', 'mssql'], conns) + self.assertIn(['mysql_default', 'mysql'], conns) + self.assertIn(['postgres_default', 'postgres'], conns) + self.assertIn(['wasb_default', 'wasb'], conns) + self.assertIn(['segment_default', 'segment'], conns) + + def test_cli_connections_list_with_args(self): + args = self.parser.parse_args(['connections', 'list', + '--output', 'tsv']) + connection_command.connections_list(args) + + def test_cli_connections_list_redirect(self): + cmd = ['airflow', 'connections', 'list'] + with tempfile.TemporaryFile() as file: + proc = subprocess.Popen(cmd, stdout=file) + proc.wait() + self.assertEqual(0, proc.returncode) + + def test_cli_connections_add_delete(self): + # TODO: We should not delete the entire database, but only reset the contents of the Connection table. + db.resetdb() + # Add connections: + uri = 'postgresql://airflow:airflow@host:5432/airflow' + with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: + connection_command.connections_add(self.parser.parse_args( + ['connections', 'add', 'new1', + '--conn_uri=%s' % uri])) + connection_command.connections_add(self.parser.parse_args( + ['connections', 'add', 'new2', + '--conn_uri=%s' % uri])) + connection_command.connections_add(self.parser.parse_args( + ['connections', 'add', 'new3', + '--conn_uri=%s' % uri, '--conn_extra', "{'extra': 'yes'}"])) + connection_command.connections_add(self.parser.parse_args( + ['connections', 'add', 'new4', + '--conn_uri=%s' % uri, '--conn_extra', "{'extra': 'yes'}"])) + connection_command.connections_add(self.parser.parse_args( + ['connections', 'add', 'new5', + '--conn_type=hive_metastore', '--conn_login=airflow', + '--conn_password=airflow', '--conn_host=host', + '--conn_port=9083', '--conn_schema=airflow'])) + connection_command.connections_add(self.parser.parse_args( + ['connections', 'add', 'new6', + '--conn_uri', "", '--conn_type=google_cloud_platform', '--conn_extra', "{'extra': 'yes'}"])) + stdout = mock_stdout.getvalue() + + # Check addition stdout + lines = [l for l in stdout.split('\n') if len(l) > 0] + self.assertListEqual(lines, [ + ("\tSuccessfully added `conn_id`=new1 : " + + "postgresql://airflow:airflow@host:5432/airflow"), + ("\tSuccessfully added `conn_id`=new2 : " + + "postgresql://airflow:airflow@host:5432/airflow"), + ("\tSuccessfully added `conn_id`=new3 : " + + "postgresql://airflow:airflow@host:5432/airflow"), + ("\tSuccessfully added `conn_id`=new4 : " + + "postgresql://airflow:airflow@host:5432/airflow"), + ("\tSuccessfully added `conn_id`=new5 : " + + "hive_metastore://airflow:airflow@host:9083/airflow"), + ("\tSuccessfully added `conn_id`=new6 : " + + "google_cloud_platform://:@:") + ]) + + # Attempt to add duplicate + with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: + connection_command.connections_add(self.parser.parse_args( + ['connections', 'add', 'new1', + '--conn_uri=%s' % uri])) + stdout = mock_stdout.getvalue() + + # Check stdout for addition attempt + lines = [l for l in stdout.split('\n') if len(l) > 0] + self.assertListEqual(lines, [ + "\tA connection with `conn_id`=new1 already exists", + ]) + + # Attempt to add without providing conn_uri + with self.assertRaises(SystemExit) as exc: + connection_command.connections_add(self.parser.parse_args( + ['connections', 'add', 'new'])) + + self.assertEqual( + exc.exception.code, + "The following args are required to add a connection: ['conn_uri or conn_type']" + ) + + # Prepare to add connections + session = settings.Session() + extra = {'new1': None, + 'new2': None, + 'new3': "{'extra': 'yes'}", + 'new4': "{'extra': 'yes'}"} + + # Add connections + for index in range(1, 6): + conn_id = 'new%s' % index + result = (session + .query(Connection) + .filter(Connection.conn_id == conn_id) + .first()) + result = (result.conn_id, result.conn_type, result.host, + result.port, result.get_extra()) + if conn_id in ['new1', 'new2', 'new3', 'new4']: + self.assertEqual(result, (conn_id, 'postgres', 'host', 5432, + extra[conn_id])) + elif conn_id == 'new5': + self.assertEqual(result, (conn_id, 'hive_metastore', 'host', + 9083, None)) + elif conn_id == 'new6': + self.assertEqual(result, (conn_id, 'google_cloud_platform', + None, None, "{'extra': 'yes'}")) + + # Delete connections + with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: + connection_command.connections_delete(self.parser.parse_args( + ['connections', 'delete', 'new1'])) + connection_command.connections_delete(self.parser.parse_args( + ['connections', 'delete', 'new2'])) + connection_command.connections_delete(self.parser.parse_args( + ['connections', 'delete', 'new3'])) + connection_command.connections_delete(self.parser.parse_args( + ['connections', 'delete', 'new4'])) + connection_command.connections_delete(self.parser.parse_args( + ['connections', 'delete', 'new5'])) + connection_command.connections_delete(self.parser.parse_args( + ['connections', 'delete', 'new6'])) + stdout = mock_stdout.getvalue() + + # Check deletion stdout + lines = [l for l in stdout.split('\n') if len(l) > 0] + self.assertListEqual(lines, [ + "\tSuccessfully deleted `conn_id`=new1", + "\tSuccessfully deleted `conn_id`=new2", + "\tSuccessfully deleted `conn_id`=new3", + "\tSuccessfully deleted `conn_id`=new4", + "\tSuccessfully deleted `conn_id`=new5", + "\tSuccessfully deleted `conn_id`=new6" + ]) + + # Check deletions + for index in range(1, 7): + conn_id = 'new%s' % index + result = (session.query(Connection) + .filter(Connection.conn_id == conn_id) + .first()) + + self.assertTrue(result is None) + + # Attempt to delete a non-existing connection + with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: + connection_command.connections_delete(self.parser.parse_args( + ['connections', 'delete', 'fake'])) + stdout = mock_stdout.getvalue() + + # Check deletion attempt stdout + lines = [l for l in stdout.split('\n') if len(l) > 0] + self.assertListEqual(lines, [ + "\tDid not find a connection with `conn_id`=fake", + ]) + + session.close() diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index fbe36a783876a..e3f123c0e4643 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -19,7 +19,6 @@ import contextlib import io import os -import re import subprocess import tempfile import unittest @@ -34,9 +33,9 @@ import airflow.bin.cli as cli from airflow import AirflowException, models, settings from airflow.bin.cli import get_num_ready_workers_running -from airflow.models import Connection, DagModel +from airflow.models import DagModel from airflow.settings import Session -from airflow.utils import db, timezone +from airflow.utils import timezone from airflow.utils.state import State from airflow.version import version from tests import conf_vars @@ -593,188 +592,6 @@ def test_cli_webserver_shutdown_when_gunicorn_master_is_killed(self, _): self.assertEqual(e.exception.code, 1) -class TestCliConnections(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.parser = cli.CLIFactory.get_parser() - - def test_cli_connections_list(self): - with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: - cli.connections_list(self.parser.parse_args(['connections', 'list'])) - stdout = mock_stdout.getvalue() - conns = [[x.strip("'") for x in re.findall(r"'\w+'", line)[:2]] - for ii, line in enumerate(stdout.split('\n')) - if ii % 2 == 1] - conns = [conn for conn in conns if len(conn) > 0] - - # Assert that some of the connections are present in the output as - # expected: - self.assertIn(['aws_default', 'aws'], conns) - self.assertIn(['hive_cli_default', 'hive_cli'], conns) - self.assertIn(['emr_default', 'emr'], conns) - self.assertIn(['mssql_default', 'mssql'], conns) - self.assertIn(['mysql_default', 'mysql'], conns) - self.assertIn(['postgres_default', 'postgres'], conns) - self.assertIn(['wasb_default', 'wasb'], conns) - self.assertIn(['segment_default', 'segment'], conns) - - def test_cli_connections_list_with_args(self): - args = self.parser.parse_args(['connections', 'list', - '--output', 'tsv']) - cli.connections_list(args) - - def test_cli_connections_list_redirect(self): - cmd = ['airflow', 'connections', 'list'] - with tempfile.TemporaryFile() as file: - proc = subprocess.Popen(cmd, stdout=file) - proc.wait() - self.assertEqual(0, proc.returncode) - - def test_cli_connections_add_delete(self): - # TODO: We should not delete the entire database, but only reset the contents of the Connection table. - db.resetdb() - # Add connections: - uri = 'postgresql://airflow:airflow@host:5432/airflow' - with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: - cli.connections_add(self.parser.parse_args( - ['connections', 'add', 'new1', - '--conn_uri=%s' % uri])) - cli.connections_add(self.parser.parse_args( - ['connections', 'add', 'new2', - '--conn_uri=%s' % uri])) - cli.connections_add(self.parser.parse_args( - ['connections', 'add', 'new3', - '--conn_uri=%s' % uri, '--conn_extra', "{'extra': 'yes'}"])) - cli.connections_add(self.parser.parse_args( - ['connections', 'add', 'new4', - '--conn_uri=%s' % uri, '--conn_extra', "{'extra': 'yes'}"])) - cli.connections_add(self.parser.parse_args( - ['connections', 'add', 'new5', - '--conn_type=hive_metastore', '--conn_login=airflow', - '--conn_password=airflow', '--conn_host=host', - '--conn_port=9083', '--conn_schema=airflow'])) - cli.connections_add(self.parser.parse_args( - ['connections', 'add', 'new6', - '--conn_uri', "", '--conn_type=google_cloud_platform', '--conn_extra', "{'extra': 'yes'}"])) - stdout = mock_stdout.getvalue() - - # Check addition stdout - lines = [l for l in stdout.split('\n') if len(l) > 0] - self.assertListEqual(lines, [ - ("\tSuccessfully added `conn_id`=new1 : " + - "postgresql://airflow:airflow@host:5432/airflow"), - ("\tSuccessfully added `conn_id`=new2 : " + - "postgresql://airflow:airflow@host:5432/airflow"), - ("\tSuccessfully added `conn_id`=new3 : " + - "postgresql://airflow:airflow@host:5432/airflow"), - ("\tSuccessfully added `conn_id`=new4 : " + - "postgresql://airflow:airflow@host:5432/airflow"), - ("\tSuccessfully added `conn_id`=new5 : " + - "hive_metastore://airflow:airflow@host:9083/airflow"), - ("\tSuccessfully added `conn_id`=new6 : " + - "google_cloud_platform://:@:") - ]) - - # Attempt to add duplicate - with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: - cli.connections_add(self.parser.parse_args( - ['connections', 'add', 'new1', - '--conn_uri=%s' % uri])) - stdout = mock_stdout.getvalue() - - # Check stdout for addition attempt - lines = [l for l in stdout.split('\n') if len(l) > 0] - self.assertListEqual(lines, [ - "\tA connection with `conn_id`=new1 already exists", - ]) - - # Attempt to add without providing conn_uri - with self.assertRaises(SystemExit) as exc: - cli.connections_add(self.parser.parse_args( - ['connections', 'add', 'new'])) - - self.assertEqual( - exc.exception.code, - "The following args are required to add a connection: ['conn_uri or conn_type']" - ) - - # Prepare to add connections - session = settings.Session() - extra = {'new1': None, - 'new2': None, - 'new3': "{'extra': 'yes'}", - 'new4': "{'extra': 'yes'}"} - - # Add connections - for index in range(1, 6): - conn_id = 'new%s' % index - result = (session - .query(Connection) - .filter(Connection.conn_id == conn_id) - .first()) - result = (result.conn_id, result.conn_type, result.host, - result.port, result.get_extra()) - if conn_id in ['new1', 'new2', 'new3', 'new4']: - self.assertEqual(result, (conn_id, 'postgres', 'host', 5432, - extra[conn_id])) - elif conn_id == 'new5': - self.assertEqual(result, (conn_id, 'hive_metastore', 'host', - 9083, None)) - elif conn_id == 'new6': - self.assertEqual(result, (conn_id, 'google_cloud_platform', - None, None, "{'extra': 'yes'}")) - - # Delete connections - with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: - cli.connections_delete(self.parser.parse_args( - ['connections', 'delete', 'new1'])) - cli.connections_delete(self.parser.parse_args( - ['connections', 'delete', 'new2'])) - cli.connections_delete(self.parser.parse_args( - ['connections', 'delete', 'new3'])) - cli.connections_delete(self.parser.parse_args( - ['connections', 'delete', 'new4'])) - cli.connections_delete(self.parser.parse_args( - ['connections', 'delete', 'new5'])) - cli.connections_delete(self.parser.parse_args( - ['connections', 'delete', 'new6'])) - stdout = mock_stdout.getvalue() - - # Check deletion stdout - lines = [l for l in stdout.split('\n') if len(l) > 0] - self.assertListEqual(lines, [ - "\tSuccessfully deleted `conn_id`=new1", - "\tSuccessfully deleted `conn_id`=new2", - "\tSuccessfully deleted `conn_id`=new3", - "\tSuccessfully deleted `conn_id`=new4", - "\tSuccessfully deleted `conn_id`=new5", - "\tSuccessfully deleted `conn_id`=new6" - ]) - - # Check deletions - for index in range(1, 7): - conn_id = 'new%s' % index - result = (session.query(Connection) - .filter(Connection.conn_id == conn_id) - .first()) - - self.assertTrue(result is None) - - # Attempt to delete a non-existing connection - with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: - cli.connections_delete(self.parser.parse_args( - ['connections', 'delete', 'fake'])) - stdout = mock_stdout.getvalue() - - # Check deletion attempt stdout - lines = [l for l in stdout.split('\n') if len(l) > 0] - self.assertListEqual(lines, [ - "\tDid not find a connection with `conn_id`=fake", - ]) - - session.close() - - class TestCliVersion(unittest.TestCase): @classmethod def setUpClass(cls): From a3620b119b94fe2ef6eb2a2717a4aebb631f27cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Sun, 17 Nov 2019 06:40:11 +0100 Subject: [PATCH 12/23] [AIRLFOW-YYY] Move version command to separate file --- airflow/bin/cli.py | 11 ++----- airflow/cli/commands/version_command.py | 26 ++++++++++++++++ tests/cli/commands/test_version_command.py | 36 ++++++++++++++++++++++ tests/cli/test_cli.py | 13 -------- 4 files changed, 64 insertions(+), 22 deletions(-) create mode 100644 airflow/cli/commands/version_command.py create mode 100644 tests/cli/commands/test_version_command.py diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index 550d222d10965..0bdb4f6485c1d 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -38,12 +38,11 @@ from daemon.pidfile import TimeoutPIDLockFile from tabulate import tabulate, tabulate_formats -import airflow from airflow import api, jobs, settings from airflow.api.client import get_current_api_client from airflow.cli.commands import ( connection_command, db_command, pool_command, role_command, rotate_fernet_key_command, sync_perm_command, - task_command, user_command, variable_command, + task_command, user_command, variable_command, version_command, ) from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowWebServerTimeout @@ -747,12 +746,6 @@ def worker(args): sub_proc.kill() -@cli_utils.action_logging -def version(args): - """Displays Airflow version at the command line""" - print(settings.HEADER + " v" + airflow.__version__) - - @cli_utils.action_logging def flower(args): """Starts Flower, Celery monitoring tool""" @@ -1694,7 +1687,7 @@ class CLIFactory: 'args': ('flower_hostname', 'flower_port', 'flower_conf', 'flower_url_prefix', 'flower_basic_auth', 'broker_api', 'pid', 'daemon', 'stdout', 'stderr', 'log_file'), }, { - 'func': version, + 'func': version_command.version, 'help': "Show the version", 'args': tuple(), }, { diff --git a/airflow/cli/commands/version_command.py b/airflow/cli/commands/version_command.py new file mode 100644 index 0000000000000..b55d4e5ffe3df --- /dev/null +++ b/airflow/cli/commands/version_command.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Version command""" +import airflow +from airflow import settings +from airflow.utils import cli as cli_utils + + +@cli_utils.action_logging +def version(args): + """Displays Airflow version at the command line""" + print(settings.HEADER + " v" + airflow.__version__) diff --git a/tests/cli/commands/test_version_command.py b/tests/cli/commands/test_version_command.py new file mode 100644 index 0000000000000..950b4b5814398 --- /dev/null +++ b/tests/cli/commands/test_version_command.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import io +import unittest +from unittest import mock + +import airflow.cli.commands.version_command +from airflow.bin import cli +from airflow.version import version + + +class TestCliVersion(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.parser = cli.CLIFactory.get_parser() + + def test_cli_version(self): + with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: + airflow.cli.commands.version_command.version(self.parser.parse_args(['version'])) + stdout = mock_stdout.getvalue() + self.assertIn(version, stdout) diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index e3f123c0e4643..786786ab2fddc 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -37,7 +37,6 @@ from airflow.settings import Session from airflow.utils import timezone from airflow.utils.state import State -from airflow.version import version from tests import conf_vars from tests.compat import mock @@ -590,15 +589,3 @@ def test_cli_webserver_shutdown_when_gunicorn_master_is_killed(self, _): with self.assertRaises(SystemExit) as e: cli.webserver(args) self.assertEqual(e.exception.code, 1) - - -class TestCliVersion(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.parser = cli.CLIFactory.get_parser() - - def test_cli_version(self): - with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: - cli.version(self.parser.parse_args(['version'])) - stdout = mock_stdout.getvalue() - self.assertIn(version, stdout) From 0b830994a25d8eb3b79bdced0c5aaedf6796a735 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Sun, 17 Nov 2019 06:54:03 +0100 Subject: [PATCH 13/23] [AIRLFOW-YYY] Move scheduler command to separate file --- airflow/bin/cli.py | 103 ++-------------------- airflow/cli/commands/scheduler_command.py | 64 ++++++++++++++ airflow/utils/cli.py | 56 ++++++++++++ tests/cli/test_cli.py | 9 +- 4 files changed, 131 insertions(+), 101 deletions(-) create mode 100644 airflow/cli/commands/scheduler_command.py diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index 0bdb4f6485c1d..604adbe6e07f6 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -28,9 +28,7 @@ import subprocess import sys import textwrap -import threading import time -import traceback from argparse import RawTextHelpFormatter import daemon @@ -41,14 +39,16 @@ from airflow import api, jobs, settings from airflow.api.client import get_current_api_client from airflow.cli.commands import ( - connection_command, db_command, pool_command, role_command, rotate_fernet_key_command, sync_perm_command, - task_command, user_command, variable_command, version_command, + connection_command, db_command, pool_command, role_command, rotate_fernet_key_command, scheduler_command, + sync_perm_command, task_command, user_command, variable_command, version_command, ) from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowWebServerTimeout from airflow.models import DAG, DagBag, DagModel, DagRun, TaskInstance from airflow.utils import cli as cli_utils, db -from airflow.utils.cli import alternative_conn_specs, get_dag, process_subdir +from airflow.utils.cli import ( + alternative_conn_specs, get_dag, process_subdir, setup_locations, setup_logging, sigint_handler, +) from airflow.utils.dot_renderer import render_dag from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.timezone import parse as parsedate @@ -64,59 +64,6 @@ DAGS_FOLDER = '[AIRFLOW_HOME]/dags' -def sigint_handler(sig, frame): # pylint: disable=unused-argument - """ - Returns without error on SIGINT or SIGTERM signals in interactive command mode - e.g. CTRL+C or kill - """ - sys.exit(0) - - -def sigquit_handler(sig, frame): # pylint: disable=unused-argument - """ - Helps debug deadlocks by printing stacktraces when this gets a SIGQUIT - e.g. kill -s QUIT or CTRL+\ - """ - print("Dumping stack traces for all threads in PID {}".format(os.getpid())) - id_to_name = {th.ident: th.name for th in threading.enumerate()} - code = [] - for thread_id, stack in sys._current_frames().items(): # pylint: disable=protected-access - code.append("\n# Thread: {}({})" - .format(id_to_name.get(thread_id, ""), thread_id)) - for filename, line_number, name, line in traceback.extract_stack(stack): - code.append('File: "{}", line {}, in {}' - .format(filename, line_number, name)) - if line: - code.append(" {}".format(line.strip())) - print("\n".join(code)) - - -def setup_logging(filename): - """Creates log file handler for daemon process""" - root = logging.getLogger() - handler = logging.FileHandler(filename) - formatter = logging.Formatter(settings.SIMPLE_LOG_FORMAT) - handler.setFormatter(formatter) - root.addHandler(handler) - root.setLevel(settings.LOGGING_LEVEL) - - return handler.stream - - -def setup_locations(process, pid=None, stdout=None, stderr=None, log=None): - """Creates logging paths""" - if not stderr: - stderr = os.path.join(settings.AIRFLOW_HOME, 'airflow-{}.err'.format(process)) - if not stdout: - stdout = os.path.join(settings.AIRFLOW_HOME, 'airflow-{}.out'.format(process)) - if not log: - log = os.path.join(settings.AIRFLOW_HOME, 'airflow-{}.log'.format(process)) - if not pid: - pid = os.path.join(settings.AIRFLOW_HOME, 'airflow-{}.pid'.format(process)) - - return pid, stdout, stderr, log - - @cli_utils.action_logging def dag_backfill(args, dag=None): """Creates backfill job or dry run for a DAG""" @@ -623,44 +570,6 @@ def monitor_gunicorn(gunicorn_master_proc): monitor_gunicorn(gunicorn_master_proc) -@cli_utils.action_logging -def scheduler(args): - """Starts Airflow Scheduler""" - print(settings.HEADER) - job = jobs.SchedulerJob( - dag_id=args.dag_id, - subdir=process_subdir(args.subdir), - num_runs=args.num_runs, - do_pickle=args.do_pickle) - - if args.daemon: - pid, stdout, stderr, log_file = setup_locations("scheduler", - args.pid, - args.stdout, - args.stderr, - args.log_file) - handle = setup_logging(log_file) - stdout = open(stdout, 'w+') - stderr = open(stderr, 'w+') - - ctx = daemon.DaemonContext( - pidfile=TimeoutPIDLockFile(pid, -1), - files_preserve=[handle], - stdout=stdout, - stderr=stderr, - ) - with ctx: - job.run() - - stdout.close() - stderr.close() - else: - signal.signal(signal.SIGINT, sigint_handler) - signal.signal(signal.SIGTERM, sigint_handler) - signal.signal(signal.SIGQUIT, sigquit_handler) - job.run() - - @cli_utils.action_logging def serve_logs(args): """Serves logs generated by Worker""" @@ -1671,7 +1580,7 @@ class CLIFactory: 'pid', 'daemon', 'stdout', 'stderr', 'access_logfile', 'error_logfile', 'log_file', 'ssl_cert', 'ssl_key', 'debug'), }, { - 'func': scheduler, + 'func': scheduler_command.scheduler, 'help': "Start a scheduler instance", 'args': ('dag_id_opt', 'subdir', 'num_runs', 'do_pickle', 'pid', 'daemon', 'stdout', 'stderr', diff --git a/airflow/cli/commands/scheduler_command.py b/airflow/cli/commands/scheduler_command.py new file mode 100644 index 0000000000000..0ae49857576fa --- /dev/null +++ b/airflow/cli/commands/scheduler_command.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Scheduler command""" +import signal + +import daemon +from daemon.pidfile import TimeoutPIDLockFile + +from airflow import jobs, settings +from airflow.utils import cli as cli_utils +from airflow.utils.cli import process_subdir, setup_locations, setup_logging, sigint_handler, sigquit_handler + + +@cli_utils.action_logging +def scheduler(args): + """Starts Airflow Scheduler""" + print(settings.HEADER) + job = jobs.SchedulerJob( + dag_id=args.dag_id, + subdir=process_subdir(args.subdir), + num_runs=args.num_runs, + do_pickle=args.do_pickle) + + if args.daemon: + pid, stdout, stderr, log_file = setup_locations("scheduler", + args.pid, + args.stdout, + args.stderr, + args.log_file) + handle = setup_logging(log_file) + stdout = open(stdout, 'w+') + stderr = open(stderr, 'w+') + + ctx = daemon.DaemonContext( + pidfile=TimeoutPIDLockFile(pid, -1), + files_preserve=[handle], + stdout=stdout, + stderr=stderr, + ) + with ctx: + job.run() + + stdout.close() + stderr.close() + else: + signal.signal(signal.SIGINT, sigint_handler) + signal.signal(signal.SIGTERM, sigint_handler) + signal.signal(signal.SIGQUIT, sigquit_handler) + job.run() diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py index f6a34ec4b3bb8..fb5aae6c9508b 100644 --- a/airflow/utils/cli.py +++ b/airflow/utils/cli.py @@ -24,10 +24,13 @@ import functools import getpass import json +import logging import os import re import socket import sys +import threading +import traceback from argparse import Namespace from datetime import datetime @@ -154,3 +157,56 @@ def get_dags(args): alternative_conn_specs = ['conn_type', 'conn_host', 'conn_login', 'conn_password', 'conn_schema', 'conn_port'] + + +def setup_locations(process, pid=None, stdout=None, stderr=None, log=None): + """Creates logging paths""" + if not stderr: + stderr = os.path.join(settings.AIRFLOW_HOME, 'airflow-{}.err'.format(process)) + if not stdout: + stdout = os.path.join(settings.AIRFLOW_HOME, 'airflow-{}.out'.format(process)) + if not log: + log = os.path.join(settings.AIRFLOW_HOME, 'airflow-{}.log'.format(process)) + if not pid: + pid = os.path.join(settings.AIRFLOW_HOME, 'airflow-{}.pid'.format(process)) + + return pid, stdout, stderr, log + + +def setup_logging(filename): + """Creates log file handler for daemon process""" + root = logging.getLogger() + handler = logging.FileHandler(filename) + formatter = logging.Formatter(settings.SIMPLE_LOG_FORMAT) + handler.setFormatter(formatter) + root.addHandler(handler) + root.setLevel(settings.LOGGING_LEVEL) + + return handler.stream + + +def sigint_handler(sig, frame): # pylint: disable=unused-argument + """ + Returns without error on SIGINT or SIGTERM signals in interactive command mode + e.g. CTRL+C or kill + """ + sys.exit(0) + + +def sigquit_handler(sig, frame): # pylint: disable=unused-argument + """ + Helps debug deadlocks by printing stacktraces when this gets a SIGQUIT + e.g. kill -s QUIT or CTRL+\ + """ + print("Dumping stack traces for all threads in PID {}".format(os.getpid())) + id_to_name = {th.ident: th.name for th in threading.enumerate()} + code = [] + for thread_id, stack in sys._current_frames().items(): # pylint: disable=protected-access + code.append("\n# Thread: {}({})" + .format(id_to_name.get(thread_id, ""), thread_id)) + for filename, line_number, name, line in traceback.extract_stack(stack): + code.append('File: "{}", line {}, in {}' + .format(filename, line_number, name)) + if line: + code.append(" {}".format(line.strip())) + print("\n".join(code)) diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 786786ab2fddc..8b86aa837d60b 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -36,6 +36,7 @@ from airflow.models import DagModel from airflow.settings import Session from airflow.utils import timezone +from airflow.utils.cli import setup_locations from airflow.utils.state import State from tests import conf_vars from tests.compat import mock @@ -524,8 +525,8 @@ def tearDown(self) -> None: self._check_processes() def _clean_pidfiles(self): - pidfile_webserver = cli.setup_locations("webserver")[0] - pidfile_monitor = cli.setup_locations("webserver-monitor")[0] + pidfile_webserver = setup_locations("webserver")[0] + pidfile_monitor = setup_locations("webserver-monitor")[0] if os.path.exists(pidfile_webserver): os.remove(pidfile_webserver) if os.path.exists(pidfile_monitor): @@ -562,8 +563,8 @@ def test_cli_webserver_foreground_with_pid(self): @unittest.skipIf("TRAVIS" in os.environ and bool(os.environ["TRAVIS"]), "Skipping test due to lack of required file permission") def test_cli_webserver_background(self): - pidfile_webserver = cli.setup_locations("webserver")[0] - pidfile_monitor = cli.setup_locations("webserver-monitor")[0] + pidfile_webserver = setup_locations("webserver")[0] + pidfile_monitor = setup_locations("webserver-monitor")[0] # Run webserver as daemon in background. Note that the wait method is not called. subprocess.Popen(["airflow", "webserver", "-D"]) From bca908820e860059e3610de6d233a53099304829 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Sun, 17 Nov 2019 07:01:46 +0100 Subject: [PATCH 14/23] [AIRLFOW-YYY] Move worker command to separate file --- airflow/bin/cli.py | 69 +------------- airflow/cli/commands/worker_command.py | 93 +++++++++++++++++++ .../test_worker_command.py} | 5 +- 3 files changed, 97 insertions(+), 70 deletions(-) create mode 100644 airflow/cli/commands/worker_command.py rename tests/cli/{test_worker_initialisation.py => commands/test_worker_command.py} (96%) diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index 604adbe6e07f6..44d532e34a8da 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -40,7 +40,7 @@ from airflow.api.client import get_current_api_client from airflow.cli.commands import ( connection_command, db_command, pool_command, role_command, rotate_fernet_key_command, scheduler_command, - sync_perm_command, task_command, user_command, variable_command, version_command, + sync_perm_command, task_command, user_command, variable_command, version_command, worker_command, ) from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowWebServerTimeout @@ -590,71 +590,6 @@ def serve_logs(filename): # pylint: disable=unused-variable, redefined-outer-na flask_app.run(host='0.0.0.0', port=worker_log_server_port) -@cli_utils.action_logging -def worker(args): - """Starts Airflow Celery worker""" - env = os.environ.copy() - env['AIRFLOW_HOME'] = settings.AIRFLOW_HOME - - if not settings.validate_session(): - log = LoggingMixin().log - log.error("Worker exiting... database connection precheck failed! ") - sys.exit(1) - - # Celery worker - from airflow.executors.celery_executor import app as celery_app - from celery.bin import worker # pylint: disable=redefined-outer-name - - autoscale = args.autoscale - if autoscale is None and conf.has_option("celery", "worker_autoscale"): - autoscale = conf.get("celery", "worker_autoscale") - worker = worker.worker(app=celery_app) # pylint: disable=redefined-outer-name - options = { - 'optimization': 'fair', - 'O': 'fair', - 'queues': args.queues, - 'concurrency': args.concurrency, - 'autoscale': autoscale, - 'hostname': args.celery_hostname, - 'loglevel': conf.get('core', 'LOGGING_LEVEL'), - } - - if conf.has_option("celery", "pool"): - options["pool"] = conf.get("celery", "pool") - - if args.daemon: - pid, stdout, stderr, log_file = setup_locations("worker", - args.pid, - args.stdout, - args.stderr, - args.log_file) - handle = setup_logging(log_file) - stdout = open(stdout, 'w+') - stderr = open(stderr, 'w+') - - ctx = daemon.DaemonContext( - pidfile=TimeoutPIDLockFile(pid, -1), - files_preserve=[handle], - stdout=stdout, - stderr=stderr, - ) - with ctx: - sub_proc = subprocess.Popen(['airflow', 'serve_logs'], env=env, close_fds=True) - worker.run(**options) - sub_proc.kill() - - stdout.close() - stderr.close() - else: - signal.signal(signal.SIGINT, sigint_handler) - signal.signal(signal.SIGTERM, sigint_handler) - - sub_proc = subprocess.Popen(['airflow', 'serve_logs'], env=env, close_fds=True) - - worker.run(**options) - sub_proc.kill() - - @cli_utils.action_logging def flower(args): """Starts Flower, Celery monitoring tool""" @@ -1586,7 +1521,7 @@ class CLIFactory: 'do_pickle', 'pid', 'daemon', 'stdout', 'stderr', 'log_file'), }, { - 'func': worker, + 'func': worker_command.worker, 'help': "Start a Celery worker node", 'args': ('do_pickle', 'queues', 'concurrency', 'celery_hostname', 'pid', 'daemon', 'stdout', 'stderr', 'log_file', 'autoscale'), diff --git a/airflow/cli/commands/worker_command.py b/airflow/cli/commands/worker_command.py new file mode 100644 index 0000000000000..9958034f6f598 --- /dev/null +++ b/airflow/cli/commands/worker_command.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Worker command""" +import os +import signal +import subprocess +import sys + +import daemon +from daemon.pidfile import TimeoutPIDLockFile + +from airflow import LoggingMixin, conf, settings +from airflow.utils import cli as cli_utils +from airflow.utils.cli import setup_locations, setup_logging, sigint_handler + + +@cli_utils.action_logging +def worker(args): + """Starts Airflow Celery worker""" + env = os.environ.copy() + env['AIRFLOW_HOME'] = settings.AIRFLOW_HOME + + if not settings.validate_session(): + log = LoggingMixin().log + log.error("Worker exiting... database connection precheck failed! ") + sys.exit(1) + + # Celery worker + from airflow.executors.celery_executor import app as celery_app + from celery.bin import worker # pylint: disable=redefined-outer-name + + autoscale = args.autoscale + if autoscale is None and conf.has_option("celery", "worker_autoscale"): + autoscale = conf.get("celery", "worker_autoscale") + worker = worker.worker(app=celery_app) # pylint: disable=redefined-outer-name + options = { + 'optimization': 'fair', + 'O': 'fair', + 'queues': args.queues, + 'concurrency': args.concurrency, + 'autoscale': autoscale, + 'hostname': args.celery_hostname, + 'loglevel': conf.get('core', 'LOGGING_LEVEL'), + } + + if conf.has_option("celery", "pool"): + options["pool"] = conf.get("celery", "pool") + + if args.daemon: + pid, stdout, stderr, log_file = setup_locations("worker", + args.pid, + args.stdout, + args.stderr, + args.log_file) + handle = setup_logging(log_file) + stdout = open(stdout, 'w+') + stderr = open(stderr, 'w+') + + ctx = daemon.DaemonContext( + pidfile=TimeoutPIDLockFile(pid, -1), + files_preserve=[handle], + stdout=stdout, + stderr=stderr, + ) + with ctx: + sub_proc = subprocess.Popen(['airflow', 'serve_logs'], env=env, close_fds=True) + worker.run(**options) + sub_proc.kill() + + stdout.close() + stderr.close() + else: + signal.signal(signal.SIGINT, sigint_handler) + signal.signal(signal.SIGTERM, sigint_handler) + + sub_proc = subprocess.Popen(['airflow', 'serve_logs'], env=env, close_fds=True) + + worker.run(**options) + sub_proc.kill() diff --git a/tests/cli/test_worker_initialisation.py b/tests/cli/commands/test_worker_command.py similarity index 96% rename from tests/cli/test_worker_initialisation.py rename to tests/cli/commands/test_worker_command.py index 0415081ac2af8..34de016e9bbae 100644 --- a/tests/cli/test_worker_initialisation.py +++ b/tests/cli/commands/test_worker_command.py @@ -16,14 +16,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - import unittest from argparse import Namespace import sqlalchemy import airflow -from airflow.bin import cli # noqa +from airflow.cli.commands import worker_command from tests.compat import mock, patch from tests.test_utils.config import conf_vars @@ -41,7 +40,7 @@ def test_error(self, mock_validate_session): mock_validate_session.return_value = False with self.assertRaises(SystemExit) as cm: # airflow.bin.cli.worker(mock_args) - cli.worker(mock_args) + worker_command.worker(mock_args) self.assertEqual(cm.exception.code, 1) @conf_vars({('core', 'worker_precheck'): 'False'}) From 4b95f5c5afb82f51debbe470cf618687f1478e4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Sun, 17 Nov 2019 07:11:00 +0100 Subject: [PATCH 15/23] [AIRLFOW-YYY] Move webserver command to separate file --- airflow/bin/cli.py | 283 +---------------- airflow/cli/commands/webserver_command.py | 302 +++++++++++++++++++ tests/cli/commands/test_webserver_command.py | 175 +++++++++++ tests/cli/test_cli.py | 150 +-------- 4 files changed, 483 insertions(+), 427 deletions(-) create mode 100644 airflow/cli/commands/webserver_command.py create mode 100644 tests/cli/commands/test_webserver_command.py diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index 44d532e34a8da..3ab2d20a0c420 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -26,13 +26,10 @@ import os import signal import subprocess -import sys import textwrap -import time from argparse import RawTextHelpFormatter import daemon -import psutil from daemon.pidfile import TimeoutPIDLockFile from tabulate import tabulate, tabulate_formats @@ -40,24 +37,20 @@ from airflow.api.client import get_current_api_client from airflow.cli.commands import ( connection_command, db_command, pool_command, role_command, rotate_fernet_key_command, scheduler_command, - sync_perm_command, task_command, user_command, variable_command, version_command, worker_command, + sync_perm_command, task_command, user_command, variable_command, version_command, webserver_command, + worker_command, ) from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowWebServerTimeout +from airflow.exceptions import AirflowException from airflow.models import DAG, DagBag, DagModel, DagRun, TaskInstance from airflow.utils import cli as cli_utils, db -from airflow.utils.cli import ( - alternative_conn_specs, get_dag, process_subdir, setup_locations, setup_logging, sigint_handler, -) +from airflow.utils.cli import alternative_conn_specs, get_dag, process_subdir, setup_locations, sigint_handler from airflow.utils.dot_renderer import render_dag from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.timezone import parse as parsedate -from airflow.www.app import cached_app, create_app api.load_auth() -LOG = LoggingMixin().log - DAGS_FOLDER = settings.DAGS_FOLDER if "BUILDING_AIRFLOW_DOCS" in os.environ: @@ -304,272 +297,6 @@ def dag_list_jobs(args, dag=None): print(msg) -def get_num_ready_workers_running(gunicorn_master_proc): - """Returns number of ready Gunicorn workers by looking for READY_PREFIX in process name""" - workers = psutil.Process(gunicorn_master_proc.pid).children() - - def ready_prefix_on_cmdline(proc): - try: - cmdline = proc.cmdline() - if len(cmdline) > 0: # pylint: disable=len-as-condition - return settings.GUNICORN_WORKER_READY_PREFIX in cmdline[0] - except psutil.NoSuchProcess: - pass - return False - - ready_workers = [proc for proc in workers if ready_prefix_on_cmdline(proc)] - return len(ready_workers) - - -def get_num_workers_running(gunicorn_master_proc): - """Returns number of running Gunicorn workers processes""" - workers = psutil.Process(gunicorn_master_proc.pid).children() - return len(workers) - - -def restart_workers(gunicorn_master_proc, num_workers_expected, master_timeout): - """ - Runs forever, monitoring the child processes of @gunicorn_master_proc and - restarting workers occasionally. - Each iteration of the loop traverses one edge of this state transition - diagram, where each state (node) represents - [ num_ready_workers_running / num_workers_running ]. We expect most time to - be spent in [n / n]. `bs` is the setting webserver.worker_refresh_batch_size. - The horizontal transition at ? happens after the new worker parses all the - dags (so it could take a while!) - V ────────────────────────────────────────────────────────────────────────┐ - [n / n] ──TTIN──> [ [n, n+bs) / n + bs ] ────?───> [n + bs / n + bs] ──TTOU─┘ - ^ ^───────────────┘ - │ - │ ┌────────────────v - └──────┴────── [ [0, n) / n ] <─── start - We change the number of workers by sending TTIN and TTOU to the gunicorn - master process, which increases and decreases the number of child workers - respectively. Gunicorn guarantees that on TTOU workers are terminated - gracefully and that the oldest worker is terminated. - """ - - def wait_until_true(fn, timeout=0): - """ - Sleeps until fn is true - """ - start_time = time.time() - while not fn(): - if 0 < timeout <= time.time() - start_time: - raise AirflowWebServerTimeout( - "No response from gunicorn master within {0} seconds" - .format(timeout)) - time.sleep(0.1) - - def start_refresh(gunicorn_master_proc): - batch_size = conf.getint('webserver', 'worker_refresh_batch_size') - LOG.debug('%s doing a refresh of %s workers', state, batch_size) - sys.stdout.flush() - sys.stderr.flush() - - excess = 0 - for _ in range(batch_size): - gunicorn_master_proc.send_signal(signal.SIGTTIN) - excess += 1 - wait_until_true(lambda: num_workers_expected + excess == - get_num_workers_running(gunicorn_master_proc), - master_timeout) - - try: # pylint: disable=too-many-nested-blocks - wait_until_true(lambda: num_workers_expected == - get_num_workers_running(gunicorn_master_proc), - master_timeout) - while True: - num_workers_running = get_num_workers_running(gunicorn_master_proc) - num_ready_workers_running = \ - get_num_ready_workers_running(gunicorn_master_proc) - - state = '[{0} / {1}]'.format(num_ready_workers_running, num_workers_running) - - # Whenever some workers are not ready, wait until all workers are ready - if num_ready_workers_running < num_workers_running: - LOG.debug('%s some workers are starting up, waiting...', state) - sys.stdout.flush() - time.sleep(1) - - # Kill a worker gracefully by asking gunicorn to reduce number of workers - elif num_workers_running > num_workers_expected: - excess = num_workers_running - num_workers_expected - LOG.debug('%s killing %s workers', state, excess) - - for _ in range(excess): - gunicorn_master_proc.send_signal(signal.SIGTTOU) - excess -= 1 - wait_until_true(lambda: num_workers_expected + excess == - get_num_workers_running(gunicorn_master_proc), - master_timeout) - - # Start a new worker by asking gunicorn to increase number of workers - elif num_workers_running == num_workers_expected: - refresh_interval = conf.getint('webserver', 'worker_refresh_interval') - LOG.debug( - '%s sleeping for %ss starting doing a refresh...', - state, refresh_interval - ) - time.sleep(refresh_interval) - start_refresh(gunicorn_master_proc) - - else: - # num_ready_workers_running == num_workers_running < num_workers_expected - LOG.error(( - "%s some workers seem to have died and gunicorn" - "did not restart them as expected" - ), state) - time.sleep(10) - if len( - psutil.Process(gunicorn_master_proc.pid).children() - ) < num_workers_expected: - start_refresh(gunicorn_master_proc) - except (AirflowWebServerTimeout, OSError) as err: - LOG.error(err) - LOG.error("Shutting down webserver") - try: - gunicorn_master_proc.terminate() - gunicorn_master_proc.wait() - finally: - sys.exit(1) - - -@cli_utils.action_logging -def webserver(args): - """Starts Airflow Webserver""" - print(settings.HEADER) - - access_logfile = args.access_logfile or conf.get('webserver', 'access_logfile') - error_logfile = args.error_logfile or conf.get('webserver', 'error_logfile') - num_workers = args.workers or conf.get('webserver', 'workers') - worker_timeout = (args.worker_timeout or - conf.get('webserver', 'web_server_worker_timeout')) - ssl_cert = args.ssl_cert or conf.get('webserver', 'web_server_ssl_cert') - ssl_key = args.ssl_key or conf.get('webserver', 'web_server_ssl_key') - if not ssl_cert and ssl_key: - raise AirflowException( - 'An SSL certificate must also be provided for use with ' + ssl_key) - if ssl_cert and not ssl_key: - raise AirflowException( - 'An SSL key must also be provided for use with ' + ssl_cert) - - if args.debug: - print( - "Starting the web server on port {0} and host {1}.".format( - args.port, args.hostname)) - app, _ = create_app(None, testing=conf.getboolean('core', 'unit_test_mode')) - app.run(debug=True, use_reloader=not app.config['TESTING'], - port=args.port, host=args.hostname, - ssl_context=(ssl_cert, ssl_key) if ssl_cert and ssl_key else None) - else: - os.environ['SKIP_DAGS_PARSING'] = 'True' - app = cached_app(None) - pid, stdout, stderr, log_file = setup_locations( - "webserver", args.pid, args.stdout, args.stderr, args.log_file) - os.environ.pop('SKIP_DAGS_PARSING') - if args.daemon: - handle = setup_logging(log_file) - stdout = open(stdout, 'w+') - stderr = open(stderr, 'w+') - - print( - textwrap.dedent('''\ - Running the Gunicorn Server with: - Workers: {num_workers} {workerclass} - Host: {hostname}:{port} - Timeout: {worker_timeout} - Logfiles: {access_logfile} {error_logfile} - =================================================================\ - '''.format(num_workers=num_workers, workerclass=args.workerclass, - hostname=args.hostname, port=args.port, - worker_timeout=worker_timeout, access_logfile=access_logfile, - error_logfile=error_logfile))) - - run_args = [ - 'gunicorn', - '-w', str(num_workers), - '-k', str(args.workerclass), - '-t', str(worker_timeout), - '-b', args.hostname + ':' + str(args.port), - '-n', 'airflow-webserver', - '-p', str(pid), - '-c', 'python:airflow.www.gunicorn_config', - ] - - if args.access_logfile: - run_args += ['--access-logfile', str(args.access_logfile)] - - if args.error_logfile: - run_args += ['--error-logfile', str(args.error_logfile)] - - if args.daemon: - run_args += ['-D'] - - if ssl_cert: - run_args += ['--certfile', ssl_cert, '--keyfile', ssl_key] - - webserver_module = 'www' - run_args += ["airflow." + webserver_module + ".app:cached_app()"] - - gunicorn_master_proc = None - - def kill_proc(dummy_signum, dummy_frame): # pylint: disable=unused-argument - gunicorn_master_proc.terminate() - gunicorn_master_proc.wait() - sys.exit(0) - - def monitor_gunicorn(gunicorn_master_proc): - # These run forever until SIG{INT, TERM, KILL, ...} signal is sent - if conf.getint('webserver', 'worker_refresh_interval') > 0: - master_timeout = conf.getint('webserver', 'web_server_master_timeout') - restart_workers(gunicorn_master_proc, num_workers, master_timeout) - else: - while gunicorn_master_proc.poll() is None: - time.sleep(1) - - sys.exit(gunicorn_master_proc.returncode) - - if args.daemon: - base, ext = os.path.splitext(pid) - ctx = daemon.DaemonContext( - pidfile=TimeoutPIDLockFile(base + "-monitor" + ext, -1), - files_preserve=[handle], - stdout=stdout, - stderr=stderr, - signal_map={ - signal.SIGINT: kill_proc, - signal.SIGTERM: kill_proc - }, - ) - with ctx: - subprocess.Popen(run_args, close_fds=True) - - # Reading pid file directly, since Popen#pid doesn't - # seem to return the right value with DaemonContext. - while True: - try: - with open(pid) as file: - gunicorn_master_proc_pid = int(file.read()) - break - except OSError: - LOG.debug("Waiting for gunicorn's pid file to be created.") - time.sleep(0.1) - - gunicorn_master_proc = psutil.Process(gunicorn_master_proc_pid) - monitor_gunicorn(gunicorn_master_proc) - - stdout.close() - stderr.close() - else: - gunicorn_master_proc = subprocess.Popen(run_args, close_fds=True) - - signal.signal(signal.SIGINT, kill_proc) - signal.signal(signal.SIGTERM, kill_proc) - - monitor_gunicorn(gunicorn_master_proc) - - @cli_utils.action_logging def serve_logs(args): """Serves logs generated by Worker""" @@ -1509,7 +1236,7 @@ class CLIFactory: 'help': "Serve logs generate by worker", 'args': tuple(), }, { - 'func': webserver, + 'func': webserver_command.webserver, 'help': "Start a Airflow webserver instance", 'args': ('port', 'workers', 'workerclass', 'worker_timeout', 'hostname', 'pid', 'daemon', 'stdout', 'stderr', 'access_logfile', diff --git a/airflow/cli/commands/webserver_command.py b/airflow/cli/commands/webserver_command.py new file mode 100644 index 0000000000000..b8fa122c7793e --- /dev/null +++ b/airflow/cli/commands/webserver_command.py @@ -0,0 +1,302 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Webserver command""" +import os +import signal +import subprocess +import sys +import textwrap +import time + +import daemon +import psutil +from daemon.pidfile import TimeoutPIDLockFile + +from airflow import AirflowException, LoggingMixin, conf, settings +from airflow.exceptions import AirflowWebServerTimeout +from airflow.utils import cli as cli_utils +from airflow.utils.cli import setup_locations, setup_logging +from airflow.www.app import cached_app, create_app + +LOG = LoggingMixin().log + + +def get_num_ready_workers_running(gunicorn_master_proc): + """Returns number of ready Gunicorn workers by looking for READY_PREFIX in process name""" + workers = psutil.Process(gunicorn_master_proc.pid).children() + + def ready_prefix_on_cmdline(proc): + try: + cmdline = proc.cmdline() + if len(cmdline) > 0: # pylint: disable=len-as-condition + return settings.GUNICORN_WORKER_READY_PREFIX in cmdline[0] + except psutil.NoSuchProcess: + pass + return False + + ready_workers = [proc for proc in workers if ready_prefix_on_cmdline(proc)] + return len(ready_workers) + + +def get_num_workers_running(gunicorn_master_proc): + """Returns number of running Gunicorn workers processes""" + workers = psutil.Process(gunicorn_master_proc.pid).children() + return len(workers) + + +def restart_workers(gunicorn_master_proc, num_workers_expected, master_timeout): + """ + Runs forever, monitoring the child processes of @gunicorn_master_proc and + restarting workers occasionally. + Each iteration of the loop traverses one edge of this state transition + diagram, where each state (node) represents + [ num_ready_workers_running / num_workers_running ]. We expect most time to + be spent in [n / n]. `bs` is the setting webserver.worker_refresh_batch_size. + The horizontal transition at ? happens after the new worker parses all the + dags (so it could take a while!) + V ────────────────────────────────────────────────────────────────────────┐ + [n / n] ──TTIN──> [ [n, n+bs) / n + bs ] ────?───> [n + bs / n + bs] ──TTOU─┘ + ^ ^───────────────┘ + │ + │ ┌────────────────v + └──────┴────── [ [0, n) / n ] <─── start + We change the number of workers by sending TTIN and TTOU to the gunicorn + master process, which increases and decreases the number of child workers + respectively. Gunicorn guarantees that on TTOU workers are terminated + gracefully and that the oldest worker is terminated. + """ + + def wait_until_true(fn, timeout=0): + """ + Sleeps until fn is true + """ + start_time = time.time() + while not fn(): + if 0 < timeout <= time.time() - start_time: + raise AirflowWebServerTimeout( + "No response from gunicorn master within {0} seconds" + .format(timeout)) + time.sleep(0.1) + + def start_refresh(gunicorn_master_proc): + batch_size = conf.getint('webserver', 'worker_refresh_batch_size') + LOG.debug('%s doing a refresh of %s workers', state, batch_size) + sys.stdout.flush() + sys.stderr.flush() + + excess = 0 + for _ in range(batch_size): + gunicorn_master_proc.send_signal(signal.SIGTTIN) + excess += 1 + wait_until_true(lambda: num_workers_expected + excess == + get_num_workers_running(gunicorn_master_proc), + master_timeout) + + try: # pylint: disable=too-many-nested-blocks + wait_until_true(lambda: num_workers_expected == + get_num_workers_running(gunicorn_master_proc), + master_timeout) + while True: + num_workers_running = get_num_workers_running(gunicorn_master_proc) + num_ready_workers_running = \ + get_num_ready_workers_running(gunicorn_master_proc) + + state = '[{0} / {1}]'.format(num_ready_workers_running, num_workers_running) + + # Whenever some workers are not ready, wait until all workers are ready + if num_ready_workers_running < num_workers_running: + LOG.debug('%s some workers are starting up, waiting...', state) + sys.stdout.flush() + time.sleep(1) + + # Kill a worker gracefully by asking gunicorn to reduce number of workers + elif num_workers_running > num_workers_expected: + excess = num_workers_running - num_workers_expected + LOG.debug('%s killing %s workers', state, excess) + + for _ in range(excess): + gunicorn_master_proc.send_signal(signal.SIGTTOU) + excess -= 1 + wait_until_true(lambda: num_workers_expected + excess == + get_num_workers_running(gunicorn_master_proc), + master_timeout) + + # Start a new worker by asking gunicorn to increase number of workers + elif num_workers_running == num_workers_expected: + refresh_interval = conf.getint('webserver', 'worker_refresh_interval') + LOG.debug( + '%s sleeping for %ss starting doing a refresh...', + state, refresh_interval + ) + time.sleep(refresh_interval) + start_refresh(gunicorn_master_proc) + + else: + # num_ready_workers_running == num_workers_running < num_workers_expected + LOG.error(( + "%s some workers seem to have died and gunicorn" + "did not restart them as expected" + ), state) + time.sleep(10) + if len( + psutil.Process(gunicorn_master_proc.pid).children() + ) < num_workers_expected: + start_refresh(gunicorn_master_proc) + except (AirflowWebServerTimeout, OSError) as err: + LOG.error(err) + LOG.error("Shutting down webserver") + try: + gunicorn_master_proc.terminate() + gunicorn_master_proc.wait() + finally: + sys.exit(1) + + +@cli_utils.action_logging +def webserver(args): + """Starts Airflow Webserver""" + print(settings.HEADER) + + access_logfile = args.access_logfile or conf.get('webserver', 'access_logfile') + error_logfile = args.error_logfile or conf.get('webserver', 'error_logfile') + num_workers = args.workers or conf.get('webserver', 'workers') + worker_timeout = (args.worker_timeout or + conf.get('webserver', 'web_server_worker_timeout')) + ssl_cert = args.ssl_cert or conf.get('webserver', 'web_server_ssl_cert') + ssl_key = args.ssl_key or conf.get('webserver', 'web_server_ssl_key') + if not ssl_cert and ssl_key: + raise AirflowException( + 'An SSL certificate must also be provided for use with ' + ssl_key) + if ssl_cert and not ssl_key: + raise AirflowException( + 'An SSL key must also be provided for use with ' + ssl_cert) + + if args.debug: + print( + "Starting the web server on port {0} and host {1}.".format( + args.port, args.hostname)) + app, _ = create_app(None, testing=conf.getboolean('core', 'unit_test_mode')) + app.run(debug=True, use_reloader=not app.config['TESTING'], + port=args.port, host=args.hostname, + ssl_context=(ssl_cert, ssl_key) if ssl_cert and ssl_key else None) + else: + os.environ['SKIP_DAGS_PARSING'] = 'True' + app = cached_app(None) + pid, stdout, stderr, log_file = setup_locations( + "webserver", args.pid, args.stdout, args.stderr, args.log_file) + os.environ.pop('SKIP_DAGS_PARSING') + if args.daemon: + handle = setup_logging(log_file) + stdout = open(stdout, 'w+') + stderr = open(stderr, 'w+') + + print( + textwrap.dedent('''\ + Running the Gunicorn Server with: + Workers: {num_workers} {workerclass} + Host: {hostname}:{port} + Timeout: {worker_timeout} + Logfiles: {access_logfile} {error_logfile} + =================================================================\ + '''.format(num_workers=num_workers, workerclass=args.workerclass, + hostname=args.hostname, port=args.port, + worker_timeout=worker_timeout, access_logfile=access_logfile, + error_logfile=error_logfile))) + + run_args = [ + 'gunicorn', + '-w', str(num_workers), + '-k', str(args.workerclass), + '-t', str(worker_timeout), + '-b', args.hostname + ':' + str(args.port), + '-n', 'airflow-webserver', + '-p', str(pid), + '-c', 'python:airflow.www.gunicorn_config', + ] + + if args.access_logfile: + run_args += ['--access-logfile', str(args.access_logfile)] + + if args.error_logfile: + run_args += ['--error-logfile', str(args.error_logfile)] + + if args.daemon: + run_args += ['-D'] + + if ssl_cert: + run_args += ['--certfile', ssl_cert, '--keyfile', ssl_key] + + webserver_module = 'www' + run_args += ["airflow." + webserver_module + ".app:cached_app()"] + + gunicorn_master_proc = None + + def kill_proc(dummy_signum, dummy_frame): # pylint: disable=unused-argument + gunicorn_master_proc.terminate() + gunicorn_master_proc.wait() + sys.exit(0) + + def monitor_gunicorn(gunicorn_master_proc): + # These run forever until SIG{INT, TERM, KILL, ...} signal is sent + if conf.getint('webserver', 'worker_refresh_interval') > 0: + master_timeout = conf.getint('webserver', 'web_server_master_timeout') + restart_workers(gunicorn_master_proc, num_workers, master_timeout) + else: + while gunicorn_master_proc.poll() is None: + time.sleep(1) + + sys.exit(gunicorn_master_proc.returncode) + + if args.daemon: + base, ext = os.path.splitext(pid) + ctx = daemon.DaemonContext( + pidfile=TimeoutPIDLockFile(base + "-monitor" + ext, -1), + files_preserve=[handle], + stdout=stdout, + stderr=stderr, + signal_map={ + signal.SIGINT: kill_proc, + signal.SIGTERM: kill_proc + }, + ) + with ctx: + subprocess.Popen(run_args, close_fds=True) + + # Reading pid file directly, since Popen#pid doesn't + # seem to return the right value with DaemonContext. + while True: + try: + with open(pid) as file: + gunicorn_master_proc_pid = int(file.read()) + break + except OSError: + LOG.debug("Waiting for gunicorn's pid file to be created.") + time.sleep(0.1) + + gunicorn_master_proc = psutil.Process(gunicorn_master_proc_pid) + monitor_gunicorn(gunicorn_master_proc) + + stdout.close() + stderr.close() + else: + gunicorn_master_proc = subprocess.Popen(run_args, close_fds=True) + + signal.signal(signal.SIGINT, kill_proc) + signal.signal(signal.SIGTERM, kill_proc) + + monitor_gunicorn(gunicorn_master_proc) diff --git a/tests/cli/commands/test_webserver_command.py b/tests/cli/commands/test_webserver_command.py new file mode 100644 index 0000000000000..d7ffbf997e761 --- /dev/null +++ b/tests/cli/commands/test_webserver_command.py @@ -0,0 +1,175 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import subprocess +import tempfile +import unittest +from time import sleep +from unittest import mock + +import psutil + +from airflow.bin import cli +from airflow.cli.commands import webserver_command +from airflow.cli.commands.webserver_command import get_num_ready_workers_running +from airflow.models import DagBag +from airflow.utils.cli import setup_locations +from tests import conf_vars, settings + + +class TestCLIGetNumReadyWorkersRunning(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.dagbag = DagBag(include_examples=True) + cls.parser = cli.CLIFactory.get_parser() + + def setUp(self): + self.gunicorn_master_proc = mock.Mock(pid=None) + self.children = mock.MagicMock() + self.child = mock.MagicMock() + self.process = mock.MagicMock() + + def test_ready_prefix_on_cmdline(self): + self.child.cmdline.return_value = [settings.GUNICORN_WORKER_READY_PREFIX] + self.process.children.return_value = [self.child] + + with mock.patch('psutil.Process', return_value=self.process): + self.assertEqual(get_num_ready_workers_running(self.gunicorn_master_proc), 1) + + def test_ready_prefix_on_cmdline_no_children(self): + self.process.children.return_value = [] + + with mock.patch('psutil.Process', return_value=self.process): + self.assertEqual(get_num_ready_workers_running(self.gunicorn_master_proc), 0) + + def test_ready_prefix_on_cmdline_zombie(self): + self.child.cmdline.return_value = [] + self.process.children.return_value = [self.child] + + with mock.patch('psutil.Process', return_value=self.process): + self.assertEqual(get_num_ready_workers_running(self.gunicorn_master_proc), 0) + + def test_ready_prefix_on_cmdline_dead_process(self): + self.child.cmdline.side_effect = psutil.NoSuchProcess(11347) + self.process.children.return_value = [self.child] + + with mock.patch('psutil.Process', return_value=self.process): + self.assertEqual(get_num_ready_workers_running(self.gunicorn_master_proc), 0) + + def test_cli_webserver_debug(self): + env = os.environ.copy() + proc = psutil.Popen(["airflow", "webserver", "-d"], env=env) + sleep(3) # wait for webserver to start + return_code = proc.poll() + self.assertEqual( + None, + return_code, + "webserver terminated with return code {} in debug mode".format(return_code)) + proc.terminate() + proc.wait() + + +class TestCliWebServer(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.parser = cli.CLIFactory.get_parser() + + def setUp(self) -> None: + self._check_processes() + self._clean_pidfiles() + + def _check_processes(self): + try: + # Confirm that webserver hasn't been launched. + # pgrep returns exit status 1 if no process matched. + self.assertEqual(1, subprocess.Popen(["pgrep", "-f", "-c", "airflow webserver"]).wait()) + self.assertEqual(1, subprocess.Popen(["pgrep", "-c", "gunicorn"]).wait()) + except: # noqa: E722 + subprocess.Popen(["ps", "-ax"]).wait() + raise + + def tearDown(self) -> None: + self._check_processes() + + def _clean_pidfiles(self): + pidfile_webserver = setup_locations("webserver")[0] + pidfile_monitor = setup_locations("webserver-monitor")[0] + if os.path.exists(pidfile_webserver): + os.remove(pidfile_webserver) + if os.path.exists(pidfile_monitor): + os.remove(pidfile_monitor) + + def _wait_pidfile(self, pidfile): + while True: + try: + with open(pidfile) as file: + return int(file.read()) + except Exception: # pylint: disable=broad-except + sleep(1) + + def test_cli_webserver_foreground(self): + # Run webserver in foreground and terminate it. + proc = subprocess.Popen(["airflow", "webserver"]) + proc.terminate() + proc.wait() + + @unittest.skipIf("TRAVIS" in os.environ and bool(os.environ["TRAVIS"]), + "Skipping test due to lack of required file permission") + def test_cli_webserver_foreground_with_pid(self): + # Run webserver in foreground with --pid option + pidfile = tempfile.mkstemp()[1] + proc = subprocess.Popen(["airflow", "webserver", "--pid", pidfile]) + + # Check the file specified by --pid option exists + self._wait_pidfile(pidfile) + + # Terminate webserver + proc.terminate() + proc.wait() + + @unittest.skipIf("TRAVIS" in os.environ and bool(os.environ["TRAVIS"]), + "Skipping test due to lack of required file permission") + def test_cli_webserver_background(self): + pidfile_webserver = setup_locations("webserver")[0] + pidfile_monitor = setup_locations("webserver-monitor")[0] + + # Run webserver as daemon in background. Note that the wait method is not called. + subprocess.Popen(["airflow", "webserver", "-D"]) + + pid_monitor = self._wait_pidfile(pidfile_monitor) + self._wait_pidfile(pidfile_webserver) + + # Assert that gunicorn and its monitor are launched. + self.assertEqual(0, subprocess.Popen(["pgrep", "-f", "-c", "airflow webserver"]).wait()) + self.assertEqual(0, subprocess.Popen(["pgrep", "-c", "gunicorn"]).wait()) + + # Terminate monitor process. + proc = psutil.Process(pid_monitor) + proc.terminate() + proc.wait() + + # Patch for causing webserver timeout + @mock.patch("airflow.cli.commands.webserver_command.get_num_workers_running", return_value=0) + def test_cli_webserver_shutdown_when_gunicorn_master_is_killed(self, _): + # Shorten timeout so that this test doesn't take too long time + args = self.parser.parse_args(['webserver']) + with conf_vars({('webserver', 'web_server_master_timeout'): '10'}): + with self.assertRaises(SystemExit) as e: + webserver_command.webserver(args) + self.assertEqual(e.exception.code, 1) diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 8b86aa837d60b..4060fea013c8b 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -24,21 +24,16 @@ import unittest from argparse import Namespace from datetime import datetime, time, timedelta -from time import sleep -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock -import psutil import pytz import airflow.bin.cli as cli from airflow import AirflowException, models, settings -from airflow.bin.cli import get_num_ready_workers_running from airflow.models import DagModel from airflow.settings import Session from airflow.utils import timezone -from airflow.utils.cli import setup_locations from airflow.utils.state import State -from tests import conf_vars from tests.compat import mock dag_folder_path = '/'.join(os.path.realpath(__file__).split('/')[:-1]) @@ -111,59 +106,6 @@ def create_mock_args( # pylint: disable=too-many-arguments ) -class TestCLI(unittest.TestCase): - - @classmethod - def setUpClass(cls): - cls.dagbag = models.DagBag(include_examples=True) - cls.parser = cli.CLIFactory.get_parser() - - def setUp(self): - self.gunicorn_master_proc = Mock(pid=None) - self.children = MagicMock() - self.child = MagicMock() - self.process = MagicMock() - - def test_ready_prefix_on_cmdline(self): - self.child.cmdline.return_value = [settings.GUNICORN_WORKER_READY_PREFIX] - self.process.children.return_value = [self.child] - - with patch('psutil.Process', return_value=self.process): - self.assertEqual(get_num_ready_workers_running(self.gunicorn_master_proc), 1) - - def test_ready_prefix_on_cmdline_no_children(self): - self.process.children.return_value = [] - - with patch('psutil.Process', return_value=self.process): - self.assertEqual(get_num_ready_workers_running(self.gunicorn_master_proc), 0) - - def test_ready_prefix_on_cmdline_zombie(self): - self.child.cmdline.return_value = [] - self.process.children.return_value = [self.child] - - with patch('psutil.Process', return_value=self.process): - self.assertEqual(get_num_ready_workers_running(self.gunicorn_master_proc), 0) - - def test_ready_prefix_on_cmdline_dead_process(self): - self.child.cmdline.side_effect = psutil.NoSuchProcess(11347) - self.process.children.return_value = [self.child] - - with patch('psutil.Process', return_value=self.process): - self.assertEqual(get_num_ready_workers_running(self.gunicorn_master_proc), 0) - - def test_cli_webserver_debug(self): - env = os.environ.copy() - proc = psutil.Popen(["airflow", "webserver", "-d"], env=env) - sleep(3) # wait for webserver to start - return_code = proc.poll() - self.assertEqual( - None, - return_code, - "webserver terminated with return code {} in debug mode".format(return_code)) - proc.terminate() - proc.wait() - - class TestCliDags(unittest.TestCase): @classmethod @@ -500,93 +442,3 @@ def test_cli_list_jobs(self): def test_dag_state(self): self.assertEqual(None, cli.dag_state(self.parser.parse_args([ 'dags', 'state', 'example_bash_operator', DEFAULT_DATE.isoformat()]))) - - -class TestCliWebServer(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.parser = cli.CLIFactory.get_parser() - - def setUp(self) -> None: - self._check_processes() - self._clean_pidfiles() - - def _check_processes(self): - try: - # Confirm that webserver hasn't been launched. - # pgrep returns exit status 1 if no process matched. - self.assertEqual(1, subprocess.Popen(["pgrep", "-f", "-c", "airflow webserver"]).wait()) - self.assertEqual(1, subprocess.Popen(["pgrep", "-c", "gunicorn"]).wait()) - except: # noqa: E722 - subprocess.Popen(["ps", "-ax"]).wait() - raise - - def tearDown(self) -> None: - self._check_processes() - - def _clean_pidfiles(self): - pidfile_webserver = setup_locations("webserver")[0] - pidfile_monitor = setup_locations("webserver-monitor")[0] - if os.path.exists(pidfile_webserver): - os.remove(pidfile_webserver) - if os.path.exists(pidfile_monitor): - os.remove(pidfile_monitor) - - def _wait_pidfile(self, pidfile): - while True: - try: - with open(pidfile) as file: - return int(file.read()) - except Exception: # pylint: disable=broad-except - sleep(1) - - def test_cli_webserver_foreground(self): - # Run webserver in foreground and terminate it. - proc = subprocess.Popen(["airflow", "webserver"]) - proc.terminate() - proc.wait() - - @unittest.skipIf("TRAVIS" in os.environ and bool(os.environ["TRAVIS"]), - "Skipping test due to lack of required file permission") - def test_cli_webserver_foreground_with_pid(self): - # Run webserver in foreground with --pid option - pidfile = tempfile.mkstemp()[1] - proc = subprocess.Popen(["airflow", "webserver", "--pid", pidfile]) - - # Check the file specified by --pid option exists - self._wait_pidfile(pidfile) - - # Terminate webserver - proc.terminate() - proc.wait() - - @unittest.skipIf("TRAVIS" in os.environ and bool(os.environ["TRAVIS"]), - "Skipping test due to lack of required file permission") - def test_cli_webserver_background(self): - pidfile_webserver = setup_locations("webserver")[0] - pidfile_monitor = setup_locations("webserver-monitor")[0] - - # Run webserver as daemon in background. Note that the wait method is not called. - subprocess.Popen(["airflow", "webserver", "-D"]) - - pid_monitor = self._wait_pidfile(pidfile_monitor) - self._wait_pidfile(pidfile_webserver) - - # Assert that gunicorn and its monitor are launched. - self.assertEqual(0, subprocess.Popen(["pgrep", "-f", "-c", "airflow webserver"]).wait()) - self.assertEqual(0, subprocess.Popen(["pgrep", "-c", "gunicorn"]).wait()) - - # Terminate monitor process. - proc = psutil.Process(pid_monitor) - proc.terminate() - proc.wait() - - # Patch for causing webserver timeout - @mock.patch("airflow.bin.cli.get_num_workers_running", return_value=0) - def test_cli_webserver_shutdown_when_gunicorn_master_is_killed(self, _): - # Shorten timeout so that this test doesn't take too long time - args = self.parser.parse_args(['webserver']) - with conf_vars({('webserver', 'web_server_master_timeout'): '10'}): - with self.assertRaises(SystemExit) as e: - cli.webserver(args) - self.assertEqual(e.exception.code, 1) From ba70cc7ee717cae24a8d07a5f755b971d0ecd502 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Sun, 17 Nov 2019 07:16:28 +0100 Subject: [PATCH 16/23] [AIRLFOW-YYY] Move dag commands to separate file --- airflow/bin/cli.py | 338 +----------------- airflow/cli/commands/dag_command.py | 326 +++++++++++++++++ .../test_dag_command.py} | 57 +-- 3 files changed, 373 insertions(+), 348 deletions(-) create mode 100644 airflow/cli/commands/dag_command.py rename tests/cli/{test_cli.py => commands/test_dag_command.py} (90%) diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index 3ab2d20a0c420..1464450b9bc22 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -20,33 +20,24 @@ """Command-line interface""" import argparse -import errno -import json -import logging import os import signal -import subprocess import textwrap from argparse import RawTextHelpFormatter import daemon from daemon.pidfile import TimeoutPIDLockFile -from tabulate import tabulate, tabulate_formats +from tabulate import tabulate_formats -from airflow import api, jobs, settings -from airflow.api.client import get_current_api_client +from airflow import api, settings from airflow.cli.commands import ( - connection_command, db_command, pool_command, role_command, rotate_fernet_key_command, scheduler_command, - sync_perm_command, task_command, user_command, variable_command, version_command, webserver_command, - worker_command, + connection_command, dag_command, db_command, pool_command, role_command, rotate_fernet_key_command, + scheduler_command, sync_perm_command, task_command, user_command, variable_command, version_command, + webserver_command, worker_command, ) from airflow.configuration import conf -from airflow.exceptions import AirflowException -from airflow.models import DAG, DagBag, DagModel, DagRun, TaskInstance -from airflow.utils import cli as cli_utils, db -from airflow.utils.cli import alternative_conn_specs, get_dag, process_subdir, setup_locations, sigint_handler -from airflow.utils.dot_renderer import render_dag -from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils import cli as cli_utils +from airflow.utils.cli import alternative_conn_specs, setup_locations, sigint_handler from airflow.utils.timezone import parse as parsedate api.load_auth() @@ -57,246 +48,6 @@ DAGS_FOLDER = '[AIRFLOW_HOME]/dags' -@cli_utils.action_logging -def dag_backfill(args, dag=None): - """Creates backfill job or dry run for a DAG""" - logging.basicConfig( - level=settings.LOGGING_LEVEL, - format=settings.SIMPLE_LOG_FORMAT) - - signal.signal(signal.SIGTERM, sigint_handler) - - dag = dag or get_dag(args) - - if not args.start_date and not args.end_date: - raise AirflowException("Provide a start_date and/or end_date") - - # If only one date is passed, using same as start and end - args.end_date = args.end_date or args.start_date - args.start_date = args.start_date or args.end_date - - if args.task_regex: - dag = dag.sub_dag( - task_regex=args.task_regex, - include_upstream=not args.ignore_dependencies) - - run_conf = None - if args.conf: - run_conf = json.loads(args.conf) - - if args.dry_run: - print("Dry run of DAG {0} on {1}".format(args.dag_id, - args.start_date)) - for task in dag.tasks: - print("Task {0}".format(task.task_id)) - ti = TaskInstance(task, args.start_date) - ti.dry_run() - else: - if args.reset_dagruns: - DAG.clear_dags( - [dag], - start_date=args.start_date, - end_date=args.end_date, - confirm_prompt=not args.yes, - include_subdags=True, - ) - - dag.run( - start_date=args.start_date, - end_date=args.end_date, - mark_success=args.mark_success, - local=args.local, - donot_pickle=(args.donot_pickle or - conf.getboolean('core', 'donot_pickle')), - ignore_first_depends_on_past=args.ignore_first_depends_on_past, - ignore_task_deps=args.ignore_dependencies, - pool=args.pool, - delay_on_limit_secs=args.delay_on_limit, - verbose=args.verbose, - conf=run_conf, - rerun_failed_tasks=args.rerun_failed_tasks, - run_backwards=args.run_backwards - ) - - -@cli_utils.action_logging -def dag_trigger(args): - """ - Creates a dag run for the specified dag - - :param args: - :return: - """ - api_client = get_current_api_client() - log = LoggingMixin().log - try: - message = api_client.trigger_dag(dag_id=args.dag_id, - run_id=args.run_id, - conf=args.conf, - execution_date=args.exec_date) - except OSError as err: - log.error(err) - raise AirflowException(err) - log.info(message) - - -@cli_utils.action_logging -def dag_delete(args): - """ - Deletes all DB records related to the specified dag - - :param args: - :return: - """ - api_client = get_current_api_client() - log = LoggingMixin().log - if args.yes or input( - "This will drop all existing records related to the specified DAG. " - "Proceed? (y/n)").upper() == "Y": - try: - message = api_client.delete_dag(dag_id=args.dag_id) - except OSError as err: - log.error(err) - raise AirflowException(err) - log.info(message) - else: - print("Bail.") - - -@cli_utils.action_logging -def dag_pause(args): - """Pauses a DAG""" - set_is_paused(True, args) - - -@cli_utils.action_logging -def dag_unpause(args): - """Unpauses a DAG""" - set_is_paused(False, args) - - -def set_is_paused(is_paused, args): - """Sets is_paused for DAG by a given dag_id""" - DagModel.get_dagmodel(args.dag_id).set_is_paused( - is_paused=is_paused, - ) - - print("Dag: {}, paused: {}".format(args.dag_id, str(is_paused))) - - -def dag_show(args): - """Displays DAG or saves it's graphic representation to the file""" - dag = get_dag(args) - dot = render_dag(dag) - if args.save: - filename, _, fileformat = args.save.rpartition('.') - dot.render(filename=filename, format=fileformat, cleanup=True) - print("File {} saved".format(args.save)) - elif args.imgcat: - data = dot.pipe(format='png') - try: - proc = subprocess.Popen("imgcat", stdout=subprocess.PIPE, stdin=subprocess.PIPE) - except OSError as e: - if e.errno == errno.ENOENT: - raise AirflowException( - "Failed to execute. Make sure the imgcat executables are on your systems \'PATH\'" - ) - else: - raise - out, err = proc.communicate(data) - if out: - print(out.decode('utf-8')) - if err: - print(err.decode('utf-8')) - else: - print(dot.source) - - -@cli_utils.action_logging -def dag_state(args): - """ - Returns the state of a DagRun at the command line. - >>> airflow dags state tutorial 2015-01-01T00:00:00.000000 - running - """ - dag = get_dag(args) - dr = DagRun.find(dag.dag_id, execution_date=args.execution_date) - print(dr[0].state if len(dr) > 0 else None) # pylint: disable=len-as-condition - - -@cli_utils.action_logging -def dag_next_execution(args): - """ - Returns the next execution datetime of a DAG at the command line. - >>> airflow dags next_execution tutorial - 2018-08-31 10:38:00 - """ - dag = get_dag(args) - - if dag.is_paused: - print("[INFO] Please be reminded this DAG is PAUSED now.") - - if dag.latest_execution_date: - next_execution_dttm = dag.following_schedule(dag.latest_execution_date) - - if next_execution_dttm is None: - print("[WARN] No following schedule can be found. " + - "This DAG may have schedule interval '@once' or `None`.") - - print(next_execution_dttm) - else: - print("[WARN] Only applicable when there is execution record found for the DAG.") - print(None) - - -@cli_utils.action_logging -def dag_list_dags(args): - """Displays dags with or without stats at the command line""" - dagbag = DagBag(process_subdir(args.subdir)) - list_template = textwrap.dedent("""\n - ------------------------------------------------------------------- - DAGS - ------------------------------------------------------------------- - {dag_list} - """) - dag_list = "\n".join(sorted(dagbag.dags)) - print(list_template.format(dag_list=dag_list)) - if args.report: - print(dagbag.dagbag_report()) - - -@cli_utils.action_logging -def dag_list_jobs(args, dag=None): - """Lists latest n jobs""" - queries = [] - if dag: - args.dag_id = dag.dag_id - if args.dag_id: - dagbag = DagBag() - - if args.dag_id not in dagbag.dags: - error_message = "Dag id {} not found".format(args.dag_id) - raise AirflowException(error_message) - queries.append(jobs.BaseJob.dag_id == args.dag_id) - - if args.state: - queries.append(jobs.BaseJob.state == args.state) - - with db.create_session() as session: - all_jobs = (session - .query(jobs.BaseJob) - .filter(*queries) - .order_by(jobs.BaseJob.start_date.desc()) - .limit(args.limit) - .all()) - fields = ['dag_id', 'state', 'job_type', 'start_date', 'end_date'] - all_jobs = [[job.__getattribute__(field) for field in fields] for job in all_jobs] - msg = tabulate(all_jobs, - [field.capitalize().replace('_', ' ') for field in fields], - tablefmt=args.output) - print(msg) - - @cli_utils.action_logging def serve_logs(args): """Serves logs generated by Worker""" @@ -392,59 +143,6 @@ def kerberos(args): airflow.security.kerberos.run(principal=args.principal, keytab=args.keytab) -@cli_utils.action_logging -def dag_list_dag_runs(args, dag=None): - """Lists dag runs for a given DAG""" - if dag: - args.dag_id = dag.dag_id - - dagbag = DagBag() - - if args.dag_id not in dagbag.dags: - error_message = "Dag id {} not found".format(args.dag_id) - raise AirflowException(error_message) - - dag_runs = list() - state = args.state.lower() if args.state else None - for dag_run in DagRun.find(dag_id=args.dag_id, - state=state, - no_backfills=args.no_backfill): - dag_runs.append({ - 'id': dag_run.id, - 'run_id': dag_run.run_id, - 'state': dag_run.state, - 'dag_id': dag_run.dag_id, - 'execution_date': dag_run.execution_date.isoformat(), - 'start_date': ((dag_run.start_date or '') and - dag_run.start_date.isoformat()), - }) - if not dag_runs: - print('No dag runs for {dag_id}'.format(dag_id=args.dag_id)) - - header_template = textwrap.dedent("""\n - {line} - DAG RUNS - {line} - {dag_run_header} - """) - - dag_runs.sort(key=lambda x: x['execution_date'], reverse=True) - dag_run_header = '%-3s | %-20s | %-10s | %-20s | %-20s |' % ('id', - 'run_id', - 'state', - 'execution_date', - 'start_date') - print(header_template.format(dag_run_header=dag_run_header, - line='-' * 120)) - for dag_run in dag_runs: - record = '%-3s | %-20s | %-10s | %-20s | %-20s |' % (dag_run['id'], - dag_run['run_id'], - dag_run['state'], - dag_run['execution_date'], - dag_run['start_date']) - print(record) - - class Arg: """Class to keep information about command line argument""" # pylint: disable=redefined-builtin @@ -975,13 +673,13 @@ class CLIFactory: 'name': 'dags', 'subcommands': ( { - 'func': dag_list_dags, + 'func': dag_command.dag_list_dags, 'name': 'list', 'help': "List all the DAGs", 'args': ('subdir', 'report'), }, { - 'func': dag_list_dag_runs, + 'func': dag_command.dag_list_dag_runs, 'name': 'list_runs', 'help': "List dag runs given a DAG id. If state option is given, it will only " "search for all the dagruns with the given state. " @@ -990,55 +688,55 @@ class CLIFactory: 'args': ('dag_id', 'no_backfill', 'state'), }, { - 'func': dag_list_jobs, + 'func': dag_command.dag_list_jobs, 'name': 'list_jobs', 'help': "List the jobs", 'args': ('dag_id_opt', 'state', 'limit', 'output',), }, { - 'func': dag_state, + 'func': dag_command.dag_state, 'name': 'state', 'help': "Get the status of a dag run", 'args': ('dag_id', 'execution_date', 'subdir'), }, { - 'func': dag_next_execution, + 'func': dag_command.dag_next_execution, 'name': 'next_execution', 'help': "Get the next execution datetime of a DAG.", 'args': ('dag_id', 'subdir'), }, { - 'func': dag_pause, + 'func': dag_command.dag_pause, 'name': 'pause', 'help': 'Pause a DAG', 'args': ('dag_id', 'subdir'), }, { - 'func': dag_unpause, + 'func': dag_command.dag_unpause, 'name': 'unpause', 'help': 'Resume a paused DAG', 'args': ('dag_id', 'subdir'), }, { - 'func': dag_trigger, + 'func': dag_command.dag_trigger, 'name': 'trigger', 'help': 'Trigger a DAG run', 'args': ('dag_id', 'subdir', 'run_id', 'conf', 'exec_date'), }, { - 'func': dag_delete, + 'func': dag_command.dag_delete, 'name': 'delete', 'help': "Delete all DB records related to the specified DAG", 'args': ('dag_id', 'yes'), }, { - 'func': dag_show, + 'func': dag_command.dag_show, 'name': 'show', 'help': "Displays DAG's tasks with their dependencies", 'args': ('dag_id', 'subdir', 'save', 'imgcat',), }, { - 'func': dag_backfill, + 'func': dag_command.dag_backfill, 'name': 'backfill', 'help': "Run subsections of a DAG for a specified date range. " "If reset_dag_run option is used," diff --git a/airflow/cli/commands/dag_command.py b/airflow/cli/commands/dag_command.py new file mode 100644 index 0000000000000..4ba249ab311d0 --- /dev/null +++ b/airflow/cli/commands/dag_command.py @@ -0,0 +1,326 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Dag sub-commands""" +import errno +import json +import logging +import signal +import subprocess +import textwrap + +from tabulate import tabulate + +from airflow import DAG, AirflowException, LoggingMixin, conf, jobs, settings +from airflow.api.client import get_current_api_client +from airflow.models import DagBag, DagModel, DagRun, TaskInstance +from airflow.utils import cli as cli_utils, db +from airflow.utils.cli import get_dag, process_subdir, sigint_handler +from airflow.utils.dot_renderer import render_dag + + +@cli_utils.action_logging +def dag_backfill(args, dag=None): + """Creates backfill job or dry run for a DAG""" + logging.basicConfig( + level=settings.LOGGING_LEVEL, + format=settings.SIMPLE_LOG_FORMAT) + + signal.signal(signal.SIGTERM, sigint_handler) + + dag = dag or get_dag(args) + + if not args.start_date and not args.end_date: + raise AirflowException("Provide a start_date and/or end_date") + + # If only one date is passed, using same as start and end + args.end_date = args.end_date or args.start_date + args.start_date = args.start_date or args.end_date + + if args.task_regex: + dag = dag.sub_dag( + task_regex=args.task_regex, + include_upstream=not args.ignore_dependencies) + + run_conf = None + if args.conf: + run_conf = json.loads(args.conf) + + if args.dry_run: + print("Dry run of DAG {0} on {1}".format(args.dag_id, + args.start_date)) + for task in dag.tasks: + print("Task {0}".format(task.task_id)) + ti = TaskInstance(task, args.start_date) + ti.dry_run() + else: + if args.reset_dagruns: + DAG.clear_dags( + [dag], + start_date=args.start_date, + end_date=args.end_date, + confirm_prompt=not args.yes, + include_subdags=True, + ) + + dag.run( + start_date=args.start_date, + end_date=args.end_date, + mark_success=args.mark_success, + local=args.local, + donot_pickle=(args.donot_pickle or + conf.getboolean('core', 'donot_pickle')), + ignore_first_depends_on_past=args.ignore_first_depends_on_past, + ignore_task_deps=args.ignore_dependencies, + pool=args.pool, + delay_on_limit_secs=args.delay_on_limit, + verbose=args.verbose, + conf=run_conf, + rerun_failed_tasks=args.rerun_failed_tasks, + run_backwards=args.run_backwards + ) + + +@cli_utils.action_logging +def dag_trigger(args): + """ + Creates a dag run for the specified dag + + :param args: + :return: + """ + api_client = get_current_api_client() + log = LoggingMixin().log + try: + message = api_client.trigger_dag(dag_id=args.dag_id, + run_id=args.run_id, + conf=args.conf, + execution_date=args.exec_date) + except OSError as err: + log.error(err) + raise AirflowException(err) + log.info(message) + + +@cli_utils.action_logging +def dag_delete(args): + """ + Deletes all DB records related to the specified dag + + :param args: + :return: + """ + api_client = get_current_api_client() + log = LoggingMixin().log + if args.yes or input( + "This will drop all existing records related to the specified DAG. " + "Proceed? (y/n)").upper() == "Y": + try: + message = api_client.delete_dag(dag_id=args.dag_id) + except OSError as err: + log.error(err) + raise AirflowException(err) + log.info(message) + else: + print("Bail.") + + +@cli_utils.action_logging +def dag_pause(args): + """Pauses a DAG""" + set_is_paused(True, args) + + +@cli_utils.action_logging +def dag_unpause(args): + """Unpauses a DAG""" + set_is_paused(False, args) + + +def set_is_paused(is_paused, args): + """Sets is_paused for DAG by a given dag_id""" + DagModel.get_dagmodel(args.dag_id).set_is_paused( + is_paused=is_paused, + ) + + print("Dag: {}, paused: {}".format(args.dag_id, str(is_paused))) + + +def dag_show(args): + """Displays DAG or saves it's graphic representation to the file""" + dag = get_dag(args) + dot = render_dag(dag) + if args.save: + filename, _, fileformat = args.save.rpartition('.') + dot.render(filename=filename, format=fileformat, cleanup=True) + print("File {} saved".format(args.save)) + elif args.imgcat: + data = dot.pipe(format='png') + try: + proc = subprocess.Popen("imgcat", stdout=subprocess.PIPE, stdin=subprocess.PIPE) + except OSError as e: + if e.errno == errno.ENOENT: + raise AirflowException( + "Failed to execute. Make sure the imgcat executables are on your systems \'PATH\'" + ) + else: + raise + out, err = proc.communicate(data) + if out: + print(out.decode('utf-8')) + if err: + print(err.decode('utf-8')) + else: + print(dot.source) + + +@cli_utils.action_logging +def dag_state(args): + """ + Returns the state of a DagRun at the command line. + >>> airflow dags state tutorial 2015-01-01T00:00:00.000000 + running + """ + dag = get_dag(args) + dr = DagRun.find(dag.dag_id, execution_date=args.execution_date) + print(dr[0].state if len(dr) > 0 else None) # pylint: disable=len-as-condition + + +@cli_utils.action_logging +def dag_next_execution(args): + """ + Returns the next execution datetime of a DAG at the command line. + >>> airflow dags next_execution tutorial + 2018-08-31 10:38:00 + """ + dag = get_dag(args) + + if dag.is_paused: + print("[INFO] Please be reminded this DAG is PAUSED now.") + + if dag.latest_execution_date: + next_execution_dttm = dag.following_schedule(dag.latest_execution_date) + + if next_execution_dttm is None: + print("[WARN] No following schedule can be found. " + + "This DAG may have schedule interval '@once' or `None`.") + + print(next_execution_dttm) + else: + print("[WARN] Only applicable when there is execution record found for the DAG.") + print(None) + + +@cli_utils.action_logging +def dag_list_dags(args): + """Displays dags with or without stats at the command line""" + dagbag = DagBag(process_subdir(args.subdir)) + list_template = textwrap.dedent("""\n + ------------------------------------------------------------------- + DAGS + ------------------------------------------------------------------- + {dag_list} + """) + dag_list = "\n".join(sorted(dagbag.dags)) + print(list_template.format(dag_list=dag_list)) + if args.report: + print(dagbag.dagbag_report()) + + +@cli_utils.action_logging +def dag_list_jobs(args, dag=None): + """Lists latest n jobs""" + queries = [] + if dag: + args.dag_id = dag.dag_id + if args.dag_id: + dagbag = DagBag() + + if args.dag_id not in dagbag.dags: + error_message = "Dag id {} not found".format(args.dag_id) + raise AirflowException(error_message) + queries.append(jobs.BaseJob.dag_id == args.dag_id) + + if args.state: + queries.append(jobs.BaseJob.state == args.state) + + with db.create_session() as session: + all_jobs = (session + .query(jobs.BaseJob) + .filter(*queries) + .order_by(jobs.BaseJob.start_date.desc()) + .limit(args.limit) + .all()) + fields = ['dag_id', 'state', 'job_type', 'start_date', 'end_date'] + all_jobs = [[job.__getattribute__(field) for field in fields] for job in all_jobs] + msg = tabulate(all_jobs, + [field.capitalize().replace('_', ' ') for field in fields], + tablefmt=args.output) + print(msg) + + +@cli_utils.action_logging +def dag_list_dag_runs(args, dag=None): + """Lists dag runs for a given DAG""" + if dag: + args.dag_id = dag.dag_id + + dagbag = DagBag() + + if args.dag_id not in dagbag.dags: + error_message = "Dag id {} not found".format(args.dag_id) + raise AirflowException(error_message) + + dag_runs = list() + state = args.state.lower() if args.state else None + for dag_run in DagRun.find(dag_id=args.dag_id, + state=state, + no_backfills=args.no_backfill): + dag_runs.append({ + 'id': dag_run.id, + 'run_id': dag_run.run_id, + 'state': dag_run.state, + 'dag_id': dag_run.dag_id, + 'execution_date': dag_run.execution_date.isoformat(), + 'start_date': ((dag_run.start_date or '') and + dag_run.start_date.isoformat()), + }) + if not dag_runs: + print('No dag runs for {dag_id}'.format(dag_id=args.dag_id)) + + header_template = textwrap.dedent("""\n + {line} + DAG RUNS + {line} + {dag_run_header} + """) + + dag_runs.sort(key=lambda x: x['execution_date'], reverse=True) + dag_run_header = '%-3s | %-20s | %-10s | %-20s | %-20s |' % ('id', + 'run_id', + 'state', + 'execution_date', + 'start_date') + print(header_template.format(dag_run_header=dag_run_header, + line='-' * 120)) + for dag_run in dag_runs: + record = '%-3s | %-20s | %-10s | %-20s | %-20s |' % (dag_run['id'], + dag_run['run_id'], + dag_run['state'], + dag_run['execution_date'], + dag_run['start_date']) + print(record) diff --git a/tests/cli/test_cli.py b/tests/cli/commands/test_dag_command.py similarity index 90% rename from tests/cli/test_cli.py rename to tests/cli/commands/test_dag_command.py index 4060fea013c8b..5dbb4a28e7cf2 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/commands/test_dag_command.py @@ -30,6 +30,7 @@ import airflow.bin.cli as cli from airflow import AirflowException, models, settings +from airflow.cli.commands import dag_command from airflow.models import DagModel from airflow.settings import Session from airflow.utils import timezone @@ -113,9 +114,9 @@ def setUpClass(cls): cls.dagbag = models.DagBag(include_examples=True) cls.parser = cli.CLIFactory.get_parser() - @mock.patch("airflow.bin.cli.DAG.run") + @mock.patch("airflow.cli.commands.dag_command.DAG.run") def test_backfill(self, mock_run): - cli.dag_backfill(self.parser.parse_args([ + dag_command.dag_backfill(self.parser.parse_args([ 'dags', 'backfill', 'example_bash_operator', '-s', DEFAULT_DATE.isoformat()])) @@ -138,7 +139,7 @@ def test_backfill(self, mock_run): dag = self.dagbag.get_dag('example_bash_operator') with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: - cli.dag_backfill(self.parser.parse_args([ + dag_command.dag_backfill(self.parser.parse_args([ 'dags', 'backfill', 'example_bash_operator', '-t', 'runme_0', '--dry_run', '-s', DEFAULT_DATE.isoformat()]), dag=dag) @@ -150,13 +151,13 @@ def test_backfill(self, mock_run): mock_run.assert_not_called() # Dry run shouldn't run the backfill - cli.dag_backfill(self.parser.parse_args([ + dag_command.dag_backfill(self.parser.parse_args([ 'dags', 'backfill', 'example_bash_operator', '--dry_run', '-s', DEFAULT_DATE.isoformat()]), dag=dag) mock_run.assert_not_called() # Dry run shouldn't run the backfill - cli.dag_backfill(self.parser.parse_args([ + dag_command.dag_backfill(self.parser.parse_args([ 'dags', 'backfill', 'example_bash_operator', '-l', '-s', DEFAULT_DATE.isoformat()]), dag=dag) @@ -180,18 +181,18 @@ def test_backfill(self, mock_run): def test_show_dag_print(self): temp_stdout = io.StringIO() with contextlib.redirect_stdout(temp_stdout): - cli.dag_show(self.parser.parse_args([ + dag_command.dag_show(self.parser.parse_args([ 'dags', 'show', 'example_bash_operator'])) out = temp_stdout.getvalue() self.assertIn("label=example_bash_operator", out) self.assertIn("graph [label=example_bash_operator labelloc=t rankdir=LR]", out) self.assertIn("runme_2 -> run_after_loop", out) - @mock.patch("airflow.bin.cli.render_dag") + @mock.patch("airflow.cli.commands.dag_command.render_dag") def test_show_dag_dave(self, mock_render_dag): temp_stdout = io.StringIO() with contextlib.redirect_stdout(temp_stdout): - cli.dag_show(self.parser.parse_args([ + dag_command.dag_show(self.parser.parse_args([ 'dags', 'show', 'example_bash_operator', '--save', 'awesome.png'] )) out = temp_stdout.getvalue() @@ -200,14 +201,14 @@ def test_show_dag_dave(self, mock_render_dag): ) self.assertIn("File awesome.png saved", out) - @mock.patch("airflow.bin.cli.subprocess.Popen") - @mock.patch("airflow.bin.cli.render_dag") + @mock.patch("airflow.cli.commands.dag_command.subprocess.Popen") + @mock.patch("airflow.cli.commands.dag_command.render_dag") def test_show_dag_imgcat(self, mock_render_dag, mock_popen): mock_render_dag.return_value.pipe.return_value = b"DOT_DATA" mock_popen.return_value.communicate.return_value = (b"OUT", b"ERR") temp_stdout = io.StringIO() with contextlib.redirect_stdout(temp_stdout): - cli.dag_show(self.parser.parse_args([ + dag_command.dag_show(self.parser.parse_args([ 'dags', 'show', 'example_bash_operator', '--imgcat'] )) out = temp_stdout.getvalue() @@ -216,7 +217,7 @@ def test_show_dag_imgcat(self, mock_render_dag, mock_popen): self.assertIn("OUT", out) self.assertIn("ERR", out) - @mock.patch("airflow.bin.cli.DAG.run") + @mock.patch("airflow.cli.commands.dag_command.DAG.run") def test_cli_backfill_depends_on_past(self, mock_run): """ Test that CLI respects -I argument @@ -237,7 +238,7 @@ def test_cli_backfill_depends_on_past(self, mock_run): ] dag = self.dagbag.get_dag(dag_id) - cli.dag_backfill(self.parser.parse_args(args), dag=dag) + dag_command.dag_backfill(self.parser.parse_args(args), dag=dag) mock_run.assert_called_once_with( start_date=run_date, @@ -255,7 +256,7 @@ def test_cli_backfill_depends_on_past(self, mock_run): verbose=False, ) - @mock.patch("airflow.bin.cli.DAG.run") + @mock.patch("airflow.cli.commands.dag_command.DAG.run") def test_cli_backfill_depends_on_past_backwards(self, mock_run): """ Test that CLI respects -B argument and raises on interaction with depends_on_past @@ -277,7 +278,7 @@ def test_cli_backfill_depends_on_past_backwards(self, mock_run): ] dag = self.dagbag.get_dag(dag_id) - cli.dag_backfill(self.parser.parse_args(args), dag=dag) + dag_command.dag_backfill(self.parser.parse_args(args), dag=dag) mock_run.assert_called_once_with( start_date=start_date, end_date=end_date, @@ -362,15 +363,15 @@ def reset_dr_db(dag_id): def test_cli_list_dags(self): args = self.parser.parse_args(['dags', 'list', '--report']) - cli.dag_list_dags(args) + dag_command.dag_list_dags(args) def test_cli_list_dag_runs(self): - cli.dag_trigger(self.parser.parse_args([ + dag_command.dag_trigger(self.parser.parse_args([ 'dags', 'trigger', 'example_bash_operator', ])) args = self.parser.parse_args(['dags', 'list_runs', 'example_bash_operator', '--no_backfill']) - cli.dag_list_dag_runs(args) + dag_command.dag_list_dag_runs(args) def test_cli_list_jobs_with_args(self): args = self.parser.parse_args(['dags', 'list_jobs', '--dag_id', @@ -378,26 +379,26 @@ def test_cli_list_jobs_with_args(self): '--state', 'success', '--limit', '100', '--output', 'tsv']) - cli.dag_list_jobs(args) + dag_command.dag_list_jobs(args) def test_pause(self): args = self.parser.parse_args([ 'dags', 'pause', 'example_bash_operator']) - cli.dag_pause(args) + dag_command.dag_pause(args) self.assertIn(self.dagbag.dags['example_bash_operator'].is_paused, [True, 1]) args = self.parser.parse_args([ 'dags', 'unpause', 'example_bash_operator']) - cli.dag_unpause(args) + dag_command.dag_unpause(args) self.assertIn(self.dagbag.dags['example_bash_operator'].is_paused, [False, 0]) def test_trigger_dag(self): - cli.dag_trigger(self.parser.parse_args([ + dag_command.dag_trigger(self.parser.parse_args([ 'dags', 'trigger', 'example_bash_operator', '-c', '{"foo": "bar"}'])) self.assertRaises( ValueError, - cli.dag_trigger, + dag_command.dag_trigger, self.parser.parse_args([ 'dags', 'trigger', 'example_bash_operator', '--run_id', 'trigger_dag_xxx', @@ -410,12 +411,12 @@ def test_delete_dag(self): session = settings.Session() session.add(DM(dag_id=key)) session.commit() - cli.dag_delete(self.parser.parse_args([ + dag_command.dag_delete(self.parser.parse_args([ 'dags', 'delete', key, '--yes'])) self.assertEqual(session.query(DM).filter_by(dag_id=key).count(), 0) self.assertRaises( AirflowException, - cli.dag_delete, + dag_command.dag_delete, self.parser.parse_args([ 'dags', 'delete', 'does_not_exist_dag', @@ -431,14 +432,14 @@ def test_delete_dag_existing_file(self): with tempfile.NamedTemporaryFile() as f: session.add(DM(dag_id=key, fileloc=f.name)) session.commit() - cli.dag_delete(self.parser.parse_args([ + dag_command.dag_delete(self.parser.parse_args([ 'dags', 'delete', key, '--yes'])) self.assertEqual(session.query(DM).filter_by(dag_id=key).count(), 0) def test_cli_list_jobs(self): args = self.parser.parse_args(['dags', 'list_jobs']) - cli.dag_list_jobs(args) + dag_command.dag_list_jobs(args) def test_dag_state(self): - self.assertEqual(None, cli.dag_state(self.parser.parse_args([ + self.assertEqual(None, dag_command.dag_state(self.parser.parse_args([ 'dags', 'state', 'example_bash_operator', DEFAULT_DATE.isoformat()]))) From 494491e500ae5a1da1902df06df73cf6a5188167 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Sun, 17 Nov 2019 07:20:10 +0100 Subject: [PATCH 17/23] [AIRLFOW-YYY] Move serve logs command to separate file --- airflow/bin/cli.py | 26 ++------------ airflow/cli/commands/serve_logs_command.py | 42 ++++++++++++++++++++++ 2 files changed, 45 insertions(+), 23 deletions(-) create mode 100644 airflow/cli/commands/serve_logs_command.py diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index 1464450b9bc22..09299f47523b3 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -32,8 +32,8 @@ from airflow import api, settings from airflow.cli.commands import ( connection_command, dag_command, db_command, pool_command, role_command, rotate_fernet_key_command, - scheduler_command, sync_perm_command, task_command, user_command, variable_command, version_command, - webserver_command, worker_command, + scheduler_command, serve_logs_command, sync_perm_command, task_command, user_command, variable_command, + version_command, webserver_command, worker_command, ) from airflow.configuration import conf from airflow.utils import cli as cli_utils @@ -48,26 +48,6 @@ DAGS_FOLDER = '[AIRFLOW_HOME]/dags' -@cli_utils.action_logging -def serve_logs(args): - """Serves logs generated by Worker""" - print("Starting flask") - import flask - flask_app = flask.Flask(__name__) - - @flask_app.route('/log/') - def serve_logs(filename): # pylint: disable=unused-variable, redefined-outer-name - log = os.path.expanduser(conf.get('core', 'BASE_LOG_FOLDER')) - return flask.send_from_directory( - log, - filename, - mimetype="application/json", - as_attachment=False) - - worker_log_server_port = int(conf.get('celery', 'WORKER_LOG_SERVER_PORT')) - flask_app.run(host='0.0.0.0', port=worker_log_server_port) - - @cli_utils.action_logging def flower(args): """Starts Flower, Celery monitoring tool""" @@ -930,7 +910,7 @@ class CLIFactory: 'args': ('principal', 'keytab', 'pid', 'daemon', 'stdout', 'stderr', 'log_file'), }, { - 'func': serve_logs, + 'func': serve_logs_command.serve_logs, 'help': "Serve logs generate by worker", 'args': tuple(), }, { diff --git a/airflow/cli/commands/serve_logs_command.py b/airflow/cli/commands/serve_logs_command.py new file mode 100644 index 0000000000000..86e29464124d8 --- /dev/null +++ b/airflow/cli/commands/serve_logs_command.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Serve logs command""" +import os + +from airflow import conf +from airflow.utils import cli as cli_utils + + +@cli_utils.action_logging +def serve_logs(args): + """Serves logs generated by Worker""" + print("Starting flask") + import flask + flask_app = flask.Flask(__name__) + + @flask_app.route('/log/') + def serve_logs(filename): # pylint: disable=unused-variable, redefined-outer-name + log = os.path.expanduser(conf.get('core', 'BASE_LOG_FOLDER')) + return flask.send_from_directory( + log, + filename, + mimetype="application/json", + as_attachment=False) + + worker_log_server_port = int(conf.get('celery', 'WORKER_LOG_SERVER_PORT')) + flask_app.run(host='0.0.0.0', port=worker_log_server_port) From 6da4362e7a86dc0645a073cf2491fa171b4ed744 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Sun, 17 Nov 2019 07:25:59 +0100 Subject: [PATCH 18/23] [AIRLFOW-YYY] Move flower command to separate file --- airflow/bin/cli.py | 58 ++------------------ airflow/cli/commands/flower_command.py | 75 ++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 53 deletions(-) create mode 100644 airflow/cli/commands/flower_command.py diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index 09299f47523b3..810ab513d7316 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -21,7 +21,6 @@ import argparse import os -import signal import textwrap from argparse import RawTextHelpFormatter @@ -31,13 +30,13 @@ from airflow import api, settings from airflow.cli.commands import ( - connection_command, dag_command, db_command, pool_command, role_command, rotate_fernet_key_command, - scheduler_command, serve_logs_command, sync_perm_command, task_command, user_command, variable_command, - version_command, webserver_command, worker_command, + connection_command, dag_command, db_command, flower_command, pool_command, role_command, + rotate_fernet_key_command, scheduler_command, serve_logs_command, sync_perm_command, task_command, + user_command, variable_command, version_command, webserver_command, worker_command, ) from airflow.configuration import conf from airflow.utils import cli as cli_utils -from airflow.utils.cli import alternative_conn_specs, setup_locations, sigint_handler +from airflow.utils.cli import alternative_conn_specs, setup_locations from airflow.utils.timezone import parse as parsedate api.load_auth() @@ -48,53 +47,6 @@ DAGS_FOLDER = '[AIRFLOW_HOME]/dags' -@cli_utils.action_logging -def flower(args): - """Starts Flower, Celery monitoring tool""" - broka = conf.get('celery', 'BROKER_URL') - address = '--address={}'.format(args.hostname) - port = '--port={}'.format(args.port) - api = '' # pylint: disable=redefined-outer-name - if args.broker_api: - api = '--broker_api=' + args.broker_api - - url_prefix = '' - if args.url_prefix: - url_prefix = '--url-prefix=' + args.url_prefix - - basic_auth = '' - if args.basic_auth: - basic_auth = '--basic_auth=' + args.basic_auth - - flower_conf = '' - if args.flower_conf: - flower_conf = '--conf=' + args.flower_conf - - if args.daemon: - pid, stdout, stderr, _ = setup_locations("flower", args.pid, args.stdout, args.stderr, args.log_file) - stdout = open(stdout, 'w+') - stderr = open(stderr, 'w+') - - ctx = daemon.DaemonContext( - pidfile=TimeoutPIDLockFile(pid, -1), - stdout=stdout, - stderr=stderr, - ) - - with ctx: - os.execvp("flower", ['flower', '-b', - broka, address, port, api, flower_conf, url_prefix, basic_auth]) - - stdout.close() - stderr.close() - else: - signal.signal(signal.SIGINT, sigint_handler) - signal.signal(signal.SIGTERM, sigint_handler) - - os.execvp("flower", ['flower', '-b', - broka, address, port, api, flower_conf, url_prefix, basic_auth]) - - @cli_utils.action_logging def kerberos(args): """Start a kerberos ticket renewer""" @@ -931,7 +883,7 @@ class CLIFactory: 'args': ('do_pickle', 'queues', 'concurrency', 'celery_hostname', 'pid', 'daemon', 'stdout', 'stderr', 'log_file', 'autoscale'), }, { - 'func': flower, + 'func': flower_command.flower, 'help': "Start a Celery Flower", 'args': ('flower_hostname', 'flower_port', 'flower_conf', 'flower_url_prefix', 'flower_basic_auth', 'broker_api', 'pid', 'daemon', 'stdout', 'stderr', 'log_file'), diff --git a/airflow/cli/commands/flower_command.py b/airflow/cli/commands/flower_command.py new file mode 100644 index 0000000000000..da0eef5a8b36d --- /dev/null +++ b/airflow/cli/commands/flower_command.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Flower command""" +import os +import signal + +import daemon +from daemon.pidfile import TimeoutPIDLockFile + +from airflow import conf +from airflow.utils import cli as cli_utils +from airflow.utils.cli import setup_locations, sigint_handler + + +@cli_utils.action_logging +def flower(args): + """Starts Flower, Celery monitoring tool""" + broka = conf.get('celery', 'BROKER_URL') + address = '--address={}'.format(args.hostname) + port = '--port={}'.format(args.port) + api = '' # pylint: disable=redefined-outer-name + if args.broker_api: + api = '--broker_api=' + args.broker_api + + url_prefix = '' + if args.url_prefix: + url_prefix = '--url-prefix=' + args.url_prefix + + basic_auth = '' + if args.basic_auth: + basic_auth = '--basic_auth=' + args.basic_auth + + flower_conf = '' + if args.flower_conf: + flower_conf = '--conf=' + args.flower_conf + + if args.daemon: + pid, stdout, stderr, _ = setup_locations("flower", args.pid, args.stdout, args.stderr, args.log_file) + stdout = open(stdout, 'w+') + stderr = open(stderr, 'w+') + + ctx = daemon.DaemonContext( + pidfile=TimeoutPIDLockFile(pid, -1), + stdout=stdout, + stderr=stderr, + ) + + with ctx: + os.execvp("flower", ['flower', '-b', + broka, address, port, api, flower_conf, url_prefix, basic_auth]) + + stdout.close() + stderr.close() + else: + signal.signal(signal.SIGINT, sigint_handler) + signal.signal(signal.SIGTERM, sigint_handler) + + os.execvp("flower", ['flower', '-b', + broka, address, port, api, flower_conf, url_prefix, basic_auth]) From c7ab4faccf46c9d684609d6eec7b51ad4b831326 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Sun, 17 Nov 2019 07:31:03 +0100 Subject: [PATCH 19/23] [AIRLFOW-YYY] Move kerberos command to separate file --- airflow/bin/cli.py | 37 ++--------------- airflow/cli/commands/kerberos_command.py | 52 ++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 34 deletions(-) create mode 100644 airflow/cli/commands/kerberos_command.py diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index 810ab513d7316..7d29e618d0321 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -24,19 +24,16 @@ import textwrap from argparse import RawTextHelpFormatter -import daemon -from daemon.pidfile import TimeoutPIDLockFile from tabulate import tabulate_formats from airflow import api, settings from airflow.cli.commands import ( - connection_command, dag_command, db_command, flower_command, pool_command, role_command, + connection_command, dag_command, db_command, flower_command, kerberos_command, pool_command, role_command, rotate_fernet_key_command, scheduler_command, serve_logs_command, sync_perm_command, task_command, user_command, variable_command, version_command, webserver_command, worker_command, ) from airflow.configuration import conf -from airflow.utils import cli as cli_utils -from airflow.utils.cli import alternative_conn_specs, setup_locations +from airflow.utils.cli import alternative_conn_specs from airflow.utils.timezone import parse as parsedate api.load_auth() @@ -47,34 +44,6 @@ DAGS_FOLDER = '[AIRFLOW_HOME]/dags' -@cli_utils.action_logging -def kerberos(args): - """Start a kerberos ticket renewer""" - print(settings.HEADER) - import airflow.security.kerberos # pylint: disable=redefined-outer-name - - if args.daemon: - pid, stdout, stderr, _ = setup_locations( - "kerberos", args.pid, args.stdout, args.stderr, args.log_file - ) - stdout = open(stdout, 'w+') - stderr = open(stderr, 'w+') - - ctx = daemon.DaemonContext( - pidfile=TimeoutPIDLockFile(pid, -1), - stdout=stdout, - stderr=stderr, - ) - - with ctx: - airflow.security.kerberos.run(principal=args.principal, keytab=args.keytab) - - stdout.close() - stderr.close() - else: - airflow.security.kerberos.run(principal=args.principal, keytab=args.keytab) - - class Arg: """Class to keep information about command line argument""" # pylint: disable=redefined-builtin @@ -857,7 +826,7 @@ class CLIFactory: }, ), }, { - 'func': kerberos, + 'func': kerberos_command.kerberos, 'help': "Start a kerberos ticket renewer", 'args': ('principal', 'keytab', 'pid', 'daemon', 'stdout', 'stderr', 'log_file'), diff --git a/airflow/cli/commands/kerberos_command.py b/airflow/cli/commands/kerberos_command.py new file mode 100644 index 0000000000000..a75c596183f35 --- /dev/null +++ b/airflow/cli/commands/kerberos_command.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Kerberos command""" +import daemon +from daemon.pidfile import TimeoutPIDLockFile + +from airflow import settings +from airflow.security import kerberos as krb +from airflow.utils import cli as cli_utils +from airflow.utils.cli import setup_locations + + +@cli_utils.action_logging +def kerberos(args): + """Start a kerberos ticket renewer""" + print(settings.HEADER) + + if args.daemon: + pid, stdout, stderr, _ = setup_locations( + "kerberos", args.pid, args.stdout, args.stderr, args.log_file + ) + stdout = open(stdout, 'w+') + stderr = open(stderr, 'w+') + + ctx = daemon.DaemonContext( + pidfile=TimeoutPIDLockFile(pid, -1), + stdout=stdout, + stderr=stderr, + ) + + with ctx: + krb.run(principal=args.principal, keytab=args.keytab) + + stdout.close() + stderr.close() + else: + krb.run(principal=args.principal, keytab=args.keytab) From 4f9c2e6ccf5c1a234fb93d0694aaf5036b6e2643 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Sun, 17 Nov 2019 07:47:25 +0100 Subject: [PATCH 20/23] [AIRFLOW-YYY] Lazy load CLI commands --- airflow/bin/cli.py | 143 ++++++++++++++++++++++++++------------------- 1 file changed, 82 insertions(+), 61 deletions(-) diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index 7d29e618d0321..094954382c96c 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -23,17 +23,14 @@ import os import textwrap from argparse import RawTextHelpFormatter +from typing import Callable from tabulate import tabulate_formats from airflow import api, settings -from airflow.cli.commands import ( - connection_command, dag_command, db_command, flower_command, kerberos_command, pool_command, role_command, - rotate_fernet_key_command, scheduler_command, serve_logs_command, sync_perm_command, task_command, - user_command, variable_command, version_command, webserver_command, worker_command, -) from airflow.configuration import conf from airflow.utils.cli import alternative_conn_specs +from airflow.utils.module_loading import import_string from airflow.utils.timezone import parse as parsedate api.load_auth() @@ -44,6 +41,19 @@ DAGS_FOLDER = '[AIRFLOW_HOME]/dags' +def lazy_load_command(import_path: str) -> Callable: + """Create a lazy loader for command""" + _, _, name = import_path.rpartition('.') + + def command(*args, **kwargs): + func = import_string(import_path) + return func(*args, **kwargs) + + command.__name__ = name # type: ignore + + return command + + class Arg: """Class to keep information about command line argument""" # pylint: disable=redefined-builtin @@ -574,13 +584,13 @@ class CLIFactory: 'name': 'dags', 'subcommands': ( { - 'func': dag_command.dag_list_dags, + 'func': lazy_load_command('airflow.cli.commands.dag_command.dag_list_dags'), 'name': 'list', 'help': "List all the DAGs", 'args': ('subdir', 'report'), }, { - 'func': dag_command.dag_list_dag_runs, + 'func': lazy_load_command('airflow.cli.commands.dag_command.dag_list_dag_runs'), 'name': 'list_runs', 'help': "List dag runs given a DAG id. If state option is given, it will only " "search for all the dagruns with the given state. " @@ -589,55 +599,55 @@ class CLIFactory: 'args': ('dag_id', 'no_backfill', 'state'), }, { - 'func': dag_command.dag_list_jobs, + 'func': lazy_load_command('airflow.cli.commands.dag_command.dag_list_jobs'), 'name': 'list_jobs', 'help': "List the jobs", 'args': ('dag_id_opt', 'state', 'limit', 'output',), }, { - 'func': dag_command.dag_state, + 'func': lazy_load_command('airflow.cli.commands.dag_command.dag_state'), 'name': 'state', 'help': "Get the status of a dag run", 'args': ('dag_id', 'execution_date', 'subdir'), }, { - 'func': dag_command.dag_next_execution, + 'func': lazy_load_command('airflow.cli.commands.dag_command.dag_next_execution'), 'name': 'next_execution', 'help': "Get the next execution datetime of a DAG.", 'args': ('dag_id', 'subdir'), }, { - 'func': dag_command.dag_pause, + 'func': lazy_load_command('airflow.cli.commands.dag_command.dag_pause'), 'name': 'pause', 'help': 'Pause a DAG', 'args': ('dag_id', 'subdir'), }, { - 'func': dag_command.dag_unpause, + 'func': lazy_load_command('airflow.cli.commands.dag_command.dag_unpause'), 'name': 'unpause', 'help': 'Resume a paused DAG', 'args': ('dag_id', 'subdir'), }, { - 'func': dag_command.dag_trigger, + 'func': lazy_load_command('airflow.cli.commands.dag_command.dag_trigger'), 'name': 'trigger', 'help': 'Trigger a DAG run', 'args': ('dag_id', 'subdir', 'run_id', 'conf', 'exec_date'), }, { - 'func': dag_command.dag_delete, + 'func': lazy_load_command('airflow.cli.commands.dag_command.dag_delete'), 'name': 'delete', 'help': "Delete all DB records related to the specified DAG", 'args': ('dag_id', 'yes'), }, { - 'func': dag_command.dag_show, + 'func': lazy_load_command('airflow.cli.commands.dag_command.dag_show'), 'name': 'show', 'help': "Displays DAG's tasks with their dependencies", 'args': ('dag_id', 'subdir', 'save', 'imgcat',), }, { - 'func': dag_command.dag_backfill, + 'func': lazy_load_command('airflow.cli.commands.dag_command.dag_backfill'), 'name': 'backfill', 'help': "Run subsections of a DAG for a specified date range. " "If reset_dag_run option is used," @@ -661,13 +671,13 @@ class CLIFactory: 'name': 'tasks', 'subcommands': ( { - 'func': task_command.task_list, + 'func': lazy_load_command('airflow.cli.commands.task_command.task_list'), 'name': 'list', 'help': "List the tasks within a DAG", 'args': ('dag_id', 'tree', 'subdir'), }, { - 'func': task_command.task_clear, + 'func': lazy_load_command('airflow.cli.commands.task_command.task_clear'), 'name': 'clear', 'help': "Clear a set of task instance, as if they never ran", 'args': ( @@ -676,13 +686,13 @@ class CLIFactory: 'only_running', 'exclude_subdags', 'exclude_parentdag', 'dag_regex'), }, { - 'func': task_command.task_state, + 'func': lazy_load_command('airflow.cli.commands.task_command.task_state'), 'name': 'state', 'help': "Get the status of a task instance", 'args': ('dag_id', 'task_id', 'execution_date', 'subdir'), }, { - 'func': task_command.task_failed_deps, + 'func': lazy_load_command('airflow.cli.commands.task_command.task_failed_deps'), 'name': 'failed_deps', 'help': ( "Returns the unmet dependencies for a task instance from the perspective " @@ -692,13 +702,13 @@ class CLIFactory: 'args': ('dag_id', 'task_id', 'execution_date', 'subdir'), }, { - 'func': task_command.task_render, + 'func': lazy_load_command('airflow.cli.commands.task_command.task_render'), 'name': 'render', 'help': "Render a task instance's template(s)", 'args': ('dag_id', 'task_id', 'execution_date', 'subdir'), }, { - 'func': task_command.task_run, + 'func': lazy_load_command('airflow.cli.commands.task_command.task_run'), 'name': 'run', 'help': "Run a single task instance", 'args': ( @@ -708,7 +718,7 @@ class CLIFactory: 'ignore_depends_on_past', 'ship_dag', 'pickle', 'job_id', 'interactive',), }, { - 'func': task_command.task_test, + 'func': lazy_load_command('airflow.cli.commands.task_command.task_test'), 'name': 'test', 'help': ( "Test a task instance. This will run a task without checking for " @@ -723,37 +733,37 @@ class CLIFactory: 'name': 'pools', 'subcommands': ( { - 'func': pool_command.pool_list, + 'func': lazy_load_command('airflow.cli.commands.pool_command.pool_list'), 'name': 'list', 'help': 'List pools', 'args': ('output',), }, { - 'func': pool_command.pool_get, + 'func': lazy_load_command('airflow.cli.commands.pool_command.pool_get'), 'name': 'get', 'help': 'Get pool size', 'args': ('pool_name', 'output',), }, { - 'func': pool_command.pool_set, + 'func': lazy_load_command('airflow.cli.commands.pool_command.pool_set'), 'name': 'set', 'help': 'Configure pool', 'args': ('pool_name', 'pool_slots', 'pool_description', 'output',), }, { - 'func': pool_command.pool_delete, + 'func': lazy_load_command('airflow.cli.commands.pool_command.pool_delete'), 'name': 'delete', 'help': 'Delete pool', 'args': ('pool_name', 'output',), }, { - 'func': pool_command.pool_import, + 'func': lazy_load_command('airflow.cli.commands.pool_command.pool_import'), 'name': 'import', 'help': 'Import pool', 'args': ('pool_import', 'output',), }, { - 'func': pool_command.pool_export, + 'func': lazy_load_command('airflow.cli.commands.pool_command.pool_export'), 'name': 'export', 'help': 'Export pool', 'args': ('pool_export', 'output',), @@ -764,37 +774,37 @@ class CLIFactory: 'name': 'variables', 'subcommands': ( { - 'func': variable_command.variables_list, + 'func': lazy_load_command('airflow.cli.commands.variable_command.variables_list'), 'name': 'list', 'help': 'List variables', 'args': (), }, { - 'func': variable_command.variables_get, + 'func': lazy_load_command('airflow.cli.commands.variable_command.variables_get'), 'name': 'get', 'help': 'Get variable', 'args': ('var', 'json', 'default'), }, { - 'func': variable_command.variables_set, + 'func': lazy_load_command('airflow.cli.commands.variable_command.variables_set'), 'name': 'set', 'help': 'Set variable', 'args': ('var', 'var_value', 'json'), }, { - 'func': variable_command.variables_delete, + 'func': lazy_load_command('airflow.cli.commands.variable_command.variables_delete'), 'name': 'delete', 'help': 'Delete variable', 'args': ('var',), }, { - 'func': variable_command.variables_import, + 'func': lazy_load_command('airflow.cli.commands.variable_command.variables_import'), 'name': 'import', 'help': 'Import variables', 'args': ('var_import',), }, { - 'func': variable_command.variables_export, + 'func': lazy_load_command('airflow.cli.commands.variable_command.variables_export'), 'name': 'export', 'help': 'Export variables', 'args': ('var_export',), @@ -807,57 +817,64 @@ class CLIFactory: 'name': 'db', 'subcommands': ( { - 'func': db_command.initdb, + 'func': lazy_load_command('airflow.cli.commands.db_command.initdb'), 'name': 'init', 'help': "Initialize the metadata database", 'args': (), }, { - 'func': db_command.resetdb, + 'func': lazy_load_command('airflow.cli.commands.db_command.resetdb'), 'name': 'reset', 'help': "Burn down and rebuild the metadata database", 'args': ('yes',), }, { - 'func': db_command.upgradedb, + 'func': lazy_load_command('airflow.cli.commands.db_command.upgradedb'), 'name': 'upgrade', 'help': "Upgrade the metadata database to latest version", 'args': tuple(), }, ), }, { - 'func': kerberos_command.kerberos, + 'name': 'kerberos', + 'func': lazy_load_command('airflow.cli.commands.kerberos_command.kerberos'), 'help': "Start a kerberos ticket renewer", 'args': ('principal', 'keytab', 'pid', 'daemon', 'stdout', 'stderr', 'log_file'), }, { - 'func': serve_logs_command.serve_logs, + 'name': 'serve_logs', + 'func': lazy_load_command('airflow.cli.commands.serve_logs_command.serve_logs'), 'help': "Serve logs generate by worker", 'args': tuple(), }, { - 'func': webserver_command.webserver, + 'name': 'webserver', + 'func': lazy_load_command('airflow.cli.commands.webserver_command.webserver'), 'help': "Start a Airflow webserver instance", 'args': ('port', 'workers', 'workerclass', 'worker_timeout', 'hostname', 'pid', 'daemon', 'stdout', 'stderr', 'access_logfile', 'error_logfile', 'log_file', 'ssl_cert', 'ssl_key', 'debug'), }, { - 'func': scheduler_command.scheduler, + 'name': 'scheduler', + 'func': lazy_load_command('airflow.cli.commands.scheduler_command.scheduler'), 'help': "Start a scheduler instance", 'args': ('dag_id_opt', 'subdir', 'num_runs', 'do_pickle', 'pid', 'daemon', 'stdout', 'stderr', 'log_file'), }, { - 'func': worker_command.worker, + 'name': 'worker', + 'func': lazy_load_command('airflow.cli.commands.worker_command.worker'), 'help': "Start a Celery worker node", 'args': ('do_pickle', 'queues', 'concurrency', 'celery_hostname', 'pid', 'daemon', 'stdout', 'stderr', 'log_file', 'autoscale'), }, { - 'func': flower_command.flower, + 'name': 'flower', + 'func': lazy_load_command('airflow.cli.commands.flower_command.flower'), 'help': "Start a Celery Flower", 'args': ('flower_hostname', 'flower_port', 'flower_conf', 'flower_url_prefix', 'flower_basic_auth', 'broker_api', 'pid', 'daemon', 'stdout', 'stderr', 'log_file'), }, { - 'func': version_command.version, + 'name': 'version', + 'func': lazy_load_command('airflow.cli.commands.version_command.version'), 'help': "Show the version", 'args': tuple(), }, { @@ -865,19 +882,19 @@ class CLIFactory: 'name': 'connections', 'subcommands': ( { - 'func': connection_command.connections_list, + 'func': lazy_load_command('airflow.cli.commands.connection_command.connections_list'), 'name': 'list', 'help': 'List connections', 'args': ('output',), }, { - 'func': connection_command.connections_add, + 'func': lazy_load_command('airflow.cli.commands.connection_command.connections_add'), 'name': 'add', 'help': 'Add a connection', 'args': ('conn_id', 'conn_uri', 'conn_extra') + tuple(alternative_conn_specs), }, { - 'func': connection_command.connections_delete, + 'func': lazy_load_command('airflow.cli.commands.connection_command.connections_delete'), 'name': 'delete', 'help': 'Delete a connection', 'args': ('conn_id',), @@ -888,44 +905,44 @@ class CLIFactory: 'name': 'users', 'subcommands': ( { - 'func': user_command.users_list, + 'func': lazy_load_command('airflow.cli.commands.user_command.users_list'), 'name': 'list', 'help': 'List users', 'args': ('output',), }, { - 'func': user_command.users_create, + 'func': lazy_load_command('airflow.cli.commands.user_command.users_create'), 'name': 'create', 'help': 'Create a user', 'args': ('role', 'username', 'email', 'firstname', 'lastname', 'password', 'use_random_password') }, { - 'func': user_command.users_delete, + 'func': lazy_load_command('airflow.cli.commands.user_command.users_delete'), 'name': 'delete', 'help': 'Delete a user', 'args': ('username',), }, { - 'func': user_command.add_role, + 'func': lazy_load_command('airflow.cli.commands.user_command.add_role'), 'name': 'add_role', 'help': 'Add role to a user', 'args': ('username_optional', 'email_optional', 'role'), }, { - 'func': user_command.remove_role, + 'func': lazy_load_command('airflow.cli.commands.user_command.remove_role'), 'name': 'remove_role', 'help': 'Remove role from a user', 'args': ('username_optional', 'email_optional', 'role'), }, { - 'func': user_command.users_import, + 'func': lazy_load_command('airflow.cli.commands.user_command.users_import'), 'name': 'import', 'help': 'Import a user', 'args': ('user_import',), }, { - 'func': user_command.users_export, + 'func': lazy_load_command('airflow.cli.commands.user_command.users_export'), 'name': 'export', 'help': 'Export a user', 'args': ('user_export',), @@ -936,32 +953,34 @@ class CLIFactory: 'name': 'roles', 'subcommands': ( { - 'func': role_command.roles_list, + 'func': lazy_load_command('airflow.cli.commands.role_command.roles_list'), 'name': 'list', 'help': 'List roles', 'args': ('output',), }, { - 'func': role_command.roles_create, + 'func': lazy_load_command('airflow.cli.commands.role_command.roles_create'), 'name': 'create', 'help': 'Create role', 'args': ('roles',), }, ), }, { - 'func': sync_perm_command.sync_perm, + 'name': 'sync_perm', + 'func': lazy_load_command('airflow.cli.commands.sync_perm_command.sync_perm'), 'help': "Update permissions for existing roles and DAGs.", 'args': tuple(), }, { - 'func': rotate_fernet_key_command.rotate_fernet_key, + 'name': 'rotate_fernet_key', + 'func': lazy_load_command('airflow.cli.commands.rotate_fernet_key_command.rotate_fernet_key'), 'help': 'Rotate all encrypted connection credentials and variables; see ' 'https://airflow.readthedocs.io/en/stable/howto/secure-connections.html' '#rotating-encryption-keys.', 'args': (), }, ) - subparsers_dict = {sp.get('name') or sp['func'].__name__: sp for sp in subparsers} + subparsers_dict = {sp.get('name') or sp['func'].__name__: sp for sp in subparsers} # type: ignore dag_subparsers = ( 'list_tasks', 'backfill', 'test', 'run', 'pause', 'unpause', 'list_dag_runs') @@ -982,7 +1001,9 @@ def get_parser(cls, dag_parser=False): @classmethod def _add_subcommand(cls, subparsers, sub): dag_parser = False - sub_proc = subparsers.add_parser(sub.get('name') or sub['func'].__name__, help=sub['help']) + sub_proc = subparsers.add_parser( + sub.get('name') or sub['func'].__name__, help=sub['help'] # type: ignore + ) sub_proc.formatter_class = RawTextHelpFormatter subcommands = sub.get('subcommands', []) From ebd350fa874569958ca9b94015ea54851af6ec71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Sun, 17 Nov 2019 09:34:44 +0100 Subject: [PATCH 21/23] [AIRFLOW-YYY] Fix migration --- airflow/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/airflow/__init__.py b/airflow/__init__.py index 93d48d1931271..037da89237a8f 100644 --- a/airflow/__init__.py +++ b/airflow/__init__.py @@ -37,7 +37,8 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.models import DAG - +# Load SQLAlchemy models during package initialization +from airflow import jobs # noqa: F401 __version__ = version.version From 504086dd1cbc4cd9a576905cc1040d64538ae901 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Mon, 18 Nov 2019 13:21:28 +0100 Subject: [PATCH 22/23] fixup! [AIRFLOW-YYY] Fix migration --- airflow/__init__.py | 2 -- airflow/models/__init__.py | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/__init__.py b/airflow/__init__.py index 037da89237a8f..f0ce14b0626b3 100644 --- a/airflow/__init__.py +++ b/airflow/__init__.py @@ -37,8 +37,6 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.models import DAG -# Load SQLAlchemy models during package initialization -from airflow import jobs # noqa: F401 __version__ = version.version diff --git a/airflow/models/__init__.py b/airflow/models/__init__.py index ad02fd41178bc..9630f41f820f5 100644 --- a/airflow/models/__init__.py +++ b/airflow/models/__init__.py @@ -17,6 +17,8 @@ # specific language governing permissions and limitations # under the License. """Airflow models""" +# Load SQLAlchemy models during package initialization +import airflow.jobs # noqa: F401 from airflow.models.base import ID_LEN, Base # noqa: F401 from airflow.models.baseoperator import BaseOperator # noqa: F401 from airflow.models.connection import Connection # noqa: F401 From b40e73e85a16ba0775f28e1f4b2742fa4d463e72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Mon, 18 Nov 2019 14:52:21 +0100 Subject: [PATCH 23/23] fixup! fixup! [AIRFLOW-YYY] Fix migration --- airflow/models/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/airflow/models/__init__.py b/airflow/models/__init__.py index 9630f41f820f5..6f78f72f1f408 100644 --- a/airflow/models/__init__.py +++ b/airflow/models/__init__.py @@ -17,8 +17,6 @@ # specific language governing permissions and limitations # under the License. """Airflow models""" -# Load SQLAlchemy models during package initialization -import airflow.jobs # noqa: F401 from airflow.models.base import ID_LEN, Base # noqa: F401 from airflow.models.baseoperator import BaseOperator # noqa: F401 from airflow.models.connection import Connection # noqa: F401 @@ -38,3 +36,7 @@ from airflow.models.taskreschedule import TaskReschedule # noqa: F401 from airflow.models.variable import Variable # noqa: F401 from airflow.models.xcom import XCOM_RETURN_KEY, XCom # noqa: F401 + +# Load SQLAlchemy models during package initialization +# Must be loaded after loading DAG model. +import airflow.jobs # noqa: F401 isort # isort:skip