Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
import threading
import traceback
import weakref
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from concurrent.futures.process import BrokenProcessPool
from numbers import Number
Expand Down Expand Up @@ -39,6 +40,7 @@
from distributed.diagnostics import nvml
from distributed.diagnostics.plugin import PipInstall
from distributed.metrics import time
from distributed.profile import wait_profiler
from distributed.protocol import pickle
from distributed.scheduler import Scheduler
from distributed.utils import TimeoutError
Expand Down Expand Up @@ -1715,6 +1717,35 @@ async def test_story_with_deps(c, s, a, b):
assert_worker_story(story, expected, strict=True)


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_stimulus_story(c, s, a):
class C:
pass

f = c.submit(C, key="f") # Test that substrings aren't matched by story()
f2 = c.submit(inc, 2, key="f2")
f3 = c.submit(inc, 3, key="f3")
await wait([f, f2, f3])

# Test that ExecuteSuccessEvent.value is not stored in the the event log
assert isinstance(a.data["f"], C)
ref = weakref.ref(a.data["f"])
del f
while "f" in a.data:
await asyncio.sleep(0.01)
wait_profiler()
assert ref() is None

story = a.stimulus_story("f", "f2")
assert {ev.key for ev in story} == {"f", "f2"}
assert {ev.type for ev in story} == {C, int}

prev_handled = story[0].handled
for ev in story[1:]:
assert ev.handled >= prev_handled
prev_handled = ev.handled


@gen_cluster(client=True)
async def test_gather_dep_one_worker_always_busy(c, s, a, b):
# Ensure that both dependencies for H are on another worker than H itself.
Expand Down Expand Up @@ -3303,6 +3334,7 @@ async def test_Worker__to_dict(c, s, a):
"in_flight_tasks",
"in_flight_workers",
"log",
"stimulus_log",
"tasks",
"logs",
"config",
Expand Down
101 changes: 100 additions & 1 deletion distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

import pytest

from distributed.protocol.serialize import Serialize
from distributed.utils import recursive_to_dict
from distributed.worker_state_machine import (
ExecuteFailureEvent,
ExecuteSuccessEvent,
Instruction,
ReleaseWorkerDataMsg,
RescheduleEvent,
RescheduleMsg,
SendMessageToScheduler,
StateMachineEvent,
Expand Down Expand Up @@ -101,7 +105,9 @@ def test_slots(cls):
params = [
k
for k in dir(cls)
if not k.startswith("_") and k != "op" and not callable(getattr(cls, k))
if not k.startswith("_")
and k not in ("op", "handled")
and not callable(getattr(cls, k))
]
inst = cls(**dict.fromkeys(params))
assert not hasattr(inst, "__dict__")
Expand Down Expand Up @@ -133,3 +139,96 @@ def test_merge_recs_instructions():
)
with pytest.raises(ValueError):
merge_recs_instructions(({x: "memory"}, []), ({x: "released"}, []))


def test_event_to_dict():
ev = RescheduleEvent(stimulus_id="test", key="x")
ev2 = ev.to_loggable(handled=11.22)
assert ev2 == ev
d = recursive_to_dict(ev2)
assert d == {
"cls": "RescheduleEvent",
"stimulus_id": "test",
"handled": 11.22,
"key": "x",
}
ev3 = StateMachineEvent.from_dict(d)
assert ev3 == ev


def test_executesuccess_to_dict():
"""The potentially very large ExecuteSuccessEvent.value is not stored in the log"""
ev = ExecuteSuccessEvent(
stimulus_id="test",
key="x",
value=123,
start=123.4,
stop=456.7,
nbytes=890,
type=int,
)
ev2 = ev.to_loggable(handled=11.22)
assert ev2.value is None
assert ev.value == 123
d = recursive_to_dict(ev2)
assert d == {
"cls": "ExecuteSuccessEvent",
"stimulus_id": "test",
"handled": 11.22,
"key": "x",
"value": None,
"nbytes": 890,
"start": 123.4,
"stop": 456.7,
"type": "<class 'int'>",
}
ev3 = StateMachineEvent.from_dict(d)
assert isinstance(ev3, ExecuteSuccessEvent)
assert ev3.stimulus_id == "test"
assert ev3.handled == 11.22
assert ev3.key == "x"
assert ev3.value is None
assert ev3.start == 123.4
assert ev3.stop == 456.7
assert ev3.nbytes == 890
assert ev3.type is None


def test_executefailure_to_dict():
ev = ExecuteFailureEvent(
stimulus_id="test",
key="x",
start=123.4,
stop=456.7,
exception=Serialize(ValueError("foo")),
traceback=Serialize("lose me"),
exception_text="exc text",
traceback_text="tb text",
)
ev2 = ev.to_loggable(handled=11.22)
assert ev2 == ev
d = recursive_to_dict(ev2)
assert d == {
"cls": "ExecuteFailureEvent",
"stimulus_id": "test",
"handled": 11.22,
"key": "x",
"start": 123.4,
"stop": 456.7,
"exception": "<Serialize: foo>",
"traceback": "<Serialize: lose me>",
"exception_text": "exc text",
"traceback_text": "tb text",
}
ev3 = StateMachineEvent.from_dict(d)
assert isinstance(ev3, ExecuteFailureEvent)
assert ev3.stimulus_id == "test"
assert ev3.handled == 11.22
assert ev3.key == "x"
assert ev3.start == 123.4
assert ev3.stop == 456.7
assert isinstance(ev3.exception, Serialize)
assert isinstance(ev3.exception.data, Exception)
assert ev3.traceback is None
assert ev3.exception_text == "exc text"
assert ev3.traceback_text == "tb text"
15 changes: 13 additions & 2 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ class Worker(ServerNode):
executed_count: int
long_running: set[str]
log: deque[tuple] # [(..., stimulus_id: str | None, timestamp: float), ...]
stimulus_log: deque[StateMachineEvent]
incoming_transfer_log: deque[dict[str, Any]]
outgoing_transfer_log: deque[dict[str, Any]]
target_message_size: int
Expand Down Expand Up @@ -519,7 +520,8 @@ def __init__(

self.target_message_size = int(50e6) # 50 MB

self.log = deque(maxlen=100000)
self.log = deque(maxlen=100_000)
self.stimulus_log = deque(maxlen=10_000)
if validate is None:
validate = dask.config.get("distributed.scheduler.validate")
self.validate = validate
Expand Down Expand Up @@ -1020,6 +1022,7 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict:
"in_flight_tasks": self.in_flight_tasks,
"in_flight_workers": self.in_flight_workers,
"log": self.log,
"stimulus_log": self.stimulus_log,
"tasks": self.tasks,
"logs": self.get_logs(),
"config": dask.config.config,
Expand Down Expand Up @@ -2598,7 +2601,7 @@ def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None:

def handle_stimulus(self, stim: StateMachineEvent) -> None:
with log_errors():
# self.stimulus_history.append(stim) # TODO
self.stimulus_log.append(stim.to_loggable(handled=time()))
recs, instructions = self.handle_event(stim)
self.transitions(recs, stimulus_id=stim.stimulus_id)
self._handle_instructions(instructions)
Expand Down Expand Up @@ -2648,9 +2651,17 @@ def stateof(self, key: str) -> dict[str, Any]:
}

def story(self, *keys_or_tasks: str | TaskState) -> list[tuple]:
"""Return all transitions involving one or more tasks"""
keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks}
return worker_story(keys, self.log)

def stimulus_story(
self, *keys_or_tasks: str | TaskState
) -> list[StateMachineEvent]:
"""Return all state machine events involving one or more tasks"""
keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks}
return [ev for ev in self.stimulus_log if getattr(ev, "key", None) in keys]

