diff --git a/distributed/batched.py b/distributed/batched.py index 313aab67b56..89e99719e9c 100644 --- a/distributed/batched.py +++ b/distributed/batched.py @@ -127,7 +127,7 @@ def _background_send(self): self.stopped.set() self.abort() - def send(self, msg): + def send(self, *msgs): """Schedule a message for sending to the other side This completes quickly and synchronously @@ -135,8 +135,8 @@ def send(self, msg): if self.comm is not None and self.comm.closed(): raise CommClosedError - self.message_count += 1 - self.buffer.append(msg) + self.message_count += len(msgs) + self.buffer.extend(msgs) # Avoid spurious wakeups if possible if self.next_deadline is None: self.waker.set() diff --git a/distributed/scheduler.py b/distributed/scheduler.py index bbd8f01aeda..c9d17a7cbaa 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -182,7 +182,10 @@ def nogil(func): EventExtension, ] -ALL_TASK_STATES = {"released", "waiting", "no-worker", "processing", "erred", "memory"} +ALL_TASK_STATES = declare( + set, {"released", "waiting", "no-worker", "processing", "erred", "memory"} +) +globals()["ALL_TASK_STATES"] = ALL_TASK_STATES @final @@ -1961,7 +1964,7 @@ def transition_waiting_processing(self, key): # logger.debug("Send job to worker: %s, %s", worker, key) - worker_msgs[worker] = _task_to_msg(self, ts) + worker_msgs[worker] = [_task_to_msg(self, ts)] return {}, worker_msgs, client_msgs except Exception as e: @@ -2168,11 +2171,13 @@ def transition_memory_released(self, key, safe: bint = False): ws._has_what.remove(ts) ws._nbytes -= ts.get_nbytes() ts._group._nbytes_in_memory -= ts.get_nbytes() - worker_msgs[ws._address] = { - "op": "delete-data", - "keys": [key], - "report": False, - } + worker_msgs[ws._address] = [ + { + "op": "delete-data", + "keys": [key], + "report": False, + } + ] ts._who_has.clear() @@ -2181,7 +2186,7 @@ def transition_memory_released(self, key, safe: bint = False): report_msg = {"op": "lost-data", "key": key} cs: ClientState for cs in ts._who_wants: - client_msgs[cs._client_key] = report_msg + client_msgs[cs._client_key] = [report_msg] if not ts._run_spec: # pure data recommendations[key] = "forgotten" @@ -2234,7 +2239,7 @@ def transition_released_erred(self, key): } cs: ClientState for cs in ts._who_wants: - client_msgs[cs._client_key] = report_msg + client_msgs[cs._client_key] = [report_msg] ts.state = "erred" @@ -2276,7 +2281,7 @@ def transition_erred_released(self, key): report_msg = {"op": "task-retried", "key": key} cs: ClientState for cs in ts._who_wants: - client_msgs[cs._client_key] = report_msg + client_msgs[cs._client_key] = [report_msg] ts.state = "released" @@ -2343,7 +2348,7 @@ def transition_processing_released(self, key): w: str = _remove_from_processing(self, ts) if w: - worker_msgs[w] = {"op": "release-task", "key": key} + worker_msgs[w] = [{"op": "release-task", "key": key}] ts.state = "released" @@ -2432,7 +2437,7 @@ def transition_processing_erred( } cs: ClientState for cs in ts._who_wants: - client_msgs[cs._client_key] = report_msg + client_msgs[cs._client_key] = [report_msg] cs = self._clients["fire-and-forget"] if ts in cs._wants_what: @@ -4706,6 +4711,29 @@ def client_send(self, client, msg): if self.status == Status.running: logger.critical("Tried writing to closed comm: %s", msg) + def send_all(self, client_msgs: dict, worker_msgs: dict): + """Send messages to client and workers""" + stream_comms: dict = self.stream_comms + client_comms: dict = self.client_comms + msgs: list + + for worker, msgs in worker_msgs.items(): + try: + w = stream_comms[worker] + w.send(*msgs) + except (CommClosedError, AttributeError): + self.loop.add_callback(self.remove_worker, address=worker) + + for client, msgs in client_msgs.items(): + c = client_comms.get(client) + if c is None: + continue + try: + c.send(*msgs) + except CommClosedError: + if self.status == Status.running: + logger.critical("Tried writing to closed comm: %s", msgs) + ############################ # Less common interactions # ############################ @@ -5814,12 +5842,12 @@ async def register_worker_plugin(self, comm, plugin, name=None): # State Transitions # ##################### - def transition(self, key, finish, *args, **kwargs): + def _transition(self, key, finish: str, *args, **kwargs): """Transition a key from its current state to the finish state Examples -------- - >>> self.transition('x', 'waiting') + >>> self._transition('x', 'waiting') {'x': 'processing'} Returns @@ -5832,47 +5860,85 @@ def transition(self, key, finish, *args, **kwargs): """ parent: SchedulerState = cast(SchedulerState, self) ts: TaskState + start: str + start_finish: tuple + finish2: str + recommendations: dict worker_msgs: dict client_msgs: dict + msgs: list + new_msgs: list + dependents: set + dependencies: set try: - try: - ts = parent._tasks[key] - except KeyError: - return {} + recommendations = {} + worker_msgs = {} + client_msgs = {} + + ts = parent._tasks.get(key) + if ts is None: + return recommendations, worker_msgs, client_msgs start = ts._state if start == finish: - return {} + return recommendations, worker_msgs, client_msgs if self.plugins: dependents = set(ts._dependents) dependencies = set(ts._dependencies) - recommendations: dict = {} - worker_msgs = {} - client_msgs = {} - if (start, finish) in self._transitions: - func = self._transitions[start, finish] - recommendations, worker_msgs, client_msgs = func(key, *args, **kwargs) - elif "released" not in (start, finish): + start_finish = (start, finish) + func = self._transitions.get(start_finish) + if func is not None: + a: tuple = func(key, *args, **kwargs) + recommendations, worker_msgs, client_msgs = a + elif "released" not in start_finish: func = self._transitions["released", finish] assert not args and not kwargs - a = self.transition(key, "released") - if key in a: - func = self._transitions["released", a[key]] - b, worker_msgs, client_msgs = func(key) - a = a.copy() - a.update(b) - recommendations = a + a_recs: dict + a_wmsgs: dict + a_cmsgs: dict + a: tuple = self._transition(key, "released") + a_recs, a_wmsgs, a_cmsgs = a + v = a_recs.get(key) + if v is not None: + func = self._transitions["released", v] + b_recs: dict + b_wmsgs: dict + b_cmsgs: dict + b: tuple = func(key) + b_recs, b_wmsgs, b_cmsgs = b + + recommendations.update(a_recs) + for w, new_msgs in a_wmsgs.items(): + msgs = worker_msgs.get(w) + if msgs is not None: + msgs.extend(new_msgs) + else: + worker_msgs[w] = new_msgs + for c, new_msgs in a_cmsgs.items(): + msgs = client_msgs.get(c) + if msgs is not None: + msgs.extend(new_msgs) + else: + client_msgs[c] = new_msgs + + recommendations.update(b_recs) + for w, new_msgs in b_wmsgs.items(): + msgs = worker_msgs.get(w) + if msgs is not None: + msgs.extend(new_msgs) + else: + worker_msgs[w] = new_msgs + for c, new_msgs in b_cmsgs.items(): + msgs = client_msgs.get(c) + if msgs is not None: + msgs.extend(new_msgs) + else: + client_msgs[c] = new_msgs + start = "released" else: - raise RuntimeError( - "Impossible transition from %r to %r" % (start, finish) - ) - - for worker, msg in worker_msgs.items(): - self.worker_send(worker, msg) - for client, msg in client_msgs.items(): - self.client_send(client, msg) + raise RuntimeError("Impossible transition from %r to %r" % start_finish) finish2 = ts._state self.transition_log.append((key, start, finish2, recommendations, time())) @@ -5888,11 +5954,8 @@ def transition(self, key, finish, *args, **kwargs): if self.plugins: # Temporarily put back forgotten key for plugin to retrieve it if ts._state == "forgotten": - try: - ts._dependents = dependents - ts._dependencies = dependencies - except KeyError: - pass + ts._dependents = dependents + ts._dependencies = dependencies parent._tasks[ts._key] = ts for plugin in list(self.plugins): try: @@ -5905,11 +5968,16 @@ def transition(self, key, finish, *args, **kwargs): tg: TaskGroup = ts._group if ts._state == "forgotten" and tg._name in parent._task_groups: # Remove TaskGroup if all tasks are in the forgotten state - if not any([tg._states.get(s) for s in ALL_TASK_STATES]): + all_forgotten: bint = True + for s in ALL_TASK_STATES: + if tg._states.get(s): + all_forgotten = False + break + if all_forgotten: ts._prefix._groups.remove(tg) del parent._task_groups[tg._name] - return recommendations + return recommendations, worker_msgs, client_msgs except Exception as e: logger.exception("Error transitioning %r from %r to %r", key, start, finish) if LOG_PDB: @@ -5918,6 +5986,30 @@ def transition(self, key, finish, *args, **kwargs): pdb.set_trace() raise + def transition(self, key, finish: str, *args, **kwargs): + """Transition a key from its current state to the finish state + + Examples + -------- + >>> self.transition('x', 'waiting') + {'x': 'processing'} + + Returns + ------- + Dictionary of recommendations for future transitions + + See Also + -------- + Scheduler.transitions: transitive version of this function + """ + recommendations: dict + worker_msgs: dict + client_msgs: dict + a: tuple = self._transition(key, finish, *args, **kwargs) + recommendations, worker_msgs, client_msgs = a + self.send_all(client_msgs, worker_msgs) + return recommendations + def transitions(self, recommendations: dict): """Process transitions until none are left @@ -5925,13 +6017,38 @@ def transitions(self, recommendations: dict): reach a steady state """ parent: SchedulerState = cast(SchedulerState, self) - keys = set() + keys: set = set() recommendations = recommendations.copy() + worker_msgs: dict = {} + client_msgs: dict = {} + msgs: list + new_msgs: list + new: tuple + new_recs: dict + new_wmsgs: dict + new_cmsgs: dict while recommendations: key, finish = recommendations.popitem() keys.add(key) - new = self.transition(key, finish) - recommendations.update(new) + + new = self._transition(key, finish) + new_recs, new_wmsgs, new_cmsgs = new + + recommendations.update(new_recs) + for w, new_msgs in new_wmsgs.items(): + msgs = worker_msgs.get(w) + if msgs is not None: + msgs.extend(new_msgs) + else: + worker_msgs[w] = new_msgs + for c, new_msgs in new_cmsgs.items(): + msgs = client_msgs.get(c) + if msgs is not None: + msgs.extend(new_msgs) + else: + client_msgs[c] = new_msgs + + self.send_all(client_msgs, worker_msgs) if parent._validate: for key in keys: @@ -6513,7 +6630,7 @@ def _add_to_memory( report_msg["type"] = type for cs in ts._who_wants: - client_msgs[cs._client_key] = report_msg + client_msgs[cs._client_key] = [report_msg] ts.state = "memory" ts._type = typename @@ -6567,7 +6684,7 @@ def _propagate_forgotten( ws._nbytes -= ts.get_nbytes() w: str = ws._address if w in state._workers_dv: # in case worker has died - worker_msgs[w] = {"op": "delete-data", "keys": [key], "report": False} + worker_msgs[w] = [{"op": "delete-data", "keys": [key], "report": False}] ts._who_has.clear() @@ -6674,7 +6791,7 @@ def _task_to_client_msgs(state: SchedulerState, ts: TaskState) -> dict: client_msgs: dict = {} for k in client_keys: - client_msgs[k] = report_msg + client_msgs[k] = [report_msg] return client_msgs