Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/meegkit.dss.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
dss0
dss1
dss_line
dss_line_iter



Expand Down
92 changes: 92 additions & 0 deletions examples/example_dss_line.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
Remove line noise with ZapLine
==============================

Find a spatial filter to get rid of line noise [1]_.

Uses meegkit.dss_line().

References
----------
.. [1] de Cheveigné, A. (2019). ZapLine: A simple and effective method to
remove power line artifacts [Preprint]. https://doi.org/10.1101/782029

"""
# Authors: Maciej Szul <maciej.szul@isc.cnrs.fr>
# Nicolas Barascud <nicolas.barascud@gmail.com>
import os

import matplotlib.pyplot as plt
import numpy as np
from meegkit import dss
from meegkit.utils import create_line_data, unfold
from scipy import signal

###############################################################################
# Line noise removal
# =============================================================================

###############################################################################
# Remove line noise with dss_line()
# -----------------------------------------------------------------------------
# We first generate some noisy data to work with
sfreq = 250
fline = 50
nsamples = 10000
nchans = 10
data = create_line_data(n_samples=3 * nsamples, n_chans=nchans,
n_trials=1, fline=fline / sfreq, SNR=2)[0]
data = data[..., 0] # only take first trial

# Apply dss_line (ZapLine)
out, _ = dss.dss_line(data, fline, sfreq, nkeep=1)

###############################################################################
# Plot before/after
f, ax = plt.subplots(1, 2, sharey=True)
f, Pxx = signal.welch(data, sfreq, nperseg=500, axis=0, return_onesided=True)
ax[0].semilogy(f, Pxx)
f, Pxx = signal.welch(out, sfreq, nperseg=500, axis=0, return_onesided=True)
ax[1].semilogy(f, Pxx)
ax[0].set_xlabel('frequency [Hz]')
ax[1].set_xlabel('frequency [Hz]')
ax[0].set_ylabel('PSD [V**2/Hz]')
ax[0].set_title('before')
ax[1].set_title('after')
plt.show()


###############################################################################
# Remove line noise with dss_line_iter()
# -----------------------------------------------------------------------------
# We first load some noisy data to work with
data = np.load(os.path.join('..', 'tests', 'data', 'dss_line_data.npy'))
fline = 50
sfreq = 200
print(data.shape) # n_samples, n_chans, n_trials

# Apply dss_line(), removing only one component
out1, _ = dss.dss_line(data, fline, sfreq, nremove=1, nfft=400)

###############################################################################
# Now try dss_line_iter(). This applies dss_line() repeatedly until the
# artifact is gone
out2, iterations = dss.dss_line_iter(data, fline, sfreq, nfft=400)
print(f'Removed {iterations} components')

###############################################################################
# Plot results with dss_line() vs. dss_line_iter()
f, ax = plt.subplots(1, 2, sharey=True)
f, Pxx = signal.welch(unfold(out1), sfreq, nperseg=200, axis=0,
return_onesided=True)
ax[0].semilogy(f, Pxx, lw=.5)
f, Pxx = signal.welch(unfold(out2), sfreq, nperseg=200, axis=0,
return_onesided=True)
ax[1].semilogy(f, Pxx, lw=.5)
ax[0].set_xlabel('frequency [Hz]')
ax[1].set_xlabel('frequency [Hz]')
ax[0].set_ylabel('PSD [V**2/Hz]')
ax[0].set_title('dss_line')
ax[1].set_title('dss_line_iter')
plt.tight_layout()
plt.show()
2 changes: 1 addition & 1 deletion meegkit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""M/EEG denoising utilities in python."""
__version__ = '0.1.1'
__version__ = '0.1.2'

from . import asr, cca, detrend, dss, sns, star, ress, trca, tspca, utils

