diff --git a/distributed/core.py b/distributed/core.py index 9aaf535191c..cb33463ede5 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -136,7 +136,6 @@ def __init__( connection_args=None, timeout=None, io_loop=None, - **kwargs, ): self.handlers = { "identity": self.identity, @@ -238,8 +237,6 @@ def set_thread_ident(): self.__stopped = False - super().__init__(**kwargs) - @property def status(self): return self._status diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index b5306b37818..b35183b3489 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -1063,3 +1063,21 @@ async def test_cluster_names(): async with LocalCluster(processes=False, asynchronous=True) as unnamed_cluster2: assert unnamed_cluster2 != unnamed_cluster + + +@pytest.mark.asyncio +@pytest.mark.parametrize("nanny", [True, False]) +async def test_local_cluster_redundant_kwarg(nanny): + with pytest.raises(TypeError, match="unexpected keyword argument"): + # Extra arguments are forwarded to the worker class. Depending on + # whether we use the nanny or not, the error treatment is quite + # different and we should assert that an exception is raised + async with await LocalCluster( + typo_kwarg="foo", processes=nanny, n_workers=1 + ) as cluster: + + # This will never work but is a reliable way to block without hard + # coding any sleep values + async with Client(cluster) as c: + f = c.submit(sleep, 0) + await f diff --git a/distributed/nanny.py b/distributed/nanny.py index a275055029d..8cd28ed4189 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -9,6 +9,7 @@ import weakref from contextlib import suppress from multiprocessing.queues import Empty +from time import sleep as sync_sleep import psutil from tornado import gen @@ -362,7 +363,6 @@ async def instantiate(self, comm=None) -> Status: config=self.config, ) - self.auto_restart = True if self.death_timeout: try: result = await asyncio.wait_for( @@ -378,7 +378,11 @@ async def instantiate(self, comm=None) -> Status: raise else: - result = await self.process.start() + try: + result = await self.process.start() + except Exception: + await self.close() + raise return result async def restart(self, comm=None, timeout=2, executor_wait=True): @@ -414,9 +418,10 @@ def memory_monitor(self): """ Track worker's memory. Restart if it goes above terminate fraction """ if self.status != Status.running: return + if self.process is None or self.process.process is None: + return None process = self.process.process - if process is None: - return + try: proc = self._psutil_process memory = proc.memory_info().rss @@ -519,6 +524,9 @@ async def close(self, comm=None, timeout=5, report=None): class WorkerProcess: + # The interval how often to check the msg queue for init + _init_msg_interval = 0.05 + def __init__( self, worker_kwargs, @@ -584,9 +592,14 @@ async def start(self) -> Status: except OSError: logger.exception("Nanny failed to start process", exc_info=True) self.process.terminate() - return - - msg = await self._wait_until_connected(uid) + self.status = Status.failed + return self.status + try: + msg = await self._wait_until_connected(uid) + except Exception: + self.status = Status.failed + self.process.terminate() + raise if not msg: return self.status self.worker_address = msg["address"] @@ -683,14 +696,15 @@ async def kill(self, timeout=2, executor_wait=True): logger.error("Failed to kill worker process: %s", e) async def _wait_until_connected(self, uid): - delay = 0.05 while True: if self.status != Status.starting: return + # This is a multiprocessing queue and we'd block the event loop if + # we simply called get try: msg = self.init_result_q.get_nowait() except Empty: - await asyncio.sleep(delay) + await asyncio.sleep(self._init_msg_interval) continue if msg["uid"] != uid: # ensure that we didn't cross queues @@ -700,7 +714,6 @@ async def _wait_until_connected(self, uid): logger.error( "Failed while trying to start worker process: %s", msg["exception"] ) - await self.process.join() raise msg["exception"] else: return msg @@ -718,88 +731,108 @@ def _run( config, Worker, ): # pragma: no cover - os.environ.update(env) - dask.config.set(config) try: - from dask.multiprocessing import initialize_worker_process - except ImportError: # old Dask version - pass - else: - initialize_worker_process() + os.environ.update(env) + dask.config.set(config) + try: + from dask.multiprocessing import initialize_worker_process + except ImportError: # old Dask version + pass + else: + initialize_worker_process() - if silence_logs: - logger.setLevel(silence_logs) + if silence_logs: + logger.setLevel(silence_logs) - IOLoop.clear_instance() - loop = IOLoop() - loop.make_current() - worker = Worker(**worker_kwargs) + IOLoop.clear_instance() + loop = IOLoop() + loop.make_current() + worker = Worker(**worker_kwargs) - async def do_stop(timeout=5, executor_wait=True): - try: - await worker.close( - report=True, - nanny=False, - safe=True, # TODO: Graceful or not? - executor_wait=executor_wait, - timeout=timeout, - ) - finally: - loop.stop() - - def watch_stop_q(): - """ - Wait for an incoming stop message and then stop the - worker cleanly. - """ - while True: - try: - msg = child_stop_q.get(timeout=1000) - except Empty: - pass - else: - child_stop_q.close() - assert msg.pop("op") == "stop" - loop.add_callback(do_stop, **msg) - break - - t = threading.Thread(target=watch_stop_q, name="Nanny stop queue watch") - t.daemon = True - t.start() - - async def run(): - """ - Try to start worker and inform parent of outcome. - """ - try: - await worker - except Exception as e: - logger.exception("Failed to start worker") - init_result_q.put({"uid": uid, "exception": e}) - init_result_q.close() - else: + async def do_stop(timeout=5, executor_wait=True): try: - assert worker.address - except ValueError: - pass - else: - init_result_q.put( - { - "address": worker.address, - "dir": worker.local_directory, - "uid": uid, - } + await worker.close( + report=True, + nanny=False, + safe=True, # TODO: Graceful or not? + executor_wait=executor_wait, + timeout=timeout, ) + finally: + loop.stop() + + def watch_stop_q(): + """ + Wait for an incoming stop message and then stop the + worker cleanly. + """ + while True: + try: + msg = child_stop_q.get(timeout=1000) + except Empty: + pass + else: + child_stop_q.close() + assert msg.pop("op") == "stop" + loop.add_callback(do_stop, **msg) + break + + t = threading.Thread(target=watch_stop_q, name="Nanny stop queue watch") + t.daemon = True + t.start() + + async def run(): + """ + Try to start worker and inform parent of outcome. + """ + try: + await worker + except Exception as e: + logger.exception("Failed to start worker") + init_result_q.put({"uid": uid, "exception": e}) init_result_q.close() - await worker.finished() - logger.info("Worker closed") - - try: - loop.run_sync(run) - except (TimeoutError, gen.TimeoutError): - # Loop was stopped before wait_until_closed() returned, ignore - pass - except KeyboardInterrupt: - # At this point the loop is not running thus we have to run - # do_stop() explicitly. - loop.run_sync(do_stop) + # If we hit an exception here we need to wait for a least + # one interval for the outside to pick up this message. + # Otherwise we arrive in a race condition where the process + # cleanup wipes the queue before the exception can be + # properly handled. See also + # WorkerProcess._wait_until_connected (the 2 is for good + # measure) + sync_sleep(cls._init_msg_interval * 2) + else: + try: + assert worker.address + except ValueError: + pass + else: + init_result_q.put( + { + "address": worker.address, + "dir": worker.local_directory, + "uid": uid, + } + ) + init_result_q.close() + await worker.finished() + logger.info("Worker closed") + + except Exception as e: + logger.exception("Failed to initialize Worker") + init_result_q.put({"uid": uid, "exception": e}) + init_result_q.close() + # If we hit an exception here we need to wait for a least one + # interval for the outside to pick up this message. Otherwise we + # arrive in a race condition where the process cleanup wipes the + # queue before the exception can be properly handled. See also + # WorkerProcess._wait_until_connected (the 2 is for good measure) + sync_sleep(cls._init_msg_interval * 2) + else: + try: + loop.run_sync(run) + except (TimeoutError, gen.TimeoutError): + # Loop was stopped before wait_until_closed() returned, ignore + pass + except KeyboardInterrupt: + # At this point the loop is not running thus we have to run + # do_stop() explicitly. + loop.run_sync(do_stop) diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 398b933c02c..013daab9e06 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -880,3 +880,9 @@ async def sleep(comm=None): # weakref set/dict should be cleaned up assert not len(server._ongoing_coroutines) + + +@pytest.mark.asyncio +async def test_server_redundant_kwarg(): + with pytest.raises(TypeError, match="unexpected keyword argument"): + await Server({}, typo_kwarg="foo") diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 779cbdc4a4a..1203fb3bbed 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -14,7 +14,6 @@ import dask from distributed import Client, Nanny, Scheduler, Worker, rpc, wait, worker -from distributed.compatibility import MACOS from distributed.core import CommClosedError, Status from distributed.diagnostics import SchedulerPlugin from distributed.metrics import time @@ -565,10 +564,19 @@ async def start(self): raise StartException("broken") -@pytest.mark.flaky(reruns=10, reruns_delay=5, condition=MACOS) @pytest.mark.asyncio async def test_worker_start_exception(cleanup): # make sure this raises the right Exception: with pytest.raises(StartException): async with Nanny("tcp://localhost:1", worker_class=BrokenWorker) as n: await n.start() + + +@pytest.mark.asyncio +async def test_failure_during_worker_initialization(cleanup): + with captured_logger(logger="distributed.nanny", level=logging.WARNING) as logs: + async with Scheduler() as s: + with pytest.raises(Exception): + async with Nanny(s.address, foo="bar") as n: + await n + assert "Restarting worker" not in logs.getvalue()