diff --git a/meegkit/utils/base.py b/meegkit/utils/base.py index a446ff7d..4b1c1ee6 100644 --- a/meegkit/utils/base.py +++ b/meegkit/utils/base.py @@ -1,5 +1,5 @@ """Math utils.""" -from scipy.linalg import lstsq, solve +from scipy import linalg def mrdivide(A, B): @@ -36,7 +36,14 @@ def mldivide(A, B): https://docs.scipy.org/doc/numpy/user/numpy-for-matlab-users.html """ - if A.shape[0] == A.shape[1]: - return solve(A, B) - else: - return lstsq(A, B) + try: + # Note: we must use overwrite_a=False in order to be able to + # use the fall-back solution below in case a LinAlgError is raised + return linalg.solve(A, B, sym_pos=True, overwrite_a=False) + except linalg.LinAlgError: + # Singular matrix in solving dual problem. Using least-squares + # solution instead. + return linalg.lstsq(A, B, lapack_driver='gelsy')[0] + except linalg.LinAlgError: + print('Solution not stable. Model not updated!') + return None diff --git a/meegkit/utils/stats.py b/meegkit/utils/stats.py index 5f9195a9..f3e5b86b 100644 --- a/meegkit/utils/stats.py +++ b/meegkit/utils/stats.py @@ -287,21 +287,23 @@ def snr_spectrum(data, freqs, n_avg=1, n_harm=1, skipbins=1): Parameters ---------- - data : ndarray , shape=(n_freqs, n_chans, [n_trials, ]) - Power spectrum. + data : ndarray , shape=(n_freqs, n_chans,[ n_trials,]) + One-sided power spectral density estimate, specified as a real-valued, + nonnegative array. The power spectral density must be expressed in + linear units, not decibels. freqs : array, shape=(n_freqs,) Frequency bins. n_avg : int Number of neighbour bins to estimate noise over. Make sure that this - value doesn't overlap with neighbouring target frequencies + value doesn't overlap with neighbouring target frequencies. n_harm : int Compute SNR at each frequency bin as a pooled RMS over this bin and - n_harm harmonics (see references below) + n_harm harmonics (see references below). Returns ------- - SNR : ndarray, shape=(Nconds, Nchans, Nsamples) or (Nchans, Nsamples) - Signal-to-Noise-corrected spectrum + SNR : ndarray, shape=(n_freqs, n_chans, n_trials) or (n_freqs, n_chans) + Signal-to-Noise-corrected spectrum. References ---------- @@ -362,30 +364,34 @@ def snr_spectrum(data, freqs, n_avg=1, n_harm=1, skipbins=1): tmp = [t for t in tmp if t >= 0 and t < n_freqs] bin_noise.append(tmp) del tmp + else: + bin_peaks.append(0) + bin_noise.append(0) # SNR at central bin is ratio between (power at central # bin) to (average of N surrounding bins) # -------------------------------------------------------------------------- - for i_trial in range(n_trials * n_chans): + for i_trial in range(n_chans * n_trials): - # RMS of signal over fundamental+harmonics - A = data[bin_peaks, i_trial] + # Mean of signal over fundamental+harmonics + A = np.mean(data[bin_peaks, i_trial] ** 2) # Noise around fundamental+harmonics B = np.zeros(len(bin_noise)) for h in range(len(B)): - B[h] = np.mean(data[bin_noise[h], i_trial::n_trials].flatten()) + B[h] = np.mean(data[bin_noise[h], i_trial].flatten() ** 2) # Ratio with np.errstate(divide='ignore', invalid='ignore'): - SNR[i_bin, i_trial] = np.sqrt(np.sum(A)) / np.sqrt(np.sum(B)) - SNR[np.abs(SNR) == np.inf] = 1 - SNR[SNR == 0] = 1 - SNR[np.isnan(SNR)] = 1 + SNR[i_bin, i_trial] = np.sqrt(A) / np.sqrt(B) del A del B + # SNR[np.abs(SNR) == np.inf] = 1 + # SNR[SNR == 0] = 1 + # SNR[np.isnan(SNR)] = 1 + # Reshape matrix if necessary if np.min((n_trials, n_chans)) > 1: SNR = np.reshape(SNR, (n_freqs, n_chans, n_trials)) diff --git a/tests/test_ress.py b/tests/test_ress.py index 9760799a..ba770100 100644 --- a/tests/test_ress.py +++ b/tests/test_ress.py @@ -7,7 +7,7 @@ from meegkit.utils import fold, rms, unfold, snr_spectrum -def create_data(n_times, n_chans=10, n_trials=50, freq=12, sfreq=250, +def create_data(n_times, n_chans=10, n_trials=20, freq=12, sfreq=250, noise_dim=8, SNR=1, t0=100, show=False): """Create synthetic data. @@ -59,29 +59,39 @@ def test_ress(target, n_trials, show=False): out = ress.RESS(data, sfreq=sfreq, peak_freq=target) - nfft = 250 - bins, psd = ss.welch(out, sfreq, window="hamming", nperseg=nfft, - noverlap=125, axis=0) - - print(psd.shape) - psd = psd.mean(axis=1, keepdims=True) # average over trials - snr = snr_spectrum(psd, bins, skipbins=2, n_avg=2) - + nfft = 500 + bins, psd = ss.welch(out, sfreq, window="boxcar", nperseg=nfft, + noverlap=0, axis=0, average='mean') + # psd = np.abs(np.fft.fft(out, nfft, axis=0)) + # psd = psd[0:psd.shape[0] // 2 + 1] + # bins = np.linspace(0, sfreq // 2, psd.shape[0]) + # print(psd.shape) + # print(bins[:10]) + + psd = psd.mean(axis=-1, keepdims=True) # average over trials + snr = snr_spectrum(psd + psd.max() / 20, bins, skipbins=1, n_avg=2) + # snr = snr.mean(1) if show: - f, ax = plt.subplots(1) - ax.plot(bins, snr, 'o') - ax.axhline(1, ls=':', c='grey', zorder=0) - ax.axvline(target, ls=':', c='grey', zorder=0) - ax.set_ylabel('SNR (a.u.)') - ax.set_xlabel('Frequency (Hz)') - ax.set_xlim([0, 40]) + f, ax = plt.subplots(2) + ax[0].plot(bins, snr, ':o') + ax[0].axhline(1, ls=':', c='grey', zorder=0) + ax[0].axvline(target, ls=':', c='grey', zorder=0) + ax[0].set_ylabel('SNR (a.u.)') + ax[0].set_xlabel('Frequency (Hz)') + ax[0].set_xlim([0, 40]) + ax[0].set_ylim([0, 10]) + ax[1].plot(bins, psd) + ax[1].axvline(target, ls=':', c='grey', zorder=0) + ax[1].set_ylabel('PSD') + ax[1].set_xlabel('Frequency (Hz)') + ax[1].set_xlim([0, 40]) plt.show() assert snr[bins == target] > 10 - assert (snr[(bins < target - 1) | (bins > target + 1)] < 2).all() + assert (snr[(bins <= target - 2) | (bins >= target + 2)] < 2).all() if __name__ == '__main__': import pytest - # pytest.main([__file__]) - test_ress(12, 10, show=True) + pytest.main([__file__]) + # test_ress(12, 20, show=True)