Skip to content
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# MEEGkit

[![unit-tests](https://github.com/nbara/python-meegkit/workflows/unit-tests/badge.svg?style=flat)](https://github.com/nbara/python-meegkit/actions?workflow=unit-tests)
[![documentation](https://img.shields.io/travis/nbara/python-meegkit.svg?label=documentation&logo=travis)](https://travis-ci.org/nbara/python-meegkit)
[![documentation](https://img.shields.io/travis/nbara/python-meegkit.svg?label=documentation&logo=travis)](https://www.travis-ci.com/github/nbara/python-meegkit)
[![codecov](https://codecov.io/gh/nbara/python-meegkit/branch/master/graph/badge.svg)](https://codecov.io/gh/nbara/python-meegkit)
[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/nbara/python-meegkit/master)
[![twitter](https://img.shields.io/twitter/follow/lebababa?label=Twitter&style=flat&logo=Twitter)](https://twitter.com/intent/follow?screen_name=lebababa)
Expand Down
171 changes: 110 additions & 61 deletions examples/example_trca.ipynb

Large diffs are not rendered by default.

103 changes: 63 additions & 40 deletions examples/example_trca.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""
Task-related component analysis (TRCA)-based SSVEP detection
============================================================
Task-related component analysis for SSVEP detection
===================================================

Sample code for the task-related component analysis (TRCA)-based steady
-state visual evoked potential (SSVEP) detection method [1]_. The filter
bank analysis [2, 3]_ can also be combined to the TRCA-based algorithm.
bank analysis can also be combined to the TRCA-based algorithm [2]_ [3]_.

Uses meegkit.trca.TRCA()
This code is based on the Matlab implementation from:
https://github.com/mnakanishi/TRCA-SSVEP

Uses `meegkit.trca.TRCA()`.

References:

Expand All @@ -21,14 +24,13 @@
"High-speed spelling with a noninvasive brain-computer interface",
Proc. Int. Natl. Acad. Sci. U. S. A, 112(44): E6058-6067, 2015.

This code is based on the Matlab implementation from
https://github.com/mnakanishi/TRCA-SSVEP

"""
# Author: Giuseppe Ferraro <giuseppe.ferraro@isae-supaero.fr>
# Authors: Giuseppe Ferraro <giuseppe.ferraro@isae-supaero.fr>
# Nicolas Barascud <nicolas.barascud@gmail.com>
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import scipy.io
from meegkit.trca import TRCA
Expand All @@ -39,78 +41,100 @@
###############################################################################
# Parameters
# -----------------------------------------------------------------------------
len_gaze_s = 0.5 # data length for target identification [s]
len_delay_s = 0.13 # visual latency being considered in the analysis [s]
dur_gaze = 0.5 # data length for target identification [s]
delay = 0.13 # visual latency being considered in the analysis [s]
n_bands = 5 # number of sub-bands in filter bank analysis
is_ensemble = True # True = ensemble TRCA method; False = TRCA method
alpha_ci = 0.05 # 100*(1-alpha_ci): confidence interval for accuracy
sfreq = 250 # sampling rate [Hz]
len_shift_s = 0.5 # duration for gaze shifting [s]
list_freqs = np.concatenate(
[[x + 8 for x in range(8)],
dur_shift = 0.5 # duration for gaze shifting [s]
list_freqs = np.array(
[[x + 8.0 for x in range(8)],
[x + 8.2 for x in range(8)],
[x + 8.4 for x in range(8)],
[x + 8.6 for x in range(8)],
[x + 8.8 for x in range(8)]]) # list of stimulus frequencies
n_targets = len(list_freqs) # The number of stimuli
[x + 8.8 for x in range(8)]]).T # list of stimulus frequencies
n_targets = list_freqs.size # The number of stimuli

# Preparing useful variables (DONT'T need to modify)
len_gaze_smpl = round_half_up(len_gaze_s * sfreq) # data length [samples]
len_delay_smpl = round_half_up(len_delay_s * sfreq) # visual latency [samples]
len_sel_s = len_gaze_s + len_shift_s # selection time [s]
# Useful variables (no need to modify)
dur_gaze_s = round_half_up(dur_gaze * sfreq) # data length [samples]
delay_s = round_half_up(delay * sfreq) # visual latency [samples]
dur_sel_s = dur_gaze + dur_shift # selection time [s]
ci = 100 * (1 - alpha_ci) # confidence interval

# Performing the TRCA-based SSVEP detection algorithm
print('Results of the ensemble TRCA-based method:\n')

###############################################################################
# Load data
# -----------------------------------------------------------------------------
path = os.path.join('..', 'tests', 'data', 'trcadata.mat')
mat = scipy.io.loadmat(path)
eeg = mat["eeg"]
eeg = scipy.io.loadmat(path)["eeg"]

n_trials = eeg.shape[0]
n_chans = eeg.shape[1]
n_samples = eeg.shape[2]
n_blocks = eeg.shape[3]
n_trials, n_chans, n_samples, n_blocks = eeg.shape

# Convert dummy Matlab format to (sample, channels, trials) and construct
# vector of labels
eeg = np.reshape(eeg.transpose([2, 1, 3, 0]),
(n_samples, n_chans, n_trials * n_blocks))
labels = np.array([x for x in range(n_targets)] * n_blocks)

crop_data = np.arange(len_delay_smpl, len_delay_smpl + len_gaze_smpl)
crop_data = np.arange(delay_s, delay_s + dur_gaze_s)
eeg = eeg[crop_data]

###############################################################################
# TRCA classification
# -----------------------------------------------------------------------------
# Estimate classification performance with a Leave-One-Block-Out
# cross-validation approach.
#
# To get a sense of the filterbank specification in relation to the stimuli
# we can plot the individual filterbank sub-bands as well as the target
# frequencies (with their expected harmonics in the EEG spectrum). We use the
# filterbank specification described in [2]_.

# We use the filterbank specification described in [2]_.
filterbank = [[(6, 90), (4, 100)], # passband, stopband freqs [(Wp), (Ws)]
[(14, 90), (10, 100)],
[(22, 90), (16, 100)],
[(30, 90), (24, 100)],
[(38, 90), (32, 100)],
[(46, 90), (40, 100)],
[(54, 90), (48, 100)]]

f, ax = plt.subplots(1, figsize=(7, 4))
for i, band in enumerate(filterbank):
ax.axvspan(ymin=i / len(filterbank) + .02,
ymax=(i + 1) / len(filterbank) - .02,
xmin=filterbank[i][1][0], xmax=filterbank[i][1][1],
alpha=0.2, facecolor=f'C{i}')
ax.axvspan(ymin=i / len(filterbank) + .02,
ymax=(i + 1) / len(filterbank) - .02,
xmin=filterbank[i][0][0], xmax=filterbank[i][0][1],
alpha=0.5, label=f'sub-band{i}', facecolor=f'C{i}')

for f in list_freqs.flat:
colors = np.ones((9, 4))
colors[:, :3] = np.linspace(0, .5, 9)[:, None]
ax.scatter(f * np.arange(1, 10), [f] * 9, c=colors, s=8, zorder=100)

ax.set_ylabel('Stimulus frequency (Hz)')
ax.set_xlabel('EEG response frequency (Hz)')
ax.set_xlim([0, 102])
ax.set_xticks(np.arange(0, 100, 10))
ax.grid(True, ls=':', axis='x')
ax.legend(bbox_to_anchor=(1.05, .5), fontsize='small')
plt.tight_layout()
plt.show()

###############################################################################
# Now perform the TRCA-based SSVEP detection algorithm
trca = TRCA(sfreq, filterbank, is_ensemble)

print('Results of the ensemble TRCA-based method:\n')
accs = np.zeros(n_blocks)
itrs = np.zeros(n_blocks)
for i in range(n_blocks):

# Training stage
traindata = eeg.copy()

# Select all folds except one for training
traindata = np.concatenate(
(traindata[..., :i * n_trials],
traindata[..., (i + 1) * n_trials:]), 2)
(eeg[..., :i * n_trials],
eeg[..., (i + 1) * n_trials:]), 2)
y_train = np.concatenate(
(labels[:i * n_trials], labels[(i + 1) * n_trials:]), 0)

Expand All @@ -125,16 +149,15 @@
# Evaluation of the performance for this fold (accuracy and ITR)
is_correct = estimated == y_test
accs[i] = np.mean(is_correct) * 100
itrs[i] = itr(n_targets, np.mean(is_correct), len_sel_s)
itrs[i] = itr(n_targets, np.mean(is_correct), dur_sel_s)
print(f"Block {i}: accuracy = {accs[i]:.1f}, \tITR = {itrs[i]:.1f}")

# Mean accuracy and ITR computation
mu, _, muci, _ = normfit(accs, alpha_ci)
print()
print(f"Mean accuracy = {mu:.1f}%\t({ci:.0f}% CI: {muci[0]:.1f}-{muci[1]:.1f}%)") # noqa
print(f"\nMean accuracy = {mu:.1f}%\t({ci:.0f}% CI: {muci[0]:.1f}-{muci[1]:.1f}%)") # noqa

mu, _, muci, _ = normfit(itrs, alpha_ci)
print(f"Mean ITR = {mu:.1f}\t({ci:.0f}% CI: {muci[0]:.1f}-{muci[1]:.1f}%)")
print(f"Mean ITR = {mu:.1f}\t({ci:.0f}% CI: {muci[0]:.1f}-{muci[1]:.1f})")
if is_ensemble:
ensemble = 'ensemble TRCA-based method'
else:
Expand Down
Loading