From d8c9586e8e73f29c4d81893d8670e322048f2d81 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 26 Sep 2022 17:07:41 +0100 Subject: [PATCH 1/2] AMM/measure --- distributed/active_memory_manager.py | 36 +++- distributed/distributed-schema.yaml | 10 +- distributed/distributed.yaml | 10 +- .../tests/test_active_memory_manager.py | 177 ++++++++---------- docs/source/active_memory_manager.rst | 2 + 5 files changed, 133 insertions(+), 102 deletions(-) diff --git a/distributed/active_memory_manager.py b/distributed/active_memory_manager.py index 37fb22b9483..c6a07574fd4 100644 --- a/distributed/active_memory_manager.py +++ b/distributed/active_memory_manager.py @@ -46,14 +46,20 @@ class ActiveMemoryManagerExtension: ``distributed.scheduler.active-memory-manager``. """ + #: Back-reference to the scheduler holding this extension scheduler: Scheduler + #: All active policies policies: set[ActiveMemoryManagerPolicy] + #: Memory measure to use. Must be one of the attributes or properties of + #: :class:`distributed.scheduler.MemoryState`. + measure: str + #: Run automatically every this many seconds interval: float - - # These attributes only exist within the scope of self.run() - # Current memory (in bytes) allocated on each worker, plus/minus pending actions + #: Current memory (in bytes) allocated on each worker, plus/minus pending actions + #: This attribute only exist within the scope of self.run(). workers_memory: dict[WorkerState, int] - # Pending replications and deletions for each task + #: Pending replications and deletions for each task + #: This attribute only exist within the scope of self.run(). pending: dict[TaskState, tuple[set[WorkerState], set[WorkerState]]] def __init__( @@ -63,6 +69,7 @@ def __init__( # away on the fly a specialized manager, separate from the main one. policies: set[ActiveMemoryManagerPolicy] | None = None, *, + measure: str | None = None, register: bool = True, start: bool | None = None, interval: float | None = None, @@ -83,6 +90,23 @@ def __init__( for policy in policies: self.add_policy(policy) + if not measure: + measure = dask.config.get( + "distributed.scheduler.active-memory-manager.measure" + ) + mem = scheduler.memory + measure_domain = { + name + for name in dir(mem) + if not name.startswith("_") and isinstance(getattr(mem, name), int) + } + if not isinstance(measure, str) or measure not in measure_domain: + raise ValueError( + "distributed.scheduler.active-memory-manager.measure " + "must be one of " + ", ".join(sorted(measure_domain)) + ) + self.measure = measure + if register: scheduler.extensions["amm"] = self scheduler.handlers["amm_handler"] = self.amm_handler @@ -92,6 +116,7 @@ def __init__( dask.config.get("distributed.scheduler.active-memory-manager.interval") ) self.interval = interval + if start is None: start = dask.config.get("distributed.scheduler.active-memory-manager.start") if start: @@ -140,8 +165,9 @@ def run_once(self) -> None: assert not hasattr(self, "pending") self.pending = {} + measure = self.measure self.workers_memory = { - w: w.memory.optimistic for w in self.scheduler.workers.values() + ws: getattr(ws.memory, measure) for ws in self.scheduler.workers.values() } try: # populate self.pending diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index 1e6daf54253..f282c01eb5d 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -277,7 +277,7 @@ properties: active-memory-manager: type: object - required: [start, interval, policies] + required: [start, interval, measure, policies] additionalProperties: false properties: start: @@ -287,6 +287,14 @@ properties: type: string description: Time expression, e.g. "2s". Run the AMM cycle every . + measure: + enum: + - process + - optimistic + - managed + - managed_in_memory + description: + One of the attributes of distributed.scheduler.MemoryState policies: type: array items: diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index d3a93a7a651..fc767e32edb 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -67,11 +67,17 @@ distributed: # you'll have to either manually start it with client.amm.start() or run it once # with client.amm.run_once(). start: false + # Once started, run the AMM cycle every interval: 2s + + # Memory measure to use. Must be one of the attributes of + # distributed.scheduler.MemoryState. + measure: optimistic + + # Policies that should be executed at every cycle. Any additional keys in each + # object are passed as keyword arguments to the policy constructor. policies: - # Policies that should be executed at every cycle. Any additional keys in each - # object are passed as keyword arguments to the policy constructor. - class: distributed.active_memory_manager.ReduceReplicas worker: diff --git a/distributed/tests/test_active_memory_manager.py b/distributed/tests/test_active_memory_manager.py index 4c3c513adcd..d18c7894531 100644 --- a/distributed/tests/test_active_memory_manager.py +++ b/distributed/tests/test_active_memory_manager.py @@ -3,14 +3,16 @@ import asyncio import logging import random +import warnings from collections.abc import Iterator from contextlib import contextmanager -from time import sleep from typing import Any, Literal import pytest -from distributed import Event, Lock, Nanny, wait +import dask.config + +from distributed import Event, Lock, Scheduler, wait from distributed.active_memory_manager import ( ActiveMemoryManagerExtension, ActiveMemoryManagerPolicy, @@ -18,14 +20,17 @@ ) from distributed.core import Status from distributed.utils_test import ( + BlockedGatherDep, assert_story, captured_logger, gen_cluster, + gen_test, inc, lock_inc, slowinc, wait_for_state, ) +from distributed.worker_state_machine import AcquireReplicasEvent NO_AMM_START = {"distributed.scheduler.active-memory-manager.start": False} @@ -87,11 +92,13 @@ def demo_config( candidates: list[int] | None = None, start: bool = False, interval: float = 0.1, + measure: str = "managed", ) -> dict[str, Any]: """Create a dask config for AMM with DemoPolicy""" return { "distributed.scheduler.active-memory-manager.start": start, "distributed.scheduler.active-memory-manager.interval": interval, + "distributed.scheduler.active-memory-manager.measure": measure, "distributed.scheduler.active-memory-manager.policies": [ { "class": "distributed.tests.test_active_memory_manager.DemoPolicy", @@ -349,25 +356,15 @@ async def test_double_drop_stress(c, s, a, b): assert len(s.tasks["x"].who_has) == 1 -@pytest.mark.slow -@gen_cluster( - nthreads=[("", 1)] * 4, - Worker=Nanny, - client=True, - worker_kwargs={"memory_limit": "2 GiB"}, - config=demo_config("drop", n=1), -) -async def test_drop_from_worker_with_least_free_memory(c, s, *nannies): - a1, a2, a3, a4 = s.workers.keys() +@gen_cluster(nthreads=[("", 1)] * 4, client=True, config=demo_config("drop", n=1)) +async def test_drop_from_worker_with_least_free_memory(c, s, *workers): ws1, ws2, ws3, ws4 = s.workers.values() futures = await c.scatter({"x": 1}, broadcast=True) assert s.tasks["x"].who_has == {ws1, ws2, ws3, ws4} - # Allocate enough RAM to be safely more than unmanaged memory - clog = c.submit(lambda: "x" * 2**29, workers=[a3]) # 512 MiB - # await wait(clog) is not enough; we need to wait for the heartbeats - while ws3.memory.optimistic < 2**29: - await asyncio.sleep(0.01) + clog = c.submit(lambda: "x" * 100, workers=[ws3.address]) + await wait(clog) + s.extensions["amm"].run_once() while s.tasks["x"].who_has != {ws1, ws2, ws4}: @@ -612,27 +609,14 @@ async def test_double_replicate_stress(c, s, a, b): await asyncio.sleep(0.01) -@pytest.mark.slow -@gen_cluster( - nthreads=[("", 1)] * 4, - Worker=Nanny, - client=True, - worker_kwargs={"memory_limit": "2 GiB"}, - config=demo_config("replicate", n=1), -) -async def test_replicate_to_worker_with_most_free_memory(c, s, *nannies): - a1, a2, a3, a4 = s.workers.keys() +@gen_cluster(nthreads=[("", 1)] * 4, client=True, config=demo_config("replicate", n=1)) +async def test_replicate_to_worker_with_most_free_memory(c, s, *workers): ws1, ws2, ws3, ws4 = s.workers.values() - futures = await c.scatter({"x": 1}, workers=[a1]) + x = await c.scatter({"x": 1}, workers=[ws1.address]) + clogs = await c.scatter([2, 3], workers=[ws2.address, ws4.address]) + assert s.tasks["x"].who_has == {ws1} - # Allocate enough RAM to be safely more than unmanaged memory - clog2 = c.submit(lambda: "x" * 2**29, workers=[a2]) # 512 MiB - clog4 = c.submit(lambda: "x" * 2**29, workers=[a4]) # 512 MiB - # await wait(clog) is not enough; we need to wait for the heartbeats - for ws in (ws2, ws4): - while ws.memory.optimistic < 2**29: - await asyncio.sleep(0.01) s.extensions["amm"].run_once() while s.tasks["x"].who_has != {ws1, ws3}: @@ -701,6 +685,17 @@ async def test_replicate_avoids_paused_workers_2(c, s, a, b): assert "x" not in b.data +@gen_test() +async def test_bad_measure(): + with dask.config.set( + {"distributed.scheduler.active-memory-manager.measure": "notexist"} + ): + with pytest.raises(ValueError) as e: + await Scheduler(dashboard_address=":0") + + assert "measure must be one of " in str(e.value) + + @gen_cluster( nthreads=[("", 1)] * 4, client=True, @@ -789,20 +784,19 @@ async def test_RetireWorker_no_remove(c, s, a, b): assert not s.extensions["amm"].policies -@pytest.mark.slow @pytest.mark.parametrize("use_ReduceReplicas", [False, True]) @gen_cluster( client=True, - Worker=Nanny, config={ "distributed.scheduler.active-memory-manager.start": True, "distributed.scheduler.active-memory-manager.interval": 0.1, + "distributed.scheduler.active-memory-manager.measure": "managed", "distributed.scheduler.active-memory-manager.policies": [ {"class": "distributed.active_memory_manager.ReduceReplicas"}, ], }, ) -async def test_RetireWorker_with_ReduceReplicas(c, s, *nannies, use_ReduceReplicas): +async def test_RetireWorker_with_ReduceReplicas(c, s, *workers, use_ReduceReplicas): """RetireWorker and ReduceReplicas work well with each other. If ReduceReplicas is enabled, @@ -823,12 +817,12 @@ async def test_RetireWorker_with_ReduceReplicas(c, s, *nannies, use_ReduceReplic if not use_ReduceReplicas: s.extensions["amm"].policies.clear() - x = c.submit(lambda: "x" * 2**26, key="x", workers=[ws_a.address]) # 64 MiB - y = c.submit(lambda: "y" * 2**26, key="y", workers=[ws_a.address]) # 64 MiB + x = c.submit(lambda: "x", key="x", workers=[ws_a.address]) + y = c.submit(lambda: "y", key="y", workers=[ws_a.address]) z = c.submit(lambda x: None, x, key="z", workers=[ws_b.address]) # copy x to ws_b # Make sure that the worker NOT being retired has the most RAM usage to test that # it is not being picked first since there's a retiring worker. - w = c.submit(lambda: "w" * 2**28, key="w", workers=[ws_b.address]) # 256 MiB + w = c.submit(lambda: "w" * 100, key="w", workers=[ws_b.address]) await wait([x, y, z, w]) await c.retire_workers([ws_a.address], remove=False) @@ -960,8 +954,9 @@ async def test_RetireWorker_new_keys_arrive_after_all_keys_moved_away(c, s, a, b while not len(ws_b.has_what) == len(xs): await asyncio.sleep(0) - # `_track_retire_worker` _should_ now be sleeping for 0.5s, because there were >=200 keys on A. - # In this test, everything from the beginning of the transfers needs to happen within 0.5s. + # `_track_retire_worker` _should_ now be sleeping for 0.5s, because there were >=200 + # keys on A. In this test, everything from the beginning of the transfers needs to + # happen within 0.5s. # Simulate the policy running again. Because the default 2s AMM interval is longer # than the 0.5s wait, what we're about to trigger is unlikely, but still possible @@ -1008,52 +1003,52 @@ async def test_RetireWorker_new_keys_arrive_after_all_keys_moved_away(c, s, a, b await extra.result() -# FIXME can't drop runtime of this test below 10s; see distributed#5585 -@pytest.mark.slow @gen_cluster( client=True, - Worker=Nanny, - nthreads=[("", 1)] * 3, config={ "distributed.scheduler.worker-ttl": "500ms", "distributed.scheduler.active-memory-manager.start": True, - "distributed.scheduler.active-memory-manager.interval": 0.1, + "distributed.scheduler.active-memory-manager.interval": 0.05, + "distributed.scheduler.active-memory-manager.measure": "managed", "distributed.scheduler.active-memory-manager.policies": [], }, ) -async def test_RetireWorker_faulty_recipient(c, s, *nannies): - """RetireWorker requests to replicate a key onto a unresponsive worker. +async def test_RetireWorker_faulty_recipient(c, s, w1, w2): + """RetireWorker requests to replicate a key onto an unresponsive worker. The AMM will iterate multiple times, repeating the command, until eventually the scheduler declares the worker dead and removes it from the pool; at that point the AMM will choose another valid worker and complete the job. """ - # ws1 is being retired - # ws2 has the lowest RAM usage and is chosen as a recipient, but is unresponsive - ws1, ws2, ws3 = s.workers.values() - f = c.submit(lambda: "x", key="x", workers=[ws1.address]) - await wait(f) - assert s.tasks["x"].who_has == {ws1} + # w1 is being retired + # w3 has the lowest RAM usage and is chosen as a recipient, but is unresponsive + + x = c.submit(lambda: 123, key="x", workers=[w1.address]) + await wait(x) + # Fill w2 with dummy data so that it's got the highest memory usage + clutter = await c.scatter(456, workers=[w2.address]) + + async with BlockedGatherDep(s.address) as w3: + await c.wait_for_workers(3) + + retire_fut = asyncio.create_task(c.retire_workers([w1.address])) + # w3 is chosen as the recipient for x, because it's got the lowest memory usage + await w3.in_gather_dep.wait() + + # AMM unfruitfully sends to w3 a new {op: acquire-replicas} message every 0.05s + while ( + sum(isinstance(ev, AcquireReplicasEvent) for ev in w3.state.stimulus_log) + < 3 + ): + await asyncio.sleep(0.01) - # Fill ws3 with 200 MB of managed memory - # We're using plenty to make sure it's safely more than the unmanaged memory of ws2 - clutter = c.map(lambda i: "x" * 4_000_000, range(50), workers=[ws3.address]) - await wait([f] + clutter) - while ws3.memory.process < 200_000_000: - # Wait for heartbeat - await asyncio.sleep(0.01) - assert ws2.memory.process < ws3.memory.process + assert not retire_fut.done() - # Make ws2 unresponsive - clog_fut = asyncio.create_task(c.run(sleep, 3600, workers=[ws2.address])) - await asyncio.sleep(0.2) - assert ws2.address in s.workers + # w3 has been shut down. At this point, AMM switches to w2. + await retire_fut - await c.retire_workers([ws1.address]) - assert ws1.address not in s.workers - # The AMM tried over and over to send the data to ws2, until it was declared dead - assert ws2.address not in s.workers - assert s.tasks["x"].who_has == {ws3} - clog_fut.cancel() + assert w1.address not in s.workers + assert w3.address not in s.workers + assert dict(w2.data) == {"x": 123, clutter.key: 456} class DropEverything(ActiveMemoryManagerPolicy): @@ -1082,20 +1077,21 @@ async def tensordot_stress(c): da = pytest.importorskip("dask.array") rng = da.random.RandomState(0) - a = rng.random((20, 20), chunks=(1, 1)) - b = (a @ a.T).sum().round(3) - assert await c.compute(b) == 2134.398 + a = rng.random((10, 10), chunks=(1, 1)) + # dask.array.core.PerformanceWarning: Increasing number of chunks by factor of 10 + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + b = (a @ a.T).sum().round(3) + assert await c.compute(b) == 245.394 @pytest.mark.slow -@pytest.mark.avoid_ci(reason="distributed#5371") @gen_cluster( client=True, nthreads=[("", 1)] * 4, - Worker=Nanny, config=NO_AMM_START, ) -async def test_noamm_stress(c, s, *nannies): +async def test_noamm_stress(c, s, *workers): """Test the tensordot_stress helper without AMM. This is to figure out if a stability issue is AMM-specific or not. """ @@ -1103,20 +1099,19 @@ async def test_noamm_stress(c, s, *nannies): @pytest.mark.slow -@pytest.mark.avoid_ci(reason="distributed#5371") @gen_cluster( client=True, nthreads=[("", 1)] * 4, - Worker=Nanny, config={ "distributed.scheduler.active-memory-manager.start": True, "distributed.scheduler.active-memory-manager.interval": 0.1, + "distributed.scheduler.active-memory-manager.measure": "managed", "distributed.scheduler.active-memory-manager.policies": [ {"class": "distributed.tests.test_active_memory_manager.DropEverything"}, ], }, ) -async def test_drop_stress(c, s, *nannies): +async def test_drop_stress(c, s, *workers): """A policy which suggests dropping everything won't break a running computation, but only slow it down. @@ -1126,20 +1121,19 @@ async def test_drop_stress(c, s, *nannies): @pytest.mark.slow -@pytest.mark.avoid_ci(reason="distributed#5371") @gen_cluster( client=True, nthreads=[("", 1)] * 4, - Worker=Nanny, config={ "distributed.scheduler.active-memory-manager.start": True, "distributed.scheduler.active-memory-manager.interval": 0.1, + "distributed.scheduler.active-memory-manager.measure": "managed", "distributed.scheduler.active-memory-manager.policies": [ {"class": "distributed.active_memory_manager.ReduceReplicas"}, ], }, ) -async def test_ReduceReplicas_stress(c, s, *nannies): +async def test_ReduceReplicas_stress(c, s, *workers): """Running ReduceReplicas compulsively won't break a running computation. Unlike test_drop_stress above, this test does not stop running after a few seconds - the policy must not disrupt the computation too much. @@ -1148,19 +1142,14 @@ async def test_ReduceReplicas_stress(c, s, *nannies): @pytest.mark.slow -@pytest.mark.avoid_ci(reason="distributed#5371") @pytest.mark.parametrize("use_ReduceReplicas", [False, True]) @gen_cluster( client=True, nthreads=[("", 1)] * 10, - Worker=Nanny, config={ "distributed.scheduler.active-memory-manager.start": True, - # If interval is too low, then the AMM will rerun while tasks have not yet have - # the time to migrate. This is OK if it happens occasionally, but if this - # setting is too aggressive the cluster will get flooded with repeated comm - # requests. - "distributed.scheduler.active-memory-manager.interval": 2.0, + "distributed.scheduler.active-memory-manager.interval": 0.1, + "distributed.scheduler.active-memory-manager.measure": "managed", "distributed.scheduler.active-memory-manager.policies": [ {"class": "distributed.active_memory_manager.ReduceReplicas"}, ], @@ -1168,7 +1157,7 @@ async def test_ReduceReplicas_stress(c, s, *nannies): scheduler_kwargs={"transition_counter_max": 500_000}, worker_kwargs={"transition_counter_max": 500_000}, ) -async def test_RetireWorker_stress(c, s, *nannies, use_ReduceReplicas): +async def test_RetireWorker_stress(c, s, *workers, use_ReduceReplicas): """It is safe to retire the best part of a cluster in the middle of a computation""" if not use_ReduceReplicas: s.extensions["amm"].policies.clear() diff --git a/docs/source/active_memory_manager.rst b/docs/source/active_memory_manager.rst index 10eb9c92903..cf5416b33ac 100644 --- a/docs/source/active_memory_manager.rst +++ b/docs/source/active_memory_manager.rst @@ -36,6 +36,7 @@ The AMM can be enabled through the :doc:`Dask configuration file active-memory-manager: start: true interval: 2s + measure: optimistic The above is the recommended setup and will run all enabled *AMM policies* (see below) every two seconds. Alternatively, you can manually start/stop the AMM from the @@ -79,6 +80,7 @@ Individual policies are enabled, disabled, and configured through the Dask confi active-memory-manager: start: true interval: 2s + measure: optimistic policies: - class: distributed.active_memory_manager.ReduceReplicas - class: my_package.MyPolicy From e667a9bb76464b0fa550f7e7d336b433d5eb8f65 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 28 Sep 2022 12:55:32 +0100 Subject: [PATCH 2/2] Update distributed/tests/test_active_memory_manager.py --- distributed/tests/test_active_memory_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/distributed/tests/test_active_memory_manager.py b/distributed/tests/test_active_memory_manager.py index d18c7894531..c07765f41e5 100644 --- a/distributed/tests/test_active_memory_manager.py +++ b/distributed/tests/test_active_memory_manager.py @@ -1025,6 +1025,7 @@ async def test_RetireWorker_faulty_recipient(c, s, w1, w2): x = c.submit(lambda: 123, key="x", workers=[w1.address]) await wait(x) # Fill w2 with dummy data so that it's got the highest memory usage + # among the workers that are not being retired (w2 and w3). clutter = await c.scatter(456, workers=[w2.address]) async with BlockedGatherDep(s.address) as w3: