Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 10 additions & 14 deletions mne/mixed_norm/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,22 +175,21 @@ def _mixed_norm_solver_prox(M, G, alpha, maxit=200, tol=1e-8, verbose=None,

if init is None:
X = 0.0
active_set = np.zeros(n_sources, dtype=np.bool)
R = M.copy()
if gram is not None:
R = np.dot(G.T, R)
else:
X, active_set = init
X = init
if gram is None:
R = M - np.dot(G[:, active_set], X)
R = M - np.dot(G, X)
else:
R = GTM - np.dot(gram[:, active_set], X)
R = GTM - np.dot(gram, X)

t = 1.0
Y = np.zeros((n_sources, n_times)) # FISTA aux variable
E = [] # track cost function

active_set = np.ones(n_sources, dtype=np.bool) # HACK
active_set = np.ones(n_sources, dtype=np.bool) # start with full AS

for i in xrange(maxit):
X0, active_set_0 = X, active_set # store previous values
Expand Down Expand Up @@ -231,14 +230,12 @@ def _mixed_norm_solver_cd(M, G, alpha, maxit=10000, tol=1e-8,
n_sensors, n_times = M.shape
n_sensors, n_sources = G.shape

if init is None:
X = np.zeros((n_sources, n_times))
else:
X, active_set = init
if init is not None:
init = init.T

clf = MultiTaskLasso(alpha=alpha / len(M), tol=tol, normalize=False,
fit_intercept=False, max_iter=maxit).fit(G, M,
coef_init=X.T)
coef_init=init)
X = clf.coef_.T
active_set = np.any(X, axis=1)
X = X[active_set]
Expand Down Expand Up @@ -325,7 +322,7 @@ def mixed_norm_solver(M, G, alpha, maxit=3000, tol=1e-8, verbose=None,
logger.info("Using proximal iterations")
l21_solver = _mixed_norm_solver_prox

init = None
X_init = None

if active_set_size is not None:
n_sensors, n_times = M.shape
Expand All @@ -336,7 +333,7 @@ def mixed_norm_solver(M, G, alpha, maxit=3000, tol=1e-8, verbose=None,
active_set = np.tile(active_set[:, None], [1, n_orient]).ravel()
for k in xrange(maxit):
X, as_, E = l21_solver(M, G[:, active_set], alpha,
maxit=maxit, tol=tol, init=init,
maxit=maxit, tol=tol, init=X_init,
n_orient=n_orient)
as_ = np.where(active_set)[0][as_]
gap, pobj, dobj, R = dgap_l21(M, G, X, as_, alpha, n_orient)
Expand All @@ -362,7 +359,6 @@ def mixed_norm_solver(M, G, alpha, maxit=3000, tol=1e-8, verbose=None,
idx_active_set = np.where(active_set)[0]
idx = np.searchsorted(idx_active_set, idx_old_active_set)
X_init[idx] = X
init = X_init, active_set[active_set == True]
if np.all(active_set_old == active_set):
logger.info('Convergence stopped (AS did not change) !')
break
Expand All @@ -373,7 +369,7 @@ def mixed_norm_solver(M, G, alpha, maxit=3000, tol=1e-8, verbose=None,
active_set[as_] = True
else:
X, active_set, E = l21_solver(M, G, alpha, maxit=maxit, tol=tol,
init=init, n_orient=n_orient)
init=X_init, n_orient=n_orient)

if (active_set.sum() > 0) and debias:
bias = compute_bias(M, G[:, active_set], X, n_orient=n_orient)
Expand Down