diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index de3bf2423b9..0646d7cfc9d 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -7,6 +7,7 @@ import sys import threading import traceback +import weakref from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from concurrent.futures.process import BrokenProcessPool from numbers import Number @@ -39,6 +40,7 @@ from distributed.diagnostics import nvml from distributed.diagnostics.plugin import PipInstall from distributed.metrics import time +from distributed.profile import wait_profiler from distributed.protocol import pickle from distributed.scheduler import Scheduler from distributed.utils import TimeoutError @@ -1715,6 +1717,35 @@ async def test_story_with_deps(c, s, a, b): assert_worker_story(story, expected, strict=True) +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_stimulus_story(c, s, a): + class C: + pass + + f = c.submit(C, key="f") # Test that substrings aren't matched by story() + f2 = c.submit(inc, 2, key="f2") + f3 = c.submit(inc, 3, key="f3") + await wait([f, f2, f3]) + + # Test that ExecuteSuccessEvent.value is not stored in the the event log + assert isinstance(a.data["f"], C) + ref = weakref.ref(a.data["f"]) + del f + while "f" in a.data: + await asyncio.sleep(0.01) + wait_profiler() + assert ref() is None + + story = a.stimulus_story("f", "f2") + assert {ev.key for ev in story} == {"f", "f2"} + assert {ev.type for ev in story} == {C, int} + + prev_handled = story[0].handled + for ev in story[1:]: + assert ev.handled >= prev_handled + prev_handled = ev.handled + + @gen_cluster(client=True) async def test_gather_dep_one_worker_always_busy(c, s, a, b): # Ensure that both dependencies for H are on another worker than H itself. @@ -3303,6 +3334,7 @@ async def test_Worker__to_dict(c, s, a): "in_flight_tasks", "in_flight_workers", "log", + "stimulus_log", "tasks", "logs", "config", diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 87828a17cf7..d8337ace8e5 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -2,10 +2,14 @@ import pytest +from distributed.protocol.serialize import Serialize from distributed.utils import recursive_to_dict from distributed.worker_state_machine import ( + ExecuteFailureEvent, + ExecuteSuccessEvent, Instruction, ReleaseWorkerDataMsg, + RescheduleEvent, RescheduleMsg, SendMessageToScheduler, StateMachineEvent, @@ -101,7 +105,9 @@ def test_slots(cls): params = [ k for k in dir(cls) - if not k.startswith("_") and k != "op" and not callable(getattr(cls, k)) + if not k.startswith("_") + and k not in ("op", "handled") + and not callable(getattr(cls, k)) ] inst = cls(**dict.fromkeys(params)) assert not hasattr(inst, "__dict__") @@ -133,3 +139,96 @@ def test_merge_recs_instructions(): ) with pytest.raises(ValueError): merge_recs_instructions(({x: "memory"}, []), ({x: "released"}, [])) + + +def test_event_to_dict(): + ev = RescheduleEvent(stimulus_id="test", key="x") + ev2 = ev.to_loggable(handled=11.22) + assert ev2 == ev + d = recursive_to_dict(ev2) + assert d == { + "cls": "RescheduleEvent", + "stimulus_id": "test", + "handled": 11.22, + "key": "x", + } + ev3 = StateMachineEvent.from_dict(d) + assert ev3 == ev + + +def test_executesuccess_to_dict(): + """The potentially very large ExecuteSuccessEvent.value is not stored in the log""" + ev = ExecuteSuccessEvent( + stimulus_id="test", + key="x", + value=123, + start=123.4, + stop=456.7, + nbytes=890, + type=int, + ) + ev2 = ev.to_loggable(handled=11.22) + assert ev2.value is None + assert ev.value == 123 + d = recursive_to_dict(ev2) + assert d == { + "cls": "ExecuteSuccessEvent", + "stimulus_id": "test", + "handled": 11.22, + "key": "x", + "value": None, + "nbytes": 890, + "start": 123.4, + "stop": 456.7, + "type": "", + } + ev3 = StateMachineEvent.from_dict(d) + assert isinstance(ev3, ExecuteSuccessEvent) + assert ev3.stimulus_id == "test" + assert ev3.handled == 11.22 + assert ev3.key == "x" + assert ev3.value is None + assert ev3.start == 123.4 + assert ev3.stop == 456.7 + assert ev3.nbytes == 890 + assert ev3.type is None + + +def test_executefailure_to_dict(): + ev = ExecuteFailureEvent( + stimulus_id="test", + key="x", + start=123.4, + stop=456.7, + exception=Serialize(ValueError("foo")), + traceback=Serialize("lose me"), + exception_text="exc text", + traceback_text="tb text", + ) + ev2 = ev.to_loggable(handled=11.22) + assert ev2 == ev + d = recursive_to_dict(ev2) + assert d == { + "cls": "ExecuteFailureEvent", + "stimulus_id": "test", + "handled": 11.22, + "key": "x", + "start": 123.4, + "stop": 456.7, + "exception": "", + "traceback": "", + "exception_text": "exc text", + "traceback_text": "tb text", + } + ev3 = StateMachineEvent.from_dict(d) + assert isinstance(ev3, ExecuteFailureEvent) + assert ev3.stimulus_id == "test" + assert ev3.handled == 11.22 + assert ev3.key == "x" + assert ev3.start == 123.4 + assert ev3.stop == 456.7 + assert isinstance(ev3.exception, Serialize) + assert isinstance(ev3.exception.data, Exception) + assert ev3.traceback is None + assert ev3.exception_text == "exc text" + assert ev3.traceback_text == "tb text" diff --git a/distributed/worker.py b/distributed/worker.py index 7c5bc61ca15..e4241ccb3be 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -364,6 +364,7 @@ class Worker(ServerNode): executed_count: int long_running: set[str] log: deque[tuple] # [(..., stimulus_id: str | None, timestamp: float), ...] + stimulus_log: deque[StateMachineEvent] incoming_transfer_log: deque[dict[str, Any]] outgoing_transfer_log: deque[dict[str, Any]] target_message_size: int @@ -519,7 +520,8 @@ def __init__( self.target_message_size = int(50e6) # 50 MB - self.log = deque(maxlen=100000) + self.log = deque(maxlen=100_000) + self.stimulus_log = deque(maxlen=10_000) if validate is None: validate = dask.config.get("distributed.scheduler.validate") self.validate = validate @@ -1020,6 +1022,7 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict: "in_flight_tasks": self.in_flight_tasks, "in_flight_workers": self.in_flight_workers, "log": self.log, + "stimulus_log": self.stimulus_log, "tasks": self.tasks, "logs": self.get_logs(), "config": dask.config.config, @@ -2598,7 +2601,7 @@ def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None: def handle_stimulus(self, stim: StateMachineEvent) -> None: with log_errors(): - # self.stimulus_history.append(stim) # TODO + self.stimulus_log.append(stim.to_loggable(handled=time())) recs, instructions = self.handle_event(stim) self.transitions(recs, stimulus_id=stim.stimulus_id) self._handle_instructions(instructions) @@ -2648,9 +2651,17 @@ def stateof(self, key: str) -> dict[str, Any]: } def story(self, *keys_or_tasks: str | TaskState) -> list[tuple]: + """Return all transitions involving one or more tasks""" keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks} return worker_story(keys, self.log) + def stimulus_story( + self, *keys_or_tasks: str | TaskState + ) -> list[StateMachineEvent]: + """Return all state machine events involving one or more tasks""" + keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks} + return [ev for ev in self.stimulus_log if getattr(ev, "key", None) in keys] + def ensure_communicating(self) -> None: stimulus_id = f"ensure-communicating-{time()}" skipped_worker_in_flight = [] diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index a21c2acb301..5e993fe5041 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -3,6 +3,7 @@ import heapq import sys from collections.abc import Callable, Container, Iterator +from copy import copy from dataclasses import dataclass, field from functools import lru_cache from typing import Collection # TODO move to collections.abc (requires Python >=3.9) @@ -353,8 +354,61 @@ class AddKeysMsg(SendMessageToScheduler): @dataclass class StateMachineEvent: - __slots__ = ("stimulus_id",) + __slots__ = ("stimulus_id", "handled") stimulus_id: str + #: timestamp of when the event was handled by the worker + # TODO Switch to @dataclass(slots=True), uncomment the line below, and remove the + # __new__ method (requires Python >=3.10) + # handled: float | None = field(init=False, default=None) + _classes: ClassVar[dict[str, type[StateMachineEvent]]] = {} + + def __new__(cls, *args, **kwargs): + self = object.__new__(cls) + self.handled = None + return self + + def __init_subclass__(cls): + StateMachineEvent._classes[cls.__name__] = cls + + def to_loggable(self, *, handled: float) -> StateMachineEvent: + """Produce a variant version of self that is small enough to be stored in memory + in the medium term and contains meaningful information for debugging + """ + self.handled = handled + return self + + def _to_dict(self, *, exclude: Container[str] = ()) -> dict: + """Dictionary representation for debugging purposes. + + See also + -------- + distributed.utils.recursive_to_dict + """ + info = { + "cls": type(self).__name__, + "stimulus_id": self.stimulus_id, + "handled": self.handled, + } + info.update({k: getattr(self, k) for k in self.__annotations__}) + info = {k: v for k, v in info.items() if k not in exclude} + return recursive_to_dict(info, exclude=exclude) + + @staticmethod + def from_dict(d: dict) -> StateMachineEvent: + """Convert the output of ``recursive_to_dict`` back into the original object. + The output object is meaningful for the purpose of rebuilding the state machine, + but not necessarily identical to the original. + """ + kwargs = d.copy() + cls = StateMachineEvent._classes[kwargs.pop("cls")] + handled = kwargs.pop("handled") + inst = cls(**kwargs) + inst.handled = handled + inst._after_from_dict() + return inst + + def _after_from_dict(self) -> None: + """Optional post-processing after an instance is created by ``from_dict``""" @dataclass @@ -372,6 +426,16 @@ class ExecuteSuccessEvent(StateMachineEvent): type: type | None __slots__ = tuple(__annotations__) # type: ignore + def to_loggable(self, *, handled: float) -> StateMachineEvent: + out = copy(self) + out.handled = handled + out.value = None + return out + + def _after_from_dict(self) -> None: + self.value = None + self.type = None + @dataclass class ExecuteFailureEvent(StateMachineEvent): @@ -384,6 +448,10 @@ class ExecuteFailureEvent(StateMachineEvent): traceback_text: str __slots__ = tuple(__annotations__) # type: ignore + def _after_from_dict(self) -> None: + self.exception = Serialize(Exception()) + self.traceback = None + @dataclass class CancelComputeEvent(StateMachineEvent):