Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import inspect
from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence, Union, overload

from sqlalchemy import func
from sqlalchemy import func, or_
from sqlalchemy.orm import Session

from airflow.exceptions import AirflowException, XComNotFound
Expand All @@ -33,6 +33,7 @@
from airflow.utils.mixins import ResolveMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.state import State
from airflow.utils.types import NOTSET, ArgNotSet
from airflow.utils.xcom import XCOM_RETURN_KEY

Expand Down Expand Up @@ -309,11 +310,26 @@ def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg:
return super().zip(*others, fillvalue=fillvalue)

def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XCom

task = self.operator
if isinstance(task, MappedOperator):
unfinished_ti_count_query = session.query(func.count(TaskInstance.map_index)).filter(
TaskInstance.dag_id == task.dag_id,
TaskInstance.run_id == run_id,
TaskInstance.task_id == task.task_id,
# Special NULL treatment is needed because 'state' can be NULL.
# The "IN" part would produce "NULL NOT IN ..." and eventually
# "NULl = NULL", which is a big no-no in SQL.
or_(
TaskInstance.state.is_(None),
TaskInstance.state.in_(s.value for s in State.unfinished if s is not None),
),
)
if unfinished_ti_count_query.scalar():
return None # Not all of the expanded tis are done yet.
query = session.query(func.count(XCom.map_index)).filter(
XCom.dag_id == task.dag_id,
XCom.run_id == run_id,
Expand All @@ -332,7 +348,11 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:

@provide_session
def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
from airflow.models.taskinstance import TaskInstance

ti = context["ti"]
assert isinstance(ti, TaskInstance), "Wait for AIP-44 implementation to complete"

task_id = self.operator.task_id
map_indexes = ti.get_relevant_upstream_map_indexes(
self.operator,
Expand Down
41 changes: 41 additions & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3956,3 +3956,44 @@ def last_task():
middle_ti = dr.get_task_instance(task_id="middle_task", map_index=i)
assert middle_ti.state == State.SCHEDULED
assert "3 downstream tasks scheduled from follow-on schedule" in caplog.text


def test_mini_scheduler_not_skip_mapped_downstream_until_all_upstreams_finish(dag_maker, session):
with dag_maker(session=session):

@task
def generate() -> list[list[int]]:
return []

@task
def a_sum(numbers: list[int]) -> int:
return sum(numbers)

@task
def b_double(summed: int) -> int:
return summed * 2

@task
def c_gather(result) -> None:
pass

static = EmptyOperator(task_id="static")

summed = a_sum.expand(numbers=generate())
doubled = b_double.expand(summed=summed)
static >> c_gather(doubled)

dr: DagRun = dag_maker.create_dagrun()
tis = {(ti.task_id, ti.map_index): ti for ti in dr.task_instances}

static_ti = tis[("static", -1)]
static_ti.run(session=session)
static_ti.schedule_downstream_tasks(session=session)
# No tasks should be skipped yet!
assert not dr.get_task_instances([TaskInstanceState.SKIPPED], session=session)

generate_ti = tis[("generate", -1)]
generate_ti.run(session=session)
generate_ti.schedule_downstream_tasks(session=session)
# Now downstreams can be skipped.
assert dr.get_task_instances([TaskInstanceState.SKIPPED], session=session)