diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index b2ca44b4752..42503001dca 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -7,7 +7,7 @@ from enum import Enum from functools import partial from types import ModuleType -from typing import Any, Literal +from typing import Any, Generic, Literal, TypeVar import msgpack @@ -28,6 +28,8 @@ ) from distributed.utils import ensure_memoryview, has_keyword +T = TypeVar("T") + dask_serialize = dask.utils.Dispatch("dask_serialize") dask_deserialize = dask.utils.Dispatch("dask_deserialize") @@ -561,7 +563,7 @@ def __ne__(self, other): return not (self == other) -class ToPickle: +class ToPickle(Generic[T]): """Mark an object that should be pickled Both the scheduler and workers with automatically unpickle this @@ -572,19 +574,18 @@ class ToPickle: to False, the scheduler will raise an exception instead. """ - def __init__(self, data): + data: T + + def __init__(self, data: T): self.data = data - def __repr__(self): + def __repr__(self) -> str: return "" % str(self.data) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and other.data == self.data - def __ne__(self, other): - return not (self == other) - - def __hash__(self): + def __hash__(self) -> int: return hash(self.data)