Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 40 additions & 46 deletions src/elli/dispersions/base_dispersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,21 @@ def __init__(self, *args, **kwargs):
self.single_params_template, *args, **kwargs
)

summation_error_message = ""

def __radd__(self, other: Union[int, float, "Dispersion"]) -> "Dispersion":
return self.__add__(other)

def __add__(
self, other: Union[int, float, "AdditiveDispersion"]
) -> "AdditiveDispersion":
raise TypeError(
f"unsupported operand type(s) for +: {type(self)} and {type(other)}."
f"\n{self.summation_error_message}"
if self.summation_error_message
else ""
)

@abstractmethod
def dielectric_function(self, lbda: npt.ArrayLike) -> npt.NDArray:
"""Calculates the dielectric function in a given wavelength window.
Expand Down Expand Up @@ -101,26 +116,6 @@ def add(self, *args, **kwargs) -> "Dispersion":

return self

def __radd__(self, other: Union[int, float, "Dispersion"]) -> "Dispersion":
"""Add up the dielectric function of multiple models"""
return self.__add__(other)

def __add__(self, other: Union[int, float, "Dispersion"]) -> "Dispersion":
"""Add up the dielectric function of multiple models"""
if isinstance(other, UnsummableDispersion):
other.__add__(self)

if isinstance(other, DispersionSum):
return other.__add__(self)

if isinstance(other, (int, float)):
return DispersionSum(self, dispersions.EpsilonInf(eps=other))

if not isinstance(other, Dispersion):
raise TypeError(f"Invalid type {type(other)} added to dispersion")

return DispersionSum(self, other)

def get_dielectric(self, lbda: npt.ArrayLike) -> npt.NDArray:
"""Returns the dielectric constant for wavelength 'lbda' default unit (nm)
in the convention ε1 + iε2."""
Expand Down Expand Up @@ -205,17 +200,31 @@ def _dict_to_str(dic):
)


class UnsummableDispersion(Dispersion):
"""This denotes a dispersion which is not summable"""
class AdditiveDispersion(Dispersion):
"""An additive dispersion"""

@property
@abstractmethod
def summation_error_message(self):
"""The message being displayed when someone tries
to perform an addition with this dispersion."""
def __add__(self, other: Union[int, float, "Dispersion"]) -> "DispersionSum":
"""Add up the dielectric function of multiple models"""
if not isinstance(other, Dispersion) and not isinstance(other, (int, float)):
raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
)

def __add__(self, _: Union[int, float, "Dispersion"]) -> "Dispersion":
raise ValueError(self.summation_error_message)
if isinstance(self, DispersionSum) and isinstance(other, DispersionSum):
self.dispersions += other.dispersions # pylint: disable=no-member
return self

if isinstance(self, AdditiveDispersion) and isinstance(other, DispersionSum):
other.dispersions.append(self)
return other

if isinstance(other, (int, float)):
return DispersionSum(self, dispersions.EpsilonInf(eps=other))

if not isinstance(other, AdditiveDispersion):
other.__add__(self)

return DispersionSum(self, other)


class DispersionFactory:
Expand All @@ -240,31 +249,16 @@ def get_dispersion(identifier: str, *args, **kwargs) -> Dispersion:
raise ValueError(f"No such dispersion: {identifier}")


class DispersionSum(Dispersion):
class DispersionSum(AdditiveDispersion):
"""Represents a sum of two dispersions"""

single_params_template = {}
rep_params_template = {}

def __init__(self, *disps: Dispersion) -> None:
def __init__(self, *disps: AdditiveDispersion) -> None:
super().__init__()
self.dispersions = disps

def __add__(self, other: Union[int, float, "Dispersion"]) -> "Dispersion":
if isinstance(other, UnsummableDispersion):
other.__add__(self)

if isinstance(other, DispersionSum):
self.dispersions += other.dispersions
return self

if isinstance(other, (int, float)):
self.dispersions.append(dispersions.EpsilonInf(eps=other))
return self

self.dispersions.append(other)
return self

def dielectric_function(self, lbda: npt.ArrayLike) -> npt.NDArray:
dielectric_function = sum(
disp.dielectric_function(lbda) for disp in self.dispersions
Expand Down
6 changes: 3 additions & 3 deletions src/elli/dispersions/cauchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
"""Cauchy dispersion."""
import numpy.typing as npt

from .base_dispersion import UnsummableDispersion
from .base_dispersion import Dispersion


class Cauchy(UnsummableDispersion):
class Cauchy(Dispersion):
r"""Cauchy dispersion.

