Skip to content
59 changes: 50 additions & 9 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,10 @@ class TaskState:
#: Cached hash of :attr:`~TaskState.client_key`
_hash: int

#: Cached while tasks are in `queued` or `no-worker`; set in
#: `transition_waiting_processing` and `_add_to_processing`
_rootish: bool | None

# Support for weakrefs to a class with __slots__
__weakref__: Any = None
__slots__ = tuple(__annotations__)
Expand Down Expand Up @@ -1352,6 +1356,7 @@ def __init__(self, key: str, run_spec: object, state: TaskStateState):
self.metadata = {}
self.annotations = {}
self.erred_on = set()
self._rootish = None
TaskState._instances.add(self)

def __hash__(self) -> int:
Expand Down Expand Up @@ -1511,10 +1516,14 @@ class SchedulerState:
#: All tasks currently known to the scheduler
tasks: dict[str, TaskState]

#: Tasks in the "queued" state, ordered by priority
#: Tasks in the "queued" state, ordered by priority.
#: They are all root-ish.
#: Always empty if `worker-saturation` is set to `inf`.
queued: HeapSet[TaskState]

#: Tasks in the "no-worker" state
#: Tasks in the "no-worker" state.
#: They may or may not have restrictions.
#: Only contains root-ish tasks if `worker-saturation` is set to `inf`.
unrunnable: set[TaskState]

#: Subset of tasks that exist in memory on more than one worker
Expand Down Expand Up @@ -2014,11 +2023,19 @@ def transition_no_worker_processing(self, key, stimulus_id):
assert not ts.actor, f"Actors can't be in `no-worker`: {ts}"
assert ts in self.unrunnable

if ws := self.decide_worker_non_rootish(ts):
decide_worker = (
self.decide_worker_rootish_queuing_disabled
if self.is_rootish(ts)
else self.decide_worker_non_rootish
)
if ws := decide_worker(ts):
self.unrunnable.discard(ts)
worker_msgs = _add_to_processing(self, ts, ws)
# If no worker, task just stays in `no-worker`

if self.validate and self.is_rootish(ts):
assert ws is not None

