From b7ec90ed1f2dd295f2a4122c6f544ccbe8e006aa Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Fri, 25 Nov 2022 11:42:41 +0100 Subject: [PATCH 01/33] add MatrixOperator. --- modopt/opt/linear.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/modopt/opt/linear.py b/modopt/opt/linear.py index 3807253b..3519024a 100644 --- a/modopt/opt/linear.py +++ b/modopt/opt/linear.py @@ -11,6 +11,7 @@ import numpy as np from modopt.base.types import check_callable, check_float +from modopt.base.backend import get_array_module from modopt.signal.wavelet import filter_convolve_stack @@ -80,6 +81,24 @@ def __init__(self): self.adj_op = self.op +class MatrixOperator(LinearParent): + """ + Matrix Operator class. + + This transform an array into a suitable linear operator. + """ + + def __init__(self, array): + self.op = lambda x: array @ x + xp = get_array_module(array) + + if xp.any(xp.iscomplex(array)): + + self.adj_op = lambda x: array.T.conjugate() @ x + else: + self.adj_op = lambda x: array.T @ x + + class WaveletConvolve(LinearParent): """Wavelet Convolution Class. From 26385aa95e08b6ad7c61cdc3dd95431ff2c6b5f1 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Tue, 29 Nov 2022 15:46:01 +0100 Subject: [PATCH 02/33] move base test to pytest. --- modopt/tests/test_base.py | 428 ++++++++++++++------------------------ 1 file changed, 157 insertions(+), 271 deletions(-) diff --git a/modopt/tests/test_base.py b/modopt/tests/test_base.py index 873a4506..2ebbbeb5 100644 --- a/modopt/tests/test_base.py +++ b/modopt/tests/test_base.py @@ -1,192 +1,135 @@ -# -*- coding: utf-8 -*- - -"""UNIT TESTS FOR BASE. - -This module contains unit tests for the modopt.base module. - -:Author: Samuel Farrens - """ +Test for base module. -from builtins import range -from unittest import TestCase, skipIf - +:Author: Pierre-Antoine Comby Date: Tue, 29 Nov 2022 15:46:12 +0100 Subject: [PATCH 03/33] [fixme] remove flake8 and emoji config. --- setup.cfg | 2 -- 1 file changed, 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 87496ced..810644e0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -83,8 +83,6 @@ testpaths = modopt addopts = --verbose - --emoji - --flake8 --cov=modopt --cov-report=term --cov-report=xml From 70016f9ee886289fa6264d571c92544a4ae84672 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Wed, 30 Nov 2022 10:46:07 +0100 Subject: [PATCH 04/33] rewrite test_math module using pytest. --- modopt/tests/test_math.py | 683 +++++++++++++++----------------------- 1 file changed, 267 insertions(+), 416 deletions(-) diff --git a/modopt/tests/test_math.py b/modopt/tests/test_math.py index ba175ae6..5d1fa743 100644 --- a/modopt/tests/test_math.py +++ b/modopt/tests/test_math.py @@ -1,215 +1,186 @@ -# -*- coding: utf-8 -*- - """UNIT TESTS FOR MATH. This module contains unit tests for the modopt.math module. +:Author: Pierre-Antoine Comby :Author: Samuel Farrens - """ - -from unittest import TestCase, skipIf, skipUnless +import pytest import numpy as np import numpy.testing as npt + from modopt.math import convolve, matrix, metrics, stats try: import astropy except ImportError: # pragma: no cover - import_astropy = False + ASTROPY_AVAILABLE = False else: # pragma: no cover - import_astropy = True + ASTROPY_AVAILABLE = True try: from skimage.metrics import structural_similarity as compare_ssim except ImportError: # pragma: no cover - import_skimage = False + SKIMAGE_AVAILABLE = False else: - import_skimage = True - - -class ConvolveTestCase(TestCase): - """Test case for convolve module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(18).reshape(2, 3, 3) - self.data2 = self.data1 + 1 - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - - @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover - def test_convolve_astropy(self): - """Test convolve using astropy.""" - npt.assert_allclose( - convolve.convolve(self.data1[0], self.data2[0], method='astropy'), - np.array([ - [210.0, 201.0, 210.0], - [129.0, 120.0, 129.0], - [210.0, 201.0, 210.0], - ]), - err_msg='Incorrect convolution: astropy', - ) - - npt.assert_raises( - ValueError, - convolve.convolve, - self.data1[0], - self.data2, - ) - - npt.assert_raises( - ValueError, - convolve.convolve, - self.data1[0], - self.data2[0], - method='bla', - ) - - def test_convolve_scipy(self): - """Test convolve using scipy.""" - npt.assert_allclose( - convolve.convolve(self.data1[0], self.data2[0], method='scipy'), - np.array([ + SKIMAGE_AVAILABLE = True + + +class TestConvolve: + """Test convolve functions.""" + + array233 = np.arange(18).reshape((2, 3, 3)) + array233_1 = array233 + 1 + result_astropy = np.array( + [ + [210.0, 201.0, 210.0], + [129.0, 120.0, 129.0], + [210.0, 201.0, 210.0], + ] + ) + result_scipy = np.array( + [ + [ [14.0, 35.0, 38.0], [57.0, 120.0, 111.0], [110.0, 197.0, 158.0], - ]), - err_msg='Incorrect convolution: scipy', - ) - - def test_convolve_stack(self): - """Test convolve_stack.""" - npt.assert_allclose( - convolve.convolve_stack(self.data1, self.data2), - np.array([ - [ - [14.0, 35.0, 38.0], - [57.0, 120.0, 111.0], - [110.0, 197.0, 158.0], - ], - [ - [518.0, 845.0, 614.0], - [975.0, 1578.0, 1137.0], - [830.0, 1331.0, 950.0], - ], - ]), - err_msg='Incorrect convolution: stack', - ) - - def test_convolve_stack_rot(self): - """Test convolve_stack rotated.""" + ], + [ + [518.0, 845.0, 614.0], + [975.0, 1578.0, 1137.0], + [830.0, 1331.0, 950.0], + ], + ] + ) + + result_rot_kernel = np.array( + [ + [ + [66.0, 115.0, 82.0], + [153.0, 240.0, 159.0], + [90.0, 133.0, 82.0], + ], + [ + [714.0, 1087.0, 730.0], + [1125.0, 1698.0, 1131.0], + [738.0, 1105.0, 730.0], + ], + ] + ) + + @pytest.mark.parametrize( + ("input_data", "kernel", "method", "result"), + [ + pytest.param( + array233[0], + array233_1[0], + "astropy", + result_astropy, + marks=pytest.mark.skipif(not ASTROPY_AVAILABLE, reason="astropy not available"), + ), + pytest.param( + array233[0], + array233_1, + "astropy", + result_astropy, + marks=pytest.mark.xfail(raises=ValueError), + ), + pytest.param( + array233[0], + array233_1[0], + "fail!", + result_astropy, + marks=pytest.mark.xfail(raises=ValueError), + ), + (array233[0], array233_1[0], "scipy", result_scipy[0]), + ], + ) + def test_convolve(self, input_data, kernel, method, result): + """Test convolve function.""" + npt.assert_allclose(convolve.convolve(input_data, kernel, method), result) + + @pytest.mark.parametrize( + ("result", "rot_kernel"), + [ + (result_scipy, False), + (result_rot_kernel, True), + ], + ) + def test_convolve_stack(self, result, rot_kernel): + """Test convolve stack function.""" npt.assert_allclose( - convolve.convolve_stack(self.data1, self.data2, rot_kernel=True), - np.array([ - [ - [66.0, 115.0, 82.0], - [153.0, 240.0, 159.0], - [90.0, 133.0, 82.0], - ], - [ - [714.0, 1087.0, 730.0], - [1125.0, 1698.0, 1131.0], - [738.0, 1105.0, 730.0], - ], - ]), - err_msg='Incorrect convolution: stack rot', + convolve.convolve_stack( + self.array233, self.array233_1, rot_kernel=rot_kernel + ), + result, ) -class MatrixTestCase(TestCase): - """Test case for matrix module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3) - self.data2 = np.arange(3) - self.data3 = np.arange(6).reshape(2, 3) - np.random.seed(1) - self.pmInstance1 = matrix.PowerMethod( - lambda x_val: x_val.dot(x_val.T), - self.data1.shape, - verbose=True, - ) - np.random.seed(1) - self.pmInstance2 = matrix.PowerMethod( - lambda x_val: x_val.dot(x_val.T), - self.data1.shape, - auto_run=False, - verbose=True, - ) - self.pmInstance2.get_spec_rad(max_iter=1) - self.gram_schmidt_out = ( - np.array([ +class TestMatrix: + array3 = np.arange(3) + array33 = np.arange(9).reshape((3, 3)) + array23 = np.arange(6).reshape((2, 3)) + gram_schmidt_out = ( + np.array( + [ [0, 1.0, 2.0], [3.0, 1.2, -6e-1], [-1.77635684e-15, 0, 0], - ]), - np.array([ + ] + ), + np.array( + [ [0, 0.4472136, 0.89442719], [0.91287093, 0.36514837, -0.18257419], [-1.0, 0, 0], - ]), - ) + ] + ), + ) - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.data3 = None - self.pmInstance1 = None - self.pmInstance2 = None - self.gram_schmidt_out = None - - def test_gram_schmidt_orthonormal(self): - """Test gram_schmidt with orthonormal output.""" - npt.assert_allclose( - matrix.gram_schmidt(self.data1), - self.gram_schmidt_out[1], - err_msg='Incorrect Gram-Schmidt: orthonormal', - ) - - npt.assert_raises( - ValueError, - matrix.gram_schmidt, - self.data1, - return_opt='bla', - ) - - def test_gram_schmidt_orthogonal(self): - """Test gram_schmidt with orthogonal output.""" - npt.assert_allclose( - matrix.gram_schmidt(self.data1, return_opt='orthogonal'), - self.gram_schmidt_out[0], - err_msg='Incorrect Gram-Schmidt: orthogonal', + @pytest.fixture + def pm_instance(self, request): + np.random.seed(1) + """Power Method instance.""" + pm = matrix.PowerMethod( + lambda x_val: x_val.dot(x_val.T), + self.array33.shape, + auto_run=request.param, + verbose=True, ) - - def test_gram_schmidt_both(self): - """Test gram_schmidt with both outputs.""" + if not request.param: + pm.get_spec_rad(max_iter=1) + return pm + + @pytest.mark.parametrize( + ("return_opt", "output"), + [ + ("orthonormal", gram_schmidt_out[1]), + ("orthogonal", gram_schmidt_out[0]), + ("both", gram_schmidt_out), + pytest.param( + "fail!", gram_schmidt_out, marks=pytest.mark.xfail(raises=ValueError) + ), + ], + ) + def test_gram_schmidt(self, return_opt, output): + """Test gram schmidt.""" npt.assert_allclose( - matrix.gram_schmidt(self.data1, return_opt='both'), - self.gram_schmidt_out, - err_msg='Incorrect Gram-Schmidt: both', + matrix.gram_schmidt(self.array33, return_opt=return_opt), output ) def test_nuclear_norm(self): - """Test nuclear_norm.""" + """Test nuclear norm.""" npt.assert_almost_equal( - matrix.nuclear_norm(self.data1), + matrix.nuclear_norm(self.array33), 15.49193338482967, - err_msg='Incorrect nuclear norm', ) def test_project(self): """Test project.""" npt.assert_array_equal( - matrix.project(self.data2, self.data2 + 3), + matrix.project(self.array3, self.array3 + 3), np.array([0, 2.8, 5.6]), - err_msg='Incorrect projection', ) def test_rot_matrix(self): @@ -217,280 +188,160 @@ def test_rot_matrix(self): npt.assert_allclose( matrix.rot_matrix(np.pi / 6), np.array([[0.8660254, -0.5], [0.5, 0.8660254]]), - err_msg='Incorrect rotation matrix', ) def test_rotate(self): """Test rotate.""" npt.assert_array_equal( - matrix.rotate(self.data1, np.pi / 2), + matrix.rotate(self.array33, np.pi / 2), np.array([[2, 5, 8], [1, 4, 7], [0, 3, 6]]), - err_msg='Incorrect rotation', - ) - - npt.assert_raises(ValueError, matrix.rotate, self.data3, np.pi / 2) - - def test_powermethod_converged(self): - """Test PowerMethod converged.""" - npt.assert_almost_equal( - self.pmInstance1.spec_rad, - 1.0, - err_msg='Incorrect spectral radius: converged', ) - npt.assert_almost_equal( - self.pmInstance1.inv_spec_rad, - 1.0, - err_msg='Incorrect inverse spectral radius: converged', - ) - - def test_powermethod_unconverged(self): - """Test PowerMethod unconverged.""" - npt.assert_almost_equal( - self.pmInstance2.spec_rad, - 0.8675467477372257, - err_msg='Incorrect spectral radius: unconverged', - ) - - npt.assert_almost_equal( - self.pmInstance2.inv_spec_rad, - 1.152675636913221, - err_msg='Incorrect inverse spectral radius: unconverged', - ) - - -class MetricsTestCase(TestCase): - """Test case for metrics module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(49).reshape(7, 7) - self.mask = np.ones(self.data1.shape) - self.ssim_res = 0.8963363560519094 - self.ssim_mask_res = 0.805154442543846 - self.snr_res = 10.134554256920536 - self.psnr_res = 14.860761791850397 - self.mse_res = 0.03265305507330247 - self.nrmse_res = 0.31136678840022625 - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.mask = None - self.ssim_res = None - self.ssim_mask_res = None - self.psnr_res = None - self.mse_res = None - self.nrmse_res = None - - @skipIf(import_skimage, 'skimage is installed.') # pragma: no cover - def test_ssim_skimage_error(self): - """Test ssim skimage error.""" - npt.assert_raises(ImportError, metrics.ssim, self.data1, self.data1) - - @skipUnless(import_skimage, 'skimage not installed.') # pragma: no cover - def test_ssim(self): + npt.assert_raises(ValueError, matrix.rotate, self.array23, np.pi / 2) + + @pytest.mark.parametrize( + ("pm_instance", "value"), + [(True, 1.0), (False, 0.8675467477372257)], + indirect=["pm_instance"], + ) + def test_power_method(self, pm_instance, value): + """Test power method.""" + npt.assert_almost_equal(pm_instance.spec_rad, value) + npt.assert_almost_equal(pm_instance.inv_spec_rad, 1 / value) + + +class TestMetrics: + """Test metrics module.""" + + data1 = np.arange(49).reshape(7, 7) + mask = np.ones(data1.shape) + ssim_res = 0.8963363560519094 + ssim_mask_res = 0.805154442543846 + snr_res = 10.134554256920536 + psnr_res = 14.860761791850397 + mse_res = 0.03265305507330247 + nrmse_res = 0.31136678840022625 + + @pytest.mark.skipif(not SKIMAGE_AVAILABLE, reason="skimage not installed") + @pytest.mark.parametrize( + ("data1", "data2", "result", "mask"), + [ + (data1, data1**2, ssim_res, None), + (data1, data1**2, ssim_mask_res, mask), + pytest.param( + data1, data1, None, 1, marks=pytest.mark.xfail(raises=ValueError) + ), + ], + ) + def test_ssim(self, data1, data2, result, mask): """Test ssim.""" - npt.assert_almost_equal( - metrics.ssim(self.data1, self.data1 ** 2), - self.ssim_res, - err_msg='Incorrect SSIM result', - ) - - npt.assert_almost_equal( - metrics.ssim(self.data1, self.data1 ** 2, mask=self.mask), - self.ssim_mask_res, - err_msg='Incorrect SSIM result', - ) + npt.assert_almost_equal(metrics.ssim(data1, data2, mask=mask), result) - npt.assert_raises( - ValueError, - metrics.ssim, - self.data1, - self.data1, - mask=1, - ) + @pytest.mark.skipif(SKIMAGE_AVAILABLE, reason="skimage installed") + def test_ssim_fail(self): + """Test ssim.""" + npt.assert_raises(ImportError, metrics.ssim, self.data1, self.data1) - def test_snr(self): + @pytest.mark.parametrize( + ("metric", "data", "result", "mask"), + [ + (metrics.snr, data1, snr_res, None), + (metrics.snr, data1, snr_res, mask), + (metrics.psnr, data1, psnr_res, None), + (metrics.psnr, data1, psnr_res, mask), + (metrics.mse, data1, mse_res, None), + (metrics.mse, data1, mse_res, mask), + (metrics.nrmse, data1, nrmse_res, None), + (metrics.nrmse, data1, nrmse_res, mask), + ], + ) + def test_metric(self, metric, data, result, mask): """Test snr.""" - npt.assert_almost_equal( - metrics.snr(self.data1, self.data1 ** 2), - self.snr_res, - err_msg='Incorrect SNR result', - ) - - npt.assert_almost_equal( - metrics.snr(self.data1, self.data1 ** 2, mask=self.mask), - self.snr_res, - err_msg='Incorrect SNR result', - ) - - def test_psnr(self): - """Test psnr.""" - npt.assert_almost_equal( - metrics.psnr(self.data1, self.data1 ** 2), - self.psnr_res, - err_msg='Incorrect PSNR result', - ) - - npt.assert_almost_equal( - metrics.psnr(self.data1, self.data1 ** 2, mask=self.mask), - self.psnr_res, - err_msg='Incorrect PSNR result', - ) - - def test_mse(self): - """Test mse.""" - npt.assert_almost_equal( - metrics.mse(self.data1, self.data1 ** 2), - self.mse_res, - err_msg='Incorrect MSE result', - ) - - npt.assert_almost_equal( - metrics.mse(self.data1, self.data1 ** 2, mask=self.mask), - self.mse_res, - err_msg='Incorrect MSE result', - ) - - def test_nrmse(self): - """Test nrmse.""" - npt.assert_almost_equal( - metrics.nrmse(self.data1, self.data1 ** 2), - self.nrmse_res, - err_msg='Incorrect NRMSE result', - ) - - npt.assert_almost_equal( - metrics.nrmse(self.data1, self.data1 ** 2, mask=self.mask), - self.nrmse_res, - err_msg='Incorrect NRMSE result', - ) - - -class StatsTestCase(TestCase): - """Test case for stats module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3) - self.data2 = np.arange(18).reshape(2, 3, 3) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - - @skipIf(import_astropy, 'Astropy is installed.') # pragma: no cover - def test_gaussian_kernel_astropy_error(self): - """Test gaussian_kernel astropy error.""" - npt.assert_raises( - ImportError, - stats.gaussian_kernel, - self.data1.shape, - 1, - ) - - @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover - def test_gaussian_kernel_max(self): - """Test gaussian_kernel with max norm.""" - npt.assert_allclose( - stats.gaussian_kernel(self.data1.shape, 1), - np.array([ - [0.36787944, 0.60653066, 0.36787944], - [0.60653066, 1.0, 0.60653066], - [0.36787944, 0.60653066, 0.36787944], - ]), - err_msg='Incorrect gaussian kernel: max norm', - ) - - npt.assert_raises( - ValueError, - stats.gaussian_kernel, - self.data1.shape, - 1, - norm='bla', - ) - - @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover - def test_gaussian_kernel_sum(self): - """Test gaussian_kernel with sum norm.""" - npt.assert_allclose( - stats.gaussian_kernel(self.data1.shape, 1, norm='sum'), - np.array([ - [0.07511361, 0.1238414, 0.07511361], - [0.1238414, 0.20417996, 0.1238414], - [0.07511361, 0.1238414, 0.07511361], - ]), - err_msg='Incorrect gaussian kernel: sum norm', - ) - - @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover - def test_gaussian_kernel_none(self): - """Test gaussian_kernel with no norm.""" - npt.assert_allclose( - stats.gaussian_kernel(self.data1.shape, 1, norm='none'), - np.array([ - [0.05854983, 0.09653235, 0.05854983], - [0.09653235, 0.15915494, 0.09653235], - [0.05854983, 0.09653235, 0.05854983], - ]), - err_msg='Incorrect gaussian kernel: sum norm', - ) + npt.assert_almost_equal(metric(data, data**2, mask=mask), result) + + +class TestStats: + """Test stats module.""" + + array33 = np.arange(9).reshape(3, 3) + array233 = np.arange(18).reshape(2, 3, 3) + + @pytest.mark.skipif(not ASTROPY_AVAILABLE, reason="astropy not installed") + @pytest.mark.parametrize( + ("norm", "result"), + [ + ( + "max", + np.array( + [ + [0.36787944, 0.60653066, 0.36787944], + [0.60653066, 1.0, 0.60653066], + [0.36787944, 0.60653066, 0.36787944], + ] + ), + ), + ( + "sum", + np.array( + [ + [0.07511361, 0.1238414, 0.07511361], + [0.1238414, 0.20417996, 0.1238414], + [0.07511361, 0.1238414, 0.07511361], + ] + ), + ), + ( + "none", + np.array( + [ + [0.05854983, 0.09653235, 0.05854983], + [0.09653235, 0.15915494, 0.09653235], + [0.05854983, 0.09653235, 0.05854983], + ] + ), + ), + pytest.param("fail", None, marks=pytest.mark.xfail(raises=ValueError)), + ], + ) + def test_gaussian_kernel(self, norm, result): + """Test gaussian kernel.""" + npt.assert_allclose(stats.gaussian_kernel(self.array33.shape, 1), result) + + @pytest.mark.skipif(ASTROPY_AVAILABLE, reason="astropy installed") + def test_import_astropy(self): + """Test missing astropy.""" + npt.assert_raises(ImportError, stats.gaussian_kernel, self.array33.shape, 1) def test_mad(self): """Test mad.""" - npt.assert_equal( - stats.mad(self.data1), - 2.0, - err_msg='Incorrect median absolute deviation', - ) - - def test_mse(self): - """Test mse.""" - npt.assert_equal( - stats.mse(self.data1, self.data1 + 2), - 4.0, - err_msg='Incorrect mean squared error', - ) + npt.assert_equal(stats.mad(self.array33), 2.0) - def test_psnr_starck(self): - """Test psnr.""" + def test_sigma_mad(self): + """Test sigma_mad.""" npt.assert_almost_equal( - stats.psnr(self.data1, self.data1 + 2), - 12.041199826559248, - err_msg='Incorrect PSNR: starck', - ) - - npt.assert_raises( - ValueError, - stats.psnr, - self.data1, - self.data1, - method='bla', + stats.sigma_mad(self.array33), + 2.9651999999999998, ) - def test_psnr_wiki(self): - """Test psnr wiki method.""" - npt.assert_almost_equal( - stats.psnr(self.data1, self.data1 + 2, method='wiki'), - 42.110203695399477, - err_msg='Incorrect PSNR: wiki', - ) + @pytest.mark.parametrize( + ("data1", "data2", "method", "result"), + [ + (array33, array33 + 2, "starck", 12.041199826559248), + pytest.param( + array33, array33, "fail", 0 , marks=pytest.mark.xfail(raises=ValueError) + ), + (array33, array33 + 2, "wiki", 42.110203695399477), + ], + ) + def test_psnr(self, data1, data2, method, result): + """Test psnr.""" + npt.assert_almost_equal(stats.psnr(data1, data2, method=method), result) def test_psnr_stack(self): """Test psnr stack.""" npt.assert_almost_equal( - stats.psnr_stack(self.data2, self.data2 + 2), + stats.psnr_stack(self.array233, self.array233 + 2), 12.041199826559248, - err_msg='Incorrect PSNR stack', ) - npt.assert_raises(ValueError, stats.psnr_stack, self.data1, self.data1) - - def test_sigma_mad(self): - """Test sigma_mad.""" - npt.assert_almost_equal( - stats.sigma_mad(self.data1), - 2.9651999999999998, - err_msg='Incorrect sigma from MAD', - ) + npt.assert_raises(ValueError, stats.psnr_stack, self.array33, self.array33) From 33a4ff098b372d61a652ae4f4f4295acd580aabd Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Wed, 30 Nov 2022 11:53:01 +0100 Subject: [PATCH 05/33] use fail/skipparam helper function. --- modopt/tests/test_helpers/__init__.py | 2 ++ modopt/tests/test_helpers/utils.py | 11 ++++++++ modopt/tests/test_math.py | 38 +++++++++------------------ setup.cfg | 1 + 4 files changed, 27 insertions(+), 25 deletions(-) create mode 100644 modopt/tests/test_helpers/__init__.py create mode 100644 modopt/tests/test_helpers/utils.py diff --git a/modopt/tests/test_helpers/__init__.py b/modopt/tests/test_helpers/__init__.py new file mode 100644 index 00000000..89a177ae --- /dev/null +++ b/modopt/tests/test_helpers/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python3 +from .utils import failparam, skipparam diff --git a/modopt/tests/test_helpers/utils.py b/modopt/tests/test_helpers/utils.py new file mode 100644 index 00000000..bcd544b2 --- /dev/null +++ b/modopt/tests/test_helpers/utils.py @@ -0,0 +1,11 @@ +import pytest + + +def failparam(*args, raises=ValueError): + """Return a pytest parametrization that should raise an error.""" + return pytest.param(*args, marks=pytest.mark.raises(exception=raises)) + + +def skipparam(*args, cond=True, reason=""): + """Return a pytest parametrization that should raise an error.""" + return pytest.param(*args, marks=pytest.mark.skipif(cond, reason=reason)) diff --git a/modopt/tests/test_math.py b/modopt/tests/test_math.py index 5d1fa743..86e5c629 100644 --- a/modopt/tests/test_math.py +++ b/modopt/tests/test_math.py @@ -6,6 +6,7 @@ :Author: Samuel Farrens """ import pytest +from test_helpers import failparam, skipparam import numpy as np import numpy.testing as npt @@ -72,26 +73,19 @@ class TestConvolve: @pytest.mark.parametrize( ("input_data", "kernel", "method", "result"), [ - pytest.param( + skipparam( array233[0], array233_1[0], "astropy", result_astropy, - marks=pytest.mark.skipif(not ASTROPY_AVAILABLE, reason="astropy not available"), + cond=not ASTROPY_AVAILABLE, + reason="astropy not available", ), - pytest.param( - array233[0], - array233_1, - "astropy", - result_astropy, - marks=pytest.mark.xfail(raises=ValueError), + failparam( + array233[0], array233_1, "astropy", result_astropy, raises=ValueError ), - pytest.param( - array233[0], - array233_1[0], - "fail!", - result_astropy, - marks=pytest.mark.xfail(raises=ValueError), + failparam( + array233[0], array233_1[0], "fail!", result_astropy, raises=ValueError ), (array233[0], array233_1[0], "scipy", result_scipy[0]), ], @@ -140,8 +134,8 @@ class TestMatrix: @pytest.fixture def pm_instance(self, request): - np.random.seed(1) """Power Method instance.""" + np.random.seed(1) pm = matrix.PowerMethod( lambda x_val: x_val.dot(x_val.T), self.array33.shape, @@ -158,9 +152,7 @@ def pm_instance(self, request): ("orthonormal", gram_schmidt_out[1]), ("orthogonal", gram_schmidt_out[0]), ("both", gram_schmidt_out), - pytest.param( - "fail!", gram_schmidt_out, marks=pytest.mark.xfail(raises=ValueError) - ), + failparam("fail!", gram_schmidt_out, raises=ValueError), ], ) def test_gram_schmidt(self, return_opt, output): @@ -228,9 +220,7 @@ class TestMetrics: [ (data1, data1**2, ssim_res, None), (data1, data1**2, ssim_mask_res, mask), - pytest.param( - data1, data1, None, 1, marks=pytest.mark.xfail(raises=ValueError) - ), + failparam(data1, data1, None, 1, raises=ValueError), ], ) def test_ssim(self, data1, data2, result, mask): @@ -300,7 +290,7 @@ class TestStats: ] ), ), - pytest.param("fail", None, marks=pytest.mark.xfail(raises=ValueError)), + failparam("fail", None, raises=ValueError), ], ) def test_gaussian_kernel(self, norm, result): @@ -327,9 +317,7 @@ def test_sigma_mad(self): ("data1", "data2", "method", "result"), [ (array33, array33 + 2, "starck", 12.041199826559248), - pytest.param( - array33, array33, "fail", 0 , marks=pytest.mark.xfail(raises=ValueError) - ), + failparam(array33, array33, "fail", 0, raises=ValueError), (array33, array33 + 2, "wiki", 42.110203695399477), ], ) diff --git a/setup.cfg b/setup.cfg index 810644e0..f8991e7e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -79,6 +79,7 @@ max-string-usages = 20 max-raises = 5 [tool:pytest] +norecursedirs=tests/test_helpers testpaths = modopt addopts = From 111262399253ecfcce618f17a7c8a5d213b17f53 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Wed, 30 Nov 2022 12:21:58 +0100 Subject: [PATCH 06/33] generalize usage of failparam --- develop.txt | 1 + modopt/tests/test_base.py | 14 +++++--------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/develop.txt b/develop.txt index 7397e15c..5105696e 100644 --- a/develop.txt +++ b/develop.txt @@ -2,6 +2,7 @@ coverage>=5.5 flake8>=4 nose>=1.3.7 pytest>=6.2.2 +pytest-raises>=0.10 pytest-cov>=2.11.1 pytest-pep8>=1.0.6 pytest-emoji>=0.2.0 diff --git a/modopt/tests/test_base.py b/modopt/tests/test_base.py index 2ebbbeb5..4551e05d 100644 --- a/modopt/tests/test_base.py +++ b/modopt/tests/test_base.py @@ -6,6 +6,7 @@ import numpy as np import numpy.testing as npt import pytest +from test_helpers import failparam, skipparam from modopt.base import backend, np_adjust, transform, types from modopt.base.backend import LIBRARIES @@ -48,7 +49,7 @@ def test_rotate_stack(self): 1, [1, 1], np.array([1, 1]), - pytest.param("1", marks=pytest.mark.xfail(raises=ValueError)), + failparam("1", raises=ValueError), ], ) def test_pad2d(self, padding): @@ -155,7 +156,7 @@ class TestType: (1, 1.0), (data_list, data_flt), (data_int, data_flt), - pytest.param("1.0", 1.0, marks=pytest.mark.xfail(raises=TypeError)), + failparam("1.0", 1.0, raises=TypeError), ], ) def test_check_float(self, data, checked): @@ -169,7 +170,7 @@ def test_check_float(self, data, checked): (1, 1), (data_list, data_int), (data_flt, data_int), - pytest.param("1", None, marks=pytest.mark.xfail(raises=TypeError)), + failparam("1", None, raises=TypeError), ], ) def test_check_int(self, data, checked): @@ -192,12 +193,7 @@ def test_check_npndarray(self, data, dtype): @pytest.mark.parametrize( "backend_name", [ - pytest.param( - name, - marks=pytest.mark.skipif( - LIBRARIES[name] is None, reason=f"{name} not installed" - ), - ) + skipparam(name, cond=LIBRARIES[name] is None, reason=f"{name} not installed") for name in LIBRARIES ], ) From 2add34d3921af7f9c6c3b39718e082fb5a7c43c4 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Wed, 30 Nov 2022 12:22:18 +0100 Subject: [PATCH 07/33] refactor test_signal. --- modopt/tests/test_signal.py | 466 ++++++------------------------------ 1 file changed, 78 insertions(+), 388 deletions(-) diff --git a/modopt/tests/test_signal.py b/modopt/tests/test_signal.py index 7490b98c..61fdec85 100644 --- a/modopt/tests/test_signal.py +++ b/modopt/tests/test_signal.py @@ -1,414 +1,104 @@ -# -*- coding: utf-8 -*- - """UNIT TESTS FOR SIGNAL. This module contains unit tests for the modopt.signal module. :Author: Samuel Farrens - +:Author: Pierre-Antoine Comby """ -from unittest import TestCase - import numpy as np import numpy.testing as npt +import pytest + +from test_helpers import failparam from modopt.signal import filter, noise, positivity, svd, validation, wavelet -class FilterTestCase(TestCase): - """Test case for filter module.""" +@pytest.mark.parametrize( + ("norm", "result"), [(True, 0.24197072451914337), (False, 0.60653065971263342)] +) +def test_gaussian_filter(norm, result): + """Test gaussian filter.""" + npt.assert_almost_equal(filter.gaussian_filter(1, 1, norm=norm), result) - def test_guassian_filter(self): - """Test guassian_filter.""" - npt.assert_almost_equal( - filter.gaussian_filter(1, 1), - 0.24197072451914337, - err_msg='Incorrect Gaussian filter', - ) - npt.assert_almost_equal( - filter.gaussian_filter(1, 1, norm=False), - 0.60653065971263342, - err_msg='Incorrect Gaussian filter', - ) +def test_mex_hat(): + """Test mex_hat.""" + npt.assert_almost_equal( + filter.mex_hat(2, 1), + -0.35213905225713371, + ) - def test_mex_hat(self): - """Test mex_hat.""" - npt.assert_almost_equal( - filter.mex_hat(2, 1), - -0.35213905225713371, - err_msg='Incorrect Mexican hat filter', - ) - def test_mex_hat_dir(self): - """Test mex_hat_dir.""" - npt.assert_almost_equal( - filter.mex_hat_dir(1, 2, 1), - 0.17606952612856686, - err_msg='Incorrect directional Mexican hat filter', - ) +def test_mex_hat_dir(): + """Test mex_hat_dir.""" + npt.assert_almost_equal( + filter.mex_hat_dir(1, 2, 1), + 0.17606952612856686, + ) -class NoiseTestCase(TestCase): - """Test case for noise module.""" +class TestNoise: + """Test noise module.""" - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.data2 = np.array( - [[0, 2.0, 2.0], [4.0, 5.0, 10], [11.0, 15.0, 18.0]], - ) - self.data3 = np.array([ + data1 = np.arange(9).reshape(3, 3).astype(float) + data2 = np.array( + [[0, 2.0, 2.0], [4.0, 5.0, 10], [11.0, 15.0, 18.0]], + ) + data3 = np.array( + [ [1.62434536, 0.38824359, 1.47182825], [1.92703138, 4.86540763, 2.6984613], [7.74481176, 6.2387931, 8.3190391], - ]) - self.data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]]) - self.data5 = np.array( - [[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]], - ) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.data3 = None - self.data4 = None - self.data5 = None - - def test_add_noise_poisson(self): - """Test add_noise with Poisson noise.""" - np.random.seed(1) - npt.assert_array_equal( - noise.add_noise(self.data1, noise_type='poisson'), - self.data2, - err_msg='Incorrect noise: Poisson', - ) - - npt.assert_raises( - ValueError, - noise.add_noise, - self.data1, - noise_type='bla', - ) - - npt.assert_raises(ValueError, noise.add_noise, self.data1, (1, 1)) - - def test_add_noise_gaussian(self): - """Test add_noise with Gaussian noise.""" - np.random.seed(1) - npt.assert_almost_equal( - noise.add_noise(self.data1), - self.data3, - err_msg='Incorrect noise: Gaussian', - ) - + ] + ) + data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]]) + data5 = np.array( + [[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]], + ) + + @pytest.mark.parametrize( + ("data", "noise_type", "sigma", "data_noise"), + [ + (data1, "poisson", 1, data2), + (data1, "gauss", 1, data3), + (data1, "gauss", (1, 1, 1), data3), + failparam(data1, "fail", 1, data1), + ], + ) + def test_add_noise(self, data, noise_type, sigma, data_noise): + """Test add_noise.""" np.random.seed(1) npt.assert_almost_equal( - noise.add_noise(self.data1, sigma=(1, 1, 1)), - self.data3, - err_msg='Incorrect noise: Gaussian', - ) - - def test_thresh_hard(self): - """Test thresh with hard threshold.""" - npt.assert_array_equal( - noise.thresh(self.data1, 5), - self.data4, - err_msg='Incorrect threshold: hard', - ) - - npt.assert_raises( - ValueError, - noise.thresh, - self.data1, - 5, - threshold_type='bla', - ) - - def test_thresh_soft(self): - """Test thresh with soft threshold.""" - npt.assert_array_equal( - noise.thresh(self.data1, 5, threshold_type='soft'), - self.data5, - err_msg='Incorrect threshold: soft', - ) - - -class PositivityTestCase(TestCase): - """Test case for positivity module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3) - 5 - self.data2 = np.array([[0, 0, 0], [0, 0, 0], [1, 2, 3]]) - self.data3 = np.array( - [np.arange(5) - 3, np.arange(4) - 2], - dtype=object, - ) - self.data4 = np.array( - [np.array([0, 0, 0, 0, 1]), np.array([0, 0, 0, 1])], - dtype=object, - ) - self.pos_dtype_obj = positivity.positive(self.data3) - self.err = 'Incorrect positivity' - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - - def test_positivity(self): - """Test positivity.""" - npt.assert_equal(positivity.positive(-1), 0, err_msg=self.err) - - npt.assert_equal( - positivity.positive(-1.0), - -float(0), - err_msg=self.err, - ) - - npt.assert_equal( - positivity.positive(self.data1), - self.data2, - err_msg=self.err, - ) - - for expected, output in zip(self.data4, self.pos_dtype_obj): - print(expected, output) - npt.assert_array_equal(expected, output, err_msg=self.err) - - npt.assert_raises(TypeError, positivity.positive, '-1') - - -class SVDTestCase(TestCase): - """Test case for svd module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(18).reshape(9, 2).astype(float) - self.data2 = np.arange(32).reshape(16, 2).astype(float) - self.data3 = np.array( - [ - np.array([ - [-0.01744594, -0.61438865], - [-0.08435304, -0.50397984], - [-0.15126014, -0.39357102], - [-0.21816724, -0.28316221], - [-0.28507434, -0.17275339], - [-0.35198144, -0.06234457], - [-0.41888854, 0.04806424], - [-0.48579564, 0.15847306], - [-0.55270274, 0.26888188], - ]), - np.array([42.23492742, 1.10041151]), - np.array([ - [-0.67608034, -0.73682791], - [0.73682791, -0.67608034], - ]), - ], - dtype=object, - ) - self.data4 = np.array([ - [-1.05426832e-16, 1.0], - [2.0, 3.0], - [4.0, 5.0], - [6.0, 7.0], - [8.0, 9.0], - [1.0e1, 1.1e1], - [1.2e1, 1.3e1], - [1.4e1, 1.5e1], - [1.6e1, 1.7e1], - ]) - self.data5 = np.array([ - [0.49815487, 0.54291537], - [2.40863386, 2.62505584], - [4.31911286, 4.70719631], - [6.22959185, 6.78933678], - [8.14007085, 8.87147725], - [10.05054985, 10.95361772], - [11.96102884, 13.03575819], - [13.87150784, 15.11789866], - [15.78198684, 17.20003913], - ]) - self.svd = svd.calculate_svd(self.data1) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.data3 = None - self.data4 = None - self.svd = None - - def test_find_n_pc(self): - """Test find_n_pc.""" - npt.assert_equal( - svd.find_n_pc(svd.svd(self.data2)[0]), - 2, - err_msg='Incorrect number of principal components.', - ) - - npt.assert_raises(ValueError, svd.find_n_pc, np.arange(3)) - - def test_calculate_svd(self): - """Test calculate_svd.""" - npt.assert_almost_equal( - self.svd[0], - np.array(self.data3)[0], - err_msg='Incorrect SVD calculation: U', - ) - - npt.assert_almost_equal( - self.svd[1], - np.array(self.data3)[1], - err_msg='Incorrect SVD calculation: S', - ) - - npt.assert_almost_equal( - self.svd[2], - np.array(self.data3)[2], - err_msg='Incorrect SVD calculation: V', - ) - - def test_svd_thresh(self): - """Test svd_thresh.""" - npt.assert_almost_equal( - svd.svd_thresh(self.data1), - self.data4, - err_msg='Incorrect SVD tresholding', - ) - - npt.assert_almost_equal( - svd.svd_thresh(self.data1, n_pc=1), - self.data5, - err_msg='Incorrect SVD tresholding', - ) - - npt.assert_almost_equal( - svd.svd_thresh(self.data1, n_pc='all'), - self.data1, - err_msg='Incorrect SVD tresholding', - ) - - npt.assert_raises(TypeError, svd.svd_thresh, 1) - - npt.assert_raises(ValueError, svd.svd_thresh, self.data1, n_pc='bla') - - def test_svd_thresh_coef(self): - """Test svd_thresh_coef.""" - npt.assert_almost_equal( - svd.svd_thresh_coef(self.data1, lambda x_val: x_val, 0), - self.data1, - err_msg='Incorrect SVD coefficient tresholding', - ) - - npt.assert_raises(TypeError, svd.svd_thresh_coef, self.data1, 0, 0) - - -class ValidationTestCase(TestCase): - """Test case for validation module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - - def test_transpose_test(self): - """Test transpose_test.""" - np.random.seed(2) - npt.assert_equal( - validation.transpose_test( - lambda x_val, y_val: x_val.dot(y_val), - lambda x_val, y_val: x_val.dot(y_val.T), - self.data1.shape, - x_args=self.data1, + noise.add_noise(data, sigma=sigma, noise_type=noise_type), data_noise + ) + + @pytest.mark.parametrize( + ("threshold_type", "result"), + [("hard", data4), ("soft", data5), failparam("fail", None, raises=ValueError)], + ) + def test_thresh(self, threshold_type, result): + """Test threshold.""" + npt.assert_array_equal(noise.thresh(self.data1, 5, threshold_type=threshold_type), result) + + @pytest.mark.parametrize( + ("value", "expected"), + [ + (-1.0, -float(0)), + (-1, 0), + (data1 - 5, data5), + ( + np.array([np.arange(3) - 1, np.arange(2) - 1], dtype=object), + np.array([np.array([0, 0, 1]),np.array([0, 0])], dtype=object), ), - None, - ) - - npt.assert_raises( - TypeError, - validation.transpose_test, - 0, - 0, - self.data1.shape, - x_args=self.data1, - ) - - -class WaveletTestCase(TestCase): - """Test case for wavelet module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.data2 = np.arange(36).reshape(4, 3, 3).astype(float) - self.data3 = np.array([ - [ - [6.0, 20, 26.0], - [36.0, 84.0, 84.0], - [90, 164.0, 134.0], - ], - [ - [78.0, 155.0, 134.0], - [225.0, 408.0, 327.0], - [270, 461.0, 350], - ], - [ - [150, 290, 242.0], - [414.0, 732.0, 570], - [450, 758.0, 566.0], - ], - [ - [222.0, 425.0, 350], - [603.0, 1056.0, 813.0], - [630, 1055.0, 782.0], - ], - ]) - - self.data4 = np.array([ - [6496.0, 9796.0, 6544.0], - [9924.0, 14910, 9924.0], - [6544.0, 9796.0, 6496.0], - ]) - - self.data5 = np.array([ - [[0, 1.0, 4.0], [3.0, 10, 13.0], [6.0, 19.0, 22.0]], - [[3.0, 10, 13.0], [24.0, 46.0, 40], [45.0, 82.0, 67.0]], - [[6.0, 19.0, 22.0], [45.0, 82.0, 67.0], [84.0, 145.0, 112.0]], - ]) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.data3 = None - self.data4 = None - self.data5 = None - - def test_filter_convolve(self): - """Test filter_convolve.""" - npt.assert_almost_equal( - wavelet.filter_convolve(self.data1, self.data2), - self.data3, - err_msg='Inccorect filter comvolution.', - ) - - npt.assert_almost_equal( - wavelet.filter_convolve(self.data2, self.data2, filter_rot=True), - self.data4, - err_msg='Inccorect filter comvolution.', - ) - - def test_filter_convolve_stack(self): - """Test filter_convolve_stack.""" - npt.assert_almost_equal( - wavelet.filter_convolve_stack(self.data1, self.data1), - self.data5, - err_msg='Inccorect filter stack comvolution.', - ) + failparam("-1", None, raises=TypeError), + ], + ) + def test_positive(self, value, expected): + """Test positive.""" + if isinstance(value, np.ndarray) and value.dtype == 'O': + for v, e in zip(positivity.positive(value), expected): + npt.assert_array_equal(v, e) + else: + npt.assert_array_equal(positivity.positive(value), expected) From b33c0d8dd7f0a016e33d1728b0244c52def62aff Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Wed, 30 Nov 2022 16:04:12 +0100 Subject: [PATCH 08/33] refactor test_signal, the end. --- modopt/tests/test_signal.py | 220 +++++++++++++++++++++++++++++++++++- 1 file changed, 217 insertions(+), 3 deletions(-) diff --git a/modopt/tests/test_signal.py b/modopt/tests/test_signal.py index 61fdec85..c46e1339 100644 --- a/modopt/tests/test_signal.py +++ b/modopt/tests/test_signal.py @@ -80,7 +80,9 @@ def test_add_noise(self, data, noise_type, sigma, data_noise): ) def test_thresh(self, threshold_type, result): """Test threshold.""" - npt.assert_array_equal(noise.thresh(self.data1, 5, threshold_type=threshold_type), result) + npt.assert_array_equal( + noise.thresh(self.data1, 5, threshold_type=threshold_type), result + ) @pytest.mark.parametrize( ("value", "expected"), @@ -90,15 +92,227 @@ def test_thresh(self, threshold_type, result): (data1 - 5, data5), ( np.array([np.arange(3) - 1, np.arange(2) - 1], dtype=object), - np.array([np.array([0, 0, 1]),np.array([0, 0])], dtype=object), + np.array([np.array([0, 0, 1]), np.array([0, 0])], dtype=object), ), failparam("-1", None, raises=TypeError), ], ) def test_positive(self, value, expected): """Test positive.""" - if isinstance(value, np.ndarray) and value.dtype == 'O': + if isinstance(value, np.ndarray) and value.dtype == "O": for v, e in zip(positivity.positive(value), expected): npt.assert_array_equal(v, e) else: npt.assert_array_equal(positivity.positive(value), expected) + + +class TestSVD: + """Test for svd module.""" + + @pytest.fixture + def data(self): + """Initialize test data.""" + data1 = np.arange(18).reshape(9, 2).astype(float) + data2 = np.arange(32).reshape(16, 2).astype(float) + data3 = np.array( + [ + np.array( + [ + [-0.01744594, -0.61438865], + [-0.08435304, -0.50397984], + [-0.15126014, -0.39357102], + [-0.21816724, -0.28316221], + [-0.28507434, -0.17275339], + [-0.35198144, -0.06234457], + [-0.41888854, 0.04806424], + [-0.48579564, 0.15847306], + [-0.55270274, 0.26888188], + ] + ), + np.array([42.23492742, 1.10041151]), + np.array( + [ + [-0.67608034, -0.73682791], + [0.73682791, -0.67608034], + ] + ), + ], + dtype=object, + ) + data4 = np.array( + [ + [-1.05426832e-16, 1.0], + [2.0, 3.0], + [4.0, 5.0], + [6.0, 7.0], + [8.0, 9.0], + [1.0e1, 1.1e1], + [1.2e1, 1.3e1], + [1.4e1, 1.5e1], + [1.6e1, 1.7e1], + ] + ) + + data5 = np.array( + [ + [0.49815487, 0.54291537], + [2.40863386, 2.62505584], + [4.31911286, 4.70719631], + [6.22959185, 6.78933678], + [8.14007085, 8.87147725], + [10.05054985, 10.95361772], + [11.96102884, 13.03575819], + [13.87150784, 15.11789866], + [15.78198684, 17.20003913], + ] + ) + return (data1, data2, data3, data4, data5) + + @pytest.fixture + def svd0(self, data): + """Compute SVD for data[0].""" + return svd.calculate_svd(data[0]) + + def test_find_n_pc(self, data): + """Test find_n_pc.""" + npt.assert_equal( + svd.find_n_pc(svd.svd(data[1])[0]), + 2, + err_msg="Incorrect number of principal components.", + ) + + def test_n_pc_fail_non_square(self): + """Test find_n_pc.""" + npt.assert_raises(ValueError, svd.find_n_pc, np.arange(3)) + + def test_calculate_svd(self, data, svd0): + """Test calculate_svd.""" + errors = [] + for i, name in enumerate("USV"): + try: + npt.assert_almost_equal(svd0[i], data[2][i]) + except AssertionError: + errors.append(name) + if errors: + raise AssertionError("Incorrect SVD calculation for: " + ", ".join(errors)) + + @pytest.mark.parametrize( + ("n_pc", "idx_res"), [(None, 3), (1, 4), ("all", 0), failparam("fail", 1)] + ) + def test_svd_thresh(self, data, n_pc, idx_res): + """Test svd_tresh.""" + npt.assert_almost_equal( + svd.svd_thresh(data[0], n_pc=n_pc), + data[idx_res], + ) + + def test_svd_tresh_invalid_type(self): + """Test svd_tresh failure.""" + npt.assert_raises(TypeError, svd.svd_thresh, 1) + + @pytest.mark.parametrize("operator", [lambda x: x, failparam(0, raises=TypeError)]) + def test_svd_thresh_coef(self, data, operator): + """Test svd_tresh_coef.""" + npt.assert_almost_equal( + svd.svd_thresh_coef(data[0], operator, 0), + data[0], + err_msg="Incorrect SVD coefficient tresholding", + ) + + # TODO test_svd_thresh_coef_fast + + +# TODO: is this module really necessary ? +# It is not use anywhere, not in modopt, nor pysap. +class TestValidation: + """Test validation Module.""" + + array33 = np.arange(9).reshape(3, 3) + + def test_transpose_test(self): + """Test transpose_test.""" + np.random.seed(2) + npt.assert_equal( + validation.transpose_test( + lambda x_val, y_val: x_val.dot(y_val), + lambda x_val, y_val: x_val.dot(y_val.T), + self.array33.shape, + x_args=self.array33, + ), + None, + ) + + +class TestWavelet: + """Test Wavelet Module.""" + + @pytest.fixture + def data(self): + """Set test parameter values.""" + data1 = np.arange(9).reshape(3, 3).astype(float) + data2 = np.arange(36).reshape(4, 3, 3).astype(float) + data3 = np.array( + [ + [ + [6.0, 20, 26.0], + [36.0, 84.0, 84.0], + [90, 164.0, 134.0], + ], + [ + [78.0, 155.0, 134.0], + [225.0, 408.0, 327.0], + [270, 461.0, 350], + ], + [ + [150, 290, 242.0], + [414.0, 732.0, 570], + [450, 758.0, 566.0], + ], + [ + [222.0, 425.0, 350], + [603.0, 1056.0, 813.0], + [630, 1055.0, 782.0], + ], + ] + ) + + data4 = np.array( + [ + [6496.0, 9796.0, 6544.0], + [9924.0, 14910, 9924.0], + [6544.0, 9796.0, 6496.0], + ] + ) + + data5 = np.array( + [ + [[0, 1.0, 4.0], [3.0, 10, 13.0], [6.0, 19.0, 22.0]], + [[3.0, 10, 13.0], [24.0, 46.0, 40], [45.0, 82.0, 67.0]], + [[6.0, 19.0, 22.0], [45.0, 82.0, 67.0], [84.0, 145.0, 112.0]], + ] + ) + return (data1, data2, data3, data4, data5) + + @pytest.mark.parametrize( + ("idx_data", "idx_filter", "idx_res", "filter_rot"), + [ + (0, 1, 2, False), (1, 1, 3, True) + ] + ) + def test_filter_convolve(self, data, idx_data, idx_filter, idx_res, filter_rot): + """Test filter_convolve.""" + npt.assert_almost_equal( + wavelet.filter_convolve( + data[idx_data], data[idx_filter], filter_rot=filter_rot + ), + data[idx_res], + err_msg="Inccorect filter comvolution.", + ) + + def test_filter_convolve_stack(self, data): + """Test filter_convolve_stack.""" + npt.assert_almost_equal( + wavelet.filter_convolve_stack(data[0], data[0]), + data[4], + err_msg="Inccorect filter stack comvolution.", + ) From 2299e3ff61f666e72f1e422b76606e9c8957873b Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Wed, 30 Nov 2022 16:17:25 +0100 Subject: [PATCH 09/33] lint --- modopt/tests/test_signal.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/modopt/tests/test_signal.py b/modopt/tests/test_signal.py index c46e1339..269b1052 100644 --- a/modopt/tests/test_signal.py +++ b/modopt/tests/test_signal.py @@ -9,7 +9,6 @@ import numpy as np import numpy.testing as npt import pytest - from test_helpers import failparam from modopt.signal import filter, noise, positivity, svd, validation, wavelet @@ -295,9 +294,7 @@ def data(self): @pytest.mark.parametrize( ("idx_data", "idx_filter", "idx_res", "filter_rot"), - [ - (0, 1, 2, False), (1, 1, 3, True) - ] + [(0, 1, 2, False), (1, 1, 3, True)], ) def test_filter_convolve(self, data, idx_data, idx_filter, idx_res, filter_rot): """Test filter_convolve.""" From e04b5efe6e8671e2807f85e0ab4575323bcd9191 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Tue, 6 Dec 2022 11:56:29 +0100 Subject: [PATCH 10/33] fix missing parameter. --- modopt/tests/test_math.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modopt/tests/test_math.py b/modopt/tests/test_math.py index 86e5c629..ee2cabfd 100644 --- a/modopt/tests/test_math.py +++ b/modopt/tests/test_math.py @@ -295,7 +295,9 @@ class TestStats: ) def test_gaussian_kernel(self, norm, result): """Test gaussian kernel.""" - npt.assert_allclose(stats.gaussian_kernel(self.array33.shape, 1), result) + npt.assert_allclose( + stats.gaussian_kernel(self.array33.shape, 1, norm=norm), result + ) @pytest.mark.skipif(ASTROPY_AVAILABLE, reason="astropy installed") def test_import_astropy(self): From bc4c06ca58058487675a006493088681d3efae23 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Tue, 6 Dec 2022 11:56:46 +0100 Subject: [PATCH 11/33] add dummy object test helper. --- modopt/tests/test_helpers/__init__.py | 2 +- modopt/tests/test_helpers/utils.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/modopt/tests/test_helpers/__init__.py b/modopt/tests/test_helpers/__init__.py index 89a177ae..eadbc7bc 100644 --- a/modopt/tests/test_helpers/__init__.py +++ b/modopt/tests/test_helpers/__init__.py @@ -1,2 +1,2 @@ #!/usr/bin/env python3 -from .utils import failparam, skipparam +from .utils import failparam, skipparam, Dummy diff --git a/modopt/tests/test_helpers/utils.py b/modopt/tests/test_helpers/utils.py index bcd544b2..ad3795d5 100644 --- a/modopt/tests/test_helpers/utils.py +++ b/modopt/tests/test_helpers/utils.py @@ -9,3 +9,6 @@ def failparam(*args, raises=ValueError): def skipparam(*args, cond=True, reason=""): """Return a pytest parametrization that should raise an error.""" return pytest.param(*args, marks=pytest.mark.skipif(cond, reason=reason)) + +class Dummy: + pass From 6ed61ba7339b3e223babbe3167c2a60b91dcf8a8 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Tue, 6 Dec 2022 11:57:03 +0100 Subject: [PATCH 12/33] rewrite test for cost and gradients. --- modopt/tests/test_opt.py | 1161 ++++---------------------------------- 1 file changed, 115 insertions(+), 1046 deletions(-) diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py index d5547783..00ece76c 100644 --- a/modopt/tests/test_opt.py +++ b/modopt/tests/test_opt.py @@ -1,1071 +1,140 @@ -# -*- coding: utf-8 -*- - """UNIT TESTS FOR OPT. -This module contains unit tests for the modopt.opt module. +This module contains tests for the modopt.opt module. :Author: Samuel Farrens - +:Author: Pierre-Antoine Comby """ -from builtins import zip -from unittest import TestCase, skipIf, skipUnless - import numpy as np import numpy.testing as npt +import pytest +from pytest_cases import parametrize, parametrize_with_cases, case -from modopt.opt import algorithms, cost, gradient, linear, proximity, reweight +from modopt.opt import cost, gradient, linear, proximity, reweight +from test_helpers import Dummy + +SKLEARN_AVAILABLE = True try: import sklearn -except ImportError: # pragma: no cover - import_sklearn = False -else: - import_sklearn = True +except ImportError: + SKLEARN_AVAILABLE = False # Basic functions to be used as operators or as dummy functions func_identity = lambda x_val: x_val func_double = lambda x_val: x_val * 2 -func_sq = lambda x_val: x_val ** 2 -func_cube = lambda x_val: x_val ** 3 - - -class Dummy(object): - """Dummy class for tests.""" - - pass - - -class AlgorithmTestCase(TestCase): - """Test case for algorithms module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.data2 = self.data1 + np.random.randn(*self.data1.shape) * 1e-6 - self.data3 = np.arange(9).reshape(3, 3).astype(float) + 1 - - grad_inst = gradient.GradBasic( - self.data1, - func_identity, - func_identity, - ) - - prox_inst = proximity.Positivity() - prox_dual_inst = proximity.IdentityProx() - linear_inst = linear.Identity() - reweight_inst = reweight.cwbReweight(self.data3) - cost_inst = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) - self.setup = algorithms.SetUp() - self.max_iter = 20 - - self.fb_all_iter = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=None, - auto_iterate=False, - beta_update=func_identity, - ) - self.fb_all_iter.iterate(self.max_iter) - - self.fb1 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - ) - - self.fb2 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - lambda_update=None, - ) - - self.fb3 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - a_cd=3, - ) - - self.fb4 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - r_lazy=3, - p_lazy=0.7, - q_lazy=0.7, - ) - - self.fb5 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - restart_strategy='adaptive', - xi_restart=0.9, - ) - - self.fb6 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - restart_strategy='greedy', - xi_restart=0.9, - min_beta=1.0, - s_greedy=1.1, - ) - - self.gfb_all_iter = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - cost=None, - auto_iterate=False, - gamma_update=func_identity, - beta_update=func_identity, - ) - self.gfb_all_iter.iterate(self.max_iter) - - self.gfb1 = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - gamma_update=func_identity, - lambda_update=func_identity, - ) - - self.gfb2 = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - cost=cost_inst, - ) - - self.gfb3 = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - cost=cost_inst, - step_size=2, - ) - - self.condat_all_iter = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - cost=None, - prox_dual=prox_dual_inst, - sigma_update=func_identity, - tau_update=func_identity, - rho_update=func_identity, - auto_iterate=False, - ) - self.condat_all_iter.iterate(self.max_iter) - - self.condat1 = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - prox_dual=prox_dual_inst, - sigma_update=func_identity, - tau_update=func_identity, - rho_update=func_identity, - ) - - self.condat2 = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - prox_dual=prox_dual_inst, - linear=linear_inst, - cost=cost_inst, - reweight=reweight_inst, - ) - - self.condat3 = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - prox_dual=prox_dual_inst, - linear=Dummy(), - cost=cost_inst, - auto_iterate=False, - ) - - self.pogm_all_iter = algorithms.POGM( - u=self.data1, - x=self.data1, - y=self.data1, - z=self.data1, - grad=grad_inst, - prox=prox_inst, - auto_iterate=False, - cost=None, - ) - self.pogm_all_iter.iterate(self.max_iter) - - self.pogm1 = algorithms.POGM( - u=self.data1, - x=self.data1, - y=self.data1, - z=self.data1, - grad=grad_inst, - prox=prox_inst, - ) - - self.dummy = Dummy() - self.dummy.cost = func_identity - self.setup._check_operator(self.dummy.cost) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.setup = None - self.fb_all_iter = None - self.fb1 = None - self.fb2 = None - self.gfb_all_iter = None - self.gfb1 = None - self.gfb2 = None - self.condat_all_iter = None - self.condat1 = None - self.condat2 = None - self.condat3 = None - self.pogm1 = None - self.pogm_all_iter = None - self.dummy = None - - def test_set_up(self): - """Test set_up.""" - npt.assert_raises(TypeError, self.setup._check_input_data, 1) - - npt.assert_raises(TypeError, self.setup._check_param, 1) - - npt.assert_raises(TypeError, self.setup._check_param_update, 1) - - def test_all_iter(self): - """Test if all opt run for all iterations.""" - opts = [ - self.fb_all_iter, - self.gfb_all_iter, - self.condat_all_iter, - self.pogm_all_iter, - ] - for opt in opts: - npt.assert_equal(opt.idx, self.max_iter - 1) - - def test_forward_backward(self): - """Test forward_backward.""" - npt.assert_array_equal( - self.fb1.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb2.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb3.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb4.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb5.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb6.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - def test_gen_forward_backward(self): - """Test gen_forward_backward.""" - npt.assert_array_equal( - self.gfb1.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_array_equal( - self.gfb2.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_array_equal( - self.gfb3.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_equal( - self.gfb3.step_size, - 2, - err_msg='Incorrect step size.', - ) - - npt.assert_raises( - TypeError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=1, - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[1], - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[0.5, 0.5], - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[0.5], - ) - - def test_condat(self): - """Test gen_condat.""" - npt.assert_almost_equal( - self.condat1.x_final, - self.data1, - err_msg='Incorrect Condat result.', - ) - - npt.assert_almost_equal( - self.condat2.x_final, - self.data1, - err_msg='Incorrect Condat result.', - ) - - def test_pogm(self): - """Test pogm.""" - npt.assert_almost_equal( - self.pogm1.x_final, - self.data1, - err_msg='Incorrect POGM result.', - ) - - -class CostTestCase(TestCase): - """Test case for cost module.""" - - def setUp(self): - """Set test parameter values.""" - dummy_inst1 = Dummy() - dummy_inst1.cost = func_sq - dummy_inst2 = Dummy() - dummy_inst2.cost = func_cube - - self.inst1 = cost.costObj([dummy_inst1, dummy_inst2]) - self.inst2 = cost.costObj([dummy_inst1, dummy_inst2], cost_interval=2) - # Test that by default cost of False if interval is None - self.inst_none = cost.costObj( - [dummy_inst1, dummy_inst2], - cost_interval=None, - ) - for _ in range(2): - self.inst1.get_cost(2) - for _ in range(6): - self.inst2.get_cost(2) - self.inst_none.get_cost(2) - self.dummy = Dummy() - - def tearDown(self): - """Unset test parameter values.""" - self.inst = None - - def test_cost_object(self): - """Test cost_object.""" - npt.assert_equal( - self.inst1.get_cost(2), - False, - err_msg='Incorrect cost test result.', - ) - npt.assert_equal( - self.inst1.get_cost(2), - True, - err_msg='Incorrect cost test result.', - ) - npt.assert_equal( - self.inst_none.get_cost(2), - False, - err_msg='Incorrect cost test result.', - ) - - npt.assert_equal(self.inst1.cost, 12, err_msg='Incorrect cost value.') - - npt.assert_equal(self.inst2.cost, 12, err_msg='Incorrect cost value.') - - npt.assert_raises(TypeError, cost.costObj, 1) - - npt.assert_raises(ValueError, cost.costObj, [self.dummy, self.dummy]) - - -class GradientTestCase(TestCase): - """Test case for gradient module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.gp = gradient.GradParent( - self.data1, - func_sq, - func_cube, - func_identity, - lambda input_val: 1.0, - data_type=np.floating, - ) - self.gp.grad = self.gp.get_grad(self.data1) - self.gb = gradient.GradBasic( - self.data1, - func_sq, - func_cube, - ) - self.gb.get_grad(self.data1) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.gp = None - self.gb = None - - def test_grad_parent_operators(self): - """Test GradParent.""" - npt.assert_array_equal( - self.gp.op(self.data1), - np.array([[0, 1.0, 4.0], [9.0, 16.0, 25.0], [36.0, 49.0, 64.0]]), - err_msg='Incorrect gradient operation.', - ) - - npt.assert_array_equal( - self.gp.trans_op(self.data1), - np.array( - [[0, 1.0, 8.0], [27.0, 64.0, 125.0], [216.0, 343.0, 512.0]], - ), - err_msg='Incorrect gradient transpose operation.', - ) - - npt.assert_array_equal( - self.gp.trans_op_op(self.data1), - np.array([ - [0, 1.0, 6.40000000e1], - [7.29000000e2, 4.09600000e3, 1.56250000e4], - [4.66560000e4, 1.17649000e5, 2.62144000e5], - ]), - err_msg='Incorrect gradient transpose operation operation.', - ) - - npt.assert_equal( - self.gp.cost(self.data1), - 1.0, - err_msg='Incorrect cost.', - ) - - npt.assert_raises( - TypeError, - gradient.GradParent, - 1, - func_sq, - func_cube, - ) - - def test_grad_basic_gradient(self): - """Test GradBasic.""" - npt.assert_array_equal( - self.gb.grad, - np.array([ +func_sq = lambda x_val: x_val**2 +func_cube = lambda x_val: x_val**3 + + +@case(tags="cost") +@parametrize( + ("cost_interval", "n_calls", "converged"), + [(1, 1, False), (1, 2, True), (2, 5, False), (None, 6, False)], +) +def case_cost_op(cost_interval, n_calls, converged): + """Case function for costs.""" + dummy_inst1 = Dummy() + dummy_inst1.cost = func_sq + dummy_inst2 = Dummy() + dummy_inst2.cost = func_cube + + cost_obj = cost.costObj([dummy_inst1, dummy_inst2], cost_interval=cost_interval) + + for _ in range(n_calls + 1): + cost_obj.get_cost(2) + return cost_obj, converged + + +@parametrize_with_cases("cost_obj, converged", cases=".", has_tag="cost") +def test_costs(cost_obj, converged): + """Test cost.""" + npt.assert_equal(cost_obj.get_cost(2), converged) + if cost_obj._cost_interval: + npt.assert_equal(cost_obj.cost, 12) + + +def test_raise_cost(): + """Test error raising for cost.""" + npt.assert_raises(TypeError, cost.costObj, 1) + npt.assert_raises(ValueError, cost.costObj, [Dummy(), Dummy()]) + + +@case(tags="grad") +@parametrize(call=("op", "trans_op", "trans_op_op")) +def case_grad_parent(call): + """Case for gradient parent.""" + input_data = np.arange(9).reshape(3, 3) + callables = { + "op": func_sq, + "trans_op": func_cube, + "get_grad": func_identity, + "cost": lambda input_val: 1.0, + } + + grad_op = gradient.GradParent( + input_data, + **callables, + data_type=np.floating, + ) + if call != "trans_op_op": + result = callables[call](input_data) + else: + result = callables["trans_op"](callables["op"](input_data)) + + grad_call = getattr(grad_op, call)(input_data) + return grad_call, result + + +@parametrize_with_cases("grad_values, result", cases=".", has_tag="grad") +def test_grad_op(grad_values, result): + """Test Gradient operator.""" + npt.assert_equal(grad_values, result) + + +@pytest.fixture +def grad_basic(): + """Case for GradBasic.""" + input_data = np.arange(9).reshape(3, 3) + grad_op = gradient.GradBasic( + input_data, + func_sq, + func_cube, + verbose=True, + ) + grad_op.get_grad(input_data) + return grad_op + + +def test_grad_basic(grad_basic): + """Test grad basic.""" + npt.assert_array_equal( + grad_basic.grad, + np.array( + [ [0, 0, 8.0], [2.16000000e2, 1.72800000e3, 8.0e3], [2.70000000e4, 7.40880000e4, 1.75616000e5], - ]), - err_msg='Incorrect gradient.', - ) - - -class LinearTestCase(TestCase): - """Test case for linear module.""" - - def setUp(self): - """Set test parameter values.""" - self.parent = linear.LinearParent( - func_sq, - func_cube, - ) - self.ident = linear.Identity() - filters = np.arange(8).reshape(2, 2, 2).astype(float) - self.wave = linear.WaveletConvolve(filters) - self.combo = linear.LinearCombo([self.parent, self.parent]) - self.combo_weight = linear.LinearCombo( - [self.parent, self.parent], - [1.0, 1.0], - ) - self.data1 = np.arange(18).reshape(2, 3, 3).astype(float) - self.data2 = np.arange(4).reshape(1, 2, 2).astype(float) - self.data3 = np.arange(8).reshape(1, 2, 2, 2).astype(float) - self.data4 = np.array([[[[0, 0], [0, 4.0]], [[0, 4.0], [8.0, 28.0]]]]) - self.data5 = np.array([[[28.0, 62.0], [68.0, 140.0]]]) - self.dummy = Dummy() - - def tearDown(self): - """Unset test parameter values.""" - self.parent = None - self.ident = None - self.combo = None - self.combo_weight = None - self.data1 = None - self.data2 = None - self.data3 = None - self.data4 = None - self.data5 = None - self.dummy = None - - def test_linear_parent(self): - """Test LinearParent.""" - npt.assert_equal( - self.parent.op(2), - 4, - err_msg='Incorrect linear parent operation.', - ) - - npt.assert_equal( - self.parent.adj_op(2), - 8, - err_msg='Incorrect linear parent adjoint operation.', - ) - - npt.assert_raises(TypeError, linear.LinearParent, 0, 0) - - def test_identity(self): - """Test Identity.""" - npt.assert_equal( - self.ident.op(1.0), - 1.0, - err_msg='Incorrect identity operation.', - ) - - npt.assert_equal( - self.ident.adj_op(1.0), - 1.0, - err_msg='Incorrect identity adjoint operation.', - ) - - def test_wavelet_convolve(self): - """Test WaveletConvolve.""" - npt.assert_almost_equal( - self.wave.op(self.data2), - self.data4, - err_msg='Incorrect wavelet convolution operation.', - ) - - npt.assert_almost_equal( - self.wave.adj_op(self.data3), - self.data5, - err_msg='Incorrect wavelet convolution adjoint operation.', - ) - - def test_linear_combo(self): - """Test LinearCombo.""" - npt.assert_equal( - self.combo.op(2), - np.array([4, 4]).astype(object), - err_msg='Incorrect combined linear operation', - ) - - npt.assert_equal( - self.combo.adj_op([2, 2]), - 8.0, - err_msg='Incorrect combined linear adjoint operation', - ) - - npt.assert_raises(TypeError, linear.LinearCombo, self.parent) - - npt.assert_raises(ValueError, linear.LinearCombo, []) - - npt.assert_raises(ValueError, linear.LinearCombo, [self.dummy]) - - self.dummy.op = func_identity - - npt.assert_raises(ValueError, linear.LinearCombo, [self.dummy]) - - def test_linear_combo_weight(self): - """Test LinearCombo with weight .""" - npt.assert_equal( - self.combo_weight.op(2), - np.array([4, 4]).astype(object), - err_msg='Incorrect combined linear operation', - ) - - npt.assert_equal( - self.combo_weight.adj_op([2, 2]), - 16.0, - err_msg='Incorrect combined linear adjoint operation', - ) - - npt.assert_raises( - ValueError, - linear.LinearCombo, - [self.parent, self.parent], - [1.0], - ) - - npt.assert_raises( - TypeError, - linear.LinearCombo, - [self.parent, self.parent], - ['1', '1'], - ) - - -class ProximityTestCase(TestCase): - """Test case for proximity module.""" - - def setUp(self): - """Set test parameter values.""" - self.parent = proximity.ProximityParent( - func_sq, - func_double, - ) - self.identity = proximity.IdentityProx() - self.positivity = proximity.Positivity() - weights = np.ones(9).reshape(3, 3).astype(float) * 3 - self.sparsethresh = proximity.SparseThreshold( - linear.Identity(), - weights, - ) - self.lowrank = proximity.LowRankMatrix(10.0, thresh_type='hard') - self.lowrank_rank = proximity.LowRankMatrix( - 10.0, - initial_rank=1, - thresh_type='hard', - ) - self.lowrank_ngole = proximity.LowRankMatrix( - 10.0, - lowr_type='ngole', - operator=func_double, - ) - self.linear_comp = proximity.LinearCompositionProx( - linear_op=linear.Identity(), - prox_op=self.sparsethresh, - ) - self.combo = proximity.ProximityCombo([self.identity, self.positivity]) - if import_sklearn: - self.owl = proximity.OrderedWeightedL1Norm(weights.flatten()) - self.ridge = proximity.Ridge(linear.Identity(), weights) - self.elasticnet_alpha0 = proximity.ElasticNet( - linear.Identity(), - alpha=0, - beta=weights, - ) - self.elasticnet_beta0 = proximity.ElasticNet( - linear.Identity(), - alpha=weights, - beta=0, - ) - self.one_support = proximity.KSupportNorm(beta=0.2, k_value=1) - self.five_support_norm = proximity.KSupportNorm(beta=3, k_value=5) - self.d_support = proximity.KSupportNorm(beta=3.0 * 2, k_value=19) - self.group_lasso = proximity.GroupLASSO( - weights=np.tile(weights, (4, 1, 1)), - ) - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.data2 = np.array([[-0, -0, -0], [0, 1.0, 2.0], [3.0, 4.0, 5.0]]) - self.data3 = np.arange(18).reshape(2, 3, 3).astype(float) - self.data4 = np.array([ - [ - [2.73843189, 3.14594066, 3.55344943], - [3.9609582, 4.36846698, 4.77597575], - [5.18348452, 5.59099329, 5.99850206], - ], - [ - [8.07085295, 9.2718846, 10.47291625], - [11.67394789, 12.87497954, 14.07601119], - [15.27704284, 16.47807449, 17.67910614], - ], - ]) - self.data5 = np.array([ - [[0, 0, 0], [0, 0, 0], [0, 0, 0]], - [ - [4.00795282, 4.60438026, 5.2008077], - [5.79723515, 6.39366259, 6.99009003], - [7.58651747, 8.18294492, 8.77937236], - ], - ]) - self.data6 = self.data3 * -1 - self.data7 = self.combo.op(self.data6) - self.data8 = np.empty(2, dtype=np.ndarray) - self.data8[0] = np.array( - [[-0, -1.0, -2.0], [-3.0, -4.0, -5.0], [-6.0, -7.0, -8.0]], - ) - self.data8[1] = np.array( - [[-0, -0, -0], [-0, -0, -0], [-0, -0, -0]], - ) - self.data9 = self.data1 * (1 + 1j) - self.data10 = self.data9 / (2 * 3 + 1) - self.data11 = np.asarray( - [[0, 0, 0], [0, 1.0, 1.25], [1.5, 1.75, 2.0]], - ) - self.random_data = 3 * np.random.random( - self.group_lasso.weights[0].shape, - ) - self.random_data_tile = np.tile( - self.random_data, - (self.group_lasso.weights.shape[0], 1, 1), - ) - self.gl_result_data = 2 * self.random_data_tile - 3 - self.gl_result_data = np.array( - (self.gl_result_data * (self.gl_result_data > 0).astype('int')) - / 2, - ) - - self.dummy = Dummy() - - def tearDown(self): - """Unset test parameter values.""" - self.parent = None - self.identity = None - self.positivity = None - self.sparsethresh = None - self.lowrank = None - self.lowrank_rank = None - self.lowrank_ngole = None - self.combo = None - self.data1 = None - self.data2 = None - self.data3 = None - self.data4 = None - self.data5 = None - self.data6 = None - self.data7 = None - self.data8 = None - self.dummy = None - self.random_data = None - self.random_data_tile = None - self.gl_result_data = None - - def test_proximity_parent(self): - """Test ProximityParent.""" - npt.assert_equal( - self.parent.op(3), - 9, - err_msg='Inccoret proximity parent operation.', - ) - - npt.assert_equal( - self.parent.cost(3), - 6, - err_msg='Incorrect proximity parent cost.', - ) - - def test_identity(self): - """Test IdentityProx.""" - npt.assert_equal( - self.identity.op(3), - 3, - err_msg='Incorrect proximity identity operation.', - ) - - npt.assert_equal( - self.identity.cost(3), - 0, - err_msg='Incorrect proximity identity cost.', - ) - - def test_positivity(self): - """Test Positivity.""" - npt.assert_equal( - self.positivity.op(-3), - 0, - err_msg='Incorrect proximity positivity operation.', - ) - - npt.assert_equal( - self.positivity.cost(-3, verbose=True), - 0, - err_msg='Incorrect proximity positivity cost.', - ) - - def test_sparse_threshold(self): - """Test SparseThreshold.""" - npt.assert_array_equal( - self.sparsethresh.op(self.data1), - self.data2, - err_msg='Incorrect sparse threshold operation.', - ) - - npt.assert_equal( - self.sparsethresh.cost(self.data1, verbose=True), - 108.0, - err_msg='Incorrect sparse threshold cost.', - ) - - def test_low_rank_matrix(self): - """Test LowRankMatrix.""" - npt.assert_almost_equal( - self.lowrank.op(self.data3), - self.data4, - err_msg='Incorrect low rank operation: standard', - ) - - npt.assert_almost_equal( - self.lowrank_rank.op(self.data3), - self.data4, - err_msg='Incorrect low rank operation: standard with rank', - ) - npt.assert_almost_equal( - self.lowrank_ngole.op(self.data3), - self.data5, - err_msg='Incorrect low rank operation: ngole', - ) - - npt.assert_almost_equal( - self.lowrank.cost(self.data3, verbose=True), - 469.39132942464983, - err_msg='Incorrect low rank cost.', - ) - - def test_linear_comp_prox(self): - """Test LinearCompositionProx.""" - npt.assert_array_equal( - self.linear_comp.op(self.data1), - self.data2, - err_msg='Incorrect sparse threshold operation.', - ) - - npt.assert_equal( - self.linear_comp.cost(self.data1, verbose=True), - 108.0, - err_msg='Incorrect sparse threshold cost.', - ) - - def test_proximity_combo(self): - """Test ProximityCombo.""" - for data7, data8 in zip(self.data7, self.data8): - npt.assert_array_equal( - data7, - data8, - err_msg='Incorrect combined operation', - ) - - npt.assert_equal( - self.combo.cost(self.data6), - 0, - err_msg='Incorrect combined cost.', - ) - - npt.assert_raises(TypeError, proximity.ProximityCombo, 1) - - npt.assert_raises(ValueError, proximity.ProximityCombo, []) - - npt.assert_raises(ValueError, proximity.ProximityCombo, [self.dummy]) - - self.dummy.op = func_identity - - npt.assert_raises(ValueError, proximity.ProximityCombo, [self.dummy]) - - @skipIf(import_sklearn, 'sklearn is installed.') # pragma: no cover - def test_owl_sklearn_error(self): - """Test OrderedWeightedL1Norm with Scikit-Learn.""" - npt.assert_raises(ImportError, proximity.OrderedWeightedL1Norm, 1) - - @skipUnless(import_sklearn, 'sklearn not installed.') # pragma: no cover - def test_sparse_owl(self): - """Test OrderedWeightedL1Norm.""" - npt.assert_array_equal( - self.owl.op(self.data1.flatten()), - self.data2.flatten(), - err_msg='Incorrect sparse threshold operation.', - ) - - npt.assert_equal( - self.owl.cost(self.data1.flatten(), verbose=True), - 108.0, - err_msg='Incorrect sparse threshold cost.', - ) - - npt.assert_raises( - ValueError, - proximity.OrderedWeightedL1Norm, - np.arange(10), - ) - - def test_ridge(self): - """Test Ridge.""" - npt.assert_array_equal( - self.ridge.op(self.data9), - self.data10, - err_msg='Incorect shrinkage operation.', - ) - - npt.assert_equal( - self.ridge.cost(self.data9, verbose=True), - 408.0 * 3.0, - err_msg='Incorect shrinkage cost.', - ) - - def test_elastic_net_alpha0(self): - """Test ElasticNet.""" - npt.assert_array_equal( - self.elasticnet_alpha0.op(self.data1), - self.data2, - err_msg='Incorect sparse threshold operation ElasticNet class.', - ) - - npt.assert_equal( - self.elasticnet_alpha0.cost(self.data1), - 108.0, - err_msg='Incorect shrinkage cost in ElasticNet class.', - ) - - def test_elastic_net_beta0(self): - """Test ElasticNet with beta=0.""" - npt.assert_array_equal( - self.elasticnet_beta0.op(self.data9), - self.data10, - err_msg='Incorect ridge operation ElasticNet class.', - ) - - npt.assert_equal( - self.elasticnet_beta0.cost(self.data9, verbose=True), - 408.0 * 3.0, - err_msg='Incorect shrinkage cost in ElasticNet class.', - ) - - def test_one_support_norm(self): - """Test KSupportNorm with k=1.""" - npt.assert_allclose( - self.one_support.op(self.data1.flatten()), - self.data2.flatten(), - err_msg='Incorect sparse threshold operation for 1-support norm', - rtol=1e-6, - ) - - npt.assert_equal( - self.one_support.cost(self.data1.flatten(), verbose=True), - 259.2, - err_msg='Incorect sparse threshold cost.', - ) - - npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0) - - def test_five_support_norm(self): - """Test KSupportNorm with k=5.""" - npt.assert_allclose( - self.five_support_norm.op(self.data1.flatten()), - self.data11.flatten(), - err_msg='Incorect sparse Ksupport norm operation', - rtol=1e-6, - ) - - npt.assert_equal( - self.five_support_norm.cost(self.data1.flatten(), verbose=True), - 684.0, - err_msg='Incorrect 5-support norm cost.', - ) - - npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0) - - def test_d_support_norm(self): - """Test KSupportNorm with k=19.""" - npt.assert_allclose( - self.d_support.op(self.data9.flatten()), - self.data10.flatten(), - err_msg='Incorect shrinkage operation for d-support norm', - rtol=1e-6, - ) - - npt.assert_almost_equal( - self.d_support.cost(self.data9.flatten(), verbose=True), - 408.0 * 3.0, - err_msg='Incorrect shrinkage cost for d-support norm.', - ) - - npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0) - - def test_group_lasso(self): - """Test GroupLASSO.""" - npt.assert_allclose( - self.group_lasso.op(self.random_data_tile), - self.gl_result_data, - ) - npt.assert_equal( - self.group_lasso.cost(self.random_data_tile), - np.sum(6 * self.random_data_tile), - ) - # Check that for 0 weights operator doesnt change result - self.group_lasso.weights = np.zeros_like(self.group_lasso.weights) - npt.assert_equal( - self.group_lasso.op(self.random_data_tile), - self.random_data_tile, - ) - npt.assert_equal(self.group_lasso.cost(self.random_data_tile), 0) - + ] + ), + err_msg="Incorrect gradient.", + ) -class ReweightTestCase(TestCase): - """Test case for reweight module.""" - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) + 1 - self.data2 = np.array( - [[0.5, 1.0, 1.5], [2.0, 2.5, 3.0], [3.5, 4.0, 4.5]], - ) - self.rw = reweight.cwbReweight(self.data1) - self.rw.reweight(self.data1) +def test_grad_basic_cost(grad_basic): + """Test grad_basic cost.""" + npt.assert_almost_equal(grad_basic.cost(np.arange(9).reshape(3,3)), 3192.0) - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.rw = None - def test_cwbreweight(self): - """Test cwbReweight.""" - npt.assert_array_equal( - self.rw.weights, - self.data2, - err_msg='Incorrect CWB re-weighting.', - ) - npt.assert_raises(ValueError, self.rw.reweight, self.data1[0]) +def test_grad_op_raises(): + """Test raise error.""" + npt.assert_raises( + TypeError, + gradient.GradParent, + 1, + func_sq, + func_cube, + ) From d4786e85efa18fdf84ee4a29aa938c429155efea Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Tue, 6 Dec 2022 11:57:26 +0100 Subject: [PATCH 13/33] show missing lines in coverage reports --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index f8991e7e..13f9dc73 100644 --- a/setup.cfg +++ b/setup.cfg @@ -85,7 +85,7 @@ testpaths = addopts = --verbose --cov=modopt - --cov-report=term + --cov-report=term-missing --cov-report=xml --junitxml=pytest.xml --pydocstyle From bee267b7df3a862bbe7eff34efb75981042ea0d7 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Tue, 6 Dec 2022 17:59:01 +0100 Subject: [PATCH 14/33] rewrite of proximity operators testing. --- modopt/tests/test_opt.py | 374 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 371 insertions(+), 3 deletions(-) diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py index 00ece76c..cd183e76 100644 --- a/modopt/tests/test_opt.py +++ b/modopt/tests/test_opt.py @@ -9,7 +9,7 @@ import numpy as np import numpy.testing as npt import pytest -from pytest_cases import parametrize, parametrize_with_cases, case +from pytest_cases import parametrize, parametrize_with_cases, case, fixture, fixture_ref from modopt.opt import cost, gradient, linear, proximity, reweight @@ -125,8 +125,7 @@ def test_grad_basic(grad_basic): def test_grad_basic_cost(grad_basic): """Test grad_basic cost.""" - npt.assert_almost_equal(grad_basic.cost(np.arange(9).reshape(3,3)), 3192.0) - + npt.assert_almost_equal(grad_basic.cost(np.arange(9).reshape(3, 3)), 3192.0) def test_grad_op_raises(): @@ -138,3 +137,372 @@ def test_grad_op_raises(): func_sq, func_cube, ) + + +############# +# LINEAR OP # +############# + + +@case(tags="linear") +def case_linear_identity(): + """Case linear operator identity.""" + linop = linear.Identity() + + data_op, data_adj_op, res_op, res_adj_op = 1, 1, 1, 1 + + return linop, data_op, data_adj_op, res_op, res_adj_op + + +@case(tags="linear") +def case_linear_wavelet(): + """Case linear operator wavelet.""" + linop = linear.WaveletConvolve(filters=np.arange(8).reshape(2, 2, 2).astype(float)) + data_op = np.arange(4).reshape(1, 2, 2).astype(float) + data_adj_op = np.arange(8).reshape(1, 2, 2, 2).astype(float) + res_op = np.array([[[[0, 0], [0, 4.0]], [[0, 4.0], [8.0, 28.0]]]]) + res_adj_op = np.array([[[28.0, 62.0], [68.0, 140.0]]]) + + return linop, data_op, data_adj_op, res_op, res_adj_op + + +@case(tags="linear") +def case_linear_combo(): + """Case linear operator combo.""" + parent = linear.LinearParent( + func_sq, + func_cube, + ) + linop = linear.LinearCombo([parent, parent], [1.0, 1.0]) + + data_op, data_adj_op, res_op, res_adj_op = ( + 2, + np.array([2, 2]), + np.array([4, 4]), + 16.0, + ) + + return linop, data_op, data_adj_op, res_op, res_adj_op + + +@case(tags="linear") +def case_linear_combo_weight(): + """Case linear operator combo with weights.""" + parent = linear.LinearParent( + func_sq, + func_cube, + ) + linop = linear.LinearCombo([parent, parent], [1.0, 1.0]) + + data_op, data_adj_op, res_op, res_adj_op = ( + 2, + np.array([2, 2]), + np.array([4, 4]), + 16.0, + ) + + return linop, data_op, data_adj_op, res_op, res_adj_op + + +@fixture +@parametrize_with_cases( + "linop, data_op, data_adj_op, res_op, res_adj_op", cases=".", has_tag="linear" +) +def lin_adj_op(linop, data_op, data_adj_op, res_op, res_adj_op): + """Get adj_op relative data.""" + return linop.adj_op, data_adj_op, res_adj_op + + +@fixture +@parametrize_with_cases( + "linop, data_op, data_adj_op, res_op, res_adj_op", cases=".", has_tag="linear" +) +def lin_op(linop, data_op, data_adj_op, res_op, res_adj_op): + """Get op relative data.""" + return linop.op, data_op, res_op + + +@parametrize( + ("action", "data", "result"), [fixture_ref(lin_op), fixture_ref(lin_adj_op)] +) +def test_linear_operator(action, data, result): + """Test linear operator.""" + npt.assert_almost_equal(action(data), result) + + +dummy_with_op = Dummy() +dummy_with_op.op = lambda x: x + + +@pytest.mark.parametrize( + ("args", "error"), + [ + ([linear.LinearParent(func_sq, func_cube)], TypeError), + ([[]], ValueError), + ([[Dummy()]], ValueError), + ([[dummy_with_op]], ValueError), + ([[]], ValueError), + ([[linear.LinearParent(func_sq, func_cube)] * 2, [1.0]], ValueError), + ([[linear.LinearParent(func_sq, func_cube)] * 2, ["1", "1"]], TypeError), + ], +) +def test_linear_combo_errors(args, error): + """Test linear combo_errors.""" + npt.assert_raises(error, linear.LinearCombo, *args) + + +############# +# Proximity # +############# + + +class ProxCases: + """Class containing all proximal operator cases. + + Each case should return 4 parameters: + 1. The proximal operator + 2. test input data + 3. Expected result data + 4. Expected cost value. + """ + + weights = np.ones(9).reshape(3, 3).astype(float) * 3 + array33 = np.arange(9).reshape(3, 3).astype(float) + array33_st = np.array([[-0, -0, -0], [0, 1.0, 2.0], [3.0, 4.0, 5.0]]) + array33_st2 = array33_st * -1 + + array33_support = np.asarray([[0, 0, 0], [0, 1.0, 1.25], [1.5, 1.75, 2.0]]) + + array233 = np.arange(18).reshape(2, 3, 3).astype(float) + array233_2 = np.array( + [ + [ + [2.73843189, 3.14594066, 3.55344943], + [3.9609582, 4.36846698, 4.77597575], + [5.18348452, 5.59099329, 5.99850206], + ], + [ + [8.07085295, 9.2718846, 10.47291625], + [11.67394789, 12.87497954, 14.07601119], + [15.27704284, 16.47807449, 17.67910614], + ], + ] + ) + array233_3 = np.array( + [ + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [ + [4.00795282, 4.60438026, 5.2008077], + [5.79723515, 6.39366259, 6.99009003], + [7.58651747, 8.18294492, 8.77937236], + ], + ] + ) + + def case_prox_parent(self): + """Case prox parent.""" + return ( + proximity.ProximityParent( + func_sq, + func_double, + ), + 3, + 9, + 6, + ) + + def case_prox_identity(self): + """Case prox identity.""" + return proximity.IdentityProx(), 3, 3, 0 + + def case_prox_positivity(self): + """Case prox positivity.""" + return proximity.Positivity(), -3, 0, 0 + + def case_prox_sparsethresh(self): + """Case prox sparsethreshosld.""" + return ( + proximity.SparseThreshold(linear.Identity(), weights=self.weights), + self.array33, + self.array33_st, + 108, + ) + + @parametrize( + "lowr_type, initial_rank, operator, result, cost", + [ + ("standard", None, None, array233_2, 469.3913294246498), + ("standard", 1, None, array233_2, 469.3913294246498), + ("ngole", None, func_double, array233_3, 469.3913294246498), + ], + ) + def case_prox_lowrank(self, lowr_type, initial_rank, operator, result, cost): + """Case prox lowrank.""" + return ( + proximity.LowRankMatrix( + 10, + lowr_type=lowr_type, + initial_rank=initial_rank, + operator=operator, + thresh_type="hard" if lowr_type == "standard" else "soft", + ), + self.array233, + result, + cost, + ) + + def case_prox_linear_comp(self): + """Case prox linear comp.""" + return ( + proximity.LinearCompositionProx( + linear_op=linear.Identity(), prox_op=self.case_prox_sparsethresh()[0] + ), + self.array33, + self.array33_st, + 108, + ) + + def case_prox_ridge(self): + """Case prox ridge.""" + return ( + proximity.Ridge(linear.Identity(), self.weights), + self.array33 * (1 + 1j), + self.array33 * (1 + 1j) / 7, + 1224, + ) + + @parametrize("alpha, beta", [(0, weights), (weights, 0)]) + def case_prox_elasticnet(self, alpha, beta): + """Case prox elastic net.""" + if np.all(alpha == 0): + data = self.case_prox_sparsethresh()[1:] + else: + data = self.case_prox_ridge()[1:] + return (proximity.ElasticNet(linear.Identity(), alpha, beta), *data) + + @parametrize( + "beta, k_value, data, result, cost", + [ + (0.2, 1, array33.flatten(), array33_st.flatten(), 259.2), + (3, 5, array33.flatten(), array33_support.flatten(), 684.0), + ( + 6.0, + 9, + array33.flatten() * (1 + 1j), + array33.flatten() * (1 + 1j) / 7, + 1224, + ), + ], + ) + def case_prox_Ksupport(self, beta, k_value, data, result, cost): + """Case prox Ksupport norm.""" + return (proximity.KSupportNorm(beta=beta, k_value=k_value), data, result, cost) + + @parametrize(use_weights=[True, False]) + def case_prox_grouplasso(self, use_weights): + """Case GroupLasso proximity.""" + if use_weights: + weights = np.tile(self.weights, (4, 1, 1)) + else: + weights = np.tile(np.zeros((3, 3)), (4, 1, 1)) + + random_data = 3 * np.random.random(weights[0].shape) + random_data_tile = np.tile(random_data, (weights.shape[0], 1, 1)) + if use_weights: + gl_result_data = 2 * random_data_tile - 3 + gl_result_data = ( + np.array(gl_result_data * (gl_result_data > 0).astype("int")) / 2 + ) + cost = np.sum(random_data_tile) * 6 + else: + gl_result_data = random_data_tile + cost = 0 + return ( + proximity.GroupLASSO( + weights=weights, + ), + random_data_tile, + gl_result_data, + cost, + ) + + @pytest.mark.skipif(not SKLEARN_AVAILABLE, reason="sklearn not available.") + def case_prox_owl(self): + """Case prox for owl.""" + return ( + proximity.OrderedWeightedL1Norm(self.weights.flatten()), + self.array33.flatten(), + self.array33_st.flatten(), + 108.0, + ) + + +@parametrize_with_cases("operator, input_data, op_result, cost_result", cases=ProxCases) +def test_prox_op(operator, input_data, op_result, cost_result): + """Test proximity operator op.""" + npt.assert_almost_equal(operator.op(input_data), op_result) + + +@parametrize_with_cases("operator, input_data, op_result, cost_result", cases=ProxCases) +def test_prox_cost(operator, input_data, op_result, cost_result): + """Test proximity operator cost.""" + npt.assert_almost_equal(operator.cost(input_data, verbose=True), cost_result) + + +@parametrize( + "arg, error", + [ + (1, TypeError), + ([], ValueError), + ([Dummy()], ValueError), + ([dummy_with_op], ValueError), + ], +) +def test_error_prox_combo(arg, error): + """Test errors for proximity combo.""" + npt.assert_raises(error, proximity.ProximityCombo, arg) + + +@pytest.mark.skipif(SKLEARN_AVAILABLE, reason="sklearn is installed") +def test_fail_sklearn(): + """Test fail OWL wit sklearn.""" + npt.assert_raises(ImportError, proximity.OrderedWeightedL1Norm, 1) + + +def test_fail_owl(): + """Test fail owl.""" + npt.assert_raises( + ValueError, + proximity.OrderedWeightedL1Norm, + np.arange(10), + ) + + npt.assert_raises( + ValueError, + proximity.OrderedWeightedL1Norm, + -np.arange(10), + ) + + +def test_fail_Ksupport_norm(): + """Test fail for Ksupport norm.""" + npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0) + + +def test_reweight(): + """Test for reweight module.""" + data1 = np.arange(9).reshape(3, 3).astype(float) + 1 + data2 = np.array( + [[0.5, 1.0, 1.5], [2.0, 2.5, 3.0], [3.5, 4.0, 4.5]], + ) + + rw = reweight.cwbReweight(data1) + rw.reweight(data1) + + npt.assert_array_equal( + rw.weights, + data2, + err_msg="Incorrect CWB re-weighting.", + ) + + npt.assert_raises(ValueError, rw.reweight, data1[0]) From 611a0641b62f28c7eb46a8a7a11a1732c50f3c4c Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Tue, 6 Dec 2022 18:13:11 +0100 Subject: [PATCH 15/33] add fail low rank method. --- modopt/tests/test_opt.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py index cd183e76..a74830f0 100644 --- a/modopt/tests/test_opt.py +++ b/modopt/tests/test_opt.py @@ -484,6 +484,12 @@ def test_fail_owl(): ) +def test_fail_lowrank(): + """Test fail for lowrnk.""" + prox_op = proximity.LowRankMatrix(10, lowr_type="fail") + npt.assert_raises(ValueError, prox_op.op, 0) + + def test_fail_Ksupport_norm(): """Test fail for Ksupport norm.""" npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0) From c128c811bb5fb3208cc705066aba0fcf773b430d Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Tue, 6 Dec 2022 19:02:21 +0100 Subject: [PATCH 16/33] add cases for algorithms test --- modopt/tests/test_algorithms.py | 571 ++++++++------------------------ 1 file changed, 137 insertions(+), 434 deletions(-) diff --git a/modopt/tests/test_algorithms.py b/modopt/tests/test_algorithms.py index 7ff96a8b..f5a688a7 100644 --- a/modopt/tests/test_algorithms.py +++ b/modopt/tests/test_algorithms.py @@ -1,470 +1,173 @@ # -*- coding: utf-8 -*- -"""UNIT TESTS FOR OPT.ALGORITHMS. +"""UNIT TESTS FOR Algorithms. -This module contains unit tests for the modopt.opt.algorithms module. +This module contains unit tests for the modopt.opt module. :Author: Samuel Farrens - +:Author: Pierre-Antoine Comby """ -from unittest import TestCase - import numpy as np import numpy.testing as npt - from modopt.opt import algorithms, cost, gradient, linear, proximity, reweight - -# Basic functions to be used as operators or as dummy functions -func_identity = lambda x_val: x_val -func_double = lambda x_val: x_val * 2 -func_sq = lambda x_val: x_val ** 2 -func_cube = lambda x_val: x_val ** 3 - - -class Dummy(object): - """Dummy class for tests.""" - - pass - - -class AlgorithmTestCase(TestCase): - """Test case for algorithms module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.data2 = self.data1 + np.random.randn(*self.data1.shape) * 1e-6 - self.data3 = np.arange(9).reshape(3, 3).astype(float) + 1 - - grad_inst = gradient.GradBasic( - self.data1, - func_identity, - func_identity, - ) - +from pytest_cases import (case, fixture, fixture_ref, parametrize, + parametrize_with_cases) + +from test_helpers import Dummy + +SKLEARN_AVAILABLE = True +try: + import sklearn +except ImportError: + SKLEARN_AVAILABLE = False + + +idty = lambda x_val: x_val + + +class AlgoCases: + """Cases for algorithms.""" + + data1 = np.arange(9).reshape(3, 3).astype(float) + data2 = data1 + np.random.randn(*data1.shape) * 1e-6 + data3 = np.arange(9).reshape(3, 3).astype(float) + 1 + max_iter = 20 + + @parametrize( + kwargs=[ + {"beta_update": idty, "auto_iterate": False, "cost": None}, + {"beta_update": idty}, + {"cost": None, "lambda_update": None}, + {"beta_update": idty, "a_cd": 3}, + {"beta_update": idty, "r_lazy": 3, "p_lazy": 0.7, "q_lazy": 0.7}, + {"restart_strategy": "adaptive", "xi_restart": 0.9}, + { + "restart_strategy": "greedy", + "xi_restart": 0.9, + "min_beta": 1.0, + "s_greedy": 1.1, + }, + ] + ) + def case_forward_backward(self, kwargs): + """Forward Backward case.""" + algo = algorithms.ForwardBackward( + grad=gradient.GradBasic(self.data1, idty, idty), + prox=proximity.Positivity(), + **kwargs, + ) + if kwargs.get("auto_iterate", None) is False: + algo.iterate(self.max_iter) + return algo + + @parametrize( + kwargs=[ + { + "cost": None, + "auto_iterate": False, + "gamma_update": idty, + "beta_update": idty, + }, + {"gamma_update": idty, "lambda_update": idty}, + {"cost": True}, + {"cost": True, "step_size": 2}, + ] + ) + def case_gen_forward_backward(self, kwargs): + """General FB setup.""" + grad_inst = gradient.GradBasic(self.data1, idty, idty) prox_inst = proximity.Positivity() prox_dual_inst = proximity.IdentityProx() - linear_inst = linear.Identity() - reweight_inst = reweight.cwbReweight(self.data3) - cost_inst = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) - self.setup = algorithms.SetUp() - self.max_iter = 20 - - self.fb_all_iter = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=None, - auto_iterate=False, - beta_update=func_identity, - ) - self.fb_all_iter.iterate(self.max_iter) - - self.fb1 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - ) - - self.fb2 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - lambda_update=None, - ) - - self.fb3 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - a_cd=3, - ) - - self.fb4 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - r_lazy=3, - p_lazy=0.7, - q_lazy=0.7, - ) - - self.fb5 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - restart_strategy='adaptive', - xi_restart=0.9, - ) - - self.fb6 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - restart_strategy='greedy', - xi_restart=0.9, - min_beta=1.0, - s_greedy=1.1, - ) - - self.gfb_all_iter = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - cost=None, - auto_iterate=False, - gamma_update=func_identity, - beta_update=func_identity, - ) - self.gfb_all_iter.iterate(self.max_iter) - - self.gfb1 = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - gamma_update=func_identity, - lambda_update=func_identity, - ) - - self.gfb2 = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - cost=cost_inst, - ) - - self.gfb3 = algorithms.GenForwardBackward( - self.data1, + if kwargs.get("cost", None) is True: + kwargs["cost"] = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) + algo = algorithms.GenForwardBackward( grad=grad_inst, prox_list=[prox_inst, prox_dual_inst], - cost=cost_inst, - step_size=2, - ) - - self.condat_all_iter = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - cost=None, - prox_dual=prox_dual_inst, - sigma_update=func_identity, - tau_update=func_identity, - rho_update=func_identity, - auto_iterate=False, - ) - self.condat_all_iter.iterate(self.max_iter) - - self.condat1 = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - prox_dual=prox_dual_inst, - sigma_update=func_identity, - tau_update=func_identity, - rho_update=func_identity, - ) - - self.condat2 = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - prox_dual=prox_dual_inst, - linear=linear_inst, - cost=cost_inst, - reweight=reweight_inst, - ) + **kwargs, + ) + if kwargs.get("auto_iterate", None) is False: + algo.iterate(self.max_iter) + return algo + + @parametrize( + kwargs=[ + { + "sigma_dual": idty, + "tau_update": idty, + "rho_update": idty, + "auto_iterate": False, + }, + { + "sigma_dual": idty, + "tau_update": idty, + "rho_update": idty, + }, + { + "linear": linear.Identity(), + "cost": True, + "reweight": reweight.cwbReweight(data3), + }, + ] + ) + def case_condat(self, kwargs): + """Condat Vu Algorithm setup.""" + grad_inst = gradient.GradBasic(self.data1, idty, idty) + prox_inst = proximity.Positivity() + prox_dual_inst = proximity.IdentityProx() + if kwargs.get("cost", None) is True: + kwargs["cost"] = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) - self.condat3 = algorithms.Condat( + algo = algorithms.Condat( self.data1, self.data2, grad=grad_inst, prox=prox_inst, prox_dual=prox_dual_inst, - linear=Dummy(), - cost=cost_inst, - auto_iterate=False, + **kwargs, ) + if kwargs.get("auto_iterate", None) is False: + algo.iterate(self.max_iter) + return algo - self.pogm_all_iter = algorithms.POGM( + @parametrize(kwargs=[{"auto_iterate": False, "cost": None}, {}]) + def case_pogm(self, kwargs): + """POGM setup.""" + grad_inst = gradient.GradBasic(self.data1, idty, idty) + prox_inst = proximity.Positivity() + algo = algorithms.POGM( u=self.data1, x=self.data1, y=self.data1, z=self.data1, grad=grad_inst, prox=prox_inst, - auto_iterate=False, - cost=None, + **kwargs, ) - self.pogm_all_iter.iterate(self.max_iter) - self.pogm1 = algorithms.POGM( - u=self.data1, - x=self.data1, - y=self.data1, - z=self.data1, - grad=grad_inst, - prox=prox_inst, - ) + if kwargs.get("auto_iterate", None) is False: + algo.iterate(self.max_iter) - self.vanilla_grad = algorithms.VanillaGenericGradOpt( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - ) - self.ada_grad = algorithms.AdaGenericGradOpt( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - ) - self.adam_grad = algorithms.ADAMGradOpt( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - ) - self.momentum_grad = algorithms.MomentumGradOpt( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - ) - self.rms_grad = algorithms.RMSpropGradOpt( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - ) - self.saga_grad = algorithms.SAGAOptGradOpt( + @parametrize( + GradDescent=[ + algorithms.VanillaGenericGradOpt, + algorithms.AdaGenericGradOpt, + algorithms.ADAMGradOpt, + algorithms.MomentumGradOpt, + algorithms.RMSpropGradOpt, + algorithms.SAGAOptGradOpt, + ] + ) + def case_grad(self, GradDescent): + """Gradient Descent algorithm test.""" + grad_inst = gradient.GradBasic(self.data1, idty, idty) + prox_inst = proximity.Positivity() + cost_inst = cost.costObj([grad_inst, prox_inst]) + + algo = GradDescent( self.data1, grad=grad_inst, prox=prox_inst, cost=cost_inst, ) - - self.dummy = Dummy() - self.dummy.cost = func_identity - self.setup._check_operator(self.dummy.cost) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.setup = None - self.fb_all_iter = None - self.fb1 = None - self.fb2 = None - self.gfb_all_iter = None - self.gfb1 = None - self.gfb2 = None - self.condat_all_iter = None - self.condat1 = None - self.condat2 = None - self.condat3 = None - self.pogm1 = None - self.pogm_all_iter = None - self.dummy = None - - def test_set_up(self): - """Test set_up.""" - npt.assert_raises(TypeError, self.setup._check_input_data, 1) - - npt.assert_raises(TypeError, self.setup._check_param, 1) - - npt.assert_raises(TypeError, self.setup._check_param_update, 1) - - def test_all_iter(self): - """Test if all opt run for all iterations.""" - opts = [ - self.fb_all_iter, - self.gfb_all_iter, - self.condat_all_iter, - self.pogm_all_iter, - ] - for opt in opts: - npt.assert_equal(opt.idx, self.max_iter - 1) - - def test_forward_backward(self): - """Test forward_backward.""" - npt.assert_array_equal( - self.fb1.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb2.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb3.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb4.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb5.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb6.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - def test_gen_forward_backward(self): - """Test gen_forward_backward.""" - npt.assert_array_equal( - self.gfb1.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_array_equal( - self.gfb2.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_array_equal( - self.gfb3.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_equal( - self.gfb3.step_size, - 2, - err_msg='Incorrect step size.', - ) - - npt.assert_raises( - TypeError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=1, - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[1], - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[0.5, 0.5], - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[0.5], - ) - - def test_condat(self): - """Test gen_condat.""" - npt.assert_almost_equal( - self.condat1.x_final, - self.data1, - err_msg='Incorrect Condat result.', - ) - - npt.assert_almost_equal( - self.condat2.x_final, - self.data1, - err_msg='Incorrect Condat result.', - ) - - def test_pogm(self): - """Test pogm.""" - npt.assert_almost_equal( - self.pogm1.x_final, - self.data1, - err_msg='Incorrect POGM result.', - ) - - def test_ada_grad(self): - """Test ADA Gradient Descent.""" - self.ada_grad.iterate() - npt.assert_almost_equal( - self.ada_grad.x_final, - self.data1, - err_msg='Incorrect ADAGrad results.', - ) - - def test_adam_grad(self): - """Test ADAM Gradient Descent.""" - self.adam_grad.iterate() - npt.assert_almost_equal( - self.adam_grad.x_final, - self.data1, - err_msg='Incorrect ADAMGrad results.', - ) - - def test_momemtum_grad(self): - """Test Momemtum Gradient Descent.""" - self.momentum_grad.iterate() - npt.assert_almost_equal( - self.momentum_grad.x_final, - self.data1, - err_msg='Incorrect MomentumGrad results.', - ) - - def test_rmsprop_grad(self): - """Test RMSProp Gradient Descent.""" - self.rms_grad.iterate() - npt.assert_almost_equal( - self.rms_grad.x_final, - self.data1, - err_msg='Incorrect RMSPropGrad results.', - ) - - def test_saga_grad(self): - """Test SAGA Descent.""" - self.saga_grad.iterate() - npt.assert_almost_equal( - self.saga_grad.x_final, - self.data1, - err_msg='Incorrect SAGA Grad results.', - ) - - def test_vanilla_grad(self): - """Test Vanilla Gradient Descent.""" - self.vanilla_grad.iterate() - npt.assert_almost_equal( - self.vanilla_grad.x_final, - self.data1, - err_msg='Incorrect VanillaGrad results.', - ) + return algo From 34ff1a89787266098980bf8e868b0dc616b13cef Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Wed, 7 Dec 2022 11:34:06 +0100 Subject: [PATCH 17/33] add algorithm test. --- modopt/tests/test_algorithms.py | 107 +++++++++++++++++++++++--------- 1 file changed, 79 insertions(+), 28 deletions(-) diff --git a/modopt/tests/test_algorithms.py b/modopt/tests/test_algorithms.py index f5a688a7..64fbaa0c 100644 --- a/modopt/tests/test_algorithms.py +++ b/modopt/tests/test_algorithms.py @@ -10,9 +10,16 @@ import numpy as np import numpy.testing as npt +import pytest from modopt.opt import algorithms, cost, gradient, linear, proximity, reweight -from pytest_cases import (case, fixture, fixture_ref, parametrize, - parametrize_with_cases) +from pytest_cases import ( + case, + fixture, + fixture_ref, + lazy_value, + parametrize, + parametrize_with_cases, +) from test_helpers import Dummy @@ -23,7 +30,35 @@ SKLEARN_AVAILABLE = False -idty = lambda x_val: x_val +@fixture +def idty(): + """Identity function.""" + return lambda x: x + + +@fixture +def reweight_op(): + """Reweight operator.""" + data3 = np.arange(9).reshape(3, 3).astype(float) + 1 + return reweight.cwbReweight(data3) + + +def build_kwargs(kwargs): + """Build the kwargs for each algorithm, replacing placeholder by true value. + + Direct parametrization somehow is not working with pytest-xdist and pytest-cases. + """ + update_value = { + "idty": lambda x: x, + "lin_idty": linear.Identity(), + "reweight_op": reweight.cwbReweight( + np.arange(9).reshape(3, 3).astype(float) + 1 + ), + } + # update the value of the dict is possible. + for key in kwargs: + kwargs[key] = update_value.get(kwargs[key], kwargs[key]) + return kwargs class AlgoCases: @@ -31,16 +66,15 @@ class AlgoCases: data1 = np.arange(9).reshape(3, 3).astype(float) data2 = data1 + np.random.randn(*data1.shape) * 1e-6 - data3 = np.arange(9).reshape(3, 3).astype(float) + 1 max_iter = 20 @parametrize( kwargs=[ - {"beta_update": idty, "auto_iterate": False, "cost": None}, - {"beta_update": idty}, + {"beta_update": "idty", "auto_iterate": False, "cost": None}, + {"beta_update": "idty"}, {"cost": None, "lambda_update": None}, - {"beta_update": idty, "a_cd": 3}, - {"beta_update": idty, "r_lazy": 3, "p_lazy": 0.7, "q_lazy": 0.7}, + {"beta_update": "idty", "a_cd": 3}, + {"beta_update": "idty", "r_lazy": 3, "p_lazy": 0.7, "q_lazy": 0.7}, {"restart_strategy": "adaptive", "xi_restart": 0.9}, { "restart_strategy": "greedy", @@ -50,68 +84,73 @@ class AlgoCases: }, ] ) - def case_forward_backward(self, kwargs): + def case_forward_backward(self, kwargs, idty): """Forward Backward case.""" + kwargs = build_kwargs(kwargs) algo = algorithms.ForwardBackward( + self.data1, grad=gradient.GradBasic(self.data1, idty, idty), prox=proximity.Positivity(), **kwargs, ) if kwargs.get("auto_iterate", None) is False: algo.iterate(self.max_iter) - return algo + return algo, kwargs @parametrize( kwargs=[ { "cost": None, "auto_iterate": False, - "gamma_update": idty, - "beta_update": idty, + "gamma_update": "idty", + "beta_update": "idty", }, - {"gamma_update": idty, "lambda_update": idty}, + {"gamma_update": "idty", "lambda_update": "idty"}, {"cost": True}, {"cost": True, "step_size": 2}, ] ) - def case_gen_forward_backward(self, kwargs): + def case_gen_forward_backward(self, kwargs, idty): """General FB setup.""" + kwargs = build_kwargs(kwargs) grad_inst = gradient.GradBasic(self.data1, idty, idty) prox_inst = proximity.Positivity() prox_dual_inst = proximity.IdentityProx() if kwargs.get("cost", None) is True: kwargs["cost"] = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) algo = algorithms.GenForwardBackward( + self.data1, grad=grad_inst, prox_list=[prox_inst, prox_dual_inst], **kwargs, ) if kwargs.get("auto_iterate", None) is False: algo.iterate(self.max_iter) - return algo + return algo, kwargs @parametrize( kwargs=[ { - "sigma_dual": idty, - "tau_update": idty, - "rho_update": idty, + "sigma_dual": "idty", + "tau_update": "idty", + "rho_update": "idty", "auto_iterate": False, }, { - "sigma_dual": idty, - "tau_update": idty, - "rho_update": idty, + "sigma_dual": "idty", + "tau_update": "idty", + "rho_update": "idty", }, { - "linear": linear.Identity(), + "linear": "lin_idty", "cost": True, - "reweight": reweight.cwbReweight(data3), + "reweight": "reweight_op", }, ] ) - def case_condat(self, kwargs): + def case_condat(self, kwargs, idty): """Condat Vu Algorithm setup.""" + kwargs = build_kwargs(kwargs) grad_inst = gradient.GradBasic(self.data1, idty, idty) prox_inst = proximity.Positivity() prox_dual_inst = proximity.IdentityProx() @@ -128,10 +167,10 @@ def case_condat(self, kwargs): ) if kwargs.get("auto_iterate", None) is False: algo.iterate(self.max_iter) - return algo + return algo, kwargs @parametrize(kwargs=[{"auto_iterate": False, "cost": None}, {}]) - def case_pogm(self, kwargs): + def case_pogm(self, kwargs, idty): """POGM setup.""" grad_inst = gradient.GradBasic(self.data1, idty, idty) prox_inst = proximity.Positivity() @@ -147,6 +186,7 @@ def case_pogm(self, kwargs): if kwargs.get("auto_iterate", None) is False: algo.iterate(self.max_iter) + return algo, kwargs @parametrize( GradDescent=[ @@ -158,7 +198,7 @@ def case_pogm(self, kwargs): algorithms.SAGAOptGradOpt, ] ) - def case_grad(self, GradDescent): + def case_grad(self, GradDescent, idty): """Gradient Descent algorithm test.""" grad_inst = gradient.GradBasic(self.data1, idty, idty) prox_inst = proximity.Positivity() @@ -170,4 +210,15 @@ def case_grad(self, GradDescent): prox=prox_inst, cost=cost_inst, ) - return algo + algo.iterate() + return algo, {} + + +@parametrize_with_cases("algo, kwargs", cases=AlgoCases) +def test_algo(algo, kwargs): + """Test algorithms.""" + if kwargs.get("auto_iterate") is False: + # algo already run + npt.assert_almost_equal(algo.idx, AlgoCases.max_iter - 1) + else: + npt.assert_almost_equal(algo.x_final, AlgoCases.data1) From c16c24fafad77bd2fe7b61a1c841c57ebf299297 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Wed, 7 Dec 2022 12:39:57 +0100 Subject: [PATCH 18/33] add pytest-cases and pytest-xdists support. --- .github/workflows/ci-build.yml | 2 +- develop.txt | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml index 3ffcb6f4..b85afd20 100644 --- a/.github/workflows/ci-build.yml +++ b/.github/workflows/ci-build.yml @@ -61,7 +61,7 @@ jobs: shell: bash -l {0} run: | export PATH=/usr/share/miniconda/bin:$PATH - python setup.py test + pytest -n 2 - name: Save Test Results if: always() diff --git a/develop.txt b/develop.txt index 5105696e..44522571 100644 --- a/develop.txt +++ b/develop.txt @@ -3,6 +3,8 @@ flake8>=4 nose>=1.3.7 pytest>=6.2.2 pytest-raises>=0.10 +pytest-cases>= 3.6 +pytest-xdist>= 3.0.1 pytest-cov>=2.11.1 pytest-pep8>=1.0.6 pytest-emoji>=0.2.0 From 38f065a638d9c6f2869f5bd94eefd9748619dc56 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Thu, 8 Dec 2022 13:48:54 +0100 Subject: [PATCH 19/33] add support for testing metrics. --- modopt/tests/test_algorithms.py | 84 +++++++++++++++++++++------------ 1 file changed, 54 insertions(+), 30 deletions(-) diff --git a/modopt/tests/test_algorithms.py b/modopt/tests/test_algorithms.py index 64fbaa0c..0216ea3c 100644 --- a/modopt/tests/test_algorithms.py +++ b/modopt/tests/test_algorithms.py @@ -43,10 +43,12 @@ def reweight_op(): return reweight.cwbReweight(data3) -def build_kwargs(kwargs): - """Build the kwargs for each algorithm, replacing placeholder by true value. +def build_kwargs(kwargs, use_metrics): + """Build the kwargs for each algorithm, replacing placeholders by true values. - Direct parametrization somehow is not working with pytest-xdist and pytest-cases. + This function has to be call for each test, as direct parameterization somehow + is not working with pytest-xdist and pytest-cases. + It also adds dummy metric measurement to validate the metric api. """ update_value = { "idty": lambda x: x, @@ -55,12 +57,27 @@ def build_kwargs(kwargs): np.arange(9).reshape(3, 3).astype(float) + 1 ), } + new_kwargs = dict() + print(kwargs) # update the value of the dict is possible. for key in kwargs: - kwargs[key] = update_value.get(kwargs[key], kwargs[key]) - return kwargs + new_kwargs[key] = update_value.get(kwargs[key], kwargs[key]) + if use_metrics: + new_kwargs["linear"] = linear.Identity() + new_kwargs["metrics"] = { + "diff": { + "metric": lambda test, ref: np.sum(test - ref), + "mapping": {"x_new": "test"}, + "cst_kwargs": {"ref": np.arange(9).reshape((3, 3))}, + "early_stopping": False, + } + } + return new_kwargs + + +@parametrize(use_metrics=[True, False]) class AlgoCases: """Cases for algorithms.""" @@ -84,18 +101,18 @@ class AlgoCases: }, ] ) - def case_forward_backward(self, kwargs, idty): + def case_forward_backward(self, kwargs, idty, use_metrics): """Forward Backward case.""" - kwargs = build_kwargs(kwargs) + update_kwargs = build_kwargs(kwargs, use_metrics) algo = algorithms.ForwardBackward( self.data1, grad=gradient.GradBasic(self.data1, idty, idty), prox=proximity.Positivity(), - **kwargs, + **update_kwargs, ) - if kwargs.get("auto_iterate", None) is False: + if update_kwargs.get("auto_iterate", None) is False: algo.iterate(self.max_iter) - return algo, kwargs + return algo, update_kwargs @parametrize( kwargs=[ @@ -110,23 +127,23 @@ def case_forward_backward(self, kwargs, idty): {"cost": True, "step_size": 2}, ] ) - def case_gen_forward_backward(self, kwargs, idty): + def case_gen_forward_backward(self, kwargs, use_metrics, idty): """General FB setup.""" - kwargs = build_kwargs(kwargs) + update_kwargs = build_kwargs(kwargs, use_metrics) grad_inst = gradient.GradBasic(self.data1, idty, idty) prox_inst = proximity.Positivity() prox_dual_inst = proximity.IdentityProx() - if kwargs.get("cost", None) is True: - kwargs["cost"] = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) + if update_kwargs.get("cost", None) is True: + update_kwargs["cost"] = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) algo = algorithms.GenForwardBackward( self.data1, grad=grad_inst, prox_list=[prox_inst, prox_dual_inst], - **kwargs, + **update_kwargs, ) - if kwargs.get("auto_iterate", None) is False: + if update_kwargs.get("auto_iterate", None) is False: algo.iterate(self.max_iter) - return algo, kwargs + return algo, update_kwargs @parametrize( kwargs=[ @@ -148,14 +165,14 @@ def case_gen_forward_backward(self, kwargs, idty): }, ] ) - def case_condat(self, kwargs, idty): + def case_condat(self, kwargs, use_metrics, idty): """Condat Vu Algorithm setup.""" - kwargs = build_kwargs(kwargs) + update_kwargs = build_kwargs(kwargs, use_metrics) grad_inst = gradient.GradBasic(self.data1, idty, idty) prox_inst = proximity.Positivity() prox_dual_inst = proximity.IdentityProx() - if kwargs.get("cost", None) is True: - kwargs["cost"] = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) + if update_kwargs.get("cost", None) is True: + update_kwargs["cost"] = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) algo = algorithms.Condat( self.data1, @@ -163,15 +180,16 @@ def case_condat(self, kwargs, idty): grad=grad_inst, prox=prox_inst, prox_dual=prox_dual_inst, - **kwargs, + **update_kwargs, ) - if kwargs.get("auto_iterate", None) is False: + if update_kwargs.get("auto_iterate", None) is False: algo.iterate(self.max_iter) - return algo, kwargs + return algo, update_kwargs @parametrize(kwargs=[{"auto_iterate": False, "cost": None}, {}]) - def case_pogm(self, kwargs, idty): + def case_pogm(self, kwargs, use_metrics, idty): """POGM setup.""" + update_kwargs = build_kwargs(kwargs, use_metrics) grad_inst = gradient.GradBasic(self.data1, idty, idty) prox_inst = proximity.Positivity() algo = algorithms.POGM( @@ -181,12 +199,12 @@ def case_pogm(self, kwargs, idty): z=self.data1, grad=grad_inst, prox=prox_inst, - **kwargs, + **update_kwargs, ) - if kwargs.get("auto_iterate", None) is False: + if update_kwargs.get("auto_iterate", None) is False: algo.iterate(self.max_iter) - return algo, kwargs + return algo, update_kwargs @parametrize( GradDescent=[ @@ -198,8 +216,9 @@ def case_pogm(self, kwargs, idty): algorithms.SAGAOptGradOpt, ] ) - def case_grad(self, GradDescent, idty): + def case_grad(self, GradDescent, use_metrics, idty): """Gradient Descent algorithm test.""" + update_kwargs = build_kwargs({}, use_metrics) grad_inst = gradient.GradBasic(self.data1, idty, idty) prox_inst = proximity.Positivity() cost_inst = cost.costObj([grad_inst, prox_inst]) @@ -209,9 +228,10 @@ def case_grad(self, GradDescent, idty): grad=grad_inst, prox=prox_inst, cost=cost_inst, + **update_kwargs, ) algo.iterate() - return algo, {} + return algo, update_kwargs @parametrize_with_cases("algo, kwargs", cases=AlgoCases) @@ -222,3 +242,7 @@ def test_algo(algo, kwargs): npt.assert_almost_equal(algo.idx, AlgoCases.max_iter - 1) else: npt.assert_almost_equal(algo.x_final, AlgoCases.data1) + + if kwargs.get("metrics"): + print(algo.metrics) + npt.assert_almost_equal(algo.metrics["diff"]["values"][-1], 0, 3) From cf76d87ff3fb38ba1c4faf01485707c9d864bee1 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Thu, 8 Dec 2022 13:49:22 +0100 Subject: [PATCH 20/33] improve base module coverage. --- modopt/tests/test_base.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/modopt/tests/test_base.py b/modopt/tests/test_base.py index 4551e05d..29aad6bd 100644 --- a/modopt/tests/test_base.py +++ b/modopt/tests/test_base.py @@ -1,7 +1,9 @@ """ Test for base module. -:Author: Pierre-Antoine Comby + Pierre-Antoine Comby """ import numpy as np import numpy.testing as npt @@ -175,7 +177,7 @@ def test_check_float(self, data, checked): ) def test_check_int(self, data, checked): """Test check int.""" - npt.assert_array_equal(types.check_float(data), checked) + npt.assert_array_equal(types.check_int(data), checked) @pytest.mark.parametrize( ("data", "dtype"), [(data_flt, np.integer), (data_int, np.floating)] @@ -189,6 +191,10 @@ def test_check_npndarray(self, data, dtype): dtype=dtype, ) + def test_check_callable(self): + """Test callable.""" + npt.assert_raises(TypeError, types.check_callable, 1) + @pytest.mark.parametrize( "backend_name", From 7eb524eaa764f79e2b060829a6917bbe45ac5450 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Thu, 8 Dec 2022 13:49:34 +0100 Subject: [PATCH 21/33] test for wrong mask in metric module. --- modopt/tests/test_math.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modopt/tests/test_math.py b/modopt/tests/test_math.py index ee2cabfd..d2248440 100644 --- a/modopt/tests/test_math.py +++ b/modopt/tests/test_math.py @@ -243,6 +243,7 @@ def test_ssim_fail(self): (metrics.mse, data1, mse_res, mask), (metrics.nrmse, data1, nrmse_res, None), (metrics.nrmse, data1, nrmse_res, mask), + failparam(metrics.snr, data1, snr_res, "maskfail", raises=ValueError), ], ) def test_metric(self, metric, data, result, mask): From bf5b8ea9d167320e1df82f931d01a4565bfc99c7 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Thu, 8 Dec 2022 13:49:53 +0100 Subject: [PATCH 22/33] add docstring. --- modopt/tests/test_math.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modopt/tests/test_math.py b/modopt/tests/test_math.py index d2248440..993dc5f3 100644 --- a/modopt/tests/test_math.py +++ b/modopt/tests/test_math.py @@ -112,6 +112,8 @@ def test_convolve_stack(self, result, rot_kernel): class TestMatrix: + """Test matrix module.""" + array3 = np.arange(3) array33 = np.arange(9).reshape((3, 3)) array23 = np.arange(6).reshape((2, 3)) From c9aadbae80c46f8c89b62bd35fadbfc2d090d161 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Thu, 8 Dec 2022 13:53:48 +0100 Subject: [PATCH 23/33] update email adress and authors field. --- modopt/tests/test_algorithms.py | 5 +++-- modopt/tests/test_base.py | 2 +- modopt/tests/test_math.py | 5 +++-- modopt/tests/test_opt.py | 5 +++-- modopt/tests/test_signal.py | 5 +++-- 5 files changed, 13 insertions(+), 9 deletions(-) diff --git a/modopt/tests/test_algorithms.py b/modopt/tests/test_algorithms.py index 0216ea3c..73091acd 100644 --- a/modopt/tests/test_algorithms.py +++ b/modopt/tests/test_algorithms.py @@ -4,8 +4,9 @@ This module contains unit tests for the modopt.opt module. -:Author: Samuel Farrens -:Author: Pierre-Antoine Comby +:Authors: + Samuel Farrens + Pierre-Antoine Comby """ import numpy as np diff --git a/modopt/tests/test_base.py b/modopt/tests/test_base.py index 29aad6bd..6dddb71f 100644 --- a/modopt/tests/test_base.py +++ b/modopt/tests/test_base.py @@ -1,7 +1,7 @@ """ Test for base module. -:Author: +:Authors: Samuel Farrens Pierre-Antoine Comby """ diff --git a/modopt/tests/test_math.py b/modopt/tests/test_math.py index 993dc5f3..ea8f99ff 100644 --- a/modopt/tests/test_math.py +++ b/modopt/tests/test_math.py @@ -2,8 +2,9 @@ This module contains unit tests for the modopt.math module. -:Author: Pierre-Antoine Comby -:Author: Samuel Farrens +:Authors: + Samuel Farrens + Pierre-Antoine Comby """ import pytest from test_helpers import failparam, skipparam diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py index a74830f0..a675b805 100644 --- a/modopt/tests/test_opt.py +++ b/modopt/tests/test_opt.py @@ -2,8 +2,9 @@ This module contains tests for the modopt.opt module. -:Author: Samuel Farrens -:Author: Pierre-Antoine Comby +:Authors: + Samuel Farrens + Pierre-Antoine Comby """ import numpy as np diff --git a/modopt/tests/test_signal.py b/modopt/tests/test_signal.py index 269b1052..213f2102 100644 --- a/modopt/tests/test_signal.py +++ b/modopt/tests/test_signal.py @@ -2,8 +2,9 @@ This module contains unit tests for the modopt.signal module. -:Author: Samuel Farrens -:Author: Pierre-Antoine Comby +:Authors: + Samuel Farrens + Pierre-Antoine Comby """ import numpy as np From f8aade1a2b45698349aeb8b285f8c5305ef17c63 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Thu, 8 Dec 2022 14:22:09 +0100 Subject: [PATCH 24/33] 100% coverage for transform module. --- modopt/tests/test_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modopt/tests/test_base.py b/modopt/tests/test_base.py index 6dddb71f..ce528aa2 100644 --- a/modopt/tests/test_base.py +++ b/modopt/tests/test_base.py @@ -114,6 +114,7 @@ class TestTransforms: ("func", "indata", "layout", "outdata"), [ (transform.cube2map, cube, layout, map), + failparam(transform.cube2map, np.eye(2), layout, map, raises=ValueError), (transform.map2cube, map, layout, cube), (transform.map2matrix, map, layout, matrix), (transform.matrix2map, matrix, matrix.shape, map), From 9fb10ed63289fc85f76fa113d8ff8348cbacb757 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Thu, 8 Dec 2022 14:22:32 +0100 Subject: [PATCH 25/33] move linear operator to class --- modopt/tests/test_opt.py | 97 +++++++++++++++++++--------------------- 1 file changed, 45 insertions(+), 52 deletions(-) diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py index a675b805..cb403ec9 100644 --- a/modopt/tests/test_opt.py +++ b/modopt/tests/test_opt.py @@ -145,69 +145,62 @@ def test_grad_op_raises(): ############# -@case(tags="linear") -def case_linear_identity(): - """Case linear operator identity.""" - linop = linear.Identity() +class LinearCases: + """Linear operator cases.""" - data_op, data_adj_op, res_op, res_adj_op = 1, 1, 1, 1 + def case_linear_identity(self): + """Case linear operator identity.""" + linop = linear.Identity() - return linop, data_op, data_adj_op, res_op, res_adj_op + data_op, data_adj_op, res_op, res_adj_op = 1, 1, 1, 1 + return linop, data_op, data_adj_op, res_op, res_adj_op -@case(tags="linear") -def case_linear_wavelet(): - """Case linear operator wavelet.""" - linop = linear.WaveletConvolve(filters=np.arange(8).reshape(2, 2, 2).astype(float)) - data_op = np.arange(4).reshape(1, 2, 2).astype(float) - data_adj_op = np.arange(8).reshape(1, 2, 2, 2).astype(float) - res_op = np.array([[[[0, 0], [0, 4.0]], [[0, 4.0], [8.0, 28.0]]]]) - res_adj_op = np.array([[[28.0, 62.0], [68.0, 140.0]]]) - - return linop, data_op, data_adj_op, res_op, res_adj_op - - -@case(tags="linear") -def case_linear_combo(): - """Case linear operator combo.""" - parent = linear.LinearParent( - func_sq, - func_cube, - ) - linop = linear.LinearCombo([parent, parent], [1.0, 1.0]) - - data_op, data_adj_op, res_op, res_adj_op = ( - 2, - np.array([2, 2]), - np.array([4, 4]), - 16.0, - ) - - return linop, data_op, data_adj_op, res_op, res_adj_op + def case_linear_wavelet(self): + """Case linear operator wavelet.""" + linop = linear.WaveletConvolve( + filters=np.arange(8).reshape(2, 2, 2).astype(float) + ) + data_op = np.arange(4).reshape(1, 2, 2).astype(float) + data_adj_op = np.arange(8).reshape(1, 2, 2, 2).astype(float) + res_op = np.array([[[[0, 0], [0, 4.0]], [[0, 4.0], [8.0, 28.0]]]]) + res_adj_op = np.array([[[28.0, 62.0], [68.0, 140.0]]]) + + return linop, data_op, data_adj_op, res_op, res_adj_op + + @parametrize(weights=[[1.0, 1.0], None]) + def case_linear_combo(self, weights): + """Case linear operator combo with weights.""" + parent = linear.LinearParent( + func_sq, + func_cube, + ) + linop = linear.LinearCombo([parent, parent], weights) + data_op, data_adj_op, res_op, res_adj_op = ( + 2, + np.array([2, 2]), + np.array([4, 4]), + 8.0 * (2 if weights else 1), + ) -@case(tags="linear") -def case_linear_combo_weight(): - """Case linear operator combo with weights.""" - parent = linear.LinearParent( - func_sq, - func_cube, - ) - linop = linear.LinearCombo([parent, parent], [1.0, 1.0]) + return linop, data_op, data_adj_op, res_op, res_adj_op - data_op, data_adj_op, res_op, res_adj_op = ( - 2, - np.array([2, 2]), - np.array([4, 4]), - 16.0, - ) + @parametrize(factor=[1, 1 + 1j]) + def case_linear_matrix(self, factor): + """Case linear operator from matrix.""" + linop = linear.MatrixOperator(np.eye(5) * factor) + data_op = np.arange(5) + data_adj_op = np.arange(5) + res_op = np.arange(5) * factor + res_adj_op = np.arange(5) * np.conjugate(factor) - return linop, data_op, data_adj_op, res_op, res_adj_op + return linop, data_op, data_adj_op, res_op, res_adj_op @fixture @parametrize_with_cases( - "linop, data_op, data_adj_op, res_op, res_adj_op", cases=".", has_tag="linear" + "linop, data_op, data_adj_op, res_op, res_adj_op", cases=LinearCases ) def lin_adj_op(linop, data_op, data_adj_op, res_op, res_adj_op): """Get adj_op relative data.""" @@ -216,7 +209,7 @@ def lin_adj_op(linop, data_op, data_adj_op, res_op, res_adj_op): @fixture @parametrize_with_cases( - "linop, data_op, data_adj_op, res_op, res_adj_op", cases=".", has_tag="linear" + "linop, data_op, data_adj_op, res_op, res_adj_op", cases=LinearCases ) def lin_op(linop, data_op, data_adj_op, res_op, res_adj_op): """Get op relative data.""" From 67fea29421c3306e7808065569507a205a356500 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Thu, 8 Dec 2022 14:22:45 +0100 Subject: [PATCH 26/33] update docstring. --- modopt/opt/linear.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/modopt/opt/linear.py b/modopt/opt/linear.py index 3519024a..83241625 100644 --- a/modopt/opt/linear.py +++ b/modopt/opt/linear.py @@ -85,7 +85,7 @@ class MatrixOperator(LinearParent): """ Matrix Operator class. - This transform an array into a suitable linear operator. + This class transforms an array into a suitable linear operator. """ def __init__(self, array): @@ -93,7 +93,6 @@ def __init__(self, array): xp = get_array_module(array) if xp.any(xp.iscomplex(array)): - self.adj_op = lambda x: array.T.conjugate() @ x else: self.adj_op = lambda x: array.T @ x From b3d019de7616eb7ae84ef29cfcce0c9be55f6383 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Thu, 8 Dec 2022 14:23:23 +0100 Subject: [PATCH 27/33] paramet(e)rization. --- modopt/tests/test_helpers/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modopt/tests/test_helpers/utils.py b/modopt/tests/test_helpers/utils.py index ad3795d5..f882a3f4 100644 --- a/modopt/tests/test_helpers/utils.py +++ b/modopt/tests/test_helpers/utils.py @@ -2,12 +2,12 @@ def failparam(*args, raises=ValueError): - """Return a pytest parametrization that should raise an error.""" + """Return a pytest parameterization that should raise an error.""" return pytest.param(*args, marks=pytest.mark.raises(exception=raises)) def skipparam(*args, cond=True, reason=""): - """Return a pytest parametrization that should raise an error.""" + """Return a pytest parameterization that should raise an error.""" return pytest.param(*args, marks=pytest.mark.skipif(cond, reason=reason)) class Dummy: From d355fdc6e4c7d60cf4ba395dc8f4cba9cc5534c3 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Thu, 8 Dec 2022 14:24:26 +0100 Subject: [PATCH 28/33] update docstring. --- modopt/tests/test_base.py | 2 +- modopt/tests/test_math.py | 4 ++-- modopt/tests/test_opt.py | 10 +++++----- modopt/tests/test_signal.py | 10 +++++----- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/modopt/tests/test_base.py b/modopt/tests/test_base.py index ce528aa2..e32ff94b 100644 --- a/modopt/tests/test_base.py +++ b/modopt/tests/test_base.py @@ -87,7 +87,7 @@ def test_ftr(self): ) def test_ftl(self): - """Test ftl.""" + """Test fancy transpose left.""" npt.assert_array_equal( np_adjust.ftl(self.array233), np.array( diff --git a/modopt/tests/test_math.py b/modopt/tests/test_math.py index ea8f99ff..e44011c9 100644 --- a/modopt/tests/test_math.py +++ b/modopt/tests/test_math.py @@ -298,7 +298,7 @@ class TestStats: ], ) def test_gaussian_kernel(self, norm, result): - """Test gaussian kernel.""" + """Test Gaussian kernel.""" npt.assert_allclose( stats.gaussian_kernel(self.array33.shape, 1, norm=norm), result ) @@ -328,7 +328,7 @@ def test_sigma_mad(self): ], ) def test_psnr(self, data1, data2, method, result): - """Test psnr.""" + """Test PSNR.""" npt.assert_almost_equal(stats.psnr(data1, data2, method=method), result) def test_psnr_stack(self): diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py index cb403ec9..6e15d98c 100644 --- a/modopt/tests/test_opt.py +++ b/modopt/tests/test_opt.py @@ -389,7 +389,7 @@ def case_prox_elasticnet(self, alpha, beta): ], ) def case_prox_Ksupport(self, beta, k_value, data, result, cost): - """Case prox Ksupport norm.""" + """Case prox K-support norm.""" return (proximity.KSupportNorm(beta=beta, k_value=k_value), data, result, cost) @parametrize(use_weights=[True, False]) @@ -422,7 +422,7 @@ def case_prox_grouplasso(self, use_weights): @pytest.mark.skipif(not SKLEARN_AVAILABLE, reason="sklearn not available.") def case_prox_owl(self): - """Case prox for owl.""" + """Case prox for Ordered Weighted L1 Norm.""" return ( proximity.OrderedWeightedL1Norm(self.weights.flatten()), self.array33.flatten(), @@ -464,7 +464,7 @@ def test_fail_sklearn(): def test_fail_owl(): - """Test fail owl.""" + """Test errors for Ordered Weighted L1 Norm.""" npt.assert_raises( ValueError, proximity.OrderedWeightedL1Norm, @@ -479,13 +479,13 @@ def test_fail_owl(): def test_fail_lowrank(): - """Test fail for lowrnk.""" + """Test fail for lowrank.""" prox_op = proximity.LowRankMatrix(10, lowr_type="fail") npt.assert_raises(ValueError, prox_op.op, 0) def test_fail_Ksupport_norm(): - """Test fail for Ksupport norm.""" + """Test fail for K-support norm.""" npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0) diff --git a/modopt/tests/test_signal.py b/modopt/tests/test_signal.py index 213f2102..188ed5a7 100644 --- a/modopt/tests/test_signal.py +++ b/modopt/tests/test_signal.py @@ -24,7 +24,7 @@ def test_gaussian_filter(norm, result): def test_mex_hat(): - """Test mex_hat.""" + """Test mexican hat filter.""" npt.assert_almost_equal( filter.mex_hat(2, 1), -0.35213905225713371, @@ -32,7 +32,7 @@ def test_mex_hat(): def test_mex_hat_dir(): - """Test mex_hat_dir.""" + """Test directional mexican hat filter.""" npt.assert_almost_equal( filter.mex_hat_dir(1, 2, 1), 0.17606952612856686, @@ -170,11 +170,11 @@ def data(self): @pytest.fixture def svd0(self, data): - """Compute SVD for data[0].""" + """Compute SVD of first data sample.""" return svd.calculate_svd(data[0]) def test_find_n_pc(self, data): - """Test find_n_pc.""" + """Test find number of principal component.""" npt.assert_equal( svd.find_n_pc(svd.svd(data[1])[0]), 2, @@ -223,7 +223,7 @@ def test_svd_thresh_coef(self, data, operator): # TODO: is this module really necessary ? -# It is not use anywhere, not in modopt, nor pysap. +# It is not use anywhere, neither in modopt, nor pysap. class TestValidation: """Test validation Module.""" From 2ba52e373604488915598bc4715f2f40ab38777a Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 2 Jan 2023 17:26:36 +0100 Subject: [PATCH 29/33] improve test_helper module. --- modopt/tests/test_helpers/__init__.py | 1 - modopt/tests/test_helpers/utils.py | 9 ++++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/modopt/tests/test_helpers/__init__.py b/modopt/tests/test_helpers/__init__.py index eadbc7bc..3886b877 100644 --- a/modopt/tests/test_helpers/__init__.py +++ b/modopt/tests/test_helpers/__init__.py @@ -1,2 +1 @@ -#!/usr/bin/env python3 from .utils import failparam, skipparam, Dummy diff --git a/modopt/tests/test_helpers/utils.py b/modopt/tests/test_helpers/utils.py index f882a3f4..d9347a2f 100644 --- a/modopt/tests/test_helpers/utils.py +++ b/modopt/tests/test_helpers/utils.py @@ -1,3 +1,9 @@ +""" +Some helper functions for the test parametrization. +They should be used inside ``@pytest.mark.parametrize`` call. + +:Author: Pierre-Antoine Comby +""" import pytest @@ -7,8 +13,9 @@ def failparam(*args, raises=ValueError): def skipparam(*args, cond=True, reason=""): - """Return a pytest parameterization that should raise an error.""" + """Return a pytest parameterization that should be skip if cond is valid.""" return pytest.param(*args, marks=pytest.mark.skipif(cond, reason=reason)) + class Dummy: pass From 9d5868d76f9017c0831a576fc670e4ea33f872b7 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 2 Jan 2023 17:41:35 +0100 Subject: [PATCH 30/33] raises should be specified for each failparam call. --- modopt/tests/test_helpers/utils.py | 4 +++- modopt/tests/test_signal.py | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/modopt/tests/test_helpers/utils.py b/modopt/tests/test_helpers/utils.py index d9347a2f..d8227640 100644 --- a/modopt/tests/test_helpers/utils.py +++ b/modopt/tests/test_helpers/utils.py @@ -7,8 +7,10 @@ import pytest -def failparam(*args, raises=ValueError): +def failparam(*args, raises=None): """Return a pytest parameterization that should raise an error.""" + if not issubclass(raises, Exception): + raise ValueError("raises should be an expected Exception.") return pytest.param(*args, marks=pytest.mark.raises(exception=raises)) diff --git a/modopt/tests/test_signal.py b/modopt/tests/test_signal.py index 188ed5a7..31b8b6bb 100644 --- a/modopt/tests/test_signal.py +++ b/modopt/tests/test_signal.py @@ -64,7 +64,7 @@ class TestNoise: (data1, "poisson", 1, data2), (data1, "gauss", 1, data3), (data1, "gauss", (1, 1, 1), data3), - failparam(data1, "fail", 1, data1), + failparam(data1, "fail", 1, data1, raises=ValueError), ], ) def test_add_noise(self, data, noise_type, sigma, data_noise): @@ -197,7 +197,8 @@ def test_calculate_svd(self, data, svd0): raise AssertionError("Incorrect SVD calculation for: " + ", ".join(errors)) @pytest.mark.parametrize( - ("n_pc", "idx_res"), [(None, 3), (1, 4), ("all", 0), failparam("fail", 1)] + ("n_pc", "idx_res"), + [(None, 3), (1, 4), ("all", 0), failparam("fail", 1, raises=ValueError)], ) def test_svd_thresh(self, data, n_pc, idx_res): """Test svd_tresh.""" From c17d3c89d8738b985b1b75f6a8ad8d60a2b0a058 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 2 Jan 2023 17:46:10 +0100 Subject: [PATCH 31/33] encapsulate module's test in classes. --- modopt/tests/test_signal.py | 50 +++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/modopt/tests/test_signal.py b/modopt/tests/test_signal.py index 31b8b6bb..202e541b 100644 --- a/modopt/tests/test_signal.py +++ b/modopt/tests/test_signal.py @@ -15,28 +15,30 @@ from modopt.signal import filter, noise, positivity, svd, validation, wavelet -@pytest.mark.parametrize( - ("norm", "result"), [(True, 0.24197072451914337), (False, 0.60653065971263342)] -) -def test_gaussian_filter(norm, result): - """Test gaussian filter.""" - npt.assert_almost_equal(filter.gaussian_filter(1, 1, norm=norm), result) - - -def test_mex_hat(): - """Test mexican hat filter.""" - npt.assert_almost_equal( - filter.mex_hat(2, 1), - -0.35213905225713371, +class TestFilter: + """Test filter module""" + @pytest.mark.parametrize( + ("norm", "result"), [(True, 0.24197072451914337), (False, 0.60653065971263342)] ) + def test_gaussian_filter(self, norm, result): + """Test gaussian filter.""" + npt.assert_almost_equal(filter.gaussian_filter(1, 1, norm=norm), result) -def test_mex_hat_dir(): - """Test directional mexican hat filter.""" - npt.assert_almost_equal( - filter.mex_hat_dir(1, 2, 1), - 0.17606952612856686, - ) + def test_mex_hat(self): + """Test mexican hat filter.""" + npt.assert_almost_equal( + filter.mex_hat(2, 1), + -0.35213905225713371, + ) + + + def test_mex_hat_dir(self): + """Test directional mexican hat filter.""" + npt.assert_almost_equal( + filter.mex_hat_dir(1, 2, 1), + 0.17606952612856686, + ) class TestNoise: @@ -84,6 +86,13 @@ def test_thresh(self, threshold_type, result): noise.thresh(self.data1, 5, threshold_type=threshold_type), result ) +class TestPositivity: + """Test positivity module.""" + data1 = np.arange(9).reshape(3, 3).astype(float) + data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]]) + data5 = np.array( + [[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]], + ) @pytest.mark.parametrize( ("value", "expected"), [ @@ -222,9 +231,6 @@ def test_svd_thresh_coef(self, data, operator): # TODO test_svd_thresh_coef_fast - -# TODO: is this module really necessary ? -# It is not use anywhere, neither in modopt, nor pysap. class TestValidation: """Test validation Module.""" From 27f689f9dedefbf5b8d1c5be08007cc5b9c6e9d5 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Tue, 3 Jan 2023 14:05:50 +0100 Subject: [PATCH 32/33] skip test if sklearn is not installed. --- modopt/tests/test_opt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py index 6e15d98c..0e45ffb8 100644 --- a/modopt/tests/test_opt.py +++ b/modopt/tests/test_opt.py @@ -459,10 +459,11 @@ def test_error_prox_combo(arg, error): @pytest.mark.skipif(SKLEARN_AVAILABLE, reason="sklearn is installed") def test_fail_sklearn(): - """Test fail OWL wit sklearn.""" + """Test fail OWL with sklearn.""" npt.assert_raises(ImportError, proximity.OrderedWeightedL1Norm, 1) +@pytest.mark.skipif(not SKLEARN_AVAILABLE, reason="sklearn is not installed.") def test_fail_owl(): """Test errors for Ordered Weighted L1 Norm.""" npt.assert_raises( From 707135461d3cfbcca7792545c1cb530753518a81 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Tue, 3 Jan 2023 14:35:25 +0100 Subject: [PATCH 33/33] pin pydocstyle --- develop.txt | 1 + setup.cfg | 1 + 2 files changed, 2 insertions(+) diff --git a/develop.txt b/develop.txt index f80da576..6ff665eb 100644 --- a/develop.txt +++ b/develop.txt @@ -5,6 +5,7 @@ pytest-cases>= 3.6 pytest-xdist>= 3.0.1 pytest-cov>=2.11.1 pytest-emoji>=0.2.0 +pydocstyle==6.1.1 pytest-pydocstyle>=2.2.0 black isort diff --git a/setup.cfg b/setup.cfg index 13f9dc73..8d8e821b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -92,3 +92,4 @@ addopts = [pydocstyle] convention=numpy +add-ignore=D107