Single parameters:
Expand All @@ -28,7 +28,7 @@ class Cauchy(UnsummableDispersion):
"""

summation_error_message = (
"The cauchy dispersion cannot be added to other dispersions. "
"The Cauchy dispersion cannot be added to other dispersions. "
"Try the Poles or Lorentz model instead."
)

Expand Down
6 changes: 3 additions & 3 deletions src/elli/dispersions/cauchy_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
"""Cauchy dispersion with custom exponents."""
import numpy.typing as npt

from .base_dispersion import UnsummableDispersion
from .base_dispersion import Dispersion


class CauchyCustomExponent(UnsummableDispersion):
class CauchyCustomExponent(Dispersion):
r"""Cauchy dispersion with custom exponents.

Single parameters:
Expand All @@ -22,7 +22,7 @@ class CauchyCustomExponent(UnsummableDispersion):
"""

summation_error_message = (
"The cauchy dispersion cannot be added to other dispersions. "
"The Cauchy dispersion cannot be added to other dispersions. "
"Try the Poles or Lorentz model instead."
)

Expand Down
4 changes: 2 additions & 2 deletions src/elli/dispersions/cody_lorentz.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from scipy.interpolate import interp1d

from ..utils import conversion_wavelength_energy
from .base_dispersion import Dispersion
from .base_dispersion import AdditiveDispersion
from ..kkr import im2re_reciprocal


class CodyLorentz(Dispersion):
class CodyLorentz(AdditiveDispersion):
"""Tauc-Lorentz dispersion law. Model by Ferlauto et al.

Single parameters:
Expand Down
4 changes: 2 additions & 2 deletions src/elli/dispersions/constant_refractive_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
"""Constant refractive index."""
import numpy.typing as npt

from .base_dispersion import UnsummableDispersion
from .base_dispersion import Dispersion


class ConstantRefractiveIndex(UnsummableDispersion):
class ConstantRefractiveIndex(Dispersion):
r"""Constant refractive index.

Single parameters:
Expand Down
4 changes: 2 additions & 2 deletions src/elli/dispersions/drude_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import numpy.typing as npt

from ..utils import conversion_wavelength_energy
from .base_dispersion import Dispersion
from .base_dispersion import AdditiveDispersion


class DrudeEnergy(Dispersion):
class DrudeEnergy(AdditiveDispersion):
r"""Drude dispersion model with parameters in units of energy.
Drude models in the literature typically contain an additional epsilon infinity value.
Use `EpsilonInf` to add this parameter or simply add a number, e.g. DrudeEnergy() + 2, where
Expand Down
4 changes: 2 additions & 2 deletions src/elli/dispersions/drude_resistivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import scipy.constants as sc

from ..utils import conversion_wavelength_energy
from .base_dispersion import Dispersion
from .base_dispersion import AdditiveDispersion


class DrudeResistivity(Dispersion):
class DrudeResistivity(AdditiveDispersion):
r"""Drude dispersion model with resistivity based parameters.
Drude models in the literature typically contain an additional epsilon infinity value.
Use `EpsilonInf` to add this parameter or simply do DrudeEnergy() + eps_inf.
Expand Down
4 changes: 2 additions & 2 deletions src/elli/dispersions/epsilon_inf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
"""Constant epsilon infinity."""
import numpy.typing as npt

from .base_dispersion import Dispersion
from .base_dispersion import AdditiveDispersion


class EpsilonInf(Dispersion):
class EpsilonInf(AdditiveDispersion):
r"""Constant epsilon infinity.

Single parameters:
Expand Down
4 changes: 2 additions & 2 deletions src/elli/dispersions/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from scipy.special import dawsn

from ..utils import conversion_wavelength_energy
from .base_dispersion import Dispersion
from .base_dispersion import AdditiveDispersion


