Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1585,6 +1585,8 @@ class SchedulerState:
_resources: object
_saturated: set
_tasks: dict
_task_groups: dict
_task_prefixes: dict
_task_metadata: dict
_total_nthreads: Py_ssize_t
_total_occupancy: double
Expand Down Expand Up @@ -1635,6 +1637,8 @@ def __init__(
self._tasks = tasks
else:
self._tasks = dict()
self._task_groups = dict()
self._task_prefixes = dict()
self._task_metadata = dict()
self._total_nthreads = 0
self._total_occupancy = 0
Expand Down Expand Up @@ -1691,6 +1695,14 @@ def saturated(self):
def tasks(self):
return self._tasks

@property
def task_groups(self):
return self._task_groups

@property
def task_prefixes(self):
return self._task_prefixes

@property
def task_metadata(self):
return self._task_metadata
Expand Down Expand Up @@ -1738,6 +1750,8 @@ def __pdict__(self):
"unknown_durations": self._unknown_durations,
"validate": self._validate,
"tasks": self._tasks,
"task_groups": self._task_groups,
"task_prefixes": self._task_prefixes,
"total_nthreads": self._total_nthreads,
"total_occupancy": self._total_occupancy,
"extensions": self._extensions,
Expand Down Expand Up @@ -2926,8 +2940,6 @@ def __init__(

# Task state
tasks = dict()
self.task_groups = dict()
self.task_prefixes = dict()
for old_attr, new_attr, wrap in [
("priority", "priority", None),
("dependencies", "dependencies", _legacy_task_key_set),
Expand Down Expand Up @@ -3919,17 +3931,15 @@ def new_task(self, key, spec, state):
tg: TaskGroup
ts._state = state
prefix_key = key_split(key)
try:
tp = self.task_prefixes[prefix_key]
except KeyError:
self.task_prefixes[prefix_key] = tp = TaskPrefix(prefix_key)
tp = parent._task_prefixes.get(prefix_key)
if tp is None:
parent._task_prefixes[prefix_key] = tp = TaskPrefix(prefix_key)
ts._prefix = tp

group_key = ts._group_key
try:
tg = self.task_groups[group_key]
except KeyError:
self.task_groups[group_key] = tg = TaskGroup(group_key)
tg = parent._task_groups.get(group_key)
if tg is None:
parent._task_groups[group_key] = tg = TaskGroup(group_key)
tg._prefix = tp
tp._groups.append(tg)
tg.add(ts)
Expand Down Expand Up @@ -5891,12 +5901,12 @@ def transition(self, key, finish, *args, **kwargs):
if ts._state == "forgotten":
del parent._tasks[ts._key]

if ts._state == "forgotten" and ts._group._name in self.task_groups:
tg: TaskGroup = ts._group
if ts._state == "forgotten" and tg._name in parent._task_groups:
# Remove TaskGroup if all tasks are in the forgotten state
tg: TaskGroup = ts._group
if not any([tg._states.get(s) for s in ALL_TASK_STATES]):
ts._prefix._groups.remove(tg)
del self.task_groups[tg._name]
del parent._task_groups[tg._name]

return recommendations
except Exception as e:
Expand Down