diff --git a/pyaml/configuration/fileloader.py b/pyaml/configuration/fileloader.py index 085af3cd..6df60c1d 100644 --- a/pyaml/configuration/fileloader.py +++ b/pyaml/configuration/fileloader.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) accepted_suffixes = [".yaml", ".yml", ".json"] - +FILE_PREFIX = "file:" ROOT = {"path": Path.cwd().resolve()} @@ -82,7 +82,7 @@ def load( # Expand condition -def hasToExpand(value): +def hasToLoad(value): return isinstance(value, str) and any( value.endswith(suffix) for suffix in accepted_suffixes ) @@ -105,8 +105,13 @@ def __init__(self, filename: str, parent_path_stack: list[Path]): def expand_dict(self, d: dict): for key, value in d.items(): try: - if hasToExpand(value): - d[key] = load(value, self.files_stack, self.use_fast_loader) + if hasToLoad(value): + if value.startswith(FILE_PREFIX): + # remove prefix + stripped_value = value[len(FILE_PREFIX) :] + d[key] = str(get_root_folder() / Path(stripped_value)) + else: + d[key] = load(value, self.files_stack, self.use_fast_loader) else: self.expand(value) except PyAMLConfigCyclingException as pyaml_ex: @@ -127,7 +132,7 @@ def expand_dict(self, d: dict): # Recursively expand a list def expand_list(self, l: list): for idx, value in enumerate(l): - if hasToExpand(value): + if hasToLoad(value): l[idx] = load(value, self.files_stack) else: self.expand(value) diff --git a/pyaml/tuning_tools/orbit.py b/pyaml/tuning_tools/orbit.py index 20ea94dd..8c8f8785 100644 --- a/pyaml/tuning_tools/orbit.py +++ b/pyaml/tuning_tools/orbit.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import TYPE_CHECKING, Literal, Optional +from typing import TYPE_CHECKING, Literal, Optional, Union try: from typing import Self # Python 3.11+ @@ -8,7 +8,7 @@ from typing_extensions import Self # Python 3.10 and earlier -from pydantic import BaseModel, ConfigDict +from pydantic import ConfigDict if TYPE_CHECKING: from ..common.element_holder import ElementHolder @@ -16,6 +16,8 @@ from ..arrays.magnet_array import MagnetArray from ..common.element import Element, ElementConfigModel from ..common.exception import PyAMLException +from ..configuration.factory import Factory +from ..configuration.fileloader import get_path, load from ..external.pySC.pySC import ResponseMatrix as pySC_ResponseMatrix from ..external.pySC.pySC.apps import orbit_correction from ..external.pySC_interface import pySCInterface @@ -34,21 +36,31 @@ class ConfigModel(ElementConfigModel): hcorr_array_name: str vcorr_array_name: str singular_values: int - response_matrix: Optional[ResponseMatrix] + response_matrix: Union[str, ResponseMatrix] class Orbit(Element): def __init__(self, cfg: ConfigModel): super().__init__(cfg.name) self._cfg = cfg - self.bpm_array_name = cfg.bpm_array_name self.hcorr_array_name = cfg.hcorr_array_name self.vcorr_array_name = cfg.vcorr_array_name self.singular_values = cfg.singular_values - self.response_matrix = pySC_ResponseMatrix.model_validate( - cfg.response_matrix._cfg.model_dump() - ) + + if type(cfg.response_matrix) is str: + response_matrix_filename = cfg.response_matrix + # assigns self.response_matrix + if Path(response_matrix_filename).exists(): + self.load_response_matrix(response_matrix_filename) + else: + logger.warning(f"{response_matrix_filename} does not exist.") + self.response_matrix = None + else: + self.response_matrix = pySC_ResponseMatrix.model_validate( + cfg.response_matrix._cfg.model_dump() + ) + self._hcorr: MagnetArray = None self._vcorr: MagnetArray = None self._hvcorr: MagnetArray = None @@ -59,6 +71,9 @@ def correct( gain: float = 1.0, plane: Optional[Literal["H", "V"]] = None, ): + if self.response_matrix is None: + raise PyAMLException(f"{self.get_name()} does not have a response_matrix.") + interface = pySCInterface( element_holder=self._peer, bpm_array_name=self.bpm_array_name, @@ -130,3 +145,10 @@ def attach(self, peer: "ElementHolder") -> Self: obj = self.__class__(self._cfg) obj._peer = peer return obj + + def load_response_matrix(self, filename: str) -> None: + path = Path(filename) + config_dict = load(str(path.resolve())) + rm = Factory.depth_first_build(config_dict, ignore_external=False) + self.response_matrix = pySC_ResponseMatrix.model_validate(rm._cfg.model_dump()) + return None diff --git a/tests/config/EBSOrbit.yaml b/tests/config/EBSOrbit.yaml index 10d95ba6..2699a112 100644 --- a/tests/config/EBSOrbit.yaml +++ b/tests/config/EBSOrbit.yaml @@ -934,7 +934,7 @@ devices: vcorr_array_name: VCorr name: DEFAULT_ORBIT_CORRECTION singular_values: 162 - response_matrix: ideal_orm_disp.json + response_matrix: file:ideal_orm_disp.json - type: pyaml.rf.rf_plant name: RF masterclock: