Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ repos:
- id: isort
language_version: python3
- repo: https://github.com/asottile/pyupgrade
# Do not upgrade: there's a bug in Cython that causes sum(... for ...) to fail;
# it needs sum([... for ...])
rev: v2.13.0
rev: v2.32.0
hooks:
- id: pyupgrade
args:
Expand Down
2 changes: 1 addition & 1 deletion distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from tornado.ioloop import PeriodicCallback

from distributed import cluster_dump, preloading
from distributed import versions as version_module # type: ignore
from distributed import versions as version_module
from distributed.batched import BatchedSend
from distributed.cfexecutor import ClientExecutor
from distributed.core import (
Expand Down
13 changes: 9 additions & 4 deletions distributed/comm/asyncio_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import socket
import struct
import sys
import weakref
from itertools import islice
from typing import Any
Expand Down Expand Up @@ -776,10 +777,14 @@ class _ZeroCopyWriter:
# (which would be very large), and set a limit on the number of buffers to
# pass to sendmsg.
if hasattr(socket.socket, "sendmsg"):
try:
SENDMSG_MAX_COUNT = os.sysconf("SC_IOV_MAX") # type: ignore
except Exception:
SENDMSG_MAX_COUNT = 16 # Should be supported on all systems
# Note: can't use WINDOWS constant as it upsets mypy
if sys.platform == "win32":
SENDMSG_MAX_COUNT = 16 # No os.sysconf available
else:
try:
SENDMSG_MAX_COUNT = os.sysconf("SC_IOV_MAX")
except Exception:
SENDMSG_MAX_COUNT = 16 # Should be supported on all systems
else:
SENDMSG_MAX_COUNT = 1 # sendmsg not supported, use send instead

Expand Down
2 changes: 1 addition & 1 deletion distributed/comm/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __call__(self, **kwargs: str) -> Iterable[importlib.metadata.EntryPoint]:
if sys.version_info >= (3, 10):
# py3.10 importlib.metadata type annotations are not in mypy yet
# https://github.com/python/typeshed/pull/7331
_entry_points: _EntryPoints = importlib.metadata.entry_points # type: ignore[assignment]
_entry_points: _EntryPoints = importlib.metadata.entry_points
else:

def _entry_points(
Expand Down
2 changes: 1 addition & 1 deletion distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
except ImportError:
pass
else:
ucp = None # type: ignore
ucp = None

device_array = None
pre_existing_cuda_context = False
Expand Down
2 changes: 1 addition & 1 deletion distributed/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

LINUX = sys.platform == "linux"
MACOS = sys.platform == "darwin"
WINDOWS = sys.platform.startswith("win")
WINDOWS = sys.platform == "win32"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



if sys.version_info >= (3, 9):
Expand Down
4 changes: 2 additions & 2 deletions distributed/dashboard/components/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2794,7 +2794,7 @@ def _get_timeseries(self, restrict_to_existing=False):
back = None
# Remove any periods of zero compute at the front or back of the timeseries
if len(self.plugin.compute):
agg = sum([np.array(v[front:]) for v in self.plugin.compute.values()])
agg = sum(np.array(v[front:]) for v in self.plugin.compute.values())
front2 = len(agg) - len(np.trim_zeros(agg, trim="f"))
front += front2
back = len(np.trim_zeros(agg, trim="b")) - len(agg) or None
Expand Down Expand Up @@ -3192,7 +3192,7 @@ def update(self):
"names": ["Scheduler", "Workers"],
"values": [
s._tick_interval_observed,
sum([w.metrics["event_loop_interval"] for w in s.workers.values()])
sum(w.metrics["event_loop_interval"] for w in s.workers.values())
/ (len(s.workers) or 1),
],
}
Expand Down
2 changes: 1 addition & 1 deletion distributed/http/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_handlers(server, modules: list[str], prefix="/"):
_routes = []
for module_name in modules:
module = importlib.import_module(module_name)
_routes.extend(module.routes) # type: ignore
_routes.extend(module.routes)

routes = []

Expand Down
6 changes: 3 additions & 3 deletions distributed/pytest_resourceleaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def test1():
import psutil
import pytest

from distributed.compatibility import WINDOWS
from distributed.metrics import time


Expand Down Expand Up @@ -155,10 +154,11 @@ def format(self, before: int, after: int) -> str:

class FDChecker(ResourceChecker, name="fds"):
def measure(self) -> int:
if WINDOWS:
# Note: can't use WINDOWS constant as it upsets mypy
if sys.platform == "win32":
# Don't use num_handles(); you'll get tens of thousands of reported leaks
return 0
return psutil.Process().num_fds() # type: ignore
return psutil.Process().num_fds()

def has_leak(self, before: int, after: int) -> bool:
return after > before
Expand Down
68 changes: 32 additions & 36 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,14 +744,14 @@ def __repr__(self) -> str:

@property
def nbytes_total(self) -> int:
return sum([tg.nbytes_total for tg in self.groups])
return sum(tg.nbytes_total for tg in self.groups)

def __len__(self) -> int:
return sum(map(len, self.groups))

@property
def duration(self) -> float:
return sum([tg.duration for tg in self.groups])
return sum(tg.duration for tg in self.groups)

@property
def types(self) -> set[str]:
Expand Down Expand Up @@ -1400,7 +1400,9 @@ def new_task(
# State Transitions #
#####################

def _transition(self, key, finish: str, stimulus_id: str, *args, **kwargs):
def _transition(
self, key: str, finish: str, stimulus_id: str, *args, **kwargs
) -> tuple[dict, dict, dict]:
"""Transition a key from its current state to the finish state

Examples
Expand Down Expand Up @@ -1432,9 +1434,9 @@ def _transition(self, key, finish: str, stimulus_id: str, *args, **kwargs):
if self.transition_counter_max:
assert self.transition_counter < self.transition_counter_max

recommendations = {} # type: ignore
worker_msgs = {} # type: ignore
client_msgs = {} # type: ignore
recommendations: dict = {}
worker_msgs: dict = {}
client_msgs: dict = {}

if self.plugins:
dependents = set(ts.dependents)
Expand All @@ -1444,47 +1446,41 @@ def _transition(self, key, finish: str, stimulus_id: str, *args, **kwargs):
if func is not None:
recommendations, client_msgs, worker_msgs = func(
self, key, stimulus_id, *args, **kwargs
) # type: ignore
)

elif "released" not in (start, finish):
assert not args and not kwargs, (args, kwargs, start, finish)
a_recs: dict
a_cmsgs: dict
a_wmsgs: dict
a: tuple = self._transition(key, "released", stimulus_id)
a_recs, a_cmsgs, a_wmsgs = a
a_recs, a_cmsgs, a_wmsgs = self._transition(
key, "released", stimulus_id
)

v = a_recs.get(key, finish)
func = self._TRANSITIONS_TABLE["released", v]
b_recs: dict
b_cmsgs: dict
b_wmsgs: dict
b: tuple = func(self, key, stimulus_id) # type: ignore
b_recs, b_cmsgs, b_wmsgs = b
b_recs, b_cmsgs, b_wmsgs = func(self, key, stimulus_id)

recommendations.update(a_recs)
for c, new_msgs in a_cmsgs.items():
msgs = client_msgs.get(c) # type: ignore
msgs = client_msgs.get(c)
if msgs is not None:
msgs.extend(new_msgs)
else:
client_msgs[c] = new_msgs
for w, new_msgs in a_wmsgs.items():
msgs = worker_msgs.get(w) # type: ignore
msgs = worker_msgs.get(w)
if msgs is not None:
msgs.extend(new_msgs)
else:
worker_msgs[w] = new_msgs

recommendations.update(b_recs)
for c, new_msgs in b_cmsgs.items():
msgs = client_msgs.get(c) # type: ignore
msgs = client_msgs.get(c)
if msgs is not None:
msgs.extend(new_msgs)
else:
client_msgs[c] = new_msgs
for w, new_msgs in b_wmsgs.items():
msgs = worker_msgs.get(w) # type: ignore
msgs = worker_msgs.get(w)
if msgs is not None:
msgs.extend(new_msgs)
else:
Expand Down Expand Up @@ -1953,7 +1949,7 @@ def transition_processing_memory(
assert not ts.exception_blame
assert ts.state == "processing"

ws = self.workers.get(worker) # type: ignore
ws = self.workers.get(worker)
if ws is None:
recommendations[key] = "released"
return recommendations, client_msgs, worker_msgs
Expand Down Expand Up @@ -2280,7 +2276,7 @@ def transition_processing_erred(
traceback=None,
exception_text: str = None,
traceback_text: str = None,
worker: str = None, # type: ignore
worker: str = None,
**kwargs,
):
ws: WorkerState
Expand Down Expand Up @@ -3455,7 +3451,7 @@ def heartbeat_worker(
) -> dict[str, Any]:
address = self.coerce_address(address, resolve_address)
address = normalize_address(address)
ws: WorkerState = self.workers.get(address) # type: ignore
ws = self.workers.get(address)
if ws is None:
return {"status": "missing"}

Expand Down Expand Up @@ -4778,7 +4774,7 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None):
def handle_worker_status_change(
self, status: str, worker: str, stimulus_id: str
) -> None:
ws: WorkerState = self.workers.get(worker) # type: ignore
ws = self.workers.get(worker)
if not ws:
return
prev_status = ws.status
Expand Down Expand Up @@ -5290,9 +5286,9 @@ async def gather_on_worker(
)
return set(who_has)

ws: WorkerState = self.workers.get(worker_address) # type: ignore
ws = self.workers.get(worker_address)

if ws is None:
if not ws:
logger.warning(f"Worker {worker_address} lost during replication")
return set(who_has)
elif result["status"] == "OK":
Expand Down Expand Up @@ -5344,8 +5340,8 @@ async def delete_worker_data(
)
return

ws: WorkerState = self.workers.get(worker_address) # type: ignore
if ws is None:
ws = self.workers.get(worker_address)
if not ws:
return

for key in keys:
Expand Down Expand Up @@ -5917,9 +5913,9 @@ def workers_to_close(
groups = groupby(key, self.workers.values())

limit_bytes = {
k: sum([ws.memory_limit for ws in v]) for k, v in groups.items()
k: sum(ws.memory_limit for ws in v) for k, v in groups.items()
}
group_bytes = {k: sum([ws.nbytes for ws in v]) for k, v in groups.items()}
group_bytes = {k: sum(ws.nbytes for ws in v) for k, v in groups.items()}

limit = sum(limit_bytes.values())
total = sum(group_bytes.values())
Expand Down Expand Up @@ -6871,8 +6867,8 @@ def profile_to_figure(state):
tasks_timings=tasks_timings,
address=self.address,
nworkers=len(self.workers),
threads=sum([ws.nthreads for ws in self.workers.values()]),
memory=format_bytes(sum([ws.memory_limit for ws in self.workers.values()])),
threads=sum(ws.nthreads for ws in self.workers.values()),
memory=format_bytes(sum(ws.memory_limit for ws in self.workers.values())),
code=code,
dask_version=dask.__version__,
distributed_version=distributed.__version__,
Expand Down Expand Up @@ -7106,8 +7102,8 @@ def adaptive_target(self, target_duration=None):
cpu = max(1, cpu)

# add more workers if more than 60% of memory is used
limit = sum([ws.memory_limit for ws in self.workers.values()])
used = sum([ws.nbytes for ws in self.workers.values()])
limit = sum(ws.memory_limit for ws in self.workers.values())
used = sum(ws.nbytes for ws in self.workers.values())
memory = 0
if used > 0.6 * limit and limit > 0:
memory = 2 * len(self.workers)
Expand Down Expand Up @@ -7519,7 +7515,7 @@ def validate_task_state(ts: TaskState) -> None:

if ts.actor:
if ts.state == "memory":
assert sum([ts in ws.actors for ws in ts.who_has]) == 1
assert sum(ts in ws.actors for ws in ts.who_has) == 1
if ts.state == "processing":
assert ts.processing_on
assert ts in ts.processing_on.actors
Expand Down
2 changes: 1 addition & 1 deletion distributed/shuffle/shuffle_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def get_output_partition(self, i: int) -> pd.DataFrame:
self.output_partitions_left > 0
), f"No outputs remaining, but requested output partition {i} on {self.worker.address}."

sync(self.worker.loop, self.multi_file.flush) # type: ignore
sync(self.worker.loop, self.multi_file.flush)
try:
df = self.multi_file.read(i)
with self.time("cpu"):
Expand Down
2 changes: 1 addition & 1 deletion distributed/spill.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class SpilledSize(NamedTuple):
def __add__(self, other: SpilledSize) -> SpilledSize: # type: ignore
return SpilledSize(self.memory + other.memory, self.disk + other.disk)

def __sub__(self, other: SpilledSize) -> SpilledSize: # type: ignore
def __sub__(self, other: SpilledSize) -> SpilledSize:
return SpilledSize(self.memory - other.memory, self.disk - other.disk)


Expand Down
18 changes: 10 additions & 8 deletions distributed/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def memory_limit() -> int:
limit = psutil.virtual_memory().total

# Check cgroups if available
# Note: can't use LINUX and WINDOWS constants as they upset mypy
if sys.platform == "linux":
try:
with open("/sys/fs/cgroup/memory/memory.limit_in_bytes") as f:
Expand All @@ -27,14 +28,15 @@ def memory_limit() -> int:
pass

# Check rlimit if available
try:
import resource

hard_limit = resource.getrlimit(resource.RLIMIT_RSS)[1] # type: ignore
if hard_limit > 0:
limit = min(limit, hard_limit)
except (ImportError, OSError):
pass
if sys.platform != "win32":
try:
import resource

hard_limit = resource.getrlimit(resource.RLIMIT_RSS)[1]
if hard_limit > 0:
limit = min(limit, hard_limit)
except (ImportError, OSError):
pass

return limit

Expand Down
10 changes: 9 additions & 1 deletion distributed/tests/test_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pathlib
import signal
import socket
import sys
import threading
from contextlib import contextmanager
from time import sleep
Expand Down Expand Up @@ -563,9 +564,16 @@ async def test_dump_cluster_unresponsive_remote_worker(c, s, a, b, tmpdir):
clog_fut.cancel()


# Note: can't use WINDOWS constant as it upsets mypy
if sys.platform == "win32":
TERM_SIGNALS = (signal.SIGTERM, signal.SIGINT)
else:
TERM_SIGNALS = (signal.SIGTERM, signal.SIGHUP, signal.SIGINT)


def garbage_process(barrier, ignore_sigterm: bool = False, t: float = 3600) -> None:
if ignore_sigterm:
for signum in (signal.SIGTERM, signal.SIGHUP, signal.SIGINT): # type: ignore
for signum in TERM_SIGNALS:
signal.signal(signum, signal.SIG_IGN)
barrier.wait()
sleep(t)
Expand Down
Loading