Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 85 additions & 9 deletions examples/example_ress.ipynb

Large diffs are not rendered by default.

20 changes: 6 additions & 14 deletions examples/example_ress.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import scipy.signal as ss

from meegkit import ress
from meegkit.utils import unfold, rms, fold
from meegkit.utils import unfold, rms, fold, snr_spectrum

# import config

Expand Down Expand Up @@ -73,21 +73,13 @@
df = sfreq / nfft # frequency resolution
bins, psd = ss.welch(out, sfreq, window="hamming", nperseg=nfft,
noverlap=125, axis=0)
psd = psd.mean(axis=1) # average over trials

# Loop over frequencies and compute SNR
skipbins = 1 # skip bins directly next to frequency of interest
n_bins = int(3 / df) # number of bins to average over
snr = np.zeros(len(bins))
for ibin in range(n_bins + 1, len(bins) - n_bins - 1):
numer = psd[ibin]
irange = np.r_[np.arange(ibin - n_bins, ibin - skipbins),
np.arange(ibin + skipbins + 1, ibin + n_bins)]
denom = np.mean(psd[irange])
snr[ibin] = numer / denom # divide amplitude at peak by neighbours
psd = psd.mean(axis=1, keepdims=True) # average over trials
snr = snr_spectrum(psd, bins, skipbins=2, n_avg=2)

f, ax = plt.subplots(1)
ax.plot(bins, snr, 'o')
ax.plot(bins, snr, 'o', label='SNR')
ax.plot(bins[bins == target], snr[bins == target], 'ro', label='Target SNR')
ax.axhline(1, ls=':', c='grey', zorder=0)
ax.axvline(target, ls=':', c='grey', zorder=0)
ax.set_ylabel('SNR (a.u.)')
ax.set_xlabel('Frequency (Hz)')
Expand Down
44 changes: 18 additions & 26 deletions meegkit/utils/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def snr_spectrum(data, freqs, n_avg=1, n_harm=1, skipbins=1):

Parameters
----------
data : ndarray , shape=([n_trials, ]n_chans, n_freqs)
data : ndarray , shape=(n_freqs, n_chans, [n_trials, ])
Power spectrum.
freqs : array, shape=(n_freqs,)
Frequency bins.
Expand Down Expand Up @@ -316,19 +316,19 @@ def snr_spectrum(data, freqs, n_avg=1, n_harm=1, skipbins=1):

