diff --git a/doc/modules/meegkit.dss.rst b/doc/modules/meegkit.dss.rst index 497ad2ee..33974b7d 100644 --- a/doc/modules/meegkit.dss.rst +++ b/doc/modules/meegkit.dss.rst @@ -16,6 +16,7 @@ dss0 dss1 dss_line + dss_line_iter diff --git a/examples/example_dss_line.py b/examples/example_dss_line.py new file mode 100644 index 00000000..cc5f6ab9 --- /dev/null +++ b/examples/example_dss_line.py @@ -0,0 +1,92 @@ +""" +Remove line noise with ZapLine +============================== + +Find a spatial filter to get rid of line noise [1]_. + +Uses meegkit.dss_line(). + +References +---------- +.. [1] de Cheveigné, A. (2019). ZapLine: A simple and effective method to + remove power line artifacts [Preprint]. https://doi.org/10.1101/782029 + +""" +# Authors: Maciej Szul +# Nicolas Barascud +import os + +import matplotlib.pyplot as plt +import numpy as np +from meegkit import dss +from meegkit.utils import create_line_data, unfold +from scipy import signal + +############################################################################### +# Line noise removal +# ============================================================================= + +############################################################################### +# Remove line noise with dss_line() +# ----------------------------------------------------------------------------- +# We first generate some noisy data to work with +sfreq = 250 +fline = 50 +nsamples = 10000 +nchans = 10 +data = create_line_data(n_samples=3 * nsamples, n_chans=nchans, + n_trials=1, fline=fline / sfreq, SNR=2)[0] +data = data[..., 0] # only take first trial + +# Apply dss_line (ZapLine) +out, _ = dss.dss_line(data, fline, sfreq, nkeep=1) + +############################################################################### +# Plot before/after +f, ax = plt.subplots(1, 2, sharey=True) +f, Pxx = signal.welch(data, sfreq, nperseg=500, axis=0, return_onesided=True) +ax[0].semilogy(f, Pxx) +f, Pxx = signal.welch(out, sfreq, nperseg=500, axis=0, return_onesided=True) +ax[1].semilogy(f, Pxx) +ax[0].set_xlabel('frequency [Hz]') +ax[1].set_xlabel('frequency [Hz]') +ax[0].set_ylabel('PSD [V**2/Hz]') +ax[0].set_title('before') +ax[1].set_title('after') +plt.show() + + +############################################################################### +# Remove line noise with dss_line_iter() +# ----------------------------------------------------------------------------- +# We first load some noisy data to work with +data = np.load(os.path.join('..', 'tests', 'data', 'dss_line_data.npy')) +fline = 50 +sfreq = 200 +print(data.shape) # n_samples, n_chans, n_trials + +# Apply dss_line(), removing only one component +out1, _ = dss.dss_line(data, fline, sfreq, nremove=1, nfft=400) + +############################################################################### +# Now try dss_line_iter(). This applies dss_line() repeatedly until the +# artifact is gone +out2, iterations = dss.dss_line_iter(data, fline, sfreq, nfft=400) +print(f'Removed {iterations} components') + +############################################################################### +# Plot results with dss_line() vs. dss_line_iter() +f, ax = plt.subplots(1, 2, sharey=True) +f, Pxx = signal.welch(unfold(out1), sfreq, nperseg=200, axis=0, + return_onesided=True) +ax[0].semilogy(f, Pxx, lw=.5) +f, Pxx = signal.welch(unfold(out2), sfreq, nperseg=200, axis=0, + return_onesided=True) +ax[1].semilogy(f, Pxx, lw=.5) +ax[0].set_xlabel('frequency [Hz]') +ax[1].set_xlabel('frequency [Hz]') +ax[0].set_ylabel('PSD [V**2/Hz]') +ax[0].set_title('dss_line') +ax[1].set_title('dss_line_iter') +plt.tight_layout() +plt.show() diff --git a/meegkit/__init__.py b/meegkit/__init__.py index 162a3c54..e01b7aba 100644 --- a/meegkit/__init__.py +++ b/meegkit/__init__.py @@ -1,5 +1,5 @@ """M/EEG denoising utilities in python.""" -__version__ = '0.1.1' +__version__ = '0.1.2' from . import asr, cca, detrend, dss, sns, star, ress, trca, tspca, utils diff --git a/meegkit/dss.py b/meegkit/dss.py index f408951c..f46371b5 100644 --- a/meegkit/dss.py +++ b/meegkit/dss.py @@ -1,6 +1,10 @@ """Denoising source separation.""" +# Authors: Nicolas Barascud +# Maciej Szul + import numpy as np from scipy import linalg +from scipy.signal import welch from .tspca import tsr from .utils import (demean, gaussfilt, mean_over_trials, pca, smooth, @@ -230,3 +234,130 @@ def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, show=False): p = wpwr(X - y)[0] / wpwr(X)[0] print('Power of components removed by DSS: {:.2f}'.format(p)) return y, artifact + + +def dss_line_iter(data, fline, sfreq, win_sz=10, spot_sz=2.5, + nfft=512, show=False, prefix="dss_iter", n_iter_max=100): + """Remove power line artifact iteratively. + + This method applies dss_line() until the artifact has been smoothed out + from the spectrum. + + Parameters + ---------- + data : data, shape=(n_samples, n_chans, n_trials) + Input data. + fline : float + Line frequency. + sfreq : float + Sampling frequency. + win_sz : float + Half of the width of the window around the target frequency used to fit + the polynomial (default=10). + spot_sz : float + Half of the width of the window around the target frequency used to + remove the peak and interpolate (default=2.5). + nfft : int + FFT size for the internal PSD calculation (default=512). + show: bool + Produce a visual output of each iteration (default=False). + prefix : str + Path and first part of the visualisation output file + "{prefix}_{iteration number}.png" (default="dss_iter"). + n_iter_max : int + Maximum number of iterations (default=100). + + Returns + ------- + data : array, shape=(n_samples, n_chans, n_trials) + Denoised data. + iterations : int + Number of iterations. + """ + + def nan_basic_interp(array): + """Nan interpolation.""" + nans, ix = np.isnan(array), lambda x: x.nonzero()[0] + array[nans] = np.interp(ix(nans), ix(~nans), array[~nans]) + return array + + freq_rn = [fline - win_sz, fline + win_sz] + freq_sp = [fline - spot_sz, fline + spot_sz] + freq, psd = welch(data, fs=sfreq, nfft=nfft, axis=0) + + freq_rn_ix = np.logical_and(freq >= freq_rn[0], freq <= freq_rn[1]) + freq_used = freq[freq_rn_ix] + freq_sp_ix = np.logical_and(freq_used >= freq_sp[0], + freq_used <= freq_sp[1]) + + if psd.ndim == 3: + mean_psd = np.mean(psd, axis=(1, 2))[freq_rn_ix] + elif psd.ndim == 2: + mean_psd = np.mean(psd, axis=(1))[freq_rn_ix] + + mean_psd_wospot = mean_psd.copy() + mean_psd_wospot[freq_sp_ix] = np.nan + mean_psd_tf = nan_basic_interp(mean_psd_wospot) + pf = np.polyfit(freq_used, mean_psd_tf, 3) + p = np.poly1d(pf) + clean_fit_line = p(freq_used) + + aggr_resid = [] + iterations = 0 + while iterations < n_iter_max: + data, _ = dss_line(data, fline, sfreq, nfft=nfft, nremove=1) + freq, psd = welch(data, fs=sfreq, nfft=nfft, axis=0) + if psd.ndim == 3: + mean_psd = np.mean(psd, axis=(1, 2))[freq_rn_ix] + elif psd.ndim == 2: + mean_psd = np.mean(psd, axis=(1))[freq_rn_ix] + + residuals = mean_psd - clean_fit_line + mean_score = np.mean(residuals[freq_sp_ix]) + aggr_resid.append(mean_score) + + print("Iteration {} score: {}".format(iterations, mean_score)) + + if show: + import matplotlib.pyplot as plt + f, ax = plt.subplots(2, 2, figsize=(12, 6), facecolor="white") + + if psd.ndim == 3: + mean_sens = np.mean(psd, axis=2) + elif psd.ndim == 2: + mean_sens = psd + + y = mean_sens[freq_rn_ix] + ax.flat[0].plot(freq_used, y) + ax.flat[0].set_title("Mean PSD across trials") + + ax.flat[1].plot(freq_used, mean_psd_tf, c="gray") + ax.flat[1].plot(freq_used, mean_psd, c="blue") + ax.flat[1].plot(freq_used, clean_fit_line, c="red") + ax.flat[1].set_title("Mean PSD across trials and sensors") + + tf_ix = np.where(freq_used <= fline)[0][-1] + ax.flat[2].plot(residuals, freq_used) + color = "green" + if mean_score <= 0: + color = "red" + ax.flat[2].scatter(residuals[tf_ix], freq_used[tf_ix], c=color) + ax.flat[2].set_title("Residuals") + + ax.flat[3].plot(np.arange(iterations + 1), aggr_resid, marker='o') + ax.flat[3].set_title("Iterations") + + f.set_tight_layout(True) + plt.savefig(f"{prefix}_{iterations:03}.png") + plt.close("all") + + if mean_score <= 0: + break + + iterations += 1 + + if iterations == n_iter_max: + raise RuntimeError('Could not converge. Consider increasing the ' + 'maximum number of iterations') + + return data, iterations diff --git a/meegkit/utils/__init__.py b/meegkit/utils/__init__.py index cfca680e..b34ee4ab 100644 --- a/meegkit/utils/__init__.py +++ b/meegkit/utils/__init__.py @@ -13,3 +13,4 @@ spectral_envelope, teager_kaiser) from .stats import (bootstrap_ci, bootstrap_snr, cronbach, rms, robust_mean, rolling_corr, snr_spectrum) +from .testing import create_line_data diff --git a/meegkit/utils/testing.py b/meegkit/utils/testing.py new file mode 100644 index 00000000..008f616c --- /dev/null +++ b/meegkit/utils/testing.py @@ -0,0 +1,74 @@ +"""Synthetic test data.""" +import numpy as np +from meegkit.utils import fold, rms, unfold + +import matplotlib.pyplot as plt + + +def create_line_data(n_samples=100 * 3, n_chans=30, n_trials=100, noise_dim=20, + n_bad_chans=1, SNR=.1, fline=1, t0=None, show=False): + """Create synthetic data. + + Parameters + ---------- + n_samples : int + Number of samples (default=100*3). + n_chans : int + Number of channels (default=30). + n_trials : int + Number of trials (default=100). + noise_dim : int + Dimensionality of noise (default=20). + n_bad_chans : int + Number of bad channels (default=1). + t0 : int + Onset sample of artifact. + fline : float + Normalized frequency of artifact (freq/samplerate), (default=1). + + Returns + ------- + data : ndarray, shape=(n_samples, n_chans, n_trials) + source : ndarray, shape=(n_samples,) + """ + rng = np.random.RandomState(2022) + + if t0 is None: + t0 = n_samples // 3 + t1 = n_samples - 2 * t0 # artifact duration + + # create source signal + source = np.hstack(( + np.zeros(t0), + np.sin(2 * np.pi * fline * np.arange(t1)), + np.zeros(t0))) # noise -> artifact -> noise + source = source[:, None] + + # mix source in channels + s = source * rng.randn(1, n_chans) + s = s[:, :, np.newaxis] + s = np.tile(s, (1, 1, n_trials)) # create trials + + # set first `n_bad_chans` to zero + s[:, :n_bad_chans] = 0. + + # noise + noise = np.dot( + unfold(rng.randn(n_samples, noise_dim, n_trials)), + rng.randn(noise_dim, n_chans)) + noise = fold(noise, n_samples) + + # mix signal and noise + data = noise / rms(noise.flatten()) + SNR * s / rms(s.flatten()) + + if show: + f, ax = plt.subplots(3) + ax[0].plot(source.mean(-1), label='source') + ax[1].plot(noise[:, 1].mean(-1), label='noise (avg over trials)') + ax[2].plot(data[:, 1].mean(-1), label='mixture (avg over trials)') + ax[0].legend() + ax[1].legend() + ax[2].legend() + plt.show() + + return data, source diff --git a/setup.py b/setup.py index dbc721c4..608ce40a 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,6 @@ author='N Barascud', author_email='nicolas.barascud@gmail.com', license='UNLICENSED', - version='0.1.1', + version='0.1.2', packages=find_packages(exclude=['doc', 'tests']), zip_safe=False) diff --git a/tests/conftest.py b/tests/conftest.py index e9548083..0b54abf0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,33 @@ import pytest -import numpy as np -import random as rand +import matplotlib.pyplot as plt -@pytest.fixture -def random(): - rand.seed(9) - np.random.seed(9) + +def pytest_addoption(parser): + """Add command line option to pytest.""" + parser.addoption( + "--runslow", + action="store_true", + default=False, + help="run slow tests" + ) + parser.addoption( + "--noplots", + action="store_true", + default=False, + help="halt on plots" + ) + + +def pytest_collection_modifyitems(config, items): + """Do not skip slow test if option provided.""" + if config.getoption("--noplots"): + plt.switch_backend('agg') + + if config.getoption("--runslow"): + # --runslow given in cli: do not skip slow tests + return + skip_slow = pytest.mark.skip(reason="need --runslow option to run") + for item in items: + if "slow" in item.keywords: + item.add_marker(skip_slow) diff --git a/tests/data/dss_line_data.npy b/tests/data/dss_line_data.npy new file mode 100644 index 00000000..472d6512 Binary files /dev/null and b/tests/data/dss_line_data.npy differ diff --git a/tests/test_dss.py b/tests/test_dss.py index 250e69f7..d892ec3c 100644 --- a/tests/test_dss.py +++ b/tests/test_dss.py @@ -1,69 +1,15 @@ """Test DSS functions.""" +import os +from tempfile import TemporaryDirectory + import matplotlib.pyplot as plt import numpy as np import pytest +from meegkit import dss +from meegkit.utils import create_line_data, fold, tscov, unfold from numpy.testing import assert_allclose from scipy import signal -from meegkit import dss -from meegkit.utils import fold, rms, tscov, unfold - - -def create_data(n_samples=100 * 3, n_chans=30, n_trials=100, noise_dim=20, - n_bad_chans=1, SNR=.1, show=False): - """Create synthetic data. - - Parameters - ---------- - n_samples : int - [description], by default 100*3 - n_chans : int - [description], by default 30 - n_trials : int - [description], by default 100 - noise_dim : int - Dimensionality of noise, by default 20 - n_bad_chans : int - [description], by default 1 - - Returns - ------- - data : ndarray, shape=(n_samples, n_chans, n_trials) - source : ndarray, shape=(n_samples,) - """ - # source - source = np.hstack(( - np.zeros((n_samples // 3,)), - np.sin(2 * np.pi * np.arange(n_samples // 3) / (n_samples / 3)).T, - np.zeros((n_samples // 3,))))[np.newaxis].T - s = source * np.random.randn(1, n_chans) # 300 * 30 - s = s[:, :, np.newaxis] - s = np.tile(s, (1, 1, 100)) - - # set first `n_bad_chans` to zero - s[:, :n_bad_chans] = 0. - - # noise - noise = np.dot( - unfold(np.random.randn(n_samples, noise_dim, n_trials)), - np.random.randn(noise_dim, n_chans)) - noise = fold(noise, n_samples) - - # mix signal and noise - data = noise / rms(noise.flatten()) + SNR * s / rms(s.flatten()) - - if show: - f, ax = plt.subplots(3) - ax[0].plot(source[:, 0], label='source') - ax[1].plot(noise[:, 1, 0], label='noise') - ax[2].plot(data[:, 1, 0], label='mixture') - ax[0].legend() - ax[1].legend() - ax[2].legend() - plt.show() - - return data, source - @pytest.mark.parametrize('n_bad_chans', [0, -1]) def test_dss0(n_bad_chans): @@ -78,7 +24,7 @@ def test_dss0(n_bad_chans): to zero. """ n_samples = 300 - data, source = create_data(n_samples=n_samples, n_bad_chans=n_bad_chans) + data, source = create_line_data(n_samples=n_samples, n_bad_chans=n_bad_chans) # apply DSS to clean them c0, _ = tscov(data) @@ -96,7 +42,7 @@ def test_dss0(n_bad_chans): def test_dss1(show=False): """Test DSS1 (evoked).""" n_samples = 300 - data, source = create_data(n_samples=n_samples) + data, source = create_line_data(n_samples=n_samples) todss, _, pwr0, pwr1 = dss.dss1(data, weights=None, ) z = fold(np.dot(unfold(data), todss), epoch_size=n_samples) @@ -135,11 +81,8 @@ def test_dss_line(nkeep): fline = 20 nsamples = 10000 nchans = 10 - x = np.random.randn(nsamples, nchans) - artifact = np.sin(np.arange(nsamples) / sr * 2 * np.pi * fline)[:, None] - artifact[artifact < 0] = 0 - artifact = artifact ** 3 - s = x + 10 * artifact + s = create_line_data(n_samples=3 * nsamples, n_chans=nchans, + n_trials=1, fline=fline / sr, SNR=2)[0][..., 0] def _plot(x): f, ax = plt.subplots(1, 2, sharey=True) @@ -174,8 +117,57 @@ def _plot(x): out, _ = dss.dss_line(s, fline, sr, nremove=1) +def test_dss_line_iter(): + """Test line noise removal.""" + + # data = np.load("data/dss_line_iter_test_data.npy") + # # time x channel x trial sf=200 fline=50 + + sr = 200 + fline = 25 + n_samples = 9000 + n_chans = 10 + + # 2D case, n_outputs == 1 + x, _ = create_line_data(n_samples, n_chans=n_chans, n_trials=1, + noise_dim=10, SNR=2, fline=fline / sr) + x = x[..., 0] + + # RuntimeError when max iterations has been reached + with pytest.raises(RuntimeError): + out, _ = dss.dss_line_iter(x, fline + 1, sr, + show=False, n_iter_max=2) + + with TemporaryDirectory() as tmpdir: + out, _ = dss.dss_line_iter(x, fline + .5, sr, + prefix=os.path.join(tmpdir, 'dss_iter_'), + show=True) + + def _plot(before, after): + f, ax = plt.subplots(1, 2, sharey=True) + f, Pxx = signal.welch(before[:, -1], sr, nperseg=1024, axis=0, + return_onesided=True) + ax[0].semilogy(f, Pxx) + f, Pxx = signal.welch(after[:, -1], sr, nperseg=1024, axis=0, + return_onesided=True) + ax[1].semilogy(f, Pxx) + ax[0].set_xlabel('frequency [Hz]') + ax[1].set_xlabel('frequency [Hz]') + ax[0].set_ylabel('PSD [V**2/Hz]') + ax[0].set_title('before') + ax[1].set_title('after') + plt.show() + + _plot(x, out) + + # # Test n_trials > 1 TODO + x, _ = create_line_data(n_samples, n_chans=n_chans, n_trials=2, + noise_dim=10, SNR=2, fline=fline / sr) + out, _ = dss.dss_line_iter(x, fline, sr, show=False) + if __name__ == '__main__': - pytest.main([__file__]) + # pytest.main([__file__]) # create_data(SNR=5, show=True) # test_dss1(True) # test_dss_line(None) + test_dss_line_iter()