Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions meegkit/detrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def regress(x, r, w=None, threshold=1e-7, return_mean=False):
xx = demean(x[:, i], wc) * wc

# remove channel-specific-weighted mean from regressor
r = demean(r, wc)
r = demean(r, wc, inplace=True)
rr = r * wc
V, _ = pca(rr.T @ rr, thresh=threshold)
rr = rr.dot(V)
Expand Down Expand Up @@ -282,12 +282,14 @@ def _plot_detrend(x, y, w):
f = plt.figure()
gs = GridSpec(4, 1, figure=f)
ax1 = f.add_subplot(gs[:3, 0])
plt.plot(x, label='original', color='C0')
plt.plot(y, label='detrended', color='C1')
lines = ax1.plot(x, label='original', color='C0')
plt.setp(lines[1:], label="_")
lines = ax1.plot(y, label='detrended', color='C1')
plt.setp(lines[1:], label="_")
ax1.set_xlim(0, n_times)
ax1.set_xticklabels('')
ax1.set_title('Robust detrending')
ax1.legend()
ax1.legend(fontsize='smaller')

ax2 = f.add_subplot(gs[3, 0])
ax2.pcolormesh(w.T, cmap='Greys')
Expand Down
14 changes: 6 additions & 8 deletions meegkit/dss.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,18 +192,17 @@ def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, blocksize=None,
if X.shape[0] < nfft:
print('Reducing nfft to {}'.format(X.shape[0]))
nfft = X.shape[0]
n_samples, n_chans, n_trials = theshapeof(X)
n_samples, n_chans, _ = theshapeof(X)
if blocksize is None:
blocksize = n_samples

# Recentre data
X = demean(X)
X = demean(X, inplace=True)

# Cancel line_frequency and harmonics + light lowpass
X_filt = smooth(X, sfreq / fline)

# Subtract clean data from original data. The result is the artifact plus
# some residual biological signal
# X - X_filt results in the artifact plus some residual biological signal
X_noise = X - X_filt

# Reduce dimensionality to avoid overfitting
Expand All @@ -226,9 +225,8 @@ def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, blocksize=None,
X_block = X_block.transpose(1, 2, 0)

# bias data
X_bias = gaussfilt(X_block, sfreq, fline, fwhm=1, n_harm=n_harm)
c0 += tscov(X_block)[0]
c1 += tscov(X_bias)[0]
c1 += tscov(gaussfilt(X_block, sfreq, fline, fwhm=1, n_harm=n_harm))[0]

# DSS to isolate line components from residual
todss, _, pwr0, pwr1 = dss0(c0, c1)
Expand All @@ -248,12 +246,12 @@ def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, blocksize=None,

# reconstruct clean signal
y = X_filt + X_res
artifact = X - y

# Power of components
p = wpwr(X - y)[0] / wpwr(X)[0]
print('Power of components removed by DSS: {:.2f}'.format(p))
return y, artifact
# return the reconstructed clean signal, and the artifact
return y, X - y


