diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 5a60eac5f92..f2126a80e41 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1585,6 +1585,8 @@ class SchedulerState: _resources: object _saturated: set _tasks: dict + _task_groups: dict + _task_prefixes: dict _task_metadata: dict _total_nthreads: Py_ssize_t _total_occupancy: double @@ -1635,6 +1637,8 @@ def __init__( self._tasks = tasks else: self._tasks = dict() + self._task_groups = dict() + self._task_prefixes = dict() self._task_metadata = dict() self._total_nthreads = 0 self._total_occupancy = 0 @@ -1691,6 +1695,14 @@ def saturated(self): def tasks(self): return self._tasks + @property + def task_groups(self): + return self._task_groups + + @property + def task_prefixes(self): + return self._task_prefixes + @property def task_metadata(self): return self._task_metadata @@ -1738,6 +1750,8 @@ def __pdict__(self): "unknown_durations": self._unknown_durations, "validate": self._validate, "tasks": self._tasks, + "task_groups": self._task_groups, + "task_prefixes": self._task_prefixes, "total_nthreads": self._total_nthreads, "total_occupancy": self._total_occupancy, "extensions": self._extensions, @@ -2926,8 +2940,6 @@ def __init__( # Task state tasks = dict() - self.task_groups = dict() - self.task_prefixes = dict() for old_attr, new_attr, wrap in [ ("priority", "priority", None), ("dependencies", "dependencies", _legacy_task_key_set), @@ -3919,17 +3931,15 @@ def new_task(self, key, spec, state): tg: TaskGroup ts._state = state prefix_key = key_split(key) - try: - tp = self.task_prefixes[prefix_key] - except KeyError: - self.task_prefixes[prefix_key] = tp = TaskPrefix(prefix_key) + tp = parent._task_prefixes.get(prefix_key) + if tp is None: + parent._task_prefixes[prefix_key] = tp = TaskPrefix(prefix_key) ts._prefix = tp group_key = ts._group_key - try: - tg = self.task_groups[group_key] - except KeyError: - self.task_groups[group_key] = tg = TaskGroup(group_key) + tg = parent._task_groups.get(group_key) + if tg is None: + parent._task_groups[group_key] = tg = TaskGroup(group_key) tg._prefix = tp tp._groups.append(tg) tg.add(ts) @@ -5891,12 +5901,12 @@ def transition(self, key, finish, *args, **kwargs): if ts._state == "forgotten": del parent._tasks[ts._key] - if ts._state == "forgotten" and ts._group._name in self.task_groups: + tg: TaskGroup = ts._group + if ts._state == "forgotten" and tg._name in parent._task_groups: # Remove TaskGroup if all tasks are in the forgotten state - tg: TaskGroup = ts._group if not any([tg._states.get(s) for s in ALL_TASK_STATES]): ts._prefix._groups.remove(tg) - del self.task_groups[tg._name] + del parent._task_groups[tg._name] return recommendations except Exception as e: