Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 48 additions & 32 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Iterable,
Iterator,
Mapping,
MutableSet,
Sequence,
Set,
)
Expand Down Expand Up @@ -66,7 +67,6 @@
from distributed._stories import scheduler_story
from distributed.active_memory_manager import ActiveMemoryManagerExtension, RetireWorker
from distributed.batched import BatchedSend
from distributed.collections import HeapSet
from distributed.comm import (
Comm,
CommClosedError,
Expand Down Expand Up @@ -1534,8 +1534,9 @@ class SchedulerState:
#: All tasks currently known to the scheduler
tasks: dict[str, TaskState]

#: Tasks in the "queued" state, ordered by priority
queued: HeapSet[TaskState]
#: Tasks in the "queued" state, ordered by priority.
#: A `SortedSet` (doesn't support annotations https://github.com/python/typeshed/issues/8574)
queued: MutableSet[TaskState]

#: Tasks in the "no-worker" state
unrunnable: set[TaskState]
Expand Down Expand Up @@ -1610,7 +1611,7 @@ def __init__(
resources: dict[str, dict[str, float]],
tasks: dict[str, TaskState],
unrunnable: set[TaskState],
queued: HeapSet[TaskState],
queued: SortedSet,
validate: bool,
plugins: Iterable[SchedulerPlugin] = (),
transition_counter_max: int | Literal[False] = False,
Expand Down Expand Up @@ -2219,8 +2220,9 @@ def transition_waiting_processing(self, key: str, stimulus_id: str) -> RecsMsgs:
if not (ws := self.decide_worker_rootish_queuing_disabled(ts)):
return {ts.key: "no-worker"}, {}, {}
else:
if not (ws := self.decide_worker_rootish_queuing_enabled()):
return {ts.key: "queued"}, {}, {}
# All rootish tasks go straight to `queued` first.
# `stimulus_queue_slots_maybe_opened` will then maybe pop some off later.
return {ts.key: "queued"}, {}, {}
else:
if not (ws := self.decide_worker_non_rootish(ts)):
return {ts.key: "no-worker"}, {}, {}
Expand Down Expand Up @@ -2651,7 +2653,6 @@ def transition_waiting_queued(self, key: str, stimulus_id: str) -> RecsMsgs:
ts = self.tasks[key]

if self.validate:
assert not self.idle_task_count, (ts, self.idle_task_count)
self._validate_ready(ts)

ts.state = "queued"
Expand Down Expand Up @@ -2683,21 +2684,25 @@ def transition_queued_released(self, key: str, stimulus_id: str) -> RecsMsgs:
self._propagate_released(ts, recommendations)
return recommendations, {}, {}

def transition_queued_processing(self, key: str, stimulus_id: str) -> RecsMsgs:
def transition_queued_processing(
self, key: str, stimulus_id: str, *, ws: WorkerState
) -> RecsMsgs:
# Never called as a recommendation, only directly via `transition("processing", ws)`.
# The `ws` argument is required.
ts = self.tasks[key]
recommendations: Recs = {}
worker_msgs: Msgs = {}

if self.validate:
assert not ts.actor, f"Actors can't be queued: {ts}"
assert ts in self.queued
assert not _worker_full(ws, self.WORKER_SATURATION), (
ws,
_task_slots_available(ws, self.WORKER_SATURATION),
)

if ws := self.decide_worker_rootish_queuing_enabled():
self.queued.discard(ts)
worker_msgs = self._add_to_processing(ts, ws)
# If no worker, task just stays `queued`
self.queued.discard(ts)
worker_msgs: Msgs = self._add_to_processing(ts, ws)

return recommendations, {}, worker_msgs
return {}, {}, worker_msgs

def _remove_key(self, key: str) -> None:
ts = self.tasks.pop(key)
Expand Down Expand Up @@ -3529,7 +3534,7 @@ def __init__(
self._last_client = None
self._last_time = 0
unrunnable = set()
queued: HeapSet[TaskState] = HeapSet(key=operator.attrgetter("priority"))
queued = SortedSet(key=operator.attrgetter("priority"))

self.datasets = {}

Expand Down Expand Up @@ -4580,6 +4585,7 @@ def update_graph(
logger.exception(e)

self.transitions(recommendations, stimulus_id)
self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id)

for ts in touched_tasks:
if ts.state in ("memory", "erred"):
Expand Down Expand Up @@ -4608,23 +4614,33 @@ def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None:
"""
if not self.queued:
return
slots_available = sum(
_task_slots_available(ws, self.WORKER_SATURATION)
for ws in self.idle_task_count
)
if slots_available == 0:
return

recommendations: Recs = {}
for qts in self.queued.peekn(slots_available):
if self.validate:
assert qts.state == "queued", qts.state
assert not qts.processing_on, (qts, qts.processing_on)
assert not qts.waiting_on, (qts, qts.processing_on)
assert qts.who_wants or qts.waiters, qts
recommendations[qts.key] = "processing"

self.transitions(recommendations, stimulus_id)
submittable: list[tuple[str, WorkerState]] = []
for ws in self.idle_task_count:
ws_idx: int
ws_idx = self.workers.index(ws.address) # type: ignore
# TODO assumes all workers have the same number of threads
tasks_per_worker = math.ceil(len(self.queued) / len(self.workers))
q_idx = ws_idx * tasks_per_worker
slots = _task_slots_available(ws, self.WORKER_SATURATION)
n = min(slots, tasks_per_worker)

if q_idx >= len(self.queued):
# TODO should we always select from the back of the queue when there are
# more workers than needed? Will this lead to uneven task selection?
q_idx = len(self.queued) - n

for qts in self.queued[q_idx : q_idx + n]: # type: ignore
if self.validate:
assert qts.state == "queued", qts.state
assert not qts.processing_on, (qts, qts.processing_on)
assert not qts.waiting_on, (qts, qts.processing_on)
assert qts.who_wants or qts.waiters, qts
# Store in a list for later to avoid mutating `queued` while iterating
submittable.append((qts.key, ws))

for key, ws in submittable:
self.transition(key, "processing", ws=ws, stimulus_id=stimulus_id)

def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwargs):
"""Mark that a task has finished execution on a particular worker"""
Expand Down
35 changes: 17 additions & 18 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import asyncio
import itertools
import json
import logging
import math
Expand Down Expand Up @@ -156,10 +155,7 @@ def test_decide_worker_coschedule_order_neighbors(ndeps, nthreads):
@gen_cluster(
client=True,
nthreads=nthreads,
config={
"distributed.scheduler.work-stealing": False,
"distributed.scheduler.worker-saturation": float("inf"),
},
config={"distributed.scheduler.work-stealing": False},
)
async def test_decide_worker_coschedule_order_neighbors_(c, s, *workers):
r"""
Expand Down Expand Up @@ -188,6 +184,11 @@ async def test_decide_worker_coschedule_order_neighbors_(c, s, *workers):
"""
da = pytest.importorskip("dask.array")
np = pytest.importorskip("numpy")
no_queue = math.isinf(s.WORKER_SATURATION)
if not no_queue and len({w.state.nthreads for w in workers}) > 1:
pytest.skip(
"co-assignment + queuing is imbalanced for heterogeneous workers"
)

if ndeps == 0:
x = da.random.random((100, 100), chunks=(10, 10))
Expand Down Expand Up @@ -219,7 +220,9 @@ def random(**kwargs):
keys = {stringify(k) for k in keys}

# No more than 2 workers should have any keys
assert sum(any(k in w.data for k in keys) for w in workers) <= 2
assert sum(any(k in w.data for k in keys) for w in workers) <= (
2 if no_queue else 3
)

# What fraction of the keys for this row does each worker hold?
key_fractions = [
Expand All @@ -233,10 +236,10 @@ def random(**kwargs):

# There may be one or two rows that were poorly split across workers,
# but the vast majority of rows should only be on one worker.
assert np.mean(primary_worker_key_fractions) >= 0.9
assert np.median(primary_worker_key_fractions) == 1.0
assert np.mean(secondary_worker_key_fractions) <= 0.1
assert np.median(secondary_worker_key_fractions) == 0.0
assert np.mean(primary_worker_key_fractions) >= (0.9 if no_queue else 0.7)
assert np.median(primary_worker_key_fractions) >= (1.0 if no_queue else 0.9)
assert np.mean(secondary_worker_key_fractions) <= (0.1 if no_queue else 0.3)
assert np.median(secondary_worker_key_fractions) <= (0.0 if no_queue else 0.1)

# Check that there were few transfers
unexpected_transfers = []
Expand All @@ -254,7 +257,9 @@ def random(**kwargs):
# A transfer at the very end to move aggregated results is fine (necessary with
# unbalanced workers in fact), but generally there should be very very few
# transfers.
assert len(unexpected_transfers) <= 3, unexpected_transfers
assert len(unexpected_transfers) <= (
3 if no_queue else len(workers) + 1
), unexpected_transfers

test_decide_worker_coschedule_order_neighbors_()

Expand Down Expand Up @@ -423,10 +428,6 @@ async def test_queued_release_multiple_workers(c, s, *workers):
await async_wait_for(lambda: second_batch[0].key in s.tasks, 5)

# All of the second batch should be queued after the first batch
assert [ts.key for ts in s.queued.sorted()] == [
f.key
for f in itertools.chain(first_batch[s.total_nthreads :], second_batch)
]

# Cancel the first batch.
# Use `Client.close` instead of `del first_batch` because deleting futures sends cancellation
Expand All @@ -438,9 +439,7 @@ async def test_queued_release_multiple_workers(c, s, *workers):
await async_wait_for(lambda: len(s.tasks) == len(second_batch), 5)

# Second batch should move up the queue and start processing
assert len(s.queued) == len(second_batch) - s.total_nthreads, list(
s.queued.sorted()
)
assert len(s.queued) == len(second_batch) - s.total_nthreads, list(s.queued)

await event.set()
await c2.gather(second_batch)
Expand Down