Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/vector/_backends/numpy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,15 @@ def __eq__(self, other: typing.Any) -> typing.Any:
def __ne__(self, other: typing.Any) -> typing.Any:
return numpy.not_equal(self, other) # type: ignore

def __reduce__(self) -> typing.Union[str, typing.Tuple[typing.Any, ...]]:
pickled_state = super().__reduce__()
new_state = (*pickled_state[2], self.__dict__)
return pickled_state[0], pickled_state[1], new_state

def __setstate__(self, state: typing.Any) -> None:
self.__dict__.update(state[-1])
super().__setstate__(state[0:-1]) # type: ignore
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder to myself: we should enable --show-error-codes on mypy and require the error codes here, so we know why these are being ignored. I can do that after the PR.


def __array_ufunc__(
self,
ufunc: typing.Any,
Expand Down Expand Up @@ -665,6 +674,9 @@ def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> "VectorNumpy2D":
return array.view(cls)

def __array_finalize__(self, obj: typing.Any) -> None:
if obj is None:
return

if _has(self, ("x", "y")):
self._azimuthal_type = AzimuthalNumpyXY
elif _has(self, ("rho", "phi")):
Expand Down Expand Up @@ -840,6 +852,9 @@ class MomentumNumpy2D(PlanarMomentum, VectorNumpy2D): # type: ignore
dtype: "numpy.dtype[typing.Any]"

def __array_finalize__(self, obj: typing.Any) -> None:
if obj is None:
return

self.dtype.names = tuple(
_repr_momentum_to_generic.get(x, x) for x in (self.dtype.names or ())
)
Expand Down Expand Up @@ -882,6 +897,9 @@ def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> "VectorNumpy3D":
return array.view(cls)

def __array_finalize__(self, obj: typing.Any) -> None:
if obj is None:
return

if _has(self, ("x", "y")):
self._azimuthal_type = AzimuthalNumpyXY
elif _has(self, ("rho", "phi")):
Expand Down Expand Up @@ -1076,6 +1094,9 @@ class MomentumNumpy3D(SpatialMomentum, VectorNumpy3D): # type: ignore
dtype: "numpy.dtype[typing.Any]"

def __array_finalize__(self, obj: typing.Any) -> None:
if obj is None:
return

self.dtype.names = tuple(
_repr_momentum_to_generic.get(x, x) for x in (self.dtype.names or ())
)
Expand Down Expand Up @@ -1132,6 +1153,9 @@ def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> "VectorNumpy4D":
return array.view(cls)

def __array_finalize__(self, obj: typing.Any) -> None:
if obj is None:
return

if _has(self, ("x", "y")):
self._azimuthal_type = AzimuthalNumpyXY
elif _has(self, ("rho", "phi")):
Expand Down Expand Up @@ -1349,6 +1373,9 @@ class MomentumNumpy4D(LorentzMomentum, VectorNumpy4D): # type: ignore
dtype: "numpy.dtype[typing.Any]"

def __array_finalize__(self, obj: typing.Any) -> None:
if obj is None:
return

self.dtype.names = tuple(
_repr_momentum_to_generic.get(x, x) for x in (self.dtype.names or ())
)
Expand Down
97 changes: 97 additions & 0 deletions tests/backends/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# or https://github.com/scikit-hep/vector for details.

import math
import pickle

import numpy

Expand All @@ -29,3 +30,99 @@ def test_rhophi():
assert numpy.allclose(array.y, [0, 1, 4])
assert numpy.allclose(array.rho, [0, 1, 5])
assert numpy.allclose(array.phi, [10, math.atan2(1, 0), math.atan2(4, 3)])


def test_pickle_vector_numpy_2d():
array = vector._backends.numpy_.VectorNumpy2D(
[(0, 0), (0, 1), (3, 4)], dtype=[("x", numpy.float64), ("y", numpy.float64)]
)

array_pickled = pickle.dumps(array)
array_new = pickle.loads(array_pickled)

assert numpy.allclose(array_new.x, array.x)
assert numpy.allclose(array_new.y, array.y)


def test_pickle_momentum_numpy_2d():
array = vector._backends.numpy_.MomentumNumpy2D(
[(0, 0), (0, 1), (3, 4)], dtype=[("rho", numpy.float64), ("phi", numpy.float64)]
)

array_pickled = pickle.dumps(array)
array_new = pickle.loads(array_pickled)

assert numpy.allclose(array_new.rho, array.rho)
assert numpy.allclose(array_new.phi, array.phi)


def test_pickle_vector_numpy_3d():
array = vector._backends.numpy_.VectorNumpy3D(
[(0, 0, 0), (0, 1, 1), (3, 4, 5)],
dtype=[("x", numpy.float64), ("y", numpy.float64), ("z", numpy.float64)],
)

array_pickled = pickle.dumps(array)
array_new = pickle.loads(array_pickled)

assert numpy.allclose(array_new.x, array.x)
assert numpy.allclose(array_new.y, array.y)
assert numpy.allclose(array_new.z, array.z)


def test_pickle_momentum_numpy_3d():
array = vector._backends.numpy_.MomentumNumpy3D(
[(0, 0, 0), (0, 1, 1), (3, 4, 5)],
dtype=[
("rho", numpy.float64),
("phi", numpy.float64),
("theta", numpy.float64),
],
)

array_pickled = pickle.dumps(array)
array_new = pickle.loads(array_pickled)

assert numpy.allclose(array_new.rho, array.rho)
assert numpy.allclose(array_new.phi, array.phi)
assert numpy.allclose(array_new.theta, array.theta)


def test_pickle_vector_numpy_4d():
array = vector._backends.numpy_.VectorNumpy4D(
[(0, 0, 0, 0), (0, 1, 1, 1), (3, 4, 5, 6)],
dtype=[
("x", numpy.float64),
("y", numpy.float64),
("z", numpy.float64),
("t", numpy.float64),
],
)

array_pickled = pickle.dumps(array)
array_new = pickle.loads(array_pickled)

assert numpy.allclose(array_new.x, array.x)
assert numpy.allclose(array_new.y, array.y)
assert numpy.allclose(array_new.z, array.z)
assert numpy.allclose(array_new.t, array.t)


def test_pickle_momentum_numpy_4d():
array = vector._backends.numpy_.MomentumNumpy4D(
[(0, 0, 0, 0), (0, 1, 1, 1), (3, 4, 5, 6)],
dtype=[
("rho", numpy.float64),
("phi", numpy.float64),
("theta", numpy.float64),
("tau", numpy.float64),
],
)

array_pickled = pickle.dumps(array)
array_new = pickle.loads(array_pickled)

assert numpy.allclose(array_new.rho, array.rho)
assert numpy.allclose(array_new.phi, array.phi)
assert numpy.allclose(array_new.theta, array.theta)
assert numpy.allclose(array_new.tau, array.tau)