diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 903c27633931b..934d97c19b3cf 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -19,11 +19,12 @@ from builtins import range +# To avoid circular imports +import airflow.utils.dag_processing from airflow import configuration from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State - PARALLELISM = configuration.conf.getint('core', 'PARALLELISM') @@ -50,11 +51,11 @@ def start(self): # pragma: no cover """ pass - def queue_command(self, task_instance, command, priority=1, queue=None): - key = task_instance.key + def queue_command(self, simple_task_instance, command, priority=1, queue=None): + key = simple_task_instance.key if key not in self.queued_tasks and key not in self.running: self.log.info("Adding to queue: %s", command) - self.queued_tasks[key] = (command, priority, queue, task_instance) + self.queued_tasks[key] = (command, priority, queue, simple_task_instance) else: self.log.info("could not queue task {}".format(key)) @@ -86,7 +87,7 @@ def queue_task_instance( pickle_id=pickle_id, cfg_path=cfg_path) self.queue_command( - task_instance, + airflow.utils.dag_processing.SimpleTaskInstance(task_instance), command, priority=task_instance.task.priority_weight_total, queue=task_instance.task.queue) @@ -124,26 +125,13 @@ def heartbeat(self): key=lambda x: x[1][1], reverse=True) for i in range(min((open_slots, len(self.queued_tasks)))): - key, (command, _, queue, ti) = sorted_queue.pop(0) - # TODO(jlowin) without a way to know what Job ran which tasks, - # there is a danger that another Job started running a task - # that was also queued to this executor. This is the last chance - # to check if that happened. The most probable way is that a - # Scheduler tried to run a task that was originally queued by a - # Backfill. This fix reduces the probability of a collision but - # does NOT eliminate it. + key, (command, _, queue, simple_ti) = sorted_queue.pop(0) self.queued_tasks.pop(key) - ti.refresh_from_db() - if ti.state != State.RUNNING: - self.running[key] = command - self.execute_async(key=key, - command=command, - queue=queue, - executor_config=ti.executor_config) - else: - self.log.info( - 'Task is already running, not sending to ' - 'executor: {}'.format(key)) + self.running[key] = command + self.execute_async(key=key, + command=command, + queue=queue, + executor_config=simple_ti.executor_config) # Calling child class sync method self.log.debug("Calling the %s sync method", self.__class__) @@ -151,7 +139,7 @@ def heartbeat(self): def change_state(self, key, state): self.log.debug("Changing state: {}".format(key)) - self.running.pop(key) + self.running.pop(key, None) self.event_buffer[key] = state def fail(self, key): diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index 0de48b4d39ff1..0e71778ecc995 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -33,10 +33,13 @@ from airflow.executors.base_executor import BaseExecutor from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.module_loading import import_string +from airflow.utils.timeout import timeout # Make it constant for unit test. CELERY_FETCH_ERR_MSG_HEADER = 'Error fetching Celery task state' +CELERY_SEND_ERR_MSG_HEADER = 'Error sending Celery task' + ''' To start the celery worker, run the command: airflow worker @@ -55,12 +58,12 @@ @app.task -def execute_command(command): +def execute_command(command_to_exec): log = LoggingMixin().log - log.info("Executing command in Celery: %s", command) + log.info("Executing command in Celery: %s", command_to_exec) env = os.environ.copy() try: - subprocess.check_call(command, stderr=subprocess.STDOUT, + subprocess.check_call(command_to_exec, stderr=subprocess.STDOUT, close_fds=True, env=env) except subprocess.CalledProcessError as e: log.exception('execute_command encountered a CalledProcessError') @@ -95,9 +98,10 @@ def fetch_celery_task_state(celery_task): """ try: - # Accessing state property of celery task will make actual network request - # to get the current state of the task. - res = (celery_task[0], celery_task[1].state) + with timeout(seconds=2): + # Accessing state property of celery task will make actual network request + # to get the current state of the task. + res = (celery_task[0], celery_task[1].state) except Exception as e: exception_traceback = "Celery Task ID: {}\n{}".format(celery_task[0], traceback.format_exc()) @@ -105,6 +109,19 @@ def fetch_celery_task_state(celery_task): return res +def send_task_to_executor(task_tuple): + key, simple_ti, command, queue, task = task_tuple + try: + with timeout(seconds=2): + result = task.apply_async(args=[command], queue=queue) + except Exception as e: + exception_traceback = "Celery Task ID: {}\n{}".format(key, + traceback.format_exc()) + result = ExceptionWithTraceback(e, exception_traceback) + + return key, command, result + + class CeleryExecutor(BaseExecutor): """ CeleryExecutor is recommended for production use of Airflow. It allows @@ -135,16 +152,16 @@ def start(self): 'Starting Celery Executor using {} processes for syncing'.format( self._sync_parallelism)) - def execute_async(self, key, command, - queue=DEFAULT_CELERY_CONFIG['task_default_queue'], - executor_config=None): - self.log.info("[celery] queuing {key} through celery, " - "queue={queue}".format(**locals())) - self.tasks[key] = execute_command.apply_async( - args=[command], queue=queue) - self.last_state[key] = celery_states.PENDING + def _num_tasks_per_send_process(self, to_send_count): + """ + How many Celery tasks should each worker process send. + :return: Number of tasks that should be sent per process + :rtype: int + """ + return max(1, + int(math.ceil(1.0 * to_send_count / self._sync_parallelism))) - def _num_tasks_per_process(self): + def _num_tasks_per_fetch_process(self): """ How many Celery tasks should be sent to each worker process. :return: Number of tasks that should be used per process @@ -153,6 +170,71 @@ def _num_tasks_per_process(self): return max(1, int(math.ceil(1.0 * len(self.tasks) / self._sync_parallelism))) + def heartbeat(self): + # Triggering new jobs + if not self.parallelism: + open_slots = len(self.queued_tasks) + else: + open_slots = self.parallelism - len(self.running) + + self.log.debug("{} running task instances".format(len(self.running))) + self.log.debug("{} in queue".format(len(self.queued_tasks))) + self.log.debug("{} open slots".format(open_slots)) + + sorted_queue = sorted( + [(k, v) for k, v in self.queued_tasks.items()], + key=lambda x: x[1][1], + reverse=True) + + task_tuples_to_send = [] + + for i in range(min((open_slots, len(self.queued_tasks)))): + key, (command, _, queue, simple_ti) = sorted_queue.pop(0) + task_tuples_to_send.append((key, simple_ti, command, queue, + execute_command)) + + cached_celery_backend = None + if task_tuples_to_send: + tasks = [t[4] for t in task_tuples_to_send] + + # Celery state queries will stuck if we do not use one same backend + # for all tasks. + cached_celery_backend = tasks[0].backend + + if task_tuples_to_send: + # Use chunking instead of a work queue to reduce context switching + # since tasks are roughly uniform in size + chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send)) + num_processes = min(len(task_tuples_to_send), self._sync_parallelism) + + send_pool = Pool(processes=num_processes) + key_and_async_results = send_pool.map( + send_task_to_executor, + task_tuples_to_send, + chunksize=chunksize) + + send_pool.close() + send_pool.join() + self.log.debug('Sent all tasks.') + + for key, command, result in key_and_async_results: + if isinstance(result, ExceptionWithTraceback): + self.log.error( + CELERY_SEND_ERR_MSG_HEADER + ":{}\n{}\n".format( + result.exception, result.traceback)) + elif result is not None: + # Only pops when enqueued successfully, otherwise keep it + # and expect scheduler loop to deal with it. + self.queued_tasks.pop(key) + result.backend = cached_celery_backend + self.running[key] = command + self.tasks[key] = result + self.last_state[key] = celery_states.PENDING + + # Calling child class sync method + self.log.debug("Calling the {} sync method".format(self.__class__)) + self.sync() + def sync(self): num_processes = min(len(self.tasks), self._sync_parallelism) if num_processes == 0: @@ -167,7 +249,7 @@ def sync(self): # Use chunking instead of a work queue to reduce context switching since tasks are # roughly uniform in size - chunksize = self._num_tasks_per_process() + chunksize = self._num_tasks_per_fetch_process() self.log.debug("Waiting for inquiries to complete...") task_keys_to_states = self._sync_pool.map( diff --git a/airflow/jobs.py b/airflow/jobs.py index 9e68fad79785c..ca124bf1f1af8 100644 --- a/airflow/jobs.py +++ b/airflow/jobs.py @@ -52,6 +52,7 @@ DagFileProcessorAgent, SimpleDag, SimpleDagBag, + SimpleTaskInstance, list_py_file_paths) from airflow.utils.db import create_session, provide_session from airflow.utils.email import get_email_address_list, send_email @@ -598,6 +599,7 @@ def __init__( 'run_duration') self.processor_agent = None + self._last_loop = False signal.signal(signal.SIGINT, self._exit_gracefully) signal.signal(signal.SIGTERM, self._exit_gracefully) @@ -1228,13 +1230,13 @@ def _change_state_for_executable_task_instances(self, task_instances, acceptable_states, session=None): """ Changes the state of task instances in the list with one of the given states - to QUEUED atomically, and returns the TIs changed. + to QUEUED atomically, and returns the TIs changed in SimpleTaskInstance format. :param task_instances: TaskInstances to change the state of :type task_instances: List[TaskInstance] :param acceptable_states: Filters the TaskInstances updated to be in these states :type acceptable_states: Iterable[State] - :return: List[TaskInstance] + :return: List[SimpleTaskInstance] """ if len(task_instances) == 0: session.commit() @@ -1276,81 +1278,57 @@ def _change_state_for_executable_task_instances(self, task_instances, else task_instance.queued_dttm) session.merge(task_instance) - # save which TIs we set before session expires them - filter_for_ti_enqueue = ([and_(TI.dag_id == ti.dag_id, - TI.task_id == ti.task_id, - TI.execution_date == ti.execution_date) - for ti in tis_to_set_to_queued]) - session.commit() - - # requery in batches since above was expired by commit + # Generate a list of SimpleTaskInstance for the use of queuing + # them in the executor. + simple_task_instances = [SimpleTaskInstance(ti) for ti in + tis_to_set_to_queued] - def query(result, items): - tis_to_be_queued = ( - session - .query(TI) - .filter(or_(*items)) - .all()) - task_instance_str = "\n\t".join( - ["{}".format(x) for x in tis_to_be_queued]) - self.log.info("Setting the following {} tasks to queued state:\n\t{}" - .format(len(tis_to_be_queued), - task_instance_str)) - return result + tis_to_be_queued - - tis_to_be_queued = helpers.reduce_in_chunks(query, - filter_for_ti_enqueue, - [], - self.max_tis_per_query) + task_instance_str = "\n\t".join( + ["{}".format(x) for x in tis_to_set_to_queued]) - return tis_to_be_queued + session.commit() + self.logger.info("Setting the following {} tasks to queued state:\n\t{}" + .format(len(tis_to_set_to_queued), task_instance_str)) + return simple_task_instances - def _enqueue_task_instances_with_queued_state(self, simple_dag_bag, task_instances): + def _enqueue_task_instances_with_queued_state(self, simple_dag_bag, + simple_task_instances): """ Takes task_instances, which should have been set to queued, and enqueues them with the executor. - :param task_instances: TaskInstances to enqueue - :type task_instances: List[TaskInstance] + :param simple_task_instances: TaskInstances to enqueue + :type simple_task_instances: List[SimpleTaskInstance] :param simple_dag_bag: Should contains all of the task_instances' dags :type simple_dag_bag: SimpleDagBag """ TI = models.TaskInstance # actually enqueue them - for task_instance in task_instances: - simple_dag = simple_dag_bag.get_dag(task_instance.dag_id) + for simple_task_instance in simple_task_instances: + simple_dag = simple_dag_bag.get_dag(simple_task_instance.dag_id) command = TI.generate_command( - task_instance.dag_id, - task_instance.task_id, - task_instance.execution_date, + simple_task_instance.dag_id, + simple_task_instance.task_id, + simple_task_instance.execution_date, local=True, mark_success=False, ignore_all_deps=False, ignore_depends_on_past=False, ignore_task_deps=False, ignore_ti_state=False, - pool=task_instance.pool, + pool=simple_task_instance.pool, file_path=simple_dag.full_filepath, pickle_id=simple_dag.pickle_id) - priority = task_instance.priority_weight - queue = task_instance.queue + priority = simple_task_instance.priority_weight + queue = simple_task_instance.queue self.log.info( "Sending %s to executor with priority %s and queue %s", - task_instance.key, priority, queue + simple_task_instance.key, priority, queue ) - # save attributes so sqlalchemy doesnt expire them - copy_dag_id = task_instance.dag_id - copy_task_id = task_instance.task_id - copy_execution_date = task_instance.execution_date - make_transient(task_instance) - task_instance.dag_id = copy_dag_id - task_instance.task_id = copy_task_id - task_instance.execution_date = copy_execution_date - self.executor.queue_command( - task_instance, + simple_task_instance, command, priority=priority, queue=queue) @@ -1374,24 +1352,65 @@ def _execute_task_instances(self, :type simple_dag_bag: SimpleDagBag :param states: Execute TaskInstances in these states :type states: Tuple[State] - :return: None + :return: Number of task instance with state changed. """ executable_tis = self._find_executable_task_instances(simple_dag_bag, states, session=session) def query(result, items): - tis_with_state_changed = self._change_state_for_executable_task_instances( - items, - states, - session=session) + simple_tis_with_state_changed = \ + self._change_state_for_executable_task_instances(items, + states, + session=session) self._enqueue_task_instances_with_queued_state( simple_dag_bag, - tis_with_state_changed) + simple_tis_with_state_changed) session.commit() - return result + len(tis_with_state_changed) + return result + len(simple_tis_with_state_changed) return helpers.reduce_in_chunks(query, executable_tis, 0, self.max_tis_per_query) + @provide_session + def _change_state_for_tasks_failed_to_execute(self, session): + """ + If there are tasks left over in the executor, + we set them back to SCHEDULED to avoid creating hanging tasks. + :param session: + :return: + """ + if self.executor.queued_tasks: + TI = models.TaskInstance + filter_for_ti_state_change = ( + [and_( + TI.dag_id == dag_id, + TI.task_id == task_id, + TI.execution_date == execution_date, + # The TI.try_number will return raw try_number+1 since the + # ti is not running. And we need to -1 to match the DB record. + TI._try_number == try_number test_change_state_for_tasks_failed_to_execute- 1, + TI.state == State.QUEUED) + for dag_id, task_id, execution_date, try_number + in self.executor.queued_tasks.keys()]) + ti_query = (session.query(TI) + .filter(or_(*filter_for_ti_state_change))) + tis_to_set_to_scheduled = (ti_query + .with_for_update() + .all()) + if len(tis_to_set_to_scheduled) == 0: + session.commit() + return + + # set TIs to queued state + for task_instance in tis_to_set_to_scheduled: + task_instance.state = State.SCHEDULED + + task_instance_str = "\n\t".join( + ["{}".format(x) for x in tis_to_set_to_scheduled]) + + session.commit() + self.logger.info("Set the follow tasks to scheduled state:\n\t{}" + .format(task_instance_str)) + def _process_dags(self, dagbag, dags, tis_out): """ Iterates over the dags and processes them. Processing includes: @@ -1507,6 +1526,8 @@ def processor_factory(file_path, zombies): try: self._execute_helper() + except Exception: + self.log.exception("Exception when executing execute_helper") finally: self.processor_agent.end() self.log.info("Exited execute loop") @@ -1557,6 +1578,7 @@ def _execute_helper(self): self.log.info("Harvesting DAG parsing results") simple_dags = self.processor_agent.harvest_simple_dags() + self.log.debug("Harvested {} SimpleDAGs".format(len(simple_dags))) # Send tasks for execution if available simple_dag_bag = SimpleDagBag(simple_dags) @@ -1593,6 +1615,8 @@ def _execute_helper(self): self.log.debug("Heartbeating the executor") self.executor.heartbeat() + self._change_state_for_tasks_failed_to_execute() + # Process events from the executor self._process_executor_events(simple_dag_bag) @@ -1612,8 +1636,13 @@ def _execute_helper(self): self.log.debug("Sleeping for %.2f seconds", self._processor_poll_interval) time.sleep(self._processor_poll_interval) - # Exit early for a test mode + # Exit early for a test mode, run one additional scheduler loop + # to reduce the possibility that parsed DAG was put into the queue + # by the DAG manager but not yet received by DAG agent. if self.processor_agent.done: + self._last_loop = True + + if self._last_loop: self.log.info("Exiting scheduler loop as all files" " have been processed {} times".format(self.num_runs)) break diff --git a/airflow/models.py b/airflow/models.py index 9ab2348cc2ded..8870a2921e85c 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -22,12 +22,12 @@ from __future__ import print_function from __future__ import unicode_literals +import copy +from collections import defaultdict, namedtuple +from builtins import ImportError as BuiltinImportError, bytes, object, str from future.standard_library import install_aliases -from builtins import str, object, bytes, ImportError as BuiltinImportError -import copy -from collections import namedtuple, defaultdict try: # Fix Python > 3.7 deprecation from collections.abc import Hashable diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py index 47f473e9aa3d3..62c4c91968384 100644 --- a/airflow/utils/dag_processing.py +++ b/airflow/utils/dag_processing.py @@ -146,6 +146,21 @@ def __init__(self, ti): self._end_date = ti.end_date self._try_number = ti.try_number self._state = ti.state + self._executor_config = ti.executor_config + if hasattr(ti, 'run_as_user'): + self._run_as_user = ti.run_as_user + else: + self._run_as_user = None + if hasattr(ti, 'pool'): + self._pool = ti.pool + else: + self._pool = None + if hasattr(ti, 'priority_weight'): + self._priority_weight = ti.priority_weight + else: + self._priority_weight = None + self._queue = ti.queue + self._key = ti.key @property def dag_id(self): @@ -175,6 +190,48 @@ def try_number(self): def state(self): return self._state + @property + def pool(self): + return self._pool + + @property + def priority_weight(self): + return self._priority_weight + + @property + def queue(self): + return self._queue + + @property + def key(self): + return self._key + + @property + def executor_config(self): + return self._executor_config + + @provide_session + def construct_task_instance(self, session=None, lock_for_update=False): + """ + Construct a TaskInstance from the database based on the primary key + :param session: DB session. + :param lock_for_update: if True, indicates that the database should + lock the TaskInstance (issuing a FOR UPDATE clause) until the + session is committed. + """ + TI = airflow.models.TaskInstance + + qry = session.query(TI).filter( + TI.dag_id == self._dag_id, + TI.task_id == self._task_id, + TI.execution_date == self._execution_date) + + if lock_for_update: + ti = qry.with_for_update().first() + else: + ti = qry.first() + return ti + class SimpleDagBag(BaseDagBag): """ @@ -566,11 +623,16 @@ def end(self): Terminate (and then kill) the manager process launched. :return: """ - if not self._process or not self._process.is_alive(): + if not self._process: self.log.warn('Ending without manager process.') return this_process = psutil.Process(os.getpid()) - manager_process = psutil.Process(self._process.pid) + try: + manager_process = psutil.Process(self._process.pid) + except psutil.NoSuchProcess: + self.log.info("Manager process not running.") + return + # First try SIGTERM if manager_process.is_running() \ and manager_process.pid in [x.pid for x in this_process.children()]: diff --git a/airflow/utils/timeout.py b/airflow/utils/timeout.py index a86b9d357b5c6..f64800587792b 100644 --- a/airflow/utils/timeout.py +++ b/airflow/utils/timeout.py @@ -23,6 +23,7 @@ from __future__ import unicode_literals import signal +import os from airflow.exceptions import AirflowTaskTimeout from airflow.utils.log.logging_mixin import LoggingMixin @@ -35,10 +36,10 @@ class timeout(LoggingMixin): def __init__(self, seconds=1, error_message='Timeout'): self.seconds = seconds - self.error_message = error_message + self.error_message = error_message + ', PID: ' + str(os.getpid()) def handle_timeout(self, signum, frame): - self.log.error("Process timed out") + self.log.error("Process timed out, PID: " + str(os.getpid())) raise AirflowTaskTimeout(self.error_message) def __enter__(self): diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py index 380201d30a045..5a7d6e984c334 100644 --- a/tests/executors/test_celery_executor.py +++ b/tests/executors/test_celery_executor.py @@ -18,12 +18,17 @@ # under the License. import sys import unittest +from multiprocessing import Pool + import mock from celery.contrib.testing.worker import start_worker -from airflow.executors.celery_executor import CeleryExecutor -from airflow.executors.celery_executor import app +from airflow.executors import celery_executor from airflow.executors.celery_executor import CELERY_FETCH_ERR_MSG_HEADER +from airflow.executors.celery_executor import (CeleryExecutor, celery_configuration, + send_task_to_executor, execute_command) +from airflow.executors.celery_executor import app +from celery import states as celery_states from airflow.utils.state import State from airflow import configuration @@ -40,16 +45,37 @@ def test_celery_integration(self): executor = CeleryExecutor() executor.start() with start_worker(app=app, logfile=sys.stdout, loglevel='debug'): - success_command = ['true', 'some_parameter'] fail_command = ['false', 'some_parameter'] - executor.execute_async(key='success', command=success_command) - # errors are propagated for some reason - try: - executor.execute_async(key='fail', command=fail_command) - except Exception: - pass + cached_celery_backend = execute_command.backend + task_tuples_to_send = [('success', 'fake_simple_ti', success_command, + celery_configuration['task_default_queue'], + execute_command), + ('fail', 'fake_simple_ti', fail_command, + celery_configuration['task_default_queue'], + execute_command)] + + chunksize = executor._num_tasks_per_send_process(len(task_tuples_to_send)) + num_processes = min(len(task_tuples_to_send), executor._sync_parallelism) + + send_pool = Pool(processes=num_processes) + key_and_async_results = send_pool.map( + send_task_to_executor, + task_tuples_to_send, + chunksize=chunksize) + + send_pool.close() + send_pool.join() + + for key, command, result in key_and_async_results: + # Only pops when enqueued successfully, otherwise keep it + # and expect scheduler loop to deal with it. + result.backend = cached_celery_backend + executor.running[key] = command + executor.tasks[key] = result + executor.last_state[key] = celery_states.PENDING + executor.running['success'] = True executor.running['fail'] = True @@ -64,6 +90,21 @@ def test_celery_integration(self): self.assertNotIn('success', executor.last_state) self.assertNotIn('fail', executor.last_state) + def test_error_sending_task(self): + @app.task + def fake_execute_command(): + pass + + # fake_execute_command takes no arguments while execute_command takes 1, + # which will cause TypeError when calling task.apply_async() + celery_executor.execute_command = fake_execute_command + executor = CeleryExecutor() + value_tuple = 'command', '_', 'queue', 'should_be_a_simple_ti' + executor.queued_tasks['key'] = value_tuple + executor.heartbeat() + self.assertEquals(1, len(executor.queued_tasks)) + self.assertEquals(executor.queued_tasks['key'], value_tuple) + def test_exception_propagation(self): @app.task def fake_celery_task(): diff --git a/tests/executors/test_executor.py b/tests/executors/test_executor.py index aab66644b8a76..366ea8c967cce 100644 --- a/tests/executors/test_executor.py +++ b/tests/executors/test_executor.py @@ -46,7 +46,8 @@ def heartbeat(self): ti = self._running.pop() ti.set_state(State.SUCCESS, session) for key, val in list(self.queued_tasks.items()): - (command, priority, queue, ti) = val + (command, priority, queue, simple_ti) = val + ti = simple_ti.construct_task_instance() ti.set_state(State.RUNNING, session) self._running.append(ti) self.queued_tasks.pop(key) diff --git a/tests/test_jobs.py b/tests/test_jobs.py index af8ccc6c2e8c0..fb161c1b6931b 100644 --- a/tests/test_jobs.py +++ b/tests/test_jobs.py @@ -2032,6 +2032,54 @@ def test_change_state_for_tis_without_dagrun(self): ti2.refresh_from_db(session=session) self.assertEqual(ti2.state, State.SCHEDULED) + def test_change_state_for_tasks_failed_to_execute(self): + dag = DAG( + dag_id='dag_id', + start_date=DEFAULT_DATE) + + task = DummyOperator( + task_id='task_id', + dag=dag, + owner='airflow') + + # If there's no left over task in executor.queued_tasks, nothing happens + session = settings.Session() + scheduler_job = SchedulerJob() + mock_logger = mock.MagicMock() + test_executor = TestExecutor() + scheduler_job.executor = test_executor + scheduler_job._logger = mock_logger + scheduler_job._change_state_for_tasks_failed_to_execute() + mock_logger.info.assert_not_called() + + # Tasks failed to execute with QUEUED state will be set to SCHEDULED state. + session.query(TI).delete() + session.commit() + key = 'dag_id', 'task_id', DEFAULT_DATE, 1 + test_executor.queued_tasks[key] = 'value' + ti = TI(task, DEFAULT_DATE) + ti.state = State.QUEUED + session.merge(ti) + session.commit() + + scheduler_job._change_state_for_tasks_failed_to_execute() + + ti.refresh_from_db() + self.assertEquals(State.SCHEDULED, ti.state) + + # Tasks failed to execute with RUNNING state will not be set to SCHEDULED state. + session.query(TI).delete() + session.commit() + ti.state = State.RUNNING + + session.merge(ti) + session.commit() + + scheduler_job._change_state_for_tasks_failed_to_execute() + + ti.refresh_from_db() + self.assertEquals(State.RUNNING, ti.state) + def test_execute_helper_reset_orphaned_tasks(self): session = settings.Session() dag = DAG( @@ -2949,7 +2997,8 @@ def run_with_error(task): pass ti_tuple = six.next(six.itervalues(executor.queued_tasks)) - (command, priority, queue, ti) = ti_tuple + (command, priority, queue, simple_ti) = ti_tuple + ti = simple_ti.construct_task_instance() ti.task = dag_task1 self.assertEqual(ti.try_number, 1) @@ -2970,15 +3019,21 @@ def run_with_error(task): # removing self.assertEqual(ti.state, State.SCHEDULED) # as scheduler will move state from SCHEDULED to QUEUED - # now the executor has cleared and it should be allowed the re-queue + # now the executor has cleared and it should be allowed the re-queue, + # but tasks stay in the executor.queued_tasks after executor.heartbeat() + # will be set back to SCHEDULED state executor.queued_tasks.clear() do_schedule() ti.refresh_from_db() - self.assertEqual(ti.state, State.QUEUED) - # calling below again in order to ensure with try_number 2, - # scheduler doesn't put task in queue + + self.assertEqual(ti.state, State.SCHEDULED) + + # To verify that task does get re-queued. + executor.queued_tasks.clear() + executor.do_update = True do_schedule() - self.assertEquals(1, len(executor.queued_tasks)) + ti.refresh_from_db() + self.assertEqual(ti.state, State.RUNNING) @unittest.skipUnless("INTEGRATION" in os.environ, "Can only run end to end") def test_retry_handling_job(self): @@ -3023,8 +3078,8 @@ def test_scheduler_run_duration(self): logging.info("Test ran in %.2fs, expected %.2fs", run_duration, expected_run_duration) - # 5s to wait for child process to exit and 1s dummy sleep - # in scheduler loop to prevent excessive logs. + # 5s to wait for child process to exit, 1s dummy sleep + # in scheduler loop to prevent excessive logs and 1s for last loop to finish. self.assertLess(run_duration - expected_run_duration, 6.0) def test_dag_with_system_exit(self):