diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 75c6fe259be..b3d281dc4cb 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2695,7 +2695,13 @@ def transition_processing_memory( ws, key, ) - return recommendations, client_msgs, worker_msgs + worker_msgs[ts._processing_on.address] = [ + { + "op": "cancel-compute", + "key": key, + "reason": "Finished on different worker", + } + ] has_compute_startstop: bool = False compute_start: double @@ -4229,19 +4235,25 @@ async def add_worker( client_msgs: dict = {} worker_msgs: dict = {} if nbytes: + assert isinstance(nbytes, dict) for key in nbytes: ts: TaskState = parent._tasks.get(key) - if ts is not None and ts._state in ("processing", "waiting"): - t: tuple = parent._transition( - key, - "memory", - worker=address, - nbytes=nbytes[key], - typename=types[key], - ) - recommendations, client_msgs, worker_msgs = t - parent._transitions(recommendations, client_msgs, worker_msgs) - recommendations = {} + if ts is not None: + if ts.state == "memory": + self.add_keys(worker=address, keys=[key]) + else: + t: tuple = parent._transition( + key, + "memory", + worker=address, + nbytes=nbytes[key], + typename=types[key], + ) + recommendations, client_msgs, worker_msgs = t + parent._transitions( + recommendations, client_msgs, worker_msgs + ) + recommendations = {} for ts in list(parent._unrunnable): valid: set = self.valid_workers(ts) @@ -4646,10 +4658,15 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): ts: TaskState = parent._tasks.get(key) if ts is None: return recommendations, client_msgs, worker_msgs + + if ts.state == "memory": + self.add_keys(worker=worker, keys=[key]) + return recommendations, client_msgs, worker_msgs + ws: WorkerState = parent._workers_dv[worker] ts._metadata.update(kwargs["metadata"]) - if ts._state == "processing": + if ts._state != "released": r: tuple = parent._transition(key, "memory", worker=worker, **kwargs) recommendations, client_msgs, worker_msgs = r @@ -5567,12 +5584,11 @@ async def gather(self, comm=None, keys=None, serializers=None): # Remove suspicious workers from the scheduler but allow them to # reconnect. await asyncio.gather( - *[ + *( self.remove_worker(address=worker, close=False) for worker in missing_workers - ] + ) ) - recommendations: dict client_msgs: dict = {} worker_msgs: dict = {} diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index cb848075d59..9cf0eb160f1 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -43,7 +43,7 @@ tls_only_security, varying, ) -from distributed.worker import dumps_function, dumps_task +from distributed.worker import dumps_function, dumps_task, get_worker if sys.version_info < (3, 8): try: @@ -2180,8 +2180,13 @@ async def test_gather_no_workers(c, s, a, b): assert list(res["keys"]) == ["x"] +@pytest.mark.slow +@pytest.mark.parametrize("reschedule_different_worker", [True, False]) +@pytest.mark.parametrize("swap_data_insert_order", [True, False]) @gen_cluster(client=True, client_kwargs={"direct_to_workers": False}) -async def test_gather_allow_worker_reconnect(c, s, a, b): +async def test_gather_allow_worker_reconnect( + c, s, a, b, reschedule_different_worker, swap_data_insert_order +): """ Test that client resubmissions allow failed workers to reconnect and re-use their results. Failure scenario would be a connection issue during result @@ -2189,29 +2194,53 @@ async def test_gather_allow_worker_reconnect(c, s, a, b): Upon connection failure, the worker is flagged as suspicious and removed from the scheduler. If the worker is healthy and reconnencts we want to use its results instead of recomputing them. + + See also distributed.tests.test_worker.py::test_worker_reconnects_mid_compute """ # GH3246 - already_calculated = [] - - import time - - def inc_slow(x): - # Once the graph below is rescheduled this computation runs again. We - # need to sleep for at least 0.5 seconds to give the worker a chance to - # reconnect (Heartbeat timing). In slow CI situations, the actual - # reconnect might take a bit longer, therefore wait more - if x in already_calculated: - time.sleep(2) - already_calculated.append(x) + if reschedule_different_worker: + from distributed.diagnostics.plugin import SchedulerPlugin + + class SwitchRestrictions(SchedulerPlugin): + def __init__(self, scheduler): + self.scheduler = scheduler + + def transition(self, key, start, finish, **kwargs): + if key in ("reducer", "final") and finish == "memory": + self.scheduler.tasks[key]._worker_restrictions = {b.address} + + plugin = SwitchRestrictions(s) + s.add_plugin(plugin) + + from distributed import Lock + + b_address = b.address + + def inc_slow(x, lock): + w = get_worker() + if w.address == b_address: + with lock: + return x + 1 return x + 1 - x = c.submit(inc_slow, 1) - y = c.submit(inc_slow, 2) + lock = Lock() + + await lock.acquire() + + x = c.submit(inc_slow, 1, lock, workers=[a.address], allow_other_workers=True) + + def reducer(*args): + return get_worker().address - def reducer(x, y): - return x + y + def finalizer(addr): + if swap_data_insert_order: + w = get_worker() + new_data = {k: w.data[k] for k in list(w.data.keys())[::-1]} + w.data = new_data + return addr - z = c.submit(reducer, x, y) + z = c.submit(reducer, x, key="reducer", workers=[a.address]) + fin = c.submit(finalizer, z, key="final", workers=[a.address]) s.rpc = await FlakyConnectionPool(failing_connections=1) @@ -2225,9 +2254,31 @@ def reducer(x, y): ) as client_logger: # Gather using the client (as an ordinary user would) # Upon a missing key, the client will reschedule the computations - res = await c.gather(z) + res = None + while not res: + try: + # This reduces test runtime by about a second since we're + # depending on a worker heartbeat for a reconnect. + res = await asyncio.wait_for(fin, 0.1) + except asyncio.TimeoutError: + await a.heartbeat() + + # Ensure that we're actually reusing the result + assert res == a.address + await lock.release() + + while not all(all(ts.state == "memory" for ts in w.tasks.values()) for w in [a, b]): + await asyncio.sleep(0.01) - assert res == 5 + assert z.key in a.tasks + assert z.key not in b.tasks + assert b.executed_count == 1 + for w in [a, b]: + assert x.key in w.tasks + assert w.tasks[x.key].state == "memory" + while not len(s.tasks[x.key].who_has) == 2: + await asyncio.sleep(0.01) + assert len(s.tasks[z.key].who_has) == 1 sched_logger = sched_logger.getvalue() client_logger = client_logger.getvalue() @@ -2243,24 +2294,6 @@ def reducer(x, y): # is rather an artifact and not the intention assert "Workers don't have promised key" in sched_logger - # Once the worker reconnects, it will also submit the keys it holds such - # that the scheduler again knows about the result. - # The final reduce step should then be used from the re-connected worker - # instead of recomputing it. - transitions_to_processing = [ - (key, start, timestamp) - for key, start, finish, recommendations, timestamp in s.transition_log - if finish == "processing" and "reducer" in key - ] - assert len(transitions_to_processing) == 1 - - finish_processing_transitions = 0 - for transition in s.transition_log: - key, start, finish, recommendations, timestamp = transition - if "reducer" in key and finish == "processing": - finish_processing_transitions += 1 - assert finish_processing_transitions == 1 - @gen_cluster(client=True) async def test_too_many_groups(c, s, a, b): diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index ca2b5a70d13..4c9cc83d45b 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2370,6 +2370,71 @@ async def test_hold_on_to_replicas(c, s, *workers): await asyncio.sleep(0.01) +@gen_cluster(client=True) +async def test_worker_reconnects_mid_compute(c, s, a, b): + """ + This test ensure that if a worker disconnects while computing a result, the scheduler will still accept the result. + + There is also an edge case tested which ensures that the reconnect is + successful if a task is currently executing, see + https://github.com/dask/distributed/issues/5078 + + See also distributed.tests.test_scheduler.py::test_gather_allow_worker_reconnect + """ + with captured_logger("distributed.scheduler") as s_logs: + # Let's put one task in memory to ensure the reconnect has tasks in + # different states + f1 = c.submit(inc, 1, workers=[a.address], allow_other_workers=True) + await f1 + a_address = a.address + a.periodic_callbacks["heartbeat"].stop() + await a.heartbeat() + a.heartbeat_active = True + + from distributed import Lock + + def fast_on_a(lock): + w = get_worker() + import time + + if w.address != a_address: + lock.acquire() + else: + time.sleep(1) + + lock = Lock() + # We want to be sure that A is the only one computing this result + async with lock: + + f2 = c.submit( + fast_on_a, lock, workers=[a.address], allow_other_workers=True + ) + + while f2.key not in a.tasks: + await asyncio.sleep(0.01) + + await s.stream_comms[a.address].close() + + assert len(s.workers) == 1 + a.heartbeat_active = False + await a.heartbeat() + assert len(s.workers) == 2 + # Since B is locked, this is ensured to originate from A + await f2 + + assert "Unexpected worker completed task" in s_logs.getvalue() + + while not len(s.tasks[f2.key].who_has) == 2: + await asyncio.sleep(0.001) + + # Ensure that all keys have been properly registered and will also be + # cleaned up nicely. + del f1, f2 + + while any(w.tasks for w in [a, b]): + await asyncio.sleep(0.001) + + @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) async def test_forget_dependents_after_release(c, s, a): diff --git a/distributed/worker.py b/distributed/worker.py index 806699845f2..c182a84249d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -699,6 +699,7 @@ def __init__( stream_handlers = { "close": self.close, "compute-task": self.add_task, + "cancel-compute": self.cancel_compute, "free-keys": self.handle_free_keys, "superfluous-data": self.handle_superfluous_data, "steal-request": self.steal_request, @@ -901,7 +902,14 @@ async def _register_with_scheduler(self): keys=list(self.data), nthreads=self.nthreads, name=self.name, - nbytes={ts.key: ts.get_nbytes() for ts in self.tasks.values()}, + nbytes={ + ts.key: ts.get_nbytes() + for ts in self.tasks.values() + # Only if the task is in memory this is a sensible + # result since otherwise it simply submits the + # default value + if ts.state == "memory" + }, types={k: typename(v) for k, v in self.data.items()}, now=time(), resources=self.total_resources, @@ -1544,6 +1552,22 @@ async def set_resources(self, **resources): # Task Management # ################### + def cancel_compute(self, key, reason): + """ + Cancel a task on a best effort basis. This is only possible while a task + is in state `waiting` or `ready`. + Nothing will happen otherwise. + """ + ts = self.tasks.get(key) + if ts and ts.state in ("waiting", "ready"): + self.log.append((key, "cancel-compute", reason)) + ts.scheduler_holds_ref = False + # All possible dependents of TS should not be in state Processing on + # scheduler side and therefore should not be assigned to a worker, + # yet. + assert not ts.dependents + self.release_key(key, reason=reason, report=False) + def add_task( self, key,