diff --git a/modopt/opt/proximity.py b/modopt/opt/proximity.py index e0f28e96..f8f368ef 100644 --- a/modopt/opt/proximity.py +++ b/modopt/opt/proximity.py @@ -28,7 +28,7 @@ from modopt.math.matrix import nuclear_norm from modopt.signal.noise import thresh from modopt.signal.positivity import positive -from modopt.signal.svd import svd_thresh, svd_thresh_coef +from modopt.signal.svd import svd_thresh, svd_thresh_coef, svd_thresh_coef_fast class ProximityParent(object): @@ -237,6 +237,9 @@ class LowRankMatrix(ProximityParent): lowr_type : {'standard', 'ngole'} Low-rank implementation (options are 'standard' or 'ngole', default is 'standard') + initial_rank: int, optional + Initial guess of the rank of future input_data. + If provided this will save computation time. operator : class Operator class ('ngole' only) @@ -268,6 +271,7 @@ def __init__( threshold, thresh_type='soft', lowr_type='standard', + initial_rank=None, operator=None, ): @@ -277,8 +281,9 @@ def __init__( self.operator = operator self.op = self._op_method self.cost = self._cost_method + self.rank = initial_rank - def _op_method(self, input_data, extra_factor=1.0): + def _op_method(self, input_data, extra_factor=1.0, rank=None): """Operator. This method returns the input data after the singular values have been @@ -290,22 +295,37 @@ def _op_method(self, input_data, extra_factor=1.0): Input data array extra_factor : float Additional multiplication factor (default is ``1.0``) + rank: int, optional + Estimation of the rank to save computation time in standard mode, + if not set an internal estimation is used. Returns ------- numpy.ndarray SVD thresholded data + Raises + ------ + ValueError + if lowr_type is not in ``{'standard', 'ngole'}`` """ # Update threshold with extra factor. threshold = self.thresh * extra_factor - - if self.lowr_type == 'standard': + if self.lowr_type == 'standard' and self.rank is None and rank is None: data_matrix = svd_thresh( cube2matrix(input_data), threshold, thresh_type=self.thresh_type, ) + elif self.lowr_type == 'standard': + data_matrix, update_rank = svd_thresh_coef_fast( + cube2matrix(input_data), + threshold, + n_vals=rank or self.rank, + extra_vals=5, + thresh_type=self.thresh_type, + ) + self.rank = update_rank # save for future use elif self.lowr_type == 'ngole': data_matrix = svd_thresh_coef( @@ -314,6 +334,8 @@ def _op_method(self, input_data, extra_factor=1.0): threshold, thresh_type=self.thresh_type, ) + else: + raise ValueError('lowr_type should be standard or ngole') # Return updated data. return matrix2cube(data_matrix, input_data.shape[1:]) diff --git a/modopt/signal/svd.py b/modopt/signal/svd.py index 41241b33..6dcb9eda 100644 --- a/modopt/signal/svd.py +++ b/modopt/signal/svd.py @@ -10,6 +10,7 @@ import numpy as np from scipy.linalg import svd +from scipy.sparse.linalg import svds from modopt.base.transform import matrix2cube from modopt.interface.errors import warn @@ -200,6 +201,64 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'): return np.dot(u_vec, np.dot(s_new, v_vec)) +def svd_thresh_coef_fast( + input_data, + threshold, + n_vals=-1, + extra_vals=5, + thresh_type='hard', +): + """Threshold the singular values coefficients. + + This method thresholds the input data by using singular value + decomposition, but only computing the the greastest ``n_vals`` + values. + + Parameters + ---------- + input_data : numpy.ndarray + Input data array, 2D matrix + Operator class instance + threshold : float or numpy.ndarray + Threshold value(s) + n_vals: int, optional + Number of singular values to compute. + If None, compute all singular values. + extra_vals: int, optional + If the number of values computed is not enough to perform thresholding, + recompute by using ``n_vals + extra_vals`` (default is ``5``) + thresh_type : {'hard', 'soft'} + Type of noise to be added (default is ``'hard'``) + + Returns + ------- + tuple + The thresholded data (numpy.ndarray) and the estimated rank after + thresholding (int) + """ + if n_vals == -1: + n_vals = min(input_data.shape) - 1 + ok = False + while not ok: + (u_vec, s_values, v_vec) = svds(input_data, k=n_vals) + ok = (s_values[0] <= threshold or n_vals == min(input_data.shape) - 1) + n_vals = min(n_vals + extra_vals, *input_data.shape) + + s_values = thresh( + s_values, + threshold, + threshold_type=thresh_type, + ) + rank = np.count_nonzero(s_values) + return ( + np.dot( + u_vec[:, -rank:] * s_values[-rank:], + v_vec[-rank:, :], + ), + rank, + ) + + def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'): """Threshold the singular values coefficients. diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py index 3c33c948..d5547783 100644 --- a/modopt/tests/test_opt.py +++ b/modopt/tests/test_opt.py @@ -675,6 +675,11 @@ def setUp(self): weights, ) self.lowrank = proximity.LowRankMatrix(10.0, thresh_type='hard') + self.lowrank_rank = proximity.LowRankMatrix( + 10.0, + initial_rank=1, + thresh_type='hard', + ) self.lowrank_ngole = proximity.LowRankMatrix( 10.0, lowr_type='ngole', @@ -763,6 +768,8 @@ def tearDown(self): self.positivity = None self.sparsethresh = None self.lowrank = None + self.lowrank_rank = None + self.lowrank_ngole = None self.combo = None self.data1 = None self.data2 = None @@ -841,6 +848,11 @@ def test_low_rank_matrix(self): err_msg='Incorrect low rank operation: standard', ) + npt.assert_almost_equal( + self.lowrank_rank.op(self.data3), + self.data4, + err_msg='Incorrect low rank operation: standard with rank', + ) npt.assert_almost_equal( self.lowrank_ngole.op(self.data3), self.data5, diff --git a/setup.cfg b/setup.cfg index eada1b8c..cabd35a0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -58,7 +58,8 @@ per-file-ignores = #Justification: Needed to import matplotlib.pyplot modopt/plot/cost_plot.py: N802,WPS301 #Todo: Investigate possible bug in find_n_pc function - modopt/signal/svd.py: WPS345 + #Todo: Investigate darglint error + modopt/signal/svd.py: WPS345, DAR000 #Todo: Check security of using system executable call modopt/signal/wavelet.py: S404,S603 #Todo: Clean up tests