Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,16 @@ class TaskGroup:

The result types of this TaskGroup

.. attribute:: last_worker: WorkerState

The worker most recently assigned a task from this group, or None when the group
is not identified to be root-like by `SchedulerState.decide_worker`.

.. attribute:: last_worker_tasks_left: int

If `last_worker` is not None, the number of times that worker should be assigned
subsequent tasks until a new worker is chosen.

See also
--------
TaskPrefix
Expand All @@ -936,6 +946,8 @@ class TaskGroup:
_start: double
_stop: double
_all_durations: object
_last_worker: WorkerState
_last_worker_tasks_left: Py_ssize_t

def __init__(self, name: str):
self._name = name
Expand All @@ -949,6 +961,8 @@ def __init__(self, name: str):
self._start = 0.0
self._stop = 0.0
self._all_durations = defaultdict(float)
self._last_worker = None
self._last_worker_tasks_left = 0

@property
def name(self):
Expand Down Expand Up @@ -990,6 +1004,14 @@ def start(self):
def stop(self):
return self._stop

@property
def last_worker(self):
return self._last_worker

@property
def last_worker_tasks_left(self):
return self._last_worker_tasks_left

@ccall
def add(self, o):
ts: TaskState = o
Expand Down Expand Up @@ -2309,21 +2331,60 @@ def transition_no_worker_waiting(self, key):
@exceptval(check=False)
def decide_worker(self, ts: TaskState) -> WorkerState:
"""
Decide on a worker for task *ts*. Return a WorkerState.
Decide on a worker for task *ts*. Return a WorkerState.

If it's a root or root-like task, we place it with its relatives to
reduce future data tansfer.

If it has dependencies or restrictions, we use
`decide_worker_from_deps_and_restrictions`.

Otherwise, we pick the least occupied worker, or pick from all workers
in a round-robin fashion.
"""
if not self._workers_dv:
return None

ws: WorkerState = None
group: TaskGroup = ts._group
valid_workers: set = self.valid_workers(ts)

if (
valid_workers is not None
and not valid_workers
and not ts._loose_restrictions
and self._workers_dv
):
self._unrunnable.add(ts)
ts.state = "no-worker"
return ws

