diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py index c0cf255cad89a..1c56aa42005a0 100644 --- a/airflow/sensors/base.py +++ b/airflow/sensors/base.py @@ -29,6 +29,7 @@ from sqlalchemy import select from airflow import settings +from airflow.api_internal.internal_api_call import internal_api_call from airflow.configuration import conf from airflow.exceptions import ( AirflowException, @@ -48,6 +49,8 @@ from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: + from sqlalchemy.orm.session import Session + from airflow.utils.context import Context # As documented in https://dev.mysql.com/doc/refman/5.7/en/datetime.html. @@ -80,6 +83,31 @@ def __bool__(self) -> bool: return self.is_done +@internal_api_call +@provide_session +def _orig_start_date( + dag_id: str, task_id: str, run_id: str, map_index: int, try_number: int, session: Session = NEW_SESSION +): + """ + Get the original start_date for a rescheduled task. + + :meta private: + """ + return session.scalar( + select(TaskReschedule) + .where( + TaskReschedule.dag_id == dag_id, + TaskReschedule.task_id == task_id, + TaskReschedule.run_id == run_id, + TaskReschedule.map_index == map_index, + TaskReschedule.try_number == try_number, + ) + .order_by(TaskReschedule.id.asc()) + .with_only_columns(TaskReschedule.start_date) + .limit(1) + ) + + class BaseSensorOperator(BaseOperator, SkipMixin): """ Sensor operators are derived from this class and inherit these attributes. @@ -211,8 +239,7 @@ def poke(self, context: Context) -> bool | PokeReturnValue: """Override when deriving this class.""" raise AirflowException("Override me.") - @provide_session - def execute(self, context: Context, session=NEW_SESSION) -> Any: + def execute(self, context: Context) -> Any: started_at: datetime.datetime | float if self.reschedule: @@ -222,19 +249,12 @@ def execute(self, context: Context, session=NEW_SESSION) -> Any: # If reschedule, use the start date of the first try (first try can be either the very # first execution of the task, or the first execution after the task was cleared.) first_try_number = max_tries - retries + 1 - - start_date = session.scalar( - select(TaskReschedule) - .where( - TaskReschedule.dag_id == ti.dag_id, - TaskReschedule.task_id == ti.task_id, - TaskReschedule.run_id == ti.run_id, - TaskReschedule.map_index == ti.map_index, - TaskReschedule.try_number == first_try_number, - ) - .order_by(TaskReschedule.id.asc()) - .with_only_columns(TaskReschedule.start_date) - .limit(1) + start_date = _orig_start_date( + dag_id=ti.dag_id, + task_id=ti.task_id, + run_id=ti.run_id, + map_index=ti.map_index, + try_number=first_try_number, ) if not start_date: start_date = timezone.utcnow()