diff --git a/src/pals/kinds/ACKicker.py b/src/pals/kinds/ACKicker.py new file mode 100644 index 0000000..4c6628e --- /dev/null +++ b/src/pals/kinds/ACKicker.py @@ -0,0 +1,14 @@ +from typing import Literal + +from .ThickElement import ThickElement +from ._warnings import under_construction + + +@under_construction("ACKicker") +class ACKicker(ThickElement): + """Time varying kicker element""" + + # Discriminator field + kind: Literal["ACKicker"] = "ACKicker" + + # Note: ACKickerP parameter group not yet implemented diff --git a/src/pals/kinds/BaseElement.py b/src/pals/kinds/BaseElement.py index 577e619..d12799e 100644 --- a/src/pals/kinds/BaseElement.py +++ b/src/pals/kinds/BaseElement.py @@ -1,5 +1,15 @@ from pydantic import BaseModel, ConfigDict -from typing import Literal +from typing import Literal, Optional + +from ..parameters import ( + ApertureParameters, + BodyShiftParameters, + FloorParameters, + MetaParameters, + ReferenceParameters, + ReferenceChangeParameters, + TrackingParameters, +) class BaseElement(BaseModel): @@ -15,8 +25,19 @@ class BaseElement(BaseModel): # element name name: str + # Common parameter groups (optional for all elements) + ApertureP: Optional[ApertureParameters] = None + BodyShiftP: Optional[BodyShiftParameters] = None + FloorP: Optional[FloorParameters] = None + MetaP: Optional[MetaParameters] = None + ReferenceP: Optional[ReferenceParameters] = None + ReferenceChangeP: Optional[ReferenceChangeParameters] = None + TrackingP: Optional[TrackingParameters] = None + def model_dump(self, *args, **kwargs): """This makes sure the element name property is moved out and up to a one-key dictionary""" + # Exclude None values from serialization + kwargs.setdefault("exclude_none", True) elem_dict = super().model_dump(*args, **kwargs) name = elem_dict.pop("name", None) if name is None: diff --git a/src/pals/kinds/BeamBeam.py b/src/pals/kinds/BeamBeam.py new file mode 100644 index 0000000..048ac29 --- /dev/null +++ b/src/pals/kinds/BeamBeam.py @@ -0,0 +1,16 @@ +from typing import Literal, Optional + +from .BaseElement import BaseElement +from ..parameters import BeamBeamParameters +from ._warnings import under_construction + + +@under_construction("BeamBeam") +class BeamBeam(BaseElement): + """Element for simulating colliding beams""" + + # Discriminator field + kind: Literal["BeamBeam"] = "BeamBeam" + + # Beam-beam-specific parameters + BeamBeamP: Optional[BeamBeamParameters] = None diff --git a/src/pals/kinds/BeamLine.py b/src/pals/kinds/BeamLine.py index 433df35..8d87e48 100644 --- a/src/pals/kinds/BeamLine.py +++ b/src/pals/kinds/BeamLine.py @@ -3,8 +3,38 @@ from .BaseElement import BaseElement from .ThickElement import ThickElement + +from .ACKicker import ACKicker +from .BeamBeam import BeamBeam +from .BeginningEle import BeginningEle +from .Converter import Converter +from .CrabCavity import CrabCavity from .Drift import Drift +from .EGun import EGun +from .Feedback import Feedback +from .Fiducial import Fiducial +from .FloorShift import FloorShift +from .Foil import Foil +from .Fork import Fork +from .Girder import Girder +from .Instrument import Instrument +from .Kicker import Kicker +from .Marker import Marker +from .Mask import Mask +from .Match import Match +from .Multipole import Multipole +from .NullEle import NullEle +from .Octupole import Octupole +from .Patch import Patch from .Quadrupole import Quadrupole +from .RBend import RBend +from .RFCavity import RFCavity +from .SBend import SBend +from .Sextupole import Sextupole +from .Solenoid import Solenoid +from .Taylor import Taylor +from .UnionEle import UnionEle +from .Wiggler import Wiggler class BeamLine(BaseElement): @@ -19,11 +49,42 @@ class BeamLine(BaseElement): line: List[ Annotated[ Union[ + # Base classes (for testing compatibility) BaseElement, ThickElement, + # User-Facing element kinds + "BeamLine", + ACKicker, + BeamBeam, + BeginningEle, + Converter, + CrabCavity, Drift, + EGun, + Feedback, + Fiducial, + FloorShift, + Foil, + Fork, + Girder, + Instrument, + Kicker, + Marker, + Mask, + Match, + Multipole, + NullEle, + Octupole, + Patch, Quadrupole, - "BeamLine", + RBend, + RFCavity, + SBend, + Sextupole, + Solenoid, + Taylor, + UnionEle, + Wiggler, ], Field(discriminator="kind"), ] diff --git a/src/pals/kinds/BeginningEle.py b/src/pals/kinds/BeginningEle.py new file mode 100644 index 0000000..0af00a7 --- /dev/null +++ b/src/pals/kinds/BeginningEle.py @@ -0,0 +1,12 @@ +from typing import Literal + +from .BaseElement import BaseElement +from ._warnings import under_construction + + +@under_construction("BeginningEle") +class BeginningEle(BaseElement): + """Initial element at start of a branch""" + + # Discriminator field + kind: Literal["BeginningEle"] = "BeginningEle" diff --git a/src/pals/kinds/Converter.py b/src/pals/kinds/Converter.py new file mode 100644 index 0000000..cdeee8d --- /dev/null +++ b/src/pals/kinds/Converter.py @@ -0,0 +1,17 @@ +from typing import Literal, Optional + +from .BaseElement import BaseElement +from ..parameters import ElectricMultipoleParameters, MagneticMultipoleParameters +from ._warnings import under_construction + + +@under_construction("Converter") +class Converter(BaseElement): + """Target to produce new species. EG: Positron converter""" + + # Discriminator field + kind: Literal["Converter"] = "Converter" + + # Converter-specific parameters + ElectricMultipoleP: Optional[ElectricMultipoleParameters] = None + MagneticMultipoleP: Optional[MagneticMultipoleParameters] = None diff --git a/src/pals/kinds/CrabCavity.py b/src/pals/kinds/CrabCavity.py new file mode 100644 index 0000000..70cc334 --- /dev/null +++ b/src/pals/kinds/CrabCavity.py @@ -0,0 +1,17 @@ +from typing import Literal, Optional + +from .ThickElement import ThickElement +from ..parameters import ElectricMultipoleParameters, MagneticMultipoleParameters +from ._warnings import under_construction + + +@under_construction("CrabCavity") +class CrabCavity(ThickElement): + """RF crab cavity""" + + # Discriminator field + kind: Literal["CrabCavity"] = "CrabCavity" + + # CrabCavity-specific parameters + ElectricMultipoleP: Optional[ElectricMultipoleParameters] = None + MagneticMultipoleP: Optional[MagneticMultipoleParameters] = None diff --git a/src/pals/kinds/Drift.py b/src/pals/kinds/Drift.py index 18bcfa9..e38c6f5 100644 --- a/src/pals/kinds/Drift.py +++ b/src/pals/kinds/Drift.py @@ -4,7 +4,7 @@ class Drift(ThickElement): - """A field free region""" + """Field free region""" # Discriminator field kind: Literal["Drift"] = "Drift" diff --git a/src/pals/kinds/EGun.py b/src/pals/kinds/EGun.py new file mode 100644 index 0000000..5dd719e --- /dev/null +++ b/src/pals/kinds/EGun.py @@ -0,0 +1,17 @@ +from typing import Literal, Optional + +from .ThickElement import ThickElement +from ..parameters import ElectricMultipoleParameters, MagneticMultipoleParameters +from ._warnings import under_construction + + +@under_construction("EGun") +class EGun(ThickElement): + """Electron gun""" + + # Discriminator field + kind: Literal["EGun"] = "EGun" + + # EGun-specific parameters + ElectricMultipoleP: Optional[ElectricMultipoleParameters] = None + MagneticMultipoleP: Optional[MagneticMultipoleParameters] = None diff --git a/src/pals/kinds/Feedback.py b/src/pals/kinds/Feedback.py new file mode 100644 index 0000000..e6db64a --- /dev/null +++ b/src/pals/kinds/Feedback.py @@ -0,0 +1,12 @@ +from typing import Literal + +from .BaseElement import BaseElement +from ._warnings import under_construction + + +@under_construction("Feedback") +class Feedback(BaseElement): + """Element used to simulate a feedback circuit""" + + # Discriminator field + kind: Literal["Feedback"] = "Feedback" diff --git a/src/pals/kinds/Fiducial.py b/src/pals/kinds/Fiducial.py new file mode 100644 index 0000000..cd107e2 --- /dev/null +++ b/src/pals/kinds/Fiducial.py @@ -0,0 +1,12 @@ +from typing import Literal + +from .BaseElement import BaseElement +from ._warnings import under_construction + + +@under_construction("Fiducial") +class Fiducial(BaseElement): + """Global coordinate system fiducial point""" + + # Discriminator field + kind: Literal["Fiducial"] = "Fiducial" diff --git a/src/pals/kinds/FloorShift.py b/src/pals/kinds/FloorShift.py new file mode 100644 index 0000000..fafac44 --- /dev/null +++ b/src/pals/kinds/FloorShift.py @@ -0,0 +1,14 @@ +from typing import Literal, Optional + +from .BaseElement import BaseElement +from ..parameters import FloorShiftParameters + + +class FloorShift(BaseElement): + """Global coordinates shift element""" + + # Discriminator field + kind: Literal["FloorShift"] = "FloorShift" + + # Floor shift-specific parameters + FloorShiftP: Optional[FloorShiftParameters] = None diff --git a/src/pals/kinds/Foil.py b/src/pals/kinds/Foil.py new file mode 100644 index 0000000..56fc756 --- /dev/null +++ b/src/pals/kinds/Foil.py @@ -0,0 +1,12 @@ +from typing import Literal + +from .BaseElement import BaseElement +from ._warnings import under_construction + + +@under_construction("Foil") +class Foil(BaseElement): + """Material that can strip electrons from a particle. Will also cause energy loss and diffusion""" + + # Discriminator field + kind: Literal["Foil"] = "Foil" diff --git a/src/pals/kinds/Fork.py b/src/pals/kinds/Fork.py new file mode 100644 index 0000000..e81f48c --- /dev/null +++ b/src/pals/kinds/Fork.py @@ -0,0 +1,16 @@ +from typing import Literal, Optional + +from .BaseElement import BaseElement +from ..parameters import ForkParameters +from ._warnings import under_construction + + +@under_construction("Fork") +class Fork(BaseElement): + """Element used to connect lattice branches together""" + + # Discriminator field + kind: Literal["Fork"] = "Fork" + + # Fork-specific parameters + ForkP: Optional[ForkParameters] = None diff --git a/src/pals/kinds/Girder.py b/src/pals/kinds/Girder.py new file mode 100644 index 0000000..8b3ea91 --- /dev/null +++ b/src/pals/kinds/Girder.py @@ -0,0 +1,12 @@ +from typing import Literal + +from .BaseElement import BaseElement +from ._warnings import under_construction + + +@under_construction("Girder") +class Girder(BaseElement): + """Element to support in space a group of other elements""" + + # Discriminator field + kind: Literal["Girder"] = "Girder" diff --git a/src/pals/kinds/Instrument.py b/src/pals/kinds/Instrument.py new file mode 100644 index 0000000..d7ac343 --- /dev/null +++ b/src/pals/kinds/Instrument.py @@ -0,0 +1,17 @@ +from typing import Literal, Optional + +from .ThickElement import ThickElement +from ..parameters import ElectricMultipoleParameters, MagneticMultipoleParameters +from ._warnings import under_construction + + +@under_construction("Instrument") +class Instrument(ThickElement): + """Measurement element""" + + # Discriminator field + kind: Literal["Instrument"] = "Instrument" + + # Instrument-specific parameters + ElectricMultipoleP: Optional[ElectricMultipoleParameters] = None + MagneticMultipoleP: Optional[MagneticMultipoleParameters] = None diff --git a/src/pals/kinds/Kicker.py b/src/pals/kinds/Kicker.py new file mode 100644 index 0000000..5ee5f3a --- /dev/null +++ b/src/pals/kinds/Kicker.py @@ -0,0 +1,17 @@ +from typing import Literal, Optional + +from .ThickElement import ThickElement +from ..parameters import ElectricMultipoleParameters, MagneticMultipoleParameters +from ._warnings import under_construction + + +@under_construction("Kicker") +class Kicker(ThickElement): + """Particle kicker element""" + + # Discriminator field + kind: Literal["Kicker"] = "Kicker" + + # Kicker-specific parameters + ElectricMultipoleP: Optional[ElectricMultipoleParameters] = None + MagneticMultipoleP: Optional[MagneticMultipoleParameters] = None diff --git a/src/pals/kinds/Marker.py b/src/pals/kinds/Marker.py new file mode 100644 index 0000000..b3c9fcd --- /dev/null +++ b/src/pals/kinds/Marker.py @@ -0,0 +1,12 @@ +from typing import Literal + +from .BaseElement import BaseElement +from ._warnings import under_construction + + +@under_construction("Marker") +class Marker(BaseElement): + """Zero length element to mark a particular position""" + + # Discriminator field + kind: Literal["Marker"] = "Marker" diff --git a/src/pals/kinds/Mask.py b/src/pals/kinds/Mask.py new file mode 100644 index 0000000..4222682 --- /dev/null +++ b/src/pals/kinds/Mask.py @@ -0,0 +1,17 @@ +from typing import Literal, Optional + +from .ThickElement import ThickElement +from ..parameters import ElectricMultipoleParameters, MagneticMultipoleParameters +from ._warnings import under_construction + + +@under_construction("Mask") +class Mask(ThickElement): + """Collimation element""" + + # Discriminator field + kind: Literal["Mask"] = "Mask" + + # Mask-specific parameters + ElectricMultipoleP: Optional[ElectricMultipoleParameters] = None + MagneticMultipoleP: Optional[MagneticMultipoleParameters] = None diff --git a/src/pals/kinds/Match.py b/src/pals/kinds/Match.py new file mode 100644 index 0000000..c3a1b8c --- /dev/null +++ b/src/pals/kinds/Match.py @@ -0,0 +1,12 @@ +from typing import Literal + +from .BaseElement import BaseElement +from ._warnings import under_construction + + +@under_construction("Match") +class Match(BaseElement): + """Orbit, Twiss, and dispersion matching element""" + + # Discriminator field + kind: Literal["Match"] = "Match" diff --git a/src/pals/kinds/Multipole.py b/src/pals/kinds/Multipole.py new file mode 100644 index 0000000..1fac2f6 --- /dev/null +++ b/src/pals/kinds/Multipole.py @@ -0,0 +1,17 @@ +from typing import Literal, Optional + +from .ThickElement import ThickElement +from ..parameters import ElectricMultipoleParameters, MagneticMultipoleParameters +from ._warnings import under_construction + + +@under_construction("Multipole") +class Multipole(ThickElement): + """Multipole element""" + + # Discriminator field + kind: Literal["Multipole"] = "Multipole" + + # Multipole-specific parameters + ElectricMultipoleP: Optional[ElectricMultipoleParameters] = None + MagneticMultipoleP: Optional[MagneticMultipoleParameters] = None diff --git a/src/pals/kinds/NullEle.py b/src/pals/kinds/NullEle.py new file mode 100644 index 0000000..7a34480 --- /dev/null +++ b/src/pals/kinds/NullEle.py @@ -0,0 +1,12 @@ +from typing import Literal + +from .BaseElement import BaseElement +from ._warnings import under_construction + + +@under_construction("NullEle") +class NullEle(BaseElement): + """Placeholder element used for bookkeeping""" + + # Discriminator field + kind: Literal["NullEle"] = "NullEle" diff --git a/src/pals/kinds/Octupole.py b/src/pals/kinds/Octupole.py new file mode 100644 index 0000000..3126d6d --- /dev/null +++ b/src/pals/kinds/Octupole.py @@ -0,0 +1,17 @@ +from typing import Literal, Optional + +from .ThickElement import ThickElement +from ..parameters import ElectricMultipoleParameters, MagneticMultipoleParameters +from ._warnings import under_construction + + +@under_construction("Octupole") +class Octupole(ThickElement): + """Octupole element""" + + # Discriminator field + kind: Literal["Octupole"] = "Octupole" + + # Octupole-specific parameters + ElectricMultipoleP: Optional[ElectricMultipoleParameters] = None + MagneticMultipoleP: Optional[MagneticMultipoleParameters] = None diff --git a/src/pals/kinds/Patch.py b/src/pals/kinds/Patch.py new file mode 100644 index 0000000..900a016 --- /dev/null +++ b/src/pals/kinds/Patch.py @@ -0,0 +1,16 @@ +from typing import Literal, Optional + +from .ThickElement import ThickElement +from ..parameters import PatchParameters +from ._warnings import under_construction + + +@under_construction("Patch") +class Patch(ThickElement): + """Crooked drift used to shift the reference curve""" + + # Discriminator field + kind: Literal["Patch"] = "Patch" + + # Patch-specific parameters + PatchP: Optional[PatchParameters] = None diff --git a/src/pals/kinds/Quadrupole.py b/src/pals/kinds/Quadrupole.py index 5702369..a83b671 100644 --- a/src/pals/kinds/Quadrupole.py +++ b/src/pals/kinds/Quadrupole.py @@ -1,14 +1,15 @@ -from typing import Literal +from typing import Literal, Optional from .ThickElement import ThickElement -from ..parameters import MagneticMultipoleParameters +from ..parameters import MagneticMultipoleParameters, ElectricMultipoleParameters class Quadrupole(ThickElement): - """A quadrupole element""" + """Quadrupole element""" # Discriminator field kind: Literal["Quadrupole"] = "Quadrupole" - # Magnetic multipole parameters + # Octupole-specific parameters + ElectricMultipoleP: Optional[ElectricMultipoleParameters] = None MagneticMultipoleP: MagneticMultipoleParameters diff --git a/src/pals/kinds/RBend.py b/src/pals/kinds/RBend.py new file mode 100644 index 0000000..b017119 --- /dev/null +++ b/src/pals/kinds/RBend.py @@ -0,0 +1,21 @@ +from typing import Literal, Optional + +from .ThickElement import ThickElement +from ..parameters import ( + BendParameters, + ElectricMultipoleParameters, + MagneticMultipoleParameters, +) + + +class RBend(ThickElement): + """A rectangular bend element""" + + # Discriminator field + kind: Literal["RBend"] = "RBend" + + # Bend-specific parameters + BendP: Optional[BendParameters] = None + + ElectricMultipoleP: Optional[ElectricMultipoleParameters] = None + MagneticMultipoleP: Optional[MagneticMultipoleParameters] = None diff --git a/src/pals/kinds/RFCavity.py b/src/pals/kinds/RFCavity.py new file mode 100644 index 0000000..76a21f2 --- /dev/null +++ b/src/pals/kinds/RFCavity.py @@ -0,0 +1,25 @@ +from typing import Literal, Optional + +from .ThickElement import ThickElement +from ..parameters import ( + RFParameters, + SolenoidParameters, + ElectricMultipoleParameters, + MagneticMultipoleParameters, +) +from ._warnings import under_construction + + +@under_construction("RFCavity") +class RFCavity(ThickElement): + """RF cavity element""" + + # Discriminator field + kind: Literal["RFCavity"] = "RFCavity" + + # RF-specific parameters + RFP: Optional[RFParameters] = None + SolenoidP: Optional[SolenoidParameters] = None + + ElectricMultipoleP: Optional[ElectricMultipoleParameters] = None + MagneticMultipoleP: Optional[MagneticMultipoleParameters] = None diff --git a/src/pals/kinds/SBend.py b/src/pals/kinds/SBend.py new file mode 100644 index 0000000..98a9eb6 --- /dev/null +++ b/src/pals/kinds/SBend.py @@ -0,0 +1,21 @@ +from typing import Literal, Optional + +from .ThickElement import ThickElement +from ..parameters import ( + BendParameters, + ElectricMultipoleParameters, + MagneticMultipoleParameters, +) + + +class SBend(ThickElement): + """A sector bend element""" + + # Discriminator field + kind: Literal["SBend"] = "SBend" + + # Bend-specific parameters + BendP: Optional[BendParameters] = None + + ElectricMultipoleP: Optional[ElectricMultipoleParameters] = None + MagneticMultipoleP: Optional[MagneticMultipoleParameters] = None diff --git a/src/pals/kinds/Sextupole.py b/src/pals/kinds/Sextupole.py new file mode 100644 index 0000000..a09c43b --- /dev/null +++ b/src/pals/kinds/Sextupole.py @@ -0,0 +1,17 @@ +from typing import Literal, Optional + +from .ThickElement import ThickElement +from ..parameters import ElectricMultipoleParameters, MagneticMultipoleParameters +from ._warnings import under_construction + + +@under_construction("Sextupole") +class Sextupole(ThickElement): + """Sextupole element""" + + # Discriminator field + kind: Literal["Sextupole"] = "Sextupole" + + # Sextupole-specific parameters + ElectricMultipoleP: Optional[ElectricMultipoleParameters] = None + MagneticMultipoleP: Optional[MagneticMultipoleParameters] = None diff --git a/src/pals/kinds/Solenoid.py b/src/pals/kinds/Solenoid.py new file mode 100644 index 0000000..939db64 --- /dev/null +++ b/src/pals/kinds/Solenoid.py @@ -0,0 +1,23 @@ +from typing import Literal, Optional + +from .ThickElement import ThickElement +from ..parameters import ( + SolenoidParameters, + ElectricMultipoleParameters, + MagneticMultipoleParameters, +) +from ._warnings import under_construction + + +@under_construction("Solenoid") +class Solenoid(ThickElement): + """Solenoid element""" + + # Discriminator field + kind: Literal["Solenoid"] = "Solenoid" + + # Solenoid-specific parameters + SolenoidP: Optional[SolenoidParameters] = None + + ElectricMultipoleP: Optional[ElectricMultipoleParameters] = None + MagneticMultipoleP: Optional[MagneticMultipoleParameters] = None diff --git a/src/pals/kinds/Taylor.py b/src/pals/kinds/Taylor.py new file mode 100644 index 0000000..77f4d12 --- /dev/null +++ b/src/pals/kinds/Taylor.py @@ -0,0 +1,12 @@ +from typing import Literal + +from .BaseElement import BaseElement +from ._warnings import under_construction + + +@under_construction("Taylor") +class Taylor(BaseElement): + """Taylor map element""" + + # Discriminator field + kind: Literal["Taylor"] = "Taylor" diff --git a/src/pals/kinds/UnionEle.py b/src/pals/kinds/UnionEle.py new file mode 100644 index 0000000..1bfa0d0 --- /dev/null +++ b/src/pals/kinds/UnionEle.py @@ -0,0 +1,14 @@ +from typing import List, Literal + +from .BaseElement import BaseElement + + +class UnionEle(BaseElement): + """Union element for overlapping elements""" + + # Discriminator field + kind: Literal["UnionEle"] = "UnionEle" + + # Elements in the union + # Note: https://github.com/campa-consortium/pals/issues/89 + elements: List[BaseElement] = [] diff --git a/src/pals/kinds/Wiggler.py b/src/pals/kinds/Wiggler.py new file mode 100644 index 0000000..a9e1510 --- /dev/null +++ b/src/pals/kinds/Wiggler.py @@ -0,0 +1,17 @@ +from typing import Literal, Optional + +from .ThickElement import ThickElement +from ..parameters import ElectricMultipoleParameters, MagneticMultipoleParameters +from ._warnings import under_construction + + +@under_construction("Wiggler") +class Wiggler(ThickElement): + """Wiggler element""" + + # Discriminator field + kind: Literal["Wiggler"] = "Wiggler" + + # Wiggler-specific parameters + ElectricMultipoleP: Optional[ElectricMultipoleParameters] = None + MagneticMultipoleP: Optional[MagneticMultipoleParameters] = None diff --git a/src/pals/kinds/__init__.py b/src/pals/kinds/__init__.py index 8b8b26b..0ff0f75 100644 --- a/src/pals/kinds/__init__.py +++ b/src/pals/kinds/__init__.py @@ -2,8 +2,37 @@ simpler import statements like `from pals import Drift`. """ +from .ACKicker import ACKicker # noqa: F401 from .BaseElement import BaseElement # noqa: F401 +from .BeamBeam import BeamBeam # noqa: F401 from .BeamLine import BeamLine # noqa: F401 +from .BeginningEle import BeginningEle # noqa: F401 +from .Converter import Converter # noqa: F401 +from .CrabCavity import CrabCavity # noqa: F401 from .Drift import Drift # noqa: F401 +from .EGun import EGun # noqa: F401 +from .Feedback import Feedback # noqa: F401 +from .Fiducial import Fiducial # noqa: F401 +from .FloorShift import FloorShift # noqa: F401 +from .Foil import Foil # noqa: F401 +from .Fork import Fork # noqa: F401 +from .Girder import Girder # noqa: F401 +from .Instrument import Instrument # noqa: F401 +from .Kicker import Kicker # noqa: F401 +from .Marker import Marker # noqa: F401 +from .Mask import Mask # noqa: F401 +from .Match import Match # noqa: F401 +from .Multipole import Multipole # noqa: F401 +from .NullEle import NullEle # noqa: F401 +from .Octupole import Octupole # noqa: F401 +from .Patch import Patch # noqa: F401 from .Quadrupole import Quadrupole # noqa: F401 +from .RBend import RBend # noqa: F401 +from .RFCavity import RFCavity # noqa: F401 +from .SBend import SBend # noqa: F401 +from .Sextupole import Sextupole # noqa: F401 +from .Solenoid import Solenoid # noqa: F401 +from .Taylor import Taylor # noqa: F401 from .ThickElement import ThickElement # noqa: F401 +from .UnionEle import UnionEle # noqa: F401 +from .Wiggler import Wiggler # noqa: F401 diff --git a/src/pals/kinds/_warnings.py b/src/pals/kinds/_warnings.py new file mode 100644 index 0000000..056f20d --- /dev/null +++ b/src/pals/kinds/_warnings.py @@ -0,0 +1,63 @@ +""" +Utility module for handling warnings for under construction elements. +""" + +import warnings +from typing import TypeVar + +T = TypeVar("T", bound=type) + + +def under_construction(element_name: str = None): + """ + Compact decorator to mark an element as under construction. + + Usage: + @under_construction("ElementName") + class MyElement(BaseElement): + pass + + Args: + element_name: Optional custom name for the element. If not provided, + uses the class name. + """ + + def decorator(cls: T) -> T: + # Store original __init__ method + original_init = cls.__init__ + + def new_init(self, *args, **kwargs): + # Call original __init__ first + original_init(self, *args, **kwargs) + # Issue warning after initialization + name = element_name or cls.__name__ + warnings.warn( + f"The {name} element is marked as 'Under Construction' in the PALS standard. " + f"Please refer to the PALS documentation for current status and limitations.", + UserWarning, + stacklevel=3, + ) + + # Replace __init__ method + cls.__init__ = new_init + + # Add warning to class docstring if not already present + if ( + not hasattr(cls, "__doc__") + or not cls.__doc__ + or "UNDER CONSTRUCTION" not in cls.__doc__.upper() + ): + original_doc = cls.__doc__ or "" + warning_doc = f""" +**UNDER CONSTRUCTION**: This element is marked as 'Under Construction' in the PALS standard. + +{original_doc.strip()} + +**Warning**: This element implementation may be incomplete or subject to change. +Please refer to the PALS documentation for the current status and any limitations. +""" + cls.__doc__ = warning_doc.strip() + + return cls + + return decorator diff --git a/src/pals/parameters/ApertureParameters.py b/src/pals/parameters/ApertureParameters.py new file mode 100644 index 0000000..f7477ef --- /dev/null +++ b/src/pals/parameters/ApertureParameters.py @@ -0,0 +1,27 @@ +from typing import Literal +from pydantic import BaseModel, Field, field_validator + + +class ApertureParameters(BaseModel): + """Aperture parameters""" + + @field_validator("x_limits", "y_limits") + @classmethod + def validate_limits(cls, v): + """Validate that limits are None or that min < max""" + if v[0] is not None and v[1] is not None and v[0] >= v[1]: + raise ValueError("Lower limit must be less than upper limit") + return v + + x_limits: list[float | None, float | None] = Field(default=[None, None]) + y_limits: list[float | None, float | None] = Field(default=[None, None]) + shape: Literal["RECTANGULAR", "ELLIPTICAL", "VERTICES", "CUSTOM_SHAPE"] = ( + "RECTANGULAR" + ) + location: Literal[ + "ENTRANCE_END", "CENTER", "EXIT_END", "BOTH_ENDS", "NOWHERE", "EVERYWHERE" + ] = "ENTRANCE_END" + material: str = "" + thickness: float = Field(default=0.0, ge=0.0) + aperture_shifts_with_body: bool = False + aperture_active: bool = True diff --git a/src/pals/parameters/BeamBeamParameters.py b/src/pals/parameters/BeamBeamParameters.py new file mode 100644 index 0000000..ba530c7 --- /dev/null +++ b/src/pals/parameters/BeamBeamParameters.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + + +class BeamBeamParameters(BaseModel): + """Beam-beam parameters""" + + # Parameters will be added when construction is complete diff --git a/src/pals/parameters/BendParameters.py b/src/pals/parameters/BendParameters.py new file mode 100644 index 0000000..4c1359e --- /dev/null +++ b/src/pals/parameters/BendParameters.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel + + +class BendParameters(BaseModel): + """Bend parameters""" + + rho_ref: float = 0.0 # [radian] Reference bend angle + bend_field_ref: float = 0.0 # [T] Reference bend field + e1: float = 0.0 # [radian] Entrance end pole face rotation with respect to a sector geometry + e2: float = 0.0 # [radian] Exit end pole face rotation with respect to a rectangular geometry + e1_rect: float = 0.0 # [radian] Entrance end pole face rotation with respect to a rectangular geometry + e2_rect: float = 0.0 # [radian] Exit end pole face rotation with respect to a rectangular geometry + edge_int1: float = 0.0 # [T*m] Entrance end fringe field integral + edge_int2: float = 0.0 # [T*m] Exit end fringe field integral + g_ref: float = 0.0 # [1/m] Reference bend strength = 1/radius_ref + h1: float = 0.0 # [TODO] Entrance end pole face curvature + h2: float = 0.0 # [TODO] Exit end pole face curvature + L_chord: float = 0.0 # [m] Chord length + L_sagitta: float = 0.0 # [m] Sagitta length (output parameter) + tilt_ref: float = 0.0 # [radian] Reference tilt diff --git a/src/pals/parameters/BodyShiftParameters.py b/src/pals/parameters/BodyShiftParameters.py new file mode 100644 index 0000000..9624a50 --- /dev/null +++ b/src/pals/parameters/BodyShiftParameters.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel + + +class BodyShiftParameters(BaseModel): + """Body shift parameters""" + + x_offset: float = 0.0 + y_offset: float = 0.0 + z_offset: float = 0.0 + x_rot: float = 0.0 + y_rot: float = 0.0 + z_rot: float = 0.0 diff --git a/src/pals/parameters/ElectricMultipoleParameters.py b/src/pals/parameters/ElectricMultipoleParameters.py new file mode 100644 index 0000000..52fe74f --- /dev/null +++ b/src/pals/parameters/ElectricMultipoleParameters.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel, ConfigDict + + +class ElectricMultipoleParameters(BaseModel): + """Electric multipole parameters""" + + # Allow arbitrary fields (TODO: remove this) + model_config = ConfigDict(extra="allow") + + # TODO: add ElectricMultipoleParameters in a follow-up RP + # https://pals-project.readthedocs.io/en/latest/element-parameters.html#electricmultipolep-electric-multipole-parameters diff --git a/src/pals/parameters/FloorParameters.py b/src/pals/parameters/FloorParameters.py new file mode 100644 index 0000000..74f80b8 --- /dev/null +++ b/src/pals/parameters/FloorParameters.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + + +class FloorParameters(BaseModel): + """Floor position and orientation parameters""" + + # Under construction diff --git a/src/pals/parameters/FloorShiftParameters.py b/src/pals/parameters/FloorShiftParameters.py new file mode 100644 index 0000000..e9bdfae --- /dev/null +++ b/src/pals/parameters/FloorShiftParameters.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel + + +class FloorShiftParameters(BaseModel): + """Floor shift parameters""" + + x_offset: float = 0.0 + y_offset: float = 0.0 + z_offset: float = 0.0 + t_offset: float = 0.0 + x_rot: float = 0.0 + y_rot: float = 0.0 + z_rot: float = 0.0 diff --git a/src/pals/parameters/ForkParameters.py b/src/pals/parameters/ForkParameters.py new file mode 100644 index 0000000..0738943 --- /dev/null +++ b/src/pals/parameters/ForkParameters.py @@ -0,0 +1,11 @@ +from typing import Literal +from pydantic import BaseModel + + +class ForkParameters(BaseModel): + """Fork parameters""" + + to_line: str = "" + to_ele: str = "" + direction: Literal["FORWARDS", "BACKWARDS"] = "FORWARDS" + propagate_reference: bool = True diff --git a/src/pals/parameters/MetaParameters.py b/src/pals/parameters/MetaParameters.py new file mode 100644 index 0000000..ff79506 --- /dev/null +++ b/src/pals/parameters/MetaParameters.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + + +class MetaParameters(BaseModel): + """Meta parameters""" + + alias: str = "" + ID: str = "" + label: str = "" + description: str = "" diff --git a/src/pals/parameters/PatchParameters.py b/src/pals/parameters/PatchParameters.py new file mode 100644 index 0000000..c9e946e --- /dev/null +++ b/src/pals/parameters/PatchParameters.py @@ -0,0 +1,16 @@ +from typing import Literal +from pydantic import BaseModel + + +class PatchParameters(BaseModel): + """Patch parameters""" + + x_offset: float = 0.0 + y_offset: float = 0.0 + z_offset: float = 0.0 + x_rot: float = 0.0 + y_rot: float = 0.0 + z_rot: float = 0.0 + flexible: bool = False + ref_coords: Literal["entrance_end", "exit_end"] = "exit_end" + user_sets_length: bool = False diff --git a/src/pals/parameters/RFParameters.py b/src/pals/parameters/RFParameters.py new file mode 100644 index 0000000..55c03c6 --- /dev/null +++ b/src/pals/parameters/RFParameters.py @@ -0,0 +1,14 @@ +from pydantic import BaseModel, Field + + +class RFParameters(BaseModel): + """RF parameters""" + + frequency: float = 0.0 # [Hz] RF frequency + harmon: int = 0 # [unitless] RF frequency harmonic number + voltage: float = 0.0 # [V] RF voltage + gradient: float = 0.0 # [V/m] RF gradient + phase: float = 0.0 # [unitless] RF phase in 0 to 2*pi + multipass_phase: float = 0.0 # [unitless] RF Phase added to multipass elements + cavity_type: str = "STANDING_WAVE" # [string] Cavity type + n_cell: int = Field(default=1, gt=0) # [unitless] Number of cavity cells diff --git a/src/pals/parameters/ReferenceChangeParameters.py b/src/pals/parameters/ReferenceChangeParameters.py new file mode 100644 index 0000000..b7bcde8 --- /dev/null +++ b/src/pals/parameters/ReferenceChangeParameters.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + + +class ReferenceChangeParameters(BaseModel): + """Reference energy change and/or reference time correction parameters""" + + dE_ref: float = 0.0 # Change in reference energy + extra_dtime_ref: float = 0.0 # Reference time deviation from nominal diff --git a/src/pals/parameters/ReferenceParameters.py b/src/pals/parameters/ReferenceParameters.py new file mode 100644 index 0000000..c1e84a3 --- /dev/null +++ b/src/pals/parameters/ReferenceParameters.py @@ -0,0 +1,15 @@ +from typing import Literal +from pydantic import BaseModel + + +class ReferenceParameters(BaseModel): + """Reference parameters""" + + species_ref: str = "" + pc_ref: float = 0.0 # [momentum*c] Reference momentum times speed of light + E_tot_ref: float = 0.0 # [eV] Reference total energy + time_ref: float = 0.0 # [s] Reference time + location: str = "" # Where reference parameters are evaluated + location: Literal[ + "UPSTREAM_END", "DOWNSTREAM_END" + ] # TODO: undefined default in PALS? diff --git a/src/pals/parameters/SolenoidParameters.py b/src/pals/parameters/SolenoidParameters.py new file mode 100644 index 0000000..2768b32 --- /dev/null +++ b/src/pals/parameters/SolenoidParameters.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + + +class SolenoidParameters(BaseModel): + """Solenoid parameters""" + + Ksol: float = 0.0 # Normalized solenoid strength + Bsol: float = 0.0 # Solenoid field diff --git a/src/pals/parameters/TrackingParameters.py b/src/pals/parameters/TrackingParameters.py new file mode 100644 index 0000000..e26139f --- /dev/null +++ b/src/pals/parameters/TrackingParameters.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + + +class TrackingParameters(BaseModel): + """Tracking parameters""" + + # Parameters will be added when construction is complete diff --git a/src/pals/parameters/__init__.py b/src/pals/parameters/__init__.py index 9a4d433..6334a7c 100644 --- a/src/pals/parameters/__init__.py +++ b/src/pals/parameters/__init__.py @@ -2,4 +2,19 @@ simpler import statements like `from pals import MagneticMultipoleParameters`. """ +from .ApertureParameters import ApertureParameters # noqa: F401 +from .BeamBeamParameters import BeamBeamParameters # noqa: F401 +from .BendParameters import BendParameters # noqa: F401 +from .BodyShiftParameters import BodyShiftParameters # noqa: F401 +from .ElectricMultipoleParameters import ElectricMultipoleParameters # noqa: F401 +from .FloorParameters import FloorParameters # noqa: F401 +from .FloorShiftParameters import FloorShiftParameters # noqa: F401 +from .ForkParameters import ForkParameters # noqa: F401 from .MagneticMultipoleParameters import MagneticMultipoleParameters # noqa: F401 +from .MetaParameters import MetaParameters # noqa: F401 +from .PatchParameters import PatchParameters # noqa: F401 +from .ReferenceChangeParameters import ReferenceChangeParameters # noqa: F401 +from .ReferenceParameters import ReferenceParameters # noqa: F401 +from .RFParameters import RFParameters # noqa: F401 +from .SolenoidParameters import SolenoidParameters # noqa: F401 +from .TrackingParameters import TrackingParameters # noqa: F401 diff --git a/tests/test_elements.py b/tests/test_elements.py index 940bf1c..a115484 100644 --- a/tests/test_elements.py +++ b/tests/test_elements.py @@ -1,17 +1,13 @@ +import pytest from pydantic import ValidationError -from pals import MagneticMultipoleParameters -from pals import BaseElement -from pals import ThickElement -from pals import Drift -from pals import Quadrupole -from pals import BeamLine +import pals def test_BaseElement(): # Create one base element with custom name element_name = "base_element" - element = BaseElement(name=element_name) + element = pals.BaseElement(name=element_name) assert element.name == element_name @@ -19,7 +15,7 @@ def test_ThickElement(): # Create one thick element with custom name and length element_name = "thick_element" element_length = 1.0 - element = ThickElement( + element = pals.ThickElement( name=element_name, length=element_length, ) @@ -28,20 +24,15 @@ def test_ThickElement(): # Try to assign negative length and # detect validation error without breaking pytest element_length = -1.0 - passed = True - try: + with pytest.raises(ValidationError): element.length = element_length - except ValidationError as e: - print(e) - passed = False - assert not passed def test_Drift(): # Create one drift element with custom name and length element_name = "drift_element" element_length = 1.0 - element = Drift( + element = pals.Drift( name=element_name, length=element_length, ) @@ -50,13 +41,8 @@ def test_Drift(): # Try to assign negative length and # detect validation error without breaking pytest element_length = -1.0 - passed = True - try: + with pytest.raises(ValidationError): element.length = element_length - except ValidationError as e: - print(e) - passed = False - assert not passed def test_Quadrupole(): @@ -71,7 +57,7 @@ def test_Quadrupole(): element_magnetic_multipole_Bs2 = 2.2 element_magnetic_multipole_tilt1 = 3.1 element_magnetic_multipole_tilt2 = 3.2 - element_magnetic_multipole = MagneticMultipoleParameters( + element_magnetic_multipole = pals.MagneticMultipoleParameters( Bn1=element_magnetic_multipole_Bn1, Bs1=element_magnetic_multipole_Bs1, tilt1=element_magnetic_multipole_tilt1, @@ -79,7 +65,7 @@ def test_Quadrupole(): Bs2=element_magnetic_multipole_Bs2, tilt2=element_magnetic_multipole_tilt2, ) - element = Quadrupole( + element = pals.Quadrupole( name=element_name, length=element_length, MagneticMultipoleP=element_magnetic_multipole, @@ -99,16 +85,384 @@ def test_Quadrupole(): def test_BeamLine(): # Create first line with one base element - element1 = BaseElement(name="element1") - line1 = BeamLine(name="line1", line=[element1]) + element1 = pals.Marker(name="element1") + line1 = pals.BeamLine(name="line1", line=[element1]) assert line1.line == [element1] # Extend first line with one thick element - element2 = ThickElement(name="element2", length=2.0) + element2 = pals.Drift(name="element2", length=2.0) line1.line.extend([element2]) assert line1.line == [element1, element2] # Create second line with one drift element - element3 = Drift(name="element3", length=3.0) - line2 = BeamLine(name="line2", line=[element3]) + element3 = pals.Drift(name="element3", length=3.0) + line2 = pals.BeamLine(name="line2", line=[element3]) # Extend first line with second line line1.line.extend(line2.line) assert line1.line == [element1, element2, element3] + + +def test_Marker(): + """Test Marker element""" + element = pals.Marker(name="marker1") + assert element.name == "marker1" + assert element.kind == "Marker" + + +def test_Sextupole(): + """Test Sextupole element""" + element = pals.Sextupole( + name="sext1", + length=0.5, + MagneticMultipoleP=pals.MagneticMultipoleParameters(Bn2=1.0), + ApertureP=pals.ApertureParameters(x_limits=[-0.1, 0.1]), + ) + assert element.name == "sext1" + assert element.length == 0.5 + assert element.kind == "Sextupole" + assert element.MagneticMultipoleP.Bn2 == 1.0 + assert element.ApertureP.x_limits == [-0.1, 0.1] + + +def test_Octupole(): + """Test Octupole element""" + element = pals.Octupole( + name="oct1", + length=0.3, + ElectricMultipoleP=pals.ElectricMultipoleParameters(En3=0.5), + MetaP=pals.MetaParameters(alias="octupole_test"), + ) + assert element.name == "oct1" + assert element.length == 0.3 + assert element.kind == "Octupole" + assert element.ElectricMultipoleP.En3 == 0.5 + assert element.MetaP.alias == "octupole_test" + + +def test_Multipole(): + """Test Multipole element""" + element = pals.Multipole( + name="mult1", + length=0.4, + MagneticMultipoleP=pals.MagneticMultipoleParameters(Bn1=2.0, Bn2=1.5), + BodyShiftP=pals.BodyShiftParameters(x_offset=0.01), + ) + assert element.name == "mult1" + assert element.length == 0.4 + assert element.kind == "Multipole" + assert element.MagneticMultipoleP.Bn1 == 2.0 + assert element.BodyShiftP.x_offset == 0.01 + + +def test_RBend(): + """Test RBend element""" + bend_params = pals.BendParameters(rho_ref=1.0, bend_field_ref=2.0) + element = pals.RBend( + name="rbend1", + length=1.0, + BendP=bend_params, + ApertureP=pals.ApertureParameters(x_limits=[-0.2, 0.2]), + MetaP=pals.MetaParameters(description="Test bend"), + ) + assert element.name == "rbend1" + assert element.length == 1.0 + assert element.kind == "RBend" + assert element.BendP.rho_ref == 1.0 + assert element.ApertureP.x_limits == [-0.2, 0.2] + assert element.MetaP.description == "Test bend" + + +def test_SBend(): + """Test SBend element""" + bend_params = pals.BendParameters(rho_ref=1.5, bend_field_ref=3.0) + element = pals.SBend( + name="sbend1", + length=1.2, + BendP=bend_params, + ReferenceP=pals.ReferenceParameters(species_ref="proton"), + ) + assert element.name == "sbend1" + assert element.length == 1.2 + assert element.kind == "SBend" + assert element.BendP.rho_ref == 1.5 + assert element.ReferenceP.species_ref == "proton" + + +def test_Solenoid(): + """Test Solenoid element""" + sol_params = pals.SolenoidParameters(Ksol=0.1, Bsol=0.2) + element = pals.Solenoid( + name="sol1", + length=0.8, + SolenoidP=sol_params, + ) + assert element.name == "sol1" + assert element.length == 0.8 + assert element.kind == "Solenoid" + assert element.SolenoidP.Ksol == 0.1 + + +def test_RFCavity(): + """Test RFCavity element""" + rf_params = pals.RFParameters(frequency=1e9, voltage=1e6) + element = pals.RFCavity( + name="rf1", + length=0.5, + RFP=rf_params, + SolenoidP=pals.SolenoidParameters(Ksol=0.05), + ) + assert element.name == "rf1" + assert element.length == 0.5 + assert element.kind == "RFCavity" + assert element.RFP.frequency == 1e9 + assert element.SolenoidP.Ksol == 0.05 + + +def test_Patch(): + """Test Patch element""" + patch_params = pals.PatchParameters(x_offset=0.1, y_offset=0.2) + element = pals.Patch( + name="patch1", + length=0.3, + PatchP=patch_params, + ReferenceChangeP=pals.ReferenceChangeParameters(dE_ref=1e6), + ) + assert element.name == "patch1" + assert element.length == 0.3 + assert element.kind == "Patch" + assert element.PatchP.x_offset == 0.1 + assert element.ReferenceChangeP.dE_ref == 1e6 + + +def test_FloorShift(): + """Test FloorShift element""" + floor_params = pals.FloorShiftParameters(x_offset=0.5, z_offset=1.0) + element = pals.FloorShift( + name="floor1", + FloorShiftP=floor_params, + MetaP=pals.MetaParameters(alias="floor_test"), + ) + assert element.name == "floor1" + assert element.kind == "FloorShift" + assert element.FloorShiftP.x_offset == 0.5 + assert element.MetaP.alias == "floor_test" + + +def test_Fork(): + """Test Fork element""" + fork_params = pals.ForkParameters(to_line="line1", direction="FORWARDS") + element = pals.Fork( + name="fork1", + ForkP=fork_params, + ReferenceP=pals.ReferenceParameters(species_ref="electron"), + ) + assert element.name == "fork1" + assert element.kind == "Fork" + assert element.ForkP.to_line == "line1" + assert element.ReferenceP.species_ref == "electron" + + +def test_BeamBeam(): + """Test BeamBeam element""" + bb_params = pals.BeamBeamParameters() + element = pals.BeamBeam( + name="bb1", + BeamBeamP=bb_params, + ApertureP=pals.ApertureParameters(x_limits=[-0.05, 0.05]), + ) + assert element.name == "bb1" + assert element.kind == "BeamBeam" + assert element.ApertureP.x_limits == [-0.05, 0.05] + + +def test_BeginningEle(): + """Test BeginningEle element""" + element = pals.BeginningEle( + name="begin1", MetaP=pals.MetaParameters(description="Start of lattice") + ) + assert element.name == "begin1" + assert element.kind == "BeginningEle" + assert element.MetaP.description == "Start of lattice" + + +def test_Fiducial(): + """Test Fiducial element""" + element = pals.Fiducial( + name="fid1", ReferenceP=pals.ReferenceParameters(species_ref="proton") + ) + assert element.name == "fid1" + assert element.kind == "Fiducial" + assert element.ReferenceP.species_ref == "proton" + + +def test_NullEle(): + """Test NullEle element""" + element = pals.NullEle(name="null1") + assert element.name == "null1" + assert element.kind == "NullEle" + + +def test_Kicker(): + """Test Kicker element""" + element = pals.Kicker( + name="kick1", + length=0.2, + MagneticMultipoleP=pals.MagneticMultipoleParameters(Bn1=0.5), + ElectricMultipoleP=pals.ElectricMultipoleParameters(En1=0.3), + ) + assert element.name == "kick1" + assert element.length == 0.2 + assert element.kind == "Kicker" + assert element.MagneticMultipoleP.Bn1 == 0.5 + assert element.ElectricMultipoleP.En1 == 0.3 + + +def test_ACKicker(): + """Test ACKicker element""" + element = pals.ACKicker(name="ackick1", length=0.15) + assert element.name == "ackick1" + assert element.length == 0.15 + assert element.kind == "ACKicker" + + +def test_CrabCavity(): + """Test CrabCavity element""" + element = pals.CrabCavity( + name="crab1", + length=0.25, + MagneticMultipoleP=pals.MagneticMultipoleParameters(Bn1=0.8), + ElectricMultipoleP=pals.ElectricMultipoleParameters(En1=0.4), + ) + assert element.name == "crab1" + assert element.length == 0.25 + assert element.kind == "CrabCavity" + assert element.MagneticMultipoleP.Bn1 == 0.8 + assert element.ElectricMultipoleP.En1 == 0.4 + + +def test_EGun(): + """Test EGun element""" + element = pals.EGun( + name="egun1", + length=0.1, + MagneticMultipoleP=pals.MagneticMultipoleParameters(Bn1=1.2), + ElectricMultipoleP=pals.ElectricMultipoleParameters(En1=0.6), + ) + assert element.name == "egun1" + assert element.length == 0.1 + assert element.kind == "EGun" + assert element.MagneticMultipoleP.Bn1 == 1.2 + assert element.ElectricMultipoleP.En1 == 0.6 + + +def test_Feedback(): + """Test Feedback element""" + element = pals.Feedback( + name="fb1", MetaP=pals.MetaParameters(alias="feedback_test") + ) + assert element.name == "fb1" + assert element.kind == "Feedback" + assert element.MetaP.alias == "feedback_test" + + +def test_Girder(): + """Test Girder element""" + element = pals.Girder(name="girder1") + assert element.name == "girder1" + assert element.kind == "Girder" + + +def test_Instrument(): + """Test Instrument element""" + element = pals.Instrument( + name="inst1", + length=0.05, + MagneticMultipoleP=pals.MagneticMultipoleParameters(Bn1=0.2), + ElectricMultipoleP=pals.ElectricMultipoleParameters(En1=0.1), + ) + assert element.name == "inst1" + assert element.length == 0.05 + assert element.kind == "Instrument" + assert element.MagneticMultipoleP.Bn1 == 0.2 + assert element.ElectricMultipoleP.En1 == 0.1 + + +def test_Mask(): + """Test Mask element""" + element = pals.Mask( + name="mask1", + length=0.02, + MagneticMultipoleP=pals.MagneticMultipoleParameters(Bn1=0.15), + ElectricMultipoleP=pals.ElectricMultipoleParameters(En1=0.08), + ) + assert element.name == "mask1" + assert element.length == 0.02 + assert element.kind == "Mask" + assert element.MagneticMultipoleP.Bn1 == 0.15 + assert element.ElectricMultipoleP.En1 == 0.08 + + +def test_Match(): + """Test Match element""" + element = pals.Match( + name="match1", BodyShiftP=pals.BodyShiftParameters(x_offset=0.01, y_rot=0.02) + ) + assert element.name == "match1" + assert element.kind == "Match" + assert element.BodyShiftP.x_offset == 0.01 + assert element.BodyShiftP.y_rot == 0.02 + + +def test_Taylor(): + """Test Taylor element""" + element = pals.Taylor( + name="taylor1", + ReferenceChangeP=pals.ReferenceChangeParameters( + dE_ref=1e6, extra_dtime_ref=1e-9 + ), + ) + assert element.name == "taylor1" + assert element.kind == "Taylor" + assert element.ReferenceChangeP.dE_ref == 1e6 + assert element.ReferenceChangeP.extra_dtime_ref == 1e-9 + + +def test_Wiggler(): + """Test Wiggler element""" + element = pals.Wiggler( + name="wig1", + length=2.0, + MagneticMultipoleP=pals.MagneticMultipoleParameters(Bn1=0.5), + ElectricMultipoleP=pals.ElectricMultipoleParameters(En1=0.3), + ) + assert element.name == "wig1" + assert element.length == 2.0 + assert element.kind == "Wiggler" + assert element.MagneticMultipoleP.Bn1 == 0.5 + assert element.ElectricMultipoleP.En1 == 0.3 + + +def test_Converter(): + """Test Converter element""" + element = pals.Converter( + name="conv1", + MagneticMultipoleP=pals.MagneticMultipoleParameters(Bn1=0.4), + ElectricMultipoleP=pals.ElectricMultipoleParameters(En1=0.2), + ) + assert element.name == "conv1" + assert element.kind == "Converter" + assert element.MagneticMultipoleP.Bn1 == 0.4 + assert element.ElectricMultipoleP.En1 == 0.2 + + +def test_Foil(): + """Test Foil element""" + element = pals.Foil(name="foil1") + assert element.name == "foil1" + assert element.kind == "Foil" + + +def test_UnionEle(): + """Test UnionEle element""" + element = pals.UnionEle(name="union1", elements=[]) + assert element.name == "union1" + assert element.kind == "UnionEle" + assert element.elements == [] diff --git a/tests/test_parameters.py b/tests/test_parameters.py new file mode 100644 index 0000000..21eb8f9 --- /dev/null +++ b/tests/test_parameters.py @@ -0,0 +1,104 @@ +import pytest +from pydantic import ValidationError + +from pals import ( + ApertureParameters, + BeamBeamParameters, + BendParameters, + BodyShiftParameters, + FloorShiftParameters, + ForkParameters, + MagneticMultipoleParameters, + MetaParameters, + PatchParameters, + ReferenceChangeParameters, + ReferenceParameters, + RFParameters, + SolenoidParameters, + # TrackingParameters, # not yet tested +) + + +def test_ParameterClasses(): + """Test parameter classes""" + # Test ApertureParameters + aperture = ApertureParameters(x_limits=[-0.1, 0.1], y_limits=[-0.05, 0.05]) + assert aperture.x_limits == [-0.1, 0.1] + + with pytest.raises(ValidationError): + _ = ApertureParameters( + x_limits=[-0.1, 0.1], y_limits=[-0.05, 0.05, 0.1], shape="wrong" + ) + + # Test BodyShiftParameters + body_shift = BodyShiftParameters(x_offset=0.01, y_rot=0.02) + assert body_shift.x_offset == 0.01 + + # Test MetaParameters + meta = MetaParameters(alias="test", description="test element") + assert meta.alias == "test" + + # Test ElectricMultipoleParameters (TODO) + # emp = ElectricMultipoleParameters(En1=1.0, Es1=0.5) + # assert emp.En1 == 1.0 + + # Test MagneticMultipoleParameters + mmp = MagneticMultipoleParameters(Bn1=1.0, Bs1=0.5) + assert mmp.Bn1 == 1.0 + assert mmp.Bs1 == 0.5 + + # catch typos + with pytest.raises(ValidationError): + _ = MagneticMultipoleParameters(Bm1=1.0, Bs1=0.5) + with pytest.raises(ValidationError): + _ = MagneticMultipoleParameters(Bn1=1.0, Bv1=0.5) + with pytest.raises(ValidationError): + _ = MagneticMultipoleParameters(Bn01=1.0, Bs01=0.5) + + # Test SolenoidParameters + sol = SolenoidParameters(Ksol=0.1, Bsol=0.2) + assert sol.Ksol == 0.1 + + # Test RFParameters + rf = RFParameters(frequency=1e9, voltage=1e6) + assert rf.frequency == 1e9 + + with pytest.raises(ValidationError): + _ = RFParameters(frequency=1e9, voltage=1e6, n_cell=0) + with pytest.raises(ValidationError): + _ = RFParameters(frequency=1e9, voltage=1e6, n_cell=-1) + + # Test BendParameters + bend = BendParameters(rho_ref=1.0, bend_field_ref=2.0) + assert bend.rho_ref == 1.0 + + # Test PatchParameters + patch = PatchParameters(x_offset=0.1, flexible=True) + assert patch.x_offset == 0.1 + + # Test FloorShiftParameters + floor = FloorShiftParameters(x_offset=0.5, z_offset=1.0) + assert floor.x_offset == 0.5 + + # Test ForkParameters + fork = ForkParameters(to_line="line1", direction="FORWARDS") + assert fork.to_line == "line1" + + # Test ReferenceParameters + ref = ReferenceParameters(species_ref="electron", pc_ref=1e6) + assert ref.species_ref == "electron" + + # TODO: Test TrackingParameters + # tracking = TrackingParameters(...) + # assert tracking.i.. + + # TODO: Test FloorParameters + + # Test ReferenceChangeParameters + ref_change = ReferenceChangeParameters(extra_dtime_ref=1e6, dE_ref=1e-9) + assert ref_change.extra_dtime_ref == 1e6 + assert ref_change.dE_ref == 1e-9 + + # Test BeamBeamParameters + beambeam = BeamBeamParameters() + assert beambeam is not None diff --git a/tests/test_serialization.py b/tests/test_serialization.py index eb718cd..b4b7181 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -2,18 +2,16 @@ import os import yaml -from pals import BaseElement -from pals import ThickElement -from pals import BeamLine +import pals def test_yaml(): # Create one base element - element1 = BaseElement(name="element1") + element1 = pals.Marker(name="element1") # Create one thick element - element2 = ThickElement(name="element2", length=2.0) + element2 = pals.Drift(name="element2", length=2.0) # Create line with both elements - line = BeamLine(name="line", line=[element1, element2]) + line = pals.BeamLine(name="line", line=[element1, element2]) # Serialize the BeamLine object to YAML yaml_data = yaml.dump(line.model_dump(), default_flow_style=False) print(f"\n{yaml_data}") @@ -25,7 +23,7 @@ def test_yaml(): with open(test_file, "r") as file: yaml_data = yaml.safe_load(file) # Parse the YAML data back into a BeamLine object - loaded_line = BeamLine(**yaml_data) + loaded_line = pals.BeamLine(**yaml_data) # Remove the test file os.remove(test_file) # Validate loaded BeamLine object @@ -34,11 +32,11 @@ def test_yaml(): def test_json(): # Create one base element - element1 = BaseElement(name="element1") + element1 = pals.Marker(name="element1") # Create one thick element - element2 = ThickElement(name="element2", length=2.0) + element2 = pals.Drift(name="element2", length=2.0) # Create line with both elements - line = BeamLine(name="line", line=[element1, element2]) + line = pals.BeamLine(name="line", line=[element1, element2]) # Serialize the BeamLine object to JSON json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2) print(f"\n{json_data}") @@ -50,8 +48,276 @@ def test_json(): with open(test_file, "r") as file: json_data = json.loads(file.read()) # Parse the JSON data back into a BeamLine object - loaded_line = BeamLine(**json_data) + loaded_line = pals.BeamLine(**json_data) # Remove the test file os.remove(test_file) # Validate loaded BeamLine object assert line == loaded_line + + +def test_comprehensive_lattice(): + """Test a comprehensive lattice using every PALS element at least once""" + + # Create elements in alphabetical order for easy maintenance + # ACKicker + ackicker = pals.ACKicker(name="ackicker1", length=0.1) + + # BeamBeam + beambeam = pals.BeamBeam(name="beambeam1", BeamBeamP=pals.BeamBeamParameters()) + + # BeginningEle + beginning = pals.BeginningEle(name="beginning1") + + # Converter + converter = pals.Converter(name="converter1") + + # CrabCavity + crabcavity = pals.CrabCavity(name="crabcavity1", length=0.2) + + # Drift + drift = pals.Drift(name="drift1", length=0.5) + + # EGun + egun = pals.EGun(name="egun1", length=0.15) + + # Feedback + feedback = pals.Feedback(name="feedback1") + + # Fiducial + fiducial = pals.Fiducial(name="fiducial1") + + # FloorShift + floorshift = pals.FloorShift( + name="floorshift1", FloorShiftP=pals.FloorShiftParameters(x_offset=0.1) + ) + + # Foil + foil = pals.Foil(name="foil1") + + # Fork + fork = pals.Fork(name="fork1", ForkP=pals.ForkParameters(to_line="line1")) + + # Girder + girder = pals.Girder(name="girder1") + + # Instrument + instrument = pals.Instrument(name="instrument1", length=0.05) + + # Kicker + kicker = pals.Kicker(name="kicker1", length=0.1) + + # Marker + marker = pals.Marker(name="marker1") + + # Mask + mask = pals.Mask(name="mask1", length=0.02) + + # Match + match = pals.Match(name="match1") + + # Multipole + multipole = pals.Multipole(name="multipole1", length=0.3) + + # NullEle + nullele = pals.NullEle(name="nullele1") + + # Octupole + octupole = pals.Octupole( + name="octupole1", + length=0.25, + ElectricMultipoleP=pals.ElectricMultipoleParameters(En3=0.5), + MetaP=pals.MetaParameters(alias="octupole_test"), + ) + + # Patch + patch = pals.Patch( + name="patch1", length=0.4, PatchP=pals.PatchParameters(x_offset=0.05) + ) + + # Quadrupole + quadrupole = pals.Quadrupole( + name="quadrupole1", + length=0.8, + MagneticMultipoleP=pals.MagneticMultipoleParameters(Bn1=1.0), + ) + + # RBend + rbend = pals.RBend( + name="rbend1", + length=1.0, + BendP=pals.BendParameters(rho_ref=2.0), + ApertureP=pals.ApertureParameters(x_limits=[-0.2, 0.2]), + ) + + # RFCavity + rfcavity = pals.RFCavity( + name="rfcavity1", + length=0.3, + RFP=pals.RFParameters(frequency=1e9), + SolenoidP=pals.SolenoidParameters(Ksol=0.05), + ) + + # SBend + sbend = pals.SBend( + name="sbend1", length=1.2, BendP=pals.BendParameters(rho_ref=1.5) + ) + + # Sextupole + sextupole = pals.Sextupole( + name="sextupole1", + length=0.2, + MagneticMultipoleP=pals.MagneticMultipoleParameters(Bn2=1.0), + ApertureP=pals.ApertureParameters(x_limits=[-0.1, 0.1]), + ) + + # Solenoid + solenoid = pals.Solenoid( + name="solenoid1", length=0.6, SolenoidP=pals.SolenoidParameters(Ksol=0.1) + ) + + # Taylor + taylor = pals.Taylor(name="taylor1") + + # UnionEle + unionele = pals.UnionEle(name="unionele1", elements=[]) + + # Wiggler + wiggler = pals.Wiggler(name="wiggler1", length=2.0) + + # Create comprehensive lattice + lattice = pals.BeamLine( + name="comprehensive_lattice", + line=[ + beginning, # Start with beginning element + fiducial, # Global coordinate reference + marker, # Mark position + drift, # Field-free region + quadrupole, # Focusing element + sextupole, # Chromatic correction + octupole, # Higher order correction + multipole, # General multipole + rbend, # Rectangular bend + sbend, # Sector bend + solenoid, # Longitudinal focusing + rfcavity, # RF acceleration + crabcavity, # RF crab cavity + kicker, # Transverse kick + ackicker, # AC kicker + patch, # Coordinate transformation + floorshift, # Global coordinate shift + instrument, # Measurement device + mask, # Collimation + match, # Matching element + egun, # Electron source + converter, # Species conversion + foil, # Electron stripping + beambeam, # Colliding beams + feedback, # Feedback system + girder, # Support structure + fork, # Branch connection + taylor, # Taylor map + unionele, # Overlapping elements + wiggler, # Undulator + nullele, # Placeholder + ], + ) + + # Test serialization to YAML + yaml_data = yaml.dump(lattice.model_dump(), default_flow_style=False) + print(f"\nComprehensive lattice YAML:\n{yaml_data}") + + # Write to temporary file + yaml_file = "comprehensive_lattice.yaml" + with open(yaml_file, "w") as file: + file.write(yaml_data) + + # Read back from file + with open(yaml_file, "r") as file: + loaded_yaml_data = yaml.safe_load(file) + + # Deserialize back to Python object using Pydantic model logic + loaded_lattice = pals.BeamLine(**loaded_yaml_data) + + # Verify the loaded lattice has the correct structure and parameter groups + assert len(loaded_lattice.line) == 31 # Should have 31 elements + + # Verify specific elements with parameter groups are correctly loaded + sextupole_loaded = None + octupole_loaded = None + rbend_loaded = None + rfcavity_loaded = None + + for elem in loaded_lattice.line: + if elem.name == "sextupole1": + sextupole_loaded = elem + elif elem.name == "octupole1": + octupole_loaded = elem + elif elem.name == "rbend1": + rbend_loaded = elem + elif elem.name == "rfcavity1": + rfcavity_loaded = elem + + # Test that parameter groups are correctly deserialized + assert sextupole_loaded.MagneticMultipoleP.Bn2 == 1.0 + assert sextupole_loaded.ApertureP.x_limits == [-0.1, 0.1] + + assert octupole_loaded.ElectricMultipoleP.En3 == 0.5 + assert octupole_loaded.MetaP.alias == "octupole_test" + + assert rbend_loaded.BendP.rho_ref == 2.0 + assert rbend_loaded.ApertureP.x_limits == [-0.2, 0.2] + + assert rfcavity_loaded.RFP.frequency == 1e9 + assert rfcavity_loaded.SolenoidP.Ksol == 0.05 + + # Test serialization to JSON + json_data = json.dumps(lattice.model_dump(), sort_keys=True, indent=2) + print(f"\nComprehensive lattice JSON:\n{json_data}") + + # Write to temporary file + json_file = "comprehensive_lattice.json" + with open(json_file, "w") as file: + file.write(json_data) + + # Read back from file + with open(json_file, "r") as file: + loaded_json_data = json.loads(file.read()) + + # Deserialize back to Python object using Pydantic model logic + loaded_lattice_json = pals.BeamLine(**loaded_json_data) + + # Verify the loaded lattice has the correct structure and parameter groups + assert len(loaded_lattice_json.line) == 31 # Should have 31 elements + + # Verify specific elements with parameter groups are correctly loaded + sextupole_loaded_json = None + octupole_loaded_json = None + rbend_loaded_json = None + rfcavity_loaded_json = None + + for elem in loaded_lattice_json.line: + if elem.name == "sextupole1": + sextupole_loaded_json = elem + elif elem.name == "octupole1": + octupole_loaded_json = elem + elif elem.name == "rbend1": + rbend_loaded_json = elem + elif elem.name == "rfcavity1": + rfcavity_loaded_json = elem + + # Test that parameter groups are correctly deserialized + assert sextupole_loaded_json.MagneticMultipoleP.Bn2 == 1.0 + assert sextupole_loaded_json.ApertureP.x_limits == [-0.1, 0.1] + + assert octupole_loaded_json.ElectricMultipoleP.En3 == 0.5 + assert octupole_loaded_json.MetaP.alias == "octupole_test" + + assert rbend_loaded_json.BendP.rho_ref == 2.0 + assert rbend_loaded_json.ApertureP.x_limits == [-0.2, 0.2] + + assert rfcavity_loaded_json.RFP.frequency == 1e9 + assert rfcavity_loaded_json.SolenoidP.Ksol == 0.05 + + # Clean up temporary files + os.remove(yaml_file) + os.remove(json_file)