Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 35 additions & 15 deletions airflow/sensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from sqlalchemy import select

from airflow import settings
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf
from airflow.exceptions import (
AirflowException,
Expand All @@ -48,6 +49,8 @@
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session

from airflow.utils.context import Context

# As documented in https://dev.mysql.com/doc/refman/5.7/en/datetime.html.
Expand Down Expand Up @@ -80,6 +83,31 @@ def __bool__(self) -> bool:
return self.is_done


@internal_api_call
@provide_session
def _orig_start_date(
dag_id: str, task_id: str, run_id: str, map_index: int, try_number: int, session: Session = NEW_SESSION
):
"""
Get the original start_date for a rescheduled task.

:meta private:
"""
return session.scalar(
select(TaskReschedule)
.where(
TaskReschedule.dag_id == dag_id,
TaskReschedule.task_id == task_id,
TaskReschedule.run_id == run_id,
TaskReschedule.map_index == map_index,
TaskReschedule.try_number == try_number,
)
.order_by(TaskReschedule.id.asc())
.with_only_columns(TaskReschedule.start_date)
.limit(1)
)


class BaseSensorOperator(BaseOperator, SkipMixin):
"""
Sensor operators are derived from this class and inherit these attributes.
Expand Down Expand Up @@ -211,8 +239,7 @@ def poke(self, context: Context) -> bool | PokeReturnValue:
"""Override when deriving this class."""
raise AirflowException("Override me.")

@provide_session
def execute(self, context: Context, session=NEW_SESSION) -> Any:
def execute(self, context: Context) -> Any:
started_at: datetime.datetime | float

if self.reschedule:
Expand All @@ -222,19 +249,12 @@ def execute(self, context: Context, session=NEW_SESSION) -> Any:
# If reschedule, use the start date of the first try (first try can be either the very
# first execution of the task, or the first execution after the task was cleared.)
first_try_number = max_tries - retries + 1

start_date = session.scalar(
select(TaskReschedule)
.where(
TaskReschedule.dag_id == ti.dag_id,
TaskReschedule.task_id == ti.task_id,
TaskReschedule.run_id == ti.run_id,
TaskReschedule.map_index == ti.map_index,
TaskReschedule.try_number == first_try_number,
)
.order_by(TaskReschedule.id.asc())
.with_only_columns(TaskReschedule.start_date)
.limit(1)
start_date = _orig_start_date(
dag_id=ti.dag_id,
task_id=ti.task_id,
run_id=ti.run_id,
map_index=ti.map_index,
try_number=first_try_number,
)
if not start_date:
start_date = timezone.utcnow()
Expand Down