diff --git a/distributed/nanny.py b/distributed/nanny.py index 55c7838d8f3..65a2d303785 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -13,7 +13,7 @@ from inspect import isawaitable from queue import Empty from time import sleep as sync_sleep -from typing import ClassVar +from typing import TYPE_CHECKING, ClassVar, Literal import psutil from tornado import gen @@ -45,6 +45,9 @@ ) from .worker import Worker, parse_memory_limit, run +if TYPE_CHECKING: + from .diagnostics.plugin import NannyPlugin + logger = logging.getLogger(__name__) @@ -94,6 +97,7 @@ def __init__( services=None, name=None, memory_limit="auto", + memory_terminate_fraction: float | Literal[False] | None = None, reconnect=True, validate=False, quiet=False, @@ -203,8 +207,10 @@ def __init__( self.worker_kwargs = worker_kwargs self.contact_address = contact_address - self.memory_terminate_fraction = dask.config.get( - "distributed.worker.memory.terminate" + self.memory_terminate_fraction = ( + memory_terminate_fraction + if memory_terminate_fraction is not None + else dask.config.get("distributed.worker.memory.terminate") ) self.services = services @@ -231,7 +237,7 @@ def __init__( "plugin_remove": self.plugin_remove, } - self.plugins = {} + self.plugins: dict[str, NannyPlugin] = {} super().__init__( handlers=handlers, io_loop=self.loop, connection_args=self.connection_args diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index db5cc866b5c..97ee33104e5 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -15,7 +15,7 @@ from distributed import Client, Nanny, Scheduler, Worker, config, default_client from distributed.compatibility import WINDOWS -from distributed.core import Server, rpc +from distributed.core import Server, Status, rpc from distributed.metrics import time from distributed.utils import mp_context from distributed.utils_test import ( @@ -28,6 +28,7 @@ gen_cluster, gen_test, inc, + mock_rss, new_config, tls_only_security, ) @@ -607,3 +608,27 @@ def test_start_failure_scheduler(): with pytest.raises(TypeError): with cluster(scheduler_kwargs={"foo": "bar"}): return + + +@gen_cluster( + client=True, + worker_kwargs={"heartbeat_interval": "10ms", "memory_monitor_interval": "10ms"}, +) +async def test_mock_rss(c, s, a, b): + # Test that it affects the readings sent to the Scheduler + mock_rss(a, 2e6) + while s.workers[a.address].memory.process != 2_000_000: + await asyncio.sleep(0.01) + + # Test that the instance has been mocked, not the class + assert s.workers[b.address].memory.process > 10e6 + + # Test that it's compatible with Client.run and can be used with Nannies + await c.run(mock_rss, nbytes=3e6, workers=[b.address]) + while s.workers[b.address].memory.process != 3_000_000: + await asyncio.sleep(0.01) + + # Test that it affects Worker.memory_monitor + mock_rss(a, 100e9) + while s.workers[a.address].status != Status.paused: + await asyncio.sleep(0.01) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index fc024f3b98a..f19acb724c1 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -27,6 +27,7 @@ import distributed from distributed import ( Client, + Event, Nanny, Reschedule, default_client, @@ -53,6 +54,7 @@ gen_cluster, gen_test, inc, + mock_rss, mul, nodebug, slowinc, @@ -1287,6 +1289,7 @@ async def test_spill_constrained(c, s, w): nthreads=[("", 1)], client=True, worker_kwargs=dict( + memory_limit="1000 MB", memory_monitor_interval="10ms", memory_target_fraction=False, memory_spill_fraction=0.7, @@ -1298,157 +1301,74 @@ async def test_spill_spill_threshold(c, s, a): Test that the spill threshold uses the process memory and not the managed memory reported by sizeof(), which may be inaccurate. """ - # Reach 'spill' threshold after 400MB of managed data. We need to be generous in - # order to avoid flakiness due to fluctuations in unmanaged memory. - # FIXME https://github.com/dask/distributed/issues/5367 - # This works just by luck for the purpose of the spill and pause thresholds, - # and does NOT work for the target threshold. - memory = psutil.Process().memory_info().rss - a.memory_limit = (memory + 300e6) / 0.7 - - class UnderReport: - """100 MB process memory, 10 bytes reported managed memory""" - - def __init__(self, *args): - self.data = "x" * int(100e6) - - def __sizeof__(self): - return 10 - - def __reduce__(self): - """Speed up test by writing very little to disk when spilling""" - return UnderReport, () - - futures = c.map(UnderReport, range(8)) - + x = c.submit(inc, 0, key="x") + mock_rss(a, 800e6) while not a.data.disk: await asyncio.sleep(0.01) - - -async def assert_not_everything_is_spilled(w: Worker) -> None: - start = time() - while time() < start + 0.5: - assert w.data - if not w.data.memory: # type: ignore - # The hysteresis system fails on Windows and MacOSX because process memory - # is very slow to shrink down after calls to PyFree. As a result, - # Worker.memory_monitor will continue spilling until there's nothing left. - # Nothing we can do about this short of finding either a way to change this - # behaviour at OS level or a better measure of allocated memory. - assert not LINUX, "All data was spilled to disk" - raise pytest.xfail("https://github.com/dask/distributed/issues/5840") - await asyncio.sleep(0) + assert await x == 1 @requires_zict -@gen_cluster( - nthreads=[("", 1)], - client=True, - worker_kwargs=dict( - # FIXME https://github.com/dask/distributed/issues/5367 - # Can't reconfigure the absolute target threshold after the worker - # started, so we're setting it here to something extremely small and then - # increasing the memory_limit dynamically below in order to test the - # spill threshold. - memory_limit=1, - memory_monitor_interval="10ms", - memory_target_fraction=False, - memory_spill_fraction=0.7, - memory_pause_fraction=False, - ), +@pytest.mark.parametrize( + "memory_target_fraction,managed,expect_spilled", + [ + # no target -> no hysteresis + # Over-report managed memory to test that the automated LRU eviction based on + # target is never triggered + (False, 10e9, 1), + # Under-report managed memory, so that we reach the spill threshold for process + # memory without first reaching the target threshold for managed memory + # target == spill -> no hysteresis + (0.7, 0, 1), + # target < spill -> hysteresis from spill to target + (0.4, 0, 7), + ], ) -async def test_spill_no_target_threshold(c, s, a): - """Test that you can enable the spill threshold while leaving the target threshold - to False +@gen_cluster(nthreads=[], client=True) +async def test_spill_hysteresis(c, s, memory_target_fraction, managed, expect_spilled): + """ + 1. Test that you can enable the spill threshold while leaving the target threshold + to False + 2. Test the hysteresis system where, once you reach the spill threshold, the worker + won't stop spilling until the target threshold is reached """ - memory = psutil.Process().memory_info().rss - a.memory_limit = (memory + 300e6) / 0.7 # 300 MB before we start spilling - - class OverReport: - """Configurable process memory, 10 GB reported managed memory""" - - def __init__(self, size): - self.data = "x" * size + class C: def __sizeof__(self): - return int(10e9) - - def __reduce__(self): - """Speed up test by writing very little to disk when spilling""" - return OverReport, (len(self.data),) - - f1 = c.submit(OverReport, 0, key="f1") - await wait(f1) - assert set(a.data.memory) == {"f1"} - - # 800 MB. Use large chunks to stimulate timely release of process memory. - futures = c.map(OverReport, range(int(100e6), int(100e6) + 8)) - - while not a.data.disk: - await asyncio.sleep(0.01) - assert "f1" in a.data.disk - - # Spilling normally starts at the spill threshold and stops at the target threshold. - # In this special case, it stops as soon as the process memory goes below the spill - # threshold, e.g. without a hysteresis cycle. Test that we didn't instead dump the - # whole data to disk (memory_limit * target = 0) - await assert_not_everything_is_spilled(a) - + return managed -@pytest.mark.slow -@requires_zict -@gen_cluster( - nthreads=[("", 1)], - client=True, - worker_kwargs=dict( - memory_limit="1 GiB", # See FIXME note in previous test + async with Worker( + s.address, + memory_limit="1000 MB", memory_monitor_interval="10ms", - memory_target_fraction=0.4, + memory_target_fraction=memory_target_fraction, memory_spill_fraction=0.7, memory_pause_fraction=False, - ), -) -async def test_spill_hysteresis(c, s, a): - memory = psutil.Process().memory_info().rss - a.memory_limit = (memory + 1e9) / 0.7 # Start spilling after 1 GB - - # Under-report managed memory, so that we reach the spill threshold for process - # memory without first reaching the target threshold for managed memory - class UnderReport: - def __init__(self): - self.data = "x" * int(100e6) # 100 MB - - def __sizeof__(self): - return 1 - - def __reduce__(self): - """Speed up test by writing very little to disk when spilling""" - return UnderReport, () - - max_in_memory = 0 - futures = [] - while not a.data.disk: - futures.append(c.submit(UnderReport, pure=False)) - max_in_memory = max(max_in_memory, len(a.data.memory)) + ) as a: + # Add 500MB (reported) process memory. Spilling must not happen. + futures = [c.submit(C, pure=False) for _ in range(10)] + mock_rss(a, 500e6) await wait(futures) - await asyncio.sleep(0.05) - max_in_memory = max(max_in_memory, len(a.data.memory)) - - # If there were no hysteresis, we would lose exactly 1 key. - # Note that, for this test to be meaningful, memory must shrink down readily when - # we deallocate Python objects. This is not always the case on Windows and MacOSX; - # on Linux we set MALLOC_TRIM to help in that regard. - # To verify that this test is useful, set target=spill and watch it fail. - while len(a.data.memory) > max_in_memory - 3: - await asyncio.sleep(0.01) - await assert_not_everything_is_spilled(a) + await asyncio.sleep(0.1) + assert not a.data.disk + + # Add another 250MB unamanaged memory. This must trigger the spilling. + mock_rss(a, 750e6) + # Wait until spilling starts. Then, wait until it stops. + prev_n = 0 + while not a.data.disk or len(a.data.disk) > prev_n: + prev_n = len(a.data.disk) + mock_rss(a, 250e6 + 50e6 * len(a.data.memory)) + await asyncio.sleep(0) + + assert len(a.data.disk) == expect_spilled -@pytest.mark.slow @gen_cluster( - nthreads=[("", 1)], + nthreads=[("", 2)], client=True, worker_kwargs=dict( + memory_limit="1000 MB", memory_monitor_interval="10ms", memory_target_fraction=False, memory_spill_fraction=False, @@ -1456,35 +1376,62 @@ def __reduce__(self): ), ) async def test_pause_executor(c, s, a): - # See notes in test_spill_spill_threshold - memory = psutil.Process().memory_info().rss - a.memory_limit = (memory + 160e6) / 0.8 # Pause after 200 MB + def f(ev_f): + ev_f.wait() - # Note: it's crucial to have a very large single chunk of memory that gets descoped - # all at once in order to instigate release of process memory. - # Read: https://github.com/dask/distributed/issues/5840 - def f(): - # Add 400 MB unmanaged memory - x = "x" * int(400e6) + def g(ev_g1, ev_g2): + ev_g1.wait() + # Add 900 MB unmanaged memory w = get_worker() - while w.status != Status.paused: - sleep(0.01) + mock_rss(w, 900e6) + ev_g2.wait() + mock_rss(w, 0) + + ev_f = Event() + ev_g1 = Event() + ev_g2 = Event() + + # Tasks that are running when the worker pauses + x = c.submit(f, ev_f, key="x") + y = c.submit(g, ev_g1, ev_g2, key="y") + while a.executing_count != 2: + await asyncio.sleep(0.01) - with captured_logger(logging.getLogger("distributed.worker")) as logger: - future = c.submit(f, key="x") - futures = c.map(slowinc, range(30), delay=0.1) + # Task that is queued on the worker when the worker pauses + z = c.submit(inc, 0, key="z") + while "z" not in a.tasks: + await asyncio.sleep(0.01) - while a.status != Status.paused: + with captured_logger(logging.getLogger("distributed.worker")) as logger: + # Hog the worker with 900MB memory + await ev_g1.set() + while s.workers[a.address].status != Status.paused: await asyncio.sleep(0.01) assert "Pausing worker" in logger.getvalue() - assert sum(f.status == "finished" for f in futures) < 4 - while a.status != Status.running: + # Task that is queued on the scheduler when the worker pauses. + # It is not sent to the worker + w = c.submit(inc, 0, key="w") + while "w" not in s.tasks or s.tasks["w"].state != "no-worker": await asyncio.sleep(0.01) + # Unlock a slot on the worker. It won't be used. + await ev_f.set() + await x + await asyncio.sleep(0.05) + + assert a.executing_count == 1 + assert len(a.ready) == 1 + assert a.tasks["z"].state == "ready" + assert "w" not in a.tasks + + # Release the memory + await ev_g2.set() + await wait([y, z, w]) + + assert a.status == Status.running assert "Resuming worker" in logger.getvalue() - await wait(futures) @gen_cluster(client=True, worker_kwargs={"profile_cycle_interval": "50 ms"}) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index f6625e74180..d56b8e2b85c 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -20,13 +20,14 @@ import threading import uuid import weakref -from collections import defaultdict +from collections import defaultdict, namedtuple from collections.abc import Callable from contextlib import contextmanager, nullcontext, suppress from glob import glob from itertools import count from time import sleep from typing import Any, Literal +from unittest.mock import MagicMock from distributed.compatibility import MACOS from distributed.scheduler import Scheduler @@ -1965,3 +1966,20 @@ def has_pytestmark(test_func: Callable, name: str) -> bool: """ marks = getattr(test_func, "pytestmark", []) return any(mark.name == name for mark in marks) + + +# Variant of psutil._pslinux.pmem, psutil._psosx.pmem, psutil._pswindows.pmem +pmem = namedtuple("pmem", "rss") + + +def mock_rss(dask_worker: Worker, nbytes: float) -> None: + """Mock all the process memory readings on a worker. Does not impact other workers. + + Usage: + + When using Workers: + >>> mock_rss(a, 100e6) + When using Nannies: + >>> await client.run(mock_rss, nbytes=100e6, workers=[a.worker_address]) + """ + dask_worker.monitor.proc.memory_info = MagicMock(return_value=pmem(int(nbytes))) diff --git a/distributed/worker.py b/distributed/worker.py index 3195fe38e0c..6d33c21de2d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3743,7 +3743,6 @@ def check_pause(memory): "Worker is at %.0f%% memory usage. Start spilling data to disk.", frac * 100, ) - start = time() # Implement hysteresis cycle where spilling starts at the spill threshold # and stops at the target threshold. Normally that here the target threshold # defines process memory, whereas normally it defines reported managed @@ -3768,18 +3767,14 @@ def check_pause(memory): break weight = self.data.evict() if weight == -1: - # Failed to evict: disk full, spill size limit exceeded, or pickle error + # Failed to evict: + # disk full, spill size limit exceeded, or pickle error break total += weight count += 1 - # If the current buffer is filled with a lot of small values, - # evicting one at a time is very slow and the worker might - # generate new data faster than it is able to evict. Therefore, - # only pass on control if we spent at least 0.5s evicting - if time() - start > 0.5: - await asyncio.sleep(0) - start = time() + await asyncio.sleep(0) + memory = proc.memory_info().rss if total > need and memory > target: # Issue a GC to ensure that the evicted data is actually