Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 15 additions & 36 deletions airflow/sensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from sqlalchemy import select

from airflow import settings
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf
from airflow.exceptions import (
AirflowException,
Expand All @@ -46,11 +45,9 @@
from airflow.models.taskreschedule import TaskReschedule
from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
from airflow.utils import timezone
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.session import create_session

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session

from airflow.utils.context import Context

# As documented in https://dev.mysql.com/doc/refman/5.7/en/datetime.html.
Expand Down Expand Up @@ -83,31 +80,6 @@ def __bool__(self) -> bool:
return self.is_done


@internal_api_call
@provide_session
def _orig_start_date(
dag_id: str, task_id: str, run_id: str, map_index: int, try_number: int, session: Session = NEW_SESSION
):
"""
Get the original start_date for a rescheduled task.

:meta private:
"""
return session.scalar(
select(TaskReschedule)
.where(
TaskReschedule.dag_id == dag_id,
TaskReschedule.task_id == task_id,
TaskReschedule.run_id == run_id,
TaskReschedule.map_index == map_index,
TaskReschedule.try_number == try_number,
)
.order_by(TaskReschedule.id.asc())
.with_only_columns(TaskReschedule.start_date)
.limit(1)
)


class BaseSensorOperator(BaseOperator, SkipMixin):
"""
Sensor operators are derived from this class and inherit these attributes.
Expand Down Expand Up @@ -249,13 +221,20 @@ def execute(self, context: Context) -> Any:
# If reschedule, use the start date of the first try (first try can be either the very
# first execution of the task, or the first execution after the task was cleared.)
first_try_number = max_tries - retries + 1
start_date = _orig_start_date(
dag_id=ti.dag_id,
task_id=ti.task_id,
run_id=ti.run_id,
map_index=ti.map_index,
try_number=first_try_number,
)
with create_session() as session:
start_date = session.scalar(
select(TaskReschedule)
.where(
TaskReschedule.dag_id == ti.dag_id,
TaskReschedule.task_id == ti.task_id,
TaskReschedule.run_id == ti.run_id,
TaskReschedule.map_index == ti.map_index,
TaskReschedule.try_number == first_try_number,
)
.order_by(TaskReschedule.id.asc())
.with_only_columns(TaskReschedule.start_date)
.limit(1)
)
if not start_date:
start_date = timezone.utcnow()
started_at = start_date
Expand Down