From 2f6085749ec9cbfbfc41a8f346fcde08a330bfc3 Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 22 Apr 2021 15:07:02 +0200 Subject: [PATCH] Close workers if an exception occurs during transition --- distributed/tests/test_worker.py | 49 +++++++++++++++++++++++++++++- distributed/worker.py | 52 ++++++++++++++++++++++++-------- 2 files changed, 87 insertions(+), 14 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 4681c5c36de..2d3c4885ef7 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -54,7 +54,14 @@ s, slowinc, ) -from distributed.worker import Worker, error_message, logger, parse_memory_limit, weight +from distributed.worker import ( + TaskState, + Worker, + error_message, + logger, + parse_memory_limit, + weight, +) @pytest.mark.asyncio @@ -1792,3 +1799,43 @@ async def test_story(c, s, w): def test_weight_deprecated(): with pytest.warns(DeprecationWarning): weight("foo", "bar") + + +@gen_cluster(nthreads=[("127.0.0.1", 1)]) +async def test_worker_closes_if_transition_raises_Worker(s, a): + """Ensure that a worker closes if an exception is raised during + transitioning a task since we do not support transactions and transactions + methods are usually not atomic.""" + with captured_logger("distributed.worker", level=logging.CRITICAL) as log: + # we do not care what kind of exception is raised here + with pytest.raises(Exception): + a.transition(TaskState("key"), "nope") + while a.status == Status.running: + await asyncio.sleep(0.005) + assert a.status in [Status.closing_gracefully, Status.closing, Status.closed] + log = log.getvalue() + assert "Caught exception in attempt to transition" in log + + +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)], Worker=Nanny) +async def test_worker_closes_if_transition_raises_Nanny(c, s, a): + """Ensure that a worker closes if an exception is raised during + transitioning a task since we do not support transactions and transactions + methods are usually not atomic. + If a Nanny is available, restart. + """ + orig_process = a.process + + # This also should log, see test_worker_closes_if_transition_raises_Worker + # but capturing logs is more difficult with Nanny + def transition_key(dask_worker): + dask_worker.transition(TaskState("key"), "nope") + + with pytest.raises(Exception): + await c.run(transition_key) + while orig_process.is_alive(): + await asyncio.sleep(0.005) + + # A new worker should start a finish this + fut = c.submit(inc, 1) + assert await fut.result() == 2 diff --git a/distributed/worker.py b/distributed/worker.py index 1acd395fc12..80f35368f6b 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1290,7 +1290,13 @@ async def close_gracefully(self, restart=None): logger.info("Closing worker gracefully: %s", self.address) self.status = Status.closing_gracefully - await self.scheduler.retire_workers(workers=[self.address], remove=False) + await self.scheduler.retire_workers( + workers=[self.address], + remove=False, + # Save one RPC call since it is unnecessary for the scheduler to + # close this worker since we're doing it ourself below + close_workers=False, + ) await self.close(safe=True, nanny=not restart) async def terminate(self, comm=None, report=True, **kwargs): @@ -1592,18 +1598,38 @@ def add_task( raise def transition(self, ts, finish, **kwargs): - if ts is None: - return - start = ts.state - if start == finish: - return - func = self._transitions[start, finish] - state = func(ts, **kwargs) - self.log.append((ts.key, start, state or finish)) - ts.state = state or finish - if self.validate: - self.validate_task(ts) - self._notify_plugins("transition", ts.key, start, state or finish, **kwargs) + try: + if ts is None: + return + start = ts.state + if start == finish: + return + func = self._transitions[start, finish] + state = func(ts, **kwargs) + self.log.append((ts.key, start, state or finish)) + ts.state = state or finish + if self.validate: + self.validate_task(ts) + self._notify_plugins("transition", ts.key, start, state or finish, **kwargs) + except Exception as exc: + # We cannot perform something like a transaction rollback. Therefore + # we should be unforgiving and close if we suspect something is + # wrong. + logger.critical( + "Caught exception in attempt to transition %s to %s. We can no " + "longer guarantee a valid state of this Worker and need close. " + "Please file a bug report at https://github.com/dask/distributed/issues " + "with logs leading up to this event as you see appropriate.", + ts, + finish, + exc_info=exc, + ) + # The graceful removal of a worker will not increase any suspicious + # counter since it is removed with flag "safe". We might need to + # reconsider this behaviour in case a given task is transitioned + # wrongfully all the time + self.loop.add_callback(self.close_gracefully, restart=True) + raise def transition_new_waiting(self, ts): try: