Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion meegkit/dss.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def dss0(c0, c1, keep1=None, keep2=1e-9):
W = np.sqrt(1. / eigval0) # diagonal of whitening matrix

# c1 is projected into whitened PCA space of data channels
c2 = (W * eigvec0.squeeze()).T.dot(c1).dot(eigvec0.squeeze()) * W
c2 = (W * eigvec0).T.dot(c1).dot(eigvec0) * W

# proj. matrix from whitened data space to a space maximizing bias
eigvec2, eigval2 = pca(c2, max_comps=keep1, thresh=keep2)
Expand Down
16 changes: 12 additions & 4 deletions tests/test_dss.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
"""Test DSS functions."""
import matplotlib.pyplot as plt
import numpy as np
from meegkit import dss
from meegkit.utils import demean, fold, rms, tscov, unfold
import pytest
from numpy.testing import assert_allclose
from scipy import signal

from meegkit import dss
from meegkit.utils import demean, fold, rms, tscov, unfold

def test_dss0():

@pytest.mark.parametrize('n_bad_chans', [0, -1])
def test_dss0(n_bad_chans):
"""Test dss0.

Find the linear combinations of multichannel data that
maximize repeatability over trials. Data are time * channel * trials.

Uses dss0().

`n_bad_chans` set the values of the first corresponding number of channels
to zero.
"""
# create synthetic data
n_samples = 100 * 3
Expand All @@ -30,6 +36,9 @@ def test_dss0():
s = s[:, :, np.newaxis]
s = np.tile(s, (1, 1, 100))

# set first `n_bad_chans` to zero
s[:, :n_bad_chans] = 0.

# noise
noise = np.dot(
unfold(np.random.randn(n_samples, noise_dim, n_trials)),
Expand Down Expand Up @@ -81,6 +90,5 @@ def test_dss_line():
plt.show()

if __name__ == '__main__':
import pytest
pytest.main([__file__])
# test_dss_line()