diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 243f86d577d..e60a63be8e0 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -921,6 +921,16 @@ class TaskGroup: The result types of this TaskGroup + .. attribute:: last_worker: WorkerState + + The worker most recently assigned a task from this group, or None when the group + is not identified to be root-like by `SchedulerState.decide_worker`. + + .. attribute:: last_worker_tasks_left: int + + If `last_worker` is not None, the number of times that worker should be assigned + subsequent tasks until a new worker is chosen. + See also -------- TaskPrefix @@ -936,6 +946,8 @@ class TaskGroup: _start: double _stop: double _all_durations: object + _last_worker: WorkerState + _last_worker_tasks_left: Py_ssize_t def __init__(self, name: str): self._name = name @@ -949,6 +961,8 @@ def __init__(self, name: str): self._start = 0.0 self._stop = 0.0 self._all_durations = defaultdict(float) + self._last_worker = None + self._last_worker_tasks_left = 0 @property def name(self): @@ -990,6 +1004,14 @@ def start(self): def stop(self): return self._stop + @property + def last_worker(self): + return self._last_worker + + @property + def last_worker_tasks_left(self): + return self._last_worker_tasks_left + @ccall def add(self, o): ts: TaskState = o @@ -2309,21 +2331,60 @@ def transition_no_worker_waiting(self, key): @exceptval(check=False) def decide_worker(self, ts: TaskState) -> WorkerState: """ - Decide on a worker for task *ts*. Return a WorkerState. + Decide on a worker for task *ts*. Return a WorkerState. + + If it's a root or root-like task, we place it with its relatives to + reduce future data tansfer. + + If it has dependencies or restrictions, we use + `decide_worker_from_deps_and_restrictions`. + + Otherwise, we pick the least occupied worker, or pick from all workers + in a round-robin fashion. """ + if not self._workers_dv: + return None + ws: WorkerState = None + group: TaskGroup = ts._group valid_workers: set = self.valid_workers(ts) if ( valid_workers is not None and not valid_workers and not ts._loose_restrictions - and self._workers_dv ): self._unrunnable.add(ts) ts.state = "no-worker" return ws + # Group is larger than cluster with few dependencies? Minimize future data transfers. + if ( + valid_workers is None + and len(group) > self._total_nthreads * 2 + and sum(map(len, group._dependencies)) < 5 + ): + ws: WorkerState = group._last_worker + + if not ( + ws and group._last_worker_tasks_left and ws._address in self._workers_dv + ): + # Last-used worker is full or unknown; pick a new worker for the next few tasks + ws = min( + (self._idle_dv or self._workers_dv).values(), + key=partial(self.worker_objective, ts), + ) + group._last_worker_tasks_left = math.floor( + (len(group) / self._total_nthreads) * ws._nthreads + ) + + # Record `last_worker`, or clear it on the final task + group._last_worker = ( + ws if group.states["released"] + group.states["waiting"] > 1 else None + ) + group._last_worker_tasks_left -= 1 + return ws + if ts._dependencies or valid_workers is not None: ws = decide_worker( ts, @@ -2332,6 +2393,7 @@ def decide_worker(self, ts: TaskState) -> WorkerState: partial(self.worker_objective, ts), ) else: + # Fastpath when there are no related tasks or restrictions worker_pool = self._idle or self._workers worker_pool_dv = cast(dict, worker_pool) wp_vals = worker_pool.values() diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 8d38c69e10b..1608a7d67b0 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -17,7 +17,7 @@ import dask from dask import delayed -from dask.utils import apply, parse_timedelta +from dask.utils import apply, parse_timedelta, stringify from distributed import Client, Nanny, Worker, fire_and_forget, wait from distributed.comm import Comm @@ -126,6 +126,114 @@ async def test_decide_worker_with_restrictions(client, s, a, b, c): assert x.key in a.data or x.key in b.data +@pytest.mark.parametrize("ndeps", [0, 1, 4]) +@pytest.mark.parametrize( + "nthreads", + [ + [("127.0.0.1", 1)] * 5, + [("127.0.0.1", 3), ("127.0.0.1", 2), ("127.0.0.1", 1)], + ], +) +def test_decide_worker_coschedule_order_neighbors(ndeps, nthreads): + @gen_cluster( + client=True, + nthreads=nthreads, + config={"distributed.scheduler.work-stealing": False}, + ) + async def test(c, s, *workers): + r""" + Ensure that sibling root tasks are scheduled to the same node, reducing future data transfer. + + We generate a wide layer of "root" tasks (random NumPy arrays). All of those tasks share 0-5 + trivial dependencies. The ``ndeps=0`` and ``ndeps=1`` cases are most common in real-world use + (``ndeps=1`` is basically ``da.from_array(..., inline_array=False)`` or ``da.from_zarr``). + The graph is structured like this (though the number of tasks and workers is different): + + |-W1-| |-W2-| |-W3-| |-W4-| < ---- ideal task scheduling + + q r s t < --- `sum-aggregate-` + / \ / \ / \ / \ + i j k l m n o p < --- `sum-` + | | | | | | | | + a b c d e f g h < --- `random-` + \ \ \ | | / / / + TRIVIAL * 0..5 + + Neighboring `random-` tasks should be scheduled on the same worker. We test that generally, + only one worker holds each row of the array, that the `random-` tasks are never transferred, + and that there are few transfers overall. + """ + da = pytest.importorskip("dask.array") + np = pytest.importorskip("numpy") + + if ndeps == 0: + x = da.random.random((100, 100), chunks=(10, 10)) + else: + + def random(**kwargs): + assert len(kwargs) == ndeps + return np.random.random((10, 10)) + + trivial_deps = {f"k{i}": delayed(object()) for i in range(ndeps)} + + # TODO is there a simpler (non-blockwise) way to make this sort of graph? + x = da.blockwise( + random, + "yx", + new_axes={"y": (10,) * 10, "x": (10,) * 10}, + dtype=float, + **trivial_deps, + ) + + xx, xsum = dask.persist(x, x.sum(axis=1, split_every=20)) + await xsum + + # Check that each chunk-row of the array is (mostly) stored on the same worker + primary_worker_key_fractions = [] + secondary_worker_key_fractions = [] + for i, keys in enumerate(x.__dask_keys__()): + # Iterate along rows of the array. + keys = set(stringify(k) for k in keys) + + # No more than 2 workers should have any keys + assert sum(any(k in w.data for k in keys) for w in workers) <= 2 + + # What fraction of the keys for this row does each worker hold? + key_fractions = [ + len(set(w.data).intersection(keys)) / len(keys) for w in workers + ] + key_fractions.sort() + # Primary worker: holds the highest percentage of keys + # Secondary worker: holds the second highest percentage of keys + primary_worker_key_fractions.append(key_fractions[-1]) + secondary_worker_key_fractions.append(key_fractions[-2]) + + # There may be one or two rows that were poorly split across workers, + # but the vast majority of rows should only be on one worker. + assert np.mean(primary_worker_key_fractions) >= 0.9 + assert np.median(primary_worker_key_fractions) == 1.0 + assert np.mean(secondary_worker_key_fractions) <= 0.1 + assert np.median(secondary_worker_key_fractions) == 0.0 + + # Check that there were few transfers + unexpected_transfers = [] + for worker in workers: + for log in worker.incoming_transfer_log: + keys = log["keys"] + # The root-ish tasks should never be transferred + assert not any(k.startswith("random") for k in keys), keys + # `object-` keys (the trivial deps of the root random tasks) should be transferred + if any(not k.startswith("object") for k in keys): + # But not many other things should be + unexpected_transfers.append(list(keys)) + + # A transfer at the very end to move aggregated results is fine (necessary with unbalanced workers in fact), + # but generally there should be very very few transfers. + assert len(unexpected_transfers) <= 3, unexpected_transfers + + test() + + @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) async def test_move_data_over_break_restrictions(client, s, a, b, c): [x] = await client.scatter([1], workers=b.address) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index c47b3127b0f..7ca8f29bcc4 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -146,21 +146,18 @@ async def test_steal_related_tasks(e, s, a, b, c): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10, timeout=1000) async def test_dont_steal_fast_tasks_compute_time(c, s, *workers): - np = pytest.importorskip("numpy") - x = c.submit(np.random.random, 10000000, workers=workers[0].address) - def do_nothing(x, y=None): pass - # execute and measure runtime once - await wait(c.submit(do_nothing, 1)) + xs = c.map(do_nothing, range(10), workers=workers[0].address) + await wait(xs) - futures = c.map(do_nothing, range(1000), y=x) + futures = c.map(do_nothing, range(1000), y=xs) await wait(futures) - assert len(s.who_has[x.key]) == 1 - assert len(s.has_what[workers[0].address]) == 1001 + assert len(set.union(*(s.who_has[x.key] for x in xs))) == 1 + assert len(s.has_what[workers[0].address]) == len(xs) + len(futures) @gen_cluster(client=True) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 903241f7225..c38b8b76363 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1590,12 +1590,12 @@ async def test_lifetime(cleanup): async with Scheduler() as s: async with Worker(s.address) as a, Worker(s.address, lifetime="1 seconds") as b: async with Client(s.address, asynchronous=True) as c: - futures = c.map(slowinc, range(200), delay=0.1) + futures = c.map(slowinc, range(200), delay=0.1, worker=[b.address]) await asyncio.sleep(1.5) assert b.status != Status.running await b.finished() - assert set(b.data).issubset(a.data) # successfully moved data over + assert set(b.data) == set(a.data) # successfully moved data over @gen_cluster(client=True, worker_kwargs={"lifetime": "10s", "lifetime_stagger": "2s"})