From 46bd150863004726f6a3a34fc2eb187d8af58786 Mon Sep 17 00:00:00 2001 From: pierrejeambrun Date: Mon, 13 Feb 2023 22:04:01 +0100 Subject: [PATCH 1/2] Migrate TaskInstance.check_and_change_state_before_execution to rpc --- .../endpoints/rpc_api_endpoint.py | 2 + airflow/jobs/local_task_job.py | 3 +- airflow/models/taskinstance.py | 73 ++++++++++--------- tests/models/test_taskinstance.py | 8 +- 4 files changed, 47 insertions(+), 39 deletions(-) diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index 5456d6b9a040f..97fcf042951b6 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -27,6 +27,7 @@ from airflow.api_connexion.types import APIResponse from airflow.models import Trigger, Variable, XCom from airflow.models.dagwarning import DagWarning +from airflow.models.taskinstance import TaskInstance from airflow.serialization.serialized_objects import BaseSerialization log = logging.getLogger(__name__) @@ -46,6 +47,7 @@ def _initialize_map() -> dict[str, Callable]: DagModel.get_paused_dag_ids, DagFileProcessorManager.clear_nonexistent_import_errors, DagWarning.purge_inactive_dag_warnings, + TaskInstance.check_and_change_state_before_execution, XCom.get_value, XCom.get_one, XCom.get_many, diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py index ed5ec5ffea236..f4123595d491f 100644 --- a/airflow/jobs/local_task_job.py +++ b/airflow/jobs/local_task_job.py @@ -138,7 +138,8 @@ def sigusr2_debug_handler(signum, frame): # This is not supported on Windows systems signal.signal(signal.SIGUSR2, sigusr2_debug_handler) - if not self.task_instance.check_and_change_state_before_execution( + if not TaskInstance.check_and_change_state_before_execution( + self.task_instance, mark_success=self.mark_success, ignore_all_deps=self.ignore_all_deps, ignore_depends_on_past=self.ignore_depends_on_past, diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index c74ff8b6b0567..4cd85bedb3d87 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -66,6 +66,7 @@ from sqlalchemy.sql.expression import ColumnOperators, case from airflow import settings +from airflow.api_internal.internal_api_call import internal_api_call from airflow.compat.functools import cache from airflow.configuration import conf from airflow.datasets import Dataset @@ -1207,9 +1208,11 @@ def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun: return dr + @staticmethod + @internal_api_call @provide_session def check_and_change_state_before_execution( - self, + ti: TaskInstance, verbose: bool = True, ignore_all_deps: bool = False, ignore_depends_on_past: bool = False, @@ -1228,6 +1231,7 @@ def check_and_change_state_before_execution( True if and only if state is set to RUNNING, which implies that task should be executed, in preparation for _run_raw_task + :param ti: the task instance :param verbose: whether to turn on more verbose logging :param ignore_all_deps: Ignore all of the non-critical dependencies, just runs :param ignore_depends_on_past: Ignore depends_on_past DAG attribute @@ -1242,20 +1246,20 @@ def check_and_change_state_before_execution( :param session: SQLAlchemy ORM Session :return: whether the state was changed to running or not """ - task = self.task - self.refresh_from_task(task, pool_override=pool) - self.test_mode = test_mode - self.refresh_from_db(session=session, lock_for_update=True) - self.job_id = job_id - self.hostname = get_hostname() - self.pid = None + task = ti.task + ti.refresh_from_task(task, pool_override=pool) + ti.test_mode = test_mode + ti.refresh_from_db(session=session, lock_for_update=True) + ti.job_id = job_id + ti.hostname = get_hostname() + ti.pid = None - if not ignore_all_deps and not ignore_ti_state and self.state == State.SUCCESS: + if not ignore_all_deps and not ignore_ti_state and ti.state == State.SUCCESS: Stats.incr( "previously_succeeded", 1, 1, - tags={"dag_id": self.dag_id, "run_id": self.run_id, "task_id": self.task_id}, + tags={"dag_id": ti.dag_id, "run_id": ti.run_id, "task_id": ti.task_id}, ) if not mark_success: @@ -1270,7 +1274,7 @@ def check_and_change_state_before_execution( ignore_task_deps=ignore_task_deps, description="non-requeueable deps", ) - if not self.are_dependencies_met( + if not ti.are_dependencies_met( dep_context=non_requeueable_dep_context, session=session, verbose=True ): session.commit() @@ -1282,11 +1286,11 @@ def check_and_change_state_before_execution( # Set the task start date. In case it was re-scheduled use the initial # start date that is recorded in task_reschedule table # If the task continues after being deferred (next_method is set), use the original start_date - self.start_date = self.start_date if self.next_method else timezone.utcnow() - if self.state == State.UP_FOR_RESCHEDULE: - task_reschedule: TR = TR.query_for_task_instance(self, session=session).first() + ti.start_date = ti.start_date if ti.next_method else timezone.utcnow() + if ti.state == State.UP_FOR_RESCHEDULE: + task_reschedule: TR = TR.query_for_task_instance(ti, session=session).first() if task_reschedule: - self.start_date = task_reschedule.start_date + ti.start_date = task_reschedule.start_date # Secondly we find non-runnable but requeueable tis. We reset its state. # This is because we might have hit concurrency limits, @@ -1300,33 +1304,33 @@ def check_and_change_state_before_execution( ignore_ti_state=ignore_ti_state, description="requeueable deps", ) - if not self.are_dependencies_met(dep_context=dep_context, session=session, verbose=True): - self.state = State.NONE - self.log.warning( + if not ti.are_dependencies_met(dep_context=dep_context, session=session, verbose=True): + ti.state = State.NONE + ti.log.warning( "Rescheduling due to concurrency limits reached " "at task runtime. Attempt %s of " "%s. State set to NONE.", - self.try_number, - self.max_tries + 1, + ti.try_number, + ti.max_tries + 1, ) - self.queued_dttm = timezone.utcnow() - session.merge(self) + ti.queued_dttm = timezone.utcnow() + session.merge(ti) session.commit() return False - if self.next_kwargs is not None: - self.log.info("Resuming after deferral") + if ti.next_kwargs is not None: + ti.log.info("Resuming after deferral") else: - self.log.info("Starting attempt %s of %s", self.try_number, self.max_tries + 1) - self._try_number += 1 + ti.log.info("Starting attempt %s of %s", ti.try_number, ti.max_tries + 1) + ti._try_number += 1 if not test_mode: - session.add(Log(State.RUNNING, self)) - self.state = State.RUNNING - self.external_executor_id = external_executor_id - self.end_date = None + session.add(Log(State.RUNNING, ti)) + ti.state = State.RUNNING + ti.external_executor_id = external_executor_id + ti.end_date = None if not test_mode: - session.merge(self).task = task + session.merge(ti).task = task session.commit() # Closing all pooled connections to prevent @@ -1334,9 +1338,9 @@ def check_and_change_state_before_execution( settings.engine.dispose() # type: ignore if verbose: if mark_success: - self.log.info("Marking success for %s on %s", self.task, self.execution_date) + ti.log.info("Marking success for %s on %s", ti.task, ti.execution_date) else: - self.log.info("Executing %s on %s", self.task, self.execution_date) + ti.log.info("Executing %s on %s", ti.task, ti.execution_date) return True def _date_or_empty(self, attr: str) -> str: @@ -1715,7 +1719,8 @@ def run( session: Session = NEW_SESSION, ) -> None: """Run TaskInstance""" - res = self.check_and_change_state_before_execution( + res = TaskInstance.check_and_change_state_before_execution( + self, verbose=verbose, ignore_all_deps=ignore_all_deps, ignore_depends_on_past=ignore_depends_on_past, diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index e540202e1fb07..7667f5fb38951 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -1495,7 +1495,7 @@ def test_check_and_change_state_before_execution(self, create_task_instance): ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id) assert ti_from_deserialized_task._try_number == 0 - assert ti_from_deserialized_task.check_and_change_state_before_execution() + assert TI.check_and_change_state_before_execution(ti_from_deserialized_task) # State should be running, and try_number column should be incremented assert ti_from_deserialized_task.state == State.RUNNING assert ti_from_deserialized_task._try_number == 1 @@ -1508,7 +1508,7 @@ def test_check_and_change_state_before_execution_dep_not_met(self, create_task_i serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag ti2 = TI(task=serialized_dag.get_task(task2.task_id), run_id=ti.run_id) - assert not ti2.check_and_change_state_before_execution() + assert not TI.check_and_change_state_before_execution(ti2) def test_check_and_change_state_before_execution_dep_not_met_already_running(self, create_task_instance): """return False if the task instance state is running""" @@ -1521,7 +1521,7 @@ def test_check_and_change_state_before_execution_dep_not_met_already_running(sel serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id) - assert not ti_from_deserialized_task.check_and_change_state_before_execution() + assert not TI.check_and_change_state_before_execution(ti_from_deserialized_task) assert ti_from_deserialized_task.state == State.RUNNING def test_check_and_change_state_before_execution_dep_not_met_not_runnable_state( @@ -1537,7 +1537,7 @@ def test_check_and_change_state_before_execution_dep_not_met_not_runnable_state( serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id) - assert not ti_from_deserialized_task.check_and_change_state_before_execution() + assert not TI.check_and_change_state_before_execution(ti_from_deserialized_task) assert ti_from_deserialized_task.state == State.FAILED def test_try_number(self, create_task_instance): From 298b5a0f4350f386f1dc0afc3291b0cef62101b0 Mon Sep 17 00:00:00 2001 From: pierrejeambrun Date: Fri, 17 Feb 2023 02:47:34 +0100 Subject: [PATCH 2/2] Update using PK + code review --- .../endpoints/rpc_api_endpoint.py | 1 + airflow/jobs/local_task_job.py | 11 +- airflow/models/taskinstance.py | 108 ++++++++++++++---- tests/jobs/test_local_task_job.py | 5 - tests/models/test_taskinstance.py | 53 ++++++--- 5 files changed, 131 insertions(+), 47 deletions(-) diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index 97fcf042951b6..dffc88333feb4 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -48,6 +48,7 @@ def _initialize_map() -> dict[str, Callable]: DagFileProcessorManager.clear_nonexistent_import_errors, DagWarning.purge_inactive_dag_warnings, TaskInstance.check_and_change_state_before_execution, + TaskInstance.retrieve_from_db, XCom.get_value, XCom.get_one, XCom.get_many, diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py index f4123595d491f..cc7457cf66609 100644 --- a/airflow/jobs/local_task_job.py +++ b/airflow/jobs/local_task_job.py @@ -138,8 +138,12 @@ def sigusr2_debug_handler(signum, frame): # This is not supported on Windows systems signal.signal(signal.SIGUSR2, sigusr2_debug_handler) - if not TaskInstance.check_and_change_state_before_execution( - self.task_instance, + self.task_instance = TaskInstance.check_and_change_state_before_execution( + self.task_instance.dag_id, + self.task_instance.run_id, + self.task_instance.task_id, + self.task_instance.map_index, + self.task_instance.task, mark_success=self.mark_success, ignore_all_deps=self.ignore_all_deps, ignore_depends_on_past=self.ignore_depends_on_past, @@ -149,7 +153,8 @@ def sigusr2_debug_handler(signum, frame): job_id=self.id, pool=self.pool, external_executor_id=self.external_executor_id, - ): + ) + if not self.task_instance.state == State.RUNNING: self.log.info("Task is not able to be run") return None diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 4cd85bedb3d87..dd7959d6ab4f0 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -59,7 +59,7 @@ ) from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.mutable import MutableDict -from sqlalchemy.orm import reconstructor, relationship +from sqlalchemy.orm import make_transient, reconstructor, relationship from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value from sqlalchemy.orm.session import Session from sqlalchemy.sql.elements import BooleanClauseList @@ -788,6 +788,51 @@ def error(self, session: Session = NEW_SESSION) -> None: session.merge(self) session.commit() + @classmethod + @internal_api_call + @provide_session + def retrieve_from_db( + cls, + dag_id: str, + run_id: str, + task_id: str, + map_index: int, + session: Session = NEW_SESSION, + lock_for_update: bool = False, + ) -> TaskInstance | None: + """ + Retrieve the task instance from the database based on the primary key + + :param dag_id: The Dag ID + :param run_id: The Dag run ID + :param task_id: The Task ID + :param map_index: The map index + :param session: SQLAlchemy ORM Session + :param lock_for_update: if True, indicates that the database should + lock the TaskInstance (issuing a FOR UPDATE clause) until the + session is committed. + :return: The TaskInstance object retrieved from the database. + """ + logger = cls.logger() + logger.debug( + f"Retrieving TaskInstance from DB with primary key: ({dag_id}, {task_id}, {run_id}, {map_index})" + ) + + qry = session.query(TaskInstance).filter( + TaskInstance.dag_id == dag_id, + TaskInstance.task_id == task_id, + TaskInstance.run_id == run_id, + TaskInstance.map_index == map_index, + ) + + if lock_for_update: + for attempt in run_with_db_retries(logger=logger): + with attempt: + ti: TaskInstance | None = qry.with_for_update().one_or_none() + else: + ti = qry.one_or_none() + return ti + @provide_session def refresh_from_db(self, session: Session = NEW_SESSION, lock_for_update: bool = False) -> None: """ @@ -1121,7 +1166,6 @@ def get_failed_dep_statuses(self, dep_context: DepContext | None = None, session dep_context = dep_context or DepContext() for dep in dep_context.deps | self.task.deps: for dep_status in dep.get_dep_statuses(self, session, dep_context): - self.log.debug( "%s dependency '%s' PASSED: %s, %s", self, @@ -1212,7 +1256,11 @@ def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun: @internal_api_call @provide_session def check_and_change_state_before_execution( - ti: TaskInstance, + dag_id: str, + run_id: str, + task_id: str, + map_index: int, + task: Operator, verbose: bool = True, ignore_all_deps: bool = False, ignore_depends_on_past: bool = False, @@ -1225,14 +1273,18 @@ def check_and_change_state_before_execution( pool: str | None = None, external_executor_id: str | None = None, session: Session = NEW_SESSION, - ) -> bool: - """ - Checks dependencies and then sets state to RUNNING if they are met. Returns - True if and only if state is set to RUNNING, which implies that task should be - executed, in preparation for _run_raw_task - - :param ti: the task instance - :param verbose: whether to turn on more verbose logging + ) -> TaskInstance: + """ + Retrieve the TI based on its primary keys. Checks dependencies and then sets state to RUNNING if they + are met. Returns an updated version of the retrieved TI. If state is set to RUNNING, it implies + that task should be executed, in preparation for _run_raw_task. + + :param dag_id: The Dag ID + :param run_id: The Dag run ID + :param task_id: The Task ID + :param map_index: The map index + :pram task: The task object + :param verbose: Whether to turn on more verbose logging :param ignore_all_deps: Ignore all of the non-critical dependencies, just runs :param ignore_depends_on_past: Ignore depends_on_past DAG attribute :param wait_for_past_depends_before_skipping: Wait for past depends before mark the ti as skipped @@ -1241,15 +1293,22 @@ def check_and_change_state_before_execution( :param mark_success: Don't run the task, mark its state as success :param test_mode: Doesn't record success or failure in the DB :param job_id: Job (BackfillJob / LocalTaskJob / SchedulerJob) ID - :param pool: specifies the pool to use to run the task instance + :param pool: Specifies the pool to use to run the task instance :param external_executor_id: The identifier of the celery executor :param session: SQLAlchemy ORM Session :return: whether the state was changed to running or not """ - task = ti.task + ti = TaskInstance.retrieve_from_db( + dag_id, run_id, task_id, map_index, session=session, lock_for_update=True + ) + if ti is None: + ti = TaskInstance( + task=task, + run_id=run_id, + map_index=map_index, + ) ti.refresh_from_task(task, pool_override=pool) ti.test_mode = test_mode - ti.refresh_from_db(session=session, lock_for_update=True) ti.job_id = job_id ti.hostname = get_hostname() ti.pid = None @@ -1278,7 +1337,8 @@ def check_and_change_state_before_execution( dep_context=non_requeueable_dep_context, session=session, verbose=True ): session.commit() - return False + make_transient(ti) + return ti # For reporting purposes, we report based on 1-indexed, # not 0-indexed lists (i.e. Attempt 1 instead of @@ -1316,7 +1376,8 @@ def check_and_change_state_before_execution( ti.queued_dttm = timezone.utcnow() session.merge(ti) session.commit() - return False + make_transient(ti) + return ti if ti.next_kwargs is not None: ti.log.info("Resuming after deferral") @@ -1341,7 +1402,9 @@ def check_and_change_state_before_execution( ti.log.info("Marking success for %s on %s", ti.task, ti.execution_date) else: ti.log.info("Executing %s on %s", ti.task, ti.execution_date) - return True + + make_transient(ti) + return ti def _date_or_empty(self, attr: str) -> str: result: datetime | None = getattr(self, attr, None) @@ -1719,8 +1782,12 @@ def run( session: Session = NEW_SESSION, ) -> None: """Run TaskInstance""" - res = TaskInstance.check_and_change_state_before_execution( - self, + ti_before_execution = TaskInstance.check_and_change_state_before_execution( + self.dag_id, + self.run_id, + self.task_id, + self.map_index, + self.task, verbose=verbose, ignore_all_deps=ignore_all_deps, ignore_depends_on_past=ignore_depends_on_past, @@ -1733,7 +1800,8 @@ def run( pool=pool, session=session, ) - if not res: + self.state = ti_before_execution.state + if not self.state == State.RUNNING: return self._run_raw_task( diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index f008b522da076..d9594217033a7 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -279,7 +279,6 @@ def test_heartbeat_failed_fast(self): dag_id = "test_heartbeat_failed_fast" task_id = "test_heartbeat_failed_fast_op" with create_session() as session: - dag_id = "test_heartbeat_failed_fast" task_id = "test_heartbeat_failed_fast_op" dag = self.dagbag.get_dag(dag_id) @@ -341,7 +340,6 @@ def test_mark_success_no_kill(self, caplog, get_test_dag, session): ) def test_localtaskjob_double_trigger(self): - dag = self.dagbag.dags.get("test_localtaskjob_double_trigger") task = dag.get_task("test_localtaskjob_double_trigger_task") @@ -379,7 +377,6 @@ def test_localtaskjob_double_trigger(self): @patch.object(StandardTaskRunner, "return_code") @mock.patch("airflow.jobs.scheduler_job.Stats.incr", autospec=True) def test_local_task_return_code_metric(self, mock_stats_incr, mock_return_code, create_dummy_dag): - _, task = create_dummy_dag("test_localtaskjob_code") ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE) @@ -400,7 +397,6 @@ def test_local_task_return_code_metric(self, mock_stats_incr, mock_return_code, @patch.object(StandardTaskRunner, "return_code") def test_localtaskjob_maintain_heart_rate(self, mock_return_code, caplog, create_dummy_dag): - _, task = create_dummy_dag("test_localtaskjob_double_trigger") ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE) @@ -685,7 +681,6 @@ def test_fast_follow( get_test_dag, ): with conf_vars(conf): - dag = get_test_dag( "test_dagrun_fast_follow", ) diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 7667f5fb38951..2b60a8f7997e4 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -333,7 +333,7 @@ def test_not_requeue_non_requeueable_task_instance(self, dag_maker): assert ti.state == State.QUEUED dep_patch.return_value = TIDepStatus("mock_" + class_name, True, "mock") - for (dep_patch, method_patch) in patch_dict.values(): + for dep_patch, method_patch in patch_dict.values(): dep_patch.stop() def test_mark_non_runnable_task_as_success(self, create_task_instance): @@ -810,7 +810,6 @@ def func(): return done with dag_maker(dag_id="test_reschedule_handling") as dag: - task = PythonSensor.partial( task_id="test_reschedule_handling_sensor", mode="reschedule", @@ -1495,10 +1494,16 @@ def test_check_and_change_state_before_execution(self, create_task_instance): ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id) assert ti_from_deserialized_task._try_number == 0 - assert TI.check_and_change_state_before_execution(ti_from_deserialized_task) + ti_before_execution = TI.check_and_change_state_before_execution( + ti_from_deserialized_task.dag_id, + ti_from_deserialized_task.run_id, + ti_from_deserialized_task.task_id, + ti_from_deserialized_task.map_index, + ti_from_deserialized_task.task, + ) # State should be running, and try_number column should be incremented - assert ti_from_deserialized_task.state == State.RUNNING - assert ti_from_deserialized_task._try_number == 1 + assert ti_before_execution.state == State.RUNNING + assert ti_before_execution._try_number == 1 def test_check_and_change_state_before_execution_dep_not_met(self, create_task_instance): ti = create_task_instance(dag_id="test_check_and_change_state_before_execution") @@ -1508,10 +1513,16 @@ def test_check_and_change_state_before_execution_dep_not_met(self, create_task_i serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag ti2 = TI(task=serialized_dag.get_task(task2.task_id), run_id=ti.run_id) - assert not TI.check_and_change_state_before_execution(ti2) + ti_before_execution = TI.check_and_change_state_before_execution( + ti2.dag_id, + ti2.run_id, + ti2.task_id, + ti2.map_index, + ti2.task, + ) + assert ti_before_execution.state != State.RUNNING def test_check_and_change_state_before_execution_dep_not_met_already_running(self, create_task_instance): - """return False if the task instance state is running""" ti = create_task_instance(dag_id="test_check_and_change_state_before_execution") with create_session() as _: ti.state = State.RUNNING @@ -1521,8 +1532,14 @@ def test_check_and_change_state_before_execution_dep_not_met_already_running(sel serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id) - assert not TI.check_and_change_state_before_execution(ti_from_deserialized_task) - assert ti_from_deserialized_task.state == State.RUNNING + ti_before_execution = TI.check_and_change_state_before_execution( + ti_from_deserialized_task.dag_id, + ti_from_deserialized_task.run_id, + ti_from_deserialized_task.task_id, + ti_from_deserialized_task.map_index, + ti_from_deserialized_task.task, + ) + assert ti_before_execution.state == State.RUNNING def test_check_and_change_state_before_execution_dep_not_met_not_runnable_state( self, create_task_instance @@ -1537,8 +1554,14 @@ def test_check_and_change_state_before_execution_dep_not_met_not_runnable_state( serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id) - assert not TI.check_and_change_state_before_execution(ti_from_deserialized_task) - assert ti_from_deserialized_task.state == State.FAILED + ti_before_execution = TI.check_and_change_state_before_execution( + ti_from_deserialized_task.dag_id, + ti_from_deserialized_task.run_id, + ti_from_deserialized_task.task_id, + ti_from_deserialized_task.map_index, + ti_from_deserialized_task.task, + ) + assert ti_before_execution.state == State.FAILED def test_try_number(self, create_task_instance): """ @@ -2003,7 +2026,6 @@ def get_test_ti(execution_date: pendulum.DateTime, state: str) -> TI: @pytest.mark.parametrize("schedule_interval, catchup", _prev_dates_param_list) def test_previous_ti(self, schedule_interval, catchup, dag_maker) -> None: - scenario = [State.SUCCESS, State.FAILED, State.SUCCESS] ti_list = self._test_previous_dates_setup(schedule_interval, catchup, scenario, dag_maker) @@ -2016,7 +2038,6 @@ def test_previous_ti(self, schedule_interval, catchup, dag_maker) -> None: @pytest.mark.parametrize("schedule_interval, catchup", _prev_dates_param_list) def test_previous_ti_success(self, schedule_interval, catchup, dag_maker) -> None: - scenario = [State.FAILED, State.SUCCESS, State.FAILED, State.SUCCESS] ti_list = self._test_previous_dates_setup(schedule_interval, catchup, scenario, dag_maker) @@ -2030,7 +2051,6 @@ def test_previous_ti_success(self, schedule_interval, catchup, dag_maker) -> Non @pytest.mark.parametrize("schedule_interval, catchup", _prev_dates_param_list) def test_previous_execution_date_success(self, schedule_interval, catchup, dag_maker) -> None: - scenario = [State.FAILED, State.SUCCESS, State.FAILED, State.SUCCESS] ti_list = self._test_previous_dates_setup(schedule_interval, catchup, scenario, dag_maker) @@ -2045,7 +2065,6 @@ def test_previous_execution_date_success(self, schedule_interval, catchup, dag_m @pytest.mark.parametrize("schedule_interval, catchup", _prev_dates_param_list) def test_previous_start_date_success(self, schedule_interval, catchup, dag_maker) -> None: - scenario = [State.FAILED, State.SUCCESS, State.FAILED, State.SUCCESS] ti_list = self._test_previous_dates_setup(schedule_interval, catchup, scenario, dag_maker) @@ -2332,7 +2351,6 @@ def on_execute_callable(context): assert context["dag_run"].dag_id == "test_dagrun_execute_callback" for i, callback_input in enumerate([[on_execute_callable], on_execute_callable]): - ti = create_task_instance( dag_id=f"test_execute_callback_{i}", on_execute_callback=callback_input, @@ -2369,7 +2387,6 @@ def on_finish_callable(context): completed = True for i, callback_input in enumerate([[on_finish_callable], on_finish_callable]): - ti = create_task_instance( dag_id=f"test_finish_callback_{i}", end_date=timezone.utcnow() + datetime.timedelta(days=10), @@ -2668,7 +2685,6 @@ def test_generate_command_specific_param(self): @provide_session def test_get_rendered_template_fields(self, dag_maker, session=None): - with dag_maker("test-dag", session=session) as dag: task = BashOperator(task_id="op1", bash_command="{{ task.task_id }}") dag.fileloc = TEST_DAGS_FOLDER / "test_get_k8s_pod_yaml.py" @@ -2842,7 +2858,6 @@ def test_refresh_from_db(self, create_task_instance): ), f"Key: {key} had different values. Make sure it loads it in the refresh refresh_from_db()" def test_operator_field_with_serialization(self, create_task_instance): - ti = create_task_instance() assert ti.task.task_type == "EmptyOperator" assert ti.task.operator_name == "EmptyOperator"