# Group is larger than cluster with few dependencies? Minimize future data transfers.
if (
valid_workers is None
and len(group) > self._total_nthreads * 2
and sum(map(len, group._dependencies)) < 5
Comment on lines +2363 to +2365
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if it is possible to construct a test which helps us with this cutoff condition. One of the motivators for this heuristics is that this is likely acting as a good binary classifier for graphs where most have either very small or very large numbers here. afaiu, this is an unproven assumption

I'm interested in a test for this particular boundary condition for two reasons

  • It would helps us to identify regressions when this boundary is accidentally moved
  • How good is our classifier? Might it be 40% better if this value was 10? A well written test could help in an analysis for this

I won't push hard on this if it proves too difficult or others disagree on the value. I'm just having a hard time with heuristics if I can change them without tests breaking.

For the current test I can do either of the following without the test breaking

  • remove len(group) > self._total_nthreads * 2 entirely
  • Increase the boundary for total thread count, e.g. len(group) > self._total_nthreads * 10 (it breaks eventually if pushed further)
  • Increase the boundary for dependencies sum(map(len, group._dependencies)) < 100 (increased further, it breaks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In 90% of cases the number here is 0 (like da.random.random) or 1 (like da.from_zarr). I can imagine but can't actually come up with cases where this might be 2 (like da.from_zarr(zarr_array, parameter=some_dask_thing))

I get the aversion to magic numbers. This one feels pretty safe to me though.

):
ws: WorkerState = group._last_worker

if not (
ws and group._last_worker_tasks_left and ws._address in self._workers_dv
):
# Last-used worker is full or unknown; pick a new worker for the next few tasks
ws = min(
(self._idle_dv or self._workers_dv).values(),
key=partial(self.worker_objective, ts),
)
group._last_worker_tasks_left = math.floor(
(len(group) / self._total_nthreads) * ws._nthreads
)

# Record `last_worker`, or clear it on the final task
group._last_worker = (
ws if group.states["released"] + group.states["waiting"] > 1 else None
)
group._last_worker_tasks_left -= 1
return ws

if ts._dependencies or valid_workers is not None:
ws = decide_worker(
ts,
Expand All @@ -2332,6 +2393,7 @@ def decide_worker(self, ts: TaskState) -> WorkerState:
partial(self.worker_objective, ts),
)
else:
# Fastpath when there are no related tasks or restrictions
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm realizing that this codepath will now only be rarely triggered (when there are 0 deps, but also the TaskGroup is small). Do we need to add this round-robining into our selection of a new worker for root-ish tasks? (Since we know we'll be running the tasks on every worker, I'm not sure it matters much that we may always start with the same one in an idle cluster.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly. More broadly this is probably a good reminder that while we run this on some larger example computations we should also remember to try looking at some profiles of the scheduler to see if/how things have changed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gjoseph92 did you run into issues with this yet? I'm curious, have you tried using many workers? (for some sensible definition of many)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gjoseph92 did this come up in profiling? this seems like the only pending comment. I'd like to get this in if possible

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It did not come up in profiling, and I haven't run into any issues with it. I feel pretty confident that round-robining is irrelevant when we're running TaskGroups larger than the cluster. I mostly brought it up because this branch is now pretty long and complicated for a codepath that we'll almost never go down. But maybe that's okay.

have you tried using many workers?

I haven't tried pangeo-style workloads with >30 workers, but I have tried my standard shuffle-profile with this which prompted 91aee92 which I need to look into a little more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My hope is that we can get this in today or tomorrow. Is that hope achievable? If not, do you have a sense for what a reasonably deadline would be?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Today.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I shall prepare the happy dance.

worker_pool = self._idle or self._workers
worker_pool_dv = cast(dict, worker_pool)
wp_vals = worker_pool.values()
Expand Down
110 changes: 109 additions & 1 deletion distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import dask
from dask import delayed
from dask.utils import apply, parse_timedelta
from dask.utils import apply, parse_timedelta, stringify

from distributed import Client, Nanny, Worker, fire_and_forget, wait
from distributed.comm import Comm
Expand Down Expand Up @@ -126,6 +126,114 @@ async def test_decide_worker_with_restrictions(client, s, a, b, c):
assert x.key in a.data or x.key in b.data


@pytest.mark.parametrize("ndeps", [0, 1, 4])
@pytest.mark.parametrize(
"nthreads",
[
[("127.0.0.1", 1)] * 5,
[("127.0.0.1", 3), ("127.0.0.1", 2), ("127.0.0.1", 1)],
],
)
def test_decide_worker_coschedule_order_neighbors(ndeps, nthreads):
@gen_cluster(
client=True,
nthreads=nthreads,
config={"distributed.scheduler.work-stealing": False},
)
async def test(c, s, *workers):
r"""
Ensure that sibling root tasks are scheduled to the same node, reducing future data transfer.

We generate a wide layer of "root" tasks (random NumPy arrays). All of those tasks share 0-5
trivial dependencies. The ``ndeps=0`` and ``ndeps=1`` cases are most common in real-world use
(``ndeps=1`` is basically ``da.from_array(..., inline_array=False)`` or ``da.from_zarr``).
The graph is structured like this (though the number of tasks and workers is different):

|-W1-| |-W2-| |-W3-| |-W4-| < ---- ideal task scheduling

q r s t < --- `sum-aggregate-`
/ \ / \ / \ / \
i j k l m n o p < --- `sum-`
| | | | | | | |
a b c d e f g h < --- `random-`
\ \ \ | | / / /
TRIVIAL * 0..5

Neighboring `random-` tasks should be scheduled on the same worker. We test that generally,
only one worker holds each row of the array, that the `random-` tasks are never transferred,
and that there are few transfers overall.
"""
da = pytest.importorskip("dask.array")
np = pytest.importorskip("numpy")

if ndeps == 0:
x = da.random.random((100, 100), chunks=(10, 10))
else:

def random(**kwargs):
assert len(kwargs) == ndeps
return np.random.random((10, 10))

trivial_deps = {f"k{i}": delayed(object()) for i in range(ndeps)}

# TODO is there a simpler (non-blockwise) way to make this sort of graph?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO still relevant?

x = da.blockwise(
random,
"yx",
new_axes={"y": (10,) * 10, "x": (10,) * 10},
dtype=float,
**trivial_deps,
)

xx, xsum = dask.persist(x, x.sum(axis=1, split_every=20))
await xsum

# Check that each chunk-row of the array is (mostly) stored on the same worker
primary_worker_key_fractions = []
secondary_worker_key_fractions = []
for i, keys in enumerate(x.__dask_keys__()):
# Iterate along rows of the array.
keys = set(stringify(k) for k in keys)

# No more than 2 workers should have any keys
assert sum(any(k in w.data for k in keys) for w in workers) <= 2

# What fraction of the keys for this row does each worker hold?
key_fractions = [
len(set(w.data).intersection(keys)) / len(keys) for w in workers
]
key_fractions.sort()
# Primary worker: holds the highest percentage of keys
# Secondary worker: holds the second highest percentage of keys
primary_worker_key_fractions.append(key_fractions[-1])
secondary_worker_key_fractions.append(key_fractions[-2])

# There may be one or two rows that were poorly split across workers,
# but the vast majority of rows should only be on one worker.
assert np.mean(primary_worker_key_fractions) >= 0.9
assert np.median(primary_worker_key_fractions) == 1.0
assert np.mean(secondary_worker_key_fractions) <= 0.1
assert np.median(secondary_worker_key_fractions) == 0.0

# Check that there were few transfers
unexpected_transfers = []
for worker in workers:
for log in worker.incoming_transfer_log:
keys = log["keys"]
# The root-ish tasks should never be transferred
assert not any(k.startswith("random") for k in keys), keys
# `object-` keys (the trivial deps of the root random tasks) should be transferred
if any(not k.startswith("object") for k in keys):
# But not many other things should be
unexpected_transfers.append(list(keys))

# A transfer at the very end to move aggregated results is fine (necessary with unbalanced workers in fact),
# but generally there should be very very few transfers.
assert len(unexpected_transfers) <= 3, unexpected_transfers

test()


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3)
async def test_move_data_over_break_restrictions(client, s, a, b, c):
[x] = await client.scatter([1], workers=b.address)
Expand Down
13 changes: 5 additions & 8 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,21 +146,18 @@ async def test_steal_related_tasks(e, s, a, b, c):

@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10, timeout=1000)
async def test_dont_steal_fast_tasks_compute_time(c, s, *workers):
np = pytest.importorskip("numpy")
x = c.submit(np.random.random, 10000000, workers=workers[0].address)