Expand Down
131 changes: 131 additions & 0 deletions meegkit/dss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""Denoising source separation."""
# Authors: Nicolas Barascud <nicolas.barascud@gmail.com>
# Maciej Szul <maciej.szul@isc.cnrs.fr>

import numpy as np
from scipy import linalg
from scipy.signal import welch

from .tspca import tsr
from .utils import (demean, gaussfilt, mean_over_trials, pca, smooth,
Expand Down Expand Up @@ -230,3 +234,130 @@ def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, show=False):
p = wpwr(X - y)[0] / wpwr(X)[0]
print('Power of components removed by DSS: {:.2f}'.format(p))
return y, artifact


def dss_line_iter(data, fline, sfreq, win_sz=10, spot_sz=2.5,
nfft=512, show=False, prefix="dss_iter", n_iter_max=100):
"""Remove power line artifact iteratively.

This method applies dss_line() until the artifact has been smoothed out
from the spectrum.

Parameters
----------
data : data, shape=(n_samples, n_chans, n_trials)
Input data.
fline : float
Line frequency.
sfreq : float
Sampling frequency.
win_sz : float
Half of the width of the window around the target frequency used to fit
the polynomial (default=10).
spot_sz : float
Half of the width of the window around the target frequency used to
remove the peak and interpolate (default=2.5).
nfft : int
FFT size for the internal PSD calculation (default=512).
show: bool
Produce a visual output of each iteration (default=False).
prefix : str
Path and first part of the visualisation output file
"{prefix}_{iteration number}.png" (default="dss_iter").
n_iter_max : int
Maximum number of iterations (default=100).

Returns
-------
data : array, shape=(n_samples, n_chans, n_trials)
Denoised data.
iterations : int
Number of iterations.
"""

def nan_basic_interp(array):
"""Nan interpolation."""
nans, ix = np.isnan(array), lambda x: x.nonzero()[0]
array[nans] = np.interp(ix(nans), ix(~nans), array[~nans])
return array

freq_rn = [fline - win_sz, fline + win_sz]
freq_sp = [fline - spot_sz, fline + spot_sz]
freq, psd = welch(data, fs=sfreq, nfft=nfft, axis=0)

freq_rn_ix = np.logical_and(freq >= freq_rn[0], freq <= freq_rn[1])
freq_used = freq[freq_rn_ix]
freq_sp_ix = np.logical_and(freq_used >= freq_sp[0],
freq_used <= freq_sp[1])

if psd.ndim == 3:
mean_psd = np.mean(psd, axis=(1, 2))[freq_rn_ix]
elif psd.ndim == 2:
mean_psd = np.mean(psd, axis=(1))[freq_rn_ix]

mean_psd_wospot = mean_psd.copy()
mean_psd_wospot[freq_sp_ix] = np.nan
mean_psd_tf = nan_basic_interp(mean_psd_wospot)
pf = np.polyfit(freq_used, mean_psd_tf, 3)
p = np.poly1d(pf)
clean_fit_line = p(freq_used)

aggr_resid = []
iterations = 0
while iterations < n_iter_max:
data, _ = dss_line(data, fline, sfreq, nfft=nfft, nremove=1)
freq, psd = welch(data, fs=sfreq, nfft=nfft, axis=0)
if psd.ndim == 3:
mean_psd = np.mean(psd, axis=(1, 2))[freq_rn_ix]
elif psd.ndim == 2:
mean_psd = np.mean(psd, axis=(1))[freq_rn_ix]

residuals = mean_psd - clean_fit_line
mean_score = np.mean(residuals[freq_sp_ix])
aggr_resid.append(mean_score)

print("Iteration {} score: {}".format(iterations, mean_score))

if show:
import matplotlib.pyplot as plt
f, ax = plt.subplots(2, 2, figsize=(12, 6), facecolor="white")

if psd.ndim == 3:
mean_sens = np.mean(psd, axis=2)
elif psd.ndim == 2:
mean_sens = psd

y = mean_sens[freq_rn_ix]
ax.flat[0].plot(freq_used, y)
ax.flat[0].set_title("Mean PSD across trials")

ax.flat[1].plot(freq_used, mean_psd_tf, c="gray")
ax.flat[1].plot(freq_used, mean_psd, c="blue")
ax.flat[1].plot(freq_used, clean_fit_line, c="red")
ax.flat[1].set_title("Mean PSD across trials and sensors")

tf_ix = np.where(freq_used <= fline)[0][-1]
ax.flat[2].plot(residuals, freq_used)
color = "green"
if mean_score <= 0:
color = "red"
ax.flat[2].scatter(residuals[tf_ix], freq_used[tf_ix], c=color)
ax.flat[2].set_title("Residuals")

ax.flat[3].plot(np.arange(iterations + 1), aggr_resid, marker='o')
ax.flat[3].set_title("Iterations")

f.set_tight_layout(True)
plt.savefig(f"{prefix}_{iterations:03}.png")
plt.close("all")

if mean_score <= 0:
break

iterations += 1

if iterations == n_iter_max:
raise RuntimeError('Could not converge. Consider increasing the '
'maximum number of iterations')

return data, iterations
1 change: 1 addition & 0 deletions meegkit/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
spectral_envelope, teager_kaiser)
from .stats import (bootstrap_ci, bootstrap_snr, cronbach, rms, robust_mean,
rolling_corr, snr_spectrum)
from .testing import create_line_data
74 changes: 74 additions & 0 deletions meegkit/utils/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Synthetic test data."""
import numpy as np
from meegkit.utils import fold, rms, unfold

