Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 13 additions & 25 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@

from builtins import range

# To avoid circular imports
import airflow.utils.dag_processing
from airflow import configuration
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State


PARALLELISM = configuration.conf.getint('core', 'PARALLELISM')


Expand All @@ -50,11 +51,11 @@ def start(self): # pragma: no cover
"""
pass

def queue_command(self, task_instance, command, priority=1, queue=None):
key = task_instance.key
def queue_command(self, simple_task_instance, command, priority=1, queue=None):
key = simple_task_instance.key
if key not in self.queued_tasks and key not in self.running:
self.log.info("Adding to queue: %s", command)
self.queued_tasks[key] = (command, priority, queue, task_instance)
self.queued_tasks[key] = (command, priority, queue, simple_task_instance)
else:
self.log.info("could not queue task {}".format(key))

Expand Down Expand Up @@ -86,7 +87,7 @@ def queue_task_instance(
pickle_id=pickle_id,
cfg_path=cfg_path)
self.queue_command(
task_instance,
airflow.utils.dag_processing.SimpleTaskInstance(task_instance),
command,
priority=task_instance.task.priority_weight_total,
queue=task_instance.task.queue)
Expand Down Expand Up @@ -124,34 +125,21 @@ def heartbeat(self):
key=lambda x: x[1][1],
reverse=True)
for i in range(min((open_slots, len(self.queued_tasks)))):
key, (command, _, queue, ti) = sorted_queue.pop(0)
# TODO(jlowin) without a way to know what Job ran which tasks,
# there is a danger that another Job started running a task
# that was also queued to this executor. This is the last chance
# to check if that happened. The most probable way is that a
# Scheduler tried to run a task that was originally queued by a
# Backfill. This fix reduces the probability of a collision but
# does NOT eliminate it.
key, (command, _, queue, simple_ti) = sorted_queue.pop(0)
self.queued_tasks.pop(key)
ti.refresh_from_db()
if ti.state != State.RUNNING:
self.running[key] = command
self.execute_async(key=key,
command=command,
queue=queue,
executor_config=ti.executor_config)
else:
self.log.info(
'Task is already running, not sending to '
'executor: {}'.format(key))
self.running[key] = command
self.execute_async(key=key,
command=command,
queue=queue,
executor_config=simple_ti.executor_config)

# Calling child class sync method
self.log.debug("Calling the %s sync method", self.__class__)
self.sync()

def change_state(self, key, state):
self.log.debug("Changing state: {}".format(key))
self.running.pop(key)
self.running.pop(key, None)
self.event_buffer[key] = state

def fail(self, key):
Expand Down
116 changes: 100 additions & 16 deletions airflow/executors/celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@
from airflow.executors.base_executor import BaseExecutor
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.module_loading import import_string
from airflow.utils.timeout import timeout

# Make it constant for unit test.
CELERY_FETCH_ERR_MSG_HEADER = 'Error fetching Celery task state'

CELERY_SEND_ERR_MSG_HEADER = 'Error sending Celery task'

'''
To start the celery worker, run the command:
airflow worker
Expand All @@ -55,12 +58,12 @@


@app.task
def execute_command(command):
def execute_command(command_to_exec):
log = LoggingMixin().log
log.info("Executing command in Celery: %s", command)
log.info("Executing command in Celery: %s", command_to_exec)
env = os.environ.copy()
try:
subprocess.check_call(command, stderr=subprocess.STDOUT,
subprocess.check_call(command_to_exec, stderr=subprocess.STDOUT,
close_fds=True, env=env)
except subprocess.CalledProcessError as e:
log.exception('execute_command encountered a CalledProcessError')
Expand Down Expand Up @@ -95,16 +98,30 @@ def fetch_celery_task_state(celery_task):
"""

try:
# Accessing state property of celery task will make actual network request
# to get the current state of the task.
res = (celery_task[0], celery_task[1].state)
with timeout(seconds=2):
# Accessing state property of celery task will make actual network request
# to get the current state of the task.
res = (celery_task[0], celery_task[1].state)
except Exception as e:
exception_traceback = "Celery Task ID: {}\n{}".format(celery_task[0],
traceback.format_exc())
res = ExceptionWithTraceback(e, exception_traceback)
return res


