diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml index 2279abef..cdfff840 100644 --- a/.github/workflows/ci-build.yml +++ b/.github/workflows/ci-build.yml @@ -52,7 +52,7 @@ jobs: shell: bash -l {0} run: | export PATH=/usr/share/miniconda/bin:$PATH - python setup.py test + pytest -n 2 - name: Save Test Results if: always() diff --git a/develop.txt b/develop.txt index 25857153..6ff665eb 100644 --- a/develop.txt +++ b/develop.txt @@ -1,7 +1,11 @@ coverage>=5.5 pytest>=6.2.2 +pytest-raises>=0.10 +pytest-cases>= 3.6 +pytest-xdist>= 3.0.1 pytest-cov>=2.11.1 pytest-emoji>=0.2.0 +pydocstyle==6.1.1 pytest-pydocstyle>=2.2.0 black isort diff --git a/modopt/opt/linear.py b/modopt/opt/linear.py index 3807253b..83241625 100644 --- a/modopt/opt/linear.py +++ b/modopt/opt/linear.py @@ -11,6 +11,7 @@ import numpy as np from modopt.base.types import check_callable, check_float +from modopt.base.backend import get_array_module from modopt.signal.wavelet import filter_convolve_stack @@ -80,6 +81,23 @@ def __init__(self): self.adj_op = self.op +class MatrixOperator(LinearParent): + """ + Matrix Operator class. + + This class transforms an array into a suitable linear operator. + """ + + def __init__(self, array): + self.op = lambda x: array @ x + xp = get_array_module(array) + + if xp.any(xp.iscomplex(array)): + self.adj_op = lambda x: array.T.conjugate() @ x + else: + self.adj_op = lambda x: array.T @ x + + class WaveletConvolve(LinearParent): """Wavelet Convolution Class. diff --git a/modopt/tests/test_algorithms.py b/modopt/tests/test_algorithms.py index 7ff96a8b..73091acd 100644 --- a/modopt/tests/test_algorithms.py +++ b/modopt/tests/test_algorithms.py @@ -1,470 +1,249 @@ # -*- coding: utf-8 -*- -"""UNIT TESTS FOR OPT.ALGORITHMS. +"""UNIT TESTS FOR Algorithms. -This module contains unit tests for the modopt.opt.algorithms module. - -:Author: Samuel Farrens +This module contains unit tests for the modopt.opt module. +:Authors: + Samuel Farrens + Pierre-Antoine Comby """ -from unittest import TestCase - import numpy as np import numpy.testing as npt - +import pytest from modopt.opt import algorithms, cost, gradient, linear, proximity, reweight - -# Basic functions to be used as operators or as dummy functions -func_identity = lambda x_val: x_val -func_double = lambda x_val: x_val * 2 -func_sq = lambda x_val: x_val ** 2 -func_cube = lambda x_val: x_val ** 3 - - -class Dummy(object): - """Dummy class for tests.""" - - pass - - -class AlgorithmTestCase(TestCase): - """Test case for algorithms module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.data2 = self.data1 + np.random.randn(*self.data1.shape) * 1e-6 - self.data3 = np.arange(9).reshape(3, 3).astype(float) + 1 - - grad_inst = gradient.GradBasic( - self.data1, - func_identity, - func_identity, - ) - +from pytest_cases import ( + case, + fixture, + fixture_ref, + lazy_value, + parametrize, + parametrize_with_cases, +) + +from test_helpers import Dummy + +SKLEARN_AVAILABLE = True +try: + import sklearn +except ImportError: + SKLEARN_AVAILABLE = False + + +@fixture +def idty(): + """Identity function.""" + return lambda x: x + + +@fixture +def reweight_op(): + """Reweight operator.""" + data3 = np.arange(9).reshape(3, 3).astype(float) + 1 + return reweight.cwbReweight(data3) + + +def build_kwargs(kwargs, use_metrics): + """Build the kwargs for each algorithm, replacing placeholders by true values. + + This function has to be call for each test, as direct parameterization somehow + is not working with pytest-xdist and pytest-cases. + It also adds dummy metric measurement to validate the metric api. + """ + update_value = { + "idty": lambda x: x, + "lin_idty": linear.Identity(), + "reweight_op": reweight.cwbReweight( + np.arange(9).reshape(3, 3).astype(float) + 1 + ), + } + new_kwargs = dict() + print(kwargs) + # update the value of the dict is possible. + for key in kwargs: + new_kwargs[key] = update_value.get(kwargs[key], kwargs[key]) + + if use_metrics: + new_kwargs["linear"] = linear.Identity() + new_kwargs["metrics"] = { + "diff": { + "metric": lambda test, ref: np.sum(test - ref), + "mapping": {"x_new": "test"}, + "cst_kwargs": {"ref": np.arange(9).reshape((3, 3))}, + "early_stopping": False, + } + } + + return new_kwargs + + +@parametrize(use_metrics=[True, False]) +class AlgoCases: + """Cases for algorithms.""" + + data1 = np.arange(9).reshape(3, 3).astype(float) + data2 = data1 + np.random.randn(*data1.shape) * 1e-6 + max_iter = 20 + + @parametrize( + kwargs=[ + {"beta_update": "idty", "auto_iterate": False, "cost": None}, + {"beta_update": "idty"}, + {"cost": None, "lambda_update": None}, + {"beta_update": "idty", "a_cd": 3}, + {"beta_update": "idty", "r_lazy": 3, "p_lazy": 0.7, "q_lazy": 0.7}, + {"restart_strategy": "adaptive", "xi_restart": 0.9}, + { + "restart_strategy": "greedy", + "xi_restart": 0.9, + "min_beta": 1.0, + "s_greedy": 1.1, + }, + ] + ) + def case_forward_backward(self, kwargs, idty, use_metrics): + """Forward Backward case.""" + update_kwargs = build_kwargs(kwargs, use_metrics) + algo = algorithms.ForwardBackward( + self.data1, + grad=gradient.GradBasic(self.data1, idty, idty), + prox=proximity.Positivity(), + **update_kwargs, + ) + if update_kwargs.get("auto_iterate", None) is False: + algo.iterate(self.max_iter) + return algo, update_kwargs + + @parametrize( + kwargs=[ + { + "cost": None, + "auto_iterate": False, + "gamma_update": "idty", + "beta_update": "idty", + }, + {"gamma_update": "idty", "lambda_update": "idty"}, + {"cost": True}, + {"cost": True, "step_size": 2}, + ] + ) + def case_gen_forward_backward(self, kwargs, use_metrics, idty): + """General FB setup.""" + update_kwargs = build_kwargs(kwargs, use_metrics) + grad_inst = gradient.GradBasic(self.data1, idty, idty) prox_inst = proximity.Positivity() prox_dual_inst = proximity.IdentityProx() - linear_inst = linear.Identity() - reweight_inst = reweight.cwbReweight(self.data3) - cost_inst = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) - self.setup = algorithms.SetUp() - self.max_iter = 20 - - self.fb_all_iter = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=None, - auto_iterate=False, - beta_update=func_identity, - ) - self.fb_all_iter.iterate(self.max_iter) - - self.fb1 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - ) - - self.fb2 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - lambda_update=None, - ) - - self.fb3 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - a_cd=3, - ) - - self.fb4 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - r_lazy=3, - p_lazy=0.7, - q_lazy=0.7, - ) - - self.fb5 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - restart_strategy='adaptive', - xi_restart=0.9, - ) - - self.fb6 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - restart_strategy='greedy', - xi_restart=0.9, - min_beta=1.0, - s_greedy=1.1, - ) - - self.gfb_all_iter = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - cost=None, - auto_iterate=False, - gamma_update=func_identity, - beta_update=func_identity, - ) - self.gfb_all_iter.iterate(self.max_iter) - - self.gfb1 = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - gamma_update=func_identity, - lambda_update=func_identity, - ) - - self.gfb2 = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - cost=cost_inst, - ) - - self.gfb3 = algorithms.GenForwardBackward( + if update_kwargs.get("cost", None) is True: + update_kwargs["cost"] = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) + algo = algorithms.GenForwardBackward( self.data1, grad=grad_inst, prox_list=[prox_inst, prox_dual_inst], - cost=cost_inst, - step_size=2, - ) - - self.condat_all_iter = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - cost=None, - prox_dual=prox_dual_inst, - sigma_update=func_identity, - tau_update=func_identity, - rho_update=func_identity, - auto_iterate=False, - ) - self.condat_all_iter.iterate(self.max_iter) - - self.condat1 = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - prox_dual=prox_dual_inst, - sigma_update=func_identity, - tau_update=func_identity, - rho_update=func_identity, - ) - - self.condat2 = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - prox_dual=prox_dual_inst, - linear=linear_inst, - cost=cost_inst, - reweight=reweight_inst, - ) + **update_kwargs, + ) + if update_kwargs.get("auto_iterate", None) is False: + algo.iterate(self.max_iter) + return algo, update_kwargs + + @parametrize( + kwargs=[ + { + "sigma_dual": "idty", + "tau_update": "idty", + "rho_update": "idty", + "auto_iterate": False, + }, + { + "sigma_dual": "idty", + "tau_update": "idty", + "rho_update": "idty", + }, + { + "linear": "lin_idty", + "cost": True, + "reweight": "reweight_op", + }, + ] + ) + def case_condat(self, kwargs, use_metrics, idty): + """Condat Vu Algorithm setup.""" + update_kwargs = build_kwargs(kwargs, use_metrics) + grad_inst = gradient.GradBasic(self.data1, idty, idty) + prox_inst = proximity.Positivity() + prox_dual_inst = proximity.IdentityProx() + if update_kwargs.get("cost", None) is True: + update_kwargs["cost"] = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) - self.condat3 = algorithms.Condat( + algo = algorithms.Condat( self.data1, self.data2, grad=grad_inst, prox=prox_inst, prox_dual=prox_dual_inst, - linear=Dummy(), - cost=cost_inst, - auto_iterate=False, + **update_kwargs, ) + if update_kwargs.get("auto_iterate", None) is False: + algo.iterate(self.max_iter) + return algo, update_kwargs - self.pogm_all_iter = algorithms.POGM( + @parametrize(kwargs=[{"auto_iterate": False, "cost": None}, {}]) + def case_pogm(self, kwargs, use_metrics, idty): + """POGM setup.""" + update_kwargs = build_kwargs(kwargs, use_metrics) + grad_inst = gradient.GradBasic(self.data1, idty, idty) + prox_inst = proximity.Positivity() + algo = algorithms.POGM( u=self.data1, x=self.data1, y=self.data1, z=self.data1, grad=grad_inst, prox=prox_inst, - auto_iterate=False, - cost=None, + **update_kwargs, ) - self.pogm_all_iter.iterate(self.max_iter) - self.pogm1 = algorithms.POGM( - u=self.data1, - x=self.data1, - y=self.data1, - z=self.data1, - grad=grad_inst, - prox=prox_inst, - ) + if update_kwargs.get("auto_iterate", None) is False: + algo.iterate(self.max_iter) + return algo, update_kwargs - self.vanilla_grad = algorithms.VanillaGenericGradOpt( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - ) - self.ada_grad = algorithms.AdaGenericGradOpt( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - ) - self.adam_grad = algorithms.ADAMGradOpt( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - ) - self.momentum_grad = algorithms.MomentumGradOpt( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - ) - self.rms_grad = algorithms.RMSpropGradOpt( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - ) - self.saga_grad = algorithms.SAGAOptGradOpt( + @parametrize( + GradDescent=[ + algorithms.VanillaGenericGradOpt, + algorithms.AdaGenericGradOpt, + algorithms.ADAMGradOpt, + algorithms.MomentumGradOpt, + algorithms.RMSpropGradOpt, + algorithms.SAGAOptGradOpt, + ] + ) + def case_grad(self, GradDescent, use_metrics, idty): + """Gradient Descent algorithm test.""" + update_kwargs = build_kwargs({}, use_metrics) + grad_inst = gradient.GradBasic(self.data1, idty, idty) + prox_inst = proximity.Positivity() + cost_inst = cost.costObj([grad_inst, prox_inst]) + + algo = GradDescent( self.data1, grad=grad_inst, prox=prox_inst, cost=cost_inst, + **update_kwargs, ) + algo.iterate() + return algo, update_kwargs - self.dummy = Dummy() - self.dummy.cost = func_identity - self.setup._check_operator(self.dummy.cost) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.setup = None - self.fb_all_iter = None - self.fb1 = None - self.fb2 = None - self.gfb_all_iter = None - self.gfb1 = None - self.gfb2 = None - self.condat_all_iter = None - self.condat1 = None - self.condat2 = None - self.condat3 = None - self.pogm1 = None - self.pogm_all_iter = None - self.dummy = None - - def test_set_up(self): - """Test set_up.""" - npt.assert_raises(TypeError, self.setup._check_input_data, 1) - - npt.assert_raises(TypeError, self.setup._check_param, 1) - - npt.assert_raises(TypeError, self.setup._check_param_update, 1) - - def test_all_iter(self): - """Test if all opt run for all iterations.""" - opts = [ - self.fb_all_iter, - self.gfb_all_iter, - self.condat_all_iter, - self.pogm_all_iter, - ] - for opt in opts: - npt.assert_equal(opt.idx, self.max_iter - 1) - - def test_forward_backward(self): - """Test forward_backward.""" - npt.assert_array_equal( - self.fb1.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb2.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb3.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb4.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb5.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb6.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - def test_gen_forward_backward(self): - """Test gen_forward_backward.""" - npt.assert_array_equal( - self.gfb1.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_array_equal( - self.gfb2.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_array_equal( - self.gfb3.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_equal( - self.gfb3.step_size, - 2, - err_msg='Incorrect step size.', - ) - npt.assert_raises( - TypeError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=1, - ) +@parametrize_with_cases("algo, kwargs", cases=AlgoCases) +def test_algo(algo, kwargs): + """Test algorithms.""" + if kwargs.get("auto_iterate") is False: + # algo already run + npt.assert_almost_equal(algo.idx, AlgoCases.max_iter - 1) + else: + npt.assert_almost_equal(algo.x_final, AlgoCases.data1) - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[1], - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[0.5, 0.5], - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[0.5], - ) - - def test_condat(self): - """Test gen_condat.""" - npt.assert_almost_equal( - self.condat1.x_final, - self.data1, - err_msg='Incorrect Condat result.', - ) - - npt.assert_almost_equal( - self.condat2.x_final, - self.data1, - err_msg='Incorrect Condat result.', - ) - - def test_pogm(self): - """Test pogm.""" - npt.assert_almost_equal( - self.pogm1.x_final, - self.data1, - err_msg='Incorrect POGM result.', - ) - - def test_ada_grad(self): - """Test ADA Gradient Descent.""" - self.ada_grad.iterate() - npt.assert_almost_equal( - self.ada_grad.x_final, - self.data1, - err_msg='Incorrect ADAGrad results.', - ) - - def test_adam_grad(self): - """Test ADAM Gradient Descent.""" - self.adam_grad.iterate() - npt.assert_almost_equal( - self.adam_grad.x_final, - self.data1, - err_msg='Incorrect ADAMGrad results.', - ) - - def test_momemtum_grad(self): - """Test Momemtum Gradient Descent.""" - self.momentum_grad.iterate() - npt.assert_almost_equal( - self.momentum_grad.x_final, - self.data1, - err_msg='Incorrect MomentumGrad results.', - ) - - def test_rmsprop_grad(self): - """Test RMSProp Gradient Descent.""" - self.rms_grad.iterate() - npt.assert_almost_equal( - self.rms_grad.x_final, - self.data1, - err_msg='Incorrect RMSPropGrad results.', - ) - - def test_saga_grad(self): - """Test SAGA Descent.""" - self.saga_grad.iterate() - npt.assert_almost_equal( - self.saga_grad.x_final, - self.data1, - err_msg='Incorrect SAGA Grad results.', - ) - - def test_vanilla_grad(self): - """Test Vanilla Gradient Descent.""" - self.vanilla_grad.iterate() - npt.assert_almost_equal( - self.vanilla_grad.x_final, - self.data1, - err_msg='Incorrect VanillaGrad results.', - ) + if kwargs.get("metrics"): + print(algo.metrics) + npt.assert_almost_equal(algo.metrics["diff"]["values"][-1], 0, 3) diff --git a/modopt/tests/test_base.py b/modopt/tests/test_base.py index 873a4506..e32ff94b 100644 --- a/modopt/tests/test_base.py +++ b/modopt/tests/test_base.py @@ -1,192 +1,139 @@ -# -*- coding: utf-8 -*- - -"""UNIT TESTS FOR BASE. - -This module contains unit tests for the modopt.base module. - -:Author: Samuel Farrens - """ +Test for base module. -from builtins import range -from unittest import TestCase, skipIf - +:Authors: + Samuel Farrens + Pierre-Antoine Comby +""" import numpy as np import numpy.testing as npt +import pytest +from test_helpers import failparam, skipparam -from modopt.base import np_adjust, transform, types -from modopt.base.backend import (LIBRARIES, change_backend, get_array_module, - get_backend) +from modopt.base import backend, np_adjust, transform, types +from modopt.base.backend import LIBRARIES -class NPAdjustTestCase(TestCase): - """Test case for np_adjust module.""" +class TestNpAdjust: + """Test for npadjust.""" - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape((3, 3)) - self.data2 = np.arange(18).reshape((2, 3, 3)) - self.data3 = np.array([ + array33 = np.arange(9).reshape((3, 3)) + array233 = np.arange(18).reshape((2, 3, 3)) + arraypad = np.array( + [ [0, 0, 0, 0, 0], [0, 0, 1, 2, 0], [0, 3, 4, 5, 0], [0, 6, 7, 8, 0], [0, 0, 0, 0, 0], - ]) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.data3 = None + ] + ) def test_rotate(self): """Test rotate.""" npt.assert_array_equal( - np_adjust.rotate(self.data1), - np.array([[8, 7, 6], [5, 4, 3], [2, 1, 0]]), - err_msg='Incorrect rotation', + np_adjust.rotate(self.array33), + np.rot90(np.rot90(self.array33)), + err_msg="Incorrect rotation.", ) def test_rotate_stack(self): """Test rotate_stack.""" npt.assert_array_equal( - np_adjust.rotate_stack(self.data2), - np.array([ - [[8, 7, 6], [5, 4, 3], [2, 1, 0]], - [[17, 16, 15], [14, 13, 12], [11, 10, 9]], - ]), - err_msg='Incorrect stack rotation', + np_adjust.rotate_stack(self.array233), + np.rot90(self.array233, k=2, axes=(1, 2)), + err_msg="Incorrect stack rotation.", ) - def test_pad2d(self): + @pytest.mark.parametrize( + "padding", + [ + 1, + [1, 1], + np.array([1, 1]), + failparam("1", raises=ValueError), + ], + ) + def test_pad2d(self, padding): """Test pad2d.""" - npt.assert_array_equal( - np_adjust.pad2d(self.data1, (1, 1)), - self.data3, - err_msg='Incorrect padding', - ) - - npt.assert_array_equal( - np_adjust.pad2d(self.data1, 1), - self.data3, - err_msg='Incorrect padding', - ) - - npt.assert_array_equal( - np_adjust.pad2d(self.data1, np.array([1, 1])), - self.data3, - err_msg='Incorrect padding', - ) - - npt.assert_raises(ValueError, np_adjust.pad2d, self.data1, '1') + npt.assert_equal(np_adjust.pad2d(self.array33, padding), self.arraypad) def test_fancy_transpose(self): - """Test fancy_transpose.""" + """Test fancy transpose.""" npt.assert_array_equal( - np_adjust.fancy_transpose(self.data2), - np.array([ - [[0, 3, 6], [9, 12, 15]], - [[1, 4, 7], [10, 13, 16]], - [[2, 5, 8], [11, 14, 17]], - ]), - err_msg='Incorrect fancy transpose', + np_adjust.fancy_transpose(self.array233), + np.array( + [ + [[0, 3, 6], [9, 12, 15]], + [[1, 4, 7], [10, 13, 16]], + [[2, 5, 8], [11, 14, 17]], + ] + ), + err_msg="Incorrect fancy transpose", ) def test_ftr(self): """Test ftr.""" npt.assert_array_equal( - np_adjust.ftr(self.data2), - np.array([ - [[0, 3, 6], [9, 12, 15]], - [[1, 4, 7], [10, 13, 16]], - [[2, 5, 8], [11, 14, 17]], - ]), - err_msg='Incorrect fancy transpose: ftr', + np_adjust.ftr(self.array233), + np.array( + [ + [[0, 3, 6], [9, 12, 15]], + [[1, 4, 7], [10, 13, 16]], + [[2, 5, 8], [11, 14, 17]], + ] + ), + err_msg="Incorrect fancy transpose: ftr", ) def test_ftl(self): - """Test ftl.""" - npt.assert_array_equal( - np_adjust.ftl(self.data2), - np.array([ - [[0, 9], [1, 10], [2, 11]], - [[3, 12], [4, 13], [5, 14]], - [[6, 15], [7, 16], [8, 17]], - ]), - err_msg='Incorrect fancy transpose: ftl', - ) - - -class TransformTestCase(TestCase): - """Test case for transform module.""" - - def setUp(self): - """Set test parameter values.""" - self.cube = np.arange(16).reshape((4, 2, 2)) - self.map = np.array( - [[0, 1, 4, 5], [2, 3, 6, 7], [8, 9, 12, 13], [10, 11, 14, 15]], - ) - self.matrix = np.array( - [[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]], - ) - self.layout = (2, 2) - - def tearDown(self): - """Unset test parameter values.""" - self.cube = None - self.map = None - self.layout = None - - def test_cube2map(self): + """Test fancy transpose left.""" + npt.assert_array_equal( + np_adjust.ftl(self.array233), + np.array( + [ + [[0, 9], [1, 10], [2, 11]], + [[3, 12], [4, 13], [5, 14]], + [[6, 15], [7, 16], [8, 17]], + ] + ), + err_msg="Incorrect fancy transpose: ftl", + ) + + +class TestTransforms: + """Test for the transform module.""" + + cube = np.arange(16).reshape((4, 2, 2)) + map = np.array([[0, 1, 4, 5], [2, 3, 6, 7], [8, 9, 12, 13], [10, 11, 14, 15]]) + matrix = np.array([[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]) + layout = (2, 2) + fail_layout = (3, 3) + + @pytest.mark.parametrize( + ("func", "indata", "layout", "outdata"), + [ + (transform.cube2map, cube, layout, map), + failparam(transform.cube2map, np.eye(2), layout, map, raises=ValueError), + (transform.map2cube, map, layout, cube), + (transform.map2matrix, map, layout, matrix), + (transform.matrix2map, matrix, matrix.shape, map), + ], + ) + def test_map(self, func, indata, layout, outdata): """Test cube2map.""" npt.assert_array_equal( - transform.cube2map(self.cube, self.layout), - self.map, - err_msg='Incorrect transformation: cube2map', - ) - - npt.assert_raises( - ValueError, - transform.cube2map, - self.map, - self.layout, - ) - - npt.assert_raises(ValueError, transform.cube2map, self.cube, (3, 3)) - - def test_map2cube(self): - """Test map2cube.""" - npt.assert_array_equal( - transform.map2cube(self.map, self.layout), - self.cube, - err_msg='Incorrect transformation: map2cube', - ) - - npt.assert_raises(ValueError, transform.map2cube, self.map, (3, 3)) - - def test_map2matrix(self): - """Test map2matrix.""" - npt.assert_array_equal( - transform.map2matrix(self.map, self.layout), - self.matrix, - err_msg='Incorrect transformation: map2matrix', - ) - - def test_matrix2map(self): - """Test matrix2map.""" - npt.assert_array_equal( - transform.matrix2map(self.matrix, self.map.shape), - self.map, - err_msg='Incorrect transformation: matrix2map', + func(indata, layout), + outdata, ) + if func.__name__ != "map2matrix": + npt.assert_raises(ValueError, func, indata, self.fail_layout) def test_cube2matrix(self): """Test cube2matrix.""" npt.assert_array_equal( transform.cube2matrix(self.cube), self.matrix, - err_msg='Incorrect transformation: cube2matrix', ) def test_matrix2cube(self): @@ -194,136 +141,78 @@ def test_matrix2cube(self): npt.assert_array_equal( transform.matrix2cube(self.matrix, self.cube[0].shape), self.cube, - err_msg='Incorrect transformation: matrix2cube', - ) - - -class TypesTestCase(TestCase): - """Test case for types module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = list(range(5)) - self.data2 = np.arange(5) - self.data3 = np.arange(5).astype(float) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.data3 = None - - def test_check_float(self): - """Test check_float.""" - npt.assert_array_equal( - types.check_float(1.0), - 1.0, - err_msg='Float check failed', - ) - - npt.assert_array_equal( - types.check_float(1), - 1.0, - err_msg='Float check failed', - ) - - npt.assert_array_equal( - types.check_float(self.data1), - self.data3, - err_msg='Float check failed', - ) - - npt.assert_array_equal( - types.check_float(self.data2), - self.data3, - err_msg='Float check failed', - ) - - npt.assert_raises(TypeError, types.check_float, '1') - - def test_check_int(self): - """Test check_int.""" - npt.assert_array_equal( - types.check_int(1), - 1, - err_msg='Float check failed', - ) - - npt.assert_array_equal( - types.check_int(1.0), - 1, - err_msg='Float check failed', - ) - - npt.assert_array_equal( - types.check_int(self.data1), - self.data2, - err_msg='Float check failed', - ) - - npt.assert_array_equal( - types.check_int(self.data3), - self.data2, - err_msg='Int check failed', - ) - - npt.assert_raises(TypeError, types.check_int, '1') - - def test_check_npndarray(self): + err_msg="Incorrect transformation: matrix2cube", + ) + + +class TestType: + """Test for type module.""" + + data_list = list(range(5)) + data_int = np.arange(5) + data_flt = np.arange(5).astype(float) + + @pytest.mark.parametrize( + ("data", "checked"), + [ + (1.0, 1.0), + (1, 1.0), + (data_list, data_flt), + (data_int, data_flt), + failparam("1.0", 1.0, raises=TypeError), + ], + ) + def test_check_float(self, data, checked): + """Test check float.""" + npt.assert_array_equal(types.check_float(data), checked) + + @pytest.mark.parametrize( + ("data", "checked"), + [ + (1.0, 1), + (1, 1), + (data_list, data_int), + (data_flt, data_int), + failparam("1", None, raises=TypeError), + ], + ) + def test_check_int(self, data, checked): + """Test check int.""" + npt.assert_array_equal(types.check_int(data), checked) + + @pytest.mark.parametrize( + ("data", "dtype"), [(data_flt, np.integer), (data_int, np.floating)] + ) + def test_check_npndarray(self, data, dtype): """Test check_npndarray.""" npt.assert_raises( TypeError, types.check_npndarray, - self.data3, - dtype=np.integer, - ) - - -class TestBackend(TestCase): - """Test the backend codes.""" - - def setUp(self): - """Set test parameter values.""" - self.input = np.array([10, 10]) - - @skipIf(LIBRARIES['tensorflow'] is None, 'tensorflow library not installed') - def test_tf_backend(self): - """Test tensorflow backend.""" - xp, backend = get_backend('tensorflow') - if backend != 'tensorflow' or xp != LIBRARIES['tensorflow']: - raise AssertionError('tensorflow get_backend fails!') - tf_input = change_backend(self.input, 'tensorflow') - if ( - get_array_module(LIBRARIES['tensorflow'].ones(1)) != LIBRARIES['tensorflow'] - or get_array_module(tf_input) != LIBRARIES['tensorflow'] - ): - raise AssertionError('tensorflow backend fails!') - - @skipIf(LIBRARIES['cupy'] is None, 'cupy library not installed') - def test_cp_backend(self): - """Test cupy backend.""" - xp, backend = get_backend('cupy') - if backend != 'cupy' or xp != LIBRARIES['cupy']: - raise AssertionError('cupy get_backend fails!') - cp_input = change_backend(self.input, 'cupy') - if ( - get_array_module(LIBRARIES['cupy'].ones(1)) != LIBRARIES['cupy'] - or get_array_module(cp_input) != LIBRARIES['cupy'] - ): - raise AssertionError('cupy backend fails!') - - def test_np_backend(self): - """Test numpy backend.""" - xp, backend = get_backend('numpy') - if backend != 'numpy' or xp != LIBRARIES['numpy']: - raise AssertionError('numpy get_backend fails!') - np_input = change_backend(self.input, 'numpy') - if ( - get_array_module(LIBRARIES['numpy'].ones(1)) != LIBRARIES['numpy'] - or get_array_module(np_input) != LIBRARIES['numpy'] - ): - raise AssertionError('numpy backend fails!') - - def tearDown(self): - """Tear Down of objects.""" - self.input = None + data, + dtype=dtype, + ) + + def test_check_callable(self): + """Test callable.""" + npt.assert_raises(TypeError, types.check_callable, 1) + + +@pytest.mark.parametrize( + "backend_name", + [ + skipparam(name, cond=LIBRARIES[name] is None, reason=f"{name} not installed") + for name in LIBRARIES + ], +) +def test_tf_backend(backend_name): + """Test Modopt computational backends.""" + xp, checked_backend_name = backend.get_backend(backend_name) + if checked_backend_name != backend_name or xp != LIBRARIES[backend_name]: + raise AssertionError(f"{backend_name} get_backend fails!") + xp_input = backend.change_backend(np.array([10, 10]), backend_name) + if ( + backend.get_array_module(LIBRARIES[backend_name].ones(1)) + != backend.LIBRARIES[backend_name] + or backend.get_array_module(xp_input) != LIBRARIES[backend_name] + ): + raise AssertionError(f"{backend_name} backend fails!") diff --git a/modopt/tests/test_helpers/__init__.py b/modopt/tests/test_helpers/__init__.py new file mode 100644 index 00000000..3886b877 --- /dev/null +++ b/modopt/tests/test_helpers/__init__.py @@ -0,0 +1 @@ +from .utils import failparam, skipparam, Dummy diff --git a/modopt/tests/test_helpers/utils.py b/modopt/tests/test_helpers/utils.py new file mode 100644 index 00000000..d8227640 --- /dev/null +++ b/modopt/tests/test_helpers/utils.py @@ -0,0 +1,23 @@ +""" +Some helper functions for the test parametrization. +They should be used inside ``@pytest.mark.parametrize`` call. + +:Author: Pierre-Antoine Comby +""" +import pytest + + +def failparam(*args, raises=None): + """Return a pytest parameterization that should raise an error.""" + if not issubclass(raises, Exception): + raise ValueError("raises should be an expected Exception.") + return pytest.param(*args, marks=pytest.mark.raises(exception=raises)) + + +def skipparam(*args, cond=True, reason=""): + """Return a pytest parameterization that should be skip if cond is valid.""" + return pytest.param(*args, marks=pytest.mark.skipif(cond, reason=reason)) + + +class Dummy: + pass diff --git a/modopt/tests/test_math.py b/modopt/tests/test_math.py index ba175ae6..e44011c9 100644 --- a/modopt/tests/test_math.py +++ b/modopt/tests/test_math.py @@ -1,215 +1,181 @@ -# -*- coding: utf-8 -*- - """UNIT TESTS FOR MATH. This module contains unit tests for the modopt.math module. -:Author: Samuel Farrens - +:Authors: + Samuel Farrens + Pierre-Antoine Comby """ - -from unittest import TestCase, skipIf, skipUnless +import pytest +from test_helpers import failparam, skipparam import numpy as np import numpy.testing as npt + from modopt.math import convolve, matrix, metrics, stats try: import astropy except ImportError: # pragma: no cover - import_astropy = False + ASTROPY_AVAILABLE = False else: # pragma: no cover - import_astropy = True + ASTROPY_AVAILABLE = True try: from skimage.metrics import structural_similarity as compare_ssim except ImportError: # pragma: no cover - import_skimage = False + SKIMAGE_AVAILABLE = False else: - import_skimage = True - - -class ConvolveTestCase(TestCase): - """Test case for convolve module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(18).reshape(2, 3, 3) - self.data2 = self.data1 + 1 - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - - @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover - def test_convolve_astropy(self): - """Test convolve using astropy.""" - npt.assert_allclose( - convolve.convolve(self.data1[0], self.data2[0], method='astropy'), - np.array([ - [210.0, 201.0, 210.0], - [129.0, 120.0, 129.0], - [210.0, 201.0, 210.0], - ]), - err_msg='Incorrect convolution: astropy', - ) - - npt.assert_raises( - ValueError, - convolve.convolve, - self.data1[0], - self.data2, - ) - - npt.assert_raises( - ValueError, - convolve.convolve, - self.data1[0], - self.data2[0], - method='bla', - ) - - def test_convolve_scipy(self): - """Test convolve using scipy.""" - npt.assert_allclose( - convolve.convolve(self.data1[0], self.data2[0], method='scipy'), - np.array([ + SKIMAGE_AVAILABLE = True + + +class TestConvolve: + """Test convolve functions.""" + + array233 = np.arange(18).reshape((2, 3, 3)) + array233_1 = array233 + 1 + result_astropy = np.array( + [ + [210.0, 201.0, 210.0], + [129.0, 120.0, 129.0], + [210.0, 201.0, 210.0], + ] + ) + result_scipy = np.array( + [ + [ [14.0, 35.0, 38.0], [57.0, 120.0, 111.0], [110.0, 197.0, 158.0], - ]), - err_msg='Incorrect convolution: scipy', - ) - - def test_convolve_stack(self): - """Test convolve_stack.""" + ], + [ + [518.0, 845.0, 614.0], + [975.0, 1578.0, 1137.0], + [830.0, 1331.0, 950.0], + ], + ] + ) + + result_rot_kernel = np.array( + [ + [ + [66.0, 115.0, 82.0], + [153.0, 240.0, 159.0], + [90.0, 133.0, 82.0], + ], + [ + [714.0, 1087.0, 730.0], + [1125.0, 1698.0, 1131.0], + [738.0, 1105.0, 730.0], + ], + ] + ) + + @pytest.mark.parametrize( + ("input_data", "kernel", "method", "result"), + [ + skipparam( + array233[0], + array233_1[0], + "astropy", + result_astropy, + cond=not ASTROPY_AVAILABLE, + reason="astropy not available", + ), + failparam( + array233[0], array233_1, "astropy", result_astropy, raises=ValueError + ), + failparam( + array233[0], array233_1[0], "fail!", result_astropy, raises=ValueError + ), + (array233[0], array233_1[0], "scipy", result_scipy[0]), + ], + ) + def test_convolve(self, input_data, kernel, method, result): + """Test convolve function.""" + npt.assert_allclose(convolve.convolve(input_data, kernel, method), result) + + @pytest.mark.parametrize( + ("result", "rot_kernel"), + [ + (result_scipy, False), + (result_rot_kernel, True), + ], + ) + def test_convolve_stack(self, result, rot_kernel): + """Test convolve stack function.""" npt.assert_allclose( - convolve.convolve_stack(self.data1, self.data2), - np.array([ - [ - [14.0, 35.0, 38.0], - [57.0, 120.0, 111.0], - [110.0, 197.0, 158.0], - ], - [ - [518.0, 845.0, 614.0], - [975.0, 1578.0, 1137.0], - [830.0, 1331.0, 950.0], - ], - ]), - err_msg='Incorrect convolution: stack', + convolve.convolve_stack( + self.array233, self.array233_1, rot_kernel=rot_kernel + ), + result, ) - def test_convolve_stack_rot(self): - """Test convolve_stack rotated.""" - npt.assert_allclose( - convolve.convolve_stack(self.data1, self.data2, rot_kernel=True), - np.array([ - [ - [66.0, 115.0, 82.0], - [153.0, 240.0, 159.0], - [90.0, 133.0, 82.0], - ], - [ - [714.0, 1087.0, 730.0], - [1125.0, 1698.0, 1131.0], - [738.0, 1105.0, 730.0], - ], - ]), - err_msg='Incorrect convolution: stack rot', - ) +class TestMatrix: + """Test matrix module.""" -class MatrixTestCase(TestCase): - """Test case for matrix module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3) - self.data2 = np.arange(3) - self.data3 = np.arange(6).reshape(2, 3) - np.random.seed(1) - self.pmInstance1 = matrix.PowerMethod( - lambda x_val: x_val.dot(x_val.T), - self.data1.shape, - verbose=True, - ) - np.random.seed(1) - self.pmInstance2 = matrix.PowerMethod( - lambda x_val: x_val.dot(x_val.T), - self.data1.shape, - auto_run=False, - verbose=True, - ) - self.pmInstance2.get_spec_rad(max_iter=1) - self.gram_schmidt_out = ( - np.array([ + array3 = np.arange(3) + array33 = np.arange(9).reshape((3, 3)) + array23 = np.arange(6).reshape((2, 3)) + gram_schmidt_out = ( + np.array( + [ [0, 1.0, 2.0], [3.0, 1.2, -6e-1], [-1.77635684e-15, 0, 0], - ]), - np.array([ + ] + ), + np.array( + [ [0, 0.4472136, 0.89442719], [0.91287093, 0.36514837, -0.18257419], [-1.0, 0, 0], - ]), - ) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.data3 = None - self.pmInstance1 = None - self.pmInstance2 = None - self.gram_schmidt_out = None - - def test_gram_schmidt_orthonormal(self): - """Test gram_schmidt with orthonormal output.""" - npt.assert_allclose( - matrix.gram_schmidt(self.data1), - self.gram_schmidt_out[1], - err_msg='Incorrect Gram-Schmidt: orthonormal', - ) + ] + ), + ) - npt.assert_raises( - ValueError, - matrix.gram_schmidt, - self.data1, - return_opt='bla', - ) - - def test_gram_schmidt_orthogonal(self): - """Test gram_schmidt with orthogonal output.""" - npt.assert_allclose( - matrix.gram_schmidt(self.data1, return_opt='orthogonal'), - self.gram_schmidt_out[0], - err_msg='Incorrect Gram-Schmidt: orthogonal', + @pytest.fixture + def pm_instance(self, request): + """Power Method instance.""" + np.random.seed(1) + pm = matrix.PowerMethod( + lambda x_val: x_val.dot(x_val.T), + self.array33.shape, + auto_run=request.param, + verbose=True, ) - - def test_gram_schmidt_both(self): - """Test gram_schmidt with both outputs.""" + if not request.param: + pm.get_spec_rad(max_iter=1) + return pm + + @pytest.mark.parametrize( + ("return_opt", "output"), + [ + ("orthonormal", gram_schmidt_out[1]), + ("orthogonal", gram_schmidt_out[0]), + ("both", gram_schmidt_out), + failparam("fail!", gram_schmidt_out, raises=ValueError), + ], + ) + def test_gram_schmidt(self, return_opt, output): + """Test gram schmidt.""" npt.assert_allclose( - matrix.gram_schmidt(self.data1, return_opt='both'), - self.gram_schmidt_out, - err_msg='Incorrect Gram-Schmidt: both', + matrix.gram_schmidt(self.array33, return_opt=return_opt), output ) def test_nuclear_norm(self): - """Test nuclear_norm.""" + """Test nuclear norm.""" npt.assert_almost_equal( - matrix.nuclear_norm(self.data1), + matrix.nuclear_norm(self.array33), 15.49193338482967, - err_msg='Incorrect nuclear norm', ) def test_project(self): """Test project.""" npt.assert_array_equal( - matrix.project(self.data2, self.data2 + 3), + matrix.project(self.array3, self.array3 + 3), np.array([0, 2.8, 5.6]), - err_msg='Incorrect projection', ) def test_rot_matrix(self): @@ -217,280 +183,159 @@ def test_rot_matrix(self): npt.assert_allclose( matrix.rot_matrix(np.pi / 6), np.array([[0.8660254, -0.5], [0.5, 0.8660254]]), - err_msg='Incorrect rotation matrix', ) def test_rotate(self): """Test rotate.""" npt.assert_array_equal( - matrix.rotate(self.data1, np.pi / 2), + matrix.rotate(self.array33, np.pi / 2), np.array([[2, 5, 8], [1, 4, 7], [0, 3, 6]]), - err_msg='Incorrect rotation', - ) - - npt.assert_raises(ValueError, matrix.rotate, self.data3, np.pi / 2) - - def test_powermethod_converged(self): - """Test PowerMethod converged.""" - npt.assert_almost_equal( - self.pmInstance1.spec_rad, - 1.0, - err_msg='Incorrect spectral radius: converged', ) - npt.assert_almost_equal( - self.pmInstance1.inv_spec_rad, - 1.0, - err_msg='Incorrect inverse spectral radius: converged', - ) - - def test_powermethod_unconverged(self): - """Test PowerMethod unconverged.""" - npt.assert_almost_equal( - self.pmInstance2.spec_rad, - 0.8675467477372257, - err_msg='Incorrect spectral radius: unconverged', - ) - - npt.assert_almost_equal( - self.pmInstance2.inv_spec_rad, - 1.152675636913221, - err_msg='Incorrect inverse spectral radius: unconverged', - ) - - -class MetricsTestCase(TestCase): - """Test case for metrics module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(49).reshape(7, 7) - self.mask = np.ones(self.data1.shape) - self.ssim_res = 0.8963363560519094 - self.ssim_mask_res = 0.805154442543846 - self.snr_res = 10.134554256920536 - self.psnr_res = 14.860761791850397 - self.mse_res = 0.03265305507330247 - self.nrmse_res = 0.31136678840022625 - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.mask = None - self.ssim_res = None - self.ssim_mask_res = None - self.psnr_res = None - self.mse_res = None - self.nrmse_res = None - - @skipIf(import_skimage, 'skimage is installed.') # pragma: no cover - def test_ssim_skimage_error(self): - """Test ssim skimage error.""" - npt.assert_raises(ImportError, metrics.ssim, self.data1, self.data1) - - @skipUnless(import_skimage, 'skimage not installed.') # pragma: no cover - def test_ssim(self): + npt.assert_raises(ValueError, matrix.rotate, self.array23, np.pi / 2) + + @pytest.mark.parametrize( + ("pm_instance", "value"), + [(True, 1.0), (False, 0.8675467477372257)], + indirect=["pm_instance"], + ) + def test_power_method(self, pm_instance, value): + """Test power method.""" + npt.assert_almost_equal(pm_instance.spec_rad, value) + npt.assert_almost_equal(pm_instance.inv_spec_rad, 1 / value) + + +class TestMetrics: + """Test metrics module.""" + + data1 = np.arange(49).reshape(7, 7) + mask = np.ones(data1.shape) + ssim_res = 0.8963363560519094 + ssim_mask_res = 0.805154442543846 + snr_res = 10.134554256920536 + psnr_res = 14.860761791850397 + mse_res = 0.03265305507330247 + nrmse_res = 0.31136678840022625 + + @pytest.mark.skipif(not SKIMAGE_AVAILABLE, reason="skimage not installed") + @pytest.mark.parametrize( + ("data1", "data2", "result", "mask"), + [ + (data1, data1**2, ssim_res, None), + (data1, data1**2, ssim_mask_res, mask), + failparam(data1, data1, None, 1, raises=ValueError), + ], + ) + def test_ssim(self, data1, data2, result, mask): """Test ssim.""" - npt.assert_almost_equal( - metrics.ssim(self.data1, self.data1 ** 2), - self.ssim_res, - err_msg='Incorrect SSIM result', - ) + npt.assert_almost_equal(metrics.ssim(data1, data2, mask=mask), result) - npt.assert_almost_equal( - metrics.ssim(self.data1, self.data1 ** 2, mask=self.mask), - self.ssim_mask_res, - err_msg='Incorrect SSIM result', - ) - - npt.assert_raises( - ValueError, - metrics.ssim, - self.data1, - self.data1, - mask=1, - ) + @pytest.mark.skipif(SKIMAGE_AVAILABLE, reason="skimage installed") + def test_ssim_fail(self): + """Test ssim.""" + npt.assert_raises(ImportError, metrics.ssim, self.data1, self.data1) - def test_snr(self): + @pytest.mark.parametrize( + ("metric", "data", "result", "mask"), + [ + (metrics.snr, data1, snr_res, None), + (metrics.snr, data1, snr_res, mask), + (metrics.psnr, data1, psnr_res, None), + (metrics.psnr, data1, psnr_res, mask), + (metrics.mse, data1, mse_res, None), + (metrics.mse, data1, mse_res, mask), + (metrics.nrmse, data1, nrmse_res, None), + (metrics.nrmse, data1, nrmse_res, mask), + failparam(metrics.snr, data1, snr_res, "maskfail", raises=ValueError), + ], + ) + def test_metric(self, metric, data, result, mask): """Test snr.""" - npt.assert_almost_equal( - metrics.snr(self.data1, self.data1 ** 2), - self.snr_res, - err_msg='Incorrect SNR result', - ) - - npt.assert_almost_equal( - metrics.snr(self.data1, self.data1 ** 2, mask=self.mask), - self.snr_res, - err_msg='Incorrect SNR result', - ) - - def test_psnr(self): - """Test psnr.""" - npt.assert_almost_equal( - metrics.psnr(self.data1, self.data1 ** 2), - self.psnr_res, - err_msg='Incorrect PSNR result', - ) - - npt.assert_almost_equal( - metrics.psnr(self.data1, self.data1 ** 2, mask=self.mask), - self.psnr_res, - err_msg='Incorrect PSNR result', - ) - - def test_mse(self): - """Test mse.""" - npt.assert_almost_equal( - metrics.mse(self.data1, self.data1 ** 2), - self.mse_res, - err_msg='Incorrect MSE result', - ) - - npt.assert_almost_equal( - metrics.mse(self.data1, self.data1 ** 2, mask=self.mask), - self.mse_res, - err_msg='Incorrect MSE result', - ) - - def test_nrmse(self): - """Test nrmse.""" - npt.assert_almost_equal( - metrics.nrmse(self.data1, self.data1 ** 2), - self.nrmse_res, - err_msg='Incorrect NRMSE result', - ) - - npt.assert_almost_equal( - metrics.nrmse(self.data1, self.data1 ** 2, mask=self.mask), - self.nrmse_res, - err_msg='Incorrect NRMSE result', - ) - - -class StatsTestCase(TestCase): - """Test case for stats module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3) - self.data2 = np.arange(18).reshape(2, 3, 3) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - - @skipIf(import_astropy, 'Astropy is installed.') # pragma: no cover - def test_gaussian_kernel_astropy_error(self): - """Test gaussian_kernel astropy error.""" - npt.assert_raises( - ImportError, - stats.gaussian_kernel, - self.data1.shape, - 1, - ) - - @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover - def test_gaussian_kernel_max(self): - """Test gaussian_kernel with max norm.""" + npt.assert_almost_equal(metric(data, data**2, mask=mask), result) + + +class TestStats: + """Test stats module.""" + + array33 = np.arange(9).reshape(3, 3) + array233 = np.arange(18).reshape(2, 3, 3) + + @pytest.mark.skipif(not ASTROPY_AVAILABLE, reason="astropy not installed") + @pytest.mark.parametrize( + ("norm", "result"), + [ + ( + "max", + np.array( + [ + [0.36787944, 0.60653066, 0.36787944], + [0.60653066, 1.0, 0.60653066], + [0.36787944, 0.60653066, 0.36787944], + ] + ), + ), + ( + "sum", + np.array( + [ + [0.07511361, 0.1238414, 0.07511361], + [0.1238414, 0.20417996, 0.1238414], + [0.07511361, 0.1238414, 0.07511361], + ] + ), + ), + ( + "none", + np.array( + [ + [0.05854983, 0.09653235, 0.05854983], + [0.09653235, 0.15915494, 0.09653235], + [0.05854983, 0.09653235, 0.05854983], + ] + ), + ), + failparam("fail", None, raises=ValueError), + ], + ) + def test_gaussian_kernel(self, norm, result): + """Test Gaussian kernel.""" npt.assert_allclose( - stats.gaussian_kernel(self.data1.shape, 1), - np.array([ - [0.36787944, 0.60653066, 0.36787944], - [0.60653066, 1.0, 0.60653066], - [0.36787944, 0.60653066, 0.36787944], - ]), - err_msg='Incorrect gaussian kernel: max norm', + stats.gaussian_kernel(self.array33.shape, 1, norm=norm), result ) - npt.assert_raises( - ValueError, - stats.gaussian_kernel, - self.data1.shape, - 1, - norm='bla', - ) - - @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover - def test_gaussian_kernel_sum(self): - """Test gaussian_kernel with sum norm.""" - npt.assert_allclose( - stats.gaussian_kernel(self.data1.shape, 1, norm='sum'), - np.array([ - [0.07511361, 0.1238414, 0.07511361], - [0.1238414, 0.20417996, 0.1238414], - [0.07511361, 0.1238414, 0.07511361], - ]), - err_msg='Incorrect gaussian kernel: sum norm', - ) - - @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover - def test_gaussian_kernel_none(self): - """Test gaussian_kernel with no norm.""" - npt.assert_allclose( - stats.gaussian_kernel(self.data1.shape, 1, norm='none'), - np.array([ - [0.05854983, 0.09653235, 0.05854983], - [0.09653235, 0.15915494, 0.09653235], - [0.05854983, 0.09653235, 0.05854983], - ]), - err_msg='Incorrect gaussian kernel: sum norm', - ) + @pytest.mark.skipif(ASTROPY_AVAILABLE, reason="astropy installed") + def test_import_astropy(self): + """Test missing astropy.""" + npt.assert_raises(ImportError, stats.gaussian_kernel, self.array33.shape, 1) def test_mad(self): """Test mad.""" - npt.assert_equal( - stats.mad(self.data1), - 2.0, - err_msg='Incorrect median absolute deviation', - ) - - def test_mse(self): - """Test mse.""" - npt.assert_equal( - stats.mse(self.data1, self.data1 + 2), - 4.0, - err_msg='Incorrect mean squared error', - ) + npt.assert_equal(stats.mad(self.array33), 2.0) - def test_psnr_starck(self): - """Test psnr.""" + def test_sigma_mad(self): + """Test sigma_mad.""" npt.assert_almost_equal( - stats.psnr(self.data1, self.data1 + 2), - 12.041199826559248, - err_msg='Incorrect PSNR: starck', - ) - - npt.assert_raises( - ValueError, - stats.psnr, - self.data1, - self.data1, - method='bla', + stats.sigma_mad(self.array33), + 2.9651999999999998, ) - def test_psnr_wiki(self): - """Test psnr wiki method.""" - npt.assert_almost_equal( - stats.psnr(self.data1, self.data1 + 2, method='wiki'), - 42.110203695399477, - err_msg='Incorrect PSNR: wiki', - ) + @pytest.mark.parametrize( + ("data1", "data2", "method", "result"), + [ + (array33, array33 + 2, "starck", 12.041199826559248), + failparam(array33, array33, "fail", 0, raises=ValueError), + (array33, array33 + 2, "wiki", 42.110203695399477), + ], + ) + def test_psnr(self, data1, data2, method, result): + """Test PSNR.""" + npt.assert_almost_equal(stats.psnr(data1, data2, method=method), result) def test_psnr_stack(self): """Test psnr stack.""" npt.assert_almost_equal( - stats.psnr_stack(self.data2, self.data2 + 2), + stats.psnr_stack(self.array233, self.array233 + 2), 12.041199826559248, - err_msg='Incorrect PSNR stack', ) - npt.assert_raises(ValueError, stats.psnr_stack, self.data1, self.data1) - - def test_sigma_mad(self): - """Test sigma_mad.""" - npt.assert_almost_equal( - stats.sigma_mad(self.data1), - 2.9651999999999998, - err_msg='Incorrect sigma from MAD', - ) + npt.assert_raises(ValueError, stats.psnr_stack, self.array33, self.array33) diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py index d5547783..0e45ffb8 100644 --- a/modopt/tests/test_opt.py +++ b/modopt/tests/test_opt.py @@ -1,718 +1,275 @@ -# -*- coding: utf-8 -*- - """UNIT TESTS FOR OPT. -This module contains unit tests for the modopt.opt module. - -:Author: Samuel Farrens +This module contains tests for the modopt.opt module. +:Authors: + Samuel Farrens + Pierre-Antoine Comby """ -from builtins import zip -from unittest import TestCase, skipIf, skipUnless - import numpy as np import numpy.testing as npt +import pytest +from pytest_cases import parametrize, parametrize_with_cases, case, fixture, fixture_ref + +from modopt.opt import cost, gradient, linear, proximity, reweight -from modopt.opt import algorithms, cost, gradient, linear, proximity, reweight +from test_helpers import Dummy +SKLEARN_AVAILABLE = True try: import sklearn -except ImportError: # pragma: no cover - import_sklearn = False -else: - import_sklearn = True +except ImportError: + SKLEARN_AVAILABLE = False # Basic functions to be used as operators or as dummy functions func_identity = lambda x_val: x_val func_double = lambda x_val: x_val * 2 -func_sq = lambda x_val: x_val ** 2 -func_cube = lambda x_val: x_val ** 3 - - -class Dummy(object): - """Dummy class for tests.""" - - pass - - -class AlgorithmTestCase(TestCase): - """Test case for algorithms module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.data2 = self.data1 + np.random.randn(*self.data1.shape) * 1e-6 - self.data3 = np.arange(9).reshape(3, 3).astype(float) + 1 - - grad_inst = gradient.GradBasic( - self.data1, - func_identity, - func_identity, - ) - - prox_inst = proximity.Positivity() - prox_dual_inst = proximity.IdentityProx() - linear_inst = linear.Identity() - reweight_inst = reweight.cwbReweight(self.data3) - cost_inst = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) - self.setup = algorithms.SetUp() - self.max_iter = 20 - - self.fb_all_iter = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=None, - auto_iterate=False, - beta_update=func_identity, - ) - self.fb_all_iter.iterate(self.max_iter) - - self.fb1 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - ) - - self.fb2 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - lambda_update=None, - ) - - self.fb3 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - a_cd=3, - ) - - self.fb4 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - r_lazy=3, - p_lazy=0.7, - q_lazy=0.7, - ) - - self.fb5 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - restart_strategy='adaptive', - xi_restart=0.9, - ) - - self.fb6 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - restart_strategy='greedy', - xi_restart=0.9, - min_beta=1.0, - s_greedy=1.1, - ) - - self.gfb_all_iter = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - cost=None, - auto_iterate=False, - gamma_update=func_identity, - beta_update=func_identity, - ) - self.gfb_all_iter.iterate(self.max_iter) - - self.gfb1 = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - gamma_update=func_identity, - lambda_update=func_identity, - ) - - self.gfb2 = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - cost=cost_inst, - ) - - self.gfb3 = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - cost=cost_inst, - step_size=2, - ) - - self.condat_all_iter = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - cost=None, - prox_dual=prox_dual_inst, - sigma_update=func_identity, - tau_update=func_identity, - rho_update=func_identity, - auto_iterate=False, - ) - self.condat_all_iter.iterate(self.max_iter) - - self.condat1 = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - prox_dual=prox_dual_inst, - sigma_update=func_identity, - tau_update=func_identity, - rho_update=func_identity, - ) - - self.condat2 = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - prox_dual=prox_dual_inst, - linear=linear_inst, - cost=cost_inst, - reweight=reweight_inst, - ) - - self.condat3 = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - prox_dual=prox_dual_inst, - linear=Dummy(), - cost=cost_inst, - auto_iterate=False, - ) - - self.pogm_all_iter = algorithms.POGM( - u=self.data1, - x=self.data1, - y=self.data1, - z=self.data1, - grad=grad_inst, - prox=prox_inst, - auto_iterate=False, - cost=None, - ) - self.pogm_all_iter.iterate(self.max_iter) - - self.pogm1 = algorithms.POGM( - u=self.data1, - x=self.data1, - y=self.data1, - z=self.data1, - grad=grad_inst, - prox=prox_inst, - ) - - self.dummy = Dummy() - self.dummy.cost = func_identity - self.setup._check_operator(self.dummy.cost) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.setup = None - self.fb_all_iter = None - self.fb1 = None - self.fb2 = None - self.gfb_all_iter = None - self.gfb1 = None - self.gfb2 = None - self.condat_all_iter = None - self.condat1 = None - self.condat2 = None - self.condat3 = None - self.pogm1 = None - self.pogm_all_iter = None - self.dummy = None - - def test_set_up(self): - """Test set_up.""" - npt.assert_raises(TypeError, self.setup._check_input_data, 1) - - npt.assert_raises(TypeError, self.setup._check_param, 1) - - npt.assert_raises(TypeError, self.setup._check_param_update, 1) - - def test_all_iter(self): - """Test if all opt run for all iterations.""" - opts = [ - self.fb_all_iter, - self.gfb_all_iter, - self.condat_all_iter, - self.pogm_all_iter, - ] - for opt in opts: - npt.assert_equal(opt.idx, self.max_iter - 1) - - def test_forward_backward(self): - """Test forward_backward.""" - npt.assert_array_equal( - self.fb1.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb2.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb3.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb4.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb5.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb6.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - def test_gen_forward_backward(self): - """Test gen_forward_backward.""" - npt.assert_array_equal( - self.gfb1.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_array_equal( - self.gfb2.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_array_equal( - self.gfb3.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_equal( - self.gfb3.step_size, - 2, - err_msg='Incorrect step size.', - ) - - npt.assert_raises( - TypeError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=1, - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[1], - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[0.5, 0.5], - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[0.5], - ) - - def test_condat(self): - """Test gen_condat.""" - npt.assert_almost_equal( - self.condat1.x_final, - self.data1, - err_msg='Incorrect Condat result.', - ) - - npt.assert_almost_equal( - self.condat2.x_final, - self.data1, - err_msg='Incorrect Condat result.', - ) +func_sq = lambda x_val: x_val**2 +func_cube = lambda x_val: x_val**3 + + +@case(tags="cost") +@parametrize( + ("cost_interval", "n_calls", "converged"), + [(1, 1, False), (1, 2, True), (2, 5, False), (None, 6, False)], +) +def case_cost_op(cost_interval, n_calls, converged): + """Case function for costs.""" + dummy_inst1 = Dummy() + dummy_inst1.cost = func_sq + dummy_inst2 = Dummy() + dummy_inst2.cost = func_cube + + cost_obj = cost.costObj([dummy_inst1, dummy_inst2], cost_interval=cost_interval) + + for _ in range(n_calls + 1): + cost_obj.get_cost(2) + return cost_obj, converged + + +@parametrize_with_cases("cost_obj, converged", cases=".", has_tag="cost") +def test_costs(cost_obj, converged): + """Test cost.""" + npt.assert_equal(cost_obj.get_cost(2), converged) + if cost_obj._cost_interval: + npt.assert_equal(cost_obj.cost, 12) + + +def test_raise_cost(): + """Test error raising for cost.""" + npt.assert_raises(TypeError, cost.costObj, 1) + npt.assert_raises(ValueError, cost.costObj, [Dummy(), Dummy()]) + + +@case(tags="grad") +@parametrize(call=("op", "trans_op", "trans_op_op")) +def case_grad_parent(call): + """Case for gradient parent.""" + input_data = np.arange(9).reshape(3, 3) + callables = { + "op": func_sq, + "trans_op": func_cube, + "get_grad": func_identity, + "cost": lambda input_val: 1.0, + } + + grad_op = gradient.GradParent( + input_data, + **callables, + data_type=np.floating, + ) + if call != "trans_op_op": + result = callables[call](input_data) + else: + result = callables["trans_op"](callables["op"](input_data)) + + grad_call = getattr(grad_op, call)(input_data) + return grad_call, result + + +@parametrize_with_cases("grad_values, result", cases=".", has_tag="grad") +def test_grad_op(grad_values, result): + """Test Gradient operator.""" + npt.assert_equal(grad_values, result) + + +@pytest.fixture +def grad_basic(): + """Case for GradBasic.""" + input_data = np.arange(9).reshape(3, 3) + grad_op = gradient.GradBasic( + input_data, + func_sq, + func_cube, + verbose=True, + ) + grad_op.get_grad(input_data) + return grad_op + + +def test_grad_basic(grad_basic): + """Test grad basic.""" + npt.assert_array_equal( + grad_basic.grad, + np.array( + [ + [0, 0, 8.0], + [2.16000000e2, 1.72800000e3, 8.0e3], + [2.70000000e4, 7.40880000e4, 1.75616000e5], + ] + ), + err_msg="Incorrect gradient.", + ) - def test_pogm(self): - """Test pogm.""" - npt.assert_almost_equal( - self.pogm1.x_final, - self.data1, - err_msg='Incorrect POGM result.', - ) +def test_grad_basic_cost(grad_basic): + """Test grad_basic cost.""" + npt.assert_almost_equal(grad_basic.cost(np.arange(9).reshape(3, 3)), 3192.0) -class CostTestCase(TestCase): - """Test case for cost module.""" - def setUp(self): - """Set test parameter values.""" - dummy_inst1 = Dummy() - dummy_inst1.cost = func_sq - dummy_inst2 = Dummy() - dummy_inst2.cost = func_cube +def test_grad_op_raises(): + """Test raise error.""" + npt.assert_raises( + TypeError, + gradient.GradParent, + 1, + func_sq, + func_cube, + ) - self.inst1 = cost.costObj([dummy_inst1, dummy_inst2]) - self.inst2 = cost.costObj([dummy_inst1, dummy_inst2], cost_interval=2) - # Test that by default cost of False if interval is None - self.inst_none = cost.costObj( - [dummy_inst1, dummy_inst2], - cost_interval=None, - ) - for _ in range(2): - self.inst1.get_cost(2) - for _ in range(6): - self.inst2.get_cost(2) - self.inst_none.get_cost(2) - self.dummy = Dummy() - - def tearDown(self): - """Unset test parameter values.""" - self.inst = None - - def test_cost_object(self): - """Test cost_object.""" - npt.assert_equal( - self.inst1.get_cost(2), - False, - err_msg='Incorrect cost test result.', - ) - npt.assert_equal( - self.inst1.get_cost(2), - True, - err_msg='Incorrect cost test result.', - ) - npt.assert_equal( - self.inst_none.get_cost(2), - False, - err_msg='Incorrect cost test result.', - ) - npt.assert_equal(self.inst1.cost, 12, err_msg='Incorrect cost value.') +############# +# LINEAR OP # +############# - npt.assert_equal(self.inst2.cost, 12, err_msg='Incorrect cost value.') - npt.assert_raises(TypeError, cost.costObj, 1) +class LinearCases: + """Linear operator cases.""" - npt.assert_raises(ValueError, cost.costObj, [self.dummy, self.dummy]) + def case_linear_identity(self): + """Case linear operator identity.""" + linop = linear.Identity() + data_op, data_adj_op, res_op, res_adj_op = 1, 1, 1, 1 -class GradientTestCase(TestCase): - """Test case for gradient module.""" + return linop, data_op, data_adj_op, res_op, res_adj_op - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.gp = gradient.GradParent( - self.data1, - func_sq, - func_cube, - func_identity, - lambda input_val: 1.0, - data_type=np.floating, - ) - self.gp.grad = self.gp.get_grad(self.data1) - self.gb = gradient.GradBasic( - self.data1, - func_sq, - func_cube, - ) - self.gb.get_grad(self.data1) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.gp = None - self.gb = None - - def test_grad_parent_operators(self): - """Test GradParent.""" - npt.assert_array_equal( - self.gp.op(self.data1), - np.array([[0, 1.0, 4.0], [9.0, 16.0, 25.0], [36.0, 49.0, 64.0]]), - err_msg='Incorrect gradient operation.', - ) - - npt.assert_array_equal( - self.gp.trans_op(self.data1), - np.array( - [[0, 1.0, 8.0], [27.0, 64.0, 125.0], [216.0, 343.0, 512.0]], - ), - err_msg='Incorrect gradient transpose operation.', + def case_linear_wavelet(self): + """Case linear operator wavelet.""" + linop = linear.WaveletConvolve( + filters=np.arange(8).reshape(2, 2, 2).astype(float) ) + data_op = np.arange(4).reshape(1, 2, 2).astype(float) + data_adj_op = np.arange(8).reshape(1, 2, 2, 2).astype(float) + res_op = np.array([[[[0, 0], [0, 4.0]], [[0, 4.0], [8.0, 28.0]]]]) + res_adj_op = np.array([[[28.0, 62.0], [68.0, 140.0]]]) - npt.assert_array_equal( - self.gp.trans_op_op(self.data1), - np.array([ - [0, 1.0, 6.40000000e1], - [7.29000000e2, 4.09600000e3, 1.56250000e4], - [4.66560000e4, 1.17649000e5, 2.62144000e5], - ]), - err_msg='Incorrect gradient transpose operation operation.', - ) - - npt.assert_equal( - self.gp.cost(self.data1), - 1.0, - err_msg='Incorrect cost.', - ) + return linop, data_op, data_adj_op, res_op, res_adj_op - npt.assert_raises( - TypeError, - gradient.GradParent, - 1, + @parametrize(weights=[[1.0, 1.0], None]) + def case_linear_combo(self, weights): + """Case linear operator combo with weights.""" + parent = linear.LinearParent( func_sq, func_cube, ) + linop = linear.LinearCombo([parent, parent], weights) - def test_grad_basic_gradient(self): - """Test GradBasic.""" - npt.assert_array_equal( - self.gb.grad, - np.array([ - [0, 0, 8.0], - [2.16000000e2, 1.72800000e3, 8.0e3], - [2.70000000e4, 7.40880000e4, 1.75616000e5], - ]), - err_msg='Incorrect gradient.', + data_op, data_adj_op, res_op, res_adj_op = ( + 2, + np.array([2, 2]), + np.array([4, 4]), + 8.0 * (2 if weights else 1), ) + return linop, data_op, data_adj_op, res_op, res_adj_op -class LinearTestCase(TestCase): - """Test case for linear module.""" + @parametrize(factor=[1, 1 + 1j]) + def case_linear_matrix(self, factor): + """Case linear operator from matrix.""" + linop = linear.MatrixOperator(np.eye(5) * factor) + data_op = np.arange(5) + data_adj_op = np.arange(5) + res_op = np.arange(5) * factor + res_adj_op = np.arange(5) * np.conjugate(factor) - def setUp(self): - """Set test parameter values.""" - self.parent = linear.LinearParent( - func_sq, - func_cube, - ) - self.ident = linear.Identity() - filters = np.arange(8).reshape(2, 2, 2).astype(float) - self.wave = linear.WaveletConvolve(filters) - self.combo = linear.LinearCombo([self.parent, self.parent]) - self.combo_weight = linear.LinearCombo( - [self.parent, self.parent], - [1.0, 1.0], - ) - self.data1 = np.arange(18).reshape(2, 3, 3).astype(float) - self.data2 = np.arange(4).reshape(1, 2, 2).astype(float) - self.data3 = np.arange(8).reshape(1, 2, 2, 2).astype(float) - self.data4 = np.array([[[[0, 0], [0, 4.0]], [[0, 4.0], [8.0, 28.0]]]]) - self.data5 = np.array([[[28.0, 62.0], [68.0, 140.0]]]) - self.dummy = Dummy() - - def tearDown(self): - """Unset test parameter values.""" - self.parent = None - self.ident = None - self.combo = None - self.combo_weight = None - self.data1 = None - self.data2 = None - self.data3 = None - self.data4 = None - self.data5 = None - self.dummy = None - - def test_linear_parent(self): - """Test LinearParent.""" - npt.assert_equal( - self.parent.op(2), - 4, - err_msg='Incorrect linear parent operation.', - ) + return linop, data_op, data_adj_op, res_op, res_adj_op - npt.assert_equal( - self.parent.adj_op(2), - 8, - err_msg='Incorrect linear parent adjoint operation.', - ) - npt.assert_raises(TypeError, linear.LinearParent, 0, 0) +@fixture +@parametrize_with_cases( + "linop, data_op, data_adj_op, res_op, res_adj_op", cases=LinearCases +) +def lin_adj_op(linop, data_op, data_adj_op, res_op, res_adj_op): + """Get adj_op relative data.""" + return linop.adj_op, data_adj_op, res_adj_op - def test_identity(self): - """Test Identity.""" - npt.assert_equal( - self.ident.op(1.0), - 1.0, - err_msg='Incorrect identity operation.', - ) - npt.assert_equal( - self.ident.adj_op(1.0), - 1.0, - err_msg='Incorrect identity adjoint operation.', - ) +@fixture +@parametrize_with_cases( + "linop, data_op, data_adj_op, res_op, res_adj_op", cases=LinearCases +) +def lin_op(linop, data_op, data_adj_op, res_op, res_adj_op): + """Get op relative data.""" + return linop.op, data_op, res_op - def test_wavelet_convolve(self): - """Test WaveletConvolve.""" - npt.assert_almost_equal( - self.wave.op(self.data2), - self.data4, - err_msg='Incorrect wavelet convolution operation.', - ) - npt.assert_almost_equal( - self.wave.adj_op(self.data3), - self.data5, - err_msg='Incorrect wavelet convolution adjoint operation.', - ) +@parametrize( + ("action", "data", "result"), [fixture_ref(lin_op), fixture_ref(lin_adj_op)] +) +def test_linear_operator(action, data, result): + """Test linear operator.""" + npt.assert_almost_equal(action(data), result) - def test_linear_combo(self): - """Test LinearCombo.""" - npt.assert_equal( - self.combo.op(2), - np.array([4, 4]).astype(object), - err_msg='Incorrect combined linear operation', - ) - npt.assert_equal( - self.combo.adj_op([2, 2]), - 8.0, - err_msg='Incorrect combined linear adjoint operation', - ) +dummy_with_op = Dummy() +dummy_with_op.op = lambda x: x - npt.assert_raises(TypeError, linear.LinearCombo, self.parent) - npt.assert_raises(ValueError, linear.LinearCombo, []) +@pytest.mark.parametrize( + ("args", "error"), + [ + ([linear.LinearParent(func_sq, func_cube)], TypeError), + ([[]], ValueError), + ([[Dummy()]], ValueError), + ([[dummy_with_op]], ValueError), + ([[]], ValueError), + ([[linear.LinearParent(func_sq, func_cube)] * 2, [1.0]], ValueError), + ([[linear.LinearParent(func_sq, func_cube)] * 2, ["1", "1"]], TypeError), + ], +) +def test_linear_combo_errors(args, error): + """Test linear combo_errors.""" + npt.assert_raises(error, linear.LinearCombo, *args) - npt.assert_raises(ValueError, linear.LinearCombo, [self.dummy]) - self.dummy.op = func_identity +############# +# Proximity # +############# - npt.assert_raises(ValueError, linear.LinearCombo, [self.dummy]) - def test_linear_combo_weight(self): - """Test LinearCombo with weight .""" - npt.assert_equal( - self.combo_weight.op(2), - np.array([4, 4]).astype(object), - err_msg='Incorrect combined linear operation', - ) - - npt.assert_equal( - self.combo_weight.adj_op([2, 2]), - 16.0, - err_msg='Incorrect combined linear adjoint operation', - ) +class ProxCases: + """Class containing all proximal operator cases. - npt.assert_raises( - ValueError, - linear.LinearCombo, - [self.parent, self.parent], - [1.0], - ) - - npt.assert_raises( - TypeError, - linear.LinearCombo, - [self.parent, self.parent], - ['1', '1'], - ) + Each case should return 4 parameters: + 1. The proximal operator + 2. test input data + 3. Expected result data + 4. Expected cost value. + """ + weights = np.ones(9).reshape(3, 3).astype(float) * 3 + array33 = np.arange(9).reshape(3, 3).astype(float) + array33_st = np.array([[-0, -0, -0], [0, 1.0, 2.0], [3.0, 4.0, 5.0]]) + array33_st2 = array33_st * -1 -class ProximityTestCase(TestCase): - """Test case for proximity module.""" + array33_support = np.asarray([[0, 0, 0], [0, 1.0, 1.25], [1.5, 1.75, 2.0]]) - def setUp(self): - """Set test parameter values.""" - self.parent = proximity.ProximityParent( - func_sq, - func_double, - ) - self.identity = proximity.IdentityProx() - self.positivity = proximity.Positivity() - weights = np.ones(9).reshape(3, 3).astype(float) * 3 - self.sparsethresh = proximity.SparseThreshold( - linear.Identity(), - weights, - ) - self.lowrank = proximity.LowRankMatrix(10.0, thresh_type='hard') - self.lowrank_rank = proximity.LowRankMatrix( - 10.0, - initial_rank=1, - thresh_type='hard', - ) - self.lowrank_ngole = proximity.LowRankMatrix( - 10.0, - lowr_type='ngole', - operator=func_double, - ) - self.linear_comp = proximity.LinearCompositionProx( - linear_op=linear.Identity(), - prox_op=self.sparsethresh, - ) - self.combo = proximity.ProximityCombo([self.identity, self.positivity]) - if import_sklearn: - self.owl = proximity.OrderedWeightedL1Norm(weights.flatten()) - self.ridge = proximity.Ridge(linear.Identity(), weights) - self.elasticnet_alpha0 = proximity.ElasticNet( - linear.Identity(), - alpha=0, - beta=weights, - ) - self.elasticnet_beta0 = proximity.ElasticNet( - linear.Identity(), - alpha=weights, - beta=0, - ) - self.one_support = proximity.KSupportNorm(beta=0.2, k_value=1) - self.five_support_norm = proximity.KSupportNorm(beta=3, k_value=5) - self.d_support = proximity.KSupportNorm(beta=3.0 * 2, k_value=19) - self.group_lasso = proximity.GroupLASSO( - weights=np.tile(weights, (4, 1, 1)), - ) - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.data2 = np.array([[-0, -0, -0], [0, 1.0, 2.0], [3.0, 4.0, 5.0]]) - self.data3 = np.arange(18).reshape(2, 3, 3).astype(float) - self.data4 = np.array([ + array233 = np.arange(18).reshape(2, 3, 3).astype(float) + array233_2 = np.array( + [ [ [2.73843189, 3.14594066, 3.55344943], [3.9609582, 4.36846698, 4.77597575], @@ -723,349 +280,230 @@ def setUp(self): [11.67394789, 12.87497954, 14.07601119], [15.27704284, 16.47807449, 17.67910614], ], - ]) - self.data5 = np.array([ + ] + ) + array233_3 = np.array( + [ [[0, 0, 0], [0, 0, 0], [0, 0, 0]], [ [4.00795282, 4.60438026, 5.2008077], [5.79723515, 6.39366259, 6.99009003], [7.58651747, 8.18294492, 8.77937236], ], - ]) - self.data6 = self.data3 * -1 - self.data7 = self.combo.op(self.data6) - self.data8 = np.empty(2, dtype=np.ndarray) - self.data8[0] = np.array( - [[-0, -1.0, -2.0], [-3.0, -4.0, -5.0], [-6.0, -7.0, -8.0]], - ) - self.data8[1] = np.array( - [[-0, -0, -0], [-0, -0, -0], [-0, -0, -0]], - ) - self.data9 = self.data1 * (1 + 1j) - self.data10 = self.data9 / (2 * 3 + 1) - self.data11 = np.asarray( - [[0, 0, 0], [0, 1.0, 1.25], [1.5, 1.75, 2.0]], - ) - self.random_data = 3 * np.random.random( - self.group_lasso.weights[0].shape, - ) - self.random_data_tile = np.tile( - self.random_data, - (self.group_lasso.weights.shape[0], 1, 1), - ) - self.gl_result_data = 2 * self.random_data_tile - 3 - self.gl_result_data = np.array( - (self.gl_result_data * (self.gl_result_data > 0).astype('int')) - / 2, - ) - - self.dummy = Dummy() - - def tearDown(self): - """Unset test parameter values.""" - self.parent = None - self.identity = None - self.positivity = None - self.sparsethresh = None - self.lowrank = None - self.lowrank_rank = None - self.lowrank_ngole = None - self.combo = None - self.data1 = None - self.data2 = None - self.data3 = None - self.data4 = None - self.data5 = None - self.data6 = None - self.data7 = None - self.data8 = None - self.dummy = None - self.random_data = None - self.random_data_tile = None - self.gl_result_data = None - - def test_proximity_parent(self): - """Test ProximityParent.""" - npt.assert_equal( - self.parent.op(3), + ] + ) + + def case_prox_parent(self): + """Case prox parent.""" + return ( + proximity.ProximityParent( + func_sq, + func_double, + ), + 3, 9, - err_msg='Inccoret proximity parent operation.', - ) - - npt.assert_equal( - self.parent.cost(3), 6, - err_msg='Incorrect proximity parent cost.', - ) - - def test_identity(self): - """Test IdentityProx.""" - npt.assert_equal( - self.identity.op(3), - 3, - err_msg='Incorrect proximity identity operation.', - ) - - npt.assert_equal( - self.identity.cost(3), - 0, - err_msg='Incorrect proximity identity cost.', - ) - - def test_positivity(self): - """Test Positivity.""" - npt.assert_equal( - self.positivity.op(-3), - 0, - err_msg='Incorrect proximity positivity operation.', - ) - - npt.assert_equal( - self.positivity.cost(-3, verbose=True), - 0, - err_msg='Incorrect proximity positivity cost.', ) - def test_sparse_threshold(self): - """Test SparseThreshold.""" - npt.assert_array_equal( - self.sparsethresh.op(self.data1), - self.data2, - err_msg='Incorrect sparse threshold operation.', - ) - - npt.assert_equal( - self.sparsethresh.cost(self.data1, verbose=True), - 108.0, - err_msg='Incorrect sparse threshold cost.', - ) - - def test_low_rank_matrix(self): - """Test LowRankMatrix.""" - npt.assert_almost_equal( - self.lowrank.op(self.data3), - self.data4, - err_msg='Incorrect low rank operation: standard', - ) - - npt.assert_almost_equal( - self.lowrank_rank.op(self.data3), - self.data4, - err_msg='Incorrect low rank operation: standard with rank', - ) - npt.assert_almost_equal( - self.lowrank_ngole.op(self.data3), - self.data5, - err_msg='Incorrect low rank operation: ngole', - ) - - npt.assert_almost_equal( - self.lowrank.cost(self.data3, verbose=True), - 469.39132942464983, - err_msg='Incorrect low rank cost.', - ) - - def test_linear_comp_prox(self): - """Test LinearCompositionProx.""" - npt.assert_array_equal( - self.linear_comp.op(self.data1), - self.data2, - err_msg='Incorrect sparse threshold operation.', - ) - - npt.assert_equal( - self.linear_comp.cost(self.data1, verbose=True), - 108.0, - err_msg='Incorrect sparse threshold cost.', + def case_prox_identity(self): + """Case prox identity.""" + return proximity.IdentityProx(), 3, 3, 0 + + def case_prox_positivity(self): + """Case prox positivity.""" + return proximity.Positivity(), -3, 0, 0 + + def case_prox_sparsethresh(self): + """Case prox sparsethreshosld.""" + return ( + proximity.SparseThreshold(linear.Identity(), weights=self.weights), + self.array33, + self.array33_st, + 108, + ) + + @parametrize( + "lowr_type, initial_rank, operator, result, cost", + [ + ("standard", None, None, array233_2, 469.3913294246498), + ("standard", 1, None, array233_2, 469.3913294246498), + ("ngole", None, func_double, array233_3, 469.3913294246498), + ], + ) + def case_prox_lowrank(self, lowr_type, initial_rank, operator, result, cost): + """Case prox lowrank.""" + return ( + proximity.LowRankMatrix( + 10, + lowr_type=lowr_type, + initial_rank=initial_rank, + operator=operator, + thresh_type="hard" if lowr_type == "standard" else "soft", + ), + self.array233, + result, + cost, ) - def test_proximity_combo(self): - """Test ProximityCombo.""" - for data7, data8 in zip(self.data7, self.data8): - npt.assert_array_equal( - data7, - data8, - err_msg='Incorrect combined operation', + def case_prox_linear_comp(self): + """Case prox linear comp.""" + return ( + proximity.LinearCompositionProx( + linear_op=linear.Identity(), prox_op=self.case_prox_sparsethresh()[0] + ), + self.array33, + self.array33_st, + 108, + ) + + def case_prox_ridge(self): + """Case prox ridge.""" + return ( + proximity.Ridge(linear.Identity(), self.weights), + self.array33 * (1 + 1j), + self.array33 * (1 + 1j) / 7, + 1224, + ) + + @parametrize("alpha, beta", [(0, weights), (weights, 0)]) + def case_prox_elasticnet(self, alpha, beta): + """Case prox elastic net.""" + if np.all(alpha == 0): + data = self.case_prox_sparsethresh()[1:] + else: + data = self.case_prox_ridge()[1:] + return (proximity.ElasticNet(linear.Identity(), alpha, beta), *data) + + @parametrize( + "beta, k_value, data, result, cost", + [ + (0.2, 1, array33.flatten(), array33_st.flatten(), 259.2), + (3, 5, array33.flatten(), array33_support.flatten(), 684.0), + ( + 6.0, + 9, + array33.flatten() * (1 + 1j), + array33.flatten() * (1 + 1j) / 7, + 1224, + ), + ], + ) + def case_prox_Ksupport(self, beta, k_value, data, result, cost): + """Case prox K-support norm.""" + return (proximity.KSupportNorm(beta=beta, k_value=k_value), data, result, cost) + + @parametrize(use_weights=[True, False]) + def case_prox_grouplasso(self, use_weights): + """Case GroupLasso proximity.""" + if use_weights: + weights = np.tile(self.weights, (4, 1, 1)) + else: + weights = np.tile(np.zeros((3, 3)), (4, 1, 1)) + + random_data = 3 * np.random.random(weights[0].shape) + random_data_tile = np.tile(random_data, (weights.shape[0], 1, 1)) + if use_weights: + gl_result_data = 2 * random_data_tile - 3 + gl_result_data = ( + np.array(gl_result_data * (gl_result_data > 0).astype("int")) / 2 ) - - npt.assert_equal( - self.combo.cost(self.data6), - 0, - err_msg='Incorrect combined cost.', - ) - - npt.assert_raises(TypeError, proximity.ProximityCombo, 1) - - npt.assert_raises(ValueError, proximity.ProximityCombo, []) - - npt.assert_raises(ValueError, proximity.ProximityCombo, [self.dummy]) - - self.dummy.op = func_identity - - npt.assert_raises(ValueError, proximity.ProximityCombo, [self.dummy]) - - @skipIf(import_sklearn, 'sklearn is installed.') # pragma: no cover - def test_owl_sklearn_error(self): - """Test OrderedWeightedL1Norm with Scikit-Learn.""" - npt.assert_raises(ImportError, proximity.OrderedWeightedL1Norm, 1) - - @skipUnless(import_sklearn, 'sklearn not installed.') # pragma: no cover - def test_sparse_owl(self): - """Test OrderedWeightedL1Norm.""" - npt.assert_array_equal( - self.owl.op(self.data1.flatten()), - self.data2.flatten(), - err_msg='Incorrect sparse threshold operation.', - ) - - npt.assert_equal( - self.owl.cost(self.data1.flatten(), verbose=True), + cost = np.sum(random_data_tile) * 6 + else: + gl_result_data = random_data_tile + cost = 0 + return ( + proximity.GroupLASSO( + weights=weights, + ), + random_data_tile, + gl_result_data, + cost, + ) + + @pytest.mark.skipif(not SKLEARN_AVAILABLE, reason="sklearn not available.") + def case_prox_owl(self): + """Case prox for Ordered Weighted L1 Norm.""" + return ( + proximity.OrderedWeightedL1Norm(self.weights.flatten()), + self.array33.flatten(), + self.array33_st.flatten(), 108.0, - err_msg='Incorrect sparse threshold cost.', ) - npt.assert_raises( - ValueError, - proximity.OrderedWeightedL1Norm, - np.arange(10), - ) - def test_ridge(self): - """Test Ridge.""" - npt.assert_array_equal( - self.ridge.op(self.data9), - self.data10, - err_msg='Incorect shrinkage operation.', - ) +@parametrize_with_cases("operator, input_data, op_result, cost_result", cases=ProxCases) +def test_prox_op(operator, input_data, op_result, cost_result): + """Test proximity operator op.""" + npt.assert_almost_equal(operator.op(input_data), op_result) - npt.assert_equal( - self.ridge.cost(self.data9, verbose=True), - 408.0 * 3.0, - err_msg='Incorect shrinkage cost.', - ) - def test_elastic_net_alpha0(self): - """Test ElasticNet.""" - npt.assert_array_equal( - self.elasticnet_alpha0.op(self.data1), - self.data2, - err_msg='Incorect sparse threshold operation ElasticNet class.', - ) +@parametrize_with_cases("operator, input_data, op_result, cost_result", cases=ProxCases) +def test_prox_cost(operator, input_data, op_result, cost_result): + """Test proximity operator cost.""" + npt.assert_almost_equal(operator.cost(input_data, verbose=True), cost_result) - npt.assert_equal( - self.elasticnet_alpha0.cost(self.data1), - 108.0, - err_msg='Incorect shrinkage cost in ElasticNet class.', - ) - def test_elastic_net_beta0(self): - """Test ElasticNet with beta=0.""" - npt.assert_array_equal( - self.elasticnet_beta0.op(self.data9), - self.data10, - err_msg='Incorect ridge operation ElasticNet class.', - ) +@parametrize( + "arg, error", + [ + (1, TypeError), + ([], ValueError), + ([Dummy()], ValueError), + ([dummy_with_op], ValueError), + ], +) +def test_error_prox_combo(arg, error): + """Test errors for proximity combo.""" + npt.assert_raises(error, proximity.ProximityCombo, arg) - npt.assert_equal( - self.elasticnet_beta0.cost(self.data9, verbose=True), - 408.0 * 3.0, - err_msg='Incorect shrinkage cost in ElasticNet class.', - ) - def test_one_support_norm(self): - """Test KSupportNorm with k=1.""" - npt.assert_allclose( - self.one_support.op(self.data1.flatten()), - self.data2.flatten(), - err_msg='Incorect sparse threshold operation for 1-support norm', - rtol=1e-6, - ) - - npt.assert_equal( - self.one_support.cost(self.data1.flatten(), verbose=True), - 259.2, - err_msg='Incorect sparse threshold cost.', - ) +@pytest.mark.skipif(SKLEARN_AVAILABLE, reason="sklearn is installed") +def test_fail_sklearn(): + """Test fail OWL with sklearn.""" + npt.assert_raises(ImportError, proximity.OrderedWeightedL1Norm, 1) - npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0) - def test_five_support_norm(self): - """Test KSupportNorm with k=5.""" - npt.assert_allclose( - self.five_support_norm.op(self.data1.flatten()), - self.data11.flatten(), - err_msg='Incorect sparse Ksupport norm operation', - rtol=1e-6, - ) +@pytest.mark.skipif(not SKLEARN_AVAILABLE, reason="sklearn is not installed.") +def test_fail_owl(): + """Test errors for Ordered Weighted L1 Norm.""" + npt.assert_raises( + ValueError, + proximity.OrderedWeightedL1Norm, + np.arange(10), + ) - npt.assert_equal( - self.five_support_norm.cost(self.data1.flatten(), verbose=True), - 684.0, - err_msg='Incorrect 5-support norm cost.', - ) + npt.assert_raises( + ValueError, + proximity.OrderedWeightedL1Norm, + -np.arange(10), + ) - npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0) - def test_d_support_norm(self): - """Test KSupportNorm with k=19.""" - npt.assert_allclose( - self.d_support.op(self.data9.flatten()), - self.data10.flatten(), - err_msg='Incorect shrinkage operation for d-support norm', - rtol=1e-6, - ) +def test_fail_lowrank(): + """Test fail for lowrank.""" + prox_op = proximity.LowRankMatrix(10, lowr_type="fail") + npt.assert_raises(ValueError, prox_op.op, 0) - npt.assert_almost_equal( - self.d_support.cost(self.data9.flatten(), verbose=True), - 408.0 * 3.0, - err_msg='Incorrect shrinkage cost for d-support norm.', - ) - npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0) +def test_fail_Ksupport_norm(): + """Test fail for K-support norm.""" + npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0) - def test_group_lasso(self): - """Test GroupLASSO.""" - npt.assert_allclose( - self.group_lasso.op(self.random_data_tile), - self.gl_result_data, - ) - npt.assert_equal( - self.group_lasso.cost(self.random_data_tile), - np.sum(6 * self.random_data_tile), - ) - # Check that for 0 weights operator doesnt change result - self.group_lasso.weights = np.zeros_like(self.group_lasso.weights) - npt.assert_equal( - self.group_lasso.op(self.random_data_tile), - self.random_data_tile, - ) - npt.assert_equal(self.group_lasso.cost(self.random_data_tile), 0) +def test_reweight(): + """Test for reweight module.""" + data1 = np.arange(9).reshape(3, 3).astype(float) + 1 + data2 = np.array( + [[0.5, 1.0, 1.5], [2.0, 2.5, 3.0], [3.5, 4.0, 4.5]], + ) -class ReweightTestCase(TestCase): - """Test case for reweight module.""" + rw = reweight.cwbReweight(data1) + rw.reweight(data1) - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) + 1 - self.data2 = np.array( - [[0.5, 1.0, 1.5], [2.0, 2.5, 3.0], [3.5, 4.0, 4.5]], - ) - self.rw = reweight.cwbReweight(self.data1) - self.rw.reweight(self.data1) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.rw = None - - def test_cwbreweight(self): - """Test cwbReweight.""" - npt.assert_array_equal( - self.rw.weights, - self.data2, - err_msg='Incorrect CWB re-weighting.', - ) + npt.assert_array_equal( + rw.weights, + data2, + err_msg="Incorrect CWB re-weighting.", + ) - npt.assert_raises(ValueError, self.rw.reweight, self.data1[0]) + npt.assert_raises(ValueError, rw.reweight, data1[0]) diff --git a/modopt/tests/test_signal.py b/modopt/tests/test_signal.py index 7490b98c..202e541b 100644 --- a/modopt/tests/test_signal.py +++ b/modopt/tests/test_signal.py @@ -1,322 +1,240 @@ -# -*- coding: utf-8 -*- - """UNIT TESTS FOR SIGNAL. This module contains unit tests for the modopt.signal module. -:Author: Samuel Farrens - +:Authors: + Samuel Farrens + Pierre-Antoine Comby """ -from unittest import TestCase - import numpy as np import numpy.testing as npt +import pytest +from test_helpers import failparam from modopt.signal import filter, noise, positivity, svd, validation, wavelet -class FilterTestCase(TestCase): - """Test case for filter module.""" - - def test_guassian_filter(self): - """Test guassian_filter.""" - npt.assert_almost_equal( - filter.gaussian_filter(1, 1), - 0.24197072451914337, - err_msg='Incorrect Gaussian filter', - ) +class TestFilter: + """Test filter module""" + @pytest.mark.parametrize( + ("norm", "result"), [(True, 0.24197072451914337), (False, 0.60653065971263342)] + ) + def test_gaussian_filter(self, norm, result): + """Test gaussian filter.""" + npt.assert_almost_equal(filter.gaussian_filter(1, 1, norm=norm), result) - npt.assert_almost_equal( - filter.gaussian_filter(1, 1, norm=False), - 0.60653065971263342, - err_msg='Incorrect Gaussian filter', - ) def test_mex_hat(self): - """Test mex_hat.""" + """Test mexican hat filter.""" npt.assert_almost_equal( filter.mex_hat(2, 1), -0.35213905225713371, - err_msg='Incorrect Mexican hat filter', ) + def test_mex_hat_dir(self): - """Test mex_hat_dir.""" + """Test directional mexican hat filter.""" npt.assert_almost_equal( filter.mex_hat_dir(1, 2, 1), 0.17606952612856686, - err_msg='Incorrect directional Mexican hat filter', ) -class NoiseTestCase(TestCase): - """Test case for noise module.""" +class TestNoise: + """Test noise module.""" - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.data2 = np.array( - [[0, 2.0, 2.0], [4.0, 5.0, 10], [11.0, 15.0, 18.0]], - ) - self.data3 = np.array([ + data1 = np.arange(9).reshape(3, 3).astype(float) + data2 = np.array( + [[0, 2.0, 2.0], [4.0, 5.0, 10], [11.0, 15.0, 18.0]], + ) + data3 = np.array( + [ [1.62434536, 0.38824359, 1.47182825], [1.92703138, 4.86540763, 2.6984613], [7.74481176, 6.2387931, 8.3190391], - ]) - self.data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]]) - self.data5 = np.array( - [[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]], - ) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.data3 = None - self.data4 = None - self.data5 = None - - def test_add_noise_poisson(self): - """Test add_noise with Poisson noise.""" - np.random.seed(1) - npt.assert_array_equal( - noise.add_noise(self.data1, noise_type='poisson'), - self.data2, - err_msg='Incorrect noise: Poisson', - ) - - npt.assert_raises( - ValueError, - noise.add_noise, - self.data1, - noise_type='bla', - ) - - npt.assert_raises(ValueError, noise.add_noise, self.data1, (1, 1)) - - def test_add_noise_gaussian(self): - """Test add_noise with Gaussian noise.""" - np.random.seed(1) - npt.assert_almost_equal( - noise.add_noise(self.data1), - self.data3, - err_msg='Incorrect noise: Gaussian', - ) - + ] + ) + data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]]) + data5 = np.array( + [[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]], + ) + + @pytest.mark.parametrize( + ("data", "noise_type", "sigma", "data_noise"), + [ + (data1, "poisson", 1, data2), + (data1, "gauss", 1, data3), + (data1, "gauss", (1, 1, 1), data3), + failparam(data1, "fail", 1, data1, raises=ValueError), + ], + ) + def test_add_noise(self, data, noise_type, sigma, data_noise): + """Test add_noise.""" np.random.seed(1) npt.assert_almost_equal( - noise.add_noise(self.data1, sigma=(1, 1, 1)), - self.data3, - err_msg='Incorrect noise: Gaussian', - ) - - def test_thresh_hard(self): - """Test thresh with hard threshold.""" - npt.assert_array_equal( - noise.thresh(self.data1, 5), - self.data4, - err_msg='Incorrect threshold: hard', - ) - - npt.assert_raises( - ValueError, - noise.thresh, - self.data1, - 5, - threshold_type='bla', + noise.add_noise(data, sigma=sigma, noise_type=noise_type), data_noise ) - def test_thresh_soft(self): - """Test thresh with soft threshold.""" + @pytest.mark.parametrize( + ("threshold_type", "result"), + [("hard", data4), ("soft", data5), failparam("fail", None, raises=ValueError)], + ) + def test_thresh(self, threshold_type, result): + """Test threshold.""" npt.assert_array_equal( - noise.thresh(self.data1, 5, threshold_type='soft'), - self.data5, - err_msg='Incorrect threshold: soft', + noise.thresh(self.data1, 5, threshold_type=threshold_type), result ) - -class PositivityTestCase(TestCase): - """Test case for positivity module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3) - 5 - self.data2 = np.array([[0, 0, 0], [0, 0, 0], [1, 2, 3]]) - self.data3 = np.array( - [np.arange(5) - 3, np.arange(4) - 2], - dtype=object, - ) - self.data4 = np.array( - [np.array([0, 0, 0, 0, 1]), np.array([0, 0, 0, 1])], +class TestPositivity: + """Test positivity module.""" + data1 = np.arange(9).reshape(3, 3).astype(float) + data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]]) + data5 = np.array( + [[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]], + ) + @pytest.mark.parametrize( + ("value", "expected"), + [ + (-1.0, -float(0)), + (-1, 0), + (data1 - 5, data5), + ( + np.array([np.arange(3) - 1, np.arange(2) - 1], dtype=object), + np.array([np.array([0, 0, 1]), np.array([0, 0])], dtype=object), + ), + failparam("-1", None, raises=TypeError), + ], + ) + def test_positive(self, value, expected): + """Test positive.""" + if isinstance(value, np.ndarray) and value.dtype == "O": + for v, e in zip(positivity.positive(value), expected): + npt.assert_array_equal(v, e) + else: + npt.assert_array_equal(positivity.positive(value), expected) + + +class TestSVD: + """Test for svd module.""" + + @pytest.fixture + def data(self): + """Initialize test data.""" + data1 = np.arange(18).reshape(9, 2).astype(float) + data2 = np.arange(32).reshape(16, 2).astype(float) + data3 = np.array( + [ + np.array( + [ + [-0.01744594, -0.61438865], + [-0.08435304, -0.50397984], + [-0.15126014, -0.39357102], + [-0.21816724, -0.28316221], + [-0.28507434, -0.17275339], + [-0.35198144, -0.06234457], + [-0.41888854, 0.04806424], + [-0.48579564, 0.15847306], + [-0.55270274, 0.26888188], + ] + ), + np.array([42.23492742, 1.10041151]), + np.array( + [ + [-0.67608034, -0.73682791], + [0.73682791, -0.67608034], + ] + ), + ], dtype=object, ) - self.pos_dtype_obj = positivity.positive(self.data3) - self.err = 'Incorrect positivity' - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - - def test_positivity(self): - """Test positivity.""" - npt.assert_equal(positivity.positive(-1), 0, err_msg=self.err) - - npt.assert_equal( - positivity.positive(-1.0), - -float(0), - err_msg=self.err, + data4 = np.array( + [ + [-1.05426832e-16, 1.0], + [2.0, 3.0], + [4.0, 5.0], + [6.0, 7.0], + [8.0, 9.0], + [1.0e1, 1.1e1], + [1.2e1, 1.3e1], + [1.4e1, 1.5e1], + [1.6e1, 1.7e1], + ] ) - npt.assert_equal( - positivity.positive(self.data1), - self.data2, - err_msg=self.err, + data5 = np.array( + [ + [0.49815487, 0.54291537], + [2.40863386, 2.62505584], + [4.31911286, 4.70719631], + [6.22959185, 6.78933678], + [8.14007085, 8.87147725], + [10.05054985, 10.95361772], + [11.96102884, 13.03575819], + [13.87150784, 15.11789866], + [15.78198684, 17.20003913], + ] ) + return (data1, data2, data3, data4, data5) - for expected, output in zip(self.data4, self.pos_dtype_obj): - print(expected, output) - npt.assert_array_equal(expected, output, err_msg=self.err) + @pytest.fixture + def svd0(self, data): + """Compute SVD of first data sample.""" + return svd.calculate_svd(data[0]) - npt.assert_raises(TypeError, positivity.positive, '-1') - - -class SVDTestCase(TestCase): - """Test case for svd module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(18).reshape(9, 2).astype(float) - self.data2 = np.arange(32).reshape(16, 2).astype(float) - self.data3 = np.array( - [ - np.array([ - [-0.01744594, -0.61438865], - [-0.08435304, -0.50397984], - [-0.15126014, -0.39357102], - [-0.21816724, -0.28316221], - [-0.28507434, -0.17275339], - [-0.35198144, -0.06234457], - [-0.41888854, 0.04806424], - [-0.48579564, 0.15847306], - [-0.55270274, 0.26888188], - ]), - np.array([42.23492742, 1.10041151]), - np.array([ - [-0.67608034, -0.73682791], - [0.73682791, -0.67608034], - ]), - ], - dtype=object, - ) - self.data4 = np.array([ - [-1.05426832e-16, 1.0], - [2.0, 3.0], - [4.0, 5.0], - [6.0, 7.0], - [8.0, 9.0], - [1.0e1, 1.1e1], - [1.2e1, 1.3e1], - [1.4e1, 1.5e1], - [1.6e1, 1.7e1], - ]) - self.data5 = np.array([ - [0.49815487, 0.54291537], - [2.40863386, 2.62505584], - [4.31911286, 4.70719631], - [6.22959185, 6.78933678], - [8.14007085, 8.87147725], - [10.05054985, 10.95361772], - [11.96102884, 13.03575819], - [13.87150784, 15.11789866], - [15.78198684, 17.20003913], - ]) - self.svd = svd.calculate_svd(self.data1) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.data3 = None - self.data4 = None - self.svd = None - - def test_find_n_pc(self): - """Test find_n_pc.""" + def test_find_n_pc(self, data): + """Test find number of principal component.""" npt.assert_equal( - svd.find_n_pc(svd.svd(self.data2)[0]), + svd.find_n_pc(svd.svd(data[1])[0]), 2, - err_msg='Incorrect number of principal components.', + err_msg="Incorrect number of principal components.", ) + def test_n_pc_fail_non_square(self): + """Test find_n_pc.""" npt.assert_raises(ValueError, svd.find_n_pc, np.arange(3)) - def test_calculate_svd(self): + def test_calculate_svd(self, data, svd0): """Test calculate_svd.""" + errors = [] + for i, name in enumerate("USV"): + try: + npt.assert_almost_equal(svd0[i], data[2][i]) + except AssertionError: + errors.append(name) + if errors: + raise AssertionError("Incorrect SVD calculation for: " + ", ".join(errors)) + + @pytest.mark.parametrize( + ("n_pc", "idx_res"), + [(None, 3), (1, 4), ("all", 0), failparam("fail", 1, raises=ValueError)], + ) + def test_svd_thresh(self, data, n_pc, idx_res): + """Test svd_tresh.""" npt.assert_almost_equal( - self.svd[0], - np.array(self.data3)[0], - err_msg='Incorrect SVD calculation: U', - ) - - npt.assert_almost_equal( - self.svd[1], - np.array(self.data3)[1], - err_msg='Incorrect SVD calculation: S', - ) - - npt.assert_almost_equal( - self.svd[2], - np.array(self.data3)[2], - err_msg='Incorrect SVD calculation: V', - ) - - def test_svd_thresh(self): - """Test svd_thresh.""" - npt.assert_almost_equal( - svd.svd_thresh(self.data1), - self.data4, - err_msg='Incorrect SVD tresholding', - ) - - npt.assert_almost_equal( - svd.svd_thresh(self.data1, n_pc=1), - self.data5, - err_msg='Incorrect SVD tresholding', - ) - - npt.assert_almost_equal( - svd.svd_thresh(self.data1, n_pc='all'), - self.data1, - err_msg='Incorrect SVD tresholding', + svd.svd_thresh(data[0], n_pc=n_pc), + data[idx_res], ) + def test_svd_tresh_invalid_type(self): + """Test svd_tresh failure.""" npt.assert_raises(TypeError, svd.svd_thresh, 1) - npt.assert_raises(ValueError, svd.svd_thresh, self.data1, n_pc='bla') - - def test_svd_thresh_coef(self): - """Test svd_thresh_coef.""" + @pytest.mark.parametrize("operator", [lambda x: x, failparam(0, raises=TypeError)]) + def test_svd_thresh_coef(self, data, operator): + """Test svd_tresh_coef.""" npt.assert_almost_equal( - svd.svd_thresh_coef(self.data1, lambda x_val: x_val, 0), - self.data1, - err_msg='Incorrect SVD coefficient tresholding', + svd.svd_thresh_coef(data[0], operator, 0), + data[0], + err_msg="Incorrect SVD coefficient tresholding", ) - npt.assert_raises(TypeError, svd.svd_thresh_coef, self.data1, 0, 0) - + # TODO test_svd_thresh_coef_fast -class ValidationTestCase(TestCase): - """Test case for validation module.""" +class TestValidation: + """Test validation Module.""" - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None + array33 = np.arange(9).reshape(3, 3) def test_transpose_test(self): """Test transpose_test.""" @@ -325,90 +243,81 @@ def test_transpose_test(self): validation.transpose_test( lambda x_val, y_val: x_val.dot(y_val), lambda x_val, y_val: x_val.dot(y_val.T), - self.data1.shape, - x_args=self.data1, + self.array33.shape, + x_args=self.array33, ), None, ) - npt.assert_raises( - TypeError, - validation.transpose_test, - 0, - 0, - self.data1.shape, - x_args=self.data1, - ) - -class WaveletTestCase(TestCase): - """Test case for wavelet module.""" +class TestWavelet: + """Test Wavelet Module.""" - def setUp(self): + @pytest.fixture + def data(self): """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.data2 = np.arange(36).reshape(4, 3, 3).astype(float) - self.data3 = np.array([ - [ - [6.0, 20, 26.0], - [36.0, 84.0, 84.0], - [90, 164.0, 134.0], - ], + data1 = np.arange(9).reshape(3, 3).astype(float) + data2 = np.arange(36).reshape(4, 3, 3).astype(float) + data3 = np.array( [ - [78.0, 155.0, 134.0], - [225.0, 408.0, 327.0], - [270, 461.0, 350], - ], + [ + [6.0, 20, 26.0], + [36.0, 84.0, 84.0], + [90, 164.0, 134.0], + ], + [ + [78.0, 155.0, 134.0], + [225.0, 408.0, 327.0], + [270, 461.0, 350], + ], + [ + [150, 290, 242.0], + [414.0, 732.0, 570], + [450, 758.0, 566.0], + ], + [ + [222.0, 425.0, 350], + [603.0, 1056.0, 813.0], + [630, 1055.0, 782.0], + ], + ] + ) + + data4 = np.array( [ - [150, 290, 242.0], - [414.0, 732.0, 570], - [450, 758.0, 566.0], - ], + [6496.0, 9796.0, 6544.0], + [9924.0, 14910, 9924.0], + [6544.0, 9796.0, 6496.0], + ] + ) + + data5 = np.array( [ - [222.0, 425.0, 350], - [603.0, 1056.0, 813.0], - [630, 1055.0, 782.0], - ], - ]) - - self.data4 = np.array([ - [6496.0, 9796.0, 6544.0], - [9924.0, 14910, 9924.0], - [6544.0, 9796.0, 6496.0], - ]) - - self.data5 = np.array([ - [[0, 1.0, 4.0], [3.0, 10, 13.0], [6.0, 19.0, 22.0]], - [[3.0, 10, 13.0], [24.0, 46.0, 40], [45.0, 82.0, 67.0]], - [[6.0, 19.0, 22.0], [45.0, 82.0, 67.0], [84.0, 145.0, 112.0]], - ]) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.data3 = None - self.data4 = None - self.data5 = None - - def test_filter_convolve(self): - """Test filter_convolve.""" - npt.assert_almost_equal( - wavelet.filter_convolve(self.data1, self.data2), - self.data3, - err_msg='Inccorect filter comvolution.', + [[0, 1.0, 4.0], [3.0, 10, 13.0], [6.0, 19.0, 22.0]], + [[3.0, 10, 13.0], [24.0, 46.0, 40], [45.0, 82.0, 67.0]], + [[6.0, 19.0, 22.0], [45.0, 82.0, 67.0], [84.0, 145.0, 112.0]], + ] ) + return (data1, data2, data3, data4, data5) + @pytest.mark.parametrize( + ("idx_data", "idx_filter", "idx_res", "filter_rot"), + [(0, 1, 2, False), (1, 1, 3, True)], + ) + def test_filter_convolve(self, data, idx_data, idx_filter, idx_res, filter_rot): + """Test filter_convolve.""" npt.assert_almost_equal( - wavelet.filter_convolve(self.data2, self.data2, filter_rot=True), - self.data4, - err_msg='Inccorect filter comvolution.', + wavelet.filter_convolve( + data[idx_data], data[idx_filter], filter_rot=filter_rot + ), + data[idx_res], + err_msg="Inccorect filter comvolution.", ) - def test_filter_convolve_stack(self): + def test_filter_convolve_stack(self, data): """Test filter_convolve_stack.""" npt.assert_almost_equal( - wavelet.filter_convolve_stack(self.data1, self.data1), - self.data5, - err_msg='Inccorect filter stack comvolution.', + wavelet.filter_convolve_stack(data[0], data[0]), + data[4], + err_msg="Inccorect filter stack comvolution.", ) diff --git a/setup.cfg b/setup.cfg index afe46bbc..8d8e821b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -79,16 +79,17 @@ max-string-usages = 20 max-raises = 5 [tool:pytest] +norecursedirs=tests/test_helpers testpaths = modopt addopts = --verbose - --emoji --cov=modopt - --cov-report=term + --cov-report=term-missing --cov-report=xml --junitxml=pytest.xml --pydocstyle [pydocstyle] convention=numpy +add-ignore=D107