diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 566f149fc3c..3025695b367 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1969,6 +1969,7 @@ def __init__( ("processing", "erred"): self.transition_processing_erred, ("no-worker", "released"): self.transition_no_worker_released, ("no-worker", "waiting"): self.transition_no_worker_waiting, + ("no-worker", "memory"): self.transition_no_worker_memory, ("released", "forgotten"): self.transition_released_forgotten, ("memory", "forgotten"): self.transition_memory_forgotten, ("erred", "released"): self.transition_erred_released, @@ -2450,6 +2451,42 @@ def transition_no_worker_waiting(self, key): pdb.set_trace() raise + def transition_no_worker_memory( + self, key, nbytes=None, type=None, typename: str = None, worker=None + ): + try: + ws: WorkerState = self._workers_dv[worker] + ts: TaskState = self._tasks[key] + recommendations: dict = {} + client_msgs: dict = {} + worker_msgs: dict = {} + + if self._validate: + assert not ts._processing_on + assert not ts._waiting_on + assert ts._state == "no-worker" + + self._unrunnable.remove(ts) + + if nbytes is not None: + ts.set_nbytes(nbytes) + + self.check_idle_saturated(ws) + + _add_to_memory( + self, ts, ws, recommendations, client_msgs, type=type, typename=typename + ) + ts.state = "memory" + + return recommendations, client_msgs, worker_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb + + pdb.set_trace() + raise + @ccall @exceptval(check=False) def decide_worker(self, ts: TaskState) -> WorkerState: diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 536fd3fd45f..b8b51a27905 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -3183,3 +3183,24 @@ async def test_worker_heartbeat_after_cancel(c, s, *workers): while any(w.tasks for w in workers): await asyncio.gather(*[w.heartbeat() for w in workers]) + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_worker_reconnect_task_memory_with_resources(c, s, a): + async with Worker(s.address, resources={"A": 1}) as b: + b.periodic_callbacks["heartbeat"].stop() + + futs = c.map(inc, range(10), resources={"A": 1}) + res = c.submit(sum, futs) + + while not b.executing_count and not b.data: + await asyncio.sleep(0.001) + + await s.remove_worker(address=b.address, close=False) + while not res.done(): + await b.heartbeat() + + await res + assert ("no-worker", "memory") in { + (start, finish) for (_, start, finish, _, _) in s.transition_log + }