From 5431f2711e36e00834f3082908f7eaa8269b708c Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 16 Sep 2021 18:19:54 +0100 Subject: [PATCH 1/5] AMM ReduceReplicas to iterate only on replicated tasks (#5297) --- distributed/active_memory_manager.py | 31 ++-- distributed/scheduler.py | 211 +++++++++++---------------- 2 files changed, 105 insertions(+), 137 deletions(-) diff --git a/distributed/active_memory_manager.py b/distributed/active_memory_manager.py index 4ed2daf4113..77e67ca4319 100644 --- a/distributed/active_memory_manager.py +++ b/distributed/active_memory_manager.py @@ -189,9 +189,14 @@ def _find_recipient( pending_repl: set[WorkerState], ) -> Optional[WorkerState]: """Choose a worker to acquire a new replica of an in-memory task among a set of - candidates. If candidates is None, default to all workers in the cluster that do - not hold a replica yet. The worker with the lowest memory usage (downstream of - pending replications and drops) will be returned. + candidates. If candidates is None, default to all workers in the cluster. + Regardless, workers that either already hold a replica or are scheduled to + receive one at the end of this AMM iteration are not considered. + + Returns + ------- + The worker with the lowest memory usage (downstream of pending replications and + drops), or None if no eligible candidates are available. """ if ts.state != "memory": return None @@ -210,9 +215,15 @@ def _find_dropper( pending_drop: set[WorkerState], ) -> Optional[WorkerState]: """Choose a worker to drop its replica of an in-memory task among a set of - candidates. If candidates is None, default to all workers in the cluster that - hold a replica. The worker with the highest memory usage (downstream of pending - replications and drops) will be returned. + candidates. If candidates is None, default to all workers in the cluster. + Regardless, workers that either do not hold a replica or are already scheduled + to drop theirs at the end of this AMM iteration are not considered. + This method also ensures that a key will not lose its last replica. + + Returns + ------- + The worker with the highest memory usage (downstream of pending replications and + drops), or None if no eligible candidates are available. """ if len(ts.who_has) - len(pending_drop) < 2: return None @@ -283,13 +294,7 @@ class ReduceReplicas(ActiveMemoryManagerPolicy): """ def run(self): - # TODO this is O(n) to the total number of in-memory tasks on the cluster; it - # could be made faster by automatically attaching it to a TaskState when it - # goes above one replica and detaching it when it drops below two. - for ts in self.manager.scheduler.tasks.values(): - if len(ts.who_has) < 2: - continue - + for ts in self.manager.scheduler.replicated_tasks: desired_replicas = 1 # TODO have a marker on TaskState # If a dependent task has not been assigned to a worker yet, err on the side diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 706e85c3b24..04496b26e05 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1859,6 +1859,7 @@ class SchedulerState: _task_groups: dict _task_prefixes: dict _task_metadata: dict + _replicated_tasks: set _total_nthreads: Py_ssize_t _total_occupancy: double _transitions_table: dict @@ -1918,6 +1919,9 @@ def __init__( self._tasks = tasks else: self._tasks = dict() + self._replicated_tasks = { + ts for ts in self._tasks.values() if len(ts._who_has) > 1 + } self._computations = deque( maxlen=dask.config.get("distributed.diagnostics.computations.max-history") ) @@ -2035,6 +2039,10 @@ def task_prefixes(self): def task_metadata(self): return self._task_metadata + @property + def replicated_tasks(self): + return self._replicated_tasks + @property def total_nthreads(self): return self._total_nthreads @@ -2820,18 +2828,14 @@ def transition_memory_released(self, key, safe: bint = False): dts._waiting_on.add(ts) # XXX factor this out? - ts_nbytes: Py_ssize_t = ts.get_nbytes() worker_msg = { "op": "free-keys", "keys": [key], "reason": f"Memory->Released {key}", } for ws in ts._who_has: - del ws._has_what[ts] - ws._nbytes -= ts_nbytes worker_msgs[ws._address] = [worker_msg] - - ts._who_has.clear() + self.remove_all_replicas(ts) ts.state = "released" @@ -3425,6 +3429,40 @@ def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple: else: return (start_time, ws._nbytes) + @ccall + def add_replica(self, ts: TaskState, ws: WorkerState): + """Note that a worker holds a replica of a task with state='memory'""" + if self._validate: + assert ws not in ts._who_has + assert ts not in ws._has_what + + ws._nbytes += ts.get_nbytes() + ws._has_what[ts] = None + ts._who_has.add(ws) + if len(ts._who_has) == 2: + self._replicated_tasks.add(ts) + + @ccall + def remove_replica(self, ts: TaskState, ws: WorkerState): + """Note that a worker no longer holds a replica of a task""" + ws._nbytes -= ts.get_nbytes() + del ws._has_what[ts] + ts._who_has.remove(ws) + if len(ts._who_has) == 1: + self._replicated_tasks.remove(ts) + + @ccall + def remove_all_replicas(self, ts: TaskState): + """Remove all replicas of a task from all workers""" + ws: WorkerState + nbytes: Py_ssize_t = ts.get_nbytes() + for ws in ts._who_has: + ws._nbytes -= nbytes + del ws._has_what[ts] + if len(ts._who_has) > 1: + self._replicated_tasks.remove(ts) + ts._who_has.clear() + class Scheduler(SchedulerState, ServerNode): """Dynamic distributed task scheduler @@ -4736,70 +4774,23 @@ def stimulus_task_erred( parent: SchedulerState = cast(SchedulerState, self) logger.debug("Stimulus task erred %s, %s", key, worker) - recommendations: dict = {} - client_msgs: dict = {} - worker_msgs: dict = {} - ts: TaskState = parent._tasks.get(key) - if ts is None: - return recommendations, client_msgs, worker_msgs - - if ts._state == "processing": - retries: Py_ssize_t = ts._retries - r: tuple - if retries > 0: - ts._retries = retries - 1 - r = parent._transition(key, "waiting") - else: - r = parent._transition( - key, - "erred", - cause=key, - exception=exception, - traceback=traceback, - worker=worker, - **kwargs, - ) - recommendations, client_msgs, worker_msgs = r - - return recommendations, client_msgs, worker_msgs - - def stimulus_missing_data( - self, cause=None, key=None, worker=None, ensure=True, **kwargs - ): - """Mark that certain keys have gone missing. Recover.""" - parent: SchedulerState = cast(SchedulerState, self) - with log_errors(): - logger.debug("Stimulus missing data %s, %s", key, worker) + if ts is None or ts._state != "processing": + return {}, {}, {} - recommendations: dict = {} - client_msgs: dict = {} - worker_msgs: dict = {} - - ts: TaskState = parent._tasks.get(key) - if ts is None or ts._state == "memory": - return recommendations, client_msgs, worker_msgs - cts: TaskState = parent._tasks.get(cause) - - if cts is not None and cts._state == "memory": # couldn't find this - ws: WorkerState - cts_nbytes: Py_ssize_t = cts.get_nbytes() - for ws in cts._who_has: # TODO: this behavior is extreme - del ws._has_what[ts] - ws._nbytes -= cts_nbytes - cts._who_has.clear() - recommendations[cause] = "released" - - if key: - recommendations[key] = "released" - - parent._transitions(recommendations, client_msgs, worker_msgs) - recommendations = {} - - if parent._validate: - assert cause not in self.who_has - - return recommendations, client_msgs, worker_msgs + if ts._retries > 0: + ts._retries -= 1 + return parent._transition(key, "waiting") + else: + return parent._transition( + key, + "erred", + cause=key, + exception=exception, + traceback=traceback, + worker=worker, + **kwargs, + ) def stimulus_retry(self, comm=None, keys=None, client=None): parent: SchedulerState = cast(SchedulerState, self) @@ -4914,14 +4905,13 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): self.allowed_failures, ) - for ts in ws._has_what: - ts._who_has.remove(ws) + for ts in list(ws._has_what): + parent.remove_replica(ts, ws) if not ts._who_has: if ts._run_spec: recommendations[ts._key] = "released" else: # pure data recommendations[ts._key] = "forgotten" - ws._has_what.clear() self.transitions(recommendations) @@ -5071,6 +5061,7 @@ def validate_memory(self, key): ts: TaskState = parent._tasks[key] dts: TaskState assert ts._who_has + assert bool(ts in parent._replicated_tasks) == (len(ts._who_has) > 1) assert not ts._processing_on assert not ts._waiting_on assert ts not in parent._unrunnable @@ -5141,8 +5132,13 @@ def validate_state(self, allow_overlap=False): for k, ts in parent._tasks.items(): assert isinstance(ts, TaskState), (type(ts), ts) assert ts._key == k + assert bool(ts in parent._replicated_tasks) == (len(ts._who_has) > 1) self.validate_key(k, ts) + for ts in parent._replicated_tasks: + assert ts._state == "memory" + assert ts._key in parent._tasks + c: str cs: ClientState for c, cs in parent._clients.items(): @@ -5343,24 +5339,14 @@ def handle_task_erred(self, key=None, **msg): self.send_all(client_msgs, worker_msgs) - def handle_release_data(self, key=None, worker=None, client=None, **msg): + def handle_release_data(self, key=None, worker=None, **kwargs): parent: SchedulerState = cast(SchedulerState, self) ts: TaskState = parent._tasks.get(key) - if ts is None: + if ts is None or ts._state == "memory": return ws: WorkerState = parent._workers_dv.get(worker) - if ws is None or ts._processing_on != ws: - return - - recommendations: dict - client_msgs: dict - worker_msgs: dict - - r: tuple = self.stimulus_missing_data(key=key, ensure=False, **msg) - recommendations, client_msgs, worker_msgs = r - parent._transitions(recommendations, client_msgs, worker_msgs) - - self.send_all(client_msgs, worker_msgs) + if ws is not None and ts._processing_on == ws: + parent._transitions({key: "released"}, {}, {}) def handle_missing_data(self, key=None, errant_worker=None, **kwargs): parent: SchedulerState = cast(SchedulerState, self) @@ -5372,9 +5358,7 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): return ws: WorkerState = parent._workers_dv.get(errant_worker) if ws is not None and ws in ts._who_has: - ts._who_has.remove(ws) - del ws._has_what[ts] - ws._nbytes -= ts.get_nbytes() + parent.remove_replica(ts, ws) if not ts._who_has: if ts._run_spec: self.transitions({key: "released"}) @@ -5392,11 +5376,8 @@ def release_worker_data(self, comm=None, keys=None, worker=None): ts: TaskState recommendations: dict = {} for ts in removed_tasks: - del ws._has_what[ts] - ws._nbytes -= ts.get_nbytes() - wh: set = ts._who_has - wh.remove(ws) - if not wh: + parent.remove_replica(ts, ws) + if not ts._who_has: recommendations[ts._key] = "released" if recommendations: self.transitions(recommendations) @@ -5716,14 +5697,11 @@ async def gather(self, comm=None, keys=None, serializers=None): ) if not workers or ts is None: continue - ts_nbytes: Py_ssize_t = ts.get_nbytes() recommendations: dict = {key: "released"} for worker in workers: ws = parent._workers_dv.get(worker) - if ws is not None and ts in ws._has_what: - del ws._has_what[ts] - ts._who_has.remove(ws) - ws._nbytes -= ts_nbytes + if ws is not None and ws in ts._who_has: + parent.remove_replica(ts, ws) parent._transitions( recommendations, client_msgs, worker_msgs ) @@ -5922,10 +5900,8 @@ async def gather_on_worker( if ts is None or ts._state != "memory": logger.warning(f"Key lost during replication: {key}") continue - if ts not in ws._has_what: - ws._nbytes += ts.get_nbytes() - ws._has_what[ts] = None - ts._who_has.add(ws) + if ws not in ts._who_has: + parent.add_replica(ts, ws) return keys_failed @@ -5962,11 +5938,9 @@ async def delete_worker_data(self, worker_address: str, keys: "list[str]") -> No for key in keys: ts: TaskState = parent._tasks.get(key) - if ts is not None and ts in ws._has_what: + if ts is not None and ws in ts._who_has: assert ts._state == "memory" - del ws._has_what[ts] - ts._who_has.remove(ws) - ws._nbytes -= ts.get_nbytes() + parent.remove_replica(ts, ws) if not ts._who_has: # Last copy deleted self.transitions({key: "released"}) @@ -6714,10 +6688,8 @@ def add_keys(self, comm=None, worker=None, keys=()): for key in keys: ts: TaskState = parent._tasks.get(key) if ts is not None and ts._state == "memory": - if ts not in ws._has_what: - ws._nbytes += ts.get_nbytes() - ws._has_what[ts] = None - ts._who_has.add(ws) + if ws not in ts._who_has: + parent.add_replica(ts, ws) else: superfluous_data.append(key) if superfluous_data: @@ -6759,17 +6731,14 @@ def update_data( if ts is None: ts: TaskState = parent.new_task(key, None, "memory") ts.state = "memory" - ts_nbytes: Py_ssize_t = nbytes.get(key, -1) + ts_nbytes = nbytes.get(key, -1) if ts_nbytes >= 0: ts.set_nbytes(ts_nbytes) - else: - ts_nbytes = ts.get_nbytes() + for w in workers: ws: WorkerState = parent._workers_dv[w] - if ts not in ws._has_what: - ws._nbytes += ts_nbytes - ws._has_what[ts] = None - ts._who_has.add(ws) + if ws not in ts._who_has: + parent.add_replica(ts, ws) self.report( {"op": "key-in-memory", "key": key, "workers": list(workers)} ) @@ -7736,9 +7705,7 @@ def _add_to_memory( if state._validate: assert ts not in ws._has_what - ts._who_has.add(ws) - ws._has_what[ts] = None - ws._nbytes += ts.get_nbytes() + state.add_replica(ts, ws) deps: list = list(ts._dependents) if len(deps) > 1: @@ -7814,12 +7781,8 @@ def _propagate_forgotten( ts._dependencies.clear() ts._waiting_on.clear() - ts_nbytes: Py_ssize_t = ts.get_nbytes() - ws: WorkerState for ws in ts._who_has: - del ws._has_what[ts] - ws._nbytes -= ts_nbytes w: str = ws._address if w in state._workers_dv: # in case worker has died worker_msgs[w] = [ @@ -7829,7 +7792,7 @@ def _propagate_forgotten( "reason": f"propagate-forgotten {ts.key}", } ] - ts._who_has.clear() + state.remove_all_replicas(ts) @cfunc From ebc2011d625d474166f116615a0a45421828fb0f Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 22 Sep 2021 13:53:27 +0100 Subject: [PATCH 2/5] Split out #5340 --- distributed/active_memory_manager.py | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/distributed/active_memory_manager.py b/distributed/active_memory_manager.py index 77e67ca4319..57978d6654f 100644 --- a/distributed/active_memory_manager.py +++ b/distributed/active_memory_manager.py @@ -189,14 +189,9 @@ def _find_recipient( pending_repl: set[WorkerState], ) -> Optional[WorkerState]: """Choose a worker to acquire a new replica of an in-memory task among a set of - candidates. If candidates is None, default to all workers in the cluster. - Regardless, workers that either already hold a replica or are scheduled to - receive one at the end of this AMM iteration are not considered. - - Returns - ------- - The worker with the lowest memory usage (downstream of pending replications and - drops), or None if no eligible candidates are available. + candidates. If candidates is None, default to all workers in the cluster that do + not hold a replica yet. The worker with the lowest memory usage (downstream of + pending replications and drops) will be returned. """ if ts.state != "memory": return None @@ -215,15 +210,9 @@ def _find_dropper( pending_drop: set[WorkerState], ) -> Optional[WorkerState]: """Choose a worker to drop its replica of an in-memory task among a set of - candidates. If candidates is None, default to all workers in the cluster. - Regardless, workers that either do not hold a replica or are already scheduled - to drop theirs at the end of this AMM iteration are not considered. - This method also ensures that a key will not lose its last replica. - - Returns - ------- - The worker with the highest memory usage (downstream of pending replications and - drops), or None if no eligible candidates are available. + candidates. If candidates is None, default to all workers in the cluster that + hold a replica. The worker with the highest memory usage (downstream of pending + replications and drops) will be returned. """ if len(ts.who_has) - len(pending_drop) < 2: return None From 3c2bcfd66a49bb7e092f9840c2438d9343bbdd94 Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 22 Sep 2021 13:20:54 +0200 Subject: [PATCH 3/5] Reintroduce client and worker send in handle_release_data --- distributed/scheduler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 04496b26e05..99218e73aac 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5346,7 +5346,11 @@ def handle_release_data(self, key=None, worker=None, **kwargs): return ws: WorkerState = parent._workers_dv.get(worker) if ws is not None and ts._processing_on == ws: - parent._transitions({key: "released"}, {}, {}) + client_msgs = {} + worker_msgs = {} + # Note: The msgs dicts are filled inplace + parent._transitions({key: "released"}, client_msgs, worker_msgs) + self.send_all(client_msgs, worker_msgs) def handle_missing_data(self, key=None, errant_worker=None, **kwargs): parent: SchedulerState = cast(SchedulerState, self) From 781b80c7dd95def016f4991a0ef7ff8e402813e6 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 27 Sep 2021 20:42:23 +0100 Subject: [PATCH 4/5] clean up old handlers --- distributed/scheduler.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index e9477a9ee04..4d43430f119 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3758,7 +3758,6 @@ def __init__( worker_handlers = { "task-finished": self.handle_task_finished, "task-erred": self.handle_task_erred, - "release": self.handle_release_data, "release-worker-data": self.release_worker_data, "add-keys": self.add_keys, "missing-data": self.handle_missing_data, @@ -5342,19 +5341,6 @@ def handle_task_erred(self, key=None, **msg): self.send_all(client_msgs, worker_msgs) - def handle_release_data(self, key=None, worker=None, **kwargs): - parent: SchedulerState = cast(SchedulerState, self) - ts: TaskState = parent._tasks.get(key) - if ts is None or ts._state == "memory": - return - ws: WorkerState = parent._workers_dv.get(worker) - if ws is not None and ts._processing_on == ws: - client_msgs = {} - worker_msgs = {} - # Note: The msgs dicts are filled inplace - parent._transitions({key: "released"}, client_msgs, worker_msgs) - self.send_all(client_msgs, worker_msgs) - def handle_missing_data(self, key=None, errant_worker=None, **kwargs): parent: SchedulerState = cast(SchedulerState, self) logger.debug("handle missing data key=%s worker=%s", key, errant_worker) From f15f0865fcafa554794492a9818ece2643cec419 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 27 Sep 2021 20:47:37 +0100 Subject: [PATCH 5/5] Split out WSMR cleanup --- distributed/scheduler.py | 97 +++++++++++++++++++++++++++++++++------- 1 file changed, 82 insertions(+), 15 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 4d43430f119..6a542686fde 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3758,6 +3758,7 @@ def __init__( worker_handlers = { "task-finished": self.handle_task_finished, "task-erred": self.handle_task_erred, + "release": self.handle_release_data, "release-worker-data": self.release_worker_data, "add-keys": self.add_keys, "missing-data": self.handle_missing_data, @@ -4776,23 +4777,70 @@ def stimulus_task_erred( parent: SchedulerState = cast(SchedulerState, self) logger.debug("Stimulus task erred %s, %s", key, worker) + recommendations: dict = {} + client_msgs: dict = {} + worker_msgs: dict = {} + ts: TaskState = parent._tasks.get(key) - if ts is None or ts._state != "processing": - return {}, {}, {} + if ts is None: + return recommendations, client_msgs, worker_msgs - if ts._retries > 0: - ts._retries -= 1 - return parent._transition(key, "waiting") - else: - return parent._transition( - key, - "erred", - cause=key, - exception=exception, - traceback=traceback, - worker=worker, - **kwargs, - ) + if ts._state == "processing": + retries: Py_ssize_t = ts._retries + r: tuple + if retries > 0: + ts._retries = retries - 1 + r = parent._transition(key, "waiting") + else: + r = parent._transition( + key, + "erred", + cause=key, + exception=exception, + traceback=traceback, + worker=worker, + **kwargs, + ) + recommendations, client_msgs, worker_msgs = r + + return recommendations, client_msgs, worker_msgs + + def stimulus_missing_data( + self, cause=None, key=None, worker=None, ensure=True, **kwargs + ): + """Mark that certain keys have gone missing. Recover.""" + parent: SchedulerState = cast(SchedulerState, self) + with log_errors(): + logger.debug("Stimulus missing data %s, %s", key, worker) + + recommendations: dict = {} + client_msgs: dict = {} + worker_msgs: dict = {} + + ts: TaskState = parent._tasks.get(key) + if ts is None or ts._state == "memory": + return recommendations, client_msgs, worker_msgs + cts: TaskState = parent._tasks.get(cause) + + if cts is not None and cts._state == "memory": # couldn't find this + ws: WorkerState + cts_nbytes: Py_ssize_t = cts.get_nbytes() + for ws in cts._who_has: # TODO: this behavior is extreme + del ws._has_what[ts] + ws._nbytes -= cts_nbytes + cts._who_has.clear() + recommendations[cause] = "released" + + if key: + recommendations[key] = "released" + + parent._transitions(recommendations, client_msgs, worker_msgs) + recommendations = {} + + if parent._validate: + assert cause not in self.who_has + + return recommendations, client_msgs, worker_msgs def stimulus_retry(self, comm=None, keys=None, client=None): parent: SchedulerState = cast(SchedulerState, self) @@ -5341,6 +5389,25 @@ def handle_task_erred(self, key=None, **msg): self.send_all(client_msgs, worker_msgs) + def handle_release_data(self, key=None, worker=None, client=None, **msg): + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState = parent._tasks.get(key) + if ts is None: + return + ws: WorkerState = parent._workers_dv.get(worker) + if ws is None or ts._processing_on != ws: + return + + recommendations: dict + client_msgs: dict + worker_msgs: dict + + r: tuple = self.stimulus_missing_data(key=key, ensure=False, **msg) + recommendations, client_msgs, worker_msgs = r + parent._transitions(recommendations, client_msgs, worker_msgs) + + self.send_all(client_msgs, worker_msgs) + def handle_missing_data(self, key=None, errant_worker=None, **kwargs): parent: SchedulerState = cast(SchedulerState, self) logger.debug("handle missing data key=%s worker=%s", key, errant_worker)