From 28b602d9a05b04d17a06dc326f2e27e6d9b89d2d Mon Sep 17 00:00:00 2001 From: domna Date: Fri, 27 Jan 2023 20:46:05 +0100 Subject: [PATCH] Changes to additive dispersion instead of Unsummable --- src/elli/dispersions/base_dispersion.py | 86 +++++++++---------- src/elli/dispersions/cauchy.py | 6 +- src/elli/dispersions/cauchy_custom.py | 6 +- src/elli/dispersions/cody_lorentz.py | 4 +- .../dispersions/constant_refractive_index.py | 4 +- src/elli/dispersions/drude_energy.py | 4 +- src/elli/dispersions/drude_resistivity.py | 4 +- src/elli/dispersions/epsilon_inf.py | 4 +- src/elli/dispersions/gaussian.py | 4 +- src/elli/dispersions/lorentz_energy.py | 4 +- src/elli/dispersions/lorentz_lambda.py | 4 +- src/elli/dispersions/poles.py | 4 +- src/elli/dispersions/polynomial.py | 4 +- src/elli/dispersions/sellmeier.py | 4 +- src/elli/dispersions/sellmeier_custom.py | 4 +- src/elli/dispersions/tanguy.py | 4 +- src/elli/dispersions/tauc_lorentz.py | 4 +- tests/test_unsummable_dispersion.py | 8 +- 18 files changed, 79 insertions(+), 83 deletions(-) diff --git a/src/elli/dispersions/base_dispersion.py b/src/elli/dispersions/base_dispersion.py index bc9c8191..c631fe67 100644 --- a/src/elli/dispersions/base_dispersion.py +++ b/src/elli/dispersions/base_dispersion.py @@ -73,6 +73,21 @@ def __init__(self, *args, **kwargs): self.single_params_template, *args, **kwargs ) + summation_error_message = "" + + def __radd__(self, other: Union[int, float, "Dispersion"]) -> "Dispersion": + return self.__add__(other) + + def __add__( + self, other: Union[int, float, "AdditiveDispersion"] + ) -> "AdditiveDispersion": + raise TypeError( + f"unsupported operand type(s) for +: {type(self)} and {type(other)}." + f"\n{self.summation_error_message}" + if self.summation_error_message + else "" + ) + @abstractmethod def dielectric_function(self, lbda: npt.ArrayLike) -> npt.NDArray: """Calculates the dielectric function in a given wavelength window. @@ -101,26 +116,6 @@ def add(self, *args, **kwargs) -> "Dispersion": return self - def __radd__(self, other: Union[int, float, "Dispersion"]) -> "Dispersion": - """Add up the dielectric function of multiple models""" - return self.__add__(other) - - def __add__(self, other: Union[int, float, "Dispersion"]) -> "Dispersion": - """Add up the dielectric function of multiple models""" - if isinstance(other, UnsummableDispersion): - other.__add__(self) - - if isinstance(other, DispersionSum): - return other.__add__(self) - - if isinstance(other, (int, float)): - return DispersionSum(self, dispersions.EpsilonInf(eps=other)) - - if not isinstance(other, Dispersion): - raise TypeError(f"Invalid type {type(other)} added to dispersion") - - return DispersionSum(self, other) - def get_dielectric(self, lbda: npt.ArrayLike) -> npt.NDArray: """Returns the dielectric constant for wavelength 'lbda' default unit (nm) in the convention ε1 + iε2.""" @@ -205,17 +200,31 @@ def _dict_to_str(dic): ) -class UnsummableDispersion(Dispersion): - """This denotes a dispersion which is not summable""" +class AdditiveDispersion(Dispersion): + """An additive dispersion""" - @property - @abstractmethod - def summation_error_message(self): - """The message being displayed when someone tries - to perform an addition with this dispersion.""" + def __add__(self, other: Union[int, float, "Dispersion"]) -> "DispersionSum": + """Add up the dielectric function of multiple models""" + if not isinstance(other, Dispersion) and not isinstance(other, (int, float)): + raise TypeError( + f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" + ) - def __add__(self, _: Union[int, float, "Dispersion"]) -> "Dispersion": - raise ValueError(self.summation_error_message) + if isinstance(self, DispersionSum) and isinstance(other, DispersionSum): + self.dispersions += other.dispersions # pylint: disable=no-member + return self + + if isinstance(self, AdditiveDispersion) and isinstance(other, DispersionSum): + other.dispersions.append(self) + return other + + if isinstance(other, (int, float)): + return DispersionSum(self, dispersions.EpsilonInf(eps=other)) + + if not isinstance(other, AdditiveDispersion): + other.__add__(self) + + return DispersionSum(self, other) class DispersionFactory: @@ -240,31 +249,16 @@ def get_dispersion(identifier: str, *args, **kwargs) -> Dispersion: raise ValueError(f"No such dispersion: {identifier}") -class DispersionSum(Dispersion): +class DispersionSum(AdditiveDispersion): """Represents a sum of two dispersions""" single_params_template = {} rep_params_template = {} - def __init__(self, *disps: Dispersion) -> None: + def __init__(self, *disps: AdditiveDispersion) -> None: super().__init__() self.dispersions = disps - def __add__(self, other: Union[int, float, "Dispersion"]) -> "Dispersion": - if isinstance(other, UnsummableDispersion): - other.__add__(self) - - if isinstance(other, DispersionSum): - self.dispersions += other.dispersions - return self - - if isinstance(other, (int, float)): - self.dispersions.append(dispersions.EpsilonInf(eps=other)) - return self - - self.dispersions.append(other) - return self - def dielectric_function(self, lbda: npt.ArrayLike) -> npt.NDArray: dielectric_function = sum( disp.dielectric_function(lbda) for disp in self.dispersions diff --git a/src/elli/dispersions/cauchy.py b/src/elli/dispersions/cauchy.py index 0438cc7c..f73e5c88 100644 --- a/src/elli/dispersions/cauchy.py +++ b/src/elli/dispersions/cauchy.py @@ -2,10 +2,10 @@ """Cauchy dispersion.""" import numpy.typing as npt -from .base_dispersion import UnsummableDispersion +from .base_dispersion import Dispersion -class Cauchy(UnsummableDispersion): +class Cauchy(Dispersion): r"""Cauchy dispersion. Single parameters: @@ -28,7 +28,7 @@ class Cauchy(UnsummableDispersion): """ summation_error_message = ( - "The cauchy dispersion cannot be added to other dispersions. " + "The Cauchy dispersion cannot be added to other dispersions. " "Try the Poles or Lorentz model instead." ) diff --git a/src/elli/dispersions/cauchy_custom.py b/src/elli/dispersions/cauchy_custom.py index 4dc68672..e9bc6e0f 100644 --- a/src/elli/dispersions/cauchy_custom.py +++ b/src/elli/dispersions/cauchy_custom.py @@ -2,10 +2,10 @@ """Cauchy dispersion with custom exponents.""" import numpy.typing as npt -from .base_dispersion import UnsummableDispersion +from .base_dispersion import Dispersion -class CauchyCustomExponent(UnsummableDispersion): +class CauchyCustomExponent(Dispersion): r"""Cauchy dispersion with custom exponents. Single parameters: @@ -22,7 +22,7 @@ class CauchyCustomExponent(UnsummableDispersion): """ summation_error_message = ( - "The cauchy dispersion cannot be added to other dispersions. " + "The Cauchy dispersion cannot be added to other dispersions. " "Try the Poles or Lorentz model instead." ) diff --git a/src/elli/dispersions/cody_lorentz.py b/src/elli/dispersions/cody_lorentz.py index 46376f4e..4d2f0c1f 100644 --- a/src/elli/dispersions/cody_lorentz.py +++ b/src/elli/dispersions/cody_lorentz.py @@ -6,11 +6,11 @@ from scipy.interpolate import interp1d from ..utils import conversion_wavelength_energy -from .base_dispersion import Dispersion +from .base_dispersion import AdditiveDispersion from ..kkr import im2re_reciprocal -class CodyLorentz(Dispersion): +class CodyLorentz(AdditiveDispersion): """Tauc-Lorentz dispersion law. Model by Ferlauto et al. Single parameters: diff --git a/src/elli/dispersions/constant_refractive_index.py b/src/elli/dispersions/constant_refractive_index.py index ea1b93f5..e7c1ca3a 100644 --- a/src/elli/dispersions/constant_refractive_index.py +++ b/src/elli/dispersions/constant_refractive_index.py @@ -2,10 +2,10 @@ """Constant refractive index.""" import numpy.typing as npt -from .base_dispersion import UnsummableDispersion +from .base_dispersion import Dispersion -class ConstantRefractiveIndex(UnsummableDispersion): +class ConstantRefractiveIndex(Dispersion): r"""Constant refractive index. Single parameters: diff --git a/src/elli/dispersions/drude_energy.py b/src/elli/dispersions/drude_energy.py index 498d180d..c1cc76bb 100644 --- a/src/elli/dispersions/drude_energy.py +++ b/src/elli/dispersions/drude_energy.py @@ -3,10 +3,10 @@ import numpy.typing as npt from ..utils import conversion_wavelength_energy -from .base_dispersion import Dispersion +from .base_dispersion import AdditiveDispersion -class DrudeEnergy(Dispersion): +class DrudeEnergy(AdditiveDispersion): r"""Drude dispersion model with parameters in units of energy. Drude models in the literature typically contain an additional epsilon infinity value. Use `EpsilonInf` to add this parameter or simply add a number, e.g. DrudeEnergy() + 2, where diff --git a/src/elli/dispersions/drude_resistivity.py b/src/elli/dispersions/drude_resistivity.py index d28bc179..dcf6f5e4 100644 --- a/src/elli/dispersions/drude_resistivity.py +++ b/src/elli/dispersions/drude_resistivity.py @@ -5,10 +5,10 @@ import scipy.constants as sc from ..utils import conversion_wavelength_energy -from .base_dispersion import Dispersion +from .base_dispersion import AdditiveDispersion -class DrudeResistivity(Dispersion): +class DrudeResistivity(AdditiveDispersion): r"""Drude dispersion model with resistivity based parameters. Drude models in the literature typically contain an additional epsilon infinity value. Use `EpsilonInf` to add this parameter or simply do DrudeEnergy() + eps_inf. diff --git a/src/elli/dispersions/epsilon_inf.py b/src/elli/dispersions/epsilon_inf.py index c33eecef..4997819b 100644 --- a/src/elli/dispersions/epsilon_inf.py +++ b/src/elli/dispersions/epsilon_inf.py @@ -2,10 +2,10 @@ """Constant epsilon infinity.""" import numpy.typing as npt -from .base_dispersion import Dispersion +from .base_dispersion import AdditiveDispersion -class EpsilonInf(Dispersion): +class EpsilonInf(AdditiveDispersion): r"""Constant epsilon infinity. Single parameters: diff --git a/src/elli/dispersions/gaussian.py b/src/elli/dispersions/gaussian.py index 73b5b71a..f4a2ef69 100644 --- a/src/elli/dispersions/gaussian.py +++ b/src/elli/dispersions/gaussian.py @@ -8,10 +8,10 @@ from scipy.special import dawsn from ..utils import conversion_wavelength_energy -from .base_dispersion import Dispersion +from .base_dispersion import AdditiveDispersion -class Gaussian(Dispersion): +class Gaussian(AdditiveDispersion): r"""Dispersion law with gaussian oscillators. Single parameters: diff --git a/src/elli/dispersions/lorentz_energy.py b/src/elli/dispersions/lorentz_energy.py index 29eff9c6..fcf60c29 100644 --- a/src/elli/dispersions/lorentz_energy.py +++ b/src/elli/dispersions/lorentz_energy.py @@ -3,10 +3,10 @@ import numpy.typing as npt from ..utils import conversion_wavelength_energy -from .base_dispersion import Dispersion +from .base_dispersion import AdditiveDispersion -class LorentzEnergy(Dispersion): +class LorentzEnergy(AdditiveDispersion): r"""Lorentz dispersion law with parameters in units of energy. Single parameters: diff --git a/src/elli/dispersions/lorentz_lambda.py b/src/elli/dispersions/lorentz_lambda.py index 767a347e..7e14f949 100644 --- a/src/elli/dispersions/lorentz_lambda.py +++ b/src/elli/dispersions/lorentz_lambda.py @@ -2,10 +2,10 @@ """Lorentz dispersion law with parameters in units of wavelengths.""" import numpy.typing as npt -from .base_dispersion import Dispersion +from .base_dispersion import AdditiveDispersion -class LorentzLambda(Dispersion): +class LorentzLambda(AdditiveDispersion): r"""Lorentz dispersion law with parameters in units of wavelengths. Single parameters: diff --git a/src/elli/dispersions/poles.py b/src/elli/dispersions/poles.py index a4b278d7..bde6dbfa 100644 --- a/src/elli/dispersions/poles.py +++ b/src/elli/dispersions/poles.py @@ -3,10 +3,10 @@ import numpy.typing as npt from ..utils import conversion_wavelength_energy -from .base_dispersion import Dispersion +from .base_dispersion import AdditiveDispersion -class Poles(Dispersion): +class Poles(AdditiveDispersion): r"""Dispersion law for an UV and IR pole, i.e. Lorentz oscillators outside the fitting spectral range and zero broadening. diff --git a/src/elli/dispersions/polynomial.py b/src/elli/dispersions/polynomial.py index 16b6eb72..56120b4a 100644 --- a/src/elli/dispersions/polynomial.py +++ b/src/elli/dispersions/polynomial.py @@ -2,10 +2,10 @@ """Polynomial dispersion.""" import numpy.typing as npt -from .base_dispersion import Dispersion +from .base_dispersion import AdditiveDispersion -class Polynomial(Dispersion): +class Polynomial(AdditiveDispersion): r"""Polynomial expression for the dielectric function. Single parameters: diff --git a/src/elli/dispersions/sellmeier.py b/src/elli/dispersions/sellmeier.py index e95c798c..55a75a50 100644 --- a/src/elli/dispersions/sellmeier.py +++ b/src/elli/dispersions/sellmeier.py @@ -2,10 +2,10 @@ """Sellmeier dispersion.""" import numpy.typing as npt -from .base_dispersion import Dispersion +from .base_dispersion import AdditiveDispersion -class Sellmeier(Dispersion): +class Sellmeier(AdditiveDispersion): r"""Sellmeier dispersion. Single parameters: diff --git a/src/elli/dispersions/sellmeier_custom.py b/src/elli/dispersions/sellmeier_custom.py index 0698b9a1..bcd9530b 100644 --- a/src/elli/dispersions/sellmeier_custom.py +++ b/src/elli/dispersions/sellmeier_custom.py @@ -2,10 +2,10 @@ """Sellmeier dispersion.""" import numpy.typing as npt -from .base_dispersion import Dispersion +from .base_dispersion import AdditiveDispersion -class SellmeierCustomExponent(Dispersion): +class SellmeierCustomExponent(AdditiveDispersion): r"""Sellmeier dispersion with custom exponents. Single parameters: diff --git a/src/elli/dispersions/tanguy.py b/src/elli/dispersions/tanguy.py index d09ec445..bc14a6d6 100644 --- a/src/elli/dispersions/tanguy.py +++ b/src/elli/dispersions/tanguy.py @@ -8,10 +8,10 @@ from scipy.special import digamma, gamma from ..utils import conversion_wavelength_energy -from .base_dispersion import Dispersion +from .base_dispersion import AdditiveDispersion -class Tanguy(Dispersion): +class Tanguy(AdditiveDispersion): r"""Fractional dimensional Tanguy model. This model is an analytical expression of Wannier excitons, including bound and unbound states. diff --git a/src/elli/dispersions/tauc_lorentz.py b/src/elli/dispersions/tauc_lorentz.py index 68f68a83..5d6f4f84 100644 --- a/src/elli/dispersions/tauc_lorentz.py +++ b/src/elli/dispersions/tauc_lorentz.py @@ -5,10 +5,10 @@ from numpy.lib.scimath import sqrt from ..utils import conversion_wavelength_energy -from .base_dispersion import Dispersion +from .base_dispersion import AdditiveDispersion -class TaucLorentz(Dispersion): +class TaucLorentz(AdditiveDispersion): """Tauc-Lorentz dispersion law. Model by Jellison and Modine. Single parameters: diff --git a/tests/test_unsummable_dispersion.py b/tests/test_unsummable_dispersion.py index 997da2b9..25a9a78e 100644 --- a/tests/test_unsummable_dispersion.py +++ b/tests/test_unsummable_dispersion.py @@ -6,13 +6,15 @@ def test_fail_on_adding_cauchy(): """Test whether the kkr reproduces the analytical expression of Tauc-Lorentz""" cauchy_err_str = ( - "The cauchy dispersion cannot be added to other dispersions. " + "unsupported operand type(s) for +: " + " and ." + "\nThe Cauchy dispersion cannot be added to other dispersions. " "Try the Poles or Lorentz model instead." ) - with pytest.raises(ValueError) as sum_err: + with pytest.raises(TypeError) as sum_err: _ = Cauchy() + Cauchy() assert cauchy_err_str in str(sum_err.value) - with pytest.raises(ValueError): + with pytest.raises(TypeError): _ = 1 + Cauchy()