From 2013fb487c936c12b9fd6308deef6a223fc9584d Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Sun, 1 Sep 2024 12:57:06 +0200 Subject: [PATCH 1/4] feat: add initial source code --- mne/preprocessing/pca_obs/PCA_OBS.py | 281 ++++++++++++++++++ mne/preprocessing/pca_obs/fit_ecgTemplate.py | 78 +++++ .../pca_obs/pchip_interpolation.py | 59 ++++ .../pca_obs/rm_heart_artefact.py | 130 ++++++++ 4 files changed, 548 insertions(+) create mode 100755 mne/preprocessing/pca_obs/PCA_OBS.py create mode 100755 mne/preprocessing/pca_obs/fit_ecgTemplate.py create mode 100755 mne/preprocessing/pca_obs/pchip_interpolation.py create mode 100755 mne/preprocessing/pca_obs/rm_heart_artefact.py diff --git a/mne/preprocessing/pca_obs/PCA_OBS.py b/mne/preprocessing/pca_obs/PCA_OBS.py new file mode 100755 index 00000000000..31ff3b4d06b --- /dev/null +++ b/mne/preprocessing/pca_obs/PCA_OBS.py @@ -0,0 +1,281 @@ +import numpy as np +import mne +from scipy.signal import filtfilt, detrend +import matplotlib.pyplot as plt +from sklearn.decomposition import PCA +from fit_ecgTemplate import fit_ecgTemplate +import math +import h5py + + +def PCA_OBS(data, **kwargs): + + # Declare class to hold pca information + class PCAInfo(): + def __init__(self): + pass + + # Instantiate class + pca_info = PCAInfo() + + # Check all necessary arguments sent in + required_kws = ["debug_mode", "qrs", "filter_coords", "sr", "savename", "ch_names", "sub_nr", "condition", + "current_channel"] + assert all([kw in kwargs.keys() for kw in required_kws]), "Error. Some KWs not passed into PCA_OBS." + + # Extract all kwargs + debug_mode = kwargs['debug_mode'] + qrs = kwargs['qrs'] + filter_coords = kwargs['filter_coords'] + sr = kwargs['sr'] + ch_names = kwargs['ch_names'] + sub_nr = kwargs['sub_nr'] + condition = kwargs['condition'] + if debug_mode: # Only need current channel and saving if we're debugging + current_channel = kwargs['current_channel'] + savename = kwargs['savename'] + + fs = sr + + # Standard delay between QRS peak and artifact + delay = 0 + + Gwindow = 2 + GHW = np.floor(Gwindow / 2).astype('int') + rcount = 0 + firstplot = 1 + + # set to baseline + data = data.reshape(-1, 1) + data = data.T + data = data - np.mean(data, axis=1) + + # Allocate memory + fitted_art = np.zeros(data.shape) + peakplot = np.zeros(data.shape) + + # Extract QRS events + for idx in qrs[0]: + if idx < len(peakplot[0, :]): + peakplot[0, idx] = 1 # logical indexed locations of qrs events + # sh = np.zeros((1, delay)) + # np1 = len(peakplot) + # peakplot = [sh, peakplot[0:np1 - delay]] # shifts indexed array by the delay - skipped here since delay=0 + + peak_idx = np.nonzero(peakplot)[1] # Selecting indices along columns + peak_idx = peak_idx.reshape(-1, 1) + peak_count = len(peak_idx) + + ################################################################ + # Preparatory work - reserving memory, configure sizes, de-trend + ################################################################ + print('Pulse artifact subtraction in progress...Please wait!') + + # define peak range based on RR + RR = np.diff(peak_idx[:, 0]) + mRR = np.median(RR) + peak_range = round(mRR/2) # Rounds to an integer + midP = peak_range + 1 + baseline_range = [0, round(peak_range/8)] + n_samples_fit = round(peak_range/8) # sample fit for interpolation between fitted artifact windows + + # make sure array is long enough for PArange (if not cut off last ECG peak) + pa = peak_count # Number of QRS complexes detected + while peak_idx[pa-1, 0] + peak_range > len(data[0]): + pa = pa - 1 + steps = 1 * pa + peak_count = pa + + # Filter channel + eegchan = filtfilt(filter_coords, 1, data) + + # build PCA matrix(heart-beat-epochs x window-length) + pcamat = np.zeros((peak_count - 1, 2*peak_range+1)) # [epoch x time] + # picking out heartbeat epochs + for p in range(1, peak_count): + pcamat[p-1, :] = eegchan[0, peak_idx[p, 0] - peak_range: peak_idx[p, 0] + peak_range+1] + + # detrending matrix(twice) + pcamat = detrend(pcamat, type='constant', axis=1) # [epoch x time] - detrended along the epoch + mean_effect = np.mean(pcamat, axis=0) # [1 x time], contains the mean over all epochs + std_effect = np.std(pcamat, axis=0) # want mean and std of each column + dpcamat = detrend(pcamat, type='constant', axis=1) # [time x epoch] + + ################################################################### + # Perform PCA with sklearn + ################################################################### + # run PCA(performs SVD(singular value decomposition)) + pca = PCA(svd_solver="full") + pca.fit(dpcamat) + eigen_vectors = pca.components_ + eigen_values = pca.explained_variance_ + factor_loadings = pca.components_.T*np.sqrt(pca.explained_variance_) + pca_info.eigen_vectors = eigen_vectors + pca_info.factor_loadings = factor_loadings + pca_info.eigen_values = eigen_values + pca_info.expl_var = pca.explained_variance_ratio_ + + # define selected number of components using profile likelihood + pca_info.nComponents = 4 + + # Creates plots + if debug_mode: + # plot pca variables figure + comp2plot = pca_info.nComponents + fig, axs = plt.subplots(2, 2) + for a in np.arange(comp2plot): + axs[0, 0].plot(pca_info.eigen_vectors[:, a], label=f"Data {a}") + axs[1, 1].plot(pca_info.factor_loadings[:, a], label=f"Data {a}") + axs[0, 0].set_title('Evec') + axs[1, 1].set_title('Factor Loadings') + axs[1, 1].set(xlabel='time') + + axs[0, 1].plot(np.arange(len(pca_info.expl_var)), pca_info.expl_var, 'r*') + axs[0, 1].set(xlabel='components', ylabel='var explained (%)') + cum_explained = np.cumsum(pca_info.expl_var) + axs[0, 1].set_title(f"first {pca_info.nComponents} comp, {cum_explained[pca_info.nComponents]} % var") + + axs[1, 0].plot(pca_info.eigen_values) + axs[1, 0].set_title('eigenvalues') + axs[1, 0].set(xlabel='components') + + fig.suptitle(f"{sub_nr} thresholds PCA vars channel {current_channel}") + plt.tight_layout() + fig.savefig(f"{savename}_{condition}.jpg") + + if debug_mode: + pca_info.chan = current_channel + pca_info.meanEffect = mean_effect.T + nComponents = pca_info.nComponents + + ####################################################################### + # Make template of the ECG artefact + ####################################################################### + mean_effect = mean_effect.reshape(-1, 1) + pca_template = np.c_[mean_effect, factor_loadings[:, 0:nComponents]] + + # Plot template vars + if debug_mode: + # plot template vars + fig = plt.figure() + pcatime = (np.arange(-peak_range, peak_range+1))/fs + pcatime = pcatime.reshape(-1) + plt.plot(pcatime, std_effect) + plt.plot(pcatime, mean_effect) + plt.plot(pcatime, factor_loadings[:, 0: nComponents]) + plt.legend(['std effect', 'mean effect', 'PCA_1', 'PCA_2', 'PCA_3', 'PCA_4']) + fig.suptitle(f"{sub_nr} papc channel {current_channel}") + plt.tight_layout() + fig.savefig(f"{savename}_templateVars_{condition}.jpg") + + ################################################################################### + # Data Fitting + ################################################################################### + window_start_idx = [] + window_end_idx = [] + for p in range(0, peak_count): + # Deals with start portion of data + if p == 0: + pre_range = peak_range + post_range = math.floor((peak_idx[p + 1] - peak_idx[p])/2) + if post_range > peak_range: + post_range = peak_range + try: + post_idx_nextPeak = [] + fitted_art, post_idx_nextPeak = fit_ecgTemplate(data, pca_template, peak_idx[p], peak_range, + pre_range, post_range, baseline_range, midP, + fitted_art, post_idx_nextPeak, n_samples_fit) + # Appending to list instead of using counter + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + except Exception as e: + print(f'Cannot fit first ECG epoch. Reason: {e}') + + # Deals with last edge of data + elif p == peak_count: + print('On last section - almost there!') + try: + pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) + post_range = peak_range + if pre_range > peak_range: + pre_range = peak_range + fitted_art, _ = fit_ecgTemplate(data, pca_template, peak_idx(p), peak_range, pre_range, post_range, + baseline_range, midP, fitted_art, post_idx_nextPeak, n_samples_fit) + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + except Exception as e: + print(f'Cannot fit last ECG epoch. Reason: {e}') + + # Deals with middle portion of data + else: + try: + # ---------------- Processing of central data - -------------------- + # cycle through peak artifacts identified by peakplot + pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) + post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) + if pre_range >= peak_range: + pre_range = peak_range + if post_range > peak_range: + post_range = peak_range + + aTemplate = pca_template[midP - peak_range-1:midP + peak_range+1, :] + fitted_art, post_idx_nextPeak = fit_ecgTemplate(data, aTemplate, peak_idx[p], peak_range, pre_range, + post_range, baseline_range, midP, fitted_art, + post_idx_nextPeak, n_samples_fit) + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + except Exception as e: + print(f"Cannot fit middle section of data. Reason: {e}") + + # Plot some channels + if debug_mode: + # check with plot what has been done + # First check if this channel is one we want to plot for debugging + plotChannel = 0 + for ii in range(0,len(ch_names)): + if current_channel == ch_names[ii]: + plotChannel = 1 + + if plotChannel == 1: + fig = plt.figure() + plt.plot((np.arange(0, len(fitted_art[0, :]))/fs).reshape(-1, 1), data[:].T, zorder=0) + plt.plot((np.arange(0, len(fitted_art[0, :]))/fs).reshape(-1, 1), eegchan[:].T, 'r', zorder=5) + plt.plot((np.arange(0, len(fitted_art[0, :]))/fs).reshape(-1, 1), fitted_art[:].T, 'g', zorder=10) + plt.plot((np.arange(0, len(fitted_art[0, :]))/fs).reshape(-1, 1), (np.subtract(data[:], fitted_art[:])).T, 'm', zorder=15) + plt.legend(['raw data', 'filtered', 'fitted_art', 'clean'], loc='upper right').set_zorder(20) + plt.xlabel('time [s]') + plt.ylabel('amplitude [V]') + plt.title('Subject ' + sub_nr + ', channel ' + current_channel) + fig.savefig(f"{savename}_compareresults_{condition}.jpg") + + # Actually subtract the artefact, return needs to be the same shape as input data + # One sample shift purely due to the fact the r-peaks are currently detected in MATLAB + data = data.reshape(-1) + fitted_art = fitted_art.reshape(-1) + + # data -= fitted_art + + data_ = np.zeros(len(data)) + data_[0] = data[0] + data_[1:] = data[1:] - fitted_art[:-1] + data = data_ + + # Original code is this: + # data -= fitted_art + # data = data.T.reshape(-1) + + # Can't add annotations for window start and end (sample number added) to mne raw structure here + # Add it to the pca info class and store it that way + # Can then access in the main rm_heart_artefact and create the annotations + # Only save the pca vars if we're in debug mode + if debug_mode: + pca_info.window_start_idx = window_start_idx + pca_info.window_end_idx = window_end_idx + dataset_keywords = [a for a in dir(pca_info) if not a.startswith('__')] + fn = f"{savename}_{condition}_pca_info.h5" + with h5py.File(fn, "w") as outfile: + for keyword in dataset_keywords: + outfile.create_dataset(keyword, data=getattr(pca_info, keyword)) + + # Can only return data + return data diff --git a/mne/preprocessing/pca_obs/fit_ecgTemplate.py b/mne/preprocessing/pca_obs/fit_ecgTemplate.py new file mode 100755 index 00000000000..00ef6459fb6 --- /dev/null +++ b/mne/preprocessing/pca_obs/fit_ecgTemplate.py @@ -0,0 +1,78 @@ +import numpy as np +from scipy.signal import detrend +from scipy.interpolate import PchipInterpolator as pchip +import h5py + + +def fit_ecgTemplate(data, pca_template, aPeak_idx, peak_range, pre_range, post_range, baseline_range, midP, fitted_art, post_idx_previousPeak, n_samples_fit): + # Declare class to hold ecg fit information + class fitECG(): + def __init__(self): + pass + + # Instantiate class + fitecg = fitECG() + + # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previouspeak + # Then nextpeak is returned at the end and the process repeats + # select window of template + template = pca_template[midP - peak_range-1: midP + peak_range+1, :] + + # select window of data and detrend it + slice = data[0, aPeak_idx[0] - peak_range:aPeak_idx[0] + peak_range+1] + detrended_data = detrend(slice.reshape(-1), type='constant') + + # maps data on template and then maps it again back to the sensor space + least_square = np.linalg.lstsq(template, detrended_data, rcond=None) + pad_fit = np.dot(template, least_square[0]) + + # fit artifact, I already loop through externally channel to channel + fitted_art[0, aPeak_idx[0] - pre_range-1: aPeak_idx[0] + post_range] = pad_fit[midP - pre_range-1: midP + post_range].T + + fitecg.fitted_art = fitted_art + fitecg.template = template + fitecg.detrended_data = detrended_data + fitecg.pad_fit = pad_fit + fitecg.aPeak_idx = aPeak_idx + fitecg.midP = midP + fitecg.peak_range = peak_range + fitecg.data = data + + post_idx_nextPeak = [aPeak_idx[0] + post_range] + + # Check it's not empty + if len(post_idx_previousPeak) != 0: + # interpolate time between peaks + intpol_window = np.ceil([post_idx_previousPeak[0], aPeak_idx[0] - pre_range]).astype('int') # interpolation window + fitecg.intpol_window = intpol_window + + if intpol_window[0] < intpol_window[1]: + # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data + + # You have x_fit which is two slices on either side of the interpolation window endpoints + # You have y_fit which is the y vals corresponding to x values above + # You have x_interpol which is the time points between the two slices in x_fit that you want to interpolate + # You have y_interpol which is values from pchip at the time points specified in x_interpol + x_interpol = np.arange(intpol_window[0], intpol_window[1] + 1, 1) # points to be interpolated in pt - the gap between the endpoints of the window + x_fit = np.concatenate([np.arange(intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1), + np.arange(intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1)]) # Entire range of x values in this step (taking some number of samples before and after the window) + y_fit = fitted_art[0, x_fit] + y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation + + # Then make fitted artefact in the desired range equal to the completed fit above + fitted_art[0, post_idx_previousPeak[0]: aPeak_idx[0] - pre_range+1] = y_interpol + + fitecg.x_fit = x_fit + fitecg.y_fit = y_fit + fitecg.x_interpol = x_interpol + fitecg.y_interpol = y_interpol + fitecg.fitted_art = fitted_art # Reassign if we've gone into this loop + + # # Save to file to compare to matlab - only for debugging + # dataset_keywords = [a for a in dir(fitecg) if not a.startswith('__')] + # fn = f"/data/pt_02569/tmp_data/test_data/fitecg_test_{aPeak_idx[0]}_sub-001_tibial_S35.h5" + # with h5py.File(fn, "w") as outfile: + # for keyword in dataset_keywords: + # outfile.create_dataset(keyword, data=getattr(fitecg, keyword)) + + return fitted_art, post_idx_nextPeak diff --git a/mne/preprocessing/pca_obs/pchip_interpolation.py b/mne/preprocessing/pca_obs/pchip_interpolation.py new file mode 100755 index 00000000000..a35d1b75aed --- /dev/null +++ b/mne/preprocessing/pca_obs/pchip_interpolation.py @@ -0,0 +1,59 @@ +# Function to interpolate based on PCHIP rather than MNE inbuilt linear option + +import mne +import numpy as np +from scipy.interpolate import PchipInterpolator as pchip +import matplotlib.pyplot as plt + + +def PCHIP_interpolation(data, **kwargs): + # Check all necessary arguments sent in + required_kws = ["trigger_indices", "interpol_window_sec", "fs", "debug_mode"] + assert all([kw in kwargs.keys() for kw in required_kws]), "Error. Some KWs not passed into PCA_OBS." + + # Extract all kwargs - more elegant ways to do this + fs = kwargs['fs'] + interpol_window_sec = kwargs['interpol_window_sec'] + trigger_indices = kwargs['trigger_indices'] + debug_mode = kwargs['debug_mode'] + + if debug_mode: + plt.figure() + # plot signal with artifact + plot_range = [-50, 100] + test_trial = 100 + xx = (np.arange(plot_range[0], plot_range[1])) / fs * 1000 + plt.plot(xx, data[trigger_indices[test_trial] + plot_range[0]:trigger_indices[test_trial] + plot_range[1]]) + + # Convert intpol window to msec then convert to samples + pre_window = round((interpol_window_sec[0]*1000) * fs / 1000) # in samples + post_window = round((interpol_window_sec[1]*1000) * fs / 1000) # in samples + intpol_window = np.ceil([pre_window, post_window]).astype(int) # interpolation window + + n_samples_fit = 5 # number of samples before and after cut used for interpolation fit + + x_fit_raw = np.concatenate([np.arange(intpol_window[0]-n_samples_fit-1, intpol_window[0], 1), + np.arange(intpol_window[1]+1, intpol_window[1]+n_samples_fit+2, 1)]) + x_interpol_raw = np.arange(intpol_window[0], intpol_window[1]+1, 1) # points to be interpolated; in pt + + for ii in np.arange(0, len(trigger_indices)): # loop through all stimulation events + x_fit = trigger_indices[ii] + x_fit_raw # fit point latencies for this event + x_interpol = trigger_indices[ii] + x_interpol_raw # latencies for to-be-interpolated data points + + # Data is just a string of values + y_fit = data[x_fit] # y values to be fitted + y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation + data[x_interpol] = y_interpol # replace in data + + if np.mod(ii, 100) == 0: # talk to the operator every 100th trial + print(f'stimulation event {ii} \n') + + if debug_mode: + # plot signal with interpolated artifact + plt.figure() + plt.plot(xx, data[trigger_indices[test_trial] + plot_range[0]: trigger_indices[test_trial] + plot_range[1]]) + plt.title('After Correction') + + plt.show() + + return data diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact.py b/mne/preprocessing/pca_obs/rm_heart_artefact.py new file mode 100755 index 00000000000..0de2a73358e --- /dev/null +++ b/mne/preprocessing/pca_obs/rm_heart_artefact.py @@ -0,0 +1,130 @@ +# Calls PCA_OBS which in turn calls fit_ecgTemplate to remove the heart artefact via PCA_OBS (Principal Component +# Analysis, Optimal Basis Sets) + +import os +from scipy.io import loadmat +from scipy.signal import firls +from PCA_OBS import * +from get_conditioninfo import * +from get_channels import * + + +def rm_heart_artefact(subject, condition, srmr_nr, sampling_rate, pchip): + matlab = False # If this is true, use the data 'prepared' by matlab - testing to see where hump at 0 comes from + # Incredibly slow without parallelization + # Set variables + subject_id = f'sub-{str(subject).zfill(3)}' + cond_info = get_conditioninfo(condition, srmr_nr) + nblocks = cond_info.nblocks + cond_name = cond_info.cond_name + stimulation = cond_info.stimulation + trigger_name = cond_info.trigger_name + + # Setting paths + input_path = "/data/pt_02569/tmp_data/prepared_py/"+subject_id+"/esg/prepro/" + input_path_m = "/data/pt_02569/tmp_data/prepared/"+subject_id+"/esg/prepro/" + save_path = "/data/pt_02569/tmp_data/ecg_rm_py/"+subject_id+"/esg/prepro/" + os.makedirs(save_path, exist_ok=True) + + figure_path = save_path + + # For debugging just test one channel of each + debug_channel = ['S35'] + _, esg_chans, _ = get_channels(subject, False, False, srmr_nr) # Ignoring ECG and EOG channels + + # Dyanmically set filename + if matlab: + fname = f"raw_{sampling_rate}_spinal_{cond_name}.set" + raw = mne.io.read_raw_eeglab(input_path_m + fname, preload=True) + else: + if pchip: + fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs_pchip" + else: + fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs" + # Read .fif file from the previous step (import_data) + raw = mne.io.read_raw_fif(input_path + fname + '.fif', preload=True) + + # Read .mat file with QRS events + fname_m = f"raw_{sampling_rate}_spinal_{cond_name}" + matdata = loadmat(input_path_m+fname_m+'.mat') + QRSevents_m = matdata['QRSevents'] + fwts = matdata['fwts'] + + # Read .h5 file with alternative QRS events + # with h5py.File(input_path+fname+'.h5', "r") as infile: + # QRSevents_p = infile["QRS"][()] + + # Create filter coefficients + fs = sampling_rate + a = [0, 0, 1, 1] + f = [0, 0.4/(fs/2), 0.9/(fs / 2), 1] # 0.9 Hz highpass filter + # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter + ord = round(3*fs/0.5) + fwts = firls(ord+1, f, a) + + # Run once with a single channel and debug_mode = True to get window information + for ch in debug_channel: + # set PCA_OBS input variables + channelNames = ['S35', 'Iz', 'SC1', 'S3', 'SC6', 'S20', 'L1', 'L4'] + # these channels will be plotted(only for debugging / testing) + + # run PCA_OBS + if pchip: + name = 'pca_chan_' + ch + '_pchip' + else: + name = 'pca_chan_'+ch + PCA_OBS_kwargs = dict( + debug_mode=True, qrs=QRSevents_m, filter_coords=fwts, sr=sampling_rate, + savename=save_path+name, + ch_names=channelNames, sub_nr=subject_id, + condition=cond_name, current_channel=ch + ) + # Apply function modifies the data in raw in place + raw.copy().apply_function(PCA_OBS, picks=[ch], **PCA_OBS_kwargs) + + # This information is the same for each channel - run through fitting once to get vals, add to all channels + keywords = ['window_start_idx', 'window_end_idx'] + fn = f"{save_path}pca_chan_{ch}_{cond_name}_pca_info.h5" + with h5py.File(fn, "r") as infile: + # Get the data + window_start = infile[keywords[0]][()].reshape(-1) + window_end = infile[keywords[1]][()].reshape(-1) + + onset = [x/sampling_rate for x in window_start] # Divide by sampling rate to make times + duration = np.repeat(0.0, len(window_start)) + description = ['fit_start'] * len(window_start) + raw.annotations.append(onset, duration, description, ch_names=[esg_chans] * len(window_start)) + + onset = [x/sampling_rate for x in window_end] + duration = np.repeat(0.0, len(window_end)) + description = ['fit_end'] * len(window_end) + raw.annotations.append(onset, duration, description, ch_names=[esg_chans]*len(window_end)) + + # Then run parallel for all channels with n_jobs set and debug_mode = False + # set PCA_OBS input variables + channelNames = ['S35', 'Iz', 'SC1', 'S3', 'SC6', 'S20', 'L1', 'L4'] + # these channels will be plotted(only for debugging / testing) + + # run PCA_OBS + # In this case ch_names, current_channel and savename are dummy vars - not necessary really + PCA_OBS_kwargs = dict( + debug_mode=False, qrs=QRSevents_m, filter_coords=fwts, sr=sampling_rate, + savename=save_path + 'pca_chan', + ch_names=channelNames, sub_nr=subject_id, + condition=cond_name, current_channel=ch + ) + + # Apply function should modifies the data in raw in place + raw.apply_function(PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans)) + + # Save the new mne structure with the cleaned data + if matlab: + raw.save(os.path.join(save_path, f'data_clean_ecg_spinal_{cond_name}_withqrs_mat.fif'), fmt='double', + overwrite=True) + else: + if pchip: + raw.save(os.path.join(save_path, f'data_clean_ecg_spinal_{cond_name}_withqrs_pchip.fif'), fmt='double', + overwrite=True) + else: + raw.save(os.path.join(save_path, f'data_clean_ecg_spinal_{cond_name}_withqrs.fif'), fmt='double', + overwrite=True) From 6d5f5b234e2b25626dfffd8ab1a4373e2792e13a Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Fri, 6 Sep 2024 16:15:26 +0200 Subject: [PATCH 2/4] Minimum working example with local data --- mne/preprocessing/pca_obs/PCA_OBS.py | 99 +------------ mne/preprocessing/pca_obs/fit_ecgTemplate.py | 7 - .../pca_obs/pchip_interpolation.py | 21 +-- .../pca_obs/rm_heart_artefact.py | 132 +++++------------- 4 files changed, 42 insertions(+), 217 deletions(-) diff --git a/mne/preprocessing/pca_obs/PCA_OBS.py b/mne/preprocessing/pca_obs/PCA_OBS.py index 31ff3b4d06b..9209da8a6f0 100755 --- a/mne/preprocessing/pca_obs/PCA_OBS.py +++ b/mne/preprocessing/pca_obs/PCA_OBS.py @@ -1,5 +1,5 @@ import numpy as np -import mne +# import mne from scipy.signal import filtfilt, detrend import matplotlib.pyplot as plt from sklearn.decomposition import PCA @@ -19,32 +19,16 @@ def __init__(self): pca_info = PCAInfo() # Check all necessary arguments sent in - required_kws = ["debug_mode", "qrs", "filter_coords", "sr", "savename", "ch_names", "sub_nr", "condition", - "current_channel"] + required_kws = ["qrs", "filter_coords", "sr"] assert all([kw in kwargs.keys() for kw in required_kws]), "Error. Some KWs not passed into PCA_OBS." # Extract all kwargs - debug_mode = kwargs['debug_mode'] qrs = kwargs['qrs'] filter_coords = kwargs['filter_coords'] sr = kwargs['sr'] - ch_names = kwargs['ch_names'] - sub_nr = kwargs['sub_nr'] - condition = kwargs['condition'] - if debug_mode: # Only need current channel and saving if we're debugging - current_channel = kwargs['current_channel'] - savename = kwargs['savename'] fs = sr - # Standard delay between QRS peak and artifact - delay = 0 - - Gwindow = 2 - GHW = np.floor(Gwindow / 2).astype('int') - rcount = 0 - firstplot = 1 - # set to baseline data = data.reshape(-1, 1) data = data.T @@ -58,9 +42,6 @@ def __init__(self): for idx in qrs[0]: if idx < len(peakplot[0, :]): peakplot[0, idx] = 1 # logical indexed locations of qrs events - # sh = np.zeros((1, delay)) - # np1 = len(peakplot) - # peakplot = [sh, peakplot[0:np1 - delay]] # shifts indexed array by the delay - skipped here since delay=0 peak_idx = np.nonzero(peakplot)[1] # Selecting indices along columns peak_idx = peak_idx.reshape(-1, 1) @@ -117,34 +98,6 @@ def __init__(self): # define selected number of components using profile likelihood pca_info.nComponents = 4 - - # Creates plots - if debug_mode: - # plot pca variables figure - comp2plot = pca_info.nComponents - fig, axs = plt.subplots(2, 2) - for a in np.arange(comp2plot): - axs[0, 0].plot(pca_info.eigen_vectors[:, a], label=f"Data {a}") - axs[1, 1].plot(pca_info.factor_loadings[:, a], label=f"Data {a}") - axs[0, 0].set_title('Evec') - axs[1, 1].set_title('Factor Loadings') - axs[1, 1].set(xlabel='time') - - axs[0, 1].plot(np.arange(len(pca_info.expl_var)), pca_info.expl_var, 'r*') - axs[0, 1].set(xlabel='components', ylabel='var explained (%)') - cum_explained = np.cumsum(pca_info.expl_var) - axs[0, 1].set_title(f"first {pca_info.nComponents} comp, {cum_explained[pca_info.nComponents]} % var") - - axs[1, 0].plot(pca_info.eigen_values) - axs[1, 0].set_title('eigenvalues') - axs[1, 0].set(xlabel='components') - - fig.suptitle(f"{sub_nr} thresholds PCA vars channel {current_channel}") - plt.tight_layout() - fig.savefig(f"{savename}_{condition}.jpg") - - if debug_mode: - pca_info.chan = current_channel pca_info.meanEffect = mean_effect.T nComponents = pca_info.nComponents @@ -154,20 +107,6 @@ def __init__(self): mean_effect = mean_effect.reshape(-1, 1) pca_template = np.c_[mean_effect, factor_loadings[:, 0:nComponents]] - # Plot template vars - if debug_mode: - # plot template vars - fig = plt.figure() - pcatime = (np.arange(-peak_range, peak_range+1))/fs - pcatime = pcatime.reshape(-1) - plt.plot(pcatime, std_effect) - plt.plot(pcatime, mean_effect) - plt.plot(pcatime, factor_loadings[:, 0: nComponents]) - plt.legend(['std effect', 'mean effect', 'PCA_1', 'PCA_2', 'PCA_3', 'PCA_4']) - fig.suptitle(f"{sub_nr} papc channel {current_channel}") - plt.tight_layout() - fig.savefig(f"{savename}_templateVars_{condition}.jpg") - ################################################################################### # Data Fitting ################################################################################### @@ -227,27 +166,6 @@ def __init__(self): except Exception as e: print(f"Cannot fit middle section of data. Reason: {e}") - # Plot some channels - if debug_mode: - # check with plot what has been done - # First check if this channel is one we want to plot for debugging - plotChannel = 0 - for ii in range(0,len(ch_names)): - if current_channel == ch_names[ii]: - plotChannel = 1 - - if plotChannel == 1: - fig = plt.figure() - plt.plot((np.arange(0, len(fitted_art[0, :]))/fs).reshape(-1, 1), data[:].T, zorder=0) - plt.plot((np.arange(0, len(fitted_art[0, :]))/fs).reshape(-1, 1), eegchan[:].T, 'r', zorder=5) - plt.plot((np.arange(0, len(fitted_art[0, :]))/fs).reshape(-1, 1), fitted_art[:].T, 'g', zorder=10) - plt.plot((np.arange(0, len(fitted_art[0, :]))/fs).reshape(-1, 1), (np.subtract(data[:], fitted_art[:])).T, 'm', zorder=15) - plt.legend(['raw data', 'filtered', 'fitted_art', 'clean'], loc='upper right').set_zorder(20) - plt.xlabel('time [s]') - plt.ylabel('amplitude [V]') - plt.title('Subject ' + sub_nr + ', channel ' + current_channel) - fig.savefig(f"{savename}_compareresults_{condition}.jpg") - # Actually subtract the artefact, return needs to be the same shape as input data # One sample shift purely due to the fact the r-peaks are currently detected in MATLAB data = data.reshape(-1) @@ -264,18 +182,5 @@ def __init__(self): # data -= fitted_art # data = data.T.reshape(-1) - # Can't add annotations for window start and end (sample number added) to mne raw structure here - # Add it to the pca info class and store it that way - # Can then access in the main rm_heart_artefact and create the annotations - # Only save the pca vars if we're in debug mode - if debug_mode: - pca_info.window_start_idx = window_start_idx - pca_info.window_end_idx = window_end_idx - dataset_keywords = [a for a in dir(pca_info) if not a.startswith('__')] - fn = f"{savename}_{condition}_pca_info.h5" - with h5py.File(fn, "w") as outfile: - for keyword in dataset_keywords: - outfile.create_dataset(keyword, data=getattr(pca_info, keyword)) - # Can only return data return data diff --git a/mne/preprocessing/pca_obs/fit_ecgTemplate.py b/mne/preprocessing/pca_obs/fit_ecgTemplate.py index 00ef6459fb6..11043b0c56d 100755 --- a/mne/preprocessing/pca_obs/fit_ecgTemplate.py +++ b/mne/preprocessing/pca_obs/fit_ecgTemplate.py @@ -68,11 +68,4 @@ def __init__(self): fitecg.y_interpol = y_interpol fitecg.fitted_art = fitted_art # Reassign if we've gone into this loop - # # Save to file to compare to matlab - only for debugging - # dataset_keywords = [a for a in dir(fitecg) if not a.startswith('__')] - # fn = f"/data/pt_02569/tmp_data/test_data/fitecg_test_{aPeak_idx[0]}_sub-001_tibial_S35.h5" - # with h5py.File(fn, "w") as outfile: - # for keyword in dataset_keywords: - # outfile.create_dataset(keyword, data=getattr(fitecg, keyword)) - return fitted_art, post_idx_nextPeak diff --git a/mne/preprocessing/pca_obs/pchip_interpolation.py b/mne/preprocessing/pca_obs/pchip_interpolation.py index a35d1b75aed..e336f672363 100755 --- a/mne/preprocessing/pca_obs/pchip_interpolation.py +++ b/mne/preprocessing/pca_obs/pchip_interpolation.py @@ -1,6 +1,6 @@ # Function to interpolate based on PCHIP rather than MNE inbuilt linear option -import mne +# import mne import numpy as np from scipy.interpolate import PchipInterpolator as pchip import matplotlib.pyplot as plt @@ -8,22 +8,13 @@ def PCHIP_interpolation(data, **kwargs): # Check all necessary arguments sent in - required_kws = ["trigger_indices", "interpol_window_sec", "fs", "debug_mode"] + required_kws = ["trigger_indices", "interpol_window_sec", "fs"] assert all([kw in kwargs.keys() for kw in required_kws]), "Error. Some KWs not passed into PCA_OBS." # Extract all kwargs - more elegant ways to do this fs = kwargs['fs'] interpol_window_sec = kwargs['interpol_window_sec'] trigger_indices = kwargs['trigger_indices'] - debug_mode = kwargs['debug_mode'] - - if debug_mode: - plt.figure() - # plot signal with artifact - plot_range = [-50, 100] - test_trial = 100 - xx = (np.arange(plot_range[0], plot_range[1])) / fs * 1000 - plt.plot(xx, data[trigger_indices[test_trial] + plot_range[0]:trigger_indices[test_trial] + plot_range[1]]) # Convert intpol window to msec then convert to samples pre_window = round((interpol_window_sec[0]*1000) * fs / 1000) # in samples @@ -48,12 +39,4 @@ def PCHIP_interpolation(data, **kwargs): if np.mod(ii, 100) == 0: # talk to the operator every 100th trial print(f'stimulation event {ii} \n') - if debug_mode: - # plot signal with interpolated artifact - plt.figure() - plt.plot(xx, data[trigger_indices[test_trial] + plot_range[0]: trigger_indices[test_trial] + plot_range[1]]) - plt.title('After Correction') - - plt.show() - return data diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact.py b/mne/preprocessing/pca_obs/rm_heart_artefact.py index 0de2a73358e..f61a8b0e202 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact.py @@ -5,54 +5,35 @@ from scipy.io import loadmat from scipy.signal import firls from PCA_OBS import * -from get_conditioninfo import * -from get_channels import * +from mne.io import read_raw_fif +from mne import events_from_annotations, Epochs -def rm_heart_artefact(subject, condition, srmr_nr, sampling_rate, pchip): - matlab = False # If this is true, use the data 'prepared' by matlab - testing to see where hump at 0 comes from +if __name__ == '__main__': # Incredibly slow without parallelization # Set variables - subject_id = f'sub-{str(subject).zfill(3)}' - cond_info = get_conditioninfo(condition, srmr_nr) - nblocks = cond_info.nblocks - cond_name = cond_info.cond_name - stimulation = cond_info.stimulation - trigger_name = cond_info.trigger_name + subject_id = f'sub-001' + cond_name = 'median' + sampling_rate = 1000 + esg_chans = ['S35', 'S24', 'S36', 'Iz', 'S17', 'S15', 'S32', 'S22', + 'S19', 'S26', 'S28', 'S9', 'S13', 'S11', 'S7', 'SC1', 'S4', 'S18', + 'S8', 'S31', 'SC6', 'S12', 'S16', 'S5', 'S30', 'S20', 'S34', 'AC', + 'S21', 'S25', 'L1', 'S29', 'S14', 'S33', 'S3', 'AL', 'L4', 'S6', + 'S23'] + # For heartbeat epochs + iv_baseline = [-300 / 1000, -200 / 1000] + iv_epoch = [-400 / 1000, 600 / 1000] # Setting paths input_path = "/data/pt_02569/tmp_data/prepared_py/"+subject_id+"/esg/prepro/" - input_path_m = "/data/pt_02569/tmp_data/prepared/"+subject_id+"/esg/prepro/" - save_path = "/data/pt_02569/tmp_data/ecg_rm_py/"+subject_id+"/esg/prepro/" - os.makedirs(save_path, exist_ok=True) - - figure_path = save_path - - # For debugging just test one channel of each - debug_channel = ['S35'] - _, esg_chans, _ = get_channels(subject, False, False, srmr_nr) # Ignoring ECG and EOG channels - - # Dyanmically set filename - if matlab: - fname = f"raw_{sampling_rate}_spinal_{cond_name}.set" - raw = mne.io.read_raw_eeglab(input_path_m + fname, preload=True) - else: - if pchip: - fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs_pchip" - else: - fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs" - # Read .fif file from the previous step (import_data) - raw = mne.io.read_raw_fif(input_path + fname + '.fif', preload=True) + input_path_m = "/data/pt_02569/tmp_data/prepared/" + subject_id + "/esg/prepro/" + fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs_pchip" + raw = read_raw_fif(input_path + fname + '.fif', preload=True) # Read .mat file with QRS events fname_m = f"raw_{sampling_rate}_spinal_{cond_name}" matdata = loadmat(input_path_m+fname_m+'.mat') QRSevents_m = matdata['QRSevents'] - fwts = matdata['fwts'] - - # Read .h5 file with alternative QRS events - # with h5py.File(input_path+fname+'.h5', "r") as infile: - # QRSevents_p = infile["QRS"][()] # Create filter coefficients fs = sampling_rate @@ -62,69 +43,32 @@ def rm_heart_artefact(subject, condition, srmr_nr, sampling_rate, pchip): ord = round(3*fs/0.5) fwts = firls(ord+1, f, a) - # Run once with a single channel and debug_mode = True to get window information - for ch in debug_channel: - # set PCA_OBS input variables - channelNames = ['S35', 'Iz', 'SC1', 'S3', 'SC6', 'S20', 'L1', 'L4'] - # these channels will be plotted(only for debugging / testing) - - # run PCA_OBS - if pchip: - name = 'pca_chan_' + ch + '_pchip' - else: - name = 'pca_chan_'+ch - PCA_OBS_kwargs = dict( - debug_mode=True, qrs=QRSevents_m, filter_coords=fwts, sr=sampling_rate, - savename=save_path+name, - ch_names=channelNames, sub_nr=subject_id, - condition=cond_name, current_channel=ch - ) - # Apply function modifies the data in raw in place - raw.copy().apply_function(PCA_OBS, picks=[ch], **PCA_OBS_kwargs) - - # This information is the same for each channel - run through fitting once to get vals, add to all channels - keywords = ['window_start_idx', 'window_end_idx'] - fn = f"{save_path}pca_chan_{ch}_{cond_name}_pca_info.h5" - with h5py.File(fn, "r") as infile: - # Get the data - window_start = infile[keywords[0]][()].reshape(-1) - window_end = infile[keywords[1]][()].reshape(-1) - - onset = [x/sampling_rate for x in window_start] # Divide by sampling rate to make times - duration = np.repeat(0.0, len(window_start)) - description = ['fit_start'] * len(window_start) - raw.annotations.append(onset, duration, description, ch_names=[esg_chans] * len(window_start)) - - onset = [x/sampling_rate for x in window_end] - duration = np.repeat(0.0, len(window_end)) - description = ['fit_end'] * len(window_end) - raw.annotations.append(onset, duration, description, ch_names=[esg_chans]*len(window_end)) - - # Then run parallel for all channels with n_jobs set and debug_mode = False - # set PCA_OBS input variables - channelNames = ['S35', 'Iz', 'SC1', 'S3', 'SC6', 'S20', 'L1', 'L4'] - # these channels will be plotted(only for debugging / testing) - # run PCA_OBS # In this case ch_names, current_channel and savename are dummy vars - not necessary really PCA_OBS_kwargs = dict( - debug_mode=False, qrs=QRSevents_m, filter_coords=fwts, sr=sampling_rate, - savename=save_path + 'pca_chan', - ch_names=channelNames, sub_nr=subject_id, - condition=cond_name, current_channel=ch + qrs=QRSevents_m, filter_coords=fwts, sr=sampling_rate ) + events, event_ids = events_from_annotations(raw) + event_id_dict = {key: value for key, value in event_ids.items() if key == 'qrs'} + epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], + baseline=tuple(iv_baseline)) + evoked_before = epochs.average() + # Apply function should modifies the data in raw in place raw.apply_function(PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans)) - # Save the new mne structure with the cleaned data - if matlab: - raw.save(os.path.join(save_path, f'data_clean_ecg_spinal_{cond_name}_withqrs_mat.fif'), fmt='double', - overwrite=True) - else: - if pchip: - raw.save(os.path.join(save_path, f'data_clean_ecg_spinal_{cond_name}_withqrs_pchip.fif'), fmt='double', - overwrite=True) - else: - raw.save(os.path.join(save_path, f'data_clean_ecg_spinal_{cond_name}_withqrs.fif'), fmt='double', - overwrite=True) + epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], + baseline=tuple(iv_baseline)) + evoked_after = epochs.average() + + # Comparison image + fig, axes = plt.subplots(2, 1) + axes[0].plot(evoked_before.times, evoked_before.get_data().T) + axes[0].set_ylim([-0.0005, 0.001]) + axes[0].set_title('Before PCA-OBS') + axes[1].plot(evoked_after.times, evoked_after.get_data().T) + axes[1].set_ylim([-0.0005, 0.001]) + axes[1].set_title('After PCA-OBS') + plt.tight_layout() + plt.show() From 362254062b6472a5317a94b3715a53a487f56eb6 Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Mon, 9 Sep 2024 16:15:42 +0200 Subject: [PATCH 3/4] Implement testing dataset --- mne/preprocessing/pca_obs/PCA_OBS.py | 15 ++-- .../pca_obs/rm_heart_artefact.py | 10 ++- .../pca_obs/rm_heart_artefact_mnedata.py | 77 +++++++++++++++++++ 3 files changed, 92 insertions(+), 10 deletions(-) create mode 100644 mne/preprocessing/pca_obs/rm_heart_artefact_mnedata.py diff --git a/mne/preprocessing/pca_obs/PCA_OBS.py b/mne/preprocessing/pca_obs/PCA_OBS.py index 9209da8a6f0..aee06165d11 100755 --- a/mne/preprocessing/pca_obs/PCA_OBS.py +++ b/mne/preprocessing/pca_obs/PCA_OBS.py @@ -171,16 +171,15 @@ def __init__(self): data = data.reshape(-1) fitted_art = fitted_art.reshape(-1) - # data -= fitted_art - - data_ = np.zeros(len(data)) - data_[0] = data[0] - data_[1:] = data[1:] - fitted_art[:-1] - data = data_ + # One sample shift for my actual data (introduced using matlab r timings) + # data_ = np.zeros(len(data)) + # data_[0] = data[0] + # data_[1:] = data[1:] - fitted_art[:-1] + # data = data_ # Original code is this: - # data -= fitted_art - # data = data.T.reshape(-1) + data -= fitted_art + data = data.T.reshape(-1) # Can only return data return data diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact.py b/mne/preprocessing/pca_obs/rm_heart_artefact.py index f61a8b0e202..1b559ef9f22 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact.py @@ -44,7 +44,6 @@ fwts = firls(ord+1, f, a) # run PCA_OBS - # In this case ch_names, current_channel and savename are dummy vars - not necessary really PCA_OBS_kwargs = dict( qrs=QRSevents_m, filter_coords=fwts, sr=sampling_rate ) @@ -57,7 +56,6 @@ # Apply function should modifies the data in raw in place raw.apply_function(PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans)) - epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], baseline=tuple(iv_baseline)) evoked_after = epochs.average() @@ -71,4 +69,12 @@ axes[1].set_ylim([-0.0005, 0.001]) axes[1].set_title('After PCA-OBS') plt.tight_layout() + + # Comparison image + fig, axes = plt.subplots(1, 1) + axes.plot(evoked_before.times, evoked_before.get_data().T, color='black') + axes.set_ylim([-0.0005, 0.001]) + axes.plot(evoked_after.times, evoked_after.get_data().T, color='green') + axes.set_title('Before (black) versus after (green)') + plt.tight_layout() plt.show() diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_mnedata.py b/mne/preprocessing/pca_obs/rm_heart_artefact_mnedata.py new file mode 100644 index 00000000000..42e58e99463 --- /dev/null +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_mnedata.py @@ -0,0 +1,77 @@ +# Checking algorithm implementation with EEG data form MNE sample datasets +# Following this tutorial data: https://mne.tools/stable/auto_tutorials/preprocessing/50_artifact_correction_ssp.html#what-is-ssp + +from mne.preprocessing import ( + find_ecg_events, + create_ecg_epochs, +) +from mne.io import read_raw_fif +from mne.datasets.sample import data_path +from PCA_OBS import * +from mne import Epochs +import os +import numpy as np +from scipy.signal import firls +import matplotlib.pyplot as plt + +if __name__ == '__main__': + sample_data_folder = data_path(path='/data/pt_02569/mne_test_data/') + sample_data_raw_file = os.path.join( + sample_data_folder, "MEG", "sample", "sample_audvis_raw.fif" + ) + # here we crop and resample just for speed + raw = read_raw_fif(sample_data_raw_file, preload=True) + + # Find ECG events - no ECG channel in data, uses synthetic + ecg_events, ch_ecg, average_pulse, = find_ecg_events(raw) + # Extract just sample timings of ecg events + ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) + # print(ecg_events) + + # Create filter coefficients + fs = raw.info['sfreq'] + a = [0, 0, 1, 1] + f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter + # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter + ord = round(3 * fs / 0.5) + fwts = firls(ord + 1, f, a) + + # For heartbeat epochs + iv_baseline = [-300 / 1000, -200 / 1000] + iv_epoch = [-400 / 1000, 600 / 1000] + + # run PCA_OBS + # Algorithm is extremely sensitive to accurate R-peak timings, won't work as well with the artificial ECG + # channel estimation as we have here + PCA_OBS_kwargs = dict( + qrs=ecg_event_samples, filter_coords=fwts, sr=fs + ) + + epochs = Epochs(raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], + baseline=tuple(iv_baseline)) + evoked_before = epochs.average() + + # Apply function should modifies the data in raw in place + raw.apply_function(PCA_OBS, picks='eeg', **PCA_OBS_kwargs, n_jobs=10) + epochs = Epochs(raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], + baseline=tuple(iv_baseline)) + evoked_after = epochs.average() + + # Comparison image + fig, axes = plt.subplots(2, 1) + axes[0].plot(evoked_before.times, evoked_before.get_data().T) + axes[0].set_ylim([-1e-5, 3e-5]) + axes[0].set_title('Before PCA-OBS') + axes[1].plot(evoked_after.times, evoked_after.get_data().T) + axes[1].set_ylim([-1e-5, 3e-5]) + axes[1].set_title('After PCA-OBS') + plt.tight_layout() + + # Comparison image + fig, axes = plt.subplots(1, 1) + axes.plot(evoked_before.times, evoked_before.get_data().T, color='black') + axes.set_ylim([-1e-5, 3e-5]) + axes.plot(evoked_after.times, evoked_after.get_data().T, color='green') + axes.set_title('Before (black) versus after (green)') + plt.tight_layout() + plt.show() From 6c42f33f5fca1dec460fa902b5b6a3c21b6d4131 Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Mon, 7 Oct 2024 15:43:26 +0200 Subject: [PATCH 4/4] Feat: Update examples --- ... => rm_heart_artefact_cortical_mnedata.py} | 0 ...rm_heart_artefact_spinal_impreciserpeak.py | 80 +++++++++++++++++++ ... rm_heart_artefact_spinal_preciserpeak.py} | 0 3 files changed, 80 insertions(+) rename mne/preprocessing/pca_obs/{rm_heart_artefact_mnedata.py => rm_heart_artefact_cortical_mnedata.py} (100%) create mode 100755 mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py rename mne/preprocessing/pca_obs/{rm_heart_artefact.py => rm_heart_artefact_spinal_preciserpeak.py} (100%) diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_mnedata.py b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py similarity index 100% rename from mne/preprocessing/pca_obs/rm_heart_artefact_mnedata.py rename to mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py new file mode 100755 index 00000000000..7634096c909 --- /dev/null +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py @@ -0,0 +1,80 @@ +# Calls PCA_OBS which in turn calls fit_ecgTemplate to remove the heart artefact via PCA_OBS (Principal Component +# Analysis, Optimal Basis Sets) + +import os +from scipy.io import loadmat +from scipy.signal import firls +from PCA_OBS import * +from mne.io import read_raw_fif +from mne import events_from_annotations, Epochs +from mne.preprocessing import find_ecg_events + + +if __name__ == '__main__': + # Incredibly slow without parallelization + # Set variables + subject_id = f'sub-001' + cond_name = 'median' + sampling_rate = 1000 + esg_chans = ['S35', 'S24', 'S36', 'Iz', 'S17', 'S15', 'S32', 'S22', + 'S19', 'S26', 'S28', 'S9', 'S13', 'S11', 'S7', 'SC1', 'S4', 'S18', + 'S8', 'S31', 'SC6', 'S12', 'S16', 'S5', 'S30', 'S20', 'S34', 'AC', + 'S21', 'S25', 'L1', 'S29', 'S14', 'S33', 'S3', 'AL', 'L4', 'S6', + 'S23'] + # For heartbeat epochs + iv_baseline = [-300 / 1000, -200 / 1000] + iv_epoch = [-400 / 1000, 600 / 1000] + + # Setting paths + input_path = "/data/pt_02569/tmp_data/prepared_py/"+subject_id+"/esg/prepro/" + fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs_pchip" + raw = read_raw_fif(input_path + fname + '.fif', preload=True) + + # Find ECG events - no ECG channel in data, uses synthetic + ecg_events, ch_ecg, average_pulse, = find_ecg_events(raw, ch_name='ECG') + # Extract just sample timings of ecg events + ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) + + # Create filter coefficients + fs = sampling_rate + a = [0, 0, 1, 1] + f = [0, 0.4/(fs/2), 0.9/(fs / 2), 1] # 0.9 Hz highpass filter + # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter + ord = round(3*fs/0.5) + fwts = firls(ord+1, f, a) + + # run PCA_OBS + PCA_OBS_kwargs = dict( + qrs=ecg_event_samples, filter_coords=fwts, sr=sampling_rate + ) + + events, event_ids = events_from_annotations(raw) + event_id_dict = {key: value for key, value in event_ids.items() if key == 'qrs'} + epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], + baseline=tuple(iv_baseline)) + evoked_before = epochs.average() + + # Apply function should modifies the data in raw in place + raw.apply_function(PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans)) + epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], + baseline=tuple(iv_baseline)) + evoked_after = epochs.average() + + # Comparison image + fig, axes = plt.subplots(2, 1) + axes[0].plot(evoked_before.times, evoked_before.get_data().T) + axes[0].set_ylim([-0.0005, 0.001]) + axes[0].set_title('Before PCA-OBS') + axes[1].plot(evoked_after.times, evoked_after.get_data().T) + axes[1].set_ylim([-0.0005, 0.001]) + axes[1].set_title('After PCA-OBS') + plt.tight_layout() + + # Comparison image + fig, axes = plt.subplots(1, 1) + axes.plot(evoked_before.times, evoked_before.get_data().T, color='black') + axes.set_ylim([-0.0005, 0.001]) + axes.plot(evoked_after.times, evoked_after.get_data().T, color='green') + axes.set_title('Before (black) versus after (green)') + plt.tight_layout() + plt.show() diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py similarity index 100% rename from mne/preprocessing/pca_obs/rm_heart_artefact.py rename to mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py