diff --git a/meegkit/detrend.py b/meegkit/detrend.py index b829a1a5..de2f2778 100644 --- a/meegkit/detrend.py +++ b/meegkit/detrend.py @@ -80,8 +80,8 @@ def detrend(x, order, w=None, basis='polynomials', threshold=3, n_iter=4, elif basis == 'sinusoids': r = np.zeros((n_times, order * 2)) for i, o in enumerate(range(1, order + 1)): - r[:, 2 * i] = np.sin[2 * np.pi * o * lin / 2] - r[:, 2 * i + 1] = np.cos[2 * np.pi * o * lin / 2] + r[:, 2 * i] = np.sin(2 * np.pi * o * lin / 2) + r[:, 2 * i + 1] = np.cos(2 * np.pi * o * lin / 2) else: raise ValueError('!') @@ -129,7 +129,8 @@ def regress(x, r, w=None, threshold=1e-7, return_mean=False): Weight to apply to `x`. `w` is either a matrix of same size as `x`, or a column vector to be applied to each column of `x`. threshold : float - PCA threshold (default=1e-7). + PCA threshold (default=1e-7). Dimensions of x with eigenvalue lower + than this value will be discarded. return_mean : bool If True, also return the signal mean prior to regression. @@ -145,6 +146,7 @@ def regress(x, r, w=None, threshold=1e-7, return_mean=False): w = _check_weights(w, x) n_times = x.shape[0] n_chans = x.shape[1] + n_regs = r.shape[1] r = unfold(r) x = unfold(x) if r.shape[0] != x.shape[0]: @@ -178,9 +180,9 @@ def regress(x, r, w=None, threshold=1e-7, return_mean=False): else: yy = demean(x, w) * w rr = demean(r, w) * w - V, _ = pca(rr.T.dot(rr), thresh=threshold) - rrr = rr.dot(V) - b = mrdivide(yy.T.dot(rrr), rrr.T.dot(rrr)) + V, _ = pca(rr.T @ rr, thresh=threshold) + rr = rr @ V + b = mrdivide(yy.T @ rr, rr.T @ rr) z = demean(r, w).dot(V).dot(b.T) z = z + mn @@ -189,25 +191,22 @@ def regress(x, r, w=None, threshold=1e-7, return_mean=False): if w.shape[1] != x.shape[1]: raise ValueError('!') z = np.zeros(x.shape) - b = np.zeros((n_chans, n_chans)) + b = np.zeros((n_chans, n_regs)) for i in range(n_chans): - if sum(w[:, i]) == 0: - print('weights all zero for channel {}'.format(i)) - c = np.zeros(x.shape[1], 1) + if not np.any(w[:, i]): + print(f'weights are all zero for channel {i}') else: wc = w[:, i][:, None] # channel-specific weight - yy = demean(x[:, i], wc) * wc + xx = demean(x[:, i], wc) * wc + # remove channel-specific-weighted mean from regressor r = demean(r, wc) rr = r * wc - V, _ = pca(rr.T.dot(rr), thresh=threshold) + V, _ = pca(rr.T @ rr, thresh=threshold) rr = rr.dot(V) - c = mrdivide(yy.T.dot(rr), rr.T.dot(rr)) + b[i, :V.shape[1]] = mrdivide(xx.T @ rr, rr.T @ rr) - z[:, i] = r.dot(V.dot(c.T)).flatten() - z[:, i] += mn[:, i] - b[i] = c - b = b[:, :V.shape[1]] + z[:, i] = np.squeeze(r @ (V @ b[i, :V.shape[1]].T)) + mn[:, i] return b, z diff --git a/tests/test_detrend.py b/tests/test_detrend.py index b4bd7761..477e6cb6 100644 --- a/tests/test_detrend.py +++ b/tests/test_detrend.py @@ -39,14 +39,15 @@ def test_regress(): w[:, 1] == .8 [b, z] = regress(y, x, w) assert z.shape == (1000, 2) - assert b.shape == (2, 1) + assert b.shape == (2, 3) def test_detrend(show=False): """Test detrending.""" # basic - x = np.arange(100)[:, None] - x = x + np.random.randn(*x.shape) + x = np.arange(100)[:, None] # trend + source = np.random.randn(*x.shape) + x = x + source y, _, _ = detrend(x, 1) assert y.shape == x.shape @@ -57,20 +58,37 @@ def test_detrend(show=False): assert y.shape == x.shape - # weights + # test weights trend = np.linspace(0, 100, 1000)[:, None] data = 3 * np.random.randn(*trend.shape) + data[:100, :] = 100 x = trend + data - x[:100, :] = 100 w = np.ones(x.shape) w[:100, :] = 0 + + # Without weights – detrending fails y, _, _ = detrend(x, 3, None, show=show) + + # With weights – detrending works yy, _, _ = detrend(x, 3, w, show=show) assert y.shape == x.shape assert yy.shape == x.shape + assert np.all(np.abs(yy[100:] - data[100:]) < 1.) + + # detrend higher-dimensional data + x = np.cumsum(np.random.randn(1000, 16) + 0.1, axis=0) + y, _, _ = detrend(x, 1, show=False) + + # detrend higher-dimensional data with order 3 polynomial + x = np.cumsum(np.random.randn(1000, 16) + 0.1, axis=0) + y, _, _ = detrend(x, 3, basis='polynomials', show=True) + + # detrend with sinusoids + x = np.random.randn(1000, 2) + x += 2 * np.sin(2 * np.pi * np.arange(1000) / 200)[:, None] + y, _, _ = detrend(x, 5, basis='sinusoids', show=True) - # assert_almost_equal(yy[100:], data[100:], decimal=1) def test_ringing(): """Test reduce_ringing function.""" @@ -87,4 +105,4 @@ def test_ringing(): if __name__ == '__main__': import pytest pytest.main([__file__]) - # test_detrend(True) + # test_detrend(False)