From 60153ae1dcdaa3f8ba010d020cee033e5ec3da10 Mon Sep 17 00:00:00 2001 From: Roujansky Date: Wed, 3 Jun 2020 12:45:58 +0200 Subject: [PATCH] rm squeeze() in dss.dss0() --- meegkit/dss.py | 2 +- tests/test_dss.py | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/meegkit/dss.py b/meegkit/dss.py index 449ab4b6..10a3c588 100644 --- a/meegkit/dss.py +++ b/meegkit/dss.py @@ -101,7 +101,7 @@ def dss0(c0, c1, keep1=None, keep2=1e-9): W = np.sqrt(1. / eigval0) # diagonal of whitening matrix # c1 is projected into whitened PCA space of data channels - c2 = (W * eigvec0.squeeze()).T.dot(c1).dot(eigvec0.squeeze()) * W + c2 = (W * eigvec0).T.dot(c1).dot(eigvec0) * W # proj. matrix from whitened data space to a space maximizing bias eigvec2, eigval2 = pca(c2, max_comps=keep1, thresh=keep2) diff --git a/tests/test_dss.py b/tests/test_dss.py index 3a806a4b..2c9a1e67 100644 --- a/tests/test_dss.py +++ b/tests/test_dss.py @@ -1,19 +1,25 @@ """Test DSS functions.""" import matplotlib.pyplot as plt import numpy as np -from meegkit import dss -from meegkit.utils import demean, fold, rms, tscov, unfold +import pytest from numpy.testing import assert_allclose from scipy import signal +from meegkit import dss +from meegkit.utils import demean, fold, rms, tscov, unfold -def test_dss0(): + +@pytest.mark.parametrize('n_bad_chans', [0, -1]) +def test_dss0(n_bad_chans): """Test dss0. Find the linear combinations of multichannel data that maximize repeatability over trials. Data are time * channel * trials. Uses dss0(). + + `n_bad_chans` set the values of the first corresponding number of channels + to zero. """ # create synthetic data n_samples = 100 * 3 @@ -30,6 +36,9 @@ def test_dss0(): s = s[:, :, np.newaxis] s = np.tile(s, (1, 1, 100)) + # set first `n_bad_chans` to zero + s[:, :n_bad_chans] = 0. + # noise noise = np.dot( unfold(np.random.randn(n_samples, noise_dim, n_trials)), @@ -81,6 +90,5 @@ def test_dss_line(): plt.show() if __name__ == '__main__': - import pytest pytest.main([__file__]) # test_dss_line()