diff --git a/py/orbit/bunch_utils/__init__.py b/py/orbit/bunch_utils/__init__.py index 2263c33e..88b43d5f 100644 --- a/py/orbit/bunch_utils/__init__.py +++ b/py/orbit/bunch_utils/__init__.py @@ -7,5 +7,12 @@ from orbit.bunch_utils.particleidnumber import ParticleIdNumber +# This guards against missing numpy. +# Should be imporved with some meaningful (and MPI friendly?) warning printed out. +try: + from orbit.bunch_utils.serialize import collect_bunch, save_bunch, load_bunch +except: + pass + __all__ = [] __all__.append("addParticleIdNumbers") diff --git a/py/orbit/bunch_utils/meson.build b/py/orbit/bunch_utils/meson.build index fc1531cc..256d6f5e 100644 --- a/py/orbit/bunch_utils/meson.build +++ b/py/orbit/bunch_utils/meson.build @@ -3,11 +3,12 @@ py_sources = files([ '__init__.py', - 'particleidnumber.py' + 'particleidnumber.py', + 'serialize.py', ]) python.install_sources( py_sources, subdir: 'orbit/bunch_utils', # pure: true, -) +) diff --git a/py/orbit/bunch_utils/serialize.py b/py/orbit/bunch_utils/serialize.py new file mode 100644 index 00000000..1a8a0e3b --- /dev/null +++ b/py/orbit/bunch_utils/serialize.py @@ -0,0 +1,317 @@ +import os +import pathlib +from typing import Any, Protocol, TypedDict + +import numpy as np +from numpy.typing import NDArray + +from orbit.core import orbit_mpi +from orbit.core.bunch import Bunch + + +class SyncPartDict(TypedDict): + coords: NDArray[np.float64] + kin_energy: np.float64 + momentum: np.float64 + beta: np.float64 + gamma: np.float64 + time: np.float64 + + +class BunchDict(TypedDict): + coords: NDArray[np.float64] + sync_part: SyncPartDict + attributes: dict[str, np.float64 | np.int32] + + +class FileHandler(Protocol): + """Protocol for file handlers to read/write bunch data.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + + def read(self) -> BunchDict: ... + + def write(self, bunch: BunchDict) -> None: ... + + +class NumPyHandler: + """Handler implementing the FileHandler protocol for NumPy binary files. + This handler will create two files in the directory passed to the constructor: + - coords.npy: A memory-mapped NumPy array containing the bunch coordinates. + - attributes.npz: A NumPy archive containing data related to the synchronous particle and other bunch attributes. + """ + + _coords_fname = "coords.npy" + _attributes_fname = "attributes.npz" + + def __init__(self, dir_name: str | pathlib.Path): + if isinstance(dir_name, str): + dir_name = pathlib.Path(dir_name) + self._dir_name = dir_name + self._coords_path = dir_name / self._coords_fname + self._attributes_path = dir_name / self._attributes_fname + + def read(self) -> BunchDict: + if not self._coords_path.exists() or not self._attributes_path.exists(): + raise FileNotFoundError( + f"Required files not found in directory: {self._dir_name}" + ) + + coords = np.load(self._coords_path, mmap_mode="r") + + attr_data = np.load(self._attributes_path, allow_pickle=True) + + sync_part = attr_data["sync_part"].item() + attributes = attr_data["attributes"].item() + + return BunchDict(coords=coords, sync_part=sync_part, attributes=attributes) + + def write(self, bunch: BunchDict) -> None: + self._dir_name.mkdir(parents=True, exist_ok=True) + np.save(self._coords_path, bunch["coords"]) + np.savez( + self._attributes_path, + sync_part=bunch["sync_part"], + attributes=bunch["attributes"], + ) + + +def collect_bunch( + bunch: Bunch, output_dir: str | pathlib.Path = "/tmp", return_memmap: bool = True +) -> BunchDict | None: + """Collects attributes from a PyOrbit Bunch across all MPI ranks and returns it as a dictionary. + Parameters + ---------- + bunch : Bunch + The PyOrbit::Bunch object from which to collect attributes. + output_dir : str | pathlib.Path, optional + The director to use for temporary storage of the bunch coordinates on each MPI rank. + If None, the bunch will be stored in "/tmp". + Note: take care that the temporary files are created in a directory where all MPI ranks have write access. + return_memmap : bool, optional + Return the bunch coordinates as a memory-mapped NumPy array, otherwise the + entire array is copied into RAM and returned as normal NDArray. Default is True. + Returns + ------- + BunchDict | None + A dictionary containing the collected bunch attributes. Returns None if not on the root MPI rank or if the global bunch size is 0. + BunchDict structure: + { + "coords": NDArray[np.float64] of shape (N, 6) where N is the total number of macroparticles, + and the 6 columns correspond to [x, xp, y, yp, z, dE] in units of [m, rad, m, rad, m, GeV], respectively. + "sync_part": { + "coords": NDArray[np.float64] of shape (3,), + "kin_energy": np.float64, + "momentum": np.float64, + "beta": np.float64, + "gamma": np.float64, + "time": np.float64 + }, + "attributes": { + : , + ... + } + } + Raises + ------ + FileNotFoundError + If the temporary files created by non-root MPI ranks could not be found by the root rank during + the collection process. + """ + + global_size = bunch.getSizeGlobal() + + if global_size == 0: + return None + + mpi_comm = bunch.getMPIComm() + mpi_rank = orbit_mpi.MPI_Comm_rank(mpi_comm) + + coords_shape = (bunch.getSizeGlobal(), 6) + + local_rows = bunch.getSize() + + if isinstance(output_dir, str): + output_dir = pathlib.Path(output_dir) + + fname = output_dir / f"collect_bunch_tmpfile_{mpi_rank}.dat" + + local_shape = (local_rows, coords_shape[1]) + dtype = np.float64 + coords_memmap = np.memmap(fname, dtype=dtype, mode="w+", shape=local_shape) + + for i in range(local_rows): + coords_memmap[i, :] = ( + bunch.x(i), + bunch.xp(i), + bunch.y(i), + bunch.yp(i), + bunch.z(i), + bunch.dE(i), + ) + + coords_memmap.flush() + + bunch_dict: BunchDict = {"coords": None, "sync_part": {}, "attributes": {}} + + if mpi_rank == 0: + sync_part = bunch.getSyncParticle() + + bunch_dict["sync_part"] |= { + "coords": np.array(sync_part.pVector()), + "kin_energy": np.float64(sync_part.kinEnergy()), + "momentum": np.float64(sync_part.momentum()), + "beta": np.float64(sync_part.beta()), + "gamma": np.float64(sync_part.gamma()), + "time": np.float64(sync_part.time()), + } + + for attr in bunch.bunchAttrDoubleNames(): + bunch_dict["attributes"][attr] = np.float64(bunch.bunchAttrDouble(attr)) + + for attr in bunch.bunchAttrIntNames(): + bunch_dict["attributes"][attr] = np.int32(bunch.bunchAttrInt(attr)) + + orbit_mpi.MPI_Barrier(mpi_comm) + + if mpi_rank != 0: + return None + + coords_memmap = np.memmap(fname, dtype=dtype, mode="r+", shape=coords_shape) + + start_row = local_rows + + for r in range(1, orbit_mpi.MPI_Comm_size(mpi_comm)): + src_fname = output_dir / f"collect_bunch_tmpfile_{r}.dat" + + if not os.path.exists(src_fname): + raise FileNotFoundError( + f"Expected temporary file '{src_fname}' not found. Something went wrong." + ) + + src_memmap = np.memmap(src_fname, dtype=dtype, mode="r") + src_memmap = src_memmap.reshape((-1, coords_shape[1])) + + stop_row = start_row + src_memmap.shape[0] + + coords_memmap[start_row:stop_row, :] = src_memmap[:, :] + coords_memmap.flush() + + del src_memmap + os.remove(src_fname) + start_row = stop_row + + bunch_dict["coords"] = coords_memmap if return_memmap else np.array(coords_memmap) + + return bunch_dict + + +def save_bunch( + bunch: Bunch | BunchDict, + output_dir: str | pathlib.Path = "bunch_data/", + Handler: type[FileHandler] = NumPyHandler, +) -> None: + """Saves the collected bunch attributes to a specified directory. + Parameters + ---------- + bunch_dict : Bunch | BunchDict + The PyOrbit::Bunch object or the dictionary containing the collected bunch attributes. + output_dir : str, optional + The directory where the bunch data files will be saved. Default is "bunch_data/". + Handler : FileHandler, optional + The file handler class to use for writing the bunch data. Default is NumPyHandler. + Returns + ------- + None + Raises + ------ + ValueError + If the provided `bunch` is neither a Bunch instance nor a BunchDict. + """ + + if isinstance(bunch, Bunch): + mpi_comm = bunch.getMPIComm() + bunch = collect_bunch(bunch) + else: + mpi_comm = orbit_mpi.mpi_comm.MPI_COMM_WORLD + + mpi_rank = orbit_mpi.MPI_Comm_rank(mpi_comm) + + if mpi_rank != 0 or bunch is None: + return + + if bunch["coords"].shape[0] == 0: + print("No particles in the bunch to save.") + return + + if isinstance(output_dir, str): + output_dir = pathlib.Path(output_dir) + + handler = Handler(output_dir) + handler.write(bunch) + + +def load_bunch( + input_dir: str | pathlib.Path, Handler: type[FileHandler] = NumPyHandler +) -> tuple[Bunch, BunchDict]: + """Loads the bunch attributes from a specified directory containing NumPy binary files. + Parameters + ---------- + input_dir : str | pathlib.Path + The directory from which to load the bunch data files. + Handler : FileHandler, optional + The file handler class to use for reading the bunch data. Default is NumPyHandler. + See `orbit.bunch_utils.file_handler` for available handlers. + Returns + ------- + BunchDict + A dictionary containing the loaded bunch attributes. + Raises + ------ + FileNotFoundError + If the required files are not found in the specified directory. + TypeError + If an attribute in the loaded bunch has an unsupported type. + """ + mpi_comm = orbit_mpi.mpi_comm.MPI_COMM_WORLD + mpi_rank = orbit_mpi.MPI_Comm_rank(mpi_comm) + mpi_size = orbit_mpi.MPI_Comm_size(mpi_comm) + + handler = Handler(input_dir) + + bunch_dict = handler.read() + + coords = bunch_dict["coords"] + + global_size = coords.shape[0] + + local_size = global_size // mpi_size + remainder = global_size % mpi_size + if mpi_rank < remainder: + local_size += 1 + start_row = mpi_rank * local_size + else: + start_row = mpi_rank * local_size + remainder + stop_row = start_row + local_size + + local_coords = coords[start_row:stop_row, :] + + bunch = Bunch() + + for i in range(local_size): + bunch.addParticle(*local_coords[i, :]) + + for attr, value in bunch_dict["attributes"].items(): + if np.issubdtype(value, np.floating): + bunch.bunchAttrDouble(attr, value) + elif np.issubdtype(value, np.integer): + bunch.bunchAttrInt(attr, value) + else: + raise TypeError(f"Unsupported attribute type for '{attr}': {type(value)}") + + sync_part_obj = bunch.getSyncParticle() + sync_part_obj.rVector(tuple(bunch_dict["sync_part"]["coords"])) + sync_part_obj.kinEnergy(bunch_dict["sync_part"]["kin_energy"]) + sync_part_obj.time(bunch_dict["sync_part"]["time"]) + + return bunch, bunch_dict diff --git a/tests/py/orbit/bunch_utils/test_serialize.py b/tests/py/orbit/bunch_utils/test_serialize.py new file mode 100644 index 00000000..f5af2a05 --- /dev/null +++ b/tests/py/orbit/bunch_utils/test_serialize.py @@ -0,0 +1,78 @@ +from orbit.core.bunch import Bunch +from orbit.bunch_generators import GaussDist3D +from orbit.bunch_utils import collect_bunch + +from pytest import fixture + + +@fixture +def bunch(): + bunch = Bunch() + bunch.mass(0.939294) + bunch.charge(-1.0) + bunch.getSyncParticle().kinEnergy(0.0025) + gauss_dist = GaussDist3D() + for i in range(10): + x, xp, y, yp, z, dE = gauss_dist.getCoordinates() + bunch.addParticle(x, xp, y, yp, z, dE) + bunch.macroSize(10) + return bunch + + +def test_collect_bunch(bunch): + d = collect_bunch(bunch, return_memmap=False) + + n_particles = bunch.getSize() + + toplevel_keys = {"coords", "sync_part", "attributes"} + + attribute_keys = {"charge", "classical_radius", "mass", "macro_size"} + sync_part_keys = {"coords", "kin_energy", "momentum", "beta", "gamma", "time"} + + x, xp, y, yp, z, dE = [], [], [], [], [], [] + for i in range(n_particles): + x.append(bunch.x(i)) + xp.append(bunch.px(i)) + y.append(bunch.y(i)) + yp.append(bunch.py(i)) + z.append(bunch.z(i)) + dE.append(bunch.dE(i)) + + assert set(d.keys()) == toplevel_keys + assert set(d["sync_part"].keys()) == sync_part_keys + assert set(d["attributes"].keys()) == attribute_keys + + assert d["coords"].shape == (n_particles, 6) + + assert d["attributes"]["charge"] == bunch.bunchAttrDouble("charge") + assert d["attributes"]["classical_radius"] == bunch.bunchAttrDouble( + "classical_radius" + ) + assert d["attributes"]["mass"] == bunch.bunchAttrDouble("mass") + assert d["attributes"]["macro_size"] == bunch.bunchAttrDouble("macro_size") + + sync_part = bunch.getSyncParticle() + assert (d["sync_part"]["coords"] == sync_part.pVector()).all() + assert d["sync_part"]["kin_energy"] == sync_part.kinEnergy() + assert d["sync_part"]["momentum"] == sync_part.momentum() + assert d["sync_part"]["beta"] == sync_part.beta() + assert d["sync_part"]["gamma"] == sync_part.gamma() + assert d["sync_part"]["time"] == sync_part.time() + + +def test_collect_empty_bunch(): + bunch = Bunch() + d = collect_bunch(bunch) + assert d is None + + +def test_collect_arbitrary_bunch_attr(bunch): + bunch.bunchAttrDouble("arbitrary_dbl_attr", 42.0) + bunch.bunchAttrInt("arbitrary_int_attr", 42) + + d = collect_bunch(bunch) + + assert "arbitrary_dbl_attr" in d["attributes"] + assert d["attributes"]["arbitrary_dbl_attr"] == 42.0 + assert "arbitrary_int_attr" in d["attributes"] + assert d["attributes"]["arbitrary_int_attr"] == 42