diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml index 2a3adf3859..03cb440bca 100644 --- a/conda-recipe/meta.yaml +++ b/conda-recipe/meta.yaml @@ -39,7 +39,6 @@ requirements: - pineappl >=0.7.3 - eko >=0.14.2 - fiatlux - - frozendict # needed for caching of data loading - sphinx >=5.0.2,<6 # documentation. Needs pinning temporarily due to markdown - recommonmark - sphinx_rtd_theme >0.5 diff --git a/pyproject.toml b/pyproject.toml index 083545744f..970d3d5b1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,6 @@ pandas = "*" numpy = "*" validobj = "*" prompt_toolkit = "*" -frozendict = "*" # validphys: needed for caching of data loading # Reportengine needs to be installed from git reportengine = { git = "https://github.com/NNPDF/reportengine" } # Fit diff --git a/validphys2/src/validphys/config.py b/validphys2/src/validphys/config.py index c746995017..669e95553f 100644 --- a/validphys2/src/validphys/config.py +++ b/validphys2/src/validphys/config.py @@ -8,7 +8,6 @@ import numbers import pathlib -from frozendict import frozendict import pandas as pd from nnpdf_data import legacy_to_new_map @@ -41,7 +40,15 @@ from validphys.paramfits.config import ParamfitsConfig from validphys.plotoptions.core import get_info import validphys.scalevariations -from validphys.utils import freeze_args +from validphys.filters import (FilterDefaults, + AddedFilterRule, + FilterRule, + default_filter_settings_input, + Rule, + RuleProcessingError, + default_filter_rules_input + ) + log = logging.getLogger(__name__) @@ -1258,10 +1265,10 @@ def load_default_default_filter_rules(self, spec): ) def parse_filter_rules(self, filter_rules: (list, type(None))): - """A list of filter rules. See https://docs.nnpdf.science/vp/filters.html - for details on the syntax""" + """A tuple of FilterRule objects. Rules are immutable after parsing. + See https://docs.nnpdf.science/vp/filters.html for details on the syntax""" log.warning("Overwriting filter rules") - return filter_rules + return tuple(FilterRule(**rule) for rule in filter_rules) if filter_rules else None def parse_default_filter_rules_recorded_spec_(self, spec): """This function is a hacky fix for parsing the recorded spec @@ -1271,12 +1278,12 @@ def parse_default_filter_rules_recorded_spec_(self, spec): return spec def parse_added_filter_rules(self, rules: (list, type(None)) = None): - return rules + """ + Returns a tuple of AddedFilterRule objects. Rules are immutable after parsing. + AddedFilterRule objects inherit from FilterRule objects. + """ + return tuple(AddedFilterRule(**rule) for rule in rules) if rules else None - # Every parallel replica triggers a series of calls to this function, - # which should not happen since the rules are identical among replicas. - # E.g for NNPDF4.0 with 2 parallel replicas 693 calls, 3 parallel replicas 1001 calls... - @freeze_args @functools.lru_cache def produce_rules( self, @@ -1286,10 +1293,9 @@ def produce_rules( default_filter_rules=None, filter_rules=None, default_filter_rules_recorded_spec_=None, - added_filter_rules: (list, type(None)) = None, + added_filter_rules: (tuple, type(None)) = None, ): """Produce filter rules based on the user defined input and defaults.""" - from validphys.filters import Rule, RuleProcessingError, default_filter_rules_input theory_parameters = theoryid.get_description() @@ -1317,8 +1323,7 @@ def produce_rules( if added_filter_rules: for i, rule in enumerate(added_filter_rules): - if not isinstance(rule, (dict, frozendict)): - raise ConfigError(f"added rule {i} is not a dict") + try: rule_list.append( Rule( @@ -1331,7 +1336,7 @@ def produce_rules( except RuleProcessingError as e: raise ConfigError(f"Error processing added rule {i}: {e}") from e - return rule_list + return tuple(rule_list) @configparser.record_from_defaults def parse_default_filter_settings(self, spec: (str, type(None))): @@ -1359,10 +1364,25 @@ def load_default_default_filter_settings(self, spec): def parse_filter_defaults(self, filter_defaults: (dict, type(None))): """A mapping containing the default kinematic limits to be used when filtering data (when using internal cuts). - Currently these limits are ``q2min`` and ``w2min``. + Currently these limits are ``q2min``, ``w2min``, and ``maxTau``. + + Parameters + ---------- + filter_defaults: dict, None + A mapping containing the default kinematic limits to be used when + filtering data (when using internal cuts). + Currently these limits are ``q2min``, ``w2min``, and ``maxTau``. + + Returns + ------- + FilterDefaults + A hashable object containing the default kinematic limits to be used when + filtering data (when using internal cuts). + Currently these limits are ``q2min``, ``w2min``, and ``maxTau``. """ log.warning("Overwriting filter defaults") - return filter_defaults + parsed_filter_defaults = FilterDefaults(**filter_defaults) + return parsed_filter_defaults def produce_defaults( self, @@ -1370,34 +1390,40 @@ def produce_defaults( w2min=None, maxTau=None, default_filter_settings=None, - filter_defaults={}, + filter_defaults=None, default_filter_settings_recorded_spec_=None, ): """Produce default values for filters taking into account the values of ``q2min``, ``w2min`` and ``maxTau`` defined at namespace level and those inside a ``filter_defaults`` mapping. + + Within this function the hashable type FilterDefaults is turned into + a dictionary so as to allow for overwriting of the values of q2min, w2min and maxTau. + The dictionary is then turned back into a FilterDefaults object. """ - from validphys.filters import default_filter_settings_input + if filter_defaults is None: + filter_defaults = {} + + if isinstance(filter_defaults, FilterDefaults): + filter_defaults = filter_defaults.to_dict() if q2min is not None and "q2min" in filter_defaults and q2min != filter_defaults["q2min"]: raise ConfigError("q2min defined multiple times with different values") + if w2min is not None and "w2min" in filter_defaults and w2min != filter_defaults["w2min"]: raise ConfigError("w2min defined multiple times with different values") - if ( - maxTau is not None - and "maxTau" in filter_defaults - and maxTau != filter_defaults["maxTau"] - ): + if maxTau is not None and filter_defaults.get("maxTau", maxTau) != maxTau: raise ConfigError("maxTau defined multiple times with different values") if default_filter_settings_recorded_spec_ is not None: - filter_defaults = default_filter_settings_recorded_spec_[default_filter_settings] + filter_defaults = FilterDefaults(**default_filter_settings_recorded_spec_[default_filter_settings]) # If we find recorded specs return immediately and don't read q2min and w2min # from runcard return filter_defaults elif not filter_defaults: - filter_defaults = default_filter_settings_input() + # if filter_defaults have not been set, load the defaults with default_filter_settings_input + filter_defaults = default_filter_settings_input().to_dict() defaults_loaded = True else: defaults_loaded = False @@ -1413,7 +1439,9 @@ def produce_defaults( if maxTau is not None and defaults_loaded: log.warning("Using maxTau from runcard") filter_defaults["maxTau"] = maxTau - + + # Turn the dictionary back into a hashable FilterDefaults object + filter_defaults = FilterDefaults(**filter_defaults) return filter_defaults def produce_data(self, data_input, *, group_name="data"): diff --git a/validphys2/src/validphys/filters.py b/validphys2/src/validphys/filters.py index cded762eab..e4b937c894 100644 --- a/validphys2/src/validphys/filters.py +++ b/validphys2/src/validphys/filters.py @@ -7,6 +7,8 @@ from importlib.resources import read_text import logging import re +import dataclasses +from typing import Union import numpy as np @@ -14,7 +16,7 @@ from reportengine.compat import yaml import validphys.cuts from validphys.process_options import PROCESSES -from validphys.utils import freeze_args, generate_path_filtered_data +from validphys.utils import generate_path_filtered_data log = logging.getLogger(__name__) @@ -99,18 +101,63 @@ class FatalRuleError(Exception): """Exception raised when a rule application failed at runtime.""" +@dataclasses.dataclass(frozen=True) +class FilterDefaults: + """ + Dataclass carrying default values for filters (cuts) taking into + account the values of ``q2min``, ``w2min`` and ``maxTau``. + """ + q2min: float = None + w2min: float = None + maxTau: float = None + + def to_dict(self): + return dataclasses.asdict(self) + + +@dataclasses.dataclass(frozen=True) +class FilterRule: + """ + Dataclass which carries the filter rule information. + """ + dataset: str = None + process_type: str = None + rule: str = None + reason: str = None + local_variables: Mapping[str, Union[str, float]] = None + PTO: str = None + FNS: str = None + IC: str = None + + def to_dict(self): + rule_dict = dataclasses.asdict(self) + filtered_dict = {k: v for k, v in rule_dict.items() if v is not None} + return filtered_dict + + +@dataclasses.dataclass(frozen=True) +class AddedFilterRule(FilterRule): + """ + Dataclass which carries extra filter rule that is added to the + default rule. + """ + pass + + def default_filter_settings_input(): - """Return a dictionary with the default hardcoded filter settings. + """Return a FilterDefaults dataclass with the default hardcoded filter settings. These are defined in ``defaults.yaml`` in the ``validphys.cuts`` module. """ - return yaml.safe_load(read_text(validphys.cuts, "defaults.yaml")) + return FilterDefaults(**yaml.safe_load(read_text(validphys.cuts, "defaults.yaml"))) def default_filter_rules_input(): - """Return a dictionary with the input settings. + """ + Return a tuple of FilterRule objects. These are defined in ``filters.yaml`` in the ``validphys.cuts`` module. """ - return yaml.safe_load(read_text(validphys.cuts, "filters.yaml")) + list_rules = yaml.safe_load(read_text(validphys.cuts, "filters.yaml")) + return tuple(FilterRule(**rule) for rule in list_rules) def check_nonnegative(var: str): @@ -457,10 +504,19 @@ class Rule: numpy_functions = {"sqrt": np.sqrt, "log": np.log, "fabs": np.fabs} - def __init__(self, initial_data: dict, *, defaults: dict, theory_parameters: dict, loader=None): + def __init__(self, initial_data: FilterRule, *, defaults: dict, theory_parameters: dict, loader=None): self.dataset = None self.process_type = None self._local_variables_code = {} + + # For compatibility with legacy code that passed a dictionary + if isinstance(initial_data, FilterRule): + initial_data = initial_data.to_dict() + elif isinstance(initial_data, Mapping): + initial_data = dict(initial_data) + else: + raise RuleProcessingError("Expecting initial_data to be an instance of a FilterRule dataclass.") + for key in initial_data: setattr(self, key, initial_data[key]) @@ -511,7 +567,7 @@ def __init__(self, initial_data: dict, *, defaults: dict, theory_parameters: dic self.rule_string = self.rule self.defaults = defaults self.theory_params = theory_parameters - ns = {*self.numpy_functions, *self.defaults, *self.variables, "idat", "central_value"} + ns = {*self.numpy_functions, *self.defaults.to_dict().keys(), *self.variables, "idat", "central_value"} for k, v in self.local_variables.items(): try: self._local_variables_code[k] = lcode = compile( @@ -592,7 +648,7 @@ def __call__(self, dataset, idat): return eval( self.rule, self.numpy_functions, - {**{"idat": idat, "central_value": central_value}, **self.defaults, **ns}, + {**{"idat": idat, "central_value": central_value}, **self.defaults.to_dict(), **ns}, ) except Exception as e: # pragma: no cover raise FatalRuleError(f"Error when applying rule {self.rule_string!r}: {e}") from e @@ -628,7 +684,6 @@ def _make_point_namespace(self, dataset, idat) -> dict: return ns -@freeze_args @functools.lru_cache def get_cuts_for_dataset(commondata, rules) -> list: """Function to generate a list containing the index diff --git a/validphys2/src/validphys/loader.py b/validphys2/src/validphys/loader.py index f4ec530108..88017eed68 100644 --- a/validphys2/src/validphys/loader.py +++ b/validphys2/src/validphys/loader.py @@ -675,10 +675,10 @@ def check_default_filter_rules(self, theoryid, defaults=None): th_params = theoryid.get_description() if defaults is None: defaults = default_filter_settings_input() - return [ + return tuple( Rule(inp, defaults=defaults, theory_parameters=th_params, loader=self) for inp in default_filter_rules_input() - ] + ) def _check_theory_old_or_new(self, theoryid, commondata, cfac): """Given a theory and a commondata and a theory load the right fktable diff --git a/validphys2/src/validphys/tests/test_filter_rules.py b/validphys2/src/validphys/tests/test_filter_rules.py index 05ee29827d..09055fa821 100644 --- a/validphys2/src/validphys/tests/test_filter_rules.py +++ b/validphys2/src/validphys/tests/test_filter_rules.py @@ -103,7 +103,7 @@ def test_good_rules(): dsnames = ['ATLAS_1JET_8TEV_R06_PTY', 'NMC_NC_NOTFIXED_DW_EM-F2'] for dsname in dsnames: ds = l.check_dataset( - dsname, cuts='internal', rules=rules, theoryid=THEORYID, variant="legacy" + dsname, cuts='internal', rules=tuple(rules), theoryid=THEORYID, variant="legacy" ) assert ds.cuts.load() is not None @@ -120,7 +120,7 @@ def test_added_rules(): { "speclabel": "fewer data", "added_filter_rules": [ - {"dataset": "ATLAS_1JET_8TEV_R06_PTY", "rule": "pT < 1000", "reson": "pt cut"} + {"dataset": "ATLAS_1JET_8TEV_R06_PTY", "rule": "pT < 1000", "reason": "pt cut"} ], }, { diff --git a/validphys2/src/validphys/utils.py b/validphys2/src/validphys/utils.py index 64e0e6d5de..1c77fd5b5d 100644 --- a/validphys2/src/validphys/utils.py +++ b/validphys2/src/validphys/utils.py @@ -1,42 +1,11 @@ import contextlib -import functools import pathlib import shutil import tempfile -from typing import Any, Hashable, Mapping, Sequence -from frozendict import frozendict import numpy as np -def make_hashable(obj: Any): - # So that we don't infinitely recurse since frozenset and tuples - # are Sequences. - if isinstance(obj, Hashable): - return obj - elif isinstance(obj, Mapping): - return frozendict(obj) - elif isinstance(obj, Sequence): - return tuple([make_hashable(i) for i in obj]) - else: - raise ValueError("Object is not hashable") - - -def freeze_args(func): - """Transform mutable dictionary - Into immutable - Useful to be compatible with cache - """ - - @functools.wraps(func) - def wrapped(*args, **kwargs): - args = tuple([make_hashable(arg) for arg in args]) - kwargs = {k: make_hashable(v) for k, v in kwargs.items()} - return func(*args, **kwargs) - - return wrapped - - def generate_path_filtered_data(fit_path, setname): """Utility to ensure that both the loader and tools like setupfit utilize the same convention to generate the names of generated pseudodata"""