Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 63 additions & 64 deletions meegkit/utils/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,25 @@
mne = None


def rms(x, axis=0):
def rms(X, axis=0):
"""Root-mean-square along given axis."""
return np.sqrt(np.mean(x ** 2, axis=axis, keepdims=True))
return np.sqrt(np.mean(X ** 2, axis=axis, keepdims=True))


def robust_mean(x, axis=0, percentile=[5, 95]):
def robust_mean(X, axis=0, percentile=[5, 95]):
"""Do robust mean based on JR Kings implementation."""
x = np.array(x)
X = np.array(X)
axis_ = axis
# force axis to be 0 for facilitation
if axis is not None and axis != 0:
x = np.transpose(x, [axis] + range(0, axis) + range(axis + 1, x.ndim))
X = np.transpose(X, [axis] + range(0, axis) + range(axis + 1, X.ndim))
axis_ = 0
mM = np.percentile(x, percentile, axis=axis_)
indices_min = np.where((x - mM[0][np.newaxis, ...]) < 0)
indices_max = np.where((x - mM[1][np.newaxis, ...]) > 0)
x[indices_min] = np.nan
x[indices_max] = np.nan
m = np.nanmean(x, axis=axis_)
mM = np.percentile(X, percentile, axis=axis_)
indices_min = np.where((X - mM[0][np.newaxis, ...]) < 0)
indices_max = np.where((X - mM[1][np.newaxis, ...]) > 0)
X[indices_min] = np.nan
X[indices_max] = np.nan
m = np.nanmean(X, axis=axis_)
return m


Expand Down Expand Up @@ -142,7 +142,7 @@ def bootstrap_snr(epochs, n_bootstrap=2000, baseline=None, window=None):
Parameters
----------
epochs : mne.Epochs instance
Epochs instance to compute ERP from.
Epochs instance to compute SNR from.
n_bootstrap : int
Number of bootstrap iterations (should be > 10000 for publication
quality).
Expand Down Expand Up @@ -174,19 +174,19 @@ def bootstrap_snr(epochs, n_bootstrap=2000, baseline=None, window=None):

"""
indices = np.arange(len(epochs.selection), dtype=int)
erp_bs = np.empty((n_bootstrap, len(epochs.times)))
gfp_bs = np.empty((n_bootstrap, len(epochs.times)))
n_chans = len(epochs.ch_names)
erp_bs = np.empty((n_bootstrap, n_chans, len(epochs.times)))
gfp_bs = np.empty((n_bootstrap, n_chans, len(epochs.times)))

for i in range(n_bootstrap):
if 100 * i / n_bootstrap == np.floor(100 * i / n_bootstrap):
print('Bootstrapping... {}%'.format(round(100 * i / n_bootstrap)),
end='\r'),
bs_indices = np.random.choice(indices, replace=True, size=len(indices))
erp_bs[i] = np.mean(epochs._data[bs_indices, 0, :], 0)
erp_bs[i] = np.mean(epochs._data[bs_indices, ...], 0)

# Baseline correct mean waveform
erp_bs[i] = mne.baseline.rescale(erp_bs[i], epochs.times,
baseline=baseline, verbose='ERROR')
if baseline:
erp_bs[i] = mne.baseline.rescale(erp_bs[i], epochs.times,
baseline=baseline,
verbose='ERROR')

# Rectify waveform
gfp_bs[i] = np.sqrt(erp_bs[i] ** 2)
Expand All @@ -195,7 +195,7 @@ def bootstrap_snr(epochs, n_bootstrap=2000, baseline=None, window=None):
ci_low, ci_up = np.percentile(erp_bs, (10, 90), axis=0)

# Calculate SNR for each bootstrapped ERP; form distribution in `snr_dist`
snr_dist = np.zeros((n_bootstrap,))
snr_dist = np.zeros((n_bootstrap, n_chans))
if window is not None:
if window[0] is None:
window[0] = 0
Expand All @@ -208,15 +208,12 @@ def bootstrap_snr(epochs, n_bootstrap=2000, baseline=None, window=None):
pre = epochs.times <= 0

# SNR for each bootstrap iteration
for i in range(n_bootstrap):
snr_dist[i] = 20 * np.log10(
np.mean(gfp_bs[i, post]) / np.mean(gfp_bs[i, pre]))

print('Bootstrapping... OK! ')
snr_dist = 20 * np.log10(gfp_bs[..., post].mean(-1) /
gfp_bs[..., pre].mean(-1))

# Mean, lower, and upper bound SNR
snr_low, snr_up = np.percentile(snr_dist, (10, 90), axis=0)
snr_mean = np.mean(snr_dist)
snr_mean = np.mean(snr_dist, axis=0)
mean_bs_erp = np.mean(erp_bs, axis=0)

return (mean_bs_erp, ci_low, ci_up), (snr_mean, snr_low, snr_up)
Expand All @@ -238,56 +235,58 @@ def cronbach(epochs, K=None, n_bootstrap=2000, tmin=None, tmax=None):

Parameters
----------
epochs: instance of mne.Epochs
epochs : mne.Epochs | ndarray, shape=(n_trials, n_chans, n_samples)
Epochs to compute alpha from.
K: int
K : int
Number of trials to use for alpha computation.
n_bootstrap: int
n_bootstrap : int
Number of bootstrap resamplings.
tmin: float
tmin : float
Start time of epoch.
tmax: float
tmax : float
End of epoch.

Returns
-------
alpha:
alpha : array, shape=(n_chans,)
Cronbach alpha value
bounds: length-2 tuple
Lower and higher bound of CI.

"""
if tmin:
tmin = epochs.time_as_index(tmin)[0]
else:
tmin = 0

