diff --git a/distributed/scheduler.py b/distributed/scheduler.py index d4a84c184b5..b03332bf9f5 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5091,19 +5091,19 @@ async def restart(self, client=None, timeout=30): stimulus_id=stimulus_id, ) - nannies = {addr: ws.nanny for addr, ws in self.workers.items()} + nanny_workers = { + addr: ws.nanny for addr, ws in self.workers.items() if ws.nanny + } - for addr in list(self.workers): - try: - # Ask the worker to close if it doesn't have a nanny, - # otherwise the nanny will kill it anyway - await self.remove_worker( - address=addr, close=addr not in nannies, stimulus_id=stimulus_id - ) - except Exception: - logger.info( - "Exception while restarting. This is normal", exc_info=True - ) + # Close non-Nanny workers. We have no way to restart them, so we just let them go, + # and assume a deployment system is going to restart them for us. + await asyncio.gather( + *( + self.remove_worker(address=addr, stimulus_id=stimulus_id) + for addr in self.workers + if addr not in nanny_workers + ) + ) self.clear_task_state() @@ -5113,21 +5113,27 @@ async def restart(self, client=None, timeout=30): except Exception as e: logger.exception(e) - logger.debug("Send kill signal to nannies: %s", nannies) + logger.debug("Send kill signal to nannies: %s", nanny_workers) async with contextlib.AsyncExitStack() as stack: nannies = [ await stack.enter_async_context( rpc(nanny_address, connection_args=self.connection_args) ) - for nanny_address in nannies.values() - if nanny_address is not None + for nanny_address in nanny_workers.values() ] - resps = All( - [nanny.restart(close=True, timeout=timeout * 0.8) for nanny in nannies] - ) try: - resps = await asyncio.wait_for(resps, timeout) + resps = await asyncio.wait_for( + asyncio.gather( + *( + nanny.restart(close=True, timeout=timeout * 0.8) + for nanny in nannies + ) + ), + timeout, + ) + # NOTE: the `WorkerState` entries for these workers will be removed + # naturally when they disconnect from the scheduler. except TimeoutError: logger.error( "Nannies didn't report back restarted within " diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 90cd3f2c9b8..dd63e86f2f0 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import json import logging @@ -625,6 +627,93 @@ async def test_restart(c, s, a, b): assert not s.tasks +@gen_cluster(client=True, Worker=Nanny, timeout=60) +async def test_restart_some_nannies_some_not(c, s, a, b): + original_procs = {a.process.process, b.process.process} + original_workers = dict(s.workers) + async with Worker(s.address, nthreads=1) as w: + await c.wait_for_workers(3) + + # Halfway through `Scheduler.restart`, only the non-Nanny workers should be removed. + # Nanny-based workers should be kept around so we can call their `restart` RPC. + class ValidateRestartPlugin(SchedulerPlugin): + error: Exception | None + + def restart(self, scheduler: Scheduler) -> None: + try: + assert scheduler.workers.keys() == { + a.worker_address, + b.worker_address, + } + assert all(ws.nanny for ws in scheduler.workers.values()) + except Exception as e: + # `Scheduler.restart` swallows exceptions within plugins + self.error = e + raise + else: + self.error = None + + plugin = ValidateRestartPlugin() + s.add_plugin(plugin) + await s.restart() + + if plugin.error: + raise plugin.error + + assert w.status == Status.closed + + assert len(s.workers) == 2 + # Confirm they restarted + # NOTE: == for `psutil.Process` compares PID and creation time + new_procs = {a.process.process, b.process.process} + assert new_procs != original_procs + # The workers should have new addresses + assert s.workers.keys().isdisjoint(original_workers.keys()) + # The old WorkerState instances should be replaced + assert set(s.workers.values()).isdisjoint(original_workers.values()) + + +class SlowRestartNanny(Nanny): + def __init__(self, *args, **kwargs): + self.restart_proceed = asyncio.Event() + self.restart_called = asyncio.Event() + super().__init__(*args, **kwargs) + + async def restart(self, **kwargs): + self.restart_called.set() + await self.restart_proceed.wait() + return await super().restart(**kwargs) + + +@gen_cluster( + client=True, + nthreads=[("", 1)], + Worker=SlowRestartNanny, + worker_kwargs={"heartbeat_interval": "1ms"}, +) +async def test_restart_heartbeat_before_closing(c, s: Scheduler, n: SlowRestartNanny): + """ + Ensure that if workers heartbeat in the middle of `Scheduler.restart`, they don't close themselves. + https://github.com/dask/distributed/issues/6494 + """ + prev_workers = dict(s.workers) + restart_task = asyncio.create_task(s.restart()) + + await n.restart_called.wait() + await asyncio.sleep(0.5) # significantly longer than the heartbeat interval + + # WorkerState should not be removed yet, because the worker hasn't been told to close + assert s.workers + + n.restart_proceed.set() + # Wait until the worker has left (possibly until it's come back too) + while s.workers == prev_workers: + await asyncio.sleep(0.01) + + await restart_task + await c.wait_for_workers(1) + + @gen_cluster() async def test_broadcast(s, a, b): result = await s.broadcast(msg={"op": "ping"})