diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e2c4c0452ca5e..22643454c519b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -264,12 +264,14 @@ repos: files: \.py$ exclude: ^tests/.*\.py$|^airflow/_vendor/.*$ pass_filenames: true - - id: pylint + require_serial: true # Pylint tests should be run in one chunk to detect all cycles + - id: pylint-tests name: Run pylint for tests language: system entry: "./scripts/ci/pre_commit_pylint_tests.sh" files: ^tests/.*\.py$ pass_filenames: true + require_serial: false # For tests, it's perfectly OK to run pylint in parallel - id: flake8 name: Run flake8 language: system diff --git a/BREEZE.rst b/BREEZE.rst index 18c0a391aff38..30bc49c7f3b06 100644 --- a/BREEZE.rst +++ b/BREEZE.rst @@ -872,7 +872,7 @@ This is the current syntax for `./breeze <./breeze>`_: -S, --static-check Run selected static checks for currently changed files. You should specify static check that you would like to run or 'all' to run all checks. One of - [ all all-but-pylint check-apache-license check-executables-have-shebangs check-hooks-apply check-merge-conflict check-xml debug-statements doctoc detect-private-key end-of-file-fixer flake8 forbid-tabs insert-license lint-dockerfile mixed-line-ending mypy pylint setup-order shellcheck]. + [ all all-but-pylint check-apache-license check-executables-have-shebangs check-hooks-apply check-merge-conflict check-xml debug-statements doctoc detect-private-key end-of-file-fixer flake8 forbid-tabs insert-license lint-dockerfile mixed-line-ending mypy pylint pylint-test setup-order shellcheck]. You can pass extra arguments including options to to the pre-commit framework as passed after --. For example: @@ -886,7 +886,7 @@ This is the current syntax for `./breeze <./breeze>`_: -F, --static-check-all-files Run selected static checks for all applicable files. You should specify static check that you would like to run or 'all' to run all checks. One of - [ all all-but-pylint check-apache-license check-executables-have-shebangs check-hooks-apply check-merge-conflict check-xml debug-statements doctoc detect-private-key end-of-file-fixer flake8 forbid-tabs insert-license lint-dockerfile mixed-line-ending mypy pylint setup-order shellcheck]. + [ all all-but-pylint check-apache-license check-executables-have-shebangs check-hooks-apply check-merge-conflict check-xml debug-statements doctoc detect-private-key end-of-file-fixer flake8 forbid-tabs insert-license lint-dockerfile mixed-line-ending mypy pylint pylint-test setup-order shellcheck]. You can pass extra arguments including options to the pre-commit framework as passed after --. For example: diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 605d8dc97ca1a..83d9fe9da8d1a 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -454,7 +454,9 @@ image built locally): ----------------------------------- ---------------------------------------------------------------- ------------ ``pydevd`` Check for accidentally commited pydevd statements. ----------------------------------- ---------------------------------------------------------------- ------------ -``pylint`` Runs pylint. * +``pylint`` Runs pylint for main code. * +----------------------------------- ---------------------------------------------------------------- ------------ +``pylint-tests`` Runs pylint for tests. * ----------------------------------- ---------------------------------------------------------------- ------------ ``python-no-log-warn`` Checks if there are no deprecate log warn. ----------------------------------- ---------------------------------------------------------------- ------------ diff --git a/UPDATING.md b/UPDATING.md index 61821f5fd40f1..f0bc3598055e0 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -41,6 +41,11 @@ assists users migrating to a new version. ## Airflow Master +### Removal of airflow.AirflowMacroPlugin class + +The class was there in airflow package but it has not been used (apparently since 2015). +It has been removed. + ### Changes to settings CONTEXT_MANAGER_DAG was removed from settings. It's role has been taken by `DagContext` in diff --git a/airflow/__init__.py b/airflow/__init__.py index f9c114ef1cb86..862c9bcdc51ab 100644 --- a/airflow/__init__.py +++ b/airflow/__init__.py @@ -44,23 +44,8 @@ settings.initialize() -login = None # type: Optional[Callable] +from airflow.plugins_manager import integrate_plugins -from airflow import executors -from airflow import hooks -from airflow import macros -from airflow import operators -from airflow import sensors +login: Optional[Callable] = None - -class AirflowMacroPlugin: - # pylint: disable=missing-docstring - def __init__(self, namespace): - self.namespace = namespace - - -operators._integrate_plugins() # pylint: disable=protected-access -sensors._integrate_plugins() # pylint: disable=protected-access -hooks._integrate_plugins() # pylint: disable=protected-access -executors._integrate_plugins() # pylint: disable=protected-access -macros._integrate_plugins() # pylint: disable=protected-access +integrate_plugins() diff --git a/airflow/cli/commands/flower_command.py b/airflow/cli/commands/flower_command.py index da0eef5a8b36d..c9f90a6445614 100644 --- a/airflow/cli/commands/flower_command.py +++ b/airflow/cli/commands/flower_command.py @@ -23,7 +23,7 @@ import daemon from daemon.pidfile import TimeoutPIDLockFile -from airflow import conf +from airflow.configuration import conf from airflow.utils import cli as cli_utils from airflow.utils.cli import setup_locations, sigint_handler diff --git a/airflow/cli/commands/serve_logs_command.py b/airflow/cli/commands/serve_logs_command.py index 86e29464124d8..db48a0f0b855b 100644 --- a/airflow/cli/commands/serve_logs_command.py +++ b/airflow/cli/commands/serve_logs_command.py @@ -18,7 +18,7 @@ """Serve logs command""" import os -from airflow import conf +from airflow.configuration import conf from airflow.utils import cli as cli_utils diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index 327b51a63043e..4ea72bf60b7ef 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -25,7 +25,7 @@ from contextlib import redirect_stderr, redirect_stdout from airflow import DAG, AirflowException, conf, jobs, settings -from airflow.executors import get_default_executor +from airflow.executors.executor_loader import ExecutorLoader from airflow.models import DagPickle, TaskInstance from airflow.ti_deps.dep_context import SCHEDULER_QUEUED_DEPS, DepContext from airflow.utils import cli as cli_utils, db @@ -69,7 +69,7 @@ def _run(args, dag, ti): print(e) raise e - executor = get_default_executor() + executor = ExecutorLoader.get_default_executor() executor.start() print("Sending to executor.") executor.queue_task_instance( diff --git a/airflow/cli/commands/worker_command.py b/airflow/cli/commands/worker_command.py index bccc205a822c6..d6c57fdbe5297 100644 --- a/airflow/cli/commands/worker_command.py +++ b/airflow/cli/commands/worker_command.py @@ -23,7 +23,8 @@ import daemon from daemon.pidfile import TimeoutPIDLockFile -from airflow import conf, settings +from airflow import settings +from airflow.configuration import conf from airflow.utils import cli as cli_utils from airflow.utils.cli import setup_locations, setup_logging, sigint_handler diff --git a/airflow/executors/__init__.py b/airflow/executors/__init__.py index e6322d13ce8d6..21ee94b03b9b3 100644 --- a/airflow/executors/__init__.py +++ b/airflow/executors/__init__.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -16,84 +14,4 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# pylint: disable=missing-docstring - -import sys -from typing import Optional - -from airflow.configuration import conf -from airflow.exceptions import AirflowException -from airflow.executors.base_executor import BaseExecutor -from airflow.executors.local_executor import LocalExecutor -from airflow.executors.sequential_executor import SequentialExecutor -from airflow.utils.log.logging_mixin import LoggingMixin - -DEFAULT_EXECUTOR = None # type: Optional[BaseExecutor] - - -def _integrate_plugins(): - """Integrate plugins to the context.""" - from airflow.plugins_manager import executors_modules - for executors_module in executors_modules: - sys.modules[executors_module.__name__] = executors_module - globals()[executors_module._name] = executors_module # pylint: disable=protected-access - - -def get_default_executor(): - """Creates a new instance of the configured executor if none exists and returns it""" - global DEFAULT_EXECUTOR # pylint: disable=global-statement - - if DEFAULT_EXECUTOR is not None: - return DEFAULT_EXECUTOR - - executor_name = conf.get('core', 'EXECUTOR') - - DEFAULT_EXECUTOR = _get_executor(executor_name) - - log = LoggingMixin().log - log.info("Using executor %s", executor_name) - - return DEFAULT_EXECUTOR - - -class Executors: - LocalExecutor = "LocalExecutor" - SequentialExecutor = "SequentialExecutor" - CeleryExecutor = "CeleryExecutor" - DaskExecutor = "DaskExecutor" - KubernetesExecutor = "KubernetesExecutor" - - -def _get_executor(executor_name): - """ - Creates a new instance of the named executor. - In case the executor name is not know in airflow, - look for it in the plugins - """ - if executor_name == Executors.LocalExecutor: - return LocalExecutor() - elif executor_name == Executors.SequentialExecutor: - return SequentialExecutor() - elif executor_name == Executors.CeleryExecutor: - from airflow.executors.celery_executor import CeleryExecutor - return CeleryExecutor() - elif executor_name == Executors.DaskExecutor: - from airflow.executors.dask_executor import DaskExecutor - return DaskExecutor() - elif executor_name == Executors.KubernetesExecutor: - from airflow.executors.kubernetes_executor import KubernetesExecutor - return KubernetesExecutor() - else: - # Loading plugins - _integrate_plugins() - executor_path = executor_name.split('.') - if len(executor_path) != 2: - raise AirflowException( - "Executor {0} not supported: " - "please specify in format plugin_module.executor".format(executor_name)) - - if executor_path[0] in globals(): - return globals()[executor_path[0]].__dict__[executor_path[1]]() - else: - raise AirflowException("Executor {0} not supported.".format(executor_name)) +"""Executors.""" diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 506713a741f2b..a0109b49d9f3d 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -16,67 +14,86 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +""" +Base executor - this is the base class for all the implemented executors. +""" from collections import OrderedDict +from typing import Any, Dict, List, Optional, Set, Tuple, Union -# To avoid circular imports -import airflow.utils.dag_processing -from airflow.configuration import conf +from airflow import LoggingMixin, conf +from airflow.models import TaskInstance +from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKeyType from airflow.stats import Stats -from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State -PARALLELISM = conf.getint('core', 'PARALLELISM') +PARALLELISM: int = conf.getint('core', 'PARALLELISM') +NOT_STARTED_MESSAGE = "The executor should be started first!" -class BaseExecutor(LoggingMixin): +# Command to execute - might be either string or list of strings +# with the same semantics as subprocess.Popen +CommandType = Union[str, List[str]] - def __init__(self, parallelism=PARALLELISM): - """ - Class to derive in order to interface with executor-type systems - like Celery, Yarn and the likes. - :param parallelism: how many jobs should run at one time. Set to - ``0`` for infinity - :type parallelism: int - """ - self.parallelism = parallelism - self.queued_tasks = OrderedDict() - self.running = {} - self.event_buffer = {} +# Task that is queued. It contains all the information that is +# needed to run the task. +# +# Tuple of: command, priority, queue name, SimpleTaskInstance +QueuedTaskInstanceType = Tuple[CommandType, int, Optional[str], SimpleTaskInstance] + + +class BaseExecutor(LoggingMixin): + """ + Class to derive in order to interface with executor-type systems + like Celery, Kubernetes, Local, Sequential and the likes. + + :param parallelism: how many jobs should run at one time. Set to + ``0`` for infinity + """ + def __init__(self, parallelism: int = PARALLELISM): + super().__init__() + self.parallelism: int = parallelism + self.queued_tasks: OrderedDict[TaskInstanceKeyType, QueuedTaskInstanceType] \ + = OrderedDict() + self.running: Set[TaskInstanceKeyType] = set() + self.event_buffer: Dict[TaskInstanceKeyType, Optional[str]] = {} def start(self): # pragma: no cover """ - Executors may need to get things started. For example LocalExecutor - starts N workers. + Executors may need to get things started. """ - def queue_command(self, simple_task_instance, command, priority=1, queue=None): - key = simple_task_instance.key - if key not in self.queued_tasks and key not in self.running: + def queue_command(self, + simple_task_instance: SimpleTaskInstance, + command: CommandType, + priority: int = 1, + queue: Optional[str] = None): + """Queues command to task""" + if simple_task_instance.key not in self.queued_tasks and simple_task_instance.key not in self.running: self.log.info("Adding to queue: %s", command) - self.queued_tasks[key] = (command, priority, queue, simple_task_instance) + self.queued_tasks[simple_task_instance.key] = (command, priority, queue, simple_task_instance) else: - self.log.info("could not queue task %s", key) + self.log.info("could not queue task %s", simple_task_instance.key) def queue_task_instance( self, - task_instance, - mark_success=False, - pickle_id=None, - ignore_all_deps=False, - ignore_depends_on_past=False, - ignore_task_deps=False, - ignore_ti_state=False, - pool=None, - cfg_path=None): + task_instance: TaskInstance, + mark_success: bool = False, + pickle_id: Optional[str] = None, + ignore_all_deps: bool = False, + ignore_depends_on_past: bool = False, + ignore_task_deps: bool = False, + ignore_ti_state: bool = False, + pool: Optional[str] = None, + cfg_path: Optional[str] = None): + """Queues task instance.""" pool = pool or task_instance.pool # TODO (edgarRd): AIRFLOW-1985: # cfg_path is needed to propagate the config values if using impersonation # (run_as_user), given that there are different code paths running tasks. # For a long term solution we need to address AIRFLOW-1986 - command = task_instance.command_as_list( + command_list_to_run = task_instance.command_as_list( local=True, mark_success=mark_success, ignore_all_deps=ignore_all_deps, @@ -87,29 +104,30 @@ def queue_task_instance( pickle_id=pickle_id, cfg_path=cfg_path) self.queue_command( - airflow.utils.dag_processing.SimpleTaskInstance(task_instance), - command, + SimpleTaskInstance(task_instance), + command_list_to_run, priority=task_instance.task.priority_weight_total, queue=task_instance.task.queue) - def has_task(self, task_instance): + def has_task(self, task_instance: TaskInstance) -> bool: """ - Checks if a task is either queued or running in this executor + Checks if a task is either queued or running in this executor. :param task_instance: TaskInstance :return: True if the task is known to this executor """ - if task_instance.key in self.queued_tasks or task_instance.key in self.running: - return True + return task_instance.key in self.queued_tasks or task_instance.key in self.running - def sync(self): + def sync(self) -> None: """ Sync will get called periodically by the heartbeat method. Executors should override this to perform gather statuses. """ - def heartbeat(self): - # Triggering new jobs + def heartbeat(self) -> None: + """ + Heartbeat sent to trigger new jobs. + """ if not self.parallelism: open_slots = len(self.queued_tasks) else: @@ -132,47 +150,65 @@ def heartbeat(self): self.log.debug("Calling the %s sync method", self.__class__) self.sync() - def trigger_tasks(self, open_slots): + def trigger_tasks(self, open_slots: int) -> None: """ - Trigger tasks + Triggers tasks :param open_slots: Number of open slots - :return: """ sorted_queue = sorted( [(k, v) for k, v in self.queued_tasks.items()], key=lambda x: x[1][1], reverse=True) for _ in range(min((open_slots, len(self.queued_tasks)))): - key, (command, _, queue, simple_ti) = sorted_queue.pop(0) + key, (command, _, _, simple_ti) = sorted_queue.pop(0) self.queued_tasks.pop(key) - self.running[key] = command + self.running.add(key) self.execute_async(key=key, command=command, - queue=queue, + queue=None, executor_config=simple_ti.executor_config) - def change_state(self, key, state): + def change_state(self, key: TaskInstanceKeyType, state: str) -> None: + """ + Changes state of the task. + + :param key: Unique key for the task instance + :param state: State to set for the task. + """ self.log.debug("Changing state: %s", key) - self.running.pop(key, None) + try: + self.running.remove(key) + except KeyError: + self.log.debug('Could not find key: %s', str(key)) self.event_buffer[key] = state - def fail(self, key): + def fail(self, key: TaskInstanceKeyType) -> None: + """ + Set fail state for the event. + + :param key: Unique key for the task instance + """ self.change_state(key, State.FAILED) - def success(self, key): + def success(self, key: TaskInstanceKeyType) -> None: + """ + Set success state for the event. + + :param key: Unique key for the task instance + """ self.change_state(key, State.SUCCESS) - def get_event_buffer(self, dag_ids=None): + def get_event_buffer(self, dag_ids=None) -> Dict[TaskInstanceKeyType, Optional[str]]: """ Returns and flush the event buffer. In case dag_ids is specified it will only return and flush events for the given dag_ids. Otherwise - it returns and flushes all + it returns and flushes all events. :param dag_ids: to dag_ids to return events for, if None returns all :return: a dict of events """ - cleared_events = dict() + cleared_events: Dict[TaskInstanceKeyType, Optional[str]] = dict() if dag_ids is None: cleared_events = self.event_buffer self.event_buffer = dict() @@ -185,16 +221,21 @@ def get_event_buffer(self, dag_ids=None): return cleared_events def execute_async(self, - key, - command, - queue=None, - executor_config=None): # pragma: no cover + key: TaskInstanceKeyType, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None) -> None: # pragma: no cover """ This method will execute the command asynchronously. + + :param key: Unique key for the task instance + :param command: Command to run + :param queue: name of the queue + :param executor_config: Configuration passed to the executor. """ raise NotImplementedError() - def end(self): # pragma: no cover + def end(self) -> None: # pragma: no cover """ This method is called when the caller is done submitting job and wants to wait synchronously for the job submitted previously to be diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index 83fc44b59f3f9..906531757d3ee 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -16,20 +15,23 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Celery executor.""" import math import os import subprocess import time import traceback from multiprocessing import Pool, cpu_count +from typing import Any, List, Optional, Tuple, Union -from celery import Celery, states as celery_states +from celery import Celery, Task, states as celery_states +from celery.result import AsyncResult from airflow.config_templates.default_celery import DEFAULT_CELERY_CONFIG from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.executors.base_executor import BaseExecutor +from airflow.executors.base_executor import BaseExecutor, CommandType, QueuedTaskInstanceType +from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKeyType, TaskInstanceStateType from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.module_loading import import_string from airflow.utils.timeout import timeout @@ -57,7 +59,8 @@ @app.task -def execute_command(command_to_exec): +def execute_command(command_to_exec: str) -> None: + """Executes command.""" log = LoggingMixin().log log.info("Executing command in Celery: %s", command_to_exec) env = os.environ.copy() @@ -67,7 +70,6 @@ def execute_command(command_to_exec): except subprocess.CalledProcessError as e: log.exception('execute_command encountered a CalledProcessError') log.error(e.output) - raise AirflowException('Celery command failed') @@ -81,12 +83,13 @@ class ExceptionWithTraceback: :type exception_traceback: str """ - def __init__(self, exception, exception_traceback): + def __init__(self, exception: Exception, exception_traceback: str): self.exception = exception self.traceback = exception_traceback -def fetch_celery_task_state(celery_task): +def fetch_celery_task_state(celery_task: Tuple[TaskInstanceKeyType, AsyncResult]) \ + -> Union[TaskInstanceStateType, ExceptionWithTraceback]: """ Fetch and return the state of the given celery task. The scope of this function is global so that it can be called by subprocesses in the pool. @@ -102,22 +105,27 @@ def fetch_celery_task_state(celery_task): with timeout(seconds=2): # Accessing state property of celery task will make actual network request # to get the current state of the task. - res = (celery_task[0], celery_task[1].state) - except Exception as e: + return celery_task[0], celery_task[1].state + except Exception as e: # pylint: disable=broad-except exception_traceback = "Celery Task ID: {}\n{}".format(celery_task[0], traceback.format_exc()) - res = ExceptionWithTraceback(e, exception_traceback) - return res + return ExceptionWithTraceback(e, exception_traceback) + +# Task instance that is sent over Celery queues +# TaskInstanceKeyType, SimpleTaskInstance, Command, queue_name, CallableTask +TaskInstanceInCelery = Tuple[TaskInstanceKeyType, SimpleTaskInstance, CommandType, Optional[str], Task] -def send_task_to_executor(task_tuple): - key, _, command, queue, task = task_tuple + +def send_task_to_executor(task_tuple: TaskInstanceInCelery) \ + -> Tuple[TaskInstanceKeyType, CommandType, Union[AsyncResult, ExceptionWithTraceback]]: + """Sends task to executor.""" + key, _, command, queue, task_to_run = task_tuple try: with timeout(seconds=2): - result = task.apply_async(args=[command], queue=queue) - except Exception as e: - exception_traceback = "Celery Task ID: {}\n{}".format(key, - traceback.format_exc()) + result = task_to_run.apply_async(args=[command], queue=queue) + except Exception as e: # pylint: disable=broad-except + exception_traceback = "Celery Task ID: {}\n{}".format(key, traceback.format_exc()) result = ExceptionWithTraceback(e, exception_traceback) return key, command, result @@ -148,13 +156,13 @@ def __init__(self): self.tasks = {} self.last_state = {} - def start(self): + def start(self) -> None: self.log.debug( 'Starting Celery Executor using %s processes for syncing', self._sync_parallelism ) - def _num_tasks_per_send_process(self, to_send_count): + def _num_tasks_per_send_process(self, to_send_count: int) -> int: """ How many Celery tasks should each worker process send. @@ -164,34 +172,29 @@ def _num_tasks_per_send_process(self, to_send_count): return max(1, int(math.ceil(1.0 * to_send_count / self._sync_parallelism))) - def _num_tasks_per_fetch_process(self): + def _num_tasks_per_fetch_process(self) -> int: """ How many Celery tasks should be sent to each worker process. :return: Number of tasks that should be used per process :rtype: int """ - return max(1, - int(math.ceil(1.0 * len(self.tasks) / self._sync_parallelism))) + return max(1, int(math.ceil(1.0 * len(self.tasks) / self._sync_parallelism))) - def trigger_tasks(self, open_slots): + def trigger_tasks(self, open_slots: int) -> None: """ Overwrite trigger_tasks function from BaseExecutor :param open_slots: Number of open slots :return: """ - sorted_queue = sorted( - [(k, v) for k, v in self.queued_tasks.items()], - key=lambda x: x[1][1], - reverse=True) + sorted_queue = self.order_queued_tasks_by_priority() - task_tuples_to_send = [] + task_tuples_to_send: List[TaskInstanceInCelery] = [] for _ in range(min((open_slots, len(self.queued_tasks)))): key, (command, _, queue, simple_ti) = sorted_queue.pop(0) - task_tuples_to_send.append((key, simple_ti, command, queue, - execute_command)) + task_tuples_to_send.append((key, simple_ti, command, queue, execute_command)) cached_celery_backend = None if task_tuples_to_send: @@ -202,7 +205,7 @@ def trigger_tasks(self, open_slots): cached_celery_backend = tasks[0].backend if task_tuples_to_send: - # Use chunking instead of a work queue to reduce context switching + # Use chunks instead of a work queue to reduce context switching # since tasks are roughly uniform in size chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send)) num_processes = min(len(task_tuples_to_send), self._sync_parallelism) @@ -227,11 +230,22 @@ def trigger_tasks(self, open_slots): # and expect scheduler loop to deal with it. self.queued_tasks.pop(key) result.backend = cached_celery_backend - self.running[key] = command + self.running.add(key) self.tasks[key] = result self.last_state[key] = celery_states.PENDING - def sync(self): + def order_queued_tasks_by_priority(self) -> List[Tuple[TaskInstanceKeyType, QueuedTaskInstanceType]]: + """ + Orders the queued tasks by priority. + + :return: List of tuples from the queued_tasks according to the priority. + """ + return sorted( + [(k, v) for k, v in self.queued_tasks.items()], + key=lambda x: x[1][1], + reverse=True) + + def sync(self) -> None: num_processes = min(len(self.tasks), self._sync_parallelism) if num_processes == 0: self.log.debug("No task to query celery, skipping sync") @@ -243,7 +257,7 @@ def sync(self): # Recreate the process pool each sync in case processes in the pool die self._sync_pool = Pool(processes=num_processes) - # Use chunking instead of a work queue to reduce context switching since tasks are + # Use chunks instead of a work queue to reduce context switching since tasks are # roughly uniform in size chunksize = self._num_tasks_per_fetch_process() @@ -256,6 +270,12 @@ def sync(self): self._sync_pool.join() self.log.debug("Inquiries completed.") + self.update_task_states(task_keys_to_states) + + def update_task_states(self, + task_keys_to_states: List[Union[TaskInstanceStateType, + ExceptionWithTraceback]]) -> None: + """Updates states of the tasks.""" for key_and_state in task_keys_to_states: if isinstance(key_and_state, ExceptionWithTraceback): self.log.error( @@ -264,30 +284,44 @@ def sync(self): ) continue key, state = key_and_state - try: - if self.last_state[key] != state: - if state == celery_states.SUCCESS: - self.success(key) - del self.tasks[key] - del self.last_state[key] - elif state == celery_states.FAILURE: - self.fail(key) - del self.tasks[key] - del self.last_state[key] - elif state == celery_states.REVOKED: - self.fail(key) - del self.tasks[key] - del self.last_state[key] - else: - self.log.info("Unexpected state: %s", state) - self.last_state[key] = state - except Exception: - self.log.exception("Error syncing the Celery executor, ignoring it.") - - def end(self, synchronous=False): + self.update_task_state(key, state) + + def update_task_state(self, key: TaskInstanceKeyType, state: str) -> None: + """Updates state of a single task.""" + # noinspection PyBroadException + try: + if self.last_state[key] != state: + if state == celery_states.SUCCESS: + self.success(key) + del self.tasks[key] + del self.last_state[key] + elif state == celery_states.FAILURE: + self.fail(key) + del self.tasks[key] + del self.last_state[key] + elif state == celery_states.REVOKED: + self.fail(key) + del self.tasks[key] + del self.last_state[key] + else: + self.log.info("Unexpected state: %s", state) + self.last_state[key] = state + except Exception: # pylint: disable=broad-except + self.log.exception("Error syncing the Celery executor, ignoring it.") + + def end(self, synchronous: bool = False) -> None: if synchronous: - while any([ - task.state not in celery_states.READY_STATES - for task in self.tasks.values()]): + while any([task.state not in celery_states.READY_STATES for task in self.tasks.values()]): time.sleep(5) self.sync() + + def execute_async(self, + key: TaskInstanceKeyType, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None): + """Do not allow async execution for Celery executor.""" + raise AirflowException("No Async execution for Celery executor.") + + def terminate(self): + pass diff --git a/airflow/executors/dask_executor.py b/airflow/executors/dask_executor.py index b355c4b6e7e86..01a8f703128cd 100644 --- a/airflow/executors/dask_executor.py +++ b/airflow/executors/dask_executor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -16,14 +15,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Dask executor.""" import subprocess -import warnings +from typing import Any, Dict, Optional -import distributed +from distributed import Client, Future, as_completed +from distributed.security import Security from airflow.configuration import conf -from airflow.executors.base_executor import BaseExecutor +from airflow.executors.base_executor import NOT_STARTED_MESSAGE, BaseExecutor, CommandType +from airflow.models.taskinstance import TaskInstanceKeyType class DaskExecutor(BaseExecutor): @@ -31,21 +32,20 @@ class DaskExecutor(BaseExecutor): DaskExecutor submits tasks to a Dask Distributed cluster. """ def __init__(self, cluster_address=None): + super().__init__(parallelism=0) if cluster_address is None: cluster_address = conf.get('dask', 'cluster_address') - if not cluster_address: - raise ValueError( - 'Please provide a Dask cluster address in airflow.cfg') + assert cluster_address, 'Please provide a Dask cluster address in airflow.cfg' self.cluster_address = cluster_address # ssl / tls parameters self.tls_ca = conf.get('dask', 'tls_ca') self.tls_key = conf.get('dask', 'tls_key') self.tls_cert = conf.get('dask', 'tls_cert') - super().__init__(parallelism=0) + self.client: Optional[Client] = None + self.futures: Optional[Dict[Future, TaskInstanceKeyType]] = None - def start(self): + def start(self) -> None: if self.tls_ca or self.tls_key or self.tls_cert: - from distributed.security import Security security = Security( tls_client_key=self.tls_key, tls_client_cert=self.tls_cert, @@ -55,23 +55,25 @@ def start(self): else: security = None - self.client = distributed.Client(self.cluster_address, security=security) + self.client = Client(self.cluster_address, security=security) self.futures = {} - def execute_async(self, key, command, queue=None, executor_config=None): - if queue is not None: - warnings.warn( - 'DaskExecutor does not support queues. ' - 'All tasks will be run in the same cluster' - ) + def execute_async(self, + key: TaskInstanceKeyType, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None) -> None: + assert self.futures, NOT_STARTED_MESSAGE def airflow_run(): return subprocess.check_call(command, close_fds=True) + assert self.client, "The Dask executor has not been started yet!" future = self.client.submit(airflow_run, pure=False) self.futures[future] = key - def _process_future(self, future): + def _process_future(self, future: Future) -> None: + assert self.futures, NOT_STARTED_MESSAGE if future.done(): key = self.futures[future] if future.exception(): @@ -84,15 +86,20 @@ def _process_future(self, future): self.success(key) self.futures.pop(future) - def sync(self): + def sync(self) -> None: + assert self.futures, NOT_STARTED_MESSAGE # make a copy so futures can be popped during iteration for future in self.futures.copy(): self._process_future(future) - def end(self): - for future in distributed.as_completed(self.futures.copy()): + def end(self) -> None: + assert self.client, NOT_STARTED_MESSAGE + assert self.futures, NOT_STARTED_MESSAGE + self.client.cancel(list(self.futures.keys())) + for future in as_completed(self.futures.copy()): self._process_future(future) def terminate(self): + assert self.futures, NOT_STARTED_MESSAGE self.client.cancel(self.futures.keys()) self.end() diff --git a/airflow/executors/executor_loader.py b/airflow/executors/executor_loader.py new file mode 100644 index 0000000000000..dcf2e398ca957 --- /dev/null +++ b/airflow/executors/executor_loader.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""All executors.""" +from typing import Optional + +from airflow.executors.base_executor import BaseExecutor + + +class ExecutorLoader: + """ + Keeps constants for all the currently available executors. + """ + + LOCAL_EXECUTOR = "LocalExecutor" + SEQUENTIAL_EXECUTOR = "SequentialExecutor" + CELERY_EXECUTOR = "CeleryExecutor" + DASK_EXECUTOR = "DaskExecutor" + KUBERNETES_EXECUTOR = "KubernetesExecutor" + + _default_executor: Optional[BaseExecutor] = None + + @classmethod + def get_default_executor(cls) -> BaseExecutor: + """Creates a new instance of the configured executor if none exists and returns it""" + if cls._default_executor is not None: + return cls._default_executor + + from airflow.configuration import conf + executor_name = conf.get('core', 'EXECUTOR') + + cls._default_executor = ExecutorLoader._get_executor(executor_name) + + from airflow import LoggingMixin + log = LoggingMixin().log + log.info("Using executor %s", executor_name) + + return cls._default_executor + + @staticmethod + def _get_executor(executor_name: str) -> BaseExecutor: + """ + Creates a new instance of the named executor. + In case the executor name is unknown in airflow, + look for it in the plugins + """ + if executor_name == ExecutorLoader.LOCAL_EXECUTOR: + from airflow.executors.local_executor import LocalExecutor + return LocalExecutor() + elif executor_name == ExecutorLoader.SEQUENTIAL_EXECUTOR: + from airflow.executors.sequential_executor import SequentialExecutor + return SequentialExecutor() + elif executor_name == ExecutorLoader.CELERY_EXECUTOR: + from airflow.executors.celery_executor import CeleryExecutor + return CeleryExecutor() + elif executor_name == ExecutorLoader.DASK_EXECUTOR: + from airflow.executors.dask_executor import DaskExecutor + return DaskExecutor() + elif executor_name == ExecutorLoader.KUBERNETES_EXECUTOR: + from airflow.executors.kubernetes_executor import KubernetesExecutor + return KubernetesExecutor() + else: + # Load plugins here for executors as at that time the plugins might not have been initialized yet + # TODO: verify the above and remove two lines below in case plugins are always initialized first + from airflow import plugins_manager + plugins_manager.integrate_executor_plugins() + executor_path = executor_name.split('.') + assert len(executor_path) == 2, f"Executor {executor_name} not supported: " \ + f"please specify in format plugin_module.executor" + + assert executor_path[0] in globals(), f"Executor {executor_name} not supported" + return globals()[executor_path[0]].__dict__[executor_path[1]]() diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index f6d61472eeaf1..0bc6db7e29086 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -16,28 +16,31 @@ # under the License. """Kubernetes executor""" import base64 +import datetime import hashlib import json import multiprocessing import re -from queue import Empty -from typing import Union +from queue import Empty, Queue # pylint: disable=unused-import +from typing import Any, Dict, Optional, Tuple, Union from uuid import uuid4 import kubernetes from dateutil import parser from kubernetes import client, watch +from kubernetes.client import Configuration from kubernetes.client.rest import ApiException from airflow import settings from airflow.configuration import conf from airflow.exceptions import AirflowConfigException, AirflowException -from airflow.executors.base_executor import BaseExecutor +from airflow.executors.base_executor import NOT_STARTED_MESSAGE, BaseExecutor, CommandType from airflow.kubernetes.kube_client import get_kube_client from airflow.kubernetes.pod_generator import PodGenerator from airflow.kubernetes.pod_launcher import PodLauncher from airflow.kubernetes.worker_configuration import WorkerConfiguration from airflow.models import KubeResourceVersion, KubeWorkerIdentifier, TaskInstance +from airflow.models.taskinstance import TaskInstanceKeyType from airflow.utils.db import create_session, provide_session from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State @@ -45,6 +48,15 @@ MAX_POD_ID_LEN = 253 MAX_LABEL_LEN = 63 +# TaskInstance key, command, configuration +KubernetesJobType = Tuple[TaskInstanceKeyType, CommandType, Any] + +# key, state, pod_id, resource_version +KubernetesResultsType = Tuple[TaskInstanceKeyType, Optional[str], str, str] + +# pod_id, state, labels, resource_version +KubernetesWatchType = Tuple[str, Optional[str], Dict[str, str], str] + class KubeConfig: # pylint: disable=too-many-instance-attributes """Configuration for Kubernetes""" @@ -241,7 +253,12 @@ def _validate(self): class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin): """Watches for Kubernetes jobs""" - def __init__(self, namespace, watcher_queue, resource_version, worker_uuid, kube_config): + def __init__(self, + namespace: str, + watcher_queue: 'Queue[KubernetesWatchType]', + resource_version: Optional[str], + worker_uuid: Optional[str], + kube_config: Configuration): multiprocessing.Process.__init__(self) self.namespace = namespace self.worker_uuid = worker_uuid @@ -249,9 +266,10 @@ def __init__(self, namespace, watcher_queue, resource_version, worker_uuid, kube self.resource_version = resource_version self.kube_config = kube_config - def run(self): + def run(self) -> None: """Performs watching""" - kube_client = get_kube_client() + kube_client: client.CoreV1Api = get_kube_client() + assert self.worker_uuid, NOT_STARTED_MESSAGE while True: try: self.resource_version = self._run(kube_client, self.resource_version, @@ -263,7 +281,11 @@ def run(self): self.log.warning('Watch died gracefully, starting back up with: ' 'last resource_version: %s', self.resource_version) - def _run(self, kube_client, resource_version, worker_uuid, kube_config): + def _run(self, + kube_client: client.CoreV1Api, + resource_version: Optional[str], + worker_uuid: str, + kube_config: Any) -> Optional[str]: self.log.info( 'Event: and now my watch begins starting at resource_version: %s', resource_version @@ -277,7 +299,7 @@ def _run(self, kube_client, resource_version, worker_uuid, kube_config): for key, value in kube_config.kube_client_request_args.items(): kwargs[key] = value - last_resource_version = None + last_resource_version: Optional[str] = None for event in watcher.stream(kube_client.list_namespaced_pod, self.namespace, **kwargs): task = event['object'] @@ -295,7 +317,7 @@ def _run(self, kube_client, resource_version, worker_uuid, kube_config): return last_resource_version - def process_error(self, event): + def process_error(self, event: Any) -> str: """Process error response""" self.log.error( 'Encountered Error response from k8s list namespaced pod stream => %s', @@ -314,7 +336,7 @@ def process_error(self, event): (raw_object['reason'], raw_object['code'], raw_object['message']) ) - def process_status(self, pod_id, status, labels, resource_version): + def process_status(self, pod_id: str, status: str, labels: Dict[str, str], resource_version: str) -> None: """Process status response""" if status == 'Pending': self.log.info('Event: %s Pending', pod_id) @@ -335,7 +357,13 @@ def process_status(self, pod_id, status, labels, resource_version): class AirflowKubernetesScheduler(LoggingMixin): """Airflow Scheduler for Kubernetes""" - def __init__(self, kube_config, task_queue, result_queue, kube_client, worker_uuid): + def __init__(self, + kube_config: Any, + task_queue: 'Queue[KubernetesJobType]', + result_queue: 'Queue[KubernetesResultsType]', + kube_client: client.CoreV1Api, + worker_uuid: str): + super().__init__() self.log.debug("Creating Kubernetes executor") self.kube_config = kube_config self.task_queue = task_queue @@ -350,7 +378,7 @@ def __init__(self, kube_config, task_queue, result_queue, kube_client, worker_uu self.worker_uuid = worker_uuid self.kube_watcher = self._make_kube_watcher() - def _make_kube_watcher(self): + def _make_kube_watcher(self) -> KubernetesJobWatcher: resource_version = KubeResourceVersion.get_current_resource_version() watcher = KubernetesJobWatcher(self.namespace, self.watcher_queue, resource_version, self.worker_uuid, self.kube_config) @@ -366,7 +394,7 @@ def _health_check_kube_watcher(self): 'Process died for unknown reasons') self.kube_watcher = self._make_kube_watcher() - def run_next(self, next_job): + def run_next(self, next_job: KubernetesJobType) -> None: """ The run_next command will check the task_queue for any un-run jobs. It will then create a unique job-id, launch that job in the cluster, @@ -408,7 +436,7 @@ def delete_pod(self, pod_id: str) -> None: if e.status != 404: raise - def sync(self): + def sync(self) -> None: """ The sync function checks the status of all currently running kubernetes jobs. If a job is completed, its status is placed in the result queue to @@ -428,7 +456,7 @@ def sync(self): except Empty: break - def process_watcher_task(self, task): + def process_watcher_task(self, task: KubernetesWatchType) -> None: """Process the task by watcher.""" pod_id, state, labels, resource_version = task self.log.info( @@ -441,7 +469,7 @@ def process_watcher_task(self, task): self.result_queue.put((key, state, pod_id, resource_version)) @staticmethod - def _strip_unsafe_kubernetes_special_chars(string): + def _strip_unsafe_kubernetes_special_chars(string: str) -> str: """ Kubernetes only supports lowercase alphanumeric characters and "-" and "." in the pod name @@ -456,7 +484,7 @@ def _strip_unsafe_kubernetes_special_chars(string): return ''.join(ch.lower() for ind, ch in enumerate(string) if ch.isalnum()) @staticmethod - def _make_safe_pod_id(safe_dag_id, safe_task_id, safe_uuid): + def _make_safe_pod_id(safe_dag_id: str, safe_task_id: str, safe_uuid: str) -> str: """ Kubernetes pod names must be <= 253 chars and must pass the following regex for validation @@ -474,7 +502,7 @@ def _make_safe_pod_id(safe_dag_id, safe_task_id, safe_uuid): return safe_pod_id @staticmethod - def _make_safe_label_value(string): + def _make_safe_label_value(string: str) -> str: """ Valid label values must be 63 characters or less and must be empty or begin and end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), underscores (_), @@ -493,7 +521,7 @@ def _make_safe_label_value(string): return safe_label @staticmethod - def _create_pod_id(dag_id, task_id): + def _create_pod_id(dag_id: str, task_id: str) -> str: safe_dag_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars( dag_id) safe_task_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars( @@ -504,7 +532,7 @@ def _create_pod_id(dag_id, task_id): safe_uuid) @staticmethod - def _label_safe_datestring_to_datetime(string): + def _label_safe_datestring_to_datetime(string: str) -> datetime.datetime: """ Kubernetes doesn't permit ":" in labels. ISO datetime format uses ":" but not "_", let's @@ -516,7 +544,7 @@ def _label_safe_datestring_to_datetime(string): return parser.parse(string.replace('_plus_', '+').replace("_", ":")) @staticmethod - def _datetime_to_label_safe_datestring(datetime_obj): + def _datetime_to_label_safe_datestring(datetime_obj: datetime.datetime) -> str: """ Kubernetes doesn't like ":" in labels, since ISO datetime format uses ":" but not "_" let's @@ -527,7 +555,7 @@ def _datetime_to_label_safe_datestring(datetime_obj): """ return datetime_obj.isoformat().replace(":", "_").replace('+', '_plus_') - def _labels_to_key(self, labels): + def _labels_to_key(self, labels: Dict[str, str]) -> Optional[TaskInstanceKeyType]: try_num = 1 try: try_num = int(labels.get('try_number', '1')) @@ -567,14 +595,14 @@ def _labels_to_key(self, labels): ) dag_id = task.dag_id task_id = task.task_id - return (dag_id, task_id, ex_time, try_num) + return dag_id, task_id, ex_time, try_num self.log.warning( 'Failed to find and match task details to a pod; labels: %s', labels ) return None - def _flush_watcher_queue(self): + def _flush_watcher_queue(self) -> None: self.log.debug('Executor shutting down, watcher_queue approx. size=%d', self.watcher_queue.qsize()) while True: try: @@ -585,7 +613,7 @@ def _flush_watcher_queue(self): except Empty: break - def terminate(self): + def terminate(self) -> None: """Terminates the watcher.""" self.log.debug("Terminating kube_watcher...") self.kube_watcher.terminate() @@ -601,18 +629,19 @@ def terminate(self): class KubernetesExecutor(BaseExecutor, LoggingMixin): """Executor for Kubernetes""" + def __init__(self): self.kube_config = KubeConfig() - self.task_queue = None - self.result_queue = None - self.kube_scheduler = None - self.kube_client = None - self.worker_uuid = None self._manager = multiprocessing.Manager() + self.task_queue: 'Queue[KubernetesJobType]' = self._manager.Queue() + self.result_queue: 'Queue[KubernetesResultsType]' = self._manager.Queue() + self.kube_scheduler: Optional[AirflowKubernetesScheduler] = None + self.kube_client: Optional[client.CoreV1Api] = None + self.worker_uuid: Optional[str] = None super().__init__(parallelism=self.kube_config.parallelism) @provide_session - def clear_not_launched_queued_tasks(self, session=None): + def clear_not_launched_queued_tasks(self, session=None) -> None: """ If the airflow scheduler restarts with pending "Queued" tasks, the tasks may or may not @@ -628,6 +657,7 @@ def clear_not_launched_queued_tasks(self, session=None): proper support for State.LAUNCHED """ + assert self.kube_client, NOT_STARTED_MESSAGE queued_tasks = session\ .query(TaskInstance)\ .filter(TaskInstance.state == State.QUEUED).all() @@ -667,7 +697,7 @@ def clear_not_launched_queued_tasks(self, session=None): TaskInstance.execution_date == task.execution_date ).update({TaskInstance.state: State.NONE}) - def _inject_secrets(self): + def _inject_secrets(self) -> None: def _create_or_update_secret(secret_name, secret_path): try: return self.kube_client.create_namespaced_secret( @@ -703,18 +733,17 @@ def _create_or_update_secret(secret_name, secret_path): for service_account in name_path_pair_list: _create_or_update_secret(service_account['name'], service_account['path']) - def start(self): + def start(self) -> None: """Starts the executor""" self.log.info('Start Kubernetes executor') self.worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid() + assert self.worker_uuid, "Could not get worker_uuid" self.log.debug('Start with worker_uuid: %s', self.worker_uuid) # always need to reset resource version since we don't know # when we last started, note for behavior below # https://github.com/kubernetes-client/python/blob/master/kubernetes/docs # /CoreV1Api.md#list_namespaced_pod KubeResourceVersion.reset_resource_version() - self.task_queue = self._manager.Queue() - self.result_queue = self._manager.Queue() self.kube_client = get_kube_client() self.kube_scheduler = AirflowKubernetesScheduler( self.kube_config, self.task_queue, self.result_queue, @@ -723,7 +752,11 @@ def start(self): self._inject_secrets() self.clear_not_launched_queued_tasks() - def execute_async(self, key, command, queue=None, executor_config=None): + def execute_async(self, + key: TaskInstanceKeyType, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None) -> None: """Executes task asynchronously""" self.log.info( 'Add task %s with command %s with executor_config %s', @@ -731,14 +764,19 @@ def execute_async(self, key, command, queue=None, executor_config=None): ) kube_executor_config = PodGenerator.from_obj(executor_config) + assert self.task_queue, NOT_STARTED_MESSAGE self.task_queue.put((key, command, kube_executor_config)) - def sync(self): + def sync(self) -> None: """Synchronize task state.""" if self.running: self.log.debug('self.running: %s', self.running) if self.queued_tasks: self.log.debug('self.queued: %s', self.queued_tasks) + assert self.kube_scheduler, NOT_STARTED_MESSAGE + assert self.kube_config, NOT_STARTED_MESSAGE + assert self.result_queue, NOT_STARTED_MESSAGE + assert self.task_queue, NOT_STARTED_MESSAGE self.kube_scheduler.sync() last_resource_version = None @@ -778,18 +816,20 @@ def sync(self): break # pylint: enable=too-many-nested-blocks - def _change_state(self, key, state, pod_id: str) -> None: + def _change_state(self, key: TaskInstanceKeyType, state: Optional[str], pod_id: str) -> None: if state != State.RUNNING: if self.kube_config.delete_worker_pods: + assert self.kube_scheduler, NOT_STARTED_MESSAGE self.kube_scheduler.delete_pod(pod_id) self.log.info('Deleted pod: %s', str(key)) try: - self.running.pop(key) + self.running.remove(key) except KeyError: self.log.debug('Could not find key: %s', str(key)) self.event_buffer[key] = state - def _flush_task_queue(self): + def _flush_task_queue(self) -> None: + assert self.task_queue, NOT_STARTED_MESSAGE self.log.debug('Executor shutting down, task_queue approximate size=%d', self.task_queue.qsize()) while True: try: @@ -800,7 +840,8 @@ def _flush_task_queue(self): except Empty: break - def _flush_result_queue(self): + def _flush_result_queue(self) -> None: + assert self.result_queue, NOT_STARTED_MESSAGE self.log.debug('Executor shutting down, result_queue approximate size=%d', self.result_queue.qsize()) while True: # pylint: disable=too-many-nested-blocks try: @@ -820,8 +861,11 @@ def _flush_result_queue(self): except Empty: break - def end(self): + def end(self) -> None: """Called when the executor shuts down""" + assert self.task_queue, NOT_STARTED_MESSAGE + assert self.result_queue, NOT_STARTED_MESSAGE + assert self.kube_scheduler, NOT_STARTED_MESSAGE self.log.info('Shutting down Kubernetes executor') self.log.debug('Flushing task_queue...') self._flush_task_queue() @@ -833,3 +877,6 @@ def end(self): if self.kube_scheduler: self.kube_scheduler.terminate() self._manager.shutdown() + + def terminate(self): + """Terminate the executor is not doing anything.""" diff --git a/airflow/executors/local_executor.py b/airflow/executors/local_executor.py index 0086dcbf12a01..9173a8e61fb29 100644 --- a/airflow/executors/local_executor.py +++ b/airflow/executors/local_executor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -43,40 +42,44 @@ This option could lead to the unification of the executor implementations, running locally, into just one `LocalExecutor` with multiple modes. """ - -import multiprocessing import subprocess -from queue import Empty - -from airflow.executors.base_executor import BaseExecutor +from multiprocessing import Manager, Process +from multiprocessing.managers import SyncManager +from queue import Empty, Queue # pylint: disable=unused-import # noqa: F401 +from typing import Any, List, Optional, Tuple, Union # pylint: disable=unused-import # noqa: F401 + +from airflow import AirflowException +from airflow.executors.base_executor import NOT_STARTED_MESSAGE, PARALLELISM, BaseExecutor, CommandType +from airflow.models.taskinstance import ( # pylint: disable=unused-import # noqa: F401 + TaskInstanceKeyType, TaskInstanceStateType, +) from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State +# This is a work to be executed by a worker. +# It can Key and Command - but it can also be None, None which is actually a +# "Poison Pill" - worker seeing Poison Pill should take the pill and ... die instantly. +ExecutorWorkType = Tuple[Optional[TaskInstanceKeyType], Optional[CommandType]] -class LocalWorker(multiprocessing.Process, LoggingMixin): - """LocalWorker Process implementation to run airflow commands. Executes the given - command and puts the result into a result queue when done, terminating execution.""" +class LocalWorkerBase(Process, LoggingMixin): + """ + LocalWorkerBase implementation to run airflow commands. Executes the given + command and puts the result into a result queue when done, terminating execution. - def __init__(self, result_queue): - """ - :param result_queue: the queue to store result states tuples (key, State) - :type result_queue: multiprocessing.Queue - """ + :param result_queue: the queue to store result state + """ + def __init__(self, result_queue: 'Queue[TaskInstanceStateType]'): super().__init__() - self.daemon = True - self.result_queue = result_queue - self.key = None - self.command = None + self.daemon: bool = True + self.result_queue: 'Queue[TaskInstanceStateType]' = result_queue - def execute_work(self, key, command): + def execute_work(self, key: TaskInstanceKeyType, command: CommandType) -> None: """ Executes command received and stores result state in queue. - :param key: the key to identify the TI - :type key: tuple(dag_id, task_id, execution_date) + :param key: the key to identify the task instance :param command: the command to execute - :type command: str """ if key is None: return @@ -89,87 +92,146 @@ def execute_work(self, key, command): self.log.error("Failed to execute task %s.", str(e)) self.result_queue.put((key, state)) - def run(self): - self.execute_work(self.key, self.command) +class LocalWorker(LocalWorkerBase): + """ + Local worker that executes the task. -class QueuedLocalWorker(LocalWorker): + :param result_queue: queue where results of the tasks are put. + :param key: key identifying task instance + :param command: Command to execute + """ + def __init__(self, + result_queue: 'Queue[TaskInstanceStateType]', + key: TaskInstanceKeyType, + command: CommandType): + super().__init__(result_queue) + self.key: TaskInstanceKeyType = key + self.command: CommandType = command - """LocalWorker implementation that is waiting for tasks from a queue and will - continue executing commands as they become available in the queue. It will terminate - execution once the poison token is found.""" + def run(self) -> None: + self.execute_work(key=self.key, command=self.command) - def __init__(self, task_queue, result_queue): + +class QueuedLocalWorker(LocalWorkerBase): + """ + LocalWorker implementation that is waiting for tasks from a queue and will + continue executing commands as they become available in the queue. + It will terminate execution once the poison token is found. + + :param task_queue: queue from which worker reads tasks + :param result_queue: queue where worker puts results after finishing tasks + """ + def __init__(self, + task_queue: 'Queue[ExecutorWorkType]', + result_queue: 'Queue[TaskInstanceStateType]'): super().__init__(result_queue=result_queue) self.task_queue = task_queue - def run(self): + def run(self) -> None: while True: key, command = self.task_queue.get() try: - if key is None: + if key is None or command is None: # Received poison pill, no more tasks to run break - self.execute_work(key, command) + self.execute_work(key=key, command=command) finally: self.task_queue.task_done() class LocalExecutor(BaseExecutor): """ - LocalExecutor executes tasks locally in parallel. It uses the - multiprocessing Python library and queues to parallelize the execution + LocalExecutor executes tasks locally in parallel. + It uses the multiprocessing Python library and queues to parallelize the execution of tasks. - """ - class _UnlimitedParallelism: - """Implements LocalExecutor with unlimited parallelism, starting one process - per each command to execute.""" + :param parallelism: how many parallel processes are run in the executor + """ + def __init__(self, parallelism: int = PARALLELISM): + super().__init__(parallelism=parallelism) + self.manager: Optional[SyncManager] = None + self.result_queue: Optional['Queue[TaskInstanceStateType]'] = None + self.workers: List[QueuedLocalWorker] = [] + self.workers_used: int = 0 + self.workers_active: int = 0 + self.impl: Optional[Union['LocalExecutor.UnlimitedParallelism', + 'LocalExecutor.LimitedParallelism']] = None + + class UnlimitedParallelism: + """ + Implements LocalExecutor with unlimited parallelism, starting one process + per each command to execute. - def __init__(self, executor): - """ - :param executor: the executor instance to implement. - :type executor: LocalExecutor - """ - self.executor = executor + :param executor: the executor instance to implement. + """ + def __init__(self, executor: 'LocalExecutor'): + self.executor: 'LocalExecutor' = executor - def start(self): + def start(self) -> None: + """Starts the executor.""" self.executor.workers_used = 0 self.executor.workers_active = 0 - def execute_async(self, key, command): + # noinspection PyUnusedLocal + def execute_async(self, + key: TaskInstanceKeyType, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None) -> None: \ + # pylint: disable=unused-argument # pragma: no cover """ - :param key: the key to identify the TI - :type key: tuple(dag_id, task_id, execution_date) + Executes task asynchronously. + + :param key: the key to identify the task instance :param command: the command to execute - :type command: str + :param queue: Name of the queue + :param executor_config: configuration for the executor """ - local_worker = LocalWorker(self.executor.result_queue) - local_worker.key = key - local_worker.command = command + assert self.executor.result_queue, NOT_STARTED_MESSAGE + local_worker = LocalWorker(self.executor.result_queue, key=key, command=command) self.executor.workers_used += 1 self.executor.workers_active += 1 local_worker.start() - def sync(self): + def sync(self) -> None: + """ + Sync will get called periodically by the heartbeat method. + """ + if not self.executor.result_queue: + raise AirflowException("Executor should be started first") while not self.executor.result_queue.empty(): results = self.executor.result_queue.get() self.executor.change_state(*results) self.executor.workers_active -= 1 - def end(self): + def end(self) -> None: + """ + This method is called when the caller is done submitting job and + wants to wait synchronously for the job submitted previously to be + all done. + """ while self.executor.workers_active > 0: self.executor.sync() - class _LimitedParallelism: - """Implements LocalExecutor with limited parallelism using a task queue to - coordinate work distribution.""" - - def __init__(self, executor): - self.executor = executor + class LimitedParallelism: + """ + Implements LocalExecutor with limited parallelism using a task queue to + coordinate work distribution. - def start(self): + :param executor: the executor instance to implement. + """ + def __init__(self, executor: 'LocalExecutor'): + self.executor: 'LocalExecutor' = executor + self.queue: Optional['Queue[ExecutorWorkType]'] = None + + def start(self) -> None: + """Starts limited parallelism implementation.""" + if not self.executor.manager: + raise AirflowException("Executor must be started!") self.queue = self.executor.manager.Queue() + if not self.executor.result_queue: + raise AirflowException("Executor must be started!") self.executor.workers = [ QueuedLocalWorker(self.queue, self.executor.result_queue) for _ in range(self.executor.parallelism) @@ -177,19 +239,31 @@ def start(self): self.executor.workers_used = len(self.executor.workers) - for w in self.executor.workers: - w.start() + for worker in self.executor.workers: + worker.start() - def execute_async(self, key, command): + # noinspection PyUnusedLocal + def execute_async(self, + key: TaskInstanceKeyType, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None) -> None: \ + # pylint: disable=unused-argument # pragma: no cover """ - :param key: the key to identify the TI - :type key: tuple(dag_id, task_id, execution_date) + Executes task asynchronously. + + :param key: the key to identify the task instance :param command: the command to execute - :type command: str - """ + :param queue: name of the queue + :param executor_config: configuration for the executor + """ + assert self.queue, NOT_STARTED_MESSAGE self.queue.put((key, command)) def sync(self): + """ + Sync will get called periodically by the heartbeat method. + """ while True: try: results = self.executor.result_queue.get_nowait() @@ -201,7 +275,7 @@ def sync(self): break def end(self): - # Sending poison pill to all worker + """Ends the executor. Sends the poison pill to all workers.""" for _ in self.executor.workers: self.queue.put((None, None)) @@ -209,23 +283,42 @@ def end(self): self.queue.join() self.executor.sync() - def start(self): - self.manager = multiprocessing.Manager() + def start(self) -> None: + """Starts the executor""" + self.manager = Manager() self.result_queue = self.manager.Queue() self.workers = [] self.workers_used = 0 self.workers_active = 0 - self.impl = (LocalExecutor._UnlimitedParallelism(self) if self.parallelism == 0 - else LocalExecutor._LimitedParallelism(self)) + self.impl = (LocalExecutor.UnlimitedParallelism(self) if self.parallelism == 0 + else LocalExecutor.LimitedParallelism(self)) self.impl.start() - def execute_async(self, key, command, queue=None, executor_config=None): - self.impl.execute_async(key=key, command=command) + def execute_async(self, key: TaskInstanceKeyType, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None) -> None: + """Execute asynchronously.""" + assert self.impl, NOT_STARTED_MESSAGE + self.impl.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) - def sync(self): + def sync(self) -> None: + """ + Sync will get called periodically by the heartbeat method. + """ + assert self.impl, NOT_STARTED_MESSAGE self.impl.sync() - def end(self): + def end(self) -> None: + """ + Ends the executor. + :return: + """ + assert self.impl, NOT_STARTED_MESSAGE + assert self.manager, NOT_STARTED_MESSAGE self.impl.end() self.manager.shutdown() + + def terminate(self): + """Terminate the executor is not doing anything.""" diff --git a/airflow/executors/sequential_executor.py b/airflow/executors/sequential_executor.py index bb9303c7ccef5..a0d58b7350c7f 100644 --- a/airflow/executors/sequential_executor.py +++ b/airflow/executors/sequential_executor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -16,10 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Sequential executor.""" import subprocess +from typing import Any, Optional -from airflow.executors.base_executor import BaseExecutor +from airflow.executors.base_executor import BaseExecutor, CommandType +from airflow.models.taskinstance import TaskInstanceKeyType from airflow.utils.state import State @@ -32,14 +33,19 @@ class SequentialExecutor(BaseExecutor): Since we want airflow to work out of the box, it defaults to this SequentialExecutor alongside sqlite as you first install it. """ + def __init__(self): super().__init__() self.commands_to_run = [] - def execute_async(self, key, command, queue=None, executor_config=None): - self.commands_to_run.append((key, command,)) + def execute_async(self, + key: TaskInstanceKeyType, + command: CommandType, + queue: Optional[str] = None, + executor_config: Optional[Any] = None) -> None: + self.commands_to_run.append((key, command)) - def sync(self): + def sync(self) -> None: for key, command in self.commands_to_run: self.log.info("Executing command: %s", command) @@ -53,4 +59,8 @@ def sync(self): self.commands_to_run = [] def end(self): + """End the executor.""" self.heartbeat() + + def terminate(self): + """Terminate the executor is not doing anything.""" diff --git a/airflow/hooks/__init__.py b/airflow/hooks/__init__.py index 2020e16b02b78..48cdbdf0aa3e1 100644 --- a/airflow/hooks/__init__.py +++ b/airflow/hooks/__init__.py @@ -16,18 +16,4 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# pylint: disable=missing-docstring - -import sys - -# Imports the hooks dynamically while keeping the package API clean, -# abstracting the underlying modules - - -def _integrate_plugins(): - """Integrate plugins to the context""" - from airflow.plugins_manager import hooks_modules - for hooks_module in hooks_modules: - sys.modules[hooks_module.__name__] = hooks_module - globals()[hooks_module._name] = hooks_module # pylint: disable=protected-access +"""Hooks.""" diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py index e0360befcf795..e33237d51462b 100644 --- a/airflow/jobs/backfill_job.py +++ b/airflow/jobs/backfill_job.py @@ -24,11 +24,13 @@ from sqlalchemy.orm.session import Session, make_transient -from airflow import executors, models +from airflow import models from airflow.exceptions import ( AirflowException, DagConcurrencyLimitReached, NoAvailablePoolSlot, PoolNotFound, TaskConcurrencyLimitReached, ) +from airflow.executors.local_executor import LocalExecutor +from airflow.executors.sequential_executor import SequentialExecutor from airflow.jobs.base_job import BaseJob from airflow.models import DAG, DagPickle, DagRun from airflow.ti_deps.dep_context import BACKFILL_QUEUED_DEPS, DepContext @@ -487,8 +489,7 @@ def _per_task_process(task, key, ti, session=None): session.merge(ti) cfg_path = None - if executor.__class__ in (executors.LocalExecutor, - executors.SequentialExecutor): + if executor.__class__ in (LocalExecutor, SequentialExecutor): cfg_path = tmp_configuration_copy() executor.queue_task_instance( @@ -740,8 +741,7 @@ def _execute(self, session=None): # picklin' pickle_id = None - if not self.donot_pickle and self.executor.__class__ not in ( - executors.LocalExecutor, executors.SequentialExecutor): + if not self.donot_pickle and self.executor.__class__ not in (LocalExecutor, SequentialExecutor): pickle = DagPickle(self.dag) session.add(pickle) session.commit() diff --git a/airflow/jobs/base_job.py b/airflow/jobs/base_job.py index 874c0c6f058ce..b554b6ecb0b96 100644 --- a/airflow/jobs/base_job.py +++ b/airflow/jobs/base_job.py @@ -26,9 +26,10 @@ from sqlalchemy.exc import OperationalError from sqlalchemy.orm.session import make_transient -from airflow import executors, models +from airflow import models from airflow.configuration import conf from airflow.exceptions import AirflowException +from airflow.executors.executor_loader import ExecutorLoader from airflow.models.base import ID_LEN, Base from airflow.stats import Stats from airflow.utils import helpers, timezone @@ -79,7 +80,7 @@ def __init__( heartrate=None, *args, **kwargs): self.hostname = get_hostname() - self.executor = executor or executors.get_default_executor() + self.executor = executor or ExecutorLoader.get_default_executor() self.executor_class = executor.__class__.__name__ self.start_date = timezone.utcnow() self.latest_heartbeat = timezone.utcnow() diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 2640f1232efe5..518ccb1b75dc2 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -35,21 +35,24 @@ from sqlalchemy import and_, func, not_, or_ from sqlalchemy.orm.session import make_transient -from airflow import executors, models, settings +from airflow import models, settings from airflow.configuration import conf from airflow.exceptions import AirflowException +from airflow.executors.local_executor import LocalExecutor +from airflow.executors.sequential_executor import SequentialExecutor from airflow.jobs.base_job import BaseJob from airflow.models import DAG, DagRun, SlaMiss, errors +from airflow.models.taskinstance import SimpleTaskInstance from airflow.stats import Stats from airflow.ti_deps.dep_context import SCHEDULEABLE_STATES, SCHEDULED_DEPS, DepContext from airflow.ti_deps.deps.pool_slots_available_dep import STATES_TO_COUNT_AS_RUNNING from airflow.utils import asciiart, helpers, timezone from airflow.utils.dag_processing import ( - AbstractDagFileProcessor, DagFileProcessorAgent, SimpleDag, SimpleDagBag, SimpleTaskInstance, - list_py_file_paths, + AbstractDagFileProcessor, DagFileProcessorAgent, SimpleDag, SimpleDagBag, ) from airflow.utils.db import provide_session from airflow.utils.email import get_email_address_list, send_email +from airflow.utils.file import list_py_file_paths from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter, set_context from airflow.utils.state import State @@ -64,7 +67,7 @@ class DagFileProcessor(AbstractDagFileProcessor, LoggingMixin): :param dag_id_white_list: If specified, only look at these DAG ID's :type dag_id_white_list: list[unicode] :param zombies: zombie task instances to kill - :type zombies: list[airflow.utils.dag_processing.SimpleTaskInstance] + :type zombies: list[airflow.models.taskinstance.SimpleTaskInstance] """ # Counter that increments every time an instance of this class is created @@ -116,7 +119,7 @@ def _run_file_processor(result_channel, :param thread_name: the name to use for the process that is launched :type thread_name: unicode :param zombies: zombie task instances to kill - :type zombies: list[airflow.utils.dag_processing.SimpleTaskInstance] + :type zombies: list[airflow.models.taskinstance.SimpleTaskInstance] :return: the process that was launched :rtype: multiprocessing.Process """ @@ -1013,7 +1016,7 @@ def _change_state_for_executable_task_instances(self, task_instances, :type task_instances: list[airflow.models.TaskInstance] :param acceptable_states: Filters the TaskInstances updated to be in these states :type acceptable_states: Iterable[State] - :rtype: list[airflow.utils.dag_processing.SimpleTaskInstance] + :rtype: list[airflow.models.taskinstance.SimpleTaskInstance] """ if len(task_instances) == 0: session.commit() @@ -1282,8 +1285,7 @@ def _execute(self): # DAGs can be pickled for easier remote execution by some executors pickle_dags = False - if self.do_pickle and self.executor.__class__ not in \ - (executors.LocalExecutor, executors.SequentialExecutor): + if self.do_pickle and self.executor.__class__ not in (LocalExecutor, SequentialExecutor): pickle_dags = True self.log.info("Processing each file at most %s times", self.num_runs) @@ -1494,7 +1496,7 @@ def process_file(self, file_path, zombies, pickle_dags=False, session=None): :param file_path: the path to the Python file that should be executed :type file_path: unicode :param zombies: zombie task instances to kill. - :type zombies: list[airflow.utils.dag_processing.SimpleTaskInstance] + :type zombies: list[airflow.models.taskinstance.SimpleTaskInstance] :param pickle_dags: whether serialize the DAGs found in the file and save them to the db :type pickle_dags: bool diff --git a/airflow/kubernetes/kube_client.py b/airflow/kubernetes/kube_client.py index 2d734d9e1a111..b17e9a007ac09 100644 --- a/airflow/kubernetes/kube_client.py +++ b/airflow/kubernetes/kube_client.py @@ -68,7 +68,7 @@ def _get_client_with_patched_configuration(cfg: Optional[Configuration]) -> clie def get_kube_client(in_cluster: bool = conf.getboolean('kubernetes', 'in_cluster'), cluster_context: Optional[str] = None, - config_file: Optional[str] = None): + config_file: Optional[str] = None) -> client.CoreV1Api: """ Retrieves Kubernetes client diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py index f8d9b99d038a7..38c750085ca4a 100644 --- a/airflow/kubernetes/pod_generator.py +++ b/airflow/kubernetes/pod_generator.py @@ -26,12 +26,10 @@ import kubernetes.client.models as k8s -from airflow.executors import Executors - class PodDefaults: """ - Static defaults for the PodGenerator + Static defaults for Pods """ XCOM_MOUNT_PATH = '/airflow/xcom' SIDECAR_CONTAINER_NAME = 'airflow-xcom-sidecar' @@ -227,8 +225,9 @@ def from_obj(obj) -> k8s.V1Pod: raise TypeError( 'Cannot convert a non-dictionary or non-PodGenerator ' 'object into a KubernetesExecutorConfig') - - namespaced = obj.get(Executors.KubernetesExecutor, {}) + # We do not want to extract constant here from ExecutorLoader because it is just + # A name in dictionary rather than executor selection mechanism and it causes cyclic import + namespaced = obj.get("KubernetesExecutor", {}) resources = namespaced.get('resources') diff --git a/airflow/kubernetes/pod_launcher.py b/airflow/kubernetes/pod_launcher.py index 2d076939b77c2..686a9176a4fe6 100644 --- a/airflow/kubernetes/pod_launcher.py +++ b/airflow/kubernetes/pod_launcher.py @@ -27,7 +27,7 @@ from kubernetes.stream import stream as kubernetes_stream from requests.exceptions import BaseHTTPError -from airflow import AirflowException +from airflow.exceptions import AirflowException from airflow.kubernetes.pod_generator import PodDefaults from airflow.settings import pod_mutation_hook from airflow.utils.log.logging_mixin import LoggingMixin diff --git a/airflow/macros/__init__.py b/airflow/macros/__init__.py index 5582b3c3a3d26..6a9fa7e8fd3f6 100644 --- a/airflow/macros/__init__.py +++ b/airflow/macros/__init__.py @@ -85,12 +85,3 @@ def datetime_diff_for_humans(dt, since=None): import pendulum return pendulum.instance(dt).diff_for_humans(since) - - -def _integrate_plugins(): - """Integrate plugins to the context""" - import sys - from airflow.plugins_manager import macros_modules - for macros_module in macros_modules: - sys.modules[macros_module.__name__] = macros_module - globals()[macros_module._name] = macros_module # pylint: disable=protected-access diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 3a59cfb804564..07a06f683f6fb 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1090,7 +1090,7 @@ class BaseOperatorLink(metaclass=ABCMeta): Abstract base class that defines how we get an operator link. """ - operators = [] # type: List[Type[BaseOperator]] + operators: List[Type[BaseOperator]] = [] """ This property will be used by Airflow Plugins to find the Operators to which you want to assign this Operator Link diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 1abde55e6e0a4..17bd85920d846 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -38,7 +38,6 @@ from airflow.configuration import conf from airflow.dag.base_dag import BaseDag from airflow.exceptions import AirflowDagCycleException, AirflowException, DagNotFound, DuplicateTaskIdFound -from airflow.executors import LocalExecutor, get_default_executor from airflow.models.base import ID_LEN, Base from airflow.models.baseoperator import BaseOperator from airflow.models.dagbag import DagBag @@ -47,9 +46,9 @@ from airflow.models.taskinstance import TaskInstance, clear_task_instances from airflow.settings import MIN_SERIALIZED_DAG_UPDATE_INTERVAL, STORE_SERIALIZED_DAGS from airflow.utils import timezone -from airflow.utils.dag_processing import correct_maybe_zipped from airflow.utils.dates import cron_presets, date_range as utils_date_range from airflow.utils.db import provide_session +from airflow.utils.file import correct_maybe_zipped from airflow.utils.helpers import validate_key from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.sqlalchemy import Interval, UtcDateTime @@ -1254,9 +1253,11 @@ def run( """ from airflow.jobs import BackfillJob if not executor and local: + from airflow.executors.local_executor import LocalExecutor executor = LocalExecutor() elif not executor: - executor = get_default_executor() + from airflow.executors.executor_loader import ExecutorLoader + executor = ExecutorLoader.get_default_executor() job = BackfillJob( self, start_date=start_date, diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index ec6e1a20ee4a0..7df03d9de7171 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -33,11 +33,10 @@ from airflow.configuration import conf from airflow.dag.base_dag import BaseDagBag from airflow.exceptions import AirflowDagCycleException -from airflow.executors import get_default_executor from airflow.stats import Stats from airflow.utils import timezone -from airflow.utils.dag_processing import correct_maybe_zipped, list_py_file_paths from airflow.utils.db import provide_session +from airflow.utils.file import correct_maybe_zipped from airflow.utils.helpers import pprinttable from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.timeout import timeout @@ -88,7 +87,8 @@ def __init__( # do not use default arg in signature, to fix import cycle on plugin load if executor is None: - executor = get_default_executor() + from airflow.executors.executor_loader import ExecutorLoader + executor = ExecutorLoader.get_default_executor() dag_folder = dag_folder or settings.DAGS_FOLDER self.dag_folder = dag_folder self.dags = {} @@ -317,9 +317,7 @@ def kill_zombies(self, zombies, session=None): had a heartbeat for too long, in the current DagBag. :param zombies: zombie task instances to kill. - :type zombies: airflow.utils.dag_processing.SimpleTaskInstance :param session: DB session. - :type session: sqlalchemy.orm.session.Session """ from airflow.models.taskinstance import TaskInstance # Avoid circular import @@ -406,8 +404,8 @@ def collect_dags( FileLoadStat = namedtuple( 'FileLoadStat', "file duration dag_num task_num dags") + from airflow.utils.file import correct_maybe_zipped, list_py_file_paths dag_folder = correct_maybe_zipped(dag_folder) - for filepath in list_py_file_paths(dag_folder, safe_mode=safe_mode, include_examples=include_examples): try: diff --git a/airflow/models/kubernetes.py b/airflow/models/kubernetes.py index 50c205f0e8358..6535e2389a395 100644 --- a/airflow/models/kubernetes.py +++ b/airflow/models/kubernetes.py @@ -20,6 +20,7 @@ import uuid from sqlalchemy import Boolean, Column, String, true as sqltrue +from sqlalchemy.orm import Session from airflow.models.base import Base from airflow.utils.db import provide_session @@ -32,13 +33,13 @@ class KubeResourceVersion(Base): @staticmethod @provide_session - def get_current_resource_version(session=None): + def get_current_resource_version(session: Session = None) -> str: (resource_version,) = session.query(KubeResourceVersion.resource_version).one() return resource_version @staticmethod @provide_session - def checkpoint_resource_version(resource_version, session=None): + def checkpoint_resource_version(resource_version, session: Session = None) -> None: if resource_version: session.query(KubeResourceVersion).update({ KubeResourceVersion.resource_version: resource_version @@ -47,7 +48,7 @@ def checkpoint_resource_version(resource_version, session=None): @staticmethod @provide_session - def reset_resource_version(session=None): + def reset_resource_version(session: Session = None) -> str: session.query(KubeResourceVersion).update({ KubeResourceVersion.resource_version: '0' }) @@ -62,7 +63,7 @@ class KubeWorkerIdentifier(Base): @staticmethod @provide_session - def get_or_create_current_kube_worker_uuid(session=None): + def get_or_create_current_kube_worker_uuid(session: Session = None) -> str: (worker_uuid,) = session.query(KubeWorkerIdentifier.worker_uuid).one() if worker_uuid == '': worker_uuid = str(uuid.uuid4()) @@ -71,7 +72,7 @@ def get_or_create_current_kube_worker_uuid(session=None): @staticmethod @provide_session - def checkpoint_kube_worker_uuid(worker_uuid, session=None): + def checkpoint_kube_worker_uuid(worker_uuid: str, session: Session = None) -> None: if worker_uuid: session.query(KubeWorkerIdentifier).update({ KubeWorkerIdentifier.worker_uuid: worker_uuid diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 4caef644882bd..716d51d1b7349 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -26,7 +25,7 @@ import signal import time from datetime import datetime, timedelta -from typing import Any, Dict, Iterable, Optional, Union +from typing import Any, Dict, Iterable, Optional, Tuple, Union from urllib.parse import quote import dill @@ -121,6 +120,11 @@ def clear_task_instances(tis, dr.start_date = timezone.utcnow() +# Key used to identify task instance +# Tuple of: dag_id, task_id, execution_date, try_number +TaskInstanceKeyType = Tuple[str, str, datetime, int] + + class TaskInstance(Base, LoggingMixin): """ Task instances store the state of a task instance. This table is the @@ -213,7 +217,7 @@ def try_number(self): Return the try number that this task number will be when it is actually run. - If the TI is currently running, this will match the column in the + If the TaskInstance is currently running, this will match the column in the database, in all other cases this will be incremented. """ # This is designed so that task logs end up in the right file. @@ -395,11 +399,10 @@ def current_state(self, session=None): we use and looking up the state becomes part of the session, otherwise a new session is used. """ - TI = TaskInstance - ti = session.query(TI).filter( - TI.dag_id == self.dag_id, - TI.task_id == self.task_id, - TI.execution_date == self.execution_date, + ti = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.execution_date == self.execution_date, ).all() if ti: state = ti[0].state @@ -418,7 +421,7 @@ def error(self, session=None): session.commit() @provide_session - def refresh_from_db(self, session=None, lock_for_update=False, refresh_executor_config=False): + def refresh_from_db(self, session=None, lock_for_update=False, refresh_executor_config=False) -> None: """ Refreshes the task instance from the database based on the primary key @@ -429,12 +432,11 @@ def refresh_from_db(self, session=None, lock_for_update=False, refresh_executor_ lock the TaskInstance (issuing a FOR UPDATE clause) until the session is committed. """ - TI = TaskInstance - qry = session.query(TI).filter( - TI.dag_id == self.dag_id, - TI.task_id == self.task_id, - TI.execution_date == self.execution_date) + qry = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.execution_date == self.execution_date) if lock_for_update: ti = qry.with_for_update().first() @@ -468,7 +470,7 @@ def clear_xcom_data(self, session=None): session.commit() @property - def key(self): + def key(self) -> TaskInstanceKeyType: """ Returns a tuple that identifies the task instance uniquely """ @@ -528,7 +530,7 @@ def _get_previous_ti( # LEGACY: most likely running from unit tests if not dr: - # Means that this TI is NOT being run from a DR, but from a catchup + # Means that this TaskInstance is NOT being run from a DR, but from a catchup previous_scheduled_date = dag.previous_schedule(self.execution_date) if not previous_scheduled_date: return None @@ -742,7 +744,7 @@ def _check_and_change_state_before_execution( :type ignore_all_deps: bool :param ignore_depends_on_past: Ignore depends_on_past DAG attribute :type ignore_depends_on_past: bool - :param ignore_task_deps: Don't check the dependencies of this TI's task + :param ignore_task_deps: Don't check the dependencies of this TaskInstance's task :type ignore_task_deps: bool :param ignore_ti_state: Disregards previous task instance state :type ignore_ti_state: bool @@ -1056,7 +1058,7 @@ def handle_failure(self, error, test_mode=False, context=None, session=None): # Let's go deeper try: - # Since this function is called only when the TI state is running, + # Since this function is called only when the TaskInstance state is running, # try_number contains the current try_number (not the next). We # only mark task instance as FAILED if the next task instance # try_number exceeds the max_tries. @@ -1379,12 +1381,11 @@ def xcom_pull( @provide_session def get_num_running_task_instances(self, session): - TI = TaskInstance # .count() is inefficient return session.query(func.count()).filter( - TI.dag_id == self.dag_id, - TI.task_id == self.task_id, - TI.state == State.RUNNING + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.state == State.RUNNING ).scalar() def init_run_context(self, raw=False): @@ -1393,3 +1394,108 @@ def init_run_context(self, raw=False): """ self.raw = raw self._set_context(self) + + +# State of the task instance. +# Stores string version of the task state. +TaskInstanceStateType = Tuple[TaskInstanceKeyType, str] + + +class SimpleTaskInstance: + """ + Simplified Task Instance. + + Used to send data between processes via Queues. + """ + def __init__(self, ti: TaskInstance): + self._dag_id: str = ti.dag_id + self._task_id: str = ti.task_id + self._execution_date: datetime = ti.execution_date + self._start_date: datetime = ti.start_date + self._end_date: datetime = ti.end_date + self._try_number: int = ti.try_number + self._state: str = ti.state + self._executor_config: Any = ti.executor_config + self._run_as_user: Optional[str] = None + if hasattr(ti, 'run_as_user'): + self._run_as_user = ti.run_as_user + self._pool: Optional[str] = None + if hasattr(ti, 'pool'): + self._pool = ti.pool + self._priority_weight: Optional[int] = None + if hasattr(ti, 'priority_weight'): + self._priority_weight = ti.priority_weight + self._queue: str = ti.queue + self._key = ti.key + + # pylint: disable=missing-docstring + @property + def dag_id(self) -> str: + return self._dag_id + + @property + def task_id(self) -> str: + return self._task_id + + @property + def execution_date(self) -> datetime: + return self._execution_date + + @property + def start_date(self) -> datetime: + return self._start_date + + @property + def end_date(self) -> datetime: + return self._end_date + + @property + def try_number(self) -> int: + return self._try_number + + @property + def state(self) -> str: + return self._state + + @property + def pool(self) -> Any: + return self._pool + + @property + def priority_weight(self) -> Optional[int]: + return self._priority_weight + + @property + def queue(self) -> str: + return self._queue + + @property + def key(self) -> TaskInstanceKeyType: + return self._key + + @property + def executor_config(self): + return self._executor_config + + @provide_session + def construct_task_instance(self, session=None, lock_for_update=False) -> TaskInstance: + """ + Construct a TaskInstance from the database based on the primary key + + :param session: DB session. + :param lock_for_update: if True, indicates that the database should + lock the TaskInstance (issuing a FOR UPDATE clause) until the + session is committed. + :return: the task instance constructed + """ + + qry = session.query(TaskInstance).filter( + TaskInstance.dag_id == self._dag_id, + TaskInstance.task_id == self._task_id, + TaskInstance.execution_date == self._execution_date) + + if lock_for_update: + ti = qry.with_for_update().first() + else: + ti = qry.first() + return ti diff --git a/airflow/operators/__init__.py b/airflow/operators/__init__.py index 257e0a7b53b50..99eb3bf2c3fe8 100644 --- a/airflow/operators/__init__.py +++ b/airflow/operators/__init__.py @@ -16,14 +16,4 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# pylint: disable=missing-docstring - - -def _integrate_plugins(): - """Integrate plugins to the context""" - import sys - from airflow.plugins_manager import operators_modules - for operators_module in operators_modules: - sys.modules[operators_module.__name__] = operators_module - globals()[operators_module._name] = operators_module # pylint: disable=protected-access +"""Operators.""" diff --git a/airflow/operators/mssql_to_hive.py b/airflow/operators/mssql_to_hive.py index a7901bfe60774..8dbfd2acff969 100644 --- a/airflow/operators/mssql_to_hive.py +++ b/airflow/operators/mssql_to_hive.py @@ -106,9 +106,9 @@ def type_map(cls, mssql_type: int) -> str: Maps MsSQL type to Hive type. """ map_dict = { - pymssql.BINARY.value: 'INT', - pymssql.DECIMAL.value: 'FLOAT', - pymssql.NUMBER.value: 'INT', + pymssql.BINARY.value: 'INT', # pylint: disable=c-extension-no-member + pymssql.DECIMAL.value: 'FLOAT', # pylint: disable=c-extension-no-member + pymssql.NUMBER.value: 'INT', # pylint: disable=c-extension-no-member } return map_dict.get(mssql_type, 'STRING') diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py index 6dfa9dbe7b6c9..3b34bf0ddcdb3 100644 --- a/airflow/plugins_manager.py +++ b/airflow/plugins_manager.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -16,17 +15,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -import imp +"""Manages all plugins.""" +# noinspection PyDeprecation +import imp # pylint: disable=deprecated-module import inspect import os import re -from typing import Any, List +import sys +from typing import Any, Callable, List, Optional import pkg_resources from airflow import settings -from airflow.models.baseoperator import BaseOperatorLink from airflow.utils.log.logging_mixin import LoggingMixin log = LoggingMixin().log @@ -35,28 +35,29 @@ class AirflowPluginException(Exception): - pass + """Exception when loading plugin.""" class AirflowPlugin: - name = None # type: str - operators = [] # type: List[Any] - sensors = [] # type: List[Any] - hooks = [] # type: List[Any] - executors = [] # type: List[Any] - macros = [] # type: List[Any] - admin_views = [] # type: List[Any] - flask_blueprints = [] # type: List[Any] - menu_links = [] # type: List[Any] - appbuilder_views = [] # type: List[Any] - appbuilder_menu_items = [] # type: List[Any] + """Class used to define AirflowPlugin.""" + name: Optional[str] = None + operators: List[Any] = [] + sensors: List[Any] = [] + hooks: List[Any] = [] + executors: List[Any] = [] + macros: List[Any] = [] + admin_views: List[Any] = [] + flask_blueprints: List[Any] = [] + menu_links: List[Any] = [] + appbuilder_views: List[Any] = [] + appbuilder_menu_items: List[Any] = [] # A function that validate the statsd stat name, apply changes # to the stat name if necessary and return the transformed stat name. # # The function should have the following signature: # def func_name(stat_name: str) -> str: - stat_name_handler = None # type: Any + stat_name_handler: Optional[Callable[[str], str]] = None # A list of global operator extra links that can redirect users to # external systems. These extra links will be available on the @@ -64,16 +65,17 @@ class AirflowPlugin: # # Note: the global operator extra link can be overridden at each # operator level. - global_operator_extra_links = [] # type: List[BaseOperatorLink] + global_operator_extra_links: List[Any] = [] # A list of operator extra links to override or add operator links # to existing Airflow Operators. # These extra links will be available on the task page in form of # buttons. - operator_extra_links = [] # type: List[BaseOperatorLink] + operator_extra_links: List[Any] = [] @classmethod def validate(cls): + """Validates that plugin has a name.""" if not cls.name: raise AirflowPluginException("Your plugin needs a name.") @@ -134,14 +136,13 @@ def is_valid_plugin(plugin_obj, existing_plugins): norm_pattern = re.compile(r'[/|.]') -if settings.PLUGINS_FOLDER is None: - raise AirflowPluginException("Plugins folder is not set") +assert settings.PLUGINS_FOLDER, "Plugins folder is not set" # Crawl through the plugins folder to find AirflowPlugin derivatives for root, dirs, files in os.walk(settings.PLUGINS_FOLDER, followlinks=True): for f in files: + filepath = os.path.join(root, f) try: - filepath = os.path.join(root, f) if not os.path.isfile(filepath): continue mod_name, file_ext = os.path.splitext( @@ -157,11 +158,11 @@ def is_valid_plugin(plugin_obj, existing_plugins): for obj in list(m.__dict__.values()): if is_valid_plugin(obj, plugins): plugins.append(obj) - - except Exception as e: + except Exception as e: # pylint: disable=broad-except log.exception(e) - log.error('Failed to import plugin %s', filepath) - import_errors[filepath] = str(e) + path = filepath or str(f) + log.error('Failed to import plugin %s', path) + import_errors[path] = str(e) plugins = load_entrypoint_plugins( pkg_resources.iter_entry_points('airflow.plugins'), @@ -169,14 +170,18 @@ def is_valid_plugin(plugin_obj, existing_plugins): ) -def make_module(name, objects): +# pylint: disable=protected-access +# noinspection Mypy,PyTypeHints +def make_module(name: str, objects: List[Any]): + """Creates new module.""" log.debug('Creating module %s', name) name = name.lower() module = imp.new_module(name) - module._name = name.split('.')[-1] - module._objects = objects + module._name = name.split('.')[-1] # type: ignore + module._objects = objects # type: ignore module.__dict__.update((o.__name__, o) for o in objects) return module +# pylint: enable=protected-access # Plugin components to integrate as modules @@ -187,26 +192,29 @@ def make_module(name, objects): macros_modules = [] # Plugin components to integrate directly -admin_views = [] # type: List[Any] -flask_blueprints = [] # type: List[Any] -menu_links = [] # type: List[Any] -flask_appbuilder_views = [] # type: List[Any] -flask_appbuilder_menu_links = [] # type: List[Any] -stat_name_handler = None # type: Any -global_operator_extra_links = [] # type: List[BaseOperatorLink] -operator_extra_links = [] # type: List[BaseOperatorLink] +admin_views: List[Any] = [] +flask_blueprints: List[Any] = [] +menu_links: List[Any] = [] +flask_appbuilder_views: List[Any] = [] +flask_appbuilder_menu_links: List[Any] = [] +stat_name_handler: Any = None +global_operator_extra_links: List[Any] = [] +operator_extra_links: List[Any] = [] stat_name_handlers = [] for p in plugins: + if not p.name: + raise AirflowPluginException("Plugin name is missing.") + plugin_name: str = p.name operators_modules.append( - make_module('airflow.operators.' + p.name, p.operators + p.sensors)) + make_module('airflow.operators.' + plugin_name, p.operators + p.sensors)) sensors_modules.append( - make_module('airflow.sensors.' + p.name, p.sensors) + make_module('airflow.sensors.' + plugin_name, p.sensors) ) - hooks_modules.append(make_module('airflow.hooks.' + p.name, p.hooks)) + hooks_modules.append(make_module('airflow.hooks.' + plugin_name, p.hooks)) executors_modules.append( - make_module('airflow.executors.' + p.name, p.executors)) - macros_modules.append(make_module('airflow.macros.' + p.name, p.macros)) + make_module('airflow.executors.' + plugin_name, p.executors)) + macros_modules.append(make_module('airflow.macros.' + plugin_name, p.macros)) admin_views.extend(p.admin_views) menu_links.extend(p.menu_links) @@ -230,3 +238,52 @@ def make_module(name, objects): 'is not allowed.'.format(stat_name_handlers)) stat_name_handler = stat_name_handlers[0] if len(stat_name_handlers) == 1 else None + + +def integrate_operator_plugins() -> None: + """Integrate operators plugins to the context""" + for operators_module in operators_modules: + sys.modules[operators_module.__name__] = operators_module + # noinspection PyProtectedMember + globals()[operators_module._name] = operators_module # pylint: disable=protected-access + + +def integrate_sensor_plugins() -> None: + """Integrate sensor plugins to the context""" + for sensors_module in sensors_modules: + sys.modules[sensors_module.__name__] = sensors_module + # noinspection PyProtectedMember + globals()[sensors_module._name] = sensors_module # pylint: disable=protected-access + + +def integrate_hook_plugins() -> None: + """Integrate hook plugins to the context""" + for hooks_module in hooks_modules: + sys.modules[hooks_module.__name__] = hooks_module + # noinspection PyProtectedMember + globals()[hooks_module._name] = hooks_module # pylint: disable=protected-access + + +def integrate_executor_plugins() -> None: + """Integrate executor plugins to the context.""" + for executors_module in executors_modules: + sys.modules[executors_module.__name__] = executors_module + # noinspection PyProtectedMember + globals()[executors_module._name] = executors_module # pylint: disable=protected-access + + +def integrate_macro_plugins() -> None: + """Integrate macro plugins to the context""" + for macros_module in macros_modules: + sys.modules[macros_module.__name__] = macros_module + # noinspection PyProtectedMember + globals()[macros_module._name] = macros_module # pylint: disable=protected-access + + +def integrate_plugins() -> None: + """Integrates all types of plugins.""" + integrate_operator_plugins() + integrate_sensor_plugins() + integrate_hook_plugins() + integrate_executor_plugins() + integrate_macro_plugins() diff --git a/airflow/sensors/__init__.py b/airflow/sensors/__init__.py index 945cd00d3e613..bcdaf4c3c8da1 100644 --- a/airflow/sensors/__init__.py +++ b/airflow/sensors/__init__.py @@ -17,14 +17,4 @@ # specific language governing permissions and limitations # under the License. # - -# pylint: disable=missing-docstring - - -def _integrate_plugins(): - """Integrate plugins to the context""" - import sys - from airflow.plugins_manager import sensors_modules - for sensors_module in sensors_modules: - sys.modules[sensors_module.__name__] = sensors_module - globals()[sensors_module._name] = sensors_module # pylint: disable=protected-access +"""Sensors.""" diff --git a/airflow/settings.py b/airflow/settings.py index c97a83d344fe5..6e9a93af2b5c2 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -68,13 +67,13 @@ LOG_FORMAT = conf.get('core', 'log_format') SIMPLE_LOG_FORMAT = conf.get('core', 'simple_log_format') -SQL_ALCHEMY_CONN = None # type: Optional[str] -DAGS_FOLDER = None # type: Optional[str] -PLUGINS_FOLDER = None # type: Optional[str] -LOGGING_CLASS_PATH = None # type: Optional[str] +SQL_ALCHEMY_CONN: Optional[str] = None +DAGS_FOLDER: Optional[str] = None +PLUGINS_FOLDER: Optional[str] = None +LOGGING_CLASS_PATH: Optional[str] = None -engine = None # type: Optional[Engine] -Session = None # type: Optional[SASession] +engine: Optional[Engine] = None +Session: Optional[SASession] = None # The JSON library to use for DAG Serialization and De-Serialization json = json diff --git a/airflow/task/task_runner/base_task_runner.py b/airflow/task/task_runner/base_task_runner.py index 8345c5d7b7422..83b6064dad7c8 100644 --- a/airflow/task/task_runner/base_task_runner.py +++ b/airflow/task/task_runner/base_task_runner.py @@ -16,13 +16,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Base task runner""" import getpass import os import subprocess import threading from airflow.configuration import conf +from airflow.exceptions import AirflowConfigException from airflow.utils.configuration import tmp_configuration_copy from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname @@ -51,7 +52,7 @@ def __init__(self, local_task_job): else: try: self.run_as_user = conf.get('core', 'default_impersonation') - except conf.AirflowConfigException: + except AirflowConfigException: self.run_as_user = None # Add sudo commands to change user if we need to. Needed to handle SubDagOperator @@ -104,7 +105,7 @@ def _read_task_logs(self, stream): line = stream.readline() if isinstance(line, bytes): line = line.decode('utf-8') - if len(line) == 0: + if not line: break self.log.info('Job %s: Subtask %s %s', self._task_instance.job_id, self._task_instance.task_id, @@ -124,6 +125,7 @@ def run_command(self, run_with=None): self.log.info("Running on host: %s", get_hostname()) self.log.info('Running: %s', full_cmd) + # pylint: disable=subprocess-popen-preexec-fn proc = subprocess.Popen( full_cmd, stdout=subprocess.PIPE, diff --git a/airflow/task/task_runner/standard_task_runner.py b/airflow/task/task_runner/standard_task_runner.py index c79d496ac34b0..9c5795169c316 100644 --- a/airflow/task/task_runner/standard_task_runner.py +++ b/airflow/task/task_runner/standard_task_runner.py @@ -16,7 +16,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Standard task runner""" import psutil from airflow.task.task_runner.base_task_runner import BaseTaskRunner @@ -25,7 +25,7 @@ class StandardTaskRunner(BaseTaskRunner): """ - Runs the raw Airflow task by invoking through the Bash shell. + Standard runner for all tasks. """ def __init__(self, local_task_job): super().__init__(local_task_job) @@ -39,6 +39,3 @@ def return_code(self): def terminate(self): if self.process and psutil.pid_exists(self.process.pid): reap_process_group(self.process.pid, self.log) - - def on_finish(self): - super().on_finish() diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py index 2058cbe093de3..f03fa2e3ebccb 100644 --- a/airflow/utils/dag_processing.py +++ b/airflow/utils/dag_processing.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -16,37 +15,36 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Processes DAGs.""" import enum import importlib import logging import multiprocessing import os -import re import signal import sys import time -import zipfile from abc import ABCMeta, abstractmethod from datetime import datetime, timedelta from importlib import import_module -from typing import Iterable, NamedTuple, Optional +from typing import Any, Callable, Dict, KeysView, List, NamedTuple, Optional, Tuple import psutil -from setproctitle import setproctitle +from setproctitle import setproctitle # pylint: disable=no-name-in-module from sqlalchemy import or_ from tabulate import tabulate -# To avoid circular imports import airflow.models from airflow.configuration import conf from airflow.dag.base_dag import BaseDag, BaseDagBag from airflow.exceptions import AirflowException -from airflow.models import errors +from airflow.models import Connection, errors +from airflow.models.taskinstance import SimpleTaskInstance from airflow.settings import STORE_SERIALIZED_DAGS from airflow.stats import Stats from airflow.utils import timezone from airflow.utils.db import provide_session +from airflow.utils.file import list_py_file_paths from airflow.utils.helpers import reap_process_group from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State @@ -63,23 +61,23 @@ class SimpleDag(BaseDag): :type pickle_id: unicode """ - def __init__(self, dag, pickle_id=None): - self._dag_id = dag.dag_id - self._task_ids = [task.task_id for task in dag.tasks] - self._full_filepath = dag.full_filepath - self._is_paused = dag.is_paused - self._concurrency = dag.concurrency - self._pickle_id = pickle_id - self._task_special_args = {} + def __init__(self, dag, pickle_id: Optional[str] = None): + self._dag_id: str = dag.dag_id + self._task_ids: List[str] = [task.task_id for task in dag.tasks] + self._full_filepath: str = dag.full_filepath + self._is_paused: bool = dag.is_paused + self._concurrency: int = dag.concurrency + self._pickle_id: Optional[str] = pickle_id + self._task_special_args: Dict[str, Any] = {} for task in dag.tasks: special_args = {} if task.task_concurrency is not None: special_args['task_concurrency'] = task.task_concurrency - if len(special_args) > 0: + if special_args: self._task_special_args[task.task_id] = special_args @property - def dag_id(self): + def dag_id(self) -> str: """ :return: the DAG ID :rtype: unicode @@ -87,7 +85,7 @@ def dag_id(self): return self._dag_id @property - def task_ids(self): + def task_ids(self) -> List[str]: """ :return: A list of task IDs that are in this DAG :rtype: list[unicode] @@ -95,7 +93,7 @@ def task_ids(self): return self._task_ids @property - def full_filepath(self): + def full_filepath(self) -> str: """ :return: The absolute path to the file that contains this DAG's definition :rtype: unicode @@ -103,7 +101,7 @@ def full_filepath(self): return self._full_filepath @property - def concurrency(self): + def concurrency(self) -> int: """ :return: maximum number of tasks that can run simultaneously from this DAG :rtype: int @@ -111,7 +109,7 @@ def concurrency(self): return self._concurrency @property - def is_paused(self): + def is_paused(self) -> bool: """ :return: whether this DAG is paused or not :rtype: bool @@ -119,7 +117,7 @@ def is_paused(self): return self._is_paused @property - def pickle_id(self): + def pickle_id(self) -> Optional[str]: """ :return: The pickle ID for this DAG, if it has one. Otherwise None. :rtype: unicode @@ -127,140 +125,45 @@ def pickle_id(self): return self._pickle_id @property - def task_special_args(self): + def task_special_args(self) -> Dict[str, Any]: + """Special arguments of the task.""" return self._task_special_args - def get_task_special_arg(self, task_id, special_arg_name): + def get_task_special_arg(self, task_id: str, special_arg_name: str): + """Retrieve special arguments of the task.""" if task_id in self._task_special_args and special_arg_name in self._task_special_args[task_id]: return self._task_special_args[task_id][special_arg_name] else: return None -class SimpleTaskInstance: - def __init__(self, ti): - self._dag_id = ti.dag_id - self._task_id = ti.task_id - self._execution_date = ti.execution_date - self._start_date = ti.start_date - self._end_date = ti.end_date - self._try_number = ti.try_number - self._state = ti.state - self._executor_config = ti.executor_config - if hasattr(ti, 'run_as_user'): - self._run_as_user = ti.run_as_user - else: - self._run_as_user = None - if hasattr(ti, 'pool'): - self._pool = ti.pool - else: - self._pool = None - if hasattr(ti, 'priority_weight'): - self._priority_weight = ti.priority_weight - else: - self._priority_weight = None - self._queue = ti.queue - self._key = ti.key - - @property - def dag_id(self): - return self._dag_id - - @property - def task_id(self): - return self._task_id - - @property - def execution_date(self): - return self._execution_date - - @property - def start_date(self): - return self._start_date - - @property - def end_date(self): - return self._end_date - - @property - def try_number(self): - return self._try_number - - @property - def state(self): - return self._state - - @property - def pool(self): - return self._pool - - @property - def priority_weight(self): - return self._priority_weight - - @property - def queue(self): - return self._queue - - @property - def key(self): - return self._key - - @property - def executor_config(self): - return self._executor_config - - @provide_session - def construct_task_instance(self, session=None, lock_for_update=False): - """ - Construct a TaskInstance from the database based on the primary key - - :param session: DB session. - :param lock_for_update: if True, indicates that the database should - lock the TaskInstance (issuing a FOR UPDATE clause) until the - session is committed. - """ - TI = airflow.models.TaskInstance - - qry = session.query(TI).filter( - TI.dag_id == self._dag_id, - TI.task_id == self._task_id, - TI.execution_date == self._execution_date) - - if lock_for_update: - ti = qry.with_for_update().first() - else: - ti = qry.first() - return ti - - class SimpleDagBag(BaseDagBag): """ A collection of SimpleDag objects with some convenience methods. """ - def __init__(self, simple_dags): + def __init__(self, simple_dags: List[SimpleDag]): """ Constructor. :param simple_dags: SimpleDag objects that should be in this - :type list(airflow.utils.dag_processing.SimpleDagBag) + :type list(airflow.utils.dag_processing.SimpleDag) """ self.simple_dags = simple_dags - self.dag_id_to_simple_dag = {} + self.dag_id_to_simple_dag: Dict[str, SimpleDag] = {} for simple_dag in simple_dags: self.dag_id_to_simple_dag[simple_dag.dag_id] = simple_dag @property - def dag_ids(self): + def dag_ids(self) -> KeysView[str]: """ :return: IDs of all the DAGs in this :rtype: list[unicode] """ return self.dag_id_to_simple_dag.keys() - def get_dag(self, dag_id): + def get_dag(self, dag_id: str) -> SimpleDag: """ :param dag_id: DAG ID :type dag_id: unicode @@ -273,106 +176,6 @@ def get_dag(self, dag_id): return self.dag_id_to_simple_dag[dag_id] -def correct_maybe_zipped(fileloc): - """ - If the path contains a folder with a .zip suffix, then - the folder is treated as a zip archive and path to zip is returned. - """ - - _, archive, _ = re.search(r'((.*\.zip){})?(.*)'.format(re.escape(os.sep)), fileloc).groups() - if archive and zipfile.is_zipfile(archive): - return archive - else: - return fileloc - - -COMMENT_PATTERN = re.compile(r"\s*#.*") - - -def list_py_file_paths(directory, safe_mode=conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE', fallback=True), - include_examples=None): - """ - Traverse a directory and look for Python files. - - :param directory: the directory to traverse - :type directory: unicode - :param safe_mode: whether to use a heuristic to determine whether a file - contains Airflow DAG definitions. If not provided, use the - core.DAG_DISCOVERY_SAFE_MODE configuration setting. If not set, default - to safe. - :type safe_mode: bool - :param include_examples: include example DAGs - :type include_examples: bool - :return: a list of paths to Python files in the specified directory - :rtype: list[unicode] - """ - if include_examples is None: - include_examples = conf.getboolean('core', 'LOAD_EXAMPLES') - file_paths = [] - if directory is None: - return [] - elif os.path.isfile(directory): - return [directory] - elif os.path.isdir(directory): - patterns_by_dir = {} - for root, dirs, files in os.walk(directory, followlinks=True): - patterns = patterns_by_dir.get(root, []) - ignore_file = os.path.join(root, '.airflowignore') - if os.path.isfile(ignore_file): - with open(ignore_file, 'r') as file: - # If we have new patterns create a copy so we don't change - # the previous list (which would affect other subdirs) - lines_no_comments = [COMMENT_PATTERN.sub("", line) for line in file.read().split("\n")] - patterns += [re.compile(line) for line in lines_no_comments if line] - - # If we can ignore any subdirs entirely we should - fewer paths - # to walk is better. We have to modify the ``dirs`` array in - # place for this to affect os.walk - dirs[:] = [ - d - for d in dirs - if not any(p.search(os.path.join(root, d)) for p in patterns) - ] - - # We want patterns defined in a parent folder's .airflowignore to - # apply to subdirs too - for d in dirs: - patterns_by_dir[os.path.join(root, d)] = patterns - - for f in files: - try: - file_path = os.path.join(root, f) - if not os.path.isfile(file_path): - continue - _, file_ext = os.path.splitext(os.path.split(file_path)[-1]) - if file_ext != '.py' and not zipfile.is_zipfile(file_path): - continue - if any([re.findall(p, file_path) for p in patterns]): - continue - - # Heuristic that guesses whether a Python file contains an - # Airflow DAG definition. - might_contain_dag = True - if safe_mode and not zipfile.is_zipfile(file_path): - with open(file_path, 'rb') as fp: - content = fp.read() - might_contain_dag = all( - [s in content for s in (b'DAG', b'airflow')]) - - if not might_contain_dag: - continue - - file_paths.append(file_path) - except Exception: - log = LoggingMixin().log - log.exception("Error while examining %s", f) - if include_examples: - import airflow.example_dags - example_dag_folder = airflow.example_dags.__path__[0] - file_paths.extend(list_py_file_paths(example_dag_folder, safe_mode, False)) - return file_paths - - class AbstractDagFileProcessor(metaclass=ABCMeta): """ Processes a DAG file. See SchedulerJob.process_file() for more details. @@ -386,7 +189,7 @@ def start(self): raise NotImplementedError() @abstractmethod - def terminate(self, sigkill=False): + def terminate(self, sigkill: bool = False): """ Terminate (and then kill) the process launched to process the file """ @@ -394,7 +197,7 @@ def terminate(self, sigkill=False): @property @abstractmethod - def pid(self): + def pid(self) -> int: """ :return: the PID of the process launched to process the given file """ @@ -402,7 +205,7 @@ def pid(self): @property @abstractmethod - def exit_code(self): + def exit_code(self) -> int: """ After the process is finished, this can be called to get the return code :return: the exit code of the process @@ -412,7 +215,7 @@ def exit_code(self): @property @abstractmethod - def done(self): + def done(self) -> bool: """ Check if the process launched to process this file is done. :return: whether the process is finished running @@ -422,7 +225,7 @@ def done(self): @property @abstractmethod - def result(self): + def result(self) -> Tuple[List[SimpleDag], int]: """ A list of simple dags found, and the number of import errors @@ -451,7 +254,7 @@ def file_path(self): DagParsingStat = NamedTuple('DagParsingStat', [ - ('file_paths', Iterable[str]), + ('file_paths', List[str]), ('done', bool), ('all_files_processed', bool) ]) @@ -465,6 +268,7 @@ def file_path(self): class DagParsingSignal(enum.Enum): + """All signals sent to parser.""" AGENT_HEARTBEAT = 'agent_heartbeat' TERMINATE_MANAGER = 'terminate_manager' END_MANAGER = 'end_manager' @@ -563,6 +367,7 @@ def heartbeat(self): pass def wait_until_finished(self): + """Waits until DAG parsing is finished.""" while self._parent_signal_conn.poll(): try: result = self._parent_signal_conn.recv() @@ -660,6 +465,7 @@ def _sync_metadata(self, stat): self._done = stat.done self._all_files_processed = stat.all_files_processed + # pylint: disable=missing-docstring @property def file_paths(self): return self._file_paths @@ -696,7 +502,7 @@ def end(self): self._parent_signal_conn.close() -class DagFileProcessorManager(LoggingMixin): +class DagFileProcessorManager(LoggingMixin): # pylint: disable=too-many-instance-attributes """ Given a list of DAG definition files, this kicks off several processors in parallel to process them and put the results to a multiprocessing.Queue @@ -724,26 +530,27 @@ class DagFileProcessorManager(LoggingMixin): """ def __init__(self, - dag_directory, - file_paths, - max_runs, - processor_factory, - processor_timeout, - signal_conn, - async_mode=True): + dag_directory: str, + file_paths: List[str], + max_runs: int, + processor_factory: Callable[[str, List[Any]], AbstractDagFileProcessor], + processor_timeout: timedelta, + signal_conn: Connection, + async_mode: bool = True): self._file_paths = file_paths - self._file_path_queue = [] + self._file_path_queue: List[str] = [] self._dag_directory = dag_directory self._max_runs = max_runs self._processor_factory = processor_factory self._signal_conn = signal_conn self._async_mode = async_mode + self._parsing_start_time: Optional[datetime] = None self._parallelism = conf.getint('scheduler', 'max_threads') if 'sqlite' in conf.get('core', 'sql_alchemy_conn') and self._parallelism > 1: self.log.warning( - f"Because we cannot use more than 1 thread (max_threads = {self._parallelism}) " - "when using sqlite. So we set parallelism to 1." + "Because we cannot use more than 1 thread (max_threads = " + "%d ) when using sqlite. So we set parallelism to 1.", self._parallelism ) self._parallelism = 1 @@ -758,12 +565,12 @@ def __init__(self, self._zombie_threshold_secs = ( conf.getint('scheduler', 'scheduler_zombie_task_threshold')) # Map from file path to the processor - self._processors = {} + self._processors: Dict[str, AbstractDagFileProcessor] = {} self._heartbeat_count = 0 # Map from file path to stats about the file - self._file_stats = {} # type: dict(str, DagFileStat) + self._file_stats: Dict[str, DagFileStat] = {} self._last_zombie_query_time = None # Last time that the DAG dir was traversed to look for files @@ -772,7 +579,7 @@ def __init__(self, self.last_stat_print_time = timezone.datetime(2000, 1, 1) # TODO: Remove magic number self._zombie_query_interval = 10 - self._zombies = [] + self._zombies: List[SimpleTaskInstance] = [] # How long to wait before timing out a process to parse a DAG file self._processor_timeout = processor_timeout @@ -785,7 +592,7 @@ def __init__(self, signal.signal(signal.SIGINT, self._exit_gracefully) signal.signal(signal.SIGTERM, self._exit_gracefully) - def _exit_gracefully(self, signum, frame): + def _exit_gracefully(self, signum, frame): # pylint: disable=unused-argument """ Helper method to clean up DAG file processors to avoid leaving orphan processes. """ @@ -842,7 +649,7 @@ def start(self): continue self._refresh_dag_dir() - self._find_zombies() + self._find_zombies() # pylint: disable=no-value-for-parameter simple_dags = self.heartbeat() for simple_dag in simple_dags: @@ -899,10 +706,11 @@ def _refresh_dag_dir(self): self.log.info("There are %s files in %s", len(self._file_paths), self._dag_directory) self.set_file_paths(self._file_paths) + # noinspection PyBroadException try: self.log.debug("Removing old import errors") - self.clear_nonexistent_import_errors() - except Exception: + self.clear_nonexistent_import_errors() # pylint: disable=no-value-for-parameter + except Exception: # pylint: disable=broad-except self.log.exception("Error removing old import errors") if STORE_SERIALIZED_DAGS: @@ -915,8 +723,8 @@ def _print_stat(self): """ Occasionally print out stats about how fast the files are getting processed """ - if ((timezone.utcnow() - self.last_stat_print_time).total_seconds() > self.print_stats_interval): - if len(self._file_paths) > 0: + if (timezone.utcnow() - self.last_stat_print_time).total_seconds() > self.print_stats_interval: + if self._file_paths: self._log_file_processing_stats(self._file_paths) self.last_stat_print_time = timezone.utcnow() @@ -1014,10 +822,6 @@ def _log_file_processing_stats(self, known_file_paths): self.log.info(log_str) - @property - def file_paths(self): - return self._file_paths - def get_pid(self, file_path): """ :param file_path: the path to the file that's being processed @@ -1144,10 +948,8 @@ def collect_results(self): """ self._kill_timed_out_processors() - finished_processors = {} - """:type : dict[unicode, AbstractDagFileProcessor]""" - running_processors = {} - """:type : dict[unicode, AbstractDagFileProcessor]""" + finished_processors: Dict[str, AbstractDagFileProcessor] = {} + running_processors: Dict[str, AbstractDagFileProcessor] = {} for file_path, processor in self._processors.items(): if processor.done: @@ -1202,7 +1004,7 @@ def heartbeat(self): # Generate more file paths to process if we processed all the files # already. - if len(self._file_path_queue) == 0: + if not self._file_path_queue: self.emit_metrics() self._parsing_start_time = timezone.utcnow() @@ -1245,8 +1047,7 @@ def heartbeat(self): self._file_path_queue.extend(files_paths_to_queue) # Start more processors if we have enough slots and files to process - while (self._parallelism - len(self._processors) > 0 and - len(self._file_path_queue) > 0): + while self._parallelism - len(self._processors) > 0 and self._file_path_queue: file_path = self._file_path_queue.pop(0) processor = self._processor_factory(file_path, self._zombies) Stats.incr('dag_processing.processes') @@ -1270,7 +1071,7 @@ def _find_zombies(self, session): and update the current zombie list. """ now = timezone.utcnow() - zombies = [] + zombies: List[SimpleTaskInstance] = [] if not self._last_zombie_query_time or \ (now - self._last_zombie_query_time).total_seconds() > self._zombie_query_interval: # to avoid circular imports @@ -1313,7 +1114,7 @@ def _kill_timed_out_processors(self): self.log.info( "Processor for %s with PID %s started at %s has timed out, " "killing it.", - processor.file_path, processor.pid, processor.start_time.isoformat()) + file_path, processor.pid, processor.start_time.isoformat()) Stats.decr('dag_processing.processes') Stats.incr('dag_processing.processor_timeouts') # TODO: Remove ater Airflow 2.0 @@ -1348,7 +1149,7 @@ def end(self): them as orphaned. """ pids_to_kill = self.get_all_pids() - if len(pids_to_kill) > 0: + if pids_to_kill: # First try SIGTERM this_process = psutil.Process(os.getpid()) # Only check child processes to ensure that we don't have a case @@ -1372,7 +1173,7 @@ def end(self): # Then SIGKILL child_processes = [x for x in this_process.children(recursive=True) if x.is_running() and x.pid in pids_to_kill] - if len(child_processes) > 0: + if child_processes: self.log.info("SIGKILL processes that did not terminate gracefully") for child in child_processes: self.log.info("Killing child PID: %s", child.pid) @@ -1396,3 +1197,8 @@ def emit_metrics(self): # TODO: Remove before Airflow 2.0 Stats.gauge('collect_dags', parse_time) Stats.gauge('dagbag_import_errors', sum(stat.import_errors for stat in self._file_stats.values())) + + # pylint: disable=missing-docstring + @property + def file_paths(self): + return self._file_paths diff --git a/airflow/utils/file.py b/airflow/utils/file.py index 2d654df04e997..3c24bab5c7671 100644 --- a/airflow/utils/file.py +++ b/airflow/utils/file.py @@ -19,9 +19,14 @@ import errno import os +import re import shutil +import zipfile from contextlib import contextmanager from tempfile import mkdtemp +from typing import Dict, List, Optional, Pattern + +from airflow import LoggingMixin, conf @contextmanager @@ -56,3 +61,110 @@ def mkdirs(path, mode): raise finally: os.umask(o_umask) + + +def correct_maybe_zipped(fileloc): + """ + If the path contains a folder with a .zip suffix, then + the folder is treated as a zip archive and path to zip is returned. + """ + + _, archive, _ = re.search(r'((.*\.zip){})?(.*)'.format(re.escape(os.sep)), fileloc).groups() + if archive and zipfile.is_zipfile(archive): + return archive + else: + return fileloc + + +def list_py_file_paths(directory: str, + safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE', fallback=True), + include_examples: Optional[bool] = None): + """ + Traverse a directory and look for Python files. + + :param directory: the directory to traverse + :type directory: unicode + :param safe_mode: whether to use a heuristic to determine whether a file + contains Airflow DAG definitions. If not provided, use the + core.DAG_DISCOVERY_SAFE_MODE configuration setting. If not set, default + to safe. + :type safe_mode: bool + :param include_examples: include example DAGs + :type include_examples: bool + :return: a list of paths to Python files in the specified directory + :rtype: list[unicode] + """ + if include_examples is None: + include_examples = conf.getboolean('core', 'LOAD_EXAMPLES') + file_paths: List[str] = [] + if directory is None: + return [] + elif os.path.isfile(directory): + return [directory] + elif os.path.isdir(directory): + patterns_by_dir: Dict[str, List[Pattern[str]]] = {} + for root, dirs, files in os.walk(directory, followlinks=True): + patterns: List[Pattern[str]] = patterns_by_dir.get(root, []) + ignore_file = os.path.join(root, '.airflowignore') + if os.path.isfile(ignore_file): + with open(ignore_file, 'r') as file: + # If we have new patterns create a copy so we don't change + # the previous list (which would affect other subdirs) + lines_no_comments = [COMMENT_PATTERN.sub("", line) for line in file.read().split("\n")] + patterns += [re.compile(line) for line in lines_no_comments if line] + + # If we can ignore any subdirs entirely we should - fewer paths + # to walk is better. We have to modify the ``dirs`` array in + # place for this to affect os.walk + dirs[:] = [ + subdir + for subdir in dirs + if not any(p.search(os.path.join(root, subdir)) for p in patterns) + ] + + # We want patterns defined in a parent folder's .airflowignore to + # apply to subdirs too + for subdir in dirs: + patterns_by_dir[os.path.join(root, subdir)] = patterns + + find_dag_file_paths(file_paths, files, patterns, root, safe_mode) + if include_examples: + from airflow import example_dags + example_dag_folder = example_dags.__path__[0] # type: ignore + file_paths.extend(list_py_file_paths(example_dag_folder, safe_mode, False)) + return file_paths + + +def find_dag_file_paths(file_paths, files, patterns, root, safe_mode): + """Finds file paths of all DAG files.""" + for f in files: + # noinspection PyBroadException + try: + file_path = os.path.join(root, f) + if not os.path.isfile(file_path): + continue + _, file_ext = os.path.splitext(os.path.split(file_path)[-1]) + if file_ext != '.py' and not zipfile.is_zipfile(file_path): + continue + if any([re.findall(p, file_path) for p in patterns]): + continue + + if not might_contain_dag(file_path, safe_mode): + continue + + file_paths.append(file_path) + except Exception: # pylint: disable=broad-except + log = LoggingMixin().log + log.exception("Error while examining %s", f) + + +COMMENT_PATTERN = re.compile(r"\s*#.*") + + +def might_contain_dag(file_path, safe_mode): + """Heuristic that guesses whether a Python file contains an Airflow DAG definition.""" + if safe_mode and not zipfile.is_zipfile(file_path): + with open(file_path, 'rb') as dag_file: + content = dag_file.read() + return all([s in content for s in (b'DAG', b'airflow')]) + return True diff --git a/airflow/www/views.py b/airflow/www/views.py index 1340d34a052d9..5df3d67d0e6ef 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -50,6 +50,7 @@ set_dag_run_state_to_failed, set_dag_run_state_to_success, ) from airflow.configuration import AIRFLOW_CONFIG, conf +from airflow.executors.executor_loader import ExecutorLoader from airflow.models import Connection, DagModel, DagRun, Log, SlaMiss, TaskFail, XCom, errors from airflow.settings import STORE_SERIALIZED_DAGS from airflow.ti_deps.dep_context import RUNNING_DEPS, SCHEDULER_QUEUED_DEPS, DepContext @@ -147,11 +148,12 @@ def show_traceback(error): class AirflowBaseView(BaseView): + from airflow import macros route_base = '' # Make our macros available to our UI templates too. extra_args = { - 'macros': airflow.macros, + 'macros': macros, } def render_template(self, *args, **kwargs): @@ -807,8 +809,7 @@ def run(self): ignore_task_deps = request.form.get('ignore_task_deps') == "true" ignore_ti_state = request.form.get('ignore_ti_state') == "true" - from airflow.executors import get_default_executor - executor = get_default_executor() + executor = ExecutorLoader.get_default_executor() valid_celery_config = False valid_kubernetes_config = False diff --git a/breeze-complete b/breeze-complete index 36ea7f0e99b91..867542b532377 100644 --- a/breeze-complete +++ b/breeze-complete @@ -22,7 +22,7 @@ _BREEZE_ALLOWED_ENVS=" docker kubernetes " _BREEZE_ALLOWED_BACKENDS=" sqlite mysql postgres " _BREEZE_ALLOWED_KUBERNETES_VERSIONS=" v1.13.0 " _BREEZE_ALLOWED_KUBERNETES_MODES=" persistent_mode git_mode " -_BREEZE_ALLOWED_STATIC_CHECKS=" all all-but-pylint check-apache-license check-executables-have-shebangs check-hooks-apply check-merge-conflict check-xml debug-statements doctoc detect-private-key end-of-file-fixer flake8 forbid-tabs insert-license lint-dockerfile mixed-line-ending mypy pylint setup-order shellcheck" +_BREEZE_ALLOWED_STATIC_CHECKS=" all all-but-pylint check-apache-license check-executables-have-shebangs check-hooks-apply check-merge-conflict check-xml debug-statements doctoc detect-private-key end-of-file-fixer flake8 forbid-tabs insert-license lint-dockerfile mixed-line-ending mypy pylint pylint-test setup-order shellcheck" _BREEZE_DEFAULT_DOCKERHUB_USER="apache" _BREEZE_DEFAULT_DOCKERHUB_REPO="airflow" diff --git a/docs/howto/custom-operator.rst b/docs/howto/custom-operator.rst index 2d89b109935b1..7713468bd49b8 100644 --- a/docs/howto/custom-operator.rst +++ b/docs/howto/custom-operator.rst @@ -20,8 +20,8 @@ Creating a custom Operator ========================== -Airflow allows you to create new operators to suit the requirements of you or your team. -The extensibility is one of the many reasons which makes Apache Airflow powerful. +Airflow allows you to create new operators to suit the requirements of you or your team. +The extensibility is one of the many reasons which makes Apache Airflow powerful. You can create any operator you want by extending the :class:`airflow.models.baseoperator.BaseOperator` @@ -31,16 +31,16 @@ There are two methods that you need to override in a derived class: Use ``@apply_defaults`` decorator function to fill unspecified arguments with ``default_args``. You can specify the ``default_args`` in the dag file. See :ref:`Default args ` for more details. -* Execute - The code to execute when the runner calls the operator. The method contains the +* Execute - The code to execute when the runner calls the operator. The method contains the airflow context as a parameter that can be used to read config values. Let's implement an example ``HelloOperator`` in a new file ``hello_operator.py``: .. code:: python - + from airflow.models.baseoperator import BaseOperator from airflow.utils.decorators import apply_defaults - + class HelloOperator(BaseOperator): @apply_defaults @@ -60,7 +60,7 @@ Let's implement an example ``HelloOperator`` in a new file ``hello_operator.py`` For imports to work, you should place the file in a directory that is present in the ``PYTHONPATH`` env. Airflow adds ``dags/``, ``plugins/``, and ``config/`` directories - in the Airflow home to ``PYTHONPATH`` by default. e.g., In our example, + in the Airflow home to ``PYTHONPATH`` by default. e.g., In our example, the file is placed in the ``custom_operator`` directory. You can now use the derived custom operator as follows: @@ -77,7 +77,7 @@ Hooks Hooks act as an interface to communicate with the external shared resources in a DAG. For example, multiple tasks in a DAG can require access to a MySQL database. Instead of creating a connection per task, you can retrieve a connection from the hook and utilize it. -Hook also helps to avoid storing connection auth parameters in a DAG. +Hook also helps to avoid storing connection auth parameters in a DAG. See :doc:`connection/index` for how to create and manage connections. Let's extend our previous example to fetch name from MySQL: @@ -107,9 +107,9 @@ Let's extend our previous example to fetch name from MySQL: print(message) return message -When the operator invokes the query on the hook object, a new connection gets created if it doesn't exist. +When the operator invokes the query on the hook object, a new connection gets created if it doesn't exist. The hook retrieves the auth parameters such as username and password from Airflow -backend and passes the params to the :py:func:`airflow.hooks.base_hook.BaseHook.get_connection`. +backend and passes the params to the :py:func:`airflow.hooks.base_hook.BaseHook.get_connection`. You should create hook only in the ``execute`` method or any method which is called from ``execute``. The constructor gets called whenever Airflow parses a DAG which happens frequently. The ``execute`` gets called only during a DAG run. @@ -117,7 +117,7 @@ The ``execute`` gets called only during a DAG run. User interface ^^^^^^^^^^^^^^^ Airflow also allows the developer to control how the operator shows up in the DAG UI. -Override ``ui_color`` to change the background color of the operator in UI. +Override ``ui_color`` to change the background color of the operator in UI. Override ``ui_fgcolor`` to change the color of the label. .. code:: python @@ -134,11 +134,11 @@ Airflow considers the field names present in ``template_fields`` for templating the operator. .. code:: python - + class HelloOperator(BaseOperator): - + template_fields = ['name'] - + @apply_defaults def __init__( self, @@ -166,14 +166,14 @@ In this example, Jinja looks for the ``name`` parameter and substitutes ``{{ tas The parameter can also contain a file name, for example, a bash script or a SQL file. You need to add the extension of your file in ``template_ext``. If a ``template_field`` contains a string ending with the extension mentioned in ``template_ext``, Jinja reads the content of the file and replace the templates -with actual value. Note that Jinja substitutes the operator attributes and not the args. +with actual value. Note that Jinja substitutes the operator attributes and not the args. .. code:: python class HelloOperator(BaseOperator): - + template_fields = ['guest_name'] - + @apply_defaults def __init__( self, diff --git a/scripts/ci/pylint_todo.txt b/scripts/ci/pylint_todo.txt index 4ab5fe486b258..c583756bd1fbf 100644 --- a/scripts/ci/pylint_todo.txt +++ b/scripts/ci/pylint_todo.txt @@ -103,11 +103,6 @@ ./airflow/contrib/sensors/sftp_sensor.py ./airflow/contrib/sensors/wasb_sensor.py ./airflow/contrib/sensors/weekday_sensor.py -./airflow/executors/base_executor.py -./airflow/executors/celery_executor.py -./airflow/executors/dask_executor.py -./airflow/executors/local_executor.py -./airflow/executors/sequential_executor.py ./airflow/hooks/dbapi_hook.py ./airflow/hooks/docker_hook.py ./airflow/hooks/druid_hook.py @@ -174,7 +169,6 @@ ./airflow/operators/s3_to_redshift_operator.py ./airflow/operators/slack_operator.py ./airflow/operators/sqlite_operator.py -./airflow/plugins_manager.py ./airflow/providers/amazon/aws/hooks/redshift.py ./airflow/providers/amazon/aws/operators/athena.py ./airflow/providers/amazon/aws/sensors/athena.py @@ -213,7 +207,6 @@ ./airflow/utils/cli_action_loggers.py ./airflow/utils/compression.py ./airflow/utils/configuration.py -./airflow/utils/dag_processing.py ./airflow/utils/dates.py ./airflow/utils/db.py ./airflow/utils/decorators.py diff --git a/tests/core.py b/tests/core.py index a323fa5378b66..edebbd84ae48d 100644 --- a/tests/core.py +++ b/tests/core.py @@ -31,10 +31,11 @@ from numpy.testing import assert_array_almost_equal from pendulum import utcnow -from airflow import DAG, configuration, exceptions, jobs, settings, utils -from airflow.configuration import AirflowConfigException, conf, run_command +from airflow import DAG, exceptions, jobs, settings, utils +from airflow.configuration import ( + DEFAULT_CONFIG, AirflowConfigException, conf, parameterized_config, run_command, +) from airflow.exceptions import AirflowException -from airflow.executors import SequentialExecutor from airflow.hooks.base_hook import BaseHook from airflow.hooks.sqlite_hook import SqliteHook from airflow.models import Connection, DagBag, DagRun, TaskFail, TaskInstance, Variable @@ -754,7 +755,7 @@ def test_variable_delete(self): def test_parameterized_config_gen(self): - cfg = configuration.parameterized_config(configuration.DEFAULT_CONFIG) + cfg = parameterized_config(DEFAULT_CONFIG) # making sure some basic building blocks are present: self.assertIn("[core]", cfg) @@ -870,6 +871,7 @@ def test_bad_trigger_rule(self): def test_terminate_task(self): """If a task instance's db state get deleted, it should fail""" + from airflow.executors.sequential_executor import SequentialExecutor TI = TaskInstance dag = self.dagbag.dags.get('test_utils') task = dag.task_dict.get('sleeps_forever') diff --git a/tests/dags/test_subdag.py b/tests/dags/test_subdag.py index 52bc9ec81c2b4..97cfb30548233 100644 --- a/tests/dags/test_subdag.py +++ b/tests/dags/test_subdag.py @@ -24,7 +24,7 @@ from datetime import datetime, timedelta -from airflow.models import DAG +from airflow.models.dag import DAG from airflow.operators.dummy_operator import DummyOperator from airflow.operators.subdag_operator import SubDagOperator diff --git a/tests/executors/test_base_executor.py b/tests/executors/test_base_executor.py index 3509634edcbb3..4851f010cd219 100644 --- a/tests/executors/test_base_executor.py +++ b/tests/executors/test_base_executor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py index bf5f6b4714e02..3409ede79cdb7 100644 --- a/tests/executors/test_celery_executor.py +++ b/tests/executors/test_celery_executor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -17,6 +16,7 @@ # specific language governing permissions and limitations # under the License. import contextlib +import datetime import os import sys import unittest @@ -24,14 +24,19 @@ from unittest import mock # leave this it is used by the test worker +# noinspection PyUnresolvedReferences import celery.contrib.testing.tasks # noqa: F401 pylint: disable=ungrouped-imports from celery import Celery, states as celery_states from celery.contrib.testing.worker import start_worker from kombu.asynchronous import set_event_loop from parameterized import parameterized +from airflow import DAG from airflow.configuration import conf from airflow.executors import celery_executor +from airflow.models import TaskInstance +from airflow.models.taskinstance import SimpleTaskInstance +from airflow.operators.bash_operator import BashOperator from airflow.utils.state import State @@ -76,14 +81,17 @@ def test_celery_integration(self, broker_url): with start_worker(app=app, logfile=sys.stdout, loglevel='info'): success_command = ['true', 'some_parameter'] fail_command = ['false', 'some_parameter'] + execute_date = datetime.datetime.now() cached_celery_backend = celery_executor.execute_command.backend - task_tuples_to_send = [('success', 'fake_simple_ti', success_command, - celery_executor.celery_configuration['task_default_queue'], - celery_executor.execute_command), - ('fail', 'fake_simple_ti', fail_command, - celery_executor.celery_configuration['task_default_queue'], - celery_executor.execute_command)] + task_tuples_to_send = [ + (('success', 'fake_simple_ti', execute_date, 0), + None, success_command, celery_executor.celery_configuration['task_default_queue'], + celery_executor.execute_command), + (('fail', 'fake_simple_ti', execute_date, 0), + None, fail_command, celery_executor.celery_configuration['task_default_queue'], + celery_executor.execute_command) + ] chunksize = executor._num_tasks_per_send_process(len(task_tuples_to_send)) num_processes = min(len(task_tuples_to_send), executor._sync_parallelism) @@ -97,21 +105,21 @@ def test_celery_integration(self, broker_url): send_pool.close() send_pool.join() - for key, command, result in key_and_async_results: + for task_instance_key, _, result in key_and_async_results: # Only pops when enqueued successfully, otherwise keep it # and expect scheduler loop to deal with it. result.backend = cached_celery_backend - executor.running[key] = command - executor.tasks[key] = result - executor.last_state[key] = celery_states.PENDING + executor.running.add(task_instance_key) + executor.tasks[task_instance_key] = result + executor.last_state[task_instance_key] = celery_states.PENDING - executor.running['success'] = True - executor.running['fail'] = True + executor.running.add(('success', 'fake_simple_ti', execute_date, 0)) + executor.running.add(('fail', 'fake_simple_ti', execute_date, 0)) executor.end(synchronous=True) - self.assertTrue(executor.event_buffer['success'], State.SUCCESS) - self.assertTrue(executor.event_buffer['fail'], State.FAILED) + self.assertEqual(executor.event_buffer[('success', 'fake_simple_ti', execute_date, 0)], State.SUCCESS) + self.assertEqual(executor.event_buffer[('fail', 'fake_simple_ti', execute_date, 0)], State.FAILED) self.assertNotIn('success', executor.tasks) self.assertNotIn('fail', executor.tasks) @@ -129,11 +137,19 @@ def fake_execute_command(): # fake_execute_command takes no arguments while execute_command takes 1, # which will cause TypeError when calling task.apply_async() executor = celery_executor.CeleryExecutor() - value_tuple = 'command', '_', 'queue', 'should_be_a_simple_ti' - executor.queued_tasks['key'] = value_tuple + task = BashOperator( + task_id="test", + bash_command="true", + dag=DAG(dag_id='id'), + start_date=datetime.datetime.now() + ) + value_tuple = 'command', 1, None, \ + SimpleTaskInstance(ti=TaskInstance(task=task, execution_date=datetime.datetime.now())) + key = ('fail', 'fake_simple_ti', datetime.datetime.now(), 0) + executor.queued_tasks[key] = value_tuple executor.heartbeat() self.assertEqual(1, len(executor.queued_tasks)) - self.assertEqual(executor.queued_tasks['key'], value_tuple) + self.assertEqual(executor.queued_tasks[key], value_tuple) def test_exception_propagation(self): with self._prepare_app() as app: diff --git a/tests/executors/test_dask_executor.py b/tests/executors/test_dask_executor.py index b9d8c0105910f..7968da560b2d9 100644 --- a/tests/executors/test_dask_executor.py +++ b/tests/executors/test_dask_executor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file diff --git a/tests/executors/test_executor.py b/tests/executors/test_executor.py index f501e56d3f6ba..642ed4324009a 100644 --- a/tests/executors/test_executor.py +++ b/tests/executors/test_executor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py index 0b107d093fba3..b099cc180ab57 100644 --- a/tests/executors/test_kubernetes_executor.py +++ b/tests/executors/test_kubernetes_executor.py @@ -199,7 +199,9 @@ def test_run_next_exception(self, mock_get_kube_client, mock_kubernetes_job_watc # Execute a task while the Api Throws errors try_number = 1 kubernetes_executor.execute_async(key=('dag', 'task', datetime.utcnow(), try_number), - command='command', executor_config={}) + queue=None, + command='command', + executor_config={}) kubernetes_executor.sync() kubernetes_executor.sync() diff --git a/tests/executors/test_local_executor.py b/tests/executors/test_local_executor.py index dc07b251c1c73..1c27f44d8b010 100644 --- a/tests/executors/test_local_executor.py +++ b/tests/executors/test_local_executor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -16,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +import datetime import unittest from unittest import mock @@ -37,22 +36,26 @@ def execution_parallelism(self, parallelism=0): fail_command = ['false', 'some_parameter'] self.assertTrue(executor.result_queue.empty()) + execution_date = datetime.datetime.now() for i in range(self.TEST_SUCCESS_COMMANDS): - key, command = success_key.format(i), success_command - executor.running[key] = True + key_id, command = success_key.format(i), success_command + key = key_id, 'fake_ti', execution_date, 0 + executor.running.add(key) executor.execute_async(key=key, command=command) - executor.running['fail'] = True - executor.execute_async(key='fail', command=fail_command) + fail_key = 'fail', 'fake_ti', execution_date, 0 + executor.running.add(fail_key) + executor.execute_async(key=fail_key, command=fail_command) executor.end() # By that time Queues are already shutdown so we cannot check if they are empty self.assertEqual(len(executor.running), 0) for i in range(self.TEST_SUCCESS_COMMANDS): - key = success_key.format(i) + key_id = success_key.format(i) + key = key_id, 'fake_ti', execution_date, 0 self.assertEqual(executor.event_buffer[key], State.SUCCESS) - self.assertEqual(executor.event_buffer['fail'], State.FAILED) + self.assertEqual(executor.event_buffer[fail_key], State.FAILED) expected = self.TEST_SUCCESS_COMMANDS + 1 if parallelism == 0 else parallelism self.assertEqual(executor.workers_used, expected) diff --git a/tests/executors/test_sequential_executor.py b/tests/executors/test_sequential_executor.py index f4b1cf4871457..6100c72b9f252 100644 --- a/tests/executors/test_sequential_executor.py +++ b/tests/executors/test_sequential_executor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file diff --git a/tests/gcp/hooks/test_google_discovery_api.py b/tests/gcp/hooks/test_google_discovery_api.py index 6117656ca5162..e6d5ca3ca22d5 100644 --- a/tests/gcp/hooks/test_google_discovery_api.py +++ b/tests/gcp/hooks/test_google_discovery_api.py @@ -20,7 +20,8 @@ import unittest from unittest.mock import call, patch -from airflow import configuration, models +from airflow import models +from airflow.configuration import load_test_config from airflow.gcp.hooks.discovery_api import GoogleDiscoveryApiHook from airflow.utils import db @@ -28,7 +29,7 @@ class TestGoogleDiscoveryApiHook(unittest.TestCase): def setUp(self): - configuration.load_test_config() + load_test_config() db.merge_conn( models.Connection( diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index ce881bcf58cda..d8af907226a84 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -24,7 +24,7 @@ from airflow import AirflowException, models, settings from airflow.configuration import conf -from airflow.executors import SequentialExecutor +from airflow.executors.sequential_executor import SequentialExecutor from airflow.jobs import LocalTaskJob from airflow.models import DAG, TaskInstance as TI from airflow.operators.dummy_operator import DummyOperator diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index fa2789ebb508c..f54a613368d9c 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -31,15 +31,16 @@ import airflow.example_dags from airflow import AirflowException, models, settings from airflow.configuration import conf -from airflow.executors import BaseExecutor +from airflow.executors.base_executor import BaseExecutor from airflow.jobs import BackfillJob, SchedulerJob from airflow.models import DAG, DagBag, DagModel, DagRun, Pool, SlaMiss, TaskInstance as TI, errors from airflow.operators.bash_operator import BashOperator from airflow.operators.dummy_operator import DummyOperator from airflow.utils import timezone -from airflow.utils.dag_processing import SimpleDag, SimpleDagBag, list_py_file_paths +from airflow.utils.dag_processing import SimpleDag, SimpleDagBag from airflow.utils.dates import days_ago from airflow.utils.db import create_session, provide_session +from airflow.utils.file import list_py_file_paths from airflow.utils.state import State from tests.compat import MagicMock, Mock, PropertyMock, mock, patch from tests.core import TEST_DAG_FOLDER @@ -134,7 +135,6 @@ def test_no_orphan_process_will_be_left(self): scheduler.run() shutil.rmtree(empty_dir) - scheduler.executor.terminate() # Remove potential noise created by previous tests. current_children = set(current_process.children(recursive=True)) - set( old_children) diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 5a371bff49e67..8f7d60c069780 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -37,7 +37,7 @@ from airflow.operators.dummy_operator import DummyOperator from airflow.operators.subdag_operator import SubDagOperator from airflow.utils import timezone -from airflow.utils.dag_processing import list_py_file_paths +from airflow.utils.file import list_py_file_paths from airflow.utils.state import State from airflow.utils.weight_rule import WeightRule from tests.models import DEFAULT_DATE diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index 67dd9f16955a9..71776df7a13fc 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -30,7 +30,7 @@ from airflow import models from airflow.configuration import conf from airflow.models import DagBag, DagModel, TaskInstance as TI -from airflow.utils.dag_processing import SimpleTaskInstance +from airflow.models.taskinstance import SimpleTaskInstance from airflow.utils.db import create_session from airflow.utils.state import State from tests.models import DEFAULT_DATE, TEST_DAGS_FOLDER diff --git a/tests/operators/test_google_api_to_s3_transfer.py b/tests/operators/test_google_api_to_s3_transfer.py index 14cea11c01735..b942a2223ca54 100644 --- a/tests/operators/test_google_api_to_s3_transfer.py +++ b/tests/operators/test_google_api_to_s3_transfer.py @@ -20,7 +20,8 @@ import unittest from unittest.mock import Mock, patch -from airflow import configuration, models +from airflow import models +from airflow.configuration import load_test_config from airflow.models.xcom import MAX_XCOM_SIZE from airflow.operators.google_api_to_s3_transfer import GoogleApiToS3Transfer from airflow.utils import db @@ -29,7 +30,7 @@ class TestGoogleApiToS3Transfer(unittest.TestCase): def setUp(self): - configuration.load_test_config() + load_test_config() db.merge_conn( models.Connection( diff --git a/tests/operators/test_operators.py b/tests/operators/test_operators.py index e2612ccd1101c..163eb87bba9ae 100644 --- a/tests/operators/test_operators.py +++ b/tests/operators/test_operators.py @@ -22,7 +22,8 @@ from collections import OrderedDict from unittest import mock -from airflow import DAG, configuration, operators +from airflow import DAG, operators +from airflow.configuration import conf from airflow.utils import timezone DEFAULT_DATE = timezone.datetime(2015, 1, 1) @@ -47,7 +48,7 @@ def tearDown(self): for table in drop_tables: conn.execute("DROP TABLE IF EXISTS {}".format(table)) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_mysql_operator_test(self): sql = """ @@ -62,7 +63,7 @@ def test_mysql_operator_test(self): dag=self.dag) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_mysql_operator_test_multi(self): sql = [ @@ -78,7 +79,7 @@ def test_mysql_operator_test_multi(self): ) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_mysql_hook_test_bulk_load(self): records = ("foo", "bar", "baz") @@ -102,7 +103,7 @@ def test_mysql_hook_test_bulk_load(self): results = tuple(result[0] for result in c.fetchall()) self.assertEqual(sorted(results), sorted(records)) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_mysql_hook_test_bulk_dump(self): from airflow.hooks.mysql_hook import MySqlHook @@ -115,7 +116,7 @@ def test_mysql_hook_test_bulk_dump(self): self.skipTest("Skip test_mysql_hook_test_bulk_load " "since file output is not permitted") - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") @mock.patch('airflow.hooks.mysql_hook.MySqlHook.get_conn') def test_mysql_hook_test_bulk_dump_mock(self, mock_get_conn): @@ -136,7 +137,7 @@ def test_mysql_hook_test_bulk_dump_mock(self, mock_get_conn): """.format(tmp_file=tmp_file, table=table) assertEqualIgnoreMultipleSpaces(self, mock_execute.call_args[0][0], query) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_mysql_to_mysql(self): sql = "SELECT * FROM INFORMATION_SCHEMA.TABLES LIMIT 100;" @@ -155,7 +156,7 @@ def test_mysql_to_mysql(self): dag=self.dag) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_overwrite_schema(self): """ @@ -193,7 +194,7 @@ def tearDown(self): for t in tables_to_drop: cur.execute("DROP TABLE IF EXISTS {}".format(t)) - @unittest.skipUnless('postgres' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('postgres' in conf.get('core', 'sql_alchemy_conn'), "This is a Postgres test") def test_postgres_operator_test(self): sql = """ @@ -215,7 +216,7 @@ def test_postgres_operator_test(self): end_date=DEFAULT_DATE, ignore_ti_state=True) - @unittest.skipUnless('postgres' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('postgres' in conf.get('core', 'sql_alchemy_conn'), "This is a Postgres test") def test_postgres_operator_test_multi(self): sql = [ @@ -228,7 +229,7 @@ def test_postgres_operator_test_multi(self): task_id='postgres_operator_test_multi', sql=sql, dag=self.dag) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @unittest.skipUnless('postgres' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('postgres' in conf.get('core', 'sql_alchemy_conn'), "This is a Postgres test") def test_postgres_to_postgres(self): sql = "SELECT * FROM INFORMATION_SCHEMA.TABLES LIMIT 100;" @@ -247,7 +248,7 @@ def test_postgres_to_postgres(self): dag=self.dag) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @unittest.skipUnless('postgres' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('postgres' in conf.get('core', 'sql_alchemy_conn'), "This is a Postgres test") def test_vacuum(self): """ @@ -263,7 +264,7 @@ def test_vacuum(self): autocommit=True) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @unittest.skipUnless('postgres' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('postgres' in conf.get('core', 'sql_alchemy_conn'), "This is a Postgres test") def test_overwrite_schema(self): """ @@ -369,14 +370,14 @@ def tearDown(self): with MySqlHook().get_conn() as cur: cur.execute("DROP TABLE IF EXISTS baby_names CASCADE;") - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_clear(self): self.dag.clear( start_date=DEFAULT_DATE, end_date=timezone.utcnow()) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_mysql_to_hive(self): from airflow.operators.mysql_to_hive import MySqlToHiveTransfer @@ -391,7 +392,7 @@ def test_mysql_to_hive(self): dag=self.dag) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_mysql_to_hive_partition(self): from airflow.operators.mysql_to_hive import MySqlToHiveTransfer @@ -408,7 +409,7 @@ def test_mysql_to_hive_partition(self): dag=self.dag) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_mysql_to_hive_tblproperties(self): from airflow.operators.mysql_to_hive import MySqlToHiveTransfer @@ -424,7 +425,7 @@ def test_mysql_to_hive_tblproperties(self): dag=self.dag) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") @mock.patch('airflow.hooks.hive_hooks.HiveCliHook.load_file') def test_mysql_to_hive_type_conversion(self, mock_load_file): @@ -469,7 +470,7 @@ def test_mysql_to_hive_type_conversion(self, mock_load_file): with m.get_conn() as c: c.execute("DROP TABLE IF EXISTS {}".format(mysql_table)) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_mysql_to_hive_verify_csv_special_char(self): mysql_table = 'test_mysql_to_hive' @@ -520,7 +521,7 @@ def test_mysql_to_hive_verify_csv_special_char(self): with m.get_conn() as c: c.execute("DROP TABLE IF EXISTS {}".format(mysql_table)) - @unittest.skipUnless('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), + @unittest.skipUnless('mysql' in conf.get('core', 'sql_alchemy_conn'), "This is a MySQL test") def test_mysql_to_hive_verify_loaded_values(self): mysql_table = 'test_mysql_to_hive' diff --git a/tests/plugins/test_plugin.py b/tests/plugins/test_plugin.py index 0e3f3fe07ce21..4ec1d7f295d75 100644 --- a/tests/plugins/test_plugin.py +++ b/tests/plugins/test_plugin.py @@ -83,9 +83,11 @@ def test(self): static_url_path='/static/test_plugin') -# Create a handler to validate statsd stat name -def stat_name_dummy_handler(stat_name): - return stat_name +class StatsClass: + @staticmethod + # Create a handler to validate statsd stat name + def stat_name_dummy_handler(stat_name: str) -> str: + return stat_name # Defining the plugin class @@ -99,7 +101,7 @@ class AirflowTestPlugin(AirflowPlugin): flask_blueprints = [bp] appbuilder_views = [v_appbuilder_package] appbuilder_menu_items = [appbuilder_mitem] - stat_name_handler = staticmethod(stat_name_dummy_handler) + stat_name_handler = StatsClass.stat_name_dummy_handler global_operator_extra_links = [ AirflowLink(), GithubLink(), diff --git a/tests/task/task_runner/test_standard_task_runner.py b/tests/task/task_runner/test_standard_task_runner.py index de7416c5be50c..e99ea2a84b92d 100644 --- a/tests/task/task_runner/test_standard_task_runner.py +++ b/tests/task/task_runner/test_standard_task_runner.py @@ -76,18 +76,18 @@ def test_start_and_terminate(self): pgid = os.getpgid(runner.process.pid) self.assertTrue(pgid) - procs = [] - for p in psutil.process_iter(): + processes = [] + for process in psutil.process_iter(): try: - if os.getpgid(p.pid) == pgid: - procs.append(p) + if os.getpgid(process.pid) == pgid: + processes.append(process) except OSError: pass runner.terminate() - for p in procs: - self.assertFalse(psutil.pid_exists(p.pid)) + for process in processes: + self.assertFalse(psutil.pid_exists(process.pid)) def test_on_kill(self): """ diff --git a/tests/test_configuration.py b/tests/test_configuration.py index 79c4ca4c526ed..02127cead1ceb 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -24,7 +23,9 @@ from unittest import mock from airflow import configuration -from airflow.configuration import AirflowConfigParser, conf, parameterized_config +from airflow.configuration import ( + AirflowConfigParser, conf, expand_env_var, get_airflow_config, get_airflow_home, parameterized_config, +) @unittest.mock.patch.dict('os.environ', { @@ -42,13 +43,13 @@ def test_airflow_home_default(self): if 'AIRFLOW_HOME' in os.environ: del os.environ['AIRFLOW_HOME'] self.assertEqual( - configuration.get_airflow_home(), - configuration.expand_env_var('~/airflow')) + get_airflow_home(), + expand_env_var('~/airflow')) def test_airflow_home_override(self): with unittest.mock.patch.dict('os.environ', AIRFLOW_HOME='/path/to/airflow'): self.assertEqual( - configuration.get_airflow_home(), + get_airflow_home(), '/path/to/airflow') def test_airflow_config_default(self): @@ -56,13 +57,13 @@ def test_airflow_config_default(self): if 'AIRFLOW_CONFIG' in os.environ: del os.environ['AIRFLOW_CONFIG'] self.assertEqual( - configuration.get_airflow_config('/home/airflow'), - configuration.expand_env_var('/home/airflow/airflow.cfg')) + get_airflow_config('/home/airflow'), + expand_env_var('/home/airflow/airflow.cfg')) def test_airflow_config_override(self): with unittest.mock.patch.dict('os.environ', AIRFLOW_CONFIG='/path/to/airflow/airflow.cfg'): self.assertEqual( - configuration.get_airflow_config('/home//airflow'), + get_airflow_config('/home//airflow'), '/path/to/airflow/airflow.cfg') def test_case_sensitivity(self): diff --git a/tests/utils/test_dag_processing.py b/tests/utils/test_dag_processing.py index c34810925d461..1ad39e21445a8 100644 --- a/tests/utils/test_dag_processing.py +++ b/tests/utils/test_dag_processing.py @@ -29,11 +29,11 @@ from airflow.configuration import conf from airflow.jobs import DagFileProcessor, LocalTaskJob as LJ from airflow.models import DagBag, TaskInstance as TI +from airflow.models.taskinstance import SimpleTaskInstance from airflow.utils import timezone -from airflow.utils.dag_processing import ( - DagFileProcessorAgent, DagFileProcessorManager, DagFileStat, SimpleTaskInstance, correct_maybe_zipped, -) +from airflow.utils.dag_processing import DagFileProcessorAgent, DagFileProcessorManager, DagFileStat from airflow.utils.db import create_session +from airflow.utils.file import correct_maybe_zipped from airflow.utils.state import State from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_runs diff --git a/tests/utils/test_email.py b/tests/utils/test_email.py index 8f28075739a81..bc98ccafddf07 100644 --- a/tests/utils/test_email.py +++ b/tests/utils/test_email.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -25,7 +24,8 @@ import mock -from airflow import conf, utils +from airflow import utils +from airflow.configuration import conf from airflow.utils.email import get_email_address_list from tests import conf_vars diff --git a/tests/www/test_views.py b/tests/www/test_views.py index e3841d6d35c7c..4f4f071df4ffe 100644 --- a/tests/www/test_views.py +++ b/tests/www/test_views.py @@ -592,7 +592,7 @@ def test_run(self): resp = self.client.post('run', data=form) self.check_content_in_response('', resp, resp_code=302) - @mock.patch('airflow.executors.get_default_executor') + @mock.patch('airflow.executors.executor_loader.ExecutorLoader.get_default_executor') def test_run_with_runnable_states(self, get_default_executor_function): executor = CeleryExecutor() executor.heartbeat = lambda: True @@ -622,7 +622,7 @@ def test_run_with_runnable_states(self, get_default_executor_function): .format(state) + "The task must be cleared in order to be run" self.assertFalse(re.search(msg, resp.get_data(as_text=True))) - @mock.patch('airflow.executors.get_default_executor') + @mock.patch('airflow.executors.executor_loader.ExecutorLoader.get_default_executor') def test_run_with_not_runnable_states(self, get_default_executor_function): get_default_executor_function.return_value = CeleryExecutor()