Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions meegkit/utils/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Math utils."""
from scipy.linalg import lstsq, solve
from scipy import linalg


def mrdivide(A, B):
Expand Down Expand Up @@ -36,7 +36,14 @@ def mldivide(A, B):
https://docs.scipy.org/doc/numpy/user/numpy-for-matlab-users.html

"""
if A.shape[0] == A.shape[1]:
return solve(A, B)
else:
return lstsq(A, B)
try:
# Note: we must use overwrite_a=False in order to be able to
# use the fall-back solution below in case a LinAlgError is raised
return linalg.solve(A, B, sym_pos=True, overwrite_a=False)
except linalg.LinAlgError:
# Singular matrix in solving dual problem. Using least-squares
# solution instead.
return linalg.lstsq(A, B, lapack_driver='gelsy')[0]
except linalg.LinAlgError:
print('Solution not stable. Model not updated!')
return None
34 changes: 20 additions & 14 deletions meegkit/utils/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,21 +287,23 @@ def snr_spectrum(data, freqs, n_avg=1, n_harm=1, skipbins=1):

Parameters
----------
data : ndarray , shape=(n_freqs, n_chans, [n_trials, ])
Power spectrum.
data : ndarray , shape=(n_freqs, n_chans,[ n_trials,])
One-sided power spectral density estimate, specified as a real-valued,
nonnegative array. The power spectral density must be expressed in
linear units, not decibels.
freqs : array, shape=(n_freqs,)
Frequency bins.
n_avg : int
Number of neighbour bins to estimate noise over. Make sure that this
value doesn't overlap with neighbouring target frequencies
value doesn't overlap with neighbouring target frequencies.
n_harm : int
Compute SNR at each frequency bin as a pooled RMS over this bin and
n_harm harmonics (see references below)
n_harm harmonics (see references below).

Returns
-------
SNR : ndarray, shape=(Nconds, Nchans, Nsamples) or (Nchans, Nsamples)
Signal-to-Noise-corrected spectrum
SNR : ndarray, shape=(n_freqs, n_chans, n_trials) or (n_freqs, n_chans)
Signal-to-Noise-corrected spectrum.

References
----------
Expand Down Expand Up @@ -362,30 +364,34 @@ def snr_spectrum(data, freqs, n_avg=1, n_harm=1, skipbins=1):
tmp = [t for t in tmp if t >= 0 and t < n_freqs]
bin_noise.append(tmp)
del tmp
else:
bin_peaks.append(0)
bin_noise.append(0)

# SNR at central bin is ratio between (power at central
# bin) to (average of N surrounding bins)
# --------------------------------------------------------------------------
for i_trial in range(n_trials * n_chans):
for i_trial in range(n_chans * n_trials):

# RMS of signal over fundamental+harmonics
A = data[bin_peaks, i_trial]
# Mean of signal over fundamental+harmonics
A = np.mean(data[bin_peaks, i_trial] ** 2)

# Noise around fundamental+harmonics
B = np.zeros(len(bin_noise))
for h in range(len(B)):
B[h] = np.mean(data[bin_noise[h], i_trial::n_trials].flatten())
B[h] = np.mean(data[bin_noise[h], i_trial].flatten() ** 2)

# Ratio
with np.errstate(divide='ignore', invalid='ignore'):
SNR[i_bin, i_trial] = np.sqrt(np.sum(A)) / np.sqrt(np.sum(B))
SNR[np.abs(SNR) == np.inf] = 1
SNR[SNR == 0] = 1
SNR[np.isnan(SNR)] = 1
SNR[i_bin, i_trial] = np.sqrt(A) / np.sqrt(B)

del A
del B

# SNR[np.abs(SNR) == np.inf] = 1
# SNR[SNR == 0] = 1
# SNR[np.isnan(SNR)] = 1

# Reshape matrix if necessary
if np.min((n_trials, n_chans)) > 1:
SNR = np.reshape(SNR, (n_freqs, n_chans, n_trials))
Expand Down
48 changes: 29 additions & 19 deletions tests/test_ress.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from meegkit.utils import fold, rms, unfold, snr_spectrum


def create_data(n_times, n_chans=10, n_trials=50, freq=12, sfreq=250,
def create_data(n_times, n_chans=10, n_trials=20, freq=12, sfreq=250,
noise_dim=8, SNR=1, t0=100, show=False):
"""Create synthetic data.

Expand Down Expand Up @@ -59,29 +59,39 @@ def test_ress(target, n_trials, show=False):

out = ress.RESS(data, sfreq=sfreq, peak_freq=target)

nfft = 250
bins, psd = ss.welch(out, sfreq, window="hamming", nperseg=nfft,
noverlap=125, axis=0)

print(psd.shape)
psd = psd.mean(axis=1, keepdims=True) # average over trials
snr = snr_spectrum(psd, bins, skipbins=2, n_avg=2)

nfft = 500
bins, psd = ss.welch(out, sfreq, window="boxcar", nperseg=nfft,
noverlap=0, axis=0, average='mean')
# psd = np.abs(np.fft.fft(out, nfft, axis=0))
# psd = psd[0:psd.shape[0] // 2 + 1]
# bins = np.linspace(0, sfreq // 2, psd.shape[0])
# print(psd.shape)
# print(bins[:10])

psd = psd.mean(axis=-1, keepdims=True) # average over trials
snr = snr_spectrum(psd + psd.max() / 20, bins, skipbins=1, n_avg=2)
# snr = snr.mean(1)
if show:
f, ax = plt.subplots(1)
ax.plot(bins, snr, 'o')
ax.axhline(1, ls=':', c='grey', zorder=0)
ax.axvline(target, ls=':', c='grey', zorder=0)
ax.set_ylabel('SNR (a.u.)')
ax.set_xlabel('Frequency (Hz)')
ax.set_xlim([0, 40])
f, ax = plt.subplots(2)
ax[0].plot(bins, snr, ':o')
ax[0].axhline(1, ls=':', c='grey', zorder=0)
ax[0].axvline(target, ls=':', c='grey', zorder=0)
ax[0].set_ylabel('SNR (a.u.)')
ax[0].set_xlabel('Frequency (Hz)')
ax[0].set_xlim([0, 40])
ax[0].set_ylim([0, 10])
ax[1].plot(bins, psd)
ax[1].axvline(target, ls=':', c='grey', zorder=0)
ax[1].set_ylabel('PSD')
ax[1].set_xlabel('Frequency (Hz)')
ax[1].set_xlim([0, 40])
plt.show()

assert snr[bins == target] > 10
assert (snr[(bins < target - 1) | (bins > target + 1)] < 2).all()
assert (snr[(bins <= target - 2) | (bins >= target + 2)] < 2).all()


if __name__ == '__main__':
import pytest
# pytest.main([__file__])
test_ress(12, 10, show=True)
pytest.main([__file__])
# test_ress(12, 20, show=True)