Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions meegkit/dss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import numpy as np
from scipy import linalg

from .utils import (demean, mean_over_trials, smooth, pca, theshapeof,
tscov, gaussfilt, wpwr)
from .tspca import tsr
from .utils import (demean, gaussfilt, mean_over_trials, pca, smooth,
theshapeof, tscov, wpwr)


def dss1(data, weights=None, keep1=None, keep2=1e-12):
Expand Down Expand Up @@ -175,7 +175,7 @@ def dss_line(x, fline, sfreq, nremove=1, nfft=1024, nkeep=None, show=False):
if x.shape[0] < nfft:
print('reducing nfft to {}'.format(x.shape[0]))
nfft = x.shape[0]

n_samples, n_chans, n_trials = theshapeof(x)
x = demean(x)

# cancels line_frequency and harmonics, light lowpass
Expand Down Expand Up @@ -208,7 +208,12 @@ def dss_line(x, fline, sfreq, nremove=1, nfft=1024, nkeep=None, show=False):
plt.show()

idx_remove = np.arange(nremove)
xxxx = xxxx @ todss[:, idx_remove] # line-dominated components
if x.ndim == 3:
for t in range(n_trials): # line-dominated components
xxxx[..., t] = xxxx[..., t] @ todss[:, idx_remove]
elif x.ndim == 2:
xxxx = xxxx @ todss[:, idx_remove]

xxx, _, _, _ = tsr(xxx, xxxx) # project them out

# reconstruct clean signal
Expand Down
26 changes: 12 additions & 14 deletions meegkit/tspca.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def tsr(X, R, shifts=None, wX=None, wR=None, keep=None, thresh=1e-12):
offset1 = np.max((0, -np.min(shifts)))
idx = np.arange(offset1, X.shape[0])
# X = X[idx, ...]

# if len(wX) > 0:
# wX = wX[idx, ...]
# if len(wR) > 0:
Expand All @@ -130,10 +129,9 @@ def tsr(X, R, shifts=None, wX=None, wR=None, keep=None, thresh=1e-12):

# adjust size of X
offset2 = np.max((0, np.max(shifts)))
idx = np.arange(X.shape[0]) - offset2
idx = idx[idx >= 0]
# idx = np.arange(X.shape[0]) - offset2
# idx = idx[idx >= 0]
# X = X[idx, ...]

# if len(wX) > 0:
# wX = wX[idx, ...]

Expand All @@ -157,21 +155,21 @@ def tsr(X, R, shifts=None, wX=None, wR=None, keep=None, thresh=1e-12):
weights[..., t] = wr

wX = weights
wR = np.zeros((n_samples_R, 1, n_trials_R))
wR = weights

# remove weighted means
X, mean1 = demean(X, wX, return_mean=True)
R = demean(R, wR)

# equalize power of R channels, the equalize power of the R PCs
if R.shape[1] > 1:
R = normcol(R, wR)
C, _ = tscov(R, [])
V, _ = pca(C, thresh=1e-6)
R = R * V

# if R.shape[1] > 1:
R = normcol(R, wR)
C, _ = tscov(R)
V, _ = pca(C, thresh=1e-6)
z = np.zeros((n_samples_X, V.shape[1], n_trials_R))
for t in range(n_trials_R):
z[..., t] = R[..., t] @ V
R = normcol(z, wR)

# covariances and cross-covariance with time-shifted refs
Cr, twcr = tscov(R, shifts, wR)
Expand All @@ -182,10 +180,10 @@ def tsr(X, R, shifts=None, wX=None, wR=None, keep=None, thresh=1e-12):

# TSPCA: clean x by removing regression on time-shifted refs
y = np.zeros((n_samples_X, n_chans_X, n_trials_X))
for trial in np.arange(n_trials_X):
r = multishift(R[..., trial], shifts, reshape=True)
for t in np.arange(n_trials_X):
r = multishift(R[..., t], shifts, reshape=True)
z = r @ regression
y[..., trial] = X[:z.shape[0], :, trial] - z
y[..., t] = X[:z.shape[0], :, t] - z

y, mean2 = demean(y, wX, return_mean=True)

