From 8c5e9e010a986ce87d5225855d55e8c4d10ff040 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 22 Feb 2021 18:18:19 -0800 Subject: [PATCH 01/13] Annotate `retries` as `Py_ssize_t` --- distributed/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 911bf4fba3d..6a6934d2e37 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4017,7 +4017,7 @@ def stimulus_task_erred( recommendations: dict if ts._state == "processing": - retries = ts._retries + retries: Py_ssize_t = ts._retries if retries > 0: ts._retries = retries - 1 recommendations = self.transition(key, "waiting") From 97b50d2ce0e279c71f02f76b4e0d39085b267e86 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 22 Feb 2021 18:18:20 -0800 Subject: [PATCH 02/13] Assign and use `recommendations` throughout --- distributed/scheduler.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 6a6934d2e37..8947e54699a 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4011,11 +4011,12 @@ def stimulus_task_erred( parent: SchedulerState = cast(SchedulerState, self) logger.debug("Stimulus task erred %s, %s", key, worker) + recommendations: dict = {} + ts: TaskState = parent._tasks.get(key) if ts is None: - return {} + return recommendations - recommendations: dict if ts._state == "processing": retries: Py_ssize_t = ts._retries if retries > 0: @@ -4031,8 +4032,6 @@ def stimulus_task_erred( worker=worker, **kwargs, ) - else: - recommendations = {} return recommendations From f483aba638ff5cb5fa49db990c4f25212b403850 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 22 Feb 2021 18:18:20 -0800 Subject: [PATCH 03/13] Reuse `recommendations` throughout --- distributed/scheduler.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 8947e54699a..146f0d1556f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4043,13 +4043,13 @@ def stimulus_missing_data( with log_errors(): logger.debug("Stimulus missing data %s, %s", key, worker) + recommendations: dict = {} + ts: TaskState = parent._tasks.get(key) if ts is None or ts._state == "memory": - return {} + return recommendations cts: TaskState = parent._tasks.get(cause) - recommendations: dict = {} - if cts is not None and cts._state == "memory": # couldn't find this ws: WorkerState cts_nbytes: Py_ssize_t = cts.get_nbytes() @@ -4063,11 +4063,12 @@ def stimulus_missing_data( recommendations[key] = "released" self.transitions(recommendations) + recommendations = {} if parent._validate: assert cause not in self.who_has - return {} + return recommendations def stimulus_retry(self, comm=None, keys=None, client=None): parent: SchedulerState = cast(SchedulerState, self) From 1142bd84a3bd0b2c490b09cba781b7f2accc0ab1 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 22 Feb 2021 18:18:21 -0800 Subject: [PATCH 04/13] Validate before sending messages in `transitions` --- distributed/scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 146f0d1556f..dfef97113c0 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -6086,12 +6086,12 @@ def transitions(self, recommendations: dict): else: client_msgs[c] = new_msgs - self.send_all(client_msgs, worker_msgs) - if parent._validate: for key in keys: self.validate_key(key) + self.send_all(client_msgs, worker_msgs) + def story(self, *keys): """ Get all transitions that touch one of the input keys """ keys = {key.key if isinstance(key, TaskState) else key for key in keys} From fe46e7709e77354ec9ac4d380b87ea57cca464bf Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 22 Feb 2021 18:18:21 -0800 Subject: [PATCH 05/13] Assign `recommendations` an empty `dict` to start Also clear it out after running `transitions` as that should have handled all of them. --- distributed/scheduler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index dfef97113c0..a248bd7cf49 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3564,7 +3564,7 @@ async def add_worker( except Exception as e: logger.exception(e) - recommendations: dict + recommendations: dict = {} if nbytes: for key in nbytes: ts: TaskState = parent._tasks.get(key) @@ -3577,8 +3577,8 @@ async def add_worker( typename=types[key], ) self.transitions(recommendations) + recommendations = {} - recommendations = {} for ts in list(parent._unrunnable): valid: set = self.valid_workers(ts) if valid is None or ws in valid: @@ -3586,6 +3586,7 @@ async def add_worker( if recommendations: self.transitions(recommendations) + recommendations = {} self.log_event(address, {"action": "add-worker"}) self.log_event("all", {"action": "add-worker", "worker": address}) From 5ef1e8613d79816f4385db9a381ca792055c8bfa Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 22 Feb 2021 18:18:22 -0800 Subject: [PATCH 06/13] Rename `_transitions` to `_transitions_table` --- distributed/scheduler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index a248bd7cf49..dc36c48c966 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3120,7 +3120,7 @@ def __init__( "stop_task_metadata": self.stop_task_metadata, } - self._transitions = { + self._transitions_table = { ("released", "waiting"): self.transition_released_waiting, ("waiting", "released"): self.transition_waiting_released, ("waiting", "processing"): self.transition_waiting_processing, @@ -5926,12 +5926,12 @@ def _transition(self, key, finish: str, *args, **kwargs): dependencies = set(ts._dependencies) start_finish = (start, finish) - func = self._transitions.get(start_finish) + func = self._transitions_table.get(start_finish) if func is not None: a: tuple = func(key, *args, **kwargs) recommendations, worker_msgs, client_msgs = a elif "released" not in start_finish: - func = self._transitions["released", finish] + func = self._transitions_table["released", finish] assert not args and not kwargs a_recs: dict a_wmsgs: dict @@ -5940,7 +5940,7 @@ def _transition(self, key, finish: str, *args, **kwargs): a_recs, a_wmsgs, a_cmsgs = a v = a_recs.get(key) if v is not None: - func = self._transitions["released", v] + func = self._transitions_table["released", v] b_recs: dict b_wmsgs: dict b_cmsgs: dict From 6cce8355c4ae8524d44c0b718b4f43fe0f5ea211 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 22 Feb 2021 18:18:22 -0800 Subject: [PATCH 07/13] Swap `client_msg` and `worker_msg` order --- distributed/scheduler.py | 142 +++++++++++++++++++-------------------- 1 file changed, 71 insertions(+), 71 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index dc36c48c966..54522565ea2 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1769,8 +1769,8 @@ def transition_released_waiting(self, key): ts: TaskState = self._tasks[key] dts: TaskState recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: assert ts._run_spec @@ -1781,7 +1781,7 @@ def transition_released_waiting(self, key): if ts._has_lost_dependencies: recommendations[key] = "forgotten" - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs ts.state = "waiting" @@ -1790,7 +1790,7 @@ def transition_released_waiting(self, key): if dts._exception_blame: ts._exception_blame = dts._exception_blame recommendations[key] = "erred" - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs for dts in ts._dependencies: dep = dts._key @@ -1810,7 +1810,7 @@ def transition_released_waiting(self, key): self._unrunnable.add(ts) ts.state = "no-worker" - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -1824,8 +1824,8 @@ def transition_no_worker_waiting(self, key): ts: TaskState = self._tasks[key] dts: TaskState recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: assert ts in self._unrunnable @@ -1837,7 +1837,7 @@ def transition_no_worker_waiting(self, key): if ts._has_lost_dependencies: recommendations[key] = "forgotten" - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs for dts in ts._dependencies: dep = dts._key @@ -1857,7 +1857,7 @@ def transition_no_worker_waiting(self, key): self._unrunnable.add(ts) ts.state = "no-worker" - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -1934,8 +1934,8 @@ def transition_waiting_processing(self, key): ts: TaskState = self._tasks[key] dts: TaskState recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: assert not ts._waiting_on @@ -1948,7 +1948,7 @@ def transition_waiting_processing(self, key): ws: WorkerState = self.decide_worker(ts) if ws is None: - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs worker = ws._address duration_estimate = self.set_duration_estimate(ts, ws) @@ -1967,7 +1967,7 @@ def transition_waiting_processing(self, key): worker_msgs[worker] = [_task_to_msg(self, ts)] - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -1983,8 +1983,8 @@ def transition_waiting_memory( ws: WorkerState = self._workers_dv[worker] ts: TaskState = self._tasks[key] recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: assert not ts._processing_on @@ -2007,7 +2007,7 @@ def transition_waiting_memory( assert not ts._waiting_on assert ts._who_has - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -2029,8 +2029,8 @@ def transition_processing_memory( ws: WorkerState wws: WorkerState recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} try: ts: TaskState = self._tasks[key] assert worker @@ -2048,7 +2048,7 @@ def transition_processing_memory( ws = self._workers_dv.get(worker) if ws is None: recommendations[key] = "released" - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs if ws != ts._processing_on: # someone else has this task logger.info( @@ -2058,7 +2058,7 @@ def transition_processing_memory( ws, key, ) - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs has_compute_startstop: bool = False compute_start: double @@ -2124,7 +2124,7 @@ def transition_processing_memory( assert not ts._processing_on assert not ts._waiting_on - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -2139,8 +2139,8 @@ def transition_memory_released(self, key, safe: bint = False): ts: TaskState = self._tasks[key] dts: TaskState recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: assert not ts._waiting_on @@ -2157,8 +2157,8 @@ def transition_memory_released(self, key, safe: bint = False): recommendations[ts._key] = "erred" return ( recommendations, - worker_msgs, client_msgs, + worker_msgs, ) # don't try to recreate for dts in ts._waiters: @@ -2199,7 +2199,7 @@ def transition_memory_released(self, key, safe: bint = False): if self._validate: assert not ts._waiting_on - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -2214,8 +2214,8 @@ def transition_released_erred(self, key): dts: TaskState failing_ts: TaskState recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: with log_errors(pdb=LOG_PDB): @@ -2244,7 +2244,7 @@ def transition_released_erred(self, key): ts.state = "erred" # TODO: waiting data? - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -2258,8 +2258,8 @@ def transition_erred_released(self, key): ts: TaskState = self._tasks[key] dts: TaskState recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: with log_errors(pdb=LOG_PDB): @@ -2284,7 +2284,7 @@ def transition_erred_released(self, key): ts.state = "released" - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -2297,8 +2297,8 @@ def transition_waiting_released(self, key): try: ts: TaskState = self._tasks[key] recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: assert not ts._who_has @@ -2322,7 +2322,7 @@ def transition_waiting_released(self, key): else: ts._waiters.clear() - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -2336,8 +2336,8 @@ def transition_processing_released(self, key): ts: TaskState = self._tasks[key] dts: TaskState recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: assert ts._processing_on @@ -2368,7 +2368,7 @@ def transition_processing_released(self, key): if self._validate: assert not ts._processing_on - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -2386,8 +2386,8 @@ def transition_processing_erred( dts: TaskState failing_ts: TaskState recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: assert cause or ts._exception_blame @@ -2447,7 +2447,7 @@ def transition_processing_erred( if self._validate: assert not ts._processing_on - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -2461,8 +2461,8 @@ def transition_no_worker_released(self, key): ts: TaskState = self._tasks[key] dts: TaskState recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: assert self._tasks[key].state == "no-worker" @@ -2477,7 +2477,7 @@ def transition_no_worker_released(self, key): ts._waiters.clear() - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -2504,8 +2504,8 @@ def transition_memory_forgotten(self, key): try: ts: TaskState = self._tasks[key] recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: assert ts._state == "memory" @@ -2532,7 +2532,7 @@ def transition_memory_forgotten(self, key): client_msgs = _task_to_client_msgs(self, ts) self.remove_key(key) - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -2545,8 +2545,8 @@ def transition_released_forgotten(self, key): try: ts: TaskState = self._tasks[key] recommendations: dict = {} - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} if self._validate: assert ts._state in ("released", "erred") @@ -2570,7 +2570,7 @@ def transition_released_forgotten(self, key): client_msgs = _task_to_client_msgs(self, ts) self.remove_key(key) - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -3986,7 +3986,7 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): recommendations: dict if ts._state == "processing": - recommendations = self.transition(key, "memory", worker=worker, **kwargs) + recommendations = self._transition(key, "memory", worker=worker, **kwargs) if ts._state == "memory": assert ws in ts._who_has @@ -4743,17 +4743,10 @@ def client_send(self, client, msg): def send_all(self, client_msgs: dict, worker_msgs: dict): """Send messages to client and workers""" - stream_comms: dict = self.stream_comms client_comms: dict = self.client_comms + stream_comms: dict = self.stream_comms msgs: list - for worker, msgs in worker_msgs.items(): - try: - w = stream_comms[worker] - w.send(*msgs) - except (CommClosedError, AttributeError): - self.loop.add_callback(self.remove_worker, address=worker) - for client, msgs in client_msgs.items(): c = client_comms.get(client) if c is None: @@ -4764,6 +4757,13 @@ def send_all(self, client_msgs: dict, worker_msgs: dict): if self.status == Status.running: logger.critical("Tried writing to closed comm: %s", msgs) + for worker, msgs in worker_msgs.items(): + try: + w = stream_comms[worker] + w.send(*msgs) + except (CommClosedError, AttributeError): + self.loop.add_callback(self.remove_worker, address=worker) + ############################ # Less common interactions # ############################ @@ -5916,10 +5916,10 @@ def _transition(self, key, finish: str, *args, **kwargs): ts = parent._tasks.get(key) if ts is None: - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs start = ts._state if start == finish: - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs if self.plugins: dependents = set(ts._dependents) @@ -5929,51 +5929,51 @@ def _transition(self, key, finish: str, *args, **kwargs): func = self._transitions_table.get(start_finish) if func is not None: a: tuple = func(key, *args, **kwargs) - recommendations, worker_msgs, client_msgs = a + recommendations, client_msgs, worker_msgs = a elif "released" not in start_finish: func = self._transitions_table["released", finish] assert not args and not kwargs a_recs: dict - a_wmsgs: dict a_cmsgs: dict + a_wmsgs: dict a: tuple = self._transition(key, "released") - a_recs, a_wmsgs, a_cmsgs = a + a_recs, a_cmsgs, a_wmsgs = a v = a_recs.get(key) if v is not None: func = self._transitions_table["released", v] b_recs: dict - b_wmsgs: dict b_cmsgs: dict + b_wmsgs: dict b: tuple = func(key) - b_recs, b_wmsgs, b_cmsgs = b + b_recs, b_cmsgs, b_wmsgs = b recommendations.update(a_recs) - for w, new_msgs in a_wmsgs.items(): - msgs = worker_msgs.get(w) - if msgs is not None: - msgs.extend(new_msgs) - else: - worker_msgs[w] = new_msgs for c, new_msgs in a_cmsgs.items(): msgs = client_msgs.get(c) if msgs is not None: msgs.extend(new_msgs) else: client_msgs[c] = new_msgs - - recommendations.update(b_recs) - for w, new_msgs in b_wmsgs.items(): + for w, new_msgs in a_wmsgs.items(): msgs = worker_msgs.get(w) if msgs is not None: msgs.extend(new_msgs) else: worker_msgs[w] = new_msgs + + recommendations.update(b_recs) for c, new_msgs in b_cmsgs.items(): msgs = client_msgs.get(c) if msgs is not None: msgs.extend(new_msgs) else: client_msgs[c] = new_msgs + for w, new_msgs in b_wmsgs.items(): + msgs = worker_msgs.get(w) + if msgs is not None: + msgs.extend(new_msgs) + else: + worker_msgs[w] = new_msgs start = "released" else: @@ -6016,7 +6016,7 @@ def _transition(self, key, finish: str, *args, **kwargs): ts._prefix._groups.remove(tg) del parent._task_groups[tg._name] - return recommendations, worker_msgs, client_msgs + return recommendations, client_msgs, worker_msgs except Exception as e: logger.exception("Error transitioning %r from %r to %r", key, start, finish) if LOG_PDB: @@ -6045,7 +6045,7 @@ def transition(self, key, finish: str, *args, **kwargs): worker_msgs: dict client_msgs: dict a: tuple = self._transition(key, finish, *args, **kwargs) - recommendations, worker_msgs, client_msgs = a + recommendations, client_msgs, worker_msgs = a self.send_all(client_msgs, worker_msgs) return recommendations @@ -6058,34 +6058,34 @@ def transitions(self, recommendations: dict): parent: SchedulerState = cast(SchedulerState, self) keys: set = set() recommendations = recommendations.copy() - worker_msgs: dict = {} client_msgs: dict = {} + worker_msgs: dict = {} msgs: list new_msgs: list new: tuple new_recs: dict - new_wmsgs: dict new_cmsgs: dict + new_wmsgs: dict while recommendations: key, finish = recommendations.popitem() keys.add(key) new = self._transition(key, finish) - new_recs, new_wmsgs, new_cmsgs = new + new_recs, new_cmsgs, new_wmsgs = new recommendations.update(new_recs) - for w, new_msgs in new_wmsgs.items(): - msgs = worker_msgs.get(w) - if msgs is not None: - msgs.extend(new_msgs) - else: - worker_msgs[w] = new_msgs for c, new_msgs in new_cmsgs.items(): msgs = client_msgs.get(c) if msgs is not None: msgs.extend(new_msgs) else: client_msgs[c] = new_msgs + for w, new_msgs in new_wmsgs.items(): + msgs = worker_msgs.get(w) + if msgs is not None: + msgs.extend(new_msgs) + else: + worker_msgs[w] = new_msgs if parent._validate: for key in keys: From 1de03931ba0bef2149cb065b5072227a725fd317 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 22 Feb 2021 18:18:23 -0800 Subject: [PATCH 08/13] Factor out `send_all` from `_transitions` --- distributed/scheduler.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 54522565ea2..b95834f1e81 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -6049,7 +6049,7 @@ def transition(self, key, finish: str, *args, **kwargs): self.send_all(client_msgs, worker_msgs) return recommendations - def transitions(self, recommendations: dict): + def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: dict): """Process transitions until none are left This includes feedback from previous transitions and continues until we @@ -6058,8 +6058,6 @@ def transitions(self, recommendations: dict): parent: SchedulerState = cast(SchedulerState, self) keys: set = set() recommendations = recommendations.copy() - client_msgs: dict = {} - worker_msgs: dict = {} msgs: list new_msgs: list new: tuple @@ -6091,6 +6089,15 @@ def transitions(self, recommendations: dict): for key in keys: self.validate_key(key) + def transitions(self, recommendations: dict): + """Process transitions until none are left + + This includes feedback from previous transitions and continues until we + reach a steady state + """ + client_msgs: dict = {} + worker_msgs: dict = {} + self._transitions(recommendations, client_msgs, worker_msgs) self.send_all(client_msgs, worker_msgs) def story(self, *keys): From 4103d41200ca708910add3a02182e678e5935634 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 22 Feb 2021 18:18:24 -0800 Subject: [PATCH 09/13] Batch communication from `handle_task_finished` Track messages to send from transition calls in `handle_task_finished` and `stimulus_task_finished` to allow sending all of these in one go. Should allow the transition work to happen with fewer interruptions and send more data at once. --- distributed/scheduler.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index b95834f1e81..30ee0549e65 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3978,15 +3978,19 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): parent: SchedulerState = cast(SchedulerState, self) logger.debug("Stimulus task finished %s, %s", key, worker) + recommendations: dict = {} + client_msgs: dict = {} + worker_msgs: dict = {} + ts: TaskState = parent._tasks.get(key) if ts is None: - return {} + return recommendations, client_msgs, worker_msgs ws: WorkerState = parent._workers_dv[worker] ts._metadata.update(kwargs["metadata"]) - recommendations: dict if ts._state == "processing": - recommendations = self._transition(key, "memory", worker=worker, **kwargs) + r: tuple = self._transition(key, "memory", worker=worker, **kwargs) + recommendations, client_msgs, worker_msgs = r if ts._state == "memory": assert ws in ts._who_has @@ -4000,10 +4004,9 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): ts._who_has, ) if ws not in ts._who_has: - self.worker_send(worker, {"op": "release-task", "key": key}) - recommendations = {} + worker_msgs[worker] = [{"op": "release-task", "key": key}] - return recommendations + return recommendations, client_msgs, worker_msgs def stimulus_task_erred( self, key=None, worker=None, exception=None, traceback=None, **kwargs @@ -4589,8 +4592,16 @@ def handle_task_finished(self, key=None, worker=None, **msg): if worker not in parent._workers_dv: return validate_key(key) - r = self.stimulus_task_finished(key=key, worker=worker, **msg) - self.transitions(r) + + recommendations: dict + client_msgs: dict + worker_msgs: dict + + r: tuple = self.stimulus_task_finished(key=key, worker=worker, **msg) + recommendations, client_msgs, worker_msgs = r + self._transitions(recommendations, client_msgs, worker_msgs) + + self.send_all(client_msgs, worker_msgs) def handle_task_erred(self, key=None, **msg): r = self.stimulus_task_erred(key=key, **msg) From 71eb8f6b678ba1475ff519dd15567740b3c181aa Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 22 Feb 2021 18:18:24 -0800 Subject: [PATCH 10/13] Batch all messages from `add_worker` --- distributed/scheduler.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 30ee0549e65..99f52367021 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3565,18 +3565,21 @@ async def add_worker( logger.exception(e) recommendations: dict = {} + client_msgs: dict = {} + worker_msgs: dict = {} if nbytes: for key in nbytes: ts: TaskState = parent._tasks.get(key) if ts is not None and ts._state in ("processing", "waiting"): - recommendations = self.transition( + t: tuple = self._transition( key, "memory", worker=address, nbytes=nbytes[key], typename=types[key], ) - self.transitions(recommendations) + recommendations, client_msgs, worker_msgs = t + self._transitions(recommendations, client_msgs, worker_msgs) recommendations = {} for ts in list(parent._unrunnable): @@ -3585,9 +3588,11 @@ async def add_worker( recommendations[ts._key] = "waiting" if recommendations: - self.transitions(recommendations) + self._transitions(recommendations, client_msgs, worker_msgs) recommendations = {} + self.send_all(client_msgs, worker_msgs) + self.log_event(address, {"action": "add-worker"}) self.log_event("all", {"action": "add-worker", "worker": address}) logger.info("Register worker %s", ws) From cd07cc58e0c74091f256feec3f0e41b6322172a2 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 22 Feb 2021 18:18:25 -0800 Subject: [PATCH 11/13] Batch all messages from `handle_task_erred` --- distributed/scheduler.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 99f52367021..a3d6f9c4921 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4021,18 +4021,21 @@ def stimulus_task_erred( logger.debug("Stimulus task erred %s, %s", key, worker) recommendations: dict = {} + client_msgs: dict = {} + worker_msgs: dict = {} ts: TaskState = parent._tasks.get(key) if ts is None: - return recommendations + return recommendations, client_msgs, worker_msgs if ts._state == "processing": retries: Py_ssize_t = ts._retries + r: tuple if retries > 0: ts._retries = retries - 1 - recommendations = self.transition(key, "waiting") + r = self._transition(key, "waiting") else: - recommendations = self.transition( + r = self._transition( key, "erred", cause=key, @@ -4041,8 +4044,9 @@ def stimulus_task_erred( worker=worker, **kwargs, ) + recommendations, client_msgs, worker_msgs = r - return recommendations + return recommendations, client_msgs, worker_msgs def stimulus_missing_data( self, cause=None, key=None, worker=None, ensure=True, **kwargs @@ -4609,8 +4613,15 @@ def handle_task_finished(self, key=None, worker=None, **msg): self.send_all(client_msgs, worker_msgs) def handle_task_erred(self, key=None, **msg): - r = self.stimulus_task_erred(key=key, **msg) - self.transitions(r) + recommendations: dict + client_msgs: dict + worker_msgs: dict + + r: tuple = self.stimulus_task_erred(key=key, **msg) + recommendations, client_msgs, worker_msgs = r + self._transitions(recommendations, client_msgs, worker_msgs) + + self.send_all(client_msgs, worker_msgs) def handle_release_data(self, key=None, worker=None, client=None, **msg): parent: SchedulerState = cast(SchedulerState, self) From 55f5440d30d210b933c6ba9126ca767142127ffa Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 22 Feb 2021 18:18:25 -0800 Subject: [PATCH 12/13] Batch all messages from `handle_release_data` --- distributed/scheduler.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index a3d6f9c4921..452c7f054a3 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4057,10 +4057,12 @@ def stimulus_missing_data( logger.debug("Stimulus missing data %s, %s", key, worker) recommendations: dict = {} + client_msgs: dict = {} + worker_msgs: dict = {} ts: TaskState = parent._tasks.get(key) if ts is None or ts._state == "memory": - return recommendations + return recommendations, client_msgs, worker_msgs cts: TaskState = parent._tasks.get(cause) if cts is not None and cts._state == "memory": # couldn't find this @@ -4075,13 +4077,13 @@ def stimulus_missing_data( if key: recommendations[key] = "released" - self.transitions(recommendations) + self._transitions(recommendations, client_msgs, worker_msgs) recommendations = {} if parent._validate: assert cause not in self.who_has - return recommendations + return recommendations, client_msgs, worker_msgs def stimulus_retry(self, comm=None, keys=None, client=None): parent: SchedulerState = cast(SchedulerState, self) @@ -4631,8 +4633,16 @@ def handle_release_data(self, key=None, worker=None, client=None, **msg): ws: WorkerState = parent._workers_dv[worker] if ts._processing_on != ws: return - r = self.stimulus_missing_data(key=key, ensure=False, **msg) - self.transitions(r) + + recommendations: dict + client_msgs: dict + worker_msgs: dict + + r: tuple = self.stimulus_missing_data(key=key, ensure=False, **msg) + recommendations, client_msgs, worker_msgs = r + self._transitions(recommendations, client_msgs, worker_msgs) + + self.send_all(client_msgs, worker_msgs) def handle_missing_data(self, key=None, errant_worker=None, **kwargs): parent: SchedulerState = cast(SchedulerState, self) From 93ad1ed119e81eb1436554af3b8fbe318982210e Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 22 Feb 2021 18:18:26 -0800 Subject: [PATCH 13/13] Batch all messages from `gather` --- distributed/scheduler.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 452c7f054a3..0ddd630b71f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4893,6 +4893,10 @@ async def gather(self, comm=None, keys=None, serializers=None): for worker in missing_workers ] ) + + recommendations: dict + client_msgs: dict = {} + worker_msgs: dict = {} for key, workers in missing_keys.items(): # Task may already be gone if it was held by a # `missing_worker` @@ -4905,13 +4909,15 @@ async def gather(self, comm=None, keys=None, serializers=None): if not workers or ts is None: continue ts_nbytes: Py_ssize_t = ts.get_nbytes() + recommendations: dict = {key: "released"} for worker in workers: ws = parent._workers_dv.get(worker) if ws is not None and ts in ws._has_what: ws._has_what.remove(ts) ts._who_has.remove(ws) ws._nbytes -= ts_nbytes - self.transitions({key: "released"}) + self._transitions(recommendations, client_msgs, worker_msgs) + self.send_all(client_msgs, worker_msgs) self.log_event("all", {"action": "gather", "count": len(keys)}) return result