diff --git a/meegkit/dss.py b/meegkit/dss.py index 0277020e..cc32280a 100644 --- a/meegkit/dss.py +++ b/meegkit/dss.py @@ -194,7 +194,7 @@ def dss_line(x, fline, sfreq, nremove=1, nfft=1024, nkeep=None, show=False): if nkeep is not None: xxx_cov = tscov(xxx)[0] V, _ = pca(xxx_cov, nkeep) - xxxx = xxx * V + xxxx = xxx @ V else: xxxx = xxx.copy() diff --git a/tests/test_dss.py b/tests/test_dss.py index ad6e36ec..250e69f7 100644 --- a/tests/test_dss.py +++ b/tests/test_dss.py @@ -128,13 +128,15 @@ def test_dss1(show=False): atol=1e-6) # use abs as DSS component might be flipped -def test_dss_line(): +@pytest.mark.parametrize('nkeep', [None, 2]) +def test_dss_line(nkeep): """Test line noise removal.""" sr = 200 + fline = 20 nsamples = 10000 nchans = 10 x = np.random.randn(nsamples, nchans) - artifact = np.sin(np.arange(nsamples) / sr * 2 * np.pi * 20)[:, None] + artifact = np.sin(np.arange(nsamples) / sr * 2 * np.pi * fline)[:, None] artifact[artifact < 0] = 0 artifact = artifact ** 3 s = x + 10 * artifact @@ -155,24 +157,25 @@ def _plot(x): plt.show() # 2D case, n_outputs == 1 - out, _ = dss.dss_line(s, 20, sr) + out, _ = dss.dss_line(s, fline, sr, nkeep=nkeep) _plot(out) # Test n_outputs > 1 - out, _ = dss.dss_line(s, 20, sr, nremove=2) + out, _ = dss.dss_line(s, fline, sr, nkeep=nkeep, nremove=2) # _plot(out) # Test n_trials > 1 x = np.random.randn(nsamples, nchans, 4) - artifact = np.sin(np.arange(nsamples) / sr * 2 * np.pi * 20)[:, None, None] + artifact = np.sin( + np.arange(nsamples) / sr * 2 * np.pi * fline)[:, None, None] artifact[artifact < 0] = 0 artifact = artifact ** 3 s = x + 10 * artifact - out, _ = dss.dss_line(s, 20, sr, nremove=1) + out, _ = dss.dss_line(s, fline, sr, nremove=1) if __name__ == '__main__': pytest.main([__file__]) # create_data(SNR=5, show=True) # test_dss1(True) - # test_dss_line() + # test_dss_line(None)