import matplotlib.pyplot as plt


def create_line_data(n_samples=100 * 3, n_chans=30, n_trials=100, noise_dim=20,
n_bad_chans=1, SNR=.1, fline=1, t0=None, show=False):
"""Create synthetic data.

Parameters
----------
n_samples : int
Number of samples (default=100*3).
n_chans : int
Number of channels (default=30).
n_trials : int
Number of trials (default=100).
noise_dim : int
Dimensionality of noise (default=20).
n_bad_chans : int
Number of bad channels (default=1).
t0 : int
Onset sample of artifact.
fline : float
Normalized frequency of artifact (freq/samplerate), (default=1).

Returns
-------
data : ndarray, shape=(n_samples, n_chans, n_trials)
source : ndarray, shape=(n_samples,)
"""
rng = np.random.RandomState(2022)

if t0 is None:
t0 = n_samples // 3
t1 = n_samples - 2 * t0 # artifact duration

# create source signal
source = np.hstack((
np.zeros(t0),
np.sin(2 * np.pi * fline * np.arange(t1)),
np.zeros(t0))) # noise -> artifact -> noise
source = source[:, None]

# mix source in channels
s = source * rng.randn(1, n_chans)
s = s[:, :, np.newaxis]
s = np.tile(s, (1, 1, n_trials)) # create trials

# set first `n_bad_chans` to zero
s[:, :n_bad_chans] = 0.

# noise
noise = np.dot(
unfold(rng.randn(n_samples, noise_dim, n_trials)),
rng.randn(noise_dim, n_chans))
noise = fold(noise, n_samples)

# mix signal and noise
data = noise / rms(noise.flatten()) + SNR * s / rms(s.flatten())

if show:
f, ax = plt.subplots(3)
ax[0].plot(source.mean(-1), label='source')
ax[1].plot(noise[:, 1].mean(-1), label='noise (avg over trials)')
ax[2].plot(data[:, 1].mean(-1), label='mixture (avg over trials)')
ax[0].legend()
ax[1].legend()
ax[2].legend()
plt.show()

return data, source
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@
author='N Barascud',
author_email='nicolas.barascud@gmail.com',
license='UNLICENSED',
version='0.1.1',
version='0.1.2',
packages=find_packages(exclude=['doc', 'tests']),
zip_safe=False)
36 changes: 30 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,33 @@
import pytest
import numpy as np
import random as rand

import matplotlib.pyplot as plt

@pytest.fixture
def random():
rand.seed(9)
np.random.seed(9)

def pytest_addoption(parser):
"""Add command line option to pytest."""
parser.addoption(
"--runslow",
action="store_true",
default=False,
help="run slow tests"
)
parser.addoption(
"--noplots",
action="store_true",
default=False,
help="halt on plots"
)


def pytest_collection_modifyitems(config, items):
"""Do not skip slow test if option provided."""
if config.getoption("--noplots"):
plt.switch_backend('agg')

if config.getoption("--runslow"):
# --runslow given in cli: do not skip slow tests
return
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)
Binary file added tests/data/dss_line_data.npy
Binary file not shown.
Loading