From bfd0830edf04512bf8077dd72453aa8d8a423d78 Mon Sep 17 00:00:00 2001 From: nbara <10333715+nbara@users.noreply.github.com> Date: Mon, 20 Sep 2021 14:32:08 +0200 Subject: [PATCH 1/3] [FIX] Use linalg.eigh in RESS --- meegkit/asr.py | 16 ++++++---------- meegkit/dss.py | 2 +- meegkit/ress.py | 15 +++++++++++++-- tests/test_ress.py | 4 ++-- 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/meegkit/asr.py b/meegkit/asr.py index 14b6bc21..2674e376 100755 --- a/meegkit/asr.py +++ b/meegkit/asr.py @@ -363,25 +363,21 @@ def clean_windows(X, sfreq, max_bad_chans=0.2, zthresholds=[-3.5, 5], # combine the three masks remove_mask = np.logical_or.reduce((mask1, mask2, mask3)) - removed_wins = np.where(remove_mask) + removed_wins = np.where(remove_mask)[0] # reconstruct the samples to remove sample_maskidx = [] - for i in range(len(removed_wins[0])): + for i, win in enumerate(removed_wins): if i == 0: - sample_maskidx = np.arange( - offsets[removed_wins[0][i]], offsets[removed_wins[0][i]] + N) + sample_maskidx = np.arange(offsets[win], offsets[win] + N) else: - sample_maskidx = np.vstack(( - sample_maskidx, - np.arange(offsets[removed_wins[0][i]], - offsets[removed_wins[0][i]] + N) - )) + sample_maskidx = np.r_[(sample_maskidx, + np.arange(offsets[win], offsets[win] + N))] # delete the bad chunks from the data sample_mask2remove = np.unique(sample_maskidx) if sample_mask2remove.size: - clean = np.delete(X, sample_mask2remove, 1) + clean = np.delete(X, sample_mask2remove, axis=1) sample_mask = np.ones((1, ns), dtype=bool) sample_mask[0, sample_mask2remove] = False else: diff --git a/meegkit/dss.py b/meegkit/dss.py index 1a1bd0f5..f12cac8b 100644 --- a/meegkit/dss.py +++ b/meegkit/dss.py @@ -35,7 +35,7 @@ def dss1(X, weights=None, keep1=None, keep2=1e-12): Power per component (averaged). """ - n_samples, n_chans, n_trials = theshapeof(X) + n_trials = theshapeof(X)[-1] # if demean: # remove weighted mean # X = demean(X, weights) diff --git a/meegkit/ress.py b/meegkit/ress.py index 5d9d916e..dbb33a84 100644 --- a/meegkit/ress.py +++ b/meegkit/ress.py @@ -7,7 +7,7 @@ def RESS(X, sfreq: int, peak_freq: float, neig_freq: float = 1, peak_width: float = .5, neig_width: float = 1, n_keep: int = 1, - return_maps: bool = False): + gamma: float = 0.01, return_maps: bool = False): """Rhythmic Entrainment Source Separation. As described in [1]_. @@ -29,6 +29,10 @@ def RESS(X, sfreq: int, peak_freq: float, neig_freq: float = 1, FWHM of the neighboring frequencies (default=1). n_keep : int Number of components to keep (default=1). -1 keeps all components. + gamma : float + Regularization coefficient, between 0 and 1 (default=0.01, which + corresponds to 1 % regularization and helps reduce numerical problems + for noisy or reduced-rank matrices [2]_). return_maps : bool If True, also output mixing (to_ress) and unmixing matrices (from_ress), used to transform the data into RESS component space and @@ -67,6 +71,9 @@ def RESS(X, sfreq: int, peak_freq: float, neig_freq: float = 1, .. [1] Cohen, M. X., & Gulbinaite, R. (2017). Rhythmic entrainment source separation: Optimizing analyses of neural responses to rhythmic sensory stimulation. Neuroimage, 147, 43-56. + .. [2] Cohen, M. X. (2021). A tutorial on generalized eigendecomposition + for source separation in multichannel electrophysiology. + ArXiv:2104.12356 [Eess, q-Bio]. """ n_samples, n_chans, n_trials = theshapeof(X) @@ -82,8 +89,12 @@ def RESS(X, sfreq: int, peak_freq: float, neig_freq: float = 1, fwhm=neig_width, n_harm=1)) c1, _ = tscov(gaussfilt(X, sfreq, peak_freq, fwhm=peak_width, n_harm=1)) + # add 1% regularization to avoid numerical precision problems in the GED + c0 = (c01 + c02) / 2 + c0 = c0 * (1 - gamma) + gamma * np.trace(c0) / len(c0) * np.eye(len(c0)) + # perform generalized eigendecomposition - d, to_ress = linalg.eig(c1, (c01 + c02) / 2) + d, to_ress = linalg.eigh(c1, c0) d = d.real to_ress = to_ress.real diff --git a/tests/test_ress.py b/tests/test_ress.py index 0bbad244..df42a7bf 100644 --- a/tests/test_ress.py +++ b/tests/test_ress.py @@ -155,5 +155,5 @@ def test_ress(target, n_trials, peak_width, neig_width, neig_freq, show=False): if __name__ == '__main__': import pytest - pytest.main([__file__]) - # test_ress(12, 20, 1, 1, 1, show=True) + # pytest.main([__file__]) + test_ress(20, 16, 1, 1, 1, show=False) From 1247ea9c1379b263ad2256d11444b23d24295e72 Mon Sep 17 00:00:00 2001 From: nbara <10333715+nbara@users.noreply.github.com> Date: Mon, 20 Sep 2021 14:33:10 +0200 Subject: [PATCH 2/3] Update dss.py --- meegkit/dss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/meegkit/dss.py b/meegkit/dss.py index f12cac8b..f408951c 100644 --- a/meegkit/dss.py +++ b/meegkit/dss.py @@ -21,7 +21,7 @@ def dss1(X, weights=None, keep1=None, keep2=1e-12): keep1: int Number of PCs to retain in function:`dss0` (default=all). keep2: float - Ignore PCs smaller than keep2 in function:`dss0` (default=10^-12). + Ignore PCs smaller than keep2 in function:`dss0` (default=1e-12). Returns ------- From c1546001597fc8011a4171ed8323874dc326628e Mon Sep 17 00:00:00 2001 From: nbara <10333715+nbara@users.noreply.github.com> Date: Mon, 20 Sep 2021 14:38:53 +0200 Subject: [PATCH 3/3] Update test_ress.py --- tests/test_ress.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_ress.py b/tests/test_ress.py index df42a7bf..d5c39d41 100644 --- a/tests/test_ress.py +++ b/tests/test_ress.py @@ -155,5 +155,5 @@ def test_ress(target, n_trials, peak_width, neig_width, neig_freq, show=False): if __name__ == '__main__': import pytest - # pytest.main([__file__]) - test_ress(20, 16, 1, 1, 1, show=False) + pytest.main([__file__]) + # test_ress(20, 16, 1, 1, 1, show=False)