diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index c160dd9f6142b..f7a2a537a3fd9 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -17,7 +17,7 @@ # specific language governing permissions and limitations # under the License. # - +import datetime import logging import multiprocessing import os @@ -26,14 +26,15 @@ import threading import time from collections import defaultdict -from contextlib import redirect_stderr, redirect_stdout, suppress +from contextlib import ExitStack, redirect_stderr, redirect_stdout, suppress from datetime import timedelta from itertools import groupby -from typing import Any, Dict, List, Optional, Tuple +from multiprocessing.connection import Connection as MultiprocessingConnection +from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple from setproctitle import setproctitle from sqlalchemy import and_, func, not_, or_ -from sqlalchemy.orm.session import make_transient +from sqlalchemy.orm.session import Session, make_transient from airflow import models, settings from airflow.configuration import conf @@ -41,6 +42,7 @@ from airflow.executors.executor_loader import UNPICKLEABLE_EXECUTORS from airflow.jobs.base_job import BaseJob from airflow.models import DAG, DagModel, SlaMiss, errors +from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKey from airflow.operators.dummy_operator import DummyOperator @@ -94,32 +96,33 @@ def __init__( self._failure_callback_requests = failure_callback_requests # The process that was launched to process the given . - self._process = None + self._process: Optional[multiprocessing.process.BaseProcess] = None # The result of Scheduler.process_file(file_path). - self._result = None + self._result: Optional[Tuple[List[SimpleDag], int]] = None # Whether the process is done running. self._done = False # When the process started. - self._start_time = None + self._start_time: Optional[datetime.datetime] = None # This ID is use to uniquely name the process / thread that's launched # by this processor instance self._instance_id = DagFileProcessorProcess.class_creation_counter - self._parent_channel = None - self._result_queue = None + self._parent_channel: Optional[MultiprocessingConnection] = None DagFileProcessorProcess.class_creation_counter += 1 @property - def file_path(self): + def file_path(self) -> str: return self._file_path @staticmethod - def _run_file_processor(result_channel, - file_path, - pickle_dags, - dag_ids, - thread_name, - failure_callback_requests): + def _run_file_processor( + result_channel: MultiprocessingConnection, + file_path: str, + pickle_dags: bool, + dag_ids: Optional[List[str]], + thread_name: str, + failure_callback_requests: List[FailureCallbackRequest] + ) -> None: """ Process the given file. @@ -141,16 +144,16 @@ def _run_file_processor(result_channel, :rtype: multiprocessing.Process """ # This helper runs in the newly created process - log = logging.getLogger("airflow.processor") + log: logging.Logger = logging.getLogger("airflow.processor") set_context(log, file_path) setproctitle("airflow scheduler - DagFileProcessor {}".format(file_path)) try: # redirect stdout/stderr to log - with redirect_stdout(StreamLogWriter(log, logging.INFO)),\ - redirect_stderr(StreamLogWriter(log, logging.WARN)): - + with ExitStack() as exit_stack: + exit_stack.enter_context(redirect_stdout(StreamLogWriter(log, logging.INFO))) # type: ignore + exit_stack.enter_context(redirect_stderr(StreamLogWriter(log, logging.WARN))) # type: ignore # Re-configure the ORM engine as there are issues with multiple processes settings.configure_orm() @@ -162,7 +165,7 @@ def _run_file_processor(result_channel, log.info("Started process (PID=%s) to work on %s", os.getpid(), file_path) dag_file_processor = DagFileProcessor(dag_ids=dag_ids, log=log) - result = dag_file_processor.process_file( + result: Tuple[List[SimpleDag], int] = dag_file_processor.process_file( file_path=file_path, pickle_dags=pickle_dags, failure_callback_requests=failure_callback_requests, @@ -182,7 +185,7 @@ def _run_file_processor(result_channel, # tear it down manually here settings.dispose_orm() - def start(self): + def start(self) -> None: """ Launch the process and start processing the DAG. """ @@ -190,7 +193,7 @@ def start(self): context = multiprocessing.get_context(start_method) self._parent_channel, _child_channel = context.Pipe() - self._process = context.Process( + process = context.Process( target=type(self)._run_file_processor, args=( _child_channel, @@ -202,73 +205,77 @@ def start(self): ), name="DagFileProcessor{}-Process".format(self._instance_id) ) + self._process = process self._start_time = timezone.utcnow() - self._process.start() + process.start() - def kill(self): + def kill(self) -> None: """ Kill the process launched to process the file, and ensure consistent state. """ if self._process is None: raise AirflowException("Tried to kill before starting!") - # The queue will likely get corrupted, so remove the reference - self._result_queue = None self._kill_process() - def terminate(self, sigkill=False): + def terminate(self, sigkill: bool = False) -> None: """ Terminate (and then kill) the process launched to process the file. :param sigkill: whether to issue a SIGKILL if SIGTERM doesn't work. :type sigkill: bool """ - if self._process is None: + if self._process is None or self._parent_channel is None: raise AirflowException("Tried to call terminate before starting!") self._process.terminate() # Arbitrarily wait 5s for the process to die with suppress(TimeoutError): - self._process._popen.wait(5) # pylint: disable=protected-access + self._process._popen.wait(5) # type: ignore # pylint: disable=protected-access if sigkill: self._kill_process() self._parent_channel.close() - def _kill_process(self): - if self._process.is_alive(): + def _kill_process(self) -> None: + if self._process is None: + raise AirflowException("Tried to kill process before starting!") + + if self._process.is_alive() and self._process.pid: self.log.warning("Killing PID %s", self._process.pid) os.kill(self._process.pid, signal.SIGKILL) @property - def pid(self): + def pid(self) -> int: """ :return: the PID of the process launched to process the given file :rtype: int """ - if self._process is None: + if self._process is None or self._process.pid is None: raise AirflowException("Tried to get PID before starting!") return self._process.pid @property - def exit_code(self): + def exit_code(self) -> Optional[int]: """ After the process is finished, this can be called to get the return code :return: the exit code of the process :rtype: int """ + if self._process is None: + raise AirflowException("Tried to get exit code before starting!") if not self._done: raise AirflowException("Tried to call retcode before process was finished!") return self._process.exitcode @property - def done(self): + def done(self) -> bool: """ Check if the process launched to process this file is done. :return: whether the process is finished running :rtype: bool """ - if self._process is None: + if self._process is None or self._parent_channel is None: raise AirflowException("Tried to see if it's done before starting!") if self._done: @@ -295,17 +302,17 @@ def done(self): return False @property - def result(self): + def result(self) -> Optional[Tuple[List[SimpleDag], int]]: """ :return: result of running SchedulerJob.process_file() - :rtype: airflow.utils.dag_processing.SimpleDag + :rtype: Optional[Tuple[List[SimpleDag], int]] """ if not self.done: raise AirflowException("Tried to get the result before it's done!") return self._result @property - def start_time(self): + def start_time(self) -> datetime.datetime: """ :return: when this started to process the file :rtype: datetime @@ -342,15 +349,15 @@ class DagFileProcessor(LoggingMixin): :type log: logging.Logger """ - UNIT_TEST_MODE = conf.getboolean('core', 'UNIT_TEST_MODE') + UNIT_TEST_MODE: bool = conf.getboolean('core', 'UNIT_TEST_MODE') - def __init__(self, dag_ids, log): + def __init__(self, dag_ids: Optional[List[str]], log: logging.Logger): super().__init__() self.dag_ids = dag_ids self._log = log @provide_session - def manage_slas(self, dag: DAG, session=None): + def manage_slas(self, dag: DAG, session: Session = None) -> None: """ Finding all tasks that have SLAs defined, and sending alert emails where needed. New SLA misses are also recorded in the database. @@ -380,7 +387,7 @@ def manage_slas(self, dag: DAG, session=None): .group_by(TI.task_id).subquery('sq') ) - max_tis = session.query(TI).filter( + max_tis: List[TI] = session.query(TI).filter( TI.dag_id == dag.dag_id, TI.task_id == qry.c.task_id, TI.execution_date == qry.c.max_ti, @@ -404,7 +411,7 @@ def manage_slas(self, dag: DAG, session=None): dttm = dag.following_schedule(dttm) session.commit() - slas = ( + slas: List[SlaMiss] = ( session .query(SlaMiss) .filter(SlaMiss.notification_sent == False, SlaMiss.dag_id == dag.dag_id) # noqa pylint: disable=singleton-comparison @@ -412,8 +419,8 @@ def manage_slas(self, dag: DAG, session=None): ) if slas: # pylint: disable=too-many-nested-blocks - sla_dates = [sla.execution_date for sla in slas] - qry = ( + sla_dates: List[datetime.datetime] = [sla.execution_date for sla in slas] + fetched_tis: List[TI] = ( session .query(TI) .filter( @@ -422,8 +429,8 @@ def manage_slas(self, dag: DAG, session=None): TI.dag_id == dag.dag_id ).all() ) - blocking_tis = [] - for ti in qry: + blocking_tis: List[TI] = [] + for ti in fetched_tis: if ti.task_id in dag.task_ids: ti.task = dag.get_task(ti.task_id) blocking_tis.append(ti) @@ -471,7 +478,7 @@ def manage_slas(self, dag: DAG, session=None): continue tasks_missed_sla.append(task) - emails = set() + emails: Set[str] = set() for task in tasks_missed_sla: if task.email: if isinstance(task.email, str): @@ -500,7 +507,7 @@ def manage_slas(self, dag: DAG, session=None): session.commit() @staticmethod - def update_import_errors(session, dagbag): + def update_import_errors(session: Session, dagbag: DagBag) -> None: """ For the DAGs in the given DagBag, record any associated import errors and clears errors for files that no longer have them. These are usually displayed through the @@ -509,7 +516,7 @@ def update_import_errors(session, dagbag): :param session: session for ORM operations :type session: sqlalchemy.orm.session.Session :param dagbag: DagBag containing DAGs with import errors - :type dagbag: airflow.models.DagBag + :type dagbag: airflow.DagBag """ # Clear the errors of the processed files for dagbag_file in dagbag.file_last_changed: @@ -531,8 +538,8 @@ def create_dag_run( self, dag: DAG, dag_runs: Optional[List[DagRun]] = None, - session=None, - ) -> None: + session: Session = None, + ) -> Optional[DagRun]: """ This method checks whether a new DagRun needs to be created for a DAG based on scheduling interval. @@ -542,6 +549,7 @@ def create_dag_run( if not dag.schedule_interval: return None + active_runs: List[DagRun] if dag_runs is None: active_runs = DagRun.find( dag_id=dag.dag_id, @@ -574,15 +582,14 @@ def create_dag_run( return None # this query should be replaced by find dagrun - qry = ( + last_scheduled_run: Optional[datetime.datetime] = ( session.query(func.max(DagRun.execution_date)) .filter_by(dag_id=dag.dag_id) .filter(or_( DagRun.external_trigger == False, # noqa: E712 pylint: disable=singleton-comparison DagRun.run_type == DagRunType.SCHEDULED.value - )) + )).scalar() ) - last_scheduled_run = qry.scalar() # don't schedule @once again if dag.schedule_interval == '@once' and last_scheduled_run: @@ -674,7 +681,7 @@ def create_dag_run( @provide_session def _process_task_instances( - self, dag: DAG, dag_runs: List[DagRun], session=None + self, dag: DAG, dag_runs: List[DagRun], session: Session = None ) -> List[TaskInstanceKey]: """ This method schedules the tasks for a single DAG by looking at the @@ -718,7 +725,7 @@ def _process_task_instances( return task_instances_list @provide_session - def _process_dags(self, dags: List[DAG], session=None): + def _process_dags(self, dags: List[DAG], session: Session = None) -> List[TaskInstanceKey]: """ Iterates over the dags and processes them. Processing includes: @@ -731,19 +738,21 @@ def _process_dags(self, dags: List[DAG], session=None): :rtype: list[TaskInstance] :return: A list of generated TaskInstance objects """ - check_slas = conf.getboolean('core', 'CHECK_SLAS', fallback=True) - use_job_schedule = conf.getboolean('scheduler', 'USE_JOB_SCHEDULE') + check_slas: bool = conf.getboolean('core', 'CHECK_SLAS', fallback=True) + use_job_schedule: bool = conf.getboolean('scheduler', 'USE_JOB_SCHEDULE') # pylint: disable=too-many-nested-blocks tis_out: List[TaskInstanceKey] = [] - dag_ids = [dag.dag_id for dag in dags] + dag_ids: List[str] = [dag.dag_id for dag in dags] dag_runs = DagRun.find(dag_id=dag_ids, state=State.RUNNING, session=session) # As per the docs of groupby (https://docs.python.org/3/library/itertools.html#itertools.groupby) # we need to use `list()` otherwise the result will be wrong/incomplete - dag_runs_by_dag_id = {k: list(v) for k, v in groupby(dag_runs, lambda d: d.dag_id)} + dag_runs_by_dag_id: Dict[str, List[DagRun]] = { + k: list(v) for k, v in groupby(dag_runs, lambda d: d.dag_id) + } for dag in dags: - dag_id = dag.dag_id + dag_id: str = dag.dag_id self.log.info("Processing %s", dag_id) dag_runs_for_dag = dag_runs_by_dag_id.get(dag_id) or [] @@ -781,7 +790,12 @@ def _find_dags_to_process(self, dags: List[DAG]) -> List[DAG]: return dags @provide_session - def execute_on_failure_callbacks(self, dagbag, failure_callback_requests, session=None): + def execute_on_failure_callbacks( + self, + dagbag: DagBag, + failure_callback_requests: List[FailureCallbackRequest], + session: Session = None + ) -> None: """ Execute on failure callbacks. These objects can come from SchedulerJob or from DagFileProcessorManager. @@ -809,7 +823,10 @@ def execute_on_failure_callbacks(self, dagbag, failure_callback_requests, sessio @provide_session def process_file( - self, file_path, failure_callback_requests, pickle_dags=False, session=None + self, + file_path: str, + failure_callback_requests: List[FailureCallbackRequest], + pickle_dags: bool = False, session: Session = None ) -> Tuple[List[SimpleDag], int]: """ Process a Python file containing Airflow DAGs. @@ -841,7 +858,7 @@ def process_file( self.log.info("Processing file %s for tasks to queue", file_path) try: - dagbag = models.DagBag(file_path, include_examples=False) + dagbag = DagBag(file_path, include_examples=False) except Exception: # pylint: disable=broad-except self.log.exception("Failed at reloading the DAG file %s", file_path) Stats.incr('dag_file_refresh_error', 1, 1) @@ -864,7 +881,9 @@ def process_file( paused_dag_ids = DagModel.get_paused_dag_ids(dag_ids=dagbag.dag_ids) - unpaused_dags = [dag for dag_id, dag in dagbag.dags.items() if dag_id not in paused_dag_ids] + unpaused_dags: List[DAG] = [ + dag for dag_id, dag in dagbag.dags.items() if dag_id not in paused_dag_ids + ] simple_dags = self._prepare_simple_dags(unpaused_dags, pickle_dags, session) @@ -885,16 +904,16 @@ def process_file( @provide_session def _schedule_task_instances( self, - dagbag: models.DagBag, + dagbag: DagBag, ti_keys_to_schedule: List[TaskInstanceKey], - session=None + session: Session = None ) -> None: """ Checks whether the tasks specified by `ti_keys_to_schedule` parameter can be scheduled and updates the information in the database, :param dagbag: DagBag - :type dagbag: models.DagBag + :type dagbag: DagBag :param ti_keys_to_schedule: List of task instnace keys which can be scheduled. :type ti_keys_to_schedule: list """ @@ -908,7 +927,7 @@ def _schedule_task_instances( for ti in refreshed_tis: # Add task to task instance - dag = dagbag.dags[ti.dag_id] + dag: DAG = dagbag.dags[ti.dag_id] ti.task = dag.get_task(ti.task_id) # We check only deps needed to set TI to SCHEDULED state here. @@ -942,7 +961,9 @@ def _schedule_task_instances( session.commit() @provide_session - def _prepare_simple_dags(self, dags: List[DAG], pickle_dags: bool, session=None) -> List[SimpleDag]: + def _prepare_simple_dags( + self, dags: List[DAG], pickle_dags: bool, session: Session = None + ) -> List[SimpleDag]: """ Convert DAGS to SimpleDags. If necessary, it also Pickle the DAGs @@ -953,11 +974,10 @@ def _prepare_simple_dags(self, dags: List[DAG], pickle_dags: bool, session=None) :return: List of SimpleDag :rtype: List[airflow.utils.dag_processing.SimpleDag] """ - - simple_dags = [] + simple_dags: List[SimpleDag] = [] # Pickle the DAGs (if necessary) and put them into a SimpleDag for dag in dags: - pickle_id = dag.pickle(session).id if pickle_dags else None + pickle_id: int = dag.pickle(session).id if pickle_dags else None simple_dags.append(SimpleDag(dag, pickle_id=pickle_id)) return simple_dags @@ -991,7 +1011,7 @@ class SchedulerJob(BaseJob): __mapper_args__ = { 'polymorphic_identity': 'SchedulerJob' } - heartrate = conf.getint('scheduler', 'SCHEDULER_HEARTBEAT_SEC') + heartrate: int = conf.getint('scheduler', 'SCHEDULER_HEARTBEAT_SEC') def __init__( self, @@ -1021,21 +1041,21 @@ def __init__( self._log = log # Check what SQL backend we use - sql_conn = conf.get('core', 'sql_alchemy_conn').lower() + sql_conn: str = conf.get('core', 'sql_alchemy_conn').lower() self.using_sqlite = sql_conn.startswith('sqlite') self.using_mysql = sql_conn.startswith('mysql') - self.max_tis_per_query = conf.getint('scheduler', 'max_tis_per_query') - self.processor_agent = None + self.max_tis_per_query: int = conf.getint('scheduler', 'max_tis_per_query') + self.processor_agent: Optional[DagFileProcessorAgent] = None - def register_exit_signals(self): + def register_exit_signals(self) -> None: """ Register signals that stop child processes """ signal.signal(signal.SIGINT, self._exit_gracefully) signal.signal(signal.SIGTERM, self._exit_gracefully) - def _exit_gracefully(self, signum, frame): # pylint: disable=unused-argument + def _exit_gracefully(self, signum, frame) -> None: # pylint: disable=unused-argument """ Helper method to clean up processor_agent to avoid leaving orphan processes. """ @@ -1044,7 +1064,7 @@ def _exit_gracefully(self, signum, frame): # pylint: disable=unused-argument self.processor_agent.end() sys.exit(os.EX_OK) - def is_alive(self, grace_multiplier=None): + def is_alive(self, grace_multiplier: Optional[float] = None) -> bool: """ Is this SchedulerJob alive? @@ -1059,7 +1079,7 @@ def is_alive(self, grace_multiplier=None): if grace_multiplier is not None: # Accept the same behaviour as superclass return super().is_alive(grace_multiplier=grace_multiplier) - scheduler_health_check_threshold = conf.getint('scheduler', 'scheduler_health_check_threshold') + scheduler_health_check_threshold: int = conf.getint('scheduler', 'scheduler_health_check_threshold') return ( self.state == State.RUNNING and (timezone.utcnow() - self.latest_heartbeat).total_seconds() < scheduler_health_check_threshold @@ -1069,9 +1089,9 @@ def is_alive(self, grace_multiplier=None): def _change_state_for_tis_without_dagrun( self, simple_dag_bag: SimpleDagBag, - old_states: List[State], - new_state: State, - session=None + old_states: List[str], + new_state: str, + session: Session = None ) -> None: """ For all DAG IDs in the SimpleDagBag, look for task instances in the @@ -1103,7 +1123,7 @@ def _change_state_for_tis_without_dagrun( # We need to do this for mysql as well because it can cause deadlocks # as discussed in https://issues.apache.org/jira/browse/AIRFLOW-2516 if self.using_sqlite or self.using_mysql: - tis_to_change = query.with_for_update().all() + tis_to_change: List[TI] = query.with_for_update().all() for ti in tis_to_change: ti.set_state(new_state, session=session) tis_changed += 1 @@ -1127,7 +1147,9 @@ def _change_state_for_tis_without_dagrun( Stats.gauge('scheduler.tasks.without_dagrun', tis_changed) @provide_session - def __get_concurrency_maps(self, states: List[State], session=None): + def __get_concurrency_maps( + self, states: List[str], session: Session = None + ) -> Tuple[DefaultDict[str, int], DefaultDict[Tuple[str, str], int]]: """ Get the concurrency maps. @@ -1135,16 +1157,16 @@ def __get_concurrency_maps(self, states: List[State], session=None): :type states: list[airflow.utils.state.State] :return: A map from (dag_id, task_id) to # of task instances and a map from (dag_id, task_id) to # of task instances in the given state list - :rtype: dict[tuple[str, str], int] + :rtype: tuple[dict[str, int], dict[tuple[str, str], int]] """ - ti_concurrency_query = ( + ti_concurrency_query: List[Tuple[str, str, int]] = ( session .query(TI.task_id, TI.dag_id, func.count('*')) .filter(TI.state.in_(states)) .group_by(TI.task_id, TI.dag_id) ).all() - dag_map: Dict[str, int] = defaultdict(int) - task_map: Dict[Tuple[str, str], int] = defaultdict(int) + dag_map: DefaultDict[str, int] = defaultdict(int) + task_map: DefaultDict[Tuple[str, str], int] = defaultdict(int) for result in ti_concurrency_query: task_id, dag_id, count = result dag_map[dag_id] += count @@ -1153,7 +1175,11 @@ def __get_concurrency_maps(self, states: List[State], session=None): # pylint: disable=too-many-locals,too-many-statements @provide_session - def _find_executable_task_instances(self, simple_dag_bag: SimpleDagBag, session=None): + def _find_executable_task_instances( + self, + simple_dag_bag: SimpleDagBag, + session: Session = None + ) -> List[TI]: """ Finds TIs that are ready for execution with respect to pool limits, dag concurrency, executor state, and priority. @@ -1168,7 +1194,7 @@ def _find_executable_task_instances(self, simple_dag_bag: SimpleDagBag, session= # Get all task instances associated with scheduled # DagRuns which are not backfilled, in the given states, # and the dag is not paused - task_instances_to_examine = ( + task_instances_to_examine: List[TI] = ( session .query(TI) .filter(TI.dag_id.in_(simple_dag_bag.dag_ids)) @@ -1196,13 +1222,15 @@ def _find_executable_task_instances(self, simple_dag_bag: SimpleDagBag, session= ) # Get the pool settings - pools = {p.pool: p for p in session.query(models.Pool).all()} + pools: Dict[str, models.Pool] = {p.pool: p for p in session.query(models.Pool).all()} - pool_to_task_instances = defaultdict(list) + pool_to_task_instances: DefaultDict[str, List[models.Pool]] = defaultdict(list) for task_instance in task_instances_to_examine: pool_to_task_instances[task_instance.pool].append(task_instance) # dag_id to # of running tasks and (dag_id, task_id) to # of running tasks. + dag_concurrency_map: DefaultDict[str, int] + task_concurrency_map: DefaultDict[Tuple[str, str], int] dag_concurrency_map, task_concurrency_map = self.__get_concurrency_maps( states=list(EXECUTION_STATES), session=session) @@ -1266,7 +1294,7 @@ def _find_executable_task_instances(self, simple_dag_bag: SimpleDagBag, session= ) continue - task_concurrency_limit = simple_dag.get_task_special_arg( + task_concurrency_limit: Optional[int] = simple_dag.get_task_special_arg( task_instance.task_id, 'task_concurrency') if task_concurrency_limit is not None: @@ -1323,7 +1351,9 @@ def _find_executable_task_instances(self, simple_dag_bag: SimpleDagBag, session= return executable_tis @provide_session - def _change_state_for_executable_task_instances(self, task_instances: List[TI], session=None): + def _change_state_for_executable_task_instances( + self, task_instances: List[TI], session: Session = None + ) -> List[SimpleTaskInstance]: """ Changes the state of task instances in the list with one of the given states to QUEUED atomically, and returns the TIs changed in SimpleTaskInstance format. @@ -1336,7 +1366,7 @@ def _change_state_for_executable_task_instances(self, task_instances: List[TI], session.commit() return [] - tis_to_set_to_queued = ( + tis_to_set_to_queued: List[TI] = ( session .query(TI) .filter(TI.filter_for_tis(task_instances)) @@ -1366,8 +1396,11 @@ def _change_state_for_executable_task_instances(self, task_instances: List[TI], len(tis_to_set_to_queued), task_instance_str) return simple_task_instances - def _enqueue_task_instances_with_queued_state(self, simple_dag_bag, - simple_task_instances): + def _enqueue_task_instances_with_queued_state( + self, + simple_dag_bag: SimpleDagBag, + simple_task_instances: List[SimpleTaskInstance] + ) -> None: """ Takes task_instances, which should have been set to queued, and enqueues them with the executor. @@ -1413,7 +1446,7 @@ def _enqueue_task_instances_with_queued_state(self, simple_dag_bag, def _execute_task_instances( self, simple_dag_bag: SimpleDagBag, - session=None + session: Session = None ) -> int: """ Attempts to execute TaskInstances that should be executed by the scheduler. @@ -1431,7 +1464,7 @@ def _execute_task_instances( """ executable_tis = self._find_executable_task_instances(simple_dag_bag, session=session) - def query(result, items): + def query(result: int, items: List[TI]) -> int: simple_tis_with_state_changed = \ self._change_state_for_executable_task_instances(items, session=session) self._enqueue_task_instances_with_queued_state( @@ -1443,7 +1476,7 @@ def query(result, items): return helpers.reduce_in_chunks(query, executable_tis, 0, self.max_tis_per_query) @provide_session - def _change_state_for_tasks_failed_to_execute(self, session=None): + def _change_state_for_tasks_failed_to_execute(self, session: Session = None): """ If there are tasks left over in the executor, we set them back to SCHEDULED to avoid creating hanging tasks. @@ -1465,7 +1498,7 @@ def _change_state_for_tasks_failed_to_execute(self, session=None): for dag_id, task_id, execution_date, try_number in self.executor.queued_tasks.keys()]) ti_query = session.query(TI).filter(or_(*filter_for_ti_state_change)) - tis_to_set_to_scheduled = ti_query.with_for_update().all() + tis_to_set_to_scheduled: List[TI] = ti_query.with_for_update().all() if not tis_to_set_to_scheduled: return @@ -1482,17 +1515,20 @@ def _change_state_for_tasks_failed_to_execute(self, session=None): self.log.info("Set the following tasks to scheduled state:\n\t%s", task_instance_str) @provide_session - def _process_executor_events(self, simple_dag_bag, session=None): + def _process_executor_events(self, simple_dag_bag: SimpleDagBag, session: Session = None) -> None: """ Respond to executor events. """ - ti_primary_key_to_try_number_map = {} + if not self.processor_agent: + raise ValueError("Processor agent is not started.") + ti_primary_key_to_try_number_map: Dict[Tuple[str, str, datetime.datetime], int] = {} event_buffer = self.executor.get_event_buffer(simple_dag_bag.dag_ids) tis_with_right_state: List[TaskInstanceKey] = [] # Report execution for ti_key, value in event_buffer.items(): - state, info = value + state: str + state, _ = value # We create map (dag_id, task_id, execution_date) -> in-memory try_number ti_primary_key_to_try_number_map[ti_key.primary] = ti_key.try_number @@ -1510,11 +1546,10 @@ def _process_executor_events(self, simple_dag_bag, session=None): # Check state of finishes tasks filter_for_tis = TI.filter_for_tis(tis_with_right_state) - tis = session.query(TI).filter(filter_for_tis).all() + tis: List[TI] = session.query(TI).filter(filter_for_tis).all() for ti in tis: - ti_key: TaskInstanceKey = ti.key - try_number = ti_primary_key_to_try_number_map[ti_key.primary] - buffer_key = ti_key.with_try_number(try_number) + try_number = ti_primary_key_to_try_number_map[ti.key.primary] + buffer_key = ti.key.with_try_number(try_number) state, info = event_buffer.pop(buffer_key) # TODO: should we fail RUNNING as well, as we do in Backfills? @@ -1533,7 +1568,7 @@ def _process_executor_events(self, simple_dag_bag, session=None): f"task says its {ti.state}. (Info: {info}) Was the task killed externally?" ) - def _execute(self): + def _execute(self) -> None: self.log.info("Starting the scheduler") # DAGs can be pickled for easier remote execution by some executors @@ -1545,7 +1580,7 @@ def _execute(self): # so the scheduler job and DAG parser don't access the DB at the same time. async_mode = not self.using_sqlite - processor_timeout_seconds = conf.getint('core', 'dag_file_processor_timeout') + processor_timeout_seconds: int = conf.getint('core', 'dag_file_processor_timeout') processor_timeout = timedelta(seconds=processor_timeout_seconds) self.processor_agent = DagFileProcessorAgent( dag_directory=self.subdir, @@ -1587,7 +1622,7 @@ def _execute(self): self.executor.end() - settings.Session.remove() + settings.Session.remove() # type: ignore except Exception: # pylint: disable=broad-except self.log.exception("Exception when executing execute_helper") finally: @@ -1595,7 +1630,12 @@ def _execute(self): self.log.info("Exited execute loop") @staticmethod - def _create_dag_file_processor(file_path, failure_callback_requests, dag_ids, pickle_dags): + def _create_dag_file_processor( + file_path: str, + failure_callback_requests: List[FailureCallbackRequest], + dag_ids: Optional[List[str]], + pickle_dags: bool + ) -> DagFileProcessorProcess: """ Creates DagFileProcessorProcess instance. """ @@ -1606,7 +1646,7 @@ def _create_dag_file_processor(file_path, failure_callback_requests, dag_ids, pi failure_callback_requests=failure_callback_requests ) - def _run_scheduler_loop(self): + def _run_scheduler_loop(self) -> None: """ The actual scheduler loop. The main steps in the loop are: #. Harvest DAG parsing results through DagFileProcessorAgent @@ -1623,7 +1663,9 @@ def _run_scheduler_loop(self): :rtype: None """ - is_unit_test = conf.getboolean('core', 'unit_test_mode') + if not self.processor_agent: + raise ValueError("Processor agent is not started.") + is_unit_test: bool = conf.getboolean('core', 'unit_test_mode') # For the execute duration, parse and schedule DAGs while True: @@ -1683,7 +1725,7 @@ def _validate_and_run_task_instances(self, simple_dag_bag: SimpleDagBag) -> bool self._process_executor_events(simple_dag_bag) return True - def _process_and_execute_tasks(self, simple_dag_bag): + def _process_and_execute_tasks(self, simple_dag_bag: SimpleDagBag) -> None: # Handle cases where a DAG run state is set (perhaps manually) to # a non-running state. Handle task instances that belong to # DAG runs in those states @@ -1706,7 +1748,7 @@ def _process_and_execute_tasks(self, simple_dag_bag): self._execute_task_instances(simple_dag_bag) @provide_session - def _emit_pool_metrics(self, session=None) -> None: + def _emit_pool_metrics(self, session: Session = None) -> None: pools = models.Pool.slots_stats(session) for pool_name, slot_stats in pools.items(): Stats.gauge(f'pool.open_slots.{pool_name}', slot_stats["open"]) @@ -1714,5 +1756,5 @@ def _emit_pool_metrics(self, session=None) -> None: Stats.gauge(f'pool.running_slots.{pool_name}', slot_stats[State.RUNNING]) @provide_session - def heartbeat_callback(self, session=None): + def heartbeat_callback(self, session: Session = None) -> None: Stats.incr('scheduler_heartbeat', 1, 1) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index b1670aa242997..3b4c488d0bf4a 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -979,7 +979,7 @@ def resolve_template_files(self) -> None: self.prepare_template() @property - def upstream_list(self) -> List[str]: + def upstream_list(self) -> List["BaseOperator"]: """@property: list of tasks directly upstream""" return [self.dag.get_task(tid) for tid in self._upstream_task_ids] @@ -989,7 +989,7 @@ def upstream_task_ids(self) -> Set[str]: return self._upstream_task_ids @property - def downstream_list(self) -> List[str]: + def downstream_list(self) -> List["BaseOperator"]: """@property: list of tasks directly downstream""" return [self.dag.get_task(tid) for tid in self._downstream_task_ids] @@ -1123,7 +1123,7 @@ def get_direct_relative_ids(self, upstream: bool = False) -> Set[str]: else: return self._downstream_task_ids - def get_direct_relatives(self, upstream: bool = False) -> List[str]: + def get_direct_relatives(self, upstream: bool = False) -> List["BaseOperator"]: """ Get list of the direct relatives to the current task, upstream or downstream. diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 753e606f26a4f..69684d37695e0 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -251,7 +251,7 @@ def __init__( self._dag_id = dag_id self._full_filepath = full_filepath if full_filepath else '' self._concurrency = concurrency - self._pickle_id = None + self._pickle_id: Optional[int] = None self._description = description # set file location to caller source path @@ -304,7 +304,7 @@ def __init__( self.dagrun_timeout = dagrun_timeout self.sla_miss_callback = sla_miss_callback if default_view in DEFAULT_VIEW_PRESETS: - self._default_view = default_view + self._default_view: str = default_view else: raise AirflowException(f'Invalid values of dag.default_view: only support ' f'{DEFAULT_VIEW_PRESETS}, but get {default_view}') @@ -507,27 +507,27 @@ def get_last_dagrun(self, session=None, include_externally_triggered=False): include_externally_triggered=include_externally_triggered) @property - def dag_id(self): + def dag_id(self) -> str: return self._dag_id @dag_id.setter - def dag_id(self, value): + def dag_id(self, value: str) -> None: self._dag_id = value @property - def full_filepath(self): + def full_filepath(self) -> str: return self._full_filepath @full_filepath.setter - def full_filepath(self, value): + def full_filepath(self, value) -> None: self._full_filepath = value @property - def concurrency(self): + def concurrency(self) -> int: return self._concurrency @concurrency.setter - def concurrency(self, value): + def concurrency(self, value: int): self._concurrency = value @property @@ -539,23 +539,23 @@ def access_control(self, value): self._access_control = value @property - def description(self): + def description(self) -> Optional[str]: return self._description @property - def default_view(self): + def default_view(self) -> str: return self._default_view @property - def pickle_id(self): + def pickle_id(self) -> Optional[int]: return self._pickle_id @pickle_id.setter - def pickle_id(self, value): + def pickle_id(self, value: int) -> None: self._pickle_id = value @property - def tasks(self): + def tasks(self) -> List[BaseOperator]: return list(self.task_dict.values()) @tasks.setter @@ -564,7 +564,7 @@ def tasks(self, val): 'DAG.tasks can not be modified. Use dag.add_task() instead.') @property - def task_ids(self): + def task_ids(self) -> List[str]: return list(self.task_dict.keys()) @property @@ -1264,10 +1264,10 @@ def sub_dag(self, task_regex, include_downstream=False, return dag - def has_task(self, task_id): + def has_task(self, task_id: str): return task_id in (t.task_id for t in self.tasks) - def get_task(self, task_id, include_subdags=False): + def get_task(self, task_id: str, include_subdags: bool = False) -> BaseOperator: if task_id in self.task_dict: return self.task_dict[task_id] if include_subdags: @@ -1291,7 +1291,7 @@ def pickle_info(self): return d @provide_session - def pickle(self, session=None): + def pickle(self, session=None) -> DagPickle: dag = session.query( DagModel).filter(DagModel.dag_id == self.dag_id).first() dp = None diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 2bab4d73730c4..cc7f6bd4733f6 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -25,7 +25,7 @@ import textwrap import zipfile from datetime import datetime, timedelta -from typing import List, NamedTuple +from typing import Dict, List, NamedTuple, Optional from croniter import CroniterBadCronError, CroniterBadDateError, CroniterNotAlphaError, croniter from tabulate import tabulate @@ -79,18 +79,20 @@ class DagBag(BaseDagBag, LoggingMixin): def __init__( self, - dag_folder=None, - include_examples=conf.getboolean('core', 'LOAD_EXAMPLES'), - safe_mode=conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'), - store_serialized_dags=False, + dag_folder: Optional[str] = None, + include_examples: bool = conf.getboolean('core', 'LOAD_EXAMPLES'), + safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'), + store_serialized_dags: bool = False, ): + # Avoid circular import + from airflow.models.dag import DAG super().__init__() dag_folder = dag_folder or settings.DAGS_FOLDER self.dag_folder = dag_folder - self.dags = {} + self.dags: Dict[str, DAG] = {} # the file's last modified timestamp when we last read it - self.file_last_changed = {} - self.import_errors = {} + self.file_last_changed: Dict[str, datetime] = {} + self.import_errors: Dict[str, str] = {} self.has_logged = False self.store_serialized_dags = store_serialized_dags @@ -99,7 +101,7 @@ def __init__( include_examples=include_examples, safe_mode=safe_mode) - def size(self): + def size(self) -> int: """ :return: the amount of dags contained in this dagbag """ diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 0cc03d9855af4..7ef269b8c2e2f 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import List, Optional, Tuple, Union, cast +from typing import Any, List, Optional, Tuple, Union, cast from sqlalchemy import ( Boolean, Column, DateTime, Index, Integer, PickleType, String, UniqueConstraint, and_, func, or_, @@ -66,8 +66,17 @@ class DagRun(Base, LoggingMixin): UniqueConstraint('dag_id', 'run_id'), ) - def __init__(self, dag_id=None, run_id=None, execution_date=None, start_date=None, external_trigger=None, - conf=None, state=None, run_type=None): + def __init__( + self, + dag_id: Optional[str] = None, + run_id: Optional[str] = None, + execution_date: Optional[datetime] = None, + start_date: Optional[datetime] = None, + external_trigger: Optional[bool] = None, + conf: Optional[Any] = None, + state: Optional[str] = None, + run_type: Optional[str] = None + ): self.dag_id = dag_id self.run_id = run_id self.execution_date = execution_date @@ -131,8 +140,9 @@ def find( no_backfills: Optional[bool] = False, run_type: Optional[DagRunType] = None, session: Session = None, - execution_start_date=None, execution_end_date=None - ): + execution_start_date: Optional[datetime] = None, + execution_end_date: Optional[datetime] = None + ) -> List["DagRun"]: """ Returns a set of dag runs for the given search criteria. @@ -281,7 +291,7 @@ def get_previous_scheduled_dagrun(self, session=None): ).first() @provide_session - def update_state(self, session=None): + def update_state(self, session=None) -> List[TI]: """ Determines the overall state of the DagRun based on the state of its TaskInstances. @@ -291,7 +301,7 @@ def update_state(self, session=None): """ dag = self.get_dag() - ready_tis = [] + ready_tis: List[TI] = [] tis = [ti for ti in self.get_task_instances(session=session, state=State.task_states + (State.SHUTDOWN,))] self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis)) diff --git a/airflow/models/pool.py b/airflow/models/pool.py index a9048887d0bc5..3819d6d135fda 100644 --- a/airflow/models/pool.py +++ b/airflow/models/pool.py @@ -181,7 +181,7 @@ def queued_slots(self, session: Session): ) or 0 @provide_session - def open_slots(self, session: Session): + def open_slots(self, session: Session) -> float: """ Get the number of slots open at the moment. diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 30024dfe14aef..5fb174cd52681 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -339,7 +339,7 @@ def generate_command(dag_id: str, # pylint: disable=too-many-arguments ignore_task_deps: Optional[bool] = False, ignore_ti_state: Optional[bool] = False, local: Optional[bool] = False, - pickle_id: Optional[str] = None, + pickle_id: Optional[int] = None, file_path: Optional[str] = None, raw: Optional[bool] = False, job_id: Optional[str] = None, @@ -372,7 +372,7 @@ def generate_command(dag_id: str, # pylint: disable=too-many-arguments :type local: Optional[bool] :param pickle_id: If the DAG was serialized to the DB, the ID associated with the pickled DAG - :type pickle_id: Optional[str] + :type pickle_id: Optional[int] :param file_path: path to the file containing the DAG definition :type file_path: Optional[str] :param raw: raw mode (needs more details) @@ -391,7 +391,7 @@ def generate_command(dag_id: str, # pylint: disable=too-many-arguments if mark_success: cmd.extend(["--mark-success"]) if pickle_id: - cmd.extend(["--pickle", pickle_id]) + cmd.extend(["--pickle", str(pickle_id)]) if job_id: cmd.extend(["--job-id", str(job_id)]) if ignore_all_deps: @@ -573,7 +573,7 @@ def key(self) -> TaskInstanceKey: return TaskInstanceKey(self.dag_id, self.task_id, self.execution_date, self.try_number) @provide_session - def set_state(self, state, session=None, commit=True): + def set_state(self, state: str, session=None, commit: bool = True): """ Set TaskInstance state @@ -1779,9 +1779,7 @@ def __init__(self, ti: TaskInstance): self._run_as_user: Optional[str] = None if hasattr(ti, 'run_as_user'): self._run_as_user = ti.run_as_user - self._pool: Optional[str] = None - if hasattr(ti, 'pool'): - self._pool = ti.pool + self._pool: str = ti.pool self._priority_weight: Optional[int] = None if hasattr(ti, 'priority_weight'): self._priority_weight = ti.priority_weight @@ -1818,7 +1816,7 @@ def state(self) -> str: return self._state @property - def pool(self) -> Any: + def pool(self) -> str: return self._pool @property diff --git a/airflow/stats.py b/airflow/stats.py index ad5ef10b053bf..852bafeb8a88a 100644 --- a/airflow/stats.py +++ b/airflow/stats.py @@ -22,7 +22,7 @@ import string import textwrap from functools import wraps -from typing import Callable, Optional +from typing import TYPE_CHECKING, Callable, Optional from airflow.configuration import conf from airflow.exceptions import AirflowConfigException, InvalidStatsNameException @@ -255,5 +255,8 @@ def get_constant_tags(self): return tags -class Stats(metaclass=_Stats): # noqa: D101 - pass +if TYPE_CHECKING: + Stats: StatsLogger +else: + class Stats(metaclass=_Stats): # noqa: D101 + pass diff --git a/airflow/ti_deps/dep_context.py b/airflow/ti_deps/dep_context.py index a4e8a7d45d2eb..08afa1c21443e 100644 --- a/airflow/ti_deps/dep_context.py +++ b/airflow/ti_deps/dep_context.py @@ -65,13 +65,13 @@ class DepContext: def __init__( self, deps=None, - flag_upstream_failed=False, - ignore_all_deps=False, - ignore_depends_on_past=False, - ignore_in_retry_period=False, - ignore_in_reschedule_period=False, - ignore_task_deps=False, - ignore_ti_state=False, + flag_upstream_failed: bool = False, + ignore_all_deps: bool = False, + ignore_depends_on_past: bool = False, + ignore_in_retry_period: bool = False, + ignore_in_reschedule_period: bool = False, + ignore_task_deps: bool = False, + ignore_ti_state: bool = False, finished_tasks=None): self.deps = deps or set() self.flag_upstream_failed = flag_upstream_failed diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py index 63f0d34a5b5ff..835d0e3315339 100644 --- a/airflow/utils/dag_processing.py +++ b/airflow/utils/dag_processing.py @@ -63,12 +63,12 @@ class SimpleDag(BaseDag): :type pickle_id: unicode """ - def __init__(self, dag, pickle_id: Optional[str] = None): + def __init__(self, dag, pickle_id: Optional[int] = None): self._dag_id: str = dag.dag_id self._task_ids: List[str] = [task.task_id for task in dag.tasks] self._full_filepath: str = dag.full_filepath self._concurrency: int = dag.concurrency - self._pickle_id: Optional[str] = pickle_id + self._pickle_id: Optional[int] = pickle_id self._task_special_args: Dict[str, Any] = {} for task in dag.tasks: special_args = {} @@ -110,7 +110,7 @@ def concurrency(self) -> int: return self._concurrency @property - def pickle_id(self) -> Optional[str]: # pylint: disable=invalid-overridden-method + def pickle_id(self) -> Optional[int]: # pylint: disable=invalid-overridden-method """ :return: The pickle ID for this DAG, if it has one. Otherwise None. :rtype: unicode @@ -175,7 +175,7 @@ class AbstractDagFileProcessorProcess(metaclass=ABCMeta): """ @abstractmethod - def start(self): + def start(self) -> None: """ Launch the process to process the file """ @@ -189,7 +189,7 @@ def terminate(self, sigkill: bool = False): raise NotImplementedError() @abstractmethod - def kill(self): + def kill(self) -> None: """ Kill the process launched to process the file, and ensure consistent state. """ @@ -205,7 +205,7 @@ def pid(self) -> int: @property @abstractmethod - def exit_code(self) -> int: + def exit_code(self) -> Optional[int]: """ After the process is finished, this can be called to get the return code :return: the exit code of the process @@ -225,18 +225,18 @@ def done(self) -> bool: @property @abstractmethod - def result(self) -> Tuple[List[SimpleDag], int]: + def result(self) -> Optional[Tuple[List[SimpleDag], int]]: """ A list of simple dags found, and the number of import errors - :return: result of running SchedulerJob.process_file() - :rtype: tuple[list[airflow.utils.dag_processing.SimpleDag], int] + :return: result of running SchedulerJob.process_file() if availlablle. Otherwise, none + :rtype: Optional[Tuple[List[SimpleDag], int]] """ raise NotImplementedError() @property @abstractmethod - def start_time(self): + def start_time(self) -> datetime: """ :return: When this started to process the file :rtype: datetime @@ -245,7 +245,7 @@ def start_time(self): @property @abstractmethod - def file_path(self): + def file_path(self) -> str: """ :return: the path to the file that this is processing :rtype: unicode @@ -308,7 +308,9 @@ class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin): :type max_runs: int :param processor_factory: function that creates processors for DAG definition files. Arguments are (dag_definition_path, log_file_path) - :type processor_factory: (str, str, list) -> (AbstractDagFileProcessorProcess) + :type processor_factory: ([str, List[FailureCallbackRequest], Optional[List[str]], bool]) -> ( + AbstractDagFileProcessorProcess + ) :param processor_timeout: How long to wait before timing out a DAG file processor :type processor_timeout: timedelta :param dag_ids: if specified, only schedule tasks with these DAG IDs @@ -319,17 +321,22 @@ class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin): :type async_mode: bool """ - def __init__(self, - dag_directory, - max_runs, - processor_factory, - processor_timeout, - dag_ids, - pickle_dags, - async_mode): + def __init__( + self, + dag_directory: str, + max_runs: int, + processor_factory: Callable[ + [str, List[FailureCallbackRequest], Optional[List[str]], bool], + AbstractDagFileProcessorProcess + ], + processor_timeout: timedelta, + dag_ids: Optional[List[str]], + pickle_dags: bool, + async_mode: bool + ): super().__init__() - self._file_path_queue = [] - self._dag_directory = dag_directory + self._file_path_queue: List[str] = [] + self._dag_directory: str = dag_directory self._max_runs = max_runs self._processor_factory = processor_factory self._processor_timeout = processor_timeout @@ -337,17 +344,17 @@ def __init__(self, self._pickle_dags = pickle_dags self._async_mode = async_mode # Map from file path to the processor - self._processors = {} + self._processors: Dict[str, AbstractDagFileProcessorProcess] = {} # Pipe for communicating signals - self._process = None - self._done = False + self._process: Optional[multiprocessing.process.BaseProcess] = None + self._done: bool = False # Initialized as true so we do not deactivate w/o any actual DAG parsing. self._all_files_processed = True - self._parent_signal_conn = None - self._collected_dag_buffer = [] + self._parent_signal_conn: Optional[MultiprocessingConnection] = None + self._collected_dag_buffer: List = [] - def start(self): + def start(self) -> None: """ Launch DagFileProcessorManager processor and start DAG parsing loop in manager. """ @@ -355,7 +362,7 @@ def start(self): context = multiprocessing.get_context(mp_start_method) self._parent_signal_conn, child_signal_conn = context.Pipe() - self._process = context.Process( + process = context.Process( target=type(self)._run_processor_manager, args=( self._dag_directory, @@ -369,11 +376,13 @@ def start(self): self._async_mode ) ) - self._process.start() + self._process = process + + process.start() - self.log.info("Launched DagFileProcessorManager with pid: %s", self._process.pid) + self.log.info("Launched DagFileProcessorManager with pid: %s", process.pid) - def run_single_parsing_loop(self): + def run_single_parsing_loop(self) -> None: """ Should only be used when launched DAG file processor manager in sync mode. Send agent heartbeat signal to the manager, requesting that it runs one @@ -382,6 +391,8 @@ def run_single_parsing_loop(self): Call wait_until_finished to ensure that any launched processors have finished before continuing """ + if not self._parent_signal_conn or not self._process: + raise ValueError("Process not started.") if not self._process.is_alive(): return @@ -392,7 +403,9 @@ def run_single_parsing_loop(self): # when harvest_simple_dags calls _heartbeat_manager. pass - def send_callback_to_execute(self, full_filepath: str, task_instance: TaskInstance, msg: str): + def send_callback_to_execute( + self, full_filepath: str, task_instance: TaskInstance, msg: str + ) -> None: """ Sends information about the callback to be executed by DagFileProcessor. @@ -403,6 +416,8 @@ def send_callback_to_execute(self, full_filepath: str, task_instance: TaskInstan :param msg: Message sent in callback. :type msg: str """ + if not self._parent_signal_conn: + raise ValueError("Process not started.") try: request = FailureCallbackRequest( full_filepath=full_filepath, @@ -415,8 +430,10 @@ def send_callback_to_execute(self, full_filepath: str, task_instance: TaskInstan # when harvest_simple_dags calls _heartbeat_manager. pass - def wait_until_finished(self): + def wait_until_finished(self) -> None: """Waits until DAG parsing is finished.""" + if not self._parent_signal_conn: + raise ValueError("Process not started.") while self._parent_signal_conn.poll(timeout=None): try: result = self._parent_signal_conn.recv() @@ -429,14 +446,19 @@ def wait_until_finished(self): return @staticmethod - def _run_processor_manager(dag_directory, - max_runs, - processor_factory, - processor_timeout, - signal_conn, - dag_ids, - pickle_dags, - async_mode): + def _run_processor_manager( + dag_directory: str, + max_runs: int, + processor_factory: Callable[ + [str, List[FailureCallbackRequest]], + AbstractDagFileProcessorProcess + ], + processor_timeout: timedelta, + signal_conn: MultiprocessingConnection, + dag_ids: Optional[List[str]], + pickle_dags: bool, + async_mode: bool + ) -> None: # Make this process start as a new process group - that makes it easy # to kill all sub-process of this at the OS-level, rather than having @@ -452,7 +474,7 @@ def _run_processor_manager(dag_directory, os.environ['AIRFLOW__LOGGING__COLORED_CONSOLE_LOG'] = 'False' # Replicating the behavior of how logging module was loaded # in logging_config.py - importlib.reload(import_module(airflow.settings.LOGGING_CLASS_PATH.rsplit('.', 1)[0])) + importlib.reload(import_module(airflow.settings.LOGGING_CLASS_PATH.rsplit('.', 1)[0])) # type: ignore importlib.reload(airflow.settings) airflow.settings.initialize() del os.environ['CONFIG_PROCESSOR_MANAGER_LOGGER'] @@ -467,12 +489,14 @@ def _run_processor_manager(dag_directory, processor_manager.start() - def harvest_simple_dags(self): + def harvest_simple_dags(self) -> List[SimpleDag]: """ Harvest DAG parsing results from result queue and sync metadata from stat queue. :return: List of parsing result in SimpleDag format. """ + if not self._parent_signal_conn: + raise ValueError("Process not started.") # Receive any pending messages before checking if the process has exited. while self._parent_signal_conn.poll(timeout=0.01): try: @@ -499,6 +523,8 @@ def _heartbeat_manager(self): """ Heartbeat DAG file processor and restart it if we are not done. """ + if not self._parent_signal_conn: + raise ValueError("Process not started.") if self._process and not self._process.is_alive(): self._process.join(timeout=0) if not self.done: @@ -516,7 +542,7 @@ def _sync_metadata(self, stat): self._all_files_processed = stat.all_files_processed @property - def done(self): + def done(self) -> bool: """ Has DagFileProcessorManager ended? """ @@ -591,7 +617,7 @@ def __init__(self, ], processor_timeout: timedelta, signal_conn: MultiprocessingConnection, - dag_ids: List[str], + dag_ids: Optional[List[str]], pickle_dags: bool, async_mode: bool = True): super().__init__() diff --git a/airflow/utils/dates.py b/airflow/utils/dates.py index cf219677ffb9a..23b21cd722730 100644 --- a/airflow/utils/dates.py +++ b/airflow/utils/dates.py @@ -17,14 +17,14 @@ # under the License. from datetime import datetime, timedelta -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union from croniter import croniter from dateutil.relativedelta import relativedelta # noqa: F401 for doctest from airflow.utils import timezone -cron_presets = { +cron_presets: Dict[str, str] = { '@hourly': '0 * * * *', '@daily': '0 0 * * *', '@weekly': '0 0 * * 0', diff --git a/airflow/utils/email.py b/airflow/utils/email.py index 219c9b316f483..1ce95b8a9c986 100644 --- a/airflow/utils/email.py +++ b/airflow/utils/email.py @@ -32,17 +32,17 @@ log = logging.getLogger(__name__) -def send_email(to, subject, html_content, +def send_email(to: Union[List[str], Iterable[str]], subject: str, html_content: str, files=None, dryrun=False, cc=None, bcc=None, mime_subtype='mixed', mime_charset='utf-8', **kwargs): """ Send email using backend specified in EMAIL_BACKEND. """ backend = conf.getimport('email', 'EMAIL_BACKEND') - to = get_email_address_list(to) - to = ", ".join(to) + to_list = get_email_address_list(to) + to_comma_seperated = ", ".join(to_list) - return backend(to, subject, html_content, files=files, + return backend(to_comma_seperated, subject, html_content, files=files, dryrun=dryrun, cc=cc, bcc=bcc, mime_subtype=mime_subtype, mime_charset=mime_charset, **kwargs) diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py index 9858b3bcffd5a..f9872c1416b2b 100644 --- a/airflow/utils/helpers.py +++ b/airflow/utils/helpers.py @@ -21,7 +21,7 @@ from datetime import datetime from functools import reduce from itertools import filterfalse, tee -from typing import Any, Callable, Dict, Iterable, Optional +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, TypeVar from jinja2 import Template @@ -100,7 +100,11 @@ def as_tuple(obj): return tuple([obj]) -def chunks(items, chunk_size): +T = TypeVar('T') # pylint: disable=invalid-name +S = TypeVar('S') # pylint: disable=invalid-name + + +def chunks(items: List[T], chunk_size: int) -> Generator[List[T], None, None]: """ Yield successive chunks of a given size from a list of items """ @@ -110,7 +114,12 @@ def chunks(items, chunk_size): yield items[i:i + chunk_size] -def reduce_in_chunks(fn, iterable, initializer, chunk_size=0): +def reduce_in_chunks( + fn: Callable[[S, List[T]], S], + iterable: List[T], + initializer: S, + chunk_size: int = 0 +): """ Reduce the given list of items by splitting it into chunks of the given size and passing each chunk through the reducer @@ -122,7 +131,7 @@ def reduce_in_chunks(fn, iterable, initializer, chunk_size=0): return reduce(fn, chunks(iterable, chunk_size), initializer) -def as_flattened_list(iterable): +def as_flattened_list(iterable: Iterable[Iterable[T]]) -> List[T]: """ Return an iterable with one level flattened diff --git a/airflow/utils/mixins.py b/airflow/utils/mixins.py index 9bbc1fea9bfd7..ddf19bc42b598 100644 --- a/airflow/utils/mixins.py +++ b/airflow/utils/mixins.py @@ -25,7 +25,7 @@ class MultiprocessingStartMethodMixin: """ Convenience class to add support for different types of multiprocessing. """ - def _get_multiprocessing_start_method(self): + def _get_multiprocessing_start_method(self) -> str: """ Determine method of creating new processes by checking if the mp_start_method is set in configs, else, it uses the OS default. @@ -33,4 +33,7 @@ def _get_multiprocessing_start_method(self): if conf.has_option('core', 'mp_start_method'): return conf.get('core', 'mp_start_method') - return multiprocessing.get_start_method() + method = multiprocessing.get_start_method() + if not method: + raise ValueError("Failed to determine start method") + return method diff --git a/airflow/utils/timezone.py b/airflow/utils/timezone.py index 24fa2d15a90ef..b5c9ccb4adf56 100644 --- a/airflow/utils/timezone.py +++ b/airflow/utils/timezone.py @@ -49,7 +49,7 @@ def is_naive(value): return value.utcoffset() is None -def utcnow(): +def utcnow() -> dt.datetime: """ Get the current date and time in UTC @@ -59,13 +59,13 @@ def utcnow(): # pendulum utcnow() is not used as that sets a TimezoneInfo object # instead of a Timezone. This is not pickable and also creates issues # when using replace() - date = dt.datetime.utcnow() - date = date.replace(tzinfo=utc) + result = dt.datetime.utcnow() + result = result.replace(tzinfo=utc) - return date + return result -def utc_epoch(): +def utc_epoch() -> dt.datetime: """ Gets the epoch in the users timezone @@ -75,10 +75,10 @@ def utc_epoch(): # pendulum utcnow() is not used as that sets a TimezoneInfo object # instead of a Timezone. This is not pickable and also creates issues # when using replace() - date = dt.datetime(1970, 1, 1) - date = date.replace(tzinfo=utc) + result = dt.datetime(1970, 1, 1) + result = result.replace(tzinfo=utc) - return date + return result def convert_to_utc(value): diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 630d0117b2171..3dc70c7651bce 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -215,10 +215,9 @@ def test_scheduler_executor_overflow(self): session.merge(ti) # scheduler._process_dags(simple_dag_bag) - @mock.patch('airflow.models.DagBag', return_value=dagbag) - @mock.patch('airflow.models.DagBag.collect_dags') + @mock.patch('airflow.jobs.scheduler_job.DagBag', return_value=dagbag) @mock.patch('airflow.jobs.scheduler_job.SchedulerJob._change_state_for_tis_without_dagrun') - def do_schedule(mock_dagbag, mock_collect_dags, mock_change_state): + def do_schedule(mock_dagbag, mock_change_state): # Use a empty file since the above mock will return the # expected DAGs. Also specify only a single file so that it doesn't # try to schedule the above DAG repeatedly. @@ -2902,9 +2901,8 @@ def test_scheduler_reschedule(self): dagbag.bag_dag(dag=dag, root_dag=dag) - @mock.patch('airflow.models.DagBag', return_value=dagbag) - @mock.patch('airflow.models.DagBag.collect_dags') - def do_schedule(mock_dagbag, mock_collect_dags): + @mock.patch('airflow.jobs.scheduler_job.DagBag', return_value=dagbag) + def do_schedule(mock_dagbag): # Use a empty file since the above mock will return the # expected DAGs. Also specify only a single file so that it doesn't # try to schedule the above DAG repeatedly. @@ -2960,9 +2958,8 @@ def test_retry_still_in_executor(self): dagbag.bag_dag(dag=dag, root_dag=dag) - @mock.patch('airflow.models.DagBag', return_value=dagbag) - @mock.patch('airflow.models.DagBag.collect_dags') - def do_schedule(mock_dagbag, mock_collect_dags): + @mock.patch('airflow.jobs.scheduler_job.DagBag', return_value=dagbag) + def do_schedule(mock_dagbag): # Use a empty file since the above mock will return the # expected DAGs. Also specify only a single file so that it doesn't # try to schedule the above DAG repeatedly.