diff --git a/distributed/stealing.py b/distributed/stealing.py index 8332c60d45c..11e5fb61e82 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -439,15 +439,14 @@ def balance(self) -> None: if ( ts not in self.key_stealable or ts.processing_on is not victim + or ts not in victim.processing ): + # FIXME: Instead of discarding here, clean up stealable properly stealable.discard(ts) continue i += 1 if not (thief := _get_thief(s, ts, potential_thieves)): continue - if ts not in victim.processing: - stealable.discard(ts) - continue occ_thief = self._combined_occupancy(thief) occ_victim = self._combined_occupancy(victim) @@ -483,6 +482,9 @@ def balance(self) -> None: thief, occ_thief, nproc_thief ): potential_thieves.discard(thief) + # FIXME: move_task_request already implements some logic + # for removing ts from stealable. If we made sure to + # properly clean up, we would not need this stealable.discard(ts) self.scheduler.check_idle_saturated( victim, occ=self._combined_occupancy(victim) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index bf726d1bcf9..39e8fda5268 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -10,7 +10,7 @@ from collections import defaultdict from operator import mul from time import sleep -from typing import Callable, Iterable, Mapping, Sequence +from typing import Callable, Coroutine, Iterable, Mapping, Sequence import pytest from tlz import merge, sliding_window @@ -47,6 +47,7 @@ slowadd, slowidentity, slowinc, + wait_for_state, ) from distributed.worker_state_machine import ( DigestMetric, @@ -672,9 +673,9 @@ def block(*args, event, **kwargs): counter = itertools.count() - futures = [] - for w, ts in zip(workers, inp): - for t in sorted(ts, reverse=True): + futures_per_worker = defaultdict(list) + for w, tasks in zip(workers, inp): + for t in sorted(tasks, reverse=True): if t: [dat] = await c.scatter( [gen_nbytes(int(t * s.bandwidth))], workers=w.address @@ -692,33 +693,42 @@ def block(*args, event, **kwargs): pure=False, priority=-i, ) - futures.append(f) - - while len([ts for ts in s.tasks.values() if ts.processing_on]) < len(futures): - await asyncio.sleep(0.001) + futures_per_worker[w].append(f) + + # Make sure all tasks are scheduled on the workers + # We are relying on the futures not to be rootish (and thus not to remain in the + # scheduler-side queue) because they have worker restrictions + wait_for_states = [] + for w, fs in futures_per_worker.items(): + for i, f in enumerate(fs): + # Make sure the first task is executing, all others are ready + state = "executing" if i == 0 else "ready" + wait_for_states.append(wait_for_state(f.key, state, w)) + await asyncio.gather(*wait_for_states) + + # Balance several times since stealing might attempt to steal the already executing task + # for each saturated worker and will need a chance to correct its mistake + for _ in workers: + steal.balance() + # steal.stop() ensures that all in-flight stealing requests have been resolved + await steal.stop() - try: - for _ in range(10): - steal.balance() - await steal.stop() + await ev.set() + await c.gather([f for fs in futures_per_worker.values() for f in fs]) - result = [ - sorted( - (int(key_split(ts.key)) for ts in s.workers[w.address].processing), - reverse=True, - ) - for w in workers - ] + result = [ + sorted( + # The name of input data starts with ``SizeOf`` + (int(key_split(t)) for t in w.data.keys() if not t.startswith("SizeOf")), + reverse=True, + ) + for w in workers + ] - result2 = sorted(result, reverse=True) - expected2 = sorted(expected, reverse=True) + result2 = sorted(result, reverse=True) + expected2 = sorted(expected, reverse=True) - if result2 == expected2: - # Release the threadpools - return - finally: - await ev.set() - raise Exception(f"Expected: {expected2}; got: {result2}") + assert result2 == expected2 @pytest.mark.parametrize( @@ -1379,15 +1389,15 @@ def func(*args): @pytest.mark.parametrize( "cost, ntasks, expect_steal", [ - pytest.param(10, 5, False, id="not enough work to steal"), - pytest.param(10, 10, True, id="enough work to steal"), - pytest.param(20, 10, False, id="not enough work for increased cost"), + pytest.param(10, 10, False, id="not enough work to steal"), + pytest.param(10, 12, True, id="enough work to steal"), + pytest.param(20, 12, False, id="not enough work for increased cost"), ], ) def test_balance_expensive_tasks(cost, ntasks, expect_steal): - dependencies = {"a": cost, "b": cost} - dependency_placement = [["a"], ["b"]] - task_placement = [[["a", "b"]] * ntasks, []] + dependencies = {"a": cost} + dependency_placement = [["a"], []] + task_placement = [[["a"]] * ntasks, []] def _correct_placement(actual): actual_task_counts = [len(placed) for placed in actual] @@ -1612,6 +1622,12 @@ async def _run( **kwargs, ) + default_task_durations = { + compose_task_prefix(deps): "1s" + for tasks in task_placement + for deps in tasks + } + gen_cluster( client=True, nthreads=[("", 1)] * len(task_placement), @@ -1619,12 +1635,19 @@ async def _run( NO_AMM, config or {}, { - "distributed.scheduler.unknown-task-duration": "1s", + "distributed.scheduler.default-task-durations": default_task_durations, }, ), + # Avoid heartbeats since comm costs are sensitive to bandwidth updates + worker_kwargs={"heartbeat_interval": "100s"}, )(_run)() +def compose_task_prefix(dependencies: list[str]) -> str: + dep_key = "".join(sorted(dependencies)) + return f"task-{dep_key}" + + async def _dependency_balance_test_permutation( dependencies: Mapping[str, int], dependency_placement: list[list[str]], @@ -1668,7 +1691,7 @@ async def _dependency_balance_test_permutation( dependencies, permutated_dependency_placement, c, s, workers ) - ev, futures = await _place_tasks( + ev, futures_per_worker = await _place_tasks( permutated_task_placement, permutated_dependency_placement, dependency_futures, @@ -1683,22 +1706,20 @@ async def _dependency_balance_test_permutation( for ws in s.workers.values(): s.check_idle_saturated(ws) - try: - for _ in range(20): - steal.balance() - await steal.stop() + # Balance several since stealing might attempt to steal the already executing task + # for each saturated worker and will need a chance to correct its mistake + for _ in workers: + steal.balance() + # steal.stop() ensures that all in-flight stealing requests have been resolved + await steal.stop() - permutated_actual_placement = _get_task_placement(s, workers) - actual_placement = [permutated_actual_placement[i] for i in inverse] + await ev.set() + await c.gather([f for fs in futures_per_worker.values() for f in fs]) - if correct_placement_fn(actual_placement): - return - finally: - # Release the threadpools - await ev.set() - await c.gather(futures) + permutated_actual_placement = _get_task_placement(s, workers) + actual_placement = [permutated_actual_placement[i] for i in inverse] - raise AssertionError(actual_placement, permutation) + assert correct_placement_fn(actual_placement), (actual_placement, permutation) async def _place_dependencies( @@ -1733,30 +1754,20 @@ async def _place_dependencies( futures = {} for name, multiplier in dependencies.items(): + key = f"dep-{name}" worker_addresses = dependencies_to_workers[name] futs = await c.scatter( - {name: gen_nbytes(int(multiplier * s.bandwidth))}, + {key: gen_nbytes(int(multiplier * s.bandwidth))}, workers=worker_addresses, broadcast=True, ) - futures[name] = futs[name] + futures[name] = futs[key] await c.gather(futures.values()) - _assert_dependency_placement(placement, workers) - return futures -def _assert_dependency_placement(expected, workers): - """Assert that dependencies are placed on the workers as expected.""" - actual = [] - for worker in workers: - actual.append(list(worker.state.tasks.keys())) - - assert actual == expected - - async def _place_tasks( placement: list[list[list[str]]], dependency_placement: list[list[str]], @@ -1764,7 +1775,7 @@ async def _place_tasks( c: Client, s: Scheduler, workers: Sequence[Worker], -) -> tuple[Event, list[Future]]: +) -> tuple[Event, dict[Worker, list[Future]]]: """Places the tasks on the workers as specified. Parameters @@ -1793,36 +1804,53 @@ def block(*args, event, **kwargs): event.wait() counter = itertools.count() - futures = [] - for worker_idx, tasks in enumerate(placement): + futures_per_worker = defaultdict(list) + for worker, tasks, placed_dependencies in zip( + workers, placement, dependency_placement + ): for dependencies in tasks: + for dependency in dependencies: + assert dependency in placed_dependencies, ( + f"Dependency {dependency} of task {dependencies} not found " + "on worker {worker}. Make sure that workers already hold all " + "dependencies of their tasks to avoid transfers and skewing " + "bandwidth measurements" + ) i = next(counter) - dep_key = "".join(sorted(dependencies)) - key = f"{dep_key}-{i}" + key = f"{compose_task_prefix(dependencies)}-{i}" f = c.submit( block, [dependency_futures[dependency] for dependency in dependencies], event=ev, key=key, - workers=workers[worker_idx].address, + workers=worker.address, allow_other_workers=True, pure=False, priority=-i, ) - futures.append(f) - - while len([ts for ts in s.tasks.values() if ts.processing_on]) < len(futures): - await asyncio.sleep(0.001) - - while any( - len(w.state.tasks) < (len(tasks) + len(dependencies)) - for w, dependencies, tasks in zip(workers, dependency_placement, placement) - ): - await asyncio.sleep(0.001) - - assert_task_placement(placement, s, workers) + futures_per_worker[worker].append(f) + + # Make sure all tasks are scheduled on the workers + # We are relying on the futures not to be rootish (and thus not to remain in the + # scheduler-side queue) because they have worker restrictions + waits_for_state: list[Coroutine] = [] + for w, fs in futures_per_worker.items(): + waits_for_executing_state = [] + for f in fs: + # Every task should be either ready or executing + waits_for_state.append(wait_for_state(f.key, ["executing", "ready"], w)) + waits_for_executing_state.append( + asyncio.create_task(wait_for_state(f.key, "executing", w)) + ) + # Ensure that each worker has started executing a task + waits_for_state.append( + asyncio.wait(waits_for_executing_state, return_when="FIRST_COMPLETED") + ) + await asyncio.gather( + *waits_for_state, + ) - return ev, futures + return ev, futures_per_worker def _get_task_placement( @@ -1832,27 +1860,20 @@ def _get_task_placement( actual = [] for w in workers: actual.append( - [list(key_split(ts.key)) for ts in s.workers[w.address].processing] + [ + list(key_split(key[5:])) # Remove "task-" prefix + for key in w.data.keys() + if key.startswith("task-") + ] ) return _deterministic_placement(actual) -def _equal_placement(left, right): - """Return True IFF the two input placements are equal.""" - return _deterministic_placement(left) == _deterministic_placement(right) - - def _deterministic_placement(placement): """Return a deterministic ordering of the tasks or dependencies on each worker.""" return [sorted(placed) for placed in placement] -def assert_task_placement(expected, s, workers): - """Assert that tasks are placed on the workers as expected.""" - actual = _get_task_placement(s, workers) - assert _equal_placement(actual, expected) - - # Reproducer from https://github.com/dask/distributed/issues/6573 @gen_cluster( client=True,