diff --git a/meegkit/utils/stats.py b/meegkit/utils/stats.py index 84acaaf6..1fe1a6e7 100644 --- a/meegkit/utils/stats.py +++ b/meegkit/utils/stats.py @@ -11,25 +11,25 @@ mne = None -def rms(x, axis=0): +def rms(X, axis=0): """Root-mean-square along given axis.""" - return np.sqrt(np.mean(x ** 2, axis=axis, keepdims=True)) + return np.sqrt(np.mean(X ** 2, axis=axis, keepdims=True)) -def robust_mean(x, axis=0, percentile=[5, 95]): +def robust_mean(X, axis=0, percentile=[5, 95]): """Do robust mean based on JR Kings implementation.""" - x = np.array(x) + X = np.array(X) axis_ = axis # force axis to be 0 for facilitation if axis is not None and axis != 0: - x = np.transpose(x, [axis] + range(0, axis) + range(axis + 1, x.ndim)) + X = np.transpose(X, [axis] + range(0, axis) + range(axis + 1, X.ndim)) axis_ = 0 - mM = np.percentile(x, percentile, axis=axis_) - indices_min = np.where((x - mM[0][np.newaxis, ...]) < 0) - indices_max = np.where((x - mM[1][np.newaxis, ...]) > 0) - x[indices_min] = np.nan - x[indices_max] = np.nan - m = np.nanmean(x, axis=axis_) + mM = np.percentile(X, percentile, axis=axis_) + indices_min = np.where((X - mM[0][np.newaxis, ...]) < 0) + indices_max = np.where((X - mM[1][np.newaxis, ...]) > 0) + X[indices_min] = np.nan + X[indices_max] = np.nan + m = np.nanmean(X, axis=axis_) return m @@ -142,7 +142,7 @@ def bootstrap_snr(epochs, n_bootstrap=2000, baseline=None, window=None): Parameters ---------- epochs : mne.Epochs instance - Epochs instance to compute ERP from. + Epochs instance to compute SNR from. n_bootstrap : int Number of bootstrap iterations (should be > 10000 for publication quality). @@ -174,19 +174,19 @@ def bootstrap_snr(epochs, n_bootstrap=2000, baseline=None, window=None): """ indices = np.arange(len(epochs.selection), dtype=int) - erp_bs = np.empty((n_bootstrap, len(epochs.times))) - gfp_bs = np.empty((n_bootstrap, len(epochs.times))) + n_chans = len(epochs.ch_names) + erp_bs = np.empty((n_bootstrap, n_chans, len(epochs.times))) + gfp_bs = np.empty((n_bootstrap, n_chans, len(epochs.times))) for i in range(n_bootstrap): - if 100 * i / n_bootstrap == np.floor(100 * i / n_bootstrap): - print('Bootstrapping... {}%'.format(round(100 * i / n_bootstrap)), - end='\r'), bs_indices = np.random.choice(indices, replace=True, size=len(indices)) - erp_bs[i] = np.mean(epochs._data[bs_indices, 0, :], 0) + erp_bs[i] = np.mean(epochs._data[bs_indices, ...], 0) # Baseline correct mean waveform - erp_bs[i] = mne.baseline.rescale(erp_bs[i], epochs.times, - baseline=baseline, verbose='ERROR') + if baseline: + erp_bs[i] = mne.baseline.rescale(erp_bs[i], epochs.times, + baseline=baseline, + verbose='ERROR') # Rectify waveform gfp_bs[i] = np.sqrt(erp_bs[i] ** 2) @@ -195,7 +195,7 @@ def bootstrap_snr(epochs, n_bootstrap=2000, baseline=None, window=None): ci_low, ci_up = np.percentile(erp_bs, (10, 90), axis=0) # Calculate SNR for each bootstrapped ERP; form distribution in `snr_dist` - snr_dist = np.zeros((n_bootstrap,)) + snr_dist = np.zeros((n_bootstrap, n_chans)) if window is not None: if window[0] is None: window[0] = 0 @@ -208,15 +208,12 @@ def bootstrap_snr(epochs, n_bootstrap=2000, baseline=None, window=None): pre = epochs.times <= 0 # SNR for each bootstrap iteration - for i in range(n_bootstrap): - snr_dist[i] = 20 * np.log10( - np.mean(gfp_bs[i, post]) / np.mean(gfp_bs[i, pre])) - - print('Bootstrapping... OK! ') + snr_dist = 20 * np.log10(gfp_bs[..., post].mean(-1) / + gfp_bs[..., pre].mean(-1)) # Mean, lower, and upper bound SNR snr_low, snr_up = np.percentile(snr_dist, (10, 90), axis=0) - snr_mean = np.mean(snr_dist) + snr_mean = np.mean(snr_dist, axis=0) mean_bs_erp = np.mean(erp_bs, axis=0) return (mean_bs_erp, ci_low, ci_up), (snr_mean, snr_low, snr_up) @@ -238,56 +235,58 @@ def cronbach(epochs, K=None, n_bootstrap=2000, tmin=None, tmax=None): Parameters ---------- - epochs: instance of mne.Epochs + epochs : mne.Epochs | ndarray, shape=(n_trials, n_chans, n_samples) Epochs to compute alpha from. - K: int + K : int Number of trials to use for alpha computation. - n_bootstrap: int + n_bootstrap : int Number of bootstrap resamplings. - tmin: float + tmin : float Start time of epoch. - tmax: float + tmax : float End of epoch. Returns ------- - alpha: + alpha : array, shape=(n_chans,) Cronbach alpha value bounds: length-2 tuple Lower and higher bound of CI. """ - if tmin: - tmin = epochs.time_as_index(tmin)[0] - else: - tmin = 0 - - if tmax: - tmax = epochs.time_as_index(tmax)[0] + if isinstance(epochs, np.ndarray): + erp = epochs + tmin = tmin if tmin else 0 + tmax = tmax if tmax else -1 + elif isinstance(epochs, mne.BaseEpochs): + erp = epochs.get_data() + tmin = epochs.time_as_index(tmin)[0] if tmin else 0 + tmax = epochs.time_as_index(tmax)[0] if tmax else -1 else: - tmax = -1 + raise ValueError("epochs must be an mne.Epochs or numpy array.") + n_trials, n_chans, n_samples = erp.shape if K is None: - K = epochs._data.shape[0] + K = n_trials - alpha = np.empty(n_bootstrap) + alpha = np.empty((n_bootstrap, n_chans)) for b in np.arange(n_bootstrap): # take K trials randomly - idx = np.random.choice(range(epochs._data.shape[0]), K) - X = epochs._data[idx, 0, tmin:tmax] - sigmaY = X.var(axis=1) # var over time - sigmaX = X.sum(axis=0).var() # var of average - alpha[b] = K / (K - 1) * (1 - sigmaY.sum() / sigmaX) + idx = np.random.choice(range(n_trials), K) + X = erp[idx, :, tmin:tmax] + sigmaY = X.var(axis=2).sum(0) # var over time + sigmaX = X.sum(axis=0).var(-1) # var of average + alpha[b] = K / (K - 1) * (1 - sigmaY / sigmaX) - ci_lb, ci_ub = np.percentile(alpha, (10, 90)) - return np.max([alpha.mean(), 0]), (ci_lb, ci_ub) + ci_lo, ci_hi = np.percentile(alpha, (10, 90), axis=0) + return alpha.mean(0), ci_lo, ci_hi -def snr_spectrum(data, freqs, n_avg=1, n_harm=1, skipbins=1): +def snr_spectrum(X, freqs, n_avg=1, n_harm=1, skipbins=1): """Compute Signal-to-Noise-corrected spectrum. Parameters ---------- - data : ndarray , shape=(n_freqs, n_chans,[ n_trials,]) + X : ndarray , shape=(n_freqs, n_chans,[ n_trials,]) One-sided power spectral density estimate, specified as a real-valued, nonnegative array. The power spectral density must be expressed in linear units, not decibels. @@ -319,21 +318,21 @@ def snr_spectrum(data, freqs, n_avg=1, n_harm=1, skipbins=1): natural face images in the infant right hemisphere. Elife, 4. """ - if data.ndim == 3: - n_freqs = data.shape[0] - n_chans = data.shape[1] - n_trials = data.shape[-1] - elif data.ndim == 2: + if X.ndim == 3: + n_freqs = X.shape[0] + n_chans = X.shape[1] + n_trials = X.shape[-1] + elif X.ndim == 2: n_trials = 1 - n_freqs = data.shape[0] - n_chans = data.shape[1] + n_freqs = X.shape[0] + n_chans = X.shape[1] else: raise ValueError('Data must have shape (n_freqs, n_chans, [n_trials,])' - f', got {data.shape}') + f', got {X.shape}') # Number of points to get desired resolution - data = np.reshape(data, (n_freqs, n_chans * n_trials)) - SNR = np.zeros_like(data) + X = np.reshape(X, (n_freqs, n_chans * n_trials)) + SNR = np.zeros_like(X) for i_bin in range(n_freqs): @@ -376,12 +375,12 @@ def snr_spectrum(data, freqs, n_avg=1, n_harm=1, skipbins=1): for i_trial in range(n_chans * n_trials): # Mean of signal over fundamental+harmonics - A = np.mean(data[bin_peaks, i_trial] ** 2) + A = np.mean(X[bin_peaks, i_trial] ** 2) # Noise around fundamental+harmonics B = np.zeros(len(bin_noise)) for h in range(len(B)): - B[h] = np.mean(data[bin_noise[h], i_trial].flatten() ** 2) + B[h] = np.mean(X[bin_noise[h], i_trial].flatten() ** 2) # Ratio with np.errstate(divide='ignore', invalid='ignore'): diff --git a/tests/test_utils.py b/tests/test_utils.py index 4026a239..01678516 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,10 +2,33 @@ from meegkit.utils import (bootstrap_ci, demean, find_outlier_samples, find_outlier_trials, fold, mean_over_trials, multishift, multismooth, relshift, rms, shift, - shiftnd, unfold, widen_mask) + shiftnd, unfold, widen_mask, cronbach, robust_mean) from numpy.testing import assert_almost_equal, assert_equal +def _sim_data(n_times, n_chans, n_trials, noise_dim, SNR=1, t0=100): + """Create synthetic data.""" + # source + source = np.sin(2 * np.pi * np.linspace(0, .5, n_times - t0))[np.newaxis].T + s = source * np.random.randn(1, n_chans) + s = s[:, :, np.newaxis] + s = np.tile(s, (1, 1, n_trials)) + signal = np.zeros((n_times, n_chans, n_trials)) + signal[t0:, :, :] = s + + # noise + noise = np.dot( + unfold(np.random.randn(n_times, noise_dim, n_trials)), + np.random.randn(noise_dim, n_chans)) + noise = fold(noise, n_times) + + # mix signal and noise + signal = SNR * signal / rms(signal.flatten()) + noise = noise / rms(noise.flatten()) + noisy_data = signal + noise + return noisy_data, signal + + def test_multishift(): """Test matrix multi-shifting.""" # multishift() @@ -119,7 +142,7 @@ def test_demean(show=False): n_chans = 8 n_times = 1000 x = np.random.randn(n_times, n_chans, n_trials) - x, s = _stim_data(n_times, n_chans, n_trials, 8, SNR=10) + x, s = _sim_data(n_times, n_chans, n_trials, 8, SNR=10) # 1. demean and check trial average is almost zero x1 = demean(x) @@ -169,29 +192,6 @@ def test_demean(show=False): np.testing.assert_array_equal(x1, x3) -def _stim_data(n_times, n_chans, n_trials, noise_dim, SNR=1, t0=100): - """Create synthetic data.""" - # source - source = np.sin(2 * np.pi * np.linspace(0, .5, n_times - t0))[np.newaxis].T - s = source * np.random.randn(1, n_chans) - s = s[:, :, np.newaxis] - s = np.tile(s, (1, 1, n_trials)) - signal = np.zeros((n_times, n_chans, n_trials)) - signal[t0:, :, :] = s - - # noise - noise = np.dot( - unfold(np.random.randn(n_times, noise_dim, n_trials)), - np.random.randn(noise_dim, n_chans)) - noise = fold(noise, n_times) - - # mix signal and noise - signal = SNR * signal / rms(signal.flatten()) - noise = noise / rms(noise.flatten()) - noisy_data = signal + noise - return noisy_data, signal - - def test_computeci(): """Compute CI.""" x = np.random.randn(1000, 8, 100) @@ -231,6 +231,23 @@ def test_outliers(show=False): assert idx.shape == x.shape +def test_cronbach(): + """Test Cronbach's alpha.""" + X, _ = _sim_data(800, 8, 80, noise_dim=6, SNR=.2) + X = X.transpose([2, 1, 0]) # trials, channels, samples + alpha, lo, hi = cronbach(X, tmin=0, n_bootstrap=100) + print(alpha) + assert np.all(lo < hi) + + X, _ = _sim_data(800, 8, 80, noise_dim=6, SNR=1) + X = X.transpose([2, 1, 0]) + alpha2, lo, hi = cronbach(X, tmin=100, n_bootstrap=100) + print(alpha2) + assert np.sum(alpha2 > alpha) >= 6 + + m = robust_mean(X, axis=0) + assert m.shape == (X.shape[1], X.shape[2]) + if __name__ == '__main__': import pytest pytest.main([__file__]) @@ -241,3 +258,5 @@ def test_outliers(show=False): # y = multismooth(x, np.arange(1, 200, 4)) # plt.imshow(y.T, aspect='auto') # plt.show() + + # test_cronbach()