diff --git a/modopt/math/matrix.py b/modopt/math/matrix.py index 939cf41f..8361531d 100644 --- a/modopt/math/matrix.py +++ b/modopt/math/matrix.py @@ -285,9 +285,9 @@ class PowerMethod(object): >>> np.random.seed(1) >>> pm = PowerMethod(lambda x: x.dot(x.T), (3, 3)) >>> np.around(pm.spec_rad, 6) - 0.904292 + 1.0 >>> np.around(pm.inv_spec_rad, 6) - 1.105837 + 1.0 Notes ----- @@ -348,17 +348,21 @@ def get_spec_rad(self, tolerance=1e-6, max_iter=20, extra_factor=1.0): # Set (or reset) values of x. x_old = self._set_initial_x() + xp = get_array_module(x_old) + x_old_norm = xp.linalg.norm(x_old) + + x_old /= x_old_norm + # Iterate until the L2 norm of x converges. for i_elem in range(max_iter): - xp = get_array_module(x_old) - x_old_norm = xp.linalg.norm(x_old) - - x_new = self._operator(x_old) / x_old_norm + x_new = self._operator(x_old) x_new_norm = xp.linalg.norm(x_new) + x_new /= x_new_norm + if (xp.abs(x_new_norm - x_old_norm) < tolerance): message = ( ' - Power Method converged after {0} iterations!' @@ -374,6 +378,7 @@ def get_spec_rad(self, tolerance=1e-6, max_iter=20, extra_factor=1.0): print(message.format(max_iter)) xp.copyto(x_old, x_new) + x_old_norm = x_new_norm self.spec_rad = x_new_norm * extra_factor self.inv_spec_rad = 1.0 / self.spec_rad diff --git a/modopt/tests/test_math.py b/modopt/tests/test_math.py index 99908e02..ba175ae6 100644 --- a/modopt/tests/test_math.py +++ b/modopt/tests/test_math.py @@ -234,13 +234,13 @@ def test_powermethod_converged(self): """Test PowerMethod converged.""" npt.assert_almost_equal( self.pmInstance1.spec_rad, - 0.90429242629600837, + 1.0, err_msg='Incorrect spectral radius: converged', ) npt.assert_almost_equal( self.pmInstance1.inv_spec_rad, - 1.1058369736612865, + 1.0, err_msg='Incorrect inverse spectral radius: converged', ) @@ -248,13 +248,13 @@ def test_powermethod_unconverged(self): """Test PowerMethod unconverged.""" npt.assert_almost_equal( self.pmInstance2.spec_rad, - 0.92048833577059219, + 0.8675467477372257, err_msg='Incorrect spectral radius: unconverged', ) npt.assert_almost_equal( self.pmInstance2.inv_spec_rad, - 1.0863798715741946, + 1.152675636913221, err_msg='Incorrect inverse spectral radius: unconverged', )