From 701e4b8621e4c4b80a8d5f5313f0f1f8de005701 Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 21 Jul 2021 19:09:01 +0200 Subject: [PATCH 1/3] Ensure worker reconnect registers existing tasks properly --- distributed/scheduler.py | 34 +++++++++++------ distributed/tests/test_worker.py | 63 ++++++++++++++++++++++++++++++++ distributed/worker.py | 9 ++++- 3 files changed, 93 insertions(+), 13 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 9c25d2c85b3..2bde29f95d5 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2571,7 +2571,6 @@ def transition_processing_memory( ws, key, ) - return recommendations, client_msgs, worker_msgs has_compute_startstop: bool = False compute_start: double @@ -4106,17 +4105,25 @@ async def add_worker( if nbytes: for key in nbytes: ts: TaskState = parent._tasks.get(key) - if ts is not None and ts._state in ("processing", "waiting"): - t: tuple = parent._transition( - key, - "memory", - worker=address, - nbytes=nbytes[key], - typename=types[key], - ) - recommendations, client_msgs, worker_msgs = t - parent._transitions(recommendations, client_msgs, worker_msgs) - recommendations = {} + if ts is not None: + if ts._state in ("processing", "waiting"): + t: tuple = parent._transition( + key, + "memory", + worker=address, + nbytes=nbytes[key], + typename=types[key], + ) + recommendations, client_msgs, worker_msgs = t + parent._transitions( + recommendations, client_msgs, worker_msgs + ) + recommendations = {} + else: + self.add_keys( + worker=address, + keys=[key], + ) for ts in list(parent._unrunnable): valid: set = self.valid_workers(ts) @@ -5122,6 +5129,9 @@ def handle_task_finished(self, key=None, worker=None, **msg): return validate_key(key) + if key in self.tasks and self.tasks[key].state == "memory": + self.add_keys(worker=worker, keys=[key]) + recommendations: dict client_msgs: dict worker_msgs: dict diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index e5e877a2a98..31fb90cba9a 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2368,3 +2368,66 @@ async def test_hold_on_to_replicas(c, s, *workers): while len(workers[2].tasks) > 1: await asyncio.sleep(0.01) + + +@gen_cluster(client=True) +async def test_worker_reconnects_mid_compute(c, s, a, b): + """ + This test ensure that if a worker disconnects while computing a result, the scheduler will still accept the result. + + There is also an edge case tested which ensures that the reconnect is + successful if a task is currently executing, see + https://github.com/dask/distributed/issues/5078 + """ + with captured_logger("distributed.scheduler") as s_logs: + # Let's put one task in memory to ensure the reconnect has tasks in + # different states + f1 = c.submit(inc, 1, workers=[a.address], allow_other_workers=True) + await f1 + a_address = a.address + a.periodic_callbacks["heartbeat"].stop() + await a.heartbeat() + a.heartbeat_active = True + + from distributed import Lock + + def fast_on_a(lock): + w = get_worker() + import time + + if w.address != a_address: + lock.acquire() + else: + time.sleep(1) + + lock = Lock() + # We want to be sure that A is the only one computing this result + async with lock: + + f2 = c.submit( + fast_on_a, lock, workers=[a.address], allow_other_workers=True + ) + + while f2.key not in a.tasks: + await asyncio.sleep(0.01) + + await s.stream_comms[a.address].close() + + assert len(s.workers) == 1 + a.heartbeat_active = False + await a.heartbeat() + assert len(s.workers) == 2 + # Since B is locked, this is ensured to originate from A + await f2 + + assert "Unexpected worker completed task" in s_logs.getvalue() + + while not len(s.tasks[f2.key].who_has) == 2: + await asyncio.sleep(0.001) + + # Ensure that all keys have been properly registered and will also be + # cleaned up nicely. + del f1, f2 + + while any(w.tasks for w in [a, b]): + await asyncio.sleep(0.001) diff --git a/distributed/worker.py b/distributed/worker.py index 95d6a116e63..9c959b3fc4a 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -904,7 +904,14 @@ async def _register_with_scheduler(self): keys=list(self.data), nthreads=self.nthreads, name=self.name, - nbytes={ts.key: ts.get_nbytes() for ts in self.tasks.values()}, + nbytes={ + ts.key: ts.get_nbytes() + for ts in self.tasks.values() + # Only if the task is in memory this is a sensible + # result since otherwise it simply submits the + # default value + if ts.state == "memory" + }, types={k: typename(v) for k, v in self.data.items()}, now=time(), resources=self.total_resources, From 3b7b46f0640472e225b24454becb9bd62cce91a5 Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 22 Jul 2021 19:31:17 +0200 Subject: [PATCH 2/3] handle resheduled tasks --- distributed/scheduler.py | 40 +++++------ distributed/tests/test_scheduler.py | 107 ++++++++++++++++++---------- distributed/tests/test_worker.py | 2 + distributed/worker.py | 34 ++++++--- 4 files changed, 114 insertions(+), 69 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 2bde29f95d5..b71a8959daa 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2571,6 +2571,13 @@ def transition_processing_memory( ws, key, ) + worker_msgs[ts._processing_on.address] = [ + { + "op": "cancel-compute", + "key": key, + "reason": "Finished on different worker", + } + ] has_compute_startstop: bool = False compute_start: double @@ -4103,27 +4110,17 @@ async def add_worker( client_msgs: dict = {} worker_msgs: dict = {} if nbytes: + assert isinstance(nbytes, dict) for key in nbytes: ts: TaskState = parent._tasks.get(key) if ts is not None: - if ts._state in ("processing", "waiting"): - t: tuple = parent._transition( - key, - "memory", - worker=address, - nbytes=nbytes[key], - typename=types[key], - ) - recommendations, client_msgs, worker_msgs = t - parent._transitions( - recommendations, client_msgs, worker_msgs - ) - recommendations = {} - else: - self.add_keys( - worker=address, - keys=[key], - ) + self.handle_task_finished( + key=key, + worker=address, + nbytes=nbytes[key], + typename=types[key], + metadata=ts.metadata, + ) for ts in list(parent._unrunnable): valid: set = self.valid_workers(ts) @@ -4516,7 +4513,7 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): ws: WorkerState = parent._workers_dv[worker] ts._metadata.update(kwargs["metadata"]) - if ts._state == "processing": + if ts._state != "released": r: tuple = parent._transition(key, "memory", worker=worker, **kwargs) recommendations, client_msgs, worker_msgs = r @@ -5434,12 +5431,11 @@ async def gather(self, comm=None, keys=None, serializers=None): # Remove suspicious workers from the scheduler but allow them to # reconnect. await asyncio.gather( - *[ + *( self.remove_worker(address=worker, close=False) for worker in missing_workers - ] + ) ) - recommendations: dict client_msgs: dict = {} worker_msgs: dict = {} diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 6a151091787..da54849b413 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -43,7 +43,7 @@ tls_only_security, varying, ) -from distributed.worker import dumps_function, dumps_task +from distributed.worker import dumps_function, dumps_task, get_worker if sys.version_info < (3, 8): try: @@ -2182,8 +2182,13 @@ async def test_gather_no_workers(c, s, a, b): assert list(res["keys"]) == ["x"] +@pytest.mark.slow +@pytest.mark.parametrize("reschedule_different_worker", [True, False]) +@pytest.mark.parametrize("swap_data_insert_order", [True, False]) @gen_cluster(client=True, client_kwargs={"direct_to_workers": False}) -async def test_gather_allow_worker_reconnect(c, s, a, b): +async def test_gather_allow_worker_reconnect( + c, s, a, b, reschedule_different_worker, swap_data_insert_order +): """ Test that client resubmissions allow failed workers to reconnect and re-use their results. Failure scenario would be a connection issue during result @@ -2191,29 +2196,53 @@ async def test_gather_allow_worker_reconnect(c, s, a, b): Upon connection failure, the worker is flagged as suspicious and removed from the scheduler. If the worker is healthy and reconnencts we want to use its results instead of recomputing them. + + See also distributed.tests.test_worker.py::test_worker_reconnects_mid_compute """ # GH3246 - already_calculated = [] - - import time - - def inc_slow(x): - # Once the graph below is rescheduled this computation runs again. We - # need to sleep for at least 0.5 seconds to give the worker a chance to - # reconnect (Heartbeat timing). In slow CI situations, the actual - # reconnect might take a bit longer, therefore wait more - if x in already_calculated: - time.sleep(2) - already_calculated.append(x) + if reschedule_different_worker: + from distributed.diagnostics.plugin import SchedulerPlugin + + class SwitchRestrictions(SchedulerPlugin): + def __init__(self, scheduler): + self.scheduler = scheduler + + def transition(self, key, start, finish, **kwargs): + if key in ("reducer", "final") and finish == "memory": + self.scheduler.tasks[key]._worker_restrictions = {b.address} + + plugin = SwitchRestrictions(s) + s.add_plugin(plugin) + + from distributed import Lock + + b_address = b.address + + def inc_slow(x, lock): + w = get_worker() + if w.address == b_address: + with lock: + return x + 1 return x + 1 - x = c.submit(inc_slow, 1) - y = c.submit(inc_slow, 2) + lock = Lock() + + await lock.acquire() + + x = c.submit(inc_slow, 1, lock, workers=[a.address], allow_other_workers=True) - def reducer(x, y): - return x + y + def reducer(*args): + return get_worker().address - z = c.submit(reducer, x, y) + def finalizer(addr): + if swap_data_insert_order: + w = get_worker() + new_data = {k: w.data[k] for k in list(w.data.keys())[::-1]} + w.data = new_data + return addr + + z = c.submit(reducer, x, key="reducer", workers=[a.address]) + fin = c.submit(finalizer, z, key="final", workers=[a.address]) s.rpc = await FlakyConnectionPool(failing_connections=1) @@ -2227,9 +2256,27 @@ def reducer(x, y): ) as client_logger: # Gather using the client (as an ordinary user would) # Upon a missing key, the client will reschedule the computations - res = await c.gather(z) + res = None + while not res: + try: + # This reduces test runtime by about a second since we're + # depending on a worker heartbeat for a reconnect. + res = await asyncio.wait_for(fin, 0.1) + except asyncio.TimeoutError: + await a.heartbeat() + + # Ensure that we're actually reusing the result + assert res == a.address + await lock.release() + + while any(w.executing_count for w in [a, b]): + await asyncio.sleep(0.01) - assert res == 5 + assert z.key in a.tasks + assert z.key not in b.tasks + assert b.executed_count == 1 + assert len(s.tasks[z.key].who_has) == 1 + assert len(s.tasks[x.key].who_has) == 2 sched_logger = sched_logger.getvalue() client_logger = client_logger.getvalue() @@ -2245,24 +2292,6 @@ def reducer(x, y): # is rather an artifact and not the intention assert "Workers don't have promised key" in sched_logger - # Once the worker reconnects, it will also submit the keys it holds such - # that the scheduler again knows about the result. - # The final reduce step should then be used from the re-connected worker - # instead of recomputing it. - transitions_to_processing = [ - (key, start, timestamp) - for key, start, finish, recommendations, timestamp in s.transition_log - if finish == "processing" and "reducer" in key - ] - assert len(transitions_to_processing) == 1 - - finish_processing_transitions = 0 - for transition in s.transition_log: - key, start, finish, recommendations, timestamp = transition - if "reducer" in key and finish == "processing": - finish_processing_transitions += 1 - assert finish_processing_transitions == 1 - @gen_cluster(client=True) async def test_too_many_groups(c, s, a, b): diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 31fb90cba9a..3b49b7bbb22 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2378,6 +2378,8 @@ async def test_worker_reconnects_mid_compute(c, s, a, b): There is also an edge case tested which ensures that the reconnect is successful if a task is currently executing, see https://github.com/dask/distributed/issues/5078 + + See also distributed.tests.test_scheduler.py::test_gather_allow_worker_reconnect """ with captured_logger("distributed.scheduler") as s_logs: # Let's put one task in memory to ensure the reconnect has tasks in diff --git a/distributed/worker.py b/distributed/worker.py index 9c959b3fc4a..aef011d520d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -16,7 +16,7 @@ from datetime import timedelta from inspect import isawaitable from pickle import PicklingError -from typing import Dict, Iterable, Optional +from typing import Dict, Hashable, Iterable, Optional from tlz import first, keymap, merge, pluck # noqa: F401 from tornado.ioloop import IOLoop, PeriodicCallback @@ -702,6 +702,7 @@ def __init__( stream_handlers = { "close": self.close, "compute-task": self.add_task, + "cancel-compute": self.cancel_compute, "free-keys": self.handle_free_keys, "superfluous-data": self.handle_superfluous_data, "steal-request": self.steal_request, @@ -1554,6 +1555,22 @@ async def set_resources(self, **resources): # Task Management # ################### + def cancel_compute(self, key, reason): + """ + Cancel a task on a best effort basis. This is only possible while a task + is in state `waiting` or `ready`. + Nothing will happen otherwise. + """ + ts = self.tasks.get(key) + if ts and ts.state in ("waiting", "ready"): + self.log.append((key, "cancel-compute", reason)) + ts.scheduler_holds_ref = False + # All possible dependents of TS should not be in state Processing on + # scheduler side and therefore should not be assigned to a worker, + # yet. + assert not ts.dependents + self.release_key(key, reason=reason, report=False) + def add_task( self, key, @@ -1992,11 +2009,6 @@ def transition_executing_done(self, ts, value=no_value, report=True): for d in ts.dependents: d.waiting_for_data.add(ts.key) - # Don't release the dependency keys, but do remove them from `dependents` - for dependency in ts.dependencies: - dependency.dependents.discard(ts) - ts.dependencies.clear() - if report and self.batched_stream and self.status == Status.running: self.send_task_state_to_scheduler(ts) else: @@ -2613,7 +2625,7 @@ def steal_request(self, key): def release_key( self, - key: str, + key: Hashable, cause: Optional[TaskState] = None, reason: Optional[str] = None, report: bool = True, @@ -2621,7 +2633,7 @@ def release_key( try: if self.validate: - assert isinstance(key, str) + assert not isinstance(key, TaskState) ts = self.tasks.get(key, None) # If the scheduler holds a reference which is usually the # case when it instructed the task to be computed here or if @@ -2659,6 +2671,12 @@ def release_key( for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity + for d in ts.dependencies: + d.dependents.discard(ts) + + if not d.dependents and d.state in ("flight", "fetch"): + self.release_key(d.key, reason="Dependent released") + if report: # Inform the scheduler of keys which will have gone missing # We are releasing them before they have completed From 0eb2094ab6f999f007a3f79efd8e89c0a0ab0571 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 23 Jul 2021 11:53:04 +0200 Subject: [PATCH 3/3] Test more robust --- distributed/tests/test_scheduler.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index da54849b413..5a42159d9fd 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -2269,14 +2269,18 @@ def finalizer(addr): assert res == a.address await lock.release() - while any(w.executing_count for w in [a, b]): + while not all(all(ts.state == "memory" for ts in w.tasks.values()) for w in [a, b]): await asyncio.sleep(0.01) assert z.key in a.tasks assert z.key not in b.tasks assert b.executed_count == 1 + for w in [a, b]: + assert x.key in w.tasks + assert w.tasks[x.key].state == "memory" + while not len(s.tasks[x.key].who_has) == 2: + await asyncio.sleep(0.01) assert len(s.tasks[z.key].who_has) == 1 - assert len(s.tasks[x.key].who_has) == 2 sched_logger = sched_logger.getvalue() client_logger = client_logger.getvalue()