From 05c3929dc0e84d1a73aca741b5614f6566088438 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 29 Apr 2022 15:24:46 +0200 Subject: [PATCH 1/5] Fix recommendation ordering in transition_memory_released --- distributed/scheduler.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index d0e77a55412..507edf742f7 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2199,12 +2199,6 @@ def transition_memory_released(self, key, stimulus_id, safe: bool = False): worker_msgs, ) # don't try to recreate - for dts in ts.waiters: - if dts.state in ("no-worker", "processing"): - recommendations[dts.key] = "waiting" - elif dts.state == "waiting": - dts.waiting_on.add(ts) - # XXX factor this out? worker_msg = { "op": "free-keys", @@ -2229,6 +2223,12 @@ def transition_memory_released(self, key, stimulus_id, safe: bool = False): elif ts.who_wants or ts.waiters: recommendations[key] = "waiting" + for dts in ts.waiters: + if dts.state in ("no-worker", "processing"): + recommendations[dts.key] = "waiting" + elif dts.state == "waiting": + dts.waiting_on.add(ts) + if self.validate: assert not ts.waiting_on From a03c5c6cccd0f30647b8484c878734b686b420f0 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 29 Apr 2022 16:08:20 +0200 Subject: [PATCH 2/5] fixup! Fix recommendation ordering in transition_memory_released --- distributed/tests/test_worker.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 9346da09e95..6a084204d75 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2325,7 +2325,7 @@ def raise_exc(*args): h.key: "memory", res.key: "error", } - await assert_task_states_on_worker(expected_states_B, b) + await assert_task_states_on_worker(expected_states_B, b), b.tasks g.release() @@ -3157,7 +3157,10 @@ async def test_task_flight_compute_oserror(c, s, a, b): # inc is lost and needs to be recomputed. Therefore, sum is released ("free-keys", ("f1",)), ("f1", "release-key"), - ("f1", "waiting", "released", "released", {"f1": "forgotten"}), + # The recommendations here are hard to predict. Whatever key is + # currently scheduled to be fetched, if any, will be recommended to be + # released. + ("f1", "waiting", "released", "released", lambda msg: msg["f1"] == "forgotten"), ("f1", "released", "forgotten", "forgotten", {}), # Now, we actually compute the task *once*. This must not cycle back ("f1", "compute-task"), From e095b79e46499b7c1353fdc6cf0dfb5c3efabf7b Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 29 Apr 2022 16:09:50 +0200 Subject: [PATCH 3/5] Ensure intermediate release transitions will not forget tasks --- .../tests/test_worker_state_machine.py | 47 +++++++++++++++++++ distributed/worker.py | 28 +++++++---- 2 files changed, 66 insertions(+), 9 deletions(-) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index feb99ee995a..febbef6c1c5 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -1,9 +1,11 @@ +import asyncio from itertools import chain import pytest from distributed.protocol.serialize import Serialize from distributed.utils import recursive_to_dict +from distributed.utils_test import gen_cluster, inc from distributed.worker_state_machine import ( ExecuteFailureEvent, ExecuteSuccessEvent, @@ -19,6 +21,11 @@ ) +async def wait_for_state(key, state, dask_worker): + while key not in dask_worker.tasks or dask_worker.tasks[key].state != state: + await asyncio.sleep(0.005) + + def test_TaskState_get_nbytes(): assert TaskState("x", nbytes=123).get_nbytes() == 123 # Default to distributed.scheduler.default-data-size @@ -236,3 +243,43 @@ def test_executefailure_to_dict(): assert ev3.traceback is None assert ev3.exception_text == "exc text" assert ev3.traceback_text == "tb text" + + +@gen_cluster(client=True) +async def test_fetch_to_compute(c, s, a, b): + # Block ensure_communicating to ensure we indeed know that the task is in + # fetch and doesn't leave it accidentally + old_out_connections, b.total_out_connections = b.total_out_connections, 0 + old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0 + + f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True) + f2 = c.submit(inc, f1, workers=[b.address], key="f2") + + await wait_for_state(f1.key, "fetch", b) + await a.close() + + b.total_out_connections = old_out_connections + b.comm_threshold_bytes = old_comm_threshold + + await f2 + + +@gen_cluster(client=True) +async def test_fetch_via_amm_to_compute(c, s, a, b): + # Block ensure_communicating to ensure we indeed know that the task is in + # fetch and doesn't leave it accidentally + old_out_connections, b.total_out_connections = b.total_out_connections, 0 + old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0 + + f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True) + + await f1 + s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test") + + await wait_for_state(f1.key, "fetch", b) + await a.close() + + b.total_out_connections = old_out_connections + b.comm_threshold_bytes = old_comm_threshold + + await f1 diff --git a/distributed/worker.py b/distributed/worker.py index 8dd920d2f88..a4bedb789c3 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2636,18 +2636,28 @@ def _transition( recs, instructions = self._transition( ts, "released", stimulus_id=stimulus_id ) - v = recs.get(ts, (finish, *args)) v_state: str v_args: list | tuple - if isinstance(v, tuple): - v_state, *v_args = v - else: - v_state, v_args = v, () - b_recs, b_instructions = self._transition( - ts, v_state, *v_args, stimulus_id=stimulus_id + while v := recs.pop(ts, None): + if isinstance(v, tuple): + v_state, *v_args = v + else: + v_state, v_args = v, () + if v_state == "forgotten": + # We do not want to forget. The purpose of this + # transition path is to get to `finish` + continue + b_recs, b_instructions = self._transition( + ts, v_state, *v_args, stimulus_id=stimulus_id + ) + recs.update(b_recs) + instructions += b_instructions + + c_recs, c_instructions = self._transition( + ts, finish, *args, stimulus_id=stimulus_id ) - recs.update(b_recs) - instructions += b_instructions + recs.update(c_recs) + instructions += c_instructions except InvalidTransition: self.log_event( "invalid-worker-transition", From d332e6a573fcfd0196c782cde4650f0db515c351 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 29 Apr 2022 18:06:01 +0200 Subject: [PATCH 4/5] Allow cancelled missing transition (to released) --- distributed/worker.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/distributed/worker.py b/distributed/worker.py index a4bedb789c3..a63150f604e 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -602,6 +602,7 @@ def __init__( ("cancelled", "resumed"): self.transition_cancelled_resumed, ("cancelled", "fetch"): self.transition_cancelled_fetch, ("cancelled", "released"): self.transition_cancelled_released, + ("cancelled", "missing"): self.transition_cancelled_released, ("cancelled", "waiting"): self.transition_cancelled_waiting, ("cancelled", "forgotten"): self.transition_cancelled_forgotten, ("cancelled", "memory"): self.transition_cancelled_memory, @@ -2836,8 +2837,11 @@ def stimulus_story( def ensure_communicating(self) -> None: if self.status != Status.running: return + if not hasattr(self, "_stim_counter"): + self._stim_counter = 0 + self._stim_counter += 1 - stimulus_id = f"ensure-communicating-{time()}" + stimulus_id = f"ensure-communicating-{self._stim_counter}" skipped_worker_in_flight_or_busy = [] while self.data_needed and ( @@ -3194,7 +3198,12 @@ async def gather_dep( for d in has_what: ts = self.tasks[d] ts.who_has.remove(worker) - if not ts.who_has and ts.state not in ("released", "memory"): + if not ts.who_has and ts.state in ( + "fetch", + "flight", + "resumed", + "cancelled", + ): recommendations[ts] = "missing" self.log.append( ("missing-who-has", worker, ts.key, stimulus_id, time()) From 9682b97803c1b53e62d33b2abe13e14ba9e67e30 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 29 Apr 2022 18:34:07 +0200 Subject: [PATCH 5/5] Rework missing transitions --- distributed/worker.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/distributed/worker.py b/distributed/worker.py index a63150f604e..49ccc6e0772 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -622,7 +622,7 @@ def __init__( ("executing", "released"): self.transition_executing_released, ("executing", "rescheduled"): self.transition_executing_rescheduled, ("fetch", "flight"): self.transition_fetch_flight, - ("fetch", "missing"): self.transition_fetch_missing, + ("fetch", "missing"): self.transition_generic_to_missing, ("fetch", "released"): self.transition_generic_released, ("flight", "error"): self.transition_flight_error, ("flight", "fetch"): self.transition_flight_fetch, @@ -642,6 +642,7 @@ def __init__( ("ready", "released"): self.transition_generic_released, ("released", "error"): self.transition_generic_error, ("released", "fetch"): self.transition_released_fetch, + ("released", "missing"): self.transition_released_fetch, ("released", "forgotten"): self.transition_released_forgotten, ("released", "memory"): self.transition_released_memory, ("released", "waiting"): self.transition_released_waiting, @@ -2018,6 +2019,7 @@ def transition_missing_fetch( if self.validate: assert ts.state == "missing" assert ts.priority is not None + assert ts.who_has self._missing_dep_flight.discard(ts) ts.state = "fetch" @@ -2046,7 +2048,7 @@ def transition_flight_missing( ts.done = False return {}, [] - def transition_fetch_missing( + def transition_generic_to_missing( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: ts.state = "missing" @@ -2060,6 +2062,8 @@ def transition_released_fetch( if self.validate: assert ts.state == "released" assert ts.priority is not None + if not ts.who_has: + return {ts: "missing"}, [] ts.state = "fetch" ts.done = False self.data_needed.push(ts) @@ -3261,10 +3265,7 @@ async def gather_dep( "stimulus_id": stimulus_id, } ) - if ts.who_has: - recommendations[ts] = "fetch" - elif ts.state not in ("released", "memory"): - recommendations[ts] = "missing" + recommendations[ts] = "fetch" del data, response self.transitions(recommendations, stimulus_id=stimulus_id)