From 941ee06ab17608fcb4a1603fb54dadcf6b3fb061 Mon Sep 17 00:00:00 2001 From: Ken Kroenlein Date: Tue, 2 Apr 2024 10:36:48 -0600 Subject: [PATCH 1/2] Add custom UnitRegistry to handle scaling factors more cleanly --- gemd/__version__.py | 2 +- gemd/units/impl.py | 125 +++++++++++++++++++++++++------------ requirements.txt | 1 + setup.py | 3 +- tests/units/test_parser.py | 17 ++++- 5 files changed, 103 insertions(+), 45 deletions(-) diff --git a/gemd/__version__.py b/gemd/__version__.py index 8c0d5d5b..9aa3f903 100644 --- a/gemd/__version__.py +++ b/gemd/__version__.py @@ -1 +1 @@ -__version__ = "2.0.0" +__version__ = "2.1.0" diff --git a/gemd/units/impl.py b/gemd/units/impl.py index 94a2c011..1cbdc691 100644 --- a/gemd/units/impl.py +++ b/gemd/units/impl.py @@ -1,4 +1,5 @@ """Implementation of units.""" +from deprecation import deprecated import functools from importlib.resources import read_text import os @@ -7,7 +8,7 @@ from tempfile import TemporaryDirectory from typing import Union, List, Tuple, Generator, Any -from pint import UnitRegistry, Unit, register_unit_format +from pint import UnitRegistry, register_unit_format try: # Pint 0.23 migrated the location of this method, and augmented it from pint.pint_eval import tokenizer except ImportError: # pragma: no cover @@ -131,6 +132,7 @@ def _scaling_identify_factors( """ todo = [] for block in blocks: + # Note: while Python does not recognize ^ as exponentiation, pint does i_exp = next((i for i, t in enumerate(block) if t.string in {"**", "^"}), len(block)) i_name = next((i for i, t in enumerate(block) if t.type == NAME), None) numbers = [(i, t.string) for i, t in enumerate(block) if t.type == NUMBER and i < i_exp] @@ -168,10 +170,14 @@ def _scaling_store_and_mangle(input_string: str, todo: List[Tuple[str, str, str] """ for scaled_term, number_string, unit_string in todo: regex = rf"(? str: return _scaling_store_and_mangle(input_string, todo) -_REGISTRY: UnitRegistry = None # global requires it be defined in this scope +def _unmangle_scaling(input_string: str) -> str: + """Convert mangled scaling values into a pint-compatible expression.""" + number_re = r'\b_(_)?(\d+)(_\d+)?([eE]_?\d+)?(_(?=[a-zA-Z]))?' + while match := re.search(number_re, input_string): + replacement = '' if match.group(1) is None else '-' + replacement += match.group(2) + replacement += '' if match.group(3) is None else match.group(3).replace('_', '.') + replacement += '' if match.group(4) is None else match.group(4).replace('_', '-') + replacement += '' if match.group(5) is None else match.group(5).replace('_', ' ') + input_string = input_string.replace(match.group(0), replacement) + return input_string + + +try: # pragma: no cover + # Pint 0.23 modified the preferred way to derive a custom class + # https://pint.readthedocs.io/en/0.23/advanced/custom-registry-class.html + from pint.registry import GenericUnitRegistry + from typing_extensions import TypeAlias + + class _ScaleFactorUnit(UnitRegistry.Unit): + """Child class of Units for generating units w/ clean scaling factors.""" + + def __format__(self, format_spec): + result = super().__format__(format_spec) + return _unmangle_scaling(result) + + class _ScaleFactorQuantity(UnitRegistry.Quantity): + """Child class of Quantity for generating units w/ clean scaling factors.""" + + pass + + class _ScaleFactorRegistry(GenericUnitRegistry[_ScaleFactorQuantity, _ScaleFactorUnit]): + """UnitRegistry class that uses _GemdUnits.""" + + Quantity: TypeAlias = _ScaleFactorQuantity + Unit: TypeAlias = _ScaleFactorUnit + +except ImportError: # pragma: no cover + # https://pint.readthedocs.io/en/0.21/advanced/custom-registry-class.html + class _ScaleFactorUnit(UnitRegistry.Unit): + """Child class of Units for generating units w/ clean scaling factors.""" + + def __format__(self, format_spec): + result = super().__format__(format_spec) + return _unmangle_scaling(result) + + class _ScaleFactorRegistry(UnitRegistry): + """UnitRegistry class that uses _GemdUnits.""" + + _unit_class = _ScaleFactorUnit + +_REGISTRY: _ScaleFactorRegistry = None # global requires it be defined in this scope @functools.lru_cache(maxsize=1024 * 1024) @@ -244,38 +301,23 @@ def convert_units(value: float, starting_unit: str, final_unit: str) -> float: @register_unit_format("clean") +@deprecated(deprecated_in="2.1.0", removed_in="3.0.0", details="Scaling factor clean-up ") def _format_clean(unit, registry, **options): - """Formatter that turns scaling-factor-units into numbers again.""" - numerator = [] - denominator = [] - for u, p in unit.items(): - if re.match(r"_[\d_]+", u): - # Munged scaling factor; grab symbol, which is the prettier - u = registry.get_symbol(u) - - if p == 1: - numerator.append(u) - elif p > 0: - numerator.append(f"{u} ** {p}") - elif p == -1: - denominator.append(u) - elif p < 0: - denominator.append(f"{u} ** {-p}") - - if len(numerator) == 0: - numerator = ["1"] - - if len(denominator) > 0: - return " / ".join((" * ".join(numerator), " / ".join(denominator))) - else: - return " * ".join(numerator) + """ + DEPRECATED Formatter that turns scaling-factor-units into numbers again. + + Responsibility for this piece of clean-up has been shifted to a custom class. + + """ + from pint.formatting import _FORMATTERS + return _FORMATTERS["D"](unit, registry, **options) @functools.lru_cache(maxsize=1024) -def parse_units(units: Union[str, Unit, None], +def parse_units(units: Union[str, _ScaleFactorUnit, None], *, return_unit: bool = False - ) -> Union[str, Unit, None]: + ) -> Union[str, _ScaleFactorUnit, None]: """ Parse a string or Unit into a standard string representation of the unit. @@ -298,19 +340,20 @@ def parse_units(units: Union[str, Unit, None], else: return None elif isinstance(units, str): - parsed = _REGISTRY.parse_units(units) + # SPT-1311 Protect against leaked mangled strings + parsed = _REGISTRY.parse_units(_unmangle_scaling(units)) if return_unit: return parsed else: - return f"{parsed:clean}" - elif isinstance(units, Unit): + return f"{parsed}" + elif isinstance(units, _ScaleFactorUnit): return units else: raise UndefinedUnitError("Units must be given as a recognized unit string or Units object") @functools.lru_cache(maxsize=1024) -def get_base_units(units: Union[str, Unit]) -> Tuple[Unit, float, float]: +def get_base_units(units: Union[str, _ScaleFactorUnit]) -> Tuple[_ScaleFactorUnit, float, float]: """ Get the base units and conversion factors for the given unit. @@ -358,13 +401,13 @@ def change_definitions_file(filename: str = None): path = Path(target) os.chdir(path.parent) # Need to re-verify path because of some slippiness around tmp on MacOS - _REGISTRY = UnitRegistry(filename=Path.cwd() / path.name, - preprocessors=[_space_after_minus_preprocessor, - _scientific_notation_preprocessor, - _scaling_preprocessor - ], - autoconvert_offset_to_baseunit=True - ) + _REGISTRY = _ScaleFactorRegistry(filename=Path.cwd() / path.name, + preprocessors=[_space_after_minus_preprocessor, + _scientific_notation_preprocessor, + _scaling_preprocessor + ], + autoconvert_offset_to_baseunit=True + ) finally: os.chdir(current_dir) diff --git a/requirements.txt b/requirements.txt index 5c53781c..5bf007ea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ pint==0.20 deprecation==2.1.0 +typing-extensions==4.8.0 diff --git a/setup.py b/setup.py index 4954131f..5fd4a733 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,8 @@ }, install_requires=[ "pint>=0.20,<0.24", - "deprecation>=2.1.0,<3" + "deprecation>=2.1.0,<3", + "typing_extensions>=4.8,<5" ], extras_require={ "tests": [ diff --git a/tests/units/test_parser.py b/tests/units/test_parser.py index 88c7e33a..85696b2e 100644 --- a/tests/units/test_parser.py +++ b/tests/units/test_parser.py @@ -1,4 +1,5 @@ from contextlib import contextmanager +from deprecation import DeprecatedWarning from importlib.resources import read_binary import re from pint import UnitRegistry @@ -30,7 +31,10 @@ def test_parse_expected(return_unit): "g / -+-25e-1 m", # Weird but fine "ug / - -250 mL", # Spaces between unaries is acceptable "1 / 10**5 degC", # Spaces between unaries is acceptable - "m ** - 1" # Pint < 0.21 throws DefinitionSyntaxError + "1 / 10_000 degC", # Spaces between unaries is acceptable + "m ** - 1", # Pint < 0.21 throws DefinitionSyntaxError + "gram / _10_minute", # Stringified Unit object SPT-1311 + "gram / __1_2e_3minute", # Stringified Unit object SPT-1311 ] for unit in expected: parsed = parse_units(unit, return_unit=return_unit) @@ -205,8 +209,17 @@ def test_exponents(): def test__scientific_notation_preprocessor(): """Verify that numbers are converted into scientific notation.""" - assert "1e2 kg" in parse_units("F* 10 ** 2 kg") + assert "1e2 kilogram" in parse_units("F* 10 ** 2 kg") + assert "1e2 kg" in f'{parse_units("F* 10 ** 2 kg", return_unit=True):~}' assert "1e-5" in parse_units("F* mm*10**-5") assert "1e" not in parse_units("F* kg * 10 cm") assert "-3.07e2" in parse_units("F* -3.07 * 10 ** 2") assert "11e2" in parse_units("F* 11*10^2") + + +def test_deprecation(): + """Make sure deprecated things warn correctly.""" + megapascals = parse_units("MPa", return_unit=True) + with pytest.warns(DeprecatedWarning): + stringified = f"{megapascals:clean}" + assert megapascals == parse_units(stringified, return_unit=True) From 4c38728b67255172a424ed0aa19e1d3376a1bf1f Mon Sep 17 00:00:00 2001 From: Ken Kroenlein Date: Tue, 2 Apr 2024 11:28:59 -0600 Subject: [PATCH 2/2] Maintain support for base pint Unit objects --- gemd/units/impl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gemd/units/impl.py b/gemd/units/impl.py index 1cbdc691..2b0830a3 100644 --- a/gemd/units/impl.py +++ b/gemd/units/impl.py @@ -314,10 +314,10 @@ def _format_clean(unit, registry, **options): @functools.lru_cache(maxsize=1024) -def parse_units(units: Union[str, _ScaleFactorUnit, None], +def parse_units(units: Union[str, UnitRegistry.Unit, None], *, return_unit: bool = False - ) -> Union[str, _ScaleFactorUnit, None]: + ) -> Union[str, UnitRegistry.Unit, None]: """ Parse a string or Unit into a standard string representation of the unit. @@ -346,14 +346,14 @@ def parse_units(units: Union[str, _ScaleFactorUnit, None], return parsed else: return f"{parsed}" - elif isinstance(units, _ScaleFactorUnit): + elif isinstance(units, UnitRegistry.Unit): return units else: raise UndefinedUnitError("Units must be given as a recognized unit string or Units object") @functools.lru_cache(maxsize=1024) -def get_base_units(units: Union[str, _ScaleFactorUnit]) -> Tuple[_ScaleFactorUnit, float, float]: +def get_base_units(units: Union[str, UnitRegistry.Unit]) -> Tuple[UnitRegistry.Unit, float, float]: """ Get the base units and conversion factors for the given unit.