if tmax:
tmax = epochs.time_as_index(tmax)[0]
if isinstance(epochs, np.ndarray):
erp = epochs
tmin = tmin if tmin else 0
tmax = tmax if tmax else -1
elif isinstance(epochs, mne.BaseEpochs):
erp = epochs.get_data()
tmin = epochs.time_as_index(tmin)[0] if tmin else 0
tmax = epochs.time_as_index(tmax)[0] if tmax else -1
else:
tmax = -1
raise ValueError("epochs must be an mne.Epochs or numpy array.")

n_trials, n_chans, n_samples = erp.shape
if K is None:
K = epochs._data.shape[0]
K = n_trials

alpha = np.empty(n_bootstrap)
alpha = np.empty((n_bootstrap, n_chans))
for b in np.arange(n_bootstrap): # take K trials randomly
idx = np.random.choice(range(epochs._data.shape[0]), K)
X = epochs._data[idx, 0, tmin:tmax]
sigmaY = X.var(axis=1) # var over time
sigmaX = X.sum(axis=0).var() # var of average
alpha[b] = K / (K - 1) * (1 - sigmaY.sum() / sigmaX)
idx = np.random.choice(range(n_trials), K)
X = erp[idx, :, tmin:tmax]
sigmaY = X.var(axis=2).sum(0) # var over time
sigmaX = X.sum(axis=0).var(-1) # var of average
alpha[b] = K / (K - 1) * (1 - sigmaY / sigmaX)

ci_lb, ci_ub = np.percentile(alpha, (10, 90))
return np.max([alpha.mean(), 0]), (ci_lb, ci_ub)
ci_lo, ci_hi = np.percentile(alpha, (10, 90), axis=0)
return alpha.mean(0), ci_lo, ci_hi


def snr_spectrum(data, freqs, n_avg=1, n_harm=1, skipbins=1):
def snr_spectrum(X, freqs, n_avg=1, n_harm=1, skipbins=1):
"""Compute Signal-to-Noise-corrected spectrum.

Parameters
----------
data : ndarray , shape=(n_freqs, n_chans,[ n_trials,])
X : ndarray , shape=(n_freqs, n_chans,[ n_trials,])
One-sided power spectral density estimate, specified as a real-valued,
nonnegative array. The power spectral density must be expressed in
linear units, not decibels.
Expand Down Expand Up @@ -319,21 +318,21 @@ def snr_spectrum(data, freqs, n_avg=1, n_harm=1, skipbins=1):
natural face images in the infant right hemisphere. Elife, 4.

"""
if data.ndim == 3:
n_freqs = data.shape[0]
n_chans = data.shape[1]
n_trials = data.shape[-1]
elif data.ndim == 2:
if X.ndim == 3:
n_freqs = X.shape[0]
n_chans = X.shape[1]
n_trials = X.shape[-1]
elif X.ndim == 2:
n_trials = 1
n_freqs = data.shape[0]
n_chans = data.shape[1]
n_freqs = X.shape[0]
n_chans = X.shape[1]
else:
raise ValueError('Data must have shape (n_freqs, n_chans, [n_trials,])'
f', got {data.shape}')
f', got {X.shape}')

# Number of points to get desired resolution
data = np.reshape(data, (n_freqs, n_chans * n_trials))
SNR = np.zeros_like(data)
X = np.reshape(X, (n_freqs, n_chans * n_trials))
SNR = np.zeros_like(X)

for i_bin in range(n_freqs):

Expand Down Expand Up @@ -376,12 +375,12 @@ def snr_spectrum(data, freqs, n_avg=1, n_harm=1, skipbins=1):
for i_trial in range(n_chans * n_trials):