def do_nothing(x, y=None):
pass

# execute and measure runtime once
await wait(c.submit(do_nothing, 1))
xs = c.map(do_nothing, range(10), workers=workers[0].address)
await wait(xs)

futures = c.map(do_nothing, range(1000), y=x)
futures = c.map(do_nothing, range(1000), y=xs)

await wait(futures)

assert len(s.who_has[x.key]) == 1
assert len(s.has_what[workers[0].address]) == 1001
assert len(set.union(*(s.who_has[x.key] for x in xs))) == 1
assert len(s.has_what[workers[0].address]) == len(xs) + len(futures)


@gen_cluster(client=True)
Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,12 +1590,12 @@ async def test_lifetime(cleanup):
async with Scheduler() as s:
async with Worker(s.address) as a, Worker(s.address, lifetime="1 seconds") as b:
async with Client(s.address, asynchronous=True) as c:
futures = c.map(slowinc, range(200), delay=0.1)
futures = c.map(slowinc, range(200), delay=0.1, worker=[b.address])
await asyncio.sleep(1.5)
assert b.status != Status.running
await b.finished()

assert set(b.data).issubset(a.data) # successfully moved data over
assert set(b.data) == set(a.data) # successfully moved data over


@gen_cluster(client=True, worker_kwargs={"lifetime": "10s", "lifetime_stagger": "2s"})
Expand Down