Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ repos:
name: Sync template context variable refs
language: python
entry: ./scripts/ci/pre_commit/template_context_key_sync.py
files: ^airflow/models/taskinstance\.py$|^airflow/utils/context\.pyi?$|^docs/apache-airflow/templates-ref\.rst$
files: ^airflow/models/taskinstance\.py$|^task_sdk/src/airflow/sdk/definitions/context\.py$|^docs/apache-airflow/templates-ref\.rst$
- id: check-base-operator-usage
language: pygrep
name: Check BaseOperator core imports
Expand Down
6 changes: 5 additions & 1 deletion airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ class TaskInstance(BaseModel):
dag_id: str
run_id: str
try_number: int
map_index: int | None = None
map_index: int = -1
hostname: str | None = None


class DagRun(BaseModel):
Expand All @@ -190,6 +191,9 @@ class TIRunContext(BaseModel):
dag_run: DagRun
"""DAG run information for the task instance."""

max_tries: int
"""Maximum number of tries for the task instance (from DB)."""

variables: Annotated[list[VariableResponse], Field(default_factory=list)]
"""Variables that can be accessed by the task instance."""

Expand Down
7 changes: 5 additions & 2 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,14 @@ def ti_run(
ti_id_str = str(task_instance_id)

old = (
select(TI.state, TI.dag_id, TI.run_id, TI.task_id, TI.map_index, TI.next_method)
select(TI.state, TI.dag_id, TI.run_id, TI.task_id, TI.map_index, TI.next_method, TI.max_tries)
.where(TI.id == ti_id_str)
.with_for_update()
)
try:
(previous_state, dag_id, run_id, task_id, map_index, next_method) = session.execute(old).one()
(previous_state, dag_id, run_id, task_id, map_index, next_method, max_tries) = session.execute(
old
).one()
except NoResultFound:
log.error("Task Instance %s not found", ti_id_str)
raise HTTPException(
Expand Down Expand Up @@ -165,6 +167,7 @@ def ti_run(

return TIRunContext(
dag_run=DagRun.model_validate(dr, from_attributes=True),
max_tries=max_tries,
# TODO: Add variables and connections that are needed (and has perms) for the task
variables=[],
connections=[],
Expand Down
2 changes: 1 addition & 1 deletion airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger

from airflow.sdk.definitions.context import Context
from airflow.typing_compat import Self
from airflow.utils.context import Context


def _parse_file_entrypoint():
Expand Down
2 changes: 1 addition & 1 deletion airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@
OperatorExpandKwargsArgument,
)
from airflow.models.mappedoperator import ValidationSource
from airflow.sdk.definitions.context import Context
from airflow.sdk.definitions.dag import DAG
from airflow.utils.context import Context
from airflow.utils.task_group import TaskGroup


Expand Down
7 changes: 5 additions & 2 deletions airflow/decorators/bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@

import warnings
from collections.abc import Collection, Mapping, Sequence
from typing import Any, Callable, ClassVar
from typing import TYPE_CHECKING, Any, Callable, ClassVar

from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
from airflow.providers.standard.operators.bash import BashOperator
from airflow.utils.context import Context, context_merge
from airflow.utils.context import context_merge
from airflow.utils.operator_helpers import determine_kwargs
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
from airflow.sdk.definitions.context import Context


class _BashDecoratedOperator(DecoratedOperator, BashOperator):
"""
Expand Down
2 changes: 1 addition & 1 deletion airflow/decorators/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from typing_extensions import TypeAlias

from airflow.models.baseoperator import TaskPreExecuteHook
from airflow.utils.context import Context
from airflow.sdk.definitions.context import Context

BoolConditionFunc: TypeAlias = Callable[[Context], bool]
MsgConditionFunc: TypeAlias = "Callable[[Context], tuple[bool, str | None]]"
Expand Down
2 changes: 1 addition & 1 deletion airflow/example_dags/example_dag_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from airflow.operators.email import EmailOperator

if TYPE_CHECKING:
from airflow.utils.context import Context
from airflow.sdk.definitions.context import Context


class GetRequestOperator(BaseOperator):
Expand Down
2 changes: 1 addition & 1 deletion airflow/example_dags/example_skip_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from airflow.utils.trigger_rule import TriggerRule

if TYPE_CHECKING:
from airflow.utils.context import Context
from airflow.sdk.definitions.context import Context


# Create some placeholder operators
Expand Down
4 changes: 2 additions & 2 deletions airflow/executors/workloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class TaskInstance(BaseModel):
dag_id: str
run_id: str
try_number: int
map_index: int | None = None
map_index: int = -1

pool_slots: int
queue: str
Expand All @@ -64,7 +64,7 @@ def key(self) -> TaskInstanceKey:
task_id=self.task_id,
run_id=self.run_id,
try_number=self.try_number,
map_index=-1 if self.map_index is None else self.map_index,
map_index=self.map_index,
)


Expand Down
2 changes: 1 addition & 1 deletion airflow/lineage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from airflow.utils.session import create_session

if TYPE_CHECKING:
from airflow.utils.context import Context
from airflow.sdk.definitions.context import Context

PIPELINE_OUTLETS = "pipeline_outlets"
PIPELINE_INLETS = "pipeline_inlets"
Expand Down
6 changes: 3 additions & 3 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import datetime
import inspect
from collections.abc import Iterable, Iterator, Mapping, Sequence
from collections.abc import Iterable, Iterator, Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable

Expand All @@ -30,7 +30,7 @@
from airflow.exceptions import AirflowException
from airflow.models.expandinput import NotFullyPopulated
from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator as TaskSDKAbstractOperator
from airflow.utils.context import Context
from airflow.sdk.definitions.context import Context
from airflow.utils.db import exists_query
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.setup_teardown import SetupTeardownContext
Expand Down Expand Up @@ -512,7 +512,7 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence

def render_template_fields(
self,
context: Mapping[str, Any],
context: Context,
jinja_env: jinja2.Environment | None = None,
) -> None:
"""
Expand Down
3 changes: 2 additions & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@

# Keeping this file at all is a temp thing as we migrate the repo to the task sdk as the base, but to keep
# main working and useful for others to develop against we use the TaskSDK here but keep this file around
from airflow.sdk.definitions.context import Context
from airflow.sdk.definitions.dag import DAG, BaseOperator as TaskSDKBaseOperator
from airflow.sdk.definitions.edges import EdgeModifier as TaskSDKEdgeModifier
from airflow.serialization.enums import DagAttributeTypes
Expand All @@ -89,7 +90,7 @@
from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
from airflow.utils import timezone
from airflow.utils.context import Context, context_get_outlet_events
from airflow.utils.context import context_get_outlet_events
from airflow.utils.edgemodifier import EdgeModifier
from airflow.utils.operator_helpers import ExecutionCallableRunner
from airflow.utils.operator_resources import Resources
Expand Down
4 changes: 2 additions & 2 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@
from airflow.models.operator import Operator
from airflow.models.param import ParamsDict
from airflow.models.xcom_arg import XComArg
from airflow.sdk.definitions.context import Context
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils.context import Context
from airflow.utils.operator_resources import Resources
from airflow.utils.task_group import TaskGroup
from airflow.utils.trigger_rule import TriggerRule
Expand Down Expand Up @@ -869,7 +869,7 @@ def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:

def render_template_fields(
self,
context: Mapping[str, Any],
context: Context,
jinja_env: jinja2.Environment | None = None,
) -> None:
"""
Expand Down
11 changes: 6 additions & 5 deletions airflow/models/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from airflow.utils.types import NOTSET, ArgNotSet

if TYPE_CHECKING:
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.sdk.definitions.dag import DAG

Expand Down Expand Up @@ -332,19 +331,21 @@ def deserialize(cls, data: dict, dags: dict) -> DagParam:
def process_params(
dag: DAG,
task: Operator,
dag_run: DagRun | None,
dagrun_conf: dict[str, Any] | None,
*,
suppress_exception: bool,
) -> dict[str, Any]:
"""Merge, validate params, and convert them into a simple dict."""
from airflow.configuration import conf

dagrun_conf = dagrun_conf or {}

params = ParamsDict(suppress_exception=suppress_exception)
with contextlib.suppress(AttributeError):
params.update(dag.params)
if task.params:
params.update(task.params)
if conf.getboolean("core", "dag_run_conf_overrides_params") and dag_run and dag_run.conf:
logger.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf)
params.update(dag_run.conf)
if conf.getboolean("core", "dag_run_conf_overrides_params") and dagrun_conf:
logger.debug("Updating task params (%s) with DagRun.conf (%s)", params, dagrun_conf)
params.update(dagrun_conf)
return params.validate()
28 changes: 14 additions & 14 deletions airflow/models/skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ class SkipMixin(LoggingMixin):

@staticmethod
def _set_state_to_skipped(
dag_run: DagRun,
dag_id: str,
run_id: str,
tasks: Sequence[str] | Sequence[tuple[str, int]],
session: Session,
) -> None:
Expand All @@ -71,8 +72,8 @@ def _set_state_to_skipped(
session.execute(
update(TaskInstance)
.where(
TaskInstance.dag_id == dag_run.dag_id,
TaskInstance.run_id == dag_run.run_id,
TaskInstance.dag_id == dag_id,
TaskInstance.run_id == run_id,
tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(tasks),
)
.values(state=TaskInstanceState.SKIPPED, start_date=now, end_date=now)
Expand All @@ -82,8 +83,8 @@ def _set_state_to_skipped(
session.execute(
update(TaskInstance)
.where(
TaskInstance.dag_id == dag_run.dag_id,
TaskInstance.run_id == dag_run.run_id,
TaskInstance.dag_id == dag_id,
TaskInstance.run_id == run_id,
TaskInstance.task_id.in_(tasks),
)
.values(state=TaskInstanceState.SKIPPED, start_date=now, end_date=now)
Expand All @@ -93,7 +94,8 @@ def _set_state_to_skipped(
@provide_session
def skip(
self,
dag_run: DagRun,
dag_id: str,
run_id: str,
tasks: Iterable[DAGNode],
map_index: int = -1,
session: Session = NEW_SESSION,
Expand All @@ -105,7 +107,8 @@ def skip(
so that NotPreviouslySkippedDep knows these tasks should be skipped when they
are cleared.

:param dag_run: the DagRun for which to set the tasks to skipped
:param dag_id: the dag_id of the dag run for which to set the tasks to skipped
:param run_id: the run_id of the dag run for which to set the tasks to skipped
:param tasks: tasks to skip (not task_ids)
:param session: db session to use
:param map_index: map_index of the current task instance
Expand All @@ -116,11 +119,8 @@ def skip(
if not task_list:
return

if dag_run is None:
raise ValueError("dag_run is required")

task_ids_list = [d.task_id for d in task_list]
SkipMixin._set_state_to_skipped(dag_run, task_ids_list, session)
SkipMixin._set_state_to_skipped(dag_id, run_id, task_ids_list, session)
session.commit()

if task_id is not None:
Expand All @@ -130,8 +130,8 @@ def skip(
key=XCOM_SKIPMIXIN_KEY,
value={XCOM_SKIPMIXIN_SKIPPED: task_ids_list},
task_id=task_id,
dag_id=dag_run.dag_id,
run_id=dag_run.run_id,
dag_id=dag_id,
run_id=run_id,
map_index=map_index,
session=session,
)
Expand Down Expand Up @@ -225,7 +225,7 @@ def skip_all_except(

follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set]
log.info("Skipping tasks %s", skip_tasks)
SkipMixin._set_state_to_skipped(dag_run, skip_tasks, session=session)
SkipMixin._set_state_to_skipped(dag_run.dag_id, dag_run.run_id, skip_tasks, session=session)
ti.xcom_push(
key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids}, session=session
)
Loading