Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 32 additions & 16 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2695,7 +2695,13 @@ def transition_processing_memory(
ws,
key,
)
return recommendations, client_msgs, worker_msgs
worker_msgs[ts._processing_on.address] = [
{
"op": "cancel-compute",
"key": key,
"reason": "Finished on different worker",
}
]

has_compute_startstop: bool = False
compute_start: double
Expand Down Expand Up @@ -4229,19 +4235,25 @@ async def add_worker(
client_msgs: dict = {}
worker_msgs: dict = {}
if nbytes:
assert isinstance(nbytes, dict)
for key in nbytes:
ts: TaskState = parent._tasks.get(key)
if ts is not None and ts._state in ("processing", "waiting"):
t: tuple = parent._transition(
key,
"memory",
worker=address,
nbytes=nbytes[key],
typename=types[key],
)
recommendations, client_msgs, worker_msgs = t
parent._transitions(recommendations, client_msgs, worker_msgs)
recommendations = {}
if ts is not None:
if ts.state == "memory":
self.add_keys(worker=address, keys=[key])
else:
t: tuple = parent._transition(
key,
"memory",
worker=address,
nbytes=nbytes[key],
typename=types[key],
)
recommendations, client_msgs, worker_msgs = t
parent._transitions(
recommendations, client_msgs, worker_msgs
)
recommendations = {}

for ts in list(parent._unrunnable):
valid: set = self.valid_workers(ts)
Expand Down Expand Up @@ -4646,10 +4658,15 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs):
ts: TaskState = parent._tasks.get(key)
if ts is None:
return recommendations, client_msgs, worker_msgs

if ts.state == "memory":
self.add_keys(worker=worker, keys=[key])
return recommendations, client_msgs, worker_msgs

ws: WorkerState = parent._workers_dv[worker]
ts._metadata.update(kwargs["metadata"])

if ts._state == "processing":
if ts._state != "released":
r: tuple = parent._transition(key, "memory", worker=worker, **kwargs)
recommendations, client_msgs, worker_msgs = r

Expand Down Expand Up @@ -5567,12 +5584,11 @@ async def gather(self, comm=None, keys=None, serializers=None):
# Remove suspicious workers from the scheduler but allow them to
# reconnect.
await asyncio.gather(
*[
*(
self.remove_worker(address=worker, close=False)
for worker in missing_workers
]
)
)

recommendations: dict
client_msgs: dict = {}
worker_msgs: dict = {}
Expand Down
111 changes: 72 additions & 39 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
tls_only_security,
varying,
)
from distributed.worker import dumps_function, dumps_task
from distributed.worker import dumps_function, dumps_task, get_worker

if sys.version_info < (3, 8):
try:
Expand Down Expand Up @@ -2180,38 +2180,67 @@ async def test_gather_no_workers(c, s, a, b):
assert list(res["keys"]) == ["x"]


@pytest.mark.slow
@pytest.mark.parametrize("reschedule_different_worker", [True, False])
@pytest.mark.parametrize("swap_data_insert_order", [True, False])
@gen_cluster(client=True, client_kwargs={"direct_to_workers": False})
async def test_gather_allow_worker_reconnect(c, s, a, b):
async def test_gather_allow_worker_reconnect(
c, s, a, b, reschedule_different_worker, swap_data_insert_order
):
"""
Test that client resubmissions allow failed workers to reconnect and re-use
their results. Failure scenario would be a connection issue during result
gathering.
Upon connection failure, the worker is flagged as suspicious and removed
from the scheduler. If the worker is healthy and reconnencts we want to use
its results instead of recomputing them.

See also distributed.tests.test_worker.py::test_worker_reconnects_mid_compute
"""
# GH3246
already_calculated = []

import time

def inc_slow(x):
# Once the graph below is rescheduled this computation runs again. We
# need to sleep for at least 0.5 seconds to give the worker a chance to
# reconnect (Heartbeat timing). In slow CI situations, the actual
# reconnect might take a bit longer, therefore wait more
if x in already_calculated:
time.sleep(2)
already_calculated.append(x)
if reschedule_different_worker:
from distributed.diagnostics.plugin import SchedulerPlugin

class SwitchRestrictions(SchedulerPlugin):
def __init__(self, scheduler):
self.scheduler = scheduler

def transition(self, key, start, finish, **kwargs):
if key in ("reducer", "final") and finish == "memory":
self.scheduler.tasks[key]._worker_restrictions = {b.address}

plugin = SwitchRestrictions(s)
s.add_plugin(plugin)

from distributed import Lock

b_address = b.address

def inc_slow(x, lock):
w = get_worker()
if w.address == b_address:
with lock:
return x + 1
return x + 1

x = c.submit(inc_slow, 1)
y = c.submit(inc_slow, 2)
lock = Lock()

await lock.acquire()

x = c.submit(inc_slow, 1, lock, workers=[a.address], allow_other_workers=True)

def reducer(*args):
return get_worker().address

def reducer(x, y):
return x + y
def finalizer(addr):
if swap_data_insert_order:
w = get_worker()
new_data = {k: w.data[k] for k in list(w.data.keys())[::-1]}
w.data = new_data
return addr

z = c.submit(reducer, x, y)
z = c.submit(reducer, x, key="reducer", workers=[a.address])
fin = c.submit(finalizer, z, key="final", workers=[a.address])

s.rpc = await FlakyConnectionPool(failing_connections=1)

