From e8737b822e862d45042fe352132b99b35fce677c Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 15 Nov 2022 19:40:04 +0100 Subject: [PATCH 1/2] Ensure Server.close cannot run concurrently --- distributed/core.py | 68 +++++++++++++++++------------ distributed/nanny.py | 38 +++++----------- distributed/scheduler.py | 22 +--------- distributed/tests/test_nanny.py | 25 ++++++++++- distributed/worker.py | 19 ++------ distributed/worker_state_machine.py | 2 +- 6 files changed, 81 insertions(+), 93 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 96dfe7b68a9..da7a055d971 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -415,7 +415,7 @@ def set_thread_ident(): self.thread_id = threading.get_ident() self.io_loop.add_callback(set_thread_ident) - self._startup_lock = asyncio.Lock() + self._startstop_lock = asyncio.Lock() self.__startup_exc: Exception | None = None self.rpc = ConnectionPool( @@ -465,7 +465,7 @@ async def start_unsafe(self): @final async def start(self): - async with self._startup_lock: + async with self._startstop_lock: if self.status == Status.failed: assert self.__startup_exc is not None raise self.__startup_exc @@ -866,33 +866,47 @@ async def handle_stream(self, comm, extra=None): await comm.close() assert comm.closed() - async def close(self, timeout=None): - try: - for pc in self.periodic_callbacks.values(): - pc.stop() - - if not self.__stopped: - self.__stopped = True - _stops = set() - for listener in self.listeners: - future = listener.stop() - if inspect.isawaitable(future): - warnings.warn( - f"{type(listener)} is using an asynchronous `stop` method. " - "Support for asynchronous `Listener.stop` will be removed in a future version", - PendingDeprecationWarning, - ) - _stops.add(future) - if _stops: - await asyncio.gather(*_stops) + async def close_unsafe( + self, timeout: float | None, reason: str | None, **kwargs: Any + ) -> None: + """Attempt to close the server. This is not idempotent and not protected against concurrent closing attempts. - # TODO: Deal with exceptions - await self._ongoing_background_tasks.stop() + This is intended to be overwritten or called by subclasses. For a safe + close, please use ``Server.close`` instead. + """ - await self.rpc.close() - await asyncio.gather(*[comm.close() for comm in list(self._comms)]) - finally: - self._event_finished.set() + @final + async def close( + self, timeout: float | None = None, reason: str | None = None, **kwargs: Any + ) -> None: + async with self._startstop_lock: + if self.status in (Status.closed, Status.closing, Status.failed): + return None + + self.status = Status.closing + logger.info( + "Closing %r at %r. Reason: %s", + type(self).__name__, + self.address_safe, + reason, + ) + try: + for pc in self.periodic_callbacks.values(): + pc.stop() + self.periodic_callbacks.clear() + self.stop() + + await asyncio.wait_for(self.close_unsafe(**kwargs), timeout) + + # TODO: Deal with exceptions + await self._ongoing_background_tasks.stop() + + await asyncio.gather(*[comm.close() for comm in list(self._comms)]) + await self.rpc.close() + finally: + self._event_finished.set() + # TODO: This might break the worker + self.status = Status.closed def pingpong(comm): diff --git a/distributed/nanny.py b/distributed/nanny.py index 9e7be25c3ce..7089dda65fd 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -16,8 +16,7 @@ from collections.abc import Collection from inspect import isawaitable from queue import Empty -from time import sleep as sync_sleep -from typing import TYPE_CHECKING, Callable, ClassVar, Literal +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal from toolz import merge from tornado.ioloop import IOLoop @@ -570,22 +569,15 @@ def close_gracefully(self, reason: str = "nanny-close-gracefully") -> None: "Closing Nanny gracefully at %r. Reason: %s", self.address_safe, reason ) - async def close( - self, timeout: float = 5, reason: str = "nanny-close" - ) -> Literal["OK"]: + async def close_unsafe( + self, + timeout: float | None = 5, + reason: str | None = "nanny-close", + **kwargs: Any, + ) -> None: """ Close the worker process, stop all comms. """ - if self.status == Status.closing: - await self.finished() - assert self.status == Status.closed - - if self.status == Status.closed: - return "OK" - - self.status = Status.closing - logger.info("Closing Nanny at %r. Reason: %s", self.address_safe, reason) - for preload in self.preloads: await preload.teardown() @@ -600,14 +592,13 @@ async def close( self.stop() try: if self.process is not None: + assert timeout is not None + assert reason is not None await self.kill(timeout=timeout, reason=reason) except Exception: logger.exception("Error in Nanny killing Worker subprocess") self.process = None await self.rpc.close() - self.status = Status.closed - await super().close() - return "OK" async def _log_event(self, topic, msg): await self.scheduler.log_event( @@ -817,8 +808,8 @@ async def kill( "reason": reason, } ) - await asyncio.sleep(0) # otherwise we get broken pipe errors queue.close() + queue.join_thread() del queue try: @@ -943,14 +934,7 @@ async def run() -> None: logger.exception(f"Failed to {failure_type} worker") init_result_q.put({"uid": uid, "exception": e}) init_result_q.close() - # If we hit an exception here we need to wait for a least - # one interval for the outside to pick up this message. - # Otherwise we arrive in a race condition where the process - # cleanup wipes the queue before the exception can be - # properly handled. See also - # WorkerProcess._wait_until_connected (the 3 is for good - # measure) - sync_sleep(cls._init_msg_interval * 3) + init_result_q.join_thread() with contextlib.ExitStack() as stack: diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 0dc40a67bac..d0c37ed2bb2 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3889,21 +3889,13 @@ def del_scheduler_file(): setproctitle(f"dask scheduler [{self.address}]") return self - async def close(self, fast=None, close_workers=None): + async def close_unsafe(self, timeout: float | None, reason: str | None, **kwargs): """Send cleanup signal to all coroutines then wait until finished See Also -------- Scheduler.cleanup """ - if fast is not None or close_workers is not None: - warnings.warn( - "The 'fast' and 'close_workers' parameters in Scheduler.close have no effect and will be removed in a future version of distributed.", - FutureWarning, - ) - if self.status in (Status.closing, Status.closed): - await self.finished() - return async def log_errors(func): try: @@ -3915,8 +3907,6 @@ async def log_errors(func): *[log_errors(plugin.before_close) for plugin in list(self.plugins.values())] ) - self.status = Status.closing - logger.info("Scheduler closing...") setproctitle("dask scheduler [closing]") @@ -3930,10 +3920,6 @@ async def log_errors(func): *[log_errors(plugin.close) for plugin in list(self.plugins.values())] ) - for pc in self.periodic_callbacks.values(): - pc.stop() - self.periodic_callbacks.clear() - self.stop_services() for ext in self.extensions.values(): @@ -3961,12 +3947,6 @@ async def log_errors(func): for comm in self.client_comms.values(): comm.abort() - await self.rpc.close() - - self.status = Status.closed - self.stop() - await super().close() - setproctitle("dask scheduler [closed]") disable_gc_diagnosis() diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 0fe602c14aa..e5517fb6a5d 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -26,7 +26,7 @@ from distributed import Nanny, Scheduler, Worker, profile, rpc, wait, worker from distributed.compatibility import LINUX, WINDOWS -from distributed.core import CommClosedError, Status +from distributed.core import CommClosedError, ConnectionPool, Status from distributed.diagnostics import SchedulerPlugin from distributed.metrics import time from distributed.protocol.pickle import dumps @@ -760,3 +760,26 @@ async def test_worker_inherits_temp_config(c, s): async with Nanny(s.address): out = await c.submit(lambda: dask.config.get("test123")) assert out == 123 + + +@pytest.mark.parametrize("api", ["restart", "kill"]) +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_restart_stress_nanny_only(c, s, a, api): + async def keep_killing(): + pool = await ConnectionPool() + try: + rpc = pool(a.address) + for _ in range(2): + try: + meth = getattr(rpc, api) + await meth(reason="test-trigger") + except OSError: + break + + await asyncio.sleep(0.1) + finally: + await pool.close() + + kill_tasks = [asyncio.create_task(keep_killing()) for _ in range(2)] + await asyncio.gather(*kill_tasks) + assert a.status == Status.running diff --git a/distributed/worker.py b/distributed/worker.py index 3a6e76fea46..c61c2030be3 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1456,12 +1456,13 @@ async def start_unsafe(self): return self @log_errors - async def close( # type: ignore + async def close_unsafe( # type: ignore self, timeout: float = 30, executor_wait: bool = True, nanny: bool = True, reason: str = "worker-close", + **kwargs, ) -> str | None: """Close the worker @@ -1488,14 +1489,6 @@ async def close( # type: ignore # is the other way round. If an external caller wants to close # nanny+worker, the nanny must be notified first. ==> Remove kwarg # nanny, see also Scheduler.retire_workers - if self.status in (Status.closed, Status.closing, Status.failed): - logging.debug( - "Attempted to close worker that is already %s. Reason: %s", - self.status, - reason, - ) - await self.finished() - return None if self.status == Status.init: # If the worker is still in startup/init and is started by a nanny, @@ -1526,7 +1519,7 @@ async def close( # type: ignore pc.stop() # Cancel async instructions - await BaseWorker.close(self, timeout=timeout) + await BaseWorker.close_unsafe(self, timeout=timeout) for preload in self.preloads: try: @@ -1626,12 +1619,6 @@ def _close(executor, wait): executor=executor, wait=executor_wait ) # Just run it directly - self.stop() - await self.rpc.close() - - self.status = Status.closed - await ServerNode.close(self) - setproctitle("dask worker [closed]") return "OK" diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index fc074ac7bea..2e11666b81d 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -3666,7 +3666,7 @@ def handle_stimulus(self, *stims: StateMachineEvent) -> None: self._async_instructions.add(task) task.add_done_callback(self._handle_stimulus_from_task) - async def close(self, timeout: float = 30) -> None: + async def close_unsafe(self, timeout: float = 30) -> None: """Cancel all asynchronous instructions""" if not self._async_instructions: return From d214271ea2650de1fd746dbc91c83db98687f713 Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 16 Nov 2022 08:57:11 +0100 Subject: [PATCH 2/2] Fix signatures --- distributed/core.py | 2 +- distributed/scheduler.py | 4 +++- distributed/worker_state_machine.py | 4 +++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index da7a055d971..20eef1d7752 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -867,7 +867,7 @@ async def handle_stream(self, comm, extra=None): assert comm.closed() async def close_unsafe( - self, timeout: float | None, reason: str | None, **kwargs: Any + self, timeout: float | None = None, reason: str | None = None, **kwargs: Any ) -> None: """Attempt to close the server. This is not idempotent and not protected against concurrent closing attempts. diff --git a/distributed/scheduler.py b/distributed/scheduler.py index d0c37ed2bb2..4a3d0e9b644 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3889,7 +3889,9 @@ def del_scheduler_file(): setproctitle(f"dask scheduler [{self.address}]") return self - async def close_unsafe(self, timeout: float | None, reason: str | None, **kwargs): + async def close_unsafe( + self, timeout: float | None = None, reason: str | None = None, **kwargs + ): """Send cleanup signal to all coroutines then wait until finished See Also diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 2e11666b81d..22d59c2e654 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -3666,7 +3666,9 @@ def handle_stimulus(self, *stims: StateMachineEvent) -> None: self._async_instructions.add(task) task.add_done_callback(self._handle_stimulus_from_task) - async def close_unsafe(self, timeout: float = 30) -> None: + async def close_unsafe( + self, timeout: float | None = 30, reason: str | None = None + ) -> None: """Cancel all asynchronous instructions""" if not self._async_instructions: return