Skip to content
Merged
133 changes: 88 additions & 45 deletions examples/preprocessing/esg_rm_heart_artefact_pcaobs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
.. _ex-pcaobs:

==============================================================================================
Principal Component Analysis - Optimal Basis Sets (PCA-OBS) for removal of cardiac artefact
==============================================================================================
=====================================================================================
Principal Component Analysis - Optimal Basis Sets (PCA-OBS) removing cardiac artefact
=====================================================================================

This script shows an example of how to use an adaptation of PCA-OBS
:footcite:`NiazyEtAl2005`. PCA-OBS was originally designed to remove
Expand All @@ -24,39 +24,77 @@
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

from matplotlib import pyplot as plt
import mne
from mne.preprocessing import find_ecg_events, fix_stim_artifact
from mne.io import read_raw_eeglab
from scipy.signal import firls
import glob

import numpy as np
from mne import Epochs, events_from_annotations, concatenate_raws

###############################################################################
# Download sample subject data from OpenNeuro if you haven't already
# This will download simultaneous EEG and ESG data from a single participant after
# median nerve stimulation of the left wrist
# Set the target directory to your desired location
import openneuro as on
import glob
from matplotlib import pyplot as plt

import mne
from mne import Epochs, concatenate_raws, events_from_annotations
from mne.io import read_raw_eeglab
from mne.preprocessing import find_ecg_events, fix_stim_artifact

# add the path where you want the OpenNeuro data downloaded. Files total around 8 GB
# target_dir = "/home/steinnhm/personal/mne-data"
target_dir = '/data/pt_02569/test_data'
target_dir = "/data/pt_02569/test_data"

file_list = glob.glob(target_dir + '/sub-001/eeg/*median*.set')
file_list = glob.glob(target_dir + "/sub-001/eeg/*median*.set")
if file_list:
print('Data is already downloaded')
print("Data is already downloaded")
else:
on.download(dataset='ds004388', target_dir=target_dir, include='sub-001/*median*_eeg*')
on.download(
dataset="ds004388", target_dir=target_dir, include="sub-001/*median*_eeg*"
)

###############################################################################
# Define the esg channels (arranged in two patches over the neck and lower back)
# Also include the ECG channel for artefact correction
esg_chans = ["S35", "S24", "S36", "Iz", "S17", "S15", "S32", "S22", "S19", "S26", "S28",
"S9", "S13", "S11", "S7", "SC1", "S4", "S18", "S8", "S31", "SC6", "S12",
"S16", "S5", "S30", "S20", "S34", "S21", "S25", "L1", "S29", "S14", "S33",
"S3", "L4", "S6", "S23", 'ECG']
esg_chans = [
"S35",
"S24",
"S36",
"Iz",
"S17",
"S15",
"S32",
"S22",
"S19",
"S26",
"S28",
"S9",
"S13",
"S11",
"S7",
"SC1",
"S4",
"S18",
"S8",
"S31",
"SC6",
"S12",
"S16",
"S5",
"S30",
"S20",
"S34",
"S21",
"S25",
"L1",
"S29",
"S14",
"S33",
"S3",
"L4",
"S6",
"S23",
]

# Sampling rate
fs = 1000
Expand All @@ -73,21 +111,30 @@
# Read in each of the four blocks and concatenate the raw structures after performing
# some minimal preprocessing including removing the stimulation artefact, downsampling
# and filtering
block_files = glob.glob(target_dir + '/sub-001/eeg/*median*.set')
block_files = glob.glob(target_dir + "/sub-001/eeg/*median*.set")
block_files = sorted(block_files)

for count, block_file in enumerate(block_files):
raw = read_raw_eeglab(block_file, eog=(), preload=True, uint16_codec=None, verbose=None)
raw = read_raw_eeglab(
block_file, eog=(), preload=True, uint16_codec=None, verbose=None
)

# Isolate the ESG channels only
raw.pick(esg_chans)
# Isolate the ESG channels (including ECG for R-peak detection)
raw.pick(esg_chans + ["ECG"])

# Find trigger timings to remove the stimulation artefact
events, event_dict = events_from_annotations(raw)
trigger_name = 'Median - Stimulation'

fix_stim_artifact(raw, events=events, event_id=event_dict[trigger_name], tmin=tstart_esg, tmax=tmax_esg, mode='linear',
stim_channel=None)
trigger_name = "Median - Stimulation"

fix_stim_artifact(
raw,
events=events,
event_id=event_dict[trigger_name],
tmin=tstart_esg,
tmax=tmax_esg,
mode="linear",
stim_channel=None,
)

# Downsample the data
raw.resample(fs)
Expand All @@ -101,20 +148,19 @@
###############################################################################
# Find ECG events and add to the raw structure as event annotations
ecg_events, ch_ecg, average_pulse = find_ecg_events(raw_concat, ch_name="ECG")
ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) # Samples only
ecg_event_samples = np.asarray(
[[ecg_event[0] for ecg_event in ecg_events]]
) # Samples only

qrs_event_time = [x / fs for x in ecg_event_samples.reshape(-1)] # Divide by sampling rate to make times
qrs_event_time = [
x / fs for x in ecg_event_samples.reshape(-1)
] # Divide by sampling rate to make times
duration = np.repeat(0.0, len(ecg_event_samples))
description = ['qrs'] * len(ecg_event_samples)

raw_concat.annotations.append(qrs_event_time, duration, description, ch_names=[esg_chans]*len(qrs_event_time))
description = ["qrs"] * len(ecg_event_samples)

###############################################################################
# Create filter coefficients
a = [0, 0, 1, 1]
f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter
ord = round(3 * fs / 0.5)
fwts = firls(ord + 1, f, a)
raw_concat.annotations.append(
qrs_event_time, duration, description, ch_names=[esg_chans] * len(qrs_event_time)
)

###############################################################################
# Create evoked response about the detected R-peaks before cardiac artefact correction
Expand All @@ -132,13 +178,10 @@
)
evoked_before = epochs.average()

# Apply function - modifies the data in place
# Apply function - modifies the data in place. Optionally high-pass filter
# the data before applying PCA-OBS to remove low frequency drifts
mne.preprocessing.apply_pca_obs(
raw_concat,
picks=esg_chans,
n_jobs=5,
qrs=ecg_event_samples,
filter_coords=fwts
raw_concat, picks=esg_chans, n_jobs=5, qrs_indices=ecg_event_samples.reshape(-1)
)

epochs = Epochs(
Expand All @@ -157,8 +200,8 @@
axes.plot(evoked_before.times, evoked_before.get_data().T, color="black")
axes.plot(evoked_after.times, evoked_after.get_data().T, color="green")
axes.set_ylim([-0.0005, 0.001])
axes.set_ylabel('Amplitude (V)')
axes.set_xlabel('Time (s)')
axes.set_ylabel("Amplitude (V)")
axes.set_xlabel("Time (s)")
axes.set_title("Before (black) vs. After (green)")
plt.tight_layout()
plt.show()
Expand Down
2 changes: 1 addition & 1 deletion mne/preprocessing/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ from .maxwell import (
maxwell_filter_prepare_emptyroom,
)
from .otp import oversampled_temporal_projection
from .pca_obs import apply_pca_obs
from .realign import realign_raw
from .ssp import compute_proj_ecg, compute_proj_eog
from .stim import fix_stim_artifact
from .xdawn import Xdawn
from .pca_obs import apply_pca_obs
Loading
Loading