diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py index 1c56aa42005a0..c0cf255cad89a 100644 --- a/airflow/sensors/base.py +++ b/airflow/sensors/base.py @@ -29,7 +29,6 @@ from sqlalchemy import select from airflow import settings -from airflow.api_internal.internal_api_call import internal_api_call from airflow.configuration import conf from airflow.exceptions import ( AirflowException, @@ -49,8 +48,6 @@ from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: - from sqlalchemy.orm.session import Session - from airflow.utils.context import Context # As documented in https://dev.mysql.com/doc/refman/5.7/en/datetime.html. @@ -83,31 +80,6 @@ def __bool__(self) -> bool: return self.is_done -@internal_api_call -@provide_session -def _orig_start_date( - dag_id: str, task_id: str, run_id: str, map_index: int, try_number: int, session: Session = NEW_SESSION -): - """ - Get the original start_date for a rescheduled task. - - :meta private: - """ - return session.scalar( - select(TaskReschedule) - .where( - TaskReschedule.dag_id == dag_id, - TaskReschedule.task_id == task_id, - TaskReschedule.run_id == run_id, - TaskReschedule.map_index == map_index, - TaskReschedule.try_number == try_number, - ) - .order_by(TaskReschedule.id.asc()) - .with_only_columns(TaskReschedule.start_date) - .limit(1) - ) - - class BaseSensorOperator(BaseOperator, SkipMixin): """ Sensor operators are derived from this class and inherit these attributes. @@ -239,7 +211,8 @@ def poke(self, context: Context) -> bool | PokeReturnValue: """Override when deriving this class.""" raise AirflowException("Override me.") - def execute(self, context: Context) -> Any: + @provide_session + def execute(self, context: Context, session=NEW_SESSION) -> Any: started_at: datetime.datetime | float if self.reschedule: @@ -249,12 +222,19 @@ def execute(self, context: Context) -> Any: # If reschedule, use the start date of the first try (first try can be either the very # first execution of the task, or the first execution after the task was cleared.) first_try_number = max_tries - retries + 1 - start_date = _orig_start_date( - dag_id=ti.dag_id, - task_id=ti.task_id, - run_id=ti.run_id, - map_index=ti.map_index, - try_number=first_try_number, + + start_date = session.scalar( + select(TaskReschedule) + .where( + TaskReschedule.dag_id == ti.dag_id, + TaskReschedule.task_id == ti.task_id, + TaskReschedule.run_id == ti.run_id, + TaskReschedule.map_index == ti.map_index, + TaskReschedule.try_number == first_try_number, + ) + .order_by(TaskReschedule.id.asc()) + .with_only_columns(TaskReschedule.start_date) + .limit(1) ) if not start_date: start_date = timezone.utcnow()