# Mean of signal over fundamental+harmonics
A = np.mean(data[bin_peaks, i_trial] ** 2)
A = np.mean(X[bin_peaks, i_trial] ** 2)

# Noise around fundamental+harmonics
B = np.zeros(len(bin_noise))
for h in range(len(B)):
B[h] = np.mean(data[bin_noise[h], i_trial].flatten() ** 2)
B[h] = np.mean(X[bin_noise[h], i_trial].flatten() ** 2)

# Ratio
with np.errstate(divide='ignore', invalid='ignore'):
Expand Down
69 changes: 44 additions & 25 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,33 @@
from meegkit.utils import (bootstrap_ci, demean, find_outlier_samples,
find_outlier_trials, fold, mean_over_trials,
multishift, multismooth, relshift, rms, shift,
shiftnd, unfold, widen_mask)
shiftnd, unfold, widen_mask, cronbach, robust_mean)
from numpy.testing import assert_almost_equal, assert_equal


def _sim_data(n_times, n_chans, n_trials, noise_dim, SNR=1, t0=100):
"""Create synthetic data."""
# source
source = np.sin(2 * np.pi * np.linspace(0, .5, n_times - t0))[np.newaxis].T
s = source * np.random.randn(1, n_chans)
s = s[:, :, np.newaxis]
s = np.tile(s, (1, 1, n_trials))
signal = np.zeros((n_times, n_chans, n_trials))
signal[t0:, :, :] = s

# noise
noise = np.dot(
unfold(np.random.randn(n_times, noise_dim, n_trials)),
np.random.randn(noise_dim, n_chans))
noise = fold(noise, n_times)

# mix signal and noise
signal = SNR * signal / rms(signal.flatten())
noise = noise / rms(noise.flatten())
noisy_data = signal + noise
return noisy_data, signal


def test_multishift():
"""Test matrix multi-shifting."""
# multishift()
Expand Down Expand Up @@ -119,7 +142,7 @@ def test_demean(show=False):
n_chans = 8
n_times = 1000
x = np.random.randn(n_times, n_chans, n_trials)
x, s = _stim_data(n_times, n_chans, n_trials, 8, SNR=10)
x, s = _sim_data(n_times, n_chans, n_trials, 8, SNR=10)

# 1. demean and check trial average is almost zero
x1 = demean(x)
Expand Down Expand Up @@ -169,29 +192,6 @@ def test_demean(show=False):
np.testing.assert_array_equal(x1, x3)


def _stim_data(n_times, n_chans, n_trials, noise_dim, SNR=1, t0=100):
"""Create synthetic data."""
# source
source = np.sin(2 * np.pi * np.linspace(0, .5, n_times - t0))[np.newaxis].T
s = source * np.random.randn(1, n_chans)
s = s[:, :, np.newaxis]
s = np.tile(s, (1, 1, n_trials))
signal = np.zeros((n_times, n_chans, n_trials))
signal[t0:, :, :] = s

# noise
noise = np.dot(
unfold(np.random.randn(n_times, noise_dim, n_trials)),
np.random.randn(noise_dim, n_chans))
noise = fold(noise, n_times)

# mix signal and noise
signal = SNR * signal / rms(signal.flatten())
noise = noise / rms(noise.flatten())
noisy_data = signal + noise
return noisy_data, signal


def test_computeci():
"""Compute CI."""
x = np.random.randn(1000, 8, 100)
Expand Down Expand Up @@ -231,6 +231,23 @@ def test_outliers(show=False):
assert idx.shape == x.shape


def test_cronbach():
"""Test Cronbach's alpha."""
X, _ = _sim_data(800, 8, 80, noise_dim=6, SNR=.2)
X = X.transpose([2, 1, 0]) # trials, channels, samples
alpha, lo, hi = cronbach(X, tmin=0, n_bootstrap=100)
print(alpha)
assert np.all(lo < hi)

X, _ = _sim_data(800, 8, 80, noise_dim=6, SNR=1)
X = X.transpose([2, 1, 0])
alpha2, lo, hi = cronbach(X, tmin=100, n_bootstrap=100)
print(alpha2)
assert np.sum(alpha2 > alpha) >= 6

m = robust_mean(X, axis=0)
assert m.shape == (X.shape[1], X.shape[2])

if __name__ == '__main__':
import pytest
pytest.main([__file__])
Expand All @@ -241,3 +258,5 @@ def test_outliers(show=False):
# y = multismooth(x, np.arange(1, 200, 4))
# plt.imshow(y.T, aspect='auto')
# plt.show()

# test_cronbach()