diff --git a/distributed/client.py b/distributed/client.py index e63a3173502..115aa1a6be3 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1469,8 +1469,9 @@ def _handle_restart(self): for state in self.futures.values(): state.cancel() self.futures.clear() - with suppress(AttributeError): - self._restart_event.set() + self.generation += 1 + with self._refcount_lock: + self.refcount.clear() def _handle_error(self, exception=None): logger.warning("Scheduler exception:") @@ -3319,32 +3320,46 @@ def persist( else: return result - async def _restart(self, timeout=no_default): + async def _restart(self, timeout=no_default, wait_for_workers=True): if timeout == no_default: - timeout = self._timeout * 2 + timeout = self._timeout * 4 if timeout is not None: timeout = parse_timedelta(timeout, "s") - self._send_to_scheduler({"op": "restart", "timeout": timeout}) - self._restart_event = asyncio.Event() - try: - await asyncio.wait_for(self._restart_event.wait(), timeout) - except TimeoutError: - logger.error("Restart timed out after %.2f seconds", timeout) + await self.scheduler.restart(timeout=timeout, wait_for_workers=wait_for_workers) + return self - self.generation += 1 - with self._refcount_lock: - self.refcount.clear() + def restart(self, timeout=no_default, wait_for_workers=True): + """ + Restart all workers. Reset local state. Optionally wait for workers to return. - return self + Workers without nannies are shut down, hoping an external deployment system + will restart them. Therefore, if not using nannies and your deployment system + does not automatically restart workers, ``restart`` will just shut down all + workers, then time out! - def restart(self, **kwargs): - """Restart the distributed network + After `restart`, all connected workers are new, regardless of whether `TimeoutError` + was raised. Any workers that failed to shut down in time are removed, and + may or may not shut down on their own in the future. - This kills all active work, deletes all data on the network, and - restarts the worker processes. + Parameters + ---------- + timeout: + How long to wait for workers to shut down and come back, if `wait_for_workers` + is True, otherwise just how long to wait for workers to shut down. + Raises `asyncio.TimeoutError` if this is exceeded. + wait_for_workers: + Whether to wait for all workers to reconnect, or just for them to shut down + (default True). Use ``restart(wait_for_workers=False)`` combined with + `Client.wait_for_workers` for granular control over how many workers to + wait for. + See also + ---------- + Scheduler.restart """ - return self.sync(self._restart, **kwargs) + return self.sync( + self._restart, timeout=timeout, wait_for_workers=wait_for_workers + ) async def _upload_large_file(self, local_filename, remote_filename=None): if remote_filename is None: diff --git a/distributed/nanny.py b/distributed/nanny.py index 2eae45bc3d8..7818efee85d 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -379,7 +379,7 @@ async def kill(self, timeout=2): informed """ if self.process is None: - return "OK" + return deadline = time() + timeout await self.process.kill(timeout=0.8 * (deadline - time())) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ab5fb003ed8..775732bb8bd 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -73,7 +73,7 @@ from distributed.event import EventExtension from distributed.http import get_handlers from distributed.lock import LockExtension -from distributed.metrics import time +from distributed.metrics import monotonic, time from distributed.multi_lock import MultiLockExtension from distributed.node import ServerNode from distributed.proctitle import setproctitle @@ -3077,7 +3077,6 @@ def __init__( "client-releases-keys": self.client_releases_keys, "heartbeat-client": self.client_heartbeat, "close-client": self.remove_client, - "restart": self.restart, "subscribe-topic": self.subscribe_topic, "unsubscribe-topic": self.unsubscribe_topic, } @@ -3114,6 +3113,7 @@ def __init__( "rebalance": self.rebalance, "replicate": self.replicate, "run_function": self.run_function, + "restart": self.restart, "update_data": self.update_data, "set_resources": self.add_resources, "retire_workers": self.retire_workers, @@ -5112,12 +5112,37 @@ def clear_task_state(self): collection.clear() @log_errors - async def restart(self, client=None, timeout=30): - """Restart all workers. Reset local state.""" + async def restart(self, client=None, timeout=30, wait_for_workers=True): + """ + Restart all workers. Reset local state. Optionally wait for workers to return. + + Workers without nannies are shut down, hoping an external deployment system + will restart them. Therefore, if not using nannies and your deployment system + does not automatically restart workers, ``restart`` will just shut down all + workers, then time out! + + After `restart`, all connected workers are new, regardless of whether `TimeoutError` + was raised. Any workers that failed to shut down in time are removed, and + may or may not shut down on their own in the future. + + Parameters + ---------- + timeout: + How long to wait for workers to shut down and come back, if `wait_for_workers` + is True, otherwise just how long to wait for workers to shut down. + Raises `asyncio.TimeoutError` if this is exceeded. + wait_for_workers: + Whether to wait for all workers to reconnect, or just for them to shut down + (default True). Use ``restart(wait_for_workers=False)`` combined with + `Client.wait_for_workers` for granular control over how many workers to + wait for. + See also + ---------- + Client.restart + """ stimulus_id = f"restart-{time()}" - n_workers = len(self.workers) - logger.info("Send lost future signal to clients") + logger.info("Releasing all requested keys") for cs in self.clients.values(): self.client_releases_keys( keys=[ts.key for ts in cs.wants_what], @@ -5125,10 +5150,21 @@ async def restart(self, client=None, timeout=30): stimulus_id=stimulus_id, ) + self.clear_task_state() + self.erred_tasks.clear() + self.computations.clear() + self.report({"op": "restart"}) + + for plugin in list(self.plugins.values()): + try: + plugin.restart(self) + except Exception as e: + logger.exception(e) + + n_workers = len(self.workers) nanny_workers = { addr: ws.nanny for addr, ws in self.workers.items() if ws.nanny } - # Close non-Nanny workers. We have no way to restart them, so we just let them go, # and assume a deployment system is going to restart them for us. await asyncio.gather( @@ -5139,59 +5175,76 @@ async def restart(self, client=None, timeout=30): ) ) - self.clear_task_state() - - for plugin in list(self.plugins.values()): - try: - plugin.restart(self) - except Exception as e: - logger.exception(e) - logger.debug("Send kill signal to nannies: %s", nanny_workers) async with contextlib.AsyncExitStack() as stack: - nannies = [ - await stack.enter_async_context( - rpc(nanny_address, connection_args=self.connection_args) + nannies = await asyncio.gather( + *( + stack.enter_async_context( + rpc(nanny_address, connection_args=self.connection_args) + ) + for nanny_address in nanny_workers.values() ) - for nanny_address in nanny_workers.values() - ] + ) - try: - resps = await asyncio.wait_for( - asyncio.gather( - *( - nanny.restart(close=True, timeout=timeout * 0.8) - for nanny in nannies - ) - ), - timeout, - ) - # NOTE: the `WorkerState` entries for these workers will be removed - # naturally when they disconnect from the scheduler. - except TimeoutError: - logger.error( - "Nannies didn't report back restarted within " - "timeout. Continuing with restart process" - ) - else: - if not all(resp == "OK" for resp in resps): - logger.error( - "Not all workers responded positively: %s", - resps, - exc_info=True, + start = monotonic() + resps = await asyncio.gather( + *( + asyncio.wait_for( + # FIXME does not raise if the process fails to shut down, + # see https://github.com/dask/distributed/pull/6427/files#r894917424 + # NOTE: Nanny will automatically restart worker process when it's killed + nanny.kill(timeout=timeout), + timeout, ) + for nanny in nannies + ), + return_exceptions=True, + ) + # NOTE: the `WorkerState` entries for these workers will be removed + # naturally when they disconnect from the scheduler. - self.clear_task_state() + # Remove any workers that failed to shut down, so we can guarantee + # that after `restart`, there are no old workers around. + bad_nannies = [ + addr for addr, resp in zip(nanny_workers, resps) if resp is not None + ] + if bad_nannies: + await asyncio.gather( + *( + self.remove_worker(addr, stimulus_id=stimulus_id) + for addr in bad_nannies + ) + ) - self.erred_tasks.clear() - self.computations.clear() + raise TimeoutError( + f"{len(bad_nannies)}/{len(nannies)} nanny worker(s) did not shut down within {timeout}s" + ) self.log_event([client, "all"], {"action": "restart", "client": client}) - start = time() - while time() < start + 10 and len(self.workers) < n_workers: - await asyncio.sleep(0.01) - self.report({"op": "restart"}) + if wait_for_workers: + while monotonic() < start + timeout: + # NOTE: if new (unrelated) workers join while we're waiting, we may return before + # our shut-down workers have come back up. That's fine; workers are interchangeable. + if len(self.workers) >= n_workers: + return + await asyncio.sleep(0.2) + else: + msg = ( + f"Waited for {n_workers} worker(s) to reconnect after restarting, " + f"but after {timeout}s, only {len(self.workers)} have returned. " + "Consider a longer timeout, or `wait_for_workers=False`." + ) + + if (n_nanny := len(nanny_workers)) < n_workers: + msg += ( + f" The {n_workers - n_nanny} worker(s) not using Nannies were just shut " + "down instead of restarted (restart is only possible with Nannies). If " + "your deployment system does not automatically re-launch terminated " + "processes, then those workers will never come back, and `Client.restart` " + "will always time out. Do not use `Client.restart` in that case." + ) + raise TimeoutError(msg) from None async def broadcast( self, diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index c8c849eea79..3625caa163c 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -92,7 +92,6 @@ double, gen_cluster, gen_test, - geninc, get_cert, inc, map_varying, @@ -2662,13 +2661,13 @@ def func(x, y=10): @gen_cluster(client=True) async def test_run_coroutine(c, s, a, b): - results = await c.run(geninc, 1, delay=0.05) + results = await c.run(asyncinc, 1, delay=0.05) assert results == {a.address: 2, b.address: 2} - results = await c.run(geninc, 1, delay=0.05, workers=[a.address]) + results = await c.run(asyncinc, 1, delay=0.05, workers=[a.address]) assert results == {a.address: 2} - results = await c.run(geninc, 1, workers=[]) + results = await c.run(asyncinc, 1, workers=[]) assert results == {} with pytest.raises(RuntimeError, match="hello"): @@ -2679,14 +2678,14 @@ async def test_run_coroutine(c, s, a, b): def test_run_coroutine_sync(c, s, a, b): - result = c.run(geninc, 2, delay=0.01) + result = c.run(asyncinc, 2, delay=0.01) assert result == {a["address"]: 3, b["address"]: 3} - result = c.run(geninc, 2, workers=[a["address"]]) + result = c.run(asyncinc, 2, workers=[a["address"]]) assert result == {a["address"]: 3} t1 = time() - result = c.run(geninc, 2, delay=10, wait=False) + result = c.run(asyncinc, 2, delay=10, wait=False) t2 = time() assert result is None assert t2 - t1 <= 1.0 @@ -3498,13 +3497,17 @@ def block(ev): @pytest.mark.slow -@gen_cluster(Worker=Nanny, client=True, timeout=60) +@gen_cluster(client=True) async def test_Client_clears_references_after_restart(c, s, a, b): x = c.submit(inc, 1) assert x.key in c.refcount + assert x.key in c.futures + + with pytest.raises(TimeoutError): + await c.restart(timeout=5) - await c.restart() assert x.key not in c.refcount + assert not c.futures key = x.key del x @@ -3513,14 +3516,6 @@ async def test_Client_clears_references_after_restart(c, s, a, b): assert key not in c.refcount -@gen_cluster(Worker=Nanny, client=True) -async def test_restart_timeout_is_logged(c, s, a, b): - with captured_logger(logging.getLogger("distributed.client")) as logger: - await c.restart(timeout="0.5s") - text = logger.getvalue() - assert "Restart timed out after 0.50 seconds" in text - - def test_get_stops_work_after_error(c): with pytest.raises(RuntimeError): c.get({"x": (throws, 1), "y": (sleep, 1.5)}, ["x", "y"]) diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 9aac7d6f4b7..754e2916c0c 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -19,6 +19,7 @@ from distributed.utils import CancelledError, sync from distributed.utils_test import ( BlockedGatherDep, + async_wait_for, captured_logger, cluster, div, @@ -238,6 +239,23 @@ async def test_multiple_clients_restart(s, a, b): await asyncio.sleep(0.01) assert time() < start + 5 + assert not c1.futures + assert not c2.futures + + # Ensure both clients still work after restart. + # Reusing a previous key has no effect. + x2 = c1.submit(inc, 1, key=x.key) + y2 = c2.submit(inc, 2, key=y.key) + + assert x2._generation != x._generation + assert y2._generation != y._generation + + assert await x2 == 2 + assert await y2 == 3 + + del x2, y2 + await async_wait_for(lambda: not s.tasks, timeout=5) + await c1.close() await c2.close() diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 9ea5a0bc7c1..a0ee5823dda 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -629,68 +629,112 @@ async def test_restart(c, s, a, b): assert not s.tasks -@gen_cluster(client=True, Worker=Nanny, timeout=60) -async def test_restart_some_nannies_some_not(c, s, a, b): - original_procs = {a.process.process, b.process.process} +@pytest.mark.slow +@gen_cluster(client=True, Worker=Nanny, nthreads=[("", 1)] * 5) +async def test_restart_waits_for_new_workers(c, s, *workers): + original_procs = {n.process.process for n in workers} original_workers = dict(s.workers) - async with Worker(s.address, nthreads=1) as w: - await c.wait_for_workers(3) - # Halfway through `Scheduler.restart`, only the non-Nanny workers should be removed. - # Nanny-based workers should be kept around so we can call their `restart` RPC. - class ValidateRestartPlugin(SchedulerPlugin): - error: Exception | None - - def restart(self, scheduler: Scheduler) -> None: - try: - assert scheduler.workers.keys() == { - a.worker_address, - b.worker_address, - } - assert all(ws.nanny for ws in scheduler.workers.values()) - except Exception as e: - # `Scheduler.restart` swallows exceptions within plugins - self.error = e - raise - else: - self.error = None - - plugin = ValidateRestartPlugin() - s.add_plugin(plugin) - await s.restart() - - if plugin.error: - raise plugin.error - - assert w.status == Status.closed + await c.restart() + assert len(s.workers) == len(original_workers) + for w in workers: + assert w.address not in s.workers - assert len(s.workers) == 2 - # Confirm they restarted - # NOTE: == for `psutil.Process` compares PID and creation time - new_procs = {a.process.process, b.process.process} - assert new_procs != original_procs - # The workers should have new addresses - assert s.workers.keys().isdisjoint(original_workers.keys()) - # The old WorkerState instances should be replaced - assert set(s.workers.values()).isdisjoint(original_workers.values()) + # Confirm they restarted + # NOTE: == for `psutil.Process` compares PID and creation time + new_procs = {n.process.process for n in workers} + assert new_procs != original_procs + # The workers should have new addresses + assert s.workers.keys().isdisjoint(original_workers.keys()) + # The old WorkerState instances should be replaced + assert set(s.workers.values()).isdisjoint(original_workers.values()) -class SlowRestartNanny(Nanny): +class SlowKillNanny(Nanny): def __init__(self, *args, **kwargs): - self.restart_proceed = asyncio.Event() - self.restart_called = asyncio.Event() + self.kill_proceed = asyncio.Event() + self.kill_called = asyncio.Event() super().__init__(*args, **kwargs) - async def restart(self, **kwargs): - self.restart_called.set() - await self.restart_proceed.wait() - return await super().restart(**kwargs) + async def kill(self, *, timeout): + self.kill_called.set() + print("kill called") + await asyncio.wait_for(self.kill_proceed.wait(), timeout) + print("kill proceed") + return await super().kill(timeout=timeout) + + +@gen_cluster(client=True, Worker=SlowKillNanny, nthreads=[("", 1)] * 2) +async def test_restart_nanny_timeout_exceeded(c, s, a, b): + f = c.submit(div, 1, 0) + fr = c.submit(inc, 1, resources={"FOO": 1}) + await wait(f) + assert s.erred_tasks + assert s.computations + assert s.unrunnable + assert s.tasks + + with pytest.raises( + TimeoutError, match=r"2/2 nanny worker\(s\) did not shut down within 1s" + ): + await c.restart(timeout="1s") + assert a.kill_called.is_set() + assert b.kill_called.is_set() + + assert not s.workers + assert not s.erred_tasks + assert not s.computations + assert not s.unrunnable + assert not s.tasks + + assert not c.futures + assert f.status == "cancelled" + assert fr.status == "cancelled" + + +@gen_cluster(client=True, nthreads=[("", 1)] * 2) +async def test_restart_not_all_workers_return(c, s, a, b): + with pytest.raises(TimeoutError, match="Waited for 2 worker"): + await c.restart(timeout="1s") + + assert not s.workers + assert a.status in (Status.closed, Status.closing) + assert b.status in (Status.closed, Status.closing) + + +@gen_cluster(client=True, nthreads=[("", 1)] * 2) +async def test_restart_no_wait_for_workers(c, s, a, b): + await c.restart(timeout="1s", wait_for_workers=False) + + assert not s.workers + # Workers are not immediately closed because of https://github.com/dask/distributed/issues/6390 + # (the message is still waiting in the BatchedSend) + await a.finished() + await b.finished() + + +@pytest.mark.slow +@gen_cluster(client=True, Worker=Nanny) +async def test_restart_some_nannies_some_not(c, s, a, b): + original_addrs = set(s.workers) + async with Worker(s.address, nthreads=1) as w: + await c.wait_for_workers(3) + + # FIXME how to make this not always take 20s if the nannies do restart quickly? + with pytest.raises(TimeoutError, match=r"The 1 worker\(s\) not using Nannies"): + await c.restart(timeout="20s") + + assert w.status == Status.closed + + assert len(s.workers) == 2 + assert set(s.workers).isdisjoint(original_addrs) + assert w.address not in s.workers @gen_cluster( client=True, nthreads=[("", 1)], - Worker=SlowRestartNanny, + Worker=SlowKillNanny, worker_kwargs={"heartbeat_interval": "1ms"}, ) async def test_restart_heartbeat_before_closing(c, s, n): @@ -701,13 +745,13 @@ async def test_restart_heartbeat_before_closing(c, s, n): prev_workers = dict(s.workers) restart_task = asyncio.create_task(s.restart()) - await n.restart_called.wait() + await n.kill_called.wait() await asyncio.sleep(0.5) # significantly longer than the heartbeat interval # WorkerState should not be removed yet, because the worker hasn't been told to close assert s.workers - n.restart_proceed.set() + n.kill_proceed.set() # Wait until the worker has left (possibly until it's come back too) while s.workers == prev_workers: await asyncio.sleep(0.01) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 4994ea8c464..fadc6002376 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -430,11 +430,6 @@ def apply(func, *args, **kwargs): return apply, list(map(varying, itemslists)) -async def geninc(x, delay=0.02): - await asyncio.sleep(delay) - return x + 1 - - async def asyncinc(x, delay=0.02): await asyncio.sleep(delay) return x + 1