diff --git a/meegkit/detrend.py b/meegkit/detrend.py index 0b882e25..34de9f74 100644 --- a/meegkit/detrend.py +++ b/meegkit/detrend.py @@ -200,7 +200,7 @@ def regress(x, r, w=None, threshold=1e-7, return_mean=False): xx = demean(x[:, i], wc) * wc # remove channel-specific-weighted mean from regressor - r = demean(r, wc) + r = demean(r, wc, inplace=True) rr = r * wc V, _ = pca(rr.T @ rr, thresh=threshold) rr = rr.dot(V) @@ -282,12 +282,14 @@ def _plot_detrend(x, y, w): f = plt.figure() gs = GridSpec(4, 1, figure=f) ax1 = f.add_subplot(gs[:3, 0]) - plt.plot(x, label='original', color='C0') - plt.plot(y, label='detrended', color='C1') + lines = ax1.plot(x, label='original', color='C0') + plt.setp(lines[1:], label="_") + lines = ax1.plot(y, label='detrended', color='C1') + plt.setp(lines[1:], label="_") ax1.set_xlim(0, n_times) ax1.set_xticklabels('') ax1.set_title('Robust detrending') - ax1.legend() + ax1.legend(fontsize='smaller') ax2 = f.add_subplot(gs[3, 0]) ax2.pcolormesh(w.T, cmap='Greys') diff --git a/meegkit/dss.py b/meegkit/dss.py index b67b81b7..f4841ba3 100644 --- a/meegkit/dss.py +++ b/meegkit/dss.py @@ -192,18 +192,17 @@ def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, blocksize=None, if X.shape[0] < nfft: print('Reducing nfft to {}'.format(X.shape[0])) nfft = X.shape[0] - n_samples, n_chans, n_trials = theshapeof(X) + n_samples, n_chans, _ = theshapeof(X) if blocksize is None: blocksize = n_samples # Recentre data - X = demean(X) + X = demean(X, inplace=True) # Cancel line_frequency and harmonics + light lowpass X_filt = smooth(X, sfreq / fline) - # Subtract clean data from original data. The result is the artifact plus - # some residual biological signal + # X - X_filt results in the artifact plus some residual biological signal X_noise = X - X_filt # Reduce dimensionality to avoid overfitting @@ -226,9 +225,8 @@ def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, blocksize=None, X_block = X_block.transpose(1, 2, 0) # bias data - X_bias = gaussfilt(X_block, sfreq, fline, fwhm=1, n_harm=n_harm) c0 += tscov(X_block)[0] - c1 += tscov(X_bias)[0] + c1 += tscov(gaussfilt(X_block, sfreq, fline, fwhm=1, n_harm=n_harm))[0] # DSS to isolate line components from residual todss, _, pwr0, pwr1 = dss0(c0, c1) @@ -248,12 +246,12 @@ def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, blocksize=None, # reconstruct clean signal y = X_filt + X_res - artifact = X - y # Power of components p = wpwr(X - y)[0] / wpwr(X)[0] print('Power of components removed by DSS: {:.2f}'.format(p)) - return y, artifact + # return the reconstructed clean signal, and the artifact + return y, X - y def dss_line_iter(data, fline, sfreq, win_sz=10, spot_sz=2.5, diff --git a/meegkit/tspca.py b/meegkit/tspca.py index 9769bffc..0efbac21 100644 --- a/meegkit/tspca.py +++ b/meegkit/tspca.py @@ -158,8 +158,8 @@ def tsr(X, R, shifts=None, wX=None, wR=None, keep=None, thresh=1e-12): wR = weights # remove weighted means - X, mean1 = demean(X, wX, return_mean=True) - R = demean(R, wR) + X, mean1 = demean(X, wX, return_mean=True, inplace=True) + R = demean(R, wR, inplace=True) # equalize power of R channels, the equalize power of the R PCs # if R.shape[1] > 1: @@ -182,11 +182,9 @@ def tsr(X, R, shifts=None, wX=None, wR=None, keep=None, thresh=1e-12): y = np.zeros((n_samples_X, n_chans_X, n_trials_X)) for t in np.arange(n_trials_X): r = multishift(R[..., t], shifts, reshape=True) - z = r @ regression - y[..., t] = X[:z.shape[0], :, t] - z - - y, mean2 = demean(y, wX, return_mean=True) + y[..., t] = X[:z.shape[0], :, t] - (r @ regression) + y, mean2 = demean(y, wX, return_mean=True, inplace=True) idx = np.arange(offset1, initial_samples - offset2) mean_total = mean1 + mean2 weights = wR diff --git a/meegkit/utils/denoise.py b/meegkit/utils/denoise.py index dc87bcd7..92ed6321 100644 --- a/meegkit/utils/denoise.py +++ b/meegkit/utils/denoise.py @@ -7,7 +7,7 @@ from .matrix import fold, theshapeof, unfold, _check_weights -def demean(X, weights=None, return_mean=False): +def demean(X, weights=None, return_mean=False, inplace=False): """Remove weighted mean over rows (samples). Parameters @@ -15,10 +15,14 @@ def demean(X, weights=None, return_mean=False): X : array, shape=(n_samples, n_channels[, n_trials]) Data. weights : array, shape=(n_samples) + return_mean : bool + If True, also return signal mean (default=False). + inplace : bool + If True, save the resulting array in X (default=False). Returns ------- - demeaned_X : array, shape=(n_samples, n_channels[, n_trials]) + X : array, shape=(n_samples, n_channels[, n_trials]) Centered data. mn : array Mean value. @@ -26,6 +30,9 @@ def demean(X, weights=None, return_mean=False): """ weights = _check_weights(weights, X) ndims = X.ndim + + if not inplace: + X = X.copy() n_samples, n_chans, n_trials = theshapeof(X) X = unfold(X) @@ -43,18 +50,20 @@ def demean(X, weights=None, return_mean=False): raise ValueError('Weight array should have either the same ' + 'number of columns as X array, or 1 column.') - demeaned_X = X - mn else: mn = np.mean(X, axis=0, keepdims=True) - demeaned_X = X - mn + + # Remove mean (no copy) + X -= mn + # np.subtract(X, mn, out=X) if n_trials > 1 or ndims == 3: - demeaned_X = fold(demeaned_X, n_samples) + X = fold(X, n_samples) if return_mean: - return demeaned_X, mn # the_mean.shape=(1, the_mean.shape[0]) + return X, mn # the_mean.shape=(1, the_mean.shape[0]) else: - return demeaned_X + return X def mean_over_trials(X, weights=None): diff --git a/meegkit/utils/matrix.py b/meegkit/utils/matrix.py index bacc3580..e186b095 100644 --- a/meegkit/utils/matrix.py +++ b/meegkit/utils/matrix.py @@ -495,17 +495,18 @@ def unsqueeze(X): def fold(X, epoch_size): - """Fold 2D X into 3D.""" + """Fold 2D (n_times, n_channels) X into 3D (n_times, n_chans, n_trials).""" if X.ndim == 1: X = X[:, np.newaxis] if X.ndim > 2: raise AttributeError('X must be 2D at most') - n_chans = X.shape[0] // epoch_size + nt = X.shape[0] // epoch_size + nc = X.shape[1] if X.shape[0] / epoch_size >= 1: - X = np.transpose(np.reshape(X, (epoch_size, n_chans, X.shape[1]), - order="F").copy(), [0, 2, 1]) - return X + return X.reshape((epoch_size, nt, nc), order="F").transpose([0, 2, 1]) + else: + return X def unfold(X): diff --git a/meegkit/utils/sig.py b/meegkit/utils/sig.py index c8a312fe..ada58d5d 100644 --- a/meegkit/utils/sig.py +++ b/meegkit/utils/sig.py @@ -345,13 +345,13 @@ def gaussfilt(data, srate, f, fwhm, n_harm=1, shift=0, return_empvals=False, fx = fx + gauss # filter - tmp = np.fft.fft(data, axis=0) if data.ndim == 2: - tmp *= fx[:, None] + filtdat = 2 * np.real(np.fft.ifft( + np.fft.fft(data, axis=0) * fx[:, None], axis=0)) elif data.ndim == 3: - tmp *= fx[:, None, None] - - filtdat = 2 * np.real(np.fft.ifft(tmp, axis=0)) + filtdat = 2 * np.real(np.fft.ifft( + np.fft.fft(data, axis=0) * fx[:, None, None], + axis=0)) if return_empvals or show: empVals[0] = hz[idx_p] diff --git a/tests/test_detrend.py b/tests/test_detrend.py index 3419ec97..5d115044 100644 --- a/tests/test_detrend.py +++ b/tests/test_detrend.py @@ -11,7 +11,7 @@ def test_regress(): # Simple regression example, no weights # fit random walk y = np.cumsum(np.random.randn(1000, 1), axis=0) - x = np.arange(1000)[:, None] + x = np.arange(1000.)[:, None] x = np.hstack([x, x ** 2, x ** 3]) [b, z] = regress(y, x) diff --git a/tests/test_dss.py b/tests/test_dss.py index b15d794e..be390355 100644 --- a/tests/test_dss.py +++ b/tests/test_dss.py @@ -39,10 +39,10 @@ def test_dss0(n_bad_chans): atol=1e-6) # use abs as DSS component might be flipped -def test_dss1(show=False): +def test_dss1(show=True): """Test DSS1 (evoked).""" n_samples = 300 - data, source = create_line_data(n_samples=n_samples) + data, source = create_line_data(n_samples=n_samples, fline=.01) todss, _, pwr0, pwr1 = dss.dss1(data, weights=None, ) z = fold(np.dot(unfold(data), todss), epoch_size=n_samples) @@ -179,12 +179,10 @@ def profile_dss_line(nkeep): sr = 200 fline = 20 nsamples = 1000000 - nchans = 16 + nchans = 99 s = create_line_data(n_samples=3 * nsamples, n_chans=nchans, n_trials=1, fline=fline / sr, SNR=2)[0][..., 0] - # 2D case, n_outputs == 1 - pr = cProfile.Profile() pr.enable() out, _ = dss.dss_line(s, fline, sr, nkeep=nkeep)