Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions src/vector/_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# or https://github.com/scikit-hep/vector for details.

import typing
from contextlib import suppress

from vector._typeutils import (
BoolCollection,
Expand Down Expand Up @@ -2533,6 +2534,16 @@ def _from_signature(
]


def _get_handler_index(obj: VectorProtocol) -> int:
"""Returns the index of the first valid handler checking the list of parent classes"""
for cls in type(obj).__mro__:
with suppress(ValueError):
return _handler_priority.index(cls.__module__)
raise AssertionError(
f"Could not find a valid handler for {obj}! This should not happen."
)


def _handler_of(*objects: VectorProtocol) -> VectorProtocol:
"""
Determines which vector should wrap the output of a dispatched function.
Expand All @@ -2544,13 +2555,12 @@ def _handler_of(*objects: VectorProtocol) -> VectorProtocol:
"""
handler = None
for obj in objects:
if isinstance(obj, Vector):
if handler is None:
handler = obj
elif _handler_priority.index(
type(obj).__module__
) > _handler_priority.index(type(handler).__module__):
handler = obj
if not isinstance(obj, Vector):
continue
Comment thread
jpivarski marked this conversation as resolved.
if handler is None:
handler = obj
elif _get_handler_index(obj) > _get_handler_index(handler):
handler = obj

assert handler is not None
return handler
Expand Down
17 changes: 17 additions & 0 deletions tests/test_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2019-2021, Jonas Eschle, Jim Pivarski, Eduardo Rodrigues, and Henry Schreiner.
#
# Distributed under the 3-clause BSD license, see accompanying file LICENSE
# or https://github.com/scikit-hep/vector for details.

import vector


class CustomVector(vector.VectorObject4D):
pass


def test_handler_of():
object_a = CustomVector.from_xyzt(0.0, 0.0, 0.0, 0.0)
object_b = CustomVector.from_xyzt(1.0, 1.0, 1.0, 1.0)
protocol = vector._methods._handler_of(object_a, object_b)
assert protocol == object_a