def send_task_to_executor(task_tuple):
key, simple_ti, command, queue, task = task_tuple
try:
with timeout(seconds=2):
result = task.apply_async(args=[command], queue=queue)
except Exception as e:
exception_traceback = "Celery Task ID: {}\n{}".format(key,
traceback.format_exc())
result = ExceptionWithTraceback(e, exception_traceback)

return key, command, result


class CeleryExecutor(BaseExecutor):
"""
CeleryExecutor is recommended for production use of Airflow. It allows
Expand Down Expand Up @@ -135,24 +152,91 @@ def start(self):
'Starting Celery Executor using {} processes for syncing'.format(
self._sync_parallelism))

def execute_async(self, key, command,
queue=DEFAULT_CELERY_CONFIG['task_default_queue'],
executor_config=None):
self.log.info("[celery] queuing {key} through celery, "
"queue={queue}".format(**locals()))
self.tasks[key] = execute_command.apply_async(
args=[command], queue=queue)
self.last_state[key] = celery_states.PENDING
def _num_tasks_per_send_process(self, to_send_count):
"""
How many Celery tasks should each worker process send.

def _num_tasks_per_process(self):
:return: Number of tasks that should be sent per process
:rtype: int
"""
return max(1,
int(math.ceil(1.0 * to_send_count / self._sync_parallelism)))

def _num_tasks_per_fetch_process(self):
"""
How many Celery tasks should be sent to each worker process.

:return: Number of tasks that should be used per process
:rtype: int
"""
return max(1,
int(math.ceil(1.0 * len(self.tasks) / self._sync_parallelism)))

def heartbeat(self):
# Triggering new jobs
if not self.parallelism:
open_slots = len(self.queued_tasks)
else:
open_slots = self.parallelism - len(self.running)

self.log.debug("{} running task instances".format(len(self.running)))
self.log.debug("{} in queue".format(len(self.queued_tasks)))
self.log.debug("{} open slots".format(open_slots))

sorted_queue = sorted(
[(k, v) for k, v in self.queued_tasks.items()],
key=lambda x: x[1][1],
reverse=True)

task_tuples_to_send = []

for i in range(min((open_slots, len(self.queued_tasks)))):
key, (command, _, queue, simple_ti) = sorted_queue.pop(0)
task_tuples_to_send.append((key, simple_ti, command, queue,
execute_command))

cached_celery_backend = None
if task_tuples_to_send:
tasks = [t[4] for t in task_tuples_to_send]

# Celery state queries will stuck if we do not use one same backend
# for all tasks.
cached_celery_backend = tasks[0].backend

if task_tuples_to_send:
# Use chunking instead of a work queue to reduce context switching
# since tasks are roughly uniform in size
chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send))
num_processes = min(len(task_tuples_to_send), self._sync_parallelism)

send_pool = Pool(processes=num_processes)
key_and_async_results = send_pool.map(
send_task_to_executor,
task_tuples_to_send,
chunksize=chunksize)

send_pool.close()
send_pool.join()
self.log.debug('Sent all tasks.')

for key, command, result in key_and_async_results:
if isinstance(result, ExceptionWithTraceback):
self.log.error(
CELERY_SEND_ERR_MSG_HEADER + ":{}\n{}\n".format(
result.exception, result.traceback))
elif result is not None:
# Only pops when enqueued successfully, otherwise keep it
# and expect scheduler loop to deal with it.
self.queued_tasks.pop(key)
result.backend = cached_celery_backend
self.running[key] = command
self.tasks[key] = result
self.last_state[key] = celery_states.PENDING

# Calling child class sync method
self.log.debug("Calling the {} sync method".format(self.__class__))
self.sync()

def sync(self):
num_processes = min(len(self.tasks), self._sync_parallelism)
if num_processes == 0:
Expand All @@ -167,7 +251,7 @@ def sync(self):

# Use chunking instead of a work queue to reduce context switching since tasks are
# roughly uniform in size
chunksize = self._num_tasks_per_process()
chunksize = self._num_tasks_per_fetch_process()

self.log.debug("Waiting for inquiries to complete...")
task_keys_to_states = self._sync_pool.map(
Expand Down
Loading