diff --git a/distributed/collections.py b/distributed/collections.py index c25001047f6..b79e0cd6c08 100644 --- a/distributed/collections.py +++ b/distributed/collections.py @@ -1,6 +1,7 @@ from __future__ import annotations import heapq +import itertools import weakref from collections import OrderedDict, UserDict from collections.abc import Callable, Hashable, Iterator @@ -56,6 +57,9 @@ def __repr__(self) -> str: def __contains__(self, value: object) -> bool: return value in self._data + def __bool__(self) -> bool: + return bool(self._data) + def __len__(self) -> int: return len(self._data) @@ -93,6 +97,41 @@ def pop(self) -> T: self._data.discard(value) return value + def popright(self) -> T: + "Remove and return one of the largest elements (not necessarily the largest)!" + if not self._data: + raise KeyError("popright from an empty set") + while True: + _, _, vref = self._heap.pop() + value = vref() + if value is not None and value in self._data: + self._data.discard(value) + return value + + def topk(self, k: int) -> Iterator[T]: + # TODO confirm big-O values here + "Iterator over the largest K elements. This is O(k*logn) for k < n // 2, O(n*logn) otherwise." + k = min(k, len(self)) + if k == 1: + yield self.peek() + elif k >= len(self) // 2: + return itertools.islice(self.sorted(), k) + else: + # FIXME though neat, with all the list mutation this is probably always slower than sorting inplace. + elems: list[tuple[Any, int, weakref.ref[T]]] = [] + try: + while len(elems) < k: + elem = heapq.heappop(self._heap) + value = elem[-1]() + if value is not None and value in self._data: + # NOTE: we're in a broken state during iteration, since the value exists + # in the set but not the heap. As with all Python iterators, mutating + # while iterating is undefined. + elems.append(elem) + yield value + finally: + self._heap = elems + self._heap + def __iter__(self) -> Iterator[T]: """Iterate over all elements. This is a O(n) operation which returns the elements in pseudo-random order. @@ -104,7 +143,8 @@ def sorted(self) -> Iterator[T]: elements in order, from smallest to largest according to the key and insertion order. """ - for _, _, vref in sorted(self._heap): + self._heap.sort() # A sorted list maintains the heap invariant + for _, _, vref in self._heap: value = vref() if value in self._data: yield value diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 8e0bf453be9..c1935d12d57 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -2091,8 +2091,8 @@ def __init__(self, scheduler, **kwargs): node_colors = factor_cmap( "state", - factors=["waiting", "processing", "memory", "released", "erred"], - palette=["gray", "green", "red", "blue", "black"], + factors=["waiting", "queued", "processing", "memory", "released", "erred"], + palette=["gray", "yellow", "green", "red", "blue", "black"], ) self.root = figure(title="Task Graph", **kwargs) @@ -2980,7 +2980,7 @@ def __init__(self, scheduler, **kwargs): self.scheduler = scheduler data = progress_quads( - dict(all={}, memory={}, erred={}, released={}, processing={}) + dict(all={}, memory={}, erred={}, released={}, processing={}, queued={}) ) self.source = ColumnDataSource(data=data) @@ -3052,6 +3052,18 @@ def __init__(self, scheduler, **kwargs): fill_alpha=0.35, line_alpha=0, ) + self.root.quad( + source=self.source, + top="top", + bottom="bottom", + left="processing-loc", + right="queued-loc", + fill_color="gray", + hatch_pattern="/", + hatch_color="white", + fill_alpha=0.35, + line_alpha=0, + ) self.root.text( source=self.source, text="show-name", @@ -3087,6 +3099,14 @@ def __init__(self, scheduler, **kwargs): All:  @all +
+ Queued:  + @queued +
+
+ Processing:  + @processing +
Memory:  @memory @@ -3095,10 +3115,6 @@ def __init__(self, scheduler, **kwargs): Erred:  @erred
-
- Ready:  - @processing -
""", ) self.root.add_tools(hover) @@ -3112,6 +3128,7 @@ def update(self): "released": {}, "processing": {}, "waiting": {}, + "queued": {}, } for tp in self.scheduler.task_prefixes.values(): @@ -3122,6 +3139,7 @@ def update(self): state["released"][tp.name] = active_states["released"] state["processing"][tp.name] = active_states["processing"] state["waiting"][tp.name] = active_states["waiting"] + state["queued"][tp.name] = active_states["queued"] state["all"] = {k: sum(v[k] for v in state.values()) for k in state["memory"]} @@ -3134,7 +3152,7 @@ def update(self): totals = { k: sum(state[k].values()) - for k in ["all", "memory", "erred", "released", "waiting"] + for k in ["all", "memory", "erred", "released", "waiting", "queued"] } totals["processing"] = totals["all"] - sum( v for k, v in totals.items() if k != "all" @@ -3142,8 +3160,10 @@ def update(self): self.root.title.text = ( "Progress -- total: %(all)s, " - "in-memory: %(memory)s, processing: %(processing)s, " "waiting: %(waiting)s, " + "queued: %(queued)s, " + "processing: %(processing)s, " + "in-memory: %(memory)s, " "erred: %(erred)s" % totals ) diff --git a/distributed/diagnostics/progress_stream.py b/distributed/diagnostics/progress_stream.py index 45f84b1c2cd..04aa7c7c596 100644 --- a/distributed/diagnostics/progress_stream.py +++ b/distributed/diagnostics/progress_stream.py @@ -64,23 +64,29 @@ def progress_quads(msg, nrows=8, ncols=3): ... 'memory': {'inc': 2, 'dec': 0, 'add': 1}, ... 'erred': {'inc': 0, 'dec': 1, 'add': 0}, ... 'released': {'inc': 1, 'dec': 0, 'add': 1}, - ... 'processing': {'inc': 1, 'dec': 0, 'add': 2}} + ... 'processing': {'inc': 1, 'dec': 0, 'add': 2}, + ... 'queued': {'inc': 1, 'dec': 0, 'add': 2}} >>> progress_quads(msg, nrows=2) # doctest: +SKIP - {'name': ['inc', 'add', 'dec'], - 'left': [0, 0, 1], - 'right': [0.9, 0.9, 1.9], - 'top': [0, -1, 0], - 'bottom': [-.8, -1.8, -.8], - 'released': [1, 1, 0], - 'memory': [2, 1, 0], - 'erred': [0, 0, 1], - 'processing': [1, 0, 2], - 'done': ['3 / 5', '2 / 4', '1 / 1'], - 'released-loc': [.2/.9, .25 / 0.9, 1], - 'memory-loc': [3 / 5 / .9, .5 / 0.9, 1], - 'erred-loc': [3 / 5 / .9, .5 / 0.9, 1.9], - 'processing-loc': [4 / 5, 1 / 1, 1]}} + {'all': [5, 4, 1], + 'memory': [2, 1, 0], + 'erred': [0, 0, 1], + 'released': [1, 1, 0], + 'processing': [1, 2, 0], + 'queued': [1, 2, 0], + 'name': ['inc', 'add', 'dec'], + 'show-name': ['inc', 'add', 'dec'], + 'left': [0, 0, 1], + 'right': [0.9, 0.9, 1.9], + 'top': [0, -1, 0], + 'bottom': [-0.8, -1.8, -0.8], + 'color': ['#45BF6F', '#2E6C8E', '#440154'], + 'released-loc': [0.18, 0.225, 1.0], + 'memory-loc': [0.54, 0.45, 1.0], + 'erred-loc': [0.54, 0.45, 1.9], + 'processing-loc': [0.72, 0.9, 1.9], + 'queued-loc': [0.9, 1.35, 1.9], + 'done': ['3 / 5', '2 / 4', '1 / 1']} """ width = 0.9 names = sorted(msg["all"], key=msg["all"].get, reverse=True) @@ -100,19 +106,28 @@ def progress_quads(msg, nrows=8, ncols=3): d["memory-loc"] = [] d["erred-loc"] = [] d["processing-loc"] = [] + d["queued-loc"] = [] d["done"] = [] - for r, m, e, p, a, l in zip( - d["released"], d["memory"], d["erred"], d["processing"], d["all"], d["left"] + for r, m, e, p, q, a, l in zip( + d["released"], + d["memory"], + d["erred"], + d["processing"], + d["queued"], + d["all"], + d["left"], ): rl = width * r / a + l ml = width * (r + m) / a + l el = width * (r + m + e) / a + l pl = width * (p + r + m + e) / a + l + ql = width * (p + r + m + e + q) / a + l done = "%d / %d" % (r + m + e, a) d["released-loc"].append(rl) d["memory-loc"].append(ml) d["erred-loc"].append(el) d["processing-loc"].append(pl) + d["queued-loc"].append(ql) d["done"].append(done) return d diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index 8d73b7df145..aa42c08972a 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -117,6 +117,26 @@ properties: description: | How frequently to balance worker loads + worker-oversaturation: + type: float + description: | + Controls how many extra root tasks are sent to workers (like a `readahead`). + + `floor(worker-oversaturation * worker.nthreads)` _extra_ tasks are sent to the worker + beyond its thread count. If `.inf`, all runnable tasks are immediately sent to workers. + + Allowing oversaturation means a worker will start running a new root task as soon as + it completes the previous, even if there is a higher-priority downstream task to run. + This reduces worker idleness, by letting workers do something while waiting for further + instructions from the scheduler. + + This generally comes at the expense of increased memory usage. It leads to "wider" + (more breadth-first) execution of the graph. + + Compute-bound workloads benefit from oversaturation. Memory-bound workloads should + generally leave `worker-oversaturation` at 0, though 0.25-0.5 could slightly improve + performance if ample memory is available. + worker-ttl: type: - string diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 3cf75a23298..a37e1c34add 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -22,6 +22,7 @@ distributed: events-log-length: 100000 work-stealing: True # workers should steal tasks from each other work-stealing-interval: 100ms # Callback time for work stealing + worker-oversaturation: 0.0 # Send this fraction of nthreads extra root tasks to workers worker-ttl: "5 minutes" # like '60s'. Time to live for workers. They must heartbeat faster than this pickle: True # Is the scheduler allowed to deserialize arbitrary bytestrings preload: [] # Run custom modules with Scheduler diff --git a/distributed/http/templates/worker-table.html b/distributed/http/templates/worker-table.html index 87512ee3860..765d97133d7 100644 --- a/distributed/http/templates/worker-table.html +++ b/distributed/http/templates/worker-table.html @@ -6,6 +6,7 @@ Memory Memory use Occupancy + Queued Processing In-memory Services @@ -20,6 +21,7 @@ {{ format_bytes(ws.memory_limit) if ws.memory_limit is not None else "" }} {{ format_time(ws.occupancy) }} + {{ len(ws.queued) }} {{ len(ws.processing) }} {{ len(ws.has_what) }} {% if 'dashboard' in ws.services %} diff --git a/distributed/http/templates/worker.html b/distributed/http/templates/worker.html index 9c7608cb8c2..f5795248200 100644 --- a/distributed/http/templates/worker.html +++ b/distributed/http/templates/worker.html @@ -41,6 +41,21 @@

Processing

{% end %} +
+

Queued

+ + + + + + {% for ts in ws.queued.sorted() %} + + + + + {% end %} +
Task Priority
{{ts.key}} {{ts.priority }}
+
diff --git a/distributed/scheduler.py b/distributed/scheduler.py index d5dccbc473e..ec9496db617 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -58,6 +58,7 @@ from distributed._stories import scheduler_story from distributed.active_memory_manager import ActiveMemoryManagerExtension, RetireWorker from distributed.batched import BatchedSend +from distributed.collections import HeapSet from distributed.comm import ( Comm, CommClosedError, @@ -129,7 +130,15 @@ "stealing": WorkStealing, } -ALL_TASK_STATES = {"released", "waiting", "no-worker", "processing", "erred", "memory"} +ALL_TASK_STATES = { + "released", + "waiting", + "no-worker", + "queued", + "processing", + "erred", + "memory", +} class ClientState: @@ -440,6 +449,10 @@ class WorkerState: #: been running. executing: dict[TaskState, float] + #: Tasks queued to _potentially_ run on this worker in the future, ordered by priority. + #: The queuing is scheduler-side only; the worker is unaware of these tasks. + queued: HeapSet[TaskState] + #: The available resources on this worker, e.g. ``{"GPU": 2}``. #: These are abstract quantities that constrain certain tasks from running at the #: same time on this worker. @@ -494,6 +507,7 @@ def __init__( self.processing = {} self.long_running = set() self.executing = {} + self.queued = HeapSet(key=operator.attrgetter("priority")) self.resources = {} self.used_resources = {} self.extra = extra or {} @@ -563,7 +577,8 @@ def __repr__(self) -> str: f"" + f"processing: {len(self.processing)}, " + f"queued: {len(self.queued)}>" ) def _repr_html_(self): @@ -573,6 +588,7 @@ def _repr_html_(self): status=self.status.name, has_what=self.has_what, processing=self.processing, + queued=self.queued, ) def identity(self) -> dict[str, Any]: @@ -972,6 +988,10 @@ class TaskState: #: it. This attribute is kept in sync with :attr:`WorkerState.processing`. processing_on: WorkerState | None + #: If this task is in the "queued" state, which worker is currently queued + #: it. This attribute is kept in sync with :attr:`WorkerState.queued`. + queued_on: WorkerState | None + #: The number of times this task can automatically be retried in case of failure. #: If a task fails executing (the worker returns with an error), its :attr:`retries` #: attribute is checked. If it is equal to 0, the task is marked "erred". If it is @@ -1065,7 +1085,7 @@ class TaskState: #: Cached hash of :attr:`~TaskState.client_key` _hash: int - __slots__ = tuple(__annotations__) # type: ignore + __slots__ = tuple(__annotations__) + ("__weakref__",) # type: ignore def __init__(self, key: str, run_spec: object): self.key = key @@ -1088,6 +1108,7 @@ def __init__(self, key: str, run_spec: object): self.waiters = set() self.who_has = set() self.processing_on = None + self.queued_on = None self.has_lost_dependencies = False self.host_restrictions = None # type: ignore self.worker_restrictions = None # type: ignore @@ -1273,6 +1294,7 @@ class SchedulerState: "MEMORY_REBALANCE_SENDER_MIN", "MEMORY_REBALANCE_RECIPIENT_MAX", "MEMORY_REBALANCE_HALF_GAP", + "WORKER_OVERSATURATION", } def __init__( @@ -1341,6 +1363,9 @@ def __init__( dask.config.get("distributed.worker.memory.rebalance.sender-recipient-gap") / 2.0 ) + self.WORKER_OVERSATURATION = dask.config.get( + "distributed.scheduler.worker-oversaturation" + ) self.transition_counter = 0 self.transition_counter_max = transition_counter_max @@ -1778,12 +1803,27 @@ def decide_worker(self, ts: TaskState) -> WorkerState | None: ): ws = tg.last_worker - if not (ws and tg.last_worker_tasks_left and ws.address in self.workers): + if not ( + ws and tg.last_worker_tasks_left and self.workers.get(ws.address) is ws + ): # Last-used worker is full or unknown; pick a new worker for the next few tasks + + # We just pick the worker with the shortest queue (or if queuing is disabled, + # the fewest processing tasks). We've already decided dependencies are unimportant, + # so we don't care to schedule near them. + backlog = operator.attrgetter( + "processing" if math.isinf(self.WORKER_OVERSATURATION) else "queued" + ) ws = min( - (self.idle or self.workers).values(), - key=partial(self.worker_objective, ts), + self.workers.values(), key=lambda ws: len(backlog(ws)) / ws.nthreads ) + if self.validate: + assert ws is not tg.last_worker, ( + f"Colocation reused worker {ws} for {tg}, " + f"idle: {list(self.idle.values())}, " + f"workers: {list(self.workers.values())}" + ) + tg.last_worker_tasks_left = math.floor( (len(tg) / self.total_nthreads) * ws.nthreads ) @@ -1793,6 +1833,27 @@ def decide_worker(self, ts: TaskState) -> WorkerState | None: ws if tg.states["released"] + tg.states["waiting"] > 1 else None ) tg.last_worker_tasks_left -= 1 + + # Queue if worker is full to avoid root task overproduction. + if worker_saturated(ws, self.WORKER_OVERSATURATION): + # TODO this should be a transition function instead. + # But how do we get the `ws` into it? Recommendations on the scheduler can't take arguments. + + if self.validate: + assert not ts.queued_on, ts.queued_on + assert ts not in ws.queued + + # TODO maintain global queue of tasks as well for newly arriving workers to use? + # QUESTION could `queued` be an OrderedSet instead of a HeapSet, giving us O(1) + # operations instead of O(logn)? Reasoning is that we're always inserting elements + # in priority order anyway. + # This wouldn't work in the case that a batch of lower-priority root tasks becomes + # ready before a batch of higher-priority root tasks. + ws.queued.add(ts) + ts.queued_on = ws + ts.state = "queued" + return None + return ws if ts.dependencies or valid_workers is not None: @@ -1841,6 +1902,7 @@ def transition_waiting_processing(self, key, stimulus_id): assert not ts.who_has assert not ts.exception_blame assert not ts.processing_on + assert not ts.queued_on assert not ts.has_lost_dependencies assert ts not in self.unrunnable assert all(dts.who_has for dts in ts.dependencies) @@ -1892,6 +1954,7 @@ def transition_waiting_memory( if self.validate: assert not ts.processing_on + assert not ts.queued_on assert ts.waiting_on assert ts.state == "waiting" @@ -1908,6 +1971,7 @@ def transition_waiting_memory( if self.validate: assert not ts.processing_on + assert not ts.queued_on assert not ts.waiting_on assert ts.who_has @@ -1999,7 +2063,10 @@ def transition_processing_memory( if nbytes is not None: ts.set_nbytes(nbytes) - _remove_from_processing(self, ts) + # NOTE: recommendations for queued tasks are added first, so they'll be popped last, + # allowing higher-priority downstream tasks to be transitioned first. + # FIXME: this would be incorrect if queued tasks are user-annotated as higher priority. + _remove_from_processing(self, ts, recommendations) _add_to_memory( self, ts, ws, recommendations, client_msgs, type=type, typename=typename @@ -2007,7 +2074,18 @@ def transition_processing_memory( if self.validate: assert not ts.processing_on + assert not ts.queued_on assert not ts.waiting_on + processing_recs = { + k: r for k, r in recommendations.items() if r == "processing" + } + assert list(processing_recs) == ( + sr := sorted( + processing_recs, + key=lambda k: self.tasks[k].priority, + reverse=True, + ) + ), (list(processing_recs), sr) return recommendations, client_msgs, worker_msgs except Exception as e: @@ -2030,6 +2108,7 @@ def transition_memory_released(self, key, stimulus_id, safe: bool = False): if self.validate: assert not ts.waiting_on assert not ts.processing_on + assert not ts.queued_on if safe: assert not ts.waiters @@ -2191,6 +2270,7 @@ def transition_waiting_released(self, key, stimulus_id): if self.validate: assert not ts.who_has assert not ts.processing_on + assert not ts.queued_on dts: TaskState for dts in ts.dependencies: @@ -2232,9 +2312,9 @@ def transition_processing_released(self, key, stimulus_id): assert not ts.waiting_on assert self.tasks[key].state == "processing" - w: str = _remove_from_processing(self, ts) - if w: - worker_msgs[w] = [ + ws = _remove_from_processing(self, ts, recommendations) + if ws: + worker_msgs[ws] = [ { "op": "free-keys", "keys": [key], @@ -2259,6 +2339,7 @@ def transition_processing_released(self, key, stimulus_id): if self.validate: assert not ts.processing_on + assert not ts.queued_on return recommendations, client_msgs, worker_msgs except Exception as e: @@ -2301,7 +2382,7 @@ def transition_processing_erred( ws = ts.processing_on ws.actors.remove(ts) - w = _remove_from_processing(self, ts) + w = _remove_from_processing(self, ts, recommendations) ts.erred_on.add(w or worker) # type: ignore if exception is not None: @@ -2350,6 +2431,7 @@ def transition_processing_erred( if self.validate: assert not ts.processing_on + assert not ts.queued_on return recommendations, client_msgs, worker_msgs except Exception as e: @@ -2390,6 +2472,108 @@ def transition_no_worker_released(self, key, stimulus_id): pdb.set_trace() raise + def transition_queued_released(self, key, stimulus_id): + try: + ts: TaskState = self.tasks[key] + recommendations: dict = {} + client_msgs: dict = {} + worker_msgs: dict = {} + + # TODO allow `remove_worker` to clear `queued_on` and `ws.queued` eagerly; it's more efficient. + ws = ts.queued_on + assert ws + + if self.validate: + assert ts in ws.queued + assert not ts.processing_on + + ws.queued.remove(ts) + ts.queued_on = None + + # TODO copied from `transition_processing_released`; factor out into helper function + ts.state = "released" + + if ts.has_lost_dependencies: + recommendations[key] = "forgotten" + elif ts.waiters or ts.who_wants: + # TODO rescheduling of queued root tasks may be poor. + recommendations[key] = "waiting" + + if recommendations.get(key) != "waiting": + for dts in ts.dependencies: + if dts.state != "released": + dts.waiters.discard(ts) + if not dts.waiters and not dts.who_wants: + recommendations[dts.key] = "released" + ts.waiters.clear() + + return recommendations, client_msgs, worker_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb + + pdb.set_trace() + raise + + def transition_queued_processing(self, key, stimulus_id): + try: + ts: TaskState = self.tasks[key] + recommendations: dict = {} + client_msgs: dict = {} + worker_msgs: dict = {} + + ws = ts.queued_on + assert ws + # TODO should this be a graceful transition to released? I think `remove_worker` + # makes it such that this should never happen. + assert ( + self.workers[ws.address] is ws + ), f"Task {ts} queued on stale worker {ws}" + + if self.validate: + assert not ts.actor, "Actors can't be queued wat" + assert ts in ws.queued + # Copied from `transition_waiting_processing` + assert not ts.processing_on + assert not ts.waiting_on + assert not ts.who_has + assert not ts.exception_blame + assert not ts.has_lost_dependencies + assert ts not in self.unrunnable + assert all(dts.who_has for dts in ts.dependencies) + + # TODO other validation that this is still an appropriate worker? + + if not worker_saturated(ws, self.WORKER_OVERSATURATION): + # If more important tasks already got scheduled, remain queued + + ts.queued_on = None + ws.queued.remove(ts) + # TODO Copied from `transition_waiting_processing`; factor out into helper function + self._set_duration_estimate(ts, ws) + ts.processing_on = ws + ts.state = "processing" + self.consume_resources(ts, ws) + self.check_idle_saturated(ws) + self.n_tasks += 1 + + if ts.actor: + ws.actors.add(ts) + + # logger.debug("Send job to worker: %s, %s", worker, key) + + worker_msgs[ws.address] = [_task_to_msg(self, ts)] + + return recommendations, client_msgs, worker_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb + + pdb.set_trace() + raise + def _remove_key(self, key): ts: TaskState = self.tasks.pop(key) assert ts.state == "forgotten" @@ -2413,6 +2597,7 @@ def transition_memory_forgotten(self, key, stimulus_id): if self.validate: assert ts.state == "memory" assert not ts.processing_on + assert not ts.queued_on assert not ts.waiting_on if not ts.run_spec: # It's ok to forget a pure data task @@ -2455,6 +2640,7 @@ def transition_released_forgotten(self, key, stimulus_id): assert ts.state in ("released", "erred") assert not ts.who_has assert not ts.processing_on + assert not ts.queued_on assert not ts.waiting_on, (ts, ts.waiting_on) if not ts.run_spec: # It's ok to forget a pure data task @@ -2495,6 +2681,8 @@ def transition_released_forgotten(self, key, stimulus_id): ("waiting", "released"): transition_waiting_released, ("waiting", "processing"): transition_waiting_processing, ("waiting", "memory"): transition_waiting_memory, + ("queued", "released"): transition_queued_released, + ("queued", "processing"): transition_queued_processing, ("processing", "released"): transition_processing_released, ("processing", "memory"): transition_processing_memory, ("processing", "erred"): transition_processing_erred, @@ -2765,7 +2953,71 @@ def bulk_schedule_after_adding_worker(self, ws: WorkerState): ordering, so the recommendations are sorted by priority order here. """ ts: TaskState - tasks = [] + recommendations: dict[str, str] = {} + + # Redistribute the tasks between all worker queues. We bubble tasks off the back of the most-queued + # worker onto the front of the least-queued, and repeat this until we've accumulated enough tasks to + # put onto the new worker. This maintains the co-assignment of each worker's queue, minimizing the + # fragmentation of neighboring tasks. + # Note this does not rebalance all workers. It just rebalances the busiest workers, stealing just enough + # tasks to fill up the new worker. + # NOTE: this is probably going to be pretty slow for lots of queued tasks and/or lots of workers. + # Also unclear if this is even a good load-balancing strategy. + # TODO this is optimized for the add-worker case. Generalize for remove-worker as well. + # That would probably look like rebalancing all workers though. + if not math.isinf(self.WORKER_OVERSATURATION): + workers_with_queues: list[WorkerState] = sorted( + (wss for wss in self.workers.values() if wss.queued and wss is not ws), + key=lambda wss: len(wss.queued), + reverse=True, + ) + if workers_with_queues: + total_queued = sum(len(wss.queued) for wss in workers_with_queues) + target_qsize = int(total_queued / len(self.workers)) + moveable_tasks_so_far = 0 + last_q_tasks_to_move = 0 + i = 0 + # Go through workers with the largest queues until we've found enough workers to steal from + for i, wss in enumerate(workers_with_queues): + n_extra_tasks = len(wss.queued) - target_qsize + if n_extra_tasks <= 0: + break + moveable_tasks_so_far += n_extra_tasks + if moveable_tasks_so_far >= target_qsize: + last_q_tasks_to_move = n_extra_tasks - ( + moveable_tasks_so_far - target_qsize + ) + break + if last_q_tasks_to_move: + # Starting from the smallest, bubble tasks off the back of the queue and onto the front of the next-largest. + # At the end, bubble tasks onto the new worker's queue + while i >= 0: + src = workers_with_queues[i] + dest = workers_with_queues[i - 1] if i > 0 else ws + for _ in range(last_q_tasks_to_move): + # NOTE: `popright` is not exactly the highest element, but sorting would be too expensive. + # It's good enough, and in the common case the heap is sorted anyway (because elements are) + # inserted in sorted order by `decide_worker` + ts = src.queued.popright() + ts.queued_on = dest + dest.queued.add(ts) + + i -= 1 + last_q_tasks_to_move = target_qsize + + if ( + ws.queued + and (n := task_slots_available(ws, self.WORKER_OVERSATURATION)) > 0 + ): + # NOTE: reverse priority order, since recommendations are processed in LIFO order + for ts in reversed(list(ws.queued.topk(n))): + if self.validate: + assert ts.state == "queued" + assert ts.queued_on is ws, (ts.queued_on, ws) + assert ts.key not in recommendations, recommendations[ts.key] + recommendations[ts.key] = "processing" + + tasks: list[TaskState] = [] for ts in self.unrunnable: valid: set = self.valid_workers(ts) if valid is None or ws in valid: @@ -4253,6 +4505,10 @@ async def remove_worker( else: # pure data recommendations[ts.key] = "forgotten" + for ts in ws.queued.sorted(): + recommendations[ts.key] = "released" + # ws.queued.clear() # TODO more performant + self.transitions(recommendations, stimulus_id=stimulus_id) for plugin in list(self.plugins.values()): @@ -4371,6 +4627,7 @@ def validate_released(self, key): assert not ts.waiting_on assert not ts.who_has assert not ts.processing_on + assert not ts.queued_on assert not any([ts in dts.waiters for dts in ts.dependencies]) assert ts not in self.unrunnable @@ -4379,12 +4636,27 @@ def validate_waiting(self, key): assert ts.waiting_on assert not ts.who_has assert not ts.processing_on + assert not ts.queued_on assert ts not in self.unrunnable for dts in ts.dependencies: # We are waiting on a dependency iff it's not stored assert bool(dts.who_has) != (dts in ts.waiting_on) assert ts in dts.waiters # XXX even if dts._who_has? + def validate_queued(self, key): + ts: TaskState = self.tasks[key] + dts: TaskState + assert not ts.waiting_on + ws = ts.queued_on + assert ws + assert self.workers.get(ws.address) is ws, f"{ts} queued on stale worker {ws}" + assert ts in ws.queued + assert not ts.who_has + assert not ts.processing_on + for dts in ts.dependencies: + assert dts.who_has + assert ts in dts.waiters + def validate_processing(self, key): ts: TaskState = self.tasks[key] dts: TaskState @@ -4403,6 +4675,7 @@ def validate_memory(self, key): assert ts.who_has assert bool(ts in self.replicated_tasks) == (len(ts.who_has) > 1) assert not ts.processing_on + assert not ts.queued_on assert not ts.waiting_on assert ts not in self.unrunnable for dts in ts.dependents: @@ -4417,6 +4690,7 @@ def validate_no_worker(self, key): assert not ts.waiting_on assert ts in self.unrunnable assert not ts.processing_on + assert not ts.queued_on assert not ts.who_has for dts in ts.dependencies: assert dts.who_has @@ -7131,7 +7405,9 @@ def request_remove_replicas( ) -def _remove_from_processing(state: SchedulerState, ts: TaskState) -> str | None: +def _remove_from_processing( + state: SchedulerState, ts: TaskState, recommendations: dict +) -> str | None: """Remove *ts* from the set of processing tasks. See also @@ -7157,6 +7433,19 @@ def _remove_from_processing(state: SchedulerState, ts: TaskState) -> str | None: state.check_idle_saturated(ws) state.release_resources(ts, ws) + # If a slot has opened up for a queued task, schedule it. + if ws.queued and not worker_saturated(ws, state.WORKER_OVERSATURATION): + # TODO peek or pop? + # What if multiple tasks complete on a worker in one transition cycle? Is that possible? + # TODO should we only be scheduling 1 taks? Or N open threads? Is there a possible deadlock + # where tasks remain queued on a worker forever? + qts = ws.queued.peek() + if state.validate: + assert qts.state == "queued" + assert qts.queued_on is ws, (qts.queued_on, ws) + assert qts.key not in recommendations, recommendations[qts.key] + recommendations[qts.key] = "processing" + return ws.address @@ -7434,8 +7723,18 @@ def validate_task_state(ts: TaskState) -> None: assert dts.state != "forgotten" assert (ts.processing_on is not None) == (ts.state == "processing") + assert not (ts.processing_on and ts.queued_on), (ts.processing_on, ts.queued_on) assert bool(ts.who_has) == (ts.state == "memory"), (ts, ts.who_has, ts.state) + if ts.queued_on: + assert ts.state == "queued" + assert ts in ts.queued_on.queued + + if ts.state == "queued": + assert ts.queued_on + assert not ts.processing_on + assert not ts.who_has + if ts.state == "processing": assert all(dts.who_has for dts in ts.dependencies), ( "task processing without all deps", @@ -7443,6 +7742,7 @@ def validate_task_state(ts: TaskState) -> None: str(ts.dependencies), ) assert not ts.waiting_on + assert not ts.queued_on if ts.who_has: assert ts.waiters or ts.who_wants, ( @@ -7530,6 +7830,19 @@ def heartbeat_interval(n: int) -> float: return n / 200 + 1 +def task_slots_available(ws: WorkerState, oversaturation_factor: float) -> int: + "Number of tasks that can be sent to this worker without oversaturating it" + assert not math.isinf(oversaturation_factor) + nthreads = ws.nthreads + return max(nthreads + int(oversaturation_factor * nthreads), 1) - len(ws.processing) + + +def worker_saturated(ws: WorkerState, oversaturation_factor: float) -> bool: + if math.isinf(oversaturation_factor): + return False + return task_slots_available(ws, oversaturation_factor) <= 0 + + class KilledWorker(Exception): def __init__(self, task: str, last_worker: WorkerState): super().__init__(task, last_worker) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 2d1d7a52c8b..09af6984076 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -16,12 +16,12 @@ import cloudpickle import psutil import pytest -from tlz import concat, first, merge, valmap +from tlz import concat, first, merge, partition, valmap from tornado.ioloop import IOLoop import dask from dask import delayed -from dask.utils import apply, parse_timedelta, stringify, tmpfile, typename +from dask.utils import apply, parse_bytes, parse_timedelta, stringify, tmpfile, typename from distributed import ( CancelledError, @@ -54,6 +54,7 @@ raises_with_cause, slowadd, slowdec, + slowidentity, slowinc, tls_only_security, varying, @@ -245,6 +246,124 @@ def random(**kwargs): test_decide_worker_coschedule_order_neighbors_() +@pytest.mark.slow +@gen_cluster( + client=True, + nthreads=[("", 2)] * 2, + worker_kwargs={"memory_limit": "1.0GiB"}, + timeout=3600, # TODO remove + Worker=Nanny, + scheduler_kwargs=dict( # TODO remove + dashboard=True, + dashboard_address=":8787", + ), + config={ + "distributed.worker.memory.target": False, + "distributed.worker.memory.spill": False, + "distributed.scheduler.work-stealing": False, + }, +) +async def test_root_task_overproduction(c, s, *nannies): + """ + Workload that would run out of memory and kill workers if >2 root tasks were + ever in memory at once on a worker. + """ + + @delayed(pure=True) + def big_data(size: int) -> str: + return "x" * size + + roots = [ + big_data(parse_bytes("350 MiB"), dask_key_name=f"root-{i}") for i in range(16) + ] + passthrough = [delayed(slowidentity)(x) for x in roots] + memory_consumed = [delayed(len)(x) for x in passthrough] + reduction = [sum(sizes) for sizes in partition(4, memory_consumed)] + final = sum(reduction) + + await c.compute(final) + + +@pytest.mark.parametrize( + "oversaturation, expected_task_counts", + [ + (1.5, (5, 2)), + (1.0, (4, 2)), + (0.0, (2, 1)), + (-1.0, (1, 1)), + (float("inf"), (7, 3)) + # ^ depends on root task assignment logic; ok if changes, just needs to add up to 10 + ], +) +def test_oversaturation_factor(oversaturation, expected_task_counts: tuple[int, int]): + @gen_cluster( + client=True, + nthreads=[("", 2), ("", 1)], + config={ + "distributed.scheduler.worker-oversaturation": oversaturation, + }, + ) + async def _test_oversaturation_factor(c, s, a, b): + event = Event() + fs = c.map(lambda _: event.wait(), range(10)) + while a.state.executing_count < min( + a.nthreads, expected_task_counts[0] + ) or b.state.executing_count < min(b.nthreads, expected_task_counts[1]): + await asyncio.sleep(0.01) + + assert len(a.state.tasks) == expected_task_counts[0] + assert len(b.state.tasks) == expected_task_counts[1] + + await event.set() + await c.gather(fs) + + _test_oversaturation_factor() + + +@pytest.mark.parametrize( + "saturation_factor", + [ + 0.0, + 1.0, + pytest.param( + float("inf"), + marks=pytest.mark.skip("https://github.com/dask/distributed/issues/6597"), + ), + ], +) +@gen_cluster( + client=True, + nthreads=[("", 2), ("", 1)], +) +async def test_oversaturation_multiple_task_groups(c, s, a, b, saturation_factor): + s.WORKER_OVERSATURATION = saturation_factor + xs = [delayed(i, name=f"x-{i}") for i in range(9)] + ys = [delayed(i, name=f"y-{i}") for i in range(9)] + zs = [x + y for x, y in zip(xs, ys)] + + await c.gather(c.compute(zs)) + + assert not a.incoming_transfer_log, [l["keys"] for l in a.incoming_transfer_log] + assert not b.incoming_transfer_log, [l["keys"] for l in b.incoming_transfer_log] + assert len(a.tasks) == 18 + assert len(b.tasks) == 9 + + +@gen_cluster( + client=True, + nthreads=[("", 2)] * 2, + timeout=3600, # TODO remove + scheduler_kwargs=dict( # TODO remove + dashboard=True, + dashboard_address=":8787", + ), +) +async def test_queued_tasks_rebalance(c, s, a, b): + event = Event() + fs = c.map(lambda _: event.wait(), range(100)) + await c.gather(fs) + + @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) async def test_move_data_over_break_restrictions(client, s, a, b, c): [x] = await client.scatter([1], workers=b.address) diff --git a/distributed/widgets/templates/worker_state.html.j2 b/distributed/widgets/templates/worker_state.html.j2 index cd152080bfc..08629998e1f 100644 --- a/distributed/widgets/templates/worker_state.html.j2 +++ b/distributed/widgets/templates/worker_state.html.j2 @@ -3,3 +3,4 @@ status: {{ status }} memory: {{ has_what | length }} processing: {{ processing | length }} + queued: {{ queued | length }}