From 8faa85887387cbcfe3ae4f7332080fb969c44257 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Fri, 16 Apr 2021 10:45:30 +0200 Subject: [PATCH 1/3] Fix issues --- modopt/opt/algorithms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modopt/opt/algorithms.py b/modopt/opt/algorithms.py index fea737cb..125ac84c 100644 --- a/modopt/opt/algorithms.py +++ b/modopt/opt/algorithms.py @@ -777,7 +777,7 @@ def _update(self): if self._cost_func: self.converge = ( self.any_convergence_flag() - or self._cost_func.get_cost(self._x_new), + or self._cost_func.get_cost(self._x_new) ) def iterate(self, max_iter=150): @@ -1536,7 +1536,7 @@ def _update(self): if self._cost_func: self.converge = ( self.any_convergence_flag() - or self._cost_func.get_cost(self._x_new), + or self._cost_func.get_cost(self._x_new) ) def iterate(self, max_iter=150): From 526f8652016f391b0a3bbee92cd454b11805e46f Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Fri, 16 Apr 2021 11:02:58 +0200 Subject: [PATCH 2/3] Add right tests --- modopt/tests/test_opt.py | 58 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py index b6805006..0276b821 100644 --- a/modopt/tests/test_opt.py +++ b/modopt/tests/test_opt.py @@ -58,6 +58,17 @@ def setUp(self): reweight_inst = reweight.cwbReweight(self.data3) cost_inst = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) self.setup = algorithms.SetUp() + self.max_iter = 20 + + self.fb_all_iter = algorithms.ForwardBackward( + self.data1, + grad=grad_inst, + prox=prox_inst, + cost=None, + auto_iterate=False, + beta_update=func_identity, + ) + self.fb_all_iter.iterate(self.max_iter) self.fb1 = algorithms.ForwardBackward( self.data1, @@ -110,6 +121,17 @@ def setUp(self): s_greedy=1.1, ) + self.gfb_all_iter = algorithms.GenForwardBackward( + self.data1, + grad=grad_inst, + prox_list=[prox_inst, prox_dual_inst], + cost=None, + auto_iterate=False, + gamma_update=func_identity, + beta_update=func_identity, + ) + self.gfb_all_iter.iterate(self.max_iter) + self.gfb1 = algorithms.GenForwardBackward( self.data1, grad=grad_inst, @@ -133,6 +155,20 @@ def setUp(self): step_size=2, ) + self.condat_all_iter = algorithms.Condat( + self.data1, + self.data2, + grad=grad_inst, + prox=prox_inst, + cost=None, + prox_dual=prox_dual_inst, + sigma_update=func_identity, + tau_update=func_identity, + rho_update=func_identity, + auto_iterate=False, + ) + self.condat_all_iter.iterate(self.max_iter) + self.condat1 = algorithms.Condat( self.data1, self.data2, @@ -166,6 +202,18 @@ def setUp(self): auto_iterate=False, ) + self.pogm_all_iter = algorithms.POGM( + u=self.data1, + x=self.data1, + y=self.data1, + z=self.data1, + grad=grad_inst, + prox=prox_inst, + auto_iterate=False, + cost=None, + ) + self.pogm_all_iter.iterate(self.max_iter) + self.pogm1 = algorithms.POGM( u=self.data1, x=self.data1, @@ -184,13 +232,18 @@ def tearDown(self): self.data1 = None self.data2 = None self.setup = None + self.fb_all_iter = None self.fb1 = None self.fb2 = None + self.gfb_all_iter = None self.gfb1 = None self.gfb2 = None + self.condat_all_iter = None self.condat1 = None self.condat2 = None self.condat3 = None + self.pogm1 = None + self.pogm_all_iter = None self.dummy = None def test_set_up(self): @@ -201,6 +254,11 @@ def test_set_up(self): npt.assert_raises(TypeError, self.setup._check_param_update, 1) + def test_all_iter(self): + for opt in [self.fb_all_iter, self.gfb_all_iter, + self.condat_all_iter, self.pogm_all_iter]: + npt.assert_equal(opt.idx, self.max_iter - 1) + def test_forward_backward(self): """Test forward_backward.""" npt.assert_array_equal( From 8a9da78be4d70dd9a17a7ccea79161a586447e56 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Fri, 16 Apr 2021 11:16:24 +0200 Subject: [PATCH 3/3] Fix PEP --- modopt/tests/test_opt.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py index 0276b821..3c33c948 100644 --- a/modopt/tests/test_opt.py +++ b/modopt/tests/test_opt.py @@ -255,8 +255,14 @@ def test_set_up(self): npt.assert_raises(TypeError, self.setup._check_param_update, 1) def test_all_iter(self): - for opt in [self.fb_all_iter, self.gfb_all_iter, - self.condat_all_iter, self.pogm_all_iter]: + """Test if all opt run for all iterations.""" + opts = [ + self.fb_all_iter, + self.gfb_all_iter, + self.condat_all_iter, + self.pogm_all_iter, + ] + for opt in opts: npt.assert_equal(opt.idx, self.max_iter - 1) def test_forward_backward(self):