From a7a66f00178b10bde358a522fbf5aea62cb90847 Mon Sep 17 00:00:00 2001 From: Roujansky Date: Tue, 17 Nov 2020 15:40:56 +0100 Subject: [PATCH] removed hardcoded values + added arg 'peak_width' in ress.RESS --- meegkit/ress.py | 13 ++++++++----- meegkit/utils/stats.py | 2 ++ tests/test_ress.py | 16 +++++++++++----- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/meegkit/ress.py b/meegkit/ress.py index 56b468d8..fa037b06 100644 --- a/meegkit/ress.py +++ b/meegkit/ress.py @@ -6,8 +6,8 @@ def RESS(X, sfreq: int, peak_freq: float, neig_freq: float = 1, - neig_width: float = 1, n_keep: int = 1, return_maps: bool = False, - show: bool = False): + peak_width: float = .5, neig_width: float = 1, n_keep: int = 1, + return_maps: bool = False, show: bool = False): """Rhythmic entrainment source separation [1]_. Parameters @@ -19,7 +19,10 @@ def RESS(X, sfreq: int, peak_freq: float, neig_freq: float = 1, peak_freq : float Peak frequency. neig_freq : float - Distance of neighbouring frequencies away from peak frequency, in Hz. + Distance of neighbouring frequencies away from peak frequency, +/- in + Hz (default=1). + peak_width : float + FWHM of the peak frequency (default=.5). neig_width : float FWHM of the neighboring frequencies (default=1). n_keep : int @@ -61,8 +64,8 @@ def RESS(X, sfreq: int, peak_freq: float, neig_freq: float = 1, c01, _ = tscov(gaussfilt(X, sfreq, peak_freq + neig_freq, fwhm=neig_width, n_harm=1)) c02, _ = tscov(gaussfilt(X, sfreq, peak_freq - neig_freq, - fwhm=1, n_harm=1)) - c1, _ = tscov(gaussfilt(X, sfreq, peak_freq, fwhm=1, n_harm=1)) + fwhm=neig_width, n_harm=1)) + c1, _ = tscov(gaussfilt(X, sfreq, peak_freq, fwhm=peak_width, n_harm=1)) # perform generalized eigendecomposition d, V = linalg.eig(c1, (c01 + c02) / 2) diff --git a/meegkit/utils/stats.py b/meegkit/utils/stats.py index f3e5b86b..84acaaf6 100644 --- a/meegkit/utils/stats.py +++ b/meegkit/utils/stats.py @@ -299,6 +299,8 @@ def snr_spectrum(data, freqs, n_avg=1, n_harm=1, skipbins=1): n_harm : int Compute SNR at each frequency bin as a pooled RMS over this bin and n_harm harmonics (see references below). + skipbins : int + Number of bins skipped to estimate noise of neighbouring bins. Returns ------- diff --git a/tests/test_ress.py b/tests/test_ress.py index aedb43f2..395dd362 100644 --- a/tests/test_ress.py +++ b/tests/test_ress.py @@ -51,16 +51,21 @@ def create_data(n_times, n_chans=10, n_trials=20, freq=12, sfreq=250, @pytest.mark.parametrize('target', [12, 15, 20]) @pytest.mark.parametrize('n_trials', [16, 20]) -def test_ress(target, n_trials, show=False): +@pytest.mark.parametrize('peak_width', [.5, 1]) +@pytest.mark.parametrize('neig_width', [.5, 1]) +@pytest.mark.parametrize('neig_freq', [.5, 1]) +def test_ress(target, n_trials, peak_width, neig_width, neig_freq, show=False): """Test RESS.""" sfreq = 250 data, source = create_data(n_times=1000, n_trials=n_trials, freq=target, sfreq=sfreq, show=False) - out = ress.RESS(data, sfreq=sfreq, peak_freq=target) + out = ress.RESS(data, sfreq=sfreq, peak_freq=target, neig_freq=neig_freq, + peak_width=peak_width, neig_width=neig_width) nfft = 500 - bins, psd = ss.welch(out.squeeze(1), sfreq, window="boxcar", nperseg=nfft, + bins, psd = ss.welch(out.squeeze(1), sfreq, window="boxcar", + nperseg=nfft / (peak_width * 2), noverlap=0, axis=0, average='mean') # psd = np.abs(np.fft.fft(out, nfft, axis=0)) # psd = psd[0:psd.shape[0] // 2 + 1] @@ -91,8 +96,9 @@ def test_ress(target, n_trials, show=False): assert (snr[(bins <= target - 2) | (bins >= target + 2)] < 2).all() # test multiple components - out, maps = ress.RESS(data, sfreq=sfreq, peak_freq=target, n_keep=1, - return_maps=True) + out, maps = ress.RESS(data, sfreq=sfreq, peak_freq=target, + neig_freq=neig_freq, peak_width=peak_width, + neig_width=neig_width, n_keep=1, return_maps=True) _ = ress.RESS(data, sfreq=sfreq, peak_freq=target, n_keep=2) _ = ress.RESS(data, sfreq=sfreq, peak_freq=target, n_keep=-1)