diff --git a/distributed/scheduler.py b/distributed/scheduler.py index d0e31aa131f..98a4174775d 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3288,20 +3288,7 @@ def bulk_schedule_after_adding_worker(self, ws: WorkerState) -> Recs: Returns priority-ordered recommendations. """ - maybe_runnable: list[TaskState] = [] - # Schedule any queued tasks onto the new worker - if not math.isinf(self.WORKER_SATURATION) and self.queued: - for qts in reversed( - list( - self.queued.peekn(_task_slots_available(ws, self.WORKER_SATURATION)) - ) - ): - if self.validate: - assert qts.state == "queued" - assert not qts.processing_on - assert not qts.waiting_on - - maybe_runnable.append(qts) + maybe_runnable = list(_next_queued_tasks_for_worker(self, ws))[::-1] # Schedule any restricted tasks onto the new worker, if the worker can run them for ts in self.unrunnable: @@ -5338,6 +5325,14 @@ def handle_long_running( ws.add_to_long_running(ts) self.check_idle_saturated(ws) + recommendations = { + qts.key: "processing" for qts in _next_queued_tasks_for_worker(self, ws) + } + if self.validate: + assert len(recommendations) <= 1, (ws, recommendations) + + self.transitions(recommendations, stimulus_id) + def handle_worker_status_change( self, status: str | Status, worker: str | WorkerState, stimulus_id: str ) -> None: @@ -7886,21 +7881,32 @@ def _exit_processing_common( state.check_idle_saturated(ws) state.release_resources(ts, ws) - # If a slot has opened up for a queued task, schedule it. - if state.queued and not _worker_full(ws, state.WORKER_SATURATION): - qts = state.queued.peek() + for qts in _next_queued_tasks_for_worker(state, ws): if state.validate: - assert qts.state == "queued", qts.state assert qts.key not in recommendations, recommendations[qts.key] - - # NOTE: we don't need to schedule more than one task at once here. Since this is - # called each time 1 task completes, multiple tasks must complete for multiple - # slots to open up. recommendations[qts.key] = "processing" return ws +def _next_queued_tasks_for_worker( + state: SchedulerState, ws: WorkerState +) -> Iterator[TaskState]: + """Queued tasks to run, in priority order, on all open slots on a worker""" + if not state.queued or ws.status != Status.running: + return + + # NOTE: this is called most frequently because a single task has completed, so there + # are <= 1 task slots available on the worker. + # `peekn` has fast paths for the cases N<=0 and N==1. + for qts in state.queued.peekn(_task_slots_available(ws, state.WORKER_SATURATION)): + if state.validate: + assert qts.state == "queued", qts.state + assert not qts.processing_on + assert not qts.waiting_on + yield qts + + def _add_to_memory( state: SchedulerState, ts: TaskState, diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 6e84db227d9..db231c317dd 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -64,7 +64,7 @@ varying, wait_for_state, ) -from distributed.worker import dumps_function, dumps_task, get_worker +from distributed.worker import dumps_function, dumps_task, get_worker, secede pytestmark = pytest.mark.ci1 @@ -479,6 +479,26 @@ async def test_queued_remove_add_worker(c, s, a, b): await wait(fs) +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_secede_opens_slot(c, s, a): + first = Event() + second = Event() + + def func(first, second): + first.wait() + secede() + second.wait() + + fs = c.map(func, [first] * 5, [second] * 5) + await async_wait_for(lambda: a.state.executing, timeout=5) + + await first.set() + await async_wait_for(lambda: len(a.state.long_running) == len(fs), timeout=5) + + await second.set() + await c.gather(fs) + + @pytest.mark.parametrize( "saturation_config, expected_task_counts", [