Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 16 additions & 17 deletions meegkit/detrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def detrend(x, order, w=None, basis='polynomials', threshold=3, n_iter=4,
elif basis == 'sinusoids':
r = np.zeros((n_times, order * 2))
for i, o in enumerate(range(1, order + 1)):
r[:, 2 * i] = np.sin[2 * np.pi * o * lin / 2]
r[:, 2 * i + 1] = np.cos[2 * np.pi * o * lin / 2]
r[:, 2 * i] = np.sin(2 * np.pi * o * lin / 2)
r[:, 2 * i + 1] = np.cos(2 * np.pi * o * lin / 2)
else:
raise ValueError('!')

Expand Down Expand Up @@ -129,7 +129,8 @@ def regress(x, r, w=None, threshold=1e-7, return_mean=False):
Weight to apply to `x`. `w` is either a matrix of same size as `x`, or
a column vector to be applied to each column of `x`.
threshold : float
PCA threshold (default=1e-7).
PCA threshold (default=1e-7). Dimensions of x with eigenvalue lower
than this value will be discarded.
return_mean : bool
If True, also return the signal mean prior to regression.

Expand All @@ -145,6 +146,7 @@ def regress(x, r, w=None, threshold=1e-7, return_mean=False):
w = _check_weights(w, x)
n_times = x.shape[0]
n_chans = x.shape[1]
n_regs = r.shape[1]
r = unfold(r)
x = unfold(x)
if r.shape[0] != x.shape[0]:
Expand Down Expand Up @@ -178,9 +180,9 @@ def regress(x, r, w=None, threshold=1e-7, return_mean=False):
else:
yy = demean(x, w) * w
rr = demean(r, w) * w
V, _ = pca(rr.T.dot(rr), thresh=threshold)
rrr = rr.dot(V)
b = mrdivide(yy.T.dot(rrr), rrr.T.dot(rrr))
V, _ = pca(rr.T @ rr, thresh=threshold)
rr = rr @ V
b = mrdivide(yy.T @ rr, rr.T @ rr)

z = demean(r, w).dot(V).dot(b.T)
z = z + mn
Expand All @@ -189,25 +191,22 @@ def regress(x, r, w=None, threshold=1e-7, return_mean=False):
if w.shape[1] != x.shape[1]:
raise ValueError('!')
z = np.zeros(x.shape)
b = np.zeros((n_chans, n_chans))
b = np.zeros((n_chans, n_regs))
for i in range(n_chans):
if sum(w[:, i]) == 0:
print('weights all zero for channel {}'.format(i))
c = np.zeros(x.shape[1], 1)
if not np.any(w[:, i]):
print(f'weights are all zero for channel {i}')
else:
wc = w[:, i][:, None] # channel-specific weight
yy = demean(x[:, i], wc) * wc
xx = demean(x[:, i], wc) * wc

# remove channel-specific-weighted mean from regressor
r = demean(r, wc)
rr = r * wc
V, _ = pca(rr.T.dot(rr), thresh=threshold)
V, _ = pca(rr.T @ rr, thresh=threshold)
rr = rr.dot(V)
c = mrdivide(yy.T.dot(rr), rr.T.dot(rr))
b[i, :V.shape[1]] = mrdivide(xx.T @ rr, rr.T @ rr)

z[:, i] = r.dot(V.dot(c.T)).flatten()
z[:, i] += mn[:, i]
b[i] = c
b = b[:, :V.shape[1]]
z[:, i] = np.squeeze(r @ (V @ b[i, :V.shape[1]].T)) + mn[:, i]

return b, z

Expand Down
32 changes: 25 additions & 7 deletions tests/test_detrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@ def test_regress():
w[:, 1] == .8
[b, z] = regress(y, x, w)
assert z.shape == (1000, 2)
assert b.shape == (2, 1)
assert b.shape == (2, 3)


def test_detrend(show=False):
"""Test detrending."""
# basic
x = np.arange(100)[:, None]
x = x + np.random.randn(*x.shape)
x = np.arange(100)[:, None] # trend
source = np.random.randn(*x.shape)
x = x + source
y, _, _ = detrend(x, 1)

assert y.shape == x.shape
Expand All @@ -57,20 +58,37 @@ def test_detrend(show=False):

assert y.shape == x.shape

# weights
# test weights
trend = np.linspace(0, 100, 1000)[:, None]
data = 3 * np.random.randn(*trend.shape)
data[:100, :] = 100
x = trend + data
x[:100, :] = 100
w = np.ones(x.shape)
w[:100, :] = 0

# Without weights – detrending fails
y, _, _ = detrend(x, 3, None, show=show)

# With weights – detrending works
yy, _, _ = detrend(x, 3, w, show=show)

assert y.shape == x.shape
assert yy.shape == x.shape
assert np.all(np.abs(yy[100:] - data[100:]) < 1.)

# detrend higher-dimensional data
x = np.cumsum(np.random.randn(1000, 16) + 0.1, axis=0)
y, _, _ = detrend(x, 1, show=False)

# detrend higher-dimensional data with order 3 polynomial
x = np.cumsum(np.random.randn(1000, 16) + 0.1, axis=0)
y, _, _ = detrend(x, 3, basis='polynomials', show=True)

# detrend with sinusoids
x = np.random.randn(1000, 2)
x += 2 * np.sin(2 * np.pi * np.arange(1000) / 200)[:, None]
y, _, _ = detrend(x, 5, basis='sinusoids', show=True)

# assert_almost_equal(yy[100:], data[100:], decimal=1)

def test_ringing():
"""Test reduce_ringing function."""
Expand All @@ -87,4 +105,4 @@ def test_ringing():
if __name__ == '__main__':
import pytest
pytest.main([__file__])
# test_detrend(True)
# test_detrend(False)