Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
python -m pip install --upgrade pip
python -m pip install -r develop.txt
python -m pip install -r docs/requirements.txt
python -m pip install astropy scikit-image scikit-learn
python -m pip install astropy "scikit-image<0.20" scikit-learn
python -m pip install tensorflow>=2.4.1
python -m pip install twine
python -m pip install .
Expand Down Expand Up @@ -108,7 +108,7 @@ jobs:
python --version
python -m pip install --upgrade pip
python -m pip install -r develop.txt
python -m pip install astropy scikit-image scikit-learn
python -m pip install astropy "scikit-image<0.20" scikit-learn
python -m pip install .

- name: Run Tests
Expand Down
46 changes: 22 additions & 24 deletions modopt/math/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
import numpy as np

try:
from packaging import version
from astropy import __version__ as astropy_version
from astropy.convolution import Gaussian2DKernel
except ImportError: # pragma: no cover
import_astropy = False
else:
import_astropy = True


def gaussian_kernel(data_shape, sigma, norm='max'):
def gaussian_kernel(data_shape, sigma, norm="max"):
"""Gaussian kernel.

This method produces a Gaussian kerenal of a specified size and dispersion.
Expand All @@ -29,9 +31,8 @@ def gaussian_kernel(data_shape, sigma, norm='max'):
Desiered shape of the kernel
sigma : float
Standard deviation of the kernel
norm : {'max', 'sum', 'none'}, optional
Normalisation of the kerenl (options are ``'max'``, ``'sum'`` or
``'none'``, default is ``'max'``)
norm : {'max', 'sum'}, optional
Normalisation of the kerenl (options are ``'max'`` or ``'sum'``, default is ``'max'``)

Returns
-------
Expand Down Expand Up @@ -60,22 +61,22 @@ def gaussian_kernel(data_shape, sigma, norm='max'):

"""
if not import_astropy: # pragma: no cover
raise ImportError('Astropy package not found.')
raise ImportError("Astropy package not found.")

if norm not in {'max', 'sum', 'none'}:
if norm not in {"max", "sum"}:
raise ValueError('Invalid norm, options are "max", "sum" or "none".')

kernel = np.array(
Gaussian2DKernel(sigma, x_size=data_shape[1], y_size=data_shape[0]),
)

if norm == 'max':
if norm == "max":
return kernel / np.max(kernel)

elif norm == 'sum':
elif version.parse(astropy_version) < version.parse("5.2"):
return kernel / np.sum(kernel)

elif norm == 'none':
else:
return kernel


Expand Down Expand Up @@ -147,7 +148,7 @@ def mse(data1, data2):
return np.mean((data1 - data2) ** 2)


def psnr(data1, data2, method='starck', max_pix=255):
def psnr(data1, data2, method="starck", max_pix=255):
r"""Peak Signal-to-Noise Ratio.

This method calculates the Peak Signal-to-Noise Ratio between two data
Expand Down Expand Up @@ -202,23 +203,21 @@ def psnr(data1, data2, method='starck', max_pix=255):
10\log_{10}(\mathrm{MSE}))

"""
if method == 'starck':
return (
20 * np.log10(
(data1.shape[0] * np.abs(np.max(data1) - np.min(data1)))
/ np.linalg.norm(data1 - data2),
)
if method == "starck":
return 20 * np.log10(
(data1.shape[0] * np.abs(np.max(data1) - np.min(data1)))
/ np.linalg.norm(data1 - data2),
)

elif method == 'wiki':
return (20 * np.log10(max_pix) - 10 * np.log10(mse(data1, data2)))
elif method == "wiki":
return 20 * np.log10(max_pix) - 10 * np.log10(mse(data1, data2))

raise ValueError(
'Invalid PSNR method. Options are "starck" and "wiki"',
)


def psnr_stack(data1, data2, metric=np.mean, method='starck'):
def psnr_stack(data1, data2, metric=np.mean, method="starck"):
"""Peak Signa-to-Noise for stack of images.

This method calculates the PSNRs for two stacks of 2D arrays.
Expand Down Expand Up @@ -261,12 +260,11 @@ def psnr_stack(data1, data2, metric=np.mean, method='starck'):

"""
if data1.ndim != 3 or data2.ndim != 3:
raise ValueError('Input data must be a 3D np.ndarray')
raise ValueError("Input data must be a 3D np.ndarray")

return metric([
psnr(i_elem, j_elem, method=method)
for i_elem, j_elem in zip(data1, data2)
])
return metric(
[psnr(i_elem, j_elem, method=method) for i_elem, j_elem in zip(data1, data2)]
)


def sigma_mad(input_data):
Expand Down
10 changes: 0 additions & 10 deletions modopt/tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,16 +284,6 @@ class TestStats:
]
),
),
(
"none",
np.array(
[
[0.05854983, 0.09653235, 0.05854983],
[0.09653235, 0.15915494, 0.09653235],
[0.05854983, 0.09653235, 0.05854983],
]
),
),
failparam("fail", None, raises=ValueError),
],
)
Expand Down