Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 102 additions & 6 deletions src/vector/_backends/awkward_.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,34 @@ def from_fields(cls, array: ak.Array) -> "AzimuthalAwkward":
return AzimuthalAwkwardRhoPhi(array["rho"], array["phi"])
else:
raise ValueError(
f"array does not have azimuthal coordinates (x/y/rho/phi): {', '.join(fields)}"
"array does not have azimuthal coordinates (x, y or rho, phi): "
f"{', '.join(fields)}"
)

@classmethod
def from_momentum_fields(cls, array: ak.Array) -> "AzimuthalAwkward":
"""
Create a :doc:`vector._backends.awkward_.AzimuthalAwkwardXY` or a
:doc:`vector._backends.awkward_.AzimuthalAwkwardRhoPhi`, depending on
the fields in ``array``, allowing momentum synonyms.
"""
fields = ak.fields(array)
if "x" in fields and "y" in fields:
return AzimuthalAwkwardXY(array["x"], array["y"])
elif "x" in fields and "py" in fields:
return AzimuthalAwkwardXY(array["x"], array["py"])
elif "px" in fields and "y" in fields:
return AzimuthalAwkwardXY(array["px"], array["y"])
elif "px" in fields and "py" in fields:
return AzimuthalAwkwardXY(array["px"], array["py"])
elif "rho" in fields and "phi" in fields:
return AzimuthalAwkwardRhoPhi(array["rho"], array["phi"])
elif "pt" in fields and "phi" in fields:
return AzimuthalAwkwardRhoPhi(array["pt"], array["phi"])
else:
raise ValueError(
"array does not have azimuthal coordinates (x/px, y/py or rho/pt, phi): "
f"{', '.join(fields)}"
)


Expand All @@ -123,7 +150,31 @@ def from_fields(cls, array: ak.Array) -> "LongitudinalAwkward":
return LongitudinalAwkwardEta(array["eta"])
else:
raise ValueError(
f"array does not have longitudinal coordinates (z/theta/eta): {', '.join(fields)}"
"array does not have longitudinal coordinates (z or theta or eta): "
f"{', '.join(fields)}"
)

@classmethod
def from_momentum_fields(cls, array: ak.Array) -> "LongitudinalAwkward":
"""
Create a :doc:`vector._backends.awkward_.LongitudinalAwkwardZ`, a
:doc:`vector._backends.awkward_.LongitudinalAwkwardTheta`, or a
:doc:`vector._backends.awkward_.LongitudinalAwkwardEta`, depending on
the fields in ``array``, allowing momentum synonyms.
"""
fields = ak.fields(array)
if "z" in fields:
return LongitudinalAwkwardZ(array["z"])
elif "pz" in fields:
return LongitudinalAwkwardZ(array["pz"])
elif "theta" in fields:
return LongitudinalAwkwardTheta(array["theta"])
elif "eta" in fields:
return LongitudinalAwkwardEta(array["eta"])
else:
raise ValueError(
"array does not have longitudinal coordinates (z/pz or theta or eta): "
f"{', '.join(fields)}"
)


Expand All @@ -142,7 +193,34 @@ def from_fields(cls, array: ak.Array) -> "TemporalAwkward":
return TemporalAwkwardTau(array["tau"])
else:
raise ValueError(
f"array does not have temporal coordinates (t/tau): {', '.join(fields)}"
"array does not have temporal coordinates (t or tau): "
f"{', '.join(fields)}"
)

@classmethod
def from_momentum_fields(cls, array: ak.Array) -> "TemporalAwkward":
"""
Create a :doc:`vector._backends.awkward_.TemporalT` or a
:doc:`vector._backends.awkward_.TemporalTau`, depending on
the fields in ``array``, allowing momentum synonyms.
"""
fields = ak.fields(array)
if "t" in fields:
return TemporalAwkwardT(array["t"])
elif "E" in fields:
return TemporalAwkwardT(array["E"])
elif "energy" in fields:
return TemporalAwkwardT(array["energy"])
elif "tau" in fields:
return TemporalAwkwardTau(array["tau"])
elif "M" in fields:
return TemporalAwkwardTau(array["M"])
elif "mass" in fields:
return TemporalAwkwardTau(array["mass"])
else:
raise ValueError(
"array does not have temporal coordinates (t/E/energy or tau/M/mass): "
f"{', '.join(fields)}"
)


Expand Down Expand Up @@ -544,7 +622,9 @@ def azimuthal(self) -> AzimuthalAwkward:


class MomentumAwkward2D(PlanarMomentum, VectorAwkward2D):
pass
@property
def azimuthal(self) -> AzimuthalAwkward:
return AzimuthalAwkward.from_momentum_fields(self)


class VectorAwkward3D(VectorAwkward, Spatial, Vector3D):
Expand All @@ -558,7 +638,13 @@ def longitudinal(self) -> LongitudinalAwkward:


class MomentumAwkward3D(SpatialMomentum, VectorAwkward3D):
pass
@property
def azimuthal(self) -> AzimuthalAwkward:
return AzimuthalAwkward.from_momentum_fields(self)

@property
def longitudinal(self) -> LongitudinalAwkward:
return LongitudinalAwkward.from_momentum_fields(self)


class VectorAwkward4D(VectorAwkward, Lorentz, Vector4D):
Expand All @@ -576,7 +662,17 @@ def temporal(self) -> TemporalAwkward:


class MomentumAwkward4D(LorentzMomentum, VectorAwkward4D):
pass
@property
def azimuthal(self) -> AzimuthalAwkward:
return AzimuthalAwkward.from_momentum_fields(self)

@property
def longitudinal(self) -> LongitudinalAwkward:
return LongitudinalAwkward.from_momentum_fields(self)

@property
def temporal(self) -> TemporalAwkward:
return TemporalAwkward.from_momentum_fields(self)


# ak.Array and ak.Record subclasses ###########################################
Expand Down