diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 5342d693e91..c0a3308202d 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1311,6 +1311,10 @@ class TaskState: #: Cached hash of :attr:`~TaskState.client_key` _hash: int + #: Cached while tasks are in `queued` or `no-worker`; set in + #: `transition_waiting_processing` and `_add_to_processing` + _rootish: bool | None + # Support for weakrefs to a class with __slots__ __weakref__: Any = None __slots__ = tuple(__annotations__) @@ -1352,6 +1356,7 @@ def __init__(self, key: str, run_spec: object, state: TaskStateState): self.metadata = {} self.annotations = {} self.erred_on = set() + self._rootish = None TaskState._instances.add(self) def __hash__(self) -> int: @@ -1511,10 +1516,14 @@ class SchedulerState: #: All tasks currently known to the scheduler tasks: dict[str, TaskState] - #: Tasks in the "queued" state, ordered by priority + #: Tasks in the "queued" state, ordered by priority. + #: They are all root-ish. + #: Always empty if `worker-saturation` is set to `inf`. queued: HeapSet[TaskState] - #: Tasks in the "no-worker" state + #: Tasks in the "no-worker" state. + #: They may or may not have restrictions. + #: Only contains root-ish tasks if `worker-saturation` is set to `inf`. unrunnable: set[TaskState] #: Subset of tasks that exist in memory on more than one worker @@ -2014,11 +2023,19 @@ def transition_no_worker_processing(self, key, stimulus_id): assert not ts.actor, f"Actors can't be in `no-worker`: {ts}" assert ts in self.unrunnable - if ws := self.decide_worker_non_rootish(ts): + decide_worker = ( + self.decide_worker_rootish_queuing_disabled + if self.is_rootish(ts) + else self.decide_worker_non_rootish + ) + if ws := decide_worker(ts): self.unrunnable.discard(ts) worker_msgs = _add_to_processing(self, ts, ws) # If no worker, task just stays in `no-worker` + if self.validate and self.is_rootish(ts): + assert ws is not None + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) @@ -2052,8 +2069,8 @@ def decide_worker_rootish_queuing_disabled( ``no-worker``. """ if self.validate: - # See root-ish-ness note below in `decide_worker_rootish_queuing_enabled` assert math.isinf(self.WORKER_SATURATION) + assert self.is_rootish(ts) pool = self.idle.values() if self.idle else self.running if not pool: @@ -2113,11 +2130,6 @@ def decide_worker_rootish_queuing_enabled(self) -> WorkerState | None: """ if self.validate: - # We don't `assert self.is_rootish(ts)` here, because that check is dependent on - # cluster size. It's possible a task looked root-ish when it was queued, but the - # cluster has since scaled up and it no longer does when coming out of the queue. - # If `is_rootish` changes to a static definition, then add that assertion here - # (and actually pass in the task). assert not math.isinf(self.WORKER_SATURATION) if not self.idle: @@ -2154,6 +2166,9 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None: ``ts`` or there are no running workers, returns None, in which case the task should be transitioned to ``no-worker``. """ + if self.validate: + assert not self.is_rootish(ts) + if not self.running: return None @@ -2222,6 +2237,7 @@ def transition_waiting_processing(self, key, stimulus_id): # NOTE: having two root-ish methods is temporary. When the feature flag is removed, # there should only be one, which combines co-assignment and queuing. # Eventually, special-casing root tasks might be removed entirely, with better heuristics. + ts._rootish = True # cached until `processing` if math.isinf(self.WORKER_SATURATION): if not (ws := self.decide_worker_rootish_queuing_disabled(ts)): return {ts.key: "no-worker"}, {}, {} @@ -2229,6 +2245,7 @@ def transition_waiting_processing(self, key, stimulus_id): if not (ws := self.decide_worker_rootish_queuing_enabled()): return {ts.key: "queued"}, {}, {} else: + ts._rootish = False # cached until `processing` if not (ws := self.decide_worker_non_rootish(ts)): return {ts.key: "no-worker"}, {}, {} @@ -2988,6 +3005,15 @@ def is_rootish(self, ts: TaskState) -> bool: Root-ish tasks are part of a group that's much larger than the cluster, and have few or no dependencies. """ + # NOTE: the result of `is_rootish` is cached in `waiting->processing`, and + # invalidated when entering `processing`. This is for the benefit of the + # `queued` and and `no-worker` states. We cache `is_rootish` not for + # performance, but so it can't change if `TaskGroup` and cluster size does. That + # avoids annoying edge cases where a task does/doesn't look root-ish when it + # goes into `queued` or `unrunnable`, but that's flipped when it comes out. + if (cached := ts._rootish) is not None: + return cached + if ts.resource_restrictions or ts.worker_restrictions or ts.host_restrictions: return False tg = ts.group @@ -4892,6 +4918,7 @@ def validate_released(self, key): assert not any([ts in dts.waiters for dts in ts.dependencies]) assert ts not in self.unrunnable assert ts not in self.queued + assert ts._rootish is None, ts._rootish def validate_waiting(self, key): ts: TaskState = self.tasks[key] @@ -4900,6 +4927,7 @@ def validate_waiting(self, key): assert not ts.processing_on assert ts not in self.unrunnable assert ts not in self.queued + assert ts._rootish is None, ts._rootish for dts in ts.dependencies: # We are waiting on a dependency iff it's not stored assert bool(dts.who_has) != (dts in ts.waiting_on) @@ -4912,6 +4940,7 @@ def validate_queued(self, key): assert not ts.waiting_on assert not ts.who_has assert not ts.processing_on + assert ts._rootish is True, ts._rootish assert not ( ts.worker_restrictions or ts.host_restrictions or ts.resource_restrictions ) @@ -4928,6 +4957,7 @@ def validate_processing(self, key): assert ts in ws.processing assert not ts.who_has assert ts not in self.queued + assert ts._rootish is None, ts._rootish for dts in ts.dependencies: assert dts.who_has assert ts in dts.waiters @@ -4941,6 +4971,7 @@ def validate_memory(self, key): assert not ts.waiting_on assert ts not in self.unrunnable assert ts not in self.queued + assert ts._rootish is None, ts._rootish for dts in ts.dependents: assert (dts in ts.waiters) == ( dts.state in ("waiting", "queued", "processing", "no-worker") @@ -4955,6 +4986,7 @@ def validate_no_worker(self, key): assert not ts.processing_on assert not ts.who_has assert ts not in self.queued + assert ts._rootish is not None, ts._rootish for dts in ts.dependencies: assert dts.who_has @@ -4963,6 +4995,7 @@ def validate_erred(self, key): assert ts.exception_blame assert not ts.who_has assert ts not in self.queued + assert ts._rootish is None, ts._rootish def validate_key(self, key, ts: TaskState | None = None): try: @@ -7052,6 +7085,8 @@ def get_metadata(self, keys: list[str], default=no_default): def set_restrictions(self, worker: dict[str, Collection[str] | str]): for key, restrictions in worker.items(): ts = self.tasks[key] + if ts._rootish is not None: + raise ValueError(f"cannot set restrictions on ready {ts}") if isinstance(restrictions, str): restrictions = {restrictions} ts.worker_restrictions = set(restrictions) @@ -7774,6 +7809,7 @@ def _validate_ready(state: SchedulerState, ts: TaskState) -> None: assert ts not in state.unrunnable assert ts not in state.queued assert all(dts.who_has for dts in ts.dependencies) + assert ts._rootish is not None, ts._rootish def _add_to_processing( @@ -7785,6 +7821,7 @@ def _add_to_processing( assert ws in state.running, state.running assert (o := state.workers.get(ws.address)) is ws, (ws, o) + ts._rootish = None ws.add_to_processing(ts) ts.processing_on = ws ts.state = "processing" @@ -7812,6 +7849,9 @@ def _exit_processing_common( -------- Scheduler._set_duration_estimate """ + if state.validate: + assert ts._rootish is None, ts._rootish + ws = ts.processing_on assert ws ts.processing_on = None @@ -7913,6 +7953,7 @@ def _propagate_released( recommendations: Recs, ) -> None: ts.state = "released" + ts._rootish = None key = ts.key if ts.has_lost_dependencies: diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index f3e21e9d0af..b31a4f2a2fe 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -8,6 +8,7 @@ import pickle import re import sys +from contextlib import AsyncExitStack from itertools import product from textwrap import dedent from time import sleep @@ -481,6 +482,84 @@ async def test_queued_remove_add_worker(c, s, a, b): await wait(fs) +@gen_cluster( + client=True, + nthreads=[("", 2)] * 2, + config={ + "distributed.worker.memory.pause": False, + "distributed.worker.memory.target": False, + "distributed.worker.memory.spill": False, + "distributed.scheduler.work-stealing": False, + }, +) +async def test_queued_rootish_changes_while_paused(c, s, a, b): + "Some tasks are root-ish, some aren't. So both `unrunnable` and `queued` contain non-restricted tasks." + + root = c.submit(inc, 1, key="root") + await root + + # manually pause the workers + a.status = Status.paused + b.status = Status.paused + + await async_wait_for(lambda: not s.running, 5) + + fs = [c.submit(inc, root, key=f"inc-{i}") for i in range(s.total_nthreads * 2 + 1)] + # ^ `c.submit` in a for-loop so the first tasks don't look root-ish (`TaskGroup` too + # small), then the last one does. So N-1 tasks will go to `no-worker`, and the last + # to `queued`. `is_rootish` is just messed up like that. + + await async_wait_for(lambda: len(s.tasks) > len(fs), 5) + + # un-pause + a.status = Status.running + b.status = Status.running + await async_wait_for(lambda: len(s.running) == len(s.workers), 5) + + await c.gather(fs) + + +@gen_cluster( + client=True, + nthreads=[("", 1)], + config={"distributed.scheduler.work-stealing": False}, +) +async def test_queued_rootish_changes_scale_up(c, s, a): + "Tasks are initially root-ish. After cluster scales, they don't meet the definition, but still are." + + root = c.submit(inc, 1, key="root") + + event = Event() + clog = c.submit(event.wait, key="clog") + await wait_for_state(clog.key, "processing", s) + + fs = c.map(inc, [root] * 5, key=[f"inc-{i}" for i in range(5)]) + + await async_wait_for(lambda: len(s.tasks) > len(fs), 5) + + if not s.is_rootish(s.tasks[fs[0].key]): + pytest.fail( + "Test assumptions have changed; task is not root-ish. Test may no longer be relevant." + ) + if math.isfinite(s.WORKER_SATURATION): + assert s.queued + + async with AsyncExitStack() as stack: + for _ in range(3): + await stack.enter_async_context(Worker(s.address, nthreads=2)) + + if not s.is_rootish(s.tasks[fs[0].key]): + pytest.fail( + "Test assumptions have changed; root-ish-ness has flipped. Test may no longer be relevant." + ) + + await event.set() + await clog + + # Just verify it doesn't deadlock + await c.gather(fs) + + @gen_cluster(client=True, nthreads=[("", 1)]) async def test_secede_opens_slot(c, s, a): first = Event()