From c2a3eadac315fc43607959f526ca52b19dc05431 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 1 Apr 2022 12:18:51 +0200 Subject: [PATCH 01/29] test_scheduler.py succeeds --- distributed/core.py | 4 ++ distributed/scheduler.py | 85 +++++++++++++++++++++++++---- distributed/tests/test_scheduler.py | 6 +- 3 files changed, 82 insertions(+), 13 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index da48132e9e0..e89eea16559 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -12,6 +12,7 @@ from collections import defaultdict from collections.abc import Container from contextlib import suppress +from contextvars import ContextVar from enum import Enum from functools import partial from typing import Callable, ClassVar @@ -114,6 +115,9 @@ def _expects_comm(func: Callable) -> bool: return False +SERVER_STIMULUS_ID: ContextVar = ContextVar("stimulus_id") + + class Server: """Dask Distributed Server diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 4882bf81919..6a6d7a5258c 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -65,7 +65,7 @@ unparse_host_port, ) from distributed.comm.addressing import addresses_from_user_args -from distributed.core import Status, clean_exception, rpc, send_recv +from distributed.core import SERVER_STIMULUS_ID, Status, clean_exception, rpc, send_recv from distributed.diagnostics.memory_sampler import MemorySamplerExtension from distributed.diagnostics.plugin import SchedulerPlugin, _get_plugin_name from distributed.event import EventExtension @@ -2370,11 +2370,17 @@ def _transition(self, key, finish: str, *args, **kwargs): else: raise RuntimeError("Impossible transition from %r to %r" % start_finish) + try: + stimulus_id = SERVER_STIMULUS_ID.get() + except LookupError: + if self._validate: + raise + finish2 = ts._state # FIXME downcast antipattern scheduler = pep484_cast(Scheduler, self) scheduler.transition_log.append( - (key, start, finish2, recommendations, time()) + (key, start, finish2, recommendations, stimulus_id, time()) ) if parent._validate: logger.debug( @@ -4459,9 +4465,14 @@ async def add_worker( versions=None, nanny=None, extra=None, + stimulus_id=None, ): """Add a new worker to the cluster""" parent: SchedulerState = cast(SchedulerState, self) + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"add-worker-{time()}") with log_errors(): address = self.coerce_address(address, resolve_address) address = normalize_address(address) @@ -4648,7 +4659,13 @@ def update_graph_hlg( actors=None, fifo_timeout=0, code=None, + stimulus_id=None, ): + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"update-graph-hlg-{time()}") + unpacked_graph = HighLevelGraph.__dask_distributed_unpack__(hlg) dsk = unpacked_graph["dsk"] dependencies = unpacked_graph["deps"] @@ -4687,6 +4704,7 @@ def update_graph_hlg( fifo_timeout, annotations, code=code, + stimulus_id=SERVER_STIMULUS_ID.get(), ) def update_graph( @@ -4706,6 +4724,7 @@ def update_graph( fifo_timeout=0, annotations=None, code=None, + stimulus_id=None, ): """ Add new computations to the internal dask graph @@ -4713,6 +4732,10 @@ def update_graph( This happens whenever the Client calls submit, map, get, or compute. """ parent: SchedulerState = cast(SchedulerState, self) + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"update-graph-hlg-{time()}") start = time() fifo_timeout = parse_timedelta(fifo_timeout) keys = set(keys) @@ -5074,7 +5097,7 @@ def stimulus_retry(self, keys, client=None): return tuple(seen) - async def remove_worker(self, address, safe=False, close=True): + async def remove_worker(self, address, safe=False, close=True, stimulus_id=None): """ Remove worker from cluster @@ -5083,6 +5106,12 @@ async def remove_worker(self, address, safe=False, close=True): state. """ parent: SchedulerState = cast(SchedulerState, self) + + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"remove-worker-{time()}") + with log_errors(): if self.status == Status.closed: return @@ -5200,10 +5229,14 @@ def stimulus_cancel(self, comm, keys=None, client=None, force=False): for key in keys: self.cancel_key(key, client, force=force) - def cancel_key(self, key, client, retries=5, force=False): + def cancel_key(self, key, client, retries=5, force=False, stimulus_id=None): """Cancel a particular key and all dependents""" # TODO: this should be converted to use the transition mechanism parent: SchedulerState = cast(SchedulerState, self) + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"cancel-key-{time()}") ts: TaskState = parent._tasks.get(key) dts: TaskState try: @@ -5225,8 +5258,13 @@ def cancel_key(self, key, client, retries=5, force=False): for cs in clients: self.client_releases_keys(keys=[key], client=cs._client_key) - def client_desires_keys(self, keys=None, client=None): + def client_desires_keys(self, keys=None, client=None, stimulus_id=None): parent: SchedulerState = cast(SchedulerState, self) + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"client-desires-keys-{time()}") + cs: ClientState = parent._clients.get(client) if cs is None: # For publish, queues etc. @@ -5243,10 +5281,15 @@ def client_desires_keys(self, keys=None, client=None): if ts._state in ("memory", "erred"): self.report_on_key(ts=ts, client=client) - def client_releases_keys(self, keys=None, client=None): + def client_releases_keys(self, keys=None, client=None, stimulus_id=None): """Remove keys from client desired list""" parent: SchedulerState = cast(SchedulerState, self) + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"client-releases-keys-{time()}") + if not isinstance(keys, list): keys = list(keys) cs: ClientState = parent._clients[client] @@ -5464,12 +5507,18 @@ def report(self, msg: dict, ts: TaskState = None, client: str = None): "Closed comm %r while trying to write %s", c, msg, exc_info=True ) - async def add_client(self, comm: Comm, client: str, versions: dict) -> None: + async def add_client( + self, comm: Comm, client: str, versions: dict, stimulus_id=None + ) -> None: """Add client to network We listen to all future messages from this Comm. """ parent: SchedulerState = cast(SchedulerState, self) + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"add-client-{time()}") assert client is not None comm.name = "Scheduler->Client" logger.info("Receive client connection: %s", client) @@ -5513,9 +5562,13 @@ async def add_client(self, comm: Comm, client: str, versions: dict) -> None: except TypeError: # comm becomes None during GC pass - def remove_client(self, client: str) -> None: + def remove_client(self, client: str, stimulus_id=None) -> None: """Remove client from network""" parent: SchedulerState = cast(SchedulerState, self) + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"remove-client-{time()}") if self.status == Status.running: logger.info("Remove client %s", client) self.log_event(["all", client], {"action": "remove-client", "client": client}) @@ -6267,7 +6320,10 @@ async def gather_on_worker( return keys_failed async def delete_worker_data( - self, worker_address: str, keys: "Collection[str]" + self, + worker_address: str, + keys: "Collection[str]", + stimulus_id=None, ) -> None: """Delete data from a worker and update the corresponding worker/task states @@ -6279,6 +6335,10 @@ async def delete_worker_data( List of keys to delete on the specified worker """ parent: SchedulerState = cast(SchedulerState, self) + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"delete-worker-data-{time()}") try: await retry_operation( @@ -7609,13 +7669,18 @@ def story(self, *keys): transition_story = story - def reschedule(self, key=None, worker=None): + def reschedule(self, key=None, worker=None, stimulus_id=None): """Reschedule a task Things may have shifted and this task may now be better suited to run elsewhere """ parent: SchedulerState = cast(SchedulerState, self) + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"reschedule-{time()}") + ts: TaskState try: ts = parent._tasks[key] diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 2750868dc01..e2c055d1655 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -810,7 +810,7 @@ async def test_story(c, s, a, b): story = s.story(x.key) assert all(line in s.transition_log for line in story) assert len(story) < len(s.transition_log) - assert all(x.key == line[0] or x.key in line[-2] for line in story) + assert all(x.key == line[0] or x.key in line[3] for line in story) assert len(s.story(x.key, y.key)) > len(story) @@ -3253,7 +3253,7 @@ async def test_worker_reconnect_task_memory(c, s, a): await res assert ("no-worker", "memory") in { - (start, finish) for (_, start, finish, _, _) in s.transition_log + (start, finish) for (_, start, finish, _, _, _) in s.transition_log } @@ -3277,7 +3277,7 @@ async def test_worker_reconnect_task_memory_with_resources(c, s, a): await res assert ("no-worker", "memory") in { - (start, finish) for (_, start, finish, _, _) in s.transition_log + (start, finish) for (_, start, finish, _, _, _) in s.transition_log } From 5d10407f113416e86bd5fcf389d819e9c79a71ef Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 1 Apr 2022 13:35:03 +0200 Subject: [PATCH 02/29] Working test_worker.py and test_client.py --- distributed/scheduler.py | 18 ++++++++++++++++-- distributed/tests/test_client.py | 2 +- distributed/tests/test_worker.py | 2 +- distributed/utils_test.py | 20 +++++++++++++++++++- 4 files changed, 37 insertions(+), 5 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 6a6d7a5258c..2ea0f8a8913 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2375,6 +2375,8 @@ def _transition(self, key, finish: str, *args, **kwargs): except LookupError: if self._validate: raise + else: + stimulus_id = "" finish2 = ts._state # FIXME downcast antipattern @@ -5067,8 +5069,13 @@ def stimulus_task_erred( **kwargs, ) - def stimulus_retry(self, keys, client=None): + def stimulus_retry(self, keys, client=None, stimulus_id=None): parent: SchedulerState = cast(SchedulerState, self) + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"stimulus-retry-{time()}") + logger.info("Client %s requests to retry %d keys", client, len(keys)) if client: self.log_event(client, {"action": "retry", "count": len(keys)}) @@ -5219,8 +5226,15 @@ def remove_worker_from_events(): return "OK" - def stimulus_cancel(self, comm, keys=None, client=None, force=False): + def stimulus_cancel( + self, comm, keys=None, client=None, force=False, stimulus_id=None + ): """Stop execution on a list of keys""" + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"add-worker-{time()}") + logger.info("Client %s requests to cancel %d keys", client, len(keys)) if client: self.log_event( diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index c8b594c8467..ed4c334027a 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -4576,7 +4576,7 @@ def test_auto_normalize_collection_sync(c): def assert_no_data_loss(scheduler): - for key, start, finish, recommendations, _ in scheduler.transition_log: + for key, start, finish, recommendations, _, _ in scheduler.transition_log: if start == "memory" and finish == "released": for k, v in recommendations.items(): assert not (k == key and v == "waiting") diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index d3a0ce8b28f..eabf8d46147 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2605,7 +2605,7 @@ def sink(a, b, *args): @gen_cluster(client=True) -async def test_gather_dep_exception_one_task_2(c, s, a, b): +async def test_gather_dep_exception_one_task_2(c, s, a, b, set_stimulus): """Ensure an exception in a single task does not tear down an entire batch of gather_dep The below triggers an fetch->memory transition diff --git a/distributed/utils_test.py b/distributed/utils_test.py index a8e2652d120..608f8203530 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -50,7 +50,14 @@ from distributed.comm.tcp import TCP, BaseTCPConnector from distributed.compatibility import WINDOWS from distributed.config import initialize_logging -from distributed.core import CommClosedError, ConnectionPool, Status, connect, rpc +from distributed.core import ( + SERVER_STIMULUS_ID, + CommClosedError, + ConnectionPool, + Status, + connect, + rpc, +) from distributed.deploy import SpecCluster from distributed.diagnostics.plugin import WorkerPlugin from distributed.metrics import time @@ -117,6 +124,17 @@ def invalid_python_script(tmpdir_factory): return local_file +@pytest.fixture +def set_stimulus(request): + stimulus_id = f"{request.function.__name__.replace('_', '-')}-{time()}" + + try: + token = SERVER_STIMULUS_ID.set(stimulus_id) + yield + finally: + SERVER_STIMULUS_ID.reset(token) + + async def cleanup_global_workers(): for worker in Worker._instances: await worker.close(report=False, executor_wait=False) From e30be1bb3c04ebce0c72858f3a30465afb0fe47f Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 1 Apr 2022 13:35:30 +0200 Subject: [PATCH 03/29] Support transition_log in http output --- distributed/http/templates/task.html | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/distributed/http/templates/task.html b/distributed/http/templates/task.html index 0b5c10695e0..075b25e2264 100644 --- a/distributed/http/templates/task.html +++ b/distributed/http/templates/task.html @@ -118,16 +118,18 @@

Transition Log

Key Start Finish + Stimulus ID Recommended Key Recommended Action - {% for key, start, finish, recommendations, transition_time in scheduler.story(Task) %} + {% for key, start, finish, recommendations, stimulus_id, transition_time in scheduler.story(Task) %} {{ fromtimestamp(transition_time) }} {{key}} {{ start }} {{ finish }} + stimulus_id @@ -137,6 +139,7 @@

Transition Log

+ {{key2}} {{ rec }} From 308568e1862ea61280bbbd16f3b930446217bb1d Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 1 Apr 2022 13:50:13 +0200 Subject: [PATCH 04/29] Rename assert_worker_story assert_story --- distributed/tests/test_cluster_dump.py | 4 +-- distributed/tests/test_utils_test.py | 48 +++++++++++++------------- distributed/tests/test_worker.py | 18 +++++----- distributed/utils_test.py | 4 +-- 4 files changed, 37 insertions(+), 37 deletions(-) diff --git a/distributed/tests/test_cluster_dump.py b/distributed/tests/test_cluster_dump.py index 963d7563725..3050c46ebd1 100644 --- a/distributed/tests/test_cluster_dump.py +++ b/distributed/tests/test_cluster_dump.py @@ -8,7 +8,7 @@ import distributed from distributed.cluster_dump import DumpArtefact, _tuple_to_list, write_state -from distributed.utils_test import assert_worker_story, gen_cluster, gen_test, inc +from distributed.utils_test import assert_story, gen_cluster, gen_test, inc @pytest.mark.parametrize( @@ -144,7 +144,7 @@ def _expected_story(task_key): assert len(story) == len(fut_keys) for k, task_story in story.items(): - assert_worker_story( + assert_story( task_story, [ (k, "compute-task"), diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index dc481aae17f..c0d377acc01 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -21,7 +21,7 @@ from distributed.utils_test import ( _LockedCommPool, _UnhashableCallable, - assert_worker_story, + assert_story, check_process_leak, cluster, dump_cluster_state, @@ -406,7 +406,7 @@ async def inner_test(c, s, a, b): assert "workers" in state -def test_assert_worker_story(): +def test_assert_story(): now = time() story = [ ("foo", "id1", now - 600), @@ -414,38 +414,38 @@ def test_assert_worker_story(): ("baz", {1: 2}, "id2", now), ] # strict=False - assert_worker_story(story, [("foo",), ("bar",), ("baz", {1: 2})]) - assert_worker_story(story, []) - assert_worker_story(story, [("foo",)]) - assert_worker_story(story, [("foo",), ("bar",)]) - assert_worker_story(story, [("baz", lambda d: d[1] == 2)]) + assert_story(story, [("foo",), ("bar",), ("baz", {1: 2})]) + assert_story(story, []) + assert_story(story, [("foo",)]) + assert_story(story, [("foo",), ("bar",)]) + assert_story(story, [("baz", lambda d: d[1] == 2)]) with pytest.raises(AssertionError): - assert_worker_story(story, [("foo", "nomatch")]) + assert_story(story, [("foo", "nomatch")]) with pytest.raises(AssertionError): - assert_worker_story(story, [("baz",)]) + assert_story(story, [("baz",)]) with pytest.raises(AssertionError): - assert_worker_story(story, [("baz", {1: 3})]) + assert_story(story, [("baz", {1: 3})]) with pytest.raises(AssertionError): - assert_worker_story(story, [("foo",), ("bar",), ("baz", "extra"), ("+1",)]) + assert_story(story, [("foo",), ("bar",), ("baz", "extra"), ("+1",)]) with pytest.raises(AssertionError): - assert_worker_story(story, [("baz", lambda d: d[1] == 3)]) + assert_story(story, [("baz", lambda d: d[1] == 3)]) with pytest.raises(KeyError): # Faulty lambda - assert_worker_story(story, [("baz", lambda d: d[2] == 1)]) - assert_worker_story([], []) - assert_worker_story([("foo", "id1", now)], [("foo",)]) + assert_story(story, [("baz", lambda d: d[2] == 1)]) + assert_story([], []) + assert_story([("foo", "id1", now)], [("foo",)]) with pytest.raises(AssertionError): - assert_worker_story([], [("foo",)]) + assert_story([], [("foo",)]) # strict=True - assert_worker_story([], [], strict=True) - assert_worker_story([("foo", "id1", now)], [("foo",)]) - assert_worker_story(story, [("foo",), ("bar",), ("baz", {1: 2})], strict=True) + assert_story([], [], strict=True) + assert_story([("foo", "id1", now)], [("foo",)]) + assert_story(story, [("foo",), ("bar",), ("baz", {1: 2})], strict=True) with pytest.raises(AssertionError): - assert_worker_story(story, [("foo",), ("bar",)], strict=True) + assert_story(story, [("foo",), ("bar",)], strict=True) with pytest.raises(AssertionError): - assert_worker_story(story, [("foo",), ("baz", {1: 2})], strict=True) + assert_story(story, [("foo",), ("baz", {1: 2})], strict=True) with pytest.raises(AssertionError): - assert_worker_story(story, [], strict=True) + assert_story(story, [], strict=True) @pytest.mark.parametrize( @@ -466,11 +466,11 @@ def test_assert_worker_story(): ), ], ) -def test_assert_worker_story_malformed_story(story_factory): +def test_assert_story_malformed_story(story_factory): # defer the calls to time() to when the test runs rather than collection story = story_factory() with pytest.raises(AssertionError, match="Malformed story event"): - assert_worker_story(story, []) + assert_story(story, []) @gen_cluster() diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index eabf8d46147..18a303f2f0f 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -45,7 +45,7 @@ from distributed.utils_test import ( TaskStateMetadataPlugin, _LockedCommPool, - assert_worker_story, + assert_story, captured_logger, dec, div, @@ -1403,7 +1403,7 @@ def assert_amm_transfer_story(key: str, w_from: Worker, w_to: Worker) -> None: """Test that an in-memory key was transferred from worker w_from to worker w_to by the Active Memory Manager and it was not recalculated on w_to """ - assert_worker_story( + assert_story( w_to.story(key), [ (key, "ensure-task-exists", "released"), @@ -1698,7 +1698,7 @@ async def test_story_with_deps(c, s, a, b): (key, "put-in-memory"), (key, "executing", "memory", "memory", {}), ] - assert_worker_story(story, expected, strict=True) + assert_story(story, expected, strict=True) story = b.story(dep.key) stimulus_ids = {ev[-2] for ev in story} @@ -1713,7 +1713,7 @@ async def test_story_with_deps(c, s, a, b): (dep.key, "put-in-memory"), (dep.key, "flight", "memory", "memory", {res.key: "ready"}), ] - assert_worker_story(story, expected, strict=True) + assert_story(story, expected, strict=True) @gen_cluster(client=True) @@ -2667,7 +2667,7 @@ async def test_acquire_replicas_same_channel(c, s, a, b): # same communication channel for fut in (futA, futB): - assert_worker_story( + assert_story( b.story(fut.key), [ ("gather-dependencies", a.address, {fut.key}), @@ -2726,7 +2726,7 @@ def __getstate__(self): assert await y == 123 story = await c.run(lambda dask_worker: dask_worker.story("x")) - assert_worker_story( + assert_story( story[b], [ ("x", "ensure-task-exists", "released"), @@ -2903,7 +2903,7 @@ async def test_who_has_consistent_remove_replicas(c, s, *workers): await f2 - assert_worker_story(a.story(f1.key), [(f1.key, "missing-dep")]) + assert_story(a.story(f1.key), [(f1.key, "missing-dep")]) assert a.tasks[f1.key].suspicious_count == 0 assert s.tasks[f1.key].suspicious == 0 @@ -2983,7 +2983,7 @@ async def test_missing_released_zombie_tasks_2(c, s, a, b): while b.tasks: await asyncio.sleep(0.01) - assert_worker_story( + assert_story( b.story(ts), [("f1", "missing", "released", "released", {"f1": "forgotten"})], ) @@ -3095,7 +3095,7 @@ async def test_task_flight_compute_oserror(c, s, a, b): (f1.key, "put-in-memory"), (f1.key, "executing", "memory", "memory", {}), ] - assert_worker_story(sum_story, expected_sum_story, strict=True) + assert_story(sum_story, expected_sum_story, strict=True) @gen_cluster(client=True) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 608f8203530..f84cbac8556 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1919,7 +1919,7 @@ def xfail_ssl_issue5601(): raise -def assert_worker_story( +def assert_story( story: list[tuple], expect: list[tuple], *, strict: bool = False ) -> None: """Test the output of ``Worker.story`` @@ -1985,7 +1985,7 @@ def assert_worker_story( break except StopIteration: raise AssertionError( - f"assert_worker_story({strict=}) failed\n" + f"assert_story({strict=}) failed\n" f"story:\n{_format_story(story)}\n" f"expect:\n{_format_story(expect)}" ) from None From b017a129915e471b72d02753653b49aae417f3a1 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 1 Apr 2022 14:19:10 +0200 Subject: [PATCH 05/29] If possible, defer to STIMULUS_ID when sending messages --- distributed/scheduler.py | 37 ++++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 2ea0f8a8913..b98e11a4675 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2376,7 +2376,7 @@ def _transition(self, key, finish: str, *args, **kwargs): if self._validate: raise else: - stimulus_id = "" + stimulus_id = "" finish2 = ts._state # FIXME downcast antipattern @@ -2386,11 +2386,12 @@ def _transition(self, key, finish: str, *args, **kwargs): ) if parent._validate: logger.debug( - "Transitioned %r %s->%s (actual: %s). Consequence: %s", + "Transitioned %r %s->%s (actual: %s) from %s. Consequence: %s", key, start, finish2, ts._state, + stimulus_id, dict(recommendations), ) if self.plugins: @@ -2865,7 +2866,9 @@ def transition_processing_memory( { "op": "cancel-compute", "key": key, - "stimulus_id": f"processing-memory-{time()}", + "stimulus_id": SERVER_STIMULUS_ID.get( + f"processing-memory-{time()}" + ), } ] @@ -2953,7 +2956,7 @@ def transition_memory_released(self, key, safe: bint = False): worker_msg = { "op": "free-keys", "keys": [key], - "stimulus_id": f"memory-released-{time()}", + "stimulus_id": SERVER_STIMULUS_ID.get(f"memory-released-{time()}"), } for ws in ts._who_has: worker_msgs[ws._address] = [worker_msg] @@ -3056,7 +3059,7 @@ def transition_erred_released(self, key): w_msg = { "op": "free-keys", "keys": [key], - "stimulus_id": f"erred-released-{time()}", + "stimulus_id": SERVER_STIMULUS_ID.get(f"erred-released-{time()}"), } for ws_addr in ts._erred_on: worker_msgs[ws_addr] = [w_msg] @@ -3135,7 +3138,9 @@ def transition_processing_released(self, key): { "op": "free-keys", "keys": [key], - "stimulus_id": f"processing-released-{time()}", + "stimulus_id": SERVER_STIMULUS_ID.get( + f"processing-released-{time()}" + ), } ] @@ -4593,7 +4598,9 @@ async def add_worker( { "op": "remove-replicas", "keys": already_released_keys, - "stimulus_id": f"reconnect-already-released-{time()}", + "stimulus_id": SERVER_STIMULUS_ID.get( + f"reconnect-already-released-{time()}" + ), } ) @@ -5030,7 +5037,9 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): { "op": "free-keys", "keys": [key], - "stimulus_id": f"already-released-or-forgotten-{time()}", + "stimulus_id": SERVER_STIMULUS_ID.get( + f"already-released-or-forgotten-{time()}" + ), } ] elif ts._state == "memory": @@ -7189,14 +7198,14 @@ def add_keys(self, worker=None, keys=(), stimulus_id=None): redundant_replicas.append(key) if redundant_replicas: - if not stimulus_id: - stimulus_id = f"redundant-replicas-{time()}" self.worker_send( worker, { "op": "remove-replicas", "keys": redundant_replicas, - "stimulus_id": stimulus_id, + "stimulus_id": SERVER_STIMULUS_ID.get( + stimulus_id or f"redundant-replicas-{time()}" + ), }, ) @@ -8431,7 +8440,9 @@ def _propagate_forgotten( { "op": "free-keys", "keys": [key], - "stimulus_id": f"propagate-forgotten-{time()}", + "stimulus_id": SERVER_STIMULUS_ID.get( + f"propagate-forgotten-{time()}" + ), } ] state.remove_all_replicas(ts) @@ -8474,7 +8485,7 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) -> "key": ts._key, "priority": ts._priority, "duration": duration, - "stimulus_id": f"compute-task-{time()}", + "stimulus_id": SERVER_STIMULUS_ID.get(f"compute-task-{time()}"), "who_has": {}, } if ts._resource_restrictions: From dd6efcb7b5d545341e21b0c0b54521acef0ac063 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 1 Apr 2022 15:24:27 +0200 Subject: [PATCH 06/29] Support passing stimulus_id in Scheduler handlers --- distributed/scheduler.py | 109 ++++++++++++++++++++++++++++++++++----- 1 file changed, 95 insertions(+), 14 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index b98e11a4675..9e781eab9f1 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5013,9 +5013,13 @@ def update_graph( # TODO: balance workers - def stimulus_task_finished(self, key=None, worker=None, **kwargs): + def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwargs): """Mark that a task has finished execution on a particular worker""" parent: SchedulerState = cast(SchedulerState, self) + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"stimulus-task-finished-{time()}") logger.debug("Stimulus task finished %s, %s", key, worker) recommendations: dict = {} @@ -5054,11 +5058,21 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): return recommendations, client_msgs, worker_msgs def stimulus_task_erred( - self, key=None, worker=None, exception=None, traceback=None, **kwargs + self, + key=None, + worker=None, + exception=None, + traceback=None, + stimulus_id=None, + **kwargs, ): """Mark that a task has erred on a particular worker""" parent: SchedulerState = cast(SchedulerState, self) logger.debug("Stimulus task erred %s, %s", key, worker) + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"stimulus-task-erred-{time()}") ts: TaskState = parent._tasks.get(key) if ts is None or ts._state != "processing": @@ -5640,7 +5654,12 @@ def send_task_to_worker(self, worker, ts: TaskState, duration: double = -1): def handle_uncaught_error(self, **msg): logger.exception(clean_exception(**msg)[1]) - def handle_task_finished(self, key=None, worker=None, **msg): + def handle_task_finished(self, key=None, worker=None, stimulus_id=None, **msg): + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"handle-task-finished-{time()}") + parent: SchedulerState = cast(SchedulerState, self) if worker not in parent._workers_dv: return @@ -5656,8 +5675,12 @@ def handle_task_finished(self, key=None, worker=None, **msg): self.send_all(client_msgs, worker_msgs) - def handle_task_erred(self, key=None, **msg): + def handle_task_erred(self, key=None, stimulus_id=None, **msg): parent: SchedulerState = cast(SchedulerState, self) + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"handle-task-erred-{time()}") recommendations: dict client_msgs: dict worker_msgs: dict @@ -5700,8 +5723,13 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): else: self.transitions({key: "forgotten"}) - def release_worker_data(self, key, worker): + def release_worker_data(self, key, worker, stimulus_id=None): parent: SchedulerState = cast(SchedulerState, self) + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"release-worker-data-{time()}") + ws: WorkerState = parent._workers_dv.get(worker) ts: TaskState = parent._tasks.get(key) if not ws or not ts: @@ -5756,8 +5784,17 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): ws._long_running.add(ts) self.check_idle_saturated(ws) - def handle_worker_status_change(self, status: str, worker: str) -> None: + def handle_worker_status_change( + self, status: str, worker: str, stimulus_id=None + ) -> None: parent: SchedulerState = cast(SchedulerState, self) + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set( + stimulus_id or f"handle-worker-status-change-{time()}" + ) + ws: WorkerState = parent._workers_dv.get(worker) # type: ignore if not ws: return @@ -5793,7 +5830,7 @@ def handle_worker_status_change(self, status: str, worker: str) -> None: else: parent._running.discard(ws) - async def handle_worker(self, comm=None, worker=None): + async def handle_worker(self, comm=None, worker=None, stimulus_id=None): """ Listen to responses from a single worker @@ -5803,6 +5840,11 @@ async def handle_worker(self, comm=None, worker=None): -------- Scheduler.handle_client: Equivalent coroutine for clients """ + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"handle-worker-{time()}") + comm.name = "Scheduler connection to worker" worker_comm = self.stream_comms[worker] worker_comm.start(comm) @@ -6009,6 +6051,7 @@ async def scatter( client=None, broadcast=False, timeout=2, + stimulus_id=None, ): """Send data out to workers @@ -6017,6 +6060,11 @@ async def scatter( Scheduler.broadcast: """ parent: SchedulerState = cast(SchedulerState, self) + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"scatter-{time()}") + ws: WorkerState start = time() @@ -6053,9 +6101,14 @@ async def scatter( ) return keys - async def gather(self, keys, serializers=None): + async def gather(self, keys, serializers=None, stimulus_id=None): """Collect data from workers to the scheduler""" parent: SchedulerState = cast(SchedulerState, self) + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"gather-{time()}") + ws: WorkerState keys = list(keys) who_has = {} @@ -6126,9 +6179,14 @@ def clear_task_state(self): for collection in self._task_state_collections: collection.clear() - async def restart(self, client=None, timeout=30): + async def restart(self, client=None, timeout=30, stimulus_id=None): """Restart all workers. Reset local state.""" parent: SchedulerState = cast(SchedulerState, self) + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"restart-{time()}") + with log_errors(): n_workers = len(parent._workers_dv) @@ -6398,6 +6456,7 @@ async def rebalance( comm=None, keys: "Iterable[Hashable]" = None, workers: "Iterable[str]" = None, + stimulus_id=None, ) -> dict: """Rebalance keys so that each worker ends up with roughly the same process memory (managed+unmanaged). @@ -6463,8 +6522,14 @@ async def rebalance( allowlist of workers addresses to be considered as senders or recipients. All other workers will be ignored. The mean cluster occupancy will be calculated only using the allowed workers. + stimulus_id: str, optional + Stimulus ID that caused this function call """ parent: SchedulerState = cast(SchedulerState, self) + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"rebalance-{time()}") with log_errors(): wss: "Collection[WorkerState]" @@ -6767,6 +6832,7 @@ async def replicate( branching_factor=2, delete=True, lock=True, + stimulus_id=None, ): """Replicate data throughout cluster @@ -6784,12 +6850,19 @@ async def replicate( The larger the branching factor, the more data we copy in a single step, but the more a given worker risks being swamped by data requests. + stimulus_id: str, optional + Stimulus ID that caused this function call. See also -------- Scheduler.rebalance """ parent: SchedulerState = cast(SchedulerState, self) + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"replicate-{time()}") + ws: WorkerState wws: WorkerState ts: TaskState @@ -7019,6 +7092,7 @@ async def retire_workers( names: "list | None" = None, close_workers: bool = False, remove: bool = True, + stimulus_id=None, **kwargs, ) -> dict: """Gracefully retire workers from cluster @@ -7039,6 +7113,8 @@ async def retire_workers( remove: bool (defaults to True) Whether or not to remove the worker metadata immediately or else wait for the worker to contact us + stimulus_id: str, optional + Stimulus ID that caused this function call. **kwargs: dict Extra options to pass to workers_to_close to determine which workers we should drop @@ -7053,6 +7129,11 @@ async def retire_workers( Scheduler.workers_to_close """ parent: SchedulerState = cast(SchedulerState, self) + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"retire-workers-{time()}") + ws: WorkerState ts: TaskState with log_errors(): @@ -7212,11 +7293,7 @@ def add_keys(self, worker=None, keys=(), stimulus_id=None): return "OK" def update_data( - self, - *, - who_has: dict, - nbytes: dict, - client=None, + self, *, who_has: dict, nbytes: dict, client=None, stimulus_id=None ): """ Learn that new data has entered the network from an external source @@ -7226,6 +7303,10 @@ def update_data( Scheduler.mark_key_in_memory """ parent: SchedulerState = cast(SchedulerState, self) + try: + SERVER_STIMULUS_ID.get() + except LookupError: + SERVER_STIMULUS_ID.set(stimulus_id or f"update-data-{time()}") with log_errors(): who_has = { k: [self.coerce_address(vv) for vv in v] for k, v in who_has.items() From 40f87f7a0e423ce69d181b1f5bedacb925b7355e Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 1 Apr 2022 15:40:00 +0200 Subject: [PATCH 07/29] Transmit stimulus_id's from client --- distributed/client.py | 50 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index d406ca6333a..e6b6f070673 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1373,7 +1373,7 @@ def _dec_ref(self, key): del self.refcount[key] self._release_key(key) - def _release_key(self, key): + def _release_key(self, key, stimulus_id=None): """Release key from distributed memory""" logger.debug("Release key %s", key) st = self.futures.pop(key, None) @@ -1381,7 +1381,12 @@ def _release_key(self, key): st.cancel() if self.status != "closed": self._send_to_scheduler( - {"op": "client-releases-keys", "keys": [key], "client": self.id} + { + "op": "client-releases-keys", + "keys": [key], + "client": self.id, + "stimulus_id": stimulus_id, + } ) async def _handle_report(self): @@ -1527,8 +1532,10 @@ async def _close(self, fast=False): ): await self.scheduler_comm.close() + stimulus_id = f"client-close-{time()}" + for key in list(self.futures): - self._release_key(key=key) + self._release_key(key=key, stimulus_id=stimulus_id) if self._start_arg is None: with suppress(AttributeError): @@ -2111,7 +2118,11 @@ async def _gather_remote(self, direct, local_worker): response["data"].update(data2) else: # ask scheduler to gather data for us - response = await retry_operation(self.scheduler.gather, keys=keys) + response = await retry_operation( + self.scheduler.gather, + keys=keys, + stimulus_id=f"client-gather-remote-{time()}", + ) return response @@ -2197,6 +2208,8 @@ async def _scatter( d = await self._scatter(keymap(stringify, data), workers, broadcast) return {k: d[stringify(k)] for k in data} + stimulus_id = f"client-scatter-{time()}" + if isinstance(data, type(range(0))): data = list(data) input_type = type(data) @@ -2238,6 +2251,7 @@ async def _scatter( who_has={key: [local_worker.address] for key in data}, nbytes=valmap(sizeof, data), client=self.id, + stimulus_id=stimulus_id, ) else: @@ -2260,7 +2274,10 @@ async def _scatter( ) await self.scheduler.update_data( - who_has=who_has, nbytes=nbytes, client=self.id + who_has=who_has, + nbytes=nbytes, + client=self.id, + stimulus_id=stimulus_id, ) else: await self.scheduler.scatter( @@ -2269,6 +2286,7 @@ async def _scatter( client=self.id, broadcast=broadcast, timeout=timeout, + stimulus_id=stimulus_id, ) out = {k: Future(k, self, inform=False) for k in data} @@ -2392,7 +2410,12 @@ def scatter( async def _cancel(self, futures, force=False): keys = list({stringify(f.key) for f in futures_of(futures)}) - await self.scheduler.cancel(keys=keys, client=self.id, force=force) + await self.scheduler.cancel( + keys=keys, + client=self.id, + force=force, + stimulus_id=f"client-cancel-{time()}", + ) for k in keys: st = self.futures.pop(k, None) if st is not None: @@ -2419,7 +2442,9 @@ def cancel(self, futures, asynchronous=None, force=False): async def _retry(self, futures): keys = list({stringify(f.key) for f in futures_of(futures)}) - response = await self.scheduler.retry(keys=keys, client=self.id) + response = await self.scheduler.retry( + keys=keys, client=self.id, stimulus_id=f"client-retry-{time()}" + ) for key in response: st = self.futures[key] st.retry() @@ -3421,7 +3446,9 @@ async def _rebalance(self, futures=None, workers=None): keys = list({stringify(f.key) for f in self.futures_of(futures)}) else: keys = None - result = await self.scheduler.rebalance(keys=keys, workers=workers) + result = await self.scheduler.rebalance( + keys=keys, workers=workers, stimulus_id=f"client-rebalance-{time()}" + ) if result["status"] == "partial-fail": raise KeyError(f"Could not rebalance keys: {result['keys']}") assert result["status"] == "OK", result @@ -3456,7 +3483,11 @@ async def _replicate(self, futures, n=None, workers=None, branching_factor=2): await _wait(futures) keys = {stringify(f.key) for f in futures} await self.scheduler.replicate( - keys=list(keys), n=n, workers=workers, branching_factor=branching_factor + keys=list(keys), + n=n, + workers=workers, + branching_factor=branching_factor, + stimulus_id=f"client-replicate-{time()}", ) def replicate(self, futures, n=None, workers=None, branching_factor=2, **kwargs): @@ -4174,6 +4205,7 @@ def retire_workers( self.scheduler.retire_workers, workers=workers, close_workers=close_workers, + stimulus_id=f"client-retire-workers-{time()}", **kwargs, ) From 52697de676660bed99201b438155e4b8d016b450 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 1 Apr 2022 16:12:47 +0200 Subject: [PATCH 08/29] Generate new stimulus_id on completion/failure of Worker.execute --- distributed/tests/test_worker.py | 10 +++++++++- distributed/worker.py | 5 +++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 18a303f2f0f..999ccf6b238 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -4,6 +4,7 @@ import importlib import logging import os +import re import sys import threading import traceback @@ -1688,7 +1689,14 @@ async def test_story_with_deps(c, s, a, b): # Story now includes randomized stimulus_ids and timestamps. stimulus_ids = {ev[-2] for ev in story} - assert len(stimulus_ids) == 3, stimulus_ids + # task-finished (new_stimulus_id) from Worker.execute is added + assert {sid[: re.search(r"\d", sid).start()] for sid in stimulus_ids} == { + "ensure-computing-", + "add-worker-", + "task-finished-", + "ensure-communicating-", + } + # This is a simple transition log expected = [ (key, "compute-task"), diff --git a/distributed/worker.py b/distributed/worker.py index c607bb5afa1..b706a10b135 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3490,6 +3490,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> None: assert ts, self.story(key) ts.done = True result["key"] = ts.key + result["stimulus-id"] = new_stimulus_id = f"{result['op']}-{time()}" value = result.pop("result", None) ts.startstops.append( {"action": "compute", "start": result["start"], "stop": result["stop"]} @@ -3526,7 +3527,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> None: result["traceback_text"], ) - self.transitions(recommendations, stimulus_id=stimulus_id) + self.transitions(recommendations, stimulus_id=new_stimulus_id) logger.debug("Send compute response to scheduler: %s, %s", ts.key, result) @@ -3545,7 +3546,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> None: ts, "error", **emsg, - stimulus_id=stimulus_id, + stimulus_id=f"task-erred-{time()}", ) finally: self.ensure_computing() From 17f069aa478b7046655e0ce68cc64ab029c9d3fd Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 4 Apr 2022 09:55:42 +0200 Subject: [PATCH 09/29] Use decorator to manage stimulus injection --- distributed/core.py | 4 - distributed/scheduler.py | 353 ++++++++---------- distributed/tests/test_scheduler.py | 38 ++ distributed/tests/test_worker.py | 6 +- .../tests/test_worker_state_machine.py | 8 +- distributed/utils_comm.py | 78 +++- distributed/utils_test.py | 20 +- distributed/worker.py | 39 +- distributed/worker_state_machine.py | 11 +- 9 files changed, 316 insertions(+), 241 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index e89eea16559..da48132e9e0 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -12,7 +12,6 @@ from collections import defaultdict from collections.abc import Container from contextlib import suppress -from contextvars import ContextVar from enum import Enum from functools import partial from typing import Callable, ClassVar @@ -115,9 +114,6 @@ def _expects_comm(func: Callable) -> bool: return False -SERVER_STIMULUS_ID: ContextVar = ContextVar("stimulus_id") - - class Server: """Dask Distributed Server diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 9e781eab9f1..2610e81947a 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -25,6 +25,7 @@ Set, ) from contextlib import suppress +from contextvars import ContextVar from datetime import timedelta from functools import partial from numbers import Number @@ -65,7 +66,7 @@ unparse_host_port, ) from distributed.comm.addressing import addresses_from_user_args -from distributed.core import SERVER_STIMULUS_ID, Status, clean_exception, rpc, send_recv +from distributed.core import Status, clean_exception, rpc, send_recv from distributed.diagnostics.memory_sampler import MemorySamplerExtension from distributed.diagnostics.plugin import SchedulerPlugin, _get_plugin_name from distributed.event import EventExtension @@ -100,6 +101,7 @@ gather_from_workers, retry_operation, scatter_to_workers, + stimulus_handler, ) from distributed.utils_perf import disable_gc_diagnosis, enable_gc_diagnosis from distributed.variable import VariableExtension @@ -2370,17 +2372,18 @@ def _transition(self, key, finish: str, *args, **kwargs): else: raise RuntimeError("Impossible transition from %r to %r" % start_finish) + # FIXME downcast antipattern + scheduler = pep484_cast(Scheduler, self) + try: - stimulus_id = SERVER_STIMULUS_ID.get() + stimulus_id = scheduler.STIMULUS_ID.get() except LookupError: if self._validate: - raise + raise LookupError(scheduler.STIMULUS_ID.name) else: stimulus_id = "" finish2 = ts._state - # FIXME downcast antipattern - scheduler = pep484_cast(Scheduler, self) scheduler.transition_log.append( (key, start, finish2, recommendations, stimulus_id, time()) ) @@ -2769,7 +2772,9 @@ def transition_waiting_processing(self, key): # logger.debug("Send job to worker: %s, %s", worker, key) - worker_msgs[worker] = [_task_to_msg(self, ts)] + worker_msgs[worker] = [ + _task_to_msg(self, ts, self.STIMULUS_ID.get(f"compute-task-{time()}")) + ] return recommendations, client_msgs, worker_msgs except Exception as e: @@ -2862,11 +2867,15 @@ def transition_processing_memory( ws, key, ) + + # FIXME downcast antipattern + scheduler = pep484_cast(Scheduler, self) + worker_msgs[ts._processing_on.address] = [ { "op": "cancel-compute", "key": key, - "stimulus_id": SERVER_STIMULUS_ID.get( + "stimulus_id": scheduler.STIMULUS_ID.get( f"processing-memory-{time()}" ), } @@ -2952,11 +2961,15 @@ def transition_memory_released(self, key, safe: bint = False): elif dts._state == "waiting": dts._waiting_on.add(ts) + # FIXME downcast antipattern + scheduler = pep484_cast(Scheduler, self) + # XXX factor this out? + worker_msg = { "op": "free-keys", "keys": [key], - "stimulus_id": SERVER_STIMULUS_ID.get(f"memory-released-{time()}"), + "stimulus_id": scheduler.STIMULUS_ID.get(f"memory-released-{time()}"), } for ws in ts._who_has: worker_msgs[ws._address] = [worker_msg] @@ -3059,7 +3072,7 @@ def transition_erred_released(self, key): w_msg = { "op": "free-keys", "keys": [key], - "stimulus_id": SERVER_STIMULUS_ID.get(f"erred-released-{time()}"), + "stimulus_id": self.STIMULUS_ID.get(f"erred-released-{time()}"), } for ws_addr in ts._erred_on: worker_msgs[ws_addr] = [w_msg] @@ -3138,7 +3151,7 @@ def transition_processing_released(self, key): { "op": "free-keys", "keys": [key], - "stimulus_id": SERVER_STIMULUS_ID.get( + "stimulus_id": self.STIMULUS_ID.get( f"processing-released-{time()}" ), } @@ -3331,7 +3344,13 @@ def transition_memory_forgotten(self, key): for ws in ts._who_has: ws._actors.discard(ts) - _propagate_forgotten(self, ts, recommendations, worker_msgs) + _propagate_forgotten( + self, + ts, + recommendations, + worker_msgs, + self.STIMULUS_ID.get(f"propagate-forgotten-{time()}"), + ) client_msgs = _task_to_client_msgs(self, ts) self.remove_key(key) @@ -3369,7 +3388,13 @@ def transition_released_forgotten(self, key): else: assert 0, (ts,) - _propagate_forgotten(self, ts, recommendations, worker_msgs) + _propagate_forgotten( + self, + ts, + recommendations, + worker_msgs, + self.STIMULUS_ID.get(f"propagate-forgotten-{time()}"), + ) client_msgs = _task_to_client_msgs(self, ts) self.remove_key(key) @@ -3997,6 +4022,8 @@ def __init__( "benchmark_hardware": self.benchmark_hardware, } + self.STIMULUS_ID = ContextVar(f"stimulus_id-{uuid.uuid4().hex}") + connection_limit = get_fileno_limit() / 2 super().__init__( @@ -4472,14 +4499,9 @@ async def add_worker( versions=None, nanny=None, extra=None, - stimulus_id=None, ): """Add a new worker to the cluster""" parent: SchedulerState = cast(SchedulerState, self) - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"add-worker-{time()}") with log_errors(): address = self.coerce_address(address, resolve_address) address = normalize_address(address) @@ -4568,50 +4590,53 @@ async def add_worker( recommendations: dict = {} client_msgs: dict = {} worker_msgs: dict = {} - if nbytes: - assert isinstance(nbytes, dict) - already_released_keys = [] - for key in nbytes: - ts: TaskState = parent._tasks.get(key) # type: ignore - if ts is not None and ts.state != "released": - if ts.state == "memory": - self.add_keys(worker=address, keys=[key]) + token = self.STIMULUS_ID.set(f"add-worker-{time()}") + try: + if nbytes: + assert isinstance(nbytes, dict) + already_released_keys = [] + for key in nbytes: + ts: TaskState = parent._tasks.get(key) # type: ignore + if ts is not None and ts.state != "released": + if ts.state == "memory": + self.add_keys(worker=address, keys=[key]) + else: + t: tuple = parent._transition( + key, + "memory", + worker=address, + nbytes=nbytes[key], + typename=types[key], + ) + recommendations, client_msgs, worker_msgs = t + parent._transitions( + recommendations, client_msgs, worker_msgs + ) + recommendations = {} else: - t: tuple = parent._transition( - key, - "memory", - worker=address, - nbytes=nbytes[key], - typename=types[key], - ) - recommendations, client_msgs, worker_msgs = t - parent._transitions( - recommendations, client_msgs, worker_msgs + already_released_keys.append(key) + if already_released_keys: + if address not in worker_msgs: + worker_msgs[address] = [] + worker_msgs[address].append( + { + "op": "remove-replicas", + "keys": already_released_keys, + "stimulus_id": f"reconnect-already-released-{time()}", + } ) - recommendations = {} - else: - already_released_keys.append(key) - if already_released_keys: - if address not in worker_msgs: - worker_msgs[address] = [] - worker_msgs[address].append( - { - "op": "remove-replicas", - "keys": already_released_keys, - "stimulus_id": SERVER_STIMULUS_ID.get( - f"reconnect-already-released-{time()}" - ), - } - ) - if ws._status == Status.running: - for ts in parent._unrunnable: - valid: set = self.valid_workers(ts) - if valid is None or ws in valid: - recommendations[ts._key] = "waiting" + if ws._status == Status.running: + for ts in parent._unrunnable: + valid: set = self.valid_workers(ts) + if valid is None or ws in valid: + recommendations[ts._key] = "waiting" - if recommendations: - parent._transitions(recommendations, client_msgs, worker_msgs) + if recommendations: + parent._transitions(recommendations, client_msgs, worker_msgs) + + finally: + self.STIMULUS_ID.reset(token) self.send_all(client_msgs, worker_msgs) @@ -4652,6 +4677,7 @@ async def add_nanny(self, comm): } return msg + @stimulus_handler def update_graph_hlg( self, client=None, @@ -4670,11 +4696,6 @@ def update_graph_hlg( code=None, stimulus_id=None, ): - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"update-graph-hlg-{time()}") - unpacked_graph = HighLevelGraph.__dask_distributed_unpack__(hlg) dsk = unpacked_graph["dsk"] dependencies = unpacked_graph["deps"] @@ -4713,9 +4734,10 @@ def update_graph_hlg( fifo_timeout, annotations, code=code, - stimulus_id=SERVER_STIMULUS_ID.get(), + stimulus_id=self.STIMULUS_ID.get(), ) + @stimulus_handler def update_graph( self, client=None, @@ -4741,10 +4763,6 @@ def update_graph( This happens whenever the Client calls submit, map, get, or compute. """ parent: SchedulerState = cast(SchedulerState, self) - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"update-graph-hlg-{time()}") start = time() fifo_timeout = parse_timedelta(fifo_timeout) keys = set(keys) @@ -5013,13 +5031,10 @@ def update_graph( # TODO: balance workers + @stimulus_handler def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwargs): """Mark that a task has finished execution on a particular worker""" parent: SchedulerState = cast(SchedulerState, self) - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"stimulus-task-finished-{time()}") logger.debug("Stimulus task finished %s, %s", key, worker) recommendations: dict = {} @@ -5041,7 +5056,7 @@ def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwar { "op": "free-keys", "keys": [key], - "stimulus_id": SERVER_STIMULUS_ID.get( + "stimulus_id": self.STIMULUS_ID.get( f"already-released-or-forgotten-{time()}" ), } @@ -5057,6 +5072,7 @@ def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwar assert ws in ts._who_has return recommendations, client_msgs, worker_msgs + @stimulus_handler def stimulus_task_erred( self, key=None, @@ -5069,10 +5085,6 @@ def stimulus_task_erred( """Mark that a task has erred on a particular worker""" parent: SchedulerState = cast(SchedulerState, self) logger.debug("Stimulus task erred %s, %s", key, worker) - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"stimulus-task-erred-{time()}") ts: TaskState = parent._tasks.get(key) if ts is None or ts._state != "processing": @@ -5092,13 +5104,9 @@ def stimulus_task_erred( **kwargs, ) + @stimulus_handler def stimulus_retry(self, keys, client=None, stimulus_id=None): parent: SchedulerState = cast(SchedulerState, self) - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"stimulus-retry-{time()}") - logger.info("Client %s requests to retry %d keys", client, len(keys)) if client: self.log_event(client, {"action": "retry", "count": len(keys)}) @@ -5127,6 +5135,7 @@ def stimulus_retry(self, keys, client=None, stimulus_id=None): return tuple(seen) + @stimulus_handler async def remove_worker(self, address, safe=False, close=True, stimulus_id=None): """ Remove worker from cluster @@ -5136,12 +5145,6 @@ async def remove_worker(self, address, safe=False, close=True, stimulus_id=None) state. """ parent: SchedulerState = cast(SchedulerState, self) - - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"remove-worker-{time()}") - with log_errors(): if self.status == Status.closed: return @@ -5249,15 +5252,11 @@ def remove_worker_from_events(): return "OK" + @stimulus_handler def stimulus_cancel( self, comm, keys=None, client=None, force=False, stimulus_id=None ): """Stop execution on a list of keys""" - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"add-worker-{time()}") - logger.info("Client %s requests to cancel %d keys", client, len(keys)) if client: self.log_event( @@ -5266,14 +5265,11 @@ def stimulus_cancel( for key in keys: self.cancel_key(key, client, force=force) + @stimulus_handler def cancel_key(self, key, client, retries=5, force=False, stimulus_id=None): """Cancel a particular key and all dependents""" # TODO: this should be converted to use the transition mechanism parent: SchedulerState = cast(SchedulerState, self) - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"cancel-key-{time()}") ts: TaskState = parent._tasks.get(key) dts: TaskState try: @@ -5295,13 +5291,9 @@ def cancel_key(self, key, client, retries=5, force=False, stimulus_id=None): for cs in clients: self.client_releases_keys(keys=[key], client=cs._client_key) + @stimulus_handler def client_desires_keys(self, keys=None, client=None, stimulus_id=None): parent: SchedulerState = cast(SchedulerState, self) - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"client-desires-keys-{time()}") - cs: ClientState = parent._clients.get(client) if cs is None: # For publish, queues etc. @@ -5318,15 +5310,11 @@ def client_desires_keys(self, keys=None, client=None, stimulus_id=None): if ts._state in ("memory", "erred"): self.report_on_key(ts=ts, client=client) + @stimulus_handler def client_releases_keys(self, keys=None, client=None, stimulus_id=None): """Remove keys from client desired list""" parent: SchedulerState = cast(SchedulerState, self) - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"client-releases-keys-{time()}") - if not isinstance(keys, list): keys = list(keys) cs: ClientState = parent._clients[client] @@ -5544,18 +5532,22 @@ def report(self, msg: dict, ts: TaskState = None, client: str = None): "Closed comm %r while trying to write %s", c, msg, exc_info=True ) - async def add_client( - self, comm: Comm, client: str, versions: dict, stimulus_id=None - ) -> None: + async def add_client(self, comm: Comm, client: str, versions: dict) -> None: """Add client to network We listen to all future messages from this Comm. """ parent: SchedulerState = cast(SchedulerState, self) try: - SERVER_STIMULUS_ID.get() + stimulus_id = self.STIMULUS_ID.get() except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"add-client-{time()}") + pass + else: + if self._validate: + raise RuntimeError( + f"STIMULUS_ID {stimulus_id} set in Scheduler.add_client" + ) + assert client is not None comm.name = "Scheduler->Client" logger.info("Receive client connection: %s", client) @@ -5599,13 +5591,9 @@ async def add_client( except TypeError: # comm becomes None during GC pass - def remove_client(self, client: str, stimulus_id=None) -> None: + def remove_client(self, client: str) -> None: """Remove client from network""" parent: SchedulerState = cast(SchedulerState, self) - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"remove-client-{time()}") if self.status == Status.running: logger.info("Remove client %s", client) self.log_event(["all", client], {"action": "remove-client", "client": client}) @@ -5641,7 +5629,9 @@ def send_task_to_worker(self, worker, ts: TaskState, duration: double = -1): """Send a single computational task to a worker""" parent: SchedulerState = cast(SchedulerState, self) try: - msg: dict = _task_to_msg(parent, ts, duration) + msg: dict = _task_to_msg( + parent, ts, self.STIMULUS_ID.get(f"compute-task-{time()}"), duration + ) self.worker_send(worker, msg) except Exception as e: logger.exception(e) @@ -5654,12 +5644,8 @@ def send_task_to_worker(self, worker, ts: TaskState, duration: double = -1): def handle_uncaught_error(self, **msg): logger.exception(clean_exception(**msg)[1]) + @stimulus_handler def handle_task_finished(self, key=None, worker=None, stimulus_id=None, **msg): - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"handle-task-finished-{time()}") - parent: SchedulerState = cast(SchedulerState, self) if worker not in parent._workers_dv: return @@ -5675,12 +5661,9 @@ def handle_task_finished(self, key=None, worker=None, stimulus_id=None, **msg): self.send_all(client_msgs, worker_msgs) + @stimulus_handler def handle_task_erred(self, key=None, stimulus_id=None, **msg): parent: SchedulerState = cast(SchedulerState, self) - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"handle-task-erred-{time()}") recommendations: dict client_msgs: dict worker_msgs: dict @@ -5690,7 +5673,10 @@ def handle_task_erred(self, key=None, stimulus_id=None, **msg): self.send_all(client_msgs, worker_msgs) - def handle_missing_data(self, key=None, errant_worker=None, **kwargs): + @stimulus_handler + def handle_missing_data( + self, key=None, errant_worker=None, stimulus_id=None, **kwargs + ): """Signal that `errant_worker` does not hold `key` This may either indicate that `errant_worker` is dead or that we may be @@ -5723,13 +5709,9 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): else: self.transitions({key: "forgotten"}) + @stimulus_handler def release_worker_data(self, key, worker, stimulus_id=None): parent: SchedulerState = cast(SchedulerState, self) - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"release-worker-data-{time()}") - ws: WorkerState = parent._workers_dv.get(worker) ts: TaskState = parent._tasks.get(key) if not ws or not ts: @@ -5742,7 +5724,10 @@ def release_worker_data(self, key, worker, stimulus_id=None): if recommendations: self.transitions(recommendations) - def handle_long_running(self, key=None, worker=None, compute_duration=None): + @stimulus_handler + def handle_long_running( + self, key=None, worker=None, compute_duration=None, stimulus_id=None + ): """A task has seceded from the thread pool We stop the task from being stolen in the future, and change task @@ -5784,17 +5769,11 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): ws._long_running.add(ts) self.check_idle_saturated(ws) + @stimulus_handler def handle_worker_status_change( self, status: str, worker: str, stimulus_id=None ) -> None: parent: SchedulerState = cast(SchedulerState, self) - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set( - stimulus_id or f"handle-worker-status-change-{time()}" - ) - ws: WorkerState = parent._workers_dv.get(worker) # type: ignore if not ws: return @@ -5830,7 +5809,7 @@ def handle_worker_status_change( else: parent._running.discard(ws) - async def handle_worker(self, comm=None, worker=None, stimulus_id=None): + async def handle_worker(self, comm=None, worker=None): """ Listen to responses from a single worker @@ -5841,9 +5820,14 @@ async def handle_worker(self, comm=None, worker=None, stimulus_id=None): Scheduler.handle_client: Equivalent coroutine for clients """ try: - SERVER_STIMULUS_ID.get() + stimulus_id = self.STIMULUS_ID.get() except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"handle-worker-{time()}") + pass + else: + if self._validate: + raise RuntimeError( + f"STIMULUS_ID {stimulus_id} set in Scheduler.add_worker" + ) comm.name = "Scheduler connection to worker" worker_comm = self.stream_comms[worker] @@ -6043,6 +6027,7 @@ def send_all(self, client_msgs: dict, worker_msgs: dict): # Less common interactions # ############################ + @stimulus_handler async def scatter( self, comm=None, @@ -6060,11 +6045,6 @@ async def scatter( Scheduler.broadcast: """ parent: SchedulerState = cast(SchedulerState, self) - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"scatter-{time()}") - ws: WorkerState start = time() @@ -6101,14 +6081,10 @@ async def scatter( ) return keys + @stimulus_handler async def gather(self, keys, serializers=None, stimulus_id=None): """Collect data from workers to the scheduler""" parent: SchedulerState = cast(SchedulerState, self) - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"gather-{time()}") - ws: WorkerState keys = list(keys) who_has = {} @@ -6179,14 +6155,10 @@ def clear_task_state(self): for collection in self._task_state_collections: collection.clear() + @stimulus_handler async def restart(self, client=None, timeout=30, stimulus_id=None): """Restart all workers. Reset local state.""" parent: SchedulerState = cast(SchedulerState, self) - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"restart-{time()}") - with log_errors(): n_workers = len(parent._workers_dv) @@ -6400,6 +6372,7 @@ async def gather_on_worker( return keys_failed + @stimulus_handler async def delete_worker_data( self, worker_address: str, @@ -6416,10 +6389,6 @@ async def delete_worker_data( List of keys to delete on the specified worker """ parent: SchedulerState = cast(SchedulerState, self) - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"delete-worker-data-{time()}") try: await retry_operation( @@ -6451,6 +6420,7 @@ async def delete_worker_data( self.log_event(ws._address, {"action": "remove-worker-data", "keys": keys}) + @stimulus_handler async def rebalance( self, comm=None, @@ -6526,11 +6496,6 @@ async def rebalance( Stimulus ID that caused this function call """ parent: SchedulerState = cast(SchedulerState, self) - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"rebalance-{time()}") - with log_errors(): wss: "Collection[WorkerState]" if workers is not None: @@ -6823,6 +6788,7 @@ async def _rebalance_move_data( else: return {"status": "OK"} + @stimulus_handler async def replicate( self, comm=None, @@ -6858,11 +6824,6 @@ async def replicate( Scheduler.rebalance """ parent: SchedulerState = cast(SchedulerState, self) - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"replicate-{time()}") - ws: WorkerState wws: WorkerState ts: TaskState @@ -7084,6 +7045,7 @@ def _key(group): return result + @stimulus_handler async def retire_workers( self, comm=None, @@ -7129,11 +7091,6 @@ async def retire_workers( Scheduler.workers_to_close """ parent: SchedulerState = cast(SchedulerState, self) - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"retire-workers-{time()}") - ws: WorkerState ts: TaskState with log_errors(): @@ -7192,7 +7149,11 @@ async def retire_workers( ws.status = Status.closing_gracefully self.running.discard(ws) self.stream_comms[ws.address].send( - {"op": "worker-status-change", "status": ws.status.name} + { + "op": "worker-status-change", + "status": ws.status.name, + "stimulus_id": self.STIMULUS_ID.get(), + } ) coros.append( @@ -7237,7 +7198,11 @@ async def _track_retire_worker( # conditions and we can wait for a scheduler->worker->scheduler # round-trip. self.stream_comms[ws.address].send( - {"op": "worker-status-change", "status": prev_status.name} + { + "op": "worker-status-change", + "status": prev_status.name, + "stimulus_id": self.STIMULUS_ID.get(), + } ) return None, {} @@ -7258,6 +7223,7 @@ async def _track_retire_worker( logger.info("Retired worker %s", ws._address) return ws._address, ws.identity() + @stimulus_handler def add_keys(self, worker=None, keys=(), stimulus_id=None): """ Learn that a worker has certain keys @@ -7284,7 +7250,7 @@ def add_keys(self, worker=None, keys=(), stimulus_id=None): { "op": "remove-replicas", "keys": redundant_replicas, - "stimulus_id": SERVER_STIMULUS_ID.get( + "stimulus_id": self.STIMULUS_ID.get( stimulus_id or f"redundant-replicas-{time()}" ), }, @@ -7292,6 +7258,7 @@ def add_keys(self, worker=None, keys=(), stimulus_id=None): return "OK" + @stimulus_handler def update_data( self, *, who_has: dict, nbytes: dict, client=None, stimulus_id=None ): @@ -7303,10 +7270,6 @@ def update_data( Scheduler.mark_key_in_memory """ parent: SchedulerState = cast(SchedulerState, self) - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"update-data-{time()}") with log_errors(): who_has = { k: [self.coerce_address(vv) for vv in v] for k, v in who_has.items() @@ -7773,6 +7736,7 @@ def story(self, *keys): transition_story = story + @stimulus_handler def reschedule(self, key=None, worker=None, stimulus_id=None): """Reschedule a task @@ -7780,11 +7744,6 @@ def reschedule(self, key=None, worker=None, stimulus_id=None): elsewhere """ parent: SchedulerState = cast(SchedulerState, self) - try: - SERVER_STIMULUS_ID.get() - except LookupError: - SERVER_STIMULUS_ID.set(stimulus_id or f"reschedule-{time()}") - ts: TaskState try: ts = parent._tasks[key] @@ -8488,7 +8447,11 @@ def _add_to_memory( @cfunc @exceptval(check=False) def _propagate_forgotten( - state: SchedulerState, ts: TaskState, recommendations: dict, worker_msgs: dict + state: SchedulerState, + ts: TaskState, + recommendations: dict, + worker_msgs: dict, + stimulus_id: str, ): ts.state = "forgotten" key: str = ts._key @@ -8521,9 +8484,7 @@ def _propagate_forgotten( { "op": "free-keys", "keys": [key], - "stimulus_id": SERVER_STIMULUS_ID.get( - f"propagate-forgotten-{time()}" - ), + "stimulus_id": stimulus_id, } ] state.remove_all_replicas(ts) @@ -8552,7 +8513,9 @@ def _client_releases_keys( @cfunc @exceptval(check=False) -def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) -> dict: +def _task_to_msg( + state: SchedulerState, ts: TaskState, stimulus_id: str, duration: double = -1 +) -> dict: """Convert a single computational task to a message""" ws: WorkerState dts: TaskState @@ -8566,7 +8529,7 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) -> "key": ts._key, "priority": ts._priority, "duration": duration, - "stimulus_id": SERVER_STIMULUS_ID.get(f"compute-task-{time()}"), + "stimulus_id": stimulus_id, "who_has": {}, } if ts._resource_restrictions: diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index e2c055d1655..e815b106fc7 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -37,6 +37,7 @@ from distributed.utils import TimeoutError from distributed.utils_test import ( BrokenComm, + assert_story, captured_logger, cluster, dec, @@ -3538,3 +3539,40 @@ async def test_repr(s, a): repr(ws_b) == f"" ) + + +@gen_cluster(client=True) +async def test_stimuli(c, s, a, b): + f = c.submit(inc, 1) + key = f.key + + await f + await c.close() + + assert_story( + s.story(key), + [ + (key, "released", "waiting", {key: "processing"}), + (key, "waiting", "processing", {}), + (key, "processing", "memory", {}), + ( + key, + "memory", + "forgotten", + {}, + ), + ], + ) + + stimuli = [ + "update-graph-hlg", + "update-graph-hlg", + "task-finished", + "client-releases-keys", + ] + + stories = s.story(key) + assert len(stories) == len(stimuli) + + for stimulus_id, story in zip(stimuli, stories): + assert story[-2].startswith(stimulus_id), (story[-2], stimulus_id) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 999ccf6b238..b5a34fba45b 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1689,12 +1689,10 @@ async def test_story_with_deps(c, s, a, b): # Story now includes randomized stimulus_ids and timestamps. stimulus_ids = {ev[-2] for ev in story} - # task-finished (new_stimulus_id) from Worker.execute is added assert {sid[: re.search(r"\d", sid).start()] for sid in stimulus_ids} == { "ensure-computing-", - "add-worker-", - "task-finished-", "ensure-communicating-", + "task-finished-", } # This is a simple transition log @@ -2613,7 +2611,7 @@ def sink(a, b, *args): @gen_cluster(client=True) -async def test_gather_dep_exception_one_task_2(c, s, a, b, set_stimulus): +async def test_gather_dep_exception_one_task_2(c, s, a, b): """Ensure an exception in a single task does not tear down an entire batch of gather_dep The below triggers an fetch->memory transition diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 2c97b217aab..bfc03448ae3 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -90,5 +90,9 @@ def test_sendmsg_slots(cls): def test_sendmsg_to_dict(): # Arbitrary sample class - smsg = ReleaseWorkerDataMsg(key="x") - assert smsg.to_dict() == {"op": "release-worker-data", "key": "x"} + smsg = ReleaseWorkerDataMsg(key="x", stimulus_id="test") + assert smsg.to_dict() == { + "op": "release-worker-data", + "key": "x", + "stimulus_id": "test", + } diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index 9d2b6f7794f..acfc32627ec 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -1,22 +1,96 @@ import asyncio +import inspect import logging import random from collections import defaultdict -from functools import partial +from functools import partial, wraps from itertools import cycle from tlz import concat, drop, groupby, merge import dask.config from dask.optimization import SubgraphCallable -from dask.utils import parse_timedelta, stringify +from dask.utils import funcname, parse_timedelta, stringify from distributed.core import rpc +from distributed.metrics import time from distributed.utils import All logger = logging.getLogger(__name__) +def stimulus_handler(*args): + """Decorator controlling injection into RPC Handlers + + RPC Handler functions are entrypoints into the distributed Scheduler. + These entrypoints may receive stimuli from external entities such + as workers in the ``stimulus_id`` kwarg or the functinos + may generate the stimuli themselves. + A further complication is that RPC Handlers may call other RPC handler + functions. + + Therefore, this decorator exists to simplify the setting of + the Scheduler STIMULUS_ID and encapsulates the following logic + + 1. If the STIMULUS_ID is already set, stimuli from other sources + are ignored. + 2. If a ``stimulus_id`` kwargs is supplied by an external entity + such as a worker, the STIMULUS_ID is set to this value. + 3. Otherwise, the STIMULUS_ID is from the function name and + current time. + """ + + def decorator(fn): + name = funcname(fn).replace("_", "-") + params = list(inspect.signature(fn).parameters.values()) + if params[0].name != "self": + raise ValueError(f"{fn} must be a method") + + if not inspect.iscoroutinefunction(fn): + + @wraps(fn) + def wrapper(*args, **kw): + STIMULUS_ID = args[0].STIMULUS_ID + + try: + STIMULUS_ID.get() + except LookupError: + stimulus_id = kw.get("stimulus_id", None) or f"{name}-{time()}" + token = STIMULUS_ID.set(stimulus_id) + else: + token = None + + try: + return fn(*args, **kw) + finally: + if token: + STIMULUS_ID.reset(token) + + else: + + @wraps(fn) + async def wrapper(*args, **kw): + STIMULUS_ID = args[0].STIMULUS_ID + + try: + STIMULUS_ID.get() + except LookupError: + stimulus_id = kw.get("stimulus_id", None) or f"{name}-{time()}" + token = STIMULUS_ID.set(stimulus_id) + else: + token = None + + try: + return await fn(*args, **kw) + finally: + if token: + STIMULUS_ID.reset(token) + + return wrapper + + return decorator(args[0]) if args and callable(args[0]) else decorator + + async def gather_from_workers(who_has, rpc, close=True, serializers=None, who=None): """Gather data directly from peers diff --git a/distributed/utils_test.py b/distributed/utils_test.py index f84cbac8556..7ab8cf2ff48 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -50,14 +50,7 @@ from distributed.comm.tcp import TCP, BaseTCPConnector from distributed.compatibility import WINDOWS from distributed.config import initialize_logging -from distributed.core import ( - SERVER_STIMULUS_ID, - CommClosedError, - ConnectionPool, - Status, - connect, - rpc, -) +from distributed.core import CommClosedError, ConnectionPool, Status, connect, rpc from distributed.deploy import SpecCluster from distributed.diagnostics.plugin import WorkerPlugin from distributed.metrics import time @@ -124,17 +117,6 @@ def invalid_python_script(tmpdir_factory): return local_file -@pytest.fixture -def set_stimulus(request): - stimulus_id = f"{request.function.__name__.replace('_', '-')}-{time()}" - - try: - token = SERVER_STIMULUS_ID.set(stimulus_id) - yield - finally: - SERVER_STIMULUS_ID.reset(token) - - async def cleanup_global_workers(): for worker in Worker._instances: await worker.close(report=False, executor_wait=False) diff --git a/distributed/worker.py b/distributed/worker.py index b706a10b135..7c6fc8e6fea 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -924,14 +924,18 @@ def status(self, value): self.ensure_computing() self.ensure_communicating() - def _send_worker_status_change(self) -> None: + def _send_worker_status_change(self, stimulus_id=None) -> None: if ( self.batched_stream and self.batched_stream.comm and not self.batched_stream.comm.closed() ): self.batched_stream.send( - {"op": "worker-status-change", "status": self._status.name} + { + "op": "worker-status-change", + "status": self._status.name, + "stimulus_id": stimulus_id or f"worker-status-changed-{time()}", + } ) elif self._status != Status.closed: self.loop.call_later(0.05, self._send_worker_status_change) @@ -1897,7 +1901,7 @@ def handle_compute_task( pass elif ts.state == "memory": recommendations[ts] = "memory" - instructions.append(self._get_task_finished_msg(ts)) + instructions.append(self._get_task_finished_msg(ts, stimulus_id)) elif ts.state in { "released", "fetch", @@ -2025,7 +2029,7 @@ def transition_memory_released( recs, instructions = self.transition_generic_released( ts, stimulus_id=stimulus_id ) - instructions.append(ReleaseWorkerDataMsg(ts.key)) + instructions.append(ReleaseWorkerDataMsg(ts.key, stimulus_id)) return recs, instructions def transition_waiting_constrained( @@ -2048,7 +2052,7 @@ def transition_long_running_rescheduled( self, ts: TaskState, *, stimulus_id: str ) -> tuple[Recs, Instructions]: recs: Recs = {ts: "released"} - smsg = RescheduleMsg(key=ts.key, worker=self.address) + smsg = RescheduleMsg(key=ts.key, worker=self.address, stimulus_id=stimulus_id) return recs, [smsg] def transition_executing_rescheduled( @@ -2059,7 +2063,7 @@ def transition_executing_rescheduled( self._executing.discard(ts) recs: Recs = {ts: "released"} - smsg = RescheduleMsg(key=ts.key, worker=self.address) + smsg = RescheduleMsg(key=ts.key, worker=self.address, stimulus_id=stimulus_id) return recs, [smsg] def transition_waiting_ready( @@ -2136,6 +2140,7 @@ def transition_generic_error( traceback_text=ts.traceback_text, thread=self.threads.get(ts.key), startstops=ts.startstops, + stimulus_id=stimulus_id, ) return {}, [smsg] @@ -2320,7 +2325,7 @@ def transition_generic_memory( return recs, [] if self.validate: assert ts.key in self.data or ts.key in self.actors - smsg = self._get_task_finished_msg(ts) + smsg = self._get_task_finished_msg(ts, stimulus_id) return recs, [smsg] def transition_executing_memory( @@ -2439,7 +2444,9 @@ def transition_executing_long_running( ts.state = "long-running" self._executing.discard(ts) self.long_running.add(ts.key) - smsg = LongRunningMsg(key=ts.key, compute_duration=compute_duration) + smsg = LongRunningMsg( + key=ts.key, compute_duration=compute_duration, stimulus_id=stimulus_id + ) self.io_loop.add_callback(self.ensure_computing) return {}, [smsg] @@ -2700,7 +2707,9 @@ def ensure_communicating(self) -> None: for el in skipped_worker_in_flight: self.data_needed.push(el) - def _get_task_finished_msg(self, ts: TaskState) -> TaskFinishedMsg: + def _get_task_finished_msg( + self, ts: TaskState, stimulus_id: str + ) -> TaskFinishedMsg: if ts.key not in self.data and ts.key not in self.actors: raise RuntimeError(f"Task {ts} not ready") typ = ts.type @@ -2726,6 +2735,7 @@ def _get_task_finished_msg(self, ts: TaskState) -> TaskFinishedMsg: metadata=ts.metadata, thread=self.threads.get(ts.key), startstops=ts.startstops, + stimulus_id=stimulus_id, ) def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs: @@ -3029,7 +3039,12 @@ async def gather_dep( self.has_what[worker].discard(ts.key) self.log.append((d, "missing-dep", stimulus_id, time())) self.batched_stream.send( - {"op": "missing-data", "errant_worker": worker, "key": d} + { + "op": "missing-data", + "errant_worker": worker, + "key": d, + "stimulus_id": stimulus_id, + } ) recommendations[ts] = "fetch" if ts.who_has else "missing" del data, response @@ -3134,7 +3149,7 @@ def handle_steal_request(self, key: str, stimulus_id: str) -> None: # `transition_constrained_executing` self.transition(ts, "released", stimulus_id=stimulus_id) - def handle_worker_status_change(self, status: str) -> None: + def handle_worker_status_change(self, status: str, stimulus_id: str) -> None: new_status = Status.lookup[status] # type: ignore if ( @@ -3145,7 +3160,7 @@ def handle_worker_status_change(self, status: str) -> None: "Invalid Worker.status transition: %s -> %s", self._status, new_status ) # Reiterate the current status to the scheduler to restore sync - self._send_worker_status_change() + self._send_worker_status_change(stimulus_id) else: # Update status and send confirmation to the Scheduler (see status.setter) self.status = new_status diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index ce8d200ba7a..0a28637c0e2 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -286,6 +286,7 @@ class TaskFinishedMsg(SendMessageToScheduler): metadata: dict thread: int | None startstops: list[StartStop] + stimulus_id: str __slots__ = tuple(__annotations__) # type: ignore def to_dict(self) -> dict[str, Any]: @@ -305,6 +306,7 @@ class TaskErredMsg(SendMessageToScheduler): traceback_text: str thread: int | None startstops: list[StartStop] + stimulus_id: str __slots__ = tuple(__annotations__) # type: ignore def to_dict(self) -> dict[str, Any]: @@ -317,8 +319,9 @@ def to_dict(self) -> dict[str, Any]: class ReleaseWorkerDataMsg(SendMessageToScheduler): op = "release-worker-data" - __slots__ = ("key",) + __slots__ = ("key", "stimulus_id") key: str + stimulus_id: str @dataclass @@ -326,18 +329,20 @@ class RescheduleMsg(SendMessageToScheduler): op = "reschedule" # Not to be confused with the distributed.Reschedule Exception - __slots__ = ("key", "worker") + __slots__ = ("key", "worker", "stimulus_id") key: str worker: str + stimulus_id: str @dataclass class LongRunningMsg(SendMessageToScheduler): op = "long-running" - __slots__ = ("key", "compute_duration") + __slots__ = ("key", "compute_duration", "stimulus_id") key: str compute_duration: float + stimulus_id: str @dataclass From ef4e29c850e7ea4e5720b74e5c0ad620f3e31253 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 4 Apr 2022 11:50:18 +0200 Subject: [PATCH 10/29] Enable github tmate --- .github/workflows/tests.yaml | 6 +++--- distributed/scheduler.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index e5b74d5b49a..4111715cf25 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -142,9 +142,9 @@ jobs: python continuous_integration/scripts/parse_stdout.py < reports/stdout > reports/pytest.xml fi - # - name: Debug with tmate on failure - # if: ${{ failure() }} - # uses: mxschmitt/action-tmate@v3 + - name: Debug with tmate on failure + if: ${{ failure() }} + uses: mxschmitt/action-tmate@v3 - name: Coverage uses: codecov/codecov-action@v1 diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 524a060ea91..a3e243465a1 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2379,6 +2379,7 @@ def _transition(self, key, finish: str, *args, **kwargs): stimulus_id = scheduler.STIMULUS_ID.get() except LookupError: if self._validate: + # Can't pickle ContextVars raise LookupError(scheduler.STIMULUS_ID.name) else: stimulus_id = "" From 08a68122e44b59f3c50780066220af984695fc59 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 4 Apr 2022 12:11:06 +0200 Subject: [PATCH 11/29] Target specific test case --- .github/workflows/tests.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 4111715cf25..9228e20ff8c 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -128,6 +128,7 @@ jobs: --leaks=fds,processes,threads \ --junitxml reports/pytest.xml -o junit_suite_name=$TEST_ID \ --cov=distributed --cov-report=xml \ + -k test_remove_worker \ | tee reports/stdout - name: Generate junit XML report in case of pytest-timeout From 1ce45f0d155baf851713bf5601a17730ff2db7c4 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 4 Apr 2022 12:59:53 +0200 Subject: [PATCH 12/29] bump --- .github/workflows/tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 9228e20ff8c..4fbacfbe542 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -128,7 +128,7 @@ jobs: --leaks=fds,processes,threads \ --junitxml reports/pytest.xml -o junit_suite_name=$TEST_ID \ --cov=distributed --cov-report=xml \ - -k test_remove_worker \ + -k test_remove_worker -s -v \ | tee reports/stdout - name: Generate junit XML report in case of pytest-timeout From 222b24de2979640ffaccbb4d0f0de7e046c4e1b3 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 4 Apr 2022 13:54:43 +0200 Subject: [PATCH 13/29] Explicitly specify sync/async stimulus_handler inspect.iscoroutinefunction doesn't recognise cythonised async functions --- distributed/scheduler.py | 52 +++++++++++++++++++-------------------- distributed/utils_comm.py | 12 +++++++-- 2 files changed, 36 insertions(+), 28 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index a3e243465a1..6f52fafc83e 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4676,7 +4676,7 @@ async def add_nanny(self, comm): } return msg - @stimulus_handler + @stimulus_handler(sync=True) def update_graph_hlg( self, client=None, @@ -4736,7 +4736,7 @@ def update_graph_hlg( stimulus_id=self.STIMULUS_ID.get(), ) - @stimulus_handler + @stimulus_handler(sync=True) def update_graph( self, client=None, @@ -5030,7 +5030,7 @@ def update_graph( # TODO: balance workers - @stimulus_handler + @stimulus_handler(sync=True) def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwargs): """Mark that a task has finished execution on a particular worker""" parent: SchedulerState = cast(SchedulerState, self) @@ -5071,7 +5071,7 @@ def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwar assert ws in ts._who_has return recommendations, client_msgs, worker_msgs - @stimulus_handler + @stimulus_handler(sync=True) def stimulus_task_erred( self, key=None, @@ -5103,7 +5103,7 @@ def stimulus_task_erred( **kwargs, ) - @stimulus_handler + @stimulus_handler(sync=True) def stimulus_retry(self, keys, client=None, stimulus_id=None): parent: SchedulerState = cast(SchedulerState, self) logger.info("Client %s requests to retry %d keys", client, len(keys)) @@ -5134,7 +5134,7 @@ def stimulus_retry(self, keys, client=None, stimulus_id=None): return tuple(seen) - @stimulus_handler + @stimulus_handler(sync=False) async def remove_worker(self, address, safe=False, close=True, stimulus_id=None): """ Remove worker from cluster @@ -5251,7 +5251,7 @@ def remove_worker_from_events(): return "OK" - @stimulus_handler + @stimulus_handler(sync=True) def stimulus_cancel( self, comm, keys=None, client=None, force=False, stimulus_id=None ): @@ -5264,7 +5264,7 @@ def stimulus_cancel( for key in keys: self.cancel_key(key, client, force=force) - @stimulus_handler + @stimulus_handler(sync=True) def cancel_key(self, key, client, retries=5, force=False, stimulus_id=None): """Cancel a particular key and all dependents""" # TODO: this should be converted to use the transition mechanism @@ -5290,7 +5290,7 @@ def cancel_key(self, key, client, retries=5, force=False, stimulus_id=None): for cs in clients: self.client_releases_keys(keys=[key], client=cs._client_key) - @stimulus_handler + @stimulus_handler(sync=True) def client_desires_keys(self, keys=None, client=None, stimulus_id=None): parent: SchedulerState = cast(SchedulerState, self) cs: ClientState = parent._clients.get(client) @@ -5309,7 +5309,7 @@ def client_desires_keys(self, keys=None, client=None, stimulus_id=None): if ts._state in ("memory", "erred"): self.report_on_key(ts=ts, client=client) - @stimulus_handler + @stimulus_handler(sync=True) def client_releases_keys(self, keys=None, client=None, stimulus_id=None): """Remove keys from client desired list""" @@ -5643,7 +5643,7 @@ def send_task_to_worker(self, worker, ts: TaskState, duration: double = -1): def handle_uncaught_error(self, **msg): logger.exception(clean_exception(**msg)[1]) - @stimulus_handler + @stimulus_handler(sync=True) def handle_task_finished(self, key=None, worker=None, stimulus_id=None, **msg): parent: SchedulerState = cast(SchedulerState, self) if worker not in parent._workers_dv: @@ -5660,7 +5660,7 @@ def handle_task_finished(self, key=None, worker=None, stimulus_id=None, **msg): self.send_all(client_msgs, worker_msgs) - @stimulus_handler + @stimulus_handler(sync=True) def handle_task_erred(self, key=None, stimulus_id=None, **msg): parent: SchedulerState = cast(SchedulerState, self) recommendations: dict @@ -5672,7 +5672,7 @@ def handle_task_erred(self, key=None, stimulus_id=None, **msg): self.send_all(client_msgs, worker_msgs) - @stimulus_handler + @stimulus_handler(sync=True) def handle_missing_data( self, key=None, errant_worker=None, stimulus_id=None, **kwargs ): @@ -5708,7 +5708,7 @@ def handle_missing_data( else: self.transitions({key: "forgotten"}) - @stimulus_handler + @stimulus_handler(sync=True) def release_worker_data(self, key, worker, stimulus_id=None): parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState = parent._workers_dv.get(worker) @@ -5723,7 +5723,7 @@ def release_worker_data(self, key, worker, stimulus_id=None): if recommendations: self.transitions(recommendations) - @stimulus_handler + @stimulus_handler(sync=True) def handle_long_running( self, key=None, worker=None, compute_duration=None, stimulus_id=None ): @@ -5768,7 +5768,7 @@ def handle_long_running( ws._long_running.add(ts) self.check_idle_saturated(ws) - @stimulus_handler + @stimulus_handler(sync=True) def handle_worker_status_change( self, status: str, worker: str, stimulus_id=None ) -> None: @@ -6026,7 +6026,7 @@ def send_all(self, client_msgs: dict, worker_msgs: dict): # Less common interactions # ############################ - @stimulus_handler + @stimulus_handler(sync=False) async def scatter( self, comm=None, @@ -6080,7 +6080,7 @@ async def scatter( ) return keys - @stimulus_handler + @stimulus_handler(sync=False) async def gather(self, keys, serializers=None, stimulus_id=None): """Collect data from workers to the scheduler""" parent: SchedulerState = cast(SchedulerState, self) @@ -6154,7 +6154,7 @@ def clear_task_state(self): for collection in self._task_state_collections: collection.clear() - @stimulus_handler + @stimulus_handler(sync=False) async def restart(self, client=None, timeout=30, stimulus_id=None): """Restart all workers. Reset local state.""" parent: SchedulerState = cast(SchedulerState, self) @@ -6371,7 +6371,7 @@ async def gather_on_worker( return keys_failed - @stimulus_handler + @stimulus_handler(sync=False) async def delete_worker_data( self, worker_address: str, @@ -6419,7 +6419,7 @@ async def delete_worker_data( self.log_event(ws._address, {"action": "remove-worker-data", "keys": keys}) - @stimulus_handler + @stimulus_handler(sync=False) async def rebalance( self, comm=None, @@ -6787,7 +6787,7 @@ async def _rebalance_move_data( else: return {"status": "OK"} - @stimulus_handler + @stimulus_handler(sync=False) async def replicate( self, comm=None, @@ -7044,7 +7044,7 @@ def _key(group): return result - @stimulus_handler + @stimulus_handler(sync=False) async def retire_workers( self, comm=None, @@ -7222,7 +7222,7 @@ async def _track_retire_worker( logger.info("Retired worker %s", ws._address) return ws._address, ws.identity() - @stimulus_handler + @stimulus_handler(sync=True) def add_keys(self, worker=None, keys=(), stimulus_id=None): """ Learn that a worker has certain keys @@ -7257,7 +7257,7 @@ def add_keys(self, worker=None, keys=(), stimulus_id=None): return "OK" - @stimulus_handler + @stimulus_handler(sync=True) def update_data( self, *, who_has: dict, nbytes: dict, client=None, stimulus_id=None ): @@ -7735,7 +7735,7 @@ def story(self, *keys): transition_story = story - @stimulus_handler + @stimulus_handler(sync=True) def reschedule(self, key=None, worker=None, stimulus_id=None): """Reschedule a task diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index acfc32627ec..38a65719139 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -16,10 +16,15 @@ from distributed.metrics import time from distributed.utils import All +try: + from cython import compiled +except ImportError: + compiled = False + logger = logging.getLogger(__name__) -def stimulus_handler(*args): +def stimulus_handler(*args, sync=True): """Decorator controlling injection into RPC Handlers RPC Handler functions are entrypoints into the distributed Scheduler. @@ -46,7 +51,10 @@ def decorator(fn): if params[0].name != "self": raise ValueError(f"{fn} must be a method") - if not inspect.iscoroutinefunction(fn): + if not compiled: + assert sync is not inspect.iscoroutinefunction(fn) + + if sync: @wraps(fn) def wrapper(*args, **kw): From 55e5216fe5842d29abe8f2ef6e3d559dd06f7f82 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 4 Apr 2022 14:13:53 +0200 Subject: [PATCH 14/29] Assert with is_coroutine_function --- distributed/utils_comm.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index 38a65719139..47863df7879 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -14,12 +14,7 @@ from distributed.core import rpc from distributed.metrics import time -from distributed.utils import All - -try: - from cython import compiled -except ImportError: - compiled = False +from distributed.utils import All, is_coroutine_function logger = logging.getLogger(__name__) @@ -51,8 +46,7 @@ def decorator(fn): if params[0].name != "self": raise ValueError(f"{fn} must be a method") - if not compiled: - assert sync is not inspect.iscoroutinefunction(fn) + assert sync is not is_coroutine_function(fn) if sync: From 22e3300c251b5228ae8ada91855098ab6be80771 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 4 Apr 2022 14:29:58 +0200 Subject: [PATCH 15/29] Document sync parameter --- distributed/utils_comm.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index 47863df7879..9e0c022d9d2 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -14,7 +14,7 @@ from distributed.core import rpc from distributed.metrics import time -from distributed.utils import All, is_coroutine_function +from distributed.utils import All logger = logging.getLogger(__name__) @@ -38,6 +38,17 @@ def stimulus_handler(*args, sync=True): such as a worker, the STIMULUS_ID is set to this value. 3. Otherwise, the STIMULUS_ID is from the function name and current time. + + Parameters + ---------- + *args : tuple + If the decorator is called without keyword arguments it will + be assumed that the decorated function is in ``args[0]``. + Otherwise should be empty if call with keyword arguments. + sync : bool + Indicates whether function is sync or async. + Necessary to distinguish between sync and async stimulus handlers + in a cython environment. https://bugs.python.org/issue38225 """ def decorator(fn): @@ -46,8 +57,6 @@ def decorator(fn): if params[0].name != "self": raise ValueError(f"{fn} must be a method") - assert sync is not is_coroutine_function(fn) - if sync: @wraps(fn) From 9c42c3fcb55e517100767f1a45e568a6c33459ce Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 4 Apr 2022 14:39:14 +0200 Subject: [PATCH 16/29] Revert "bump" This reverts commit 1ce45f0d155baf851713bf5601a17730ff2db7c4. --- .github/workflows/tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 4fbacfbe542..9228e20ff8c 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -128,7 +128,7 @@ jobs: --leaks=fds,processes,threads \ --junitxml reports/pytest.xml -o junit_suite_name=$TEST_ID \ --cov=distributed --cov-report=xml \ - -k test_remove_worker -s -v \ + -k test_remove_worker \ | tee reports/stdout - name: Generate junit XML report in case of pytest-timeout From a517fcb560b9630223d913cd9133f5dfe6fb074b Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 4 Apr 2022 14:39:36 +0200 Subject: [PATCH 17/29] Revert "Target specific test case" This reverts commit 08a68122e44b59f3c50780066220af984695fc59. --- .github/workflows/tests.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 9228e20ff8c..4111715cf25 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -128,7 +128,6 @@ jobs: --leaks=fds,processes,threads \ --junitxml reports/pytest.xml -o junit_suite_name=$TEST_ID \ --cov=distributed --cov-report=xml \ - -k test_remove_worker \ | tee reports/stdout - name: Generate junit XML report in case of pytest-timeout From e29ad73dcebef3b0b25ea0e928fd9f55ff00659f Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 4 Apr 2022 14:39:39 +0200 Subject: [PATCH 18/29] Revert "Enable github tmate" This reverts commit ef4e29c850e7ea4e5720b74e5c0ad620f3e31253. --- .github/workflows/tests.yaml | 6 +++--- distributed/scheduler.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 4111715cf25..e5b74d5b49a 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -142,9 +142,9 @@ jobs: python continuous_integration/scripts/parse_stdout.py < reports/stdout > reports/pytest.xml fi - - name: Debug with tmate on failure - if: ${{ failure() }} - uses: mxschmitt/action-tmate@v3 + # - name: Debug with tmate on failure + # if: ${{ failure() }} + # uses: mxschmitt/action-tmate@v3 - name: Coverage uses: codecov/codecov-action@v1 diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 6f52fafc83e..c87861dbc18 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2379,7 +2379,6 @@ def _transition(self, key, finish: str, *args, **kwargs): stimulus_id = scheduler.STIMULUS_ID.get() except LookupError: if self._validate: - # Can't pickle ContextVars raise LookupError(scheduler.STIMULUS_ID.name) else: stimulus_id = "" From 8458969dbdaca3fdf46788cfde3734e8d1111ea9 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 4 Apr 2022 16:09:30 +0200 Subject: [PATCH 19/29] comments --- distributed/scheduler.py | 2 +- distributed/utils_comm.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index c87861dbc18..6ecbf99941d 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5824,7 +5824,7 @@ async def handle_worker(self, comm=None, worker=None): else: if self._validate: raise RuntimeError( - f"STIMULUS_ID {stimulus_id} set in Scheduler.add_worker" + f"STIMULUS_ID {stimulus_id} set in Scheduler.handle_worker" ) comm.name = "Scheduler connection to worker" diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index 9e0c022d9d2..b0334ac5f4f 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -24,19 +24,19 @@ def stimulus_handler(*args, sync=True): RPC Handler functions are entrypoints into the distributed Scheduler. These entrypoints may receive stimuli from external entities such - as workers in the ``stimulus_id`` kwarg or the functinos - may generate the stimuli themselves. + as workers in the ``stimulus_id`` kwarg or they + may generate stimuli themselves. A further complication is that RPC Handlers may call other RPC handler functions. - Therefore, this decorator exists to simplify the setting of + This decorator exists to simplify the setting of the Scheduler STIMULUS_ID and encapsulates the following logic 1. If the STIMULUS_ID is already set, stimuli from other sources are ignored. 2. If a ``stimulus_id`` kwargs is supplied by an external entity such as a worker, the STIMULUS_ID is set to this value. - 3. Otherwise, the STIMULUS_ID is from the function name and + 3. Otherwise, the STIMULUS_ID is generated from the function name and current time. Parameters From caf9a1d8b721d33a66007807ce31c58cf57b022a Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 4 Apr 2022 17:17:11 +0200 Subject: [PATCH 20/29] Template stimulus_id var in dashboard --- distributed/http/templates/task.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/http/templates/task.html b/distributed/http/templates/task.html index 075b25e2264..f10aaad5602 100644 --- a/distributed/http/templates/task.html +++ b/distributed/http/templates/task.html @@ -129,7 +129,7 @@

Transition Log

{{key}} {{ start }} {{ finish }} - stimulus_id + {{ stimulus_id }} From a57d9c1ebcfdadf5466a524b51c9c6d0609fb8d8 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 4 Apr 2022 17:21:31 +0200 Subject: [PATCH 21/29] Pass stimulus_id to Client._decref --- distributed/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 0d3510f3108..1f0a17285eb 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1371,9 +1371,9 @@ def _dec_ref(self, key): self.refcount[key] -= 1 if self.refcount[key] == 0: del self.refcount[key] - self._release_key(key) + self._release_key(key, f"client-release-key-{time()}") - def _release_key(self, key, stimulus_id=None): + def _release_key(self, key, stimulus_id: str): """Release key from distributed memory""" logger.debug("Release key %s", key) st = self.futures.pop(key, None) From b311bef56fa7fdcd8c4c4e44ce13b8b212fb3ba1 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 4 Apr 2022 18:12:54 +0200 Subject: [PATCH 22/29] stimulus_handler changes --- distributed/tests/test_scheduler.py | 6 +++--- distributed/utils_comm.py | 18 +++++++++--------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index e815b106fc7..43d1ef38f52 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -3565,10 +3565,10 @@ async def test_stimuli(c, s, a, b): ) stimuli = [ - "update-graph-hlg", - "update-graph-hlg", + "update_graph_hlg", + "update_graph_hlg", "task-finished", - "client-releases-keys", + "client_releases_keys", ] stories = s.story(key) diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index b0334ac5f4f..d16c4540594 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) -def stimulus_handler(*args, sync=True): +def stimulus_handler(*args, sync: bool = True): """Decorator controlling injection into RPC Handlers RPC Handler functions are entrypoints into the distributed Scheduler. @@ -52,16 +52,16 @@ def stimulus_handler(*args, sync=True): """ def decorator(fn): - name = funcname(fn).replace("_", "-") + name = funcname(fn) params = list(inspect.signature(fn).parameters.values()) if params[0].name != "self": - raise ValueError(f"{fn} must be a method") + raise ValueError(f"{fn} must be a method") # pragma: nocover if sync: @wraps(fn) - def wrapper(*args, **kw): - STIMULUS_ID = args[0].STIMULUS_ID + def wrapper(self, *args, **kw): + STIMULUS_ID = self.STIMULUS_ID try: STIMULUS_ID.get() @@ -72,7 +72,7 @@ def wrapper(*args, **kw): token = None try: - return fn(*args, **kw) + return fn(self, *args, **kw) finally: if token: STIMULUS_ID.reset(token) @@ -80,8 +80,8 @@ def wrapper(*args, **kw): else: @wraps(fn) - async def wrapper(*args, **kw): - STIMULUS_ID = args[0].STIMULUS_ID + async def wrapper(self, *args, **kw): + STIMULUS_ID = self.STIMULUS_ID try: STIMULUS_ID.get() @@ -92,7 +92,7 @@ async def wrapper(*args, **kw): token = None try: - return await fn(*args, **kw) + return await fn(self, *args, **kw) finally: if token: STIMULUS_ID.reset(token) From 1cf403242f9290ab65384704f91906a62ea1709c Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 4 Apr 2022 18:18:39 +0200 Subject: [PATCH 23/29] worker changes --- distributed/worker.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/distributed/worker.py b/distributed/worker.py index e23bb3cc8c2..5a7fa57b829 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -916,12 +916,12 @@ def status(self, value): """ prev_status = self.status ServerNode.status.__set__(self, value) - self._send_worker_status_change() + self._send_worker_status_change(f"worker-status-change-{time()}") if prev_status == Status.paused and value == Status.running: self.ensure_computing() self.ensure_communicating() - def _send_worker_status_change(self, stimulus_id=None) -> None: + def _send_worker_status_change(self, stimulus_id: str) -> None: if ( self.batched_stream and self.batched_stream.comm @@ -931,11 +931,11 @@ def _send_worker_status_change(self, stimulus_id=None) -> None: { "op": "worker-status-change", "status": self._status.name, - "stimulus_id": stimulus_id or f"worker-status-changed-{time()}", + "stimulus_id": stimulus_id, } ) elif self._status != Status.closed: - self.loop.call_later(0.05, self._send_worker_status_change) + self.loop.call_later(0.05, self._send_worker_status_change, stimulus_id) async def get_metrics(self) -> dict: try: @@ -1885,7 +1885,9 @@ def handle_compute_task( pass elif ts.state == "memory": recommendations[ts] = "memory" - instructions.append(self._get_task_finished_msg(ts, stimulus_id)) + instructions.append( + self._get_task_finished_msg(ts, stimulus_id=stimulus_id) + ) elif ts.state in { "released", "fetch", @@ -2013,7 +2015,7 @@ def transition_memory_released( recs, instructions = self.transition_generic_released( ts, stimulus_id=stimulus_id ) - instructions.append(ReleaseWorkerDataMsg(ts.key, stimulus_id)) + instructions.append(ReleaseWorkerDataMsg(key=ts.key, stimulus_id=stimulus_id)) return recs, instructions def transition_waiting_constrained( @@ -2309,7 +2311,7 @@ def transition_generic_memory( return recs, [] if self.validate: assert ts.key in self.data or ts.key in self.actors - smsg = self._get_task_finished_msg(ts, stimulus_id) + smsg = self._get_task_finished_msg(ts, stimulus_id=stimulus_id) return recs, [smsg] def transition_executing_memory( From 443375421285bf6af96ca472da5e01a76b6e20d6 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Tue, 5 Apr 2022 12:29:22 +0200 Subject: [PATCH 24/29] Use a contextmanager instead of a decorator --- distributed/scheduler.py | 1532 ++++++++++++++------------- distributed/tests/test_scheduler.py | 6 +- distributed/utils_comm.py | 89 +- 3 files changed, 793 insertions(+), 834 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 6ecbf99941d..83ed59c9dfd 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -24,7 +24,7 @@ Mapping, Set, ) -from contextlib import suppress +from contextlib import contextmanager, suppress from contextvars import ContextVar from datetime import timedelta from functools import partial @@ -101,7 +101,6 @@ gather_from_workers, retry_operation, scatter_to_workers, - stimulus_handler, ) from distributed.utils_perf import disable_gc_diagnosis, enable_gc_diagnosis from distributed.variable import VariableExtension @@ -2375,19 +2374,16 @@ def _transition(self, key, finish: str, *args, **kwargs): # FIXME downcast antipattern scheduler = pep484_cast(Scheduler, self) - try: - stimulus_id = scheduler.STIMULUS_ID.get() - except LookupError: - if self._validate: - raise LookupError(scheduler.STIMULUS_ID.name) - else: - stimulus_id = "" + stimulus_id = scheduler.STIMULUS_ID.get(Scheduler.STIMULUS_ID_NOT_SET) finish2 = ts._state scheduler.transition_log.append( (key, start, finish2, recommendations, stimulus_id, time()) ) if parent._validate: + if stimulus_id == Scheduler.STIMULUS_ID_NOT_SET: + raise LookupError(scheduler.STIMULUS_ID.name) + logger.debug( "Transitioned %r %s->%s (actual: %s) from %s. Consequence: %s", key, @@ -3720,6 +3716,7 @@ class Scheduler(SchedulerState, ServerNode): Time we expect certain functions to take, e.g. ``{'sum': 0.25}`` """ + STIMULUS_ID_NOT_SET = "" default_port = 8786 _instances: "ClassVar[weakref.WeakSet[Scheduler]]" = weakref.WeakSet() @@ -4089,6 +4086,32 @@ def _repr_html_(self): tasks=parent._tasks, ) + @contextmanager + def stimulus_id(self, name: str): + """Context manager for setting the Scheduler stimulus_id + + If the stimulus_id has already been set further up the call stack, + this has no effect. + + Parameters + ---------- + name : str + The name of the stimulus. + """ + try: + stimulus_id = self.STIMULUS_ID.get() + except LookupError: + token = self.STIMULUS_ID.set(name) + stimulus_id = name + else: + token = None + + try: + yield stimulus_id + finally: + if token: + self.STIMULUS_ID.reset(token) + def identity(self): """Basic information about ourselves and our cluster""" parent: SchedulerState = cast(SchedulerState, self) @@ -4588,8 +4611,8 @@ async def add_worker( recommendations: dict = {} client_msgs: dict = {} worker_msgs: dict = {} - token = self.STIMULUS_ID.set(f"add-worker-{time()}") - try: + + with self.stimulus_id(f"add-worker-{time()}"): if nbytes: assert isinstance(nbytes, dict) already_released_keys = [] @@ -4633,9 +4656,6 @@ async def add_worker( if recommendations: parent._transitions(recommendations, client_msgs, worker_msgs) - finally: - self.STIMULUS_ID.reset(token) - self.send_all(client_msgs, worker_msgs) logger.info("Register worker %s", ws) @@ -4675,7 +4695,6 @@ async def add_nanny(self, comm): } return msg - @stimulus_handler(sync=True) def update_graph_hlg( self, client=None, @@ -4716,26 +4735,26 @@ def update_graph_hlg( } priority = dask.order.order(dsk, dependencies=stripped_deps) - return self.update_graph( - client, - dsk, - keys, - dependencies, - restrictions, - priority, - loose_restrictions, - resources, - submitting_task, - retries, - user_priority, - actors, - fifo_timeout, - annotations, - code=code, - stimulus_id=self.STIMULUS_ID.get(), - ) + with self.stimulus_id(stimulus_id or f"update-graph-hlg-{time()}"): + return self.update_graph( + client, + dsk, + keys, + dependencies, + restrictions, + priority, + loose_restrictions, + resources, + submitting_task, + retries, + user_priority, + actors, + fifo_timeout, + annotations, + code=code, + stimulus_id=self.STIMULUS_ID.get(), + ) - @stimulus_handler(sync=True) def update_graph( self, client=None, @@ -4760,317 +4779,322 @@ def update_graph( This happens whenever the Client calls submit, map, get, or compute. """ - parent: SchedulerState = cast(SchedulerState, self) - start = time() - fifo_timeout = parse_timedelta(fifo_timeout) - keys = set(keys) - if len(tasks) > 1: - self.log_event( - ["all", client], {"action": "update_graph", "count": len(tasks)} - ) - - # Remove aliases - for k in list(tasks): - if tasks[k] is k: - del tasks[k] + with self.stimulus_id(stimulus_id or f"update-graph-{time()}"): + parent: SchedulerState = cast(SchedulerState, self) + start = time() + fifo_timeout = parse_timedelta(fifo_timeout) + keys = set(keys) + if len(tasks) > 1: + self.log_event( + ["all", client], {"action": "update_graph", "count": len(tasks)} + ) - dependencies = dependencies or {} + # Remove aliases + for k in list(tasks): + if tasks[k] is k: + del tasks[k] - if parent._total_occupancy > 1e-9 and parent._computations: - # Still working on something. Assign new tasks to same computation - computation = cast(Computation, parent._computations[-1]) - else: - computation = Computation() - parent._computations.append(computation) + dependencies = dependencies or {} - if code and code not in computation._code: # add new code blocks - computation._code.add(code) + if parent._total_occupancy > 1e-9 and parent._computations: + # Still working on something. Assign new tasks to same computation + computation = cast(Computation, parent._computations[-1]) + else: + computation = Computation() + parent._computations.append(computation) + + if code and code not in computation._code: # add new code blocks + computation._code.add(code) + + n = 0 + while len(tasks) != n: # walk through new tasks, cancel any bad deps + n = len(tasks) + for k, deps in list(dependencies.items()): + if any( + dep not in parent._tasks and dep not in tasks for dep in deps + ): # bad key + logger.info("User asked for computation on lost data, %s", k) + del tasks[k] + del dependencies[k] + if k in keys: + keys.remove(k) + self.report({"op": "cancelled-key", "key": k}, client=client) + self.client_releases_keys(keys=[k], client=client) + + # Avoid computation that is already finished + ts: TaskState + already_in_memory = set() # tasks that are already done + for k, v in dependencies.items(): + if v and k in parent._tasks: + ts = parent._tasks[k] + if ts._state in ("memory", "erred"): + already_in_memory.add(k) - n = 0 - while len(tasks) != n: # walk through new tasks, cancel any bad deps - n = len(tasks) - for k, deps in list(dependencies.items()): - if any( - dep not in parent._tasks and dep not in tasks for dep in deps - ): # bad key - logger.info("User asked for computation on lost data, %s", k) - del tasks[k] - del dependencies[k] - if k in keys: - keys.remove(k) - self.report({"op": "cancelled-key", "key": k}, client=client) - self.client_releases_keys(keys=[k], client=client) + dts: TaskState + if already_in_memory: + dependents = dask.core.reverse_dict(dependencies) + stack = list(already_in_memory) + done = set(already_in_memory) + while stack: # remove unnecessary dependencies + key = stack.pop() + ts = parent._tasks[key] + try: + deps = dependencies[key] + except KeyError: + deps = self.dependencies[key] + for dep in deps: + if dep in dependents: + child_deps = dependents[dep] + else: + child_deps = self.dependencies[dep] + if all(d in done for d in child_deps): + if dep in parent._tasks and dep not in done: + done.add(dep) + stack.append(dep) - # Avoid computation that is already finished - ts: TaskState - already_in_memory = set() # tasks that are already done - for k, v in dependencies.items(): - if v and k in parent._tasks: - ts = parent._tasks[k] - if ts._state in ("memory", "erred"): - already_in_memory.add(k) + for d in done: + tasks.pop(d, None) + dependencies.pop(d, None) - dts: TaskState - if already_in_memory: - dependents = dask.core.reverse_dict(dependencies) - stack = list(already_in_memory) - done = set(already_in_memory) - while stack: # remove unnecessary dependencies - key = stack.pop() - ts = parent._tasks[key] - try: - deps = dependencies[key] - except KeyError: - deps = self.dependencies[key] - for dep in deps: - if dep in dependents: - child_deps = dependents[dep] - else: - child_deps = self.dependencies[dep] - if all(d in done for d in child_deps): - if dep in parent._tasks and dep not in done: - done.add(dep) - stack.append(dep) - - for d in done: - tasks.pop(d, None) - dependencies.pop(d, None) - - # Get or create task states - stack = list(keys) - touched_keys = set() - touched_tasks = [] - while stack: - k = stack.pop() - if k in touched_keys: - continue - # XXX Have a method get_task_state(self, k) ? - ts = parent._tasks.get(k) - if ts is None: - ts = parent.new_task( - k, tasks.get(k), "released", computation=computation - ) - elif not ts._run_spec: - ts._run_spec = tasks.get(k) + # Get or create task states + stack = list(keys) + touched_keys = set() + touched_tasks = [] + while stack: + k = stack.pop() + if k in touched_keys: + continue + # XXX Have a method get_task_state(self, k) ? + ts = parent._tasks.get(k) + if ts is None: + ts = parent.new_task( + k, tasks.get(k), "released", computation=computation + ) + elif not ts._run_spec: + ts._run_spec = tasks.get(k) - touched_keys.add(k) - touched_tasks.append(ts) - stack.extend(dependencies.get(k, ())) + touched_keys.add(k) + touched_tasks.append(ts) + stack.extend(dependencies.get(k, ())) - self.client_desires_keys(keys=keys, client=client) + self.client_desires_keys(keys=keys, client=client) - # Add dependencies - for key, deps in dependencies.items(): - ts = parent._tasks.get(key) - if ts is None or ts._dependencies: - continue - for dep in deps: - dts = parent._tasks[dep] - ts.add_dependency(dts) - - # Compute priorities - if isinstance(user_priority, Number): - user_priority = {k: user_priority for k in tasks} - - annotations = annotations or {} - restrictions = restrictions or {} - loose_restrictions = loose_restrictions or [] - resources = resources or {} - retries = retries or {} - - # Override existing taxonomy with per task annotations - if annotations: - if "priority" in annotations: - user_priority.update(annotations["priority"]) - - if "workers" in annotations: - restrictions.update(annotations["workers"]) - - if "allow_other_workers" in annotations: - loose_restrictions.extend( - k for k, v in annotations["allow_other_workers"].items() if v - ) + # Add dependencies + for key, deps in dependencies.items(): + ts = parent._tasks.get(key) + if ts is None or ts._dependencies: + continue + for dep in deps: + dts = parent._tasks[dep] + ts.add_dependency(dts) + + # Compute priorities + if isinstance(user_priority, Number): + user_priority = {k: user_priority for k in tasks} + + annotations = annotations or {} + restrictions = restrictions or {} + loose_restrictions = loose_restrictions or [] + resources = resources or {} + retries = retries or {} + + # Override existing taxonomy with per task annotations + if annotations: + if "priority" in annotations: + user_priority.update(annotations["priority"]) + + if "workers" in annotations: + restrictions.update(annotations["workers"]) + + if "allow_other_workers" in annotations: + loose_restrictions.extend( + k for k, v in annotations["allow_other_workers"].items() if v + ) - if "retries" in annotations: - retries.update(annotations["retries"]) + if "retries" in annotations: + retries.update(annotations["retries"]) + + if "resources" in annotations: + resources.update(annotations["resources"]) + + for a, kv in annotations.items(): + for k, v in kv.items(): + # Tasks might have been culled, in which case + # we have nothing to annotate. + ts = parent._tasks.get(k) + if ts is not None: + ts._annotations[a] = v + + # Add actors + if actors is True: + actors = list(keys) + for actor in actors or []: + ts = parent._tasks[actor] + ts._actor = True + + priority = priority or dask.order.order( + tasks + ) # TODO: define order wrt old graph + + if submitting_task: # sub-tasks get better priority than parent tasks + ts = parent._tasks.get(submitting_task) + if ts is not None: + generation = ts._priority[0] - 0.01 + else: # super-task already cleaned up + generation = self.generation + elif self._last_time + fifo_timeout < start: + self.generation += 1 # older graph generations take precedence + generation = self.generation + self._last_time = start + else: + generation = self.generation - if "resources" in annotations: - resources.update(annotations["resources"]) + for key in set(priority) & touched_keys: + ts = parent._tasks[key] + if ts._priority is None: + ts._priority = ( + -(user_priority.get(key, 0)), + generation, + priority[key], + ) - for a, kv in annotations.items(): - for k, v in kv.items(): - # Tasks might have been culled, in which case - # we have nothing to annotate. + # Ensure all runnables have a priority + runnables = [ts for ts in touched_tasks if ts._run_spec] + for ts in runnables: + if ts._priority is None and ts._run_spec: + ts._priority = (self.generation, 0) + + if restrictions: + # *restrictions* is a dict keying task ids to lists of + # restriction specifications (either worker names or addresses) + for k, v in restrictions.items(): + if v is None: + continue ts = parent._tasks.get(k) - if ts is not None: - ts._annotations[a] = v - - # Add actors - if actors is True: - actors = list(keys) - for actor in actors or []: - ts = parent._tasks[actor] - ts._actor = True - - priority = priority or dask.order.order( - tasks - ) # TODO: define order wrt old graph - - if submitting_task: # sub-tasks get better priority than parent tasks - ts = parent._tasks.get(submitting_task) - if ts is not None: - generation = ts._priority[0] - 0.01 - else: # super-task already cleaned up - generation = self.generation - elif self._last_time + fifo_timeout < start: - self.generation += 1 # older graph generations take precedence - generation = self.generation - self._last_time = start - else: - generation = self.generation - - for key in set(priority) & touched_keys: - ts = parent._tasks[key] - if ts._priority is None: - ts._priority = (-(user_priority.get(key, 0)), generation, priority[key]) - - # Ensure all runnables have a priority - runnables = [ts for ts in touched_tasks if ts._run_spec] - for ts in runnables: - if ts._priority is None and ts._run_spec: - ts._priority = (self.generation, 0) - - if restrictions: - # *restrictions* is a dict keying task ids to lists of - # restriction specifications (either worker names or addresses) - for k, v in restrictions.items(): - if v is None: - continue - ts = parent._tasks.get(k) - if ts is None: - continue - ts._host_restrictions = set() - ts._worker_restrictions = set() - # Make sure `v` is a collection and not a single worker name / address - if not isinstance(v, (list, tuple, set)): - v = [v] - for w in v: - try: - w = self.coerce_address(w) - except ValueError: - # Not a valid address, but perhaps it's a hostname - ts._host_restrictions.add(w) - else: - ts._worker_restrictions.add(w) + if ts is None: + continue + ts._host_restrictions = set() + ts._worker_restrictions = set() + # Make sure `v` is a collection and not a single worker name / address + if not isinstance(v, (list, tuple, set)): + v = [v] + for w in v: + try: + w = self.coerce_address(w) + except ValueError: + # Not a valid address, but perhaps it's a hostname + ts._host_restrictions.add(w) + else: + ts._worker_restrictions.add(w) - if loose_restrictions: - for k in loose_restrictions: - ts = parent._tasks[k] - ts._loose_restrictions = True + if loose_restrictions: + for k in loose_restrictions: + ts = parent._tasks[k] + ts._loose_restrictions = True - if resources: - for k, v in resources.items(): - if v is None: - continue - assert isinstance(v, dict) - ts = parent._tasks.get(k) - if ts is None: - continue - ts._resource_restrictions = v + if resources: + for k, v in resources.items(): + if v is None: + continue + assert isinstance(v, dict) + ts = parent._tasks.get(k) + if ts is None: + continue + ts._resource_restrictions = v - if retries: - for k, v in retries.items(): - assert isinstance(v, int) - ts = parent._tasks.get(k) - if ts is None: - continue - ts._retries = v + if retries: + for k, v in retries.items(): + assert isinstance(v, int) + ts = parent._tasks.get(k) + if ts is None: + continue + ts._retries = v - # Compute recommendations - recommendations: dict = {} + # Compute recommendations + recommendations: dict = {} - for ts in sorted(runnables, key=operator.attrgetter("priority"), reverse=True): - if ts._state == "released" and ts._run_spec: - recommendations[ts._key] = "waiting" + for ts in sorted( + runnables, key=operator.attrgetter("priority"), reverse=True + ): + if ts._state == "released" and ts._run_spec: + recommendations[ts._key] = "waiting" - for ts in touched_tasks: - for dts in ts._dependencies: - if dts._exception_blame: - ts._exception_blame = dts._exception_blame - recommendations[ts._key] = "erred" - break + for ts in touched_tasks: + for dts in ts._dependencies: + if dts._exception_blame: + ts._exception_blame = dts._exception_blame + recommendations[ts._key] = "erred" + break - for plugin in list(self.plugins.values()): - try: - plugin.update_graph( - self, - client=client, - tasks=tasks, - keys=keys, - restrictions=restrictions or {}, - dependencies=dependencies, - priority=priority, - loose_restrictions=loose_restrictions, - resources=resources, - annotations=annotations, - ) - except Exception as e: - logger.exception(e) + for plugin in list(self.plugins.values()): + try: + plugin.update_graph( + self, + client=client, + tasks=tasks, + keys=keys, + restrictions=restrictions or {}, + dependencies=dependencies, + priority=priority, + loose_restrictions=loose_restrictions, + resources=resources, + annotations=annotations, + ) + except Exception as e: + logger.exception(e) - self.transitions(recommendations) + self.transitions(recommendations) - for ts in touched_tasks: - if ts._state in ("memory", "erred"): - self.report_on_key(ts=ts, client=client) + for ts in touched_tasks: + if ts._state in ("memory", "erred"): + self.report_on_key(ts=ts, client=client) - end = time() - if self.digests is not None: - self.digests["update-graph-duration"].add(end - start) + end = time() + if self.digests is not None: + self.digests["update-graph-duration"].add(end - start) - # TODO: balance workers + # TODO: balance workers - @stimulus_handler(sync=True) def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwargs): """Mark that a task has finished execution on a particular worker""" parent: SchedulerState = cast(SchedulerState, self) - logger.debug("Stimulus task finished %s, %s", key, worker) - recommendations: dict = {} - client_msgs: dict = {} - worker_msgs: dict = {} + with self.stimulus_id(stimulus_id or f"already-released-or-forgotten-{time()}"): + logger.debug("Stimulus task finished %s, %s", key, worker) - ws: WorkerState = parent._workers_dv[worker] - ts: TaskState = parent._tasks.get(key) - if ts is None or ts._state == "released": - logger.debug( - "Received already computed task, worker: %s, state: %s" - ", key: %s, who_has: %s", - worker, - ts._state if ts else "forgotten", - key, - ts._who_has if ts else {}, - ) - worker_msgs[worker] = [ - { - "op": "free-keys", - "keys": [key], - "stimulus_id": self.STIMULUS_ID.get( - f"already-released-or-forgotten-{time()}" - ), - } - ] - elif ts._state == "memory": - self.add_keys(worker=worker, keys=[key]) - else: - ts._metadata.update(kwargs["metadata"]) - r: tuple = parent._transition(key, "memory", worker=worker, **kwargs) - recommendations, client_msgs, worker_msgs = r + recommendations: dict = {} + client_msgs: dict = {} + worker_msgs: dict = {} + + ws: WorkerState = parent._workers_dv[worker] + ts: TaskState = parent._tasks.get(key) + if ts is None or ts._state == "released": + logger.debug( + "Received already computed task, worker: %s, state: %s" + ", key: %s, who_has: %s", + worker, + ts._state if ts else "forgotten", + key, + ts._who_has if ts else {}, + ) + worker_msgs[worker] = [ + { + "op": "free-keys", + "keys": [key], + "stimulus_id": self.STIMULUS_ID.get(), + } + ] + elif ts._state == "memory": + self.add_keys(worker=worker, keys=[key]) + else: + ts._metadata.update(kwargs["metadata"]) + r: tuple = parent._transition(key, "memory", worker=worker, **kwargs) + recommendations, client_msgs, worker_msgs = r - if ts._state == "memory": - assert ws in ts._who_has - return recommendations, client_msgs, worker_msgs + if ts._state == "memory": + assert ws in ts._who_has + return recommendations, client_msgs, worker_msgs - @stimulus_handler(sync=True) def stimulus_task_erred( self, key=None, @@ -5082,58 +5106,62 @@ def stimulus_task_erred( ): """Mark that a task has erred on a particular worker""" parent: SchedulerState = cast(SchedulerState, self) - logger.debug("Stimulus task erred %s, %s", key, worker) - ts: TaskState = parent._tasks.get(key) - if ts is None or ts._state != "processing": - return {}, {}, {} + with self.stimulus_id(stimulus_id): + logger.debug("Stimulus task erred %s, %s", key, worker) - if ts._retries > 0: - ts._retries -= 1 - return parent._transition(key, "waiting") - else: - return parent._transition( - key, - "erred", - cause=key, - exception=exception, - traceback=traceback, - worker=worker, - **kwargs, - ) + ts: TaskState = parent._tasks.get(key) + if ts is None or ts._state != "processing": + return {}, {}, {} + + if ts._retries > 0: + ts._retries -= 1 + return parent._transition(key, "waiting") + else: + return parent._transition( + key, + "erred", + cause=key, + exception=exception, + traceback=traceback, + worker=worker, + **kwargs, + ) - @stimulus_handler(sync=True) def stimulus_retry(self, keys, client=None, stimulus_id=None): parent: SchedulerState = cast(SchedulerState, self) - logger.info("Client %s requests to retry %d keys", client, len(keys)) - if client: - self.log_event(client, {"action": "retry", "count": len(keys)}) - stack = list(keys) - seen = set() - roots = [] - ts: TaskState - dts: TaskState - while stack: - key = stack.pop() - seen.add(key) - ts = parent._tasks[key] - erred_deps = [dts._key for dts in ts._dependencies if dts._state == "erred"] - if erred_deps: - stack.extend(erred_deps) - else: - roots.append(key) + with self.stimulus_id(stimulus_id): + logger.info("Client %s requests to retry %d keys", client, len(keys)) + if client: + self.log_event(client, {"action": "retry", "count": len(keys)}) + + stack = list(keys) + seen = set() + roots = [] + ts: TaskState + dts: TaskState + while stack: + key = stack.pop() + seen.add(key) + ts = parent._tasks[key] + erred_deps = [ + dts._key for dts in ts._dependencies if dts._state == "erred" + ] + if erred_deps: + stack.extend(erred_deps) + else: + roots.append(key) - recommendations: dict = {key: "waiting" for key in roots} - self.transitions(recommendations) + recommendations: dict = {key: "waiting" for key in roots} + self.transitions(recommendations) - if parent._validate: - for key in seen: - assert not parent._tasks[key].exception_blame + if parent._validate: + for key in seen: + assert not parent._tasks[key].exception_blame - return tuple(seen) + return tuple(seen) - @stimulus_handler(sync=False) async def remove_worker(self, address, safe=False, close=True, stimulus_id=None): """ Remove worker from cluster @@ -5143,7 +5171,7 @@ async def remove_worker(self, address, safe=False, close=True, stimulus_id=None) state. """ parent: SchedulerState = cast(SchedulerState, self) - with log_errors(): + with log_errors(), self.stimulus_id(stimulus_id or f"remove-worker-{time()}"): if self.status == Status.closed: return @@ -5250,76 +5278,80 @@ def remove_worker_from_events(): return "OK" - @stimulus_handler(sync=True) def stimulus_cancel( self, comm, keys=None, client=None, force=False, stimulus_id=None ): """Stop execution on a list of keys""" - logger.info("Client %s requests to cancel %d keys", client, len(keys)) - if client: - self.log_event( - client, {"action": "cancel", "count": len(keys), "force": force} - ) - for key in keys: - self.cancel_key(key, client, force=force) + with self.stimulus_id(stimulus_id): + logger.info("Client %s requests to cancel %d keys", client, len(keys)) + if client: + self.log_event( + client, {"action": "cancel", "count": len(keys), "force": force} + ) + for key in keys: + self.cancel_key(key, client, force=force) - @stimulus_handler(sync=True) def cancel_key(self, key, client, retries=5, force=False, stimulus_id=None): """Cancel a particular key and all dependents""" # TODO: this should be converted to use the transition mechanism parent: SchedulerState = cast(SchedulerState, self) - ts: TaskState = parent._tasks.get(key) - dts: TaskState - try: - cs: ClientState = parent._clients[client] - except KeyError: - return - if ts is None or not ts._who_wants: # no key yet, lets try again in a moment - if retries: - self.loop.call_later( - 0.2, lambda: self.cancel_key(key, client, retries - 1) - ) - return - if force or ts._who_wants == {cs}: # no one else wants this key - for dts in list(ts._dependents): - self.cancel_key(dts._key, client, force=force) - logger.info("Scheduler cancels key %s. Force=%s", key, force) - self.report({"op": "cancelled-key", "key": key}) - clients = list(ts._who_wants) if force else [cs] - for cs in clients: - self.client_releases_keys(keys=[key], client=cs._client_key) - - @stimulus_handler(sync=True) + with self.stimulus_id(stimulus_id or f"cancel-key-{time()}"): + ts: TaskState = parent._tasks.get(key) + dts: TaskState + try: + cs: ClientState = parent._clients[client] + except KeyError: + return + if ( + ts is None or not ts._who_wants + ): # no key yet, lets try again in a moment + if retries: + self.loop.call_later( + 0.2, lambda: self.cancel_key(key, client, retries - 1) + ) + return + if force or ts._who_wants == {cs}: # no one else wants this key + for dts in list(ts._dependents): + self.cancel_key(dts._key, client, force=force) + logger.info("Scheduler cancels key %s. Force=%s", key, force) + self.report({"op": "cancelled-key", "key": key}) + clients = list(ts._who_wants) if force else [cs] + for cs in clients: + self.client_releases_keys(keys=[key], client=cs._client_key) + def client_desires_keys(self, keys=None, client=None, stimulus_id=None): parent: SchedulerState = cast(SchedulerState, self) - cs: ClientState = parent._clients.get(client) - if cs is None: - # For publish, queues etc. - parent._clients[client] = cs = ClientState(client) - ts: TaskState - for k in keys: - ts = parent._tasks.get(k) - if ts is None: + with self.stimulus_id(stimulus_id or f"client-desires-keys-{time()}"): + cs: ClientState = parent._clients.get(client) + if cs is None: # For publish, queues etc. - ts = parent.new_task(k, None, "released") - ts._who_wants.add(cs) - cs._wants_what.add(ts) + parent._clients[client] = cs = ClientState(client) + ts: TaskState + for k in keys: + ts = parent._tasks.get(k) + if ts is None: + # For publish, queues etc. + ts = parent.new_task(k, None, "released") + ts._who_wants.add(cs) + cs._wants_what.add(ts) - if ts._state in ("memory", "erred"): - self.report_on_key(ts=ts, client=client) + if ts._state in ("memory", "erred"): + self.report_on_key(ts=ts, client=client) - @stimulus_handler(sync=True) def client_releases_keys(self, keys=None, client=None, stimulus_id=None): """Remove keys from client desired list""" parent: SchedulerState = cast(SchedulerState, self) - if not isinstance(keys, list): - keys = list(keys) - cs: ClientState = parent._clients[client] - recommendations: dict = {} + with self.stimulus_id(stimulus_id or f"client-releases-keys-{time()}"): + if not isinstance(keys, list): + keys = list(keys) + cs: ClientState = parent._clients[client] + recommendations: dict = {} - _client_releases_keys(parent, keys=keys, cs=cs, recommendations=recommendations) - self.transitions(recommendations) + _client_releases_keys( + parent, keys=keys, cs=cs, recommendations=recommendations + ) + self.transitions(recommendations) def client_heartbeat(self, client=None): """Handle heartbeats from Client""" @@ -5642,36 +5674,37 @@ def send_task_to_worker(self, worker, ts: TaskState, duration: double = -1): def handle_uncaught_error(self, **msg): logger.exception(clean_exception(**msg)[1]) - @stimulus_handler(sync=True) def handle_task_finished(self, key=None, worker=None, stimulus_id=None, **msg): parent: SchedulerState = cast(SchedulerState, self) - if worker not in parent._workers_dv: - return - validate_key(key) - recommendations: dict - client_msgs: dict - worker_msgs: dict + with self.stimulus_id(stimulus_id): + if worker not in parent._workers_dv: + return + validate_key(key) - r: tuple = self.stimulus_task_finished(key=key, worker=worker, **msg) - recommendations, client_msgs, worker_msgs = r - parent._transitions(recommendations, client_msgs, worker_msgs) + recommendations: dict + client_msgs: dict + worker_msgs: dict - self.send_all(client_msgs, worker_msgs) + r: tuple = self.stimulus_task_finished(key=key, worker=worker, **msg) + recommendations, client_msgs, worker_msgs = r + parent._transitions(recommendations, client_msgs, worker_msgs) + + self.send_all(client_msgs, worker_msgs) - @stimulus_handler(sync=True) def handle_task_erred(self, key=None, stimulus_id=None, **msg): parent: SchedulerState = cast(SchedulerState, self) recommendations: dict client_msgs: dict worker_msgs: dict - r: tuple = self.stimulus_task_erred(key=key, **msg) - recommendations, client_msgs, worker_msgs = r - parent._transitions(recommendations, client_msgs, worker_msgs) - self.send_all(client_msgs, worker_msgs) + with self.stimulus_id(stimulus_id): + r: tuple = self.stimulus_task_erred(key=key, **msg) + recommendations, client_msgs, worker_msgs = r + parent._transitions(recommendations, client_msgs, worker_msgs) + + self.send_all(client_msgs, worker_msgs) - @stimulus_handler(sync=True) def handle_missing_data( self, key=None, errant_worker=None, stimulus_id=None, **kwargs ): @@ -5690,39 +5723,42 @@ def handle_missing_data( Task key that could not be found, by default None errant_worker : str, optional Address of the worker supposed to hold a replica, by default None + stimulus_id : str, optional + Stimulus ID that generated this function call. """ parent: SchedulerState = cast(SchedulerState, self) logger.debug("handle missing data key=%s worker=%s", key, errant_worker) - self.log_event(errant_worker, {"action": "missing-data", "key": key}) - ts: TaskState = parent._tasks.get(key) - if ts is None: - return - ws: WorkerState = parent._workers_dv.get(errant_worker) - if ws is not None and ws in ts._who_has: - parent.remove_replica(ts, ws) - if ts.state == "memory" and not ts._who_has: - if ts._run_spec: - self.transitions({key: "released"}) - else: - self.transitions({key: "forgotten"}) + with self.stimulus_id(stimulus_id): + self.log_event(errant_worker, {"action": "missing-data", "key": key}) + ts: TaskState = parent._tasks.get(key) + if ts is None: + return + ws: WorkerState = parent._workers_dv.get(errant_worker) + + if ws is not None and ws in ts._who_has: + parent.remove_replica(ts, ws) + if ts.state == "memory" and not ts._who_has: + if ts._run_spec: + self.transitions({key: "released"}) + else: + self.transitions({key: "forgotten"}) - @stimulus_handler(sync=True) def release_worker_data(self, key, worker, stimulus_id=None): parent: SchedulerState = cast(SchedulerState, self) - ws: WorkerState = parent._workers_dv.get(worker) - ts: TaskState = parent._tasks.get(key) - if not ws or not ts: - return - recommendations: dict = {} - if ws in ts._who_has: - parent.remove_replica(ts, ws) - if not ts._who_has: - recommendations[ts._key] = "released" - if recommendations: - self.transitions(recommendations) + with self.stimulus_id(stimulus_id or f"release-worker-data-{time()}"): + ws: WorkerState = parent._workers_dv.get(worker) + ts: TaskState = parent._tasks.get(key) + if not ws or not ts: + return + recommendations: dict = {} + if ws in ts._who_has: + parent.remove_replica(ts, ws) + if not ts._who_has: + recommendations[ts._key] = "released" + if recommendations: + self.transitions(recommendations) - @stimulus_handler(sync=True) def handle_long_running( self, key=None, worker=None, compute_duration=None, stimulus_id=None ): @@ -5732,80 +5768,87 @@ def handle_long_running( duration accounting as if the task has stopped. """ parent: SchedulerState = cast(SchedulerState, self) - if key not in parent._tasks: - logger.debug("Skipping long_running since key %s was already released", key) - return - ts: TaskState = parent._tasks[key] - steal = parent._extensions.get("stealing") - if steal is not None: - steal.remove_key_from_stealable(ts) - ws: WorkerState = ts._processing_on - if ws is None: - logger.debug("Received long-running signal from duplicate task. Ignoring.") - return + with self.stimulus_id(stimulus_id): + if key not in parent._tasks: + logger.debug( + "Skipping long_running since key %s was already released", key + ) + return + ts: TaskState = parent._tasks[key] + steal = parent._extensions.get("stealing") + if steal is not None: + steal.remove_key_from_stealable(ts) - if compute_duration: - old_duration: double = ts._prefix._duration_average - new_duration: double = compute_duration - avg_duration: double - if old_duration < 0: - avg_duration = new_duration - else: - avg_duration = 0.5 * old_duration + 0.5 * new_duration - - ts._prefix._duration_average = avg_duration - - occ: double = ws._processing[ts] - ws._occupancy -= occ - parent._total_occupancy -= occ - # Cannot remove from processing since we're using this for things like - # idleness detection. Idle workers are typically targeted for - # downscaling but we should not downscale workers with long running - # tasks - ws._processing[ts] = 0 - ws._long_running.add(ts) - self.check_idle_saturated(ws) + ws: WorkerState = ts._processing_on + if ws is None: + logger.debug( + "Received long-running signal from duplicate task. Ignoring." + ) + return + + if compute_duration: + old_duration: double = ts._prefix._duration_average + new_duration: double = compute_duration + avg_duration: double + if old_duration < 0: + avg_duration = new_duration + else: + avg_duration = 0.5 * old_duration + 0.5 * new_duration + + ts._prefix._duration_average = avg_duration + + occ: double = ws._processing[ts] + ws._occupancy -= occ + parent._total_occupancy -= occ + # Cannot remove from processing since we're using this for things like + # idleness detection. Idle workers are typically targeted for + # downscaling but we should not downscale workers with long running + # tasks + ws._processing[ts] = 0 + ws._long_running.add(ts) + self.check_idle_saturated(ws) - @stimulus_handler(sync=True) def handle_worker_status_change( - self, status: str, worker: str, stimulus_id=None + self, status: str, worker: str, stimulus_id: str ) -> None: parent: SchedulerState = cast(SchedulerState, self) - ws: WorkerState = parent._workers_dv.get(worker) # type: ignore - if not ws: - return - prev_status = ws._status - ws._status = Status.lookup[status] # type: ignore - if ws._status == prev_status: - return - self.log_event( - ws._address, - { - "action": "worker-status-change", - "prev-status": prev_status.name, - "status": status, - }, - ) + with self.stimulus_id(stimulus_id): + ws: WorkerState = parent._workers_dv.get(worker) # type: ignore + if not ws: + return + prev_status = ws._status + ws._status = Status.lookup[status] # type: ignore + if ws._status == prev_status: + return - if ws._status == Status.running: - parent._running.add(ws) + self.log_event( + ws._address, + { + "action": "worker-status-change", + "prev-status": prev_status.name, + "status": status, + }, + ) - recs = {} - ts: TaskState - for ts in parent._unrunnable: - valid: set = self.valid_workers(ts) - if valid is None or ws in valid: - recs[ts._key] = "waiting" - if recs: - client_msgs: dict = {} - worker_msgs: dict = {} - parent._transitions(recs, client_msgs, worker_msgs) - self.send_all(client_msgs, worker_msgs) + if ws._status == Status.running: + parent._running.add(ws) - else: - parent._running.discard(ws) + recs = {} + ts: TaskState + for ts in parent._unrunnable: + valid: set = self.valid_workers(ts) + if valid is None or ws in valid: + recs[ts._key] = "waiting" + if recs: + client_msgs: dict = {} + worker_msgs: dict = {} + parent._transitions(recs, client_msgs, worker_msgs) + self.send_all(client_msgs, worker_msgs) + + else: + parent._running.discard(ws) async def handle_worker(self, comm=None, worker=None): """ @@ -6025,7 +6068,6 @@ def send_all(self, client_msgs: dict, worker_msgs: dict): # Less common interactions # ############################ - @stimulus_handler(sync=False) async def scatter( self, comm=None, @@ -6045,106 +6087,109 @@ async def scatter( parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState - start = time() - while True: - if workers is None: - wss = parent._running - else: - workers = [self.coerce_address(w) for w in workers] - wss = {parent._workers_dv[w] for w in workers} - wss = {ws for ws in wss if ws._status == Status.running} + with self.stimulus_id(stimulus_id or f"scatter-{time()}"): + start = time() + while True: + if workers is None: + wss = parent._running + else: + workers = [self.coerce_address(w) for w in workers] + wss = {parent._workers_dv[w] for w in workers} + wss = {ws for ws in wss if ws._status == Status.running} - if wss: - break - if time() > start + timeout: - raise TimeoutError("No valid workers found") - await asyncio.sleep(0.1) + if wss: + break + if time() > start + timeout: + raise TimeoutError("No valid workers found") + await asyncio.sleep(0.1) - nthreads = {ws._address: ws.nthreads for ws in wss} + nthreads = {ws._address: ws.nthreads for ws in wss} - assert isinstance(data, dict) + assert isinstance(data, dict) - keys, who_has, nbytes = await scatter_to_workers( - nthreads, data, rpc=self.rpc, report=False - ) + keys, who_has, nbytes = await scatter_to_workers( + nthreads, data, rpc=self.rpc, report=False + ) - self.update_data(who_has=who_has, nbytes=nbytes, client=client) + self.update_data(who_has=who_has, nbytes=nbytes, client=client) - if broadcast: - n = len(nthreads) if broadcast is True else broadcast - await self.replicate(keys=keys, workers=workers, n=n) + if broadcast: + n = len(nthreads) if broadcast is True else broadcast + await self.replicate(keys=keys, workers=workers, n=n) - self.log_event( - [client, "all"], {"action": "scatter", "client": client, "count": len(data)} - ) - return keys + self.log_event( + [client, "all"], + {"action": "scatter", "client": client, "count": len(data)}, + ) + return keys - @stimulus_handler(sync=False) async def gather(self, keys, serializers=None, stimulus_id=None): """Collect data from workers to the scheduler""" parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState - keys = list(keys) - who_has = {} - for key in keys: - ts: TaskState = parent._tasks.get(key) - if ts is not None: - who_has[key] = [ws._address for ws in ts._who_has] - else: - who_has[key] = [] - data, missing_keys, missing_workers = await gather_from_workers( - who_has, rpc=self.rpc, close=False, serializers=serializers - ) - if not missing_keys: - result = {"status": "OK", "data": data} - else: - missing_states = [ - (parent._tasks[key].state if key in parent._tasks else None) - for key in missing_keys - ] - logger.exception( - "Couldn't gather keys %s state: %s workers: %s", - missing_keys, - missing_states, - missing_workers, + with self.stimulus_id(stimulus_id or f"gather-{time()}"): + keys = list(keys) + who_has = {} + for key in keys: + ts: TaskState = parent._tasks.get(key) + if ts is not None: + who_has[key] = [ws._address for ws in ts._who_has] + else: + who_has[key] = [] + + data, missing_keys, missing_workers = await gather_from_workers( + who_has, rpc=self.rpc, close=False, serializers=serializers ) - result = {"status": "error", "keys": missing_keys} - with log_errors(): - # Remove suspicious workers from the scheduler but allow them to - # reconnect. - await asyncio.gather( - *( - self.remove_worker(address=worker, close=False) - for worker in missing_workers - ) + if not missing_keys: + result = {"status": "OK", "data": data} + else: + missing_states = [ + (parent._tasks[key].state if key in parent._tasks else None) + for key in missing_keys + ] + logger.exception( + "Couldn't gather keys %s state: %s workers: %s", + missing_keys, + missing_states, + missing_workers, ) - recommendations: dict - client_msgs: dict = {} - worker_msgs: dict = {} - for key, workers in missing_keys.items(): - # Task may already be gone if it was held by a - # `missing_worker` - ts: TaskState = parent._tasks.get(key) - logger.exception( - "Workers don't have promised key: %s, %s", - str(workers), - str(key), + result = {"status": "error", "keys": missing_keys} + with log_errors(): + # Remove suspicious workers from the scheduler but allow them to + # reconnect. + await asyncio.gather( + *( + self.remove_worker(address=worker, close=False) + for worker in missing_workers + ) ) - if not workers or ts is None: - continue - recommendations: dict = {key: "released"} - for worker in workers: - ws = parent._workers_dv.get(worker) - if ws is not None and ws in ts._who_has: - parent.remove_replica(ts, ws) - parent._transitions( - recommendations, client_msgs, worker_msgs - ) - self.send_all(client_msgs, worker_msgs) + recommendations: dict + client_msgs: dict = {} + worker_msgs: dict = {} + for key, workers in missing_keys.items(): + # Task may already be gone if it was held by a + # `missing_worker` + ts: TaskState = parent._tasks.get(key) + logger.exception( + "Workers don't have promised key: %s, %s", + str(workers), + str(key), + ) + if not workers or ts is None: + continue + recommendations: dict = {key: "released"} + for worker in workers: + ws = parent._workers_dv.get(worker) + if ws is not None and ws in ts._who_has: + parent.remove_replica(ts, ws) + parent._transitions( + recommendations, client_msgs, worker_msgs + ) + self.send_all(client_msgs, worker_msgs) - self.log_event("all", {"action": "gather", "count": len(keys)}) - return result + self.log_event("all", {"action": "gather", "count": len(keys)}) + return result def clear_task_state(self): # XXX what about nested state such as ClientState.wants_what @@ -6153,11 +6198,10 @@ def clear_task_state(self): for collection in self._task_state_collections: collection.clear() - @stimulus_handler(sync=False) async def restart(self, client=None, timeout=30, stimulus_id=None): """Restart all workers. Reset local state.""" parent: SchedulerState = cast(SchedulerState, self) - with log_errors(): + with log_errors(), self.stimulus_id(stimulus_id or f"restart-{time()}"): n_workers = len(parent._workers_dv) @@ -6370,7 +6414,6 @@ async def gather_on_worker( return keys_failed - @stimulus_handler(sync=False) async def delete_worker_data( self, worker_address: str, @@ -6388,37 +6431,38 @@ async def delete_worker_data( """ parent: SchedulerState = cast(SchedulerState, self) - try: - await retry_operation( - self.rpc(addr=worker_address).free_keys, - keys=list(keys), - stimulus_id=f"delete-data-{time()}", - ) - except OSError as e: - # This can happen e.g. if the worker is going through controlled shutdown; - # it doesn't necessarily mean that it went unexpectedly missing - logger.warning( - f"Communication with worker {worker_address} failed during " - f"replication: {e.__class__.__name__}: {e}" - ) - return + with self.stimulus_id(stimulus_id or f"delete-worker-data-{time()}"): - ws: WorkerState = parent._workers_dv.get(worker_address) # type: ignore - if ws is None: - return + try: + await retry_operation( + self.rpc(addr=worker_address).free_keys, + keys=list(keys), + stimulus_id=f"delete-data-{time()}", + ) + except OSError as e: + # This can happen e.g. if the worker is going through controlled shutdown; + # it doesn't necessarily mean that it went unexpectedly missing + logger.warning( + f"Communication with worker {worker_address} failed during " + f"replication: {e.__class__.__name__}: {e}" + ) + return - for key in keys: - ts: TaskState = parent._tasks.get(key) # type: ignore - if ts is not None and ws in ts._who_has: - assert ts._state == "memory" - parent.remove_replica(ts, ws) - if not ts._who_has: - # Last copy deleted - self.transitions({key: "released"}) + ws: WorkerState = parent._workers_dv.get(worker_address) # type: ignore + if ws is None: + return - self.log_event(ws._address, {"action": "remove-worker-data", "keys": keys}) + for key in keys: + ts: TaskState = parent._tasks.get(key) # type: ignore + if ts is not None and ws in ts._who_has: + assert ts._state == "memory" + parent.remove_replica(ts, ws) + if not ts._who_has: + # Last copy deleted + self.transitions({key: "released"}) + + self.log_event(ws._address, {"action": "remove-worker-data", "keys": keys}) - @stimulus_handler(sync=False) async def rebalance( self, comm=None, @@ -6494,7 +6538,7 @@ async def rebalance( Stimulus ID that caused this function call """ parent: SchedulerState = cast(SchedulerState, self) - with log_errors(): + with log_errors(), self.stimulus_id(stimulus_id or f"rebalance-{time()}"): wss: "Collection[WorkerState]" if workers is not None: wss = [parent._workers_dv[w] for w in workers] @@ -6786,7 +6830,6 @@ async def _rebalance_move_data( else: return {"status": "OK"} - @stimulus_handler(sync=False) async def replicate( self, comm=None, @@ -6826,86 +6869,89 @@ async def replicate( wws: WorkerState ts: TaskState - assert branching_factor > 0 - async with self._lock if lock else empty_context: - if workers is not None: - workers = {parent._workers_dv[w] for w in self.workers_list(workers)} - workers = {ws for ws in workers if ws._status == Status.running} - else: - workers = parent._running - - if n is None: - n = len(workers) - else: - n = min(n, len(workers)) - if n == 0: - raise ValueError("Can not use replicate to delete data") - - tasks = {parent._tasks[k] for k in keys} - missing_data = [ts._key for ts in tasks if not ts._who_has] - if missing_data: - return {"status": "partial-fail", "keys": missing_data} - - # Delete extraneous data - if delete: - del_worker_tasks = defaultdict(set) - for ts in tasks: - del_candidates = tuple(ts._who_has & workers) - if len(del_candidates) > n: - for ws in random.sample( - del_candidates, len(del_candidates) - n - ): - del_worker_tasks[ws].add(ts) - - # Note: this never raises exceptions - await asyncio.gather( - *[ - self.delete_worker_data(ws._address, [t.key for t in tasks]) - for ws, tasks in del_worker_tasks.items() - ] - ) + with self.stimulus_id(stimulus_id or f"replicate-{time()}"): + assert branching_factor > 0 + async with self._lock if lock else empty_context: + if workers is not None: + workers = { + parent._workers_dv[w] for w in self.workers_list(workers) + } + workers = {ws for ws in workers if ws._status == Status.running} + else: + workers = parent._running - # Copy not-yet-filled data - while tasks: - gathers = defaultdict(dict) - for ts in list(tasks): - if ts._state == "forgotten": - # task is no longer needed by any client or dependant task - tasks.remove(ts) - continue - n_missing = n - len(ts._who_has & workers) - if n_missing <= 0: - # Already replicated enough - tasks.remove(ts) - continue + if n is None: + n = len(workers) + else: + n = min(n, len(workers)) + if n == 0: + raise ValueError("Can not use replicate to delete data") - count = min(n_missing, branching_factor * len(ts._who_has)) - assert count > 0 + tasks = {parent._tasks[k] for k in keys} + missing_data = [ts._key for ts in tasks if not ts._who_has] + if missing_data: + return {"status": "partial-fail", "keys": missing_data} - for ws in random.sample(tuple(workers - ts._who_has), count): - gathers[ws._address][ts._key] = [ - wws._address for wws in ts._who_has + # Delete extraneous data + if delete: + del_worker_tasks = defaultdict(set) + for ts in tasks: + del_candidates = tuple(ts._who_has & workers) + if len(del_candidates) > n: + for ws in random.sample( + del_candidates, len(del_candidates) - n + ): + del_worker_tasks[ws].add(ts) + + # Note: this never raises exceptions + await asyncio.gather( + *[ + self.delete_worker_data(ws._address, [t.key for t in tasks]) + for ws, tasks in del_worker_tasks.items() ] + ) - await asyncio.gather( - *( - # Note: this never raises exceptions - self.gather_on_worker(w, who_has) - for w, who_has in gathers.items() + # Copy not-yet-filled data + while tasks: + gathers = defaultdict(dict) + for ts in list(tasks): + if ts._state == "forgotten": + # task is no longer needed by any client or dependant task + tasks.remove(ts) + continue + n_missing = n - len(ts._who_has & workers) + if n_missing <= 0: + # Already replicated enough + tasks.remove(ts) + continue + + count = min(n_missing, branching_factor * len(ts._who_has)) + assert count > 0 + + for ws in random.sample(tuple(workers - ts._who_has), count): + gathers[ws._address][ts._key] = [ + wws._address for wws in ts._who_has + ] + + await asyncio.gather( + *( + # Note: this never raises exceptions + self.gather_on_worker(w, who_has) + for w, who_has in gathers.items() + ) ) - ) - for r, v in gathers.items(): - self.log_event(r, {"action": "replicate-add", "who_has": v}) + for r, v in gathers.items(): + self.log_event(r, {"action": "replicate-add", "who_has": v}) - self.log_event( - "all", - { - "action": "replicate", - "workers": list(workers), - "key-count": len(keys), - "branching-factor": branching_factor, - }, - ) + self.log_event( + "all", + { + "action": "replicate", + "workers": list(workers), + "key-count": len(keys), + "branching-factor": branching_factor, + }, + ) def workers_to_close( self, @@ -7043,7 +7089,6 @@ def _key(group): return result - @stimulus_handler(sync=False) async def retire_workers( self, comm=None, @@ -7091,7 +7136,7 @@ async def retire_workers( parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState ts: TaskState - with log_errors(): + with log_errors(), self.stimulus_id(stimulus_id or f"retire-workers-{time()}"): # This lock makes retire_workers, rebalance, and replicate mutually # exclusive and will no longer be necessary once rebalance and replicate are # migrated to the Active Memory Manager. @@ -7221,7 +7266,6 @@ async def _track_retire_worker( logger.info("Retired worker %s", ws._address) return ws._address, ws.identity() - @stimulus_handler(sync=True) def add_keys(self, worker=None, keys=(), stimulus_id=None): """ Learn that a worker has certain keys @@ -7230,33 +7274,33 @@ def add_keys(self, worker=None, keys=(), stimulus_id=None): reasons. However, it is sent by workers from time to time. """ parent: SchedulerState = cast(SchedulerState, self) - if worker not in parent._workers_dv: - return "not found" - ws: WorkerState = parent._workers_dv[worker] - redundant_replicas = [] - for key in keys: - ts: TaskState = parent._tasks.get(key) - if ts is not None and ts._state == "memory": - if ws not in ts._who_has: - parent.add_replica(ts, ws) - else: - redundant_replicas.append(key) + with self.stimulus_id(stimulus_id or f"add-keys-{time()}"): + if worker not in parent._workers_dv: + return "not found" + ws: WorkerState = parent._workers_dv[worker] + redundant_replicas = [] + for key in keys: + ts: TaskState = parent._tasks.get(key) + if ts is not None and ts._state == "memory": + if ws not in ts._who_has: + parent.add_replica(ts, ws) + else: + redundant_replicas.append(key) - if redundant_replicas: - self.worker_send( - worker, - { - "op": "remove-replicas", - "keys": redundant_replicas, - "stimulus_id": self.STIMULUS_ID.get( - stimulus_id or f"redundant-replicas-{time()}" - ), - }, - ) + if redundant_replicas: + self.worker_send( + worker, + { + "op": "remove-replicas", + "keys": redundant_replicas, + "stimulus_id": self.STIMULUS_ID.get( + stimulus_id or f"redundant-replicas-{time()}" + ), + }, + ) - return "OK" + return "OK" - @stimulus_handler(sync=True) def update_data( self, *, who_has: dict, nbytes: dict, client=None, stimulus_id=None ): @@ -7268,7 +7312,7 @@ def update_data( Scheduler.mark_key_in_memory """ parent: SchedulerState = cast(SchedulerState, self) - with log_errors(): + with log_errors(), self.stimulus_id(stimulus_id or f"update-data-{time()}"): who_has = { k: [self.coerce_address(vv) for vv in v] for k, v in who_has.items() } @@ -7734,7 +7778,6 @@ def story(self, *keys): transition_story = story - @stimulus_handler(sync=True) def reschedule(self, key=None, worker=None, stimulus_id=None): """Reschedule a task @@ -7743,19 +7786,20 @@ def reschedule(self, key=None, worker=None, stimulus_id=None): """ parent: SchedulerState = cast(SchedulerState, self) ts: TaskState - try: - ts = parent._tasks[key] - except KeyError: - logger.warning( - "Attempting to reschedule task {}, which was not " - "found on the scheduler. Aborting reschedule.".format(key) - ) - return - if ts._state != "processing": - return - if worker and ts._processing_on.address != worker: - return - self.transitions({key: "released"}) + with self.stimulus_id(stimulus_id or f"reschedule-{key}"): + try: + ts = parent._tasks[key] + except KeyError: + logger.warning( + "Attempting to reschedule task {}, which was not " + "found on the scheduler. Aborting reschedule.".format(key) + ) + return + if ts._state != "processing": + return + if worker and ts._processing_on.address != worker: + return + self.transitions({key: "released"}) ##################### # Utility functions # diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 43d1ef38f52..e815b106fc7 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -3565,10 +3565,10 @@ async def test_stimuli(c, s, a, b): ) stimuli = [ - "update_graph_hlg", - "update_graph_hlg", + "update-graph-hlg", + "update-graph-hlg", "task-finished", - "client_releases_keys", + "client-releases-keys", ] stories = s.story(key) diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index d16c4540594..9d2b6f7794f 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -1,107 +1,22 @@ import asyncio -import inspect import logging import random from collections import defaultdict -from functools import partial, wraps +from functools import partial from itertools import cycle from tlz import concat, drop, groupby, merge import dask.config from dask.optimization import SubgraphCallable -from dask.utils import funcname, parse_timedelta, stringify +from dask.utils import parse_timedelta, stringify from distributed.core import rpc -from distributed.metrics import time from distributed.utils import All logger = logging.getLogger(__name__) -def stimulus_handler(*args, sync: bool = True): - """Decorator controlling injection into RPC Handlers - - RPC Handler functions are entrypoints into the distributed Scheduler. - These entrypoints may receive stimuli from external entities such - as workers in the ``stimulus_id`` kwarg or they - may generate stimuli themselves. - A further complication is that RPC Handlers may call other RPC handler - functions. - - This decorator exists to simplify the setting of - the Scheduler STIMULUS_ID and encapsulates the following logic - - 1. If the STIMULUS_ID is already set, stimuli from other sources - are ignored. - 2. If a ``stimulus_id`` kwargs is supplied by an external entity - such as a worker, the STIMULUS_ID is set to this value. - 3. Otherwise, the STIMULUS_ID is generated from the function name and - current time. - - Parameters - ---------- - *args : tuple - If the decorator is called without keyword arguments it will - be assumed that the decorated function is in ``args[0]``. - Otherwise should be empty if call with keyword arguments. - sync : bool - Indicates whether function is sync or async. - Necessary to distinguish between sync and async stimulus handlers - in a cython environment. https://bugs.python.org/issue38225 - """ - - def decorator(fn): - name = funcname(fn) - params = list(inspect.signature(fn).parameters.values()) - if params[0].name != "self": - raise ValueError(f"{fn} must be a method") # pragma: nocover - - if sync: - - @wraps(fn) - def wrapper(self, *args, **kw): - STIMULUS_ID = self.STIMULUS_ID - - try: - STIMULUS_ID.get() - except LookupError: - stimulus_id = kw.get("stimulus_id", None) or f"{name}-{time()}" - token = STIMULUS_ID.set(stimulus_id) - else: - token = None - - try: - return fn(self, *args, **kw) - finally: - if token: - STIMULUS_ID.reset(token) - - else: - - @wraps(fn) - async def wrapper(self, *args, **kw): - STIMULUS_ID = self.STIMULUS_ID - - try: - STIMULUS_ID.get() - except LookupError: - stimulus_id = kw.get("stimulus_id", None) or f"{name}-{time()}" - token = STIMULUS_ID.set(stimulus_id) - else: - token = None - - try: - return await fn(self, *args, **kw) - finally: - if token: - STIMULUS_ID.reset(token) - - return wrapper - - return decorator(args[0]) if args and callable(args[0]) else decorator - - async def gather_from_workers(who_has, rpc, close=True, serializers=None, who=None): """Gather data directly from peers From 046b49aee33c0f8247c7d0e7d2b6ba4efd18b316 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Tue, 5 Apr 2022 13:20:19 +0200 Subject: [PATCH 25/29] Remove default stimulus_id's throughout the scheduler --- distributed/scheduler.py | 52 ++++++++++++++------------------ distributed/stealing.py | 3 +- distributed/tests/test_worker.py | 4 ++- 3 files changed, 27 insertions(+), 32 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 83ed59c9dfd..1675bb05662 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2768,9 +2768,7 @@ def transition_waiting_processing(self, key): # logger.debug("Send job to worker: %s, %s", worker, key) - worker_msgs[worker] = [ - _task_to_msg(self, ts, self.STIMULUS_ID.get(f"compute-task-{time()}")) - ] + worker_msgs[worker] = [_task_to_msg(self, ts, self.STIMULUS_ID.get())] return recommendations, client_msgs, worker_msgs except Exception as e: @@ -2871,9 +2869,7 @@ def transition_processing_memory( { "op": "cancel-compute", "key": key, - "stimulus_id": scheduler.STIMULUS_ID.get( - f"processing-memory-{time()}" - ), + "stimulus_id": scheduler.STIMULUS_ID.get(), } ] @@ -2965,7 +2961,7 @@ def transition_memory_released(self, key, safe: bint = False): worker_msg = { "op": "free-keys", "keys": [key], - "stimulus_id": scheduler.STIMULUS_ID.get(f"memory-released-{time()}"), + "stimulus_id": scheduler.STIMULUS_ID.get(), } for ws in ts._who_has: worker_msgs[ws._address] = [worker_msg] @@ -3068,7 +3064,7 @@ def transition_erred_released(self, key): w_msg = { "op": "free-keys", "keys": [key], - "stimulus_id": self.STIMULUS_ID.get(f"erred-released-{time()}"), + "stimulus_id": self.STIMULUS_ID.get(), } for ws_addr in ts._erred_on: worker_msgs[ws_addr] = [w_msg] @@ -3147,9 +3143,7 @@ def transition_processing_released(self, key): { "op": "free-keys", "keys": [key], - "stimulus_id": self.STIMULUS_ID.get( - f"processing-released-{time()}" - ), + "stimulus_id": self.STIMULUS_ID.get(), } ] @@ -3345,7 +3339,7 @@ def transition_memory_forgotten(self, key): ts, recommendations, worker_msgs, - self.STIMULUS_ID.get(f"propagate-forgotten-{time()}"), + self.STIMULUS_ID.get(), ) client_msgs = _task_to_client_msgs(self, ts) @@ -3389,7 +3383,7 @@ def transition_released_forgotten(self, key): ts, recommendations, worker_msgs, - self.STIMULUS_ID.get(f"propagate-forgotten-{time()}"), + self.STIMULUS_ID.get(), ) client_msgs = _task_to_client_msgs(self, ts) @@ -5568,12 +5562,13 @@ async def add_client(self, comm: Comm, client: str, versions: dict) -> None: We listen to all future messages from this Comm. """ parent: SchedulerState = cast(SchedulerState, self) - try: - stimulus_id = self.STIMULUS_ID.get() - except LookupError: - pass - else: - if self._validate: + + if self._validate: + try: + stimulus_id = self.STIMULUS_ID.get() + except LookupError: + pass + else: raise RuntimeError( f"STIMULUS_ID {stimulus_id} set in Scheduler.add_client" ) @@ -5659,9 +5654,7 @@ def send_task_to_worker(self, worker, ts: TaskState, duration: double = -1): """Send a single computational task to a worker""" parent: SchedulerState = cast(SchedulerState, self) try: - msg: dict = _task_to_msg( - parent, ts, self.STIMULUS_ID.get(f"compute-task-{time()}"), duration - ) + msg: dict = _task_to_msg(parent, ts, self.STIMULUS_ID.get(), duration) self.worker_send(worker, msg) except Exception as e: logger.exception(e) @@ -5860,12 +5853,12 @@ async def handle_worker(self, comm=None, worker=None): -------- Scheduler.handle_client: Equivalent coroutine for clients """ - try: - stimulus_id = self.STIMULUS_ID.get() - except LookupError: - pass - else: - if self._validate: + if self._validate: + try: + stimulus_id = self.STIMULUS_ID.get() + except LookupError: + pass + else: raise RuntimeError( f"STIMULUS_ID {stimulus_id} set in Scheduler.handle_worker" ) @@ -6418,7 +6411,6 @@ async def delete_worker_data( self, worker_address: str, keys: "Collection[str]", - stimulus_id=None, ) -> None: """Delete data from a worker and update the corresponding worker/task states @@ -6431,7 +6423,7 @@ async def delete_worker_data( """ parent: SchedulerState = cast(SchedulerState, self) - with self.stimulus_id(stimulus_id or f"delete-worker-data-{time()}"): + with self.stimulus_id(f"delete-worker-data-{time()}"): try: await retry_operation( diff --git a/distributed/stealing.py b/distributed/stealing.py index 54ef0098c63..41a44670b0f 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -333,7 +333,8 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None): self.scheduler.total_occupancy += d["thief_duration"] self.put_key_in_stealable(ts) - self.scheduler.send_task_to_worker(thief.address, ts) + with self.scheduler.stimulus_id(stimulus_id): + self.scheduler.send_task_to_worker(thief.address, ts) self.log(("confirm", *_log_msg)) else: raise ValueError(f"Unexpected task state: {state}") diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 451426f358c..dbe58f5dd7e 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2547,7 +2547,9 @@ def __call__(self, *args, **kwargs): ts = s.tasks[fut.key] a.handle_steal_request(fut.key, stimulus_id="test") - stealing_ext.scheduler.send_task_to_worker(b.address, ts) + + with s.stimulus_id("test"): + stealing_ext.scheduler.send_task_to_worker(b.address, ts) fut2 = c.submit(inc, fut, workers=[a.address]) fut3 = c.submit(inc, fut2, workers=[a.address]) From 9241c7261aa6cdee3e5bdfd2bcee1acb06908a93 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Tue, 5 Apr 2022 13:26:55 +0200 Subject: [PATCH 26/29] RuntimeError -> AssertionError --- distributed/scheduler.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 1675bb05662..5a489474318 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5566,13 +5566,13 @@ async def add_client(self, comm: Comm, client: str, versions: dict) -> None: if self._validate: try: stimulus_id = self.STIMULUS_ID.get() - except LookupError: - pass - else: - raise RuntimeError( + raise AssertionError( f"STIMULUS_ID {stimulus_id} set in Scheduler.add_client" ) + except LookupError: + pass + assert client is not None comm.name = "Scheduler->Client" logger.info("Receive client connection: %s", client) @@ -5856,12 +5856,11 @@ async def handle_worker(self, comm=None, worker=None): if self._validate: try: stimulus_id = self.STIMULUS_ID.get() + raise AssertionError( + f"STIMULUS_ID {stimulus_id} set in Scheduler.add_client" + ) except LookupError: pass - else: - raise RuntimeError( - f"STIMULUS_ID {stimulus_id} set in Scheduler.handle_worker" - ) comm.name = "Scheduler connection to worker" worker_comm = self.stream_comms[worker] From 789080d42220ef6e95d5bbc7a44dbfa35824eee4 Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 5 Apr 2022 15:25:24 +0200 Subject: [PATCH 27/29] Move STIMULUS_ID ctxvar to utils --- distributed/client.py | 50 +- distributed/scheduler.py | 283 ++--- distributed/stealing.py | 4 +- distributed/tests/test_failed_workers.py | 5 +- distributed/tests/test_scheduler.py | 4 +- distributed/tests/test_worker.py | 4 +- .../tests/test_worker_state_machine.py | 1 + distributed/utils.py | 34 +- distributed/worker.py | 1064 +++++++++-------- distributed/worker_client.py | 15 +- distributed/worker_state_machine.py | 49 +- 11 files changed, 799 insertions(+), 714 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index af974ded4d9..081aaf661c4 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -81,6 +81,7 @@ from distributed.sizeof import sizeof from distributed.threadpoolexecutor import rejoin from distributed.utils import ( + STIMULUS_ID, All, Any, CancelledError, @@ -93,6 +94,7 @@ import_term, log_errors, no_default, + set_default_stimulus, sync, thread_state, ) @@ -472,6 +474,7 @@ def __setstate__(self, state): "tasks": {}, "keys": [stringify(self.key)], "client": c.id, + "stimulus_id": f"client-update-graph-{time()}", } ) @@ -1367,13 +1370,13 @@ def _inc_ref(self, key): self.refcount[key] += 1 def _dec_ref(self, key): - with self._refcount_lock: + with self._refcount_lock, set_default_stimulus(f"client-dec-ref-{time()}"): self.refcount[key] -= 1 if self.refcount[key] == 0: del self.refcount[key] - self._release_key(key, f"client-release-key-{time()}") + self._release_key(key) - def _release_key(self, key, stimulus_id: str): + def _release_key(self, key): """Release key from distributed memory""" logger.debug("Release key %s", key) st = self.futures.pop(key, None) @@ -1385,7 +1388,7 @@ def _release_key(self, key, stimulus_id: str): "op": "client-releases-keys", "keys": [key], "client": self.id, - "stimulus_id": stimulus_id, + "stimulus_id": STIMULUS_ID.get(), } ) @@ -1532,33 +1535,33 @@ async def _close(self, fast=False): ): await self.scheduler_comm.close() - stimulus_id = f"client-close-{time()}" + with set_default_stimulus(f"client-close-{time()}"): - for key in list(self.futures): - self._release_key(key=key, stimulus_id=stimulus_id) + for key in list(self.futures): + self._release_key(key=key) - if self._start_arg is None: - with suppress(AttributeError): - await self.cluster.close() + if self._start_arg is None: + with suppress(AttributeError): + await self.cluster.close() - await self.rpc.close() + await self.rpc.close() - self.status = "closed" + self.status = "closed" - if _get_global_client() is self: - _set_global_client(None) + if _get_global_client() is self: + _set_global_client(None) - if ( - handle_report_task is not None - and handle_report_task is not current_task - ): - with suppress(TimeoutError, asyncio.CancelledError): - await asyncio.wait_for(handle_report_task, 0 if fast else 2) + if ( + handle_report_task is not None + and handle_report_task is not current_task + ): + with suppress(TimeoutError, asyncio.CancelledError): + await asyncio.wait_for(handle_report_task, 0 if fast else 2) - with suppress(AttributeError): - await self.scheduler.close_rpc() + with suppress(AttributeError): + await self.scheduler.close_rpc() - self.scheduler = None + self.scheduler = None self.status = "closed" @@ -2947,6 +2950,7 @@ def _graph_to_futures( "fifo_timeout": fifo_timeout, "actors": actors, "code": self._get_computation_code(), + "stimulus_id": f"client-update-graph-hlg-{time()}", } ) return futures diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 5a489474318..7cd12a0b9d4 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -24,8 +24,7 @@ Mapping, Set, ) -from contextlib import contextmanager, suppress -from contextvars import ContextVar +from contextlib import suppress from datetime import timedelta from functools import partial from numbers import Number @@ -86,6 +85,7 @@ from distributed.stealing import WorkStealing from distributed.stories import scheduler_story from distributed.utils import ( + STIMULUS_ID, All, TimeoutError, empty_context, @@ -95,6 +95,7 @@ log_errors, no_default, recursive_to_dict, + set_default_stimulus, validate_key, ) from distributed.utils_comm import ( @@ -2374,7 +2375,7 @@ def _transition(self, key, finish: str, *args, **kwargs): # FIXME downcast antipattern scheduler = pep484_cast(Scheduler, self) - stimulus_id = scheduler.STIMULUS_ID.get(Scheduler.STIMULUS_ID_NOT_SET) + stimulus_id = STIMULUS_ID.get(Scheduler.STIMULUS_ID_NOT_SET) finish2 = ts._state scheduler.transition_log.append( @@ -2382,7 +2383,7 @@ def _transition(self, key, finish: str, *args, **kwargs): ) if parent._validate: if stimulus_id == Scheduler.STIMULUS_ID_NOT_SET: - raise LookupError(scheduler.STIMULUS_ID.name) + raise LookupError(STIMULUS_ID.name) logger.debug( "Transitioned %r %s->%s (actual: %s) from %s. Consequence: %s", @@ -2768,7 +2769,7 @@ def transition_waiting_processing(self, key): # logger.debug("Send job to worker: %s, %s", worker, key) - worker_msgs[worker] = [_task_to_msg(self, ts, self.STIMULUS_ID.get())] + worker_msgs[worker] = [_task_to_msg(self, ts)] return recommendations, client_msgs, worker_msgs except Exception as e: @@ -2862,14 +2863,11 @@ def transition_processing_memory( key, ) - # FIXME downcast antipattern - scheduler = pep484_cast(Scheduler, self) - worker_msgs[ts._processing_on.address] = [ { "op": "cancel-compute", "key": key, - "stimulus_id": scheduler.STIMULUS_ID.get(), + "stimulus_id": STIMULUS_ID.get(), } ] @@ -2953,15 +2951,12 @@ def transition_memory_released(self, key, safe: bint = False): elif dts._state == "waiting": dts._waiting_on.add(ts) - # FIXME downcast antipattern - scheduler = pep484_cast(Scheduler, self) - # XXX factor this out? worker_msg = { "op": "free-keys", "keys": [key], - "stimulus_id": scheduler.STIMULUS_ID.get(), + "stimulus_id": STIMULUS_ID.get(), } for ws in ts._who_has: worker_msgs[ws._address] = [worker_msg] @@ -3064,7 +3059,7 @@ def transition_erred_released(self, key): w_msg = { "op": "free-keys", "keys": [key], - "stimulus_id": self.STIMULUS_ID.get(), + "stimulus_id": STIMULUS_ID.get(), } for ws_addr in ts._erred_on: worker_msgs[ws_addr] = [w_msg] @@ -3143,7 +3138,7 @@ def transition_processing_released(self, key): { "op": "free-keys", "keys": [key], - "stimulus_id": self.STIMULUS_ID.get(), + "stimulus_id": STIMULUS_ID.get(), } ] @@ -3339,7 +3334,7 @@ def transition_memory_forgotten(self, key): ts, recommendations, worker_msgs, - self.STIMULUS_ID.get(), + STIMULUS_ID.get(), ) client_msgs = _task_to_client_msgs(self, ts) @@ -3383,7 +3378,7 @@ def transition_released_forgotten(self, key): ts, recommendations, worker_msgs, - self.STIMULUS_ID.get(), + STIMULUS_ID.get(), ) client_msgs = _task_to_client_msgs(self, ts) @@ -4011,8 +4006,6 @@ def __init__( "benchmark_hardware": self.benchmark_hardware, } - self.STIMULUS_ID = ContextVar(f"stimulus_id-{uuid.uuid4().hex}") - connection_limit = get_fileno_limit() / 2 super().__init__( @@ -4080,32 +4073,6 @@ def _repr_html_(self): tasks=parent._tasks, ) - @contextmanager - def stimulus_id(self, name: str): - """Context manager for setting the Scheduler stimulus_id - - If the stimulus_id has already been set further up the call stack, - this has no effect. - - Parameters - ---------- - name : str - The name of the stimulus. - """ - try: - stimulus_id = self.STIMULUS_ID.get() - except LookupError: - token = self.STIMULUS_ID.set(name) - stimulus_id = name - else: - token = None - - try: - yield stimulus_id - finally: - if token: - self.STIMULUS_ID.reset(token) - def identity(self): """Basic information about ourselves and our cluster""" parent: SchedulerState = cast(SchedulerState, self) @@ -4605,8 +4572,7 @@ async def add_worker( recommendations: dict = {} client_msgs: dict = {} worker_msgs: dict = {} - - with self.stimulus_id(f"add-worker-{time()}"): + with set_default_stimulus(f"add-worker-{time()}"): if nbytes: assert isinstance(nbytes, dict) already_released_keys = [] @@ -4707,29 +4673,31 @@ def update_graph_hlg( code=None, stimulus_id=None, ): - unpacked_graph = HighLevelGraph.__dask_distributed_unpack__(hlg) - dsk = unpacked_graph["dsk"] - dependencies = unpacked_graph["deps"] - annotations = unpacked_graph["annotations"] - - # Remove any self-dependencies (happens on test_publish_bag() and others) - for k, v in dependencies.items(): - deps = set(v) - if k in deps: - deps.remove(k) - dependencies[k] = deps - - if priority is None: - # Removing all non-local keys before calling order() - dsk_keys = set(dsk) # intersection() of sets is much faster than dict_keys - stripped_deps = { - k: v.intersection(dsk_keys) - for k, v in dependencies.items() - if k in dsk_keys - } - priority = dask.order.order(dsk, dependencies=stripped_deps) + with set_default_stimulus(stimulus_id): + unpacked_graph = HighLevelGraph.__dask_distributed_unpack__(hlg) + dsk = unpacked_graph["dsk"] + dependencies = unpacked_graph["deps"] + annotations = unpacked_graph["annotations"] + + # Remove any self-dependencies (happens on test_publish_bag() and others) + for k, v in dependencies.items(): + deps = set(v) + if k in deps: + deps.remove(k) + dependencies[k] = deps + + if priority is None: + # Removing all non-local keys before calling order() + dsk_keys = set( + dsk + ) # intersection() of sets is much faster than dict_keys + stripped_deps = { + k: v.intersection(dsk_keys) + for k, v in dependencies.items() + if k in dsk_keys + } + priority = dask.order.order(dsk, dependencies=stripped_deps) - with self.stimulus_id(stimulus_id or f"update-graph-hlg-{time()}"): return self.update_graph( client, dsk, @@ -4746,7 +4714,7 @@ def update_graph_hlg( fifo_timeout, annotations, code=code, - stimulus_id=self.STIMULUS_ID.get(), + stimulus_id=STIMULUS_ID.get(), ) def update_graph( @@ -4773,7 +4741,7 @@ def update_graph( This happens whenever the Client calls submit, map, get, or compute. """ - with self.stimulus_id(stimulus_id or f"update-graph-{time()}"): + with set_default_stimulus(stimulus_id): parent: SchedulerState = cast(SchedulerState, self) start = time() fifo_timeout = parse_timedelta(fifo_timeout) @@ -5049,45 +5017,44 @@ def update_graph( # TODO: balance workers - def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwargs): + def stimulus_task_finished(self, key=None, worker=None, **kwargs): """Mark that a task has finished execution on a particular worker""" parent: SchedulerState = cast(SchedulerState, self) - with self.stimulus_id(stimulus_id or f"already-released-or-forgotten-{time()}"): - logger.debug("Stimulus task finished %s, %s", key, worker) + logger.debug("Stimulus task finished %s, %s", key, worker) - recommendations: dict = {} - client_msgs: dict = {} - worker_msgs: dict = {} + recommendations: dict = {} + client_msgs: dict = {} + worker_msgs: dict = {} - ws: WorkerState = parent._workers_dv[worker] - ts: TaskState = parent._tasks.get(key) - if ts is None or ts._state == "released": - logger.debug( - "Received already computed task, worker: %s, state: %s" - ", key: %s, who_has: %s", - worker, - ts._state if ts else "forgotten", - key, - ts._who_has if ts else {}, - ) - worker_msgs[worker] = [ - { - "op": "free-keys", - "keys": [key], - "stimulus_id": self.STIMULUS_ID.get(), - } - ] - elif ts._state == "memory": - self.add_keys(worker=worker, keys=[key]) - else: - ts._metadata.update(kwargs["metadata"]) - r: tuple = parent._transition(key, "memory", worker=worker, **kwargs) - recommendations, client_msgs, worker_msgs = r + ws: WorkerState = parent._workers_dv[worker] + ts: TaskState = parent._tasks.get(key) + if ts is None or ts._state == "released": + logger.debug( + "Received already computed task, worker: %s, state: %s" + ", key: %s, who_has: %s", + worker, + ts._state if ts else "forgotten", + key, + ts._who_has if ts else {}, + ) + worker_msgs[worker] = [ + { + "op": "free-keys", + "keys": [key], + "stimulus_id": STIMULUS_ID.get(), + } + ] + elif ts._state == "memory": + self.add_keys(worker=worker, keys=[key]) + else: + ts._metadata.update(kwargs["metadata"]) + r: tuple = parent._transition(key, "memory", worker=worker, **kwargs) + recommendations, client_msgs, worker_msgs = r - if ts._state == "memory": - assert ws in ts._who_has - return recommendations, client_msgs, worker_msgs + if ts._state == "memory": + assert ws in ts._who_has + return recommendations, client_msgs, worker_msgs def stimulus_task_erred( self, @@ -5095,37 +5062,35 @@ def stimulus_task_erred( worker=None, exception=None, traceback=None, - stimulus_id=None, **kwargs, ): """Mark that a task has erred on a particular worker""" parent: SchedulerState = cast(SchedulerState, self) - with self.stimulus_id(stimulus_id): - logger.debug("Stimulus task erred %s, %s", key, worker) + logger.debug("Stimulus task erred %s, %s", key, worker) - ts: TaskState = parent._tasks.get(key) - if ts is None or ts._state != "processing": - return {}, {}, {} + ts: TaskState = parent._tasks.get(key) + if ts is None or ts._state != "processing": + return {}, {}, {} - if ts._retries > 0: - ts._retries -= 1 - return parent._transition(key, "waiting") - else: - return parent._transition( - key, - "erred", - cause=key, - exception=exception, - traceback=traceback, - worker=worker, - **kwargs, - ) + if ts._retries > 0: + ts._retries -= 1 + return parent._transition(key, "waiting") + else: + return parent._transition( + key, + "erred", + cause=key, + exception=exception, + traceback=traceback, + worker=worker, + **kwargs, + ) def stimulus_retry(self, keys, client=None, stimulus_id=None): parent: SchedulerState = cast(SchedulerState, self) - with self.stimulus_id(stimulus_id): + with set_default_stimulus(stimulus_id): logger.info("Client %s requests to retry %d keys", client, len(keys)) if client: self.log_event(client, {"action": "retry", "count": len(keys)}) @@ -5165,7 +5130,9 @@ async def remove_worker(self, address, safe=False, close=True, stimulus_id=None) state. """ parent: SchedulerState = cast(SchedulerState, self) - with log_errors(), self.stimulus_id(stimulus_id or f"remove-worker-{time()}"): + with log_errors(), set_default_stimulus( + stimulus_id or f"remove-worker-{time()}" + ): if self.status == Status.closed: return @@ -5276,7 +5243,7 @@ def stimulus_cancel( self, comm, keys=None, client=None, force=False, stimulus_id=None ): """Stop execution on a list of keys""" - with self.stimulus_id(stimulus_id): + with set_default_stimulus(stimulus_id): logger.info("Client %s requests to cancel %d keys", client, len(keys)) if client: self.log_event( @@ -5289,7 +5256,7 @@ def cancel_key(self, key, client, retries=5, force=False, stimulus_id=None): """Cancel a particular key and all dependents""" # TODO: this should be converted to use the transition mechanism parent: SchedulerState = cast(SchedulerState, self) - with self.stimulus_id(stimulus_id or f"cancel-key-{time()}"): + with set_default_stimulus(stimulus_id or f"cancel-key-{time()}"): ts: TaskState = parent._tasks.get(key) dts: TaskState try: @@ -5315,7 +5282,7 @@ def cancel_key(self, key, client, retries=5, force=False, stimulus_id=None): def client_desires_keys(self, keys=None, client=None, stimulus_id=None): parent: SchedulerState = cast(SchedulerState, self) - with self.stimulus_id(stimulus_id or f"client-desires-keys-{time()}"): + with set_default_stimulus(stimulus_id or f"client-desires-keys-{time()}"): cs: ClientState = parent._clients.get(client) if cs is None: # For publish, queues etc. @@ -5336,7 +5303,7 @@ def client_releases_keys(self, keys=None, client=None, stimulus_id=None): """Remove keys from client desired list""" parent: SchedulerState = cast(SchedulerState, self) - with self.stimulus_id(stimulus_id or f"client-releases-keys-{time()}"): + with set_default_stimulus(stimulus_id or f"client-releases-keys-{time()}"): if not isinstance(keys, list): keys = list(keys) cs: ClientState = parent._clients[client] @@ -5565,7 +5532,7 @@ async def add_client(self, comm: Comm, client: str, versions: dict) -> None: if self._validate: try: - stimulus_id = self.STIMULUS_ID.get() + stimulus_id = STIMULUS_ID.get() raise AssertionError( f"STIMULUS_ID {stimulus_id} set in Scheduler.add_client" ) @@ -5654,7 +5621,7 @@ def send_task_to_worker(self, worker, ts: TaskState, duration: double = -1): """Send a single computational task to a worker""" parent: SchedulerState = cast(SchedulerState, self) try: - msg: dict = _task_to_msg(parent, ts, self.STIMULUS_ID.get(), duration) + msg: dict = _task_to_msg(parent, ts, duration) self.worker_send(worker, msg) except Exception as e: logger.exception(e) @@ -5670,7 +5637,7 @@ def handle_uncaught_error(self, **msg): def handle_task_finished(self, key=None, worker=None, stimulus_id=None, **msg): parent: SchedulerState = cast(SchedulerState, self) - with self.stimulus_id(stimulus_id): + with set_default_stimulus(stimulus_id): if worker not in parent._workers_dv: return validate_key(key) @@ -5691,7 +5658,7 @@ def handle_task_erred(self, key=None, stimulus_id=None, **msg): client_msgs: dict worker_msgs: dict - with self.stimulus_id(stimulus_id): + with set_default_stimulus(stimulus_id): r: tuple = self.stimulus_task_erred(key=key, **msg) recommendations, client_msgs, worker_msgs = r parent._transitions(recommendations, client_msgs, worker_msgs) @@ -5722,7 +5689,7 @@ def handle_missing_data( parent: SchedulerState = cast(SchedulerState, self) logger.debug("handle missing data key=%s worker=%s", key, errant_worker) - with self.stimulus_id(stimulus_id): + with set_default_stimulus(stimulus_id): self.log_event(errant_worker, {"action": "missing-data", "key": key}) ts: TaskState = parent._tasks.get(key) if ts is None: @@ -5739,7 +5706,7 @@ def handle_missing_data( def release_worker_data(self, key, worker, stimulus_id=None): parent: SchedulerState = cast(SchedulerState, self) - with self.stimulus_id(stimulus_id or f"release-worker-data-{time()}"): + with set_default_stimulus(stimulus_id or f"release-worker-data-{time()}"): ws: WorkerState = parent._workers_dv.get(worker) ts: TaskState = parent._tasks.get(key) if not ws or not ts: @@ -5762,7 +5729,7 @@ def handle_long_running( """ parent: SchedulerState = cast(SchedulerState, self) - with self.stimulus_id(stimulus_id): + with set_default_stimulus(stimulus_id): if key not in parent._tasks: logger.debug( "Skipping long_running since key %s was already released", key @@ -5807,7 +5774,7 @@ def handle_worker_status_change( ) -> None: parent: SchedulerState = cast(SchedulerState, self) - with self.stimulus_id(stimulus_id): + with set_default_stimulus(stimulus_id): ws: WorkerState = parent._workers_dv.get(worker) # type: ignore if not ws: return @@ -5853,14 +5820,6 @@ async def handle_worker(self, comm=None, worker=None): -------- Scheduler.handle_client: Equivalent coroutine for clients """ - if self._validate: - try: - stimulus_id = self.STIMULUS_ID.get() - raise AssertionError( - f"STIMULUS_ID {stimulus_id} set in Scheduler.add_client" - ) - except LookupError: - pass comm.name = "Scheduler connection to worker" worker_comm = self.stream_comms[worker] @@ -6079,7 +6038,7 @@ async def scatter( parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState - with self.stimulus_id(stimulus_id or f"scatter-{time()}"): + with set_default_stimulus(stimulus_id or f"scatter-{time()}"): start = time() while True: if workers is None: @@ -6120,7 +6079,7 @@ async def gather(self, keys, serializers=None, stimulus_id=None): parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState - with self.stimulus_id(stimulus_id or f"gather-{time()}"): + with set_default_stimulus(stimulus_id or f"gather-{time()}"): keys = list(keys) who_has = {} for key in keys: @@ -6193,7 +6152,7 @@ def clear_task_state(self): async def restart(self, client=None, timeout=30, stimulus_id=None): """Restart all workers. Reset local state.""" parent: SchedulerState = cast(SchedulerState, self) - with log_errors(), self.stimulus_id(stimulus_id or f"restart-{time()}"): + with log_errors(), set_default_stimulus(stimulus_id or f"restart-{time()}"): n_workers = len(parent._workers_dv) @@ -6422,7 +6381,7 @@ async def delete_worker_data( """ parent: SchedulerState = cast(SchedulerState, self) - with self.stimulus_id(f"delete-worker-data-{time()}"): + with set_default_stimulus(f"delete-worker-data-{time()}"): try: await retry_operation( @@ -6529,7 +6488,7 @@ async def rebalance( Stimulus ID that caused this function call """ parent: SchedulerState = cast(SchedulerState, self) - with log_errors(), self.stimulus_id(stimulus_id or f"rebalance-{time()}"): + with log_errors(), set_default_stimulus(stimulus_id or f"rebalance-{time()}"): wss: "Collection[WorkerState]" if workers is not None: wss = [parent._workers_dv[w] for w in workers] @@ -6860,7 +6819,7 @@ async def replicate( wws: WorkerState ts: TaskState - with self.stimulus_id(stimulus_id or f"replicate-{time()}"): + with set_default_stimulus(stimulus_id or f"replicate-{time()}"): assert branching_factor > 0 async with self._lock if lock else empty_context: if workers is not None: @@ -7127,7 +7086,9 @@ async def retire_workers( parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState ts: TaskState - with log_errors(), self.stimulus_id(stimulus_id or f"retire-workers-{time()}"): + with log_errors(), set_default_stimulus( + stimulus_id or f"retire-workers-{time()}" + ): # This lock makes retire_workers, rebalance, and replicate mutually # exclusive and will no longer be necessary once rebalance and replicate are # migrated to the Active Memory Manager. @@ -7186,7 +7147,7 @@ async def retire_workers( { "op": "worker-status-change", "status": ws.status.name, - "stimulus_id": self.STIMULUS_ID.get(), + "stimulus_id": STIMULUS_ID.get(), } ) @@ -7235,7 +7196,7 @@ async def _track_retire_worker( { "op": "worker-status-change", "status": prev_status.name, - "stimulus_id": self.STIMULUS_ID.get(), + "stimulus_id": STIMULUS_ID.get(), } ) return None, {} @@ -7265,7 +7226,7 @@ def add_keys(self, worker=None, keys=(), stimulus_id=None): reasons. However, it is sent by workers from time to time. """ parent: SchedulerState = cast(SchedulerState, self) - with self.stimulus_id(stimulus_id or f"add-keys-{time()}"): + with set_default_stimulus(stimulus_id or f"add-keys-{time()}"): if worker not in parent._workers_dv: return "not found" ws: WorkerState = parent._workers_dv[worker] @@ -7284,7 +7245,7 @@ def add_keys(self, worker=None, keys=(), stimulus_id=None): { "op": "remove-replicas", "keys": redundant_replicas, - "stimulus_id": self.STIMULUS_ID.get( + "stimulus_id": STIMULUS_ID.get( stimulus_id or f"redundant-replicas-{time()}" ), }, @@ -7303,7 +7264,7 @@ def update_data( Scheduler.mark_key_in_memory """ parent: SchedulerState = cast(SchedulerState, self) - with log_errors(), self.stimulus_id(stimulus_id or f"update-data-{time()}"): + with log_errors(), set_default_stimulus(stimulus_id or f"update-data-{time()}"): who_has = { k: [self.coerce_address(vv) for vv in v] for k, v in who_has.items() } @@ -7777,7 +7738,7 @@ def reschedule(self, key=None, worker=None, stimulus_id=None): """ parent: SchedulerState = cast(SchedulerState, self) ts: TaskState - with self.stimulus_id(stimulus_id or f"reschedule-{key}"): + with set_default_stimulus(stimulus_id or f"reschedule-{key}"): try: ts = parent._tasks[key] except KeyError: @@ -8533,9 +8494,7 @@ def _client_releases_keys( @cfunc @exceptval(check=False) -def _task_to_msg( - state: SchedulerState, ts: TaskState, stimulus_id: str, duration: double = -1 -) -> dict: +def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) -> dict: """Convert a single computational task to a message""" ws: WorkerState dts: TaskState @@ -8549,7 +8508,7 @@ def _task_to_msg( "key": ts._key, "priority": ts._priority, "duration": duration, - "stimulus_id": stimulus_id, + "stimulus_id": STIMULUS_ID.get(), "who_has": {}, } if ts._resource_restrictions: diff --git a/distributed/stealing.py b/distributed/stealing.py index 41a44670b0f..27a2e850ab7 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -17,7 +17,7 @@ from distributed.comm.addressing import get_address_host from distributed.core import CommClosedError, Status from distributed.diagnostics.plugin import SchedulerPlugin -from distributed.utils import log_errors, recursive_to_dict +from distributed.utils import log_errors, recursive_to_dict, set_default_stimulus # Stealing requires multiple network bounces and if successful also task # submission which may include code serialization. Therefore, be very @@ -333,7 +333,7 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None): self.scheduler.total_occupancy += d["thief_duration"] self.put_key_in_stealable(ts) - with self.scheduler.stimulus_id(stimulus_id): + with set_default_stimulus(stimulus_id): self.scheduler.send_task_to_worker(thief.address, ts) self.log(("confirm", *_log_msg)) else: diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 0555ab20232..dfbe6a1a0cf 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -15,7 +15,7 @@ from distributed.compatibility import MACOS from distributed.metrics import time from distributed.scheduler import COMPILED -from distributed.utils import CancelledError, sync +from distributed.utils import CancelledError, set_default_stimulus, sync from distributed.utils_test import ( captured_logger, cluster, @@ -492,7 +492,8 @@ async def test_forget_data_not_supposed_to_have(s, a, b): ts = TaskState("key", state="flight") a.tasks["key"] = ts recommendations = {ts: ("memory", 123)} - a.transitions(recommendations, stimulus_id="test") + with set_default_stimulus("test"): + a.transitions(recommendations) assert a.data while a.data: diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index e815b106fc7..be6e1d11dca 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -3565,8 +3565,8 @@ async def test_stimuli(c, s, a, b): ) stimuli = [ - "update-graph-hlg", - "update-graph-hlg", + "client-update-graph-hlg", + "client-update-graph-hlg", "task-finished", "client-releases-keys", ] diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index dbe58f5dd7e..eae0e9831ff 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -42,7 +42,7 @@ from distributed.metrics import time from distributed.protocol import pickle from distributed.scheduler import Scheduler -from distributed.utils import TimeoutError +from distributed.utils import TimeoutError, set_default_stimulus from distributed.utils_test import ( TaskStateMetadataPlugin, _LockedCommPool, @@ -2548,7 +2548,7 @@ def __call__(self, *args, **kwargs): ts = s.tasks[fut.key] a.handle_steal_request(fut.key, stimulus_id="test") - with s.stimulus_id("test"): + with set_default_stimulus("test"): stealing_ext.scheduler.send_task_to_worker(b.address, ts) fut2 = c.submit(inc, fut, workers=[a.address]) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index f03966ba656..db89ad3db32 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -86,6 +86,7 @@ def test_unique_task_heap(): assert repr(heap) == "" +@pytest.mark.xfail("slots not compatible with defaultfactory") @pytest.mark.parametrize( "cls", chain( diff --git a/distributed/utils.py b/distributed/utils.py index cd0d93d08dd..4fe311e5bd2 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -23,13 +23,13 @@ from collections.abc import Collection, Container, KeysView, ValuesView from concurrent.futures import CancelledError, ThreadPoolExecutor # noqa: F401 from contextlib import contextmanager, suppress -from contextvars import ContextVar +from contextvars import ContextVar, Token from hashlib import md5 from importlib.util import cache_from_source from time import sleep from types import ModuleType from typing import Any as AnyType -from typing import ClassVar +from typing import ClassVar, Iterator import click import tblib.pickling_support @@ -67,6 +67,36 @@ no_default = "__no_default__" +STIMULUS_ID: ContextVar[str] = ContextVar("STIMULUS_ID") + + +@contextmanager +def set_default_stimulus(name: str) -> Iterator[str]: + """Context manager for setting the Scheduler stimulus_id + + If the stimulus_id has already been set further up the call stack, + this has no effect. + + Parameters + ---------- + name : str + The name of the stimulus. + """ + token: Token[str] | None + try: + stimulus_id = STIMULUS_ID.get() + except LookupError: + token = STIMULUS_ID.set(name) + stimulus_id = name + else: + token = None + + try: + yield stimulus_id + finally: + if token: + STIMULUS_ID.reset(token) + def _initialize_mp_context(): method = dask.config.get("distributed.worker.multiprocessing-method") diff --git a/distributed/worker.py b/distributed/worker.py index 47dd2d1b184..46d87fcea44 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -79,6 +79,7 @@ from distributed.threadpoolexecutor import secede as tpe_secede from distributed.utils import ( LRU, + STIMULUS_ID, TimeoutError, _maybe_complex, get_ip, @@ -92,6 +93,7 @@ offload, parse_ports, recursive_to_dict, + set_default_stimulus, silence_logging, thread_state, warn_on_duration, @@ -919,12 +921,13 @@ def status(self, value): """ prev_status = self.status ServerNode.status.__set__(self, value) - self._send_worker_status_change(f"worker-status-change-{time()}") - if prev_status == Status.paused and value == Status.running: - self.ensure_computing() - self.ensure_communicating() + with set_default_stimulus(f"init-{time()}"): + self._send_worker_status_change() + if prev_status == Status.paused and value == Status.running: + self.ensure_computing() + self.ensure_communicating() - def _send_worker_status_change(self, stimulus_id: str) -> None: + def _send_worker_status_change(self) -> None: if ( self.batched_stream and self.batched_stream.comm @@ -934,11 +937,11 @@ def _send_worker_status_change(self, stimulus_id: str) -> None: { "op": "worker-status-change", "status": self._status.name, - "stimulus_id": stimulus_id, + "stimulus_id": STIMULUS_ID.get(), } ) elif self._status != Status.closed: - self.loop.call_later(0.05, self._send_worker_status_change, stimulus_id) + self.loop.call_later(0.05, self._send_worker_status_change) async def get_metrics(self) -> dict: try: @@ -1661,8 +1664,7 @@ def update_data( report: bool = True, stimulus_id: str = None, ) -> dict[str, Any]: - if stimulus_id is None: - stimulus_id = f"update-data-{time()}" + STIMULUS_ID.set(stimulus_id or f"update-data-{time()}") recommendations: Recs = {} instructions: Instructions = [] for key, value in data.items(): @@ -1673,19 +1675,19 @@ def update_data( self.tasks[key] = ts = TaskState(key) try: - recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) + recs = self._put_key_in_memory(ts, value) except Exception as e: msg = error_message(e) recommendations = {ts: tuple(msg.values())} else: recommendations.update(recs) - self.log.append((key, "receive-from-scatter", stimulus_id, time())) + self.log.append((key, "receive-from-scatter", STIMULUS_ID.get(), time())) if report: - instructions.append(AddKeysMsg(keys=list(data), stimulus_id=stimulus_id)) + instructions.append(AddKeysMsg(keys=list(data))) - self.transitions(recommendations, stimulus_id=stimulus_id) + self.transitions(recommendations) self._handle_instructions(instructions) return {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"} @@ -1700,14 +1702,15 @@ def handle_free_keys(self, keys: list[str], stimulus_id: str) -> None: still decide to hold on to the data and task since it is required by an upstream dependency. """ - self.log.append(("free-keys", keys, stimulus_id, time())) - recommendations: Recs = {} - for key in keys: - ts = self.tasks.get(key) - if ts: - recommendations[ts] = "released" + with set_default_stimulus(stimulus_id): + self.log.append(("free-keys", keys, STIMULUS_ID.get(), time())) + recommendations: Recs = {} + for key in keys: + ts = self.tasks.get(key) + if ts: + recommendations[ts] = "released" - self.transitions(recommendations, stimulus_id=stimulus_id) + self.transitions(recommendations) def handle_remove_replicas(self, keys: list[str], stimulus_id: str) -> str: """Stream handler notifying the worker that it might be holding unreferenced, @@ -1728,28 +1731,32 @@ def handle_remove_replicas(self, keys: list[str], stimulus_id: str) -> str: For stronger guarantees, see handler free_keys """ - self.log.append(("remove-replicas", keys, stimulus_id, time())) - recommendations: Recs = {} - rejected = [] - for key in keys: - ts = self.tasks.get(key) - if ts is None or ts.state != "memory": - continue - if not ts.is_protected(): + with set_default_stimulus(stimulus_id): + self.log.append(("remove-replicas", keys, STIMULUS_ID.get(), time())) + recommendations: Recs = {} + + rejected = [] + for key in keys: + ts = self.tasks.get(key) + if ts is None or ts.state != "memory": + continue + if not ts.is_protected(): + self.log.append( + (ts.key, "remove-replica-confirmed", STIMULUS_ID.get(), time()) + ) + recommendations[ts] = "released" + else: + rejected.append(key) + + if rejected: self.log.append( - (ts.key, "remove-replica-confirmed", stimulus_id, time()) + ("remove-replica-rejected", rejected, STIMULUS_ID.get(), time()) ) - recommendations[ts] = "released" - else: - rejected.append(key) - - if rejected: - self.log.append(("remove-replica-rejected", rejected, stimulus_id, time())) - smsg = AddKeysMsg(keys=rejected, stimulus_id=stimulus_id) - self._handle_instructions([smsg]) + smsg = AddKeysMsg(keys=rejected) + self._handle_instructions([smsg]) - self.transitions(recommendations, stimulus_id=stimulus_id) + self.transitions(recommendations) return "OK" @@ -1777,7 +1784,7 @@ def handle_cancel_compute(self, key: str, stimulus_id: str) -> None: is in state `waiting` or `ready`. Nothing will happen otherwise. """ - self.handle_stimulus(CancelComputeEvent(key=key, stimulus_id=stimulus_id)) + self.handle_stimulus(CancelComputeEvent(key=key)) def handle_acquire_replicas( self, @@ -1786,36 +1793,38 @@ def handle_acquire_replicas( who_has: dict[str, Collection[str]], stimulus_id: str, ) -> None: - recommendations: Recs = {} - for key in keys: - ts = self.ensure_task_exists( - key=key, - # Transfer this data after all dependency tasks of computations with - # default or explicitly high (>0) user priority and before all - # computations with low priority (<0). Note that the priority= parameter - # of compute() is multiplied by -1 before it reaches TaskState.priority. - priority=(1,), - stimulus_id=stimulus_id, - ) - if ts.state != "memory": - recommendations[ts] = "fetch" + with set_default_stimulus(stimulus_id): + recommendations: Recs = {} + for key in keys: + ts = self.ensure_task_exists( + key=key, + # Transfer this data after all dependency tasks of computations with + # default or explicitly high (>0) user priority and before all + # computations with low priority (<0). Note that the priority= parameter + # of compute() is multiplied by -1 before it reaches TaskState.priority. + priority=(1,), + ) + if ts.state != "memory": + recommendations[ts] = "fetch" - self.update_who_has(who_has) - self.transitions(recommendations, stimulus_id=stimulus_id) + self.update_who_has(who_has) + self.transitions(recommendations) - def ensure_task_exists( - self, key: str, *, priority: tuple[int, ...], stimulus_id: str - ) -> TaskState: + def ensure_task_exists(self, key: str, *, priority: tuple[int, ...]) -> TaskState: try: ts = self.tasks[key] - logger.debug("Data task %s already known (stimulus_id=%s)", ts, stimulus_id) + logger.debug( + "Data task %s already known (stimulus_id=%s)", ts, STIMULUS_ID.get() + ) except KeyError: self.tasks[key] = ts = TaskState(key) if not ts.priority: assert priority ts.priority = priority - self.log.append((key, "ensure-task-exists", ts.state, stimulus_id, time())) + self.log.append( + (key, "ensure-task-exists", ts.state, STIMULUS_ID.get(), time()) + ) return ts def handle_compute_task( @@ -1835,77 +1844,78 @@ def handle_compute_task( annotations: dict | None = None, stimulus_id: str, ) -> None: - self.log.append((key, "compute-task", stimulus_id, time())) - try: - ts = self.tasks[key] - logger.debug( - "Asked to compute an already known task %s", - {"task": ts, "stimulus_id": stimulus_id}, - ) - except KeyError: - self.tasks[key] = ts = TaskState(key) - - ts.run_spec = SerializedTask(function, args, kwargs, task) + with set_default_stimulus(stimulus_id): + self.log.append((key, "compute-task", STIMULUS_ID.get(), time())) + try: + ts = self.tasks[key] + logger.debug( + "Asked to compute an already known task %s", + {"task": ts, "stimulus_id": STIMULUS_ID.get()}, + ) + except KeyError: + self.tasks[key] = ts = TaskState(key) - assert isinstance(priority, tuple) - priority = priority + (self.generation,) - self.generation -= 1 + ts.run_spec = SerializedTask(function, args, kwargs, task) - if actor: - self.actors[ts.key] = None + assert isinstance(priority, tuple) + priority = priority + (self.generation,) + self.generation -= 1 - ts.exception = None - ts.traceback = None - ts.exception_text = "" - ts.traceback_text = "" - ts.priority = priority - ts.duration = duration - if resource_restrictions: - ts.resource_restrictions = resource_restrictions - ts.annotations = annotations + if actor: + self.actors[ts.key] = None - recommendations: Recs = {} - instructions: Instructions = [] - for dependency in who_has: - dep_ts = self.ensure_task_exists( - key=dependency, - priority=priority, - stimulus_id=stimulus_id, - ) + ts.exception = None + ts.traceback = None + ts.exception_text = "" + ts.traceback_text = "" + ts.priority = priority + ts.duration = duration + if resource_restrictions: + ts.resource_restrictions = resource_restrictions + ts.annotations = annotations + + recommendations: Recs = {} + instructions: Instructions = [] + for dependency in who_has: + dep_ts = self.ensure_task_exists( + key=dependency, + priority=priority, + ) - # link up to child / parents - ts.dependencies.add(dep_ts) - dep_ts.dependents.add(ts) + # link up to child / parents + ts.dependencies.add(dep_ts) + dep_ts.dependents.add(ts) - if ts.state in READY | {"executing", "waiting", "resumed"}: - pass - elif ts.state == "memory": - recommendations[ts] = "memory" - instructions.append( - self._get_task_finished_msg(ts, stimulus_id=stimulus_id) - ) - elif ts.state in { - "released", - "fetch", - "flight", - "missing", - "cancelled", - "error", - }: - recommendations[ts] = "waiting" - else: # pragma: no cover - raise RuntimeError(f"Unexpected task state encountered {ts} {stimulus_id}") + if ts.state in READY | {"executing", "waiting", "resumed"}: + pass + elif ts.state == "memory": + recommendations[ts] = "memory" + instructions.append(self._get_task_finished_msg(ts)) + elif ts.state in { + "released", + "fetch", + "flight", + "missing", + "cancelled", + "error", + }: + recommendations[ts] = "waiting" + else: # pragma: no cover + raise RuntimeError( + f"Unexpected task state encountered {ts} {STIMULUS_ID.get()}" + ) - self._handle_instructions(instructions) - self.update_who_has(who_has) - self.transitions(recommendations, stimulus_id=stimulus_id) + self._handle_instructions(instructions) + self.update_who_has(who_has) + self.transitions(recommendations) - if nbytes is not None: - for key, value in nbytes.items(): - self.tasks[key].nbytes = value + if nbytes is not None: + for key, value in nbytes.items(): + self.tasks[key].nbytes = value def transition_missing_fetch( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> RecsInstrs: if self.validate: assert ts.state == "missing" @@ -1918,17 +1928,17 @@ def transition_missing_fetch( return {}, [] def transition_missing_released( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> RecsInstrs: self._missing_dep_flight.discard(ts) - recs, instructions = self.transition_generic_released( - ts, stimulus_id=stimulus_id - ) + recs, instructions = self.transition_generic_released(ts) assert ts.key in self.tasks return recs, instructions def transition_flight_missing( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> RecsInstrs: assert ts.done ts.state = "missing" @@ -1937,7 +1947,8 @@ def transition_flight_missing( return {}, [] def transition_released_fetch( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> RecsInstrs: if self.validate: assert ts.state == "released" @@ -1950,9 +1961,10 @@ def transition_released_fetch( return {}, [] def transition_generic_released( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> RecsInstrs: - self.release_key(ts.key, stimulus_id=stimulus_id) + self.release_key(ts.key) recs: Recs = {} for dependency in ts.dependencies: if ( @@ -1967,7 +1979,8 @@ def transition_generic_released( return recs, [] def transition_released_waiting( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> RecsInstrs: if self.validate: assert ts.state == "released" @@ -1993,7 +2006,9 @@ def transition_released_waiting( return recommendations, [] def transition_fetch_flight( - self, ts: TaskState, worker, *, stimulus_id: str + self, + ts: TaskState, + worker, ) -> RecsInstrs: if self.validate: assert ts.state == "fetch" @@ -2006,16 +2021,16 @@ def transition_fetch_flight( return {}, [] def transition_memory_released( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> RecsInstrs: - recs, instructions = self.transition_generic_released( - ts, stimulus_id=stimulus_id - ) - instructions.append(ReleaseWorkerDataMsg(key=ts.key, stimulus_id=stimulus_id)) + recs, instructions = self.transition_generic_released(ts) + instructions.append(ReleaseWorkerDataMsg(key=ts.key)) return recs, instructions def transition_waiting_constrained( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> RecsInstrs: if self.validate: assert ts.state == "waiting" @@ -2031,25 +2046,28 @@ def transition_waiting_constrained( return {}, [] def transition_long_running_rescheduled( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> RecsInstrs: recs: Recs = {ts: "released"} - smsg = RescheduleMsg(key=ts.key, worker=self.address, stimulus_id=stimulus_id) + smsg = RescheduleMsg(key=ts.key, worker=self.address) return recs, [smsg] def transition_executing_rescheduled( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> RecsInstrs: for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity self._executing.discard(ts) recs: Recs = {ts: "released"} - smsg = RescheduleMsg(key=ts.key, worker=self.address, stimulus_id=stimulus_id) + smsg = RescheduleMsg(key=ts.key, worker=self.address) return recs, [smsg] def transition_waiting_ready( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> RecsInstrs: if self.validate: assert ts.state == "waiting" @@ -2072,8 +2090,6 @@ def transition_cancelled_error( traceback: Serialize | None, exception_text: str, traceback_text: str, - *, - stimulus_id: str, ) -> RecsInstrs: recs: Recs = {} instructions: Instructions = [] @@ -2084,7 +2100,6 @@ def transition_cancelled_error( traceback, exception_text, traceback_text, - stimulus_id=stimulus_id, ) elif ts._previous == "flight": recs, instructions = self.transition_flight_error( @@ -2093,7 +2108,6 @@ def transition_cancelled_error( traceback, exception_text, traceback_text, - stimulus_id=stimulus_id, ) if ts._next: recs[ts] = ts._next @@ -2106,8 +2120,6 @@ def transition_generic_error( traceback: Serialize | None, exception_text: str, traceback_text: str, - *, - stimulus_id: str, ) -> RecsInstrs: ts.exception = exception ts.traceback = traceback @@ -2122,7 +2134,6 @@ def transition_generic_error( traceback_text=traceback_text, thread=self.threads.get(ts.key), startstops=ts.startstops, - stimulus_id=stimulus_id, ) return {}, [smsg] @@ -2134,8 +2145,6 @@ def transition_executing_error( traceback: Serialize | None, exception_text: str, traceback_text: str, - *, - stimulus_id: str, ) -> RecsInstrs: for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity @@ -2146,11 +2155,12 @@ def transition_executing_error( traceback, exception_text, traceback_text, - stimulus_id=stimulus_id, ) def _transition_from_resumed( - self, ts: TaskState, finish: TaskStateState, *, stimulus_id: str + self, + ts: TaskState, + finish: TaskStateState, ) -> RecsInstrs: """`resumed` is an intermediate degenerate state which splits further up into two states depending on what the last signal / next state is @@ -2178,9 +2188,7 @@ def _transition_from_resumed( # if the next state is already intended to be waiting or if the # coro/thread is still running (ts.done==False), this is a noop if ts._next != finish: - recs, instructions = self.transition_generic_released( - ts, stimulus_id=stimulus_id - ) + recs, instructions = self.transition_generic_released(ts) assert next_state recs[ts] = next_state else: @@ -2188,29 +2196,35 @@ def _transition_from_resumed( return recs, instructions def transition_resumed_fetch( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> RecsInstrs: """ See Worker._transition_from_resumed """ - return self._transition_from_resumed(ts, "fetch", stimulus_id=stimulus_id) + return self._transition_from_resumed(ts, "fetch") def transition_resumed_missing( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> RecsInstrs: """ See Worker._transition_from_resumed """ - return self._transition_from_resumed(ts, "missing", stimulus_id=stimulus_id) + return self._transition_from_resumed(ts, "missing") - def transition_resumed_waiting(self, ts: TaskState, *, stimulus_id: str): + def transition_resumed_waiting( + self, + ts: TaskState, + ): """ See Worker._transition_from_resumed """ - return self._transition_from_resumed(ts, "waiting", stimulus_id=stimulus_id) + return self._transition_from_resumed(ts, "waiting") def transition_cancelled_fetch( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> RecsInstrs: if ts.done: return {ts: "released"}, [] @@ -2222,14 +2236,17 @@ def transition_cancelled_fetch( return {ts: ("resumed", "fetch")}, [] def transition_cancelled_resumed( - self, ts: TaskState, next: TaskStateState, *, stimulus_id: str + self, + ts: TaskState, + next: TaskStateState, ) -> RecsInstrs: ts._next = next ts.state = "resumed" return {}, [] def transition_cancelled_waiting( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> RecsInstrs: if ts.done: return {ts: "released"}, [] @@ -2241,7 +2258,8 @@ def transition_cancelled_waiting( return {ts: ("resumed", "waiting")}, [] def transition_cancelled_forgotten( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> RecsInstrs: ts._next = "forgotten" if not ts.done: @@ -2249,7 +2267,8 @@ def transition_cancelled_forgotten( return {ts: "released"}, [] def transition_cancelled_released( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> RecsInstrs: if not ts.done: ts._next = "released" @@ -2261,15 +2280,14 @@ def transition_cancelled_released( for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity - recs, instructions = self.transition_generic_released( - ts, stimulus_id=stimulus_id - ) + recs, instructions = self.transition_generic_released(ts) if next_state != "released": recs[ts] = next_state return recs, instructions def transition_executing_released( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> RecsInstrs: ts._previous = ts.state ts._next = "released" @@ -2279,13 +2297,17 @@ def transition_executing_released( return {}, [] def transition_long_running_memory( - self, ts: TaskState, value=no_value, *, stimulus_id: str + self, + ts: TaskState, + value=no_value, ) -> RecsInstrs: self.executed_count += 1 - return self.transition_generic_memory(ts, value=value, stimulus_id=stimulus_id) + return self.transition_generic_memory(ts, value=value) def transition_generic_memory( - self, ts: TaskState, value=no_value, *, stimulus_id: str + self, + ts: TaskState, + value=no_value, ) -> RecsInstrs: if value is no_value and ts.key not in self.data: raise RuntimeError( @@ -2300,18 +2322,20 @@ def transition_generic_memory( self._in_flight_tasks.discard(ts) ts.coming_from = None try: - recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) + recs = self._put_key_in_memory(ts, value) except Exception as e: msg = error_message(e) recs = {ts: tuple(msg.values())} return recs, [] if self.validate: assert ts.key in self.data or ts.key in self.actors - smsg = self._get_task_finished_msg(ts, stimulus_id=stimulus_id) + smsg = self._get_task_finished_msg(ts) return recs, [smsg] def transition_executing_memory( - self, ts: TaskState, value=no_value, *, stimulus_id: str + self, + ts: TaskState, + value=no_value, ) -> RecsInstrs: if self.validate: assert ts.state == "executing" or ts.key in self.long_running @@ -2320,10 +2344,11 @@ def transition_executing_memory( self._executing.discard(ts) self.executed_count += 1 - return self.transition_generic_memory(ts, value=value, stimulus_id=stimulus_id) + return self.transition_generic_memory(ts, value=value) def transition_constrained_executing( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> RecsInstrs: if self.validate: assert not ts.waiting_for_data @@ -2337,11 +2362,12 @@ def transition_constrained_executing( self.available_resources[resource] -= quantity ts.state = "executing" self._executing.add(ts) - instr = Execute(key=ts.key, stimulus_id=stimulus_id) + instr = Execute(key=ts.key) return {}, [instr] def transition_ready_executing( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> RecsInstrs: if self.validate: assert not ts.waiting_for_data @@ -2355,10 +2381,13 @@ def transition_ready_executing( ts.state = "executing" self._executing.add(ts) - instr = Execute(key=ts.key, stimulus_id=stimulus_id) + instr = Execute(key=ts.key) return {}, [instr] - def transition_flight_fetch(self, ts: TaskState, *, stimulus_id: str) -> RecsInstrs: + def transition_flight_fetch( + self, + ts: TaskState, + ) -> RecsInstrs: # If this transition is called after the flight coroutine has finished, # we can reset the task and transition to fetch again. If it is not yet # finished, this should be a no-op @@ -2384,8 +2413,6 @@ def transition_flight_error( traceback: Serialize | None, exception_text: str, traceback_text: str, - *, - stimulus_id: str, ) -> RecsInstrs: self._in_flight_tasks.discard(ts) ts.coming_from = None @@ -2395,16 +2422,16 @@ def transition_flight_error( traceback, exception_text, traceback_text, - stimulus_id=stimulus_id, ) def transition_flight_released( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> RecsInstrs: if ts.done: # FIXME: Is this even possible? Would an assert instead be more # sensible? - return self.transition_generic_released(ts, stimulus_id=stimulus_id) + return self.transition_generic_released(ts) else: ts._previous = "flight" ts._next = "released" @@ -2413,51 +2440,58 @@ def transition_flight_released( return {}, [] def transition_cancelled_memory( - self, ts: TaskState, value, *, stimulus_id: str + self, + ts: TaskState, + value, ) -> RecsInstrs: assert ts._next return {ts: ts._next}, [] def transition_executing_long_running( - self, ts: TaskState, compute_duration: float, *, stimulus_id: str + self, + ts: TaskState, + compute_duration: float, ) -> RecsInstrs: ts.state = "long-running" self._executing.discard(ts) self.long_running.add(ts.key) - smsg = LongRunningMsg( - key=ts.key, compute_duration=compute_duration, stimulus_id=stimulus_id - ) + smsg = LongRunningMsg(key=ts.key, compute_duration=compute_duration) self.io_loop.add_callback(self.ensure_computing) return {}, [smsg] def transition_released_memory( - self, ts: TaskState, value, *, stimulus_id: str + self, + ts: TaskState, + value, ) -> RecsInstrs: try: - recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) + recs = self._put_key_in_memory(ts, value) except Exception as e: msg = error_message(e) recs = {ts: tuple(msg.values())} return recs, [] - smsg = AddKeysMsg(keys=[ts.key], stimulus_id=stimulus_id) + smsg = AddKeysMsg(keys=[ts.key]) return recs, [smsg] def transition_flight_memory( - self, ts: TaskState, value, *, stimulus_id: str + self, + ts: TaskState, + value, ) -> RecsInstrs: self._in_flight_tasks.discard(ts) ts.coming_from = None try: - recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) + recs = self._put_key_in_memory(ts, value) except Exception as e: msg = error_message(e) recs = {ts: tuple(msg.values())} return recs, [] - smsg = AddKeysMsg(keys=[ts.key], stimulus_id=stimulus_id) + smsg = AddKeysMsg(keys=[ts.key]) return recs, [smsg] def transition_released_forgotten( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> RecsInstrs: recommendations: Recs = {} # Dependents _should_ be released by the scheduler before this @@ -2474,7 +2508,7 @@ def transition_released_forgotten( return recommendations, [] def _transition( - self, ts: TaskState, finish: str | tuple, *args, stimulus_id: str, **kwargs + self, ts: TaskState, finish: str | tuple, *args, **kwargs ) -> RecsInstrs: if isinstance(finish, tuple): # the concatenated transition path might need to access the tuple @@ -2489,15 +2523,13 @@ def _transition( if func is not None: self._transition_counter += 1 - recs, instructions = func(ts, *args, stimulus_id=stimulus_id, **kwargs) + recs, instructions = func(ts, *args, **kwargs) self._notify_plugins("transition", ts.key, start, finish, **kwargs) elif "released" not in (start, finish): # start -> "released" -> finish try: - recs, instructions = self._transition( - ts, "released", stimulus_id=stimulus_id - ) + recs, instructions = self._transition(ts, "released") v = recs.get(ts, (finish, *args)) v_state: str v_args: list | tuple @@ -2505,9 +2537,7 @@ def _transition( v_state, *v_args = v else: v_state, v_args = v, () - b_recs, b_instructions = self._transition( - ts, v_state, *v_args, stimulus_id=stimulus_id - ) + b_recs, b_instructions = self._transition(ts, v_state, *v_args) recs.update(b_recs) instructions += b_instructions except InvalidTransition: @@ -2532,15 +2562,13 @@ def _transition( ts.state, # new recommendations {ts.key: new for ts, new in recs.items()}, - stimulus_id, + STIMULUS_ID.get(), time(), ) ) return recs, instructions - def transition( - self, ts: TaskState, finish: str, *, stimulus_id: str, **kwargs - ) -> None: + def transition(self, ts: TaskState, finish: str, **kwargs) -> None: """Transition a key from its current state to the finish state Examples @@ -2556,13 +2584,14 @@ def transition( -------- Scheduler.transitions: transitive version of this function """ - recs, instructions = self._transition( - ts, finish, stimulus_id=stimulus_id, **kwargs - ) + recs, instructions = self._transition(ts, finish, **kwargs) self._handle_instructions(instructions) - self.transitions(recs, stimulus_id=stimulus_id) + self.transitions(recs) - def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None: + def transitions( + self, + recommendations: Recs, + ) -> None: """Process transitions until none are left This includes feedback from previous transitions and continues until we @@ -2575,9 +2604,7 @@ def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None: while remaining_recs: ts, finish = remaining_recs.popitem() tasks.add(ts) - a_recs, a_instructions = self._transition( - ts, finish, stimulus_id=stimulus_id - ) + a_recs, a_instructions = self._transition(ts, finish) remaining_recs.update(a_recs) instructions += a_instructions @@ -2597,12 +2624,16 @@ def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None: def handle_stimulus(self, stim: StateMachineEvent) -> None: with log_errors(): - # self.stimulus_history.append(stim) # TODO - recs, instructions = self.handle_event(stim) - self.transitions(recs, stimulus_id=stim.stimulus_id) - self._handle_instructions(instructions) - self.ensure_computing() - self.ensure_communicating() + token = STIMULUS_ID.set(stim.stimulus_id) # type: ignore + try: + # self.stimulus_history.append(stim) # TODO + recs, instructions = self.handle_event(stim) + self.transitions(recs) + self._handle_instructions(instructions) + self.ensure_computing() + self.ensure_communicating() + finally: + STIMULUS_ID.reset(token) def _handle_stimulus_from_future( self, future: asyncio.Future[StateMachineEvent | None] @@ -2617,25 +2648,27 @@ def _handle_instructions(self, instructions: Instructions) -> None: # TODO this method is temporary. # See final design: https://github.com/dask/distributed/issues/5894 for inst in instructions: - if isinstance(inst, SendMessageToScheduler): - self.batched_stream.send(inst.to_dict()) - elif isinstance(inst, Execute): - coro = self.execute(inst.key, stimulus_id=inst.stimulus_id) - task = asyncio.create_task(coro) - # TODO track task (at the moment it's fire-and-forget) - task.add_done_callback(self._handle_stimulus_from_future) - else: - raise TypeError(inst) # pragma: nocover - def maybe_transition_long_running( - self, ts: TaskState, *, compute_duration: float, stimulus_id: str - ): + token = STIMULUS_ID.set(inst.stimulus_id) # type: ignore + try: + if isinstance(inst, SendMessageToScheduler): + self.batched_stream.send(inst.to_dict()) + elif isinstance(inst, Execute): + coro = self.execute(inst.key) + task = asyncio.create_task(coro) + # TODO track task (at the moment it's fire-and-forget) + task.add_done_callback(self._handle_stimulus_from_future) + else: + raise TypeError(inst) # pragma: nocover + finally: + STIMULUS_ID.reset(token) + + def maybe_transition_long_running(self, ts: TaskState, *, compute_duration: float): if ts.state == "executing": self.transition( ts, "long-running", compute_duration=compute_duration, - stimulus_id=stimulus_id, ) assert ts.state == "long-running" @@ -2653,66 +2686,71 @@ def story(self, *keys_or_tasks: str | TaskState) -> list[tuple]: return worker_story(keys, self.log) def ensure_communicating(self) -> None: - stimulus_id = f"ensure-communicating-{time()}" - skipped_worker_in_flight = [] + with set_default_stimulus(f"ensure-communicating-{time()}"): + skipped_worker_in_flight = [] - while self.data_needed and ( - len(self.in_flight_workers) < self.total_out_connections - or self.comm_nbytes < self.comm_threshold_bytes - ): - logger.debug( - "Ensure communicating. Pending: %d. Connections: %d/%d", - len(self.data_needed), - len(self.in_flight_workers), - self.total_out_connections, - ) + while self.data_needed and ( + len(self.in_flight_workers) < self.total_out_connections + or self.comm_nbytes < self.comm_threshold_bytes + ): + logger.debug( + "Ensure communicating. Pending: %d. Connections: %d/%d", + len(self.data_needed), + len(self.in_flight_workers), + self.total_out_connections, + ) - ts = self.data_needed.pop() + ts = self.data_needed.pop() - if ts.state != "fetch": - continue + if ts.state != "fetch": + continue - workers = [w for w in ts.who_has if w not in self.in_flight_workers] - if not workers: - assert ts.priority is not None - skipped_worker_in_flight.append(ts) - continue + workers = [w for w in ts.who_has if w not in self.in_flight_workers] + if not workers: + assert ts.priority is not None + skipped_worker_in_flight.append(ts) + continue - host = get_address_host(self.address) - local = [w for w in workers if get_address_host(w) == host] - if local: - worker = random.choice(local) - else: - worker = random.choice(list(workers)) - assert worker != self.address + host = get_address_host(self.address) + local = [w for w in workers if get_address_host(w) == host] + if local: + worker = random.choice(local) + else: + worker = random.choice(list(workers)) + assert worker != self.address - to_gather, total_nbytes = self.select_keys_for_gather(worker, ts.key) + to_gather, total_nbytes = self.select_keys_for_gather(worker, ts.key) - self.log.append( - ("gather-dependencies", worker, to_gather, stimulus_id, time()) - ) + self.log.append( + ( + "gather-dependencies", + worker, + to_gather, + STIMULUS_ID.get(), + time(), + ) + ) - self.comm_nbytes += total_nbytes - self.in_flight_workers[worker] = to_gather - recommendations: Recs = { - self.tasks[d]: ("flight", worker) for d in to_gather - } - self.transitions(recommendations, stimulus_id=stimulus_id) - - self.loop.add_callback( - self.gather_dep, - worker=worker, - to_gather=to_gather, - total_nbytes=total_nbytes, - stimulus_id=stimulus_id, - ) + self.comm_nbytes += total_nbytes + self.in_flight_workers[worker] = to_gather + recommendations: Recs = { + self.tasks[d]: ("flight", worker) for d in to_gather + } + self.transitions(recommendations) - for el in skipped_worker_in_flight: - self.data_needed.push(el) + self.loop.add_callback( + self.gather_dep, + worker=worker, + to_gather=to_gather, + total_nbytes=total_nbytes, + # add_callback is not ctx sensitive + stimulus_id=STIMULUS_ID.get(), + ) - def _get_task_finished_msg( - self, ts: TaskState, stimulus_id: str - ) -> TaskFinishedMsg: + for el in skipped_worker_in_flight: + self.data_needed.push(el) + + def _get_task_finished_msg(self, ts: TaskState) -> TaskFinishedMsg: if ts.key not in self.data and ts.key not in self.actors: raise RuntimeError(f"Task {ts} not ready") typ = ts.type @@ -2738,10 +2776,13 @@ def _get_task_finished_msg( metadata=ts.metadata, thread=self.threads.get(ts.key), startstops=ts.startstops, - stimulus_id=stimulus_id, ) - def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs: + def _put_key_in_memory( + self, + ts: TaskState, + value, + ) -> Recs: """ Put a key into memory and set data related task state attributes. On success, generate recommendations for dependents. @@ -2791,7 +2832,7 @@ def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs: self.waiting_for_data_count -= 1 recommendations[dep] = "ready" - self.log.append((ts.key, "put-in-memory", stimulus_id, time())) + self.log.append((ts.key, "put-in-memory", STIMULUS_ID.get(), time())) return recommendations def select_keys_for_gather(self, worker, dep): @@ -2924,7 +2965,6 @@ async def gather_dep( worker: str, to_gather: Iterable[str], total_nbytes: int, - *, stimulus_id: str, ) -> None: """Gather dependencies for a task from a worker who has them @@ -2940,133 +2980,168 @@ async def gather_dep( total_nbytes : int Total number of bytes for all the dependencies in to_gather combined """ - if self.status not in Status.ANY_RUNNING: # type: ignore - return - - recommendations: Recs = {} - with log_errors(): - response = {} - to_gather_keys: set[str] = set() - cancelled_keys: set[str] = set() - try: - to_gather_keys, cancelled_keys, cause = self._filter_deps_for_fetch( - to_gather - ) + with set_default_stimulus(stimulus_id): + if self.status not in Status.ANY_RUNNING: # type: ignore + return - if not to_gather_keys: - self.log.append( - ("nothing-to-gather", worker, to_gather, stimulus_id, time()) + recommendations: Recs = {} + with log_errors(): + response = {} + to_gather_keys: set[str] = set() + cancelled_keys: set[str] = set() + try: + to_gather_keys, cancelled_keys, cause = self._filter_deps_for_fetch( + to_gather ) - return - assert cause - # Keep namespace clean since this func is long and has many - # dep*, *ts* variables - del to_gather - - self.log.append( - ("request-dep", worker, to_gather_keys, stimulus_id, time()) - ) - logger.debug( - "Request %d keys for task %s from %s", - len(to_gather_keys), - cause, - worker, - ) - - start = time() - response = await get_data_from_worker( - self.rpc, to_gather_keys, worker, who=self.address - ) - stop = time() - if response["status"] == "busy": - return - - self._update_metrics_received_data( - start=start, - stop=stop, - data=response["data"], - cause=cause, - worker=worker, - ) - self.log.append( - ("receive-dep", worker, set(response["data"]), stimulus_id, time()) - ) + if not to_gather_keys: + self.log.append( + ( + "nothing-to-gather", + worker, + to_gather, + STIMULUS_ID.get(), + time(), + ) + ) + return - except OSError: - logger.exception("Worker stream died during communication: %s", worker) - has_what = self.has_what.pop(worker) - self.pending_data_per_worker.pop(worker) - self.log.append( - ("receive-dep-failed", worker, has_what, stimulus_id, time()) - ) - for d in has_what: - ts = self.tasks[d] - ts.who_has.remove(worker) + assert cause + # Keep namespace clean since this func is long and has many + # dep*, *ts* variables + del to_gather - except Exception as e: - logger.exception(e) - if self.batched_stream and LOG_PDB: - import pdb - - pdb.set_trace() - msg = error_message(e) - for k in self.in_flight_workers[worker]: - ts = self.tasks[k] - recommendations[ts] = tuple(msg.values()) - raise - finally: - self.comm_nbytes -= total_nbytes - busy = response.get("status", "") == "busy" - data = response.get("data", {}) + self.log.append( + ( + "request-dep", + worker, + to_gather_keys, + STIMULUS_ID.get(), + time(), + ) + ) + logger.debug( + "Request %d keys for task %s from %s", + len(to_gather_keys), + cause, + worker, + ) - if busy: + start = time() + response = await get_data_from_worker( + self.rpc, to_gather_keys, worker, who=self.address + ) + stop = time() + if response["status"] == "busy": + return + + self._update_metrics_received_data( + start=start, + stop=stop, + data=response["data"], + cause=cause, + worker=worker, + ) self.log.append( - ("busy-gather", worker, to_gather_keys, stimulus_id, time()) + ( + "receive-dep", + worker, + set(response["data"]), + STIMULUS_ID.get(), + time(), + ) ) - for d in self.in_flight_workers.pop(worker): - ts = self.tasks[d] - ts.done = True - if d in cancelled_keys: - if ts.state == "cancelled": - recommendations[ts] = "released" - else: - recommendations[ts] = "fetch" - elif d in data: - recommendations[ts] = ("memory", data[d]) - elif busy: - recommendations[ts] = "fetch" - elif ts not in recommendations: - ts.who_has.discard(worker) - self.has_what[worker].discard(ts.key) - self.log.append((d, "missing-dep", stimulus_id, time())) - self.batched_stream.send( - { - "op": "missing-data", - "errant_worker": worker, - "key": d, - "stimulus_id": stimulus_id, - } + except OSError: + logger.exception( + "Worker stream died during communication: %s", worker + ) + has_what = self.has_what.pop(worker) + self.pending_data_per_worker.pop(worker) + self.log.append( + ( + "receive-dep-failed", + worker, + has_what, + STIMULUS_ID.get(), + time(), ) - recommendations[ts] = "fetch" if ts.who_has else "missing" - del data, response - self.transitions(recommendations, stimulus_id=stimulus_id) - self.ensure_computing() + ) + for d in has_what: + ts = self.tasks[d] + ts.who_has.remove(worker) - if not busy: - self.repetitively_busy = 0 - else: - # Exponential backoff to avoid hammering scheduler/worker - self.repetitively_busy += 1 - await asyncio.sleep(0.100 * 1.5**self.repetitively_busy) + except Exception as e: + logger.exception(e) + if self.batched_stream and LOG_PDB: + import pdb - await self.query_who_has(*to_gather_keys) + pdb.set_trace() + msg = error_message(e) + for k in self.in_flight_workers[worker]: + ts = self.tasks[k] + recommendations[ts] = tuple(msg.values()) + raise + finally: + self.comm_nbytes -= total_nbytes + busy = response.get("status", "") == "busy" + data = response.get("data", {}) + + if busy: + self.log.append( + ( + "busy-gather", + worker, + to_gather_keys, + STIMULUS_ID.get(), + time(), + ) + ) - self.ensure_communicating() + for d in self.in_flight_workers.pop(worker): + ts = self.tasks[d] + ts.done = True + if d in cancelled_keys: + if ts.state == "cancelled": + recommendations[ts] = "released" + else: + recommendations[ts] = "fetch" + elif d in data: + recommendations[ts] = ("memory", data[d]) + elif busy: + recommendations[ts] = "fetch" + elif ts not in recommendations: + ts.who_has.discard(worker) + self.has_what[worker].discard(ts.key) + self.log.append( + (d, "missing-dep", STIMULUS_ID.get(), time()) + ) + self.batched_stream.send( + { + "op": "missing-data", + "errant_worker": worker, + "key": d, + "stimulus_id": STIMULUS_ID.get(), + } + ) + recommendations[ts] = "fetch" if ts.who_has else "missing" + del data, response + self.transitions(recommendations) + self.ensure_computing() + + if not busy: + self.repetitively_busy = 0 + else: + # Exponential backoff to avoid hammering scheduler/worker + self.repetitively_busy += 1 + await asyncio.sleep(0.100 * 1.5**self.repetitively_busy) + + await self.query_who_has(*to_gather_keys) + + self.ensure_communicating() async def find_missing(self) -> None: - with log_errors(): + with log_errors(), set_default_stimulus(f"find-missing-{time()}"): if not self._missing_dep_flight: return try: @@ -3074,7 +3149,6 @@ async def find_missing(self) -> None: for ts in self._missing_dep_flight: assert not ts.who_has - stimulus_id = f"find-missing-{time()}" who_has = await retry_operation( self.scheduler.who_has, keys=[ts.key for ts in self._missing_dep_flight], @@ -3085,7 +3159,7 @@ async def find_missing(self) -> None: for ts in self._missing_dep_flight: if ts.who_has: recommendations[ts] = "fetch" - self.transitions(recommendations, stimulus_id=stimulus_id) + self.transitions(recommendations) finally: # This is quite arbitrary but the heartbeat has scaling implemented @@ -3134,49 +3208,52 @@ def handle_steal_request(self, key: str, stimulus_id: str) -> None: # There may be a race condition between stealing and releasing a task. # In this case the self.tasks is already cleared. The `None` will be # registered as `already-computing` on the other end - ts = self.tasks.get(key) - state = ts.state if ts is not None else None + with set_default_stimulus(stimulus_id): + ts = self.tasks.get(key) + state = ts.state if ts is not None else None - response = { - "op": "steal-response", - "key": key, - "state": state, - "stimulus_id": stimulus_id, - } - self.batched_stream.send(response) + response = { + "op": "steal-response", + "key": key, + "state": state, + "stimulus_id": stimulus_id, + } + self.batched_stream.send(response) - if state in READY | {"waiting"}: - assert ts - # If task is marked as "constrained" we haven't yet assigned it an - # `available_resources` to run on, that happens in - # `transition_constrained_executing` - self.transition(ts, "released", stimulus_id=stimulus_id) + if state in READY | {"waiting"}: + assert ts + # If task is marked as "constrained" we haven't yet assigned it an + # `available_resources` to run on, that happens in + # `transition_constrained_executing` + self.transition(ts, "released") def handle_worker_status_change(self, status: str, stimulus_id: str) -> None: - new_status = Status.lookup[status] # type: ignore + with set_default_stimulus(stimulus_id): + new_status = Status.lookup[status] # type: ignore - if ( - new_status == Status.closing_gracefully - and self._status not in Status.ANY_RUNNING # type: ignore - ): - logger.error( - "Invalid Worker.status transition: %s -> %s", self._status, new_status - ) - # Reiterate the current status to the scheduler to restore sync - self._send_worker_status_change(stimulus_id) - else: - # Update status and send confirmation to the Scheduler (see status.setter) - self.status = new_status + if ( + new_status == Status.closing_gracefully + and self._status not in Status.ANY_RUNNING # type: ignore + ): + logger.error( + "Invalid Worker.status transition: %s -> %s", + self._status, + new_status, + ) + # Reiterate the current status to the scheduler to restore sync + self._send_worker_status_change() + else: + # Update status and send confirmation to the Scheduler (see status.setter) + self.status = new_status def release_key( self, key: str, cause: TaskState | None = None, report: bool = True, - *, - stimulus_id: str, ) -> None: try: + stimulus_id = STIMULUS_ID.get() if self.validate: assert not isinstance(key, TaskState) ts = self.tasks[key] @@ -3355,7 +3432,8 @@ def meets_resource_constraints(self, key: str) -> bool: return True async def _maybe_deserialize_task( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, ) -> tuple[Callable, tuple, dict[str, Any]] | None: if ts.run_spec is None: return None @@ -3375,14 +3453,13 @@ async def _maybe_deserialize_task( return function, args, kwargs except Exception as e: logger.error("Could not deserialize task", exc_info=True) - self.log.append((ts.key, "deserialize-error", stimulus_id, time())) + self.log.append((ts.key, "deserialize-error", STIMULUS_ID.get(), time())) emsg = error_message(e) del emsg["status"] # type: ignore self.transition( ts, "error", **emsg, - stimulus_id=stimulus_id, ) raise @@ -3390,7 +3467,7 @@ def ensure_computing(self) -> None: if self.status in (Status.paused, Status.closing_gracefully): return try: - stimulus_id = f"ensure-computing-{time()}" + STIMULUS_ID.set(f"ensure-computing-{time()}") while self.constrained and self.executing_count < self.nthreads: key = self.constrained[0] ts = self.tasks.get(key, None) @@ -3399,7 +3476,7 @@ def ensure_computing(self) -> None: continue if self.meets_resource_constraints(key): self.constrained.popleft() - self.transition(ts, "executing", stimulus_id=stimulus_id) + self.transition(ts, "executing") else: break while self.ready and self.executing_count < self.nthreads: @@ -3412,9 +3489,9 @@ def ensure_computing(self) -> None: # continue through the heap continue elif ts.key in self.data: - self.transition(ts, "memory", stimulus_id=stimulus_id) + self.transition(ts, "memory") elif ts.state in READY: - self.transition(ts, "executing", stimulus_id=stimulus_id) + self.transition(ts, "executing") except Exception as e: # pragma: no cover logger.exception(e) if LOG_PDB: @@ -3423,7 +3500,10 @@ def ensure_computing(self) -> None: pdb.set_trace() raise - async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | None: + async def execute( + self, + key: str, + ) -> StateMachineEvent | None: if self.status in {Status.closing, Status.closed, Status.closing_gracefully}: return None ts = self.tasks.get(key) @@ -3434,7 +3514,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No "Trying to execute task %s which is not in executing state anymore", ts, ) - return AlreadyCancelledEvent(key=ts.key, stimulus_id=stimulus_id) + return AlreadyCancelledEvent(key=ts.key) try: if self.validate: @@ -3443,7 +3523,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No assert ts.run_spec is not None function, args, kwargs = await self._maybe_deserialize_task( # type: ignore - ts, stimulus_id=stimulus_id + ts ) args2, kwargs2 = self._prepare_args_for_execution(ts, args, kwargs) @@ -3497,46 +3577,49 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No self.threads[key] = result["thread"] - if result["op"] == "task-finished": - if self.digests is not None: - self.digests["task-duration"].add(result["stop"] - result["start"]) - new_stimulus_id = f"{result['op']}-{time()}" - return ExecuteSuccessEvent( + token = STIMULUS_ID.set(f"{result['op']}-{time()}") + try: + if result["op"] == "task-finished": + if self.digests is not None: + self.digests["task-duration"].add( + result["stop"] - result["start"] + ) + return ExecuteSuccessEvent( + key=key, + value=result["result"], + start=result["start"], + stop=result["stop"], + nbytes=result["nbytes"], + type=result["type"], + ) + if isinstance(result["actual-exception"], Reschedule): + return RescheduleEvent(key=ts.key) + + logger.warning( + "Compute Failed\n" + "Key: %s\n" + "Function: %s\n" + "args: %s\n" + "kwargs: %s\n" + "Exception: %r\n", + key, + str(funcname(function))[:1000], + convert_args_to_str(args2, max_len=1000), + convert_kwargs_to_str(kwargs2, max_len=1000), + result["exception_text"], + ) + return ExecuteFailureEvent( key=key, - value=result["result"], start=result["start"], stop=result["stop"], - nbytes=result["nbytes"], - type=result["type"], - stimulus_id=new_stimulus_id, + exception=result["exception"], + traceback=result["traceback"], + exception_text=result["exception_text"], + traceback_text=result["traceback_text"], ) - - if isinstance(result["actual-exception"], Reschedule): - return RescheduleEvent(key=ts.key, stimulus_id=stimulus_id) - - logger.warning( - "Compute Failed\n" - "Key: %s\n" - "Function: %s\n" - "args: %s\n" - "kwargs: %s\n" - "Exception: %r\n", - key, - str(funcname(function))[:1000], - convert_args_to_str(args2, max_len=1000), - convert_kwargs_to_str(kwargs2, max_len=1000), - result["exception_text"], - ) - return ExecuteFailureEvent( - key=key, - start=result["start"], - stop=result["stop"], - exception=result["exception"], - traceback=result["traceback"], - exception_text=result["exception_text"], - traceback_text=result["traceback_text"], - stimulus_id=f"task-erred-{time()}", - ) + except Exception: + STIMULUS_ID.reset(token) + raise except Exception as exc: logger.error("Exception during execution of task %s.", key, exc_info=True) @@ -3563,7 +3646,7 @@ def _(self, ev: CancelComputeEvent) -> RecsInstrs: if not ts or ts.state not in READY | {"waiting"}: return {}, [] - self.log.append((ev.key, "cancel-compute", ev.stimulus_id, time())) + self.log.append((ev.key, "cancel-compute", STIMULUS_ID.get(), time())) # All possible dependents of ts should not be in state Processing on # scheduler side and therefore should not be assigned to a worker, yet. assert not ts.dependents @@ -4190,12 +4273,15 @@ def secede(): worker = get_worker() tpe_secede() # have this thread secede from the thread pool duration = time() - thread_state.start_time - worker.loop.add_callback( - worker.maybe_transition_long_running, - worker.tasks[thread_state.key], - compute_duration=duration, - stimulus_id=f"secede-{thread_state.key}-{time()}", - ) + token = STIMULUS_ID.set(f"secede-{thread_state.key}-{time()}") + try: + worker.loop.add_callback( + worker.maybe_transition_long_running, + worker.tasks[thread_state.key], + compute_duration=duration, + ) + finally: + STIMULUS_ID.reset(token) class Reschedule(Exception): diff --git a/distributed/worker_client.py b/distributed/worker_client.py index 5a775b38191..52a77a58890 100644 --- a/distributed/worker_client.py +++ b/distributed/worker_client.py @@ -5,6 +5,7 @@ from distributed.metrics import time from distributed.threadpoolexecutor import rejoin, secede +from distributed.utils import set_default_stimulus from distributed.worker import get_client, get_worker, thread_state @@ -53,13 +54,13 @@ def worker_client(timeout=None, separate_thread=True): if separate_thread: duration = time() - thread_state.start_time secede() # have this thread secede from the thread pool - worker.loop.add_callback( - worker.transition, - worker.tasks[thread_state.key], - "long-running", - stimulus_id=f"worker-client-secede-{time()}", - compute_duration=duration, - ) + with set_default_stimulus(f"worker-client-secede-{time()}"): + worker.loop.add_callback( + worker.transition, + worker.tasks[thread_state.key], + "long-running", + compute_duration=duration, + ) yield client diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 966be0a3527..1b468bafe5e 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -12,7 +12,7 @@ from dask.utils import parse_bytes from distributed.protocol.serialize import Serialize -from distributed.utils import recursive_to_dict +from distributed.utils import STIMULUS_ID, recursive_to_dict if TYPE_CHECKING: # TODO move to typing and get out of TYPE_CHECKING (requires Python >=3.10) @@ -69,6 +69,10 @@ class InvalidTransition(Exception): pass +def stimulus_id_factory() -> str: + return STIMULUS_ID.get() + + @lru_cache def _default_data_size() -> int: return parse_bytes(dask.config.get("distributed.scheduler.default-data-size")) @@ -260,15 +264,14 @@ class Instruction: @dataclass class Execute(Instruction): - __slots__ = ("key", "stimulus_id") key: str - stimulus_id: str + stimulus_id: str = field(default_factory=stimulus_id_factory) class SendMessageToScheduler(Instruction): - __slots__ = () #: Matches a key in Scheduler.stream_handlers op: ClassVar[str] + stimulus_id: str = field(default_factory=stimulus_id_factory) def to_dict(self) -> dict[str, Any]: """Convert object to dict so that it can be serialized with msgpack""" @@ -288,8 +291,8 @@ class TaskFinishedMsg(SendMessageToScheduler): metadata: dict thread: int | None startstops: list[StartStop] - stimulus_id: str - __slots__ = tuple(__annotations__) # type: ignore + stimulus_id: str = field(default_factory=stimulus_id_factory) + # __slots__ = tuple(__annotations__) # type: ignore def to_dict(self) -> dict[str, Any]: d = super().to_dict() @@ -308,8 +311,7 @@ class TaskErredMsg(SendMessageToScheduler): traceback_text: str thread: int | None startstops: list[StartStop] - stimulus_id: str - __slots__ = tuple(__annotations__) # type: ignore + stimulus_id: str = field(default_factory=stimulus_id_factory) def to_dict(self) -> dict[str, Any]: d = super().to_dict() @@ -321,9 +323,8 @@ def to_dict(self) -> dict[str, Any]: class ReleaseWorkerDataMsg(SendMessageToScheduler): op = "release-worker-data" - __slots__ = ("key", "stimulus_id") key: str - stimulus_id: str + stimulus_id: str = field(default_factory=stimulus_id_factory) # Not to be confused with RescheduleEvent below or the distributed.Reschedule Exception @@ -332,35 +333,34 @@ class RescheduleMsg(SendMessageToScheduler): op = "reschedule" # Not to be confused with the distributed.Reschedule Exception - __slots__ = ("key", "worker", "stimulus_id") + __slots__ = ("key", "worker") key: str worker: str - stimulus_id: str + stimulus_id: str = field(default_factory=stimulus_id_factory) @dataclass class LongRunningMsg(SendMessageToScheduler): op = "long-running" - __slots__ = ("key", "compute_duration", "stimulus_id") + __slots__ = ("key", "compute_duration") key: str compute_duration: float - stimulus_id: str + stimulus_id: str = field(default_factory=stimulus_id_factory) @dataclass class AddKeysMsg(SendMessageToScheduler): op = "add-keys" - __slots__ = ("keys", "stimulus_id") + __slots__ = "keys" keys: list[str] - stimulus_id: str -@dataclass +@dataclass() class StateMachineEvent: - __slots__ = ("stimulus_id",) - stimulus_id: str + key: str + # stimulus_id: str @dataclass @@ -371,8 +371,8 @@ class ExecuteSuccessEvent(StateMachineEvent): stop: float nbytes: int type: type | None - stimulus_id: str - __slots__ = tuple(__annotations__) # type: ignore + stimulus_id: str = field(default_factory=stimulus_id_factory) + # __slots__ = tuple(__annotations__) # type: ignore @dataclass @@ -384,20 +384,22 @@ class ExecuteFailureEvent(StateMachineEvent): traceback: Serialize | None exception_text: str traceback_text: str - stimulus_id: str - __slots__ = tuple(__annotations__) # type: ignore + stimulus_id: str = field(default_factory=stimulus_id_factory) + # __slots__ = tuple(__annotations__) # type: ignore @dataclass class CancelComputeEvent(StateMachineEvent): __slots__ = ("key",) key: str + stimulus_id: str = field(default_factory=stimulus_id_factory) @dataclass class AlreadyCancelledEvent(StateMachineEvent): __slots__ = ("key",) key: str + stimulus_id: str = field(default_factory=stimulus_id_factory) # Not to be confused with RescheduleMsg above or the distributed.Reschedule Exception @@ -405,6 +407,7 @@ class AlreadyCancelledEvent(StateMachineEvent): class RescheduleEvent(StateMachineEvent): __slots__ = ("key",) key: str + stimulus_id: str = field(default_factory=stimulus_id_factory) if TYPE_CHECKING: From 93f8aa75329b987d83e27c6d32566253464115aa Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 5 Apr 2022 15:32:22 +0200 Subject: [PATCH 28/29] use ctx.run for gather_dep --- distributed/worker.py | 278 +++++++++++++++++++++--------------------- 1 file changed, 137 insertions(+), 141 deletions(-) diff --git a/distributed/worker.py b/distributed/worker.py index 46d87fcea44..ceeb70f7477 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -25,6 +25,7 @@ ) from concurrent.futures import Executor from contextlib import suppress +from contextvars import copy_context from datetime import timedelta from inspect import isawaitable from pickle import PicklingError @@ -2737,14 +2738,15 @@ def ensure_communicating(self) -> None: self.tasks[d]: ("flight", worker) for d in to_gather } self.transitions(recommendations) - + ctx = copy_context() self.loop.add_callback( - self.gather_dep, - worker=worker, - to_gather=to_gather, - total_nbytes=total_nbytes, - # add_callback is not ctx sensitive - stimulus_id=STIMULUS_ID.get(), + ctx.run, + functools.partial( + self.gather_dep, + worker=worker, + to_gather=to_gather, + total_nbytes=total_nbytes, + ), ) for el in skipped_worker_in_flight: @@ -2965,7 +2967,6 @@ async def gather_dep( worker: str, to_gather: Iterable[str], total_nbytes: int, - stimulus_id: str, ) -> None: """Gather dependencies for a task from a worker who has them @@ -2980,165 +2981,160 @@ async def gather_dep( total_nbytes : int Total number of bytes for all the dependencies in to_gather combined """ - with set_default_stimulus(stimulus_id): - if self.status not in Status.ANY_RUNNING: # type: ignore - return - - recommendations: Recs = {} - with log_errors(): - response = {} - to_gather_keys: set[str] = set() - cancelled_keys: set[str] = set() - try: - to_gather_keys, cancelled_keys, cause = self._filter_deps_for_fetch( - to_gather - ) - - if not to_gather_keys: - self.log.append( - ( - "nothing-to-gather", - worker, - to_gather, - STIMULUS_ID.get(), - time(), - ) - ) - return + if self.status not in Status.ANY_RUNNING: # type: ignore + return - assert cause - # Keep namespace clean since this func is long and has many - # dep*, *ts* variables - del to_gather + recommendations: Recs = {} + with log_errors(): + response = {} + to_gather_keys: set[str] = set() + cancelled_keys: set[str] = set() + try: + to_gather_keys, cancelled_keys, cause = self._filter_deps_for_fetch( + to_gather + ) + if not to_gather_keys: self.log.append( ( - "request-dep", + "nothing-to-gather", worker, - to_gather_keys, + to_gather, STIMULUS_ID.get(), time(), ) ) - logger.debug( - "Request %d keys for task %s from %s", - len(to_gather_keys), - cause, + return + + assert cause + # Keep namespace clean since this func is long and has many + # dep*, *ts* variables + del to_gather + + self.log.append( + ( + "request-dep", worker, + to_gather_keys, + STIMULUS_ID.get(), + time(), ) + ) + logger.debug( + "Request %d keys for task %s from %s", + len(to_gather_keys), + cause, + worker, + ) - start = time() - response = await get_data_from_worker( - self.rpc, to_gather_keys, worker, who=self.address + start = time() + response = await get_data_from_worker( + self.rpc, to_gather_keys, worker, who=self.address + ) + stop = time() + if response["status"] == "busy": + return + + self._update_metrics_received_data( + start=start, + stop=stop, + data=response["data"], + cause=cause, + worker=worker, + ) + self.log.append( + ( + "receive-dep", + worker, + set(response["data"]), + STIMULUS_ID.get(), + time(), ) - stop = time() - if response["status"] == "busy": - return - - self._update_metrics_received_data( - start=start, - stop=stop, - data=response["data"], - cause=cause, - worker=worker, + ) + + except OSError: + logger.exception("Worker stream died during communication: %s", worker) + has_what = self.has_what.pop(worker) + self.pending_data_per_worker.pop(worker) + self.log.append( + ( + "receive-dep-failed", + worker, + has_what, + STIMULUS_ID.get(), + time(), ) + ) + for d in has_what: + ts = self.tasks[d] + ts.who_has.remove(worker) + + except Exception as e: + logger.exception(e) + if self.batched_stream and LOG_PDB: + import pdb + + pdb.set_trace() + msg = error_message(e) + for k in self.in_flight_workers[worker]: + ts = self.tasks[k] + recommendations[ts] = tuple(msg.values()) + raise + finally: + self.comm_nbytes -= total_nbytes + busy = response.get("status", "") == "busy" + data = response.get("data", {}) + + if busy: self.log.append( ( - "receive-dep", + "busy-gather", worker, - set(response["data"]), + to_gather_keys, STIMULUS_ID.get(), time(), ) ) - except OSError: - logger.exception( - "Worker stream died during communication: %s", worker - ) - has_what = self.has_what.pop(worker) - self.pending_data_per_worker.pop(worker) - self.log.append( - ( - "receive-dep-failed", - worker, - has_what, - STIMULUS_ID.get(), - time(), + for d in self.in_flight_workers.pop(worker): + ts = self.tasks[d] + ts.done = True + if d in cancelled_keys: + if ts.state == "cancelled": + recommendations[ts] = "released" + else: + recommendations[ts] = "fetch" + elif d in data: + recommendations[ts] = ("memory", data[d]) + elif busy: + recommendations[ts] = "fetch" + elif ts not in recommendations: + ts.who_has.discard(worker) + self.has_what[worker].discard(ts.key) + self.log.append((d, "missing-dep", STIMULUS_ID.get(), time())) + self.batched_stream.send( + { + "op": "missing-data", + "errant_worker": worker, + "key": d, + "stimulus_id": STIMULUS_ID.get(), + } ) - ) - for d in has_what: - ts = self.tasks[d] - ts.who_has.remove(worker) + recommendations[ts] = "fetch" if ts.who_has else "missing" + del data, response + self.transitions(recommendations) + self.ensure_computing() - except Exception as e: - logger.exception(e) - if self.batched_stream and LOG_PDB: - import pdb + if not busy: + self.repetitively_busy = 0 + else: + # Exponential backoff to avoid hammering scheduler/worker + self.repetitively_busy += 1 + await asyncio.sleep(0.100 * 1.5**self.repetitively_busy) - pdb.set_trace() - msg = error_message(e) - for k in self.in_flight_workers[worker]: - ts = self.tasks[k] - recommendations[ts] = tuple(msg.values()) - raise - finally: - self.comm_nbytes -= total_nbytes - busy = response.get("status", "") == "busy" - data = response.get("data", {}) - - if busy: - self.log.append( - ( - "busy-gather", - worker, - to_gather_keys, - STIMULUS_ID.get(), - time(), - ) - ) + await self.query_who_has(*to_gather_keys) - for d in self.in_flight_workers.pop(worker): - ts = self.tasks[d] - ts.done = True - if d in cancelled_keys: - if ts.state == "cancelled": - recommendations[ts] = "released" - else: - recommendations[ts] = "fetch" - elif d in data: - recommendations[ts] = ("memory", data[d]) - elif busy: - recommendations[ts] = "fetch" - elif ts not in recommendations: - ts.who_has.discard(worker) - self.has_what[worker].discard(ts.key) - self.log.append( - (d, "missing-dep", STIMULUS_ID.get(), time()) - ) - self.batched_stream.send( - { - "op": "missing-data", - "errant_worker": worker, - "key": d, - "stimulus_id": STIMULUS_ID.get(), - } - ) - recommendations[ts] = "fetch" if ts.who_has else "missing" - del data, response - self.transitions(recommendations) - self.ensure_computing() - - if not busy: - self.repetitively_busy = 0 - else: - # Exponential backoff to avoid hammering scheduler/worker - self.repetitively_busy += 1 - await asyncio.sleep(0.100 * 1.5**self.repetitively_busy) - - await self.query_who_has(*to_gather_keys) - - self.ensure_communicating() + self.ensure_communicating() async def find_missing(self) -> None: with log_errors(), set_default_stimulus(f"find-missing-{time()}"): From c780413dc8db414278dced38a4f857eeb2f9a289 Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 5 Apr 2022 16:27:13 +0200 Subject: [PATCH 29/29] add reason kwarg --- distributed/tests/test_worker_state_machine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index db89ad3db32..214f6dba251 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -86,7 +86,7 @@ def test_unique_task_heap(): assert repr(heap) == "" -@pytest.mark.xfail("slots not compatible with defaultfactory") +@pytest.mark.xfail(reason="slots not compatible with defaultfactory") @pytest.mark.parametrize( "cls", chain(