diff --git a/README.md b/README.md index 82a095ad..27d6f260 100644 --- a/README.md +++ b/README.md @@ -91,3 +91,22 @@ If you use this, you should cite the following article: [1] Cohen, M. X., & Gulbinaite, R. (2017). Rhythmic entrainment source separation: Optimizing analyses of neural responses to rhythmic sensory stimulation. Neuroimage, 147, 43-56. ``` + +### 4. Task-Related Component Analysis (TRCA) + +This code is based on the [Matlab implementation from Masaki Nakanishi](https://github.com/mnakanishi/TRCA-SSVEP), and was adapted to python by [Giuseppe Ferraro](mailto:giuseppe.ferraro@isae-supaero.fr) + +If you use this, you should cite the following articles: + +```sql +[1] M. Nakanishi, Y. Wang, X. Chen, Y.-T. Wang, X. Gao, and T.-P. Jung, + "Enhancing detection of SSVEPs for a high-speed brain speller using + task-related component analysis", IEEE Trans. Biomed. Eng, 65(1): 104-112, + 2018. +[2] X. Chen, Y. Wang, S. Gao, T. -P. Jung and X. Gao, "Filter bank + canonical correlation analysis for implementing a high-speed SSVEP-based + brain-computer interface", J. Neural Eng., 12: 046008, 2015. +[3] X. Chen, Y. Wang, M. Nakanishi, X. Gao, T. -P. Jung, S. Gao, + "High-speed spelling with a noninvasive brain-computer interface", + Proc. Int. Natl. Acad. Sci. U. S. A, 112(44): E6058-6067, 2015. +``` diff --git a/doc/index.rst b/doc/index.rst index f733c353..e4a89444 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -26,6 +26,7 @@ Contents ~meegkit.ress ~meegkit.sns ~meegkit.star + ~meegkit.trca ~meegkit.tspca ~meegkit.utils diff --git a/doc/modules/meegkit.utils.rst b/doc/modules/meegkit.utils.rst index a1919394..9797de75 100644 --- a/doc/modules/meegkit.utils.rst +++ b/doc/modules/meegkit.utils.rst @@ -71,9 +71,19 @@ Signal ---- - Statistics ---------- .. automodule:: meegkit.utils.stats .. autosummary:: + + +| + +---- + +TRCA utilities +-------------- +.. automodule:: meegkit.utils.trca + + .. autosummary:: diff --git a/examples/example_asr.py b/examples/example_asr.py index 77c70752..d8634d6c 100644 --- a/examples/example_asr.py +++ b/examples/example_asr.py @@ -11,7 +11,6 @@ import matplotlib.pyplot as plt from meegkit.asr import ASR -from meegkit.utils.asr import yulewalk_filter from meegkit.utils.matrix import sliding_window # THIS_FOLDER = os.path.dirname(os.path.abspath(__file__)) diff --git a/examples/example_trca.ipynb b/examples/example_trca.ipynb new file mode 100644 index 00000000..78382686 --- /dev/null +++ b/examples/example_trca.ipynb @@ -0,0 +1,258 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "# Task-related component analysis (TRCA)-based SSVEP detection\n", + "\n", + "Sample code for the task-related component analysis (TRCA)-based steady\n", + "-state visual evoked potential (SSVEP) detection method [1]_. The filter\n", + "bank analysis [2, 3]_ can also be combined to the TRCA-based algorithm.\n", + "\n", + "Uses meegkit.trca.TRCA()\n", + "\n", + "References:\n", + "\n", + ".. [1] M. Nakanishi, Y. Wang, X. Chen, Y.-T. Wang, X. Gao, and T.-P. Jung,\n", + " \"Enhancing detection of SSVEPs for a high-speed brain speller using\n", + " task-related component analysis\", IEEE Trans. Biomed. Eng, 65(1): 104-112,\n", + " 2018.\n", + ".. [2] X. Chen, Y. Wang, S. Gao, T. -P. Jung and X. Gao, \"Filter bank\n", + " canonical correlation analysis for implementing a high-speed SSVEP-based\n", + " brain-computer interface\", J. Neural Eng., 12: 046008, 2015.\n", + ".. [3] X. Chen, Y. Wang, M. Nakanishi, X. Gao, T. -P. Jung, S. Gao,\n", + " \"High-speed spelling with a noninvasive brain-computer interface\",\n", + " Proc. Int. Natl. Acad. Sci. U. S. A, 112(44): E6058-6067, 2015.\n", + "\n", + "This code is based on the Matlab implementation from\n", + "https://github.com/mnakanishi/TRCA-SSVEP\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Author: Giuseppe Ferraro \n", + "import os\n", + "import time\n", + "\n", + "import numpy as np\n", + "import scipy.io\n", + "from meegkit.trca import TRCA\n", + "from meegkit.utils.trca import itr, normfit, round_half_up\n", + "\n", + "t = time.time()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Parameters\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Results of the ensemble TRCA-based method:\n\n" + ] + } + ], + "source": [ + "len_gaze_s = 0.5 # data length for target identification [s]\n", + "len_delay_s = 0.13 # visual latency being considered in the analysis [s]\n", + "n_bands = 5 # number of sub-bands in filter bank analysis\n", + "is_ensemble = True # True = ensemble TRCA method; False = TRCA method\n", + "alpha_ci = 0.05 # 100*(1-alpha_ci): confidence interval for accuracy\n", + "sfreq = 250 # sampling rate [Hz]\n", + "len_shift_s = 0.5 # duration for gaze shifting [s]\n", + "list_freqs = np.concatenate(\n", + " [[x + 8 for x in range(8)],\n", + " [x + 8.2 for x in range(8)],\n", + " [x + 8.4 for x in range(8)],\n", + " [x + 8.6 for x in range(8)],\n", + " [x + 8.8 for x in range(8)]]) # list of stimulus frequencies\n", + "n_targets = len(list_freqs) # The number of stimuli\n", + "\n", + "# Preparing useful variables (DONT'T need to modify)\n", + "len_gaze_smpl = round_half_up(len_gaze_s * sfreq) # data length [samples]\n", + "len_delay_smpl = round_half_up(len_delay_s * sfreq) # visual latency [samples]\n", + "len_sel_s = len_gaze_s + len_shift_s # selection time [s]\n", + "ci = 100 * (1 - alpha_ci) # confidence interval\n", + "\n", + "# Performing the TRCA-based SSVEP detection algorithm\n", + "print('Results of the ensemble TRCA-based method:\\n')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load data\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "path = os.path.join('..', 'tests', 'data', 'trcadata.mat')\n", + "mat = scipy.io.loadmat(path)\n", + "eeg = mat[\"eeg\"]\n", + "\n", + "n_trials = eeg.shape[0]\n", + "n_chans = eeg.shape[1]\n", + "n_samples = eeg.shape[2]\n", + "n_blocks = eeg.shape[3]\n", + "\n", + "# Convert dummy Matlab format to (sample, channels, trials) and construct\n", + "# vector of labels\n", + "eeg = np.reshape(eeg.transpose([2, 1, 3, 0]),\n", + " (n_samples, n_chans, n_trials * n_blocks))\n", + "labels = np.array([x for x in range(n_targets)] * n_blocks)\n", + "\n", + "crop_data = np.arange(len_delay_smpl, len_delay_smpl + len_gaze_smpl)\n", + "eeg = eeg[crop_data]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## TRCA classification\n", + "Estimate classification performance with a Leave-One-Block-Out\n", + "cross-validation approach.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Block 0: accuracy = 97.5, \tITR = 301.3\n", + "Block 1: accuracy = 100.0, \tITR = 319.3\n", + "Block 2: accuracy = 95.0, \tITR = 286.3\n", + "Block 3: accuracy = 95.0, \tITR = 286.3\n", + "Block 4: accuracy = 95.0, \tITR = 286.3\n", + "Block 5: accuracy = 100.0, \tITR = 319.3\n", + "\n", + "Mean accuracy = 97.1%\t(95% CI: 97.0-97.1%)\n", + "Mean ITR = 299.8\t(95% CI: 299.4-300.2%)\n", + "\n", + "Elapsed time: 13.8 seconds\n" + ] + } + ], + "source": [ + "# We use the filterbank specification described in [2]_.\n", + "filterbank = [[(6, 90), (4, 100)], # passband freqs, stopband freqs (Wp, Ws)\n", + " [(14, 90), (10, 100)],\n", + " [(22, 90), (16, 100)],\n", + " [(30, 90), (24, 100)],\n", + " [(38, 90), (32, 100)],\n", + " [(46, 90), (40, 100)],\n", + " [(54, 90), (48, 100)]]\n", + "trca = TRCA(sfreq, filterbank, is_ensemble)\n", + "\n", + "accs = np.zeros(n_blocks)\n", + "itrs = np.zeros(n_blocks)\n", + "for i in range(n_blocks):\n", + "\n", + " # Training stage\n", + " traindata = eeg.copy()\n", + "\n", + " # Select all folds except one for training\n", + " traindata = np.concatenate(\n", + " (traindata[..., :i * n_trials],\n", + " traindata[..., (i + 1) * n_trials:]), 2)\n", + " y_train = np.concatenate(\n", + " (labels[:i * n_trials], labels[(i + 1) * n_trials:]), 0)\n", + "\n", + " # Construction of the spatial filter and the reference signals\n", + " trca.fit(traindata, y_train)\n", + "\n", + " # Test stage\n", + " testdata = eeg[..., i * n_trials:(i + 1) * n_trials]\n", + " y_test = labels[i * n_trials:(i + 1) * n_trials]\n", + " estimated = trca.predict(testdata)\n", + "\n", + " # Evaluation of the performance for this fold (accuracy and ITR)\n", + " is_correct = estimated == y_test\n", + " accs[i] = np.mean(is_correct) * 100\n", + " itrs[i] = itr(n_targets, np.mean(is_correct), len_sel_s)\n", + " print(f\"Block {i}: accuracy = {accs[i]:.1f}, \\tITR = {itrs[i]:.1f}\")\n", + "\n", + "# Mean accuracy and ITR computation\n", + "mu, _, muci, _ = normfit(accs, alpha_ci)\n", + "print()\n", + "print(f\"Mean accuracy = {mu:.1f}%\\t({ci:.0f}% CI: {muci[0]:.1f}-{muci[1]:.1f}%)\") # noqa\n", + "\n", + "mu, _, muci, _ = normfit(itrs, alpha_ci)\n", + "print(f\"Mean ITR = {mu:.1f}\\t({ci:.0f}% CI: {muci[0]:.1f}-{muci[1]:.1f}%)\")\n", + "if is_ensemble:\n", + " ensemble = 'ensemble TRCA-based method'\n", + "else:\n", + " ensemble = 'TRCA-based method'\n", + "\n", + "print(f\"\\nElapsed time: {time.time()-t:.1f} seconds\")" + ] + } + ], + "metadata": { + "kernelspec": { + "name": "python388jvsc74a57bd0d64e410d98a0dc7c6b3fb09ececfc32281268599ac952adfc85e199a2f396698", + "display_name": "Python 3.8.8 64-bit ('miniconda3': virtualenv)" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8-final" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/examples/example_trca.py b/examples/example_trca.py new file mode 100644 index 00000000..f3f93df7 --- /dev/null +++ b/examples/example_trca.py @@ -0,0 +1,143 @@ +""" +Task-related component analysis (TRCA)-based SSVEP detection +============================================================ + +Sample code for the task-related component analysis (TRCA)-based steady +-state visual evoked potential (SSVEP) detection method [1]_. The filter +bank analysis [2, 3]_ can also be combined to the TRCA-based algorithm. + +Uses meegkit.trca.TRCA() + +References: + +.. [1] M. Nakanishi, Y. Wang, X. Chen, Y.-T. Wang, X. Gao, and T.-P. Jung, + "Enhancing detection of SSVEPs for a high-speed brain speller using + task-related component analysis", IEEE Trans. Biomed. Eng, 65(1): 104-112, + 2018. +.. [2] X. Chen, Y. Wang, S. Gao, T. -P. Jung and X. Gao, "Filter bank + canonical correlation analysis for implementing a high-speed SSVEP-based + brain-computer interface", J. Neural Eng., 12: 046008, 2015. +.. [3] X. Chen, Y. Wang, M. Nakanishi, X. Gao, T. -P. Jung, S. Gao, + "High-speed spelling with a noninvasive brain-computer interface", + Proc. Int. Natl. Acad. Sci. U. S. A, 112(44): E6058-6067, 2015. + +This code is based on the Matlab implementation from +https://github.com/mnakanishi/TRCA-SSVEP + +""" +# Author: Giuseppe Ferraro +import os +import time + +import numpy as np +import scipy.io +from meegkit.trca import TRCA +from meegkit.utils.trca import itr, normfit, round_half_up + +t = time.time() + +############################################################################### +# Parameters +# ----------------------------------------------------------------------------- +len_gaze_s = 0.5 # data length for target identification [s] +len_delay_s = 0.13 # visual latency being considered in the analysis [s] +n_bands = 5 # number of sub-bands in filter bank analysis +is_ensemble = True # True = ensemble TRCA method; False = TRCA method +alpha_ci = 0.05 # 100*(1-alpha_ci): confidence interval for accuracy +sfreq = 250 # sampling rate [Hz] +len_shift_s = 0.5 # duration for gaze shifting [s] +list_freqs = np.concatenate( + [[x + 8 for x in range(8)], + [x + 8.2 for x in range(8)], + [x + 8.4 for x in range(8)], + [x + 8.6 for x in range(8)], + [x + 8.8 for x in range(8)]]) # list of stimulus frequencies +n_targets = len(list_freqs) # The number of stimuli + +# Preparing useful variables (DONT'T need to modify) +len_gaze_smpl = round_half_up(len_gaze_s * sfreq) # data length [samples] +len_delay_smpl = round_half_up(len_delay_s * sfreq) # visual latency [samples] +len_sel_s = len_gaze_s + len_shift_s # selection time [s] +ci = 100 * (1 - alpha_ci) # confidence interval + +# Performing the TRCA-based SSVEP detection algorithm +print('Results of the ensemble TRCA-based method:\n') + +############################################################################### +# Load data +# ----------------------------------------------------------------------------- +path = os.path.join('..', 'tests', 'data', 'trcadata.mat') +mat = scipy.io.loadmat(path) +eeg = mat["eeg"] + +n_trials = eeg.shape[0] +n_chans = eeg.shape[1] +n_samples = eeg.shape[2] +n_blocks = eeg.shape[3] + +# Convert dummy Matlab format to (sample, channels, trials) and construct +# vector of labels +eeg = np.reshape(eeg.transpose([2, 1, 3, 0]), + (n_samples, n_chans, n_trials * n_blocks)) +labels = np.array([x for x in range(n_targets)] * n_blocks) + +crop_data = np.arange(len_delay_smpl, len_delay_smpl + len_gaze_smpl) +eeg = eeg[crop_data] + +############################################################################### +# TRCA classification +# ----------------------------------------------------------------------------- +# Estimate classification performance with a Leave-One-Block-Out +# cross-validation approach. + +# We use the filterbank specification described in [2]_. +filterbank = [[(6, 90), (4, 100)], # passband, stopband freqs [(Wp), (Ws)] + [(14, 90), (10, 100)], + [(22, 90), (16, 100)], + [(30, 90), (24, 100)], + [(38, 90), (32, 100)], + [(46, 90), (40, 100)], + [(54, 90), (48, 100)]] +trca = TRCA(sfreq, filterbank, is_ensemble) + +accs = np.zeros(n_blocks) +itrs = np.zeros(n_blocks) +for i in range(n_blocks): + + # Training stage + traindata = eeg.copy() + + # Select all folds except one for training + traindata = np.concatenate( + (traindata[..., :i * n_trials], + traindata[..., (i + 1) * n_trials:]), 2) + y_train = np.concatenate( + (labels[:i * n_trials], labels[(i + 1) * n_trials:]), 0) + + # Construction of the spatial filter and the reference signals + trca.fit(traindata, y_train) + + # Test stage + testdata = eeg[..., i * n_trials:(i + 1) * n_trials] + y_test = labels[i * n_trials:(i + 1) * n_trials] + estimated = trca.predict(testdata) + + # Evaluation of the performance for this fold (accuracy and ITR) + is_correct = estimated == y_test + accs[i] = np.mean(is_correct) * 100 + itrs[i] = itr(n_targets, np.mean(is_correct), len_sel_s) + print(f"Block {i}: accuracy = {accs[i]:.1f}, \tITR = {itrs[i]:.1f}") + +# Mean accuracy and ITR computation +mu, _, muci, _ = normfit(accs, alpha_ci) +print() +print(f"Mean accuracy = {mu:.1f}%\t({ci:.0f}% CI: {muci[0]:.1f}-{muci[1]:.1f}%)") # noqa + +mu, _, muci, _ = normfit(itrs, alpha_ci) +print(f"Mean ITR = {mu:.1f}\t({ci:.0f}% CI: {muci[0]:.1f}-{muci[1]:.1f}%)") +if is_ensemble: + ensemble = 'ensemble TRCA-based method' +else: + ensemble = 'TRCA-based method' + +print(f"\nElapsed time: {time.time()-t:.1f} seconds") diff --git a/meegkit/__init__.py b/meegkit/__init__.py index 795795bd..162a3c54 100644 --- a/meegkit/__init__.py +++ b/meegkit/__init__.py @@ -1,7 +1,7 @@ """M/EEG denoising utilities in python.""" __version__ = '0.1.1' -from . import asr, cca, detrend, dss, sns, star, ress, tspca, utils +from . import asr, cca, detrend, dss, sns, star, ress, trca, tspca, utils -__all__ = ['asr', 'cca', 'detrend', 'dss', 'ress', 'sns', 'star', 'tspca', - 'utils'] +__all__ = ['asr', 'cca', 'detrend', 'dss', 'ress', 'sns', 'star', 'trca', + 'tspca', 'utils'] diff --git a/meegkit/trca.py b/meegkit/trca.py new file mode 100644 index 00000000..2f3c2e75 --- /dev/null +++ b/meegkit/trca.py @@ -0,0 +1,218 @@ +"""Task-Related Component Analysis (TRCA).""" +# Author: Giuseppe Ferraro +import numpy as np +import scipy.linalg as linalg + +from .utils.trca import bandpass +from .utils import theshapeof + + +def trca(X): + """Task-related component analysis (TRCA). + + This function implements the method described in [1]_. + + Parameters + ---------- + X : array, shape=(n_samples, n_chans[, n_trials]) + Training data. + + Returns + ------- + W : array, shape=(n_chans,) + Weight coefficients for electrodes which can be used as a spatial + filter. + + References + ---------- + .. [1] M. Nakanishi, Y. Wang, X. Chen, Y. -T. Wang, X. Gao, and T.-P. Jung, + "Enhancing detection of SSVEPs for a high-speed brain speller using + task-related component analysis", IEEE Trans. Biomed. Eng, + 65(1):104-112, 2018. + + """ + n_samples, n_chans, n_trials = theshapeof(X) + + S = np.zeros((n_chans, n_chans)) + for trial_i in range(n_trials - 1): + x1 = np.squeeze(X[..., trial_i]) + + # Mean centering for the selected trial + x1 -= np.mean(x1, 0) + + # Select a second trial that is different + for trial_j in range(trial_i + 1, n_trials): + x2 = np.squeeze(X[..., trial_j]) + + # Mean centering for the selected trial + x2 -= np.mean(x2, 0) + + # Compute empirical covariance between the two selected trials and + # sum it + S = S + x1.T @ x2 + x2.T @ x1 + + # Reshape to have all the data as a sequence + UX = np.zeros((n_chans, n_samples * n_trials)) + for trial in range(n_trials): + UX[:, trial * n_samples:(trial + 1) * n_samples] = X[..., trial].T + + # Mean centering + UX -= np.mean(UX, 1)[:, None] + + # Compute empirical variance of all data (to be bounded) + Q = np.dot(UX, UX.T) + + # Compute eigenvalues and vectors + lambdas, W = linalg.eig(S, Q, left=True, right=False) + + # Select the eigenvector corresponding to the biggest eigenvalue + W_best = W[:, np.argmax(lambdas)] + + return W_best + + +class TRCA: + """Task-Related Component Analysis (TRCA). + + Parameters + ---------- + sfreq : float + Sampling rate. + filterbank : list[[2-tuple, 2-tuple]] + Filterbank frequencies. Each list element is itself a list of passband + `Wp` and stopband `Ws` edges frequencies `[Wp, Ws]`. For example, this + creates 3 bands, starting at 6, 14, and 22 hz respectively:: + + [[(6, 90), (4, 100)], + [(14, 90), (10, 100)], + [(22, 90), (16, 100)]] + + See :func:`scipy.signal.cheb1ord()` for more information on how to + specify the `Wp` and `Ws`. + ensemble: bool + If True, perform the ensemble TRCA analysis (default=False). + + Attributes + ---------- + traindata : array, shape=(n_bands, n_chans, n_trials) + Reference (training) data decomposed into sub-band components by the + filter bank analysis. + y_train : array, shape=(n_trials) + Labels associated with the train data. + coef_ : array, shape=(n_chans, n_chans) + Weight coefficients for electrodes which can be used as a spatial + filter. + classes : list + Classes. + n_bands : int + Number of sub-bands + + """ + + def __init__(self, sfreq, filterbank, ensemble=False): + self.sfreq = sfreq + self.ensemble = ensemble + self.filterbank = filterbank + self.n_bands = len(self.filterbank) + self.coef_ = None + + def fit(self, X, y): + """Training stage of the TRCA-based SSVEP detection. + + Parameters + ---------- + X : array, shape=(n_samples, n_chans[, n_trials]) + Training EEG data. + y : array, shape=(trials,) + True label corresponding to each trial of the data array. + + """ + n_samples, n_chans, _ = theshapeof(X) + classes = np.unique(y) + + trains = np.zeros((len(classes), self.n_bands, n_samples, n_chans)) + + W = np.zeros((self.n_bands, len(classes), n_chans)) + + for class_i in classes: + # Select data with a specific label + eeg_tmp = X[..., y == class_i] + for fb_i in range(self.n_bands): + # Filter the signal with fb_i + eeg_tmp = bandpass(eeg_tmp, self.sfreq, + Wp=self.filterbank[fb_i][0], + Ws=self.filterbank[fb_i][1]) + if (eeg_tmp.ndim == 3): + # Compute mean of the signal across trials + trains[class_i, fb_i] = np.mean(eeg_tmp, -1) + else: + trains[class_i, fb_i] = eeg_tmp + # Find the spatial filter for the corresponding filtered signal + # and label + w_best = trca(eeg_tmp) + W[fb_i, class_i, :] = w_best # Store the spatial filter + + self.trains = trains + self.coef_ = W + self.classes = classes + + return self + + def predict(self, X): + """Test phase of the TRCA-based SSVEP detection. + + Parameters + ---------- + X: array, shape=(n_samples, n_chans[, n_trials]) + Test data. + model: dict + Fitted model to be used in testing phase. + + Returns + ------- + pred: np.array, shape (trials) + The target estimated by the method. + + """ + if self.coef_ is None: + raise RuntimeError('TRCA is not fitted') + + # Alpha coefficients for the fusion of filterbank analysis + fb_coefs = [(x + 1)**(-1.25) + 0.25 for x in range(self.n_bands)] + _, _, n_trials = theshapeof(X) + + r = np.zeros((self.n_bands, len(self.classes))) + pred = np.zeros((n_trials), 'int') # To store predictions + + for trial in range(n_trials): + test_tmp = X[..., trial] # Pick a trial to be analysed + for fb_i in range(self.n_bands): + + # Filterbank on testdata + testdata = bandpass(test_tmp, self.sfreq, + Wp=self.filterbank[fb_i][0], + Ws=self.filterbank[fb_i][1]) + + for class_i in self.classes: + # Retrieve reference signal for class i + # (shape: n_chans, n_samples) + traindata = np.squeeze(self.trains[class_i, fb_i]) + if self.ensemble: + # Shape of (# of channel, # of class) + w = np.squeeze(self.coef_[fb_i]).T + else: + # Shape of (# of channel) + w = np.squeeze(self.coef_[fb_i, class_i]) + + # Compute 2D correlation of spatially filtered test data + # with ref + r_tmp = np.corrcoef((testdata @ w).flatten(), + (traindata @ w).flatten()) + r[fb_i, class_i] = r_tmp[0, 1] + + rho = np.dot(fb_coefs, r) # Fusion for the filterbank analysis + + tau = np.argmax(rho) # Retrieving the index of the max + pred[trial] = int(tau) + + return pred diff --git a/meegkit/utils/stats.py b/meegkit/utils/stats.py index 1fe1a6e7..cc694857 100644 --- a/meegkit/utils/stats.py +++ b/meegkit/utils/stats.py @@ -33,7 +33,7 @@ def robust_mean(X, axis=0, percentile=[5, 95]): return m -def rolling_corr(X, y, window=None, fs=1, step=1, axis=0): +def rolling_corr(X, y, window=None, sfreq=1, step=1, axis=0): """Calculate rolling correlation between some data and a reference signal. Parameters @@ -44,7 +44,7 @@ def rolling_corr(X, y, window=None, fs=1, step=1, axis=0): Reference signal. window : int Number of timepoints for to include for each correlation calculation. - fs: int + sfreq: int Sampling frequency (default=1). step : int If > 1, only compute correlations every `step` samples. @@ -83,7 +83,7 @@ def rolling_corr(X, y, window=None, fs=1, step=1, axis=0): corr = corr.squeeze(-1) # Times relative to end of window - t_corr = (timebins + window) / float(fs) + t_corr = (timebins + window) / float(sfreq) assert len(t_corr) == corr.shape[0] diff --git a/meegkit/utils/trca.py b/meegkit/utils/trca.py new file mode 100644 index 00000000..c596dd2a --- /dev/null +++ b/meegkit/utils/trca.py @@ -0,0 +1,137 @@ +"""TRCA utils.""" +import numpy as np +import scipy + +from scipy.signal import filtfilt, cheb1ord, cheby1 +from scipy.stats import chi2, t + + +def round_half_up(num, decimals=0): + """Round half up round the last decimal of the number. + + The rules are: + from 0 to 4 rounds down + from 5 to 9 rounds up + + Parameters + ---------- + num : float + Number to round + decimals : number of decimals + + Returns + ------- + num rounded + """ + multiplier = 10 ** decimals + return int(np.floor(num * multiplier + 0.5) / multiplier) + + +def normfit(data, ci=0.95): + """Compute the mean, std and confidence interval for them. + + Parameters + ---------- + data: array, shape=() + Input data. + ci : float + Confidence interval (default=0.95). + + Returns + ------- + m : mean + sigma : std deviation + [m - h, m + h] : confidence interval of the mean + [sigmaCI_lower, sigmaCI_upper] : confidence interval of the std + """ + arr = 1.0 * np.array(data) + num = len(arr) + avg, std_err = np.mean(arr), scipy.stats.sem(arr) + h_int = std_err * t.ppf((1 + ci) / 2., num - 1) + var = np.var(data, ddof=1) + var_ci_upper = var * (num - 1) / (chi2.ppf((1 - ci) / 2, num - 1)) + var_ci_lower = var * (num - 1) / (chi2.ppf(1 - (1 - ci) / 2, num - 1)) + sigma = np.sqrt(var) + sigma_ci_lower = np.sqrt(var_ci_lower) + sigma_ci_upper = np.sqrt(var_ci_upper) + + return avg, sigma, [avg - h_int, avg + + h_int], [sigma_ci_lower, sigma_ci_upper] + + +def itr(n, p, t): + """Compute information transfer rate (ITR). + + Inputs + ------ + n : int + Number of targets. + p : float + Target identification accuracy (0 <= p <= 1). + t : float + Average time for a selection (s). + + Returns + ------- + itr : float + Information transfer rate [bits/min] + + References + ---------- + .. [1] M. Cheng, X. Gao, S. Gao, and D. Xu, + "Design and Implementation of a Brain-Computer Interface With High + Transfer Rates", IEEE Trans. Biomed. Eng. 49, 1181-1186, 2002. + + """ + itr = 0 + + if (p < 0 or 1 < p): + raise ValueError('Accuracy need to be between 0 and 1.') + elif (p < 1 / n): + raise ValueError('ITR might be incorrect because accuracy < chance') + itr = 0 + elif (p == 1): + itr = np.log2(n) * 60 / t + else: + itr = (np.log2(n) + p * np.log2(p) + (1 - p) * + np.log2((1 - p) / (n - 1))) * 60 / t + + return itr + + +def bandpass(eeg, sfreq, Wp, Ws): + """Filter bank design for decomposing EEG data into sub-band components. + + Parameters + ---------- + eeg : np.array, shape=(n_samples, n_chans[, n_trials]) + Training data. + sfreq : int + Sampling frequency of the data. + Wp : 2-tuple + Passband for Chebyshev filter. + Ws : 2-tuple + Stopband for Chebyshev filter. + + Returns + ------- + y: np.array, shape=(n_trials, n_chans, n_samples) + Sub-band components decomposed by a filter bank. + + See Also + -------- + scipy.signal.cheb1ord : + Chebyshev type I filter order selection. + + """ + # Chebyshev type I filter order selection. + N, Wn = cheb1ord(Wp, Ws, 3, 40, fs=sfreq) + + # Chebyshev type I filter design + B, A = cheby1(N, 0.5, Wn, btype="bandpass", fs=sfreq) + + # the arguments 'axis=0, padtype='odd', padlen=3*(max(len(B),len(A))-1)' + # correspond to Matlab filtfilt : https://dsp.stackexchange.com/a/47945 + y = filtfilt(B, A, eeg, axis=0, padtype='odd', + padlen=3 * (max(len(B), len(A)) - 1)) + return y diff --git a/setup.cfg b/setup.cfg index 4cd146d5..300b0b8e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [flake8] -exclude = __init__.py,*externals*,constants.py,fixes.py,*examples* +exclude = __init__.py,*externals*,constants.py,fixes.py ignore = E241,E305,W504 [pydocstyle] diff --git a/tests/data/trcadata.mat b/tests/data/trcadata.mat new file mode 100644 index 00000000..a5f2d4e0 Binary files /dev/null and b/tests/data/trcadata.mat differ diff --git a/tests/test_trca.py b/tests/test_trca.py new file mode 100644 index 00000000..7fcb8d35 --- /dev/null +++ b/tests/test_trca.py @@ -0,0 +1,107 @@ +"""TRCA tests.""" +import os + +import numpy as np +import pytest +import scipy.io +from meegkit.trca import TRCA +from meegkit.utils.trca import itr, normfit, round_half_up + +########################################################################## +# Load data +# ----------------------------------------------------------------------------- +path = os.path.join('.', 'tests', 'data', 'trcadata.mat') +mat = scipy.io.loadmat(path) +eeg = mat["eeg"] + +n_trials = eeg.shape[0] +n_chans = eeg.shape[1] +n_samples = eeg.shape[2] +n_blocks = eeg.shape[3] + +list_freqs = np.concatenate( + [[x + 8 for x in range(8)], + [x + 8.2 for x in range(8)], + [x + 8.4 for x in range(8)], + [x + 8.6 for x in range(8)], + [x + 8.8 for x in range(8)]]) # list of stimulus frequencies +n_targets = len(list_freqs) # The number of stimuli + +# Convert dummy Matlab format to (sample, channels, trials) and construct +# vector of labels +eeg = np.reshape(eeg.transpose([2, 1, 3, 0]), + (n_samples, n_chans, n_trials * n_blocks)) +labels = np.array([x for x in range(n_targets)] * n_blocks) + +# We use the filterbank specification described in [2]_. +filterbank = [[[6, 90], [4, 100]], # passband freqs, stopband freqs (Wp, Ws) + [[14, 90], [10, 100]], + [[22, 90], [16, 100]], + [[30, 90], [24, 100]], + [[38, 90], [32, 100]]] + + +@pytest.mark.parametrize('ensemble', [True, False]) +def test_trcacode(ensemble): + """Test TRCA.""" + len_gaze_s = 0.5 # data length for target identification [s] + len_delay_s = 0.13 # visual latency being considered in the analysis [s] + alpha_ci = 0.05 # 100*(1-alpha_ci): confidence interval for accuracy + sfreq = 250 # sampling rate [Hz] + len_shift_s = 0.5 # duration for gaze shifting [s] + + # Preparing useful variables (DONT'T need to modify) + len_gaze_smpl = round_half_up(len_gaze_s * sfreq) # data length [samples] + len_delay_smpl = round_half_up(len_delay_s * sfreq) # visual latency [samples] + len_sel_s = len_gaze_s + len_shift_s # selection time [s] + ci = 100 * (1 - alpha_ci) # confidence interval + + crop_data = np.arange(len_delay_smpl, len_delay_smpl + len_gaze_smpl) + + ########################################################################## + # TRCA classification + # ----------------------------------------------------------------------------- + # Estimate classification performance with a Leave-One-Block-Out + # cross-validation approach + trca = TRCA(sfreq, filterbank, ensemble) + accs = np.zeros(2) + itrs = np.zeros(2) + for i in range(2): + + # Training stage + traindata = eeg.copy()[crop_data] + + # Select all folds except one for training + traindata = np.concatenate( + (traindata[..., :i * n_trials], + traindata[..., (i + 1) * n_trials:]), 2) + y_train = np.concatenate( + (labels[:i * n_trials], labels[(i + 1) * n_trials:]), 0) + + # Construction of the spatial filter and the reference signals + trca.fit(traindata, y_train) + + # Test stage + testdata = eeg[crop_data, :, i * n_trials:(i + 1) * n_trials] + y_test = labels[i * n_trials:(i + 1) * n_trials] + estimated = trca.predict(testdata) + + # Evaluation of the performance for this fold (accuracy and ITR) + is_correct = estimated == y_test + accs[i] = np.mean(is_correct) * 100 + itrs[i] = itr(n_targets, np.mean(is_correct), len_sel_s) + print(f"Block {i}: accuracy = {accs[i]:.1f}, \tITR = {itrs[i]:.1f}") + + # Mean accuracy and ITR computation + mu, _, muci, _ = normfit(accs, alpha_ci) + print(f"Mean accuracy = {mu:.1f}%\t({ci:.0f}% CI: {muci[0]:.1f}-{muci[1]:.1f}%)") # noqa + assert mu > 95 + mu, _, muci, _ = normfit(itrs, alpha_ci) + print(f"Mean ITR = {mu:.1f}\t({ci:.0f}% CI: {muci[0]:.1f}-{muci[1]:.1f}%)") + assert mu > 300 + + +if __name__ == '__main__': + import pytest + pytest.main([__file__]) + # test_trcacode()