def dss_line_iter(data, fline, sfreq, win_sz=10, spot_sz=2.5,
Expand Down
10 changes: 4 additions & 6 deletions meegkit/tspca.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def tsr(X, R, shifts=None, wX=None, wR=None, keep=None, thresh=1e-12):
wR = weights

# remove weighted means
X, mean1 = demean(X, wX, return_mean=True)
R = demean(R, wR)
X, mean1 = demean(X, wX, return_mean=True, inplace=True)
R = demean(R, wR, inplace=True)

# equalize power of R channels, the equalize power of the R PCs
# if R.shape[1] > 1:
Expand All @@ -182,11 +182,9 @@ def tsr(X, R, shifts=None, wX=None, wR=None, keep=None, thresh=1e-12):
y = np.zeros((n_samples_X, n_chans_X, n_trials_X))
for t in np.arange(n_trials_X):
r = multishift(R[..., t], shifts, reshape=True)
z = r @ regression
y[..., t] = X[:z.shape[0], :, t] - z

y, mean2 = demean(y, wX, return_mean=True)
y[..., t] = X[:z.shape[0], :, t] - (r @ regression)

y, mean2 = demean(y, wX, return_mean=True, inplace=True)
idx = np.arange(offset1, initial_samples - offset2)
mean_total = mean1 + mean2
weights = wR
Expand Down
23 changes: 16 additions & 7 deletions meegkit/utils/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,32 @@
from .matrix import fold, theshapeof, unfold, _check_weights


def demean(X, weights=None, return_mean=False):
def demean(X, weights=None, return_mean=False, inplace=False):
"""Remove weighted mean over rows (samples).

Parameters
----------
X : array, shape=(n_samples, n_channels[, n_trials])
Data.
weights : array, shape=(n_samples)
return_mean : bool
If True, also return signal mean (default=False).
inplace : bool
If True, save the resulting array in X (default=False).

Returns
-------
demeaned_X : array, shape=(n_samples, n_channels[, n_trials])
X : array, shape=(n_samples, n_channels[, n_trials])
Centered data.
mn : array
Mean value.

"""
weights = _check_weights(weights, X)
ndims = X.ndim

if not inplace:
X = X.copy()
n_samples, n_chans, n_trials = theshapeof(X)
X = unfold(X)

Expand All @@ -43,18 +50,20 @@ def demean(X, weights=None, return_mean=False):
raise ValueError('Weight array should have either the same ' +
'number of columns as X array, or 1 column.')

demeaned_X = X - mn
else:
mn = np.mean(X, axis=0, keepdims=True)
demeaned_X = X - mn

# Remove mean (no copy)
X -= mn
# np.subtract(X, mn, out=X)

if n_trials > 1 or ndims == 3:
demeaned_X = fold(demeaned_X, n_samples)
X = fold(X, n_samples)

if return_mean:
return demeaned_X, mn # the_mean.shape=(1, the_mean.shape[0])
return X, mn # the_mean.shape=(1, the_mean.shape[0])
else:
return demeaned_X
return X


def mean_over_trials(X, weights=None):
Expand Down
11 changes: 6 additions & 5 deletions meegkit/utils/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,17 +495,18 @@ def unsqueeze(X):


def fold(X, epoch_size):
"""Fold 2D X into 3D."""
"""Fold 2D (n_times, n_channels) X into 3D (n_times, n_chans, n_trials)."""
if X.ndim == 1:
X = X[:, np.newaxis]
if X.ndim > 2:
raise AttributeError('X must be 2D at most')

n_chans = X.shape[0] // epoch_size
nt = X.shape[0] // epoch_size
nc = X.shape[1]
if X.shape[0] / epoch_size >= 1:
X = np.transpose(np.reshape(X, (epoch_size, n_chans, X.shape[1]),
order="F").copy(), [0, 2, 1])
return X
return X.reshape((epoch_size, nt, nc), order="F").transpose([0, 2, 1])
else:
return X


def unfold(X):
Expand Down
10 changes: 5 additions & 5 deletions meegkit/utils/sig.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,13 @@ def gaussfilt(data, srate, f, fwhm, n_harm=1, shift=0, return_empvals=False,
fx = fx + gauss

# filter
tmp = np.fft.fft(data, axis=0)
if data.ndim == 2:
tmp *= fx[:, None]
filtdat = 2 * np.real(np.fft.ifft(
np.fft.fft(data, axis=0) * fx[:, None], axis=0))
elif data.ndim == 3:
tmp *= fx[:, None, None]

filtdat = 2 * np.real(np.fft.ifft(tmp, axis=0))
filtdat = 2 * np.real(np.fft.ifft(
np.fft.fft(data, axis=0) * fx[:, None, None],
axis=0))

if return_empvals or show:
empVals[0] = hz[idx_p]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_detrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_regress():
# Simple regression example, no weights
# fit random walk
y = np.cumsum(np.random.randn(1000, 1), axis=0)
x = np.arange(1000)[:, None]
x = np.arange(1000.)[:, None]
x = np.hstack([x, x ** 2, x ** 3])
[b, z] = regress(y, x)

Expand Down
8 changes: 3 additions & 5 deletions tests/test_dss.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def test_dss0(n_bad_chans):
atol=1e-6) # use abs as DSS component might be flipped


def test_dss1(show=False):
def test_dss1(show=True):
"""Test DSS1 (evoked)."""
n_samples = 300
data, source = create_line_data(n_samples=n_samples)
data, source = create_line_data(n_samples=n_samples, fline=.01)

todss, _, pwr0, pwr1 = dss.dss1(data, weights=None, )
z = fold(np.dot(unfold(data), todss), epoch_size=n_samples)
Expand Down Expand Up @@ -179,12 +179,10 @@ def profile_dss_line(nkeep):
sr = 200
fline = 20
nsamples = 1000000
nchans = 16
nchans = 99
s = create_line_data(n_samples=3 * nsamples, n_chans=nchans,
n_trials=1, fline=fline / sr, SNR=2)[0][..., 0]

# 2D case, n_outputs == 1

pr = cProfile.Profile()
pr.enable()
out, _ = dss.dss_line(s, fline, sr, nkeep=nkeep)
Expand Down