Expand Down
5 changes: 1 addition & 4 deletions meegkit/utils/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,8 +553,8 @@ def normcol(X, weights=None, return_norm=False):
"""
if X.ndim == 3:
n_samples, n_chans, n_trials = theshapeof(X)
X = unfold(X)
weights = _check_weights(weights, X)
X = unfold(X)
if not weights.any(): # no weights
X_norm, N = normcol(X, return_norm=True)
N = N ** 2
Expand All @@ -568,9 +568,6 @@ def normcol(X, weights=None, return_norm=False):
if weights.ndim == 2 and weights.shape[1] == 1:
weights = np.tile(weights, (1, n_samples, n_trials))

if weights.shape != weights.shape:
raise ValueError("Weight array should have be same shape as X")

weights = unfold(weights)
X_norm, N = normcol(X, weights, return_norm=True)
N = N ** 2
Expand Down
6 changes: 5 additions & 1 deletion meegkit/utils/sig.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,11 @@ def gaussfilt(data, srate, f, fwhm, n_harm=1, shift=0, return_empvals=False,

# filter
tmp = np.fft.fft(data, axis=0)
tmp *= fx[:, None]
if data.ndim == 2:
tmp *= fx[:, None]
elif data.ndim == 3:
tmp *= fx[:, None, None]

filtdat = 2 * np.real(np.fft.ifft(tmp, axis=0))

if return_empvals or show:
Expand Down
48 changes: 32 additions & 16 deletions tests/test_dss.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,22 +73,38 @@ def test_dss_line():
artifact = artifact ** 3
s = x + 10 * artifact

def _plot(x):
f, ax = plt.subplots(1, 2, sharey=True)
f, Pxx = signal.welch(x, sr, nperseg=1024, axis=0,
return_onesided=True)
ax[1].semilogy(f, Pxx)
f, Pxx = signal.welch(s, sr, nperseg=1024, axis=0,
return_onesided=True)
ax[0].semilogy(f, Pxx)
ax[0].set_xlabel('frequency [Hz]')
ax[1].set_xlabel('frequency [Hz]')
ax[0].set_ylabel('PSD [V**2/Hz]')
ax[0].set_title('before')
ax[1].set_title('after')
plt.show()

# 2D case, n_outputs == 1
out, _ = dss.dss_line(s, 20, sr)
f, ax = plt.subplots(1, 2, sharey=True)
print(out.shape)
f, Pxx = signal.welch(out, sr, nperseg=1024, axis=0,
return_onesided=True)
ax[1].semilogy(f, Pxx)
f, Pxx = signal.welch(s, sr, nperseg=1024, axis=0,
return_onesided=True)
ax[0].semilogy(f, Pxx)
ax[0].set_xlabel('frequency [Hz]')
ax[1].set_xlabel('frequency [Hz]')
ax[0].set_ylabel('PSD [V**2/Hz]')
ax[0].set_title('before')
ax[1].set_title('after')
plt.show()
_plot(out)

# Test n_outputs > 1
out, _ = dss.dss_line(s, 20, sr, nremove=2)
# _plot(out)

# Test n_trials > 1
x = np.random.randn(nsamples, nchans, 4)
artifact = np.sin(np.arange(nsamples) / sr * 2 * np.pi * 20)[:, None, None]
artifact[artifact < 0] = 0
artifact = artifact ** 3
s = x + 10 * artifact
out, _ = dss.dss_line(s, 20, sr, nremove=1)


if __name__ == '__main__':
pytest.main([__file__])
# test_dss_line()
# pytest.main([__file__])
test_dss_line()
31 changes: 19 additions & 12 deletions tests/test_tspca.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,39 +64,46 @@ def test_tsr(show=True):
# artifact + harmonics
artifact = np.sin(np.arange(nsamples) / sr * 2 * np.pi * 10)[:, None]
artifact[artifact < 0] = 0
artifact = artifact ** 3
signal = x + 10 * artifact
artifact = 5 * artifact ** 3
signal = x + artifact

signal -= np.mean(signal, keepdims=True)
artifact -= np.mean(artifact, keepdims=True)

# Without shifts
y, idx, mean_total, weights = tspca.tsr(
signal,
artifact,
shifts=[0])

if show:
f, ax = plt.subplots(2, 1)
ax[0].plot(y[:500, 0], 'grey', label='cleaned')
ax[0].plot(x[:500, 0], ':', label='signal')
ax[1].plot((y - x)[:500], label='cleaned - signal')
f, ax = plt.subplots(2, 1, num='without shifts')
ax[0].plot(y[:500, 0], 'grey', label='recovered signal')
ax[0].plot(x[:500, 0], ':', label='real signal')
ax[1].plot((y - x)[:500], label='residual')
ax[0].legend()
ax[1].legend()
# plt.show()

# Test residual almost 0.0
np.testing.assert_almost_equal(y - x, np.zeros_like(y), decimal=1)

# With shifts
# With shifts. We slide the input array by one sample, and check that the
# artifact is successfully regressed.
y, idx, mean_total, weights = tspca.tsr(
signal + artifact,
signal,
np.roll(artifact, 1, axis=0),
shifts=[-1, 0, 1])

if show:
f, ax = plt.subplots(2, 1)
ax[0].plot(y[:500, 0], 'grey', label='cleaned')
ax[0].plot(x[:500, 0], ':', label='signal')
ax[1].plot((y - x)[:500, 0], label='cleaned - signal')
f, ax = plt.subplots(3, 1, num='with shifts')
ax[0].plot(signal[:500], label='signal + noise')
ax[1].plot(x[:500, 0], 'grey', label='real signal')
ax[1].plot(y[:500, 0], ':', label='recovered signal')
ax[2].plot((signal - y)[:500, 0], label='before - after')
ax[0].legend()
ax[1].legend()
ax[2].legend()
plt.show()

if __name__ == '__main__':
Expand Down