diff --git a/distributed/collections.py b/distributed/collections.py
index c25001047f6..b79e0cd6c08 100644
--- a/distributed/collections.py
+++ b/distributed/collections.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import heapq
+import itertools
import weakref
from collections import OrderedDict, UserDict
from collections.abc import Callable, Hashable, Iterator
@@ -56,6 +57,9 @@ def __repr__(self) -> str:
def __contains__(self, value: object) -> bool:
return value in self._data
+ def __bool__(self) -> bool:
+ return bool(self._data)
+
def __len__(self) -> int:
return len(self._data)
@@ -93,6 +97,41 @@ def pop(self) -> T:
self._data.discard(value)
return value
+ def popright(self) -> T:
+ "Remove and return one of the largest elements (not necessarily the largest)!"
+ if not self._data:
+ raise KeyError("popright from an empty set")
+ while True:
+ _, _, vref = self._heap.pop()
+ value = vref()
+ if value is not None and value in self._data:
+ self._data.discard(value)
+ return value
+
+ def topk(self, k: int) -> Iterator[T]:
+ # TODO confirm big-O values here
+ "Iterator over the largest K elements. This is O(k*logn) for k < n // 2, O(n*logn) otherwise."
+ k = min(k, len(self))
+ if k == 1:
+ yield self.peek()
+ elif k >= len(self) // 2:
+ return itertools.islice(self.sorted(), k)
+ else:
+ # FIXME though neat, with all the list mutation this is probably always slower than sorting inplace.
+ elems: list[tuple[Any, int, weakref.ref[T]]] = []
+ try:
+ while len(elems) < k:
+ elem = heapq.heappop(self._heap)
+ value = elem[-1]()
+ if value is not None and value in self._data:
+ # NOTE: we're in a broken state during iteration, since the value exists
+ # in the set but not the heap. As with all Python iterators, mutating
+ # while iterating is undefined.
+ elems.append(elem)
+ yield value
+ finally:
+ self._heap = elems + self._heap
+
def __iter__(self) -> Iterator[T]:
"""Iterate over all elements. This is a O(n) operation which returns the
elements in pseudo-random order.
@@ -104,7 +143,8 @@ def sorted(self) -> Iterator[T]:
elements in order, from smallest to largest according to the key and insertion
order.
"""
- for _, _, vref in sorted(self._heap):
+ self._heap.sort() # A sorted list maintains the heap invariant
+ for _, _, vref in self._heap:
value = vref()
if value in self._data:
yield value
diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py
index 8e0bf453be9..c1935d12d57 100644
--- a/distributed/dashboard/components/scheduler.py
+++ b/distributed/dashboard/components/scheduler.py
@@ -2091,8 +2091,8 @@ def __init__(self, scheduler, **kwargs):
node_colors = factor_cmap(
"state",
- factors=["waiting", "processing", "memory", "released", "erred"],
- palette=["gray", "green", "red", "blue", "black"],
+ factors=["waiting", "queued", "processing", "memory", "released", "erred"],
+ palette=["gray", "yellow", "green", "red", "blue", "black"],
)
self.root = figure(title="Task Graph", **kwargs)
@@ -2980,7 +2980,7 @@ def __init__(self, scheduler, **kwargs):
self.scheduler = scheduler
data = progress_quads(
- dict(all={}, memory={}, erred={}, released={}, processing={})
+ dict(all={}, memory={}, erred={}, released={}, processing={}, queued={})
)
self.source = ColumnDataSource(data=data)
@@ -3052,6 +3052,18 @@ def __init__(self, scheduler, **kwargs):
fill_alpha=0.35,
line_alpha=0,
)
+ self.root.quad(
+ source=self.source,
+ top="top",
+ bottom="bottom",
+ left="processing-loc",
+ right="queued-loc",
+ fill_color="gray",
+ hatch_pattern="/",
+ hatch_color="white",
+ fill_alpha=0.35,
+ line_alpha=0,
+ )
self.root.text(
source=self.source,
text="show-name",
@@ -3087,6 +3099,14 @@ def __init__(self, scheduler, **kwargs):
All:
@all
+
+ Queued:
+ @queued
+
+
+ Processing:
+ @processing
+
Memory:
@memory
@@ -3095,10 +3115,6 @@ def __init__(self, scheduler, **kwargs):
Erred:
@erred
-
- Ready:
- @processing
-
""",
)
self.root.add_tools(hover)
@@ -3112,6 +3128,7 @@ def update(self):
"released": {},
"processing": {},
"waiting": {},
+ "queued": {},
}
for tp in self.scheduler.task_prefixes.values():
@@ -3122,6 +3139,7 @@ def update(self):
state["released"][tp.name] = active_states["released"]
state["processing"][tp.name] = active_states["processing"]
state["waiting"][tp.name] = active_states["waiting"]
+ state["queued"][tp.name] = active_states["queued"]
state["all"] = {k: sum(v[k] for v in state.values()) for k in state["memory"]}
@@ -3134,7 +3152,7 @@ def update(self):
totals = {
k: sum(state[k].values())
- for k in ["all", "memory", "erred", "released", "waiting"]
+ for k in ["all", "memory", "erred", "released", "waiting", "queued"]
}
totals["processing"] = totals["all"] - sum(
v for k, v in totals.items() if k != "all"
@@ -3142,8 +3160,10 @@ def update(self):
self.root.title.text = (
"Progress -- total: %(all)s, "
- "in-memory: %(memory)s, processing: %(processing)s, "
"waiting: %(waiting)s, "
+ "queued: %(queued)s, "
+ "processing: %(processing)s, "
+ "in-memory: %(memory)s, "
"erred: %(erred)s" % totals
)
diff --git a/distributed/diagnostics/progress_stream.py b/distributed/diagnostics/progress_stream.py
index 45f84b1c2cd..04aa7c7c596 100644
--- a/distributed/diagnostics/progress_stream.py
+++ b/distributed/diagnostics/progress_stream.py
@@ -64,23 +64,29 @@ def progress_quads(msg, nrows=8, ncols=3):
... 'memory': {'inc': 2, 'dec': 0, 'add': 1},
... 'erred': {'inc': 0, 'dec': 1, 'add': 0},
... 'released': {'inc': 1, 'dec': 0, 'add': 1},
- ... 'processing': {'inc': 1, 'dec': 0, 'add': 2}}
+ ... 'processing': {'inc': 1, 'dec': 0, 'add': 2},
+ ... 'queued': {'inc': 1, 'dec': 0, 'add': 2}}
>>> progress_quads(msg, nrows=2) # doctest: +SKIP
- {'name': ['inc', 'add', 'dec'],
- 'left': [0, 0, 1],
- 'right': [0.9, 0.9, 1.9],
- 'top': [0, -1, 0],
- 'bottom': [-.8, -1.8, -.8],
- 'released': [1, 1, 0],
- 'memory': [2, 1, 0],
- 'erred': [0, 0, 1],
- 'processing': [1, 0, 2],
- 'done': ['3 / 5', '2 / 4', '1 / 1'],
- 'released-loc': [.2/.9, .25 / 0.9, 1],
- 'memory-loc': [3 / 5 / .9, .5 / 0.9, 1],
- 'erred-loc': [3 / 5 / .9, .5 / 0.9, 1.9],
- 'processing-loc': [4 / 5, 1 / 1, 1]}}
+ {'all': [5, 4, 1],
+ 'memory': [2, 1, 0],
+ 'erred': [0, 0, 1],
+ 'released': [1, 1, 0],
+ 'processing': [1, 2, 0],
+ 'queued': [1, 2, 0],
+ 'name': ['inc', 'add', 'dec'],
+ 'show-name': ['inc', 'add', 'dec'],
+ 'left': [0, 0, 1],
+ 'right': [0.9, 0.9, 1.9],
+ 'top': [0, -1, 0],
+ 'bottom': [-0.8, -1.8, -0.8],
+ 'color': ['#45BF6F', '#2E6C8E', '#440154'],
+ 'released-loc': [0.18, 0.225, 1.0],
+ 'memory-loc': [0.54, 0.45, 1.0],
+ 'erred-loc': [0.54, 0.45, 1.9],
+ 'processing-loc': [0.72, 0.9, 1.9],
+ 'queued-loc': [0.9, 1.35, 1.9],
+ 'done': ['3 / 5', '2 / 4', '1 / 1']}
"""
width = 0.9
names = sorted(msg["all"], key=msg["all"].get, reverse=True)
@@ -100,19 +106,28 @@ def progress_quads(msg, nrows=8, ncols=3):
d["memory-loc"] = []
d["erred-loc"] = []
d["processing-loc"] = []
+ d["queued-loc"] = []
d["done"] = []
- for r, m, e, p, a, l in zip(
- d["released"], d["memory"], d["erred"], d["processing"], d["all"], d["left"]
+ for r, m, e, p, q, a, l in zip(
+ d["released"],
+ d["memory"],
+ d["erred"],
+ d["processing"],
+ d["queued"],
+ d["all"],
+ d["left"],
):
rl = width * r / a + l
ml = width * (r + m) / a + l
el = width * (r + m + e) / a + l
pl = width * (p + r + m + e) / a + l
+ ql = width * (p + r + m + e + q) / a + l
done = "%d / %d" % (r + m + e, a)
d["released-loc"].append(rl)
d["memory-loc"].append(ml)
d["erred-loc"].append(el)
d["processing-loc"].append(pl)
+ d["queued-loc"].append(ql)
d["done"].append(done)
return d
diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml
index 8d73b7df145..aa42c08972a 100644
--- a/distributed/distributed-schema.yaml
+++ b/distributed/distributed-schema.yaml
@@ -117,6 +117,26 @@ properties:
description: |
How frequently to balance worker loads
+ worker-oversaturation:
+ type: float
+ description: |
+ Controls how many extra root tasks are sent to workers (like a `readahead`).
+
+ `floor(worker-oversaturation * worker.nthreads)` _extra_ tasks are sent to the worker
+ beyond its thread count. If `.inf`, all runnable tasks are immediately sent to workers.
+
+ Allowing oversaturation means a worker will start running a new root task as soon as
+ it completes the previous, even if there is a higher-priority downstream task to run.
+ This reduces worker idleness, by letting workers do something while waiting for further
+ instructions from the scheduler.
+
+ This generally comes at the expense of increased memory usage. It leads to "wider"
+ (more breadth-first) execution of the graph.
+
+ Compute-bound workloads benefit from oversaturation. Memory-bound workloads should
+ generally leave `worker-oversaturation` at 0, though 0.25-0.5 could slightly improve
+ performance if ample memory is available.
+
worker-ttl:
type:
- string
diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml
index 3cf75a23298..a37e1c34add 100644
--- a/distributed/distributed.yaml
+++ b/distributed/distributed.yaml
@@ -22,6 +22,7 @@ distributed:
events-log-length: 100000
work-stealing: True # workers should steal tasks from each other
work-stealing-interval: 100ms # Callback time for work stealing
+ worker-oversaturation: 0.0 # Send this fraction of nthreads extra root tasks to workers
worker-ttl: "5 minutes" # like '60s'. Time to live for workers. They must heartbeat faster than this
pickle: True # Is the scheduler allowed to deserialize arbitrary bytestrings
preload: [] # Run custom modules with Scheduler
diff --git a/distributed/http/templates/worker-table.html b/distributed/http/templates/worker-table.html
index 87512ee3860..765d97133d7 100644
--- a/distributed/http/templates/worker-table.html
+++ b/distributed/http/templates/worker-table.html
@@ -6,6 +6,7 @@
Memory |
Memory use |
Occupancy |
+ Queued |
Processing |
In-memory |
Services |
@@ -20,6 +21,7 @@
{{ format_bytes(ws.memory_limit) if ws.memory_limit is not None else "" }} |
|
{{ format_time(ws.occupancy) }} |
+ {{ len(ws.queued) }} |
{{ len(ws.processing) }} |
{{ len(ws.has_what) }} |
{% if 'dashboard' in ws.services %}
diff --git a/distributed/http/templates/worker.html b/distributed/http/templates/worker.html
index 9c7608cb8c2..f5795248200 100644
--- a/distributed/http/templates/worker.html
+++ b/distributed/http/templates/worker.html
@@ -41,6 +41,21 @@ Processing
{% end %}
+
+
Queued
+
+
+ | Task |
+ Priority |
+
+ {% for ts in ws.queued.sorted() %}
+
+ | {{ts.key}} |
+ {{ts.priority }} |
+
+ {% end %}
+
+
diff --git a/distributed/scheduler.py b/distributed/scheduler.py
index d5dccbc473e..ec9496db617 100644
--- a/distributed/scheduler.py
+++ b/distributed/scheduler.py
@@ -58,6 +58,7 @@
from distributed._stories import scheduler_story
from distributed.active_memory_manager import ActiveMemoryManagerExtension, RetireWorker
from distributed.batched import BatchedSend
+from distributed.collections import HeapSet
from distributed.comm import (
Comm,
CommClosedError,
@@ -129,7 +130,15 @@
"stealing": WorkStealing,
}
-ALL_TASK_STATES = {"released", "waiting", "no-worker", "processing", "erred", "memory"}
+ALL_TASK_STATES = {
+ "released",
+ "waiting",
+ "no-worker",
+ "queued",
+ "processing",
+ "erred",
+ "memory",
+}
class ClientState:
@@ -440,6 +449,10 @@ class WorkerState:
#: been running.
executing: dict[TaskState, float]
+ #: Tasks queued to _potentially_ run on this worker in the future, ordered by priority.
+ #: The queuing is scheduler-side only; the worker is unaware of these tasks.
+ queued: HeapSet[TaskState]
+
#: The available resources on this worker, e.g. ``{"GPU": 2}``.
#: These are abstract quantities that constrain certain tasks from running at the
#: same time on this worker.
@@ -494,6 +507,7 @@ def __init__(
self.processing = {}
self.long_running = set()
self.executing = {}
+ self.queued = HeapSet(key=operator.attrgetter("priority"))
self.resources = {}
self.used_resources = {}
self.extra = extra or {}
@@ -563,7 +577,8 @@ def __repr__(self) -> str:
f""
+ f"processing: {len(self.processing)}, "
+ f"queued: {len(self.queued)}>"
)
def _repr_html_(self):
@@ -573,6 +588,7 @@ def _repr_html_(self):
status=self.status.name,
has_what=self.has_what,
processing=self.processing,
+ queued=self.queued,
)
def identity(self) -> dict[str, Any]:
@@ -972,6 +988,10 @@ class TaskState:
#: it. This attribute is kept in sync with :attr:`WorkerState.processing`.
processing_on: WorkerState | None
+ #: If this task is in the "queued" state, which worker is currently queued
+ #: it. This attribute is kept in sync with :attr:`WorkerState.queued`.
+ queued_on: WorkerState | None
+
#: The number of times this task can automatically be retried in case of failure.
#: If a task fails executing (the worker returns with an error), its :attr:`retries`
#: attribute is checked. If it is equal to 0, the task is marked "erred". If it is
@@ -1065,7 +1085,7 @@ class TaskState:
#: Cached hash of :attr:`~TaskState.client_key`
_hash: int
- __slots__ = tuple(__annotations__) # type: ignore
+ __slots__ = tuple(__annotations__) + ("__weakref__",) # type: ignore
def __init__(self, key: str, run_spec: object):
self.key = key
@@ -1088,6 +1108,7 @@ def __init__(self, key: str, run_spec: object):
self.waiters = set()
self.who_has = set()
self.processing_on = None
+ self.queued_on = None
self.has_lost_dependencies = False
self.host_restrictions = None # type: ignore
self.worker_restrictions = None # type: ignore
@@ -1273,6 +1294,7 @@ class SchedulerState:
"MEMORY_REBALANCE_SENDER_MIN",
"MEMORY_REBALANCE_RECIPIENT_MAX",
"MEMORY_REBALANCE_HALF_GAP",
+ "WORKER_OVERSATURATION",
}
def __init__(
@@ -1341,6 +1363,9 @@ def __init__(
dask.config.get("distributed.worker.memory.rebalance.sender-recipient-gap")
/ 2.0
)
+ self.WORKER_OVERSATURATION = dask.config.get(
+ "distributed.scheduler.worker-oversaturation"
+ )
self.transition_counter = 0
self.transition_counter_max = transition_counter_max
@@ -1778,12 +1803,27 @@ def decide_worker(self, ts: TaskState) -> WorkerState | None:
):
ws = tg.last_worker
- if not (ws and tg.last_worker_tasks_left and ws.address in self.workers):
+ if not (
+ ws and tg.last_worker_tasks_left and self.workers.get(ws.address) is ws
+ ):
# Last-used worker is full or unknown; pick a new worker for the next few tasks
+
+ # We just pick the worker with the shortest queue (or if queuing is disabled,
+ # the fewest processing tasks). We've already decided dependencies are unimportant,
+ # so we don't care to schedule near them.
+ backlog = operator.attrgetter(
+ "processing" if math.isinf(self.WORKER_OVERSATURATION) else "queued"
+ )
ws = min(
- (self.idle or self.workers).values(),
- key=partial(self.worker_objective, ts),
+ self.workers.values(), key=lambda ws: len(backlog(ws)) / ws.nthreads
)
+ if self.validate:
+ assert ws is not tg.last_worker, (
+ f"Colocation reused worker {ws} for {tg}, "
+ f"idle: {list(self.idle.values())}, "
+ f"workers: {list(self.workers.values())}"
+ )
+
tg.last_worker_tasks_left = math.floor(
(len(tg) / self.total_nthreads) * ws.nthreads
)
@@ -1793,6 +1833,27 @@ def decide_worker(self, ts: TaskState) -> WorkerState | None:
ws if tg.states["released"] + tg.states["waiting"] > 1 else None
)
tg.last_worker_tasks_left -= 1
+
+ # Queue if worker is full to avoid root task overproduction.
+ if worker_saturated(ws, self.WORKER_OVERSATURATION):
+ # TODO this should be a transition function instead.
+ # But how do we get the `ws` into it? Recommendations on the scheduler can't take arguments.
+
+ if self.validate:
+ assert not ts.queued_on, ts.queued_on
+ assert ts not in ws.queued
+
+ # TODO maintain global queue of tasks as well for newly arriving workers to use?
+ # QUESTION could `queued` be an OrderedSet instead of a HeapSet, giving us O(1)
+ # operations instead of O(logn)? Reasoning is that we're always inserting elements
+ # in priority order anyway.
+ # This wouldn't work in the case that a batch of lower-priority root tasks becomes
+ # ready before a batch of higher-priority root tasks.
+ ws.queued.add(ts)
+ ts.queued_on = ws
+ ts.state = "queued"
+ return None
+
return ws
if ts.dependencies or valid_workers is not None:
@@ -1841,6 +1902,7 @@ def transition_waiting_processing(self, key, stimulus_id):
assert not ts.who_has
assert not ts.exception_blame
assert not ts.processing_on
+ assert not ts.queued_on
assert not ts.has_lost_dependencies
assert ts not in self.unrunnable
assert all(dts.who_has for dts in ts.dependencies)
@@ -1892,6 +1954,7 @@ def transition_waiting_memory(
if self.validate:
assert not ts.processing_on
+ assert not ts.queued_on
assert ts.waiting_on
assert ts.state == "waiting"
@@ -1908,6 +1971,7 @@ def transition_waiting_memory(
if self.validate:
assert not ts.processing_on
+ assert not ts.queued_on
assert not ts.waiting_on
assert ts.who_has
@@ -1999,7 +2063,10 @@ def transition_processing_memory(
if nbytes is not None:
ts.set_nbytes(nbytes)
- _remove_from_processing(self, ts)
+ # NOTE: recommendations for queued tasks are added first, so they'll be popped last,
+ # allowing higher-priority downstream tasks to be transitioned first.
+ # FIXME: this would be incorrect if queued tasks are user-annotated as higher priority.
+ _remove_from_processing(self, ts, recommendations)
_add_to_memory(
self, ts, ws, recommendations, client_msgs, type=type, typename=typename
@@ -2007,7 +2074,18 @@ def transition_processing_memory(
if self.validate:
assert not ts.processing_on
+ assert not ts.queued_on
assert not ts.waiting_on
+ processing_recs = {
+ k: r for k, r in recommendations.items() if r == "processing"
+ }
+ assert list(processing_recs) == (
+ sr := sorted(
+ processing_recs,
+ key=lambda k: self.tasks[k].priority,
+ reverse=True,
+ )
+ ), (list(processing_recs), sr)
return recommendations, client_msgs, worker_msgs
except Exception as e:
@@ -2030,6 +2108,7 @@ def transition_memory_released(self, key, stimulus_id, safe: bool = False):
if self.validate:
assert not ts.waiting_on
assert not ts.processing_on
+ assert not ts.queued_on
if safe:
assert not ts.waiters
@@ -2191,6 +2270,7 @@ def transition_waiting_released(self, key, stimulus_id):
if self.validate:
assert not ts.who_has
assert not ts.processing_on
+ assert not ts.queued_on
dts: TaskState
for dts in ts.dependencies:
@@ -2232,9 +2312,9 @@ def transition_processing_released(self, key, stimulus_id):
assert not ts.waiting_on
assert self.tasks[key].state == "processing"
- w: str = _remove_from_processing(self, ts)
- if w:
- worker_msgs[w] = [
+ ws = _remove_from_processing(self, ts, recommendations)
+ if ws:
+ worker_msgs[ws] = [
{
"op": "free-keys",
"keys": [key],
@@ -2259,6 +2339,7 @@ def transition_processing_released(self, key, stimulus_id):
if self.validate:
assert not ts.processing_on
+ assert not ts.queued_on
return recommendations, client_msgs, worker_msgs
except Exception as e:
@@ -2301,7 +2382,7 @@ def transition_processing_erred(
ws = ts.processing_on
ws.actors.remove(ts)
- w = _remove_from_processing(self, ts)
+ w = _remove_from_processing(self, ts, recommendations)
ts.erred_on.add(w or worker) # type: ignore
if exception is not None:
@@ -2350,6 +2431,7 @@ def transition_processing_erred(
if self.validate:
assert not ts.processing_on
+ assert not ts.queued_on
return recommendations, client_msgs, worker_msgs
except Exception as e:
@@ -2390,6 +2472,108 @@ def transition_no_worker_released(self, key, stimulus_id):
pdb.set_trace()
raise
+ def transition_queued_released(self, key, stimulus_id):
+ try:
+ ts: TaskState = self.tasks[key]
+ recommendations: dict = {}
+ client_msgs: dict = {}
+ worker_msgs: dict = {}
+
+ # TODO allow `remove_worker` to clear `queued_on` and `ws.queued` eagerly; it's more efficient.
+ ws = ts.queued_on
+ assert ws
+
+ if self.validate:
+ assert ts in ws.queued
+ assert not ts.processing_on
+
+ ws.queued.remove(ts)
+ ts.queued_on = None
+
+ # TODO copied from `transition_processing_released`; factor out into helper function
+ ts.state = "released"
+
+ if ts.has_lost_dependencies:
+ recommendations[key] = "forgotten"
+ elif ts.waiters or ts.who_wants:
+ # TODO rescheduling of queued root tasks may be poor.
+ recommendations[key] = "waiting"
+
+ if recommendations.get(key) != "waiting":
+ for dts in ts.dependencies:
+ if dts.state != "released":
+ dts.waiters.discard(ts)
+ if not dts.waiters and not dts.who_wants:
+ recommendations[dts.key] = "released"
+ ts.waiters.clear()
+
+ return recommendations, client_msgs, worker_msgs
+ except Exception as e:
+ logger.exception(e)
+ if LOG_PDB:
+ import pdb
+
+ pdb.set_trace()
+ raise
+
+ def transition_queued_processing(self, key, stimulus_id):
+ try:
+ ts: TaskState = self.tasks[key]
+ recommendations: dict = {}
+ client_msgs: dict = {}
+ worker_msgs: dict = {}
+
+ ws = ts.queued_on
+ assert ws
+ # TODO should this be a graceful transition to released? I think `remove_worker`
+ # makes it such that this should never happen.
+ assert (
+ self.workers[ws.address] is ws
+ ), f"Task {ts} queued on stale worker {ws}"
+
+ if self.validate:
+ assert not ts.actor, "Actors can't be queued wat"
+ assert ts in ws.queued
+ # Copied from `transition_waiting_processing`
+ assert not ts.processing_on
+ assert not ts.waiting_on
+ assert not ts.who_has
+ assert not ts.exception_blame
+ assert not ts.has_lost_dependencies
+ assert ts not in self.unrunnable
+ assert all(dts.who_has for dts in ts.dependencies)
+
+ # TODO other validation that this is still an appropriate worker?
+
+ if not worker_saturated(ws, self.WORKER_OVERSATURATION):
+ # If more important tasks already got scheduled, remain queued
+
+ ts.queued_on = None
+ ws.queued.remove(ts)
+ # TODO Copied from `transition_waiting_processing`; factor out into helper function
+ self._set_duration_estimate(ts, ws)
+ ts.processing_on = ws
+ ts.state = "processing"
+ self.consume_resources(ts, ws)
+ self.check_idle_saturated(ws)
+ self.n_tasks += 1
+
+ if ts.actor:
+ ws.actors.add(ts)
+
+ # logger.debug("Send job to worker: %s, %s", worker, key)
+
+ worker_msgs[ws.address] = [_task_to_msg(self, ts)]
+
+ return recommendations, client_msgs, worker_msgs
+ except Exception as e:
+ logger.exception(e)
+ if LOG_PDB:
+ import pdb
+
+ pdb.set_trace()
+ raise
+
def _remove_key(self, key):
ts: TaskState = self.tasks.pop(key)
assert ts.state == "forgotten"
@@ -2413,6 +2597,7 @@ def transition_memory_forgotten(self, key, stimulus_id):
if self.validate:
assert ts.state == "memory"
assert not ts.processing_on
+ assert not ts.queued_on
assert not ts.waiting_on
if not ts.run_spec:
# It's ok to forget a pure data task
@@ -2455,6 +2640,7 @@ def transition_released_forgotten(self, key, stimulus_id):
assert ts.state in ("released", "erred")
assert not ts.who_has
assert not ts.processing_on
+ assert not ts.queued_on
assert not ts.waiting_on, (ts, ts.waiting_on)
if not ts.run_spec:
# It's ok to forget a pure data task
@@ -2495,6 +2681,8 @@ def transition_released_forgotten(self, key, stimulus_id):
("waiting", "released"): transition_waiting_released,
("waiting", "processing"): transition_waiting_processing,
("waiting", "memory"): transition_waiting_memory,
+ ("queued", "released"): transition_queued_released,
+ ("queued", "processing"): transition_queued_processing,
("processing", "released"): transition_processing_released,
("processing", "memory"): transition_processing_memory,
("processing", "erred"): transition_processing_erred,
@@ -2765,7 +2953,71 @@ def bulk_schedule_after_adding_worker(self, ws: WorkerState):
ordering, so the recommendations are sorted by priority order here.
"""
ts: TaskState
- tasks = []
+ recommendations: dict[str, str] = {}
+
+ # Redistribute the tasks between all worker queues. We bubble tasks off the back of the most-queued
+ # worker onto the front of the least-queued, and repeat this until we've accumulated enough tasks to
+ # put onto the new worker. This maintains the co-assignment of each worker's queue, minimizing the
+ # fragmentation of neighboring tasks.
+ # Note this does not rebalance all workers. It just rebalances the busiest workers, stealing just enough
+ # tasks to fill up the new worker.
+ # NOTE: this is probably going to be pretty slow for lots of queued tasks and/or lots of workers.
+ # Also unclear if this is even a good load-balancing strategy.
+ # TODO this is optimized for the add-worker case. Generalize for remove-worker as well.
+ # That would probably look like rebalancing all workers though.
+ if not math.isinf(self.WORKER_OVERSATURATION):
+ workers_with_queues: list[WorkerState] = sorted(
+ (wss for wss in self.workers.values() if wss.queued and wss is not ws),
+ key=lambda wss: len(wss.queued),
+ reverse=True,
+ )
+ if workers_with_queues:
+ total_queued = sum(len(wss.queued) for wss in workers_with_queues)
+ target_qsize = int(total_queued / len(self.workers))
+ moveable_tasks_so_far = 0
+ last_q_tasks_to_move = 0
+ i = 0
+ # Go through workers with the largest queues until we've found enough workers to steal from
+ for i, wss in enumerate(workers_with_queues):
+ n_extra_tasks = len(wss.queued) - target_qsize
+ if n_extra_tasks <= 0:
+ break
+ moveable_tasks_so_far += n_extra_tasks
+ if moveable_tasks_so_far >= target_qsize:
+ last_q_tasks_to_move = n_extra_tasks - (
+ moveable_tasks_so_far - target_qsize
+ )
+ break
+ if last_q_tasks_to_move:
+ # Starting from the smallest, bubble tasks off the back of the queue and onto the front of the next-largest.
+ # At the end, bubble tasks onto the new worker's queue
+ while i >= 0:
+ src = workers_with_queues[i]
+ dest = workers_with_queues[i - 1] if i > 0 else ws
+ for _ in range(last_q_tasks_to_move):
+ # NOTE: `popright` is not exactly the highest element, but sorting would be too expensive.
+ # It's good enough, and in the common case the heap is sorted anyway (because elements are)
+ # inserted in sorted order by `decide_worker`
+ ts = src.queued.popright()
+ ts.queued_on = dest
+ dest.queued.add(ts)
+
+ i -= 1
+ last_q_tasks_to_move = target_qsize
+
+ if (
+ ws.queued
+ and (n := task_slots_available(ws, self.WORKER_OVERSATURATION)) > 0
+ ):
+ # NOTE: reverse priority order, since recommendations are processed in LIFO order
+ for ts in reversed(list(ws.queued.topk(n))):
+ if self.validate:
+ assert ts.state == "queued"
+ assert ts.queued_on is ws, (ts.queued_on, ws)
+ assert ts.key not in recommendations, recommendations[ts.key]
+ recommendations[ts.key] = "processing"
+
+ tasks: list[TaskState] = []
for ts in self.unrunnable:
valid: set = self.valid_workers(ts)
if valid is None or ws in valid:
@@ -4253,6 +4505,10 @@ async def remove_worker(
else: # pure data
recommendations[ts.key] = "forgotten"
+ for ts in ws.queued.sorted():
+ recommendations[ts.key] = "released"
+ # ws.queued.clear() # TODO more performant
+
self.transitions(recommendations, stimulus_id=stimulus_id)
for plugin in list(self.plugins.values()):
@@ -4371,6 +4627,7 @@ def validate_released(self, key):
assert not ts.waiting_on
assert not ts.who_has
assert not ts.processing_on
+ assert not ts.queued_on
assert not any([ts in dts.waiters for dts in ts.dependencies])
assert ts not in self.unrunnable
@@ -4379,12 +4636,27 @@ def validate_waiting(self, key):
assert ts.waiting_on
assert not ts.who_has
assert not ts.processing_on
+ assert not ts.queued_on
assert ts not in self.unrunnable
for dts in ts.dependencies:
# We are waiting on a dependency iff it's not stored
assert bool(dts.who_has) != (dts in ts.waiting_on)
assert ts in dts.waiters # XXX even if dts._who_has?
+ def validate_queued(self, key):
+ ts: TaskState = self.tasks[key]
+ dts: TaskState
+ assert not ts.waiting_on
+ ws = ts.queued_on
+ assert ws
+ assert self.workers.get(ws.address) is ws, f"{ts} queued on stale worker {ws}"
+ assert ts in ws.queued
+ assert not ts.who_has
+ assert not ts.processing_on
+ for dts in ts.dependencies:
+ assert dts.who_has
+ assert ts in dts.waiters
+
def validate_processing(self, key):
ts: TaskState = self.tasks[key]
dts: TaskState
@@ -4403,6 +4675,7 @@ def validate_memory(self, key):
assert ts.who_has
assert bool(ts in self.replicated_tasks) == (len(ts.who_has) > 1)
assert not ts.processing_on
+ assert not ts.queued_on
assert not ts.waiting_on
assert ts not in self.unrunnable
for dts in ts.dependents:
@@ -4417,6 +4690,7 @@ def validate_no_worker(self, key):
assert not ts.waiting_on
assert ts in self.unrunnable
assert not ts.processing_on
+ assert not ts.queued_on
assert not ts.who_has
for dts in ts.dependencies:
assert dts.who_has
@@ -7131,7 +7405,9 @@ def request_remove_replicas(
)
-def _remove_from_processing(state: SchedulerState, ts: TaskState) -> str | None:
+def _remove_from_processing(
+ state: SchedulerState, ts: TaskState, recommendations: dict
+) -> str | None:
"""Remove *ts* from the set of processing tasks.
See also
@@ -7157,6 +7433,19 @@ def _remove_from_processing(state: SchedulerState, ts: TaskState) -> str | None:
state.check_idle_saturated(ws)
state.release_resources(ts, ws)
+ # If a slot has opened up for a queued task, schedule it.
+ if ws.queued and not worker_saturated(ws, state.WORKER_OVERSATURATION):
+ # TODO peek or pop?
+ # What if multiple tasks complete on a worker in one transition cycle? Is that possible?
+ # TODO should we only be scheduling 1 taks? Or N open threads? Is there a possible deadlock
+ # where tasks remain queued on a worker forever?
+ qts = ws.queued.peek()
+ if state.validate:
+ assert qts.state == "queued"
+ assert qts.queued_on is ws, (qts.queued_on, ws)
+ assert qts.key not in recommendations, recommendations[qts.key]
+ recommendations[qts.key] = "processing"
+
return ws.address
@@ -7434,8 +7723,18 @@ def validate_task_state(ts: TaskState) -> None:
assert dts.state != "forgotten"
assert (ts.processing_on is not None) == (ts.state == "processing")
+ assert not (ts.processing_on and ts.queued_on), (ts.processing_on, ts.queued_on)
assert bool(ts.who_has) == (ts.state == "memory"), (ts, ts.who_has, ts.state)
+ if ts.queued_on:
+ assert ts.state == "queued"
+ assert ts in ts.queued_on.queued
+
+ if ts.state == "queued":
+ assert ts.queued_on
+ assert not ts.processing_on
+ assert not ts.who_has
+
if ts.state == "processing":
assert all(dts.who_has for dts in ts.dependencies), (
"task processing without all deps",
@@ -7443,6 +7742,7 @@ def validate_task_state(ts: TaskState) -> None:
str(ts.dependencies),
)
assert not ts.waiting_on
+ assert not ts.queued_on
if ts.who_has:
assert ts.waiters or ts.who_wants, (
@@ -7530,6 +7830,19 @@ def heartbeat_interval(n: int) -> float:
return n / 200 + 1
+def task_slots_available(ws: WorkerState, oversaturation_factor: float) -> int:
+ "Number of tasks that can be sent to this worker without oversaturating it"
+ assert not math.isinf(oversaturation_factor)
+ nthreads = ws.nthreads
+ return max(nthreads + int(oversaturation_factor * nthreads), 1) - len(ws.processing)
+
+
+def worker_saturated(ws: WorkerState, oversaturation_factor: float) -> bool:
+ if math.isinf(oversaturation_factor):
+ return False
+ return task_slots_available(ws, oversaturation_factor) <= 0
+
+
class KilledWorker(Exception):
def __init__(self, task: str, last_worker: WorkerState):
super().__init__(task, last_worker)
diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py
index 2d1d7a52c8b..09af6984076 100644
--- a/distributed/tests/test_scheduler.py
+++ b/distributed/tests/test_scheduler.py
@@ -16,12 +16,12 @@
import cloudpickle
import psutil
import pytest
-from tlz import concat, first, merge, valmap
+from tlz import concat, first, merge, partition, valmap
from tornado.ioloop import IOLoop
import dask
from dask import delayed
-from dask.utils import apply, parse_timedelta, stringify, tmpfile, typename
+from dask.utils import apply, parse_bytes, parse_timedelta, stringify, tmpfile, typename
from distributed import (
CancelledError,
@@ -54,6 +54,7 @@
raises_with_cause,
slowadd,
slowdec,
+ slowidentity,
slowinc,
tls_only_security,
varying,
@@ -245,6 +246,124 @@ def random(**kwargs):
test_decide_worker_coschedule_order_neighbors_()
+@pytest.mark.slow
+@gen_cluster(
+ client=True,
+ nthreads=[("", 2)] * 2,
+ worker_kwargs={"memory_limit": "1.0GiB"},
+ timeout=3600, # TODO remove
+ Worker=Nanny,
+ scheduler_kwargs=dict( # TODO remove
+ dashboard=True,
+ dashboard_address=":8787",
+ ),
+ config={
+ "distributed.worker.memory.target": False,
+ "distributed.worker.memory.spill": False,
+ "distributed.scheduler.work-stealing": False,
+ },
+)
+async def test_root_task_overproduction(c, s, *nannies):
+ """
+ Workload that would run out of memory and kill workers if >2 root tasks were
+ ever in memory at once on a worker.
+ """
+
+ @delayed(pure=True)
+ def big_data(size: int) -> str:
+ return "x" * size
+
+ roots = [
+ big_data(parse_bytes("350 MiB"), dask_key_name=f"root-{i}") for i in range(16)
+ ]
+ passthrough = [delayed(slowidentity)(x) for x in roots]
+ memory_consumed = [delayed(len)(x) for x in passthrough]
+ reduction = [sum(sizes) for sizes in partition(4, memory_consumed)]
+ final = sum(reduction)
+
+ await c.compute(final)
+
+
+@pytest.mark.parametrize(
+ "oversaturation, expected_task_counts",
+ [
+ (1.5, (5, 2)),
+ (1.0, (4, 2)),
+ (0.0, (2, 1)),
+ (-1.0, (1, 1)),
+ (float("inf"), (7, 3))
+ # ^ depends on root task assignment logic; ok if changes, just needs to add up to 10
+ ],
+)
+def test_oversaturation_factor(oversaturation, expected_task_counts: tuple[int, int]):
+ @gen_cluster(
+ client=True,
+ nthreads=[("", 2), ("", 1)],
+ config={
+ "distributed.scheduler.worker-oversaturation": oversaturation,
+ },
+ )
+ async def _test_oversaturation_factor(c, s, a, b):
+ event = Event()
+ fs = c.map(lambda _: event.wait(), range(10))
+ while a.state.executing_count < min(
+ a.nthreads, expected_task_counts[0]
+ ) or b.state.executing_count < min(b.nthreads, expected_task_counts[1]):
+ await asyncio.sleep(0.01)
+
+ assert len(a.state.tasks) == expected_task_counts[0]
+ assert len(b.state.tasks) == expected_task_counts[1]
+
+ await event.set()
+ await c.gather(fs)
+
+ _test_oversaturation_factor()
+
+
+@pytest.mark.parametrize(
+ "saturation_factor",
+ [
+ 0.0,
+ 1.0,
+ pytest.param(
+ float("inf"),
+ marks=pytest.mark.skip("https://github.com/dask/distributed/issues/6597"),
+ ),
+ ],
+)
+@gen_cluster(
+ client=True,
+ nthreads=[("", 2), ("", 1)],
+)
+async def test_oversaturation_multiple_task_groups(c, s, a, b, saturation_factor):
+ s.WORKER_OVERSATURATION = saturation_factor
+ xs = [delayed(i, name=f"x-{i}") for i in range(9)]
+ ys = [delayed(i, name=f"y-{i}") for i in range(9)]
+ zs = [x + y for x, y in zip(xs, ys)]
+
+ await c.gather(c.compute(zs))
+
+ assert not a.incoming_transfer_log, [l["keys"] for l in a.incoming_transfer_log]
+ assert not b.incoming_transfer_log, [l["keys"] for l in b.incoming_transfer_log]
+ assert len(a.tasks) == 18
+ assert len(b.tasks) == 9
+
+
+@gen_cluster(
+ client=True,
+ nthreads=[("", 2)] * 2,
+ timeout=3600, # TODO remove
+ scheduler_kwargs=dict( # TODO remove
+ dashboard=True,
+ dashboard_address=":8787",
+ ),
+)
+async def test_queued_tasks_rebalance(c, s, a, b):
+ event = Event()
+ fs = c.map(lambda _: event.wait(), range(100))
+ await c.gather(fs)
+
+
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3)
async def test_move_data_over_break_restrictions(client, s, a, b, c):
[x] = await client.scatter([1], workers=b.address)
diff --git a/distributed/widgets/templates/worker_state.html.j2 b/distributed/widgets/templates/worker_state.html.j2
index cd152080bfc..08629998e1f 100644
--- a/distributed/widgets/templates/worker_state.html.j2
+++ b/distributed/widgets/templates/worker_state.html.j2
@@ -3,3 +3,4 @@
status: {{ status }}
memory: {{ has_what | length }}
processing: {{ processing | length }}
+ queued: {{ queued | length }}