def ensure_communicating(self) -> None:
stimulus_id = f"ensure-communicating-{time()}"
skipped_worker_in_flight = []
Expand Down
70 changes: 69 additions & 1 deletion distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import heapq
import sys
from collections.abc import Callable, Container, Iterator
from copy import copy
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Collection # TODO move to collections.abc (requires Python >=3.9)
Expand Down Expand Up @@ -353,8 +354,61 @@ class AddKeysMsg(SendMessageToScheduler):

@dataclass
class StateMachineEvent:
__slots__ = ("stimulus_id",)
__slots__ = ("stimulus_id", "handled")
stimulus_id: str
#: timestamp of when the event was handled by the worker
# TODO Switch to @dataclass(slots=True), uncomment the line below, and remove the
# __new__ method (requires Python >=3.10)
# handled: float | None = field(init=False, default=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is the reason for the new method containing the self.handled = None assignment

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. clarified in comment

_classes: ClassVar[dict[str, type[StateMachineEvent]]] = {}

def __new__(cls, *args, **kwargs):
self = object.__new__(cls)
self.handled = None
return self

def __init_subclass__(cls):
StateMachineEvent._classes[cls.__name__] = cls

def to_loggable(self, *, handled: float) -> StateMachineEvent:
"""Produce a variant version of self that is small enough to be stored in memory
in the medium term and contains meaningful information for debugging
"""
self.handled = handled
return self

def _to_dict(self, *, exclude: Container[str] = ()) -> dict:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This dictionary conversion seems necessary because stimulus_log: StateMachineEvent has been added to Worker and thus must be supported by Worker._to_dict

"""Dictionary representation for debugging purposes.

See also
--------
distributed.utils.recursive_to_dict
"""
info = {
"cls": type(self).__name__,
"stimulus_id": self.stimulus_id,
"handled": self.handled,
}
info.update({k: getattr(self, k) for k in self.__annotations__})
info = {k: v for k, v in info.items() if k not in exclude}
return recursive_to_dict(info, exclude=exclude)

@staticmethod
def from_dict(d: dict) -> StateMachineEvent:
"""Convert the output of ``recursive_to_dict`` back into the original object.
The output object is meaningful for the purpose of rebuilding the state machine,
but not necessarily identical to the original.
"""
kwargs = d.copy()
cls = StateMachineEvent._classes[kwargs.pop("cls")]
handled = kwargs.pop("handled")
inst = cls(**kwargs)
inst.handled = handled
inst._after_from_dict()
return inst

def _after_from_dict(self) -> None:
"""Optional post-processing after an instance is created by ``from_dict``"""


@dataclass
Expand All @@ -372,6 +426,16 @@ class ExecuteSuccessEvent(StateMachineEvent):
type: type | None
__slots__ = tuple(__annotations__) # type: ignore

def to_loggable(self, *, handled: float) -> StateMachineEvent:
out = copy(self)
out.handled = handled
out.value = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the execution result is discarded because of the potentially large size of the result, and possibly the complexity of serialising/deserialising the result?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not discarding it would cause worker.stimulus_log to become effecitvely a copy of worker.data, except that it never loses any data!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed!

return out

def _after_from_dict(self) -> None:
self.value = None
self.type = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the execution result type is discarded here because it's merely a string representation at this point and one would have to deal with serialising/unserialising types.

In any case, I think reconstructing the result of execution is non-trivial. How does this impact replayability of events on the Worker (out of interest?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these fields that are being discarded on a serialization round-trip should be inconsequential for the purpose of rebuilding the state.



@dataclass
class ExecuteFailureEvent(StateMachineEvent):
Expand All @@ -384,6 +448,10 @@ class ExecuteFailureEvent(StateMachineEvent):
traceback_text: str
__slots__ = tuple(__annotations__) # type: ignore

def _after_from_dict(self) -> None:
self.exception = Serialize(Exception())
self.traceback = None


@dataclass
class CancelComputeEvent(StateMachineEvent):
Expand Down