-
-
Notifications
You must be signed in to change notification settings - Fork 748
Fix test_balance_expensive_tasks and improve helper functions in test_steal.py
#7253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e9928dc
4dfaee8
21a6c91
9bc0afc
ee13aa1
a777cd4
3a4f2d5
89ba73d
ba83af2
ee4ae2a
0801df9
a5c5f1e
d70b9a9
548deed
ea54c98
3f53fda
0b23d8f
2365f7a
7bcb4d1
cbbae16
c0a1abb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,7 @@ | |
| from collections import defaultdict | ||
| from operator import mul | ||
| from time import sleep | ||
| from typing import Callable, Iterable, Mapping, Sequence | ||
| from typing import Callable, Coroutine, Iterable, Mapping, Sequence | ||
|
|
||
| import pytest | ||
| from tlz import merge, sliding_window | ||
|
|
@@ -47,6 +47,7 @@ | |
| slowadd, | ||
| slowidentity, | ||
| slowinc, | ||
| wait_for_state, | ||
| ) | ||
| from distributed.worker_state_machine import ( | ||
| DigestMetric, | ||
|
|
@@ -672,9 +673,9 @@ def block(*args, event, **kwargs): | |
|
|
||
| counter = itertools.count() | ||
|
|
||
| futures = [] | ||
| for w, ts in zip(workers, inp): | ||
| for t in sorted(ts, reverse=True): | ||
| futures_per_worker = defaultdict(list) | ||
| for w, tasks in zip(workers, inp): | ||
| for t in sorted(tasks, reverse=True): | ||
| if t: | ||
| [dat] = await c.scatter( | ||
| [gen_nbytes(int(t * s.bandwidth))], workers=w.address | ||
|
|
@@ -692,33 +693,42 @@ def block(*args, event, **kwargs): | |
| pure=False, | ||
| priority=-i, | ||
| ) | ||
| futures.append(f) | ||
|
|
||
| while len([ts for ts in s.tasks.values() if ts.processing_on]) < len(futures): | ||
| await asyncio.sleep(0.001) | ||
| futures_per_worker[w].append(f) | ||
|
|
||
| # Make sure all tasks are scheduled on the workers | ||
| # We are relying on the futures not to be rootish (and thus not to remain in the | ||
| # scheduler-side queue) because they have worker restrictions | ||
| wait_for_states = [] | ||
| for w, fs in futures_per_worker.items(): | ||
| for i, f in enumerate(fs): | ||
| # Make sure the first task is executing, all others are ready | ||
| state = "executing" if i == 0 else "ready" | ||
| wait_for_states.append(wait_for_state(f.key, state, w)) | ||
| await asyncio.gather(*wait_for_states) | ||
|
|
||
| # Balance several times since stealing might attempt to steal the already executing task | ||
| # for each saturated worker and will need a chance to correct its mistake | ||
| for _ in workers: | ||
| steal.balance() | ||
| # steal.stop() ensures that all in-flight stealing requests have been resolved | ||
| await steal.stop() | ||
|
|
||
| try: | ||
| for _ in range(10): | ||
| steal.balance() | ||
| await steal.stop() | ||
| await ev.set() | ||
| await c.gather([f for fs in futures_per_worker.values() for f in fs]) | ||
|
|
||
| result = [ | ||
| sorted( | ||
| (int(key_split(ts.key)) for ts in s.workers[w.address].processing), | ||
| reverse=True, | ||
| ) | ||
| for w in workers | ||
| ] | ||
| result = [ | ||
| sorted( | ||
| # The name of input data starts with ``SizeOf`` | ||
| (int(key_split(t)) for t in w.data.keys() if not t.startswith("SizeOf")), | ||
| reverse=True, | ||
| ) | ||
| for w in workers | ||
| ] | ||
|
|
||
| result2 = sorted(result, reverse=True) | ||
| expected2 = sorted(expected, reverse=True) | ||
| result2 = sorted(result, reverse=True) | ||
| expected2 = sorted(expected, reverse=True) | ||
|
|
||
| if result2 == expected2: | ||
| # Release the threadpools | ||
| return | ||
| finally: | ||
| await ev.set() | ||
| raise Exception(f"Expected: {expected2}; got: {result2}") | ||
| assert result2 == expected2 | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
|
|
@@ -1379,15 +1389,15 @@ def func(*args): | |
| @pytest.mark.parametrize( | ||
| "cost, ntasks, expect_steal", | ||
| [ | ||
| pytest.param(10, 5, False, id="not enough work to steal"), | ||
| pytest.param(10, 10, True, id="enough work to steal"), | ||
| pytest.param(20, 10, False, id="not enough work for increased cost"), | ||
| pytest.param(10, 10, False, id="not enough work to steal"), | ||
| pytest.param(10, 12, True, id="enough work to steal"), | ||
| pytest.param(20, 12, False, id="not enough work for increased cost"), | ||
| ], | ||
| ) | ||
| def test_balance_expensive_tasks(cost, ntasks, expect_steal): | ||
| dependencies = {"a": cost, "b": cost} | ||
| dependency_placement = [["a"], ["b"]] | ||
| task_placement = [[["a", "b"]] * ntasks, []] | ||
| dependencies = {"a": cost} | ||
| dependency_placement = [["a"], []] | ||
| task_placement = [[["a"]] * ntasks, []] | ||
|
|
||
| def _correct_placement(actual): | ||
| actual_task_counts = [len(placed) for placed in actual] | ||
|
|
@@ -1612,19 +1622,32 @@ async def _run( | |
| **kwargs, | ||
| ) | ||
|
|
||
| default_task_durations = { | ||
| compose_task_prefix(deps): "1s" | ||
| for tasks in task_placement | ||
| for deps in tasks | ||
| } | ||
|
|
||
| gen_cluster( | ||
| client=True, | ||
| nthreads=[("", 1)] * len(task_placement), | ||
| config=merge( | ||
| NO_AMM, | ||
| config or {}, | ||
| { | ||
| "distributed.scheduler.unknown-task-duration": "1s", | ||
| "distributed.scheduler.default-task-durations": default_task_durations, | ||
| }, | ||
| ), | ||
| # Avoid heartbeats since comm costs are sensitive to bandwidth updates | ||
| worker_kwargs={"heartbeat_interval": "100s"}, | ||
| )(_run)() | ||
|
|
||
|
|
||
| def compose_task_prefix(dependencies: list[str]) -> str: | ||
| dep_key = "".join(sorted(dependencies)) | ||
| return f"task-{dep_key}" | ||
|
|
||
|
|
||
| async def _dependency_balance_test_permutation( | ||
| dependencies: Mapping[str, int], | ||
| dependency_placement: list[list[str]], | ||
|
|
@@ -1668,7 +1691,7 @@ async def _dependency_balance_test_permutation( | |
| dependencies, permutated_dependency_placement, c, s, workers | ||
| ) | ||
|
|
||
| ev, futures = await _place_tasks( | ||
| ev, futures_per_worker = await _place_tasks( | ||
| permutated_task_placement, | ||
| permutated_dependency_placement, | ||
| dependency_futures, | ||
|
|
@@ -1683,22 +1706,20 @@ async def _dependency_balance_test_permutation( | |
| for ws in s.workers.values(): | ||
| s.check_idle_saturated(ws) | ||
|
|
||
| try: | ||
| for _ in range(20): | ||
| steal.balance() | ||
| await steal.stop() | ||
| # Balance several since stealing might attempt to steal the already executing task | ||
| # for each saturated worker and will need a chance to correct its mistake | ||
| for _ in workers: | ||
| steal.balance() | ||
| # steal.stop() ensures that all in-flight stealing requests have been resolved | ||
| await steal.stop() | ||
|
|
||
| permutated_actual_placement = _get_task_placement(s, workers) | ||
| actual_placement = [permutated_actual_placement[i] for i in inverse] | ||
| await ev.set() | ||
| await c.gather([f for fs in futures_per_worker.values() for f in fs]) | ||
|
|
||
| if correct_placement_fn(actual_placement): | ||
| return | ||
| finally: | ||
| # Release the threadpools | ||
| await ev.set() | ||
| await c.gather(futures) | ||
| permutated_actual_placement = _get_task_placement(s, workers) | ||
| actual_placement = [permutated_actual_placement[i] for i in inverse] | ||
|
|
||
| raise AssertionError(actual_placement, permutation) | ||
| assert correct_placement_fn(actual_placement), (actual_placement, permutation) | ||
|
|
||
|
|
||
| async def _place_dependencies( | ||
|
|
@@ -1733,38 +1754,28 @@ async def _place_dependencies( | |
|
|
||
| futures = {} | ||
| for name, multiplier in dependencies.items(): | ||
| key = f"dep-{name}" | ||
| worker_addresses = dependencies_to_workers[name] | ||
| futs = await c.scatter( | ||
| {name: gen_nbytes(int(multiplier * s.bandwidth))}, | ||
| {key: gen_nbytes(int(multiplier * s.bandwidth))}, | ||
| workers=worker_addresses, | ||
| broadcast=True, | ||
| ) | ||
| futures[name] = futs[name] | ||
| futures[name] = futs[key] | ||
|
|
||
| await c.gather(futures.values()) | ||
|
|
||
| _assert_dependency_placement(placement, workers) | ||
|
|
||
| return futures | ||
|
|
||
|
|
||
| def _assert_dependency_placement(expected, workers): | ||
| """Assert that dependencies are placed on the workers as expected.""" | ||
| actual = [] | ||
| for worker in workers: | ||
| actual.append(list(worker.state.tasks.keys())) | ||
|
|
||
| assert actual == expected | ||
|
|
||
|
|
||
| async def _place_tasks( | ||
| placement: list[list[list[str]]], | ||
| dependency_placement: list[list[str]], | ||
| dependency_futures: Mapping[str, Future], | ||
| c: Client, | ||
| s: Scheduler, | ||
| workers: Sequence[Worker], | ||
| ) -> tuple[Event, list[Future]]: | ||
| ) -> tuple[Event, dict[Worker, list[Future]]]: | ||
| """Places the tasks on the workers as specified. | ||
|
|
||
| Parameters | ||
|
|
@@ -1793,36 +1804,53 @@ def block(*args, event, **kwargs): | |
| event.wait() | ||
|
|
||
| counter = itertools.count() | ||
| futures = [] | ||
| for worker_idx, tasks in enumerate(placement): | ||
| futures_per_worker = defaultdict(list) | ||
| for worker, tasks, placed_dependencies in zip( | ||
| workers, placement, dependency_placement | ||
| ): | ||
| for dependencies in tasks: | ||
| for dependency in dependencies: | ||
| assert dependency in placed_dependencies, ( | ||
| f"Dependency {dependency} of task {dependencies} not found " | ||
| "on worker {worker}. Make sure that workers already hold all " | ||
| "dependencies of their tasks to avoid transfers and skewing " | ||
| "bandwidth measurements" | ||
| ) | ||
| i = next(counter) | ||
| dep_key = "".join(sorted(dependencies)) | ||
| key = f"{dep_key}-{i}" | ||
| key = f"{compose_task_prefix(dependencies)}-{i}" | ||
| f = c.submit( | ||
| block, | ||
| [dependency_futures[dependency] for dependency in dependencies], | ||
| event=ev, | ||
| key=key, | ||
| workers=workers[worker_idx].address, | ||
| workers=worker.address, | ||
| allow_other_workers=True, | ||
| pure=False, | ||
| priority=-i, | ||
| ) | ||
| futures.append(f) | ||
|
|
||
| while len([ts for ts in s.tasks.values() if ts.processing_on]) < len(futures): | ||
| await asyncio.sleep(0.001) | ||
|
|
||
| while any( | ||
| len(w.state.tasks) < (len(tasks) + len(dependencies)) | ||
| for w, dependencies, tasks in zip(workers, dependency_placement, placement) | ||
| ): | ||
| await asyncio.sleep(0.001) | ||
|
|
||
| assert_task_placement(placement, s, workers) | ||
| futures_per_worker[worker].append(f) | ||
|
|
||
| # Make sure all tasks are scheduled on the workers | ||
| # We are relying on the futures not to be rootish (and thus not to remain in the | ||
| # scheduler-side queue) because they have worker restrictions | ||
| waits_for_state: list[Coroutine] = [] | ||
| for w, fs in futures_per_worker.items(): | ||
| waits_for_executing_state = [] | ||
| for f in fs: | ||
| # Every task should be either ready or executing | ||
| waits_for_state.append(wait_for_state(f.key, ["executing", "ready"], w)) | ||
| waits_for_executing_state.append( | ||
| asyncio.create_task(wait_for_state(f.key, "executing", w)) | ||
| ) | ||
| # Ensure that each worker has started executing a task | ||
| waits_for_state.append( | ||
| asyncio.wait(waits_for_executing_state, return_when="FIRST_COMPLETED") | ||
| ) | ||
| await asyncio.gather( | ||
| *waits_for_state, | ||
| ) | ||
|
|
||
| return ev, futures | ||
| return ev, futures_per_worker | ||
|
|
||
|
|
||
| def _get_task_placement( | ||
|
|
@@ -1832,27 +1860,20 @@ def _get_task_placement( | |
| actual = [] | ||
| for w in workers: | ||
| actual.append( | ||
| [list(key_split(ts.key)) for ts in s.workers[w.address].processing] | ||
| [ | ||
| list(key_split(key[5:])) # Remove "task-" prefix | ||
| for key in w.data.keys() | ||
| if key.startswith("task-") | ||
| ] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you explain the reasoning behind this change?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The tests should now check where the tasks have eventually been executed after stealing took place. This is the behavior we want to test here and relies less on internals as the previous version. The latter relied on checking the scheduler's worker state after stealing but during processing. |
||
| ) | ||
| return _deterministic_placement(actual) | ||
|
|
||
|
|
||
| def _equal_placement(left, right): | ||
| """Return True IFF the two input placements are equal.""" | ||
| return _deterministic_placement(left) == _deterministic_placement(right) | ||
|
|
||
|
|
||
| def _deterministic_placement(placement): | ||
| """Return a deterministic ordering of the tasks or dependencies on each worker.""" | ||
| return [sorted(placed) for placed in placement] | ||
|
|
||
|
|
||
| def assert_task_placement(expected, s, workers): | ||
| """Assert that tasks are placed on the workers as expected.""" | ||
| actual = _get_task_placement(s, workers) | ||
| assert _equal_placement(actual, expected) | ||
|
|
||
|
|
||
| # Reproducer from https://github.com/dask/distributed/issues/6573 | ||
| @gen_cluster( | ||
| client=True, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this change needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was changed to avoid the reliance on stealing tasks of unknown duration. Instead, we now provide default durations for all relevant tasks. This concern was initially voiced here:
#7243 (comment)