class Gaussian(Dispersion):
class Gaussian(AdditiveDispersion):
r"""Dispersion law with gaussian oscillators.

Single parameters:
Expand Down
4 changes: 2 additions & 2 deletions src/elli/dispersions/lorentz_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import numpy.typing as npt

from ..utils import conversion_wavelength_energy
from .base_dispersion import Dispersion
from .base_dispersion import AdditiveDispersion


class LorentzEnergy(Dispersion):
class LorentzEnergy(AdditiveDispersion):
r"""Lorentz dispersion law with parameters in units of energy.

Single parameters:
Expand Down
4 changes: 2 additions & 2 deletions src/elli/dispersions/lorentz_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
"""Lorentz dispersion law with parameters in units of wavelengths."""
import numpy.typing as npt

from .base_dispersion import Dispersion
from .base_dispersion import AdditiveDispersion


class LorentzLambda(Dispersion):
class LorentzLambda(AdditiveDispersion):
r"""Lorentz dispersion law with parameters in units of wavelengths.

Single parameters:
Expand Down
4 changes: 2 additions & 2 deletions src/elli/dispersions/poles.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import numpy.typing as npt

from ..utils import conversion_wavelength_energy
from .base_dispersion import Dispersion
from .base_dispersion import AdditiveDispersion


class Poles(Dispersion):
class Poles(AdditiveDispersion):
r"""Dispersion law for an UV and IR pole,
i.e. Lorentz oscillators outside the fitting spectral range and zero broadening.

Expand Down
4 changes: 2 additions & 2 deletions src/elli/dispersions/polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
"""Polynomial dispersion."""
import numpy.typing as npt

from .base_dispersion import Dispersion
from .base_dispersion import AdditiveDispersion


class Polynomial(Dispersion):
class Polynomial(AdditiveDispersion):
r"""Polynomial expression for the dielectric function.

Single parameters:
Expand Down
4 changes: 2 additions & 2 deletions src/elli/dispersions/sellmeier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
"""Sellmeier dispersion."""
import numpy.typing as npt

from .base_dispersion import Dispersion
from .base_dispersion import AdditiveDispersion


class Sellmeier(Dispersion):
class Sellmeier(AdditiveDispersion):
r"""Sellmeier dispersion.

Single parameters:
Expand Down
4 changes: 2 additions & 2 deletions src/elli/dispersions/sellmeier_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
"""Sellmeier dispersion."""
import numpy.typing as npt

from .base_dispersion import Dispersion
from .base_dispersion import AdditiveDispersion


class SellmeierCustomExponent(Dispersion):
class SellmeierCustomExponent(AdditiveDispersion):
r"""Sellmeier dispersion with custom exponents.

Single parameters:
Expand Down
4 changes: 2 additions & 2 deletions src/elli/dispersions/tanguy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from scipy.special import digamma, gamma

from ..utils import conversion_wavelength_energy
from .base_dispersion import Dispersion
from .base_dispersion import AdditiveDispersion


class Tanguy(Dispersion):
class Tanguy(AdditiveDispersion):
r"""Fractional dimensional Tanguy model.
This model is an analytical expression of Wannier excitons, including
bound and unbound states.
Expand Down
4 changes: 2 additions & 2 deletions src/elli/dispersions/tauc_lorentz.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from numpy.lib.scimath import sqrt

from ..utils import conversion_wavelength_energy
from .base_dispersion import Dispersion
from .base_dispersion import AdditiveDispersion


class TaucLorentz(Dispersion):
class TaucLorentz(AdditiveDispersion):
"""Tauc-Lorentz dispersion law. Model by Jellison and Modine.

Single parameters:
Expand Down
8 changes: 5 additions & 3 deletions tests/test_unsummable_dispersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
def test_fail_on_adding_cauchy():
"""Test whether the kkr reproduces the analytical expression of Tauc-Lorentz"""
cauchy_err_str = (
"The cauchy dispersion cannot be added to other dispersions. "
"unsupported operand type(s) for +: "
"<class 'elli.dispersions.cauchy.Cauchy'> and <class 'elli.dispersions.cauchy.Cauchy'>."
"\nThe Cauchy dispersion cannot be added to other dispersions. "
"Try the Poles or Lorentz model instead."
)
with pytest.raises(ValueError) as sum_err:
with pytest.raises(TypeError) as sum_err:
_ = Cauchy() + Cauchy()

assert cauchy_err_str in str(sum_err.value)

with pytest.raises(ValueError):
with pytest.raises(TypeError):
_ = 1 + Cauchy()