Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions pyaml/configuration/fileloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
logger = logging.getLogger(__name__)

accepted_suffixes = [".yaml", ".yml", ".json"]

FILE_PREFIX = "file:"

ROOT = {"path": Path.cwd().resolve()}

Expand Down Expand Up @@ -82,7 +82,7 @@ def load(


# Expand condition
def hasToExpand(value):
def hasToLoad(value):
return isinstance(value, str) and any(
value.endswith(suffix) for suffix in accepted_suffixes
)
Expand All @@ -105,8 +105,13 @@ def __init__(self, filename: str, parent_path_stack: list[Path]):
def expand_dict(self, d: dict):
for key, value in d.items():
try:
if hasToExpand(value):
d[key] = load(value, self.files_stack, self.use_fast_loader)
if hasToLoad(value):
if value.startswith(FILE_PREFIX):
# remove prefix
stripped_value = value[len(FILE_PREFIX) :]
d[key] = str(get_root_folder() / Path(stripped_value))
else:
d[key] = load(value, self.files_stack, self.use_fast_loader)
else:
self.expand(value)
except PyAMLConfigCyclingException as pyaml_ex:
Expand All @@ -127,7 +132,7 @@ def expand_dict(self, d: dict):
# Recursively expand a list
def expand_list(self, l: list):
for idx, value in enumerate(l):
if hasToExpand(value):
if hasToLoad(value):
l[idx] = load(value, self.files_stack)
else:
self.expand(value)
Expand Down
36 changes: 29 additions & 7 deletions pyaml/tuning_tools/orbit.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Literal, Optional
from typing import TYPE_CHECKING, Literal, Optional, Union

try:
from typing import Self # Python 3.11+
except ImportError:
from typing_extensions import Self # Python 3.10 and earlier


from pydantic import BaseModel, ConfigDict
from pydantic import ConfigDict

if TYPE_CHECKING:
from ..common.element_holder import ElementHolder
# from ..external.pySC.pySC import ResponseMatrix as pySC_ResponseMatrix
from ..arrays.magnet_array import MagnetArray
from ..common.element import Element, ElementConfigModel
from ..common.exception import PyAMLException
from ..configuration.factory import Factory
from ..configuration.fileloader import get_path, load
from ..external.pySC.pySC import ResponseMatrix as pySC_ResponseMatrix
from ..external.pySC.pySC.apps import orbit_correction
from ..external.pySC_interface import pySCInterface
Expand All @@ -34,21 +36,31 @@ class ConfigModel(ElementConfigModel):
hcorr_array_name: str
vcorr_array_name: str
singular_values: int
response_matrix: Optional[ResponseMatrix]
response_matrix: Union[str, ResponseMatrix]


class Orbit(Element):
def __init__(self, cfg: ConfigModel):
super().__init__(cfg.name)
self._cfg = cfg

self.bpm_array_name = cfg.bpm_array_name
self.hcorr_array_name = cfg.hcorr_array_name
self.vcorr_array_name = cfg.vcorr_array_name
self.singular_values = cfg.singular_values
self.response_matrix = pySC_ResponseMatrix.model_validate(
cfg.response_matrix._cfg.model_dump()
)

if type(cfg.response_matrix) is str:
response_matrix_filename = cfg.response_matrix
# assigns self.response_matrix
if Path(response_matrix_filename).exists():
self.load_response_matrix(response_matrix_filename)
else:
logger.warning(f"{response_matrix_filename} does not exist.")
self.response_matrix = None
else:
self.response_matrix = pySC_ResponseMatrix.model_validate(
cfg.response_matrix._cfg.model_dump()
)

self._hcorr: MagnetArray = None
self._vcorr: MagnetArray = None
self._hvcorr: MagnetArray = None
Expand All @@ -59,6 +71,9 @@ def correct(
gain: float = 1.0,
plane: Optional[Literal["H", "V"]] = None,
):
if self.response_matrix is None:
raise PyAMLException(f"{self.get_name()} does not have a response_matrix.")

interface = pySCInterface(
element_holder=self._peer,
bpm_array_name=self.bpm_array_name,
Expand Down Expand Up @@ -130,3 +145,10 @@ def attach(self, peer: "ElementHolder") -> Self:
obj = self.__class__(self._cfg)
obj._peer = peer
return obj

def load_response_matrix(self, filename: str) -> None:
path = Path(filename)
config_dict = load(str(path.resolve()))
rm = Factory.depth_first_build(config_dict, ignore_external=False)
self.response_matrix = pySC_ResponseMatrix.model_validate(rm._cfg.model_dump())
return None
2 changes: 1 addition & 1 deletion tests/config/EBSOrbit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,7 @@ devices:
vcorr_array_name: VCorr
name: DEFAULT_ORBIT_CORRECTION
singular_values: 162
response_matrix: ideal_orm_disp.json
response_matrix: file:ideal_orm_disp.json
- type: pyaml.rf.rf_plant
name: RF
masterclock:
Expand Down
Loading