diff --git a/distributed/collections.py b/distributed/collections.py
index 4aef7d555e9..992b4582ebf 100644
--- a/distributed/collections.py
+++ b/distributed/collections.py
@@ -1,12 +1,18 @@
from __future__ import annotations
+import dataclasses
import heapq
import itertools
import weakref
from collections import OrderedDict, UserDict
from collections.abc import Callable, Hashable, Iterator
-from typing import MutableSet # TODO move to collections.abc (requires Python >=3.9)
-from typing import Any, TypeVar, cast
+from typing import ( # TODO move to collections.abc (requires Python >=3.9)
+ Any,
+ Container,
+ MutableSet,
+ TypeVar,
+ cast,
+)
T = TypeVar("T", bound=Hashable)
@@ -199,3 +205,54 @@ def clear(self) -> None:
self._data.clear()
self._heap.clear()
self._sorted = True
+
+
+# NOTE: only used in Scheduler; if work stealing is ever removed,
+# this could be moved to `scheduler.py`.
+@dataclasses.dataclass
+class Occupancy:
+ cpu: float
+ network: float
+
+ def __add__(self, other: Any) -> Occupancy:
+ if isinstance(other, type(self)):
+ return type(self)(self.cpu + other.cpu, self.network + other.network)
+ return NotImplemented
+
+ def __iadd__(self, other: Any) -> Occupancy:
+ if isinstance(other, type(self)):
+ self.cpu += other.cpu
+ self.network += other.network
+ return self
+ return NotImplemented
+
+ def __sub__(self, other: Any) -> Occupancy:
+ if isinstance(other, type(self)):
+ return type(self)(self.cpu - other.cpu, self.network - other.network)
+ return NotImplemented
+
+ def __isub__(self, other: Any) -> Occupancy:
+ if isinstance(other, type(self)):
+ self.cpu -= other.cpu
+ self.network -= other.network
+ return self
+ return NotImplemented
+
+ def __bool__(self) -> bool:
+ return self.cpu != 0 or self.network != 0
+
+ def __eq__(self, other: Any) -> bool:
+ if isinstance(other, type(self)):
+ return self.cpu == other.cpu and self.network == other.network
+ return NotImplemented
+
+ def clear(self) -> None:
+ self.cpu = 0.0
+ self.network = 0.0
+
+ def _to_dict(self, *, exclude: Container[str] = ()) -> dict[str, float]:
+ return {"cpu": self.cpu, "network": self.network}
+
+ @property
+ def total(self) -> float:
+ return self.cpu + self.network
diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py
index d89483472da..cb8fee8d1eb 100644
--- a/distributed/dashboard/components/scheduler.py
+++ b/distributed/dashboard/components/scheduler.py
@@ -120,13 +120,14 @@ def __init__(self, scheduler, **kwargs):
self.scheduler = scheduler
self.source = ColumnDataSource(
{
- "occupancy": [0, 0],
- "worker": ["a", "b"],
- "x": [0.0, 0.1],
- "y": [1, 2],
- "ms": [1, 2],
- "color": ["red", "blue"],
- "escaped_worker": ["a", "b"],
+ "occupancy_network": [],
+ "occupancy_cpu": [],
+ "occupancy_network_ms": [],
+ "occupancy_cpu_ms": [],
+ "worker": [],
+ "y": [],
+ "color": [],
+ "escaped_worker": [],
}
)
@@ -139,10 +140,14 @@ def __init__(self, scheduler, **kwargs):
min_border_bottom=50,
**kwargs,
)
- rect = self.root.rect(
- source=self.source, x="x", width="ms", y="y", height=0.9, color="color"
+ self.root.hbar_stack(
+ ["occupancy_network_ms", "occupancy_cpu_ms"],
+ source=self.source,
+ y="y",
+ height=0.9,
+ fill_alpha=[0.8, 1.0],
+ color="color",
)
- rect.nonselection_glyph = None
self.root.xaxis.minor_tick_line_alpha = 0
self.root.yaxis.visible = False
@@ -153,7 +158,9 @@ def __init__(self, scheduler, **kwargs):
tap = TapTool(callback=OpenURL(url="./info/worker/@escaped_worker.html"))
hover = HoverTool()
- hover.tooltips = "@worker : @occupancy s."
+ hover.tooltips = (
+ "@worker : network: @occupancy_network s, cpu: @occupancy_cpu s."
+ )
hover.point_policy = "follow_mouse"
self.root.add_tools(hover, tap)
@@ -163,10 +170,14 @@ def update(self):
workers = self.scheduler.workers.values()
y = list(range(len(workers)))
- occupancy = [ws.occupancy for ws in workers]
- ms = [occ * 1000 for occ in occupancy]
- x = [occ / 500 for occ in occupancy]
- total = sum(occupancy)
+ occupancy_network, occupancy_cpu = zip(
+ *((ws.occupancy.network, ws.occupancy.cpu) for ws in workers)
+ )
+ occupancy_network = np.array(occupancy_network)
+ occupancy_cpu = np.array(occupancy_cpu)
+ total_network = occupancy_network.sum()
+ total_cpu = occupancy_cpu.sum()
+ total = total_network + total_cpu
color = []
for ws in workers:
if ws in self.scheduler.idle:
@@ -178,20 +189,22 @@ def update(self):
if total:
self.root.title.text = (
- f"Occupancy -- total time: {format_time(total)} "
- f"wall time: {format_time(total / self.scheduler.total_nthreads)}"
+ f"Total time: {format_time(total)}, "
+ f"wall time: {format_time(total / self.scheduler.total_nthreads)}, "
+ f"{total_network / total:.0%} network"
)
else:
self.root.title.text = "Occupancy"
- if occupancy:
+ if workers:
result = {
- "occupancy": occupancy,
+ "occupancy_network": occupancy_network,
+ "occupancy_cpu": occupancy_cpu,
+ "occupancy_network_ms": occupancy_network * 1000,
+ "occupancy_cpu_ms": occupancy_cpu * 1000,
"worker": [ws.address for ws in workers],
- "ms": ms,
"color": color,
"escaped_worker": [escape.url_escape(ws.address) for ws in workers],
- "x": x,
"y": y,
}
diff --git a/distributed/http/templates/worker-table.html b/distributed/http/templates/worker-table.html
index 87512ee3860..52eeb9c3f87 100644
--- a/distributed/http/templates/worker-table.html
+++ b/distributed/http/templates/worker-table.html
@@ -5,7 +5,8 @@
Cores |
Memory |
Memory use |
- Occupancy |
+ Network Occupancy |
+ CPU Occupancy |
Processing |
In-memory |
Services |
@@ -19,7 +20,8 @@
{{ ws.nthreads }} |
{{ format_bytes(ws.memory_limit) if ws.memory_limit is not None else "" }} |
|
- {{ format_time(ws.occupancy) }} |
+ {{ format_time(ws.occupancy.network) }} |
+ {{ format_time(ws.occupancy.cpu) }} |
{{ len(ws.processing) }} |
{{ len(ws.has_what) }} |
{% if 'dashboard' in ws.services %}
diff --git a/distributed/scheduler.py b/distributed/scheduler.py
index c7aeba991e5..b1de0f20cb9 100644
--- a/distributed/scheduler.py
+++ b/distributed/scheduler.py
@@ -65,7 +65,7 @@
from distributed._stories import scheduler_story
from distributed.active_memory_manager import ActiveMemoryManagerExtension, RetireWorker
from distributed.batched import BatchedSend
-from distributed.collections import HeapSet
+from distributed.collections import HeapSet, Occupancy
from distributed.comm import (
Comm,
CommClosedError,
@@ -424,10 +424,10 @@ class WorkerState:
#: (i.e. the tasks in this worker's :attr:`~WorkerState.has_what`).
nbytes: int
- #: The total expected runtime, in seconds, of all tasks currently processing on this
- #: worker. This is the sum of all the costs in this worker's
+ #: The total expected cost, in seconds, of all tasks currently processing on this
+ #: worker. This is the sum of all the Occupancies in this worker's
# :attr:`~WorkerState.processing` dictionary.
- occupancy: float
+ occupancy: Occupancy
#: Worker memory unknown to the worker, in bytes, which has been there for more than
#: 30 seconds. See :class:`MemoryState`.
@@ -455,12 +455,12 @@ class WorkerState:
_has_what: dict[TaskState, None]
#: A dictionary of tasks that have been submitted to this worker. Each task state is
- #: associated with the expected cost in seconds of running that task, summing both
- #: the task's expected computation time and the expected communication time of its
- #: result.
+ #: associated with the expected cost in seconds of running that task, both of
+ #: the task's expected computation time and the expected serial communication time of
+ #: its dependencies.
#:
- #: If a task is already executing on the worker and the excecution time is twice the
- #: learned average TaskGroup duration, this will be set to twice the current
+ #: If a task is already executing on the worker and the execution time is twice the
+ #: learned average TaskGroup duration, the `cpu` time will be set to twice the current
#: executing time. If the task is unknown, the default task duration is used instead
#: of the TaskGroup average.
#:
@@ -470,13 +470,13 @@ class WorkerState:
#:
#: All the tasks here are in the "processing" state.
#: This attribute is kept in sync with :attr:`TaskState.processing_on`.
- processing: dict[TaskState, float]
+ processing: dict[TaskState, Occupancy]
#: Running tasks that invoked :func:`distributed.secede`
long_running: set[TaskState]
#: A dictionary of tasks that are currently being run on this worker.
- #: Each task state is asssociated with the duration in seconds which the task has
+ #: Each task state is associated with the duration in seconds which the task has
#: been running.
executing: dict[TaskState, float]
@@ -527,7 +527,7 @@ def __init__(
self.status = status
self._hash = hash(self.server_id)
self.nbytes = 0
- self.occupancy = 0
+ self.occupancy = Occupancy(0.0, 0.0)
self._memory_unmanaged_old = 0
self._memory_unmanaged_history = deque()
self.metrics = {}
@@ -1387,7 +1387,7 @@ def __init__(
self.task_prefixes: dict[str, TaskPrefix] = {}
self.task_metadata = {} # type: ignore
self.total_nthreads = 0
- self.total_occupancy = 0.0
+ self.total_occupancy = Occupancy(0.0, 0.0)
self.unknown_durations: dict[str, set[TaskState]] = {}
self.queued = queued
self.unrunnable = unrunnable
@@ -1957,9 +1957,9 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None:
wp_vals = worker_pool.values()
n_workers: int = len(wp_vals)
if n_workers < 20: # smart but linear in small case
- ws = min(wp_vals, key=operator.attrgetter("occupancy"))
+ ws = min(wp_vals, key=lambda ws: ws.occupancy.total)
assert ws
- if ws.occupancy == 0:
+ if not ws.occupancy:
# special case to use round-robin; linear search
# for next worker with zero occupancy (or just
# land back where we started).
@@ -1968,7 +1968,7 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None:
i: int
for i in range(n_workers):
wp_i = wp_vals[(i + start) % n_workers]
- if wp_i.occupancy == 0:
+ if not wp_i.occupancy:
ws = wp_i
break
else: # dumb but fast in large case
@@ -2801,19 +2801,17 @@ def _set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> None:
if ts in ws.long_running:
return
- exec_time: float = ws.executing.get(ts, 0)
- duration: float = self.get_task_duration(ts)
- total_duration: float
- if exec_time > 2 * duration:
- total_duration = 2 * exec_time
- else:
- comm: float = self.get_comm_cost(ts, ws)
- total_duration = duration + comm
+ exec_time = ws.executing.get(ts, 0.0)
+ cpu = self.get_task_duration(ts)
+ if exec_time > 2 * cpu:
+ cpu = 2 * exec_time
+ network = self.get_comm_cost(ts, ws)
- old = ws.processing.get(ts, 0)
- ws.processing[ts] = total_duration
- self.total_occupancy += total_duration - old
- ws.occupancy += total_duration - old
+ old = ws.processing.get(ts, Occupancy(0, 0))
+ ws.processing[ts] = new = Occupancy(cpu, network)
+ delta = new - old
+ self.total_occupancy += delta
+ ws.occupancy += delta
def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0):
"""Update the status of the idle and saturated state
@@ -2841,11 +2839,11 @@ def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0):
if self.total_nthreads == 0 or ws.status == Status.closed:
return
if occ < 0:
- occ = ws.occupancy
+ occ = ws.occupancy.total
nc: int = ws.nthreads
p: int = len(ws.processing)
- avg: float = self.total_occupancy / self.total_nthreads
+ avg: float = self.total_occupancy.total / self.total_nthreads
idle = self.idle
saturated = self.saturated
@@ -3003,7 +3001,8 @@ def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple:
nbytes = dts.get_nbytes()
comm_bytes += nbytes
- stack_time: float = ws.occupancy / ws.nthreads
+ # FIXME use `occupancy.cpu` https://github.com/dask/distributed/issues/7003
+ stack_time: float = ws.occupancy.total / ws.nthreads
start_time: float = stack_time + comm_bytes / self.bandwidth
if ts.actor:
@@ -3045,7 +3044,7 @@ def remove_all_replicas(self, ts: TaskState):
def _reevaluate_occupancy_worker(self, ws: WorkerState):
"""See reevaluate_occupancy"""
ts: TaskState
- old = ws.occupancy
+ old = ws.occupancy.total
for ts in ws.processing:
self._set_duration_estimate(ts, ws)
@@ -3053,7 +3052,8 @@ def _reevaluate_occupancy_worker(self, ws: WorkerState):
steal = self.extensions.get("stealing")
if steal is None:
return
- if ws.occupancy > old * 1.3 or old > ws.occupancy * 1.3:
+ current = ws.occupancy.total
+ if current > old * 1.3 or old > current * 1.3:
for ts in ws.processing:
steal.recalculate_cost(ts)
@@ -4142,7 +4142,7 @@ def update_graph(
dependencies = dependencies or {}
- if self.total_occupancy > 1e-9 and self.computations:
+ if self.total_occupancy.total > 1e-9 and self.computations:
# Still working on something. Assign new tasks to same computation
computation = self.computations[-1]
else:
@@ -4879,19 +4879,26 @@ def validate_state(self, allow_overlap: bool = False) -> None:
}
assert a == b, (a, b)
- actual_total_occupancy = 0
+ actual_total_occupancy = Occupancy(0, 0)
for worker, ws in self.workers.items():
ws_processing_total = sum(
- cost for ts, cost in ws.processing.items() if ts not in ws.long_running
+ (
+ cost
+ for ts, cost in ws.processing.items()
+ if ts not in ws.long_running
+ ),
+ start=Occupancy(0, 0),
)
- assert abs(ws_processing_total - ws.occupancy) < 1e-8, (
+ delta = ws_processing_total - ws.occupancy
+ assert abs(delta.cpu) < 1e-8 and abs(delta.network) < 1e-8, (
worker,
ws_processing_total,
ws.occupancy,
)
actual_total_occupancy += ws.occupancy
- assert abs(actual_total_occupancy - self.total_occupancy) < 1e-8, (
+ delta = actual_total_occupancy - self.total_occupancy
+ assert abs(delta.cpu) < 1e-8 and abs(delta.network) < 1e-8, (
actual_total_occupancy,
self.total_occupancy,
)
@@ -5088,7 +5095,7 @@ def handle_long_running(
if key not in self.tasks:
logger.debug("Skipping long_running since key %s was already released", key)
return
- ts = self.tasks[key]
+ ts: TaskState = self.tasks[key]
steal = self.extensions.get("stealing")
if steal is not None:
steal.remove_key_from_stealable(ts)
@@ -5112,7 +5119,7 @@ def handle_long_running(
# idleness detection. Idle workers are typically targeted for
# downscaling but we should not downscale workers with long running
# tasks
- ws.processing[ts] = 0
+ ws.processing[ts].clear()
ws.long_running.add(ts)
self.check_idle_saturated(ws)
@@ -7543,10 +7550,12 @@ def adaptive_target(self, target_duration=None):
# CPU
# TODO consider any user-specified default task durations for queued tasks
- queued_occupancy = len(self.queued) * self.UNKNOWN_TASK_DURATION
+ queued_occupancy: float = len(self.queued) * self.UNKNOWN_TASK_DURATION
+ # TODO: threads per worker
+ # TODO don't include network occupancy?
cpu = math.ceil(
- (self.total_occupancy + queued_occupancy) / target_duration
- ) # TODO: threads per worker
+ (self.total_occupancy.total + queued_occupancy) / target_duration
+ )
# Avoid a few long tasks from asking for many cores
tasks_ready = len(self.queued)
@@ -7693,7 +7702,7 @@ def _exit_processing_common(
ws.long_running.discard(ts)
if not ws.processing:
state.total_occupancy -= ws.occupancy
- ws.occupancy = 0
+ ws.occupancy.clear()
else:
state.total_occupancy -= duration
ws.occupancy -= duration
diff --git a/distributed/stealing.py b/distributed/stealing.py
index b9858f4656a..94d2abdea75 100644
--- a/distributed/stealing.py
+++ b/distributed/stealing.py
@@ -14,6 +14,7 @@
import dask
from dask.utils import parse_timedelta
+from distributed.collections import Occupancy
from distributed.comm.addressing import get_address_host
from distributed.core import CommClosedError, Status
from distributed.diagnostics.plugin import SchedulerPlugin
@@ -56,8 +57,8 @@
class InFlightInfo(TypedDict):
victim: WorkerState
thief: WorkerState
- victim_duration: float
- thief_duration: float
+ victim_duration: Occupancy
+ thief_duration: Occupancy
stimulus_id: str
@@ -78,7 +79,7 @@ class WorkStealing(SchedulerPlugin):
# { task state: }
in_flight: dict[TaskState, InFlightInfo]
# { worker state: occupancy }
- in_flight_occupancy: defaultdict[WorkerState, float]
+ in_flight_occupancy: defaultdict[WorkerState, Occupancy]
_in_flight_event: asyncio.Event
_request_counter: int
@@ -100,7 +101,7 @@ def __init__(self, scheduler: Scheduler):
self.scheduler.events["stealing"] = deque(maxlen=100000)
self.count = 0
self.in_flight = {}
- self.in_flight_occupancy = defaultdict(lambda: 0)
+ self.in_flight_occupancy = defaultdict(lambda: Occupancy(0, 0))
self._in_flight_event = asyncio.Event()
self._request_counter = 0
self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm
@@ -205,7 +206,7 @@ def remove_key_from_stealable(self, ts):
except KeyError:
pass
- def steal_time_ratio(self, ts):
+ def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, None]:
"""The compute to communication time ratio of a key
Returns
@@ -219,16 +220,18 @@ def steal_time_ratio(self, ts):
return None, None
if not ts.dependencies: # no dependencies fast path
- return 0, 0
+ return 0.0, 0
ws = ts.processing_on
+ assert ws
compute_time = ws.processing[ts]
- if compute_time < 0.005: # 5ms, just give up
+ if compute_time.total < 0.005: # 5ms, just give up
return None, None
nbytes = ts.get_nbytes_deps()
- transfer_time = nbytes / self.scheduler.bandwidth + LATENCY
- cost_multiplier = transfer_time / compute_time
+ transfer_time: float = nbytes / self.scheduler.bandwidth + LATENCY
+ # FIXME don't use `compute_time.total` https://github.com/dask/distributed/issues/7003
+ cost_multiplier = transfer_time / compute_time.total
if cost_multiplier > 100:
return None, None
@@ -262,9 +265,10 @@ def move_task_request(
victim_duration = victim.processing[ts]
- thief_duration = self.scheduler.get_task_duration(
- ts
- ) + self.scheduler.get_comm_cost(ts, thief)
+ thief_duration = Occupancy(
+ self.scheduler.get_task_duration(ts),
+ self.scheduler.get_comm_cost(ts, thief),
+ )
self.scheduler.stream_comms[victim.address].send(
{"op": "steal-request", "key": key, "stimulus_id": stimulus_id}
@@ -354,7 +358,7 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None):
self.scheduler.total_occupancy -= duration
if not victim.processing:
self.scheduler.total_occupancy -= victim.occupancy
- victim.occupancy = 0
+ victim.occupancy.clear()
thief.processing[ts] = d["thief_duration"]
thief.occupancy += d["thief_duration"]
self.scheduler.total_occupancy += d["thief_duration"]
@@ -379,27 +383,32 @@ def balance(self):
s = self.scheduler
def combined_occupancy(ws: WorkerState) -> float:
- return ws.occupancy + self.in_flight_occupancy[ws]
+ return ws.occupancy.total + self.in_flight_occupancy[ws].total
def maybe_move_task(
level: int,
ts: TaskState,
victim: WorkerState,
thief: WorkerState,
- duration: float,
+ duration: Occupancy,
cost_multiplier: float,
) -> None:
+ # TODO calculate separately for cpu vs network?
occ_thief = combined_occupancy(thief)
occ_victim = combined_occupancy(victim)
- if occ_thief + cost_multiplier * duration <= occ_victim - duration / 2:
+ duration_total = duration.total
+ if (
+ occ_thief + cost_multiplier * duration_total
+ <= occ_victim - duration_total / 2
+ ):
self.move_task_request(ts, victim, thief)
log.append(
(
start,
level,
ts.key,
- duration,
+ duration_total,
victim.address,
occ_victim,
thief.address,
diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py
index 3336fc8481d..e140d05d863 100644
--- a/distributed/tests/test_cancelled_state.py
+++ b/distributed/tests/test_cancelled_state.py
@@ -1034,7 +1034,7 @@ def f(ev1, ev2, ev3, ev4):
await ev1.wait()
ts = a.state.tasks["x"]
assert ts.state == "executing"
- assert sum(ws.processing.values()) > 0
+ assert any(ws.processing.values())
x.release()
await wait_for_state("x", "cancelled", a)
@@ -1050,7 +1050,7 @@ def f(ev1, ev2, ev3, ev4):
# Test that the scheduler receives a delayed {op: long-running}
assert ws.processing
- while sum(ws.processing.values()):
+ while any(ws.processing.values()):
await asyncio.sleep(0.1)
assert ws.processing
diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py
index 3121dcb6bdd..848c74a5278 100644
--- a/distributed/tests/test_client.py
+++ b/distributed/tests/test_client.py
@@ -70,6 +70,7 @@
wait,
)
from distributed.cluster_dump import load_cluster_dump
+from distributed.collections import Occupancy
from distributed.comm import CommClosedError
from distributed.compatibility import LINUX, WINDOWS
from distributed.core import Status
@@ -5109,8 +5110,11 @@ def long_running(lock, entered):
await entered.wait()
ts = s.tasks[f.key]
ws = s.workers[a.address]
- assert ws.occupancy == parse_timedelta(
- dask.config.get("distributed.scheduler.unknown-task-duration")
+ assert ws.occupancy == Occupancy(
+ cpu=parse_timedelta(
+ dask.config.get("distributed.scheduler.unknown-task-duration"),
+ ),
+ network=0,
)
while ws.occupancy:
@@ -5118,12 +5122,12 @@ def long_running(lock, entered):
await a.heartbeat()
s._set_duration_estimate(ts, ws)
- assert s.workers[a.address].occupancy == 0
- assert s.total_occupancy == 0
- assert ws.occupancy == 0
+ assert s.workers[a.address].occupancy == Occupancy(0, 0)
+ assert s.total_occupancy == Occupancy(0, 0)
+ assert ws.occupancy == Occupancy(0, 0)
s._ongoing_background_tasks.call_soon(s.reevaluate_occupancy, 0)
- assert s.workers[a.address].occupancy == 0
+ assert s.workers[a.address].occupancy == Occupancy(0, 0)
await l.release()
with (
@@ -5133,8 +5137,8 @@ def long_running(lock, entered):
):
await f
- assert s.total_occupancy == 0
- assert ws.occupancy == 0
+ assert s.total_occupancy == Occupancy(0, 0)
+ assert ws.occupancy == Occupancy(0, 0)
assert not ws.long_running
@@ -5176,14 +5180,16 @@ def long_running(lock, entered):
if ordinary_task:
# Should be exactly 0.5 but if for whatever reason this test runs slow,
# some approximation may kick in increasing this number
- assert s.total_occupancy >= 0.5
- assert ws.occupancy >= 0.5
+ assert not s.total_occupancy.network
+ assert not ws.occupancy.network
+ assert s.total_occupancy.cpu >= 0.5
+ assert ws.occupancy.cpu >= 0.5
await l2.release()
await f2
# In the end, everything should be reset
- assert s.total_occupancy == 0
- assert ws.occupancy == 0
+ assert not s.total_occupancy
+ assert not ws.occupancy
assert not ws.long_running
diff --git a/distributed/tests/test_collections.py b/distributed/tests/test_collections.py
index 066cf147a33..3b1ca3b164d 100644
--- a/distributed/tests/test_collections.py
+++ b/distributed/tests/test_collections.py
@@ -7,7 +7,7 @@
import pytest
-from distributed.collections import LRU, HeapSet
+from distributed.collections import LRU, HeapSet, Occupancy
def test_lru():
@@ -339,3 +339,38 @@ def test_heapset_sort_duplicate():
heap.add(c1)
assert list(heap.sorted()) == [c1, c2]
+
+
+def test_occupancy():
+ ozero = Occupancy(0, 0)
+ assert not ozero
+ assert not ozero.total
+ assert ozero == ozero
+
+ o1_0 = Occupancy(1, 0)
+ assert o1_0
+ assert o1_0.total == 1
+ assert ozero != o1_0
+ assert o1_0 + ozero == o1_0
+
+ o0_1 = Occupancy(0, 1)
+ o1_1 = o0_1 + o1_0
+ assert o1_1.total == 2
+ assert o1_1 == o1_0 + o0_1
+
+ assert o1_1 - o0_1 == o1_0
+
+ mut = Occupancy(0, 0)
+ mut += ozero
+ assert not mut
+ mut += o1_0
+ assert mut == o1_0
+ mut += o1_0
+ assert mut == Occupancy(2, 0)
+
+ mut -= o0_1
+ assert mut == Occupancy(2, -1)
+ assert mut.total == 1
+
+ mut.clear()
+ assert mut == ozero
diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py
index 35fe85cdc7c..9a768743e03 100644
--- a/distributed/tests/test_scheduler.py
+++ b/distributed/tests/test_scheduler.py
@@ -1423,12 +1423,13 @@ async def test_learn_occupancy(c, s, a, b):
await asyncio.sleep(0.01)
nproc = sum(ts.state == "processing" for ts in s.tasks.values())
- assert nproc * 0.1 < s.total_occupancy < nproc * 0.4
+ assert not s.total_occupancy.network
+ assert nproc * 0.1 < s.total_occupancy.cpu < nproc * 0.4
for w in [a, b]:
ws = s.workers[w.address]
- occ = ws.occupancy
+ assert not ws.occupancy.network
proc = len(ws.processing)
- assert proc * 0.1 < occ < proc * 0.4
+ assert proc * 0.1 < ws.occupancy.cpu < proc * 0.4
@pytest.mark.slow
@@ -1440,7 +1441,8 @@ async def test_learn_occupancy_2(c, s, a, b):
await asyncio.sleep(0.01)
nproc = sum(ts.state == "processing" for ts in s.tasks.values())
- assert nproc * 0.1 < s.total_occupancy < nproc * 0.4
+ assert not s.total_occupancy.network
+ assert nproc * 0.1 < s.total_occupancy.cpu < nproc * 0.4
@gen_cluster(client=True)
@@ -1448,14 +1450,14 @@ async def test_occupancy_cleardown(c, s, a, b):
s.validate = False
# Inject excess values in s.occupancy
- s.workers[a.address].occupancy = 2
- s.total_occupancy += 2
+ s.workers[a.address].occupancy.cpu += 2
+ s.total_occupancy.cpu += 2
futures = c.map(slowinc, range(100), delay=0.01)
await wait(futures)
# Verify that occupancy values have been zeroed out
- assert abs(s.total_occupancy) < 0.01
- assert all(ws.occupancy == 0 for ws in s.workers.values())
+ assert abs(s.total_occupancy.total) < 0.01
+ assert all(not ws.occupancy for ws in s.workers.values())
@nodebug
@@ -1492,11 +1494,13 @@ async def test_learn_occupancy_multiple_workers(c, s, a, b):
await wait(x)
- assert not any(v == 0.5 for w in s.workers.values() for v in w.processing.values())
+ assert not any(
+ occ.cpu == 0.5 for w in s.workers.values() for occ in w.processing.values()
+ )
@gen_cluster(client=True)
-async def test_include_communication_in_occupancy(c, s, a, b):
+async def test_occupancy_network(c, s, a, b):
await c.submit(slowadd, 1, 2, delay=0)
x = c.submit(operator.mul, b"0", int(s.bandwidth), workers=a.address)
y = c.submit(operator.mul, b"1", int(s.bandwidth * 1.5), workers=b.address)
@@ -1507,7 +1511,9 @@ async def test_include_communication_in_occupancy(c, s, a, b):
ts = s.tasks[z.key]
assert ts.processing_on == s.workers[b.address]
- assert s.workers[b.address].processing[ts] > 1
+ occ = s.workers[b.address].processing[ts]
+ assert occ.network >= 1
+ assert occ.cpu > 0
await wait(z)
del z
diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py
index 6f3874bc1b2..de90c68fb7f 100644
--- a/distributed/tests/test_steal.py
+++ b/distributed/tests/test_steal.py
@@ -1128,7 +1128,7 @@ def block(x, event):
del futs1
- assert all(v == 0 for v in steal.in_flight_occupancy.values())
+ assert all(not v for v in steal.in_flight_occupancy.values())
@gen_cluster(