diff --git a/distributed/active_memory_manager.py b/distributed/active_memory_manager.py index c2bbe7ccd3c..a86914cafcf 100644 --- a/distributed/active_memory_manager.py +++ b/distributed/active_memory_manager.py @@ -303,13 +303,7 @@ class ReduceReplicas(ActiveMemoryManagerPolicy): """ def run(self): - # TODO this is O(n) to the total number of in-memory tasks on the cluster; it - # could be made faster by automatically attaching it to a TaskState when it - # goes above one replica and detaching it when it drops below two. - for ts in self.manager.scheduler.tasks.values(): - if len(ts.who_has) < 2: - continue - + for ts in self.manager.scheduler.replicated_tasks: desired_replicas = 1 # TODO have a marker on TaskState # If a dependent task has not been assigned to a worker yet, err on the side diff --git a/distributed/scheduler.py b/distributed/scheduler.py index c538056e7dc..6a542686fde 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1858,6 +1858,7 @@ class SchedulerState: _task_groups: dict _task_prefixes: dict _task_metadata: dict + _replicated_tasks: set _total_nthreads: Py_ssize_t _total_occupancy: double _transitions_table: dict @@ -1917,6 +1918,9 @@ def __init__( self._tasks = tasks else: self._tasks = dict() + self._replicated_tasks = { + ts for ts in self._tasks.values() if len(ts._who_has) > 1 + } self._computations = deque( maxlen=dask.config.get("distributed.diagnostics.computations.max-history") ) @@ -2034,6 +2038,10 @@ def task_prefixes(self): def task_metadata(self): return self._task_metadata + @property + def replicated_tasks(self): + return self._replicated_tasks + @property def total_nthreads(self): return self._total_nthreads @@ -2819,18 +2827,14 @@ def transition_memory_released(self, key, safe: bint = False): dts._waiting_on.add(ts) # XXX factor this out? - ts_nbytes: Py_ssize_t = ts.get_nbytes() worker_msg = { "op": "free-keys", "keys": [key], "reason": f"Memory->Released {key}", } for ws in ts._who_has: - del ws._has_what[ts] - ws._nbytes -= ts_nbytes worker_msgs[ws._address] = [worker_msg] - - ts._who_has.clear() + self.remove_all_replicas(ts) ts.state = "released" @@ -3428,6 +3432,40 @@ def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple: else: return (start_time, ws._nbytes) + @ccall + def add_replica(self, ts: TaskState, ws: WorkerState): + """Note that a worker holds a replica of a task with state='memory'""" + if self._validate: + assert ws not in ts._who_has + assert ts not in ws._has_what + + ws._nbytes += ts.get_nbytes() + ws._has_what[ts] = None + ts._who_has.add(ws) + if len(ts._who_has) == 2: + self._replicated_tasks.add(ts) + + @ccall + def remove_replica(self, ts: TaskState, ws: WorkerState): + """Note that a worker no longer holds a replica of a task""" + ws._nbytes -= ts.get_nbytes() + del ws._has_what[ts] + ts._who_has.remove(ws) + if len(ts._who_has) == 1: + self._replicated_tasks.remove(ts) + + @ccall + def remove_all_replicas(self, ts: TaskState): + """Remove all replicas of a task from all workers""" + ws: WorkerState + nbytes: Py_ssize_t = ts.get_nbytes() + for ws in ts._who_has: + ws._nbytes -= nbytes + del ws._has_what[ts] + if len(ts._who_has) > 1: + self._replicated_tasks.remove(ts) + ts._who_has.clear() + class Scheduler(SchedulerState, ServerNode): """Dynamic distributed task scheduler @@ -4917,14 +4955,13 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): self.allowed_failures, ) - for ts in ws._has_what: - ts._who_has.remove(ws) + for ts in list(ws._has_what): + parent.remove_replica(ts, ws) if not ts._who_has: if ts._run_spec: recommendations[ts._key] = "released" else: # pure data recommendations[ts._key] = "forgotten" - ws._has_what.clear() self.transitions(recommendations) @@ -5074,6 +5111,7 @@ def validate_memory(self, key): ts: TaskState = parent._tasks[key] dts: TaskState assert ts._who_has + assert bool(ts in parent._replicated_tasks) == (len(ts._who_has) > 1) assert not ts._processing_on assert not ts._waiting_on assert ts not in parent._unrunnable @@ -5144,8 +5182,13 @@ def validate_state(self, allow_overlap=False): for k, ts in parent._tasks.items(): assert isinstance(ts, TaskState), (type(ts), ts) assert ts._key == k + assert bool(ts in parent._replicated_tasks) == (len(ts._who_has) > 1) self.validate_key(k, ts) + for ts in parent._replicated_tasks: + assert ts._state == "memory" + assert ts._key in parent._tasks + c: str cs: ClientState for c, cs in parent._clients.items(): @@ -5375,9 +5418,7 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): return ws: WorkerState = parent._workers_dv.get(errant_worker) if ws is not None and ws in ts._who_has: - ts._who_has.remove(ws) - del ws._has_what[ts] - ws._nbytes -= ts.get_nbytes() + parent.remove_replica(ts, ws) if not ts._who_has: if ts._run_spec: self.transitions({key: "released"}) @@ -5391,12 +5432,9 @@ def release_worker_data(self, comm=None, key=None, worker=None): if not ws or not ts: return recommendations: dict = {} - if ts in ws._has_what: - del ws._has_what[ts] - ws._nbytes -= ts.get_nbytes() - wh: set = ts._who_has - wh.remove(ws) - if not wh: + if ws in ts._who_has: + parent.remove_replica(ts, ws) + if not ts._who_has: recommendations[ts._key] = "released" if recommendations: self.transitions(recommendations) @@ -5716,14 +5754,11 @@ async def gather(self, comm=None, keys=None, serializers=None): ) if not workers or ts is None: continue - ts_nbytes: Py_ssize_t = ts.get_nbytes() recommendations: dict = {key: "released"} for worker in workers: ws = parent._workers_dv.get(worker) - if ws is not None and ts in ws._has_what: - del ws._has_what[ts] - ts._who_has.remove(ws) - ws._nbytes -= ts_nbytes + if ws is not None and ws in ts._who_has: + parent.remove_replica(ts, ws) parent._transitions( recommendations, client_msgs, worker_msgs ) @@ -5922,10 +5957,8 @@ async def gather_on_worker( if ts is None or ts._state != "memory": logger.warning(f"Key lost during replication: {key}") continue - if ts not in ws._has_what: - ws._nbytes += ts.get_nbytes() - ws._has_what[ts] = None - ts._who_has.add(ws) + if ws not in ts._who_has: + parent.add_replica(ts, ws) return keys_failed @@ -5962,11 +5995,9 @@ async def delete_worker_data(self, worker_address: str, keys: "list[str]") -> No for key in keys: ts: TaskState = parent._tasks.get(key) - if ts is not None and ts in ws._has_what: + if ts is not None and ws in ts._who_has: assert ts._state == "memory" - del ws._has_what[ts] - ts._who_has.remove(ws) - ws._nbytes -= ts.get_nbytes() + parent.remove_replica(ts, ws) if not ts._who_has: # Last copy deleted self.transitions({key: "released"}) @@ -6714,10 +6745,8 @@ def add_keys(self, comm=None, worker=None, keys=()): for key in keys: ts: TaskState = parent._tasks.get(key) if ts is not None and ts._state == "memory": - if ts not in ws._has_what: - ws._nbytes += ts.get_nbytes() - ws._has_what[ts] = None - ts._who_has.add(ws) + if ws not in ts._who_has: + parent.add_replica(ts, ws) else: redundant_replicas.append(key) @@ -6760,17 +6789,14 @@ def update_data( if ts is None: ts: TaskState = parent.new_task(key, None, "memory") ts.state = "memory" - ts_nbytes: Py_ssize_t = nbytes.get(key, -1) + ts_nbytes = nbytes.get(key, -1) if ts_nbytes >= 0: ts.set_nbytes(ts_nbytes) - else: - ts_nbytes = ts.get_nbytes() + for w in workers: ws: WorkerState = parent._workers_dv[w] - if ts not in ws._has_what: - ws._nbytes += ts_nbytes - ws._has_what[ts] = None - ts._who_has.add(ws) + if ws not in ts._who_has: + parent.add_replica(ts, ws) self.report( {"op": "key-in-memory", "key": key, "workers": list(workers)} ) @@ -7737,9 +7763,7 @@ def _add_to_memory( if state._validate: assert ts not in ws._has_what - ts._who_has.add(ws) - ws._has_what[ts] = None - ws._nbytes += ts.get_nbytes() + state.add_replica(ts, ws) deps: list = list(ts._dependents) if len(deps) > 1: @@ -7815,12 +7839,8 @@ def _propagate_forgotten( ts._dependencies.clear() ts._waiting_on.clear() - ts_nbytes: Py_ssize_t = ts.get_nbytes() - ws: WorkerState for ws in ts._who_has: - del ws._has_what[ts] - ws._nbytes -= ts_nbytes w: str = ws._address if w in state._workers_dv: # in case worker has died worker_msgs[w] = [ @@ -7830,7 +7850,7 @@ def _propagate_forgotten( "reason": f"propagate-forgotten {ts.key}", } ] - ts._who_has.clear() + state.remove_all_replicas(ts) @cfunc