Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 63 additions & 41 deletions distributed/active_memory_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import asyncio
from collections import defaultdict
from collections.abc import Generator
from typing import TYPE_CHECKING
Expand All @@ -10,7 +9,8 @@
import dask
from dask.utils import parse_timedelta

from .utils import import_term
from .metrics import time
from .utils import import_term, log_errors

if TYPE_CHECKING:
from .scheduler import Scheduler, TaskState, WorkerState
Expand Down Expand Up @@ -115,45 +115,67 @@ def run_once(self, comm=None) -> None:
"""Run all policies once and asynchronously (fire and forget) enact their
recommendations to replicate/drop keys
"""
# This should never fail since this is a synchronous method
assert not hasattr(self, "pending")

self.pending = defaultdict(lambda: (set(), set()))
self.workers_memory = {
w: w.memory.optimistic for w in self.scheduler.workers.values()
}
try:
# populate self.pending
self._run_policies()

drop_by_worker: defaultdict[str, set[str]] = defaultdict(set)
repl_by_worker: defaultdict[str, dict[str, list[str]]] = defaultdict(dict)

for ts, (pending_repl, pending_drop) in self.pending.items():
if not ts.who_has:
continue
who_has = [ws_snd.address for ws_snd in ts.who_has - pending_drop]

assert who_has # Never drop the last replica
for ws_rec in pending_repl:
assert ws_rec not in ts.who_has
repl_by_worker[ws_rec.address][ts.key] = who_has
for ws in pending_drop:
assert ws in ts.who_has
drop_by_worker[ws.address].add(ts.key)

# Fire-and-forget enact recommendations from policies
# This is temporary code, waiting for
# https://github.com/dask/distributed/pull/5046
for addr, who_has_map in repl_by_worker.items():
asyncio.create_task(self.scheduler.gather_on_worker(addr, who_has_map))
for addr, keys in drop_by_worker.items():
asyncio.create_task(self.scheduler.delete_worker_data(addr, keys))
# End temporary code

finally:
del self.workers_memory
del self.pending
with log_errors():
# This should never fail since this is a synchronous method
assert not hasattr(self, "pending")

self.pending = defaultdict(lambda: (set(), set()))
self.workers_memory = {
w: w.memory.optimistic for w in self.scheduler.workers.values()
}
try:
# populate self.pending
self._run_policies()

drop_by_worker: defaultdict[WorkerState, set[TaskState]] = defaultdict(
set
)
repl_by_worker: defaultdict[
WorkerState, dict[TaskState, set[WorkerState]]
] = defaultdict(dict)

for ts, (pending_repl, pending_drop) in self.pending.items():
if not ts.who_has:
continue
who_has = {ws_snd.address for ws_snd in ts.who_has - pending_drop}
assert who_has # Never drop the last replica
for ws_rec in pending_repl:
assert ws_rec not in ts.who_has
repl_by_worker[ws_rec][ts] = who_has
for ws in pending_drop:
assert ws in ts.who_has
drop_by_worker[ws].add(ts)

# Fire-and-forget enact recommendations from policies
stimulus_id = str(time())
for ws_rec, ts_to_who_has in repl_by_worker.items():
self.scheduler.stream_comms[ws_rec.address].send(
{
"op": "acquire-replicas",
"keys": [ts.key for ts in ts_to_who_has],
"stimulus_id": "acquire-replicas-" + stimulus_id,
"priorities": {ts.key: ts.priority for ts in ts_to_who_has},
"who_has": {ts.key: v for ts, v in ts_to_who_has.items()},
},
)

for ws, tss in drop_by_worker.items():
# The scheduler immediately forgets about the replica and suggests
# the worker to drop it. The worker may refuse, at which point it
# will send back an add-keys message to reinstate it.
for ts in tss:
self.scheduler.remove_replica(ts, ws)
self.scheduler.stream_comms[ws.address].send(
{
"op": "remove-replicas",
"keys": [ts.key for ts in tss],
"stimulus_id": "remove-replicas-" + stimulus_id,
}
)

finally:
del self.workers_memory
del self.pending

def _run_policies(self) -> None:
"""Sequentially run ActiveMemoryManagerPolicy.run() for all registered policies,
Expand Down
119 changes: 78 additions & 41 deletions distributed/tests/test_active_memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ async def test_drop_with_waiter(c, s, a, b):
assert not y2.done()


@pytest.mark.xfail(reason="distributed#5265")
@gen_cluster(client=True, config=NO_AMM_START)
async def test_double_drop(c, s, a, b):
"""An AMM drop policy runs once to drop one of the two replicas of a key.
Expand Down Expand Up @@ -329,46 +328,6 @@ async def test_drop_with_bad_candidates(c, s, a, b):
assert s.tasks["x"].who_has == {ws0, ws1}


