From e9928dcfa688331b90a4c7cfbfcaf4cff519e750 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 3 Nov 2022 16:15:31 +0100 Subject: [PATCH 01/17] Refactor assert_balanced --- distributed/tests/test_steal.py | 62 ++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index eb66ec7ec12..dc0d5469f14 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -672,9 +672,9 @@ def block(*args, event, **kwargs): counter = itertools.count() - futures = [] - for w, ts in zip(workers, inp): - for t in sorted(ts, reverse=True): + futures_per_worker = defaultdict(list) + for w, tasks in zip(workers, inp): + for t in sorted(tasks, reverse=True): if t: [dat] = await c.scatter( [gen_nbytes(int(t * s.bandwidth))], workers=w.address @@ -692,33 +692,45 @@ def block(*args, event, **kwargs): pure=False, priority=-i, ) - futures.append(f) + futures_per_worker[w].append(f) - while len([ts for ts in s.tasks.values() if ts.processing_on]) < len(futures): - await asyncio.sleep(0.001) + # Make sure all tasks are scheduled on the workers + # We are relying on the futures not to be rootish (and thus not to remain in the + # scheduler-side queue) because they have worker restrictions + wait_for_states = [] + for w, fs in futures_per_worker.items(): + for i, f in enumerate(fs): + # Make sure the first task is executing, all others are ready + state = "executing" if i == 0 else "ready" + wait_for_states.append(wait_for_state(f.key, state, w)) + await asyncio.gather(*wait_for_states) - try: - for _ in range(10): - steal.balance() - await steal.stop() + for w, ts in zip(workers, inp): + assert len(w.state.executing) == min(1, len(ts)) + assert len(w.state.ready) == max(0, len(ts) - 1) - result = [ - sorted( - (int(key_split(ts.key)) for ts in s.workers[w.address].processing), - reverse=True, - ) - for w in workers - ] + # Balance twice since stealing might attempt to steal the already executing task + # on first try and will need a second try to correct its mistake + for _ in range(2): + steal.balance() + await steal.stop() - result2 = sorted(result, reverse=True) - expected2 = sorted(expected, reverse=True) + await ev.set() + await c.gather([f for fs in futures_per_worker.values() for f in fs]) - if result2 == expected2: - # Release the threadpools - return - finally: - await ev.set() - raise Exception(f"Expected: {expected2}; got: {result2}") + result = [ + sorted( + # Exclude input data encoded with ``SizeOf`` + (int(key_split(t)) for t in w.data.keys() if not t.startswith("SizeOf")), + reverse=True, + ) + for w in workers + ] + + result2 = sorted(result, reverse=True) + expected2 = sorted(expected, reverse=True) + + assert result2 == expected2 @pytest.mark.parametrize( From 4dfaee8e329619d58db446a02b80e4cdb088eaf6 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 3 Nov 2022 16:26:41 +0100 Subject: [PATCH 02/17] Add explanation --- distributed/tests/test_steal.py | 1 + 1 file changed, 1 insertion(+) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index dc0d5469f14..904c890900c 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -713,6 +713,7 @@ def block(*args, event, **kwargs): # on first try and will need a second try to correct its mistake for _ in range(2): steal.balance() + # steal.stop() ensures that all in-flight stealing requests have been resolved await steal.stop() await ev.set() From 21a6c91a6f1dc7bf9ae7b61108f2a9ca3db9779c Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 3 Nov 2022 16:27:27 +0100 Subject: [PATCH 03/17] Minor --- distributed/tests/test_steal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 904c890900c..debb6a321b9 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -721,7 +721,7 @@ def block(*args, event, **kwargs): result = [ sorted( - # Exclude input data encoded with ``SizeOf`` + # The name of input data starts with ``SizeOf`` (int(key_split(t)) for t in w.data.keys() if not t.startswith("SizeOf")), reverse=True, ) From 9bc0afc12886061972af2bd82dcab9acb8cf83b3 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 3 Nov 2022 17:34:04 +0100 Subject: [PATCH 04/17] Simplify _run_dependency_balance_test --- distributed/tests/test_steal.py | 98 +++++++++++++-------------------- 1 file changed, 38 insertions(+), 60 deletions(-) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index debb6a321b9..0b949d8a198 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -705,10 +705,6 @@ def block(*args, event, **kwargs): wait_for_states.append(wait_for_state(f.key, state, w)) await asyncio.gather(*wait_for_states) - for w, ts in zip(workers, inp): - assert len(w.state.executing) == min(1, len(ts)) - assert len(w.state.ready) == max(0, len(ts) - 1) - # Balance twice since stealing might attempt to steal the already executing task # on first try and will need a second try to correct its mistake for _ in range(2): @@ -1641,6 +1637,7 @@ async def _run( NO_AMM, config or {}, { + # FIXME "distributed.scheduler.unknown-task-duration": "1s", }, ), @@ -1690,7 +1687,7 @@ async def _dependency_balance_test_permutation( dependencies, permutated_dependency_placement, c, s, workers ) - ev, futures = await _place_tasks( + ev, futures_per_worker = await _place_tasks( permutated_task_placement, permutated_dependency_placement, dependency_futures, @@ -1705,22 +1702,20 @@ async def _dependency_balance_test_permutation( for ws in s.workers.values(): s.check_idle_saturated(ws) - try: - for _ in range(20): - steal.balance() - await steal.stop() + # Balance twice since stealing might attempt to steal the already executing task + # on first try and will need a second try to correct its mistake + for _ in range(2): + steal.balance() + # steal.stop() ensures that all in-flight stealing requests have been resolved + await steal.stop() - permutated_actual_placement = _get_task_placement(s, workers) - actual_placement = [permutated_actual_placement[i] for i in inverse] + await ev.set() + await c.gather([f for fs in futures_per_worker.values() for f in fs]) - if correct_placement_fn(actual_placement): - return - finally: - # Release the threadpools - await ev.set() - await c.gather(futures) + permutated_actual_placement = _get_task_placement(s, workers) + actual_placement = [permutated_actual_placement[i] for i in inverse] - raise AssertionError(actual_placement, permutation) + assert correct_placement_fn(actual_placement), (actual_placement, permutation) async def _place_dependencies( @@ -1755,30 +1750,20 @@ async def _place_dependencies( futures = {} for name, multiplier in dependencies.items(): + key = f"dep-{name}" worker_addresses = dependencies_to_workers[name] futs = await c.scatter( - {name: gen_nbytes(int(multiplier * s.bandwidth))}, + {key: gen_nbytes(int(multiplier * s.bandwidth))}, workers=worker_addresses, broadcast=True, ) - futures[name] = futs[name] + futures[name] = futs[key] await c.gather(futures.values()) - _assert_dependency_placement(placement, workers) - return futures -def _assert_dependency_placement(expected, workers): - """Assert that dependencies are placed on the workers as expected.""" - actual = [] - for worker in workers: - actual.append(list(worker.state.tasks.keys())) - - assert actual == expected - - async def _place_tasks( placement: list[list[list[str]]], dependency_placement: list[list[str]], @@ -1786,7 +1771,7 @@ async def _place_tasks( c: Client, s: Scheduler, workers: Sequence[Worker], -) -> tuple[Event, list[Future]]: +) -> tuple[Event, dict[Worker, list[Future]]]: """Places the tasks on the workers as specified. Parameters @@ -1815,36 +1800,36 @@ def block(*args, event, **kwargs): event.wait() counter = itertools.count() - futures = [] - for worker_idx, tasks in enumerate(placement): + futures_per_worker = defaultdict(list) + for worker, tasks in zip(workers, placement): for dependencies in tasks: i = next(counter) dep_key = "".join(sorted(dependencies)) - key = f"{dep_key}-{i}" + key = f"task-{dep_key}-{i}" f = c.submit( block, [dependency_futures[dependency] for dependency in dependencies], event=ev, key=key, - workers=workers[worker_idx].address, + workers=worker.address, allow_other_workers=True, pure=False, priority=-i, ) - futures.append(f) - - while len([ts for ts in s.tasks.values() if ts.processing_on]) < len(futures): - await asyncio.sleep(0.001) + futures_per_worker[worker].append(f) - while any( - len(w.state.tasks) < (len(tasks) + len(dependencies)) - for w, dependencies, tasks in zip(workers, dependency_placement, placement) - ): - await asyncio.sleep(0.001) - - assert_task_placement(placement, s, workers) + # Make sure all tasks are scheduled on the workers + # We are relying on the futures not to be rootish (and thus not to remain in the + # scheduler-side queue) because they have worker restrictions + wait_for_states = [] + for w, fs in futures_per_worker.items(): + for i, f in enumerate(fs): + # Make sure the first task is executing, all others are ready + state = "executing" if i == 0 else "ready" + wait_for_states.append(wait_for_state(f.key, state, w)) + await asyncio.gather(*wait_for_states) - return ev, futures + return ev, futures_per_worker def _get_task_placement( @@ -1854,27 +1839,20 @@ def _get_task_placement( actual = [] for w in workers: actual.append( - [list(key_split(ts.key)) for ts in s.workers[w.address].processing] + [ + list(key_split(key[5:])) # Remove "task-" prefix + for key in w.data.keys() + if key.startswith("task-") + ] ) return _deterministic_placement(actual) -def _equal_placement(left, right): - """Return True IFF the two input placements are equal.""" - return _deterministic_placement(left) == _deterministic_placement(right) - - def _deterministic_placement(placement): """Return a deterministic ordering of the tasks or dependencies on each worker.""" return [sorted(placed) for placed in placement] -def assert_task_placement(expected, s, workers): - """Assert that tasks are placed on the workers as expected.""" - actual = _get_task_placement(s, workers) - assert _equal_placement(actual, expected) - - # Reproducer from https://github.com/dask/distributed/issues/6573 @gen_cluster( client=True, From ee13aa1f1e0a66ef00424fa71320074c58a06711 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 3 Nov 2022 17:46:58 +0100 Subject: [PATCH 05/17] Configure default durations instead of unknown durations --- distributed/tests/test_steal.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 0b949d8a198..6c8ef4e9589 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1630,6 +1630,12 @@ async def _run( **kwargs, ) + default_task_durations = { + compose_task_prefix(deps): "1s" + for tasks in task_placement + for deps in tasks + } + gen_cluster( client=True, nthreads=[("", 1)] * len(task_placement), @@ -1637,13 +1643,17 @@ async def _run( NO_AMM, config or {}, { - # FIXME - "distributed.scheduler.unknown-task-duration": "1s", + "distributed.scheduler.default-task-durations": default_task_durations, }, ), )(_run)() +def compose_task_prefix(dependencies: list[str]) -> str: + dep_key = "".join(sorted(dependencies)) + return f"task-{dep_key}" + + async def _dependency_balance_test_permutation( dependencies: Mapping[str, int], dependency_placement: list[list[str]], @@ -1804,8 +1814,7 @@ def block(*args, event, **kwargs): for worker, tasks in zip(workers, placement): for dependencies in tasks: i = next(counter) - dep_key = "".join(sorted(dependencies)) - key = f"task-{dep_key}-{i}" + key = f"{compose_task_prefix(dependencies)}-{i}" f = c.submit( block, [dependency_futures[dependency] for dependency in dependencies], From a777cd480e877ed790467a339a28df6016ca6b7c Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 3 Nov 2022 20:00:22 +0100 Subject: [PATCH 06/17] Trigger CI From 3a4f2d565d9c007c44ca88803640f7bc8bb64248 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 3 Nov 2022 20:30:52 +0100 Subject: [PATCH 07/17] Deterministic ordering for stealable --- distributed/stealing.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/distributed/stealing.py b/distributed/stealing.py index b3a36c40f2b..3847ed1f514 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -64,7 +64,7 @@ class InFlightInfo(TypedDict): class WorkStealing(SchedulerPlugin): scheduler: Scheduler # {worker: ({ task states for level 0}, ..., {task states for level 14})} - stealable: dict[str, tuple[set[TaskState], ...]] + stealable: dict[str, tuple[dict[TaskState, None], ...]] # { task state: (worker, level) } key_stealable: dict[TaskState, tuple[str, int]] # (multiplier for level 0, ... multiplier for level 14) @@ -153,7 +153,7 @@ def log(self, msg: Any) -> None: return self.scheduler.log_event("stealing", msg) def add_worker(self, scheduler: Any = None, worker: Any = None) -> None: - self.stealable[worker] = tuple(set() for _ in range(15)) + self.stealable[worker] = tuple({} for _ in range(15)) def remove_worker(self, scheduler: Scheduler, worker: str) -> None: del self.stealable[worker] @@ -218,7 +218,7 @@ def put_key_in_stealable(self, ts: TaskState) -> None: assert ts.processing_on ws = ts.processing_on worker = ws.address - self.stealable[worker][level].add(ts) + self.stealable[worker][level][ts] = None self.key_stealable[ts] = (worker, level) def remove_key_from_stealable(self, ts: TaskState) -> None: @@ -228,7 +228,7 @@ def remove_key_from_stealable(self, ts: TaskState) -> None: worker, level = result try: - self.stealable[worker][level].remove(ts) + del self.stealable[worker][level][ts] except KeyError: pass @@ -440,13 +440,13 @@ def balance(self) -> None: ts not in self.key_stealable or ts.processing_on is not victim ): - stealable.discard(ts) + del stealable[ts] continue i += 1 if not (thief := _get_thief(s, ts, potential_thieves)): continue if ts not in victim.processing: - stealable.discard(ts) + del stealable[ts] continue occ_thief = self._combined_occupancy(thief) @@ -483,7 +483,7 @@ def balance(self) -> None: thief, occ_thief, nproc_thief ): potential_thieves.discard(thief) - stealable.discard(ts) + del stealable[ts] self.scheduler.check_idle_saturated( victim, occ=self._combined_occupancy(victim) ) From 89ba73d5e195378192af0ba3e8833e853262200a Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 4 Nov 2022 08:06:57 +0100 Subject: [PATCH 08/17] Retries --- distributed/tests/test_steal.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 6c8ef4e9589..25559ed9ecf 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -705,9 +705,9 @@ def block(*args, event, **kwargs): wait_for_states.append(wait_for_state(f.key, state, w)) await asyncio.gather(*wait_for_states) - # Balance twice since stealing might attempt to steal the already executing task - # on first try and will need a second try to correct its mistake - for _ in range(2): + # Balance several since stealing might attempt to steal the already executing task + # for each saturated worker and will need a chance to correct its mistake + for _ in workers: steal.balance() # steal.stop() ensures that all in-flight stealing requests have been resolved await steal.stop() @@ -1712,9 +1712,9 @@ async def _dependency_balance_test_permutation( for ws in s.workers.values(): s.check_idle_saturated(ws) - # Balance twice since stealing might attempt to steal the already executing task - # on first try and will need a second try to correct its mistake - for _ in range(2): + # Balance several since stealing might attempt to steal the already executing task + # for each saturated worker and will need a chance to correct its mistake + for _ in workers: steal.balance() # steal.stop() ensures that all in-flight stealing requests have been resolved await steal.stop() From ba83af220be788853258d521ebc8ee0c4ac6c85c Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 4 Nov 2022 09:14:42 +0100 Subject: [PATCH 09/17] Revert "Deterministic ordering for stealable" This reverts commit 3a4f2d565d9c007c44ca88803640f7bc8bb64248. --- distributed/stealing.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/distributed/stealing.py b/distributed/stealing.py index 3847ed1f514..b3a36c40f2b 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -64,7 +64,7 @@ class InFlightInfo(TypedDict): class WorkStealing(SchedulerPlugin): scheduler: Scheduler # {worker: ({ task states for level 0}, ..., {task states for level 14})} - stealable: dict[str, tuple[dict[TaskState, None], ...]] + stealable: dict[str, tuple[set[TaskState], ...]] # { task state: (worker, level) } key_stealable: dict[TaskState, tuple[str, int]] # (multiplier for level 0, ... multiplier for level 14) @@ -153,7 +153,7 @@ def log(self, msg: Any) -> None: return self.scheduler.log_event("stealing", msg) def add_worker(self, scheduler: Any = None, worker: Any = None) -> None: - self.stealable[worker] = tuple({} for _ in range(15)) + self.stealable[worker] = tuple(set() for _ in range(15)) def remove_worker(self, scheduler: Scheduler, worker: str) -> None: del self.stealable[worker] @@ -218,7 +218,7 @@ def put_key_in_stealable(self, ts: TaskState) -> None: assert ts.processing_on ws = ts.processing_on worker = ws.address - self.stealable[worker][level][ts] = None + self.stealable[worker][level].add(ts) self.key_stealable[ts] = (worker, level) def remove_key_from_stealable(self, ts: TaskState) -> None: @@ -228,7 +228,7 @@ def remove_key_from_stealable(self, ts: TaskState) -> None: worker, level = result try: - del self.stealable[worker][level][ts] + self.stealable[worker][level].remove(ts) except KeyError: pass @@ -440,13 +440,13 @@ def balance(self) -> None: ts not in self.key_stealable or ts.processing_on is not victim ): - del stealable[ts] + stealable.discard(ts) continue i += 1 if not (thief := _get_thief(s, ts, potential_thieves)): continue if ts not in victim.processing: - del stealable[ts] + stealable.discard(ts) continue occ_thief = self._combined_occupancy(thief) @@ -483,7 +483,7 @@ def balance(self) -> None: thief, occ_thief, nproc_thief ): potential_thieves.discard(thief) - del stealable[ts] + stealable.discard(ts) self.scheduler.check_idle_saturated( victim, occ=self._combined_occupancy(victim) ) From ee4ae2a0b6bb20401d38aaae385c68c7f8988ae6 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 4 Nov 2022 09:24:17 +0100 Subject: [PATCH 10/17] Minor reordering and FIXMEs --- distributed/stealing.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/distributed/stealing.py b/distributed/stealing.py index b3a36c40f2b..4fbd5fe7c24 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -439,15 +439,14 @@ def balance(self) -> None: if ( ts not in self.key_stealable or ts.processing_on is not victim + or ts not in victim.processing ): + # FIXME: Instead of discarding here, clean up stealable properly stealable.discard(ts) continue i += 1 if not (thief := _get_thief(s, ts, potential_thieves)): continue - if ts not in victim.processing: - stealable.discard(ts) - continue occ_thief = self._combined_occupancy(thief) occ_victim = self._combined_occupancy(victim) @@ -483,6 +482,9 @@ def balance(self) -> None: thief, occ_thief, nproc_thief ): potential_thieves.discard(thief) + # FIXME: move_task_request already implements some logic + # for removing ts from stealable. If we made sure to + # properly clean up, we would not need this stealable.discard(ts) self.scheduler.check_idle_saturated( victim, occ=self._combined_occupancy(victim) From 0801df96a41155937c6282f9655cd73bb3e61a3f Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 8 Nov 2022 14:33:16 +0100 Subject: [PATCH 11/17] Update distributed/tests/test_steal.py Co-authored-by: crusaderky --- distributed/tests/test_steal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 25559ed9ecf..35374185732 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -705,7 +705,7 @@ def block(*args, event, **kwargs): wait_for_states.append(wait_for_state(f.key, state, w)) await asyncio.gather(*wait_for_states) - # Balance several since stealing might attempt to steal the already executing task + # Balance several times since stealing might attempt to steal the already executing task # for each saturated worker and will need a chance to correct its mistake for _ in workers: steal.balance() From a5c5f1e03dc6f1b8a5e4dd5386e72277d9c0882c Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 10 Nov 2022 15:42:08 +0100 Subject: [PATCH 12/17] Trigger CI From 548deed1678031abb0160f5d0bb27242742d2d3c Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 14 Nov 2022 12:27:12 +0000 Subject: [PATCH 13/17] fix CI --- distributed/tests/test_steal.py | 1 + 1 file changed, 1 insertion(+) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index dc974f22e47..bd14c048983 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -48,6 +48,7 @@ slowadd, slowidentity, slowinc, + wait_for_state, ) from distributed.worker_state_machine import ( ExecuteSuccessEvent, From ea54c9834e1643b38917fda3b9a02c58179d7c97 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 14 Nov 2022 15:02:46 +0100 Subject: [PATCH 14/17] Fix test_balance_expensive_tasks --- distributed/tests/test_steal.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index bd14c048983..25cbd995e80 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -10,7 +10,7 @@ from collections import defaultdict from operator import mul from time import sleep -from typing import Callable, Iterable, Mapping, Sequence +from typing import Callable, Coroutine, Iterable, Mapping, Sequence import numpy as np import pytest @@ -1384,9 +1384,9 @@ def func(*args): @pytest.mark.parametrize( "cost, ntasks, expect_steal", [ - pytest.param(10, 5, False, id="not enough work to steal"), - pytest.param(10, 10, True, id="enough work to steal"), - pytest.param(20, 10, False, id="not enough work for increased cost"), + pytest.param(10, 10, False, id="not enough work to steal"), + pytest.param(10, 12, True, id="enough work to steal"), + pytest.param(20, 12, False, id="not enough work for increased cost"), ], ) def test_balance_expensive_tasks(cost, ntasks, expect_steal): @@ -1817,13 +1817,22 @@ def block(*args, event, **kwargs): # Make sure all tasks are scheduled on the workers # We are relying on the futures not to be rootish (and thus not to remain in the # scheduler-side queue) because they have worker restrictions - wait_for_states = [] + waits_for_state: list[Coroutine] = [] for w, fs in futures_per_worker.items(): - for i, f in enumerate(fs): - # Make sure the first task is executing, all others are ready - state = "executing" if i == 0 else "ready" - wait_for_states.append(wait_for_state(f.key, state, w)) - await asyncio.gather(*wait_for_states) + waits_for_executing_state = [] + for f in fs: + # Every task should be either ready or executing + waits_for_state.append(wait_for_state(f.key, ["executing", "ready"], w)) + waits_for_executing_state.append( + asyncio.create_task(wait_for_state(f.key, "executing", w)) + ) + # Ensure that each worker has started executing a task + waits_for_state.append( + asyncio.wait(waits_for_executing_state, return_when="FIRST_COMPLETED") + ) + await asyncio.gather( + *waits_for_state, + ) return ev, futures_per_worker From 3f53fda84e2871ce3c30bd60ecb30707a5af7150 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 14 Nov 2022 17:51:16 +0100 Subject: [PATCH 15/17] Avoid heartbeats --- distributed/tests/test_steal.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 25cbd995e80..2e0b0f31ceb 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1633,6 +1633,8 @@ async def _run( "distributed.scheduler.default-task-durations": default_task_durations, }, ), + # Avoid heartbeats since comm costs are sensitive to bandwidth updates + worker_kwargs={"heartbeat_interval": "100s"}, )(_run)() From 2365f7ac7f637217ae23727e2d6cc146160cff48 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 25 Nov 2022 17:35:12 +0100 Subject: [PATCH 16/17] Fix test --- distributed/tests/test_steal.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 2e0b0f31ceb..86966e101ee 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1390,9 +1390,9 @@ def func(*args): ], ) def test_balance_expensive_tasks(cost, ntasks, expect_steal): - dependencies = {"a": cost, "b": cost} - dependency_placement = [["a"], ["b"]] - task_placement = [[["a", "b"]] * ntasks, []] + dependencies = {"a": cost} + dependency_placement = [["a"], []] + task_placement = [[["a"]] * ntasks, []] def _correct_placement(actual): actual_task_counts = [len(placed) for placed in actual] From 7bcb4d125be8859ecc2dbaa3f5af8143a5c987c0 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 28 Nov 2022 12:03:27 +0100 Subject: [PATCH 17/17] Add assertion to avoid mistakes in the future --- distributed/tests/test_steal.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 86966e101ee..b8f899d74a8 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1800,8 +1800,17 @@ def block(*args, event, **kwargs): counter = itertools.count() futures_per_worker = defaultdict(list) - for worker, tasks in zip(workers, placement): + for worker, tasks, placed_dependencies in zip( + workers, placement, dependency_placement + ): for dependencies in tasks: + for dependency in dependencies: + assert dependency in placed_dependencies, ( + f"Dependency {dependency} of task {dependencies} not found " + "on worker {worker}. Make sure that workers already hold all " + "dependencies of their tasks to avoid transfers and skewing " + "bandwidth measurements" + ) i = next(counter) key = f"{compose_task_prefix(dependencies)}-{i}" f = c.submit(