From 5cae2a6476180fd404955fa9e5e4b3b1e3c9772d Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Tue, 15 Sep 2020 15:16:51 +0100 Subject: [PATCH 01/70] Officially support running more than one scheduler concurrently. This PR implements scheduler HA as proposed in AIP-15. The high level design is as follows: - Move all scheduling decisions into SchedulerJob (requiring DAG serialization in the scheduler) - Use row-level locks to ensure schedulers don't stomp on each other (`SELECT ... FOR UPDATE`) - Use `SKIP LOCKED` for better performance when multiple schedulers are running. (Mysql < 8 and MariaDB don't support this) - Scheduling decisions are not tied to the parsing speed, but can operate just on the database *DagFileProcessorProcess*: Previously this component was responsible for more than just parsing the DAG files as it's name might imply. It also was responsible for creating DagRuns, and also making scheduling decisions of TIs, sending them from "None" to "scheduled" state. This commit changes it so that the DagFileProcessorProcess now will update the SerializedDAG row for this DAG, and make no scheduling decisions itself. To make the scheduler's job easier (so that it can make as many decisions as possible without having to load the possibly-large SerializedDAG row) we store/update some columns on the DagModel table: - `next_dagrun`: The execution_date of the next dag run that should be created (or None) - `next_dagrun_create_after`: The earliest point at which the next dag run can be created Pre-computing these values (and updating them every time the DAG is parsed) reduce the overall load on the DB as many decisions can be taken by selecting just these two columns/the small DagModel row. In case of max_active_runs, or `@once` these columns will be set to null, meaning "don't create any dag runs" *SchedulerJob* The SchedulerJob used to only queue/send tasks to the executor after they were parsed, and returned from the DagFileProcessorProcess. This PR breaks the link between parsing and enqueuing of tasks, instead of looking at DAGs as they are parsed, we now: - store a new datetime column, `last_scheduling_decision` on DagRun table, signifying when a scheduler last examined a DagRun - Each time around the loop the scheduler will get (and lock) the next _n_ DagRuns via `DagRun.next_dagruns_to_examine`, prioritising DagRuns which haven't been touched by a scheduler in the longest period - SimpleTaskInstance etc have been almost entirely removed now, as we use the serialized versions --- airflow/executors/base_executor.py | 33 +- airflow/jobs/scheduler_job.py | 800 ++++++------------ ..._add_scheduling_decision_to_dagrun_and_.py | 80 ++ airflow/models/dag.py | 192 ++++- airflow/models/dagbag.py | 19 +- airflow/models/dagrun.py | 60 +- airflow/models/pool.py | 18 +- airflow/models/serialized_dag.py | 19 +- airflow/models/taskinstance.py | 10 +- airflow/stats.py | 56 ++ airflow/utils/dag_processing.py | 59 +- 11 files changed, 708 insertions(+), 638 deletions(-) create mode 100644 airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 73a002cabe93f..4140511b86e49 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -17,11 +17,12 @@ """ Base executor - this is the base class for all the implemented executors. """ +import sys from collections import OrderedDict -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple from airflow.configuration import conf -from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance, TaskInstanceKey +from airflow.models.taskinstance import TaskInstance, TaskInstanceKey from airflow.stats import Stats from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State @@ -39,8 +40,8 @@ # Task that is queued. It contains all the information that is # needed to run the task. # -# Tuple of: command, priority, queue name, SimpleTaskInstance -QueuedTaskInstanceType = Tuple[CommandType, int, Optional[str], Union[SimpleTaskInstance, TaskInstance]] +# Tuple of: command, priority, queue name, TaskInstance +QueuedTaskInstanceType = Tuple[CommandType, int, Optional[str], TaskInstance] # Event_buffer dict value type # Tuple of: state, info @@ -72,16 +73,16 @@ def start(self): # pragma: no cover """ def queue_command(self, - simple_task_instance: SimpleTaskInstance, + task_instance: TaskInstance, command: CommandType, priority: int = 1, queue: Optional[str] = None): """Queues command to task""" - if simple_task_instance.key not in self.queued_tasks and simple_task_instance.key not in self.running: + if task_instance.key not in self.queued_tasks and task_instance.key not in self.running: self.log.info("Adding to queue: %s", command) - self.queued_tasks[simple_task_instance.key] = (command, priority, queue, simple_task_instance) + self.queued_tasks[task_instance.key] = (command, priority, queue, task_instance) else: - self.log.error("could not queue task %s", simple_task_instance.key) + self.log.error("could not queue task %s", task_instance.key) def queue_task_instance( self, @@ -112,7 +113,7 @@ def queue_task_instance( pickle_id=pickle_id, cfg_path=cfg_path) self.queue_command( - SimpleTaskInstance(task_instance), + task_instance, command_list_to_run, priority=task_instance.task.priority_weight_total, queue=task_instance.task.queue) @@ -178,13 +179,13 @@ def trigger_tasks(self, open_slots: int) -> None: sorted_queue = self.order_queued_tasks_by_priority() for _ in range(min((open_slots, len(self.queued_tasks)))): - key, (command, _, _, simple_ti) = sorted_queue.pop(0) + key, (command, _, _, ti) = sorted_queue.pop(0) self.queued_tasks.pop(key) self.running.add(key) self.execute_async(key=key, command=command, queue=None, - executor_config=simple_ti.executor_config) + executor_config=ti.executor_config) def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None: """ @@ -282,6 +283,16 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance # Subclasses can do better! return tis + @property + def slots_available(self): + """ + Number of new tasks this executor instance can accept + """ + if self.parallelism: + return self.parallelism - len(self.running) - len(self.queued_tasks) + else: + return sys.maxsize + @staticmethod def validate_command(command: List[str]) -> None: """Check if the command to execute is airflow command""" diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index dd31f36b4aea3..cb34736fb3ed9 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -18,8 +18,10 @@ # under the License. # import datetime +import enum import logging import multiprocessing +import operator import os import signal import sys @@ -28,13 +30,13 @@ from collections import defaultdict from contextlib import ExitStack, redirect_stderr, redirect_stdout, suppress from datetime import timedelta -from itertools import groupby from multiprocessing.connection import Connection as MultiprocessingConnection -from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple +from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Union from setproctitle import setproctitle from sqlalchemy import and_, func, not_, or_ -from sqlalchemy.orm import load_only +from sqlalchemy.exc import OperationalError +from sqlalchemy.orm import load_only, selectinload from sqlalchemy.orm.session import Session, make_transient from airflow import models, settings @@ -45,21 +47,17 @@ from airflow.models import DAG, DagModel, SlaMiss, errors from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun -from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKey -from airflow.operators.dummy_operator import DummyOperator -from airflow.serialization.serialized_objects import SerializedDAG +from airflow.models.taskinstance import TaskInstanceKey from airflow.stats import Stats -from airflow.ti_deps.dep_context import DepContext -from airflow.ti_deps.dependencies_deps import SCHEDULED_DEPS from airflow.ti_deps.dependencies_states import EXECUTION_STATES -from airflow.utils import helpers, timezone +from airflow.utils import timezone from airflow.utils.dag_processing import ( AbstractDagFileProcessorProcess, DagFileProcessorAgent, FailureCallbackRequest, SimpleDagBag, ) from airflow.utils.email import get_email_address_list, send_email from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter, set_context from airflow.utils.mixins import MultiprocessingStartMethodMixin -from airflow.utils.session import provide_session +from airflow.utils.session import create_session, provide_session from airflow.utils.sqlalchemy import skip_locked from airflow.utils.state import State from airflow.utils.types import DagRunType @@ -101,7 +99,7 @@ def __init__( # The process that was launched to process the given . self._process: Optional[multiprocessing.process.BaseProcess] = None # The result of Scheduler.process_file(file_path). - self._result: Optional[Tuple[List[dict], int]] = None + self._result: Optional[int] = None # Whether the process is done running. self._done = False # When the process started. @@ -178,7 +176,7 @@ def _run_file_processor( log.info("Started process (PID=%s) to work on %s", os.getpid(), file_path) dag_file_processor = DagFileProcessor(dag_ids=dag_ids, log=log) - result: Tuple[List[dict], int] = dag_file_processor.process_file( + result: int = dag_file_processor.process_file( file_path=file_path, pickle_dags=pickle_dags, failure_callback_requests=failure_callback_requests, @@ -337,10 +335,10 @@ def done(self) -> bool: return False @property - def result(self) -> Optional[Tuple[List[dict], int]]: + def result(self) -> Optional[int]: """ :return: result of running SchedulerJob.process_file() - :rtype: Optional[Tuple[List[dict], int]] + :rtype: int or None """ if not self.done: raise AirflowException("Tried to get the result before it's done!") @@ -565,262 +563,6 @@ def update_import_errors(session: Session, dagbag: DagBag) -> None: stacktrace=stacktrace)) session.commit() - # pylint: disable=too-many-return-statements,too-many-branches - @provide_session - def create_dag_run( - self, - dag: DAG, - dag_runs: Optional[List[DagRun]] = None, - session: Session = None, - ) -> Optional[DagRun]: - """ - This method checks whether a new DagRun needs to be created - for a DAG based on scheduling interval. - Returns DagRun if one is scheduled. Otherwise returns None. - """ - # pylint: disable=too-many-nested-blocks - if not dag.schedule_interval: - return None - - active_runs: List[DagRun] - if dag_runs is None: - active_runs = DagRun.find( - dag_id=dag.dag_id, - state=State.RUNNING, - external_trigger=False, - session=session - ) - else: - active_runs = [ - dag_run - for dag_run in dag_runs - if not dag_run.external_trigger - ] - # return if already reached maximum active runs and no timeout setting - if len(active_runs) >= dag.max_active_runs and not dag.dagrun_timeout: - return None - timed_out_runs = 0 - for dr in active_runs: - if ( - dr.start_date and dag.dagrun_timeout and - dr.start_date < timezone.utcnow() - dag.dagrun_timeout - ): - dr.state = State.FAILED - dr.end_date = timezone.utcnow() - dag.handle_callback(dr, success=False, reason='dagrun_timeout', - session=session) - timed_out_runs += 1 - session.commit() - if len(active_runs) - timed_out_runs >= dag.max_active_runs: - return None - - # this query should be replaced by find dagrun - last_scheduled_run: Optional[datetime.datetime] = ( - session.query(func.max(DagRun.execution_date)) - .filter_by(dag_id=dag.dag_id) - .filter(or_( - DagRun.external_trigger == False, # noqa: E712 pylint: disable=singleton-comparison - DagRun.run_type == DagRunType.SCHEDULED.value - )).scalar() - ) - - # don't schedule @once again - if dag.schedule_interval == '@once' and last_scheduled_run: - return None - - # don't do scheduler catchup for dag's that don't have dag.catchup = True - if not (dag.catchup or dag.schedule_interval == '@once'): - # The logic is that we move start_date up until - # one period before, so that timezone.utcnow() is AFTER - # the period end, and the job can be created... - now = timezone.utcnow() - next_start = dag.following_schedule(now) - last_start = dag.previous_schedule(now) - if next_start <= now or isinstance(dag.schedule_interval, timedelta): - new_start = last_start - else: - new_start = dag.previous_schedule(last_start) - - if dag.start_date: - if new_start >= dag.start_date: - dag.start_date = new_start - else: - dag.start_date = new_start - - next_run_date = None - if not last_scheduled_run: - # First run - task_start_dates = [t.start_date for t in dag.tasks] - if task_start_dates: - next_run_date = dag.normalize_schedule(min(task_start_dates)) - self.log.debug( - "Next run date based on tasks %s", - next_run_date - ) - else: - next_run_date = dag.following_schedule(last_scheduled_run) - - # make sure backfills are also considered - last_run = dag.get_last_dagrun(session=session) - if last_run and next_run_date: - while next_run_date <= last_run.execution_date: - next_run_date = dag.following_schedule(next_run_date) - - # don't ever schedule prior to the dag's start_date - if dag.start_date: - next_run_date = (dag.start_date if not next_run_date - else max(next_run_date, dag.start_date)) - if next_run_date == dag.start_date: - next_run_date = dag.normalize_schedule(dag.start_date) - - self.log.debug( - "Dag start date: %s. Next run date: %s", - dag.start_date, next_run_date - ) - - # don't ever schedule in the future or if next_run_date is None - if not next_run_date or next_run_date > timezone.utcnow(): - return None - - # this structure is necessary to avoid a TypeError from concatenating - # NoneType - period_end = None - if dag.schedule_interval == '@once': - period_end = next_run_date - elif next_run_date: - period_end = dag.following_schedule(next_run_date) - - # Don't schedule a dag beyond its end_date (as specified by the dag param) - if next_run_date and dag.end_date and next_run_date > dag.end_date: - return None - - # Don't schedule a dag beyond its end_date (as specified by the task params) - # Get the min task end date, which may come from the dag.default_args - min_task_end_date = min([t.end_date for t in dag.tasks if t.end_date], default=None) - if next_run_date and min_task_end_date and next_run_date > min_task_end_date: - return None - - if next_run_date and period_end and period_end <= timezone.utcnow(): - next_run = dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=next_run_date, - start_date=timezone.utcnow(), - state=State.RUNNING, - external_trigger=False - ) - return next_run - - return None - - @provide_session - def _process_task_instances( - self, dag: DAG, dag_runs: List[DagRun], session: Session = None - ) -> List[TaskInstanceKey]: - """ - This method schedules the tasks for a single DAG by looking at the - active DAG runs and adding task instances that should run to the - queue. - """ - # update the state of the previously active dag runs - active_dag_runs = 0 - task_instances_list = [] - for run in dag_runs: - self.log.info("Examining DAG run %s", run) - # don't consider runs that are executed in the future unless - # specified by config and schedule_interval is None - if run.execution_date > timezone.utcnow() and not dag.allow_future_exec_dates: - self.log.error( - "Execution date is in future: %s", - run.execution_date - ) - continue - - if active_dag_runs >= dag.max_active_runs: - self.log.info("Number of active dag runs reached max_active_run.") - break - - # skip backfill dagruns for now as long as they are not really scheduled - if run.is_backfill: - continue - - # todo: run.dag is transient but needs to be set - run.dag = dag # type: ignore - # todo: preferably the integrity check happens at dag collection time - run.verify_integrity(session=session) - ready_tis = run.update_state(session=session) - if run.state == State.RUNNING: - active_dag_runs += 1 - self.log.debug("Examining active DAG run: %s", run) - for ti in ready_tis: - self.log.debug('Queuing task: %s', ti) - task_instances_list.append(ti.key) - return task_instances_list - - @provide_session - def _process_dags(self, dags: List[DAG], session: Session = None) -> List[TaskInstanceKey]: - """ - Iterates over the dags and processes them. Processing includes: - - 1. Create appropriate DagRun(s) in the DB. - 2. Create appropriate TaskInstance(s) in the DB. - 3. Send emails for tasks that have missed SLAs (if CHECK_SLAS config enabled). - - :param dags: the DAGs from the DagBag to process - :type dags: List[airflow.models.DAG] - :rtype: list[TaskInstance] - :return: A list of generated TaskInstance objects - """ - check_slas: bool = conf.getboolean('core', 'CHECK_SLAS', fallback=True) - use_job_schedule: bool = conf.getboolean('scheduler', 'USE_JOB_SCHEDULE') - - # pylint: disable=too-many-nested-blocks - tis_out: List[TaskInstanceKey] = [] - dag_ids: List[str] = [dag.dag_id for dag in dags] - dag_runs = DagRun.find(dag_id=dag_ids, state=State.RUNNING, session=session) - # As per the docs of groupby (https://docs.python.org/3/library/itertools.html#itertools.groupby) - # we need to use `list()` otherwise the result will be wrong/incomplete - dag_runs_by_dag_id: Dict[str, List[DagRun]] = { - k: list(v) for k, v in groupby(dag_runs, lambda d: d.dag_id) - } - - for dag in dags: - dag_id: str = dag.dag_id - self.log.info("Processing %s", dag_id) - dag_runs_for_dag = dag_runs_by_dag_id.get(dag_id) or [] - - # Only creates DagRun for DAGs that are not subdag since - # DagRun of subdags are created when SubDagOperator executes. - if not dag.is_subdag and use_job_schedule: - dag_run = self.create_dag_run(dag, dag_runs=dag_runs_for_dag) - if dag_run: - dag_runs_for_dag.append(dag_run) - expected_start_date = dag.following_schedule(dag_run.execution_date) - if expected_start_date: - schedule_delay = dag_run.start_date - expected_start_date - Stats.timing( - 'dagrun.schedule_delay.{dag_id}'.format(dag_id=dag.dag_id), - schedule_delay) - self.log.info("Created %s", dag_run) - - if dag_runs_for_dag: - tis_out.extend(self._process_task_instances(dag, dag_runs_for_dag)) - if check_slas: - self.manage_slas(dag) - - return tis_out - - def _find_dags_to_process(self, dags: List[DAG]) -> List[DAG]: - """ - Find the DAGs that are not paused to process. - - :param dags: specified DAGs - :return: DAGs to process - """ - if self.dag_ids: - dags = [dag for dag in dags - if dag.dag_id in self.dag_ids] - return dags - @provide_session def execute_on_failure_callbacks( self, @@ -860,7 +602,7 @@ def process_file( failure_callback_requests: List[FailureCallbackRequest], pickle_dags: bool = False, session: Session = None - ) -> Tuple[List[dict], int]: + ) -> int: """ Process a Python file containing Airflow DAGs. @@ -886,9 +628,8 @@ def process_file( :type pickle_dags: bool :param session: Sqlalchemy ORM Session :type session: Session - :return: a tuple with list of SimpleDags made from the Dags found in the file and - count of import errors. - :rtype: Tuple[List[dict], int] + :return: count of import errors + :rtype: int """ self.log.info("Processing file %s for tasks to queue", file_path) @@ -897,36 +638,33 @@ def process_file( except Exception: # pylint: disable=broad-except self.log.exception("Failed at reloading the DAG file %s", file_path) Stats.incr('dag_file_refresh_error', 1, 1) - return [], 0 + return 0 if len(dagbag.dags) > 0: self.log.info("DAG(s) %s retrieved from %s", dagbag.dags.keys(), file_path) else: self.log.warning("No viable dags retrieved from %s", file_path) self.update_import_errors(session, dagbag) - return [], len(dagbag.import_errors) + return len(dagbag.import_errors) try: self.execute_on_failure_callbacks(dagbag, failure_callback_requests) except Exception: # pylint: disable=broad-except self.log.exception("Error executing failure callback!") - # Save individual DAGs in the ORM and update DagModel.last_scheduled_time + # Save individual DAGs in the ORM + dagbag.read_dags_from_db = True dagbag.sync_to_db() - paused_dag_ids = DagModel.get_paused_dag_ids(dag_ids=dagbag.dag_ids) - - unpaused_dags: List[DAG] = [ - dag for dag_id, dag in dagbag.dags.items() if dag_id not in paused_dag_ids - ] + if pickle_dags: + paused_dag_ids = DagModel.get_paused_dag_ids(dag_ids=dagbag.dag_ids) - serialized_dags = self._prepare_serialized_dags(unpaused_dags, pickle_dags, session) - - dags = self._find_dags_to_process(unpaused_dags) - - ti_keys_to_schedule = self._process_dags(dags, session) + unpaused_dags: List[DAG] = [ + dag for dag_id, dag in dagbag.dags.items() if dag_id not in paused_dag_ids + ] - self._schedule_task_instances(dagbag, ti_keys_to_schedule, session) + for dag in unpaused_dags: + dag.pickle(session) # Record import errors into the ORM try: @@ -934,85 +672,7 @@ def process_file( except Exception: # pylint: disable=broad-except self.log.exception("Error logging import errors!") - return serialized_dags, len(dagbag.import_errors) - - @provide_session - def _schedule_task_instances( - self, - dagbag: DagBag, - ti_keys_to_schedule: List[TaskInstanceKey], - session: Session = None - ) -> None: - """ - Checks whether the tasks specified by `ti_keys_to_schedule` parameter can be scheduled and - updates the information in the database, - - :param dagbag: DagBag - :type dagbag: DagBag - :param ti_keys_to_schedule: List of task instance keys which can be scheduled. - :type ti_keys_to_schedule: list - """ - # Refresh all task instances that will be scheduled - filter_for_tis = TI.filter_for_tis(ti_keys_to_schedule) - - refreshed_tis: List[TI] = [] - - if filter_for_tis is not None: - refreshed_tis = session.query(TI).filter(filter_for_tis).with_for_update().all() - - for ti in refreshed_tis: - # Add task to task instance - dag: DAG = dagbag.dags[ti.dag_id] - ti.task = dag.get_task(ti.task_id) - - # We check only deps needed to set TI to SCHEDULED state here. - # Deps needed to set TI to QUEUED state will be batch checked later - # by the scheduler for better performance. - dep_context = DepContext(deps=SCHEDULED_DEPS, ignore_task_deps=True) - - # Only schedule tasks that have their dependencies met, e.g. to avoid - # a task that recently got its state changed to RUNNING from somewhere - # other than the scheduler from getting its state overwritten. - if ti.are_dependencies_met( - dep_context=dep_context, - session=session, - verbose=True - ): - # Task starts out in the scheduled state. All tasks in the - # scheduled state will be sent to the executor - ti.state = State.SCHEDULED - # If the task is dummy, then mark it as done automatically - if isinstance(ti.task, DummyOperator) \ - and not ti.task.on_execute_callback \ - and not ti.task.on_success_callback: - ti.state = State.SUCCESS - ti.start_date = ti.end_date = timezone.utcnow() - ti.duration = 0 - - # Also save this task instance to the DB. - self.log.info("Creating / updating %s in ORM", ti) - session.merge(ti) - # commit batch - session.commit() - - @provide_session - def _prepare_serialized_dags( - self, dags: List[DAG], pickle_dags: bool, session: Session = None - ) -> List[dict]: - """ - Convert DAGS to SimpleDags. If necessary, it also Pickle the DAGs - - :param dags: List of DAGs - :return: List of SimpleDag - :rtype: List[dict] - """ - serialized_dags: List[dict] = [] - # Pickle the DAGs (if necessary) and put them into a SimpleDagBag - for dag in dags: - if pickle_dags: - dag.pickle(session) - serialized_dags.append(SerializedDAG.to_dict(dag)) - return serialized_dags + return len(dagbag.import_errors) class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes @@ -1046,6 +706,12 @@ class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes } heartrate: int = conf.getint('scheduler', 'SCHEDULER_HEARTBEAT_SEC') + # Singleton object pattern, PEP-484 style + class _NoLockObtained(enum.Enum): + token = 0 + + NO_LOCK_OBTAINED = _NoLockObtained.token + def __init__( self, dag_id: Optional[str] = None, @@ -1221,39 +887,54 @@ def __get_concurrency_maps( # pylint: disable=too-many-locals,too-many-statements @provide_session - def _find_executable_task_instances( + def _executable_task_instances_to_queued( self, - simple_dag_bag: SimpleDagBag, + max_tis: int, + dag_bag: DagBag, session: Session = None ) -> List[TI]: """ Finds TIs that are ready for execution with respect to pool limits, dag concurrency, executor state, and priority. - :param simple_dag_bag: TaskInstances associated with DAGs in the - simple_dag_bag will be fetched from the DB and executed - :type simple_dag_bag: airflow.utils.dag_processing.SimpleDagBag + :param max_tis: Maximum number of TIs to queue in this loop. + :type max_tis: int + :param dag_bag: TaskInstances associated with DAGs in the + _dag_bag will be fetched from the DB and executed + :type dag_bag: airflow.models.DagBag :return: list[airflow.models.TaskInstance] """ executable_tis: List[TI] = [] + # Get the pool settings. We get a lock on the pool rows, treating this as a "critical section" + # Throws an exception if lock cannot be obtained, rather than blocking + pools = models.Pool.slots_stats(with_for_update={'nowait': True}, session=session) + + # If the pools are full, there is no point doing anything! + max_tis = min(max_tis, sum(map(operator.itemgetter('open'), pools.values()))) + + if max_tis == 0: + self.log.debug("All pools are full!") + return executable_tis + # Get all task instances associated with scheduled # DagRuns which are not backfilled, in the given states, # and the dag is not paused task_instances_to_examine: List[TI] = ( session .query(TI) - .filter(TI.dag_id.in_(simple_dag_bag.dag_ids)) - .outerjoin( - DR, and_(DR.dag_id == TI.dag_id, DR.execution_date == TI.execution_date) - ) - .filter(or_(DR.run_id.is_(None), DR.run_type != DagRunType.BACKFILL_JOB.value)) - .outerjoin(DM, DM.dag_id == TI.dag_id) - .filter(or_(DM.dag_id.is_(None), not_(DM.is_paused))) + .join(TI.dag_run) + .filter(DR.run_type != DagRunType.BACKFILL_JOB.value) + .join(TI.dag_model) + .filter(not_(DM.is_paused)) .filter(TI.state == State.SCHEDULED) + .options(selectinload('dag_model')) + .limit(max_tis) + .with_for_update(**skip_locked(of=TI, session=session)) .all() ) - Stats.gauge('scheduler.tasks.pending', len(task_instances_to_examine)) + # TODO[HA]: This was wrong before anyway, as it only looked at a sub-set of dags, not everything. + # Stats.gauge('scheduler.tasks.pending', len(task_instances_to_examine)) if len(task_instances_to_examine) == 0: self.log.debug("No tasks to consider for execution.") @@ -1267,9 +948,6 @@ def _find_executable_task_instances( task_instance_str ) - # Get the pool settings - pools: Dict[str, models.Pool] = {p.pool: p for p in session.query(models.Pool).all()} - pool_to_task_instances: DefaultDict[str, List[models.Pool]] = defaultdict(list) for task_instance in task_instances_to_examine: pool_to_task_instances[task_instance.pool].append(task_instance) @@ -1296,7 +974,7 @@ def _find_executable_task_instances( ) continue - open_slots = pools[pool].open_slots(session=session) + open_slots = pools[pool]["open"] num_ready = len(task_instances) self.log.info( @@ -1324,10 +1002,9 @@ def _find_executable_task_instances( # Check to make sure that the task concurrency of the DAG hasn't been # reached. dag_id = task_instance.dag_id - serialized_dag = simple_dag_bag.get_dag(dag_id) current_dag_concurrency = dag_concurrency_map[dag_id] - dag_concurrency_limit = simple_dag_bag.get_dag(dag_id).concurrency + dag_concurrency_limit = task_instance.dag_model.concurrency self.log.info( "DAG %s has %s/%s running and queued tasks", dag_id, current_dag_concurrency, dag_concurrency_limit @@ -1341,27 +1018,23 @@ def _find_executable_task_instances( continue task_concurrency_limit: Optional[int] = None - if serialized_dag.has_task(task_instance.task_id): - task_concurrency_limit = serialized_dag.get_task( - task_instance.task_id).task_concurrency - - if task_concurrency_limit is not None: - current_task_concurrency = task_concurrency_map[ - (task_instance.dag_id, task_instance.task_id) - ] - - if current_task_concurrency >= task_concurrency_limit: - self.log.info("Not executing %s since the task concurrency for" - " this task has been reached.", task_instance) - continue - - if self.executor.has_task(task_instance): - self.log.debug( - "Not handling task %s as the executor reports it is running", - task_instance.key - ) - num_tasks_in_executor += 1 - continue + if task_instance.dag_model.has_task_concurrency_limits: + # Many dags don't have a task_concurrency, so where we can avoid loading the full + # serialized DAG the better. + serialized_dag = dag_bag.get_dag(dag_id) + if serialized_dag.has_task(task_instance.task_id): + task_concurrency_limit = serialized_dag.get_task( + task_instance.task_id).task_concurrency + + if task_concurrency_limit is not None: + current_task_concurrency = task_concurrency_map[ + (task_instance.dag_id, task_instance.task_id) + ] + + if current_task_concurrency >= task_concurrency_limit: + self.log.info("Not executing %s since the task concurrency for" + " this task has been reached.", task_instance) + continue if task_instance.pool_slots > open_slots: self.log.info("Not executing %s since it requires %s slots " @@ -1387,116 +1060,63 @@ def _find_executable_task_instances( [repr(x) for x in executable_tis]) self.log.info( "Setting the following tasks to queued state:\n\t%s", task_instance_str) - # so these dont expire on commit - for ti in executable_tis: - copy_dag_id = ti.dag_id - copy_execution_date = ti.execution_date - copy_task_id = ti.task_id - make_transient(ti) - ti.dag_id = copy_dag_id - ti.execution_date = copy_execution_date - ti.task_id = copy_task_id - return executable_tis - - @provide_session - def _change_state_for_executable_task_instances( - self, task_instances: List[TI], session: Session = None - ) -> List[SimpleTaskInstance]: - """ - Changes the state of task instances in the list with one of the given states - to QUEUED atomically, and returns the TIs changed in SimpleTaskInstance format. - - :param task_instances: TaskInstances to change the state of - :type task_instances: list[airflow.models.TaskInstance] - :rtype: list[airflow.models.taskinstance.SimpleTaskInstance] - """ - if len(task_instances) == 0: - session.commit() - return [] - - tis_to_set_to_queued: List[TI] = ( - session - .query(TI) - .filter(TI.filter_for_tis(task_instances)) - .filter(TI.state == State.SCHEDULED) - .with_for_update() - .all() - ) - - if len(tis_to_set_to_queued) == 0: - self.log.info("No tasks were able to have their state changed to queued.") - session.commit() - return [] # set TIs to queued state - filter_for_tis = TI.filter_for_tis(tis_to_set_to_queued) + filter_for_tis = TI.filter_for_tis(executable_tis) session.query(TI).filter(filter_for_tis).update( + # TODO[ha]: should we use func.now()? How does that work with DB timezone on mysql when it's not + # UTC? {TI.state: State.QUEUED, TI.queued_dttm: timezone.utcnow(), TI.queued_by_job_id: self.id}, synchronize_session=False ) - session.commit() - # Generate a list of SimpleTaskInstance for the use of queuing - # them in the executor. - simple_task_instances = [SimpleTaskInstance(ti) for ti in tis_to_set_to_queued] - - task_instance_str = "\n\t".join([repr(x) for x in tis_to_set_to_queued]) - self.log.info("Setting the following %s tasks to queued state:\n\t%s", - len(tis_to_set_to_queued), task_instance_str) - return simple_task_instances + for ti in executable_tis: + make_transient(ti) + return executable_tis def _enqueue_task_instances_with_queued_state( self, - simple_dag_bag: SimpleDagBag, - simple_task_instances: List[SimpleTaskInstance] + task_instances: List[TI] ) -> None: """ Takes task_instances, which should have been set to queued, and enqueues them with the executor. - :param simple_task_instances: TaskInstances to enqueue - :type simple_task_instances: list[SimpleTaskInstance] - :param simple_dag_bag: Should contains all of the task_instances' dags - :type simple_dag_bag: airflow.utils.dag_processing.SimpleDagBag + :param task_instances: TaskInstances to enqueue + :type task_instances: list[TaskInstance] """ # actually enqueue them - for simple_task_instance in simple_task_instances: - serialized_dag = simple_dag_bag.get_dag(simple_task_instance.dag_id) + for ti in task_instances: command = TI.generate_command( - simple_task_instance.dag_id, - simple_task_instance.task_id, - simple_task_instance.execution_date, + ti.dag_id, + ti.task_id, + ti.execution_date, local=True, mark_success=False, ignore_all_deps=False, ignore_depends_on_past=False, ignore_task_deps=False, ignore_ti_state=False, - pool=simple_task_instance.pool, - file_path=serialized_dag.full_filepath, - pickle_id=serialized_dag.pickle_id, + pool=ti.pool, + file_path=ti.dag_model.fileloc, + pickle_id=ti.dag_model.pickle_id, ) - priority = simple_task_instance.priority_weight - queue = simple_task_instance.queue + priority = ti.priority_weight + queue = ti.queue self.log.info( "Sending %s to executor with priority %s and queue %s", - simple_task_instance.key, priority, queue + ti.key, priority, queue ) self.executor.queue_command( - simple_task_instance, + ti, command, priority=priority, queue=queue, ) - @provide_session - def _execute_task_instances( - self, - simple_dag_bag: SimpleDagBag, - session: Session = None - ) -> int: + def _execute_task_instances(self, dag_bag: DagBag, session: Session) -> int: """ Attempts to execute TaskInstances that should be executed by the scheduler. @@ -1506,23 +1126,18 @@ def _execute_task_instances( 2. Change the state for the TIs above atomically. 3. Enqueue the TIs in the executor. - :param simple_dag_bag: TaskInstances associated with DAGs in the - simple_dag_bag will be fetched from the DB and executed - :type simple_dag_bag: airflow.utils.dag_processing.SimpleDagBag + :param dag_bag: TaskInstances associated with DAGs in the + dag_bag will be fetched from the DB and executed + :type dag_bag: airflow.models.DagBag + :param session: + :type session: sqlalchemy.orm.Session :return: Number of task instance with state changed. """ - executable_tis = self._find_executable_task_instances(simple_dag_bag, session=session) - - def query(result: int, items: List[TI]) -> int: - simple_tis_with_state_changed = \ - self._change_state_for_executable_task_instances(items, session=session) - self._enqueue_task_instances_with_queued_state( - simple_dag_bag, - simple_tis_with_state_changed) - session.commit() - return result + len(simple_tis_with_state_changed) + max_tis = min(self.max_tis_per_query, self.executor.slots_available) + queued_tis = self._executable_task_instances_to_queued(max_tis, dag_bag, session=session) - return helpers.reduce_in_chunks(query, executable_tis, 0, self.max_tis_per_query) + self._enqueue_task_instances_with_queued_state(queued_tis) + return len(queued_tis) @provide_session def _change_state_for_tasks_failed_to_execute(self, session: Session = None): @@ -1564,14 +1179,14 @@ def _change_state_for_tasks_failed_to_execute(self, session: Session = None): self.log.info("Set the following tasks to scheduled state:\n\t%s", task_instance_str) @provide_session - def _process_executor_events(self, simple_dag_bag: SimpleDagBag, session: Session = None) -> None: + def _process_executor_events(self, session: Session = None) -> int: """ Respond to executor events. """ if not self.processor_agent: raise ValueError("Processor agent is not started.") ti_primary_key_to_try_number_map: Dict[Tuple[str, str, datetime.datetime], int] = {} - event_buffer = self.executor.get_event_buffer(simple_dag_bag.dag_ids) + event_buffer = self.executor.get_event_buffer() tis_with_right_state: List[TaskInstanceKey] = [] # Report execution @@ -1591,11 +1206,11 @@ def _process_executor_events(self, simple_dag_bag: SimpleDagBag, session: Sessio # Return if no finished tasks if not tis_with_right_state: - return + return len(event_buffer) # Check state of finished tasks filter_for_tis = TI.filter_for_tis(tis_with_right_state) - tis: List[TI] = session.query(TI).filter(filter_for_tis).all() + tis: List[TI] = session.query(TI).filter(filter_for_tis).options(selectinload('dag_model')).all() for ti in tis: try_number = ti_primary_key_to_try_number_map[ti.key.primary] buffer_key = ti.key.with_try_number(try_number) @@ -1612,13 +1227,14 @@ def _process_executor_events(self, simple_dag_bag: SimpleDagBag, session: Sessio msg = "Executor reports task instance %s finished (%s) although the " \ "task says its %s. (Info: %s) Was the task killed externally?" self.log.error(msg, ti, state, ti.state, info) - serialized_dag = simple_dag_bag.get_dag(ti.dag_id) self.processor_agent.send_callback_to_execute( - full_filepath=serialized_dag.full_filepath, + full_filepath=ti.dag_model.full_filepath, task_instance=ti, msg=msg % (ti, state, ti.state, info), ) + return len(event_buffer) + def _execute(self) -> None: self.log.info("Starting the scheduler") @@ -1719,6 +1335,8 @@ def _run_scheduler_loop(self) -> None: raise ValueError("Processor agent is not started.") is_unit_test: bool = conf.getboolean('core', 'unit_test_mode') + dag_bag = DagBag() + # For the execute duration, parse and schedule DAGs while True: loop_start_time = time.time() @@ -1730,15 +1348,24 @@ def _run_scheduler_loop(self) -> None: self.log.debug("Waiting for processors to finish since we're using sqlite") self.processor_agent.wait_until_finished() - serialized_dags = self.processor_agent.harvest_serialized_dags() + with create_session() as session: + timer = Stats.timer('scheduler.critical_section_duration') + timer.start() + num_queued_tis = self._scheduler_loop_critical_section(dag_bag, session) - self.log.debug("Harvested %d SimpleDAGs", len(serialized_dags)) + if num_queued_tis is self.NO_LOCK_OBTAINED: + Stats.incr('scheduler.critical_section_lock_busy') + num_queued_tis = 0 + else: + # Make sure we only sent this metric if we obtained the lock, otherwise we'll skew the + # metric, way down + timer.stop(send=True) - # Send tasks for execution if available - simple_dag_bag = SimpleDagBag(serialized_dags) + self.executor.heartbeat() + session.expunge_all() + num_finished_events = self._process_executor_events(session=session) - if not self._validate_and_run_task_instances(simple_dag_bag=simple_dag_bag): - continue + self.processor_agent.heartbeat() # Heartbeat the scheduler periodically self.heartbeat(only_if_necessary=True) @@ -1749,7 +1376,10 @@ def _run_scheduler_loop(self) -> None: loop_duration = loop_end_time - loop_start_time self.log.debug("Ran scheduling loop in %.2f seconds", loop_duration) - if not is_unit_test: + if not is_unit_test and not num_queued_tis and not num_finished_events: + # If the scheduler is doing things, don't sleep. This means when there is work to do, the + # scheduler will run "as quick as possible", but when it's stopped, it can sleep, dropping CPU + # usage when "idle" time.sleep(self._processor_poll_interval) if self.processor_agent.done: @@ -1758,53 +1388,119 @@ def _run_scheduler_loop(self) -> None: ) break - def _validate_and_run_task_instances(self, simple_dag_bag: SimpleDagBag) -> bool: - if simple_dag_bag.serialized_dags: - try: - self._process_and_execute_tasks(simple_dag_bag) - except Exception as e: # pylint: disable=broad-except - self.log.error("Error queuing tasks") - self.log.exception(e) - return False - - # Call heartbeats - self.log.debug("Heartbeating the executor") - self.executor.heartbeat() - - self._change_state_for_tasks_failed_to_execute() - - # Process events from the executor - self._process_executor_events(simple_dag_bag) - return True - - def _process_and_execute_tasks(self, simple_dag_bag: SimpleDagBag) -> None: - # Handle cases where a DAG run state is set (perhaps manually) to - # a non-running state. Handle task instances that belong to - # DAG runs in those states - # If a task instance is up for retry but the corresponding DAG run - # isn't running, mark the task instance as FAILED so we don't try - # to re-run it. - self._change_state_for_tis_without_dagrun( - simple_dag_bag=simple_dag_bag, - old_states=[State.UP_FOR_RETRY], - new_state=State.FAILED - ) - # If a task instance is scheduled or queued or up for reschedule, - # but the corresponding DAG run isn't running, set the state to - # NONE so we don't try to re-run it. - self._change_state_for_tis_without_dagrun( - simple_dag_bag=simple_dag_bag, - old_states=[State.QUEUED, - State.SCHEDULED, - State.UP_FOR_RESCHEDULE, - State.SENSING], - new_state=State.NONE - ) - self._execute_task_instances(simple_dag_bag) + def _scheduler_loop_critical_section(self, dag_bag, session) -> Union[int, _NoLockObtained]: + """ + :return: Number of TIs enqueued in this iteration + :rtype: int + """ + try: + from sqlalchemy import event + expected_commit = False + + @event.listens_for(session.bind, 'commit') + def validate_commit(_): + nonlocal expected_commit + if expected_commit: + expected_commit = False + return + raise RuntimeError("UNEXPECTED COMMIT - THIS WILL BREAK HA LOCKS!") + + # Put a check in place to make sure we don't commit unexpectedly + + query = DagModel.dags_needing_dagruns(session) + for dag_model in query: + dag = dag_bag.get_dag(dag_model.dag_id, session=session) + next_run_date = dag_model.next_dagrun + dag.create_dagrun( + run_type=DagRunType.SCHEDULED.value, + execution_date=next_run_date, + start_date=timezone.utcnow(), + state=State.RUNNING, + external_trigger=False, + session=session + ) + + # Check max_active_runs, to see if we are now at the limit for this dag? + active_runs_of_dag = dict(session.query(DagRun.dag_id, func.count('*')).filter( + DagRun.dag_id == dag_model.dag_id, + DagRun.state == State.RUNNING, # pylint: disable=comparison-with-callable + DagRun.external_trigger.is_(False) + ).scalar()) + + # TODO[HA]: add back in dagrun.timeout + + if dag.max_active_runs and dag.max_active_runs >= active_runs_of_dag: + self.log.info( + "DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs", + dag.dag_id, active_runs_of_dag, dag.max_active_runs + ) + dag_model.next_dagrun = None + dag_model.next_dagrun_create_after = None + else: + next_dagrun_info = dag.next_dagrun_info(next_run_date) + if next_dagrun_info: + dag_model.next_dagrun = next_dagrun_info['execution_date'] + dag_model.next_dagrun_create_after = next_dagrun_info['can_be_created_after'] + else: + dag_model.next_dagrun = None + dag_model.next_dagrun_create_after = None + + # TODO[HA]: Should we do a session.flush() so we don't have to keep lots of state/object in + # memory for larger dags? or expunge_all() + + # commit the session - Release the write lock on DagModel table. + expected_commit = True + session.commit() + # END: create dagruns + + # Tunable limit?, or select multiple dag runs for a single dag? + for dag_run in DagRun.next_dagruns_to_examine(session): + dag_run.dag = dag_bag.get_dag(dag_run.dag_id, session=session) + + # TODO[HA]: Run verify_integrity, but only if the serialized_dag has changed + + # TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else? + schedulable_tis = dag_run.update_state(session=session) + # TODO[HA]: Don't return, update these form in update_state + session.query(TI).filter( + TI.dag_id == dag_run.dag_id, + TI.execution_date == dag_run.execution_date, + TI.task_id.in_(ti.task_id for ti in schedulable_tis) + ).update({TI.state: State.SCHEDULED}, synchronize_session=False) + + # TODO[HA]: Manage SLAs + + expected_commit = True + session.commit() + + # Without this, the session has an invalid view of the DB + session.expunge_all() + # END: schedule TIs + + # Find anything TIs in state SCHEDULED, try to QUEUE it (send it to the executor) + + # TODO[HA]: Do we need to call + # _change_state_for_tis_without_dagrun (2x) that we were before + # to tidy up manually tweaked TIs. Do we need to do it every + # time? + + return self._execute_task_instances(dag_bag, session=session) + + # End of loop, allowed/expected to commit + except OperationalError as e: + # Postgres: lock not available + if getattr(e.orig, 'pgcode') == '55P03': + # We could test if e.orig is an instance of psycopg2.errors.LockNotAvailable, but that + # involves importing it. This doesn't + self.log.debug("Critical section lock held by another Scheduler") + return self.NO_LOCK_OBTAINED + raise + finally: + event.remove(session.bind, 'commit', validate_commit) @provide_session def _emit_pool_metrics(self, session: Session = None) -> None: - pools = models.Pool.slots_stats(session) + pools = models.Pool.slots_stats(session=session) for pool_name, slot_stats in pools.items(): Stats.gauge(f'pool.open_slots.{pool_name}', slot_stats["open"]) Stats.gauge(f'pool.queued_slots.{pool_name}', slot_stats[State.QUEUED]) diff --git a/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py b/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py new file mode 100644 index 0000000000000..30a4c3a2d91f2 --- /dev/null +++ b/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Add scheduling_decision to DagRun and DAG + +Revision ID: 98271e7606e2 +Revises: e1a11ece99cc +Create Date: 2020-09-15 12:13:32.968148 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '98271e7606e2' +down_revision = 'e1a11ece99cc' +branch_labels = None +depends_on = None + + +def upgrade(): + """Apply Add scheduling_decision to DagRun and DAG""" + with op.batch_alter_table('dag_run', schema=None) as batch_op: + batch_op.add_column(sa.Column('last_scheduling_decision', sa.DateTime(timezone=True), nullable=True)) + batch_op.create_index('idx_last_scheduling_decision', ['last_scheduling_decision'], unique=False) + + with op.batch_alter_table('dag', schema=None) as batch_op: + batch_op.add_column(sa.Column('next_dagrun', sa.DateTime(timezone=True), nullable=True)) + batch_op.add_column(sa.Column('next_dagrun_create_after', sa.DateTime(timezone=True), nullable=True)) + # Create with nullable and no default, then ALTER to set values, to aviod table level lock + batch_op.add_column(sa.Column('concurrency', sa.Integer(), nullable=True)) + batch_op.add_column(sa.Column('has_task_concurrency_limits', sa.Boolean(), nullable=True)) + + batch_op.create_index('idx_next_dagrun_create_after', ['next_dagrun_create_after'], unique=False) + + try: + from airflow.configuration import conf + concurrency = conf.getint('core', 'dag_concurrency', fallback=16) + except: # noqa + concurrency = 16 + + # Set it to true here as it makes us take the slow/more complete path, and when it's next parsed by the + # DagParser it will get set to correct value. + op.execute( + "UPDATE dag SET concurrency={}, has_task_concurrency_limits=true where concurrency IS NULL".format( + concurrency + ) + ) + op.alter_column('dag', 'concurrency', type_=sa.Integer(), nullable=False) + op.alter_column('dag', 'has_task_concurrency_limits', type_=sa.Boolean(), nullable=False) + + +def downgrade(): + """Unapply Add scheduling_decision to DagRun and DAG""" + with op.batch_alter_table('dag_run', schema=None) as batch_op: + batch_op.drop_index('idx_last_scheduling_decision') + batch_op.drop_column('last_scheduling_decision') + + with op.batch_alter_table('dag', schema=None) as batch_op: + batch_op.drop_index('idx_next_dagrun_create_after') + batch_op.drop_column('next_dagrun_create_after') + batch_op.drop_column('next_dagrun') + batch_op.drop_column('concurrency') + batch_op.drop_column('has_task_concurrency_limits') diff --git a/airflow/models/dag.py b/airflow/models/dag.py index b3465476e81bf..283e8154e9f04 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -467,6 +467,99 @@ def previous_schedule(self, dttm): elif self.normalized_schedule_interval is not None: return timezone.convert_to_utc(dttm - self.normalized_schedule_interval) + def next_dagrun_info(self, date_last_automated_dagrun : Optional[pendulum.DateTime]): + """ + Get information about the next DagRun of this dag after ``date_last_automated_dagrun`` -- the + execution date, and the earliest it could be scheduled + + :param date_last_automated_dagrun: The max(execution_date) of existing + "automated" DagRuns for this dag (scheduled or backfill, but not + manual) + """ + next_execution_date = self.next_dagrun_after_date(date_last_automated_dagrun) + + if next_execution_date is None or self.schedule_interval in (None, '@once'): + return None + + return { + 'execution_date': next_execution_date, + 'can_be_created_after': self.following_schedule(next_execution_date) + } + + def next_dagrun_after_date(self, date_last_automated_dagrun: Optional[pendulum.DateTime]): + """ + Get the next execution date after the given ``date_last_automated_dagrun``, acording to + schedule_interval, start_date, end_date etc. This doesn't check max active run or any other + "concurrency" type limits, it only perofmrs calculations based on the varios date and interval fields + of this dag and it's tasks. + + :param date_last_automated_dagrun: The execution_date of the last scheduler or + backfill triggered run for this dag + :type date_last_automated_dagrun: pendulum.Pendulum + """ + if not self.schedule_interval: + return None + + # don't schedule @once again + if self.schedule_interval == '@once' and date_last_automated_dagrun: + return None + + # don't do scheduler catchup for dag's that don't have dag.catchup = True + if not (self.catchup or self.schedule_interval == '@once'): + # The logic is that we move start_date up until + # one period before, so that timezone.utcnow() is AFTER + # the period end, and the job can be created... + now = timezone.utcnow() + next_start = self.following_schedule(now) + last_start = self.previous_schedule(now) + if next_start <= now or isinstance(self.schedule_interval, timedelta): + new_start = last_start + else: + new_start = self.previous_schedule(last_start) + + if self.start_date: + if new_start >= self.start_date: + self.start_date = new_start + else: + self.start_date = new_start + + next_run_date = None + if not date_last_automated_dagrun: + # First run + task_start_dates = [t.start_date for t in self.tasks] + if task_start_dates: + next_run_date = self.normalize_schedule(min(task_start_dates)) + self.log.debug("Next run date based on tasks %s", next_run_date) + else: + next_run_date = self.following_schedule(date_last_automated_dagrun) + + if date_last_automated_dagrun and next_run_date: + while next_run_date <= date_last_automated_dagrun: + next_run_date = self.following_schedule(next_run_date) + + # don't ever schedule prior to the dag's start_date + if self.start_date: + next_run_date = self.start_date if not next_run_date else max(next_run_date, self.start_date) + if next_run_date == self.start_date: + next_run_date = self.normalize_schedule(self.start_date) + + self.log.debug( + "Dag start date: %s. Next run date: %s", + self.start_date, next_run_date + ) + + # Don't schedule a dag beyond its end_date (as specified by the dag param) + if next_run_date and self.end_date and next_run_date > self.end_date: + return None + + # Don't schedule a dag beyond its end_date (as specified by the task params) + # Get the min task end date, which may come from the dag.default_args + task_end_dates = [t.end_date for t in self.tasks if t.end_date] + if task_end_dates and next_run_date: + min_task_end_date = min(task_end_dates) + if next_run_date > min_task_end_date: + return None + def get_run_dates(self, start_date, end_date=None): """ Returns a list of dates between the interval received as parameter using this @@ -1561,8 +1654,6 @@ def create_dagrun(self, ) session.add(run) - session.commit() - run.dag = self # create the associated task instances @@ -1573,7 +1664,7 @@ def create_dagrun(self, @classmethod @provide_session - def bulk_sync_to_db(cls, dags: Collection["DAG"], sync_time=None, session=None): + def bulk_sync_to_db(cls, dags: Collection["DAG"], session=None): """ Save attributes about list of DAG to the DB. Note that this method can be called for both DAGs and SubDAGs. A SubDag is actually a @@ -1581,15 +1672,11 @@ def bulk_sync_to_db(cls, dags: Collection["DAG"], sync_time=None, session=None): :param dags: the DAG objects to save to the DB :type dags: List[airflow.models.dag.DAG] - :param sync_time: The time that the DAG should be marked as sync'ed - :type sync_time: datetime :return: None """ if not dags: return - if sync_time is None: - sync_time = timezone.utcnow() log.info("Sync %s DAGs", len(dags)) dag_by_ids = {dag.dag_id: dag for dag in dags} dag_ids = set(dag_by_ids.keys()) @@ -1615,6 +1702,21 @@ def bulk_sync_to_db(cls, dags: Collection["DAG"], sync_time=None, session=None): session.add(orm_dag) orm_dags.append(orm_dag) + # Get the latest dag run for each existing dag as a single query (avoid n+1 query) + most_recent_dag_runs = dict(session.query(DagRun.dag_id, func.max_(DagRun.execution_date)).filter( + DagRun.dag_id.in_(existing_dag_ids), + or_( + DagRun.run_type == DagRunType.BACKFILL_JOB.value, + DagRun.run_type == DagRunType.SCHEDULED.value, + ), + ).group_by(DagRun.dag_id).all()) + + num_active_runs = dict(session.query(DagRun.dag_id, func.count('*')).filter( + DagRun.dag_id.in_(existing_dag_ids), + DagRun.state == State.RUNNING, # pylint: disable=comparison-with-callable + DagRun.external_trigger.is_(False) + ).group_by(DagRun.dag_id).all()) + for orm_dag in sorted(orm_dags, key=lambda d: d.dag_id): dag = dag_by_ids[orm_dag.dag_id] if dag.is_subdag: @@ -1627,10 +1729,33 @@ def bulk_sync_to_db(cls, dags: Collection["DAG"], sync_time=None, session=None): orm_dag.fileloc = dag.fileloc orm_dag.owners = dag.owner orm_dag.is_active = True - orm_dag.last_scheduler_run = sync_time orm_dag.default_view = dag.default_view orm_dag.description = dag.description orm_dag.schedule_interval = dag.schedule_interval + orm_dag.concurrency = dag.concurrency + orm_dag.has_task_concurrency_limits = any( + t.task_concurrency is not None for t in dag.tasks + ) + + next_dagrun_info = dag.next_dagrun_info(most_recent_dag_runs.get(dag.dag_id)) + if next_dagrun_info: + orm_dag.next_dagrun = next_dagrun_info['execution_date'] + orm_dag.next_dagrun_create_after = next_dagrun_info['can_be_created_after'] + else: + orm_dag.next_dagrun = None + orm_dag.next_dagrun_create_after = None + + active_runs_of_dag = num_active_runs.get(dag.dag_id, 0) + if dag.max_active_runs and active_runs_of_dag >= dag.max_active_runs: + # Since this happens every time the dag is parsed it would be quite spammy + log.debug( + "DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs", + dag.dag_id, active_runs_of_dag, dag.max_active_runs + ) + orm_dag.next_dagrun_create_after = None + + log.info("Setting next_dagrun for %s to %s", dag.dag_id, orm_dag.next_dagrun) + for orm_tag in list(orm_dag.tags): if orm_tag.name not in orm_dag.tags: session.delete(orm_tag) @@ -1646,23 +1771,16 @@ def bulk_sync_to_db(cls, dags: Collection["DAG"], sync_time=None, session=None): if settings.STORE_DAG_CODE: DagCode.bulk_sync_to_db([dag.fileloc for dag in orm_dags]) - session.commit() - - for dag in dags: - cls.bulk_sync_to_db(dag.subdags, sync_time=sync_time, session=session) - @provide_session - def sync_to_db(self, sync_time=None, session=None): + def sync_to_db(self, session=None): """ Save attributes about this DAG to the DB. Note that this method can be called for both DAGs and SubDAGs. A SubDag is actually a SubDagOperator. - :param sync_time: The time that the DAG should be marked as sync'ed - :type sync_time: datetime :return: None """ - self.bulk_sync_to_db([self], sync_time, session) + self.bulk_sync_to_db([self], session) def get_default_view(self): """This is only there for backward compatible jinja2 templates""" @@ -1824,8 +1942,18 @@ class DagModel(Base): # Tags for view filter tags = relationship('DagTag', cascade='all,delete-orphan', backref=backref('dag')) + concurrency = Column(Integer, nullable=False) + + has_task_concurrency_limits = Column(Boolean, nullable=False) + + # The execution_date of the next dag run + next_dagrun = Column(UtcDateTime) + # Earliest time at which this ``next_dagrun`` can be created + next_dagrun_create_after = Column(UtcDateTime) + __table_args__ = ( Index('idx_root_dag_id', root_dag_id, unique=False), + Index('idx_next_dagrun_create_after', next_dagrun_create_after, unique=False), ) def __repr__(self): @@ -1939,6 +2067,36 @@ def deactivate_deleted_dags(cls, alive_dag_filelocs: List[str], session=None): session.rollback() raise + @classmethod + def dags_needing_dagruns(cls, session: Session): + """ + Return (and lock) a list of Dag objects that are due to create a new DagRun This will return a + resultset of rows that is row-level-locked with a "SELECT ... FOR UPDATE" query, you should ensure + that any scheduling decisions are made in a single transaction -- as soon as the transaction is + commited it will be unlocked. + """ + + # TODO[HA]: Bake this query, it is run _A lot_ + # TODO[HA]: Make this limit a tunable. We limit so that _one_ scheduler + # doesn't try to do all the creation of dag runs + return session.query(cls).filter( + cls.is_paused.is_(False), + cls.is_active.is_(True), + cls.next_dagrun_create_after <= func.now(), + ).order_by( + cls.next_dagrun_create_after + ).limit(10).with_for_update(of=cls, skip_locked=True) + + +STATICA_HACK = True +globals()['kcah_acitats'[::-1].upper()] = False +if STATICA_HACK: # pragma: no cover + # Let pylint know about these relationships, without introducing an import cycle + from sqlalchemy.orm import relationship + + from airflow.models.serialized_dag import SerializedDagModel + DagModel.serialized_dag = relationship(SerializedDagModel) + class DagContext: """ diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index ed709f916d948..ac2c253fe8522 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -30,6 +30,7 @@ from typing import Dict, List, NamedTuple, Optional from croniter import CroniterBadCronError, CroniterBadDateError, CroniterNotAlphaError, croniter +from sqlalchemy.orm import Session from tabulate import tabulate from airflow import settings @@ -42,6 +43,7 @@ from airflow.utils.dag_cycle_tester import test_cycle from airflow.utils.file import correct_maybe_zipped, list_py_file_paths, might_contain_dag from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.session import provide_session from airflow.utils.timeout import timeout @@ -144,7 +146,8 @@ def store_serialized_dags(self) -> bool: def dag_ids(self) -> List[str]: return list(self.dags.keys()) - def get_dag(self, dag_id): + @provide_session + def get_dag(self, dag_id, session: Session = None): """ Gets the DAG out of the dictionary, and refreshes it if expired @@ -159,7 +162,7 @@ def get_dag(self, dag_id): from airflow.models.serialized_dag import SerializedDagModel if dag_id not in self.dags: # Load from DB if not (yet) in the bag - self._add_dag_from_db(dag_id=dag_id) + self._add_dag_from_db(dag_id=dag_id, session=session) return self.dags.get(dag_id) # If DAG is in the DagBag, check the following @@ -173,7 +176,7 @@ def get_dag(self, dag_id): ): sd_last_updated_datetime = SerializedDagModel.get_last_updated_datetime(dag_id=dag_id) if sd_last_updated_datetime > self.dags_last_fetched[dag_id]: - self._add_dag_from_db(dag_id=dag_id) + self._add_dag_from_db(dag_id=dag_id, session=session) return self.dags.get(dag_id) @@ -183,7 +186,7 @@ def get_dag(self, dag_id): if dag_id in self.dags: dag = self.dags[dag_id] if dag.is_subdag: - root_dag_id = dag.parent_dag.dag_id + root_dag_id = dag.parent_dag.dag_id # type: ignore # If DAG Model is absent, we can't check last_expired property. Is the DAG not yet synchronized? orm_dag = DagModel.get_current(root_dag_id) @@ -192,7 +195,7 @@ def get_dag(self, dag_id): # If the dag corresponding to root_dag_id is absent or expired is_missing = root_dag_id not in self.dags - is_expired = (orm_dag.last_expired and dag.last_loaded < orm_dag.last_expired) + is_expired = (orm_dag.last_expired and dag and dag.last_loaded < orm_dag.last_expired) if is_missing or is_expired: # Reprocess source file found_dags = self.process_file( @@ -205,10 +208,10 @@ def get_dag(self, dag_id): del self.dags[dag_id] return self.dags.get(dag_id) - def _add_dag_from_db(self, dag_id: str): + def _add_dag_from_db(self, dag_id: str, session: Session): """Add DAG to DagBag from DB""" from airflow.models.serialized_dag import SerializedDagModel - row = SerializedDagModel.get(dag_id) + row = SerializedDagModel.get(dag_id, session) if not row: raise ValueError(f"DAG '{dag_id}' not found in serialized_dag table") @@ -525,6 +528,6 @@ def sync_to_db(self): DAG.bulk_sync_to_db(self.dags.values()) # Write Serialized DAGs to DB if DAG Serialization is turned on # Even though self.read_dags_from_db is False - if settings.STORE_SERIALIZED_DAGS: + if settings.STORE_SERIALIZED_DAGS or self.read_dags_from_db: self.log.debug("Calling the SerializedDagModel.bulk_sync_to_db method") SerializedDagModel.bulk_sync_to_db(self.dags.values()) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 429e98dbfa244..f5e2c359bb63e 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -26,6 +26,7 @@ from sqlalchemy.orm import backref, relationship, synonym from sqlalchemy.orm.session import Session +from airflow.configuration import conf as airflow_conf from airflow.exceptions import AirflowException from airflow.models.base import ID_LEN, Base from airflow.models.taskinstance import TaskInstance as TI @@ -36,7 +37,7 @@ from airflow.utils import timezone from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import provide_session -from airflow.utils.sqlalchemy import UtcDateTime +from airflow.utils.sqlalchemy import UtcDateTime, skip_locked from airflow.utils.state import State from airflow.utils.types import DagRunType @@ -59,6 +60,8 @@ class DagRun(Base, LoggingMixin): external_trigger = Column(Boolean, default=True) run_type = Column(String(50), nullable=False) conf = Column(PickleType) + # When a scheduler last attempted to schedule TIs for this DagRun + last_scheduling_decision = Column(UtcDateTime) dag = None @@ -66,6 +69,7 @@ class DagRun(Base, LoggingMixin): Index('dag_id_state', dag_id, _state), UniqueConstraint('dag_id', 'execution_date'), UniqueConstraint('dag_id', 'run_id'), + Index('idx_last_scheduling_decision', last_scheduling_decision), ) task_instances = relationship( @@ -75,6 +79,8 @@ class DagRun(Base, LoggingMixin): backref=backref('dag_run', uselist=False), ) + DEFAULT_DAGRUNS_TO_EXAMINE = airflow_conf.getint('scheduler', 'max_dagruns_per_query', fallback=20) + def __init__( self, dag_id: Optional[str] = None, @@ -139,6 +145,43 @@ def refresh_from_db(self, session: Session = None): self.id = dr.id self.state = dr.state + @classmethod + def next_dagruns_to_examine( + cls, + session: Session, + max_number: Optional[int] = None, + ): + """ + Return the next DagRuns that the scheduler should attempt to schedule. + + This will return zero or more DagRun rows that are row-level-locked with a "SELECT ... FOR UPDATE" + query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as + the transaction is commited it will be unlocked. + + :rtype: list[DagRun] + """ + from airflow.models.dag import DagModel + + if max_number is None: + max_number = cls.DEFAULT_DAGRUNS_TO_EXAMINE + + # TODO: Bake this query, it is run _A lot_ + query = session.query(cls).filter( + cls.state == State.RUNNING, + cls.run_type != DagRunType.BACKFILL_JOB.value + ).join( + DagModel, + DagModel.dag_id == cls.dag_id, + ).filter( + DagModel.is_paused.is_(False), + DagModel.is_active.is_(True), + ).order_by( + cls.last_scheduling_decision.nullsfirst(), + cls.execution_date, + ).limit(max_number).with_for_update(**skip_locked(of=cls, session=session)) + + return query + @staticmethod @provide_session def find( @@ -311,6 +354,10 @@ def update_state(self, session: Session = None) -> List[TI]: :return: ready_tis: the tis that can be scheduled in the current loop :rtype ready_tis: list[airflow.models.TaskInstance] """ + + start_dttm = timezone.utcnow() + self.last_scheduling_decision = start_dttm + dag = self.get_dag() ready_tis: List[TI] = [] tis = list(self.get_task_instances(session=session, state=State.task_states + (State.SHUTDOWN,))) @@ -318,7 +365,6 @@ def update_state(self, session: Session = None) -> List[TI]: for ti in tis: ti.task = dag.get_task(ti.task_id) - start_dttm = timezone.utcnow() unfinished_tasks = [t for t in tis if t.state in State.unfinished()] finished_tasks = [t for t in tis if t.state in State.finished() + [State.UPSTREAM_FAILED]] none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks) @@ -342,6 +388,9 @@ def update_state(self, session: Session = None) -> List[TI]: leaf_task_ids = {t.task_id for t in dag.leaves} leaf_tis = [ti for ti in tis if ti.task_id in leaf_task_ids] + # TODO[ha]: These callbacks shouldn't run in the scheduler loop - check if Kamil changed this to run + # via the dag processor! + # if all roots finished and at least one failed, the run failed if not unfinished_tasks and any( leaf_ti.state in {State.FAILED, State.UPSTREAM_FAILED} for leaf_ti in leaf_tis @@ -372,10 +421,6 @@ def update_state(self, session: Session = None) -> List[TI]: self._emit_duration_stats_for_finished_state() - # todo: determine we want to use with_for_update to make sure to lock the run - session.merge(self) - session.commit() - return ready_tis def _get_ready_tis( @@ -492,12 +537,13 @@ def verify_integrity(self, session: Session = None): session.add(ti) try: - session.commit() + session.flush() except IntegrityError as err: self.log.info(str(err)) self.log.info('Hit IntegrityError while creating the TIs for ' f'{dag.dag_id} - {self.execution_date}.') self.log.info('Doing session rollback.') + # TODO[HA]: We probaly need to savepoint this so we can keep the transaction alive. session.rollback() @staticmethod diff --git a/airflow/models/pool.py b/airflow/models/pool.py index 1fbc3ad0aae23..ee001bd5824cd 100644 --- a/airflow/models/pool.py +++ b/airflow/models/pool.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict, Iterable, Optional, Tuple +from typing import Dict, Iterable, Optional, Tuple, Union from sqlalchemy import Column, Integer, String, Text, func from sqlalchemy.orm.session import Session @@ -81,7 +81,11 @@ def get_default_pool(session: Session = None): @staticmethod @provide_session - def slots_stats(session: Session = None) -> Dict[str, PoolStats]: + def slots_stats( + *, + with_for_update: Union[bool, Dict] = False, + session: Session = None, + ) -> Dict[str, PoolStats]: """ Get Pool stats (Number of Running, Queued, Open & Total tasks) @@ -91,7 +95,15 @@ def slots_stats(session: Session = None) -> Dict[str, PoolStats]: pools: Dict[str, PoolStats] = {} - pool_rows: Iterable[Tuple[str, int]] = session.query(Pool.pool, Pool.slots).all() + query = session.query(Pool.pool, Pool.slots) + + if with_for_update: + if isinstance(with_for_update, bool): + query = query.with_for_update() + else: + query = query.with_for_update(**with_for_update) + + pool_rows: Iterable[Tuple[str, int]] = query.all() for (pool_name, total_slots) in pool_rows: pools[pool_name] = PoolStats(total=total_slots, running=0, queued=0, open=0) diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py index 12e4c8c65cef7..965d1f0cf8582 100644 --- a/airflow/models/serialized_dag.py +++ b/airflow/models/serialized_dag.py @@ -25,12 +25,13 @@ import sqlalchemy_jsonfield from sqlalchemy import BigInteger, Column, Index, String, and_ -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, backref, relationship from sqlalchemy.sql import exists from airflow.models.base import ID_LEN, Base from airflow.models.dag import DAG, DagModel from airflow.models.dagcode import DagCode +from airflow.models.dagrun import DagRun from airflow.serialization.serialized_objects import SerializedDAG from airflow.settings import MIN_SERIALIZED_DAG_UPDATE_INTERVAL, json from airflow.utils import timezone @@ -73,6 +74,22 @@ class SerializedDagModel(Base): Index('idx_fileloc_hash', fileloc_hash, unique=False), ) + dag_runs = relationship( + DagRun, + primaryjoin=dag_id == DagRun.dag_id, + foreign_keys=dag_id, + backref=backref('serialized_dag', uselist=False, innerjoin=True), + ) + + dag_model = relationship( + DagModel, + primaryjoin=dag_id == DagModel.dag_id, # type: ignore + foreign_keys=dag_id, + uselist=False, + innerjoin=True, + backref=backref('serialized_dag', uselist=False, innerjoin=True), + ) + def __init__(self, dag: DAG): self.dag_id = dag.dag_id self.fileloc = dag.full_filepath diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 335b4dd4a573f..5d53e9fc4bc93 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -34,7 +34,7 @@ import pendulum from jinja2 import TemplateAssertionError, UndefinedError from sqlalchemy import Column, Float, Index, Integer, PickleType, String, and_, func, or_ -from sqlalchemy.orm import reconstructor +from sqlalchemy.orm import reconstructor, relationship from sqlalchemy.orm.session import Session from sqlalchemy.sql.elements import BooleanClauseList @@ -237,6 +237,14 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904 Index('ti_job_id', job_id), ) + dag_model = relationship( + "DagModel", + primaryjoin="TaskInstance.dag_id == DagModel.dag_id", + foreign_keys=dag_id, + uselist=False, + innerjoin=True, + ) + def __init__(self, task, execution_date: datetime, state: Optional[str] = None): super().__init__() self.dag_id = task.dag_id diff --git a/airflow/stats.py b/airflow/stats.py index 5913f765534f5..641f3e3a8e5bd 100644 --- a/airflow/stats.py +++ b/airflow/stats.py @@ -30,6 +30,23 @@ log = logging.getLogger(__name__) +class TimerProtocol(Protocol): + """Type protocol for StatsLogger.timer""" + def __enter__(self): + ... + + def __exit__(self, exc_type, exc_value, traceback): + ... + + def start(self): + """Start the timer""" + ... + + def stop(self, send=True): + """Stop, and (by default) submit the timer to statsd""" + ... + + class StatsLogger(Protocol): """This class is only used for TypeChecking (for IDEs, mypy, pylint, etc)""" @@ -49,6 +66,25 @@ def gauge(cls, stat: str, value: float, rate: int = 1, delta: bool = False) -> N def timing(cls, stat: str, dt) -> None: """Stats timing""" + @classmethod + def timer(cls, *args, **kwargs) -> TimerProtocol: + """Timer metric that can be cancelled""" + + +class DummyTimer: + """No-op timer""" + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + return self + + def start(self): + """Start the timer""" + + def stop(self, send=True): # pylint: disable=unused-argument + """Stop, and (by default) submit the timer to statsd""" + class DummyStatsLogger: """If no StatsLogger is configured, DummyStatsLogger is used as a fallback""" @@ -69,6 +105,11 @@ def gauge(cls, stat, value, rate=1, delta=False): def timing(cls, stat, dt): """Stats timing""" + @classmethod + def timer(cls, *args, **kwargs): + """Timer metric that can be cancelled""" + return DummyTimer() + # Only characters in the character set are considered valid # for the stat_name if stat_name_default_handler is used. @@ -171,6 +212,13 @@ def timing(self, stat, dt): return self.statsd.timing(stat, dt) return None + @validate_stat + def timer(self, stat, *args, **kwargs): + """Timer metric that can be cancelled""" + if self.allow_list_validator.test(stat): + return self.statsd.timer(stat, *args, **kwargs) + return DummyTimer() + class SafeDogStatsdLogger: """DogStatsd Logger""" @@ -211,6 +259,14 @@ def timing(self, stat, dt, tags=None): return self.dogstatsd.timing(metric=stat, value=dt, tags=tags) return None + @validate_stat + def timer(self, stat, *args, tags=None, **kwargs): + """Timer metric that can be cancelled""" + if self.allow_list_validator.test(stat): + tags = tags or [] + return self.dogstatsd.timer(stat, *args, tags=tags, **kwargs) + return DummyTimer() + class _Stats(type): instance: Optional[StatsLogger] = None diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py index 9363c6b052c16..ee98c352f69c5 100644 --- a/airflow/utils/dag_processing.py +++ b/airflow/utils/dag_processing.py @@ -30,7 +30,7 @@ from datetime import datetime, timedelta from importlib import import_module from multiprocessing.connection import Connection as MultiprocessingConnection -from typing import Callable, Dict, KeysView, List, NamedTuple, Optional, Tuple +from typing import Any, Callable, Dict, KeysView, List, NamedTuple, Optional, Union, cast from setproctitle import setproctitle # pylint: disable=no-name-in-module from sqlalchemy import or_ @@ -149,12 +149,12 @@ def done(self) -> bool: @property @abstractmethod - def result(self) -> Optional[Tuple[List[dict], int]]: + def result(self) -> Optional[int]: """ A list of simple dags found, and the number of import errors :return: result of running SchedulerJob.process_file() if availlablle. Otherwise, none - :rtype: Optional[Tuple[List[dict], int]] + :rtype: Optional[int] """ raise NotImplementedError() @@ -280,7 +280,6 @@ def __init__( self._all_files_processed = True self._parent_signal_conn: Optional[MultiprocessingConnection] = None - self._collected_dag_buffer: List = [] self._last_parsing_stat_received_at: float = time.monotonic() @@ -421,11 +420,9 @@ def _run_processor_manager( processor_manager.start() - def harvest_serialized_dags(self) -> List[SerializedDAG]: + def heartbeat(self) -> None: """ - Harvest DAG parsing results from result queue and sync metadata from stat queue. - - :return: List of parsing result in SerializedDAG format. + Check if the DagFileProcessorManager process is alive, and process any pending messages """ if not self._parent_signal_conn: raise ValueError("Process not started.") @@ -436,20 +433,16 @@ def harvest_serialized_dags(self) -> List[SerializedDAG]: except (EOFError, ConnectionError): break self._process_message(result) - serialized_dags = self._collected_dag_buffer - self._collected_dag_buffer = [] # If it died unexpectedly restart the manager process self._heartbeat_manager() - return serialized_dags - def _process_message(self, message): self.log.debug("Received message of type %s", type(message).__name__) if isinstance(message, DagParsingStat): self._sync_metadata(message) else: - self._collected_dag_buffer.append(SerializedDAG.from_dict(message)) + raise RuntimeError(f"Unexpected message recieved of type {type(message).__name__}") def _heartbeat_manager(self): """ @@ -624,7 +617,9 @@ def __init__(self, self._log = logging.getLogger('airflow.processor_manager') - self.waitables = {self._signal_conn: self._signal_conn} + self.waitables: Dict[Any, Union[MultiprocessingConnection, AbstractDagFileProcessorProcess]] = { + self._signal_conn: self._signal_conn, + } def register_exit_signals(self): """ @@ -724,11 +719,9 @@ def _run_parsing_loop(self): if not processor: continue - serialized_dags = self._collect_results_from_processor(processor) + self._collect_results_from_processor(processor) self.waitables.pop(sentinel) self._processors.pop(processor.file_path) - for serialized_dag in serialized_dags: - self._signal_conn.send(serialized_dag) self._refresh_dag_dir() self._find_zombies() # pylint: disable=no-value-for-parameter @@ -756,9 +749,7 @@ def _run_parsing_loop(self): self.wait_until_finished() # Collect anything else that has finished, but don't kick off any more processors - serialized_dags = self.collect_results() - for serialized_dag in serialized_dags: - self._signal_conn.send(serialized_dag) + self.collect_results() self._print_stat() @@ -1039,22 +1030,23 @@ def wait_until_finished(self): while not processor.done: time.sleep(0.1) - def _collect_results_from_processor(self, processor): + def _collect_results_from_processor(self, processor) -> None: self.log.debug("Processor for %s finished", processor.file_path) Stats.decr('dag_processing.processes') last_finish_time = timezone.utcnow() if processor.result is not None: - dags, count_import_errors = processor.result + count_import_errors = processor.result else: self.log.error( "Processor for %s exited with return code %s.", processor.file_path, processor.exit_code ) - dags, count_import_errors = [], -1 + count_import_errors = -1 stat = DagFileStat( - num_dags=len(dags), + # TODO: Return number of dags, number of errors? + num_dags=0, import_errors=count_import_errors, last_finish_time=last_finish_time, last_duration=(last_finish_time - processor.start_time).total_seconds(), @@ -1062,26 +1054,19 @@ def _collect_results_from_processor(self, processor): ) self._file_stats[processor.file_path] = stat - return dags - - def collect_results(self): + def collect_results(self) -> None: """ Collect the result from any finished DAG processors - - :return: a list of dicts that were produced by processors that - have finished since the last time this was called - :rtype: list[dict] """ - # Collect all the DAGs that were found in the processed files - serialized_dags = [] - ready = multiprocessing.connection.wait(self.waitables.keys() - [self._signal_conn], timeout=0) for sentinel in ready: - processor = self.waitables[sentinel] + if sentinel is self._signal_conn: + continue + processor = cast(AbstractDagFileProcessorProcess, self.waitables[sentinel]) self.waitables.pop(processor.waitable_handle) self._processors.pop(processor.file_path) - serialized_dags += self._collect_results_from_processor(processor) + self._collect_results_from_processor(processor) self.log.debug("%s/%s DAG parsing processes running", len(self._processors), self._parallelism) @@ -1089,8 +1074,6 @@ def collect_results(self): self.log.debug("%s file paths queued for processing", len(self._file_path_queue)) - return serialized_dags - def start_new_processes(self): """ Start more processors if we have enough slots and files to process From 06c07befd8f42cdaec14c932888f68aef2d136a3 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 16 Sep 2020 12:41:20 +0100 Subject: [PATCH 02/70] fixup! Officially support running more than one scheduler concurrently. --- airflow/jobs/scheduler_job.py | 90 ++++++++++++++++++----------- docs/logging-monitoring/metrics.rst | 15 +++-- 2 files changed, 66 insertions(+), 39 deletions(-) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index cb34736fb3ed9..a30c9812bee98 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -18,7 +18,6 @@ # under the License. # import datetime -import enum import logging import multiprocessing import operator @@ -31,7 +30,7 @@ from contextlib import ExitStack, redirect_stderr, redirect_stdout, suppress from datetime import timedelta from multiprocessing.connection import Connection as MultiprocessingConnection -from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Union +from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple from setproctitle import setproctitle from sqlalchemy import and_, func, not_, or_ @@ -706,12 +705,6 @@ class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes } heartrate: int = conf.getint('scheduler', 'SCHEDULER_HEARTBEAT_SEC') - # Singleton object pattern, PEP-484 style - class _NoLockObtained(enum.Enum): - token = 0 - - NO_LOCK_OBTAINED = _NoLockObtained.token - def __init__( self, dag_id: Optional[str] = None, @@ -1116,7 +1109,7 @@ def _enqueue_task_instances_with_queued_state( queue=queue, ) - def _execute_task_instances(self, dag_bag: DagBag, session: Session) -> int: + def _critical_section_execute_task_instances(self, dag_bag: DagBag, session: Session) -> int: """ Attempts to execute TaskInstances that should be executed by the scheduler. @@ -1126,6 +1119,12 @@ def _execute_task_instances(self, dag_bag: DagBag, session: Session) -> int: 2. Change the state for the TIs above atomically. 3. Enqueue the TIs in the executor. + HA note: This function is a "critical section" meaning that only a single executor process can execute + this function at the same time. This is achived by doing ``SELECT ... from pool FOR UPDATE``. For DBs + that support NOWAIT, a "blocked" scheduler will skip this and continue on with other tasks (creating + new DAG runs, progressing TIs from None to SCHEDULED etc.); DBs that don't support this (such as + MariaDB or MySQL 5.x) the other schedulers will wait for the lock before continuiung. + :param dag_bag: TaskInstances associated with DAGs in the dag_bag will be fetched from the DB and executed :type dag_bag: airflow.models.DagBag @@ -1349,17 +1348,7 @@ def _run_scheduler_loop(self) -> None: self.processor_agent.wait_until_finished() with create_session() as session: - timer = Stats.timer('scheduler.critical_section_duration') - timer.start() - num_queued_tis = self._scheduler_loop_critical_section(dag_bag, session) - - if num_queued_tis is self.NO_LOCK_OBTAINED: - Stats.incr('scheduler.critical_section_lock_busy') - num_queued_tis = 0 - else: - # Make sure we only sent this metric if we obtained the lock, otherwise we'll skew the - # metric, way down - timer.stop(send=True) + num_queued_tis = self._do_scheduling(dag_bag, session) self.executor.heartbeat() session.expunge_all() @@ -1388,8 +1377,22 @@ def _run_scheduler_loop(self) -> None: ) break - def _scheduler_loop_critical_section(self, dag_bag, session) -> Union[int, _NoLockObtained]: + def _do_scheduling(self, dag_bag, session) -> int: """ + This function is where the main scheduling decisions take places. It: + + - Creates any necessary DAG runs by examining the next_dagrun_create_after column of DagModel + + - Finds the "next n oldest" running DAG Runs to examine for scheduling (n=20 by default) and tries to + progress state (TIs to SCHEDULED, or DagRuns to SUCCESS/FAILURE etc) + + By "next oldest", we mean hasn't been examined/scheduled in the most time. + + - Then, via a Critical Section (locking the rows of the Pool model) we queue tasks, and then send them + to the executor. + + See docs of _critical_section_execute_task_instances for more. + :return: Number of TIs enqueued in this iteration :rtype: int """ @@ -1420,7 +1423,8 @@ def validate_commit(_): session=session ) - # Check max_active_runs, to see if we are now at the limit for this dag? + # Check max_active_runs, to see if we are _now_ at the limit for this dag? (we've just created + # one after all) active_runs_of_dag = dict(session.query(DagRun.dag_id, func.count('*')).filter( DagRun.dag_id == dag_model.dag_id, DagRun.state == State.RUNNING, # pylint: disable=comparison-with-callable @@ -1477,24 +1481,42 @@ def validate_commit(_): session.expunge_all() # END: schedule TIs - # Find anything TIs in state SCHEDULED, try to QUEUE it (send it to the executor) - # TODO[HA]: Do we need to call # _change_state_for_tis_without_dagrun (2x) that we were before # to tidy up manually tweaked TIs. Do we need to do it every # time? - return self._execute_task_instances(dag_bag, session=session) + try: + timer = Stats.timer('scheduler.critical_section_duration') + timer.start() - # End of loop, allowed/expected to commit - except OperationalError as e: - # Postgres: lock not available - if getattr(e.orig, 'pgcode') == '55P03': - # We could test if e.orig is an instance of psycopg2.errors.LockNotAvailable, but that - # involves importing it. This doesn't - self.log.debug("Critical section lock held by another Scheduler") - return self.NO_LOCK_OBTAINED - raise + # Find anything TIs in state SCHEDULED, try to QUEUE it (send it to the executor) + num_queued_tis = self._critical_section_execute_task_instances(dag_bag, session=session) + + # Make sure we only sent this metric if we obtained the lock, otherwise we'll skew the + # metric, way down + timer.stop(send=True) + except OperationalError as e: + timer.stop(send=False) + + # DB specific error codes: + # Postgres: 55P03 + # MySQL: 3572, 'Statement aborted because lock(s) could not be acquired immediately and NOWAIT + # is set.' + # MySQL: 1205, 'Lock wait timeout exceeded; try restarting transaction + # (when NOWAIT isn't available) + db_err_code = getattr(e.orig, 'pgcode', None) or e.orig.args[0] + + # We could test if e.orig is an instance of + # psycopg2.errors.LockNotAvailable/_mysql_exceptions.OperationalError, but that involves + # importing it. This doesn't + if db_err_code in ('55P03', 1205, 3572): + self.log.debug("Critical section lock held by another Scheduler") + Stats.incr('scheduler.critical_section_busy') + return 0 + raise + + return num_queued_tis finally: event.remove(session.bind, 'commit', validate_commit) diff --git a/docs/logging-monitoring/metrics.rst b/docs/logging-monitoring/metrics.rst index bdff897b654ba..227959ca6753e 100644 --- a/docs/logging-monitoring/metrics.rst +++ b/docs/logging-monitoring/metrics.rst @@ -89,6 +89,9 @@ Name Description ``scheduler.tasks.starving`` Number of tasks that cannot be scheduled because of no open slot in pool ``scheduler.orphaned_tasks.cleared`` Number of Orphaned tasks cleared by the Scheduler ``scheduler.orphaned_tasks.adopted`` Number of Orphaned tasks adopted by the Scheduler +``scheduler.critical_section_busy`` Count of times a scheduler process tried to get a lock on the critical + seciton (needed to send tasks to the executor) and found it locked by + another process. ``sla_email_notification_failure`` Number of failed SLA miss email notification attempts ``ti.start..`` Number of started task in a given dag. Similar to _start but for task ``ti.finish...`` Number of completed task in a given dag. Similar to _end but for task @@ -124,14 +127,16 @@ Name Description Timers ------ -=========================================== ================================================= +=========================================== ================================================================= Name Description -=========================================== ================================================= +=========================================== ================================================================= ``dagrun.dependency-check.`` Milliseconds taken to check DAG dependencies ``dag...duration`` Milliseconds taken to finish a task ``dag_processing.last_duration.`` Milliseconds taken to load the given DAG file ``dagrun.duration.success.`` Milliseconds taken for a DagRun to reach success state ``dagrun.duration.failed.`` Milliseconds taken for a DagRun to reach failed state -``dagrun.schedule_delay.`` Milliseconds of delay between the scheduled DagRun - start date and the actual DagRun start date -=========================================== ================================================= +``dagrun.schedule_delay.`` Milliseconds of delay between the scheduled DagRun start date and + the actual DagRun start date +``scheduler.critical_section_duration`` Millseconds spent in the critical section of scheduler loop -- + only a single scheduler can enter this loop at a time +=========================================== ================================================================= From 38b049c854c5caeeca835baa0320c29bdd820b6c Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 16 Sep 2020 12:51:05 +0100 Subject: [PATCH 03/70] fixup! Officially support running more than one scheduler concurrently. --- airflow/models/dag.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 283e8154e9f04..9de3eedeecc1f 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -560,6 +560,8 @@ def next_dagrun_after_date(self, date_last_automated_dagrun: Optional[pendulum.D if next_run_date > min_task_end_date: return None + return next_run_date + def get_run_dates(self, start_date, end_date=None): """ Returns a list of dates between the interval received as parameter using this From f46abde416f1327f972f748e0895748a24f88887 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 16 Sep 2020 14:28:22 +0100 Subject: [PATCH 04/70] fixup! Officially support running more than one scheduler concurrently. --- airflow/jobs/scheduler_job.py | 18 +++++++++--------- airflow/utils/dag_processing.py | 12 ++++++------ 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index a30c9812bee98..821407d947188 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -98,7 +98,7 @@ def __init__( # The process that was launched to process the given . self._process: Optional[multiprocessing.process.BaseProcess] = None # The result of Scheduler.process_file(file_path). - self._result: Optional[int] = None + self._result: Optional[Tuple[int, int]] = None # Whether the process is done running. self._done = False # When the process started. @@ -175,7 +175,7 @@ def _run_file_processor( log.info("Started process (PID=%s) to work on %s", os.getpid(), file_path) dag_file_processor = DagFileProcessor(dag_ids=dag_ids, log=log) - result: int = dag_file_processor.process_file( + result: Tuple[int, int] = dag_file_processor.process_file( file_path=file_path, pickle_dags=pickle_dags, failure_callback_requests=failure_callback_requests, @@ -334,7 +334,7 @@ def done(self) -> bool: return False @property - def result(self) -> Optional[int]: + def result(self) -> Optional[Tuple[int, int]]: """ :return: result of running SchedulerJob.process_file() :rtype: int or None @@ -601,7 +601,7 @@ def process_file( failure_callback_requests: List[FailureCallbackRequest], pickle_dags: bool = False, session: Session = None - ) -> int: + ) -> Tuple[int, int]: """ Process a Python file containing Airflow DAGs. @@ -627,8 +627,8 @@ def process_file( :type pickle_dags: bool :param session: Sqlalchemy ORM Session :type session: Session - :return: count of import errors - :rtype: int + :return: number of dags found, count of import errors + :rtype: Tuple[int, int] """ self.log.info("Processing file %s for tasks to queue", file_path) @@ -637,14 +637,14 @@ def process_file( except Exception: # pylint: disable=broad-except self.log.exception("Failed at reloading the DAG file %s", file_path) Stats.incr('dag_file_refresh_error', 1, 1) - return 0 + return 0, 0 if len(dagbag.dags) > 0: self.log.info("DAG(s) %s retrieved from %s", dagbag.dags.keys(), file_path) else: self.log.warning("No viable dags retrieved from %s", file_path) self.update_import_errors(session, dagbag) - return len(dagbag.import_errors) + return 0, len(dagbag.import_errors) try: self.execute_on_failure_callbacks(dagbag, failure_callback_requests) @@ -671,7 +671,7 @@ def process_file( except Exception: # pylint: disable=broad-except self.log.exception("Error logging import errors!") - return len(dagbag.import_errors) + return len(dagbag.dags), len(dagbag.import_errors) class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py index ee98c352f69c5..ddc20687e3c2e 100644 --- a/airflow/utils/dag_processing.py +++ b/airflow/utils/dag_processing.py @@ -30,7 +30,7 @@ from datetime import datetime, timedelta from importlib import import_module from multiprocessing.connection import Connection as MultiprocessingConnection -from typing import Any, Callable, Dict, KeysView, List, NamedTuple, Optional, Union, cast +from typing import Any, Callable, Dict, KeysView, List, NamedTuple, Optional, Tuple, Union, cast from setproctitle import setproctitle # pylint: disable=no-name-in-module from sqlalchemy import or_ @@ -149,12 +149,12 @@ def done(self) -> bool: @property @abstractmethod - def result(self) -> Optional[int]: + def result(self) -> Optional[Tuple[int, int]]: """ A list of simple dags found, and the number of import errors :return: result of running SchedulerJob.process_file() if availlablle. Otherwise, none - :rtype: Optional[int] + :rtype: Optional[Tuple[int, int]] """ raise NotImplementedError() @@ -1036,17 +1036,17 @@ def _collect_results_from_processor(self, processor) -> None: last_finish_time = timezone.utcnow() if processor.result is not None: - count_import_errors = processor.result + num_dags, count_import_errors = processor.result else: self.log.error( "Processor for %s exited with return code %s.", processor.file_path, processor.exit_code ) count_import_errors = -1 + num_dags = 0 stat = DagFileStat( - # TODO: Return number of dags, number of errors? - num_dags=0, + num_dags=num_dags, import_errors=count_import_errors, last_finish_time=last_finish_time, last_duration=(last_finish_time - processor.start_time).total_seconds(), From ae757284bc64d2da4e3e07b50582f2907414e3b8 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 16 Sep 2020 15:46:36 +0100 Subject: [PATCH 05/70] fixup! Officially support running more than one scheduler concurrently. --- airflow/jobs/scheduler_job.py | 8 ++++---- airflow/models/dagbag.py | 2 +- airflow/models/dagrun.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 821407d947188..4c9f413c23533 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -1415,7 +1415,7 @@ def validate_commit(_): dag = dag_bag.get_dag(dag_model.dag_id, session=session) next_run_date = dag_model.next_dagrun dag.create_dagrun( - run_type=DagRunType.SCHEDULED.value, + run_type=DagRunType.SCHEDULED, execution_date=next_run_date, start_date=timezone.utcnow(), state=State.RUNNING, @@ -1425,11 +1425,11 @@ def validate_commit(_): # Check max_active_runs, to see if we are _now_ at the limit for this dag? (we've just created # one after all) - active_runs_of_dag = dict(session.query(DagRun.dag_id, func.count('*')).filter( + active_runs_of_dag = session.query(func.count('*')).filter( DagRun.dag_id == dag_model.dag_id, DagRun.state == State.RUNNING, # pylint: disable=comparison-with-callable - DagRun.external_trigger.is_(False) - ).scalar()) + DagRun.external_trigger.is_(False), + ).scalar() # TODO[HA]: add back in dagrun.timeout diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index ac2c253fe8522..5f3d547626745 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -189,7 +189,7 @@ def get_dag(self, dag_id, session: Session = None): root_dag_id = dag.parent_dag.dag_id # type: ignore # If DAG Model is absent, we can't check last_expired property. Is the DAG not yet synchronized? - orm_dag = DagModel.get_current(root_dag_id) + orm_dag = DagModel.get_current(root_dag_id, session=session) if not orm_dag: return self.dags.get(dag_id) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index f5e2c359bb63e..c742bd4ce0945 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -176,7 +176,7 @@ def next_dagruns_to_examine( DagModel.is_paused.is_(False), DagModel.is_active.is_(True), ).order_by( - cls.last_scheduling_decision.nullsfirst(), + cls.last_scheduling_decision, cls.execution_date, ).limit(max_number).with_for_update(**skip_locked(of=cls, session=session)) From 6ffb76273a94472ff4f96d276c33371af17d5f8c Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Wed, 16 Sep 2020 16:29:39 +0100 Subject: [PATCH 06/70] fixup! Officially support running more than one scheduler concurrently. --- airflow/models/dag.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 9de3eedeecc1f..cf542a852d64d 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -57,7 +57,7 @@ from airflow.utils.helpers import validate_key from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import provide_session -from airflow.utils.sqlalchemy import Interval, UtcDateTime +from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked from airflow.utils.state import State from airflow.utils.types import DagRunType @@ -2087,7 +2087,7 @@ def dags_needing_dagruns(cls, session: Session): cls.next_dagrun_create_after <= func.now(), ).order_by( cls.next_dagrun_create_after - ).limit(10).with_for_update(of=cls, skip_locked=True) + ).limit(10).with_for_update(**skip_locked(of=cls, session=session)) STATICA_HACK = True From d89827b02a72dbc8ff17a5b41a13ee693fabfd22 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Wed, 16 Sep 2020 17:12:07 +0100 Subject: [PATCH 07/70] fixup! Officially support running more than one scheduler concurrently. --- airflow/jobs/scheduler_job.py | 4 ++-- airflow/utils/sqlalchemy.py | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 4c9f413c23533..4f27935be949d 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -57,7 +57,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter, set_context from airflow.utils.mixins import MultiprocessingStartMethodMixin from airflow.utils.session import create_session, provide_session -from airflow.utils.sqlalchemy import skip_locked +from airflow.utils.sqlalchemy import nowait, skip_locked from airflow.utils.state import State from airflow.utils.types import DagRunType @@ -901,7 +901,7 @@ def _executable_task_instances_to_queued( # Get the pool settings. We get a lock on the pool rows, treating this as a "critical section" # Throws an exception if lock cannot be obtained, rather than blocking - pools = models.Pool.slots_stats(with_for_update={'nowait': True}, session=session) + pools = models.Pool.slots_stats(with_for_update=nowait(session), session=session) # If the pools are full, there is no point doing anything! max_tis = min(max_tis, sum(map(operator.itemgetter('open'), pools.values()))) diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py index d5f8f13a5bbbe..88b1a1f0b333a 100644 --- a/airflow/utils/sqlalchemy.py +++ b/airflow/utils/sqlalchemy.py @@ -141,3 +141,24 @@ def skip_locked(session: Session) -> Dict[str, Any]: return {'skip_locked': True} else: return {} + + +def nowait(session: Session) -> Dict[str, Any]: + """ + Return kwargs for passing to `with_for_update()` suitable for the current DB engine version. + + We do this as we document the fact that on DB engines that don't support this construct, we do not + support/recommend running HA scheduler. If a user ignores this and tries anyway everything will still + work, just slightly slower in some circumstances. + + Specifically don't emit NOWAIT for MySQL < 8, or MariaDB, neither of which support this construct + + See https://jira.mariadb.org/browse/MDEV-13115 + """ + + dialect = session.bind.dialect + + if dialect.name != "mysql" or dialect.supports_for_update_of: + return {'nowait': True} + else: + return {} From 254aff1077caa474df14d28c97f564261c78d1ed Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Wed, 16 Sep 2020 17:20:34 +0100 Subject: [PATCH 08/70] fixup! Officially support running more than one scheduler concurrently. --- tests/utils/test_sqlalchemy.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_sqlalchemy.py b/tests/utils/test_sqlalchemy.py index d59bbc96d1343..ac8a52a776ef9 100644 --- a/tests/utils/test_sqlalchemy.py +++ b/tests/utils/test_sqlalchemy.py @@ -26,7 +26,7 @@ from airflow import settings from airflow.models import DAG from airflow.settings import Session -from airflow.utils.sqlalchemy import skip_locked +from airflow.utils.sqlalchemy import nowait, skip_locked from airflow.utils.state import State from airflow.utils.timezone import utcnow @@ -110,6 +110,18 @@ def test_skip_locked(self, dialect, supports_for_update_of, expected_return_valu session.bind.dialect.supports_for_update_of = supports_for_update_of self.assertEqual(skip_locked(session=session), expected_return_value) + @parameterized.expand([ + ("postgresql", True, {'nowait': True}, ), + ("mysql", False, {}, ), + ("mysql", True, {'nowait': True}, ), + ("sqlite", False, {'nowait': True, }, ), + ]) + def test_nowait(self, dialect, supports_for_update_of, expected_return_value): + session = mock.Mock() + session.bind.dialect.name = dialect + session.bind.dialect.supports_for_update_of = supports_for_update_of + self.assertEqual(nowait(session=session), expected_return_value) + def tearDown(self): self.session.close() settings.engine.dispose() From 8122862227a41b959d545d9028f83ec0e3ae1b6f Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Wed, 16 Sep 2020 19:08:23 +0100 Subject: [PATCH 09/70] fixup! Officially support running more than one scheduler concurrently. --- airflow/jobs/scheduler_job.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 4f27935be949d..40d595fd4ccbd 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -1227,7 +1227,7 @@ def _process_executor_events(self, session: Session = None) -> int: "task says its %s. (Info: %s) Was the task killed externally?" self.log.error(msg, ti, state, ti.state, info) self.processor_agent.send_callback_to_execute( - full_filepath=ti.dag_model.full_filepath, + full_filepath=ti.dag_model.fileloc, task_instance=ti, msg=msg % (ti, state, ti.state, info), ) From 2d4fe383786ae8c0bf56ead54698c2d2fe25d7fe Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Wed, 16 Sep 2020 19:35:30 +0100 Subject: [PATCH 10/70] fixup! Officially support running more than one scheduler concurrently. --- airflow/models/taskinstance.py | 18 ++++++++++++------ .../ti_deps/deps/not_previously_skipped_dep.py | 2 +- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 5d53e9fc4bc93..4c6f30a08cd7e 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1741,12 +1741,15 @@ def xcom_push( dag_id=self.dag_id, execution_date=execution_date or self.execution_date) + @provide_session def xcom_pull( # pylint: disable=inconsistent-return-statements - self, - task_ids: Optional[Union[str, Iterable[str]]] = None, - dag_id: Optional[str] = None, - key: str = XCOM_RETURN_KEY, - include_prior_dates: bool = False) -> Any: + self, + task_ids: Optional[Union[str, Iterable[str]]] = None, + dag_id: Optional[str] = None, + key: str = XCOM_RETURN_KEY, + include_prior_dates: bool = False, + session: Session = None + ) -> Any: """ Pull XComs that optionally meet certain criteria. @@ -1775,6 +1778,8 @@ def xcom_pull( # pylint: disable=inconsistent-return-statements execution_date are returned. If True, XComs from previous dates are returned as well. :type include_prior_dates: bool + :param session: Sqlalchemy ORM Session + :type session: Session """ if dag_id is None: dag_id = self.dag_id @@ -1784,7 +1789,8 @@ def xcom_pull( # pylint: disable=inconsistent-return-statements key=key, dag_ids=dag_id, task_ids=task_ids, - include_prior_dates=include_prior_dates + include_prior_dates=include_prior_dates, + session=session ).with_entities(XCom.value) # Since we're only fetching the values field, and not the diff --git a/airflow/ti_deps/deps/not_previously_skipped_dep.py b/airflow/ti_deps/deps/not_previously_skipped_dep.py index 409d73a89844c..4ecef93ad847b 100644 --- a/airflow/ti_deps/deps/not_previously_skipped_dep.py +++ b/airflow/ti_deps/deps/not_previously_skipped_dep.py @@ -51,7 +51,7 @@ def _get_dep_statuses( continue prev_result = ti.xcom_pull( - task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY + task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY, session=session ) if prev_result is None: From bfe9cad336090cd81f9dcee0db73b9df50493fb2 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Wed, 16 Sep 2020 20:43:27 +0100 Subject: [PATCH 11/70] fixup! Officially support running more than one scheduler concurrently. --- .../executors/celery_kubernetes_executor.py | 22 ++++++++++--------- airflow/models/dag.py | 6 ++--- airflow/models/dagrun.py | 2 +- docs/logging-monitoring/metrics.rst | 4 ++-- docs/spelling_wordlist.txt | 1 + 5 files changed, 19 insertions(+), 16 deletions(-) diff --git a/airflow/executors/celery_kubernetes_executor.py b/airflow/executors/celery_kubernetes_executor.py index 51c1e17368419..ef82c2060585f 100644 --- a/airflow/executors/celery_kubernetes_executor.py +++ b/airflow/executors/celery_kubernetes_executor.py @@ -62,17 +62,19 @@ def start(self) -> None: self.celery_executor.start() self.kubernetes_executor.start() - def queue_command(self, - simple_task_instance: SimpleTaskInstance, - command: CommandType, - priority: int = 1, - queue: Optional[str] = None): + def queue_command( + self, + task_instance: TaskInstance, + command: CommandType, + priority: int = 1, + queue: Optional[str] = None + ): """Queues command via celery or kubernetes executor""" - executor = self._router(simple_task_instance) - self.log.debug("Using executor: %s for %s", - executor.__class__.__name__, simple_task_instance.key - ) - executor.queue_command(simple_task_instance, command, priority, queue) + executor = self._router(task_instance) + self.log.debug( + "Using executor: %s for %s", executor.__class__.__name__, task_instance.key + ) + executor.queue_command(task_instance, command, priority, queue) def queue_task_instance( self, diff --git a/airflow/models/dag.py b/airflow/models/dag.py index cf542a852d64d..0b7dafe63f367 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -488,9 +488,9 @@ def next_dagrun_info(self, date_last_automated_dagrun : Optional[pendulum.DateTi def next_dagrun_after_date(self, date_last_automated_dagrun: Optional[pendulum.DateTime]): """ - Get the next execution date after the given ``date_last_automated_dagrun``, acording to + Get the next execution date after the given ``date_last_automated_dagrun``, according to schedule_interval, start_date, end_date etc. This doesn't check max active run or any other - "concurrency" type limits, it only perofmrs calculations based on the varios date and interval fields + "concurrency" type limits, it only performs calculations based on the various date and interval fields of this dag and it's tasks. :param date_last_automated_dagrun: The execution_date of the last scheduler or @@ -2075,7 +2075,7 @@ def dags_needing_dagruns(cls, session: Session): Return (and lock) a list of Dag objects that are due to create a new DagRun This will return a resultset of rows that is row-level-locked with a "SELECT ... FOR UPDATE" query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as the transaction is - commited it will be unlocked. + committed it will be unlocked. """ # TODO[HA]: Bake this query, it is run _A lot_ diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index c742bd4ce0945..f7e9cbab86fbb 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -156,7 +156,7 @@ def next_dagruns_to_examine( This will return zero or more DagRun rows that are row-level-locked with a "SELECT ... FOR UPDATE" query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as - the transaction is commited it will be unlocked. + the transaction is committed it will be unlocked. :rtype: list[DagRun] """ diff --git a/docs/logging-monitoring/metrics.rst b/docs/logging-monitoring/metrics.rst index 227959ca6753e..2a3c9f4387d59 100644 --- a/docs/logging-monitoring/metrics.rst +++ b/docs/logging-monitoring/metrics.rst @@ -90,7 +90,7 @@ Name Description ``scheduler.orphaned_tasks.cleared`` Number of Orphaned tasks cleared by the Scheduler ``scheduler.orphaned_tasks.adopted`` Number of Orphaned tasks adopted by the Scheduler ``scheduler.critical_section_busy`` Count of times a scheduler process tried to get a lock on the critical - seciton (needed to send tasks to the executor) and found it locked by + section (needed to send tasks to the executor) and found it locked by another process. ``sla_email_notification_failure`` Number of failed SLA miss email notification attempts ``ti.start..`` Number of started task in a given dag. Similar to _start but for task @@ -137,6 +137,6 @@ Name Description ``dagrun.duration.failed.`` Milliseconds taken for a DagRun to reach failed state ``dagrun.schedule_delay.`` Milliseconds of delay between the scheduled DagRun start date and the actual DagRun start date -``scheduler.critical_section_duration`` Millseconds spent in the critical section of scheduler loop -- +``scheduler.critical_section_duration`` Milliseconds spent in the critical section of scheduler loop -- only a single scheduler can enter this loop at a time =========================================== ================================================================= diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 08e798d0d961d..9315dd3ebb94b 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1102,6 +1102,7 @@ reqs resetdb resourceVersion resumable +resultset rfc ricard rideable From 4ce98c7c12ef7a410d66ca4ddf1af58a5ddc10be Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 16 Sep 2020 21:21:13 +0100 Subject: [PATCH 12/70] fixup! Officially support running more than one scheduler concurrently. --- airflow/models/dag.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 0b7dafe63f367..670710e907eaa 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1773,6 +1773,10 @@ def bulk_sync_to_db(cls, dags: Collection["DAG"], session=None): if settings.STORE_DAG_CODE: DagCode.bulk_sync_to_db([dag.fileloc for dag in orm_dags]) + # Issue SQL/finish "Unit of Work", but let @provide_session commit (or if passed a session, let caller + # decide when to commit + session.flush() + @provide_session def sync_to_db(self, session=None): """ From ef766b1e601f9da1b5f5eb4af371c7b9b3d900ca Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Wed, 16 Sep 2020 21:48:40 +0100 Subject: [PATCH 13/70] fixup! Officially support running more than one scheduler concurrently. --- airflow/jobs/scheduler_job.py | 5 ++++- airflow/models/dag.py | 2 +- airflow/models/dagrun.py | 2 +- tests/jobs/test_scheduler_job.py | 12 +++++++++--- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 40d595fd4ccbd..c352c108b59c1 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -923,7 +923,7 @@ def _executable_task_instances_to_queued( .filter(TI.state == State.SCHEDULED) .options(selectinload('dag_model')) .limit(max_tis) - .with_for_update(**skip_locked(of=TI, session=session)) + .with_for_update(of=TI, **skip_locked(session=session)) .all() ) # TODO[HA]: This was wrong before anyway, as it only looked at a sub-set of dags, not everything. @@ -1589,4 +1589,7 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = None): self.log.info("Reset the following %s orphaned TaskInstances:\n\t%s", len(to_reset), task_instance_str) + # Issue SQL/finish "Unit of Work", but let @provide_session commit (or if passed a session, let caller + # decide when to commit + session.flush() return len(to_reset) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 670710e907eaa..92507a4d3a956 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2091,7 +2091,7 @@ def dags_needing_dagruns(cls, session: Session): cls.next_dagrun_create_after <= func.now(), ).order_by( cls.next_dagrun_create_after - ).limit(10).with_for_update(**skip_locked(of=cls, session=session)) + ).limit(10).with_for_update(of=cls, **skip_locked(session=session)) STATICA_HACK = True diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index f7e9cbab86fbb..614bc3cb514bc 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -178,7 +178,7 @@ def next_dagruns_to_examine( ).order_by( cls.last_scheduling_decision, cls.execution_date, - ).limit(max_number).with_for_update(**skip_locked(of=cls, session=session)) + ).limit(max_number).with_for_update(of=cls, **skip_locked(session=session)) return query diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index cfbdbab21b244..9d62f2555432a 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -3579,7 +3579,6 @@ def test_adopt_or_reset_orphaned_tasks_stale_scheduler_jobs(self): DummyOperator(task_id='task1', dag=dag) DummyOperator(task_id='task2', dag=dag) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler_job = SchedulerJob() session = settings.Session() scheduler_job.state = State.RUNNING @@ -3592,7 +3591,14 @@ def test_adopt_or_reset_orphaned_tasks_stale_scheduler_jobs(self): session.add(old_job) session.flush() - dr1 = dag_file_processor.create_dag_run(dag, session=session) + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + start_date=timezone.utcnow(), + state=State.RUNNING, + session=session + ) + ti1, ti2 = dr1.get_task_instances(session=session) dr1.state = State.RUNNING ti1.state = State.SCHEDULED @@ -3606,7 +3612,7 @@ def test_adopt_or_reset_orphaned_tasks_stale_scheduler_jobs(self): session.flush() num_reset_tis = scheduler_job.adopt_or_reset_orphaned_tasks(session=session) - session.flush() + self.assertEqual(1, num_reset_tis) session.refresh(ti1) From 87fed6c7c174e682a892ada9001bc3502455757a Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 16 Sep 2020 21:54:12 +0100 Subject: [PATCH 14/70] fixup! Officially support running more than one scheduler concurrently. --- airflow/models/dag.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 92507a4d3a956..1b745629628ab 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1777,6 +1777,9 @@ def bulk_sync_to_db(cls, dags: Collection["DAG"], session=None): # decide when to commit session.flush() + for dag in dags: + cls.bulk_sync_to_db(dag.subdags, session=session) + @provide_session def sync_to_db(self, session=None): """ From 98402c5e385ca4d6f65dcab8f0d58c85e4f92ef7 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Thu, 17 Sep 2020 10:20:07 +0100 Subject: [PATCH 15/70] fixup! Officially support running more than one scheduler concurrently. --- airflow/jobs/scheduler_job.py | 117 ++++++++++++++++------------------ 1 file changed, 56 insertions(+), 61 deletions(-) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index c352c108b59c1..9f74c15113b99 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -740,6 +740,8 @@ def __init__( self.max_tis_per_query: int = conf.getint('scheduler', 'max_tis_per_query') self.processor_agent: Optional[DagFileProcessorAgent] = None + self.dagbag = DagBag(read_dags_from_db=True) + def register_exit_signals(self) -> None: """ Register signals that stop child processes @@ -880,21 +882,13 @@ def __get_concurrency_maps( # pylint: disable=too-many-locals,too-many-statements @provide_session - def _executable_task_instances_to_queued( - self, - max_tis: int, - dag_bag: DagBag, - session: Session = None - ) -> List[TI]: + def _executable_task_instances_to_queued(self, max_tis: int, session: Session = None) -> List[TI]: """ Finds TIs that are ready for execution with respect to pool limits, dag concurrency, executor state, and priority. :param max_tis: Maximum number of TIs to queue in this loop. :type max_tis: int - :param dag_bag: TaskInstances associated with DAGs in the - _dag_bag will be fetched from the DB and executed - :type dag_bag: airflow.models.DagBag :return: list[airflow.models.TaskInstance] """ executable_tis: List[TI] = [] @@ -1014,7 +1008,7 @@ def _executable_task_instances_to_queued( if task_instance.dag_model.has_task_concurrency_limits: # Many dags don't have a task_concurrency, so where we can avoid loading the full # serialized DAG the better. - serialized_dag = dag_bag.get_dag(dag_id) + serialized_dag = self.dagbag.get_dag(dag_id, session=session) if serialized_dag.has_task(task_instance.task_id): task_concurrency_limit = serialized_dag.get_task( task_instance.task_id).task_concurrency @@ -1109,7 +1103,7 @@ def _enqueue_task_instances_with_queued_state( queue=queue, ) - def _critical_section_execute_task_instances(self, dag_bag: DagBag, session: Session) -> int: + def _critical_section_execute_task_instances(self, session: Session) -> int: """ Attempts to execute TaskInstances that should be executed by the scheduler. @@ -1125,15 +1119,12 @@ def _critical_section_execute_task_instances(self, dag_bag: DagBag, session: Ses new DAG runs, progressing TIs from None to SCHEDULED etc.); DBs that don't support this (such as MariaDB or MySQL 5.x) the other schedulers will wait for the lock before continuiung. - :param dag_bag: TaskInstances associated with DAGs in the - dag_bag will be fetched from the DB and executed - :type dag_bag: airflow.models.DagBag :param session: :type session: sqlalchemy.orm.Session :return: Number of task instance with state changed. """ max_tis = min(self.max_tis_per_query, self.executor.slots_available) - queued_tis = self._executable_task_instances_to_queued(max_tis, dag_bag, session=session) + queued_tis = self._executable_task_instances_to_queued(max_tis, session=session) self._enqueue_task_instances_with_queued_state(queued_tis) return len(queued_tis) @@ -1334,8 +1325,6 @@ def _run_scheduler_loop(self) -> None: raise ValueError("Processor agent is not started.") is_unit_test: bool = conf.getboolean('core', 'unit_test_mode') - dag_bag = DagBag() - # For the execute duration, parse and schedule DAGs while True: loop_start_time = time.time() @@ -1348,7 +1337,7 @@ def _run_scheduler_loop(self) -> None: self.processor_agent.wait_until_finished() with create_session() as session: - num_queued_tis = self._do_scheduling(dag_bag, session) + num_queued_tis = self._do_scheduling(session) self.executor.heartbeat() session.expunge_all() @@ -1377,7 +1366,7 @@ def _run_scheduler_loop(self) -> None: ) break - def _do_scheduling(self, dag_bag, session) -> int: + def _do_scheduling(self, session) -> int: """ This function is where the main scheduling decisions take places. It: @@ -1412,54 +1401,16 @@ def validate_commit(_): query = DagModel.dags_needing_dagruns(session) for dag_model in query: - dag = dag_bag.get_dag(dag_model.dag_id, session=session) - next_run_date = dag_model.next_dagrun - dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=next_run_date, - start_date=timezone.utcnow(), - state=State.RUNNING, - external_trigger=False, - session=session - ) - - # Check max_active_runs, to see if we are _now_ at the limit for this dag? (we've just created - # one after all) - active_runs_of_dag = session.query(func.count('*')).filter( - DagRun.dag_id == dag_model.dag_id, - DagRun.state == State.RUNNING, # pylint: disable=comparison-with-callable - DagRun.external_trigger.is_(False), - ).scalar() - - # TODO[HA]: add back in dagrun.timeout - - if dag.max_active_runs and dag.max_active_runs >= active_runs_of_dag: - self.log.info( - "DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs", - dag.dag_id, active_runs_of_dag, dag.max_active_runs - ) - dag_model.next_dagrun = None - dag_model.next_dagrun_create_after = None - else: - next_dagrun_info = dag.next_dagrun_info(next_run_date) - if next_dagrun_info: - dag_model.next_dagrun = next_dagrun_info['execution_date'] - dag_model.next_dagrun_create_after = next_dagrun_info['can_be_created_after'] - else: - dag_model.next_dagrun = None - dag_model.next_dagrun_create_after = None - - # TODO[HA]: Should we do a session.flush() so we don't have to keep lots of state/object in - # memory for larger dags? or expunge_all() + dag = self.dagbag.get_dag(dag_model.dag_id, session=session) + self._create_dag_run(dag_model, dag, session) # commit the session - Release the write lock on DagModel table. expected_commit = True session.commit() # END: create dagruns - # Tunable limit?, or select multiple dag runs for a single dag? for dag_run in DagRun.next_dagruns_to_examine(session): - dag_run.dag = dag_bag.get_dag(dag_run.dag_id, session=session) + dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session) # TODO[HA]: Run verify_integrity, but only if the serialized_dag has changed @@ -1491,7 +1442,7 @@ def validate_commit(_): timer.start() # Find anything TIs in state SCHEDULED, try to QUEUE it (send it to the executor) - num_queued_tis = self._critical_section_execute_task_instances(dag_bag, session=session) + num_queued_tis = self._critical_section_execute_task_instances(session=session) # Make sure we only sent this metric if we obtained the lock, otherwise we'll skew the # metric, way down @@ -1520,6 +1471,50 @@ def validate_commit(_): finally: event.remove(session.bind, 'commit', validate_commit) + def _create_dag_run(self, dag_model, dag, session): + """ + Unconditionally create a DAG run for the given DAG, and update the dag_model's fields to control + if/when the next DAGRun should be created + """ + next_run_date = dag_model.next_dagrun + dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=next_run_date, + start_date=timezone.utcnow(), + state=State.RUNNING, + external_trigger=False, + session=session + ) + + # Check max_active_runs, to see if we are _now_ at the limit for this dag? (we've just created + # one after all) + active_runs_of_dag = session.query(func.count('*')).filter( + DagRun.dag_id == dag_model.dag_id, + DagRun.state == State.RUNNING, # pylint: disable=comparison-with-callable + DagRun.external_trigger.is_(False), + ).scalar() + + # TODO[HA]: add back in dagrun.timeout + + if dag.max_active_runs and dag.max_active_runs >= active_runs_of_dag: + self.log.info( + "DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs", + dag.dag_id, active_runs_of_dag, dag.max_active_runs + ) + dag_model.next_dagrun = None + dag_model.next_dagrun_create_after = None + else: + next_dagrun_info = dag.next_dagrun_info(next_run_date) + if next_dagrun_info: + dag_model.next_dagrun = next_dagrun_info['execution_date'] + dag_model.next_dagrun_create_after = next_dagrun_info['can_be_created_after'] + else: + dag_model.next_dagrun = None + dag_model.next_dagrun_create_after = None + + # TODO[HA]: Should we do a session.flush() so we don't have to keep lots of state/object in + # memory for larger dags? or expunge_all() + @provide_session def _emit_pool_metrics(self, session: Session = None) -> None: pools = models.Pool.slots_stats(session=session) From 9a891e8dfa851096391419a160e3748966c794e2 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Fri, 18 Sep 2020 23:04:51 +0100 Subject: [PATCH 16/70] fixup! Officially support running more than one scheduler concurrently. --- airflow/cli/cli_parser.py | 2 +- airflow/cli/commands/scheduler_command.py | 1 - airflow/jobs/backfill_job.py | 1 + airflow/jobs/scheduler_job.py | 112 +- airflow/models/dag.py | 29 +- airflow/models/dagbag.py | 12 +- airflow/models/dagrun.py | 8 +- airflow/settings.py | 2 + tests/jobs/test_scheduler_job.py | 1356 ++++++++------------- tests/models/test_dag.py | 508 +++++--- tests/models/test_dagrun.py | 44 +- tests/test_utils/mock_executor.py | 3 +- tests/utils/test_dag_processing.py | 50 +- 13 files changed, 1078 insertions(+), 1050 deletions(-) diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py index 1627900825273..ad526d5c224e2 100644 --- a/airflow/cli/cli_parser.py +++ b/airflow/cli/cli_parser.py @@ -1335,7 +1335,7 @@ class GroupCommand(NamedTuple): help="Start a scheduler instance", func=lazy_load_command('airflow.cli.commands.scheduler_command.scheduler'), args=( - ARG_DAG_ID_OPT, ARG_SUBDIR, ARG_NUM_RUNS, ARG_DO_PICKLE, ARG_PID, ARG_DAEMON, ARG_STDOUT, + ARG_SUBDIR, ARG_NUM_RUNS, ARG_DO_PICKLE, ARG_PID, ARG_DAEMON, ARG_STDOUT, ARG_STDERR, ARG_LOG_FILE ), ), diff --git a/airflow/cli/commands/scheduler_command.py b/airflow/cli/commands/scheduler_command.py index a7109239380fc..f0f019a58ac12 100644 --- a/airflow/cli/commands/scheduler_command.py +++ b/airflow/cli/commands/scheduler_command.py @@ -32,7 +32,6 @@ def scheduler(args): """Starts Airflow Scheduler""" print(settings.HEADER) job = SchedulerJob( - dag_id=args.dag_id, subdir=process_subdir(args.subdir), num_runs=args.num_runs, do_pickle=args.do_pickle) diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py index c60611915ad97..7d8271fa14711 100644 --- a/airflow/jobs/backfill_job.py +++ b/airflow/jobs/backfill_job.py @@ -630,6 +630,7 @@ def _per_task_process(task, key, ti, session=None): # pylint: disable=too-many- _dag_runs = ti_status.active_runs[:] for run in _dag_runs: run.update_state(session=session) + session.merge(run) if run.state in State.finished(): ti_status.finished_runs += 1 ti_status.active_runs.remove(run) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 9f74c15113b99..0a695fdee9d4a 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -18,6 +18,7 @@ # under the License. # import datetime +import itertools import logging import multiprocessing import operator @@ -707,20 +708,12 @@ class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes def __init__( self, - dag_id: Optional[str] = None, - dag_ids: Optional[List[str]] = None, subdir: str = settings.DAGS_FOLDER, num_runs: int = conf.getint('scheduler', 'num_runs'), processor_poll_interval: float = conf.getfloat('scheduler', 'processor_poll_interval'), do_pickle: bool = False, log: Any = None, *args, **kwargs): - # for BaseJob compatibility - self.dag_id = dag_id - self.dag_ids = [dag_id] if dag_id else [] - if dag_ids: - self.dag_ids.extend(dag_ids) - self.subdir = subdir self.num_runs = num_runs @@ -898,20 +891,24 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = pools = models.Pool.slots_stats(with_for_update=nowait(session), session=session) # If the pools are full, there is no point doing anything! - max_tis = min(max_tis, sum(map(operator.itemgetter('open'), pools.values()))) + # If _somehow_ the pool is overfull, don't let the limit go negative - it breaks SQL + pool_slots_free = max(0, sum(map(operator.itemgetter('open'), pools.values()))) - if max_tis == 0: + if pool_slots_free == 0: self.log.debug("All pools are full!") return executable_tis + max_tis = min(max_tis, pool_slots_free) + # Get all task instances associated with scheduled # DagRuns which are not backfilled, in the given states, # and the dag is not paused task_instances_to_examine: List[TI] = ( session .query(TI) - .join(TI.dag_run) - .filter(DR.run_type != DagRunType.BACKFILL_JOB.value) + .outerjoin(TI.dag_run) + .filter(or_(DR.run_id.is_(None), + DR.run_type != DagRunType.BACKFILL_JOB.value)) .join(TI.dag_model) .filter(not_(DM.is_paused)) .filter(TI.state == State.SCHEDULED) @@ -1244,7 +1241,7 @@ def _execute(self) -> None: max_runs=self.num_runs, processor_factory=type(self)._create_dag_file_processor, processor_timeout=processor_timeout, - dag_ids=self.dag_ids, + dag_ids=[], pickle_dags=pickle_dags, async_mode=async_mode, ) @@ -1325,8 +1322,7 @@ def _run_scheduler_loop(self) -> None: raise ValueError("Processor agent is not started.") is_unit_test: bool = conf.getboolean('core', 'unit_test_mode') - # For the execute duration, parse and schedule DAGs - while True: + for loop_count in itertools.count(start=1): loop_start_time = time.time() if self.using_sqlite: @@ -1360,9 +1356,10 @@ def _run_scheduler_loop(self) -> None: # usage when "idle" time.sleep(self._processor_poll_interval) - if self.processor_agent.done: + if self.num_runs > 0 and loop_count >= self.num_runs and self.processor_agent.done: self.log.info( - "Exiting scheduler loop as all files have been processed %d times", self.num_runs + "Exiting scheduler loop as requested number of runs (%d - got to %d) has been reached", + self.num_runs, loop_count, ) break @@ -1389,6 +1386,7 @@ def _do_scheduling(self, session) -> int: from sqlalchemy import event expected_commit = False + # Put a check in place to make sure we don't commit unexpectedly @event.listens_for(session.bind, 'commit') def validate_commit(_): nonlocal expected_commit @@ -1397,8 +1395,6 @@ def validate_commit(_): return raise RuntimeError("UNEXPECTED COMMIT - THIS WILL BREAK HA LOCKS!") - # Put a check in place to make sure we don't commit unexpectedly - query = DagModel.dags_needing_dagruns(session) for dag_model in query: dag = self.dagbag.get_dag(dag_model.dag_id, session=session) @@ -1410,20 +1406,7 @@ def validate_commit(_): # END: create dagruns for dag_run in DagRun.next_dagruns_to_examine(session): - dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session) - - # TODO[HA]: Run verify_integrity, but only if the serialized_dag has changed - - # TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else? - schedulable_tis = dag_run.update_state(session=session) - # TODO[HA]: Don't return, update these form in update_state - session.query(TI).filter( - TI.dag_id == dag_run.dag_id, - TI.execution_date == dag_run.execution_date, - TI.task_id.in_(ti.task_id for ti in schedulable_tis) - ).update({TI.state: State.SCHEDULED}, synchronize_session=False) - - # TODO[HA]: Manage SLAs + self._schedule_dag_run(dag_run, session) expected_commit = True session.commit() @@ -1438,6 +1421,11 @@ def validate_commit(_): # time? try: + if self.executor.slots_available <= 0: + # We know we can't do anything here, so don't even try! + self.log.debug("Executor full, skipping critical section") + return 0 + timer = Stats.timer('scheduler.critical_section_duration') timer.start() @@ -1471,7 +1459,7 @@ def validate_commit(_): finally: event.remove(session.bind, 'commit', validate_commit) - def _create_dag_run(self, dag_model, dag, session): + def _create_dag_run(self, dag_model: DagModel, dag: DAG, session: Session) -> None: """ Unconditionally create a DAG run for the given DAG, and update the dag_model's fields to control if/when the next DAGRun should be created @@ -1496,7 +1484,7 @@ def _create_dag_run(self, dag_model, dag, session): # TODO[HA]: add back in dagrun.timeout - if dag.max_active_runs and dag.max_active_runs >= active_runs_of_dag: + if dag.max_active_runs and active_runs_of_dag >= dag.max_active_runs: self.log.info( "DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs", dag.dag_id, active_runs_of_dag, dag.max_active_runs @@ -1515,6 +1503,60 @@ def _create_dag_run(self, dag_model, dag, session): # TODO[HA]: Should we do a session.flush() so we don't have to keep lots of state/object in # memory for larger dags? or expunge_all() + def _schedule_dag_run(self, dag_run: DagRun, session: Session) -> int: + """ + Make scheduling decisions about an individual dag run + + :return: Number of tasks scheduled + """ + dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session) + + if not dag_run.dag: + self.log.error( + "Couldn't find dag %s in DagBag/DB!", dag_run.dag_id + ) + return 0 + + if dag_run.execution_date > timezone.utcnow() and not dag_run.dag.allow_future_exec_dates: + self.log.error( + "Execution date is in future: %s", + dag_run.execution_date + ) + return 0 + + # TODO: This query is probably horribly inefficient (though there is an + # index on (dag_id,state)). It is to deal with the case when a user + # clears more than max_active_runs older tasks -- we don't want the + # scheduler to suddenly go and start running tasks from all of the + # runs. (AIRFLOW-137/GH #1442) + # + # The longer term fix would be to have `clear` do this, and put DagRuns + # in to the queued state, then take DRs out of queued before creating + # any new ones + if dag_run.dag.max_active_runs: + currently_active_runs = session.query(func.count(TI.execution_date.distinct())).filter( + TI.dag_id == dag_run.dag_id, + TI.state.notin_(State.finished()) + ).scalar() + + if currently_active_runs >= dag_run.dag.max_active_runs: + return 0 + + # TODO[HA]: Run verify_integrity, but only if the serialized_dag has changed + + # TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else? + schedulable_tis = dag_run.update_state(session=session) + # TODO[HA]: Don't return, update these from in update_state? + count = session.query(TI).filter( + TI.dag_id == dag_run.dag_id, + TI.execution_date == dag_run.execution_date, + TI.task_id.in_(ti.task_id for ti in schedulable_tis) + ).update({TI.state: State.SCHEDULED}, synchronize_session=False) + + # TODO[HA]: Manage SLAs + + return count + @provide_session def _emit_pool_metrics(self, session: Session = None) -> None: pools = models.Pool.slots_stats(session=session) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 1b745629628ab..fbf1c71638e86 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -476,14 +476,25 @@ def next_dagrun_info(self, date_last_automated_dagrun : Optional[pendulum.DateTi "automated" DagRuns for this dag (scheduled or backfill, but not manual) """ + if (self.schedule_interval == "@once" and date_last_automated_dagrun) or \ + self.schedule_interval is None: + # Manual trigger, or already created the run for @once, can short circuit + return None next_execution_date = self.next_dagrun_after_date(date_last_automated_dagrun) - if next_execution_date is None or self.schedule_interval in (None, '@once'): + if next_execution_date is None: return None + if self.schedule_interval == "@once": + # For "@once" it can be created "now" + return { + 'execution_date': next_execution_date, + 'can_be_created_after': next_execution_date, + } + return { 'execution_date': next_execution_date, - 'can_be_created_after': self.following_schedule(next_execution_date) + 'can_be_created_after': self.following_schedule(next_execution_date), } def next_dagrun_after_date(self, date_last_automated_dagrun: Optional[pendulum.DateTime]): @@ -704,10 +715,7 @@ def owner(self) -> str: @property def allow_future_exec_dates(self) -> bool: - return conf.getboolean( - 'scheduler', - 'allow_trigger_in_future', - fallback=False) and self.schedule_interval is None + return settings.ALLOW_FUTURE_EXEC_DATES and self.schedule_interval is None @provide_session def get_concurrency_reached(self, session=None) -> bool: @@ -1710,6 +1718,7 @@ def bulk_sync_to_db(cls, dags: Collection["DAG"], session=None): or_( DagRun.run_type == DagRunType.BACKFILL_JOB.value, DagRun.run_type == DagRunType.SCHEDULED.value, + DagRun.external_trigger.is_(True), ), ).group_by(DagRun.dag_id).all()) @@ -1965,6 +1974,14 @@ class DagModel(Base): Index('idx_next_dagrun_create_after', next_dagrun_create_after, unique=False), ) + def __init__(self, **kwargs): + super().__init__(**kwargs) + if self.concurrency is None: + self.concurrency = conf.getint('core', 'dag_concurrency'), + if self.has_task_concurrency_limits is None: + # Be safe -- this will be updated later once the DAG is parsed + self.has_task_concurrency_limits = True + def __repr__(self): return "".format(self=self) diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 5f3d547626745..3f548a84599e8 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -174,7 +174,10 @@ def get_dag(self, dag_id, session: Session = None): dag_id in self.dags_last_fetched and timezone.utcnow() > self.dags_last_fetched[dag_id] + min_serialized_dag_fetch_secs ): - sd_last_updated_datetime = SerializedDagModel.get_last_updated_datetime(dag_id=dag_id) + sd_last_updated_datetime = SerializedDagModel.get_last_updated_datetime( + dag_id=dag_id, + session=session, + ) if sd_last_updated_datetime > self.dags_last_fetched[dag_id]: self._add_dag_from_db(dag_id=dag_id, session=session) @@ -517,7 +520,8 @@ def dagbag_report(self): """) return report - def sync_to_db(self): + @provide_session + def sync_to_db(self, session: Optional[Session] = None): """ Save attributes about list of DAG to the DB. """ @@ -525,9 +529,9 @@ def sync_to_db(self): from airflow.models.dag import DAG from airflow.models.serialized_dag import SerializedDagModel self.log.debug("Calling the DAG.bulk_sync_to_db method") - DAG.bulk_sync_to_db(self.dags.values()) + DAG.bulk_sync_to_db(self.dags.values(), session=session) # Write Serialized DAGs to DB if DAG Serialization is turned on # Even though self.read_dags_from_db is False if settings.STORE_SERIALIZED_DAGS or self.read_dags_from_db: self.log.debug("Calling the SerializedDagModel.bulk_sync_to_db method") - SerializedDagModel.bulk_sync_to_db(self.dags.values()) + SerializedDagModel.bulk_sync_to_db(self.dags.values(), session=session) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 614bc3cb514bc..d60e326797e0a 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -26,6 +26,7 @@ from sqlalchemy.orm import backref, relationship, synonym from sqlalchemy.orm.session import Session +from airflow import settings from airflow.configuration import conf as airflow_conf from airflow.exceptions import AirflowException from airflow.models.base import ID_LEN, Base @@ -178,9 +179,12 @@ def next_dagruns_to_examine( ).order_by( cls.last_scheduling_decision, cls.execution_date, - ).limit(max_number).with_for_update(of=cls, **skip_locked(session=session)) + ) + + if not settings.ALLOW_FUTURE_EXEC_DATES: + query = query.filter(DagRun.execution_date <= func.now()) - return query + return query.limit(max_number).with_for_update(of=cls, **skip_locked(session=session)) @staticmethod @provide_session diff --git a/airflow/settings.py b/airflow/settings.py index 0fe1b93fa9c1d..d42a6251bbb5f 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -373,3 +373,5 @@ def initialize(): 'execute_tasks_new_python_interpreter', fallback=False, ) + +ALLOW_FUTURE_EXEC_DATES = conf.getboolean('scheduler', 'allow_trigger_in_future', fallback=False) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 9d62f2555432a..fdb19605fce89 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -29,7 +29,6 @@ import psutil import pytest import six -from freezegun import freeze_time from mock import MagicMock, patch from parameterized import parameterized @@ -125,8 +124,7 @@ def create_test_dag(self, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + timed dag.clear() dag.is_subdag = False with create_session() as session: - orm_dag = DagModel(dag_id=dag.dag_id) - orm_dag.is_paused = False + orm_dag = DagModel(dag_id=dag.dag_id, is_paused=False) session.merge(orm_dag) session.commit() return dag @@ -196,58 +194,6 @@ def test_dag_file_processor_sla_miss_callback_invalid_sla(self): dag_file_processor.manage_slas(dag=dag, session=session) sla_callback.assert_not_called() - def test_scheduler_executor_overflow(self): - """ - Test that tasks that are set back to scheduled and removed from the executor - queue in the case of an overflow. - """ - executor = MockExecutor(do_update=True, parallelism=3) - - with create_session() as session: - dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), - include_examples=False, - include_smart_sensor=False) - dag = self.create_test_dag() - dag.clear() - dagbag.bag_dag(dag=dag, root_dag=dag) - dag = self.create_test_dag() - dag.clear() - task = DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') - tis = [] - for i in range(1, 10): - ti = TaskInstance(task, DEFAULT_DATE + timedelta(days=i)) - ti.state = State.SCHEDULED - tis.append(ti) - session.merge(ti) - - # scheduler._process_dags(simple_dag_bag) - @mock.patch('airflow.jobs.scheduler_job.DagBag', return_value=dagbag) - @mock.patch('airflow.jobs.scheduler_job.SchedulerJob._change_state_for_tis_without_dagrun') - def do_schedule(mock_dagbag, mock_change_state): - # Use a empty file since the above mock will return the - # expected DAGs. Also specify only a single file so that it doesn't - # try to schedule the above DAG repeatedly. - with conf_vars({('core', 'mp_start_method'): 'fork'}): - scheduler = SchedulerJob(num_runs=1, - executor=executor, - subdir=os.path.join(settings.DAGS_FOLDER, - "no_dags.py")) - scheduler.heartrate = 0 - scheduler.run() - - do_schedule() # pylint: disable=no-value-for-parameter - for ti in tis: - ti.refresh_from_db() - self.assertEqual(len(executor.queued_tasks), 0) - - successful_tasks = [ti for ti in tis if ti.state == State.SUCCESS] - scheduled_tasks = [ti for ti in tis if ti.state == State.SCHEDULED] - self.assertEqual(3, len(successful_tasks)) - self.assertEqual(6, len(scheduled_tasks)) - def test_dag_file_processor_sla_miss_callback_sent_notification(self): """ Test that the dag file processor does not call the sla_miss_callback when a @@ -420,70 +366,6 @@ def test_dag_file_processor_sla_miss_deleted_task(self): dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log) dag_file_processor.manage_slas(dag=dag, session=session) - def test_dag_file_processor_dagrun_once(self): - """ - Test if the dag file proccessor does not create multiple dagruns - if a dag is scheduled with @once and a start_date - """ - dag = DAG( - 'test_scheduler_dagrun_once', - start_date=timezone.datetime(2015, 1, 1), - schedule_interval="@once") - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dag.clear() - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNone(dr) - - @freeze_time(timezone.datetime(2020, 1, 5)) - def test_dag_file_processor_dagrun_with_timedelta_schedule_and_catchup_false(self): - """ - Test that the dag file processor does not create multiple dagruns - if a dag is scheduled with 'timedelta' and catchup=False - """ - dag = DAG( - 'test_scheduler_dagrun_once_with_timedelta_and_catchup_false', - start_date=timezone.datetime(2015, 1, 1), - schedule_interval=timedelta(days=1), - catchup=False) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag.clear() - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - self.assertEqual(dr.execution_date, timezone.datetime(2020, 1, 4)) - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNone(dr) - - @freeze_time(timezone.datetime(2020, 5, 4)) - def test_dag_file_processor_dagrun_with_timedelta_schedule_and_catchup_true(self): - """ - Test that the dag file processor creates multiple dagruns - if a dag is scheduled with 'timedelta' and catchup=True - """ - dag = DAG( - 'test_scheduler_dagrun_once_with_timedelta_and_catchup_true', - start_date=timezone.datetime(2020, 5, 1), - schedule_interval=timedelta(days=1), - catchup=True) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag.clear() - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - self.assertEqual(dr.execution_date, timezone.datetime(2020, 5, 1)) - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - self.assertEqual(dr.execution_date, timezone.datetime(2020, 5, 2)) - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - self.assertEqual(dr.execution_date, timezone.datetime(2020, 5, 3)) - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNone(dr) - @parameterized.expand([ [State.NONE, None, None], [State.UP_FOR_RETRY, timezone.utcnow() - datetime.timedelta(minutes=30), @@ -499,7 +381,7 @@ def test_dag_file_processor_process_task_instances(self, state, start_date, end_ dag = DAG( dag_id='test_scheduler_process_execute_task', start_date=DEFAULT_DATE) - dag_task1 = DummyOperator( + DummyOperator( task_id='dummy', dag=dag, owner='airflow') @@ -510,24 +392,27 @@ def test_dag_file_processor_process_task_instances(self, state, start_date, end_ dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + scheduler = SchedulerJob() + scheduler.dagbag.bag_dag(dag, root_dag=dag) dag.clear() - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + assert dr is not None with create_session() as session: - tis = dr.get_task_instances(session=session) - for ti in tis: - ti.state = state - ti.start_date = start_date - ti.end_date = end_date + ti = dr.get_task_instances(session=session)[0] + ti.state = state + ti.start_date = start_date + ti.end_date = end_date - mock_list = dag_file_processor._process_task_instances(dag, dag_runs=[dr]) + count = scheduler._schedule_dag_run(dr, session) + assert count == 1 - self.assertEqual( - [(dag.dag_id, dag_task1.task_id, DEFAULT_DATE, TRY_NUMBER)], - mock_list - ) + session.refresh(ti) + assert ti.state == State.SCHEDULED @parameterized.expand([ [State.NONE, None, None], @@ -546,7 +431,7 @@ def test_dag_file_processor_process_task_instances_with_task_concurrency( dag = DAG( dag_id='test_scheduler_process_execute_task_with_task_concurrency', start_date=DEFAULT_DATE) - dag_task1 = DummyOperator( + DummyOperator( task_id='dummy', task_concurrency=2, dag=dag, @@ -558,23 +443,27 @@ def test_dag_file_processor_process_task_instances_with_task_concurrency( dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + scheduler = SchedulerJob() + scheduler.dagbag.bag_dag(dag, root_dag=dag) dag.clear() - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + assert dr is not None with create_session() as session: - tis = dr.get_task_instances(session=session) - for ti in tis: - ti.state = state - ti.start_date = start_date - ti.end_date = end_date + ti = dr.get_task_instances(session=session)[0] + ti.state = state + ti.start_date = start_date + ti.end_date = end_date - ti_to_schedule = dag_file_processor._process_task_instances(dag, dag_runs=[dr]) + count = scheduler._schedule_dag_run(dr, session) + assert count == 1 - assert ti_to_schedule == [ - (dag.dag_id, dag_task1.task_id, DEFAULT_DATE, TRY_NUMBER), - ] + session.refresh(ti) + assert ti.state == State.SCHEDULED @parameterized.expand([ [State.NONE, None, None], @@ -595,11 +484,11 @@ def test_dag_file_processor_process_task_instances_depends_on_past(self, state, 'depends_on_past': True, }, ) - dag_task1 = DummyOperator( + DummyOperator( task_id='dummy1', dag=dag, owner='airflow') - dag_task2 = DummyOperator( + DummyOperator( task_id='dummy2', dag=dag, owner='airflow') @@ -610,10 +499,15 @@ def test_dag_file_processor_process_task_instances_depends_on_past(self, state, dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + scheduler = SchedulerJob() + scheduler.dagbag.bag_dag(dag, root_dag=dag) dag.clear() - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + assert dr is not None with create_session() as session: tis = dr.get_task_instances(session=session) @@ -622,13 +516,15 @@ def test_dag_file_processor_process_task_instances_depends_on_past(self, state, ti.start_date = start_date ti.end_date = end_date - ti_to_schedule = dag_file_processor._process_task_instances(dag, dag_runs=[dr]) + count = scheduler._schedule_dag_run(dr, session) + assert count == 2 - assert sorted(ti_to_schedule) == [ - (dag.dag_id, dag_task1.task_id, DEFAULT_DATE, TRY_NUMBER), - (dag.dag_id, dag_task2.task_id, DEFAULT_DATE, TRY_NUMBER), - ] + session.refresh(tis[0]) + session.refresh(tis[1]) + assert tis[0].state == State.SCHEDULED + assert tis[1].state == State.SCHEDULED + @pytest.mark.xfail(run=False, reason="TODO[HA]") def test_dag_file_processor_do_not_schedule_removed_task(self): dag = DAG( dag_id='test_scheduler_do_not_schedule_removed_task', @@ -663,81 +559,7 @@ def test_dag_file_processor_do_not_schedule_removed_task(self): self.assertEqual([], mock_list) - def test_dag_file_processor_do_not_schedule_too_early(self): - dag = DAG( - dag_id='test_scheduler_do_not_schedule_too_early', - start_date=timezone.datetime(2200, 1, 1)) - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') - - session = settings.Session() - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - session.commit() - session.close() - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag.clear() - - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNone(dr) - - mock_list = dag_file_processor._process_task_instances(dag, dag_runs=[]) - self.assertEqual([], mock_list) - - def test_dag_file_processor_do_not_schedule_without_tasks(self): - dag = DAG( - dag_id='test_scheduler_do_not_schedule_without_tasks', - start_date=DEFAULT_DATE) - - with create_session() as session: - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag.clear(session=session) - dag.start_date = None - dr = dag_file_processor.create_dag_run(dag, session=session) - self.assertIsNone(dr) - - def test_dag_file_processor_do_not_run_finished(self): - dag = DAG( - dag_id='test_scheduler_do_not_run_finished', - start_date=DEFAULT_DATE) - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') - - session = settings.Session() - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - session.commit() - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag.clear() - - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - - tis = dr.get_task_instances(session=session) - for ti in tis: - ti.state = State.SUCCESS - - session.commit() - session.close() - - mock_list = dag_file_processor._process_task_instances(dag, dag_runs=[dr]) - - self.assertEqual([], mock_list) - + @pytest.mark.xfail(run=False, reason="TODO[HA]") def test_dag_file_processor_add_new_task(self): """ Test if a task instance will be added if the dag is updated @@ -779,36 +601,7 @@ def test_dag_file_processor_add_new_task(self): tis = dr.get_task_instances() self.assertEqual(len(tis), 2) - def test_dag_file_processor_verify_max_active_runs(self): - """ - Test if a a dagrun will not be scheduled if max_dag_runs has been reached - """ - dag = DAG( - dag_id='test_scheduler_verify_max_active_runs', - start_date=DEFAULT_DATE) - dag.max_active_runs = 1 - - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') - - session = settings.Session() - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - session.commit() - session.close() - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag.clear() - - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNone(dr) - + @pytest.mark.xfail(run=False, reason="TODO[HA]") def test_dag_file_processor_fail_dagrun_timeout(self): """ Test if a a dagrun wil be set failed if timeout @@ -845,6 +638,7 @@ def test_dag_file_processor_fail_dagrun_timeout(self): dr.refresh_from_db(session=session) self.assertEqual(dr.state, State.FAILED) + @pytest.mark.xfail(run=False, reason="TODO[HA]") def test_dag_file_processor_verify_max_active_runs_and_dagrun_timeout(self): """ Test if a a dagrun will not be scheduled if max_dag_runs @@ -889,7 +683,7 @@ def test_dag_file_processor_verify_max_active_runs_and_dagrun_timeout(self): new_dr = dag_file_processor.create_dag_run(dag) self.assertIsNotNone(new_dr) - def test_dag_file_processor_max_active_runs_respected_after_clear(self): + def test_runs_respected_after_clear(self): """ Test if _process_task_instances only schedules ti's up to max_active_runs (related to issue AIRFLOW-137) @@ -899,7 +693,7 @@ def test_dag_file_processor_max_active_runs_respected_after_clear(self): start_date=DEFAULT_DATE) dag.max_active_runs = 3 - dag_task1 = DummyOperator( + DummyOperator( task_id='dummy', dag=dag, owner='airflow') @@ -911,15 +705,32 @@ def test_dag_file_processor_max_active_runs_respected_after_clear(self): session.close() dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + scheduler = SchedulerJob() + scheduler.dagbag.bag_dag(dag, root_dag=dag) dag.clear() + date = DEFAULT_DATE + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.RUNNING, + ) + date = dag.following_schedule(date) + dr2 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.RUNNING, + ) + date = dag.following_schedule(date) + dr3 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.RUNNING, + ) + # First create up to 3 dagruns in RUNNING state. - dr1 = dag_file_processor.create_dag_run(dag) assert dr1 is not None - dr2 = dag_file_processor.create_dag_run(dag) assert dr2 is not None - dr3 = dag_file_processor.create_dag_run(dag) assert dr3 is not None assert len(DagRun.find(dag_id=dag.dag_id, state=State.RUNNING, session=session)) == 3 @@ -928,150 +739,15 @@ def test_dag_file_processor_max_active_runs_respected_after_clear(self): # and schedule them in, so we can check how many # tasks are put on the task_instances_list (should be one, not 3) - task_instances_list = dag_file_processor._process_task_instances(dag, dag_runs=[dr1, dr2, dr3]) - - self.assertEqual([(dag.dag_id, dag_task1.task_id, DEFAULT_DATE, TRY_NUMBER)], task_instances_list) - - def test_find_dags_to_run_includes_subdags(self): - dag = self.dagbag.get_dag('test_subdag_operator') - self.assertGreater(len(dag.subdags), 0) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dags = dag_file_processor._find_dags_to_process(self.dagbag.dags.values()) - - self.assertIn(dag, dags) - for subdag in dag.subdags: - self.assertIn(subdag, dags) - - def test_dag_catchup_option(self): - """ - Test to check that a DAG with catchup = False only schedules beginning now, not back to the start date - """ - - def setup_dag(dag_id, schedule_interval, start_date, catchup): - default_args = { - 'owner': 'airflow', - 'depends_on_past': False, - 'start_date': start_date - } - dag = DAG(dag_id, - schedule_interval=schedule_interval, - max_active_runs=1, - catchup=catchup, - default_args=default_args) - - op1 = DummyOperator(task_id='t1', dag=dag) - op2 = DummyOperator(task_id='t2', dag=dag) - op2.set_upstream(op1) - op3 = DummyOperator(task_id='t3', dag=dag) - op3.set_upstream(op2) - - session = settings.Session() - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - session.commit() - session.close() - - return SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - now = timezone.utcnow() - six_hours_ago_to_the_hour = (now - datetime.timedelta(hours=6)).replace( - minute=0, second=0, microsecond=0) - half_an_hour_ago = now - datetime.timedelta(minutes=30) - two_hours_ago = now - datetime.timedelta(hours=2) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - - dag1 = setup_dag(dag_id='dag_with_catchup', - schedule_interval='* * * * *', - start_date=six_hours_ago_to_the_hour, - catchup=True) - default_catchup = conf.getboolean('scheduler', 'catchup_by_default') - self.assertEqual(default_catchup, True) - self.assertEqual(dag1.catchup, True) - - dag2 = setup_dag(dag_id='dag_without_catchup_ten_minute', - schedule_interval='*/10 * * * *', - start_date=six_hours_ago_to_the_hour, - catchup=False) - dr = dag_file_processor.create_dag_run(dag2) - # We had better get a dag run - self.assertIsNotNone(dr) - # The DR should be scheduled in the last half an hour, not 6 hours ago - self.assertGreater(dr.execution_date, half_an_hour_ago) - # The DR should be scheduled BEFORE now - self.assertLess(dr.execution_date, timezone.utcnow()) - - dag3 = setup_dag(dag_id='dag_without_catchup_hourly', - schedule_interval='@hourly', - start_date=six_hours_ago_to_the_hour, - catchup=False) - dr = dag_file_processor.create_dag_run(dag3) - # We had better get a dag run - self.assertIsNotNone(dr) - # The DR should be scheduled in the last 2 hours, not 6 hours ago - self.assertGreater(dr.execution_date, two_hours_ago) - # The DR should be scheduled BEFORE now - self.assertLess(dr.execution_date, timezone.utcnow()) - - dag4 = setup_dag(dag_id='dag_without_catchup_once', - schedule_interval='@once', - start_date=six_hours_ago_to_the_hour, - catchup=False) - dr = dag_file_processor.create_dag_run(dag4) - self.assertIsNotNone(dr) - - def test_dag_file_processor_auto_align(self): - """ - Test if the schedule_interval will be auto aligned with the start_date - such that if the start_date coincides with the schedule the first - execution_date will be start_date, otherwise it will be start_date + - interval. - """ - dag = DAG( - dag_id='test_scheduler_auto_align_1', - start_date=timezone.datetime(2016, 1, 1, 10, 10, 0), - schedule_interval="4 5 * * *" - ) - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') - - session = settings.Session() - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - session.commit() - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - self.assertEqual(dr.execution_date, timezone.datetime(2016, 1, 2, 5, 4)) - - dag = DAG( - dag_id='test_scheduler_auto_align_2', - start_date=timezone.datetime(2016, 1, 1, 10, 10, 0), - schedule_interval="10 10 * * *" - ) - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') - - session = settings.Session() - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - session.commit() - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - dag.clear() - - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - self.assertEqual(dr.execution_date, timezone.datetime(2016, 1, 1, 10, 10)) - + with create_session() as session: + num_scheduled = scheduler._schedule_dag_run(dr1, session) + assert num_scheduled == 1 + num_scheduled = scheduler._schedule_dag_run(dr2, session) + assert num_scheduled == 0 + num_scheduled = scheduler._schedule_dag_run(dr3, session) + assert num_scheduled == 0 + + @pytest.mark.xfail(run=False, reason="TODO[HA]") def test_process_dags_not_create_dagrun_for_subdags(self): dag = self.dagbag.get_dag('test_subdag_operator') @@ -1154,34 +830,7 @@ def test_process_file_should_failure_callback(self): self.assertEqual("Callback fired", content) os.remove(callback_file.name) - def test_should_parse_only_unpaused_dags(self): - dag_file = os.path.join( - os.path.dirname(os.path.realpath(__file__)), '../dags/test_multiple_dags.py' - ) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dagbag = DagBag(dag_folder=dag_file, include_examples=False) - dagbag.sync_to_db() - with create_session() as session: - session.query(TaskInstance).delete() - ( - session.query(DagModel) - .filter(DagModel.dag_id == "test_multiple_dags__dag_1") - .update({DagModel.is_paused: True}, synchronize_session=False) - ) - - serialized_dags, import_errors_count = dag_file_processor.process_file( - file_path=dag_file, failure_callback_requests=[] - ) - - dags = [SerializedDAG.from_dict(serialized_dag) for serialized_dag in serialized_dags] - - with create_session() as session: - tis = session.query(TaskInstance).all() - - self.assertEqual(0, import_errors_count) - self.assertEqual(['test_multiple_dags__dag_2'], [dag.dag_id for dag in dags]) - self.assertEqual({'test_multiple_dags__dag_2'}, {ti.dag_id for ti in tis}) - + @pytest.mark.skip def test_should_mark_dummy_task_as_success(self): dag_file = os.path.join( os.path.dirname(os.path.realpath(__file__)), '../dags/test_only_dummy_tasks.py' @@ -1425,7 +1074,6 @@ def run_single_scheduler_loop_with_no_dags(self, dags_folder): """ scheduler = SchedulerJob( executor=self.null_exec, - dag_id='this_dag_doesnt_exist', # We don't want to actually run anything num_runs=1, subdir=os.path.join(dags_folder)) scheduler.heartrate = 0 @@ -1461,31 +1109,23 @@ def test_process_executor_events(self, mock_stats_incr): DummyOperator(dag=dag2, task_id=task_id_1) dag.fileloc = "/test_path1/" dag2.fileloc = "/test_path1/" - dagbag1 = self._make_simple_dag_bag([dag]) - dagbag2 = self._make_simple_dag_bag([dag2]) - scheduler = SchedulerJob() + executor = MockExecutor(do_update=False) + scheduler = SchedulerJob(executor=executor) + scheduler.processor_agent = mock.MagicMock() + session = settings.Session() + dag.sync_to_db(session=session) + dag2.sync_to_db(session=session) ti1 = TaskInstance(task1, DEFAULT_DATE) ti1.state = State.QUEUED session.merge(ti1) session.commit() - executor = MockExecutor(do_update=False) executor.event_buffer[ti1.key] = State.FAILED, None - scheduler.executor = executor - - scheduler.processor_agent = mock.MagicMock() - # dag bag does not contain dag_id - scheduler._process_executor_events(simple_dag_bag=dagbag2) - ti1.refresh_from_db() - self.assertEqual(ti1.state, State.QUEUED) - scheduler.processor_agent.send_callback_to_execute.assert_not_called() - - # dag bag does contain dag_id - scheduler._process_executor_events(simple_dag_bag=dagbag1) + scheduler._process_executor_events(session=session) ti1.refresh_from_db() self.assertEqual(ti1.state, State.QUEUED) scheduler.processor_agent.send_callback_to_execute.assert_called_once_with( @@ -1504,7 +1144,7 @@ def test_process_executor_events(self, mock_stats_incr): session.commit() executor.event_buffer[ti1.key] = State.SUCCESS, None - scheduler._process_executor_events(simple_dag_bag=dagbag1) + scheduler._process_executor_events(session=session) ti1.refresh_from_db() self.assertEqual(ti1.state, State.SUCCESS) scheduler.processor_agent.send_callback_to_execute.assert_not_called() @@ -1517,16 +1157,13 @@ def test_process_executor_events_uses_inmemory_try_number(self): task_id = "task_id" try_number = 42 - scheduler = SchedulerJob() executor = MagicMock() + scheduler = SchedulerJob(executor=executor) + scheduler.processor_agent = MagicMock() event_buffer = { TaskInstanceKey(dag_id, task_id, execution_date, try_number): (State.SUCCESS, None) } executor.get_event_buffer.return_value = event_buffer - scheduler.executor = executor - - processor_agent = MagicMock() - scheduler.processor_agent = processor_agent dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE) task = DummyOperator(dag=dag, task_id=task_id) @@ -1536,7 +1173,7 @@ def test_process_executor_events_uses_inmemory_try_number(self): ti.state = State.SUCCESS session.merge(ti) - scheduler._process_executor_events(simple_dag_bag=MagicMock()) + scheduler._process_executor_events() # Assert that the even_buffer is empty so the task was popped using right # task instance key self.assertEqual(event_buffer, {}) @@ -1548,27 +1185,33 @@ def test_execute_task_instances_is_paused_wont_execute(self): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE) task1 = DummyOperator(dag=dag, task_id=task_id_1) dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() - dr1 = dag_file_processor.create_dag_run(dag) + dagmodel = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + dr1 = dag.create_dagrun( + run_type=DagRunType.BACKFILL_JOB, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) ti1 = TaskInstance(task1, DEFAULT_DATE) ti1.state = State.SCHEDULED - dr1.state = State.RUNNING - dagmodel = DagModel() - dagmodel.dag_id = dag_id - dagmodel.is_paused = True session.merge(ti1) session.merge(dr1) session.add(dagmodel) - session.commit() + session.flush() - scheduler._execute_task_instances(dagbag) + scheduler._critical_section_execute_task_instances(session) + session.flush() ti1.refresh_from_db() self.assertEqual(State.SCHEDULED, ti1.state) + session.rollback() def test_execute_task_instances_no_dagrun_task_will_execute(self): """ @@ -1581,22 +1224,32 @@ def test_execute_task_instances_no_dagrun_task_will_execute(self): task1 = DummyOperator(dag=dag, task_id=task_id_1) dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() - dag_file_processor.create_dag_run(dag) + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) ti1 = TaskInstance(task1, DEFAULT_DATE) ti1.state = State.SCHEDULED ti1.execution_date = ti1.execution_date + datetime.timedelta(days=1) session.merge(ti1) - session.commit() + session.flush() - scheduler._execute_task_instances(dagbag) + scheduler._critical_section_execute_task_instances(session) + session.flush() ti1.refresh_from_db() self.assertEqual(State.QUEUED, ti1.state) + session.rollback() def test_execute_task_instances_backfill_tasks_wont_execute(self): """ @@ -1608,26 +1261,36 @@ def test_execute_task_instances_backfill_tasks_wont_execute(self): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE) task1 = DummyOperator(dag=dag, task_id=task_id_1) dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() - dr1 = dag_file_processor.create_dag_run(dag) - dr1.run_type = DagRunType.BACKFILL_JOB.value + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + dr1 = dag.create_dagrun( + run_type=DagRunType.BACKFILL_JOB, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) ti1 = TaskInstance(task1, dr1.execution_date) ti1.refresh_from_db() ti1.state = State.SCHEDULED session.merge(ti1) session.merge(dr1) - session.commit() + session.flush() self.assertTrue(dr1.is_backfill) - scheduler._execute_task_instances(dagbag) + scheduler._critical_section_execute_task_instances(session) + session.flush() ti1.refresh_from_db() self.assertEqual(State.SCHEDULED, ti1.state) + session.rollback() def test_find_executable_task_instances_backfill_nodagrun(self): dag_id = 'SchedulerJobTest.test_find_executable_task_instances_backfill_nodagrun' @@ -1635,15 +1298,27 @@ def test_find_executable_task_instances_backfill_nodagrun(self): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16) task1 = DummyOperator(dag=dag, task_id=task_id_1) dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() - dr1 = dag_file_processor.create_dag_run(dag) - dr2 = dag_file_processor.create_dag_run(dag) - dr2.run_type = DagRunType.BACKFILL_JOB.value + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + dr2 = dag.create_dagrun( + run_type=DagRunType.BACKFILL_JOB, + execution_date=dag.following_schedule(dr1.execution_date), + state=State.RUNNING, + ) ti_no_dagrun = TaskInstance(task1, DEFAULT_DATE - datetime.timedelta(days=1)) ti_backfill = TaskInstance(task1, dr2.execution_date) @@ -1657,16 +1332,15 @@ def test_find_executable_task_instances_backfill_nodagrun(self): session.merge(ti_no_dagrun) session.merge(ti_backfill) session.merge(ti_with_dagrun) - session.commit() + session.flush() - res = scheduler._find_executable_task_instances( - dagbag, - session=session) + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) self.assertEqual(2, len(res)) res_keys = map(lambda x: x.key, res) self.assertIn(ti_no_dagrun.key, res_keys) self.assertIn(ti_with_dagrun.key, res_keys) + session.rollback() def test_find_executable_task_instances_pool(self): dag_id = 'SchedulerJobTest.test_find_executable_task_instances_pool' @@ -1676,14 +1350,27 @@ def test_find_executable_task_instances_pool(self): task1 = DummyOperator(dag=dag, task_id=task_id_1, pool='a') task2 = DummyOperator(dag=dag, task_id=task_id_2, pool='b') dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() - dr1 = dag_file_processor.create_dag_run(dag) - dr2 = dag_file_processor.create_dag_run(dag) + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + dr2 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=dag.following_schedule(dr1.execution_date), + state=State.RUNNING, + ) tis = ([ TaskInstance(task1, dr1.execution_date), @@ -1698,12 +1385,10 @@ def test_find_executable_task_instances_pool(self): pool2 = Pool(pool='b', slots=100, description='haha') session.add(pool) session.add(pool2) - session.commit() + session.flush() - res = scheduler._find_executable_task_instances( - dagbag, - session=session) - session.commit() + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) + session.flush() self.assertEqual(3, len(res)) res_keys = [] for ti in res: @@ -1711,6 +1396,7 @@ def test_find_executable_task_instances_pool(self): self.assertIn(tis[0].key, res_keys) self.assertIn(tis[1].key, res_keys) self.assertIn(tis[3].key, res_keys) + session.rollback() def test_find_executable_task_instances_in_default_pool(self): set_default_pool_slots(1) @@ -1720,40 +1406,50 @@ def test_find_executable_task_instances_in_default_pool(self): op1 = DummyOperator(dag=dag, task_id='dummy1') op2 = DummyOperator(dag=dag, task_id='dummy2') dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) executor = MockExecutor(do_update=True) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob(executor=executor) - dr1 = dag_file_processor.create_dag_run(dag) - dr2 = dag_file_processor.create_dag_run(dag) + session = settings.Session() + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + dr2 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=dag.following_schedule(dr1.execution_date), + state=State.RUNNING, + ) ti1 = TaskInstance(task=op1, execution_date=dr1.execution_date) ti2 = TaskInstance(task=op2, execution_date=dr2.execution_date) ti1.state = State.SCHEDULED ti2.state = State.SCHEDULED - session = settings.Session() session.merge(ti1) session.merge(ti2) - session.commit() + session.flush() # Two tasks w/o pool up for execution and our default pool size is 1 - res = scheduler._find_executable_task_instances( - dagbag, - session=session) + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) self.assertEqual(1, len(res)) ti2.state = State.RUNNING session.merge(ti2) - session.commit() + session.flush() # One task w/o pool up for execution and one task task running - res = scheduler._find_executable_task_instances( - dagbag, - session=session) + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) self.assertEqual(0, len(res)) + session.rollback() session.close() def test_nonexistent_pool(self): @@ -1762,24 +1458,32 @@ def test_nonexistent_pool(self): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16) task = DummyOperator(dag=dag, task_id=task_id, pool="this_pool_doesnt_exist") dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() - dr = dag_file_processor.create_dag_run(dag) + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) ti = TaskInstance(task, dr.execution_date) ti.state = State.SCHEDULED session.merge(ti) session.commit() - res = scheduler._find_executable_task_instances( - dagbag, - session=session) - session.commit() + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) + session.flush() self.assertEqual(0, len(res)) + session.rollback() def test_find_executable_task_instances_none(self): dag_id = 'SchedulerJobTest.test_find_executable_task_instances_none' @@ -1787,18 +1491,28 @@ def test_find_executable_task_instances_none(self): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16) DummyOperator(dag=dag, task_id=task_id_1) dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() - dag_file_processor.create_dag_run(dag) - session.commit() + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + session.flush() - self.assertEqual(0, len(scheduler._find_executable_task_instances( - dagbag, + self.assertEqual(0, len(scheduler._executable_task_instances_to_queued( + max_tis=32, session=session))) + session.rollback() def test_find_executable_task_instances_concurrency(self): dag_id = 'SchedulerJobTest.test_find_executable_task_instances_concurrency' @@ -1806,15 +1520,32 @@ def test_find_executable_task_instances_concurrency(self): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2) task1 = DummyOperator(dag=dag, task_id=task_id_1) dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() - dr1 = dag_file_processor.create_dag_run(dag) - dr2 = dag_file_processor.create_dag_run(dag) - dr3 = dag_file_processor.create_dag_run(dag) + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + dr2 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=dag.following_schedule(dr1.execution_date), + state=State.RUNNING, + ) + dr3 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=dag.following_schedule(dr2.execution_date), + state=State.RUNNING, + ) ti1 = TaskInstance(task1, dr1.execution_date) ti2 = TaskInstance(task1, dr2.execution_date) @@ -1826,11 +1557,9 @@ def test_find_executable_task_instances_concurrency(self): session.merge(ti2) session.merge(ti3) - session.commit() + session.flush() - res = scheduler._find_executable_task_instances( - dagbag, - session=session) + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) self.assertEqual(1, len(res)) res_keys = map(lambda x: x.key, res) @@ -1838,13 +1567,12 @@ def test_find_executable_task_instances_concurrency(self): ti2.state = State.RUNNING session.merge(ti2) - session.commit() + session.flush() - res = scheduler._find_executable_task_instances( - dagbag, - session=session) + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) self.assertEqual(0, len(res)) + session.rollback() def test_find_executable_task_instances_concurrency_queued(self): dag_id = 'SchedulerJobTest.test_find_executable_task_instances_concurrency_queued' @@ -1853,12 +1581,21 @@ def test_find_executable_task_instances_concurrency_queued(self): task2 = DummyOperator(dag=dag, task_id='dummy2') task3 = DummyOperator(dag=dag, task_id='dummy3') dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() - dag_run = dag_file_processor.create_dag_run(dag) + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + dag_run = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) ti1 = TaskInstance(task1, dag_run.execution_date) ti2 = TaskInstance(task2, dag_run.execution_date) @@ -1871,15 +1608,15 @@ def test_find_executable_task_instances_concurrency_queued(self): session.merge(ti2) session.merge(ti3) - session.commit() + session.flush() - res = scheduler._find_executable_task_instances( - dagbag, - session=session) + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) self.assertEqual(1, len(res)) self.assertEqual(res[0].key, ti3.key) + session.rollback() + # TODO: This is a hack, I think I need to just remove the setting and have it on always def test_find_executable_task_instances_task_concurrency(self): # pylint: disable=too-many-statements dag_id = 'SchedulerJobTest.test_find_executable_task_instances_task_concurrency' task_id_1 = 'dummy' @@ -1887,17 +1624,29 @@ def test_find_executable_task_instances_task_concurrency(self): # pylint: disab dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16) task1 = DummyOperator(dag=dag, task_id=task_id_1, task_concurrency=2) task2 = DummyOperator(dag=dag, task_id=task_id_2) - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) executor = MockExecutor(do_update=True) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob(executor=executor) session = settings.Session() - dr1 = dag_file_processor.create_dag_run(dag) - dr2 = dag_file_processor.create_dag_run(dag) - dr3 = dag_file_processor.create_dag_run(dag) + scheduler.dagbag.bag_dag(dag, root_dag=dag) + scheduler.dagbag.sync_to_db(session=session) + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + dr2 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=dag.following_schedule(dr1.execution_date), + state=State.RUNNING, + ) + dr3 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=dag.following_schedule(dr2.execution_date), + state=State.RUNNING, + ) ti1_1 = TaskInstance(task1, dr1.execution_date) ti2 = TaskInstance(task2, dr1.execution_date) @@ -1906,11 +1655,9 @@ def test_find_executable_task_instances_task_concurrency(self): # pylint: disab ti2.state = State.SCHEDULED session.merge(ti1_1) session.merge(ti2) - session.commit() + session.flush() - res = scheduler._find_executable_task_instances( - dagbag, - session=session) + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) self.assertEqual(2, len(res)) @@ -1921,11 +1668,9 @@ def test_find_executable_task_instances_task_concurrency(self): # pylint: disab session.merge(ti1_1) session.merge(ti2) session.merge(ti1_2) - session.commit() + session.flush() - res = scheduler._find_executable_task_instances( - dagbag, - session=session) + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) self.assertEqual(1, len(res)) @@ -1934,11 +1679,9 @@ def test_find_executable_task_instances_task_concurrency(self): # pylint: disab ti1_3.state = State.SCHEDULED session.merge(ti1_2) session.merge(ti1_3) - session.commit() + session.flush() - res = scheduler._find_executable_task_instances( - dagbag, - session=session) + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) self.assertEqual(0, len(res)) @@ -1948,11 +1691,9 @@ def test_find_executable_task_instances_task_concurrency(self): # pylint: disab session.merge(ti1_1) session.merge(ti1_2) session.merge(ti1_3) - session.commit() + session.flush() - res = scheduler._find_executable_task_instances( - dagbag, - session=session) + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) self.assertEqual(2, len(res)) @@ -1962,20 +1703,12 @@ def test_find_executable_task_instances_task_concurrency(self): # pylint: disab session.merge(ti1_1) session.merge(ti1_2) session.merge(ti1_3) - session.commit() + session.flush() - res = scheduler._find_executable_task_instances( - dagbag, - session=session) + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) self.assertEqual(1, len(res)) - - def test_change_state_for_executable_task_instances_no_tis(self): - scheduler = SchedulerJob() - session = settings.Session() - res = scheduler._change_state_for_executable_task_instances( - [], session) - self.assertEqual(0, len(res)) + session.rollback() def test_change_state_for_executable_task_instances_no_tis_with_state(self): dag_id = 'SchedulerJobTest.test_change_state_for__no_tis_with_state' @@ -1988,10 +1721,24 @@ def test_change_state_for_executable_task_instances_no_tis_with_state(self): scheduler = SchedulerJob() session = settings.Session() - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dr1 = dag_file_processor.create_dag_run(dag) - dr2 = dag_file_processor.create_dag_run(dag) - dr3 = dag_file_processor.create_dag_run(dag) + date = DEFAULT_DATE + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.RUNNING, + ) + date = dag.following_schedule(date) + dr2 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.RUNNING, + ) + date = dag.following_schedule(date) + dr3 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.RUNNING, + ) ti1 = TaskInstance(task1, dr1.execution_date) ti2 = TaskInstance(task1, dr2.execution_date) @@ -2003,57 +1750,47 @@ def test_change_state_for_executable_task_instances_no_tis_with_state(self): session.merge(ti2) session.merge(ti3) - session.commit() + session.flush() - res = scheduler._change_state_for_executable_task_instances( - [ti1, ti2, ti3], - session) + res = scheduler._executable_task_instances_to_queued(max_tis=100, session=session) self.assertEqual(0, len(res)) + session.rollback() + def test_enqueue_task_instances_with_queued_state(self): dag_id = 'SchedulerJobTest.test_enqueue_task_instances_with_queued_state' task_id_1 = 'dummy' dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE) task1 = DummyOperator(dag=dag, task_id=task_id_1) dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() - dr1 = dag_file_processor.create_dag_run(dag) - + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) ti1 = TaskInstance(task1, dr1.execution_date) + ti1.dag_model = dag_model session.merge(ti1) - session.commit() + session.flush() with patch.object(BaseExecutor, 'queue_command') as mock_queue_command: - scheduler._enqueue_task_instances_with_queued_state(dagbag, [ti1]) + scheduler._enqueue_task_instances_with_queued_state([ti1]) assert mock_queue_command.called + session.rollback() - def test_execute_task_instances_nothing(self): - dag_id = 'SchedulerJobTest.test_execute_task_instances_nothing' - task_id_1 = 'dummy' - dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2) - task1 = DummyOperator(dag=dag, task_id=task_id_1) - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = SimpleDagBag([]) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - scheduler = SchedulerJob() - session = settings.Session() - - dr1 = dag_file_processor.create_dag_run(dag) - ti1 = TaskInstance(task1, dr1.execution_date) - ti1.state = State.SCHEDULED - session.merge(ti1) - session.commit() - - self.assertEqual(0, scheduler._execute_task_instances(dagbag)) - - def test_execute_task_instances(self): + def test_critical_section_execute_task_instances(self): dag_id = 'SchedulerJobTest.test_execute_task_instances' task_id_1 = 'dummy_task' task_id_2 = 'dummy_task_nonexistent_queue' @@ -2065,14 +1802,24 @@ def test_execute_task_instances(self): task1 = DummyOperator(dag=dag, task_id=task_id_1) task2 = DummyOperator(dag=dag, task_id=task_id_2) dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() # create first dag run with 1 running and 1 queued - dr1 = dag_file_processor.create_dag_run(dag) + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + ti1 = TaskInstance(task1, dr1.execution_date) ti2 = TaskInstance(task2, dr1.execution_date) ti1.refresh_from_db() @@ -2081,7 +1828,7 @@ def test_execute_task_instances(self): ti2.state = State.RUNNING session.merge(ti1) session.merge(ti2) - session.commit() + session.flush() self.assertEqual(State.RUNNING, dr1.state) self.assertEqual( @@ -2092,7 +1839,11 @@ def test_execute_task_instances(self): ) # create second dag run - dr2 = dag_file_processor.create_dag_run(dag) + dr2 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=dag.following_schedule(dr1.execution_date), + state=State.RUNNING, + ) ti3 = TaskInstance(task1, dr2.execution_date) ti4 = TaskInstance(task2, dr2.execution_date) ti3.refresh_from_db() @@ -2102,11 +1853,11 @@ def test_execute_task_instances(self): ti4.state = State.SCHEDULED session.merge(ti3) session.merge(ti4) - session.commit() + session.flush() self.assertEqual(State.RUNNING, dr2.state) - res = scheduler._execute_task_instances(dagbag) + res = scheduler._critical_section_execute_task_instances(session) # check that concurrency is respected ti1.refresh_from_db() @@ -2136,16 +1887,26 @@ def test_execute_task_instances_limit(self): task1 = DummyOperator(dag=dag, task_id=task_id_1) task2 = DummyOperator(dag=dag, task_id=task_id_2) dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() - scheduler.max_tis_per_query = 3 session = settings.Session() + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + date = dag.start_date tis = [] for _ in range(0, 4): - dr = dag_file_processor.create_dag_run(dag) + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.RUNNING, + ) + date = dag.following_schedule(date) ti1 = TaskInstance(task1, dr.execution_date) ti2 = TaskInstance(task2, dr.execution_date) tis.append(ti1) @@ -2156,10 +1917,22 @@ def test_execute_task_instances_limit(self): ti2.state = State.SCHEDULED session.merge(ti1) session.merge(ti2) - session.commit() - res = scheduler._execute_task_instances(dagbag) - - self.assertEqual(8, res) + session.flush() + scheduler.max_tis_per_query = 2 + res = scheduler._critical_section_execute_task_instances(session) + self.assertEqual(2, res) + + scheduler.max_tis_per_query = 8 + with mock.patch.object(type(scheduler.executor), + 'slots_available', + new_callable=mock.PropertyMock) as mock_slots: + mock_slots.return_value = 2 + # Check that we don't "overfill" the executor + self.assertEqual(2, res) + res = scheduler._critical_section_execute_task_instances(session) + + res = scheduler._critical_section_execute_task_instances(session) + self.assertEqual(4, res) for ti in tis: ti.refresh_from_db() self.assertEqual(State.QUEUED, ti.state) @@ -2358,6 +2131,7 @@ def test_adopt_or_reset_orphaned_tasks(self): [State.SCHEDULED, State.NONE], [State.UP_FOR_RESCHEDULE, State.NONE], ]) + @pytest.mark.xfail(run=False, reason="TODO[HA]") def test_scheduler_loop_should_change_state_for_tis_without_dagrun(self, initial_task_state, expected_task_state): @@ -2390,8 +2164,6 @@ def test_scheduler_loop_should_change_state_for_tis_without_dagrun(self, executor.queued_tasks scheduler.executor = executor processor = mock.MagicMock() - processor.harvest_serialized_dags.return_value = [ - SerializedDAG.from_dict(SerializedDAG.to_dict(dag))] processor.done = True scheduler.processor_agent = processor @@ -2423,13 +2195,20 @@ def evaluate_dagrun( if run_kwargs is None: run_kwargs = {} - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) dag = self.dagbag.get_dag(dag_id) - dr = dag_file_processor.create_dag_run(dag) + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=dag.next_dagrun_after_date(None), + state=State.RUNNING, + ) if advance_execution_date: # run a second time to schedule a dagrun after the start_date - dr = dag_file_processor.create_dag_run(dag) + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=dag.following_schedule(dr.execution_date), + state=State.RUNNING, + ) ex_date = dr.execution_date for tid, state in expected_task_states.items(): @@ -2499,10 +2278,13 @@ def test_dagrun_root_fail_unfinished(self): """ # TODO: this should live in test_dagrun.py # Run both the failed and successful tasks - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) dag_id = 'test_dagrun_states_root_fail_unfinished' dag = self.dagbag.get_dag(dag_id) - dr = dag_file_processor.create_dag_run(dag) + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) self.null_exec.mock_task_fail(dag_id, 'test_dagrun_fail', DEFAULT_DATE) with self.assertRaises(AirflowException): @@ -2523,8 +2305,10 @@ def test_dagrun_root_after_dagrun_unfinished(self): Noted: the DagRun state could be still in running state during CI. """ + clear_db_dags() dag_id = 'test_dagrun_states_root_future' dag = self.dagbag.get_dag(dag_id) + dag.sync_to_db() scheduler = SchedulerJob( dag_id, num_runs=1, @@ -2582,8 +2366,12 @@ def test_scheduler_start_date(self): dag.clear() self.assertGreater(dag.start_date, datetime.datetime.now(timezone.utc)) - scheduler = SchedulerJob(dag_id, - executor=self.null_exec, + # Deactivate other dags in this file + other_dag = self.dagbag.get_dag('test_task_start_date_scheduling') + other_dag.is_paused_upon_creation = True + other_dag.sync_to_db() + + scheduler = SchedulerJob(executor=self.null_exec, subdir=dag.fileloc, num_runs=1) scheduler.run() @@ -2636,10 +2424,16 @@ def test_scheduler_task_start_date(self): dag_id = 'test_task_start_date_scheduling' dag = self.dagbag.get_dag(dag_id) + dag.sync_to_db() dag.clear() - scheduler = SchedulerJob(dag_id, - executor=self.null_exec, - subdir=os.path.join(TEST_DAG_FOLDER, 'test_scheduler_dags.py'), + + # Deactivate other dags in this file + other_dag = self.dagbag.get_dag('test_start_date_scheduling') + other_dag.is_paused_upon_creation = True + other_dag.sync_to_db() + + scheduler = SchedulerJob(executor=self.null_exec, + subdir=dag.fileloc, num_runs=2) scheduler.run() @@ -2712,47 +2506,38 @@ def test_scheduler_verify_pool_full(self): owner='airflow', pool='test_scheduler_verify_pool_full') + dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), + include_examples=False, + read_dags_from_db=True) + dagbag.bag_dag(dag=dag, root_dag=dag) + dagbag.sync_to_db() + session = settings.Session() pool = Pool(pool='test_scheduler_verify_pool_full', slots=1) session.add(pool) - orm_dag = DagModel(dag_id=dag.dag_id) - orm_dag.is_paused = False - session.merge(orm_dag) - session.commit() + session.flush() dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob(executor=self.null_exec) # Create 2 dagruns, which will create 2 task instances. - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - self.assertEqual(dr.execution_date, DEFAULT_DATE) - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - dag_runs = DagRun.find(dag_id="test_scheduler_verify_pool_full") - task_instances_list = dag_file_processor._process_task_instances(dag, dag_runs=dag_runs) - self.assertEqual(len(task_instances_list), 2) - dagbag = self._make_simple_dag_bag([dag]) - - # Recreated part of the scheduler here, to kick off tasks -> executor - for ti_key in task_instances_list: - task = dag.get_task(ti_key[1]) - ti = TaskInstance(task, ti_key[2]) - # Task starts out in the scheduled state. All tasks in the - # scheduled state will be sent to the executor - ti.state = State.SCHEDULED - - # Also save this task instance to the DB. - session.merge(ti) - session.commit() + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + scheduler._schedule_dag_run(dr, session) + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=dag.following_schedule(dr.execution_date), + state=State.RUNNING, + ) + scheduler._schedule_dag_run(dr, session) - self.assertEqual(len(scheduler.executor.queued_tasks), 0, "Check test pre-condition") - scheduler._execute_task_instances(dagbag, session=session) + task_instances_list = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) - self.assertEqual(len(scheduler.executor.queued_tasks), 1) + self.assertEqual(len(task_instances_list), 1) def test_scheduler_verify_pool_full_2_slots_per_task(self): """ @@ -2772,44 +2557,36 @@ def test_scheduler_verify_pool_full_2_slots_per_task(self): pool_slots=2, ) + dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), + include_examples=False, + read_dags_from_db=True) + dagbag.bag_dag(dag=dag, root_dag=dag) + dagbag.sync_to_db() + session = settings.Session() pool = Pool(pool='test_scheduler_verify_pool_full_2_slots_per_task', slots=6) session.add(pool) - orm_dag = DagModel(dag_id=dag.dag_id) - orm_dag.is_paused = False - session.merge(orm_dag) session.commit() dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob(executor=self.null_exec) # Create 5 dagruns, which will create 5 task instances. + date = DEFAULT_DATE for _ in range(5): - dag_file_processor.create_dag_run(dag) - dag_runs = DagRun.find(dag_id="test_scheduler_verify_pool_full_2_slots_per_task") - task_instances_list = dag_file_processor._process_task_instances(dag, dag_runs=dag_runs) - self.assertEqual(len(task_instances_list), 5) - dagbag = self._make_simple_dag_bag([dag]) - - # Recreated part of the scheduler here, to kick off tasks -> executor - for ti_key in task_instances_list: - task = dag.get_task(ti_key[1]) - ti = TaskInstance(task, ti_key[2]) - # Task starts out in the scheduled state. All tasks in the - # scheduled state will be sent to the executor - ti.state = State.SCHEDULED - - # Also save this task instance to the DB. - session.merge(ti) - session.commit() + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.RUNNING, + ) + scheduler._schedule_dag_run(dr, session) + date = dag.following_schedule(date) - self.assertEqual(len(scheduler.executor.queued_tasks), 0, "Check test pre-condition") - scheduler._execute_task_instances(dagbag, session=session) + task_instances_list = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) # As tasks require 2 slots, only 3 can fit into 6 available - self.assertEqual(len(scheduler.executor.queued_tasks), 3) + self.assertEqual(len(task_instances_list), 3) def test_scheduler_verify_priority_and_slots(self): """ @@ -2850,42 +2627,32 @@ def test_scheduler_verify_priority_and_slots(self): priority_weight=1, ) + dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), + include_examples=False, + read_dags_from_db=True) + dagbag.bag_dag(dag=dag, root_dag=dag) + dagbag.sync_to_db() + session = settings.Session() pool = Pool(pool='test_scheduler_verify_priority_and_slots', slots=2) session.add(pool) - orm_dag = DagModel(dag_id=dag.dag_id) - orm_dag.is_paused = False - session.merge(orm_dag) session.commit() dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob(executor=self.null_exec) - dag_file_processor.create_dag_run(dag) - dag_runs = DagRun.find(dag_id="test_scheduler_verify_priority_and_slots") - task_instances_list = dag_file_processor._process_task_instances(dag, dag_runs=dag_runs) - self.assertEqual(len(task_instances_list), 3) - dagbag = self._make_simple_dag_bag([dag]) - - # Recreated part of the scheduler here, to kick off tasks -> executor - for ti_key in task_instances_list: - task = dag.get_task(ti_key[1]) - ti = TaskInstance(task, ti_key[2]) - # Task starts out in the scheduled state. All tasks in the - # scheduled state will be sent to the executor - ti.state = State.SCHEDULED - - # Also save this task instance to the DB. - session.merge(ti) - session.commit() + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + scheduler._schedule_dag_run(dr, session) - self.assertEqual(len(scheduler.executor.queued_tasks), 0, "Check test pre-condition") - scheduler._execute_task_instances(dagbag, session=session) + task_instances_list = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) # Only second and third - self.assertEqual(len(scheduler.executor.queued_tasks), 2) + self.assertEqual(len(task_instances_list), 2) ti0 = session.query(TaskInstance)\ .filter(TaskInstance.task_id == 'test_scheduler_verify_priority_and_slots_t0').first() @@ -2899,69 +2666,13 @@ def test_scheduler_verify_priority_and_slots(self): .filter(TaskInstance.task_id == 'test_scheduler_verify_priority_and_slots_t2').first() self.assertEqual(ti2.state, State.QUEUED) - def test_scheduler_reschedule(self): - """ - Checks if tasks that are not taken up by the executor - get rescheduled - """ - executor = MockExecutor(do_update=False) - dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py")) - dagbag.dags.clear() - - dag = DAG( - dag_id='test_scheduler_reschedule', - start_date=DEFAULT_DATE) - dummy_task = BashOperator( - task_id='dummy', - dag=dag, - owner='airflow', - bash_command='echo 1', - ) - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dag.clear() - dag.is_subdag = False - - with create_session() as session: - orm_dag = DagModel(dag_id=dag.dag_id) - orm_dag.is_paused = False - session.merge(orm_dag) - - dagbag.bag_dag(dag=dag, root_dag=dag) - - @mock.patch('airflow.jobs.scheduler_job.DagBag', return_value=dagbag) - def do_schedule(mock_dagbag): - # Use a empty file since the above mock will return the - # expected DAGs. Also specify only a single file so that it doesn't - # try to schedule the above DAG repeatedly. - with conf_vars({('core', 'mp_start_method'): 'fork'}): - scheduler = SchedulerJob(num_runs=1, - executor=executor, - subdir=os.path.join(settings.DAGS_FOLDER, - "no_dags.py")) - scheduler.heartrate = 0 - scheduler.run() - - do_schedule() # pylint: disable=no-value-for-parameter - with create_session() as session: - ti = session.query(TaskInstance).filter(TaskInstance.dag_id == dag.dag_id, - TaskInstance.task_id == dummy_task.task_id).first() - self.assertEqual(0, len(executor.queued_tasks)) - self.assertEqual(State.SCHEDULED, ti.state) - - executor.do_update = True - do_schedule() # pylint: disable=no-value-for-parameter - self.assertEqual(0, len(executor.queued_tasks)) - ti.refresh_from_db() - self.assertEqual(State.SUCCESS, ti.state) - def test_retry_still_in_executor(self): """ Checks if the scheduler does not put a task in limbo, when a task is retried but is still present in the executor. """ executor = MockExecutor(do_update=False) - dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py")) + dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), include_examples=False) dagbag.dags.clear() dag = DAG( @@ -2984,6 +2695,7 @@ def test_retry_still_in_executor(self): session.merge(orm_dag) dagbag.bag_dag(dag=dag, root_dag=dag) + dagbag.sync_to_db() @mock.patch('airflow.jobs.scheduler_job.DagBag', return_value=dagbag) def do_schedule(mock_dagbag): @@ -3003,10 +2715,6 @@ def do_schedule(mock_dagbag): TaskInstance.task_id == 'test_retry_handling_op').first() ti.task = dag_task1 - # Nothing should be left in the queued_tasks as we don't do update in MockExecutor yet, - # and the queued_tasks will be cleared by scheduler job. - self.assertEqual(0, len(executor.queued_tasks)) - def run_with_error(ti, ignore_ti_state=False): try: ti.run(ignore_ti_state=ignore_ti_state) @@ -3028,13 +2736,6 @@ def run_with_error(ti, ignore_ti_state=False): ti.state = State.SCHEDULED session.merge(ti) - # do schedule - do_schedule() # pylint: disable=no-value-for-parameter - # MockExecutor is not aware of the TaskInstance since we don't do update yet - # and no trace of this TaskInstance will be left in the executor. - self.assertFalse(executor.has_task(ti)) - self.assertEqual(ti.state, State.SCHEDULED) - # To verify that task does get re-queued. executor.do_update = True do_schedule() # pylint: disable=no-value-for-parameter @@ -3064,34 +2765,6 @@ def test_retry_handling_job(self): self.assertEqual(ti.try_number, 2) self.assertEqual(ti.state, State.UP_FOR_RETRY) - def test_dag_with_system_exit(self): - """ - Test to check that a DAG with a system.exit() doesn't break the scheduler. - """ - - dag_id = 'exit_test_dag' - dag_ids = [dag_id] - dag_directory = os.path.join(settings.DAGS_FOLDER, "..", "dags_with_system_exit") - dag_file = os.path.join(dag_directory, 'b_test_scheduler_dags.py') - - dagbag = DagBag(dag_folder=dag_file) - for dag_id in dag_ids: - dag = dagbag.get_dag(dag_id) - dag.clear() - - scheduler = SchedulerJob(dag_ids=dag_ids, - executor=self.null_exec, - subdir=dag_directory, - num_runs=1) - scheduler.run() - with create_session() as session: - tis = session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all() - # Since this dag has no end date, and there's a chance that we'll - # start a and finish two dag parsing processes twice in one loop! - self.assertGreaterEqual( - len(tis), 1, - repr(tis)) - def test_dag_get_active_runs(self): """ Test to check that a DAG returns its active runs @@ -3128,10 +2801,13 @@ def test_dag_get_active_runs(self): session.commit() session.close() - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) dag1.clear() - dr = dag_file_processor.create_dag_run(dag1) + dr = dag1.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=start_date, + state=State.RUNNING, + ) # We had better get a dag run self.assertIsNotNone(dr) @@ -3178,6 +2854,7 @@ def test_add_unparseable_file_after_sched_start_creates_import_error(self): with open(unparseable_filename, 'w') as unparseable_file: unparseable_file.writelines(UNPARSEABLE_DAG_FILE_CONTENTS) + print("Second run") self.run_single_scheduler_loop_with_no_dags(dags_folder) finally: shutil.rmtree(dags_folder) @@ -3459,15 +3136,17 @@ def test_adopt_or_reset_orphaned_tasks_external_triggered_dag(self): task_id = dag_id + '_task' DummyOperator(task_id=task_id, dag=dag) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() - dr1 = dag_file_processor.create_dag_run(dag, session=session) + dr1 = dag.create_dagrun(run_type=DagRunType.SCHEDULED, + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + external_trigger=True, + session=session) ti = dr1.get_task_instances(session=session)[0] - dr1.state = State.RUNNING ti.state = State.SCHEDULED - dr1.external_trigger = True session.merge(ti) session.merge(dr1) session.commit() @@ -3481,17 +3160,18 @@ def test_adopt_or_reset_orphaned_tasks_backfill_dag(self): task_id = dag_id + '_task' DummyOperator(task_id=task_id, dag=dag) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() session.add(scheduler) session.flush() - dr1 = dag_file_processor.create_dag_run(dag, session=session) + dr1 = dag.create_dagrun(run_type=DagRunType.BACKFILL_JOB, + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session) ti = dr1.get_task_instances(session=session)[0] ti.state = State.SCHEDULED - dr1.state = State.RUNNING - dr1.run_type = DagRunType.BACKFILL_JOB.value session.merge(ti) session.merge(dr1) session.flush() @@ -3528,14 +3208,16 @@ def test_reset_orphaned_tasks_no_orphans(self): task_id = dag_id + '_task' DummyOperator(task_id=task_id, dag=dag) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() session.add(scheduler) session.flush() - dr1 = dag_file_processor.create_dag_run(dag, session=session) - dr1.state = State.RUNNING + dr1 = dag.create_dagrun(run_type=DagRunType.SCHEDULED, + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session) tis = dr1.get_task_instances(session=session) tis[0].state = State.RUNNING tis[0].queued_by_job_id = scheduler.id @@ -3554,14 +3236,16 @@ def test_reset_orphaned_tasks_non_running_dagruns(self): task_id = dag_id + '_task' DummyOperator(task_id=task_id, dag=dag) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() session.add(scheduler) session.flush() - dr1 = dag_file_processor.create_dag_run(dag, session=session) - dr1.state = State.SUCCESS + dr1 = dag.create_dagrun(run_type=DagRunType.SCHEDULED, + state=State.SUCCESS, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session) tis = dr1.get_task_instances(session=session) self.assertEqual(1, len(tis)) tis[0].state = State.SCHEDULED @@ -3622,6 +3306,7 @@ def test_adopt_or_reset_orphaned_tasks_stale_scheduler_jobs(self): session.rollback() +@pytest.mark.xfail(reason="Work out where this goes") def test_task_with_upstream_skip_process_task_instances(): """ Test if _process_task_instances puts a task instance into SKIPPED state if any of its @@ -3663,6 +3348,7 @@ def test_task_with_upstream_skip_process_task_instances(): assert tis[dummy3.task_id].state == State.SKIPPED +@pytest.mark.xfail(reason="Work why this didn't infinite loop before!") class TestSchedulerJobQueriesCount(unittest.TestCase): """ These tests are designed to detect changes in the number of queries for @@ -3703,7 +3389,8 @@ def test_execute_queries_count_with_harvested_dags(self, expected_query_count, d dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE, include_examples=False) for i, dag in enumerate(dagbag.dags.values()): - dr = dag.create_dagrun(state=State.RUNNING, run_id=f"{DagRunType.MANUAL.value}__{i}") + dr = dag.create_dagrun(state=State.RUNNING, run_id=f"{DagRunType.MANUAL.value}__{i}", + execution_date=DEFAULT_DATE) for ti in dr.get_task_instances(): ti.set_state(state=State.SCHEDULED) @@ -3758,4 +3445,5 @@ def test_execute_queries_count_no_harvested_dags(self, expected_query_count, dag job.processor_agent = mock_agent with assert_queries_count(expected_query_count): - job._run_scheduler_loop() + with create_session() as session: + job._do_scheduling(session) diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index eac4f1ea5c688..16230e6adf1e8 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -24,6 +24,7 @@ import re import unittest from contextlib import redirect_stdout +from datetime import timedelta from tempfile import NamedTemporaryFile from typing import Optional from unittest import mock @@ -31,12 +32,12 @@ import pendulum from dateutil.relativedelta import relativedelta +from freezegun import freeze_time from parameterized import parameterized from airflow import models, settings from airflow.configuration import conf from airflow.exceptions import AirflowException, DuplicateTaskIdFound -from airflow.jobs.scheduler_job import DagFileProcessor from airflow.models import DAG, DagModel, DagRun, DagTag, TaskFail, TaskInstance as TI from airflow.models.baseoperator import BaseOperator from airflow.operators.bash import BashOperator @@ -53,6 +54,8 @@ from tests.test_utils.asserts import assert_queries_count from tests.test_utils.db import clear_db_dags, clear_db_runs +TEST_DATE = datetime_tz(2015, 1, 2, 0, 0) + class TestDag(unittest.TestCase): @@ -652,6 +655,25 @@ def test_following_previous_schedule_daily_dag_cet_to_cest(self): self.assertEqual(prev_local.isoformat(), "2018-03-24T03:00:00+01:00") self.assertEqual(prev.isoformat(), "2018-03-24T02:00:00+00:00") + def test_following_schedule_relativedelta(self): + """ + Tests following_schedule a dag with a relativedelta schedule_interval + """ + dag_id = "test_schedule_dag_relativedelta" + delta = relativedelta(hours=+1) + dag = DAG(dag_id=dag_id, + schedule_interval=delta) + dag.add_task(BaseOperator( + task_id="faketastic", + owner='Also fake', + start_date=TEST_DATE)) + + _next = dag.following_schedule(TEST_DATE) + self.assertEqual(_next.isoformat(), "2015-01-02T01:00:00+00:00") + + _next = dag.following_schedule(_next) + self.assertEqual(_next.isoformat(), "2015-01-02T02:00:00+00:00") + def test_dagtag_repr(self): clear_db_dags() dag = DAG('dag-test-dagtag', start_date=DEFAULT_DATE, tags=['tag-1', 'tag-2']) @@ -667,7 +689,7 @@ def test_bulk_sync_to_db(self): DAG(f'dag-bulk-sync-{i}', start_date=DEFAULT_DATE, tags=["test-dag"]) for i in range(0, 4) ] - with assert_queries_count(3): + with assert_queries_count(5): DAG.bulk_sync_to_db(dags) with create_session() as session: self.assertEqual( @@ -684,14 +706,14 @@ def test_bulk_sync_to_db(self): set(session.query(DagTag.dag_id, DagTag.name).all()) ) # Re-sync should do fewer queries - with assert_queries_count(2): + with assert_queries_count(3): DAG.bulk_sync_to_db(dags) - with assert_queries_count(2): + with assert_queries_count(3): DAG.bulk_sync_to_db(dags) # Adding tags for dag in dags: dag.tags.append("test-dag2") - with assert_queries_count(3): + with assert_queries_count(4): DAG.bulk_sync_to_db(dags) with create_session() as session: self.assertEqual( @@ -714,7 +736,7 @@ def test_bulk_sync_to_db(self): # Removing tags for dag in dags: dag.tags.remove("test-dag") - with assert_queries_count(3): + with assert_queries_count(4): DAG.bulk_sync_to_db(dags) with create_session() as session: self.assertEqual( @@ -731,8 +753,46 @@ def test_bulk_sync_to_db(self): set(session.query(DagTag.dag_id, DagTag.name).all()) ) - @patch('airflow.models.dag.timezone.utcnow') - def test_sync_to_db(self, mock_now): + def test_bulk_sync_to_db_max_active_runs(self): + """ + Test that DagModel.next_dagrun_create_after is set to NULL when the dag cannot be created due to max + active runs being hit. + """ + dag = DAG( + dag_id='test_scheduler_verify_max_active_runs', + start_date=DEFAULT_DATE) + dag.max_active_runs = 1 + + DummyOperator( + task_id='dummy', + dag=dag, + owner='airflow') + + session = settings.Session() + dag.clear() + DAG.bulk_sync_to_db([dag], session) + + model = session.query(DagModel).get((dag.dag_id,)) + + period_end = dag.following_schedule(DEFAULT_DATE) + assert model.next_dagrun == DEFAULT_DATE + assert model.next_dagrun_create_after == period_end + + dr = dag.create_dagrun( + state=State.RUNNING, + execution_date=model.next_dagrun, + run_type=DagRunType.SCHEDULED, + session=session, + ) + assert dr is not None + DAG.bulk_sync_to_db([dag]) + + model = session.query(DagModel).get((dag.dag_id,)) + assert model.next_dagrun == period_end + # We signle "at max active runs" by saying this run is never eligible to be created + assert model.next_dagrun_create_after is None + + def test_sync_to_db(self): dag = DAG( 'dag', start_date=DEFAULT_DATE, @@ -748,31 +808,25 @@ def test_sync_to_db(self, mock_now): owner='owner2', subdag=subdag ) - now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC')) - mock_now.return_value = now session = settings.Session() dag.sync_to_db(session=session) orm_dag = session.query(DagModel).filter(DagModel.dag_id == 'dag').one() self.assertEqual(set(orm_dag.owners.split(', ')), {'owner1', 'owner2'}) - self.assertEqual(orm_dag.last_scheduler_run, now) self.assertTrue(orm_dag.is_active) self.assertIsNotNone(orm_dag.default_view) self.assertEqual(orm_dag.default_view, conf.get('webserver', 'dag_default_view').lower()) self.assertEqual(orm_dag.safe_dag_id, 'dag') - orm_subdag = session.query(DagModel).filter( - DagModel.dag_id == 'dag.subtask').one() + orm_subdag = session.query(DagModel).filter(DagModel.dag_id == 'dag.subtask').one() self.assertEqual(set(orm_subdag.owners.split(', ')), {'owner1', 'owner2'}) - self.assertEqual(orm_subdag.last_scheduler_run, now) self.assertTrue(orm_subdag.is_active) self.assertEqual(orm_subdag.safe_dag_id, 'dag__dot__subtask') self.assertEqual(orm_subdag.fileloc, orm_dag.fileloc) session.close() - @patch('airflow.models.dag.timezone.utcnow') - def test_sync_to_db_default_view(self, mock_now): + def test_sync_to_db_default_view(self): dag = DAG( 'dag', start_date=DEFAULT_DATE, @@ -788,8 +842,6 @@ def test_sync_to_db_default_view(self, mock_now): start_date=DEFAULT_DATE, ) ) - now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC')) - mock_now.return_value = now session = settings.Session() dag.sync_to_db(session=session) @@ -1038,65 +1090,25 @@ def test_schedule_dag_no_previous_runs(self): dag.add_task(BaseOperator( task_id="faketastic", owner='Also fake', - start_date=datetime_tz(2015, 1, 2, 0, 0))) + start_date=TEST_DATE)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag_run = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dag_run) - self.assertEqual(dag.dag_id, dag_run.dag_id) - self.assertIsNotNone(dag_run.run_id) - self.assertNotEqual('', dag_run.run_id) - self.assertEqual( - datetime_tz(2015, 1, 2, 0, 0), - dag_run.execution_date, - msg='dag_run.execution_date did not match expectation: {0}' - .format(dag_run.execution_date) + dag_run = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=TEST_DATE, + state=State.RUNNING, ) - self.assertEqual(State.RUNNING, dag_run.state) - self.assertFalse(dag_run.external_trigger) - dag.clear() - self._clean_up(dag_id) - - def test_schedule_dag_relativedelta(self): - """ - Tests scheduling a dag with a relativedelta schedule_interval - """ - dag_id = "test_schedule_dag_relativedelta" - delta = relativedelta(hours=+1) - dag = DAG(dag_id=dag_id, - schedule_interval=delta) - dag.add_task(BaseOperator( - task_id="faketastic", - owner='Also fake', - start_date=datetime_tz(2015, 1, 2, 0, 0))) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag_run = dag_file_processor.create_dag_run(dag) self.assertIsNotNone(dag_run) self.assertEqual(dag.dag_id, dag_run.dag_id) self.assertIsNotNone(dag_run.run_id) self.assertNotEqual('', dag_run.run_id) self.assertEqual( - datetime_tz(2015, 1, 2, 0, 0), + TEST_DATE, dag_run.execution_date, msg='dag_run.execution_date did not match expectation: {0}' .format(dag_run.execution_date) ) self.assertEqual(State.RUNNING, dag_run.state) self.assertFalse(dag_run.external_trigger) - dag_run2 = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dag_run2) - self.assertEqual(dag.dag_id, dag_run2.dag_id) - self.assertIsNotNone(dag_run2.run_id) - self.assertNotEqual('', dag_run2.run_id) - self.assertEqual( - datetime_tz(2015, 1, 2, 0, 0) + delta, - dag_run2.execution_date, - msg='dag_run2.execution_date did not match expectation: {0}' - .format(dag_run2.execution_date) - ) - self.assertEqual(State.RUNNING, dag_run2.state) - self.assertFalse(dag_run2.external_trigger) dag.clear() self._clean_up(dag_id) @@ -1113,13 +1125,13 @@ def test_dag_handle_callback_crash(self, mock_stats): # callback with invalid signature should not cause crashes on_success_callback=lambda: 1, on_failure_callback=mock_callback_with_exception) + when = TEST_DATE dag.add_task(BaseOperator( task_id="faketastic", owner='Also fake', - start_date=datetime_tz(2015, 1, 2, 0, 0))) + start_date=when)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag_run = dag_file_processor.create_dag_run(dag) + dag_run = dag.create_dagrun(State.RUNNING, when, run_type=DagRunType.MANUAL) # should not rause any exception dag.handle_callback(dag_run, success=False) dag.handle_callback(dag_run, success=True) @@ -1129,7 +1141,7 @@ def test_dag_handle_callback_crash(self, mock_stats): dag.clear() self._clean_up(dag_id) - def test_schedule_dag_fake_scheduled_previous(self): + def test_next_dagrun_after_fake_scheduled_previous(self): """ Test scheduling a dag where there is a prior DagRun which has the same run_id as the next run should have @@ -1144,24 +1156,19 @@ def test_schedule_dag_fake_scheduled_previous(self): owner='Also fake', start_date=DEFAULT_DATE)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) dag.create_dagrun(run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, state=State.SUCCESS, external_trigger=True) - dag_run = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dag_run) - self.assertEqual(dag.dag_id, dag_run.dag_id) - self.assertIsNotNone(dag_run.run_id) - self.assertNotEqual('', dag_run.run_id) - self.assertEqual( - DEFAULT_DATE + delta, - dag_run.execution_date, - msg='dag_run.execution_date did not match expectation: {0}' - .format(dag_run.execution_date) - ) - self.assertEqual(State.RUNNING, dag_run.state) - self.assertFalse(dag_run.external_trigger) + dag.sync_to_db() + with create_session() as session: + model = session.query(DagModel).get((dag.dag_id,)) + + # Even though there is a run for this date already, it is marked as manual/external, so we should + # create a scheduled one anyway! + assert model.next_dagrun == DEFAULT_DATE + assert model.next_dagrun_create_after == dag.following_schedule(DEFAULT_DATE) + self._clean_up(dag_id) def test_schedule_dag_once(self): @@ -1176,13 +1183,17 @@ def test_schedule_dag_once(self): dag.add_task(BaseOperator( task_id="faketastic", owner='Also fake', - start_date=datetime_tz(2015, 1, 2, 0, 0))) - dag_run = DagFileProcessor(dag_ids=[], log=mock.MagicMock()).create_dag_run(dag) - dag_run2 = DagFileProcessor(dag_ids=[], log=mock.MagicMock()).create_dag_run(dag) + start_date=TEST_DATE)) + dag.create_dagrun(run_type=DagRunType.SCHEDULED, + execution_date=TEST_DATE, + state=State.SUCCESS) - self.assertIsNotNone(dag_run) - self.assertIsNone(dag_run2) - dag.clear() + dag.sync_to_db() + with create_session() as session: + model = session.query(DagModel).get((dag.dag_id,)) + + assert model.next_dagrun is None + assert model.next_dagrun_create_after is None self._clean_up(dag_id) def test_fractional_seconds(self): @@ -1195,7 +1206,7 @@ def test_fractional_seconds(self): dag.add_task(BaseOperator( task_id="faketastic", owner='Also fake', - start_date=datetime_tz(2015, 1, 2, 0, 0))) + start_date=TEST_DATE)) start_date = timezone.utcnow() @@ -1215,77 +1226,6 @@ def test_fractional_seconds(self): "dag run start_date loses precision ") self._clean_up(dag_id) - def test_schedule_dag_start_end_dates(self): - """ - Tests that an attempt to schedule a task after the Dag's end_date - does not succeed. - """ - delta = datetime.timedelta(hours=1) - runs = 3 - start_date = DEFAULT_DATE - end_date = start_date + (runs - 1) * delta - dag_id = "test_schedule_dag_start_end_dates" - dag = DAG(dag_id=dag_id, - start_date=start_date, - end_date=end_date, - schedule_interval=delta) - dag.add_task(BaseOperator(task_id='faketastic', owner='Also fake')) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - # Create and schedule the dag runs - dag_runs = [] - for _ in range(runs): - dag_runs.append(dag_file_processor.create_dag_run(dag)) - - additional_dag_run = dag_file_processor.create_dag_run(dag) - - for dag_run in dag_runs: - self.assertIsNotNone(dag_run) - - self.assertIsNone(additional_dag_run) - self._clean_up(dag_id) - - def test_schedule_dag_no_end_date_up_to_today_only(self): - """ - Tests that a Dag created without an end_date can only be scheduled up - to and including the current datetime. - - For example, if today is 2016-01-01 and we are scheduling from a - start_date of 2015-01-01, only jobs up to, but not including - 2016-01-01 should be scheduled. - """ - session = settings.Session() - delta = datetime.timedelta(days=1) - now = pendulum.now('UTC') - start_date = now.subtract(weeks=1) - - runs = (now - start_date).days - dag_id = "test_schedule_dag_no_end_date_up_to_today_only" - dag = DAG(dag_id=dag_id, - start_date=start_date, - schedule_interval=delta) - dag.add_task(BaseOperator(task_id='faketastic', owner='Also fake')) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag_runs = [] - for _ in range(runs): - dag_run = dag_file_processor.create_dag_run(dag) - dag_runs.append(dag_run) - - # Mark the DagRun as complete - dag_run.state = State.SUCCESS - session.merge(dag_run) - session.commit() - - # Attempt to schedule an additional dag run (for 2016-01-01) - additional_dag_run = dag_file_processor.create_dag_run(dag) - - for dag_run in dag_runs: - self.assertIsNotNone(dag_run) - - self.assertIsNone(additional_dag_run) - self._clean_up(dag_id) - def test_pickling(self): test_dag_id = 'test_pickling' args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} @@ -1489,6 +1429,249 @@ def test_clear_dag(self, ti_state_begin, ti_state_end: Optional[str]): self.assertEqual(task_instance.state, ti_state_end) self._clean_up(dag_id) + def test_next_dagrun_after_date_once(self): + dag = DAG( + 'test_scheduler_dagrun_once', + start_date=timezone.datetime(2015, 1, 1), + schedule_interval="@once") + + next_date = dag.next_dagrun_after_date(None) + + assert next_date == timezone.datetime(2015, 1, 1) + + next_date = dag.next_dagrun_after_date(next_date) + assert next_date is None + + def test_next_dagrun_after_date_start_end_dates(self): + """ + Tests that an attempt to schedule a task after the Dag's end_date + does not succeed. + """ + delta = datetime.timedelta(hours=1) + runs = 3 + start_date = DEFAULT_DATE + end_date = start_date + (runs - 1) * delta + dag_id = "test_schedule_dag_start_end_dates" + dag = DAG(dag_id=dag_id, + start_date=start_date, + end_date=end_date, + schedule_interval=delta) + dag.add_task(BaseOperator(task_id='faketastic', owner='Also fake')) + + # Create and schedule the dag runs + dates = [] + date = None + for _ in range(runs): + date = dag.next_dagrun_after_date(date) + dates.append(date) + + for date in dates: + assert date is not None + + assert dates[-1] == end_date + + assert dag.next_dagrun_after_date(date) is None + + def test_next_dagrun_after_date_catcup(self): + """ + Test to check that a DAG with catchup = False only schedules beginning now, not back to the start date + """ + + def make_dag(dag_id, schedule_interval, start_date, catchup): + default_args = { + 'owner': 'airflow', + 'depends_on_past': False, + } + dag = DAG(dag_id, + schedule_interval=schedule_interval, + start_date=start_date, + catchup=catchup, + default_args=default_args) + + op1 = DummyOperator(task_id='t1', dag=dag) + op2 = DummyOperator(task_id='t2', dag=dag) + op3 = DummyOperator(task_id='t3', dag=dag) + op1 >> op2 >> op3 + + return dag + + now = timezone.utcnow() + six_hours_ago_to_the_hour = (now - datetime.timedelta(hours=6)).replace( + minute=0, second=0, microsecond=0) + half_an_hour_ago = now - datetime.timedelta(minutes=30) + two_hours_ago = now - datetime.timedelta(hours=2) + + dag1 = make_dag(dag_id='dag_without_catchup_ten_minute', + schedule_interval='*/10 * * * *', + start_date=six_hours_ago_to_the_hour, + catchup=False) + next_date = dag1.next_dagrun_after_date(None) + # The DR should be scheduled in the last half an hour, not 6 hours ago + assert next_date > half_an_hour_ago + assert next_date < timezone.utcnow() + + dag2 = make_dag(dag_id='dag_without_catchup_hourly', + schedule_interval='@hourly', + start_date=six_hours_ago_to_the_hour, + catchup=False) + + next_date = dag2.next_dagrun_after_date(None) + # The DR should be scheduled in the last 2 hours, not 6 hours ago + assert next_date > two_hours_ago + # The DR should be scheduled BEFORE now + assert next_date < timezone.utcnow() + + dag3 = make_dag(dag_id='dag_without_catchup_once', + schedule_interval='@once', + start_date=six_hours_ago_to_the_hour, + catchup=False) + + next_date = dag3.next_dagrun_after_date(None) + # The DR should be scheduled in the last 2 hours, not 6 hours ago + assert next_date == six_hours_ago_to_the_hour + + @freeze_time(timezone.datetime(2020, 1, 5)) + def test_next_dagrun_after_date_timedelta_schedule_and_catchup_false(self): + """ + Test that the dag file processor does not create multiple dagruns + if a dag is scheduled with 'timedelta' and catchup=False + """ + dag = DAG( + 'test_scheduler_dagrun_once_with_timedelta_and_catchup_false', + start_date=timezone.datetime(2015, 1, 1), + schedule_interval=timedelta(days=1), + catchup=False) + + next_date = dag.next_dagrun_after_date(None) + assert next_date == timezone.datetime(2020, 1, 4) + + # The date to create is in the future, this is handled by "DagModel.dags_needing_dagruns" + next_date = dag.next_dagrun_after_date(next_date) + assert next_date == timezone.datetime(2020, 1, 5) + + @freeze_time(timezone.datetime(2020, 5, 4)) + def test_next_dagrun_after_date_timedelta_schedule_and_catchup_true(self): + """ + Test that the dag file processor creates multiple dagruns + if a dag is scheduled with 'timedelta' and catchup=True + """ + dag = DAG( + 'test_scheduler_dagrun_once_with_timedelta_and_catchup_true', + start_date=timezone.datetime(2020, 5, 1), + schedule_interval=timedelta(days=1), + catchup=True) + + next_date = dag.next_dagrun_after_date(None) + assert next_date == timezone.datetime(2020, 5, 1) + + next_date = dag.next_dagrun_after_date(next_date) + assert next_date == timezone.datetime(2020, 5, 2) + + next_date = dag.next_dagrun_after_date(next_date) + assert next_date == timezone.datetime(2020, 5, 3) + + # The date to create is in the future, this is handled by "DagModel.dags_needing_dagruns" + next_date = dag.next_dagrun_after_date(next_date) + assert next_date == timezone.datetime(2020, 5, 4) + + def test_next_dagrun_after_auto_align(self): + """ + Test if the schedule_interval will be auto aligned with the start_date + such that if the start_date coincides with the schedule the first + execution_date will be start_date, otherwise it will be start_date + + interval. + """ + dag = DAG( + dag_id='test_scheduler_auto_align_1', + start_date=timezone.datetime(2016, 1, 1, 10, 10, 0), + schedule_interval="4 5 * * *" + ) + DummyOperator( + task_id='dummy', + dag=dag, + owner='airflow') + + next_date = dag.next_dagrun_after_date(None) + assert next_date == timezone.datetime(2016, 1, 2, 5, 4) + + dag = DAG( + dag_id='test_scheduler_auto_align_2', + start_date=timezone.datetime(2016, 1, 1, 10, 10, 0), + schedule_interval="10 10 * * *" + ) + DummyOperator( + task_id='dummy', + dag=dag, + owner='airflow') + + next_date = dag.next_dagrun_after_date(None) + assert next_date == timezone.datetime(2016, 1, 1, 10, 10) + + +class TestDagModel: + + def test_dags_needing_dagruns_not_too_early(self): + dag = DAG( + dag_id='far_future_dag', + start_date=timezone.datetime(2200, 1, 1)) + DummyOperator( + task_id='dummy', + dag=dag, + owner='airflow') + + session = settings.Session() + orm_dag = DagModel( + dag_id=dag.dag_id, + concurrency=1, + has_task_concurrency_limits=False, + next_dagrun=dag.start_date, + next_dagrun_create_after=timezone.datetime(2200, 1, 2), + is_active=True, + ) + session.add(orm_dag) + session.flush() + + dag_models = DagModel.dags_needing_dagruns(session).all() + assert dag_models == [] + + session.rollback() + session.close() + + def test_dags_needing_dagruns_only_unpaused(self): + """ + We should never create dagruns for unpaused DAGs + """ + dag = DAG( + dag_id='test_dags', + start_date=DEFAULT_DATE) + DummyOperator( + task_id='dummy', + dag=dag, + owner='airflow') + + session = settings.Session() + orm_dag = DagModel( + dag_id=dag.dag_id, + has_task_concurrency_limits=False, + next_dagrun=dag.start_date, + next_dagrun_create_after=dag.following_schedule(DEFAULT_DATE), + is_active=True, + ) + session.add(orm_dag) + session.flush() + + models = DagModel.dags_needing_dagruns(session).all() + assert models == [orm_dag] + + orm_dag.is_paused = True + session.flush() + + models = DagModel.dags_needing_dagruns(session).all() + assert models == [] + + session.rollback() + session.close() + class TestQueries(unittest.TestCase): @@ -1506,8 +1689,9 @@ def test_count_number_queries(self, tasks_count): dag = DAG('test_dagrun_query_count', start_date=DEFAULT_DATE) for i in range(tasks_count): DummyOperator(task_id=f'dummy_task_{i}', owner='test', dag=dag) - with assert_queries_count(3): + with assert_queries_count(2): dag.create_dagrun( run_id="test_dagrun_query_count", - state=State.RUNNING + state=State.RUNNING, + execution_date=TEST_DATE, ) diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 78a02de4f4c58..4753118fe1c86 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -23,7 +23,7 @@ from parameterized import parameterized from airflow import models, settings -from airflow.models import DAG, DagBag, TaskInstance as TI, clear_task_instances +from airflow.models import DAG, DagBag, DagModel, TaskInstance as TI, clear_task_instances from airflow.models.dagrun import DagRun from airflow.operators.dummy_operator import DummyOperator from airflow.operators.python import ShortCircuitOperator @@ -662,3 +662,45 @@ def test_wait_for_downstream(self, prev_ti_state, is_ti_success): ti.set_state(State.QUEUED) ti.run() self.assertEqual(ti.state == State.SUCCESS, is_ti_success) + + def test_next_dagruns_to_examine_only_unpaused(self): + """ + Check that "next_dagruns_to_examine" ignores runs from paused/inactive DAGs + """ + + dag = DAG( + dag_id='test_dags', + start_date=DEFAULT_DATE) + DummyOperator( + task_id='dummy', + dag=dag, + owner='airflow') + + session = settings.Session() + orm_dag = DagModel( + dag_id=dag.dag_id, + has_task_concurrency_limits=False, + next_dagrun=dag.start_date, + next_dagrun_create_after=dag.following_schedule(DEFAULT_DATE), + is_active=True, + ) + session.add(orm_dag) + session.flush() + dr = dag.create_dagrun(run_type=DagRunType.SCHEDULED, + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session) + + runs = DagRun.next_dagruns_to_examine(session).all() + + assert runs == [dr] + + orm_dag.is_paused = True + session.flush() + + runs = DagRun.next_dagruns_to_examine(session).all() + assert runs == [] + + session.rollback() + session.close() diff --git a/tests/test_utils/mock_executor.py b/tests/test_utils/mock_executor.py index 0143c95fae853..746caf4dfcb6a 100644 --- a/tests/test_utils/mock_executor.py +++ b/tests/test_utils/mock_executor.py @@ -67,10 +67,9 @@ def sort_by(item): open_slots = self.parallelism - len(self.running) sorted_queue = sorted(self.queued_tasks.items(), key=sort_by) for index in range(min((open_slots, len(sorted_queue)))): - (key, (_, _, _, simple_ti)) = sorted_queue[index] + (key, (_, _, _, ti)) = sorted_queue[index] self.queued_tasks.pop(key) state = self.mock_task_results[key] - ti = simple_ti.construct_task_instance(session=session, lock_for_update=True) ti.set_state(state, session=session) self.change_state(key, state) diff --git a/tests/utils/test_dag_processing.py b/tests/utils/test_dag_processing.py index 9a51c761a8655..9cdce92426d42 100644 --- a/tests/utils/test_dag_processing.py +++ b/tests/utils/test_dag_processing.py @@ -25,10 +25,12 @@ from unittest import mock from unittest.mock import MagicMock, PropertyMock +import pytest + from airflow.configuration import conf from airflow.jobs.local_task_job import LocalTaskJob as LJ from airflow.jobs.scheduler_job import DagFileProcessorProcess -from airflow.models import DagBag, TaskInstance as TI +from airflow.models import DagBag, DagModel, TaskInstance as TI from airflow.models.taskinstance import SimpleTaskInstance from airflow.utils import timezone from airflow.utils.dag_processing import ( @@ -40,7 +42,7 @@ from airflow.utils.state import State from tests.test_logging_config import SETTINGS_FILE_VALID, settings_context from tests.test_utils.config import conf_vars -from tests.test_utils.db import clear_db_runs +from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags TEST_DAG_FOLDER = os.path.join( os.path.dirname(os.path.realpath(__file__)), os.pardir, 'dags') @@ -322,6 +324,50 @@ def test_kill_timed_out_processors_no_kill(self, mock_dag_file_processor, mock_p manager._kill_timed_out_processors() mock_dag_file_processor.kill.assert_not_called() + @conf_vars({('core', 'load_examples'): 'False'}) + @pytest.mark.execution_timeout(10) + def test_dag_with_system_exit(self): + """ + Test to check that a DAG with a system.exit() doesn't break the scheduler. + """ + + # We need to _actually_ parse the files here to test the behaviour. + # Right now the parsing code lives in SchedulerJob, even though it's + # called via utils.dag_processing. + from airflow.jobs.scheduler_job import SchedulerJob + + dag_id = 'exit_test_dag' + dag_directory = os.path.normpath(os.path.join(TEST_DAG_FOLDER, os.pardir, "dags_with_system_exit")) + + # Delete the one valid DAG/SerializedDAG, and check that it gets re-created + clear_db_dags() + clear_db_serialized_dags() + + child_pipe, parent_pipe = multiprocessing.Pipe() + + manager = DagFileProcessorManager( + dag_directory=dag_directory, + dag_ids=[], + max_runs=1, + processor_factory=SchedulerJob._create_dag_file_processor, + processor_timeout=timedelta(seconds=5), + signal_conn=child_pipe, + pickle_dags=False, + async_mode=True) + + manager._run_parsing_loop() + + while parent_pipe.poll(timeout=None): + result = parent_pipe.recv() + if isinstance(result, DagParsingStat) and result.done: + break + + # Three files in folder should be processed + assert len(result.file_paths) == 3 + + with create_session() as session: + assert session.query(DagModel).get(dag_id) is not None + class TestDagFileProcessorAgent(unittest.TestCase): def setUp(self): From 807819d142d583746fac8125fd3399954e6dcf0b Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 21 Sep 2020 23:01:41 +0100 Subject: [PATCH 17/70] fixup! Officially support running more than one scheduler concurrently. --- airflow/jobs/scheduler_job.py | 48 +++--- airflow/models/dag.py | 36 ++--- tests/jobs/test_scheduler_job.py | 266 +++++++++++++++++-------------- 3 files changed, 187 insertions(+), 163 deletions(-) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 0a695fdee9d4a..e10e1b27badd5 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -1464,16 +1464,22 @@ def _create_dag_run(self, dag_model: DagModel, dag: DAG, session: Session) -> No Unconditionally create a DAG run for the given DAG, and update the dag_model's fields to control if/when the next DAGRun should be created """ - next_run_date = dag_model.next_dagrun dag.create_dagrun( run_type=DagRunType.SCHEDULED, - execution_date=next_run_date, + execution_date=dag_model.next_dagrun, start_date=timezone.utcnow(), state=State.RUNNING, external_trigger=False, session=session ) + self._update_dag_next_dagrun(dag_model, dag, session) + + # TODO[HA]: Should we do a session.flush() so we don't have to keep lots of state/object in + # memory for larger dags? or expunge_all() + + def _update_dag_next_dagrun(self, dag_model: DagModel, dag: DAG, session: Session) -> None: + # Check max_active_runs, to see if we are _now_ at the limit for this dag? (we've just created # one after all) active_runs_of_dag = session.query(func.count('*')).filter( @@ -1489,19 +1495,9 @@ def _create_dag_run(self, dag_model: DagModel, dag: DAG, session: Session) -> No "DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs", dag.dag_id, active_runs_of_dag, dag.max_active_runs ) - dag_model.next_dagrun = None dag_model.next_dagrun_create_after = None else: - next_dagrun_info = dag.next_dagrun_info(next_run_date) - if next_dagrun_info: - dag_model.next_dagrun = next_dagrun_info['execution_date'] - dag_model.next_dagrun_create_after = next_dagrun_info['can_be_created_after'] - else: - dag_model.next_dagrun = None - dag_model.next_dagrun_create_after = None - - # TODO[HA]: Should we do a session.flush() so we don't have to keep lots of state/object in - # memory for larger dags? or expunge_all() + dag_model.next_dagrun, dag_model.next_dagrun_create_after = dag.next_dagrun_info(dag_model.next_dagrun) def _schedule_dag_run(self, dag_run: DagRun, session: Session) -> int: """ @@ -1509,15 +1505,31 @@ def _schedule_dag_run(self, dag_run: DagRun, session: Session) -> int: :return: Number of tasks scheduled """ - dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session) + dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session) - if not dag_run.dag: + if not dag: self.log.error( "Couldn't find dag %s in DagBag/DB!", dag_run.dag_id ) return 0 - if dag_run.execution_date > timezone.utcnow() and not dag_run.dag.allow_future_exec_dates: + if ( + dag_run.start_date and dag.dagrun_timeout and + dag_run.start_date < timezone.utcnow() - dag.dagrun_timeout + ): + dag_run.state = State.FAILED + dag_run.end_date = timezone.utcnow() + self.log.info("Run %s of %s has timed-out", dag_run.run_id, dag_run.dag_id) + session.flush() + + # Work out if we should allow creating a new DagRun now? + self._update_dag_next_dagrun(session.query(DagModel).get(dag_run.dag_id), dag, session) + + # TODO[HA] run `dag.handle_callback` via the DagFileProcessor + + return 0 + + if dag_run.execution_date > timezone.utcnow() and not dag.allow_future_exec_dates: self.log.error( "Execution date is in future: %s", dag_run.execution_date @@ -1533,13 +1545,13 @@ def _schedule_dag_run(self, dag_run: DagRun, session: Session) -> int: # The longer term fix would be to have `clear` do this, and put DagRuns # in to the queued state, then take DRs out of queued before creating # any new ones - if dag_run.dag.max_active_runs: + if dag.max_active_runs: currently_active_runs = session.query(func.count(TI.execution_date.distinct())).filter( TI.dag_id == dag_run.dag_id, TI.state.notin_(State.finished()) ).scalar() - if currently_active_runs >= dag_run.dag.max_active_runs: + if currently_active_runs >= dag.max_active_runs: return 0 # TODO[HA]: Run verify_integrity, but only if the serialized_dag has changed diff --git a/airflow/models/dag.py b/airflow/models/dag.py index fbf1c71638e86..8413f16d42884 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -28,7 +28,7 @@ from collections import OrderedDict from datetime import datetime, timedelta from typing import ( - TYPE_CHECKING, Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Type, Union, cast, + TYPE_CHECKING, Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, Type, Union, cast, ) import jinja2 @@ -156,8 +156,7 @@ class DAG(BaseDag, LoggingMixin): :type max_active_runs: int :param dagrun_timeout: specify how long a DagRun should be up before timing out / failing, so that new DagRuns can be created. The timeout - is only enforced for scheduled DagRuns, and only once the - # of active DagRuns == max_active_runs. + is only enforced for scheduled DagRuns. :type dagrun_timeout: datetime.timedelta :param sla_miss_callback: specify a function to call when reporting SLA timeouts. @@ -467,7 +466,10 @@ def previous_schedule(self, dttm): elif self.normalized_schedule_interval is not None: return timezone.convert_to_utc(dttm - self.normalized_schedule_interval) - def next_dagrun_info(self, date_last_automated_dagrun : Optional[pendulum.DateTime]): + def next_dagrun_info( + self, + date_last_automated_dagrun : Optional[pendulum.DateTime], + ) -> Tuple[Optional[pendulum.DateTime], Optional[pendulum.DateTime]]: """ Get information about the next DagRun of this dag after ``date_last_automated_dagrun`` -- the execution date, and the earliest it could be scheduled @@ -479,23 +481,17 @@ def next_dagrun_info(self, date_last_automated_dagrun : Optional[pendulum.DateTi if (self.schedule_interval == "@once" and date_last_automated_dagrun) or \ self.schedule_interval is None: # Manual trigger, or already created the run for @once, can short circuit - return None + return (None, None) next_execution_date = self.next_dagrun_after_date(date_last_automated_dagrun) if next_execution_date is None: - return None + return (None, None) if self.schedule_interval == "@once": # For "@once" it can be created "now" - return { - 'execution_date': next_execution_date, - 'can_be_created_after': next_execution_date, - } + return (next_execution_date, next_execution_date) - return { - 'execution_date': next_execution_date, - 'can_be_created_after': self.following_schedule(next_execution_date), - } + return (next_execution_date, self.following_schedule(next_execution_date)) def next_dagrun_after_date(self, date_last_automated_dagrun: Optional[pendulum.DateTime]): """ @@ -1748,17 +1744,13 @@ def bulk_sync_to_db(cls, dags: Collection["DAG"], session=None): t.task_concurrency is not None for t in dag.tasks ) - next_dagrun_info = dag.next_dagrun_info(most_recent_dag_runs.get(dag.dag_id)) - if next_dagrun_info: - orm_dag.next_dagrun = next_dagrun_info['execution_date'] - orm_dag.next_dagrun_create_after = next_dagrun_info['can_be_created_after'] - else: - orm_dag.next_dagrun = None - orm_dag.next_dagrun_create_after = None + orm_dag.next_dagrun, orm_dag.next_dagrun_create_after = dag.next_dagrun_info( + most_recent_dag_runs.get(dag.dag_id), + ) active_runs_of_dag = num_active_runs.get(dag.dag_id, 0) if dag.max_active_runs and active_runs_of_dag >= dag.max_active_runs: - # Since this happens every time the dag is parsed it would be quite spammy + # Since this happens every time the dag is parsed it would be quite spammy at info log.debug( "DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs", dag.dag_id, active_runs_of_dag, dag.max_active_runs diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index fdb19605fce89..95039e0f5353f 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -524,41 +524,6 @@ def test_dag_file_processor_process_task_instances_depends_on_past(self, state, assert tis[0].state == State.SCHEDULED assert tis[1].state == State.SCHEDULED - @pytest.mark.xfail(run=False, reason="TODO[HA]") - def test_dag_file_processor_do_not_schedule_removed_task(self): - dag = DAG( - dag_id='test_scheduler_do_not_schedule_removed_task', - start_date=DEFAULT_DATE) - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') - - session = settings.Session() - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - session.commit() - session.close() - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag.clear() - - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - - dr = DagRun.find(run_id=dr.run_id)[0] - # Re-create the DAG, but remove the task - dag = DAG( - dag_id='test_scheduler_do_not_schedule_removed_task', - start_date=DEFAULT_DATE) - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - mock_list = dag_file_processor._process_task_instances(dag, dag_runs=[dr]) - - self.assertEqual([], mock_list) - @pytest.mark.xfail(run=False, reason="TODO[HA]") def test_dag_file_processor_add_new_task(self): """ @@ -601,88 +566,6 @@ def test_dag_file_processor_add_new_task(self): tis = dr.get_task_instances() self.assertEqual(len(tis), 2) - @pytest.mark.xfail(run=False, reason="TODO[HA]") - def test_dag_file_processor_fail_dagrun_timeout(self): - """ - Test if a a dagrun wil be set failed if timeout - """ - dag = DAG( - dag_id='test_scheduler_fail_dagrun_timeout', - start_date=DEFAULT_DATE) - dag.dagrun_timeout = datetime.timedelta(seconds=60) - - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') - - session = settings.Session() - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - session.commit() - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag.clear() - - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - dr.start_date = timezone.utcnow() - datetime.timedelta(days=1) - session.merge(dr) - session.commit() - - dr2 = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr2) - - dr.refresh_from_db(session=session) - self.assertEqual(dr.state, State.FAILED) - - @pytest.mark.xfail(run=False, reason="TODO[HA]") - def test_dag_file_processor_verify_max_active_runs_and_dagrun_timeout(self): - """ - Test if a a dagrun will not be scheduled if max_dag_runs - has been reached and dagrun_timeout is not reached - - Test if a a dagrun will be scheduled if max_dag_runs has - been reached but dagrun_timeout is also reached - """ - dag = DAG( - dag_id='test_scheduler_verify_max_active_runs_and_dagrun_timeout', - start_date=DEFAULT_DATE) - dag.max_active_runs = 1 - dag.dagrun_timeout = datetime.timedelta(seconds=60) - - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') - - session = settings.Session() - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - session.commit() - session.close() - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag.clear() - - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - - # Should not be scheduled as DagRun has not timedout and max_active_runs is reached - new_dr = dag_file_processor.create_dag_run(dag) - self.assertIsNone(new_dr) - - # Should be scheduled as dagrun_timeout has passed - dr.start_date = timezone.utcnow() - datetime.timedelta(days=1) - session.merge(dr) - session.commit() - new_dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(new_dr) - def test_runs_respected_after_clear(self): """ Test if _process_task_instances only schedules ti's up to max_active_runs @@ -2177,6 +2060,147 @@ def test_scheduler_loop_should_change_state_for_tis_without_dagrun(self, self.assertEqual(ti.start_date, ti.end_date) self.assertIsNotNone(ti.duration) + def test_dagrun_timeout_verify_max_active_runs(self): + """ + Test if a a dagrun will not be scheduled if max_dag_runs + has been reached and dagrun_timeout is not reached + + Test if a a dagrun would be scheduled if max_dag_runs has + been reached but dagrun_timeout is also reached + """ + dag = DAG( + dag_id='test_scheduler_verify_max_active_runs_and_dagrun_timeout', + start_date=DEFAULT_DATE) + dag.max_active_runs = 1 + dag.dagrun_timeout = datetime.timedelta(seconds=60) + + DummyOperator( + task_id='dummy', + dag=dag, + owner='airflow') + + scheduler = SchedulerJob() + scheduler.dagbag.bag_dag(dag, root_dag=dag) + scheduler.dagbag.sync_to_db() + + session = settings.Session() + orm_dag = session.query(DagModel).get(dag.dag_id) + assert orm_dag is not None + + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + scheduler = SchedulerJob() + scheduler._create_dag_run(orm_dag, dag, session) + + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] + + # Should not be able to create a new dag run, as we are at max active runs + assert orm_dag.next_dagrun_create_after is None + # But we should record the date of _what run_ it would be + assert isinstance(orm_dag.next_dagrun, datetime.datetime) + + # Should be scheduled as dagrun_timeout has passed + dr.start_date = timezone.utcnow() - datetime.timedelta(days=1) + session.flush() + + scheduler._schedule_dag_run(dr, session) + session.flush() + + session.refresh(dr) + assert dr.state == State.FAILED + session.refresh(orm_dag) + assert isinstance(orm_dag.next_dagrun, datetime.datetime) + assert isinstance(orm_dag.next_dagrun_create_after, datetime.datetime) + + # TODO[HA] Verify dag failure callback request sent to file processor + + session.rollback() + session.close() + + def test_dagrun_timeout_fails_run(self): + """ + Test if a a dagrun wil be set failed if timeout, even without max_active_runs + """ + dag = DAG( + dag_id='test_scheduler_fail_dagrun_timeout', + start_date=DEFAULT_DATE) + dag.dagrun_timeout = datetime.timedelta(seconds=60) + + DummyOperator( + task_id='dummy', + dag=dag, + owner='airflow') + + scheduler = SchedulerJob() + scheduler.dagbag.bag_dag(dag, root_dag=dag) + scheduler.dagbag.sync_to_db() + + session = settings.Session() + orm_dag = session.query(DagModel).get(dag.dag_id) + assert orm_dag is not None + + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + scheduler = SchedulerJob() + scheduler._create_dag_run(orm_dag, dag, session) + + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] + + # Should be scheduled as dagrun_timeout has passed + dr.start_date = timezone.utcnow() - datetime.timedelta(days=1) + session.flush() + + scheduler._schedule_dag_run(dr, session) + session.flush() + + session.refresh(dr) + assert dr.state == State.FAILED + + # TODO[HA] Verify dag failure callback request sent to file processor + + session.rollback() + session.close() + + def test_do_not_schedule_removed_task(self): + dag = DAG( + dag_id='test_scheduler_do_not_schedule_removed_task', + start_date=DEFAULT_DATE) + DummyOperator( + task_id='dummy', + dag=dag, + owner='airflow') + + session = settings.Session() + dag.sync_to_db(session=session) + session.flush() + + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + dr = dag.create_dagrun( + execution_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + state=State.RUNNING, + session=session, + ) + self.assertIsNotNone(dr) + + # Re-create the DAG, but remove the task + dag = DAG( + dag_id='test_scheduler_do_not_schedule_removed_task', + start_date=DEFAULT_DATE) + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + scheduler = SchedulerJob() + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) + + self.assertEqual([], res) + session.rollback() + session.close() + @provide_session def evaluate_dagrun( self, @@ -2310,7 +2334,6 @@ def test_dagrun_root_after_dagrun_unfinished(self): dag = self.dagbag.get_dag(dag_id) dag.sync_to_db() scheduler = SchedulerJob( - dag_id, num_runs=1, executor=self.null_exec, subdir=dag.fileloc) @@ -2407,7 +2430,6 @@ def test_scheduler_start_date(self): scheduler = SchedulerJob(dag_id, executor=self.null_exec, - subdir=dag.fileloc, num_runs=1) scheduler.run() @@ -2455,8 +2477,7 @@ def test_scheduler_multiprocessing(self): dag = self.dagbag.get_dag(dag_id) dag.clear() - scheduler = SchedulerJob(dag_ids=dag_ids, - executor=self.null_exec, + scheduler = SchedulerJob(executor=self.null_exec, subdir=os.path.join(TEST_DAG_FOLDER, 'test_scheduler_dags.py'), num_runs=1) scheduler.run() @@ -2478,8 +2499,7 @@ def test_scheduler_multiprocessing_with_spawn_method(self): dag = self.dagbag.get_dag(dag_id) dag.clear() - scheduler = SchedulerJob(dag_ids=dag_ids, - executor=self.null_exec, + scheduler = SchedulerJob(executor=self.null_exec, subdir=os.path.join( TEST_DAG_FOLDER, 'test_scheduler_dags.py'), num_runs=1) From 2d5b0673447acdf7c1ec46f909b6c1b83f6959ce Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Fri, 18 Sep 2020 23:40:41 +0100 Subject: [PATCH 18/70] Move callbacks from Scheduler loop to DagProcessorProcess --- airflow/jobs/scheduler_job.py | 143 ++++++++++++++++++++--------- airflow/models/dagrun.py | 54 ++++++++--- airflow/settings.py | 3 + airflow/utils/callback_requests.py | 101 ++++++++++++++++++++ airflow/utils/dag_processing.py | 59 ++++++------ tests/jobs/test_scheduler_job.py | 59 +++++++++--- tests/models/test_dagrun.py | 85 +++++++++++++++++ tests/utils/test_dag_processing.py | 5 +- 8 files changed, 410 insertions(+), 99 deletions(-) create mode 100644 airflow/utils/callback_requests.py diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index e10e1b27badd5..c38b60678db70 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -47,13 +47,14 @@ from airflow.models import DAG, DagModel, SlaMiss, errors from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun -from airflow.models.taskinstance import TaskInstanceKey +from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKey from airflow.stats import Stats from airflow.ti_deps.dependencies_states import EXECUTION_STATES from airflow.utils import timezone -from airflow.utils.dag_processing import ( - AbstractDagFileProcessorProcess, DagFileProcessorAgent, FailureCallbackRequest, SimpleDagBag, +from airflow.utils.callback_requests import ( + CallbackRequest, DagCallbackRequest, SlaCallbackRequest, TaskCallbackRequest, ) +from airflow.utils.dag_processing import AbstractDagFileProcessorProcess, DagFileProcessorAgent, SimpleDagBag from airflow.utils.email import get_email_address_list, send_email from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter, set_context from airflow.utils.mixins import MultiprocessingStartMethodMixin @@ -76,8 +77,8 @@ class DagFileProcessorProcess(AbstractDagFileProcessorProcess, LoggingMixin, Mul :type pickle_dags: bool :param dag_ids: If specified, only look at these DAG ID's :type dag_ids: List[str] - :param failure_callback_requests: failure callback to execute - :type failure_callback_requests: List[airflow.utils.dag_processing.FailureCallbackRequest] + :param callback_requests: failure callback to execute + :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest] """ # Counter that increments every time an instance of this class is created @@ -88,13 +89,13 @@ def __init__( file_path: str, pickle_dags: bool, dag_ids: Optional[List[str]], - failure_callback_requests: List[FailureCallbackRequest] + callback_requests: List[CallbackRequest], ): super().__init__() self._file_path = file_path self._pickle_dags = pickle_dags self._dag_ids = dag_ids - self._failure_callback_requests = failure_callback_requests + self._callback_requests = callback_requests # The process that was launched to process the given . self._process: Optional[multiprocessing.process.BaseProcess] = None @@ -123,7 +124,7 @@ def _run_file_processor( pickle_dags: bool, dag_ids: Optional[List[str]], thread_name: str, - failure_callback_requests: List[FailureCallbackRequest] + callback_requests: List[CallbackRequest], ) -> None: """ Process the given file. @@ -142,8 +143,8 @@ def _run_file_processor( :type dag_ids: list[str] :param thread_name: the name to use for the process that is launched :type thread_name: str - :param failure_callback_requests: failure callback to execute - :type failure_callback_requests: list[airflow.utils.dag_processing.FailureCallbackRequest] + :param callback_requests: failure callback to execute + :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest] :return: the process that was launched :rtype: multiprocessing.Process """ @@ -179,7 +180,7 @@ def _run_file_processor( result: Tuple[int, int] = dag_file_processor.process_file( file_path=file_path, pickle_dags=pickle_dags, - failure_callback_requests=failure_callback_requests, + callback_requests=callback_requests, ) result_channel.send(result) end_time = time.time() @@ -214,7 +215,7 @@ def start(self) -> None: self._pickle_dags, self._dag_ids, "DagFileProcessor{}".format(self._instance_id), - self._failure_callback_requests + self._callback_requests ), name="DagFileProcessor{}-Process".format(self._instance_id) ) @@ -564,42 +565,64 @@ def update_import_errors(session: Session, dagbag: DagBag) -> None: session.commit() @provide_session - def execute_on_failure_callbacks( + def execute_callbacks( self, dagbag: DagBag, - failure_callback_requests: List[FailureCallbackRequest], + callback_requests: List[CallbackRequest], session: Session = None ) -> None: """ Execute on failure callbacks. These objects can come from SchedulerJob or from DagFileProcessorManager. - :param failure_callback_requests: failure callbacks to execute - :type failure_callback_requests: List[airflow.utils.dag_processing.FailureCallbackRequest] + :param dagbag: Dag Bag of dags + :param callback_requests: failure callbacks to execute + :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest] :param session: DB session. """ - for request in failure_callback_requests: - simple_ti = request.simple_task_instance - if simple_ti.dag_id in dagbag.dags: - dag = dagbag.dags[simple_ti.dag_id] - if simple_ti.task_id in dag.task_ids: - task = dag.get_task(simple_ti.task_id) - ti = TI(task, simple_ti.execution_date) - # Get properties needed for failure handling from SimpleTaskInstance. - ti.start_date = simple_ti.start_date - ti.end_date = simple_ti.end_date - ti.try_number = simple_ti.try_number - ti.state = simple_ti.state - ti.test_mode = self.UNIT_TEST_MODE + for request in callback_requests: + if isinstance(request, TaskCallbackRequest): + self._execute_task_callbacks(dagbag, request) + elif isinstance(request, SlaCallbackRequest): + self.manage_slas(dagbag.dags.get(request.dag_id)) + elif isinstance(request, DagCallbackRequest): + self._execute_dag_callbacks(dagbag, request, session) + + session.commit() + + @provide_session + def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, session: Session): + dag = dagbag.dags[request.dag_id] + dag_run = dag.get_dagrun(execution_date=request.execution_date, session=session) + dag.handle_callback( + dagrun=dag_run, + success=not request.is_failure_callback, + reason=request.msg, + session=session + ) + + def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest): + simple_ti = request.simple_task_instance + if simple_ti.dag_id in dagbag.dags: + dag = dagbag.dags[simple_ti.dag_id] + if simple_ti.task_id in dag.task_ids: + task = dag.get_task(simple_ti.task_id) + ti = TI(task, simple_ti.execution_date) + # Get properties needed for failure handling from SimpleTaskInstance. + ti.start_date = simple_ti.start_date + ti.end_date = simple_ti.end_date + ti.try_number = simple_ti.try_number + ti.state = simple_ti.state + ti.test_mode = self.UNIT_TEST_MODE + if request.is_failure_callback: ti.handle_failure(request.msg, ti.test_mode, ti.get_template_context()) self.log.info('Executed failure callback for %s in state %s', ti, ti.state) - session.commit() @provide_session def process_file( self, file_path: str, - failure_callback_requests: List[FailureCallbackRequest], + callback_requests: List[CallbackRequest], pickle_dags: bool = False, session: Session = None ) -> Tuple[int, int]: @@ -621,8 +644,8 @@ def process_file( :param file_path: the path to the Python file that should be executed :type file_path: str - :param failure_callback_requests: failure callback to execute - :type failure_callback_requests: List[airflow.utils.dag_processing.FailureCallbackRequest] + :param callback_requests: failure callback to execute + :type callback_requests: List[airflow.utils.dag_processing.CallbackRequest] :param pickle_dags: whether serialize the DAGs found in the file and save them to the db :type pickle_dags: bool @@ -648,9 +671,9 @@ def process_file( return 0, len(dagbag.import_errors) try: - self.execute_on_failure_callbacks(dagbag, failure_callback_requests) + self.execute_callbacks(dagbag, callback_requests) except Exception: # pylint: disable=broad-except - self.log.exception("Error executing failure callback!") + self.log.exception("Error executing callback!") # Save individual DAGs in the ORM dagbag.read_dags_from_db = True @@ -1214,12 +1237,14 @@ def _process_executor_events(self, session: Session = None) -> int: msg = "Executor reports task instance %s finished (%s) although the " \ "task says its %s. (Info: %s) Was the task killed externally?" self.log.error(msg, ti, state, ti.state, info) - self.processor_agent.send_callback_to_execute( + request = TaskCallbackRequest( full_filepath=ti.dag_model.fileloc, - task_instance=ti, + simple_task_instance=SimpleTaskInstance(ti), msg=msg % (ti, state, ti.state, info), ) + self.processor_agent.send_callback_to_execute(request) + return len(event_buffer) def _execute(self) -> None: @@ -1287,7 +1312,7 @@ def _execute(self) -> None: @staticmethod def _create_dag_file_processor( file_path: str, - failure_callback_requests: List[FailureCallbackRequest], + callback_requests: List[CallbackRequest], dag_ids: Optional[List[str]], pickle_dags: bool ) -> DagFileProcessorProcess: @@ -1298,7 +1323,7 @@ def _create_dag_file_processor( file_path=file_path, pickle_dags=pickle_dags, dag_ids=dag_ids, - failure_callback_requests=failure_callback_requests + callback_requests=callback_requests ) def _run_scheduler_loop(self) -> None: @@ -1525,7 +1550,16 @@ def _schedule_dag_run(self, dag_run: DagRun, session: Session) -> int: # Work out if we should allow creating a new DagRun now? self._update_dag_next_dagrun(session.query(DagModel).get(dag_run.dag_id), dag, session) - # TODO[HA] run `dag.handle_callback` via the DagFileProcessor + dag_run.callback = DagCallbackRequest( + full_filepath=dag.fileloc, + dag_id=dag.dag_id, + execution_date=dag_run.execution_date, + is_failure_callback=True, + msg='timed_out' + ) + + # Send SLA & DAG Success/Failure Callbacks to be executed + self._send_dag_callbacks_to_processor(dag_run) return 0 @@ -1557,7 +1591,7 @@ def _schedule_dag_run(self, dag_run: DagRun, session: Session) -> int: # TODO[HA]: Run verify_integrity, but only if the serialized_dag has changed # TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else? - schedulable_tis = dag_run.update_state(session=session) + schedulable_tis = dag_run.update_state(session=session, execute_callbacks=False) # TODO[HA]: Don't return, update these from in update_state? count = session.query(TI).filter( TI.dag_id == dag_run.dag_id, @@ -1565,10 +1599,33 @@ def _schedule_dag_run(self, dag_run: DagRun, session: Session) -> int: TI.task_id.in_(ti.task_id for ti in schedulable_tis) ).update({TI.state: State.SCHEDULED}, synchronize_session=False) - # TODO[HA]: Manage SLAs - return count + def _send_dag_callbacks_to_processor(self, dag_run: DagRun): + if not self.processor_agent: + raise ValueError("Processor agent is not started.") + + dag = dag_run.get_dag() + self._manage_slas(dag) + if dag_run.callback: + self.processor_agent.send_callback_to_execute(dag_run.callback) + + def _manage_slas(self, dag: DAG): + if not settings.CHECK_SLAS: + return + + if not any(isinstance(ti.sla, timedelta) for ti in dag.tasks): + self.log.debug("Skipping SLA check for %s because no tasks in DAG have SLAs", dag) + return + + if not self.processor_agent: + raise ValueError("Processor agent is not started.") + + self.processor_agent.send_sla_callback_request_to_execute( + full_filepath=dag.fileloc, + dag_id=dag.dag_id + ) + @provide_session def _emit_pool_metrics(self, session: Session = None) -> None: pools = models.Pool.slots_stats(session=session) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index d60e326797e0a..de338213ac3ec 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -19,7 +19,7 @@ from typing import Any, List, Optional, Tuple, Union from sqlalchemy import ( - Boolean, Column, DateTime, Index, Integer, PickleType, String, UniqueConstraint, and_, func, or_, + Boolean, Column, DateTime, Index, Integer, PickleType, String, UniqueConstraint, and_, func, or_, orm, ) from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.declarative import declared_attr @@ -35,7 +35,7 @@ from airflow.stats import Stats from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_states import SCHEDULEABLE_STATES -from airflow.utils import timezone +from airflow.utils import callback_requests, timezone from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import provide_session from airflow.utils.sqlalchemy import UtcDateTime, skip_locked @@ -101,8 +101,13 @@ def __init__( self.conf = conf or {} self.state = state self.run_type = run_type + self.callback: Optional[callback_requests.DagCallbackRequest] = None super().__init__() + @orm.reconstructor + def init_on_load(self): + self.callback: Optional[callback_requests.DagCallbackRequest] = None + def __repr__(self): return ( ' List[TI]: + def update_state(self, session: Session = None, execute_callbacks: bool = True) -> List[TI]: """ Determines the overall state of the DagRun based on the state of its TaskInstances. :param session: Sqlalchemy ORM Session :type session: Session + :param execute_callbacks: Should dag callbacks (success/failure, SLA etc) be invoked + directly (default: true) or recorded as a pending request in the ``callback`` property + :type execute_callbacks: bool :return: ready_tis: the tis that can be scheduled in the current loop :rtype ready_tis: list[airflow.models.TaskInstance] """ @@ -372,8 +380,7 @@ def update_state(self, session: Session = None) -> List[TI]: unfinished_tasks = [t for t in tis if t.state in State.unfinished()] finished_tasks = [t for t in tis if t.state in State.finished() + [State.UPSTREAM_FAILED]] none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks) - none_task_concurrency = all(t.task.task_concurrency is None - for t in unfinished_tasks) + none_task_concurrency = all(t.task.task_concurrency is None for t in unfinished_tasks) if unfinished_tasks: scheduleable_tasks = [ut for ut in unfinished_tasks if ut.state in SCHEDULEABLE_STATES] self.log.debug( @@ -392,16 +399,22 @@ def update_state(self, session: Session = None) -> List[TI]: leaf_task_ids = {t.task_id for t in dag.leaves} leaf_tis = [ti for ti in tis if ti.task_id in leaf_task_ids] - # TODO[ha]: These callbacks shouldn't run in the scheduler loop - check if Kamil changed this to run - # via the dag processor! - # if all roots finished and at least one failed, the run failed if not unfinished_tasks and any( leaf_ti.state in {State.FAILED, State.UPSTREAM_FAILED} for leaf_ti in leaf_tis ): self.log.error('Marking run %s failed', self) self.set_state(State.FAILED) - dag.handle_callback(self, success=False, reason='task_failure', session=session) + if execute_callbacks: + dag.handle_callback(self, success=False, reason='task_failure', session=session) + else: + self.callback = callback_requests.DagCallbackRequest( + full_filepath=dag.fileloc, + dag_id=self.dag_id, + execution_date=self.execution_date, + is_failure_callback=True, + msg='task_failure' + ) # if all leafs succeeded and no unfinished tasks, the run succeeded elif not unfinished_tasks and all( @@ -409,15 +422,32 @@ def update_state(self, session: Session = None) -> List[TI]: ): self.log.info('Marking run %s successful', self) self.set_state(State.SUCCESS) - dag.handle_callback(self, success=True, reason='success', session=session) + if execute_callbacks: + dag.handle_callback(self, success=True, reason='success', session=session) + else: + self.callback = callback_requests.DagCallbackRequest( + full_filepath=dag.fileloc, + dag_id=self.dag_id, + execution_date=self.execution_date, + is_failure_callback=False, + msg='success' + ) # if *all tasks* are deadlocked, the run failed elif (unfinished_tasks and none_depends_on_past and none_task_concurrency and not are_runnable_tasks): self.log.error('Deadlock; marking run %s failed', self) self.set_state(State.FAILED) - dag.handle_callback(self, success=False, reason='all_tasks_deadlocked', - session=session) + if execute_callbacks: + dag.handle_callback(self, success=False, reason='all_tasks_deadlocked', session=session) + else: + self.callback = callback_requests.DagCallbackRequest( + full_filepath=dag.fileloc, + dag_id=self.dag_id, + execution_date=self.execution_date, + is_failure_callback=True, + msg='all_tasks_deadlocked' + ) # finally, if the roots aren't done, the dag is still running else: diff --git a/airflow/settings.py b/airflow/settings.py index d42a6251bbb5f..0f9084852495f 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -375,3 +375,6 @@ def initialize(): ) ALLOW_FUTURE_EXEC_DATES = conf.getboolean('scheduler', 'allow_trigger_in_future', fallback=False) + +# Whether or not to check each dagrun against defined SLAs +CHECK_SLAS = conf.getboolean('core', 'check_slas', fallback=True) diff --git a/airflow/utils/callback_requests.py b/airflow/utils/callback_requests.py new file mode 100644 index 0000000000000..fe8017c721fb8 --- /dev/null +++ b/airflow/utils/callback_requests.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime +from typing import Optional + +from airflow.models.taskinstance import SimpleTaskInstance + + +class CallbackRequest: + """ + Base Class with information about the callback to be executed. + + :param full_filepath: File Path to use to run the callback + :param msg: Additional Message that can be used for logging + """ + + def __init__(self, full_filepath: str, msg: Optional[str] = None): + self.full_filepath = full_filepath + self.msg = msg + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __repr__(self): + return str(self.__dict__) + + +class TaskCallbackRequest(CallbackRequest): + """ + A Class with information about the success/failure TI callback to be executed. Currently, only failure + callbacks (when tasks are externally killed) and Zombies are run via DagFileProcessorProcess. + + :param full_filepath: File Path to use to run the callback + :param simple_task_instance: Simplified Task Instance representation + :param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback + :param msg: Additional Message that can be used for logging to determine failure/zombie + """ + + def __init__( + self, + full_filepath: str, + simple_task_instance: SimpleTaskInstance, + is_failure_callback: Optional[bool] = True, + msg: Optional[str] = None + ): + super().__init__(full_filepath=full_filepath, msg=msg) + self.simple_task_instance = simple_task_instance + self.is_failure_callback = is_failure_callback + + +class DagCallbackRequest(CallbackRequest): + """ + A Class with information about the success/failure DAG callback to be executed. + + :param full_filepath: File Path to use to run the callback + :param dag_id: DAG ID + :param execution_date: Execution Date for the DagRun + :param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback + :param msg: Additional Message that can be used for logging + """ + + def __init__( + self, + full_filepath: str, + dag_id: str, + execution_date: datetime, + is_failure_callback: Optional[bool] = True, + msg: Optional[str] = None + ): + super().__init__(full_filepath=full_filepath, msg=msg) + self.dag_id = dag_id + self.execution_date = execution_date + self.is_failure_callback = is_failure_callback + + +class SlaCallbackRequest(CallbackRequest): + """ + A class with information about the SLA callback to be executed. + + :param full_filepath: File Path to use to run the callback + :param dag_id: DAG ID + """ + + def __init__(self, full_filepath: str, dag_id: str): + super().__init__(full_filepath) + self.dag_id = dag_id diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py index ddc20687e3c2e..7fccf5139c6a0 100644 --- a/airflow/utils/dag_processing.py +++ b/airflow/utils/dag_processing.py @@ -41,11 +41,12 @@ from airflow.dag.base_dag import BaseDagBag from airflow.exceptions import AirflowException from airflow.models import errors -from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance +from airflow.models.taskinstance import SimpleTaskInstance from airflow.serialization.serialized_objects import SerializedDAG from airflow.settings import STORE_DAG_CODE, STORE_SERIALIZED_DAGS from airflow.stats import Stats from airflow.utils import timezone +from airflow.utils.callback_requests import CallbackRequest, SlaCallbackRequest, TaskCallbackRequest from airflow.utils.file import list_py_file_paths from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.mixins import MultiprocessingStartMethodMixin @@ -211,14 +212,6 @@ class DagParsingSignal(enum.Enum): END_MANAGER = 'end_manager' -class FailureCallbackRequest(NamedTuple): - """A message with information about the callback to be executed.""" - - full_filepath: str - simple_task_instance: SimpleTaskInstance - msg: str - - class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin): """ Agent for DAG file processing. It is responsible for all DAG parsing @@ -236,7 +229,7 @@ class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin): :type max_runs: int :param processor_factory: function that creates processors for DAG definition files. Arguments are (dag_definition_path, log_file_path) - :type processor_factory: ([str, List[FailureCallbackRequest], Optional[List[str]], bool]) -> ( + :type processor_factory: ([str, List[CallbackRequest], Optional[List[str]], bool]) -> ( AbstractDagFileProcessorProcess ) :param processor_timeout: How long to wait before timing out a DAG file processor @@ -254,7 +247,7 @@ def __init__( dag_directory: str, max_runs: int, processor_factory: Callable[ - [str, List[FailureCallbackRequest], Optional[List[str]], bool], + [str, List[CallbackRequest], Optional[List[str]], bool], AbstractDagFileProcessorProcess ], processor_timeout: timedelta, @@ -333,27 +326,35 @@ def run_single_parsing_loop(self) -> None: # when harvest_serialized_dags calls _heartbeat_manager. pass - def send_callback_to_execute( - self, full_filepath: str, task_instance: TaskInstance, msg: str - ) -> None: + def send_callback_to_execute(self, request: CallbackRequest) -> None: """ Sends information about the callback to be executed by DagFileProcessor. + :param request: Callback request to be executed. + :type request: CallbackRequest + """ + if not self._parent_signal_conn: + raise ValueError("Process not started.") + try: + self._parent_signal_conn.send(request) + except ConnectionError: + # If this died cos of an error then we will noticed and restarted + # when harvest_serialized_dags calls _heartbeat_manager. + pass + + def send_sla_callback_request_to_execute(self, full_filepath: str, dag_id: str) -> None: + """ + Sends information about the SLA callback to be executed by DagFileProcessor. + :param full_filepath: DAG File path :type full_filepath: str - :param task_instance: Task Instance for which the callback is to be executed. - :type task_instance: airflow.models.taskinstance.TaskInstance - :param msg: Message sent in callback. - :type msg: str + :param dag_id: DAG ID + :type dag_id: str """ if not self._parent_signal_conn: raise ValueError("Process not started.") try: - request = FailureCallbackRequest( - full_filepath=full_filepath, - simple_task_instance=SimpleTaskInstance(task_instance), - msg=msg - ) + request = SlaCallbackRequest(full_filepath=full_filepath, dag_id=dag_id) self._parent_signal_conn.send(request) except ConnectionError: # If this died cos of an error then we will noticed and restarted @@ -381,7 +382,7 @@ def _run_processor_manager( dag_directory: str, max_runs: int, processor_factory: Callable[ - [str, List[FailureCallbackRequest]], + [str, List[CallbackRequest]], AbstractDagFileProcessorProcess ], processor_timeout: timedelta, @@ -550,7 +551,7 @@ def __init__(self, dag_directory: str, max_runs: int, processor_factory: Callable[ - [str, List[FailureCallbackRequest]], + [str, List[CallbackRequest]], AbstractDagFileProcessorProcess ], processor_timeout: timedelta, @@ -613,7 +614,7 @@ def __init__(self, self.dag_dir_list_interval = conf.getint('scheduler', 'dag_dir_list_interval') # Mapping file name and callbacks requests - self._callback_to_execute: Dict[str, List[FailureCallbackRequest]] = defaultdict(list) + self._callback_to_execute: Dict[str, List[CallbackRequest]] = defaultdict(list) self._log = logging.getLogger('airflow.processor_manager') @@ -692,7 +693,7 @@ def _run_parsing_loop(self): elif agent_signal == DagParsingSignal.AGENT_RUN_ONCE: # continue the loop to parse dags pass - elif isinstance(agent_signal, FailureCallbackRequest): + elif isinstance(agent_signal, CallbackRequest): self._add_callback_to_queue(agent_signal) else: raise ValueError(f"Invalid message {type(agent_signal)}") @@ -774,7 +775,7 @@ def _run_parsing_loop(self): else: poll_time = 0.0 - def _add_callback_to_queue(self, request: FailureCallbackRequest): + def _add_callback_to_queue(self, request: CallbackRequest): self._callback_to_execute[request.full_filepath].append(request) # Callback has a higher priority over DAG Run scheduling if request.full_filepath in self._file_path_queue: @@ -1179,7 +1180,7 @@ def _find_zombies(self, session): self._last_zombie_query_time = timezone.utcnow() for ti, file_loc in zombies: - request = FailureCallbackRequest( + request = TaskCallbackRequest( full_filepath=file_loc, simple_task_instance=SimpleTaskInstance(ti), msg="Detected as zombie", diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 95039e0f5353f..7339bb5d72150 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -47,7 +47,8 @@ from airflow.operators.dummy_operator import DummyOperator from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils import timezone -from airflow.utils.dag_processing import FailureCallbackRequest, SimpleDagBag +from airflow.utils.callback_requests import DagCallbackRequest, TaskCallbackRequest +from airflow.utils.dag_processing import SimpleDagBag from airflow.utils.dates import days_ago from airflow.utils.file import list_py_file_paths from airflow.utils.session import create_session, provide_session @@ -668,13 +669,13 @@ def test_execute_on_failure_callbacks(self, mock_ti_handle_failure): session.commit() requests = [ - FailureCallbackRequest( + TaskCallbackRequest( full_filepath="A", simple_task_instance=SimpleTaskInstance(ti), msg="Message" ) ] - dag_file_processor.execute_on_failure_callbacks(dagbag, requests) + dag_file_processor.execute_callbacks(dagbag, requests) mock_ti_handle_failure.assert_called_once_with( "Message", conf.getboolean('core', 'unit_test_mode'), @@ -698,7 +699,7 @@ def test_process_file_should_failure_callback(self): session.commit() requests = [ - FailureCallbackRequest( + TaskCallbackRequest( full_filepath=dag.full_filepath, simple_task_instance=SimpleTaskInstance(ti), msg="Message" @@ -727,7 +728,7 @@ def test_should_mark_dummy_task_as_success(self): dagbag.sync_to_db() serialized_dags, import_errors_count = dag_file_processor.process_file( - file_path=dag_file, failure_callback_requests=[] + file_path=dag_file, callback_requests=[] ) dags = [SerializedDAG.from_dict(serialized_dag) for serialized_dag in serialized_dags] @@ -757,7 +758,7 @@ def test_should_mark_dummy_task_as_success(self): self.assertIsNone(duration) dag_file_processor.process_file( - file_path=dag_file, failure_callback_requests=[] + file_path=dag_file, callback_requests=[] ) with create_session() as session: tis = session.query(TaskInstance).all() @@ -980,8 +981,9 @@ def test_no_orphan_process_will_be_left(self): old_children) self.assertFalse(current_children) + @mock.patch('airflow.jobs.scheduler_job.TaskCallbackRequest') @mock.patch('airflow.jobs.scheduler_job.Stats.incr') - def test_process_executor_events(self, mock_stats_incr): + def test_process_executor_events(self, mock_stats_incr, mock_task_callback): dag_id = "test_process_executor_events" dag_id2 = "test_process_executor_events_2" task_id_1 = 'dummy_task' @@ -994,6 +996,8 @@ def test_process_executor_events(self, mock_stats_incr): dag2.fileloc = "/test_path1/" executor = MockExecutor(do_update=False) + task_callback = mock.MagicMock() + mock_task_callback.return_value = task_callback scheduler = SchedulerJob(executor=executor) scheduler.processor_agent = mock.MagicMock() @@ -1011,14 +1015,15 @@ def test_process_executor_events(self, mock_stats_incr): scheduler._process_executor_events(session=session) ti1.refresh_from_db() self.assertEqual(ti1.state, State.QUEUED) - scheduler.processor_agent.send_callback_to_execute.assert_called_once_with( + mock_task_callback.assert_called_once_with( full_filepath='/test_path1/', - task_instance=mock.ANY, + simple_task_instance=mock.ANY, msg='Executor reports task instance ' ' ' 'finished (failed) although the task says its queued. (Info: None) ' 'Was the task killed externally?' ) + scheduler.processor_agent.send_callback_to_execute.assert_called_once_with(task_callback) scheduler.processor_agent.reset_mock() # ti in success state @@ -2105,6 +2110,10 @@ def test_dagrun_timeout_verify_max_active_runs(self): dr.start_date = timezone.utcnow() - datetime.timedelta(days=1) session.flush() + # Mock that processor_agent is started + scheduler.processor_agent = mock.Mock() + scheduler.processor_agent.send_callback_to_execute = mock.Mock() + scheduler._schedule_dag_run(dr, session) session.flush() @@ -2114,14 +2123,24 @@ def test_dagrun_timeout_verify_max_active_runs(self): assert isinstance(orm_dag.next_dagrun, datetime.datetime) assert isinstance(orm_dag.next_dagrun_create_after, datetime.datetime) - # TODO[HA] Verify dag failure callback request sent to file processor + # Verify dag failure callback request is added to dagrun.callback + assert dr.callback == DagCallbackRequest( + full_filepath=dr.dag.fileloc, + dag_id=dr.dag_id, + is_failure_callback=True, + execution_date=dr.execution_date, + msg="timed_out" + ) + + # Verify dag failure callback request is sent to file processor + scheduler.processor_agent.send_callback_to_execute.assert_called_once_with(dr.callback) session.rollback() session.close() def test_dagrun_timeout_fails_run(self): """ - Test if a a dagrun wil be set failed if timeout, even without max_active_runs + Test if a a dagrun will be set failed if timeout, even without max_active_runs """ dag = DAG( dag_id='test_scheduler_fail_dagrun_timeout', @@ -2147,20 +2166,34 @@ def test_dagrun_timeout_fails_run(self): scheduler._create_dag_run(orm_dag, dag, session) drs = DagRun.find(dag_id=dag.dag_id, session=session) - assert len(drs) == 1 + assert len(drs) == 1 dr = drs[0] # Should be scheduled as dagrun_timeout has passed dr.start_date = timezone.utcnow() - datetime.timedelta(days=1) session.flush() + # Mock that processor_agent is started + scheduler.processor_agent = mock.Mock() + scheduler.processor_agent.send_callback_to_execute = mock.Mock() + scheduler._schedule_dag_run(dr, session) session.flush() session.refresh(dr) assert dr.state == State.FAILED - # TODO[HA] Verify dag failure callback request sent to file processor + # Verify dag failure callback request is added to dagrun.callback + assert dr.callback == DagCallbackRequest( + full_filepath=dr.dag.fileloc, + dag_id=dr.dag_id, + is_failure_callback=True, + execution_date=dr.execution_date, + msg="timed_out" + ) + + # Verify dag failure callback request is sent to file processor + scheduler.processor_agent.send_callback_to_execute.assert_called_once_with(dr.callback) session.rollback() session.close() diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 4753118fe1c86..7d9367b3bf67e 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -28,6 +28,7 @@ from airflow.operators.dummy_operator import DummyOperator from airflow.operators.python import ShortCircuitOperator from airflow.utils import timezone +from airflow.utils.callback_requests import DagCallbackRequest from airflow.utils.state import State from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import DagRunType @@ -329,6 +330,8 @@ def on_success_callable(context): task_states=initial_task_states) dag_run.update_state() self.assertEqual(State.SUCCESS, dag_run.state) + # Callbacks are not added until handle_callback = False is passed to dag_run.update_state() + self.assertIsNone(dag_run.callback) def test_dagrun_failure_callback(self): def on_failure_callable(context): @@ -360,6 +363,88 @@ def on_failure_callable(context): task_states=initial_task_states) dag_run.update_state() self.assertEqual(State.FAILED, dag_run.state) + # Callbacks are not added until handle_callback = False is passed to dag_run.update_state() + self.assertIsNone(dag_run.callback) + + def test_dagrun_update_state_with_handle_callback_success(self): + def on_success_callable(context): + self.assertEqual( + context['dag_run'].dag_id, + 'test_dagrun_update_state_with_handle_callback_success' + ) + + dag = DAG( + dag_id='test_dagrun_update_state_with_handle_callback_success', + start_date=datetime.datetime(2017, 1, 1), + on_success_callback=on_success_callable, + ) + dag_task1 = DummyOperator( + task_id='test_state_succeeded1', + dag=dag) + dag_task2 = DummyOperator( + task_id='test_state_succeeded2', + dag=dag) + dag_task1.set_downstream(dag_task2) + + initial_task_states = { + 'test_state_succeeded1': State.SUCCESS, + 'test_state_succeeded2': State.SUCCESS, + } + + dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states) + self.assertIsNone(dag_run.callback) + + dag_run.update_state(execute_callbacks=False) + self.assertEqual(State.SUCCESS, dag_run.state) + # Callbacks are not added until handle_callback = False is passed to dag_run.update_state() + + assert dag_run.callback == DagCallbackRequest( + full_filepath=dag_run.dag.fileloc, + dag_id="test_dagrun_update_state_with_handle_callback_success", + execution_date=dag_run.execution_date, + is_failure_callback=False, + msg="success" + ) + + def test_dagrun_update_state_with_handle_callback_failure(self): + def on_failure_callable(context): + self.assertEqual( + context['dag_run'].dag_id, + 'test_dagrun_update_state_with_handle_callback_failure' + ) + + dag = DAG( + dag_id='test_dagrun_update_state_with_handle_callback_failure', + start_date=datetime.datetime(2017, 1, 1), + on_failure_callback=on_failure_callable, + ) + dag_task1 = DummyOperator( + task_id='test_state_succeeded1', + dag=dag) + dag_task2 = DummyOperator( + task_id='test_state_failed2', + dag=dag) + dag_task1.set_downstream(dag_task2) + + initial_task_states = { + 'test_state_succeeded1': State.SUCCESS, + 'test_state_failed2': State.FAILED, + } + + dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states) + self.assertIsNone(dag_run.callback) + + dag_run.update_state(execute_callbacks=False) + self.assertEqual(State.FAILED, dag_run.state) + # Callbacks are not added until handle_callback = False is passed to dag_run.update_state() + + assert dag_run.callback == DagCallbackRequest( + full_filepath=dag_run.dag.fileloc, + dag_id="test_dagrun_update_state_with_handle_callback_failure", + execution_date=dag_run.execution_date, + is_failure_callback=True, + msg="task_failure" + ) def test_dagrun_set_state_end_date(self): session = settings.Session() diff --git a/tests/utils/test_dag_processing.py b/tests/utils/test_dag_processing.py index 9cdce92426d42..37eff531fdc23 100644 --- a/tests/utils/test_dag_processing.py +++ b/tests/utils/test_dag_processing.py @@ -33,9 +33,9 @@ from airflow.models import DagBag, DagModel, TaskInstance as TI from airflow.models.taskinstance import SimpleTaskInstance from airflow.utils import timezone +from airflow.utils.callback_requests import TaskCallbackRequest from airflow.utils.dag_processing import ( DagFileProcessorAgent, DagFileProcessorManager, DagFileStat, DagParsingSignal, DagParsingStat, - FailureCallbackRequest, ) from airflow.utils.file import correct_maybe_zipped, open_maybe_zipped from airflow.utils.session import create_session @@ -216,6 +216,7 @@ def test_find_zombies(self): self.assertEqual(1, len(requests)) self.assertEqual(requests[0].full_filepath, dag.full_filepath) self.assertEqual(requests[0].msg, "Detected as zombie") + self.assertEqual(requests[0].is_failure_callback, True) self.assertIsInstance(requests[0].simple_task_instance, SimpleTaskInstance) self.assertEqual(ti.dag_id, requests[0].simple_task_instance.dag_id) self.assertEqual(ti.task_id, requests[0].simple_task_instance.task_id) @@ -252,7 +253,7 @@ def test_handle_failure_callback_with_zombies_are_correctly_passed_to_dag_file_p session.commit() fake_failure_callback_requests = [ - FailureCallbackRequest( + TaskCallbackRequest( full_filepath=dag.full_filepath, simple_task_instance=SimpleTaskInstance(ti), msg="Message" From 58b8514dfde5584812e3385610feb3fe03088031 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Wed, 23 Sep 2020 15:10:53 +0100 Subject: [PATCH 19/70] fixup! Officially support running more than one scheduler concurrently. --- airflow/jobs/scheduler_job.py | 3 ++- airflow/models/dag.py | 4 ++-- airflow/models/dagrun.py | 3 +-- airflow/stats.py | 2 ++ airflow/utils/sqlalchemy.py | 1 - 5 files changed, 7 insertions(+), 6 deletions(-) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index c38b60678db70..aa903a68dd67a 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -1522,7 +1522,8 @@ def _update_dag_next_dagrun(self, dag_model: DagModel, dag: DAG, session: Sessio ) dag_model.next_dagrun_create_after = None else: - dag_model.next_dagrun, dag_model.next_dagrun_create_after = dag.next_dagrun_info(dag_model.next_dagrun) + dag_model.next_dagrun, dag_model.next_dagrun_create_after = \ + dag.next_dagrun_info(dag_model.next_dagrun) def _schedule_dag_run(self, dag_run: DagRun, session: Session) -> int: """ diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 8413f16d42884..e3152a7cc11d3 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -28,7 +28,8 @@ from collections import OrderedDict from datetime import datetime, timedelta from typing import ( - TYPE_CHECKING, Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, Type, Union, cast, + TYPE_CHECKING, Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, Type, Union, + cast, ) import jinja2 @@ -2093,7 +2094,6 @@ def dags_needing_dagruns(cls, session: Session): that any scheduling decisions are made in a single transaction -- as soon as the transaction is committed it will be unlocked. """ - # TODO[HA]: Bake this query, it is run _A lot_ # TODO[HA]: Make this limit a tunable. We limit so that _one_ scheduler # doesn't try to do all the creation of dag runs diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index de338213ac3ec..2d89ad6e88eb6 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -164,7 +164,7 @@ def next_dagruns_to_examine( query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as the transaction is committed it will be unlocked. - :rtype: list[DagRun] + :rtype: list[airflow.models.DagRun] """ from airflow.models.dag import DagModel @@ -366,7 +366,6 @@ def update_state(self, session: Session = None, execute_callbacks: bool = True) :return: ready_tis: the tis that can be scheduled in the current loop :rtype ready_tis: list[airflow.models.TaskInstance] """ - start_dttm = timezone.utcnow() self.last_scheduling_decision = start_dttm diff --git a/airflow/stats.py b/airflow/stats.py index 641f3e3a8e5bd..3d4b0875c6276 100644 --- a/airflow/stats.py +++ b/airflow/stats.py @@ -32,6 +32,7 @@ class TimerProtocol(Protocol): """Type protocol for StatsLogger.timer""" + def __enter__(self): ... @@ -73,6 +74,7 @@ def timer(cls, *args, **kwargs) -> TimerProtocol: class DummyTimer: """No-op timer""" + def __enter__(self): return self diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py index 88b1a1f0b333a..3344c59f77525 100644 --- a/airflow/utils/sqlalchemy.py +++ b/airflow/utils/sqlalchemy.py @@ -155,7 +155,6 @@ def nowait(session: Session) -> Dict[str, Any]: See https://jira.mariadb.org/browse/MDEV-13115 """ - dialect = session.bind.dialect if dialect.name != "mysql" or dialect.supports_for_update_of: From 44c60e58f71c02743ca09d75f785e3da63a041e3 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 23 Sep 2020 17:24:34 +0100 Subject: [PATCH 20/70] fixup! Officially support running more than one scheduler concurrently. --- airflow/jobs/backfill_job.py | 1 - airflow/jobs/scheduler_job.py | 22 +++++++++++++++++----- airflow/models/dagrun.py | 2 ++ tests/jobs/test_scheduler_job.py | 19 ++++++++++++++----- tests/models/test_dag.py | 5 +++++ tests/utils/test_dag_processing.py | 18 +++++++++++++++--- 6 files changed, 53 insertions(+), 14 deletions(-) diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py index 7d8271fa14711..c60611915ad97 100644 --- a/airflow/jobs/backfill_job.py +++ b/airflow/jobs/backfill_job.py @@ -630,7 +630,6 @@ def _per_task_process(task, key, ti, session=None): # pylint: disable=too-many- _dag_runs = ti_status.active_runs[:] for run in _dag_runs: run.update_state(session=session) - session.merge(run) if run.state in State.finished(): ti_status.finished_runs += 1 ti_status.active_runs.remove(run) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index aa903a68dd67a..84ec56ca921f7 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -713,9 +713,13 @@ class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes :param subdir: directory containing Python files with Airflow DAG definitions, or a specific path to a file :type subdir: str - :param num_runs: The number of times to try to schedule each DAG file. - -1 for unlimited times. + :param num_runs: The number of times to run the scheduling loop. If you + have a large number of DAG files this could complete before each file + has been parsed. -1 for unlimited times. :type num_runs: int + :param num_times_parse_dags: The number of times to try to parse each DAG file. + -1 for unlimited times. + :type num_times_parse_dags: int :param processor_poll_interval: The number of seconds to wait between polls of running processors :type processor_poll_interval: int @@ -733,6 +737,7 @@ def __init__( self, subdir: str = settings.DAGS_FOLDER, num_runs: int = conf.getint('scheduler', 'num_runs'), + num_times_parse_dags: int = -1, processor_poll_interval: float = conf.getfloat('scheduler', 'processor_poll_interval'), do_pickle: bool = False, log: Any = None, @@ -740,6 +745,7 @@ def __init__( self.subdir = subdir self.num_runs = num_runs + self.num_times_parse_dags = num_times_parse_dags self._processor_poll_interval = processor_poll_interval self.do_pickle = do_pickle @@ -1253,7 +1259,7 @@ def _execute(self) -> None: # DAGs can be pickled for easier remote execution by some executors pickle_dags = self.do_pickle and self.executor_class not in UNPICKLEABLE_EXECUTORS - self.log.info("Processing each file at most %s times", self.num_runs) + self.log.info("Processing each file at most %s times", self.num_times_parse_dags) # When using sqlite, we do not use async_mode # so the scheduler job and DAG parser don't access the DB at the same time. @@ -1263,7 +1269,7 @@ def _execute(self) -> None: processor_timeout = timedelta(seconds=processor_timeout_seconds) self.processor_agent = DagFileProcessorAgent( dag_directory=self.subdir, - max_runs=self.num_runs, + max_runs=self.num_times_parse_dags, processor_factory=type(self)._create_dag_file_processor, processor_timeout=processor_timeout, dag_ids=[], @@ -1381,12 +1387,18 @@ def _run_scheduler_loop(self) -> None: # usage when "idle" time.sleep(self._processor_poll_interval) - if self.num_runs > 0 and loop_count >= self.num_runs and self.processor_agent.done: + if self.num_runs > 0 and loop_count >= self.num_runs: self.log.info( "Exiting scheduler loop as requested number of runs (%d - got to %d) has been reached", self.num_runs, loop_count, ) break + if self.processor_agent.done: + self.log.info( + "Exiting scheduler loop as requested DAG parse count (%d) has been reached", + self.num_times_parse_dags, loop_count, + ) + break def _do_scheduling(self, session) -> int: """ diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 2d89ad6e88eb6..11e892ca7a0d9 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -454,6 +454,8 @@ def update_state(self, session: Session = None, execute_callbacks: bool = True) self._emit_duration_stats_for_finished_state() + session.merge(self) + return ready_tis def _get_ready_tis( diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 7339bb5d72150..64e0938228a12 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -958,7 +958,7 @@ def run_single_scheduler_loop_with_no_dags(self, dags_folder): """ scheduler = SchedulerJob( executor=self.null_exec, - num_runs=1, + num_times_parse_dags=1, subdir=os.path.join(dags_folder)) scheduler.heartrate = 0 scheduler.run() @@ -2477,15 +2477,18 @@ def test_scheduler_task_start_date(self): Test that the scheduler respects task start dates that are different from DAG start dates """ + dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), include_examples=False) dag_id = 'test_task_start_date_scheduling' dag = self.dagbag.get_dag(dag_id) - dag.sync_to_db() - dag.clear() + dag.is_paused_upon_creation = False + dagbag.bag_dag(dag=dag, root_dag=dag) - # Deactivate other dags in this file + # Deactivate other dags in this file so the scheduler doesn't waste time processing them other_dag = self.dagbag.get_dag('test_start_date_scheduling') other_dag.is_paused_upon_creation = True - other_dag.sync_to_db() + dagbag.bag_dag(dag=other_dag, root_dag=other_dag) + + dagbag.sync_to_db() scheduler = SchedulerJob(executor=self.null_exec, subdir=dag.fileloc, @@ -2900,6 +2903,12 @@ def test_add_unparseable_file_before_sched_start_creates_import_error(self): @conf_vars({("core", "dagbag_import_error_tracebacks"): "False"}) def test_add_unparseable_file_after_sched_start_creates_import_error(self): + """ + Check that new DAG files are picked up, and import errors recorded. + + This is more of an "integration" test as it checks SchedulerJob, DagFileProcessorManager and + DagFileProcessor + """ dags_folder = mkdtemp() try: unparseable_filename = os.path.join(dags_folder, TEMP_DAG_FILENAME) diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 16230e6adf1e8..ad4a04eb27c80 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -1184,10 +1184,15 @@ def test_schedule_dag_once(self): task_id="faketastic", owner='Also fake', start_date=TEST_DATE)) + + # Sync once to create the DagModel + dag.sync_to_db() + dag.create_dagrun(run_type=DagRunType.SCHEDULED, execution_date=TEST_DATE, state=State.SUCCESS) + # Then sync again after creating the dag run -- this should update next_dagrun dag.sync_to_db() with create_session() as session: model = session.query(DagModel).get((dag.dag_id,)) diff --git a/tests/utils/test_dag_processing.py b/tests/utils/test_dag_processing.py index 37eff531fdc23..e0f296e5a862b 100644 --- a/tests/utils/test_dag_processing.py +++ b/tests/utils/test_dag_processing.py @@ -31,6 +31,7 @@ from airflow.jobs.local_task_job import LocalTaskJob as LJ from airflow.jobs.scheduler_job import DagFileProcessorProcess from airflow.models import DagBag, DagModel, TaskInstance as TI +from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import SimpleTaskInstance from airflow.utils import timezone from airflow.utils.callback_requests import TaskCallbackRequest @@ -431,6 +432,9 @@ class path, thus when reloading logging module the airflow.processor_manager @conf_vars({('core', 'load_examples'): 'False'}) def test_parse_once(self): + clear_db_serialized_dags() + clear_db_dags() + test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_scheduler_dags.py') async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn') processor_agent = DagFileProcessorAgent(test_dag_path, @@ -447,10 +451,18 @@ def test_parse_once(self): while not processor_agent.done: if not async_mode: processor_agent.wait_until_finished() - parsing_result.extend(processor_agent.harvest_serialized_dags()) + processor_agent.heartbeat() + + assert processor_agent.all_files_processed + assert processor_agent.done + + + with create_session() as session: + dag_ids = session.query(DagModel.dag_id).order_by("dag_id").all() + assert dag_ids == [('test_start_date_scheduling',), ('test_task_start_date_scheduling',)] - dag_ids = [result.dag_id for result in parsing_result] - self.assertEqual(dag_ids.count('test_start_date_scheduling'), 1) + dag_ids = session.query(SerializedDagModel.dag_id).order_by("dag_id").all() + assert dag_ids == [('test_start_date_scheduling',), ('test_task_start_date_scheduling',)] def test_launch_process(self): test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_scheduler_dags.py') From c73b078baaf06a289e647d943f14399ceb257239 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Thu, 24 Sep 2020 11:32:34 +0100 Subject: [PATCH 21/70] fixup! Move callbacks from Scheduler loop to DagProcessorProcess --- tests/utils/test_dag_processing.py | 40 ++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/tests/utils/test_dag_processing.py b/tests/utils/test_dag_processing.py index e0f296e5a862b..6565c48a999ae 100644 --- a/tests/utils/test_dag_processing.py +++ b/tests/utils/test_dag_processing.py @@ -54,14 +54,14 @@ class FakeDagFileProcessorRunner(DagFileProcessorProcess): # This fake processor will return the zombies it received in constructor # as its processing result w/o actually parsing anything. - def __init__(self, file_path, pickle_dags, dag_ids, zombies): - super().__init__(file_path, pickle_dags, dag_ids, zombies) + def __init__(self, file_path, pickle_dags, dag_ids, callbacks): + super().__init__(file_path, pickle_dags, dag_ids, callbacks) # We need a "real" selectable handle for waitable_handle to work readable, writable = multiprocessing.Pipe(duplex=False) writable.send('abc') writable.close() self._waitable_handle = readable - self._result = zombies, 0 + self._result = 0, 0 def start(self): pass @@ -83,12 +83,12 @@ def result(self): return self._result @staticmethod - def _fake_dag_processor_factory(file_path, zombies, dag_ids, pickle_dags): + def _fake_dag_processor_factory(file_path, callbacks, dag_ids, pickle_dags): return FakeDagFileProcessorRunner( file_path, pickle_dags, dag_ids, - zombies + callbacks, ) @property @@ -253,7 +253,7 @@ def test_handle_failure_callback_with_zombies_are_correctly_passed_to_dag_file_p ti.job_id = local_job.id session.commit() - fake_failure_callback_requests = [ + expected_failure_callback_requests = [ TaskCallbackRequest( full_filepath=dag.full_filepath, simple_task_instance=SimpleTaskInstance(ti), @@ -266,23 +266,38 @@ def test_handle_failure_callback_with_zombies_are_correctly_passed_to_dag_file_p child_pipe, parent_pipe = multiprocessing.Pipe() async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn') + fake_processors = [] + + def fake_processor_factory(*args, **kwargs): + nonlocal fake_processors + processor = FakeDagFileProcessorRunner._fake_dag_processor_factory(*args, **kwargs) + fake_processors.append(processor) + return processor + manager = DagFileProcessorManager( dag_directory=test_dag_path, max_runs=1, - processor_factory=FakeDagFileProcessorRunner._fake_dag_processor_factory, + processor_factory=fake_processor_factory, processor_timeout=timedelta.max, signal_conn=child_pipe, dag_ids=[], pickle_dags=False, async_mode=async_mode) - parsing_result = self.run_processor_manager_one_loop(manager, parent_pipe) + self.run_processor_manager_one_loop(manager, parent_pipe) + + # Once for initial parse, and then again for the add_callback_to_queue + assert len(fake_processors) == 2 - self.assertEqual(len(fake_failure_callback_requests), len(parsing_result)) - self.assertEqual( - set(zombie.simple_task_instance.key for zombie in fake_failure_callback_requests), - set(result.simple_task_instance.key for result in parsing_result) + assert fake_processors[0]._file_path == test_dag_path + assert fake_processors[0]._callback_requests == [] + assert fake_processors[1]._file_path == test_dag_path + callback_requests = fake_processors[1]._callback_requests + assert ( + set(zombie.simple_task_instance.key for zombie in expected_failure_callback_requests) == + set(result.simple_task_instance.key for result in callback_requests) ) + child_pipe.close() parent_pipe.close() @@ -445,7 +460,6 @@ def test_parse_once(self): False, async_mode) processor_agent.start() - parsing_result = [] if not async_mode: processor_agent.run_single_parsing_loop() while not processor_agent.done: From c806d1f36496cf775bf862bd178dc881dd49d49e Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Thu, 24 Sep 2020 11:32:44 +0100 Subject: [PATCH 22/70] fixup! Officially support running more than one scheduler concurrently. --- airflow/jobs/scheduler_job.py | 3 ++- tests/models/test_dag.py | 4 ++-- tests/utils/test_dag_processing.py | 1 - 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 84ec56ca921f7..733cfc3d6c25d 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -1395,7 +1395,8 @@ def _run_scheduler_loop(self) -> None: break if self.processor_agent.done: self.log.info( - "Exiting scheduler loop as requested DAG parse count (%d) has been reached", + "Exiting scheduler loop as requested DAG parse count (%d) has been reached after %d " + " scheduler loops", self.num_times_parse_dags, loop_count, ) break diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index ad4a04eb27c80..9345a8abe890c 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -1665,8 +1665,8 @@ def test_dags_needing_dagruns_only_unpaused(self): session.add(orm_dag) session.flush() - models = DagModel.dags_needing_dagruns(session).all() - assert models == [orm_dag] + needed = DagModel.dags_needing_dagruns(session).all() + assert needed == [orm_dag] orm_dag.is_paused = True session.flush() diff --git a/tests/utils/test_dag_processing.py b/tests/utils/test_dag_processing.py index 6565c48a999ae..6e1138ebb35db 100644 --- a/tests/utils/test_dag_processing.py +++ b/tests/utils/test_dag_processing.py @@ -470,7 +470,6 @@ def test_parse_once(self): assert processor_agent.all_files_processed assert processor_agent.done - with create_session() as session: dag_ids = session.query(DagModel.dag_id).order_by("dag_id").all() assert dag_ids == [('test_start_date_scheduling',), ('test_task_start_date_scheduling',)] From 46f848f190881411bb30083c3ab49b6eb6e33166 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 24 Sep 2020 14:34:04 +0100 Subject: [PATCH 23/70] =?UTF-8?q?Don=E2=80=99t=20run=20verify=5Fintegrity?= =?UTF-8?q?=20if=20the=20Serialized=20DAG=20hasn=E2=80=99t=20changed?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit dag_run.verify_integrity is slow, and we don't want to call it every time, just when the dag structure changes (which we can know now thanks to DAG Serialization) --- airflow/jobs/scheduler_job.py | 25 ++- ..._add_scheduling_decision_to_dagrun_and_.py | 2 + airflow/models/dag.py | 24 ++- airflow/models/dagbag.py | 3 + airflow/models/dagrun.py | 7 +- airflow/models/serialized_dag.py | 15 ++ tests/jobs/test_scheduler_job.py | 155 +++++++++++++++--- 7 files changed, 190 insertions(+), 41 deletions(-) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 733cfc3d6c25d..9cdabb16eef0c 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -47,6 +47,7 @@ from airflow.models import DAG, DagModel, SlaMiss, errors from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun +from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKey from airflow.stats import Stats from airflow.ti_deps.dependencies_states import EXECUTION_STATES @@ -1502,13 +1503,16 @@ def _create_dag_run(self, dag_model: DagModel, dag: DAG, session: Session) -> No Unconditionally create a DAG run for the given DAG, and update the dag_model's fields to control if/when the next DAGRun should be created """ + dag_hash = self.dagbag.dags_hash.get(dag.dag_id, None) + dag.create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=dag_model.next_dagrun, start_date=timezone.utcnow(), state=State.RUNNING, external_trigger=False, - session=session + session=session, + dag_hash=dag_hash ) self._update_dag_next_dagrun(dag_model, dag, session) @@ -1602,8 +1606,7 @@ def _schedule_dag_run(self, dag_run: DagRun, session: Session) -> int: if currently_active_runs >= dag.max_active_runs: return 0 - # TODO[HA]: Run verify_integrity, but only if the serialized_dag has changed - + self._verify_integrity_if_dag_changed(dag_run=dag_run, session=session) # TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else? schedulable_tis = dag_run.update_state(session=session, execute_callbacks=False) # TODO[HA]: Don't return, update these from in update_state? @@ -1615,6 +1618,22 @@ def _schedule_dag_run(self, dag_run: DagRun, session: Session) -> int: return count + @provide_session + def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session=None): + """Only run DagRun.verify integrity if Serialized DAG has changed since it is slow""" + latest_version = SerializedDagModel.get_latest_version_hash(dag_run.dag_id, session=session) + if dag_run.dag_hash == latest_version: + self.log.debug("DAG %s not changed structure, skipping dagrun.verify_integrity", dag_run.dag_id) + return + + dag_run.dag_hash = latest_version + + # Refresh the DAG + dag_run.dag = self.dagbag.get_dag(dag_id=dag_run.dag_id) + + # Verify integrity also takes care of session.flush + dag_run.verify_integrity(session=session) + def _send_dag_callbacks_to_processor(self, dag_run: DagRun): if not self.processor_agent: raise ValueError("Processor agent is not started.") diff --git a/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py b/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py index 30a4c3a2d91f2..28693c92fb1a2 100644 --- a/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py +++ b/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py @@ -39,6 +39,7 @@ def upgrade(): with op.batch_alter_table('dag_run', schema=None) as batch_op: batch_op.add_column(sa.Column('last_scheduling_decision', sa.DateTime(timezone=True), nullable=True)) batch_op.create_index('idx_last_scheduling_decision', ['last_scheduling_decision'], unique=False) + batch_op.add_column(sa.Column('dag_hash', sa.String(32), nullable=True)) with op.batch_alter_table('dag', schema=None) as batch_op: batch_op.add_column(sa.Column('next_dagrun', sa.DateTime(timezone=True), nullable=True)) @@ -71,6 +72,7 @@ def downgrade(): with op.batch_alter_table('dag_run', schema=None) as batch_op: batch_op.drop_index('idx_last_scheduling_decision') batch_op.drop_column('last_scheduling_decision') + batch_op.drop_column('dag_hash') with op.batch_alter_table('dag', schema=None) as batch_op: batch_op.drop_index('idx_next_dagrun_create_after') diff --git a/airflow/models/dag.py b/airflow/models/dag.py index e3152a7cc11d3..c0cc63b10978f 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1606,15 +1606,18 @@ def cli(self): args.func(args, self) @provide_session - def create_dagrun(self, - state, - execution_date=None, - run_id=None, - start_date=None, - external_trigger=False, - conf=None, - run_type=None, - session=None): + def create_dagrun( + self, + state, + execution_date=None, + run_id=None, + start_date=None, + external_trigger=False, + conf=None, + run_type=None, + session=None, + dag_hash=None + ): """ Creates a dag run from this dag including the tasks associated with this dag. Returns the dag run. @@ -1635,6 +1638,8 @@ def create_dagrun(self, :type conf: dict :param session: database session :type session: sqlalchemy.orm.session.Session + :param dag_hash: Hash of Serialized DAG + :type dag_hash: str """ if run_id and not run_type: if not isinstance(run_id, str): @@ -1658,6 +1663,7 @@ def create_dagrun(self, conf=conf, state=state, run_type=run_type.value, + dag_hash=dag_hash ) session.add(run) diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 3f548a84599e8..538b70ad95fb8 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -117,6 +117,8 @@ def __init__( self.read_dags_from_db = read_dags_from_db # Only used by read_dags_from_db=True self.dags_last_fetched: Dict[str, datetime] = {} + # Only used by SchedulerJob to compare the dag_hash to identify change in DAGs + self.dags_hash: Dict[str, str] = {} self.dagbag_import_error_tracebacks = conf.getboolean('core', 'dagbag_import_error_tracebacks') self.dagbag_import_error_traceback_depth = conf.getint('core', 'dagbag_import_error_traceback_depth') @@ -223,6 +225,7 @@ def _add_dag_from_db(self, dag_id: str, session: Session): self.dags[subdag.dag_id] = subdag self.dags[dag.dag_id] = dag self.dags_last_fetched[dag.dag_id] = timezone.utcnow() + self.dags_hash[dag.dag_id] = row.dag_hash def process_file(self, filepath, only_if_updated=True, safe_mode=True): """ diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 11e892ca7a0d9..3ccc60192d0ab 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -63,6 +63,7 @@ class DagRun(Base, LoggingMixin): conf = Column(PickleType) # When a scheduler last attempted to schedule TIs for this DagRun last_scheduling_decision = Column(UtcDateTime) + dag_hash = Column(String(32)) dag = None @@ -91,7 +92,8 @@ def __init__( external_trigger: Optional[bool] = None, conf: Optional[Any] = None, state: Optional[str] = None, - run_type: Optional[str] = None + run_type: Optional[str] = None, + dag_hash: Optional[str] = None, ): self.dag_id = dag_id self.run_id = run_id @@ -101,6 +103,7 @@ def __init__( self.conf = conf or {} self.state = state self.run_type = run_type + self.dag_hash = dag_hash self.callback: Optional[callback_requests.DagCallbackRequest] = None super().__init__() @@ -578,7 +581,7 @@ def verify_integrity(self, session: Session = None): self.log.info('Hit IntegrityError while creating the TIs for ' f'{dag.dag_id} - {self.execution_date}.') self.log.info('Doing session rollback.') - # TODO[HA]: We probaly need to savepoint this so we can keep the transaction alive. + # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive. session.rollback() @staticmethod diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py index 965d1f0cf8582..e2174bfd97017 100644 --- a/airflow/models/serialized_dag.py +++ b/airflow/models/serialized_dag.py @@ -264,3 +264,18 @@ def get_last_updated_datetime(cls, dag_id: str, session: Session = None) -> date :type session: Session """ return session.query(cls.last_updated).filter(cls.dag_id == dag_id).scalar() + + @classmethod + @provide_session + def get_latest_version_hash(cls, dag_id: str, session: Session = None) -> str: + """ + Get the latest DAG version for a given DAG ID. + + :param dag_id: DAG ID + :type dag_id: str + :param session: ORM Session + :type session: Session + :return: DAG Hash + :rtype: str + """ + return session.query(cls.dag_hash).filter(cls.dag_id == dag_id).scalar() diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 64e0938228a12..a30e320c7d394 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -31,6 +31,7 @@ import six from mock import MagicMock, patch from parameterized import parameterized +from sqlalchemy import func import airflow.example_dags import airflow.smart_sensor_dags @@ -42,6 +43,7 @@ from airflow.jobs.scheduler_job import DagFileProcessor, SchedulerJob from airflow.models import DAG, DagBag, DagModel, Pool, SlaMiss, TaskInstance, errors from airflow.models.dagrun import DagRun +from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKey from airflow.operators.bash import BashOperator from airflow.operators.dummy_operator import DummyOperator @@ -57,8 +59,8 @@ from tests.test_utils.asserts import assert_queries_count from tests.test_utils.config import conf_vars, env_vars from tests.test_utils.db import ( - clear_db_dags, clear_db_errors, clear_db_jobs, clear_db_pools, clear_db_runs, clear_db_sla_miss, - set_default_pool_slots, + clear_db_dags, clear_db_errors, clear_db_jobs, clear_db_pools, clear_db_runs, clear_db_serialized_dags, + clear_db_sla_miss, set_default_pool_slots, ) from tests.test_utils.mock_executor import MockExecutor @@ -105,6 +107,7 @@ def clean_db(): clear_db_sla_miss() clear_db_errors() clear_db_jobs() + clear_db_serialized_dags() def setUp(self): self.clean_db() @@ -525,44 +528,42 @@ def test_dag_file_processor_process_task_instances_depends_on_past(self, state, assert tis[0].state == State.SCHEDULED assert tis[1].state == State.SCHEDULED - @pytest.mark.xfail(run=False, reason="TODO[HA]") - def test_dag_file_processor_add_new_task(self): + def test_scheduler_job_add_new_task(self): """ Test if a task instance will be added if the dag is updated """ - dag = DAG( - dag_id='test_scheduler_add_new_task', - start_date=DEFAULT_DATE) + dag = DAG(dag_id='test_scheduler_add_new_task', start_date=DEFAULT_DATE) + DummyOperator(task_id='dummy', dag=dag, owner='airflow') - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') + scheduler = SchedulerJob() + scheduler.dagbag.bag_dag(dag, root_dag=dag) + scheduler.dagbag.sync_to_db() session = settings.Session() - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - session.commit() - session.close() - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + orm_dag = session.query(DagModel).get(dag.dag_id) + assert orm_dag is not None - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag.clear() + scheduler = SchedulerJob() + dag = scheduler.dagbag.get_dag('test_scheduler_add_new_task', session=session) + scheduler._create_dag_run(orm_dag, dag, session) - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] tis = dr.get_task_instances() self.assertEqual(len(tis), 1) - DummyOperator( - task_id='dummy2', - dag=dag, - owner='airflow') - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + DummyOperator(task_id='dummy2', dag=dag, owner='airflow') + SerializedDagModel.write_dag(dag=dag) - dag_file_processor._process_task_instances(dag, dag_runs=[dr]) + scheduled_tis = scheduler._schedule_dag_run(dr, session) + session.flush() + assert scheduled_tis == 2 + + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] tis = dr.get_task_instances() self.assertEqual(len(tis), 2) @@ -914,6 +915,7 @@ def setUp(self): clear_db_dags() clear_db_sla_miss() clear_db_errors() + clear_db_serialized_dags() # Speed up some tests by not running the tasks, just look at what we # enqueue! @@ -2722,6 +2724,105 @@ def test_scheduler_verify_priority_and_slots(self): .filter(TaskInstance.task_id == 'test_scheduler_verify_priority_and_slots_t2').first() self.assertEqual(ti2.state, State.QUEUED) + def test_verify_integrity_if_dag_not_changed(self): + dag = DAG(dag_id='test_verify_integrity_if_dag_not_changed', start_date=DEFAULT_DATE) + DummyOperator(task_id='dummy', dag=dag, owner='airflow') + + scheduler = SchedulerJob() + scheduler.dagbag.bag_dag(dag, root_dag=dag) + scheduler.dagbag.sync_to_db() + + session = settings.Session() + orm_dag = session.query(DagModel).get(dag.dag_id) + assert orm_dag is not None + + scheduler = SchedulerJob() + dag = scheduler.dagbag.get_dag('test_verify_integrity_if_dag_not_changed', session=session) + scheduler._create_dag_run(orm_dag, dag, session) + + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] + + # Verify that DagRun.verify_integrity is not called + with mock.patch('airflow.jobs.scheduler_job.DagRun.verify_integrity') as mock_verify_integrity: + scheduled_tis = scheduler._schedule_dag_run(dr, session) + mock_verify_integrity.assert_not_called() + session.flush() + + assert scheduled_tis == 1 + + tis_count = session.query(func.count(TaskInstance.task_id)).filter( + TaskInstance.dag_id == dr.dag_id, + TaskInstance.execution_date == dr.execution_date, + TaskInstance.task_id == dr.dag.tasks[0].task_id, + TaskInstance.state == State.SCHEDULED + ).scalar() + assert tis_count == 1 + + latest_dag_version = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session) + assert dr.dag_hash == latest_dag_version + + session.rollback() + session.close() + + def test_verify_integrity_if_dag_changed(self): + dag = DAG(dag_id='test_verify_integrity_if_dag_changed', start_date=DEFAULT_DATE) + DummyOperator(task_id='dummy', dag=dag, owner='airflow') + + scheduler = SchedulerJob() + scheduler.dagbag.bag_dag(dag, root_dag=dag) + scheduler.dagbag.sync_to_db() + + session = settings.Session() + orm_dag = session.query(DagModel).get(dag.dag_id) + assert orm_dag is not None + + scheduler = SchedulerJob() + dag = scheduler.dagbag.get_dag('test_verify_integrity_if_dag_changed', session=session) + scheduler._create_dag_run(orm_dag, dag, session) + + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] + + dag_version_1 = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session) + assert dr.dag_hash == dag_version_1 + assert scheduler.dagbag.dags == {'test_verify_integrity_if_dag_changed': dag} + assert len(scheduler.dagbag.dags.get("test_verify_integrity_if_dag_changed").tasks) == 1 + + # Now let's say the DAG got updated (new task got added) + BashOperator(task_id='bash_task_1', dag=dag, bash_command='echo hi') + SerializedDagModel.write_dag(dag=dag) + + dag_version_2 = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session) + assert dag_version_2 != dag_version_1 + + scheduled_tis = scheduler._schedule_dag_run(dr, session) + session.flush() + + assert scheduled_tis == 2 + + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] + assert dr.dag_hash == dag_version_2 + assert scheduler.dagbag.dags == {'test_verify_integrity_if_dag_changed': dag} + assert len(scheduler.dagbag.dags.get("test_verify_integrity_if_dag_changed").tasks) == 2 + + tis_count = session.query(func.count(TaskInstance.task_id)).filter( + TaskInstance.dag_id == dr.dag_id, + TaskInstance.execution_date == dr.execution_date, + TaskInstance.state == State.SCHEDULED + ).scalar() + assert tis_count == 2 + + latest_dag_version = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session) + assert dr.dag_hash == latest_dag_version + + session.rollback() + session.close() + def test_retry_still_in_executor(self): """ Checks if the scheduler does not put a task in limbo, when a task is retried From 5a0a73f1704c6e76b0d18edd4f56ac2d70870190 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 24 Sep 2020 18:25:56 +0100 Subject: [PATCH 24/70] Do not Execute tasks with DummyOperators --- airflow/jobs/scheduler_job.py | 27 +++++++++++++++++++++++++- tests/jobs/test_scheduler_job.py | 33 +++++++++++++++++--------------- 2 files changed, 44 insertions(+), 16 deletions(-) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 9cdabb16eef0c..5abf6c74cfec3 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -1610,12 +1610,37 @@ def _schedule_dag_run(self, dag_run: DagRun, session: Session) -> int: # TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else? schedulable_tis = dag_run.update_state(session=session, execute_callbacks=False) # TODO[HA]: Don't return, update these from in update_state? + + # Get list of TIs that do not need to executed, these are + # tasks using DummyOperator and without on_execute_callback / on_success_callback + dummy_tis = [ + ti for ti in schedulable_tis + if + ( + ti.task.task_type == "DummyOperator" + and not ti.task.on_execute_callback + and not ti.task.on_success_callback + ) + ] + count = session.query(TI).filter( TI.dag_id == dag_run.dag_id, TI.execution_date == dag_run.execution_date, - TI.task_id.in_(ti.task_id for ti in schedulable_tis) + TI.task_id.in_(ti.task_id for ti in schedulable_tis if ti not in dummy_tis) ).update({TI.state: State.SCHEDULED}, synchronize_session=False) + # Tasks using DummyOperator should not be executed, mark them as success + session.query(TI).filter( + TI.dag_id == dag_run.dag_id, + TI.execution_date == dag_run.execution_date, + TI.task_id.in_(ti.task_id for ti in dummy_tis) + ).update({ + TI.state: State.SUCCESS, + TI.start_date: timezone.utcnow(), + TI.end_date: timezone.utcnow(), + TI.duration: 0 + }, synchronize_session=False) + return count @provide_session diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index a30e320c7d394..05aafe3e65d78 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -715,29 +715,34 @@ def test_process_file_should_failure_callback(self): self.assertEqual("Callback fired", content) os.remove(callback_file.name) - @pytest.mark.skip def test_should_mark_dummy_task_as_success(self): dag_file = os.path.join( os.path.dirname(os.path.realpath(__file__)), '../dags/test_only_dummy_tasks.py' ) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - with create_session() as session: - session.query(TaskInstance).delete() - session.query(DagModel).delete() - dagbag = DagBag(dag_folder=dag_file, include_examples=False) - dagbag.sync_to_db() + # Write DAGs to dag and serialized_dag table + with mock.patch("airflow.models.dagbag.settings.STORE_SERIALIZED_DAGS", return_value=True): + dagbag = DagBag(dag_folder=dag_file, include_examples=False) + dagbag.sync_to_db() - serialized_dags, import_errors_count = dag_file_processor.process_file( - file_path=dag_file, callback_requests=[] - ) + scheduler_job = SchedulerJob() + dag = scheduler_job.dagbag.get_dag("test_only_dummy_tasks") - dags = [SerializedDAG.from_dict(serialized_dag) for serialized_dag in serialized_dags] + # Create DagRun + session = settings.Session() + orm_dag = session.query(DagModel).get(dag.dag_id) + scheduler_job._create_dag_run(orm_dag, dag, session) + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] + + # Schedule TaskInstances + scheduler_job._schedule_dag_run(dr, session) with create_session() as session: tis = session.query(TaskInstance).all() - self.assertEqual(0, import_errors_count) + dags = scheduler_job.dagbag.dags.values() self.assertEqual(['test_only_dummy_tasks'], [dag.dag_id for dag in dags]) self.assertEqual(5, len(tis)) self.assertEqual({ @@ -758,9 +763,7 @@ def test_should_mark_dummy_task_as_success(self): self.assertIsNone(end_date) self.assertIsNone(duration) - dag_file_processor.process_file( - file_path=dag_file, callback_requests=[] - ) + scheduler_job._schedule_dag_run(dr, session) with create_session() as session: tis = session.query(TaskInstance).all() From 6e940886a79dff1e09a54777393836c7e207619a Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Fri, 25 Sep 2020 13:00:09 +0100 Subject: [PATCH 25/70] =?UTF-8?q?fixup!=20Don=E2=80=99t=20run=20verify=5Fi?= =?UTF-8?q?ntegrity=20if=20the=20Serialized=20DAG=20hasn=E2=80=99t=20chang?= =?UTF-8?q?ed?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/jobs/test_scheduler_job.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 05aafe3e65d78..ce0b53c234e66 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -918,7 +918,6 @@ def setUp(self): clear_db_dags() clear_db_sla_miss() clear_db_errors() - clear_db_serialized_dags() # Speed up some tests by not running the tasks, just look at what we # enqueue! From 48a9d5dceb82be6069633a0420f8d4ed14d8af02 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Fri, 25 Sep 2020 13:50:17 +0100 Subject: [PATCH 26/70] fixup! Do not Execute tasks with DummyOperators --- tests/jobs/test_scheduler_job.py | 34 ++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index ce0b53c234e66..f6516c342b37d 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -2560,11 +2560,13 @@ def test_scheduler_verify_pool_full(self): dag_id='test_scheduler_verify_pool_full', start_date=DEFAULT_DATE) - DummyOperator( + BashOperator( task_id='dummy', dag=dag, owner='airflow', - pool='test_scheduler_verify_pool_full') + pool='test_scheduler_verify_pool_full', + bash_command='echo hi', + ) dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), include_examples=False, @@ -2609,12 +2611,13 @@ def test_scheduler_verify_pool_full_2_slots_per_task(self): dag_id='test_scheduler_verify_pool_full_2_slots_per_task', start_date=DEFAULT_DATE) - DummyOperator( + BashOperator( task_id='dummy', dag=dag, owner='airflow', pool='test_scheduler_verify_pool_full_2_slots_per_task', pool_slots=2, + bash_command='echo hi', ) dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), @@ -2660,31 +2663,34 @@ def test_scheduler_verify_priority_and_slots(self): start_date=DEFAULT_DATE) # Medium priority, not enough slots - DummyOperator( + BashOperator( task_id='test_scheduler_verify_priority_and_slots_t0', dag=dag, owner='airflow', pool='test_scheduler_verify_priority_and_slots', pool_slots=2, priority_weight=2, + bash_command='echo hi', ) # High priority, occupies first slot - DummyOperator( + BashOperator( task_id='test_scheduler_verify_priority_and_slots_t1', dag=dag, owner='airflow', pool='test_scheduler_verify_priority_and_slots', pool_slots=1, priority_weight=3, + bash_command='echo hi', ) # Low priority, occupies second slot - DummyOperator( + BashOperator( task_id='test_scheduler_verify_priority_and_slots_t2', dag=dag, owner='airflow', pool='test_scheduler_verify_priority_and_slots', pool_slots=1, priority_weight=1, + bash_command='echo hi', ) dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), @@ -2727,8 +2733,14 @@ def test_scheduler_verify_priority_and_slots(self): self.assertEqual(ti2.state, State.QUEUED) def test_verify_integrity_if_dag_not_changed(self): + # CleanUp + with create_session() as session: + session.query(SerializedDagModel).filter( + SerializedDagModel.dag_id == 'test_verify_integrity_if_dag_not_changed' + ).delete(synchronize_session=False) + dag = DAG(dag_id='test_verify_integrity_if_dag_not_changed', start_date=DEFAULT_DATE) - DummyOperator(task_id='dummy', dag=dag, owner='airflow') + BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo hi') scheduler = SchedulerJob() scheduler.dagbag.bag_dag(dag, root_dag=dag) @@ -2769,8 +2781,14 @@ def test_verify_integrity_if_dag_not_changed(self): session.close() def test_verify_integrity_if_dag_changed(self): + # CleanUp + with create_session() as session: + session.query(SerializedDagModel).filter( + SerializedDagModel.dag_id == 'test_verify_integrity_if_dag_changed' + ).delete(synchronize_session=False) + dag = DAG(dag_id='test_verify_integrity_if_dag_changed', start_date=DEFAULT_DATE) - DummyOperator(task_id='dummy', dag=dag, owner='airflow') + BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo hi') scheduler = SchedulerJob() scheduler.dagbag.bag_dag(dag, root_dag=dag) From c79c88e5dae6d36aacb93e7df051ed9eecba9aeb Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Fri, 25 Sep 2020 18:22:01 +0100 Subject: [PATCH 27/70] fixup! Officially support running more than one scheduler concurrently. --- tests/jobs/test_scheduler_job.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index f6516c342b37d..29f9d6e41d74a 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -2278,6 +2278,8 @@ def evaluate_dagrun( self.null_exec.mock_task_fail(dag_id, tid, ex_date) try: + dag = DagBag().get_dag(dag.dag_id) + assert not isinstance(dag, SerializedDAG) # This needs a _REAL_ dag, not the serialized version dag.run(start_date=ex_date, end_date=ex_date, executor=self.null_exec, **run_kwargs) except AirflowException: From d61496af7c37725827698b4cd915ca9ac8e8f962 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Fri, 25 Sep 2020 20:55:11 +0100 Subject: [PATCH 28/70] fixup! Do not Execute tasks with DummyOperators --- tests/jobs/test_scheduler_job.py | 34 +++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 29f9d6e41d74a..6eeb059f70fd1 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -385,10 +385,12 @@ def test_dag_file_processor_process_task_instances(self, state, start_date, end_ dag = DAG( dag_id='test_scheduler_process_execute_task', start_date=DEFAULT_DATE) - DummyOperator( + BashOperator( task_id='dummy', dag=dag, - owner='airflow') + owner='airflow', + bash_command='echo hi' + ) with create_session() as session: orm_dag = DagModel(dag_id=dag.dag_id) @@ -435,11 +437,13 @@ def test_dag_file_processor_process_task_instances_with_task_concurrency( dag = DAG( dag_id='test_scheduler_process_execute_task_with_task_concurrency', start_date=DEFAULT_DATE) - DummyOperator( + BashOperator( task_id='dummy', task_concurrency=2, dag=dag, - owner='airflow') + owner='airflow', + bash_command='echo Hi' + ) with create_session() as session: orm_dag = DagModel(dag_id=dag.dag_id) @@ -488,14 +492,18 @@ def test_dag_file_processor_process_task_instances_depends_on_past(self, state, 'depends_on_past': True, }, ) - DummyOperator( + BashOperator( task_id='dummy1', dag=dag, - owner='airflow') - DummyOperator( + owner='airflow', + bash_command='echo hi' + ) + BashOperator( task_id='dummy2', dag=dag, - owner='airflow') + owner='airflow', + bash_command='echo hi' + ) with create_session() as session: orm_dag = DagModel(dag_id=dag.dag_id) @@ -533,7 +541,7 @@ def test_scheduler_job_add_new_task(self): Test if a task instance will be added if the dag is updated """ dag = DAG(dag_id='test_scheduler_add_new_task', start_date=DEFAULT_DATE) - DummyOperator(task_id='dummy', dag=dag, owner='airflow') + BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo test') scheduler = SchedulerJob() scheduler.dagbag.bag_dag(dag, root_dag=dag) @@ -554,7 +562,7 @@ def test_scheduler_job_add_new_task(self): tis = dr.get_task_instances() self.assertEqual(len(tis), 1) - DummyOperator(task_id='dummy2', dag=dag, owner='airflow') + BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo test') SerializedDagModel.write_dag(dag=dag) scheduled_tis = scheduler._schedule_dag_run(dr, session) @@ -578,10 +586,12 @@ def test_runs_respected_after_clear(self): start_date=DEFAULT_DATE) dag.max_active_runs = 3 - DummyOperator( + BashOperator( task_id='dummy', dag=dag, - owner='airflow') + owner='airflow', + bash_command='echo Hi' + ) session = settings.Session() orm_dag = DagModel(dag_id=dag.dag_id) From 291be0f39519bf4ac524d05214c799d531e53bdd Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Fri, 25 Sep 2020 21:27:15 +0100 Subject: [PATCH 29/70] fixup! Officially support running more than one scheduler concurrently. dag.create_dagrun was missing a session.flush due to which dagrun.verify_integrity failed with following ``` > if task.start_date > self.execution_date and not self.is_backfill: E TypeError: '>' not supported between instances of 'datetime.datetime' and 'NoneType' ``` --- airflow/models/dag.py | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index c0cc63b10978f..f55e92a87e790 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1666,6 +1666,7 @@ def create_dagrun( dag_hash=dag_hash ) session.add(run) + session.flush() run.dag = self From daad3f839e63b541c9fe50781bd3d75a4c1a4717 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 28 Sep 2020 11:43:27 +0100 Subject: [PATCH 30/70] fixup! Officially support running more than one scheduler concurrently. --- airflow/models/dag.py | 2 +- tests/models/test_dag.py | 38 ++++++++++++++++++++ tests/test_utils/perf/perf_kit/python.py | 2 +- tests/test_utils/perf/perf_kit/sqlalchemy.py | 2 +- 4 files changed, 41 insertions(+), 3 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index f55e92a87e790..7dfdf0176789a 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -505,7 +505,7 @@ def next_dagrun_after_date(self, date_last_automated_dagrun: Optional[pendulum.D backfill triggered run for this dag :type date_last_automated_dagrun: pendulum.Pendulum """ - if not self.schedule_interval: + if not self.schedule_interval or self.is_subdag: return None # don't schedule @once again diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 9345a8abe890c..1f221b23c0017 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -1612,6 +1612,44 @@ def test_next_dagrun_after_auto_align(self): next_date = dag.next_dagrun_after_date(None) assert next_date == timezone.datetime(2016, 1, 1, 10, 10) + def test_next_dagrun_after_not_for_subdags(self): + """ + Test the subdags are never marked to have dagruns created, as they are + handled by the SubDagOperator, not the scheduler + """ + + def subdag(parent_dag_name, child_dag_name, args): + """ + Create a subdag. + """ + dag_subdag = DAG(dag_id='%s.%s' % (parent_dag_name, child_dag_name), schedule_interval="@daily", default_args=args) + + for i in range(2): + DummyOperator(task_id='%s-task-%s' % (child_dag_name, i + 1), dag=dag_subdag) + + return dag_subdag + + with DAG( + dag_id='test_subdag_operator', + start_date=datetime.datetime(2019, 1, 1), + max_active_runs=1, + schedule_interval=timedelta(minutes=1), + ) as dag: + section_1 = SubDagOperator( + task_id='section-1', + subdag=subdag(dag.dag_id, 'section-1', {'start_date': dag.start_date}), + ) + + subdag = section_1.subdag + # parent_dag and is_subdag was set by DagBag. We don't use DagBag, so this value is not set. + subdag.parent_dag = dag + subdag.is_subdag = True + + next_date = dag.next_dagrun_after_date(None) + assert next_date == timezone.datetime(2019, 1, 1, 0, 0) + + next_subdag_date = subdag.next_dagrun_after_date(None) + assert next_subdag_date is None, "SubDags should never have DagRuns created by the scheduler" class TestDagModel: diff --git a/tests/test_utils/perf/perf_kit/python.py b/tests/test_utils/perf/perf_kit/python.py index 3169e9c43ea93..7d92a497fe9d7 100644 --- a/tests/test_utils/perf/perf_kit/python.py +++ b/tests/test_utils/perf/perf_kit/python.py @@ -96,7 +96,7 @@ def case(): log = logging.getLogger(__name__) processor = DagFileProcessor(dag_ids=[], log=log) dag_file = os.path.join(os.path.dirname(airflow.__file__), "example_dags", "example_complex.py") - processor.process_file(file_path=dag_file, failure_callback_requests=[]) + processor.process_file(file_path=dag_file, callback_requests=[]) # Load modules case() diff --git a/tests/test_utils/perf/perf_kit/sqlalchemy.py b/tests/test_utils/perf/perf_kit/sqlalchemy.py index 06a305c709646..e5c7c359cec90 100644 --- a/tests/test_utils/perf/perf_kit/sqlalchemy.py +++ b/tests/test_utils/perf/perf_kit/sqlalchemy.py @@ -222,7 +222,7 @@ def case(): log = logging.getLogger(__name__) processor = DagFileProcessor(dag_ids=[], log=log) dag_file = os.path.join(os.path.dirname(__file__), os.path.pardir, "dags", "elastic_dag.py") - processor.process_file(file_path=dag_file, failure_callback_requests=[]) + processor.process_file(file_path=dag_file, callback_requests=[]) with trace_queries(), count_queries(): case() From 742d63337f4ec5b334b4082fb36684dad9e00e27 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 28 Sep 2020 13:48:03 +0100 Subject: [PATCH 31/70] fixup! Do not Execute tasks with DummyOperators --- airflow/jobs/scheduler_job.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 5abf6c74cfec3..2802414ccabb9 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -1630,16 +1630,17 @@ def _schedule_dag_run(self, dag_run: DagRun, session: Session) -> int: ).update({TI.state: State.SCHEDULED}, synchronize_session=False) # Tasks using DummyOperator should not be executed, mark them as success - session.query(TI).filter( - TI.dag_id == dag_run.dag_id, - TI.execution_date == dag_run.execution_date, - TI.task_id.in_(ti.task_id for ti in dummy_tis) - ).update({ - TI.state: State.SUCCESS, - TI.start_date: timezone.utcnow(), - TI.end_date: timezone.utcnow(), - TI.duration: 0 - }, synchronize_session=False) + if dummy_tis: + session.query(TI).filter( + TI.dag_id == dag_run.dag_id, + TI.execution_date == dag_run.execution_date, + TI.task_id.in_(ti.task_id for ti in dummy_tis) + ).update({ + TI.state: State.SUCCESS, + TI.start_date: timezone.utcnow(), + TI.end_date: timezone.utcnow(), + TI.duration: 0 + }, synchronize_session=False) return count From 167e146edb09b951f67139db8f513cb009c42e00 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 28 Sep 2020 13:48:40 +0100 Subject: [PATCH 32/70] =?UTF-8?q?fixup!=20Don=E2=80=99t=20run=20verify=5Fi?= =?UTF-8?q?ntegrity=20if=20the=20Serialized=20DAG=20hasn=E2=80=99t=20chang?= =?UTF-8?q?ed?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- airflow/jobs/scheduler_job.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 2802414ccabb9..6d037ee0efa04 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -1655,7 +1655,7 @@ def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session=None): dag_run.dag_hash = latest_version # Refresh the DAG - dag_run.dag = self.dagbag.get_dag(dag_id=dag_run.dag_id) + dag_run.dag = self.dagbag.get_dag(dag_id=dag_run.dag_id, session=session) # Verify integrity also takes care of session.flush dag_run.verify_integrity(session=session) From 764aa20cecab2aeee1a974b454d6270c6f3c3308 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 28 Sep 2020 17:23:13 +0100 Subject: [PATCH 33/70] fixup! Officially support running more than one scheduler concurrently. --- airflow/jobs/scheduler_job.py | 131 ++++++++++------- tests/jobs/test_scheduler_job.py | 241 +++++++++++++++---------------- 2 files changed, 197 insertions(+), 175 deletions(-) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 6d037ee0efa04..242be9094b0da 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -31,7 +31,7 @@ from contextlib import ExitStack, redirect_stderr, redirect_stdout, suppress from datetime import timedelta from multiprocessing.connection import Connection as MultiprocessingConnection -from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple +from typing import Any, DefaultDict, Dict, Iterable, List, Optional, Set, Tuple from setproctitle import setproctitle from sqlalchemy import and_, func, not_, or_ @@ -1435,17 +1435,38 @@ def validate_commit(_): raise RuntimeError("UNEXPECTED COMMIT - THIS WILL BREAK HA LOCKS!") query = DagModel.dags_needing_dagruns(session) - for dag_model in query: - dag = self.dagbag.get_dag(dag_model.dag_id, session=session) - self._create_dag_run(dag_model, dag, session) + self._create_dag_runs(query.all(), session) # commit the session - Release the write lock on DagModel table. expected_commit = True session.commit() # END: create dagruns - for dag_run in DagRun.next_dagruns_to_examine(session): - self._schedule_dag_run(dag_run, session) + dag_runs = DagRun.next_dagruns_to_examine(session) + + # Bulk fetch the currently active dag runs for the dags we are + # examining, rather than making one query per DagRun + + # TODO: This query is probably horribly inefficient (though there is an + # index on (dag_id,state)). It is to deal with the case when a user + # clears more than max_active_runs older tasks -- we don't want the + # scheduler to suddenly go and start running tasks from all of the + # runs. (AIRFLOW-137/GH #1442) + # + # The longer term fix would be to have `clear` do this, and put DagRuns + # in to the queued state, then take DRs out of queued before creating + # any new ones + # TODO[HA]: Why is this on TI, not on DagRun?? + currently_active_runs = dict(session.query( + TI.dag_id, + func.count(TI.execution_date.distinct()), + ).filter( + TI.dag_id.in_(list({dag_run.dag_id for dag_run in dag_runs})), + TI.state.notin_(State.finished()) + ).group_by(TI.dag_id).all()) + + for dag_run in dag_runs: + self._schedule_dag_run(dag_run, currently_active_runs.get(dag_run.dag_id, 0), session) expected_commit = True session.commit() @@ -1498,54 +1519,65 @@ def validate_commit(_): finally: event.remove(session.bind, 'commit', validate_commit) - def _create_dag_run(self, dag_model: DagModel, dag: DAG, session: Session) -> None: + def _create_dag_runs(self, dag_models: Iterable[DagModel], session: Session) -> None: """ Unconditionally create a DAG run for the given DAG, and update the dag_model's fields to control if/when the next DAGRun should be created """ - dag_hash = self.dagbag.dags_hash.get(dag.dag_id, None) - - dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=dag_model.next_dagrun, - start_date=timezone.utcnow(), - state=State.RUNNING, - external_trigger=False, - session=session, - dag_hash=dag_hash - ) + for dag_model in dag_models: + dag = self.dagbag.get_dag(dag_model.dag_id, session=session) + dag_hash = self.dagbag.dags_hash.get(dag.dag_id, None) + dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=dag_model.next_dagrun, + start_date=timezone.utcnow(), + state=State.RUNNING, + external_trigger=False, + session=session, + dag_hash=dag_hash + ) - self._update_dag_next_dagrun(dag_model, dag, session) + self._update_dag_next_dagruns(dag_models, session) # TODO[HA]: Should we do a session.flush() so we don't have to keep lots of state/object in # memory for larger dags? or expunge_all() - def _update_dag_next_dagrun(self, dag_model: DagModel, dag: DAG, session: Session) -> None: + def _update_dag_next_dagruns(self, dag_models: Iterable[DagModel], session: Session) -> None: + """ + Bulk update the next_dagrun and next_dagrun_create_after for all the dags. - # Check max_active_runs, to see if we are _now_ at the limit for this dag? (we've just created - # one after all) - active_runs_of_dag = session.query(func.count('*')).filter( - DagRun.dag_id == dag_model.dag_id, + We batch the select queries to get info about all the dags at once + """ + # Check max_active_runs, to see if we are _now_ at the limit for any of + # these dag? (we've just created a DagRun for them after all) + active_runs_of_dags = dict(session.query(DagRun.dag_id, func.count('*')).filter( + DagRun.dag_id.in_([o.dag_id for o in dag_models]), DagRun.state == State.RUNNING, # pylint: disable=comparison-with-callable DagRun.external_trigger.is_(False), - ).scalar() - - # TODO[HA]: add back in dagrun.timeout + ).group_by(DagRun.dag_id).all()) - if dag.max_active_runs and active_runs_of_dag >= dag.max_active_runs: - self.log.info( - "DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs", - dag.dag_id, active_runs_of_dag, dag.max_active_runs - ) - dag_model.next_dagrun_create_after = None - else: - dag_model.next_dagrun, dag_model.next_dagrun_create_after = \ - dag.next_dagrun_info(dag_model.next_dagrun) + for dag_model in dag_models: + dag = self.dagbag.get_dag(dag_model.dag_id, session=session) + active_runs_of_dag = active_runs_of_dags.get(dag.dag_id, 0) + if dag.max_active_runs and active_runs_of_dag >= dag.max_active_runs: + self.log.info( + "DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs", + dag.dag_id, active_runs_of_dag, dag.max_active_runs + ) + dag_model.next_dagrun_create_after = None + else: + dag_model.next_dagrun, dag_model.next_dagrun_create_after = \ + dag.next_dagrun_info(dag_model.next_dagrun) - def _schedule_dag_run(self, dag_run: DagRun, session: Session) -> int: + def _schedule_dag_run(self, dag_run: DagRun, currently_active_runs: int, session: Session) -> int: """ Make scheduling decisions about an individual dag run + ``currently_active_runs`` is passed in so that a batch query can be + used to ask this for all dag runs in the batch, to avoid an n+1 query. + + :param dag_run: The DagRun to schedule + :param currently_active_runs: Number of currently active runs of this DAG :return: Number of tasks scheduled """ dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session) @@ -1566,7 +1598,7 @@ def _schedule_dag_run(self, dag_run: DagRun, session: Session) -> int: session.flush() # Work out if we should allow creating a new DagRun now? - self._update_dag_next_dagrun(session.query(DagModel).get(dag_run.dag_id), dag, session) + self._update_dag_next_dagruns([session.query(DagModel).get(dag_run.dag_id)], session) dag_run.callback = DagCallbackRequest( full_filepath=dag.fileloc, @@ -1588,22 +1620,13 @@ def _schedule_dag_run(self, dag_run: DagRun, session: Session) -> int: ) return 0 - # TODO: This query is probably horribly inefficient (though there is an - # index on (dag_id,state)). It is to deal with the case when a user - # clears more than max_active_runs older tasks -- we don't want the - # scheduler to suddenly go and start running tasks from all of the - # runs. (AIRFLOW-137/GH #1442) - # - # The longer term fix would be to have `clear` do this, and put DagRuns - # in to the queued state, then take DRs out of queued before creating - # any new ones if dag.max_active_runs: - currently_active_runs = session.query(func.count(TI.execution_date.distinct())).filter( - TI.dag_id == dag_run.dag_id, - TI.state.notin_(State.finished()) - ).scalar() - if currently_active_runs >= dag.max_active_runs: + self.log.info( + "DAG %s already has %d active runs, not queuing any more tasks", + dag.dag_id, + currently_active_runs, + ) return 0 self._verify_integrity_if_dag_changed(dag_run=dag_run, session=session) @@ -1623,6 +1646,10 @@ def _schedule_dag_run(self, dag_run: DagRun, session: Session) -> int: ) ] + # This will do one query per dag run. We "could" build up a complex + # query to update all the TIs across all the execution dates and dag + # IDs in a single query, but it turns out that can be _very very slow_ + # see #11147/commit ee90807ac for more details count = session.query(TI).filter( TI.dag_id == dag_run.dag_id, TI.execution_date == dag_run.execution_date, diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 6eeb059f70fd1..e377d7a869de7 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -414,7 +414,7 @@ def test_dag_file_processor_process_task_instances(self, state, start_date, end_ ti.start_date = start_date ti.end_date = end_date - count = scheduler._schedule_dag_run(dr, session) + count = scheduler._schedule_dag_run(dr, 0, session) assert count == 1 session.refresh(ti) @@ -467,7 +467,7 @@ def test_dag_file_processor_process_task_instances_with_task_concurrency( ti.start_date = start_date ti.end_date = end_date - count = scheduler._schedule_dag_run(dr, session) + count = scheduler._schedule_dag_run(dr, 0, session) assert count == 1 session.refresh(ti) @@ -528,7 +528,7 @@ def test_dag_file_processor_process_task_instances_depends_on_past(self, state, ti.start_date = start_date ti.end_date = end_date - count = scheduler._schedule_dag_run(dr, session) + count = scheduler._schedule_dag_run(dr, 0, session) assert count == 2 session.refresh(tis[0]) @@ -553,7 +553,7 @@ def test_scheduler_job_add_new_task(self): scheduler = SchedulerJob() dag = scheduler.dagbag.get_dag('test_scheduler_add_new_task', session=session) - scheduler._create_dag_run(orm_dag, dag, session) + scheduler._create_dag_runs([orm_dag], session) drs = DagRun.find(dag_id=dag.dag_id, session=session) assert len(drs) == 1 @@ -565,7 +565,7 @@ def test_scheduler_job_add_new_task(self): BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo test') SerializedDagModel.write_dag(dag=dag) - scheduled_tis = scheduler._schedule_dag_run(dr, session) + scheduled_tis = scheduler._schedule_dag_run(dr, 0, session) session.flush() assert scheduled_tis == 2 @@ -635,36 +635,13 @@ def test_runs_respected_after_clear(self): # and schedule them in, so we can check how many # tasks are put on the task_instances_list (should be one, not 3) with create_session() as session: - num_scheduled = scheduler._schedule_dag_run(dr1, session) + num_scheduled = scheduler._schedule_dag_run(dr1, 0, session) assert num_scheduled == 1 - num_scheduled = scheduler._schedule_dag_run(dr2, session) + num_scheduled = scheduler._schedule_dag_run(dr2, 1, session) assert num_scheduled == 0 - num_scheduled = scheduler._schedule_dag_run(dr3, session) + num_scheduled = scheduler._schedule_dag_run(dr3, 1, session) assert num_scheduled == 0 - @pytest.mark.xfail(run=False, reason="TODO[HA]") - def test_process_dags_not_create_dagrun_for_subdags(self): - dag = self.dagbag.get_dag('test_subdag_operator') - - scheduler = DagFileProcessor(dag_ids=[dag.dag_id], log=mock.MagicMock()) - scheduler._process_task_instances = mock.MagicMock() - scheduler.manage_slas = mock.MagicMock() - - scheduler._process_dags([dag] + dag.subdags) - - with create_session() as session: - sub_dagruns = ( - session.query(DagRun).filter(DagRun.dag_id == dag.subdags[0].dag_id).count() - ) - - self.assertEqual(0, sub_dagruns) - - parent_dagruns = ( - session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).count() - ) - - self.assertGreater(parent_dagruns, 0) - @patch.object(TaskInstance, 'handle_failure') def test_execute_on_failure_callbacks(self, mock_ti_handle_failure): dagbag = DagBag(dag_folder="/dev/null", include_examples=True) @@ -741,14 +718,14 @@ def test_should_mark_dummy_task_as_success(self): # Create DagRun session = settings.Session() orm_dag = session.query(DagModel).get(dag.dag_id) - scheduler_job._create_dag_run(orm_dag, dag, session) + scheduler_job._create_dag_runs([orm_dag], session) drs = DagRun.find(dag_id=dag.dag_id, session=session) assert len(drs) == 1 dr = drs[0] # Schedule TaskInstances - scheduler_job._schedule_dag_run(dr, session) + scheduler_job._schedule_dag_run(dr, 0, session) with create_session() as session: tis = session.query(TaskInstance).all() @@ -773,7 +750,7 @@ def test_should_mark_dummy_task_as_success(self): self.assertIsNone(end_date) self.assertIsNone(duration) - scheduler_job._schedule_dag_run(dr, session) + scheduler_job._schedule_dag_run(dr, 0, session) with create_session() as session: tis = session.query(TaskInstance).all() @@ -816,63 +793,6 @@ def setUp(self) -> None: clear_db_sla_miss() clear_db_errors() - @parameterized.expand( - [ - # pylint: disable=bad-whitespace - # expected, dag_count, task_count, start_ago, schedule_interval, shape - # One DAG with one task per DAG file - ([ 1, 1, 1, 1], 1, 1, "1d", "None", "no_structure"), # noqa - ([ 1, 1, 1, 1], 1, 1, "1d", "None", "linear"), # noqa - ([ 9, 5, 5, 5], 1, 1, "1d", "@once", "no_structure"), # noqa - ([ 9, 5, 5, 5], 1, 1, "1d", "@once", "linear"), # noqa - ([ 9, 12, 15, 18], 1, 1, "1d", "30m", "no_structure"), # noqa - ([ 9, 12, 15, 18], 1, 1, "1d", "30m", "linear"), # noqa - ([ 9, 12, 15, 18], 1, 1, "1d", "30m", "binary_tree"), # noqa - ([ 9, 12, 15, 18], 1, 1, "1d", "30m", "star"), # noqa - ([ 9, 12, 15, 18], 1, 1, "1d", "30m", "grid"), # noqa - # One DAG with five tasks per DAG file - ([ 1, 1, 1, 1], 1, 5, "1d", "None", "no_structure"), # noqa - ([ 1, 1, 1, 1], 1, 5, "1d", "None", "linear"), # noqa - ([ 9, 5, 5, 5], 1, 5, "1d", "@once", "no_structure"), # noqa - ([10, 6, 6, 6], 1, 5, "1d", "@once", "linear"), # noqa - ([ 9, 12, 15, 18], 1, 5, "1d", "30m", "no_structure"), # noqa - ([10, 14, 18, 22], 1, 5, "1d", "30m", "linear"), # noqa - ([10, 14, 18, 22], 1, 5, "1d", "30m", "binary_tree"), # noqa - ([10, 14, 18, 22], 1, 5, "1d", "30m", "star"), # noqa - ([10, 14, 18, 22], 1, 5, "1d", "30m", "grid"), # noqa - # 10 DAGs with 10 tasks per DAG file - ([ 1, 1, 1, 1], 10, 10, "1d", "None", "no_structure"), # noqa - ([ 1, 1, 1, 1], 10, 10, "1d", "None", "linear"), # noqa - ([81, 41, 41, 41], 10, 10, "1d", "@once", "no_structure"), # noqa - ([91, 51, 51, 51], 10, 10, "1d", "@once", "linear"), # noqa - ([81, 111, 111, 111], 10, 10, "1d", "30m", "no_structure"), # noqa - ([91, 131, 131, 131], 10, 10, "1d", "30m", "linear"), # noqa - ([91, 131, 131, 131], 10, 10, "1d", "30m", "binary_tree"), # noqa - ([91, 131, 131, 131], 10, 10, "1d", "30m", "star"), # noqa - ([91, 131, 131, 131], 10, 10, "1d", "30m", "grid"), # noqa - # pylint: enable=bad-whitespace - ] - ) - def test_process_dags_queries_count( - self, expected_query_counts, dag_count, task_count, start_ago, schedule_interval, shape - ): - with mock.patch.dict("os.environ", { - "PERF_DAGS_COUNT": str(dag_count), - "PERF_TASKS_COUNT": str(task_count), - "PERF_START_AGO": start_ago, - "PERF_SCHEDULE_INTERVAL": schedule_interval, - "PERF_SHAPE": shape, - }), conf_vars({ - ('scheduler', 'use_job_schedule'): 'True', - }): - dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE, - include_examples=False, - include_smart_sensor=False) - processor = DagFileProcessor([], mock.MagicMock()) - for expected_query_count in expected_query_counts: - with assert_queries_count(expected_query_count): - processor._process_dags(dagbag.dags.values()) - @parameterized.expand( [ # pylint: disable=bad-whitespace @@ -2109,7 +2029,7 @@ def test_dagrun_timeout_verify_max_active_runs(self): dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) scheduler = SchedulerJob() - scheduler._create_dag_run(orm_dag, dag, session) + scheduler._create_dag_runs([orm_dag], session) drs = DagRun.find(dag_id=dag.dag_id, session=session) assert len(drs) == 1 @@ -2128,7 +2048,7 @@ def test_dagrun_timeout_verify_max_active_runs(self): scheduler.processor_agent = mock.Mock() scheduler.processor_agent.send_callback_to_execute = mock.Mock() - scheduler._schedule_dag_run(dr, session) + scheduler._schedule_dag_run(dr, 0, session) session.flush() session.refresh(dr) @@ -2177,7 +2097,7 @@ def test_dagrun_timeout_fails_run(self): dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) scheduler = SchedulerJob() - scheduler._create_dag_run(orm_dag, dag, session) + scheduler._create_dag_runs([orm_dag], session) drs = DagRun.find(dag_id=dag.dag_id, session=session) assert len(drs) == 1 @@ -2191,7 +2111,7 @@ def test_dagrun_timeout_fails_run(self): scheduler.processor_agent = mock.Mock() scheduler.processor_agent.send_callback_to_execute = mock.Mock() - scheduler._schedule_dag_run(dr, session) + scheduler._schedule_dag_run(dr, 0, session) session.flush() session.refresh(dr) @@ -2601,13 +2521,13 @@ def test_scheduler_verify_pool_full(self): execution_date=DEFAULT_DATE, state=State.RUNNING, ) - scheduler._schedule_dag_run(dr, session) + scheduler._schedule_dag_run(dr, 0, session) dr = dag.create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=dag.following_schedule(dr.execution_date), state=State.RUNNING, ) - scheduler._schedule_dag_run(dr, session) + scheduler._schedule_dag_run(dr, 0, session) task_instances_list = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) @@ -2655,7 +2575,7 @@ def test_scheduler_verify_pool_full_2_slots_per_task(self): execution_date=date, state=State.RUNNING, ) - scheduler._schedule_dag_run(dr, session) + scheduler._schedule_dag_run(dr, 0, session) date = dag.following_schedule(date) task_instances_list = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) @@ -2725,7 +2645,7 @@ def test_scheduler_verify_priority_and_slots(self): execution_date=DEFAULT_DATE, state=State.RUNNING, ) - scheduler._schedule_dag_run(dr, session) + scheduler._schedule_dag_run(dr, 0, session) task_instances_list = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) @@ -2764,7 +2684,7 @@ def test_verify_integrity_if_dag_not_changed(self): scheduler = SchedulerJob() dag = scheduler.dagbag.get_dag('test_verify_integrity_if_dag_not_changed', session=session) - scheduler._create_dag_run(orm_dag, dag, session) + scheduler._create_dag_runs([orm_dag], session) drs = DagRun.find(dag_id=dag.dag_id, session=session) assert len(drs) == 1 @@ -2772,7 +2692,7 @@ def test_verify_integrity_if_dag_not_changed(self): # Verify that DagRun.verify_integrity is not called with mock.patch('airflow.jobs.scheduler_job.DagRun.verify_integrity') as mock_verify_integrity: - scheduled_tis = scheduler._schedule_dag_run(dr, session) + scheduled_tis = scheduler._schedule_dag_run(dr, 0, session) mock_verify_integrity.assert_not_called() session.flush() @@ -2812,7 +2732,7 @@ def test_verify_integrity_if_dag_changed(self): scheduler = SchedulerJob() dag = scheduler.dagbag.get_dag('test_verify_integrity_if_dag_changed', session=session) - scheduler._create_dag_run(orm_dag, dag, session) + scheduler._create_dag_runs([orm_dag], session) drs = DagRun.find(dag_id=dag.dag_id, session=session) assert len(drs) == 1 @@ -2830,7 +2750,7 @@ def test_verify_integrity_if_dag_changed(self): dag_version_2 = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session) assert dag_version_2 != dag_version_1 - scheduled_tis = scheduler._schedule_dag_run(dr, session) + scheduled_tis = scheduler._schedule_dag_run(dr, 0, session) session.flush() assert scheduled_tis == 2 @@ -3518,7 +3438,7 @@ def test_task_with_upstream_skip_process_task_instances(): dummy3 = DummyOperator(task_id="dummy3") [dummy1, dummy2] >> dummy3 - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + # dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) dag.clear() dr = dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, @@ -3532,8 +3452,8 @@ def test_task_with_upstream_skip_process_task_instances(): tis[dummy2.task_id].state = State.SUCCESS assert tis[dummy3.task_id].state == State.NONE - dag_runs = DagRun.find(dag_id='test_task_with_upstream_skip_dag') - dag_file_processor._process_task_instances(dag, dag_runs=dag_runs) + # dag_runs = DagRun.find(dag_id='test_task_with_upstream_skip_dag') + # dag_file_processor._process_task_instances(dag, dag_runs=dag_runs) with create_session() as session: tis = {ti.task_id: ti for ti in dr.get_task_instances(session=session)} @@ -3543,7 +3463,6 @@ def test_task_with_upstream_skip_process_task_instances(): assert tis[dummy3.task_id].state == State.SKIPPED -@pytest.mark.xfail(reason="Work why this didn't infinite loop before!") class TestSchedulerJobQueriesCount(unittest.TestCase): """ These tests are designed to detect changes in the number of queries for @@ -3563,11 +3482,11 @@ def setUp(self) -> None: # pylint: disable=bad-whitespace # expected, dag_count, task_count # One DAG with one task per DAG file - (13, 1, 1), # noqa + (21, 1, 1), # noqa # One DAG with five tasks per DAG file - (17, 1, 5), # noqa + (21, 1, 5), # noqa # 10 DAGs with 10 tasks per DAG file - (46, 10, 10), # noqa + (93, 10, 10), # noqa ] ) def test_execute_queries_count_with_harvested_dags(self, expected_query_count, dag_count, task_count): @@ -3580,37 +3499,48 @@ def test_execute_queries_count_with_harvested_dags(self, expected_query_count, d }), conf_vars({ ('scheduler', 'use_job_schedule'): 'True', ('core', 'load_examples'): 'False', - }): - + ('core', 'store_serialized_dags'): 'True', + }), mock.patch.object(settings, 'STORE_SERIALIZED_DAGS', True): + dagruns = [] dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE, include_examples=False) - for i, dag in enumerate(dagbag.dags.values()): - dr = dag.create_dagrun(state=State.RUNNING, run_id=f"{DagRunType.MANUAL.value}__{i}", - execution_date=DEFAULT_DATE) + dagbag.sync_to_db() + + dag_ids = dagbag.dag_ids + dagbag = DagBag(read_dags_from_db=True) + for i, dag_id in enumerate(dag_ids): + dag = dagbag.get_dag(dag_id) + dr = dag.create_dagrun( + state=State.RUNNING, + run_id=f"{DagRunType.MANUAL.value}__{i}", + dag_hash=dagbag.dags_hash[dag.dag_id], + ) + dagruns.append(dr) for ti in dr.get_task_instances(): ti.set_state(state=State.SCHEDULED) mock_agent = mock.MagicMock() - mock_agent.harvest_serialized_dags.return_value = [ - SerializedDAG.from_dict(SerializedDAG.to_dict(d)) for d in dagbag.dags.values()] - job = SchedulerJob(subdir=PERF_DAGS_FOLDER) - job.executor = MockExecutor() + job = SchedulerJob(subdir=PERF_DAGS_FOLDER, num_runs=1) + job.executor = MockExecutor(do_update=False) job.heartbeat = mock.MagicMock() job.processor_agent = mock_agent with assert_queries_count(expected_query_count): - job._run_scheduler_loop() + with mock.patch.object(DagRun, 'next_dagruns_to_examine') as mock_dagruns: + mock_dagruns.return_value = dagruns + + job._run_scheduler_loop() @parameterized.expand( [ # pylint: disable=bad-whitespace # expected, dag_count, task_count # One DAG with one task per DAG file - (2, 1, 1), # noqa + (8, 1, 1), # noqa # One DAG with five tasks per DAG file - (2, 1, 5), # noqa + (8, 1, 5), # noqa # 10 DAGs with 10 tasks per DAG file - (2, 10, 10), # noqa + (8, 10, 10), # noqa ] ) def test_execute_queries_count_no_harvested_dags(self, expected_query_count, dag_count, task_count): @@ -3642,3 +3572,68 @@ def test_execute_queries_count_no_harvested_dags(self, expected_query_count, dag with assert_queries_count(expected_query_count): with create_session() as session: job._do_scheduling(session) + + @parameterized.expand( + [ + # pylint: disable=bad-whitespace + # expected, dag_count, task_count, start_ago, schedule_interval, shape + # One DAG with one task per DAG file + ([ 8, 8, 8, 8], 1, 1, "1d", "None", "no_structure"), # noqa + ([ 8, 8, 8, 8], 1, 1, "1d", "None", "linear"), # noqa + ([20, 12, 12, 12], 1, 1, "1d", "@once", "no_structure"), # noqa + ([20, 12, 12, 12], 1, 1, "1d", "@once", "linear"), # noqa + ([20, 22, 25, 28], 1, 1, "1d", "30m", "no_structure"), # noqa + ([20, 22, 25, 28], 1, 1, "1d", "30m", "linear"), # noqa + ([20, 22, 25, 28], 1, 1, "1d", "30m", "binary_tree"), # noqa + ([20, 22, 25, 28], 1, 1, "1d", "30m", "star"), # noqa + ([20, 22, 25, 28], 1, 1, "1d", "30m", "grid"), # noqa + # One DAG with five tasks per DAG file + ([ 8, 8, 8, 8], 1, 5, "1d", "None", "no_structure"), # noqa + ([ 8, 8, 8, 8], 1, 5, "1d", "None", "linear"), # noqa + ([20, 12, 12, 12], 1, 5, "1d", "@once", "no_structure"), # noqa + ([21, 13, 13, 13], 1, 5, "1d", "@once", "linear"), # noqa + ([20, 22, 25, 28], 1, 5, "1d", "30m", "no_structure"), # noqa + ([21, 24, 28, 32], 1, 5, "1d", "30m", "linear"), # noqa + ([21, 24, 28, 32], 1, 5, "1d", "30m", "binary_tree"), # noqa + ([21, 24, 28, 32], 1, 5, "1d", "30m", "star"), # noqa + ([21, 24, 28, 32], 1, 5, "1d", "30m", "grid"), # noqa + # 10 DAGs with 10 tasks per DAG file + ([ 8, 8, 8, 8], 10, 10, "1d", "None", "no_structure"), # noqa + ([ 8, 8, 8, 8], 10, 10, "1d", "None", "linear"), # noqa + ([83, 36, 36, 36], 10, 10, "1d", "@once", "no_structure"), # noqa + ([93, 49, 49, 49], 10, 10, "1d", "@once", "linear"), # noqa + ([83, 97, 97, 97], 10, 10, "1d", "30m", "no_structure"), # noqa + ([93, 123, 120, 120], 10, 10, "1d", "30m", "linear"), # noqa + ([93, 117, 117, 117], 10, 10, "1d", "30m", "binary_tree"), # noqa + ([93, 117, 117, 117], 10, 10, "1d", "30m", "star"), # noqa + ([93, 117, 117, 117], 10, 10, "1d", "30m", "grid"), # noqa + # pylint: enable=bad-whitespace + ] + ) + def test_process_dags_queries_count( + self, expected_query_counts, dag_count, task_count, start_ago, schedule_interval, shape + ): + with mock.patch.dict("os.environ", { + "PERF_DAGS_COUNT": str(dag_count), + "PERF_TASKS_COUNT": str(task_count), + "PERF_START_AGO": start_ago, + "PERF_SCHEDULE_INTERVAL": schedule_interval, + "PERF_SHAPE": shape, + }), conf_vars({ + ('scheduler', 'use_job_schedule'): 'True', + ('core', 'store_serialized_dags'): 'True', + }), mock.patch.object(settings, 'STORE_SERIALIZED_DAGS', True): + + dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE, include_examples=False) + dagbag.sync_to_db() + + mock_agent = mock.MagicMock() + + job = SchedulerJob(subdir=PERF_DAGS_FOLDER, num_runs=1) + job.executor = MockExecutor(do_update=False) + job.heartbeat = mock.MagicMock() + job.processor_agent = mock_agent + for expected_query_count in expected_query_counts: + with create_session() as session: + with assert_queries_count(expected_query_count): + job._do_scheduling(session) From 59dab940bb0bb2490dac14cacfe4b2834f78e1b5 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 28 Sep 2020 17:44:51 +0100 Subject: [PATCH 34/70] fixup! Officially support running more than one scheduler concurrently. --- tests/models/test_dag.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 1f221b23c0017..cc2b26c7ddaba 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -1622,7 +1622,9 @@ def subdag(parent_dag_name, child_dag_name, args): """ Create a subdag. """ - dag_subdag = DAG(dag_id='%s.%s' % (parent_dag_name, child_dag_name), schedule_interval="@daily", default_args=args) + dag_subdag = DAG(dag_id='%s.%s' % (parent_dag_name, child_dag_name), + schedule_interval="@daily", + default_args=args) for i in range(2): DummyOperator(task_id='%s-task-%s' % (child_dag_name, i + 1), dag=dag_subdag) @@ -1651,6 +1653,7 @@ def subdag(parent_dag_name, child_dag_name, args): next_subdag_date = subdag.next_dagrun_after_date(None) assert next_subdag_date is None, "SubDags should never have DagRuns created by the scheduler" + class TestDagModel: def test_dags_needing_dagruns_not_too_early(self): @@ -1709,8 +1712,8 @@ def test_dags_needing_dagruns_only_unpaused(self): orm_dag.is_paused = True session.flush() - models = DagModel.dags_needing_dagruns(session).all() - assert models == [] + dag_models = DagModel.dags_needing_dagruns(session).all() + assert dag_models == [] session.rollback() session.close() From d3364d1be38570b16c0849c6db44c276af402db3 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 28 Sep 2020 18:20:48 +0100 Subject: [PATCH 35/70] fixup! Officially support running more than one scheduler concurrently. --- tests/jobs/test_backfill_job.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py index 184eec97a8539..3682e34548937 100644 --- a/tests/jobs/test_backfill_job.py +++ b/tests/jobs/test_backfill_job.py @@ -25,7 +25,7 @@ import pytest import sqlalchemy -from mock import Mock, patch +from mock import patch from parameterized import parameterized from airflow import settings @@ -35,7 +35,6 @@ TaskConcurrencyLimitReached, ) from airflow.jobs.backfill_job import BackfillJob -from airflow.jobs.scheduler_job import DagFileProcessor from airflow.models import DAG, DagBag, Pool, TaskInstance as TI from airflow.models.dagrun import DagRun from airflow.operators.dummy_operator import DummyOperator @@ -145,11 +144,12 @@ def test_trigger_controller_dag(self): target_dag = self.dagbag.get_dag('example_trigger_target_dag') target_dag.sync_to_db() - dag_file_processor = DagFileProcessor(dag_ids=[], log=Mock()) - task_instances_list = dag_file_processor._process_task_instances( - target_dag, - dag_runs=DagRun.find(dag_id='example_trigger_target_dag') - ) + # dag_file_processor = DagFileProcessor(dag_ids=[], log=Mock()) + task_instances_list = [] + # task_instances_list = dag_file_processor._process_task_instances( + # target_dag, + # dag_runs=DagRun.find(dag_id='example_trigger_target_dag') + # ) self.assertFalse(task_instances_list) job = BackfillJob( @@ -160,10 +160,11 @@ def test_trigger_controller_dag(self): ) job.run() - task_instances_list = dag_file_processor._process_task_instances( - target_dag, - dag_runs=DagRun.find(dag_id='example_trigger_target_dag') - ) + task_instances_list = [] + # task_instances_list = dag_file_processor._process_task_instances( + # target_dag, + # dag_runs=DagRun.find(dag_id='example_trigger_target_dag') + # ) self.assertTrue(task_instances_list) From 71ad5afa87dd07d59261afcc3d384cc96712afbc Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 28 Sep 2020 18:32:35 +0100 Subject: [PATCH 36/70] fixup! Officially support running more than one scheduler concurrently. --- tests/jobs/test_scheduler_job.py | 65 -------------------------------- 1 file changed, 65 deletions(-) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index e377d7a869de7..7257c90e52b15 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -774,71 +774,6 @@ def test_should_mark_dummy_task_as_success(self): self.assertIsNone(duration) -@pytest.mark.heisentests -class TestDagFileProcessorQueriesCount(unittest.TestCase): - """ - These tests are designed to detect changes in the number of queries for different DAG files. - - Each test has saved queries count in the table/spreadsheets. If you make a change that affected the number - of queries, please update the tables. - - These tests allow easy detection when a change is made that affects the performance of the - DagFileProcessor. - """ - - def setUp(self) -> None: - clear_db_runs() - clear_db_pools() - clear_db_dags() - clear_db_sla_miss() - clear_db_errors() - - @parameterized.expand( - [ - # pylint: disable=bad-whitespace - # expected, dag_count, task_count, start_ago, schedule_interval, shape - # One DAG with two tasks per DAG file - ([ 5, 5, 5, 5], 1, 1, "1d", "None", "no_structure"), # noqa - ([ 5, 5, 5, 5], 1, 1, "1d", "None", "linear"), # noqa - ([15, 9, 9, 9], 1, 1, "1d", "@once", "no_structure"), # noqa - ([15, 9, 9, 9], 1, 1, "1d", "@once", "linear"), # noqa - ([15, 18, 21, 24], 1, 1, "1d", "30m", "no_structure"), # noqa - ([15, 18, 21, 24], 1, 1, "1d", "30m", "linear"), # noqa - # One DAG with five tasks per DAG file - ([ 5, 5, 5, 5], 1, 5, "1d", "None", "no_structure"), # noqa - ([ 5, 5, 5, 5], 1, 5, "1d", "None", "linear"), # noqa - ([15, 9, 9, 9], 1, 5, "1d", "@once", "no_structure"), # noqa - ([16, 10, 10, 10], 1, 5, "1d", "@once", "linear"), # noqa - ([15, 18, 21, 24], 1, 5, "1d", "30m", "no_structure"), # noqa - ([16, 20, 24, 28], 1, 5, "1d", "30m", "linear"), # noqa - # 10 DAGs with 10 tasks per DAG file - ([ 5, 5, 5, 5], 10, 10, "1d", "None", "no_structure"), # noqa - ([ 5, 5, 5, 5], 10, 10, "1d", "None", "linear"), # noqa - ([87, 45, 45, 45], 10, 10, "1d", "@once", "no_structure"), # noqa - ([97, 55, 55, 55], 10, 10, "1d", "@once", "linear"), # noqa - ([87, 117, 117, 117], 10, 10, "1d", "30m", "no_structure"), # noqa - ([97, 137, 137, 137], 10, 10, "1d", "30m", "linear"), # noqa - # pylint: enable=bad-whitespace - ] - ) - def test_process_file_queries_count( - self, expected_query_counts, dag_count, task_count, start_ago, schedule_interval, shape - ): - with mock.patch.dict("os.environ", { - "PERF_DAGS_COUNT": str(dag_count), - "PERF_TASKS_COUNT": str(task_count), - "PERF_START_AGO": start_ago, - "PERF_SCHEDULE_INTERVAL": schedule_interval, - "PERF_SHAPE": shape, - }), conf_vars({ - ('scheduler', 'use_job_schedule'): 'True' - }): - processor = DagFileProcessor([], mock.MagicMock()) - for expected_query_count in expected_query_counts: - with assert_queries_count(expected_query_count): - processor.process_file(ELASTIC_DAG_FILE, []) - - @pytest.mark.usefixtures("disable_load_example") class TestSchedulerJob(unittest.TestCase): From b1c114cb8ab9c5f9c2f8eab2653e9d50b4e76791 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 28 Sep 2020 19:14:02 +0100 Subject: [PATCH 37/70] fixup! Officially support running more than one scheduler concurrently. --- .../98271e7606e2_add_scheduling_decision_to_dagrun_and_.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py b/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py index 28693c92fb1a2..e320169453d4a 100644 --- a/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py +++ b/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py @@ -63,8 +63,9 @@ def upgrade(): concurrency ) ) - op.alter_column('dag', 'concurrency', type_=sa.Integer(), nullable=False) - op.alter_column('dag', 'has_task_concurrency_limits', type_=sa.Boolean(), nullable=False) + with op.batch_alter_table('dag', schema=None) as batch_op: + batch_op.alter_column('concurrency', type_=sa.Integer(), nullable=False) + batch_op.alter_column('has_task_concurrency_limits', type_=sa.Boolean(), nullable=False) def downgrade(): From 5625737bb38bfa66c252bebd67ca5994211f9043 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Mon, 28 Sep 2020 23:57:28 +0100 Subject: [PATCH 38/70] fixup! Officially support running more than one scheduler concurrently. --- airflow/jobs/scheduler_job.py | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 242be9094b0da..b9687dc28d8a9 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -1512,6 +1512,7 @@ def validate_commit(_): if db_err_code in ('55P03', 1205, 3572): self.log.debug("Critical section lock held by another Scheduler") Stats.incr('scheduler.critical_section_busy') + session.rollback() return 0 raise From 949665108f9327c024ac8834fe46aff02eaa5012 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Tue, 29 Sep 2020 08:06:56 +0100 Subject: [PATCH 39/70] fixup! Officially support running more than one scheduler concurrently. --- airflow/models/dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 7dfdf0176789a..9cb370ced9d14 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1977,7 +1977,7 @@ class DagModel(Base): def __init__(self, **kwargs): super().__init__(**kwargs) if self.concurrency is None: - self.concurrency = conf.getint('core', 'dag_concurrency'), + self.concurrency = conf.getint('core', 'dag_concurrency') if self.has_task_concurrency_limits is None: # Be safe -- this will be updated later once the DAG is parsed self.has_task_concurrency_limits = True From aa4b03ac9d168760f0cb4cad77c418fd047fb135 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Tue, 29 Sep 2020 12:45:06 +0100 Subject: [PATCH 40/70] fixup! Officially support running more than one scheduler concurrently. --- tests/www/test_views.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/www/test_views.py b/tests/www/test_views.py index 98f3bca1633c9..621de17fde2d6 100644 --- a/tests/www/test_views.py +++ b/tests/www/test_views.py @@ -52,7 +52,6 @@ from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.bash import BashOperator from airflow.operators.dummy_operator import DummyOperator -from airflow.settings import Session from airflow.ti_deps.dependencies_states import QUEUEABLE_STATES, RUNNABLE_STATES from airflow.utils import dates, timezone from airflow.utils.log.logging_mixin import ExternalLoggingMixin @@ -639,10 +638,10 @@ def test_view_uses_existing_dagbag(self, endpoint): self.check_content_in_response('example_bash_operator', resp) @parameterized.expand([ - ("hello\nworld", r'\"conf\":{\"abc\":\"hello\\nworld\"}}'), - ("hello'world", r'\"conf\":{\"abc\":\"hello\\u0027world\"}}'), - ("