class DropEverything(ActiveMemoryManagerPolicy):
"""Inanely suggest to drop every single key in the cluster"""

def run(self):
for ts in self.manager.scheduler.tasks.values():
# Instead of yielding ("drop", ts, None) for each worker, which would result
# in semi-predictable output about which replica survives, randomly choose a
# different survivor at each AMM run.
candidates = list(ts.who_has)
random.shuffle(candidates)
for ws in candidates:
yield "drop", ts, {ws}


@pytest.mark.xfail(reason="distributed#5046, distributed#5265")
@pytest.mark.slow
@gen_cluster(
client=True,
nthreads=[("", 1)] * 8,
Worker=Nanny,
config={
"distributed.scheduler.active-memory-manager.start": True,
"distributed.scheduler.active-memory-manager.interval": 0.1,
"distributed.scheduler.active-memory-manager.policies": [
{"class": "distributed.tests.test_active_memory_manager.DropEverything"},
],
},
)
async def test_drop_stress(c, s, *nannies):
"""A policy which suggests dropping everything won't break a running computation,
but only slow it down.
"""
import dask.array as da

rng = da.random.RandomState(0)
a = rng.random((20, 20), chunks=(1, 1))
b = (a @ a.T).sum().round(3)
assert await c.compute(b) == 2134.398


@gen_cluster(nthreads=[("", 1)] * 4, client=True, config=demo_config("replicate", n=2))
async def test_replicate(c, s, *workers):
futures = await c.scatter({"x": 123})
Expand Down Expand Up @@ -496,3 +455,81 @@ async def test_ReduceReplicas(c, s, *workers):
s.extensions["amm"].run_once()
while len(s.tasks["x"].who_has) > 1:
await asyncio.sleep(0.01)


class DropEverything(ActiveMemoryManagerPolicy):
"""Inanely suggest to drop every single key in the cluster"""

def __init__(self):
self.i = 0

def run(self):
for ts in self.manager.scheduler.tasks.values():
# Instead of yielding ("drop", ts, None) for each worker, which would result
# in semi-predictable output about which replica survives, randomly choose a
# different survivor at each AMM run.
candidates = list(ts.who_has)
random.shuffle(candidates)
for ws in candidates:
yield "drop", ts, {ws}

# Stop running after ~2s
self.i += 1
if self.i == 20:
self.manager.policies.remove(self)


async def _tensordot_stress(c):
da = pytest.importorskip("dask.array")

rng = da.random.RandomState(0)
a = rng.random((20, 20), chunks=(1, 1))
b = (a @ a.T).sum().round(3)
assert await c.compute(b) == 2134.398


@pytest.mark.slow
@pytest.mark.xfail(reason="https://github.com/dask/distributed/issues/5371")
@gen_cluster(
client=True,
nthreads=[("", 1)] * 4,
Worker=Nanny,
config={
"distributed.scheduler.active-memory-manager.start": True,
"distributed.scheduler.active-memory-manager.interval": 0.1,
"distributed.scheduler.active-memory-manager.policies": [
{"class": "distributed.tests.test_active_memory_manager.DropEverything"},
],
},
timeout=120,
)
async def test_drop_stress(c, s, *nannies):
"""A policy which suggests dropping everything won't break a running computation,
but only slow it down.

See also: test_ReduceReplicas_stress
"""
await _tensordot_stress(c)


@pytest.mark.slow
@pytest.mark.xfail(reason="https://github.com/dask/distributed/issues/5371")
@gen_cluster(
client=True,
nthreads=[("", 1)] * 4,
Worker=Nanny,
config={
"distributed.scheduler.active-memory-manager.start": True,
"distributed.scheduler.active-memory-manager.interval": 0.1,
"distributed.scheduler.active-memory-manager.policies": [
{"class": "distributed.active_memory_manager.ReduceReplicas"},
],
},
timeout=120,
)
async def test_ReduceReplicas_stress(c, s, *nannies):
"""Running ReduceReplicas compulsively won't break a running computation. Unlike
test_drop_stress above, this test does not stop running after a few seconds - the
policy must not disrupt the computation too much.
"""
await _tensordot_stress(c)