Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 15 additions & 35 deletions airflow/sensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from sqlalchemy import select

from airflow import settings
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf
from airflow.exceptions import (
AirflowException,
Expand All @@ -49,8 +48,6 @@
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session

from airflow.utils.context import Context

# As documented in https://dev.mysql.com/doc/refman/5.7/en/datetime.html.
Expand Down Expand Up @@ -83,31 +80,6 @@ def __bool__(self) -> bool:
return self.is_done


@internal_api_call
@provide_session
def _orig_start_date(
dag_id: str, task_id: str, run_id: str, map_index: int, try_number: int, session: Session = NEW_SESSION
):
"""
Get the original start_date for a rescheduled task.

:meta private:
"""
return session.scalar(
select(TaskReschedule)
.where(
TaskReschedule.dag_id == dag_id,
TaskReschedule.task_id == task_id,
TaskReschedule.run_id == run_id,
TaskReschedule.map_index == map_index,
TaskReschedule.try_number == try_number,
)
.order_by(TaskReschedule.id.asc())
.with_only_columns(TaskReschedule.start_date)
.limit(1)
)


class BaseSensorOperator(BaseOperator, SkipMixin):
"""
Sensor operators are derived from this class and inherit these attributes.
Expand Down Expand Up @@ -239,7 +211,8 @@ def poke(self, context: Context) -> bool | PokeReturnValue:
"""Override when deriving this class."""
raise AirflowException("Override me.")

def execute(self, context: Context) -> Any:
@provide_session
def execute(self, context: Context, session=NEW_SESSION) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding additional arg session is causing lot of mypy failures in providers where BaseSensorOperator is being extended. Error's something like this:

providers/src/airflow/providers/microsoft/azure/sensors/data_factory.py:98: error:
Signature of "execute" incompatible with supertype "BaseSensorOperator" 
[override]
        def execute(self, context: Context) -> None:
        ^
providers/src/airflow/providers/microsoft/azure/sensors/data_factory.py:98: note:      Superclass:
providers/src/airflow/providers/microsoft/azure/sensors/data_factory.py:98: note:          def execute(context: Context, session: Any = ...) -> Any
providers/src/airflow/providers/microsoft/azure/sensors/data_factory.py:98: note:      Subclass:
providers/src/airflow/providers/microsoft/azure/sensors/data_factory.py:98: note:          def execute(self, context: Context) -> None

Should we just remove the decorator instead of updating execute method signature here?

Failing checks: https://github.com/apache/airflow/actions/runs/12101035549/job/33740515392?pr=44510

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reverted this in #44510

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uuups, oh was at-sleep while the error was on main. If we revert I assume we should not make this is a batch with another PR - should we not "just" revert the PR from me that caused it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just fixed it and it passed all tests. That should be good? I removed the decorator as well

started_at: datetime.datetime | float

if self.reschedule:
Expand All @@ -249,12 +222,19 @@ def execute(self, context: Context) -> Any:
# If reschedule, use the start date of the first try (first try can be either the very
# first execution of the task, or the first execution after the task was cleared.)
first_try_number = max_tries - retries + 1
start_date = _orig_start_date(
dag_id=ti.dag_id,
task_id=ti.task_id,
run_id=ti.run_id,
map_index=ti.map_index,
try_number=first_try_number,

start_date = session.scalar(
select(TaskReschedule)
.where(
TaskReschedule.dag_id == ti.dag_id,
TaskReschedule.task_id == ti.task_id,
TaskReschedule.run_id == ti.run_id,
TaskReschedule.map_index == ti.map_index,
TaskReschedule.try_number == first_try_number,
)
.order_by(TaskReschedule.id.asc())
.with_only_columns(TaskReschedule.start_date)
.limit(1)
)
if not start_date:
start_date = timezone.utcnow()
Expand Down