diff --git a/distributed/scheduler.py b/distributed/scheduler.py index df115c50466..ab4acc97876 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2100,16 +2100,20 @@ def decide_worker_rootish_queuing_disabled( tg = ts.group lws = tg.last_worker - if not ( - lws and tg.last_worker_tasks_left and self.workers.get(lws.address) is lws + if ( + lws + and tg.last_worker_tasks_left + and lws.status == Status.running + and self.workers.get(lws.address) is lws ): - # Last-used worker is full or unknown; pick a new worker for the next few tasks + ws = lws + else: + # Last-used worker is full, unknown, retiring, or paused; + # pick a new worker for the next few tasks ws = min(pool, key=partial(self.worker_objective, ts)) tg.last_worker_tasks_left = math.floor( (len(tg) / self.total_nthreads) * ws.nthreads ) - else: - ws = lws # Record `last_worker`, or clear it on the final task tg.last_worker = ( diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 3ed53391dd9..6e84db227d9 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -43,6 +43,7 @@ from distributed.utils import TimeoutError from distributed.utils_test import ( NO_AMM, + BlockedGatherDep, BrokenComm, async_wait_for, captured_logger, @@ -253,6 +254,76 @@ def random(**kwargs): test_decide_worker_coschedule_order_neighbors_() +@pytest.mark.skipif( + math.isfinite(dask.config.get("distributed.scheduler.worker-saturation")), + reason="Not relevant with queuing on; see https://github.com/dask/distributed/issues/7204", +) +@gen_cluster( + client=True, + nthreads=[("", 1)], + config={"distributed.scheduler.work-stealing": False}, +) +async def test_decide_worker_rootish_while_last_worker_is_retiring(c, s, a): + """https://github.com/dask/distributed/issues/7063""" + # Put a task in memory on the worker to be retired and prevent the other from + # acquiring a replica. This will cause a to be stuck in closing_gracefully later on, + # until we set b.block_gather_dep. + m = (await c.scatter({"m": 1}, workers=[a.address]))["m"] + + evx = [Event() for _ in range(3)] + evy = Event() + + async with BlockedGatherDep(s.address, nthreads=1) as b: + xs = [ + c.submit(lambda ev: ev.wait(), evx[0], key="x-0", workers=[a.address]), + c.submit(lambda ev: ev.wait(), evx[1], key="x-1", workers=[a.address]), + c.submit(lambda ev: ev.wait(), evx[2], key="x-2", workers=[b.address]), + ] + ys = [ + c.submit(lambda x, ev: ev.wait(), xs[0], evy, key="y-0"), + c.submit(lambda x, ev: ev.wait(), xs[0], evy, key="y-1"), + c.submit(lambda x, ev: ev.wait(), xs[1], evy, key="y-2"), + c.submit(lambda x, ev: ev.wait(), xs[2], evy, key="y-3"), + c.submit(lambda x, ev: ev.wait(), xs[2], evy, key="y-4"), + c.submit(lambda x, ev: ev.wait(), xs[2], evy, key="y-5"), + ] + + while a.state.executing_count != 1 or b.state.executing_count != 1: + await asyncio.sleep(0.01) + + # - y-2 has no restrictions + # - TaskGroup(y) has more than 4 tasks (total_nthreads * 2) + # - TaskGroup(y) has less than 5 dependency groups + # - TaskGroup(y) has less than 5 dependency tasks + assert s.is_rootish(s.tasks["y-2"]) + + await evx[0].set() + await wait_for_state("y-0", "processing", s) + await wait_for_state("y-1", "processing", s) + assert s.tasks["y-2"].group.last_worker == s.workers[a.address] + assert s.tasks["y-2"].group.last_worker_tasks_left == 1 + + # Take a out of the running pool, but without removing it from the cluster + # completely + retire_task = asyncio.create_task(c.retire_workers([a.address])) + # Wait until AMM sends AcquireReplicasEvent to b to move away m + await b.in_gather_dep.wait() + assert s.workers[a.address].status == Status.closing_gracefully + + # Transition y-2 to processing. Normally, it would be scheduled on a, but it's + # not a running worker, so we must choose b + await evx[1].set() + await wait_for_state("y-2", "processing", s) + await wait_for_state("y-2", "waiting", b) # x-1 is in memory on a + + # Cleanup + b.block_gather_dep.set() + await evx[2].set() + await evy.set() + await retire_task + await wait(xs + ys) + + @pytest.mark.slow @gen_cluster( nthreads=[("", 2)] * 4,