From 91d9a3b8ce86b6eaca2bc72c710a7a62eac984ad Mon Sep 17 00:00:00 2001 From: nbara <10333715+nbara@users.noreply.github.com> Date: Tue, 8 Sep 2020 12:45:48 +0200 Subject: [PATCH] [FIX] dss_line() works with nremove>1 --- meegkit/dss.py | 13 +++++++---- meegkit/tspca.py | 26 +++++++++++----------- meegkit/utils/matrix.py | 5 +---- meegkit/utils/sig.py | 6 +++++- tests/test_dss.py | 48 +++++++++++++++++++++++++++-------------- tests/test_tspca.py | 31 +++++++++++++++----------- 6 files changed, 78 insertions(+), 51 deletions(-) diff --git a/meegkit/dss.py b/meegkit/dss.py index 10a3c588..45402580 100644 --- a/meegkit/dss.py +++ b/meegkit/dss.py @@ -2,9 +2,9 @@ import numpy as np from scipy import linalg -from .utils import (demean, mean_over_trials, smooth, pca, theshapeof, - tscov, gaussfilt, wpwr) from .tspca import tsr +from .utils import (demean, gaussfilt, mean_over_trials, pca, smooth, + theshapeof, tscov, wpwr) def dss1(data, weights=None, keep1=None, keep2=1e-12): @@ -175,7 +175,7 @@ def dss_line(x, fline, sfreq, nremove=1, nfft=1024, nkeep=None, show=False): if x.shape[0] < nfft: print('reducing nfft to {}'.format(x.shape[0])) nfft = x.shape[0] - + n_samples, n_chans, n_trials = theshapeof(x) x = demean(x) # cancels line_frequency and harmonics, light lowpass @@ -208,7 +208,12 @@ def dss_line(x, fline, sfreq, nremove=1, nfft=1024, nkeep=None, show=False): plt.show() idx_remove = np.arange(nremove) - xxxx = xxxx @ todss[:, idx_remove] # line-dominated components + if x.ndim == 3: + for t in range(n_trials): # line-dominated components + xxxx[..., t] = xxxx[..., t] @ todss[:, idx_remove] + elif x.ndim == 2: + xxxx = xxxx @ todss[:, idx_remove] + xxx, _, _, _ = tsr(xxx, xxxx) # project them out # reconstruct clean signal diff --git a/meegkit/tspca.py b/meegkit/tspca.py index f48ea9dc..9769bffc 100644 --- a/meegkit/tspca.py +++ b/meegkit/tspca.py @@ -120,7 +120,6 @@ def tsr(X, R, shifts=None, wX=None, wR=None, keep=None, thresh=1e-12): offset1 = np.max((0, -np.min(shifts))) idx = np.arange(offset1, X.shape[0]) # X = X[idx, ...] - # if len(wX) > 0: # wX = wX[idx, ...] # if len(wR) > 0: @@ -130,10 +129,9 @@ def tsr(X, R, shifts=None, wX=None, wR=None, keep=None, thresh=1e-12): # adjust size of X offset2 = np.max((0, np.max(shifts))) - idx = np.arange(X.shape[0]) - offset2 - idx = idx[idx >= 0] + # idx = np.arange(X.shape[0]) - offset2 + # idx = idx[idx >= 0] # X = X[idx, ...] - # if len(wX) > 0: # wX = wX[idx, ...] @@ -157,7 +155,6 @@ def tsr(X, R, shifts=None, wX=None, wR=None, keep=None, thresh=1e-12): weights[..., t] = wr wX = weights - wR = np.zeros((n_samples_R, 1, n_trials_R)) wR = weights # remove weighted means @@ -165,13 +162,14 @@ def tsr(X, R, shifts=None, wX=None, wR=None, keep=None, thresh=1e-12): R = demean(R, wR) # equalize power of R channels, the equalize power of the R PCs - if R.shape[1] > 1: - R = normcol(R, wR) - C, _ = tscov(R, []) - V, _ = pca(C, thresh=1e-6) - R = R * V - + # if R.shape[1] > 1: R = normcol(R, wR) + C, _ = tscov(R) + V, _ = pca(C, thresh=1e-6) + z = np.zeros((n_samples_X, V.shape[1], n_trials_R)) + for t in range(n_trials_R): + z[..., t] = R[..., t] @ V + R = normcol(z, wR) # covariances and cross-covariance with time-shifted refs Cr, twcr = tscov(R, shifts, wR) @@ -182,10 +180,10 @@ def tsr(X, R, shifts=None, wX=None, wR=None, keep=None, thresh=1e-12): # TSPCA: clean x by removing regression on time-shifted refs y = np.zeros((n_samples_X, n_chans_X, n_trials_X)) - for trial in np.arange(n_trials_X): - r = multishift(R[..., trial], shifts, reshape=True) + for t in np.arange(n_trials_X): + r = multishift(R[..., t], shifts, reshape=True) z = r @ regression - y[..., trial] = X[:z.shape[0], :, trial] - z + y[..., t] = X[:z.shape[0], :, t] - z y, mean2 = demean(y, wX, return_mean=True) diff --git a/meegkit/utils/matrix.py b/meegkit/utils/matrix.py index 3087c8bc..af36976c 100644 --- a/meegkit/utils/matrix.py +++ b/meegkit/utils/matrix.py @@ -553,8 +553,8 @@ def normcol(X, weights=None, return_norm=False): """ if X.ndim == 3: n_samples, n_chans, n_trials = theshapeof(X) - X = unfold(X) weights = _check_weights(weights, X) + X = unfold(X) if not weights.any(): # no weights X_norm, N = normcol(X, return_norm=True) N = N ** 2 @@ -568,9 +568,6 @@ def normcol(X, weights=None, return_norm=False): if weights.ndim == 2 and weights.shape[1] == 1: weights = np.tile(weights, (1, n_samples, n_trials)) - if weights.shape != weights.shape: - raise ValueError("Weight array should have be same shape as X") - weights = unfold(weights) X_norm, N = normcol(X, weights, return_norm=True) N = N ** 2 diff --git a/meegkit/utils/sig.py b/meegkit/utils/sig.py index 85d9ee9a..92ecfefe 100644 --- a/meegkit/utils/sig.py +++ b/meegkit/utils/sig.py @@ -399,7 +399,11 @@ def gaussfilt(data, srate, f, fwhm, n_harm=1, shift=0, return_empvals=False, # filter tmp = np.fft.fft(data, axis=0) - tmp *= fx[:, None] + if data.ndim == 2: + tmp *= fx[:, None] + elif data.ndim == 3: + tmp *= fx[:, None, None] + filtdat = 2 * np.real(np.fft.ifft(tmp, axis=0)) if return_empvals or show: diff --git a/tests/test_dss.py b/tests/test_dss.py index 2c9a1e67..22ab6d56 100644 --- a/tests/test_dss.py +++ b/tests/test_dss.py @@ -73,22 +73,38 @@ def test_dss_line(): artifact = artifact ** 3 s = x + 10 * artifact + def _plot(x): + f, ax = plt.subplots(1, 2, sharey=True) + f, Pxx = signal.welch(x, sr, nperseg=1024, axis=0, + return_onesided=True) + ax[1].semilogy(f, Pxx) + f, Pxx = signal.welch(s, sr, nperseg=1024, axis=0, + return_onesided=True) + ax[0].semilogy(f, Pxx) + ax[0].set_xlabel('frequency [Hz]') + ax[1].set_xlabel('frequency [Hz]') + ax[0].set_ylabel('PSD [V**2/Hz]') + ax[0].set_title('before') + ax[1].set_title('after') + plt.show() + + # 2D case, n_outputs == 1 out, _ = dss.dss_line(s, 20, sr) - f, ax = plt.subplots(1, 2, sharey=True) - print(out.shape) - f, Pxx = signal.welch(out, sr, nperseg=1024, axis=0, - return_onesided=True) - ax[1].semilogy(f, Pxx) - f, Pxx = signal.welch(s, sr, nperseg=1024, axis=0, - return_onesided=True) - ax[0].semilogy(f, Pxx) - ax[0].set_xlabel('frequency [Hz]') - ax[1].set_xlabel('frequency [Hz]') - ax[0].set_ylabel('PSD [V**2/Hz]') - ax[0].set_title('before') - ax[1].set_title('after') - plt.show() + _plot(out) + + # Test n_outputs > 1 + out, _ = dss.dss_line(s, 20, sr, nremove=2) + # _plot(out) + + # Test n_trials > 1 + x = np.random.randn(nsamples, nchans, 4) + artifact = np.sin(np.arange(nsamples) / sr * 2 * np.pi * 20)[:, None, None] + artifact[artifact < 0] = 0 + artifact = artifact ** 3 + s = x + 10 * artifact + out, _ = dss.dss_line(s, 20, sr, nremove=1) + if __name__ == '__main__': - pytest.main([__file__]) - # test_dss_line() + # pytest.main([__file__]) + test_dss_line() diff --git a/tests/test_tspca.py b/tests/test_tspca.py index dcd802d6..67d69efd 100644 --- a/tests/test_tspca.py +++ b/tests/test_tspca.py @@ -64,8 +64,12 @@ def test_tsr(show=True): # artifact + harmonics artifact = np.sin(np.arange(nsamples) / sr * 2 * np.pi * 10)[:, None] artifact[artifact < 0] = 0 - artifact = artifact ** 3 - signal = x + 10 * artifact + artifact = 5 * artifact ** 3 + signal = x + artifact + + signal -= np.mean(signal, keepdims=True) + artifact -= np.mean(artifact, keepdims=True) + # Without shifts y, idx, mean_total, weights = tspca.tsr( signal, @@ -73,10 +77,10 @@ def test_tsr(show=True): shifts=[0]) if show: - f, ax = plt.subplots(2, 1) - ax[0].plot(y[:500, 0], 'grey', label='cleaned') - ax[0].plot(x[:500, 0], ':', label='signal') - ax[1].plot((y - x)[:500], label='cleaned - signal') + f, ax = plt.subplots(2, 1, num='without shifts') + ax[0].plot(y[:500, 0], 'grey', label='recovered signal') + ax[0].plot(x[:500, 0], ':', label='real signal') + ax[1].plot((y - x)[:500], label='residual') ax[0].legend() ax[1].legend() # plt.show() @@ -84,19 +88,22 @@ def test_tsr(show=True): # Test residual almost 0.0 np.testing.assert_almost_equal(y - x, np.zeros_like(y), decimal=1) - # With shifts + # With shifts. We slide the input array by one sample, and check that the + # artifact is successfully regressed. y, idx, mean_total, weights = tspca.tsr( - signal + artifact, + signal, np.roll(artifact, 1, axis=0), shifts=[-1, 0, 1]) if show: - f, ax = plt.subplots(2, 1) - ax[0].plot(y[:500, 0], 'grey', label='cleaned') - ax[0].plot(x[:500, 0], ':', label='signal') - ax[1].plot((y - x)[:500, 0], label='cleaned - signal') + f, ax = plt.subplots(3, 1, num='with shifts') + ax[0].plot(signal[:500], label='signal + noise') + ax[1].plot(x[:500, 0], 'grey', label='real signal') + ax[1].plot(y[:500, 0], ':', label='recovered signal') + ax[2].plot((signal - y)[:500, 0], label='before - after') ax[0].legend() ax[1].legend() + ax[2].legend() plt.show() if __name__ == '__main__':