return recommendations, client_msgs, worker_msgs
except Exception as e:
logger.exception(e)
Expand Down Expand Up @@ -2052,8 +2069,8 @@ def decide_worker_rootish_queuing_disabled(
``no-worker``.
"""
if self.validate:
# See root-ish-ness note below in `decide_worker_rootish_queuing_enabled`
assert math.isinf(self.WORKER_SATURATION)
assert self.is_rootish(ts)

pool = self.idle.values() if self.idle else self.running
if not pool:
Expand Down Expand Up @@ -2113,11 +2130,6 @@ def decide_worker_rootish_queuing_enabled(self) -> WorkerState | None:

"""
if self.validate:
# We don't `assert self.is_rootish(ts)` here, because that check is dependent on
# cluster size. It's possible a task looked root-ish when it was queued, but the
# cluster has since scaled up and it no longer does when coming out of the queue.
# If `is_rootish` changes to a static definition, then add that assertion here
# (and actually pass in the task).
assert not math.isinf(self.WORKER_SATURATION)

if not self.idle:
Expand Down Expand Up @@ -2154,6 +2166,9 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None:
``ts`` or there are no running workers, returns None, in which case the task
should be transitioned to ``no-worker``.
"""
if self.validate:
assert not self.is_rootish(ts)

if not self.running:
return None

Expand Down Expand Up @@ -2222,13 +2237,15 @@ def transition_waiting_processing(self, key, stimulus_id):
# NOTE: having two root-ish methods is temporary. When the feature flag is removed,
# there should only be one, which combines co-assignment and queuing.
# Eventually, special-casing root tasks might be removed entirely, with better heuristics.
ts._rootish = True # cached until `processing`
if math.isinf(self.WORKER_SATURATION):
if not (ws := self.decide_worker_rootish_queuing_disabled(ts)):
return {ts.key: "no-worker"}, {}, {}
else:
if not (ws := self.decide_worker_rootish_queuing_enabled()):
return {ts.key: "queued"}, {}, {}
else:
ts._rootish = False # cached until `processing`
if not (ws := self.decide_worker_non_rootish(ts)):
return {ts.key: "no-worker"}, {}, {}

Expand Down Expand Up @@ -2988,6 +3005,15 @@ def is_rootish(self, ts: TaskState) -> bool:
Root-ish tasks are part of a group that's much larger than the cluster,
and have few or no dependencies.
"""
# NOTE: the result of `is_rootish` is cached in `waiting->processing`, and
# invalidated when entering `processing`. This is for the benefit of the
# `queued` and and `no-worker` states. We cache `is_rootish` not for
# performance, but so it can't change if `TaskGroup` and cluster size does. That
# avoids annoying edge cases where a task does/doesn't look root-ish when it
# goes into `queued` or `unrunnable`, but that's flipped when it comes out.
if (cached := ts._rootish) is not None:
return cached

if ts.resource_restrictions or ts.worker_restrictions or ts.host_restrictions:
return False
tg = ts.group
Expand Down Expand Up @@ -4892,6 +4918,7 @@ def validate_released(self, key):
assert not any([ts in dts.waiters for dts in ts.dependencies])
assert ts not in self.unrunnable
assert ts not in self.queued
assert ts._rootish is None, ts._rootish

def validate_waiting(self, key):
ts: TaskState = self.tasks[key]
Expand All @@ -4900,6 +4927,7 @@ def validate_waiting(self, key):
assert not ts.processing_on
assert ts not in self.unrunnable
assert ts not in self.queued
assert ts._rootish is None, ts._rootish
for dts in ts.dependencies:
# We are waiting on a dependency iff it's not stored
assert bool(dts.who_has) != (dts in ts.waiting_on)
Expand All @@ -4912,6 +4940,7 @@ def validate_queued(self, key):
assert not ts.waiting_on
assert not ts.who_has
assert not ts.processing_on
assert ts._rootish is True, ts._rootish
assert not (
ts.worker_restrictions or ts.host_restrictions or ts.resource_restrictions
)
Expand All @@ -4928,6 +4957,7 @@ def validate_processing(self, key):
assert ts in ws.processing
assert not ts.who_has
assert ts not in self.queued
assert ts._rootish is None, ts._rootish
for dts in ts.dependencies:
assert dts.who_has
assert ts in dts.waiters
Expand All @@ -4941,6 +4971,7 @@ def validate_memory(self, key):
assert not ts.waiting_on
assert ts not in self.unrunnable
assert ts not in self.queued
assert ts._rootish is None, ts._rootish
for dts in ts.dependents:
assert (dts in ts.waiters) == (
dts.state in ("waiting", "queued", "processing", "no-worker")
Expand All @@ -4955,6 +4986,7 @@ def validate_no_worker(self, key):
assert not ts.processing_on
assert not ts.who_has
assert ts not in self.queued
assert ts._rootish is not None, ts._rootish
for dts in ts.dependencies:
assert dts.who_has

Expand All @@ -4963,6 +4995,7 @@ def validate_erred(self, key):
assert ts.exception_blame
assert not ts.who_has
assert ts not in self.queued
assert ts._rootish is None, ts._rootish

def validate_key(self, key, ts: TaskState | None = None):
try:
Expand Down Expand Up @@ -7052,6 +7085,8 @@ def get_metadata(self, keys: list[str], default=no_default):
def set_restrictions(self, worker: dict[str, Collection[str] | str]):
for key, restrictions in worker.items():
ts = self.tasks[key]
if ts._rootish is not None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like that set_restrictions is a public API at all. Doesn't seem like something you should be able to do post-hoc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails test_reschedule_concurrent_requests_deadlock, which sets restrictions on a processing task.

@gen_cluster(
client=True,
nthreads=[("", 1)] * 3,
config={
"distributed.scheduler.work-stealing-interval": 1_000_000,
},
)
async def test_reschedule_concurrent_requests_deadlock(c, s, *workers):
# https://github.com/dask/distributed/issues/5370
steal = s.extensions["stealing"]
w0 = workers[0]
ev = Event()
futs1 = c.map(
lambda _, ev: ev.wait(),
range(10),
ev=ev,
key=[f"f1-{ix}" for ix in range(10)],
workers=[w0.address],
allow_other_workers=True,
)
while not w0.active_keys:
await asyncio.sleep(0.01)
# ready is a heap but we don't need last, just not the next
victim_key = list(w0.active_keys)[0]
victim_ts = s.tasks[victim_key]
wsA = victim_ts.processing_on
other_workers = [ws for ws in s.workers.values() if ws != wsA]
wsB = other_workers[0]
wsC = other_workers[1]
steal.move_task_request(victim_ts, wsA, wsB)
s.set_restrictions(worker={victim_key: [wsB.address]})
s._reschedule(victim_key, stimulus_id="test")
assert wsB == victim_ts.processing_on
# move_task_request is not responsible for respecting worker restrictions
steal.move_task_request(victim_ts, wsB, wsC)
# Let tasks finish
await ev.set()
await c.gather(futs1)
assert victim_ts.who_has != {wsC}
msgs = steal.story(victim_ts)
msgs = [msg[:-1] for msg in msgs] # Remove random IDs
# There are three possible outcomes
expect1 = [
("stale-response", victim_key, "executing", wsA.address),
("already-computing", victim_key, "executing", wsB.address, wsC.address),
]
expect2 = [
("already-computing", victim_key, "executing", wsB.address, wsC.address),
("already-aborted", victim_key, "executing", wsA.address),
]
# This outcome appears only in ~2% of the runs
expect3 = [
("already-computing", victim_key, "executing", wsB.address, wsC.address),
("already-aborted", victim_key, "memory", wsA.address),
]
assert msgs in (expect1, expect2, expect3)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like that set_restrictions is a public API at all. Doesn't seem like something you should be able to do post-hoc.

We can change it. First step is a deprecation warning

raise ValueError(f"cannot set restrictions on ready {ts}")
if isinstance(restrictions, str):
restrictions = {restrictions}
ts.worker_restrictions = set(restrictions)
Expand Down Expand Up @@ -7774,6 +7809,7 @@ def _validate_ready(state: SchedulerState, ts: TaskState) -> None:
assert ts not in state.unrunnable
assert ts not in state.queued
assert all(dts.who_has for dts in ts.dependencies)
assert ts._rootish is not None, ts._rootish


def _add_to_processing(
Expand All @@ -7785,6 +7821,7 @@ def _add_to_processing(
assert ws in state.running, state.running
assert (o := state.workers.get(ws.address)) is ws, (ws, o)

ts._rootish = None
ws.add_to_processing(ts)
ts.processing_on = ws
ts.state = "processing"
Expand Down Expand Up @@ -7812,6 +7849,9 @@ def _exit_processing_common(
--------
Scheduler._set_duration_estimate
"""
if state.validate:
assert ts._rootish is None, ts._rootish

ws = ts.processing_on
assert ws
ts.processing_on = None
Expand Down Expand Up @@ -7913,6 +7953,7 @@ def _propagate_released(
recommendations: Recs,
) -> None:
ts.state = "released"
ts._rootish = None
key = ts.key

if ts.has_lost_dependencies:
Expand Down
79 changes: 79 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pickle
import re
import sys
from contextlib import AsyncExitStack
from itertools import product
from textwrap import dedent
from time import sleep
Expand Down Expand Up @@ -481,6 +482,84 @@ async def test_queued_remove_add_worker(c, s, a, b):
await wait(fs)


@gen_cluster(
client=True,
nthreads=[("", 2)] * 2,
config={
"distributed.worker.memory.pause": False,
"distributed.worker.memory.target": False,
"distributed.worker.memory.spill": False,
"distributed.scheduler.work-stealing": False,
},
)
async def test_queued_rootish_changes_while_paused(c, s, a, b):
"Some tasks are root-ish, some aren't. So both `unrunnable` and `queued` contain non-restricted tasks."

root = c.submit(inc, 1, key="root")
await root

# manually pause the workers
a.status = Status.paused
b.status = Status.paused

await async_wait_for(lambda: not s.running, 5)

fs = [c.submit(inc, root, key=f"inc-{i}") for i in range(s.total_nthreads * 2 + 1)]
# ^ `c.submit` in a for-loop so the first tasks don't look root-ish (`TaskGroup` too
# small), then the last one does. So N-1 tasks will go to `no-worker`, and the last
# to `queued`. `is_rootish` is just messed up like that.

await async_wait_for(lambda: len(s.tasks) > len(fs), 5)

# un-pause
a.status = Status.running
b.status = Status.running
await async_wait_for(lambda: len(s.running) == len(s.workers), 5)

await c.gather(fs)


@gen_cluster(
client=True,
nthreads=[("", 1)],
config={"distributed.scheduler.work-stealing": False},
)
async def test_queued_rootish_changes_scale_up(c, s, a):
"Tasks are initially root-ish. After cluster scales, they don't meet the definition, but still are."

root = c.submit(inc, 1, key="root")

event = Event()
clog = c.submit(event.wait, key="clog")
await wait_for_state(clog.key, "processing", s)

fs = c.map(inc, [root] * 5, key=[f"inc-{i}" for i in range(5)])

await async_wait_for(lambda: len(s.tasks) > len(fs), 5)

if not s.is_rootish(s.tasks[fs[0].key]):
pytest.fail(
"Test assumptions have changed; task is not root-ish. Test may no longer be relevant."
)
if math.isfinite(s.WORKER_SATURATION):
assert s.queued

async with AsyncExitStack() as stack:
for _ in range(3):
await stack.enter_async_context(Worker(s.address, nthreads=2))

if not s.is_rootish(s.tasks[fs[0].key]):
pytest.fail(
"Test assumptions have changed; root-ish-ness has flipped. Test may no longer be relevant."
)

await event.set()
await clog

# Just verify it doesn't deadlock
await c.gather(fs)


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_secede_opens_slot(c, s, a):
first = Event()
Expand Down