Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions meegkit/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,25 +363,21 @@ def clean_windows(X, sfreq, max_bad_chans=0.2, zthresholds=[-3.5, 5],

# combine the three masks
remove_mask = np.logical_or.reduce((mask1, mask2, mask3))
removed_wins = np.where(remove_mask)
removed_wins = np.where(remove_mask)[0]

# reconstruct the samples to remove
sample_maskidx = []
for i in range(len(removed_wins[0])):
for i, win in enumerate(removed_wins):
if i == 0:
sample_maskidx = np.arange(
offsets[removed_wins[0][i]], offsets[removed_wins[0][i]] + N)
sample_maskidx = np.arange(offsets[win], offsets[win] + N)
else:
sample_maskidx = np.vstack((
sample_maskidx,
np.arange(offsets[removed_wins[0][i]],
offsets[removed_wins[0][i]] + N)
))
sample_maskidx = np.r_[(sample_maskidx,
np.arange(offsets[win], offsets[win] + N))]

# delete the bad chunks from the data
sample_mask2remove = np.unique(sample_maskidx)
if sample_mask2remove.size:
clean = np.delete(X, sample_mask2remove, 1)
clean = np.delete(X, sample_mask2remove, axis=1)
sample_mask = np.ones((1, ns), dtype=bool)
sample_mask[0, sample_mask2remove] = False
else:
Expand Down
4 changes: 2 additions & 2 deletions meegkit/dss.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def dss1(X, weights=None, keep1=None, keep2=1e-12):
keep1: int
Number of PCs to retain in function:`dss0` (default=all).
keep2: float
Ignore PCs smaller than keep2 in function:`dss0` (default=10^-12).
Ignore PCs smaller than keep2 in function:`dss0` (default=1e-12).

Returns
-------
Expand All @@ -35,7 +35,7 @@ def dss1(X, weights=None, keep1=None, keep2=1e-12):
Power per component (averaged).

"""
n_samples, n_chans, n_trials = theshapeof(X)
n_trials = theshapeof(X)[-1]

# if demean: # remove weighted mean
# X = demean(X, weights)
Expand Down
15 changes: 13 additions & 2 deletions meegkit/ress.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def RESS(X, sfreq: int, peak_freq: float, neig_freq: float = 1,
peak_width: float = .5, neig_width: float = 1, n_keep: int = 1,
return_maps: bool = False):
gamma: float = 0.01, return_maps: bool = False):
"""Rhythmic Entrainment Source Separation.

As described in [1]_.
Expand All @@ -29,6 +29,10 @@ def RESS(X, sfreq: int, peak_freq: float, neig_freq: float = 1,
FWHM of the neighboring frequencies (default=1).
n_keep : int
Number of components to keep (default=1). -1 keeps all components.
gamma : float
Regularization coefficient, between 0 and 1 (default=0.01, which
corresponds to 1 % regularization and helps reduce numerical problems
for noisy or reduced-rank matrices [2]_).
return_maps : bool
If True, also output mixing (to_ress) and unmixing matrices
(from_ress), used to transform the data into RESS component space and
Expand Down Expand Up @@ -67,6 +71,9 @@ def RESS(X, sfreq: int, peak_freq: float, neig_freq: float = 1,
.. [1] Cohen, M. X., & Gulbinaite, R. (2017). Rhythmic entrainment source
separation: Optimizing analyses of neural responses to rhythmic sensory
stimulation. Neuroimage, 147, 43-56.
.. [2] Cohen, M. X. (2021). A tutorial on generalized eigendecomposition
for source separation in multichannel electrophysiology.
ArXiv:2104.12356 [Eess, q-Bio].

"""
n_samples, n_chans, n_trials = theshapeof(X)
Expand All @@ -82,8 +89,12 @@ def RESS(X, sfreq: int, peak_freq: float, neig_freq: float = 1,
fwhm=neig_width, n_harm=1))
c1, _ = tscov(gaussfilt(X, sfreq, peak_freq, fwhm=peak_width, n_harm=1))

# add 1% regularization to avoid numerical precision problems in the GED
c0 = (c01 + c02) / 2
c0 = c0 * (1 - gamma) + gamma * np.trace(c0) / len(c0) * np.eye(len(c0))

# perform generalized eigendecomposition
d, to_ress = linalg.eig(c1, (c01 + c02) / 2)
d, to_ress = linalg.eigh(c1, c0)
d = d.real
to_ress = to_ress.real

Expand Down
2 changes: 1 addition & 1 deletion tests/test_ress.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,4 @@ def test_ress(target, n_trials, peak_width, neig_width, neig_freq, show=False):
if __name__ == '__main__':
import pytest
pytest.main([__file__])
# test_ress(12, 20, 1, 1, 1, show=True)
# test_ress(20, 16, 1, 1, 1, show=False)