diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index 5e3931c10d848..50e72c614c1e7 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -25,8 +25,6 @@ from flask import Response from airflow.api_connexion.types import APIResponse -from airflow.models import Trigger, Variable, XCom -from airflow.models.dagwarning import DagWarning from airflow.serialization.serialized_objects import BaseSerialization log = logging.getLogger(__name__) @@ -36,7 +34,9 @@ def _initialize_map() -> dict[str, Callable]: from airflow.dag_processing.manager import DagFileProcessorManager from airflow.dag_processing.processor import DagFileProcessor + from airflow.models import Trigger, Variable, XCom from airflow.models.dag import DagModel + from airflow.models.dagwarning import DagWarning functions: list[Callable] = [ DagFileProcessor.update_import_errors, diff --git a/airflow/cli/commands/dag_processor_command.py b/airflow/cli/commands/dag_processor_command.py index 8e87933fcbb22..70bebf285b5e1 100644 --- a/airflow/cli/commands/dag_processor_command.py +++ b/airflow/cli/commands/dag_processor_command.py @@ -26,6 +26,8 @@ from airflow import settings from airflow.configuration import conf +from airflow.dag_processing.manager import DagFileProcessorManager +from airflow.jobs.dag_processor_job_runner import DagProcessorJobRunner from airflow.jobs.job import Job, run_job from airflow.utils import cli as cli_utils from airflow.utils.cli import setup_locations, setup_logging @@ -33,22 +35,20 @@ log = logging.getLogger(__name__) -def _create_dag_processor_job(args: Any) -> Job: +def _create_dag_processor_job_runner(args: Any) -> DagProcessorJobRunner: """Creates DagFileProcessorProcess instance.""" - from airflow.dag_processing.manager import DagFileProcessorManager - processor_timeout_seconds: int = conf.getint("core", "dag_file_processor_timeout") processor_timeout = timedelta(seconds=processor_timeout_seconds) - processor = DagFileProcessorManager( - processor_timeout=processor_timeout, - dag_directory=args.subdir, - max_runs=args.num_runs, - dag_ids=[], - pickle_dags=args.do_pickle, - ) - return Job( - job_runner=processor.job_runner, + return DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + processor_timeout=processor_timeout, + dag_directory=args.subdir, + max_runs=args.num_runs, + dag_ids=[], + pickle_dags=args.do_pickle, + ), ) @@ -62,7 +62,7 @@ def dag_processor(args): if sql_conn.startswith("sqlite"): raise SystemExit("Standalone DagProcessor is not supported when using sqlite.") - job = _create_dag_processor_job(args) + job_runner = _create_dag_processor_job_runner(args) if args.daemon: pid, stdout, stderr, log_file = setup_locations( @@ -81,6 +81,6 @@ def dag_processor(args): umask=int(settings.DAEMON_UMASK, 8), ) with ctx: - run_job(job) + run_job(job=job_runner.job, execute_callable=job_runner._execute) else: - run_job(job) + run_job(job=job_runner.job, execute_callable=job_runner._execute) diff --git a/airflow/cli/commands/scheduler_command.py b/airflow/cli/commands/scheduler_command.py index 818212052b306..22f9a758088ad 100644 --- a/airflow/cli/commands/scheduler_command.py +++ b/airflow/cli/commands/scheduler_command.py @@ -35,11 +35,11 @@ from airflow.utils.scheduler_health import serve_health_check -def _run_scheduler_job(job: Job, *, skip_serve_logs: bool) -> None: +def _run_scheduler_job(job_runner: SchedulerJobRunner, *, skip_serve_logs: bool) -> None: InternalApiConfig.force_database_direct_access() enable_health_check = conf.getboolean("scheduler", "ENABLE_HEALTH_CHECK") with _serve_logs(skip_serve_logs), _serve_health_check(enable_health_check): - run_job(job) + run_job(job=job_runner.job, execute_callable=job_runner._execute) @cli_utils.action_cli @@ -47,14 +47,10 @@ def scheduler(args): """Starts Airflow Scheduler.""" print(settings.HEADER) - job = Job( - job_runner=SchedulerJobRunner( - subdir=process_subdir(args.subdir), - num_runs=args.num_runs, - do_pickle=args.do_pickle, - ) + job_runner = SchedulerJobRunner( + job=Job(), subdir=process_subdir(args.subdir), num_runs=args.num_runs, do_pickle=args.do_pickle ) - ExecutorLoader.validate_database_executor_compatibility(job.executor) + ExecutorLoader.validate_database_executor_compatibility(job_runner.job.executor) if args.daemon: pid, stdout, stderr, log_file = setup_locations( @@ -73,12 +69,12 @@ def scheduler(args): umask=int(settings.DAEMON_UMASK, 8), ) with ctx: - _run_scheduler_job(job, skip_serve_logs=args.skip_serve_logs) + _run_scheduler_job(job_runner, skip_serve_logs=args.skip_serve_logs) else: signal.signal(signal.SIGINT, sigint_handler) signal.signal(signal.SIGTERM, sigint_handler) signal.signal(signal.SIGQUIT, sigquit_handler) - _run_scheduler_job(job, skip_serve_logs=args.skip_serve_logs) + _run_scheduler_job(job_runner, skip_serve_logs=args.skip_serve_logs) @contextmanager diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index 3d6a57be713c0..3873a55a66c38 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -248,7 +248,8 @@ def _run_task_by_executor(args, dag: DAG, ti: TaskInstance) -> None: def _run_task_by_local_task_job(args, ti: TaskInstance) -> TaskReturnCode | None: """Run LocalTaskJob, which monitors the raw task execution process.""" - local_task_job_runner = LocalTaskJobRunner( + job_runner = LocalTaskJobRunner( + job=Job(dag_id=ti.dag_id), task_instance=ti, mark_success=args.mark_success, pickle_id=args.pickle, @@ -260,12 +261,8 @@ def _run_task_by_local_task_job(args, ti: TaskInstance) -> TaskReturnCode | None pool=args.pool, external_executor_id=_extract_external_executor_id(args), ) - local_task_job = Job( - job_runner=local_task_job_runner, - dag_id=ti.dag_id, - ) try: - ret = run_job(local_task_job) + ret = run_job(job=job_runner.job, execute_callable=job_runner._execute) finally: if args.shut_down_logging: logging.shutdown() diff --git a/airflow/cli/commands/triggerer_command.py b/airflow/cli/commands/triggerer_command.py index 8ab4f344ea302..7bfd77bf8ee2f 100644 --- a/airflow/cli/commands/triggerer_command.py +++ b/airflow/cli/commands/triggerer_command.py @@ -55,8 +55,7 @@ def triggerer(args): """Starts Airflow Triggerer.""" settings.MASK_SECRETS_IN_LOGS = True print(settings.HEADER) - triggerer_job_runner = TriggererJobRunner(capacity=args.capacity) - job = Job(job_runner=triggerer_job_runner) + triggerer_job_runner = TriggererJobRunner(job=Job(), capacity=args.capacity) if args.daemon: pid, stdout, stderr, log_file = setup_locations( @@ -75,10 +74,10 @@ def triggerer(args): umask=int(settings.DAEMON_UMASK, 8), ) with daemon_context, _serve_logs(args.skip_serve_logs): - run_job(job) + run_job(job=triggerer_job_runner.job, execute_callable=triggerer_job_runner._execute) else: signal.signal(signal.SIGINT, sigint_handler) signal.signal(signal.SIGTERM, sigint_handler) signal.signal(signal.SIGQUIT, sigquit_handler) with _serve_logs(args.skip_serve_logs): - run_job(job) + run_job(job=triggerer_job_runner.job, execute_callable=triggerer_job_runner._execute) diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py index b45d833e715be..2cb16a3475780 100644 --- a/airflow/dag_processing/manager.py +++ b/airflow/dag_processing/manager.py @@ -35,7 +35,7 @@ from importlib import import_module from multiprocessing.connection import Connection as MultiprocessingConnection from pathlib import Path -from typing import TYPE_CHECKING, Any, NamedTuple, cast +from typing import Any, Callable, NamedTuple, cast from setproctitle import setproctitle from sqlalchemy.orm import Session @@ -46,7 +46,6 @@ from airflow.callbacks.callback_requests import CallbackRequest, SlaCallbackRequest from airflow.configuration import conf from airflow.dag_processing.processor import DagFileProcessorProcess -from airflow.jobs.job import perform_heartbeat from airflow.models import errors from airflow.models.dag import DagModel from airflow.models.dagwarning import DagWarning @@ -67,9 +66,6 @@ from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import prohibit_commit, skip_locked, with_row_locks -if TYPE_CHECKING: - from airflow.jobs.dag_processor_job_runner import DagProcessorJobRunner - class DagParsingStat(NamedTuple): """Information on processing progress.""" @@ -385,8 +381,6 @@ def __init__( signal_conn: MultiprocessingConnection | None = None, async_mode: bool = True, ): - from airflow.jobs.dag_processor_job_runner import DagProcessorJobRunner - super().__init__() # known files; this will be updated every `dag_dir_list_interval` and stuff added/removed accordingly self._file_paths: list[str] = [] @@ -399,9 +393,6 @@ def __init__( self._async_mode = async_mode self._parsing_start_time: float | None = None self._dag_directory = dag_directory - self._job_runner = DagProcessorJobRunner( - processor=self, - ) # Set the signal conn in to non-blocking mode, so that attempting to # send when the buffer is full errors, rather than hangs for-ever # attempting to send (this is to avoid deadlocks!) @@ -465,10 +456,7 @@ def __init__( if self._direct_scheduler_conn is not None else {} ) - - @property - def job_runner(self) -> DagProcessorJobRunner: - return self._job_runner + self.heartbeat: Callable[[], None] = lambda: None def register_exit_signals(self): """Register signals that stop child processes.""" @@ -585,11 +573,7 @@ def _run_parsing_loop(self): while True: loop_start_time = time.monotonic() ready = multiprocessing.connection.wait(self.waitables.keys(), timeout=poll_time) - # we cannot (for now) define job in _job_runner nicely due to circular references of - # job and job runner, so we have to use getattr, but we might address it in the future - # change when decoupling these two even more - if getattr(self._job_runner, "job", None) is not None: - perform_heartbeat(self._job_runner.job, only_if_necessary=False) + self.heartbeat() if self._direct_scheduler_conn is not None and self._direct_scheduler_conn in ready: agent_signal = self._direct_scheduler_conn.recv() diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index 2ed17b65704ed..f1ba08e496aa0 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -645,7 +645,7 @@ def execute_callbacks( ) -> None: """ Execute on failure callbacks. - These objects can come from SchedulerJob or from DagFileProcessorManager. + These objects can come from SchedulerJobRunner or from DagProcessorJobRunner. :param dagbag: Dag Bag of dags :param callback_requests: failure callbacks to execute diff --git a/airflow/jobs/backfill_job_runner.py b/airflow/jobs/backfill_job_runner.py index f4e0087a93b84..c99cae2d21a25 100644 --- a/airflow/jobs/backfill_job_runner.py +++ b/airflow/jobs/backfill_job_runner.py @@ -69,8 +69,6 @@ class BackfillJobRunner(BaseJobRunner, LoggingMixin): STATES_COUNT_AS_RUNNING = (State.RUNNING, State.QUEUED) - job: Job # backfill_job can only run with Job class not the Pydantic serialized version - @attr.define class _DagRunTaskStatus: """ @@ -110,6 +108,7 @@ class _DagRunTaskStatus: def __init__( self, + job: Job, dag: DAG, start_date=None, end_date=None, @@ -126,8 +125,6 @@ def __init__( run_at_least_once=False, continue_on_failures=False, disable_retry=False, - *args, - **kwargs, ) -> None: """ Create a BackfillJobRunner. @@ -151,6 +148,14 @@ def __init__( :param args: :param kwargs: """ + super().__init__() + if job.job_type and job.job_type != self.job_type: + raise Exception( + f"The job is already assigned a different job_type: {job.job_type}." + f"This is a bug and should be reported." + ) + self.job = job + self.job.job_type = self.job_type self.dag = dag self.dag_id = dag.dag_id self.bf_start_date = start_date @@ -168,7 +173,6 @@ def __init__( self.run_at_least_once = run_at_least_once self.continue_on_failures = continue_on_failures self.disable_retry = disable_retry - super().__init__(*args, **kwargs) def _update_counters(self, ti_status: _DagRunTaskStatus, session: Session) -> None: """ @@ -632,7 +636,9 @@ def _per_task_process(key, ti: TaskInstance, session): except (NoAvailablePoolSlot, DagConcurrencyLimitReached, TaskConcurrencyLimitReached) as e: self.log.debug(e) - perform_heartbeat(job=self.job, only_if_necessary=is_unit_test) + perform_heartbeat( + job=self.job, heartbeat_callback=self.heartbeat_callback, only_if_necessary=is_unit_test + ) # execute the tasks in the queue executor.heartbeat() diff --git a/airflow/jobs/base_job_runner.py b/airflow/jobs/base_job_runner.py index 096a929d1e6ef..da25da64e15e6 100644 --- a/airflow/jobs/base_job_runner.py +++ b/airflow/jobs/base_job_runner.py @@ -25,7 +25,6 @@ from sqlalchemy.orm import Session from airflow.jobs.job import Job - from airflow.serialization.pydantic.job import JobPydantic class BaseJobRunner: @@ -33,8 +32,6 @@ class BaseJobRunner: job_type = "undefined" - job: Job | JobPydantic - def _execute(self) -> int | None: """ Executes the logic connected to the runner. This method should be diff --git a/airflow/jobs/dag_processor_job_runner.py b/airflow/jobs/dag_processor_job_runner.py index a6a4548e40ec1..0fda9d937314c 100644 --- a/airflow/jobs/dag_processor_job_runner.py +++ b/airflow/jobs/dag_processor_job_runner.py @@ -17,15 +17,23 @@ from __future__ import annotations +from typing import Any + from airflow.dag_processing.manager import DagFileProcessorManager from airflow.jobs.base_job_runner import BaseJobRunner +from airflow.jobs.job import Job, perform_heartbeat from airflow.utils.log.logging_mixin import LoggingMixin +def empty_callback(_: Any) -> None: + pass + + class DagProcessorJobRunner(BaseJobRunner, LoggingMixin): """ DagProcessorJobRunner is a job runner that runs a DagFileProcessorManager processor. + :param job: Job instance to use :param processor: DagFileProcessorManager instance to use """ @@ -33,12 +41,24 @@ class DagProcessorJobRunner(BaseJobRunner, LoggingMixin): def __init__( self, + job: Job, processor: DagFileProcessorManager, *args, **kwargs, ): - self.processor = processor super().__init__(*args, **kwargs) + self.job = job + if job.job_type and job.job_type != self.job_type: + raise Exception( + f"The job is already assigned a different job_type: {job.job_type}." + f"This is a bug and should be reported." + ) + self.processor = processor + self.processor.heartbeat = lambda: perform_heartbeat( + job=self.job, + heartbeat_callback=empty_callback, + only_if_necessary=False, + ) def _execute(self) -> int | None: self.log.info("Starting the Dag Processor Job") diff --git a/airflow/jobs/job.py b/airflow/jobs/job.py index c9b0f36c66d9f..9867efa9ecb51 100644 --- a/airflow/jobs/job.py +++ b/airflow/jobs/job.py @@ -18,18 +18,17 @@ from __future__ import annotations from time import sleep -from typing import NoReturn +from typing import Callable, NoReturn from sqlalchemy import Column, Index, Integer, String, case from sqlalchemy.exc import OperationalError -from sqlalchemy.orm import Session, backref, foreign, relationship -from sqlalchemy.orm.session import make_transient +from sqlalchemy.orm import backref, foreign, relationship +from sqlalchemy.orm.session import Session, make_transient from airflow.compat.functools import cached_property from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.executors.executor_loader import ExecutorLoader -from airflow.jobs.base_job_runner import BaseJobRunner from airflow.listeners.listener import get_listener_manager from airflow.models.base import ID_LEN, Base from airflow.serialization.pydantic.job import JobPydantic @@ -100,10 +99,9 @@ class Job(Base, LoggingMixin): heartrate = conf.getfloat("scheduler", "JOB_HEARTBEAT_SEC") - def __init__(self, job_runner: BaseJobRunner, executor=None, heartrate=None, **kwargs): + def __init__(self, executor=None, heartrate=None, **kwargs): # Save init parameters as DB fields self.hostname = get_hostname() - self.job_type = job_runner.job_type if executor: self.executor = executor self.executor_class = executor.__class__.__name__ @@ -116,8 +114,6 @@ def __init__(self, job_runner: BaseJobRunner, executor=None, heartrate=None, **k self.unixname = getuser() self.max_tis_per_query: int = conf.getint("scheduler", "max_tis_per_query") get_listener_manager().hook.on_starting(component=self) - self._job_runner = job_runner - self._job_runner.job = self super().__init__(**kwargs) @cached_property @@ -157,7 +153,9 @@ def on_kill(self): """Will be called when an external kill command is received.""" @provide_session - def heartbeat(self, session: Session = NEW_SESSION) -> None: + def heartbeat( + self, heartbeat_callback: Callable[[Session], None], session: Session = NEW_SESSION + ) -> None: """ Heartbeats update the job's entry in the database with a timestamp for the latest_heartbeat and allows for the job to be killed @@ -176,6 +174,8 @@ def heartbeat(self, session: Session = NEW_SESSION) -> None: heart rate. If you go over 60 seconds before calling it, it won't sleep at all. + :param heartbeat_callback: Callback that will be run when the heartbeat is recorded in the Job + :param session to use for saving the job """ previous_heartbeat = self.latest_heartbeat @@ -206,7 +206,7 @@ def heartbeat(self, session: Session = NEW_SESSION) -> None: # At this point, the DB has updated. previous_heartbeat = self.latest_heartbeat - self.job_runner.heartbeat_callback(session=session) + heartbeat_callback(session) self.log.debug("[heartbeat]") except OperationalError: Stats.incr(convert_camel_to_snake(self.__class__.__name__) + "_heartbeat_failure", 1, 1) @@ -232,11 +232,6 @@ def complete_execution(self, session: Session = NEW_SESSION): session.commit() Stats.incr(self.__class__.__name__.lower() + "_end", 1, 1) - @property - def job_runner(self) -> BaseJobRunner: - """Returns the job runner instance.""" - return self._job_runner - @provide_session def most_recent_job(self, session: Session = NEW_SESSION) -> Job | None: """Returns the most recent job of this type, if any, based on last heartbeat received.""" @@ -267,7 +262,9 @@ def most_recent_job(job_type: str, session: Session = NEW_SESSION) -> Job | None @provide_session -def run_job(job: Job | JobPydantic, session: Session = NEW_SESSION) -> int | None: +def run_job( + job: Job | JobPydantic, execute_callable: Callable[[], int | None], session: Session = NEW_SESSION +) -> int | None: """ Runs the job. The Job is always an ORM object and setting the state is happening within the same DB session and the session is kept open throughout the whole execution @@ -278,15 +275,15 @@ def run_job(job: Job | JobPydantic, session: Session = NEW_SESSION) -> int | Non """ # The below assert is a temporary one, to make MyPy happy with partial AIP-44 work - we will remove it # once final AIP-44 changes are completed. - assert isinstance(job, Job), "Job should be ORM object not Pydantic one here (AIP-44 WIP)" + assert not isinstance(job, JobPydantic), "Job should be ORM object not Pydantic one here (AIP-44 WIP)" job.prepare_for_execution(session=session) try: - return execute_job(job) + return execute_job(job, execute_callable=execute_callable) finally: job.complete_execution(session=session) -def execute_job(job: Job | JobPydantic) -> int | None: +def execute_job(job: Job | JobPydantic, execute_callable: Callable[[], int | None]) -> int | None: """ Executes the job. @@ -304,12 +301,13 @@ def execute_job(job: Job | JobPydantic) -> int | None: not really matter, because except of running the heartbeat and state setting, the runner should not modify the job state. + :param execute_callable: callable to execute when running the job. + :meta private: """ ret = None try: - # This job_runner reference and type-ignore will be removed by further refactoring step - ret = job.job_runner._execute() # type:ignore[union-attr] + ret = execute_callable() # In case of max runs or max duration job.state = State.SUCCESS except SystemExit: @@ -321,21 +319,24 @@ def execute_job(job: Job | JobPydantic) -> int | None: return ret -def perform_heartbeat(job: Job | JobPydantic, only_if_necessary: bool) -> None: +def perform_heartbeat( + job: Job | JobPydantic, heartbeat_callback: Callable[[Session], None], only_if_necessary: bool +) -> None: """ Performs heartbeat for the Job passed to it,optionally checking if it is necessary. :param job: job to perform heartbeat for + :param heartbeat_callback: callback to run by the heartbeat :param only_if_necessary: only heartbeat if it is necessary (i.e. if there are things to run for triggerer for example) """ # The below assert is a temporary one, to make MyPy happy with partial AIP-44 work - we will remove it # once final AIP-44 changes are completed. - assert isinstance(job, Job), "Job should be ORM object not Pydantic one here (AIP-44 WIP)" + assert not isinstance(job, JobPydantic), "Job should be ORM object not Pydantic one here (AIP-44 WIP)" seconds_remaining: float = 0.0 if job.latest_heartbeat and job.heartrate: seconds_remaining = job.heartrate - (timezone.utcnow() - job.latest_heartbeat).total_seconds() if seconds_remaining > 0 and only_if_necessary: return with create_session() as session: - job.heartbeat(session=session) + job.heartbeat(heartbeat_callback=heartbeat_callback, session=session) diff --git a/airflow/jobs/local_task_job_runner.py b/airflow/jobs/local_task_job_runner.py index d5becaa243903..579586d04d004 100644 --- a/airflow/jobs/local_task_job_runner.py +++ b/airflow/jobs/local_task_job_runner.py @@ -25,8 +25,9 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.jobs.base_job_runner import BaseJobRunner -from airflow.jobs.job import perform_heartbeat +from airflow.jobs.job import Job, perform_heartbeat from airflow.models.taskinstance import TaskInstance, TaskReturnCode +from airflow.serialization.pydantic.job import JobPydantic from airflow.stats import Stats from airflow.utils import timezone from airflow.utils.log.file_task_handler import _set_task_deferred_context_var @@ -74,7 +75,8 @@ class LocalTaskJobRunner(BaseJobRunner, LoggingMixin): def __init__( self, - task_instance: TaskInstance, + job: Job | JobPydantic, + task_instance: TaskInstance, # TODO add TaskInstancePydantic ignore_all_deps: bool = False, ignore_depends_on_past: bool = False, wait_for_past_depends_before_skipping: bool = False, @@ -84,9 +86,16 @@ def __init__( pickle_id: int | None = None, pool: str | None = None, external_executor_id: str | None = None, - *args, - **kwargs, ): + BaseJobRunner.__init__(self) + LoggingMixin.__init__(self, context=task_instance) + if job.job_type and job.job_type != self.job_type: + raise Exception( + f"The job is already assigned a different job_type: {job.job_type}." + f"This is a bug and should be reported." + ) + self.job = job + self.job.job_type = self.job_type self.task_instance = task_instance self.ignore_all_deps = ignore_all_deps self.ignore_depends_on_past = ignore_depends_on_past @@ -97,15 +106,12 @@ def __init__( self.pickle_id = pickle_id self.mark_success = mark_success self.external_executor_id = external_executor_id - # terminating state is used so that a job don't try to # terminate multiple times self.terminating = False self._state_change_checks = 0 - super().__init__(*args, **kwargs) - def _execute(self) -> int | None: from airflow.task.task_runner import get_task_runner @@ -190,7 +196,9 @@ def sigusr2_debug_handler(signum, frame): self.handle_task_exit(return_code) return return_code - perform_heartbeat(self.job, only_if_necessary=False) + perform_heartbeat( + job=self.job, heartbeat_callback=self.heartbeat_callback, only_if_necessary=False + ) # If it's been too long since we've heartbeat, then it's possible that # the scheduler rescheduled this task, so kill launched processes. diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index cff6a234d62c1..2373c974f7271 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -121,10 +121,9 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin): job_type = "SchedulerJob" heartrate: int = conf.getint("scheduler", "SCHEDULER_HEARTBEAT_SEC") - job: Job # scheduler can only run with Job class not the Pydantic serialized version - def __init__( self, + job: Job, subdir: str = settings.DAGS_FOLDER, num_runs: int = conf.getint("scheduler", "num_runs"), num_times_parse_dags: int = -1, @@ -132,11 +131,16 @@ def __init__( do_pickle: bool = False, log: logging.Logger | None = None, processor_poll_interval: float | None = None, - *args, - **kwargs, ): + super().__init__() + if job.job_type and job.job_type != self.job_type: + raise Exception( + f"The job is already assigned a different job_type: {job.job_type}." + f"This is a bug and should be reported." + ) + self.job = job + self.job.job_type = self.job_type self.subdir = subdir - self.num_runs = num_runs # In specific tests, we want to stop the parse loop after the _files_ have been parsed a certain # number of times. This is only to support testing, and isn't something a user is likely to want to @@ -157,7 +161,6 @@ def __init__( self._standalone_dag_processor = conf.getboolean("scheduler", "standalone_dag_processor") self._dag_stale_not_seen_duration = conf.getint("scheduler", "dag_stale_not_seen_duration") self.do_pickle = do_pickle - super().__init__(*args, **kwargs) if log: self._log = log @@ -915,7 +918,9 @@ def _run_scheduler_loop(self) -> None: self.processor_agent.heartbeat() # Heartbeat the scheduler periodically - perform_heartbeat(job=self.job, only_if_necessary=True) + perform_heartbeat( + job=self.job, heartbeat_callback=self.heartbeat_callback, only_if_necessary=True + ) # Run any pending timed events next_event = timers.run(blocking=False) diff --git a/airflow/jobs/triggerer_job_runner.py b/airflow/jobs/triggerer_job_runner.py index da5c1b969ebf1..fc2800dcd198c 100644 --- a/airflow/jobs/triggerer_job_runner.py +++ b/airflow/jobs/triggerer_job_runner.py @@ -34,8 +34,9 @@ from airflow.configuration import conf from airflow.jobs.base_job_runner import BaseJobRunner -from airflow.jobs.job import perform_heartbeat +from airflow.jobs.job import Job, perform_heartbeat from airflow.models.trigger import Trigger +from airflow.serialization.pydantic.job import JobPydantic from airflow.stats import Stats from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.typing_compat import TypedDict @@ -246,10 +247,19 @@ class TriggererJobRunner(BaseJobRunner, LoggingMixin): job_type = "TriggererJob" - def __init__(self, capacity=None, *args, **kwargs): - # Call superclass - super().__init__(*args, **kwargs) - + def __init__( + self, + job: Job | JobPydantic, + capacity=None, + ): + super().__init__() + if job.job_type and job.job_type != self.job_type: + raise Exception( + f"The job is already assigned a different job_type: {job.job_type}." + f"This is a bug and should be reported." + ) + self.job = job + self.job.job_type = self.job_type if capacity is None: self.capacity = conf.getint("triggerer", "default_capacity", fallback=1000) elif isinstance(capacity, int) and capacity > 0: @@ -355,7 +365,7 @@ def _run_trigger_loop(self) -> None: self.handle_events() # Handle failed triggers self.handle_failed_triggers() - perform_heartbeat(self.job, only_if_necessary=True) + perform_heartbeat(self.job, heartbeat_callback=self.heartbeat_callback, only_if_necessary=True) # Collect stats self.emit_metrics() # Idle sleep diff --git a/airflow/models/dag.py b/airflow/models/dag.py index a7c8d0a8963fe..1d3c2538b8122 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2468,28 +2468,27 @@ def run( executor = ExecutorLoader.get_default_executor() from airflow.jobs.job import Job - job = Job( - job_runner=BackfillJobRunner( - self, - start_date=start_date, - end_date=end_date, - mark_success=mark_success, - donot_pickle=donot_pickle, - ignore_task_deps=ignore_task_deps, - ignore_first_depends_on_past=ignore_first_depends_on_past, - pool=pool, - delay_on_limit_secs=delay_on_limit_secs, - verbose=verbose, - conf=conf, - rerun_failed_tasks=rerun_failed_tasks, - run_backwards=run_backwards, - run_at_least_once=run_at_least_once, - continue_on_failures=continue_on_failures, - disable_retry=disable_retry, - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=self, + start_date=start_date, + end_date=end_date, + mark_success=mark_success, + donot_pickle=donot_pickle, + ignore_task_deps=ignore_task_deps, + ignore_first_depends_on_past=ignore_first_depends_on_past, + pool=pool, + delay_on_limit_secs=delay_on_limit_secs, + verbose=verbose, + conf=conf, + rerun_failed_tasks=rerun_failed_tasks, + run_backwards=run_backwards, + run_at_least_once=run_at_least_once, + continue_on_failures=continue_on_failures, + disable_retry=disable_retry, ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) def cli(self): """Exposes a CLI specific to this DAG""" diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py index 1f8dce26dd265..c83d9f53ef87c 100644 --- a/airflow/serialization/enums.py +++ b/airflow/serialization/enums.py @@ -51,7 +51,7 @@ class DagAttributeTypes(str, Enum): XCOM_REF = "xcomref" DATASET = "dataset" SIMPLE_TASK_INSTANCE = "simple_task_instance" - BASE_JOB = "base_job" + BASE_JOB = "Job" TASK_INSTANCE = "task_instance" DAG_RUN = "dag_run" DATA_SET = "data_set" diff --git a/airflow/task/task_runner/__init__.py b/airflow/task/task_runner/__init__.py index fa57d52b5cc89..078b24a402201 100644 --- a/airflow/task/task_runner/__init__.py +++ b/airflow/task/task_runner/__init__.py @@ -59,5 +59,5 @@ def get_task_runner(local_task_job_runner: LocalTaskJobRunner) -> BaseTaskRunner f'The task runner could not be loaded. Please check "task_runner" key in "core" section. ' f'Current value: "{_TASK_RUNNER_NAME}".' ) - task_runner = task_runner_class(local_task_job_runner.job) + task_runner = task_runner_class(local_task_job_runner) return task_runner diff --git a/airflow/task/task_runner/base_task_runner.py b/airflow/task/task_runner/base_task_runner.py index 570ce1b1b8181..64523b17c5ae2 100644 --- a/airflow/task/task_runner/base_task_runner.py +++ b/airflow/task/task_runner/base_task_runner.py @@ -21,9 +21,7 @@ import os import subprocess import threading -from typing import cast -from airflow.jobs.job import Job from airflow.jobs.local_task_job_runner import LocalTaskJobRunner from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager from airflow.utils.platform import IS_WINDOWS @@ -32,7 +30,6 @@ # ignored to avoid flake complaining on Linux from pwd import getpwnam # noqa - from airflow.configuration import conf from airflow.exceptions import AirflowConfigException from airflow.utils.configuration import tmp_configuration_copy @@ -49,19 +46,13 @@ class BaseTaskRunner(LoggingMixin): Invoke the `airflow tasks run` command with raw mode enabled in a subprocess. - :param base_job: The job associated with running the associated task instance. The job_runner for it - should be LocalTaskJobRunner + :param job_runner: The LocalTaskJobRunner associated with the task runner """ - def __init__(self, base_job: Job): - self.job_runner: LocalTaskJobRunner = cast(LocalTaskJobRunner, base_job.job_runner) - if not hasattr(self.job_runner, "task_instance"): - raise ValueError( - "BaseTaskRunner can only be used with LocalTaskJobRunner and " - "have task_instance field defined" - ) - super().__init__(self.job_runner.task_instance) - self._task_instance = self.job_runner.task_instance + def __init__(self, job_runner: LocalTaskJobRunner): + self.job_runner = job_runner + super().__init__(job_runner.task_instance) + self._task_instance = job_runner.task_instance popen_prepend = [] if self._task_instance.run_as_user: @@ -104,7 +95,7 @@ def __init__(self, base_job: Job): raw=True, pickle_id=self.job_runner.pickle_id, mark_success=self.job_runner.mark_success, - job_id=base_job.id, + job_id=self.job_runner.job.id, pool=self.job_runner.pool, cfg_path=cfg_path, ) diff --git a/airflow/task/task_runner/cgroup_task_runner.py b/airflow/task/task_runner/cgroup_task_runner.py index 8b6b03e5a559d..2ab011471377a 100644 --- a/airflow/task/task_runner/cgroup_task_runner.py +++ b/airflow/task/task_runner/cgroup_task_runner.py @@ -25,7 +25,7 @@ import psutil from cgroupspy import trees -from airflow.jobs.job import Job +from airflow.jobs.local_task_job_runner import LocalTaskJobRunner from airflow.task.task_runner.base_task_runner import BaseTaskRunner from airflow.utils.operator_resources import Resources from airflow.utils.platform import getuser @@ -62,8 +62,8 @@ class CgroupTaskRunner(BaseTaskRunner): airflow ALL= (root) NOEXEC: !/bin/chmod /CGROUPS_FOLDER/cpu/airflow/* * """ - def __init__(self, base_job: Job): - super().__init__(base_job=base_job) + def __init__(self, job_runner: LocalTaskJobRunner): + super().__init__(job_runner=job_runner) self.process = None self._finished_running = False self._cpu_shares = None diff --git a/airflow/task/task_runner/standard_task_runner.py b/airflow/task/task_runner/standard_task_runner.py index b4233d20ea7dd..9b506580e1a93 100644 --- a/airflow/task/task_runner/standard_task_runner.py +++ b/airflow/task/task_runner/standard_task_runner.py @@ -24,7 +24,7 @@ import psutil from setproctitle import setproctitle -from airflow.jobs.job import Job +from airflow.jobs.local_task_job_runner import LocalTaskJobRunner from airflow.models.taskinstance import TaskReturnCode from airflow.settings import CAN_FORK from airflow.task.task_runner.base_task_runner import BaseTaskRunner @@ -35,10 +35,10 @@ class StandardTaskRunner(BaseTaskRunner): """Standard runner for all tasks.""" - def __init__(self, base_job: Job): - super().__init__(base_job=base_job) + def __init__(self, job_runner: LocalTaskJobRunner): + super().__init__(job_runner=job_runner) self._rc = None - self.dag = self.job_runner.task_instance.task.dag + self.dag = self._task_instance.task.dag def start(self): if CAN_FORK and not self.run_as_user: diff --git a/dev/perf/scheduler_dag_execution_timing.py b/dev/perf/scheduler_dag_execution_timing.py index 5162cf1285303..613a929e9e405 100755 --- a/dev/perf/scheduler_dag_execution_timing.py +++ b/dev/perf/scheduler_dag_execution_timing.py @@ -92,7 +92,7 @@ def change_state(self, key, state, info=None): if not self.dags_to_watch: self.log.warning("STOPPING SCHEDULER -- all runs complete") - self.scheduler_job.job_runner.processor_agent._done = True + self.job_runner.processor_agent._done = True return self.log.warning( "WAITING ON %d RUNS", sum(map(attrgetter("waiting_for"), self.dags_to_watch.values())) @@ -119,7 +119,7 @@ class ShortCircuitExecutor(ShortCircuitExecutorMixin, executor_cls): Placeholder class that implements the inheritance hierarchy """ - scheduler_job = None + job_runner = None return ShortCircuitExecutor @@ -279,8 +279,9 @@ def main(num_runs, repeat, pre_create_dag_runs, executor_class, dag_ids): ShortCircuitExecutor = get_executor_under_test(executor_class) executor = ShortCircuitExecutor(dag_ids_to_watch=dag_ids, num_runs=num_runs) - scheduler_job = Job(job_runner=SchedulerJobRunner(dag_ids=dag_ids, do_pickle=False), executor=executor) - executor.scheduler_job = scheduler_job + scheduler_job = Job(executor=executor) + job_runner = SchedulerJobRunner(job=scheduler_job, dag_ids=dag_ids, do_pickle=False) + executor.job_runner = job_runner total_tasks = sum(len(dag.tasks) for dag in dags) @@ -293,7 +294,7 @@ def main(num_runs, repeat, pre_create_dag_runs, executor_class, dag_ids): # Need a lambda to refer to the _latest_ value for scheduler_job, not just # the initial one - code_to_test = lambda: run_job(scheduler_job) + code_to_test = lambda: run_job(job=job_runner.job, execute_callable=job_runner._execute) for count in range(repeat): gc.disable() @@ -310,9 +311,8 @@ def main(num_runs, repeat, pre_create_dag_runs, executor_class, dag_ids): reset_dag(dag, session) executor.reset(dag_ids) - scheduler_job = Job( - job_runner=SchedulerJobRunner(dag_ids=dag_ids, do_pickle=False), executor=executor - ) + scheduler_job = Job(executor=executor) + job_runner = SchedulerJobRunner(job=scheduler_job, dag_ids=dag_ids, do_pickle=False) executor.scheduler_job = scheduler_job print() diff --git a/dev/perf/sql_queries.py b/dev/perf/sql_queries.py index 74394190c1f2d..2bddc349ef15d 100644 --- a/dev/perf/sql_queries.py +++ b/dev/perf/sql_queries.py @@ -124,7 +124,8 @@ def run_scheduler_job(with_db_reset=False) -> None: if with_db_reset: reset_db() - run_job(Job(job_runner=SchedulerJobRunner(subdir=DAG_FOLDER, do_pickle=False, num_runs=3))) + job_runner = SchedulerJobRunner(job=Job(), subdir=DAG_FOLDER, do_pickle=False, num_runs=3) + run_job(job=job_runner.job, execute_callable=job_runner._execute) def is_query(line: str) -> bool: diff --git a/docs/apache-airflow/authoring-and-scheduling/dagfile-processing.rst b/docs/apache-airflow/authoring-and-scheduling/dagfile-processing.rst index 676fd7af78ad2..e492dd3940807 100644 --- a/docs/apache-airflow/authoring-and-scheduling/dagfile-processing.rst +++ b/docs/apache-airflow/authoring-and-scheduling/dagfile-processing.rst @@ -1,3 +1,4 @@ + .. Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information diff --git a/tests/api_connexion/endpoints/test_health_endpoint.py b/tests/api_connexion/endpoints/test_health_endpoint.py index 577ba33cec88e..efa552013857e 100644 --- a/tests/api_connexion/endpoints/test_health_endpoint.py +++ b/tests/api_connexion/endpoints/test_health_endpoint.py @@ -48,14 +48,9 @@ class TestGetHealth(TestHealthTestBase): @provide_session def test_healthy_scheduler_status(self, session): last_scheduler_heartbeat_for_testing_1 = timezone.utcnow() - session.add( - Job( - job_type="SchedulerJob", - state=State.RUNNING, - latest_heartbeat=last_scheduler_heartbeat_for_testing_1, - job_runner=SchedulerJobRunner(), - ) - ) + job = Job(state=State.RUNNING, latest_heartbeat=last_scheduler_heartbeat_for_testing_1) + SchedulerJobRunner(job=job) + session.add(job) session.commit() resp_json = self.client.get("/api/v1/health").json assert "healthy" == resp_json["metadatabase"]["status"] @@ -68,14 +63,9 @@ def test_healthy_scheduler_status(self, session): @provide_session def test_unhealthy_scheduler_is_slow(self, session): last_scheduler_heartbeat_for_testing_2 = timezone.utcnow() - timedelta(minutes=1) - session.add( - Job( - job_type="SchedulerJob", - state=State.RUNNING, - latest_heartbeat=last_scheduler_heartbeat_for_testing_2, - job_runner=SchedulerJobRunner(), - ) - ) + job = Job(state=State.RUNNING, latest_heartbeat=last_scheduler_heartbeat_for_testing_2) + SchedulerJobRunner(job=job) + session.add(job) session.commit() resp_json = self.client.get("/api/v1/health").json assert "healthy" == resp_json["metadatabase"]["status"] diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py index ed3327bca4de2..ac2b4b694d7f3 100644 --- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py @@ -232,7 +232,8 @@ def test_should_respond_200_with_task_state_in_deferred(self, session): )[0] ti.trigger = Trigger("none", {}) ti.trigger.created_date = now - ti.triggerer_job = Job(job_runner=TriggererJobRunner()) + ti.triggerer_job = Job() + TriggererJobRunner(job=ti.triggerer_job) ti.triggerer_job.state = "running" session.commit() response = self.client.get( diff --git a/tests/cli/commands/test_dag_processor_command.py b/tests/cli/commands/test_dag_processor_command.py index 0fbcfe37ea580..0fb9dc6374052 100644 --- a/tests/cli/commands/test_dag_processor_command.py +++ b/tests/cli/commands/test_dag_processor_command.py @@ -42,7 +42,7 @@ def setup_class(cls): ("core", "load_examples"): "False", } ) - @mock.patch("airflow.jobs.dag_processor_job_runner.DagProcessorJobRunner") + @mock.patch("airflow.cli.commands.dag_processor_command.DagProcessorJobRunner") @pytest.mark.skipif( conf.get_mandatory_value("database", "sql_alchemy_conn").lower().startswith("sqlite"), reason="Standalone Dag Processor doesn't support sqlite.", @@ -51,7 +51,7 @@ def test_start_job( self, mock_dag_job, ): - """Ensure that DagFileProcessorManager is started""" + """Ensure that DagProcessorJobRunner is started""" with conf_vars({("scheduler", "standalone_dag_processor"): "True"}): mock_dag_job.return_value.job_type = "DagProcessorJob" args = self.parser.parse_args(["dag-processor"]) diff --git a/tests/cli/commands/test_jobs_command.py b/tests/cli/commands/test_jobs_command.py index cf3a276414440..0ec89149040cd 100644 --- a/tests/cli/commands/test_jobs_command.py +++ b/tests/cli/commands/test_jobs_command.py @@ -38,19 +38,21 @@ def setup_class(cls): def setup_method(self) -> None: clear_db_jobs() self.scheduler_job = None + self.job_runner = None def teardown_method(self) -> None: - if self.scheduler_job and self.scheduler_job.job_runner.processor_agent: - self.scheduler_job.job_runner.processor_agent.end() + if self.job_runner and self.job_runner.processor_agent: + self.job_runner.processor_agent.end() clear_db_jobs() def test_should_report_success_for_one_working_scheduler(self): with create_session() as session: - self.scheduler_job = Job(job_runner=SchedulerJobRunner()) + self.scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=self.scheduler_job) self.scheduler_job.state = State.RUNNING session.add(self.scheduler_job) session.commit() - self.scheduler_job.heartbeat() + self.scheduler_job.heartbeat(heartbeat_callback=self.job_runner.heartbeat_callback) with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: jobs_command.check(self.parser.parse_args(["jobs", "check", "--job-type", "SchedulerJob"])) @@ -58,12 +60,13 @@ def test_should_report_success_for_one_working_scheduler(self): def test_should_report_success_for_one_working_scheduler_with_hostname(self): with create_session() as session: - self.scheduler_job = Job(job_runner=SchedulerJobRunner()) + self.scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=self.scheduler_job) self.scheduler_job.state = State.RUNNING self.scheduler_job.hostname = "HOSTNAME" session.add(self.scheduler_job) session.commit() - self.scheduler_job.heartbeat() + self.scheduler_job.heartbeat(heartbeat_callback=self.job_runner.heartbeat_callback) with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: jobs_command.check( @@ -75,52 +78,64 @@ def test_should_report_success_for_one_working_scheduler_with_hostname(self): def test_should_report_success_for_ha_schedulers(self): scheduler_jobs = [] + job_runners = [] with create_session() as session: for _ in range(3): - scheduler_job = Job(job_runner=SchedulerJobRunner()) + scheduler_job = Job() + job_runner = SchedulerJobRunner(job=scheduler_job) scheduler_job.state = State.RUNNING session.add(scheduler_job) scheduler_jobs.append(scheduler_job) + job_runners.append(job_runner) session.commit() - scheduler_job.heartbeat() - - with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: - jobs_command.check( - self.parser.parse_args( - ["jobs", "check", "--job-type", "SchedulerJob", "--limit", "100", "--allow-multiple"] + scheduler_job.heartbeat(heartbeat_callback=job_runner.heartbeat_callback) + try: + with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: + jobs_command.check( + self.parser.parse_args( + ["jobs", "check", "--job-type", "SchedulerJob", "--limit", "100", "--allow-multiple"] + ) ) - ) - assert "Found 3 alive jobs." in temp_stdout.getvalue() - for scheduler_job in scheduler_jobs: - if scheduler_job.job_runner.processor_agent: - scheduler_job.job_runner.processor_agent.end() + assert "Found 3 alive jobs." in temp_stdout.getvalue() + finally: + for job_runner in job_runners: + if job_runner.processor_agent: + job_runner.processor_agent.end() def test_should_ignore_not_running_jobs(self): scheduler_jobs = [] + job_runners = [] with create_session() as session: for _ in range(3): - scheduler_job = Job(job_runner=SchedulerJobRunner()) + scheduler_job = Job() + job_runner = SchedulerJobRunner(job=scheduler_job) scheduler_job.state = State.SHUTDOWN session.add(scheduler_job) scheduler_jobs.append(scheduler_job) + job_runners.append(job_runner) session.commit() # No alive jobs found. with pytest.raises(SystemExit, match=r"No alive jobs found."): jobs_command.check(self.parser.parse_args(["jobs", "check"])) - for scheduler_job in scheduler_jobs: - if scheduler_job.job_runner.processor_agent: - scheduler_job.job_runner.processor_agent.end() + for job_runner in job_runners: + if job_runner.processor_agent: + job_runner.processor_agent.end() def test_should_raise_exception_for_multiple_scheduler_on_one_host(self): scheduler_jobs = [] + job_runners = [] with create_session() as session: for _ in range(3): - scheduler_job = Job(job_runner=SchedulerJobRunner()) + scheduler_job = Job() + job_runner = SchedulerJobRunner(job=scheduler_job) + job_runner.job = scheduler_job scheduler_job.state = State.RUNNING scheduler_job.hostname = "HOSTNAME" session.add(scheduler_job) + scheduler_jobs.append(scheduler_job) + job_runners.append(job_runner) session.commit() - scheduler_job.heartbeat() + scheduler_job.heartbeat(heartbeat_callback=job_runner.heartbeat_callback) with pytest.raises(SystemExit, match=r"Found 3 alive jobs. Expected only one."): jobs_command.check( @@ -135,9 +150,9 @@ def test_should_raise_exception_for_multiple_scheduler_on_one_host(self): ] ) ) - for scheduler_job in scheduler_jobs: - if scheduler_job.job_runner.processor_agent: - scheduler_job.job_runner.processor_agent.end() + for job_runner in job_runners: + if job_runner.processor_agent: + job_runner.processor_agent.end() def test_should_raise_exception_for_allow_multiple_and_limit_1(self): with pytest.raises( diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py index 683e8594224b1..376faeda41df4 100644 --- a/tests/cli/commands/test_task_command.py +++ b/tests/cli/commands/test_task_command.py @@ -246,7 +246,7 @@ def test_cli_test_different_path(self, session): assert ti.xcom_pull(ti.task_id) == new_file_path.as_posix() @mock.patch("airflow.cli.commands.task_command.LocalTaskJobRunner") - def test_run_with_existing_dag_run_id(self, mock_local_job): + def test_run_with_existing_dag_run_id(self, mock_local_job_runner): """ Test that we can run with existing dag_run_id """ @@ -260,9 +260,10 @@ def test_run_with_existing_dag_run_id(self, mock_local_job): task0_id, self.run_id, ] - mock_local_job.return_value.job_type = "LocalTaskJob" + mock_local_job_runner.return_value.job_type = "LocalTaskJob" task_command.task_run(self.parser.parse_args(args0), dag=self.dag) - mock_local_job.assert_called_once_with( + mock_local_job_runner.assert_called_once_with( + job=mock.ANY, task_instance=mock.ANY, mark_success=False, ignore_all_deps=True, @@ -646,6 +647,7 @@ def test_external_executor_id_present_for_fork_run_task(self, mock_local_job): task_command.task_run(args) mock_local_job.assert_called_once_with( + job=mock.ANY, task_instance=mock.ANY, mark_success=False, pickle_id=None, @@ -667,6 +669,7 @@ def test_external_executor_id_present_for_process_run_task(self, mock_local_job) with mock.patch.dict(os.environ, {"external_executor_id": "12345FEDCBA"}): task_command.task_run(args) mock_local_job.assert_called_once_with( + job=mock.ANY, task_instance=mock.ANY, mark_success=False, pickle_id=None, diff --git a/tests/cli/commands/test_triggerer_command.py b/tests/cli/commands/test_triggerer_command.py index 38398a0e520ef..387437db3b3bb 100644 --- a/tests/cli/commands/test_triggerer_command.py +++ b/tests/cli/commands/test_triggerer_command.py @@ -45,4 +45,4 @@ def test_capacity_argument( triggerer_command.triggerer(args) mock_serve.return_value.__enter__.assert_called_once() mock_serve.return_value.__exit__.assert_called_once() - mock_triggerer_job_runner.assert_called_once_with(capacity=42) + mock_triggerer_job_runner.assert_called_once_with(job=mock.ANY, capacity=42) diff --git a/tests/core/test_impersonation_tests.py b/tests/core/test_impersonation_tests.py index daca39e4a8258..3827f898961b0 100644 --- a/tests/core/test_impersonation_tests.py +++ b/tests/core/test_impersonation_tests.py @@ -113,7 +113,9 @@ def run_backfill(self, dag_id, task_id): dag = self.dagbag.get_dag(dag_id) dag.clear() - run_job(Job(job_runner=BackfillJobRunner(dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE))) + job = Job() + job_runner = BackfillJobRunner(job=job, dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + run_job(job=job, execute_callable=job_runner._execute) run_id = DagRun.generate_run_id(DagRunType.BACKFILL_JOB, execution_date=DEFAULT_DATE) ti = TaskInstance(task=dag.get_task(task_id), run_id=run_id) ti.refresh_from_db() diff --git a/tests/dag_processing/test_manager.py b/tests/dag_processing/test_job_runner.py similarity index 70% rename from tests/dag_processing/test_manager.py rename to tests/dag_processing/test_job_runner.py index a7d9328bbacfa..87d321f932214 100644 --- a/tests/dag_processing/test_manager.py +++ b/tests/dag_processing/test_job_runner.py @@ -49,6 +49,8 @@ DagParsingStat, ) from airflow.dag_processing.processor import DagFileProcessorProcess +from airflow.jobs.dag_processor_job_runner import DagProcessorJobRunner +from airflow.jobs.job import Job from airflow.models import DagBag, DagModel, DbCallbackRequest, errors from airflow.models.dagcode import DagCode from airflow.models.serialized_dag import SerializedDagModel @@ -111,7 +113,7 @@ def waitable_handle(self): return self._waitable_handle -class TestDagFileProcessorManager: +class TestDagProcessorJobRunner: def setup_method(self): dictConfig(DEFAULT_LOGGING_CONFIG) clear_db_runs() @@ -126,13 +128,13 @@ def teardown_class(self): clear_db_callbacks() def run_processor_manager_one_loop(self, manager, parent_pipe): - if not manager._async_mode: + if not manager.processor._async_mode: parent_pipe.send(DagParsingSignal.AGENT_RUN_ONCE) results = [] while True: - manager._run_parsing_loop() + manager.processor._run_parsing_loop() while parent_pipe.poll(timeout=0.01): obj = parent_pipe.recv() @@ -153,14 +155,17 @@ def test_remove_file_clears_import_error(self, tmpdir): child_pipe, parent_pipe = multiprocessing.Pipe() async_mode = "sqlite" not in conf.get("database", "sql_alchemy_conn") - manager = DagFileProcessorManager( - dag_directory=tmpdir, - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=child_pipe, - dag_ids=[], - pickle_dags=False, - async_mode=async_mode, + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory=tmpdir, + max_runs=1, + processor_timeout=timedelta(days=365), + signal_conn=child_pipe, + dag_ids=[], + pickle_dags=False, + async_mode=async_mode, + ), ) with create_session() as session: @@ -188,14 +193,17 @@ def test_max_runs_when_no_files(self): with TemporaryDirectory(prefix="empty-airflow-dags-") as dags_folder: async_mode = "sqlite" not in conf.get("database", "sql_alchemy_conn") - manager = DagFileProcessorManager( - dag_directory=dags_folder, - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=child_pipe, - dag_ids=[], - pickle_dags=False, - async_mode=async_mode, + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory=dags_folder, + max_runs=1, + processor_timeout=timedelta(days=365), + signal_conn=child_pipe, + dag_ids=[], + pickle_dags=False, + async_mode=async_mode, + ), ) self.run_processor_manager_one_loop(manager, parent_pipe) @@ -208,25 +216,28 @@ def test_start_new_processes_with_same_filepath(self): Test that when a processor already exist with a filepath, a new processor won't be created with that filepath. The filepath will just be removed from the list. """ - manager = DagFileProcessorManager( - dag_directory="directory", - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - dag_ids=[], - pickle_dags=False, - async_mode=True, + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory="directory", + max_runs=1, + processor_timeout=timedelta(days=365), + signal_conn=MagicMock(), + dag_ids=[], + pickle_dags=False, + async_mode=True, + ), ) file_1 = "file_1.py" file_2 = "file_2.py" file_3 = "file_3.py" - manager._file_path_queue = collections.deque([file_1, file_2, file_3]) + manager.processor._file_path_queue = collections.deque([file_1, file_2, file_3]) # Mock that only one processor exists. This processor runs with 'file_1' - manager._processors[file_1] = MagicMock() + manager.processor._processors[file_1] = MagicMock() # Start New Processes - manager.start_new_processes() + manager.processor.start_new_processes() # Because of the config: '[scheduler] parsing_processes = 2' # verify that only one extra process is created @@ -234,51 +245,57 @@ def test_start_new_processes_with_same_filepath(self): # even though it is first in '_file_path_queue' # a new processor is created with 'file_2' and not 'file_1'. - assert file_1 in manager._processors.keys() - assert file_2 in manager._processors.keys() - assert collections.deque([file_3]) == manager._file_path_queue + assert file_1 in manager.processor._processors.keys() + assert file_2 in manager.processor._processors.keys() + assert collections.deque([file_3]) == manager.processor._file_path_queue def test_set_file_paths_when_processor_file_path_not_in_new_file_paths(self): - manager = DagFileProcessorManager( - dag_directory="directory", - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - dag_ids=[], - pickle_dags=False, - async_mode=True, + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory="directory", + max_runs=1, + processor_timeout=timedelta(days=365), + signal_conn=MagicMock(), + dag_ids=[], + pickle_dags=False, + async_mode=True, + ), ) mock_processor = MagicMock() mock_processor.stop.side_effect = AttributeError("DagFileProcessor object has no attribute stop") mock_processor.terminate.side_effect = None - manager._processors["missing_file.txt"] = mock_processor - manager._file_stats["missing_file.txt"] = DagFileStat(0, 0, None, None, 0) + manager.processor._processors["missing_file.txt"] = mock_processor + manager.processor._file_stats["missing_file.txt"] = DagFileStat(0, 0, None, None, 0) - manager.set_file_paths(["abc.txt"]) - assert manager._processors == {} - assert "missing_file.txt" not in manager._file_stats + manager.processor.set_file_paths(["abc.txt"]) + assert manager.processor._processors == {} + assert "missing_file.txt" not in manager.processor._file_stats def test_set_file_paths_when_processor_file_path_is_in_new_file_paths(self): - manager = DagFileProcessorManager( - dag_directory="directory", - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - dag_ids=[], - pickle_dags=False, - async_mode=True, + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory="directory", + max_runs=1, + processor_timeout=timedelta(days=365), + signal_conn=MagicMock(), + dag_ids=[], + pickle_dags=False, + async_mode=True, + ), ) mock_processor = MagicMock() mock_processor.stop.side_effect = AttributeError("DagFileProcessor object has no attribute stop") mock_processor.terminate.side_effect = None - manager._processors["abc.txt"] = mock_processor + manager.processor._processors["abc.txt"] = mock_processor - manager.set_file_paths(["abc.txt"]) - assert manager._processors == {"abc.txt": mock_processor} + manager.processor.set_file_paths(["abc.txt"]) + assert manager.processor._processors == {"abc.txt": mock_processor} @conf_vars({("scheduler", "file_parsing_sort_mode"): "alphabetical"}) @mock.patch("zipfile.is_zipfile", return_value=True) @@ -292,20 +309,23 @@ def test_file_paths_in_queue_sorted_alphabetically( dag_files = ["file_3.py", "file_2.py", "file_4.py", "file_1.py"] mock_find_path.return_value = dag_files - manager = DagFileProcessorManager( - dag_directory="directory", - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - dag_ids=[], - pickle_dags=False, - async_mode=True, + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory="directory", + max_runs=1, + processor_timeout=timedelta(days=365), + signal_conn=MagicMock(), + dag_ids=[], + pickle_dags=False, + async_mode=True, + ), ) - manager.set_file_paths(dag_files) - assert manager._file_path_queue == collections.deque() - manager.prepare_file_path_queue() - assert manager._file_path_queue == collections.deque( + manager.processor.set_file_paths(dag_files) + assert manager.processor._file_path_queue == collections.deque() + manager.processor.prepare_file_path_queue() + assert manager.processor._file_path_queue == collections.deque( ["file_1.py", "file_2.py", "file_3.py", "file_4.py"] ) @@ -321,28 +341,31 @@ def test_file_paths_in_queue_sorted_random_seeded_by_host( dag_files = ["file_3.py", "file_2.py", "file_4.py", "file_1.py"] mock_find_path.return_value = dag_files - manager = DagFileProcessorManager( - dag_directory="directory", - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - dag_ids=[], - pickle_dags=False, - async_mode=True, + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory="directory", + max_runs=1, + processor_timeout=timedelta(days=365), + signal_conn=MagicMock(), + dag_ids=[], + pickle_dags=False, + async_mode=True, + ), ) - manager.set_file_paths(dag_files) - assert manager._file_path_queue == collections.deque() - manager.prepare_file_path_queue() + manager.processor.set_file_paths(dag_files) + assert manager.processor._file_path_queue == collections.deque() + manager.processor.prepare_file_path_queue() expected_order = collections.deque(dag_files) random.Random(get_hostname()).shuffle(expected_order) - assert manager._file_path_queue == expected_order + assert manager.processor._file_path_queue == expected_order # Verify running it again produces same order - manager._file_paths = [] - manager.prepare_file_path_queue() - assert manager._file_path_queue == expected_order + manager.processor._file_paths = [] + manager.processor.prepare_file_path_queue() + assert manager.processor._file_path_queue == expected_order @pytest.fixture def change_platform_timezone(self, monkeypatch): @@ -383,20 +406,23 @@ def test_file_paths_in_queue_sorted_by_modified_time( mock_getmtime.side_effect = list(paths_with_mtime.values()) mock_find_path.return_value = dag_files - manager = DagFileProcessorManager( - dag_directory="directory", - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - dag_ids=[], - pickle_dags=False, - async_mode=True, + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory="directory", + max_runs=1, + processor_timeout=timedelta(days=365), + signal_conn=MagicMock(), + dag_ids=[], + pickle_dags=False, + async_mode=True, + ), ) - manager.set_file_paths(dag_files) - assert manager._file_path_queue == collections.deque() - manager.prepare_file_path_queue() - assert manager._file_path_queue == collections.deque( + manager.processor.set_file_paths(dag_files) + assert manager.processor._file_path_queue == collections.deque() + manager.processor.prepare_file_path_queue() + assert manager.processor._file_path_queue == collections.deque( ["file_4.py", "file_1.py", "file_3.py", "file_2.py"] ) @@ -420,19 +446,22 @@ def test_file_paths_in_queue_excludes_missing_file( mock_getmtime.side_effect = [1.0, 2.0, FileNotFoundError()] mock_find_path.return_value = dag_files - manager = DagFileProcessorManager( - dag_directory="directory", - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - dag_ids=[], - pickle_dags=False, - async_mode=True, + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory="directory", + max_runs=1, + processor_timeout=timedelta(days=365), + signal_conn=MagicMock(), + dag_ids=[], + pickle_dags=False, + async_mode=True, + ), ) - manager.set_file_paths(dag_files) - manager.prepare_file_path_queue() - assert manager._file_path_queue == collections.deque(["file_2.py", "file_3.py"]) + manager.processor.set_file_paths(dag_files) + manager.processor.prepare_file_path_queue() + assert manager.processor._file_path_queue == collections.deque(["file_2.py", "file_3.py"]) @conf_vars({("scheduler", "file_parsing_sort_mode"): "modified_time"}) @mock.patch("zipfile.is_zipfile", return_value=True) @@ -454,23 +483,28 @@ def test_add_new_file_to_parsing_queue( mock_getmtime.side_effect = [1.0, 2.0, 3.0] mock_find_path.return_value = dag_files - manager = DagFileProcessorManager( - dag_directory="directory", - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - dag_ids=[], - pickle_dags=False, - async_mode=True, + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory="directory", + max_runs=1, + processor_timeout=timedelta(days=365), + signal_conn=MagicMock(), + dag_ids=[], + pickle_dags=False, + async_mode=True, + ), ) - manager.set_file_paths(dag_files) - manager.prepare_file_path_queue() - assert manager._file_path_queue == collections.deque(["file_3.py", "file_2.py", "file_1.py"]) + manager.processor.set_file_paths(dag_files) + manager.processor.prepare_file_path_queue() + assert manager.processor._file_path_queue == collections.deque( + ["file_3.py", "file_2.py", "file_1.py"] + ) - manager.set_file_paths(dag_files + ["file_4.py"]) - manager.add_new_file_path_to_queue() - assert manager._file_path_queue == collections.deque( + manager.processor.set_file_paths(dag_files + ["file_4.py"]) + manager.processor.add_new_file_path_to_queue() + assert manager.processor._file_path_queue == collections.deque( ["file_4.py", "file_3.py", "file_2.py", "file_1.py"] ) @@ -499,44 +533,47 @@ def test_recently_modified_file_is_parsed_with_mtime_mode( mock_getmtime.side_effect = [initial_file_1_mtime] mock_find_path.return_value = dag_files - manager = DagFileProcessorManager( - dag_directory="directory", - max_runs=3, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - dag_ids=[], - pickle_dags=False, - async_mode=True, + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory="directory", + max_runs=3, + processor_timeout=timedelta(days=365), + signal_conn=MagicMock(), + dag_ids=[], + pickle_dags=False, + async_mode=True, + ), ) # let's say the DAG was just parsed 10 seconds before the Freezed time last_finish_time = freezed_base_time - timedelta(seconds=10) - manager._file_stats = { + manager.processor._file_stats = { "file_1.py": DagFileStat(1, 0, last_finish_time, timedelta(seconds=1.0), 1), } with time_machine.travel(freezed_base_time): - manager.set_file_paths(dag_files) - assert manager._file_path_queue == collections.deque() + manager.processor.set_file_paths(dag_files) + assert manager.processor._file_path_queue == collections.deque() # File Path Queue will be empty as the "modified time" < "last finish time" - manager.prepare_file_path_queue() - assert manager._file_path_queue == collections.deque() + manager.processor.prepare_file_path_queue() + assert manager.processor._file_path_queue == collections.deque() # Simulate the DAG modification by using modified_time which is greater # than the last_parse_time but still less than now - min_file_process_interval file_1_new_mtime = freezed_base_time - timedelta(seconds=5) file_1_new_mtime_ts = file_1_new_mtime.timestamp() with time_machine.travel(freezed_base_time): - manager.set_file_paths(dag_files) - assert manager._file_path_queue == collections.deque() + manager.processor.set_file_paths(dag_files) + assert manager.processor._file_path_queue == collections.deque() # File Path Queue will be empty as the "modified time" < "last finish time" mock_getmtime.side_effect = [file_1_new_mtime_ts] - manager.prepare_file_path_queue() + manager.processor.prepare_file_path_queue() # Check that file is added to the queue even though file was just recently passed - assert manager._file_path_queue == collections.deque(["file_1.py"]) + assert manager.processor._file_path_queue == collections.deque(["file_1.py"]) assert last_finish_time < file_1_new_mtime assert ( - manager._file_process_interval - > (freezed_base_time - manager.get_last_finish_time("file_1.py")).total_seconds() + manager.processor._file_process_interval + > (freezed_base_time - manager.processor.get_last_finish_time("file_1.py")).total_seconds() ) def test_scan_stale_dags(self): @@ -544,14 +581,17 @@ def test_scan_stale_dags(self): Ensure that DAGs are marked inactive when the file is parsed but the DagModel.last_parsed_time is not updated. """ - manager = DagFileProcessorManager( - dag_directory="directory", - max_runs=1, - processor_timeout=timedelta(minutes=10), - signal_conn=MagicMock(), - dag_ids=[], - pickle_dags=False, - async_mode=True, + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory="directory", + max_runs=1, + processor_timeout=timedelta(minutes=10), + signal_conn=MagicMock(), + dag_ids=[], + pickle_dags=False, + async_mode=True, + ), ) test_dag_path = str(TEST_DAG_FOLDER / "test_example_bash_operator.py") @@ -572,8 +612,8 @@ def test_scan_stale_dags(self): last_duration=1, run_count=1, ) - manager._file_paths = [test_dag_path] - manager._file_stats[test_dag_path] = stat + manager.processor._file_paths = [test_dag_path] + manager.processor._file_stats[test_dag_path] = stat active_dag_count = ( session.query(func.count(DagModel.dag_id)) @@ -589,7 +629,7 @@ def test_scan_stale_dags(self): ) assert serialized_dag_count == 1 - manager._scan_stale_dags() + manager.processor._scan_stale_dags() active_dag_count = ( session.query(func.count(DagModel.dag_id)) @@ -617,14 +657,17 @@ def test_scan_stale_dags_standalone_mode(self): Ensure only dags from current dag_directory are updated """ dag_directory = "directory" - manager = DagFileProcessorManager( - dag_directory=dag_directory, - max_runs=1, - processor_timeout=timedelta(minutes=10), - signal_conn=MagicMock(), - dag_ids=[], - pickle_dags=False, - async_mode=True, + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory=dag_directory, + max_runs=1, + processor_timeout=timedelta(minutes=10), + signal_conn=MagicMock(), + dag_ids=[], + pickle_dags=False, + async_mode=True, + ), ) test_dag_path = str(TEST_DAG_FOLDER / "test_example_bash_operator.py") @@ -651,13 +694,13 @@ def test_scan_stale_dags_standalone_mode(self): last_duration=1, run_count=1, ) - manager._file_paths = [test_dag_path] - manager._file_stats[test_dag_path] = stat + manager.processor._file_paths = [test_dag_path] + manager.processor._file_stats[test_dag_path] = stat active_dag_count = session.query(func.count(DagModel.dag_id)).filter(DagModel.is_active).scalar() assert active_dag_count == 2 - manager._scan_stale_dags() + manager.processor._scan_stale_dags() active_dag_count = session.query(func.count(DagModel.dag_id)).filter(DagModel.is_active).scalar() assert active_dag_count == 1 @@ -670,14 +713,17 @@ def test_scan_stale_dags_standalone_mode(self): def test_kill_timed_out_processors_kill(self, mock_kill, mock_pid, mock_waitable_handle): mock_pid.return_value = 1234 mock_waitable_handle.return_value = 3 - manager = DagFileProcessorManager( - dag_directory="directory", - max_runs=1, - processor_timeout=timedelta(seconds=5), - signal_conn=MagicMock(), - dag_ids=[], - pickle_dags=False, - async_mode=True, + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory="directory", + max_runs=1, + processor_timeout=timedelta(seconds=5), + signal_conn=MagicMock(), + dag_ids=[], + pickle_dags=False, + async_mode=True, + ), ) processor = DagFileProcessorProcess( @@ -688,26 +734,29 @@ def test_kill_timed_out_processors_kill(self, mock_kill, mock_pid, mock_waitable callback_requests=[], ) processor._start_time = timezone.make_aware(datetime.min) - manager._processors = {"abc.txt": processor} - manager.waitables[3] = processor - initial_waitables = len(manager.waitables) - manager._kill_timed_out_processors() + manager.processor._processors = {"abc.txt": processor} + manager.processor.waitables[3] = processor + initial_waitables = len(manager.processor.waitables) + manager.processor._kill_timed_out_processors() mock_kill.assert_called_once_with() - assert len(manager._processors) == 0 - assert len(manager.waitables) == initial_waitables - 1 + assert len(manager.processor._processors) == 0 + assert len(manager.processor.waitables) == initial_waitables - 1 @mock.patch("airflow.dag_processing.processor.DagFileProcessorProcess.pid", new_callable=PropertyMock) @mock.patch("airflow.dag_processing.processor.DagFileProcessorProcess") def test_kill_timed_out_processors_no_kill(self, mock_dag_file_processor, mock_pid): mock_pid.return_value = 1234 - manager = DagFileProcessorManager( - dag_directory=TEST_DAG_FOLDER, - max_runs=1, - processor_timeout=timedelta(seconds=5), - signal_conn=MagicMock(), - dag_ids=[], - pickle_dags=False, - async_mode=True, + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory=TEST_DAG_FOLDER, + max_runs=1, + processor_timeout=timedelta(seconds=5), + signal_conn=MagicMock(), + dag_ids=[], + pickle_dags=False, + async_mode=True, + ), ) processor = DagFileProcessorProcess( @@ -718,8 +767,8 @@ def test_kill_timed_out_processors_no_kill(self, mock_dag_file_processor, mock_p callback_requests=[], ) processor._start_time = timezone.make_aware(datetime.max) - manager._processors = {"abc.txt": processor} - manager._kill_timed_out_processors() + manager.processor._processors = {"abc.txt": processor} + manager.processor._kill_timed_out_processors() mock_dag_file_processor.kill.assert_not_called() @conf_vars({("core", "load_examples"): "False"}) @@ -738,17 +787,20 @@ def test_dag_with_system_exit(self): child_pipe, parent_pipe = multiprocessing.Pipe() - manager = DagFileProcessorManager( - dag_directory=dag_directory, - dag_ids=[], - max_runs=1, - processor_timeout=timedelta(seconds=5), - signal_conn=child_pipe, - pickle_dags=False, - async_mode=True, + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory=dag_directory, + dag_ids=[], + max_runs=1, + processor_timeout=timedelta(seconds=5), + signal_conn=child_pipe, + pickle_dags=False, + async_mode=True, + ), ) - manager._run_parsing_loop() + manager.processor._run_parsing_loop() result = None while parent_pipe.poll(timeout=None): @@ -757,7 +809,7 @@ def test_dag_with_system_exit(self): break # Three files in folder should be processed - assert sum(stat.run_count for stat in manager._file_stats.values()) == 3 + assert sum(stat.run_count for stat in manager.processor._file_stats.values()) == 3 with create_session() as session: assert session.get(DagModel, dag_id) is not None @@ -790,7 +842,7 @@ def keep_pipe_full(pipe, exit_event): req = CallbackRequest(str(dag_filepath)) try: - logging.debug("Sending CallbackRequests %d", n + 1) + logging.info("Sending CallbackRequests %d", n + 1) pipe.send(req) except TypeError: # This is actually the error you get when the parent pipe @@ -834,7 +886,10 @@ def fake_processor_(*args, **kwargs): logging.info("Closing pipes") parent_pipe.close() child_pipe.close() + logging.info("Closed pipes") + logging.info("Joining thread") thread.join(timeout=1.0) + logging.info("Joined thread") @conf_vars({("core", "load_examples"): "False"}) @mock.patch("airflow.dag_processing.manager.Stats.timing") @@ -852,18 +907,21 @@ def test_send_file_processing_statsd_timing(self, statsd_timing_mock, tmpdir): child_pipe, parent_pipe = multiprocessing.Pipe() async_mode = "sqlite" not in conf.get("database", "sql_alchemy_conn") - manager = DagFileProcessorManager( - dag_directory=tmpdir, - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=child_pipe, - dag_ids=[], - pickle_dags=False, - async_mode=async_mode, + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory=tmpdir, + max_runs=1, + processor_timeout=timedelta(days=365), + signal_conn=child_pipe, + dag_ids=[], + pickle_dags=False, + async_mode=async_mode, + ), ) self.run_processor_manager_one_loop(manager, parent_pipe) - last_runtime = manager.get_last_runtime(manager.file_paths[0]) + last_runtime = manager.processor.get_last_runtime(manager.processor.file_paths[0]) child_pipe.close() parent_pipe.close() @@ -881,15 +939,18 @@ def test_send_file_processing_statsd_timing(self, statsd_timing_mock, tmpdir): ) def test_refresh_dags_dir_doesnt_delete_zipped_dags(self, tmpdir): - """Test DagFileProcessorManager._refresh_dag_dir method""" - manager = DagFileProcessorManager( - dag_directory=TEST_DAG_FOLDER, - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - dag_ids=[], - pickle_dags=False, - async_mode=True, + """Test DagProcessorJobRunner._refresh_dag_dir method""" + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory=TEST_DAG_FOLDER, + max_runs=1, + processor_timeout=timedelta(days=365), + signal_conn=MagicMock(), + dag_ids=[], + pickle_dags=False, + async_mode=True, + ), ) dagbag = DagBag(dag_folder=tmpdir, include_examples=False) zipped_dag_path = os.path.join(TEST_DAGS_FOLDER, "test_zip.zip") @@ -897,8 +958,8 @@ def test_refresh_dags_dir_doesnt_delete_zipped_dags(self, tmpdir): dag = dagbag.get_dag("test_zip_dag") dag.sync_to_db() SerializedDagModel.write_dag(dag) - manager.last_dag_dir_refresh_time = timezone.utcnow() - timedelta(minutes=10) - manager._refresh_dag_dir() + manager.processor.last_dag_dir_refresh_time = timezone.utcnow() - timedelta(minutes=10) + manager.processor._refresh_dag_dir() # Assert dag not deleted in SDM assert SerializedDagModel.has_dag("test_zip_dag") # assert code not deleted @@ -911,7 +972,7 @@ def test_refresh_dags_dir_doesnt_delete_zipped_dags(self, tmpdir): } ) def test_fetch_callbacks_from_database(self, tmpdir): - """Test DagFileProcessorManager._fetch_callbacks method""" + """Test DagProcessorJobRunner._fetch_callbacks method""" dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py" callback1 = DagCallbackRequest( @@ -940,14 +1001,17 @@ def test_fetch_callbacks_from_database(self, tmpdir): session.add(DbCallbackRequest(callback=callback3, priority_weight=9)) child_pipe, parent_pipe = multiprocessing.Pipe() - manager = DagFileProcessorManager( - dag_directory=str(tmpdir), - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=child_pipe, - dag_ids=[], - pickle_dags=False, - async_mode=False, + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory=str(tmpdir), + max_runs=1, + processor_timeout=timedelta(days=365), + signal_conn=child_pipe, + dag_ids=[], + pickle_dags=False, + async_mode=False, + ), ) with create_session() as session: @@ -961,7 +1025,7 @@ def test_fetch_callbacks_from_database(self, tmpdir): } ) def test_fetch_callbacks_for_current_dag_directory_only(self, tmpdir): - """Test DagFileProcessorManager._fetch_callbacks method""" + """Test DagProcessorJobRunner._fetch_callbacks method""" dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py" callback1 = DagCallbackRequest( @@ -984,14 +1048,17 @@ def test_fetch_callbacks_for_current_dag_directory_only(self, tmpdir): session.add(DbCallbackRequest(callback=callback2, priority_weight=10)) child_pipe, parent_pipe = multiprocessing.Pipe() - manager = DagFileProcessorManager( - dag_directory=tmpdir, - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=child_pipe, - dag_ids=[], - pickle_dags=False, - async_mode=False, + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory=tmpdir, + max_runs=1, + processor_timeout=timedelta(days=365), + signal_conn=child_pipe, + dag_ids=[], + pickle_dags=False, + async_mode=False, + ), ) with create_session() as session: @@ -1006,7 +1073,7 @@ def test_fetch_callbacks_for_current_dag_directory_only(self, tmpdir): } ) def test_fetch_callbacks_from_database_max_per_loop(self, tmpdir): - """Test DagFileProcessorManager._fetch_callbacks method""" + """Test DagProcessorJobRunner._fetch_callbacks method""" dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py" with create_session() as session: @@ -1021,14 +1088,17 @@ def test_fetch_callbacks_from_database_max_per_loop(self, tmpdir): session.add(DbCallbackRequest(callback=callback, priority_weight=i)) child_pipe, parent_pipe = multiprocessing.Pipe() - manager = DagFileProcessorManager( - dag_directory=str(tmpdir), - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=child_pipe, - dag_ids=[], - pickle_dags=False, - async_mode=False, + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory=str(tmpdir), + max_runs=1, + processor_timeout=timedelta(days=365), + signal_conn=child_pipe, + dag_ids=[], + pickle_dags=False, + async_mode=False, + ), ) with create_session() as session: @@ -1059,14 +1129,17 @@ def test_fetch_callbacks_from_database_not_standalone(self, tmpdir): session.add(DbCallbackRequest(callback=callback, priority_weight=10)) child_pipe, parent_pipe = multiprocessing.Pipe() - manager = DagFileProcessorManager( - dag_directory=tmpdir, - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=child_pipe, - dag_ids=[], - pickle_dags=False, - async_mode=False, + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory=tmpdir, + max_runs=1, + processor_timeout=timedelta(days=365), + signal_conn=child_pipe, + dag_ids=[], + pickle_dags=False, + async_mode=False, + ), ) with create_session() as session: @@ -1079,14 +1152,17 @@ def test_fetch_callbacks_from_database_not_standalone(self, tmpdir): def test_callback_queue(self, tmpdir): # given - manager = DagFileProcessorManager( - dag_directory=TEST_DAG_FOLDER, - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - dag_ids=[], - pickle_dags=False, - async_mode=True, + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory=TEST_DAG_FOLDER, + max_runs=1, + processor_timeout=timedelta(days=365), + signal_conn=MagicMock(), + dag_ids=[], + pickle_dags=False, + async_mode=True, + ), ) dag1_req1 = DagCallbackRequest( @@ -1132,38 +1208,45 @@ def test_callback_queue(self, tmpdir): ) # when - manager._add_callback_to_queue(dag1_req1) - manager._add_callback_to_queue(dag1_sla1) - manager._add_callback_to_queue(dag2_req1) + manager.processor._add_callback_to_queue(dag1_req1) + manager.processor._add_callback_to_queue(dag1_sla1) + manager.processor._add_callback_to_queue(dag2_req1) # then - requests should be in manager's queue, with dag2 ahead of dag1 (because it was added last) - assert manager._file_path_queue == collections.deque( + assert manager.processor._file_path_queue == collections.deque( [dag2_req1.full_filepath, dag1_req1.full_filepath] ) - assert set(manager._callback_to_execute.keys()) == {dag1_req1.full_filepath, dag2_req1.full_filepath} - assert manager._callback_to_execute[dag1_req1.full_filepath] == [dag1_req1, dag1_sla1] - assert manager._callback_to_execute[dag2_req1.full_filepath] == [dag2_req1] + assert set(manager.processor._callback_to_execute.keys()) == { + dag1_req1.full_filepath, + dag2_req1.full_filepath, + } + assert manager.processor._callback_to_execute[dag1_req1.full_filepath] == [dag1_req1, dag1_sla1] + assert manager.processor._callback_to_execute[dag2_req1.full_filepath] == [dag2_req1] # when - manager._add_callback_to_queue(dag1_sla2) - manager._add_callback_to_queue(dag3_sla1) + manager.processor._add_callback_to_queue(dag1_sla2) + manager.processor._add_callback_to_queue(dag3_sla1) # then - since sla2 == sla1, should not have brought dag1 to the fore, and an SLA on dag3 doesn't # update the queue, although the callback is registered - assert manager._file_path_queue == collections.deque( + assert manager.processor._file_path_queue == collections.deque( [dag2_req1.full_filepath, dag1_req1.full_filepath] ) - assert manager._callback_to_execute[dag1_req1.full_filepath] == [dag1_req1, dag1_sla1] - assert manager._callback_to_execute[dag3_sla1.full_filepath] == [dag3_sla1] + assert manager.processor._callback_to_execute[dag1_req1.full_filepath] == [dag1_req1, dag1_sla1] + assert manager.processor._callback_to_execute[dag3_sla1.full_filepath] == [dag3_sla1] # when - manager._add_callback_to_queue(dag1_req2) + manager.processor._add_callback_to_queue(dag1_req2) # then - non-sla callback should have brought dag1 to the fore - assert manager._file_path_queue == collections.deque( + assert manager.processor._file_path_queue == collections.deque( [dag1_req1.full_filepath, dag2_req1.full_filepath] ) - assert manager._callback_to_execute[dag1_req1.full_filepath] == [dag1_req1, dag1_sla1, dag1_req2] + assert manager.processor._callback_to_execute[dag1_req1.full_filepath] == [ + dag1_req1, + dag1_sla1, + dag1_req2, + ] class TestDagFileProcessorAgent: diff --git a/tests/executors/test_dask_executor.py b/tests/executors/test_dask_executor.py index 9753de911ad31..f64f2a7dfdc72 100644 --- a/tests/executors/test_dask_executor.py +++ b/tests/executors/test_dask_executor.py @@ -109,15 +109,16 @@ def test_backfill_integration(self): dag = self.dagbag.get_dag("example_bash_operator") job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - ignore_first_depends_on_past=True, - ), executor=DaskExecutor(cluster_address=self.cluster.scheduler_address), ) - run_job(job) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE, + ignore_first_depends_on_past=True, + ) + run_job(job=job, execute_callable=job_runner._execute) def teardown_method(self): self.cluster.close(timeout=5) diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py index e464e41bc65bf..7001f6f1b917f 100644 --- a/tests/jobs/test_backfill_job.py +++ b/tests/jobs/test_backfill_job.py @@ -121,17 +121,16 @@ def test_unfinished_dag_runs_set_to_failed(self, dag_maker): dag = self._get_dummy_dag(dag_maker) dag_run = dag_maker.create_dagrun(state=None) - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=8), - ignore_first_depends_on_past=True, - ) + job = Job(executor=MockExecutor()) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=8), + ignore_first_depends_on_past=True, ) - job.job_runner._set_unfinished_dag_runs_to_failed([dag_run]) - + job_runner._set_unfinished_dag_runs_to_failed([dag_run]) dag_run.refresh_from_db() assert State.FAILED == dag_run.state @@ -143,16 +142,15 @@ def test_dag_run_with_finished_tasks_set_to_success(self, dag_maker): for ti in dag_run.get_task_instances(): ti.set_state(State.SUCCESS) - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=8), - ignore_first_depends_on_past=True, - ) + job = Job(executor=MockExecutor()) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=8), + ignore_first_depends_on_past=True, ) - - job.job_runner._set_unfinished_dag_runs_to_failed([dag_run]) + job_runner._set_unfinished_dag_runs_to_failed([dag_run]) dag_run.refresh_from_db() @@ -167,17 +165,20 @@ def test_trigger_controller_dag(self, session): target_dag_run = session.query(DagRun).filter(DagRun.dag_id == target_dag.dag_id).one_or_none() assert target_dag_run is None - job = Job( - job_runner=BackfillJobRunner( - dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_first_depends_on_past=True - ) + job = Job(executor=MockExecutor()) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE, + ignore_first_depends_on_past=True, ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) dag_run = session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).one_or_none() assert dag_run is not None - task_instances_list = job.job_runner._task_instances_for_dag_run(dag=dag, dag_run=dag_run) + task_instances_list = job_runner._task_instances_for_dag_run(dag=dag, dag_run=dag_run) assert task_instances_list @@ -188,17 +189,16 @@ def test_backfill_multi_dates(self): end_date = DEFAULT_DATE + datetime.timedelta(days=1) executor = MockExecutor(parallelism=16) - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=end_date, - ignore_first_depends_on_past=True, - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=end_date, + ignore_first_depends_on_past=True, ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) expected_execution_order = [ ("runme_0", DEFAULT_DATE), @@ -282,17 +282,16 @@ def test_backfill_examples(self, dag_id, expected_execution_order): logger.info("*** Running example DAG: %s", dag.dag_id) executor = MockExecutor() - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - ignore_first_depends_on_past=True, - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE, + ignore_first_depends_on_past=True, ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) assert [ ((dag_id, task_id, f"backfill__{DEFAULT_DATE.isoformat()}", 1, -1), (State.SUCCESS, None)) for task_id in expected_execution_order @@ -305,16 +304,15 @@ def test_backfill_conf(self, dag_maker): executor = MockExecutor() conf_ = json.loads("""{"key": "value"}""") - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - conf=conf_, - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), + conf=conf_, ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) # We ignore the first dag_run created by fixture dr = DagRun.find( @@ -335,16 +333,15 @@ def test_backfill_respect_max_active_tis_per_dag_limit(self, mock_log, dag_maker executor = MockExecutor() - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=7), - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=7), ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) assert len(executor.history) > 0 @@ -387,16 +384,14 @@ def test_backfill_respect_dag_concurrency_limit(self, mock_log, dag_maker): executor = MockExecutor() - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=7), - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=7), ) - - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) assert len(executor.history) > 0 @@ -442,16 +437,15 @@ def test_backfill_respect_default_pool_limit(self, mock_log, dag_maker): executor = MockExecutor() - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=7), - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=7), ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) assert len(executor.history) > 0 @@ -500,17 +494,16 @@ def test_backfill_pool_not_found(self, dag_maker): executor = MockExecutor() - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=7), - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=7), ) try: - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) except AirflowException: return @@ -535,16 +528,15 @@ def test_backfill_respect_pool_limit(self, mock_log, dag_maker): executor = MockExecutor() - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=7), - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=7), ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) assert len(executor.history) > 0 @@ -587,30 +579,28 @@ def test_backfill_run_rescheduled(self, dag_maker): executor = MockExecutor() - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) ti = TI(task=dag.get_task("test_backfill_run_rescheduled_task-1"), execution_date=DEFAULT_DATE) ti.refresh_from_db() ti.set_state(State.UP_FOR_RESCHEDULE) - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - rerun_failed_tasks=True, - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), + rerun_failed_tasks=True, ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) ti = TI(task=dag.get_task("test_backfill_run_rescheduled_task-1"), execution_date=DEFAULT_DATE) ti.refresh_from_db() assert ti.state == State.SUCCESS @@ -626,22 +616,21 @@ def test_backfill_override_conf(self, dag_maker): executor = MockExecutor() - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - conf={"a": 1}, - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), + conf={"a": 1}, ) with patch.object( - job.job_runner, + job_runner, "_task_instances_for_dag_run", - wraps=job.job_runner._task_instances_for_dag_run, + wraps=job_runner._task_instances_for_dag_run, ) as wrapped_task_instances_for_dag_run: - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) dr = wrapped_task_instances_for_dag_run.call_args_list[0][0][1] assert dr.conf == {"a": 1} @@ -658,17 +647,16 @@ def test_backfill_skip_active_scheduled_dagrun(self, dag_maker, caplog): executor = MockExecutor() - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), ) with caplog.at_level(logging.ERROR, logger="airflow.jobs.backfill_job_runner.BackfillJob"): caplog.clear() - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) assert "Backfill cannot be created for DagRun" in caplog.messages[0] ti = TI( @@ -686,30 +674,28 @@ def test_backfill_rerun_failed_tasks(self, dag_maker): executor = MockExecutor() - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) ti = TI(task=dag.get_task("test_backfill_rerun_failed_task-1"), execution_date=DEFAULT_DATE) ti.refresh_from_db() ti.set_state(State.FAILED) - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - rerun_failed_tasks=True, - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), + rerun_failed_tasks=True, ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) ti = TI(task=dag.get_task("test_backfill_rerun_failed_task-1"), execution_date=DEFAULT_DATE) ti.refresh_from_db() assert ti.state == State.SUCCESS @@ -724,30 +710,28 @@ def test_backfill_rerun_upstream_failed_tasks(self, dag_maker): executor = MockExecutor() - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) ti = TI(task=dag.get_task("test_backfill_rerun_upstream_failed_task-1"), execution_date=DEFAULT_DATE) ti.refresh_from_db() ti.set_state(State.UPSTREAM_FAILED) - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - rerun_failed_tasks=True, - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), + rerun_failed_tasks=True, ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) ti = TI(task=dag.get_task("test_backfill_rerun_upstream_failed_task-1"), execution_date=DEFAULT_DATE) ti.refresh_from_db() assert ti.state == State.SUCCESS @@ -760,32 +744,30 @@ def test_backfill_rerun_failed_tasks_without_flag(self, dag_maker): executor = MockExecutor() - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) ti = TI(task=dag.get_task("test_backfill_rerun_failed_task-1"), execution_date=DEFAULT_DATE) ti.refresh_from_db() ti.set_state(State.FAILED) - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - rerun_failed_tasks=False, - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), + rerun_failed_tasks=False, ) with pytest.raises(AirflowException): - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) def test_backfill_retry_intermittent_failed_task(self, dag_maker): with dag_maker( @@ -806,15 +788,14 @@ def test_backfill_retry_intermittent_failed_task(self, dag_maker): executor.mock_task_results[ TaskInstanceKey(dag.dag_id, task1.task_id, DEFAULT_DATE, try_number=2) ] = State.UP_FOR_RETRY - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) def test_backfill_retry_always_failed_task(self, dag_maker): with dag_maker( @@ -833,16 +814,15 @@ def test_backfill_retry_always_failed_task(self, dag_maker): TaskInstanceKey(dag.dag_id, task1.task_id, dr.run_id, try_number=1) ] = State.UP_FOR_RETRY executor.mock_task_fail(dag.dag_id, task1.task_id, dr.run_id, try_number=2) - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE, ) with pytest.raises(BackfillUnfinished): - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) def test_backfill_ordered_concurrent_execute(self, dag_maker): @@ -864,15 +844,14 @@ def test_backfill_ordered_concurrent_execute(self, dag_maker): dag_maker.create_dagrun(run_id=runid0) executor = MockExecutor(parallelism=16) - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=2), - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=2), ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) runid1 = f"backfill__{(DEFAULT_DATE + datetime.timedelta(days=1)).isoformat()}" runid2 = f"backfill__{(DEFAULT_DATE + datetime.timedelta(days=2)).isoformat()}" @@ -908,16 +887,14 @@ def test_backfill_pooled_tasks(self): dag.clear() executor = MockExecutor(do_update=True) - job = Job( - job_runner=BackfillJobRunner(dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE), - executor=executor, - ) + job = Job(executor=executor) + job_runner = BackfillJobRunner(job=job, dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # run with timeout because this creates an infinite loop if not # caught try: with timeout(seconds=5): - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) except AirflowTaskTimeout: pass ti = TI(task=dag.get_task("test_backfill_pooled_task"), execution_date=DEFAULT_DATE) @@ -932,17 +909,15 @@ def test_backfill_depends_on_past_works_independently_on_ignore_depends_on_past( dag.clear() run_date = DEFAULT_DATE + datetime.timedelta(days=5) - run_job( - Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=run_date, - end_date=run_date, - ignore_first_depends_on_past=ignore_depends_on_past, - ), - executor=MockExecutor(), - ) + job = Job(executor=MockExecutor()) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=run_date, + end_date=run_date, + ignore_first_depends_on_past=ignore_depends_on_past, ) + run_job(job=job, execute_callable=job_runner._execute) # ti should have succeeded ti = TI(dag.tasks[0], run_date) @@ -964,11 +939,9 @@ def test_backfill_depends_on_past_backwards(self): dag.clear() executor = MockExecutor() - job = Job( - job_runner=BackfillJobRunner(dag=dag, ignore_first_depends_on_past=True, **kwargs), - executor=executor, - ) - run_job(job) + job = Job(executor=executor) + job_runner = BackfillJobRunner(job=job, dag=dag, ignore_first_depends_on_past=True, **kwargs) + run_job(job=job, execute_callable=job_runner._execute) ti = TI(dag.get_task("test_dop_task"), end_date) ti.refresh_from_db() @@ -979,11 +952,9 @@ def test_backfill_depends_on_past_backwards(self): expected_msg = "You cannot backfill backwards because one or more tasks depend_on_past: test_dop_task" with pytest.raises(AirflowException, match=expected_msg): executor = MockExecutor() - job = Job( - job_runner=BackfillJobRunner(dag=dag, run_backwards=True, **kwargs), - executor=executor, - ) - run_job(job) + job = Job(executor=executor) + job_runner = BackfillJobRunner(job=job, dag=dag, run_backwards=True, **kwargs) + run_job(job=job, execute_callable=job_runner._execute) def test_cli_receives_delay_arg(self): """ @@ -1030,16 +1001,15 @@ def test_backfill_max_limit_check_within_limit(self, dag_maker): end_date = DEFAULT_DATE executor = MockExecutor() - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=start_date, - end_date=end_date, - donot_pickle=True, - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=start_date, + end_date=end_date, + donot_pickle=True, ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) dagruns = DagRun.find(dag_id=dag.dag_id) assert 2 == len(dagruns) @@ -1055,17 +1025,16 @@ def test_backfill_notifies_dagrun_listener(self, dag_maker): end_date = DEFAULT_DATE executor = MockExecutor() - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=start_date, - end_date=end_date, - donot_pickle=True, - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=start_date, + end_date=end_date, + donot_pickle=True, ) job.notification_threadpool = mock.MagicMock() - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) assert len(dag_listener.running) == 1 assert len(dag_listener.success) == 1 @@ -1110,15 +1079,16 @@ def run_backfill(cond): executor = MockExecutor() job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=start_date, - end_date=end_date, - donot_pickle=True, - ), executor=executor, ) - run_job(job) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=start_date, + end_date=end_date, + donot_pickle=True, + ) + run_job(job=job, execute_callable=job_runner._execute) backfill_job_thread = threading.Thread( target=run_backfill, name="run_backfill", args=(dag_run_created_cond,) @@ -1162,13 +1132,11 @@ def test_backfill_max_limit_check_no_count_existing(self, dag_maker): dag_maker.create_dagrun(state=None) executor = MockExecutor() - job = Job( - job_runner=BackfillJobRunner( - dag=dag, start_date=start_date, end_date=end_date, donot_pickle=True - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, dag=dag, start_date=start_date, end_date=end_date, donot_pickle=True ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) # BackfillJobRunner will run since the existing DagRun does not count for the max # active limit since it's within the backfill date range. @@ -1190,16 +1158,15 @@ def test_backfill_max_limit_check_complete_loop(self, dag_maker): # backfill job 3 times success_expected = 2 executor = MockExecutor() - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=start_date, - end_date=end_date, - donot_pickle=True, - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=start_date, + end_date=end_date, + donot_pickle=True, ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) success_dagruns = len(DagRun.find(dag_id=dag.dag_id, state=State.SUCCESS)) running_dagruns = len(DagRun.find(dag_id=dag.dag_id, state=State.RUNNING)) @@ -1228,11 +1195,9 @@ def test_sub_set_subdag(self, dag_maker): sub_dag = dag.partial_subset( task_ids_or_regex="leave*", include_downstream=False, include_upstream=False ) - job = Job( - job_runner=BackfillJobRunner(dag=sub_dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE), - executor=executor, - ) - run_job(job) + job = Job(executor=executor) + job_runner = BackfillJobRunner(job=job, dag=sub_dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + run_job(job=job, execute_callable=job_runner._execute) for ti in dr.get_task_instances(): if ti.task_id == "leave1" or ti.task_id == "leave2": @@ -1275,12 +1240,10 @@ def test_backfill_fill_blanks(self, dag_maker): session.commit() session.close() - job = Job( - job_runner=BackfillJobRunner(dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE), - executor=executor, - ) + job = Job(executor=executor) + job_runner = BackfillJobRunner(job=job, dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with pytest.raises(AirflowException, match="Some task instances failed"): - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) dr.refresh_from_db() @@ -1306,16 +1269,15 @@ def test_backfill_execute_subdag(self): start_date = timezone.utcnow() executor = MockExecutor() - job = Job( - job_runner=BackfillJobRunner( - dag=subdag, - start_date=start_date, - end_date=start_date, - donot_pickle=True, - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=subdag, + start_date=start_date, + end_date=start_date, + donot_pickle=True, ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) subdag_op_task.pre_execute(context={"execution_date": start_date}) subdag_op_task.execute(context={"execution_date": start_date}) @@ -1351,18 +1313,17 @@ def test_subdag_clear_parentdag_downstream_clear(self): subdag = subdag_op_task.subdag executor = MockExecutor() - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - donot_pickle=True, - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE, + donot_pickle=True, ) with timeout(seconds=30): - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) ti_subdag = TI(task=dag.get_task("daily_job"), execution_date=DEFAULT_DATE) ti_subdag.refresh_from_db() @@ -1405,14 +1366,13 @@ def test_backfill_execute_subdag_with_removed_task(self): session = settings.Session() executor = MockExecutor() - job = Job( - job_runner=BackfillJobRunner( - dag=subdag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - donot_pickle=True, - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=subdag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE, + donot_pickle=True, ) dr = DagRun( dag_id=subdag.dag_id, execution_date=DEFAULT_DATE, run_id="test", run_type=DagRunType.BACKFILL_JOB @@ -1428,7 +1388,7 @@ def test_backfill_execute_subdag_with_removed_task(self): session.commit() with timeout(seconds=30): - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) for task in subdag.tasks: instance = ( @@ -1452,8 +1412,8 @@ def test_update_counters(self, dag_maker, session): with dag_maker(dag_id="test_manage_executor_state", start_date=DEFAULT_DATE, session=session) as dag: task1 = EmptyOperator(task_id="dummy", owner="airflow") dr = dag_maker.create_dagrun(state=None) - job = Job(job_runner=BackfillJobRunner(dag=dag)) - + job = Job() + job_runner = BackfillJobRunner(job=job, dag=dag) ti = TI(task1, dr.execution_date) ti.refresh_from_db() @@ -1470,7 +1430,7 @@ def test_update_counters(self, dag_maker, session): ti_status.running[ti.key] = ti # Task is queued and marked as running ti._try_number += 1 # Try number is increased during ti.run() ti.set_state(State.SUCCESS, session) # Task finishes with success state - job.job_runner._update_counters(ti_status=ti_status, session=session) # Update counters + job_runner._update_counters(ti_status=ti_status, session=session) # Update counters assert len(ti_status.running) == 0 assert len(ti_status.succeeded) == 1 assert len(ti_status.skipped) == 0 @@ -1483,7 +1443,7 @@ def test_update_counters(self, dag_maker, session): ti_status.running[ti.key] = ti ti._try_number += 1 ti.set_state(State.SKIPPED, session) - job.job_runner._update_counters(ti_status=ti_status, session=session) + job_runner._update_counters(ti_status=ti_status, session=session) assert len(ti_status.running) == 0 assert len(ti_status.succeeded) == 0 assert len(ti_status.skipped) == 1 @@ -1496,7 +1456,7 @@ def test_update_counters(self, dag_maker, session): ti_status.running[ti.key] = ti ti._try_number += 1 ti.set_state(State.FAILED, session) - job.job_runner._update_counters(ti_status=ti_status, session=session) + job_runner._update_counters(ti_status=ti_status, session=session) assert len(ti_status.running) == 0 assert len(ti_status.succeeded) == 0 assert len(ti_status.skipped) == 0 @@ -1509,7 +1469,7 @@ def test_update_counters(self, dag_maker, session): ti_status.running[ti.key] = ti ti._try_number += 1 ti.set_state(State.UP_FOR_RETRY, session) - job.job_runner._update_counters(ti_status=ti_status, session=session) + job_runner._update_counters(ti_status=ti_status, session=session) assert len(ti_status.running) == 0 assert len(ti_status.succeeded) == 0 assert len(ti_status.skipped) == 0 @@ -1530,7 +1490,7 @@ def test_update_counters(self, dag_maker, session): ti._try_number += 1 # Try number is increased during ti.run() ti._try_number -= 1 # Task is being rescheduled, decrement try_number ti.set_state(State.UP_FOR_RESCHEDULE, session) # Task finishes with reschedule state - job.job_runner._update_counters(ti_status=ti_status, session=session) + job_runner._update_counters(ti_status=ti_status, session=session) assert len(ti_status.running) == 0 assert len(ti_status.succeeded) == 0 assert len(ti_status.skipped) == 0 @@ -1548,7 +1508,7 @@ def test_update_counters(self, dag_maker, session): session.merge(ti) session.commit() ti_status.running[ti.key] = ti - job.job_runner._update_counters(ti_status=ti_status, session=session) + job_runner._update_counters(ti_status=ti_status, session=session) assert len(ti_status.running) == 0 assert len(ti_status.succeeded) == 0 assert len(ti_status.skipped) == 0 @@ -1562,7 +1522,7 @@ def test_update_counters(self, dag_maker, session): # Deferred tasks are put into scheduled by the triggerer # Check that they are put into to_run ti_status.running[ti.key] = ti - job.job_runner._update_counters(ti_status=ti_status, session=session) + job_runner._update_counters(ti_status=ti_status, session=session) assert len(ti_status.running) == 0 assert len(ti_status.succeeded) == 0 assert len(ti_status.skipped) == 0 @@ -1575,7 +1535,7 @@ def test_update_counters(self, dag_maker, session): # to reschedule it, we should leave it in ti_status.running ti.set_state(State.DEFERRED) ti_status.running[ti.key] = ti - job.job_runner._update_counters(ti_status=ti_status, session=session) + job_runner._update_counters(ti_status=ti_status, session=session) assert len(ti_status.running) == 1 assert len(ti_status.succeeded) == 0 assert len(ti_status.skipped) == 0 @@ -1618,16 +1578,15 @@ def test_backfill_run_backwards(self): executor = MockExecutor(parallelism=16) - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=1), - run_backwards=True, - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=1), + run_backwards=True, ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) session = settings.Session() tis = ( @@ -1659,8 +1618,8 @@ def test_reset_orphaned_tasks_with_orphans(self, dag_maker): tasks.append(task) session = settings.Session() - job = Job(job_runner=BackfillJobRunner(dag=dag)) - + job = Job() + job_runner = BackfillJobRunner(job=job, dag=dag) # create dagruns dr1 = dag_maker.create_dagrun(state=State.RUNNING) dr2 = dag.create_dagrun(run_id="test2", state=State.SUCCESS) @@ -1681,7 +1640,7 @@ def test_reset_orphaned_tasks_with_orphans(self, dag_maker): session.merge(ti2) session.commit() - assert 2 == job.job_runner.reset_state_for_orphaned_tasks() + assert 2 == job_runner.reset_state_for_orphaned_tasks() for ti in dr1_tis + dr2_tis: ti.refresh_from_db() @@ -1701,7 +1660,7 @@ def test_reset_orphaned_tasks_with_orphans(self, dag_maker): ti.state = state session.commit() - job.job_runner.reset_state_for_orphaned_tasks(filter_by_dag_run=dr1, session=session) + job_runner.reset_state_for_orphaned_tasks(filter_by_dag_run=dr1, session=session) # check same for dag_run version for state, ti in zip(states, dr2_tis): @@ -1719,7 +1678,8 @@ def test_reset_orphaned_tasks_specified_dagrun(self, session, dag_maker): ) as dag: EmptyOperator(task_id=task_id, dag=dag) - job = Job(job_runner=BackfillJobRunner(dag=dag)) + job = Job() + job_runner = BackfillJobRunner(job=job, dag=dag) # make two dagruns, only reset for one dr1 = dag_maker.create_dagrun(state=State.SUCCESS) dr2 = dag.create_dagrun(run_id="test2", state=State.RUNNING, session=session) @@ -1734,7 +1694,7 @@ def test_reset_orphaned_tasks_specified_dagrun(self, session, dag_maker): session.merge(dr2) session.flush() - num_reset_tis = job.job_runner.reset_state_for_orphaned_tasks(filter_by_dag_run=dr2, session=session) + num_reset_tis = job_runner.reset_state_for_orphaned_tasks(filter_by_dag_run=dr2, session=session) assert 1 == num_reset_tis ti1.refresh_from_db(session=session) ti2.refresh_from_db(session=session) @@ -1746,11 +1706,11 @@ def test_job_id_is_assigned_to_dag_run(self, dag_maker): with dag_maker(dag_id=dag_id, start_date=DEFAULT_DATE, schedule="@daily") as dag: EmptyOperator(task_id="dummy_task", dag=dag) - job = Job( - job_runner=BackfillJobRunner(dag=dag, start_date=timezone.utcnow() - datetime.timedelta(days=1)), - executor=MockExecutor(), + job = Job(executor=MockExecutor()) + job_runner = BackfillJobRunner( + job=job, dag=dag, start_date=timezone.utcnow() - datetime.timedelta(days=1) ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) dr: DagRun = dag.get_last_dagrun() assert dr.creating_job_id == job.id @@ -1761,16 +1721,15 @@ def test_backfill_has_job_id_int(self): executor = MockExecutor(parallelism=16) - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=1), - run_backwards=True, - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=1), + run_backwards=True, ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) assert isinstance(executor.job_id, int) @pytest.mark.long_running @@ -1792,16 +1751,15 @@ def test_backfilling_dags(self, dag_id, executor_name, session): when = timezone.datetime(2022, 1, 1) - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=when, - end_date=when, - donot_pickle=True, - executor=ExecutorLoader.load_executor(executor_name), - ) + job = Job(executor=ExecutorLoader.load_executor(executor_name)) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=when, + end_date=when, + donot_pickle=True, ) - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) dr = DagRun.find(dag_id=dag.dag_id, execution_date=when, session=session)[0] assert dr @@ -1852,14 +1810,13 @@ def consumer(value): ti_status.active_runs.append(dr) ti_status.to_run = {ti.key: ti for ti in dr.task_instances} - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=dr.execution_date, - end_date=dr.execution_date, - donot_pickle=True, - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=dr.execution_date, + end_date=dr.execution_date, + donot_pickle=True, ) executor_change_state = executor.change_state @@ -1880,7 +1837,7 @@ def on_change_state(key, state, info=None): executor_change_state(key, state, info) with patch.object(executor, "change_state", side_effect=on_change_state): - job.job_runner._process_backfill_task_instances( + job_runner._process_backfill_task_instances( ti_status=ti_status, executor=job.executor, start_date=dr.execution_date, @@ -1925,13 +1882,9 @@ def consumer(a, b): executor = MockExecutor() when = timezone.datetime(2022, 1, 1) - run_job( - Job( - job_runner=BackfillJobRunner(dag=dag, start_date=when, end_date=when, donot_pickle=True), - executor=executor, - ) - ) - + job = Job(executor=executor) + job_runner = BackfillJobRunner(job=job, dag=dag, start_date=when, end_date=when, donot_pickle=True) + run_job(job=job, execute_callable=job_runner._execute) (dr,) = DagRun.find(dag_id=dag.dag_id, execution_date=when, session=session) assert dr.state == DagRunState.FAILED @@ -1952,17 +1905,15 @@ def test_start_date_set_for_resetted_dagruns(self, dag_maker, session, caplog): session.merge(dr) session.flush() dag.clear() - run_job( - Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - donot_pickle=True, - ), - executor=MockExecutor(), - ) + job = Job(executor=MockExecutor()) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE, + donot_pickle=True, ) + run_job(job=job, execute_callable=job_runner._execute) (dr,) = DagRun.find(dag_id=dag.dag_id, execution_date=DEFAULT_DATE, session=session) assert dr.start_date @@ -1984,17 +1935,16 @@ def test_task_instances_are_not_set_to_scheduled_when_dagrun_reset(self, dag_mak dag.clear() - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE + datetime.timedelta(days=1), - end_date=DEFAULT_DATE + datetime.timedelta(days=4), - donot_pickle=True, - ), - executor=MockExecutor(), + job = Job(executor=MockExecutor()) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE + datetime.timedelta(days=1), + end_date=DEFAULT_DATE + datetime.timedelta(days=4), + donot_pickle=True, ) for dr in DagRun.find(dag_id=dag.dag_id, session=session): - tasks_to_run = job.job_runner._task_instances_for_dag_run(dag, dr, session=session) + tasks_to_run = job_runner._task_instances_for_dag_run(dag, dr, session=session) states = [ti.state for _, ti in tasks_to_run.items()] assert TaskInstanceState.SCHEDULED in states assert State.NONE in states @@ -2026,17 +1976,16 @@ def test_backfill_disable_retry(self, dag_maker, disable_retry, try_number, exce TaskInstanceKey(dag.dag_id, task1.task_id, dag_run.run_id, try_number=2) ] = TaskInstanceState.FAILED - job = Job( - job_runner=BackfillJobRunner( - dag=dag, - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - disable_retry=disable_retry, - ), - executor=executor, + job = Job(executor=executor) + job_runner = BackfillJobRunner( + job=job, + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE, + disable_retry=disable_retry, ) with pytest.raises(exception): - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) ti = dag_run.get_task_instance(task_id=task1.task_id) assert ti._try_number == try_number diff --git a/tests/jobs/test_base_job.py b/tests/jobs/test_base_job.py index f5504fb453d08..fb415e31a3865 100644 --- a/tests/jobs/test_base_job.py +++ b/tests/jobs/test_base_job.py @@ -37,8 +37,9 @@ class TestJob: def test_state_success(self): - job = Job(job_runner=MockJobRunner()) - run_job(job) + job = Job() + job_runner = MockJobRunner(job=job) + run_job(job=job, execute_callable=job_runner._execute) assert job.state == State.SUCCESS assert job.end_date is not None @@ -46,8 +47,9 @@ def test_state_success(self): def test_state_sysexit(self): import sys - job = Job(job_runner=MockJobRunner(lambda: sys.exit(0))) - run_job(job) + job = Job() + job_runner = MockJobRunner(job=job, func=lambda: sys.exit(0)) + run_job(job=job, execute_callable=job_runner._execute) assert job.state == State.SUCCESS assert job.end_date is not None @@ -56,8 +58,9 @@ def test_base_job_respects_plugin_hooks(self): import sys - job = Job(job_runner=MockJobRunner(lambda: sys.exit(0))) - run_job(job) + job = Job() + job_runner = MockJobRunner(job=job, func=lambda: sys.exit(0)) + run_job(job=job, execute_callable=job_runner._execute) assert job.state == State.SUCCESS assert job.end_date is not None @@ -68,8 +71,9 @@ def test_base_job_respects_plugin_lifecycle(self, dag_maker): """ get_listener_manager().add_listener(lifecycle_listener) - job = Job(job_runner=MockJobRunner(lambda: sys.exit(0))) - run_job(job) + job = Job() + job_runner = MockJobRunner(job=job, func=lambda: sys.exit(0)) + run_job(job=job, execute_callable=job_runner._execute) assert lifecycle_listener.started_component is job assert lifecycle_listener.stopped_component is job @@ -78,18 +82,21 @@ def test_state_failed(self): def abort(): raise RuntimeError("fail") - job = Job(job_runner=MockJobRunner(abort)) + job = Job() + job_runner = MockJobRunner(job=job, func=abort) with raises(RuntimeError): - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) assert job.state == State.FAILED assert job.end_date is not None def test_most_recent_job(self): with create_session() as session: - old_job = Job(job_runner=MockJobRunner(), heartrate=10) + old_job = Job(heartrate=10) + MockJobRunner(job=old_job) old_job.latest_heartbeat = old_job.latest_heartbeat - datetime.timedelta(seconds=20) - job = Job(job_runner=MockJobRunner(), heartrate=10) + job = Job(heartrate=10) + MockJobRunner(job=job) session.add(job) session.add(old_job) session.flush() @@ -101,13 +108,16 @@ def test_most_recent_job(self): def test_most_recent_job_running_precedence(self): with create_session() as session: - old_running_state_job = Job(job_runner=MockJobRunner(), heartrate=10) + old_running_state_job = Job(heartrate=10) + MockJobRunner(job=old_running_state_job) old_running_state_job.latest_heartbeat = timezone.utcnow() old_running_state_job.state = State.RUNNING - new_failed_state_job = Job(job_runner=MockJobRunner(), heartrate=10) + new_failed_state_job = Job(heartrate=10) + MockJobRunner(job=new_failed_state_job) new_failed_state_job.latest_heartbeat = timezone.utcnow() new_failed_state_job.state = State.FAILED - new_null_state_job = Job(job_runner=MockJobRunner(), heartrate=10) + new_null_state_job = Job(heartrate=10) + MockJobRunner(job=new_null_state_job) new_null_state_job.latest_heartbeat = timezone.utcnow() new_null_state_job.state = None session.add(old_running_state_job) @@ -120,7 +130,7 @@ def test_most_recent_job_running_precedence(self): session.rollback() def test_is_alive(self): - job = Job(job_runner=MockJobRunner(), heartrate=10, state=State.RUNNING) + job = Job(heartrate=10, state=State.RUNNING) assert job.is_alive() is True job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=20) @@ -145,12 +155,12 @@ def test_heartbeat_failed(self, mock_create_session): mock_session = Mock(spec_set=session, name="MockSession") mock_create_session.return_value.__enter__.return_value = mock_session - job = Job(job_runner=MockJobRunner(), heartrate=10, state=State.RUNNING) + job = Job(heartrate=10, state=State.RUNNING) job.latest_heartbeat = when mock_session.commit.side_effect = OperationalError("Force fail", {}, None) - job.heartbeat() + job.heartbeat(heartbeat_callback=lambda: None) assert job.latest_heartbeat == when, "attribute not updated when heartbeat fails" @@ -169,7 +179,8 @@ def test_essential_attr(self, mock_getuser, mock_hostname, mock_default_executor mock_getuser.return_value = "testuser" mock_default_executor.return_value = mock_sequential_executor - test_job = Job(job_runner=MockJobRunner(), heartrate=10, dag_id="example_dag", state=State.RUNNING) + test_job = Job(heartrate=10, dag_id="example_dag", state=State.RUNNING) + MockJobRunner(job=test_job) assert test_job.executor_class == "SequentialExecutor" assert test_job.heartrate == 10 assert test_job.dag_id == "example_dag" @@ -182,18 +193,16 @@ def test_essential_attr(self, mock_getuser, mock_hostname, mock_default_executor def test_heartbeat(self, frozen_sleep, monkeypatch): monkeypatch.setattr("airflow.jobs.job.sleep", frozen_sleep) with create_session() as session: - job = Job(job_runner=MockJobRunner(), heartrate=10) + job = Job(heartrate=10) job.latest_heartbeat = timezone.utcnow() session.add(job) session.commit() hb_callback = Mock() - job.job_runner.heartbeat_callback = hb_callback + job.heartbeat(heartbeat_callback=hb_callback) - job.heartbeat() - - hb_callback.assert_called_once_with(session=ANY) + hb_callback.assert_called_once_with(ANY) hb_callback.reset_mock() - perform_heartbeat(job=job, only_if_necessary=True) + perform_heartbeat(job=job, heartbeat_callback=hb_callback, only_if_necessary=True) assert hb_callback.called is False diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index 52ba119c1ceda..00ccb3e886c3a 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -113,12 +113,9 @@ def test_localtaskjob_essential_attr(self, dag_maker): ti = dr.get_task_instance(task_id=op1.task_id) - job1 = Job( - job_runner=LocalTaskJobRunner(task_instance=ti, ignore_ti_state=True), - dag_id=ti.dag_id, - executor=SequentialExecutor(), - ) + job1 = Job(dag_id=ti.dag_id, executor=SequentialExecutor()) + LocalTaskJobRunner(job=job1, task_instance=ti, ignore_ti_state=True) essential_attr = ["dag_id", "job_type", "start_date", "hostname"] check_result_1 = [hasattr(job1, attr) for attr in essential_attr] @@ -138,18 +135,15 @@ def test_localtaskjob_heartbeat(self, dag_maker): ti.hostname = "blablabla" session.commit() - job1 = Job( - job_runner=LocalTaskJobRunner(task_instance=ti, ignore_ti_state=True), - dag_id=ti.dag_id, - executor=SequentialExecutor(), - ) + job1 = Job(dag_id=ti.dag_id, executor=SequentialExecutor()) + job_runner = LocalTaskJobRunner(job=job1, task_instance=ti, ignore_ti_state=True) ti.task = op1 ti.refresh_from_task(op1) - job1.task_runner = StandardTaskRunner(job1) + job1.task_runner = StandardTaskRunner(job_runner) job1.task_runner.process = mock.Mock() - job1.job_runner.task_runner = job1.task_runner + job_runner.task_runner = job1.task_runner with pytest.raises(AirflowException): - job1.job_runner.heartbeat_callback() + job_runner.heartbeat_callback() job1.task_runner.process.pid = 1 ti.state = State.RUNNING @@ -160,11 +154,11 @@ def test_localtaskjob_heartbeat(self, dag_maker): assert ti.pid != os.getpid() assert not ti.run_as_user assert not job1.task_runner.run_as_user - job1.job_runner.heartbeat_callback(session=None) + job_runner.heartbeat_callback(session=None) job1.task_runner.process.pid = 2 with pytest.raises(AirflowException): - job1.job_runner.heartbeat_callback() + job_runner.heartbeat_callback() # Now, set the ti.pid to None and test that no error # is raised. @@ -174,7 +168,7 @@ def test_localtaskjob_heartbeat(self, dag_maker): assert ti.pid != job1.task_runner.process.pid assert not ti.run_as_user assert not job1.task_runner.run_as_user - job1.job_runner.heartbeat_callback() + job_runner.heartbeat_callback() @mock.patch("subprocess.check_call") @mock.patch("airflow.jobs.local_task_job_runner.psutil") @@ -189,21 +183,18 @@ def test_localtaskjob_heartbeat_with_run_as_user(self, psutil_mock, _, dag_maker ti.hostname = get_hostname() session.commit() - job1 = Job( - job_runner=LocalTaskJobRunner(task_instance=ti, ignore_ti_state=True), - dag_id=ti.dag_id, - executor=SequentialExecutor(), - ) + job1 = Job(dag_id=ti.dag_id, executor=SequentialExecutor()) + job_runner = LocalTaskJobRunner(job=job1, task_instance=ti, ignore_ti_state=True) ti.task = op1 ti.refresh_from_task(op1) - job1.task_runner = StandardTaskRunner(job1) + job1.task_runner = StandardTaskRunner(job_runner) job1.task_runner.process = mock.Mock() job1.task_runner.process.pid = 2 - job1.job_runner.task_runner = job1.task_runner + job_runner.task_runner = job1.task_runner # Here, ti.pid is 2, the parent process of ti.pid is a mock(different). # And task_runner process is 2. Should fail with pytest.raises(AirflowException, match="PID of job runner does not match"): - job1.job_runner.heartbeat_callback() + job_runner.heartbeat_callback() job1.task_runner.process.pid = 1 # We make the parent process of ti.pid to equal the task_runner process id @@ -215,13 +206,13 @@ def test_localtaskjob_heartbeat_with_run_as_user(self, psutil_mock, _, dag_maker assert ti.run_as_user session.merge(ti) session.commit() - job1.job_runner.heartbeat_callback(session=None) + job_runner.heartbeat_callback(session=None) # Here the task_runner process id is changed to 2 # while parent process of ti.pid is kept at 1, which is different job1.task_runner.process.pid = 2 with pytest.raises(AirflowException, match="PID of job runner does not match"): - job1.job_runner.heartbeat_callback() + job_runner.heartbeat_callback() # Here we set the ti.pid to None and test that no error is # raised @@ -231,7 +222,7 @@ def test_localtaskjob_heartbeat_with_run_as_user(self, psutil_mock, _, dag_maker assert ti.run_as_user assert job1.task_runner.run_as_user == ti.run_as_user assert ti.pid != job1.task_runner.process.pid - job1.job_runner.heartbeat_callback() + job_runner.heartbeat_callback() @conf_vars({("core", "default_impersonation"): "testuser"}) @mock.patch("subprocess.check_call") @@ -247,21 +238,18 @@ def test_localtaskjob_heartbeat_with_default_impersonation(self, psutil_mock, _, ti.hostname = get_hostname() session.commit() - job1 = Job( - job_runner=LocalTaskJobRunner(task_instance=ti, ignore_ti_state=True), - dag_id=ti.dag_id, - executor=SequentialExecutor(), - ) + job1 = Job(dag_id=ti.dag_id, executor=SequentialExecutor()) + job_runner = LocalTaskJobRunner(job1, task_instance=ti, ignore_ti_state=True) ti.task = op1 ti.refresh_from_task(op1) - job1.task_runner = StandardTaskRunner(job1) + job1.task_runner = StandardTaskRunner(job_runner) job1.task_runner.process = mock.Mock() job1.task_runner.process.pid = 2 - job1.job_runner.task_runner = job1.task_runner + job_runner.task_runner = job1.task_runner # Here, ti.pid is 2, the parent process of ti.pid is a mock(different). # And task_runner process is 2. Should fail with pytest.raises(AirflowException, match="PID of job runner does not match"): - job1.job_runner.heartbeat_callback() + job_runner.heartbeat_callback() job1.task_runner.process.pid = 1 # We make the parent process of ti.pid to equal the task_runner process id @@ -273,13 +261,13 @@ def test_localtaskjob_heartbeat_with_default_impersonation(self, psutil_mock, _, assert job1.task_runner.run_as_user == "testuser" session.merge(ti) session.commit() - job1.job_runner.heartbeat_callback(session=None) + job_runner.heartbeat_callback(session=None) # Here the task_runner process id is changed to 2 # while parent process of ti.pid is kept at 1, which is different job1.task_runner.process.pid = 2 with pytest.raises(AirflowException, match="PID of job runner does not match"): - job1.job_runner.heartbeat_callback() + job_runner.heartbeat_callback() # Now, set the ti.pid to None and test that no error # is raised. @@ -289,7 +277,7 @@ def test_localtaskjob_heartbeat_with_default_impersonation(self, psutil_mock, _, assert job1.task_runner.run_as_user == "testuser" assert ti.run_as_user is None assert ti.pid != job1.task_runner.process.pid - job1.job_runner.heartbeat_callback() + job_runner.heartbeat_callback() def test_heartbeat_failed_fast(self): """ @@ -320,15 +308,12 @@ def test_heartbeat_failed_fast(self): ti.pid = 1 session.commit() - job = Job( - job_runner=LocalTaskJobRunner(task_instance=ti), - dag_id=ti.dag_id, - executor=MockExecutor(do_update=False), - ) + job = Job(dag_id=ti.dag_id, executor=MockExecutor(do_update=False)) + job_runner = LocalTaskJobRunner(job=job, task_instance=ti) job.heartrate = 2 heartbeat_records = [] - job.job_runner.heartbeat_callback = lambda session: heartbeat_records.append(job.latest_heartbeat) - run_job(job) + job_runner.heartbeat_callback = lambda session: heartbeat_records.append(job.latest_heartbeat) + run_job(job=job, execute_callable=job_runner._execute) assert len(heartbeat_records) > 2 for i in range(1, len(heartbeat_records)): time1 = heartbeat_records[i - 1] @@ -354,10 +339,11 @@ def test_mark_success_no_kill(self, caplog, get_test_dag, session): ti = dr.get_task_instance(task.task_id) ti.refresh_from_task(task) - job1 = Job(job_runner=LocalTaskJobRunner(task_instance=ti, ignore_ti_state=True), dag_id=ti.dag_id) + job1 = Job(dag_id=ti.dag_id) + job_runner = LocalTaskJobRunner(job=job1, task_instance=ti, ignore_ti_state=True) with timeout(30): - run_job(job1) + run_job(job=job1, execute_callable=job_runner._execute) ti.refresh_from_db() assert State.SUCCESS == ti.state assert ( @@ -389,13 +375,10 @@ def test_localtaskjob_double_trigger(self): ti_run = TaskInstance(task=task, run_id=dr.run_id) ti_run.refresh_from_db() - job1 = Job( - job_runner=LocalTaskJobRunner(task_instance=ti_run), - dag_id=ti_run.dag_id, - executor=SequentialExecutor(), - ) + job1 = Job(dag_id=ti_run.dag_id, executor=SequentialExecutor()) + job_runner = LocalTaskJobRunner(job=job1, task_instance=ti_run) with patch.object(StandardTaskRunner, "start", return_value=None) as mock_method: - run_job(job1) + run_job(job=job1, execute_callable=job_runner._execute) mock_method.assert_not_called() ti = dr.get_task_instance(task_id=task.task_id, session=session) @@ -412,17 +395,14 @@ def test_local_task_return_code_metric(self, mock_stats_incr, mock_return_code, ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti_run.refresh_from_db() - job1 = Job( - job_runner=LocalTaskJobRunner(task_instance=ti_run), - dag_id=ti_run.dag_id, - executor=SequentialExecutor(), - ) + job1 = Job(dag_id=ti_run.dag_id, executor=SequentialExecutor()) + job_runner = LocalTaskJobRunner(job=job1, task_instance=ti_run) job1.id = 95 mock_return_code.side_effect = [None, -9, None] with timeout(10): - run_job(job1) + run_job(job=job1, execute_callable=job_runner._execute) mock_stats_incr.assert_has_calls( [ @@ -437,12 +417,8 @@ def test_localtaskjob_maintain_heart_rate(self, mock_return_code, caplog, create ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti_run.refresh_from_db() - job1 = Job( - job_runner=LocalTaskJobRunner(task_instance=ti_run), - dag_id=ti_run.dag_id, - executor=SequentialExecutor(), - ) - + job1 = Job(dag_id=ti_run.dag_id, executor=SequentialExecutor()) + job_runner = LocalTaskJobRunner(job=job1, task_instance=ti_run) time_start = time.time() # this should make sure we only heartbeat once and exit at the second @@ -452,7 +428,7 @@ def test_localtaskjob_maintain_heart_rate(self, mock_return_code, caplog, create mock_return_code.side_effect = [None, 0, None] with timeout(10): - run_job(job1) + run_job(job=job1, execute_callable=job_runner._execute) assert mock_return_code.call_count == 3 time_end = time.time() @@ -483,18 +459,12 @@ def test_mark_failure_on_failure_callback(self, caplog, get_test_dag): ti = dr.get_task_instance(task.task_id) ti.refresh_from_task(task) - job1 = Job( - job_runner=LocalTaskJobRunner( - task_instance=ti, - ignore_ti_state=True, - ), - dag_id=ti.dag_id, - executor=SequentialExecutor(), - ) + job1 = Job(dag_id=ti.dag_id, executor=SequentialExecutor()) + job_runner = LocalTaskJobRunner(job=job1, task_instance=ti, ignore_ti_state=True) with timeout(30): # This should be _much_ shorter to run. # If you change this limit, make the timeout in the callable above bigger - run_job(job1) + run_job(job=job1, execute_callable=job_runner._execute) ti.refresh_from_db() assert ti.state == State.FAILED @@ -521,18 +491,12 @@ def test_dagrun_timeout_logged_in_task_logs(self, caplog, get_test_dag): ti = dr.get_task_instance(task.task_id) ti.refresh_from_task(task) - job1 = Job( - job_runner=LocalTaskJobRunner( - task_instance=ti, - ignore_ti_state=True, - ), - dag_id=ti.dag_id, - executor=SequentialExecutor(), - ) + job1 = Job(dag_id=ti.dag_id, executor=SequentialExecutor()) + job_runner = LocalTaskJobRunner(job=job1, task_instance=ti, ignore_ti_state=True) with timeout(30): # This should be _much_ shorter to run. # If you change this limit, make the timeout in the callable above bigger - run_job(job1) + run_job(job=job1, execute_callable=job_runner._execute) ti.refresh_from_db() assert ti.state == State.SKIPPED @@ -557,12 +521,9 @@ def test_failure_callback_called_by_airflow_run_raw_process(self, monkeypatch, t ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() - job1 = Job( - job_runner=LocalTaskJobRunner(task_instance=ti, ignore_ti_state=True), - executor=SequentialExecutor(), - dag_id=ti.dag_id, - ) - run_job(job1) + job1 = Job(executor=SequentialExecutor(), dag_id=ti.dag_id) + job_runner = LocalTaskJobRunner(job1, task_instance=ti, ignore_ti_state=True) + run_job(job=job1, execute_callable=job_runner._execute) ti.refresh_from_db() assert ti.state == State.FAILED # task exits with failure state @@ -592,17 +553,11 @@ def test_mark_success_on_success_callback(self, caplog, get_test_dag): ti = dr.get_task_instance(task.task_id) ti.refresh_from_task(task) - job = Job( - job_runner=LocalTaskJobRunner( - task_instance=ti, - ignore_ti_state=True, - ), - executor=SequentialExecutor(), - dag_id=ti.dag_id, - ) - + job = Job(executor=SequentialExecutor(), dag_id=ti.dag_id) + job_runner = LocalTaskJobRunner(job=job, task_instance=ti, ignore_ti_state=True) with timeout(30): - run_job(job) # This should run fast because of the return_code=None + # This should run fast because of the return_code=None + run_job(job=job, execute_callable=job_runner._execute) ti.refresh_from_db() assert ( "State of this instance has been externally set to success. Terminating instance." in caplog.text @@ -670,15 +625,9 @@ def send_signal(ti, signal_sent, sig): thread.daemon = True thread.start() - job1 = Job( - job_runner=LocalTaskJobRunner( - task_instance=ti, - ignore_ti_state=True, - ), - dag_id=ti.dag_id, - executor=SequentialExecutor(), - ) - run_job(job1) + job1 = Job(dag_id=ti.dag_id, executor=SequentialExecutor()) + job_runner = LocalTaskJobRunner(job=job1, task_instance=ti, ignore_ti_state=True) + run_job(job=job1, execute_callable=job_runner._execute) ti.refresh_from_db() @@ -758,8 +707,9 @@ def test_fast_follow( "test_dagrun_fast_follow", ) - scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - scheduler_job.job_runner.dagbag.bag_dag(dag, root_dag=dag) + scheduler_job = Job() + scheduler_job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + scheduler_job_runner.dagbag.bag_dag(dag, root_dag=dag) dag_run = dag.create_dagrun(run_id="test_dagrun_fast_follow", state=State.RUNNING) @@ -772,36 +722,24 @@ def test_fast_follow( ti = TaskInstance(task=dag.get_task(task_ids_to_run[0]), execution_date=dag_run.execution_date) ti.refresh_from_db() - job1 = Job( - job_runner=LocalTaskJobRunner( - task_instance=ti, - ignore_ti_state=True, - ), - executor=SequentialExecutor(), - dag_id=ti.dag_id, - ) - job1.task_runner = StandardTaskRunner(job1) + job1 = Job(executor=SequentialExecutor(), dag_id=ti.dag_id) + job_runner = LocalTaskJobRunner(job=job1, task_instance=ti, ignore_ti_state=True) + job1.task_runner = StandardTaskRunner(job_runner) - run_job(job1) + run_job(job=job1, execute_callable=job_runner._execute) self.validate_ti_states(dag_run, first_run_state, error_message) if second_run_state: ti = TaskInstance( task=dag.get_task(task_ids_to_run[1]), execution_date=dag_run.execution_date ) ti.refresh_from_db() - job2 = Job( - job_runner=LocalTaskJobRunner( - task_instance=ti, - ignore_ti_state=True, - ), - dag_id=ti.dag_id, - executor=SequentialExecutor(), - ) - job2.task_runner = StandardTaskRunner(job2) - run_job(job2) + job2 = Job(dag_id=ti.dag_id, executor=SequentialExecutor()) + job_runner = LocalTaskJobRunner(job=job2, task_instance=ti, ignore_ti_state=True) + job2.task_runner = StandardTaskRunner(job_runner) + run_job(job2, execute_callable=job_runner._execute) self.validate_ti_states(dag_run, second_run_state, error_message) - if scheduler_job.job_runner.processor_agent: - scheduler_job.job_runner.processor_agent.end() + if scheduler_job_runner.processor_agent: + scheduler_job_runner.processor_agent.end() @conf_vars({("scheduler", "schedule_after_task_execution"): "True"}) def test_mini_scheduler_works_with_wait_for_upstream(self, caplog, get_test_dag): @@ -829,15 +767,12 @@ def test_mini_scheduler_works_with_wait_for_upstream(self, caplog, get_test_dag) session.merge(ti2_l) job1 = Job( - job_runner=LocalTaskJobRunner( - task_instance=ti2_k, - ignore_ti_state=True, - ), executor=SequentialExecutor(), dag_id=ti2_k.dag_id, ) - job1.task_runner = StandardTaskRunner(job1) - run_job(job1) + job_runner = LocalTaskJobRunner(job=job1, task_instance=ti2_k, ignore_ti_state=True) + job1.task_runner = StandardTaskRunner(job_runner) + run_job(job=job1, execute_callable=job_runner._execute) ti2_k.refresh_from_db() ti2_l.refresh_from_db() @@ -876,18 +811,12 @@ def task_function(ti): dag_run = dag_maker.create_dagrun() ti = TaskInstance(task=task, run_id=dag_run.run_id) ti.refresh_from_db() - job = Job( - job_runner=LocalTaskJobRunner( - task_instance=ti, - ignore_ti_state=True, - ), - executor=SequentialExecutor(), - dag_id=ti.dag_id, - ) + job = Job(executor=SequentialExecutor(), dag_id=ti.dag_id) + job_runner = LocalTaskJobRunner(job=job, task_instance=ti, ignore_ti_state=True) settings.engine.dispose() with timeout(10): with pytest.raises(AirflowException, match=r"Segmentation Fault detected"): - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) assert SIGSEGV_MESSAGE in caplog.messages @@ -913,9 +842,10 @@ def test_number_of_queries_single_loop(mock_get_task_runner, dag_maker): ti = dr.task_instances[0] ti.refresh_from_task(task) - job = Job(job_runner=LocalTaskJobRunner(task_instance=ti), dag_id=ti.dag_id, executor=MockExecutor()) + job = Job(dag_id=ti.dag_id, executor=MockExecutor()) + job_runner = LocalTaskJobRunner(job=job, task_instance=ti) with assert_queries_count(18): - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) class TestSigtermOnRunner: @@ -1028,12 +958,6 @@ def task_function(): dag.create_dagrun(state=State.RUNNING, run_id=run_id, execution_date=execution_date) ti = TaskInstance(task=task, execution_date=execution_date) ti.refresh_from_db() - job = Job( - job_runner=LocalTaskJobRunner( - task_instance=ti, - ignore_ti_state=True, - ), - executor=SequentialExecutor(), - dag_id=ti.dag_id, - ) - run_job(job) + job = Job(executor=SequentialExecutor(), dag_id=ti.dag_id) + job_runner = LocalTaskJobRunner(job=job, task_instance=ti, ignore_ti_state=True) + run_job(job=job, execute_callable=job_runner._execute) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 49ea41b840726..4870a2aa132e9 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -79,7 +79,6 @@ ) from tests.test_utils.mock_executor import MockExecutor from tests.test_utils.mock_operators import CustomOperator -from tests.utils.test_helpers import MockJobRunner from tests.utils.test_timezone import UTC ROOT_FOLDER = os.path.realpath( @@ -138,13 +137,13 @@ def clean_db(): @pytest.fixture(autouse=True) def per_test(self) -> Generator: self.clean_db() - self.scheduler_job = None + self.job_runner = None yield - if self.scheduler_job and self.scheduler_job.job_runner.processor_agent: - self.scheduler_job.job_runner.processor_agent.end() - self.scheduler_job = None + if self.job_runner and self.job_runner.processor_agent: + self.job_runner.processor_agent.end() + self.job_runner = None self.clean_db() @pytest.fixture(autouse=True) @@ -171,24 +170,25 @@ def set_instance_attrs(self, dagbag) -> Generator: ) def test_is_alive(self, configs): with conf_vars(configs): - self.scheduler_job = Job(job_runner=SchedulerJobRunner(None), heartrate=10, state=State.RUNNING) - assert self.scheduler_job.is_alive() + scheduler_job = Job(heartrate=10, state=State.RUNNING) + self.job_runner = SchedulerJobRunner(scheduler_job) + assert scheduler_job.is_alive() - self.scheduler_job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=20) - assert self.scheduler_job.is_alive() + scheduler_job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=20) + assert scheduler_job.is_alive() - self.scheduler_job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=31) - assert not self.scheduler_job.is_alive() + scheduler_job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=31) + assert not scheduler_job.is_alive() # test because .seconds was used before instead of total_seconds # internal repr of datetime is (days, seconds) - self.scheduler_job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(days=1) - assert not self.scheduler_job.is_alive() + scheduler_job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(days=1) + assert not scheduler_job.is_alive() - self.scheduler_job.state = State.SUCCESS - self.scheduler_job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=10) + scheduler_job.state = State.SUCCESS + scheduler_job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=10) assert ( - not self.scheduler_job.is_alive() + not scheduler_job.is_alive() ), "Completed jobs even with recent heartbeat should not be alive" def run_single_scheduler_loop_with_no_dags(self, dags_folder): @@ -200,23 +200,24 @@ def run_single_scheduler_loop_with_no_dags(self, dags_folder): :param dags_folder: the directory to traverse """ - self.scheduler_job = Job( - job_runner=SchedulerJobRunner( - executor=self.null_exec, num_times_parse_dags=1, subdir=os.path.join(dags_folder) - ) + scheduler_job = Job( + executor=self.null_exec, + num_times_parse_dags=1, + subdir=os.path.join(dags_folder), ) - self.scheduler_job.heartrate = 0 - run_job(self.scheduler_job) + self.job_runner = SchedulerJobRunner(scheduler_job) + scheduler_job.heartrate = 0 + run_job(scheduler_job, execute_callable=self.job_runner._execute) def test_no_orphan_process_will_be_left(self): empty_dir = mkdtemp() current_process = psutil.Process() old_children = current_process.children(recursive=True) - self.scheduler_job = Job( - job_runner=SchedulerJobRunner(subdir=empty_dir, num_runs=1), + scheduler_job = Job( executor=MockExecutor(do_update=False), ) - run_job(self.scheduler_job) + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=empty_dir, num_runs=1) + run_job(scheduler_job, execute_callable=self.job_runner._execute) shutil.rmtree(empty_dir) # Remove potential noise created by previous tests. @@ -239,19 +240,20 @@ def test_process_executor_events(self, mock_stats_incr, mock_task_callback, dag_ executor = MockExecutor(do_update=False) task_callback = mock.MagicMock() mock_task_callback.return_value = task_callback - self.scheduler_job = Job(job_runner=SchedulerJobRunner(), executor=executor) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job(executor=executor) + self.job_runner = SchedulerJobRunner(scheduler_job) + self.job_runner.processor_agent = mock.MagicMock() ti1.state = State.QUEUED session.merge(ti1) session.commit() executor.event_buffer[ti1.key] = State.FAILED, None - self.scheduler_job.job_runner._process_executor_events(session=session) + self.job_runner._process_executor_events(session=session) ti1.refresh_from_db(session=session) assert ti1.state == State.FAILED - self.scheduler_job.executor.callback_sink.send.assert_not_called() - self.scheduler_job.job_runner.processor_agent.reset_mock() + scheduler_job.executor.callback_sink.send.assert_not_called() + self.job_runner.processor_agent.reset_mock() # ti in success state ti1.state = State.SUCCESS @@ -259,10 +261,10 @@ def test_process_executor_events(self, mock_stats_incr, mock_task_callback, dag_ session.commit() executor.event_buffer[ti1.key] = State.SUCCESS, None - self.scheduler_job.job_runner._process_executor_events(session=session) + self.job_runner._process_executor_events(session=session) ti1.refresh_from_db(session=session) assert ti1.state == State.SUCCESS - self.scheduler_job.executor.callback_sink.send.assert_not_called() + scheduler_job.executor.callback_sink.send.assert_not_called() mock_stats_incr.assert_has_calls( [ mock.call( @@ -287,8 +289,9 @@ def test_process_executor_events_with_no_callback(self, mock_stats_incr, mock_ta executor = MockExecutor(do_update=False) task_callback = mock.MagicMock() mock_task_callback.return_value = task_callback - self.scheduler_job = Job(job_runner=SchedulerJobRunner(), executor=executor) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job(executor=executor) + self.job_runner = SchedulerJobRunner(scheduler_job) + self.job_runner.processor_agent = mock.MagicMock() session = settings.Session() with dag_maker(dag_id=dag_id, fileloc="/test_path1/"): @@ -302,18 +305,19 @@ def test_process_executor_events_with_no_callback(self, mock_stats_incr, mock_ta executor = MockExecutor(do_update=False) task_callback = mock.MagicMock() mock_task_callback.return_value = task_callback - self.scheduler_job = Job(job_runner=SchedulerJobRunner(), executor=executor) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job(executor=executor) + self.job_runner = SchedulerJobRunner(scheduler_job) + self.job_runner.processor_agent = mock.MagicMock() ti1.state = State.QUEUED session.merge(ti1) session.commit() executor.event_buffer[ti1.key] = State.FAILED, None - self.scheduler_job.job_runner._process_executor_events(session=session) + self.job_runner._process_executor_events(session=session) ti1.refresh_from_db(session=session) assert ti1.state == State.UP_FOR_RETRY - self.scheduler_job.executor.callback_sink.send.assert_not_called() + scheduler_job.executor.callback_sink.send.assert_not_called() # ti in success state ti1.state = State.SUCCESS @@ -321,10 +325,10 @@ def test_process_executor_events_with_no_callback(self, mock_stats_incr, mock_ta session.commit() executor.event_buffer[ti1.key] = State.SUCCESS, None - self.scheduler_job.job_runner._process_executor_events(session=session) + self.job_runner._process_executor_events(session=session) ti1.refresh_from_db(session=session) assert ti1.state == State.SUCCESS - self.scheduler_job.executor.callback_sink.send.assert_not_called() + scheduler_job.executor.callback_sink.send.assert_not_called() mock_stats_incr.assert_has_calls( [ mock.call( @@ -352,8 +356,9 @@ def test_process_executor_events_with_callback(self, mock_stats_incr, mock_task_ executor = MockExecutor(do_update=False) task_callback = mock.MagicMock() mock_task_callback.return_value = task_callback - self.scheduler_job = Job(job_runner=SchedulerJobRunner(), executor=executor) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job(executor=executor) + self.job_runner = SchedulerJobRunner(scheduler_job) + self.job_runner.processor_agent = mock.MagicMock() session = settings.Session() ti1.state = State.QUEUED @@ -362,7 +367,7 @@ def test_process_executor_events_with_callback(self, mock_stats_incr, mock_task_ executor.event_buffer[ti1.key] = State.FAILED, None - self.scheduler_job.job_runner._process_executor_events(session=session) + self.job_runner._process_executor_events(session=session) ti1.refresh_from_db() # The state will remain in queued here and # will be set to failed in dag parsing process @@ -376,8 +381,8 @@ def test_process_executor_events_with_callback(self, mock_stats_incr, mock_task_ "finished (failed) although the task says its queued. (Info: None) " "Was the task killed externally?", ) - self.scheduler_job.executor.callback_sink.send.assert_called_once_with(task_callback) - self.scheduler_job.executor.callback_sink.reset_mock() + scheduler_job.executor.callback_sink.send.assert_called_once_with(task_callback) + scheduler_job.executor.callback_sink.reset_mock() mock_stats_incr.assert_called_once_with( "scheduler.tasks.killed_externally", tags={ @@ -402,10 +407,11 @@ def test_process_executor_event_missing_dag(self, mock_stats_incr, mock_task_cal executor = MockExecutor(do_update=False) task_callback = mock.MagicMock() mock_task_callback.return_value = task_callback - self.scheduler_job = Job(job_runner=SchedulerJobRunner(), executor=executor) - self.scheduler_job.job_runner.dagbag = mock.MagicMock() - self.scheduler_job.job_runner.dagbag.get_dag.side_effect = Exception("failed") - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job(executor=executor) + self.job_runner = SchedulerJobRunner(scheduler_job) + self.job_runner.dagbag = mock.MagicMock() + self.job_runner.dagbag.get_dag.side_effect = Exception("failed") + self.job_runner.processor_agent = mock.MagicMock() session = settings.Session() ti1.state = State.QUEUED @@ -413,7 +419,7 @@ def test_process_executor_event_missing_dag(self, mock_stats_incr, mock_task_cal session.commit() executor.event_buffer[ti1.key] = State.FAILED, None - self.scheduler_job.job_runner._process_executor_events(session=session) + self.job_runner._process_executor_events(session=session) ti1.refresh_from_db() assert ti1.state == State.FAILED @@ -433,9 +439,10 @@ def test_process_executor_events_ti_requeued(self, mock_stats_incr, mock_task_ca executor = MockExecutor(do_update=False) task_callback = mock.MagicMock() mock_task_callback.return_value = task_callback - self.scheduler_job = Job(job_runner=SchedulerJobRunner(), executor=executor) - self.scheduler_job.id = 1 - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job(executor=executor) + self.job_runner = SchedulerJobRunner(scheduler_job) + self.id = 1 + self.job_runner.processor_agent = mock.MagicMock() # ti is queued with another try number - do not fail it ti1.state = State.QUEUED @@ -446,10 +453,10 @@ def test_process_executor_events_ti_requeued(self, mock_stats_incr, mock_task_ca executor.event_buffer[ti1.key.with_try_number(1)] = State.SUCCESS, None - self.scheduler_job.job_runner._process_executor_events(session=session) + self.job_runner._process_executor_events(session=session) ti1.refresh_from_db(session=session) assert ti1.state == State.QUEUED - self.scheduler_job.executor.callback_sink.send.assert_not_called() + scheduler_job.executor.callback_sink.send.assert_not_called() # ti is queued by another scheduler - do not fail it ti1.state = State.QUEUED @@ -459,10 +466,10 @@ def test_process_executor_events_ti_requeued(self, mock_stats_incr, mock_task_ca executor.event_buffer[ti1.key] = State.SUCCESS, None - self.scheduler_job.job_runner._process_executor_events(session=session) + self.job_runner._process_executor_events(session=session) ti1.refresh_from_db(session=session) assert ti1.state == State.QUEUED - self.scheduler_job.executor.callback_sink.send.assert_not_called() + scheduler_job.executor.callback_sink.send.assert_not_called() # ti is queued by this scheduler but it is handed back to the executor - do not fail it ti1.state = State.QUEUED @@ -473,10 +480,10 @@ def test_process_executor_events_ti_requeued(self, mock_stats_incr, mock_task_ca executor.event_buffer[ti1.key] = State.SUCCESS, None executor.has_task = mock.MagicMock(return_value=True) - self.scheduler_job.job_runner._process_executor_events(session=session) + self.job_runner._process_executor_events(session=session) ti1.refresh_from_db(session=session) assert ti1.state == State.QUEUED - self.scheduler_job.executor.callback_sink.send.assert_not_called() + scheduler_job.executor.callback_sink.send.assert_not_called() mock_stats_incr.assert_not_called() def test_execute_task_instances_is_paused_wont_execute(self, session, dag_maker): @@ -487,13 +494,13 @@ def test_execute_task_instances_is_paused_wont_execute(self, session, dag_maker) EmptyOperator(task_id=task_id_1) assert isinstance(dag, SerializedDAG) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) dr1 = dag_maker.create_dagrun(run_type=DagRunType.BACKFILL_JOB) (ti1,) = dr1.task_instances ti1.state = State.SCHEDULED - self.scheduler_job.job_runner._critical_section_enqueue_task_instances(session) + self.job_runner._critical_section_enqueue_task_instances(session) session.flush() ti1.refresh_from_db(session=session) assert State.SCHEDULED == ti1.state @@ -509,7 +516,8 @@ def test_execute_task_instances_backfill_tasks_wont_execute(self, dag_maker): with dag_maker(dag_id=dag_id): task1 = EmptyOperator(task_id=task_id_1) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() dr1 = dag_maker.create_dagrun(run_type=DagRunType.BACKFILL_JOB) @@ -522,7 +530,7 @@ def test_execute_task_instances_backfill_tasks_wont_execute(self, dag_maker): assert dr1.is_backfill - self.scheduler_job.job_runner._critical_section_enqueue_task_instances(session) + self.job_runner._critical_section_enqueue_task_instances(session) session.flush() ti1.refresh_from_db() assert State.SCHEDULED == ti1.state @@ -530,19 +538,19 @@ def test_execute_task_instances_backfill_tasks_wont_execute(self, dag_maker): @conf_vars({("scheduler", "standalone_dag_processor"): "False"}) def test_setup_callback_sink_not_standalone_dag_processor(self): - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull, num_runs=1)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) + self.job_runner._execute() - self.scheduler_job.job_runner._execute() - - assert isinstance(self.scheduler_job.executor.callback_sink, PipeCallbackSink) + assert isinstance(scheduler_job.executor.callback_sink, PipeCallbackSink) @conf_vars({("scheduler", "standalone_dag_processor"): "True"}) def test_setup_callback_sink_standalone_dag_processor(self): - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull, num_runs=1)) - - self.scheduler_job.job_runner._execute() + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) + self.job_runner._execute() - assert isinstance(self.scheduler_job.executor.callback_sink, DatabaseCallbackSink) + assert isinstance(scheduler_job.executor.callback_sink, DatabaseCallbackSink) def test_find_executable_task_instances_backfill(self, dag_maker): dag_id = "SchedulerJobTest.test_find_executable_task_instances_backfill" @@ -550,7 +558,8 @@ def test_find_executable_task_instances_backfill(self, dag_maker): with dag_maker(dag_id=dag_id, max_active_tasks=16): task1 = EmptyOperator(task_id=task_id_1) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) @@ -567,7 +576,7 @@ def test_find_executable_task_instances_backfill(self, dag_maker): session.merge(ti_with_dagrun) session.flush() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) assert 1 == len(res) res_keys = map(lambda x: x.key, res) assert ti_with_dagrun.key in res_keys @@ -582,7 +591,8 @@ def test_find_executable_task_instances_pool(self, dag_maker): EmptyOperator(task_id=task_id_1, pool="a", priority_weight=2) EmptyOperator(task_id=task_id_2, pool="b", priority_weight=1) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED) @@ -603,7 +613,7 @@ def test_find_executable_task_instances_pool(self, dag_maker): session.add(pool2) session.flush() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) session.flush() assert 3 == len(res) res_keys = [] @@ -635,7 +645,8 @@ def test_find_executable_task_instances_only_running_dagruns( EmptyOperator(task_id=task_id_1) EmptyOperator(task_id=task_id_2) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) dr = dag_maker.create_dagrun(state=state) @@ -644,7 +655,7 @@ def test_find_executable_task_instances_only_running_dagruns( ti.state = State.SCHEDULED session.merge(ti) session.flush() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) session.flush() assert total_executed_ti == len(res) @@ -668,7 +679,8 @@ def test_find_executable_task_instances_order_execution_date(self, dag_maker): dr1 = session.merge(dr1, load=False) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) tis = dr1.task_instances + dr2.task_instances for ti in tis: @@ -676,7 +688,7 @@ def test_find_executable_task_instances_order_execution_date(self, dag_maker): session.merge(ti) session.flush() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=1, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=1, session=session) session.flush() assert [ti.key for ti in res] == [tis[1].key] session.rollback() @@ -696,7 +708,8 @@ def test_find_executable_task_instances_order_priority(self, dag_maker): dr1 = session.merge(dr1, load=False) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) tis = dr1.task_instances + dr2.task_instances for ti in tis: @@ -704,7 +717,7 @@ def test_find_executable_task_instances_order_priority(self, dag_maker): session.merge(ti) session.flush() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=1, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=1, session=session) session.flush() assert [ti.key for ti in res] == [tis[1].key] session.rollback() @@ -715,7 +728,8 @@ def test_find_executable_task_instances_order_priority_with_pools(self, dag_make even if different pools are involved. """ - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() dag_id = "SchedulerJobTest.test_find_executable_task_instances_order_priority_with_pools" @@ -740,7 +754,7 @@ def test_find_executable_task_instances_order_priority_with_pools(self, dag_make session.flush() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) assert 2 == len(res) assert ti3.key == res[0].key @@ -762,7 +776,8 @@ def test_find_executable_task_instances_order_execution_date_and_priority(self, dr2 = dag_maker.create_dagrun(execution_date=DEFAULT_DATE + timedelta(hours=1)) dr1 = session.merge(dr1, load=False) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) tis = dr1.task_instances + dr2.task_instances for ti in tis: @@ -770,7 +785,7 @@ def test_find_executable_task_instances_order_execution_date_and_priority(self, session.merge(ti) session.flush() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=1, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=1, session=session) session.flush() assert [ti.key for ti in res] == [tis[1].key] session.rollback() @@ -782,9 +797,9 @@ def test_find_executable_task_instances_in_default_pool(self, dag_maker): with dag_maker(dag_id=dag_id): op1 = EmptyOperator(task_id="dummy1") op2 = EmptyOperator(task_id="dummy2") + scheduler_job = Job(executor=MockExecutor()) + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) - executor = MockExecutor(do_update=True) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(), executor=executor) session = settings.Session() dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) @@ -798,14 +813,14 @@ def test_find_executable_task_instances_in_default_pool(self, dag_maker): session.flush() # Two tasks w/o pool up for execution and our default pool size is 1 - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) assert 1 == len(res) ti2.state = State.RUNNING session.flush() # One task w/o pool up for execution and one task running - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) assert 0 == len(res) session.rollback() @@ -821,9 +836,11 @@ def test_queued_task_instances_fails_with_missing_dag(self, dag_maker, session): EmptyOperator(task_id=task_id_1) EmptyOperator(task_id=task_id_2) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.job_runner.dagbag = mock.MagicMock() - self.scheduler_job.job_runner.dagbag.get_dag.return_value = None + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner.dagbag = mock.MagicMock() + self.job_runner.dagbag.get_dag.return_value = None dr = dag_maker.create_dagrun(state=DagRunState.RUNNING) @@ -832,7 +849,7 @@ def test_queued_task_instances_fails_with_missing_dag(self, dag_maker, session): ti.state = State.SCHEDULED session.merge(ti) session.flush() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) session.flush() assert 0 == len(res) tis = dr.get_task_instances(session=session) @@ -844,7 +861,8 @@ def test_nonexistent_pool(self, dag_maker): with dag_maker(dag_id=dag_id, max_active_tasks=16): EmptyOperator(task_id="dummy_wrong_pool", pool="this_pool_doesnt_exist") - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() dr = dag_maker.create_dagrun() @@ -854,7 +872,7 @@ def test_nonexistent_pool(self, dag_maker): session.merge(ti) session.commit() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) session.flush() assert 0 == len(res) session.rollback() @@ -864,7 +882,8 @@ def test_infinite_pool(self, dag_maker): with dag_maker(dag_id=dag_id, concurrency=16): EmptyOperator(task_id="dummy", pool="infinite_pool") - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() dr = dag_maker.create_dagrun() @@ -875,7 +894,7 @@ def test_infinite_pool(self, dag_maker): session.add(infinite_pool) session.commit() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) session.flush() assert 1 == len(res) session.rollback() @@ -886,7 +905,8 @@ def test_not_enough_pool_slots(self, caplog, dag_maker): EmptyOperator(task_id="cannot_run", pool="some_pool", pool_slots=4) EmptyOperator(task_id="can_run", pool="some_pool", pool_slots=1) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() dr = dag_maker.create_dagrun() ti = dr.task_instances[0] @@ -899,7 +919,7 @@ def test_not_enough_pool_slots(self, caplog, dag_maker): session.add(some_pool) session.commit() with caplog.at_level(logging.WARNING): - self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) assert ( "Not executing . " @@ -929,12 +949,11 @@ def test_find_executable_task_instances_none(self, dag_maker): with dag_maker(dag_id=dag_id, max_active_tasks=16): EmptyOperator(task_id=task_id_1) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() - assert 0 == len( - self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) - ) + assert 0 == len(self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session)) session.rollback() def test_tis_for_queued_dagruns_are_not_run(self, dag_maker): @@ -948,7 +967,8 @@ def test_tis_for_queued_dagruns_are_not_run(self, dag_maker): task1 = EmptyOperator(task_id=task_id_1) dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, state=State.QUEUED) dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() ti1 = TaskInstance(task1, run_id=dr1.run_id) ti2 = TaskInstance(task1, run_id=dr2.run_id) @@ -957,7 +977,7 @@ def test_tis_for_queued_dagruns_are_not_run(self, dag_maker): session.merge(ti1) session.merge(ti2) session.flush() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) assert 1 == len(res) assert ti2.key == res[0].key @@ -972,7 +992,8 @@ def test_find_executable_task_instances_concurrency(self, dag_maker): with dag_maker(dag_id=dag_id, max_active_tasks=2, session=session): EmptyOperator(task_id="dummy") - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED) @@ -990,7 +1011,7 @@ def test_find_executable_task_instances_concurrency(self, dag_maker): session.flush() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) assert 1 == len(res) res_keys = map(lambda x: x.key, res) @@ -1000,7 +1021,7 @@ def test_find_executable_task_instances_concurrency(self, dag_maker): session.merge(ti2) session.flush() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) assert 0 == len(res) session.rollback() @@ -1012,7 +1033,8 @@ def test_find_executable_task_instances_concurrency_queued(self, dag_maker): task2 = EmptyOperator(task_id="dummy2") task3 = EmptyOperator(task_id="dummy3") - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() dag_run = dag_maker.create_dagrun() @@ -1030,7 +1052,7 @@ def test_find_executable_task_instances_concurrency_queued(self, dag_maker): session.flush() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) assert 1 == len(res) assert res[0].key == ti3.key @@ -1046,7 +1068,9 @@ def test_find_executable_task_instances_max_active_tis_per_dag(self, dag_maker): task2 = EmptyOperator(task_id=task_id_2) executor = MockExecutor(do_update=True) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(), executor=executor) + + scheduler_job = Job(executor=executor) + self.job_runner = SchedulerJobRunner(job=scheduler_job) session = settings.Session() dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) @@ -1062,7 +1086,7 @@ def test_find_executable_task_instances_max_active_tis_per_dag(self, dag_maker): session.merge(ti2) session.flush() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) assert 2 == len(res) @@ -1075,7 +1099,7 @@ def test_find_executable_task_instances_max_active_tis_per_dag(self, dag_maker): session.merge(ti1_2) session.flush() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) assert 1 == len(res) @@ -1086,7 +1110,7 @@ def test_find_executable_task_instances_max_active_tis_per_dag(self, dag_maker): session.merge(ti1_3) session.flush() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) assert 0 == len(res) @@ -1098,7 +1122,7 @@ def test_find_executable_task_instances_max_active_tis_per_dag(self, dag_maker): session.merge(ti1_3) session.flush() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) assert 2 == len(res) @@ -1110,7 +1134,7 @@ def test_find_executable_task_instances_max_active_tis_per_dag(self, dag_maker): session.merge(ti1_3) session.flush() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) assert 1 == len(res) session.rollback() @@ -1121,7 +1145,8 @@ def test_change_state_for_executable_task_instances_no_tis_with_state(self, dag_ with dag_maker(dag_id=dag_id, max_active_tasks=2): task1 = EmptyOperator(task_id=task_id_1) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) @@ -1140,7 +1165,7 @@ def test_change_state_for_executable_task_instances_no_tis_with_state(self, dag_ session.flush() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=100, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=100, session=session) assert 0 == len(res) session.rollback() @@ -1148,7 +1173,8 @@ def test_change_state_for_executable_task_instances_no_tis_with_state(self, dag_ def test_find_executable_task_instances_not_enough_pool_slots_for_first(self, dag_maker): set_default_pool_slots(1) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() dag_id = "SchedulerJobTest.test_find_executable_task_instances_not_enough_pool_slots_for_first" @@ -1166,14 +1192,15 @@ def test_find_executable_task_instances_not_enough_pool_slots_for_first(self, da # Schedule ti with lower priority, # because the one with higher priority is limited by a concurrency limit - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) assert 1 == len(res) assert res[0].key == ti2.key session.rollback() def test_find_executable_task_instances_not_enough_dag_concurrency_for_first(self, dag_maker): - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() dag_id_1 = ( @@ -1202,14 +1229,15 @@ def test_find_executable_task_instances_not_enough_dag_concurrency_for_first(sel # Schedule ti with lower priority, # because the one with higher priority is limited by a concurrency limit - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=1, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=1, session=session) assert 1 == len(res) assert res[0].key == ti2.key session.rollback() def test_find_executable_task_instances_not_enough_task_concurrency_for_first(self, dag_maker): - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() dag_id = "SchedulerJobTest.test_find_executable_task_instances_not_enough_task_concurrency_for_first" @@ -1230,7 +1258,7 @@ def test_find_executable_task_instances_not_enough_task_concurrency_for_first(se # Schedule ti with lower priority, # because the one with higher priority is limited by a concurrency limit - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=1, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=1, session=session) assert 1 == len(res) assert res[0].key == ti1b.key @@ -1244,7 +1272,8 @@ def test_find_executable_task_instances_negative_open_pool_slots(self, dag_maker """ set_default_pool_slots(0) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() pool1 = Pool(pool="pool1", slots=1) @@ -1266,7 +1295,7 @@ def test_find_executable_task_instances_negative_open_pool_slots(self, dag_maker ti2.state = State.RUNNING session.flush() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=1, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=1, session=session) assert 1 == len(res) assert res[0].key == ti1.key @@ -1274,7 +1303,8 @@ def test_find_executable_task_instances_negative_open_pool_slots(self, dag_maker @mock.patch("airflow.jobs.scheduler_job_runner.Stats.gauge") def test_emit_pool_starving_tasks_metrics(self, mock_stats_gauge, dag_maker): - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() dag_id = "SchedulerJobTest.test_emit_pool_starving_tasks_metrics" @@ -1289,7 +1319,7 @@ def test_emit_pool_starving_tasks_metrics(self, mock_stats_gauge, dag_maker): set_default_pool_slots(1) session.flush() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) assert 0 == len(res) mock_stats_gauge.assert_has_calls( @@ -1304,7 +1334,7 @@ def test_emit_pool_starving_tasks_metrics(self, mock_stats_gauge, dag_maker): set_default_pool_slots(2) session.flush() - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) assert 1 == len(res) mock_stats_gauge.assert_has_calls( @@ -1325,13 +1355,14 @@ def test_enqueue_task_instances_with_queued_state(self, dag_maker, session): with dag_maker(dag_id=dag_id, start_date=DEFAULT_DATE, session=session): task1 = EmptyOperator(task_id=task_id_1) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) dr1 = dag_maker.create_dagrun() ti1 = dr1.get_task_instance(task1.task_id, session) with patch.object(BaseExecutor, "queue_command") as mock_queue_command: - self.scheduler_job.job_runner._enqueue_task_instances_with_queued_state([ti1], session=session) + self.job_runner._enqueue_task_instances_with_queued_state([ti1], session=session) assert mock_queue_command.called session.rollback() @@ -1345,7 +1376,8 @@ def test_enqueue_task_instances_sets_ti_state_to_None_if_dagrun_in_finish_state( with dag_maker(dag_id=dag_id, start_date=DEFAULT_DATE, session=session): task1 = EmptyOperator(task_id=task_id_1) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) dr1 = dag_maker.create_dagrun(state=state) ti = dr1.get_task_instance(task1.task_id, session) @@ -1354,7 +1386,7 @@ def test_enqueue_task_instances_sets_ti_state_to_None_if_dagrun_in_finish_state( session.commit() with patch.object(BaseExecutor, "queue_command") as mock_queue_command: - self.scheduler_job.job_runner._enqueue_task_instances_with_queued_state([ti], session=session) + self.job_runner._enqueue_task_instances_with_queued_state([ti], session=session) session.flush() ti.refresh_from_db(session=session) assert ti.state == State.NONE @@ -1373,7 +1405,8 @@ def test_critical_section_enqueue_task_instances(self, dag_maker): task1 = EmptyOperator(task_id=task_id_1) task2 = EmptyOperator(task_id=task_id_2) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) # create first dag run with 1 running and 1 queued @@ -1399,7 +1432,7 @@ def test_critical_section_enqueue_task_instances(self, dag_maker): assert State.RUNNING == dr2.state - res = self.scheduler_job.job_runner._critical_section_enqueue_task_instances(session) + res = self.job_runner._critical_section_enqueue_task_instances(session) # check that max_active_tasks is respected ti1.refresh_from_db() @@ -1427,7 +1460,8 @@ def test_execute_task_instances_limit(self, dag_maker): task1 = EmptyOperator(task_id=task_id_1) task2 = EmptyOperator(task_id=task_id_2) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) def _create_dagruns(): dagrun = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, state=State.RUNNING) @@ -1447,20 +1481,20 @@ def _create_dagruns(): ti1.state = State.SCHEDULED ti2.state = State.SCHEDULED session.flush() - self.scheduler_job.max_tis_per_query = 2 - res = self.scheduler_job.job_runner._critical_section_enqueue_task_instances(session) + scheduler_job.max_tis_per_query = 2 + res = self.job_runner._critical_section_enqueue_task_instances(session) assert 2 == res - self.scheduler_job.max_tis_per_query = 8 + scheduler_job.max_tis_per_query = 8 with mock.patch.object( - type(self.scheduler_job.executor), "slots_available", new_callable=mock.PropertyMock + type(scheduler_job.executor), "slots_available", new_callable=mock.PropertyMock ) as mock_slots: mock_slots.return_value = 2 # Check that we don't "overfill" the executor assert 2 == res - res = self.scheduler_job.job_runner._critical_section_enqueue_task_instances(session) + res = self.job_runner._critical_section_enqueue_task_instances(session) - res = self.scheduler_job.job_runner._critical_section_enqueue_task_instances(session) + res = self.job_runner._critical_section_enqueue_task_instances(session) assert 4 == res for ti in tis: ti.refresh_from_db() @@ -1478,7 +1512,8 @@ def test_execute_task_instances_unlimited(self, dag_maker): task1 = EmptyOperator(task_id=task_id_1) task2 = EmptyOperator(task_id=task_id_2) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) def _create_dagruns(): dagrun = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, state=State.RUNNING) @@ -1497,10 +1532,10 @@ def _create_dagruns(): ti1.state = State.SCHEDULED ti2.state = State.SCHEDULED session.flush() - self.scheduler_job.max_tis_per_query = 0 - self.scheduler_job.executor = MagicMock(slots_available=36) + scheduler_job.max_tis_per_query = 0 + scheduler_job.executor = MagicMock(slots_available=36) - res = self.scheduler_job.job_runner._critical_section_enqueue_task_instances(session) + res = self.job_runner._critical_section_enqueue_task_instances(session) # 20 dag runs * 2 tasks each = 40, but limited by number of slots available assert res == 36 session.rollback() @@ -1526,10 +1561,10 @@ def test_adopt_or_reset_orphaned_tasks(self, dag_maker): processor = mock.MagicMock() - self.scheduler_job = Job(job_runner=SchedulerJobRunner(num_runs=0)) - self.scheduler_job.job_runner.processor_agent = processor - - self.scheduler_job.job_runner.adopt_or_reset_orphaned_tasks() + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=0) + self.job_runner.processor_agent = processor + self.job_runner.adopt_or_reset_orphaned_tasks() ti = dr.get_task_instance(task_id=op1.task_id, session=session) assert ti.state == State.NONE @@ -1542,30 +1577,28 @@ def test_executor_end_called(self, mock_processor_agent): """ Test to make sure executor.end gets called with a successful scheduler loop run """ - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull, num_runs=1)) - self.scheduler_job.executor = mock.MagicMock(slots_available=8) - - run_job(self.scheduler_job) - - self.scheduler_job.executor.end.assert_called_once() - self.scheduler_job.job_runner.processor_agent.end.assert_called_once() + scheduler_job = Job(executor=mock.MagicMock(slots_available=8)) + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) + run_job(scheduler_job, execute_callable=self.job_runner._execute) + scheduler_job.executor.end.assert_called_once() + self.job_runner.processor_agent.end.assert_called_once() @mock.patch("airflow.dag_processing.manager.DagFileProcessorAgent") def test_cleanup_methods_all_called(self, mock_processor_agent): """ Test to make sure all cleanup methods are called when the scheduler loop has an exception """ - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull, num_runs=1)) - self.scheduler_job.executor = mock.MagicMock(slots_available=8) - self.scheduler_job.job_runner._run_scheduler_loop = mock.MagicMock(side_effect=Exception("oops")) + scheduler_job = Job(executor=mock.MagicMock(slots_available=8)) + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) + self.job_runner._run_scheduler_loop = mock.MagicMock(side_effect=Exception("oops")) mock_processor_agent.return_value.end.side_effect = Exception("double oops") - self.scheduler_job.executor.end = mock.MagicMock(side_effect=Exception("triple oops")) + scheduler_job.executor.end = mock.MagicMock(side_effect=Exception("triple oops")) with pytest.raises(Exception): - run_job(self.scheduler_job) + run_job(scheduler_job, execute_callable=self.job_runner._execute) - self.scheduler_job.job_runner.processor_agent.end.assert_called_once() - self.scheduler_job.executor.end.assert_called_once() + self.job_runner.processor_agent.end.assert_called_once() + scheduler_job.executor.end.assert_called_once() mock_processor_agent.return_value.end.reset_mock(side_effect=True) def test_queued_dagruns_stops_creating_when_max_active_is_reached(self, dag_maker): @@ -1574,17 +1607,19 @@ def test_queued_dagruns_stops_creating_when_max_active_is_reached(self, dag_make EmptyOperator(task_id="mytask") session = settings.Session() - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor() - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + scheduler_job.executor = MockExecutor() + self.job_runner.processor_agent = mock.MagicMock() - self.scheduler_job.job_runner.dagbag = dag_maker.dagbag + self.job_runner.dagbag = dag_maker.dagbag session = settings.Session() orm_dag = session.get(DagModel, dag.dag_id) assert orm_dag is not None for _ in range(20): - self.scheduler_job.job_runner._create_dag_runs([orm_dag], session) + self.job_runner._create_dag_runs([orm_dag], session) drs = session.query(DagRun).all() assert len(drs) == 10 @@ -1594,7 +1629,7 @@ def test_queued_dagruns_stops_creating_when_max_active_is_reached(self, dag_make session.commit() assert session.query(DagRun.state).filter(DagRun.state == State.RUNNING).count() == 10 for _ in range(20): - self.scheduler_job.job_runner._create_dag_runs([orm_dag], session) + self.job_runner._create_dag_runs([orm_dag], session) assert session.query(DagRun).count() == 10 assert session.query(DagRun.state).filter(DagRun.state == State.RUNNING).count() == 10 assert session.query(DagRun.state).filter(DagRun.state == State.QUEUED).count() == 0 @@ -1604,9 +1639,11 @@ def test_runs_are_created_after_max_active_runs_was_reached(self, dag_maker, ses """ Test that when creating runs once max_active_runs is reached the runs does not stick """ - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor(do_update=True) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + scheduler_job.executor = MockExecutor(do_update=True) + self.job_runner.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) with dag_maker(max_active_runs=1, session=session) as dag: # Need to use something that doesn't immediately get marked as success by the scheduler @@ -1619,7 +1656,7 @@ def test_runs_are_created_after_max_active_runs_was_reached(self, dag_maker, ses # Reach max_active_runs for _ in range(3): - self.scheduler_job.job_runner._do_scheduling(session) + self.job_runner._do_scheduling(session) # Complete dagrun # Add dag_run back in to the session (_do_scheduling does an expunge_all) @@ -1629,7 +1666,7 @@ def test_runs_are_created_after_max_active_runs_was_reached(self, dag_maker, ses # create new run for _ in range(3): - self.scheduler_job.job_runner._do_scheduling(session) + self.job_runner._do_scheduling(session) # Assert that new runs has created dag_runs = DagRun.find(dag_id=dag.dag_id, session=session) @@ -1652,15 +1689,17 @@ def test_dagrun_timeout_verify_max_active_runs(self, dag_maker): ) as dag: EmptyOperator(task_id="dummy") - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.job_runner.dagbag = dag_maker.dagbag + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner.dagbag = dag_maker.dagbag session = settings.Session() orm_dag = session.get(DagModel, dag.dag_id) assert orm_dag is not None - self.scheduler_job.job_runner._create_dag_runs([orm_dag], session) - self.scheduler_job.job_runner._start_queued_dagruns(session) + self.job_runner._create_dag_runs([orm_dag], session) + self.job_runner._start_queued_dagruns(session) drs = DagRun.find(dag_id=dag.dag_id, session=session) assert len(drs) == 1 @@ -1677,9 +1716,9 @@ def test_dagrun_timeout_verify_max_active_runs(self, dag_maker): session.flush() # Mock that processor_agent is started - self.scheduler_job.job_runner.processor_agent = mock.Mock() + self.job_runner.processor_agent = mock.Mock() - callback = self.scheduler_job.job_runner._schedule_dag_run(dr, session) + callback = self.job_runner._schedule_dag_run(dr, session) session.flush() session.refresh(dr) @@ -1720,13 +1759,15 @@ def test_dagrun_timeout_fails_run(self, dag_maker): dr = dag_maker.create_dagrun(start_date=timezone.utcnow() - datetime.timedelta(days=1)) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.job_runner.dagbag = dag_maker.dagbag + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner.dagbag = dag_maker.dagbag # Mock that processor_agent is started - self.scheduler_job.job_runner.processor_agent = mock.Mock() + self.job_runner.processor_agent = mock.Mock() - callback = self.scheduler_job.job_runner._schedule_dag_run(dr, session) + callback = self.job_runner._schedule_dag_run(dr, session) session.flush() session.refresh(dr) @@ -1762,14 +1803,16 @@ def test_dagrun_timeout_fails_run_and_update_next_dagrun(self, dag_maker): dr = dag_maker.create_dagrun(start_date=timezone.utcnow() - datetime.timedelta(days=1)) # check that next_dagrun is dr.execution_date dag_maker.dag_model.next_dagrun == dr.execution_date - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.job_runner.dagbag = dag_maker.dagbag - self.scheduler_job.executor = MockExecutor() + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner.dagbag = dag_maker.dagbag + scheduler_job.executor = MockExecutor() # Mock that processor_agent is started - self.scheduler_job.job_runner.processor_agent = mock.Mock() + self.job_runner.processor_agent = mock.Mock() - self.scheduler_job.job_runner._schedule_dag_run(dr, session) + self.job_runner._schedule_dag_run(dr, session) session.flush() session.refresh(dr) assert dr.state == State.FAILED @@ -1796,10 +1839,12 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg, dag_mak ) as dag: EmptyOperator(task_id="dummy") - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor() - self.scheduler_job.job_runner.dagbag = dag_maker.dagbag - self.scheduler_job.job_runner.processor_agent = mock.Mock() + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + scheduler_job.executor = MockExecutor() + self.job_runner.dagbag = dag_maker.dagbag + self.job_runner.processor_agent = mock.Mock() session = settings.Session() dr = dag_maker.create_dagrun() @@ -1808,7 +1853,7 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg, dag_mak ti.set_state(state, session) with mock.patch.object(settings, "USE_JOB_SCHEDULE", False): - self.scheduler_job.job_runner._do_scheduling(session) + self.job_runner._do_scheduling(session) expected_callback = DagCallbackRequest( full_filepath=dag.fileloc, @@ -1820,7 +1865,7 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg, dag_mak ) # Verify dag failure callback request is sent to file processor - self.scheduler_job.executor.callback_sink.send.assert_called_once_with(expected_callback) + scheduler_job.executor.callback_sink.send.assert_called_once_with(expected_callback) session.rollback() session.close() @@ -1842,10 +1887,12 @@ def test_dagrun_plugins_are_notified(self, state, expected_callback_msg, dag_mak dag_listener.clear() get_listener_manager().add_listener(dag_listener) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor() - self.scheduler_job.job_runner.dagbag = dag_maker.dagbag - self.scheduler_job.job_runner.processor_agent = mock.Mock() + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + scheduler_job.executor = MockExecutor() + self.job_runner.dagbag = dag_maker.dagbag + self.job_runner.processor_agent = mock.Mock() session = settings.Session() dr = dag_maker.create_dagrun() @@ -1854,7 +1901,7 @@ def test_dagrun_plugins_are_notified(self, state, expected_callback_msg, dag_mak ti.set_state(state, session) with mock.patch.object(settings, "USE_JOB_SCHEDULE", False): - self.scheduler_job.job_runner._do_scheduling(session) + self.job_runner._do_scheduling(session) assert len(dag_listener.success) or len(dag_listener.failure) @@ -1872,16 +1919,18 @@ def test_dagrun_timeout_callbacks_are_stored_in_database(self, dag_maker, sessio ) as dag: EmptyOperator(task_id="empty") - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor() - self.scheduler_job.executor.callback_sink = DatabaseCallbackSink() - self.scheduler_job.job_runner.dagbag = dag_maker.dagbag - self.scheduler_job.job_runner.processor_agent = mock.Mock() + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + scheduler_job.executor = MockExecutor() + scheduler_job.executor.callback_sink = DatabaseCallbackSink() + self.job_runner.dagbag = dag_maker.dagbag + self.job_runner.processor_agent = mock.Mock() dr = dag_maker.create_dagrun(start_date=DEFAULT_DATE) with mock.patch.object(settings, "USE_JOB_SCHEDULE", False): - self.scheduler_job.job_runner._do_scheduling(session) + self.job_runner._do_scheduling(session) callback = ( session.query(DbCallbackRequest) @@ -1909,10 +1958,12 @@ def test_dagrun_callbacks_commited_before_sent(self, dag_maker): with dag_maker(dag_id="test_dagrun_callbacks_commited_before_sent"): EmptyOperator(task_id="dummy") - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.job_runner.processor_agent = mock.Mock() - self.scheduler_job.job_runner._send_dag_callbacks_to_processor = mock.Mock() - self.scheduler_job.job_runner._schedule_dag_run = mock.Mock() + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner.processor_agent = mock.Mock() + self.job_runner._send_dag_callbacks_to_processor = mock.Mock() + self.job_runner._schedule_dag_run = mock.Mock() dr = dag_maker.create_dagrun() session = settings.Session() @@ -1932,15 +1983,15 @@ def mock_schedule_dag_run(*args, **kwargs): def mock_send_dag_callbacks_to_processor(*args, **kwargs): mock_guard.return_value.__enter__.return_value.commit.assert_called() - self.scheduler_job.job_runner._send_dag_callbacks_to_processor.side_effect = ( + self.job_runner._send_dag_callbacks_to_processor.side_effect = ( mock_send_dag_callbacks_to_processor ) - self.scheduler_job.job_runner._schedule_dag_run.side_effect = mock_schedule_dag_run + self.job_runner._schedule_dag_run.side_effect = mock_schedule_dag_run - self.scheduler_job.job_runner._do_scheduling(session) + self.job_runner._do_scheduling(session) # Verify dag failure callback request is sent to file processor - self.scheduler_job.job_runner._send_dag_callbacks_to_processor.assert_called_once() + self.job_runner._send_dag_callbacks_to_processor.assert_called_once() # and mock_send_dag_callbacks_to_processor has asserted the callback was sent after a commit session.rollback() @@ -1956,9 +2007,11 @@ def test_dagrun_callbacks_are_not_added_when_callbacks_are_not_defined(self, sta ): BashOperator(task_id="test_task", bash_command="echo hi") - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.job_runner.processor_agent = mock.Mock() - self.scheduler_job.job_runner._send_dag_callbacks_to_processor = mock.Mock() + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner.processor_agent = mock.Mock() + self.job_runner._send_dag_callbacks_to_processor = mock.Mock() session = settings.Session() dr = dag_maker.create_dagrun() @@ -1966,11 +2019,11 @@ def test_dagrun_callbacks_are_not_added_when_callbacks_are_not_defined(self, sta ti.set_state(state, session) with mock.patch.object(settings, "USE_JOB_SCHEDULE", False): - self.scheduler_job.job_runner._do_scheduling(session) + self.job_runner._do_scheduling(session) # Verify Callback is not set (i.e is None) when no callbacks are set on DAG - self.scheduler_job.job_runner._send_dag_callbacks_to_processor.assert_called_once() - call_args = self.scheduler_job.job_runner._send_dag_callbacks_to_processor.call_args[0] + self.job_runner._send_dag_callbacks_to_processor.assert_called_once() + call_args = self.job_runner._send_dag_callbacks_to_processor.call_args[0] assert call_args[0].dag_id == dr.dag_id assert call_args[1] is None @@ -1989,9 +2042,11 @@ def test_dagrun_callbacks_are_added_when_callbacks_are_defined(self, state, msg, ): BashOperator(task_id="test_task", bash_command="echo hi") - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.job_runner.processor_agent = mock.Mock() - self.scheduler_job.job_runner._send_dag_callbacks_to_processor = mock.Mock() + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner.processor_agent = mock.Mock() + self.job_runner._send_dag_callbacks_to_processor = mock.Mock() session = settings.Session() dr = dag_maker.create_dagrun() @@ -1999,11 +2054,11 @@ def test_dagrun_callbacks_are_added_when_callbacks_are_defined(self, state, msg, ti.set_state(state, session) with mock.patch.object(settings, "USE_JOB_SCHEDULE", False): - self.scheduler_job.job_runner._do_scheduling(session) + self.job_runner._do_scheduling(session) # Verify Callback is set (i.e is None) when no callbacks are set on DAG - self.scheduler_job.job_runner._send_dag_callbacks_to_processor.assert_called_once() - call_args = self.scheduler_job.job_runner._send_dag_callbacks_to_processor.call_args[0] + self.job_runner._send_dag_callbacks_to_processor.assert_called_once() + call_args = self.job_runner._send_dag_callbacks_to_processor.call_args[0] assert call_args[0].dag_id == dr.dag_id assert call_args[1] is not None assert call_args[1].msg == msg @@ -2024,9 +2079,10 @@ def test_dagrun_notify_called_success(self, dag_maker): executor = MockExecutor(do_update=False) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(), executor=executor) - self.scheduler_job.job_runner.dagbag = dag_maker.dagbag - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job(executor=executor) + self.job_runner = SchedulerJobRunner(scheduler_job) + self.job_runner.dagbag = dag_maker.dagbag + self.job_runner.processor_agent = mock.MagicMock() session = settings.Session() dr = dag_maker.create_dagrun() @@ -2035,7 +2091,7 @@ def test_dagrun_notify_called_success(self, dag_maker): ti.set_state(State.SUCCESS, session) with mock.patch.object(settings, "USE_JOB_SCHEDULE", False): - self.scheduler_job.job_runner._do_scheduling(session) + self.job_runner._do_scheduling(session) assert dag_listener.success[0].dag_id == dr.dag_id assert dag_listener.success[0].run_id == dr.run_id @@ -2064,8 +2120,10 @@ def test_do_not_schedule_removed_task(self, dag_maker): ): pass - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - res = self.scheduler_job.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) assert [] == res @@ -2212,11 +2270,10 @@ def test_dagrun_root_after_dagrun_unfinished(self): dag_id = "test_dagrun_states_root_future" dag = self.dagbag.get_dag(dag_id) dag.sync_to_db() - self.scheduler_job = Job( - job_runner=SchedulerJobRunner(num_runs=1, subdir=dag.fileloc), - executor=self.null_exec, - ) - run_job(self.scheduler_job) + + scheduler_job = Job(executor=self.null_exec) + self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1, subdir=dag.fileloc) + run_job(scheduler_job, execute_callable=self.job_runner._execute) first_run = DagRun.find(dag_id=dag_id, execution_date=DEFAULT_DATE)[0] ti_ids = [(ti.task_id, ti.state) for ti in first_run.get_task_instances()] @@ -2283,11 +2340,11 @@ def test_scheduler_start_date(self, configs): other_dag.is_paused_upon_creation = True other_dag.sync_to_db() - self.scheduler_job = Job( - job_runner=SchedulerJobRunner(subdir=dag.fileloc, num_runs=1), + scheduler_job = Job( executor=self.null_exec, ) - run_job(self.scheduler_job) + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=dag.fileloc, num_runs=1) + run_job(scheduler_job, execute_callable=self.job_runner._execute) # zero tasks ran assert len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()) == 0 @@ -2299,11 +2356,11 @@ def test_scheduler_start_date(self, configs): # That behavior still exists, but now it will only do so if after the # start date bf_exec = MockExecutor() - backfill_job = Job( - BackfillJobRunner(dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE), - executor=bf_exec, + backfill_job = Job(executor=bf_exec) + job_runner = BackfillJobRunner( + job=backfill_job, dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE ) - run_job(backfill_job) + run_job(job=backfill_job, execute_callable=job_runner._execute) # one task ran assert len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()) == 1 @@ -2315,11 +2372,11 @@ def test_scheduler_start_date(self, configs): ] == bf_exec.sorted_tasks session.commit() - self.scheduler_job = Job( - job_runner=SchedulerJobRunner(dag.fileloc, num_runs=1), + scheduler_job = Job( executor=self.null_exec, ) - run_job(self.scheduler_job) + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=dag.fileloc, num_runs=1) + run_job(scheduler_job, execute_callable=self.job_runner._execute) # still one task assert len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()) == 1 @@ -2354,11 +2411,11 @@ def test_scheduler_task_start_date(self, configs): dagbag.sync_to_db() - self.scheduler_job = Job( - job_runner=SchedulerJobRunner(subdir=dag.fileloc, num_runs=3), + scheduler_job = Job( executor=self.null_exec, ) - run_job(self.scheduler_job) + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=dag.fileloc, num_runs=3) + run_job(scheduler_job, execute_callable=self.job_runner._execute) session = settings.Session() tiq = session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id) @@ -2386,14 +2443,15 @@ def test_scheduler_multiprocessing(self, configs): dag = self.dagbag.get_dag(dag_id) dag.clear() - self.scheduler_job = Job( - job_runner=SchedulerJobRunner( - subdir=os.path.join(TEST_DAG_FOLDER, "test_scheduler_dags.py"), - num_runs=1, - ), + scheduler_job = Job( executor=self.null_exec, ) - run_job(self.scheduler_job) + self.job_runner = SchedulerJobRunner( + job=scheduler_job, + subdir=os.path.join(TEST_DAG_FOLDER, "test_scheduler_dags.py"), + num_runs=1, + ) + run_job(scheduler_job, execute_callable=self.job_runner._execute) # zero tasks ran dag_id = "test_start_date_scheduling" @@ -2424,18 +2482,19 @@ def test_scheduler_verify_pool_full(self, dag_maker, configs): session.add(pool) session.flush() - self.scheduler_job = Job(job_runner=SchedulerJobRunner(), executor=self.null_exec) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job(executor=self.null_exec) + self.job_runner = SchedulerJobRunner(job=scheduler_job) + self.job_runner.processor_agent = mock.MagicMock() # Create 2 dagruns, which will create 2 task instances. dr = dag_maker.create_dagrun( run_type=DagRunType.SCHEDULED, ) - self.scheduler_job.job_runner._schedule_dag_run(dr, session) + self.job_runner._schedule_dag_run(dr, session) dr = dag_maker.create_dagrun_after(dr, run_type=DagRunType.SCHEDULED, state=State.RUNNING) - self.scheduler_job.job_runner._schedule_dag_run(dr, session) + self.job_runner._schedule_dag_run(dr, session) session.flush() - task_instances_list = self.scheduler_job.job_runner._executable_task_instances_to_queued( + task_instances_list = self.job_runner._executable_task_instances_to_queued( max_tis=32, session=session ) @@ -2464,8 +2523,10 @@ def test_scheduler_verify_pool_full_2_slots_per_task(self, dag_maker, session): session.add(pool) session.flush() - self.scheduler_job = Job(job_runner=SchedulerJobRunner(), executor=self.null_exec) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job(executor=self.null_exec) + self.job_runner = SchedulerJobRunner(job=scheduler_job) + + self.job_runner.processor_agent = mock.MagicMock() # Create 5 dagruns, which will create 5 task instances. def _create_dagruns(): @@ -2476,9 +2537,9 @@ def _create_dagruns(): yield dr for dr in _create_dagruns(): - self.scheduler_job.job_runner._schedule_dag_run(dr, session) + self.job_runner._schedule_dag_run(dr, session) - task_instances_list = self.scheduler_job.job_runner._executable_task_instances_to_queued( + task_instances_list = self.job_runner._executable_task_instances_to_queued( max_tis=32, session=session ) @@ -2518,8 +2579,9 @@ def test_scheduler_keeps_scheduling_pool_full(self, dag_maker): session.add(pool_p2) session.flush() - scheduler = Job(job_runner=SchedulerJobRunner(), executor=self.null_exec) - scheduler.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job(executor=self.null_exec) + self.job_runner = SchedulerJobRunner(job=scheduler_job) + self.job_runner.processor_agent = mock.MagicMock() def _create_dagruns(dag: DAG): next_info = dag.next_dagrun_info(None) @@ -2539,12 +2601,12 @@ def _create_dagruns(dag: DAG): # To increase the chances the TIs from the "full" pool will get retrieved first, we schedule all # TIs from the first dag first. for dr in _create_dagruns(dag_d1): - scheduler.job_runner._schedule_dag_run(dr, session) + self.job_runner._schedule_dag_run(dr, session) for dr in _create_dagruns(dag_d2): - scheduler.job_runner._schedule_dag_run(dr, session) + self.job_runner._schedule_dag_run(dr, session) - scheduler.job_runner._executable_task_instances_to_queued(max_tis=2, session=session) - task_instances_list2 = scheduler.job_runner._executable_task_instances_to_queued( + self.job_runner._executable_task_instances_to_queued(max_tis=2, session=session) + task_instances_list2 = self.job_runner._executable_task_instances_to_queued( max_tis=2, session=session ) @@ -2593,8 +2655,10 @@ def test_scheduler_verify_priority_and_slots(self, dag_maker): session.add(pool) session.flush() - self.scheduler_job = Job(job_runner=SchedulerJobRunner(), executor=self.null_exec) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job(executor=self.null_exec) + self.job_runner = SchedulerJobRunner(job=scheduler_job) + + self.job_runner.processor_agent = mock.MagicMock() dr = dag_maker.create_dagrun() for ti in dr.task_instances: @@ -2602,7 +2666,7 @@ def test_scheduler_verify_priority_and_slots(self, dag_maker): session.merge(ti) session.flush() - task_instances_list = self.scheduler_job.job_runner._executable_task_instances_to_queued( + task_instances_list = self.job_runner._executable_task_instances_to_queued( max_tis=32, session=session ) @@ -2640,18 +2704,19 @@ def test_verify_integrity_if_dag_not_changed(self, dag_maker): with dag_maker(dag_id="test_verify_integrity_if_dag_not_changed") as dag: BashOperator(task_id="dummy", bash_command="echo hi") - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() orm_dag = dag_maker.dag_model assert orm_dag is not None - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() - dag = self.scheduler_job.job_runner.dagbag.get_dag( - "test_verify_integrity_if_dag_not_changed", session=session - ) - self.scheduler_job.job_runner._create_dag_runs([orm_dag], session) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner.processor_agent = mock.MagicMock() + dag = self.job_runner.dagbag.get_dag("test_verify_integrity_if_dag_not_changed", session=session) + self.job_runner._create_dag_runs([orm_dag], session) drs = DagRun.find(dag_id=dag.dag_id, session=session) assert len(drs) == 1 @@ -2659,7 +2724,7 @@ def test_verify_integrity_if_dag_not_changed(self, dag_maker): # Verify that DagRun.verify_integrity is not called with mock.patch("airflow.jobs.scheduler_job_runner.DagRun.verify_integrity") as mock_verify_integrity: - self.scheduler_job.job_runner._schedule_dag_run(dr, session) + self.job_runner._schedule_dag_run(dr, session) mock_verify_integrity.assert_not_called() session.flush() @@ -2691,18 +2756,19 @@ def test_verify_integrity_if_dag_changed(self, dag_maker): with dag_maker(dag_id="test_verify_integrity_if_dag_changed") as dag: BashOperator(task_id="dummy", bash_command="echo hi") - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() orm_dag = dag_maker.dag_model assert orm_dag is not None - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() - dag = self.scheduler_job.job_runner.dagbag.get_dag( - "test_verify_integrity_if_dag_changed", session=session - ) - self.scheduler_job.job_runner._create_dag_runs([orm_dag], session) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner.processor_agent = mock.MagicMock() + dag = self.job_runner.dagbag.get_dag("test_verify_integrity_if_dag_changed", session=session) + self.job_runner._create_dag_runs([orm_dag], session) drs = DagRun.find(dag_id=dag.dag_id, session=session) assert len(drs) == 1 @@ -2710,11 +2776,8 @@ def test_verify_integrity_if_dag_changed(self, dag_maker): dag_version_1 = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session) assert dr.dag_hash == dag_version_1 - assert self.scheduler_job.job_runner.dagbag.dags == {"test_verify_integrity_if_dag_changed": dag} - assert ( - len(self.scheduler_job.job_runner.dagbag.dags.get("test_verify_integrity_if_dag_changed").tasks) - == 1 - ) + assert self.job_runner.dagbag.dags == {"test_verify_integrity_if_dag_changed": dag} + assert len(self.job_runner.dagbag.dags.get("test_verify_integrity_if_dag_changed").tasks) == 1 # Now let's say the DAG got updated (new task got added) BashOperator(task_id="bash_task_1", dag=dag, bash_command="echo hi") @@ -2723,18 +2786,15 @@ def test_verify_integrity_if_dag_changed(self, dag_maker): dag_version_2 = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session) assert dag_version_2 != dag_version_1 - self.scheduler_job.job_runner._schedule_dag_run(dr, session) + self.job_runner._schedule_dag_run(dr, session) session.flush() drs = DagRun.find(dag_id=dag.dag_id, session=session) assert len(drs) == 1 dr = drs[0] assert dr.dag_hash == dag_version_2 - assert self.scheduler_job.job_runner.dagbag.dags == {"test_verify_integrity_if_dag_changed": dag} - assert ( - len(self.scheduler_job.job_runner.dagbag.dags.get("test_verify_integrity_if_dag_changed").tasks) - == 2 - ) + assert self.job_runner.dagbag.dags == {"test_verify_integrity_if_dag_changed": dag} + assert len(self.job_runner.dagbag.dags.get("test_verify_integrity_if_dag_changed").tasks) == 2 tis_count = ( session.query(func.count(TaskInstance.task_id)) @@ -2763,18 +2823,19 @@ def test_verify_integrity_if_dag_disappeared(self, dag_maker, caplog): with dag_maker(dag_id="test_verify_integrity_if_dag_disappeared") as dag: BashOperator(task_id="dummy", bash_command="echo hi") - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() orm_dag = dag_maker.dag_model assert orm_dag is not None - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() - dag = self.scheduler_job.job_runner.dagbag.get_dag( - "test_verify_integrity_if_dag_disappeared", session=session - ) - self.scheduler_job.job_runner._create_dag_runs([orm_dag], session) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner.processor_agent = mock.MagicMock() + dag = self.job_runner.dagbag.get_dag("test_verify_integrity_if_dag_disappeared", session=session) + self.job_runner._create_dag_runs([orm_dag], session) dag_id = dag.dag_id drs = DagRun.find(dag_id=dag_id, session=session) assert len(drs) == 1 @@ -2782,23 +2843,16 @@ def test_verify_integrity_if_dag_disappeared(self, dag_maker, caplog): dag_version_1 = SerializedDagModel.get_latest_version_hash(dag_id, session=session) assert dr.dag_hash == dag_version_1 - assert self.scheduler_job.job_runner.dagbag.dags == {"test_verify_integrity_if_dag_disappeared": dag} - assert ( - len( - self.scheduler_job.job_runner.dagbag.dags.get( - "test_verify_integrity_if_dag_disappeared" - ).tasks - ) - == 1 - ) + assert self.job_runner.dagbag.dags == {"test_verify_integrity_if_dag_disappeared": dag} + assert len(self.job_runner.dagbag.dags.get("test_verify_integrity_if_dag_disappeared").tasks) == 1 SerializedDagModel.remove_dag(dag_id=dag_id) - dag = self.scheduler_job.job_runner.dagbag.dags[dag_id] - self.scheduler_job.job_runner.dagbag.dags = MagicMock() - self.scheduler_job.job_runner.dagbag.dags.get.side_effect = [dag, None] + dag = self.job_runner.dagbag.dags[dag_id] + self.job_runner.dagbag.dags = MagicMock() + self.job_runner.dagbag.dags.get.side_effect = [dag, None] session.flush() with caplog.at_level(logging.WARNING): - callback = self.scheduler_job.job_runner._schedule_dag_run(dr, session) + callback = self.job_runner._schedule_dag_run(dr, session) assert "The DAG disappeared before verifying integrity" in caplog.text assert callback is None @@ -2832,16 +2886,16 @@ def do_schedule(session): # Use a empty file since the above mock will return the # expected DAGs. Also specify only a single file so that it doesn't # try to schedule the above DAG repeatedly. - self.scheduler_job = Job( - job_runner=SchedulerJobRunner(num_runs=1, subdir=os.devnull), + scheduler_job = Job( executor=executor, ) - self.scheduler_job.job_runner.dagbag = dag_maker.dagbag - self.scheduler_job.heartrate = 0 + self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1, subdir=os.devnull) + self.job_runner.dagbag = dag_maker.dagbag + scheduler_job.heartrate = 0 # Since the DAG is not in the directory watched by scheduler job, # it would've been marked as deleted and not being scheduled. with mock.patch.object(DagModel, "deactivate_deleted_dags"): - run_job(self.scheduler_job) + run_job(scheduler_job, execute_callable=self.job_runner._execute) do_schedule() with create_session() as session: @@ -2893,9 +2947,9 @@ def test_retry_handling_job(self): dag_task1 = dag.get_task("test_retry_handling_op") dag.clear() - self.scheduler_job = Job(job_runner=SchedulerJobRunner(dag_id=dag.dag_id, num_runs=1)) - self.scheduler_job.heartrate = 0 - run_job(self.scheduler_job) + scheduler_job = Job(jobe_type=SchedulerJobRunner.job_type, heartrate=0) + self.job_runner = SchedulerJobRunner(job=scheduler_job, dag_id=dag.dag_id, num_runs=1) + run_job(scheduler_job, execute_callable=self.job_runner._execute) session = settings.Session() ti = ( @@ -2987,9 +3041,10 @@ def test_list_py_file_paths(self): def test_adopt_or_reset_orphaned_tasks_nothing(self): """Try with nothing.""" - self.scheduler_job = Job(job_runner=SchedulerJobRunner()) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job) session = settings.Session() - assert 0 == self.scheduler_job.job_runner.adopt_or_reset_orphaned_tasks(session=session) + assert 0 == self.job_runner.adopt_or_reset_orphaned_tasks(session=session) def test_adopt_or_reset_orphaned_tasks_external_triggered_dag(self, dag_maker): dag_id = "test_reset_orphaned_tasks_external_triggered_dag" @@ -2997,7 +3052,8 @@ def test_adopt_or_reset_orphaned_tasks_external_triggered_dag(self, dag_maker): task_id = dag_id + "_task" EmptyOperator(task_id=task_id) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() dr1 = dag_maker.create_dagrun(external_trigger=True) @@ -3007,7 +3063,7 @@ def test_adopt_or_reset_orphaned_tasks_external_triggered_dag(self, dag_maker): session.merge(dr1) session.commit() - num_reset_tis = self.scheduler_job.job_runner.adopt_or_reset_orphaned_tasks(session=session) + num_reset_tis = self.job_runner.adopt_or_reset_orphaned_tasks(session=session) assert 1 == num_reset_tis def test_adopt_or_reset_orphaned_tasks_backfill_dag(self, dag_maker): @@ -3016,9 +3072,10 @@ def test_adopt_or_reset_orphaned_tasks_backfill_dag(self, dag_maker): task_id = dag_id + "_task" EmptyOperator(task_id=task_id) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() - session.add(self.scheduler_job) + session.add(scheduler_job) session.flush() dr1 = dag_maker.create_dagrun(run_type=DagRunType.BACKFILL_JOB) @@ -3030,7 +3087,7 @@ def test_adopt_or_reset_orphaned_tasks_backfill_dag(self, dag_maker): session.flush() assert dr1.is_backfill - assert 0 == self.scheduler_job.job_runner.adopt_or_reset_orphaned_tasks(session=session) + assert 0 == self.job_runner.adopt_or_reset_orphaned_tasks(session=session) session.rollback() def test_reset_orphaned_tasks_no_orphans(self, dag_maker): @@ -3039,20 +3096,21 @@ def test_reset_orphaned_tasks_no_orphans(self, dag_maker): task_id = dag_id + "_task" EmptyOperator(task_id=task_id) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() - session.add(self.scheduler_job) + session.add(scheduler_job) session.flush() dr1 = dag_maker.create_dagrun() tis = dr1.get_task_instances(session=session) tis[0].state = State.RUNNING - tis[0].queued_by_job_id = self.scheduler_job.id + tis[0].queued_by_job_id = scheduler_job.id session.merge(dr1) session.merge(tis[0]) session.flush() - assert 0 == self.scheduler_job.job_runner.adopt_or_reset_orphaned_tasks(session=session) + assert 0 == self.job_runner.adopt_or_reset_orphaned_tasks(session=session) tis[0].refresh_from_db() assert State.RUNNING == tis[0].state @@ -3063,21 +3121,22 @@ def test_reset_orphaned_tasks_non_running_dagruns(self, dag_maker): task_id = dag_id + "_task" EmptyOperator(task_id=task_id) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() - session.add(self.scheduler_job) + session.add(scheduler_job) session.flush() dr1 = dag_maker.create_dagrun() tis = dr1.get_task_instances(session=session) assert 1 == len(tis) tis[0].state = State.SCHEDULED - tis[0].queued_by_job_id = self.scheduler_job.id + tis[0].queued_by_job_id = scheduler_job.id session.merge(dr1) session.merge(tis[0]) session.flush() - assert 0 == self.scheduler_job.job_runner.adopt_or_reset_orphaned_tasks(session=session) + assert 0 == self.job_runner.adopt_or_reset_orphaned_tasks(session=session) session.rollback() def test_adopt_or_reset_orphaned_tasks_stale_scheduler_jobs(self, dag_maker): @@ -3086,13 +3145,15 @@ def test_adopt_or_reset_orphaned_tasks_stale_scheduler_jobs(self, dag_maker): EmptyOperator(task_id="task1") EmptyOperator(task_id="task2") - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() - self.scheduler_job.state = State.RUNNING - self.scheduler_job.latest_heartbeat = timezone.utcnow() - session.add(self.scheduler_job) + scheduler_job.state = State.RUNNING + scheduler_job.latest_heartbeat = timezone.utcnow() + session.add(scheduler_job) - old_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + old_job = Job() + old_job_runner = SchedulerJobRunner(job=old_job, subdir=os.devnull) old_job.state = State.RUNNING old_job.latest_heartbeat = timezone.utcnow() - timedelta(minutes=15) session.add(old_job) @@ -3112,11 +3173,11 @@ def test_adopt_or_reset_orphaned_tasks_stale_scheduler_jobs(self, dag_maker): session.merge(ti1) ti2.state = State.QUEUED - ti2.queued_by_job_id = self.scheduler_job.id + ti2.queued_by_job_id = scheduler_job.id session.merge(ti2) session.flush() - num_reset_tis = self.scheduler_job.job_runner.adopt_or_reset_orphaned_tasks(session=session) + num_reset_tis = self.job_runner.adopt_or_reset_orphaned_tasks(session=session) assert 1 == num_reset_tis @@ -3125,32 +3186,35 @@ def test_adopt_or_reset_orphaned_tasks_stale_scheduler_jobs(self, dag_maker): session.refresh(ti2) assert ti2.state == State.QUEUED session.rollback() - if old_job.job_runner.processor_agent: - old_job.job_runner.processor_agent.end() + if old_job_runner.processor_agent: + old_job_runner.processor_agent.end() def test_adopt_or_reset_orphaned_tasks_only_fails_scheduler_jobs(self, caplog): """Make sure we only set SchedulerJobs to failed, not all jobs""" session = settings.Session() - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.state = State.RUNNING - self.scheduler_job.latest_heartbeat = timezone.utcnow() - session.add(self.scheduler_job) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + scheduler_job.state = State.RUNNING + scheduler_job.latest_heartbeat = timezone.utcnow() + session.add(scheduler_job) session.flush() - old_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) + old_job = Job() + self.job_runner = SchedulerJobRunner(job=old_job, subdir=os.devnull) old_job.state = State.RUNNING old_job.latest_heartbeat = timezone.utcnow() - timedelta(minutes=15) session.add(old_job) session.flush() - old_task_job = Job(job_runner=MockJobRunner(), state=State.RUNNING) + old_task_job = Job(state=State.RUNNING) old_task_job.latest_heartbeat = timezone.utcnow() - timedelta(minutes=15) session.add(old_task_job) session.flush() with caplog.at_level("INFO", logger="airflow.jobs.scheduler_job_runner"): - self.scheduler_job.job_runner.adopt_or_reset_orphaned_tasks(session=session) + self.job_runner.adopt_or_reset_orphaned_tasks(session=session) session.expire_all() assert old_job.state == State.FAILED @@ -3164,11 +3228,11 @@ def test_send_sla_callbacks_to_processor_sla_disabled(self, dag_maker): EmptyOperator(task_id="task1") with patch.object(settings, "CHECK_SLAS", False): - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor() - - self.scheduler_job.job_runner._send_sla_callbacks_to_processor(dag) - self.scheduler_job.executor.callback_sink.send.assert_not_called() + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + scheduler_job.executor = MockExecutor() + self.job_runner._send_sla_callbacks_to_processor(dag) + scheduler_job.executor.callback_sink.send.assert_not_called() def test_send_sla_callbacks_to_processor_sla_no_task_slas(self, dag_maker): """Test SLA Callbacks are not sent when no task SLAs are defined""" @@ -3177,11 +3241,11 @@ def test_send_sla_callbacks_to_processor_sla_no_task_slas(self, dag_maker): EmptyOperator(task_id="task1") with patch.object(settings, "CHECK_SLAS", True): - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor() - - self.scheduler_job.job_runner._send_sla_callbacks_to_processor(dag) - self.scheduler_job.executor.callback_sink.send.assert_not_called() + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + scheduler_job.executor = MockExecutor() + self.job_runner._send_sla_callbacks_to_processor(dag) + scheduler_job.executor.callback_sink.send.assert_not_called() @pytest.mark.parametrize( "schedule", @@ -3202,17 +3266,16 @@ def test_send_sla_callbacks_to_processor_sla_with_task_slas(self, schedule, dag_ EmptyOperator(task_id="task1", sla=timedelta(seconds=60)) with patch.object(settings, "CHECK_SLAS", True): - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor() - - self.scheduler_job.job_runner._send_sla_callbacks_to_processor(dag) - + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + scheduler_job.executor = MockExecutor() + self.job_runner._send_sla_callbacks_to_processor(dag) expected_callback = SlaCallbackRequest( full_filepath=dag.fileloc, dag_id=dag.dag_id, processor_subdir=TEST_DAG_FOLDER, ) - self.scheduler_job.executor.callback_sink.send.assert_called_once_with(expected_callback) + scheduler_job.executor.callback_sink.send.assert_called_once_with(expected_callback) @pytest.mark.parametrize( "schedule", @@ -3228,11 +3291,11 @@ def test_send_sla_callbacks_to_processor_sla_dag_not_scheduled(self, schedule, d EmptyOperator(task_id="task1", sla=timedelta(seconds=5)) with patch.object(settings, "CHECK_SLAS", True): - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor() - - self.scheduler_job.job_runner._send_sla_callbacks_to_processor(dag) - self.scheduler_job.executor.callback_sink.send.assert_not_called() + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + scheduler_job.executor = MockExecutor() + self.job_runner._send_sla_callbacks_to_processor(dag) + scheduler_job.executor.callback_sink.send.assert_not_called() def test_create_dag_runs(self, dag_maker): """ @@ -3247,18 +3310,20 @@ def test_create_dag_runs(self, dag_maker): dag_model = dag_maker.dag_model - self.scheduler_job = Job(job_runner=SchedulerJobRunner(), executor=self.null_exec) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job(executor=self.null_exec) + self.job_runner = SchedulerJobRunner(job=scheduler_job) + + self.job_runner.processor_agent = mock.MagicMock() with create_session() as session: - self.scheduler_job.job_runner._create_dag_runs([dag_model], session) + self.job_runner._create_dag_runs([dag_model], session) dr = session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).first() # Assert dr state is queued assert dr.state == State.QUEUED assert dr.start_date is None - assert dag.get_last_dagrun().creating_job_id == self.scheduler_job.id + assert dag.get_last_dagrun().creating_job_id == scheduler_job.id @pytest.mark.need_serialized_dag def test_create_dag_runs_datasets(self, session, dag_maker): @@ -3324,11 +3389,13 @@ def test_create_dag_runs_datasets(self, session, dag_maker): ) session.flush() - self.scheduler_job = Job(job_runner=SchedulerJobRunner(), executor=self.null_exec) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job(executor=self.null_exec) + self.job_runner = SchedulerJobRunner(job=scheduler_job) + + self.job_runner.processor_agent = mock.MagicMock() with create_session() as session: - self.scheduler_job.job_runner._create_dagruns_for_dags(session, session) + self.job_runner._create_dagruns_for_dags(session, session) def dict_from_obj(obj): """Get dict of column attrs from SqlAlchemy object.""" @@ -3353,7 +3420,7 @@ def dict_from_obj(obj): # dag3 DDRQ record should be deleted since the dag run was triggered assert session.query(DatasetDagRunQueue).filter_by(target_dag_id=dag3.dag_id).one_or_none() is None - assert dag3.get_last_dagrun().creating_job_id == self.scheduler_job.id + assert dag3.get_last_dagrun().creating_job_id == scheduler_job.id @time_machine.travel(DEFAULT_DATE + datetime.timedelta(days=1, seconds=9), tick=False) @mock.patch("airflow.jobs.scheduler_job_runner.Stats.timing") @@ -3371,12 +3438,14 @@ def test_start_dagruns(self, stats_timing, dag_maker): dag_model = dag_maker.dag_model - self.scheduler_job = Job(job_runner=SchedulerJobRunner(), executor=self.null_exec) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job(executor=self.null_exec) + self.job_runner = SchedulerJobRunner(job=scheduler_job) + + self.job_runner.processor_agent = mock.MagicMock() with create_session() as session: - self.scheduler_job.job_runner._create_dag_runs([dag_model], session) - self.scheduler_job.job_runner._start_queued_dagruns(session) + self.job_runner._create_dag_runs([dag_model], session) + self.job_runner._start_queued_dagruns(session) dr = session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).first() # Assert dr state is running @@ -3396,7 +3465,7 @@ def test_start_dagruns(self, stats_timing, dag_maker): ] ) - assert dag.get_last_dagrun().creating_job_id == self.scheduler_job.id + assert dag.get_last_dagrun().creating_job_id == scheduler_job.id def test_extra_operator_links_not_loaded_in_scheduler_loop(self, dag_maker): """ @@ -3414,13 +3483,15 @@ def test_extra_operator_links_not_loaded_in_scheduler_loop(self, dag_maker): assert custom_task.operator_extra_links session = settings.Session() - self.scheduler_job = Job(job_runner=SchedulerJobRunner(), executor=self.null_exec) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job(executor=self.null_exec) + self.job_runner = SchedulerJobRunner(job=scheduler_job) + + self.job_runner.processor_agent = mock.MagicMock() - self.scheduler_job.job_runner._start_queued_dagruns(session) + self.job_runner._start_queued_dagruns(session) session.flush() # Get serialized dag - s_dag_2 = self.scheduler_job.job_runner.dagbag.get_dag(dag.dag_id) + s_dag_2 = self.job_runner.dagbag.get_dag(dag.dag_id) custom_task = s_dag_2.task_dict["custom_task"] # Test that custom_task has no Operator Links (after de-serialization) in the Scheduling Loop assert not custom_task.operator_extra_links @@ -3435,8 +3506,9 @@ def test_scheduler_create_dag_runs_does_not_raise_error(self, caplog, dag_maker) task_id="dummy", ) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull), executor=self.null_exec) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job(executor=self.null_exec) + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + self.job_runner.processor_agent = mock.MagicMock() caplog.set_level("FATAL") caplog.clear() @@ -3444,7 +3516,7 @@ def test_scheduler_create_dag_runs_does_not_raise_error(self, caplog, dag_maker) "ERROR", logger="airflow.jobs.scheduler_job_runner", ): - self.scheduler_job.job_runner._create_dag_runs([dag_maker.dag_model], session) + self.job_runner._create_dag_runs([dag_maker.dag_model], session) assert caplog.messages == [ "DAG 'test_scheduler_create_dag_runs_does_not_raise_error' not found in serialized_dag table", ] @@ -3470,13 +3542,15 @@ def test_bulk_write_to_db_external_trigger_dont_skip_scheduled_run(self, dag_mak assert dag_model.next_dagrun_data_interval_start == DEFAULT_DATE assert dag_model.next_dagrun_data_interval_end == DEFAULT_DATE + timedelta(minutes=1) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor(do_update=False) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + scheduler_job.executor = MockExecutor(do_update=False) + self.job_runner.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) # Verify a DagRun is created with the correct dates # when Scheduler._do_scheduling is run in the Scheduler Loop - self.scheduler_job.job_runner._do_scheduling(session) + self.job_runner._do_scheduling(session) dr1 = dag.get_dagrun(DEFAULT_DATE, session=session) assert dr1 is not None assert dr1.state == State.RUNNING @@ -3544,11 +3618,12 @@ def test_scheduler_create_dag_runs_check_existing_run(self, dag_maker): assert dag.get_last_dagrun(session) == dagrun - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull), executor=self.null_exec) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job(executor=self.null_exec) + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + self.job_runner.processor_agent = mock.MagicMock() # Test that this does not raise any error - self.scheduler_job.job_runner._create_dag_runs([dag_model], session) + self.job_runner._create_dag_runs([dag_model], session) # Assert the next dagrun fields are set correctly to next execution date assert dag_model.next_dagrun_data_interval_start == DEFAULT_DATE + timedelta(days=1) @@ -3592,26 +3667,28 @@ def test_do_schedule_max_active_runs_dag_timed_out(self, dag_maker): session=session, ) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor() - self.scheduler_job.job_runner.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + scheduler_job.executor = MockExecutor() + self.job_runner.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) my_dag = session.get(DagModel, dag.dag_id) - self.scheduler_job.job_runner._create_dag_runs([my_dag], session) + self.job_runner._create_dag_runs([my_dag], session) # Run relevant part of scheduling again to assert run2 has been scheduled - self.scheduler_job.job_runner._schedule_dag_run(run1, session) + self.job_runner._schedule_dag_run(run1, session) run1 = session.merge(run1) session.refresh(run1) assert run1.state == State.FAILED assert run1_ti.state == State.SKIPPED session.flush() # Run relevant part of scheduling again to assert run2 has been scheduled - self.scheduler_job.job_runner._start_queued_dagruns(session) + self.job_runner._start_queued_dagruns(session) session.flush() run2 = session.merge(run2) session.refresh(run2) assert run2.state == State.RUNNING - self.scheduler_job.job_runner._schedule_dag_run(run2, session) + self.job_runner._schedule_dag_run(run2, session) run2_ti = run2.get_task_instance(task1.task_id, session) assert run2_ti.state == State.SCHEDULED @@ -3633,11 +3710,13 @@ def test_do_schedule_max_active_runs_task_removed(self, session, dag_maker): state=State.RUNNING, ) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor(do_update=False) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + scheduler_job.executor = MockExecutor(do_update=False) + self.job_runner.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) - num_queued = self.scheduler_job.job_runner._do_scheduling(session) + num_queued = self.job_runner._do_scheduling(session) assert num_queued == 1 session.flush() @@ -3652,14 +3731,16 @@ def test_more_runs_are_not_created_when_max_active_runs_is_reached(self, dag_mak """ with dag_maker(max_active_runs=1): EmptyOperator(task_id="task") - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor(do_update=False) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + scheduler_job.executor = MockExecutor(do_update=False) + self.job_runner.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) session = settings.Session() assert session.query(DagRun).count() == 0 query, _ = DagModel.dags_needing_dagruns(session) dag_models = query.all() - self.scheduler_job.job_runner._create_dag_runs(dag_models, session) + self.job_runner._create_dag_runs(dag_models, session) dr = session.query(DagRun).one() dr.state == DagRunState.QUEUED assert session.query(DagRun).count() == 1 @@ -3668,7 +3749,7 @@ def test_more_runs_are_not_created_when_max_active_runs_is_reached(self, dag_mak # dags_needing_dagruns query should not return any value query, _ = DagModel.dags_needing_dagruns(session) assert len(query.all()) == 0 - self.scheduler_job.job_runner._create_dag_runs(dag_models, session) + self.job_runner._create_dag_runs(dag_models, session) assert session.query(DagRun).count() == 1 assert dag_maker.dag_model.next_dagrun_create_after is None assert dag_maker.dag_model.next_dagrun == DEFAULT_DATE @@ -3681,7 +3762,7 @@ def test_more_runs_are_not_created_when_max_active_runs_is_reached(self, dag_mak session.merge(dr) session.flush() # check that next_dagrun is set properly by Schedulerjob._update_dag_next_dagruns - self.scheduler_job.job_runner._schedule_dag_run(dr, session) + self.job_runner._schedule_dag_run(dr, session) session.flush() query, _ = DagModel.dags_needing_dagruns(session) assert len(query.all()) == 1 @@ -3717,14 +3798,16 @@ def complete_one_dagrun(): # Need to use something that doesn't immediately get marked as success by the scheduler BashOperator(task_id="task", bash_command="true") - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor(do_update=True) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + scheduler_job.executor = MockExecutor(do_update=True) + self.job_runner.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) query, _ = DagModel.dags_needing_dagruns(session) query.all() for _ in range(3): - self.scheduler_job.job_runner._do_scheduling(session) + self.job_runner._do_scheduling(session) model: DagModel = session.get(DagModel, dag.dag_id) @@ -3739,7 +3822,7 @@ def complete_one_dagrun(): assert DagRun.active_runs_of_dags(session=session) == {"test_dag": 3} for _ in range(5): - self.scheduler_job.job_runner._do_scheduling(session) + self.job_runner._do_scheduling(session) complete_one_dagrun() expected_execution_dates = [datetime.datetime(2016, 1, d, tzinfo=timezone.utc) for d in range(1, 6)] @@ -3775,11 +3858,13 @@ def test_do_schedule_max_active_runs_and_manual_trigger(self, dag_maker): dag.sync_to_db(session=session) # Update the date fields - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor(do_update=False) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) - num_queued = self.scheduler_job.job_runner._do_scheduling(session) + scheduler_job.executor = MockExecutor(do_update=False) + self.job_runner.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) + + num_queued = self.job_runner._do_scheduling(session) # Add it back in to the session so we can refresh it. (_do_scheduling does an expunge_all to reduce # memory) dag_run = session.merge(dag_run) @@ -3798,7 +3883,7 @@ def test_do_schedule_max_active_runs_and_manual_trigger(self, dag_maker): ) session.flush() - self.scheduler_job.job_runner._do_scheduling(session) + self.job_runner._do_scheduling(session) # Assert that only 1 dagrun is active assert len(DagRun.find(dag_id=dag.dag_id, state=State.RUNNING, session=session)) == 1 @@ -3830,13 +3915,15 @@ def test_max_active_runs_in_a_dag_doesnt_stop_running_dagruns_in_otherdags(self, for _ in range(9): dr = dag_maker.create_dagrun_after(dr, run_type=DagRunType.SCHEDULED, state=State.QUEUED) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor(do_update=False) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + scheduler_job.executor = MockExecutor(do_update=False) + self.job_runner.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) - self.scheduler_job.job_runner._start_queued_dagruns(session) + self.job_runner._start_queued_dagruns(session) session.flush() - self.scheduler_job.job_runner._start_queued_dagruns(session) + self.job_runner._start_queued_dagruns(session) session.flush() dag1_running_count = ( @@ -3858,11 +3945,13 @@ def test_start_queued_dagruns_do_follow_execution_date_order(self, dag_maker): run_id=f"dagrun_{i}", run_type=DagRunType.SCHEDULED, state=State.QUEUED, execution_date=date ) date = dr.execution_date + timedelta(hours=1) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor(do_update=False) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + scheduler_job.executor = MockExecutor(do_update=False) + self.job_runner.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) - self.scheduler_job.job_runner._start_queued_dagruns(session) + self.job_runner._start_queued_dagruns(session) session.flush() dr = DagRun.find(run_id="dagrun_0") ti = dr[0].get_task_instance(task_id="mytask", session=session) @@ -3874,7 +3963,7 @@ def test_start_queued_dagruns_do_follow_execution_date_order(self, dag_maker): session.merge(dr[0]) session.flush() assert dr[0].state == State.SUCCESS - self.scheduler_job.job_runner._start_queued_dagruns(session) + self.job_runner._start_queued_dagruns(session) session.flush() dr = DagRun.find(run_id="dagrun_1") assert len(session.query(DagRun).filter(DagRun.state == State.RUNNING).all()) == 1 @@ -3927,9 +4016,11 @@ def test_no_dagruns_would_stuck_in_running(self, dag_maker): dr = dag_maker.create_dagrun(run_id=f"dr2_run_{i+1}", state=State.QUEUED, execution_date=date) date = dr.execution_date + timedelta(hours=1) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor(do_update=False) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + scheduler_job.executor = MockExecutor(do_update=False) + self.job_runner.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) ti = TaskInstance(task=task1, execution_date=DEFAULT_DATE) ti.refresh_from_db() @@ -3938,8 +4029,8 @@ def test_no_dagruns_would_stuck_in_running(self, dag_maker): session.flush() # Run the scheduler loop with mock.patch.object(settings, "USE_JOB_SCHEDULE", False): - self.scheduler_job.job_runner._do_scheduling(session) - self.scheduler_job.job_runner._do_scheduling(session) + self.job_runner._do_scheduling(session) + self.job_runner._do_scheduling(session) assert DagRun.find(run_id="dr1_run_1")[0].state == State.SUCCESS assert DagRun.find(run_id="dr1_run_2")[0].state == State.RUNNING @@ -3968,8 +4059,10 @@ def test_dag_file_processor_process_task_instances(self, state, start_date, end_ with dag_maker(dag_id="test_scheduler_process_execute_task"): BashOperator(task_id="dummy", bash_command="echo hi") - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner.processor_agent = mock.MagicMock() dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) assert dr is not None @@ -3980,7 +4073,7 @@ def test_dag_file_processor_process_task_instances(self, state, start_date, end_ ti.start_date = start_date ti.end_date = end_date - self.scheduler_job.job_runner._schedule_dag_run(dr, session) + self.job_runner._schedule_dag_run(dr, session) assert session.query(TaskInstance).filter_by(state=State.SCHEDULED).count() == 1 session.refresh(ti) @@ -4012,8 +4105,10 @@ def test_dag_file_processor_process_task_instances_with_max_active_tis_per_dag( with dag_maker(dag_id="test_scheduler_process_execute_task_with_max_active_tis_per_dag"): BashOperator(task_id="dummy", max_active_tis_per_dag=2, bash_command="echo Hi") - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner.processor_agent = mock.MagicMock() dr = dag_maker.create_dagrun( run_type=DagRunType.SCHEDULED, @@ -4026,7 +4121,7 @@ def test_dag_file_processor_process_task_instances_with_max_active_tis_per_dag( ti.start_date = start_date ti.end_date = end_date - self.scheduler_job.job_runner._schedule_dag_run(dr, session) + self.job_runner._schedule_dag_run(dr, session) assert session.query(TaskInstance).filter_by(state=State.SCHEDULED).count() == 1 session.refresh(ti) @@ -4064,8 +4159,10 @@ def test_dag_file_processor_process_task_instances_depends_on_past( BashOperator(task_id="dummy1", bash_command="echo hi") BashOperator(task_id="dummy2", bash_command="echo hi") - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner.processor_agent = mock.MagicMock() dr = dag_maker.create_dagrun( run_type=DagRunType.SCHEDULED, ) @@ -4078,7 +4175,7 @@ def test_dag_file_processor_process_task_instances_depends_on_past( ti.start_date = start_date ti.end_date = end_date - self.scheduler_job.job_runner._schedule_dag_run(dr, session) + self.job_runner._schedule_dag_run(dr, session) assert session.query(TaskInstance).filter_by(state=State.SCHEDULED).count() == 2 session.refresh(tis[0]) @@ -4093,19 +4190,23 @@ def test_scheduler_job_add_new_task(self, dag_maker): with dag_maker(dag_id="test_scheduler_add_new_task") as dag: BashOperator(task_id="dummy", bash_command="echo test") - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.job_runner.dagbag = dag_maker.dagbag + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner.dagbag = dag_maker.dagbag session = settings.Session() orm_dag = dag_maker.dag_model assert orm_dag is not None - if self.scheduler_job.job_runner.processor_agent: - self.scheduler_job.job_runner.processor_agent.end() - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() - dag = self.scheduler_job.job_runner.dagbag.get_dag("test_scheduler_add_new_task", session=session) - self.scheduler_job.job_runner._create_dag_runs([orm_dag], session) + if self.job_runner.processor_agent: + self.job_runner.processor_agent.end() + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner.processor_agent = mock.MagicMock() + dag = self.job_runner.dagbag.get_dag("test_scheduler_add_new_task", session=session) + self.job_runner._create_dag_runs([orm_dag], session) drs = DagRun.find(dag_id=dag.dag_id, session=session) assert len(drs) == 1 @@ -4117,7 +4218,7 @@ def test_scheduler_job_add_new_task(self, dag_maker): BashOperator(task_id="dummy2", dag=dag, bash_command="echo test") SerializedDagModel.write_dag(dag=dag) - self.scheduler_job.job_runner._schedule_dag_run(dr, session) + self.job_runner._schedule_dag_run(dr, session) assert session.query(TaskInstance).filter_by(state=State.SCHEDULED).count() == 2 session.flush() @@ -4139,8 +4240,10 @@ def test_runs_respected_after_clear(self, dag_maker): ) as dag: BashOperator(task_id="dummy", bash_command="echo Hi") - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner.processor_agent = mock.MagicMock() session = settings.Session() dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, state=State.QUEUED) @@ -4151,7 +4254,7 @@ def test_runs_respected_after_clear(self, dag_maker): assert len(DagRun.find(dag_id=dag.dag_id, state=State.QUEUED, session=session)) == 3 session = settings.Session() - self.scheduler_job.job_runner._start_queued_dagruns(session) + self.job_runner._start_queued_dagruns(session) session.flush() # Assert that only 1 dagrun is active assert len(DagRun.find(dag_id=dag.dag_id, state=State.RUNNING, session=session)) == 1 @@ -4192,8 +4295,10 @@ def test_timeout_triggers(self, dag_maker): session.flush() # Boot up the scheduler and make it check timeouts - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.job_runner.check_trigger_timeouts(session=session) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner.check_trigger_timeouts(session=session) # Make sure that TI1 is now scheduled to fail, and 2 wasn't touched session.refresh(ti1) @@ -4204,12 +4309,13 @@ def test_timeout_triggers(self, dag_maker): def test_find_zombies_nothing(self): executor = MockExecutor(do_update=False) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(), executor=executor) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job(executor=executor) + self.job_runner = SchedulerJobRunner(scheduler_job) + self.job_runner.processor_agent = mock.MagicMock() - self.scheduler_job.job_runner._find_zombies() + self.job_runner._find_zombies() - self.scheduler_job.executor.callback_sink.send.assert_not_called() + scheduler_job.executor.callback_sink.send.assert_not_called() def test_find_zombies(self, load_examples): dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False) @@ -4225,9 +4331,10 @@ def test_find_zombies(self, load_examples): session=session, ) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor() - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + scheduler_job.executor = MockExecutor() + self.job_runner.processor_agent = mock.MagicMock() # We will provision 2 tasks so we can check we only find zombies from this scheduler tasks_to_setup = ["branching", "run_this_first"] @@ -4237,7 +4344,8 @@ def test_find_zombies(self, load_examples): ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING) ti.queued_by_job_id = 999 - local_job = Job(job_runner=LocalTaskJobRunner(ti), dag_id=ti.dag_id) + local_job = Job(dag_id=ti.dag_id) + LocalTaskJobRunner(job=local_job, task_instance=ti) local_job.state = State.SHUTDOWN session.add(local_job) @@ -4249,16 +4357,16 @@ def test_find_zombies(self, load_examples): assert task.task_id == "run_this_first" # Make sure we have the task/ti we expect - ti.queued_by_job_id = self.scheduler_job.id + ti.queued_by_job_id = scheduler_job.id session.flush() - self.scheduler_job.job_runner._find_zombies() + self.job_runner._find_zombies() - self.scheduler_job.executor.callback_sink.send.assert_called_once() - requests = self.scheduler_job.executor.callback_sink.send.call_args[0] + scheduler_job.executor.callback_sink.send.assert_called_once() + requests = scheduler_job.executor.callback_sink.send.call_args[0] assert 1 == len(requests) assert requests[0].full_filepath == dag.fileloc - assert requests[0].msg == str(self.scheduler_job.job_runner._generate_zombie_message_details(ti)) + assert requests[0].msg == str(self.job_runner._generate_zombie_message_details(ti)) assert requests[0].is_failure_callback is True assert isinstance(requests[0].simple_task_instance, SimpleTaskInstance) assert ti.dag_id == requests[0].simple_task_instance.dag_id @@ -4288,9 +4396,9 @@ def test_zombie_message(self, load_examples): session=session, ) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor() - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job(executor=MockExecutor()) + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + self.job_runner.processor_agent = mock.MagicMock() # We will provision 2 tasks so we can check we only find zombies from this scheduler tasks_to_setup = ["branching", "run_this_first"] @@ -4300,7 +4408,7 @@ def test_zombie_message(self, load_examples): ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING) ti.queued_by_job_id = 999 - local_job = Job(job_runner=LocalTaskJobRunner(ti), dag_id=ti.dag_id) + local_job = Job(dag_id=ti.dag_id) local_job.state = State.SHUTDOWN session.add(local_job) @@ -4312,10 +4420,10 @@ def test_zombie_message(self, load_examples): assert task.task_id == "run_this_first" # Make sure we have the task/ti we expect - ti.queued_by_job_id = self.scheduler_job.id + ti.queued_by_job_id = scheduler_job.id session.flush() - zombie_message = self.scheduler_job.job_runner._generate_zombie_message_details(ti) + zombie_message = self.job_runner._generate_zombie_message_details(ti) assert zombie_message == { "DAG Id": "example_branch_operator", "Task Id": "run_this_first", @@ -4326,7 +4434,7 @@ def test_zombie_message(self, load_examples): ti.map_index = 2 ti.external_executor_id = "abcdefg" - zombie_message = self.scheduler_job.job_runner._generate_zombie_message_details(ti) + zombie_message = self.job_runner._generate_zombie_message_details(ti) assert zombie_message == { "DAG Id": "example_branch_operator", "Task Id": "run_this_first", @@ -4359,7 +4467,9 @@ def test_find_zombies_handle_failure_callbacks_are_correctly_passed_to_dag_proce task = dag.get_task(task_id="run_this_last") ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING) - local_job = Job(job_runner=LocalTaskJobRunner(ti), dag_id=ti.dag_id) + + local_job = Job(dag_id=ti.dag_id) + LocalTaskJobRunner(job=local_job, task_instance=ti) local_job.state = State.SHUTDOWN session.add(local_job) session.flush() @@ -4370,23 +4480,25 @@ def test_find_zombies_handle_failure_callbacks_are_correctly_passed_to_dag_proce ti.job_id = local_job.id session.flush() - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor() - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + scheduler_job.executor = MockExecutor() + self.job_runner.processor_agent = mock.MagicMock() - self.scheduler_job.job_runner._find_zombies() + self.job_runner._find_zombies() - self.scheduler_job.executor.callback_sink.send.assert_called_once() + scheduler_job.executor.callback_sink.send.assert_called_once() expected_failure_callback_requests = [ TaskCallbackRequest( full_filepath=dag.fileloc, simple_task_instance=SimpleTaskInstance.from_ti(ti), processor_subdir=TEST_DAG_FOLDER, - msg=str(self.scheduler_job.job_runner._generate_zombie_message_details(ti)), + msg=str(self.job_runner._generate_zombie_message_details(ti)), ) ] - callback_requests = self.scheduler_job.executor.callback_sink.send.call_args[0] + callback_requests = scheduler_job.executor.callback_sink.send.call_args[0] assert len(callback_requests) == 1 assert {zombie.simple_task_instance.key for zombie in expected_failure_callback_requests} == { result.simple_task_instance.key for result in callback_requests @@ -4411,14 +4523,14 @@ def test_cleanup_stale_dags(self): session.flush() - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor() - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() + scheduler_job = Job(executor=MockExecutor()) + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + self.job_runner.processor_agent = mock.MagicMock() active_dag_count = session.query(func.count(DagModel.dag_id)).filter(DagModel.is_active).scalar() assert active_dag_count == 2 - self.scheduler_job.job_runner._cleanup_stale_dags(session) + self.job_runner._cleanup_stale_dags(session) session.flush() @@ -4483,7 +4595,7 @@ def watch_heartbeat(*args, **kwargs): heartbeat_spy = mock.patch.object(job_runner, "heartbeat", new=watch_heartbeat) with heartbeat_spy, set_state_spy, do_scheduling_spy, executor_events_spy: - run_job(job_runner.job) + run_job(job_runner.job, execute_callable=job_runner._execute) @pytest.mark.long_running @pytest.mark.parametrize("dag_id", ["test_mapped_classic", "test_mapped_taskflow"]) @@ -4505,8 +4617,8 @@ def test_mapped_dag(self, dag_id, session): executor = SequentialExecutor() - job = Job(job_runner=SchedulerJobRunner(subdir=dag.fileloc, executor=executor)) - + job = Job(executor=executor) + self.job_runner = SchedulerJobRunner(job=job, subdir=dag.fileloc) self.run_scheduler_until_dagrun_terminal(job) dr.refresh_from_db(session) @@ -4521,25 +4633,27 @@ def test_should_mark_empty_task_as_success(self): dagbag = DagBag(dag_folder=dag_file, include_examples=False, read_dags_from_db=False) dagbag.sync_to_db() - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.job_runner.processor_agent = mock.MagicMock() - dag = self.scheduler_job.job_runner.dagbag.get_dag("test_only_empty_tasks") + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner.processor_agent = mock.MagicMock() + dag = self.job_runner.dagbag.get_dag("test_only_empty_tasks") # Create DagRun session = settings.Session() orm_dag = session.get(DagModel, dag.dag_id) - self.scheduler_job.job_runner._create_dag_runs([orm_dag], session) + self.job_runner._create_dag_runs([orm_dag], session) drs = DagRun.find(dag_id=dag.dag_id, session=session) assert len(drs) == 1 dr = drs[0] # Schedule TaskInstances - self.scheduler_job.job_runner._schedule_dag_run(dr, session) + self.job_runner._schedule_dag_run(dr, session) with create_session() as session: tis = session.query(TaskInstance).all() - dags = self.scheduler_job.job_runner.dagbag.dags.values() + dags = self.job_runner.dagbag.dags.values() assert ["test_only_empty_tasks"] == [dag.dag_id for dag in dags] assert 6 == len(tis) assert { @@ -4562,7 +4676,7 @@ def test_should_mark_empty_task_as_success(self): assert end_date is None assert duration is None - self.scheduler_job.job_runner._schedule_dag_run(dr, session) + self.job_runner._schedule_dag_run(dr, session) with create_session() as session: tis = session.query(TaskInstance).all() @@ -4601,12 +4715,14 @@ def test_catchup_works_correctly(self, dag_maker): ) as dag: EmptyOperator(task_id="dummy") - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor() - self.scheduler_job.job_runner.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) - self.scheduler_job.job_runner._create_dag_runs([dag_maker.dag_model], session) - self.scheduler_job.job_runner._start_queued_dagruns(session) + scheduler_job.executor = MockExecutor() + self.job_runner.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) + + self.job_runner._create_dag_runs([dag_maker.dag_model], session) + self.job_runner._start_queued_dagruns(session) # first dagrun execution date is DEFAULT_DATE 2016-01-01T00:00:00+00:00 dr = DagRun.find(execution_date=DEFAULT_DATE, session=session)[0] ti = dr.get_task_instance(task_id="dummy") @@ -4614,11 +4730,11 @@ def test_catchup_works_correctly(self, dag_maker): session.merge(ti) session.flush() - self.scheduler_job.job_runner._schedule_dag_run(dr, session) + self.job_runner._schedule_dag_run(dr, session) session.flush() # Run the second time so _update_dag_next_dagrun will run - self.scheduler_job.job_runner._schedule_dag_run(dr, session) + self.job_runner._schedule_dag_run(dr, session) session.flush() dag.catchup = False @@ -4626,7 +4742,7 @@ def test_catchup_works_correctly(self, dag_maker): assert not dag.catchup dm = DagModel.get_dagmodel(dag.dag_id) - self.scheduler_job.job_runner._create_dag_runs([dm], session) + self.job_runner._create_dag_runs([dm], session) # Check catchup worked correctly by ensuring execution_date is quite new # Our dag is a daily dag @@ -4657,9 +4773,11 @@ def test_update_dagrun_state_for_paused_dag(self, dag_maker, session): assert scheduled_run.state == State.RUNNING - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor() - self.scheduler_job.job_runner._update_dag_run_state_for_paused_dags(session=session) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + scheduler_job.executor = MockExecutor() + self.job_runner._update_dag_run_state_for_paused_dags(session=session) session.flush() # TI still running, DagRun left in running @@ -4668,7 +4786,7 @@ def test_update_dagrun_state_for_paused_dag(self, dag_maker, session): prior_last_scheduling_decision = scheduled_run.last_scheduling_decision # Make sure we don't constantly try dagruns over and over - self.scheduler_job.job_runner._update_dag_run_state_for_paused_dags(session=session) + self.job_runner._update_dag_run_state_for_paused_dags(session=session) (scheduled_run,) = DagRun.find(dag_id=dag.dag_id, run_type=DagRunType.SCHEDULED, session=session) assert scheduled_run.state == State.RUNNING # last_scheduling_decision is bumped by update_state, so check that to determine if we tried again @@ -4676,7 +4794,7 @@ def test_update_dagrun_state_for_paused_dag(self, dag_maker, session): # Once the TI is in a terminal state though, DagRun goes to success ti.set_state(TaskInstanceState.SUCCESS) - self.scheduler_job.job_runner._update_dag_run_state_for_paused_dags(session=session) + self.job_runner._update_dag_run_state_for_paused_dags(session=session) (scheduled_run,) = DagRun.find(dag_id=dag.dag_id, run_type=DagRunType.SCHEDULED, session=session) assert scheduled_run.state == State.SUCCESS @@ -4699,9 +4817,11 @@ def test_update_dagrun_state_for_paused_dag_not_for_backfill(self, dag_maker, se assert backfill_run.state == State.RUNNING - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.executor = MockExecutor() - self.scheduler_job.job_runner._update_dag_run_state_for_paused_dags() + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + scheduler_job.executor = MockExecutor() + self.job_runner._update_dag_run_state_for_paused_dags() session.flush() (backfill_run,) = DagRun.find(dag_id=dag.dag_id, run_type=DagRunType.BACKFILL_JOB, session=session) @@ -4725,8 +4845,10 @@ def test_dataset_orphaning(self, dag_maker, session): with dag_maker(dag_id="datasets-1", schedule=[dataset1], session=session): BashOperator(task_id="task", bash_command="echo 1", outlets=[dataset3]) - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - self.scheduler_job.job_runner._orphan_unreferenced_datasets(session=session) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner._orphan_unreferenced_datasets(session=session) session.flush() # and find the orphans @@ -4774,8 +4896,9 @@ def test_schedule_dag_run_with_upstream_skip(dag_maker, session): # dag_runs = DagRun.find(dag_id='test_task_with_upstream_skip_dag') # dag_file_processor._process_task_instances(dag, dag_runs=dag_runs) - scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=os.devnull)) - scheduler_job.job_runner._schedule_dag_run(dr, session) + scheduler_job = Job() + job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + job_runner._schedule_dag_run(dr, session) session.flush() tis = {ti.task_id: ti for ti in dr.get_task_instances(session=session)} assert tis[dummy1.task_id].state == State.SKIPPED @@ -4809,9 +4932,8 @@ def per_test(self) -> Generator: yield - if self.scheduler_job and self.scheduler_job.job_runner.processor_agent: # type: ignore[attr-defined] - self.scheduler_job.job_runner.processor_agent.end() # type: ignore[attr-defined] - self.scheduler_job = None + if self.job_runner.processor_agent: # type: ignore[attr-defined] + self.job_runner.processor_agent.end() # type: ignore[attr-defined] self.clean_db() @pytest.mark.parametrize( @@ -4863,16 +4985,18 @@ def test_execute_queries_count_with_harvested_dags(self, expected_query_count, d mock_agent = mock.MagicMock() - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=PERF_DAGS_FOLDER, num_runs=1)) - self.scheduler_job.executor = MockExecutor(do_update=False) - self.scheduler_job.heartbeat = mock.MagicMock() - self.scheduler_job.job_runner.processor_agent = mock_agent + scheduler_job = Job( + executor=MockExecutor(do_update=False), + ) + scheduler_job.heartbeat = mock.MagicMock() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=PERF_DAGS_FOLDER, num_runs=1) + self.job_runner.processor_agent = mock_agent with assert_queries_count(expected_query_count, margin=15): with mock.patch.object(DagRun, "next_dagruns_to_examine") as mock_dagruns: mock_dagruns.return_value = dagruns - self.scheduler_job.job_runner._run_scheduler_loop() + self.job_runner._run_scheduler_loop() @pytest.mark.parametrize( "expected_query_counts, dag_count, task_count, start_ago, schedule_interval, shape", @@ -4938,10 +5062,10 @@ def test_process_dags_queries_count( mock_agent = mock.MagicMock() - self.scheduler_job = Job(job_runner=SchedulerJobRunner(subdir=PERF_DAGS_FOLDER, num_runs=1)) - self.scheduler_job.executor = MockExecutor(do_update=False) - self.scheduler_job.heartbeat = mock.MagicMock() - self.scheduler_job.job_runner.processor_agent = mock_agent + scheduler_job = Job(job_type=SchedulerJobRunner.job_type, executor=MockExecutor(do_update=False)) + scheduler_job.heartbeat = mock.MagicMock() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=PERF_DAGS_FOLDER, num_runs=1) + self.job_runner.processor_agent = mock_agent failures = [] # Collects assertion errors and report all of them at the end. message = "Expected {expected_count} query, but got {current_count} located at:" @@ -4949,7 +5073,7 @@ def test_process_dags_queries_count( with create_session() as session: try: with assert_queries_count(expected_query_count, message_fmt=message, margin=15): - self.scheduler_job.job_runner._do_scheduling(session) + self.job_runner._do_scheduling(session) except AssertionError as e: failures.append(str(e)) if failures: diff --git a/tests/jobs/test_triggerer_job.py b/tests/jobs/test_triggerer_job.py index d22a371a9f52b..0e19927f17911 100644 --- a/tests/jobs/test_triggerer_job.py +++ b/tests/jobs/test_triggerer_job.py @@ -125,23 +125,24 @@ def __init__(self, password, **kwargs): trigger = SuccessTrigger() op = SensitiveArgOperator(task_id="sensitive_arg_task", password="some_password") create_trigger_in_db(session, trigger, operator=op) - job = Job(job_runner=TriggererJobRunner()) - job.job_runner.load_triggers() + triggerer_job = Job() + triggerer_job_runner = TriggererJobRunner(triggerer_job) + triggerer_job_runner.load_triggers() # Now, start TriggerRunner up (and set it as a daemon thread during tests) - job.job_runner.daemon = True - job.job_runner.trigger_runner.start() + triggerer_job_runner.daemon = True + triggerer_job_runner.trigger_runner.start() try: # Wait for up to 3 seconds for it to fire and appear in the event queue for _ in range(30): - if job.job_runner.trigger_runner.events: - assert list(job.job_runner.trigger_runner.events) == [(1, TriggerEvent(True))] + if triggerer_job_runner.trigger_runner.events: + assert list(triggerer_job_runner.trigger_runner.events) == [(1, TriggerEvent(True))] break time.sleep(0.1) else: pytest.fail("TriggerRunner never sent the trigger event out") finally: # We always have to stop the runner - job.job_runner.trigger_runner.stop = True + triggerer_job_runner.trigger_runner.stop = True stdout = capsys.readouterr().out assert "test_dag/test_run/sensitive_arg_task/-1/1 (ID 1) starting" in stdout assert "some_password" not in stdout @@ -150,7 +151,7 @@ def __init__(self, password, **kwargs): def test_is_alive(): """Checks the heartbeat logic""" # Current time - triggerer_job = Job(job_runner=TriggererJobRunner(None), heartrate=10, state=State.RUNNING) + triggerer_job = Job(heartrate=10, state=State.RUNNING) assert triggerer_job.is_alive() # Slightly old, but still fresh @@ -170,15 +171,16 @@ def test_is_alive(): def test_is_needed(session): """Checks the triggerer-is-needed logic""" # No triggers, no need - triggerer_job = Job(job_runner=TriggererJobRunner(None), heartrate=10, state=State.RUNNING) - assert triggerer_job.job_runner.is_needed() is False + triggerer_job = Job(heartrate=10, state=State.RUNNING) + triggerer_job_runner = TriggererJobRunner(triggerer_job) + assert triggerer_job_runner.is_needed() is False # Add a trigger, it's needed trigger = TimeDeltaTrigger(datetime.timedelta(days=7)) trigger_orm = Trigger.from_object(trigger) trigger_orm.id = 1 session.add(trigger_orm) session.commit() - assert triggerer_job.job_runner.is_needed() is True + assert triggerer_job_runner.is_needed() is True def test_capacity_decode(): @@ -192,8 +194,9 @@ def test_capacity_decode(): None, ] for input_str in variants: - job = Job(job_runner=TriggererJobRunner(capacity=input_str)) - assert job.job_runner.capacity == input_str or job.job_runner.capacity == 1000 + job = Job() + job_runner = TriggererJobRunner(job, capacity=input_str) + assert job_runner.capacity == input_str or job_runner.capacity == 1000 # Negative cases variants = [ @@ -204,7 +207,8 @@ def test_capacity_decode(): ] for input_str in variants: with pytest.raises(ValueError): - TriggererJobRunner(capacity=input_str) + job = Job() + TriggererJobRunner(job=job, capacity=input_str) def test_trigger_lifecycle(session): @@ -217,18 +221,19 @@ def test_trigger_lifecycle(session): trigger = TimeDeltaTrigger(datetime.timedelta(days=7)) dag_model, run, trigger_orm, task_instance = create_trigger_in_db(session, trigger) # Make a TriggererJobRunner and have it retrieve DB tasks - job = Job(job_runner=TriggererJobRunner()) - job.job_runner.load_triggers() + job = Job() + job_runner = TriggererJobRunner(job) + job_runner.load_triggers() # Make sure it turned up in TriggerRunner's queue - assert [x for x, y in job.job_runner.trigger_runner.to_create] == [1] + assert [x for x, y in job_runner.trigger_runner.to_create] == [1] # Now, start TriggerRunner up (and set it as a daemon thread during tests) - job.job_runner.daemon = True - job.job_runner.trigger_runner.start() + job_runner.daemon = True + job_runner.trigger_runner.start() try: # Wait for up to 3 seconds for it to appear in the TriggerRunner's storage for _ in range(30): - if job.job_runner.trigger_runner.triggers: - assert list(job.job_runner.trigger_runner.triggers.keys()) == [1] + if job_runner.trigger_runner.triggers: + assert list(job_runner.trigger_runner.triggers.keys()) == [1] break time.sleep(0.1) else: @@ -237,17 +242,17 @@ def test_trigger_lifecycle(session): session.delete(trigger_orm) session.commit() # Re-load the triggers - job.job_runner.load_triggers() + job_runner.load_triggers() # Wait for up to 3 seconds for it to vanish from the TriggerRunner's storage for _ in range(30): - if not job.job_runner.trigger_runner.triggers: + if not job_runner.trigger_runner.triggers: break time.sleep(0.1) else: pytest.fail("TriggerRunner never deleted trigger") finally: # We always have to stop the runner - job.job_runner.trigger_runner.stop = True + job_runner.trigger_runner.stop = True def test_trigger_create_race_condition_18392(session, tmp_path): @@ -339,21 +344,22 @@ def handle_events(self): session.commit() - job = Job(job_runner=TriggererJob_()) - job.job_runner.trigger_runner = TriggerRunner_() - thread = Thread(target=job.job_runner._execute) + job = Job() + job_runner = TriggererJob_(job) + job_runner.trigger_runner = TriggerRunner_() + thread = Thread(target=job_runner._execute) thread.start() try: for _ in range(40): time.sleep(0.1) # ready to evaluate after 2 loops - if getattr(job.job_runner.trigger_runner, "loop_count", 0) >= 2: + if getattr(job_runner.trigger_runner, "loop_count", 0) >= 2: break else: pytest.fail("did not observe 2 loops in the runner thread") finally: - job.job_runner.trigger_runner.stop = True - job.job_runner.trigger_runner.join() + job_runner.trigger_runner.stop = True + job_runner.trigger_runner.join() thread.join() instances = path.read_text().splitlines() assert len(instances) == 1 @@ -372,10 +378,11 @@ def test_trigger_from_dead_triggerer(session): session.add(trigger_orm) session.commit() # Make a TriggererJobRunner and have it retrieve DB tasks - job = Job(job_runner=TriggererJobRunner()) - job.job_runner.load_triggers() + job = Job() + job_runner = TriggererJobRunner(job) + job_runner.load_triggers() # Make sure it turned up in TriggerRunner's queue - assert [x for x, y in job.job_runner.trigger_runner.to_create] == [1] + assert [x for x, y in job_runner.trigger_runner.to_create] == [1] def test_trigger_from_expired_triggerer(session): @@ -390,7 +397,7 @@ def test_trigger_from_expired_triggerer(session): trigger_orm.triggerer_id = 42 session.add(trigger_orm) # Use a TriggererJobRunner with an expired heartbeat - triggerer_job_orm = Job(job_runner=TriggererJobRunner()) + triggerer_job_orm = Job(TriggererJobRunner.job_type) triggerer_job_orm.id = 42 triggerer_job_orm.start_date = timezone.utcnow() - datetime.timedelta(hours=1) triggerer_job_orm.end_date = None @@ -398,10 +405,11 @@ def test_trigger_from_expired_triggerer(session): session.add(triggerer_job_orm) session.commit() # Make a TriggererJobRunner and have it retrieve DB tasks - job = Job(job_runner=TriggererJobRunner()) - job.job_runner.load_triggers() + job = Job(TriggererJobRunner.job_type) + job_runner = TriggererJobRunner(job) + job_runner.load_triggers() # Make sure it turned up in TriggerRunner's queue - assert [x for x, y in job.job_runner.trigger_runner.to_create] == [1] + assert [x for x, y in job_runner.trigger_runner.to_create] == [1] def test_trigger_firing(session): @@ -413,23 +421,24 @@ def test_trigger_firing(session): trigger = SuccessTrigger() create_trigger_in_db(session, trigger) # Make a TriggererJobRunner and have it retrieve DB tasks - job = Job(job_runner=TriggererJobRunner()) - job.job_runner.load_triggers() + job = Job() + job_runner = TriggererJobRunner(job) + job_runner.load_triggers() # Now, start TriggerRunner up (and set it as a daemon thread during tests) - job.job_runner.daemon = True - job.job_runner.trigger_runner.start() + job_runner.daemon = True + job_runner.trigger_runner.start() try: # Wait for up to 3 seconds for it to fire and appear in the event queue for _ in range(30): - if job.job_runner.trigger_runner.events: - assert list(job.job_runner.trigger_runner.events) == [(1, TriggerEvent(True))] + if job_runner.trigger_runner.events: + assert list(job_runner.trigger_runner.events) == [(1, TriggerEvent(True))] break time.sleep(0.1) else: pytest.fail("TriggerRunner never sent the trigger event out") finally: # We always have to stop the runner - job.job_runner.trigger_runner.stop = True + job_runner.trigger_runner.stop = True def test_trigger_failing(session): @@ -441,17 +450,18 @@ def test_trigger_failing(session): trigger = FailureTrigger() create_trigger_in_db(session, trigger) # Make a TriggererJobRunner and have it retrieve DB tasks - job = Job(job_runner=TriggererJobRunner()) - job.job_runner.load_triggers() + job = Job() + job_runner = TriggererJobRunner(job) + job_runner.load_triggers() # Now, start TriggerRunner up (and set it as a daemon thread during tests) - job.job_runner.daemon = True - job.job_runner.trigger_runner.start() + job_runner.daemon = True + job_runner.trigger_runner.start() try: # Wait for up to 3 seconds for it to fire and appear in the event queue for _ in range(30): - if job.job_runner.trigger_runner.failed_triggers: - assert len(job.job_runner.trigger_runner.failed_triggers) == 1 - trigger_id, exc = list(job.job_runner.trigger_runner.failed_triggers)[0] + if job_runner.trigger_runner.failed_triggers: + assert len(job_runner.trigger_runner.failed_triggers) == 1 + trigger_id, exc = list(job_runner.trigger_runner.failed_triggers)[0] assert trigger_id == 1 assert isinstance(exc, ValueError) assert exc.args[0] == "Deliberate trigger failure" @@ -461,7 +471,7 @@ def test_trigger_failing(session): pytest.fail("TriggerRunner never marked the trigger as failed") finally: # We always have to stop the runner - job.job_runner.trigger_runner.stop = True + job_runner.trigger_runner.stop = True def test_trigger_cleanup(session): @@ -506,14 +516,15 @@ def test_invalid_trigger(session, dag_maker): session.commit() # Make a TriggererJobRunner and have it retrieve DB tasks - job = Job(job_runner=TriggererJobRunner()) - job.job_runner.load_triggers() + job = Job() + job_runner = TriggererJobRunner(job) + job_runner.load_triggers() # Make sure it turned up in the failed queue - assert len(job.job_runner.trigger_runner.failed_triggers) == 1 + assert len(job_runner.trigger_runner.failed_triggers) == 1 # Run the failed trigger handler - job.job_runner.handle_failed_triggers() + job_runner.handle_failed_triggers() # Make sure it marked the task instance as failed (which is actually the # scheduled state with a payload to make it fail) @@ -530,7 +541,8 @@ def test_handler_config_respects_donot_wrap(mock_configure, should_wrap): from airflow.jobs import triggerer_job_runner triggerer_job_runner.DISABLE_WRAPPER = not should_wrap - TriggererJobRunner() + job = Job() + TriggererJobRunner(job=job) if should_wrap: mock_configure.assert_called() else: @@ -540,7 +552,8 @@ def test_handler_config_respects_donot_wrap(mock_configure, should_wrap): @patch("airflow.jobs.triggerer_job_runner.setup_queue_listener") def test_triggerer_job_always_creates_listener(mock_setup): mock_setup.assert_not_called() - TriggererJobRunner() + job = Job() + TriggererJobRunner(job=job) mock_setup.assert_called() diff --git a/tests/listeners/test_listeners.py b/tests/listeners/test_listeners.py index a1c5a784b7906..15f9782c382da 100644 --- a/tests/listeners/test_listeners.py +++ b/tests/listeners/test_listeners.py @@ -80,9 +80,10 @@ def test_multiple_listeners(create_task_instance, session=None): lm.add_listener(full_listener) lm.add_listener(lifecycle_listener) - job = Job(job_runner=MockJobRunner()) + job = Job() + job_runner = MockJobRunner(job=job) try: - run_job(job) + run_job(job=job, execute_callable=job_runner._execute) except NotImplementedError: pass # just for lifecycle diff --git a/tests/models/test_trigger.py b/tests/models/test_trigger.py index 827312d9f23f6..ac5bba4606de2 100644 --- a/tests/models/test_trigger.py +++ b/tests/models/test_trigger.py @@ -140,14 +140,17 @@ def test_assign_unassigned(session, create_task_instance): """ Tests that unassigned triggers of all appropriate states are assigned. """ - finished_triggerer = Job(job_runner=TriggererJobRunner(None), heartrate=10, state=State.SUCCESS) + finished_triggerer = Job(heartrate=10, state=State.SUCCESS) + TriggererJobRunner(finished_triggerer) finished_triggerer.end_date = timezone.utcnow() - datetime.timedelta(hours=1) session.add(finished_triggerer) assert not finished_triggerer.is_alive() - healthy_triggerer = Job(job_runner=TriggererJobRunner(None), heartrate=10, state=State.RUNNING) + healthy_triggerer = Job(heartrate=10, state=State.RUNNING) + TriggererJobRunner(healthy_triggerer) session.add(healthy_triggerer) assert healthy_triggerer.is_alive() - new_triggerer = Job(job_runner=TriggererJobRunner(None), heartrate=10, state=State.RUNNING) + new_triggerer = Job(heartrate=10, state=State.RUNNING) + TriggererJobRunner(new_triggerer) session.add(new_triggerer) assert new_triggerer.is_alive() session.commit() diff --git a/tests/serialization/test_pydantic_models.py b/tests/serialization/test_pydantic_models.py index 8947afea35634..b5e471b044c41 100644 --- a/tests/serialization/test_pydantic_models.py +++ b/tests/serialization/test_pydantic_models.py @@ -76,7 +76,8 @@ def test_serializing_pydantic_dagrun(session, create_task_instance): def test_serializing_pydantic_local_task_job(session, create_task_instance): dag_id = "test-dag" ti = create_task_instance(dag_id=dag_id, session=session) - ltj = Job(job_runner=LocalTaskJobRunner(task_instance=ti), dag_id=ti.dag_id) + ltj = Job(dag_id=ti.dag_id) + LocalTaskJobRunner(job=ltj, task_instance=ti) ltj.state = State.RUNNING session.commit() pydantic_job = JobPydantic.from_orm(ltj) diff --git a/tests/task/task_runner/test_base_task_runner.py b/tests/task/task_runner/test_base_task_runner.py index fe5dcc099ba5c..90edf7e32086d 100644 --- a/tests/task/task_runner/test_base_task_runner.py +++ b/tests/task/task_runner/test_base_task_runner.py @@ -39,9 +39,9 @@ def test_config_copy_mode(tmp_configuration_copy, subprocess_call, dag_maker, im dr = dag_maker.create_dagrun() ti = dr.task_instances[0] - task_runner = LocalTaskJobRunner(ti) - job = Job(job_runner=task_runner, dag_id=ti.dag_id) - runner = BaseTaskRunner(job) + job = Job(dag_id=ti.dag_id) + job_runner = LocalTaskJobRunner(job=job, task_instance=ti) + runner = BaseTaskRunner(job_runner) # So we don't try to delete it -- cos the file won't exist del runner._cfg_path diff --git a/tests/task/task_runner/test_cgroup_task_runner.py b/tests/task/task_runner/test_cgroup_task_runner.py index 0d40f0c00282d..c3999e19b09a3 100644 --- a/tests/task/task_runner/test_cgroup_task_runner.py +++ b/tests/task/task_runner/test_cgroup_task_runner.py @@ -32,12 +32,13 @@ def test_cgroup_task_runner_super_calls(self, mock_super_on_finish, mock_super_i and when task finishes, CgroupTaskRunner.on_finish() calls super().on_finish() to delete the temp cfg file. """ - base_job = mock.Mock() - base_job.task_instance = mock.MagicMock() - base_job.task_instance.run_as_user = None - base_job.task_instance.command_as_list.return_value = ["sleep", "1000"] + Job = mock.Mock() + Job.job_type = None + Job.task_instance = mock.MagicMock() + Job.task_instance.run_as_user = None + Job.task_instance.command_as_list.return_value = ["sleep", "1000"] - runner = CgroupTaskRunner(base_job) + runner = CgroupTaskRunner(Job) assert mock_super_init.called runner.on_finish() diff --git a/tests/task/task_runner/test_standard_task_runner.py b/tests/task/task_runner/test_standard_task_runner.py index d642c5418cecf..7d81025b4d6f9 100644 --- a/tests/task/task_runner/test_standard_task_runner.py +++ b/tests/task/task_runner/test_standard_task_runner.py @@ -94,10 +94,11 @@ def setup_class(self): @patch("airflow.utils.log.file_task_handler.FileTaskHandler._init_file") def test_start_and_terminate(self, mock_init): mock_init.return_value = "/tmp/any" - base_job = mock.Mock() - base_job.task_instance = mock.MagicMock() - base_job.task_instance.run_as_user = None - base_job.task_instance.command_as_list.return_value = [ + Job = mock.Mock() + Job.job_type = None + Job.task_instance = mock.MagicMock() + Job.task_instance.run_as_user = None + Job.task_instance.command_as_list.return_value = [ "airflow", "tasks", "run", @@ -105,27 +106,26 @@ def test_start_and_terminate(self, mock_init): "task1", "2016-01-01", ] - base_job.job_runner = LocalTaskJobRunner(base_job.task_instance) - - runner = StandardTaskRunner(base_job) - runner.start() + job_runner = LocalTaskJobRunner(job=Job, task_instance=Job.task_instance) + task_runner = StandardTaskRunner(job_runner) + task_runner.start() # Wait until process sets its pgid to be equal to pid with timeout(seconds=1): while True: - runner_pgid = os.getpgid(runner.process.pid) - if runner_pgid == runner.process.pid: + runner_pgid = os.getpgid(task_runner.process.pid) + if runner_pgid == task_runner.process.pid: break time.sleep(0.01) assert runner_pgid > 0 assert runner_pgid != os.getpgid(0), "Task should be in a different process group to us" processes = list(self._procs_in_pgroup(runner_pgid)) - runner.terminate() + task_runner.terminate() for process in processes: assert not psutil.pid_exists(process.pid), f"{process} is still alive" - assert runner.return_code() is not None + assert task_runner.return_code() is not None def test_notifies_about_start_and_stop(self): path_listener_writer = "/tmp/test_notifies_about_start_and_stop" @@ -150,20 +150,21 @@ def test_notifies_about_start_and_stop(self): start_date=DEFAULT_DATE, ) ti = TaskInstance(task=task, run_id="test") - job1 = Job(job_runner=LocalTaskJobRunner(task_instance=ti, ignore_ti_state=True), dag_id=ti.dag_id) - runner = StandardTaskRunner(job1) - runner.start() + job = Job(dag_id=ti.dag_id) + job_runner = LocalTaskJobRunner(job=job, task_instance=ti, ignore_ti_state=True) + task_runner = StandardTaskRunner(job_runner) + task_runner.start() # Wait until process makes itself the leader of its own process group with timeout(seconds=1): while True: - runner_pgid = os.getpgid(runner.process.pid) - if runner_pgid == runner.process.pid: + runner_pgid = os.getpgid(task_runner.process.pid) + if runner_pgid == task_runner.process.pid: break time.sleep(0.01) # Wait till process finishes - assert runner.return_code(timeout=10) is not None + assert task_runner.return_code(timeout=10) is not None with open(path_listener_writer) as f: assert f.readline() == "on_starting\n" assert f.readline() == "on_task_instance_running\n" @@ -193,20 +194,21 @@ def test_notifies_about_fail(self): start_date=DEFAULT_DATE, ) ti = TaskInstance(task=task, run_id="test") - job1 = Job(job_runner=LocalTaskJobRunner(task_instance=ti, ignore_ti_state=True), dag_id=ti.dag_id) - runner = StandardTaskRunner(job1) - runner.start() + job = Job(dag_id=ti.dag_id) + job_runner = LocalTaskJobRunner(job=job, task_instance=ti, ignore_ti_state=True) + task_runner = StandardTaskRunner(job_runner) + task_runner.start() # Wait until process makes itself the leader of its own process group with timeout(seconds=1): while True: - runner_pgid = os.getpgid(runner.process.pid) - if runner_pgid == runner.process.pid: + runner_pgid = os.getpgid(task_runner.process.pid) + if runner_pgid == task_runner.process.pid: break time.sleep(0.01) # Wait till process finishes - assert runner.return_code(timeout=10) is not None + assert task_runner.return_code(timeout=10) is not None with open(path_listener_writer) as f: assert f.readline() == "on_starting\n" assert f.readline() == "on_task_instance_running\n" @@ -216,12 +218,13 @@ def test_notifies_about_fail(self): @patch("airflow.utils.log.file_task_handler.FileTaskHandler._init_file") def test_start_and_terminate_run_as_user(self, mock_init): mock_init.return_value = "/tmp/any" - base_job = mock.Mock() - base_job.task_instance = mock.MagicMock() - base_job.task_instance.task_id = "task_id" - base_job.task_instance.dag_id = "dag_id" - base_job.task_instance.run_as_user = getuser() - base_job.task_instance.command_as_list.return_value = [ + Job = mock.Mock() + Job.job_type = None + Job.task_instance = mock.MagicMock() + Job.task_instance.task_id = "task_id" + Job.task_instance.dag_id = "dag_id" + Job.task_instance.run_as_user = getuser() + Job.task_instance.command_as_list.return_value = [ "airflow", "tasks", "test", @@ -229,24 +232,25 @@ def test_start_and_terminate_run_as_user(self, mock_init): "task1", "2016-01-01", ] - base_job.job_runner = LocalTaskJobRunner(base_job.task_instance) - runner = StandardTaskRunner(base_job) - - runner.start() - time.sleep(0.5) + job_runner = LocalTaskJobRunner(job=Job, task_instance=Job.task_instance) + task_runner = StandardTaskRunner(job_runner) - pgid = os.getpgid(runner.process.pid) - assert pgid > 0 - assert pgid != os.getpgid(0), "Task should be in a different process group to us" + task_runner.start() + try: + time.sleep(0.5) - processes = list(self._procs_in_pgroup(pgid)) + pgid = os.getpgid(task_runner.process.pid) + assert pgid > 0 + assert pgid != os.getpgid(0), "Task should be in a different process group to us" - runner.terminate() + processes = list(self._procs_in_pgroup(pgid)) + finally: + task_runner.terminate() for process in processes: assert not psutil.pid_exists(process.pid), f"{process} is still alive" - assert runner.return_code() is not None + assert task_runner.return_code() is not None @propagate_task_logger() @patch("airflow.utils.log.file_task_handler.FileTaskHandler._init_file") @@ -257,12 +261,13 @@ def test_early_reap_exit(self, mock_init, caplog): -9 and a log message. """ mock_init.return_value = "/tmp/any" - base_job = mock.Mock() - base_job.task_instance = mock.MagicMock() - base_job.task_instance.task_id = "task_id" - base_job.task_instance.dag_id = "dag_id" - base_job.task_instance.run_as_user = getuser() - base_job.task_instance.command_as_list.return_value = [ + Job = mock.Mock() + Job.job_type = None + Job.task_instance = mock.MagicMock() + Job.task_instance.task_id = "task_id" + Job.task_instance.dag_id = "dag_id" + Job.task_instance.run_as_user = getuser() + Job.task_instance.command_as_list.return_value = [ "airflow", "tasks", "test", @@ -270,24 +275,24 @@ def test_early_reap_exit(self, mock_init, caplog): "task1", "2016-01-01", ] - base_job.job_runner = LocalTaskJobRunner(base_job.task_instance) + job_runner = LocalTaskJobRunner(job=Job, task_instance=Job.task_instance) # Kick off the runner - runner = StandardTaskRunner(base_job) - runner.start() + task_runner = StandardTaskRunner(job_runner) + task_runner.start() time.sleep(0.2) # Kill the child process externally from the runner # Note that we have to do this from ANOTHER process, as if we just # call os.kill here we're doing it from the parent process and it # won't be the same as an external kill in terms of OS tracking. - pgid = os.getpgid(runner.process.pid) + pgid = os.getpgid(task_runner.process.pid) os.system(f"kill -s KILL {pgid}") time.sleep(0.2) - runner.terminate() + task_runner.terminate() - assert runner.return_code() == -9 + assert task_runner.return_code() == -9 assert "running out of memory" in caplog.text def test_on_kill(self): @@ -319,14 +324,15 @@ def test_on_kill(self): start_date=DEFAULT_DATE, ) ti = TaskInstance(task=task, run_id="test") - job1 = Job(job_runner=LocalTaskJobRunner(task_instance=ti, ignore_ti_state=True), dag_id=ti.dag_id) - runner = StandardTaskRunner(job1) - runner.start() + job = Job(dag_id=ti.dag_id) + job_runner = LocalTaskJobRunner(job=job, task_instance=ti, ignore_ti_state=True) + task_runner = StandardTaskRunner(job_runner) + task_runner.start() with timeout(seconds=3): while True: - runner_pgid = os.getpgid(runner.process.pid) - if runner_pgid == runner.process.pid: + runner_pgid = os.getpgid(task_runner.process.pid) + if runner_pgid == task_runner.process.pid: break time.sleep(0.01) @@ -341,7 +347,7 @@ def test_on_kill(self): logging.info("Task started. Give the task some time to settle") time.sleep(3) logging.info("Terminating processes %s belonging to %s group", processes, runner_pgid) - runner.terminate() + task_runner.terminate() logging.info("Waiting for the on kill killed file to appear") with timeout(seconds=4): @@ -377,26 +383,27 @@ def test_parsing_context(self): start_date=DEFAULT_DATE, ) ti = TaskInstance(task=task, run_id="test") - job1 = Job(job_runner=LocalTaskJobRunner(task_instance=ti, ignore_ti_state=True), dag_id=ti.dag_id) - runner = StandardTaskRunner(job1) - runner.start() + job = Job(dag_id=ti.dag_id) + job_runner = LocalTaskJobRunner(job=job, task_instance=ti, ignore_ti_state=True) + task_runner = StandardTaskRunner(job_runner) + task_runner.start() # Wait until process sets its pgid to be equal to pid with timeout(seconds=1): while True: - runner_pgid = os.getpgid(runner.process.pid) - if runner_pgid == runner.process.pid: + runner_pgid = os.getpgid(task_runner.process.pid) + if runner_pgid == task_runner.process.pid: break time.sleep(0.01) assert runner_pgid > 0 assert runner_pgid != os.getpgid(0), "Task should be in a different process group to us" processes = list(self._procs_in_pgroup(runner_pgid)) - psutil.wait_procs([runner.process]) + psutil.wait_procs([task_runner.process]) for process in processes: assert not psutil.pid_exists(process.pid), f"{process} is still alive" - assert runner.return_code() == 0 + assert task_runner.return_code() == 0 text = context_file.read_text() assert ( text == "_AIRFLOW_PARSING_CONTEXT_DAG_ID=test_parsing_context\n" diff --git a/tests/task/task_runner/test_task_runner.py b/tests/task/task_runner/test_task_runner.py index e95d47c473022..1c9928f6b3a51 100644 --- a/tests/task/task_runner/test_task_runner.py +++ b/tests/task/task_runner/test_task_runner.py @@ -38,10 +38,10 @@ def test_should_support_core_task_runner(self, mock_subprocess): ti = mock.MagicMock(map_index=-1, run_as_user=None) ti.get_template_context.return_value = {"ti": ti} ti.get_dagrun.return_value.get_log_template.return_value.filename = "blah" - base_job = mock.MagicMock(task_instance=ti) - base_job.job_runner = LocalTaskJobRunner(ti) - base_job.job_runner.job = base_job - task_runner = get_task_runner(base_job.job_runner) + Job = mock.MagicMock(task_instance=ti) + Job.job_type = None + job_runner = LocalTaskJobRunner(job=Job, task_instance=ti) + task_runner = get_task_runner(job_runner) assert "StandardTaskRunner" == task_runner.__class__.__name__ @@ -50,13 +50,11 @@ def test_should_support_core_task_runner(self, mock_subprocess): "tests.task.task_runner.test_task_runner.custom_task_runner", ) def test_should_support_custom_legacy_task_runner(self): - base_job = mock.MagicMock( - **{"task_instance.get_template_context.return_value": {"ti": mock.MagicMock()}} - ) + mock.MagicMock(**{"task_instance.get_template_context.return_value": {"ti": mock.MagicMock()}}) custom_task_runner.reset_mock() - task_runner = get_task_runner(base_job) + task_runner = get_task_runner(custom_task_runner) - custom_task_runner.assert_called_once_with(base_job.job) + custom_task_runner.assert_called_once_with(custom_task_runner) assert custom_task_runner.return_value == task_runner diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 0adbd43b210e4..2e2fb2db4abd5 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -24,6 +24,8 @@ from airflow import AirflowException from airflow.jobs.base_job_runner import BaseJobRunner +from airflow.jobs.job import Job +from airflow.serialization.pydantic.job import JobPydantic from airflow.utils import helpers, timezone from airflow.utils.helpers import ( at_most_one, @@ -329,9 +331,11 @@ def test_prune_dict(self, mode, expected): class MockJobRunner(BaseJobRunner): job_type = "MockJob" - def __init__(self, func=None): - self.func = func + def __init__(self, job: Job | JobPydantic, func=None): super().__init__() + self.job = job + self.job.job_type = self.job_type + self.func = func def _execute(self): if self.func is not None: diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py index dc66c4a4e204f..87e655e5e95cd 100644 --- a/tests/utils/test_log_handlers.py +++ b/tests/utils/test_log_handlers.py @@ -403,7 +403,7 @@ def test_set_context_trigger(self, create_dummy_dag, dag_maker, is_a_trigger, se assert isinstance(ti, TaskInstance) if is_a_trigger: ti.is_trigger_log_context = True - job = Job(job_runner=TriggererJobRunner()) + job = Job() t = Trigger("", {}) t.triggerer_job = job ti.triggerer = t @@ -475,7 +475,7 @@ def test_log_retrieval_valid_trigger(self, create_task_instance): ) ti.hostname = "hostname" trigger = Trigger("", {}) - job = Job(job_runner=TriggererJobRunner()) + job = Job(TriggererJobRunner.job_type) job.id = 123 trigger.triggerer_job = job ti.trigger = trigger diff --git a/tests/www/views/test_views_base.py b/tests/www/views/test_views_base.py index 0b4ee78c7d310..2ba8795cf1bbe 100644 --- a/tests/www/views/test_views_base.py +++ b/tests/www/views/test_views_base.py @@ -63,11 +63,10 @@ def heartbeat_healthy(): # case-1: healthy scheduler status last_heartbeat = timezone.utcnow() job = Job( - job_type="SchedulerJob", state="running", latest_heartbeat=last_heartbeat, - job_runner=SchedulerJobRunner(), ) + SchedulerJobRunner(job=job), with create_session() as session: session.add(job) yield "healthy", last_heartbeat.isoformat() @@ -84,11 +83,10 @@ def heartbeat_too_slow(): # case-2: unhealthy scheduler status - scenario 1 (SchedulerJob is running too slowly) last_heartbeat = timezone.utcnow() - datetime.timedelta(minutes=1) job = Job( - job_type="SchedulerJob", state="running", latest_heartbeat=last_heartbeat, - job_runner=SchedulerJobRunner(), ) + SchedulerJobRunner(job=job), with create_session() as session: session.query(Job).filter( Job.job_type == "SchedulerJob",