Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/pythonpackage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: unit-tests

on:
push:
branches:
branches:
- master
pull_request:
branches:
Expand Down Expand Up @@ -35,7 +35,7 @@ jobs:
make pep
- name: Test with pytest
run: |
pytest --cov=meegkit --cov-report=xml tests/
pytest --cov=./ --cov-report=xml tests/
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1
with:
Expand Down
105 changes: 58 additions & 47 deletions meegkit/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,18 @@
import logging

import numpy as np

from scipy import linalg, signal
from statsmodels.robust.scale import mad

from .utils import nonlinear_eigenspace, block_covariance
from .utils.asr import (block_geometric_median, fit_eeg_distribution, yulewalk,
from .utils import block_covariance, nonlinear_eigenspace
from .utils.asr import (geometric_median, fit_eeg_distribution, yulewalk,
yulewalk_filter)

try:
import pyriemann
except ImportError:
pyriemann = None

__all__ = ['ASR', 'clean_windows', 'asr_calibrate', 'asr_process']


class ASR():
"""Artifact Subspace Reconstruction.
Expand Down Expand Up @@ -64,6 +61,9 @@ class ASR():
ASR [2]_.
memory : float
Memory size (s), regulates the number of covariance matrices to store.
estimator : str in {'scm', 'lwf', 'oas', 'mcd'}
Covariance estimator (default: 'scm' which computes the sample
covariance). Use 'lwf' if you need regularization (requires pyriemann).

Attributes
----------
Expand Down Expand Up @@ -99,10 +99,10 @@ class ASR():

"""

def __init__(self, sfreq=250, cutoff=5, blocksize=10, win_len=0.5,
def __init__(self, sfreq=250, cutoff=5, blocksize=100, win_len=0.5,
win_overlap=0.66, max_dropout_fraction=0.1,
min_clean_fraction=0.25, name='asrfilter', method='euclid',
**kwargs):
estimator='scm', **kwargs):

if pyriemann is None and method == 'riemann':
logging.warning('Need pyriemann to use riemannian ASR flavor.')
Expand All @@ -116,12 +116,19 @@ def __init__(self, sfreq=250, cutoff=5, blocksize=10, win_len=0.5,
self.min_clean_fraction = min_clean_fraction
self.max_bad_chans = 0.3
self.method = method
self.memory = 1 * sfreq # smoothing window for covariances
self.memory = int(2 * sfreq) # smoothing window for covariances
self.sample_weight = np.geomspace(0.05, 1, num=self.memory + 1)
self.sfreq = sfreq
self.estimator = estimator

self.reset()

def reset(self):
"""Reset filter."""
# Initialise yulewalk-filter coefficients with sensible defaults
F = np.array([0, 2, 3, 13, 16, 40, np.minimum(
80.0, (sfreq / 2.0) - 1.0), sfreq / 2.0]) * 2.0 / sfreq
F = np.array([0, 2, 3, 13, 16, 40,
np.minimum(80.0, (self.sfreq / 2.0) - 1.0),
self.sfreq / 2.0]) * 2.0 / self.sfreq
M = np.array([3, 0.75, 0.33, 0.33, 1, 1, 3, 3])
B, A = yulewalk(8, F, M)
self.ab_ = (A, B)
Expand All @@ -131,10 +138,6 @@ def __init__(self, sfreq=250, cutoff=5, blocksize=10, win_len=0.5,
self._counter = []
self._fitted = False

def _reset(self):
"""Reset filter."""
return

def fit(self, X, y=None, **kwargs):
"""Calibration for the Artifact Subspace Reconstruction method.

Expand Down Expand Up @@ -186,7 +189,8 @@ def fit(self, X, y=None, **kwargs):
win_overlap=self.win_overlap,
max_dropout_fraction=self.max_dropout_fraction,
min_clean_fraction=self.min_clean_fraction,
method=self.method)
method=self.method,
estimator=self.estimator)

self.state_ = dict(M=M, T=T, R=None)
self._fitted = True
Expand All @@ -212,18 +216,24 @@ def transform(self, X, y=None, **kwargs):
out = self.transform(X[0])
return out[None, ...]
else:
return X

# Yulewalk-filtered data
X_filt, self.zi_ = yulewalk_filter(
X, sfreq=self.sfreq, ab=self.ab_, zi=self.zi_)
outs = [self.transform(x) for x in X]
return np.stack(outs, axis=0)
else:
# Yulewalk-filtered data
X_filt, self.zi_ = yulewalk_filter(
X, sfreq=self.sfreq, ab=self.ab_, zi=self.zi_)

if not self._fitted:
logging.warning('ASR is not fitted ! Returning unfiltered data.')
return X

cov = 1 / X.shape[-1] * X_filt @ X_filt.T
self._counter.append(X.shape[-1])
if self.estimator == 'scm':
cov = 1 / X.shape[-1] * X_filt @ X_filt.T
else:
cov = pyriemann.estimation.covariances(X_filt[None, ...],
self.estimator)[0]

self._counter.append(X_filt.shape[-1])
self.cov_.append(cov)

# Regulate the number of covariance matrices that are stored
Expand All @@ -232,19 +242,20 @@ def transform(self, X, y=None, **kwargs):
self.cov_.pop(0)
self._counter.pop(0)
else:
self._counter[0] = self.memory
self._counter = [self.memory, ]
break

# Exponential covariance weight – the most recent covariance has a
# Exponential covariance weights – the most recent covariance has a
# weight of 1, while the oldest one in memory has a weight of 5%
sample_weight = np.geomspace(0.05, 1, num=self.memory + 1)
sample_weight = sample_weight[self._counter]
weights = [1, ]
for c in np.cumsum(self._counter[1:]):
weights = [self.sample_weight[-c]] + weights

# Clean data, using covariances weighted by sample_weight
out, self.state_ = asr_process(X, X_filt, self.state_,
cov=np.stack(self.cov_),
method=self.method,
sample_weight=sample_weight)
sample_weight=weights)

return out

Expand Down Expand Up @@ -409,9 +420,9 @@ def clean_windows(X, sfreq, max_bad_chans=0.2, zthresholds=[-3.5, 5],
return clean, sample_mask


def asr_calibrate(X, sfreq, cutoff=5, blocksize=10, win_len=0.5,
def asr_calibrate(X, sfreq, cutoff=5, blocksize=100, win_len=0.5,
win_overlap=0.66, max_dropout_fraction=0.1,
min_clean_fraction=0.25, method='euclid'):
min_clean_fraction=0.25, method='euclid', estimator='scm'):
"""Calibration function for the Artifact Subspace Reconstruction method.

The input to this data is a multi-channel time series of calibration data.
Expand Down Expand Up @@ -455,8 +466,8 @@ def asr_calibrate(X, sfreq, cutoff=5, blocksize=10, win_len=0.5,
blocksize : int
Block size for calculating the robust data covariance and thresholds,
in samples; allows to reduce the memory and time requirements of the
robust estimators by this factor (down to Channels x Channels x Samples
x 16 / Blocksize bytes) (default=10).
robust estimators by this factor (down to n_chans x n_chans x n_samples
x 16 / blocksize bytes) (default=100).
win_len : float
Window length that is used to check the data for artifact content. This
is ideally as long as the expected time scale of the artifacts but
Expand Down Expand Up @@ -490,23 +501,18 @@ def asr_calibrate(X, sfreq, cutoff=5, blocksize=10, win_len=0.5,
# window length for calculating thresholds
N = int(np.round(win_len * sfreq))

U = block_covariance(X, window=blocksize, overlap=win_overlap,
estimator=estimator)
if method == 'euclid':
U = np.zeros((blocksize, nc, nc))
for k in range(blocksize):
rangevect = np.minimum(ns - 1, np.arange(k, ns + k, blocksize))
x = X[:, rangevect]
U[k, ...] = x @ x.T
Uavg = block_geometric_median(U.reshape((-1, nc * nc)) / blocksize, 2)
Uavg = geometric_median(U.reshape((-1, nc * nc)))
Uavg = Uavg.reshape((nc, nc))
elif method == 'riemann':
blocksize = int(ns // blocksize)
U = block_covariance(X, window=blocksize, overlap=win_overlap)
else: # method == 'riemann'
Uavg = pyriemann.utils.mean.mean_covariance(U, metric='riemann')

# get the mixing matrix M
M = linalg.sqrtm(np.real(Uavg))
D, Vtmp = linalg.eig(M)
# D, Vtmp = nonlinear_eigenspace(M, nc)
D, Vtmp = linalg.eigh(M)
# D, Vtmp = nonlinear_eigenspace(M, nc) TODO
V = Vtmp[:, np.argsort(D)]

# get the threshold matrix T
Expand Down Expand Up @@ -573,19 +579,24 @@ def asr_process(X, X_filt, state, cov=None, detrend=False, method='riemann',
if cov is None:
if detrend:
X_filt = signal.detrend(X_filt, axis=1, type='constant')
cov = np.cov(X_filt, bias=True)
else:
if cov.ndim == 3:
cov = block_covariance(X_filt, window=nc ** 2)

cov = cov.squeeze()
if cov.ndim == 3:
if method == 'riemann':
cov = pyriemann.utils.mean.mean_covariance(
cov, metric='riemann', sample_weight=sample_weight)
else:
cov = geometric_median(cov.reshape((-1, nc * nc)))
cov = cov.reshape((nc, nc))

maxdims = int(np.fix(0.66 * nc)) # constant TODO make param

# do a PCA to find potential artifacts
if method == 'riemann':
D, Vtmp = nonlinear_eigenspace(cov, nc) # TODO
else:
D, Vtmp = np.linalg.eig(cov)
D, Vtmp = linalg.eigh(cov)

V = np.real(Vtmp[:, np.argsort(D)])
D = np.real(D[np.argsort(D)])
Expand All @@ -602,7 +613,7 @@ def asr_process(X, X_filt, state, cov=None, detrend=False, method='riemann',
else:
VT = np.dot(V.T, M)
demux = VT * keep[:, None]
R = np.dot(np.dot(M, np.linalg.pinv(demux)), V.T)
R = np.dot(np.dot(M, linalg.pinv(demux)), V.T)

if state['R'] is not None:
# apply the reconstruction to intermediate samples (using raised-cosine
Expand Down
Loading