Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
from flask import Response

from airflow.api_connexion.types import APIResponse
from airflow.models import Trigger, Variable, XCom
from airflow.models.dagwarning import DagWarning
from airflow.serialization.serialized_objects import BaseSerialization

log = logging.getLogger(__name__)
Expand All @@ -36,7 +34,9 @@
def _initialize_map() -> dict[str, Callable]:
from airflow.dag_processing.manager import DagFileProcessorManager
from airflow.dag_processing.processor import DagFileProcessor
from airflow.models import Trigger, Variable, XCom
from airflow.models.dag import DagModel
from airflow.models.dagwarning import DagWarning

functions: list[Callable] = [
DagFileProcessor.update_import_errors,
Expand Down
30 changes: 15 additions & 15 deletions airflow/cli/commands/dag_processor_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,29 @@

from airflow import settings
from airflow.configuration import conf
from airflow.dag_processing.manager import DagFileProcessorManager
from airflow.jobs.dag_processor_job_runner import DagProcessorJobRunner
from airflow.jobs.job import Job, run_job
from airflow.utils import cli as cli_utils
from airflow.utils.cli import setup_locations, setup_logging

log = logging.getLogger(__name__)


def _create_dag_processor_job(args: Any) -> Job:
def _create_dag_processor_job_runner(args: Any) -> DagProcessorJobRunner:
"""Creates DagFileProcessorProcess instance."""
from airflow.dag_processing.manager import DagFileProcessorManager

processor_timeout_seconds: int = conf.getint("core", "dag_file_processor_timeout")
processor_timeout = timedelta(seconds=processor_timeout_seconds)

processor = DagFileProcessorManager(
processor_timeout=processor_timeout,
dag_directory=args.subdir,
max_runs=args.num_runs,
dag_ids=[],
pickle_dags=args.do_pickle,
)
return Job(
job_runner=processor.job_runner,
return DagProcessorJobRunner(
job=Job(),
processor=DagFileProcessorManager(
processor_timeout=processor_timeout,
dag_directory=args.subdir,
max_runs=args.num_runs,
dag_ids=[],
pickle_dags=args.do_pickle,
),
)


Expand All @@ -62,7 +62,7 @@ def dag_processor(args):
if sql_conn.startswith("sqlite"):
raise SystemExit("Standalone DagProcessor is not supported when using sqlite.")

job = _create_dag_processor_job(args)
job_runner = _create_dag_processor_job_runner(args)

if args.daemon:
pid, stdout, stderr, log_file = setup_locations(
Expand All @@ -81,6 +81,6 @@ def dag_processor(args):
umask=int(settings.DAEMON_UMASK, 8),
)
with ctx:
run_job(job)
run_job(job=job_runner.job, execute_callable=job_runner._execute)
else:
run_job(job)
run_job(job=job_runner.job, execute_callable=job_runner._execute)
18 changes: 7 additions & 11 deletions airflow/cli/commands/scheduler_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,22 @@
from airflow.utils.scheduler_health import serve_health_check


def _run_scheduler_job(job: Job, *, skip_serve_logs: bool) -> None:
def _run_scheduler_job(job_runner: SchedulerJobRunner, *, skip_serve_logs: bool) -> None:
InternalApiConfig.force_database_direct_access()
enable_health_check = conf.getboolean("scheduler", "ENABLE_HEALTH_CHECK")
with _serve_logs(skip_serve_logs), _serve_health_check(enable_health_check):
run_job(job)
run_job(job=job_runner.job, execute_callable=job_runner._execute)


@cli_utils.action_cli
def scheduler(args):
"""Starts Airflow Scheduler."""
print(settings.HEADER)

job = Job(
job_runner=SchedulerJobRunner(
subdir=process_subdir(args.subdir),
num_runs=args.num_runs,
do_pickle=args.do_pickle,
)
job_runner = SchedulerJobRunner(
job=Job(), subdir=process_subdir(args.subdir), num_runs=args.num_runs, do_pickle=args.do_pickle
)
ExecutorLoader.validate_database_executor_compatibility(job.executor)
ExecutorLoader.validate_database_executor_compatibility(job_runner.job.executor)

if args.daemon:
pid, stdout, stderr, log_file = setup_locations(
Expand All @@ -73,12 +69,12 @@ def scheduler(args):
umask=int(settings.DAEMON_UMASK, 8),
)
with ctx:
_run_scheduler_job(job, skip_serve_logs=args.skip_serve_logs)
_run_scheduler_job(job_runner, skip_serve_logs=args.skip_serve_logs)
else:
signal.signal(signal.SIGINT, sigint_handler)
signal.signal(signal.SIGTERM, sigint_handler)
signal.signal(signal.SIGQUIT, sigquit_handler)
_run_scheduler_job(job, skip_serve_logs=args.skip_serve_logs)
_run_scheduler_job(job_runner, skip_serve_logs=args.skip_serve_logs)


@contextmanager
Expand Down
9 changes: 3 additions & 6 deletions airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ def _run_task_by_executor(args, dag: DAG, ti: TaskInstance) -> None:

def _run_task_by_local_task_job(args, ti: TaskInstance) -> TaskReturnCode | None:
"""Run LocalTaskJob, which monitors the raw task execution process."""
local_task_job_runner = LocalTaskJobRunner(
job_runner = LocalTaskJobRunner(
job=Job(dag_id=ti.dag_id),
task_instance=ti,
mark_success=args.mark_success,
pickle_id=args.pickle,
Expand All @@ -260,12 +261,8 @@ def _run_task_by_local_task_job(args, ti: TaskInstance) -> TaskReturnCode | None
pool=args.pool,
external_executor_id=_extract_external_executor_id(args),
)
local_task_job = Job(
job_runner=local_task_job_runner,
dag_id=ti.dag_id,
)
try:
ret = run_job(local_task_job)
ret = run_job(job=job_runner.job, execute_callable=job_runner._execute)
finally:
if args.shut_down_logging:
logging.shutdown()
Expand Down
7 changes: 3 additions & 4 deletions airflow/cli/commands/triggerer_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def triggerer(args):
"""Starts Airflow Triggerer."""
settings.MASK_SECRETS_IN_LOGS = True
print(settings.HEADER)
triggerer_job_runner = TriggererJobRunner(capacity=args.capacity)
job = Job(job_runner=triggerer_job_runner)
triggerer_job_runner = TriggererJobRunner(job=Job(), capacity=args.capacity)

if args.daemon:
pid, stdout, stderr, log_file = setup_locations(
Expand All @@ -75,10 +74,10 @@ def triggerer(args):
umask=int(settings.DAEMON_UMASK, 8),
)
with daemon_context, _serve_logs(args.skip_serve_logs):
run_job(job)
run_job(job=triggerer_job_runner.job, execute_callable=triggerer_job_runner._execute)
else:
signal.signal(signal.SIGINT, sigint_handler)
signal.signal(signal.SIGTERM, sigint_handler)
signal.signal(signal.SIGQUIT, sigquit_handler)
with _serve_logs(args.skip_serve_logs):
run_job(job)
run_job(job=triggerer_job_runner.job, execute_callable=triggerer_job_runner._execute)
22 changes: 3 additions & 19 deletions airflow/dag_processing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from importlib import import_module
from multiprocessing.connection import Connection as MultiprocessingConnection
from pathlib import Path
from typing import TYPE_CHECKING, Any, NamedTuple, cast
from typing import Any, Callable, NamedTuple, cast

from setproctitle import setproctitle
from sqlalchemy.orm import Session
Expand All @@ -46,7 +46,6 @@
from airflow.callbacks.callback_requests import CallbackRequest, SlaCallbackRequest
from airflow.configuration import conf
from airflow.dag_processing.processor import DagFileProcessorProcess
from airflow.jobs.job import perform_heartbeat
from airflow.models import errors
from airflow.models.dag import DagModel
from airflow.models.dagwarning import DagWarning
Expand All @@ -67,9 +66,6 @@
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import prohibit_commit, skip_locked, with_row_locks

if TYPE_CHECKING:
from airflow.jobs.dag_processor_job_runner import DagProcessorJobRunner


class DagParsingStat(NamedTuple):
"""Information on processing progress."""
Expand Down Expand Up @@ -385,8 +381,6 @@ def __init__(
signal_conn: MultiprocessingConnection | None = None,
async_mode: bool = True,
):
from airflow.jobs.dag_processor_job_runner import DagProcessorJobRunner

super().__init__()
# known files; this will be updated every `dag_dir_list_interval` and stuff added/removed accordingly
self._file_paths: list[str] = []
Expand All @@ -399,9 +393,6 @@ def __init__(
self._async_mode = async_mode
self._parsing_start_time: float | None = None
self._dag_directory = dag_directory
self._job_runner = DagProcessorJobRunner(
processor=self,
)
# Set the signal conn in to non-blocking mode, so that attempting to
# send when the buffer is full errors, rather than hangs for-ever
# attempting to send (this is to avoid deadlocks!)
Expand Down Expand Up @@ -465,10 +456,7 @@ def __init__(
if self._direct_scheduler_conn is not None
else {}
)

@property
def job_runner(self) -> DagProcessorJobRunner:
return self._job_runner
self.heartbeat: Callable[[], None] = lambda: None

def register_exit_signals(self):
"""Register signals that stop child processes."""
Expand Down Expand Up @@ -585,11 +573,7 @@ def _run_parsing_loop(self):
while True:
loop_start_time = time.monotonic()
ready = multiprocessing.connection.wait(self.waitables.keys(), timeout=poll_time)
# we cannot (for now) define job in _job_runner nicely due to circular references of
# job and job runner, so we have to use getattr, but we might address it in the future
# change when decoupling these two even more
if getattr(self._job_runner, "job", None) is not None:
perform_heartbeat(self._job_runner.job, only_if_necessary=False)
self.heartbeat()
if self._direct_scheduler_conn is not None and self._direct_scheduler_conn in ready:
agent_signal = self._direct_scheduler_conn.recv()

Expand Down
2 changes: 1 addition & 1 deletion airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def execute_callbacks(
) -> None:
"""
Execute on failure callbacks.
These objects can come from SchedulerJob or from DagFileProcessorManager.
These objects can come from SchedulerJobRunner or from DagProcessorJobRunner.

:param dagbag: Dag Bag of dags
:param callback_requests: failure callbacks to execute
Expand Down
18 changes: 12 additions & 6 deletions airflow/jobs/backfill_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ class BackfillJobRunner(BaseJobRunner, LoggingMixin):

STATES_COUNT_AS_RUNNING = (State.RUNNING, State.QUEUED)

job: Job # backfill_job can only run with Job class not the Pydantic serialized version

@attr.define
class _DagRunTaskStatus:
"""
Expand Down Expand Up @@ -110,6 +108,7 @@ class _DagRunTaskStatus:

def __init__(
self,
job: Job,
dag: DAG,
start_date=None,
end_date=None,
Expand All @@ -126,8 +125,6 @@ def __init__(
run_at_least_once=False,
continue_on_failures=False,
disable_retry=False,
*args,
**kwargs,
) -> None:
"""
Create a BackfillJobRunner.
Expand All @@ -151,6 +148,14 @@ def __init__(
:param args:
:param kwargs:
"""
super().__init__()
if job.job_type and job.job_type != self.job_type:
raise Exception(
f"The job is already assigned a different job_type: {job.job_type}."
f"This is a bug and should be reported."
)
self.job = job
self.job.job_type = self.job_type
self.dag = dag
self.dag_id = dag.dag_id
self.bf_start_date = start_date
Expand All @@ -168,7 +173,6 @@ def __init__(
self.run_at_least_once = run_at_least_once
self.continue_on_failures = continue_on_failures
self.disable_retry = disable_retry
super().__init__(*args, **kwargs)

def _update_counters(self, ti_status: _DagRunTaskStatus, session: Session) -> None:
"""
Expand Down Expand Up @@ -632,7 +636,9 @@ def _per_task_process(key, ti: TaskInstance, session):
except (NoAvailablePoolSlot, DagConcurrencyLimitReached, TaskConcurrencyLimitReached) as e:
self.log.debug(e)

perform_heartbeat(job=self.job, only_if_necessary=is_unit_test)
perform_heartbeat(
job=self.job, heartbeat_callback=self.heartbeat_callback, only_if_necessary=is_unit_test
)
# execute the tasks in the queue
executor.heartbeat()

Expand Down
3 changes: 0 additions & 3 deletions airflow/jobs/base_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,13 @@
from sqlalchemy.orm import Session

from airflow.jobs.job import Job
from airflow.serialization.pydantic.job import JobPydantic


class BaseJobRunner:
"""Abstract class for job runners to derive from."""

job_type = "undefined"

job: Job | JobPydantic

def _execute(self) -> int | None:
"""
Executes the logic connected to the runner. This method should be
Expand Down
22 changes: 21 additions & 1 deletion airflow/jobs/dag_processor_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,48 @@

from __future__ import annotations

from typing import Any

from airflow.dag_processing.manager import DagFileProcessorManager
from airflow.jobs.base_job_runner import BaseJobRunner
from airflow.jobs.job import Job, perform_heartbeat
from airflow.utils.log.logging_mixin import LoggingMixin


def empty_callback(_: Any) -> None:
pass


class DagProcessorJobRunner(BaseJobRunner, LoggingMixin):
"""
DagProcessorJobRunner is a job runner that runs a DagFileProcessorManager processor.

:param job: Job instance to use
:param processor: DagFileProcessorManager instance to use
"""

job_type = "DagProcessorJob"

def __init__(
self,
job: Job,
processor: DagFileProcessorManager,
*args,
**kwargs,
):
self.processor = processor
super().__init__(*args, **kwargs)
self.job = job
if job.job_type and job.job_type != self.job_type:
raise Exception(
f"The job is already assigned a different job_type: {job.job_type}."
f"This is a bug and should be reported."
)
self.processor = processor
self.processor.heartbeat = lambda: perform_heartbeat(
job=self.job,
heartbeat_callback=empty_callback,
only_if_necessary=False,
)

def _execute(self) -> int | None:
self.log.info("Starting the Dag Processor Job")
Expand Down
Loading