diff --git a/distributed/scheduler.py b/distributed/scheduler.py index c13f5292538..25d4629db17 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3977,15 +3977,20 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): parent: SchedulerState = cast(SchedulerState, self) logger.debug("Stimulus task finished %s, %s", key, worker) + recommendations: dict = {} + worker_msgs: dict = {} + client_msgs: dict = {} + ts: TaskState = parent._tasks.get(key) if ts is None: - return {} + return recommendations, worker_msgs, client_msgs + ws: WorkerState = parent._workers_dv[worker] ts._metadata.update(kwargs["metadata"]) - recommendations: dict if ts._state == "processing": - recommendations = self.transition(key, "memory", worker=worker, **kwargs) + a: tuple = self._transition(key, "memory", worker=worker, **kwargs) + recommendations, worker_msgs, client_msgs = a if ts._state == "memory": assert ws in ts._who_has @@ -3999,10 +4004,9 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): ts._who_has, ) if ws not in ts._who_has: - self.worker_send(worker, {"op": "release-task", "key": key}) - recommendations = {} + worker_msgs[worker] = {"op": "release-task", "key": key} - return recommendations + return recommendations, worker_msgs, client_msgs def stimulus_task_erred( self, key=None, worker=None, exception=None, traceback=None, **kwargs @@ -4587,8 +4591,46 @@ def handle_task_finished(self, key=None, worker=None, **msg): if worker not in parent._workers_dv: return validate_key(key) - r = self.stimulus_task_finished(key=key, worker=worker, **msg) - self.transitions(r) + + a: tuple = self.stimulus_task_finished(key=key, worker=worker, **msg) + recommendations, worker_msgs, client_msgs = a + + keys: set = set() + recommendations = recommendations.copy() + worker_msgs: dict = {} + client_msgs: dict = {} + msgs: list + new_msgs: list + new: tuple + new_recs: dict + new_wmsgs: dict + new_cmsgs: dict + while recommendations: + key, finish = recommendations.popitem() + keys.add(key) + + new = self._transition(key, finish) + new_recs, new_wmsgs, new_cmsgs = new + + recommendations.update(new_recs) + for w, new_msgs in new_wmsgs.items(): + msgs = worker_msgs.get(w) + if msgs is not None: + msgs.extend(new_msgs) + else: + worker_msgs[w] = new_msgs + for c, new_msgs in new_cmsgs.items(): + msgs = client_msgs.get(c) + if msgs is not None: + msgs.extend(new_msgs) + else: + client_msgs[c] = new_msgs + + self.send_all(client_msgs, worker_msgs) + + if parent._validate: + for key in keys: + self.validate_key(key) def handle_task_erred(self, key=None, **msg): r = self.stimulus_task_erred(key=key, **msg)