diff --git a/meegkit/asr.py b/meegkit/asr.py index 8eec218b..f2a6c0d9 100755 --- a/meegkit/asr.py +++ b/meegkit/asr.py @@ -27,10 +27,6 @@ class ASR(): Parameters ---------- - X : array, shape=() - *zero-mean* (e.g., high-pass filtered) and reasonably clean EEG of not - much less than 30 seconds (this method is typically used with 1 minute - or more). sfreq : float Sampling rate of the data, in Hz. @@ -162,7 +158,9 @@ def fit(self, X, y=None, **kwargs): ---------- X : array, shape=(n_channels, n_samples) The calibration data should have been high-pass filtered (for - example at 0.5Hz or 1Hz using a Butterworth IIR filter). + example at 0.5Hz or 1Hz using a Butterworth IIR filter), and be + reasonably clean not less than 30 seconds (this method is typically + used with 1 minute or more). """ if X.ndim == 3: @@ -202,8 +200,6 @@ def transform(self, X, y=None, **kwargs): ---------- X : array, shape=([n_trials, ]n_channels, n_samples) Raw data. - X_filt : array, shape=([n_trials, ]n_channels, n_samples) - Yulewalk-filtered data (optional). If this i Returns ------- @@ -218,7 +214,7 @@ def transform(self, X, y=None, **kwargs): else: return X - # Yulewalk-filtered data (optional). + # Yulewalk-filtered data X_filt, self.zi_ = yulewalk_filter( X, sfreq=self.sfreq, ab=self.ab_, zi=self.zi_) diff --git a/meegkit/utils/asr.py b/meegkit/utils/asr.py index 8fe9e584..f93498af 100755 --- a/meegkit/utils/asr.py +++ b/meegkit/utils/asr.py @@ -316,7 +316,8 @@ def yulewalk_filter(X, sfreq, zi=None, ab=None, axis=-1): # apply the signal shaping filter and initialize the IIR filter state if zi is None: zi = signal.lfilter_zi(B, A) - out, zf = signal.lfilter(B, A, X, zi=zi[:, None] * X[:, 0], axis=axis) + zi = np.transpose(X[:, 0] * zi[:, None]) + out, zf = signal.lfilter(B, A, X, zi=zi, axis=axis) else: out, zf = signal.lfilter(B, A, X, zi=zi, axis=axis) diff --git a/tests/test_asr.py b/tests/test_asr.py index b6995c32..8a0ac114 100644 --- a/tests/test_asr.py +++ b/tests/test_asr.py @@ -90,6 +90,31 @@ def test_yulewalk(sfreq, show=False): plt.show() +@pytest.mark.parametrize(argnames='n_chans', argvalues=(4, 8, 12)) +def test_yulewalk_filter(n_chans, show=False): + """Test yulewalk filter.""" + rawp = raw.copy() + n_chan_orig = rawp.shape[0] + rawp = np.random.randn(n_chans, n_chan_orig) @ rawp + raw_filt, iirstate = yulewalk_filter(rawp, sfreq) + + if show: + f, ax = plt.subplots(n_chans, sharex=True, figsize=(8, 5)) + for i in range(n_chans): + ax[i].plot(rawp[i], lw=.5, label='before') + ax[i].plot(raw_filt[i], label='after', lw=.5) + ax[i].set_ylim([-50, 50]) + if i < n_chans - 1: + ax[i].set_yticks([]) + ax[i].set_xlabel('Time (s)') + ax[i].set_ylabel(f'ch{i}') + ax[0].legend(fontsize='small', bbox_to_anchor=(1.04, 1), + borderaxespad=0) + plt.subplots_adjust(hspace=0, right=0.75) + plt.suptitle('Before/after filter') + plt.show() + + def test_asr_functions(show=False, method='riemann'): """Test ASR functions (offline use). @@ -178,3 +203,4 @@ def test_asr_class(show=False): # test_yulewalk(250, True) # test_asr_functions(True) # test_asr_class(True) + # test_yulewalk_filter(16, True)