"""
if data.ndim == 3:
n_trials = data.shape[0]
n_freqs = data.shape[0]
n_chans = data.shape[1]
n_freqs = data.shape[-1]
n_trials = data.shape[-1]
elif data.ndim == 2:
n_trials = 1
n_chans = data.shape[0]
n_freqs = data.shape[-1]
n_freqs = data.shape[0]
n_chans = data.shape[1]
else:
raise ValueError('Data must have shape (n_trials, n_chans, n_freqs)'
' or (n_chans, n_freqs)')
raise ValueError('Data must have shape (n_freqs, n_chans, [n_trials,])'
f', got {data.shape}')

# Number of points to get desired resolution
data = np.reshape(data, (n_trials * n_chans, n_freqs))
data = np.reshape(data, (n_freqs, n_chans * n_trials))
SNR = np.zeros_like(data)

for i_bin in range(n_freqs):
Expand All @@ -351,12 +351,12 @@ def snr_spectrum(data, freqs, n_avg=1, n_harm=1, skipbins=1):
# Now get indices of noise (i.e., neighbouring FFT bins)
# eg if currentbin=54, navg=3, skipbins=1 :
# bin_noise = 51, 52, 56, 57
tmp = np.stack((np.arange(bin_peaks[h] - skipbins - n_avg,
bin_peaks[h] - skipbins),
np.arange(bin_peaks[h] + skipbins + 1,
bin_peaks[h] + skipbins + 1 + n_avg)
))
tmp = tmp.flatten().astype(int)
tmp = np.r_[
(np.arange(bin_peaks[h] - skipbins - n_avg,
bin_peaks[h] - skipbins),
np.arange(bin_peaks[h] + skipbins + 1,
bin_peaks[h] + skipbins + 1 + n_avg))]
tmp = tmp.astype(int)

# Remove impossible bin values (eg <1 or >n_samp)
tmp = [t for t in tmp if t >= 0 and t < n_freqs]
Expand All @@ -369,24 +369,16 @@ def snr_spectrum(data, freqs, n_avg=1, n_harm=1, skipbins=1):
for i_trial in range(n_trials * n_chans):

# RMS of signal over fundamental+harmonics
A = data[i_trial, bin_peaks]
A = data[bin_peaks, i_trial]

# Noise around fundamental+harmonics
B = np.zeros(len(bin_noise))
for h in range(len(B)):

if n_trials > 1:
# Mean over samples and median over trials
B[h] = np.median(np.mean(data[i_trial::n_chans,
bin_noise[h]],
1))
else:
# Mean over samples
B[h] = np.mean(data[i_trial, bin_noise[h]])
B[h] = np.mean(data[bin_noise[h], i_trial::n_trials].flatten())

# Ratio
with np.errstate(divide='ignore', invalid='ignore'):
SNR[i_trial, i_bin] = np.sqrt(np.sum(A)) / np.sqrt(np.sum(B))
SNR[i_bin, i_trial] = np.sqrt(np.sum(A)) / np.sqrt(np.sum(B))
SNR[np.abs(SNR) == np.inf] = 1
SNR[SNR == 0] = 1
SNR[np.isnan(SNR)] = 1
Expand All @@ -396,6 +388,6 @@ def snr_spectrum(data, freqs, n_avg=1, n_harm=1, skipbins=1):

# Reshape matrix if necessary
if np.min((n_trials, n_chans)) > 1:
SNR = np.reshape(SNR, (n_trials, n_chans, n_freqs))
SNR = np.reshape(SNR, (n_freqs, n_chans, n_trials))

return SNR
25 changes: 8 additions & 17 deletions tests/test_ress.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import pytest
import scipy.signal as ss
from meegkit import ress
from meegkit.utils import fold, rms, unfold
from numpy.testing import assert_allclose
from meegkit.utils import fold, rms, unfold, snr_spectrum


def create_data(n_times, n_chans=10, n_trials=50, freq=12, sfreq=250,
Expand Down Expand Up @@ -56,31 +55,22 @@ def test_ress(target, n_trials, show=False):
"""Test RESS."""
sfreq = 250
data, source = create_data(n_times=1000, n_trials=n_trials, freq=target,
sfreq=sfreq, show=show)
sfreq=sfreq, show=False)

out = ress.RESS(data, sfreq=sfreq, peak_freq=target)

nfft = 250
df = sfreq / nfft # frequency resolution
bins, psd = ss.welch(out, sfreq, window="hamming", nperseg=nfft,
noverlap=125, axis=0)
psd = psd.mean(axis=1) # average over trials

skipbins = 1 # .5 Hz, hard-coded!
n_bins = int(3 / df) # 2 Hz

# loop over frequencies and compute SNR
snr = np.zeros(len(bins))
for ibin in range(n_bins + 1, len(bins) - n_bins - 1):
numer = psd[ibin]
irange = np.r_[np.arange(ibin - n_bins, ibin - skipbins),
np.arange(ibin + skipbins + 1, ibin + n_bins)]
denom = np.mean(psd[irange])
snr[ibin] = numer / denom
print(psd.shape)
psd = psd.mean(axis=1, keepdims=True) # average over trials
snr = snr_spectrum(psd, bins, skipbins=2, n_avg=2)

if show:
f, ax = plt.subplots(1)
ax.plot(bins, snr, 'o')
ax.axhline(1, ls=':', c='grey', zorder=0)
ax.axvline(target, ls=':', c='grey', zorder=0)
ax.set_ylabel('SNR (a.u.)')
ax.set_xlabel('Frequency (Hz)')
Expand All @@ -93,4 +83,5 @@ def test_ress(target, n_trials, show=False):

if __name__ == '__main__':
import pytest
pytest.main([__file__])
# pytest.main([__file__])
test_ress(12, 10, show=True)