Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions meegkit/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ class ASR():

Parameters
----------
X : array, shape=()
*zero-mean* (e.g., high-pass filtered) and reasonably clean EEG of not
much less than 30 seconds (this method is typically used with 1 minute
or more).
sfreq : float
Sampling rate of the data, in Hz.

Expand Down Expand Up @@ -162,7 +158,9 @@ def fit(self, X, y=None, **kwargs):
----------
X : array, shape=(n_channels, n_samples)
The calibration data should have been high-pass filtered (for
example at 0.5Hz or 1Hz using a Butterworth IIR filter).
example at 0.5Hz or 1Hz using a Butterworth IIR filter), and be
reasonably clean not less than 30 seconds (this method is typically
used with 1 minute or more).

"""
if X.ndim == 3:
Expand Down Expand Up @@ -202,8 +200,6 @@ def transform(self, X, y=None, **kwargs):
----------
X : array, shape=([n_trials, ]n_channels, n_samples)
Raw data.
X_filt : array, shape=([n_trials, ]n_channels, n_samples)
Yulewalk-filtered data (optional). If this i

Returns
-------
Expand All @@ -218,7 +214,7 @@ def transform(self, X, y=None, **kwargs):
else:
return X

# Yulewalk-filtered data (optional).
# Yulewalk-filtered data
X_filt, self.zi_ = yulewalk_filter(
X, sfreq=self.sfreq, ab=self.ab_, zi=self.zi_)

Expand Down
3 changes: 2 additions & 1 deletion meegkit/utils/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,8 @@ def yulewalk_filter(X, sfreq, zi=None, ab=None, axis=-1):
# apply the signal shaping filter and initialize the IIR filter state
if zi is None:
zi = signal.lfilter_zi(B, A)
out, zf = signal.lfilter(B, A, X, zi=zi[:, None] * X[:, 0], axis=axis)
zi = np.transpose(X[:, 0] * zi[:, None])
out, zf = signal.lfilter(B, A, X, zi=zi, axis=axis)
else:
out, zf = signal.lfilter(B, A, X, zi=zi, axis=axis)

Expand Down
26 changes: 26 additions & 0 deletions tests/test_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,31 @@ def test_yulewalk(sfreq, show=False):
plt.show()


@pytest.mark.parametrize(argnames='n_chans', argvalues=(4, 8, 12))
def test_yulewalk_filter(n_chans, show=False):
"""Test yulewalk filter."""
rawp = raw.copy()
n_chan_orig = rawp.shape[0]
rawp = np.random.randn(n_chans, n_chan_orig) @ rawp
raw_filt, iirstate = yulewalk_filter(rawp, sfreq)

if show:
f, ax = plt.subplots(n_chans, sharex=True, figsize=(8, 5))
for i in range(n_chans):
ax[i].plot(rawp[i], lw=.5, label='before')
ax[i].plot(raw_filt[i], label='after', lw=.5)
ax[i].set_ylim([-50, 50])
if i < n_chans - 1:
ax[i].set_yticks([])
ax[i].set_xlabel('Time (s)')
ax[i].set_ylabel(f'ch{i}')
ax[0].legend(fontsize='small', bbox_to_anchor=(1.04, 1),
borderaxespad=0)
plt.subplots_adjust(hspace=0, right=0.75)
plt.suptitle('Before/after filter')
plt.show()


def test_asr_functions(show=False, method='riemann'):
"""Test ASR functions (offline use).

Expand Down Expand Up @@ -178,3 +203,4 @@ def test_asr_class(show=False):
# test_yulewalk(250, True)
# test_asr_functions(True)
# test_asr_class(True)
# test_yulewalk_filter(16, True)