From 6184928b331faf405c9b77cfdde079246235b4fe Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Fri, 18 Feb 2022 22:51:37 -0800 Subject: [PATCH 01/10] Extract SchedulerState from Scheduler --- distributed/active_memory_manager.py | 4 +- distributed/scheduler.py | 1134 ++++++++++++++------------ distributed/stealing.py | 4 +- distributed/tests/test_client.py | 30 +- distributed/tests/test_scheduler.py | 50 +- 5 files changed, 658 insertions(+), 564 deletions(-) diff --git a/distributed/active_memory_manager.py b/distributed/active_memory_manager.py index 4a616095908..14980fb0533 100644 --- a/distributed/active_memory_manager.py +++ b/distributed/active_memory_manager.py @@ -222,10 +222,10 @@ def log_reject(msg: str) -> None: return None if candidates is None: - candidates = self.scheduler.running.copy() + candidates = self.scheduler.state.running.copy() else: # Don't modify orig_candidates - candidates = candidates & self.scheduler.running + candidates = candidates & self.scheduler.state.running if not candidates: log_reject("no running candidates") return None diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 59c1d4a41cb..0dd88edbbeb 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -29,7 +29,6 @@ from functools import partial from numbers import Number from typing import ClassVar, Literal -from typing import cast as pep484_cast import psutil from sortedcontainers import SortedDict, SortedSet @@ -1923,6 +1922,7 @@ def _task_key_or_none(task: TaskState): return task._key if task is not None else None +@final @cclass class SchedulerState: """Underlying task state of dynamic scheduler @@ -1991,6 +1991,7 @@ class SchedulerState: _replicated_tasks: set _total_nthreads: Py_ssize_t _total_occupancy: double + _transition_log: deque _transitions_table: dict _unknown_durations: dict _unrunnable: set @@ -2001,12 +2002,12 @@ class SchedulerState: _plugins: dict # dict[str, SchedulerPlugin] # Variables from dask.config, cached by __init__ for performance - UNKNOWN_TASK_DURATION: double - MEMORY_RECENT_TO_OLD_TIME: double - MEMORY_REBALANCE_MEASURE: str - MEMORY_REBALANCE_SENDER_MIN: double - MEMORY_REBALANCE_RECIPIENT_MAX: double - MEMORY_REBALANCE_HALF_GAP: double + _UNKNOWN_TASK_DURATION: double + _MEMORY_RECENT_TO_OLD_TIME: double + _MEMORY_REBALANCE_MEASURE: str + _MEMORY_REBALANCE_SENDER_MIN: double + _MEMORY_REBALANCE_RECIPIENT_MAX: double + _MEMORY_REBALANCE_HALF_GAP: double def __init__( self, @@ -2047,6 +2048,9 @@ def __init__( self._task_metadata = {} self._total_nthreads = 0 self._total_occupancy = 0 + self._transition_log = deque( + maxlen=dask.config.get("distributed.scheduler.transition-log-length") + ) self._transitions_table = { ("released", "waiting"): self.transition_released_waiting, ("waiting", "released"): self.transition_waiting_released, @@ -2076,22 +2080,22 @@ def __init__( self._plugins = {} if not plugins else {_get_plugin_name(p): p for p in plugins} # Variables from dask.config, cached by __init__ for performance - self.UNKNOWN_TASK_DURATION = parse_timedelta( + self._UNKNOWN_TASK_DURATION = parse_timedelta( dask.config.get("distributed.scheduler.unknown-task-duration") ) - self.MEMORY_RECENT_TO_OLD_TIME = parse_timedelta( + self._MEMORY_RECENT_TO_OLD_TIME = parse_timedelta( dask.config.get("distributed.worker.memory.recent-to-old-time") ) - self.MEMORY_REBALANCE_MEASURE = dask.config.get( + self._MEMORY_REBALANCE_MEASURE = dask.config.get( "distributed.worker.memory.rebalance.measure" ) - self.MEMORY_REBALANCE_SENDER_MIN = dask.config.get( + self._MEMORY_REBALANCE_SENDER_MIN = dask.config.get( "distributed.worker.memory.rebalance.sender-min" ) - self.MEMORY_REBALANCE_RECIPIENT_MAX = dask.config.get( + self._MEMORY_REBALANCE_RECIPIENT_MAX = dask.config.get( "distributed.worker.memory.rebalance.recipient-max" ) - self.MEMORY_REBALANCE_HALF_GAP = ( + self._MEMORY_REBALANCE_HALF_GAP = ( dask.config.get("distributed.worker.memory.rebalance.sender-recipient-gap") / 2.0 ) @@ -2108,6 +2112,10 @@ def aliases(self): def bandwidth(self): return self._bandwidth + @bandwidth.setter + def bandwidth(self, v: double): + self._bandwidth = v + @property def clients(self): return self._clients @@ -2168,6 +2176,10 @@ def replicated_tasks(self): def total_nthreads(self): return self._total_nthreads + @total_nthreads.setter + def total_nthreads(self, v: Py_ssize_t): + self._total_nthreads = v + @property def total_occupancy(self): return self._total_occupancy @@ -2176,6 +2188,10 @@ def total_occupancy(self): def total_occupancy(self, v: double): self._total_occupancy = v + @property + def transition_log(self): + return self._transition_log + @property def transition_counter(self): return self._transition_counter @@ -2204,6 +2220,34 @@ def workers(self): def plugins(self) -> "dict[str, SchedulerPlugin]": return self._plugins + @plugins.setter + def plugins(self, val): + self._plugins = val + + @property + def UNKNOWN_TASK_DURATION(self): + return self._UNKNOWN_TASK_DURATION + + @property + def MEMORY_RECENT_TO_OLD_TIME(self): + return self._MEMORY_RECENT_TO_OLD_TIME + + @property + def MEMORY_REBALANCE_MEASURE(self): + return self._MEMORY_REBALANCE_MEASURE + + @property + def MEMORY_REBALANCE_SENDER_MIN(self): + return self._MEMORY_REBALANCE_SENDER_MIN + + @property + def MEMORY_REBALANCE_RECIPIENT_MAX(self): + return self._MEMORY_REBALANCE_RECIPIENT_MAX + + @property + def MEMORY_REBALANCE_HALF_GAP(self): + return self._MEMORY_REBALANCE_HALF_GAP + @property def memory(self) -> MemoryState: return MemoryState.sum(*(w.memory for w in self.workers.values())) @@ -2264,12 +2308,13 @@ def new_task( # State Transitions # ##################### - def _transition(self, key, finish: str, *args, **kwargs): + @ccall + def transition(self, key, finish: str, args: tuple = None, kwargs: dict = None): """Transition a key from its current state to the finish state - Examples + Examples.stimulus_task_finished -------- - >>> self._transition('x', 'waiting') + >>> self.transition('x', 'waiting') {'x': 'processing'} Returns @@ -2280,7 +2325,6 @@ def _transition(self, key, finish: str, *args, **kwargs): -------- Scheduler.transitions : transitive version of this function """ - parent: SchedulerState = cast(SchedulerState, self) ts: TaskState start: str start_finish: tuple @@ -2292,12 +2336,14 @@ def _transition(self, key, finish: str, *args, **kwargs): new_msgs: list dependents: set dependencies: set + args = args or () + kwargs = kwargs or {} try: recommendations = {} worker_msgs = {} client_msgs = {} - ts = parent._tasks.get(key) # type: ignore + ts = self._tasks.get(key) # type: ignore if ts is None: return recommendations, client_msgs, worker_msgs start = ts._state @@ -2318,7 +2364,7 @@ def _transition(self, key, finish: str, *args, **kwargs): a_recs: dict a_cmsgs: dict a_wmsgs: dict - a: tuple = self._transition(key, "released") + a: tuple = self.transition(key, "released") a_recs, a_cmsgs, a_wmsgs = a v = a_recs.get(key, finish) @@ -2362,12 +2408,8 @@ def _transition(self, key, finish: str, *args, **kwargs): raise RuntimeError("Impossible transition from %r to %r" % start_finish) finish2 = ts._state - # FIXME downcast antipattern - scheduler = pep484_cast(Scheduler, self) - scheduler.transition_log.append( - (key, start, finish2, recommendations, time()) - ) - if parent._validate: + self._transition_log.append((key, start, finish2, recommendations, time())) + if self._validate: logger.debug( "Transitioned %r %s->%s (actual: %s). Consequence: %s", key, @@ -2381,17 +2423,17 @@ def _transition(self, key, finish: str, *args, **kwargs): if ts._state == "forgotten": ts._dependents = dependents ts._dependencies = dependencies - parent._tasks[ts._key] = ts + self._tasks[ts._key] = ts for plugin in list(self.plugins.values()): try: plugin.transition(key, start, finish2, *args, **kwargs) except Exception: logger.info("Plugin failed with exception", exc_info=True) if ts._state == "forgotten": - del parent._tasks[ts._key] + del self._tasks[ts._key] tg: TaskGroup = ts._group - if ts._state == "forgotten" and tg._name in parent._task_groups: + if ts._state == "forgotten" and tg._name in self._task_groups: # Remove TaskGroup if all tasks are in the forgotten state all_forgotten: bint = True for s in ALL_TASK_STATES: @@ -2400,7 +2442,7 @@ def _transition(self, key, finish: str, *args, **kwargs): break if all_forgotten: ts._prefix._groups.remove(tg) - del parent._task_groups[tg._name] + del self._task_groups[tg._name] return recommendations, client_msgs, worker_msgs except Exception: @@ -2411,7 +2453,8 @@ def _transition(self, key, finish: str, *args, **kwargs): pdb.set_trace() raise - def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: dict): + @ccall + def transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: dict): """Process transitions until none are left This includes feedback from previous transitions and continues until we @@ -2429,7 +2472,7 @@ def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: di key, finish = recommendations.popitem() keys.add(key) - new = self._transition(key, finish) + new = self.transition(key, finish) new_recs, new_cmsgs, new_wmsgs = new recommendations.update(new_recs) @@ -2447,10 +2490,8 @@ def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: di worker_msgs[w] = new_msgs if self._validate: - # FIXME downcast antipattern - scheduler = pep484_cast(Scheduler, self) for key in keys: - scheduler.validate_key(key) + self.validate_key(key) def transition_released_waiting(self, key): try: @@ -3437,7 +3478,7 @@ def get_task_duration(self, ts: TaskState) -> double: if s is None: self._unknown_durations[ts._prefix._name] = s = set() s.add(ts) - return self.UNKNOWN_TASK_DURATION + return self._UNKNOWN_TASK_DURATION @ccall @exceptval(check=False) @@ -3587,7 +3628,7 @@ def remove_all_replicas(self, ts: TaskState): @ccall @exceptval(check=False) - def _reevaluate_occupancy_worker(self, ws: WorkerState): + def reevaluate_occupancy_worker(self, ws: WorkerState): """See reevaluate_occupancy""" ts: TaskState old = ws._occupancy @@ -3602,8 +3643,145 @@ def _reevaluate_occupancy_worker(self, ws: WorkerState): for ts in ws._processing: steal.recalculate_cost(ts) + ################### + # Task Validation # + ################### + # TODO: could all be @ccall, but called rarely + + def validate_released(self, key): + ts: TaskState = self._tasks[key] + dts: TaskState + assert ts._state == "released" + assert not ts._waiters + assert not ts._waiting_on + assert not ts._who_has + assert not ts._processing_on + assert not any([ts in dts._waiters for dts in ts._dependencies]) + assert ts not in self._unrunnable -class Scheduler(SchedulerState, ServerNode): + def validate_waiting(self, key): + ts: TaskState = self._tasks[key] + dts: TaskState + assert ts._waiting_on + assert not ts._who_has + assert not ts._processing_on + assert ts not in self._unrunnable + for dts in ts._dependencies: + # We are waiting on a dependency iff it's not stored + assert (not not dts._who_has) != (dts in ts._waiting_on) + assert ts in dts._waiters # XXX even if dts._who_has? + + def validate_processing(self, key): + ts: TaskState = self._tasks[key] + dts: TaskState + assert not ts._waiting_on + ws: WorkerState = ts._processing_on + assert ws + assert ts in ws._processing + assert not ts._who_has + for dts in ts._dependencies: + assert dts._who_has + assert ts in dts._waiters + + def validate_memory(self, key): + ts: TaskState = self._tasks[key] + dts: TaskState + assert ts._who_has + assert bool(ts in self._replicated_tasks) == (len(ts._who_has) > 1) + assert not ts._processing_on + assert not ts._waiting_on + assert ts not in self._unrunnable + for dts in ts._dependents: + assert (dts in ts._waiters) == (dts._state in ("waiting", "processing")) + assert ts not in dts._waiting_on + + def validate_no_worker(self, key): + ts: TaskState = self._tasks[key] + dts: TaskState + assert ts in self._unrunnable + assert not ts._waiting_on + assert not ts._processing_on + assert not ts._who_has + for dts in ts._dependencies: + assert dts._who_has + + def validate_erred(self, key): + ts: TaskState = self._tasks[key] + assert ts._exception_blame + assert not ts._who_has + + def validate_key(self, key, ts: TaskState = None): + try: + if ts is None: + ts = self._tasks.get(key) + if ts is None: + logger.debug("Key lost: %s", key) + else: + ts.validate() + try: + func = getattr(self, "validate_" + ts._state.replace("-", "_")) + except AttributeError: + logger.error( + "self.validate_%s not found", ts._state.replace("-", "_") + ) + else: + func(key) + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb + + pdb.set_trace() + raise + + def validate_state(self, allow_overlap=False): + validate_state(self._tasks, self._workers, self._clients) + + # if not (set(self.state.workers) == set(self.stream_comms)): + # raise ValueError("Workers not the same in all collections") + + ws: WorkerState + for w, ws in self._workers_dv.items(): + assert isinstance(w, str), (type(w), w) + assert isinstance(ws, WorkerState), (type(ws), ws) + assert ws._address == w + if not ws._processing: + assert not ws._occupancy + assert ws._address in self._idle_dv + + ts: TaskState + for k, ts in self._tasks.items(): + assert isinstance(ts, TaskState), (type(ts), ts) + assert ts._key == k + self.validate_key(k, ts) + + c: str + cs: ClientState + for c, cs in self._clients.items(): + # client=None is often used in tests... + assert c is None or type(c) == str, (type(c), c) + assert type(cs) == ClientState, (type(cs), cs) + assert cs._client_key == c + + a = {w: ws._nbytes for w, ws in self._workers_dv.items()} + b = { + w: sum(ts.get_nbytes() for ts in ws._has_what) + for w, ws in self._workers_dv.items() + } + assert a == b, (a, b) + + actual_total_occupancy = 0 + for worker, ws in self._workers_dv.items(): + assert abs(sum(ws._processing.values()) - ws._occupancy) < 1e-8 + actual_total_occupancy += ws._occupancy + + assert abs(actual_total_occupancy - self._total_occupancy) < 1e-8, ( + actual_total_occupancy, + self._total_occupancy, + ) + + +class Scheduler(ServerNode): """Dynamic distributed task scheduler The scheduler tracks the current state of workers, data, and computations. @@ -3865,18 +4043,15 @@ def __init__( resources = {} aliases = {} - self._task_state_collections = [unrunnable] + self._task_state_collections: list = [unrunnable] - self._worker_collections = [ + self._worker_collections: list = [ workers, host_info, resources, aliases, ] - self.transition_log = deque( - maxlen=dask.config.get("distributed.scheduler.transition-log-length") - ) self.log = deque( maxlen=dask.config.get("distributed.scheduler.transition-log-length") ) @@ -3890,7 +4065,7 @@ def __init__( self.worker_plugins = {} self.nanny_plugins = {} - worker_handlers = { + worker_handlers: dict = { "task-finished": self.handle_task_finished, "task-erred": self.handle_task_erred, "release-worker-data": self.release_worker_data, @@ -3903,7 +4078,7 @@ def __init__( "worker-status-change": self.handle_worker_status_change, } - client_handlers = { + client_handlers: dict = { "update-graph": self.update_graph, "update-graph-hlg": self.update_graph_hlg, "client-desires-keys": self.client_desires_keys, @@ -3974,8 +4149,7 @@ def __init__( connection_limit = get_fileno_limit() / 2 - super().__init__( - # Arguments to SchedulerState + self.state: SchedulerState = SchedulerState( aliases=aliases, clients=clients, workers=workers, @@ -3985,7 +4159,9 @@ def __init__( unrunnable=unrunnable, validate=validate, plugins=plugins, - # Arguments to ServerNode + ) + + super().__init__( handlers=self.handlers, stream_handlers=merge(worker_handlers, client_handlers), io_loop=self.loop, @@ -4015,31 +4191,136 @@ def __init__( self.rpc.allow_offload = False self.status = Status.undefined + #################### + # state properties # + #################### + + @property + def aliases(self): + return self.state.aliases + + @property + def bandwidth(self): + return self.state.bandwidth + + @property + def clients(self): + return self.state.clients + + @property + def extensions(self): + return self.state.extensions + + @property + def host_info(self): + return self.state.host_info + + @property + def idle(self): + return self.state.idle + + @property + def n_tasks(self): + return self.state.n_tasks + + @property + def plugins(self): + return self.state.plugins + + @plugins.setter + def plugins(self, val): + self.state.plugins = val + + @property + def resources(self): + return self.state.resources + + @property + def saturated(self): + return self.state.saturated + + @property + def tasks(self): + return self.state.tasks + + @property + def task_groups(self): + return self.state.task_groups + + @property + def task_prefixes(self): + return self.state.task_prefixes + + @property + def task_metadata(self): + return self.state.task_metadata + + @property + def total_nthreads(self): + return self.state.total_nthreads + + @property + def total_occupancy(self): + return self.state.total_occupancy + + @total_occupancy.setter + def total_occupancy(self, v: double): + self.state.total_occupancy = v + + @property + def transition_counter(self): + return self.state.transition_counter + + @property + def unknown_durations(self): + return self.state.unknown_durations + + @property + def unrunnable(self): + return self.state.unrunnable + + @property + def validate(self): + return self.state.validate + + @validate.setter + def validate(self, v: bint): + self.state.validate = v + + @property + def workers(self): + return self.state.workers + + @property + def memory(self) -> MemoryState: + return MemoryState.sum(*(w.memory for w in self.state.workers.values())) + + @property + def __pdict__(self): + return self.state.__pdict__ + ################## # Administration # ################## def __repr__(self): - parent: SchedulerState = cast(SchedulerState, self) return ( f"" + f"workers: {len(self.state.workers)}, " + f"cores: {self.state.total_nthreads}, " + f"tasks: {len(self.state.tasks)}>" ) def _repr_html_(self): - parent: SchedulerState = cast(SchedulerState, self) return get_template("scheduler.html.j2").render( address=self.address, - workers=parent._workers_dv, - threads=parent._total_nthreads, - tasks=parent._tasks, + workers=self.state.workers, + threads=self.state.total_nthreads, + tasks=self.state.tasks, ) def identity(self, comm=None): """Basic information about ourselves and our cluster""" - parent: SchedulerState = cast(SchedulerState, self) d = { "type": type(self).__name__, "id": str(self.id), @@ -4048,7 +4329,7 @@ def identity(self, comm=None): "started": self.time_started, "workers": { worker.address: worker.identity() - for worker in parent._workers_dv.values() + for worker in self.state.workers.values() }, } return d @@ -4067,7 +4348,7 @@ def _to_dict( """ info = super()._to_dict(exclude=exclude) extra = { - "transition_log": self.transition_log, + "transition_log": self.state.transition_log, "log": self.log, "tasks": self.tasks, "task_groups": self.task_groups, @@ -4096,8 +4377,7 @@ def get_worker_service_addr(self, worker, service_name, protocol=False): Whether or not to include a full address with protocol (True) or just a (host, port) pair """ - parent: SchedulerState = cast(SchedulerState, self) - ws: WorkerState = parent._workers_dv[worker] + ws: WorkerState = self.state.workers[worker] port = ws._services.get(service_name) if port is None: return None @@ -4180,7 +4460,6 @@ async def close(self, comm=None, fast=False, close_workers=False): -------- Scheduler.cleanup """ - parent: SchedulerState = cast(SchedulerState, self) if self.status in (Status.closing, Status.closed): await self.finished() return @@ -4194,13 +4473,13 @@ async def close(self, comm=None, fast=False, close_workers=False): if close_workers: await self.broadcast(msg={"op": "close_gracefully"}, nanny=True) - for worker in parent._workers_dv: + for worker in self.state.workers: # Report would require the worker to unregister with the # currently closing scheduler. This is not necessary and might # delay shutdown of the worker unnecessarily self.worker_send(worker, {"op": "close", "report": False}) for i in range(20): # wait a second for send signals to clear - if parent._workers_dv: + if self.state.workers: await asyncio.sleep(0.05) else: break @@ -4215,7 +4494,7 @@ async def close(self, comm=None, fast=False, close_workers=False): self.stop_services() - for ext in parent._extensions.values(): + for ext in self.state.extensions.values(): with suppress(AttributeError): ext.teardown() logger.info("Scheduler closing all comms") @@ -4273,10 +4552,9 @@ def heartbeat_worker( metrics: dict, executing: dict = None, ): - parent: SchedulerState = cast(SchedulerState, self) address = self.coerce_address(address, resolve_address) address = normalize_address(address) - ws: WorkerState = parent._workers_dv.get(address) # type: ignore + ws: WorkerState = self.state.workers.get(address) # type: ignore if ws is None: return {"status": "missing"} @@ -4284,12 +4562,12 @@ def heartbeat_worker( local_now = time() host_info = host_info or {} - dh: dict = parent._host_info.setdefault(host, {}) + dh: dict = self.state.host_info.setdefault(host, {}) dh["last-seen"] = local_now - frac = 1 / len(parent._workers_dv) - parent._bandwidth = ( - parent._bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac + frac = 1 / len(self.state.workers) + self.state.bandwidth = ( + self.state.bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac ) for other, (bw, count) in metrics["bandwidth"]["workers"].items(): if (address, other) not in self.bandwidth_workers: @@ -4311,16 +4589,16 @@ def heartbeat_worker( ws._last_seen = local_now if executing is not None: ws._executing = { - parent._tasks[key]: duration - for key, duration in executing.items() - if key in parent._tasks + self.state.tasks[key]: duration for key, duration in executing.items() } ws._metrics = metrics # Calculate RSS - dask keys, separating "old" and "new" usage # See MemoryState for details - max_memory_unmanaged_old_hist_age = local_now - parent.MEMORY_RECENT_TO_OLD_TIME + max_memory_unmanaged_old_hist_age = ( + local_now - self.state.MEMORY_RECENT_TO_OLD_TIME + ) memory_unmanaged_old = ws._memory_unmanaged_old while ws._memory_other_history: timestamp, size = ws._memory_other_history[0] @@ -4350,7 +4628,7 @@ def heartbeat_worker( ws._memory_unmanaged_old = size if host_info: - dh = parent._host_info.setdefault(host, {}) + dh = self.state.host_info.setdefault(host, {}) dh.update(host_info) if now: @@ -4364,7 +4642,7 @@ def heartbeat_worker( return { "status": "OK", "time": local_now, - "heartbeat-interval": heartbeat_interval(len(parent._workers_dv)), + "heartbeat-interval": heartbeat_interval(len(self.state.workers)), } async def add_worker( @@ -4392,16 +4670,15 @@ async def add_worker( extra=None, ): """Add a new worker to the cluster""" - parent: SchedulerState = cast(SchedulerState, self) with log_errors(): address = self.coerce_address(address, resolve_address) address = normalize_address(address) host = get_address_host(address) - if address in parent._workers_dv: + if address in self.state.workers: raise ValueError("Worker already exists %s" % address) - if name in parent._aliases: + if name in self.state.aliases: logger.warning( "Worker tried to connect with a duplicate name: %s", name ) @@ -4418,7 +4695,7 @@ async def add_worker( self.log_event("all", {"action": "add-worker", "worker": address}) ws: WorkerState - parent._workers[address] = ws = WorkerState( + self.state.workers[address] = ws = WorkerState( address=address, status=Status.lookup[status], # type: ignore pid=pid, @@ -4432,11 +4709,11 @@ async def add_worker( extra=extra, ) if ws._status == Status.running: - parent._running.add(ws) + self.state.running.add(ws) - dh: dict = parent._host_info.get(host) # type: ignore + dh: dict = self.state.host_info.get(host) # type: ignore if dh is None: - parent._host_info[host] = dh = {} + self.state.host_info[host] = dh = {} dh_addresses: set = dh.get("addresses") # type: ignore if dh_addresses is None: @@ -4446,8 +4723,8 @@ async def add_worker( dh_addresses.add(address) dh["nthreads"] += nthreads - parent._total_nthreads += nthreads - parent._aliases[name] = address + self.state.total_nthreads += nthreads + self.state.aliases[name] = address self.heartbeat_worker( address=address, @@ -4458,9 +4735,9 @@ async def add_worker( metrics=metrics, ) - # Do not need to adjust parent._total_occupancy as self.occupancy[ws] cannot + # Do not need to adjust self.state.total_occupancy as self.occupancy[ws] cannot # exist before this. - self.check_idle_saturated(ws) + self.state.check_idle_saturated(ws) # for key in keys: # TODO # self.mark_key_in_memory(key, [address]) @@ -4468,7 +4745,7 @@ async def add_worker( self.stream_comms[address] = BatchedSend(interval="5ms", loop=self.loop) if ws._nthreads > len(ws._processing): - parent._idle[ws._address] = ws + self.state.idle[ws._address] = ws for plugin in list(self.plugins.values()): try: @@ -4485,20 +4762,22 @@ async def add_worker( assert isinstance(nbytes, dict) already_released_keys = [] for key in nbytes: - ts: TaskState = parent._tasks.get(key) # type: ignore + ts: TaskState = self.state.tasks.get(key) # type: ignore if ts is not None and ts.state != "released": if ts.state == "memory": self.add_keys(worker=address, keys=[key]) else: - t: tuple = parent._transition( + t: tuple = self.state.transition( key, "memory", - worker=address, - nbytes=nbytes[key], - typename=types[key], + kwargs=dict( + worker=address, + nbytes=nbytes[key], + typename=types[key], + ), ) recommendations, client_msgs, worker_msgs = t - parent._transitions( + self.state.transitions( recommendations, client_msgs, worker_msgs ) recommendations = {} @@ -4516,13 +4795,13 @@ async def add_worker( ) if ws._status == Status.running: - for ts in parent._unrunnable: - valid: set = self.valid_workers(ts) + for ts in self.state.unrunnable: + valid: set = self.state.valid_workers(ts) if valid is None or ws in valid: recommendations[ts._key] = "waiting" if recommendations: - parent._transitions(recommendations, client_msgs, worker_msgs) + self.state.transitions(recommendations, client_msgs, worker_msgs) self.send_all(client_msgs, worker_msgs) @@ -4531,7 +4810,7 @@ async def add_worker( msg = { "status": "OK", "time": time(), - "heartbeat-interval": heartbeat_interval(len(parent._workers_dv)), + "heartbeat-interval": heartbeat_interval(len(self.state.workers)), "worker-plugins": self.worker_plugins, } @@ -4539,10 +4818,10 @@ async def add_worker( version_warning = version_module.error_message( version_module.get_versions(), merge( - {w: ws._versions for w, ws in parent._workers_dv.items()}, + {w: ws._versions for w, ws in self.state.workers.items()}, { c: cs._versions - for c, cs in parent._clients.items() + for c, cs in self.state.clients.items() if cs._versions }, ), @@ -4563,6 +4842,15 @@ async def add_nanny(self, comm): } return msg + def get_task_duration(self, ts: TaskState) -> double: + return self.state.get_task_duration(ts) + + def get_comm_cost(self, *args, **kwargs): + return self.state.get_comm_cost(*args, **kwargs) + + def check_idle_saturated(self, *args, **kwargs): + return self.state.check_idle_saturated(*args, **kwargs) + def update_graph_hlg( self, client=None, @@ -4643,7 +4931,6 @@ def update_graph( This happens whenever the Client calls submit, map, get, or compute. """ - parent: SchedulerState = cast(SchedulerState, self) start = time() fifo_timeout = parse_timedelta(fifo_timeout) keys = set(keys) @@ -4659,12 +4946,12 @@ def update_graph( dependencies = dependencies or {} - if parent._total_occupancy > 1e-9 and parent._computations: + if self.state.total_occupancy > 1e-9 and self.state.computations: # Still working on something. Assign new tasks to same computation - computation = cast(Computation, parent._computations[-1]) + computation = cast(Computation, self.state.computations[-1]) else: computation = Computation() - parent._computations.append(computation) + self.state.computations.append(computation) if code and code not in computation._code: # add new code blocks computation._code.add(code) @@ -4674,7 +4961,7 @@ def update_graph( n = len(tasks) for k, deps in list(dependencies.items()): if any( - dep not in parent._tasks and dep not in tasks for dep in deps + dep not in self.state.tasks and dep not in tasks for dep in deps ): # bad key logger.info("User asked for computation on lost data, %s", k) del tasks[k] @@ -4688,8 +4975,8 @@ def update_graph( ts: TaskState already_in_memory = set() # tasks that are already done for k, v in dependencies.items(): - if v and k in parent._tasks: - ts = parent._tasks[k] + if v and k in self.state.tasks: + ts = self.state.tasks[k] if ts._state in ("memory", "erred"): already_in_memory.add(k) @@ -4700,7 +4987,7 @@ def update_graph( done = set(already_in_memory) while stack: # remove unnecessary dependencies key = stack.pop() - ts = parent._tasks[key] + ts = self.state.tasks[key] try: deps = dependencies[key] except KeyError: @@ -4711,7 +4998,7 @@ def update_graph( else: child_deps = self.dependencies[dep] if all(d in done for d in child_deps): - if dep in parent._tasks and dep not in done: + if dep in self.state.tasks and dep not in done: done.add(dep) stack.append(dep) @@ -4728,9 +5015,9 @@ def update_graph( if k in touched_keys: continue # XXX Have a method get_task_state(self, k) ? - ts = parent._tasks.get(k) + ts = self.state.tasks.get(k) if ts is None: - ts = parent.new_task( + ts = self.state.new_task( k, tasks.get(k), "released", computation=computation ) elif not ts._run_spec: @@ -4744,11 +5031,11 @@ def update_graph( # Add dependencies for key, deps in dependencies.items(): - ts = parent._tasks.get(key) + ts = self.state.tasks.get(key) if ts is None or ts._dependencies: continue for dep in deps: - dts = parent._tasks[dep] + dts = self.state.tasks[dep] ts.add_dependency(dts) # Compute priorities @@ -4784,7 +5071,7 @@ def update_graph( for k, v in kv.items(): # Tasks might have been culled, in which case # we have nothing to annotate. - ts = parent._tasks.get(k) + ts = self.state.tasks.get(k) if ts is not None: ts._annotations[a] = v @@ -4792,7 +5079,7 @@ def update_graph( if actors is True: actors = list(keys) for actor in actors or []: - ts = parent._tasks[actor] + ts = self.state.tasks[actor] ts._actor = True priority = priority or dask.order.order( @@ -4800,7 +5087,7 @@ def update_graph( ) # TODO: define order wrt old graph if submitting_task: # sub-tasks get better priority than parent tasks - ts = parent._tasks.get(submitting_task) + ts = self.state.tasks.get(submitting_task) if ts is not None: generation = ts._priority[0] - 0.01 else: # super-task already cleaned up @@ -4813,7 +5100,7 @@ def update_graph( generation = self.generation for key in set(priority) & touched_keys: - ts = parent._tasks[key] + ts = self.state.tasks[key] if ts._priority is None: ts._priority = (-(user_priority.get(key, 0)), generation, priority[key]) @@ -4829,7 +5116,7 @@ def update_graph( for k, v in restrictions.items(): if v is None: continue - ts = parent._tasks.get(k) + ts = self.state.tasks.get(k) if ts is None: continue ts._host_restrictions = set() @@ -4848,7 +5135,7 @@ def update_graph( if loose_restrictions: for k in loose_restrictions: - ts = parent._tasks[k] + ts = self.state.tasks[k] ts._loose_restrictions = True if resources: @@ -4856,7 +5143,7 @@ def update_graph( if v is None: continue assert isinstance(v, dict) - ts = parent._tasks.get(k) + ts = self.state.tasks.get(k) if ts is None: continue ts._resource_restrictions = v @@ -4864,7 +5151,7 @@ def update_graph( if retries: for k, v in retries.items(): assert isinstance(v, int) - ts = parent._tasks.get(k) + ts = self.state.tasks.get(k) if ts is None: continue ts._retries = v @@ -4914,15 +5201,14 @@ def update_graph( def stimulus_task_finished(self, key=None, worker=None, **kwargs): """Mark that a task has finished execution on a particular worker""" - parent: SchedulerState = cast(SchedulerState, self) logger.debug("Stimulus task finished %s, %s", key, worker) recommendations: dict = {} client_msgs: dict = {} worker_msgs: dict = {} - ws: WorkerState = parent._workers_dv[worker] - ts: TaskState = parent._tasks.get(key) + ws: WorkerState = self.state.workers[worker] + ts: TaskState = self.state.tasks.get(key) if ts is None or ts._state == "released": logger.debug( "Received already computed task, worker: %s, state: %s" @@ -4943,7 +5229,11 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): self.add_keys(worker=worker, keys=[key]) else: ts._metadata.update(kwargs["metadata"]) - r: tuple = parent._transition(key, "memory", worker=worker, **kwargs) + r: tuple = self.state.transition( + key, + "memory", + kwargs=dict(worker=worker, **kwargs), + ) recommendations, client_msgs, worker_msgs = r if ts._state == "memory": @@ -4954,29 +5244,31 @@ def stimulus_task_erred( self, key=None, worker=None, exception=None, traceback=None, **kwargs ): """Mark that a task has erred on a particular worker""" - parent: SchedulerState = cast(SchedulerState, self) logger.debug("Stimulus task erred %s, %s", key, worker) - ts: TaskState = parent._tasks.get(key) + ts: TaskState = self.state.tasks.get(key) if ts is None or ts._state != "processing": return {}, {}, {} - if ts._retries > 0: - ts._retries -= 1 - return parent._transition(key, "waiting") + retries: Py_ssize_t = ts._retries + if retries > 0: + retries -= 1 + ts._retries = retries + return self.state.transition(key, "waiting") else: - return parent._transition( + return self.state.transition( key, "erred", - cause=key, - exception=exception, - traceback=traceback, - worker=worker, - **kwargs, + kwargs=dict( + cause=key, + exception=exception, + traceback=traceback, + worker=worker, + **kwargs, + ), ) def stimulus_retry(self, comm=None, keys=None, client=None): - parent: SchedulerState = cast(SchedulerState, self) logger.info("Client %s requests to retry %d keys", client, len(keys)) if client: self.log_event(client, {"action": "retry", "count": len(keys)}) @@ -4989,7 +5281,7 @@ def stimulus_retry(self, comm=None, keys=None, client=None): while stack: key = stack.pop() seen.add(key) - ts = parent._tasks[key] + ts = self.state.tasks[key] erred_deps = [dts._key for dts in ts._dependencies if dts._state == "erred"] if erred_deps: stack.extend(erred_deps) @@ -4999,9 +5291,9 @@ def stimulus_retry(self, comm=None, keys=None, client=None): recommendations: dict = {key: "waiting" for key in roots} self.transitions(recommendations) - if parent._validate: + if self.state.validate: for key in seen: - assert not parent._tasks[key].exception_blame + assert not self.state.tasks[key].exception_blame return tuple(seen) @@ -5013,19 +5305,18 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): appears to be unresponsive. This may send its tasks back to a released state. """ - parent: SchedulerState = cast(SchedulerState, self) with log_errors(): if self.status == Status.closed: return address = self.coerce_address(address) - if address not in parent._workers_dv: + if address not in self.state.workers: return "already-removed" host = get_address_host(address) - ws: WorkerState = parent._workers_dv[address] + ws: WorkerState = self.state.workers[address] event_msg = { "action": "remove-worker", @@ -5042,23 +5333,23 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): self.remove_resources(address) - dh: dict = parent._host_info[host] + dh: dict = self.state.host_info[host] dh_addresses: set = dh["addresses"] dh_addresses.remove(address) dh["nthreads"] -= ws._nthreads - parent._total_nthreads -= ws._nthreads + self.state.total_nthreads -= ws._nthreads if not dh_addresses: - del parent._host_info[host] + del self.state.host_info[host] self.rpc.remove(address) del self.stream_comms[address] - del parent._aliases[ws._name] - parent._idle.pop(ws._address, None) - parent._saturated.discard(ws) - del parent._workers[address] + del self.state.aliases[ws._name] + self.state.idle.pop(ws._address, None) + self.state.saturated.discard(ws) + del self.state.workers[address] ws.status = Status.closed - parent._running.discard(ws) - parent._total_occupancy -= ws._occupancy + self.state.running.discard(ws) + self.state.total_occupancy -= ws._occupancy recommendations: dict = {} @@ -5084,7 +5375,7 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): ) for ts in list(ws._has_what): - parent.remove_replica(ts, ws) + self.state.remove_replica(ts, ws) if not ts._who_has: if ts._run_spec: recommendations[ts._key] = "released" @@ -5101,16 +5392,16 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): except Exception as e: logger.exception(e) - if not parent._workers_dv: + if not self.state.workers: logger.info("Lost all workers") - for w in parent._workers_dv: + for w in self.state.workers: self.bandwidth_workers.pop((address, w), None) self.bandwidth_workers.pop((w, address), None) def remove_worker_from_events(): # If the worker isn't registered anymore after the delay, remove from events - if address not in parent._workers_dv and address in self.events: + if address not in self.state.workers and address in self.events: del self.events[address] cleanup_delay = parse_timedelta( @@ -5134,11 +5425,10 @@ def stimulus_cancel(self, comm, keys=None, client=None, force=False): def cancel_key(self, key, client, retries=5, force=False): """Cancel a particular key and all dependents""" # TODO: this should be converted to use the transition mechanism - parent: SchedulerState = cast(SchedulerState, self) - ts: TaskState = parent._tasks.get(key) + ts: TaskState = self.state.tasks.get(key) dts: TaskState try: - cs: ClientState = parent._clients[client] + cs: ClientState = self.state.clients[client] except KeyError: return if ts is None or not ts._who_wants: # no key yet, lets try again in a moment @@ -5157,17 +5447,16 @@ def cancel_key(self, key, client, retries=5, force=False): self.client_releases_keys(keys=[key], client=cs._client_key) def client_desires_keys(self, keys=None, client=None): - parent: SchedulerState = cast(SchedulerState, self) - cs: ClientState = parent._clients.get(client) + cs: ClientState = self.state.clients.get(client) if cs is None: # For publish, queues etc. - parent._clients[client] = cs = ClientState(client) + self.state.clients[client] = cs = ClientState(client) ts: TaskState for k in keys: - ts = parent._tasks.get(k) + ts = self.state.tasks.get(k) if ts is None: # For publish, queues etc. - ts = parent.new_task(k, None, "released") + ts = self.state.new_task(k, None, "released") ts._who_wants.add(cs) cs._wants_what.add(ts) @@ -5177,175 +5466,23 @@ def client_desires_keys(self, keys=None, client=None): def client_releases_keys(self, keys=None, client=None): """Remove keys from client desired list""" - parent: SchedulerState = cast(SchedulerState, self) if not isinstance(keys, list): keys = list(keys) - cs: ClientState = parent._clients[client] + cs: ClientState = self.state.clients[client] recommendations: dict = {} - _client_releases_keys(parent, keys=keys, cs=cs, recommendations=recommendations) + _client_releases_keys( + self.state, keys=keys, cs=cs, recommendations=recommendations + ) self.transitions(recommendations) def client_heartbeat(self, client=None): """Handle heartbeats from Client""" - parent: SchedulerState = cast(SchedulerState, self) - cs: ClientState = parent._clients[client] + cs: ClientState = self.state.clients[client] cs._last_seen = time() - ################### - # Task Validation # - ################### - - def validate_released(self, key): - parent: SchedulerState = cast(SchedulerState, self) - ts: TaskState = parent._tasks[key] - dts: TaskState - assert ts._state == "released" - assert not ts._waiters - assert not ts._waiting_on - assert not ts._who_has - assert not ts._processing_on - assert not any([ts in dts._waiters for dts in ts._dependencies]) - assert ts not in parent._unrunnable - - def validate_waiting(self, key): - parent: SchedulerState = cast(SchedulerState, self) - ts: TaskState = parent._tasks[key] - dts: TaskState - assert ts._waiting_on - assert not ts._who_has - assert not ts._processing_on - assert ts not in parent._unrunnable - for dts in ts._dependencies: - # We are waiting on a dependency iff it's not stored - assert bool(dts._who_has) != (dts in ts._waiting_on) - assert ts in dts._waiters # XXX even if dts._who_has? - - def validate_processing(self, key): - parent: SchedulerState = cast(SchedulerState, self) - ts: TaskState = parent._tasks[key] - dts: TaskState - assert not ts._waiting_on - ws: WorkerState = ts._processing_on - assert ws - assert ts in ws._processing - assert not ts._who_has - for dts in ts._dependencies: - assert dts._who_has - assert ts in dts._waiters - - def validate_memory(self, key): - parent: SchedulerState = cast(SchedulerState, self) - ts: TaskState = parent._tasks[key] - dts: TaskState - assert ts._who_has - assert bool(ts in parent._replicated_tasks) == (len(ts._who_has) > 1) - assert not ts._processing_on - assert not ts._waiting_on - assert ts not in parent._unrunnable - for dts in ts._dependents: - assert (dts in ts._waiters) == (dts._state in ("waiting", "processing")) - assert ts not in dts._waiting_on - - def validate_no_worker(self, key): - parent: SchedulerState = cast(SchedulerState, self) - ts: TaskState = parent._tasks[key] - dts: TaskState - assert ts in parent._unrunnable - assert not ts._waiting_on - assert ts in parent._unrunnable - assert not ts._processing_on - assert not ts._who_has - for dts in ts._dependencies: - assert dts._who_has - - def validate_erred(self, key): - parent: SchedulerState = cast(SchedulerState, self) - ts: TaskState = parent._tasks[key] - assert ts._exception_blame - assert not ts._who_has - - def validate_key(self, key, ts: TaskState = None): - parent: SchedulerState = cast(SchedulerState, self) - try: - if ts is None: - ts = parent._tasks.get(key) - if ts is None: - logger.debug("Key lost: %s", key) - else: - ts.validate() - try: - func = getattr(self, "validate_" + ts._state.replace("-", "_")) - except AttributeError: - logger.error( - "self.validate_%s not found", ts._state.replace("-", "_") - ) - else: - func(key) - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb - - pdb.set_trace() - raise - - def validate_state(self, allow_overlap=False): - parent: SchedulerState = cast(SchedulerState, self) - validate_state(parent._tasks, parent._workers, parent._clients) - - if not (set(parent._workers_dv) == set(self.stream_comms)): - raise ValueError("Workers not the same in all collections") - - ws: WorkerState - for w, ws in parent._workers_dv.items(): - assert isinstance(w, str), (type(w), w) - assert isinstance(ws, WorkerState), (type(ws), ws) - assert ws._address == w - if not ws._processing: - assert not ws._occupancy - assert ws._address in parent._idle_dv - assert (ws._status == Status.running) == (ws in parent._running) - - for ws in parent._running: - assert ws._status == Status.running - assert ws._address in parent._workers_dv - - ts: TaskState - for k, ts in parent._tasks.items(): - assert isinstance(ts, TaskState), (type(ts), ts) - assert ts._key == k - assert bool(ts in parent._replicated_tasks) == (len(ts._who_has) > 1) - self.validate_key(k, ts) - - for ts in parent._replicated_tasks: - assert ts._state == "memory" - assert ts._key in parent._tasks - - c: str - cs: ClientState - for c, cs in parent._clients.items(): - # client=None is often used in tests... - assert c is None or type(c) == str, (type(c), c) - assert type(cs) == ClientState, (type(cs), cs) - assert cs._client_key == c - - a = {w: ws._nbytes for w, ws in parent._workers_dv.items()} - b = { - w: sum(ts.get_nbytes() for ts in ws._has_what) - for w, ws in parent._workers_dv.items() - } - assert a == b, (a, b) - - actual_total_occupancy = 0 - for worker, ws in parent._workers_dv.items(): - assert abs(sum(ws._processing.values()) - ws._occupancy) < 1e-8 - actual_total_occupancy += ws._occupancy - - assert abs(actual_total_occupancy - parent._total_occupancy) < 1e-8, ( - actual_total_occupancy, - parent._total_occupancy, - ) + def validate_state(self): + self.state.validate_state() ################### # Manage Messages # @@ -5358,11 +5495,10 @@ def report(self, msg: dict, ts: TaskState = None, client: str = None): If the message contains a key then we only send the message to those comms that care about the key. """ - parent: SchedulerState = cast(SchedulerState, self) if ts is None: msg_key = msg.get("key") if msg_key is not None: - tasks: dict = parent._tasks + tasks: dict = self.state.tasks ts = tasks.get(msg_key) cs: ClientState @@ -5400,12 +5536,11 @@ async def add_client(self, comm, client=None, versions=None): We listen to all future messages from this Comm. """ - parent: SchedulerState = cast(SchedulerState, self) assert client is not None comm.name = "Scheduler->Client" logger.info("Receive client connection: %s", client) self.log_event(["all", client], {"action": "add-client", "client": client}) - parent._clients[client] = ClientState(client, versions=versions) + self.state.clients[client] = ClientState(client, versions=versions) for plugin in list(self.plugins.values()): try: @@ -5421,7 +5556,7 @@ async def add_client(self, comm, client=None, versions=None): ws: WorkerState version_warning = version_module.error_message( version_module.get_versions(), - {w: ws._versions for w, ws in parent._workers_dv.items()}, + {w: ws._versions for w, ws in self.state.workers.items()}, versions, ) msg.update(version_warning) @@ -5446,12 +5581,11 @@ async def add_client(self, comm, client=None, versions=None): def remove_client(self, client=None): """Remove client from network""" - parent: SchedulerState = cast(SchedulerState, self) if self.status == Status.running: logger.info("Remove client %s", client) self.log_event(["all", client], {"action": "remove-client", "client": client}) try: - cs: ClientState = parent._clients[client] + cs: ClientState = self.state.clients[client] except KeyError: # XXX is this a legitimate condition? pass @@ -5460,7 +5594,7 @@ def remove_client(self, client=None): self.client_releases_keys( keys=[ts._key for ts in cs._wants_what], client=cs._client_key ) - del parent._clients[client] + del self.state.clients[client] for plugin in list(self.plugins.values()): try: @@ -5470,7 +5604,7 @@ def remove_client(self, client=None): def remove_client_from_events(): # If the client isn't registered anymore after the delay, remove from events - if client not in parent._clients and client in self.events: + if client not in self.state.clients and client in self.events: del self.events[client] cleanup_delay = parse_timedelta( @@ -5480,9 +5614,8 @@ def remove_client_from_events(): def send_task_to_worker(self, worker, ts: TaskState, duration: double = -1): """Send a single computational task to a worker""" - parent: SchedulerState = cast(SchedulerState, self) try: - msg: dict = _task_to_msg(parent, ts, duration) + msg: dict = _task_to_msg(self.state, ts, duration) self.worker_send(worker, msg) except Exception as e: logger.exception(e) @@ -5496,8 +5629,7 @@ def handle_uncaught_error(self, **msg): logger.exception(clean_exception(**msg)[1]) def handle_task_finished(self, key=None, worker=None, **msg): - parent: SchedulerState = cast(SchedulerState, self) - if worker not in parent._workers_dv: + if worker not in self.state.workers: return validate_key(key) @@ -5507,18 +5639,17 @@ def handle_task_finished(self, key=None, worker=None, **msg): r: tuple = self.stimulus_task_finished(key=key, worker=worker, **msg) recommendations, client_msgs, worker_msgs = r - parent._transitions(recommendations, client_msgs, worker_msgs) + self.state.transitions(recommendations, client_msgs, worker_msgs) self.send_all(client_msgs, worker_msgs) def handle_task_erred(self, key=None, **msg): - parent: SchedulerState = cast(SchedulerState, self) recommendations: dict client_msgs: dict worker_msgs: dict r: tuple = self.stimulus_task_erred(key=key, **msg) recommendations, client_msgs, worker_msgs = r - parent._transitions(recommendations, client_msgs, worker_msgs) + self.state.transitions(recommendations, client_msgs, worker_msgs) self.send_all(client_msgs, worker_msgs) @@ -5539,16 +5670,15 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): errant_worker : str, optional Address of the worker supposed to hold a replica, by default None """ - parent: SchedulerState = cast(SchedulerState, self) logger.debug("handle missing data key=%s worker=%s", key, errant_worker) self.log_event(errant_worker, {"action": "missing-data", "key": key}) - ts: TaskState = parent._tasks.get(key) + ts: TaskState = self.state.tasks.get(key) if ts is None: return - ws: WorkerState = parent._workers_dv.get(errant_worker) + ws: WorkerState = self.state.workers.get(errant_worker) if ws is not None and ws in ts._who_has: - parent.remove_replica(ts, ws) + self.state.remove_replica(ts, ws) if ts.state == "memory" and not ts._who_has: if ts._run_spec: self.transitions({key: "released"}) @@ -5556,14 +5686,13 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): self.transitions({key: "forgotten"}) def release_worker_data(self, comm=None, key=None, worker=None): - parent: SchedulerState = cast(SchedulerState, self) - ws: WorkerState = parent._workers_dv.get(worker) - ts: TaskState = parent._tasks.get(key) + ws: WorkerState = self.state.workers.get(worker) + ts: TaskState = self.state.tasks.get(key) if not ws or not ts: return recommendations: dict = {} if ws in ts._who_has: - parent.remove_replica(ts, ws) + self.state.remove_replica(ts, ws) if not ts._who_has: recommendations[ts._key] = "released" if recommendations: @@ -5575,12 +5704,11 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): We stop the task from being stolen in the future, and change task duration accounting as if the task has stopped. """ - parent: SchedulerState = cast(SchedulerState, self) - if key not in parent._tasks: + if key not in self.state.tasks: logger.debug("Skipping long_running since key %s was already released", key) return - ts: TaskState = parent._tasks[key] - steal = parent._extensions.get("stealing") + ts: TaskState = self.state.tasks[key] + steal = self.state.extensions.get("stealing") if steal is not None: steal.remove_key_from_stealable(ts) @@ -5602,18 +5730,17 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): occ: double = ws._processing[ts] ws._occupancy -= occ - parent._total_occupancy -= occ + self.state.total_occupancy -= occ # Cannot remove from processing since we're using this for things like # idleness detection. Idle workers are typically targeted for # downscaling but we should not downscale workers with long running # tasks ws._processing[ts] = 0 ws._long_running.add(ts) - self.check_idle_saturated(ws) + self.state.check_idle_saturated(ws) def handle_worker_status_change(self, status: str, worker: str) -> None: - parent: SchedulerState = cast(SchedulerState, self) - ws: WorkerState = parent._workers_dv.get(worker) # type: ignore + ws: WorkerState = self.state.workers.get(worker) # type: ignore if not ws: return prev_status = ws._status @@ -5631,22 +5758,22 @@ def handle_worker_status_change(self, status: str, worker: str) -> None: ) if ws._status == Status.running: - parent._running.add(ws) + self.state.running.add(ws) recs = {} ts: TaskState - for ts in parent._unrunnable: - valid: set = self.valid_workers(ts) + for ts in self.state.unrunnable: + valid: set = self.state.valid_workers(ts) if valid is None or ws in valid: recs[ts._key] = "waiting" if recs: client_msgs: dict = {} worker_msgs: dict = {} - parent._transitions(recs, client_msgs, worker_msgs) + self.state.transitions(recs, client_msgs, worker_msgs) self.send_all(client_msgs, worker_msgs) else: - parent._running.discard(ws) + self.state.running.discard(ws) async def handle_worker(self, comm=None, worker=None): """ @@ -5873,16 +6000,15 @@ async def scatter( -------- Scheduler.broadcast: """ - parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState start = time() while True: if workers is None: - wss = parent._running + wss = self.state.running else: workers = [self.coerce_address(w) for w in workers] - wss = {parent._workers_dv[w] for w in workers} + wss = {self.state.workers[w] for w in workers} wss = {ws for ws in wss if ws._status == Status.running} if wss: @@ -5911,13 +6037,12 @@ async def scatter( return keys async def gather(self, comm=None, keys=None, serializers=None): - """Collect data from workers to the scheduler""" - parent: SchedulerState = cast(SchedulerState, self) + """Collect data in from workers""" ws: WorkerState keys = list(keys) who_has = {} for key in keys: - ts: TaskState = parent._tasks.get(key) + ts: TaskState = self.state.tasks.get(key) if ts is not None: who_has[key] = [ws._address for ws in ts._who_has] else: @@ -5930,7 +6055,7 @@ async def gather(self, comm=None, keys=None, serializers=None): result = {"status": "OK", "data": data} else: missing_states = [ - (parent._tasks[key].state if key in parent._tasks else None) + (self.state.tasks[key].state if key in self.state.tasks else None) for key in missing_keys ] logger.exception( @@ -5955,7 +6080,7 @@ async def gather(self, comm=None, keys=None, serializers=None): for key, workers in missing_keys.items(): # Task may already be gone if it was held by a # `missing_worker` - ts: TaskState = parent._tasks.get(key) + ts: TaskState = self.state.tasks.get(key) logger.exception( "Workers don't have promised key: %s, %s", str(workers), @@ -5965,10 +6090,10 @@ async def gather(self, comm=None, keys=None, serializers=None): continue recommendations: dict = {key: "released"} for worker in workers: - ws = parent._workers_dv.get(worker) + ws = self.state.workers.get(worker) if ws is not None and ws in ts._who_has: - parent.remove_replica(ts, ws) - parent._transitions( + self.state.remove_replica(ts, ws) + self.state.transitions( recommendations, client_msgs, worker_msgs ) self.send_all(client_msgs, worker_msgs) @@ -5985,23 +6110,22 @@ def clear_task_state(self): async def restart(self, client=None, timeout=30): """Restart all workers. Reset local state.""" - parent: SchedulerState = cast(SchedulerState, self) with log_errors(): - n_workers = len(parent._workers_dv) + n_workers = len(self.state.workers) logger.info("Send lost future signal to clients") cs: ClientState ts: TaskState - for cs in parent._clients.values(): + for cs in self.state.clients.values(): self.client_releases_keys( keys=[ts._key for ts in cs._wants_what], client=cs._client_key ) ws: WorkerState - nannies = {addr: ws._nanny for addr, ws in parent._workers_dv.items()} + nannies = {addr: ws._nanny for addr, ws in self.state.workers.items()} - for addr in list(parent._workers_dv): + for addr in list(self.state.workers): try: # Ask the worker to close if it doesn't have a nanny, # otherwise the nanny will kill it anyway @@ -6058,7 +6182,7 @@ async def restart(self, client=None, timeout=30): self.log_event([client, "all"], {"action": "restart", "client": client}) start = time() - while time() < start + 10 and len(parent._workers_dv) < n_workers: + while time() < start + 10 and len(self.state.workers) < n_workers: await asyncio.sleep(0.01) self.report({"op": "restart"}) @@ -6075,7 +6199,6 @@ async def broadcast( on_error: "Literal['raise', 'return', 'return_pickle', 'ignore']" = "raise", ) -> dict: # dict[str, Any] """Broadcast message to workers, return all results""" - parent: SchedulerState = cast(SchedulerState, self) if workers is True: warnings.warn( "workers=True is deprecated; pass workers=None or omit instead", @@ -6084,18 +6207,18 @@ async def broadcast( workers = None if workers is None: if hosts is None: - workers = list(parent._workers_dv) + workers = list(self.state.workers) else: workers = [] if hosts is not None: for host in hosts: - dh: dict = parent._host_info.get(host) # type: ignore + dh: dict = self.state.host_info.get(host) # type: ignore if dh is not None: workers.extend(dh["addresses"]) # TODO replace with worker_list if nanny: - addresses = [parent._workers_dv[w].nanny for w in workers] + addresses = [self.state.workers[w].nanny for w in workers] else: addresses = workers @@ -6171,8 +6294,7 @@ async def gather_on_worker( ) return set(who_has) - parent: SchedulerState = cast(SchedulerState, self) - ws: WorkerState = parent._workers_dv.get(worker_address) # type: ignore + ws: WorkerState = self.state.workers.get(worker_address) # type: ignore if ws is None: logger.warning(f"Worker {worker_address} lost during replication") @@ -6190,12 +6312,12 @@ async def gather_on_worker( raise ValueError(f"Unexpected message from {worker_address}: {result}") for key in keys_ok: - ts: TaskState = parent._tasks.get(key) # type: ignore + ts: TaskState = self.state.tasks.get(key) # type: ignore if ts is None or ts._state != "memory": logger.warning(f"Key lost during replication: {key}") continue if ws not in ts._who_has: - parent.add_replica(ts, ws) + self.state.add_replica(ts, ws) return keys_failed @@ -6211,8 +6333,6 @@ async def delete_worker_data( keys: list[str] List of keys to delete on the specified worker """ - parent: SchedulerState = cast(SchedulerState, self) - try: await retry_operation( self.rpc(addr=worker_address).free_keys, @@ -6228,15 +6348,15 @@ async def delete_worker_data( ) return - ws: WorkerState = parent._workers_dv.get(worker_address) # type: ignore + ws: WorkerState = self.state.workers.get(worker_address) # type: ignore if ws is None: return for key in keys: - ts: TaskState = parent._tasks.get(key) # type: ignore + ts: TaskState = self.state.tasks.get(key) # type: ignore if ts is not None and ws in ts._who_has: assert ts._state == "memory" - parent.remove_replica(ts, ws) + self.state.remove_replica(ts, ws) if not ts._who_has: # Last copy deleted self.transitions({key: "released"}) @@ -6314,14 +6434,12 @@ async def rebalance( All other workers will be ignored. The mean cluster occupancy will be calculated only using the allowed workers. """ - parent: SchedulerState = cast(SchedulerState, self) - with log_errors(): wss: "Collection[WorkerState]" if workers is not None: - wss = [parent._workers_dv[w] for w in workers] + wss = [self.state.workers[w] for w in workers] else: - wss = parent._workers_dv.values() + wss = self.state.workers.values() if not wss: return {"status": "OK"} @@ -6333,7 +6451,7 @@ async def rebalance( missing_data = [ k for k in keys - if k not in parent._tasks or not parent._tasks[k].who_has + if k not in self.state.tasks or not self.state.tasks[k].who_has ] if missing_data: return {"status": "partial-fail", "keys": missing_data} @@ -6380,7 +6498,6 @@ def _rebalance_find_msgs( - recipient worker - task to be transferred """ - parent: SchedulerState = cast(SchedulerState, self) ts: TaskState ws: WorkerState @@ -6413,15 +6530,18 @@ def _rebalance_find_msgs( # (distributed.worker.memory.recent-to-old-time). # This lets us ignore temporary spikes caused by task heap usage. memory_by_worker = [ - (ws, getattr(ws.memory, parent.MEMORY_REBALANCE_MEASURE)) for ws in workers + (ws, getattr(ws.memory, self.state.MEMORY_REBALANCE_MEASURE)) + for ws in workers ] mean_memory = sum(m for _, m in memory_by_worker) // len(memory_by_worker) for ws, ws_memory in memory_by_worker: if ws.memory_limit: - half_gap = int(parent.MEMORY_REBALANCE_HALF_GAP * ws.memory_limit) - sender_min = parent.MEMORY_REBALANCE_SENDER_MIN * ws.memory_limit - recipient_max = parent.MEMORY_REBALANCE_RECIPIENT_MAX * ws.memory_limit + half_gap = int(self.state.MEMORY_REBALANCE_HALF_GAP * ws.memory_limit) + sender_min = self.state.MEMORY_REBALANCE_SENDER_MIN * ws.memory_limit + recipient_max = ( + self.state.MEMORY_REBALANCE_RECIPIENT_MAX * ws.memory_limit + ) else: half_gap = 0 sender_min = 0.0 @@ -6639,7 +6759,6 @@ async def replicate( -------- Scheduler.rebalance """ - parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState wws: WorkerState ts: TaskState @@ -6647,10 +6766,10 @@ async def replicate( assert branching_factor > 0 async with self._lock if lock else empty_context: if workers is not None: - workers = {parent._workers_dv[w] for w in self.workers_list(workers)} + workers = {self.state.workers[w] for w in self.workers_list(workers)} workers = {ws for ws in workers if ws._status == Status.running} else: - workers = parent._running + workers = self.state.running if n is None: n = len(workers) @@ -6659,7 +6778,7 @@ async def replicate( if n == 0: raise ValueError("Can not use replicate to delete data") - tasks = {parent._tasks[k] for k in keys} + tasks = {self.state.tasks[k] for k in keys} missing_data = [ts._key for ts in tasks if not ts._who_has] if missing_data: return {"status": "partial-fail", "keys": missing_data} @@ -6792,20 +6911,19 @@ def workers_to_close( -------- Scheduler.retire_workers """ - parent: SchedulerState = cast(SchedulerState, self) if target is not None and n is None: - n = len(parent._workers_dv) - target + n = len(self.state.workers) - target if n is not None: if n < 0: n = 0 - target = len(parent._workers_dv) - n + target = len(self.state.workers) - n if n is None and memory_ratio is None: memory_ratio = 2 ws: WorkerState with log_errors(): - if not n and all([ws._processing for ws in parent._workers_dv.values()]): + if not n and all([ws._processing for ws in self.state.workers.values()]): return [] if key is None: @@ -6815,7 +6933,7 @@ def workers_to_close( ): key = pickle.loads(key) - groups = groupby(key, parent._workers.values()) + groups = groupby(key, self.state.workers.values()) limit_bytes = { k: sum([ws._memory_limit for ws in v]) for k, v in groups.items() @@ -6834,7 +6952,7 @@ def _key(group): idle = sorted(groups, key=_key) to_close = [] - n_remain = len(parent._workers_dv) + n_remain = len(self.state.workers) while idle: group = idle.pop() @@ -6902,7 +7020,6 @@ async def retire_workers( -------- Scheduler.workers_to_close """ - parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState ts: TaskState with log_errors(): @@ -6922,18 +7039,18 @@ async def retire_workers( names_set = {str(name) for name in names} wss = { ws - for ws in parent._workers_dv.values() + for ws in self.state.workers.values() if str(ws._name) in names_set } elif workers is not None: wss = { - parent._workers_dv[address] + self.state.workers[address] for address in workers - if address in parent._workers_dv + if address in self.state.workers } else: wss = { - parent._workers_dv[address] + self.state.workers[address] for address in self.workers_to_close(**kwargs) } if not wss: @@ -6959,7 +7076,7 @@ async def retire_workers( # the same on the scheduler to prevent race conditions. prev_status = ws.status ws.status = Status.closing_gracefully - self.running.discard(ws) + self.state.running.discard(ws) self.stream_comms[ws.address].send( {"op": "worker-status-change", "status": ws.status.name} ) @@ -6998,8 +7115,6 @@ async def _track_retire_worker( close_workers: bool, remove: bool, ) -> tuple: # tuple[str | None, dict] - parent: SchedulerState = cast(SchedulerState, self) - while not policy.done(): if policy.no_recipients: # Abort retirement. This time we don't need to worry about race @@ -7019,7 +7134,7 @@ async def _track_retire_worker( "All unique keys on worker %s have been replicated elsewhere", ws._address ) - if close_workers and ws._address in parent._workers_dv: + if close_workers and ws._address in self.state.workers: await self.close_worker(worker=ws._address, safe=True) if remove: await self.remove_worker(address=ws._address, safe=True) @@ -7034,16 +7149,15 @@ def add_keys(self, comm=None, worker=None, keys=(), stimulus_id=None): This should not be used in practice and is mostly here for legacy reasons. However, it is sent by workers from time to time. """ - parent: SchedulerState = cast(SchedulerState, self) - if worker not in parent._workers_dv: + if worker not in self.state.workers: return "not found" - ws: WorkerState = parent._workers_dv[worker] + ws: WorkerState = self.state.workers[worker] redundant_replicas = [] for key in keys: - ts: TaskState = parent._tasks.get(key) + ts: TaskState = self.state.tasks.get(key) if ts is not None and ts._state == "memory": if ws not in ts._who_has: - parent.add_replica(ts, ws) + self.state.add_replica(ts, ws) else: redundant_replicas.append(key) @@ -7077,7 +7191,6 @@ def update_data( -------- Scheduler.mark_key_in_memory """ - parent: SchedulerState = cast(SchedulerState, self) with log_errors(): who_has = { k: [self.coerce_address(vv) for vv in v] for k, v in who_has.items() @@ -7085,18 +7198,18 @@ def update_data( logger.debug("Update data %s", who_has) for key, workers in who_has.items(): - ts: TaskState = parent._tasks.get(key) # type: ignore + ts: TaskState = self.state.tasks.get(key) # type: ignore if ts is None: - ts = parent.new_task(key, None, "memory") + ts = self.state.new_task(key, None, "memory") ts.state = "memory" ts_nbytes = nbytes.get(key, -1) if ts_nbytes >= 0: ts.set_nbytes(ts_nbytes) for w in workers: - ws: WorkerState = parent._workers_dv[w] + ws: WorkerState = self.state.workers[w] if ws not in ts._who_has: - parent.add_replica(ts, ws) + self.state.add_replica(ts, ws) self.report( {"op": "key-in-memory", "key": key, "workers": list(workers)} ) @@ -7105,9 +7218,8 @@ def update_data( self.client_desires_keys(keys=list(who_has), client=client) def report_on_key(self, key: str = None, ts: TaskState = None, client: str = None): - parent: SchedulerState = cast(SchedulerState, self) if ts is None: - ts = parent._tasks.get(key) + ts = self.state.tasks.get(key) elif key is None: key = ts._key else: @@ -7118,7 +7230,7 @@ def report_on_key(self, key: str = None, ts: TaskState = None, client: str = Non if ts is None: report_msg = {"op": "cancelled-key", "key": key} else: - report_msg = _task_to_report_msg(parent, ts) + report_msg = _task_to_report_msg(self.state, ts) if report_msg is not None: self.report(report_msg, ts=ts, client=client) @@ -7177,79 +7289,73 @@ def subscribe_worker_status(self, comm=None): return ident def get_processing(self, comm=None, workers=None): - parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState ts: TaskState if workers is not None: workers = set(map(self.coerce_address, workers)) return { - w: [ts._key for ts in parent._workers_dv[w].processing] for w in workers + w: [ts._key for ts in self.state.workers[w].processing] for w in workers } else: return { w: [ts._key for ts in ws._processing] - for w, ws in parent._workers_dv.items() + for w, ws in self.state.workers.items() } def get_who_has(self, comm=None, keys=None): - parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState ts: TaskState if keys is not None: return { - k: [ws._address for ws in parent._tasks[k].who_has] - if k in parent._tasks + k: [ws._address for ws in self.state.tasks[k].who_has] + if k in self.state.tasks else [] for k in keys } else: return { key: [ws._address for ws in ts._who_has] - for key, ts in parent._tasks.items() + for key, ts in self.state.tasks.items() } def get_has_what(self, comm=None, workers=None): - parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState ts: TaskState if workers is not None: workers = map(self.coerce_address, workers) return { - w: [ts._key for ts in parent._workers_dv[w].has_what] - if w in parent._workers_dv + w: [ts._key for ts in self.state.workers[w].has_what] + if w in self.state.workers else [] for w in workers } else: return { w: [ts._key for ts in ws.has_what] - for w, ws in parent._workers_dv.items() + for w, ws in self.state.workers.items() } def get_ncores(self, comm=None, workers=None): - parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState if workers is not None: workers = map(self.coerce_address, workers) return { - w: parent._workers_dv[w].nthreads + w: self.state.workers[w].nthreads for w in workers - if w in parent._workers_dv + if w in self.state.workers } else: - return {w: ws._nthreads for w, ws in parent._workers_dv.items()} + return {w: ws._nthreads for w, ws in self.state.workers.items()} def get_ncores_running(self, comm=None, workers=None): - parent: SchedulerState = cast(SchedulerState, self) ncores = self.get_ncores(workers=workers) return { w: n for w, n in ncores.items() - if parent._workers_dv[w].status == Status.running + if self.state.workers[w].status == Status.running } async def get_call_stack(self, comm=None, keys=None): - parent: SchedulerState = cast(SchedulerState, self) ts: TaskState dts: TaskState if keys is not None: @@ -7257,7 +7363,7 @@ async def get_call_stack(self, comm=None, keys=None): processing = set() while stack: key = stack.pop() - ts = parent._tasks[key] + ts = self.state.tasks[key] if ts._state == "waiting": stack.extend([dts._key for dts in ts._dependencies]) elif ts._state == "processing": @@ -7268,7 +7374,7 @@ async def get_call_stack(self, comm=None, keys=None): if ts._processing_on: workers[ts._processing_on.address].append(ts._key) else: - workers = {w: None for w in parent._workers_dv} + workers = {w: None for w in self.state.workers} if not workers: return {} @@ -7280,14 +7386,15 @@ async def get_call_stack(self, comm=None, keys=None): return response def get_nbytes(self, comm=None, keys=None, summary=True): - parent: SchedulerState = cast(SchedulerState, self) ts: TaskState with log_errors(): if keys is not None: - result = {k: parent._tasks[k].nbytes for k in keys} + result = {k: self.state.tasks[k].nbytes for k in keys} else: result = { - k: ts._nbytes for k, ts in parent._tasks.items() if ts._nbytes >= 0 + k: ts._nbytes + for k, ts in self.state.tasks.items() + if ts._nbytes >= 0 } if summary: @@ -7318,8 +7425,7 @@ def run_function(self, stream, function, args=(), kwargs={}, wait=True): return run(self, stream, function=function, args=args, kwargs=kwargs, wait=wait) def set_metadata(self, comm=None, keys=None, value=None): - parent: SchedulerState = cast(SchedulerState, self) - metadata = parent._task_metadata + metadata = self.state.task_metadata for key in keys[:-1]: if key not in metadata or not isinstance(metadata[key], (dict, list)): metadata[key] = {} @@ -7327,8 +7433,7 @@ def set_metadata(self, comm=None, keys=None, value=None): metadata[keys[-1]] = value def get_metadata(self, comm=None, keys=None, default=no_default): - parent: SchedulerState = cast(SchedulerState, self) - metadata = parent._task_metadata + metadata = self.state.task_metadata for key in keys[:-1]: metadata = metadata[key] try: @@ -7368,9 +7473,8 @@ def get_task_prefix_states(self, comm=None): return state def get_task_status(self, comm=None, keys=None): - parent: SchedulerState = cast(SchedulerState, self) return { - key: (parent._tasks[key].state if key in parent._tasks else None) + key: (self.state.tasks[key].state if key in self.state.tasks else None) for key in keys } @@ -7461,11 +7565,10 @@ def transition(self, key, finish: str, *args, **kwargs): -------- Scheduler.transitions: transitive version of this function """ - parent: SchedulerState = cast(SchedulerState, self) recommendations: dict worker_msgs: dict client_msgs: dict - a: tuple = parent._transition(key, finish, *args, **kwargs) + a: tuple = self.state.transition(key, finish, args=args, kwargs=kwargs) recommendations, client_msgs, worker_msgs = a self.send_all(client_msgs, worker_msgs) return recommendations @@ -7476,17 +7579,18 @@ def transitions(self, recommendations: dict): This includes feedback from previous transitions and continues until we reach a steady state """ - parent: SchedulerState = cast(SchedulerState, self) client_msgs: dict = {} worker_msgs: dict = {} - parent._transitions(recommendations, client_msgs, worker_msgs) + self.state.transitions(recommendations, client_msgs, worker_msgs) self.send_all(client_msgs, worker_msgs) def story(self, *keys): """Get all transitions that touch one of the input keys""" keys = {key.key if isinstance(key, TaskState) else key for key in keys} return [ - t for t in self.transition_log if t[0] in keys or keys.intersection(t[3]) + t + for t in self.state.transition_log + if t[0] in keys or keys.intersection(t[3]) ] transition_story = story @@ -7497,10 +7601,9 @@ def reschedule(self, key=None, worker=None): Things may have shifted and this task may now be better suited to run elsewhere """ - parent: SchedulerState = cast(SchedulerState, self) ts: TaskState try: - ts = parent._tasks[key] + ts = self.state.tasks[key] except KeyError: logger.warning( "Attempting to reschedule task {}, which was not " @@ -7518,26 +7621,24 @@ def reschedule(self, key=None, worker=None): ##################### def add_resources(self, comm=None, worker=None, resources=None): - parent: SchedulerState = cast(SchedulerState, self) - ws: WorkerState = parent._workers_dv[worker] + ws: WorkerState = self.state.workers[worker] if resources: ws._resources.update(resources) ws._used_resources = {} for resource, quantity in ws._resources.items(): ws._used_resources[resource] = 0 - dr: dict = parent._resources.get(resource, None) + dr: dict = self.state.resources.get(resource, None) if dr is None: - parent._resources[resource] = dr = {} + self.state.resources[resource] = dr = {} dr[worker] = quantity return "OK" def remove_resources(self, worker): - parent: SchedulerState = cast(SchedulerState, self) - ws: WorkerState = parent._workers_dv[worker] + ws: WorkerState = self.state.workers[worker] for resource, quantity in ws._resources.items(): - dr: dict = parent._resources.get(resource, None) + dr: dict = self.state.resources.get(resource, None) if dr is None: - parent._resources[resource] = dr = {} + self.state.resources[resource] = dr = {} del dr[worker] def coerce_address(self, addr, resolve=True): @@ -7548,9 +7649,8 @@ def coerce_address(self, addr, resolve=True): Handles strings, tuples, or aliases. """ # XXX how many address-parsing routines do we have? - parent: SchedulerState = cast(SchedulerState, self) - if addr in parent._aliases: - addr = parent._aliases[addr] + if addr in self.state.aliases: + addr = self.state.aliases[addr] if isinstance(addr, tuple): addr = unparse_host_port(*addr) if not isinstance(addr, str): @@ -7570,16 +7670,17 @@ def workers_list(self, workers): Takes a list of worker addresses or hostnames. Returns a list of all worker addresses that match """ - parent: SchedulerState = cast(SchedulerState, self) if workers is None: - return list(parent._workers) + return list(self.state.workers) out = set() for w in workers: if ":" in w: out.add(w) else: - out.update({ww for ww in parent._workers if w in ww}) # TODO: quadratic + out.update( + {ww for ww in self.state.workers if w in ww} + ) # TODO: quadratic return list(out) def start_ipython(self, comm=None): @@ -7606,11 +7707,10 @@ async def get_profile( stop=None, key=None, ): - parent: SchedulerState = cast(SchedulerState, self) if workers is None: - workers = parent._workers_dv + workers = self.state.workers else: - workers = set(parent._workers_dv) & set(workers) + workers = set(self.state.workers) & set(workers) if scheduler: return profile.get_profile(self.io_loop.profile, start=start, stop=stop) @@ -7640,16 +7740,15 @@ async def get_profile_metadata( stop=None, profile_cycle_interval=None, ): - parent: SchedulerState = cast(SchedulerState, self) dt = profile_cycle_interval or dask.config.get( "distributed.worker.profile.cycle" ) dt = parse_timedelta(dt, default="ms") if workers is None: - workers = parent._workers_dv + workers = self.state.workers else: - workers = set(parent._workers_dv) & set(workers) + workers = set(self.state.workers) & set(workers) results = await asyncio.gather( *(self.rpc(w).profile_metadata(start=start, stop=stop) for w in workers), return_exceptions=True, @@ -7685,7 +7784,6 @@ async def get_profile_metadata( async def performance_report( self, comm=None, start=None, last_count=None, code="", mode=None ): - parent: SchedulerState = cast(SchedulerState, self) stop = time() # Profiles compute, scheduler, workers = await asyncio.gather( @@ -7783,10 +7881,10 @@ def profile_to_figure(state): ntasks=total_tasks, tasks_timings=tasks_timings, address=self.address, - nworkers=len(parent._workers_dv), - threads=sum([ws._nthreads for ws in parent._workers_dv.values()]), + nworkers=len(self.state.workers), + threads=sum([ws._nthreads for ws in self.state.workers.values()]), memory=format_bytes( - sum([ws._memory_limit for ws in parent._workers_dv.values()]) + sum([ws._memory_limit for ws in self.state.workers.values()]) ), code=code, dask_version=dask.__version__, @@ -7892,16 +7990,15 @@ def get_events(self, comm=None, topic=None): return valmap(tuple, self.events) async def get_worker_monitor_info(self, recent=False, starts=None): - parent: SchedulerState = cast(SchedulerState, self) if starts is None: starts = {} results = await asyncio.gather( *( self.rpc(w).get_monitor_info(recent=recent, start=starts.get(w, 0)) - for w in parent._workers_dv + for w in self.state.workers ) ) - return dict(zip(parent._workers_dv, results)) + return dict(zip(self.state.workers, results)) ########### # Cleanup # @@ -7922,7 +8019,6 @@ def reevaluate_occupancy(self, worker_index: Py_ssize_t = 0): lets us avoid this fringe optimization when we have better things to think about. """ - parent: SchedulerState = cast(SchedulerState, self) try: if self.status == Status.closed: return @@ -7930,7 +8026,7 @@ def reevaluate_occupancy(self, worker_index: Py_ssize_t = 0): next_time = timedelta(seconds=0.1) if self.proc.cpu_percent() < 50: - workers: list = list(parent._workers.values()) + workers: list = list(self.state.workers.values()) nworkers: Py_ssize_t = len(workers) i: Py_ssize_t for i in range(nworkers): @@ -7939,7 +8035,7 @@ def reevaluate_occupancy(self, worker_index: Py_ssize_t = 0): try: if ws is None or not ws._processing: continue - parent._reevaluate_occupancy_worker(ws) + self.state.reevaluate_occupancy_worker(ws) finally: del ws # lose ref @@ -7957,12 +8053,11 @@ def reevaluate_occupancy(self, worker_index: Py_ssize_t = 0): raise async def check_worker_ttl(self): - parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState now = time() - for ws in parent._workers_dv.values(): + for ws in self.state.workers.values(): if (ws._last_seen < now - self.worker_ttl) and ( - ws._last_seen < now - 10 * heartbeat_interval(len(parent._workers_dv)) + ws._last_seen < now - 10 * heartbeat_interval(len(self.state.workers)) ): logger.warning( "Worker failed to heartbeat within %s seconds. Closing: %s", @@ -7972,11 +8067,10 @@ async def check_worker_ttl(self): await self.remove_worker(address=ws._address) def check_idle(self): - parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState if ( - any([ws._processing for ws in parent._workers_dv.values()]) - or parent._unrunnable + any([ws._processing for ws in self.state.workers.values()]) + or self.state.unrunnable ): self.idle_since = None return @@ -8006,20 +8100,19 @@ def adaptive_target(self, comm=None, target_duration=None): -------- distributed.deploy.Adaptive """ - parent: SchedulerState = cast(SchedulerState, self) if target_duration is None: target_duration = dask.config.get("distributed.adaptive.target-duration") target_duration = parse_timedelta(target_duration) # CPU cpu = math.ceil( - parent._total_occupancy / target_duration + self.state.total_occupancy / target_duration ) # TODO: threads per worker # Avoid a few long tasks from asking for many cores ws: WorkerState tasks_processing = 0 - for ws in parent._workers_dv.values(): + for ws in self.state.workers.values(): tasks_processing += len(ws._processing) if tasks_processing > cpu: @@ -8027,35 +8120,34 @@ def adaptive_target(self, comm=None, target_duration=None): else: cpu = min(tasks_processing, cpu) - if parent._unrunnable and not parent._workers_dv: + if self.state.unrunnable and not self.state.workers: cpu = max(1, cpu) # add more workers if more than 60% of memory is used - limit = sum([ws._memory_limit for ws in parent._workers_dv.values()]) - used = sum([ws._nbytes for ws in parent._workers_dv.values()]) + limit = sum([ws._memory_limit for ws in self.state.workers.values()]) + used = sum([ws._nbytes for ws in self.state.workers.values()]) memory = 0 if used > 0.6 * limit and limit > 0: - memory = 2 * len(parent._workers_dv) + memory = 2 * len(self.state.workers) target = max(memory, cpu) - if target >= len(parent._workers_dv): + if target >= len(self.state.workers): return target else: # Scale down? to_close = self.workers_to_close() - return len(parent._workers_dv) - len(to_close) + return len(self.state.workers) - len(to_close) def request_acquire_replicas(self, addr: str, keys: list, *, stimulus_id: str): """Asynchronously ask a worker to acquire a replica of the listed keys from other workers. This is a fire-and-forget operation which offers no feedback for success or failure, and is intended for housekeeping and not for computation. """ - parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState ts: TaskState who_has = {} for key in keys: - ts = parent._tasks[key] + ts = self.state.tasks[key] who_has[key] = {ws._address for ws in ts._who_has} self.stream_comms[addr].send( @@ -8078,19 +8170,18 @@ def request_remove_replicas(self, addr: str, keys: list, *, stimulus_id: str): to re-add itself to who_has. If the worker agrees to discard the task, there is no feedback. """ - parent: SchedulerState = cast(SchedulerState, self) - ws: WorkerState = parent._workers_dv[addr] + ws: WorkerState = self.state.workers[addr] validate = self.validate # The scheduler immediately forgets about the replica and suggests the worker to # drop it. The worker may refuse, at which point it will send back an add-keys # message to reinstate it. for key in keys: - ts: TaskState = parent._tasks[key] + ts: TaskState = self.state.tasks[key] if validate: # Do not destroy the last copy assert len(ts._who_has) > 1 - self.remove_replica(ts, ws) + self.state.remove_replica(ts, ws) self.stream_comms[addr].send( { @@ -8389,6 +8480,7 @@ def decide_worker( return ws +@ccall def validate_task_state(ts: TaskState): """ Validate the given TaskState. @@ -8445,7 +8537,7 @@ def validate_task_state(ts: TaskState): assert dts._state != "forgotten" assert (ts._processing_on is not None) == (ts._state == "processing") - assert bool(ts._who_has) == (ts._state == "memory"), (ts, ts._who_has, ts._state) + assert (not not ts._who_has) == (ts._state == "memory"), (ts, ts._who_has) if ts._state == "processing": assert all([dts._who_has for dts in ts._dependencies]), ( @@ -8490,6 +8582,7 @@ def validate_task_state(ts: TaskState): assert ts in ts._processing_on.actors +@ccall def validate_worker_state(ws: WorkerState): ts: TaskState for ts in ws._has_what: @@ -8504,6 +8597,7 @@ def validate_worker_state(ws: WorkerState): assert ts._state in ("memory", "processing") +@ccall def validate_state(tasks, workers, clients): """ Validate a current runtime state diff --git a/distributed/stealing.py b/distributed/stealing.py index 71ce8d6c0ea..19bf3abe362 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -301,8 +301,8 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None): _log_msg = [key, state, victim.address, thief.address, stimulus_id] if ts.state != "processing": - self.scheduler._reevaluate_occupancy_worker(thief) - self.scheduler._reevaluate_occupancy_worker(victim) + self.scheduler.state.reevaluate_occupancy_worker(thief) + self.scheduler.state.reevaluate_occupancy_worker(victim) elif ( state in _WORKER_STATE_UNDEFINED or state in _WORKER_STATE_CONFIRM diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 4612d9179d6..31d1d95a212 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3872,7 +3872,7 @@ async def test_idempotence(s, a, b): # Submit x = c.submit(inc, 1) await x - log = list(s.transition_log) + log = list(s.state.transition_log) len_single_submit = len(log) # see last assert @@ -3880,29 +3880,29 @@ async def test_idempotence(s, a, b): assert x.key == y.key await y await asyncio.sleep(0.1) - log2 = list(s.transition_log) + log2 = list(s.state.transition_log) assert log == log2 # Error a = c.submit(div, 1, 0) await wait(a) assert a.status == "error" - log = list(s.transition_log) + log = list(s.state.transition_log) b = f.submit(div, 1, 0) assert a.key == b.key await wait(b) await asyncio.sleep(0.1) - log2 = list(s.transition_log) + log2 = list(s.state.transition_log) assert log == log2 - s.transition_log.clear() + s.state.transition_log.clear() # Simultaneous Submit d = c.submit(inc, 2) e = c.submit(inc, 2) await wait([d, e]) - assert len(s.transition_log) == len_single_submit + assert len(s.state.transition_log) == len_single_submit await c.close() await f.close() @@ -4566,7 +4566,7 @@ def test_auto_normalize_collection_sync(c): def assert_no_data_loss(scheduler): - for key, start, finish, recommendations, _ in scheduler.transition_log: + for key, start, finish, recommendations, _ in scheduler.state.transition_log: if start == "memory" and finish == "released": for k, v in recommendations.items(): assert not (k == key and v == "waiting") @@ -5557,7 +5557,7 @@ def fib(x): future = c.submit(fib, 8) result = await future assert result == 21 - assert len(s.transition_log) > 50 + assert len(s.state.transition_log) > 50 @gen_cluster(client=True) @@ -6873,7 +6873,7 @@ def test_computation_object_code_dask_compute(client): test_function_code = inspect.getsource(test_computation_object_code_dask_compute) def fetch_comp_code(dask_scheduler): - computations = list(dask_scheduler.computations) + computations = list(dask_scheduler.state.computations) assert len(computations) == 1 comp = computations[0] assert len(comp.code) == 1 @@ -6913,7 +6913,7 @@ async def test_computation_object_code_dask_persist(c, s, a, b): test_function_code = inspect.getsource( test_computation_object_code_dask_persist.__wrapped__ ) - computations = list(s.computations) + computations = list(s.state.computations) assert len(computations) == 1 comp = computations[0] assert len(comp.code) == 1 @@ -6933,7 +6933,7 @@ def func(x): test_function_code = inspect.getsource( test_computation_object_code_client_submit_simple.__wrapped__ ) - computations = list(s.computations) + computations = list(s.state.computations) assert len(computations) == 1 comp = computations[0] @@ -6954,7 +6954,7 @@ def func(x): test_function_code = inspect.getsource( test_computation_object_code_client_submit_list_comp.__wrapped__ ) - computations = list(s.computations) + computations = list(s.state.computations) assert len(computations) == 1 comp = computations[0] @@ -6976,7 +6976,7 @@ def func(x): test_function_code = inspect.getsource( test_computation_object_code_client_submit_dict_comp.__wrapped__ ) - computations = list(s.computations) + computations = list(s.state.computations) assert len(computations) == 1 comp = computations[0] @@ -6996,7 +6996,7 @@ async def test_computation_object_code_client_map(c, s, a, b): test_function_code = inspect.getsource( test_computation_object_code_client_map.__wrapped__ ) - computations = list(s.computations) + computations = list(s.state.computations) assert len(computations) == 1 comp = computations[0] assert len(comp.code) == 1 @@ -7014,7 +7014,7 @@ async def test_computation_object_code_client_compute(c, s, a, b): test_function_code = inspect.getsource( test_computation_object_code_client_compute.__wrapped__ ) - computations = list(s.computations) + computations = list(s.state.computations) assert len(computations) == 1 comp = computations[0] assert len(comp.code) == 1 diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 6d5081ef0bd..57e7abe6bb7 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -734,11 +734,11 @@ async def test_coerce_address(s): assert s.coerce_address(123) == b.address assert s.coerce_address("charlie") == c.address - assert s.coerce_hostname("127.0.0.1") == "127.0.0.1" - assert s.coerce_hostname("alice") == a.ip - assert s.coerce_hostname(123) == b.ip - assert s.coerce_hostname("charlie") == c.ip - assert s.coerce_hostname("jimmy") == "jimmy" + assert s.state.coerce_hostname("127.0.0.1") == "127.0.0.1" + assert s.state.coerce_hostname("alice") == a.ip + assert s.state.coerce_hostname(123) == b.ip + assert s.state.coerce_hostname("charlie") == c.ip + assert s.state.coerce_hostname("jimmy") == "jimmy" assert s.coerce_address("zzzt:8000", resolve=False) == "tcp://zzzt:8000" await asyncio.gather(a.close(), b.close(), c.close()) @@ -796,11 +796,11 @@ async def test_story(c, s, a, b): f = c.persist(y) await wait([f]) - assert s.transition_log + assert s.state.transition_log story = s.story(x.key) - assert all(line in s.transition_log for line in story) - assert len(story) < len(s.transition_log) + assert all(line in s.state.transition_log for line in story) + assert len(story) < len(s.state.transition_log) assert all(x.key == line[0] or x.key in line[-2] for line in story) assert len(s.story(x.key, y.key)) > len(story) @@ -1561,13 +1561,13 @@ async def test_dont_recompute_if_persisted(c, s, a, b): yy = y.persist() await wait(yy) - old = list(s.transition_log) + old = list(s.state.transition_log) yyy = y.persist() await wait(yyy) await asyncio.sleep(0.100) - assert list(s.transition_log) == old + assert list(s.state.transition_log) == old @gen_cluster(client=True) @@ -1598,12 +1598,12 @@ async def test_dont_recompute_if_persisted_3(c, s, a, b): ww = w.persist() await wait(ww) - old = list(s.transition_log) + old = list(s.state.transition_log) www = w.persist() await wait(www) await asyncio.sleep(0.100) - assert list(s.transition_log) == old + assert list(s.state.transition_log) == old @gen_cluster(client=True) @@ -1650,13 +1650,13 @@ async def test_dont_recompute_if_erred(c, s, a, b): yy = y.persist() await wait(yy) - old = list(s.transition_log) + old = list(s.state.transition_log) yyy = y.persist() await wait(yyy) await asyncio.sleep(0.100) - assert list(s.transition_log) == old + assert list(s.state.transition_log) == old @gen_cluster() @@ -1918,7 +1918,7 @@ async def test_get_task_duration(c, s, a, b): await future assert 10 < s.task_prefixes["inc"].duration_average < 100 - ts_pref1 = s.new_task("inc-abcdefab", None, "released") + ts_pref1 = s.state.new_task("inc-abcdefab", None, "released") assert 10 < s.get_task_duration(ts_pref1) < 100 # make sure get_task_duration adds TaskStates to unknown dict @@ -3172,16 +3172,16 @@ async def test_computations(c, s, a, b): z = (x - 2).persist() await z - assert len(s.computations) == 2 - assert "add" in str(s.computations[0].groups) - assert "sub" in str(s.computations[1].groups) - assert "sub" not in str(s.computations[0].groups) + assert len(s.state.computations) == 2 + assert "add" in str(s.state.computations[0].groups) + assert "sub" in str(s.state.computations[1].groups) + assert "sub" not in str(s.state.computations[0].groups) - assert isinstance(repr(s.computations[1]), str) + assert isinstance(repr(s.state.computations[1]), str) - assert s.computations[1].stop == max(tg.stop for tg in s.task_groups.values()) + assert s.state.computations[1].stop == max(tg.stop for tg in s.task_groups.values()) - assert s.computations[0].states["memory"] == y.npartitions + assert s.state.computations[0].states["memory"] == y.npartitions @gen_cluster(client=True) @@ -3190,7 +3190,7 @@ async def test_computations_futures(c, s, a, b): total = c.submit(sum, futures) await total - [computation] = s.computations + [computation] = s.state.computations assert "sum" in str(computation.groups) assert "inc" in str(computation.groups) @@ -3246,7 +3246,7 @@ async def test_worker_reconnect_task_memory(c, s, a): await res assert ("no-worker", "memory") in { - (start, finish) for (_, start, finish, _, _) in s.transition_log + (start, finish) for (_, start, finish, _, _) in s.state.transition_log } @@ -3270,7 +3270,7 @@ async def test_worker_reconnect_task_memory_with_resources(c, s, a): await res assert ("no-worker", "memory") in { - (start, finish) for (_, start, finish, _, _) in s.transition_log + (start, finish) for (_, start, finish, _, _) in s.state.transition_log } From 359915da39ffbf30a3396b15538d1a204d36d891 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 23 Feb 2022 14:45:55 -0800 Subject: [PATCH 02/10] Use `bool` instead of `not not` --- distributed/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 8ad41503b52..6e91f2dbb40 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -8533,7 +8533,7 @@ def validate_task_state(ts: TaskState): assert dts._state != "forgotten" assert (ts._processing_on is not None) == (ts._state == "processing") - assert (not not ts._who_has) == (ts._state == "memory"), (ts, ts._who_has) + assert bool(ts._who_has) == (ts._state == "memory"), (ts, ts._who_has) if ts._state == "processing": assert all([dts._who_has for dts in ts._dependencies]), ( From 47ee1269815b9209f4d3116a8b0db9ff254caad0 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 23 Feb 2022 15:10:06 -0800 Subject: [PATCH 03/10] More test fixes --- distributed/tests/test_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 31d1d95a212..a6d69f39127 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5233,10 +5233,10 @@ def long_running(lock): ts = s.tasks[f.key] ws = s.workers[a.address] - s.set_duration_estimate(ts, ws) + s.state.set_duration_estimate(ts, ws) assert s.workers[a.address].occupancy == 0 - s.reevaluate_occupancy(0) + s.state.reevaluate_occupancy(0) assert s.workers[a.address].occupancy == 0 await l.release() @@ -6893,7 +6893,7 @@ def test_computation_object_code_not_available(client): result = np.where(ddf.a > 4) def fetch_comp_code(dask_scheduler): - computations = list(dask_scheduler.computations) + computations = list(dask_scheduler.state.computations) assert len(computations) == 1 comp = computations[0] assert len(comp.code) == 1 From 6849761bc14557f5be8623d69530478ada3bf00b Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Fri, 25 Feb 2022 17:29:23 -0800 Subject: [PATCH 04/10] Fix `test_long_running_not_in_occupancy` --- distributed/tests/test_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index a6d69f39127..7ae368858f2 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5236,7 +5236,7 @@ def long_running(lock): s.state.set_duration_estimate(ts, ws) assert s.workers[a.address].occupancy == 0 - s.state.reevaluate_occupancy(0) + s.reevaluate_occupancy(0) assert s.workers[a.address].occupancy == 0 await l.release() From e798f9608fddb00478fae130e590565b93f323a8 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 28 Feb 2022 12:53:49 -0800 Subject: [PATCH 05/10] Add back missing `if` --- distributed/scheduler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 6e91f2dbb40..3240e02d99b 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4589,7 +4589,9 @@ def heartbeat_worker( ws._last_seen = local_now if executing is not None: ws._executing = { - self.state.tasks[key]: duration for key, duration in executing.items() + self.state.tasks[key]: duration + for key, duration in executing.items() + if key in self.state.tasks } ws._metrics = metrics From 99d91907a6600677481f973761e89c005fd56f8c Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 28 Feb 2022 14:28:30 -0800 Subject: [PATCH 06/10] Grab `transition_log` from `SchedulerState` --- distributed/tests/test_worker_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/tests/test_worker_client.py b/distributed/tests/test_worker_client.py index 4477b365efb..e0fb8a868d3 100644 --- a/distributed/tests/test_worker_client.py +++ b/distributed/tests/test_worker_client.py @@ -37,7 +37,7 @@ def func(x): assert xx == 10 + 1 + (10 + 1) * 2 assert yy == 20 + 1 + (20 + 1) * 2 - assert len(s.transition_log) > 10 + assert len(s.state.transition_log) > 10 assert len([id for id in s.wants_what if id.lower().startswith("client")]) == 1 From a060153d6aef08d5998659dbfa86e041213c6757 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 28 Feb 2022 14:33:51 -0800 Subject: [PATCH 07/10] Grab `replicated_tasks` from `SchedulerState` --- distributed/active_memory_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/active_memory_manager.py b/distributed/active_memory_manager.py index 14980fb0533..53a1b447405 100644 --- a/distributed/active_memory_manager.py +++ b/distributed/active_memory_manager.py @@ -467,7 +467,7 @@ def run(self): nkeys = 0 ndrop = 0 - for ts in self.manager.scheduler.replicated_tasks: + for ts in self.manager.scheduler.state.replicated_tasks: desired_replicas = 1 # TODO have a marker on TaskState # If a dependent task has not been assigned to a worker yet, err on the side From d2ba1206709b10710d61dacf7d59ab6d07232304 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 28 Feb 2022 14:34:39 -0800 Subject: [PATCH 08/10] Grab `running` from `SchedulerState` --- distributed/active_memory_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/active_memory_manager.py b/distributed/active_memory_manager.py index 53a1b447405..1fc2cbe30c4 100644 --- a/distributed/active_memory_manager.py +++ b/distributed/active_memory_manager.py @@ -591,7 +591,7 @@ def run(self): drop_ws = (yield "drop", ts, {ws}) if drop_ws: continue # Use case 1 or 2 - if ts.who_has & self.manager.scheduler.running: + if ts.who_has & self.manager.scheduler.state.running: continue # Use case 3 or 4 # Use case 5 From 62f57ba97513c5084fcd7960062e2132ba229025 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 28 Feb 2022 16:29:26 -0800 Subject: [PATCH 09/10] Move cached `MEMORY_*` config vars to Scheduler --- distributed/scheduler.py | 82 ++++++++++++---------------------------- 1 file changed, 25 insertions(+), 57 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 3240e02d99b..26e0132b0ac 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2001,13 +2001,8 @@ class SchedulerState: _transition_counter: Py_ssize_t _plugins: dict # dict[str, SchedulerPlugin] - # Variables from dask.config, cached by __init__ for performance + # Pulled from dask.config for efficient usage _UNKNOWN_TASK_DURATION: double - _MEMORY_RECENT_TO_OLD_TIME: double - _MEMORY_REBALANCE_MEASURE: str - _MEMORY_REBALANCE_SENDER_MIN: double - _MEMORY_REBALANCE_RECIPIENT_MAX: double - _MEMORY_REBALANCE_HALF_GAP: double def __init__( self, @@ -2079,26 +2074,10 @@ def __init__( } self._plugins = {} if not plugins else {_get_plugin_name(p): p for p in plugins} - # Variables from dask.config, cached by __init__ for performance + # Pulled from dask.config for efficient usage self._UNKNOWN_TASK_DURATION = parse_timedelta( dask.config.get("distributed.scheduler.unknown-task-duration") ) - self._MEMORY_RECENT_TO_OLD_TIME = parse_timedelta( - dask.config.get("distributed.worker.memory.recent-to-old-time") - ) - self._MEMORY_REBALANCE_MEASURE = dask.config.get( - "distributed.worker.memory.rebalance.measure" - ) - self._MEMORY_REBALANCE_SENDER_MIN = dask.config.get( - "distributed.worker.memory.rebalance.sender-min" - ) - self._MEMORY_REBALANCE_RECIPIENT_MAX = dask.config.get( - "distributed.worker.memory.rebalance.recipient-max" - ) - self._MEMORY_REBALANCE_HALF_GAP = ( - dask.config.get("distributed.worker.memory.rebalance.sender-recipient-gap") - / 2.0 - ) self._transition_counter = 0 # Call Server.__init__() @@ -2224,30 +2203,6 @@ def plugins(self) -> "dict[str, SchedulerPlugin]": def plugins(self, val): self._plugins = val - @property - def UNKNOWN_TASK_DURATION(self): - return self._UNKNOWN_TASK_DURATION - - @property - def MEMORY_RECENT_TO_OLD_TIME(self): - return self._MEMORY_RECENT_TO_OLD_TIME - - @property - def MEMORY_REBALANCE_MEASURE(self): - return self._MEMORY_REBALANCE_MEASURE - - @property - def MEMORY_REBALANCE_SENDER_MIN(self): - return self._MEMORY_REBALANCE_SENDER_MIN - - @property - def MEMORY_REBALANCE_RECIPIENT_MAX(self): - return self._MEMORY_REBALANCE_RECIPIENT_MAX - - @property - def MEMORY_REBALANCE_HALF_GAP(self): - return self._MEMORY_REBALANCE_HALF_GAP - @property def memory(self) -> MemoryState: return MemoryState.sum(*(w.memory for w in self.workers.values())) @@ -4149,6 +4104,24 @@ def __init__( connection_limit = get_fileno_limit() / 2 + # Variables from dask.config, cached by __init__ for performance + self.MEMORY_RECENT_TO_OLD_TIME = parse_timedelta( + dask.config.get("distributed.worker.memory.recent-to-old-time") + ) + self.MEMORY_REBALANCE_MEASURE = dask.config.get( + "distributed.worker.memory.rebalance.measure" + ) + self.MEMORY_REBALANCE_SENDER_MIN = dask.config.get( + "distributed.worker.memory.rebalance.sender-min" + ) + self.MEMORY_REBALANCE_RECIPIENT_MAX = dask.config.get( + "distributed.worker.memory.rebalance.recipient-max" + ) + self.MEMORY_REBALANCE_HALF_GAP = ( + dask.config.get("distributed.worker.memory.rebalance.sender-recipient-gap") + / 2.0 + ) + self.state: SchedulerState = SchedulerState( aliases=aliases, clients=clients, @@ -4598,9 +4571,7 @@ def heartbeat_worker( # Calculate RSS - dask keys, separating "old" and "new" usage # See MemoryState for details - max_memory_unmanaged_old_hist_age = ( - local_now - self.state.MEMORY_RECENT_TO_OLD_TIME - ) + max_memory_unmanaged_old_hist_age = local_now - self.MEMORY_RECENT_TO_OLD_TIME memory_unmanaged_old = ws._memory_unmanaged_old while ws._memory_other_history: timestamp, size = ws._memory_other_history[0] @@ -6530,18 +6501,15 @@ def _rebalance_find_msgs( # (distributed.worker.memory.recent-to-old-time). # This lets us ignore temporary spikes caused by task heap usage. memory_by_worker = [ - (ws, getattr(ws.memory, self.state.MEMORY_REBALANCE_MEASURE)) - for ws in workers + (ws, getattr(ws.memory, self.MEMORY_REBALANCE_MEASURE)) for ws in workers ] mean_memory = sum(m for _, m in memory_by_worker) // len(memory_by_worker) for ws, ws_memory in memory_by_worker: if ws.memory_limit: - half_gap = int(self.state.MEMORY_REBALANCE_HALF_GAP * ws.memory_limit) - sender_min = self.state.MEMORY_REBALANCE_SENDER_MIN * ws.memory_limit - recipient_max = ( - self.state.MEMORY_REBALANCE_RECIPIENT_MAX * ws.memory_limit - ) + half_gap = int(self.MEMORY_REBALANCE_HALF_GAP * ws.memory_limit) + sender_min = self.MEMORY_REBALANCE_SENDER_MIN * ws.memory_limit + recipient_max = self.MEMORY_REBALANCE_RECIPIENT_MAX * ws.memory_limit else: half_gap = 0 sender_min = 0.0 From 4f025665a029f4d4d1b804e3dd237341ee4722cb Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Tue, 1 Mar 2022 02:11:11 -0800 Subject: [PATCH 10/10] Convert functions on `SchedulerState` to methods --- distributed/scheduler.py | 503 +++++++++++++++++++-------------------- 1 file changed, 248 insertions(+), 255 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 26e0132b0ac..3564ab72450 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2572,8 +2572,8 @@ def transition_no_worker_memory( self.check_idle_saturated(ws) - _add_to_memory( - self, ts, ws, recommendations, client_msgs, type=type, typename=typename + self._add_to_memory( + ts, ws, recommendations, client_msgs, type=type, typename=typename ) ts.state = "memory" @@ -2586,6 +2586,235 @@ def transition_no_worker_memory( pdb.set_trace() raise + # Utility methods related to transitions + + @cfunc + @exceptval(check=False) + def _remove_from_processing(self, ts: TaskState) -> str: # -> str | None + """ + Remove *ts* from the set of processing tasks. + + See also ``Scheduler.set_duration_estimate`` + """ + ws: WorkerState = ts._processing_on + ts._processing_on = None # type: ignore + w: str = ws._address + + if w not in self._workers_dv: # may have been removed + return None # type: ignore + + duration: double = ws._processing.pop(ts) + if not ws._processing: + self._total_occupancy -= ws._occupancy + ws._occupancy = 0 + else: + self._total_occupancy -= duration + ws._occupancy -= duration + + self.check_idle_saturated(ws) + self.release_resources(ts, ws) + + return w + + @cfunc + @exceptval(check=False) + def _add_to_memory( + self, + ts: TaskState, + ws: WorkerState, + recommendations: dict, + client_msgs: dict, + type=None, + typename: str = None, + ): + """ + Add *ts* to the set of in-memory tasks. + """ + if self._validate: + assert ts not in ws._has_what + + self.add_replica(ts, ws) + + deps: list = list(ts._dependents) + if len(deps) > 1: + deps.sort(key=operator.attrgetter("priority"), reverse=True) + + dts: TaskState + s: set + for dts in deps: + s = dts._waiting_on + if ts in s: + s.discard(ts) + if not s: # new task ready to run + recommendations[dts._key] = "processing" + + for dts in ts._dependencies: + s = dts._waiters + s.discard(ts) + if not s and not dts._who_wants: + recommendations[dts._key] = "released" + + report_msg: dict = {} + cs: ClientState + if not ts._waiters and not ts._who_wants: + recommendations[ts._key] = "released" + else: + report_msg["op"] = "key-in-memory" + report_msg["key"] = ts._key + if type is not None: + report_msg["type"] = type + + for cs in ts._who_wants: + client_msgs[cs._client_key] = [report_msg] + + ts.state = "memory" + ts._type = typename # type: ignore + ts._group._types.add(typename) + + cs = self._clients["fire-and-forget"] + if ts in cs._wants_what: + self._client_releases_keys( + cs=cs, + keys=[ts._key], + recommendations=recommendations, + ) + + @cfunc + @exceptval(check=False) + def _propagate_forgotten( + self, ts: TaskState, recommendations: dict, worker_msgs: dict + ): + ts.state = "forgotten" + key: str = ts._key + dts: TaskState + for dts in ts._dependents: + dts._has_lost_dependencies = True + dts._dependencies.remove(ts) + dts._waiting_on.discard(ts) + if dts._state not in ("memory", "erred"): + # Cannot compute task anymore + recommendations[dts._key] = "forgotten" + ts._dependents.clear() + ts._waiters.clear() + + for dts in ts._dependencies: + dts._dependents.remove(ts) + dts._waiters.discard(ts) + if not dts._dependents and not dts._who_wants: + # Task not needed anymore + assert dts is not ts + recommendations[dts._key] = "forgotten" + ts._dependencies.clear() + ts._waiting_on.clear() + + ws: WorkerState + for ws in ts._who_has: + w: str = ws._address + if w in self._workers_dv: # in case worker has died + worker_msgs[w] = [ + { + "op": "free-keys", + "keys": [key], + "stimulus_id": f"propagate-forgotten-{time()}", + } + ] + self.remove_all_replicas(ts) + + @ccall + @exceptval(check=False) + def _client_releases_keys(self, keys: list, cs: ClientState, recommendations: dict): + """Remove keys from client desired list""" + logger.debug("Client %s releases keys: %s", cs._client_key, keys) + ts: TaskState + for key in keys: + ts = self._tasks.get(key) # type: ignore + if ts is not None and ts in cs._wants_what: + cs._wants_what.remove(ts) + ts._who_wants.remove(cs) + if not ts._who_wants: + if not ts._dependents: + # No live dependents, can forget + recommendations[ts._key] = "forgotten" + elif ts._state != "erred" and not ts._waiters: + recommendations[ts._key] = "released" + + # Message building functions for communication + + @ccall + @exceptval(check=False) + def _task_to_msg(self, ts: TaskState, duration: double = -1) -> dict: + """Convert a single computational task to a message""" + ws: WorkerState + dts: TaskState + + # FIXME: The duration attribute is not used on worker. We could safe ourselves the time to compute and submit this + if duration < 0: + duration = self.get_task_duration(ts) + + msg: dict = { + "op": "compute-task", + "key": ts._key, + "priority": ts._priority, + "duration": duration, + "stimulus_id": f"compute-task-{time()}", + "who_has": {}, + } + if ts._resource_restrictions: + msg["resource_restrictions"] = ts._resource_restrictions + if ts._actor: + msg["actor"] = True + + deps: set = ts._dependencies + if deps: + msg["who_has"] = { + dts._key: [ws._address for ws in dts._who_has] for dts in deps + } + msg["nbytes"] = {dts._key: dts._nbytes for dts in deps} + + if self._validate: + assert all(msg["who_has"].values()) + + task = ts._run_spec + if type(task) is dict: + msg.update(task) + else: + msg["task"] = task + + if ts._annotations: + msg["annotations"] = ts._annotations + + return msg + + @ccall + @exceptval(check=False) + def _task_to_report_msg(self, ts: TaskState) -> dict: # -> dict | None + if ts._state == "forgotten": + return {"op": "cancelled-key", "key": ts._key} + elif ts._state == "memory": + return {"op": "key-in-memory", "key": ts._key} + elif ts._state == "erred": + failing_ts: TaskState = ts._exception_blame + return { + "op": "task-erred", + "key": ts._key, + "exception": failing_ts._exception, + "traceback": failing_ts._traceback, + } + else: + return None # type: ignore + + @cfunc + @exceptval(check=False) + def _task_to_client_msgs(self, ts: TaskState) -> dict: + if ts._who_wants: + report_msg: dict = self._task_to_report_msg(ts) + if report_msg is not None: + cs: ClientState + return {cs._client_key: [report_msg] for cs in ts._who_wants} + return {} + + # Other transition related functions + @ccall @exceptval(check=False) def decide_worker(self, ts: TaskState) -> WorkerState: # -> WorkerState | None @@ -2747,7 +2976,7 @@ def transition_waiting_processing(self, key): # logger.debug("Send job to worker: %s, %s", worker, key) - worker_msgs[worker] = [_task_to_msg(self, ts)] + worker_msgs[worker] = [self._task_to_msg(ts)] return recommendations, client_msgs, worker_msgs except Exception as e: @@ -2780,8 +3009,8 @@ def transition_waiting_memory( self.check_idle_saturated(ws) - _add_to_memory( - self, ts, ws, recommendations, client_msgs, type=type, typename=typename + self._add_to_memory( + ts, ws, recommendations, client_msgs, type=type, typename=typename ) if self._validate: @@ -2875,10 +3104,10 @@ def transition_processing_memory( if nbytes is not None: ts.set_nbytes(nbytes) - _remove_from_processing(self, ts) + self._remove_from_processing(ts) - _add_to_memory( - self, ts, ws, recommendations, client_msgs, type=type, typename=typename + self._add_to_memory( + ts, ws, recommendations, client_msgs, type=type, typename=typename ) if self._validate: @@ -3108,7 +3337,7 @@ def transition_processing_released(self, key): assert not ts._waiting_on assert self._tasks[key].state == "processing" - w: str = _remove_from_processing(self, ts) + w: str = self._remove_from_processing(ts) if w: worker_msgs[w] = [ { @@ -3175,7 +3404,7 @@ def transition_processing_erred( ws = ts._processing_on ws._actors.remove(ts) - w = _remove_from_processing(self, ts) + w = self._remove_from_processing(ts) ts._erred_on.add(w or worker) if exception is not None: @@ -3215,8 +3444,7 @@ def transition_processing_erred( cs = self._clients["fire-and-forget"] if ts in cs._wants_what: - _client_releases_keys( - self, + self._client_releases_keys( cs=cs, keys=[key], recommendations=recommendations, @@ -3305,9 +3533,9 @@ def transition_memory_forgotten(self, key): for ws in ts._who_has: ws._actors.discard(ts) - _propagate_forgotten(self, ts, recommendations, worker_msgs) + self._propagate_forgotten(ts, recommendations, worker_msgs) - client_msgs = _task_to_client_msgs(self, ts) + client_msgs = self._task_to_client_msgs(ts) self.remove_key(key) return recommendations, client_msgs, worker_msgs @@ -3343,9 +3571,9 @@ def transition_released_forgotten(self, key): else: assert 0, (ts,) - _propagate_forgotten(self, ts, recommendations, worker_msgs) + self._propagate_forgotten(ts, recommendations, worker_msgs) - client_msgs = _task_to_client_msgs(self, ts) + client_msgs = self._task_to_client_msgs(ts) self.remove_key(key) return recommendations, client_msgs, worker_msgs @@ -5444,8 +5672,8 @@ def client_releases_keys(self, keys=None, client=None): cs: ClientState = self.state.clients[client] recommendations: dict = {} - _client_releases_keys( - self.state, keys=keys, cs=cs, recommendations=recommendations + self.state._client_releases_keys( + keys=keys, cs=cs, recommendations=recommendations ) self.transitions(recommendations) @@ -5588,7 +5816,7 @@ def remove_client_from_events(): def send_task_to_worker(self, worker, ts: TaskState, duration: double = -1): """Send a single computational task to a worker""" try: - msg: dict = _task_to_msg(self.state, ts, duration) + msg: dict = self.state._task_to_msg(ts, duration) self.worker_send(worker, msg) except Exception as e: logger.exception(e) @@ -7196,7 +7424,7 @@ def report_on_key(self, key: str = None, ts: TaskState = None, client: str = Non if ts is None: report_msg = {"op": "cancelled-key", "key": key} else: - report_msg = _task_to_report_msg(self.state, ts) + report_msg = self.state._task_to_report_msg(ts) if report_msg is not None: self.report(report_msg, ts=ts, client=client) @@ -8158,241 +8386,6 @@ def request_remove_replicas(self, addr: str, keys: list, *, stimulus_id: str): ) -@cfunc -@exceptval(check=False) -def _remove_from_processing( - state: SchedulerState, ts: TaskState -) -> str: # -> str | None - """ - Remove *ts* from the set of processing tasks. - - See also ``Scheduler.set_duration_estimate`` - """ - ws: WorkerState = ts._processing_on - ts._processing_on = None # type: ignore - w: str = ws._address - - if w not in state._workers_dv: # may have been removed - return None # type: ignore - - duration: double = ws._processing.pop(ts) - if not ws._processing: - state._total_occupancy -= ws._occupancy - ws._occupancy = 0 - else: - state._total_occupancy -= duration - ws._occupancy -= duration - - state.check_idle_saturated(ws) - state.release_resources(ts, ws) - - return w - - -@cfunc -@exceptval(check=False) -def _add_to_memory( - state: SchedulerState, - ts: TaskState, - ws: WorkerState, - recommendations: dict, - client_msgs: dict, - type=None, - typename: str = None, -): - """ - Add *ts* to the set of in-memory tasks. - """ - if state._validate: - assert ts not in ws._has_what - - state.add_replica(ts, ws) - - deps: list = list(ts._dependents) - if len(deps) > 1: - deps.sort(key=operator.attrgetter("priority"), reverse=True) - - dts: TaskState - s: set - for dts in deps: - s = dts._waiting_on - if ts in s: - s.discard(ts) - if not s: # new task ready to run - recommendations[dts._key] = "processing" - - for dts in ts._dependencies: - s = dts._waiters - s.discard(ts) - if not s and not dts._who_wants: - recommendations[dts._key] = "released" - - report_msg: dict = {} - cs: ClientState - if not ts._waiters and not ts._who_wants: - recommendations[ts._key] = "released" - else: - report_msg["op"] = "key-in-memory" - report_msg["key"] = ts._key - if type is not None: - report_msg["type"] = type - - for cs in ts._who_wants: - client_msgs[cs._client_key] = [report_msg] - - ts.state = "memory" - ts._type = typename # type: ignore - ts._group._types.add(typename) - - cs = state._clients["fire-and-forget"] - if ts in cs._wants_what: - _client_releases_keys( - state, - cs=cs, - keys=[ts._key], - recommendations=recommendations, - ) - - -@cfunc -@exceptval(check=False) -def _propagate_forgotten( - state: SchedulerState, ts: TaskState, recommendations: dict, worker_msgs: dict -): - ts.state = "forgotten" - key: str = ts._key - dts: TaskState - for dts in ts._dependents: - dts._has_lost_dependencies = True - dts._dependencies.remove(ts) - dts._waiting_on.discard(ts) - if dts._state not in ("memory", "erred"): - # Cannot compute task anymore - recommendations[dts._key] = "forgotten" - ts._dependents.clear() - ts._waiters.clear() - - for dts in ts._dependencies: - dts._dependents.remove(ts) - dts._waiters.discard(ts) - if not dts._dependents and not dts._who_wants: - # Task not needed anymore - assert dts is not ts - recommendations[dts._key] = "forgotten" - ts._dependencies.clear() - ts._waiting_on.clear() - - ws: WorkerState - for ws in ts._who_has: - w: str = ws._address - if w in state._workers_dv: # in case worker has died - worker_msgs[w] = [ - { - "op": "free-keys", - "keys": [key], - "stimulus_id": f"propagate-forgotten-{time()}", - } - ] - state.remove_all_replicas(ts) - - -@cfunc -@exceptval(check=False) -def _client_releases_keys( - state: SchedulerState, keys: list, cs: ClientState, recommendations: dict -): - """Remove keys from client desired list""" - logger.debug("Client %s releases keys: %s", cs._client_key, keys) - ts: TaskState - for key in keys: - ts = state._tasks.get(key) # type: ignore - if ts is not None and ts in cs._wants_what: - cs._wants_what.remove(ts) - ts._who_wants.remove(cs) - if not ts._who_wants: - if not ts._dependents: - # No live dependents, can forget - recommendations[ts._key] = "forgotten" - elif ts._state != "erred" and not ts._waiters: - recommendations[ts._key] = "released" - - -@cfunc -@exceptval(check=False) -def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) -> dict: - """Convert a single computational task to a message""" - ws: WorkerState - dts: TaskState - - # FIXME: The duration attribute is not used on worker. We could safe ourselves the time to compute and submit this - if duration < 0: - duration = state.get_task_duration(ts) - - msg: dict = { - "op": "compute-task", - "key": ts._key, - "priority": ts._priority, - "duration": duration, - "stimulus_id": f"compute-task-{time()}", - "who_has": {}, - } - if ts._resource_restrictions: - msg["resource_restrictions"] = ts._resource_restrictions - if ts._actor: - msg["actor"] = True - - deps: set = ts._dependencies - if deps: - msg["who_has"] = { - dts._key: [ws._address for ws in dts._who_has] for dts in deps - } - msg["nbytes"] = {dts._key: dts._nbytes for dts in deps} - - if state._validate: - assert all(msg["who_has"].values()) - - task = ts._run_spec - if type(task) is dict: - msg.update(task) - else: - msg["task"] = task - - if ts._annotations: - msg["annotations"] = ts._annotations - - return msg - - -@cfunc -@exceptval(check=False) -def _task_to_report_msg(state: SchedulerState, ts: TaskState) -> dict: # -> dict | None - if ts._state == "forgotten": - return {"op": "cancelled-key", "key": ts._key} - elif ts._state == "memory": - return {"op": "key-in-memory", "key": ts._key} - elif ts._state == "erred": - failing_ts: TaskState = ts._exception_blame - return { - "op": "task-erred", - "key": ts._key, - "exception": failing_ts._exception, - "traceback": failing_ts._traceback, - } - else: - return None # type: ignore - - -@cfunc -@exceptval(check=False) -def _task_to_client_msgs(state: SchedulerState, ts: TaskState) -> dict: - if ts._who_wants: - report_msg: dict = _task_to_report_msg(state, ts) - if report_msg is not None: - cs: ClientState - return {cs._client_key: [report_msg] for cs in ts._who_wants} - return {} - - @cfunc @exceptval(check=False) def decide_worker(