From 6dbb36f3a4ac50eb11f834466834d33c098b281d Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 12:47:30 -0800 Subject: [PATCH 01/38] Move `worker_send` into transition functions --- distributed/scheduler.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 77a9be2f6dc..ac9366184b0 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4598,7 +4598,7 @@ async def register_worker_plugin(self, comm, plugin, name=None): # State Transitions # ##################### - def _remove_from_processing(self, ts: TaskState, send_worker_msg=None): + def _remove_from_processing(self, ts: TaskState) -> str: """ Remove *ts* from the set of processing tasks. """ @@ -4616,8 +4616,9 @@ def _remove_from_processing(self, ts: TaskState, send_worker_msg=None): ws._occupancy -= duration self.check_idle_saturated(ws) self.release_resources(ts, ws) - if send_worker_msg: - self.worker_send(w, send_worker_msg) + return w + else: + return None def _add_to_memory( self, @@ -5222,9 +5223,9 @@ def transition_processing_released(self, key): assert not ts._waiting_on assert self.tasks[key].state == "processing" - self._remove_from_processing( - ts, send_worker_msg={"op": "release-task", "key": key} - ) + w: str = self._remove_from_processing(ts) + if w: + self.worker_send(w, {"op": "release-task", "key": key}) ts.state = "released" From 1bf1faa40410b0cfdb9fc685a8cc3854ad75d315 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 12:47:31 -0800 Subject: [PATCH 02/38] Refactor `_task_to_msg` from `send_task_to_worker` Provides a way for callers to simply construct the message if they are not wanting to send it yet. --- distributed/scheduler.py | 63 ++++++++++++++++++++++------------------ 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ac9366184b0..65580abc7b4 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3292,42 +3292,47 @@ def remove_client_from_events(): ) self.loop.call_later(cleanup_delay, remove_client_from_events) - def send_task_to_worker(self, worker, ts: TaskState, duration=None): - """ Send a single computational task to a worker """ - try: - ws: WorkerState - dts: TaskState + def _task_to_msg(self, ts: TaskState, duration=None) -> dict: + """ Convert a single computational task to a message """ + ws: WorkerState + dts: TaskState - if duration is None: - duration = self.get_task_duration(ts) + if duration is None: + duration = self.get_task_duration(ts) - msg: dict = { - "op": "compute-task", - "key": ts._key, - "priority": ts._priority, - "duration": duration, + msg: dict = { + "op": "compute-task", + "key": ts._key, + "priority": ts._priority, + "duration": duration, + } + if ts._resource_restrictions: + msg["resource_restrictions"] = ts._resource_restrictions + if ts._actor: + msg["actor"] = True + + deps: set = ts._dependencies + if deps: + msg["who_has"] = { + dts._key: [ws._address for ws in dts._who_has] for dts in deps } - if ts._resource_restrictions: - msg["resource_restrictions"] = ts._resource_restrictions - if ts._actor: - msg["actor"] = True + msg["nbytes"] = {dts._key: dts._nbytes for dts in deps} - deps: set = ts._dependencies - if deps: - msg["who_has"] = { - dts._key: [ws._address for ws in dts._who_has] for dts in deps - } - msg["nbytes"] = {dts._key: dts._nbytes for dts in deps} + if self.validate: + assert all(msg["who_has"].values()) - if self.validate: - assert all(msg["who_has"].values()) + task = ts._run_spec + if type(task) is dict: + msg.update(task) + else: + msg["task"] = task - task = ts._run_spec - if type(task) is dict: - msg.update(task) - else: - msg["task"] = task + return msg + def send_task_to_worker(self, worker, ts: TaskState, duration=None): + """ Send a single computational task to a worker """ + try: + msg: dict = self._task_to_msg(ts, duration) self.worker_send(worker, msg) except Exception as e: logger.exception(e) From 9f44f0609fc260d011e39ab71adf856d9d7dfe6f Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 12:47:32 -0800 Subject: [PATCH 03/38] Move `report` out of `_add_to_memory` --- distributed/scheduler.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 65580abc7b4..27ad4bd547f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4630,6 +4630,7 @@ def _add_to_memory( ts: TaskState, ws: WorkerState, recommendations: dict, + report_msg: dict, type=None, typename=None, **kwargs, @@ -4666,10 +4667,10 @@ def _add_to_memory( if not ts._waiters and not ts._who_wants: recommendations[ts._key] = "released" else: - msg: dict = {"op": "key-in-memory", "key": ts._key} + report_msg["op"] = "key-in-memory" + report_msg["key"] = ts._key if type is not None: - msg["type"] = type - self.report(msg) + report_msg["type"] = type ts.state = "memory" ts._type = typename @@ -4907,8 +4908,11 @@ def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs): self.check_idle_saturated(ws) recommendations: dict = {} + report_msg: dict = {} - self._add_to_memory(ts, ws, recommendations, **kwargs) + self._add_to_memory(ts, ws, recommendations, report_msg, **kwargs) + if report_msg: + self.report(report_msg) if self.validate: assert not ts._processing_on @@ -5018,10 +5022,15 @@ def transition_processing_memory( ts.set_nbytes(nbytes) recommendations: dict = {} + report_msg: dict = {} self._remove_from_processing(ts) - self._add_to_memory(ts, ws, recommendations, type=type, typename=typename) + self._add_to_memory( + ts, ws, recommendations, report_msg, type=type, typename=typename + ) + if report_msg: + self.report(report_msg) if self.validate: assert not ts._processing_on From 68cf243a5a4d5bf16da2386a7a175b9c642e2aec Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 12:47:33 -0800 Subject: [PATCH 04/38] Refactor out `_client_releases_keys` This provides us a way to effectively call `client_releases_keys` from other transitions without starting a new transition of its own. --- distributed/scheduler.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 27ad4bd547f..401041b0829 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3000,22 +3000,20 @@ def client_desires_keys(self, keys=None, client=None): if ts._state in ("memory", "erred"): self.report_on_key(ts=ts, client=client) - def client_releases_keys(self, keys=None, client=None): + def _client_releases_keys(self, keys: list, cs: ClientState, recommendations: dict): """ Remove keys from client desired list """ - logger.debug("Client %s releases keys: %s", client, keys) - cs: ClientState = self.clients[client] + logger.debug("Client %s releases keys: %s", cs._client_key, keys) ts: TaskState - tasks2 = set() - for key in list(keys): + tasks2: set = set() + for key in keys: ts = self.tasks.get(key) if ts is not None and ts in cs._wants_what: cs._wants_what.remove(ts) - s = ts._who_wants + s: set = ts._who_wants s.remove(cs) if not s: tasks2.add(ts) - recommendations: dict = {} for ts in tasks2: if not ts._dependents: # No live dependents, can forget @@ -3023,6 +3021,15 @@ def client_releases_keys(self, keys=None, client=None): elif ts._state != "erred" and not ts._waiters: recommendations[ts._key] = "released" + def client_releases_keys(self, keys=None, client=None): + """ Remove keys from client desired list """ + + if not isinstance(keys, list): + keys = list(keys) + cs: ClientState = self.clients[client] + recommendations: dict = {} + + self._client_releases_keys(keys=keys, cs=cs, recommendations=recommendations) self.transitions(recommendations) def client_heartbeat(self, client=None): From 3c9d3ac468ea2691756c77a0f5c0dd9827e2afe7 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 12:47:33 -0800 Subject: [PATCH 05/38] Collect client recs in `_add_to_memory` --- distributed/scheduler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 401041b0829..113b78294d9 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4685,7 +4685,11 @@ def _add_to_memory( cs: ClientState = self.clients["fire-and-forget"] if ts in cs._wants_what: - self.client_releases_keys(client="fire-and-forget", keys=[ts._key]) + self._client_releases_keys( + client="fire-and-forget", + keys=[ts._key], + recommendations=recommendations, + ) def transition_released_waiting(self, key): try: From 1332985cc4954f13870a188e97e82c64a76d0b88 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 12:47:34 -0800 Subject: [PATCH 06/38] Use `_client_releases_keys` in transitions --- distributed/scheduler.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 113b78294d9..3bfc3f35442 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4686,7 +4686,7 @@ def _add_to_memory( cs: ClientState = self.clients["fire-and-forget"] if ts in cs._wants_what: self._client_releases_keys( - client="fire-and-forget", + cs=cs, keys=[ts._key], recommendations=recommendations, ) @@ -5341,7 +5341,11 @@ def transition_processing_erred( cs: ClientState = self.clients["fire-and-forget"] if ts in cs._wants_what: - self.client_releases_keys(client="fire-and-forget", keys=[key]) + self._client_releases_keys( + cs=cs, + keys=[key], + recommendations=recommendations, + ) if self.validate: assert not ts._processing_on From fe31d61875e42440fa9013ab560f188415153186 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 12:47:35 -0800 Subject: [PATCH 07/38] Refactor out `_task_to_report_msg` Separates out the code needed to build a message for `report` based on the `TaskState` in question from the actual call to `self.report`. --- distributed/scheduler.py | 39 +++++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 3bfc3f35442..bc736b320d4 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4316,6 +4316,24 @@ def update_data( if client: self.client_desires_keys(keys=list(who_has), client=client) + def _task_to_report_msg(self, ts: TaskState) -> dict: + if ts is None: + return {"op": "cancelled-key", "key": ts._key} + elif ts._state == "forgotten": + return {"op": "cancelled-key", "key": ts._key} + elif ts._state == "memory": + return {"op": "key-in-memory", "key": ts._key} + elif ts._state == "erred": + failing_ts: TaskState = ts._exception_blame + return { + "op": "task-erred", + "key": ts._key, + "exception": failing_ts._exception, + "traceback": failing_ts._traceback, + } + else: + return None + def report_on_key(self, key: str = None, ts: TaskState = None, client: str = None): if ts is None: tasks: dict = self.tasks @@ -4326,24 +4344,9 @@ def report_on_key(self, key: str = None, ts: TaskState = None, client: str = Non assert False, (key, ts) return - if ts is None: - self.report({"op": "cancelled-key", "key": key}, client=client) - elif ts._state == "forgotten": - self.report({"op": "cancelled-key", "key": key}, ts=ts, client=client) - elif ts._state == "memory": - self.report({"op": "key-in-memory", "key": key}, ts=ts, client=client) - elif ts._state == "erred": - failing_ts: TaskState = ts._exception_blame - self.report( - { - "op": "task-erred", - "key": key, - "exception": failing_ts._exception, - "traceback": failing_ts._traceback, - }, - ts=ts, - client=client, - ) + report_msg: dict = self._task_to_report_msg(ts) + if report_msg is not None: + self.report(report_msg, ts=ts, client=client) async def feed( self, comm, function=None, setup=None, teardown=None, interval="1s", **kwargs From 6b533a1257a03133256c7728014d0e070e5ba3b0 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 12:47:36 -0800 Subject: [PATCH 08/38] Collect and send worker messages from transitions --- distributed/scheduler.py | 85 +++++++++++++++++++++++++--------------- 1 file changed, 54 insertions(+), 31 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index bc736b320d4..15b5e2c731d 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4700,6 +4700,7 @@ def transition_released_waiting(self, key): workers: dict = cast(dict, self.workers) ts: TaskState = tasks[key] dts: TaskState + worker_msgs: dict = {} if self.validate: assert ts._run_spec @@ -4709,7 +4710,7 @@ def transition_released_waiting(self, key): assert not any([dts._state == "forgotten" for dts in ts._dependencies]) if ts._has_lost_dependencies: - return {key: "forgotten"} + return {key: "forgotten"}, worker_msgs ts.state = "waiting" @@ -4720,7 +4721,7 @@ def transition_released_waiting(self, key): if dts._exception_blame: ts._exception_blame = dts._exception_blame recommendations[key] = "erred" - return recommendations + return recommendations, worker_msgs for dts in ts._dependencies: dep = dts._key @@ -4740,7 +4741,7 @@ def transition_released_waiting(self, key): self.unrunnable.add(ts) ts.state = "no-worker" - return recommendations + return recommendations, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -4755,6 +4756,7 @@ def transition_no_worker_waiting(self, key): workers: dict = cast(dict, self.workers) ts: TaskState = tasks[key] dts: TaskState + worker_msgs: dict = {} if self.validate: assert ts in self.unrunnable @@ -4765,7 +4767,7 @@ def transition_no_worker_waiting(self, key): self.unrunnable.remove(ts) if ts._has_lost_dependencies: - return {key: "forgotten"} + return {key: "forgotten"}, worker_msgs recommendations: dict = {} @@ -4787,7 +4789,7 @@ def transition_no_worker_waiting(self, key): self.unrunnable.add(ts) ts.state = "no-worker" - return recommendations + return recommendations, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -4862,6 +4864,7 @@ def transition_waiting_processing(self, key): tasks: dict = self.tasks ts: TaskState = tasks[key] dts: TaskState + worker_msgs: dict = {} if self.validate: assert not ts._waiting_on @@ -4874,7 +4877,7 @@ def transition_waiting_processing(self, key): ws: WorkerState = self.decide_worker(ts) if ws is None: - return {} + return {}, worker_msgs worker = ws._address duration_estimate = self.set_duration_estimate(ts, ws) @@ -4893,7 +4896,7 @@ def transition_waiting_processing(self, key): self.send_task_to_worker(worker, ts) - return {} + return {}, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -4908,6 +4911,7 @@ def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs): ws: WorkerState = workers[worker] tasks: dict = self.tasks ts: TaskState = tasks[key] + worker_msgs: dict = {} if self.validate: assert not ts._processing_on @@ -4933,7 +4937,7 @@ def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs): assert not ts._waiting_on assert ts._who_has - return recommendations + return recommendations, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -4954,6 +4958,7 @@ def transition_processing_memory( ): ws: WorkerState wws: WorkerState + worker_msgs: dict = {} try: tasks: dict = self.tasks ts: TaskState = tasks[key] @@ -4972,7 +4977,7 @@ def transition_processing_memory( workers: dict = cast(dict, self.workers) ws = workers.get(worker) if ws is None: - return {key: "released"} + return {key: "released"}, worker_msgs if ws != ts._processing_on: # someone else has this task logger.info( @@ -4982,7 +4987,7 @@ def transition_processing_memory( ws, key, ) - return {} + return {}, worker_msgs if startstops: L = list() @@ -5050,7 +5055,7 @@ def transition_processing_memory( assert not ts._processing_on assert not ts._waiting_on - return recommendations + return recommendations, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -5065,6 +5070,7 @@ def transition_memory_released(self, key, safe=False): tasks: dict = self.tasks ts: TaskState = tasks[key] dts: TaskState + worker_msgs: dict = {} if self.validate: assert not ts._waiting_on @@ -5078,7 +5084,7 @@ def transition_memory_released(self, key, safe=False): if ts._who_wants: ts._exception_blame = ts ts._exception = "Worker holding Actor was lost" - return {ts._key: "erred"} # don't try to recreate + return {ts._key: "erred"}, worker_msgs # don't try to recreate recommendations: dict = {} @@ -5093,9 +5099,12 @@ def transition_memory_released(self, key, safe=False): ws._has_what.remove(ts) ws._nbytes -= ts.get_nbytes() ts._group._nbytes_in_memory -= ts.get_nbytes() - self.worker_send( - ws._address, {"op": "delete-data", "keys": [key], "report": False} - ) + worker_msgs[ws._address] = { + "op": "delete-data", + "keys": [key], + "report": False, + } + ts._who_has.clear() ts.state = "released" @@ -5112,7 +5121,7 @@ def transition_memory_released(self, key, safe=False): if self.validate: assert not ts._waiting_on - return recommendations + return recommendations, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -5127,6 +5136,7 @@ def transition_released_erred(self, key): ts: TaskState = tasks[key] dts: TaskState failing_ts: TaskState + worker_msgs: dict = {} if self.validate: with log_errors(pdb=LOG_PDB): @@ -5156,7 +5166,7 @@ def transition_released_erred(self, key): ts.state = "erred" # TODO: waiting data? - return recommendations + return recommendations, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -5170,6 +5180,7 @@ def transition_erred_released(self, key): tasks: dict = self.tasks ts: TaskState = tasks[key] dts: TaskState + worker_msgs: dict = {} if self.validate: with log_errors(pdb=LOG_PDB): @@ -5192,7 +5203,7 @@ def transition_erred_released(self, key): self.report({"op": "task-retried", "key": key}) ts.state = "released" - return recommendations + return recommendations, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -5205,6 +5216,7 @@ def transition_waiting_released(self, key): try: tasks: dict = self.tasks ts: TaskState = tasks[key] + worker_msgs: dict = {} if self.validate: assert not ts._who_has @@ -5230,7 +5242,7 @@ def transition_waiting_released(self, key): else: ts._waiters.clear() - return recommendations + return recommendations, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -5244,6 +5256,7 @@ def transition_processing_released(self, key): tasks: dict = self.tasks ts: TaskState = tasks[key] dts: TaskState + worker_msgs: dict = {} if self.validate: assert ts._processing_on @@ -5253,7 +5266,7 @@ def transition_processing_released(self, key): w: str = self._remove_from_processing(ts) if w: - self.worker_send(w, {"op": "release-task", "key": key}) + worker_msgs[w] = {"op": "release-task", "key": key} ts.state = "released" @@ -5276,7 +5289,7 @@ def transition_processing_released(self, key): if self.validate: assert not ts._processing_on - return recommendations + return recommendations, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -5294,6 +5307,7 @@ def transition_processing_erred( ts: TaskState = tasks[key] dts: TaskState failing_ts: TaskState + worker_msgs: dict = {} if self.validate: assert cause or ts._exception_blame @@ -5353,7 +5367,7 @@ def transition_processing_erred( if self.validate: assert not ts._processing_on - return recommendations + return recommendations, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -5367,6 +5381,7 @@ def transition_no_worker_released(self, key): tasks: dict = self.tasks ts: TaskState = tasks[key] dts: TaskState + worker_msgs: dict = {} if self.validate: assert self.tasks[key].state == "no-worker" @@ -5381,7 +5396,7 @@ def transition_no_worker_released(self, key): ts._waiters.clear() - return {} + return {}, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -5404,6 +5419,7 @@ def remove_key(self, key): self.task_metadata.pop(key, None) def _propagate_forgotten(self, ts: TaskState, recommendations: dict): + worker_msgs: dict = {} workers: dict = cast(dict, self.workers) ts.state = "forgotten" key: str = ts._key @@ -5438,17 +5454,18 @@ def _propagate_forgotten(self, ts: TaskState, recommendations: dict): ws._nbytes -= ts.get_nbytes() w: str = ws._address if w in workers: # in case worker has died - self.worker_send( - w, {"op": "delete-data", "keys": [key], "report": False} - ) + worker_msgs[w] = {"op": "delete-data", "keys": [key], "report": False} ts._who_has.clear() + return worker_msgs + def transition_memory_forgotten(self, key): tasks: dict ws: WorkerState try: tasks = self.tasks ts: TaskState = tasks[key] + worker_msgs: dict = {} if self.validate: assert ts._state == "memory" @@ -5472,12 +5489,12 @@ def transition_memory_forgotten(self, key): for ws in ts._who_has: ws._actors.discard(ts) - self._propagate_forgotten(ts, recommendations) + worker_msgs = self._propagate_forgotten(ts, recommendations) self.report_on_key(ts=ts) self.remove_key(key) - return recommendations + return recommendations, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -5490,6 +5507,7 @@ def transition_released_forgotten(self, key): try: tasks: dict = self.tasks ts: TaskState = tasks[key] + worker_msgs: dict = {} if self.validate: assert ts._state in ("released", "erred") @@ -5514,7 +5532,7 @@ def transition_released_forgotten(self, key): self.report_on_key(ts=ts) self.remove_key(key) - return recommendations + return recommendations, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -5540,6 +5558,7 @@ def transition(self, key, finish, *args, **kwargs): Scheduler.transitions: transitive version of this function """ ts: TaskState + worker_msgs: dict try: try: ts = self.tasks[key] @@ -5554,16 +5573,17 @@ def transition(self, key, finish, *args, **kwargs): dependencies = set(ts._dependencies) recommendations: dict = {} + worker_msgs = {} if (start, finish) in self._transitions: func = self._transitions[start, finish] - recommendations = func(key, *args, **kwargs) + recommendations, worker_msgs = func(key, *args, **kwargs) elif "released" not in (start, finish): func = self._transitions["released", finish] assert not args and not kwargs a = self.transition(key, "released") if key in a: func = self._transitions["released", a[key]] - b = func(key) + b, worker_msgs = func(key) a = a.copy() a.update(b) recommendations = a @@ -5573,6 +5593,9 @@ def transition(self, key, finish, *args, **kwargs): "Impossible transition from %r to %r" % (start, finish) ) + for worker, msg in worker_msgs.items(): + self.worker_send(worker, msg) + finish2 = ts._state self.transition_log.append((key, start, finish2, recommendations, time())) if self.validate: From 2abe46cfea41bbab9caa708d3129b2aba6a425dc Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 12:47:37 -0800 Subject: [PATCH 09/38] Handle `report` in `transition` --- distributed/scheduler.py | 108 ++++++++++++++++++++++----------------- 1 file changed, 61 insertions(+), 47 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 15b5e2c731d..a9f3bcb4218 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4701,6 +4701,7 @@ def transition_released_waiting(self, key): ts: TaskState = tasks[key] dts: TaskState worker_msgs: dict = {} + report_msg: dict = {} if self.validate: assert ts._run_spec @@ -4710,7 +4711,7 @@ def transition_released_waiting(self, key): assert not any([dts._state == "forgotten" for dts in ts._dependencies]) if ts._has_lost_dependencies: - return {key: "forgotten"}, worker_msgs + return {key: "forgotten"}, worker_msgs, report_msg ts.state = "waiting" @@ -4721,7 +4722,7 @@ def transition_released_waiting(self, key): if dts._exception_blame: ts._exception_blame = dts._exception_blame recommendations[key] = "erred" - return recommendations, worker_msgs + return recommendations, worker_msgs, report_msg for dts in ts._dependencies: dep = dts._key @@ -4741,7 +4742,7 @@ def transition_released_waiting(self, key): self.unrunnable.add(ts) ts.state = "no-worker" - return recommendations, worker_msgs + return recommendations, worker_msgs, report_msg except Exception as e: logger.exception(e) if LOG_PDB: @@ -4757,6 +4758,7 @@ def transition_no_worker_waiting(self, key): ts: TaskState = tasks[key] dts: TaskState worker_msgs: dict = {} + report_msg: dict = {} if self.validate: assert ts in self.unrunnable @@ -4767,7 +4769,7 @@ def transition_no_worker_waiting(self, key): self.unrunnable.remove(ts) if ts._has_lost_dependencies: - return {key: "forgotten"}, worker_msgs + return {key: "forgotten"}, worker_msgs, report_msg recommendations: dict = {} @@ -4789,7 +4791,7 @@ def transition_no_worker_waiting(self, key): self.unrunnable.add(ts) ts.state = "no-worker" - return recommendations, worker_msgs + return recommendations, worker_msgs, report_msg except Exception as e: logger.exception(e) if LOG_PDB: @@ -4865,6 +4867,7 @@ def transition_waiting_processing(self, key): ts: TaskState = tasks[key] dts: TaskState worker_msgs: dict = {} + report_msg: dict = {} if self.validate: assert not ts._waiting_on @@ -4877,7 +4880,7 @@ def transition_waiting_processing(self, key): ws: WorkerState = self.decide_worker(ts) if ws is None: - return {}, worker_msgs + return {}, worker_msgs, report_msg worker = ws._address duration_estimate = self.set_duration_estimate(ts, ws) @@ -4896,7 +4899,7 @@ def transition_waiting_processing(self, key): self.send_task_to_worker(worker, ts) - return {}, worker_msgs + return {}, worker_msgs, report_msg except Exception as e: logger.exception(e) if LOG_PDB: @@ -4912,6 +4915,7 @@ def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs): tasks: dict = self.tasks ts: TaskState = tasks[key] worker_msgs: dict = {} + report_msg: dict = {} if self.validate: assert not ts._processing_on @@ -4929,15 +4933,13 @@ def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs): report_msg: dict = {} self._add_to_memory(ts, ws, recommendations, report_msg, **kwargs) - if report_msg: - self.report(report_msg) if self.validate: assert not ts._processing_on assert not ts._waiting_on assert ts._who_has - return recommendations, worker_msgs + return recommendations, worker_msgs, report_msg except Exception as e: logger.exception(e) if LOG_PDB: @@ -4959,6 +4961,7 @@ def transition_processing_memory( ws: WorkerState wws: WorkerState worker_msgs: dict = {} + report_msg: dict = {} try: tasks: dict = self.tasks ts: TaskState = tasks[key] @@ -4977,7 +4980,7 @@ def transition_processing_memory( workers: dict = cast(dict, self.workers) ws = workers.get(worker) if ws is None: - return {key: "released"}, worker_msgs + return {key: "released"}, worker_msgs, report_msg if ws != ts._processing_on: # someone else has this task logger.info( @@ -4987,7 +4990,7 @@ def transition_processing_memory( ws, key, ) - return {}, worker_msgs + return {}, worker_msgs, report_msg if startstops: L = list() @@ -5048,14 +5051,12 @@ def transition_processing_memory( self._add_to_memory( ts, ws, recommendations, report_msg, type=type, typename=typename ) - if report_msg: - self.report(report_msg) if self.validate: assert not ts._processing_on assert not ts._waiting_on - return recommendations, worker_msgs + return recommendations, worker_msgs, report_msg except Exception as e: logger.exception(e) if LOG_PDB: @@ -5071,6 +5072,7 @@ def transition_memory_released(self, key, safe=False): ts: TaskState = tasks[key] dts: TaskState worker_msgs: dict = {} + report_msg: dict = {} if self.validate: assert not ts._waiting_on @@ -5084,7 +5086,11 @@ def transition_memory_released(self, key, safe=False): if ts._who_wants: ts._exception_blame = ts ts._exception = "Worker holding Actor was lost" - return {ts._key: "erred"}, worker_msgs # don't try to recreate + return ( + {ts._key: "erred"}, + worker_msgs, + report_msg, + ) # don't try to recreate recommendations: dict = {} @@ -5109,7 +5115,7 @@ def transition_memory_released(self, key, safe=False): ts.state = "released" - self.report({"op": "lost-data", "key": key}) + report_msg = {"op": "lost-data", "key": key} if not ts._run_spec: # pure data recommendations[key] = "forgotten" @@ -5121,7 +5127,7 @@ def transition_memory_released(self, key, safe=False): if self.validate: assert not ts._waiting_on - return recommendations, worker_msgs + return recommendations, worker_msgs, report_msg except Exception as e: logger.exception(e) if LOG_PDB: @@ -5137,6 +5143,7 @@ def transition_released_erred(self, key): dts: TaskState failing_ts: TaskState worker_msgs: dict = {} + report_msg: dict = {} if self.validate: with log_errors(pdb=LOG_PDB): @@ -5154,19 +5161,17 @@ def transition_released_erred(self, key): if not dts._who_has: recommendations[dts._key] = "erred" - self.report( - { - "op": "task-erred", - "key": key, - "exception": failing_ts._exception, - "traceback": failing_ts._traceback, - } - ) + report_msg = { + "op": "task-erred", + "key": key, + "exception": failing_ts._exception, + "traceback": failing_ts._traceback, + } ts.state = "erred" # TODO: waiting data? - return recommendations, worker_msgs + return recommendations, worker_msgs, report_msg except Exception as e: logger.exception(e) if LOG_PDB: @@ -5181,6 +5186,7 @@ def transition_erred_released(self, key): ts: TaskState = tasks[key] dts: TaskState worker_msgs: dict = {} + report_msg: dict = {} if self.validate: with log_errors(pdb=LOG_PDB): @@ -5200,10 +5206,10 @@ def transition_erred_released(self, key): if dts._state == "erred": recommendations[dts._key] = "waiting" - self.report({"op": "task-retried", "key": key}) + report_msg = {"op": "task-retried", "key": key} ts.state = "released" - return recommendations, worker_msgs + return recommendations, worker_msgs, report_msg except Exception as e: logger.exception(e) if LOG_PDB: @@ -5217,6 +5223,7 @@ def transition_waiting_released(self, key): tasks: dict = self.tasks ts: TaskState = tasks[key] worker_msgs: dict = {} + report_msg: dict = {} if self.validate: assert not ts._who_has @@ -5242,7 +5249,7 @@ def transition_waiting_released(self, key): else: ts._waiters.clear() - return recommendations, worker_msgs + return recommendations, worker_msgs, report_msg except Exception as e: logger.exception(e) if LOG_PDB: @@ -5257,6 +5264,7 @@ def transition_processing_released(self, key): ts: TaskState = tasks[key] dts: TaskState worker_msgs: dict = {} + report_msg: dict = {} if self.validate: assert ts._processing_on @@ -5289,7 +5297,7 @@ def transition_processing_released(self, key): if self.validate: assert not ts._processing_on - return recommendations, worker_msgs + return recommendations, worker_msgs, report_msg except Exception as e: logger.exception(e) if LOG_PDB: @@ -5308,6 +5316,7 @@ def transition_processing_erred( dts: TaskState failing_ts: TaskState worker_msgs: dict = {} + report_msg: dict = {} if self.validate: assert cause or ts._exception_blame @@ -5347,14 +5356,12 @@ def transition_processing_erred( ts.state = "erred" - self.report( - { - "op": "task-erred", - "key": key, - "exception": failing_ts._exception, - "traceback": failing_ts._traceback, - } - ) + report_msg = { + "op": "task-erred", + "key": key, + "exception": failing_ts._exception, + "traceback": failing_ts._traceback, + } cs: ClientState = self.clients["fire-and-forget"] if ts in cs._wants_what: @@ -5367,7 +5374,7 @@ def transition_processing_erred( if self.validate: assert not ts._processing_on - return recommendations, worker_msgs + return recommendations, worker_msgs, report_msg except Exception as e: logger.exception(e) if LOG_PDB: @@ -5382,6 +5389,7 @@ def transition_no_worker_released(self, key): ts: TaskState = tasks[key] dts: TaskState worker_msgs: dict = {} + report_msg: dict = {} if self.validate: assert self.tasks[key].state == "no-worker" @@ -5396,7 +5404,7 @@ def transition_no_worker_released(self, key): ts._waiters.clear() - return {}, worker_msgs + return {}, worker_msgs, report_msg except Exception as e: logger.exception(e) if LOG_PDB: @@ -5466,6 +5474,7 @@ def transition_memory_forgotten(self, key): tasks = self.tasks ts: TaskState = tasks[key] worker_msgs: dict = {} + report_msg: dict = {} if self.validate: assert ts._state == "memory" @@ -5491,10 +5500,10 @@ def transition_memory_forgotten(self, key): worker_msgs = self._propagate_forgotten(ts, recommendations) - self.report_on_key(ts=ts) + report_msg = self._task_to_report_msg(ts) self.remove_key(key) - return recommendations, worker_msgs + return recommendations, worker_msgs, report_msg except Exception as e: logger.exception(e) if LOG_PDB: @@ -5508,6 +5517,7 @@ def transition_released_forgotten(self, key): tasks: dict = self.tasks ts: TaskState = tasks[key] worker_msgs: dict = {} + report_msg: dict = {} if self.validate: assert ts._state in ("released", "erred") @@ -5529,10 +5539,10 @@ def transition_released_forgotten(self, key): recommendations: dict = {} self._propagate_forgotten(ts, recommendations) - self.report_on_key(ts=ts) + report_msg = self._task_to_report_msg(ts) self.remove_key(key) - return recommendations, worker_msgs + return recommendations, worker_msgs, report_msg except Exception as e: logger.exception(e) if LOG_PDB: @@ -5559,6 +5569,7 @@ def transition(self, key, finish, *args, **kwargs): """ ts: TaskState worker_msgs: dict + report_msg: dict try: try: ts = self.tasks[key] @@ -5574,16 +5585,17 @@ def transition(self, key, finish, *args, **kwargs): recommendations: dict = {} worker_msgs = {} + report_msg = {} if (start, finish) in self._transitions: func = self._transitions[start, finish] - recommendations, worker_msgs = func(key, *args, **kwargs) + recommendations, worker_msgs, report_msg = func(key, *args, **kwargs) elif "released" not in (start, finish): func = self._transitions["released", finish] assert not args and not kwargs a = self.transition(key, "released") if key in a: func = self._transitions["released", a[key]] - b, worker_msgs = func(key) + b, worker_msgs, report_msg = func(key) a = a.copy() a.update(b) recommendations = a @@ -5595,6 +5607,8 @@ def transition(self, key, finish, *args, **kwargs): for worker, msg in worker_msgs.items(): self.worker_send(worker, msg) + if report_msg: + self.report(report_msg, ts=ts) finish2 = ts._state self.transition_log.append((key, start, finish2, recommendations, time())) From bc13fbe49e699f7ea8533776da35004495b36fab Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 12:47:38 -0800 Subject: [PATCH 10/38] Add method to send a message to a specific client --- distributed/scheduler.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index a9f3bcb4218..78d789d896f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3490,6 +3490,18 @@ def worker_send(self, worker, msg): except (CommClosedError, AttributeError): self.loop.add_callback(self.remove_worker, address=worker) + def client_send(self, client, msg): + """Send message to client""" + client_comms: dict = self.client_comms + c = client_comms.get(client) + if c is None: + return + try: + c.send(msg) + except CommClosedError: + if self.status == Status.running: + logger.critical("Tried writing to closed comm: %s", msg) + ############################ # Less common interactions # ############################ From 0138c53995f3ac0a713cb42eda8e2c4550586893 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 12:47:39 -0800 Subject: [PATCH 11/38] Add `_task_to_client_msgs` This converts a `TaskState` into a `dict` of messages with the keys being the Clients to notify and the message being the report message. Allows us to think of messages simply in terms of the message and where it needs to be delivered without needing to know anything about the `TaskState` it came from or the `ClientState`s involved. --- distributed/scheduler.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 78d789d896f..938fc203e69 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4346,6 +4346,25 @@ def _task_to_report_msg(self, ts: TaskState) -> dict: else: return None + def _task_to_client_msgs(self, ts: TaskState) -> dict: + cs: ClientState + clients: dict = self.clients + client_keys: list + if ts is None: + # Notify all clients + client_keys = list(clients) + else: + # Notify clients interested in key + client_keys = [cs._client_key for cs in ts._who_wants] + + report_msg: dict = self._task_to_report_msg(ts) + + client_msgs: dict = {} + for k in client_keys: + client_msgs[k] = report_msg + + return client_msgs + def report_on_key(self, key: str = None, ts: TaskState = None, client: str = None): if ts is None: tasks: dict = self.tasks From b0ffcf2c1e52503a84cc143261ce47befb068528 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 12:47:39 -0800 Subject: [PATCH 12/38] Replace `report_msg` with `client_msgs` Instead of collecting a message to pass to `report` and letting the relevant Clients be collected from the `TaskState` information later, go ahead and collect that immediately while handling that `TaskState`. These Clients then form the keys of `client_msgs` where the message contains what was in `report_msg`. This allows us to keep all the `TaskState` work contained to where it is relevant and can be handled efficiently. Then the messaging out to Clients only needs be concerned with the messages and where they go without needing to worry about what they pertain to. --- distributed/scheduler.py | 118 ++++++++++++++++++++++----------------- 1 file changed, 68 insertions(+), 50 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 938fc203e69..02d058fec98 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4671,7 +4671,7 @@ def _add_to_memory( ts: TaskState, ws: WorkerState, recommendations: dict, - report_msg: dict, + client_msgs: dict, type=None, typename=None, **kwargs, @@ -4705,6 +4705,8 @@ def _add_to_memory( if not s and not dts._who_wants: recommendations[dts._key] = "released" + report_msg: dict = {} + cs: ClientState if not ts._waiters and not ts._who_wants: recommendations[ts._key] = "released" else: @@ -4713,11 +4715,14 @@ def _add_to_memory( if type is not None: report_msg["type"] = type + for cs in ts._who_wants: + client_msgs[cs._client_key] = report_msg + ts.state = "memory" ts._type = typename ts._group._types.add(typename) - cs: ClientState = self.clients["fire-and-forget"] + cs = self.clients["fire-and-forget"] if ts in cs._wants_what: self._client_releases_keys( cs=cs, @@ -4732,7 +4737,7 @@ def transition_released_waiting(self, key): ts: TaskState = tasks[key] dts: TaskState worker_msgs: dict = {} - report_msg: dict = {} + client_msgs: dict = {} if self.validate: assert ts._run_spec @@ -4742,7 +4747,7 @@ def transition_released_waiting(self, key): assert not any([dts._state == "forgotten" for dts in ts._dependencies]) if ts._has_lost_dependencies: - return {key: "forgotten"}, worker_msgs, report_msg + return {key: "forgotten"}, worker_msgs, client_msgs ts.state = "waiting" @@ -4753,7 +4758,7 @@ def transition_released_waiting(self, key): if dts._exception_blame: ts._exception_blame = dts._exception_blame recommendations[key] = "erred" - return recommendations, worker_msgs, report_msg + return recommendations, worker_msgs, client_msgs for dts in ts._dependencies: dep = dts._key @@ -4773,7 +4778,7 @@ def transition_released_waiting(self, key): self.unrunnable.add(ts) ts.state = "no-worker" - return recommendations, worker_msgs, report_msg + return recommendations, worker_msgs, client_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -4789,7 +4794,7 @@ def transition_no_worker_waiting(self, key): ts: TaskState = tasks[key] dts: TaskState worker_msgs: dict = {} - report_msg: dict = {} + client_msgs: dict = {} if self.validate: assert ts in self.unrunnable @@ -4800,7 +4805,7 @@ def transition_no_worker_waiting(self, key): self.unrunnable.remove(ts) if ts._has_lost_dependencies: - return {key: "forgotten"}, worker_msgs, report_msg + return {key: "forgotten"}, worker_msgs, client_msgs recommendations: dict = {} @@ -4822,7 +4827,7 @@ def transition_no_worker_waiting(self, key): self.unrunnable.add(ts) ts.state = "no-worker" - return recommendations, worker_msgs, report_msg + return recommendations, worker_msgs, client_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -4898,7 +4903,7 @@ def transition_waiting_processing(self, key): ts: TaskState = tasks[key] dts: TaskState worker_msgs: dict = {} - report_msg: dict = {} + client_msgs: dict = {} if self.validate: assert not ts._waiting_on @@ -4911,7 +4916,7 @@ def transition_waiting_processing(self, key): ws: WorkerState = self.decide_worker(ts) if ws is None: - return {}, worker_msgs, report_msg + return {}, worker_msgs, client_msgs worker = ws._address duration_estimate = self.set_duration_estimate(ts, ws) @@ -4930,7 +4935,7 @@ def transition_waiting_processing(self, key): self.send_task_to_worker(worker, ts) - return {}, worker_msgs, report_msg + return {}, worker_msgs, client_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -4946,7 +4951,7 @@ def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs): tasks: dict = self.tasks ts: TaskState = tasks[key] worker_msgs: dict = {} - report_msg: dict = {} + client_msgs: dict = {} if self.validate: assert not ts._processing_on @@ -4961,16 +4966,16 @@ def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs): self.check_idle_saturated(ws) recommendations: dict = {} - report_msg: dict = {} + client_msgs: dict = {} - self._add_to_memory(ts, ws, recommendations, report_msg, **kwargs) + self._add_to_memory(ts, ws, recommendations, client_msgs, **kwargs) if self.validate: assert not ts._processing_on assert not ts._waiting_on assert ts._who_has - return recommendations, worker_msgs, report_msg + return recommendations, worker_msgs, client_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -4992,7 +4997,7 @@ def transition_processing_memory( ws: WorkerState wws: WorkerState worker_msgs: dict = {} - report_msg: dict = {} + client_msgs: dict = {} try: tasks: dict = self.tasks ts: TaskState = tasks[key] @@ -5011,7 +5016,7 @@ def transition_processing_memory( workers: dict = cast(dict, self.workers) ws = workers.get(worker) if ws is None: - return {key: "released"}, worker_msgs, report_msg + return {key: "released"}, worker_msgs, client_msgs if ws != ts._processing_on: # someone else has this task logger.info( @@ -5021,7 +5026,7 @@ def transition_processing_memory( ws, key, ) - return {}, worker_msgs, report_msg + return {}, worker_msgs, client_msgs if startstops: L = list() @@ -5075,19 +5080,19 @@ def transition_processing_memory( ts.set_nbytes(nbytes) recommendations: dict = {} - report_msg: dict = {} + client_msgs: dict = {} self._remove_from_processing(ts) self._add_to_memory( - ts, ws, recommendations, report_msg, type=type, typename=typename + ts, ws, recommendations, client_msgs, type=type, typename=typename ) if self.validate: assert not ts._processing_on assert not ts._waiting_on - return recommendations, worker_msgs, report_msg + return recommendations, worker_msgs, client_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -5103,7 +5108,7 @@ def transition_memory_released(self, key, safe=False): ts: TaskState = tasks[key] dts: TaskState worker_msgs: dict = {} - report_msg: dict = {} + client_msgs: dict = {} if self.validate: assert not ts._waiting_on @@ -5120,7 +5125,7 @@ def transition_memory_released(self, key, safe=False): return ( {ts._key: "erred"}, worker_msgs, - report_msg, + client_msgs, ) # don't try to recreate recommendations: dict = {} @@ -5147,6 +5152,9 @@ def transition_memory_released(self, key, safe=False): ts.state = "released" report_msg = {"op": "lost-data", "key": key} + cs: ClientState + for cs in ts._who_wants: + client_msgs[cs._client_key] = report_msg if not ts._run_spec: # pure data recommendations[key] = "forgotten" @@ -5158,7 +5166,7 @@ def transition_memory_released(self, key, safe=False): if self.validate: assert not ts._waiting_on - return recommendations, worker_msgs, report_msg + return recommendations, worker_msgs, client_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -5174,7 +5182,7 @@ def transition_released_erred(self, key): dts: TaskState failing_ts: TaskState worker_msgs: dict = {} - report_msg: dict = {} + client_msgs: dict = {} if self.validate: with log_errors(pdb=LOG_PDB): @@ -5198,11 +5206,14 @@ def transition_released_erred(self, key): "exception": failing_ts._exception, "traceback": failing_ts._traceback, } + cs: ClientState + for cs in ts._who_wants: + client_msgs[cs._client_key] = report_msg ts.state = "erred" # TODO: waiting data? - return recommendations, worker_msgs, report_msg + return recommendations, worker_msgs, client_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -5217,7 +5228,7 @@ def transition_erred_released(self, key): ts: TaskState = tasks[key] dts: TaskState worker_msgs: dict = {} - report_msg: dict = {} + client_msgs: dict = {} if self.validate: with log_errors(pdb=LOG_PDB): @@ -5238,9 +5249,13 @@ def transition_erred_released(self, key): recommendations[dts._key] = "waiting" report_msg = {"op": "task-retried", "key": key} + cs: ClientState + for cs in ts._who_wants: + client_msgs[cs._client_key] = report_msg + ts.state = "released" - return recommendations, worker_msgs, report_msg + return recommendations, worker_msgs, client_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -5254,7 +5269,7 @@ def transition_waiting_released(self, key): tasks: dict = self.tasks ts: TaskState = tasks[key] worker_msgs: dict = {} - report_msg: dict = {} + client_msgs: dict = {} if self.validate: assert not ts._who_has @@ -5280,7 +5295,7 @@ def transition_waiting_released(self, key): else: ts._waiters.clear() - return recommendations, worker_msgs, report_msg + return recommendations, worker_msgs, client_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -5295,7 +5310,7 @@ def transition_processing_released(self, key): ts: TaskState = tasks[key] dts: TaskState worker_msgs: dict = {} - report_msg: dict = {} + client_msgs: dict = {} if self.validate: assert ts._processing_on @@ -5328,7 +5343,7 @@ def transition_processing_released(self, key): if self.validate: assert not ts._processing_on - return recommendations, worker_msgs, report_msg + return recommendations, worker_msgs, client_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -5347,7 +5362,7 @@ def transition_processing_erred( dts: TaskState failing_ts: TaskState worker_msgs: dict = {} - report_msg: dict = {} + client_msgs: dict = {} if self.validate: assert cause or ts._exception_blame @@ -5393,8 +5408,11 @@ def transition_processing_erred( "exception": failing_ts._exception, "traceback": failing_ts._traceback, } + cs: ClientState + for cs in ts._who_wants: + client_msgs[cs._client_key] = report_msg - cs: ClientState = self.clients["fire-and-forget"] + cs = self.clients["fire-and-forget"] if ts in cs._wants_what: self._client_releases_keys( cs=cs, @@ -5405,7 +5423,7 @@ def transition_processing_erred( if self.validate: assert not ts._processing_on - return recommendations, worker_msgs, report_msg + return recommendations, worker_msgs, client_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -5420,7 +5438,7 @@ def transition_no_worker_released(self, key): ts: TaskState = tasks[key] dts: TaskState worker_msgs: dict = {} - report_msg: dict = {} + client_msgs: dict = {} if self.validate: assert self.tasks[key].state == "no-worker" @@ -5435,7 +5453,7 @@ def transition_no_worker_released(self, key): ts._waiters.clear() - return {}, worker_msgs, report_msg + return {}, worker_msgs, client_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -5505,7 +5523,7 @@ def transition_memory_forgotten(self, key): tasks = self.tasks ts: TaskState = tasks[key] worker_msgs: dict = {} - report_msg: dict = {} + client_msgs: dict = {} if self.validate: assert ts._state == "memory" @@ -5531,10 +5549,10 @@ def transition_memory_forgotten(self, key): worker_msgs = self._propagate_forgotten(ts, recommendations) - report_msg = self._task_to_report_msg(ts) + client_msgs = self._task_to_client_msgs(ts) self.remove_key(key) - return recommendations, worker_msgs, report_msg + return recommendations, worker_msgs, client_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -5548,7 +5566,7 @@ def transition_released_forgotten(self, key): tasks: dict = self.tasks ts: TaskState = tasks[key] worker_msgs: dict = {} - report_msg: dict = {} + client_msgs: dict = {} if self.validate: assert ts._state in ("released", "erred") @@ -5570,10 +5588,10 @@ def transition_released_forgotten(self, key): recommendations: dict = {} self._propagate_forgotten(ts, recommendations) - report_msg = self._task_to_report_msg(ts) + client_msgs = self._task_to_client_msgs(ts) self.remove_key(key) - return recommendations, worker_msgs, report_msg + return recommendations, worker_msgs, client_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -5600,7 +5618,7 @@ def transition(self, key, finish, *args, **kwargs): """ ts: TaskState worker_msgs: dict - report_msg: dict + client_msgs: dict try: try: ts = self.tasks[key] @@ -5616,17 +5634,17 @@ def transition(self, key, finish, *args, **kwargs): recommendations: dict = {} worker_msgs = {} - report_msg = {} + client_msgs = {} if (start, finish) in self._transitions: func = self._transitions[start, finish] - recommendations, worker_msgs, report_msg = func(key, *args, **kwargs) + recommendations, worker_msgs, client_msgs = func(key, *args, **kwargs) elif "released" not in (start, finish): func = self._transitions["released", finish] assert not args and not kwargs a = self.transition(key, "released") if key in a: func = self._transitions["released", a[key]] - b, worker_msgs, report_msg = func(key) + b, worker_msgs, client_msgs = func(key) a = a.copy() a.update(b) recommendations = a @@ -5638,8 +5656,8 @@ def transition(self, key, finish, *args, **kwargs): for worker, msg in worker_msgs.items(): self.worker_send(worker, msg) - if report_msg: - self.report(report_msg, ts=ts) + for client, msg in client_msgs.items(): + self.client_send(client, msg) finish2 = ts._state self.transition_log.append((key, start, finish2, recommendations, time())) From fe33d4afc581ebf6ed729f5c77835ba24d82ccc6 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 12:50:09 -0800 Subject: [PATCH 13/38] Create empty `SchedulerState` class This class should ultimately handle everything related to task graph state and performing transitions currently done in the `Scheduler`. This is intended to be a base class of the `Scheduler`, which we will use to manage related state. As this is a separate class that does not inherit from other Python classes, it should be a good target for Cythonization. This should allow us to more thoroughly optimize these components of the `Scheduler`. Afterwards the `Scheduler` should ideally be left with communication part, `async` methods, user facing APIs, etc., used to interact with other elements that are not really targets of the Cythonization effort. --- distributed/core.py | 3 +++ distributed/scheduler.py | 29 ++++++++++++++++++++++++++++- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/distributed/core.py b/distributed/core.py index 1d98241bb1f..d9d20dd0992 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -135,6 +135,7 @@ def __init__( connection_args=None, timeout=None, io_loop=None, + **kwargs, ): self.handlers = { "identity": self.identity, @@ -230,6 +231,8 @@ def set_thread_ident(): self.__stopped = False + super().__init__(**kwargs) + @property def status(self): return self._status diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 02d058fec98..b279e8d6338 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1526,7 +1526,34 @@ def _task_key_or_none(task): return task.key if task is not None else None -class Scheduler(ServerNode): +@cclass +class SchedulerState: + """Underlying task state of dynamic scheduler + + Tracks the current state of workers, data, and computations. + + Handles transitions between different task states. Notifies the + Scheduler of changes by messaging passing through Queues, which the + Scheduler listens to responds accordingly. + + All events are handled quickly, in linear time with respect to their + input (which is often of constant size) and generally within a + millisecond. Additionally when Cythonized, this can be faster still. + To accomplish this the scheduler tracks a lot of state. Every + operation maintains the consistency of this state. + + Users typically do not interact with ``Transitions`` directly. Instead + users interact with the ``Client``, which in turn engages the + ``Scheduler`` affecting different transitions here under-the-hood. In + the background ``Worker``s also engage with the ``Scheduler`` + affecting these state transitions as well. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +class Scheduler(SchedulerState, ServerNode): """Dynamic distributed task scheduler The scheduler tracks the current state of workers, data, and computations. From 59605779a679399e96063760aa124990de334db2 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 12:50:10 -0800 Subject: [PATCH 14/38] Move `transition*` methods into `SchedulerState` Grabs all of the `transition` methods and methods they call and moves them to `SchedulerState`. --- distributed/scheduler.py | 7200 +++++++++++++++++++------------------- 1 file changed, 3600 insertions(+), 3600 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index b279e8d6338..63d337cad1b 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1552,4080 +1552,4219 @@ class SchedulerState: def __init__(self, **kwargs): super().__init__(**kwargs) + def _remove_from_processing(self, ts: TaskState) -> str: + """ + Remove *ts* from the set of processing tasks. + """ + workers: dict = cast(dict, self.workers) + ws: WorkerState = ts._processing_on + ts._processing_on = None + w: str = ws._address + if w in workers: # may have been removed + duration = ws._processing.pop(ts) + if not ws._processing: + self.total_occupancy -= ws._occupancy + ws._occupancy = 0 + else: + self.total_occupancy -= duration + ws._occupancy -= duration + self.check_idle_saturated(ws) + self.release_resources(ts, ws) + return w + else: + return None -class Scheduler(SchedulerState, ServerNode): - """Dynamic distributed task scheduler - - The scheduler tracks the current state of workers, data, and computations. - The scheduler listens for events and responds by controlling workers - appropriately. It continuously tries to use the workers to execute an ever - growing dask graph. + def _add_to_memory( + self, + ts: TaskState, + ws: WorkerState, + recommendations: dict, + client_msgs: dict, + type=None, + typename=None, + **kwargs, + ): + """ + Add *ts* to the set of in-memory tasks. + """ + if self.validate: + assert ts not in ws._has_what - All events are handled quickly, in linear time with respect to their input - (which is often of constant size) and generally within a millisecond. To - accomplish this the scheduler tracks a lot of state. Every operation - maintains the consistency of this state. + ts._who_has.add(ws) + ws._has_what.add(ts) + ws._nbytes += ts.get_nbytes() - The scheduler communicates with the outside world through Comm objects. - It maintains a consistent and valid view of the world even when listening - to several clients at once. + deps: list = list(ts._dependents) + if len(deps) > 1: + deps.sort(key=operator.attrgetter("priority"), reverse=True) - A Scheduler is typically started either with the ``dask-scheduler`` - executable:: + dts: TaskState + s: set + for dts in deps: + s = dts._waiting_on + if ts in s: + s.discard(ts) + if not s: # new task ready to run + recommendations[dts._key] = "processing" - $ dask-scheduler - Scheduler started at 127.0.0.1:8786 + for dts in ts._dependencies: + s = dts._waiters + s.discard(ts) + if not s and not dts._who_wants: + recommendations[dts._key] = "released" - Or within a LocalCluster a Client starts up without connection - information:: + report_msg: dict = {} + cs: ClientState + if not ts._waiters and not ts._who_wants: + recommendations[ts._key] = "released" + else: + report_msg["op"] = "key-in-memory" + report_msg["key"] = ts._key + if type is not None: + report_msg["type"] = type - >>> c = Client() # doctest: +SKIP - >>> c.cluster.scheduler # doctest: +SKIP - Scheduler(...) + for cs in ts._who_wants: + client_msgs[cs._client_key] = report_msg - Users typically do not interact with the scheduler directly but rather with - the client object ``Client``. + ts.state = "memory" + ts._type = typename + ts._group._types.add(typename) - **State** + cs = self.clients["fire-and-forget"] + if ts in cs._wants_what: + self._client_releases_keys( + cs=cs, + keys=[ts._key], + recommendations=recommendations, + ) - The scheduler contains the following state variables. Each variable is - listed along with what it stores and a brief description. + def transition_released_waiting(self, key): + try: + tasks: dict = self.tasks + workers: dict = cast(dict, self.workers) + ts: TaskState = tasks[key] + dts: TaskState + worker_msgs: dict = {} + client_msgs: dict = {} - * **tasks:** ``{task key: TaskState}`` - Tasks currently known to the scheduler - * **unrunnable:** ``{TaskState}`` - Tasks in the "no-worker" state + if self.validate: + assert ts._run_spec + assert not ts._waiting_on + assert not ts._who_has + assert not ts._processing_on + assert not any([dts._state == "forgotten" for dts in ts._dependencies]) - * **workers:** ``{worker key: WorkerState}`` - Workers currently connected to the scheduler - * **idle:** ``{WorkerState}``: - Set of workers that are not fully utilized - * **saturated:** ``{WorkerState}``: - Set of workers that are not over-utilized + if ts._has_lost_dependencies: + return {key: "forgotten"}, worker_msgs, client_msgs - * **host_info:** ``{hostname: dict}``: - Information about each worker host + ts.state = "waiting" - * **clients:** ``{client key: ClientState}`` - Clients currently connected to the scheduler + recommendations: dict = {} - * **services:** ``{str: port}``: - Other services running on this scheduler, like Bokeh - * **loop:** ``IOLoop``: - The running Tornado IOLoop - * **client_comms:** ``{client key: Comm}`` - For each client, a Comm object used to receive task requests and - report task status updates. - * **stream_comms:** ``{worker key: Comm}`` - For each worker, a Comm object from which we both accept stimuli and - report results - * **task_duration:** ``{key-prefix: time}`` - Time we expect certain functions to take, e.g. ``{'sum': 0.25}`` - """ + dts: TaskState + for dts in ts._dependencies: + if dts._exception_blame: + ts._exception_blame = dts._exception_blame + recommendations[key] = "erred" + return recommendations, worker_msgs, client_msgs - default_port = 8786 - _instances = weakref.WeakSet() + for dts in ts._dependencies: + dep = dts._key + if not dts._who_has: + ts._waiting_on.add(dts) + if dts._state == "released": + recommendations[dep] = "waiting" + else: + dts._waiters.add(ts) - def __init__( - self, - loop=None, - delete_interval="500ms", - synchronize_worker_interval="60s", - services=None, - service_kwargs=None, - allowed_failures=None, - extensions=None, - validate=None, - scheduler_file=None, - security=None, - worker_ttl=None, - idle_timeout=None, - interface=None, - host=None, - port=0, - protocol=None, - dashboard_address=None, - dashboard=None, - http_prefix="/", - preload=None, - preload_argv=(), - plugins=(), - **kwargs, - ): - self._setup_logging(logger) + ts._waiters = {dts for dts in ts._dependents if dts._state == "waiting"} - # Attributes - if allowed_failures is None: - allowed_failures = dask.config.get("distributed.scheduler.allowed-failures") - self.allowed_failures = allowed_failures - if validate is None: - validate = dask.config.get("distributed.scheduler.validate") - self.validate = validate - self.proc = psutil.Process() - self.delete_interval = parse_timedelta(delete_interval, default="ms") - self.synchronize_worker_interval = parse_timedelta( - synchronize_worker_interval, default="ms" - ) - self.digests = None - self.service_specs = services or {} - self.service_kwargs = service_kwargs or {} - self.services = {} - self.scheduler_file = scheduler_file - worker_ttl = worker_ttl or dask.config.get("distributed.scheduler.worker-ttl") - self.worker_ttl = parse_timedelta(worker_ttl) if worker_ttl else None - idle_timeout = idle_timeout or dask.config.get( - "distributed.scheduler.idle-timeout" - ) - if idle_timeout: - self.idle_timeout = parse_timedelta(idle_timeout) - else: - self.idle_timeout = None - self.idle_since = time() - self.time_started = self.idle_since # compatibility for dask-gateway - self._lock = asyncio.Lock() - self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth")) - self.bandwidth_workers = defaultdict(float) - self.bandwidth_types = defaultdict(float) + if not ts._waiting_on: + if workers: + recommendations[key] = "processing" + else: + self.unrunnable.add(ts) + ts.state = "no-worker" - if not preload: - preload = dask.config.get("distributed.scheduler.preload") - if not preload_argv: - preload_argv = dask.config.get("distributed.scheduler.preload-argv") - self.preloads = preloading.process_preloads(self, preload, preload_argv) + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - if isinstance(security, dict): - security = Security(**security) - self.security = security or Security() - assert isinstance(self.security, Security) - self.connection_args = self.security.get_connection_args("scheduler") - self.connection_args["handshake_overrides"] = { # common denominator - "pickle-protocol": 4 - } + pdb.set_trace() + raise - self._start_address = addresses_from_user_args( - host=host, - port=port, - interface=interface, - protocol=protocol, - security=security, - default_port=self.default_port, - ) + def transition_no_worker_waiting(self, key): + try: + tasks: dict = self.tasks + workers: dict = cast(dict, self.workers) + ts: TaskState = tasks[key] + dts: TaskState + worker_msgs: dict = {} + client_msgs: dict = {} - http_server_modules = dask.config.get("distributed.scheduler.http.routes") - show_dashboard = dashboard or (dashboard is None and dashboard_address) - missing_bokeh = False - # install vanilla route if show_dashboard but bokeh is not installed - if show_dashboard: - try: - import distributed.dashboard.scheduler - except ImportError: - missing_bokeh = True - http_server_modules.append("distributed.http.scheduler.missing_bokeh") - routes = get_handlers( - server=self, modules=http_server_modules, prefix=http_prefix - ) - self.start_http_server(routes, dashboard_address, default_port=8787) - if show_dashboard and not missing_bokeh: - distributed.dashboard.scheduler.connect( - self.http_application, self.http_server, self, prefix=http_prefix - ) + if self.validate: + assert ts in self.unrunnable + assert not ts._waiting_on + assert not ts._who_has + assert not ts._processing_on - # Communication state - self.loop = loop or IOLoop.current() - self.client_comms = dict() - self.stream_comms = dict() - self._worker_coroutines = [] - self._ipython_kernel = None + self.unrunnable.remove(ts) - # Task state - self.tasks = dict() - self.task_groups = dict() - self.task_prefixes = dict() - for old_attr, new_attr, wrap in [ - ("priority", "priority", None), - ("dependencies", "dependencies", _legacy_task_key_set), - ("dependents", "dependents", _legacy_task_key_set), - ("retries", "retries", None), - ]: - func = operator.attrgetter(new_attr) - if wrap is not None: - func = compose(wrap, func) - setattr(self, old_attr, _StateLegacyMapping(self.tasks, func)) + if ts._has_lost_dependencies: + return {key: "forgotten"}, worker_msgs, client_msgs - for old_attr, new_attr, wrap in [ - ("nbytes", "nbytes", None), - ("who_wants", "who_wants", _legacy_client_key_set), - ("who_has", "who_has", _legacy_worker_key_set), - ("waiting", "waiting_on", _legacy_task_key_set), - ("waiting_data", "waiters", _legacy_task_key_set), - ("rprocessing", "processing_on", None), - ("host_restrictions", "host_restrictions", None), - ("worker_restrictions", "worker_restrictions", None), - ("resource_restrictions", "resource_restrictions", None), - ("suspicious_tasks", "suspicious", None), - ("exceptions", "exception", None), - ("tracebacks", "traceback", None), - ("exceptions_blame", "exception_blame", _task_key_or_none), - ]: - func = operator.attrgetter(new_attr) - if wrap is not None: - func = compose(wrap, func) - setattr(self, old_attr, _OptionalStateLegacyMapping(self.tasks, func)) + recommendations: dict = {} - for old_attr, new_attr, wrap in [ - ("loose_restrictions", "loose_restrictions", None) - ]: - func = operator.attrgetter(new_attr) - if wrap is not None: - func = compose(wrap, func) - setattr(self, old_attr, _StateLegacySet(self.tasks, func)) + for dts in ts._dependencies: + dep = dts._key + if not dts._who_has: + ts._waiting_on.add(dts) + if dts._state == "released": + recommendations[dep] = "waiting" + else: + dts._waiters.add(ts) - self.generation = 0 - self._last_client = None - self._last_time = 0 - self.unrunnable = set() + ts.state = "waiting" - self.n_tasks = 0 - self.task_metadata = dict() - self.datasets = dict() + if not ts._waiting_on: + if workers: + recommendations[key] = "processing" + else: + self.unrunnable.add(ts) + ts.state = "no-worker" - # Prefix-keyed containers - self.unknown_durations = defaultdict(set) + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - # Client state - self.clients = dict() - for old_attr, new_attr, wrap in [ - ("wants_what", "wants_what", _legacy_task_key_set) - ]: - func = operator.attrgetter(new_attr) - if wrap is not None: - func = compose(wrap, func) - setattr(self, old_attr, _StateLegacyMapping(self.clients, func)) - self.clients["fire-and-forget"] = ClientState("fire-and-forget") + pdb.set_trace() + raise - # Worker state - self.workers = sortedcontainers.SortedDict() - for old_attr, new_attr, wrap in [ - ("nthreads", "nthreads", None), - ("worker_bytes", "nbytes", None), - ("worker_resources", "resources", None), - ("used_resources", "used_resources", None), - ("occupancy", "occupancy", None), - ("worker_info", "metrics", None), - ("processing", "processing", _legacy_task_key_dict), - ("has_what", "has_what", _legacy_task_key_set), - ]: - func = operator.attrgetter(new_attr) - if wrap is not None: - func = compose(wrap, func) - setattr(self, old_attr, _StateLegacyMapping(self.workers, func)) + def decide_worker(self, ts: TaskState) -> WorkerState: + """ + Decide on a worker for task *ts*. Return a WorkerState. + """ + workers: dict = cast(dict, self.workers) + ws: WorkerState = None + valid_workers: set = self.valid_workers(ts) - self.idle = sortedcontainers.SortedDict() - self.saturated = set() + if ( + valid_workers is not None + and not valid_workers + and not ts._loose_restrictions + and workers + ): + self.unrunnable.add(ts) + ts.state = "no-worker" + return ws - self.total_nthreads = 0 - self.total_occupancy = 0 - self.host_info = defaultdict(dict) - self.resources = defaultdict(dict) - self.aliases = dict() + if ts._dependencies or valid_workers is not None: + ws = decide_worker( + ts, + workers.values(), + valid_workers, + partial(self.worker_objective, ts), + ) + else: + worker_pool = self.idle or self.workers + worker_pool_dv = cast(dict, worker_pool) + n_workers: Py_ssize_t = len(worker_pool_dv) + if n_workers < 20: # smart but linear in small case + ws = min(worker_pool.values(), key=operator.attrgetter("occupancy")) + else: # dumb but fast in large case + n_tasks: Py_ssize_t = self.n_tasks + ws = worker_pool.values()[n_tasks % n_workers] - self._task_state_collections = [self.unrunnable] + if self.validate: + assert ws is None or isinstance(ws, WorkerState), ( + type(ws), + ws, + ) + assert ws._address in workers - self._worker_collections = [ - self.workers, - self.host_info, - self.resources, - self.aliases, - ] + return ws - self.extensions = {} - self.plugins = list(plugins) - self.transition_log = deque( - maxlen=dask.config.get("distributed.scheduler.transition-log-length") - ) - self.log = deque( - maxlen=dask.config.get("distributed.scheduler.transition-log-length") - ) - self.events = defaultdict(lambda: deque(maxlen=100000)) - self.event_counts = defaultdict(int) - self.worker_plugins = [] + def set_duration_estimate(self, ts: TaskState, ws: WorkerState): + """Estimate task duration using worker state and task state. - worker_handlers = { - "task-finished": self.handle_task_finished, - "task-erred": self.handle_task_erred, - "release": self.handle_release_data, - "release-worker-data": self.release_worker_data, - "add-keys": self.add_keys, - "missing-data": self.handle_missing_data, - "long-running": self.handle_long_running, - "reschedule": self.reschedule, - "keep-alive": lambda *args, **kwargs: None, - "log-event": self.log_worker_event, - } + If a task takes longer than twice the current average duration we + estimate the task duration to be 2x current-runtime, otherwise we set it + to be the average duration. + """ + duration: double = self.get_task_duration(ts) + comm: double = self.get_comm_cost(ts, ws) + total_duration: double = duration + comm + if ts in ws._executing: + exec_time: double = ws._executing[ts] + if exec_time > 2 * duration: + total_duration = 2 * exec_time + ws._processing[ts] = total_duration + return total_duration - client_handlers = { - "update-graph": self.update_graph, - "update-graph-hlg": self.update_graph_hlg, - "client-desires-keys": self.client_desires_keys, - "update-data": self.update_data, - "report-key": self.report_on_key, - "client-releases-keys": self.client_releases_keys, - "heartbeat-client": self.client_heartbeat, - "close-client": self.remove_client, - "restart": self.restart, - } + def transition_waiting_processing(self, key): + try: + tasks: dict = self.tasks + ts: TaskState = tasks[key] + dts: TaskState + worker_msgs: dict = {} + client_msgs: dict = {} - self.handlers = { - "register-client": self.add_client, - "scatter": self.scatter, - "register-worker": self.add_worker, - "unregister": self.remove_worker, - "gather": self.gather, - "cancel": self.stimulus_cancel, - "retry": self.stimulus_retry, - "feed": self.feed, - "terminate": self.close, - "broadcast": self.broadcast, - "proxy": self.proxy, - "ncores": self.get_ncores, - "has_what": self.get_has_what, - "who_has": self.get_who_has, - "processing": self.get_processing, - "call_stack": self.get_call_stack, - "profile": self.get_profile, - "performance_report": self.performance_report, - "get_logs": self.get_logs, - "logs": self.get_logs, - "worker_logs": self.get_worker_logs, - "log_event": self.log_worker_event, - "events": self.get_events, - "nbytes": self.get_nbytes, - "versions": self.versions, - "add_keys": self.add_keys, - "rebalance": self.rebalance, - "replicate": self.replicate, - "start_ipython": self.start_ipython, - "run_function": self.run_function, - "update_data": self.update_data, - "set_resources": self.add_resources, - "retire_workers": self.retire_workers, - "get_metadata": self.get_metadata, - "set_metadata": self.set_metadata, - "heartbeat_worker": self.heartbeat_worker, - "get_task_status": self.get_task_status, - "get_task_stream": self.get_task_stream, - "register_worker_plugin": self.register_worker_plugin, - "adaptive_target": self.adaptive_target, - "workers_to_close": self.workers_to_close, - "subscribe_worker_status": self.subscribe_worker_status, - "start_task_metadata": self.start_task_metadata, - "stop_task_metadata": self.stop_task_metadata, - } - - self._transitions = { - ("released", "waiting"): self.transition_released_waiting, - ("waiting", "released"): self.transition_waiting_released, - ("waiting", "processing"): self.transition_waiting_processing, - ("waiting", "memory"): self.transition_waiting_memory, - ("processing", "released"): self.transition_processing_released, - ("processing", "memory"): self.transition_processing_memory, - ("processing", "erred"): self.transition_processing_erred, - ("no-worker", "released"): self.transition_no_worker_released, - ("no-worker", "waiting"): self.transition_no_worker_waiting, - ("released", "forgotten"): self.transition_released_forgotten, - ("memory", "forgotten"): self.transition_memory_forgotten, - ("erred", "forgotten"): self.transition_released_forgotten, - ("erred", "released"): self.transition_erred_released, - ("memory", "released"): self.transition_memory_released, - ("released", "erred"): self.transition_released_erred, - } - - connection_limit = get_fileno_limit() / 2 - - super().__init__( - handlers=self.handlers, - stream_handlers=merge(worker_handlers, client_handlers), - io_loop=self.loop, - connection_limit=connection_limit, - deserialize=False, - connection_args=self.connection_args, - **kwargs, - ) + if self.validate: + assert not ts._waiting_on + assert not ts._who_has + assert not ts._exception_blame + assert not ts._processing_on + assert not ts._has_lost_dependencies + assert ts not in self.unrunnable + assert all([dts._who_has for dts in ts._dependencies]) - if self.worker_ttl: - pc = PeriodicCallback(self.check_worker_ttl, self.worker_ttl) - self.periodic_callbacks["worker-ttl"] = pc + ws: WorkerState = self.decide_worker(ts) + if ws is None: + return {}, worker_msgs, client_msgs + worker = ws._address - if self.idle_timeout: - pc = PeriodicCallback(self.check_idle, self.idle_timeout / 4) - self.periodic_callbacks["idle-timeout"] = pc + duration_estimate = self.set_duration_estimate(ts, ws) + ts._processing_on = ws + ws._occupancy += duration_estimate + self.total_occupancy += duration_estimate + ts.state = "processing" + self.consume_resources(ts, ws) + self.check_idle_saturated(ws) + self.n_tasks += 1 - if extensions is None: - extensions = list(DEFAULT_EXTENSIONS) - if dask.config.get("distributed.scheduler.work-stealing"): - extensions.append(WorkStealing) - for ext in extensions: - ext(self) + if ts._actor: + ws._actors.add(ts) - setproctitle("dask-scheduler [not started]") - Scheduler._instances.add(self) - self.rpc.allow_offload = False - self.status = Status.undefined + # logger.debug("Send job to worker: %s, %s", worker, key) - ################## - # Administration # - ################## + worker_msgs[worker] = self._task_to_msg(ts) - def __repr__(self): - return '' % ( - self.address, - len(self.workers), - self.total_nthreads, - ) + return {}, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - def identity(self, comm=None): - """ Basic information about ourselves and our cluster """ - d = { - "type": type(self).__name__, - "id": str(self.id), - "address": self.address, - "services": {key: v.port for (key, v) in self.services.items()}, - "workers": { - worker.address: worker.identity() for worker in self.workers.values() - }, - } - return d + pdb.set_trace() + raise - def get_worker_service_addr(self, worker, service_name, protocol=False): - """ - Get the (host, port) address of the named service on the *worker*. - Returns None if the service doesn't exist. + def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs): + try: + workers: dict = cast(dict, self.workers) + ws: WorkerState = workers[worker] + tasks: dict = self.tasks + ts: TaskState = tasks[key] + worker_msgs: dict = {} + client_msgs: dict = {} - Parameters - ---------- - worker : address - service_name : str - Common services include 'bokeh' and 'nanny' - protocol : boolean - Whether or not to include a full address with protocol (True) - or just a (host, port) pair - """ - ws: WorkerState = self.workers[worker] - port = ws._services.get(service_name) - if port is None: - return None - elif protocol: - return "%(protocol)s://%(host)s:%(port)d" % { - "protocol": ws._address.split("://")[0], - "host": ws.host, - "port": port, - } - else: - return ws.host, port + if self.validate: + assert not ts._processing_on + assert ts._waiting_on + assert ts._state == "waiting" - async def start(self): - """ Clear out old state and restart all running coroutines """ - await super().start() - assert self.status != Status.running + ts._waiting_on.clear() - enable_gc_diagnosis() + if nbytes is not None: + ts.set_nbytes(nbytes) - self.clear_task_state() + self.check_idle_saturated(ws) - with suppress(AttributeError): - for c in self._worker_coroutines: - c.cancel() + recommendations: dict = {} + client_msgs: dict = {} - for addr in self._start_address: - await self.listen( - addr, - allow_offload=False, - handshake_overrides={"pickle-protocol": 4, "compression": None}, - **self.security.get_listen_args("scheduler"), - ) - self.ip = get_address_host(self.listen_address) - listen_ip = self.ip + self._add_to_memory(ts, ws, recommendations, client_msgs, **kwargs) - if listen_ip == "0.0.0.0": - listen_ip = "" + if self.validate: + assert not ts._processing_on + assert not ts._waiting_on + assert ts._who_has - if self.address.startswith("inproc://"): - listen_ip = "localhost" + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - # Services listen on all addresses - self.start_services(listen_ip) + pdb.set_trace() + raise - for listener in self.listeners: - logger.info(" Scheduler at: %25s", listener.contact_address) - for k, v in self.services.items(): - logger.info("%11s at: %25s", k, "%s:%d" % (listen_ip, v.port)) + def transition_processing_memory( + self, + key, + nbytes=None, + type=None, + typename=None, + worker=None, + startstops=None, + **kwargs, + ): + ws: WorkerState + wws: WorkerState + worker_msgs: dict = {} + client_msgs: dict = {} + try: + tasks: dict = self.tasks + ts: TaskState = tasks[key] + assert worker + assert isinstance(worker, str) - self.loop.add_callback(self.reevaluate_occupancy) + if self.validate: + assert ts._processing_on + ws = ts._processing_on + assert ts in ws._processing + assert not ts._waiting_on + assert not ts._who_has, (ts, ts._who_has) + assert not ts._exception_blame + assert ts._state == "processing" - if self.scheduler_file: - with open(self.scheduler_file, "w") as f: - json.dump(self.identity(), f, indent=2) + workers: dict = cast(dict, self.workers) + ws = workers.get(worker) + if ws is None: + return {key: "released"}, worker_msgs, client_msgs - fn = self.scheduler_file # remove file when we close the process - - def del_scheduler_file(): - if os.path.exists(fn): - os.remove(fn) + if ws != ts._processing_on: # someone else has this task + logger.info( + "Unexpected worker completed task, likely due to" + " work stealing. Expected: %s, Got: %s, Key: %s", + ts._processing_on, + ws, + key, + ) + return {}, worker_msgs, client_msgs - weakref.finalize(self, del_scheduler_file) + if startstops: + L = list() + for startstop in startstops: + stop = startstop["stop"] + start = startstop["start"] + action = startstop["action"] + if action == "compute": + L.append((start, stop)) - for preload in self.preloads: - await preload.start() + # record timings of all actions -- a cheaper way of + # getting timing info compared with get_task_stream() + ts._prefix._all_durations[action] += stop - start - await asyncio.gather(*[plugin.start(self) for plugin in self.plugins]) + if len(L) > 0: + compute_start, compute_stop = L[0] + else: # This is very rare + compute_start = compute_stop = None + else: + compute_start = compute_stop = None - self.start_periodic_callbacks() + ############################# + # Update Timing Information # + ############################# + if compute_start and ws._processing.get(ts, True): + # Update average task duration for worker + old_duration = ts._prefix._duration_average + new_duration = compute_stop - compute_start + if old_duration < 0: + avg_duration = new_duration + else: + avg_duration = 0.5 * old_duration + 0.5 * new_duration - setproctitle("dask-scheduler [%s]" % (self.address,)) - return self + ts._prefix._duration_average = avg_duration + ts._group._duration += new_duration - async def close(self, comm=None, fast=False, close_workers=False): - """Send cleanup signal to all coroutines then wait until finished + tts: TaskState + for tts in self.unknown_durations.pop(ts._prefix._name, ()): + if tts._processing_on: + wws = tts._processing_on + old = wws._processing[tts] + comm = self.get_comm_cost(tts, wws) + wws._processing[tts] = avg_duration + comm + wws._occupancy += avg_duration + comm - old + self.total_occupancy += avg_duration + comm - old - See Also - -------- - Scheduler.cleanup - """ - if self.status in (Status.closing, Status.closed, Status.closing_gracefully): - await self.finished() - return - self.status = Status.closing + ############################ + # Update State Information # + ############################ + if nbytes is not None: + ts.set_nbytes(nbytes) - logger.info("Scheduler closing...") - setproctitle("dask-scheduler [closing]") + recommendations: dict = {} + client_msgs: dict = {} - for preload in self.preloads: - await preload.teardown() + self._remove_from_processing(ts) - if close_workers: - await self.broadcast(msg={"op": "close_gracefully"}, nanny=True) - for worker in self.workers: - self.worker_send(worker, {"op": "close"}) - for i in range(20): # wait a second for send signals to clear - if self.workers: - await asyncio.sleep(0.05) - else: - break + self._add_to_memory( + ts, ws, recommendations, client_msgs, type=type, typename=typename + ) - await asyncio.gather(*[plugin.close() for plugin in self.plugins]) + if self.validate: + assert not ts._processing_on + assert not ts._waiting_on - for pc in self.periodic_callbacks.values(): - pc.stop() - self.periodic_callbacks.clear() + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - self.stop_services() + pdb.set_trace() + raise - for ext in self.extensions.values(): - with suppress(AttributeError): - ext.teardown() - logger.info("Scheduler closing all comms") + def transition_memory_released(self, key, safe=False): + ws: WorkerState + try: + tasks: dict = self.tasks + ts: TaskState = tasks[key] + dts: TaskState + worker_msgs: dict = {} + client_msgs: dict = {} - futures = [] - for w, comm in list(self.stream_comms.items()): - if not comm.closed(): - comm.send({"op": "close", "report": False}) - comm.send({"op": "close-stream"}) - with suppress(AttributeError): - futures.append(comm.close()) + if self.validate: + assert not ts._waiting_on + assert not ts._processing_on + if safe: + assert not ts._waiters - for future in futures: # TODO: do all at once - await future + if ts._actor: + for ws in ts._who_has: + ws._actors.discard(ts) + if ts._who_wants: + ts._exception_blame = ts + ts._exception = "Worker holding Actor was lost" + return ( + {ts._key: "erred"}, + worker_msgs, + client_msgs, + ) # don't try to recreate - for comm in self.client_comms.values(): - comm.abort() + recommendations: dict = {} - await self.rpc.close() + for dts in ts._waiters: + if dts._state in ("no-worker", "processing"): + recommendations[dts._key] = "waiting" + elif dts._state == "waiting": + dts._waiting_on.add(ts) - self.status = Status.closed - self.stop() - await super().close() + # XXX factor this out? + for ws in ts._who_has: + ws._has_what.remove(ts) + ws._nbytes -= ts.get_nbytes() + ts._group._nbytes_in_memory -= ts.get_nbytes() + worker_msgs[ws._address] = { + "op": "delete-data", + "keys": [key], + "report": False, + } - setproctitle("dask-scheduler [closed]") - disable_gc_diagnosis() + ts._who_has.clear() - async def close_worker(self, comm=None, worker=None, safe=None): - """Remove a worker from the cluster + ts.state = "released" - This both removes the worker from our local state and also sends a - signal to the worker to shut down. This works regardless of whether or - not the worker has a nanny process restarting it - """ - logger.info("Closing worker %s", worker) - with log_errors(): - self.log_event(worker, {"action": "close-worker"}) - ws: WorkerState = self.workers[worker] - nanny_addr = ws._nanny - address = nanny_addr or worker + report_msg = {"op": "lost-data", "key": key} + cs: ClientState + for cs in ts._who_wants: + client_msgs[cs._client_key] = report_msg - self.worker_send(worker, {"op": "close", "report": False}) - await self.remove_worker(address=worker, safe=safe) + if not ts._run_spec: # pure data + recommendations[key] = "forgotten" + elif ts._has_lost_dependencies: + recommendations[key] = "forgotten" + elif ts._who_wants or ts._waiters: + recommendations[key] = "waiting" - ########### - # Stimuli # - ########### + if self.validate: + assert not ts._waiting_on - def heartbeat_worker( - self, - comm=None, - address=None, - resolve_address=True, - now=None, - resources=None, - host_info=None, - metrics=None, - executing=None, - ): - address = self.coerce_address(address, resolve_address) - address = normalize_address(address) - if address not in self.workers: - return {"status": "missing"} + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - host = get_address_host(address) - local_now = time() - now = now or time() - assert metrics - host_info = host_info or {} + pdb.set_trace() + raise - self.host_info[host]["last-seen"] = local_now - frac = 1 / len(self.workers) - self.bandwidth = ( - self.bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac - ) - for other, (bw, count) in metrics["bandwidth"]["workers"].items(): - if (address, other) not in self.bandwidth_workers: - self.bandwidth_workers[address, other] = bw / count - else: - alpha = (1 - frac) ** count - self.bandwidth_workers[address, other] = self.bandwidth_workers[ - address, other - ] * alpha + bw * (1 - alpha) - for typ, (bw, count) in metrics["bandwidth"]["types"].items(): - if typ not in self.bandwidth_types: - self.bandwidth_types[typ] = bw / count - else: - alpha = (1 - frac) ** count - self.bandwidth_types[typ] = self.bandwidth_types[typ] * alpha + bw * ( - 1 - alpha - ) + def transition_released_erred(self, key): + try: + tasks: dict = self.tasks + ts: TaskState = tasks[key] + dts: TaskState + failing_ts: TaskState + worker_msgs: dict = {} + client_msgs: dict = {} - ws: WorkerState = self.workers[address] + if self.validate: + with log_errors(pdb=LOG_PDB): + assert ts._exception_blame + assert not ts._who_has + assert not ts._waiting_on + assert not ts._waiters - ws._last_seen = time() + recommendations: dict = {} - if executing is not None: - ws._executing = { - self.tasks[key]: duration for key, duration in executing.items() + failing_ts = ts._exception_blame + + for dts in ts._dependents: + dts._exception_blame = failing_ts + if not dts._who_has: + recommendations[dts._key] = "erred" + + report_msg = { + "op": "task-erred", + "key": key, + "exception": failing_ts._exception, + "traceback": failing_ts._traceback, } + cs: ClientState + for cs in ts._who_wants: + client_msgs[cs._client_key] = report_msg - if metrics: - ws._metrics = metrics + ts.state = "erred" - if host_info: - self.host_info[host].update(host_info) + # TODO: waiting data? + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - delay = time() - now - ws._time_delay = delay + pdb.set_trace() + raise - if resources: - self.add_resources(worker=address, resources=resources) + def transition_erred_released(self, key): + try: + tasks: dict = self.tasks + ts: TaskState = tasks[key] + dts: TaskState + worker_msgs: dict = {} + client_msgs: dict = {} - self.log_event(address, merge({"action": "heartbeat"}, metrics)) + if self.validate: + with log_errors(pdb=LOG_PDB): + assert all([dts._state != "erred" for dts in ts._dependencies]) + assert ts._exception_blame + assert not ts._who_has + assert not ts._waiting_on + assert not ts._waiters - return { - "status": "OK", - "time": time(), - "heartbeat-interval": heartbeat_interval(len(self.workers)), - } + recommendations: dict = {} - async def add_worker( - self, - comm=None, - address=None, - keys=(), - nthreads=None, - name=None, - resolve_address=True, - nbytes=None, - types=None, - now=None, - resources=None, - host_info=None, - memory_limit=None, - metrics=None, - pid=0, - services=None, - local_directory=None, - versions=None, - nanny=None, - extra=None, - ): - """ Add a new worker to the cluster """ - with log_errors(): - address = self.coerce_address(address, resolve_address) - address = normalize_address(address) - host = get_address_host(address) + ts._exception = None + ts._exception_blame = None + ts._traceback = None - ws: WorkerState = self.workers.get(address) - if ws is not None: - raise ValueError("Worker already exists %s" % ws) + for dts in ts._dependents: + if dts._state == "erred": + recommendations[dts._key] = "waiting" - if name in self.aliases: - logger.warning( - "Worker tried to connect with a duplicate name: %s", name - ) - msg = { - "status": "error", - "message": "name taken, %s" % name, - "time": time(), - } - if comm: - await comm.write(msg) - return + report_msg = {"op": "task-retried", "key": key} + cs: ClientState + for cs in ts._who_wants: + client_msgs[cs._client_key] = report_msg - self.workers[address] = ws = WorkerState( - address=address, - pid=pid, - nthreads=nthreads, - memory_limit=memory_limit or 0, - name=name, - local_directory=local_directory, - services=services, - versions=versions, - nanny=nanny, - extra=extra, - ) + ts.state = "released" - if "addresses" not in self.host_info[host]: - self.host_info[host].update({"addresses": set(), "nthreads": 0}) + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - self.host_info[host]["addresses"].add(address) - self.host_info[host]["nthreads"] += nthreads + pdb.set_trace() + raise - self.total_nthreads += nthreads - self.aliases[name] = address + def transition_waiting_released(self, key): + try: + tasks: dict = self.tasks + ts: TaskState = tasks[key] + worker_msgs: dict = {} + client_msgs: dict = {} - response = self.heartbeat_worker( - address=address, - resolve_address=resolve_address, - now=now, - resources=resources, - host_info=host_info, - metrics=metrics, - ) + if self.validate: + assert not ts._who_has + assert not ts._processing_on - # Do not need to adjust self.total_occupancy as self.occupancy[ws] cannot exist before this. - self.check_idle_saturated(ws) + recommendations: dict = {} - # for key in keys: # TODO - # self.mark_key_in_memory(key, [address]) + dts: TaskState + for dts in ts._dependencies: + s = dts._waiters + if ts in s: + s.discard(ts) + if not s and not dts._who_wants: + recommendations[dts._key] = "released" + ts._waiting_on.clear() - self.stream_comms[address] = BatchedSend(interval="5ms", loop=self.loop) + ts.state = "released" - if ws._nthreads > len(ws._processing): - self.idle[ws._address] = ws + if ts._has_lost_dependencies: + recommendations[key] = "forgotten" + elif not ts._exception_blame and (ts._who_wants or ts._waiters): + recommendations[key] = "waiting" + else: + ts._waiters.clear() - for plugin in self.plugins[:]: - try: - result = plugin.add_worker(scheduler=self, worker=address) - if inspect.isawaitable(result): - await result - except Exception as e: - logger.exception(e) + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - recommendations: dict - if nbytes: - for key in nbytes: - tasks: dict = self.tasks - ts: TaskState = tasks.get(key) - if ts is not None and ts._state in ("processing", "waiting"): - recommendations = self.transition( - key, - "memory", - worker=address, - nbytes=nbytes[key], - typename=types[key], - ) - self.transitions(recommendations) + pdb.set_trace() + raise - recommendations = {} - for ts in list(self.unrunnable): - valid: set = self.valid_workers(ts) - if valid is None or ws in valid: - recommendations[ts._key] = "waiting" + def transition_processing_released(self, key): + try: + tasks: dict = self.tasks + ts: TaskState = tasks[key] + dts: TaskState + worker_msgs: dict = {} + client_msgs: dict = {} - if recommendations: - self.transitions(recommendations) + if self.validate: + assert ts._processing_on + assert not ts._who_has + assert not ts._waiting_on + assert self.tasks[key].state == "processing" - self.log_event(address, {"action": "add-worker"}) - self.log_event("all", {"action": "add-worker", "worker": address}) - logger.info("Register worker %s", ws) + w: str = self._remove_from_processing(ts) + if w: + worker_msgs[w] = {"op": "release-task", "key": key} - msg = { - "status": "OK", - "time": time(), - "heartbeat-interval": heartbeat_interval(len(self.workers)), - "worker-plugins": self.worker_plugins, - } + ts.state = "released" - cs: ClientState - version_warning = version_module.error_message( - version_module.get_versions(), - merge( - {w: ws._versions for w, ws in self.workers.items()}, - {c: cs._versions for c, cs in self.clients.items() if cs._versions}, - ), - versions, - client_name="This Worker", - ) - msg.update(version_warning) + recommendations: dict = {} - if comm: - await comm.write(msg) - await self.handle_worker(comm=comm, worker=address) - - def update_graph_hlg( - self, - client=None, - hlg=None, - keys=None, - dependencies=None, - restrictions=None, - priority=None, - loose_restrictions=None, - resources=None, - submitting_task=None, - retries=None, - user_priority=0, - actors=None, - fifo_timeout=0, - ): + if ts._has_lost_dependencies: + recommendations[key] = "forgotten" + elif ts._waiters or ts._who_wants: + recommendations[key] = "waiting" - dsk, dependencies, annotations = highlevelgraph_unpack(hlg) + if recommendations.get(key) != "waiting": + for dts in ts._dependencies: + if dts._state != "released": + s = dts._waiters + s.discard(ts) + if not s and not dts._who_wants: + recommendations[dts._key] = "released" + ts._waiters.clear() - # Remove any self-dependencies (happens on test_publish_bag() and others) - for k, v in dependencies.items(): - deps = set(v) - if k in deps: - deps.remove(k) - dependencies[k] = deps + if self.validate: + assert not ts._processing_on - if priority is None: - # Removing all non-local keys before calling order() - dsk_keys = set(dsk) # intersection() of sets is much faster than dict_keys - stripped_deps = { - k: v.intersection(dsk_keys) - for k, v in dependencies.items() - if k in dsk_keys - } - priority = dask.order.order(dsk, dependencies=stripped_deps) + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - return self.update_graph( - client, - dsk, - keys, - dependencies, - restrictions, - priority, - loose_restrictions, - resources, - submitting_task, - retries, - user_priority, - actors, - fifo_timeout, - annotations, - ) + pdb.set_trace() + raise - def update_graph( - self, - client=None, - tasks=None, - keys=None, - dependencies=None, - restrictions=None, - priority=None, - loose_restrictions=None, - resources=None, - submitting_task=None, - retries=None, - user_priority=0, - actors=None, - fifo_timeout=0, - annotations=None, + def transition_processing_erred( + self, key, cause=None, exception=None, traceback=None, **kwargs ): - """ - Add new computations to the internal dask graph - - This happens whenever the Client calls submit, map, get, or compute. - """ - start = time() - fifo_timeout = parse_timedelta(fifo_timeout) - keys = set(keys) - if len(tasks) > 1: - self.log_event( - ["all", client], {"action": "update_graph", "count": len(tasks)} - ) + ws: WorkerState + try: + tasks: dict = self.tasks + ts: TaskState = tasks[key] + dts: TaskState + failing_ts: TaskState + worker_msgs: dict = {} + client_msgs: dict = {} - # Remove aliases - for k in list(tasks): - if tasks[k] is k: - del tasks[k] + if self.validate: + assert cause or ts._exception_blame + assert ts._processing_on + assert not ts._who_has + assert not ts._waiting_on - dependencies = dependencies or {} + if ts._actor: + ws = ts._processing_on + ws._actors.remove(ts) - n = 0 - while len(tasks) != n: # walk through new tasks, cancel any bad deps - n = len(tasks) - for k, deps in list(dependencies.items()): - if any( - dep not in self.tasks and dep not in tasks for dep in deps - ): # bad key - logger.info("User asked for computation on lost data, %s", k) - del tasks[k] - del dependencies[k] - if k in keys: - keys.remove(k) - self.report({"op": "cancelled-key", "key": k}, client=client) - self.client_releases_keys(keys=[k], client=client) + self._remove_from_processing(ts) - # Avoid computation that is already finished - ts: TaskState - already_in_memory = set() # tasks that are already done - for k, v in dependencies.items(): - if v and k in self.tasks: - ts = self.tasks[k] - if ts._state in ("memory", "erred"): - already_in_memory.add(k) + if exception is not None: + ts._exception = exception + if traceback is not None: + ts._traceback = traceback + if cause is not None: + failing_ts = self.tasks[cause] + ts._exception_blame = failing_ts + else: + failing_ts = ts._exception_blame - dts: TaskState - if already_in_memory: - dependents = dask.core.reverse_dict(dependencies) - stack = list(already_in_memory) - done = set(already_in_memory) - while stack: # remove unnecessary dependencies - key = stack.pop() - ts = self.tasks[key] - try: - deps = dependencies[key] - except KeyError: - deps = self.dependencies[key] - for dep in deps: - if dep in dependents: - child_deps = dependents[dep] - else: - child_deps = self.dependencies[dep] - if all(d in done for d in child_deps): - if dep in self.tasks and dep not in done: - done.add(dep) - stack.append(dep) + recommendations: dict = {} - for d in done: - tasks.pop(d, None) - dependencies.pop(d, None) + for dts in ts._dependents: + dts._exception_blame = failing_ts + recommendations[dts._key] = "erred" - # Get or create task states - stack = list(keys) - touched_keys = set() - touched_tasks = [] - while stack: - k = stack.pop() - if k in touched_keys: - continue - # XXX Have a method get_task_state(self, k) ? - ts = self.tasks.get(k) - if ts is None: - ts = self.new_task(k, tasks.get(k), "released") - elif not ts._run_spec: - ts._run_spec = tasks.get(k) + for dts in ts._dependencies: + s = dts._waiters + s.discard(ts) + if not s and not dts._who_wants: + recommendations[dts._key] = "released" - touched_keys.add(k) - touched_tasks.append(ts) - stack.extend(dependencies.get(k, ())) + ts._waiters.clear() # do anything with this? - self.client_desires_keys(keys=keys, client=client) + ts.state = "erred" - # Add dependencies - for key, deps in dependencies.items(): - ts = self.tasks.get(key) - if ts is None or ts._dependencies: - continue - for dep in deps: - dts = self.tasks[dep] - ts.add_dependency(dts) + report_msg = { + "op": "task-erred", + "key": key, + "exception": failing_ts._exception, + "traceback": failing_ts._traceback, + } + cs: ClientState + for cs in ts._who_wants: + client_msgs[cs._client_key] = report_msg - # Compute priorities - if isinstance(user_priority, Number): - user_priority = {k: user_priority for k in tasks} + cs = self.clients["fire-and-forget"] + if ts in cs._wants_what: + self._client_releases_keys( + cs=cs, + keys=[key], + recommendations=recommendations, + ) - annotations = annotations or {} - restrictions = restrictions or {} - loose_restrictions = loose_restrictions or [] - resources = resources or {} - retries = retries or {} + if self.validate: + assert not ts._processing_on - # Override existing taxonomy with per task annotations - if annotations: - if "priority" in annotations: - user_priority.update(annotations["priority"]) + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - if "workers" in annotations: - restrictions.update(annotations["workers"]) + pdb.set_trace() + raise - if "allow_other_workers" in annotations: - loose_restrictions.extend( - k for k, v in annotations["allow_other_workers"].items() if v - ) - - if "retries" in annotations: - retries.update(annotations["retries"]) + def transition_no_worker_released(self, key): + try: + tasks: dict = self.tasks + ts: TaskState = tasks[key] + dts: TaskState + worker_msgs: dict = {} + client_msgs: dict = {} - if "resources" in annotations: - resources.update(annotations["resources"]) + if self.validate: + assert self.tasks[key].state == "no-worker" + assert not ts._who_has + assert not ts._waiting_on - for a, kv in annotations.items(): - for k, v in kv.items(): - ts = self.tasks[k] - ts._annotations[a] = v + self.unrunnable.remove(ts) + ts.state = "released" - # Add actors - if actors is True: - actors = list(keys) - for actor in actors or []: - ts = self.tasks[actor] - ts._actor = True + for dts in ts._dependencies: + dts._waiters.discard(ts) - priority = priority or dask.order.order( - tasks - ) # TODO: define order wrt old graph + ts._waiters.clear() - if submitting_task: # sub-tasks get better priority than parent tasks - ts = self.tasks.get(submitting_task) - if ts is not None: - generation = ts._priority[0] - 0.01 - else: # super-task already cleaned up - generation = self.generation - elif self._last_time + fifo_timeout < start: - self.generation += 1 # older graph generations take precedence - generation = self.generation - self._last_time = start - else: - generation = self.generation + return {}, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - for key in set(priority) & touched_keys: - ts = self.tasks[key] - if ts._priority is None: - ts._priority = (-(user_priority.get(key, 0)), generation, priority[key]) + pdb.set_trace() + raise - # Ensure all runnables have a priority - runnables = [ts for ts in touched_tasks if ts._run_spec] - for ts in runnables: - if ts._priority is None and ts._run_spec: - ts._priority = (self.generation, 0) + def _propagate_forgotten(self, ts: TaskState, recommendations: dict): + worker_msgs: dict = {} + workers: dict = cast(dict, self.workers) + ts.state = "forgotten" + key: str = ts._key + dts: TaskState + for dts in ts._dependents: + dts._has_lost_dependencies = True + dts._dependencies.remove(ts) + dts._waiting_on.discard(ts) + if dts._state not in ("memory", "erred"): + # Cannot compute task anymore + recommendations[dts._key] = "forgotten" + ts._dependents.clear() + ts._waiters.clear() - if restrictions: - # *restrictions* is a dict keying task ids to lists of - # restriction specifications (either worker names or addresses) - for k, v in restrictions.items(): - if v is None: - continue - ts = self.tasks.get(k) - if ts is None: - continue - ts._host_restrictions = set() - ts._worker_restrictions = set() - for w in v: - try: - w = self.coerce_address(w) - except ValueError: - # Not a valid address, but perhaps it's a hostname - ts._host_restrictions.add(w) - else: - ts._worker_restrictions.add(w) + for dts in ts._dependencies: + dts._dependents.remove(ts) + s: set = dts._waiters + s.discard(ts) + if not dts._dependents and not dts._who_wants: + # Task not needed anymore + assert dts is not ts + recommendations[dts._key] = "forgotten" + ts._dependencies.clear() + ts._waiting_on.clear() - if loose_restrictions: - for k in loose_restrictions: - ts = self.tasks[k] - ts._loose_restrictions = True + if ts._who_has: + ts._group._nbytes_in_memory -= ts.get_nbytes() - if resources: - for k, v in resources.items(): - if v is None: - continue - assert isinstance(v, dict) - ts = self.tasks.get(k) - if ts is None: - continue - ts._resource_restrictions = v + ws: WorkerState + for ws in ts._who_has: + ws._has_what.remove(ts) + ws._nbytes -= ts.get_nbytes() + w: str = ws._address + if w in workers: # in case worker has died + worker_msgs[w] = {"op": "delete-data", "keys": [key], "report": False} + ts._who_has.clear() - if retries: - for k, v in retries.items(): - assert isinstance(v, int) - ts = self.tasks.get(k) - if ts is None: - continue - ts._retries = v + return worker_msgs - # Compute recommendations - recommendations: dict = {} + def transition_memory_forgotten(self, key): + tasks: dict + ws: WorkerState + try: + tasks = self.tasks + ts: TaskState = tasks[key] + worker_msgs: dict = {} + client_msgs: dict = {} - for ts in sorted(runnables, key=operator.attrgetter("priority"), reverse=True): - if ts._state == "released" and ts._run_spec: - recommendations[ts._key] = "waiting" + if self.validate: + assert ts._state == "memory" + assert not ts._processing_on + assert not ts._waiting_on + if not ts._run_spec: + # It's ok to forget a pure data task + pass + elif ts._has_lost_dependencies: + # It's ok to forget a task with forgotten dependencies + pass + elif not ts._who_wants and not ts._waiters and not ts._dependents: + # It's ok to forget a task that nobody needs + pass + else: + assert 0, (ts,) - for ts in touched_tasks: - for dts in ts._dependencies: - if dts._exception_blame: - ts._exception_blame = dts._exception_blame - recommendations[ts._key] = "erred" - break + recommendations: dict = {} - for plugin in self.plugins[:]: - try: - plugin.update_graph( - self, - client=client, - tasks=tasks, - keys=keys, - restrictions=restrictions or {}, - dependencies=dependencies, - priority=priority, - loose_restrictions=loose_restrictions, - resources=resources, - annotations=annotations, - ) - except Exception as e: - logger.exception(e) + if ts._actor: + for ws in ts._who_has: + ws._actors.discard(ts) - self.transitions(recommendations) + worker_msgs = self._propagate_forgotten(ts, recommendations) - for ts in touched_tasks: - if ts._state in ("memory", "erred"): - self.report_on_key(ts=ts, client=client) + client_msgs = self._task_to_client_msgs(ts) + self.remove_key(key) - end = time() - if self.digests is not None: - self.digests["update-graph-duration"].add(end - start) + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - # TODO: balance workers + pdb.set_trace() + raise - def new_task(self, key, spec, state): - """ Create a new task, and associated states """ - ts: TaskState = TaskState(key, spec) - tp: TaskPrefix - tg: TaskGroup - ts._state = state - prefix_key = key_split(key) + def transition_released_forgotten(self, key): try: - tp = self.task_prefixes[prefix_key] - except KeyError: - self.task_prefixes[prefix_key] = tp = TaskPrefix(prefix_key) - ts._prefix = tp + tasks: dict = self.tasks + ts: TaskState = tasks[key] + worker_msgs: dict = {} + client_msgs: dict = {} - group_key = ts._group_key - try: - tg = self.task_groups[group_key] - except KeyError: - self.task_groups[group_key] = tg = TaskGroup(group_key) - tg._prefix = tp - tp._groups.append(tg) - tg.add(ts) - self.tasks[key] = ts - return ts + if self.validate: + assert ts._state in ("released", "erred") + assert not ts._who_has + assert not ts._processing_on + assert not ts._waiting_on, (ts, ts._waiting_on) + if not ts._run_spec: + # It's ok to forget a pure data task + pass + elif ts._has_lost_dependencies: + # It's ok to forget a task with forgotten dependencies + pass + elif not ts._who_wants and not ts._waiters and not ts._dependents: + # It's ok to forget a task that nobody needs + pass + else: + assert 0, (ts,) - def stimulus_task_finished(self, key=None, worker=None, **kwargs): - """ Mark that a task has finished execution on a particular worker """ - logger.debug("Stimulus task finished %s, %s", key, worker) - - tasks: dict = self.tasks - ts: TaskState = tasks.get(key) - if ts is None: - return {} - workers: dict = cast(dict, self.workers) - ws: WorkerState = workers[worker] - ts._metadata.update(kwargs["metadata"]) - - recommendations: dict - if ts._state == "processing": - recommendations = self.transition(key, "memory", worker=worker, **kwargs) - - if ts._state == "memory": - assert ws in ts._who_has - else: - logger.debug( - "Received already computed task, worker: %s, state: %s" - ", key: %s, who_has: %s", - worker, - ts._state, - key, - ts._who_has, - ) - if ws not in ts._who_has: - self.worker_send(worker, {"op": "release-task", "key": key}) - recommendations = {} + recommendations: dict = {} + self._propagate_forgotten(ts, recommendations) - return recommendations + client_msgs = self._task_to_client_msgs(ts) + self.remove_key(key) - def stimulus_task_erred( - self, key=None, worker=None, exception=None, traceback=None, **kwargs - ): - """ Mark that a task has erred on a particular worker """ - logger.debug("Stimulus task erred %s, %s", key, worker) + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - ts: TaskState = self.tasks.get(key) - if ts is None: - return {} + pdb.set_trace() + raise - recommendations: dict - if ts._state == "processing": - retries = ts._retries - if retries > 0: - ts._retries = retries - 1 - recommendations = self.transition(key, "waiting") - else: - recommendations = self.transition( - key, - "erred", - cause=key, - exception=exception, - traceback=traceback, - worker=worker, - **kwargs, - ) - else: - recommendations = {} + def check_idle_saturated(self, ws: WorkerState, occ: double = -1.0): + """Update the status of the idle and saturated state - return recommendations + The scheduler keeps track of workers that are .. - def stimulus_missing_data( - self, cause=None, key=None, worker=None, ensure=True, **kwargs - ): - """ Mark that certain keys have gone missing. Recover. """ - with log_errors(): - logger.debug("Stimulus missing data %s, %s", key, worker) + - Saturated: have enough work to stay busy + - Idle: do not have enough work to stay busy - ts: TaskState = self.tasks.get(key) - if ts is None or ts._state == "memory": - return {} - cts: TaskState = self.tasks.get(cause) + They are considered saturated if they both have enough tasks to occupy + all of their threads, and if the expected runtime of those tasks is + large enough. - recommendations: dict = {} + This is useful for load balancing and adaptivity. + """ + total_nthreads: Py_ssize_t = self.total_nthreads + if total_nthreads == 0 or ws.status == Status.closed: + return + if occ < 0: + occ = ws._occupancy - if cts is not None and cts._state == "memory": # couldn't find this - ws: WorkerState - for ws in cts._who_has: # TODO: this behavior is extreme - ws._has_what.remove(cts) - ws._nbytes -= cts.get_nbytes() - cts._who_has.clear() - recommendations[cause] = "released" + nc: Py_ssize_t = ws._nthreads + p: Py_ssize_t = len(ws._processing) + total_occupancy: double = self.total_occupancy + avg: double = total_occupancy / total_nthreads - if key: - recommendations[key] = "released" + idle = self.idle + saturated: set = self.saturated + if p < nc or occ < nc * avg / 2: + idle[ws._address] = ws + saturated.discard(ws) + else: + idle.pop(ws._address, None) - self.transitions(recommendations) + if p > nc: + pending: double = occ * (p - nc) / (p * nc) + if 0.4 < pending > 1.9 * avg: + saturated.add(ws) + return - if self.validate: - assert cause not in self.who_has + saturated.discard(ws) - return {} + def _client_releases_keys(self, keys: list, cs: ClientState, recommendations: dict): + """ Remove keys from client desired list """ + logger.debug("Client %s releases keys: %s", cs._client_key, keys) + ts: TaskState + tasks2: set = set() + for key in keys: + ts = self.tasks.get(key) + if ts is not None and ts in cs._wants_what: + cs._wants_what.remove(ts) + s: set = ts._who_wants + s.remove(cs) + if not s: + tasks2.add(ts) - def stimulus_retry(self, comm=None, keys=None, client=None): - logger.info("Client %s requests to retry %d keys", client, len(keys)) - if client: - self.log_event(client, {"action": "retry", "count": len(keys)}) + for ts in tasks2: + if not ts._dependents: + # No live dependents, can forget + recommendations[ts._key] = "forgotten" + elif ts._state != "erred" and not ts._waiters: + recommendations[ts._key] = "released" - stack = list(keys) - seen = set() - roots = [] - ts: TaskState + def _task_to_msg(self, ts: TaskState, duration=None) -> dict: + """ Convert a single computational task to a message """ + ws: WorkerState dts: TaskState - while stack: - key = stack.pop() - seen.add(key) - ts = self.tasks[key] - erred_deps = [dts._key for dts in ts._dependencies if dts._state == "erred"] - if erred_deps: - stack.extend(erred_deps) - else: - roots.append(key) - recommendations: dict = {key: "waiting" for key in roots} - self.transitions(recommendations) + if duration is None: + duration = self.get_task_duration(ts) - if self.validate: - for key in seen: - assert not self.tasks[key].exception_blame + msg: dict = { + "op": "compute-task", + "key": ts._key, + "priority": ts._priority, + "duration": duration, + } + if ts._resource_restrictions: + msg["resource_restrictions"] = ts._resource_restrictions + if ts._actor: + msg["actor"] = True - return tuple(seen) + deps: set = ts._dependencies + if deps: + msg["who_has"] = { + dts._key: [ws._address for ws in dts._who_has] for dts in deps + } + msg["nbytes"] = {dts._key: dts._nbytes for dts in deps} - async def remove_worker(self, comm=None, address=None, safe=False, close=True): - """ - Remove worker from cluster + if self.validate: + assert all(msg["who_has"].values()) - We do this when a worker reports that it plans to leave or when it - appears to be unresponsive. This may send its tasks back to a released - state. - """ - with log_errors(): - if self.status == Status.closed: - return + task = ts._run_spec + if type(task) is dict: + msg.update(task) + else: + msg["task"] = task - address = self.coerce_address(address) + return msg - if address not in self.workers: - return "already-removed" + def _task_to_report_msg(self, ts: TaskState) -> dict: + if ts is None: + return {"op": "cancelled-key", "key": ts._key} + elif ts._state == "forgotten": + return {"op": "cancelled-key", "key": ts._key} + elif ts._state == "memory": + return {"op": "key-in-memory", "key": ts._key} + elif ts._state == "erred": + failing_ts: TaskState = ts._exception_blame + return { + "op": "task-erred", + "key": ts._key, + "exception": failing_ts._exception, + "traceback": failing_ts._traceback, + } + else: + return None - host = get_address_host(address) + def _task_to_client_msgs(self, ts: TaskState) -> dict: + cs: ClientState + clients: dict = self.clients + client_keys: list + if ts is None: + # Notify all clients + client_keys = list(clients) + else: + # Notify clients interested in key + client_keys = [cs._client_key for cs in ts._who_wants] - ws: WorkerState = self.workers[address] + report_msg: dict = self._task_to_report_msg(ts) - self.log_event( - ["all", address], - { - "action": "remove-worker", - "worker": address, - "processing-tasks": dict(ws._processing), - }, - ) - logger.info("Remove worker %s", ws) - if close: - with suppress(AttributeError, CommClosedError): - self.stream_comms[address].send({"op": "close", "report": False}) + client_msgs: dict = {} + for k in client_keys: + client_msgs[k] = report_msg - self.remove_resources(address) + return client_msgs - self.host_info[host]["nthreads"] -= ws._nthreads - self.host_info[host]["addresses"].remove(address) - self.total_nthreads -= ws._nthreads - - if not self.host_info[host]["addresses"]: - del self.host_info[host] + def _reevaluate_occupancy_worker(self, ws: WorkerState): + """ See reevaluate_occupancy """ + old: double = ws._occupancy + new: double = 0 + diff: double + ts: TaskState + est: double + for ts in ws._processing: + est = self.set_duration_estimate(ts, ws) + new += est - self.rpc.remove(address) - del self.stream_comms[address] - del self.aliases[ws._name] - self.idle.pop(ws._address, None) - self.saturated.discard(ws) - del self.workers[address] - ws.status = Status.closed - self.total_occupancy -= ws._occupancy + ws._occupancy = new + diff = new - old + self.total_occupancy += diff + self.check_idle_saturated(ws) - recommendations: dict = {} + # significant increase in duration + if new > old * 1.3: + steal = self.extensions.get("stealing") + if steal is not None: + for ts in ws._processing: + steal.remove_key_from_stealable(ts) + steal.put_key_in_stealable(ts) - ts: TaskState - for ts in list(ws._processing): - k = ts._key - recommendations[k] = "released" - if not safe: - ts._suspicious += 1 - ts._prefix._suspicious += 1 - if ts._suspicious > self.allowed_failures: - del recommendations[k] - e = pickle.dumps( - KilledWorker(task=k, last_worker=ws.clean()), protocol=4 - ) - r = self.transition(k, "erred", exception=e, cause=k) - recommendations.update(r) - logger.info( - "Task %s marked as failed because %d workers died" - " while trying to run it", - ts._key, - self.allowed_failures, - ) + def get_comm_cost(self, ts: TaskState, ws: WorkerState): + """ + Get the estimated communication cost (in s.) to compute the task + on the given worker. + """ + dts: TaskState + deps: set = ts._dependencies - ws._has_what + nbytes: Py_ssize_t = 0 + bandwidth: double = self.bandwidth + for dts in deps: + nbytes += dts._nbytes + return nbytes / bandwidth - for ts in ws._has_what: - ts._who_has.remove(ws) - if not ts._who_has: - if ts._run_spec: - recommendations[ts._key] = "released" - else: # pure data - recommendations[ts._key] = "forgotten" - ws._has_what.clear() + def get_task_duration(self, ts: TaskState, default: double = -1): + """ + Get the estimated computation cost of the given task + (not including any communication cost). + """ + duration: double = ts._prefix._duration_average + if duration < 0: + s: set = self.unknown_durations[ts._prefix._name] + s.add(ts) + if default < 0: + duration = UNKNOWN_TASK_DURATION + else: + duration = default - self.transitions(recommendations) + return duration - for plugin in self.plugins[:]: - try: - result = plugin.remove_worker(scheduler=self, worker=address) - if inspect.isawaitable(result): - await result - except Exception as e: - logger.exception(e) + def valid_workers(self, ts: TaskState) -> set: + """Return set of currently valid workers for key - if not self.workers: - logger.info("Lost all workers") + If all workers are valid then this returns ``None``. + This checks tracks the following state: - for w in self.workers: - self.bandwidth_workers.pop((address, w), None) - self.bandwidth_workers.pop((w, address), None) + * worker_restrictions + * host_restrictions + * resource_restrictions + """ + workers: dict = cast(dict, self.workers) + s: set = None - def remove_worker_from_events(): - # If the worker isn't registered anymore after the delay, remove from events - if address not in self.workers and address in self.events: - del self.events[address] + if ts._worker_restrictions: + s = {w for w in ts._worker_restrictions if w in workers} - cleanup_delay = parse_timedelta( - dask.config.get("distributed.scheduler.events-cleanup-delay") - ) - self.loop.call_later(cleanup_delay, remove_worker_from_events) - logger.debug("Removed worker %s", ws) + if ts._host_restrictions: + # Resolve the alias here rather than early, for the worker + # may not be connected when host_restrictions is populated + hr: list = [self.coerce_hostname(h) for h in ts._host_restrictions] + # XXX need HostState? + sl: list = [ + self.host_info[h]["addresses"] for h in hr if h in self.host_info + ] + ss: set = set.union(*sl) if sl else set() + if s is None: + s = ss + else: + s |= ss - return "OK" + if ts._resource_restrictions: + dw: dict = { + resource: { + w + for w, supplied in self.resources[resource].items() + if supplied >= required + } + for resource, required in ts._resource_restrictions.items() + } - def stimulus_cancel(self, comm, keys=None, client=None, force=False): - """ Stop execution on a list of keys """ - logger.info("Client %s requests to cancel %d keys", client, len(keys)) - if client: - self.log_event( - client, {"action": "cancel", "count": len(keys), "force": force} - ) - for key in keys: - self.cancel_key(key, client, force=force) + ww: set = set.intersection(*dw.values()) + if s is None: + s = ww + else: + s &= ww - def cancel_key(self, key, client, retries=5, force=False): - """ Cancel a particular key and all dependents """ - # TODO: this should be converted to use the transition mechanism - ts: TaskState = self.tasks.get(key) - dts: TaskState - try: - cs: ClientState = self.clients[client] - except KeyError: - return - if ts is None or not ts._who_wants: # no key yet, lets try again in a moment - if retries: - self.loop.call_later( - 0.2, lambda: self.cancel_key(key, client, retries - 1) - ) - return - if force or ts._who_wants == {cs}: # no one else wants this key - for dts in list(ts._dependents): - self.cancel_key(dts._key, client, force=force) - logger.info("Scheduler cancels key %s. Force=%s", key, force) - self.report({"op": "cancelled-key", "key": key}) - clients = list(ts._who_wants) if force else [cs] - for cs in clients: - self.client_releases_keys(keys=[key], client=cs._client_key) + if s is not None: + s = {workers[w] for w in s} - def client_desires_keys(self, keys=None, client=None): - cs: ClientState = self.clients.get(client) - if cs is None: - # For publish, queues etc. - self.clients[client] = cs = ClientState(client) - ts: TaskState - for k in keys: - ts = self.tasks.get(k) - if ts is None: - # For publish, queues etc. - ts = self.new_task(k, None, "released") - ts._who_wants.add(cs) - cs._wants_what.add(ts) + return s - if ts._state in ("memory", "erred"): - self.report_on_key(ts=ts, client=client) + def worker_objective(self, ts: TaskState, ws: WorkerState): + """ + Objective function to determine which worker should get the task - def _client_releases_keys(self, keys: list, cs: ClientState, recommendations: dict): - """ Remove keys from client desired list """ - logger.debug("Client %s releases keys: %s", cs._client_key, keys) - ts: TaskState - tasks2: set = set() - for key in keys: - ts = self.tasks.get(key) - if ts is not None and ts in cs._wants_what: - cs._wants_what.remove(ts) - s: set = ts._who_wants - s.remove(cs) - if not s: - tasks2.add(ts) + Minimize expected start time. If a tie then break with data storage. + """ + dts: TaskState + nbytes: Py_ssize_t + comm_bytes: Py_ssize_t = 0 + for dts in ts._dependencies: + if ws not in dts._who_has: + nbytes = dts.get_nbytes() + comm_bytes += nbytes - for ts in tasks2: - if not ts._dependents: - # No live dependents, can forget - recommendations[ts._key] = "forgotten" - elif ts._state != "erred" and not ts._waiters: - recommendations[ts._key] = "released" + bandwidth: double = self.bandwidth + stack_time: double = ws._occupancy / ws._nthreads + start_time: double = stack_time + comm_bytes / bandwidth - def client_releases_keys(self, keys=None, client=None): - """ Remove keys from client desired list """ + if ts._actor: + return (len(ws._actors), start_time, ws._nbytes) + else: + return (start_time, ws._nbytes) - if not isinstance(keys, list): - keys = list(keys) - cs: ClientState = self.clients[client] - recommendations: dict = {} - self._client_releases_keys(keys=keys, cs=cs, recommendations=recommendations) - self.transitions(recommendations) +class Scheduler(SchedulerState, ServerNode): + """Dynamic distributed task scheduler - def client_heartbeat(self, client=None): - """ Handle heartbeats from Client """ - cs: ClientState = self.clients[client] - cs._last_seen = time() + The scheduler tracks the current state of workers, data, and computations. + The scheduler listens for events and responds by controlling workers + appropriately. It continuously tries to use the workers to execute an ever + growing dask graph. - ################### - # Task Validation # - ################### + All events are handled quickly, in linear time with respect to their input + (which is often of constant size) and generally within a millisecond. To + accomplish this the scheduler tracks a lot of state. Every operation + maintains the consistency of this state. - def validate_released(self, key): - ts: TaskState = self.tasks[key] - dts: TaskState - assert ts._state == "released" - assert not ts._waiters - assert not ts._waiting_on - assert not ts._who_has - assert not ts._processing_on - assert not any([ts in dts._waiters for dts in ts._dependencies]) - assert ts not in self.unrunnable + The scheduler communicates with the outside world through Comm objects. + It maintains a consistent and valid view of the world even when listening + to several clients at once. - def validate_waiting(self, key): - ts: TaskState = self.tasks[key] - dts: TaskState - assert ts._waiting_on - assert not ts._who_has - assert not ts._processing_on - assert ts not in self.unrunnable - for dts in ts._dependencies: - # We are waiting on a dependency iff it's not stored - assert (not not dts._who_has) != (dts in ts._waiting_on) - assert ts in dts._waiters # XXX even if dts._who_has? + A Scheduler is typically started either with the ``dask-scheduler`` + executable:: - def validate_processing(self, key): - ts: TaskState = self.tasks[key] - dts: TaskState - assert not ts._waiting_on - ws: WorkerState = ts._processing_on - assert ws - assert ts in ws._processing - assert not ts._who_has - for dts in ts._dependencies: - assert dts._who_has - assert ts in dts._waiters + $ dask-scheduler + Scheduler started at 127.0.0.1:8786 - def validate_memory(self, key): - ts: TaskState = self.tasks[key] - dts: TaskState - assert ts._who_has - assert not ts._processing_on - assert not ts._waiting_on - assert ts not in self.unrunnable - for dts in ts._dependents: - assert (dts in ts._waiters) == (dts._state in ("waiting", "processing")) - assert ts not in dts._waiting_on + Or within a LocalCluster a Client starts up without connection + information:: - def validate_no_worker(self, key): - ts: TaskState = self.tasks[key] - dts: TaskState - assert ts in self.unrunnable - assert not ts._waiting_on - assert ts in self.unrunnable - assert not ts._processing_on - assert not ts._who_has - for dts in ts._dependencies: - assert dts._who_has + >>> c = Client() # doctest: +SKIP + >>> c.cluster.scheduler # doctest: +SKIP + Scheduler(...) - def validate_erred(self, key): - ts: TaskState = self.tasks[key] - assert ts._exception_blame - assert not ts._who_has + Users typically do not interact with the scheduler directly but rather with + the client object ``Client``. - def validate_key(self, key, ts: TaskState = None): - try: - if ts is None: - ts = self.tasks.get(key) - if ts is None: - logger.debug("Key lost: %s", key) - else: - ts.validate() - try: - func = getattr(self, "validate_" + ts._state.replace("-", "_")) - except AttributeError: - logger.error( - "self.validate_%s not found", ts._state.replace("-", "_") - ) - else: - func(key) - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + **State** - pdb.set_trace() - raise + The scheduler contains the following state variables. Each variable is + listed along with what it stores and a brief description. - def validate_state(self, allow_overlap=False): - validate_state(self.tasks, self.workers, self.clients) + * **tasks:** ``{task key: TaskState}`` + Tasks currently known to the scheduler + * **unrunnable:** ``{TaskState}`` + Tasks in the "no-worker" state - if not (set(self.workers) == set(self.stream_comms)): - raise ValueError("Workers not the same in all collections") + * **workers:** ``{worker key: WorkerState}`` + Workers currently connected to the scheduler + * **idle:** ``{WorkerState}``: + Set of workers that are not fully utilized + * **saturated:** ``{WorkerState}``: + Set of workers that are not over-utilized - ws: WorkerState - for w, ws in self.workers.items(): - assert isinstance(w, str), (type(w), w) - assert isinstance(ws, WorkerState), (type(ws), ws) - assert ws._address == w - if not ws._processing: - assert not ws._occupancy - assert ws._address in cast(dict, self.idle) + * **host_info:** ``{hostname: dict}``: + Information about each worker host - ts: TaskState - for k, ts in self.tasks.items(): - assert isinstance(ts, TaskState), (type(ts), ts) - assert ts._key == k - self.validate_key(k, ts) + * **clients:** ``{client key: ClientState}`` + Clients currently connected to the scheduler - c: str - cs: ClientState - for c, cs in self.clients.items(): - # client=None is often used in tests... - assert c is None or type(c) == str, (type(c), c) - assert type(cs) == ClientState, (type(cs), cs) - assert cs._client_key == c + * **services:** ``{str: port}``: + Other services running on this scheduler, like Bokeh + * **loop:** ``IOLoop``: + The running Tornado IOLoop + * **client_comms:** ``{client key: Comm}`` + For each client, a Comm object used to receive task requests and + report task status updates. + * **stream_comms:** ``{worker key: Comm}`` + For each worker, a Comm object from which we both accept stimuli and + report results + * **task_duration:** ``{key-prefix: time}`` + Time we expect certain functions to take, e.g. ``{'sum': 0.25}`` + """ - a = {w: ws._nbytes for w, ws in self.workers.items()} - b = { - w: sum(ts.get_nbytes() for ts in ws._has_what) - for w, ws in self.workers.items() - } - assert a == b, (a, b) + default_port = 8786 + _instances = weakref.WeakSet() - actual_total_occupancy = 0 - for worker, ws in self.workers.items(): - assert abs(sum(ws._processing.values()) - ws._occupancy) < 1e-8 - actual_total_occupancy += ws._occupancy + def __init__( + self, + loop=None, + delete_interval="500ms", + synchronize_worker_interval="60s", + services=None, + service_kwargs=None, + allowed_failures=None, + extensions=None, + validate=None, + scheduler_file=None, + security=None, + worker_ttl=None, + idle_timeout=None, + interface=None, + host=None, + port=0, + protocol=None, + dashboard_address=None, + dashboard=None, + http_prefix="/", + preload=None, + preload_argv=(), + plugins=(), + **kwargs, + ): + self._setup_logging(logger) - assert abs(actual_total_occupancy - self.total_occupancy) < 1e-8, ( - actual_total_occupancy, - self.total_occupancy, + # Attributes + if allowed_failures is None: + allowed_failures = dask.config.get("distributed.scheduler.allowed-failures") + self.allowed_failures = allowed_failures + if validate is None: + validate = dask.config.get("distributed.scheduler.validate") + self.validate = validate + self.proc = psutil.Process() + self.delete_interval = parse_timedelta(delete_interval, default="ms") + self.synchronize_worker_interval = parse_timedelta( + synchronize_worker_interval, default="ms" ) + self.digests = None + self.service_specs = services or {} + self.service_kwargs = service_kwargs or {} + self.services = {} + self.scheduler_file = scheduler_file + worker_ttl = worker_ttl or dask.config.get("distributed.scheduler.worker-ttl") + self.worker_ttl = parse_timedelta(worker_ttl) if worker_ttl else None + idle_timeout = idle_timeout or dask.config.get( + "distributed.scheduler.idle-timeout" + ) + if idle_timeout: + self.idle_timeout = parse_timedelta(idle_timeout) + else: + self.idle_timeout = None + self.idle_since = time() + self.time_started = self.idle_since # compatibility for dask-gateway + self._lock = asyncio.Lock() + self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth")) + self.bandwidth_workers = defaultdict(float) + self.bandwidth_types = defaultdict(float) - ################### - # Manage Messages # - ################### - - def report(self, msg: dict, ts: TaskState = None, client: str = None): - """ - Publish updates to all listening Queues and Comms + if not preload: + preload = dask.config.get("distributed.scheduler.preload") + if not preload_argv: + preload_argv = dask.config.get("distributed.scheduler.preload-argv") + self.preloads = preloading.process_preloads(self, preload, preload_argv) - If the message contains a key then we only send the message to those - comms that care about the key. - """ - if ts is None: - msg_key = msg.get("key") - if msg_key is not None: - tasks: dict = self.tasks - ts = tasks.get(msg_key) + if isinstance(security, dict): + security = Security(**security) + self.security = security or Security() + assert isinstance(self.security, Security) + self.connection_args = self.security.get_connection_args("scheduler") + self.connection_args["handshake_overrides"] = { # common denominator + "pickle-protocol": 4 + } - cs: ClientState - client_comms: dict = self.client_comms - client_keys: list - if ts is None: - # Notify all clients - client_keys = list(client_comms) - elif client is None: - # Notify clients interested in key - client_keys = [cs._client_key for cs in ts._who_wants] - else: - # Notify clients interested in key (including `client`) - client_keys = [ - cs._client_key for cs in ts._who_wants if cs._client_key != client - ] - client_keys.append(client) - - k: str - for k in client_keys: - c = client_comms.get(k) - if c is None: - continue - try: - c.send(msg) - # logger.debug("Scheduler sends message to client %s", msg) - except CommClosedError: - if self.status == Status.running: - logger.critical("Tried writing to closed comm: %s", msg) - - async def add_client(self, comm, client=None, versions=None): - """Add client to network - - We listen to all future messages from this Comm. - """ - assert client is not None - comm.name = "Scheduler->Client" - logger.info("Receive client connection: %s", client) - self.log_event(["all", client], {"action": "add-client", "client": client}) - self.clients[client] = ClientState(client, versions=versions) + self._start_address = addresses_from_user_args( + host=host, + port=port, + interface=interface, + protocol=protocol, + security=security, + default_port=self.default_port, + ) - for plugin in self.plugins[:]: + http_server_modules = dask.config.get("distributed.scheduler.http.routes") + show_dashboard = dashboard or (dashboard is None and dashboard_address) + missing_bokeh = False + # install vanilla route if show_dashboard but bokeh is not installed + if show_dashboard: try: - plugin.add_client(scheduler=self, client=client) - except Exception as e: - logger.exception(e) - - try: - bcomm = BatchedSend(interval="2ms", loop=self.loop) - bcomm.start(comm) - self.client_comms[client] = bcomm - msg = {"op": "stream-start"} - ws: WorkerState - version_warning = version_module.error_message( - version_module.get_versions(), - {w: ws._versions for w, ws in self.workers.items()}, - versions, + import distributed.dashboard.scheduler + except ImportError: + missing_bokeh = True + http_server_modules.append("distributed.http.scheduler.missing_bokeh") + routes = get_handlers( + server=self, modules=http_server_modules, prefix=http_prefix + ) + self.start_http_server(routes, dashboard_address, default_port=8787) + if show_dashboard and not missing_bokeh: + distributed.dashboard.scheduler.connect( + self.http_application, self.http_server, self, prefix=http_prefix ) - msg.update(version_warning) - bcomm.send(msg) - try: - await self.handle_stream(comm=comm, extra={"client": client}) - finally: - self.remove_client(client=client) - logger.debug("Finished handling client %s", client) - finally: - if not comm.closed(): - self.client_comms[client].send({"op": "stream-closed"}) - try: - if not shutting_down(): - await self.client_comms[client].close() - del self.client_comms[client] - if self.status == Status.running: - logger.info("Close client connection: %s", client) - except TypeError: # comm becomes None during GC - pass + # Communication state + self.loop = loop or IOLoop.current() + self.client_comms = dict() + self.stream_comms = dict() + self._worker_coroutines = [] + self._ipython_kernel = None - def remove_client(self, client=None): - """ Remove client from network """ - if self.status == Status.running: - logger.info("Remove client %s", client) - self.log_event(["all", client], {"action": "remove-client", "client": client}) - try: - cs: ClientState = self.clients[client] - except KeyError: - # XXX is this a legitimate condition? - pass - else: - ts: TaskState - self.client_releases_keys( - keys=[ts._key for ts in cs._wants_what], client=cs._client_key - ) - del self.clients[client] + # Task state + self.tasks = dict() + self.task_groups = dict() + self.task_prefixes = dict() + for old_attr, new_attr, wrap in [ + ("priority", "priority", None), + ("dependencies", "dependencies", _legacy_task_key_set), + ("dependents", "dependents", _legacy_task_key_set), + ("retries", "retries", None), + ]: + func = operator.attrgetter(new_attr) + if wrap is not None: + func = compose(wrap, func) + setattr(self, old_attr, _StateLegacyMapping(self.tasks, func)) - for plugin in self.plugins[:]: - try: - plugin.remove_client(scheduler=self, client=client) - except Exception as e: - logger.exception(e) + for old_attr, new_attr, wrap in [ + ("nbytes", "nbytes", None), + ("who_wants", "who_wants", _legacy_client_key_set), + ("who_has", "who_has", _legacy_worker_key_set), + ("waiting", "waiting_on", _legacy_task_key_set), + ("waiting_data", "waiters", _legacy_task_key_set), + ("rprocessing", "processing_on", None), + ("host_restrictions", "host_restrictions", None), + ("worker_restrictions", "worker_restrictions", None), + ("resource_restrictions", "resource_restrictions", None), + ("suspicious_tasks", "suspicious", None), + ("exceptions", "exception", None), + ("tracebacks", "traceback", None), + ("exceptions_blame", "exception_blame", _task_key_or_none), + ]: + func = operator.attrgetter(new_attr) + if wrap is not None: + func = compose(wrap, func) + setattr(self, old_attr, _OptionalStateLegacyMapping(self.tasks, func)) - def remove_client_from_events(): - # If the client isn't registered anymore after the delay, remove from events - if client not in self.clients and client in self.events: - del self.events[client] + for old_attr, new_attr, wrap in [ + ("loose_restrictions", "loose_restrictions", None) + ]: + func = operator.attrgetter(new_attr) + if wrap is not None: + func = compose(wrap, func) + setattr(self, old_attr, _StateLegacySet(self.tasks, func)) - cleanup_delay = parse_timedelta( - dask.config.get("distributed.scheduler.events-cleanup-delay") - ) - self.loop.call_later(cleanup_delay, remove_client_from_events) + self.generation = 0 + self._last_client = None + self._last_time = 0 + self.unrunnable = set() - def _task_to_msg(self, ts: TaskState, duration=None) -> dict: - """ Convert a single computational task to a message """ - ws: WorkerState - dts: TaskState + self.n_tasks = 0 + self.task_metadata = dict() + self.datasets = dict() - if duration is None: - duration = self.get_task_duration(ts) + # Prefix-keyed containers + self.unknown_durations = defaultdict(set) - msg: dict = { - "op": "compute-task", - "key": ts._key, - "priority": ts._priority, - "duration": duration, - } - if ts._resource_restrictions: - msg["resource_restrictions"] = ts._resource_restrictions - if ts._actor: - msg["actor"] = True + # Client state + self.clients = dict() + for old_attr, new_attr, wrap in [ + ("wants_what", "wants_what", _legacy_task_key_set) + ]: + func = operator.attrgetter(new_attr) + if wrap is not None: + func = compose(wrap, func) + setattr(self, old_attr, _StateLegacyMapping(self.clients, func)) + self.clients["fire-and-forget"] = ClientState("fire-and-forget") - deps: set = ts._dependencies - if deps: - msg["who_has"] = { - dts._key: [ws._address for ws in dts._who_has] for dts in deps - } - msg["nbytes"] = {dts._key: dts._nbytes for dts in deps} + # Worker state + self.workers = sortedcontainers.SortedDict() + for old_attr, new_attr, wrap in [ + ("nthreads", "nthreads", None), + ("worker_bytes", "nbytes", None), + ("worker_resources", "resources", None), + ("used_resources", "used_resources", None), + ("occupancy", "occupancy", None), + ("worker_info", "metrics", None), + ("processing", "processing", _legacy_task_key_dict), + ("has_what", "has_what", _legacy_task_key_set), + ]: + func = operator.attrgetter(new_attr) + if wrap is not None: + func = compose(wrap, func) + setattr(self, old_attr, _StateLegacyMapping(self.workers, func)) - if self.validate: - assert all(msg["who_has"].values()) + self.idle = sortedcontainers.SortedDict() + self.saturated = set() - task = ts._run_spec - if type(task) is dict: - msg.update(task) - else: - msg["task"] = task + self.total_nthreads = 0 + self.total_occupancy = 0 + self.host_info = defaultdict(dict) + self.resources = defaultdict(dict) + self.aliases = dict() - return msg + self._task_state_collections = [self.unrunnable] - def send_task_to_worker(self, worker, ts: TaskState, duration=None): - """ Send a single computational task to a worker """ - try: - msg: dict = self._task_to_msg(ts, duration) - self.worker_send(worker, msg) - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + self._worker_collections = [ + self.workers, + self.host_info, + self.resources, + self.aliases, + ] - pdb.set_trace() - raise + self.extensions = {} + self.plugins = list(plugins) + self.transition_log = deque( + maxlen=dask.config.get("distributed.scheduler.transition-log-length") + ) + self.log = deque( + maxlen=dask.config.get("distributed.scheduler.transition-log-length") + ) + self.events = defaultdict(lambda: deque(maxlen=100000)) + self.event_counts = defaultdict(int) + self.worker_plugins = [] - def handle_uncaught_error(self, **msg): - logger.exception(clean_exception(**msg)[1]) + worker_handlers = { + "task-finished": self.handle_task_finished, + "task-erred": self.handle_task_erred, + "release": self.handle_release_data, + "release-worker-data": self.release_worker_data, + "add-keys": self.add_keys, + "missing-data": self.handle_missing_data, + "long-running": self.handle_long_running, + "reschedule": self.reschedule, + "keep-alive": lambda *args, **kwargs: None, + "log-event": self.log_worker_event, + } - def handle_task_finished(self, key=None, worker=None, **msg): - if worker not in self.workers: - return - validate_key(key) - r = self.stimulus_task_finished(key=key, worker=worker, **msg) - self.transitions(r) + client_handlers = { + "update-graph": self.update_graph, + "update-graph-hlg": self.update_graph_hlg, + "client-desires-keys": self.client_desires_keys, + "update-data": self.update_data, + "report-key": self.report_on_key, + "client-releases-keys": self.client_releases_keys, + "heartbeat-client": self.client_heartbeat, + "close-client": self.remove_client, + "restart": self.restart, + } - def handle_task_erred(self, key=None, **msg): - r = self.stimulus_task_erred(key=key, **msg) - self.transitions(r) + self.handlers = { + "register-client": self.add_client, + "scatter": self.scatter, + "register-worker": self.add_worker, + "unregister": self.remove_worker, + "gather": self.gather, + "cancel": self.stimulus_cancel, + "retry": self.stimulus_retry, + "feed": self.feed, + "terminate": self.close, + "broadcast": self.broadcast, + "proxy": self.proxy, + "ncores": self.get_ncores, + "has_what": self.get_has_what, + "who_has": self.get_who_has, + "processing": self.get_processing, + "call_stack": self.get_call_stack, + "profile": self.get_profile, + "performance_report": self.performance_report, + "get_logs": self.get_logs, + "logs": self.get_logs, + "worker_logs": self.get_worker_logs, + "log_event": self.log_worker_event, + "events": self.get_events, + "nbytes": self.get_nbytes, + "versions": self.versions, + "add_keys": self.add_keys, + "rebalance": self.rebalance, + "replicate": self.replicate, + "start_ipython": self.start_ipython, + "run_function": self.run_function, + "update_data": self.update_data, + "set_resources": self.add_resources, + "retire_workers": self.retire_workers, + "get_metadata": self.get_metadata, + "set_metadata": self.set_metadata, + "heartbeat_worker": self.heartbeat_worker, + "get_task_status": self.get_task_status, + "get_task_stream": self.get_task_stream, + "register_worker_plugin": self.register_worker_plugin, + "adaptive_target": self.adaptive_target, + "workers_to_close": self.workers_to_close, + "subscribe_worker_status": self.subscribe_worker_status, + "start_task_metadata": self.start_task_metadata, + "stop_task_metadata": self.stop_task_metadata, + } - def handle_release_data(self, key=None, worker=None, client=None, **msg): - ts: TaskState = self.tasks.get(key) - if ts is None: - return - ws: WorkerState = self.workers[worker] - if ts._processing_on != ws: - return - r = self.stimulus_missing_data(key=key, ensure=False, **msg) - self.transitions(r) + self._transitions = { + ("released", "waiting"): self.transition_released_waiting, + ("waiting", "released"): self.transition_waiting_released, + ("waiting", "processing"): self.transition_waiting_processing, + ("waiting", "memory"): self.transition_waiting_memory, + ("processing", "released"): self.transition_processing_released, + ("processing", "memory"): self.transition_processing_memory, + ("processing", "erred"): self.transition_processing_erred, + ("no-worker", "released"): self.transition_no_worker_released, + ("no-worker", "waiting"): self.transition_no_worker_waiting, + ("released", "forgotten"): self.transition_released_forgotten, + ("memory", "forgotten"): self.transition_memory_forgotten, + ("erred", "forgotten"): self.transition_released_forgotten, + ("erred", "released"): self.transition_erred_released, + ("memory", "released"): self.transition_memory_released, + ("released", "erred"): self.transition_released_erred, + } - def handle_missing_data(self, key=None, errant_worker=None, **kwargs): - logger.debug("handle missing data key=%s worker=%s", key, errant_worker) - self.log.append(("missing", key, errant_worker)) + connection_limit = get_fileno_limit() / 2 - ts: TaskState = self.tasks.get(key) - if ts is None or not ts._who_has: - return - if errant_worker in self.workers: - ws: WorkerState = self.workers[errant_worker] - if ws in ts._who_has: - ts._who_has.remove(ws) - ws._has_what.remove(ts) - ws._nbytes -= ts.get_nbytes() - if not ts._who_has: - if ts._run_spec: - self.transitions({key: "released"}) - else: - self.transitions({key: "forgotten"}) + super().__init__( + handlers=self.handlers, + stream_handlers=merge(worker_handlers, client_handlers), + io_loop=self.loop, + connection_limit=connection_limit, + deserialize=False, + connection_args=self.connection_args, + **kwargs, + ) - def release_worker_data(self, comm=None, keys=None, worker=None): - ws: WorkerState = self.workers[worker] - tasks = {self.tasks[k] for k in keys} - removed_tasks = tasks & ws._has_what - ws._has_what -= removed_tasks + if self.worker_ttl: + pc = PeriodicCallback(self.check_worker_ttl, self.worker_ttl) + self.periodic_callbacks["worker-ttl"] = pc - ts: TaskState - recommendations: dict = {} - for ts in removed_tasks: - ws._nbytes -= ts.get_nbytes() - wh = ts._who_has - wh.remove(ws) - if not wh: - recommendations[ts._key] = "released" - if recommendations: - self.transitions(recommendations) + if self.idle_timeout: + pc = PeriodicCallback(self.check_idle, self.idle_timeout / 4) + self.periodic_callbacks["idle-timeout"] = pc - def handle_long_running(self, key=None, worker=None, compute_duration=None): - """A task has seceded from the thread pool + if extensions is None: + extensions = list(DEFAULT_EXTENSIONS) + if dask.config.get("distributed.scheduler.work-stealing"): + extensions.append(WorkStealing) + for ext in extensions: + ext(self) - We stop the task from being stolen in the future, and change task - duration accounting as if the task has stopped. - """ - ts: TaskState = self.tasks[key] - if "stealing" in self.extensions: - self.extensions["stealing"].remove_key_from_stealable(ts) + setproctitle("dask-scheduler [not started]") + Scheduler._instances.add(self) + self.rpc.allow_offload = False + self.status = Status.undefined - ws: WorkerState = ts._processing_on - if ws is None: - logger.debug("Received long-running signal from duplicate task. Ignoring.") - return + ################## + # Administration # + ################## - if compute_duration: - old_duration = ts._prefix._duration_average - new_duration = compute_duration - if old_duration < 0: - avg_duration = new_duration - else: - avg_duration = 0.5 * old_duration + 0.5 * new_duration + def __repr__(self): + return '' % ( + self.address, + len(self.workers), + self.total_nthreads, + ) - ts._prefix._duration_average = avg_duration + def identity(self, comm=None): + """ Basic information about ourselves and our cluster """ + d = { + "type": type(self).__name__, + "id": str(self.id), + "address": self.address, + "services": {key: v.port for (key, v) in self.services.items()}, + "workers": { + worker.address: worker.identity() for worker in self.workers.values() + }, + } + return d - ws._occupancy -= ws._processing[ts] - self.total_occupancy -= ws._processing[ts] - ws._processing[ts] = 0 - self.check_idle_saturated(ws) + def get_worker_service_addr(self, worker, service_name, protocol=False): + """ + Get the (host, port) address of the named service on the *worker*. + Returns None if the service doesn't exist. - async def handle_worker(self, comm=None, worker=None): + Parameters + ---------- + worker : address + service_name : str + Common services include 'bokeh' and 'nanny' + protocol : boolean + Whether or not to include a full address with protocol (True) + or just a (host, port) pair """ - Listen to responses from a single worker + ws: WorkerState = self.workers[worker] + port = ws._services.get(service_name) + if port is None: + return None + elif protocol: + return "%(protocol)s://%(host)s:%(port)d" % { + "protocol": ws._address.split("://")[0], + "host": ws.host, + "port": port, + } + else: + return ws.host, port - This is the main loop for scheduler-worker interaction + async def start(self): + """ Clear out old state and restart all running coroutines """ + await super().start() + assert self.status != Status.running - See Also - -------- - Scheduler.handle_client: Equivalent coroutine for clients - """ - comm.name = "Scheduler connection to worker" - worker_comm = self.stream_comms[worker] - worker_comm.start(comm) - logger.info("Starting worker compute stream, %s", worker) - try: - await self.handle_stream(comm=comm, extra={"worker": worker}) - finally: - if worker in self.stream_comms: - worker_comm.abort() - await self.remove_worker(address=worker) + enable_gc_diagnosis() - def add_plugin(self, plugin=None, idempotent=False, **kwargs): - """ - Add external plugin to scheduler + self.clear_task_state() - See https://distributed.readthedocs.io/en/latest/plugins.html - """ - if isinstance(plugin, type): - plugin = plugin(self, **kwargs) + with suppress(AttributeError): + for c in self._worker_coroutines: + c.cancel() - if idempotent and any(isinstance(p, type(plugin)) for p in self.plugins): - return + for addr in self._start_address: + await self.listen( + addr, + allow_offload=False, + handshake_overrides={"pickle-protocol": 4, "compression": None}, + **self.security.get_listen_args("scheduler"), + ) + self.ip = get_address_host(self.listen_address) + listen_ip = self.ip - self.plugins.append(plugin) + if listen_ip == "0.0.0.0": + listen_ip = "" - def remove_plugin(self, plugin): - """ Remove external plugin from scheduler """ - self.plugins.remove(plugin) + if self.address.startswith("inproc://"): + listen_ip = "localhost" - def worker_send(self, worker, msg): - """Send message to worker + # Services listen on all addresses + self.start_services(listen_ip) - This also handles connection failures by adding a callback to remove - the worker on the next cycle. - """ - stream_comms: dict = self.stream_comms - try: - stream_comms[worker].send(msg) - except (CommClosedError, AttributeError): - self.loop.add_callback(self.remove_worker, address=worker) + for listener in self.listeners: + logger.info(" Scheduler at: %25s", listener.contact_address) + for k, v in self.services.items(): + logger.info("%11s at: %25s", k, "%s:%d" % (listen_ip, v.port)) - def client_send(self, client, msg): - """Send message to client""" - client_comms: dict = self.client_comms - c = client_comms.get(client) - if c is None: - return - try: - c.send(msg) - except CommClosedError: - if self.status == Status.running: - logger.critical("Tried writing to closed comm: %s", msg) + self.loop.add_callback(self.reevaluate_occupancy) - ############################ - # Less common interactions # - ############################ + if self.scheduler_file: + with open(self.scheduler_file, "w") as f: + json.dump(self.identity(), f, indent=2) - async def scatter( - self, - comm=None, - data=None, - workers=None, - client=None, - broadcast=False, - timeout=2, - ): - """Send data out to workers + fn = self.scheduler_file # remove file when we close the process - See also - -------- - Scheduler.broadcast: - """ - start = time() - while not self.workers: - await asyncio.sleep(0.2) - if time() > start + timeout: - raise TimeoutError("No workers found") + def del_scheduler_file(): + if os.path.exists(fn): + os.remove(fn) - if workers is None: - ws: WorkerState - nthreads = {w: ws._nthreads for w, ws in self.workers.items()} - else: - workers = [self.coerce_address(w) for w in workers] - nthreads = {w: self.workers[w].nthreads for w in workers} + weakref.finalize(self, del_scheduler_file) - assert isinstance(data, dict) + for preload in self.preloads: + await preload.start() - keys, who_has, nbytes = await scatter_to_workers( - nthreads, data, rpc=self.rpc, report=False - ) + await asyncio.gather(*[plugin.start(self) for plugin in self.plugins]) - self.update_data(who_has=who_has, nbytes=nbytes, client=client) + self.start_periodic_callbacks() - if broadcast: - if broadcast == True: # noqa: E712 - n = len(nthreads) - else: - n = broadcast - await self.replicate(keys=keys, workers=workers, n=n) + setproctitle("dask-scheduler [%s]" % (self.address,)) + return self - self.log_event( - [client, "all"], {"action": "scatter", "client": client, "count": len(data)} - ) - return keys + async def close(self, comm=None, fast=False, close_workers=False): + """Send cleanup signal to all coroutines then wait until finished - async def gather(self, comm=None, keys=None, serializers=None): - """ Collect data in from workers """ - ws: WorkerState - keys = list(keys) - who_has = {} - for key in keys: - ts: TaskState = self.tasks.get(key) - if ts is not None: - who_has[key] = [ws._address for ws in ts._who_has] - else: - who_has[key] = [] + See Also + -------- + Scheduler.cleanup + """ + if self.status in (Status.closing, Status.closed, Status.closing_gracefully): + await self.finished() + return + self.status = Status.closing - data, missing_keys, missing_workers = await gather_from_workers( - who_has, rpc=self.rpc, close=False, serializers=serializers - ) - if not missing_keys: - result = {"status": "OK", "data": data} - else: - missing_states = [ - (self.tasks[key].state if key in self.tasks else None) - for key in missing_keys - ] - logger.exception( - "Couldn't gather keys %s state: %s workers: %s", - missing_keys, - missing_states, - missing_workers, - ) - result = {"status": "error", "keys": missing_keys} - with log_errors(): - # Remove suspicious workers from the scheduler but allow them to - # reconnect. - await asyncio.gather( - *[ - self.remove_worker(address=worker, close=False) - for worker in missing_workers - ] - ) - for key, workers in missing_keys.items(): - # Task may already be gone if it was held by a - # `missing_worker` - ts: TaskState = self.tasks.get(key) - logger.exception( - "Workers don't have promised key: %s, %s", - str(workers), - str(key), - ) - if not workers or ts is None: - continue - for worker in workers: - ws = self.workers.get(worker) - if ws is not None and ts in ws._has_what: - ws._has_what.remove(ts) - ts._who_has.remove(ws) - ws._nbytes -= ts.get_nbytes() - self.transitions({key: "released"}) + logger.info("Scheduler closing...") + setproctitle("dask-scheduler [closing]") - self.log_event("all", {"action": "gather", "count": len(keys)}) - return result + for preload in self.preloads: + await preload.teardown() - def clear_task_state(self): - # XXX what about nested state such as ClientState.wants_what - # (see also fire-and-forget...) - logger.info("Clear task state") - for collection in self._task_state_collections: - collection.clear() + if close_workers: + await self.broadcast(msg={"op": "close_gracefully"}, nanny=True) + for worker in self.workers: + self.worker_send(worker, {"op": "close"}) + for i in range(20): # wait a second for send signals to clear + if self.workers: + await asyncio.sleep(0.05) + else: + break - async def restart(self, client=None, timeout=3): - """ Restart all workers. Reset local state. """ - with log_errors(): + await asyncio.gather(*[plugin.close() for plugin in self.plugins]) - n_workers = len(self.workers) + for pc in self.periodic_callbacks.values(): + pc.stop() + self.periodic_callbacks.clear() - logger.info("Send lost future signal to clients") - cs: ClientState - ts: TaskState - for cs in self.clients.values(): - self.client_releases_keys( - keys=[ts._key for ts in cs._wants_what], client=cs._client_key - ) + self.stop_services() - ws: WorkerState - nannies = {addr: ws._nanny for addr, ws in self.workers.items()} + for ext in self.extensions.values(): + with suppress(AttributeError): + ext.teardown() + logger.info("Scheduler closing all comms") - for addr in list(self.workers): - try: - # Ask the worker to close if it doesn't have a nanny, - # otherwise the nanny will kill it anyway - await self.remove_worker(address=addr, close=addr not in nannies) - except Exception as e: - logger.info( - "Exception while restarting. This is normal", exc_info=True - ) + futures = [] + for w, comm in list(self.stream_comms.items()): + if not comm.closed(): + comm.send({"op": "close", "report": False}) + comm.send({"op": "close-stream"}) + with suppress(AttributeError): + futures.append(comm.close()) - self.clear_task_state() + for future in futures: # TODO: do all at once + await future - for plugin in self.plugins[:]: - try: - plugin.restart(self) - except Exception as e: - logger.exception(e) + for comm in self.client_comms.values(): + comm.abort() - logger.debug("Send kill signal to nannies: %s", nannies) + await self.rpc.close() - nannies = [ - rpc(nanny_address, connection_args=self.connection_args) - for nanny_address in nannies.values() - if nanny_address is not None - ] + self.status = Status.closed + self.stop() + await super().close() - resps = All( - [ - nanny.restart( - close=True, timeout=timeout * 0.8, executor_wait=False - ) - for nanny in nannies - ] - ) - try: - resps = await asyncio.wait_for(resps, timeout) - except TimeoutError: - logger.error( - "Nannies didn't report back restarted within " - "timeout. Continuuing with restart process" - ) - else: - if not all(resp == "OK" for resp in resps): - logger.error( - "Not all workers responded positively: %s", resps, exc_info=True - ) - finally: - await asyncio.gather(*[nanny.close_rpc() for nanny in nannies]) + setproctitle("dask-scheduler [closed]") + disable_gc_diagnosis() - self.clear_task_state() + async def close_worker(self, comm=None, worker=None, safe=None): + """Remove a worker from the cluster - with suppress(AttributeError): - for c in self._worker_coroutines: - c.cancel() + This both removes the worker from our local state and also sends a + signal to the worker to shut down. This works regardless of whether or + not the worker has a nanny process restarting it + """ + logger.info("Closing worker %s", worker) + with log_errors(): + self.log_event(worker, {"action": "close-worker"}) + ws: WorkerState = self.workers[worker] + nanny_addr = ws._nanny + address = nanny_addr or worker - self.log_event([client, "all"], {"action": "restart", "client": client}) - start = time() - while time() < start + 10 and len(self.workers) < n_workers: - await asyncio.sleep(0.01) + self.worker_send(worker, {"op": "close", "report": False}) + await self.remove_worker(address=worker, safe=safe) - self.report({"op": "restart"}) + ########### + # Stimuli # + ########### - async def broadcast( + def heartbeat_worker( self, comm=None, - msg=None, - workers=None, - hosts=None, - nanny=False, - serializers=None, + address=None, + resolve_address=True, + now=None, + resources=None, + host_info=None, + metrics=None, + executing=None, ): - """ Broadcast message to workers, return all results """ - if workers is None or workers is True: - if hosts is None: - workers = list(self.workers) - else: - workers = [] - if hosts is not None: - for host in hosts: - if host in self.host_info: - workers.extend(self.host_info[host]["addresses"]) - # TODO replace with worker_list - - if nanny: - addresses = [self.workers[w].nanny for w in workers] - else: - addresses = workers + address = self.coerce_address(address, resolve_address) + address = normalize_address(address) + if address not in self.workers: + return {"status": "missing"} - async def send_message(addr): - comm = await self.rpc.connect(addr) - comm.name = "Scheduler Broadcast" - try: - resp = await send_recv(comm, close=True, serializers=serializers, **msg) - finally: - self.rpc.reuse(addr, comm) - return resp + host = get_address_host(address) + local_now = time() + now = now or time() + assert metrics + host_info = host_info or {} - results = await All( - [send_message(address) for address in addresses if address is not None] + self.host_info[host]["last-seen"] = local_now + frac = 1 / len(self.workers) + self.bandwidth = ( + self.bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac ) + for other, (bw, count) in metrics["bandwidth"]["workers"].items(): + if (address, other) not in self.bandwidth_workers: + self.bandwidth_workers[address, other] = bw / count + else: + alpha = (1 - frac) ** count + self.bandwidth_workers[address, other] = self.bandwidth_workers[ + address, other + ] * alpha + bw * (1 - alpha) + for typ, (bw, count) in metrics["bandwidth"]["types"].items(): + if typ not in self.bandwidth_types: + self.bandwidth_types[typ] = bw / count + else: + alpha = (1 - frac) ** count + self.bandwidth_types[typ] = self.bandwidth_types[typ] * alpha + bw * ( + 1 - alpha + ) - return dict(zip(workers, results)) + ws: WorkerState = self.workers[address] - async def proxy(self, comm=None, msg=None, worker=None, serializers=None): - """ Proxy a communication through the scheduler to some other worker """ - d = await self.broadcast( - comm=comm, msg=msg, workers=[worker], serializers=serializers - ) - return d[worker] + ws._last_seen = time() - async def _delete_worker_data(self, worker_address, keys): - """Delete data from a worker and update the corresponding worker/task states + if executing is not None: + ws._executing = { + self.tasks[key]: duration for key, duration in executing.items() + } - Parameters - ---------- - worker_address: str - Worker address to delete keys from - keys: List[str] - List of keys to delete on the specified worker - """ - await retry_operation( - self.rpc(addr=worker_address).delete_data, keys=list(keys), report=False - ) + if metrics: + ws._metrics = metrics - ws: WorkerState = self.workers[worker_address] - ts: TaskState - tasks: set = {self.tasks[key] for key in keys} - ws._has_what -= tasks - for ts in tasks: - ts._who_has.remove(ws) - ws._nbytes -= ts.get_nbytes() - self.log_event(ws._address, {"action": "remove-worker-data", "keys": keys}) + if host_info: + self.host_info[host].update(host_info) - async def rebalance(self, comm=None, keys=None, workers=None): - """Rebalance keys so that each worker stores roughly equal bytes + delay = time() - now + ws._time_delay = delay - **Policy** + if resources: + self.add_resources(worker=address, resources=resources) - This orders the workers by what fraction of bytes of the existing keys - they have. It walks down this list from most-to-least. At each worker - it sends the largest results it can find and sends them to the least - occupied worker until either the sender or the recipient are at the - average expected load. - """ - ts: TaskState - with log_errors(): - async with self._lock: - if keys: - tasks = {self.tasks[k] for k in keys} - missing_data = [ts._key for ts in tasks if not ts._who_has] - if missing_data: - return {"status": "missing-data", "keys": missing_data} - else: - tasks = set(self.tasks.values()) + self.log_event(address, merge({"action": "heartbeat"}, metrics)) - if workers: - workers = {self.workers[w] for w in workers} - workers_by_task = {ts: ts._who_has & workers for ts in tasks} - else: - workers = set(self.workers.values()) - workers_by_task = {ts: ts._who_has for ts in tasks} + return { + "status": "OK", + "time": time(), + "heartbeat-interval": heartbeat_interval(len(self.workers)), + } - ws: WorkerState - tasks_by_worker = {ws: set() for ws in workers} + async def add_worker( + self, + comm=None, + address=None, + keys=(), + nthreads=None, + name=None, + resolve_address=True, + nbytes=None, + types=None, + now=None, + resources=None, + host_info=None, + memory_limit=None, + metrics=None, + pid=0, + services=None, + local_directory=None, + versions=None, + nanny=None, + extra=None, + ): + """ Add a new worker to the cluster """ + with log_errors(): + address = self.coerce_address(address, resolve_address) + address = normalize_address(address) + host = get_address_host(address) - for k, v in workers_by_task.items(): - for vv in v: - tasks_by_worker[vv].add(k) + ws: WorkerState = self.workers.get(address) + if ws is not None: + raise ValueError("Worker already exists %s" % ws) - worker_bytes = { - ws: sum(ts.get_nbytes() for ts in v) - for ws, v in tasks_by_worker.items() + if name in self.aliases: + logger.warning( + "Worker tried to connect with a duplicate name: %s", name + ) + msg = { + "status": "error", + "message": "name taken, %s" % name, + "time": time(), } + if comm: + await comm.write(msg) + return - avg = sum(worker_bytes.values()) / len(worker_bytes) + self.workers[address] = ws = WorkerState( + address=address, + pid=pid, + nthreads=nthreads, + memory_limit=memory_limit or 0, + name=name, + local_directory=local_directory, + services=services, + versions=versions, + nanny=nanny, + extra=extra, + ) - sorted_workers = list( - map(first, sorted(worker_bytes.items(), key=second, reverse=True)) - ) + if "addresses" not in self.host_info[host]: + self.host_info[host].update({"addresses": set(), "nthreads": 0}) - recipients = iter(reversed(sorted_workers)) - recipient = next(recipients) - msgs = [] # (sender, recipient, key) - for sender in sorted_workers[: len(workers) // 2]: - sender_keys = { - ts: ts.get_nbytes() for ts in tasks_by_worker[sender] - } - sender_keys = iter( - sorted(sender_keys.items(), key=second, reverse=True) - ) + self.host_info[host]["addresses"].add(address) + self.host_info[host]["nthreads"] += nthreads - try: - while worker_bytes[sender] > avg: - while ( - worker_bytes[recipient] < avg - and worker_bytes[sender] > avg - ): - ts, nb = next(sender_keys) - if ts not in tasks_by_worker[recipient]: - tasks_by_worker[recipient].add(ts) - # tasks_by_worker[sender].remove(ts) - msgs.append((sender, recipient, ts)) - worker_bytes[sender] -= nb - worker_bytes[recipient] += nb - if worker_bytes[sender] > avg: - recipient = next(recipients) - except StopIteration: - break + self.total_nthreads += nthreads + self.aliases[name] = address - to_recipients = defaultdict(lambda: defaultdict(list)) - to_senders = defaultdict(list) - for sender, recipient, ts in msgs: - to_recipients[recipient.address][ts._key].append(sender.address) - to_senders[sender.address].append(ts._key) + response = self.heartbeat_worker( + address=address, + resolve_address=resolve_address, + now=now, + resources=resources, + host_info=host_info, + metrics=metrics, + ) - result = await asyncio.gather( - *( - retry_operation(self.rpc(addr=r).gather, who_has=v) - for r, v in to_recipients.items() - ) - ) - for r, v in to_recipients.items(): - self.log_event(r, {"action": "rebalance", "who_has": v}) + # Do not need to adjust self.total_occupancy as self.occupancy[ws] cannot exist before this. + self.check_idle_saturated(ws) - self.log_event( - "all", - { - "action": "rebalance", - "total-keys": len(tasks), - "senders": valmap(len, to_senders), - "recipients": valmap(len, to_recipients), - "moved_keys": len(msgs), - }, - ) + # for key in keys: # TODO + # self.mark_key_in_memory(key, [address]) - if not all(r["status"] == "OK" for r in result): - return { - "status": "missing-data", - "keys": tuple( - concat( - r["keys"].keys() - for r in result - if r["status"] == "missing-data" - ) - ), - } + self.stream_comms[address] = BatchedSend(interval="5ms", loop=self.loop) - for sender, recipient, ts in msgs: - assert ts._state == "memory" - ts._who_has.add(recipient) - recipient.has_what.add(ts) - recipient.nbytes += ts.get_nbytes() - self.log.append( - ( - "rebalance", - ts._key, - time(), - sender.address, - recipient.address, + if ws._nthreads > len(ws._processing): + self.idle[ws._address] = ws + + for plugin in self.plugins[:]: + try: + result = plugin.add_worker(scheduler=self, worker=address) + if inspect.isawaitable(result): + await result + except Exception as e: + logger.exception(e) + + recommendations: dict + if nbytes: + for key in nbytes: + tasks: dict = self.tasks + ts: TaskState = tasks.get(key) + if ts is not None and ts._state in ("processing", "waiting"): + recommendations = self.transition( + key, + "memory", + worker=address, + nbytes=nbytes[key], + typename=types[key], ) - ) + self.transitions(recommendations) - await asyncio.gather( - *(self._delete_worker_data(r, v) for r, v in to_senders.items()) - ) + recommendations = {} + for ts in list(self.unrunnable): + valid: set = self.valid_workers(ts) + if valid is None or ws in valid: + recommendations[ts._key] = "waiting" - return {"status": "OK"} + if recommendations: + self.transitions(recommendations) - async def replicate( + self.log_event(address, {"action": "add-worker"}) + self.log_event("all", {"action": "add-worker", "worker": address}) + logger.info("Register worker %s", ws) + + msg = { + "status": "OK", + "time": time(), + "heartbeat-interval": heartbeat_interval(len(self.workers)), + "worker-plugins": self.worker_plugins, + } + + cs: ClientState + version_warning = version_module.error_message( + version_module.get_versions(), + merge( + {w: ws._versions for w, ws in self.workers.items()}, + {c: cs._versions for c, cs in self.clients.items() if cs._versions}, + ), + versions, + client_name="This Worker", + ) + msg.update(version_warning) + + if comm: + await comm.write(msg) + await self.handle_worker(comm=comm, worker=address) + + def update_graph_hlg( self, - comm=None, + client=None, + hlg=None, keys=None, - n=None, - workers=None, - branching_factor=2, - delete=True, - lock=True, + dependencies=None, + restrictions=None, + priority=None, + loose_restrictions=None, + resources=None, + submitting_task=None, + retries=None, + user_priority=0, + actors=None, + fifo_timeout=0, ): - """Replicate data throughout cluster - This performs a tree copy of the data throughout the network - individually on each piece of data. + dsk, dependencies, annotations = highlevelgraph_unpack(hlg) - Parameters - ---------- - keys: Iterable - list of keys to replicate - n: int - Number of replications we expect to see within the cluster - branching_factor: int, optional - The number of workers that can copy data in each generation. - The larger the branching factor, the more data we copy in - a single step, but the more a given worker risks being - swamped by data requests. + # Remove any self-dependencies (happens on test_publish_bag() and others) + for k, v in dependencies.items(): + deps = set(v) + if k in deps: + deps.remove(k) + dependencies[k] = deps - See also - -------- - Scheduler.rebalance - """ - ws: WorkerState - wws: WorkerState - ts: TaskState + if priority is None: + # Removing all non-local keys before calling order() + dsk_keys = set(dsk) # intersection() of sets is much faster than dict_keys + stripped_deps = { + k: v.intersection(dsk_keys) + for k, v in dependencies.items() + if k in dsk_keys + } + priority = dask.order.order(dsk, dependencies=stripped_deps) - assert branching_factor > 0 - async with self._lock if lock else empty_context: - workers = {self.workers[w] for w in self.workers_list(workers)} - if n is None: - n = len(workers) - else: - n = min(n, len(workers)) - if n == 0: - raise ValueError("Can not use replicate to delete data") + return self.update_graph( + client, + dsk, + keys, + dependencies, + restrictions, + priority, + loose_restrictions, + resources, + submitting_task, + retries, + user_priority, + actors, + fifo_timeout, + annotations, + ) - tasks = {self.tasks[k] for k in keys} - missing_data = [ts._key for ts in tasks if not ts._who_has] - if missing_data: - return {"status": "missing-data", "keys": missing_data} + def update_graph( + self, + client=None, + tasks=None, + keys=None, + dependencies=None, + restrictions=None, + priority=None, + loose_restrictions=None, + resources=None, + submitting_task=None, + retries=None, + user_priority=0, + actors=None, + fifo_timeout=0, + annotations=None, + ): + """ + Add new computations to the internal dask graph - # Delete extraneous data - if delete: - del_worker_tasks = defaultdict(set) - for ts in tasks: - del_candidates = ts._who_has & workers - if len(del_candidates) > n: - for ws in random.sample( - del_candidates, len(del_candidates) - n - ): - del_worker_tasks[ws].add(ts) + This happens whenever the Client calls submit, map, get, or compute. + """ + start = time() + fifo_timeout = parse_timedelta(fifo_timeout) + keys = set(keys) + if len(tasks) > 1: + self.log_event( + ["all", client], {"action": "update_graph", "count": len(tasks)} + ) - await asyncio.gather( - *[ - self._delete_worker_data(ws._address, [t.key for t in tasks]) - for ws, tasks in del_worker_tasks.items() - ] - ) + # Remove aliases + for k in list(tasks): + if tasks[k] is k: + del tasks[k] - # Copy not-yet-filled data - while tasks: - gathers = defaultdict(dict) - for ts in list(tasks): - if ts._state == "forgotten": - # task is no longer needed by any client or dependant task - tasks.remove(ts) - continue - n_missing = n - len(ts._who_has & workers) - if n_missing <= 0: - # Already replicated enough - tasks.remove(ts) - continue + dependencies = dependencies or {} - count = min(n_missing, branching_factor * len(ts._who_has)) - assert count > 0 + n = 0 + while len(tasks) != n: # walk through new tasks, cancel any bad deps + n = len(tasks) + for k, deps in list(dependencies.items()): + if any( + dep not in self.tasks and dep not in tasks for dep in deps + ): # bad key + logger.info("User asked for computation on lost data, %s", k) + del tasks[k] + del dependencies[k] + if k in keys: + keys.remove(k) + self.report({"op": "cancelled-key", "key": k}, client=client) + self.client_releases_keys(keys=[k], client=client) - for ws in random.sample(workers - ts._who_has, count): - gathers[ws._address][ts._key] = [ - wws._address for wws in ts._who_has - ] + # Avoid computation that is already finished + ts: TaskState + already_in_memory = set() # tasks that are already done + for k, v in dependencies.items(): + if v and k in self.tasks: + ts = self.tasks[k] + if ts._state in ("memory", "erred"): + already_in_memory.add(k) - results = await asyncio.gather( - *( - retry_operation(self.rpc(addr=w).gather, who_has=who_has) - for w, who_has in gathers.items() - ) - ) - for w, v in zip(gathers, results): - if v["status"] == "OK": - self.add_keys(worker=w, keys=list(gathers[w])) + dts: TaskState + if already_in_memory: + dependents = dask.core.reverse_dict(dependencies) + stack = list(already_in_memory) + done = set(already_in_memory) + while stack: # remove unnecessary dependencies + key = stack.pop() + ts = self.tasks[key] + try: + deps = dependencies[key] + except KeyError: + deps = self.dependencies[key] + for dep in deps: + if dep in dependents: + child_deps = dependents[dep] else: - logger.warning("Communication failed during replication: %s", v) + child_deps = self.dependencies[dep] + if all(d in done for d in child_deps): + if dep in self.tasks and dep not in done: + done.add(dep) + stack.append(dep) - self.log_event(w, {"action": "replicate-add", "keys": gathers[w]}) + for d in done: + tasks.pop(d, None) + dependencies.pop(d, None) - self.log_event( - "all", - { - "action": "replicate", - "workers": list(workers), - "key-count": len(keys), - "branching-factor": branching_factor, - }, - ) + # Get or create task states + stack = list(keys) + touched_keys = set() + touched_tasks = [] + while stack: + k = stack.pop() + if k in touched_keys: + continue + # XXX Have a method get_task_state(self, k) ? + ts = self.tasks.get(k) + if ts is None: + ts = self.new_task(k, tasks.get(k), "released") + elif not ts._run_spec: + ts._run_spec = tasks.get(k) - def workers_to_close( - self, - comm=None, - memory_ratio=None, - n=None, - key=None, - minimum=None, - target=None, - attribute="address", - ): - """ - Find workers that we can close with low cost + touched_keys.add(k) + touched_tasks.append(ts) + stack.extend(dependencies.get(k, ())) - This returns a list of workers that are good candidates to retire. - These workers are not running anything and are storing - relatively little data relative to their peers. If all workers are - idle then we still maintain enough workers to have enough RAM to store - our data, with a comfortable buffer. + self.client_desires_keys(keys=keys, client=client) - This is for use with systems like ``distributed.deploy.adaptive``. + # Add dependencies + for key, deps in dependencies.items(): + ts = self.tasks.get(key) + if ts is None or ts._dependencies: + continue + for dep in deps: + dts = self.tasks[dep] + ts.add_dependency(dts) - Parameters - ---------- - memory_factor: Number - Amount of extra space we want to have for our stored data. - Defaults two 2, or that we want to have twice as much memory as we - currently have data. - n: int - Number of workers to close - minimum: int - Minimum number of workers to keep around - key: Callable(WorkerState) - An optional callable mapping a WorkerState object to a group - affiliation. Groups will be closed together. This is useful when - closing workers must be done collectively, such as by hostname. - target: int - Target number of workers to have after we close - attribute : str - The attribute of the WorkerState object to return, like "address" - or "name". Defaults to "address". + # Compute priorities + if isinstance(user_priority, Number): + user_priority = {k: user_priority for k in tasks} - Examples - -------- - >>> scheduler.workers_to_close() - ['tcp://192.168.0.1:1234', 'tcp://192.168.0.2:1234'] + annotations = annotations or {} + restrictions = restrictions or {} + loose_restrictions = loose_restrictions or [] + resources = resources or {} + retries = retries or {} - Group workers by hostname prior to closing + # Override existing taxonomy with per task annotations + if annotations: + if "priority" in annotations: + user_priority.update(annotations["priority"]) - >>> scheduler.workers_to_close(key=lambda ws: ws.host) - ['tcp://192.168.0.1:1234', 'tcp://192.168.0.1:4567'] + if "workers" in annotations: + restrictions.update(annotations["workers"]) - Remove two workers + if "allow_other_workers" in annotations: + loose_restrictions.extend( + k for k, v in annotations["allow_other_workers"].items() if v + ) - >>> scheduler.workers_to_close(n=2) + if "retries" in annotations: + retries.update(annotations["retries"]) - Keep enough workers to have twice as much memory as we we need. + if "resources" in annotations: + resources.update(annotations["resources"]) - >>> scheduler.workers_to_close(memory_ratio=2) + for a, kv in annotations.items(): + for k, v in kv.items(): + ts = self.tasks[k] + ts._annotations[a] = v - Returns - ------- - to_close: list of worker addresses that are OK to close + # Add actors + if actors is True: + actors = list(keys) + for actor in actors or []: + ts = self.tasks[actor] + ts._actor = True - See Also - -------- - Scheduler.retire_workers - """ - if target is not None and n is None: - n = len(self.workers) - target - if n is not None: - if n < 0: - n = 0 - target = len(self.workers) - n + priority = priority or dask.order.order( + tasks + ) # TODO: define order wrt old graph - if n is None and memory_ratio is None: - memory_ratio = 2 - - ws: WorkerState - with log_errors(): - if not n and all([ws._processing for ws in self.workers.values()]): - return [] + if submitting_task: # sub-tasks get better priority than parent tasks + ts = self.tasks.get(submitting_task) + if ts is not None: + generation = ts._priority[0] - 0.01 + else: # super-task already cleaned up + generation = self.generation + elif self._last_time + fifo_timeout < start: + self.generation += 1 # older graph generations take precedence + generation = self.generation + self._last_time = start + else: + generation = self.generation - if key is None: - key = operator.attrgetter("address") - if isinstance(key, bytes) and dask.config.get( - "distributed.scheduler.pickle" - ): - key = pickle.loads(key) + for key in set(priority) & touched_keys: + ts = self.tasks[key] + if ts._priority is None: + ts._priority = (-(user_priority.get(key, 0)), generation, priority[key]) - groups = groupby(key, self.workers.values()) + # Ensure all runnables have a priority + runnables = [ts for ts in touched_tasks if ts._run_spec] + for ts in runnables: + if ts._priority is None and ts._run_spec: + ts._priority = (self.generation, 0) - limit_bytes = { - k: sum([ws._memory_limit for ws in v]) for k, v in groups.items() - } - group_bytes = {k: sum([ws._nbytes for ws in v]) for k, v in groups.items()} + if restrictions: + # *restrictions* is a dict keying task ids to lists of + # restriction specifications (either worker names or addresses) + for k, v in restrictions.items(): + if v is None: + continue + ts = self.tasks.get(k) + if ts is None: + continue + ts._host_restrictions = set() + ts._worker_restrictions = set() + for w in v: + try: + w = self.coerce_address(w) + except ValueError: + # Not a valid address, but perhaps it's a hostname + ts._host_restrictions.add(w) + else: + ts._worker_restrictions.add(w) - limit = sum(limit_bytes.values()) - total = sum(group_bytes.values()) + if loose_restrictions: + for k in loose_restrictions: + ts = self.tasks[k] + ts._loose_restrictions = True - def _key(group): - wws: WorkerState - is_idle = not any([wws._processing for wws in groups[group]]) - bytes = -group_bytes[group] - return (is_idle, bytes) + if resources: + for k, v in resources.items(): + if v is None: + continue + assert isinstance(v, dict) + ts = self.tasks.get(k) + if ts is None: + continue + ts._resource_restrictions = v - idle = sorted(groups, key=_key) + if retries: + for k, v in retries.items(): + assert isinstance(v, int) + ts = self.tasks.get(k) + if ts is None: + continue + ts._retries = v - to_close = [] - n_remain = len(self.workers) + # Compute recommendations + recommendations: dict = {} - while idle: - group = idle.pop() - if n is None and any([ws._processing for ws in groups[group]]): - break + for ts in sorted(runnables, key=operator.attrgetter("priority"), reverse=True): + if ts._state == "released" and ts._run_spec: + recommendations[ts._key] = "waiting" - if minimum and n_remain - len(groups[group]) < minimum: + for ts in touched_tasks: + for dts in ts._dependencies: + if dts._exception_blame: + ts._exception_blame = dts._exception_blame + recommendations[ts._key] = "erred" break - limit -= limit_bytes[group] + for plugin in self.plugins[:]: + try: + plugin.update_graph( + self, + client=client, + tasks=tasks, + keys=keys, + restrictions=restrictions or {}, + dependencies=dependencies, + priority=priority, + loose_restrictions=loose_restrictions, + resources=resources, + annotations=annotations, + ) + except Exception as e: + logger.exception(e) - if (n is not None and n_remain - len(groups[group]) >= target) or ( - memory_ratio is not None and limit >= memory_ratio * total - ): - to_close.append(group) - n_remain -= len(groups[group]) + self.transitions(recommendations) - else: - break + for ts in touched_tasks: + if ts._state in ("memory", "erred"): + self.report_on_key(ts=ts, client=client) - result = [getattr(ws, attribute) for g in to_close for ws in groups[g]] - if result: - logger.debug("Suggest closing workers: %s", result) + end = time() + if self.digests is not None: + self.digests["update-graph-duration"].add(end - start) - return result + # TODO: balance workers - async def retire_workers( - self, - comm=None, - workers=None, - remove=True, - close_workers=False, - names=None, - lock=True, - **kwargs, - ) -> dict: - """Gracefully retire workers from cluster + def new_task(self, key, spec, state): + """ Create a new task, and associated states """ + ts: TaskState = TaskState(key, spec) + tp: TaskPrefix + tg: TaskGroup + ts._state = state + prefix_key = key_split(key) + try: + tp = self.task_prefixes[prefix_key] + except KeyError: + self.task_prefixes[prefix_key] = tp = TaskPrefix(prefix_key) + ts._prefix = tp - Parameters - ---------- - workers: list (optional) - List of worker addresses to retire. - If not provided we call ``workers_to_close`` which finds a good set - workers_names: list (optional) - List of worker names to retire. - remove: bool (defaults to True) - Whether or not to remove the worker metadata immediately or else - wait for the worker to contact us - close_workers: bool (defaults to False) - Whether or not to actually close the worker explicitly from here. - Otherwise we expect some external job scheduler to finish off the - worker. - **kwargs: dict - Extra options to pass to workers_to_close to determine which - workers we should drop + group_key = ts._group_key + try: + tg = self.task_groups[group_key] + except KeyError: + self.task_groups[group_key] = tg = TaskGroup(group_key) + tg._prefix = tp + tp._groups.append(tg) + tg.add(ts) + self.tasks[key] = ts + return ts - Returns - ------- - Dictionary mapping worker ID/address to dictionary of information about - that worker for each retired worker. + def stimulus_task_finished(self, key=None, worker=None, **kwargs): + """ Mark that a task has finished execution on a particular worker """ + logger.debug("Stimulus task finished %s, %s", key, worker) - See Also - -------- - Scheduler.workers_to_close - """ - ws: WorkerState - ts: TaskState - with log_errors(): - async with self._lock if lock else empty_context: - if names is not None: - if names: - logger.info("Retire worker names %s", names) - names = set(map(str, names)) - workers = [ - ws._address - for ws in self.workers.values() - if str(ws._name) in names - ] - if workers is None: - while True: - try: - workers = self.workers_to_close(**kwargs) - if workers: - workers = await self.retire_workers( - workers=workers, - remove=remove, - close_workers=close_workers, - lock=False, - ) - return workers - else: - return {} - except KeyError: # keys left during replicate - pass - workers = {self.workers[w] for w in workers if w in self.workers} - if not workers: - return {} - logger.info("Retire workers %s", workers) + tasks: dict = self.tasks + ts: TaskState = tasks.get(key) + if ts is None: + return {} + workers: dict = cast(dict, self.workers) + ws: WorkerState = workers[worker] + ts._metadata.update(kwargs["metadata"]) - # Keys orphaned by retiring those workers - keys = set.union(*[w.has_what for w in workers]) - keys = {ts._key for ts in keys if ts._who_has.issubset(workers)} + recommendations: dict + if ts._state == "processing": + recommendations = self.transition(key, "memory", worker=worker, **kwargs) - other_workers = set(self.workers.values()) - workers - if keys: - if other_workers: - logger.info("Moving %d keys to other workers", len(keys)) - await self.replicate( - keys=keys, - workers=[ws._address for ws in other_workers], - n=1, - delete=False, - lock=False, - ) - else: - return {} - - worker_keys = {ws._address: ws.identity() for ws in workers} - if close_workers and worker_keys: - await asyncio.gather( - *[self.close_worker(worker=w, safe=True) for w in worker_keys] - ) - if remove: - await asyncio.gather( - *[self.remove_worker(address=w, safe=True) for w in worker_keys] - ) + if ts._state == "memory": + assert ws in ts._who_has + else: + logger.debug( + "Received already computed task, worker: %s, state: %s" + ", key: %s, who_has: %s", + worker, + ts._state, + key, + ts._who_has, + ) + if ws not in ts._who_has: + self.worker_send(worker, {"op": "release-task", "key": key}) + recommendations = {} - self.log_event( - "all", - { - "action": "retire-workers", - "workers": worker_keys, - "moved-keys": len(keys), - }, - ) - self.log_event(list(worker_keys), {"action": "retired"}) + return recommendations - return worker_keys + def stimulus_task_erred( + self, key=None, worker=None, exception=None, traceback=None, **kwargs + ): + """ Mark that a task has erred on a particular worker """ + logger.debug("Stimulus task erred %s, %s", key, worker) - def add_keys(self, comm=None, worker=None, keys=()): - """ - Learn that a worker has certain keys + ts: TaskState = self.tasks.get(key) + if ts is None: + return {} - This should not be used in practice and is mostly here for legacy - reasons. However, it is sent by workers from time to time. - """ - if worker not in self.workers: - return "not found" - ws: WorkerState = self.workers[worker] - for key in keys: - ts: TaskState = self.tasks.get(key) - if ts is not None and ts._state == "memory": - if ts not in ws._has_what: - ws._nbytes += ts.get_nbytes() - ws._has_what.add(ts) - ts._who_has.add(ws) + recommendations: dict + if ts._state == "processing": + retries = ts._retries + if retries > 0: + ts._retries = retries - 1 + recommendations = self.transition(key, "waiting") else: - self.worker_send( - worker, {"op": "delete-data", "keys": [key], "report": False} + recommendations = self.transition( + key, + "erred", + cause=key, + exception=exception, + traceback=traceback, + worker=worker, + **kwargs, ) + else: + recommendations = {} - return "OK" + return recommendations - def update_data( - self, comm=None, who_has=None, nbytes=None, client=None, serializers=None + def stimulus_missing_data( + self, cause=None, key=None, worker=None, ensure=True, **kwargs ): - """ - Learn that new data has entered the network from an external source - - See Also - -------- - Scheduler.mark_key_in_memory - """ + """ Mark that certain keys have gone missing. Recover. """ with log_errors(): - who_has = { - k: [self.coerce_address(vv) for vv in v] for k, v in who_has.items() - } - logger.debug("Update data %s", who_has) + logger.debug("Stimulus missing data %s, %s", key, worker) - for key, workers in who_has.items(): - ts: TaskState = self.tasks.get(key) - if ts is None: - ts: TaskState = self.new_task(key, None, "memory") - ts.state = "memory" - if key in nbytes: - ts.set_nbytes(nbytes[key]) - for w in workers: - ws: WorkerState = self.workers[w] - if ts not in ws._has_what: - ws._nbytes += ts.get_nbytes() - ws._has_what.add(ts) - ts._who_has.add(ws) - self.report( - {"op": "key-in-memory", "key": key, "workers": list(workers)} - ) + ts: TaskState = self.tasks.get(key) + if ts is None or ts._state == "memory": + return {} + cts: TaskState = self.tasks.get(cause) - if client: - self.client_desires_keys(keys=list(who_has), client=client) + recommendations: dict = {} - def _task_to_report_msg(self, ts: TaskState) -> dict: - if ts is None: - return {"op": "cancelled-key", "key": ts._key} - elif ts._state == "forgotten": - return {"op": "cancelled-key", "key": ts._key} - elif ts._state == "memory": - return {"op": "key-in-memory", "key": ts._key} - elif ts._state == "erred": - failing_ts: TaskState = ts._exception_blame - return { - "op": "task-erred", - "key": ts._key, - "exception": failing_ts._exception, - "traceback": failing_ts._traceback, - } - else: - return None + if cts is not None and cts._state == "memory": # couldn't find this + ws: WorkerState + for ws in cts._who_has: # TODO: this behavior is extreme + ws._has_what.remove(cts) + ws._nbytes -= cts.get_nbytes() + cts._who_has.clear() + recommendations[cause] = "released" - def _task_to_client_msgs(self, ts: TaskState) -> dict: - cs: ClientState - clients: dict = self.clients - client_keys: list - if ts is None: - # Notify all clients - client_keys = list(clients) - else: - # Notify clients interested in key - client_keys = [cs._client_key for cs in ts._who_wants] + if key: + recommendations[key] = "released" - report_msg: dict = self._task_to_report_msg(ts) + self.transitions(recommendations) - client_msgs: dict = {} - for k in client_keys: - client_msgs[k] = report_msg + if self.validate: + assert cause not in self.who_has - return client_msgs + return {} - def report_on_key(self, key: str = None, ts: TaskState = None, client: str = None): - if ts is None: - tasks: dict = self.tasks - ts = tasks.get(key) - elif key is None: - key = ts._key - else: - assert False, (key, ts) - return + def stimulus_retry(self, comm=None, keys=None, client=None): + logger.info("Client %s requests to retry %d keys", client, len(keys)) + if client: + self.log_event(client, {"action": "retry", "count": len(keys)}) - report_msg: dict = self._task_to_report_msg(ts) - if report_msg is not None: - self.report(report_msg, ts=ts, client=client) + stack = list(keys) + seen = set() + roots = [] + ts: TaskState + dts: TaskState + while stack: + key = stack.pop() + seen.add(key) + ts = self.tasks[key] + erred_deps = [dts._key for dts in ts._dependencies if dts._state == "erred"] + if erred_deps: + stack.extend(erred_deps) + else: + roots.append(key) - async def feed( - self, comm, function=None, setup=None, teardown=None, interval="1s", **kwargs - ): - """ - Provides a data Comm to external requester + recommendations: dict = {key: "waiting" for key in roots} + self.transitions(recommendations) - Caution: this runs arbitrary Python code on the scheduler. This should - eventually be phased out. It is mostly used by diagnostics. + if self.validate: + for key in seen: + assert not self.tasks[key].exception_blame + + return tuple(seen) + + async def remove_worker(self, comm=None, address=None, safe=False, close=True): """ - if not dask.config.get("distributed.scheduler.pickle"): - logger.warn( - "Tried to call 'feed' route with custom functions, but " - "pickle is disallowed. Set the 'distributed.scheduler.pickle'" - "config value to True to use the 'feed' route (this is mostly " - "commonly used with progress bars)" - ) - return + Remove worker from cluster - interval = parse_timedelta(interval) + We do this when a worker reports that it plans to leave or when it + appears to be unresponsive. This may send its tasks back to a released + state. + """ with log_errors(): - if function: - function = pickle.loads(function) - if setup: - setup = pickle.loads(setup) - if teardown: - teardown = pickle.loads(teardown) - state = setup(self) if setup else None - if inspect.isawaitable(state): - state = await state - try: - while self.status == Status.running: - if state is None: - response = function(self) - else: - response = function(self, state) - await comm.write(response) - await asyncio.sleep(interval) - except (EnvironmentError, CommClosedError): - pass - finally: - if teardown: - teardown(self, state) + if self.status == Status.closed: + return - def log_worker_event(self, worker=None, topic=None, msg=None): - self.log_event(topic, msg) + address = self.coerce_address(address) - def subscribe_worker_status(self, comm=None): - WorkerStatusPlugin(self, comm) - ident = self.identity() - for v in ident["workers"].values(): - del v["metrics"] - del v["last_seen"] - return ident + if address not in self.workers: + return "already-removed" - def get_processing(self, comm=None, workers=None): - ws: WorkerState - ts: TaskState - if workers is not None: - workers = set(map(self.coerce_address, workers)) - return {w: [ts._key for ts in self.workers[w].processing] for w in workers} - else: - return { - w: [ts._key for ts in ws._processing] for w, ws in self.workers.items() - } + host = get_address_host(address) - def get_who_has(self, comm=None, keys=None): - ws: WorkerState - ts: TaskState - if keys is not None: - return { - k: [ws._address for ws in self.tasks[k].who_has] - if k in self.tasks - else [] - for k in keys - } - else: - return { - key: [ws._address for ws in ts._who_has] - for key, ts in self.tasks.items() - } + ws: WorkerState = self.workers[address] - def get_has_what(self, comm=None, workers=None): - ws: WorkerState + self.log_event( + ["all", address], + { + "action": "remove-worker", + "worker": address, + "processing-tasks": dict(ws._processing), + }, + ) + logger.info("Remove worker %s", ws) + if close: + with suppress(AttributeError, CommClosedError): + self.stream_comms[address].send({"op": "close", "report": False}) + + self.remove_resources(address) + + self.host_info[host]["nthreads"] -= ws._nthreads + self.host_info[host]["addresses"].remove(address) + self.total_nthreads -= ws._nthreads + + if not self.host_info[host]["addresses"]: + del self.host_info[host] + + self.rpc.remove(address) + del self.stream_comms[address] + del self.aliases[ws._name] + self.idle.pop(ws._address, None) + self.saturated.discard(ws) + del self.workers[address] + ws.status = Status.closed + self.total_occupancy -= ws._occupancy + + recommendations: dict = {} + + ts: TaskState + for ts in list(ws._processing): + k = ts._key + recommendations[k] = "released" + if not safe: + ts._suspicious += 1 + ts._prefix._suspicious += 1 + if ts._suspicious > self.allowed_failures: + del recommendations[k] + e = pickle.dumps( + KilledWorker(task=k, last_worker=ws.clean()), protocol=4 + ) + r = self.transition(k, "erred", exception=e, cause=k) + recommendations.update(r) + logger.info( + "Task %s marked as failed because %d workers died" + " while trying to run it", + ts._key, + self.allowed_failures, + ) + + for ts in ws._has_what: + ts._who_has.remove(ws) + if not ts._who_has: + if ts._run_spec: + recommendations[ts._key] = "released" + else: # pure data + recommendations[ts._key] = "forgotten" + ws._has_what.clear() + + self.transitions(recommendations) + + for plugin in self.plugins[:]: + try: + result = plugin.remove_worker(scheduler=self, worker=address) + if inspect.isawaitable(result): + await result + except Exception as e: + logger.exception(e) + + if not self.workers: + logger.info("Lost all workers") + + for w in self.workers: + self.bandwidth_workers.pop((address, w), None) + self.bandwidth_workers.pop((w, address), None) + + def remove_worker_from_events(): + # If the worker isn't registered anymore after the delay, remove from events + if address not in self.workers and address in self.events: + del self.events[address] + + cleanup_delay = parse_timedelta( + dask.config.get("distributed.scheduler.events-cleanup-delay") + ) + self.loop.call_later(cleanup_delay, remove_worker_from_events) + logger.debug("Removed worker %s", ws) + + return "OK" + + def stimulus_cancel(self, comm, keys=None, client=None, force=False): + """ Stop execution on a list of keys """ + logger.info("Client %s requests to cancel %d keys", client, len(keys)) + if client: + self.log_event( + client, {"action": "cancel", "count": len(keys), "force": force} + ) + for key in keys: + self.cancel_key(key, client, force=force) + + def cancel_key(self, key, client, retries=5, force=False): + """ Cancel a particular key and all dependents """ + # TODO: this should be converted to use the transition mechanism + ts: TaskState = self.tasks.get(key) + dts: TaskState + try: + cs: ClientState = self.clients[client] + except KeyError: + return + if ts is None or not ts._who_wants: # no key yet, lets try again in a moment + if retries: + self.loop.call_later( + 0.2, lambda: self.cancel_key(key, client, retries - 1) + ) + return + if force or ts._who_wants == {cs}: # no one else wants this key + for dts in list(ts._dependents): + self.cancel_key(dts._key, client, force=force) + logger.info("Scheduler cancels key %s. Force=%s", key, force) + self.report({"op": "cancelled-key", "key": key}) + clients = list(ts._who_wants) if force else [cs] + for cs in clients: + self.client_releases_keys(keys=[key], client=cs._client_key) + + def client_desires_keys(self, keys=None, client=None): + cs: ClientState = self.clients.get(client) + if cs is None: + # For publish, queues etc. + self.clients[client] = cs = ClientState(client) ts: TaskState - if workers is not None: - workers = map(self.coerce_address, workers) - return { - w: [ts._key for ts in self.workers[w].has_what] - if w in self.workers - else [] - for w in workers - } - else: - return { - w: [ts._key for ts in ws._has_what] for w, ws in self.workers.items() - } + for k in keys: + ts = self.tasks.get(k) + if ts is None: + # For publish, queues etc. + ts = self.new_task(k, None, "released") + ts._who_wants.add(cs) + cs._wants_what.add(ts) - def get_ncores(self, comm=None, workers=None): - ws: WorkerState - if workers is not None: - workers = map(self.coerce_address, workers) - return {w: self.workers[w].nthreads for w in workers if w in self.workers} - else: - return {w: ws._nthreads for w, ws in self.workers.items()} + if ts._state in ("memory", "erred"): + self.report_on_key(ts=ts, client=client) + + def client_releases_keys(self, keys=None, client=None): + """ Remove keys from client desired list """ + + if not isinstance(keys, list): + keys = list(keys) + cs: ClientState = self.clients[client] + recommendations: dict = {} + + self._client_releases_keys(keys=keys, cs=cs, recommendations=recommendations) + self.transitions(recommendations) + + def client_heartbeat(self, client=None): + """ Handle heartbeats from Client """ + cs: ClientState = self.clients[client] + cs._last_seen = time() + + ################### + # Task Validation # + ################### + + def validate_released(self, key): + ts: TaskState = self.tasks[key] + dts: TaskState + assert ts._state == "released" + assert not ts._waiters + assert not ts._waiting_on + assert not ts._who_has + assert not ts._processing_on + assert not any([ts in dts._waiters for dts in ts._dependencies]) + assert ts not in self.unrunnable + + def validate_waiting(self, key): + ts: TaskState = self.tasks[key] + dts: TaskState + assert ts._waiting_on + assert not ts._who_has + assert not ts._processing_on + assert ts not in self.unrunnable + for dts in ts._dependencies: + # We are waiting on a dependency iff it's not stored + assert (not not dts._who_has) != (dts in ts._waiting_on) + assert ts in dts._waiters # XXX even if dts._who_has? + + def validate_processing(self, key): + ts: TaskState = self.tasks[key] + dts: TaskState + assert not ts._waiting_on + ws: WorkerState = ts._processing_on + assert ws + assert ts in ws._processing + assert not ts._who_has + for dts in ts._dependencies: + assert dts._who_has + assert ts in dts._waiters + + def validate_memory(self, key): + ts: TaskState = self.tasks[key] + dts: TaskState + assert ts._who_has + assert not ts._processing_on + assert not ts._waiting_on + assert ts not in self.unrunnable + for dts in ts._dependents: + assert (dts in ts._waiters) == (dts._state in ("waiting", "processing")) + assert ts not in dts._waiting_on + + def validate_no_worker(self, key): + ts: TaskState = self.tasks[key] + dts: TaskState + assert ts in self.unrunnable + assert not ts._waiting_on + assert ts in self.unrunnable + assert not ts._processing_on + assert not ts._who_has + for dts in ts._dependencies: + assert dts._who_has + + def validate_erred(self, key): + ts: TaskState = self.tasks[key] + assert ts._exception_blame + assert not ts._who_has + + def validate_key(self, key, ts: TaskState = None): + try: + if ts is None: + ts = self.tasks.get(key) + if ts is None: + logger.debug("Key lost: %s", key) + else: + ts.validate() + try: + func = getattr(self, "validate_" + ts._state.replace("-", "_")) + except AttributeError: + logger.error( + "self.validate_%s not found", ts._state.replace("-", "_") + ) + else: + func(key) + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - async def get_call_stack(self, comm=None, keys=None): - ts: TaskState - dts: TaskState - if keys is not None: - stack = list(keys) - processing = set() - while stack: - key = stack.pop() - ts = self.tasks[key] - if ts._state == "waiting": - stack.extend([dts._key for dts in ts._dependencies]) - elif ts._state == "processing": - processing.add(ts) + pdb.set_trace() + raise - workers = defaultdict(list) - for ts in processing: - if ts._processing_on: - workers[ts._processing_on.address].append(ts._key) - else: - workers = {w: None for w in self.workers} + def validate_state(self, allow_overlap=False): + validate_state(self.tasks, self.workers, self.clients) - if not workers: - return {} + if not (set(self.workers) == set(self.stream_comms)): + raise ValueError("Workers not the same in all collections") - results = await asyncio.gather( - *(self.rpc(w).call_stack(keys=v) for w, v in workers.items()) - ) - response = {w: r for w, r in zip(workers, results) if r} - return response + ws: WorkerState + for w, ws in self.workers.items(): + assert isinstance(w, str), (type(w), w) + assert isinstance(ws, WorkerState), (type(ws), ws) + assert ws._address == w + if not ws._processing: + assert not ws._occupancy + assert ws._address in cast(dict, self.idle) - def get_nbytes(self, comm=None, keys=None, summary=True): ts: TaskState - with log_errors(): - if keys is not None: - result = {k: self.tasks[k].nbytes for k in keys} - else: - result = { - k: ts._nbytes for k, ts in self.tasks.items() if ts._nbytes >= 0 - } - - if summary: - out = defaultdict(lambda: 0) - for k, v in result.items(): - out[key_split(k)] += v - result = dict(out) + for k, ts in self.tasks.items(): + assert isinstance(ts, TaskState), (type(ts), ts) + assert ts._key == k + self.validate_key(k, ts) - return result + c: str + cs: ClientState + for c, cs in self.clients.items(): + # client=None is often used in tests... + assert c is None or type(c) == str, (type(c), c) + assert type(cs) == ClientState, (type(cs), cs) + assert cs._client_key == c - def get_comm_cost(self, ts: TaskState, ws: WorkerState): - """ - Get the estimated communication cost (in s.) to compute the task - on the given worker. - """ - dts: TaskState - deps: set = ts._dependencies - ws._has_what - nbytes: Py_ssize_t = 0 - bandwidth: double = self.bandwidth - for dts in deps: - nbytes += dts._nbytes - return nbytes / bandwidth + a = {w: ws._nbytes for w, ws in self.workers.items()} + b = { + w: sum(ts.get_nbytes() for ts in ws._has_what) + for w, ws in self.workers.items() + } + assert a == b, (a, b) - def get_task_duration(self, ts: TaskState, default: double = -1): - """ - Get the estimated computation cost of the given task - (not including any communication cost). - """ - duration: double = ts._prefix._duration_average - if duration < 0: - s: set = self.unknown_durations[ts._prefix._name] - s.add(ts) - if default < 0: - duration = UNKNOWN_TASK_DURATION - else: - duration = default + actual_total_occupancy = 0 + for worker, ws in self.workers.items(): + assert abs(sum(ws._processing.values()) - ws._occupancy) < 1e-8 + actual_total_occupancy += ws._occupancy - return duration + assert abs(actual_total_occupancy - self.total_occupancy) < 1e-8, ( + actual_total_occupancy, + self.total_occupancy, + ) - def run_function(self, stream, function, args=(), kwargs={}, wait=True): - """Run a function within this process + ################### + # Manage Messages # + ################### - See Also - -------- - Client.run_on_scheduler: + def report(self, msg: dict, ts: TaskState = None, client: str = None): """ - from .worker import run - - self.log_event("all", {"action": "run-function", "function": function}) - return run(self, stream, function=function, args=args, kwargs=kwargs, wait=wait) + Publish updates to all listening Queues and Comms - def set_metadata(self, comm=None, keys=None, value=None): - try: - metadata = self.task_metadata - for key in keys[:-1]: - if key not in metadata or not isinstance(metadata[key], (dict, list)): - metadata[key] = dict() - metadata = metadata[key] - metadata[keys[-1]] = value - except Exception as e: - import pdb + If the message contains a key then we only send the message to those + comms that care about the key. + """ + if ts is None: + msg_key = msg.get("key") + if msg_key is not None: + tasks: dict = self.tasks + ts = tasks.get(msg_key) - pdb.set_trace() + cs: ClientState + client_comms: dict = self.client_comms + client_keys: list + if ts is None: + # Notify all clients + client_keys = list(client_comms) + elif client is None: + # Notify clients interested in key + client_keys = [cs._client_key for cs in ts._who_wants] + else: + # Notify clients interested in key (including `client`) + client_keys = [ + cs._client_key for cs in ts._who_wants if cs._client_key != client + ] + client_keys.append(client) - def get_metadata(self, comm=None, keys=None, default=no_default): - metadata = self.task_metadata - for key in keys[:-1]: - metadata = metadata[key] - try: - return metadata[keys[-1]] - except KeyError: - if default != no_default: - return default - else: - raise + k: str + for k in client_keys: + c = client_comms.get(k) + if c is None: + continue + try: + c.send(msg) + # logger.debug("Scheduler sends message to client %s", msg) + except CommClosedError: + if self.status == Status.running: + logger.critical("Tried writing to closed comm: %s", msg) - def get_task_status(self, comm=None, keys=None): - return { - key: (self.tasks[key].state if key in self.tasks else None) for key in keys - } + async def add_client(self, comm, client=None, versions=None): + """Add client to network - def get_task_stream(self, comm=None, start=None, stop=None, count=None): - from distributed.diagnostics.task_stream import TaskStreamPlugin + We listen to all future messages from this Comm. + """ + assert client is not None + comm.name = "Scheduler->Client" + logger.info("Receive client connection: %s", client) + self.log_event(["all", client], {"action": "add-client", "client": client}) + self.clients[client] = ClientState(client, versions=versions) - self.add_plugin(TaskStreamPlugin, idempotent=True) - tsp = [p for p in self.plugins if isinstance(p, TaskStreamPlugin)][0] - return tsp.collect(start=start, stop=stop, count=count) + for plugin in self.plugins[:]: + try: + plugin.add_client(scheduler=self, client=client) + except Exception as e: + logger.exception(e) - def start_task_metadata(self, comm=None, name=None): - plugin = CollectTaskMetaDataPlugin(scheduler=self, name=name) + try: + bcomm = BatchedSend(interval="2ms", loop=self.loop) + bcomm.start(comm) + self.client_comms[client] = bcomm + msg = {"op": "stream-start"} + ws: WorkerState + version_warning = version_module.error_message( + version_module.get_versions(), + {w: ws._versions for w, ws in self.workers.items()}, + versions, + ) + msg.update(version_warning) + bcomm.send(msg) - self.add_plugin(plugin) + try: + await self.handle_stream(comm=comm, extra={"client": client}) + finally: + self.remove_client(client=client) + logger.debug("Finished handling client %s", client) + finally: + if not comm.closed(): + self.client_comms[client].send({"op": "stream-closed"}) + try: + if not shutting_down(): + await self.client_comms[client].close() + del self.client_comms[client] + if self.status == Status.running: + logger.info("Close client connection: %s", client) + except TypeError: # comm becomes None during GC + pass - def stop_task_metadata(self, comm=None, name=None): - plugins = [ - p - for p in self.plugins - if isinstance(p, CollectTaskMetaDataPlugin) and p.name == name - ] - if len(plugins) != 1: - raise ValueError( - "Expected to find exactly one CollectTaskMetaDataPlugin " - f"with name {name} but found {len(plugins)}." + def remove_client(self, client=None): + """ Remove client from network """ + if self.status == Status.running: + logger.info("Remove client %s", client) + self.log_event(["all", client], {"action": "remove-client", "client": client}) + try: + cs: ClientState = self.clients[client] + except KeyError: + # XXX is this a legitimate condition? + pass + else: + ts: TaskState + self.client_releases_keys( + keys=[ts._key for ts in cs._wants_what], client=cs._client_key ) + del self.clients[client] - plugin = plugins[0] - self.remove_plugin(plugin) - return {"metadata": plugin.metadata, "state": plugin.state} + for plugin in self.plugins[:]: + try: + plugin.remove_client(scheduler=self, client=client) + except Exception as e: + logger.exception(e) - async def register_worker_plugin(self, comm, plugin, name=None): - """ Registers a setup function, and call it on every worker """ - self.worker_plugins.append({"plugin": plugin, "name": name}) + def remove_client_from_events(): + # If the client isn't registered anymore after the delay, remove from events + if client not in self.clients and client in self.events: + del self.events[client] - responses = await self.broadcast( - msg=dict(op="plugin-add", plugin=plugin, name=name) + cleanup_delay = parse_timedelta( + dask.config.get("distributed.scheduler.events-cleanup-delay") ) - return responses + self.loop.call_later(cleanup_delay, remove_client_from_events) - ##################### - # State Transitions # - ##################### + def send_task_to_worker(self, worker, ts: TaskState, duration=None): + """ Send a single computational task to a worker """ + try: + msg: dict = self._task_to_msg(ts, duration) + self.worker_send(worker, msg) + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - def _remove_from_processing(self, ts: TaskState) -> str: - """ - Remove *ts* from the set of processing tasks. - """ - workers: dict = cast(dict, self.workers) - ws: WorkerState = ts._processing_on - ts._processing_on = None - w: str = ws._address - if w in workers: # may have been removed - duration = ws._processing.pop(ts) - if not ws._processing: - self.total_occupancy -= ws._occupancy - ws._occupancy = 0 - else: - self.total_occupancy -= duration - ws._occupancy -= duration - self.check_idle_saturated(ws) - self.release_resources(ts, ws) - return w - else: - return None + pdb.set_trace() + raise - def _add_to_memory( - self, - ts: TaskState, - ws: WorkerState, - recommendations: dict, - client_msgs: dict, - type=None, - typename=None, - **kwargs, - ): - """ - Add *ts* to the set of in-memory tasks. - """ - if self.validate: - assert ts not in ws._has_what + def handle_uncaught_error(self, **msg): + logger.exception(clean_exception(**msg)[1]) - ts._who_has.add(ws) - ws._has_what.add(ts) - ws._nbytes += ts.get_nbytes() + def handle_task_finished(self, key=None, worker=None, **msg): + if worker not in self.workers: + return + validate_key(key) + r = self.stimulus_task_finished(key=key, worker=worker, **msg) + self.transitions(r) - deps: list = list(ts._dependents) - if len(deps) > 1: - deps.sort(key=operator.attrgetter("priority"), reverse=True) + def handle_task_erred(self, key=None, **msg): + r = self.stimulus_task_erred(key=key, **msg) + self.transitions(r) - dts: TaskState - s: set - for dts in deps: - s = dts._waiting_on - if ts in s: - s.discard(ts) - if not s: # new task ready to run - recommendations[dts._key] = "processing" + def handle_release_data(self, key=None, worker=None, client=None, **msg): + ts: TaskState = self.tasks.get(key) + if ts is None: + return + ws: WorkerState = self.workers[worker] + if ts._processing_on != ws: + return + r = self.stimulus_missing_data(key=key, ensure=False, **msg) + self.transitions(r) - for dts in ts._dependencies: - s = dts._waiters - s.discard(ts) - if not s and not dts._who_wants: - recommendations[dts._key] = "released" + def handle_missing_data(self, key=None, errant_worker=None, **kwargs): + logger.debug("handle missing data key=%s worker=%s", key, errant_worker) + self.log.append(("missing", key, errant_worker)) - report_msg: dict = {} - cs: ClientState - if not ts._waiters and not ts._who_wants: - recommendations[ts._key] = "released" - else: - report_msg["op"] = "key-in-memory" - report_msg["key"] = ts._key - if type is not None: - report_msg["type"] = type + ts: TaskState = self.tasks.get(key) + if ts is None or not ts._who_has: + return + if errant_worker in self.workers: + ws: WorkerState = self.workers[errant_worker] + if ws in ts._who_has: + ts._who_has.remove(ws) + ws._has_what.remove(ts) + ws._nbytes -= ts.get_nbytes() + if not ts._who_has: + if ts._run_spec: + self.transitions({key: "released"}) + else: + self.transitions({key: "forgotten"}) - for cs in ts._who_wants: - client_msgs[cs._client_key] = report_msg + def release_worker_data(self, comm=None, keys=None, worker=None): + ws: WorkerState = self.workers[worker] + tasks = {self.tasks[k] for k in keys} + removed_tasks = tasks & ws._has_what + ws._has_what -= removed_tasks - ts.state = "memory" - ts._type = typename - ts._group._types.add(typename) + ts: TaskState + recommendations: dict = {} + for ts in removed_tasks: + ws._nbytes -= ts.get_nbytes() + wh = ts._who_has + wh.remove(ws) + if not wh: + recommendations[ts._key] = "released" + if recommendations: + self.transitions(recommendations) - cs = self.clients["fire-and-forget"] - if ts in cs._wants_what: - self._client_releases_keys( - cs=cs, - keys=[ts._key], - recommendations=recommendations, - ) + def handle_long_running(self, key=None, worker=None, compute_duration=None): + """A task has seceded from the thread pool - def transition_released_waiting(self, key): - try: - tasks: dict = self.tasks - workers: dict = cast(dict, self.workers) - ts: TaskState = tasks[key] - dts: TaskState - worker_msgs: dict = {} - client_msgs: dict = {} + We stop the task from being stolen in the future, and change task + duration accounting as if the task has stopped. + """ + ts: TaskState = self.tasks[key] + if "stealing" in self.extensions: + self.extensions["stealing"].remove_key_from_stealable(ts) - if self.validate: - assert ts._run_spec - assert not ts._waiting_on - assert not ts._who_has - assert not ts._processing_on - assert not any([dts._state == "forgotten" for dts in ts._dependencies]) + ws: WorkerState = ts._processing_on + if ws is None: + logger.debug("Received long-running signal from duplicate task. Ignoring.") + return - if ts._has_lost_dependencies: - return {key: "forgotten"}, worker_msgs, client_msgs + if compute_duration: + old_duration = ts._prefix._duration_average + new_duration = compute_duration + if old_duration < 0: + avg_duration = new_duration + else: + avg_duration = 0.5 * old_duration + 0.5 * new_duration - ts.state = "waiting" + ts._prefix._duration_average = avg_duration - recommendations: dict = {} + ws._occupancy -= ws._processing[ts] + self.total_occupancy -= ws._processing[ts] + ws._processing[ts] = 0 + self.check_idle_saturated(ws) - dts: TaskState - for dts in ts._dependencies: - if dts._exception_blame: - ts._exception_blame = dts._exception_blame - recommendations[key] = "erred" - return recommendations, worker_msgs, client_msgs + async def handle_worker(self, comm=None, worker=None): + """ + Listen to responses from a single worker - for dts in ts._dependencies: - dep = dts._key - if not dts._who_has: - ts._waiting_on.add(dts) - if dts._state == "released": - recommendations[dep] = "waiting" - else: - dts._waiters.add(ts) + This is the main loop for scheduler-worker interaction + + See Also + -------- + Scheduler.handle_client: Equivalent coroutine for clients + """ + comm.name = "Scheduler connection to worker" + worker_comm = self.stream_comms[worker] + worker_comm.start(comm) + logger.info("Starting worker compute stream, %s", worker) + try: + await self.handle_stream(comm=comm, extra={"worker": worker}) + finally: + if worker in self.stream_comms: + worker_comm.abort() + await self.remove_worker(address=worker) + + def add_plugin(self, plugin=None, idempotent=False, **kwargs): + """ + Add external plugin to scheduler + + See https://distributed.readthedocs.io/en/latest/plugins.html + """ + if isinstance(plugin, type): + plugin = plugin(self, **kwargs) - ts._waiters = {dts for dts in ts._dependents if dts._state == "waiting"} + if idempotent and any(isinstance(p, type(plugin)) for p in self.plugins): + return - if not ts._waiting_on: - if workers: - recommendations[key] = "processing" - else: - self.unrunnable.add(ts) - ts.state = "no-worker" + self.plugins.append(plugin) - return recommendations, worker_msgs, client_msgs - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + def remove_plugin(self, plugin): + """ Remove external plugin from scheduler """ + self.plugins.remove(plugin) - pdb.set_trace() - raise + def worker_send(self, worker, msg): + """Send message to worker - def transition_no_worker_waiting(self, key): + This also handles connection failures by adding a callback to remove + the worker on the next cycle. + """ + stream_comms: dict = self.stream_comms try: - tasks: dict = self.tasks - workers: dict = cast(dict, self.workers) - ts: TaskState = tasks[key] - dts: TaskState - worker_msgs: dict = {} - client_msgs: dict = {} + stream_comms[worker].send(msg) + except (CommClosedError, AttributeError): + self.loop.add_callback(self.remove_worker, address=worker) - if self.validate: - assert ts in self.unrunnable - assert not ts._waiting_on - assert not ts._who_has - assert not ts._processing_on + def client_send(self, client, msg): + """Send message to client""" + client_comms: dict = self.client_comms + c = client_comms.get(client) + if c is None: + return + try: + c.send(msg) + except CommClosedError: + if self.status == Status.running: + logger.critical("Tried writing to closed comm: %s", msg) - self.unrunnable.remove(ts) + ############################ + # Less common interactions # + ############################ - if ts._has_lost_dependencies: - return {key: "forgotten"}, worker_msgs, client_msgs + async def scatter( + self, + comm=None, + data=None, + workers=None, + client=None, + broadcast=False, + timeout=2, + ): + """Send data out to workers - recommendations: dict = {} + See also + -------- + Scheduler.broadcast: + """ + start = time() + while not self.workers: + await asyncio.sleep(0.2) + if time() > start + timeout: + raise TimeoutError("No workers found") - for dts in ts._dependencies: - dep = dts._key - if not dts._who_has: - ts._waiting_on.add(dts) - if dts._state == "released": - recommendations[dep] = "waiting" - else: - dts._waiters.add(ts) + if workers is None: + ws: WorkerState + nthreads = {w: ws._nthreads for w, ws in self.workers.items()} + else: + workers = [self.coerce_address(w) for w in workers] + nthreads = {w: self.workers[w].nthreads for w in workers} - ts.state = "waiting" + assert isinstance(data, dict) - if not ts._waiting_on: - if workers: - recommendations[key] = "processing" - else: - self.unrunnable.add(ts) - ts.state = "no-worker" + keys, who_has, nbytes = await scatter_to_workers( + nthreads, data, rpc=self.rpc, report=False + ) - return recommendations, worker_msgs, client_msgs - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + self.update_data(who_has=who_has, nbytes=nbytes, client=client) - pdb.set_trace() - raise + if broadcast: + if broadcast == True: # noqa: E712 + n = len(nthreads) + else: + n = broadcast + await self.replicate(keys=keys, workers=workers, n=n) - def decide_worker(self, ts: TaskState) -> WorkerState: - """ - Decide on a worker for task *ts*. Return a WorkerState. - """ - workers: dict = cast(dict, self.workers) - ws: WorkerState = None - valid_workers: set = self.valid_workers(ts) + self.log_event( + [client, "all"], {"action": "scatter", "client": client, "count": len(data)} + ) + return keys - if ( - valid_workers is not None - and not valid_workers - and not ts._loose_restrictions - and workers - ): - self.unrunnable.add(ts) - ts.state = "no-worker" - return ws + async def gather(self, comm=None, keys=None, serializers=None): + """ Collect data in from workers """ + ws: WorkerState + keys = list(keys) + who_has = {} + for key in keys: + ts: TaskState = self.tasks.get(key) + if ts is not None: + who_has[key] = [ws._address for ws in ts._who_has] + else: + who_has[key] = [] - if ts._dependencies or valid_workers is not None: - ws = decide_worker( - ts, - workers.values(), - valid_workers, - partial(self.worker_objective, ts), - ) + data, missing_keys, missing_workers = await gather_from_workers( + who_has, rpc=self.rpc, close=False, serializers=serializers + ) + if not missing_keys: + result = {"status": "OK", "data": data} else: - worker_pool = self.idle or self.workers - worker_pool_dv = cast(dict, worker_pool) - n_workers: Py_ssize_t = len(worker_pool_dv) - if n_workers < 20: # smart but linear in small case - ws = min(worker_pool.values(), key=operator.attrgetter("occupancy")) - else: # dumb but fast in large case - n_tasks: Py_ssize_t = self.n_tasks - ws = worker_pool.values()[n_tasks % n_workers] - - if self.validate: - assert ws is None or isinstance(ws, WorkerState), ( - type(ws), - ws, + missing_states = [ + (self.tasks[key].state if key in self.tasks else None) + for key in missing_keys + ] + logger.exception( + "Couldn't gather keys %s state: %s workers: %s", + missing_keys, + missing_states, + missing_workers, ) - assert ws._address in workers + result = {"status": "error", "keys": missing_keys} + with log_errors(): + # Remove suspicious workers from the scheduler but allow them to + # reconnect. + await asyncio.gather( + *[ + self.remove_worker(address=worker, close=False) + for worker in missing_workers + ] + ) + for key, workers in missing_keys.items(): + # Task may already be gone if it was held by a + # `missing_worker` + ts: TaskState = self.tasks.get(key) + logger.exception( + "Workers don't have promised key: %s, %s", + str(workers), + str(key), + ) + if not workers or ts is None: + continue + for worker in workers: + ws = self.workers.get(worker) + if ws is not None and ts in ws._has_what: + ws._has_what.remove(ts) + ts._who_has.remove(ws) + ws._nbytes -= ts.get_nbytes() + self.transitions({key: "released"}) - return ws + self.log_event("all", {"action": "gather", "count": len(keys)}) + return result - def set_duration_estimate(self, ts: TaskState, ws: WorkerState): - """Estimate task duration using worker state and task state. + def clear_task_state(self): + # XXX what about nested state such as ClientState.wants_what + # (see also fire-and-forget...) + logger.info("Clear task state") + for collection in self._task_state_collections: + collection.clear() - If a task takes longer than twice the current average duration we - estimate the task duration to be 2x current-runtime, otherwise we set it - to be the average duration. - """ - duration: double = self.get_task_duration(ts) - comm: double = self.get_comm_cost(ts, ws) - total_duration: double = duration + comm - if ts in ws._executing: - exec_time: double = ws._executing[ts] - if exec_time > 2 * duration: - total_duration = 2 * exec_time - ws._processing[ts] = total_duration - return total_duration + async def restart(self, client=None, timeout=3): + """ Restart all workers. Reset local state. """ + with log_errors(): + + n_workers = len(self.workers) + + logger.info("Send lost future signal to clients") + cs: ClientState + ts: TaskState + for cs in self.clients.values(): + self.client_releases_keys( + keys=[ts._key for ts in cs._wants_what], client=cs._client_key + ) + + ws: WorkerState + nannies = {addr: ws._nanny for addr, ws in self.workers.items()} + + for addr in list(self.workers): + try: + # Ask the worker to close if it doesn't have a nanny, + # otherwise the nanny will kill it anyway + await self.remove_worker(address=addr, close=addr not in nannies) + except Exception as e: + logger.info( + "Exception while restarting. This is normal", exc_info=True + ) - def transition_waiting_processing(self, key): - try: - tasks: dict = self.tasks - ts: TaskState = tasks[key] - dts: TaskState - worker_msgs: dict = {} - client_msgs: dict = {} + self.clear_task_state() - if self.validate: - assert not ts._waiting_on - assert not ts._who_has - assert not ts._exception_blame - assert not ts._processing_on - assert not ts._has_lost_dependencies - assert ts not in self.unrunnable - assert all([dts._who_has for dts in ts._dependencies]) + for plugin in self.plugins[:]: + try: + plugin.restart(self) + except Exception as e: + logger.exception(e) - ws: WorkerState = self.decide_worker(ts) - if ws is None: - return {}, worker_msgs, client_msgs - worker = ws._address + logger.debug("Send kill signal to nannies: %s", nannies) - duration_estimate = self.set_duration_estimate(ts, ws) - ts._processing_on = ws - ws._occupancy += duration_estimate - self.total_occupancy += duration_estimate - ts.state = "processing" - self.consume_resources(ts, ws) - self.check_idle_saturated(ws) - self.n_tasks += 1 + nannies = [ + rpc(nanny_address, connection_args=self.connection_args) + for nanny_address in nannies.values() + if nanny_address is not None + ] - if ts._actor: - ws._actors.add(ts) + resps = All( + [ + nanny.restart( + close=True, timeout=timeout * 0.8, executor_wait=False + ) + for nanny in nannies + ] + ) + try: + resps = await asyncio.wait_for(resps, timeout) + except TimeoutError: + logger.error( + "Nannies didn't report back restarted within " + "timeout. Continuuing with restart process" + ) + else: + if not all(resp == "OK" for resp in resps): + logger.error( + "Not all workers responded positively: %s", resps, exc_info=True + ) + finally: + await asyncio.gather(*[nanny.close_rpc() for nanny in nannies]) - # logger.debug("Send job to worker: %s, %s", worker, key) + self.clear_task_state() - self.send_task_to_worker(worker, ts) + with suppress(AttributeError): + for c in self._worker_coroutines: + c.cancel() - return {}, worker_msgs, client_msgs - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + self.log_event([client, "all"], {"action": "restart", "client": client}) + start = time() + while time() < start + 10 and len(self.workers) < n_workers: + await asyncio.sleep(0.01) - pdb.set_trace() - raise + self.report({"op": "restart"}) - def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs): - try: - workers: dict = cast(dict, self.workers) - ws: WorkerState = workers[worker] - tasks: dict = self.tasks - ts: TaskState = tasks[key] - worker_msgs: dict = {} - client_msgs: dict = {} + async def broadcast( + self, + comm=None, + msg=None, + workers=None, + hosts=None, + nanny=False, + serializers=None, + ): + """ Broadcast message to workers, return all results """ + if workers is None or workers is True: + if hosts is None: + workers = list(self.workers) + else: + workers = [] + if hosts is not None: + for host in hosts: + if host in self.host_info: + workers.extend(self.host_info[host]["addresses"]) + # TODO replace with worker_list - if self.validate: - assert not ts._processing_on - assert ts._waiting_on - assert ts._state == "waiting" + if nanny: + addresses = [self.workers[w].nanny for w in workers] + else: + addresses = workers - ts._waiting_on.clear() + async def send_message(addr): + comm = await self.rpc.connect(addr) + comm.name = "Scheduler Broadcast" + try: + resp = await send_recv(comm, close=True, serializers=serializers, **msg) + finally: + self.rpc.reuse(addr, comm) + return resp - if nbytes is not None: - ts.set_nbytes(nbytes) + results = await All( + [send_message(address) for address in addresses if address is not None] + ) - self.check_idle_saturated(ws) + return dict(zip(workers, results)) - recommendations: dict = {} - client_msgs: dict = {} + async def proxy(self, comm=None, msg=None, worker=None, serializers=None): + """ Proxy a communication through the scheduler to some other worker """ + d = await self.broadcast( + comm=comm, msg=msg, workers=[worker], serializers=serializers + ) + return d[worker] - self._add_to_memory(ts, ws, recommendations, client_msgs, **kwargs) + async def _delete_worker_data(self, worker_address, keys): + """Delete data from a worker and update the corresponding worker/task states - if self.validate: - assert not ts._processing_on - assert not ts._waiting_on - assert ts._who_has + Parameters + ---------- + worker_address: str + Worker address to delete keys from + keys: List[str] + List of keys to delete on the specified worker + """ + await retry_operation( + self.rpc(addr=worker_address).delete_data, keys=list(keys), report=False + ) - return recommendations, worker_msgs, client_msgs - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + ws: WorkerState = self.workers[worker_address] + ts: TaskState + tasks: set = {self.tasks[key] for key in keys} + ws._has_what -= tasks + for ts in tasks: + ts._who_has.remove(ws) + ws._nbytes -= ts.get_nbytes() + self.log_event(ws._address, {"action": "remove-worker-data", "keys": keys}) - pdb.set_trace() - raise + async def rebalance(self, comm=None, keys=None, workers=None): + """Rebalance keys so that each worker stores roughly equal bytes - def transition_processing_memory( - self, - key, - nbytes=None, - type=None, - typename=None, - worker=None, - startstops=None, - **kwargs, - ): - ws: WorkerState - wws: WorkerState - worker_msgs: dict = {} - client_msgs: dict = {} - try: - tasks: dict = self.tasks - ts: TaskState = tasks[key] - assert worker - assert isinstance(worker, str) + **Policy** - if self.validate: - assert ts._processing_on - ws = ts._processing_on - assert ts in ws._processing - assert not ts._waiting_on - assert not ts._who_has, (ts, ts._who_has) - assert not ts._exception_blame - assert ts._state == "processing" + This orders the workers by what fraction of bytes of the existing keys + they have. It walks down this list from most-to-least. At each worker + it sends the largest results it can find and sends them to the least + occupied worker until either the sender or the recipient are at the + average expected load. + """ + ts: TaskState + with log_errors(): + async with self._lock: + if keys: + tasks = {self.tasks[k] for k in keys} + missing_data = [ts._key for ts in tasks if not ts._who_has] + if missing_data: + return {"status": "missing-data", "keys": missing_data} + else: + tasks = set(self.tasks.values()) - workers: dict = cast(dict, self.workers) - ws = workers.get(worker) - if ws is None: - return {key: "released"}, worker_msgs, client_msgs + if workers: + workers = {self.workers[w] for w in workers} + workers_by_task = {ts: ts._who_has & workers for ts in tasks} + else: + workers = set(self.workers.values()) + workers_by_task = {ts: ts._who_has for ts in tasks} - if ws != ts._processing_on: # someone else has this task - logger.info( - "Unexpected worker completed task, likely due to" - " work stealing. Expected: %s, Got: %s, Key: %s", - ts._processing_on, - ws, - key, - ) - return {}, worker_msgs, client_msgs + ws: WorkerState + tasks_by_worker = {ws: set() for ws in workers} - if startstops: - L = list() - for startstop in startstops: - stop = startstop["stop"] - start = startstop["start"] - action = startstop["action"] - if action == "compute": - L.append((start, stop)) + for k, v in workers_by_task.items(): + for vv in v: + tasks_by_worker[vv].add(k) - # record timings of all actions -- a cheaper way of - # getting timing info compared with get_task_stream() - ts._prefix._all_durations[action] += stop - start + worker_bytes = { + ws: sum(ts.get_nbytes() for ts in v) + for ws, v in tasks_by_worker.items() + } - if len(L) > 0: - compute_start, compute_stop = L[0] - else: # This is very rare - compute_start = compute_stop = None - else: - compute_start = compute_stop = None + avg = sum(worker_bytes.values()) / len(worker_bytes) - ############################# - # Update Timing Information # - ############################# - if compute_start and ws._processing.get(ts, True): - # Update average task duration for worker - old_duration = ts._prefix._duration_average - new_duration = compute_stop - compute_start - if old_duration < 0: - avg_duration = new_duration - else: - avg_duration = 0.5 * old_duration + 0.5 * new_duration + sorted_workers = list( + map(first, sorted(worker_bytes.items(), key=second, reverse=True)) + ) - ts._prefix._duration_average = avg_duration - ts._group._duration += new_duration + recipients = iter(reversed(sorted_workers)) + recipient = next(recipients) + msgs = [] # (sender, recipient, key) + for sender in sorted_workers[: len(workers) // 2]: + sender_keys = { + ts: ts.get_nbytes() for ts in tasks_by_worker[sender] + } + sender_keys = iter( + sorted(sender_keys.items(), key=second, reverse=True) + ) - tts: TaskState - for tts in self.unknown_durations.pop(ts._prefix._name, ()): - if tts._processing_on: - wws = tts._processing_on - old = wws._processing[tts] - comm = self.get_comm_cost(tts, wws) - wws._processing[tts] = avg_duration + comm - wws._occupancy += avg_duration + comm - old - self.total_occupancy += avg_duration + comm - old + try: + while worker_bytes[sender] > avg: + while ( + worker_bytes[recipient] < avg + and worker_bytes[sender] > avg + ): + ts, nb = next(sender_keys) + if ts not in tasks_by_worker[recipient]: + tasks_by_worker[recipient].add(ts) + # tasks_by_worker[sender].remove(ts) + msgs.append((sender, recipient, ts)) + worker_bytes[sender] -= nb + worker_bytes[recipient] += nb + if worker_bytes[sender] > avg: + recipient = next(recipients) + except StopIteration: + break - ############################ - # Update State Information # - ############################ - if nbytes is not None: - ts.set_nbytes(nbytes) + to_recipients = defaultdict(lambda: defaultdict(list)) + to_senders = defaultdict(list) + for sender, recipient, ts in msgs: + to_recipients[recipient.address][ts._key].append(sender.address) + to_senders[sender.address].append(ts._key) - recommendations: dict = {} - client_msgs: dict = {} + result = await asyncio.gather( + *( + retry_operation(self.rpc(addr=r).gather, who_has=v) + for r, v in to_recipients.items() + ) + ) + for r, v in to_recipients.items(): + self.log_event(r, {"action": "rebalance", "who_has": v}) - self._remove_from_processing(ts) + self.log_event( + "all", + { + "action": "rebalance", + "total-keys": len(tasks), + "senders": valmap(len, to_senders), + "recipients": valmap(len, to_recipients), + "moved_keys": len(msgs), + }, + ) - self._add_to_memory( - ts, ws, recommendations, client_msgs, type=type, typename=typename - ) + if not all(r["status"] == "OK" for r in result): + return { + "status": "missing-data", + "keys": tuple( + concat( + r["keys"].keys() + for r in result + if r["status"] == "missing-data" + ) + ), + } - if self.validate: - assert not ts._processing_on - assert not ts._waiting_on + for sender, recipient, ts in msgs: + assert ts._state == "memory" + ts._who_has.add(recipient) + recipient.has_what.add(ts) + recipient.nbytes += ts.get_nbytes() + self.log.append( + ( + "rebalance", + ts._key, + time(), + sender.address, + recipient.address, + ) + ) - return recommendations, worker_msgs, client_msgs - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + await asyncio.gather( + *(self._delete_worker_data(r, v) for r, v in to_senders.items()) + ) - pdb.set_trace() - raise + return {"status": "OK"} - def transition_memory_released(self, key, safe=False): - ws: WorkerState - try: - tasks: dict = self.tasks - ts: TaskState = tasks[key] - dts: TaskState - worker_msgs: dict = {} - client_msgs: dict = {} + async def replicate( + self, + comm=None, + keys=None, + n=None, + workers=None, + branching_factor=2, + delete=True, + lock=True, + ): + """Replicate data throughout cluster - if self.validate: - assert not ts._waiting_on - assert not ts._processing_on - if safe: - assert not ts._waiters + This performs a tree copy of the data throughout the network + individually on each piece of data. - if ts._actor: - for ws in ts._who_has: - ws._actors.discard(ts) - if ts._who_wants: - ts._exception_blame = ts - ts._exception = "Worker holding Actor was lost" - return ( - {ts._key: "erred"}, - worker_msgs, - client_msgs, - ) # don't try to recreate + Parameters + ---------- + keys: Iterable + list of keys to replicate + n: int + Number of replications we expect to see within the cluster + branching_factor: int, optional + The number of workers that can copy data in each generation. + The larger the branching factor, the more data we copy in + a single step, but the more a given worker risks being + swamped by data requests. - recommendations: dict = {} + See also + -------- + Scheduler.rebalance + """ + ws: WorkerState + wws: WorkerState + ts: TaskState - for dts in ts._waiters: - if dts._state in ("no-worker", "processing"): - recommendations[dts._key] = "waiting" - elif dts._state == "waiting": - dts._waiting_on.add(ts) + assert branching_factor > 0 + async with self._lock if lock else empty_context: + workers = {self.workers[w] for w in self.workers_list(workers)} + if n is None: + n = len(workers) + else: + n = min(n, len(workers)) + if n == 0: + raise ValueError("Can not use replicate to delete data") - # XXX factor this out? - for ws in ts._who_has: - ws._has_what.remove(ts) - ws._nbytes -= ts.get_nbytes() - ts._group._nbytes_in_memory -= ts.get_nbytes() - worker_msgs[ws._address] = { - "op": "delete-data", - "keys": [key], - "report": False, - } + tasks = {self.tasks[k] for k in keys} + missing_data = [ts._key for ts in tasks if not ts._who_has] + if missing_data: + return {"status": "missing-data", "keys": missing_data} - ts._who_has.clear() + # Delete extraneous data + if delete: + del_worker_tasks = defaultdict(set) + for ts in tasks: + del_candidates = ts._who_has & workers + if len(del_candidates) > n: + for ws in random.sample( + del_candidates, len(del_candidates) - n + ): + del_worker_tasks[ws].add(ts) - ts.state = "released" + await asyncio.gather( + *[ + self._delete_worker_data(ws._address, [t.key for t in tasks]) + for ws, tasks in del_worker_tasks.items() + ] + ) - report_msg = {"op": "lost-data", "key": key} - cs: ClientState - for cs in ts._who_wants: - client_msgs[cs._client_key] = report_msg + # Copy not-yet-filled data + while tasks: + gathers = defaultdict(dict) + for ts in list(tasks): + if ts._state == "forgotten": + # task is no longer needed by any client or dependant task + tasks.remove(ts) + continue + n_missing = n - len(ts._who_has & workers) + if n_missing <= 0: + # Already replicated enough + tasks.remove(ts) + continue - if not ts._run_spec: # pure data - recommendations[key] = "forgotten" - elif ts._has_lost_dependencies: - recommendations[key] = "forgotten" - elif ts._who_wants or ts._waiters: - recommendations[key] = "waiting" + count = min(n_missing, branching_factor * len(ts._who_has)) + assert count > 0 - if self.validate: - assert not ts._waiting_on + for ws in random.sample(workers - ts._who_has, count): + gathers[ws._address][ts._key] = [ + wws._address for wws in ts._who_has + ] - return recommendations, worker_msgs, client_msgs - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + results = await asyncio.gather( + *( + retry_operation(self.rpc(addr=w).gather, who_has=who_has) + for w, who_has in gathers.items() + ) + ) + for w, v in zip(gathers, results): + if v["status"] == "OK": + self.add_keys(worker=w, keys=list(gathers[w])) + else: + logger.warning("Communication failed during replication: %s", v) - pdb.set_trace() - raise + self.log_event(w, {"action": "replicate-add", "keys": gathers[w]}) - def transition_released_erred(self, key): - try: - tasks: dict = self.tasks - ts: TaskState = tasks[key] - dts: TaskState - failing_ts: TaskState - worker_msgs: dict = {} - client_msgs: dict = {} + self.log_event( + "all", + { + "action": "replicate", + "workers": list(workers), + "key-count": len(keys), + "branching-factor": branching_factor, + }, + ) - if self.validate: - with log_errors(pdb=LOG_PDB): - assert ts._exception_blame - assert not ts._who_has - assert not ts._waiting_on - assert not ts._waiters + def workers_to_close( + self, + comm=None, + memory_ratio=None, + n=None, + key=None, + minimum=None, + target=None, + attribute="address", + ): + """ + Find workers that we can close with low cost - recommendations: dict = {} + This returns a list of workers that are good candidates to retire. + These workers are not running anything and are storing + relatively little data relative to their peers. If all workers are + idle then we still maintain enough workers to have enough RAM to store + our data, with a comfortable buffer. - failing_ts = ts._exception_blame + This is for use with systems like ``distributed.deploy.adaptive``. - for dts in ts._dependents: - dts._exception_blame = failing_ts - if not dts._who_has: - recommendations[dts._key] = "erred" + Parameters + ---------- + memory_factor: Number + Amount of extra space we want to have for our stored data. + Defaults two 2, or that we want to have twice as much memory as we + currently have data. + n: int + Number of workers to close + minimum: int + Minimum number of workers to keep around + key: Callable(WorkerState) + An optional callable mapping a WorkerState object to a group + affiliation. Groups will be closed together. This is useful when + closing workers must be done collectively, such as by hostname. + target: int + Target number of workers to have after we close + attribute : str + The attribute of the WorkerState object to return, like "address" + or "name". Defaults to "address". - report_msg = { - "op": "task-erred", - "key": key, - "exception": failing_ts._exception, - "traceback": failing_ts._traceback, - } - cs: ClientState - for cs in ts._who_wants: - client_msgs[cs._client_key] = report_msg + Examples + -------- + >>> scheduler.workers_to_close() + ['tcp://192.168.0.1:1234', 'tcp://192.168.0.2:1234'] - ts.state = "erred" + Group workers by hostname prior to closing - # TODO: waiting data? - return recommendations, worker_msgs, client_msgs - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + >>> scheduler.workers_to_close(key=lambda ws: ws.host) + ['tcp://192.168.0.1:1234', 'tcp://192.168.0.1:4567'] - pdb.set_trace() - raise + Remove two workers - def transition_erred_released(self, key): - try: - tasks: dict = self.tasks - ts: TaskState = tasks[key] - dts: TaskState - worker_msgs: dict = {} - client_msgs: dict = {} + >>> scheduler.workers_to_close(n=2) - if self.validate: - with log_errors(pdb=LOG_PDB): - assert all([dts._state != "erred" for dts in ts._dependencies]) - assert ts._exception_blame - assert not ts._who_has - assert not ts._waiting_on - assert not ts._waiters + Keep enough workers to have twice as much memory as we we need. - recommendations: dict = {} + >>> scheduler.workers_to_close(memory_ratio=2) - ts._exception = None - ts._exception_blame = None - ts._traceback = None + Returns + ------- + to_close: list of worker addresses that are OK to close - for dts in ts._dependents: - if dts._state == "erred": - recommendations[dts._key] = "waiting" + See Also + -------- + Scheduler.retire_workers + """ + if target is not None and n is None: + n = len(self.workers) - target + if n is not None: + if n < 0: + n = 0 + target = len(self.workers) - n - report_msg = {"op": "task-retried", "key": key} - cs: ClientState - for cs in ts._who_wants: - client_msgs[cs._client_key] = report_msg + if n is None and memory_ratio is None: + memory_ratio = 2 - ts.state = "released" + ws: WorkerState + with log_errors(): + if not n and all([ws._processing for ws in self.workers.values()]): + return [] - return recommendations, worker_msgs, client_msgs - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + if key is None: + key = operator.attrgetter("address") + if isinstance(key, bytes) and dask.config.get( + "distributed.scheduler.pickle" + ): + key = pickle.loads(key) - pdb.set_trace() - raise + groups = groupby(key, self.workers.values()) - def transition_waiting_released(self, key): - try: - tasks: dict = self.tasks - ts: TaskState = tasks[key] - worker_msgs: dict = {} - client_msgs: dict = {} + limit_bytes = { + k: sum([ws._memory_limit for ws in v]) for k, v in groups.items() + } + group_bytes = {k: sum([ws._nbytes for ws in v]) for k, v in groups.items()} - if self.validate: - assert not ts._who_has - assert not ts._processing_on + limit = sum(limit_bytes.values()) + total = sum(group_bytes.values()) - recommendations: dict = {} + def _key(group): + wws: WorkerState + is_idle = not any([wws._processing for wws in groups[group]]) + bytes = -group_bytes[group] + return (is_idle, bytes) - dts: TaskState - for dts in ts._dependencies: - s = dts._waiters - if ts in s: - s.discard(ts) - if not s and not dts._who_wants: - recommendations[dts._key] = "released" - ts._waiting_on.clear() + idle = sorted(groups, key=_key) - ts.state = "released" + to_close = [] + n_remain = len(self.workers) - if ts._has_lost_dependencies: - recommendations[key] = "forgotten" - elif not ts._exception_blame and (ts._who_wants or ts._waiters): - recommendations[key] = "waiting" - else: - ts._waiters.clear() + while idle: + group = idle.pop() + if n is None and any([ws._processing for ws in groups[group]]): + break - return recommendations, worker_msgs, client_msgs - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + if minimum and n_remain - len(groups[group]) < minimum: + break - pdb.set_trace() - raise + limit -= limit_bytes[group] - def transition_processing_released(self, key): - try: - tasks: dict = self.tasks - ts: TaskState = tasks[key] - dts: TaskState - worker_msgs: dict = {} - client_msgs: dict = {} + if (n is not None and n_remain - len(groups[group]) >= target) or ( + memory_ratio is not None and limit >= memory_ratio * total + ): + to_close.append(group) + n_remain -= len(groups[group]) - if self.validate: - assert ts._processing_on - assert not ts._who_has - assert not ts._waiting_on - assert self.tasks[key].state == "processing" + else: + break - w: str = self._remove_from_processing(ts) - if w: - worker_msgs[w] = {"op": "release-task", "key": key} + result = [getattr(ws, attribute) for g in to_close for ws in groups[g]] + if result: + logger.debug("Suggest closing workers: %s", result) - ts.state = "released" + return result - recommendations: dict = {} + async def retire_workers( + self, + comm=None, + workers=None, + remove=True, + close_workers=False, + names=None, + lock=True, + **kwargs, + ) -> dict: + """Gracefully retire workers from cluster - if ts._has_lost_dependencies: - recommendations[key] = "forgotten" - elif ts._waiters or ts._who_wants: - recommendations[key] = "waiting" + Parameters + ---------- + workers: list (optional) + List of worker addresses to retire. + If not provided we call ``workers_to_close`` which finds a good set + workers_names: list (optional) + List of worker names to retire. + remove: bool (defaults to True) + Whether or not to remove the worker metadata immediately or else + wait for the worker to contact us + close_workers: bool (defaults to False) + Whether or not to actually close the worker explicitly from here. + Otherwise we expect some external job scheduler to finish off the + worker. + **kwargs: dict + Extra options to pass to workers_to_close to determine which + workers we should drop - if recommendations.get(key) != "waiting": - for dts in ts._dependencies: - if dts._state != "released": - s = dts._waiters - s.discard(ts) - if not s and not dts._who_wants: - recommendations[dts._key] = "released" - ts._waiters.clear() + Returns + ------- + Dictionary mapping worker ID/address to dictionary of information about + that worker for each retired worker. - if self.validate: - assert not ts._processing_on + See Also + -------- + Scheduler.workers_to_close + """ + ws: WorkerState + ts: TaskState + with log_errors(): + async with self._lock if lock else empty_context: + if names is not None: + if names: + logger.info("Retire worker names %s", names) + names = set(map(str, names)) + workers = [ + ws._address + for ws in self.workers.values() + if str(ws._name) in names + ] + if workers is None: + while True: + try: + workers = self.workers_to_close(**kwargs) + if workers: + workers = await self.retire_workers( + workers=workers, + remove=remove, + close_workers=close_workers, + lock=False, + ) + return workers + else: + return {} + except KeyError: # keys left during replicate + pass + workers = {self.workers[w] for w in workers if w in self.workers} + if not workers: + return {} + logger.info("Retire workers %s", workers) - return recommendations, worker_msgs, client_msgs - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + # Keys orphaned by retiring those workers + keys = set.union(*[w.has_what for w in workers]) + keys = {ts._key for ts in keys if ts._who_has.issubset(workers)} - pdb.set_trace() - raise + other_workers = set(self.workers.values()) - workers + if keys: + if other_workers: + logger.info("Moving %d keys to other workers", len(keys)) + await self.replicate( + keys=keys, + workers=[ws._address for ws in other_workers], + n=1, + delete=False, + lock=False, + ) + else: + return {} - def transition_processing_erred( - self, key, cause=None, exception=None, traceback=None, **kwargs - ): - ws: WorkerState - try: - tasks: dict = self.tasks - ts: TaskState = tasks[key] - dts: TaskState - failing_ts: TaskState - worker_msgs: dict = {} - client_msgs: dict = {} + worker_keys = {ws._address: ws.identity() for ws in workers} + if close_workers and worker_keys: + await asyncio.gather( + *[self.close_worker(worker=w, safe=True) for w in worker_keys] + ) + if remove: + await asyncio.gather( + *[self.remove_worker(address=w, safe=True) for w in worker_keys] + ) - if self.validate: - assert cause or ts._exception_blame - assert ts._processing_on - assert not ts._who_has - assert not ts._waiting_on + self.log_event( + "all", + { + "action": "retire-workers", + "workers": worker_keys, + "moved-keys": len(keys), + }, + ) + self.log_event(list(worker_keys), {"action": "retired"}) - if ts._actor: - ws = ts._processing_on - ws._actors.remove(ts) + return worker_keys - self._remove_from_processing(ts) + def add_keys(self, comm=None, worker=None, keys=()): + """ + Learn that a worker has certain keys - if exception is not None: - ts._exception = exception - if traceback is not None: - ts._traceback = traceback - if cause is not None: - failing_ts = self.tasks[cause] - ts._exception_blame = failing_ts + This should not be used in practice and is mostly here for legacy + reasons. However, it is sent by workers from time to time. + """ + if worker not in self.workers: + return "not found" + ws: WorkerState = self.workers[worker] + for key in keys: + ts: TaskState = self.tasks.get(key) + if ts is not None and ts._state == "memory": + if ts not in ws._has_what: + ws._nbytes += ts.get_nbytes() + ws._has_what.add(ts) + ts._who_has.add(ws) else: - failing_ts = ts._exception_blame + self.worker_send( + worker, {"op": "delete-data", "keys": [key], "report": False} + ) - recommendations: dict = {} + return "OK" - for dts in ts._dependents: - dts._exception_blame = failing_ts - recommendations[dts._key] = "erred" + def update_data( + self, comm=None, who_has=None, nbytes=None, client=None, serializers=None + ): + """ + Learn that new data has entered the network from an external source - for dts in ts._dependencies: - s = dts._waiters - s.discard(ts) - if not s and not dts._who_wants: - recommendations[dts._key] = "released" + See Also + -------- + Scheduler.mark_key_in_memory + """ + with log_errors(): + who_has = { + k: [self.coerce_address(vv) for vv in v] for k, v in who_has.items() + } + logger.debug("Update data %s", who_has) - ts._waiters.clear() # do anything with this? + for key, workers in who_has.items(): + ts: TaskState = self.tasks.get(key) + if ts is None: + ts: TaskState = self.new_task(key, None, "memory") + ts.state = "memory" + if key in nbytes: + ts.set_nbytes(nbytes[key]) + for w in workers: + ws: WorkerState = self.workers[w] + if ts not in ws._has_what: + ws._nbytes += ts.get_nbytes() + ws._has_what.add(ts) + ts._who_has.add(ws) + self.report( + {"op": "key-in-memory", "key": key, "workers": list(workers)} + ) - ts.state = "erred" + if client: + self.client_desires_keys(keys=list(who_has), client=client) - report_msg = { - "op": "task-erred", - "key": key, - "exception": failing_ts._exception, - "traceback": failing_ts._traceback, - } - cs: ClientState - for cs in ts._who_wants: - client_msgs[cs._client_key] = report_msg + def report_on_key(self, key: str = None, ts: TaskState = None, client: str = None): + if ts is None: + tasks: dict = self.tasks + ts = tasks.get(key) + elif key is None: + key = ts._key + else: + assert False, (key, ts) + return + + report_msg: dict = self._task_to_report_msg(ts) + if report_msg is not None: + self.report(report_msg, ts=ts, client=client) + + async def feed( + self, comm, function=None, setup=None, teardown=None, interval="1s", **kwargs + ): + """ + Provides a data Comm to external requester - cs = self.clients["fire-and-forget"] - if ts in cs._wants_what: - self._client_releases_keys( - cs=cs, - keys=[key], - recommendations=recommendations, - ) + Caution: this runs arbitrary Python code on the scheduler. This should + eventually be phased out. It is mostly used by diagnostics. + """ + if not dask.config.get("distributed.scheduler.pickle"): + logger.warn( + "Tried to call 'feed' route with custom functions, but " + "pickle is disallowed. Set the 'distributed.scheduler.pickle'" + "config value to True to use the 'feed' route (this is mostly " + "commonly used with progress bars)" + ) + return - if self.validate: - assert not ts._processing_on + interval = parse_timedelta(interval) + with log_errors(): + if function: + function = pickle.loads(function) + if setup: + setup = pickle.loads(setup) + if teardown: + teardown = pickle.loads(teardown) + state = setup(self) if setup else None + if inspect.isawaitable(state): + state = await state + try: + while self.status == Status.running: + if state is None: + response = function(self) + else: + response = function(self, state) + await comm.write(response) + await asyncio.sleep(interval) + except (EnvironmentError, CommClosedError): + pass + finally: + if teardown: + teardown(self, state) - return recommendations, worker_msgs, client_msgs - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + def log_worker_event(self, worker=None, topic=None, msg=None): + self.log_event(topic, msg) - pdb.set_trace() - raise + def subscribe_worker_status(self, comm=None): + WorkerStatusPlugin(self, comm) + ident = self.identity() + for v in ident["workers"].values(): + del v["metrics"] + del v["last_seen"] + return ident - def transition_no_worker_released(self, key): - try: - tasks: dict = self.tasks - ts: TaskState = tasks[key] - dts: TaskState - worker_msgs: dict = {} - client_msgs: dict = {} + def get_processing(self, comm=None, workers=None): + ws: WorkerState + ts: TaskState + if workers is not None: + workers = set(map(self.coerce_address, workers)) + return {w: [ts._key for ts in self.workers[w].processing] for w in workers} + else: + return { + w: [ts._key for ts in ws._processing] for w, ws in self.workers.items() + } - if self.validate: - assert self.tasks[key].state == "no-worker" - assert not ts._who_has - assert not ts._waiting_on + def get_who_has(self, comm=None, keys=None): + ws: WorkerState + ts: TaskState + if keys is not None: + return { + k: [ws._address for ws in self.tasks[k].who_has] + if k in self.tasks + else [] + for k in keys + } + else: + return { + key: [ws._address for ws in ts._who_has] + for key, ts in self.tasks.items() + } - self.unrunnable.remove(ts) - ts.state = "released" + def get_has_what(self, comm=None, workers=None): + ws: WorkerState + ts: TaskState + if workers is not None: + workers = map(self.coerce_address, workers) + return { + w: [ts._key for ts in self.workers[w].has_what] + if w in self.workers + else [] + for w in workers + } + else: + return { + w: [ts._key for ts in ws._has_what] for w, ws in self.workers.items() + } - for dts in ts._dependencies: - dts._waiters.discard(ts) + def get_ncores(self, comm=None, workers=None): + ws: WorkerState + if workers is not None: + workers = map(self.coerce_address, workers) + return {w: self.workers[w].nthreads for w in workers if w in self.workers} + else: + return {w: ws._nthreads for w, ws in self.workers.items()} - ts._waiters.clear() + async def get_call_stack(self, comm=None, keys=None): + ts: TaskState + dts: TaskState + if keys is not None: + stack = list(keys) + processing = set() + while stack: + key = stack.pop() + ts = self.tasks[key] + if ts._state == "waiting": + stack.extend([dts._key for dts in ts._dependencies]) + elif ts._state == "processing": + processing.add(ts) - return {}, worker_msgs, client_msgs - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + workers = defaultdict(list) + for ts in processing: + if ts._processing_on: + workers[ts._processing_on.address].append(ts._key) + else: + workers = {w: None for w in self.workers} - pdb.set_trace() - raise + if not workers: + return {} - def remove_key(self, key): - tasks: dict = self.tasks - ts: TaskState = tasks.pop(key) - assert ts._state == "forgotten" - self.unrunnable.discard(ts) - cs: ClientState - for cs in ts._who_wants: - cs._wants_what.remove(ts) - ts._who_wants.clear() - ts._processing_on = None - ts._exception_blame = ts._exception = ts._traceback = None - self.task_metadata.pop(key, None) + results = await asyncio.gather( + *(self.rpc(w).call_stack(keys=v) for w, v in workers.items()) + ) + response = {w: r for w, r in zip(workers, results) if r} + return response - def _propagate_forgotten(self, ts: TaskState, recommendations: dict): - worker_msgs: dict = {} - workers: dict = cast(dict, self.workers) - ts.state = "forgotten" - key: str = ts._key - dts: TaskState - for dts in ts._dependents: - dts._has_lost_dependencies = True - dts._dependencies.remove(ts) - dts._waiting_on.discard(ts) - if dts._state not in ("memory", "erred"): - # Cannot compute task anymore - recommendations[dts._key] = "forgotten" - ts._dependents.clear() - ts._waiters.clear() + def get_nbytes(self, comm=None, keys=None, summary=True): + ts: TaskState + with log_errors(): + if keys is not None: + result = {k: self.tasks[k].nbytes for k in keys} + else: + result = { + k: ts._nbytes for k, ts in self.tasks.items() if ts._nbytes >= 0 + } - for dts in ts._dependencies: - dts._dependents.remove(ts) - s: set = dts._waiters - s.discard(ts) - if not dts._dependents and not dts._who_wants: - # Task not needed anymore - assert dts is not ts - recommendations[dts._key] = "forgotten" - ts._dependencies.clear() - ts._waiting_on.clear() + if summary: + out = defaultdict(lambda: 0) + for k, v in result.items(): + out[key_split(k)] += v + result = dict(out) - if ts._who_has: - ts._group._nbytes_in_memory -= ts.get_nbytes() + return result - ws: WorkerState - for ws in ts._who_has: - ws._has_what.remove(ts) - ws._nbytes -= ts.get_nbytes() - w: str = ws._address - if w in workers: # in case worker has died - worker_msgs[w] = {"op": "delete-data", "keys": [key], "report": False} - ts._who_has.clear() + def run_function(self, stream, function, args=(), kwargs={}, wait=True): + """Run a function within this process - return worker_msgs + See Also + -------- + Client.run_on_scheduler: + """ + from .worker import run - def transition_memory_forgotten(self, key): - tasks: dict - ws: WorkerState - try: - tasks = self.tasks - ts: TaskState = tasks[key] - worker_msgs: dict = {} - client_msgs: dict = {} + self.log_event("all", {"action": "run-function", "function": function}) + return run(self, stream, function=function, args=args, kwargs=kwargs, wait=wait) - if self.validate: - assert ts._state == "memory" - assert not ts._processing_on - assert not ts._waiting_on - if not ts._run_spec: - # It's ok to forget a pure data task - pass - elif ts._has_lost_dependencies: - # It's ok to forget a task with forgotten dependencies - pass - elif not ts._who_wants and not ts._waiters and not ts._dependents: - # It's ok to forget a task that nobody needs - pass - else: - assert 0, (ts,) + def set_metadata(self, comm=None, keys=None, value=None): + try: + metadata = self.task_metadata + for key in keys[:-1]: + if key not in metadata or not isinstance(metadata[key], (dict, list)): + metadata[key] = dict() + metadata = metadata[key] + metadata[keys[-1]] = value + except Exception as e: + import pdb - recommendations: dict = {} + pdb.set_trace() - if ts._actor: - for ws in ts._who_has: - ws._actors.discard(ts) + def get_metadata(self, comm=None, keys=None, default=no_default): + metadata = self.task_metadata + for key in keys[:-1]: + metadata = metadata[key] + try: + return metadata[keys[-1]] + except KeyError: + if default != no_default: + return default + else: + raise - worker_msgs = self._propagate_forgotten(ts, recommendations) + def get_task_status(self, comm=None, keys=None): + return { + key: (self.tasks[key].state if key in self.tasks else None) for key in keys + } - client_msgs = self._task_to_client_msgs(ts) - self.remove_key(key) + def get_task_stream(self, comm=None, start=None, stop=None, count=None): + from distributed.diagnostics.task_stream import TaskStreamPlugin - return recommendations, worker_msgs, client_msgs - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + self.add_plugin(TaskStreamPlugin, idempotent=True) + tsp = [p for p in self.plugins if isinstance(p, TaskStreamPlugin)][0] + return tsp.collect(start=start, stop=stop, count=count) - pdb.set_trace() - raise + def start_task_metadata(self, comm=None, name=None): + plugin = CollectTaskMetaDataPlugin(scheduler=self, name=name) - def transition_released_forgotten(self, key): - try: - tasks: dict = self.tasks - ts: TaskState = tasks[key] - worker_msgs: dict = {} - client_msgs: dict = {} + self.add_plugin(plugin) - if self.validate: - assert ts._state in ("released", "erred") - assert not ts._who_has - assert not ts._processing_on - assert not ts._waiting_on, (ts, ts._waiting_on) - if not ts._run_spec: - # It's ok to forget a pure data task - pass - elif ts._has_lost_dependencies: - # It's ok to forget a task with forgotten dependencies - pass - elif not ts._who_wants and not ts._waiters and not ts._dependents: - # It's ok to forget a task that nobody needs - pass - else: - assert 0, (ts,) + def stop_task_metadata(self, comm=None, name=None): + plugins = [ + p + for p in self.plugins + if isinstance(p, CollectTaskMetaDataPlugin) and p.name == name + ] + if len(plugins) != 1: + raise ValueError( + "Expected to find exactly one CollectTaskMetaDataPlugin " + f"with name {name} but found {len(plugins)}." + ) - recommendations: dict = {} - self._propagate_forgotten(ts, recommendations) + plugin = plugins[0] + self.remove_plugin(plugin) + return {"metadata": plugin.metadata, "state": plugin.state} - client_msgs = self._task_to_client_msgs(ts) - self.remove_key(key) + async def register_worker_plugin(self, comm, plugin, name=None): + """ Registers a setup function, and call it on every worker """ + self.worker_plugins.append({"plugin": plugin, "name": name}) - return recommendations, worker_msgs, client_msgs - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + responses = await self.broadcast( + msg=dict(op="plugin-add", plugin=plugin, name=name) + ) + return responses - pdb.set_trace() - raise + ##################### + # State Transitions # + ##################### + + def remove_key(self, key): + tasks: dict = self.tasks + ts: TaskState = tasks.pop(key) + assert ts._state == "forgotten" + self.unrunnable.discard(ts) + cs: ClientState + for cs in ts._who_wants: + cs._wants_what.remove(ts) + ts._who_wants.clear() + ts._processing_on = None + ts._exception_blame = ts._exception = ts._traceback = None + self.task_metadata.pop(key, None) def transition(self, key, finish, *args, **kwargs): """Transition a key from its current state to the finish state @@ -5782,98 +5921,6 @@ def reschedule(self, key=None, worker=None): # Assigning Tasks to Workers # ############################## - def check_idle_saturated(self, ws: WorkerState, occ: double = -1.0): - """Update the status of the idle and saturated state - - The scheduler keeps track of workers that are .. - - - Saturated: have enough work to stay busy - - Idle: do not have enough work to stay busy - - They are considered saturated if they both have enough tasks to occupy - all of their threads, and if the expected runtime of those tasks is - large enough. - - This is useful for load balancing and adaptivity. - """ - total_nthreads: Py_ssize_t = self.total_nthreads - if total_nthreads == 0 or ws.status == Status.closed: - return - if occ < 0: - occ = ws._occupancy - - nc: Py_ssize_t = ws._nthreads - p: Py_ssize_t = len(ws._processing) - total_occupancy: double = self.total_occupancy - avg: double = total_occupancy / total_nthreads - - idle = self.idle - saturated: set = self.saturated - if p < nc or occ < nc * avg / 2: - idle[ws._address] = ws - saturated.discard(ws) - else: - idle.pop(ws._address, None) - - if p > nc: - pending: double = occ * (p - nc) / (p * nc) - if 0.4 < pending > 1.9 * avg: - saturated.add(ws) - return - - saturated.discard(ws) - - def valid_workers(self, ts: TaskState) -> set: - """Return set of currently valid workers for key - - If all workers are valid then this returns ``None``. - This checks tracks the following state: - - * worker_restrictions - * host_restrictions - * resource_restrictions - """ - workers: dict = cast(dict, self.workers) - s: set = None - - if ts._worker_restrictions: - s = {w for w in ts._worker_restrictions if w in workers} - - if ts._host_restrictions: - # Resolve the alias here rather than early, for the worker - # may not be connected when host_restrictions is populated - hr: list = [self.coerce_hostname(h) for h in ts._host_restrictions] - # XXX need HostState? - sl: list = [ - self.host_info[h]["addresses"] for h in hr if h in self.host_info - ] - ss: set = set.union(*sl) if sl else set() - if s is None: - s = ss - else: - s |= ss - - if ts._resource_restrictions: - dw: dict = { - resource: { - w - for w, supplied in self.resources[resource].items() - if supplied >= required - } - for resource, required in ts._resource_restrictions.items() - } - - ww: set = set.intersection(*dw.values()) - if s is None: - s = ww - else: - s &= ww - - if s is not None: - s = {workers[w] for w in s} - - return s - def consume_resources(self, ts: TaskState, ws: WorkerState): if ts._resource_restrictions: for r, required in ts._resource_restrictions.items(): @@ -5965,29 +6012,6 @@ def start_ipython(self, comm=None): ) return self._ipython_kernel.get_connection_info() - def worker_objective(self, ts: TaskState, ws: WorkerState): - """ - Objective function to determine which worker should get the task - - Minimize expected start time. If a tie then break with data storage. - """ - dts: TaskState - nbytes: Py_ssize_t - comm_bytes: Py_ssize_t = 0 - for dts in ts._dependencies: - if ws not in dts._who_has: - nbytes = dts.get_nbytes() - comm_bytes += nbytes - - bandwidth: double = self.bandwidth - stack_time: double = ws._occupancy / ws._nthreads - start_time: double = stack_time + comm_bytes / bandwidth - - if ts._actor: - return (len(ws._actors), start_time, ws._nbytes) - else: - return (start_time, ws._nbytes) - async def get_profile( self, comm=None, @@ -6283,30 +6307,6 @@ def reevaluate_occupancy(self, worker_index=0): logger.error("Error in reevaluate occupancy", exc_info=True) raise - def _reevaluate_occupancy_worker(self, ws: WorkerState): - """ See reevaluate_occupancy """ - old: double = ws._occupancy - new: double = 0 - diff: double - ts: TaskState - est: double - for ts in ws._processing: - est = self.set_duration_estimate(ts, ws) - new += est - - ws._occupancy = new - diff = new - old - self.total_occupancy += diff - self.check_idle_saturated(ws) - - # significant increase in duration - if new > old * 1.3: - steal = self.extensions.get("stealing") - if steal is not None: - for ts in ws._processing: - steal.remove_key_from_stealable(ts) - steal.put_key_in_stealable(ts) - async def check_worker_ttl(self): ws: WorkerState now = time() From 3a02cea137db9206db376803c9d5b64a0ff03d37 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 12:50:11 -0800 Subject: [PATCH 15/38] Add attributes for `SchedulerState` Provides annotations for some attributes used by `SchedulerState`. Also makes sure to provide an `_` prefix for these as they are accessible in Cython only. --- distributed/scheduler.py | 42 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 63d337cad1b..87c13a86d96 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1547,8 +1547,50 @@ class SchedulerState: ``Scheduler`` affecting different transitions here under-the-hood. In the background ``Worker``s also engage with the ``Scheduler`` affecting these state transitions as well. + + **State** + + The ``Transitions`` object contains the following state variables. + Each variable is listed along with what it stores and a brief + description. + + * **tasks:** ``{task key: TaskState}`` + Tasks currently known to the scheduler + * **unrunnable:** ``{TaskState}`` + Tasks in the "no-worker" state + + * **workers:** ``{worker key: WorkerState}`` + Workers currently connected to the scheduler + * **idle:** ``{WorkerState}``: + Set of workers that are not fully utilized + * **saturated:** ``{WorkerState}``: + Set of workers that are not over-utilized + + * **clients:** ``{client key: ClientState}`` + Clients currently connected to the scheduler + + * **task_duration:** ``{key-prefix: time}`` + Time we expect certain functions to take, e.g. ``{'sum': 0.25}`` """ + _bandwidth: double + _clients: dict + _extensions: dict + _host_info: object + _idle: object + _idle_dv: dict + _n_tasks: Py_ssize_t + _resources: object + _saturated: set + _tasks: dict + _total_nthreads: Py_ssize_t + _total_occupancy: double + _unknown_durations: object + _unrunnable: set + _validate: bint + _workers: object + _workers_dv: dict + def __init__(self, **kwargs): super().__init__(**kwargs) From ac4fe97c679ca1d925608385b46893fdd56dee0e Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 12:50:12 -0800 Subject: [PATCH 16/38] Initialize attributes in `SchedulerState` --- distributed/scheduler.py | 50 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 87c13a86d96..6b5c4712ba7 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1591,7 +1591,55 @@ class SchedulerState: _workers: object _workers_dv: dict - def __init__(self, **kwargs): + def __init__( + self, + clients: dict = None, + workers=None, + host_info=None, + resources=None, + tasks: dict = None, + unrunnable: set = None, + validate: bint = False, + **kwargs, + ): + self._bandwidth = parse_bytes( + dask.config.get("distributed.scheduler.bandwidth") + ) + if clients is not None: + self._clients = clients + else: + self._clients = dict() + self._clients["fire-and-forget"] = ClientState("fire-and-forget") + self._extensions = dict() + if host_info is not None: + self._host_info = host_info + else: + self._host_info = defaultdict(dict) + self._idle = sortedcontainers.SortedDict() + self._idle_dv: dict = cast(dict, self._idle) + self._n_tasks = 0 + if resources is not None: + self._resources = resources + else: + self._resources = defaultdict(dict) + self._saturated = set() + if tasks is not None: + self._tasks = tasks + else: + self._tasks = dict() + self._total_nthreads = 0 + self._total_occupancy = 0 + self._unknown_durations = defaultdict(set) + if unrunnable is not None: + self._unrunnable = unrunnable + else: + self._unrunnable = set() + self._validate = validate + if workers is not None: + self._workers = workers + else: + self._workers = sortedcontainers.SortedDict() + self._workers_dv: dict = cast(dict, self._workers) super().__init__(**kwargs) def _remove_from_processing(self, ts: TaskState) -> str: From f39caefde564b11f4de3ddb1f14b6aa86e88fc7a Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 12:50:13 -0800 Subject: [PATCH 17/38] Pass arguments to `super` class --- distributed/scheduler.py | 39 ++++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 6b5c4712ba7..b479016a8ae 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2956,7 +2956,6 @@ def __init__( self.allowed_failures = allowed_failures if validate is None: validate = dask.config.get("distributed.scheduler.validate") - self.validate = validate self.proc = psutil.Process() self.delete_interval = parse_timedelta(delete_interval, default="ms") self.synchronize_worker_interval = parse_timedelta( @@ -3034,7 +3033,7 @@ def __init__( self._ipython_kernel = None # Task state - self.tasks = dict() + tasks = dict() self.task_groups = dict() self.task_prefixes = dict() for old_attr, new_attr, wrap in [ @@ -3046,7 +3045,7 @@ def __init__( func = operator.attrgetter(new_attr) if wrap is not None: func = compose(wrap, func) - setattr(self, old_attr, _StateLegacyMapping(self.tasks, func)) + setattr(self, old_attr, _StateLegacyMapping(tasks, func)) for old_attr, new_attr, wrap in [ ("nbytes", "nbytes", None), @@ -3066,7 +3065,7 @@ def __init__( func = operator.attrgetter(new_attr) if wrap is not None: func = compose(wrap, func) - setattr(self, old_attr, _OptionalStateLegacyMapping(self.tasks, func)) + setattr(self, old_attr, _OptionalStateLegacyMapping(tasks, func)) for old_attr, new_attr, wrap in [ ("loose_restrictions", "loose_restrictions", None) @@ -3074,12 +3073,12 @@ def __init__( func = operator.attrgetter(new_attr) if wrap is not None: func = compose(wrap, func) - setattr(self, old_attr, _StateLegacySet(self.tasks, func)) + setattr(self, old_attr, _StateLegacySet(tasks, func)) self.generation = 0 self._last_client = None self._last_time = 0 - self.unrunnable = set() + unrunnable = set() self.n_tasks = 0 self.task_metadata = dict() @@ -3089,18 +3088,17 @@ def __init__( self.unknown_durations = defaultdict(set) # Client state - self.clients = dict() + clients = dict() for old_attr, new_attr, wrap in [ ("wants_what", "wants_what", _legacy_task_key_set) ]: func = operator.attrgetter(new_attr) if wrap is not None: func = compose(wrap, func) - setattr(self, old_attr, _StateLegacyMapping(self.clients, func)) - self.clients["fire-and-forget"] = ClientState("fire-and-forget") + setattr(self, old_attr, _StateLegacyMapping(clients, func)) # Worker state - self.workers = sortedcontainers.SortedDict() + workers = sortedcontainers.SortedDict() for old_attr, new_attr, wrap in [ ("nthreads", "nthreads", None), ("worker_bytes", "nbytes", None), @@ -3114,23 +3112,23 @@ def __init__( func = operator.attrgetter(new_attr) if wrap is not None: func = compose(wrap, func) - setattr(self, old_attr, _StateLegacyMapping(self.workers, func)) + setattr(self, old_attr, _StateLegacyMapping(workers, func)) self.idle = sortedcontainers.SortedDict() self.saturated = set() self.total_nthreads = 0 self.total_occupancy = 0 - self.host_info = defaultdict(dict) - self.resources = defaultdict(dict) + host_info = defaultdict(dict) + resources = defaultdict(dict) self.aliases = dict() - self._task_state_collections = [self.unrunnable] + self._task_state_collections = [unrunnable] self._worker_collections = [ - self.workers, - self.host_info, - self.resources, + workers, + host_info, + resources, self.aliases, ] @@ -3245,6 +3243,13 @@ def __init__( connection_limit=connection_limit, deserialize=False, connection_args=self.connection_args, + clients=clients, + workers=workers, + host_info=host_info, + resources=resources, + tasks=tasks, + unrunnable=unrunnable, + validate=validate, **kwargs, ) From 6379cd17b426c8111bd578589d5b7a1066865c86 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 12:50:13 -0800 Subject: [PATCH 18/38] Use `SchedulerState` attributes Now that we have typed attributes available through Cython. Use them throughout the scheduler to improve performance. Also drop any duplicate attribute assignments in `Scheduler`'s `__init__`. --- distributed/scheduler.py | 576 +++++++++++++++++++-------------------- 1 file changed, 286 insertions(+), 290 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index b479016a8ae..e5c95968596 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1646,17 +1646,17 @@ def _remove_from_processing(self, ts: TaskState) -> str: """ Remove *ts* from the set of processing tasks. """ - workers: dict = cast(dict, self.workers) + workers: dict = cast(dict, self._workers) ws: WorkerState = ts._processing_on ts._processing_on = None w: str = ws._address if w in workers: # may have been removed duration = ws._processing.pop(ts) if not ws._processing: - self.total_occupancy -= ws._occupancy + self._total_occupancy -= ws._occupancy ws._occupancy = 0 else: - self.total_occupancy -= duration + self._total_occupancy -= duration ws._occupancy -= duration self.check_idle_saturated(ws) self.release_resources(ts, ws) @@ -1677,7 +1677,7 @@ def _add_to_memory( """ Add *ts* to the set of in-memory tasks. """ - if self.validate: + if self._validate: assert ts not in ws._has_what ts._who_has.add(ws) @@ -1720,7 +1720,7 @@ def _add_to_memory( ts._type = typename ts._group._types.add(typename) - cs = self.clients["fire-and-forget"] + cs = self._clients["fire-and-forget"] if ts in cs._wants_what: self._client_releases_keys( cs=cs, @@ -1730,14 +1730,14 @@ def _add_to_memory( def transition_released_waiting(self, key): try: - tasks: dict = self.tasks - workers: dict = cast(dict, self.workers) + tasks: dict = self._tasks + workers: dict = cast(dict, self._workers) ts: TaskState = tasks[key] dts: TaskState worker_msgs: dict = {} client_msgs: dict = {} - if self.validate: + if self._validate: assert ts._run_spec assert not ts._waiting_on assert not ts._who_has @@ -1773,7 +1773,7 @@ def transition_released_waiting(self, key): if workers: recommendations[key] = "processing" else: - self.unrunnable.add(ts) + self._unrunnable.add(ts) ts.state = "no-worker" return recommendations, worker_msgs, client_msgs @@ -1787,20 +1787,20 @@ def transition_released_waiting(self, key): def transition_no_worker_waiting(self, key): try: - tasks: dict = self.tasks - workers: dict = cast(dict, self.workers) + tasks: dict = self._tasks + workers: dict = cast(dict, self._workers) ts: TaskState = tasks[key] dts: TaskState worker_msgs: dict = {} client_msgs: dict = {} - if self.validate: - assert ts in self.unrunnable + if self._validate: + assert ts in self._unrunnable assert not ts._waiting_on assert not ts._who_has assert not ts._processing_on - self.unrunnable.remove(ts) + self._unrunnable.remove(ts) if ts._has_lost_dependencies: return {key: "forgotten"}, worker_msgs, client_msgs @@ -1822,7 +1822,7 @@ def transition_no_worker_waiting(self, key): if workers: recommendations[key] = "processing" else: - self.unrunnable.add(ts) + self._unrunnable.add(ts) ts.state = "no-worker" return recommendations, worker_msgs, client_msgs @@ -1838,7 +1838,7 @@ def decide_worker(self, ts: TaskState) -> WorkerState: """ Decide on a worker for task *ts*. Return a WorkerState. """ - workers: dict = cast(dict, self.workers) + workers: dict = cast(dict, self._workers) ws: WorkerState = None valid_workers: set = self.valid_workers(ts) @@ -1848,7 +1848,7 @@ def decide_worker(self, ts: TaskState) -> WorkerState: and not ts._loose_restrictions and workers ): - self.unrunnable.add(ts) + self._unrunnable.add(ts) ts.state = "no-worker" return ws @@ -1860,16 +1860,16 @@ def decide_worker(self, ts: TaskState) -> WorkerState: partial(self.worker_objective, ts), ) else: - worker_pool = self.idle or self.workers + worker_pool = self._idle or self._workers worker_pool_dv = cast(dict, worker_pool) n_workers: Py_ssize_t = len(worker_pool_dv) if n_workers < 20: # smart but linear in small case ws = min(worker_pool.values(), key=operator.attrgetter("occupancy")) else: # dumb but fast in large case - n_tasks: Py_ssize_t = self.n_tasks + n_tasks: Py_ssize_t = self._n_tasks ws = worker_pool.values()[n_tasks % n_workers] - if self.validate: + if self._validate: assert ws is None or isinstance(ws, WorkerState), ( type(ws), ws, @@ -1897,19 +1897,19 @@ def set_duration_estimate(self, ts: TaskState, ws: WorkerState): def transition_waiting_processing(self, key): try: - tasks: dict = self.tasks + tasks: dict = self._tasks ts: TaskState = tasks[key] dts: TaskState worker_msgs: dict = {} client_msgs: dict = {} - if self.validate: + if self._validate: assert not ts._waiting_on assert not ts._who_has assert not ts._exception_blame assert not ts._processing_on assert not ts._has_lost_dependencies - assert ts not in self.unrunnable + assert ts not in self._unrunnable assert all([dts._who_has for dts in ts._dependencies]) ws: WorkerState = self.decide_worker(ts) @@ -1920,11 +1920,11 @@ def transition_waiting_processing(self, key): duration_estimate = self.set_duration_estimate(ts, ws) ts._processing_on = ws ws._occupancy += duration_estimate - self.total_occupancy += duration_estimate + self._total_occupancy += duration_estimate ts.state = "processing" self.consume_resources(ts, ws) self.check_idle_saturated(ws) - self.n_tasks += 1 + self._n_tasks += 1 if ts._actor: ws._actors.add(ts) @@ -1944,14 +1944,14 @@ def transition_waiting_processing(self, key): def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs): try: - workers: dict = cast(dict, self.workers) + workers: dict = cast(dict, self._workers) ws: WorkerState = workers[worker] - tasks: dict = self.tasks + tasks: dict = self._tasks ts: TaskState = tasks[key] worker_msgs: dict = {} client_msgs: dict = {} - if self.validate: + if self._validate: assert not ts._processing_on assert ts._waiting_on assert ts._state == "waiting" @@ -1968,7 +1968,7 @@ def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs): self._add_to_memory(ts, ws, recommendations, client_msgs, **kwargs) - if self.validate: + if self._validate: assert not ts._processing_on assert not ts._waiting_on assert ts._who_has @@ -1997,12 +1997,12 @@ def transition_processing_memory( worker_msgs: dict = {} client_msgs: dict = {} try: - tasks: dict = self.tasks + tasks: dict = self._tasks ts: TaskState = tasks[key] assert worker assert isinstance(worker, str) - if self.validate: + if self._validate: assert ts._processing_on ws = ts._processing_on assert ts in ws._processing @@ -2011,7 +2011,7 @@ def transition_processing_memory( assert not ts._exception_blame assert ts._state == "processing" - workers: dict = cast(dict, self.workers) + workers: dict = cast(dict, self._workers) ws = workers.get(worker) if ws is None: return {key: "released"}, worker_msgs, client_msgs @@ -2062,14 +2062,14 @@ def transition_processing_memory( ts._group._duration += new_duration tts: TaskState - for tts in self.unknown_durations.pop(ts._prefix._name, ()): + for tts in self._unknown_durations.pop(ts._prefix._name, ()): if tts._processing_on: wws = tts._processing_on old = wws._processing[tts] comm = self.get_comm_cost(tts, wws) wws._processing[tts] = avg_duration + comm wws._occupancy += avg_duration + comm - old - self.total_occupancy += avg_duration + comm - old + self._total_occupancy += avg_duration + comm - old ############################ # Update State Information # @@ -2086,7 +2086,7 @@ def transition_processing_memory( ts, ws, recommendations, client_msgs, type=type, typename=typename ) - if self.validate: + if self._validate: assert not ts._processing_on assert not ts._waiting_on @@ -2102,13 +2102,13 @@ def transition_processing_memory( def transition_memory_released(self, key, safe=False): ws: WorkerState try: - tasks: dict = self.tasks + tasks: dict = self._tasks ts: TaskState = tasks[key] dts: TaskState worker_msgs: dict = {} client_msgs: dict = {} - if self.validate: + if self._validate: assert not ts._waiting_on assert not ts._processing_on if safe: @@ -2161,7 +2161,7 @@ def transition_memory_released(self, key, safe=False): elif ts._who_wants or ts._waiters: recommendations[key] = "waiting" - if self.validate: + if self._validate: assert not ts._waiting_on return recommendations, worker_msgs, client_msgs @@ -2175,14 +2175,14 @@ def transition_memory_released(self, key, safe=False): def transition_released_erred(self, key): try: - tasks: dict = self.tasks + tasks: dict = self._tasks ts: TaskState = tasks[key] dts: TaskState failing_ts: TaskState worker_msgs: dict = {} client_msgs: dict = {} - if self.validate: + if self._validate: with log_errors(pdb=LOG_PDB): assert ts._exception_blame assert not ts._who_has @@ -2222,13 +2222,13 @@ def transition_released_erred(self, key): def transition_erred_released(self, key): try: - tasks: dict = self.tasks + tasks: dict = self._tasks ts: TaskState = tasks[key] dts: TaskState worker_msgs: dict = {} client_msgs: dict = {} - if self.validate: + if self._validate: with log_errors(pdb=LOG_PDB): assert all([dts._state != "erred" for dts in ts._dependencies]) assert ts._exception_blame @@ -2264,12 +2264,12 @@ def transition_erred_released(self, key): def transition_waiting_released(self, key): try: - tasks: dict = self.tasks + tasks: dict = self._tasks ts: TaskState = tasks[key] worker_msgs: dict = {} client_msgs: dict = {} - if self.validate: + if self._validate: assert not ts._who_has assert not ts._processing_on @@ -2304,17 +2304,17 @@ def transition_waiting_released(self, key): def transition_processing_released(self, key): try: - tasks: dict = self.tasks + tasks: dict = self._tasks ts: TaskState = tasks[key] dts: TaskState worker_msgs: dict = {} client_msgs: dict = {} - if self.validate: + if self._validate: assert ts._processing_on assert not ts._who_has assert not ts._waiting_on - assert self.tasks[key].state == "processing" + assert self._tasks[key].state == "processing" w: str = self._remove_from_processing(ts) if w: @@ -2338,7 +2338,7 @@ def transition_processing_released(self, key): recommendations[dts._key] = "released" ts._waiters.clear() - if self.validate: + if self._validate: assert not ts._processing_on return recommendations, worker_msgs, client_msgs @@ -2355,14 +2355,14 @@ def transition_processing_erred( ): ws: WorkerState try: - tasks: dict = self.tasks + tasks: dict = self._tasks ts: TaskState = tasks[key] dts: TaskState failing_ts: TaskState worker_msgs: dict = {} client_msgs: dict = {} - if self.validate: + if self._validate: assert cause or ts._exception_blame assert ts._processing_on assert not ts._who_has @@ -2379,7 +2379,7 @@ def transition_processing_erred( if traceback is not None: ts._traceback = traceback if cause is not None: - failing_ts = self.tasks[cause] + failing_ts = self._tasks[cause] ts._exception_blame = failing_ts else: failing_ts = ts._exception_blame @@ -2410,7 +2410,7 @@ def transition_processing_erred( for cs in ts._who_wants: client_msgs[cs._client_key] = report_msg - cs = self.clients["fire-and-forget"] + cs = self._clients["fire-and-forget"] if ts in cs._wants_what: self._client_releases_keys( cs=cs, @@ -2418,7 +2418,7 @@ def transition_processing_erred( recommendations=recommendations, ) - if self.validate: + if self._validate: assert not ts._processing_on return recommendations, worker_msgs, client_msgs @@ -2432,18 +2432,18 @@ def transition_processing_erred( def transition_no_worker_released(self, key): try: - tasks: dict = self.tasks + tasks: dict = self._tasks ts: TaskState = tasks[key] dts: TaskState worker_msgs: dict = {} client_msgs: dict = {} - if self.validate: - assert self.tasks[key].state == "no-worker" + if self._validate: + assert self._tasks[key].state == "no-worker" assert not ts._who_has assert not ts._waiting_on - self.unrunnable.remove(ts) + self._unrunnable.remove(ts) ts.state = "released" for dts in ts._dependencies: @@ -2462,7 +2462,7 @@ def transition_no_worker_released(self, key): def _propagate_forgotten(self, ts: TaskState, recommendations: dict): worker_msgs: dict = {} - workers: dict = cast(dict, self.workers) + workers: dict = cast(dict, self._workers) ts.state = "forgotten" key: str = ts._key dts: TaskState @@ -2505,12 +2505,12 @@ def transition_memory_forgotten(self, key): tasks: dict ws: WorkerState try: - tasks = self.tasks + tasks = self._tasks ts: TaskState = tasks[key] worker_msgs: dict = {} client_msgs: dict = {} - if self.validate: + if self._validate: assert ts._state == "memory" assert not ts._processing_on assert not ts._waiting_on @@ -2548,12 +2548,12 @@ def transition_memory_forgotten(self, key): def transition_released_forgotten(self, key): try: - tasks: dict = self.tasks + tasks: dict = self._tasks ts: TaskState = tasks[key] worker_msgs: dict = {} client_msgs: dict = {} - if self.validate: + if self._validate: assert ts._state in ("released", "erred") assert not ts._who_has assert not ts._processing_on @@ -2599,7 +2599,7 @@ def check_idle_saturated(self, ws: WorkerState, occ: double = -1.0): This is useful for load balancing and adaptivity. """ - total_nthreads: Py_ssize_t = self.total_nthreads + total_nthreads: Py_ssize_t = self._total_nthreads if total_nthreads == 0 or ws.status == Status.closed: return if occ < 0: @@ -2607,11 +2607,11 @@ def check_idle_saturated(self, ws: WorkerState, occ: double = -1.0): nc: Py_ssize_t = ws._nthreads p: Py_ssize_t = len(ws._processing) - total_occupancy: double = self.total_occupancy + total_occupancy: double = self._total_occupancy avg: double = total_occupancy / total_nthreads - idle = self.idle - saturated: set = self.saturated + idle = self._idle + saturated: set = self._saturated if p < nc or occ < nc * avg / 2: idle[ws._address] = ws saturated.discard(ws) @@ -2632,7 +2632,7 @@ def _client_releases_keys(self, keys: list, cs: ClientState, recommendations: di ts: TaskState tasks2: set = set() for key in keys: - ts = self.tasks.get(key) + ts = self._tasks.get(key) if ts is not None and ts in cs._wants_what: cs._wants_what.remove(ts) s: set = ts._who_wants @@ -2673,7 +2673,7 @@ def _task_to_msg(self, ts: TaskState, duration=None) -> dict: } msg["nbytes"] = {dts._key: dts._nbytes for dts in deps} - if self.validate: + if self._validate: assert all(msg["who_has"].values()) task = ts._run_spec @@ -2704,7 +2704,7 @@ def _task_to_report_msg(self, ts: TaskState) -> dict: def _task_to_client_msgs(self, ts: TaskState) -> dict: cs: ClientState - clients: dict = self.clients + clients: dict = self._clients client_keys: list if ts is None: # Notify all clients @@ -2734,12 +2734,12 @@ def _reevaluate_occupancy_worker(self, ws: WorkerState): ws._occupancy = new diff = new - old - self.total_occupancy += diff + self._total_occupancy += diff self.check_idle_saturated(ws) # significant increase in duration if new > old * 1.3: - steal = self.extensions.get("stealing") + steal = self._extensions.get("stealing") if steal is not None: for ts in ws._processing: steal.remove_key_from_stealable(ts) @@ -2753,7 +2753,7 @@ def get_comm_cost(self, ts: TaskState, ws: WorkerState): dts: TaskState deps: set = ts._dependencies - ws._has_what nbytes: Py_ssize_t = 0 - bandwidth: double = self.bandwidth + bandwidth: double = self._bandwidth for dts in deps: nbytes += dts._nbytes return nbytes / bandwidth @@ -2765,7 +2765,7 @@ def get_task_duration(self, ts: TaskState, default: double = -1): """ duration: double = ts._prefix._duration_average if duration < 0: - s: set = self.unknown_durations[ts._prefix._name] + s: set = self._unknown_durations[ts._prefix._name] s.add(ts) if default < 0: duration = UNKNOWN_TASK_DURATION @@ -2784,7 +2784,7 @@ def valid_workers(self, ts: TaskState) -> set: * host_restrictions * resource_restrictions """ - workers: dict = cast(dict, self.workers) + workers: dict = cast(dict, self._workers) s: set = None if ts._worker_restrictions: @@ -2796,7 +2796,7 @@ def valid_workers(self, ts: TaskState) -> set: hr: list = [self.coerce_hostname(h) for h in ts._host_restrictions] # XXX need HostState? sl: list = [ - self.host_info[h]["addresses"] for h in hr if h in self.host_info + self._host_info[h]["addresses"] for h in hr if h in self._host_info ] ss: set = set.union(*sl) if sl else set() if s is None: @@ -2808,7 +2808,7 @@ def valid_workers(self, ts: TaskState) -> set: dw: dict = { resource: { w - for w, supplied in self.resources[resource].items() + for w, supplied in self._resources[resource].items() if supplied >= required } for resource, required in ts._resource_restrictions.items() @@ -2839,7 +2839,7 @@ def worker_objective(self, ts: TaskState, ws: WorkerState): nbytes = dts.get_nbytes() comm_bytes += nbytes - bandwidth: double = self.bandwidth + bandwidth: double = self._bandwidth stack_time: double = ws._occupancy / ws._nthreads start_time: double = stack_time + comm_bytes / bandwidth @@ -2978,7 +2978,6 @@ def __init__( self.idle_since = time() self.time_started = self.idle_since # compatibility for dask-gateway self._lock = asyncio.Lock() - self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth")) self.bandwidth_workers = defaultdict(float) self.bandwidth_types = defaultdict(float) @@ -3080,12 +3079,10 @@ def __init__( self._last_time = 0 unrunnable = set() - self.n_tasks = 0 self.task_metadata = dict() self.datasets = dict() # Prefix-keyed containers - self.unknown_durations = defaultdict(set) # Client state clients = dict() @@ -3114,11 +3111,6 @@ def __init__( func = compose(wrap, func) setattr(self, old_attr, _StateLegacyMapping(workers, func)) - self.idle = sortedcontainers.SortedDict() - self.saturated = set() - - self.total_nthreads = 0 - self.total_occupancy = 0 host_info = defaultdict(dict) resources = defaultdict(dict) self.aliases = dict() @@ -3132,7 +3124,6 @@ def __init__( self.aliases, ] - self.extensions = {} self.plugins = list(plugins) self.transition_log = deque( maxlen=dask.config.get("distributed.scheduler.transition-log-length") @@ -3280,8 +3271,8 @@ def __init__( def __repr__(self): return '' % ( self.address, - len(self.workers), - self.total_nthreads, + len(self._workers), + self._total_nthreads, ) def identity(self, comm=None): @@ -3292,7 +3283,7 @@ def identity(self, comm=None): "address": self.address, "services": {key: v.port for (key, v) in self.services.items()}, "workers": { - worker.address: worker.identity() for worker in self.workers.values() + worker.address: worker.identity() for worker in self._workers.values() }, } return d @@ -3311,7 +3302,7 @@ def get_worker_service_addr(self, worker, service_name, protocol=False): Whether or not to include a full address with protocol (True) or just a (host, port) pair """ - ws: WorkerState = self.workers[worker] + ws: WorkerState = self._workers[worker] port = ws._services.get(service_name) if port is None: return None @@ -3405,10 +3396,10 @@ async def close(self, comm=None, fast=False, close_workers=False): if close_workers: await self.broadcast(msg={"op": "close_gracefully"}, nanny=True) - for worker in self.workers: + for worker in self._workers: self.worker_send(worker, {"op": "close"}) for i in range(20): # wait a second for send signals to clear - if self.workers: + if self._workers: await asyncio.sleep(0.05) else: break @@ -3421,7 +3412,7 @@ async def close(self, comm=None, fast=False, close_workers=False): self.stop_services() - for ext in self.extensions.values(): + for ext in self._extensions.values(): with suppress(AttributeError): ext.teardown() logger.info("Scheduler closing all comms") @@ -3459,7 +3450,7 @@ async def close_worker(self, comm=None, worker=None, safe=None): logger.info("Closing worker %s", worker) with log_errors(): self.log_event(worker, {"action": "close-worker"}) - ws: WorkerState = self.workers[worker] + ws: WorkerState = self._workers[worker] nanny_addr = ws._nanny address = nanny_addr or worker @@ -3483,7 +3474,7 @@ def heartbeat_worker( ): address = self.coerce_address(address, resolve_address) address = normalize_address(address) - if address not in self.workers: + if address not in self._workers: return {"status": "missing"} host = get_address_host(address) @@ -3492,10 +3483,10 @@ def heartbeat_worker( assert metrics host_info = host_info or {} - self.host_info[host]["last-seen"] = local_now - frac = 1 / len(self.workers) - self.bandwidth = ( - self.bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac + self._host_info[host]["last-seen"] = local_now + frac = 1 / len(self._workers) + self._bandwidth = ( + self._bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac ) for other, (bw, count) in metrics["bandwidth"]["workers"].items(): if (address, other) not in self.bandwidth_workers: @@ -3514,7 +3505,7 @@ def heartbeat_worker( 1 - alpha ) - ws: WorkerState = self.workers[address] + ws: WorkerState = self._workers[address] ws._last_seen = time() @@ -3527,7 +3518,7 @@ def heartbeat_worker( ws._metrics = metrics if host_info: - self.host_info[host].update(host_info) + self._host_info[host].update(host_info) delay = time() - now ws._time_delay = delay @@ -3540,7 +3531,7 @@ def heartbeat_worker( return { "status": "OK", "time": time(), - "heartbeat-interval": heartbeat_interval(len(self.workers)), + "heartbeat-interval": heartbeat_interval(len(self._workers)), } async def add_worker( @@ -3571,7 +3562,7 @@ async def add_worker( address = normalize_address(address) host = get_address_host(address) - ws: WorkerState = self.workers.get(address) + ws: WorkerState = self._workers.get(address) if ws is not None: raise ValueError("Worker already exists %s" % ws) @@ -3588,7 +3579,7 @@ async def add_worker( await comm.write(msg) return - self.workers[address] = ws = WorkerState( + self._workers[address] = ws = WorkerState( address=address, pid=pid, nthreads=nthreads, @@ -3601,13 +3592,13 @@ async def add_worker( extra=extra, ) - if "addresses" not in self.host_info[host]: - self.host_info[host].update({"addresses": set(), "nthreads": 0}) + if "addresses" not in self._host_info[host]: + self._host_info[host].update({"addresses": set(), "nthreads": 0}) - self.host_info[host]["addresses"].add(address) - self.host_info[host]["nthreads"] += nthreads + self._host_info[host]["addresses"].add(address) + self._host_info[host]["nthreads"] += nthreads - self.total_nthreads += nthreads + self._total_nthreads += nthreads self.aliases[name] = address response = self.heartbeat_worker( @@ -3619,7 +3610,7 @@ async def add_worker( metrics=metrics, ) - # Do not need to adjust self.total_occupancy as self.occupancy[ws] cannot exist before this. + # Do not need to adjust self._total_occupancy as self.occupancy[ws] cannot exist before this. self.check_idle_saturated(ws) # for key in keys: # TODO @@ -3628,7 +3619,7 @@ async def add_worker( self.stream_comms[address] = BatchedSend(interval="5ms", loop=self.loop) if ws._nthreads > len(ws._processing): - self.idle[ws._address] = ws + self._idle[ws._address] = ws for plugin in self.plugins[:]: try: @@ -3641,7 +3632,7 @@ async def add_worker( recommendations: dict if nbytes: for key in nbytes: - tasks: dict = self.tasks + tasks: dict = self._tasks ts: TaskState = tasks.get(key) if ts is not None and ts._state in ("processing", "waiting"): recommendations = self.transition( @@ -3654,7 +3645,7 @@ async def add_worker( self.transitions(recommendations) recommendations = {} - for ts in list(self.unrunnable): + for ts in list(self._unrunnable): valid: set = self.valid_workers(ts) if valid is None or ws in valid: recommendations[ts._key] = "waiting" @@ -3669,7 +3660,7 @@ async def add_worker( msg = { "status": "OK", "time": time(), - "heartbeat-interval": heartbeat_interval(len(self.workers)), + "heartbeat-interval": heartbeat_interval(len(self._workers)), "worker-plugins": self.worker_plugins, } @@ -3677,8 +3668,12 @@ async def add_worker( version_warning = version_module.error_message( version_module.get_versions(), merge( - {w: ws._versions for w, ws in self.workers.items()}, - {c: cs._versions for c, cs in self.clients.items() if cs._versions}, + {w: ws._versions for w, ws in self._workers.items()}, + { + c: cs._versions + for c, cs in self._clients.items() + if cs._versions + }, ), versions, client_name="This Worker", @@ -3784,7 +3779,7 @@ def update_graph( n = len(tasks) for k, deps in list(dependencies.items()): if any( - dep not in self.tasks and dep not in tasks for dep in deps + dep not in self._tasks and dep not in tasks for dep in deps ): # bad key logger.info("User asked for computation on lost data, %s", k) del tasks[k] @@ -3798,8 +3793,8 @@ def update_graph( ts: TaskState already_in_memory = set() # tasks that are already done for k, v in dependencies.items(): - if v and k in self.tasks: - ts = self.tasks[k] + if v and k in self._tasks: + ts = self._tasks[k] if ts._state in ("memory", "erred"): already_in_memory.add(k) @@ -3810,7 +3805,7 @@ def update_graph( done = set(already_in_memory) while stack: # remove unnecessary dependencies key = stack.pop() - ts = self.tasks[key] + ts = self._tasks[key] try: deps = dependencies[key] except KeyError: @@ -3821,7 +3816,7 @@ def update_graph( else: child_deps = self.dependencies[dep] if all(d in done for d in child_deps): - if dep in self.tasks and dep not in done: + if dep in self._tasks and dep not in done: done.add(dep) stack.append(dep) @@ -3838,7 +3833,7 @@ def update_graph( if k in touched_keys: continue # XXX Have a method get_task_state(self, k) ? - ts = self.tasks.get(k) + ts = self._tasks.get(k) if ts is None: ts = self.new_task(k, tasks.get(k), "released") elif not ts._run_spec: @@ -3852,11 +3847,11 @@ def update_graph( # Add dependencies for key, deps in dependencies.items(): - ts = self.tasks.get(key) + ts = self._tasks.get(key) if ts is None or ts._dependencies: continue for dep in deps: - dts = self.tasks[dep] + dts = self._tasks[dep] ts.add_dependency(dts) # Compute priorities @@ -3890,14 +3885,14 @@ def update_graph( for a, kv in annotations.items(): for k, v in kv.items(): - ts = self.tasks[k] + ts = self._tasks[k] ts._annotations[a] = v # Add actors if actors is True: actors = list(keys) for actor in actors or []: - ts = self.tasks[actor] + ts = self._tasks[actor] ts._actor = True priority = priority or dask.order.order( @@ -3905,7 +3900,7 @@ def update_graph( ) # TODO: define order wrt old graph if submitting_task: # sub-tasks get better priority than parent tasks - ts = self.tasks.get(submitting_task) + ts = self._tasks.get(submitting_task) if ts is not None: generation = ts._priority[0] - 0.01 else: # super-task already cleaned up @@ -3918,7 +3913,7 @@ def update_graph( generation = self.generation for key in set(priority) & touched_keys: - ts = self.tasks[key] + ts = self._tasks[key] if ts._priority is None: ts._priority = (-(user_priority.get(key, 0)), generation, priority[key]) @@ -3934,7 +3929,7 @@ def update_graph( for k, v in restrictions.items(): if v is None: continue - ts = self.tasks.get(k) + ts = self._tasks.get(k) if ts is None: continue ts._host_restrictions = set() @@ -3950,7 +3945,7 @@ def update_graph( if loose_restrictions: for k in loose_restrictions: - ts = self.tasks[k] + ts = self._tasks[k] ts._loose_restrictions = True if resources: @@ -3958,7 +3953,7 @@ def update_graph( if v is None: continue assert isinstance(v, dict) - ts = self.tasks.get(k) + ts = self._tasks.get(k) if ts is None: continue ts._resource_restrictions = v @@ -3966,7 +3961,7 @@ def update_graph( if retries: for k, v in retries.items(): assert isinstance(v, int) - ts = self.tasks.get(k) + ts = self._tasks.get(k) if ts is None: continue ts._retries = v @@ -4035,18 +4030,18 @@ def new_task(self, key, spec, state): tg._prefix = tp tp._groups.append(tg) tg.add(ts) - self.tasks[key] = ts + self._tasks[key] = ts return ts def stimulus_task_finished(self, key=None, worker=None, **kwargs): """ Mark that a task has finished execution on a particular worker """ logger.debug("Stimulus task finished %s, %s", key, worker) - tasks: dict = self.tasks + tasks: dict = self._tasks ts: TaskState = tasks.get(key) if ts is None: return {} - workers: dict = cast(dict, self.workers) + workers: dict = cast(dict, self._workers) ws: WorkerState = workers[worker] ts._metadata.update(kwargs["metadata"]) @@ -4077,7 +4072,7 @@ def stimulus_task_erred( """ Mark that a task has erred on a particular worker """ logger.debug("Stimulus task erred %s, %s", key, worker) - ts: TaskState = self.tasks.get(key) + ts: TaskState = self._tasks.get(key) if ts is None: return {} @@ -4109,10 +4104,10 @@ def stimulus_missing_data( with log_errors(): logger.debug("Stimulus missing data %s, %s", key, worker) - ts: TaskState = self.tasks.get(key) + ts: TaskState = self._tasks.get(key) if ts is None or ts._state == "memory": return {} - cts: TaskState = self.tasks.get(cause) + cts: TaskState = self._tasks.get(cause) recommendations: dict = {} @@ -4129,7 +4124,7 @@ def stimulus_missing_data( self.transitions(recommendations) - if self.validate: + if self._validate: assert cause not in self.who_has return {} @@ -4147,7 +4142,7 @@ def stimulus_retry(self, comm=None, keys=None, client=None): while stack: key = stack.pop() seen.add(key) - ts = self.tasks[key] + ts = self._tasks[key] erred_deps = [dts._key for dts in ts._dependencies if dts._state == "erred"] if erred_deps: stack.extend(erred_deps) @@ -4157,9 +4152,9 @@ def stimulus_retry(self, comm=None, keys=None, client=None): recommendations: dict = {key: "waiting" for key in roots} self.transitions(recommendations) - if self.validate: + if self._validate: for key in seen: - assert not self.tasks[key].exception_blame + assert not self._tasks[key].exception_blame return tuple(seen) @@ -4177,12 +4172,12 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): address = self.coerce_address(address) - if address not in self.workers: + if address not in self._workers: return "already-removed" host = get_address_host(address) - ws: WorkerState = self.workers[address] + ws: WorkerState = self._workers[address] self.log_event( ["all", address], @@ -4199,21 +4194,21 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): self.remove_resources(address) - self.host_info[host]["nthreads"] -= ws._nthreads - self.host_info[host]["addresses"].remove(address) - self.total_nthreads -= ws._nthreads + self._host_info[host]["nthreads"] -= ws._nthreads + self._host_info[host]["addresses"].remove(address) + self._total_nthreads -= ws._nthreads - if not self.host_info[host]["addresses"]: - del self.host_info[host] + if not self._host_info[host]["addresses"]: + del self._host_info[host] self.rpc.remove(address) del self.stream_comms[address] del self.aliases[ws._name] - self.idle.pop(ws._address, None) - self.saturated.discard(ws) - del self.workers[address] + self._idle.pop(ws._address, None) + self._saturated.discard(ws) + del self._workers[address] ws.status = Status.closed - self.total_occupancy -= ws._occupancy + self._total_occupancy -= ws._occupancy recommendations: dict = {} @@ -4257,16 +4252,16 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): except Exception as e: logger.exception(e) - if not self.workers: + if not self._workers: logger.info("Lost all workers") - for w in self.workers: + for w in self._workers: self.bandwidth_workers.pop((address, w), None) self.bandwidth_workers.pop((w, address), None) def remove_worker_from_events(): # If the worker isn't registered anymore after the delay, remove from events - if address not in self.workers and address in self.events: + if address not in self._workers and address in self.events: del self.events[address] cleanup_delay = parse_timedelta( @@ -4290,10 +4285,10 @@ def stimulus_cancel(self, comm, keys=None, client=None, force=False): def cancel_key(self, key, client, retries=5, force=False): """ Cancel a particular key and all dependents """ # TODO: this should be converted to use the transition mechanism - ts: TaskState = self.tasks.get(key) + ts: TaskState = self._tasks.get(key) dts: TaskState try: - cs: ClientState = self.clients[client] + cs: ClientState = self._clients[client] except KeyError: return if ts is None or not ts._who_wants: # no key yet, lets try again in a moment @@ -4312,13 +4307,13 @@ def cancel_key(self, key, client, retries=5, force=False): self.client_releases_keys(keys=[key], client=cs._client_key) def client_desires_keys(self, keys=None, client=None): - cs: ClientState = self.clients.get(client) + cs: ClientState = self._clients.get(client) if cs is None: # For publish, queues etc. - self.clients[client] = cs = ClientState(client) + self._clients[client] = cs = ClientState(client) ts: TaskState for k in keys: - ts = self.tasks.get(k) + ts = self._tasks.get(k) if ts is None: # For publish, queues etc. ts = self.new_task(k, None, "released") @@ -4341,7 +4336,7 @@ def client_releases_keys(self, keys=None, client=None): def client_heartbeat(self, client=None): """ Handle heartbeats from Client """ - cs: ClientState = self.clients[client] + cs: ClientState = self._clients[client] cs._last_seen = time() ################### @@ -4349,7 +4344,7 @@ def client_heartbeat(self, client=None): ################### def validate_released(self, key): - ts: TaskState = self.tasks[key] + ts: TaskState = self._tasks[key] dts: TaskState assert ts._state == "released" assert not ts._waiters @@ -4357,22 +4352,22 @@ def validate_released(self, key): assert not ts._who_has assert not ts._processing_on assert not any([ts in dts._waiters for dts in ts._dependencies]) - assert ts not in self.unrunnable + assert ts not in self._unrunnable def validate_waiting(self, key): - ts: TaskState = self.tasks[key] + ts: TaskState = self._tasks[key] dts: TaskState assert ts._waiting_on assert not ts._who_has assert not ts._processing_on - assert ts not in self.unrunnable + assert ts not in self._unrunnable for dts in ts._dependencies: # We are waiting on a dependency iff it's not stored assert (not not dts._who_has) != (dts in ts._waiting_on) assert ts in dts._waiters # XXX even if dts._who_has? def validate_processing(self, key): - ts: TaskState = self.tasks[key] + ts: TaskState = self._tasks[key] dts: TaskState assert not ts._waiting_on ws: WorkerState = ts._processing_on @@ -4384,36 +4379,36 @@ def validate_processing(self, key): assert ts in dts._waiters def validate_memory(self, key): - ts: TaskState = self.tasks[key] + ts: TaskState = self._tasks[key] dts: TaskState assert ts._who_has assert not ts._processing_on assert not ts._waiting_on - assert ts not in self.unrunnable + assert ts not in self._unrunnable for dts in ts._dependents: assert (dts in ts._waiters) == (dts._state in ("waiting", "processing")) assert ts not in dts._waiting_on def validate_no_worker(self, key): - ts: TaskState = self.tasks[key] + ts: TaskState = self._tasks[key] dts: TaskState - assert ts in self.unrunnable + assert ts in self._unrunnable assert not ts._waiting_on - assert ts in self.unrunnable + assert ts in self._unrunnable assert not ts._processing_on assert not ts._who_has for dts in ts._dependencies: assert dts._who_has def validate_erred(self, key): - ts: TaskState = self.tasks[key] + ts: TaskState = self._tasks[key] assert ts._exception_blame assert not ts._who_has def validate_key(self, key, ts: TaskState = None): try: if ts is None: - ts = self.tasks.get(key) + ts = self._tasks.get(key) if ts is None: logger.debug("Key lost: %s", key) else: @@ -4435,49 +4430,49 @@ def validate_key(self, key, ts: TaskState = None): raise def validate_state(self, allow_overlap=False): - validate_state(self.tasks, self.workers, self.clients) + validate_state(self._tasks, self._workers, self._clients) - if not (set(self.workers) == set(self.stream_comms)): + if not (set(self._workers) == set(self.stream_comms)): raise ValueError("Workers not the same in all collections") ws: WorkerState - for w, ws in self.workers.items(): + for w, ws in self._workers.items(): assert isinstance(w, str), (type(w), w) assert isinstance(ws, WorkerState), (type(ws), ws) assert ws._address == w if not ws._processing: assert not ws._occupancy - assert ws._address in cast(dict, self.idle) + assert ws._address in cast(dict, self._idle) ts: TaskState - for k, ts in self.tasks.items(): + for k, ts in self._tasks.items(): assert isinstance(ts, TaskState), (type(ts), ts) assert ts._key == k self.validate_key(k, ts) c: str cs: ClientState - for c, cs in self.clients.items(): + for c, cs in self._clients.items(): # client=None is often used in tests... assert c is None or type(c) == str, (type(c), c) assert type(cs) == ClientState, (type(cs), cs) assert cs._client_key == c - a = {w: ws._nbytes for w, ws in self.workers.items()} + a = {w: ws._nbytes for w, ws in self._workers.items()} b = { w: sum(ts.get_nbytes() for ts in ws._has_what) - for w, ws in self.workers.items() + for w, ws in self._workers.items() } assert a == b, (a, b) actual_total_occupancy = 0 - for worker, ws in self.workers.items(): + for worker, ws in self._workers.items(): assert abs(sum(ws._processing.values()) - ws._occupancy) < 1e-8 actual_total_occupancy += ws._occupancy - assert abs(actual_total_occupancy - self.total_occupancy) < 1e-8, ( + assert abs(actual_total_occupancy - self._total_occupancy) < 1e-8, ( actual_total_occupancy, - self.total_occupancy, + self._total_occupancy, ) ################### @@ -4494,7 +4489,7 @@ def report(self, msg: dict, ts: TaskState = None, client: str = None): if ts is None: msg_key = msg.get("key") if msg_key is not None: - tasks: dict = self.tasks + tasks: dict = self._tasks ts = tasks.get(msg_key) cs: ClientState @@ -4534,7 +4529,7 @@ async def add_client(self, comm, client=None, versions=None): comm.name = "Scheduler->Client" logger.info("Receive client connection: %s", client) self.log_event(["all", client], {"action": "add-client", "client": client}) - self.clients[client] = ClientState(client, versions=versions) + self._clients[client] = ClientState(client, versions=versions) for plugin in self.plugins[:]: try: @@ -4550,7 +4545,7 @@ async def add_client(self, comm, client=None, versions=None): ws: WorkerState version_warning = version_module.error_message( version_module.get_versions(), - {w: ws._versions for w, ws in self.workers.items()}, + {w: ws._versions for w, ws in self._workers.items()}, versions, ) msg.update(version_warning) @@ -4579,7 +4574,7 @@ def remove_client(self, client=None): logger.info("Remove client %s", client) self.log_event(["all", client], {"action": "remove-client", "client": client}) try: - cs: ClientState = self.clients[client] + cs: ClientState = self._clients[client] except KeyError: # XXX is this a legitimate condition? pass @@ -4588,7 +4583,7 @@ def remove_client(self, client=None): self.client_releases_keys( keys=[ts._key for ts in cs._wants_what], client=cs._client_key ) - del self.clients[client] + del self._clients[client] for plugin in self.plugins[:]: try: @@ -4598,7 +4593,7 @@ def remove_client(self, client=None): def remove_client_from_events(): # If the client isn't registered anymore after the delay, remove from events - if client not in self.clients and client in self.events: + if client not in self._clients and client in self.events: del self.events[client] cleanup_delay = parse_timedelta( @@ -4623,7 +4618,7 @@ def handle_uncaught_error(self, **msg): logger.exception(clean_exception(**msg)[1]) def handle_task_finished(self, key=None, worker=None, **msg): - if worker not in self.workers: + if worker not in self._workers: return validate_key(key) r = self.stimulus_task_finished(key=key, worker=worker, **msg) @@ -4634,10 +4629,10 @@ def handle_task_erred(self, key=None, **msg): self.transitions(r) def handle_release_data(self, key=None, worker=None, client=None, **msg): - ts: TaskState = self.tasks.get(key) + ts: TaskState = self._tasks.get(key) if ts is None: return - ws: WorkerState = self.workers[worker] + ws: WorkerState = self._workers[worker] if ts._processing_on != ws: return r = self.stimulus_missing_data(key=key, ensure=False, **msg) @@ -4647,11 +4642,11 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): logger.debug("handle missing data key=%s worker=%s", key, errant_worker) self.log.append(("missing", key, errant_worker)) - ts: TaskState = self.tasks.get(key) + ts: TaskState = self._tasks.get(key) if ts is None or not ts._who_has: return - if errant_worker in self.workers: - ws: WorkerState = self.workers[errant_worker] + if errant_worker in self._workers: + ws: WorkerState = self._workers[errant_worker] if ws in ts._who_has: ts._who_has.remove(ws) ws._has_what.remove(ts) @@ -4663,8 +4658,8 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): self.transitions({key: "forgotten"}) def release_worker_data(self, comm=None, keys=None, worker=None): - ws: WorkerState = self.workers[worker] - tasks = {self.tasks[k] for k in keys} + ws: WorkerState = self._workers[worker] + tasks = {self._tasks[k] for k in keys} removed_tasks = tasks & ws._has_what ws._has_what -= removed_tasks @@ -4685,9 +4680,9 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): We stop the task from being stolen in the future, and change task duration accounting as if the task has stopped. """ - ts: TaskState = self.tasks[key] - if "stealing" in self.extensions: - self.extensions["stealing"].remove_key_from_stealable(ts) + ts: TaskState = self._tasks[key] + if "stealing" in self._extensions: + self._extensions["stealing"].remove_key_from_stealable(ts) ws: WorkerState = ts._processing_on if ws is None: @@ -4705,7 +4700,7 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): ts._prefix._duration_average = avg_duration ws._occupancy -= ws._processing[ts] - self.total_occupancy -= ws._processing[ts] + self._total_occupancy -= ws._processing[ts] ws._processing[ts] = 0 self.check_idle_saturated(ws) @@ -4792,17 +4787,17 @@ async def scatter( Scheduler.broadcast: """ start = time() - while not self.workers: + while not self._workers: await asyncio.sleep(0.2) if time() > start + timeout: raise TimeoutError("No workers found") if workers is None: ws: WorkerState - nthreads = {w: ws._nthreads for w, ws in self.workers.items()} + nthreads = {w: ws._nthreads for w, ws in self._workers.items()} else: workers = [self.coerce_address(w) for w in workers] - nthreads = {w: self.workers[w].nthreads for w in workers} + nthreads = {w: self._workers[w].nthreads for w in workers} assert isinstance(data, dict) @@ -4830,7 +4825,7 @@ async def gather(self, comm=None, keys=None, serializers=None): keys = list(keys) who_has = {} for key in keys: - ts: TaskState = self.tasks.get(key) + ts: TaskState = self._tasks.get(key) if ts is not None: who_has[key] = [ws._address for ws in ts._who_has] else: @@ -4843,7 +4838,7 @@ async def gather(self, comm=None, keys=None, serializers=None): result = {"status": "OK", "data": data} else: missing_states = [ - (self.tasks[key].state if key in self.tasks else None) + (self._tasks[key].state if key in self._tasks else None) for key in missing_keys ] logger.exception( @@ -4865,7 +4860,7 @@ async def gather(self, comm=None, keys=None, serializers=None): for key, workers in missing_keys.items(): # Task may already be gone if it was held by a # `missing_worker` - ts: TaskState = self.tasks.get(key) + ts: TaskState = self._tasks.get(key) logger.exception( "Workers don't have promised key: %s, %s", str(workers), @@ -4874,7 +4869,7 @@ async def gather(self, comm=None, keys=None, serializers=None): if not workers or ts is None: continue for worker in workers: - ws = self.workers.get(worker) + ws = self._workers.get(worker) if ws is not None and ts in ws._has_what: ws._has_what.remove(ts) ts._who_has.remove(ws) @@ -4895,20 +4890,20 @@ async def restart(self, client=None, timeout=3): """ Restart all workers. Reset local state. """ with log_errors(): - n_workers = len(self.workers) + n_workers = len(self._workers) logger.info("Send lost future signal to clients") cs: ClientState ts: TaskState - for cs in self.clients.values(): + for cs in self._clients.values(): self.client_releases_keys( keys=[ts._key for ts in cs._wants_what], client=cs._client_key ) ws: WorkerState - nannies = {addr: ws._nanny for addr, ws in self.workers.items()} + nannies = {addr: ws._nanny for addr, ws in self._workers.items()} - for addr in list(self.workers): + for addr in list(self._workers): try: # Ask the worker to close if it doesn't have a nanny, # otherwise the nanny will kill it anyway @@ -4965,7 +4960,7 @@ async def restart(self, client=None, timeout=3): self.log_event([client, "all"], {"action": "restart", "client": client}) start = time() - while time() < start + 10 and len(self.workers) < n_workers: + while time() < start + 10 and len(self._workers) < n_workers: await asyncio.sleep(0.01) self.report({"op": "restart"}) @@ -4982,17 +4977,17 @@ async def broadcast( """ Broadcast message to workers, return all results """ if workers is None or workers is True: if hosts is None: - workers = list(self.workers) + workers = list(self._workers) else: workers = [] if hosts is not None: for host in hosts: - if host in self.host_info: - workers.extend(self.host_info[host]["addresses"]) + if host in self._host_info: + workers.extend(self._host_info[host]["addresses"]) # TODO replace with worker_list if nanny: - addresses = [self.workers[w].nanny for w in workers] + addresses = [self._workers[w].nanny for w in workers] else: addresses = workers @@ -5032,9 +5027,9 @@ async def _delete_worker_data(self, worker_address, keys): self.rpc(addr=worker_address).delete_data, keys=list(keys), report=False ) - ws: WorkerState = self.workers[worker_address] + ws: WorkerState = self._workers[worker_address] ts: TaskState - tasks: set = {self.tasks[key] for key in keys} + tasks: set = {self._tasks[key] for key in keys} ws._has_what -= tasks for ts in tasks: ts._who_has.remove(ws) @@ -5056,18 +5051,18 @@ async def rebalance(self, comm=None, keys=None, workers=None): with log_errors(): async with self._lock: if keys: - tasks = {self.tasks[k] for k in keys} + tasks = {self._tasks[k] for k in keys} missing_data = [ts._key for ts in tasks if not ts._who_has] if missing_data: return {"status": "missing-data", "keys": missing_data} else: - tasks = set(self.tasks.values()) + tasks = set(self._tasks.values()) if workers: - workers = {self.workers[w] for w in workers} + workers = {self._workers[w] for w in workers} workers_by_task = {ts: ts._who_has & workers for ts in tasks} else: - workers = set(self.workers.values()) + workers = set(self._workers.values()) workers_by_task = {ts: ts._who_has for ts in tasks} ws: WorkerState @@ -5213,7 +5208,7 @@ async def replicate( assert branching_factor > 0 async with self._lock if lock else empty_context: - workers = {self.workers[w] for w in self.workers_list(workers)} + workers = {self._workers[w] for w in self.workers_list(workers)} if n is None: n = len(workers) else: @@ -5221,7 +5216,7 @@ async def replicate( if n == 0: raise ValueError("Can not use replicate to delete data") - tasks = {self.tasks[k] for k in keys} + tasks = {self._tasks[k] for k in keys} missing_data = [ts._key for ts in tasks if not ts._who_has] if missing_data: return {"status": "missing-data", "keys": missing_data} @@ -5358,18 +5353,18 @@ def workers_to_close( Scheduler.retire_workers """ if target is not None and n is None: - n = len(self.workers) - target + n = len(self._workers) - target if n is not None: if n < 0: n = 0 - target = len(self.workers) - n + target = len(self._workers) - n if n is None and memory_ratio is None: memory_ratio = 2 ws: WorkerState with log_errors(): - if not n and all([ws._processing for ws in self.workers.values()]): + if not n and all([ws._processing for ws in self._workers.values()]): return [] if key is None: @@ -5379,7 +5374,7 @@ def workers_to_close( ): key = pickle.loads(key) - groups = groupby(key, self.workers.values()) + groups = groupby(key, self._workers.values()) limit_bytes = { k: sum([ws._memory_limit for ws in v]) for k, v in groups.items() @@ -5398,7 +5393,7 @@ def _key(group): idle = sorted(groups, key=_key) to_close = [] - n_remain = len(self.workers) + n_remain = len(self._workers) while idle: group = idle.pop() @@ -5474,7 +5469,7 @@ async def retire_workers( names = set(map(str, names)) workers = [ ws._address - for ws in self.workers.values() + for ws in self._workers.values() if str(ws._name) in names ] if workers is None: @@ -5493,7 +5488,7 @@ async def retire_workers( return {} except KeyError: # keys left during replicate pass - workers = {self.workers[w] for w in workers if w in self.workers} + workers = {self._workers[w] for w in workers if w in self._workers} if not workers: return {} logger.info("Retire workers %s", workers) @@ -5502,7 +5497,7 @@ async def retire_workers( keys = set.union(*[w.has_what for w in workers]) keys = {ts._key for ts in keys if ts._who_has.issubset(workers)} - other_workers = set(self.workers.values()) - workers + other_workers = set(self._workers.values()) - workers if keys: if other_workers: logger.info("Moving %d keys to other workers", len(keys)) @@ -5545,11 +5540,11 @@ def add_keys(self, comm=None, worker=None, keys=()): This should not be used in practice and is mostly here for legacy reasons. However, it is sent by workers from time to time. """ - if worker not in self.workers: + if worker not in self._workers: return "not found" - ws: WorkerState = self.workers[worker] + ws: WorkerState = self._workers[worker] for key in keys: - ts: TaskState = self.tasks.get(key) + ts: TaskState = self._tasks.get(key) if ts is not None and ts._state == "memory": if ts not in ws._has_what: ws._nbytes += ts.get_nbytes() @@ -5579,14 +5574,14 @@ def update_data( logger.debug("Update data %s", who_has) for key, workers in who_has.items(): - ts: TaskState = self.tasks.get(key) + ts: TaskState = self._tasks.get(key) if ts is None: ts: TaskState = self.new_task(key, None, "memory") ts.state = "memory" if key in nbytes: ts.set_nbytes(nbytes[key]) for w in workers: - ws: WorkerState = self.workers[w] + ws: WorkerState = self._workers[w] if ts not in ws._has_what: ws._nbytes += ts.get_nbytes() ws._has_what.add(ts) @@ -5600,7 +5595,7 @@ def update_data( def report_on_key(self, key: str = None, ts: TaskState = None, client: str = None): if ts is None: - tasks: dict = self.tasks + tasks: dict = self._tasks ts = tasks.get(key) elif key is None: key = ts._key @@ -5671,10 +5666,10 @@ def get_processing(self, comm=None, workers=None): ts: TaskState if workers is not None: workers = set(map(self.coerce_address, workers)) - return {w: [ts._key for ts in self.workers[w].processing] for w in workers} + return {w: [ts._key for ts in self._workers[w].processing] for w in workers} else: return { - w: [ts._key for ts in ws._processing] for w, ws in self.workers.items() + w: [ts._key for ts in ws._processing] for w, ws in self._workers.items() } def get_who_has(self, comm=None, keys=None): @@ -5682,15 +5677,15 @@ def get_who_has(self, comm=None, keys=None): ts: TaskState if keys is not None: return { - k: [ws._address for ws in self.tasks[k].who_has] - if k in self.tasks + k: [ws._address for ws in self._tasks[k].who_has] + if k in self._tasks else [] for k in keys } else: return { key: [ws._address for ws in ts._who_has] - for key, ts in self.tasks.items() + for key, ts in self._tasks.items() } def get_has_what(self, comm=None, workers=None): @@ -5699,23 +5694,23 @@ def get_has_what(self, comm=None, workers=None): if workers is not None: workers = map(self.coerce_address, workers) return { - w: [ts._key for ts in self.workers[w].has_what] - if w in self.workers + w: [ts._key for ts in self._workers[w].has_what] + if w in self._workers else [] for w in workers } else: return { - w: [ts._key for ts in ws._has_what] for w, ws in self.workers.items() + w: [ts._key for ts in ws._has_what] for w, ws in self._workers.items() } def get_ncores(self, comm=None, workers=None): ws: WorkerState if workers is not None: workers = map(self.coerce_address, workers) - return {w: self.workers[w].nthreads for w in workers if w in self.workers} + return {w: self._workers[w].nthreads for w in workers if w in self._workers} else: - return {w: ws._nthreads for w, ws in self.workers.items()} + return {w: ws._nthreads for w, ws in self._workers.items()} async def get_call_stack(self, comm=None, keys=None): ts: TaskState @@ -5725,7 +5720,7 @@ async def get_call_stack(self, comm=None, keys=None): processing = set() while stack: key = stack.pop() - ts = self.tasks[key] + ts = self._tasks[key] if ts._state == "waiting": stack.extend([dts._key for dts in ts._dependencies]) elif ts._state == "processing": @@ -5736,7 +5731,7 @@ async def get_call_stack(self, comm=None, keys=None): if ts._processing_on: workers[ts._processing_on.address].append(ts._key) else: - workers = {w: None for w in self.workers} + workers = {w: None for w in self._workers} if not workers: return {} @@ -5751,10 +5746,10 @@ def get_nbytes(self, comm=None, keys=None, summary=True): ts: TaskState with log_errors(): if keys is not None: - result = {k: self.tasks[k].nbytes for k in keys} + result = {k: self._tasks[k].nbytes for k in keys} else: result = { - k: ts._nbytes for k, ts in self.tasks.items() if ts._nbytes >= 0 + k: ts._nbytes for k, ts in self._tasks.items() if ts._nbytes >= 0 } if summary: @@ -5804,7 +5799,8 @@ def get_metadata(self, comm=None, keys=None, default=no_default): def get_task_status(self, comm=None, keys=None): return { - key: (self.tasks[key].state if key in self.tasks else None) for key in keys + key: (self._tasks[key].state if key in self._tasks else None) + for key in keys } def get_task_stream(self, comm=None, start=None, stop=None, count=None): @@ -5849,10 +5845,10 @@ async def register_worker_plugin(self, comm, plugin, name=None): ##################### def remove_key(self, key): - tasks: dict = self.tasks + tasks: dict = self._tasks ts: TaskState = tasks.pop(key) assert ts._state == "forgotten" - self.unrunnable.discard(ts) + self._unrunnable.discard(ts) cs: ClientState for cs in ts._who_wants: cs._wants_what.remove(ts) @@ -5882,7 +5878,7 @@ def transition(self, key, finish, *args, **kwargs): client_msgs: dict try: try: - ts = self.tasks[key] + ts = self._tasks[key] except KeyError: return {} start = ts._state @@ -5922,7 +5918,7 @@ def transition(self, key, finish, *args, **kwargs): finish2 = ts._state self.transition_log.append((key, start, finish2, recommendations, time())) - if self.validate: + if self._validate: logger.debug( "Transitioned %r %s->%s (actual: %s). Consequence: %s", key, @@ -5939,14 +5935,14 @@ def transition(self, key, finish, *args, **kwargs): ts._dependencies = dependencies except KeyError: pass - self.tasks[ts._key] = ts + self._tasks[ts._key] = ts for plugin in list(self.plugins): try: plugin.transition(key, start, finish2, *args, **kwargs) except Exception: logger.info("Plugin failed with exception", exc_info=True) if ts._state == "forgotten": - del self.tasks[ts._key] + del self._tasks[ts._key] if ts._state == "forgotten" and ts._group._name in self.task_groups: # Remove TaskGroup if all tasks are in the forgotten state @@ -5978,7 +5974,7 @@ def transitions(self, recommendations: dict): new = self.transition(key, finish) recommendations.update(new) - if self.validate: + if self._validate: for key in keys: self.validate_key(key) @@ -5999,7 +5995,7 @@ def reschedule(self, key=None, worker=None): """ ts: TaskState try: - ts = self.tasks[key] + ts = self._tasks[key] except KeyError: logger.warning( "Attempting to reschedule task {}, which was not " @@ -6031,19 +6027,19 @@ def release_resources(self, ts: TaskState, ws: WorkerState): ##################### def add_resources(self, comm=None, worker=None, resources=None): - ws: WorkerState = self.workers[worker] + ws: WorkerState = self._workers[worker] if resources: ws._resources.update(resources) ws._used_resources = {} for resource, quantity in ws._resources.items(): ws._used_resources[resource] = 0 - self.resources[resource][worker] = quantity + self._resources[resource][worker] = quantity return "OK" def remove_resources(self, worker): - ws: WorkerState = self.workers[worker] + ws: WorkerState = self._workers[worker] for resource, quantity in ws._resources.items(): - del self.resources[resource][worker] + del self._resources[resource][worker] def coerce_address(self, addr, resolve=True): """ @@ -6072,7 +6068,7 @@ def coerce_hostname(self, host): Coerce the hostname of a worker. """ if host in self.aliases: - return self.workers[self.aliases[host]].host + return self._workers[self.aliases[host]].host else: return host @@ -6084,14 +6080,14 @@ def workers_list(self, workers): Returns a list of all worker addresses that match """ if workers is None: - return list(self.workers) + return list(self._workers) out = set() for w in workers: if ":" in w: out.add(w) else: - out.update({ww for ww in self.workers if w in ww}) # TODO: quadratic + out.update({ww for ww in self._workers if w in ww}) # TODO: quadratic return list(out) def start_ipython(self, comm=None): @@ -6119,9 +6115,9 @@ async def get_profile( key=None, ): if workers is None: - workers = self.workers + workers = self._workers else: - workers = set(self.workers) & set(workers) + workers = set(self._workers) & set(workers) if scheduler: return profile.get_profile(self.io_loop.profile, start=start, stop=stop) @@ -6157,9 +6153,9 @@ async def get_profile_metadata( dt = parse_timedelta(dt, default="ms") if workers is None: - workers = self.workers + workers = self._workers else: - workers = set(self.workers) & set(workers) + workers = set(self._workers) & set(workers) results = await asyncio.gather( *(self.rpc(w).profile_metadata(start=start, stop=stop) for w in workers), return_exceptions=True, @@ -6277,10 +6273,10 @@ def profile_to_figure(state): ntasks=total_tasks, tasks_timings=tasks_timings, address=self.address, - nworkers=len(self.workers), - threads=sum([ws._nthreads for ws in self.workers.values()]), + nworkers=len(self._workers), + threads=sum([ws._nthreads for ws in self._workers.values()]), memory=format_bytes( - sum([ws._memory_limit for ws in self.workers.values()]) + sum([ws._memory_limit for ws in self._workers.values()]) ), code=code, dask_version=dask.__version__, @@ -6378,7 +6374,7 @@ def reevaluate_occupancy(self, worker_index=0): next_time = timedelta(seconds=DELAY) if self.proc.cpu_percent() < 50: - workers = list(self.workers.values()) + workers = list(self._workers.values()) for i in range(len(workers)): ws: WorkerState = workers[worker_index % len(workers)] worker_index += 1 @@ -6405,9 +6401,9 @@ def reevaluate_occupancy(self, worker_index=0): async def check_worker_ttl(self): ws: WorkerState now = time() - for ws in self.workers.values(): + for ws in self._workers.values(): if (ws._last_seen < now - self.worker_ttl) and ( - ws._last_seen < now - 10 * heartbeat_interval(len(self.workers)) + ws._last_seen < now - 10 * heartbeat_interval(len(self._workers)) ): logger.warning( "Worker failed to heartbeat within %s seconds. Closing: %s", @@ -6418,7 +6414,7 @@ async def check_worker_ttl(self): def check_idle(self): ws: WorkerState - if any([ws._processing for ws in self.workers.values()]) or self.unrunnable: + if any([ws._processing for ws in self._workers.values()]) or self._unrunnable: self.idle_since = None return elif not self.idle_since: @@ -6453,13 +6449,13 @@ def adaptive_target(self, comm=None, target_duration=None): # CPU cpu = math.ceil( - self.total_occupancy / target_duration + self._total_occupancy / target_duration ) # TODO: threads per worker # Avoid a few long tasks from asking for many cores ws: WorkerState tasks_processing = 0 - for ws in self.workers.values(): + for ws in self._workers.values(): tasks_processing += len(ws._processing) if tasks_processing > cpu: @@ -6467,25 +6463,25 @@ def adaptive_target(self, comm=None, target_duration=None): else: cpu = min(tasks_processing, cpu) - if self.unrunnable and not self.workers: + if self._unrunnable and not self._workers: cpu = max(1, cpu) # Memory - limit_bytes = {addr: ws._memory_limit for addr, ws in self.workers.items()} - worker_bytes = [ws._nbytes for ws in self.workers.values()] + limit_bytes = {addr: ws._memory_limit for addr, ws in self._workers.items()} + worker_bytes = [ws._nbytes for ws in self._workers.values()] limit = sum(limit_bytes.values()) total = sum(worker_bytes) if total > 0.6 * limit: - memory = 2 * len(self.workers) + memory = 2 * len(self._workers) else: memory = 0 target = max(memory, cpu) - if target >= len(self.workers): + if target >= len(self._workers): return target else: # Scale down? to_close = self.workers_to_close() - return len(self.workers) - len(to_close) + return len(self._workers) - len(to_close) @cfunc From 980fd675f7fadd7bdf8555ca572ff78e49f45c7b Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 12:50:14 -0800 Subject: [PATCH 19/38] Use `cast` to access parent class attributes Apparently Cython does not allow C-style access of attributes from Python subclasses of extension types. However we can work around this by `cast`ing `self` as the parent class type. To avoid muddying the code too much, do this once per method and assign the result to `parent`. Then use `parent` for all of the parent attribute access needed. Should fix this issue while still allowing fast C-style attribute access in the `Scheduler` subclass. --- distributed/scheduler.py | 485 +++++++++++++++++++++++---------------- 1 file changed, 281 insertions(+), 204 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index e5c95968596..f3c89ceb572 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3269,21 +3269,23 @@ def __init__( ################## def __repr__(self): + parent: SchedulerState = cast(SchedulerState, self) return '' % ( self.address, - len(self._workers), - self._total_nthreads, + len(parent._workers), + parent._total_nthreads, ) def identity(self, comm=None): """ Basic information about ourselves and our cluster """ + parent: SchedulerState = cast(SchedulerState, self) d = { "type": type(self).__name__, "id": str(self.id), "address": self.address, "services": {key: v.port for (key, v) in self.services.items()}, "workers": { - worker.address: worker.identity() for worker in self._workers.values() + worker.address: worker.identity() for worker in parent._workers.values() }, } return d @@ -3302,7 +3304,8 @@ def get_worker_service_addr(self, worker, service_name, protocol=False): Whether or not to include a full address with protocol (True) or just a (host, port) pair """ - ws: WorkerState = self._workers[worker] + parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState = parent._workers[worker] port = ws._services.get(service_name) if port is None: return None @@ -3383,6 +3386,7 @@ async def close(self, comm=None, fast=False, close_workers=False): -------- Scheduler.cleanup """ + parent: SchedulerState = cast(SchedulerState, self) if self.status in (Status.closing, Status.closed, Status.closing_gracefully): await self.finished() return @@ -3396,10 +3400,10 @@ async def close(self, comm=None, fast=False, close_workers=False): if close_workers: await self.broadcast(msg={"op": "close_gracefully"}, nanny=True) - for worker in self._workers: + for worker in parent._workers: self.worker_send(worker, {"op": "close"}) for i in range(20): # wait a second for send signals to clear - if self._workers: + if parent._workers: await asyncio.sleep(0.05) else: break @@ -3412,7 +3416,7 @@ async def close(self, comm=None, fast=False, close_workers=False): self.stop_services() - for ext in self._extensions.values(): + for ext in parent._extensions.values(): with suppress(AttributeError): ext.teardown() logger.info("Scheduler closing all comms") @@ -3447,10 +3451,11 @@ async def close_worker(self, comm=None, worker=None, safe=None): signal to the worker to shut down. This works regardless of whether or not the worker has a nanny process restarting it """ + parent: SchedulerState = cast(SchedulerState, self) logger.info("Closing worker %s", worker) with log_errors(): self.log_event(worker, {"action": "close-worker"}) - ws: WorkerState = self._workers[worker] + ws: WorkerState = parent._workers[worker] nanny_addr = ws._nanny address = nanny_addr or worker @@ -3472,9 +3477,10 @@ def heartbeat_worker( metrics=None, executing=None, ): + parent: SchedulerState = cast(SchedulerState, self) address = self.coerce_address(address, resolve_address) address = normalize_address(address) - if address not in self._workers: + if address not in parent._workers: return {"status": "missing"} host = get_address_host(address) @@ -3483,10 +3489,10 @@ def heartbeat_worker( assert metrics host_info = host_info or {} - self._host_info[host]["last-seen"] = local_now - frac = 1 / len(self._workers) - self._bandwidth = ( - self._bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac + parent._host_info[host]["last-seen"] = local_now + frac = 1 / len(parent._workers) + parent._bandwidth = ( + parent._bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac ) for other, (bw, count) in metrics["bandwidth"]["workers"].items(): if (address, other) not in self.bandwidth_workers: @@ -3505,7 +3511,7 @@ def heartbeat_worker( 1 - alpha ) - ws: WorkerState = self._workers[address] + ws: WorkerState = parent._workers[address] ws._last_seen = time() @@ -3518,7 +3524,7 @@ def heartbeat_worker( ws._metrics = metrics if host_info: - self._host_info[host].update(host_info) + parent._host_info[host].update(host_info) delay = time() - now ws._time_delay = delay @@ -3531,7 +3537,7 @@ def heartbeat_worker( return { "status": "OK", "time": time(), - "heartbeat-interval": heartbeat_interval(len(self._workers)), + "heartbeat-interval": heartbeat_interval(len(parent._workers)), } async def add_worker( @@ -3557,12 +3563,13 @@ async def add_worker( extra=None, ): """ Add a new worker to the cluster """ + parent: SchedulerState = cast(SchedulerState, self) with log_errors(): address = self.coerce_address(address, resolve_address) address = normalize_address(address) host = get_address_host(address) - ws: WorkerState = self._workers.get(address) + ws: WorkerState = parent._workers.get(address) if ws is not None: raise ValueError("Worker already exists %s" % ws) @@ -3579,7 +3586,7 @@ async def add_worker( await comm.write(msg) return - self._workers[address] = ws = WorkerState( + parent._workers[address] = ws = WorkerState( address=address, pid=pid, nthreads=nthreads, @@ -3592,13 +3599,13 @@ async def add_worker( extra=extra, ) - if "addresses" not in self._host_info[host]: - self._host_info[host].update({"addresses": set(), "nthreads": 0}) + if "addresses" not in parent._host_info[host]: + parent._host_info[host].update({"addresses": set(), "nthreads": 0}) - self._host_info[host]["addresses"].add(address) - self._host_info[host]["nthreads"] += nthreads + parent._host_info[host]["addresses"].add(address) + parent._host_info[host]["nthreads"] += nthreads - self._total_nthreads += nthreads + parent._total_nthreads += nthreads self.aliases[name] = address response = self.heartbeat_worker( @@ -3610,7 +3617,7 @@ async def add_worker( metrics=metrics, ) - # Do not need to adjust self._total_occupancy as self.occupancy[ws] cannot exist before this. + # Do not need to adjust parent._total_occupancy as self.occupancy[ws] cannot exist before this. self.check_idle_saturated(ws) # for key in keys: # TODO @@ -3619,7 +3626,7 @@ async def add_worker( self.stream_comms[address] = BatchedSend(interval="5ms", loop=self.loop) if ws._nthreads > len(ws._processing): - self._idle[ws._address] = ws + parent._idle[ws._address] = ws for plugin in self.plugins[:]: try: @@ -3632,7 +3639,7 @@ async def add_worker( recommendations: dict if nbytes: for key in nbytes: - tasks: dict = self._tasks + tasks: dict = parent._tasks ts: TaskState = tasks.get(key) if ts is not None and ts._state in ("processing", "waiting"): recommendations = self.transition( @@ -3645,7 +3652,7 @@ async def add_worker( self.transitions(recommendations) recommendations = {} - for ts in list(self._unrunnable): + for ts in list(parent._unrunnable): valid: set = self.valid_workers(ts) if valid is None or ws in valid: recommendations[ts._key] = "waiting" @@ -3660,7 +3667,7 @@ async def add_worker( msg = { "status": "OK", "time": time(), - "heartbeat-interval": heartbeat_interval(len(self._workers)), + "heartbeat-interval": heartbeat_interval(len(parent._workers)), "worker-plugins": self.worker_plugins, } @@ -3668,10 +3675,10 @@ async def add_worker( version_warning = version_module.error_message( version_module.get_versions(), merge( - {w: ws._versions for w, ws in self._workers.items()}, + {w: ws._versions for w, ws in parent._workers.items()}, { c: cs._versions - for c, cs in self._clients.items() + for c, cs in parent._clients.items() if cs._versions }, ), @@ -3759,6 +3766,7 @@ def update_graph( This happens whenever the Client calls submit, map, get, or compute. """ + parent: SchedulerState = cast(SchedulerState, self) start = time() fifo_timeout = parse_timedelta(fifo_timeout) keys = set(keys) @@ -3779,7 +3787,7 @@ def update_graph( n = len(tasks) for k, deps in list(dependencies.items()): if any( - dep not in self._tasks and dep not in tasks for dep in deps + dep not in parent._tasks and dep not in tasks for dep in deps ): # bad key logger.info("User asked for computation on lost data, %s", k) del tasks[k] @@ -3793,8 +3801,8 @@ def update_graph( ts: TaskState already_in_memory = set() # tasks that are already done for k, v in dependencies.items(): - if v and k in self._tasks: - ts = self._tasks[k] + if v and k in parent._tasks: + ts = parent._tasks[k] if ts._state in ("memory", "erred"): already_in_memory.add(k) @@ -3805,7 +3813,7 @@ def update_graph( done = set(already_in_memory) while stack: # remove unnecessary dependencies key = stack.pop() - ts = self._tasks[key] + ts = parent._tasks[key] try: deps = dependencies[key] except KeyError: @@ -3816,7 +3824,7 @@ def update_graph( else: child_deps = self.dependencies[dep] if all(d in done for d in child_deps): - if dep in self._tasks and dep not in done: + if dep in parent._tasks and dep not in done: done.add(dep) stack.append(dep) @@ -3833,7 +3841,7 @@ def update_graph( if k in touched_keys: continue # XXX Have a method get_task_state(self, k) ? - ts = self._tasks.get(k) + ts = parent._tasks.get(k) if ts is None: ts = self.new_task(k, tasks.get(k), "released") elif not ts._run_spec: @@ -3847,11 +3855,11 @@ def update_graph( # Add dependencies for key, deps in dependencies.items(): - ts = self._tasks.get(key) + ts = parent._tasks.get(key) if ts is None or ts._dependencies: continue for dep in deps: - dts = self._tasks[dep] + dts = parent._tasks[dep] ts.add_dependency(dts) # Compute priorities @@ -3885,14 +3893,14 @@ def update_graph( for a, kv in annotations.items(): for k, v in kv.items(): - ts = self._tasks[k] + ts = parent._tasks[k] ts._annotations[a] = v # Add actors if actors is True: actors = list(keys) for actor in actors or []: - ts = self._tasks[actor] + ts = parent._tasks[actor] ts._actor = True priority = priority or dask.order.order( @@ -3900,7 +3908,7 @@ def update_graph( ) # TODO: define order wrt old graph if submitting_task: # sub-tasks get better priority than parent tasks - ts = self._tasks.get(submitting_task) + ts = parent._tasks.get(submitting_task) if ts is not None: generation = ts._priority[0] - 0.01 else: # super-task already cleaned up @@ -3913,7 +3921,7 @@ def update_graph( generation = self.generation for key in set(priority) & touched_keys: - ts = self._tasks[key] + ts = parent._tasks[key] if ts._priority is None: ts._priority = (-(user_priority.get(key, 0)), generation, priority[key]) @@ -3929,7 +3937,7 @@ def update_graph( for k, v in restrictions.items(): if v is None: continue - ts = self._tasks.get(k) + ts = parent._tasks.get(k) if ts is None: continue ts._host_restrictions = set() @@ -3945,7 +3953,7 @@ def update_graph( if loose_restrictions: for k in loose_restrictions: - ts = self._tasks[k] + ts = parent._tasks[k] ts._loose_restrictions = True if resources: @@ -3953,7 +3961,7 @@ def update_graph( if v is None: continue assert isinstance(v, dict) - ts = self._tasks.get(k) + ts = parent._tasks.get(k) if ts is None: continue ts._resource_restrictions = v @@ -3961,7 +3969,7 @@ def update_graph( if retries: for k, v in retries.items(): assert isinstance(v, int) - ts = self._tasks.get(k) + ts = parent._tasks.get(k) if ts is None: continue ts._retries = v @@ -4011,6 +4019,7 @@ def update_graph( def new_task(self, key, spec, state): """ Create a new task, and associated states """ + parent: SchedulerState = cast(SchedulerState, self) ts: TaskState = TaskState(key, spec) tp: TaskPrefix tg: TaskGroup @@ -4030,18 +4039,19 @@ def new_task(self, key, spec, state): tg._prefix = tp tp._groups.append(tg) tg.add(ts) - self._tasks[key] = ts + parent._tasks[key] = ts return ts def stimulus_task_finished(self, key=None, worker=None, **kwargs): """ Mark that a task has finished execution on a particular worker """ + parent: SchedulerState = cast(SchedulerState, self) logger.debug("Stimulus task finished %s, %s", key, worker) - tasks: dict = self._tasks + tasks: dict = parent._tasks ts: TaskState = tasks.get(key) if ts is None: return {} - workers: dict = cast(dict, self._workers) + workers: dict = cast(dict, parent._workers) ws: WorkerState = workers[worker] ts._metadata.update(kwargs["metadata"]) @@ -4070,9 +4080,10 @@ def stimulus_task_erred( self, key=None, worker=None, exception=None, traceback=None, **kwargs ): """ Mark that a task has erred on a particular worker """ + parent: SchedulerState = cast(SchedulerState, self) logger.debug("Stimulus task erred %s, %s", key, worker) - ts: TaskState = self._tasks.get(key) + ts: TaskState = parent._tasks.get(key) if ts is None: return {} @@ -4101,13 +4112,14 @@ def stimulus_missing_data( self, cause=None, key=None, worker=None, ensure=True, **kwargs ): """ Mark that certain keys have gone missing. Recover. """ + parent: SchedulerState = cast(SchedulerState, self) with log_errors(): logger.debug("Stimulus missing data %s, %s", key, worker) - ts: TaskState = self._tasks.get(key) + ts: TaskState = parent._tasks.get(key) if ts is None or ts._state == "memory": return {} - cts: TaskState = self._tasks.get(cause) + cts: TaskState = parent._tasks.get(cause) recommendations: dict = {} @@ -4124,12 +4136,13 @@ def stimulus_missing_data( self.transitions(recommendations) - if self._validate: + if parent._validate: assert cause not in self.who_has return {} def stimulus_retry(self, comm=None, keys=None, client=None): + parent: SchedulerState = cast(SchedulerState, self) logger.info("Client %s requests to retry %d keys", client, len(keys)) if client: self.log_event(client, {"action": "retry", "count": len(keys)}) @@ -4142,7 +4155,7 @@ def stimulus_retry(self, comm=None, keys=None, client=None): while stack: key = stack.pop() seen.add(key) - ts = self._tasks[key] + ts = parent._tasks[key] erred_deps = [dts._key for dts in ts._dependencies if dts._state == "erred"] if erred_deps: stack.extend(erred_deps) @@ -4152,9 +4165,9 @@ def stimulus_retry(self, comm=None, keys=None, client=None): recommendations: dict = {key: "waiting" for key in roots} self.transitions(recommendations) - if self._validate: + if parent._validate: for key in seen: - assert not self._tasks[key].exception_blame + assert not parent._tasks[key].exception_blame return tuple(seen) @@ -4166,18 +4179,19 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): appears to be unresponsive. This may send its tasks back to a released state. """ + parent: SchedulerState = cast(SchedulerState, self) with log_errors(): if self.status == Status.closed: return address = self.coerce_address(address) - if address not in self._workers: + if address not in parent._workers: return "already-removed" host = get_address_host(address) - ws: WorkerState = self._workers[address] + ws: WorkerState = parent._workers[address] self.log_event( ["all", address], @@ -4194,21 +4208,21 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): self.remove_resources(address) - self._host_info[host]["nthreads"] -= ws._nthreads - self._host_info[host]["addresses"].remove(address) - self._total_nthreads -= ws._nthreads + parent._host_info[host]["nthreads"] -= ws._nthreads + parent._host_info[host]["addresses"].remove(address) + parent._total_nthreads -= ws._nthreads - if not self._host_info[host]["addresses"]: - del self._host_info[host] + if not parent._host_info[host]["addresses"]: + del parent._host_info[host] self.rpc.remove(address) del self.stream_comms[address] del self.aliases[ws._name] - self._idle.pop(ws._address, None) - self._saturated.discard(ws) - del self._workers[address] + parent._idle.pop(ws._address, None) + parent._saturated.discard(ws) + del parent._workers[address] ws.status = Status.closed - self._total_occupancy -= ws._occupancy + parent._total_occupancy -= ws._occupancy recommendations: dict = {} @@ -4252,16 +4266,16 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): except Exception as e: logger.exception(e) - if not self._workers: + if not parent._workers: logger.info("Lost all workers") - for w in self._workers: + for w in parent._workers: self.bandwidth_workers.pop((address, w), None) self.bandwidth_workers.pop((w, address), None) def remove_worker_from_events(): # If the worker isn't registered anymore after the delay, remove from events - if address not in self._workers and address in self.events: + if address not in parent._workers and address in self.events: del self.events[address] cleanup_delay = parse_timedelta( @@ -4285,10 +4299,11 @@ def stimulus_cancel(self, comm, keys=None, client=None, force=False): def cancel_key(self, key, client, retries=5, force=False): """ Cancel a particular key and all dependents """ # TODO: this should be converted to use the transition mechanism - ts: TaskState = self._tasks.get(key) + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState = parent._tasks.get(key) dts: TaskState try: - cs: ClientState = self._clients[client] + cs: ClientState = parent._clients[client] except KeyError: return if ts is None or not ts._who_wants: # no key yet, lets try again in a moment @@ -4307,13 +4322,14 @@ def cancel_key(self, key, client, retries=5, force=False): self.client_releases_keys(keys=[key], client=cs._client_key) def client_desires_keys(self, keys=None, client=None): - cs: ClientState = self._clients.get(client) + parent: SchedulerState = cast(SchedulerState, self) + cs: ClientState = parent._clients.get(client) if cs is None: # For publish, queues etc. - self._clients[client] = cs = ClientState(client) + parent._clients[client] = cs = ClientState(client) ts: TaskState for k in keys: - ts = self._tasks.get(k) + ts = parent._tasks.get(k) if ts is None: # For publish, queues etc. ts = self.new_task(k, None, "released") @@ -4326,9 +4342,10 @@ def client_desires_keys(self, keys=None, client=None): def client_releases_keys(self, keys=None, client=None): """ Remove keys from client desired list """ + parent: SchedulerState = cast(SchedulerState, self) if not isinstance(keys, list): keys = list(keys) - cs: ClientState = self.clients[client] + cs: ClientState = parent._clients[client] recommendations: dict = {} self._client_releases_keys(keys=keys, cs=cs, recommendations=recommendations) @@ -4336,7 +4353,8 @@ def client_releases_keys(self, keys=None, client=None): def client_heartbeat(self, client=None): """ Handle heartbeats from Client """ - cs: ClientState = self._clients[client] + parent: SchedulerState = cast(SchedulerState, self) + cs: ClientState = parent._clients[client] cs._last_seen = time() ################### @@ -4344,7 +4362,8 @@ def client_heartbeat(self, client=None): ################### def validate_released(self, key): - ts: TaskState = self._tasks[key] + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState = parent._tasks[key] dts: TaskState assert ts._state == "released" assert not ts._waiters @@ -4352,22 +4371,24 @@ def validate_released(self, key): assert not ts._who_has assert not ts._processing_on assert not any([ts in dts._waiters for dts in ts._dependencies]) - assert ts not in self._unrunnable + assert ts not in parent._unrunnable def validate_waiting(self, key): - ts: TaskState = self._tasks[key] + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState = parent._tasks[key] dts: TaskState assert ts._waiting_on assert not ts._who_has assert not ts._processing_on - assert ts not in self._unrunnable + assert ts not in parent._unrunnable for dts in ts._dependencies: # We are waiting on a dependency iff it's not stored assert (not not dts._who_has) != (dts in ts._waiting_on) assert ts in dts._waiters # XXX even if dts._who_has? def validate_processing(self, key): - ts: TaskState = self._tasks[key] + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState = parent._tasks[key] dts: TaskState assert not ts._waiting_on ws: WorkerState = ts._processing_on @@ -4379,36 +4400,40 @@ def validate_processing(self, key): assert ts in dts._waiters def validate_memory(self, key): - ts: TaskState = self._tasks[key] + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState = parent._tasks[key] dts: TaskState assert ts._who_has assert not ts._processing_on assert not ts._waiting_on - assert ts not in self._unrunnable + assert ts not in parent._unrunnable for dts in ts._dependents: assert (dts in ts._waiters) == (dts._state in ("waiting", "processing")) assert ts not in dts._waiting_on def validate_no_worker(self, key): - ts: TaskState = self._tasks[key] + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState = parent._tasks[key] dts: TaskState - assert ts in self._unrunnable + assert ts in parent._unrunnable assert not ts._waiting_on - assert ts in self._unrunnable + assert ts in parent._unrunnable assert not ts._processing_on assert not ts._who_has for dts in ts._dependencies: assert dts._who_has def validate_erred(self, key): - ts: TaskState = self._tasks[key] + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState = parent._tasks[key] assert ts._exception_blame assert not ts._who_has def validate_key(self, key, ts: TaskState = None): + parent: SchedulerState = cast(SchedulerState, self) try: if ts is None: - ts = self._tasks.get(key) + ts = parent._tasks.get(key) if ts is None: logger.debug("Key lost: %s", key) else: @@ -4430,49 +4455,50 @@ def validate_key(self, key, ts: TaskState = None): raise def validate_state(self, allow_overlap=False): - validate_state(self._tasks, self._workers, self._clients) + parent: SchedulerState = cast(SchedulerState, self) + validate_state(parent._tasks, parent._workers, parent._clients) - if not (set(self._workers) == set(self.stream_comms)): + if not (set(parent._workers) == set(self.stream_comms)): raise ValueError("Workers not the same in all collections") ws: WorkerState - for w, ws in self._workers.items(): + for w, ws in parent._workers.items(): assert isinstance(w, str), (type(w), w) assert isinstance(ws, WorkerState), (type(ws), ws) assert ws._address == w if not ws._processing: assert not ws._occupancy - assert ws._address in cast(dict, self._idle) + assert ws._address in cast(dict, parent._idle) ts: TaskState - for k, ts in self._tasks.items(): + for k, ts in parent._tasks.items(): assert isinstance(ts, TaskState), (type(ts), ts) assert ts._key == k self.validate_key(k, ts) c: str cs: ClientState - for c, cs in self._clients.items(): + for c, cs in parent._clients.items(): # client=None is often used in tests... assert c is None or type(c) == str, (type(c), c) assert type(cs) == ClientState, (type(cs), cs) assert cs._client_key == c - a = {w: ws._nbytes for w, ws in self._workers.items()} + a = {w: ws._nbytes for w, ws in parent._workers.items()} b = { w: sum(ts.get_nbytes() for ts in ws._has_what) - for w, ws in self._workers.items() + for w, ws in parent._workers.items() } assert a == b, (a, b) actual_total_occupancy = 0 - for worker, ws in self._workers.items(): + for worker, ws in parent._workers.items(): assert abs(sum(ws._processing.values()) - ws._occupancy) < 1e-8 actual_total_occupancy += ws._occupancy - assert abs(actual_total_occupancy - self._total_occupancy) < 1e-8, ( + assert abs(actual_total_occupancy - parent._total_occupancy) < 1e-8, ( actual_total_occupancy, - self._total_occupancy, + parent._total_occupancy, ) ################### @@ -4486,10 +4512,11 @@ def report(self, msg: dict, ts: TaskState = None, client: str = None): If the message contains a key then we only send the message to those comms that care about the key. """ + parent: SchedulerState = cast(SchedulerState, self) if ts is None: msg_key = msg.get("key") if msg_key is not None: - tasks: dict = self._tasks + tasks: dict = parent._tasks ts = tasks.get(msg_key) cs: ClientState @@ -4525,11 +4552,12 @@ async def add_client(self, comm, client=None, versions=None): We listen to all future messages from this Comm. """ + parent: SchedulerState = cast(SchedulerState, self) assert client is not None comm.name = "Scheduler->Client" logger.info("Receive client connection: %s", client) self.log_event(["all", client], {"action": "add-client", "client": client}) - self._clients[client] = ClientState(client, versions=versions) + parent._clients[client] = ClientState(client, versions=versions) for plugin in self.plugins[:]: try: @@ -4545,7 +4573,7 @@ async def add_client(self, comm, client=None, versions=None): ws: WorkerState version_warning = version_module.error_message( version_module.get_versions(), - {w: ws._versions for w, ws in self._workers.items()}, + {w: ws._versions for w, ws in parent._workers.items()}, versions, ) msg.update(version_warning) @@ -4570,11 +4598,12 @@ async def add_client(self, comm, client=None, versions=None): def remove_client(self, client=None): """ Remove client from network """ + parent: SchedulerState = cast(SchedulerState, self) if self.status == Status.running: logger.info("Remove client %s", client) self.log_event(["all", client], {"action": "remove-client", "client": client}) try: - cs: ClientState = self._clients[client] + cs: ClientState = parent._clients[client] except KeyError: # XXX is this a legitimate condition? pass @@ -4583,7 +4612,7 @@ def remove_client(self, client=None): self.client_releases_keys( keys=[ts._key for ts in cs._wants_what], client=cs._client_key ) - del self._clients[client] + del parent._clients[client] for plugin in self.plugins[:]: try: @@ -4593,7 +4622,7 @@ def remove_client(self, client=None): def remove_client_from_events(): # If the client isn't registered anymore after the delay, remove from events - if client not in self._clients and client in self.events: + if client not in parent._clients and client in self.events: del self.events[client] cleanup_delay = parse_timedelta( @@ -4603,8 +4632,9 @@ def remove_client_from_events(): def send_task_to_worker(self, worker, ts: TaskState, duration=None): """ Send a single computational task to a worker """ + parent: SchedulerState = cast(SchedulerState, self) try: - msg: dict = self._task_to_msg(ts, duration) + msg: dict = parent._task_to_msg(ts, duration) self.worker_send(worker, msg) except Exception as e: logger.exception(e) @@ -4618,7 +4648,8 @@ def handle_uncaught_error(self, **msg): logger.exception(clean_exception(**msg)[1]) def handle_task_finished(self, key=None, worker=None, **msg): - if worker not in self._workers: + parent: SchedulerState = cast(SchedulerState, self) + if worker not in parent._workers: return validate_key(key) r = self.stimulus_task_finished(key=key, worker=worker, **msg) @@ -4629,24 +4660,26 @@ def handle_task_erred(self, key=None, **msg): self.transitions(r) def handle_release_data(self, key=None, worker=None, client=None, **msg): - ts: TaskState = self._tasks.get(key) + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState = parent._tasks.get(key) if ts is None: return - ws: WorkerState = self._workers[worker] + ws: WorkerState = parent._workers[worker] if ts._processing_on != ws: return r = self.stimulus_missing_data(key=key, ensure=False, **msg) self.transitions(r) def handle_missing_data(self, key=None, errant_worker=None, **kwargs): + parent: SchedulerState = cast(SchedulerState, self) logger.debug("handle missing data key=%s worker=%s", key, errant_worker) self.log.append(("missing", key, errant_worker)) - ts: TaskState = self._tasks.get(key) + ts: TaskState = parent._tasks.get(key) if ts is None or not ts._who_has: return - if errant_worker in self._workers: - ws: WorkerState = self._workers[errant_worker] + if errant_worker in parent._workers: + ws: WorkerState = parent._workers[errant_worker] if ws in ts._who_has: ts._who_has.remove(ws) ws._has_what.remove(ts) @@ -4658,8 +4691,9 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): self.transitions({key: "forgotten"}) def release_worker_data(self, comm=None, keys=None, worker=None): - ws: WorkerState = self._workers[worker] - tasks = {self._tasks[k] for k in keys} + parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState = parent._workers[worker] + tasks = {parent._tasks[k] for k in keys} removed_tasks = tasks & ws._has_what ws._has_what -= removed_tasks @@ -4680,7 +4714,8 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): We stop the task from being stolen in the future, and change task duration accounting as if the task has stopped. """ - ts: TaskState = self._tasks[key] + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState = parent._tasks[key] if "stealing" in self._extensions: self._extensions["stealing"].remove_key_from_stealable(ts) @@ -4700,7 +4735,7 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): ts._prefix._duration_average = avg_duration ws._occupancy -= ws._processing[ts] - self._total_occupancy -= ws._processing[ts] + parent._total_occupancy -= ws._processing[ts] ws._processing[ts] = 0 self.check_idle_saturated(ws) @@ -4786,18 +4821,19 @@ async def scatter( -------- Scheduler.broadcast: """ + parent: SchedulerState = cast(SchedulerState, self) start = time() - while not self._workers: + while not parent._workers: await asyncio.sleep(0.2) if time() > start + timeout: raise TimeoutError("No workers found") if workers is None: ws: WorkerState - nthreads = {w: ws._nthreads for w, ws in self._workers.items()} + nthreads = {w: ws._nthreads for w, ws in parent._workers.items()} else: workers = [self.coerce_address(w) for w in workers] - nthreads = {w: self._workers[w].nthreads for w in workers} + nthreads = {w: parent._workers[w].nthreads for w in workers} assert isinstance(data, dict) @@ -4821,11 +4857,12 @@ async def scatter( async def gather(self, comm=None, keys=None, serializers=None): """ Collect data in from workers """ + parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState keys = list(keys) who_has = {} for key in keys: - ts: TaskState = self._tasks.get(key) + ts: TaskState = parent._tasks.get(key) if ts is not None: who_has[key] = [ws._address for ws in ts._who_has] else: @@ -4838,7 +4875,7 @@ async def gather(self, comm=None, keys=None, serializers=None): result = {"status": "OK", "data": data} else: missing_states = [ - (self._tasks[key].state if key in self._tasks else None) + (parent._tasks[key].state if key in parent._tasks else None) for key in missing_keys ] logger.exception( @@ -4860,7 +4897,7 @@ async def gather(self, comm=None, keys=None, serializers=None): for key, workers in missing_keys.items(): # Task may already be gone if it was held by a # `missing_worker` - ts: TaskState = self._tasks.get(key) + ts: TaskState = parent._tasks.get(key) logger.exception( "Workers don't have promised key: %s, %s", str(workers), @@ -4869,7 +4906,7 @@ async def gather(self, comm=None, keys=None, serializers=None): if not workers or ts is None: continue for worker in workers: - ws = self._workers.get(worker) + ws = parent._workers.get(worker) if ws is not None and ts in ws._has_what: ws._has_what.remove(ts) ts._who_has.remove(ws) @@ -4888,22 +4925,23 @@ def clear_task_state(self): async def restart(self, client=None, timeout=3): """ Restart all workers. Reset local state. """ + parent: SchedulerState = cast(SchedulerState, self) with log_errors(): - n_workers = len(self._workers) + n_workers = len(parent._workers) logger.info("Send lost future signal to clients") cs: ClientState ts: TaskState - for cs in self._clients.values(): + for cs in parent._clients.values(): self.client_releases_keys( keys=[ts._key for ts in cs._wants_what], client=cs._client_key ) ws: WorkerState - nannies = {addr: ws._nanny for addr, ws in self._workers.items()} + nannies = {addr: ws._nanny for addr, ws in parent._workers.items()} - for addr in list(self._workers): + for addr in list(parent._workers): try: # Ask the worker to close if it doesn't have a nanny, # otherwise the nanny will kill it anyway @@ -4960,7 +4998,7 @@ async def restart(self, client=None, timeout=3): self.log_event([client, "all"], {"action": "restart", "client": client}) start = time() - while time() < start + 10 and len(self._workers) < n_workers: + while time() < start + 10 and len(parent._workers) < n_workers: await asyncio.sleep(0.01) self.report({"op": "restart"}) @@ -4975,19 +5013,20 @@ async def broadcast( serializers=None, ): """ Broadcast message to workers, return all results """ + parent: SchedulerState = cast(SchedulerState, self) if workers is None or workers is True: if hosts is None: - workers = list(self._workers) + workers = list(parent._workers) else: workers = [] if hosts is not None: for host in hosts: - if host in self._host_info: - workers.extend(self._host_info[host]["addresses"]) + if host in parent._host_info: + workers.extend(parent._host_info[host]["addresses"]) # TODO replace with worker_list if nanny: - addresses = [self._workers[w].nanny for w in workers] + addresses = [parent._workers[w].nanny for w in workers] else: addresses = workers @@ -5023,13 +5062,14 @@ async def _delete_worker_data(self, worker_address, keys): keys: List[str] List of keys to delete on the specified worker """ + parent: SchedulerState = cast(SchedulerState, self) await retry_operation( self.rpc(addr=worker_address).delete_data, keys=list(keys), report=False ) - ws: WorkerState = self._workers[worker_address] + ws: WorkerState = parent._workers[worker_address] ts: TaskState - tasks: set = {self._tasks[key] for key in keys} + tasks: set = {parent._tasks[key] for key in keys} ws._has_what -= tasks for ts in tasks: ts._who_has.remove(ws) @@ -5047,22 +5087,23 @@ async def rebalance(self, comm=None, keys=None, workers=None): occupied worker until either the sender or the recipient are at the average expected load. """ + parent: SchedulerState = cast(SchedulerState, self) ts: TaskState with log_errors(): async with self._lock: if keys: - tasks = {self._tasks[k] for k in keys} + tasks = {parent._tasks[k] for k in keys} missing_data = [ts._key for ts in tasks if not ts._who_has] if missing_data: return {"status": "missing-data", "keys": missing_data} else: - tasks = set(self._tasks.values()) + tasks = set(parent._tasks.values()) if workers: - workers = {self._workers[w] for w in workers} + workers = {parent._workers[w] for w in workers} workers_by_task = {ts: ts._who_has & workers for ts in tasks} else: - workers = set(self._workers.values()) + workers = set(parent._workers.values()) workers_by_task = {ts: ts._who_has for ts in tasks} ws: WorkerState @@ -5202,13 +5243,14 @@ async def replicate( -------- Scheduler.rebalance """ + parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState wws: WorkerState ts: TaskState assert branching_factor > 0 async with self._lock if lock else empty_context: - workers = {self._workers[w] for w in self.workers_list(workers)} + workers = {parent._workers[w] for w in self.workers_list(workers)} if n is None: n = len(workers) else: @@ -5216,7 +5258,7 @@ async def replicate( if n == 0: raise ValueError("Can not use replicate to delete data") - tasks = {self._tasks[k] for k in keys} + tasks = {parent._tasks[k] for k in keys} missing_data = [ts._key for ts in tasks if not ts._who_has] if missing_data: return {"status": "missing-data", "keys": missing_data} @@ -5352,19 +5394,20 @@ def workers_to_close( -------- Scheduler.retire_workers """ + parent: SchedulerState = cast(SchedulerState, self) if target is not None and n is None: - n = len(self._workers) - target + n = len(parent._workers) - target if n is not None: if n < 0: n = 0 - target = len(self._workers) - n + target = len(parent._workers) - n if n is None and memory_ratio is None: memory_ratio = 2 ws: WorkerState with log_errors(): - if not n and all([ws._processing for ws in self._workers.values()]): + if not n and all([ws._processing for ws in parent._workers.values()]): return [] if key is None: @@ -5374,7 +5417,7 @@ def workers_to_close( ): key = pickle.loads(key) - groups = groupby(key, self._workers.values()) + groups = groupby(key, parent._workers.values()) limit_bytes = { k: sum([ws._memory_limit for ws in v]) for k, v in groups.items() @@ -5393,7 +5436,7 @@ def _key(group): idle = sorted(groups, key=_key) to_close = [] - n_remain = len(self._workers) + n_remain = len(parent._workers) while idle: group = idle.pop() @@ -5459,6 +5502,7 @@ async def retire_workers( -------- Scheduler.workers_to_close """ + parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState ts: TaskState with log_errors(): @@ -5469,7 +5513,7 @@ async def retire_workers( names = set(map(str, names)) workers = [ ws._address - for ws in self._workers.values() + for ws in parent._workers.values() if str(ws._name) in names ] if workers is None: @@ -5488,7 +5532,7 @@ async def retire_workers( return {} except KeyError: # keys left during replicate pass - workers = {self._workers[w] for w in workers if w in self._workers} + workers = {parent._workers[w] for w in workers if w in parent._workers} if not workers: return {} logger.info("Retire workers %s", workers) @@ -5497,7 +5541,7 @@ async def retire_workers( keys = set.union(*[w.has_what for w in workers]) keys = {ts._key for ts in keys if ts._who_has.issubset(workers)} - other_workers = set(self._workers.values()) - workers + other_workers = set(parent._workers.values()) - workers if keys: if other_workers: logger.info("Moving %d keys to other workers", len(keys)) @@ -5540,11 +5584,12 @@ def add_keys(self, comm=None, worker=None, keys=()): This should not be used in practice and is mostly here for legacy reasons. However, it is sent by workers from time to time. """ - if worker not in self._workers: + parent: SchedulerState = cast(SchedulerState, self) + if worker not in parent._workers: return "not found" - ws: WorkerState = self._workers[worker] + ws: WorkerState = parent._workers[worker] for key in keys: - ts: TaskState = self._tasks.get(key) + ts: TaskState = parent._tasks.get(key) if ts is not None and ts._state == "memory": if ts not in ws._has_what: ws._nbytes += ts.get_nbytes() @@ -5567,6 +5612,7 @@ def update_data( -------- Scheduler.mark_key_in_memory """ + parent: SchedulerState = cast(SchedulerState, self) with log_errors(): who_has = { k: [self.coerce_address(vv) for vv in v] for k, v in who_has.items() @@ -5574,14 +5620,14 @@ def update_data( logger.debug("Update data %s", who_has) for key, workers in who_has.items(): - ts: TaskState = self._tasks.get(key) + ts: TaskState = parent._tasks.get(key) if ts is None: ts: TaskState = self.new_task(key, None, "memory") ts.state = "memory" if key in nbytes: ts.set_nbytes(nbytes[key]) for w in workers: - ws: WorkerState = self._workers[w] + ws: WorkerState = parent._workers[w] if ts not in ws._has_what: ws._nbytes += ts.get_nbytes() ws._has_what.add(ts) @@ -5594,8 +5640,9 @@ def update_data( self.client_desires_keys(keys=list(who_has), client=client) def report_on_key(self, key: str = None, ts: TaskState = None, client: str = None): + parent: SchedulerState = cast(SchedulerState, self) if ts is None: - tasks: dict = self._tasks + tasks: dict = parent._tasks ts = tasks.get(key) elif key is None: key = ts._key @@ -5662,57 +5709,67 @@ def subscribe_worker_status(self, comm=None): return ident def get_processing(self, comm=None, workers=None): + parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState ts: TaskState if workers is not None: workers = set(map(self.coerce_address, workers)) - return {w: [ts._key for ts in self._workers[w].processing] for w in workers} + return { + w: [ts._key for ts in parent._workers[w].processing] for w in workers + } else: return { - w: [ts._key for ts in ws._processing] for w, ws in self._workers.items() + w: [ts._key for ts in ws._processing] + for w, ws in parent._workers.items() } def get_who_has(self, comm=None, keys=None): + parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState ts: TaskState if keys is not None: return { - k: [ws._address for ws in self._tasks[k].who_has] - if k in self._tasks + k: [ws._address for ws in parent._tasks[k].who_has] + if k in parent._tasks else [] for k in keys } else: return { key: [ws._address for ws in ts._who_has] - for key, ts in self._tasks.items() + for key, ts in parent._tasks.items() } def get_has_what(self, comm=None, workers=None): + parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState ts: TaskState if workers is not None: workers = map(self.coerce_address, workers) return { - w: [ts._key for ts in self._workers[w].has_what] - if w in self._workers + w: [ts._key for ts in parent._workers[w].has_what] + if w in parent._workers else [] for w in workers } else: return { - w: [ts._key for ts in ws._has_what] for w, ws in self._workers.items() + w: [ts._key for ts in ws._has_what] for w, ws in parent._workers.items() } def get_ncores(self, comm=None, workers=None): + parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState if workers is not None: workers = map(self.coerce_address, workers) - return {w: self._workers[w].nthreads for w in workers if w in self._workers} + return { + w: parent._workers[w].nthreads for w in workers if w in parent._workers + } else: - return {w: ws._nthreads for w, ws in self._workers.items()} + return {w: ws._nthreads for w, ws in parent._workers.items()} async def get_call_stack(self, comm=None, keys=None): + parent: SchedulerState = cast(SchedulerState, self) ts: TaskState dts: TaskState if keys is not None: @@ -5720,7 +5777,7 @@ async def get_call_stack(self, comm=None, keys=None): processing = set() while stack: key = stack.pop() - ts = self._tasks[key] + ts = parent._tasks[key] if ts._state == "waiting": stack.extend([dts._key for dts in ts._dependencies]) elif ts._state == "processing": @@ -5731,7 +5788,7 @@ async def get_call_stack(self, comm=None, keys=None): if ts._processing_on: workers[ts._processing_on.address].append(ts._key) else: - workers = {w: None for w in self._workers} + workers = {w: None for w in parent._workers} if not workers: return {} @@ -5743,13 +5800,14 @@ async def get_call_stack(self, comm=None, keys=None): return response def get_nbytes(self, comm=None, keys=None, summary=True): + parent: SchedulerState = cast(SchedulerState, self) ts: TaskState with log_errors(): if keys is not None: - result = {k: self._tasks[k].nbytes for k in keys} + result = {k: parent._tasks[k].nbytes for k in keys} else: result = { - k: ts._nbytes for k, ts in self._tasks.items() if ts._nbytes >= 0 + k: ts._nbytes for k, ts in parent._tasks.items() if ts._nbytes >= 0 } if summary: @@ -5798,8 +5856,9 @@ def get_metadata(self, comm=None, keys=None, default=no_default): raise def get_task_status(self, comm=None, keys=None): + parent: SchedulerState = cast(SchedulerState, self) return { - key: (self._tasks[key].state if key in self._tasks else None) + key: (parent._tasks[key].state if key in parent._tasks else None) for key in keys } @@ -5845,10 +5904,11 @@ async def register_worker_plugin(self, comm, plugin, name=None): ##################### def remove_key(self, key): - tasks: dict = self._tasks + parent: SchedulerState = cast(SchedulerState, self) + tasks: dict = parent._tasks ts: TaskState = tasks.pop(key) assert ts._state == "forgotten" - self._unrunnable.discard(ts) + parent._unrunnable.discard(ts) cs: ClientState for cs in ts._who_wants: cs._wants_what.remove(ts) @@ -5873,12 +5933,13 @@ def transition(self, key, finish, *args, **kwargs): -------- Scheduler.transitions: transitive version of this function """ + parent: SchedulerState = cast(SchedulerState, self) ts: TaskState worker_msgs: dict client_msgs: dict try: try: - ts = self._tasks[key] + ts = parent._tasks[key] except KeyError: return {} start = ts._state @@ -5918,7 +5979,7 @@ def transition(self, key, finish, *args, **kwargs): finish2 = ts._state self.transition_log.append((key, start, finish2, recommendations, time())) - if self._validate: + if parent._validate: logger.debug( "Transitioned %r %s->%s (actual: %s). Consequence: %s", key, @@ -5935,14 +5996,14 @@ def transition(self, key, finish, *args, **kwargs): ts._dependencies = dependencies except KeyError: pass - self._tasks[ts._key] = ts + parent._tasks[ts._key] = ts for plugin in list(self.plugins): try: plugin.transition(key, start, finish2, *args, **kwargs) except Exception: logger.info("Plugin failed with exception", exc_info=True) if ts._state == "forgotten": - del self._tasks[ts._key] + del parent._tasks[ts._key] if ts._state == "forgotten" and ts._group._name in self.task_groups: # Remove TaskGroup if all tasks are in the forgotten state @@ -5966,6 +6027,7 @@ def transitions(self, recommendations: dict): This includes feedback from previous transitions and continues until we reach a steady state """ + parent: SchedulerState = cast(SchedulerState, self) keys = set() recommendations = recommendations.copy() while recommendations: @@ -5974,7 +6036,7 @@ def transitions(self, recommendations: dict): new = self.transition(key, finish) recommendations.update(new) - if self._validate: + if parent._validate: for key in keys: self.validate_key(key) @@ -5993,9 +6055,10 @@ def reschedule(self, key=None, worker=None): Things may have shifted and this task may now be better suited to run elsewhere """ + parent: SchedulerState = cast(SchedulerState, self) ts: TaskState try: - ts = self._tasks[key] + ts = parent._tasks[key] except KeyError: logger.warning( "Attempting to reschedule task {}, which was not " @@ -6027,19 +6090,21 @@ def release_resources(self, ts: TaskState, ws: WorkerState): ##################### def add_resources(self, comm=None, worker=None, resources=None): - ws: WorkerState = self._workers[worker] + parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState = parent._workers[worker] if resources: ws._resources.update(resources) ws._used_resources = {} for resource, quantity in ws._resources.items(): ws._used_resources[resource] = 0 - self._resources[resource][worker] = quantity + parent._resources[resource][worker] = quantity return "OK" def remove_resources(self, worker): - ws: WorkerState = self._workers[worker] + parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState = parent._workers[worker] for resource, quantity in ws._resources.items(): - del self._resources[resource][worker] + del parent._resources[resource][worker] def coerce_address(self, addr, resolve=True): """ @@ -6067,8 +6132,9 @@ def coerce_hostname(self, host): """ Coerce the hostname of a worker. """ + parent: SchedulerState = cast(SchedulerState, self) if host in self.aliases: - return self._workers[self.aliases[host]].host + return parent._workers[self.aliases[host]].host else: return host @@ -6079,15 +6145,16 @@ def workers_list(self, workers): Takes a list of worker addresses or hostnames. Returns a list of all worker addresses that match """ + parent: SchedulerState = cast(SchedulerState, self) if workers is None: - return list(self._workers) + return list(parent._workers) out = set() for w in workers: if ":" in w: out.add(w) else: - out.update({ww for ww in self._workers if w in ww}) # TODO: quadratic + out.update({ww for ww in parent._workers if w in ww}) # TODO: quadratic return list(out) def start_ipython(self, comm=None): @@ -6114,10 +6181,11 @@ async def get_profile( stop=None, key=None, ): + parent: SchedulerState = cast(SchedulerState, self) if workers is None: - workers = self._workers + workers = parent._workers else: - workers = set(self._workers) & set(workers) + workers = set(parent._workers) & set(workers) if scheduler: return profile.get_profile(self.io_loop.profile, start=start, stop=stop) @@ -6147,15 +6215,16 @@ async def get_profile_metadata( stop=None, profile_cycle_interval=None, ): + parent: SchedulerState = cast(SchedulerState, self) dt = profile_cycle_interval or dask.config.get( "distributed.worker.profile.cycle" ) dt = parse_timedelta(dt, default="ms") if workers is None: - workers = self._workers + workers = parent._workers else: - workers = set(self._workers) & set(workers) + workers = set(parent._workers) & set(workers) results = await asyncio.gather( *(self.rpc(w).profile_metadata(start=start, stop=stop) for w in workers), return_exceptions=True, @@ -6189,6 +6258,7 @@ async def get_profile_metadata( return {"counts": counts, "keys": keys} async def performance_report(self, comm=None, start=None, code=""): + parent: SchedulerState = cast(SchedulerState, self) stop = time() # Profiles compute, scheduler, workers = await asyncio.gather( @@ -6273,10 +6343,10 @@ def profile_to_figure(state): ntasks=total_tasks, tasks_timings=tasks_timings, address=self.address, - nworkers=len(self._workers), - threads=sum([ws._nthreads for ws in self._workers.values()]), + nworkers=len(parent._workers), + threads=sum([ws._nthreads for ws in parent._workers.values()]), memory=format_bytes( - sum([ws._memory_limit for ws in self._workers.values()]) + sum([ws._memory_limit for ws in parent._workers.values()]) ), code=code, dask_version=dask.__version__, @@ -6365,6 +6435,7 @@ def reevaluate_occupancy(self, worker_index=0): lets us avoid this fringe optimization when we have better things to think about. """ + parent: SchedulerState = cast(SchedulerState, self) DELAY = 0.1 try: if self.status == Status.closed: @@ -6374,7 +6445,7 @@ def reevaluate_occupancy(self, worker_index=0): next_time = timedelta(seconds=DELAY) if self.proc.cpu_percent() < 50: - workers = list(self._workers.values()) + workers = list(parent._workers.values()) for i in range(len(workers)): ws: WorkerState = workers[worker_index % len(workers)] worker_index += 1 @@ -6399,11 +6470,12 @@ def reevaluate_occupancy(self, worker_index=0): raise async def check_worker_ttl(self): + parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState now = time() - for ws in self._workers.values(): + for ws in parent._workers.values(): if (ws._last_seen < now - self.worker_ttl) and ( - ws._last_seen < now - 10 * heartbeat_interval(len(self._workers)) + ws._last_seen < now - 10 * heartbeat_interval(len(parent._workers)) ): logger.warning( "Worker failed to heartbeat within %s seconds. Closing: %s", @@ -6413,8 +6485,12 @@ async def check_worker_ttl(self): await self.remove_worker(address=ws._address) def check_idle(self): + parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState - if any([ws._processing for ws in self._workers.values()]) or self._unrunnable: + if ( + any([ws._processing for ws in parent._workers.values()]) + or parent._unrunnable + ): self.idle_since = None return elif not self.idle_since: @@ -6443,19 +6519,20 @@ def adaptive_target(self, comm=None, target_duration=None): -------- distributed.deploy.Adaptive """ + parent: SchedulerState = cast(SchedulerState, self) if target_duration is None: target_duration = dask.config.get("distributed.adaptive.target-duration") target_duration = parse_timedelta(target_duration) # CPU cpu = math.ceil( - self._total_occupancy / target_duration + parent._total_occupancy / target_duration ) # TODO: threads per worker # Avoid a few long tasks from asking for many cores ws: WorkerState tasks_processing = 0 - for ws in self._workers.values(): + for ws in parent._workers.values(): tasks_processing += len(ws._processing) if tasks_processing > cpu: @@ -6463,25 +6540,25 @@ def adaptive_target(self, comm=None, target_duration=None): else: cpu = min(tasks_processing, cpu) - if self._unrunnable and not self._workers: + if parent._unrunnable and not parent._workers: cpu = max(1, cpu) # Memory - limit_bytes = {addr: ws._memory_limit for addr, ws in self._workers.items()} - worker_bytes = [ws._nbytes for ws in self._workers.values()] + limit_bytes = {addr: ws._memory_limit for addr, ws in parent._workers.items()} + worker_bytes = [ws._nbytes for ws in parent._workers.values()] limit = sum(limit_bytes.values()) total = sum(worker_bytes) if total > 0.6 * limit: - memory = 2 * len(self._workers) + memory = 2 * len(parent._workers) else: memory = 0 target = max(memory, cpu) - if target >= len(self._workers): + if target >= len(parent._workers): return target else: # Scale down? to_close = self.workers_to_close() - return len(self._workers) - len(to_close) + return len(parent._workers) - len(to_close) @cfunc From 8dc806da243116320e746aa783d3436bff2edc07 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 12:50:15 -0800 Subject: [PATCH 20/38] Drop no longer needed `cast`s & local assignments --- distributed/scheduler.py | 84 ++++++++++++++-------------------------- 1 file changed, 29 insertions(+), 55 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index f3c89ceb572..5e5472ad19f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1646,11 +1646,10 @@ def _remove_from_processing(self, ts: TaskState) -> str: """ Remove *ts* from the set of processing tasks. """ - workers: dict = cast(dict, self._workers) ws: WorkerState = ts._processing_on ts._processing_on = None w: str = ws._address - if w in workers: # may have been removed + if w in self._workers_dv: # may have been removed duration = ws._processing.pop(ts) if not ws._processing: self._total_occupancy -= ws._occupancy @@ -1730,9 +1729,7 @@ def _add_to_memory( def transition_released_waiting(self, key): try: - tasks: dict = self._tasks - workers: dict = cast(dict, self._workers) - ts: TaskState = tasks[key] + ts: TaskState = self._tasks[key] dts: TaskState worker_msgs: dict = {} client_msgs: dict = {} @@ -1770,7 +1767,7 @@ def transition_released_waiting(self, key): ts._waiters = {dts for dts in ts._dependents if dts._state == "waiting"} if not ts._waiting_on: - if workers: + if self._workers_dv: recommendations[key] = "processing" else: self._unrunnable.add(ts) @@ -1787,9 +1784,7 @@ def transition_released_waiting(self, key): def transition_no_worker_waiting(self, key): try: - tasks: dict = self._tasks - workers: dict = cast(dict, self._workers) - ts: TaskState = tasks[key] + ts: TaskState = self._tasks[key] dts: TaskState worker_msgs: dict = {} client_msgs: dict = {} @@ -1819,7 +1814,7 @@ def transition_no_worker_waiting(self, key): ts.state = "waiting" if not ts._waiting_on: - if workers: + if self._workers_dv: recommendations[key] = "processing" else: self._unrunnable.add(ts) @@ -1838,7 +1833,6 @@ def decide_worker(self, ts: TaskState) -> WorkerState: """ Decide on a worker for task *ts*. Return a WorkerState. """ - workers: dict = cast(dict, self._workers) ws: WorkerState = None valid_workers: set = self.valid_workers(ts) @@ -1846,7 +1840,7 @@ def decide_worker(self, ts: TaskState) -> WorkerState: valid_workers is not None and not valid_workers and not ts._loose_restrictions - and workers + and self._workers_dv ): self._unrunnable.add(ts) ts.state = "no-worker" @@ -1855,7 +1849,7 @@ def decide_worker(self, ts: TaskState) -> WorkerState: if ts._dependencies or valid_workers is not None: ws = decide_worker( ts, - workers.values(), + self._workers_dv.values(), valid_workers, partial(self.worker_objective, ts), ) @@ -1866,15 +1860,14 @@ def decide_worker(self, ts: TaskState) -> WorkerState: if n_workers < 20: # smart but linear in small case ws = min(worker_pool.values(), key=operator.attrgetter("occupancy")) else: # dumb but fast in large case - n_tasks: Py_ssize_t = self._n_tasks - ws = worker_pool.values()[n_tasks % n_workers] + ws = worker_pool.values()[self._n_tasks % n_workers] if self._validate: assert ws is None or isinstance(ws, WorkerState), ( type(ws), ws, ) - assert ws._address in workers + assert ws._address in self._workers_dv return ws @@ -1897,8 +1890,7 @@ def set_duration_estimate(self, ts: TaskState, ws: WorkerState): def transition_waiting_processing(self, key): try: - tasks: dict = self._tasks - ts: TaskState = tasks[key] + ts: TaskState = self._tasks[key] dts: TaskState worker_msgs: dict = {} client_msgs: dict = {} @@ -1944,10 +1936,8 @@ def transition_waiting_processing(self, key): def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs): try: - workers: dict = cast(dict, self._workers) - ws: WorkerState = workers[worker] - tasks: dict = self._tasks - ts: TaskState = tasks[key] + ws: WorkerState = self._workers_dv[worker] + ts: TaskState = self._tasks[key] worker_msgs: dict = {} client_msgs: dict = {} @@ -1997,8 +1987,7 @@ def transition_processing_memory( worker_msgs: dict = {} client_msgs: dict = {} try: - tasks: dict = self._tasks - ts: TaskState = tasks[key] + ts: TaskState = self._tasks[key] assert worker assert isinstance(worker, str) @@ -2011,8 +2000,7 @@ def transition_processing_memory( assert not ts._exception_blame assert ts._state == "processing" - workers: dict = cast(dict, self._workers) - ws = workers.get(worker) + ws = self._workers_dv.get(worker) if ws is None: return {key: "released"}, worker_msgs, client_msgs @@ -2102,8 +2090,7 @@ def transition_processing_memory( def transition_memory_released(self, key, safe=False): ws: WorkerState try: - tasks: dict = self._tasks - ts: TaskState = tasks[key] + ts: TaskState = self._tasks[key] dts: TaskState worker_msgs: dict = {} client_msgs: dict = {} @@ -2175,8 +2162,7 @@ def transition_memory_released(self, key, safe=False): def transition_released_erred(self, key): try: - tasks: dict = self._tasks - ts: TaskState = tasks[key] + ts: TaskState = self._tasks[key] dts: TaskState failing_ts: TaskState worker_msgs: dict = {} @@ -2222,8 +2208,7 @@ def transition_released_erred(self, key): def transition_erred_released(self, key): try: - tasks: dict = self._tasks - ts: TaskState = tasks[key] + ts: TaskState = self._tasks[key] dts: TaskState worker_msgs: dict = {} client_msgs: dict = {} @@ -2264,8 +2249,7 @@ def transition_erred_released(self, key): def transition_waiting_released(self, key): try: - tasks: dict = self._tasks - ts: TaskState = tasks[key] + ts: TaskState = self._tasks[key] worker_msgs: dict = {} client_msgs: dict = {} @@ -2304,8 +2288,7 @@ def transition_waiting_released(self, key): def transition_processing_released(self, key): try: - tasks: dict = self._tasks - ts: TaskState = tasks[key] + ts: TaskState = self._tasks[key] dts: TaskState worker_msgs: dict = {} client_msgs: dict = {} @@ -2355,8 +2338,7 @@ def transition_processing_erred( ): ws: WorkerState try: - tasks: dict = self._tasks - ts: TaskState = tasks[key] + ts: TaskState = self._tasks[key] dts: TaskState failing_ts: TaskState worker_msgs: dict = {} @@ -2432,8 +2414,7 @@ def transition_processing_erred( def transition_no_worker_released(self, key): try: - tasks: dict = self._tasks - ts: TaskState = tasks[key] + ts: TaskState = self._tasks[key] dts: TaskState worker_msgs: dict = {} client_msgs: dict = {} @@ -2462,7 +2443,6 @@ def transition_no_worker_released(self, key): def _propagate_forgotten(self, ts: TaskState, recommendations: dict): worker_msgs: dict = {} - workers: dict = cast(dict, self._workers) ts.state = "forgotten" key: str = ts._key dts: TaskState @@ -2495,18 +2475,16 @@ def _propagate_forgotten(self, ts: TaskState, recommendations: dict): ws._has_what.remove(ts) ws._nbytes -= ts.get_nbytes() w: str = ws._address - if w in workers: # in case worker has died + if w in self._workers_dv: # in case worker has died worker_msgs[w] = {"op": "delete-data", "keys": [key], "report": False} ts._who_has.clear() return worker_msgs def transition_memory_forgotten(self, key): - tasks: dict ws: WorkerState try: - tasks = self._tasks - ts: TaskState = tasks[key] + ts: TaskState = self._tasks[key] worker_msgs: dict = {} client_msgs: dict = {} @@ -2548,8 +2526,7 @@ def transition_memory_forgotten(self, key): def transition_released_forgotten(self, key): try: - tasks: dict = self._tasks - ts: TaskState = tasks[key] + ts: TaskState = self._tasks[key] worker_msgs: dict = {} client_msgs: dict = {} @@ -2704,11 +2681,10 @@ def _task_to_report_msg(self, ts: TaskState) -> dict: def _task_to_client_msgs(self, ts: TaskState) -> dict: cs: ClientState - clients: dict = self._clients client_keys: list if ts is None: # Notify all clients - client_keys = list(clients) + client_keys = list(self._clients) else: # Notify clients interested in key client_keys = [cs._client_key for cs in ts._who_wants] @@ -2784,11 +2760,10 @@ def valid_workers(self, ts: TaskState) -> set: * host_restrictions * resource_restrictions """ - workers: dict = cast(dict, self._workers) s: set = None if ts._worker_restrictions: - s = {w for w in ts._worker_restrictions if w in workers} + s = {w for w in ts._worker_restrictions if w in self._workers_dv} if ts._host_restrictions: # Resolve the alias here rather than early, for the worker @@ -2821,7 +2796,7 @@ def valid_workers(self, ts: TaskState) -> set: s &= ww if s is not None: - s = {workers[w] for w in s} + s = {self._workers_dv[w] for w in s} return s @@ -4051,8 +4026,7 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): ts: TaskState = tasks.get(key) if ts is None: return {} - workers: dict = cast(dict, parent._workers) - ws: WorkerState = workers[worker] + ws: WorkerState = parent._workers_dv[worker] ts._metadata.update(kwargs["metadata"]) recommendations: dict @@ -4468,7 +4442,7 @@ def validate_state(self, allow_overlap=False): assert ws._address == w if not ws._processing: assert not ws._occupancy - assert ws._address in cast(dict, parent._idle) + assert ws._address in parent._idle_dv ts: TaskState for k, ts in parent._tasks.items(): From f3b07f2f177a693b4d4404ae7abff161fb61b070 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 12:50:16 -0800 Subject: [PATCH 21/38] Use `dict` views onto `SortedDict` where possible As Cython will leverage the Python C API on `dict` objects, it is faster to use a `dict`. This works well with `SortedDict` as it is a subclass of `dict`. However we have to keep in mind the `dict` will not have the same order as the `SortedDict`. Also the `dict` won't handle modifications correctly (unlike the `SortedDict`). So we need to be mindful of how the object is being used before using the `dict` view. That said, this works well in a lot of the code if order doesn't matter as happens when things are placed in new `dict`s, `set`s, etc. or when ordered objects like `list` end up being used in a way where order is irrelevant like `sum`ming or where the iteration order is matched by everything else. In cases where it is not obvious the `dict` can be used, we simply skip it and stick with the `SortedDict`. --- distributed/scheduler.py | 143 ++++++++++++++++++++------------------- 1 file changed, 75 insertions(+), 68 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 5e5472ad19f..aa36c4dcf67 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4160,12 +4160,12 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): address = self.coerce_address(address) - if address not in parent._workers: + if address not in parent._workers_dv: return "already-removed" host = get_address_host(address) - ws: WorkerState = parent._workers[address] + ws: WorkerState = parent._workers_dv[address] self.log_event( ["all", address], @@ -4240,16 +4240,16 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): except Exception as e: logger.exception(e) - if not parent._workers: + if not parent._workers_dv: logger.info("Lost all workers") - for w in parent._workers: + for w in parent._workers_dv: self.bandwidth_workers.pop((address, w), None) self.bandwidth_workers.pop((w, address), None) def remove_worker_from_events(): # If the worker isn't registered anymore after the delay, remove from events - if address not in parent._workers and address in self.events: + if address not in parent._workers_dv and address in self.events: del self.events[address] cleanup_delay = parse_timedelta( @@ -4432,11 +4432,11 @@ def validate_state(self, allow_overlap=False): parent: SchedulerState = cast(SchedulerState, self) validate_state(parent._tasks, parent._workers, parent._clients) - if not (set(parent._workers) == set(self.stream_comms)): + if not (set(parent._workers_dv) == set(self.stream_comms)): raise ValueError("Workers not the same in all collections") ws: WorkerState - for w, ws in parent._workers.items(): + for w, ws in parent._workers_dv.items(): assert isinstance(w, str), (type(w), w) assert isinstance(ws, WorkerState), (type(ws), ws) assert ws._address == w @@ -4458,15 +4458,15 @@ def validate_state(self, allow_overlap=False): assert type(cs) == ClientState, (type(cs), cs) assert cs._client_key == c - a = {w: ws._nbytes for w, ws in parent._workers.items()} + a = {w: ws._nbytes for w, ws in parent._workers_dv.items()} b = { w: sum(ts.get_nbytes() for ts in ws._has_what) - for w, ws in parent._workers.items() + for w, ws in parent._workers_dv.items() } assert a == b, (a, b) actual_total_occupancy = 0 - for worker, ws in parent._workers.items(): + for worker, ws in parent._workers_dv.items(): assert abs(sum(ws._processing.values()) - ws._occupancy) < 1e-8 actual_total_occupancy += ws._occupancy @@ -4547,7 +4547,7 @@ async def add_client(self, comm, client=None, versions=None): ws: WorkerState version_warning = version_module.error_message( version_module.get_versions(), - {w: ws._versions for w, ws in parent._workers.items()}, + {w: ws._versions for w, ws in parent._workers_dv.items()}, versions, ) msg.update(version_warning) @@ -4623,7 +4623,7 @@ def handle_uncaught_error(self, **msg): def handle_task_finished(self, key=None, worker=None, **msg): parent: SchedulerState = cast(SchedulerState, self) - if worker not in parent._workers: + if worker not in parent._workers_dv: return validate_key(key) r = self.stimulus_task_finished(key=key, worker=worker, **msg) @@ -4638,7 +4638,7 @@ def handle_release_data(self, key=None, worker=None, client=None, **msg): ts: TaskState = parent._tasks.get(key) if ts is None: return - ws: WorkerState = parent._workers[worker] + ws: WorkerState = parent._workers_dv[worker] if ts._processing_on != ws: return r = self.stimulus_missing_data(key=key, ensure=False, **msg) @@ -4652,8 +4652,8 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): ts: TaskState = parent._tasks.get(key) if ts is None or not ts._who_has: return - if errant_worker in parent._workers: - ws: WorkerState = parent._workers[errant_worker] + if errant_worker in parent._workers_dv: + ws: WorkerState = parent._workers_dv[errant_worker] if ws in ts._who_has: ts._who_has.remove(ws) ws._has_what.remove(ts) @@ -4666,7 +4666,7 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): def release_worker_data(self, comm=None, keys=None, worker=None): parent: SchedulerState = cast(SchedulerState, self) - ws: WorkerState = parent._workers[worker] + ws: WorkerState = parent._workers_dv[worker] tasks = {parent._tasks[k] for k in keys} removed_tasks = tasks & ws._has_what ws._has_what -= removed_tasks @@ -4797,17 +4797,17 @@ async def scatter( """ parent: SchedulerState = cast(SchedulerState, self) start = time() - while not parent._workers: + while not parent._workers_dv: await asyncio.sleep(0.2) if time() > start + timeout: raise TimeoutError("No workers found") if workers is None: ws: WorkerState - nthreads = {w: ws._nthreads for w, ws in parent._workers.items()} + nthreads = {w: ws._nthreads for w, ws in parent._workers_dv.items()} else: workers = [self.coerce_address(w) for w in workers] - nthreads = {w: parent._workers[w].nthreads for w in workers} + nthreads = {w: parent._workers_dv[w].nthreads for w in workers} assert isinstance(data, dict) @@ -4880,7 +4880,7 @@ async def gather(self, comm=None, keys=None, serializers=None): if not workers or ts is None: continue for worker in workers: - ws = parent._workers.get(worker) + ws = parent._workers_dv.get(worker) if ws is not None and ts in ws._has_what: ws._has_what.remove(ts) ts._who_has.remove(ws) @@ -4902,7 +4902,7 @@ async def restart(self, client=None, timeout=3): parent: SchedulerState = cast(SchedulerState, self) with log_errors(): - n_workers = len(parent._workers) + n_workers = len(parent._workers_dv) logger.info("Send lost future signal to clients") cs: ClientState @@ -4913,9 +4913,9 @@ async def restart(self, client=None, timeout=3): ) ws: WorkerState - nannies = {addr: ws._nanny for addr, ws in parent._workers.items()} + nannies = {addr: ws._nanny for addr, ws in parent._workers_dv.items()} - for addr in list(parent._workers): + for addr in list(parent._workers_dv): try: # Ask the worker to close if it doesn't have a nanny, # otherwise the nanny will kill it anyway @@ -4972,7 +4972,7 @@ async def restart(self, client=None, timeout=3): self.log_event([client, "all"], {"action": "restart", "client": client}) start = time() - while time() < start + 10 and len(parent._workers) < n_workers: + while time() < start + 10 and len(parent._workers_dv) < n_workers: await asyncio.sleep(0.01) self.report({"op": "restart"}) @@ -4990,7 +4990,7 @@ async def broadcast( parent: SchedulerState = cast(SchedulerState, self) if workers is None or workers is True: if hosts is None: - workers = list(parent._workers) + workers = list(parent._workers_dv) else: workers = [] if hosts is not None: @@ -5000,7 +5000,7 @@ async def broadcast( # TODO replace with worker_list if nanny: - addresses = [parent._workers[w].nanny for w in workers] + addresses = [parent._workers_dv[w].nanny for w in workers] else: addresses = workers @@ -5041,7 +5041,7 @@ async def _delete_worker_data(self, worker_address, keys): self.rpc(addr=worker_address).delete_data, keys=list(keys), report=False ) - ws: WorkerState = parent._workers[worker_address] + ws: WorkerState = parent._workers_dv[worker_address] ts: TaskState tasks: set = {parent._tasks[key] for key in keys} ws._has_what -= tasks @@ -5074,10 +5074,10 @@ async def rebalance(self, comm=None, keys=None, workers=None): tasks = set(parent._tasks.values()) if workers: - workers = {parent._workers[w] for w in workers} + workers = {parent._workers_dv[w] for w in workers} workers_by_task = {ts: ts._who_has & workers for ts in tasks} else: - workers = set(parent._workers.values()) + workers = set(parent._workers_dv.values()) workers_by_task = {ts: ts._who_has for ts in tasks} ws: WorkerState @@ -5224,7 +5224,7 @@ async def replicate( assert branching_factor > 0 async with self._lock if lock else empty_context: - workers = {parent._workers[w] for w in self.workers_list(workers)} + workers = {parent._workers_dv[w] for w in self.workers_list(workers)} if n is None: n = len(workers) else: @@ -5370,18 +5370,18 @@ def workers_to_close( """ parent: SchedulerState = cast(SchedulerState, self) if target is not None and n is None: - n = len(parent._workers) - target + n = len(parent._workers_dv) - target if n is not None: if n < 0: n = 0 - target = len(parent._workers) - n + target = len(parent._workers_dv) - n if n is None and memory_ratio is None: memory_ratio = 2 ws: WorkerState with log_errors(): - if not n and all([ws._processing for ws in parent._workers.values()]): + if not n and all([ws._processing for ws in parent._workers_dv.values()]): return [] if key is None: @@ -5410,7 +5410,7 @@ def _key(group): idle = sorted(groups, key=_key) to_close = [] - n_remain = len(parent._workers) + n_remain = len(parent._workers_dv) while idle: group = idle.pop() @@ -5487,7 +5487,7 @@ async def retire_workers( names = set(map(str, names)) workers = [ ws._address - for ws in parent._workers.values() + for ws in parent._workers_dv.values() if str(ws._name) in names ] if workers is None: @@ -5506,7 +5506,9 @@ async def retire_workers( return {} except KeyError: # keys left during replicate pass - workers = {parent._workers[w] for w in workers if w in parent._workers} + workers = { + parent._workers_dv[w] for w in workers if w in parent._workers_dv + } if not workers: return {} logger.info("Retire workers %s", workers) @@ -5515,7 +5517,7 @@ async def retire_workers( keys = set.union(*[w.has_what for w in workers]) keys = {ts._key for ts in keys if ts._who_has.issubset(workers)} - other_workers = set(parent._workers.values()) - workers + other_workers = set(parent._workers_dv.values()) - workers if keys: if other_workers: logger.info("Moving %d keys to other workers", len(keys)) @@ -5559,9 +5561,9 @@ def add_keys(self, comm=None, worker=None, keys=()): reasons. However, it is sent by workers from time to time. """ parent: SchedulerState = cast(SchedulerState, self) - if worker not in parent._workers: + if worker not in parent._workers_dv: return "not found" - ws: WorkerState = parent._workers[worker] + ws: WorkerState = parent._workers_dv[worker] for key in keys: ts: TaskState = parent._tasks.get(key) if ts is not None and ts._state == "memory": @@ -5601,7 +5603,7 @@ def update_data( if key in nbytes: ts.set_nbytes(nbytes[key]) for w in workers: - ws: WorkerState = parent._workers[w] + ws: WorkerState = parent._workers_dv[w] if ts not in ws._has_what: ws._nbytes += ts.get_nbytes() ws._has_what.add(ts) @@ -5689,12 +5691,12 @@ def get_processing(self, comm=None, workers=None): if workers is not None: workers = set(map(self.coerce_address, workers)) return { - w: [ts._key for ts in parent._workers[w].processing] for w in workers + w: [ts._key for ts in parent._workers_dv[w].processing] for w in workers } else: return { w: [ts._key for ts in ws._processing] - for w, ws in parent._workers.items() + for w, ws in parent._workers_dv.items() } def get_who_has(self, comm=None, keys=None): @@ -5721,14 +5723,15 @@ def get_has_what(self, comm=None, workers=None): if workers is not None: workers = map(self.coerce_address, workers) return { - w: [ts._key for ts in parent._workers[w].has_what] - if w in parent._workers + w: [ts._key for ts in parent._workers_dv[w].has_what] + if w in parent._workers_dv else [] for w in workers } else: return { - w: [ts._key for ts in ws._has_what] for w, ws in parent._workers.items() + w: [ts._key for ts in ws._has_what] + for w, ws in parent._workers_dv.items() } def get_ncores(self, comm=None, workers=None): @@ -5737,10 +5740,12 @@ def get_ncores(self, comm=None, workers=None): if workers is not None: workers = map(self.coerce_address, workers) return { - w: parent._workers[w].nthreads for w in workers if w in parent._workers + w: parent._workers_dv[w].nthreads + for w in workers + if w in parent._workers_dv } else: - return {w: ws._nthreads for w, ws in parent._workers.items()} + return {w: ws._nthreads for w, ws in parent._workers_dv.items()} async def get_call_stack(self, comm=None, keys=None): parent: SchedulerState = cast(SchedulerState, self) @@ -5762,7 +5767,7 @@ async def get_call_stack(self, comm=None, keys=None): if ts._processing_on: workers[ts._processing_on.address].append(ts._key) else: - workers = {w: None for w in parent._workers} + workers = {w: None for w in parent._workers_dv} if not workers: return {} @@ -6065,7 +6070,7 @@ def release_resources(self, ts: TaskState, ws: WorkerState): def add_resources(self, comm=None, worker=None, resources=None): parent: SchedulerState = cast(SchedulerState, self) - ws: WorkerState = parent._workers[worker] + ws: WorkerState = parent._workers_dv[worker] if resources: ws._resources.update(resources) ws._used_resources = {} @@ -6076,7 +6081,7 @@ def add_resources(self, comm=None, worker=None, resources=None): def remove_resources(self, worker): parent: SchedulerState = cast(SchedulerState, self) - ws: WorkerState = parent._workers[worker] + ws: WorkerState = parent._workers_dv[worker] for resource, quantity in ws._resources.items(): del parent._resources[resource][worker] @@ -6108,7 +6113,7 @@ def coerce_hostname(self, host): """ parent: SchedulerState = cast(SchedulerState, self) if host in self.aliases: - return parent._workers[self.aliases[host]].host + return parent._workers_dv[self.aliases[host]].host else: return host @@ -6157,9 +6162,9 @@ async def get_profile( ): parent: SchedulerState = cast(SchedulerState, self) if workers is None: - workers = parent._workers + workers = parent._workers_dv else: - workers = set(parent._workers) & set(workers) + workers = set(parent._workers_dv) & set(workers) if scheduler: return profile.get_profile(self.io_loop.profile, start=start, stop=stop) @@ -6196,9 +6201,9 @@ async def get_profile_metadata( dt = parse_timedelta(dt, default="ms") if workers is None: - workers = parent._workers + workers = parent._workers_dv else: - workers = set(parent._workers) & set(workers) + workers = set(parent._workers_dv) & set(workers) results = await asyncio.gather( *(self.rpc(w).profile_metadata(start=start, stop=stop) for w in workers), return_exceptions=True, @@ -6317,10 +6322,10 @@ def profile_to_figure(state): ntasks=total_tasks, tasks_timings=tasks_timings, address=self.address, - nworkers=len(parent._workers), - threads=sum([ws._nthreads for ws in parent._workers.values()]), + nworkers=len(parent._workers_dv), + threads=sum([ws._nthreads for ws in parent._workers_dv.values()]), memory=format_bytes( - sum([ws._memory_limit for ws in parent._workers.values()]) + sum([ws._memory_limit for ws in parent._workers_dv.values()]) ), code=code, dask_version=dask.__version__, @@ -6447,9 +6452,9 @@ async def check_worker_ttl(self): parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState now = time() - for ws in parent._workers.values(): + for ws in parent._workers_dv.values(): if (ws._last_seen < now - self.worker_ttl) and ( - ws._last_seen < now - 10 * heartbeat_interval(len(parent._workers)) + ws._last_seen < now - 10 * heartbeat_interval(len(parent._workers_dv)) ): logger.warning( "Worker failed to heartbeat within %s seconds. Closing: %s", @@ -6462,7 +6467,7 @@ def check_idle(self): parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState if ( - any([ws._processing for ws in parent._workers.values()]) + any([ws._processing for ws in parent._workers_dv.values()]) or parent._unrunnable ): self.idle_since = None @@ -6506,7 +6511,7 @@ def adaptive_target(self, comm=None, target_duration=None): # Avoid a few long tasks from asking for many cores ws: WorkerState tasks_processing = 0 - for ws in parent._workers.values(): + for ws in parent._workers_dv.values(): tasks_processing += len(ws._processing) if tasks_processing > cpu: @@ -6514,25 +6519,27 @@ def adaptive_target(self, comm=None, target_duration=None): else: cpu = min(tasks_processing, cpu) - if parent._unrunnable and not parent._workers: + if parent._unrunnable and not parent._workers_dv: cpu = max(1, cpu) # Memory - limit_bytes = {addr: ws._memory_limit for addr, ws in parent._workers.items()} - worker_bytes = [ws._nbytes for ws in parent._workers.values()] + limit_bytes = { + addr: ws._memory_limit for addr, ws in parent._workers_dv.items() + } + worker_bytes = [ws._nbytes for ws in parent._workers_dv.values()] limit = sum(limit_bytes.values()) total = sum(worker_bytes) if total > 0.6 * limit: - memory = 2 * len(parent._workers) + memory = 2 * len(parent._workers_dv) else: memory = 0 target = max(memory, cpu) - if target >= len(parent._workers): + if target >= len(parent._workers_dv): return target else: # Scale down? to_close = self.workers_to_close() - return len(parent._workers) - len(to_close) + return len(parent._workers_dv) - len(to_close) @cfunc From 38c1c35ed6ebc50f9734f6c7393a2c33c2db21b7 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 12:50:17 -0800 Subject: [PATCH 22/38] Add `@property`s for `SchedulerState` attributes Ensures these Cython attributes can be retrieved by Python code. --- distributed/scheduler.py | 68 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index aa36c4dcf67..d2b0dc8eab2 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1642,6 +1642,74 @@ def __init__( self._workers_dv: dict = cast(dict, self._workers) super().__init__(**kwargs) + @property + def bandwidth(self): + return self._bandwidth + + @property + def clients(self): + return self._clients + + @property + def extensions(self): + return self._extensions + + @property + def host_info(self): + return self._host_info + + @property + def idle(self): + return self._idle + + @property + def n_tasks(self): + return self._n_tasks + + @property + def resources(self): + return self._resources + + @property + def saturated(self): + return self._saturated + + @property + def tasks(self): + return self._tasks + + @property + def total_nthreads(self): + return self._total_nthreads + + @property + def total_occupancy(self): + return self._total_occupancy + + @total_occupancy.setter + def total_occupancy(self, v: double): + self._total_occupancy = v + + @property + def unknown_durations(self): + return self._unknown_durations + + @property + def unrunnable(self): + return self._unrunnable + + @property + def validate(self): + return self._validate + + @validate.setter + def validate(self, v: bint): + self._validate = v + + @property + def workers(self): + return self._workers + def _remove_from_processing(self, ts: TaskState) -> str: """ Remove *ts* from the set of processing tasks. From 8fe83880ad820c8f1bb153fba9fa8d03daa24eac Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 13:01:45 -0800 Subject: [PATCH 23/38] Take `worker_msgs` arg in `_propagate_forgotten` Matches the behavior around `recommendations`. Also simplifies the handling of `worker_msgs` (it's created only once). --- distributed/scheduler.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index d2b0dc8eab2..8c001dd18e4 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2509,8 +2509,9 @@ def transition_no_worker_released(self, key): pdb.set_trace() raise - def _propagate_forgotten(self, ts: TaskState, recommendations: dict): - worker_msgs: dict = {} + def _propagate_forgotten( + self, ts: TaskState, recommendations: dict, worker_msgs: dict + ): ts.state = "forgotten" key: str = ts._key dts: TaskState @@ -2547,8 +2548,6 @@ def _propagate_forgotten(self, ts: TaskState, recommendations: dict): worker_msgs[w] = {"op": "delete-data", "keys": [key], "report": False} ts._who_has.clear() - return worker_msgs - def transition_memory_forgotten(self, key): ws: WorkerState try: @@ -2578,7 +2577,7 @@ def transition_memory_forgotten(self, key): for ws in ts._who_has: ws._actors.discard(ts) - worker_msgs = self._propagate_forgotten(ts, recommendations) + self._propagate_forgotten(ts, recommendations, worker_msgs) client_msgs = self._task_to_client_msgs(ts) self.remove_key(key) @@ -2616,7 +2615,7 @@ def transition_released_forgotten(self, key): assert 0, (ts,) recommendations: dict = {} - self._propagate_forgotten(ts, recommendations) + self._propagate_forgotten(ts, recommendations, worker_msgs) client_msgs = self._task_to_client_msgs(ts) self.remove_key(key) From 7b5dec3b3102d7aa3624a5c32022b045e3808f8b Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Thu, 17 Dec 2020 19:00:49 -0800 Subject: [PATCH 24/38] Use `parent` for private methods As these will become `cfunc` decorated functions (only accessible from the C API), we will need to `cast` to `SchedulerState` to get access to them, which is what we do here in preparation for that. --- distributed/scheduler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 8c001dd18e4..ded9d7c3f7e 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4389,7 +4389,7 @@ def client_releases_keys(self, keys=None, client=None): cs: ClientState = parent._clients[client] recommendations: dict = {} - self._client_releases_keys(keys=keys, cs=cs, recommendations=recommendations) + parent._client_releases_keys(keys=keys, cs=cs, recommendations=recommendations) self.transitions(recommendations) def client_heartbeat(self, client=None): @@ -5693,7 +5693,7 @@ def report_on_key(self, key: str = None, ts: TaskState = None, client: str = Non assert False, (key, ts) return - report_msg: dict = self._task_to_report_msg(ts) + report_msg: dict = parent._task_to_report_msg(ts) if report_msg is not None: self.report(report_msg, ts=ts, client=client) @@ -6498,7 +6498,7 @@ def reevaluate_occupancy(self, worker_index=0): try: if ws is None or not ws._processing: continue - self._reevaluate_occupancy_worker(ws) + parent._reevaluate_occupancy_worker(ws) finally: del ws # lose ref From 8a9a1de5a627ccf1e76b290229b55b4973eea7a2 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 13:14:05 -0800 Subject: [PATCH 25/38] Annotate function arguments and return types --- distributed/scheduler.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ded9d7c3f7e..9f9e8507b45 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1738,7 +1738,7 @@ def _add_to_memory( recommendations: dict, client_msgs: dict, type=None, - typename=None, + typename: str = None, **kwargs, ): """ @@ -1939,7 +1939,7 @@ def decide_worker(self, ts: TaskState) -> WorkerState: return ws - def set_duration_estimate(self, ts: TaskState, ws: WorkerState): + def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> double: """Estimate task duration using worker state and task state. If a task takes longer than twice the current average duration we @@ -2045,7 +2045,7 @@ def transition_processing_memory( key, nbytes=None, type=None, - typename=None, + typename: str = None, worker=None, startstops=None, **kwargs, @@ -2155,7 +2155,7 @@ def transition_processing_memory( pdb.set_trace() raise - def transition_memory_released(self, key, safe=False): + def transition_memory_released(self, key, safe: bint = False): ws: WorkerState try: ts: TaskState = self._tasks[key] @@ -2788,7 +2788,7 @@ def _reevaluate_occupancy_worker(self, ws: WorkerState): steal.remove_key_from_stealable(ts) steal.put_key_in_stealable(ts) - def get_comm_cost(self, ts: TaskState, ws: WorkerState): + def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> double: """ Get the estimated communication cost (in s.) to compute the task on the given worker. @@ -2801,7 +2801,7 @@ def get_comm_cost(self, ts: TaskState, ws: WorkerState): nbytes += dts._nbytes return nbytes / bandwidth - def get_task_duration(self, ts: TaskState, default: double = -1): + def get_task_duration(self, ts: TaskState, default: double = -1) -> double: """ Get the estimated computation cost of the given task (not including any communication cost). @@ -2867,7 +2867,7 @@ def valid_workers(self, ts: TaskState) -> set: return s - def worker_objective(self, ts: TaskState, ws: WorkerState): + def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple: """ Objective function to determine which worker should get the task From 2114ea8d498b1a12c0ad72fdb6f450d4359f7860 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Thu, 17 Dec 2020 18:51:40 -0800 Subject: [PATCH 26/38] Use `cfunc` on private `SchedulerState` methods --- distributed/scheduler.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 9f9e8507b45..4f6faf9ace4 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1710,6 +1710,8 @@ def validate(self, v: bint): def workers(self): return self._workers + @cfunc + @exceptval(check=False) def _remove_from_processing(self, ts: TaskState) -> str: """ Remove *ts* from the set of processing tasks. @@ -2509,6 +2511,8 @@ def transition_no_worker_released(self, key): pdb.set_trace() raise + @cfunc + @exceptval(check=False) def _propagate_forgotten( self, ts: TaskState, recommendations: dict, worker_msgs: dict ): @@ -2670,6 +2674,8 @@ def check_idle_saturated(self, ws: WorkerState, occ: double = -1.0): saturated.discard(ws) + @cfunc + @exceptval(check=False) def _client_releases_keys(self, keys: list, cs: ClientState, recommendations: dict): """ Remove keys from client desired list """ logger.debug("Client %s releases keys: %s", cs._client_key, keys) @@ -2691,6 +2697,8 @@ def _client_releases_keys(self, keys: list, cs: ClientState, recommendations: di elif ts._state != "erred" and not ts._waiters: recommendations[ts._key] = "released" + @cfunc + @exceptval(check=False) def _task_to_msg(self, ts: TaskState, duration=None) -> dict: """ Convert a single computational task to a message """ ws: WorkerState @@ -2728,6 +2736,8 @@ def _task_to_msg(self, ts: TaskState, duration=None) -> dict: return msg + @cfunc + @exceptval(check=False) def _task_to_report_msg(self, ts: TaskState) -> dict: if ts is None: return {"op": "cancelled-key", "key": ts._key} @@ -2746,6 +2756,8 @@ def _task_to_report_msg(self, ts: TaskState) -> dict: else: return None + @cfunc + @exceptval(check=False) def _task_to_client_msgs(self, ts: TaskState) -> dict: cs: ClientState client_keys: list @@ -2764,6 +2776,8 @@ def _task_to_client_msgs(self, ts: TaskState) -> dict: return client_msgs + @cfunc + @exceptval(check=False) def _reevaluate_occupancy_worker(self, ws: WorkerState): """ See reevaluate_occupancy """ old: double = ws._occupancy From 096b66171a2508ab18ec06e867e589d554b09cbd Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 16 Dec 2020 13:18:48 -0800 Subject: [PATCH 27/38] Decorate `SchedulerState` methods with `@ccall` Should allow faster C API calls to these methods in addition to their existing Python APIs. Also disable exception checking when Python objects are `return`ed since these already can and are checked for an exception. Note that `@ccall` is not permitted on functions taking `**kwargs` so those have been skipped. --- distributed/scheduler.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 4f6faf9ace4..851701eceff 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1899,6 +1899,8 @@ def transition_no_worker_waiting(self, key): pdb.set_trace() raise + @ccall + @exceptval(check=False) def decide_worker(self, ts: TaskState) -> WorkerState: """ Decide on a worker for task *ts*. Return a WorkerState. @@ -1941,6 +1943,7 @@ def decide_worker(self, ts: TaskState) -> WorkerState: return ws + @ccall def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> double: """Estimate task duration using worker state and task state. @@ -2633,6 +2636,8 @@ def transition_released_forgotten(self, key): pdb.set_trace() raise + @ccall + @exceptval(check=False) def check_idle_saturated(self, ws: WorkerState, occ: double = -1.0): """Update the status of the idle and saturated state @@ -2802,6 +2807,7 @@ def _reevaluate_occupancy_worker(self, ws: WorkerState): steal.remove_key_from_stealable(ts) steal.put_key_in_stealable(ts) + @ccall def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> double: """ Get the estimated communication cost (in s.) to compute the task @@ -2815,6 +2821,7 @@ def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> double: nbytes += dts._nbytes return nbytes / bandwidth + @ccall def get_task_duration(self, ts: TaskState, default: double = -1) -> double: """ Get the estimated computation cost of the given task @@ -2831,6 +2838,8 @@ def get_task_duration(self, ts: TaskState, default: double = -1) -> double: return duration + @ccall + @exceptval(check=False) def valid_workers(self, ts: TaskState) -> set: """Return set of currently valid workers for key @@ -2881,6 +2890,8 @@ def valid_workers(self, ts: TaskState) -> set: return s + @ccall + @exceptval(check=False) def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple: """ Objective function to determine which worker should get the task From bbf156c63954f8118caf7b4c731b91ebfb45dcb3 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Fri, 8 Jan 2021 13:49:35 -0800 Subject: [PATCH 28/38] Add optional args to `transition_waiting_memory` --- distributed/scheduler.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 851701eceff..6c9e5bcf872 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2007,7 +2007,9 @@ def transition_waiting_processing(self, key): pdb.set_trace() raise - def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs): + def transition_waiting_memory( + self, key, nbytes=None, type=None, typename: str = None, worker=None, **kwargs + ): try: ws: WorkerState = self._workers_dv[worker] ts: TaskState = self._tasks[key] @@ -2029,7 +2031,9 @@ def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs): recommendations: dict = {} client_msgs: dict = {} - self._add_to_memory(ts, ws, recommendations, client_msgs, **kwargs) + self._add_to_memory( + ts, ws, recommendations, client_msgs, type=type, typename=typename + ) if self._validate: assert not ts._processing_on From 5fdc87134c7391c102a3f2e138549c1c5bb0d9bc Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Fri, 8 Jan 2021 13:02:03 -0800 Subject: [PATCH 29/38] Drop `**kwargs` from `_add_to_memory` --- distributed/scheduler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 6c9e5bcf872..53572cd0216 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1741,7 +1741,6 @@ def _add_to_memory( client_msgs: dict, type=None, typename: str = None, - **kwargs, ): """ Add *ts* to the set of in-memory tasks. From 2c76c6ab13799c4ad7d18828faac3bdb568078a1 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Fri, 8 Jan 2021 13:03:49 -0800 Subject: [PATCH 30/38] Use `cfunc` to decorate `_add_to_memory` --- distributed/scheduler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 53572cd0216..b4d0b6eab86 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1733,6 +1733,8 @@ def _remove_from_processing(self, ts: TaskState) -> str: else: return None + @cfunc + @exceptval(check=False) def _add_to_memory( self, ts: TaskState, From 7fb522869de399f6a7ec28f116a2e0781f6c6f58 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Fri, 8 Jan 2021 12:45:37 -0800 Subject: [PATCH 31/38] Make `SchedulerState` private methods functions Gets them out of the `SchedulerState` virtual table so they can be used like normal C functions. Should allow inlining and other optimizations that apply. --- distributed/scheduler.py | 560 ++++++++++++++++++++------------------- 1 file changed, 286 insertions(+), 274 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index b4d0b6eab86..6efe6ea0e1b 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1710,94 +1710,6 @@ def validate(self, v: bint): def workers(self): return self._workers - @cfunc - @exceptval(check=False) - def _remove_from_processing(self, ts: TaskState) -> str: - """ - Remove *ts* from the set of processing tasks. - """ - ws: WorkerState = ts._processing_on - ts._processing_on = None - w: str = ws._address - if w in self._workers_dv: # may have been removed - duration = ws._processing.pop(ts) - if not ws._processing: - self._total_occupancy -= ws._occupancy - ws._occupancy = 0 - else: - self._total_occupancy -= duration - ws._occupancy -= duration - self.check_idle_saturated(ws) - self.release_resources(ts, ws) - return w - else: - return None - - @cfunc - @exceptval(check=False) - def _add_to_memory( - self, - ts: TaskState, - ws: WorkerState, - recommendations: dict, - client_msgs: dict, - type=None, - typename: str = None, - ): - """ - Add *ts* to the set of in-memory tasks. - """ - if self._validate: - assert ts not in ws._has_what - - ts._who_has.add(ws) - ws._has_what.add(ts) - ws._nbytes += ts.get_nbytes() - - deps: list = list(ts._dependents) - if len(deps) > 1: - deps.sort(key=operator.attrgetter("priority"), reverse=True) - - dts: TaskState - s: set - for dts in deps: - s = dts._waiting_on - if ts in s: - s.discard(ts) - if not s: # new task ready to run - recommendations[dts._key] = "processing" - - for dts in ts._dependencies: - s = dts._waiters - s.discard(ts) - if not s and not dts._who_wants: - recommendations[dts._key] = "released" - - report_msg: dict = {} - cs: ClientState - if not ts._waiters and not ts._who_wants: - recommendations[ts._key] = "released" - else: - report_msg["op"] = "key-in-memory" - report_msg["key"] = ts._key - if type is not None: - report_msg["type"] = type - - for cs in ts._who_wants: - client_msgs[cs._client_key] = report_msg - - ts.state = "memory" - ts._type = typename - ts._group._types.add(typename) - - cs = self._clients["fire-and-forget"] - if ts in cs._wants_what: - self._client_releases_keys( - cs=cs, - keys=[ts._key], - recommendations=recommendations, - ) - def transition_released_waiting(self, key): try: ts: TaskState = self._tasks[key] @@ -1997,7 +1909,7 @@ def transition_waiting_processing(self, key): # logger.debug("Send job to worker: %s, %s", worker, key) - worker_msgs[worker] = self._task_to_msg(ts) + worker_msgs[worker] = _task_to_msg(self, ts) return {}, worker_msgs, client_msgs except Exception as e: @@ -2032,8 +1944,8 @@ def transition_waiting_memory( recommendations: dict = {} client_msgs: dict = {} - self._add_to_memory( - ts, ws, recommendations, client_msgs, type=type, typename=typename + _add_to_memory( + self, ts, ws, recommendations, client_msgs, type=type, typename=typename ) if self._validate: @@ -2146,10 +2058,10 @@ def transition_processing_memory( recommendations: dict = {} client_msgs: dict = {} - self._remove_from_processing(ts) + _remove_from_processing(self, ts) - self._add_to_memory( - ts, ws, recommendations, client_msgs, type=type, typename=typename + _add_to_memory( + self, ts, ws, recommendations, client_msgs, type=type, typename=typename ) if self._validate: @@ -2377,7 +2289,7 @@ def transition_processing_released(self, key): assert not ts._waiting_on assert self._tasks[key].state == "processing" - w: str = self._remove_from_processing(ts) + w: str = _remove_from_processing(self, ts) if w: worker_msgs[w] = {"op": "release-task", "key": key} @@ -2432,7 +2344,7 @@ def transition_processing_erred( ws = ts._processing_on ws._actors.remove(ts) - self._remove_from_processing(ts) + _remove_from_processing(self, ts) if exception is not None: ts._exception = exception @@ -2472,7 +2384,8 @@ def transition_processing_erred( cs = self._clients["fire-and-forget"] if ts in cs._wants_what: - self._client_releases_keys( + _client_releases_keys( + self, cs=cs, keys=[key], recommendations=recommendations, @@ -2519,47 +2432,6 @@ def transition_no_worker_released(self, key): pdb.set_trace() raise - @cfunc - @exceptval(check=False) - def _propagate_forgotten( - self, ts: TaskState, recommendations: dict, worker_msgs: dict - ): - ts.state = "forgotten" - key: str = ts._key - dts: TaskState - for dts in ts._dependents: - dts._has_lost_dependencies = True - dts._dependencies.remove(ts) - dts._waiting_on.discard(ts) - if dts._state not in ("memory", "erred"): - # Cannot compute task anymore - recommendations[dts._key] = "forgotten" - ts._dependents.clear() - ts._waiters.clear() - - for dts in ts._dependencies: - dts._dependents.remove(ts) - s: set = dts._waiters - s.discard(ts) - if not dts._dependents and not dts._who_wants: - # Task not needed anymore - assert dts is not ts - recommendations[dts._key] = "forgotten" - ts._dependencies.clear() - ts._waiting_on.clear() - - if ts._who_has: - ts._group._nbytes_in_memory -= ts.get_nbytes() - - ws: WorkerState - for ws in ts._who_has: - ws._has_what.remove(ts) - ws._nbytes -= ts.get_nbytes() - w: str = ws._address - if w in self._workers_dv: # in case worker has died - worker_msgs[w] = {"op": "delete-data", "keys": [key], "report": False} - ts._who_has.clear() - def transition_memory_forgotten(self, key): ws: WorkerState try: @@ -2589,9 +2461,9 @@ def transition_memory_forgotten(self, key): for ws in ts._who_has: ws._actors.discard(ts) - self._propagate_forgotten(ts, recommendations, worker_msgs) + _propagate_forgotten(self, ts, recommendations, worker_msgs) - client_msgs = self._task_to_client_msgs(ts) + client_msgs = _task_to_client_msgs(self, ts) self.remove_key(key) return recommendations, worker_msgs, client_msgs @@ -2627,9 +2499,9 @@ def transition_released_forgotten(self, key): assert 0, (ts,) recommendations: dict = {} - self._propagate_forgotten(ts, recommendations, worker_msgs) + _propagate_forgotten(self, ts, recommendations, worker_msgs) - client_msgs = self._task_to_client_msgs(ts) + client_msgs = _task_to_client_msgs(self, ts) self.remove_key(key) return recommendations, worker_msgs, client_msgs @@ -2684,134 +2556,6 @@ def check_idle_saturated(self, ws: WorkerState, occ: double = -1.0): saturated.discard(ws) - @cfunc - @exceptval(check=False) - def _client_releases_keys(self, keys: list, cs: ClientState, recommendations: dict): - """ Remove keys from client desired list """ - logger.debug("Client %s releases keys: %s", cs._client_key, keys) - ts: TaskState - tasks2: set = set() - for key in keys: - ts = self._tasks.get(key) - if ts is not None and ts in cs._wants_what: - cs._wants_what.remove(ts) - s: set = ts._who_wants - s.remove(cs) - if not s: - tasks2.add(ts) - - for ts in tasks2: - if not ts._dependents: - # No live dependents, can forget - recommendations[ts._key] = "forgotten" - elif ts._state != "erred" and not ts._waiters: - recommendations[ts._key] = "released" - - @cfunc - @exceptval(check=False) - def _task_to_msg(self, ts: TaskState, duration=None) -> dict: - """ Convert a single computational task to a message """ - ws: WorkerState - dts: TaskState - - if duration is None: - duration = self.get_task_duration(ts) - - msg: dict = { - "op": "compute-task", - "key": ts._key, - "priority": ts._priority, - "duration": duration, - } - if ts._resource_restrictions: - msg["resource_restrictions"] = ts._resource_restrictions - if ts._actor: - msg["actor"] = True - - deps: set = ts._dependencies - if deps: - msg["who_has"] = { - dts._key: [ws._address for ws in dts._who_has] for dts in deps - } - msg["nbytes"] = {dts._key: dts._nbytes for dts in deps} - - if self._validate: - assert all(msg["who_has"].values()) - - task = ts._run_spec - if type(task) is dict: - msg.update(task) - else: - msg["task"] = task - - return msg - - @cfunc - @exceptval(check=False) - def _task_to_report_msg(self, ts: TaskState) -> dict: - if ts is None: - return {"op": "cancelled-key", "key": ts._key} - elif ts._state == "forgotten": - return {"op": "cancelled-key", "key": ts._key} - elif ts._state == "memory": - return {"op": "key-in-memory", "key": ts._key} - elif ts._state == "erred": - failing_ts: TaskState = ts._exception_blame - return { - "op": "task-erred", - "key": ts._key, - "exception": failing_ts._exception, - "traceback": failing_ts._traceback, - } - else: - return None - - @cfunc - @exceptval(check=False) - def _task_to_client_msgs(self, ts: TaskState) -> dict: - cs: ClientState - client_keys: list - if ts is None: - # Notify all clients - client_keys = list(self._clients) - else: - # Notify clients interested in key - client_keys = [cs._client_key for cs in ts._who_wants] - - report_msg: dict = self._task_to_report_msg(ts) - - client_msgs: dict = {} - for k in client_keys: - client_msgs[k] = report_msg - - return client_msgs - - @cfunc - @exceptval(check=False) - def _reevaluate_occupancy_worker(self, ws: WorkerState): - """ See reevaluate_occupancy """ - old: double = ws._occupancy - new: double = 0 - diff: double - ts: TaskState - est: double - for ts in ws._processing: - est = self.set_duration_estimate(ts, ws) - new += est - - ws._occupancy = new - diff = new - old - self._total_occupancy += diff - self.check_idle_saturated(ws) - - # significant increase in duration - if new > old * 1.3: - steal = self._extensions.get("stealing") - if steal is not None: - for ts in ws._processing: - steal.remove_key_from_stealable(ts) - steal.put_key_in_stealable(ts) - @ccall def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> double: """ @@ -4419,7 +4163,7 @@ def client_releases_keys(self, keys=None, client=None): cs: ClientState = parent._clients[client] recommendations: dict = {} - parent._client_releases_keys(keys=keys, cs=cs, recommendations=recommendations) + _client_releases_keys(parent, keys=keys, cs=cs, recommendations=recommendations) self.transitions(recommendations) def client_heartbeat(self, client=None): @@ -4705,7 +4449,7 @@ def send_task_to_worker(self, worker, ts: TaskState, duration=None): """ Send a single computational task to a worker """ parent: SchedulerState = cast(SchedulerState, self) try: - msg: dict = parent._task_to_msg(ts, duration) + msg: dict = _task_to_msg(parent, ts, duration) self.worker_send(worker, msg) except Exception as e: logger.exception(e) @@ -5723,7 +5467,7 @@ def report_on_key(self, key: str = None, ts: TaskState = None, client: str = Non assert False, (key, ts) return - report_msg: dict = parent._task_to_report_msg(ts) + report_msg: dict = _task_to_report_msg(parent, ts) if report_msg is not None: self.report(report_msg, ts=ts, client=client) @@ -6528,7 +6272,7 @@ def reevaluate_occupancy(self, worker_index=0): try: if ws is None or not ws._processing: continue - parent._reevaluate_occupancy_worker(ws) + _reevaluate_occupancy_worker(parent, ws) finally: del ws # lose ref @@ -6639,6 +6383,274 @@ def adaptive_target(self, comm=None, target_duration=None): return len(parent._workers_dv) - len(to_close) +@cfunc +@exceptval(check=False) +def _remove_from_processing(state: SchedulerState, ts: TaskState) -> str: + """ + Remove *ts* from the set of processing tasks. + """ + ws: WorkerState = ts._processing_on + ts._processing_on = None + w: str = ws._address + if w in state._workers_dv: # may have been removed + duration = ws._processing.pop(ts) + if not ws._processing: + state._total_occupancy -= ws._occupancy + ws._occupancy = 0 + else: + state._total_occupancy -= duration + ws._occupancy -= duration + state.check_idle_saturated(ws) + state.release_resources(ts, ws) + return w + else: + return None + + +@cfunc +@exceptval(check=False) +def _add_to_memory( + state: SchedulerState, + ts: TaskState, + ws: WorkerState, + recommendations: dict, + client_msgs: dict, + type=None, + typename: str = None, +): + """ + Add *ts* to the set of in-memory tasks. + """ + if state._validate: + assert ts not in ws._has_what + + ts._who_has.add(ws) + ws._has_what.add(ts) + ws._nbytes += ts.get_nbytes() + + deps: list = list(ts._dependents) + if len(deps) > 1: + deps.sort(key=operator.attrgetter("priority"), reverse=True) + + dts: TaskState + s: set + for dts in deps: + s = dts._waiting_on + if ts in s: + s.discard(ts) + if not s: # new task ready to run + recommendations[dts._key] = "processing" + + for dts in ts._dependencies: + s = dts._waiters + s.discard(ts) + if not s and not dts._who_wants: + recommendations[dts._key] = "released" + + report_msg: dict = {} + cs: ClientState + if not ts._waiters and not ts._who_wants: + recommendations[ts._key] = "released" + else: + report_msg["op"] = "key-in-memory" + report_msg["key"] = ts._key + if type is not None: + report_msg["type"] = type + + for cs in ts._who_wants: + client_msgs[cs._client_key] = report_msg + + ts.state = "memory" + ts._type = typename + ts._group._types.add(typename) + + cs = state._clients["fire-and-forget"] + if ts in cs._wants_what: + _client_releases_keys( + state, + cs=cs, + keys=[ts._key], + recommendations=recommendations, + ) + + +@cfunc +@exceptval(check=False) +def _propagate_forgotten( + state: SchedulerState, ts: TaskState, recommendations: dict, worker_msgs: dict +): + ts.state = "forgotten" + key: str = ts._key + dts: TaskState + for dts in ts._dependents: + dts._has_lost_dependencies = True + dts._dependencies.remove(ts) + dts._waiting_on.discard(ts) + if dts._state not in ("memory", "erred"): + # Cannot compute task anymore + recommendations[dts._key] = "forgotten" + ts._dependents.clear() + ts._waiters.clear() + + for dts in ts._dependencies: + dts._dependents.remove(ts) + s: set = dts._waiters + s.discard(ts) + if not dts._dependents and not dts._who_wants: + # Task not needed anymore + assert dts is not ts + recommendations[dts._key] = "forgotten" + ts._dependencies.clear() + ts._waiting_on.clear() + + if ts._who_has: + ts._group._nbytes_in_memory -= ts.get_nbytes() + + ws: WorkerState + for ws in ts._who_has: + ws._has_what.remove(ts) + ws._nbytes -= ts.get_nbytes() + w: str = ws._address + if w in state._workers_dv: # in case worker has died + worker_msgs[w] = {"op": "delete-data", "keys": [key], "report": False} + ts._who_has.clear() + + +@cfunc +@exceptval(check=False) +def _client_releases_keys( + state: SchedulerState, keys: list, cs: ClientState, recommendations: dict +): + """ Remove keys from client desired list """ + logger.debug("Client %s releases keys: %s", cs._client_key, keys) + ts: TaskState + tasks2: set = set() + for key in keys: + ts = state._tasks.get(key) + if ts is not None and ts in cs._wants_what: + cs._wants_what.remove(ts) + s: set = ts._who_wants + s.remove(cs) + if not s: + tasks2.add(ts) + + for ts in tasks2: + if not ts._dependents: + # No live dependents, can forget + recommendations[ts._key] = "forgotten" + elif ts._state != "erred" and not ts._waiters: + recommendations[ts._key] = "released" + + +@cfunc +@exceptval(check=False) +def _task_to_msg(state: SchedulerState, ts: TaskState, duration=None) -> dict: + """ Convert a single computational task to a message """ + ws: WorkerState + dts: TaskState + + if duration is None: + duration = state.get_task_duration(ts) + + msg: dict = { + "op": "compute-task", + "key": ts._key, + "priority": ts._priority, + "duration": duration, + } + if ts._resource_restrictions: + msg["resource_restrictions"] = ts._resource_restrictions + if ts._actor: + msg["actor"] = True + + deps: set = ts._dependencies + if deps: + msg["who_has"] = { + dts._key: [ws._address for ws in dts._who_has] for dts in deps + } + msg["nbytes"] = {dts._key: dts._nbytes for dts in deps} + + if state._validate: + assert all(msg["who_has"].values()) + + task = ts._run_spec + if type(task) is dict: + msg.update(task) + else: + msg["task"] = task + + return msg + + +@cfunc +@exceptval(check=False) +def _task_to_report_msg(state: SchedulerState, ts: TaskState) -> dict: + if ts is None: + return {"op": "cancelled-key", "key": ts._key} + elif ts._state == "forgotten": + return {"op": "cancelled-key", "key": ts._key} + elif ts._state == "memory": + return {"op": "key-in-memory", "key": ts._key} + elif ts._state == "erred": + failing_ts: TaskState = ts._exception_blame + return { + "op": "task-erred", + "key": ts._key, + "exception": failing_ts._exception, + "traceback": failing_ts._traceback, + } + else: + return None + + +@cfunc +@exceptval(check=False) +def _task_to_client_msgs(state: SchedulerState, ts: TaskState) -> dict: + cs: ClientState + client_keys: list + if ts is None: + # Notify all clients + client_keys = list(state._clients) + else: + # Notify clients interested in key + client_keys = [cs._client_key for cs in ts._who_wants] + + report_msg: dict = _task_to_report_msg(state, ts) + + client_msgs: dict = {} + for k in client_keys: + client_msgs[k] = report_msg + + return client_msgs + + +@cfunc +@exceptval(check=False) +def _reevaluate_occupancy_worker(state: SchedulerState, ws: WorkerState): + """ See reevaluate_occupancy """ + old: double = ws._occupancy + new: double = 0 + diff: double + ts: TaskState + est: double + for ts in ws._processing: + est = state.set_duration_estimate(ts, ws) + new += est + + ws._occupancy = new + diff = new - old + state._total_occupancy += diff + state.check_idle_saturated(ws) + + # significant increase in duration + if new > old * 1.3: + steal = state._extensions.get("stealing") + if steal is not None: + for ts in ws._processing: + steal.remove_key_from_stealable(ts) + steal.put_key_in_stealable(ts) + + @cfunc @exceptval(check=False) def decide_worker( From c7f2078db46ccf23050eb308c6b59d3638c424c7 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Fri, 8 Jan 2021 18:33:24 -0800 Subject: [PATCH 32/38] Collect `@property`s in `__pdict__` As `@property`s are not placed in `__dict__`, but we need these in some kind of `dict` for the dashboard. Collect them in `__pdict__`. This should allow us to easily access them and place them in the dashboard. https://stackoverflow.com/q/47432613/3877089 --- distributed/scheduler.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 6efe6ea0e1b..4b453aec420 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1710,6 +1710,26 @@ def validate(self, v: bint): def workers(self): return self._workers + @property + def __pdict__(self): + return { + "bandwidth": self._bandwidth, + "resources": self._resources, + "saturated": self._saturated, + "unrunnable": self._unrunnable, + "n_tasks": self._n_tasks, + "unknown_durations": self._unknown_durations, + "validate": self._validate, + "tasks": self._tasks, + "total_nthreads": self._total_nthreads, + "total_occupancy": self._total_occupancy, + "extensions": self._extensions, + "clients": self._clients, + "workers": self._workers, + "idle": self._idle, + "host_info": self._host_info, + } + def transition_released_waiting(self, key): try: ts: TaskState = self._tasks[key] From 2212edc11b424c17f2f656eb73b2ca206742705a Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Fri, 8 Jan 2021 18:38:49 -0800 Subject: [PATCH 33/38] Use `__pdict__` with `__dict__` in dashboard Should give us the `@property`s as well. --- distributed/http/scheduler/info.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/distributed/http/scheduler/info.py b/distributed/http/scheduler/info.py index 6e5a222dd23..96199faba38 100644 --- a/distributed/http/scheduler/info.py +++ b/distributed/http/scheduler/info.py @@ -33,7 +33,13 @@ def get(self): "workers.html", title="Workers", scheduler=self.server, - **merge(self.server.__dict__, ns, self.extra, rel_path_statics), + **merge( + self.server.__dict__, + self.server.__pdict__, + ns, + self.extra, + rel_path_statics, + ), ) @@ -49,7 +55,13 @@ def get(self, worker): title="Worker: " + worker, scheduler=self.server, Worker=worker, - **merge(self.server.__dict__, ns, self.extra, rel_path_statics), + **merge( + self.server.__dict__, + self.server.__pdict__, + ns, + self.extra, + rel_path_statics, + ), ) @@ -65,7 +77,13 @@ def get(self, task): title="Task: " + task, Task=task, scheduler=self.server, - **merge(self.server.__dict__, ns, self.extra, rel_path_statics), + **merge( + self.server.__dict__, + self.server.__pdict__, + ns, + self.extra, + rel_path_statics, + ), ) From bef52404443ccb03fb0c138c1ccfaa21e310372a Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 20 Jan 2021 13:58:18 -0800 Subject: [PATCH 34/38] Refactor `consume_resources` & `release_resources` --- distributed/scheduler.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index a10b428a69c..921a8a28f09 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2534,6 +2534,10 @@ def transition_released_forgotten(self, key): pdb.set_trace() raise + ############################## + # Assigning Tasks to Workers # + ############################## + @ccall @exceptval(check=False) def check_idle_saturated(self, ws: WorkerState, occ: double = -1.0): @@ -2660,6 +2664,18 @@ def valid_workers(self, ts: TaskState) -> set: return s + @ccall + def consume_resources(self, ts: TaskState, ws: WorkerState): + if ts._resource_restrictions: + for r, required in ts._resource_restrictions.items(): + ws._used_resources[r] += required + + @ccall + def release_resources(self, ts: TaskState, ws: WorkerState): + if ts._resource_restrictions: + for r, required in ts._resource_restrictions.items(): + ws._used_resources[r] -= required + @ccall @exceptval(check=False) def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple: @@ -5913,20 +5929,6 @@ def reschedule(self, key=None, worker=None): return self.transitions({key: "released"}) - ############################## - # Assigning Tasks to Workers # - ############################## - - def consume_resources(self, ts: TaskState, ws: WorkerState): - if ts._resource_restrictions: - for r, required in ts._resource_restrictions.items(): - ws._used_resources[r] += required - - def release_resources(self, ts: TaskState, ws: WorkerState): - if ts._resource_restrictions: - for r, required in ts._resource_restrictions.items(): - ws._used_resources[r] -= required - ##################### # Utility functions # ##################### From caf3c1a9e3c9eb22fc04a4654d49204b88857d9e Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 20 Jan 2021 14:33:42 -0800 Subject: [PATCH 35/38] Refactor `aliases` into `SchedulerState` --- distributed/scheduler.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 921a8a28f09..3811210dec0 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1574,6 +1574,7 @@ class SchedulerState: Time we expect certain functions to take, e.g. ``{'sum': 0.25}`` """ + _aliases: dict _bandwidth: double _clients: dict _extensions: dict @@ -1594,6 +1595,7 @@ class SchedulerState: def __init__( self, + aliases: dict = None, clients: dict = None, workers=None, host_info=None, @@ -1603,6 +1605,10 @@ def __init__( validate: bint = False, **kwargs, ): + if aliases is not None: + self._aliases = aliases + else: + self._aliases = dict() self._bandwidth = parse_bytes( dask.config.get("distributed.scheduler.bandwidth") ) @@ -1643,6 +1649,10 @@ def __init__( self._workers_dv: dict = cast(dict, self._workers) super().__init__(**kwargs) + @property + def aliases(self): + return self._aliases + @property def bandwidth(self): return self._bandwidth @@ -2966,7 +2976,7 @@ def __init__( host_info = defaultdict(dict) resources = defaultdict(dict) - self.aliases = dict() + aliases = dict() self._task_state_collections = [unrunnable] @@ -2974,7 +2984,7 @@ def __init__( workers, host_info, resources, - self.aliases, + aliases, ] self.plugins = list(plugins) @@ -3081,6 +3091,7 @@ def __init__( connection_limit = get_fileno_limit() / 2 super().__init__( + aliases=aliases, handlers=self.handlers, stream_handlers=merge(worker_handlers, client_handlers), io_loop=self.loop, @@ -3427,7 +3438,7 @@ async def add_worker( if ws is not None: raise ValueError("Worker already exists %s" % ws) - if name in self.aliases: + if name in parent._aliases: logger.warning( "Worker tried to connect with a duplicate name: %s", name ) @@ -3460,7 +3471,7 @@ async def add_worker( parent._host_info[host]["nthreads"] += nthreads parent._total_nthreads += nthreads - self.aliases[name] = address + parent._aliases[name] = address response = self.heartbeat_worker( address=address, @@ -4070,7 +4081,7 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): self.rpc.remove(address) del self.stream_comms[address] - del self.aliases[ws._name] + del parent._aliases[ws._name] parent._idle.pop(ws._address, None) parent._saturated.discard(ws) del parent._workers[address] @@ -5958,8 +5969,9 @@ def coerce_address(self, addr, resolve=True): Handles strings, tuples, or aliases. """ # XXX how many address-parsing routines do we have? - if addr in self.aliases: - addr = self.aliases[addr] + parent: SchedulerState = cast(SchedulerState, self) + if addr in parent._aliases: + addr = parent._aliases[addr] if isinstance(addr, tuple): addr = unparse_host_port(*addr) if not isinstance(addr, str): From ab5becae11373b693877d7477b7d695bbf1ae895 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 20 Jan 2021 14:34:46 -0800 Subject: [PATCH 36/38] Refactor `coerce_hostname` into `SchedulerState` --- distributed/scheduler.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 3811210dec0..38b3d86fee7 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2686,6 +2686,16 @@ def release_resources(self, ts: TaskState, ws: WorkerState): for r, required in ts._resource_restrictions.items(): ws._used_resources[r] -= required + @ccall + def coerce_hostname(self, host): + """ + Coerce the hostname of a worker. + """ + if host in self._aliases: + return self._workers_dv[self._aliases[host]].host + else: + return host + @ccall @exceptval(check=False) def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple: @@ -5984,16 +5994,6 @@ def coerce_address(self, addr, resolve=True): return addr - def coerce_hostname(self, host): - """ - Coerce the hostname of a worker. - """ - parent: SchedulerState = cast(SchedulerState, self) - if host in self.aliases: - return parent._workers_dv[self.aliases[host]].host - else: - return host - def workers_list(self, workers): """ List of qualifying workers From f93ecad9c26d40945c4183664736a8b1f2c9f17b Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 20 Jan 2021 14:39:52 -0800 Subject: [PATCH 37/38] Refactor `task_metadata` into `SchedulerState` --- distributed/scheduler.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 38b3d86fee7..109d211a801 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1585,6 +1585,7 @@ class SchedulerState: _resources: object _saturated: set _tasks: dict + _task_metadata: dict _total_nthreads: Py_ssize_t _total_occupancy: double _unknown_durations: object @@ -1634,6 +1635,7 @@ def __init__( self._tasks = tasks else: self._tasks = dict() + self._task_metadata = dict() self._total_nthreads = 0 self._total_occupancy = 0 self._unknown_durations = defaultdict(set) @@ -1689,6 +1691,10 @@ def saturated(self): def tasks(self): return self._tasks + @property + def task_metadata(self): + return self._task_metadata + @property def total_nthreads(self): return self._total_nthreads @@ -2952,7 +2958,6 @@ def __init__( self._last_time = 0 unrunnable = set() - self.task_metadata = dict() self.datasets = dict() # Prefix-keyed containers @@ -5710,8 +5715,9 @@ def run_function(self, stream, function, args=(), kwargs={}, wait=True): return run(self, stream, function=function, args=args, kwargs=kwargs, wait=wait) def set_metadata(self, comm=None, keys=None, value=None): + parent: SchedulerState = cast(SchedulerState, self) try: - metadata = self.task_metadata + metadata = parent._task_metadata for key in keys[:-1]: if key not in metadata or not isinstance(metadata[key], (dict, list)): metadata[key] = dict() @@ -5723,7 +5729,8 @@ def set_metadata(self, comm=None, keys=None, value=None): pdb.set_trace() def get_metadata(self, comm=None, keys=None, default=no_default): - metadata = self.task_metadata + parent: SchedulerState = cast(SchedulerState, self) + metadata = parent._task_metadata for key in keys[:-1]: metadata = metadata[key] try: From 9275830f284168e643ad8c25d3220ad81a287ff9 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 20 Jan 2021 14:40:28 -0800 Subject: [PATCH 38/38] Refactor `remove_key` into `SchedulerState` --- distributed/scheduler.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 109d211a801..cd6ac0eedce 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2469,6 +2469,19 @@ def transition_no_worker_released(self, key): pdb.set_trace() raise + @ccall + def remove_key(self, key): + ts: TaskState = self._tasks.pop(key) + assert ts._state == "forgotten" + self._unrunnable.discard(ts) + cs: ClientState + for cs in ts._who_wants: + cs._wants_what.remove(ts) + ts._who_wants.clear() + ts._processing_on = None + ts._exception_blame = ts._exception = ts._traceback = None + self._task_metadata.pop(key, None) + def transition_memory_forgotten(self, key): ws: WorkerState try: @@ -5789,20 +5802,6 @@ async def register_worker_plugin(self, comm, plugin, name=None): # State Transitions # ##################### - def remove_key(self, key): - parent: SchedulerState = cast(SchedulerState, self) - tasks: dict = parent._tasks - ts: TaskState = tasks.pop(key) - assert ts._state == "forgotten" - parent._unrunnable.discard(ts) - cs: ClientState - for cs in ts._who_wants: - cs._wants_what.remove(ts) - ts._who_wants.clear() - ts._processing_on = None - ts._exception_blame = ts._exception = ts._traceback = None - self.task_metadata.pop(key, None) - def transition(self, key, finish, *args, **kwargs): """Transition a key from its current state to the finish state