diff --git a/meegkit/asr.py b/meegkit/asr.py index 14b6bc21..2674e376 100755 --- a/meegkit/asr.py +++ b/meegkit/asr.py @@ -363,25 +363,21 @@ def clean_windows(X, sfreq, max_bad_chans=0.2, zthresholds=[-3.5, 5], # combine the three masks remove_mask = np.logical_or.reduce((mask1, mask2, mask3)) - removed_wins = np.where(remove_mask) + removed_wins = np.where(remove_mask)[0] # reconstruct the samples to remove sample_maskidx = [] - for i in range(len(removed_wins[0])): + for i, win in enumerate(removed_wins): if i == 0: - sample_maskidx = np.arange( - offsets[removed_wins[0][i]], offsets[removed_wins[0][i]] + N) + sample_maskidx = np.arange(offsets[win], offsets[win] + N) else: - sample_maskidx = np.vstack(( - sample_maskidx, - np.arange(offsets[removed_wins[0][i]], - offsets[removed_wins[0][i]] + N) - )) + sample_maskidx = np.r_[(sample_maskidx, + np.arange(offsets[win], offsets[win] + N))] # delete the bad chunks from the data sample_mask2remove = np.unique(sample_maskidx) if sample_mask2remove.size: - clean = np.delete(X, sample_mask2remove, 1) + clean = np.delete(X, sample_mask2remove, axis=1) sample_mask = np.ones((1, ns), dtype=bool) sample_mask[0, sample_mask2remove] = False else: diff --git a/meegkit/dss.py b/meegkit/dss.py index 1a1bd0f5..f408951c 100644 --- a/meegkit/dss.py +++ b/meegkit/dss.py @@ -21,7 +21,7 @@ def dss1(X, weights=None, keep1=None, keep2=1e-12): keep1: int Number of PCs to retain in function:`dss0` (default=all). keep2: float - Ignore PCs smaller than keep2 in function:`dss0` (default=10^-12). + Ignore PCs smaller than keep2 in function:`dss0` (default=1e-12). Returns ------- @@ -35,7 +35,7 @@ def dss1(X, weights=None, keep1=None, keep2=1e-12): Power per component (averaged). """ - n_samples, n_chans, n_trials = theshapeof(X) + n_trials = theshapeof(X)[-1] # if demean: # remove weighted mean # X = demean(X, weights) diff --git a/meegkit/ress.py b/meegkit/ress.py index 5d9d916e..dbb33a84 100644 --- a/meegkit/ress.py +++ b/meegkit/ress.py @@ -7,7 +7,7 @@ def RESS(X, sfreq: int, peak_freq: float, neig_freq: float = 1, peak_width: float = .5, neig_width: float = 1, n_keep: int = 1, - return_maps: bool = False): + gamma: float = 0.01, return_maps: bool = False): """Rhythmic Entrainment Source Separation. As described in [1]_. @@ -29,6 +29,10 @@ def RESS(X, sfreq: int, peak_freq: float, neig_freq: float = 1, FWHM of the neighboring frequencies (default=1). n_keep : int Number of components to keep (default=1). -1 keeps all components. + gamma : float + Regularization coefficient, between 0 and 1 (default=0.01, which + corresponds to 1 % regularization and helps reduce numerical problems + for noisy or reduced-rank matrices [2]_). return_maps : bool If True, also output mixing (to_ress) and unmixing matrices (from_ress), used to transform the data into RESS component space and @@ -67,6 +71,9 @@ def RESS(X, sfreq: int, peak_freq: float, neig_freq: float = 1, .. [1] Cohen, M. X., & Gulbinaite, R. (2017). Rhythmic entrainment source separation: Optimizing analyses of neural responses to rhythmic sensory stimulation. Neuroimage, 147, 43-56. + .. [2] Cohen, M. X. (2021). A tutorial on generalized eigendecomposition + for source separation in multichannel electrophysiology. + ArXiv:2104.12356 [Eess, q-Bio]. """ n_samples, n_chans, n_trials = theshapeof(X) @@ -82,8 +89,12 @@ def RESS(X, sfreq: int, peak_freq: float, neig_freq: float = 1, fwhm=neig_width, n_harm=1)) c1, _ = tscov(gaussfilt(X, sfreq, peak_freq, fwhm=peak_width, n_harm=1)) + # add 1% regularization to avoid numerical precision problems in the GED + c0 = (c01 + c02) / 2 + c0 = c0 * (1 - gamma) + gamma * np.trace(c0) / len(c0) * np.eye(len(c0)) + # perform generalized eigendecomposition - d, to_ress = linalg.eig(c1, (c01 + c02) / 2) + d, to_ress = linalg.eigh(c1, c0) d = d.real to_ress = to_ress.real diff --git a/tests/test_ress.py b/tests/test_ress.py index 0bbad244..d5c39d41 100644 --- a/tests/test_ress.py +++ b/tests/test_ress.py @@ -156,4 +156,4 @@ def test_ress(target, n_trials, peak_width, neig_width, neig_freq, show=False): if __name__ == '__main__': import pytest pytest.main([__file__]) - # test_ress(12, 20, 1, 1, 1, show=True) + # test_ress(20, 16, 1, 1, 1, show=False)