Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,15 +439,14 @@ def balance(self) -> None:
if (
ts not in self.key_stealable
or ts.processing_on is not victim
or ts not in victim.processing
):
# FIXME: Instead of discarding here, clean up stealable properly
stealable.discard(ts)
continue
i += 1
if not (thief := _get_thief(s, ts, potential_thieves)):
continue
if ts not in victim.processing:
stealable.discard(ts)
continue

occ_thief = self._combined_occupancy(thief)
occ_victim = self._combined_occupancy(victim)
Expand Down Expand Up @@ -483,6 +482,9 @@ def balance(self) -> None:
thief, occ_thief, nproc_thief
):
potential_thieves.discard(thief)
# FIXME: move_task_request already implements some logic
# for removing ts from stealable. If we made sure to
# properly clean up, we would not need this
stealable.discard(ts)
self.scheduler.check_idle_saturated(
victim, occ=self._combined_occupancy(victim)
Expand Down
205 changes: 113 additions & 92 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections import defaultdict
from operator import mul
from time import sleep
from typing import Callable, Iterable, Mapping, Sequence
from typing import Callable, Coroutine, Iterable, Mapping, Sequence

import pytest
from tlz import merge, sliding_window
Expand Down Expand Up @@ -47,6 +47,7 @@
slowadd,
slowidentity,
slowinc,
wait_for_state,
)
from distributed.worker_state_machine import (
DigestMetric,
Expand Down Expand Up @@ -672,9 +673,9 @@ def block(*args, event, **kwargs):

counter = itertools.count()

futures = []
for w, ts in zip(workers, inp):
for t in sorted(ts, reverse=True):
futures_per_worker = defaultdict(list)
for w, tasks in zip(workers, inp):
for t in sorted(tasks, reverse=True):
if t:
[dat] = await c.scatter(
[gen_nbytes(int(t * s.bandwidth))], workers=w.address
Expand All @@ -692,33 +693,42 @@ def block(*args, event, **kwargs):
pure=False,
priority=-i,
)
futures.append(f)

while len([ts for ts in s.tasks.values() if ts.processing_on]) < len(futures):
await asyncio.sleep(0.001)
futures_per_worker[w].append(f)

# Make sure all tasks are scheduled on the workers
# We are relying on the futures not to be rootish (and thus not to remain in the
# scheduler-side queue) because they have worker restrictions
wait_for_states = []
for w, fs in futures_per_worker.items():
for i, f in enumerate(fs):
# Make sure the first task is executing, all others are ready
state = "executing" if i == 0 else "ready"
wait_for_states.append(wait_for_state(f.key, state, w))
await asyncio.gather(*wait_for_states)

# Balance several times since stealing might attempt to steal the already executing task
# for each saturated worker and will need a chance to correct its mistake
for _ in workers:
steal.balance()
# steal.stop() ensures that all in-flight stealing requests have been resolved
await steal.stop()

try:
for _ in range(10):
steal.balance()
await steal.stop()
await ev.set()
await c.gather([f for fs in futures_per_worker.values() for f in fs])

result = [
sorted(
(int(key_split(ts.key)) for ts in s.workers[w.address].processing),
reverse=True,
)
for w in workers
]
result = [
sorted(
# The name of input data starts with ``SizeOf``
(int(key_split(t)) for t in w.data.keys() if not t.startswith("SizeOf")),
reverse=True,
)
for w in workers
]

result2 = sorted(result, reverse=True)
expected2 = sorted(expected, reverse=True)
result2 = sorted(result, reverse=True)
expected2 = sorted(expected, reverse=True)

if result2 == expected2:
# Release the threadpools
return
finally:
await ev.set()
raise Exception(f"Expected: {expected2}; got: {result2}")
assert result2 == expected2


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1379,15 +1389,15 @@ def func(*args):
@pytest.mark.parametrize(
"cost, ntasks, expect_steal",
[
pytest.param(10, 5, False, id="not enough work to steal"),
pytest.param(10, 10, True, id="enough work to steal"),
pytest.param(20, 10, False, id="not enough work for increased cost"),
pytest.param(10, 10, False, id="not enough work to steal"),
pytest.param(10, 12, True, id="enough work to steal"),
pytest.param(20, 12, False, id="not enough work for increased cost"),
],
)
def test_balance_expensive_tasks(cost, ntasks, expect_steal):
dependencies = {"a": cost, "b": cost}
dependency_placement = [["a"], ["b"]]
task_placement = [[["a", "b"]] * ntasks, []]
dependencies = {"a": cost}
dependency_placement = [["a"], []]
task_placement = [[["a"]] * ntasks, []]

def _correct_placement(actual):
actual_task_counts = [len(placed) for placed in actual]
Expand Down Expand Up @@ -1612,19 +1622,32 @@ async def _run(
**kwargs,
)

default_task_durations = {
compose_task_prefix(deps): "1s"
for tasks in task_placement
for deps in tasks
}

gen_cluster(
client=True,
nthreads=[("", 1)] * len(task_placement),
config=merge(
NO_AMM,
config or {},
{
"distributed.scheduler.unknown-task-duration": "1s",
"distributed.scheduler.default-task-durations": default_task_durations,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this change needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was changed to avoid the reliance on stealing tasks of unknown duration. Instead, we now provide default durations for all relevant tasks. This concern was initially voiced here:

The test implicitly relies on tasks of unknown duration to be stolen (#5572). It should be changed not to rely on this specific use case.

#7243 (comment)

},
),
# Avoid heartbeats since comm costs are sensitive to bandwidth updates
worker_kwargs={"heartbeat_interval": "100s"},
)(_run)()


def compose_task_prefix(dependencies: list[str]) -> str:
dep_key = "".join(sorted(dependencies))
return f"task-{dep_key}"


async def _dependency_balance_test_permutation(
dependencies: Mapping[str, int],
dependency_placement: list[list[str]],
Expand Down Expand Up @@ -1668,7 +1691,7 @@ async def _dependency_balance_test_permutation(
dependencies, permutated_dependency_placement, c, s, workers
)

ev, futures = await _place_tasks(
ev, futures_per_worker = await _place_tasks(
permutated_task_placement,
permutated_dependency_placement,
dependency_futures,
Expand All @@ -1683,22 +1706,20 @@ async def _dependency_balance_test_permutation(
for ws in s.workers.values():
s.check_idle_saturated(ws)

try:
for _ in range(20):
steal.balance()
await steal.stop()
# Balance several since stealing might attempt to steal the already executing task
# for each saturated worker and will need a chance to correct its mistake
for _ in workers:
steal.balance()
# steal.stop() ensures that all in-flight stealing requests have been resolved
await steal.stop()

permutated_actual_placement = _get_task_placement(s, workers)
actual_placement = [permutated_actual_placement[i] for i in inverse]
await ev.set()
await c.gather([f for fs in futures_per_worker.values() for f in fs])

if correct_placement_fn(actual_placement):
return
finally:
# Release the threadpools
await ev.set()
await c.gather(futures)
permutated_actual_placement = _get_task_placement(s, workers)
actual_placement = [permutated_actual_placement[i] for i in inverse]

raise AssertionError(actual_placement, permutation)
assert correct_placement_fn(actual_placement), (actual_placement, permutation)


async def _place_dependencies(
Expand Down Expand Up @@ -1733,38 +1754,28 @@ async def _place_dependencies(

futures = {}
for name, multiplier in dependencies.items():
key = f"dep-{name}"
worker_addresses = dependencies_to_workers[name]
futs = await c.scatter(
{name: gen_nbytes(int(multiplier * s.bandwidth))},
{key: gen_nbytes(int(multiplier * s.bandwidth))},
workers=worker_addresses,
broadcast=True,
)
futures[name] = futs[name]
futures[name] = futs[key]

await c.gather(futures.values())

_assert_dependency_placement(placement, workers)

return futures


def _assert_dependency_placement(expected, workers):
"""Assert that dependencies are placed on the workers as expected."""
actual = []
for worker in workers:
actual.append(list(worker.state.tasks.keys()))

assert actual == expected


async def _place_tasks(
placement: list[list[list[str]]],
dependency_placement: list[list[str]],
dependency_futures: Mapping[str, Future],
c: Client,
s: Scheduler,
workers: Sequence[Worker],
) -> tuple[Event, list[Future]]:
) -> tuple[Event, dict[Worker, list[Future]]]:
"""Places the tasks on the workers as specified.

Parameters
Expand Down Expand Up @@ -1793,36 +1804,53 @@ def block(*args, event, **kwargs):
event.wait()

counter = itertools.count()
futures = []
for worker_idx, tasks in enumerate(placement):
futures_per_worker = defaultdict(list)
for worker, tasks, placed_dependencies in zip(
workers, placement, dependency_placement
):
for dependencies in tasks:
for dependency in dependencies:
assert dependency in placed_dependencies, (
f"Dependency {dependency} of task {dependencies} not found "
"on worker {worker}. Make sure that workers already hold all "
"dependencies of their tasks to avoid transfers and skewing "
"bandwidth measurements"
)
i = next(counter)
dep_key = "".join(sorted(dependencies))
key = f"{dep_key}-{i}"
key = f"{compose_task_prefix(dependencies)}-{i}"
f = c.submit(
block,
[dependency_futures[dependency] for dependency in dependencies],
event=ev,
key=key,
workers=workers[worker_idx].address,
workers=worker.address,
allow_other_workers=True,
pure=False,
priority=-i,
)
futures.append(f)

while len([ts for ts in s.tasks.values() if ts.processing_on]) < len(futures):
await asyncio.sleep(0.001)

while any(
len(w.state.tasks) < (len(tasks) + len(dependencies))
for w, dependencies, tasks in zip(workers, dependency_placement, placement)
):
await asyncio.sleep(0.001)

assert_task_placement(placement, s, workers)
futures_per_worker[worker].append(f)

# Make sure all tasks are scheduled on the workers
# We are relying on the futures not to be rootish (and thus not to remain in the
# scheduler-side queue) because they have worker restrictions
waits_for_state: list[Coroutine] = []
for w, fs in futures_per_worker.items():
waits_for_executing_state = []
for f in fs:
# Every task should be either ready or executing
waits_for_state.append(wait_for_state(f.key, ["executing", "ready"], w))
waits_for_executing_state.append(
asyncio.create_task(wait_for_state(f.key, "executing", w))
)
# Ensure that each worker has started executing a task
waits_for_state.append(
asyncio.wait(waits_for_executing_state, return_when="FIRST_COMPLETED")
)
await asyncio.gather(
*waits_for_state,
)

return ev, futures
return ev, futures_per_worker


def _get_task_placement(
Expand All @@ -1832,27 +1860,20 @@ def _get_task_placement(
actual = []
for w in workers:
actual.append(
[list(key_split(ts.key)) for ts in s.workers[w.address].processing]
[
list(key_split(key[5:])) # Remove "task-" prefix
for key in w.data.keys()
if key.startswith("task-")
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain the reasoning behind this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests should now check where the tasks have eventually been executed after stealing took place. This is the behavior we want to test here and relies less on internals as the previous version. The latter relied on checking the scheduler's worker state after stealing but during processing.

)
return _deterministic_placement(actual)


def _equal_placement(left, right):
"""Return True IFF the two input placements are equal."""
return _deterministic_placement(left) == _deterministic_placement(right)


def _deterministic_placement(placement):
"""Return a deterministic ordering of the tasks or dependencies on each worker."""
return [sorted(placed) for placed in placement]


def assert_task_placement(expected, s, workers):
"""Assert that tasks are placed on the workers as expected."""
actual = _get_task_placement(s, workers)
assert _equal_placement(actual, expected)


# Reproducer from https://github.com/dask/distributed/issues/6573
@gen_cluster(
client=True,
Expand Down