diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 911bf4fba3d..0ddd630b71f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1769,8 +1769,8 @@ def transition_released_waiting(self, key): ts: TaskState = self._tasks[key] dts: TaskState recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: assert ts._run_spec @@ -1781,7 +1781,7 @@ def transition_released_waiting(self, key): if ts._has_lost_dependencies: recommendations[key] = "forgotten" - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs ts.state = "waiting" @@ -1790,7 +1790,7 @@ def transition_released_waiting(self, key): if dts._exception_blame: ts._exception_blame = dts._exception_blame recommendations[key] = "erred" - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs for dts in ts._dependencies: dep = dts._key @@ -1810,7 +1810,7 @@ def transition_released_waiting(self, key): self._unrunnable.add(ts) ts.state = "no-worker" - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -1824,8 +1824,8 @@ def transition_no_worker_waiting(self, key): ts: TaskState = self._tasks[key] dts: TaskState recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: assert ts in self._unrunnable @@ -1837,7 +1837,7 @@ def transition_no_worker_waiting(self, key): if ts._has_lost_dependencies: recommendations[key] = "forgotten" - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs for dts in ts._dependencies: dep = dts._key @@ -1857,7 +1857,7 @@ def transition_no_worker_waiting(self, key): self._unrunnable.add(ts) ts.state = "no-worker" - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -1934,8 +1934,8 @@ def transition_waiting_processing(self, key): ts: TaskState = self._tasks[key] dts: TaskState recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: assert not ts._waiting_on @@ -1948,7 +1948,7 @@ def transition_waiting_processing(self, key): ws: WorkerState = self.decide_worker(ts) if ws is None: - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs worker = ws._address duration_estimate = self.set_duration_estimate(ts, ws) @@ -1967,7 +1967,7 @@ def transition_waiting_processing(self, key): worker_msgs[worker] = [_task_to_msg(self, ts)] - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -1983,8 +1983,8 @@ def transition_waiting_memory( ws: WorkerState = self._workers_dv[worker] ts: TaskState = self._tasks[key] recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: assert not ts._processing_on @@ -2007,7 +2007,7 @@ def transition_waiting_memory( assert not ts._waiting_on assert ts._who_has - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -2029,8 +2029,8 @@ def transition_processing_memory( ws: WorkerState wws: WorkerState recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} try: ts: TaskState = self._tasks[key] assert worker @@ -2048,7 +2048,7 @@ def transition_processing_memory( ws = self._workers_dv.get(worker) if ws is None: recommendations[key] = "released" - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs if ws != ts._processing_on: # someone else has this task logger.info( @@ -2058,7 +2058,7 @@ def transition_processing_memory( ws, key, ) - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs has_compute_startstop: bool = False compute_start: double @@ -2124,7 +2124,7 @@ def transition_processing_memory( assert not ts._processing_on assert not ts._waiting_on - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -2139,8 +2139,8 @@ def transition_memory_released(self, key, safe: bint = False): ts: TaskState = self._tasks[key] dts: TaskState recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: assert not ts._waiting_on @@ -2157,8 +2157,8 @@ def transition_memory_released(self, key, safe: bint = False): recommendations[ts._key] = "erred" return ( recommendations, - worker_msgs, client_msgs, + worker_msgs, ) # don't try to recreate for dts in ts._waiters: @@ -2199,7 +2199,7 @@ def transition_memory_released(self, key, safe: bint = False): if self._validate: assert not ts._waiting_on - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -2214,8 +2214,8 @@ def transition_released_erred(self, key): dts: TaskState failing_ts: TaskState recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: with log_errors(pdb=LOG_PDB): @@ -2244,7 +2244,7 @@ def transition_released_erred(self, key): ts.state = "erred" # TODO: waiting data? - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -2258,8 +2258,8 @@ def transition_erred_released(self, key): ts: TaskState = self._tasks[key] dts: TaskState recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: with log_errors(pdb=LOG_PDB): @@ -2284,7 +2284,7 @@ def transition_erred_released(self, key): ts.state = "released" - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -2297,8 +2297,8 @@ def transition_waiting_released(self, key): try: ts: TaskState = self._tasks[key] recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: assert not ts._who_has @@ -2322,7 +2322,7 @@ def transition_waiting_released(self, key): else: ts._waiters.clear() - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -2336,8 +2336,8 @@ def transition_processing_released(self, key): ts: TaskState = self._tasks[key] dts: TaskState recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: assert ts._processing_on @@ -2368,7 +2368,7 @@ def transition_processing_released(self, key): if self._validate: assert not ts._processing_on - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -2386,8 +2386,8 @@ def transition_processing_erred( dts: TaskState failing_ts: TaskState recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: assert cause or ts._exception_blame @@ -2447,7 +2447,7 @@ def transition_processing_erred( if self._validate: assert not ts._processing_on - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -2461,8 +2461,8 @@ def transition_no_worker_released(self, key): ts: TaskState = self._tasks[key] dts: TaskState recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: assert self._tasks[key].state == "no-worker" @@ -2477,7 +2477,7 @@ def transition_no_worker_released(self, key): ts._waiters.clear() - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -2504,8 +2504,8 @@ def transition_memory_forgotten(self, key): try: ts: TaskState = self._tasks[key] recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: assert ts._state == "memory" @@ -2532,7 +2532,7 @@ def transition_memory_forgotten(self, key): client_msgs = _task_to_client_msgs(self, ts) self.remove_key(key) - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -2545,8 +2545,8 @@ def transition_released_forgotten(self, key): try: ts: TaskState = self._tasks[key] recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: assert ts._state in ("released", "erred") @@ -2570,7 +2570,7 @@ def transition_released_forgotten(self, key): client_msgs = _task_to_client_msgs(self, ts) self.remove_key(key) - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -3120,7 +3120,7 @@ def __init__( "stop_task_metadata": self.stop_task_metadata, } - self._transitions = { + self._transitions_table = { ("released", "waiting"): self.transition_released_waiting, ("waiting", "released"): self.transition_waiting_released, ("waiting", "processing"): self.transition_waiting_processing, @@ -3564,28 +3564,34 @@ async def add_worker( except Exception as e: logger.exception(e) - recommendations: dict + recommendations: dict = {} + client_msgs: dict = {} + worker_msgs: dict = {} if nbytes: for key in nbytes: ts: TaskState = parent._tasks.get(key) if ts is not None and ts._state in ("processing", "waiting"): - recommendations = self.transition( + t: tuple = self._transition( key, "memory", worker=address, nbytes=nbytes[key], typename=types[key], ) - self.transitions(recommendations) + recommendations, client_msgs, worker_msgs = t + self._transitions(recommendations, client_msgs, worker_msgs) + recommendations = {} - recommendations = {} for ts in list(parent._unrunnable): valid: set = self.valid_workers(ts) if valid is None or ws in valid: recommendations[ts._key] = "waiting" if recommendations: - self.transitions(recommendations) + self._transitions(recommendations, client_msgs, worker_msgs) + recommendations = {} + + self.send_all(client_msgs, worker_msgs) self.log_event(address, {"action": "add-worker"}) self.log_event("all", {"action": "add-worker", "worker": address}) @@ -3977,15 +3983,19 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): parent: SchedulerState = cast(SchedulerState, self) logger.debug("Stimulus task finished %s, %s", key, worker) + recommendations: dict = {} + client_msgs: dict = {} + worker_msgs: dict = {} + ts: TaskState = parent._tasks.get(key) if ts is None: - return {} + return recommendations, client_msgs, worker_msgs ws: WorkerState = parent._workers_dv[worker] ts._metadata.update(kwargs["metadata"]) - recommendations: dict if ts._state == "processing": - recommendations = self.transition(key, "memory", worker=worker, **kwargs) + r: tuple = self._transition(key, "memory", worker=worker, **kwargs) + recommendations, client_msgs, worker_msgs = r if ts._state == "memory": assert ws in ts._who_has @@ -3999,10 +4009,9 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): ts._who_has, ) if ws not in ts._who_has: - self.worker_send(worker, {"op": "release-task", "key": key}) - recommendations = {} + worker_msgs[worker] = [{"op": "release-task", "key": key}] - return recommendations + return recommendations, client_msgs, worker_msgs def stimulus_task_erred( self, key=None, worker=None, exception=None, traceback=None, **kwargs @@ -4011,18 +4020,22 @@ def stimulus_task_erred( parent: SchedulerState = cast(SchedulerState, self) logger.debug("Stimulus task erred %s, %s", key, worker) + recommendations: dict = {} + client_msgs: dict = {} + worker_msgs: dict = {} + ts: TaskState = parent._tasks.get(key) if ts is None: - return {} + return recommendations, client_msgs, worker_msgs - recommendations: dict if ts._state == "processing": - retries = ts._retries + retries: Py_ssize_t = ts._retries + r: tuple if retries > 0: ts._retries = retries - 1 - recommendations = self.transition(key, "waiting") + r = self._transition(key, "waiting") else: - recommendations = self.transition( + r = self._transition( key, "erred", cause=key, @@ -4031,10 +4044,9 @@ def stimulus_task_erred( worker=worker, **kwargs, ) - else: - recommendations = {} + recommendations, client_msgs, worker_msgs = r - return recommendations + return recommendations, client_msgs, worker_msgs def stimulus_missing_data( self, cause=None, key=None, worker=None, ensure=True, **kwargs @@ -4044,13 +4056,15 @@ def stimulus_missing_data( with log_errors(): logger.debug("Stimulus missing data %s, %s", key, worker) + recommendations: dict = {} + client_msgs: dict = {} + worker_msgs: dict = {} + ts: TaskState = parent._tasks.get(key) if ts is None or ts._state == "memory": - return {} + return recommendations, client_msgs, worker_msgs cts: TaskState = parent._tasks.get(cause) - recommendations: dict = {} - if cts is not None and cts._state == "memory": # couldn't find this ws: WorkerState cts_nbytes: Py_ssize_t = cts.get_nbytes() @@ -4063,12 +4077,13 @@ def stimulus_missing_data( if key: recommendations[key] = "released" - self.transitions(recommendations) + self._transitions(recommendations, client_msgs, worker_msgs) + recommendations = {} if parent._validate: assert cause not in self.who_has - return {} + return recommendations, client_msgs, worker_msgs def stimulus_retry(self, comm=None, keys=None, client=None): parent: SchedulerState = cast(SchedulerState, self) @@ -4588,12 +4603,27 @@ def handle_task_finished(self, key=None, worker=None, **msg): if worker not in parent._workers_dv: return validate_key(key) - r = self.stimulus_task_finished(key=key, worker=worker, **msg) - self.transitions(r) + + recommendations: dict + client_msgs: dict + worker_msgs: dict + + r: tuple = self.stimulus_task_finished(key=key, worker=worker, **msg) + recommendations, client_msgs, worker_msgs = r + self._transitions(recommendations, client_msgs, worker_msgs) + + self.send_all(client_msgs, worker_msgs) def handle_task_erred(self, key=None, **msg): - r = self.stimulus_task_erred(key=key, **msg) - self.transitions(r) + recommendations: dict + client_msgs: dict + worker_msgs: dict + + r: tuple = self.stimulus_task_erred(key=key, **msg) + recommendations, client_msgs, worker_msgs = r + self._transitions(recommendations, client_msgs, worker_msgs) + + self.send_all(client_msgs, worker_msgs) def handle_release_data(self, key=None, worker=None, client=None, **msg): parent: SchedulerState = cast(SchedulerState, self) @@ -4603,8 +4633,16 @@ def handle_release_data(self, key=None, worker=None, client=None, **msg): ws: WorkerState = parent._workers_dv[worker] if ts._processing_on != ws: return - r = self.stimulus_missing_data(key=key, ensure=False, **msg) - self.transitions(r) + + recommendations: dict + client_msgs: dict + worker_msgs: dict + + r: tuple = self.stimulus_missing_data(key=key, ensure=False, **msg) + recommendations, client_msgs, worker_msgs = r + self._transitions(recommendations, client_msgs, worker_msgs) + + self.send_all(client_msgs, worker_msgs) def handle_missing_data(self, key=None, errant_worker=None, **kwargs): parent: SchedulerState = cast(SchedulerState, self) @@ -4742,17 +4780,10 @@ def client_send(self, client, msg): def send_all(self, client_msgs: dict, worker_msgs: dict): """Send messages to client and workers""" - stream_comms: dict = self.stream_comms client_comms: dict = self.client_comms + stream_comms: dict = self.stream_comms msgs: list - for worker, msgs in worker_msgs.items(): - try: - w = stream_comms[worker] - w.send(*msgs) - except (CommClosedError, AttributeError): - self.loop.add_callback(self.remove_worker, address=worker) - for client, msgs in client_msgs.items(): c = client_comms.get(client) if c is None: @@ -4763,6 +4794,13 @@ def send_all(self, client_msgs: dict, worker_msgs: dict): if self.status == Status.running: logger.critical("Tried writing to closed comm: %s", msgs) + for worker, msgs in worker_msgs.items(): + try: + w = stream_comms[worker] + w.send(*msgs) + except (CommClosedError, AttributeError): + self.loop.add_callback(self.remove_worker, address=worker) + ############################ # Less common interactions # ############################ @@ -4855,6 +4893,10 @@ async def gather(self, comm=None, keys=None, serializers=None): for worker in missing_workers ] ) + + recommendations: dict + client_msgs: dict = {} + worker_msgs: dict = {} for key, workers in missing_keys.items(): # Task may already be gone if it was held by a # `missing_worker` @@ -4867,13 +4909,15 @@ async def gather(self, comm=None, keys=None, serializers=None): if not workers or ts is None: continue ts_nbytes: Py_ssize_t = ts.get_nbytes() + recommendations: dict = {key: "released"} for worker in workers: ws = parent._workers_dv.get(worker) if ws is not None and ts in ws._has_what: ws._has_what.remove(ts) ts._who_has.remove(ws) ws._nbytes -= ts_nbytes - self.transitions({key: "released"}) + self._transitions(recommendations, client_msgs, worker_msgs) + self.send_all(client_msgs, worker_msgs) self.log_event("all", {"action": "gather", "count": len(keys)}) return result @@ -5915,64 +5959,64 @@ def _transition(self, key, finish: str, *args, **kwargs): ts = parent._tasks.get(key) if ts is None: - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs start = ts._state if start == finish: - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs if self.plugins: dependents = set(ts._dependents) dependencies = set(ts._dependencies) start_finish = (start, finish) - func = self._transitions.get(start_finish) + func = self._transitions_table.get(start_finish) if func is not None: a: tuple = func(key, *args, **kwargs) - recommendations, worker_msgs, client_msgs = a + recommendations, client_msgs, worker_msgs = a elif "released" not in start_finish: - func = self._transitions["released", finish] + func = self._transitions_table["released", finish] assert not args and not kwargs a_recs: dict - a_wmsgs: dict a_cmsgs: dict + a_wmsgs: dict a: tuple = self._transition(key, "released") - a_recs, a_wmsgs, a_cmsgs = a + a_recs, a_cmsgs, a_wmsgs = a v = a_recs.get(key) if v is not None: - func = self._transitions["released", v] + func = self._transitions_table["released", v] b_recs: dict - b_wmsgs: dict b_cmsgs: dict + b_wmsgs: dict b: tuple = func(key) - b_recs, b_wmsgs, b_cmsgs = b + b_recs, b_cmsgs, b_wmsgs = b recommendations.update(a_recs) - for w, new_msgs in a_wmsgs.items(): - msgs = worker_msgs.get(w) - if msgs is not None: - msgs.extend(new_msgs) - else: - worker_msgs[w] = new_msgs for c, new_msgs in a_cmsgs.items(): msgs = client_msgs.get(c) if msgs is not None: msgs.extend(new_msgs) else: client_msgs[c] = new_msgs - - recommendations.update(b_recs) - for w, new_msgs in b_wmsgs.items(): + for w, new_msgs in a_wmsgs.items(): msgs = worker_msgs.get(w) if msgs is not None: msgs.extend(new_msgs) else: worker_msgs[w] = new_msgs + + recommendations.update(b_recs) for c, new_msgs in b_cmsgs.items(): msgs = client_msgs.get(c) if msgs is not None: msgs.extend(new_msgs) else: client_msgs[c] = new_msgs + for w, new_msgs in b_wmsgs.items(): + msgs = worker_msgs.get(w) + if msgs is not None: + msgs.extend(new_msgs) + else: + worker_msgs[w] = new_msgs start = "released" else: @@ -6015,7 +6059,7 @@ def _transition(self, key, finish: str, *args, **kwargs): ts._prefix._groups.remove(tg) del parent._task_groups[tg._name] - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception("Error transitioning %r from %r to %r", key, start, finish) if LOG_PDB: @@ -6044,11 +6088,11 @@ def transition(self, key, finish: str, *args, **kwargs): worker_msgs: dict client_msgs: dict a: tuple = self._transition(key, finish, *args, **kwargs) - recommendations, worker_msgs, client_msgs = a + recommendations, client_msgs, worker_msgs = a self.send_all(client_msgs, worker_msgs) return recommendations - def transitions(self, recommendations: dict): + def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: dict): """Process transitions until none are left This includes feedback from previous transitions and continues until we @@ -6057,41 +6101,48 @@ def transitions(self, recommendations: dict): parent: SchedulerState = cast(SchedulerState, self) keys: set = set() recommendations = recommendations.copy() - worker_msgs: dict = {} - client_msgs: dict = {} msgs: list new_msgs: list new: tuple new_recs: dict - new_wmsgs: dict new_cmsgs: dict + new_wmsgs: dict while recommendations: key, finish = recommendations.popitem() keys.add(key) new = self._transition(key, finish) - new_recs, new_wmsgs, new_cmsgs = new + new_recs, new_cmsgs, new_wmsgs = new recommendations.update(new_recs) - for w, new_msgs in new_wmsgs.items(): - msgs = worker_msgs.get(w) - if msgs is not None: - msgs.extend(new_msgs) - else: - worker_msgs[w] = new_msgs for c, new_msgs in new_cmsgs.items(): msgs = client_msgs.get(c) if msgs is not None: msgs.extend(new_msgs) else: client_msgs[c] = new_msgs - - self.send_all(client_msgs, worker_msgs) + for w, new_msgs in new_wmsgs.items(): + msgs = worker_msgs.get(w) + if msgs is not None: + msgs.extend(new_msgs) + else: + worker_msgs[w] = new_msgs if parent._validate: for key in keys: self.validate_key(key) + def transitions(self, recommendations: dict): + """Process transitions until none are left + + This includes feedback from previous transitions and continues until we + reach a steady state + """ + client_msgs: dict = {} + worker_msgs: dict = {} + self._transitions(recommendations, client_msgs, worker_msgs) + self.send_all(client_msgs, worker_msgs) + def story(self, *keys): """ Get all transitions that touch one of the input keys """ keys = {key.key if isinstance(key, TaskState) else key for key in keys}