Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,14 @@ def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[in

# TODO: This initiates one API call for each XComArg. Would it be
# more efficient to do one single call and unpack the value here?

resolved = {
k: v.resolve(context) if _needs_run_time_resolution(v) else v for k, v in self.value.items()
}

all_lengths = self._get_map_lengths(resolved, upstream_map_indexes)
sized_resolved = {k: v for k, v in resolved.items() if isinstance(v, Sized)}

all_lengths = self._get_map_lengths(sized_resolved, upstream_map_indexes)

data = {k: self._expand_mapped_field(k, v, map_index, all_lengths) for k, v in resolved.items()}
literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()}
Expand Down
4 changes: 2 additions & 2 deletions task-sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe
class AssetAny(AssetBooleanCondition):
"""Use to combine assets schedule references in an "or" relationship."""

agg_func = any
agg_func = any # type: ignore[assignment]

def __or__(self, other: BaseAsset) -> BaseAsset:
if not isinstance(other, BaseAsset):
Expand All @@ -656,7 +656,7 @@ def as_expression(self) -> dict[str, Any]:
class AssetAll(AssetBooleanCondition):
"""Use to combine assets schedule references in an "and" relationship."""

agg_func = all
agg_func = all # type: ignore[assignment]

def __and__(self, other: BaseAsset) -> BaseAsset:
if not isinstance(other, BaseAsset):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@

import types
from collections.abc import Callable
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

from airflow.exceptions import AirflowException
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext
from airflow.sdk.definitions.decorators.task_group import _TaskGroupFactory

if TYPE_CHECKING:
from airflow.sdk.bases.decorator import _TaskDecorator
from airflow.sdk.definitions.xcom_arg import XComArg

try:
Expand All @@ -52,7 +53,8 @@ def initialize_context(...):
func = python_task(func)
if isinstance(func, _TaskGroupFactory):
raise AirflowException("Task groups cannot be marked as setup or teardown.")
func.is_setup = True # type: ignore[attr-defined]
func = cast("_TaskDecorator", func)
func.is_setup = True # type: ignore[attr-defined] # TODO: Remove this once mypy is bump to 1.16.1
return func


Expand All @@ -76,6 +78,9 @@ def teardown(func: Callable) -> Callable:
func = python_task(func)
if isinstance(func, _TaskGroupFactory):
raise AirflowException("Task groups cannot be marked as setup or teardown.")
func = cast("_TaskDecorator", func)

# TODO: Remove below attr-defined once mypy is bump to 1.16.1
func.is_teardown = True # type: ignore[attr-defined]
func.on_failure_fail_dagrun = on_failure_fail_dagrun # type: ignore[attr-defined]
return func
Expand Down
8 changes: 7 additions & 1 deletion task-sdk/src/airflow/sdk/definitions/taskgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from airflow.models.expandinput import SchedulerExpandInput
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
from airflow.sdk.definitions._internal.expandinput import DictOfListsExpandInput, ListOfDictsExpandInput
from airflow.sdk.definitions._internal.mixins import DependencyMixin
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.edges import EdgeModifier
Expand Down Expand Up @@ -613,7 +614,12 @@ class MappedTaskGroup(TaskGroup):
a ``@task_group`` function instead.
"""

def __init__(self, *, expand_input: SchedulerExpandInput, **kwargs: Any) -> None:
def __init__(
self,
*,
expand_input: SchedulerExpandInput | DictOfListsExpandInput | ListOfDictsExpandInput,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self._expand_input = expand_input

Expand Down
10 changes: 4 additions & 6 deletions task-sdk/src/airflow/sdk/execution_time/secrets_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,10 @@
from enum import Enum
from functools import cache, cached_property
from re import Pattern
from typing import TYPE_CHECKING, Any, TextIO, TypeAlias, TypeVar
from typing import Any, TextIO, TypeAlias, TypeVar

from airflow import settings

if TYPE_CHECKING:
from typing import TypeGuard

V1EnvVar = TypeVar("V1EnvVar")
Redactable: TypeAlias = str | V1EnvVar | dict[Any, Any] | tuple[Any, ...] | list[Any]
Redacted: TypeAlias = Redactable | str
Expand Down Expand Up @@ -154,7 +151,8 @@ def _get_v1_env_var_type() -> type:
return V1EnvVar


def _is_v1_env_var(v: Any) -> TypeGuard[V1EnvVar]:
# TODO update return type to TypeGuard[V1EnvVar] once mypy 1.17.0 is available
def _is_v1_env_var(v: Any) -> bool:
return isinstance(v, _get_v1_env_var_type())


Expand Down Expand Up @@ -256,7 +254,7 @@ def _redact(self, item: Redactable, name: str | None, depth: int, max_depth: int
return to_return
if isinstance(item, Enum):
return self._redact(item=item.value, name=name, depth=depth, max_depth=max_depth)
if _is_v1_env_var(item):
if _is_v1_env_var(item) and hasattr(item, "to_dict"):
tmp: dict = item.to_dict() # type: ignore[attr-defined] # V1EnvVar has a to_dict method
if should_hide_value_for_key(tmp.get("name", "")) and "value" in tmp:
tmp["value"] = "***"
Expand Down
2 changes: 1 addition & 1 deletion task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def kill(
return

# Escalation sequence: SIGINT -> SIGTERM -> SIGKILL
escalation_path = [signal.SIGINT, signal.SIGTERM, signal.SIGKILL]
escalation_path: list[signal.Signals] = [signal.SIGINT, signal.SIGTERM, signal.SIGKILL]

if force and signal_to_send in escalation_path:
# Start from `signal_to_send` and escalate to the end of the escalation path
Expand Down