diff --git a/distributed/diagnostics/progress.py b/distributed/diagnostics/progress.py index eaad51a2747..eeb3c8a2817 100644 --- a/distributed/diagnostics/progress.py +++ b/distributed/diagnostics/progress.py @@ -250,7 +250,7 @@ def __init__(self, scheduler): prefix = ts.prefix.name self.all[prefix].add(key) self.state[ts.state][prefix].add(key) - if ts.nbytes is not None: + if ts.nbytes >= 0: self.nbytes[prefix] += ts.nbytes scheduler.add_plugin(self) @@ -264,11 +264,11 @@ def transition(self, key, start, finish, *args, **kwargs): except KeyError: # TODO: remove me once we have a new or clean state pass - if start == "memory": + if start == "memory" and ts.nbytes >= 0: # XXX why not respect DEFAULT_DATA_SIZE? - self.nbytes[prefix] -= ts.nbytes or 0 - if finish == "memory": - self.nbytes[prefix] += ts.nbytes or 0 + self.nbytes[prefix] -= ts.nbytes + if finish == "memory" and ts.nbytes >= 0: + self.nbytes[prefix] += ts.nbytes if finish != "forgotten": self.state[finish][prefix].add(key) @@ -304,7 +304,7 @@ def __init__(self, scheduler): self.create(key, k) self.keys[k].add(key) self.groups[k][ts.state] += 1 - if ts.state == "memory" and ts.nbytes is not None: + if ts.state == "memory" and ts.nbytes >= 0: self.nbytes[k] += ts.nbytes scheduler.add_plugin(self) @@ -347,9 +347,9 @@ def transition(self, key, start, finish, *args, **kwargs): for dep in self.dependencies.pop(k): self.dependents[key_split_group(dep)].remove(k) - if start == "memory" and ts.nbytes is not None: + if start == "memory" and ts.nbytes >= 0: self.nbytes[k] -= ts.nbytes - if finish == "memory" and ts.nbytes is not None: + if finish == "memory" and ts.nbytes >= 0: self.nbytes[k] += ts.nbytes def restart(self, scheduler): diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 46a38c1ac8f..ef7adb2b1ed 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -511,7 +511,8 @@ def clean(self): nanny=self._nanny, extra=self._extra, ) - ws._processing = {ts.key: cost for ts, cost in self._processing.items()} + ts: TaskState + ws._processing = {ts._key: cost for ts, cost in self._processing.items()} return ws def __repr__(self): @@ -545,6 +546,7 @@ def ncores(self): return self._nthreads +@cclass class TaskPrefix: """Collection tracking all tasks within a group @@ -575,43 +577,68 @@ class TaskPrefix: TaskGroup """ - def __init__(self, name): - self.name = name - self.groups = [] + _name: str + _all_durations: object + _duration_average: double + _suspicious: Py_ssize_t + _groups: list + + def __init__(self, name: str): + self._name = name + self._groups = [] # store timings for each prefix-action - self.all_durations = defaultdict(float) + self._all_durations = defaultdict(float) - if self.name in dask.config.get("distributed.scheduler.default-task-durations"): - self.duration_average = parse_timedelta( - dask.config.get("distributed.scheduler.default-task-durations")[ - self.name - ] - ) + task_durations = dask.config.get("distributed.scheduler.default-task-durations") + if self._name in task_durations: + self._duration_average = parse_timedelta(task_durations[self._name]) else: - self.duration_average = None - self.suspicious = 0 + self._duration_average = -1 + self._suspicious = 0 + + @property + def name(self): + return self._name + + @property + def all_durations(self): + return self._all_durations + + @property + def duration_average(self): + return self._duration_average + + @property + def suspicious(self): + return self._suspicious + + @property + def groups(self): + return self._groups @property def states(self): - return merge_with(sum, [g.states for g in self.groups]) + tg: TaskGroup + return merge_with(sum, [tg._states for tg in self._groups]) @property def active(self): + tg: TaskGroup return [ - g - for g in self.groups - if any(v != 0 for k, v in g.states.items() if k != "forgotten") + tg + for tg in self._groups + if any(v != 0 for k, v in tg._states.items() if k != "forgotten") ] @property def active_states(self): - return merge_with(sum, [g.states for g in self.active]) + return merge_with(sum, [tg._states for tg in self.active]) def __repr__(self): return ( "<" - + self.name + + self._name + ": " + ", ".join( "%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v @@ -621,24 +648,29 @@ def __repr__(self): @property def nbytes_in_memory(self): - return sum(tg.nbytes_in_memory for tg in self.groups) + tg: TaskGroup + return sum([tg._nbytes_in_memory for tg in self._groups]) @property def nbytes_total(self): - return sum(tg.nbytes_total for tg in self.groups) + tg: TaskGroup + return sum([tg._nbytes_total for tg in self._groups]) def __len__(self): - return sum(map(len, self.groups)) + return sum(map(len, self._groups)) @property def duration(self): - return sum(tg.duration for tg in self.groups) + tg: TaskGroup + return sum([tg._duration for tg in self._groups]) @property def types(self): - return set().union(*[tg.types for tg in self.groups]) + tg: TaskGroup + return set().union(*[tg._types for tg in self._groups]) +@cclass class TaskGroup: """Collection tracking all tasks within a group @@ -680,35 +712,79 @@ class TaskGroup: TaskPrefix """ - def __init__(self, name): - self.name = name - self.states = {state: 0 for state in ALL_TASK_STATES} - self.states["forgotten"] = 0 - self.dependencies = set() - self.nbytes_total = 0 - self.nbytes_in_memory = 0 - self.duration = 0 - self.types = set() - - def add(self, ts): - self.states[ts.state] += 1 - ts.group = self + _name: str + _prefix: TaskPrefix + _states: dict + _dependencies: set + _nbytes_total: Py_ssize_t + _nbytes_in_memory: Py_ssize_t + _duration: double + _types: set + + def __init__(self, name: str): + self._name = name + self._prefix = None + self._states = {state: 0 for state in ALL_TASK_STATES} + self._states["forgotten"] = 0 + self._dependencies = set() + self._nbytes_total = 0 + self._nbytes_in_memory = 0 + self._duration = 0 + self._types = set() + + @property + def name(self): + return self._name + + @property + def prefix(self): + return self._prefix + + @property + def states(self): + return self._states + + @property + def dependencies(self): + return self._dependencies + + @property + def nbytes_total(self): + return self._nbytes_total + + @property + def nbytes_in_memory(self): + return self._nbytes_in_memory + + @property + def duration(self): + return self._duration + + @property + def types(self): + return self._types + + def add(self, o): + ts: TaskState = o + self._states[ts.state] += 1 + ts._group = self def __repr__(self): return ( "<" - + (self.name or "no-group") + + (self._name or "no-group") + ": " + ", ".join( - "%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v + "%s: %d" % (k, v) for (k, v) in sorted(self._states.items()) if v ) + ">" ) def __len__(self): - return sum(self.states.values()) + return sum(self._states.values()) +@cclass class TaskState: """ A simple object holding information about a task. @@ -950,137 +1026,296 @@ class TaskState: Task annotations """ + _key: str + _hash: Py_hash_t + _prefix: TaskPrefix + _run_spec: object + _priority: tuple + _state: str + _dependencies: set + _dependents: set + _has_lost_dependencies: bool + _waiting_on: set + _waiters: set + _who_wants: set + _who_has: set + _processing_on: WorkerState + _retries: Py_ssize_t + _nbytes: Py_ssize_t + _type: str + _exception: object + _traceback: object + _exception_blame: object + _suspicious: Py_ssize_t + _host_restrictions: set + _worker_restrictions: set + _resource_restrictions: dict + _loose_restrictions: bool + _metadata: dict + _annotations: dict + _actor: bool + _group: TaskGroup + _group_key: str + __slots__ = ( # === General description === - "actor", + "_actor", # Key name - "key", + "_key", # Hash of the key name "_hash", # Key prefix (see key_split()) - "prefix", + "_prefix", # How to run the task (None if pure data) - "run_spec", + "_run_spec", # Alive dependents and dependencies - "dependencies", - "dependents", + "_dependencies", + "_dependents", # Compute priority - "priority", + "_priority", # Restrictions - "host_restrictions", - "worker_restrictions", # not WorkerStates but addresses - "resource_restrictions", - "loose_restrictions", + "_host_restrictions", + "_worker_restrictions", # not WorkerStates but addresses + "_resource_restrictions", + "_loose_restrictions", # === Task state === "_state", # Whether some dependencies were forgotten - "has_lost_dependencies", + "_has_lost_dependencies", # If in 'waiting' state, which tasks need to complete # before we can run - "waiting_on", + "_waiting_on", # If in 'waiting' or 'processing' state, which tasks needs us # to complete before they can run - "waiters", + "_waiters", # In in 'processing' state, which worker we are processing on - "processing_on", + "_processing_on", # If in 'memory' state, Which workers have us - "who_has", + "_who_has", # Which clients want us - "who_wants", - "exception", - "traceback", - "exception_blame", - "suspicious", - "retries", - "nbytes", - "type", - "group_key", - "group", - "metadata", - "annotations", + "_who_wants", + "_exception", + "_traceback", + "_exception_blame", + "_suspicious", + "_retries", + "_nbytes", + "_type", + "_group_key", + "_group", + "_metadata", + "_annotations", ) - def __init__(self, key, run_spec): - self.key = key + def __init__(self, key: str, run_spec: object): + self._key = key self._hash = hash(key) - self.run_spec = run_spec + self._run_spec = run_spec self._state = None - self.exception = self.traceback = self.exception_blame = None - self.suspicious = self.retries = 0 - self.nbytes = None - self.priority = None - self.who_wants = set() - self.dependencies = set() - self.dependents = set() - self.waiting_on = set() - self.waiters = set() - self.who_has = set() - self.processing_on = None - self.has_lost_dependencies = False - self.host_restrictions = None - self.worker_restrictions = None - self.resource_restrictions = None - self.loose_restrictions = False - self.actor = None - self.type = None - self.group_key = key_split_group(key) - self.group = None - self.metadata = {} - self.annotations = {} + self._exception = self._traceback = self._exception_blame = None + self._suspicious = self._retries = 0 + self._nbytes = -1 + self._priority = None + self._who_wants = set() + self._dependencies = set() + self._dependents = set() + self._waiting_on = set() + self._waiters = set() + self._who_has = set() + self._processing_on = None + self._has_lost_dependencies = False + self._host_restrictions = None + self._worker_restrictions = None + self._resource_restrictions = None + self._loose_restrictions = False + self._actor = None + self._type = None + self._group_key = key_split_group(key) + self._group = None + self._metadata = {} + self._annotations = {} def __hash__(self): return self._hash def __eq__(self, other): - return type(self) == type(other) and self.key == other.key + typ_self: type = type(self) + typ_other: type = type(other) + if typ_self == typ_other: + other_ts: TaskState = other + return self._key == other_ts._key + else: + return False @property - def state(self) -> str: - return self._state + def key(self): + return self._key @property - def prefix_key(self): - return self.prefix.name + def prefix(self): + return self._prefix + + @property + def run_spec(self): + return self._run_spec + + @property + def priority(self): + return self._priority + + @property + def state(self) -> str: + return self._state @state.setter def state(self, value: str): - self.group.states[self._state] -= 1 - self.group.states[value] += 1 + self._group._states[self._state] -= 1 + self._group._states[value] += 1 self._state = value + @property + def dependencies(self): + return self._dependencies + + @property + def dependents(self): + return self._dependents + + @property + def has_lost_dependencies(self): + return self._has_lost_dependencies + + @property + def waiting_on(self): + return self._waiting_on + + @property + def waiters(self): + return self._waiters + + @property + def who_wants(self): + return self._who_wants + + @property + def who_has(self): + return self._who_has + + @property + def processing_on(self): + return self._processing_on + + @processing_on.setter + def processing_on(self, v: WorkerState): + self._processing_on = v + + @property + def retries(self): + return self._retries + + @property + def nbytes(self): + return self._nbytes + + @nbytes.setter + def nbytes(self, v: Py_ssize_t): + self._nbytes = v + + @property + def type(self): + return self._type + + @property + def exception(self): + return self._exception + + @property + def traceback(self): + return self._traceback + + @property + def exception_blame(self): + return self._exception_blame + + @property + def suspicious(self): + return self._suspicious + + @property + def host_restrictions(self): + return self._host_restrictions + + @property + def worker_restrictions(self): + return self._worker_restrictions + + @property + def resource_restrictions(self): + return self._resource_restrictions + + @property + def loose_restrictions(self): + return self._loose_restrictions + + @property + def metadata(self): + return self._metadata + + @property + def annotations(self): + return self._annotations + + @property + def actor(self): + return self._actor + + @property + def group(self): + return self._group + + @property + def group_key(self): + return self._group_key + + @property + def prefix_key(self): + return self._prefix._name + def add_dependency(self, other: "TaskState"): """ Add another task as a dependency of this task """ - self.dependencies.add(other) - self.group.dependencies.add(other.group) - other.dependents.add(self) + self._dependencies.add(other) + self._group._dependencies.add(other._group) + other._dependents.add(self) def get_nbytes(self) -> int: - nbytes = self.nbytes - return nbytes if nbytes is not None else DEFAULT_DATA_SIZE - - def set_nbytes(self, nbytes: int): - old_nbytes = self.nbytes - diff = nbytes - (old_nbytes or 0) - self.group.nbytes_total += diff - self.group.nbytes_in_memory += diff + nbytes = self._nbytes + return nbytes if nbytes >= 0 else DEFAULT_DATA_SIZE + + def set_nbytes(self, nbytes: Py_ssize_t): + diff: Py_ssize_t = nbytes + old_nbytes: Py_ssize_t = self._nbytes + if old_nbytes >= 0: + diff -= old_nbytes + self._group._nbytes_total += diff + self._group._nbytes_in_memory += diff ws: WorkerState - for ws in self.who_has: + for ws in self._who_has: ws._nbytes += diff - self.nbytes = nbytes + self._nbytes = nbytes def __repr__(self): - return "" % (self.key, self.state) + return "" % (self._key, self._state) def validate(self): try: - for cs in self.who_wants: - assert isinstance(cs, ClientState), (repr(cs), self.who_wants) - for ws in self.who_has: - assert isinstance(ws, WorkerState), (repr(ws), self.who_has) - for ts in self.dependencies: - assert isinstance(ts, TaskState), (repr(ts), self.dependencies) - for ts in self.dependents: - assert isinstance(ts, TaskState), (repr(ts), self.dependents) + for cs in self._who_wants: + assert isinstance(cs, ClientState), (repr(cs), self._who_wants) + for ws in self._who_has: + assert isinstance(ws, WorkerState), (repr(ws), self._who_has) + for ts in self._dependencies: + assert isinstance(ts, TaskState), (repr(ts), self._dependencies) + for ts in self._dependents: + assert isinstance(ts, TaskState), (repr(ts), self._dependents) validate_task_state(self) except Exception as e: logger.exception(e) @@ -1168,7 +1403,8 @@ def _legacy_task_key_set(tasks): """ Transform a set of task states into a set of task keys. """ - return {ts.key for ts in tasks} + ts: TaskState + return {ts._key for ts in tasks} def _legacy_client_key_set(clients): @@ -1191,7 +1427,8 @@ def _legacy_task_key_dict(task_dict): """ Transform a dict of {task state: value} into a dict of {task key: value}. """ - return {ts.key: value for ts, value in task_dict.items()} + ts: TaskState + return {ts._key: value for ts, value in task_dict.items()} def _task_key_or_none(task): @@ -1975,7 +2212,7 @@ async def add_worker( if nbytes: for key in nbytes: - ts = self.tasks.get(key) + ts: TaskState = self.tasks.get(key) if ts is not None and ts.state in ("processing", "waiting"): recommendations = self.transition( key, @@ -1990,7 +2227,7 @@ async def add_worker( for ts in list(self.unrunnable): valid = self.valid_workers(ts) if valid is True or ws in valid: - recommendations[ts.key] = "waiting" + recommendations[ts._key] = "waiting" if recommendations: self.transitions(recommendations) @@ -2128,11 +2365,15 @@ def update_graph( self.client_releases_keys(keys=[k], client=client) # Avoid computation that is already finished + ts: TaskState already_in_memory = set() # tasks that are already done for k, v in dependencies.items(): - if v and k in self.tasks and self.tasks[k].state in ("memory", "erred"): - already_in_memory.add(k) + if v and k in self.tasks: + ts = self.tasks[k] + if ts.state in ("memory", "erred"): + already_in_memory.add(k) + dts: TaskState if already_in_memory: dependents = dask.core.reverse_dict(dependencies) stack = list(already_in_memory) @@ -2170,8 +2411,8 @@ def update_graph( ts = self.tasks.get(k) if ts is None: ts = self.new_task(k, tasks.get(k), "released") - elif not ts.run_spec: - ts.run_spec = tasks.get(k) + elif not ts._run_spec: + ts._run_spec = tasks.get(k) touched_keys.add(k) touched_tasks.append(ts) @@ -2182,7 +2423,7 @@ def update_graph( # Add dependencies for key, deps in dependencies.items(): ts = self.tasks.get(key) - if ts is None or ts.dependencies: + if ts is None or ts._dependencies: continue for dep in deps: dts = self.tasks[dep] @@ -2219,13 +2460,15 @@ def update_graph( for a, kv in annotations.items(): for k, v in kv.items(): - self.tasks[k].annotations[a] = v + ts = self.tasks[k] + ts._annotations[a] = v # Add actors if actors is True: actors = list(keys) for actor in actors or []: - self.tasks[actor].actor = True + ts = self.tasks[actor] + ts._actor = True priority = priority or dask.order.order( tasks @@ -2234,7 +2477,7 @@ def update_graph( if submitting_task: # sub-tasks get better priority than parent tasks ts = self.tasks.get(submitting_task) if ts is not None: - generation = ts.priority[0] - 0.01 + generation = ts._priority[0] - 0.01 else: # super-task already cleaned up generation = self.generation elif self._last_time + fifo_timeout < start: @@ -2246,14 +2489,14 @@ def update_graph( for key in set(priority) & touched_keys: ts = self.tasks[key] - if ts.priority is None: - ts.priority = (-(user_priority.get(key, 0)), generation, priority[key]) + if ts._priority is None: + ts._priority = (-(user_priority.get(key, 0)), generation, priority[key]) # Ensure all runnables have a priority - runnables = [ts for ts in touched_tasks if ts.run_spec] + runnables = [ts for ts in touched_tasks if ts._run_spec] for ts in runnables: - if ts.priority is None and ts.run_spec: - ts.priority = (self.generation, 0) + if ts._priority is None and ts._run_spec: + ts._priority = (self.generation, 0) if restrictions: # *restrictions* is a dict keying task ids to lists of @@ -2264,21 +2507,21 @@ def update_graph( ts = self.tasks.get(k) if ts is None: continue - ts.host_restrictions = set() - ts.worker_restrictions = set() + ts._host_restrictions = set() + ts._worker_restrictions = set() for w in v: try: w = self.coerce_address(w) except ValueError: # Not a valid address, but perhaps it's a hostname - ts.host_restrictions.add(w) + ts._host_restrictions.add(w) else: - ts.worker_restrictions.add(w) + ts._worker_restrictions.add(w) if loose_restrictions: for k in loose_restrictions: ts = self.tasks[k] - ts.loose_restrictions = True + ts._loose_restrictions = True if resources: for k, v in resources.items(): @@ -2288,7 +2531,7 @@ def update_graph( ts = self.tasks.get(k) if ts is None: continue - ts.resource_restrictions = v + ts._resource_restrictions = v if retries: for k, v in retries.items(): @@ -2296,20 +2539,20 @@ def update_graph( ts = self.tasks.get(k) if ts is None: continue - ts.retries = v + ts._retries = v # Compute recommendations recommendations = {} for ts in sorted(runnables, key=operator.attrgetter("priority"), reverse=True): - if ts.state == "released" and ts.run_spec: - recommendations[ts.key] = "waiting" + if ts.state == "released" and ts._run_spec: + recommendations[ts._key] = "waiting" for ts in touched_tasks: - for dts in ts.dependencies: - if dts.exception_blame: - ts.exception_blame = dts.exception_blame - recommendations[ts.key] = "erred" + for dts in ts._dependencies: + if dts._exception_blame: + ts._exception_blame = dts._exception_blame + recommendations[ts._key] = "erred" break for plugin in self.plugins[:]: @@ -2333,7 +2576,7 @@ def update_graph( for ts in touched_tasks: if ts.state in ("memory", "erred"): - self.report_on_key(ts.key, client=client) + self.report_on_key(ts._key, client=client) end = time() if self.digests is not None: @@ -2343,22 +2586,24 @@ def update_graph( def new_task(self, key, spec, state): """ Create a new task, and associated states """ - ts = TaskState(key, spec) + ts: TaskState = TaskState(key, spec) + tp: TaskPrefix + tg: TaskGroup ts._state = state prefix_key = key_split(key) try: tp = self.task_prefixes[prefix_key] except KeyError: - tp = self.task_prefixes[prefix_key] = TaskPrefix(prefix_key) - ts.prefix = tp + self.task_prefixes[prefix_key] = tp = TaskPrefix(prefix_key) + ts._prefix = tp - group_key = ts.group_key + group_key = ts._group_key try: tg = self.task_groups[group_key] except KeyError: - tg = self.task_groups[group_key] = TaskGroup(group_key) - tg.prefix = tp - tp.groups.append(tg) + self.task_groups[group_key] = tg = TaskGroup(group_key) + tg._prefix = tp + tp._groups.append(tg) tg.add(ts) self.tasks[key] = ts return ts @@ -2367,17 +2612,17 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): """ Mark that a task has finished execution on a particular worker """ logger.debug("Stimulus task finished %s, %s", key, worker) - ts = self.tasks.get(key) + ts: TaskState = self.tasks.get(key) if ts is None: return {} ws: WorkerState = self.workers[worker] - ts.metadata.update(kwargs["metadata"]) + ts._metadata.update(kwargs["metadata"]) if ts.state == "processing": recommendations = self.transition(key, "memory", worker=worker, **kwargs) if ts.state == "memory": - assert ws in ts.who_has + assert ws in ts._who_has else: logger.debug( "Received already computed task, worker: %s, state: %s" @@ -2385,9 +2630,9 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): worker, ts.state, key, - ts.who_has, + ts._who_has, ) - if ws not in ts.who_has: + if ws not in ts._who_has: self.worker_send(worker, {"op": "release-task", "key": key}) recommendations = {} @@ -2399,14 +2644,14 @@ def stimulus_task_erred( """ Mark that a task has erred on a particular worker """ logger.debug("Stimulus task erred %s, %s", key, worker) - ts = self.tasks.get(key) + ts: TaskState = self.tasks.get(key) if ts is None: return {} if ts.state == "processing": - retries = ts.retries + retries = ts._retries if retries > 0: - ts.retries = retries - 1 + ts._retries = retries - 1 recommendations = self.transition(key, "waiting") else: recommendations = self.transition( @@ -2430,19 +2675,19 @@ def stimulus_missing_data( with log_errors(): logger.debug("Stimulus missing data %s, %s", key, worker) - ts = self.tasks.get(key) + ts: TaskState = self.tasks.get(key) if ts is None or ts.state == "memory": return {} - cts = self.tasks.get(cause) + cts: TaskState = self.tasks.get(cause) recommendations = {} if cts is not None and cts.state == "memory": # couldn't find this ws: WorkerState - for ws in cts.who_has: # TODO: this behavior is extreme + for ws in cts._who_has: # TODO: this behavior is extreme ws._has_what.remove(cts) ws._nbytes -= cts.get_nbytes() - cts.who_has.clear() + cts._who_has.clear() recommendations[cause] = "released" if key: @@ -2463,12 +2708,13 @@ def stimulus_retry(self, comm=None, keys=None, client=None): stack = list(keys) seen = set() roots = [] + ts: TaskState + dts: TaskState while stack: key = stack.pop() seen.add(key) - erred_deps = [ - dts.key for dts in self.tasks[key].dependencies if dts.state == "erred" - ] + ts = self.tasks[key] + erred_deps = [dts._key for dts in ts._dependencies if dts.state == "erred"] if erred_deps: stack.extend(erred_deps) else: @@ -2537,13 +2783,14 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): recommendations = {} + ts: TaskState for ts in list(ws._processing): - k = ts.key + k = ts._key recommendations[k] = "released" if not safe: - ts.suspicious += 1 - ts.prefix.suspicious += 1 - if ts.suspicious > self.allowed_failures: + ts._suspicious += 1 + ts._prefix._suspicious += 1 + if ts._suspicious > self.allowed_failures: del recommendations[k] e = pickle.dumps( KilledWorker(task=k, last_worker=ws.clean()), protocol=4 @@ -2553,17 +2800,17 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): logger.info( "Task %s marked as failed because %d workers died" " while trying to run it", - ts.key, + ts._key, self.allowed_failures, ) for ts in ws._has_what: - ts.who_has.remove(ws) - if not ts.who_has: - if ts.run_spec: - recommendations[ts.key] = "released" + ts._who_has.remove(ws) + if not ts._who_has: + if ts._run_spec: + recommendations[ts._key] = "released" else: # pure data - recommendations[ts.key] = "forgotten" + recommendations[ts._key] = "forgotten" ws._has_what.clear() self.transitions(recommendations) @@ -2609,23 +2856,24 @@ def stimulus_cancel(self, comm, keys=None, client=None, force=False): def cancel_key(self, key, client, retries=5, force=False): """ Cancel a particular key and all dependents """ # TODO: this should be converted to use the transition mechanism - ts = self.tasks.get(key) + ts: TaskState = self.tasks.get(key) + dts: TaskState try: cs: ClientState = self.clients[client] except KeyError: return - if ts is None or not ts.who_wants: # no key yet, lets try again in a moment + if ts is None or not ts._who_wants: # no key yet, lets try again in a moment if retries: self.loop.call_later( 0.2, lambda: self.cancel_key(key, client, retries - 1) ) return - if force or ts.who_wants == {cs}: # no one else wants this key - for dts in list(ts.dependents): - self.cancel_key(dts.key, client, force=force) + if force or ts._who_wants == {cs}: # no one else wants this key + for dts in list(ts._dependents): + self.cancel_key(dts._key, client, force=force) logger.info("Scheduler cancels key %s. Force=%s", key, force) self.report({"op": "cancelled-key", "key": key}) - clients = list(ts.who_wants) if force else [cs] + clients = list(ts._who_wants) if force else [cs] for cs in clients: self.client_releases_keys(keys=[key], client=cs._client_key) @@ -2634,12 +2882,13 @@ def client_desires_keys(self, keys=None, client=None): if cs is None: # For publish, queues etc. self.clients[client] = cs = ClientState(client) + ts: TaskState for k in keys: ts = self.tasks.get(k) if ts is None: # For publish, queues etc. ts = self.new_task(k, None, "released") - ts.who_wants.add(cs) + ts._who_wants.add(cs) cs._wants_what.add(ts) if ts.state in ("memory", "erred"): @@ -2649,23 +2898,24 @@ def client_releases_keys(self, keys=None, client=None): """ Remove keys from client desired list """ logger.debug("Client %s releases keys: %s", client, keys) cs: ClientState = self.clients[client] + ts: TaskState tasks2 = set() for key in list(keys): ts = self.tasks.get(key) if ts is not None and ts in cs._wants_what: cs._wants_what.remove(ts) - s = ts.who_wants + s = ts._who_wants s.remove(cs) if not s: tasks2.add(ts) recommendations = {} for ts in tasks2: - if not ts.dependents: + if not ts._dependents: # No live dependents, can forget - recommendations[ts.key] = "forgotten" - elif ts.state != "erred" and not ts.waiters: - recommendations[ts.key] = "released" + recommendations[ts._key] = "forgotten" + elif ts.state != "erred" and not ts._waiters: + recommendations[ts._key] = "released" self.transitions(recommendations) @@ -2679,63 +2929,68 @@ def client_heartbeat(self, client=None): ################### def validate_released(self, key): - ts = self.tasks[key] + ts: TaskState = self.tasks[key] + dts: TaskState assert ts.state == "released" - assert not ts.waiters - assert not ts.waiting_on - assert not ts.who_has - assert not ts.processing_on - assert not any(ts in dts.waiters for dts in ts.dependencies) + assert not ts._waiters + assert not ts._waiting_on + assert not ts._who_has + assert not ts._processing_on + assert not any([ts in dts._waiters for dts in ts._dependencies]) assert ts not in self.unrunnable def validate_waiting(self, key): - ts = self.tasks[key] - assert ts.waiting_on - assert not ts.who_has - assert not ts.processing_on + ts: TaskState = self.tasks[key] + dts: TaskState + assert ts._waiting_on + assert not ts._who_has + assert not ts._processing_on assert ts not in self.unrunnable - for dts in ts.dependencies: + for dts in ts._dependencies: # We are waiting on a dependency iff it's not stored - assert bool(dts.who_has) + (dts in ts.waiting_on) == 1 - assert ts in dts.waiters # XXX even if dts.who_has? + assert bool(dts._who_has) + (dts in ts._waiting_on) == 1 + assert ts in dts._waiters # XXX even if dts._who_has? def validate_processing(self, key): - ts = self.tasks[key] - assert not ts.waiting_on - ws: WorkerState = ts.processing_on + ts: TaskState = self.tasks[key] + dts: TaskState + assert not ts._waiting_on + ws: WorkerState = ts._processing_on assert ws assert ts in ws._processing - assert not ts.who_has - for dts in ts.dependencies: - assert dts.who_has - assert ts in dts.waiters + assert not ts._who_has + for dts in ts._dependencies: + assert dts._who_has + assert ts in dts._waiters def validate_memory(self, key): - ts = self.tasks[key] - assert ts.who_has - assert not ts.processing_on - assert not ts.waiting_on + ts: TaskState = self.tasks[key] + dts: TaskState + assert ts._who_has + assert not ts._processing_on + assert not ts._waiting_on assert ts not in self.unrunnable - for dts in ts.dependents: - assert (dts in ts.waiters) == (dts.state in ("waiting", "processing")) - assert ts not in dts.waiting_on + for dts in ts._dependents: + assert (dts in ts._waiters) == (dts.state in ("waiting", "processing")) + assert ts not in dts._waiting_on def validate_no_worker(self, key): - ts = self.tasks[key] + ts: TaskState = self.tasks[key] + dts: TaskState assert ts in self.unrunnable - assert not ts.waiting_on + assert not ts._waiting_on assert ts in self.unrunnable - assert not ts.processing_on - assert not ts.who_has - for dts in ts.dependencies: - assert dts.who_has + assert not ts._processing_on + assert not ts._who_has + for dts in ts._dependencies: + assert dts._who_has def validate_erred(self, key): - ts = self.tasks[key] - assert ts.exception_blame - assert not ts.who_has + ts: TaskState = self.tasks[key] + assert ts._exception_blame + assert not ts._who_has - def validate_key(self, key, ts=None): + def validate_key(self, key, ts: TaskState = None): try: if ts is None: ts = self.tasks.get(key) @@ -2774,9 +3029,10 @@ def validate_state(self, allow_overlap=False): assert not ws._occupancy assert ws in self.idle + ts: TaskState for k, ts in self.tasks.items(): assert isinstance(ts, TaskState), (type(ts), ts) - assert ts.key == k + assert ts._key == k self.validate_key(k, ts) c: str @@ -2808,7 +3064,7 @@ def validate_state(self, allow_overlap=False): # Manage Messages # ################### - def report(self, msg, ts=None, client=None): + def report(self, msg, ts: TaskState = None, client=None): """ Publish updates to all listening Queues and Comms @@ -2824,11 +3080,11 @@ def report(self, msg, ts=None, client=None): client_keys = list(self.client_comms) elif client is None: # Notify clients interested in key - client_keys = [cs._client_key for cs in ts.who_wants] + client_keys = [cs._client_key for cs in ts._who_wants] else: # Notify clients interested in key (including `client`) client_keys = [ - cs._client_key for cs in ts.who_wants if cs._client_key != client + cs._client_key for cs in ts._who_wants if cs._client_key != client ] client_keys.append(client) @@ -2903,8 +3159,9 @@ def remove_client(self, client=None): # XXX is this a legitimate condition? pass else: + ts: TaskState self.client_releases_keys( - keys=[ts.key for ts in cs._wants_what], client=cs._client_key + keys=[ts._key for ts in cs._wants_what], client=cs._client_key ) del self.clients[client] @@ -2927,31 +3184,32 @@ def remove_client_from_events(): def send_task_to_worker(self, worker, key): """ Send a single computational task to a worker """ try: - ts = self.tasks[key] + ts: TaskState = self.tasks[key] + dts: TaskState msg = { "op": "compute-task", "key": key, - "priority": ts.priority, + "priority": ts._priority, "duration": self.get_task_duration(ts), } - if ts.resource_restrictions: - msg["resource_restrictions"] = ts.resource_restrictions - if ts.actor: + if ts._resource_restrictions: + msg["resource_restrictions"] = ts._resource_restrictions + if ts._actor: msg["actor"] = True - deps = ts.dependencies + deps = ts._dependencies if deps: ws: WorkerState msg["who_has"] = { - dep.key: [ws._address for ws in dep.who_has] for dep in deps + dts._key: [ws._address for ws in dts._who_has] for dts in deps } - msg["nbytes"] = {dep.key: dep.nbytes for dep in deps} + msg["nbytes"] = {dts._key: dts._nbytes for dts in deps} if self.validate and deps: assert all(msg["who_has"].values()) - task = ts.run_spec + task = ts._run_spec if type(task) is dict: msg.update(task) else: @@ -2981,11 +3239,11 @@ def handle_task_erred(self, key=None, **msg): self.transitions(r) def handle_release_data(self, key=None, worker=None, client=None, **msg): - ts = self.tasks.get(key) + ts: TaskState = self.tasks.get(key) if ts is None: return ws: WorkerState = self.workers[worker] - if ts.processing_on != ws: + if ts._processing_on != ws: return r = self.stimulus_missing_data(key=key, ensure=False, **msg) self.transitions(r) @@ -2994,17 +3252,17 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): logger.debug("handle missing data key=%s worker=%s", key, errant_worker) self.log.append(("missing", key, errant_worker)) - ts = self.tasks.get(key) - if ts is None or not ts.who_has: + ts: TaskState = self.tasks.get(key) + if ts is None or not ts._who_has: return if errant_worker in self.workers: ws: WorkerState = self.workers[errant_worker] - if ws in ts.who_has: - ts.who_has.remove(ws) + if ws in ts._who_has: + ts._who_has.remove(ws) ws._has_what.remove(ts) ws._nbytes -= ts.get_nbytes() - if not ts.who_has: - if ts.run_spec: + if not ts._who_has: + if ts._run_spec: self.transitions({key: "released"}) else: self.transitions({key: "forgotten"}) @@ -3015,13 +3273,14 @@ def release_worker_data(self, comm=None, keys=None, worker=None): removed_tasks = tasks & ws._has_what ws._has_what -= removed_tasks + ts: TaskState recommendations = {} for ts in removed_tasks: ws._nbytes -= ts.get_nbytes() - wh = ts.who_has + wh = ts._who_has wh.remove(ws) if not wh: - recommendations[ts.key] = "released" + recommendations[ts._key] = "released" if recommendations: self.transitions(recommendations) @@ -3031,24 +3290,24 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): We stop the task from being stolen in the future, and change task duration accounting as if the task has stopped. """ - ts = self.tasks[key] + ts: TaskState = self.tasks[key] if "stealing" in self.extensions: self.extensions["stealing"].remove_key_from_stealable(ts) - ws: WorkerState = ts.processing_on + ws: WorkerState = ts._processing_on if ws is None: logger.debug("Received long-running signal from duplicate task. Ignoring.") return if compute_duration: - old_duration = ts.prefix.duration_average or 0 + old_duration = ts._prefix._duration_average new_duration = compute_duration - if not old_duration: + if old_duration < 0: avg_duration = new_duration else: avg_duration = 0.5 * old_duration + 0.5 * new_duration - ts.prefix.duration_average = avg_duration + ts._prefix._duration_average = avg_duration ws._occupancy -= ws._processing[ts] self.total_occupancy -= ws._processing[ts] @@ -3163,9 +3422,9 @@ async def gather(self, comm=None, keys=None, serializers=None): keys = list(keys) who_has = {} for key in keys: - ts = self.tasks.get(key) + ts: TaskState = self.tasks.get(key) if ts is not None: - who_has[key] = [ws._address for ws in ts.who_has] + who_has[key] = [ws._address for ws in ts._who_has] else: who_has[key] = [] @@ -3198,7 +3457,7 @@ async def gather(self, comm=None, keys=None, serializers=None): for key, workers in missing_keys.items(): # Task may already be gone if it was held by a # `missing_worker` - ts = self.tasks.get(key) + ts: TaskState = self.tasks.get(key) logger.exception( "Workers don't have promised key: %s, %s", str(workers), @@ -3210,7 +3469,7 @@ async def gather(self, comm=None, keys=None, serializers=None): ws = self.workers.get(worker) if ws is not None and ts in ws._has_what: ws._has_what.remove(ts) - ts.who_has.remove(ws) + ts._who_has.remove(ws) ws._nbytes -= ts.get_nbytes() self.transitions({key: "released"}) @@ -3232,9 +3491,10 @@ async def restart(self, client=None, timeout=3): logger.info("Send lost future signal to clients") cs: ClientState + ts: TaskState for cs in self.clients.values(): self.client_releases_keys( - keys=[ts.key for ts in cs._wants_what], client=cs._client_key + keys=[ts._key for ts in cs._wants_what], client=cs._client_key ) ws: WorkerState @@ -3365,10 +3625,11 @@ async def _delete_worker_data(self, worker_address, keys): ) ws: WorkerState = self.workers[worker_address] + ts: TaskState tasks = {self.tasks[key] for key in keys} ws._has_what -= tasks for ts in tasks: - ts.who_has.remove(ws) + ts._who_has.remove(ws) ws._nbytes -= ts.get_nbytes() self.log_event(ws._address, {"action": "remove-worker-data", "keys": keys}) @@ -3383,11 +3644,12 @@ async def rebalance(self, comm=None, keys=None, workers=None): occupied worker until either the sender or the recipient are at the average expected load. """ + ts: TaskState with log_errors(): async with self._lock: if keys: tasks = {self.tasks[k] for k in keys} - missing_data = [ts.key for ts in tasks if not ts.who_has] + missing_data = [ts._key for ts in tasks if not ts._who_has] if missing_data: return {"status": "missing-data", "keys": missing_data} else: @@ -3395,10 +3657,10 @@ async def rebalance(self, comm=None, keys=None, workers=None): if workers: workers = {self.workers[w] for w in workers} - workers_by_task = {ts: ts.who_has & workers for ts in tasks} + workers_by_task = {ts: ts._who_has & workers for ts in tasks} else: workers = set(self.workers.values()) - workers_by_task = {ts: ts.who_has for ts in tasks} + workers_by_task = {ts: ts._who_has for ts in tasks} ws: WorkerState tasks_by_worker = {ws: set() for ws in workers} @@ -3450,8 +3712,8 @@ async def rebalance(self, comm=None, keys=None, workers=None): to_recipients = defaultdict(lambda: defaultdict(list)) to_senders = defaultdict(list) for sender, recipient, ts in msgs: - to_recipients[recipient.address][ts.key].append(sender.address) - to_senders[sender.address].append(ts.key) + to_recipients[recipient.address][ts._key].append(sender.address) + to_senders[sender.address].append(ts._key) result = await asyncio.gather( *( @@ -3487,11 +3749,17 @@ async def rebalance(self, comm=None, keys=None, workers=None): for sender, recipient, ts in msgs: assert ts.state == "memory" - ts.who_has.add(recipient) + ts._who_has.add(recipient) recipient.has_what.add(ts) recipient.nbytes += ts.get_nbytes() self.log.append( - ("rebalance", ts.key, time(), sender.address, recipient.address) + ( + "rebalance", + ts._key, + time(), + sender.address, + recipient.address, + ) ) await asyncio.gather( @@ -3533,6 +3801,7 @@ async def replicate( """ ws: WorkerState wws: WorkerState + ts: TaskState assert branching_factor > 0 async with self._lock if lock else empty_context: @@ -3545,7 +3814,7 @@ async def replicate( raise ValueError("Can not use replicate to delete data") tasks = {self.tasks[k] for k in keys} - missing_data = [ts.key for ts in tasks if not ts.who_has] + missing_data = [ts._key for ts in tasks if not ts._who_has] if missing_data: return {"status": "missing-data", "keys": missing_data} @@ -3553,7 +3822,7 @@ async def replicate( if delete: del_worker_tasks = defaultdict(set) for ts in tasks: - del_candidates = ts.who_has & workers + del_candidates = ts._who_has & workers if len(del_candidates) > n: for ws in random.sample( del_candidates, len(del_candidates) - n @@ -3575,18 +3844,18 @@ async def replicate( # task is no longer needed by any client or dependant task tasks.remove(ts) continue - n_missing = n - len(ts.who_has & workers) + n_missing = n - len(ts._who_has & workers) if n_missing <= 0: # Already replicated enough tasks.remove(ts) continue - count = min(n_missing, branching_factor * len(ts.who_has)) + count = min(n_missing, branching_factor * len(ts._who_has)) assert count > 0 - for ws in random.sample(workers - ts.who_has, count): - gathers[ws._address][ts.key] = [ - wws._address for wws in ts.who_has + for ws in random.sample(workers - ts._who_has, count): + gathers[ws._address][ts._key] = [ + wws._address for wws in ts._who_has ] results = await asyncio.gather( @@ -3788,6 +4057,7 @@ async def retire_workers( Scheduler.workers_to_close """ ws: WorkerState + ts: TaskState with log_errors(): async with self._lock if lock else empty_context: if names is not None: @@ -3820,7 +4090,7 @@ async def retire_workers( # Keys orphaned by retiring those workers keys = set.union(*[w.has_what for w in workers]) - keys = {ts.key for ts in keys if ts.who_has.issubset(workers)} + keys = {ts._key for ts in keys if ts._who_has.issubset(workers)} other_workers = set(self.workers.values()) - workers if keys: @@ -3869,12 +4139,12 @@ def add_keys(self, comm=None, worker=None, keys=()): return "not found" ws: WorkerState = self.workers[worker] for key in keys: - ts = self.tasks.get(key) + ts: TaskState = self.tasks.get(key) if ts is not None and ts.state == "memory": if ts not in ws._has_what: ws._nbytes += ts.get_nbytes() ws._has_what.add(ts) - ts.who_has.add(ws) + ts._who_has.add(ws) else: self.worker_send( worker, {"op": "delete-data", "keys": [key], "report": False} @@ -3899,9 +4169,9 @@ def update_data( logger.debug("Update data %s", who_has) for key, workers in who_has.items(): - ts = self.tasks.get(key) + ts: TaskState = self.tasks.get(key) if ts is None: - ts = self.new_task(key, None, "memory") + ts: TaskState = self.new_task(key, None, "memory") ts.state = "memory" if key in nbytes: ts.set_nbytes(nbytes[key]) @@ -3910,7 +4180,7 @@ def update_data( if ts not in ws._has_what: ws._nbytes += ts.get_nbytes() ws._has_what.add(ts) - ts.who_has.add(ws) + ts._who_has.add(ws) self.report( {"op": "key-in-memory", "key": key, "workers": list(workers)} ) @@ -3918,7 +4188,7 @@ def update_data( if client: self.client_desires_keys(keys=list(who_has), client=client) - def report_on_key(self, key=None, ts=None, client=None): + def report_on_key(self, key=None, ts: TaskState = None, client=None): assert (key is None) + (ts is None) == 1, (key, ts) if ts is None: try: @@ -3927,19 +4197,19 @@ def report_on_key(self, key=None, ts=None, client=None): self.report({"op": "cancelled-key", "key": key}, client=client) return else: - key = ts.key + key = ts._key if ts.state == "forgotten": self.report({"op": "cancelled-key", "key": key}, ts=ts, client=client) elif ts.state == "memory": self.report({"op": "key-in-memory", "key": key}, ts=ts, client=client) elif ts.state == "erred": - failing_ts = ts.exception_blame + failing_ts: TaskState = ts._exception_blame self.report( { "op": "task-erred", "key": key, - "exception": failing_ts.exception, - "traceback": failing_ts.traceback, + "exception": failing_ts._exception, + "traceback": failing_ts._traceback, }, ts=ts, client=client, @@ -4001,16 +4271,18 @@ def subscribe_worker_status(self, comm=None): def get_processing(self, comm=None, workers=None): ws: WorkerState + ts: TaskState if workers is not None: workers = set(map(self.coerce_address, workers)) - return {w: [ts.key for ts in self.workers[w].processing] for w in workers} + return {w: [ts._key for ts in self.workers[w].processing] for w in workers} else: return { - w: [ts.key for ts in ws._processing] for w, ws in self.workers.items() + w: [ts._key for ts in ws._processing] for w, ws in self.workers.items() } def get_who_has(self, comm=None, keys=None): ws: WorkerState + ts: TaskState if keys is not None: return { k: [ws._address for ws in self.tasks[k].who_has] @@ -4020,23 +4292,24 @@ def get_who_has(self, comm=None, keys=None): } else: return { - key: [ws._address for ws in ts.who_has] + key: [ws._address for ws in ts._who_has] for key, ts in self.tasks.items() } def get_has_what(self, comm=None, workers=None): ws: WorkerState + ts: TaskState if workers is not None: workers = map(self.coerce_address, workers) return { - w: [ts.key for ts in self.workers[w].has_what] + w: [ts._key for ts in self.workers[w].has_what] if w in self.workers else [] for w in workers } else: return { - w: [ts.key for ts in ws._has_what] for w, ws in self.workers.items() + w: [ts._key for ts in ws._has_what] for w, ws in self.workers.items() } def get_ncores(self, comm=None, workers=None): @@ -4048,6 +4321,8 @@ def get_ncores(self, comm=None, workers=None): return {w: ws._nthreads for w, ws in self.workers.items()} async def get_call_stack(self, comm=None, keys=None): + ts: TaskState + dts: TaskState if keys is not None: stack = list(keys) processing = set() @@ -4055,14 +4330,14 @@ async def get_call_stack(self, comm=None, keys=None): key = stack.pop() ts = self.tasks[key] if ts.state == "waiting": - stack.extend(dts.key for dts in ts.dependencies) + stack.extend([dts._key for dts in ts._dependencies]) elif ts.state == "processing": processing.add(ts) workers = defaultdict(list) for ts in processing: - if ts.processing_on: - workers[ts.processing_on.address].append(ts.key) + if ts._processing_on: + workers[ts._processing_on.address].append(ts._key) else: workers = {w: None for w in self.workers} @@ -4076,14 +4351,13 @@ async def get_call_stack(self, comm=None, keys=None): return response def get_nbytes(self, comm=None, keys=None, summary=True): + ts: TaskState with log_errors(): if keys is not None: result = {k: self.tasks[k].nbytes for k in keys} else: result = { - k: ts.nbytes - for k, ts in self.tasks.items() - if ts.nbytes is not None + k: ts._nbytes for k, ts in self.tasks.items() if ts._nbytes >= 0 } if summary: @@ -4094,23 +4368,25 @@ def get_nbytes(self, comm=None, keys=None, summary=True): return result - def get_comm_cost(self, ts, ws: WorkerState): + def get_comm_cost(self, ts: TaskState, ws: WorkerState): """ Get the estimated communication cost (in s.) to compute the task on the given worker. """ + dts: TaskState return ( - sum(dts.nbytes for dts in ts.dependencies - ws._has_what) / self.bandwidth + sum([dts._nbytes for dts in ts._dependencies - ws._has_what]) + / self.bandwidth ) - def get_task_duration(self, ts, default=None): + def get_task_duration(self, ts: TaskState, default=None): """ Get the estimated computation cost of the given task (not including any communication cost). """ - duration = ts.prefix.duration_average - if duration is None: - self.unknown_durations[ts.prefix.name].add(ts) + duration = ts._prefix._duration_average + if duration < 0: + self.unknown_durations[ts._prefix._name].add(ts) if default is None: default = parse_timedelta( dask.config.get("distributed.scheduler.unknown-task-duration") @@ -4165,8 +4441,8 @@ def get_task_stream(self, comm=None, start=None, stop=None, count=None): from distributed.diagnostics.task_stream import TaskStreamPlugin self.add_plugin(TaskStreamPlugin, idempotent=True) - ts = [p for p in self.plugins if isinstance(p, TaskStreamPlugin)][0] - return ts.collect(start=start, stop=stop, count=count) + tsp = [p for p in self.plugins if isinstance(p, TaskStreamPlugin)][0] + return tsp.collect(start=start, stop=stop, count=count) def start_task_metadata(self, comm=None, name=None): plugin = CollectTaskMetaDataPlugin(scheduler=self, name=name) @@ -4202,12 +4478,12 @@ async def register_worker_plugin(self, comm, plugin, name=None): # State Transitions # ##################### - def _remove_from_processing(self, ts, send_worker_msg=None): + def _remove_from_processing(self, ts: TaskState, send_worker_msg=None): """ Remove *ts* from the set of processing tasks. """ - ws: WorkerState = ts.processing_on - ts.processing_on = None + ws: WorkerState = ts._processing_on + ts._processing_on = None w = ws._address if w in self.workers: # may have been removed duration = ws._processing.pop(ts) @@ -4223,7 +4499,13 @@ def _remove_from_processing(self, ts, send_worker_msg=None): self.worker_send(w, send_worker_msg) def _add_to_memory( - self, ts, ws: WorkerState, recommendations, type=None, typename=None, **kwargs + self, + ts: TaskState, + ws: WorkerState, + recommendations, + type=None, + typename=None, + **kwargs, ): """ Add *ts* to the set of in-memory tasks. @@ -4231,78 +4513,81 @@ def _add_to_memory( if self.validate: assert ts not in ws._has_what - ts.who_has.add(ws) + ts._who_has.add(ws) ws._has_what.add(ts) ws._nbytes += ts.get_nbytes() - deps = ts.dependents + deps = ts._dependents if len(deps) > 1: deps = sorted(deps, key=operator.attrgetter("priority"), reverse=True) + dts: TaskState for dts in deps: - s = dts.waiting_on + s = dts._waiting_on if ts in s: s.discard(ts) if not s: # new task ready to run - recommendations[dts.key] = "processing" + recommendations[dts._key] = "processing" - for dts in ts.dependencies: - s = dts.waiters + for dts in ts._dependencies: + s = dts._waiters s.discard(ts) - if not s and not dts.who_wants: - recommendations[dts.key] = "released" + if not s and not dts._who_wants: + recommendations[dts._key] = "released" - if not ts.waiters and not ts.who_wants: - recommendations[ts.key] = "released" + if not ts._waiters and not ts._who_wants: + recommendations[ts._key] = "released" else: - msg = {"op": "key-in-memory", "key": ts.key} + msg = {"op": "key-in-memory", "key": ts._key} if type is not None: msg["type"] = type self.report(msg) ts.state = "memory" - ts.type = typename - ts.group.types.add(typename) + ts._type = typename + ts._group._types.add(typename) cs: ClientState = self.clients["fire-and-forget"] if ts in cs._wants_what: - self.client_releases_keys(client="fire-and-forget", keys=[ts.key]) + self.client_releases_keys(client="fire-and-forget", keys=[ts._key]) def transition_released_waiting(self, key): try: - ts = self.tasks[key] + ts: TaskState = self.tasks[key] + dts: TaskState if self.validate: - assert ts.run_spec - assert not ts.waiting_on - assert not ts.who_has - assert not ts.processing_on - assert not any(dts.state == "forgotten" for dts in ts.dependencies) + assert ts._run_spec + assert not ts._waiting_on + assert not ts._who_has + assert not ts._processing_on + assert not any([dts.state == "forgotten" for dts in ts._dependencies]) - if ts.has_lost_dependencies: + if ts._has_lost_dependencies: return {key: "forgotten"} ts.state = "waiting" recommendations = {} - for dts in ts.dependencies: - if dts.exception_blame: - ts.exception_blame = dts.exception_blame + dts: TaskState + for dts in ts._dependencies: + if dts._exception_blame: + ts._exception_blame = dts._exception_blame recommendations[key] = "erred" return recommendations - for dts in ts.dependencies: - dep = dts.key - if not dts.who_has: - ts.waiting_on.add(dts) + for dts in ts._dependencies: + dep = dts._key + if not dts._who_has: + ts._waiting_on.add(dts) if dts.state == "released": recommendations[dep] = "waiting" else: - dts.waiters.add(ts) + dts._waiters.add(ts) - ts.waiters = {dts for dts in ts.dependents if dts.state == "waiting"} + ts._waiters = {dts for dts in ts._dependents if dts.state == "waiting"} - if not ts.waiting_on: + if not ts._waiting_on: if self.workers: recommendations[key] = "processing" else: @@ -4320,33 +4605,34 @@ def transition_released_waiting(self, key): def transition_no_worker_waiting(self, key): try: - ts = self.tasks[key] + ts: TaskState = self.tasks[key] + dts: TaskState if self.validate: assert ts in self.unrunnable - assert not ts.waiting_on - assert not ts.who_has - assert not ts.processing_on + assert not ts._waiting_on + assert not ts._who_has + assert not ts._processing_on self.unrunnable.remove(ts) - if ts.has_lost_dependencies: + if ts._has_lost_dependencies: return {key: "forgotten"} recommendations = {} - for dts in ts.dependencies: - dep = dts.key - if not dts.who_has: - ts.waiting_on.add(dts) + for dts in ts._dependencies: + dep = dts._key + if not dts._who_has: + ts._waiting_on.add(dts) if dts.state == "released": recommendations[dep] = "waiting" else: - dts.waiters.add(ts) + dts._waiters.add(ts) ts.state = "waiting" - if not ts.waiting_on: + if not ts._waiting_on: if self.workers: recommendations[key] = "processing" else: @@ -4362,18 +4648,18 @@ def transition_no_worker_waiting(self, key): pdb.set_trace() raise - def decide_worker(self, ts): + def decide_worker(self, ts: TaskState): """ Decide on a worker for task *ts*. Return a WorkerState. """ valid_workers = self.valid_workers(ts) - if not valid_workers and not ts.loose_restrictions and self.workers: + if not valid_workers and not ts._loose_restrictions and self.workers: self.unrunnable.add(ts) ts.state = "no-worker" return None - if ts.dependencies or valid_workers is not True: + if ts._dependencies or valid_workers is not True: worker = decide_worker( ts, self.workers.values(), @@ -4404,16 +4690,17 @@ def decide_worker(self, ts): def transition_waiting_processing(self, key): try: - ts = self.tasks[key] + ts: TaskState = self.tasks[key] + dts: TaskState if self.validate: - assert not ts.waiting_on - assert not ts.who_has - assert not ts.exception_blame - assert not ts.processing_on - assert not ts.has_lost_dependencies + assert not ts._waiting_on + assert not ts._who_has + assert not ts._exception_blame + assert not ts._processing_on + assert not ts._has_lost_dependencies assert ts not in self.unrunnable - assert all(dts.who_has for dts in ts.dependencies) + assert all([dts._who_has for dts in ts._dependencies]) ws: WorkerState = self.decide_worker(ts) if ws is None: @@ -4424,7 +4711,7 @@ def transition_waiting_processing(self, key): comm = self.get_comm_cost(ts, ws) ws._processing[ts] = duration + comm - ts.processing_on = ws + ts._processing_on = ws ws._occupancy += duration + comm self.total_occupancy += duration + comm ts.state = "processing" @@ -4432,7 +4719,7 @@ def transition_waiting_processing(self, key): self.check_idle_saturated(ws) self.n_tasks += 1 - if ts.actor: + if ts._actor: ws._actors.add(ts) # logger.debug("Send job to worker: %s, %s", worker, key) @@ -4451,14 +4738,14 @@ def transition_waiting_processing(self, key): def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs): try: ws: WorkerState = self.workers[worker] - ts = self.tasks[key] + ts: TaskState = self.tasks[key] if self.validate: - assert not ts.processing_on - assert ts.waiting_on + assert not ts._processing_on + assert ts._waiting_on assert ts.state == "waiting" - ts.waiting_on.clear() + ts._waiting_on.clear() if nbytes is not None: ts.set_nbytes(nbytes) @@ -4470,9 +4757,9 @@ def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs): self._add_to_memory(ts, ws, recommendations, **kwargs) if self.validate: - assert not ts.processing_on - assert not ts.waiting_on - assert ts.who_has + assert not ts._processing_on + assert not ts._waiting_on + assert ts._who_has return recommendations except Exception as e: @@ -4496,28 +4783,28 @@ def transition_processing_memory( ws: WorkerState wws: WorkerState try: - ts = self.tasks[key] + ts: TaskState = self.tasks[key] assert worker assert isinstance(worker, str) if self.validate: - assert ts.processing_on - ws = ts.processing_on + assert ts._processing_on + ws = ts._processing_on assert ts in ws._processing - assert not ts.waiting_on - assert not ts.who_has, (ts, ts.who_has) - assert not ts.exception_blame + assert not ts._waiting_on + assert not ts._who_has, (ts, ts._who_has) + assert not ts._exception_blame assert ts.state == "processing" ws = self.workers.get(worker) if ws is None: return {key: "released"} - if ws != ts.processing_on: # someone else has this task + if ws != ts._processing_on: # someone else has this task logger.info( "Unexpected worker completed task, likely due to" " work stealing. Expected: %s, Got: %s, Key: %s", - ts.processing_on, + ts._processing_on, ws, key, ) @@ -4534,7 +4821,7 @@ def transition_processing_memory( # record timings of all actions -- a cheaper way of # getting timing info compared with get_task_stream() - ts.prefix.all_durations[action] += stop - start + ts._prefix._all_durations[action] += stop - start if len(L) > 0: compute_start, compute_stop = L[0] @@ -4548,19 +4835,20 @@ def transition_processing_memory( ############################# if compute_start and ws._processing.get(ts, True): # Update average task duration for worker - old_duration = ts.prefix.duration_average or 0 + old_duration = ts._prefix._duration_average new_duration = compute_stop - compute_start - if not old_duration: + if old_duration < 0: avg_duration = new_duration else: avg_duration = 0.5 * old_duration + 0.5 * new_duration - ts.prefix.duration_average = avg_duration - ts.group.duration += new_duration + ts._prefix._duration_average = avg_duration + ts._group._duration += new_duration - for tts in self.unknown_durations.pop(ts.prefix.name, ()): - if tts.processing_on: - wws = tts.processing_on + tts: TaskState + for tts in self.unknown_durations.pop(ts._prefix._name, ()): + if tts._processing_on: + wws = tts._processing_on old = wws._processing[tts] comm = self.get_comm_cost(tts, wws) wws._processing[tts] = avg_duration + comm @@ -4580,8 +4868,8 @@ def transition_processing_memory( self._add_to_memory(ts, ws, recommendations, type=type, typename=typename) if self.validate: - assert not ts.processing_on - assert not ts.waiting_on + assert not ts._processing_on + assert not ts._waiting_on return recommendations except Exception as e: @@ -4595,53 +4883,54 @@ def transition_processing_memory( def transition_memory_released(self, key, safe=False): ws: WorkerState try: - ts = self.tasks[key] + ts: TaskState = self.tasks[key] + dts: TaskState if self.validate: - assert not ts.waiting_on - assert not ts.processing_on + assert not ts._waiting_on + assert not ts._processing_on if safe: - assert not ts.waiters + assert not ts._waiters - if ts.actor: - for ws in ts.who_has: + if ts._actor: + for ws in ts._who_has: ws._actors.discard(ts) - if ts.who_wants: - ts.exception_blame = ts - ts.exception = "Worker holding Actor was lost" - return {ts.key: "erred"} # don't try to recreate + if ts._who_wants: + ts._exception_blame = ts + ts._exception = "Worker holding Actor was lost" + return {ts._key: "erred"} # don't try to recreate recommendations = {} - for dts in ts.waiters: + for dts in ts._waiters: if dts.state in ("no-worker", "processing"): - recommendations[dts.key] = "waiting" + recommendations[dts._key] = "waiting" elif dts.state == "waiting": - dts.waiting_on.add(ts) + dts._waiting_on.add(ts) # XXX factor this out? - for ws in ts.who_has: + for ws in ts._who_has: ws._has_what.remove(ts) ws._nbytes -= ts.get_nbytes() - ts.group.nbytes_in_memory -= ts.get_nbytes() + ts._group._nbytes_in_memory -= ts.get_nbytes() self.worker_send( ws._address, {"op": "delete-data", "keys": [key], "report": False} ) - ts.who_has.clear() + ts._who_has.clear() ts.state = "released" self.report({"op": "lost-data", "key": key}) - if not ts.run_spec: # pure data + if not ts._run_spec: # pure data recommendations[key] = "forgotten" - elif ts.has_lost_dependencies: + elif ts._has_lost_dependencies: recommendations[key] = "forgotten" - elif ts.who_wants or ts.waiters: + elif ts._who_wants or ts._waiters: recommendations[key] = "waiting" if self.validate: - assert not ts.waiting_on + assert not ts._waiting_on return recommendations except Exception as e: @@ -4654,30 +4943,32 @@ def transition_memory_released(self, key, safe=False): def transition_released_erred(self, key): try: - ts = self.tasks[key] + ts: TaskState = self.tasks[key] + dts: TaskState + failing_ts: TaskState if self.validate: with log_errors(pdb=LOG_PDB): - assert ts.exception_blame - assert not ts.who_has - assert not ts.waiting_on - assert not ts.waiters + assert ts._exception_blame + assert not ts._who_has + assert not ts._waiting_on + assert not ts._waiters recommendations = {} - failing_ts = ts.exception_blame + failing_ts = ts._exception_blame - for dts in ts.dependents: - dts.exception_blame = failing_ts - if not dts.who_has: - recommendations[dts.key] = "erred" + for dts in ts._dependents: + dts._exception_blame = failing_ts + if not dts._who_has: + recommendations[dts._key] = "erred" self.report( { "op": "task-erred", "key": key, - "exception": failing_ts.exception, - "traceback": failing_ts.traceback, + "exception": failing_ts._exception, + "traceback": failing_ts._traceback, } ) @@ -4695,25 +4986,26 @@ def transition_released_erred(self, key): def transition_erred_released(self, key): try: - ts = self.tasks[key] + ts: TaskState = self.tasks[key] + dts: TaskState if self.validate: with log_errors(pdb=LOG_PDB): - assert all(dts.state != "erred" for dts in ts.dependencies) - assert ts.exception_blame - assert not ts.who_has - assert not ts.waiting_on - assert not ts.waiters + assert all([dts.state != "erred" for dts in ts._dependencies]) + assert ts._exception_blame + assert not ts._who_has + assert not ts._waiting_on + assert not ts._waiters recommendations = {} - ts.exception = None - ts.exception_blame = None - ts.traceback = None + ts._exception = None + ts._exception_blame = None + ts._traceback = None - for dep in ts.dependents: - if dep.state == "erred": - recommendations[dep.key] = "waiting" + for dts in ts._dependents: + if dts.state == "erred": + recommendations[dts._key] = "waiting" self.report({"op": "task-retried", "key": key}) ts.state = "released" @@ -4729,30 +5021,31 @@ def transition_erred_released(self, key): def transition_waiting_released(self, key): try: - ts = self.tasks[key] + ts: TaskState = self.tasks[key] if self.validate: - assert not ts.who_has - assert not ts.processing_on + assert not ts._who_has + assert not ts._processing_on recommendations = {} - for dts in ts.dependencies: - s = dts.waiters + dts: TaskState + for dts in ts._dependencies: + s = dts._waiters if ts in s: s.discard(ts) - if not s and not dts.who_wants: - recommendations[dts.key] = "released" - ts.waiting_on.clear() + if not s and not dts._who_wants: + recommendations[dts._key] = "released" + ts._waiting_on.clear() ts.state = "released" - if ts.has_lost_dependencies: + if ts._has_lost_dependencies: recommendations[key] = "forgotten" - elif not ts.exception_blame and (ts.who_wants or ts.waiters): + elif not ts._exception_blame and (ts._who_wants or ts._waiters): recommendations[key] = "waiting" else: - ts.waiters.clear() + ts._waiters.clear() return recommendations except Exception as e: @@ -4765,12 +5058,13 @@ def transition_waiting_released(self, key): def transition_processing_released(self, key): try: - ts = self.tasks[key] + ts: TaskState = self.tasks[key] + dts: TaskState if self.validate: - assert ts.processing_on - assert not ts.who_has - assert not ts.waiting_on + assert ts._processing_on + assert not ts._who_has + assert not ts._waiting_on assert self.tasks[key].state == "processing" self._remove_from_processing( @@ -4781,22 +5075,22 @@ def transition_processing_released(self, key): recommendations = {} - if ts.has_lost_dependencies: + if ts._has_lost_dependencies: recommendations[key] = "forgotten" - elif ts.waiters or ts.who_wants: + elif ts._waiters or ts._who_wants: recommendations[key] = "waiting" if recommendations.get(key) != "waiting": - for dts in ts.dependencies: + for dts in ts._dependencies: if dts.state != "released": - s = dts.waiters + s = dts._waiters s.discard(ts) - if not s and not dts.who_wants: - recommendations[dts.key] = "released" - ts.waiters.clear() + if not s and not dts._who_wants: + recommendations[dts._key] = "released" + ts._waiters.clear() if self.validate: - assert not ts.processing_on + assert not ts._processing_on return recommendations except Exception as e: @@ -4812,43 +5106,45 @@ def transition_processing_erred( ): ws: WorkerState try: - ts = self.tasks[key] + ts: TaskState = self.tasks[key] + dts: TaskState + failing_ts: TaskState if self.validate: - assert cause or ts.exception_blame - assert ts.processing_on - assert not ts.who_has - assert not ts.waiting_on + assert cause or ts._exception_blame + assert ts._processing_on + assert not ts._who_has + assert not ts._waiting_on - if ts.actor: - ws = ts.processing_on + if ts._actor: + ws = ts._processing_on ws._actors.remove(ts) self._remove_from_processing(ts) if exception is not None: - ts.exception = exception + ts._exception = exception if traceback is not None: - ts.traceback = traceback + ts._traceback = traceback if cause is not None: failing_ts = self.tasks[cause] - ts.exception_blame = failing_ts + ts._exception_blame = failing_ts else: - failing_ts = ts.exception_blame + failing_ts = ts._exception_blame recommendations = {} - for dts in ts.dependents: - dts.exception_blame = failing_ts - recommendations[dts.key] = "erred" + for dts in ts._dependents: + dts._exception_blame = failing_ts + recommendations[dts._key] = "erred" - for dts in ts.dependencies: - s = dts.waiters + for dts in ts._dependencies: + s = dts._waiters s.discard(ts) - if not s and not dts.who_wants: - recommendations[dts.key] = "released" + if not s and not dts._who_wants: + recommendations[dts._key] = "released" - ts.waiters.clear() # do anything with this? + ts._waiters.clear() # do anything with this? ts.state = "erred" @@ -4856,8 +5152,8 @@ def transition_processing_erred( { "op": "task-erred", "key": key, - "exception": failing_ts.exception, - "traceback": failing_ts.traceback, + "exception": failing_ts._exception, + "traceback": failing_ts._traceback, } ) @@ -4866,7 +5162,7 @@ def transition_processing_erred( self.client_releases_keys(client="fire-and-forget", keys=[key]) if self.validate: - assert not ts.processing_on + assert not ts._processing_on return recommendations except Exception as e: @@ -4879,20 +5175,21 @@ def transition_processing_erred( def transition_no_worker_released(self, key): try: - ts = self.tasks[key] + ts: TaskState = self.tasks[key] + dts: TaskState if self.validate: assert self.tasks[key].state == "no-worker" - assert not ts.who_has - assert not ts.waiting_on + assert not ts._who_has + assert not ts._waiting_on self.unrunnable.remove(ts) ts.state = "released" - for dts in ts.dependencies: - dts.waiters.discard(ts) + for dts in ts._dependencies: + dts._waiters.discard(ts) - ts.waiters.clear() + ts._waiters.clear() return {} except Exception as e: @@ -4904,48 +5201,49 @@ def transition_no_worker_released(self, key): raise def remove_key(self, key): - ts = self.tasks.pop(key) + ts: TaskState = self.tasks.pop(key) assert ts.state == "forgotten" self.unrunnable.discard(ts) cs: ClientState - for cs in ts.who_wants: + for cs in ts._who_wants: cs._wants_what.remove(ts) - ts.who_wants.clear() - ts.processing_on = None - ts.exception_blame = ts.exception = ts.traceback = None + ts._who_wants.clear() + ts._processing_on = None + ts._exception_blame = ts._exception = ts._traceback = None if key in self.task_metadata: del self.task_metadata[key] - def _propagate_forgotten(self, ts, recommendations): + def _propagate_forgotten(self, ts: TaskState, recommendations): ts.state = "forgotten" - key = ts.key - for dts in ts.dependents: - dts.has_lost_dependencies = True - dts.dependencies.remove(ts) - dts.waiting_on.discard(ts) + key = ts._key + dts: TaskState + for dts in ts._dependents: + dts._has_lost_dependencies = True + dts._dependencies.remove(ts) + dts._waiting_on.discard(ts) if dts.state not in ("memory", "erred"): # Cannot compute task anymore - recommendations[dts.key] = "forgotten" - ts.dependents.clear() - ts.waiters.clear() + recommendations[dts._key] = "forgotten" + ts._dependents.clear() + ts._waiters.clear() - for dts in ts.dependencies: - dts.dependents.remove(ts) - s = dts.waiters + for dts in ts._dependencies: + dts._dependents.remove(ts) + s = dts._waiters s.discard(ts) - if not dts.dependents and not dts.who_wants: + if not dts._dependents and not dts._who_wants: # Task not needed anymore assert dts is not ts - recommendations[dts.key] = "forgotten" - ts.dependencies.clear() - ts.waiting_on.clear() + recommendations[dts._key] = "forgotten" + ts._dependencies.clear() + ts._waiting_on.clear() - if ts.who_has: - ts.group.nbytes_in_memory -= ts.get_nbytes() + if ts._who_has: + ts._group._nbytes_in_memory -= ts.get_nbytes() ws: WorkerState - for ws in ts.who_has: + for ws in ts._who_has: ws._has_what.remove(ts) ws._nbytes -= ts.get_nbytes() w = ws._address @@ -4953,24 +5251,24 @@ def _propagate_forgotten(self, ts, recommendations): self.worker_send( w, {"op": "delete-data", "keys": [key], "report": False} ) - ts.who_has.clear() + ts._who_has.clear() def transition_memory_forgotten(self, key): ws: WorkerState try: - ts = self.tasks[key] + ts: TaskState = self.tasks[key] if self.validate: assert ts.state == "memory" - assert not ts.processing_on - assert not ts.waiting_on - if not ts.run_spec: + assert not ts._processing_on + assert not ts._waiting_on + if not ts._run_spec: # It's ok to forget a pure data task pass - elif ts.has_lost_dependencies: + elif ts._has_lost_dependencies: # It's ok to forget a task with forgotten dependencies pass - elif not ts.who_wants and not ts.waiters and not ts.dependents: + elif not ts._who_wants and not ts._waiters and not ts._dependents: # It's ok to forget a task that nobody needs pass else: @@ -4978,8 +5276,8 @@ def transition_memory_forgotten(self, key): recommendations = {} - if ts.actor: - for ws in ts.who_has: + if ts._actor: + for ws in ts._who_has: ws._actors.discard(ts) self._propagate_forgotten(ts, recommendations) @@ -4998,20 +5296,20 @@ def transition_memory_forgotten(self, key): def transition_released_forgotten(self, key): try: - ts = self.tasks[key] + ts: TaskState = self.tasks[key] if self.validate: assert ts.state in ("released", "erred") - assert not ts.who_has - assert not ts.processing_on - assert not ts.waiting_on, (ts, ts.waiting_on) - if not ts.run_spec: + assert not ts._who_has + assert not ts._processing_on + assert not ts._waiting_on, (ts, ts._waiting_on) + if not ts._run_spec: # It's ok to forget a pure data task pass - elif ts.has_lost_dependencies: + elif ts._has_lost_dependencies: # It's ok to forget a task with forgotten dependencies pass - elif not ts.who_wants and not ts.waiters and not ts.dependents: + elif not ts._who_wants and not ts._waiters and not ts._dependents: # It's ok to forget a task that nobody needs pass else: @@ -5048,6 +5346,7 @@ def transition(self, key, finish, *args, **kwargs): -------- Scheduler.transitions: transitive version of this function """ + ts: TaskState try: try: ts = self.tasks[key] @@ -5058,8 +5357,8 @@ def transition(self, key, finish, *args, **kwargs): return {} if self.plugins: - dependents = set(ts.dependents) - dependencies = set(ts.dependencies) + dependents = set(ts._dependents) + dependencies = set(ts._dependencies) if (start, finish) in self._transitions: func = self._transitions[start, finish] @@ -5095,25 +5394,25 @@ def transition(self, key, finish, *args, **kwargs): # Temporarily put back forgotten key for plugin to retrieve it if ts.state == "forgotten": try: - ts.dependents = dependents - ts.dependencies = dependencies + ts._dependents = dependents + ts._dependencies = dependencies except KeyError: pass - self.tasks[ts.key] = ts + self.tasks[ts._key] = ts for plugin in list(self.plugins): try: plugin.transition(key, start, finish2, *args, **kwargs) except Exception: logger.info("Plugin failed with exception", exc_info=True) if ts.state == "forgotten": - del self.tasks[ts.key] + del self.tasks[ts._key] - if ts.state == "forgotten" and ts.group.name in self.task_groups: + if ts.state == "forgotten" and ts._group._name in self.task_groups: # Remove TaskGroup if all tasks are in the forgotten state - tg = ts.group - if not any(tg.states.get(s) for s in ALL_TASK_STATES): - ts.prefix.groups.remove(tg) - del self.task_groups[tg.name] + tg: TaskGroup = ts._group + if not any(tg._states.get(s) for s in ALL_TASK_STATES): + ts._prefix._groups.remove(tg) + del self.task_groups[tg._name] return recommendations except Exception as e: @@ -5157,6 +5456,7 @@ def reschedule(self, key=None, worker=None): Things may have shifted and this task may now be better suited to run elsewhere """ + ts: TaskState try: ts = self.tasks[key] except KeyError: @@ -5167,7 +5467,7 @@ def reschedule(self, key=None, worker=None): return if ts.state != "processing": return - if worker and ts.processing_on.address != worker: + if worker and ts._processing_on.address != worker: return self.transitions({key: "released"}) @@ -5214,7 +5514,7 @@ def check_idle_saturated(self, ws: WorkerState, occ: double = -1.0): else: saturated.discard(ws) - def valid_workers(self, ts): + def valid_workers(self, ts: TaskState): """Return set of currently valid workers for key If all workers are valid then this returns ``True``. @@ -5226,13 +5526,13 @@ def valid_workers(self, ts): """ s = True - if ts.worker_restrictions: - s = {w for w in ts.worker_restrictions if w in self.workers} + if ts._worker_restrictions: + s = {w for w in ts._worker_restrictions if w in self.workers} - if ts.host_restrictions: + if ts._host_restrictions: # Resolve the alias here rather than early, for the worker # may not be connected when host_restrictions is populated - hr = [self.coerce_hostname(h) for h in ts.host_restrictions] + hr = [self.coerce_hostname(h) for h in ts._host_restrictions] # XXX need HostState? ss = [self.host_info[h]["addresses"] for h in hr if h in self.host_info] ss = set.union(*ss) if ss else set() @@ -5241,14 +5541,14 @@ def valid_workers(self, ts): else: s |= ss - if ts.resource_restrictions: + if ts._resource_restrictions: w = { resource: { w for w, supplied in self.resources[resource].items() if supplied >= required } - for resource, required in ts.resource_restrictions.items() + for resource, required in ts._resource_restrictions.items() } ww = set.intersection(*w.values()) @@ -5263,14 +5563,14 @@ def valid_workers(self, ts): else: return {self.workers[w] for w in s} - def consume_resources(self, ts, ws: WorkerState): - if ts.resource_restrictions: - for r, required in ts.resource_restrictions.items(): + def consume_resources(self, ts: TaskState, ws: WorkerState): + if ts._resource_restrictions: + for r, required in ts._resource_restrictions.items(): ws._used_resources[r] += required - def release_resources(self, ts, ws: WorkerState): - if ts.resource_restrictions: - for r, required in ts.resource_restrictions.items(): + def release_resources(self, ts: TaskState, ws: WorkerState): + if ts._resource_restrictions: + for r, required in ts._resource_restrictions.items(): ws._used_resources[r] -= required ##################### @@ -5354,19 +5654,20 @@ def start_ipython(self, comm=None): ) return self._ipython_kernel.get_connection_info() - def worker_objective(self, ts, ws: WorkerState): + def worker_objective(self, ts: TaskState, ws: WorkerState): """ Objective function to determine which worker should get the task Minimize expected start time. If a tie then break with data storage. """ + dts: TaskState comm_bytes = sum( - [dts.get_nbytes() for dts in ts.dependencies if ws not in dts.who_has] + [dts.get_nbytes() for dts in ts._dependencies if ws not in dts._who_has] ) stack_time = ws._occupancy / ws._nthreads start_time = comm_bytes / self.bandwidth + stack_time - if ts.actor: + if ts._actor: return (len(ws._actors), start_time, ws._nbytes) else: return (start_time, ws._nbytes) @@ -5775,7 +6076,7 @@ def adaptive_target(self, comm=None, target_duration=None): return len(self.workers) - len(to_close) -def decide_worker(ts, all_workers, valid_workers, objective): +def decide_worker(ts: TaskState, all_workers, valid_workers, objective): """ Decide which worker should take task *ts*. @@ -5791,13 +6092,14 @@ def decide_worker(ts, all_workers, valid_workers, objective): of bytes sent between workers. This is determined by calling the *objective* function. """ - deps = ts.dependencies - assert all(dts.who_has for dts in deps) - if ts.actor: + dts: TaskState + deps = ts._dependencies + assert all([dts._who_has for dts in deps]) + if ts._actor: candidates = set(all_workers) else: ws: WorkerState - candidates = {ws for dts in deps for ws in dts.who_has} + candidates = {ws for dts in deps for ws in dts._who_has} if valid_workers is True: if not candidates: candidates = set(all_workers) @@ -5806,7 +6108,7 @@ def decide_worker(ts, all_workers, valid_workers, objective): if not candidates: candidates = valid_workers if not candidates: - if ts.loose_restrictions: + if ts._loose_restrictions: return decide_worker(ts, all_workers, True, objective) else: return None @@ -5819,82 +6121,83 @@ def decide_worker(ts, all_workers, valid_workers, objective): return min(candidates, key=objective) -def validate_task_state(ts): +def validate_task_state(ts: TaskState): """ Validate the given TaskState. """ ws: WorkerState + dts: TaskState assert ts.state in ALL_TASK_STATES or ts.state == "forgotten", ts - if ts.waiting_on: - assert ts.waiting_on.issubset(ts.dependencies), ( + if ts._waiting_on: + assert ts._waiting_on.issubset(ts._dependencies), ( "waiting not subset of dependencies", - str(ts.waiting_on), - str(ts.dependencies), + str(ts._waiting_on), + str(ts._dependencies), ) - if ts.waiters: - assert ts.waiters.issubset(ts.dependents), ( + if ts._waiters: + assert ts._waiters.issubset(ts._dependents), ( "waiters not subset of dependents", - str(ts.waiters), - str(ts.dependents), + str(ts._waiters), + str(ts._dependents), ) - for dts in ts.waiting_on: - assert not dts.who_has, ("waiting on in-memory dep", str(ts), str(dts)) + for dts in ts._waiting_on: + assert not dts._who_has, ("waiting on in-memory dep", str(ts), str(dts)) assert dts.state != "released", ("waiting on released dep", str(ts), str(dts)) - for dts in ts.dependencies: - assert ts in dts.dependents, ( + for dts in ts._dependencies: + assert ts in dts._dependents, ( "not in dependency's dependents", str(ts), str(dts), - str(dts.dependents), + str(dts._dependents), ) if ts.state in ("waiting", "processing"): - assert dts in ts.waiting_on or dts.who_has, ( + assert dts in ts._waiting_on or dts._who_has, ( "dep missing", str(ts), str(dts), ) assert dts.state != "forgotten" - for dts in ts.waiters: + for dts in ts._waiters: assert dts.state in ("waiting", "processing"), ( "waiter not in play", str(ts), str(dts), ) - for dts in ts.dependents: - assert ts in dts.dependencies, ( + for dts in ts._dependents: + assert ts in dts._dependencies, ( "not in dependent's dependencies", str(ts), str(dts), - str(dts.dependencies), + str(dts._dependencies), ) assert dts.state != "forgotten" - assert (ts.processing_on is not None) == (ts.state == "processing") - assert bool(ts.who_has) == (ts.state == "memory"), (ts, ts.who_has) + assert (ts._processing_on is not None) == (ts.state == "processing") + assert bool(ts._who_has) == (ts.state == "memory"), (ts, ts._who_has) if ts.state == "processing": - assert all(dts.who_has for dts in ts.dependencies), ( + assert all([dts._who_has for dts in ts._dependencies]), ( "task processing without all deps", str(ts), - str(ts.dependencies), + str(ts._dependencies), ) - assert not ts.waiting_on + assert not ts._waiting_on - if ts.who_has: - assert ts.waiters or ts.who_wants, ( + if ts._who_has: + assert ts._waiters or ts._who_wants, ( "unneeded task in memory", str(ts), - str(ts.who_has), + str(ts._who_has), ) - if ts.run_spec: # was computed - assert ts.type - assert isinstance(ts.type, str) - assert not any(ts in dts.waiting_on for dts in ts.dependents) - for ws in ts.who_has: + if ts._run_spec: # was computed + assert ts._type + assert isinstance(ts._type, str) + assert not any([ts in dts._waiting_on for dts in ts._dependents]) + for ws in ts._who_has: assert ts in ws._has_what, ( "not in who_has' has_what", str(ts), @@ -5902,9 +6205,9 @@ def validate_task_state(ts): str(ws._has_what), ) - if ts.who_wants: + if ts._who_wants: cs: ClientState - for cs in ts.who_wants: + for cs in ts._who_wants: assert ts in cs._wants_what, ( "not in who_wants' wants_what", str(ts), @@ -5912,20 +6215,21 @@ def validate_task_state(ts): str(cs._wants_what), ) - if ts.actor: + if ts._actor: if ts.state == "memory": - assert sum([ts in ws._actors for ws in ts.who_has]) == 1 + assert sum([ts in ws._actors for ws in ts._who_has]) == 1 if ts.state == "processing": - assert ts in ts.processing_on.actors + assert ts in ts._processing_on.actors def validate_worker_state(ws: WorkerState): + ts: TaskState for ts in ws._has_what: - assert ws in ts.who_has, ( + assert ws in ts._who_has, ( "not in has_what' who_has", str(ws), str(ts), - str(ts.who_has), + str(ts._who_has), ) for ts in ws._actors: @@ -5939,6 +6243,7 @@ def validate_state(tasks, workers, clients): This performs a sequence of checks on the entire graph, running in about linear time. This raises assert errors if anything doesn't check out. """ + ts: TaskState for ts in tasks.values(): validate_task_state(ts) @@ -5949,11 +6254,11 @@ def validate_state(tasks, workers, clients): cs: ClientState for cs in clients.values(): for ts in cs._wants_what: - assert cs in ts.who_wants, ( + assert cs in ts._who_wants, ( "not in wants_what' who_wants", str(cs), str(ts), - str(ts.who_wants), + str(ts._who_wants), ) @@ -6029,8 +6334,8 @@ def update_graph(self, scheduler, dsk=None, keys=None, restrictions=None, **kwar def transition(self, key, start, finish, *args, **kwargs): if finish == "memory" or finish == "erred": - ts = self.scheduler.tasks.get(key) - if ts is not None and ts.key in self.keys: - self.metadata[key] = ts.metadata + ts: TaskState = self.scheduler.tasks.get(key) + if ts is not None and ts._key in self.keys: + self.metadata[key] = ts._metadata self.state[key] = finish self.keys.discard(key)