diff --git a/mne/mixed_norm/optim.py b/mne/mixed_norm/optim.py index a9918bc12a3..6f7751cef12 100644 --- a/mne/mixed_norm/optim.py +++ b/mne/mixed_norm/optim.py @@ -175,22 +175,21 @@ def _mixed_norm_solver_prox(M, G, alpha, maxit=200, tol=1e-8, verbose=None, if init is None: X = 0.0 - active_set = np.zeros(n_sources, dtype=np.bool) R = M.copy() if gram is not None: R = np.dot(G.T, R) else: - X, active_set = init + X = init if gram is None: - R = M - np.dot(G[:, active_set], X) + R = M - np.dot(G, X) else: - R = GTM - np.dot(gram[:, active_set], X) + R = GTM - np.dot(gram, X) t = 1.0 Y = np.zeros((n_sources, n_times)) # FISTA aux variable E = [] # track cost function - active_set = np.ones(n_sources, dtype=np.bool) # HACK + active_set = np.ones(n_sources, dtype=np.bool) # start with full AS for i in xrange(maxit): X0, active_set_0 = X, active_set # store previous values @@ -231,14 +230,12 @@ def _mixed_norm_solver_cd(M, G, alpha, maxit=10000, tol=1e-8, n_sensors, n_times = M.shape n_sensors, n_sources = G.shape - if init is None: - X = np.zeros((n_sources, n_times)) - else: - X, active_set = init + if init is not None: + init = init.T clf = MultiTaskLasso(alpha=alpha / len(M), tol=tol, normalize=False, fit_intercept=False, max_iter=maxit).fit(G, M, - coef_init=X.T) + coef_init=init) X = clf.coef_.T active_set = np.any(X, axis=1) X = X[active_set] @@ -325,7 +322,7 @@ def mixed_norm_solver(M, G, alpha, maxit=3000, tol=1e-8, verbose=None, logger.info("Using proximal iterations") l21_solver = _mixed_norm_solver_prox - init = None + X_init = None if active_set_size is not None: n_sensors, n_times = M.shape @@ -336,7 +333,7 @@ def mixed_norm_solver(M, G, alpha, maxit=3000, tol=1e-8, verbose=None, active_set = np.tile(active_set[:, None], [1, n_orient]).ravel() for k in xrange(maxit): X, as_, E = l21_solver(M, G[:, active_set], alpha, - maxit=maxit, tol=tol, init=init, + maxit=maxit, tol=tol, init=X_init, n_orient=n_orient) as_ = np.where(active_set)[0][as_] gap, pobj, dobj, R = dgap_l21(M, G, X, as_, alpha, n_orient) @@ -362,7 +359,6 @@ def mixed_norm_solver(M, G, alpha, maxit=3000, tol=1e-8, verbose=None, idx_active_set = np.where(active_set)[0] idx = np.searchsorted(idx_active_set, idx_old_active_set) X_init[idx] = X - init = X_init, active_set[active_set == True] if np.all(active_set_old == active_set): logger.info('Convergence stopped (AS did not change) !') break @@ -373,7 +369,7 @@ def mixed_norm_solver(M, G, alpha, maxit=3000, tol=1e-8, verbose=None, active_set[as_] = True else: X, active_set, E = l21_solver(M, G, alpha, maxit=maxit, tol=tol, - init=init, n_orient=n_orient) + init=X_init, n_orient=n_orient) if (active_set.sum() > 0) and debias: bias = compute_bias(M, G[:, active_set], X, n_orient=n_orient)