Expand All @@ -2225,9 +2254,31 @@ def reducer(x, y):
) as client_logger:
# Gather using the client (as an ordinary user would)
# Upon a missing key, the client will reschedule the computations
res = await c.gather(z)
res = None
while not res:
try:
# This reduces test runtime by about a second since we're
# depending on a worker heartbeat for a reconnect.
res = await asyncio.wait_for(fin, 0.1)
except asyncio.TimeoutError:
await a.heartbeat()

# Ensure that we're actually reusing the result
assert res == a.address
await lock.release()

while not all(all(ts.state == "memory" for ts in w.tasks.values()) for w in [a, b]):
await asyncio.sleep(0.01)

assert res == 5
assert z.key in a.tasks
assert z.key not in b.tasks
assert b.executed_count == 1
for w in [a, b]:
assert x.key in w.tasks
assert w.tasks[x.key].state == "memory"
while not len(s.tasks[x.key].who_has) == 2:
await asyncio.sleep(0.01)
assert len(s.tasks[z.key].who_has) == 1

sched_logger = sched_logger.getvalue()
client_logger = client_logger.getvalue()
Expand All @@ -2243,24 +2294,6 @@ def reducer(x, y):
# is rather an artifact and not the intention
assert "Workers don't have promised key" in sched_logger

# Once the worker reconnects, it will also submit the keys it holds such
# that the scheduler again knows about the result.
# The final reduce step should then be used from the re-connected worker
# instead of recomputing it.
transitions_to_processing = [
(key, start, timestamp)
for key, start, finish, recommendations, timestamp in s.transition_log
if finish == "processing" and "reducer" in key
]
assert len(transitions_to_processing) == 1

finish_processing_transitions = 0
for transition in s.transition_log:
key, start, finish, recommendations, timestamp = transition
if "reducer" in key and finish == "processing":
finish_processing_transitions += 1
assert finish_processing_transitions == 1


@gen_cluster(client=True)
async def test_too_many_groups(c, s, a, b):
Expand Down
65 changes: 65 additions & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2370,6 +2370,71 @@ async def test_hold_on_to_replicas(c, s, *workers):
await asyncio.sleep(0.01)


@gen_cluster(client=True)
async def test_worker_reconnects_mid_compute(c, s, a, b):
"""
This test ensure that if a worker disconnects while computing a result, the scheduler will still accept the result.

There is also an edge case tested which ensures that the reconnect is
successful if a task is currently executing, see
https://github.com/dask/distributed/issues/5078

See also distributed.tests.test_scheduler.py::test_gather_allow_worker_reconnect
"""
with captured_logger("distributed.scheduler") as s_logs:
# Let's put one task in memory to ensure the reconnect has tasks in
# different states
f1 = c.submit(inc, 1, workers=[a.address], allow_other_workers=True)
await f1
a_address = a.address
a.periodic_callbacks["heartbeat"].stop()
await a.heartbeat()
a.heartbeat_active = True

from distributed import Lock

def fast_on_a(lock):
w = get_worker()
import time

if w.address != a_address:
lock.acquire()
else:
time.sleep(1)

lock = Lock()
# We want to be sure that A is the only one computing this result
async with lock:

f2 = c.submit(
fast_on_a, lock, workers=[a.address], allow_other_workers=True
)

while f2.key not in a.tasks:
await asyncio.sleep(0.01)

await s.stream_comms[a.address].close()

assert len(s.workers) == 1
a.heartbeat_active = False
await a.heartbeat()
assert len(s.workers) == 2
# Since B is locked, this is ensured to originate from A
await f2

assert "Unexpected worker completed task" in s_logs.getvalue()

while not len(s.tasks[f2.key].who_has) == 2:
await asyncio.sleep(0.001)

# Ensure that all keys have been properly registered and will also be
# cleaned up nicely.
del f1, f2

while any(w.tasks for w in [a, b]):
await asyncio.sleep(0.001)


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])
async def test_forget_dependents_after_release(c, s, a):

Expand Down
26 changes: 25 additions & 1 deletion distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,7 @@ def __init__(
stream_handlers = {
"close": self.close,
"compute-task": self.add_task,
"cancel-compute": self.cancel_compute,
"free-keys": self.handle_free_keys,
"superfluous-data": self.handle_superfluous_data,
"steal-request": self.steal_request,
Expand Down Expand Up @@ -901,7 +902,14 @@ async def _register_with_scheduler(self):
keys=list(self.data),
nthreads=self.nthreads,
name=self.name,
nbytes={ts.key: ts.get_nbytes() for ts in self.tasks.values()},
nbytes={
ts.key: ts.get_nbytes()
for ts in self.tasks.values()
# Only if the task is in memory this is a sensible
# result since otherwise it simply submits the
# default value
if ts.state == "memory"
},
types={k: typename(v) for k, v in self.data.items()},
now=time(),
resources=self.total_resources,
Expand Down Expand Up @@ -1544,6 +1552,22 @@ async def set_resources(self, **resources):
# Task Management #
###################

def cancel_compute(self, key, reason):
"""
Cancel a task on a best effort basis. This is only possible while a task
is in state `waiting` or `ready`.
Nothing will happen otherwise.
"""
ts = self.tasks.get(key)
if ts and ts.state in ("waiting", "ready"):
self.log.append((key, "cancel-compute", reason))
ts.scheduler_holds_ref = False
# All possible dependents of TS should not be in state Processing on
# scheduler side and therefore should not be assigned to a worker,
# yet.
assert not ts.dependents
self.release_key(key, reason=reason, report=False)

def add_task(
self,
key,
Expand Down