From 0904b516d3537e9ca52592972e6380ee6fb25125 Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Mon, 18 Nov 2019 20:48:09 +0100 Subject: [PATCH] [AIRFLOW-6004] Untangle Executors class to avoid cyclic imports There are cyclic imports detected seemingly randomly by pylint checks when some of the PRs are run in CI It was not deterministic because pylint usually uses as many processors as many are available and it splits the list of .py files between the separate pylint processors - depending on how the split is done, pylint check might or might not detect it. The cycle is always detected when all files are used. In order to make it more deterministic, all pylint and mypy errors were resolved in all executors package and in dag_processor. At the same time plugins_manager had also been moved out of the executors and all of the operators/hooks/sensors/macros because it was also causing cyclic dependencies and it's far easier to untangle those dependencies in executor when we move the intialisation of all plugins to plugins_manager. Additionally require_serial is set in pre-commit configuration to make sure cycle detection is deterministic. --- .pre-commit-config.yaml | 4 +- BREEZE.rst | 4 +- CONTRIBUTING.rst | 4 +- UPDATING.md | 5 + airflow/__init__.py | 21 +- airflow/cli/commands/flower_command.py | 2 +- airflow/cli/commands/serve_logs_command.py | 2 +- airflow/cli/commands/task_command.py | 4 +- airflow/cli/commands/worker_command.py | 3 +- airflow/executors/__init__.py | 84 +---- airflow/executors/base_executor.py | 171 +++++---- airflow/executors/celery_executor.py | 152 +++++--- airflow/executors/dask_executor.py | 51 +-- airflow/executors/executor_loader.py | 85 +++++ airflow/executors/kubernetes_executor.py | 131 ++++--- airflow/executors/local_executor.py | 247 +++++++++---- airflow/executors/sequential_executor.py | 22 +- airflow/hooks/__init__.py | 16 +- airflow/jobs/backfill_job.py | 10 +- airflow/jobs/base_job.py | 5 +- airflow/jobs/scheduler_job.py | 20 +- airflow/kubernetes/kube_client.py | 2 +- airflow/kubernetes/pod_generator.py | 9 +- airflow/kubernetes/pod_launcher.py | 2 +- airflow/macros/__init__.py | 9 - airflow/models/baseoperator.py | 2 +- airflow/models/dag.py | 7 +- airflow/models/dagbag.py | 10 +- airflow/models/kubernetes.py | 11 +- airflow/models/taskinstance.py | 150 ++++++-- airflow/operators/__init__.py | 12 +- airflow/operators/mssql_to_hive.py | 6 +- airflow/plugins_manager.py | 143 +++++--- airflow/sensors/__init__.py | 12 +- airflow/settings.py | 13 +- airflow/task/task_runner/base_task_runner.py | 8 +- .../task/task_runner/standard_task_runner.py | 7 +- airflow/utils/dag_processing.py | 342 ++++-------------- airflow/utils/file.py | 112 ++++++ airflow/www/views.py | 7 +- breeze-complete | 2 +- docs/howto/custom-operator.rst | 32 +- scripts/ci/pylint_todo.txt | 7 - tests/core.py | 10 +- tests/dags/test_subdag.py | 2 +- tests/executors/test_base_executor.py | 1 - tests/executors/test_celery_executor.py | 52 ++- tests/executors/test_dask_executor.py | 1 - tests/executors/test_executor.py | 1 - tests/executors/test_kubernetes_executor.py | 4 +- tests/executors/test_local_executor.py | 19 +- tests/executors/test_sequential_executor.py | 1 - tests/gcp/hooks/test_google_discovery_api.py | 5 +- tests/jobs/test_local_task_job.py | 2 +- tests/jobs/test_scheduler_job.py | 6 +- tests/models/test_dag.py | 2 +- tests/models/test_dagbag.py | 2 +- .../test_google_api_to_s3_transfer.py | 5 +- tests/operators/test_operators.py | 41 ++- tests/plugins/test_plugin.py | 10 +- .../task_runner/test_standard_task_runner.py | 12 +- tests/test_configuration.py | 17 +- tests/utils/test_dag_processing.py | 6 +- tests/utils/test_email.py | 4 +- tests/www/test_views.py | 4 +- 65 files changed, 1221 insertions(+), 934 deletions(-) create mode 100644 airflow/executors/executor_loader.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e2c4c0452ca5e..22643454c519b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -264,12 +264,14 @@ repos: files: \.py$ exclude: ^tests/.*\.py$|^airflow/_vendor/.*$ pass_filenames: true - - id: pylint + require_serial: true # Pylint tests should be run in one chunk to detect all cycles + - id: pylint-tests name: Run pylint for tests language: system entry: "./scripts/ci/pre_commit_pylint_tests.sh" files: ^tests/.*\.py$ pass_filenames: true + require_serial: false # For tests, it's perfectly OK to run pylint in parallel - id: flake8 name: Run flake8 language: system diff --git a/BREEZE.rst b/BREEZE.rst index 18c0a391aff38..30bc49c7f3b06 100644 --- a/BREEZE.rst +++ b/BREEZE.rst @@ -872,7 +872,7 @@ This is the current syntax for `./breeze <./breeze>`_: -S, --static-check Run selected static checks for currently changed files. You should specify static check that you would like to run or 'all' to run all checks. One of - [ all all-but-pylint check-apache-license check-executables-have-shebangs check-hooks-apply check-merge-conflict check-xml debug-statements doctoc detect-private-key end-of-file-fixer flake8 forbid-tabs insert-license lint-dockerfile mixed-line-ending mypy pylint setup-order shellcheck]. + [ all all-but-pylint check-apache-license check-executables-have-shebangs check-hooks-apply check-merge-conflict check-xml debug-statements doctoc detect-private-key end-of-file-fixer flake8 forbid-tabs insert-license lint-dockerfile mixed-line-ending mypy pylint pylint-test setup-order shellcheck]. You can pass extra arguments including options to to the pre-commit framework as passed after --. For example: @@ -886,7 +886,7 @@ This is the current syntax for `./breeze <./breeze>`_: -F, --static-check-all-files Run selected static checks for all applicable files. You should specify static check that you would like to run or 'all' to run all checks. One of - [ all all-but-pylint check-apache-license check-executables-have-shebangs check-hooks-apply check-merge-conflict check-xml debug-statements doctoc detect-private-key end-of-file-fixer flake8 forbid-tabs insert-license lint-dockerfile mixed-line-ending mypy pylint setup-order shellcheck]. + [ all all-but-pylint check-apache-license check-executables-have-shebangs check-hooks-apply check-merge-conflict check-xml debug-statements doctoc detect-private-key end-of-file-fixer flake8 forbid-tabs insert-license lint-dockerfile mixed-line-ending mypy pylint pylint-test setup-order shellcheck]. You can pass extra arguments including options to the pre-commit framework as passed after --. For example: diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 605d8dc97ca1a..83d9fe9da8d1a 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -454,7 +454,9 @@ image built locally): ----------------------------------- ---------------------------------------------------------------- ------------ ``pydevd`` Check for accidentally commited pydevd statements. ----------------------------------- ---------------------------------------------------------------- ------------ -``pylint`` Runs pylint. * +``pylint`` Runs pylint for main code. * +----------------------------------- ---------------------------------------------------------------- ------------ +``pylint-tests`` Runs pylint for tests. * ----------------------------------- ---------------------------------------------------------------- ------------ ``python-no-log-warn`` Checks if there are no deprecate log warn. ----------------------------------- ---------------------------------------------------------------- ------------ diff --git a/UPDATING.md b/UPDATING.md index 61821f5fd40f1..f0bc3598055e0 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -41,6 +41,11 @@ assists users migrating to a new version. ## Airflow Master +### Removal of airflow.AirflowMacroPlugin class + +The class was there in airflow package but it has not been used (apparently since 2015). +It has been removed. + ### Changes to settings CONTEXT_MANAGER_DAG was removed from settings. It's role has been taken by `DagContext` in diff --git a/airflow/__init__.py b/airflow/__init__.py index f9c114ef1cb86..862c9bcdc51ab 100644 --- a/airflow/__init__.py +++ b/airflow/__init__.py @@ -44,23 +44,8 @@ settings.initialize() -login = None # type: Optional[Callable] +from airflow.plugins_manager import integrate_plugins -from airflow import executors -from airflow import hooks -from airflow import macros -from airflow import operators -from airflow import sensors +login: Optional[Callable] = None - -class AirflowMacroPlugin: - # pylint: disable=missing-docstring - def __init__(self, namespace): - self.namespace = namespace - - -operators._integrate_plugins() # pylint: disable=protected-access -sensors._integrate_plugins() # pylint: disable=protected-access -hooks._integrate_plugins() # pylint: disable=protected-access -executors._integrate_plugins() # pylint: disable=protected-access -macros._integrate_plugins() # pylint: disable=protected-access +integrate_plugins() diff --git a/airflow/cli/commands/flower_command.py b/airflow/cli/commands/flower_command.py index da0eef5a8b36d..c9f90a6445614 100644 --- a/airflow/cli/commands/flower_command.py +++ b/airflow/cli/commands/flower_command.py @@ -23,7 +23,7 @@ import daemon from daemon.pidfile import TimeoutPIDLockFile -from airflow import conf +from airflow.configuration import conf from airflow.utils import cli as cli_utils from airflow.utils.cli import setup_locations, sigint_handler diff --git a/airflow/cli/commands/serve_logs_command.py b/airflow/cli/commands/serve_logs_command.py index 86e29464124d8..db48a0f0b855b 100644 --- a/airflow/cli/commands/serve_logs_command.py +++ b/airflow/cli/commands/serve_logs_command.py @@ -18,7 +18,7 @@ """Serve logs command""" import os -from airflow import conf +from airflow.configuration import conf from airflow.utils import cli as cli_utils diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index 327b51a63043e..4ea72bf60b7ef 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -25,7 +25,7 @@ from contextlib import redirect_stderr, redirect_stdout from airflow import DAG, AirflowException, conf, jobs, settings -from airflow.executors import get_default_executor +from airflow.executors.executor_loader import ExecutorLoader from airflow.models import DagPickle, TaskInstance from airflow.ti_deps.dep_context import SCHEDULER_QUEUED_DEPS, DepContext from airflow.utils import cli as cli_utils, db @@ -69,7 +69,7 @@ def _run(args, dag, ti): print(e) raise e - executor = get_default_executor() + executor = ExecutorLoader.get_default_executor() executor.start() print("Sending to executor.") executor.queue_task_instance( diff --git a/airflow/cli/commands/worker_command.py b/airflow/cli/commands/worker_command.py index bccc205a822c6..d6c57fdbe5297 100644 --- a/airflow/cli/commands/worker_command.py +++ b/airflow/cli/commands/worker_command.py @@ -23,7 +23,8 @@ import daemon from daemon.pidfile import TimeoutPIDLockFile -from airflow import conf, settings +from airflow import settings +from airflow.configuration import conf from airflow.utils import cli as cli_utils from airflow.utils.cli import setup_locations, setup_logging, sigint_handler diff --git a/airflow/executors/__init__.py b/airflow/executors/__init__.py index e6322d13ce8d6..21ee94b03b9b3 100644 --- a/airflow/executors/__init__.py +++ b/airflow/executors/__init__.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -16,84 +14,4 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# pylint: disable=missing-docstring - -import sys -from typing import Optional - -from airflow.configuration import conf -from airflow.exceptions import AirflowException -from airflow.executors.base_executor import BaseExecutor -from airflow.executors.local_executor import LocalExecutor -from airflow.executors.sequential_executor import SequentialExecutor -from airflow.utils.log.logging_mixin import LoggingMixin - -DEFAULT_EXECUTOR = None # type: Optional[BaseExecutor] - - -def _integrate_plugins(): - """Integrate plugins to the context.""" - from airflow.plugins_manager import executors_modules - for executors_module in executors_modules: - sys.modules[executors_module.__name__] = executors_module - globals()[executors_module._name] = executors_module # pylint: disable=protected-access - - -def get_default_executor(): - """Creates a new instance of the configured executor if none exists and returns it""" - global DEFAULT_EXECUTOR # pylint: disable=global-statement - - if DEFAULT_EXECUTOR is not None: - return DEFAULT_EXECUTOR - - executor_name = conf.get('core', 'EXECUTOR') - - DEFAULT_EXECUTOR = _get_executor(executor_name) - - log = LoggingMixin().log - log.info("Using executor %s", executor_name) - - return DEFAULT_EXECUTOR - - -class Executors: - LocalExecutor = "LocalExecutor" - SequentialExecutor = "SequentialExecutor" - CeleryExecutor = "CeleryExecutor" - DaskExecutor = "DaskExecutor" - KubernetesExecutor = "KubernetesExecutor" - - -def _get_executor(executor_name): - """ - Creates a new instance of the named executor. - In case the executor name is not know in airflow, - look for it in the plugins - """ - if executor_name == Executors.LocalExecutor: - return LocalExecutor() - elif executor_name == Executors.SequentialExecutor: - return SequentialExecutor() - elif executor_name == Executors.CeleryExecutor: - from airflow.executors.celery_executor import CeleryExecutor - return CeleryExecutor() - elif executor_name == Executors.DaskExecutor: - from airflow.executors.dask_executor import DaskExecutor - return DaskExecutor() - elif executor_name == Executors.KubernetesExecutor: - from airflow.executors.kubernetes_executor import KubernetesExecutor - return KubernetesExecutor() - else: - # Loading plugins - _integrate_plugins() - executor_path = executor_name.split('.') - if len(executor_path) != 2: - raise AirflowException( - "Executor {0} not supported: " - "please specify in format plugin_module.executor".format(executor_name)) - - if executor_path[0] in globals(): - return globals()[executor_path[0]].__dict__[executor_path[1]]() - else: - raise AirflowException("Executor {0} not supported.".format(executor_name)) +"""Executors.""" diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 506713a741f2b..a0109b49d9f3d 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -16,67 +14,86 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +""" +Base executor - this is the base class for all the implemented executors. +""" from collections import OrderedDict +from typing import Any, Dict, List, Optional, Set, Tuple, Union -# To avoid circular imports -import airflow.utils.dag_processing -from airflow.configuration import conf +from airflow import LoggingMixin, conf +from airflow.models import TaskInstance +from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKeyType from airflow.stats import Stats -from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State -PARALLELISM = conf.getint('core', 'PARALLELISM') +PARALLELISM: int = conf.getint('core', 'PARALLELISM') +NOT_STARTED_MESSAGE = "The executor should be started first!" -class BaseExecutor(LoggingMixin): +# Command to execute - might be either string or list of strings +# with the same semantics as subprocess.Popen +CommandType = Union[str, List[str]] - def __init__(self, parallelism=PARALLELISM): - """ - Class to derive in order to interface with executor-type systems - like Celery, Yarn and the likes. - :param parallelism: how many jobs should run at one time. Set to - ``0`` for infinity - :type parallelism: int - """ - self.parallelism = parallelism - self.queued_tasks = OrderedDict() - self.running = {} - self.event_buffer = {} +# Task that is queued. It contains all the information that is +# needed to run the task. +# +# Tuple of: command, priority, queue name, SimpleTaskInstance +QueuedTaskInstanceType = Tuple[CommandType, int, Optional[str], SimpleTaskInstance] + + +class BaseExecutor(LoggingMixin): + """ + Class to derive in order to interface with executor-type systems + like Celery, Kubernetes, Local, Sequential and the likes. + + :param parallelism: how many jobs should run at one time. Set to + ``0`` for infinity + """ + def __init__(self, parallelism: int = PARALLELISM): + super().__init__() + self.parallelism: int = parallelism + self.queued_tasks: OrderedDict[TaskInstanceKeyType, QueuedTaskInstanceType] \ + = OrderedDict() + self.running: Set[TaskInstanceKeyType] = set() + self.event_buffer: Dict[TaskInstanceKeyType, Optional[str]] = {} def start(self): # pragma: no cover """ - Executors may need to get things started. For example LocalExecutor - starts N workers. + Executors may need to get things started. """ - def queue_command(self, simple_task_instance, command, priority=1, queue=None): - key = simple_task_instance.key - if key not in self.queued_tasks and key not in self.running: + def queue_command(self, + simple_task_instance: SimpleTaskInstance, + command: CommandType, + priority: int = 1, + queue: Optional[str] = None): + """Queues command to task""" + if simple_task_instance.key not in self.queued_tasks and simple_task_instance.key not in self.running: self.log.info("Adding to queue: %s", command) - self.queued_tasks[key] = (command, priority, queue, simple_task_instance) + self.queued_tasks[simple_task_instance.key] = (command, priority, queue, simple_task_instance) else: - self.log.info("could not queue task %s", key) + self.log.info("could not queue task %s", simple_task_instance.key) def queue_task_instance( self, - task_instance, - mark_success=False, - pickle_id=None, - ignore_all_deps=False, - ignore_depends_on_past=False, - ignore_task_deps=False, - ignore_ti_state=False, - pool=None, - cfg_path=None): + task_instance: TaskInstance, + mark_success: bool = False, + pickle_id: Optional[str] = None, + ignore_all_deps: bool = False, + ignore_depends_on_past: bool = False, + ignore_task_deps: bool = False, + ignore_ti_state: bool = False, + pool: Optional[str] = None, + cfg_path: Optional[str] = None): + """Queues task instance.""" pool = pool or task_instance.pool # TODO (edgarRd): AIRFLOW-1985: # cfg_path is needed to propagate the config values if using impersonation # (run_as_user), given that there are different code paths running tasks. # For a long term solution we need to address AIRFLOW-1986 - command = task_instance.command_as_list( + command_list_to_run = task_instance.command_as_list( local=True, mark_success=mark_success, ignore_all_deps=ignore_all_deps, @@ -87,29 +104,30 @@ def queue_task_instance( pickle_id=pickle_id, cfg_path=cfg_path) self.queue_command( - airflow.utils.dag_processing.SimpleTaskInstance(task_instance), - command, + SimpleTaskInstance(task_instance), + command_list_to_run, priority=task_instance.task.priority_weight_total, queue=task_instance.task.queue) - def has_task(self, task_instance): + def has_task(self, task_instance: TaskInstance) -> bool: """ - Checks if a task is either queued or running in this executor + Checks if a task is either queued or running in this executor. :param task_instance: TaskInstance :return: True if the task is known to this executor """ - if task_instance.key in self.queued_tasks or task_instance.key in self.running: - return True + return task_instance.key in self.queued_tasks or task_instance.key in self.running - def sync(self): + def sync(self) -> None: """ Sync will get called periodically by the heartbeat method. Executors should override this to perform gather statuses. """ - def heartbeat(self): - # Triggering new jobs + def heartbeat(self) -> None: + """ + Heartbeat sent to trigger new jobs. + """ if not self.parallelism: open_slots = len(self.queued_tasks) else: @@ -132,47 +150,65 @@ def heartbeat(self): self.log.debug("Calling the %s sync method", self.__class__) self.sync() - def trigger_tasks(self, open_slots): + def trigger_tasks(self, open_slots: int) -> None: """ - Trigger tasks + Triggers tasks :param open_slots: Number of open slots - :return: """ sorted_queue = sorted( [(k, v) for k, v in self.queued_tasks.items()], key=lambda x: x[1][1], reverse=True) for _ in range(min((open_slots, len(self.queued_tasks)))): - key, (command, _, queue, simple_ti) = sorted_queue.pop(0) + key, (command, _, _, simple_ti) = sorted_queue.pop(0) self.queued_tasks.pop(key) - self.running[key] = command + self.running.add(key) self.execute_async(key=key, command=command, - queue=queue, + queue=None, executor_config=simple_ti.executor_config) - def change_state(self, key, state): + def change_state(self, key: TaskInstanceKeyType, state: str) -> None: + """ + Changes state of the task. + + :param key: Unique key for the task instance + :param state: State to set for the task. + """ self.log.debug("Changing state: %s", key) - self.running.pop(key, None) + try: + self.running.remove(key) + except KeyError: + self.log.debug('Could not find key: %s', str(key)) self.event_buffer[key] = state - def fail(self, key): + def fail(self, key: TaskInstanceKeyType) -> None: + """ + Set fail state for the event. + + :param key: Unique key for the task instance + """ self.change_state(key, State.FAILED) - def success(self, key): + def success(self, key: TaskInstanceKeyType) -> None: + """ + Set success state for the event. + + :param key: Unique key for the task instance + """ self.change_state(key, State.SUCCESS) - def get_event_buffer(self, dag_ids=None): + def get_event_buffer(self, dag_ids=None) -> Dict[TaskInstanceKeyType, Optional[str]]: """ Returns and flush the event buffer. In case dag_ids is specified it will only return and flush events for the given dag_ids. Otherwise - it returns and flushes all + it returns and flushes all events. :param dag_ids: to dag_ids to return events for, if None returns all :return: a dict of events """ - cleared_events = dict() + cleared_events: Dict[TaskInstanceKeyType, Optional[str]] = dict() if dag_ids is None: cleared_events = self.event_buffer self.event_buffer = dict() @@ -185,16 +221,21 @@ def get_event_buffer(self, dag_ids=None): return cleared_events def execute_async(self, - key, - command, - queue=None, - executor_config=None): # pragma: no cover + key: TaskInstanceKeyType, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None) -> None: # pragma: no cover """ This method will execute the command asynchronously. + + :param key: Unique key for the task instance + :param command: Command to run + :param queue: name of the queue + :param executor_config: Configuration passed to the executor. """ raise NotImplementedError() - def end(self): # pragma: no cover + def end(self) -> None: # pragma: no cover """ This method is called when the caller is done submitting job and wants to wait synchronously for the job submitted previously to be diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index 83fc44b59f3f9..906531757d3ee 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -16,20 +15,23 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Celery executor.""" import math import os import subprocess import time import traceback from multiprocessing import Pool, cpu_count +from typing import Any, List, Optional, Tuple, Union -from celery import Celery, states as celery_states +from celery import Celery, Task, states as celery_states +from celery.result import AsyncResult from airflow.config_templates.default_celery import DEFAULT_CELERY_CONFIG from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.executors.base_executor import BaseExecutor +from airflow.executors.base_executor import BaseExecutor, CommandType, QueuedTaskInstanceType +from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKeyType, TaskInstanceStateType from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.module_loading import import_string from airflow.utils.timeout import timeout @@ -57,7 +59,8 @@ @app.task -def execute_command(command_to_exec): +def execute_command(command_to_exec: str) -> None: + """Executes command.""" log = LoggingMixin().log log.info("Executing command in Celery: %s", command_to_exec) env = os.environ.copy() @@ -67,7 +70,6 @@ def execute_command(command_to_exec): except subprocess.CalledProcessError as e: log.exception('execute_command encountered a CalledProcessError') log.error(e.output) - raise AirflowException('Celery command failed') @@ -81,12 +83,13 @@ class ExceptionWithTraceback: :type exception_traceback: str """ - def __init__(self, exception, exception_traceback): + def __init__(self, exception: Exception, exception_traceback: str): self.exception = exception self.traceback = exception_traceback -def fetch_celery_task_state(celery_task): +def fetch_celery_task_state(celery_task: Tuple[TaskInstanceKeyType, AsyncResult]) \ + -> Union[TaskInstanceStateType, ExceptionWithTraceback]: """ Fetch and return the state of the given celery task. The scope of this function is global so that it can be called by subprocesses in the pool. @@ -102,22 +105,27 @@ def fetch_celery_task_state(celery_task): with timeout(seconds=2): # Accessing state property of celery task will make actual network request # to get the current state of the task. - res = (celery_task[0], celery_task[1].state) - except Exception as e: + return celery_task[0], celery_task[1].state + except Exception as e: # pylint: disable=broad-except exception_traceback = "Celery Task ID: {}\n{}".format(celery_task[0], traceback.format_exc()) - res = ExceptionWithTraceback(e, exception_traceback) - return res + return ExceptionWithTraceback(e, exception_traceback) + +# Task instance that is sent over Celery queues +# TaskInstanceKeyType, SimpleTaskInstance, Command, queue_name, CallableTask +TaskInstanceInCelery = Tuple[TaskInstanceKeyType, SimpleTaskInstance, CommandType, Optional[str], Task] -def send_task_to_executor(task_tuple): - key, _, command, queue, task = task_tuple + +def send_task_to_executor(task_tuple: TaskInstanceInCelery) \ + -> Tuple[TaskInstanceKeyType, CommandType, Union[AsyncResult, ExceptionWithTraceback]]: + """Sends task to executor.""" + key, _, command, queue, task_to_run = task_tuple try: with timeout(seconds=2): - result = task.apply_async(args=[command], queue=queue) - except Exception as e: - exception_traceback = "Celery Task ID: {}\n{}".format(key, - traceback.format_exc()) + result = task_to_run.apply_async(args=[command], queue=queue) + except Exception as e: # pylint: disable=broad-except + exception_traceback = "Celery Task ID: {}\n{}".format(key, traceback.format_exc()) result = ExceptionWithTraceback(e, exception_traceback) return key, command, result @@ -148,13 +156,13 @@ def __init__(self): self.tasks = {} self.last_state = {} - def start(self): + def start(self) -> None: self.log.debug( 'Starting Celery Executor using %s processes for syncing', self._sync_parallelism ) - def _num_tasks_per_send_process(self, to_send_count): + def _num_tasks_per_send_process(self, to_send_count: int) -> int: """ How many Celery tasks should each worker process send. @@ -164,34 +172,29 @@ def _num_tasks_per_send_process(self, to_send_count): return max(1, int(math.ceil(1.0 * to_send_count / self._sync_parallelism))) - def _num_tasks_per_fetch_process(self): + def _num_tasks_per_fetch_process(self) -> int: """ How many Celery tasks should be sent to each worker process. :return: Number of tasks that should be used per process :rtype: int """ - return max(1, - int(math.ceil(1.0 * len(self.tasks) / self._sync_parallelism))) + return max(1, int(math.ceil(1.0 * len(self.tasks) / self._sync_parallelism))) - def trigger_tasks(self, open_slots): + def trigger_tasks(self, open_slots: int) -> None: """ Overwrite trigger_tasks function from BaseExecutor :param open_slots: Number of open slots :return: """ - sorted_queue = sorted( - [(k, v) for k, v in self.queued_tasks.items()], - key=lambda x: x[1][1], - reverse=True) + sorted_queue = self.order_queued_tasks_by_priority() - task_tuples_to_send = [] + task_tuples_to_send: List[TaskInstanceInCelery] = [] for _ in range(min((open_slots, len(self.queued_tasks)))): key, (command, _, queue, simple_ti) = sorted_queue.pop(0) - task_tuples_to_send.append((key, simple_ti, command, queue, - execute_command)) + task_tuples_to_send.append((key, simple_ti, command, queue, execute_command)) cached_celery_backend = None if task_tuples_to_send: @@ -202,7 +205,7 @@ def trigger_tasks(self, open_slots): cached_celery_backend = tasks[0].backend if task_tuples_to_send: - # Use chunking instead of a work queue to reduce context switching + # Use chunks instead of a work queue to reduce context switching # since tasks are roughly uniform in size chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send)) num_processes = min(len(task_tuples_to_send), self._sync_parallelism) @@ -227,11 +230,22 @@ def trigger_tasks(self, open_slots): # and expect scheduler loop to deal with it. self.queued_tasks.pop(key) result.backend = cached_celery_backend - self.running[key] = command + self.running.add(key) self.tasks[key] = result self.last_state[key] = celery_states.PENDING - def sync(self): + def order_queued_tasks_by_priority(self) -> List[Tuple[TaskInstanceKeyType, QueuedTaskInstanceType]]: + """ + Orders the queued tasks by priority. + + :return: List of tuples from the queued_tasks according to the priority. + """ + return sorted( + [(k, v) for k, v in self.queued_tasks.items()], + key=lambda x: x[1][1], + reverse=True) + + def sync(self) -> None: num_processes = min(len(self.tasks), self._sync_parallelism) if num_processes == 0: self.log.debug("No task to query celery, skipping sync") @@ -243,7 +257,7 @@ def sync(self): # Recreate the process pool each sync in case processes in the pool die self._sync_pool = Pool(processes=num_processes) - # Use chunking instead of a work queue to reduce context switching since tasks are + # Use chunks instead of a work queue to reduce context switching since tasks are # roughly uniform in size chunksize = self._num_tasks_per_fetch_process() @@ -256,6 +270,12 @@ def sync(self): self._sync_pool.join() self.log.debug("Inquiries completed.") + self.update_task_states(task_keys_to_states) + + def update_task_states(self, + task_keys_to_states: List[Union[TaskInstanceStateType, + ExceptionWithTraceback]]) -> None: + """Updates states of the tasks.""" for key_and_state in task_keys_to_states: if isinstance(key_and_state, ExceptionWithTraceback): self.log.error( @@ -264,30 +284,44 @@ def sync(self): ) continue key, state = key_and_state - try: - if self.last_state[key] != state: - if state == celery_states.SUCCESS: - self.success(key) - del self.tasks[key] - del self.last_state[key] - elif state == celery_states.FAILURE: - self.fail(key) - del self.tasks[key] - del self.last_state[key] - elif state == celery_states.REVOKED: - self.fail(key) - del self.tasks[key] - del self.last_state[key] - else: - self.log.info("Unexpected state: %s", state) - self.last_state[key] = state - except Exception: - self.log.exception("Error syncing the Celery executor, ignoring it.") - - def end(self, synchronous=False): + self.update_task_state(key, state) + + def update_task_state(self, key: TaskInstanceKeyType, state: str) -> None: + """Updates state of a single task.""" + # noinspection PyBroadException + try: + if self.last_state[key] != state: + if state == celery_states.SUCCESS: + self.success(key) + del self.tasks[key] + del self.last_state[key] + elif state == celery_states.FAILURE: + self.fail(key) + del self.tasks[key] + del self.last_state[key] + elif state == celery_states.REVOKED: + self.fail(key) + del self.tasks[key] + del self.last_state[key] + else: + self.log.info("Unexpected state: %s", state) + self.last_state[key] = state + except Exception: # pylint: disable=broad-except + self.log.exception("Error syncing the Celery executor, ignoring it.") + + def end(self, synchronous: bool = False) -> None: if synchronous: - while any([ - task.state not in celery_states.READY_STATES - for task in self.tasks.values()]): + while any([task.state not in celery_states.READY_STATES for task in self.tasks.values()]): time.sleep(5) self.sync() + + def execute_async(self, + key: TaskInstanceKeyType, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None): + """Do not allow async execution for Celery executor.""" + raise AirflowException("No Async execution for Celery executor.") + + def terminate(self): + pass diff --git a/airflow/executors/dask_executor.py b/airflow/executors/dask_executor.py index b355c4b6e7e86..01a8f703128cd 100644 --- a/airflow/executors/dask_executor.py +++ b/airflow/executors/dask_executor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -16,14 +15,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Dask executor.""" import subprocess -import warnings +from typing import Any, Dict, Optional -import distributed +from distributed import Client, Future, as_completed +from distributed.security import Security from airflow.configuration import conf -from airflow.executors.base_executor import BaseExecutor +from airflow.executors.base_executor import NOT_STARTED_MESSAGE, BaseExecutor, CommandType +from airflow.models.taskinstance import TaskInstanceKeyType class DaskExecutor(BaseExecutor): @@ -31,21 +32,20 @@ class DaskExecutor(BaseExecutor): DaskExecutor submits tasks to a Dask Distributed cluster. """ def __init__(self, cluster_address=None): + super().__init__(parallelism=0) if cluster_address is None: cluster_address = conf.get('dask', 'cluster_address') - if not cluster_address: - raise ValueError( - 'Please provide a Dask cluster address in airflow.cfg') + assert cluster_address, 'Please provide a Dask cluster address in airflow.cfg' self.cluster_address = cluster_address # ssl / tls parameters self.tls_ca = conf.get('dask', 'tls_ca') self.tls_key = conf.get('dask', 'tls_key') self.tls_cert = conf.get('dask', 'tls_cert') - super().__init__(parallelism=0) + self.client: Optional[Client] = None + self.futures: Optional[Dict[Future, TaskInstanceKeyType]] = None - def start(self): + def start(self) -> None: if self.tls_ca or self.tls_key or self.tls_cert: - from distributed.security import Security security = Security( tls_client_key=self.tls_key, tls_client_cert=self.tls_cert, @@ -55,23 +55,25 @@ def start(self): else: security = None - self.client = distributed.Client(self.cluster_address, security=security) + self.client = Client(self.cluster_address, security=security) self.futures = {} - def execute_async(self, key, command, queue=None, executor_config=None): - if queue is not None: - warnings.warn( - 'DaskExecutor does not support queues. ' - 'All tasks will be run in the same cluster' - ) + def execute_async(self, + key: TaskInstanceKeyType, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None) -> None: + assert self.futures, NOT_STARTED_MESSAGE def airflow_run(): return subprocess.check_call(command, close_fds=True) + assert self.client, "The Dask executor has not been started yet!" future = self.client.submit(airflow_run, pure=False) self.futures[future] = key - def _process_future(self, future): + def _process_future(self, future: Future) -> None: + assert self.futures, NOT_STARTED_MESSAGE if future.done(): key = self.futures[future] if future.exception(): @@ -84,15 +86,20 @@ def _process_future(self, future): self.success(key) self.futures.pop(future) - def sync(self): + def sync(self) -> None: + assert self.futures, NOT_STARTED_MESSAGE # make a copy so futures can be popped during iteration for future in self.futures.copy(): self._process_future(future) - def end(self): - for future in distributed.as_completed(self.futures.copy()): + def end(self) -> None: + assert self.client, NOT_STARTED_MESSAGE + assert self.futures, NOT_STARTED_MESSAGE + self.client.cancel(list(self.futures.keys())) + for future in as_completed(self.futures.copy()): self._process_future(future) def terminate(self): + assert self.futures, NOT_STARTED_MESSAGE self.client.cancel(self.futures.keys()) self.end() diff --git a/airflow/executors/executor_loader.py b/airflow/executors/executor_loader.py new file mode 100644 index 0000000000000..dcf2e398ca957 --- /dev/null +++ b/airflow/executors/executor_loader.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""All executors.""" +from typing import Optional + +from airflow.executors.base_executor import BaseExecutor + + +class ExecutorLoader: + """ + Keeps constants for all the currently available executors. + """ + + LOCAL_EXECUTOR = "LocalExecutor" + SEQUENTIAL_EXECUTOR = "SequentialExecutor" + CELERY_EXECUTOR = "CeleryExecutor" + DASK_EXECUTOR = "DaskExecutor" + KUBERNETES_EXECUTOR = "KubernetesExecutor" + + _default_executor: Optional[BaseExecutor] = None + + @classmethod + def get_default_executor(cls) -> BaseExecutor: + """Creates a new instance of the configured executor if none exists and returns it""" + if cls._default_executor is not None: + return cls._default_executor + + from airflow.configuration import conf + executor_name = conf.get('core', 'EXECUTOR') + + cls._default_executor = ExecutorLoader._get_executor(executor_name) + + from airflow import LoggingMixin + log = LoggingMixin().log + log.info("Using executor %s", executor_name) + + return cls._default_executor + + @staticmethod + def _get_executor(executor_name: str) -> BaseExecutor: + """ + Creates a new instance of the named executor. + In case the executor name is unknown in airflow, + look for it in the plugins + """ + if executor_name == ExecutorLoader.LOCAL_EXECUTOR: + from airflow.executors.local_executor import LocalExecutor + return LocalExecutor() + elif executor_name == ExecutorLoader.SEQUENTIAL_EXECUTOR: + from airflow.executors.sequential_executor import SequentialExecutor + return SequentialExecutor() + elif executor_name == ExecutorLoader.CELERY_EXECUTOR: + from airflow.executors.celery_executor import CeleryExecutor + return CeleryExecutor() + elif executor_name == ExecutorLoader.DASK_EXECUTOR: + from airflow.executors.dask_executor import DaskExecutor + return DaskExecutor() + elif executor_name == ExecutorLoader.KUBERNETES_EXECUTOR: + from airflow.executors.kubernetes_executor import KubernetesExecutor + return KubernetesExecutor() + else: + # Load plugins here for executors as at that time the plugins might not have been initialized yet + # TODO: verify the above and remove two lines below in case plugins are always initialized first + from airflow import plugins_manager + plugins_manager.integrate_executor_plugins() + executor_path = executor_name.split('.') + assert len(executor_path) == 2, f"Executor {executor_name} not supported: " \ + f"please specify in format plugin_module.executor" + + assert executor_path[0] in globals(), f"Executor {executor_name} not supported" + return globals()[executor_path[0]].__dict__[executor_path[1]]() diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index f6d61472eeaf1..0bc6db7e29086 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -16,28 +16,31 @@ # under the License. """Kubernetes executor""" import base64 +import datetime import hashlib import json import multiprocessing import re -from queue import Empty -from typing import Union +from queue import Empty, Queue # pylint: disable=unused-import +from typing import Any, Dict, Optional, Tuple, Union from uuid import uuid4 import kubernetes from dateutil import parser from kubernetes import client, watch +from kubernetes.client import Configuration from kubernetes.client.rest import ApiException from airflow import settings from airflow.configuration import conf from airflow.exceptions import AirflowConfigException, AirflowException -from airflow.executors.base_executor import BaseExecutor +from airflow.executors.base_executor import NOT_STARTED_MESSAGE, BaseExecutor, CommandType from airflow.kubernetes.kube_client import get_kube_client from airflow.kubernetes.pod_generator import PodGenerator from airflow.kubernetes.pod_launcher import PodLauncher from airflow.kubernetes.worker_configuration import WorkerConfiguration from airflow.models import KubeResourceVersion, KubeWorkerIdentifier, TaskInstance +from airflow.models.taskinstance import TaskInstanceKeyType from airflow.utils.db import create_session, provide_session from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State @@ -45,6 +48,15 @@ MAX_POD_ID_LEN = 253 MAX_LABEL_LEN = 63 +# TaskInstance key, command, configuration +KubernetesJobType = Tuple[TaskInstanceKeyType, CommandType, Any] + +# key, state, pod_id, resource_version +KubernetesResultsType = Tuple[TaskInstanceKeyType, Optional[str], str, str] + +# pod_id, state, labels, resource_version +KubernetesWatchType = Tuple[str, Optional[str], Dict[str, str], str] + class KubeConfig: # pylint: disable=too-many-instance-attributes """Configuration for Kubernetes""" @@ -241,7 +253,12 @@ def _validate(self): class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin): """Watches for Kubernetes jobs""" - def __init__(self, namespace, watcher_queue, resource_version, worker_uuid, kube_config): + def __init__(self, + namespace: str, + watcher_queue: 'Queue[KubernetesWatchType]', + resource_version: Optional[str], + worker_uuid: Optional[str], + kube_config: Configuration): multiprocessing.Process.__init__(self) self.namespace = namespace self.worker_uuid = worker_uuid @@ -249,9 +266,10 @@ def __init__(self, namespace, watcher_queue, resource_version, worker_uuid, kube self.resource_version = resource_version self.kube_config = kube_config - def run(self): + def run(self) -> None: """Performs watching""" - kube_client = get_kube_client() + kube_client: client.CoreV1Api = get_kube_client() + assert self.worker_uuid, NOT_STARTED_MESSAGE while True: try: self.resource_version = self._run(kube_client, self.resource_version, @@ -263,7 +281,11 @@ def run(self): self.log.warning('Watch died gracefully, starting back up with: ' 'last resource_version: %s', self.resource_version) - def _run(self, kube_client, resource_version, worker_uuid, kube_config): + def _run(self, + kube_client: client.CoreV1Api, + resource_version: Optional[str], + worker_uuid: str, + kube_config: Any) -> Optional[str]: self.log.info( 'Event: and now my watch begins starting at resource_version: %s', resource_version @@ -277,7 +299,7 @@ def _run(self, kube_client, resource_version, worker_uuid, kube_config): for key, value in kube_config.kube_client_request_args.items(): kwargs[key] = value - last_resource_version = None + last_resource_version: Optional[str] = None for event in watcher.stream(kube_client.list_namespaced_pod, self.namespace, **kwargs): task = event['object'] @@ -295,7 +317,7 @@ def _run(self, kube_client, resource_version, worker_uuid, kube_config): return last_resource_version - def process_error(self, event): + def process_error(self, event: Any) -> str: """Process error response""" self.log.error( 'Encountered Error response from k8s list namespaced pod stream => %s', @@ -314,7 +336,7 @@ def process_error(self, event): (raw_object['reason'], raw_object['code'], raw_object['message']) ) - def process_status(self, pod_id, status, labels, resource_version): + def process_status(self, pod_id: str, status: str, labels: Dict[str, str], resource_version: str) -> None: """Process status response""" if status == 'Pending': self.log.info('Event: %s Pending', pod_id) @@ -335,7 +357,13 @@ def process_status(self, pod_id, status, labels, resource_version): class AirflowKubernetesScheduler(LoggingMixin): """Airflow Scheduler for Kubernetes""" - def __init__(self, kube_config, task_queue, result_queue, kube_client, worker_uuid): + def __init__(self, + kube_config: Any, + task_queue: 'Queue[KubernetesJobType]', + result_queue: 'Queue[KubernetesResultsType]', + kube_client: client.CoreV1Api, + worker_uuid: str): + super().__init__() self.log.debug("Creating Kubernetes executor") self.kube_config = kube_config self.task_queue = task_queue @@ -350,7 +378,7 @@ def __init__(self, kube_config, task_queue, result_queue, kube_client, worker_uu self.worker_uuid = worker_uuid self.kube_watcher = self._make_kube_watcher() - def _make_kube_watcher(self): + def _make_kube_watcher(self) -> KubernetesJobWatcher: resource_version = KubeResourceVersion.get_current_resource_version() watcher = KubernetesJobWatcher(self.namespace, self.watcher_queue, resource_version, self.worker_uuid, self.kube_config) @@ -366,7 +394,7 @@ def _health_check_kube_watcher(self): 'Process died for unknown reasons') self.kube_watcher = self._make_kube_watcher() - def run_next(self, next_job): + def run_next(self, next_job: KubernetesJobType) -> None: """ The run_next command will check the task_queue for any un-run jobs. It will then create a unique job-id, launch that job in the cluster, @@ -408,7 +436,7 @@ def delete_pod(self, pod_id: str) -> None: if e.status != 404: raise - def sync(self): + def sync(self) -> None: """ The sync function checks the status of all currently running kubernetes jobs. If a job is completed, its status is placed in the result queue to @@ -428,7 +456,7 @@ def sync(self): except Empty: break - def process_watcher_task(self, task): + def process_watcher_task(self, task: KubernetesWatchType) -> None: """Process the task by watcher.""" pod_id, state, labels, resource_version = task self.log.info( @@ -441,7 +469,7 @@ def process_watcher_task(self, task): self.result_queue.put((key, state, pod_id, resource_version)) @staticmethod - def _strip_unsafe_kubernetes_special_chars(string): + def _strip_unsafe_kubernetes_special_chars(string: str) -> str: """ Kubernetes only supports lowercase alphanumeric characters and "-" and "." in the pod name @@ -456,7 +484,7 @@ def _strip_unsafe_kubernetes_special_chars(string): return ''.join(ch.lower() for ind, ch in enumerate(string) if ch.isalnum()) @staticmethod - def _make_safe_pod_id(safe_dag_id, safe_task_id, safe_uuid): + def _make_safe_pod_id(safe_dag_id: str, safe_task_id: str, safe_uuid: str) -> str: """ Kubernetes pod names must be <= 253 chars and must pass the following regex for validation @@ -474,7 +502,7 @@ def _make_safe_pod_id(safe_dag_id, safe_task_id, safe_uuid): return safe_pod_id @staticmethod - def _make_safe_label_value(string): + def _make_safe_label_value(string: str) -> str: """ Valid label values must be 63 characters or less and must be empty or begin and end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), underscores (_), @@ -493,7 +521,7 @@ def _make_safe_label_value(string): return safe_label @staticmethod - def _create_pod_id(dag_id, task_id): + def _create_pod_id(dag_id: str, task_id: str) -> str: safe_dag_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars( dag_id) safe_task_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars( @@ -504,7 +532,7 @@ def _create_pod_id(dag_id, task_id): safe_uuid) @staticmethod - def _label_safe_datestring_to_datetime(string): + def _label_safe_datestring_to_datetime(string: str) -> datetime.datetime: """ Kubernetes doesn't permit ":" in labels. ISO datetime format uses ":" but not "_", let's @@ -516,7 +544,7 @@ def _label_safe_datestring_to_datetime(string): return parser.parse(string.replace('_plus_', '+').replace("_", ":")) @staticmethod - def _datetime_to_label_safe_datestring(datetime_obj): + def _datetime_to_label_safe_datestring(datetime_obj: datetime.datetime) -> str: """ Kubernetes doesn't like ":" in labels, since ISO datetime format uses ":" but not "_" let's @@ -527,7 +555,7 @@ def _datetime_to_label_safe_datestring(datetime_obj): """ return datetime_obj.isoformat().replace(":", "_").replace('+', '_plus_') - def _labels_to_key(self, labels): + def _labels_to_key(self, labels: Dict[str, str]) -> Optional[TaskInstanceKeyType]: try_num = 1 try: try_num = int(labels.get('try_number', '1')) @@ -567,14 +595,14 @@ def _labels_to_key(self, labels): ) dag_id = task.dag_id task_id = task.task_id - return (dag_id, task_id, ex_time, try_num) + return dag_id, task_id, ex_time, try_num self.log.warning( 'Failed to find and match task details to a pod; labels: %s', labels ) return None - def _flush_watcher_queue(self): + def _flush_watcher_queue(self) -> None: self.log.debug('Executor shutting down, watcher_queue approx. size=%d', self.watcher_queue.qsize()) while True: try: @@ -585,7 +613,7 @@ def _flush_watcher_queue(self): except Empty: break - def terminate(self): + def terminate(self) -> None: """Terminates the watcher.""" self.log.debug("Terminating kube_watcher...") self.kube_watcher.terminate() @@ -601,18 +629,19 @@ def terminate(self): class KubernetesExecutor(BaseExecutor, LoggingMixin): """Executor for Kubernetes""" + def __init__(self): self.kube_config = KubeConfig() - self.task_queue = None - self.result_queue = None - self.kube_scheduler = None - self.kube_client = None - self.worker_uuid = None self._manager = multiprocessing.Manager() + self.task_queue: 'Queue[KubernetesJobType]' = self._manager.Queue() + self.result_queue: 'Queue[KubernetesResultsType]' = self._manager.Queue() + self.kube_scheduler: Optional[AirflowKubernetesScheduler] = None + self.kube_client: Optional[client.CoreV1Api] = None + self.worker_uuid: Optional[str] = None super().__init__(parallelism=self.kube_config.parallelism) @provide_session - def clear_not_launched_queued_tasks(self, session=None): + def clear_not_launched_queued_tasks(self, session=None) -> None: """ If the airflow scheduler restarts with pending "Queued" tasks, the tasks may or may not @@ -628,6 +657,7 @@ def clear_not_launched_queued_tasks(self, session=None): proper support for State.LAUNCHED """ + assert self.kube_client, NOT_STARTED_MESSAGE queued_tasks = session\ .query(TaskInstance)\ .filter(TaskInstance.state == State.QUEUED).all() @@ -667,7 +697,7 @@ def clear_not_launched_queued_tasks(self, session=None): TaskInstance.execution_date == task.execution_date ).update({TaskInstance.state: State.NONE}) - def _inject_secrets(self): + def _inject_secrets(self) -> None: def _create_or_update_secret(secret_name, secret_path): try: return self.kube_client.create_namespaced_secret( @@ -703,18 +733,17 @@ def _create_or_update_secret(secret_name, secret_path): for service_account in name_path_pair_list: _create_or_update_secret(service_account['name'], service_account['path']) - def start(self): + def start(self) -> None: """Starts the executor""" self.log.info('Start Kubernetes executor') self.worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid() + assert self.worker_uuid, "Could not get worker_uuid" self.log.debug('Start with worker_uuid: %s', self.worker_uuid) # always need to reset resource version since we don't know # when we last started, note for behavior below # https://github.com/kubernetes-client/python/blob/master/kubernetes/docs # /CoreV1Api.md#list_namespaced_pod KubeResourceVersion.reset_resource_version() - self.task_queue = self._manager.Queue() - self.result_queue = self._manager.Queue() self.kube_client = get_kube_client() self.kube_scheduler = AirflowKubernetesScheduler( self.kube_config, self.task_queue, self.result_queue, @@ -723,7 +752,11 @@ def start(self): self._inject_secrets() self.clear_not_launched_queued_tasks() - def execute_async(self, key, command, queue=None, executor_config=None): + def execute_async(self, + key: TaskInstanceKeyType, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None) -> None: """Executes task asynchronously""" self.log.info( 'Add task %s with command %s with executor_config %s', @@ -731,14 +764,19 @@ def execute_async(self, key, command, queue=None, executor_config=None): ) kube_executor_config = PodGenerator.from_obj(executor_config) + assert self.task_queue, NOT_STARTED_MESSAGE self.task_queue.put((key, command, kube_executor_config)) - def sync(self): + def sync(self) -> None: """Synchronize task state.""" if self.running: self.log.debug('self.running: %s', self.running) if self.queued_tasks: self.log.debug('self.queued: %s', self.queued_tasks) + assert self.kube_scheduler, NOT_STARTED_MESSAGE + assert self.kube_config, NOT_STARTED_MESSAGE + assert self.result_queue, NOT_STARTED_MESSAGE + assert self.task_queue, NOT_STARTED_MESSAGE self.kube_scheduler.sync() last_resource_version = None @@ -778,18 +816,20 @@ def sync(self): break # pylint: enable=too-many-nested-blocks - def _change_state(self, key, state, pod_id: str) -> None: + def _change_state(self, key: TaskInstanceKeyType, state: Optional[str], pod_id: str) -> None: if state != State.RUNNING: if self.kube_config.delete_worker_pods: + assert self.kube_scheduler, NOT_STARTED_MESSAGE self.kube_scheduler.delete_pod(pod_id) self.log.info('Deleted pod: %s', str(key)) try: - self.running.pop(key) + self.running.remove(key) except KeyError: self.log.debug('Could not find key: %s', str(key)) self.event_buffer[key] = state - def _flush_task_queue(self): + def _flush_task_queue(self) -> None: + assert self.task_queue, NOT_STARTED_MESSAGE self.log.debug('Executor shutting down, task_queue approximate size=%d', self.task_queue.qsize()) while True: try: @@ -800,7 +840,8 @@ def _flush_task_queue(self): except Empty: break - def _flush_result_queue(self): + def _flush_result_queue(self) -> None: + assert self.result_queue, NOT_STARTED_MESSAGE self.log.debug('Executor shutting down, result_queue approximate size=%d', self.result_queue.qsize()) while True: # pylint: disable=too-many-nested-blocks try: @@ -820,8 +861,11 @@ def _flush_result_queue(self): except Empty: break - def end(self): + def end(self) -> None: """Called when the executor shuts down""" + assert self.task_queue, NOT_STARTED_MESSAGE + assert self.result_queue, NOT_STARTED_MESSAGE + assert self.kube_scheduler, NOT_STARTED_MESSAGE self.log.info('Shutting down Kubernetes executor') self.log.debug('Flushing task_queue...') self._flush_task_queue() @@ -833,3 +877,6 @@ def end(self): if self.kube_scheduler: self.kube_scheduler.terminate() self._manager.shutdown() + + def terminate(self): + """Terminate the executor is not doing anything.""" diff --git a/airflow/executors/local_executor.py b/airflow/executors/local_executor.py index 0086dcbf12a01..9173a8e61fb29 100644 --- a/airflow/executors/local_executor.py +++ b/airflow/executors/local_executor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -43,40 +42,44 @@ This option could lead to the unification of the executor implementations, running locally, into just one `LocalExecutor` with multiple modes. """ - -import multiprocessing import subprocess -from queue import Empty - -from airflow.executors.base_executor import BaseExecutor +from multiprocessing import Manager, Process +from multiprocessing.managers import SyncManager +from queue import Empty, Queue # pylint: disable=unused-import # noqa: F401 +from typing import Any, List, Optional, Tuple, Union # pylint: disable=unused-import # noqa: F401 + +from airflow import AirflowException +from airflow.executors.base_executor import NOT_STARTED_MESSAGE, PARALLELISM, BaseExecutor, CommandType +from airflow.models.taskinstance import ( # pylint: disable=unused-import # noqa: F401 + TaskInstanceKeyType, TaskInstanceStateType, +) from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State +# This is a work to be executed by a worker. +# It can Key and Command - but it can also be None, None which is actually a +# "Poison Pill" - worker seeing Poison Pill should take the pill and ... die instantly. +ExecutorWorkType = Tuple[Optional[TaskInstanceKeyType], Optional[CommandType]] -class LocalWorker(multiprocessing.Process, LoggingMixin): - """LocalWorker Process implementation to run airflow commands. Executes the given - command and puts the result into a result queue when done, terminating execution.""" +class LocalWorkerBase(Process, LoggingMixin): + """ + LocalWorkerBase implementation to run airflow commands. Executes the given + command and puts the result into a result queue when done, terminating execution. - def __init__(self, result_queue): - """ - :param result_queue: the queue to store result states tuples (key, State) - :type result_queue: multiprocessing.Queue - """ + :param result_queue: the queue to store result state + """ + def __init__(self, result_queue: 'Queue[TaskInstanceStateType]'): super().__init__() - self.daemon = True - self.result_queue = result_queue - self.key = None - self.command = None + self.daemon: bool = True + self.result_queue: 'Queue[TaskInstanceStateType]' = result_queue - def execute_work(self, key, command): + def execute_work(self, key: TaskInstanceKeyType, command: CommandType) -> None: """ Executes command received and stores result state in queue. - :param key: the key to identify the TI - :type key: tuple(dag_id, task_id, execution_date) + :param key: the key to identify the task instance :param command: the command to execute - :type command: str """ if key is None: return @@ -89,87 +92,146 @@ def execute_work(self, key, command): self.log.error("Failed to execute task %s.", str(e)) self.result_queue.put((key, state)) - def run(self): - self.execute_work(self.key, self.command) +class LocalWorker(LocalWorkerBase): + """ + Local worker that executes the task. -class QueuedLocalWorker(LocalWorker): + :param result_queue: queue where results of the tasks are put. + :param key: key identifying task instance + :param command: Command to execute + """ + def __init__(self, + result_queue: 'Queue[TaskInstanceStateType]', + key: TaskInstanceKeyType, + command: CommandType): + super().__init__(result_queue) + self.key: TaskInstanceKeyType = key + self.command: CommandType = command - """LocalWorker implementation that is waiting for tasks from a queue and will - continue executing commands as they become available in the queue. It will terminate - execution once the poison token is found.""" + def run(self) -> None: + self.execute_work(key=self.key, command=self.command) - def __init__(self, task_queue, result_queue): + +class QueuedLocalWorker(LocalWorkerBase): + """ + LocalWorker implementation that is waiting for tasks from a queue and will + continue executing commands as they become available in the queue. + It will terminate execution once the poison token is found. + + :param task_queue: queue from which worker reads tasks + :param result_queue: queue where worker puts results after finishing tasks + """ + def __init__(self, + task_queue: 'Queue[ExecutorWorkType]', + result_queue: 'Queue[TaskInstanceStateType]'): super().__init__(result_queue=result_queue) self.task_queue = task_queue - def run(self): + def run(self) -> None: while True: key, command = self.task_queue.get() try: - if key is None: + if key is None or command is None: # Received poison pill, no more tasks to run break - self.execute_work(key, command) + self.execute_work(key=key, command=command) finally: self.task_queue.task_done() class LocalExecutor(BaseExecutor): """ - LocalExecutor executes tasks locally in parallel. It uses the - multiprocessing Python library and queues to parallelize the execution + LocalExecutor executes tasks locally in parallel. + It uses the multiprocessing Python library and queues to parallelize the execution of tasks. - """ - class _UnlimitedParallelism: - """Implements LocalExecutor with unlimited parallelism, starting one process - per each command to execute.""" + :param parallelism: how many parallel processes are run in the executor + """ + def __init__(self, parallelism: int = PARALLELISM): + super().__init__(parallelism=parallelism) + self.manager: Optional[SyncManager] = None + self.result_queue: Optional['Queue[TaskInstanceStateType]'] = None + self.workers: List[QueuedLocalWorker] = [] + self.workers_used: int = 0 + self.workers_active: int = 0 + self.impl: Optional[Union['LocalExecutor.UnlimitedParallelism', + 'LocalExecutor.LimitedParallelism']] = None + + class UnlimitedParallelism: + """ + Implements LocalExecutor with unlimited parallelism, starting one process + per each command to execute. - def __init__(self, executor): - """ - :param executor: the executor instance to implement. - :type executor: LocalExecutor - """ - self.executor = executor + :param executor: the executor instance to implement. + """ + def __init__(self, executor: 'LocalExecutor'): + self.executor: 'LocalExecutor' = executor - def start(self): + def start(self) -> None: + """Starts the executor.""" self.executor.workers_used = 0 self.executor.workers_active = 0 - def execute_async(self, key, command): + # noinspection PyUnusedLocal + def execute_async(self, + key: TaskInstanceKeyType, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None) -> None: \ + # pylint: disable=unused-argument # pragma: no cover """ - :param key: the key to identify the TI - :type key: tuple(dag_id, task_id, execution_date) + Executes task asynchronously. + + :param key: the key to identify the task instance :param command: the command to execute - :type command: str + :param queue: Name of the queue + :param executor_config: configuration for the executor """ - local_worker = LocalWorker(self.executor.result_queue) - local_worker.key = key - local_worker.command = command + assert self.executor.result_queue, NOT_STARTED_MESSAGE + local_worker = LocalWorker(self.executor.result_queue, key=key, command=command) self.executor.workers_used += 1 self.executor.workers_active += 1 local_worker.start() - def sync(self): + def sync(self) -> None: + """ + Sync will get called periodically by the heartbeat method. + """ + if not self.executor.result_queue: + raise AirflowException("Executor should be started first") while not self.executor.result_queue.empty(): results = self.executor.result_queue.get() self.executor.change_state(*results) self.executor.workers_active -= 1 - def end(self): + def end(self) -> None: + """ + This method is called when the caller is done submitting job and + wants to wait synchronously for the job submitted previously to be + all done. + """ while self.executor.workers_active > 0: self.executor.sync() - class _LimitedParallelism: - """Implements LocalExecutor with limited parallelism using a task queue to - coordinate work distribution.""" - - def __init__(self, executor): - self.executor = executor + class LimitedParallelism: + """ + Implements LocalExecutor with limited parallelism using a task queue to + coordinate work distribution. - def start(self): + :param executor: the executor instance to implement. + """ + def __init__(self, executor: 'LocalExecutor'): + self.executor: 'LocalExecutor' = executor + self.queue: Optional['Queue[ExecutorWorkType]'] = None + + def start(self) -> None: + """Starts limited parallelism implementation.""" + if not self.executor.manager: + raise AirflowException("Executor must be started!") self.queue = self.executor.manager.Queue() + if not self.executor.result_queue: + raise AirflowException("Executor must be started!") self.executor.workers = [ QueuedLocalWorker(self.queue, self.executor.result_queue) for _ in range(self.executor.parallelism) @@ -177,19 +239,31 @@ def start(self): self.executor.workers_used = len(self.executor.workers) - for w in self.executor.workers: - w.start() + for worker in self.executor.workers: + worker.start() - def execute_async(self, key, command): + # noinspection PyUnusedLocal + def execute_async(self, + key: TaskInstanceKeyType, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None) -> None: \ + # pylint: disable=unused-argument # pragma: no cover """ - :param key: the key to identify the TI - :type key: tuple(dag_id, task_id, execution_date) + Executes task asynchronously. + + :param key: the key to identify the task instance :param command: the command to execute - :type command: str - """ + :param queue: name of the queue + :param executor_config: configuration for the executor + """ + assert self.queue, NOT_STARTED_MESSAGE self.queue.put((key, command)) def sync(self): + """ + Sync will get called periodically by the heartbeat method. + """ while True: try: results = self.executor.result_queue.get_nowait() @@ -201,7 +275,7 @@ def sync(self): break def end(self): - # Sending poison pill to all worker + """Ends the executor. Sends the poison pill to all workers.""" for _ in self.executor.workers: self.queue.put((None, None)) @@ -209,23 +283,42 @@ def end(self): self.queue.join() self.executor.sync() - def start(self): - self.manager = multiprocessing.Manager() + def start(self) -> None: + """Starts the executor""" + self.manager = Manager() self.result_queue = self.manager.Queue() self.workers = [] self.workers_used = 0 self.workers_active = 0 - self.impl = (LocalExecutor._UnlimitedParallelism(self) if self.parallelism == 0 - else LocalExecutor._LimitedParallelism(self)) + self.impl = (LocalExecutor.UnlimitedParallelism(self) if self.parallelism == 0 + else LocalExecutor.LimitedParallelism(self)) self.impl.start() - def execute_async(self, key, command, queue=None, executor_config=None): - self.impl.execute_async(key=key, command=command) + def execute_async(self, key: TaskInstanceKeyType, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None) -> None: + """Execute asynchronously.""" + assert self.impl, NOT_STARTED_MESSAGE + self.impl.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) - def sync(self): + def sync(self) -> None: + """ + Sync will get called periodically by the heartbeat method. + """ + assert self.impl, NOT_STARTED_MESSAGE self.impl.sync() - def end(self): + def end(self) -> None: + """ + Ends the executor. + :return: + """ + assert self.impl, NOT_STARTED_MESSAGE + assert self.manager, NOT_STARTED_MESSAGE self.impl.end() self.manager.shutdown() + + def terminate(self): + """Terminate the executor is not doing anything.""" diff --git a/airflow/executors/sequential_executor.py b/airflow/executors/sequential_executor.py index bb9303c7ccef5..a0d58b7350c7f 100644 --- a/airflow/executors/sequential_executor.py +++ b/airflow/executors/sequential_executor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -16,10 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Sequential executor.""" import subprocess +from typing import Any, Optional -from airflow.executors.base_executor import BaseExecutor +from airflow.executors.base_executor import BaseExecutor, CommandType +from airflow.models.taskinstance import TaskInstanceKeyType from airflow.utils.state import State @@ -32,14 +33,19 @@ class SequentialExecutor(BaseExecutor): Since we want airflow to work out of the box, it defaults to this SequentialExecutor alongside sqlite as you first install it. """ + def __init__(self): super().__init__() self.commands_to_run = [] - def execute_async(self, key, command, queue=None, executor_config=None): - self.commands_to_run.append((key, command,)) + def execute_async(self, + key: TaskInstanceKeyType, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None) -> None: + self.commands_to_run.append((key, command)) - def sync(self): + def sync(self) -> None: for key, command in self.commands_to_run: self.log.info("Executing command: %s", command) @@ -53,4 +59,8 @@ def sync(self): self.commands_to_run = [] def end(self): + """End the executor.""" self.heartbeat() + + def terminate(self): + """Terminate the executor is not doing anything.""" diff --git a/airflow/hooks/__init__.py b/airflow/hooks/__init__.py index 2020e16b02b78..48cdbdf0aa3e1 100644 --- a/airflow/hooks/__init__.py +++ b/airflow/hooks/__init__.py @@ -16,18 +16,4 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# pylint: disable=missing-docstring - -import sys - -# Imports the hooks dynamically while keeping the package API clean, -# abstracting the underlying modules - - -def _integrate_plugins(): - """Integrate plugins to the context""" - from airflow.plugins_manager import hooks_modules - for hooks_module in hooks_modules: - sys.modules[hooks_module.__name__] = hooks_module - globals()[hooks_module._name] = hooks_module # pylint: disable=protected-access +"""Hooks.""" diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py index e0360befcf795..e33237d51462b 100644 --- a/airflow/jobs/backfill_job.py +++ b/airflow/jobs/backfill_job.py @@ -24,11 +24,13 @@ from sqlalchemy.orm.session import Session, make_transient -from airflow import executors, models +from airflow import models from airflow.exceptions import ( AirflowException, DagConcurrencyLimitReached, NoAvailablePoolSlot, PoolNotFound, TaskConcurrencyLimitReached, ) +from airflow.executors.local_executor import LocalExecutor +from airflow.executors.sequential_executor import SequentialExecutor from airflow.jobs.base_job import BaseJob from airflow.models import DAG, DagPickle, DagRun from airflow.ti_deps.dep_context import BACKFILL_QUEUED_DEPS, DepContext @@ -487,8 +489,7 @@ def _per_task_process(task, key, ti, session=None): session.merge(ti) cfg_path = None - if executor.__class__ in (executors.LocalExecutor, - executors.SequentialExecutor): + if executor.__class__ in (LocalExecutor, SequentialExecutor): cfg_path = tmp_configuration_copy() executor.queue_task_instance( @@ -740,8 +741,7 @@ def _execute(self, session=None): # picklin' pickle_id = None - if not self.donot_pickle and self.executor.__class__ not in ( - executors.LocalExecutor, executors.SequentialExecutor): + if not self.donot_pickle and self.executor.__class__ not in (LocalExecutor, SequentialExecutor): pickle = DagPickle(self.dag) session.add(pickle) session.commit() diff --git a/airflow/jobs/base_job.py b/airflow/jobs/base_job.py index 874c0c6f058ce..b554b6ecb0b96 100644 --- a/airflow/jobs/base_job.py +++ b/airflow/jobs/base_job.py @@ -26,9 +26,10 @@ from sqlalchemy.exc import OperationalError from sqlalchemy.orm.session import make_transient -from airflow import executors, models +from airflow import models from airflow.configuration import conf from airflow.exceptions import AirflowException +from airflow.executors.executor_loader import ExecutorLoader from airflow.models.base import ID_LEN, Base from airflow.stats import Stats from airflow.utils import helpers, timezone @@ -79,7 +80,7 @@ def __init__( heartrate=None, *args, **kwargs): self.hostname = get_hostname() - self.executor = executor or executors.get_default_executor() + self.executor = executor or ExecutorLoader.get_default_executor() self.executor_class = executor.__class__.__name__ self.start_date = timezone.utcnow() self.latest_heartbeat = timezone.utcnow() diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 2640f1232efe5..518ccb1b75dc2 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -35,21 +35,24 @@ from sqlalchemy import and_, func, not_, or_ from sqlalchemy.orm.session import make_transient -from airflow import executors, models, settings +from airflow import models, settings from airflow.configuration import conf from airflow.exceptions import AirflowException +from airflow.executors.local_executor import LocalExecutor +from airflow.executors.sequential_executor import SequentialExecutor from airflow.jobs.base_job import BaseJob from airflow.models import DAG, DagRun, SlaMiss, errors +from airflow.models.taskinstance import SimpleTaskInstance from airflow.stats import Stats from airflow.ti_deps.dep_context import SCHEDULEABLE_STATES, SCHEDULED_DEPS, DepContext from airflow.ti_deps.deps.pool_slots_available_dep import STATES_TO_COUNT_AS_RUNNING from airflow.utils import asciiart, helpers, timezone from airflow.utils.dag_processing import ( - AbstractDagFileProcessor, DagFileProcessorAgent, SimpleDag, SimpleDagBag, SimpleTaskInstance, - list_py_file_paths, + AbstractDagFileProcessor, DagFileProcessorAgent, SimpleDag, SimpleDagBag, ) from airflow.utils.db import provide_session from airflow.utils.email import get_email_address_list, send_email +from airflow.utils.file import list_py_file_paths from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter, set_context from airflow.utils.state import State @@ -64,7 +67,7 @@ class DagFileProcessor(AbstractDagFileProcessor, LoggingMixin): :param dag_id_white_list: If specified, only look at these DAG ID's :type dag_id_white_list: list[unicode] :param zombies: zombie task instances to kill - :type zombies: list[airflow.utils.dag_processing.SimpleTaskInstance] + :type zombies: list[airflow.models.taskinstance.SimpleTaskInstance] """ # Counter that increments every time an instance of this class is created @@ -116,7 +119,7 @@ def _run_file_processor(result_channel, :param thread_name: the name to use for the process that is launched :type thread_name: unicode :param zombies: zombie task instances to kill - :type zombies: list[airflow.utils.dag_processing.SimpleTaskInstance] + :type zombies: list[airflow.models.taskinstance.SimpleTaskInstance] :return: the process that was launched :rtype: multiprocessing.Process """ @@ -1013,7 +1016,7 @@ def _change_state_for_executable_task_instances(self, task_instances, :type task_instances: list[airflow.models.TaskInstance] :param acceptable_states: Filters the TaskInstances updated to be in these states :type acceptable_states: Iterable[State] - :rtype: list[airflow.utils.dag_processing.SimpleTaskInstance] + :rtype: list[airflow.models.taskinstance.SimpleTaskInstance] """ if len(task_instances) == 0: session.commit() @@ -1282,8 +1285,7 @@ def _execute(self): # DAGs can be pickled for easier remote execution by some executors pickle_dags = False - if self.do_pickle and self.executor.__class__ not in \ - (executors.LocalExecutor, executors.SequentialExecutor): + if self.do_pickle and self.executor.__class__ not in (LocalExecutor, SequentialExecutor): pickle_dags = True self.log.info("Processing each file at most %s times", self.num_runs) @@ -1494,7 +1496,7 @@ def process_file(self, file_path, zombies, pickle_dags=False, session=None): :param file_path: the path to the Python file that should be executed :type file_path: unicode :param zombies: zombie task instances to kill. - :type zombies: list[airflow.utils.dag_processing.SimpleTaskInstance] + :type zombies: list[airflow.models.taskinstance.SimpleTaskInstance] :param pickle_dags: whether serialize the DAGs found in the file and save them to the db :type pickle_dags: bool diff --git a/airflow/kubernetes/kube_client.py b/airflow/kubernetes/kube_client.py index 2d734d9e1a111..b17e9a007ac09 100644 --- a/airflow/kubernetes/kube_client.py +++ b/airflow/kubernetes/kube_client.py @@ -68,7 +68,7 @@ def _get_client_with_patched_configuration(cfg: Optional[Configuration]) -> clie def get_kube_client(in_cluster: bool = conf.getboolean('kubernetes', 'in_cluster'), cluster_context: Optional[str] = None, - config_file: Optional[str] = None): + config_file: Optional[str] = None) -> client.CoreV1Api: """ Retrieves Kubernetes client diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py index f8d9b99d038a7..38c750085ca4a 100644 --- a/airflow/kubernetes/pod_generator.py +++ b/airflow/kubernetes/pod_generator.py @@ -26,12 +26,10 @@ import kubernetes.client.models as k8s -from airflow.executors import Executors - class PodDefaults: """ - Static defaults for the PodGenerator + Static defaults for Pods """ XCOM_MOUNT_PATH = '/airflow/xcom' SIDECAR_CONTAINER_NAME = 'airflow-xcom-sidecar' @@ -227,8 +225,9 @@ def from_obj(obj) -> k8s.V1Pod: raise TypeError( 'Cannot convert a non-dictionary or non-PodGenerator ' 'object into a KubernetesExecutorConfig') - - namespaced = obj.get(Executors.KubernetesExecutor, {}) + # We do not want to extract constant here from ExecutorLoader because it is just + # A name in dictionary rather than executor selection mechanism and it causes cyclic import + namespaced = obj.get("KubernetesExecutor", {}) resources = namespaced.get('resources') diff --git a/airflow/kubernetes/pod_launcher.py b/airflow/kubernetes/pod_launcher.py index 2d076939b77c2..686a9176a4fe6 100644 --- a/airflow/kubernetes/pod_launcher.py +++ b/airflow/kubernetes/pod_launcher.py @@ -27,7 +27,7 @@ from kubernetes.stream import stream as kubernetes_stream from requests.exceptions import BaseHTTPError -from airflow import AirflowException +from airflow.exceptions import AirflowException from airflow.kubernetes.pod_generator import PodDefaults from airflow.settings import pod_mutation_hook from airflow.utils.log.logging_mixin import LoggingMixin diff --git a/airflow/macros/__init__.py b/airflow/macros/__init__.py index 5582b3c3a3d26..6a9fa7e8fd3f6 100644 --- a/airflow/macros/__init__.py +++ b/airflow/macros/__init__.py @@ -85,12 +85,3 @@ def datetime_diff_for_humans(dt, since=None): import pendulum return pendulum.instance(dt).diff_for_humans(since) - - -def _integrate_plugins(): - """Integrate plugins to the context""" - import sys - from airflow.plugins_manager import macros_modules - for macros_module in macros_modules: - sys.modules[macros_module.__name__] = macros_module - globals()[macros_module._name] = macros_module # pylint: disable=protected-access diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 3a59cfb804564..07a06f683f6fb 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1090,7 +1090,7 @@ class BaseOperatorLink(metaclass=ABCMeta): Abstract base class that defines how we get an operator link. """ - operators = [] # type: List[Type[BaseOperator]] + operators: List[Type[BaseOperator]] = [] """ This property will be used by Airflow Plugins to find the Operators to which you want to assign this Operator Link diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 1abde55e6e0a4..17bd85920d846 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -38,7 +38,6 @@ from airflow.configuration import conf from airflow.dag.base_dag import BaseDag from airflow.exceptions import AirflowDagCycleException, AirflowException, DagNotFound, DuplicateTaskIdFound -from airflow.executors import LocalExecutor, get_default_executor from airflow.models.base import ID_LEN, Base from airflow.models.baseoperator import BaseOperator from airflow.models.dagbag import DagBag @@ -47,9 +46,9 @@ from airflow.models.taskinstance import TaskInstance, clear_task_instances from airflow.settings import MIN_SERIALIZED_DAG_UPDATE_INTERVAL, STORE_SERIALIZED_DAGS from airflow.utils import timezone -from airflow.utils.dag_processing import correct_maybe_zipped from airflow.utils.dates import cron_presets, date_range as utils_date_range from airflow.utils.db import provide_session +from airflow.utils.file import correct_maybe_zipped from airflow.utils.helpers import validate_key from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.sqlalchemy import Interval, UtcDateTime @@ -1254,9 +1253,11 @@ def run( """ from airflow.jobs import BackfillJob if not executor and local: + from airflow.executors.local_executor import LocalExecutor executor = LocalExecutor() elif not executor: - executor = get_default_executor() + from airflow.executors.executor_loader import ExecutorLoader + executor = ExecutorLoader.get_default_executor() job = BackfillJob( self, start_date=start_date, diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index ec6e1a20ee4a0..7df03d9de7171 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -33,11 +33,10 @@ from airflow.configuration import conf from airflow.dag.base_dag import BaseDagBag from airflow.exceptions import AirflowDagCycleException -from airflow.executors import get_default_executor from airflow.stats import Stats from airflow.utils import timezone -from airflow.utils.dag_processing import correct_maybe_zipped, list_py_file_paths from airflow.utils.db import provide_session +from airflow.utils.file import correct_maybe_zipped from airflow.utils.helpers import pprinttable from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.timeout import timeout @@ -88,7 +87,8 @@ def __init__( # do not use default arg in signature, to fix import cycle on plugin load if executor is None: - executor = get_default_executor() + from airflow.executors.executor_loader import ExecutorLoader + executor = ExecutorLoader.get_default_executor() dag_folder = dag_folder or settings.DAGS_FOLDER self.dag_folder = dag_folder self.dags = {} @@ -317,9 +317,7 @@ def kill_zombies(self, zombies, session=None): had a heartbeat for too long, in the current DagBag. :param zombies: zombie task instances to kill. - :type zombies: airflow.utils.dag_processing.SimpleTaskInstance :param session: DB session. - :type session: sqlalchemy.orm.session.Session """ from airflow.models.taskinstance import TaskInstance # Avoid circular import @@ -406,8 +404,8 @@ def collect_dags( FileLoadStat = namedtuple( 'FileLoadStat', "file duration dag_num task_num dags") + from airflow.utils.file import correct_maybe_zipped, list_py_file_paths dag_folder = correct_maybe_zipped(dag_folder) - for filepath in list_py_file_paths(dag_folder, safe_mode=safe_mode, include_examples=include_examples): try: diff --git a/airflow/models/kubernetes.py b/airflow/models/kubernetes.py index 50c205f0e8358..6535e2389a395 100644 --- a/airflow/models/kubernetes.py +++ b/airflow/models/kubernetes.py @@ -20,6 +20,7 @@ import uuid from sqlalchemy import Boolean, Column, String, true as sqltrue +from sqlalchemy.orm import Session from airflow.models.base import Base from airflow.utils.db import provide_session @@ -32,13 +33,13 @@ class KubeResourceVersion(Base): @staticmethod @provide_session - def get_current_resource_version(session=None): + def get_current_resource_version(session: Session = None) -> str: (resource_version,) = session.query(KubeResourceVersion.resource_version).one() return resource_version @staticmethod @provide_session - def checkpoint_resource_version(resource_version, session=None): + def checkpoint_resource_version(resource_version, session: Session = None) -> None: if resource_version: session.query(KubeResourceVersion).update({ KubeResourceVersion.resource_version: resource_version @@ -47,7 +48,7 @@ def checkpoint_resource_version(resource_version, session=None): @staticmethod @provide_session - def reset_resource_version(session=None): + def reset_resource_version(session: Session = None) -> str: session.query(KubeResourceVersion).update({ KubeResourceVersion.resource_version: '0' }) @@ -62,7 +63,7 @@ class KubeWorkerIdentifier(Base): @staticmethod @provide_session - def get_or_create_current_kube_worker_uuid(session=None): + def get_or_create_current_kube_worker_uuid(session: Session = None) -> str: (worker_uuid,) = session.query(KubeWorkerIdentifier.worker_uuid).one() if worker_uuid == '': worker_uuid = str(uuid.uuid4()) @@ -71,7 +72,7 @@ def get_or_create_current_kube_worker_uuid(session=None): @staticmethod @provide_session - def checkpoint_kube_worker_uuid(worker_uuid, session=None): + def checkpoint_kube_worker_uuid(worker_uuid: str, session: Session = None) -> None: if worker_uuid: session.query(KubeWorkerIdentifier).update({ KubeWorkerIdentifier.worker_uuid: worker_uuid diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 4caef644882bd..716d51d1b7349 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -26,7 +25,7 @@ import signal import time from datetime import datetime, timedelta -from typing import Any, Dict, Iterable, Optional, Union +from typing import Any, Dict, Iterable, Optional, Tuple, Union from urllib.parse import quote import dill @@ -121,6 +120,11 @@ def clear_task_instances(tis, dr.start_date = timezone.utcnow() +# Key used to identify task instance +# Tuple of: dag_id, task_id, execution_date, try_number +TaskInstanceKeyType = Tuple[str, str, datetime, int] + + class TaskInstance(Base, LoggingMixin): """ Task instances store the state of a task instance. This table is the @@ -213,7 +217,7 @@ def try_number(self): Return the try number that this task number will be when it is actually run. - If the TI is currently running, this will match the column in the + If the TaskInstance is currently running, this will match the column in the database, in all other cases this will be incremented. """ # This is designed so that task logs end up in the right file. @@ -395,11 +399,10 @@ def current_state(self, session=None): we use and looking up the state becomes part of the session, otherwise a new session is used. """ - TI = TaskInstance - ti = session.query(TI).filter( - TI.dag_id == self.dag_id, - TI.task_id == self.task_id, - TI.execution_date == self.execution_date, + ti = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.execution_date == self.execution_date, ).all() if ti: state = ti[0].state @@ -418,7 +421,7 @@ def error(self, session=None): session.commit() @provide_session - def refresh_from_db(self, session=None, lock_for_update=False, refresh_executor_config=False): + def refresh_from_db(self, session=None, lock_for_update=False, refresh_executor_config=False) -> None: """ Refreshes the task instance from the database based on the primary key @@ -429,12 +432,11 @@ def refresh_from_db(self, session=None, lock_for_update=False, refresh_executor_ lock the TaskInstance (issuing a FOR UPDATE clause) until the session is committed. """ - TI = TaskInstance - qry = session.query(TI).filter( - TI.dag_id == self.dag_id, - TI.task_id == self.task_id, - TI.execution_date == self.execution_date) + qry = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.execution_date == self.execution_date) if lock_for_update: ti = qry.with_for_update().first() @@ -468,7 +470,7 @@ def clear_xcom_data(self, session=None): session.commit() @property - def key(self): + def key(self) -> TaskInstanceKeyType: """ Returns a tuple that identifies the task instance uniquely """ @@ -528,7 +530,7 @@ def _get_previous_ti( # LEGACY: most likely running from unit tests if not dr: - # Means that this TI is NOT being run from a DR, but from a catchup + # Means that this TaskInstance is NOT being run from a DR, but from a catchup previous_scheduled_date = dag.previous_schedule(self.execution_date) if not previous_scheduled_date: return None @@ -742,7 +744,7 @@ def _check_and_change_state_before_execution( :type ignore_all_deps: bool :param ignore_depends_on_past: Ignore depends_on_past DAG attribute :type ignore_depends_on_past: bool - :param ignore_task_deps: Don't check the dependencies of this TI's task + :param ignore_task_deps: Don't check the dependencies of this TaskInstance's task :type ignore_task_deps: bool :param ignore_ti_state: Disregards previous task instance state :type ignore_ti_state: bool @@ -1056,7 +1058,7 @@ def handle_failure(self, error, test_mode=False, context=None, session=None): # Let's go deeper try: - # Since this function is called only when the TI state is running, + # Since this function is called only when the TaskInstance state is running, # try_number contains the current try_number (not the next). We # only mark task instance as FAILED if the next task instance # try_number exceeds the max_tries. @@ -1379,12 +1381,11 @@ def xcom_pull( @provide_session def get_num_running_task_instances(self, session): - TI = TaskInstance # .count() is inefficient return session.query(func.count()).filter( - TI.dag_id == self.dag_id, - TI.task_id == self.task_id, - TI.state == State.RUNNING + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.state == State.RUNNING ).scalar() def init_run_context(self, raw=False): @@ -1393,3 +1394,108 @@ def init_run_context(self, raw=False): """ self.raw = raw self._set_context(self) + + +# State of the task instance. +# Stores string version of the task state. +TaskInstanceStateType = Tuple[TaskInstanceKeyType, str] + + +class SimpleTaskInstance: + """ + Simplified Task Instance. + + Used to send data between processes via Queues. + """ + def __init__(self, ti: TaskInstance): + self._dag_id: str = ti.dag_id + self._task_id: str = ti.task_id + self._execution_date: datetime = ti.execution_date + self._start_date: datetime = ti.start_date + self._end_date: datetime = ti.end_date + self._try_number: int = ti.try_number + self._state: str = ti.state + self._executor_config: Any = ti.executor_config + self._run_as_user: Optional[str] = None + if hasattr(ti, 'run_as_user'): + self._run_as_user = ti.run_as_user + self._pool: Optional[str] = None + if hasattr(ti, 'pool'): + self._pool = ti.pool + self._priority_weight: Optional[int] = None + if hasattr(ti, 'priority_weight'): + self._priority_weight = ti.priority_weight + self._queue: str = ti.queue + self._key = ti.key + + # pylint: disable=missing-docstring + @property + def dag_id(self) -> str: + return self._dag_id + + @property + def task_id(self) -> str: + return self._task_id + + @property + def execution_date(self) -> datetime: + return self._execution_date + + @property + def start_date(self) -> datetime: + return self._start_date + + @property + def end_date(self) -> datetime: + return self._end_date + + @property + def try_number(self) -> int: + return self._try_number + + @property + def state(self) -> str: + return self._state + + @property + def pool(self) -> Any: + return self._pool + + @property + def priority_weight(self) -> Optional[int]: + return self._priority_weight + + @property + def queue(self) -> str: + return self._queue + + @property + def key(self) -> TaskInstanceKeyType: + return self._key + + @property + def executor_config(self): + return self._executor_config + + @provide_session + def construct_task_instance(self, session=None, lock_for_update=False) -> TaskInstance: + """ + Construct a TaskInstance from the database based on the primary key + + :param session: DB session. + :param lock_for_update: if True, indicates that the database should + lock the TaskInstance (issuing a FOR UPDATE clause) until the + session is committed. + :return: the task instance constructed + """ + + qry = session.query(TaskInstance).filter( + TaskInstance.dag_id == self._dag_id, + TaskInstance.task_id == self._task_id, + TaskInstance.execution_date == self._execution_date) + + if lock_for_update: + ti = qry.with_for_update().first() + else: + ti = qry.first() + return ti diff --git a/airflow/operators/__init__.py b/airflow/operators/__init__.py index 257e0a7b53b50..99eb3bf2c3fe8 100644 --- a/airflow/operators/__init__.py +++ b/airflow/operators/__init__.py @@ -16,14 +16,4 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# pylint: disable=missing-docstring - - -def _integrate_plugins(): - """Integrate plugins to the context""" - import sys - from airflow.plugins_manager import operators_modules - for operators_module in operators_modules: - sys.modules[operators_module.__name__] = operators_module - globals()[operators_module._name] = operators_module # pylint: disable=protected-access +"""Operators.""" diff --git a/airflow/operators/mssql_to_hive.py b/airflow/operators/mssql_to_hive.py index a7901bfe60774..8dbfd2acff969 100644 --- a/airflow/operators/mssql_to_hive.py +++ b/airflow/operators/mssql_to_hive.py @@ -106,9 +106,9 @@ def type_map(cls, mssql_type: int) -> str: Maps MsSQL type to Hive type. """ map_dict = { - pymssql.BINARY.value: 'INT', - pymssql.DECIMAL.value: 'FLOAT', - pymssql.NUMBER.value: 'INT', + pymssql.BINARY.value: 'INT', # pylint: disable=c-extension-no-member + pymssql.DECIMAL.value: 'FLOAT', # pylint: disable=c-extension-no-member + pymssql.NUMBER.value: 'INT', # pylint: disable=c-extension-no-member } return map_dict.get(mssql_type, 'STRING') diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py index 6dfa9dbe7b6c9..3b34bf0ddcdb3 100644 --- a/airflow/plugins_manager.py +++ b/airflow/plugins_manager.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -16,17 +15,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -import imp +"""Manages all plugins.""" +# noinspection PyDeprecation +import imp # pylint: disable=deprecated-module import inspect import os import re -from typing import Any, List +import sys +from typing import Any, Callable, List, Optional import pkg_resources from airflow import settings -from airflow.models.baseoperator import BaseOperatorLink from airflow.utils.log.logging_mixin import LoggingMixin log = LoggingMixin().log @@ -35,28 +35,29 @@ class AirflowPluginException(Exception): - pass + """Exception when loading plugin.""" class AirflowPlugin: - name = None # type: str - operators = [] # type: List[Any] - sensors = [] # type: List[Any] - hooks = [] # type: List[Any] - executors = [] # type: List[Any] - macros = [] # type: List[Any] - admin_views = [] # type: List[Any] - flask_blueprints = [] # type: List[Any] - menu_links = [] # type: List[Any] - appbuilder_views = [] # type: List[Any] - appbuilder_menu_items = [] # type: List[Any] + """Class used to define AirflowPlugin.""" + name: Optional[str] = None + operators: List[Any] = [] + sensors: List[Any] = [] + hooks: List[Any] = [] + executors: List[Any] = [] + macros: List[Any] = [] + admin_views: List[Any] = [] + flask_blueprints: List[Any] = [] + menu_links: List[Any] = [] + appbuilder_views: List[Any] = [] + appbuilder_menu_items: List[Any] = [] # A function that validate the statsd stat name, apply changes # to the stat name if necessary and return the transformed stat name. # # The function should have the following signature: # def func_name(stat_name: str) -> str: - stat_name_handler = None # type: Any + stat_name_handler: Optional[Callable[[str], str]] = None # A list of global operator extra links that can redirect users to # external systems. These extra links will be available on the @@ -64,16 +65,17 @@ class AirflowPlugin: # # Note: the global operator extra link can be overridden at each # operator level. - global_operator_extra_links = [] # type: List[BaseOperatorLink] + global_operator_extra_links: List[Any] = [] # A list of operator extra links to override or add operator links # to existing Airflow Operators. # These extra links will be available on the task page in form of # buttons. - operator_extra_links = [] # type: List[BaseOperatorLink] + operator_extra_links: List[Any] = [] @classmethod def validate(cls): + """Validates that plugin has a name.""" if not cls.name: raise AirflowPluginException("Your plugin needs a name.") @@ -134,14 +136,13 @@ def is_valid_plugin(plugin_obj, existing_plugins): norm_pattern = re.compile(r'[/|.]') -if settings.PLUGINS_FOLDER is None: - raise AirflowPluginException("Plugins folder is not set") +assert settings.PLUGINS_FOLDER, "Plugins folder is not set" # Crawl through the plugins folder to find AirflowPlugin derivatives for root, dirs, files in os.walk(settings.PLUGINS_FOLDER, followlinks=True): for f in files: + filepath = os.path.join(root, f) try: - filepath = os.path.join(root, f) if not os.path.isfile(filepath): continue mod_name, file_ext = os.path.splitext( @@ -157,11 +158,11 @@ def is_valid_plugin(plugin_obj, existing_plugins): for obj in list(m.__dict__.values()): if is_valid_plugin(obj, plugins): plugins.append(obj) - - except Exception as e: + except Exception as e: # pylint: disable=broad-except log.exception(e) - log.error('Failed to import plugin %s', filepath) - import_errors[filepath] = str(e) + path = filepath or str(f) + log.error('Failed to import plugin %s', path) + import_errors[path] = str(e) plugins = load_entrypoint_plugins( pkg_resources.iter_entry_points('airflow.plugins'), @@ -169,14 +170,18 @@ def is_valid_plugin(plugin_obj, existing_plugins): ) -def make_module(name, objects): +# pylint: disable=protected-access +# noinspection Mypy,PyTypeHints +def make_module(name: str, objects: List[Any]): + """Creates new module.""" log.debug('Creating module %s', name) name = name.lower() module = imp.new_module(name) - module._name = name.split('.')[-1] - module._objects = objects + module._name = name.split('.')[-1] # type: ignore + module._objects = objects # type: ignore module.__dict__.update((o.__name__, o) for o in objects) return module +# pylint: enable=protected-access # Plugin components to integrate as modules @@ -187,26 +192,29 @@ def make_module(name, objects): macros_modules = [] # Plugin components to integrate directly -admin_views = [] # type: List[Any] -flask_blueprints = [] # type: List[Any] -menu_links = [] # type: List[Any] -flask_appbuilder_views = [] # type: List[Any] -flask_appbuilder_menu_links = [] # type: List[Any] -stat_name_handler = None # type: Any -global_operator_extra_links = [] # type: List[BaseOperatorLink] -operator_extra_links = [] # type: List[BaseOperatorLink] +admin_views: List[Any] = [] +flask_blueprints: List[Any] = [] +menu_links: List[Any] = [] +flask_appbuilder_views: List[Any] = [] +flask_appbuilder_menu_links: List[Any] = [] +stat_name_handler: Any = None +global_operator_extra_links: List[Any] = [] +operator_extra_links: List[Any] = [] stat_name_handlers = [] for p in plugins: + if not p.name: + raise AirflowPluginException("Plugin name is missing.") + plugin_name: str = p.name operators_modules.append( - make_module('airflow.operators.' + p.name, p.operators + p.sensors)) + make_module('airflow.operators.' + plugin_name, p.operators + p.sensors)) sensors_modules.append( - make_module('airflow.sensors.' + p.name, p.sensors) + make_module('airflow.sensors.' + plugin_name, p.sensors) ) - hooks_modules.append(make_module('airflow.hooks.' + p.name, p.hooks)) + hooks_modules.append(make_module('airflow.hooks.' + plugin_name, p.hooks)) executors_modules.append( - make_module('airflow.executors.' + p.name, p.executors)) - macros_modules.append(make_module('airflow.macros.' + p.name, p.macros)) + make_module('airflow.executors.' + plugin_name, p.executors)) + macros_modules.append(make_module('airflow.macros.' + plugin_name, p.macros)) admin_views.extend(p.admin_views) menu_links.extend(p.menu_links) @@ -230,3 +238,52 @@ def make_module(name, objects): 'is not allowed.'.format(stat_name_handlers)) stat_name_handler = stat_name_handlers[0] if len(stat_name_handlers) == 1 else None + + +def integrate_operator_plugins() -> None: + """Integrate operators plugins to the context""" + for operators_module in operators_modules: + sys.modules[operators_module.__name__] = operators_module + # noinspection PyProtectedMember + globals()[operators_module._name] = operators_module # pylint: disable=protected-access + + +def integrate_sensor_plugins() -> None: + """Integrate sensor plugins to the context""" + for sensors_module in sensors_modules: + sys.modules[sensors_module.__name__] = sensors_module + # noinspection PyProtectedMember + globals()[sensors_module._name] = sensors_module # pylint: disable=protected-access + + +def integrate_hook_plugins() -> None: + """Integrate hook plugins to the context""" + for hooks_module in hooks_modules: + sys.modules[hooks_module.__name__] = hooks_module + # noinspection PyProtectedMember + globals()[hooks_module._name] = hooks_module # pylint: disable=protected-access + + +def integrate_executor_plugins() -> None: + """Integrate executor plugins to the context.""" + for executors_module in executors_modules: + sys.modules[executors_module.__name__] = executors_module + # noinspection PyProtectedMember + globals()[executors_module._name] = executors_module # pylint: disable=protected-access + + +def integrate_macro_plugins() -> None: + """Integrate macro plugins to the context""" + for macros_module in macros_modules: + sys.modules[macros_module.__name__] = macros_module + # noinspection PyProtectedMember + globals()[macros_module._name] = macros_module # pylint: disable=protected-access + + +def integrate_plugins() -> None: + """Integrates all types of plugins.""" + integrate_operator_plugins() + integrate_sensor_plugins() + integrate_hook_plugins() + integrate_executor_plugins() + integrate_macro_plugins() diff --git a/airflow/sensors/__init__.py b/airflow/sensors/__init__.py index 945cd00d3e613..bcdaf4c3c8da1 100644 --- a/airflow/sensors/__init__.py +++ b/airflow/sensors/__init__.py @@ -17,14 +17,4 @@ # specific language governing permissions and limitations # under the License. # - -# pylint: disable=missing-docstring - - -def _integrate_plugins(): - """Integrate plugins to the context""" - import sys - from airflow.plugins_manager import sensors_modules - for sensors_module in sensors_modules: - sys.modules[sensors_module.__name__] = sensors_module - globals()[sensors_module._name] = sensors_module # pylint: disable=protected-access +"""Sensors.""" diff --git a/airflow/settings.py b/airflow/settings.py index c97a83d344fe5..6e9a93af2b5c2 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -68,13 +67,13 @@ LOG_FORMAT = conf.get('core', 'log_format') SIMPLE_LOG_FORMAT = conf.get('core', 'simple_log_format') -SQL_ALCHEMY_CONN = None # type: Optional[str] -DAGS_FOLDER = None # type: Optional[str] -PLUGINS_FOLDER = None # type: Optional[str] -LOGGING_CLASS_PATH = None # type: Optional[str] +SQL_ALCHEMY_CONN: Optional[str] = None +DAGS_FOLDER: Optional[str] = None +PLUGINS_FOLDER: Optional[str] = None +LOGGING_CLASS_PATH: Optional[str] = None -engine = None # type: Optional[Engine] -Session = None # type: Optional[SASession] +engine: Optional[Engine] = None +Session: Optional[SASession] = None # The JSON library to use for DAG Serialization and De-Serialization json = json diff --git a/airflow/task/task_runner/base_task_runner.py b/airflow/task/task_runner/base_task_runner.py index 8345c5d7b7422..83b6064dad7c8 100644 --- a/airflow/task/task_runner/base_task_runner.py +++ b/airflow/task/task_runner/base_task_runner.py @@ -16,13 +16,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Base task runner""" import getpass import os import subprocess import threading from airflow.configuration import conf +from airflow.exceptions import AirflowConfigException from airflow.utils.configuration import tmp_configuration_copy from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname @@ -51,7 +52,7 @@ def __init__(self, local_task_job): else: try: self.run_as_user = conf.get('core', 'default_impersonation') - except conf.AirflowConfigException: + except AirflowConfigException: self.run_as_user = None # Add sudo commands to change user if we need to. Needed to handle SubDagOperator @@ -104,7 +105,7 @@ def _read_task_logs(self, stream): line = stream.readline() if isinstance(line, bytes): line = line.decode('utf-8') - if len(line) == 0: + if not line: break self.log.info('Job %s: Subtask %s %s', self._task_instance.job_id, self._task_instance.task_id, @@ -124,6 +125,7 @@ def run_command(self, run_with=None): self.log.info("Running on host: %s", get_hostname()) self.log.info('Running: %s', full_cmd) + # pylint: disable=subprocess-popen-preexec-fn proc = subprocess.Popen( full_cmd, stdout=subprocess.PIPE, diff --git a/airflow/task/task_runner/standard_task_runner.py b/airflow/task/task_runner/standard_task_runner.py index c79d496ac34b0..9c5795169c316 100644 --- a/airflow/task/task_runner/standard_task_runner.py +++ b/airflow/task/task_runner/standard_task_runner.py @@ -16,7 +16,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Standard task runner""" import psutil from airflow.task.task_runner.base_task_runner import BaseTaskRunner @@ -25,7 +25,7 @@ class StandardTaskRunner(BaseTaskRunner): """ - Runs the raw Airflow task by invoking through the Bash shell. + Standard runner for all tasks. """ def __init__(self, local_task_job): super().__init__(local_task_job) @@ -39,6 +39,3 @@ def return_code(self): def terminate(self): if self.process and psutil.pid_exists(self.process.pid): reap_process_group(self.process.pid, self.log) - - def on_finish(self): - super().on_finish() diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py index 2058cbe093de3..f03fa2e3ebccb 100644 --- a/airflow/utils/dag_processing.py +++ b/airflow/utils/dag_processing.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -16,37 +15,36 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Processes DAGs.""" import enum import importlib import logging import multiprocessing import os -import re import signal import sys import time -import zipfile from abc import ABCMeta, abstractmethod from datetime import datetime, timedelta from importlib import import_module -from typing import Iterable, NamedTuple, Optional +from typing import Any, Callable, Dict, KeysView, List, NamedTuple, Optional, Tuple import psutil -from setproctitle import setproctitle +from setproctitle import setproctitle # pylint: disable=no-name-in-module from sqlalchemy import or_ from tabulate import tabulate -# To avoid circular imports import airflow.models from airflow.configuration import conf from airflow.dag.base_dag import BaseDag, BaseDagBag from airflow.exceptions import AirflowException -from airflow.models import errors +from airflow.models import Connection, errors +from airflow.models.taskinstance import SimpleTaskInstance from airflow.settings import STORE_SERIALIZED_DAGS from airflow.stats import Stats from airflow.utils import timezone from airflow.utils.db import provide_session +from airflow.utils.file import list_py_file_paths from airflow.utils.helpers import reap_process_group from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State @@ -63,23 +61,23 @@ class SimpleDag(BaseDag): :type pickle_id: unicode """ - def __init__(self, dag, pickle_id=None): - self._dag_id = dag.dag_id - self._task_ids = [task.task_id for task in dag.tasks] - self._full_filepath = dag.full_filepath - self._is_paused = dag.is_paused - self._concurrency = dag.concurrency - self._pickle_id = pickle_id - self._task_special_args = {} + def __init__(self, dag, pickle_id: Optional[str] = None): + self._dag_id: str = dag.dag_id + self._task_ids: List[str] = [task.task_id for task in dag.tasks] + self._full_filepath: str = dag.full_filepath + self._is_paused: bool = dag.is_paused + self._concurrency: int = dag.concurrency + self._pickle_id: Optional[str] = pickle_id + self._task_special_args: Dict[str, Any] = {} for task in dag.tasks: special_args = {} if task.task_concurrency is not None: special_args['task_concurrency'] = task.task_concurrency - if len(special_args) > 0: + if special_args: self._task_special_args[task.task_id] = special_args @property - def dag_id(self): + def dag_id(self) -> str: """ :return: the DAG ID :rtype: unicode @@ -87,7 +85,7 @@ def dag_id(self): return self._dag_id @property - def task_ids(self): + def task_ids(self) -> List[str]: """ :return: A list of task IDs that are in this DAG :rtype: list[unicode] @@ -95,7 +93,7 @@ def task_ids(self): return self._task_ids @property - def full_filepath(self): + def full_filepath(self) -> str: """ :return: The absolute path to the file that contains this DAG's definition :rtype: unicode @@ -103,7 +101,7 @@ def full_filepath(self): return self._full_filepath @property - def concurrency(self): + def concurrency(self) -> int: """ :return: maximum number of tasks that can run simultaneously from this DAG :rtype: int @@ -111,7 +109,7 @@ def concurrency(self): return self._concurrency @property - def is_paused(self): + def is_paused(self) -> bool: """ :return: whether this DAG is paused or not :rtype: bool @@ -119,7 +117,7 @@ def is_paused(self): return self._is_paused @property - def pickle_id(self): + def pickle_id(self) -> Optional[str]: """ :return: The pickle ID for this DAG, if it has one. Otherwise None. :rtype: unicode @@ -127,140 +125,45 @@ def pickle_id(self): return self._pickle_id @property - def task_special_args(self): + def task_special_args(self) -> Dict[str, Any]: + """Special arguments of the task.""" return self._task_special_args - def get_task_special_arg(self, task_id, special_arg_name): + def get_task_special_arg(self, task_id: str, special_arg_name: str): + """Retrieve special arguments of the task.""" if task_id in self._task_special_args and special_arg_name in self._task_special_args[task_id]: return self._task_special_args[task_id][special_arg_name] else: return None -class SimpleTaskInstance: - def __init__(self, ti): - self._dag_id = ti.dag_id - self._task_id = ti.task_id - self._execution_date = ti.execution_date - self._start_date = ti.start_date - self._end_date = ti.end_date - self._try_number = ti.try_number - self._state = ti.state - self._executor_config = ti.executor_config - if hasattr(ti, 'run_as_user'): - self._run_as_user = ti.run_as_user - else: - self._run_as_user = None - if hasattr(ti, 'pool'): - self._pool = ti.pool - else: - self._pool = None - if hasattr(ti, 'priority_weight'): - self._priority_weight = ti.priority_weight - else: - self._priority_weight = None - self._queue = ti.queue - self._key = ti.key - - @property - def dag_id(self): - return self._dag_id - - @property - def task_id(self): - return self._task_id - - @property - def execution_date(self): - return self._execution_date - - @property - def start_date(self): - return self._start_date - - @property - def end_date(self): - return self._end_date - - @property - def try_number(self): - return self._try_number - - @property - def state(self): - return self._state - - @property - def pool(self): - return self._pool - - @property - def priority_weight(self): - return self._priority_weight - - @property - def queue(self): - return self._queue - - @property - def key(self): - return self._key - - @property - def executor_config(self): - return self._executor_config - - @provide_session - def construct_task_instance(self, session=None, lock_for_update=False): - """ - Construct a TaskInstance from the database based on the primary key - - :param session: DB session. - :param lock_for_update: if True, indicates that the database should - lock the TaskInstance (issuing a FOR UPDATE clause) until the - session is committed. - """ - TI = airflow.models.TaskInstance - - qry = session.query(TI).filter( - TI.dag_id == self._dag_id, - TI.task_id == self._task_id, - TI.execution_date == self._execution_date) - - if lock_for_update: - ti = qry.with_for_update().first() - else: - ti = qry.first() - return ti - - class SimpleDagBag(BaseDagBag): """ A collection of SimpleDag objects with some convenience methods. """ - def __init__(self, simple_dags): + def __init__(self, simple_dags: List[SimpleDag]): """ Constructor. :param simple_dags: SimpleDag objects that should be in this - :type list(airflow.utils.dag_processing.SimpleDagBag) + :type list(airflow.utils.dag_processing.SimpleDag) """ self.simple_dags = simple_dags - self.dag_id_to_simple_dag = {} + self.dag_id_to_simple_dag: Dict[str, SimpleDag] = {} for simple_dag in simple_dags: self.dag_id_to_simple_dag[simple_dag.dag_id] = simple_dag @property - def dag_ids(self): + def dag_ids(self) -> KeysView[str]: """ :return: IDs of all the DAGs in this :rtype: list[unicode] """ return self.dag_id_to_simple_dag.keys() - def get_dag(self, dag_id): + def get_dag(self, dag_id: str) -> SimpleDag: """ :param dag_id: DAG ID :type dag_id: unicode @@ -273,106 +176,6 @@ def get_dag(self, dag_id): return self.dag_id_to_simple_dag[dag_id] -def correct_maybe_zipped(fileloc): - """ - If the path contains a folder with a .zip suffix, then - the folder is treated as a zip archive and path to zip is returned. - """ - - _, archive, _ = re.search(r'((.*\.zip){})?(.*)'.format(re.escape(os.sep)), fileloc).groups() - if archive and zipfile.is_zipfile(archive): - return archive - else: - return fileloc - - -COMMENT_PATTERN = re.compile(r"\s*#.*") - - -def list_py_file_paths(directory, safe_mode=conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE', fallback=True), - include_examples=None): - """ - Traverse a directory and look for Python files. - - :param directory: the directory to traverse - :type directory: unicode - :param safe_mode: whether to use a heuristic to determine whether a file - contains Airflow DAG definitions. If not provided, use the - core.DAG_DISCOVERY_SAFE_MODE configuration setting. If not set, default - to safe. - :type safe_mode: bool - :param include_examples: include example DAGs - :type include_examples: bool - :return: a list of paths to Python files in the specified directory - :rtype: list[unicode] - """ - if include_examples is None: - include_examples = conf.getboolean('core', 'LOAD_EXAMPLES') - file_paths = [] - if directory is None: - return [] - elif os.path.isfile(directory): - return [directory] - elif os.path.isdir(directory): - patterns_by_dir = {} - for root, dirs, files in os.walk(directory, followlinks=True): - patterns = patterns_by_dir.get(root, []) - ignore_file = os.path.join(root, '.airflowignore') - if os.path.isfile(ignore_file): - with open(ignore_file, 'r') as file: - # If we have new patterns create a copy so we don't change - # the previous list (which would affect other subdirs) - lines_no_comments = [COMMENT_PATTERN.sub("", line) for line in file.read().split("\n")] - patterns += [re.compile(line) for line in lines_no_comments if line] - - # If we can ignore any subdirs entirely we should - fewer paths - # to walk is better. We have to modify the ``dirs`` array in - # place for this to affect os.walk - dirs[:] = [ - d - for d in dirs - if not any(p.search(os.path.join(root, d)) for p in patterns) - ] - - # We want patterns defined in a parent folder's .airflowignore to - # apply to subdirs too - for d in dirs: - patterns_by_dir[os.path.join(root, d)] = patterns - - for f in files: - try: - file_path = os.path.join(root, f) - if not os.path.isfile(file_path): - continue - _, file_ext = os.path.splitext(os.path.split(file_path)[-1]) - if file_ext != '.py' and not zipfile.is_zipfile(file_path): - continue - if any([re.findall(p, file_path) for p in patterns]): - continue - - # Heuristic that guesses whether a Python file contains an - # Airflow DAG definition. - might_contain_dag = True - if safe_mode and not zipfile.is_zipfile(file_path): - with open(file_path, 'rb') as fp: - content = fp.read() - might_contain_dag = all( - [s in content for s in (b'DAG', b'airflow')]) - - if not might_contain_dag: - continue - - file_paths.append(file_path) - except Exception: - log = LoggingMixin().log - log.exception("Error while examining %s", f) - if include_examples: - import airflow.example_dags - example_dag_folder = airflow.example_dags.__path__[0] - file_paths.extend(list_py_file_paths(example_dag_folder, safe_mode, False)) - return file_paths - - class AbstractDagFileProcessor(metaclass=ABCMeta): """ Processes a DAG file. See SchedulerJob.process_file() for more details. @@ -386,7 +189,7 @@ def start(self): raise NotImplementedError() @abstractmethod - def terminate(self, sigkill=False): + def terminate(self, sigkill: bool = False): """ Terminate (and then kill) the process launched to process the file """ @@ -394,7 +197,7 @@ def terminate(self, sigkill=False): @property @abstractmethod - def pid(self): + def pid(self) -> int: """ :return: the PID of the process launched to process the given file """ @@ -402,7 +205,7 @@ def pid(self): @property @abstractmethod - def exit_code(self): + def exit_code(self) -> int: """ After the process is finished, this can be called to get the return code :return: the exit code of the process @@ -412,7 +215,7 @@ def exit_code(self): @property @abstractmethod - def done(self): + def done(self) -> bool: """ Check if the process launched to process this file is done. :return: whether the process is finished running @@ -422,7 +225,7 @@ def done(self): @property @abstractmethod - def result(self): + def result(self) -> Tuple[List[SimpleDag], int]: """ A list of simple dags found, and the number of import errors @@ -451,7 +254,7 @@ def file_path(self): DagParsingStat = NamedTuple('DagParsingStat', [ - ('file_paths', Iterable[str]), + ('file_paths', List[str]), ('done', bool), ('all_files_processed', bool) ]) @@ -465,6 +268,7 @@ def file_path(self): class DagParsingSignal(enum.Enum): + """All signals sent to parser.""" AGENT_HEARTBEAT = 'agent_heartbeat' TERMINATE_MANAGER = 'terminate_manager' END_MANAGER = 'end_manager' @@ -563,6 +367,7 @@ def heartbeat(self): pass def wait_until_finished(self): + """Waits until DAG parsing is finished.""" while self._parent_signal_conn.poll(): try: result = self._parent_signal_conn.recv() @@ -660,6 +465,7 @@ def _sync_metadata(self, stat): self._done = stat.done self._all_files_processed = stat.all_files_processed + # pylint: disable=missing-docstring @property def file_paths(self): return self._file_paths @@ -696,7 +502,7 @@ def end(self): self._parent_signal_conn.close() -class DagFileProcessorManager(LoggingMixin): +class DagFileProcessorManager(LoggingMixin): # pylint: disable=too-many-instance-attributes """ Given a list of DAG definition files, this kicks off several processors in parallel to process them and put the results to a multiprocessing.Queue @@ -724,26 +530,27 @@ class DagFileProcessorManager(LoggingMixin): """ def __init__(self, - dag_directory, - file_paths, - max_runs, - processor_factory, - processor_timeout, - signal_conn, - async_mode=True): + dag_directory: str, + file_paths: List[str], + max_runs: int, + processor_factory: Callable[[str, List[Any]], AbstractDagFileProcessor], + processor_timeout: timedelta, + signal_conn: Connection, + async_mode: bool = True): self._file_paths = file_paths - self._file_path_queue = [] + self._file_path_queue: List[str] = [] self._dag_directory = dag_directory self._max_runs = max_runs self._processor_factory = processor_factory self._signal_conn = signal_conn self._async_mode = async_mode + self._parsing_start_time: Optional[datetime] = None self._parallelism = conf.getint('scheduler', 'max_threads') if 'sqlite' in conf.get('core', 'sql_alchemy_conn') and self._parallelism > 1: self.log.warning( - f"Because we cannot use more than 1 thread (max_threads = {self._parallelism}) " - "when using sqlite. So we set parallelism to 1." + "Because we cannot use more than 1 thread (max_threads = " + "%d ) when using sqlite. So we set parallelism to 1.", self._parallelism ) self._parallelism = 1 @@ -758,12 +565,12 @@ def __init__(self, self._zombie_threshold_secs = ( conf.getint('scheduler', 'scheduler_zombie_task_threshold')) # Map from file path to the processor - self._processors = {} + self._processors: Dict[str, AbstractDagFileProcessor] = {} self._heartbeat_count = 0 # Map from file path to stats about the file - self._file_stats = {} # type: dict(str, DagFileStat) + self._file_stats: Dict[str, DagFileStat] = {} self._last_zombie_query_time = None # Last time that the DAG dir was traversed to look for files @@ -772,7 +579,7 @@ def __init__(self, self.last_stat_print_time = timezone.datetime(2000, 1, 1) # TODO: Remove magic number self._zombie_query_interval = 10 - self._zombies = [] + self._zombies: List[SimpleTaskInstance] = [] # How long to wait before timing out a process to parse a DAG file self._processor_timeout = processor_timeout @@ -785,7 +592,7 @@ def __init__(self, signal.signal(signal.SIGINT, self._exit_gracefully) signal.signal(signal.SIGTERM, self._exit_gracefully) - def _exit_gracefully(self, signum, frame): + def _exit_gracefully(self, signum, frame): # pylint: disable=unused-argument """ Helper method to clean up DAG file processors to avoid leaving orphan processes. """ @@ -842,7 +649,7 @@ def start(self): continue self._refresh_dag_dir() - self._find_zombies() + self._find_zombies() # pylint: disable=no-value-for-parameter simple_dags = self.heartbeat() for simple_dag in simple_dags: @@ -899,10 +706,11 @@ def _refresh_dag_dir(self): self.log.info("There are %s files in %s", len(self._file_paths), self._dag_directory) self.set_file_paths(self._file_paths) + # noinspection PyBroadException try: self.log.debug("Removing old import errors") - self.clear_nonexistent_import_errors() - except Exception: + self.clear_nonexistent_import_errors() # pylint: disable=no-value-for-parameter + except Exception: # pylint: disable=broad-except self.log.exception("Error removing old import errors") if STORE_SERIALIZED_DAGS: @@ -915,8 +723,8 @@ def _print_stat(self): """ Occasionally print out stats about how fast the files are getting processed """ - if ((timezone.utcnow() - self.last_stat_print_time).total_seconds() > self.print_stats_interval): - if len(self._file_paths) > 0: + if (timezone.utcnow() - self.last_stat_print_time).total_seconds() > self.print_stats_interval: + if self._file_paths: self._log_file_processing_stats(self._file_paths) self.last_stat_print_time = timezone.utcnow() @@ -1014,10 +822,6 @@ def _log_file_processing_stats(self, known_file_paths): self.log.info(log_str) - @property - def file_paths(self): - return self._file_paths - def get_pid(self, file_path): """ :param file_path: the path to the file that's being processed @@ -1144,10 +948,8 @@ def collect_results(self): """ self._kill_timed_out_processors() - finished_processors = {} - """:type : dict[unicode, AbstractDagFileProcessor]""" - running_processors = {} - """:type : dict[unicode, AbstractDagFileProcessor]""" + finished_processors: Dict[str, AbstractDagFileProcessor] = {} + running_processors: Dict[str, AbstractDagFileProcessor] = {} for file_path, processor in self._processors.items(): if processor.done: @@ -1202,7 +1004,7 @@ def heartbeat(self): # Generate more file paths to process if we processed all the files # already. - if len(self._file_path_queue) == 0: + if not self._file_path_queue: self.emit_metrics() self._parsing_start_time = timezone.utcnow() @@ -1245,8 +1047,7 @@ def heartbeat(self): self._file_path_queue.extend(files_paths_to_queue) # Start more processors if we have enough slots and files to process - while (self._parallelism - len(self._processors) > 0 and - len(self._file_path_queue) > 0): + while self._parallelism - len(self._processors) > 0 and self._file_path_queue: file_path = self._file_path_queue.pop(0) processor = self._processor_factory(file_path, self._zombies) Stats.incr('dag_processing.processes') @@ -1270,7 +1071,7 @@ def _find_zombies(self, session): and update the current zombie list. """ now = timezone.utcnow() - zombies = [] + zombies: List[SimpleTaskInstance] = [] if not self._last_zombie_query_time or \ (now - self._last_zombie_query_time).total_seconds() > self._zombie_query_interval: # to avoid circular imports @@ -1313,7 +1114,7 @@ def _kill_timed_out_processors(self): self.log.info( "Processor for %s with PID %s started at %s has timed out, " "killing it.", - processor.file_path, processor.pid, processor.start_time.isoformat()) + file_path, processor.pid, processor.start_time.isoformat()) Stats.decr('dag_processing.processes') Stats.incr('dag_processing.processor_timeouts') # TODO: Remove ater Airflow 2.0 @@ -1348,7 +1149,7 @@ def end(self): them as orphaned. """ pids_to_kill = self.get_all_pids() - if len(pids_to_kill) > 0: + if pids_to_kill: # First try SIGTERM this_process = psutil.Process(os.getpid()) # Only check child processes to ensure that we don't have a case @@ -1372,7 +1173,7 @@ def end(self): # Then SIGKILL child_processes = [x for x in this_process.children(recursive=True) if x.is_running() and x.pid in pids_to_kill] - if len(child_processes) > 0: + if child_processes: self.log.info("SIGKILL processes that did not terminate gracefully") for child in child_processes: self.log.info("Killing child PID: %s", child.pid) @@ -1396,3 +1197,8 @@ def emit_metrics(self): # TODO: Remove before Airflow 2.0 Stats.gauge('collect_dags', parse_time) Stats.gauge('dagbag_import_errors', sum(stat.import_errors for stat in self._file_stats.values())) + + # pylint: disable=missing-docstring + @property + def file_paths(self): + return self._file_paths diff --git a/airflow/utils/file.py b/airflow/utils/file.py index 2d654df04e997..3c24bab5c7671 100644 --- a/airflow/utils/file.py +++ b/airflow/utils/file.py @@ -19,9 +19,14 @@ import errno import os +import re import shutil +import zipfile from contextlib import contextmanager from tempfile import mkdtemp +from typing import Dict, List, Optional, Pattern + +from airflow import LoggingMixin, conf @contextmanager @@ -56,3 +61,110 @@ def mkdirs(path, mode): raise finally: os.umask(o_umask) + + +def correct_maybe_zipped(fileloc): + """ + If the path contains a folder with a .zip suffix, then + the folder is treated as a zip archive and path to zip is returned. + """ + + _, archive, _ = re.search(r'((.*\.zip){})?(.*)'.format(re.escape(os.sep)), fileloc).groups() + if archive and zipfile.is_zipfile(archive): + return archive + else: + return fileloc + + +def list_py_file_paths(directory: str, + safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE', fallback=True), + include_examples: Optional[bool] = None): + """ + Traverse a directory and look for Python files. + + :param directory: the directory to traverse + :type directory: unicode + :param safe_mode: whether to use a heuristic to determine whether a file + contains Airflow DAG definitions. If not provided, use the + core.DAG_DISCOVERY_SAFE_MODE configuration setting. If not set, default + to safe. + :type safe_mode: bool + :param include_examples: include example DAGs + :type include_examples: bool + :return: a list of paths to Python files in the specified directory + :rtype: list[unicode] + """ + if include_examples is None: + include_examples = conf.getboolean('core', 'LOAD_EXAMPLES') + file_paths: List[str] = [] + if directory is None: + return [] + elif os.path.isfile(directory): + return [directory] + elif os.path.isdir(directory): + patterns_by_dir: Dict[str, List[Pattern[str]]] = {} + for root, dirs, files in os.walk(directory, followlinks=True): + patterns: List[Pattern[str]] = patterns_by_dir.get(root, []) + ignore_file = os.path.join(root, '.airflowignore') + if os.path.isfile(ignore_file): + with open(ignore_file, 'r') as file: + # If we have new patterns create a copy so we don't change + # the previous list (which would affect other subdirs) + lines_no_comments = [COMMENT_PATTERN.sub("", line) for line in file.read().split("\n")] + patterns += [re.compile(line) for line in lines_no_comments if line] + + # If we can ignore any subdirs entirely we should - fewer paths + # to walk is better. We have to modify the ``dirs`` array in + # place for this to affect os.walk + dirs[:] = [ + subdir + for subdir in dirs + if not any(p.search(os.path.join(root, subdir)) for p in patterns) + ] + + # We want patterns defined in a parent folder's .airflowignore to + # apply to subdirs too + for subdir in dirs: + patterns_by_dir[os.path.join(root, subdir)] = patterns + + find_dag_file_paths(file_paths, files, patterns, root, safe_mode) + if include_examples: + from airflow import example_dags + example_dag_folder = example_dags.__path__[0] # type: ignore + file_paths.extend(list_py_file_paths(example_dag_folder, safe_mode, False)) + return file_paths + + +def find_dag_file_paths(file_paths, files, patterns, root, safe_mode): + """Finds file paths of all DAG files.""" + for f in files: + # noinspection PyBroadException + try: + file_path = os.path.join(root, f) + if not os.path.isfile(file_path): + continue + _, file_ext = os.path.splitext(os.path.split(file_path)[-1]) + if file_ext != '.py' and not zipfile.is_zipfile(file_path): + continue + if any([re.findall(p, file_path) for p in patterns]): + continue + + if not might_contain_dag(file_path, safe_mode): + continue + + file_paths.append(file_path) + except Exception: # pylint: disable=broad-except + log = LoggingMixin().log + log.exception("Error while examining %s", f) + + +COMMENT_PATTERN = re.compile(r"\s*#.*") + + +def might_contain_dag(file_path, safe_mode): + """Heuristic that guesses whether a Python file contains an Airflow DAG definition.""" + if safe_mode and not zipfile.is_zipfile(file_path): + with open(file_path, 'rb') as dag_file: + content = dag_file.read() + return all([s in content for s in (b'DAG', b'airflow')]) + return True diff --git a/airflow/www/views.py b/airflow/www/views.py index 1340d34a052d9..5df3d67d0e6ef 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -50,6 +50,7 @@ set_dag_run_state_to_failed, set_dag_run_state_to_success, ) from airflow.configuration import AIRFLOW_CONFIG, conf +from airflow.executors.executor_loader import ExecutorLoader from airflow.models import Connection, DagModel, DagRun, Log, SlaMiss, TaskFail, XCom, errors from airflow.settings import STORE_SERIALIZED_DAGS from airflow.ti_deps.dep_context import RUNNING_DEPS, SCHEDULER_QUEUED_DEPS, DepContext @@ -147,11 +148,12 @@ def show_traceback(error): class AirflowBaseView(BaseView): + from airflow import macros route_base = '' # Make our macros available to our UI templates too. extra_args = { - 'macros': airflow.macros, + 'macros': macros, } def render_template(self, *args, **kwargs): @@ -807,8 +809,7 @@ def run(self): ignore_task_deps = request.form.get('ignore_task_deps') == "true" ignore_ti_state = request.form.get('ignore_ti_state') == "true" - from airflow.executors import get_default_executor - executor = get_default_executor() + executor = ExecutorLoader.get_default_executor() valid_celery_config = False valid_kubernetes_config = False diff --git a/breeze-complete b/breeze-complete index 36ea7f0e99b91..867542b532377 100644 --- a/breeze-complete +++ b/breeze-complete @@ -22,7 +22,7 @@ _BREEZE_ALLOWED_ENVS=" docker kubernetes " _BREEZE_ALLOWED_BACKENDS=" sqlite mysql postgres " _BREEZE_ALLOWED_KUBERNETES_VERSIONS=" v1.13.0 " _BREEZE_ALLOWED_KUBERNETES_MODES=" persistent_mode git_mode " -_BREEZE_ALLOWED_STATIC_CHECKS=" all all-but-pylint check-apache-license check-executables-have-shebangs check-hooks-apply check-merge-conflict check-xml debug-statements doctoc detect-private-key end-of-file-fixer flake8 forbid-tabs insert-license lint-dockerfile mixed-line-ending mypy pylint setup-order shellcheck" +_BREEZE_ALLOWED_STATIC_CHECKS=" all all-but-pylint check-apache-license check-executables-have-shebangs check-hooks-apply check-merge-conflict check-xml debug-statements doctoc detect-private-key end-of-file-fixer flake8 forbid-tabs insert-license lint-dockerfile mixed-line-ending mypy pylint pylint-test setup-order shellcheck" _BREEZE_DEFAULT_DOCKERHUB_USER="apache" _BREEZE_DEFAULT_DOCKERHUB_REPO="airflow" diff --git a/docs/howto/custom-operator.rst b/docs/howto/custom-operator.rst index 2d89b109935b1..7713468bd49b8 100644 --- a/docs/howto/custom-operator.rst +++ b/docs/howto/custom-operator.rst @@ -20,8 +20,8 @@ Creating a custom Operator ========================== -Airflow allows you to create new operators to suit the requirements of you or your team. -The extensibility is one of the many reasons which makes Apache Airflow powerful. +Airflow allows you to create new operators to suit the requirements of you or your team. +The extensibility is one of the many reasons which makes Apache Airflow powerful. You can create any operator you want by extending the :class:`airflow.models.baseoperator.BaseOperator` @@ -31,16 +31,16 @@ There are two methods that you need to override in a derived class: Use ``@apply_defaults`` decorator function to fill unspecified arguments with ``default_args``. You can specify the ``default_args`` in the dag file. See :ref:`Default args ` for more details. -* Execute - The code to execute when the runner calls the operator. The method contains the +* Execute - The code to execute when the runner calls the operator. The method contains the airflow context as a parameter that can be used to read config values. Let's implement an example ``HelloOperator`` in a new file ``hello_operator.py``: .. code:: python - + from airflow.models.baseoperator import BaseOperator from airflow.utils.decorators import apply_defaults - + class HelloOperator(BaseOperator): @apply_defaults @@ -60,7 +60,7 @@ Let's implement an example ``HelloOperator`` in a new file ``hello_operator.py`` For imports to work, you should place the file in a directory that is present in the ``PYTHONPATH`` env. Airflow adds ``dags/``, ``plugins/``, and ``config/`` directories - in the Airflow home to ``PYTHONPATH`` by default. e.g., In our example, + in the Airflow home to ``PYTHONPATH`` by default. e.g., In our example, the file is placed in the ``custom_operator`` directory. You can now use the derived custom operator as follows: @@ -77,7 +77,7 @@ Hooks Hooks act as an interface to communicate with the external shared resources in a DAG. For example, multiple tasks in a DAG can require access to a MySQL database. Instead of creating a connection per task, you can retrieve a connection from the hook and utilize it. -Hook also helps to avoid storing connection auth parameters in a DAG. +Hook also helps to avoid storing connection auth parameters in a DAG. See :doc:`connection/index` for how to create and manage connections. Let's extend our previous example to fetch name from MySQL: @@ -107,9 +107,9 @@ Let's extend our previous example to fetch name from MySQL: print(message) return message -When the operator invokes the query on the hook object, a new connection gets created if it doesn't exist. +When the operator invokes the query on the hook object, a new connection gets created if it doesn't exist. The hook retrieves the auth parameters such as username and password from Airflow -backend and passes the params to the :py:func:`airflow.hooks.base_hook.BaseHook.get_connection`. +backend and passes the params to the :py:func:`airflow.hooks.base_hook.BaseHook.get_connection`. You should create hook only in the ``execute`` method or any method which is called from ``execute``. The constructor gets called whenever Airflow parses a DAG which happens frequently. The ``execute`` gets called only during a DAG run. @@ -117,7 +117,7 @@ The ``execute`` gets called only during a DAG run. User interface ^^^^^^^^^^^^^^^ Airflow also allows the developer to control how the operator shows up in the DAG UI. -Override ``ui_color`` to change the background color of the operator in UI. +Override ``ui_color`` to change the background color of the operator in UI. Override ``ui_fgcolor`` to change the color of the label. .. code:: python @@ -134,11 +134,11 @@ Airflow considers the field names present in ``template_fields`` for templating the operator. .. code:: python - + class HelloOperator(BaseOperator): - + template_fields = ['name'] - + @apply_defaults def __init__( self, @@ -166,14 +166,14 @@ In this example, Jinja looks for the ``name`` parameter and substitutes ``{{ tas The parameter can also contain a file name, for example, a bash script or a SQL file. You need to add the extension of your file in ``template_ext``. If a ``template_field`` contains a string ending with the extension mentioned in ``template_ext``, Jinja reads the content of the file and replace the templates -with actual value. Note that Jinja substitutes the operator attributes and not the args. +with actual value. Note that Jinja substitutes the operator attributes and not the args. .. code:: python class HelloOperator(BaseOperator): - + template_fields = ['guest_name'] - + @apply_defaults def __init__( self, diff --git a/scripts/ci/pylint_todo.txt b/scripts/ci/pylint_todo.txt index 4ab5fe486b258..c583756bd1fbf 100644 --- a/scripts/ci/pylint_todo.txt +++ b/scripts/ci/pylint_todo.txt @@ -103,11 +103,6 @@ ./airflow/contrib/sensors/sftp_sensor.py ./airflow/contrib/sensors/wasb_sensor.py ./airflow/contrib/sensors/weekday_sensor.py -./airflow/executors/base_executor.py -./airflow/executors/celery_executor.py -./airflow/executors/dask_executor.py -./airflow/executors/local_executor.py -./airflow/executors/sequential_executor.py ./airflow/hooks/dbapi_hook.py ./airflow/hooks/docker_hook.py ./airflow/hooks/druid_hook.py @@ -174,7 +169,6 @@ ./airflow/operators/s3_to_redshift_operator.py ./airflow/operators/slack_operator.py ./airflow/operators/sqlite_operator.py -./airflow/plugins_manager.py ./airflow/providers/amazon/aws/hooks/redshift.py ./airflow/providers/amazon/aws/operators/athena.py ./airflow/providers/amazon/aws/sensors/athena.py @@ -213,7 +207,6 @@ ./airflow/utils/cli_action_loggers.py ./airflow/utils/compression.py ./airflow/utils/configuration.py -./airflow/utils/dag_processing.py ./airflow/utils/dates.py ./airflow/utils/db.py ./airflow/utils/decorators.py diff --git a/tests/core.py b/tests/core.py index a323fa5378b66..edebbd84ae48d 100644 --- a/tests/core.py +++ b/tests/core.py @@ -31,10 +31,11 @@ from numpy.testing import assert_array_almost_equal from pendulum import utcnow -from airflow import DAG, configuration, exceptions, jobs, settings, utils -from airflow.configuration import AirflowConfigException, conf, run_command +from airflow import DAG, exceptions, jobs, settings, utils +from airflow.configuration import ( + DEFAULT_CONFIG, AirflowConfigException, conf, parameterized_config, run_command, +) from airflow.exceptions import AirflowException -from airflow.executors import SequentialExecutor from airflow.hooks.base_hook import BaseHook from airflow.hooks.sqlite_hook import SqliteHook from airflow.models import Connection, DagBag, DagRun, TaskFail, TaskInstance, Variable @@ -754,7 +755,7 @@ def test_variable_delete(self): def test_parameterized_config_gen(self): - cfg = configuration.parameterized_config(configuration.DEFAULT_CONFIG) + cfg = parameterized_config(DEFAULT_CONFIG) # making sure some basic building blocks are present: self.assertIn("[core]", cfg) @@ -870,6 +871,7 @@ def test_bad_trigger_rule(self): def test_terminate_task(self): """If a task instance's db state get deleted, it should fail""" + from airflow.executors.sequential_executor import SequentialExecutor TI = TaskInstance dag = self.dagbag.dags.get('test_utils') task = dag.task_dict.get('sleeps_forever') diff --git a/tests/dags/test_subdag.py b/tests/dags/test_subdag.py index 52bc9ec81c2b4..97cfb30548233 100644 --- a/tests/dags/test_subdag.py +++ b/tests/dags/test_subdag.py @@ -24,7 +24,7 @@ from datetime import datetime, timedelta -from airflow.models import DAG +from airflow.models.dag import DAG from airflow.operators.dummy_operator import DummyOperator from airflow.operators.subdag_operator import SubDagOperator diff --git a/tests/executors/test_base_executor.py b/tests/executors/test_base_executor.py index 3509634edcbb3..4851f010cd219 100644 --- a/tests/executors/test_base_executor.py +++ b/tests/executors/test_base_executor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py index bf5f6b4714e02..3409ede79cdb7 100644 --- a/tests/executors/test_celery_executor.py +++ b/tests/executors/test_celery_executor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -17,6 +16,7 @@ # specific language governing permissions and limitations # under the License. import contextlib +import datetime import os import sys import unittest @@ -24,14 +24,19 @@ from unittest import mock # leave this it is used by the test worker +# noinspection PyUnresolvedReferences import celery.contrib.testing.tasks # noqa: F401 pylint: disable=ungrouped-imports from celery import Celery, states as celery_states from celery.contrib.testing.worker import start_worker from kombu.asynchronous import set_event_loop from parameterized import parameterized +from airflow import DAG from airflow.configuration import conf from airflow.executors import celery_executor +from airflow.models import TaskInstance +from airflow.models.taskinstance import SimpleTaskInstance +from airflow.operators.bash_operator import BashOperator from airflow.utils.state import State @@ -76,14 +81,17 @@ def test_celery_integration(self, broker_url): with start_worker(app=app, logfile=sys.stdout, loglevel='info'): success_command = ['true', 'some_parameter'] fail_command = ['false', 'some_parameter'] + execute_date = datetime.datetime.now() cached_celery_backend = celery_executor.execute_command.backend - task_tuples_to_send = [('success', 'fake_simple_ti', success_command, - celery_executor.celery_configuration['task_default_queue'], - celery_executor.execute_command), - ('fail', 'fake_simple_ti', fail_command, - celery_executor.celery_configuration['task_default_queue'], - celery_executor.execute_command)] + task_tuples_to_send = [ + (('success', 'fake_simple_ti', execute_date, 0), + None, success_command, celery_executor.celery_configuration['task_default_queue'], + celery_executor.execute_command), + (('fail', 'fake_simple_ti', execute_date, 0), + None, fail_command, celery_executor.celery_configuration['task_default_queue'], + celery_executor.execute_command) + ] chunksize = executor._num_tasks_per_send_process(len(task_tuples_to_send)) num_processes = min(len(task_tuples_to_send), executor._sync_parallelism) @@ -97,21 +105,21 @@ def test_celery_integration(self, broker_url): send_pool.close() send_pool.join() - for key, command, result in key_and_async_results: + for task_instance_key, _, result in key_and_async_results: # Only pops when enqueued successfully, otherwise keep it # and expect scheduler loop to deal with it. result.backend = cached_celery_backend - executor.running[key] = command - executor.tasks[key] = result - executor.last_state[key] = celery_states.PENDING + executor.running.add(task_instance_key) + executor.tasks[task_instance_key] = result + executor.last_state[task_instance_key] = celery_states.PENDING - executor.running['success'] = True - executor.running['fail'] = True + executor.running.add(('success', 'fake_simple_ti', execute_date, 0)) + executor.running.add(('fail', 'fake_simple_ti', execute_date, 0)) executor.end(synchronous=True) - self.assertTrue(executor.event_buffer['success'], State.SUCCESS) - self.assertTrue(executor.event_buffer['fail'], State.FAILED) + self.assertEqual(executor.event_buffer[('success', 'fake_simple_ti', execute_date, 0)], State.SUCCESS) + self.assertEqual(executor.event_buffer[('fail', 'fake_simple_ti', execute_date, 0)], State.FAILED) self.assertNotIn('success', executor.tasks) self.assertNotIn('fail', executor.tasks) @@ -129,11 +137,19 @@ def fake_execute_command(): # fake_execute_command takes no arguments while execute_command takes 1, # which will cause TypeError when calling task.apply_async() executor = celery_executor.CeleryExecutor() - value_tuple = 'command', '_', 'queue', 'should_be_a_simple_ti' - executor.queued_tasks['key'] = value_tuple + task = BashOperator( + task_id="test", + bash_command="true", + dag=DAG(dag_id='id'), + start_date=datetime.datetime.now() + ) + value_tuple = 'command', 1, None, \ + SimpleTaskInstance(ti=TaskInstance(task=task, execution_date=datetime.datetime.now())) + key = ('fail', 'fake_simple_ti', datetime.datetime.now(), 0) + executor.queued_tasks[key] = value_tuple executor.heartbeat() self.assertEqual(1, len(executor.queued_tasks)) - self.assertEqual(executor.queued_tasks['key'], value_tuple) + self.assertEqual(executor.queued_tasks[key], value_tuple) def test_exception_propagation(self): with self._prepare_app() as app: diff --git a/tests/executors/test_dask_executor.py b/tests/executors/test_dask_executor.py index b9d8c0105910f..7968da560b2d9 100644 --- a/tests/executors/test_dask_executor.py +++ b/tests/executors/test_dask_executor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file diff --git a/tests/executors/test_executor.py b/tests/executors/test_executor.py index f501e56d3f6ba..642ed4324009a 100644 --- a/tests/executors/test_executor.py +++ b/tests/executors/test_executor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py index 0b107d093fba3..b099cc180ab57 100644 --- a/tests/executors/test_kubernetes_executor.py +++ b/tests/executors/test_kubernetes_executor.py @@ -199,7 +199,9 @@ def test_run_next_exception(self, mock_get_kube_client, mock_kubernetes_job_watc # Execute a task while the Api Throws errors try_number = 1 kubernetes_executor.execute_async(key=('dag', 'task', datetime.utcnow(), try_number), - command='command', executor_config={}) + queue=None, + command='command', + executor_config={}) kubernetes_executor.sync() kubernetes_executor.sync() diff --git a/tests/executors/test_local_executor.py b/tests/executors/test_local_executor.py index dc07b251c1c73..1c27f44d8b010 100644 --- a/tests/executors/test_local_executor.py +++ b/tests/executors/test_local_executor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -16,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +import datetime import unittest from unittest import mock @@ -37,22 +36,26 @@ def execution_parallelism(self, parallelism=0): fail_command = ['false', 'some_parameter'] self.assertTrue(executor.result_queue.empty()) + execution_date = datetime.datetime.now() for i in range(self.TEST_SUCCESS_COMMANDS): - key, command = success_key.format(i), success_command - executor.running[key] = True + key_id, command = success_key.format(i), success_command + key = key_id, 'fake_ti', execution_date, 0 + executor.running.add(key) executor.execute_async(key=key, command=command) - executor.running['fail'] = True - executor.execute_async(key='fail', command=fail_command) + fail_key = 'fail', 'fake_ti', execution_date, 0 + executor.running.add(fail_key) + executor.execute_async(key=fail_key, command=fail_command) executor.end() # By that time Queues are already shutdown so we cannot check if they are empty self.assertEqual(len(executor.running), 0) for i in range(self.TEST_SUCCESS_COMMANDS): - key = success_key.format(i) + key_id = success_key.format(i) + key = key_id, 'fake_ti', execution_date, 0 self.assertEqual(executor.event_buffer[key], State.SUCCESS) - self.assertEqual(executor.event_buffer['fail'], State.FAILED) + self.assertEqual(executor.event_buffer[fail_key], State.FAILED) expected = self.TEST_SUCCESS_COMMANDS + 1 if parallelism == 0 else parallelism self.assertEqual(executor.workers_used, expected) diff --git a/tests/executors/test_sequential_executor.py b/tests/executors/test_sequential_executor.py index f4b1cf4871457..6100c72b9f252 100644 --- a/tests/executors/test_sequential_executor.py +++ b/tests/executors/test_sequential_executor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file diff --git a/tests/gcp/hooks/test_google_discovery_api.py b/tests/gcp/hooks/test_google_discovery_api.py index 6117656ca5162..e6d5ca3ca22d5 100644 --- a/tests/gcp/hooks/test_google_discovery_api.py +++ b/tests/gcp/hooks/test_google_discovery_api.py @@ -20,7 +20,8 @@ import unittest from unittest.mock import call, patch -from airflow import configuration, models +from airflow import models +from airflow.configuration import load_test_config from airflow.gcp.hooks.discovery_api import GoogleDiscoveryApiHook from airflow.utils import db @@ -28,7 +29,7 @@ class TestGoogleDiscoveryApiHook(unittest.TestCase): def setUp(self): - configuration.load_test_config() + load_test_config() db.merge_conn( models.Connection( diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index ce881bcf58cda..d8af907226a84 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -24,7 +24,7 @@ from airflow import AirflowException, models, settings from airflow.configuration import conf -from airflow.executors import SequentialExecutor +from airflow.executors.sequential_executor import SequentialExecutor from airflow.jobs import LocalTaskJob from airflow.models import DAG, TaskInstance as TI from airflow.operators.dummy_operator import DummyOperator diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index fa2789ebb508c..f54a613368d9c 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -31,15 +31,16 @@ import airflow.example_dags from airflow import AirflowException, models, settings from airflow.configuration import conf -from airflow.executors import BaseExecutor +from airflow.executors.base_executor import BaseExecutor from airflow.jobs import BackfillJob, SchedulerJob from airflow.models import DAG, DagBag, DagModel, DagRun, Pool, SlaMiss, TaskInstance as TI, errors from airflow.operators.bash_operator import BashOperator from airflow.operators.dummy_operator import DummyOperator from airflow.utils import timezone -from airflow.utils.dag_processing import SimpleDag, SimpleDagBag, list_py_file_paths +from airflow.utils.dag_processing import SimpleDag, SimpleDagBag from airflow.utils.dates import days_ago from airflow.utils.db import create_session, provide_session +from airflow.utils.file import list_py_file_paths from airflow.utils.state import State from tests.compat import MagicMock, Mock, PropertyMock, mock, patch from tests.core import TEST_DAG_FOLDER @@ -134,7 +135,6 @@ def test_no_orphan_process_will_be_left(self): scheduler.run() shutil.rmtree(empty_dir) - scheduler.executor.terminate() # Remove potential noise created by previous tests. current_children = set(current_process.children(recursive=True)) - set( old_children) diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 5a371bff49e67..8f7d60c069780 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -37,7 +37,7 @@ from airflow.operators.dummy_operator import DummyOperator from airflow.operators.subdag_operator import SubDagOperator from airflow.utils import timezone -from airflow.utils.dag_processing import list_py_file_paths +from airflow.utils.file import list_py_file_paths from airflow.utils.state import State from airflow.utils.weight_rule import WeightRule from tests.models import DEFAULT_DATE diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index 67dd9f16955a9..71776df7a13fc 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -30,7 +30,7 @@ from airflow import models from airflow.configuration import conf from airflow.models import DagBag, DagModel, TaskInstance as TI -from airflow.utils.dag_processing import SimpleTaskInstance +from airflow.models.taskinstance import SimpleTaskInstance from airflow.utils.db import create_session from airflow.utils.state import State from tests.models import DEFAULT_DATE, TEST_DAGS_FOLDER diff --git a/tests/operators/test_google_api_to_s3_transfer.py b/tests/operators/test_google_api_to_s3_transfer.py index 14cea11c01735..b942a2223ca54 100644 --- a/tests/operators/test_google_api_to_s3_transfer.py +++ b/tests/operators/test_google_api_to_s3_transfer.py @@ -20,7 +20,8 @@ import unittest from unittest.mock import Mock, patch -from airflow import configuration, models +from airflow import models +from airflow.configuration import load_test_config from airflow.models.xcom import MAX_XCOM_SIZE from airflow.operators.google_api_to_s3_transfer import GoogleApiToS3Transfer from airflow.utils import db @@ -29,7 +30,7 @@ class TestGoogleApiToS3Transfer(unittest.TestCase): def setUp(self): - configuration.load_test_config() + load_test_config() db.merge_conn( models.Connection( diff --git a/tests/operators/test_operators.py b/tests/operators/test_operators.py index e2612ccd1101c..163eb87bba9ae 100644 --- a/tests/operators/test_operators.py +++ b/tests/operators/test_operators.py @@ -22,7 +22,8 @@ from collections import OrderedDict from unittest import mock -from airflow import DAG, configuration, operators +from airflow import DAG, operators +from airflow.configuration import conf from airflow.utils import timezone DEFAULT_DATE = timezone.datetime(2015, 1, 1) @@ -47,7 +48,7 @@ def tearDown(self): for table in drop_tables: conn.execute("DROP TABLE IF EXISTS {}".format(table)) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_mysql_operator_test(self): sql = """ @@ -62,7 +63,7 @@ def test_mysql_operator_test(self): dag=self.dag) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_mysql_operator_test_multi(self): sql = [ @@ -78,7 +79,7 @@ def test_mysql_operator_test_multi(self): ) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_mysql_hook_test_bulk_load(self): records = ("foo", "bar", "baz") @@ -102,7 +103,7 @@ def test_mysql_hook_test_bulk_load(self): results = tuple(result[0] for result in c.fetchall()) self.assertEqual(sorted(results), sorted(records)) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_mysql_hook_test_bulk_dump(self): from airflow.hooks.mysql_hook import MySqlHook @@ -115,7 +116,7 @@ def test_mysql_hook_test_bulk_dump(self): self.skipTest("Skip test_mysql_hook_test_bulk_load " "since file output is not permitted") - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") @mock.patch('airflow.hooks.mysql_hook.MySqlHook.get_conn') def test_mysql_hook_test_bulk_dump_mock(self, mock_get_conn): @@ -136,7 +137,7 @@ def test_mysql_hook_test_bulk_dump_mock(self, mock_get_conn): """.format(tmp_file=tmp_file, table=table) assertEqualIgnoreMultipleSpaces(self, mock_execute.call_args[0][0], query) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_mysql_to_mysql(self): sql = "SELECT * FROM INFORMATION_SCHEMA.TABLES LIMIT 100;" @@ -155,7 +156,7 @@ def test_mysql_to_mysql(self): dag=self.dag) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_overwrite_schema(self): """ @@ -193,7 +194,7 @@ def tearDown(self): for t in tables_to_drop: cur.execute("DROP TABLE IF EXISTS {}".format(t)) - @unittest.skipUnless('postgres' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('postgres' in conf.get('core', 'sql_alchemy_conn'), "This is a Postgres test") def test_postgres_operator_test(self): sql = """ @@ -215,7 +216,7 @@ def test_postgres_operator_test(self): end_date=DEFAULT_DATE, ignore_ti_state=True) - @unittest.skipUnless('postgres' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('postgres' in conf.get('core', 'sql_alchemy_conn'), "This is a Postgres test") def test_postgres_operator_test_multi(self): sql = [ @@ -228,7 +229,7 @@ def test_postgres_operator_test_multi(self): task_id='postgres_operator_test_multi', sql=sql, dag=self.dag) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @unittest.skipUnless('postgres' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('postgres' in conf.get('core', 'sql_alchemy_conn'), "This is a Postgres test") def test_postgres_to_postgres(self): sql = "SELECT * FROM INFORMATION_SCHEMA.TABLES LIMIT 100;" @@ -247,7 +248,7 @@ def test_postgres_to_postgres(self): dag=self.dag) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @unittest.skipUnless('postgres' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('postgres' in conf.get('core', 'sql_alchemy_conn'), "This is a Postgres test") def test_vacuum(self): """ @@ -263,7 +264,7 @@ def test_vacuum(self): autocommit=True) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @unittest.skipUnless('postgres' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('postgres' in conf.get('core', 'sql_alchemy_conn'), "This is a Postgres test") def test_overwrite_schema(self): """ @@ -369,14 +370,14 @@ def tearDown(self): with MySqlHook().get_conn() as cur: cur.execute("DROP TABLE IF EXISTS baby_names CASCADE;") - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_clear(self): self.dag.clear( start_date=DEFAULT_DATE, end_date=timezone.utcnow()) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_mysql_to_hive(self): from airflow.operators.mysql_to_hive import MySqlToHiveTransfer @@ -391,7 +392,7 @@ def test_mysql_to_hive(self): dag=self.dag) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_mysql_to_hive_partition(self): from airflow.operators.mysql_to_hive import MySqlToHiveTransfer @@ -408,7 +409,7 @@ def test_mysql_to_hive_partition(self): dag=self.dag) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_mysql_to_hive_tblproperties(self): from airflow.operators.mysql_to_hive import MySqlToHiveTransfer @@ -424,7 +425,7 @@ def test_mysql_to_hive_tblproperties(self): dag=self.dag) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") @mock.patch('airflow.hooks.hive_hooks.HiveCliHook.load_file') def test_mysql_to_hive_type_conversion(self, mock_load_file): @@ -469,7 +470,7 @@ def test_mysql_to_hive_type_conversion(self, mock_load_file): with m.get_conn() as c: c.execute("DROP TABLE IF EXISTS {}".format(mysql_table)) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_mysql_to_hive_verify_csv_special_char(self): mysql_table = 'test_mysql_to_hive' @@ -520,7 +521,7 @@ def test_mysql_to_hive_verify_csv_special_char(self): with m.get_conn() as c: c.execute("DROP TABLE IF EXISTS {}".format(mysql_table)) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_mysql_to_hive_verify_loaded_values(self): mysql_table = 'test_mysql_to_hive' diff --git a/tests/plugins/test_plugin.py b/tests/plugins/test_plugin.py index 0e3f3fe07ce21..4ec1d7f295d75 100644 --- a/tests/plugins/test_plugin.py +++ b/tests/plugins/test_plugin.py @@ -83,9 +83,11 @@ def test(self): static_url_path='/static/test_plugin') -# Create a handler to validate statsd stat name -def stat_name_dummy_handler(stat_name): - return stat_name +class StatsClass: + @staticmethod + # Create a handler to validate statsd stat name + def stat_name_dummy_handler(stat_name: str) -> str: + return stat_name # Defining the plugin class @@ -99,7 +101,7 @@ class AirflowTestPlugin(AirflowPlugin): flask_blueprints = [bp] appbuilder_views = [v_appbuilder_package] appbuilder_menu_items = [appbuilder_mitem] - stat_name_handler = staticmethod(stat_name_dummy_handler) + stat_name_handler = StatsClass.stat_name_dummy_handler global_operator_extra_links = [ AirflowLink(), GithubLink(), diff --git a/tests/task/task_runner/test_standard_task_runner.py b/tests/task/task_runner/test_standard_task_runner.py index de7416c5be50c..e99ea2a84b92d 100644 --- a/tests/task/task_runner/test_standard_task_runner.py +++ b/tests/task/task_runner/test_standard_task_runner.py @@ -76,18 +76,18 @@ def test_start_and_terminate(self): pgid = os.getpgid(runner.process.pid) self.assertTrue(pgid) - procs = [] - for p in psutil.process_iter(): + processes = [] + for process in psutil.process_iter(): try: - if os.getpgid(p.pid) == pgid: - procs.append(p) + if os.getpgid(process.pid) == pgid: + processes.append(process) except OSError: pass runner.terminate() - for p in procs: - self.assertFalse(psutil.pid_exists(p.pid)) + for process in processes: + self.assertFalse(psutil.pid_exists(process.pid)) def test_on_kill(self): """ diff --git a/tests/test_configuration.py b/tests/test_configuration.py index 79c4ca4c526ed..02127cead1ceb 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -24,7 +23,9 @@ from unittest import mock from airflow import configuration -from airflow.configuration import AirflowConfigParser, conf, parameterized_config +from airflow.configuration import ( + AirflowConfigParser, conf, expand_env_var, get_airflow_config, get_airflow_home, parameterized_config, +) @unittest.mock.patch.dict('os.environ', { @@ -42,13 +43,13 @@ def test_airflow_home_default(self): if 'AIRFLOW_HOME' in os.environ: del os.environ['AIRFLOW_HOME'] self.assertEqual( - configuration.get_airflow_home(), - configuration.expand_env_var('~/airflow')) + get_airflow_home(), + expand_env_var('~/airflow')) def test_airflow_home_override(self): with unittest.mock.patch.dict('os.environ', AIRFLOW_HOME='/path/to/airflow'): self.assertEqual( - configuration.get_airflow_home(), + get_airflow_home(), '/path/to/airflow') def test_airflow_config_default(self): @@ -56,13 +57,13 @@ def test_airflow_config_default(self): if 'AIRFLOW_CONFIG' in os.environ: del os.environ['AIRFLOW_CONFIG'] self.assertEqual( - configuration.get_airflow_config('/home/airflow'), - configuration.expand_env_var('/home/airflow/airflow.cfg')) + get_airflow_config('/home/airflow'), + expand_env_var('/home/airflow/airflow.cfg')) def test_airflow_config_override(self): with unittest.mock.patch.dict('os.environ', AIRFLOW_CONFIG='/path/to/airflow/airflow.cfg'): self.assertEqual( - configuration.get_airflow_config('/home//airflow'), + get_airflow_config('/home//airflow'), '/path/to/airflow/airflow.cfg') def test_case_sensitivity(self): diff --git a/tests/utils/test_dag_processing.py b/tests/utils/test_dag_processing.py index c34810925d461..1ad39e21445a8 100644 --- a/tests/utils/test_dag_processing.py +++ b/tests/utils/test_dag_processing.py @@ -29,11 +29,11 @@ from airflow.configuration import conf from airflow.jobs import DagFileProcessor, LocalTaskJob as LJ from airflow.models import DagBag, TaskInstance as TI +from airflow.models.taskinstance import SimpleTaskInstance from airflow.utils import timezone -from airflow.utils.dag_processing import ( - DagFileProcessorAgent, DagFileProcessorManager, DagFileStat, SimpleTaskInstance, correct_maybe_zipped, -) +from airflow.utils.dag_processing import DagFileProcessorAgent, DagFileProcessorManager, DagFileStat from airflow.utils.db import create_session +from airflow.utils.file import correct_maybe_zipped from airflow.utils.state import State from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_runs diff --git a/tests/utils/test_email.py b/tests/utils/test_email.py index 8f28075739a81..bc98ccafddf07 100644 --- a/tests/utils/test_email.py +++ b/tests/utils/test_email.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -25,7 +24,8 @@ import mock -from airflow import conf, utils +from airflow import utils +from airflow.configuration import conf from airflow.utils.email import get_email_address_list from tests import conf_vars diff --git a/tests/www/test_views.py b/tests/www/test_views.py index e3841d6d35c7c..4f4f071df4ffe 100644 --- a/tests/www/test_views.py +++ b/tests/www/test_views.py @@ -592,7 +592,7 @@ def test_run(self): resp = self.client.post('run', data=form) self.check_content_in_response('', resp, resp_code=302) - @mock.patch('airflow.executors.get_default_executor') + @mock.patch('airflow.executors.executor_loader.ExecutorLoader.get_default_executor') def test_run_with_runnable_states(self, get_default_executor_function): executor = CeleryExecutor() executor.heartbeat = lambda: True @@ -622,7 +622,7 @@ def test_run_with_runnable_states(self, get_default_executor_function): .format(state) + "The task must be cleared in order to be run" self.assertFalse(re.search(msg, resp.get_data(as_text=True))) - @mock.patch('airflow.executors.get_default_executor') + @mock.patch('airflow.executors.executor_loader.ExecutorLoader.get_default_executor') def test_run_with_not_runnable_states(self, get_default_executor_function): get_default_executor_function.return_value = CeleryExecutor()