From 9e31787aeb69cfa50795d600586a0a6e57cc3c8e Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Fri, 3 Jun 2022 12:50:24 -0600 Subject: [PATCH 1/8] Fix `Scheduler.restart` logic `Scheduler.restart` used to remove every worker without closing it. This was bad practice (#6390), as well as incorrect: it certainly seemed the intent was only to remove non-Nanny workers. Then, Nanny workers are restarted via the `restart` RPC to the Nanny, not to the worker. --- distributed/scheduler.py | 49 +++++++++++++++++----------- distributed/tests/test_scheduler.py | 50 +++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 19 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index d4a84c184b5..ac2113bf785 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5091,19 +5091,24 @@ async def restart(self, client=None, timeout=30): stimulus_id=stimulus_id, ) - nannies = {addr: ws.nanny for addr, ws in self.workers.items()} + nanny_workers: dict[str, str] = { + addr: ws.nanny for addr, ws in self.workers.items() if ws.nanny + } - for addr in list(self.workers): - try: - # Ask the worker to close if it doesn't have a nanny, - # otherwise the nanny will kill it anyway - await self.remove_worker( - address=addr, close=addr not in nannies, stimulus_id=stimulus_id - ) - except Exception: - logger.info( - "Exception while restarting. This is normal", exc_info=True - ) + # Close non-Nanny workers. We have no way to restart them, so we just let them go, + # and assume a deployment system is going to restart them for us. + close_results = await asyncio.gather( + *( + self.remove_worker(address=addr, stimulus_id=stimulus_id) + for addr in self.workers + if addr not in nanny_workers + ), + return_exceptions=True, + ) + for r in close_results: + if isinstance(r, Exception): + # TODO this is probably not, in fact, normal. + logger.info("Exception while restarting. This is normal.", exc_info=r) self.clear_task_state() @@ -5113,21 +5118,27 @@ async def restart(self, client=None, timeout=30): except Exception as e: logger.exception(e) - logger.debug("Send kill signal to nannies: %s", nannies) + logger.debug("Send kill signal to nannies: %s", nanny_workers) async with contextlib.AsyncExitStack() as stack: nannies = [ await stack.enter_async_context( rpc(nanny_address, connection_args=self.connection_args) ) - for nanny_address in nannies.values() - if nanny_address is not None + for nanny_address in nanny_workers.values() ] - resps = All( - [nanny.restart(close=True, timeout=timeout * 0.8) for nanny in nannies] - ) try: - resps = await asyncio.wait_for(resps, timeout) + resps = await asyncio.wait_for( + asyncio.gather( + *( + nanny.restart(close=True, timeout=timeout * 0.8) + for nanny in nannies + ) + ), + timeout, + ) + # NOTE: the `WorkerState` entries for these workers will be removed + # naturally when they disconnect from the scheduler. except TimeoutError: logger.error( "Nannies didn't report back restarted within " diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 90cd3f2c9b8..06ad5768ada 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import asyncio import json import logging import operator +import os import pickle import re import sys @@ -625,6 +628,53 @@ async def test_restart(c, s, a, b): assert not s.tasks +@gen_cluster(client=True, Worker=Nanny, timeout=60) +async def test_restart_some_nannies_some_not( + c: Client, s: Scheduler, a: Nanny, b: Nanny +): + original_pids = await c.run(os.getpid) + original_workers = dict(s.workers) + async with Worker(s.address, nthreads=1) as w: + await c.wait_for_workers(3) + + # Halfway through `Scheduler.restart`, only the non-Nanny workers should be removed. + # Nanny-based workers should be kept around so we can call their `restart` RPC. + class ValidateRestartPlugin(SchedulerPlugin): + error: Exception | None + + def restart(self, scheduler: Scheduler) -> None: + try: + assert scheduler.workers.keys() == { + a.worker_address, + b.worker_address, + } + assert all(ws.nanny for ws in scheduler.workers.values()) + except Exception as e: + # `Scheduler.restart` swallows exceptions within plugins + self.error = e + raise + else: + self.error = None + + plugin = ValidateRestartPlugin() + s.add_plugin(plugin) + await s.restart() + + if plugin.error: + raise plugin.error + + assert w.status == Status.closed + + assert len(s.workers) == 2 + # Confirm they restarted + new_pids = await c.run(os.getpid) + assert new_pids != original_pids + # The workers should have new addresses + assert s.workers.keys().isdisjoint(original_workers.keys()) + # The old WorkerState instances should be replaced + assert set(s.workers.values()).isdisjoint(original_workers.values()) + + @gen_cluster() async def test_broadcast(s, a, b): result = await s.broadcast(msg={"op": "ping"}) From bd74d2b2b9a4862f0d8e2a1e33ea3028c5fb4466 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Fri, 3 Jun 2022 14:47:09 -0600 Subject: [PATCH 2/8] Explicit test for the restart problem too --- distributed/tests/test_scheduler.py | 41 +++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 06ad5768ada..8d638ac1639 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -675,6 +675,47 @@ def restart(self, scheduler: Scheduler) -> None: assert set(s.workers.values()).isdisjoint(original_workers.values()) +class SlowRestartNanny(Nanny): + def __init__(self, *args, **kwargs): + self.restart_proceed = asyncio.Event() + self.restart_called = asyncio.Event() + super().__init__(*args, **kwargs) + + async def restart(self, **kwargs): + self.restart_called.set() + await self.restart_proceed.wait() + return await super().restart(**kwargs) + + +@gen_cluster( + client=True, + nthreads=[("", 1)], + Worker=SlowRestartNanny, + worker_kwargs={"heartbeat_interval": "1ms"}, +) +async def test_restart_heartbeat_before_closing(c, s: Scheduler, n: SlowRestartNanny): + """ + Ensure that if workers heartbeat in the middle of `Scheduler.restart`, they don't close themselves. + https://github.com/dask/distributed/issues/6494 + """ + prev_workers = dict(s.workers) + restart_task = asyncio.create_task(s.restart()) + + await n.restart_called.wait() + await asyncio.sleep(0.5) # significantly longer than the heartbeat interval + + # WorkerState should not be removed yet, because the worker hasn't been told to close + assert s.workers + + n.restart_proceed.set() + # Wait until the worker has left (possibly until it's come back too) + while s.workers == prev_workers: + await asyncio.sleep(0.01) + + await restart_task + await c.wait_for_workers(1) + + @gen_cluster() async def test_broadcast(s, a, b): result = await s.broadcast(msg={"op": "ping"}) From 735d9825155241ce9f9beec910f277b4810f738c Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Fri, 3 Jun 2022 14:59:04 -0600 Subject: [PATCH 3/8] improve PID sensitivity --- distributed/tests/test_scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 8d638ac1639..d60b7de10d3 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -632,7 +632,7 @@ async def test_restart(c, s, a, b): async def test_restart_some_nannies_some_not( c: Client, s: Scheduler, a: Nanny, b: Nanny ): - original_pids = await c.run(os.getpid) + original_pids = set((await c.run(os.getpid)).values()) original_workers = dict(s.workers) async with Worker(s.address, nthreads=1) as w: await c.wait_for_workers(3) @@ -667,7 +667,7 @@ def restart(self, scheduler: Scheduler) -> None: assert len(s.workers) == 2 # Confirm they restarted - new_pids = await c.run(os.getpid) + new_pids = set((await c.run(os.getpid)).values()) assert new_pids != original_pids # The workers should have new addresses assert s.workers.keys().isdisjoint(original_workers.keys()) From 7fd7f0e3383b20da6ce4a0fe5fd3e5563b235b8e Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 8 Jun 2022 09:23:14 -0600 Subject: [PATCH 4/8] Apply suggestions from code review Co-authored-by: crusaderky --- distributed/scheduler.py | 2 +- distributed/tests/test_scheduler.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ac2113bf785..4093c27a31e 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5091,7 +5091,7 @@ async def restart(self, client=None, timeout=30): stimulus_id=stimulus_id, ) - nanny_workers: dict[str, str] = { + nanny_workers = { addr: ws.nanny for addr, ws in self.workers.items() if ws.nanny } diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index d60b7de10d3..965aae104e8 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -630,9 +630,10 @@ async def test_restart(c, s, a, b): @gen_cluster(client=True, Worker=Nanny, timeout=60) async def test_restart_some_nannies_some_not( - c: Client, s: Scheduler, a: Nanny, b: Nanny + c, s, a, b ): - original_pids = set((await c.run(os.getpid)).values()) + original_pids = {a.process.process.pid, b.process.process.pid} + assert all(original_pids) original_workers = dict(s.workers) async with Worker(s.address, nthreads=1) as w: await c.wait_for_workers(3) From a033eedc975477fa923bf71de6f0c94a03fe2829 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 8 Jun 2022 09:26:56 -0600 Subject: [PATCH 5/8] Better process equality check --- distributed/tests/test_scheduler.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 965aae104e8..dd63e86f2f0 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -4,7 +4,6 @@ import json import logging import operator -import os import pickle import re import sys @@ -629,11 +628,8 @@ async def test_restart(c, s, a, b): @gen_cluster(client=True, Worker=Nanny, timeout=60) -async def test_restart_some_nannies_some_not( - c, s, a, b -): - original_pids = {a.process.process.pid, b.process.process.pid} - assert all(original_pids) +async def test_restart_some_nannies_some_not(c, s, a, b): + original_procs = {a.process.process, b.process.process} original_workers = dict(s.workers) async with Worker(s.address, nthreads=1) as w: await c.wait_for_workers(3) @@ -668,8 +664,9 @@ def restart(self, scheduler: Scheduler) -> None: assert len(s.workers) == 2 # Confirm they restarted - new_pids = set((await c.run(os.getpid)).values()) - assert new_pids != original_pids + # NOTE: == for `psutil.Process` compares PID and creation time + new_procs = {a.process.process, b.process.process} + assert new_procs != original_procs # The workers should have new addresses assert s.workers.keys().isdisjoint(original_workers.keys()) # The old WorkerState instances should be replaced From d037f3736b92571e505a2c393b95f6f32b4b97e8 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 8 Jun 2022 09:29:44 -0600 Subject: [PATCH 6/8] Don't suppress errors from `remove_worker` --- distributed/scheduler.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 4093c27a31e..b03332bf9f5 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5097,18 +5097,13 @@ async def restart(self, client=None, timeout=30): # Close non-Nanny workers. We have no way to restart them, so we just let them go, # and assume a deployment system is going to restart them for us. - close_results = await asyncio.gather( + await asyncio.gather( *( self.remove_worker(address=addr, stimulus_id=stimulus_id) for addr in self.workers if addr not in nanny_workers - ), - return_exceptions=True, + ) ) - for r in close_results: - if isinstance(r, Exception): - # TODO this is probably not, in fact, normal. - logger.info("Exception while restarting. This is normal.", exc_info=r) self.clear_task_state() From 7d90e2ad1628ce0e86bcd3eb582afecd8a577c58 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 8 Jun 2022 15:53:35 -0600 Subject: [PATCH 7/8] don't heartbeat when closing TODO make this a separate PR and add tests. Just want to see if it helps CI. --- distributed/worker.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/distributed/worker.py b/distributed/worker.py index a8f4b6ba942..f7cbf1bc497 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1215,6 +1215,9 @@ async def heartbeat(self): if self.heartbeat_active: logger.debug("Heartbeat skipped: channel busy") return + if self.status in {Status.closing, Status.closed, Status.failed}: + logger.debug(f"Heartbeat skipped: {self.status=}") + return self.heartbeat_active = True logger.debug("Heartbeat: %s", self.address) try: From 5cccba65fea4e03c4f405bb48b1107949e6d55f6 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 8 Jun 2022 16:50:18 -0600 Subject: [PATCH 8/8] Revert "don't heartbeat when closing" This reverts commit 7d90e2ad1628ce0e86bcd3eb582afecd8a577c58. --- distributed/worker.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/distributed/worker.py b/distributed/worker.py index f7cbf1bc497..a8f4b6ba942 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1215,9 +1215,6 @@ async def heartbeat(self): if self.heartbeat_active: logger.debug("Heartbeat skipped: channel busy") return - if self.status in {Status.closing, Status.closed, Status.failed}: - logger.debug(f"Heartbeat skipped: {self.status=}") - return self.heartbeat_active = True logger.debug("Heartbeat: %s", self.address) try: