From 2013fb487c936c12b9fd6308deef6a223fc9584d Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Sun, 1 Sep 2024 12:57:06 +0200 Subject: [PATCH 01/71] feat: add initial source code --- mne/preprocessing/pca_obs/PCA_OBS.py | 281 ++++++++++++++++++ mne/preprocessing/pca_obs/fit_ecgTemplate.py | 78 +++++ .../pca_obs/pchip_interpolation.py | 59 ++++ .../pca_obs/rm_heart_artefact.py | 130 ++++++++ 4 files changed, 548 insertions(+) create mode 100755 mne/preprocessing/pca_obs/PCA_OBS.py create mode 100755 mne/preprocessing/pca_obs/fit_ecgTemplate.py create mode 100755 mne/preprocessing/pca_obs/pchip_interpolation.py create mode 100755 mne/preprocessing/pca_obs/rm_heart_artefact.py diff --git a/mne/preprocessing/pca_obs/PCA_OBS.py b/mne/preprocessing/pca_obs/PCA_OBS.py new file mode 100755 index 00000000000..31ff3b4d06b --- /dev/null +++ b/mne/preprocessing/pca_obs/PCA_OBS.py @@ -0,0 +1,281 @@ +import numpy as np +import mne +from scipy.signal import filtfilt, detrend +import matplotlib.pyplot as plt +from sklearn.decomposition import PCA +from fit_ecgTemplate import fit_ecgTemplate +import math +import h5py + + +def PCA_OBS(data, **kwargs): + + # Declare class to hold pca information + class PCAInfo(): + def __init__(self): + pass + + # Instantiate class + pca_info = PCAInfo() + + # Check all necessary arguments sent in + required_kws = ["debug_mode", "qrs", "filter_coords", "sr", "savename", "ch_names", "sub_nr", "condition", + "current_channel"] + assert all([kw in kwargs.keys() for kw in required_kws]), "Error. Some KWs not passed into PCA_OBS." + + # Extract all kwargs + debug_mode = kwargs['debug_mode'] + qrs = kwargs['qrs'] + filter_coords = kwargs['filter_coords'] + sr = kwargs['sr'] + ch_names = kwargs['ch_names'] + sub_nr = kwargs['sub_nr'] + condition = kwargs['condition'] + if debug_mode: # Only need current channel and saving if we're debugging + current_channel = kwargs['current_channel'] + savename = kwargs['savename'] + + fs = sr + + # Standard delay between QRS peak and artifact + delay = 0 + + Gwindow = 2 + GHW = np.floor(Gwindow / 2).astype('int') + rcount = 0 + firstplot = 1 + + # set to baseline + data = data.reshape(-1, 1) + data = data.T + data = data - np.mean(data, axis=1) + + # Allocate memory + fitted_art = np.zeros(data.shape) + peakplot = np.zeros(data.shape) + + # Extract QRS events + for idx in qrs[0]: + if idx < len(peakplot[0, :]): + peakplot[0, idx] = 1 # logical indexed locations of qrs events + # sh = np.zeros((1, delay)) + # np1 = len(peakplot) + # peakplot = [sh, peakplot[0:np1 - delay]] # shifts indexed array by the delay - skipped here since delay=0 + + peak_idx = np.nonzero(peakplot)[1] # Selecting indices along columns + peak_idx = peak_idx.reshape(-1, 1) + peak_count = len(peak_idx) + + ################################################################ + # Preparatory work - reserving memory, configure sizes, de-trend + ################################################################ + print('Pulse artifact subtraction in progress...Please wait!') + + # define peak range based on RR + RR = np.diff(peak_idx[:, 0]) + mRR = np.median(RR) + peak_range = round(mRR/2) # Rounds to an integer + midP = peak_range + 1 + baseline_range = [0, round(peak_range/8)] + n_samples_fit = round(peak_range/8) # sample fit for interpolation between fitted artifact windows + + # make sure array is long enough for PArange (if not cut off last ECG peak) + pa = peak_count # Number of QRS complexes detected + while peak_idx[pa-1, 0] + peak_range > len(data[0]): + pa = pa - 1 + steps = 1 * pa + peak_count = pa + + # Filter channel + eegchan = filtfilt(filter_coords, 1, data) + + # build PCA matrix(heart-beat-epochs x window-length) + pcamat = np.zeros((peak_count - 1, 2*peak_range+1)) # [epoch x time] + # picking out heartbeat epochs + for p in range(1, peak_count): + pcamat[p-1, :] = eegchan[0, peak_idx[p, 0] - peak_range: peak_idx[p, 0] + peak_range+1] + + # detrending matrix(twice) + pcamat = detrend(pcamat, type='constant', axis=1) # [epoch x time] - detrended along the epoch + mean_effect = np.mean(pcamat, axis=0) # [1 x time], contains the mean over all epochs + std_effect = np.std(pcamat, axis=0) # want mean and std of each column + dpcamat = detrend(pcamat, type='constant', axis=1) # [time x epoch] + + ################################################################### + # Perform PCA with sklearn + ################################################################### + # run PCA(performs SVD(singular value decomposition)) + pca = PCA(svd_solver="full") + pca.fit(dpcamat) + eigen_vectors = pca.components_ + eigen_values = pca.explained_variance_ + factor_loadings = pca.components_.T*np.sqrt(pca.explained_variance_) + pca_info.eigen_vectors = eigen_vectors + pca_info.factor_loadings = factor_loadings + pca_info.eigen_values = eigen_values + pca_info.expl_var = pca.explained_variance_ratio_ + + # define selected number of components using profile likelihood + pca_info.nComponents = 4 + + # Creates plots + if debug_mode: + # plot pca variables figure + comp2plot = pca_info.nComponents + fig, axs = plt.subplots(2, 2) + for a in np.arange(comp2plot): + axs[0, 0].plot(pca_info.eigen_vectors[:, a], label=f"Data {a}") + axs[1, 1].plot(pca_info.factor_loadings[:, a], label=f"Data {a}") + axs[0, 0].set_title('Evec') + axs[1, 1].set_title('Factor Loadings') + axs[1, 1].set(xlabel='time') + + axs[0, 1].plot(np.arange(len(pca_info.expl_var)), pca_info.expl_var, 'r*') + axs[0, 1].set(xlabel='components', ylabel='var explained (%)') + cum_explained = np.cumsum(pca_info.expl_var) + axs[0, 1].set_title(f"first {pca_info.nComponents} comp, {cum_explained[pca_info.nComponents]} % var") + + axs[1, 0].plot(pca_info.eigen_values) + axs[1, 0].set_title('eigenvalues') + axs[1, 0].set(xlabel='components') + + fig.suptitle(f"{sub_nr} thresholds PCA vars channel {current_channel}") + plt.tight_layout() + fig.savefig(f"{savename}_{condition}.jpg") + + if debug_mode: + pca_info.chan = current_channel + pca_info.meanEffect = mean_effect.T + nComponents = pca_info.nComponents + + ####################################################################### + # Make template of the ECG artefact + ####################################################################### + mean_effect = mean_effect.reshape(-1, 1) + pca_template = np.c_[mean_effect, factor_loadings[:, 0:nComponents]] + + # Plot template vars + if debug_mode: + # plot template vars + fig = plt.figure() + pcatime = (np.arange(-peak_range, peak_range+1))/fs + pcatime = pcatime.reshape(-1) + plt.plot(pcatime, std_effect) + plt.plot(pcatime, mean_effect) + plt.plot(pcatime, factor_loadings[:, 0: nComponents]) + plt.legend(['std effect', 'mean effect', 'PCA_1', 'PCA_2', 'PCA_3', 'PCA_4']) + fig.suptitle(f"{sub_nr} papc channel {current_channel}") + plt.tight_layout() + fig.savefig(f"{savename}_templateVars_{condition}.jpg") + + ################################################################################### + # Data Fitting + ################################################################################### + window_start_idx = [] + window_end_idx = [] + for p in range(0, peak_count): + # Deals with start portion of data + if p == 0: + pre_range = peak_range + post_range = math.floor((peak_idx[p + 1] - peak_idx[p])/2) + if post_range > peak_range: + post_range = peak_range + try: + post_idx_nextPeak = [] + fitted_art, post_idx_nextPeak = fit_ecgTemplate(data, pca_template, peak_idx[p], peak_range, + pre_range, post_range, baseline_range, midP, + fitted_art, post_idx_nextPeak, n_samples_fit) + # Appending to list instead of using counter + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + except Exception as e: + print(f'Cannot fit first ECG epoch. Reason: {e}') + + # Deals with last edge of data + elif p == peak_count: + print('On last section - almost there!') + try: + pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) + post_range = peak_range + if pre_range > peak_range: + pre_range = peak_range + fitted_art, _ = fit_ecgTemplate(data, pca_template, peak_idx(p), peak_range, pre_range, post_range, + baseline_range, midP, fitted_art, post_idx_nextPeak, n_samples_fit) + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + except Exception as e: + print(f'Cannot fit last ECG epoch. Reason: {e}') + + # Deals with middle portion of data + else: + try: + # ---------------- Processing of central data - -------------------- + # cycle through peak artifacts identified by peakplot + pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) + post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) + if pre_range >= peak_range: + pre_range = peak_range + if post_range > peak_range: + post_range = peak_range + + aTemplate = pca_template[midP - peak_range-1:midP + peak_range+1, :] + fitted_art, post_idx_nextPeak = fit_ecgTemplate(data, aTemplate, peak_idx[p], peak_range, pre_range, + post_range, baseline_range, midP, fitted_art, + post_idx_nextPeak, n_samples_fit) + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + except Exception as e: + print(f"Cannot fit middle section of data. Reason: {e}") + + # Plot some channels + if debug_mode: + # check with plot what has been done + # First check if this channel is one we want to plot for debugging + plotChannel = 0 + for ii in range(0,len(ch_names)): + if current_channel == ch_names[ii]: + plotChannel = 1 + + if plotChannel == 1: + fig = plt.figure() + plt.plot((np.arange(0, len(fitted_art[0, :]))/fs).reshape(-1, 1), data[:].T, zorder=0) + plt.plot((np.arange(0, len(fitted_art[0, :]))/fs).reshape(-1, 1), eegchan[:].T, 'r', zorder=5) + plt.plot((np.arange(0, len(fitted_art[0, :]))/fs).reshape(-1, 1), fitted_art[:].T, 'g', zorder=10) + plt.plot((np.arange(0, len(fitted_art[0, :]))/fs).reshape(-1, 1), (np.subtract(data[:], fitted_art[:])).T, 'm', zorder=15) + plt.legend(['raw data', 'filtered', 'fitted_art', 'clean'], loc='upper right').set_zorder(20) + plt.xlabel('time [s]') + plt.ylabel('amplitude [V]') + plt.title('Subject ' + sub_nr + ', channel ' + current_channel) + fig.savefig(f"{savename}_compareresults_{condition}.jpg") + + # Actually subtract the artefact, return needs to be the same shape as input data + # One sample shift purely due to the fact the r-peaks are currently detected in MATLAB + data = data.reshape(-1) + fitted_art = fitted_art.reshape(-1) + + # data -= fitted_art + + data_ = np.zeros(len(data)) + data_[0] = data[0] + data_[1:] = data[1:] - fitted_art[:-1] + data = data_ + + # Original code is this: + # data -= fitted_art + # data = data.T.reshape(-1) + + # Can't add annotations for window start and end (sample number added) to mne raw structure here + # Add it to the pca info class and store it that way + # Can then access in the main rm_heart_artefact and create the annotations + # Only save the pca vars if we're in debug mode + if debug_mode: + pca_info.window_start_idx = window_start_idx + pca_info.window_end_idx = window_end_idx + dataset_keywords = [a for a in dir(pca_info) if not a.startswith('__')] + fn = f"{savename}_{condition}_pca_info.h5" + with h5py.File(fn, "w") as outfile: + for keyword in dataset_keywords: + outfile.create_dataset(keyword, data=getattr(pca_info, keyword)) + + # Can only return data + return data diff --git a/mne/preprocessing/pca_obs/fit_ecgTemplate.py b/mne/preprocessing/pca_obs/fit_ecgTemplate.py new file mode 100755 index 00000000000..00ef6459fb6 --- /dev/null +++ b/mne/preprocessing/pca_obs/fit_ecgTemplate.py @@ -0,0 +1,78 @@ +import numpy as np +from scipy.signal import detrend +from scipy.interpolate import PchipInterpolator as pchip +import h5py + + +def fit_ecgTemplate(data, pca_template, aPeak_idx, peak_range, pre_range, post_range, baseline_range, midP, fitted_art, post_idx_previousPeak, n_samples_fit): + # Declare class to hold ecg fit information + class fitECG(): + def __init__(self): + pass + + # Instantiate class + fitecg = fitECG() + + # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previouspeak + # Then nextpeak is returned at the end and the process repeats + # select window of template + template = pca_template[midP - peak_range-1: midP + peak_range+1, :] + + # select window of data and detrend it + slice = data[0, aPeak_idx[0] - peak_range:aPeak_idx[0] + peak_range+1] + detrended_data = detrend(slice.reshape(-1), type='constant') + + # maps data on template and then maps it again back to the sensor space + least_square = np.linalg.lstsq(template, detrended_data, rcond=None) + pad_fit = np.dot(template, least_square[0]) + + # fit artifact, I already loop through externally channel to channel + fitted_art[0, aPeak_idx[0] - pre_range-1: aPeak_idx[0] + post_range] = pad_fit[midP - pre_range-1: midP + post_range].T + + fitecg.fitted_art = fitted_art + fitecg.template = template + fitecg.detrended_data = detrended_data + fitecg.pad_fit = pad_fit + fitecg.aPeak_idx = aPeak_idx + fitecg.midP = midP + fitecg.peak_range = peak_range + fitecg.data = data + + post_idx_nextPeak = [aPeak_idx[0] + post_range] + + # Check it's not empty + if len(post_idx_previousPeak) != 0: + # interpolate time between peaks + intpol_window = np.ceil([post_idx_previousPeak[0], aPeak_idx[0] - pre_range]).astype('int') # interpolation window + fitecg.intpol_window = intpol_window + + if intpol_window[0] < intpol_window[1]: + # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data + + # You have x_fit which is two slices on either side of the interpolation window endpoints + # You have y_fit which is the y vals corresponding to x values above + # You have x_interpol which is the time points between the two slices in x_fit that you want to interpolate + # You have y_interpol which is values from pchip at the time points specified in x_interpol + x_interpol = np.arange(intpol_window[0], intpol_window[1] + 1, 1) # points to be interpolated in pt - the gap between the endpoints of the window + x_fit = np.concatenate([np.arange(intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1), + np.arange(intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1)]) # Entire range of x values in this step (taking some number of samples before and after the window) + y_fit = fitted_art[0, x_fit] + y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation + + # Then make fitted artefact in the desired range equal to the completed fit above + fitted_art[0, post_idx_previousPeak[0]: aPeak_idx[0] - pre_range+1] = y_interpol + + fitecg.x_fit = x_fit + fitecg.y_fit = y_fit + fitecg.x_interpol = x_interpol + fitecg.y_interpol = y_interpol + fitecg.fitted_art = fitted_art # Reassign if we've gone into this loop + + # # Save to file to compare to matlab - only for debugging + # dataset_keywords = [a for a in dir(fitecg) if not a.startswith('__')] + # fn = f"/data/pt_02569/tmp_data/test_data/fitecg_test_{aPeak_idx[0]}_sub-001_tibial_S35.h5" + # with h5py.File(fn, "w") as outfile: + # for keyword in dataset_keywords: + # outfile.create_dataset(keyword, data=getattr(fitecg, keyword)) + + return fitted_art, post_idx_nextPeak diff --git a/mne/preprocessing/pca_obs/pchip_interpolation.py b/mne/preprocessing/pca_obs/pchip_interpolation.py new file mode 100755 index 00000000000..a35d1b75aed --- /dev/null +++ b/mne/preprocessing/pca_obs/pchip_interpolation.py @@ -0,0 +1,59 @@ +# Function to interpolate based on PCHIP rather than MNE inbuilt linear option + +import mne +import numpy as np +from scipy.interpolate import PchipInterpolator as pchip +import matplotlib.pyplot as plt + + +def PCHIP_interpolation(data, **kwargs): + # Check all necessary arguments sent in + required_kws = ["trigger_indices", "interpol_window_sec", "fs", "debug_mode"] + assert all([kw in kwargs.keys() for kw in required_kws]), "Error. Some KWs not passed into PCA_OBS." + + # Extract all kwargs - more elegant ways to do this + fs = kwargs['fs'] + interpol_window_sec = kwargs['interpol_window_sec'] + trigger_indices = kwargs['trigger_indices'] + debug_mode = kwargs['debug_mode'] + + if debug_mode: + plt.figure() + # plot signal with artifact + plot_range = [-50, 100] + test_trial = 100 + xx = (np.arange(plot_range[0], plot_range[1])) / fs * 1000 + plt.plot(xx, data[trigger_indices[test_trial] + plot_range[0]:trigger_indices[test_trial] + plot_range[1]]) + + # Convert intpol window to msec then convert to samples + pre_window = round((interpol_window_sec[0]*1000) * fs / 1000) # in samples + post_window = round((interpol_window_sec[1]*1000) * fs / 1000) # in samples + intpol_window = np.ceil([pre_window, post_window]).astype(int) # interpolation window + + n_samples_fit = 5 # number of samples before and after cut used for interpolation fit + + x_fit_raw = np.concatenate([np.arange(intpol_window[0]-n_samples_fit-1, intpol_window[0], 1), + np.arange(intpol_window[1]+1, intpol_window[1]+n_samples_fit+2, 1)]) + x_interpol_raw = np.arange(intpol_window[0], intpol_window[1]+1, 1) # points to be interpolated; in pt + + for ii in np.arange(0, len(trigger_indices)): # loop through all stimulation events + x_fit = trigger_indices[ii] + x_fit_raw # fit point latencies for this event + x_interpol = trigger_indices[ii] + x_interpol_raw # latencies for to-be-interpolated data points + + # Data is just a string of values + y_fit = data[x_fit] # y values to be fitted + y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation + data[x_interpol] = y_interpol # replace in data + + if np.mod(ii, 100) == 0: # talk to the operator every 100th trial + print(f'stimulation event {ii} \n') + + if debug_mode: + # plot signal with interpolated artifact + plt.figure() + plt.plot(xx, data[trigger_indices[test_trial] + plot_range[0]: trigger_indices[test_trial] + plot_range[1]]) + plt.title('After Correction') + + plt.show() + + return data diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact.py b/mne/preprocessing/pca_obs/rm_heart_artefact.py new file mode 100755 index 00000000000..0de2a73358e --- /dev/null +++ b/mne/preprocessing/pca_obs/rm_heart_artefact.py @@ -0,0 +1,130 @@ +# Calls PCA_OBS which in turn calls fit_ecgTemplate to remove the heart artefact via PCA_OBS (Principal Component +# Analysis, Optimal Basis Sets) + +import os +from scipy.io import loadmat +from scipy.signal import firls +from PCA_OBS import * +from get_conditioninfo import * +from get_channels import * + + +def rm_heart_artefact(subject, condition, srmr_nr, sampling_rate, pchip): + matlab = False # If this is true, use the data 'prepared' by matlab - testing to see where hump at 0 comes from + # Incredibly slow without parallelization + # Set variables + subject_id = f'sub-{str(subject).zfill(3)}' + cond_info = get_conditioninfo(condition, srmr_nr) + nblocks = cond_info.nblocks + cond_name = cond_info.cond_name + stimulation = cond_info.stimulation + trigger_name = cond_info.trigger_name + + # Setting paths + input_path = "/data/pt_02569/tmp_data/prepared_py/"+subject_id+"/esg/prepro/" + input_path_m = "/data/pt_02569/tmp_data/prepared/"+subject_id+"/esg/prepro/" + save_path = "/data/pt_02569/tmp_data/ecg_rm_py/"+subject_id+"/esg/prepro/" + os.makedirs(save_path, exist_ok=True) + + figure_path = save_path + + # For debugging just test one channel of each + debug_channel = ['S35'] + _, esg_chans, _ = get_channels(subject, False, False, srmr_nr) # Ignoring ECG and EOG channels + + # Dyanmically set filename + if matlab: + fname = f"raw_{sampling_rate}_spinal_{cond_name}.set" + raw = mne.io.read_raw_eeglab(input_path_m + fname, preload=True) + else: + if pchip: + fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs_pchip" + else: + fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs" + # Read .fif file from the previous step (import_data) + raw = mne.io.read_raw_fif(input_path + fname + '.fif', preload=True) + + # Read .mat file with QRS events + fname_m = f"raw_{sampling_rate}_spinal_{cond_name}" + matdata = loadmat(input_path_m+fname_m+'.mat') + QRSevents_m = matdata['QRSevents'] + fwts = matdata['fwts'] + + # Read .h5 file with alternative QRS events + # with h5py.File(input_path+fname+'.h5', "r") as infile: + # QRSevents_p = infile["QRS"][()] + + # Create filter coefficients + fs = sampling_rate + a = [0, 0, 1, 1] + f = [0, 0.4/(fs/2), 0.9/(fs / 2), 1] # 0.9 Hz highpass filter + # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter + ord = round(3*fs/0.5) + fwts = firls(ord+1, f, a) + + # Run once with a single channel and debug_mode = True to get window information + for ch in debug_channel: + # set PCA_OBS input variables + channelNames = ['S35', 'Iz', 'SC1', 'S3', 'SC6', 'S20', 'L1', 'L4'] + # these channels will be plotted(only for debugging / testing) + + # run PCA_OBS + if pchip: + name = 'pca_chan_' + ch + '_pchip' + else: + name = 'pca_chan_'+ch + PCA_OBS_kwargs = dict( + debug_mode=True, qrs=QRSevents_m, filter_coords=fwts, sr=sampling_rate, + savename=save_path+name, + ch_names=channelNames, sub_nr=subject_id, + condition=cond_name, current_channel=ch + ) + # Apply function modifies the data in raw in place + raw.copy().apply_function(PCA_OBS, picks=[ch], **PCA_OBS_kwargs) + + # This information is the same for each channel - run through fitting once to get vals, add to all channels + keywords = ['window_start_idx', 'window_end_idx'] + fn = f"{save_path}pca_chan_{ch}_{cond_name}_pca_info.h5" + with h5py.File(fn, "r") as infile: + # Get the data + window_start = infile[keywords[0]][()].reshape(-1) + window_end = infile[keywords[1]][()].reshape(-1) + + onset = [x/sampling_rate for x in window_start] # Divide by sampling rate to make times + duration = np.repeat(0.0, len(window_start)) + description = ['fit_start'] * len(window_start) + raw.annotations.append(onset, duration, description, ch_names=[esg_chans] * len(window_start)) + + onset = [x/sampling_rate for x in window_end] + duration = np.repeat(0.0, len(window_end)) + description = ['fit_end'] * len(window_end) + raw.annotations.append(onset, duration, description, ch_names=[esg_chans]*len(window_end)) + + # Then run parallel for all channels with n_jobs set and debug_mode = False + # set PCA_OBS input variables + channelNames = ['S35', 'Iz', 'SC1', 'S3', 'SC6', 'S20', 'L1', 'L4'] + # these channels will be plotted(only for debugging / testing) + + # run PCA_OBS + # In this case ch_names, current_channel and savename are dummy vars - not necessary really + PCA_OBS_kwargs = dict( + debug_mode=False, qrs=QRSevents_m, filter_coords=fwts, sr=sampling_rate, + savename=save_path + 'pca_chan', + ch_names=channelNames, sub_nr=subject_id, + condition=cond_name, current_channel=ch + ) + + # Apply function should modifies the data in raw in place + raw.apply_function(PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans)) + + # Save the new mne structure with the cleaned data + if matlab: + raw.save(os.path.join(save_path, f'data_clean_ecg_spinal_{cond_name}_withqrs_mat.fif'), fmt='double', + overwrite=True) + else: + if pchip: + raw.save(os.path.join(save_path, f'data_clean_ecg_spinal_{cond_name}_withqrs_pchip.fif'), fmt='double', + overwrite=True) + else: + raw.save(os.path.join(save_path, f'data_clean_ecg_spinal_{cond_name}_withqrs.fif'), fmt='double', + overwrite=True) From 6d5f5b234e2b25626dfffd8ab1a4373e2792e13a Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Fri, 6 Sep 2024 16:15:26 +0200 Subject: [PATCH 02/71] Minimum working example with local data --- mne/preprocessing/pca_obs/PCA_OBS.py | 99 +------------ mne/preprocessing/pca_obs/fit_ecgTemplate.py | 7 - .../pca_obs/pchip_interpolation.py | 21 +-- .../pca_obs/rm_heart_artefact.py | 132 +++++------------- 4 files changed, 42 insertions(+), 217 deletions(-) diff --git a/mne/preprocessing/pca_obs/PCA_OBS.py b/mne/preprocessing/pca_obs/PCA_OBS.py index 31ff3b4d06b..9209da8a6f0 100755 --- a/mne/preprocessing/pca_obs/PCA_OBS.py +++ b/mne/preprocessing/pca_obs/PCA_OBS.py @@ -1,5 +1,5 @@ import numpy as np -import mne +# import mne from scipy.signal import filtfilt, detrend import matplotlib.pyplot as plt from sklearn.decomposition import PCA @@ -19,32 +19,16 @@ def __init__(self): pca_info = PCAInfo() # Check all necessary arguments sent in - required_kws = ["debug_mode", "qrs", "filter_coords", "sr", "savename", "ch_names", "sub_nr", "condition", - "current_channel"] + required_kws = ["qrs", "filter_coords", "sr"] assert all([kw in kwargs.keys() for kw in required_kws]), "Error. Some KWs not passed into PCA_OBS." # Extract all kwargs - debug_mode = kwargs['debug_mode'] qrs = kwargs['qrs'] filter_coords = kwargs['filter_coords'] sr = kwargs['sr'] - ch_names = kwargs['ch_names'] - sub_nr = kwargs['sub_nr'] - condition = kwargs['condition'] - if debug_mode: # Only need current channel and saving if we're debugging - current_channel = kwargs['current_channel'] - savename = kwargs['savename'] fs = sr - # Standard delay between QRS peak and artifact - delay = 0 - - Gwindow = 2 - GHW = np.floor(Gwindow / 2).astype('int') - rcount = 0 - firstplot = 1 - # set to baseline data = data.reshape(-1, 1) data = data.T @@ -58,9 +42,6 @@ def __init__(self): for idx in qrs[0]: if idx < len(peakplot[0, :]): peakplot[0, idx] = 1 # logical indexed locations of qrs events - # sh = np.zeros((1, delay)) - # np1 = len(peakplot) - # peakplot = [sh, peakplot[0:np1 - delay]] # shifts indexed array by the delay - skipped here since delay=0 peak_idx = np.nonzero(peakplot)[1] # Selecting indices along columns peak_idx = peak_idx.reshape(-1, 1) @@ -117,34 +98,6 @@ def __init__(self): # define selected number of components using profile likelihood pca_info.nComponents = 4 - - # Creates plots - if debug_mode: - # plot pca variables figure - comp2plot = pca_info.nComponents - fig, axs = plt.subplots(2, 2) - for a in np.arange(comp2plot): - axs[0, 0].plot(pca_info.eigen_vectors[:, a], label=f"Data {a}") - axs[1, 1].plot(pca_info.factor_loadings[:, a], label=f"Data {a}") - axs[0, 0].set_title('Evec') - axs[1, 1].set_title('Factor Loadings') - axs[1, 1].set(xlabel='time') - - axs[0, 1].plot(np.arange(len(pca_info.expl_var)), pca_info.expl_var, 'r*') - axs[0, 1].set(xlabel='components', ylabel='var explained (%)') - cum_explained = np.cumsum(pca_info.expl_var) - axs[0, 1].set_title(f"first {pca_info.nComponents} comp, {cum_explained[pca_info.nComponents]} % var") - - axs[1, 0].plot(pca_info.eigen_values) - axs[1, 0].set_title('eigenvalues') - axs[1, 0].set(xlabel='components') - - fig.suptitle(f"{sub_nr} thresholds PCA vars channel {current_channel}") - plt.tight_layout() - fig.savefig(f"{savename}_{condition}.jpg") - - if debug_mode: - pca_info.chan = current_channel pca_info.meanEffect = mean_effect.T nComponents = pca_info.nComponents @@ -154,20 +107,6 @@ def __init__(self): mean_effect = mean_effect.reshape(-1, 1) pca_template = np.c_[mean_effect, factor_loadings[:, 0:nComponents]] - # Plot template vars - if debug_mode: - # plot template vars - fig = plt.figure() - pcatime = (np.arange(-peak_range, peak_range+1))/fs - pcatime = pcatime.reshape(-1) - plt.plot(pcatime, std_effect) - plt.plot(pcatime, mean_effect) - plt.plot(pcatime, factor_loadings[:, 0: nComponents]) - plt.legend(['std effect', 'mean effect', 'PCA_1', 'PCA_2', 'PCA_3', 'PCA_4']) - fig.suptitle(f"{sub_nr} papc channel {current_channel}") - plt.tight_layout() - fig.savefig(f"{savename}_templateVars_{condition}.jpg") - ################################################################################### # Data Fitting ################################################################################### @@ -227,27 +166,6 @@ def __init__(self): except Exception as e: print(f"Cannot fit middle section of data. Reason: {e}") - # Plot some channels - if debug_mode: - # check with plot what has been done - # First check if this channel is one we want to plot for debugging - plotChannel = 0 - for ii in range(0,len(ch_names)): - if current_channel == ch_names[ii]: - plotChannel = 1 - - if plotChannel == 1: - fig = plt.figure() - plt.plot((np.arange(0, len(fitted_art[0, :]))/fs).reshape(-1, 1), data[:].T, zorder=0) - plt.plot((np.arange(0, len(fitted_art[0, :]))/fs).reshape(-1, 1), eegchan[:].T, 'r', zorder=5) - plt.plot((np.arange(0, len(fitted_art[0, :]))/fs).reshape(-1, 1), fitted_art[:].T, 'g', zorder=10) - plt.plot((np.arange(0, len(fitted_art[0, :]))/fs).reshape(-1, 1), (np.subtract(data[:], fitted_art[:])).T, 'm', zorder=15) - plt.legend(['raw data', 'filtered', 'fitted_art', 'clean'], loc='upper right').set_zorder(20) - plt.xlabel('time [s]') - plt.ylabel('amplitude [V]') - plt.title('Subject ' + sub_nr + ', channel ' + current_channel) - fig.savefig(f"{savename}_compareresults_{condition}.jpg") - # Actually subtract the artefact, return needs to be the same shape as input data # One sample shift purely due to the fact the r-peaks are currently detected in MATLAB data = data.reshape(-1) @@ -264,18 +182,5 @@ def __init__(self): # data -= fitted_art # data = data.T.reshape(-1) - # Can't add annotations for window start and end (sample number added) to mne raw structure here - # Add it to the pca info class and store it that way - # Can then access in the main rm_heart_artefact and create the annotations - # Only save the pca vars if we're in debug mode - if debug_mode: - pca_info.window_start_idx = window_start_idx - pca_info.window_end_idx = window_end_idx - dataset_keywords = [a for a in dir(pca_info) if not a.startswith('__')] - fn = f"{savename}_{condition}_pca_info.h5" - with h5py.File(fn, "w") as outfile: - for keyword in dataset_keywords: - outfile.create_dataset(keyword, data=getattr(pca_info, keyword)) - # Can only return data return data diff --git a/mne/preprocessing/pca_obs/fit_ecgTemplate.py b/mne/preprocessing/pca_obs/fit_ecgTemplate.py index 00ef6459fb6..11043b0c56d 100755 --- a/mne/preprocessing/pca_obs/fit_ecgTemplate.py +++ b/mne/preprocessing/pca_obs/fit_ecgTemplate.py @@ -68,11 +68,4 @@ def __init__(self): fitecg.y_interpol = y_interpol fitecg.fitted_art = fitted_art # Reassign if we've gone into this loop - # # Save to file to compare to matlab - only for debugging - # dataset_keywords = [a for a in dir(fitecg) if not a.startswith('__')] - # fn = f"/data/pt_02569/tmp_data/test_data/fitecg_test_{aPeak_idx[0]}_sub-001_tibial_S35.h5" - # with h5py.File(fn, "w") as outfile: - # for keyword in dataset_keywords: - # outfile.create_dataset(keyword, data=getattr(fitecg, keyword)) - return fitted_art, post_idx_nextPeak diff --git a/mne/preprocessing/pca_obs/pchip_interpolation.py b/mne/preprocessing/pca_obs/pchip_interpolation.py index a35d1b75aed..e336f672363 100755 --- a/mne/preprocessing/pca_obs/pchip_interpolation.py +++ b/mne/preprocessing/pca_obs/pchip_interpolation.py @@ -1,6 +1,6 @@ # Function to interpolate based on PCHIP rather than MNE inbuilt linear option -import mne +# import mne import numpy as np from scipy.interpolate import PchipInterpolator as pchip import matplotlib.pyplot as plt @@ -8,22 +8,13 @@ def PCHIP_interpolation(data, **kwargs): # Check all necessary arguments sent in - required_kws = ["trigger_indices", "interpol_window_sec", "fs", "debug_mode"] + required_kws = ["trigger_indices", "interpol_window_sec", "fs"] assert all([kw in kwargs.keys() for kw in required_kws]), "Error. Some KWs not passed into PCA_OBS." # Extract all kwargs - more elegant ways to do this fs = kwargs['fs'] interpol_window_sec = kwargs['interpol_window_sec'] trigger_indices = kwargs['trigger_indices'] - debug_mode = kwargs['debug_mode'] - - if debug_mode: - plt.figure() - # plot signal with artifact - plot_range = [-50, 100] - test_trial = 100 - xx = (np.arange(plot_range[0], plot_range[1])) / fs * 1000 - plt.plot(xx, data[trigger_indices[test_trial] + plot_range[0]:trigger_indices[test_trial] + plot_range[1]]) # Convert intpol window to msec then convert to samples pre_window = round((interpol_window_sec[0]*1000) * fs / 1000) # in samples @@ -48,12 +39,4 @@ def PCHIP_interpolation(data, **kwargs): if np.mod(ii, 100) == 0: # talk to the operator every 100th trial print(f'stimulation event {ii} \n') - if debug_mode: - # plot signal with interpolated artifact - plt.figure() - plt.plot(xx, data[trigger_indices[test_trial] + plot_range[0]: trigger_indices[test_trial] + plot_range[1]]) - plt.title('After Correction') - - plt.show() - return data diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact.py b/mne/preprocessing/pca_obs/rm_heart_artefact.py index 0de2a73358e..f61a8b0e202 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact.py @@ -5,54 +5,35 @@ from scipy.io import loadmat from scipy.signal import firls from PCA_OBS import * -from get_conditioninfo import * -from get_channels import * +from mne.io import read_raw_fif +from mne import events_from_annotations, Epochs -def rm_heart_artefact(subject, condition, srmr_nr, sampling_rate, pchip): - matlab = False # If this is true, use the data 'prepared' by matlab - testing to see where hump at 0 comes from +if __name__ == '__main__': # Incredibly slow without parallelization # Set variables - subject_id = f'sub-{str(subject).zfill(3)}' - cond_info = get_conditioninfo(condition, srmr_nr) - nblocks = cond_info.nblocks - cond_name = cond_info.cond_name - stimulation = cond_info.stimulation - trigger_name = cond_info.trigger_name + subject_id = f'sub-001' + cond_name = 'median' + sampling_rate = 1000 + esg_chans = ['S35', 'S24', 'S36', 'Iz', 'S17', 'S15', 'S32', 'S22', + 'S19', 'S26', 'S28', 'S9', 'S13', 'S11', 'S7', 'SC1', 'S4', 'S18', + 'S8', 'S31', 'SC6', 'S12', 'S16', 'S5', 'S30', 'S20', 'S34', 'AC', + 'S21', 'S25', 'L1', 'S29', 'S14', 'S33', 'S3', 'AL', 'L4', 'S6', + 'S23'] + # For heartbeat epochs + iv_baseline = [-300 / 1000, -200 / 1000] + iv_epoch = [-400 / 1000, 600 / 1000] # Setting paths input_path = "/data/pt_02569/tmp_data/prepared_py/"+subject_id+"/esg/prepro/" - input_path_m = "/data/pt_02569/tmp_data/prepared/"+subject_id+"/esg/prepro/" - save_path = "/data/pt_02569/tmp_data/ecg_rm_py/"+subject_id+"/esg/prepro/" - os.makedirs(save_path, exist_ok=True) - - figure_path = save_path - - # For debugging just test one channel of each - debug_channel = ['S35'] - _, esg_chans, _ = get_channels(subject, False, False, srmr_nr) # Ignoring ECG and EOG channels - - # Dyanmically set filename - if matlab: - fname = f"raw_{sampling_rate}_spinal_{cond_name}.set" - raw = mne.io.read_raw_eeglab(input_path_m + fname, preload=True) - else: - if pchip: - fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs_pchip" - else: - fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs" - # Read .fif file from the previous step (import_data) - raw = mne.io.read_raw_fif(input_path + fname + '.fif', preload=True) + input_path_m = "/data/pt_02569/tmp_data/prepared/" + subject_id + "/esg/prepro/" + fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs_pchip" + raw = read_raw_fif(input_path + fname + '.fif', preload=True) # Read .mat file with QRS events fname_m = f"raw_{sampling_rate}_spinal_{cond_name}" matdata = loadmat(input_path_m+fname_m+'.mat') QRSevents_m = matdata['QRSevents'] - fwts = matdata['fwts'] - - # Read .h5 file with alternative QRS events - # with h5py.File(input_path+fname+'.h5', "r") as infile: - # QRSevents_p = infile["QRS"][()] # Create filter coefficients fs = sampling_rate @@ -62,69 +43,32 @@ def rm_heart_artefact(subject, condition, srmr_nr, sampling_rate, pchip): ord = round(3*fs/0.5) fwts = firls(ord+1, f, a) - # Run once with a single channel and debug_mode = True to get window information - for ch in debug_channel: - # set PCA_OBS input variables - channelNames = ['S35', 'Iz', 'SC1', 'S3', 'SC6', 'S20', 'L1', 'L4'] - # these channels will be plotted(only for debugging / testing) - - # run PCA_OBS - if pchip: - name = 'pca_chan_' + ch + '_pchip' - else: - name = 'pca_chan_'+ch - PCA_OBS_kwargs = dict( - debug_mode=True, qrs=QRSevents_m, filter_coords=fwts, sr=sampling_rate, - savename=save_path+name, - ch_names=channelNames, sub_nr=subject_id, - condition=cond_name, current_channel=ch - ) - # Apply function modifies the data in raw in place - raw.copy().apply_function(PCA_OBS, picks=[ch], **PCA_OBS_kwargs) - - # This information is the same for each channel - run through fitting once to get vals, add to all channels - keywords = ['window_start_idx', 'window_end_idx'] - fn = f"{save_path}pca_chan_{ch}_{cond_name}_pca_info.h5" - with h5py.File(fn, "r") as infile: - # Get the data - window_start = infile[keywords[0]][()].reshape(-1) - window_end = infile[keywords[1]][()].reshape(-1) - - onset = [x/sampling_rate for x in window_start] # Divide by sampling rate to make times - duration = np.repeat(0.0, len(window_start)) - description = ['fit_start'] * len(window_start) - raw.annotations.append(onset, duration, description, ch_names=[esg_chans] * len(window_start)) - - onset = [x/sampling_rate for x in window_end] - duration = np.repeat(0.0, len(window_end)) - description = ['fit_end'] * len(window_end) - raw.annotations.append(onset, duration, description, ch_names=[esg_chans]*len(window_end)) - - # Then run parallel for all channels with n_jobs set and debug_mode = False - # set PCA_OBS input variables - channelNames = ['S35', 'Iz', 'SC1', 'S3', 'SC6', 'S20', 'L1', 'L4'] - # these channels will be plotted(only for debugging / testing) - # run PCA_OBS # In this case ch_names, current_channel and savename are dummy vars - not necessary really PCA_OBS_kwargs = dict( - debug_mode=False, qrs=QRSevents_m, filter_coords=fwts, sr=sampling_rate, - savename=save_path + 'pca_chan', - ch_names=channelNames, sub_nr=subject_id, - condition=cond_name, current_channel=ch + qrs=QRSevents_m, filter_coords=fwts, sr=sampling_rate ) + events, event_ids = events_from_annotations(raw) + event_id_dict = {key: value for key, value in event_ids.items() if key == 'qrs'} + epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], + baseline=tuple(iv_baseline)) + evoked_before = epochs.average() + # Apply function should modifies the data in raw in place raw.apply_function(PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans)) - # Save the new mne structure with the cleaned data - if matlab: - raw.save(os.path.join(save_path, f'data_clean_ecg_spinal_{cond_name}_withqrs_mat.fif'), fmt='double', - overwrite=True) - else: - if pchip: - raw.save(os.path.join(save_path, f'data_clean_ecg_spinal_{cond_name}_withqrs_pchip.fif'), fmt='double', - overwrite=True) - else: - raw.save(os.path.join(save_path, f'data_clean_ecg_spinal_{cond_name}_withqrs.fif'), fmt='double', - overwrite=True) + epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], + baseline=tuple(iv_baseline)) + evoked_after = epochs.average() + + # Comparison image + fig, axes = plt.subplots(2, 1) + axes[0].plot(evoked_before.times, evoked_before.get_data().T) + axes[0].set_ylim([-0.0005, 0.001]) + axes[0].set_title('Before PCA-OBS') + axes[1].plot(evoked_after.times, evoked_after.get_data().T) + axes[1].set_ylim([-0.0005, 0.001]) + axes[1].set_title('After PCA-OBS') + plt.tight_layout() + plt.show() From 362254062b6472a5317a94b3715a53a487f56eb6 Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Mon, 9 Sep 2024 16:15:42 +0200 Subject: [PATCH 03/71] Implement testing dataset --- mne/preprocessing/pca_obs/PCA_OBS.py | 15 ++-- .../pca_obs/rm_heart_artefact.py | 10 ++- .../pca_obs/rm_heart_artefact_mnedata.py | 77 +++++++++++++++++++ 3 files changed, 92 insertions(+), 10 deletions(-) create mode 100644 mne/preprocessing/pca_obs/rm_heart_artefact_mnedata.py diff --git a/mne/preprocessing/pca_obs/PCA_OBS.py b/mne/preprocessing/pca_obs/PCA_OBS.py index 9209da8a6f0..aee06165d11 100755 --- a/mne/preprocessing/pca_obs/PCA_OBS.py +++ b/mne/preprocessing/pca_obs/PCA_OBS.py @@ -171,16 +171,15 @@ def __init__(self): data = data.reshape(-1) fitted_art = fitted_art.reshape(-1) - # data -= fitted_art - - data_ = np.zeros(len(data)) - data_[0] = data[0] - data_[1:] = data[1:] - fitted_art[:-1] - data = data_ + # One sample shift for my actual data (introduced using matlab r timings) + # data_ = np.zeros(len(data)) + # data_[0] = data[0] + # data_[1:] = data[1:] - fitted_art[:-1] + # data = data_ # Original code is this: - # data -= fitted_art - # data = data.T.reshape(-1) + data -= fitted_art + data = data.T.reshape(-1) # Can only return data return data diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact.py b/mne/preprocessing/pca_obs/rm_heart_artefact.py index f61a8b0e202..1b559ef9f22 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact.py @@ -44,7 +44,6 @@ fwts = firls(ord+1, f, a) # run PCA_OBS - # In this case ch_names, current_channel and savename are dummy vars - not necessary really PCA_OBS_kwargs = dict( qrs=QRSevents_m, filter_coords=fwts, sr=sampling_rate ) @@ -57,7 +56,6 @@ # Apply function should modifies the data in raw in place raw.apply_function(PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans)) - epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], baseline=tuple(iv_baseline)) evoked_after = epochs.average() @@ -71,4 +69,12 @@ axes[1].set_ylim([-0.0005, 0.001]) axes[1].set_title('After PCA-OBS') plt.tight_layout() + + # Comparison image + fig, axes = plt.subplots(1, 1) + axes.plot(evoked_before.times, evoked_before.get_data().T, color='black') + axes.set_ylim([-0.0005, 0.001]) + axes.plot(evoked_after.times, evoked_after.get_data().T, color='green') + axes.set_title('Before (black) versus after (green)') + plt.tight_layout() plt.show() diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_mnedata.py b/mne/preprocessing/pca_obs/rm_heart_artefact_mnedata.py new file mode 100644 index 00000000000..42e58e99463 --- /dev/null +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_mnedata.py @@ -0,0 +1,77 @@ +# Checking algorithm implementation with EEG data form MNE sample datasets +# Following this tutorial data: https://mne.tools/stable/auto_tutorials/preprocessing/50_artifact_correction_ssp.html#what-is-ssp + +from mne.preprocessing import ( + find_ecg_events, + create_ecg_epochs, +) +from mne.io import read_raw_fif +from mne.datasets.sample import data_path +from PCA_OBS import * +from mne import Epochs +import os +import numpy as np +from scipy.signal import firls +import matplotlib.pyplot as plt + +if __name__ == '__main__': + sample_data_folder = data_path(path='/data/pt_02569/mne_test_data/') + sample_data_raw_file = os.path.join( + sample_data_folder, "MEG", "sample", "sample_audvis_raw.fif" + ) + # here we crop and resample just for speed + raw = read_raw_fif(sample_data_raw_file, preload=True) + + # Find ECG events - no ECG channel in data, uses synthetic + ecg_events, ch_ecg, average_pulse, = find_ecg_events(raw) + # Extract just sample timings of ecg events + ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) + # print(ecg_events) + + # Create filter coefficients + fs = raw.info['sfreq'] + a = [0, 0, 1, 1] + f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter + # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter + ord = round(3 * fs / 0.5) + fwts = firls(ord + 1, f, a) + + # For heartbeat epochs + iv_baseline = [-300 / 1000, -200 / 1000] + iv_epoch = [-400 / 1000, 600 / 1000] + + # run PCA_OBS + # Algorithm is extremely sensitive to accurate R-peak timings, won't work as well with the artificial ECG + # channel estimation as we have here + PCA_OBS_kwargs = dict( + qrs=ecg_event_samples, filter_coords=fwts, sr=fs + ) + + epochs = Epochs(raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], + baseline=tuple(iv_baseline)) + evoked_before = epochs.average() + + # Apply function should modifies the data in raw in place + raw.apply_function(PCA_OBS, picks='eeg', **PCA_OBS_kwargs, n_jobs=10) + epochs = Epochs(raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], + baseline=tuple(iv_baseline)) + evoked_after = epochs.average() + + # Comparison image + fig, axes = plt.subplots(2, 1) + axes[0].plot(evoked_before.times, evoked_before.get_data().T) + axes[0].set_ylim([-1e-5, 3e-5]) + axes[0].set_title('Before PCA-OBS') + axes[1].plot(evoked_after.times, evoked_after.get_data().T) + axes[1].set_ylim([-1e-5, 3e-5]) + axes[1].set_title('After PCA-OBS') + plt.tight_layout() + + # Comparison image + fig, axes = plt.subplots(1, 1) + axes.plot(evoked_before.times, evoked_before.get_data().T, color='black') + axes.set_ylim([-1e-5, 3e-5]) + axes.plot(evoked_after.times, evoked_after.get_data().T, color='green') + axes.set_title('Before (black) versus after (green)') + plt.tight_layout() + plt.show() From 6c42f33f5fca1dec460fa902b5b6a3c21b6d4131 Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Mon, 7 Oct 2024 15:43:26 +0200 Subject: [PATCH 04/71] Feat: Update examples --- ... => rm_heart_artefact_cortical_mnedata.py} | 0 ...rm_heart_artefact_spinal_impreciserpeak.py | 80 +++++++++++++++++++ ... rm_heart_artefact_spinal_preciserpeak.py} | 0 3 files changed, 80 insertions(+) rename mne/preprocessing/pca_obs/{rm_heart_artefact_mnedata.py => rm_heart_artefact_cortical_mnedata.py} (100%) create mode 100755 mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py rename mne/preprocessing/pca_obs/{rm_heart_artefact.py => rm_heart_artefact_spinal_preciserpeak.py} (100%) diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_mnedata.py b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py similarity index 100% rename from mne/preprocessing/pca_obs/rm_heart_artefact_mnedata.py rename to mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py new file mode 100755 index 00000000000..7634096c909 --- /dev/null +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py @@ -0,0 +1,80 @@ +# Calls PCA_OBS which in turn calls fit_ecgTemplate to remove the heart artefact via PCA_OBS (Principal Component +# Analysis, Optimal Basis Sets) + +import os +from scipy.io import loadmat +from scipy.signal import firls +from PCA_OBS import * +from mne.io import read_raw_fif +from mne import events_from_annotations, Epochs +from mne.preprocessing import find_ecg_events + + +if __name__ == '__main__': + # Incredibly slow without parallelization + # Set variables + subject_id = f'sub-001' + cond_name = 'median' + sampling_rate = 1000 + esg_chans = ['S35', 'S24', 'S36', 'Iz', 'S17', 'S15', 'S32', 'S22', + 'S19', 'S26', 'S28', 'S9', 'S13', 'S11', 'S7', 'SC1', 'S4', 'S18', + 'S8', 'S31', 'SC6', 'S12', 'S16', 'S5', 'S30', 'S20', 'S34', 'AC', + 'S21', 'S25', 'L1', 'S29', 'S14', 'S33', 'S3', 'AL', 'L4', 'S6', + 'S23'] + # For heartbeat epochs + iv_baseline = [-300 / 1000, -200 / 1000] + iv_epoch = [-400 / 1000, 600 / 1000] + + # Setting paths + input_path = "/data/pt_02569/tmp_data/prepared_py/"+subject_id+"/esg/prepro/" + fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs_pchip" + raw = read_raw_fif(input_path + fname + '.fif', preload=True) + + # Find ECG events - no ECG channel in data, uses synthetic + ecg_events, ch_ecg, average_pulse, = find_ecg_events(raw, ch_name='ECG') + # Extract just sample timings of ecg events + ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) + + # Create filter coefficients + fs = sampling_rate + a = [0, 0, 1, 1] + f = [0, 0.4/(fs/2), 0.9/(fs / 2), 1] # 0.9 Hz highpass filter + # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter + ord = round(3*fs/0.5) + fwts = firls(ord+1, f, a) + + # run PCA_OBS + PCA_OBS_kwargs = dict( + qrs=ecg_event_samples, filter_coords=fwts, sr=sampling_rate + ) + + events, event_ids = events_from_annotations(raw) + event_id_dict = {key: value for key, value in event_ids.items() if key == 'qrs'} + epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], + baseline=tuple(iv_baseline)) + evoked_before = epochs.average() + + # Apply function should modifies the data in raw in place + raw.apply_function(PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans)) + epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], + baseline=tuple(iv_baseline)) + evoked_after = epochs.average() + + # Comparison image + fig, axes = plt.subplots(2, 1) + axes[0].plot(evoked_before.times, evoked_before.get_data().T) + axes[0].set_ylim([-0.0005, 0.001]) + axes[0].set_title('Before PCA-OBS') + axes[1].plot(evoked_after.times, evoked_after.get_data().T) + axes[1].set_ylim([-0.0005, 0.001]) + axes[1].set_title('After PCA-OBS') + plt.tight_layout() + + # Comparison image + fig, axes = plt.subplots(1, 1) + axes.plot(evoked_before.times, evoked_before.get_data().T, color='black') + axes.set_ylim([-0.0005, 0.001]) + axes.plot(evoked_after.times, evoked_after.get_data().T, color='green') + axes.set_title('Before (black) versus after (green)') + plt.tight_layout() + plt.show() diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py similarity index 100% rename from mne/preprocessing/pca_obs/rm_heart_artefact.py rename to mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py From 61f58b279f11adf25043c671ab491d56aa10d032 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 23 Oct 2024 16:32:20 +0000 Subject: [PATCH 05/71] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/preprocessing/pca_obs/PCA_OBS.py | 112 +++++++++++------ mne/preprocessing/pca_obs/fit_ecgTemplate.py | 53 ++++++-- .../pca_obs/pchip_interpolation.py | 45 ++++--- .../rm_heart_artefact_cortical_mnedata.py | 57 +++++---- ...rm_heart_artefact_spinal_impreciserpeak.py | 116 +++++++++++++----- .../rm_heart_artefact_spinal_preciserpeak.py | 113 ++++++++++++----- 6 files changed, 340 insertions(+), 156 deletions(-) diff --git a/mne/preprocessing/pca_obs/PCA_OBS.py b/mne/preprocessing/pca_obs/PCA_OBS.py index aee06165d11..ec09d4f9f6e 100755 --- a/mne/preprocessing/pca_obs/PCA_OBS.py +++ b/mne/preprocessing/pca_obs/PCA_OBS.py @@ -1,17 +1,16 @@ +import math + import numpy as np +from fit_ecgTemplate import fit_ecgTemplate + # import mne -from scipy.signal import filtfilt, detrend -import matplotlib.pyplot as plt +from scipy.signal import detrend, filtfilt from sklearn.decomposition import PCA -from fit_ecgTemplate import fit_ecgTemplate -import math -import h5py def PCA_OBS(data, **kwargs): - # Declare class to hold pca information - class PCAInfo(): + class PCAInfo: def __init__(self): pass @@ -20,12 +19,14 @@ def __init__(self): # Check all necessary arguments sent in required_kws = ["qrs", "filter_coords", "sr"] - assert all([kw in kwargs.keys() for kw in required_kws]), "Error. Some KWs not passed into PCA_OBS." + assert all( + [kw in kwargs.keys() for kw in required_kws] + ), "Error. Some KWs not passed into PCA_OBS." # Extract all kwargs - qrs = kwargs['qrs'] - filter_coords = kwargs['filter_coords'] - sr = kwargs['sr'] + qrs = kwargs["qrs"] + filter_coords = kwargs["filter_coords"] + sr = kwargs["sr"] fs = sr @@ -50,19 +51,21 @@ def __init__(self): ################################################################ # Preparatory work - reserving memory, configure sizes, de-trend ################################################################ - print('Pulse artifact subtraction in progress...Please wait!') + print("Pulse artifact subtraction in progress...Please wait!") # define peak range based on RR RR = np.diff(peak_idx[:, 0]) mRR = np.median(RR) - peak_range = round(mRR/2) # Rounds to an integer + peak_range = round(mRR / 2) # Rounds to an integer midP = peak_range + 1 - baseline_range = [0, round(peak_range/8)] - n_samples_fit = round(peak_range/8) # sample fit for interpolation between fitted artifact windows + baseline_range = [0, round(peak_range / 8)] + n_samples_fit = round( + peak_range / 8 + ) # sample fit for interpolation between fitted artifact windows # make sure array is long enough for PArange (if not cut off last ECG peak) pa = peak_count # Number of QRS complexes detected - while peak_idx[pa-1, 0] + peak_range > len(data[0]): + while peak_idx[pa - 1, 0] + peak_range > len(data[0]): pa = pa - 1 steps = 1 * pa peak_count = pa @@ -71,16 +74,22 @@ def __init__(self): eegchan = filtfilt(filter_coords, 1, data) # build PCA matrix(heart-beat-epochs x window-length) - pcamat = np.zeros((peak_count - 1, 2*peak_range+1)) # [epoch x time] + pcamat = np.zeros((peak_count - 1, 2 * peak_range + 1)) # [epoch x time] # picking out heartbeat epochs for p in range(1, peak_count): - pcamat[p-1, :] = eegchan[0, peak_idx[p, 0] - peak_range: peak_idx[p, 0] + peak_range+1] + pcamat[p - 1, :] = eegchan[ + 0, peak_idx[p, 0] - peak_range : peak_idx[p, 0] + peak_range + 1 + ] # detrending matrix(twice) - pcamat = detrend(pcamat, type='constant', axis=1) # [epoch x time] - detrended along the epoch - mean_effect = np.mean(pcamat, axis=0) # [1 x time], contains the mean over all epochs + pcamat = detrend( + pcamat, type="constant", axis=1 + ) # [epoch x time] - detrended along the epoch + mean_effect = np.mean( + pcamat, axis=0 + ) # [1 x time], contains the mean over all epochs std_effect = np.std(pcamat, axis=0) # want mean and std of each column - dpcamat = detrend(pcamat, type='constant', axis=1) # [time x epoch] + dpcamat = detrend(pcamat, type="constant", axis=1) # [time x epoch] ################################################################### # Perform PCA with sklearn @@ -90,7 +99,7 @@ def __init__(self): pca.fit(dpcamat) eigen_vectors = pca.components_ eigen_values = pca.explained_variance_ - factor_loadings = pca.components_.T*np.sqrt(pca.explained_variance_) + factor_loadings = pca.components_.T * np.sqrt(pca.explained_variance_) pca_info.eigen_vectors = eigen_vectors pca_info.factor_loadings = factor_loadings pca_info.eigen_values = eigen_values @@ -116,34 +125,55 @@ def __init__(self): # Deals with start portion of data if p == 0: pre_range = peak_range - post_range = math.floor((peak_idx[p + 1] - peak_idx[p])/2) + post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) if post_range > peak_range: post_range = peak_range try: post_idx_nextPeak = [] - fitted_art, post_idx_nextPeak = fit_ecgTemplate(data, pca_template, peak_idx[p], peak_range, - pre_range, post_range, baseline_range, midP, - fitted_art, post_idx_nextPeak, n_samples_fit) + fitted_art, post_idx_nextPeak = fit_ecgTemplate( + data, + pca_template, + peak_idx[p], + peak_range, + pre_range, + post_range, + baseline_range, + midP, + fitted_art, + post_idx_nextPeak, + n_samples_fit, + ) # Appending to list instead of using counter window_start_idx.append(peak_idx[p] - peak_range) window_end_idx.append(peak_idx[p] + peak_range) except Exception as e: - print(f'Cannot fit first ECG epoch. Reason: {e}') + print(f"Cannot fit first ECG epoch. Reason: {e}") # Deals with last edge of data elif p == peak_count: - print('On last section - almost there!') + print("On last section - almost there!") try: pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) post_range = peak_range if pre_range > peak_range: pre_range = peak_range - fitted_art, _ = fit_ecgTemplate(data, pca_template, peak_idx(p), peak_range, pre_range, post_range, - baseline_range, midP, fitted_art, post_idx_nextPeak, n_samples_fit) + fitted_art, _ = fit_ecgTemplate( + data, + pca_template, + peak_idx(p), + peak_range, + pre_range, + post_range, + baseline_range, + midP, + fitted_art, + post_idx_nextPeak, + n_samples_fit, + ) window_start_idx.append(peak_idx[p] - peak_range) window_end_idx.append(peak_idx[p] + peak_range) except Exception as e: - print(f'Cannot fit last ECG epoch. Reason: {e}') + print(f"Cannot fit last ECG epoch. Reason: {e}") # Deals with middle portion of data else: @@ -157,10 +187,22 @@ def __init__(self): if post_range > peak_range: post_range = peak_range - aTemplate = pca_template[midP - peak_range-1:midP + peak_range+1, :] - fitted_art, post_idx_nextPeak = fit_ecgTemplate(data, aTemplate, peak_idx[p], peak_range, pre_range, - post_range, baseline_range, midP, fitted_art, - post_idx_nextPeak, n_samples_fit) + aTemplate = pca_template[ + midP - peak_range - 1 : midP + peak_range + 1, : + ] + fitted_art, post_idx_nextPeak = fit_ecgTemplate( + data, + aTemplate, + peak_idx[p], + peak_range, + pre_range, + post_range, + baseline_range, + midP, + fitted_art, + post_idx_nextPeak, + n_samples_fit, + ) window_start_idx.append(peak_idx[p] - peak_range) window_end_idx.append(peak_idx[p] + peak_range) except Exception as e: diff --git a/mne/preprocessing/pca_obs/fit_ecgTemplate.py b/mne/preprocessing/pca_obs/fit_ecgTemplate.py index 11043b0c56d..3e5b7f55cb2 100755 --- a/mne/preprocessing/pca_obs/fit_ecgTemplate.py +++ b/mne/preprocessing/pca_obs/fit_ecgTemplate.py @@ -1,12 +1,23 @@ import numpy as np -from scipy.signal import detrend from scipy.interpolate import PchipInterpolator as pchip -import h5py +from scipy.signal import detrend -def fit_ecgTemplate(data, pca_template, aPeak_idx, peak_range, pre_range, post_range, baseline_range, midP, fitted_art, post_idx_previousPeak, n_samples_fit): +def fit_ecgTemplate( + data, + pca_template, + aPeak_idx, + peak_range, + pre_range, + post_range, + baseline_range, + midP, + fitted_art, + post_idx_previousPeak, + n_samples_fit, +): # Declare class to hold ecg fit information - class fitECG(): + class fitECG: def __init__(self): pass @@ -16,18 +27,20 @@ def __init__(self): # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previouspeak # Then nextpeak is returned at the end and the process repeats # select window of template - template = pca_template[midP - peak_range-1: midP + peak_range+1, :] + template = pca_template[midP - peak_range - 1 : midP + peak_range + 1, :] # select window of data and detrend it - slice = data[0, aPeak_idx[0] - peak_range:aPeak_idx[0] + peak_range+1] - detrended_data = detrend(slice.reshape(-1), type='constant') + slice = data[0, aPeak_idx[0] - peak_range : aPeak_idx[0] + peak_range + 1] + detrended_data = detrend(slice.reshape(-1), type="constant") # maps data on template and then maps it again back to the sensor space least_square = np.linalg.lstsq(template, detrended_data, rcond=None) pad_fit = np.dot(template, least_square[0]) # fit artifact, I already loop through externally channel to channel - fitted_art[0, aPeak_idx[0] - pre_range-1: aPeak_idx[0] + post_range] = pad_fit[midP - pre_range-1: midP + post_range].T + fitted_art[0, aPeak_idx[0] - pre_range - 1 : aPeak_idx[0] + post_range] = pad_fit[ + midP - pre_range - 1 : midP + post_range + ].T fitecg.fitted_art = fitted_art fitecg.template = template @@ -43,7 +56,9 @@ def __init__(self): # Check it's not empty if len(post_idx_previousPeak) != 0: # interpolate time between peaks - intpol_window = np.ceil([post_idx_previousPeak[0], aPeak_idx[0] - pre_range]).astype('int') # interpolation window + intpol_window = np.ceil( + [post_idx_previousPeak[0], aPeak_idx[0] - pre_range] + ).astype("int") # interpolation window fitecg.intpol_window = intpol_window if intpol_window[0] < intpol_window[1]: @@ -53,14 +68,26 @@ def __init__(self): # You have y_fit which is the y vals corresponding to x values above # You have x_interpol which is the time points between the two slices in x_fit that you want to interpolate # You have y_interpol which is values from pchip at the time points specified in x_interpol - x_interpol = np.arange(intpol_window[0], intpol_window[1] + 1, 1) # points to be interpolated in pt - the gap between the endpoints of the window - x_fit = np.concatenate([np.arange(intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1), - np.arange(intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1)]) # Entire range of x values in this step (taking some number of samples before and after the window) + x_interpol = np.arange( + intpol_window[0], intpol_window[1] + 1, 1 + ) # points to be interpolated in pt - the gap between the endpoints of the window + x_fit = np.concatenate( + [ + np.arange( + intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1 + ), + np.arange( + intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1 + ), + ] + ) # Entire range of x values in this step (taking some number of samples before and after the window) y_fit = fitted_art[0, x_fit] y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation # Then make fitted artefact in the desired range equal to the completed fit above - fitted_art[0, post_idx_previousPeak[0]: aPeak_idx[0] - pre_range+1] = y_interpol + fitted_art[0, post_idx_previousPeak[0] : aPeak_idx[0] - pre_range + 1] = ( + y_interpol + ) fitecg.x_fit = x_fit fitecg.y_fit = y_fit diff --git a/mne/preprocessing/pca_obs/pchip_interpolation.py b/mne/preprocessing/pca_obs/pchip_interpolation.py index e336f672363..56a34eabb43 100755 --- a/mne/preprocessing/pca_obs/pchip_interpolation.py +++ b/mne/preprocessing/pca_obs/pchip_interpolation.py @@ -3,33 +3,46 @@ # import mne import numpy as np from scipy.interpolate import PchipInterpolator as pchip -import matplotlib.pyplot as plt def PCHIP_interpolation(data, **kwargs): # Check all necessary arguments sent in required_kws = ["trigger_indices", "interpol_window_sec", "fs"] - assert all([kw in kwargs.keys() for kw in required_kws]), "Error. Some KWs not passed into PCA_OBS." + assert all( + [kw in kwargs.keys() for kw in required_kws] + ), "Error. Some KWs not passed into PCA_OBS." # Extract all kwargs - more elegant ways to do this - fs = kwargs['fs'] - interpol_window_sec = kwargs['interpol_window_sec'] - trigger_indices = kwargs['trigger_indices'] + fs = kwargs["fs"] + interpol_window_sec = kwargs["interpol_window_sec"] + trigger_indices = kwargs["trigger_indices"] # Convert intpol window to msec then convert to samples - pre_window = round((interpol_window_sec[0]*1000) * fs / 1000) # in samples - post_window = round((interpol_window_sec[1]*1000) * fs / 1000) # in samples - intpol_window = np.ceil([pre_window, post_window]).astype(int) # interpolation window - - n_samples_fit = 5 # number of samples before and after cut used for interpolation fit - - x_fit_raw = np.concatenate([np.arange(intpol_window[0]-n_samples_fit-1, intpol_window[0], 1), - np.arange(intpol_window[1]+1, intpol_window[1]+n_samples_fit+2, 1)]) - x_interpol_raw = np.arange(intpol_window[0], intpol_window[1]+1, 1) # points to be interpolated; in pt + pre_window = round((interpol_window_sec[0] * 1000) * fs / 1000) # in samples + post_window = round((interpol_window_sec[1] * 1000) * fs / 1000) # in samples + intpol_window = np.ceil([pre_window, post_window]).astype( + int + ) # interpolation window + + n_samples_fit = ( + 5 # number of samples before and after cut used for interpolation fit + ) + + x_fit_raw = np.concatenate( + [ + np.arange(intpol_window[0] - n_samples_fit - 1, intpol_window[0], 1), + np.arange(intpol_window[1] + 1, intpol_window[1] + n_samples_fit + 2, 1), + ] + ) + x_interpol_raw = np.arange( + intpol_window[0], intpol_window[1] + 1, 1 + ) # points to be interpolated; in pt for ii in np.arange(0, len(trigger_indices)): # loop through all stimulation events x_fit = trigger_indices[ii] + x_fit_raw # fit point latencies for this event - x_interpol = trigger_indices[ii] + x_interpol_raw # latencies for to-be-interpolated data points + x_interpol = ( + trigger_indices[ii] + x_interpol_raw + ) # latencies for to-be-interpolated data points # Data is just a string of values y_fit = data[x_fit] # y values to be fitted @@ -37,6 +50,6 @@ def PCHIP_interpolation(data, **kwargs): data[x_interpol] = y_interpol # replace in data if np.mod(ii, 100) == 0: # talk to the operator every 100th trial - print(f'stimulation event {ii} \n') + print(f"stimulation event {ii} \n") return data diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py index 42e58e99463..e84f9de3b90 100644 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py @@ -1,21 +1,22 @@ # Checking algorithm implementation with EEG data form MNE sample datasets # Following this tutorial data: https://mne.tools/stable/auto_tutorials/preprocessing/50_artifact_correction_ssp.html#what-is-ssp -from mne.preprocessing import ( - find_ecg_events, - create_ecg_epochs, -) -from mne.io import read_raw_fif -from mne.datasets.sample import data_path -from PCA_OBS import * -from mne import Epochs import os + +import matplotlib.pyplot as plt import numpy as np +from PCA_OBS import * from scipy.signal import firls -import matplotlib.pyplot as plt -if __name__ == '__main__': - sample_data_folder = data_path(path='/data/pt_02569/mne_test_data/') +from mne import Epochs +from mne.datasets.sample import data_path +from mne.io import read_raw_fif +from mne.preprocessing import ( + find_ecg_events, +) + +if __name__ == "__main__": + sample_data_folder = data_path(path="/data/pt_02569/mne_test_data/") sample_data_raw_file = os.path.join( sample_data_folder, "MEG", "sample", "sample_audvis_raw.fif" ) @@ -23,13 +24,17 @@ raw = read_raw_fif(sample_data_raw_file, preload=True) # Find ECG events - no ECG channel in data, uses synthetic - ecg_events, ch_ecg, average_pulse, = find_ecg_events(raw) + ( + ecg_events, + ch_ecg, + average_pulse, + ) = find_ecg_events(raw) # Extract just sample timings of ecg events ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) # print(ecg_events) # Create filter coefficients - fs = raw.info['sfreq'] + fs = raw.info["sfreq"] a = [0, 0, 1, 1] f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter @@ -43,35 +48,35 @@ # run PCA_OBS # Algorithm is extremely sensitive to accurate R-peak timings, won't work as well with the artificial ECG # channel estimation as we have here - PCA_OBS_kwargs = dict( - qrs=ecg_event_samples, filter_coords=fwts, sr=fs - ) + PCA_OBS_kwargs = dict(qrs=ecg_event_samples, filter_coords=fwts, sr=fs) - epochs = Epochs(raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], - baseline=tuple(iv_baseline)) + epochs = Epochs( + raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], baseline=tuple(iv_baseline) + ) evoked_before = epochs.average() # Apply function should modifies the data in raw in place - raw.apply_function(PCA_OBS, picks='eeg', **PCA_OBS_kwargs, n_jobs=10) - epochs = Epochs(raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], - baseline=tuple(iv_baseline)) + raw.apply_function(PCA_OBS, picks="eeg", **PCA_OBS_kwargs, n_jobs=10) + epochs = Epochs( + raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], baseline=tuple(iv_baseline) + ) evoked_after = epochs.average() # Comparison image fig, axes = plt.subplots(2, 1) axes[0].plot(evoked_before.times, evoked_before.get_data().T) axes[0].set_ylim([-1e-5, 3e-5]) - axes[0].set_title('Before PCA-OBS') + axes[0].set_title("Before PCA-OBS") axes[1].plot(evoked_after.times, evoked_after.get_data().T) axes[1].set_ylim([-1e-5, 3e-5]) - axes[1].set_title('After PCA-OBS') + axes[1].set_title("After PCA-OBS") plt.tight_layout() # Comparison image fig, axes = plt.subplots(1, 1) - axes.plot(evoked_before.times, evoked_before.get_data().T, color='black') + axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") axes.set_ylim([-1e-5, 3e-5]) - axes.plot(evoked_after.times, evoked_after.get_data().T, color='green') - axes.set_title('Before (black) versus after (green)') + axes.plot(evoked_after.times, evoked_after.get_data().T, color="green") + axes.set_title("Before (black) versus after (green)") plt.tight_layout() plt.show() diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py index 7634096c909..73682b67867 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py @@ -1,80 +1,130 @@ # Calls PCA_OBS which in turn calls fit_ecgTemplate to remove the heart artefact via PCA_OBS (Principal Component # Analysis, Optimal Basis Sets) -import os -from scipy.io import loadmat -from scipy.signal import firls from PCA_OBS import * +from scipy.signal import firls + +from mne import Epochs, events_from_annotations from mne.io import read_raw_fif -from mne import events_from_annotations, Epochs from mne.preprocessing import find_ecg_events - -if __name__ == '__main__': +if __name__ == "__main__": # Incredibly slow without parallelization # Set variables - subject_id = f'sub-001' - cond_name = 'median' + subject_id = "sub-001" + cond_name = "median" sampling_rate = 1000 - esg_chans = ['S35', 'S24', 'S36', 'Iz', 'S17', 'S15', 'S32', 'S22', - 'S19', 'S26', 'S28', 'S9', 'S13', 'S11', 'S7', 'SC1', 'S4', 'S18', - 'S8', 'S31', 'SC6', 'S12', 'S16', 'S5', 'S30', 'S20', 'S34', 'AC', - 'S21', 'S25', 'L1', 'S29', 'S14', 'S33', 'S3', 'AL', 'L4', 'S6', - 'S23'] + esg_chans = [ + "S35", + "S24", + "S36", + "Iz", + "S17", + "S15", + "S32", + "S22", + "S19", + "S26", + "S28", + "S9", + "S13", + "S11", + "S7", + "SC1", + "S4", + "S18", + "S8", + "S31", + "SC6", + "S12", + "S16", + "S5", + "S30", + "S20", + "S34", + "AC", + "S21", + "S25", + "L1", + "S29", + "S14", + "S33", + "S3", + "AL", + "L4", + "S6", + "S23", + ] # For heartbeat epochs iv_baseline = [-300 / 1000, -200 / 1000] iv_epoch = [-400 / 1000, 600 / 1000] # Setting paths - input_path = "/data/pt_02569/tmp_data/prepared_py/"+subject_id+"/esg/prepro/" + input_path = "/data/pt_02569/tmp_data/prepared_py/" + subject_id + "/esg/prepro/" fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs_pchip" - raw = read_raw_fif(input_path + fname + '.fif', preload=True) + raw = read_raw_fif(input_path + fname + ".fif", preload=True) # Find ECG events - no ECG channel in data, uses synthetic - ecg_events, ch_ecg, average_pulse, = find_ecg_events(raw, ch_name='ECG') + ( + ecg_events, + ch_ecg, + average_pulse, + ) = find_ecg_events(raw, ch_name="ECG") # Extract just sample timings of ecg events ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) # Create filter coefficients fs = sampling_rate a = [0, 0, 1, 1] - f = [0, 0.4/(fs/2), 0.9/(fs / 2), 1] # 0.9 Hz highpass filter + f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter - ord = round(3*fs/0.5) - fwts = firls(ord+1, f, a) + ord = round(3 * fs / 0.5) + fwts = firls(ord + 1, f, a) # run PCA_OBS - PCA_OBS_kwargs = dict( - qrs=ecg_event_samples, filter_coords=fwts, sr=sampling_rate - ) + PCA_OBS_kwargs = dict(qrs=ecg_event_samples, filter_coords=fwts, sr=sampling_rate) events, event_ids = events_from_annotations(raw) - event_id_dict = {key: value for key, value in event_ids.items() if key == 'qrs'} - epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], - baseline=tuple(iv_baseline)) + event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} + epochs = Epochs( + raw, + events, + event_id=event_id_dict, + tmin=iv_epoch[0], + tmax=iv_epoch[1], + baseline=tuple(iv_baseline), + ) evoked_before = epochs.average() # Apply function should modifies the data in raw in place - raw.apply_function(PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans)) - epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], - baseline=tuple(iv_baseline)) + raw.apply_function( + PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans) + ) + epochs = Epochs( + raw, + events, + event_id=event_id_dict, + tmin=iv_epoch[0], + tmax=iv_epoch[1], + baseline=tuple(iv_baseline), + ) evoked_after = epochs.average() # Comparison image fig, axes = plt.subplots(2, 1) axes[0].plot(evoked_before.times, evoked_before.get_data().T) axes[0].set_ylim([-0.0005, 0.001]) - axes[0].set_title('Before PCA-OBS') + axes[0].set_title("Before PCA-OBS") axes[1].plot(evoked_after.times, evoked_after.get_data().T) axes[1].set_ylim([-0.0005, 0.001]) - axes[1].set_title('After PCA-OBS') + axes[1].set_title("After PCA-OBS") plt.tight_layout() # Comparison image fig, axes = plt.subplots(1, 1) - axes.plot(evoked_before.times, evoked_before.get_data().T, color='black') + axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") axes.set_ylim([-0.0005, 0.001]) - axes.plot(evoked_after.times, evoked_after.get_data().T, color='green') - axes.set_title('Before (black) versus after (green)') + axes.plot(evoked_after.times, evoked_after.get_data().T, color="green") + axes.set_title("Before (black) versus after (green)") plt.tight_layout() plt.show() diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py index 1b559ef9f22..4c40d8f7234 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py @@ -1,80 +1,127 @@ # Calls PCA_OBS which in turn calls fit_ecgTemplate to remove the heart artefact via PCA_OBS (Principal Component # Analysis, Optimal Basis Sets) -import os +from PCA_OBS import * from scipy.io import loadmat from scipy.signal import firls -from PCA_OBS import * -from mne.io import read_raw_fif -from mne import events_from_annotations, Epochs +from mne import Epochs, events_from_annotations +from mne.io import read_raw_fif -if __name__ == '__main__': +if __name__ == "__main__": # Incredibly slow without parallelization # Set variables - subject_id = f'sub-001' - cond_name = 'median' + subject_id = "sub-001" + cond_name = "median" sampling_rate = 1000 - esg_chans = ['S35', 'S24', 'S36', 'Iz', 'S17', 'S15', 'S32', 'S22', - 'S19', 'S26', 'S28', 'S9', 'S13', 'S11', 'S7', 'SC1', 'S4', 'S18', - 'S8', 'S31', 'SC6', 'S12', 'S16', 'S5', 'S30', 'S20', 'S34', 'AC', - 'S21', 'S25', 'L1', 'S29', 'S14', 'S33', 'S3', 'AL', 'L4', 'S6', - 'S23'] + esg_chans = [ + "S35", + "S24", + "S36", + "Iz", + "S17", + "S15", + "S32", + "S22", + "S19", + "S26", + "S28", + "S9", + "S13", + "S11", + "S7", + "SC1", + "S4", + "S18", + "S8", + "S31", + "SC6", + "S12", + "S16", + "S5", + "S30", + "S20", + "S34", + "AC", + "S21", + "S25", + "L1", + "S29", + "S14", + "S33", + "S3", + "AL", + "L4", + "S6", + "S23", + ] # For heartbeat epochs iv_baseline = [-300 / 1000, -200 / 1000] iv_epoch = [-400 / 1000, 600 / 1000] # Setting paths - input_path = "/data/pt_02569/tmp_data/prepared_py/"+subject_id+"/esg/prepro/" + input_path = "/data/pt_02569/tmp_data/prepared_py/" + subject_id + "/esg/prepro/" input_path_m = "/data/pt_02569/tmp_data/prepared/" + subject_id + "/esg/prepro/" fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs_pchip" - raw = read_raw_fif(input_path + fname + '.fif', preload=True) + raw = read_raw_fif(input_path + fname + ".fif", preload=True) # Read .mat file with QRS events fname_m = f"raw_{sampling_rate}_spinal_{cond_name}" - matdata = loadmat(input_path_m+fname_m+'.mat') - QRSevents_m = matdata['QRSevents'] + matdata = loadmat(input_path_m + fname_m + ".mat") + QRSevents_m = matdata["QRSevents"] # Create filter coefficients fs = sampling_rate a = [0, 0, 1, 1] - f = [0, 0.4/(fs/2), 0.9/(fs / 2), 1] # 0.9 Hz highpass filter + f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter - ord = round(3*fs/0.5) - fwts = firls(ord+1, f, a) + ord = round(3 * fs / 0.5) + fwts = firls(ord + 1, f, a) # run PCA_OBS - PCA_OBS_kwargs = dict( - qrs=QRSevents_m, filter_coords=fwts, sr=sampling_rate - ) + PCA_OBS_kwargs = dict(qrs=QRSevents_m, filter_coords=fwts, sr=sampling_rate) events, event_ids = events_from_annotations(raw) - event_id_dict = {key: value for key, value in event_ids.items() if key == 'qrs'} - epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], - baseline=tuple(iv_baseline)) + event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} + epochs = Epochs( + raw, + events, + event_id=event_id_dict, + tmin=iv_epoch[0], + tmax=iv_epoch[1], + baseline=tuple(iv_baseline), + ) evoked_before = epochs.average() # Apply function should modifies the data in raw in place - raw.apply_function(PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans)) - epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], - baseline=tuple(iv_baseline)) + raw.apply_function( + PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans) + ) + epochs = Epochs( + raw, + events, + event_id=event_id_dict, + tmin=iv_epoch[0], + tmax=iv_epoch[1], + baseline=tuple(iv_baseline), + ) evoked_after = epochs.average() # Comparison image fig, axes = plt.subplots(2, 1) axes[0].plot(evoked_before.times, evoked_before.get_data().T) axes[0].set_ylim([-0.0005, 0.001]) - axes[0].set_title('Before PCA-OBS') + axes[0].set_title("Before PCA-OBS") axes[1].plot(evoked_after.times, evoked_after.get_data().T) axes[1].set_ylim([-0.0005, 0.001]) - axes[1].set_title('After PCA-OBS') + axes[1].set_title("After PCA-OBS") plt.tight_layout() # Comparison image fig, axes = plt.subplots(1, 1) - axes.plot(evoked_before.times, evoked_before.get_data().T, color='black') + axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") axes.set_ylim([-0.0005, 0.001]) - axes.plot(evoked_after.times, evoked_after.get_data().T, color='green') - axes.set_title('Before (black) versus after (green)') + axes.plot(evoked_after.times, evoked_after.get_data().T, color="green") + axes.set_title("Before (black) versus after (green)") plt.tight_layout() plt.show() From 070037dcb42fe04b945891a9d3916a2cb93c09f8 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Wed, 23 Oct 2024 19:12:10 +0200 Subject: [PATCH 06/71] fix: adjust import paths, add init file to module --- mne/preprocessing/pca_obs/PCA_OBS.py | 1 - mne/preprocessing/pca_obs/__init__.py | 0 mne/preprocessing/pca_obs/pchip_interpolation.py | 6 +++--- .../pca_obs/rm_heart_artefact_cortical_mnedata.py | 2 +- .../pca_obs/rm_heart_artefact_spinal_impreciserpeak.py | 5 +++-- .../pca_obs/rm_heart_artefact_spinal_preciserpeak.py | 3 ++- 6 files changed, 9 insertions(+), 8 deletions(-) create mode 100644 mne/preprocessing/pca_obs/__init__.py diff --git a/mne/preprocessing/pca_obs/PCA_OBS.py b/mne/preprocessing/pca_obs/PCA_OBS.py index ec09d4f9f6e..691cd259cdf 100755 --- a/mne/preprocessing/pca_obs/PCA_OBS.py +++ b/mne/preprocessing/pca_obs/PCA_OBS.py @@ -3,7 +3,6 @@ import numpy as np from fit_ecgTemplate import fit_ecgTemplate -# import mne from scipy.signal import detrend, filtfilt from sklearn.decomposition import PCA diff --git a/mne/preprocessing/pca_obs/__init__.py b/mne/preprocessing/pca_obs/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/mne/preprocessing/pca_obs/pchip_interpolation.py b/mne/preprocessing/pca_obs/pchip_interpolation.py index 56a34eabb43..941aa6c0e48 100755 --- a/mne/preprocessing/pca_obs/pchip_interpolation.py +++ b/mne/preprocessing/pca_obs/pchip_interpolation.py @@ -1,6 +1,5 @@ # Function to interpolate based on PCHIP rather than MNE inbuilt linear option -# import mne import numpy as np from scipy.interpolate import PchipInterpolator as pchip @@ -8,9 +7,10 @@ def PCHIP_interpolation(data, **kwargs): # Check all necessary arguments sent in required_kws = ["trigger_indices", "interpol_window_sec", "fs"] - assert all( + if not all( [kw in kwargs.keys() for kw in required_kws] - ), "Error. Some KWs not passed into PCA_OBS." + ): + raise RuntimeError("Some KWs not passed into PCA_OBS.") # Extract all kwargs - more elegant ways to do this fs = kwargs["fs"] diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py index e84f9de3b90..442d0371e4e 100644 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt import numpy as np -from PCA_OBS import * +from mne.preprocessing.pca_obs.PCA_OBS import PCA_OBS from scipy.signal import firls from mne import Epochs diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py index 73682b67867..6c6cd615a7c 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py @@ -1,9 +1,10 @@ # Calls PCA_OBS which in turn calls fit_ecgTemplate to remove the heart artefact via PCA_OBS (Principal Component # Analysis, Optimal Basis Sets) -from PCA_OBS import * +from matplotlib import pyplot as plt +from mne.preprocessing.pca_obs.PCA_OBS import PCA_OBS from scipy.signal import firls - +import numpy as np from mne import Epochs, events_from_annotations from mne.io import read_raw_fif from mne.preprocessing import find_ecg_events diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py index 4c40d8f7234..509c7161f0d 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py @@ -1,7 +1,8 @@ # Calls PCA_OBS which in turn calls fit_ecgTemplate to remove the heart artefact via PCA_OBS (Principal Component # Analysis, Optimal Basis Sets) -from PCA_OBS import * +from matplotlib import pyplot as plt +from mne.preprocessing.pca_obs.PCA_OBS import PCA_OBS from scipy.io import loadmat from scipy.signal import firls From e1b6732750d83d4c04f8df8541fd2d4d38dc141f Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Wed, 23 Oct 2024 19:31:34 +0200 Subject: [PATCH 07/71] refactor: rearrange PCA_OBS arg structure, remove kwarg 'sr' --- mne/preprocessing/pca_obs/PCA_OBS.py | 20 +++++++------------ .../pca_obs/pchip_interpolation.py | 1 + .../rm_heart_artefact_cortical_mnedata.py | 17 ++++++++++------ ...rm_heart_artefact_spinal_impreciserpeak.py | 15 ++++++++------ .../rm_heart_artefact_spinal_preciserpeak.py | 20 ++++++++++--------- 5 files changed, 39 insertions(+), 34 deletions(-) diff --git a/mne/preprocessing/pca_obs/PCA_OBS.py b/mne/preprocessing/pca_obs/PCA_OBS.py index 691cd259cdf..e30006df104 100755 --- a/mne/preprocessing/pca_obs/PCA_OBS.py +++ b/mne/preprocessing/pca_obs/PCA_OBS.py @@ -1,13 +1,18 @@ import math import numpy as np -from fit_ecgTemplate import fit_ecgTemplate +from mne.preprocessing.pca_obs.fit_ecgTemplate import fit_ecgTemplate from scipy.signal import detrend, filtfilt from sklearn.decomposition import PCA -def PCA_OBS(data, **kwargs): +def PCA_OBS( + data, + qrs, + filter_coords, + **kwargs +): # Declare class to hold pca information class PCAInfo: def __init__(self): @@ -16,18 +21,9 @@ def __init__(self): # Instantiate class pca_info = PCAInfo() - # Check all necessary arguments sent in - required_kws = ["qrs", "filter_coords", "sr"] - assert all( - [kw in kwargs.keys() for kw in required_kws] - ), "Error. Some KWs not passed into PCA_OBS." - # Extract all kwargs qrs = kwargs["qrs"] filter_coords = kwargs["filter_coords"] - sr = kwargs["sr"] - - fs = sr # set to baseline data = data.reshape(-1, 1) @@ -66,7 +62,6 @@ def __init__(self): pa = peak_count # Number of QRS complexes detected while peak_idx[pa - 1, 0] + peak_range > len(data[0]): pa = pa - 1 - steps = 1 * pa peak_count = pa # Filter channel @@ -87,7 +82,6 @@ def __init__(self): mean_effect = np.mean( pcamat, axis=0 ) # [1 x time], contains the mean over all epochs - std_effect = np.std(pcamat, axis=0) # want mean and std of each column dpcamat = detrend(pcamat, type="constant", axis=1) # [time x epoch] ################################################################### diff --git a/mne/preprocessing/pca_obs/pchip_interpolation.py b/mne/preprocessing/pca_obs/pchip_interpolation.py index 941aa6c0e48..466946ecdce 100755 --- a/mne/preprocessing/pca_obs/pchip_interpolation.py +++ b/mne/preprocessing/pca_obs/pchip_interpolation.py @@ -3,6 +3,7 @@ import numpy as np from scipy.interpolate import PchipInterpolator as pchip +# TODO: only place defined. is this used? def PCHIP_interpolation(data, **kwargs): # Check all necessary arguments sent in diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py index 442d0371e4e..4df2262cc05 100644 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py @@ -31,7 +31,6 @@ ) = find_ecg_events(raw) # Extract just sample timings of ecg events ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) - # print(ecg_events) # Create filter coefficients fs = raw.info["sfreq"] @@ -42,21 +41,27 @@ fwts = firls(ord + 1, f, a) # For heartbeat epochs - iv_baseline = [-300 / 1000, -200 / 1000] - iv_epoch = [-400 / 1000, 600 / 1000] + iv_baseline = [-300 / 1000, -200 / 1000] # 300 ms before to 200 ms after + iv_epoch = [-400 / 1000, 600 / 1000] # 300 ms before to 200 ms after # run PCA_OBS # Algorithm is extremely sensitive to accurate R-peak timings, won't work as well with the artificial ECG # channel estimation as we have here - PCA_OBS_kwargs = dict(qrs=ecg_event_samples, filter_coords=fwts, sr=fs) - epochs = Epochs( raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], baseline=tuple(iv_baseline) ) evoked_before = epochs.average() # Apply function should modifies the data in raw in place - raw.apply_function(PCA_OBS, picks="eeg", **PCA_OBS_kwargs, n_jobs=10) + raw.apply_function( + PCA_OBS, + picks="eeg", + n_jobs=10 + **{ # args sent to PCA_OBS + "qrs": ecg_event_samples, + "filter_coords": fwts, + }, + ) epochs = Epochs( raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], baseline=tuple(iv_baseline) ) diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py index 6c6cd615a7c..62c126ae6e0 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py @@ -14,7 +14,7 @@ # Set variables subject_id = "sub-001" cond_name = "median" - sampling_rate = 1000 + fs = 1000 # sampling rate esg_chans = [ "S35", "S24", @@ -62,7 +62,7 @@ # Setting paths input_path = "/data/pt_02569/tmp_data/prepared_py/" + subject_id + "/esg/prepro/" - fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs_pchip" + fname = f"noStimart_sr{fs}_{cond_name}_withqrs_pchip" raw = read_raw_fif(input_path + fname + ".fif", preload=True) # Find ECG events - no ECG channel in data, uses synthetic @@ -75,7 +75,6 @@ ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) # Create filter coefficients - fs = sampling_rate a = [0, 0, 1, 1] f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter @@ -83,8 +82,6 @@ fwts = firls(ord + 1, f, a) # run PCA_OBS - PCA_OBS_kwargs = dict(qrs=ecg_event_samples, filter_coords=fwts, sr=sampling_rate) - events, event_ids = events_from_annotations(raw) event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} epochs = Epochs( @@ -99,7 +96,13 @@ # Apply function should modifies the data in raw in place raw.apply_function( - PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans) + PCA_OBS, + picks=esg_chans, + n_jobs=len(esg_chans), + **{ # args sent to PCA_OBS + "qrs": ecg_event_samples, + "filter_coords": fwts, + }, ) epochs = Epochs( raw, diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py index 509c7161f0d..d72bb058c2a 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py @@ -14,7 +14,7 @@ # Set variables subject_id = "sub-001" cond_name = "median" - sampling_rate = 1000 + fs = 1000 # sampling rate esg_chans = [ "S35", "S24", @@ -62,17 +62,15 @@ # Setting paths input_path = "/data/pt_02569/tmp_data/prepared_py/" + subject_id + "/esg/prepro/" - input_path_m = "/data/pt_02569/tmp_data/prepared/" + subject_id + "/esg/prepro/" - fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs_pchip" + fname = f"noStimart_sr{fs}_{cond_name}_withqrs_pchip" raw = read_raw_fif(input_path + fname + ".fif", preload=True) # Read .mat file with QRS events - fname_m = f"raw_{sampling_rate}_spinal_{cond_name}" + input_path_m = "/data/pt_02569/tmp_data/prepared/" + subject_id + "/esg/prepro/" + fname_m = f"raw_{fs}_spinal_{cond_name}" matdata = loadmat(input_path_m + fname_m + ".mat") - QRSevents_m = matdata["QRSevents"] # Create filter coefficients - fs = sampling_rate a = [0, 0, 1, 1] f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter @@ -80,8 +78,6 @@ fwts = firls(ord + 1, f, a) # run PCA_OBS - PCA_OBS_kwargs = dict(qrs=QRSevents_m, filter_coords=fwts, sr=sampling_rate) - events, event_ids = events_from_annotations(raw) event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} epochs = Epochs( @@ -96,7 +92,13 @@ # Apply function should modifies the data in raw in place raw.apply_function( - PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans) + PCA_OBS, + picks=esg_chans, + n_jobs=len(esg_chans) + **{ # args sent to PCA_OBS + "qrs": matdata["QRSevents"], + "filter_coords": fwts, + }, ) epochs = Epochs( raw, From ccd42980d325b6b466ae2450a6899a1c7ae8af0c Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Wed, 23 Oct 2024 20:09:05 +0200 Subject: [PATCH 08/71] refactor: move common variables to module init, further cleanup --- mne/preprocessing/pca_obs/PCA_OBS.py | 14 ++-- mne/preprocessing/pca_obs/__init__.py | 58 ++++++++++++++ .../pca_obs/pchip_interpolation.py | 21 ++--- ...rm_heart_artefact_spinal_impreciserpeak.py | 72 +++-------------- .../rm_heart_artefact_spinal_preciserpeak.py | 77 ++++--------------- 5 files changed, 99 insertions(+), 143 deletions(-) diff --git a/mne/preprocessing/pca_obs/PCA_OBS.py b/mne/preprocessing/pca_obs/PCA_OBS.py index e30006df104..cc226113de3 100755 --- a/mne/preprocessing/pca_obs/PCA_OBS.py +++ b/mne/preprocessing/pca_obs/PCA_OBS.py @@ -1,17 +1,19 @@ import math +from typing import Any import numpy as np from mne.preprocessing.pca_obs.fit_ecgTemplate import fit_ecgTemplate from scipy.signal import detrend, filtfilt from sklearn.decomposition import PCA +from numpy.typing import NDArray def PCA_OBS( - data, - qrs, - filter_coords, - **kwargs + data: NDArray[Any], + qrs: NDArray[Any], + filter_coords: NDArray[Any], + **_ # TODO: are there any other kwargs passed in? ): # Declare class to hold pca information class PCAInfo: @@ -21,10 +23,6 @@ def __init__(self): # Instantiate class pca_info = PCAInfo() - # Extract all kwargs - qrs = kwargs["qrs"] - filter_coords = kwargs["filter_coords"] - # set to baseline data = data.reshape(-1, 1) data = data.T diff --git a/mne/preprocessing/pca_obs/__init__.py b/mne/preprocessing/pca_obs/__init__.py index e69de29bb2d..054c4a3d2aa 100644 --- a/mne/preprocessing/pca_obs/__init__.py +++ b/mne/preprocessing/pca_obs/__init__.py @@ -0,0 +1,58 @@ +from dataclasses import dataclass +from mne.io import Raw, read_raw_fif + + +# TODO: description of what this is +ESG_CHANS = [ + "S35", + "S24", + "S36", + "Iz", + "S17", + "S15", + "S32", + "S22", + "S19", + "S26", + "S28", + "S9", + "S13", + "S11", + "S7", + "SC1", + "S4", + "S18", + "S8", + "S31", + "SC6", + "S12", + "S16", + "S5", + "S30", + "S20", + "S34", + "AC", + "S21", + "S25", + "L1", + "S29", + "S14", + "S33", + "S3", + "AL", + "L4", + "S6", + "S23", +] + +# Set variables +fs = 1000 # sampling rate + +# For heartbeat epochs +iv_baseline = [-300 / 1000, -200 / 1000] +iv_epoch = [-400 / 1000, 600 / 1000] + +# Setting paths +input_path = "/data/pt_02569/tmp_data/prepared_py/sub-001/esg/prepro/" +fname = f"noStimart_sr{fs}_median_withqrs_pchip" +raw = read_raw_fif(input_path + fname + ".fif", preload=True) diff --git a/mne/preprocessing/pca_obs/pchip_interpolation.py b/mne/preprocessing/pca_obs/pchip_interpolation.py index 466946ecdce..e4a448eeecf 100755 --- a/mne/preprocessing/pca_obs/pchip_interpolation.py +++ b/mne/preprocessing/pca_obs/pchip_interpolation.py @@ -1,23 +1,18 @@ # Function to interpolate based on PCHIP rather than MNE inbuilt linear option +from typing import Any import numpy as np from scipy.interpolate import PchipInterpolator as pchip +from numpy.typing import NDArray # TODO: only place defined. is this used? -def PCHIP_interpolation(data, **kwargs): - # Check all necessary arguments sent in - required_kws = ["trigger_indices", "interpol_window_sec", "fs"] - if not all( - [kw in kwargs.keys() for kw in required_kws] - ): - raise RuntimeError("Some KWs not passed into PCA_OBS.") - - # Extract all kwargs - more elegant ways to do this - fs = kwargs["fs"] - interpol_window_sec = kwargs["interpol_window_sec"] - trigger_indices = kwargs["trigger_indices"] - +def PCHIP_interpolation( + data: NDArray[Any], + trigger_indices, + interpol_window_sec, + fs, +): # Convert intpol window to msec then convert to samples pre_window = round((interpol_window_sec[0] * 1000) * fs / 1000) # in samples post_window = round((interpol_window_sec[1] * 1000) * fs / 1000) # in samples diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py index 62c126ae6e0..f57221e0ed0 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py @@ -3,6 +3,13 @@ from matplotlib import pyplot as plt from mne.preprocessing.pca_obs.PCA_OBS import PCA_OBS +from mne.preprocessing.pca_obs import ESG_CHANS +from mne.preprocessing.pca_obs import ( + fs, + iv_baseline, + iv_epoch, + raw, +) from scipy.signal import firls import numpy as np from mne import Epochs, events_from_annotations @@ -10,60 +17,6 @@ from mne.preprocessing import find_ecg_events if __name__ == "__main__": - # Incredibly slow without parallelization - # Set variables - subject_id = "sub-001" - cond_name = "median" - fs = 1000 # sampling rate - esg_chans = [ - "S35", - "S24", - "S36", - "Iz", - "S17", - "S15", - "S32", - "S22", - "S19", - "S26", - "S28", - "S9", - "S13", - "S11", - "S7", - "SC1", - "S4", - "S18", - "S8", - "S31", - "SC6", - "S12", - "S16", - "S5", - "S30", - "S20", - "S34", - "AC", - "S21", - "S25", - "L1", - "S29", - "S14", - "S33", - "S3", - "AL", - "L4", - "S6", - "S23", - ] - # For heartbeat epochs - iv_baseline = [-300 / 1000, -200 / 1000] - iv_epoch = [-400 / 1000, 600 / 1000] - - # Setting paths - input_path = "/data/pt_02569/tmp_data/prepared_py/" + subject_id + "/esg/prepro/" - fname = f"noStimart_sr{fs}_{cond_name}_withqrs_pchip" - raw = read_raw_fif(input_path + fname + ".fif", preload=True) # Find ECG events - no ECG channel in data, uses synthetic ( @@ -97,12 +50,11 @@ # Apply function should modifies the data in raw in place raw.apply_function( PCA_OBS, - picks=esg_chans, - n_jobs=len(esg_chans), - **{ # args sent to PCA_OBS - "qrs": ecg_event_samples, - "filter_coords": fwts, - }, + picks=ESG_CHANS, + n_jobs=len(ESG_CHANS), + # args sent to PCA_OBS + qrs=ecg_event_samples, + filter_coords=fwts, ) epochs = Epochs( raw, diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py index d72bb058c2a..f1807829642 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py @@ -5,69 +5,23 @@ from mne.preprocessing.pca_obs.PCA_OBS import PCA_OBS from scipy.io import loadmat from scipy.signal import firls - +from mne.preprocessing.pca_obs import ESG_CHANS +from mne.preprocessing.pca_obs import ( + fs, + iv_baseline, + iv_epoch, + raw, +) from mne import Epochs, events_from_annotations from mne.io import read_raw_fif + if __name__ == "__main__": - # Incredibly slow without parallelization - # Set variables - subject_id = "sub-001" - cond_name = "median" - fs = 1000 # sampling rate - esg_chans = [ - "S35", - "S24", - "S36", - "Iz", - "S17", - "S15", - "S32", - "S22", - "S19", - "S26", - "S28", - "S9", - "S13", - "S11", - "S7", - "SC1", - "S4", - "S18", - "S8", - "S31", - "SC6", - "S12", - "S16", - "S5", - "S30", - "S20", - "S34", - "AC", - "S21", - "S25", - "L1", - "S29", - "S14", - "S33", - "S3", - "AL", - "L4", - "S6", - "S23", - ] - # For heartbeat epochs - iv_baseline = [-300 / 1000, -200 / 1000] - iv_epoch = [-400 / 1000, 600 / 1000] - # Setting paths - input_path = "/data/pt_02569/tmp_data/prepared_py/" + subject_id + "/esg/prepro/" - fname = f"noStimart_sr{fs}_{cond_name}_withqrs_pchip" - raw = read_raw_fif(input_path + fname + ".fif", preload=True) # Read .mat file with QRS events - input_path_m = "/data/pt_02569/tmp_data/prepared/" + subject_id + "/esg/prepro/" - fname_m = f"raw_{fs}_spinal_{cond_name}" + input_path_m = "/data/pt_02569/tmp_data/prepared/sub-001/esg/prepro/" + fname_m = f"raw_1000_spinal_median" matdata = loadmat(input_path_m + fname_m + ".mat") # Create filter coefficients @@ -93,12 +47,11 @@ # Apply function should modifies the data in raw in place raw.apply_function( PCA_OBS, - picks=esg_chans, - n_jobs=len(esg_chans) - **{ # args sent to PCA_OBS - "qrs": matdata["QRSevents"], - "filter_coords": fwts, - }, + picks=ESG_CHANS, + n_jobs=len(ESG_CHANS), + # args sent to PCA_OBS + qrs=matdata["QRSevents"], + filter_coords=fwts, ) epochs = Epochs( raw, From 61a92eaebb332c19204f0181ea95f118b9b975a8 Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Fri, 25 Oct 2024 13:09:15 +0200 Subject: [PATCH 09/71] feat/initial-cleanup: Remove custom pchip as not in use --- .../pca_obs/pchip_interpolation.py | 51 ------------------- 1 file changed, 51 deletions(-) delete mode 100755 mne/preprocessing/pca_obs/pchip_interpolation.py diff --git a/mne/preprocessing/pca_obs/pchip_interpolation.py b/mne/preprocessing/pca_obs/pchip_interpolation.py deleted file mode 100755 index e4a448eeecf..00000000000 --- a/mne/preprocessing/pca_obs/pchip_interpolation.py +++ /dev/null @@ -1,51 +0,0 @@ -# Function to interpolate based on PCHIP rather than MNE inbuilt linear option - -from typing import Any -import numpy as np -from scipy.interpolate import PchipInterpolator as pchip -from numpy.typing import NDArray - -# TODO: only place defined. is this used? - -def PCHIP_interpolation( - data: NDArray[Any], - trigger_indices, - interpol_window_sec, - fs, -): - # Convert intpol window to msec then convert to samples - pre_window = round((interpol_window_sec[0] * 1000) * fs / 1000) # in samples - post_window = round((interpol_window_sec[1] * 1000) * fs / 1000) # in samples - intpol_window = np.ceil([pre_window, post_window]).astype( - int - ) # interpolation window - - n_samples_fit = ( - 5 # number of samples before and after cut used for interpolation fit - ) - - x_fit_raw = np.concatenate( - [ - np.arange(intpol_window[0] - n_samples_fit - 1, intpol_window[0], 1), - np.arange(intpol_window[1] + 1, intpol_window[1] + n_samples_fit + 2, 1), - ] - ) - x_interpol_raw = np.arange( - intpol_window[0], intpol_window[1] + 1, 1 - ) # points to be interpolated; in pt - - for ii in np.arange(0, len(trigger_indices)): # loop through all stimulation events - x_fit = trigger_indices[ii] + x_fit_raw # fit point latencies for this event - x_interpol = ( - trigger_indices[ii] + x_interpol_raw - ) # latencies for to-be-interpolated data points - - # Data is just a string of values - y_fit = data[x_fit] # y values to be fitted - y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation - data[x_interpol] = y_interpol # replace in data - - if np.mod(ii, 100) == 0: # talk to the operator every 100th trial - print(f"stimulation event {ii} \n") - - return data From a1eb6f6ee4c153914a140fd436cc8b86ac746fdb Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Fri, 25 Oct 2024 13:15:39 +0200 Subject: [PATCH 10/71] refactor/initial-cleanup: answer question --- mne/preprocessing/pca_obs/PCA_OBS.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/preprocessing/pca_obs/PCA_OBS.py b/mne/preprocessing/pca_obs/PCA_OBS.py index cc226113de3..541302853d6 100755 --- a/mne/preprocessing/pca_obs/PCA_OBS.py +++ b/mne/preprocessing/pca_obs/PCA_OBS.py @@ -13,7 +13,7 @@ def PCA_OBS( data: NDArray[Any], qrs: NDArray[Any], filter_coords: NDArray[Any], - **_ # TODO: are there any other kwargs passed in? + **_ # TODO: are there any other kwargs passed in? No - this is all I believe ): # Declare class to hold pca information class PCAInfo: From d75e92715950441257e1d51884bbd452c51330f2 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Mon, 4 Nov 2024 17:42:11 +0100 Subject: [PATCH 11/71] refactor: rename files to match other modules in preprocessing, reduce size of indentation tree --- mne/preprocessing/pca_obs/__init__.py | 13 ++- mne/preprocessing/pca_obs/fit_ecgTemplate.py | 87 ++++++++++--------- .../pca_obs/{PCA_OBS.py => pca_obs.py} | 3 +- .../rm_heart_artefact_cortical_mnedata.py | 4 +- ...rm_heart_artefact_spinal_impreciserpeak.py | 4 +- .../rm_heart_artefact_spinal_preciserpeak.py | 4 +- 6 files changed, 63 insertions(+), 52 deletions(-) rename mne/preprocessing/pca_obs/{PCA_OBS.py => pca_obs.py} (98%) diff --git a/mne/preprocessing/pca_obs/__init__.py b/mne/preprocessing/pca_obs/__init__.py index 054c4a3d2aa..dccbfab698a 100644 --- a/mne/preprocessing/pca_obs/__init__.py +++ b/mne/preprocessing/pca_obs/__init__.py @@ -1,6 +1,15 @@ -from dataclasses import dataclass -from mne.io import Raw, read_raw_fif +"""Principle Component Analysis of OBS (PCA-OBS).""" # TODO: What's OBS? +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +from mne.io import read_raw_fif +from .pca_obs import pca_obs + +__all__ = [ + "pca_obs" +] # TODO: description of what this is ESG_CHANS = [ diff --git a/mne/preprocessing/pca_obs/fit_ecgTemplate.py b/mne/preprocessing/pca_obs/fit_ecgTemplate.py index 3e5b7f55cb2..0aab0e15157 100755 --- a/mne/preprocessing/pca_obs/fit_ecgTemplate.py +++ b/mne/preprocessing/pca_obs/fit_ecgTemplate.py @@ -13,7 +13,7 @@ def fit_ecgTemplate( baseline_range, midP, fitted_art, - post_idx_previousPeak, + post_idx_previousPeak: list, n_samples_fit, ): # Declare class to hold ecg fit information @@ -22,6 +22,7 @@ def __init__(self): pass # Instantiate class + # TODO: Why are we storing this to a class? Can't we just use the variables and write to them? fitecg = fitECG() # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previouspeak @@ -53,46 +54,48 @@ def __init__(self): post_idx_nextPeak = [aPeak_idx[0] + post_range] - # Check it's not empty - if len(post_idx_previousPeak) != 0: - # interpolate time between peaks - intpol_window = np.ceil( - [post_idx_previousPeak[0], aPeak_idx[0] - pre_range] - ).astype("int") # interpolation window - fitecg.intpol_window = intpol_window - - if intpol_window[0] < intpol_window[1]: - # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data - - # You have x_fit which is two slices on either side of the interpolation window endpoints - # You have y_fit which is the y vals corresponding to x values above - # You have x_interpol which is the time points between the two slices in x_fit that you want to interpolate - # You have y_interpol which is values from pchip at the time points specified in x_interpol - x_interpol = np.arange( - intpol_window[0], intpol_window[1] + 1, 1 - ) # points to be interpolated in pt - the gap between the endpoints of the window - x_fit = np.concatenate( - [ - np.arange( - intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1 - ), - np.arange( - intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1 - ), - ] - ) # Entire range of x values in this step (taking some number of samples before and after the window) - y_fit = fitted_art[0, x_fit] - y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation - - # Then make fitted artefact in the desired range equal to the completed fit above - fitted_art[0, post_idx_previousPeak[0] : aPeak_idx[0] - pre_range + 1] = ( - y_interpol - ) - - fitecg.x_fit = x_fit - fitecg.y_fit = y_fit - fitecg.x_interpol = x_interpol - fitecg.y_interpol = y_interpol - fitecg.fitted_art = fitted_art # Reassign if we've gone into this loop + # if last peak, return + if not post_idx_previousPeak: + return fitted_art, post_idx_nextPeak + + # interpolate time between peaks + intpol_window = np.ceil( + [post_idx_previousPeak[0], aPeak_idx[0] - pre_range] + ).astype("int") # interpolation window + fitecg.intpol_window = intpol_window + + if intpol_window[0] < intpol_window[1]: + # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data + + # You have x_fit which is two slices on either side of the interpolation window endpoints + # You have y_fit which is the y vals corresponding to x values above + # You have x_interpol which is the time points between the two slices in x_fit that you want to interpolate + # You have y_interpol which is values from pchip at the time points specified in x_interpol + x_interpol = np.arange( + intpol_window[0], intpol_window[1] + 1, 1 + ) # points to be interpolated in pt - the gap between the endpoints of the window + x_fit = np.concatenate( + [ + np.arange( + intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1 + ), + np.arange( + intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1 + ), + ] + ) # Entire range of x values in this step (taking some number of samples before and after the window) + y_fit = fitted_art[0, x_fit] + y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation + + # Then make fitted artefact in the desired range equal to the completed fit above + fitted_art[0, post_idx_previousPeak[0] : aPeak_idx[0] - pre_range + 1] = ( + y_interpol + ) + + fitecg.x_fit = x_fit + fitecg.y_fit = y_fit + fitecg.x_interpol = x_interpol + fitecg.y_interpol = y_interpol + fitecg.fitted_art = fitted_art # Reassign if we've gone into this loop return fitted_art, post_idx_nextPeak diff --git a/mne/preprocessing/pca_obs/PCA_OBS.py b/mne/preprocessing/pca_obs/pca_obs.py similarity index 98% rename from mne/preprocessing/pca_obs/PCA_OBS.py rename to mne/preprocessing/pca_obs/pca_obs.py index 541302853d6..6fcf5882cbb 100755 --- a/mne/preprocessing/pca_obs/PCA_OBS.py +++ b/mne/preprocessing/pca_obs/pca_obs.py @@ -9,11 +9,10 @@ from numpy.typing import NDArray -def PCA_OBS( +def pca_obs( data: NDArray[Any], qrs: NDArray[Any], filter_coords: NDArray[Any], - **_ # TODO: are there any other kwargs passed in? No - this is all I believe ): # Declare class to hold pca information class PCAInfo: diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py index 4df2262cc05..f0e01d6d8d4 100644 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt import numpy as np -from mne.preprocessing.pca_obs.PCA_OBS import PCA_OBS +from mne.preprocessing.pca_obs import pca_obs from scipy.signal import firls from mne import Epochs @@ -54,7 +54,7 @@ # Apply function should modifies the data in raw in place raw.apply_function( - PCA_OBS, + pca_obs, picks="eeg", n_jobs=10 **{ # args sent to PCA_OBS diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py index f57221e0ed0..5183054dc21 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py @@ -2,7 +2,7 @@ # Analysis, Optimal Basis Sets) from matplotlib import pyplot as plt -from mne.preprocessing.pca_obs.PCA_OBS import PCA_OBS +from mne.preprocessing.pca_obs import pca_obs from mne.preprocessing.pca_obs import ESG_CHANS from mne.preprocessing.pca_obs import ( fs, @@ -49,7 +49,7 @@ # Apply function should modifies the data in raw in place raw.apply_function( - PCA_OBS, + pca_obs, picks=ESG_CHANS, n_jobs=len(ESG_CHANS), # args sent to PCA_OBS diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py index f1807829642..f561fab69a0 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py @@ -2,7 +2,7 @@ # Analysis, Optimal Basis Sets) from matplotlib import pyplot as plt -from mne.preprocessing.pca_obs.PCA_OBS import PCA_OBS +from mne.preprocessing.pca_obs import pca_obs from scipy.io import loadmat from scipy.signal import firls from mne.preprocessing.pca_obs import ESG_CHANS @@ -46,7 +46,7 @@ # Apply function should modifies the data in raw in place raw.apply_function( - PCA_OBS, + pca_obs, picks=ESG_CHANS, n_jobs=len(ESG_CHANS), # args sent to PCA_OBS From 035a9f60a3484483bb23428a2e860660f5f01f43 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Mon, 4 Nov 2024 17:48:18 +0100 Subject: [PATCH 12/71] refactor: remove unused fit_ecgTemplate variable 'baselinerange' --- mne/preprocessing/pca_obs/fit_ecgTemplate.py | 1 - mne/preprocessing/pca_obs/pca_obs.py | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/mne/preprocessing/pca_obs/fit_ecgTemplate.py b/mne/preprocessing/pca_obs/fit_ecgTemplate.py index 0aab0e15157..51665ed4d03 100755 --- a/mne/preprocessing/pca_obs/fit_ecgTemplate.py +++ b/mne/preprocessing/pca_obs/fit_ecgTemplate.py @@ -10,7 +10,6 @@ def fit_ecgTemplate( peak_range, pre_range, post_range, - baseline_range, midP, fitted_art, post_idx_previousPeak: list, diff --git a/mne/preprocessing/pca_obs/pca_obs.py b/mne/preprocessing/pca_obs/pca_obs.py index 6fcf5882cbb..846c4a0eab3 100755 --- a/mne/preprocessing/pca_obs/pca_obs.py +++ b/mne/preprocessing/pca_obs/pca_obs.py @@ -9,6 +9,7 @@ from numpy.typing import NDArray +# TODO: Are we able to split this into smaller segmented functions? def pca_obs( data: NDArray[Any], qrs: NDArray[Any], @@ -19,6 +20,8 @@ class PCAInfo: def __init__(self): pass + # NOTE: Here aswell, is there a reason we are storing this + # to a class? Shouldn't variables suffice? # Instantiate class pca_info = PCAInfo() @@ -96,7 +99,7 @@ def __init__(self): pca_info.expl_var = pca.explained_variance_ratio_ # define selected number of components using profile likelihood - pca_info.nComponents = 4 + pca_info.nComponents = 4 # TODO: Is this a variable? Or constant? Seems like a variable pca_info.meanEffect = mean_effect.T nComponents = pca_info.nComponents @@ -127,7 +130,6 @@ def __init__(self): peak_range, pre_range, post_range, - baseline_range, midP, fitted_art, post_idx_nextPeak, @@ -154,7 +156,6 @@ def __init__(self): peak_range, pre_range, post_range, - baseline_range, midP, fitted_art, post_idx_nextPeak, @@ -187,7 +188,6 @@ def __init__(self): peak_range, pre_range, post_range, - baseline_range, midP, fitted_art, post_idx_nextPeak, From 2debb614afd4201644e6ec7953a288218d104dbf Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Mon, 4 Nov 2024 17:55:06 +0100 Subject: [PATCH 13/71] refactor: make main methods private, import from module __init__, remove more unused variables and imports, add some types --- mne/preprocessing/pca_obs/__init__.py | 6 ++-- ...it_ecgTemplate.py => _fit_ecg_template.py} | 29 +++++++++++++++++-- .../pca_obs/{pca_obs.py => _pca_obs.py} | 16 +++++----- ...rm_heart_artefact_spinal_impreciserpeak.py | 3 +- .../rm_heart_artefact_spinal_preciserpeak.py | 3 +- 5 files changed, 40 insertions(+), 17 deletions(-) rename mne/preprocessing/pca_obs/{fit_ecgTemplate.py => _fit_ecg_template.py} (80%) rename mne/preprocessing/pca_obs/{pca_obs.py => _pca_obs.py} (94%) diff --git a/mne/preprocessing/pca_obs/__init__.py b/mne/preprocessing/pca_obs/__init__.py index dccbfab698a..88f6276264b 100644 --- a/mne/preprocessing/pca_obs/__init__.py +++ b/mne/preprocessing/pca_obs/__init__.py @@ -5,10 +5,12 @@ # Copyright the MNE-Python contributors. from mne.io import read_raw_fif -from .pca_obs import pca_obs +from ._pca_obs import pca_obs +from ._fit_ecg_template import fit_ecg_template __all__ = [ - "pca_obs" + "pca_obs", + "fit_ecg_template" ] # TODO: description of what this is diff --git a/mne/preprocessing/pca_obs/fit_ecgTemplate.py b/mne/preprocessing/pca_obs/_fit_ecg_template.py similarity index 80% rename from mne/preprocessing/pca_obs/fit_ecgTemplate.py rename to mne/preprocessing/pca_obs/_fit_ecg_template.py index 51665ed4d03..631bd0598f2 100755 --- a/mne/preprocessing/pca_obs/fit_ecgTemplate.py +++ b/mne/preprocessing/pca_obs/_fit_ecg_template.py @@ -3,7 +3,7 @@ from scipy.signal import detrend -def fit_ecgTemplate( +def fit_ecg_template( data, pca_template, aPeak_idx, @@ -14,7 +14,32 @@ def fit_ecgTemplate( fitted_art, post_idx_previousPeak: list, n_samples_fit, -): +) -> tuple[np.ndarray, list]: + """TODO: Write docstring about what we do here. + Fits the ECG to a template signal (?) + and returns the fitted artefact and the index of the next peak. (?) + + .. note:: This should only be used on data which is ... (TODO: are there any conditions that must be met to use our algos?) + + + # TODO: Fill out input/output and raises + Parameters + ---------- + data (_type_): _description_ + pca_template (_type_): _description_ + aPeak_idx (_type_): _description_ + peak_range (_type_): _description_ + pre_range (_type_): _description_ + post_range (_type_): _description_ + midP (_type_): _description_ + fitted_art (_type_): _description_ + post_idx_previousPeak (list): _description_ + n_samples_fit (_type_): _description_ + + Returns + ------- + _type_: _description_ + """ # Declare class to hold ecg fit information class fitECG: def __init__(self): diff --git a/mne/preprocessing/pca_obs/pca_obs.py b/mne/preprocessing/pca_obs/_pca_obs.py similarity index 94% rename from mne/preprocessing/pca_obs/pca_obs.py rename to mne/preprocessing/pca_obs/_pca_obs.py index 846c4a0eab3..3c11e293bd7 100755 --- a/mne/preprocessing/pca_obs/pca_obs.py +++ b/mne/preprocessing/pca_obs/_pca_obs.py @@ -2,18 +2,17 @@ from typing import Any import numpy as np -from mne.preprocessing.pca_obs.fit_ecgTemplate import fit_ecgTemplate +from mne.preprocessing.pca_obs import fit_ecg_template from scipy.signal import detrend, filtfilt from sklearn.decomposition import PCA -from numpy.typing import NDArray # TODO: Are we able to split this into smaller segmented functions? def pca_obs( - data: NDArray[Any], - qrs: NDArray[Any], - filter_coords: NDArray[Any], + data: np.ndarray, + qrs: np.ndarray, + filter_coords: np.ndarray, ): # Declare class to hold pca information class PCAInfo: @@ -53,7 +52,6 @@ def __init__(self): mRR = np.median(RR) peak_range = round(mRR / 2) # Rounds to an integer midP = peak_range + 1 - baseline_range = [0, round(peak_range / 8)] n_samples_fit = round( peak_range / 8 ) # sample fit for interpolation between fitted artifact windows @@ -123,7 +121,7 @@ def __init__(self): post_range = peak_range try: post_idx_nextPeak = [] - fitted_art, post_idx_nextPeak = fit_ecgTemplate( + fitted_art, post_idx_nextPeak = fit_ecg_template( data, pca_template, peak_idx[p], @@ -149,7 +147,7 @@ def __init__(self): post_range = peak_range if pre_range > peak_range: pre_range = peak_range - fitted_art, _ = fit_ecgTemplate( + fitted_art, _ = fit_ecg_template( data, pca_template, peak_idx(p), @@ -181,7 +179,7 @@ def __init__(self): aTemplate = pca_template[ midP - peak_range - 1 : midP + peak_range + 1, : ] - fitted_art, post_idx_nextPeak = fit_ecgTemplate( + fitted_art, post_idx_nextPeak = fit_ecg_template( data, aTemplate, peak_idx[p], diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py index 5183054dc21..10dce937e21 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py @@ -1,4 +1,4 @@ -# Calls PCA_OBS which in turn calls fit_ecgTemplate to remove the heart artefact via PCA_OBS (Principal Component +# Calls PCA_OBS which in turn calls fit_ecg_template to remove the heart artefact via PCA_OBS (Principal Component # Analysis, Optimal Basis Sets) from matplotlib import pyplot as plt @@ -13,7 +13,6 @@ from scipy.signal import firls import numpy as np from mne import Epochs, events_from_annotations -from mne.io import read_raw_fif from mne.preprocessing import find_ecg_events if __name__ == "__main__": diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py index f561fab69a0..21723fc6c93 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py @@ -1,4 +1,4 @@ -# Calls PCA_OBS which in turn calls fit_ecgTemplate to remove the heart artefact via PCA_OBS (Principal Component +# Calls PCA_OBS which in turn calls fit_ecg_template to remove the heart artefact via PCA_OBS (Principal Component # Analysis, Optimal Basis Sets) from matplotlib import pyplot as plt @@ -13,7 +13,6 @@ raw, ) from mne import Epochs, events_from_annotations -from mne.io import read_raw_fif if __name__ == "__main__": From 97f9a15f572933f8917d560e7bc4f5bbeba9e541 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Mon, 4 Nov 2024 18:09:04 +0100 Subject: [PATCH 14/71] refactor: add docstring templates, gather imports --- mne/preprocessing/pca_obs/_fit_ecg_template.py | 6 +++--- mne/preprocessing/pca_obs/_pca_obs.py | 18 +++++++++++++++++- .../rm_heart_artefact_cortical_mnedata.py | 9 ++++----- .../rm_heart_artefact_spinal_impreciserpeak.py | 4 ++-- .../rm_heart_artefact_spinal_preciserpeak.py | 4 ++-- 5 files changed, 28 insertions(+), 13 deletions(-) diff --git a/mne/preprocessing/pca_obs/_fit_ecg_template.py b/mne/preprocessing/pca_obs/_fit_ecg_template.py index 631bd0598f2..23f5a512539 100755 --- a/mne/preprocessing/pca_obs/_fit_ecg_template.py +++ b/mne/preprocessing/pca_obs/_fit_ecg_template.py @@ -19,8 +19,8 @@ def fit_ecg_template( Fits the ECG to a template signal (?) and returns the fitted artefact and the index of the next peak. (?) - .. note:: This should only be used on data which is ... (TODO: are there any conditions that must be met to use our algos?) - + (TODO: are there any conditions that must be met to use our algos?) + .. note:: This should only be used on data which is ... # TODO: Fill out input/output and raises Parameters @@ -38,7 +38,7 @@ def fit_ecg_template( Returns ------- - _type_: _description_ + tuple[np.ndarray, list]: the fitted artifact and the next peak index (if available) """ # Declare class to hold ecg fit information class fitECG: diff --git a/mne/preprocessing/pca_obs/_pca_obs.py b/mne/preprocessing/pca_obs/_pca_obs.py index 3c11e293bd7..c6d9f752d69 100755 --- a/mne/preprocessing/pca_obs/_pca_obs.py +++ b/mne/preprocessing/pca_obs/_pca_obs.py @@ -13,7 +13,23 @@ def pca_obs( data: np.ndarray, qrs: np.ndarray, filter_coords: np.ndarray, -): +) -> np.ndarray: + """ + Algorithm to perform the PCA OBS (Principal Component Analysis, Optimal Basis Sets) + algorithm to remove the heart artefact from EEG data. + + .. note:: This should only be used on data which is ... (TODO: are there any conditions that must be met to use our algos?) + + Parameters + ---------- + data (np.ndarray): The data which we want to remove the heart artefact from. + qrs (np.ndarray): _description_ + filter_coords (np.ndarray): _description_ + + Returns + ------- + np.ndarray: The data with the heart artefact removed. + """ # Declare class to hold pca information class PCAInfo: def __init__(self): diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py index f0e01d6d8d4..10cb7772bbf 100644 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py @@ -56,11 +56,10 @@ raw.apply_function( pca_obs, picks="eeg", - n_jobs=10 - **{ # args sent to PCA_OBS - "qrs": ecg_event_samples, - "filter_coords": fwts, - }, + n_jobs=10, + # args sent to PCA_OBS + qrs=ecg_event_samples, + filter_coords=fwts, ) epochs = Epochs( raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], baseline=tuple(iv_baseline) diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py index 10dce937e21..0f0df7d6c75 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py @@ -2,13 +2,13 @@ # Analysis, Optimal Basis Sets) from matplotlib import pyplot as plt -from mne.preprocessing.pca_obs import pca_obs -from mne.preprocessing.pca_obs import ESG_CHANS from mne.preprocessing.pca_obs import ( fs, iv_baseline, iv_epoch, raw, + pca_obs, + ESG_CHANS, ) from scipy.signal import firls import numpy as np diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py index 21723fc6c93..2a195e85a40 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py @@ -2,15 +2,15 @@ # Analysis, Optimal Basis Sets) from matplotlib import pyplot as plt -from mne.preprocessing.pca_obs import pca_obs from scipy.io import loadmat from scipy.signal import firls -from mne.preprocessing.pca_obs import ESG_CHANS from mne.preprocessing.pca_obs import ( fs, iv_baseline, iv_epoch, raw, + pca_obs, + ESG_CHANS, ) from mne import Epochs, events_from_annotations From 34a2e060bcbd5c64d1bbfce2a39b1b00881ca614 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Mon, 4 Nov 2024 18:41:21 +0100 Subject: [PATCH 15/71] test: add placeholder tests, add multiple todos to address where we get data from, how we call functions, how we assert outputs --- mne/preprocessing/pca_obs/tests/__init__.py | 0 .../pca_obs/tests/test_fit_ecg.py | 39 ++++++++++++++++ .../pca_obs/tests/test_pca_obs.py | 44 +++++++++++++++++++ 3 files changed, 83 insertions(+) create mode 100644 mne/preprocessing/pca_obs/tests/__init__.py create mode 100644 mne/preprocessing/pca_obs/tests/test_fit_ecg.py create mode 100644 mne/preprocessing/pca_obs/tests/test_pca_obs.py diff --git a/mne/preprocessing/pca_obs/tests/__init__.py b/mne/preprocessing/pca_obs/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/mne/preprocessing/pca_obs/tests/test_fit_ecg.py b/mne/preprocessing/pca_obs/tests/test_fit_ecg.py new file mode 100644 index 00000000000..eb5c80bc8d8 --- /dev/null +++ b/mne/preprocessing/pca_obs/tests/test_fit_ecg.py @@ -0,0 +1,39 @@ +"""Test the fot_ecg_template function.""" + +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +from mne.io import read_raw_fif +from mne.preprocessing.pca_obs import fit_ecg_template +from mne.datasets.testing import data_path, requires_testing_data + +# TODO: Where are the test files we want to use located? +fname = data_path(download=False) / "eyetrack" / "test_eyelink.asc" + + +@requires_testing_data +def test_fit_ecg_template(): + """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" + raw = read_raw_fif(fname) + + # Somehow have to "fake" all these inputs to the function + result = fit_ecg_template( + data=None, + pca_template=None, + aPeak_idx=None, + peak_range=None, + pre_range=None, + post_range=None, + midP=None, + fitted_art=None, + post_idx_previousPeak=None, + n_samples_fit=None, + ) + + # assert results + assert result is not None + assert result.shape == (100, 100) + assert result.shape == raw.shape # is this a condition we can test? + assert result[0, 0] == 1.0 + ... \ No newline at end of file diff --git a/mne/preprocessing/pca_obs/tests/test_pca_obs.py b/mne/preprocessing/pca_obs/tests/test_pca_obs.py new file mode 100644 index 00000000000..06d45062d6d --- /dev/null +++ b/mne/preprocessing/pca_obs/tests/test_pca_obs.py @@ -0,0 +1,44 @@ +"""Test the ieeg projection functions.""" + +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +# TODO: migrate this structure to test out function + +import pytest + +from mne.io import read_raw_fif +from mne.preprocessing.pca_obs import pca_obs +from mne.datasets.testing import data_path, requires_testing_data + +# TODO: Where are the test files we want to use located? +fname = data_path(download=False) / "eyetrack" / "test_eyelink.asc" + +@requires_testing_data +@pytest.mark.parametrize( + # TODO: Are there any parameters we can cycle through to + # test multiple? Different fs, windows, highpass freqs, etc.? + # TODO: how do we determine qrs and filter_coords? What are these? + "fs, highpass_freq, qrs, filter_coords", + [ + (0.2, 1.0, 100, 200), + (0.1, 2.0, 100, 200), + ], +) +def test_heart_artifact_removal(fs, highpass_freq, qrs, filter_coords): + """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" + raw = read_raw_fif(fname) + + # Do something with fs and highpass as processing of the data? + ... + + # call pca_obs algorithm + result = pca_obs(raw, qrs=qrs, filter_coords=filter_coords) + + # assert results + assert result is not None + assert result.shape == (100, 100) + assert result.shape == raw.shape # is this a condition we can test? + assert result[0, 0] == 1.0 + ... \ No newline at end of file From 3760c02efda91f709dbab4aaf1277ecbafce050c Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Mon, 11 Nov 2024 12:29:09 +0100 Subject: [PATCH 16/71] Removing extra examples --- .../rm_heart_artefact_cortical_mnedata.py | 86 ------------------- .../rm_heart_artefact_spinal_preciserpeak.py | 82 ------------------ 2 files changed, 168 deletions(-) delete mode 100644 mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py delete mode 100755 mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py deleted file mode 100644 index 10cb7772bbf..00000000000 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py +++ /dev/null @@ -1,86 +0,0 @@ -# Checking algorithm implementation with EEG data form MNE sample datasets -# Following this tutorial data: https://mne.tools/stable/auto_tutorials/preprocessing/50_artifact_correction_ssp.html#what-is-ssp - -import os - -import matplotlib.pyplot as plt -import numpy as np -from mne.preprocessing.pca_obs import pca_obs -from scipy.signal import firls - -from mne import Epochs -from mne.datasets.sample import data_path -from mne.io import read_raw_fif -from mne.preprocessing import ( - find_ecg_events, -) - -if __name__ == "__main__": - sample_data_folder = data_path(path="/data/pt_02569/mne_test_data/") - sample_data_raw_file = os.path.join( - sample_data_folder, "MEG", "sample", "sample_audvis_raw.fif" - ) - # here we crop and resample just for speed - raw = read_raw_fif(sample_data_raw_file, preload=True) - - # Find ECG events - no ECG channel in data, uses synthetic - ( - ecg_events, - ch_ecg, - average_pulse, - ) = find_ecg_events(raw) - # Extract just sample timings of ecg events - ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) - - # Create filter coefficients - fs = raw.info["sfreq"] - a = [0, 0, 1, 1] - f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter - # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter - ord = round(3 * fs / 0.5) - fwts = firls(ord + 1, f, a) - - # For heartbeat epochs - iv_baseline = [-300 / 1000, -200 / 1000] # 300 ms before to 200 ms after - iv_epoch = [-400 / 1000, 600 / 1000] # 300 ms before to 200 ms after - - # run PCA_OBS - # Algorithm is extremely sensitive to accurate R-peak timings, won't work as well with the artificial ECG - # channel estimation as we have here - epochs = Epochs( - raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], baseline=tuple(iv_baseline) - ) - evoked_before = epochs.average() - - # Apply function should modifies the data in raw in place - raw.apply_function( - pca_obs, - picks="eeg", - n_jobs=10, - # args sent to PCA_OBS - qrs=ecg_event_samples, - filter_coords=fwts, - ) - epochs = Epochs( - raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], baseline=tuple(iv_baseline) - ) - evoked_after = epochs.average() - - # Comparison image - fig, axes = plt.subplots(2, 1) - axes[0].plot(evoked_before.times, evoked_before.get_data().T) - axes[0].set_ylim([-1e-5, 3e-5]) - axes[0].set_title("Before PCA-OBS") - axes[1].plot(evoked_after.times, evoked_after.get_data().T) - axes[1].set_ylim([-1e-5, 3e-5]) - axes[1].set_title("After PCA-OBS") - plt.tight_layout() - - # Comparison image - fig, axes = plt.subplots(1, 1) - axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") - axes.set_ylim([-1e-5, 3e-5]) - axes.plot(evoked_after.times, evoked_after.get_data().T, color="green") - axes.set_title("Before (black) versus after (green)") - plt.tight_layout() - plt.show() diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py deleted file mode 100755 index 2a195e85a40..00000000000 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py +++ /dev/null @@ -1,82 +0,0 @@ -# Calls PCA_OBS which in turn calls fit_ecg_template to remove the heart artefact via PCA_OBS (Principal Component -# Analysis, Optimal Basis Sets) - -from matplotlib import pyplot as plt -from scipy.io import loadmat -from scipy.signal import firls -from mne.preprocessing.pca_obs import ( - fs, - iv_baseline, - iv_epoch, - raw, - pca_obs, - ESG_CHANS, -) -from mne import Epochs, events_from_annotations - - -if __name__ == "__main__": - - - # Read .mat file with QRS events - input_path_m = "/data/pt_02569/tmp_data/prepared/sub-001/esg/prepro/" - fname_m = f"raw_1000_spinal_median" - matdata = loadmat(input_path_m + fname_m + ".mat") - - # Create filter coefficients - a = [0, 0, 1, 1] - f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter - # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter - ord = round(3 * fs / 0.5) - fwts = firls(ord + 1, f, a) - - # run PCA_OBS - events, event_ids = events_from_annotations(raw) - event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} - epochs = Epochs( - raw, - events, - event_id=event_id_dict, - tmin=iv_epoch[0], - tmax=iv_epoch[1], - baseline=tuple(iv_baseline), - ) - evoked_before = epochs.average() - - # Apply function should modifies the data in raw in place - raw.apply_function( - pca_obs, - picks=ESG_CHANS, - n_jobs=len(ESG_CHANS), - # args sent to PCA_OBS - qrs=matdata["QRSevents"], - filter_coords=fwts, - ) - epochs = Epochs( - raw, - events, - event_id=event_id_dict, - tmin=iv_epoch[0], - tmax=iv_epoch[1], - baseline=tuple(iv_baseline), - ) - evoked_after = epochs.average() - - # Comparison image - fig, axes = plt.subplots(2, 1) - axes[0].plot(evoked_before.times, evoked_before.get_data().T) - axes[0].set_ylim([-0.0005, 0.001]) - axes[0].set_title("Before PCA-OBS") - axes[1].plot(evoked_after.times, evoked_after.get_data().T) - axes[1].set_ylim([-0.0005, 0.001]) - axes[1].set_title("After PCA-OBS") - plt.tight_layout() - - # Comparison image - fig, axes = plt.subplots(1, 1) - axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") - axes.set_ylim([-0.0005, 0.001]) - axes.plot(evoked_after.times, evoked_after.get_data().T, color="green") - axes.set_title("Before (black) versus after (green)") - plt.tight_layout() - plt.show() From 0a55712678b21c63a54c94d7a32710046e929579 Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Mon, 11 Nov 2024 12:36:02 +0100 Subject: [PATCH 17/71] Moving example --- .../preprocessing/esg_rm_heart_artefact_pcaobs.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py => examples/preprocessing/esg_rm_heart_artefact_pcaobs.py (100%) diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py similarity index 100% rename from mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py rename to examples/preprocessing/esg_rm_heart_artefact_pcaobs.py From e887ace5c40f89e0fe52e1d6ff4a68a5932ea98e Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Mon, 11 Nov 2024 14:20:52 +0100 Subject: [PATCH 18/71] Adding Niazy reference --- doc/references.bib | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/doc/references.bib b/doc/references.bib index a129d2f46a2..e2578ed18f2 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -1335,6 +1335,16 @@ @inproceedings{NdiayeEtAl2016 year = {2016} } +@article{NiazyEtAl2005, + author = {Niazy, R. K. and Beckmann, C.F. and Iannetti, G.D. and Brady, J. M. and Smith, S. M.}, + title = {Removal of FMRI environment artifacts from EEG data using optimal basis sets}, + journal = {NeuroImage}, + year = {2005}, + volume = {28}, + pages = {720-737}, + doi = {10.1016/j.neuroimage.2005.06.067.} +} + @article{NicholsHolmes2002, author = {Nichols, Thomas E. and Holmes, Andrew P.}, doi = {10.1002/hbm.1058}, From ee3b73eb59e7f6396edf7ef127f970d81826af58 Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Mon, 11 Nov 2024 16:13:12 +0100 Subject: [PATCH 19/71] Update example, put pca-obs in a single file --- .../esg_rm_heart_artefact_pcaobs.py | 252 ++++++++++++------ mne/preprocessing/pca_obs/__init__.py | 62 +---- .../pca_obs/_fit_ecg_template.py | 125 --------- mne/preprocessing/pca_obs/_pca_obs.py | 133 ++++++++- 4 files changed, 296 insertions(+), 276 deletions(-) delete mode 100755 mne/preprocessing/pca_obs/_fit_ecg_template.py diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index 0f0df7d6c75..e3179293e9e 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -1,85 +1,175 @@ -# Calls PCA_OBS which in turn calls fit_ecg_template to remove the heart artefact via PCA_OBS (Principal Component -# Analysis, Optimal Basis Sets) +""" +.. _ex-pcaobs: + +============================================================================================== +Principal Component Analysis - Optimal Basis Sets (PCA-OBS) for removal of cardiac artefact +============================================================================================== + +This script shows an example of how to use an adaptation of PCA-OBS +:footcite:`NiazyEtAl2005`. PCA-OBS was originally designed to remove +the ballistocardiographic artefact in simultaneous EEG-fMRI. Here, it +has been adapted to remove the delay between the detected R-peak and the +ballistocardiographic artefact such that the algorithm can be applied to +remove the cardiac artefact in EEG (electroencephalogrpahy) and ESG +(electrospinography) data. We will illustrate how it works by applying the +algorithm to ESG data, where the effect of removal is most pronounced. + +See: https://www.biorxiv.org/content/10.1101/2024.09.05.611423v1 +for more details on the dataset and application for ESG data. + +""" + +# Authors: Emma Bailey , Steinn Hauser Magnusson +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. from matplotlib import pyplot as plt -from mne.preprocessing.pca_obs import ( - fs, - iv_baseline, - iv_epoch, - raw, - pca_obs, - ESG_CHANS, -) +from mne.preprocessing.pca_obs import pca_obs +from mne.preprocessing import find_ecg_events, fix_stim_artifact +from mne.io import read_raw_eeglab from scipy.signal import firls import numpy as np -from mne import Epochs, events_from_annotations -from mne.preprocessing import find_ecg_events - -if __name__ == "__main__": - - # Find ECG events - no ECG channel in data, uses synthetic - ( - ecg_events, - ch_ecg, - average_pulse, - ) = find_ecg_events(raw, ch_name="ECG") - # Extract just sample timings of ecg events - ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) - - # Create filter coefficients - a = [0, 0, 1, 1] - f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter - # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter - ord = round(3 * fs / 0.5) - fwts = firls(ord + 1, f, a) - - # run PCA_OBS - events, event_ids = events_from_annotations(raw) - event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} - epochs = Epochs( - raw, - events, - event_id=event_id_dict, - tmin=iv_epoch[0], - tmax=iv_epoch[1], - baseline=tuple(iv_baseline), - ) - evoked_before = epochs.average() - - # Apply function should modifies the data in raw in place - raw.apply_function( - pca_obs, - picks=ESG_CHANS, - n_jobs=len(ESG_CHANS), - # args sent to PCA_OBS - qrs=ecg_event_samples, - filter_coords=fwts, - ) - epochs = Epochs( - raw, - events, - event_id=event_id_dict, - tmin=iv_epoch[0], - tmax=iv_epoch[1], - baseline=tuple(iv_baseline), - ) - evoked_after = epochs.average() - - # Comparison image - fig, axes = plt.subplots(2, 1) - axes[0].plot(evoked_before.times, evoked_before.get_data().T) - axes[0].set_ylim([-0.0005, 0.001]) - axes[0].set_title("Before PCA-OBS") - axes[1].plot(evoked_after.times, evoked_after.get_data().T) - axes[1].set_ylim([-0.0005, 0.001]) - axes[1].set_title("After PCA-OBS") - plt.tight_layout() - - # Comparison image - fig, axes = plt.subplots(1, 1) - axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") - axes.set_ylim([-0.0005, 0.001]) - axes.plot(evoked_after.times, evoked_after.get_data().T, color="green") - axes.set_title("Before (black) versus after (green)") - plt.tight_layout() - plt.show() +from mne import Epochs, events_from_annotations, concatenate_raws + +############################################################################### +# Download sample subject data from OpenNeuro if you haven't already +# This will download simultaneous EEG and ESG data from a single participant after +# median nerve stimulation of the left wrist +# Set the target directory to your desired location +import openneuro as on +import glob +target_dir = '/data/pt_02569/test_data' +file_list = glob.glob(target_dir + '/sub-001/eeg/*median*.set') +if file_list: + print('Data is already downloaded') +else: + on.download(dataset='ds004388', target_dir=target_dir, include='sub-001/*median*_eeg*') + +############################################################################### +# Define the esg channels (arranged in two patches over the neck and lower back) +# Also include the ECG channel for artefact correction +esg_chans = ["S35", "S24", "S36", "Iz", "S17", "S15", "S32", "S22", "S19", "S26", "S28", + "S9", "S13", "S11", "S7", "SC1", "S4", "S18", "S8", "S31", "SC6", "S12", + "S16", "S5", "S30", "S20", "S34", "S21", "S25", "L1", "S29", "S14", "S33", + "S3", "L4", "S6", "S23", 'ECG'] + +# Sampling rate +fs = 1000 + +# Interpolation window for ESG data to remove stimulation artefact +tstart_esg = -0.007 +tmax_esg = 0.007 + +# Define timing of heartbeat epochs +iv_baseline = [-300 / 1000, -200 / 1000] +iv_epoch = [-400 / 1000, 600 / 1000] + +############################################################################### +# Read in each of the four blocks and concatenate the raw structures after performing +# some minimal preprocessing including removing the stimulation artefact, downsampling +# and filtering +block_files = glob.glob(target_dir + '/sub-001/eeg/*median*.set') +block_files = sorted(block_files) + +for count, block_file in enumerate(block_files): + raw = read_raw_eeglab(block_file, eog=(), preload=True, uint16_codec=None, verbose=None) + + # Isolate the ESG channels only + raw.pick(esg_chans) + + # Find trigger timings to remove the stimulation artefact + events, event_dict = events_from_annotations(raw) + trigger_name = 'Median - Stimulation' + + fix_stim_artifact(raw, events=events, event_id=trigger_name, tmin=tstart_esg, tmax=tmax_esg, mode='linear', + stim_channel=None) + + # Downsample the data + raw.resample(fs) + + # Append blocks of the same condition + if count == 0: + raw_concat = raw + else: + concatenate_raws([raw_concat, raw]) + +############################################################################### +# Find ECG events and add to the raw structure as event annotations +ecg_events, ch_ecg, average_pulse = find_ecg_events(raw_concat, ch_name="ECG") +ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) # Samples only + +qrs_event_time = [x / fs for x in ecg_event_samples.reshape(-1)] # Divide by sampling rate to make times +duration = np.repeat(0.0, len(ecg_event_samples)) +description = ['qrs'] * len(ecg_event_samples) + +print(ecg_event_samples) +print(qrs_event_time) +raw_concat.annotations.append(qrs_event_time, duration, description, ch_names=[esg_chans]*len(qrs_event_time)) + +############################################################################### +# Create filter coefficients +a = [0, 0, 1, 1] +f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter +ord = round(3 * fs / 0.5) +fwts = firls(ord + 1, f, a) + +############################################################################### +# Create evoked response about the detected R-peaks before cardiac artefact correction +# Apply PCA-OBS to remove the cardiac artefact +# Create evoked response about the detected R-peaks after cardiac artefact correction +events, event_ids = events_from_annotations(raw_concat) +event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} +epochs = Epochs( + raw_concat, + events, + event_id=event_id_dict, + tmin=iv_epoch[0], + tmax=iv_epoch[1], + baseline=tuple(iv_baseline), +) +evoked_before = epochs.average() + +# Apply function - modifies the data in place +raw_concat.apply_function( + pca_obs, + picks=esg_chans, + n_jobs=len(esg_chans), + # args sent to PCA_OBS + qrs=ecg_event_samples, + filter_coords=fwts, +) + +epochs = Epochs( + raw_concat, + events, + event_id=event_id_dict, + tmin=iv_epoch[0], + tmax=iv_epoch[1], + baseline=tuple(iv_baseline), +) +evoked_after = epochs.average() + +############################################################################### +# Comparison image +fig, axes = plt.subplots(2, 1) +axes[0].plot(evoked_before.times, evoked_before.get_data().T) +axes[0].set_ylim([-0.0005, 0.001]) +axes[0].set_title("Before PCA-OBS") +axes[1].plot(evoked_after.times, evoked_after.get_data().T) +axes[1].set_ylim([-0.0005, 0.001]) +axes[1].set_title("After PCA-OBS") +plt.tight_layout() + +# Comparison image +fig, axes = plt.subplots(1, 1) +axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") +axes.set_ylim([-0.0005, 0.001]) +axes.plot(evoked_after.times, evoked_after.get_data().T, color="green") +axes.set_title("Before (black) versus after (green)") +plt.tight_layout() +plt.show() + +# %% +# References +# ---------- +# .. footbibliography:: diff --git a/mne/preprocessing/pca_obs/__init__.py b/mne/preprocessing/pca_obs/__init__.py index 88f6276264b..cfa48b95fec 100644 --- a/mne/preprocessing/pca_obs/__init__.py +++ b/mne/preprocessing/pca_obs/__init__.py @@ -1,69 +1,11 @@ -"""Principle Component Analysis of OBS (PCA-OBS).""" # TODO: What's OBS? +"""Principle Component Analysis Optimal Basis Sets (PCA-OBS).""" # Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. -from mne.io import read_raw_fif from ._pca_obs import pca_obs -from ._fit_ecg_template import fit_ecg_template __all__ = [ - "pca_obs", - "fit_ecg_template" + "pca_obs" ] - -# TODO: description of what this is -ESG_CHANS = [ - "S35", - "S24", - "S36", - "Iz", - "S17", - "S15", - "S32", - "S22", - "S19", - "S26", - "S28", - "S9", - "S13", - "S11", - "S7", - "SC1", - "S4", - "S18", - "S8", - "S31", - "SC6", - "S12", - "S16", - "S5", - "S30", - "S20", - "S34", - "AC", - "S21", - "S25", - "L1", - "S29", - "S14", - "S33", - "S3", - "AL", - "L4", - "S6", - "S23", -] - -# Set variables -fs = 1000 # sampling rate - -# For heartbeat epochs -iv_baseline = [-300 / 1000, -200 / 1000] -iv_epoch = [-400 / 1000, 600 / 1000] - -# Setting paths -input_path = "/data/pt_02569/tmp_data/prepared_py/sub-001/esg/prepro/" -fname = f"noStimart_sr{fs}_median_withqrs_pchip" -raw = read_raw_fif(input_path + fname + ".fif", preload=True) diff --git a/mne/preprocessing/pca_obs/_fit_ecg_template.py b/mne/preprocessing/pca_obs/_fit_ecg_template.py deleted file mode 100755 index 23f5a512539..00000000000 --- a/mne/preprocessing/pca_obs/_fit_ecg_template.py +++ /dev/null @@ -1,125 +0,0 @@ -import numpy as np -from scipy.interpolate import PchipInterpolator as pchip -from scipy.signal import detrend - - -def fit_ecg_template( - data, - pca_template, - aPeak_idx, - peak_range, - pre_range, - post_range, - midP, - fitted_art, - post_idx_previousPeak: list, - n_samples_fit, -) -> tuple[np.ndarray, list]: - """TODO: Write docstring about what we do here. - Fits the ECG to a template signal (?) - and returns the fitted artefact and the index of the next peak. (?) - - (TODO: are there any conditions that must be met to use our algos?) - .. note:: This should only be used on data which is ... - - # TODO: Fill out input/output and raises - Parameters - ---------- - data (_type_): _description_ - pca_template (_type_): _description_ - aPeak_idx (_type_): _description_ - peak_range (_type_): _description_ - pre_range (_type_): _description_ - post_range (_type_): _description_ - midP (_type_): _description_ - fitted_art (_type_): _description_ - post_idx_previousPeak (list): _description_ - n_samples_fit (_type_): _description_ - - Returns - ------- - tuple[np.ndarray, list]: the fitted artifact and the next peak index (if available) - """ - # Declare class to hold ecg fit information - class fitECG: - def __init__(self): - pass - - # Instantiate class - # TODO: Why are we storing this to a class? Can't we just use the variables and write to them? - fitecg = fitECG() - - # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previouspeak - # Then nextpeak is returned at the end and the process repeats - # select window of template - template = pca_template[midP - peak_range - 1 : midP + peak_range + 1, :] - - # select window of data and detrend it - slice = data[0, aPeak_idx[0] - peak_range : aPeak_idx[0] + peak_range + 1] - detrended_data = detrend(slice.reshape(-1), type="constant") - - # maps data on template and then maps it again back to the sensor space - least_square = np.linalg.lstsq(template, detrended_data, rcond=None) - pad_fit = np.dot(template, least_square[0]) - - # fit artifact, I already loop through externally channel to channel - fitted_art[0, aPeak_idx[0] - pre_range - 1 : aPeak_idx[0] + post_range] = pad_fit[ - midP - pre_range - 1 : midP + post_range - ].T - - fitecg.fitted_art = fitted_art - fitecg.template = template - fitecg.detrended_data = detrended_data - fitecg.pad_fit = pad_fit - fitecg.aPeak_idx = aPeak_idx - fitecg.midP = midP - fitecg.peak_range = peak_range - fitecg.data = data - - post_idx_nextPeak = [aPeak_idx[0] + post_range] - - # if last peak, return - if not post_idx_previousPeak: - return fitted_art, post_idx_nextPeak - - # interpolate time between peaks - intpol_window = np.ceil( - [post_idx_previousPeak[0], aPeak_idx[0] - pre_range] - ).astype("int") # interpolation window - fitecg.intpol_window = intpol_window - - if intpol_window[0] < intpol_window[1]: - # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data - - # You have x_fit which is two slices on either side of the interpolation window endpoints - # You have y_fit which is the y vals corresponding to x values above - # You have x_interpol which is the time points between the two slices in x_fit that you want to interpolate - # You have y_interpol which is values from pchip at the time points specified in x_interpol - x_interpol = np.arange( - intpol_window[0], intpol_window[1] + 1, 1 - ) # points to be interpolated in pt - the gap between the endpoints of the window - x_fit = np.concatenate( - [ - np.arange( - intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1 - ), - np.arange( - intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1 - ), - ] - ) # Entire range of x values in this step (taking some number of samples before and after the window) - y_fit = fitted_art[0, x_fit] - y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation - - # Then make fitted artefact in the desired range equal to the completed fit above - fitted_art[0, post_idx_previousPeak[0] : aPeak_idx[0] - pre_range + 1] = ( - y_interpol - ) - - fitecg.x_fit = x_fit - fitecg.y_fit = y_fit - fitecg.x_interpol = x_interpol - fitecg.y_interpol = y_interpol - fitecg.fitted_art = fitted_art # Reassign if we've gone into this loop - - return fitted_art, post_idx_nextPeak diff --git a/mne/preprocessing/pca_obs/_pca_obs.py b/mne/preprocessing/pca_obs/_pca_obs.py index c6d9f752d69..c3f2178f172 100755 --- a/mne/preprocessing/pca_obs/_pca_obs.py +++ b/mne/preprocessing/pca_obs/_pca_obs.py @@ -2,10 +2,131 @@ from typing import Any import numpy as np -from mne.preprocessing.pca_obs import fit_ecg_template - from scipy.signal import detrend, filtfilt from sklearn.decomposition import PCA +from scipy.interpolate import PchipInterpolator as pchip +from scipy.signal import detrend + +def fit_ecg_template( + data, + pca_template, + aPeak_idx, + peak_range, + pre_range, + post_range, + midP, + fitted_art, + post_idx_previousPeak: list, + n_samples_fit, +) -> tuple[np.ndarray, list]: + """TODO: Write docstring about what we do here. + Fits the ECG to a template signal (?) + and returns the fitted artefact and the index of the next peak. (?) + + (TODO: are there any conditions that must be met to use our algos?) + .. note:: This should only be used on data which is ... + + # TODO: Fill out input/output and raises + Parameters + ---------- + data (_type_): _description_ + pca_template (_type_): _description_ + aPeak_idx (_type_): _description_ + peak_range (_type_): _description_ + pre_range (_type_): _description_ + post_range (_type_): _description_ + midP (_type_): _description_ + fitted_art (_type_): _description_ + post_idx_previousPeak (list): _description_ + n_samples_fit (_type_): _description_ + + Returns + ------- + tuple[np.ndarray, list]: the fitted artifact and the next peak index (if available) + """ + # Declare class to hold ecg fit information + class fitECG: + def __init__(self): + pass + + # Instantiate class + # TODO: Why are we storing this to a class? Can't we just use the variables and write to them? + fitecg = fitECG() + + # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previouspeak + # Then nextpeak is returned at the end and the process repeats + # select window of template + template = pca_template[midP - peak_range - 1 : midP + peak_range + 1, :] + + # select window of data and detrend it + slice = data[0, aPeak_idx[0] - peak_range : aPeak_idx[0] + peak_range + 1] + detrended_data = detrend(slice.reshape(-1), type="constant") + + # maps data on template and then maps it again back to the sensor space + least_square = np.linalg.lstsq(template, detrended_data, rcond=None) + pad_fit = np.dot(template, least_square[0]) + + # fit artifact, I already loop through externally channel to channel + fitted_art[0, aPeak_idx[0] - pre_range - 1 : aPeak_idx[0] + post_range] = pad_fit[ + midP - pre_range - 1 : midP + post_range + ].T + + fitecg.fitted_art = fitted_art + fitecg.template = template + fitecg.detrended_data = detrended_data + fitecg.pad_fit = pad_fit + fitecg.aPeak_idx = aPeak_idx + fitecg.midP = midP + fitecg.peak_range = peak_range + fitecg.data = data + + post_idx_nextPeak = [aPeak_idx[0] + post_range] + + # if last peak, return + if not post_idx_previousPeak: + return fitted_art, post_idx_nextPeak + + # interpolate time between peaks + intpol_window = np.ceil( + [post_idx_previousPeak[0], aPeak_idx[0] - pre_range] + ).astype("int") # interpolation window + fitecg.intpol_window = intpol_window + + if intpol_window[0] < intpol_window[1]: + # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data + + # You have x_fit which is two slices on either side of the interpolation window endpoints + # You have y_fit which is the y vals corresponding to x values above + # You have x_interpol which is the time points between the two slices in x_fit that you want to interpolate + # You have y_interpol which is values from pchip at the time points specified in x_interpol + x_interpol = np.arange( + intpol_window[0], intpol_window[1] + 1, 1 + ) # points to be interpolated in pt - the gap between the endpoints of the window + x_fit = np.concatenate( + [ + np.arange( + intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1 + ), + np.arange( + intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1 + ), + ] + ) # Entire range of x values in this step (taking some number of samples before and after the window) + y_fit = fitted_art[0, x_fit] + y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation + + # Then make fitted artefact in the desired range equal to the completed fit above + fitted_art[0, post_idx_previousPeak[0] : aPeak_idx[0] - pre_range + 1] = ( + y_interpol + ) + + fitecg.x_fit = x_fit + fitecg.y_fit = y_fit + fitecg.x_interpol = x_interpol + fitecg.y_interpol = y_interpol + fitecg.fitted_art = fitted_art # Reassign if we've gone into this loop + + return fitted_art, post_idx_nextPeak # TODO: Are we able to split this into smaller segmented functions? @@ -213,17 +334,9 @@ def __init__(self): print(f"Cannot fit middle section of data. Reason: {e}") # Actually subtract the artefact, return needs to be the same shape as input data - # One sample shift purely due to the fact the r-peaks are currently detected in MATLAB data = data.reshape(-1) fitted_art = fitted_art.reshape(-1) - # One sample shift for my actual data (introduced using matlab r timings) - # data_ = np.zeros(len(data)) - # data_[0] = data[0] - # data_[1:] = data[1:] - fitted_art[:-1] - # data = data_ - - # Original code is this: data -= fitted_art data = data.T.reshape(-1) From adae35261cd5efd0eb54290d4cd093dd0f83b176 Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Mon, 11 Nov 2024 16:16:34 +0100 Subject: [PATCH 20/71] TODO --- mne/preprocessing/pca_obs/_pca_obs.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mne/preprocessing/pca_obs/_pca_obs.py b/mne/preprocessing/pca_obs/_pca_obs.py index c3f2178f172..e987abbe208 100755 --- a/mne/preprocessing/pca_obs/_pca_obs.py +++ b/mne/preprocessing/pca_obs/_pca_obs.py @@ -7,6 +7,9 @@ from scipy.interpolate import PchipInterpolator as pchip from scipy.signal import detrend +# TODO: This needs to be pulled out of the subfolder we've created and moved into the more 'normal' MNE setup +# with the _pca_obs in preprocessing as a single file only, _init integrated in their __init__.py and .pyi + def fit_ecg_template( data, pca_template, From b3c1b4e60f0e4a19642aeac683b6d9a4ebe24aed Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Mon, 11 Nov 2024 16:56:56 +0100 Subject: [PATCH 21/71] Update example --- .../esg_rm_heart_artefact_pcaobs.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index e3179293e9e..e24d5f42212 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -61,7 +61,7 @@ tmax_esg = 0.007 # Define timing of heartbeat epochs -iv_baseline = [-300 / 1000, -200 / 1000] +iv_baseline = [-400 / 1000, -300 / 1000] iv_epoch = [-400 / 1000, 600 / 1000] ############################################################################### @@ -81,7 +81,7 @@ events, event_dict = events_from_annotations(raw) trigger_name = 'Median - Stimulation' - fix_stim_artifact(raw, events=events, event_id=trigger_name, tmin=tstart_esg, tmax=tmax_esg, mode='linear', + fix_stim_artifact(raw, events=events, event_id=event_dict[trigger_name], tmin=tstart_esg, tmax=tmax_esg, mode='linear', stim_channel=None) # Downsample the data @@ -102,8 +102,6 @@ duration = np.repeat(0.0, len(ecg_event_samples)) description = ['qrs'] * len(ecg_event_samples) -print(ecg_event_samples) -print(qrs_event_time) raw_concat.annotations.append(qrs_event_time, duration, description, ch_names=[esg_chans]*len(qrs_event_time)) ############################################################################### @@ -150,16 +148,6 @@ evoked_after = epochs.average() ############################################################################### -# Comparison image -fig, axes = plt.subplots(2, 1) -axes[0].plot(evoked_before.times, evoked_before.get_data().T) -axes[0].set_ylim([-0.0005, 0.001]) -axes[0].set_title("Before PCA-OBS") -axes[1].plot(evoked_after.times, evoked_after.get_data().T) -axes[1].set_ylim([-0.0005, 0.001]) -axes[1].set_title("After PCA-OBS") -plt.tight_layout() - # Comparison image fig, axes = plt.subplots(1, 1) axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") From e6aa17d2c154bc85966f1fb2a29db03c41c29f69 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Fri, 15 Nov 2024 13:48:57 +0100 Subject: [PATCH 22/71] refactor: change way of calling pca obs method, add comments --- .../esg_rm_heart_artefact_pcaobs.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index e24d5f42212..da8583c9901 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -24,7 +24,7 @@ # Copyright the MNE-Python contributors. from matplotlib import pyplot as plt -from mne.preprocessing.pca_obs import pca_obs +import mne from mne.preprocessing import find_ecg_events, fix_stim_artifact from mne.io import read_raw_eeglab from scipy.signal import firls @@ -38,7 +38,11 @@ # Set the target directory to your desired location import openneuro as on import glob + +# add the path where you want the OpenNeuro data downloaded. Files total around 8 GB +# target_dir = "/home/steinnhm/personal/mne-data" target_dir = '/data/pt_02569/test_data' + file_list = glob.glob(target_dir + '/sub-001/eeg/*median*.set') if file_list: print('Data is already downloaded') @@ -128,13 +132,12 @@ evoked_before = epochs.average() # Apply function - modifies the data in place -raw_concat.apply_function( - pca_obs, - picks=esg_chans, - n_jobs=len(esg_chans), - # args sent to PCA_OBS - qrs=ecg_event_samples, - filter_coords=fwts, +mne.preprocessing.apply_pca_obs( + raw_concat, + picks=esg_chans, + n_jobs=4, + qrs=ecg_event_samples, + filter_coords=fwts ) epochs = Epochs( From b150b6c4419cbe02492bf45cc7cd05f382c67bbc Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Fri, 15 Nov 2024 13:50:19 +0100 Subject: [PATCH 23/71] refactor: move pca obs method out of separate python module, change logging to use mne logger instead of prints, add wrapper method in front of private _pca_obs method to handle parallel processing --- mne/preprocessing/__init__.pyi | 2 + .../{pca_obs/_pca_obs.py => pca_obs.py} | 44 ++++++++++++++----- mne/preprocessing/pca_obs/__init__.py | 11 ----- mne/preprocessing/pca_obs/tests/__init__.py | 0 .../pca_obs/tests/test_fit_ecg.py | 39 ---------------- .../{pca_obs => }/tests/test_pca_obs.py | 0 6 files changed, 35 insertions(+), 61 deletions(-) rename mne/preprocessing/{pca_obs/_pca_obs.py => pca_obs.py} (91%) delete mode 100644 mne/preprocessing/pca_obs/__init__.py delete mode 100644 mne/preprocessing/pca_obs/tests/__init__.py delete mode 100644 mne/preprocessing/pca_obs/tests/test_fit_ecg.py rename mne/preprocessing/{pca_obs => }/tests/test_pca_obs.py (100%) diff --git a/mne/preprocessing/__init__.pyi b/mne/preprocessing/__init__.pyi index 54f1c825c13..d58c6e77d24 100644 --- a/mne/preprocessing/__init__.pyi +++ b/mne/preprocessing/__init__.pyi @@ -44,6 +44,7 @@ __all__ = [ "realign_raw", "regress_artifact", "write_fine_calibration", + "apply_pca_obs", ] from . import eyetracking, ieeg, nirs from ._annotate_amplitude import annotate_amplitude @@ -89,3 +90,4 @@ from .realign import realign_raw from .ssp import compute_proj_ecg, compute_proj_eog from .stim import fix_stim_artifact from .xdawn import Xdawn +from .pca_obs import apply_pca_obs \ No newline at end of file diff --git a/mne/preprocessing/pca_obs/_pca_obs.py b/mne/preprocessing/pca_obs.py similarity index 91% rename from mne/preprocessing/pca_obs/_pca_obs.py rename to mne/preprocessing/pca_obs.py index e987abbe208..0b00607262f 100755 --- a/mne/preprocessing/pca_obs/_pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -1,5 +1,10 @@ +"""Principle Component Analysis Optimal Basis Sets (PCA-OBS).""" + +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + import math -from typing import Any import numpy as np from scipy.signal import detrend, filtfilt @@ -7,6 +12,9 @@ from scipy.interpolate import PchipInterpolator as pchip from scipy.signal import detrend +from mne.io.fiff.raw import Raw +from mne.utils import logger, warn + # TODO: This needs to be pulled out of the subfolder we've created and moved into the more 'normal' MNE setup # with the _pca_obs in preprocessing as a single file only, _init integrated in their __init__.py and .pyi @@ -132,9 +140,19 @@ def __init__(self): return fitted_art, post_idx_nextPeak +def apply_pca_obs(raw: Raw, picks: list[str], n_jobs: int, qrs: np.ndarray, filter_coords: np.ndarray) -> None: + raw.apply_function( + _pca_obs, + picks=picks, + n_jobs=n_jobs, + # args sent to PCA_OBS + qrs=qrs, + filter_coords=filter_coords, + ) + # TODO: Are we able to split this into smaller segmented functions? -def pca_obs( - data: np.ndarray, +def _pca_obs( + data: np.ndarray, qrs: np.ndarray, filter_coords: np.ndarray, ) -> np.ndarray: @@ -146,9 +164,13 @@ def pca_obs( Parameters ---------- - data (np.ndarray): The data which we want to remove the heart artefact from. - qrs (np.ndarray): _description_ - filter_coords (np.ndarray): _description_ + data: ndarray, shape (n_channels, n_times) + The data which we want to remove the heart artefact from. + + qrs: ndarray, shape (n_peaks, 1) + Array of times in (s), of detected R-peaks in ECG channel. + + filter_coords: ndarray Returns ------- @@ -185,7 +207,7 @@ def __init__(self): ################################################################ # Preparatory work - reserving memory, configure sizes, de-trend ################################################################ - print("Pulse artifact subtraction in progress...Please wait!") + logger.info("Pulse artifact subtraction in progress...Please wait!") # define peak range based on RR RR = np.diff(peak_idx[:, 0]) @@ -277,11 +299,11 @@ def __init__(self): window_start_idx.append(peak_idx[p] - peak_range) window_end_idx.append(peak_idx[p] + peak_range) except Exception as e: - print(f"Cannot fit first ECG epoch. Reason: {e}") + warn(f"Cannot fit first ECG epoch. Reason: {e}") # Deals with last edge of data elif p == peak_count: - print("On last section - almost there!") + logger.info("On last section - almost there!") try: pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) post_range = peak_range @@ -302,7 +324,7 @@ def __init__(self): window_start_idx.append(peak_idx[p] - peak_range) window_end_idx.append(peak_idx[p] + peak_range) except Exception as e: - print(f"Cannot fit last ECG epoch. Reason: {e}") + warn(f"Cannot fit last ECG epoch. Reason: {e}") # Deals with middle portion of data else: @@ -334,7 +356,7 @@ def __init__(self): window_start_idx.append(peak_idx[p] - peak_range) window_end_idx.append(peak_idx[p] + peak_range) except Exception as e: - print(f"Cannot fit middle section of data. Reason: {e}") + warn(f"Cannot fit middle section of data. Reason: {e}") # Actually subtract the artefact, return needs to be the same shape as input data data = data.reshape(-1) diff --git a/mne/preprocessing/pca_obs/__init__.py b/mne/preprocessing/pca_obs/__init__.py deleted file mode 100644 index cfa48b95fec..00000000000 --- a/mne/preprocessing/pca_obs/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Principle Component Analysis Optimal Basis Sets (PCA-OBS).""" - -# Authors: The MNE-Python contributors. -# License: BSD-3-Clause -# Copyright the MNE-Python contributors. - -from ._pca_obs import pca_obs - -__all__ = [ - "pca_obs" -] diff --git a/mne/preprocessing/pca_obs/tests/__init__.py b/mne/preprocessing/pca_obs/tests/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/mne/preprocessing/pca_obs/tests/test_fit_ecg.py b/mne/preprocessing/pca_obs/tests/test_fit_ecg.py deleted file mode 100644 index eb5c80bc8d8..00000000000 --- a/mne/preprocessing/pca_obs/tests/test_fit_ecg.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Test the fot_ecg_template function.""" - -# Authors: The MNE-Python contributors. -# License: BSD-3-Clause -# Copyright the MNE-Python contributors. - -from mne.io import read_raw_fif -from mne.preprocessing.pca_obs import fit_ecg_template -from mne.datasets.testing import data_path, requires_testing_data - -# TODO: Where are the test files we want to use located? -fname = data_path(download=False) / "eyetrack" / "test_eyelink.asc" - - -@requires_testing_data -def test_fit_ecg_template(): - """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" - raw = read_raw_fif(fname) - - # Somehow have to "fake" all these inputs to the function - result = fit_ecg_template( - data=None, - pca_template=None, - aPeak_idx=None, - peak_range=None, - pre_range=None, - post_range=None, - midP=None, - fitted_art=None, - post_idx_previousPeak=None, - n_samples_fit=None, - ) - - # assert results - assert result is not None - assert result.shape == (100, 100) - assert result.shape == raw.shape # is this a condition we can test? - assert result[0, 0] == 1.0 - ... \ No newline at end of file diff --git a/mne/preprocessing/pca_obs/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py similarity index 100% rename from mne/preprocessing/pca_obs/tests/test_pca_obs.py rename to mne/preprocessing/tests/test_pca_obs.py From 951e87ca95ed57bb2cad63e04d6e772319b5666f Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Fri, 15 Nov 2024 15:16:00 +0100 Subject: [PATCH 24/71] Refactor: Docstrings, removed classes --- .../esg_rm_heart_artefact_pcaobs.py | 11 +- mne/preprocessing/pca_obs.py | 117 ++++++------------ 2 files changed, 43 insertions(+), 85 deletions(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index da8583c9901..f5e3ad3757f 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -19,7 +19,8 @@ """ -# Authors: Emma Bailey , Steinn Hauser Magnusson +# Authors: Emma Bailey , +# Steinn Hauser Magnusson # License: BSD-3-Clause # Copyright the MNE-Python contributors. @@ -135,7 +136,7 @@ mne.preprocessing.apply_pca_obs( raw_concat, picks=esg_chans, - n_jobs=4, + n_jobs=5, qrs=ecg_event_samples, filter_coords=fwts ) @@ -154,9 +155,11 @@ # Comparison image fig, axes = plt.subplots(1, 1) axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") -axes.set_ylim([-0.0005, 0.001]) axes.plot(evoked_after.times, evoked_after.get_data().T, color="green") -axes.set_title("Before (black) versus after (green)") +axes.set_ylim([-0.0005, 0.001]) +axes.set_ylabel('Amplitude (V)') +axes.set_xlabel('Time (s)') +axes.set_title("Before (black) vs. After (green)") plt.tight_layout() plt.show() diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index 0b00607262f..2ef3851a58e 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -1,6 +1,7 @@ """Principle Component Analysis Optimal Basis Sets (PCA-OBS).""" -# Authors: The MNE-Python contributors. +# Authors: Emma Bailey , +# Steinn Hauser Magnusson # License: BSD-3-Clause # Copyright the MNE-Python contributors. @@ -15,8 +16,8 @@ from mne.io.fiff.raw import Raw from mne.utils import logger, warn -# TODO: This needs to be pulled out of the subfolder we've created and moved into the more 'normal' MNE setup -# with the _pca_obs in preprocessing as a single file only, _init integrated in their __init__.py and .pyi + +# TODO: check arguments passed in, raise errors, tests def fit_ecg_template( data, @@ -27,47 +28,37 @@ def fit_ecg_template( post_range, midP, fitted_art, - post_idx_previousPeak: list, + post_idx_previousPeak, n_samples_fit, ) -> tuple[np.ndarray, list]: - """TODO: Write docstring about what we do here. - Fits the ECG to a template signal (?) - and returns the fitted artefact and the index of the next peak. (?) - - (TODO: are there any conditions that must be met to use our algos?) - .. note:: This should only be used on data which is ... + """ + Fits the heartbeat artefact found in the data + Returns the fitted artefact and the index of the next peak. - # TODO: Fill out input/output and raises Parameters ---------- - data (_type_): _description_ - pca_template (_type_): _description_ - aPeak_idx (_type_): _description_ - peak_range (_type_): _description_ - pre_range (_type_): _description_ - post_range (_type_): _description_ - midP (_type_): _description_ - fitted_art (_type_): _description_ - post_idx_previousPeak (list): _description_ - n_samples_fit (_type_): _description_ + data (ndarray): Data from the raw signal (n_channels, n_times) + pca_template (ndarray): Mean heartbeat and first N (4) principal components of the heartbeat matrix + aPeak_idx (int): Sample index of current R-peak + peak_range (int): Half the median RR-interval + pre_range (int): Number of samples to fit before the R-peak + post_range (int): Number of samples to fit after the R-peak + midP (float): Sample index marking middle of the median RR interval in the signal. + Used to extract relevant part of PCA_template. + fitted_art (ndarray): The computed heartbeat artefact computed to remove from the data + post_idx_previousPeak (optional int): Sample index of previous R-peak + n_samples_fit (int): Sample fit for interpolation between fitted artifact windows. + Helps reduce sharp edges at the end of fitted heartbeat events. Returns ------- tuple[np.ndarray, list]: the fitted artifact and the next peak index (if available) """ - # Declare class to hold ecg fit information - class fitECG: - def __init__(self): - pass - - # Instantiate class - # TODO: Why are we storing this to a class? Can't we just use the variables and write to them? - fitecg = fitECG() # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previouspeak # Then nextpeak is returned at the end and the process repeats # select window of template - template = pca_template[midP - peak_range - 1 : midP + peak_range + 1, :] + template = pca_template[midP - peak_range - 1: midP + peak_range + 1, :] # select window of data and detrend it slice = data[0, aPeak_idx[0] - peak_range : aPeak_idx[0] + peak_range + 1] @@ -77,31 +68,19 @@ def __init__(self): least_square = np.linalg.lstsq(template, detrended_data, rcond=None) pad_fit = np.dot(template, least_square[0]) - # fit artifact, I already loop through externally channel to channel - fitted_art[0, aPeak_idx[0] - pre_range - 1 : aPeak_idx[0] + post_range] = pad_fit[ - midP - pre_range - 1 : midP + post_range + # fit artifact + fitted_art[0, aPeak_idx[0] - pre_range - 1: aPeak_idx[0] + post_range] = pad_fit[ + midP - pre_range - 1: midP + post_range ].T - fitecg.fitted_art = fitted_art - fitecg.template = template - fitecg.detrended_data = detrended_data - fitecg.pad_fit = pad_fit - fitecg.aPeak_idx = aPeak_idx - fitecg.midP = midP - fitecg.peak_range = peak_range - fitecg.data = data - - post_idx_nextPeak = [aPeak_idx[0] + post_range] - # if last peak, return - if not post_idx_previousPeak: - return fitted_art, post_idx_nextPeak + if post_idx_previousPeak is None: + return fitted_art, aPeak_idx[0] + post_range # interpolate time between peaks intpol_window = np.ceil( - [post_idx_previousPeak[0], aPeak_idx[0] - pre_range] + [post_idx_previousPeak, aPeak_idx[0] - pre_range] ).astype("int") # interpolation window - fitecg.intpol_window = intpol_window if intpol_window[0] < intpol_window[1]: # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data @@ -127,17 +106,11 @@ def __init__(self): y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation # Then make fitted artefact in the desired range equal to the completed fit above - fitted_art[0, post_idx_previousPeak[0] : aPeak_idx[0] - pre_range + 1] = ( + fitted_art[0, post_idx_previousPeak: aPeak_idx[0] - pre_range + 1] = ( y_interpol ) - fitecg.x_fit = x_fit - fitecg.y_fit = y_fit - fitecg.x_interpol = x_interpol - fitecg.y_interpol = y_interpol - fitecg.fitted_art = fitted_art # Reassign if we've gone into this loop - - return fitted_art, post_idx_nextPeak + return fitted_art, aPeak_idx[0] + post_range def apply_pca_obs(raw: Raw, picks: list[str], n_jobs: int, qrs: np.ndarray, filter_coords: np.ndarray) -> None: @@ -160,8 +133,6 @@ def _pca_obs( Algorithm to perform the PCA OBS (Principal Component Analysis, Optimal Basis Sets) algorithm to remove the heart artefact from EEG data. - .. note:: This should only be used on data which is ... (TODO: are there any conditions that must be met to use our algos?) - Parameters ---------- data: ndarray, shape (n_channels, n_times) @@ -170,21 +141,13 @@ def _pca_obs( qrs: ndarray, shape (n_peaks, 1) Array of times in (s), of detected R-peaks in ECG channel. - filter_coords: ndarray + filter_coords: ndarray (N, ) + The numerator coefficient vector of the filter passed to scipy.signal.filtfilt Returns ------- - np.ndarray: The data with the heart artefact removed. + np.ndarray: The data with the heart artefact suppressed. """ - # Declare class to hold pca information - class PCAInfo: - def __init__(self): - pass - - # NOTE: Here aswell, is there a reason we are storing this - # to a class? Shouldn't variables suffice? - # Instantiate class - pca_info = PCAInfo() # set to baseline data = data.reshape(-1, 1) @@ -250,18 +213,10 @@ def __init__(self): # run PCA(performs SVD(singular value decomposition)) pca = PCA(svd_solver="full") pca.fit(dpcamat) - eigen_vectors = pca.components_ - eigen_values = pca.explained_variance_ factor_loadings = pca.components_.T * np.sqrt(pca.explained_variance_) - pca_info.eigen_vectors = eigen_vectors - pca_info.factor_loadings = factor_loadings - pca_info.eigen_values = eigen_values - pca_info.expl_var = pca.explained_variance_ratio_ # define selected number of components using profile likelihood - pca_info.nComponents = 4 # TODO: Is this a variable? Or constant? Seems like a variable - pca_info.meanEffect = mean_effect.T - nComponents = pca_info.nComponents + nComponents = 4 ####################################################################### # Make template of the ECG artefact @@ -282,7 +237,7 @@ def __init__(self): if post_range > peak_range: post_range = peak_range try: - post_idx_nextPeak = [] + post_idx_nextPeak = None fitted_art, post_idx_nextPeak = fit_ecg_template( data, pca_template, @@ -302,7 +257,7 @@ def __init__(self): warn(f"Cannot fit first ECG epoch. Reason: {e}") # Deals with last edge of data - elif p == peak_count: + elif p == peak_count-1: logger.info("On last section - almost there!") try: pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) @@ -312,7 +267,7 @@ def __init__(self): fitted_art, _ = fit_ecg_template( data, pca_template, - peak_idx(p), + peak_idx[p], peak_range, pre_range, post_range, From c70db0a130e1601a171610d062191ac35d30cf11 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Sun, 24 Nov 2024 17:19:23 +0100 Subject: [PATCH 25/71] refactor: remove unrequired index references, adjust variable namings to have consistent patterns --- mne/preprocessing/pca_obs.py | 165 ++++++++++++++++++----------------- 1 file changed, 83 insertions(+), 82 deletions(-) diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index 2ef3851a58e..b9e0a01e64a 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -6,6 +6,7 @@ # Copyright the MNE-Python contributors. import math +from typing import Optional import numpy as np from scipy.signal import detrend, filtfilt @@ -20,17 +21,17 @@ # TODO: check arguments passed in, raise errors, tests def fit_ecg_template( - data, - pca_template, - aPeak_idx, - peak_range, - pre_range, - post_range, - midP, - fitted_art, - post_idx_previousPeak, - n_samples_fit, -) -> tuple[np.ndarray, list]: + data: np.ndarray, + pca_template: np.ndarray, + a_peak_idx: int, + peak_range: int, + pre_range: int, + post_range: int, + mid_p: float, + fitted_art: np.ndarray, + post_idx_previous_peak: Optional[int], + n_samples_fit: int, +) -> tuple[np.ndarray, int]: """ Fits the heartbeat artefact found in the data Returns the fitted artefact and the index of the next peak. @@ -39,29 +40,29 @@ def fit_ecg_template( ---------- data (ndarray): Data from the raw signal (n_channels, n_times) pca_template (ndarray): Mean heartbeat and first N (4) principal components of the heartbeat matrix - aPeak_idx (int): Sample index of current R-peak + a_peak_idx (int): Sample index of current R-peak peak_range (int): Half the median RR-interval pre_range (int): Number of samples to fit before the R-peak post_range (int): Number of samples to fit after the R-peak - midP (float): Sample index marking middle of the median RR interval in the signal. + mid_p (float): Sample index marking middle of the median RR interval in the signal. Used to extract relevant part of PCA_template. fitted_art (ndarray): The computed heartbeat artefact computed to remove from the data - post_idx_previousPeak (optional int): Sample index of previous R-peak + post_idx_previous_peak (optional int): Sample index of previous R-peak n_samples_fit (int): Sample fit for interpolation between fitted artifact windows. Helps reduce sharp edges at the end of fitted heartbeat events. Returns ------- - tuple[np.ndarray, list]: the fitted artifact and the next peak index (if available) + tuple[np.ndarray, int]: the fitted artifact and the next peak index """ - # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previouspeak + # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previous_peak # Then nextpeak is returned at the end and the process repeats # select window of template - template = pca_template[midP - peak_range - 1: midP + peak_range + 1, :] + template = pca_template[mid_p - peak_range - 1: mid_p + peak_range + 1, :] # select window of data and detrend it - slice = data[0, aPeak_idx[0] - peak_range : aPeak_idx[0] + peak_range + 1] + slice = data[0, a_peak_idx[0] - peak_range : a_peak_idx[0] + peak_range + 1] detrended_data = detrend(slice.reshape(-1), type="constant") # maps data on template and then maps it again back to the sensor space @@ -69,18 +70,18 @@ def fit_ecg_template( pad_fit = np.dot(template, least_square[0]) # fit artifact - fitted_art[0, aPeak_idx[0] - pre_range - 1: aPeak_idx[0] + post_range] = pad_fit[ - midP - pre_range - 1: midP + post_range + fitted_art[0, a_peak_idx[0] - pre_range - 1: a_peak_idx[0] + post_range] = pad_fit[ + mid_p - pre_range - 1: mid_p + post_range ].T # if last peak, return - if post_idx_previousPeak is None: - return fitted_art, aPeak_idx[0] + post_range + if post_idx_previous_peak is None: + return fitted_art, a_peak_idx[0] + post_range # interpolate time between peaks intpol_window = np.ceil( - [post_idx_previousPeak, aPeak_idx[0] - pre_range] - ).astype("int") # interpolation window + [post_idx_previous_peak, a_peak_idx[0] - pre_range] + ).astype(int) # interpolation window if intpol_window[0] < intpol_window[1]: # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data @@ -106,14 +107,18 @@ def fit_ecg_template( y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation # Then make fitted artefact in the desired range equal to the completed fit above - fitted_art[0, post_idx_previousPeak: aPeak_idx[0] - pre_range + 1] = ( + fitted_art[0, post_idx_previous_peak: a_peak_idx[0] - pre_range + 1] = ( y_interpol ) - return fitted_art, aPeak_idx[0] + post_range + return fitted_art, a_peak_idx[0] + post_range def apply_pca_obs(raw: Raw, picks: list[str], n_jobs: int, qrs: np.ndarray, filter_coords: np.ndarray) -> None: + """ + Main convenience function for applying the PCA-OBS algorithm + to certain picks of a Raw object. + """ raw.apply_function( _pca_obs, picks=picks, @@ -123,11 +128,11 @@ def apply_pca_obs(raw: Raw, picks: list[str], n_jobs: int, qrs: np.ndarray, filt filter_coords=filter_coords, ) -# TODO: Are we able to split this into smaller segmented functions? def _pca_obs( data: np.ndarray, qrs: np.ndarray, filter_coords: np.ndarray, + n_components: int = 4 # number of components to pick from the PCA ) -> np.ndarray: """ Algorithm to perform the PCA OBS (Principal Component Analysis, Optimal Basis Sets) @@ -167,25 +172,23 @@ def _pca_obs( peak_idx = peak_idx.reshape(-1, 1) peak_count = len(peak_idx) - ################################################################ - # Preparatory work - reserving memory, configure sizes, de-trend - ################################################################ + ################################################################## + # Preparatory work - reserving memory, configure sizes, de-trend # + ################################################################## logger.info("Pulse artifact subtraction in progress...Please wait!") # define peak range based on RR RR = np.diff(peak_idx[:, 0]) mRR = np.median(RR) peak_range = round(mRR / 2) # Rounds to an integer - midP = peak_range + 1 + mid_p = peak_range + 1 n_samples_fit = round( peak_range / 8 ) # sample fit for interpolation between fitted artifact windows # make sure array is long enough for PArange (if not cut off last ECG peak) - pa = peak_count # Number of QRS complexes detected - while peak_idx[pa - 1, 0] + peak_range > len(data[0]): - pa = pa - 1 - peak_count = pa + while peak_idx[peak_count - 1, 0] + peak_range > len(data[0]): + peak_count = peak_count - 1 # reduce number of QRS complexes detected # Filter channel eegchan = filtfilt(filter_coords, 1, data) @@ -202,34 +205,33 @@ def _pca_obs( pcamat = detrend( pcamat, type="constant", axis=1 ) # [epoch x time] - detrended along the epoch - mean_effect = np.mean( + mean_effect: np.ndarray = np.mean( pcamat, axis=0 ) # [1 x time], contains the mean over all epochs dpcamat = detrend(pcamat, type="constant", axis=1) # [time x epoch] - ################################################################### - # Perform PCA with sklearn - ################################################################### - # run PCA(performs SVD(singular value decomposition)) + ############################ + # Perform PCA with sklearn # + ############################ + # run PCA, perform singular value decomposition (SVD) pca = PCA(svd_solver="full") pca.fit(dpcamat) factor_loadings = pca.components_.T * np.sqrt(pca.explained_variance_) # define selected number of components using profile likelihood - nComponents = 4 - ####################################################################### - # Make template of the ECG artefact - ####################################################################### + ##################################### + # Make template of the ECG artefact # + ##################################### mean_effect = mean_effect.reshape(-1, 1) - pca_template = np.c_[mean_effect, factor_loadings[:, 0:nComponents]] + pca_template = np.c_[mean_effect, factor_loadings[:, :n_components]] - ################################################################################### - # Data Fitting - ################################################################################### + ################ + # Data Fitting # + ################ window_start_idx = [] window_end_idx = [] - for p in range(0, peak_count): + for p in range(peak_count): # Deals with start portion of data if p == 0: pre_range = peak_range @@ -239,16 +241,16 @@ def _pca_obs( try: post_idx_nextPeak = None fitted_art, post_idx_nextPeak = fit_ecg_template( - data, - pca_template, - peak_idx[p], - peak_range, - pre_range, - post_range, - midP, - fitted_art, - post_idx_nextPeak, - n_samples_fit, + data=data, + pca_template=pca_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_nextPeak, + n_samples_fit=n_samples_fit, ) # Appending to list instead of using counter window_start_idx.append(peak_idx[p] - peak_range) @@ -265,16 +267,16 @@ def _pca_obs( if pre_range > peak_range: pre_range = peak_range fitted_art, _ = fit_ecg_template( - data, - pca_template, - peak_idx[p], - peak_range, - pre_range, - post_range, - midP, - fitted_art, - post_idx_nextPeak, - n_samples_fit, + data=data, + pca_template=pca_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_nextPeak, + n_samples_fit=n_samples_fit, ) window_start_idx.append(peak_idx[p] - peak_range) window_end_idx.append(peak_idx[p] + peak_range) @@ -293,20 +295,20 @@ def _pca_obs( if post_range > peak_range: post_range = peak_range - aTemplate = pca_template[ - midP - peak_range - 1 : midP + peak_range + 1, : + a_template = pca_template[ + mid_p - peak_range - 1 : mid_p + peak_range + 1, : ] fitted_art, post_idx_nextPeak = fit_ecg_template( - data, - aTemplate, - peak_idx[p], - peak_range, - pre_range, - post_range, - midP, - fitted_art, - post_idx_nextPeak, - n_samples_fit, + data=data, + pca_template=a_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_nextPeak, + n_samples_fit=n_samples_fit, ) window_start_idx.append(peak_idx[p] - peak_range) window_end_idx.append(peak_idx[p] + peak_range) @@ -320,5 +322,4 @@ def _pca_obs( data -= fitted_art data = data.T.reshape(-1) - # Can only return data return data From fb8c7ff39a4b08143daf52b7abfa250a34bdfba3 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Sun, 24 Nov 2024 17:35:52 +0100 Subject: [PATCH 26/71] docs: minor edits in docstrings --- mne/preprocessing/pca_obs.py | 58 +++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index b9e0a01e64a..45973f82ded 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -39,17 +39,19 @@ def fit_ecg_template( Parameters ---------- data (ndarray): Data from the raw signal (n_channels, n_times) - pca_template (ndarray): Mean heartbeat and first N (4) principal components of the heartbeat matrix + pca_template (ndarray): Mean heartbeat and first N (default 4) + principal components of the heartbeat matrix a_peak_idx (int): Sample index of current R-peak peak_range (int): Half the median RR-interval pre_range (int): Number of samples to fit before the R-peak post_range (int): Number of samples to fit after the R-peak - mid_p (float): Sample index marking middle of the median RR interval in the signal. - Used to extract relevant part of PCA_template. - fitted_art (ndarray): The computed heartbeat artefact computed to remove from the data + mid_p (float): Sample index marking middle of the median RR interval + in the signal. Used to extract relevant part of PCA_template. + fitted_art (ndarray): The computed heartbeat artefact computed to + remove from the data post_idx_previous_peak (optional int): Sample index of previous R-peak n_samples_fit (int): Sample fit for interpolation between fitted artifact windows. - Helps reduce sharp edges at the end of fitted heartbeat events. + Helps reduce sharp edges at the end of fitted heartbeat events. Returns ------- @@ -114,10 +116,32 @@ def fit_ecg_template( return fitted_art, a_peak_idx[0] + post_range -def apply_pca_obs(raw: Raw, picks: list[str], n_jobs: int, qrs: np.ndarray, filter_coords: np.ndarray) -> None: +def apply_pca_obs( + raw: Raw, + picks: list[str], + qrs: np.ndarray, + filter_coords: np.ndarray, + n_components: int = 4, + n_jobs: Optional[int] = None, +) -> None: """ Main convenience function for applying the PCA-OBS algorithm - to certain picks of a Raw object. + to certain picks of a Raw object. Updates the Raw object in-place. + + Parameters + ---------- + raw: Raw + The raw data to process + picks: list[str] + Channels in the Raw object to remove the heart artefact from + qrs: ndarray, shape (n_peaks, 1) + Array of times in (s), of detected R-peaks in ECG channel. + filter_coords: ndarray (N, ) + The numerator coefficient vector of the filter passed to scipy.signal.filtfilt + n_components: int, default 4 + Number of PCA components to use to form the OBS + n_jobs: int, default None + Number of jobs to perform the PCA-OBS processing in parallel """ raw.apply_function( _pca_obs, @@ -126,32 +150,18 @@ def apply_pca_obs(raw: Raw, picks: list[str], n_jobs: int, qrs: np.ndarray, filt # args sent to PCA_OBS qrs=qrs, filter_coords=filter_coords, + n_components=n_components, ) def _pca_obs( data: np.ndarray, qrs: np.ndarray, filter_coords: np.ndarray, - n_components: int = 4 # number of components to pick from the PCA + n_components: int, ) -> np.ndarray: """ Algorithm to perform the PCA OBS (Principal Component Analysis, Optimal Basis Sets) - algorithm to remove the heart artefact from EEG data. - - Parameters - ---------- - data: ndarray, shape (n_channels, n_times) - The data which we want to remove the heart artefact from. - - qrs: ndarray, shape (n_peaks, 1) - Array of times in (s), of detected R-peaks in ECG channel. - - filter_coords: ndarray (N, ) - The numerator coefficient vector of the filter passed to scipy.signal.filtfilt - - Returns - ------- - np.ndarray: The data with the heart artefact suppressed. + algorithm to remove the heart artefact from EEG data (shape [n_channels, n_times]) """ # set to baseline From 4c73a962648c02570a97ea6b85fe0d941246ad7d Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Mon, 25 Nov 2024 09:10:12 +0100 Subject: [PATCH 27/71] fix: add minor sanity checks for the function inputs --- mne/preprocessing/pca_obs.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index 45973f82ded..fe877f7beb9 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -127,6 +127,7 @@ def apply_pca_obs( """ Main convenience function for applying the PCA-OBS algorithm to certain picks of a Raw object. Updates the Raw object in-place. + Makes sanity checks for all inputs. Parameters ---------- @@ -143,6 +144,13 @@ def apply_pca_obs( n_jobs: int, default None Number of jobs to perform the PCA-OBS processing in parallel """ + + if not qrs: + raise ValueError("qrs must not be empty") + + if not filter_coords: + raise ValueError("filter_coords must not be empty") + raw.apply_function( _pca_obs, picks=picks, From e69039b7ab3f22d6648a56b3f51a31fd48614213 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Mon, 25 Nov 2024 09:15:37 +0100 Subject: [PATCH 28/71] test: add initial test structure, missing validation of post-hear-artifact-removed data shapes and values --- mne/preprocessing/tests/test_pca_obs.py | 74 ++++++++++++++++++------- 1 file changed, 54 insertions(+), 20 deletions(-) diff --git a/mne/preprocessing/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py index 06d45062d6d..7c67db5899e 100644 --- a/mne/preprocessing/tests/test_pca_obs.py +++ b/mne/preprocessing/tests/test_pca_obs.py @@ -4,41 +4,75 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -# TODO: migrate this structure to test out function +import copy +from pathlib import Path -import pytest +import numpy as np +from scipy.signal import firls +import pytest from mne.io import read_raw_fif -from mne.preprocessing.pca_obs import pca_obs -from mne.datasets.testing import data_path, requires_testing_data -# TODO: Where are the test files we want to use located? -fname = data_path(download=False) / "eyetrack" / "test_eyelink.asc" +from mne.preprocessing import apply_pca_obs +from mne.preprocessing.ecg import find_ecg_events + +data_path = Path(__file__).parents[2] / "io" / "tests" / "data" +raw_fname = data_path / "test_raw.fif" + + +@pytest.fixture() +def short_raw_data(): + """Create a short, picked raw instance.""" + return read_raw_fif(raw_fname, preload=True).crop(0, 7) + -@requires_testing_data @pytest.mark.parametrize( # TODO: Are there any parameters we can cycle through to # test multiple? Different fs, windows, highpass freqs, etc.? # TODO: how do we determine qrs and filter_coords? What are these? - "fs, highpass_freq, qrs, filter_coords", + ("fs", "highpass_freq", "qrs", "filter_coords"), [ (0.2, 1.0, 100, 200), (0.1, 2.0, 100, 200), ], ) -def test_heart_artifact_removal(fs, highpass_freq, qrs, filter_coords): +def test_heart_artifact_removal(short_raw, fs, highpass_freq, qrs, filter_coords): """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" - raw = read_raw_fif(fname) - # Do something with fs and highpass as processing of the data? - ... + # get the sampling frequency of the test data and generate the filter coords as in our example + fs = short_raw.info["sfreq"] + a = [0, 0, 1, 1] + f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter + ord = round(3 * fs / 0.5) + filter_coords = firls(ord + 1, f, a) + + # extract the QRS + ecg_events, _, _ = find_ecg_events(short_raw, ch_name=None) + ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) + + # copy the original raw and remove the heart artifact in-place + raw_orig = copy.deepcopy(short_raw) + apply_pca_obs( + raw=short_raw, + picks=["eeg"], + qrs=ecg_event_samples, + filter_coords=filter_coords, + ) + # raw.get_data() ? to get shapes to compare + + assert raw_orig != short_raw + + # # Do something with fs and highpass as processing of the data? + # ... + + # # call pca_obs algorithm + # result = pca_obs(raw, qrs=qrs, filter_coords=filter_coords) - # call pca_obs algorithm - result = pca_obs(raw, qrs=qrs, filter_coords=filter_coords) + # # assert results + # assert result is not None + # assert result.shape == (100, 100) + # assert result.shape == raw.shape # is this a condition we can test? + # assert result[0, 0] == 1.0 - # assert results - assert result is not None - assert result.shape == (100, 100) - assert result.shape == raw.shape # is this a condition we can test? - assert result[0, 0] == 1.0 - ... \ No newline at end of file +if __name__ == "__main__": + pytest.main(["mne/preprocessing/tests/test_pca_obs.py"]) \ No newline at end of file From 4bd3a519e36412898b9beea725c70b025d9c2138 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Mon, 25 Nov 2024 09:19:00 +0100 Subject: [PATCH 29/71] style: run pre-commit hooks on test file --- mne/preprocessing/tests/test_pca_obs.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/mne/preprocessing/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py index 7c67db5899e..2edb6619c10 100644 --- a/mne/preprocessing/tests/test_pca_obs.py +++ b/mne/preprocessing/tests/test_pca_obs.py @@ -5,14 +5,13 @@ # Copyright the MNE-Python contributors. import copy -from pathlib import Path +from pathlib import Path import numpy as np +import pytest from scipy.signal import firls -import pytest from mne.io import read_raw_fif - from mne.preprocessing import apply_pca_obs from mne.preprocessing.ecg import find_ecg_events @@ -27,7 +26,7 @@ def short_raw_data(): @pytest.mark.parametrize( - # TODO: Are there any parameters we can cycle through to + # TODO: Are there any parameters we can cycle through to # test multiple? Different fs, windows, highpass freqs, etc.? # TODO: how do we determine qrs and filter_coords? What are these? ("fs", "highpass_freq", "qrs", "filter_coords"), @@ -38,17 +37,17 @@ def short_raw_data(): ) def test_heart_artifact_removal(short_raw, fs, highpass_freq, qrs, filter_coords): """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" - - # get the sampling frequency of the test data and generate the filter coords as in our example + # get the sampling frequency of the test data and + # generate the filter coords as in our example fs = short_raw.info["sfreq"] a = [0, 0, 1, 1] f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter - ord = round(3 * fs / 0.5) - filter_coords = firls(ord + 1, f, a) + ord_ = round(3 * fs / 0.5) + filter_coords = firls(ord_ + 1, f, a) # extract the QRS ecg_events, _, _ = find_ecg_events(short_raw, ch_name=None) - ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) + ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) # copy the original raw and remove the heart artifact in-place raw_orig = copy.deepcopy(short_raw) @@ -59,11 +58,10 @@ def test_heart_artifact_removal(short_raw, fs, highpass_freq, qrs, filter_coords filter_coords=filter_coords, ) # raw.get_data() ? to get shapes to compare - + assert raw_orig != short_raw # # Do something with fs and highpass as processing of the data? - # ... # # call pca_obs algorithm # result = pca_obs(raw, qrs=qrs, filter_coords=filter_coords) @@ -71,8 +69,9 @@ def test_heart_artifact_removal(short_raw, fs, highpass_freq, qrs, filter_coords # # assert results # assert result is not None # assert result.shape == (100, 100) - # assert result.shape == raw.shape # is this a condition we can test? + # assert result.shape == raw.shape # is this a condition we can test? # assert result[0, 0] == 1.0 + if __name__ == "__main__": - pytest.main(["mne/preprocessing/tests/test_pca_obs.py"]) \ No newline at end of file + pytest.main(["mne/preprocessing/tests/test_pca_obs.py"]) From 07458ba8fd7ac23cf5a38696056085f3ab1c4a76 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Mon, 25 Nov 2024 09:23:49 +0100 Subject: [PATCH 30/71] docs: remove duplicated docstring --- mne/preprocessing/pca_obs.py | 59 +++++++++++++++--------------------- 1 file changed, 25 insertions(+), 34 deletions(-) diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index fe877f7beb9..2eb14a951db 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -6,20 +6,18 @@ # Copyright the MNE-Python contributors. import math -from typing import Optional import numpy as np +from scipy.interpolate import PchipInterpolator as pchip from scipy.signal import detrend, filtfilt from sklearn.decomposition import PCA -from scipy.interpolate import PchipInterpolator as pchip -from scipy.signal import detrend from mne.io.fiff.raw import Raw from mne.utils import logger, warn - # TODO: check arguments passed in, raise errors, tests + def fit_ecg_template( data: np.ndarray, pca_template: np.ndarray, @@ -29,7 +27,7 @@ def fit_ecg_template( post_range: int, mid_p: float, fitted_art: np.ndarray, - post_idx_previous_peak: Optional[int], + post_idx_previous_peak: int | None, n_samples_fit: int, ) -> tuple[np.ndarray, int]: """ @@ -39,15 +37,15 @@ def fit_ecg_template( Parameters ---------- data (ndarray): Data from the raw signal (n_channels, n_times) - pca_template (ndarray): Mean heartbeat and first N (default 4) + pca_template (ndarray): Mean heartbeat and first N (default 4) principal components of the heartbeat matrix a_peak_idx (int): Sample index of current R-peak peak_range (int): Half the median RR-interval pre_range (int): Number of samples to fit before the R-peak post_range (int): Number of samples to fit after the R-peak - mid_p (float): Sample index marking middle of the median RR interval + mid_p (float): Sample index marking middle of the median RR interval in the signal. Used to extract relevant part of PCA_template. - fitted_art (ndarray): The computed heartbeat artefact computed to + fitted_art (ndarray): The computed heartbeat artefact computed to remove from the data post_idx_previous_peak (optional int): Sample index of previous R-peak n_samples_fit (int): Sample fit for interpolation between fitted artifact windows. @@ -57,11 +55,10 @@ def fit_ecg_template( ------- tuple[np.ndarray, int]: the fitted artifact and the next peak index """ - # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previous_peak # Then nextpeak is returned at the end and the process repeats # select window of template - template = pca_template[mid_p - peak_range - 1: mid_p + peak_range + 1, :] + template = pca_template[mid_p - peak_range - 1 : mid_p + peak_range + 1, :] # select window of data and detrend it slice = data[0, a_peak_idx[0] - peak_range : a_peak_idx[0] + peak_range + 1] @@ -72,8 +69,8 @@ def fit_ecg_template( pad_fit = np.dot(template, least_square[0]) # fit artifact - fitted_art[0, a_peak_idx[0] - pre_range - 1: a_peak_idx[0] + post_range] = pad_fit[ - mid_p - pre_range - 1: mid_p + post_range + fitted_art[0, a_peak_idx[0] - pre_range - 1 : a_peak_idx[0] + post_range] = pad_fit[ + mid_p - pre_range - 1 : mid_p + post_range ].T # if last peak, return @@ -81,9 +78,9 @@ def fit_ecg_template( return fitted_art, a_peak_idx[0] + post_range # interpolate time between peaks - intpol_window = np.ceil( - [post_idx_previous_peak, a_peak_idx[0] - pre_range] - ).astype(int) # interpolation window + intpol_window = np.ceil([post_idx_previous_peak, a_peak_idx[0] - pre_range]).astype( + int + ) # interpolation window if intpol_window[0] < intpol_window[1]: # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data @@ -97,19 +94,15 @@ def fit_ecg_template( ) # points to be interpolated in pt - the gap between the endpoints of the window x_fit = np.concatenate( [ - np.arange( - intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1 - ), - np.arange( - intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1 - ), + np.arange(intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1), + np.arange(intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1), ] ) # Entire range of x values in this step (taking some number of samples before and after the window) y_fit = fitted_art[0, x_fit] y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation # Then make fitted artefact in the desired range equal to the completed fit above - fitted_art[0, post_idx_previous_peak: a_peak_idx[0] - pre_range + 1] = ( + fitted_art[0, post_idx_previous_peak : a_peak_idx[0] - pre_range + 1] = ( y_interpol ) @@ -117,15 +110,15 @@ def fit_ecg_template( def apply_pca_obs( - raw: Raw, - picks: list[str], - qrs: np.ndarray, + raw: Raw, + picks: list[str], + qrs: np.ndarray, filter_coords: np.ndarray, n_components: int = 4, - n_jobs: Optional[int] = None, + n_jobs: int | None = None, ) -> None: """ - Main convenience function for applying the PCA-OBS algorithm + Main convenience function for applying the PCA-OBS algorithm to certain picks of a Raw object. Updates the Raw object in-place. Makes sanity checks for all inputs. @@ -144,10 +137,9 @@ def apply_pca_obs( n_jobs: int, default None Number of jobs to perform the PCA-OBS processing in parallel """ - if not qrs: raise ValueError("qrs must not be empty") - + if not filter_coords: raise ValueError("filter_coords must not be empty") @@ -161,6 +153,7 @@ def apply_pca_obs( n_components=n_components, ) + def _pca_obs( data: np.ndarray, qrs: np.ndarray, @@ -168,10 +161,8 @@ def _pca_obs( n_components: int, ) -> np.ndarray: """ - Algorithm to perform the PCA OBS (Principal Component Analysis, Optimal Basis Sets) - algorithm to remove the heart artefact from EEG data (shape [n_channels, n_times]) + algorithm to remove the heart artefact from EEG data (shape [n_channels, n_times]). """ - # set to baseline data = data.reshape(-1, 1) data = data.T @@ -206,7 +197,7 @@ def _pca_obs( # make sure array is long enough for PArange (if not cut off last ECG peak) while peak_idx[peak_count - 1, 0] + peak_range > len(data[0]): - peak_count = peak_count - 1 # reduce number of QRS complexes detected + peak_count = peak_count - 1 # reduce number of QRS complexes detected # Filter channel eegchan = filtfilt(filter_coords, 1, data) @@ -277,7 +268,7 @@ def _pca_obs( warn(f"Cannot fit first ECG epoch. Reason: {e}") # Deals with last edge of data - elif p == peak_count-1: + elif p == peak_count - 1: logger.info("On last section - almost there!") try: pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) From a5e68d7ec3bdf3e77dff723a118e9019a8ab20b4 Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Wed, 27 Nov 2024 10:26:36 +0100 Subject: [PATCH 31/71] Removed filter_coords from within the method --- .../esg_rm_heart_artefact_pcaobs.py | 10 +-------- mne/preprocessing/pca_obs.py | 21 ++++++------------- mne/preprocessing/tests/test_pca_obs.py | 17 +++++---------- 3 files changed, 12 insertions(+), 36 deletions(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index f5e3ad3757f..cd1f2cd48c0 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -109,13 +109,6 @@ raw_concat.annotations.append(qrs_event_time, duration, description, ch_names=[esg_chans]*len(qrs_event_time)) -############################################################################### -# Create filter coefficients -a = [0, 0, 1, 1] -f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter -ord = round(3 * fs / 0.5) -fwts = firls(ord + 1, f, a) - ############################################################################### # Create evoked response about the detected R-peaks before cardiac artefact correction # Apply PCA-OBS to remove the cardiac artefact @@ -137,8 +130,7 @@ raw_concat, picks=esg_chans, n_jobs=5, - qrs=ecg_event_samples, - filter_coords=fwts + qrs=ecg_event_samples ) epochs = Epochs( diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index 2eb14a951db..54e64e730b2 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -9,7 +9,7 @@ import numpy as np from scipy.interpolate import PchipInterpolator as pchip -from scipy.signal import detrend, filtfilt +from scipy.signal import detrend from sklearn.decomposition import PCA from mne.io.fiff.raw import Raw @@ -113,7 +113,6 @@ def apply_pca_obs( raw: Raw, picks: list[str], qrs: np.ndarray, - filter_coords: np.ndarray, n_components: int = 4, n_jobs: int | None = None, ) -> None: @@ -130,18 +129,15 @@ def apply_pca_obs( Channels in the Raw object to remove the heart artefact from qrs: ndarray, shape (n_peaks, 1) Array of times in (s), of detected R-peaks in ECG channel. - filter_coords: ndarray (N, ) - The numerator coefficient vector of the filter passed to scipy.signal.filtfilt n_components: int, default 4 Number of PCA components to use to form the OBS n_jobs: int, default None Number of jobs to perform the PCA-OBS processing in parallel """ - if not qrs: - raise ValueError("qrs must not be empty") - - if not filter_coords: - raise ValueError("filter_coords must not be empty") + # TODO: Causes error 'ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()' + # Removed for now + # if not qrs: + # raise ValueError("qrs must not be empty") raw.apply_function( _pca_obs, @@ -149,7 +145,6 @@ def apply_pca_obs( n_jobs=n_jobs, # args sent to PCA_OBS qrs=qrs, - filter_coords=filter_coords, n_components=n_components, ) @@ -157,7 +152,6 @@ def apply_pca_obs( def _pca_obs( data: np.ndarray, qrs: np.ndarray, - filter_coords: np.ndarray, n_components: int, ) -> np.ndarray: """ @@ -199,14 +193,11 @@ def _pca_obs( while peak_idx[peak_count - 1, 0] + peak_range > len(data[0]): peak_count = peak_count - 1 # reduce number of QRS complexes detected - # Filter channel - eegchan = filtfilt(filter_coords, 1, data) - # build PCA matrix(heart-beat-epochs x window-length) pcamat = np.zeros((peak_count - 1, 2 * peak_range + 1)) # [epoch x time] # picking out heartbeat epochs for p in range(1, peak_count): - pcamat[p - 1, :] = eegchan[ + pcamat[p - 1, :] = data[ 0, peak_idx[p, 0] - peak_range : peak_idx[p, 0] + peak_range + 1 ] diff --git a/mne/preprocessing/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py index 2edb6619c10..46367d1eb21 100644 --- a/mne/preprocessing/tests/test_pca_obs.py +++ b/mne/preprocessing/tests/test_pca_obs.py @@ -28,22 +28,16 @@ def short_raw_data(): @pytest.mark.parametrize( # TODO: Are there any parameters we can cycle through to # test multiple? Different fs, windows, highpass freqs, etc.? - # TODO: how do we determine qrs and filter_coords? What are these? - ("fs", "highpass_freq", "qrs", "filter_coords"), + # TODO: how do we determine qrs? What are these? + # QRS is marking the sample index of R-peaks in the signal + ("fs", "highpass_freq", "qrs"), [ (0.2, 1.0, 100, 200), (0.1, 2.0, 100, 200), ], ) -def test_heart_artifact_removal(short_raw, fs, highpass_freq, qrs, filter_coords): +def test_heart_artifact_removal(short_raw, fs, highpass_freq, qrs): """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" - # get the sampling frequency of the test data and - # generate the filter coords as in our example - fs = short_raw.info["sfreq"] - a = [0, 0, 1, 1] - f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter - ord_ = round(3 * fs / 0.5) - filter_coords = firls(ord_ + 1, f, a) # extract the QRS ecg_events, _, _ = find_ecg_events(short_raw, ch_name=None) @@ -55,7 +49,6 @@ def test_heart_artifact_removal(short_raw, fs, highpass_freq, qrs, filter_coords raw=short_raw, picks=["eeg"], qrs=ecg_event_samples, - filter_coords=filter_coords, ) # raw.get_data() ? to get shapes to compare @@ -64,7 +57,7 @@ def test_heart_artifact_removal(short_raw, fs, highpass_freq, qrs, filter_coords # # Do something with fs and highpass as processing of the data? # # call pca_obs algorithm - # result = pca_obs(raw, qrs=qrs, filter_coords=filter_coords) + # result = pca_obs(raw, qrs=qrs) # # assert results # assert result is not None From 1938573956b2db32d9f02d2d27c592bebdfdf6b0 Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Wed, 27 Nov 2024 10:43:07 +0100 Subject: [PATCH 32/71] Adding info to filter to example --- examples/preprocessing/esg_rm_heart_artefact_pcaobs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index cd1f2cd48c0..52bb3af73dc 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -126,6 +126,7 @@ evoked_before = epochs.average() # Apply function - modifies the data in place +# Optionally high-pass filter the data before applying PCA-OBS to remove low frequency drifts mne.preprocessing.apply_pca_obs( raw_concat, picks=esg_chans, From 19e0802aee20e88dcc1de1502987650a2b17bf16 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Wed, 4 Dec 2024 22:17:02 +0100 Subject: [PATCH 33/71] style: run import sorter pre-commit hook --- mne/preprocessing/__init__.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/preprocessing/__init__.pyi b/mne/preprocessing/__init__.pyi index d58c6e77d24..d0a6a1dd742 100644 --- a/mne/preprocessing/__init__.pyi +++ b/mne/preprocessing/__init__.pyi @@ -86,8 +86,8 @@ from .maxwell import ( maxwell_filter_prepare_emptyroom, ) from .otp import oversampled_temporal_projection +from .pca_obs import apply_pca_obs from .realign import realign_raw from .ssp import compute_proj_ecg, compute_proj_eog from .stim import fix_stim_artifact from .xdawn import Xdawn -from .pca_obs import apply_pca_obs \ No newline at end of file From 49a96d09e3c0232d509f639aa57fdcd77450951a Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Wed, 4 Dec 2024 22:17:13 +0100 Subject: [PATCH 34/71] refactor,test: migrate data shape to be 1d, add sanity checks for PCA function input, ad tests for copying original data and comparing to data modified in-place, add window size checks and remove generic try-except blocks BREAKING CHANGE --- mne/preprocessing/pca_obs.py | 242 +++++++++++------------- mne/preprocessing/tests/test_pca_obs.py | 75 ++++---- 2 files changed, 146 insertions(+), 171 deletions(-) diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index 54e64e730b2..aeb5478e7eb 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -13,9 +13,6 @@ from sklearn.decomposition import PCA from mne.io.fiff.raw import Raw -from mne.utils import logger, warn - -# TODO: check arguments passed in, raise errors, tests def fit_ecg_template( @@ -31,7 +28,8 @@ def fit_ecg_template( n_samples_fit: int, ) -> tuple[np.ndarray, int]: """ - Fits the heartbeat artefact found in the data + Fits the heartbeat artefact found in the data. + Returns the fitted artefact and the index of the next peak. Parameters @@ -48,8 +46,8 @@ def fit_ecg_template( fitted_art (ndarray): The computed heartbeat artefact computed to remove from the data post_idx_previous_peak (optional int): Sample index of previous R-peak - n_samples_fit (int): Sample fit for interpolation between fitted artifact windows. - Helps reduce sharp edges at the end of fitted heartbeat events. + n_samples_fit (int): Sample fit for interpolation in fitted artifact + windows. Helps reduce sharp edges at end of fitted heartbeat events Returns ------- @@ -61,52 +59,55 @@ def fit_ecg_template( template = pca_template[mid_p - peak_range - 1 : mid_p + peak_range + 1, :] # select window of data and detrend it - slice = data[0, a_peak_idx[0] - peak_range : a_peak_idx[0] + peak_range + 1] - detrended_data = detrend(slice.reshape(-1), type="constant") + slice_ = data[a_peak_idx - peak_range : a_peak_idx + peak_range + 1] + + detrended_data = detrend(slice_, type="constant") # maps data on template and then maps it again back to the sensor space least_square = np.linalg.lstsq(template, detrended_data, rcond=None) pad_fit = np.dot(template, least_square[0]) # fit artifact - fitted_art[0, a_peak_idx[0] - pre_range - 1 : a_peak_idx[0] + post_range] = pad_fit[ + fitted_art[a_peak_idx - pre_range - 1 : a_peak_idx + post_range] = pad_fit[ mid_p - pre_range - 1 : mid_p + post_range ].T # if last peak, return if post_idx_previous_peak is None: - return fitted_art, a_peak_idx[0] + post_range + return fitted_art, a_peak_idx + post_range # interpolate time between peaks - intpol_window = np.ceil([post_idx_previous_peak, a_peak_idx[0] - pre_range]).astype( + intpol_window = np.ceil([post_idx_previous_peak, a_peak_idx - pre_range]).astype( int ) # interpolation window if intpol_window[0] < intpol_window[1]: # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data - # You have x_fit which is two slices on either side of the interpolation window endpoints + # You have x_fit which is two slices on either side of the interpolation window + # endpoints # You have y_fit which is the y vals corresponding to x values above - # You have x_interpol which is the time points between the two slices in x_fit that you want to interpolate - # You have y_interpol which is values from pchip at the time points specified in x_interpol - x_interpol = np.arange( - intpol_window[0], intpol_window[1] + 1, 1 - ) # points to be interpolated in pt - the gap between the endpoints of the window + # You have x_interpol which is the time points between the two slices in x_fit + # that you want to interpolate + # You have y_interpol which is values from pchip at the time points specified in + # x_interpol + # points to be interpolated in pt - the gap between the endpoints of the window + x_interpol = np.arange(intpol_window[0], intpol_window[1] + 1, 1) + # Entire range of x values in this step (taking some + # number of samples before and after the window) x_fit = np.concatenate( [ np.arange(intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1), np.arange(intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1), ] - ) # Entire range of x values in this step (taking some number of samples before and after the window) - y_fit = fitted_art[0, x_fit] + ) + y_fit = fitted_art[x_fit] y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation - # Then make fitted artefact in the desired range equal to the completed fit above - fitted_art[0, post_idx_previous_peak : a_peak_idx[0] - pre_range + 1] = ( - y_interpol - ) + # make fitted artefact in the desired range equal to the completed fit above + fitted_art[post_idx_previous_peak : a_peak_idx - pre_range + 1] = y_interpol - return fitted_art, a_peak_idx[0] + post_range + return fitted_art, a_peak_idx + post_range def apply_pca_obs( @@ -117,9 +118,9 @@ def apply_pca_obs( n_jobs: int | None = None, ) -> None: """ - Main convenience function for applying the PCA-OBS algorithm - to certain picks of a Raw object. Updates the Raw object in-place. - Makes sanity checks for all inputs. + Apply the PCA-OBS algorithm to picks of a Raw object. + + Update the Raw object in-place. Make sanity checks for all inputs. Parameters ---------- @@ -134,10 +135,13 @@ def apply_pca_obs( n_jobs: int, default None Number of jobs to perform the PCA-OBS processing in parallel """ - # TODO: Causes error 'ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()' - # Removed for now - # if not qrs: - # raise ValueError("qrs must not be empty") + # sanity checks + if len(qrs.shape) > 1: + raise ValueError("qrs must be a 1d array") + if not isinstance(n_jobs, int) or n_jobs < 1: + raise ValueError("n_jobs must be an integer greater than 0") + if not picks: + raise ValueError("picks must be a list of channel names") raw.apply_function( _pca_obs, @@ -154,35 +158,22 @@ def _pca_obs( qrs: np.ndarray, n_components: int, ) -> np.ndarray: - """ - algorithm to remove the heart artefact from EEG data (shape [n_channels, n_times]). - """ + """Algorithm to remove heart artefact from EEG data (array of length n_times).""" # set to baseline - data = data.reshape(-1, 1) - data = data.T - data = data - np.mean(data, axis=1) + data = data - np.mean(data) - # Allocate memory + # Allocate memory for artifact which will be subtracted from the data fitted_art = np.zeros(data.shape) - peakplot = np.zeros(data.shape) - - # Extract QRS events - for idx in qrs[0]: - if idx < len(peakplot[0, :]): - peakplot[0, idx] = 1 # logical indexed locations of qrs events - peak_idx = np.nonzero(peakplot)[1] # Selecting indices along columns - peak_idx = peak_idx.reshape(-1, 1) + # Extract QRS event indexes which are within out data timeframe + peak_idx = qrs[qrs < len(data)] peak_count = len(peak_idx) ################################################################## # Preparatory work - reserving memory, configure sizes, de-trend # ################################################################## - logger.info("Pulse artifact subtraction in progress...Please wait!") - # define peak range based on RR - RR = np.diff(peak_idx[:, 0]) - mRR = np.median(RR) + mRR = np.median(np.diff(peak_idx)) peak_range = round(mRR / 2) # Rounds to an integer mid_p = peak_range + 1 n_samples_fit = round( @@ -190,16 +181,15 @@ def _pca_obs( ) # sample fit for interpolation between fitted artifact windows # make sure array is long enough for PArange (if not cut off last ECG peak) - while peak_idx[peak_count - 1, 0] + peak_range > len(data[0]): + # NOTE: Here we previously checked for the last part of the window to be big enough. + while peak_idx[peak_count - 1] + peak_range > len(data): peak_count = peak_count - 1 # reduce number of QRS complexes detected # build PCA matrix(heart-beat-epochs x window-length) pcamat = np.zeros((peak_count - 1, 2 * peak_range + 1)) # [epoch x time] # picking out heartbeat epochs for p in range(1, peak_count): - pcamat[p - 1, :] = data[ - 0, peak_idx[p, 0] - peak_range : peak_idx[p, 0] + peak_range + 1 - ] + pcamat[p - 1, :] = data[peak_idx[p] - peak_range : peak_idx[p] + peak_range + 1] # detrending matrix(twice) pcamat = detrend( @@ -218,7 +208,7 @@ def _pca_obs( pca.fit(dpcamat) factor_loadings = pca.components_.T * np.sqrt(pca.explained_variance_) - # define selected number of components using profile likelihood + # define selected number of components using profile likelihood ##################################### # Make template of the ECG artefact # @@ -231,95 +221,87 @@ def _pca_obs( ################ window_start_idx = [] window_end_idx = [] + post_idx_nextPeak = None + for p in range(peak_count): + # if the current peak doesn't have enough data in the + # start of the peak_range, skip fitting the artifact + if peak_idx[p] - peak_range < 0: + continue + # Deals with start portion of data if p == 0: pre_range = peak_range post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) if post_range > peak_range: post_range = peak_range - try: - post_idx_nextPeak = None - fitted_art, post_idx_nextPeak = fit_ecg_template( - data=data, - pca_template=pca_template, - a_peak_idx=peak_idx[p], - peak_range=peak_range, - pre_range=pre_range, - post_range=post_range, - mid_p=mid_p, - fitted_art=fitted_art, - post_idx_previous_peak=post_idx_nextPeak, - n_samples_fit=n_samples_fit, - ) - # Appending to list instead of using counter - window_start_idx.append(peak_idx[p] - peak_range) - window_end_idx.append(peak_idx[p] + peak_range) - except Exception as e: - warn(f"Cannot fit first ECG epoch. Reason: {e}") + + fitted_art, post_idx_nextPeak = fit_ecg_template( + data=data, + pca_template=pca_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_nextPeak, + n_samples_fit=n_samples_fit, + ) + # Appending to list instead of using counter + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) # Deals with last edge of data elif p == peak_count - 1: - logger.info("On last section - almost there!") - try: - pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) - post_range = peak_range - if pre_range > peak_range: - pre_range = peak_range - fitted_art, _ = fit_ecg_template( - data=data, - pca_template=pca_template, - a_peak_idx=peak_idx[p], - peak_range=peak_range, - pre_range=pre_range, - post_range=post_range, - mid_p=mid_p, - fitted_art=fitted_art, - post_idx_previous_peak=post_idx_nextPeak, - n_samples_fit=n_samples_fit, - ) - window_start_idx.append(peak_idx[p] - peak_range) - window_end_idx.append(peak_idx[p] + peak_range) - except Exception as e: - warn(f"Cannot fit last ECG epoch. Reason: {e}") + pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) + post_range = peak_range + if pre_range > peak_range: + pre_range = peak_range + fitted_art, _ = fit_ecg_template( + data=data, + pca_template=pca_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_nextPeak, + n_samples_fit=n_samples_fit, + ) + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) # Deals with middle portion of data else: - try: - # ---------------- Processing of central data - -------------------- - # cycle through peak artifacts identified by peakplot - pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) - post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) - if pre_range >= peak_range: - pre_range = peak_range - if post_range > peak_range: - post_range = peak_range - - a_template = pca_template[ - mid_p - peak_range - 1 : mid_p + peak_range + 1, : - ] - fitted_art, post_idx_nextPeak = fit_ecg_template( - data=data, - pca_template=a_template, - a_peak_idx=peak_idx[p], - peak_range=peak_range, - pre_range=pre_range, - post_range=post_range, - mid_p=mid_p, - fitted_art=fitted_art, - post_idx_previous_peak=post_idx_nextPeak, - n_samples_fit=n_samples_fit, - ) - window_start_idx.append(peak_idx[p] - peak_range) - window_end_idx.append(peak_idx[p] + peak_range) - except Exception as e: - warn(f"Cannot fit middle section of data. Reason: {e}") + # ---------------- Processing of central data - -------------------- + # cycle through peak artifacts identified by peakplot + pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) + post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) + if pre_range >= peak_range: + pre_range = peak_range + if post_range > peak_range: + post_range = peak_range - # Actually subtract the artefact, return needs to be the same shape as input data - data = data.reshape(-1) - fitted_art = fitted_art.reshape(-1) + a_template = pca_template[ + mid_p - peak_range - 1 : mid_p + peak_range + 1, : + ] + fitted_art, post_idx_nextPeak = fit_ecg_template( + data=data, + pca_template=a_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_nextPeak, + n_samples_fit=n_samples_fit, + ) + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + # Actually subtract the artefact, return needs to be the same shape as input data data -= fitted_art - data = data.T.reshape(-1) - return data diff --git a/mne/preprocessing/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py index 46367d1eb21..82b1efa6ebf 100644 --- a/mne/preprocessing/tests/test_pca_obs.py +++ b/mne/preprocessing/tests/test_pca_obs.py @@ -4,16 +4,15 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -import copy from pathlib import Path import numpy as np +import pandas as pd import pytest -from scipy.signal import firls from mne.io import read_raw_fif +from mne.io.fiff.raw import Raw from mne.preprocessing import apply_pca_obs -from mne.preprocessing.ecg import find_ecg_events data_path = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = data_path / "test_raw.fif" @@ -22,49 +21,43 @@ @pytest.fixture() def short_raw_data(): """Create a short, picked raw instance.""" - return read_raw_fif(raw_fname, preload=True).crop(0, 7) + return read_raw_fif(raw_fname, preload=True) -@pytest.mark.parametrize( - # TODO: Are there any parameters we can cycle through to - # test multiple? Different fs, windows, highpass freqs, etc.? - # TODO: how do we determine qrs? What are these? - # QRS is marking the sample index of R-peaks in the signal - ("fs", "highpass_freq", "qrs"), - [ - (0.2, 1.0, 100, 200), - (0.1, 2.0, 100, 200), - ], -) -def test_heart_artifact_removal(short_raw, fs, highpass_freq, qrs): +def test_heart_artifact_removal(short_raw_data: Raw): """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" + # fake some random qrs events + ecg_event_samples = np.arange(0, len(short_raw_data.times), 1400) + 1430 - # extract the QRS - ecg_events, _, _ = find_ecg_events(short_raw, ch_name=None) - ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) + # copy the original raw. heart artifact is removed in-place + orig_df: pd.DataFrame = short_raw_data.to_data_frame().copy(deep=True) - # copy the original raw and remove the heart artifact in-place - raw_orig = copy.deepcopy(short_raw) - apply_pca_obs( - raw=short_raw, - picks=["eeg"], - qrs=ecg_event_samples, - ) - # raw.get_data() ? to get shapes to compare - - assert raw_orig != short_raw - - # # Do something with fs and highpass as processing of the data? + # perform heart artifact removal + apply_pca_obs(raw=short_raw_data, picks=["eeg"], qrs=ecg_event_samples, n_jobs=1) - # # call pca_obs algorithm - # result = pca_obs(raw, qrs=qrs) - - # # assert results - # assert result is not None - # assert result.shape == (100, 100) - # assert result.shape == raw.shape # is this a condition we can test? - # assert result[0, 0] == 1.0 + # compare processed df to original df + removed_heart_artifact_df: pd.DataFrame = short_raw_data.to_data_frame() + # ensure all column names remain the same + pd.testing.assert_index_equal( + orig_df.columns, + removed_heart_artifact_df.columns, + ) -if __name__ == "__main__": - pytest.main(["mne/preprocessing/tests/test_pca_obs.py"]) + # ensure every column starting with EEG has been altered + altered_cols = [c for c in orig_df.columns if c.startswith("EEG")] + for col in altered_cols: + with pytest.raises( + AssertionError + ): # make sure that error is raised when we check equal + pd.testing.assert_series_equal( + orig_df[col], + removed_heart_artifact_df[col], + ) + + # ensure every column not starting with EEG has not been altered + unaltered_cols = [c for c in orig_df.columns if not c.startswith("EEG")] + pd.testing.assert_frame_equal( + orig_df[unaltered_cols], + removed_heart_artifact_df[unaltered_cols], + ) From 925b2931235506cfe1fdad1127280aa8fa518719 Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Fri, 6 Dec 2024 16:23:00 +0100 Subject: [PATCH 35/71] example:Reshape ecg_events before pca_obs, tests:ecg_events is in samples for algorith --- examples/preprocessing/esg_rm_heart_artefact_pcaobs.py | 2 +- mne/preprocessing/pca_obs.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index 52bb3af73dc..773cb7410ed 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -131,7 +131,7 @@ raw_concat, picks=esg_chans, n_jobs=5, - qrs=ecg_event_samples + qrs=ecg_event_samples.reshape(-1) ) epochs = Epochs( diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index aeb5478e7eb..05320cf22ae 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -129,7 +129,7 @@ def apply_pca_obs( picks: list[str] Channels in the Raw object to remove the heart artefact from qrs: ndarray, shape (n_peaks, 1) - Array of times in (s), of detected R-peaks in ECG channel. + Array of times in (sample indices), of detected R-peaks in ECG channel. n_components: int, default 4 Number of PCA components to use to form the OBS n_jobs: int, default None From a3d11237a6a7c56e9d3275f5ca75fc8b27e70446 Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Fri, 6 Dec 2024 21:50:08 +0100 Subject: [PATCH 36/71] example: update channel selection --- examples/preprocessing/esg_rm_heart_artefact_pcaobs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index 773cb7410ed..6a66a2057cb 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -56,7 +56,7 @@ esg_chans = ["S35", "S24", "S36", "Iz", "S17", "S15", "S32", "S22", "S19", "S26", "S28", "S9", "S13", "S11", "S7", "SC1", "S4", "S18", "S8", "S31", "SC6", "S12", "S16", "S5", "S30", "S20", "S34", "S21", "S25", "L1", "S29", "S14", "S33", - "S3", "L4", "S6", "S23", 'ECG'] + "S3", "L4", "S6", "S23"] # Sampling rate fs = 1000 @@ -79,8 +79,8 @@ for count, block_file in enumerate(block_files): raw = read_raw_eeglab(block_file, eog=(), preload=True, uint16_codec=None, verbose=None) - # Isolate the ESG channels only - raw.pick(esg_chans) + # Isolate the ESG channels (including ECG for R-peak detection) + raw.pick(esg_chans + ['ECG']) # Find trigger timings to remove the stimulation artefact events, event_dict = events_from_annotations(raw) From 2ae84b0f5a69a10be243ead7a54c8cad4431161e Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Wed, 18 Dec 2024 14:29:28 +0100 Subject: [PATCH 37/71] refactor,test: change public qrs kwarg to be more clear about being indices, add sanity checks for input values, add negative-test which verifies proper exceptions when bad data is passed to function --- .../esg_rm_heart_artefact_pcaobs.py | 124 ++++++++++++------ mne/preprocessing/pca_obs.py | 26 ++-- mne/preprocessing/tests/test_pca_obs.py | 38 +++++- 3 files changed, 137 insertions(+), 51 deletions(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index 6a66a2057cb..b75552f657a 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -1,9 +1,9 @@ """ .. _ex-pcaobs: -============================================================================================== -Principal Component Analysis - Optimal Basis Sets (PCA-OBS) for removal of cardiac artefact -============================================================================================== +===================================================================================== +Principal Component Analysis - Optimal Basis Sets (PCA-OBS) removing cardiac artefact +===================================================================================== This script shows an example of how to use an adaptation of PCA-OBS :footcite:`NiazyEtAl2005`. PCA-OBS was originally designed to remove @@ -24,13 +24,9 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -from matplotlib import pyplot as plt -import mne -from mne.preprocessing import find_ecg_events, fix_stim_artifact -from mne.io import read_raw_eeglab -from scipy.signal import firls +import glob + import numpy as np -from mne import Epochs, events_from_annotations, concatenate_raws ############################################################################### # Download sample subject data from OpenNeuro if you haven't already @@ -38,25 +34,67 @@ # median nerve stimulation of the left wrist # Set the target directory to your desired location import openneuro as on -import glob +from matplotlib import pyplot as plt + +import mne +from mne import Epochs, concatenate_raws, events_from_annotations +from mne.io import read_raw_eeglab +from mne.preprocessing import find_ecg_events, fix_stim_artifact # add the path where you want the OpenNeuro data downloaded. Files total around 8 GB # target_dir = "/home/steinnhm/personal/mne-data" -target_dir = '/data/pt_02569/test_data' +target_dir = "/data/pt_02569/test_data" -file_list = glob.glob(target_dir + '/sub-001/eeg/*median*.set') +file_list = glob.glob(target_dir + "/sub-001/eeg/*median*.set") if file_list: - print('Data is already downloaded') + print("Data is already downloaded") else: - on.download(dataset='ds004388', target_dir=target_dir, include='sub-001/*median*_eeg*') + on.download( + dataset="ds004388", target_dir=target_dir, include="sub-001/*median*_eeg*" + ) ############################################################################### # Define the esg channels (arranged in two patches over the neck and lower back) # Also include the ECG channel for artefact correction -esg_chans = ["S35", "S24", "S36", "Iz", "S17", "S15", "S32", "S22", "S19", "S26", "S28", - "S9", "S13", "S11", "S7", "SC1", "S4", "S18", "S8", "S31", "SC6", "S12", - "S16", "S5", "S30", "S20", "S34", "S21", "S25", "L1", "S29", "S14", "S33", - "S3", "L4", "S6", "S23"] +esg_chans = [ + "S35", + "S24", + "S36", + "Iz", + "S17", + "S15", + "S32", + "S22", + "S19", + "S26", + "S28", + "S9", + "S13", + "S11", + "S7", + "SC1", + "S4", + "S18", + "S8", + "S31", + "SC6", + "S12", + "S16", + "S5", + "S30", + "S20", + "S34", + "S21", + "S25", + "L1", + "S29", + "S14", + "S33", + "S3", + "L4", + "S6", + "S23", +] # Sampling rate fs = 1000 @@ -73,21 +111,30 @@ # Read in each of the four blocks and concatenate the raw structures after performing # some minimal preprocessing including removing the stimulation artefact, downsampling # and filtering -block_files = glob.glob(target_dir + '/sub-001/eeg/*median*.set') +block_files = glob.glob(target_dir + "/sub-001/eeg/*median*.set") block_files = sorted(block_files) for count, block_file in enumerate(block_files): - raw = read_raw_eeglab(block_file, eog=(), preload=True, uint16_codec=None, verbose=None) + raw = read_raw_eeglab( + block_file, eog=(), preload=True, uint16_codec=None, verbose=None + ) # Isolate the ESG channels (including ECG for R-peak detection) - raw.pick(esg_chans + ['ECG']) + raw.pick(esg_chans + ["ECG"]) # Find trigger timings to remove the stimulation artefact events, event_dict = events_from_annotations(raw) - trigger_name = 'Median - Stimulation' - - fix_stim_artifact(raw, events=events, event_id=event_dict[trigger_name], tmin=tstart_esg, tmax=tmax_esg, mode='linear', - stim_channel=None) + trigger_name = "Median - Stimulation" + + fix_stim_artifact( + raw, + events=events, + event_id=event_dict[trigger_name], + tmin=tstart_esg, + tmax=tmax_esg, + mode="linear", + stim_channel=None, + ) # Downsample the data raw.resample(fs) @@ -101,13 +148,19 @@ ############################################################################### # Find ECG events and add to the raw structure as event annotations ecg_events, ch_ecg, average_pulse = find_ecg_events(raw_concat, ch_name="ECG") -ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) # Samples only +ecg_event_samples = np.asarray( + [[ecg_event[0] for ecg_event in ecg_events]] +) # Samples only -qrs_event_time = [x / fs for x in ecg_event_samples.reshape(-1)] # Divide by sampling rate to make times +qrs_event_time = [ + x / fs for x in ecg_event_samples.reshape(-1) +] # Divide by sampling rate to make times duration = np.repeat(0.0, len(ecg_event_samples)) -description = ['qrs'] * len(ecg_event_samples) +description = ["qrs"] * len(ecg_event_samples) -raw_concat.annotations.append(qrs_event_time, duration, description, ch_names=[esg_chans]*len(qrs_event_time)) +raw_concat.annotations.append( + qrs_event_time, duration, description, ch_names=[esg_chans] * len(qrs_event_time) +) ############################################################################### # Create evoked response about the detected R-peaks before cardiac artefact correction @@ -125,13 +178,10 @@ ) evoked_before = epochs.average() -# Apply function - modifies the data in place -# Optionally high-pass filter the data before applying PCA-OBS to remove low frequency drifts +# Apply function - modifies the data in place. Optionally high-pass filter +# the data before applying PCA-OBS to remove low frequency drifts mne.preprocessing.apply_pca_obs( - raw_concat, - picks=esg_chans, - n_jobs=5, - qrs=ecg_event_samples.reshape(-1) + raw_concat, picks=esg_chans, n_jobs=5, qrs_indices=ecg_event_samples.reshape(-1) ) epochs = Epochs( @@ -150,8 +200,8 @@ axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") axes.plot(evoked_after.times, evoked_after.get_data().T, color="green") axes.set_ylim([-0.0005, 0.001]) -axes.set_ylabel('Amplitude (V)') -axes.set_xlabel('Time (s)') +axes.set_ylabel("Amplitude (V)") +axes.set_xlabel("Time (s)") axes.set_title("Before (black) vs. After (green)") plt.tight_layout() plt.show() diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index 05320cf22ae..0c21dae8158 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -13,6 +13,7 @@ from sklearn.decomposition import PCA from mne.io.fiff.raw import Raw +from mne.utils import logger def fit_ecg_template( @@ -113,7 +114,7 @@ def fit_ecg_template( def apply_pca_obs( raw: Raw, picks: list[str], - qrs: np.ndarray, + qrs_indices: np.ndarray, n_components: int = 4, n_jobs: int | None = None, ) -> None: @@ -128,18 +129,25 @@ def apply_pca_obs( The raw data to process picks: list[str] Channels in the Raw object to remove the heart artefact from - qrs: ndarray, shape (n_peaks, 1) - Array of times in (sample indices), of detected R-peaks in ECG channel. + qrs_indices: ndarray, shape (n_peaks, 1) + Array of indices in the Raw data of detected R-peaks in ECG channel. n_components: int, default 4 Number of PCA components to use to form the OBS n_jobs: int, default None - Number of jobs to perform the PCA-OBS processing in parallel + Number of jobs to perform the PCA-OBS processing in parallel. + Passed on to Raw.apply_function """ # sanity checks - if len(qrs.shape) > 1: - raise ValueError("qrs must be a 1d array") - if not isinstance(n_jobs, int) or n_jobs < 1: - raise ValueError("n_jobs must be an integer greater than 0") + if not isinstance(qrs_indices, np.ndarray): + raise ValueError("qrs_indices must be an array") + if len(qrs_indices.shape) > 1: + raise ValueError("qrs_indices must be a 1d array") + if qrs_indices.dtype != int: + raise ValueError("qrs_indices must be an array of integers") + if np.any(qrs_indices < 0): + raise ValueError("qrs_indices must be strictly positive integers") + if np.any(qrs_indices >= raw.n_times): + logger.warning("out of bound qrs_indices will be ignored..") if not picks: raise ValueError("picks must be a list of channel names") @@ -148,7 +156,7 @@ def apply_pca_obs( picks=picks, n_jobs=n_jobs, # args sent to PCA_OBS - qrs=qrs, + qrs=qrs_indices, n_components=n_components, ) diff --git a/mne/preprocessing/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py index 82b1efa6ebf..e2d07a8ce72 100644 --- a/mne/preprocessing/tests/test_pca_obs.py +++ b/mne/preprocessing/tests/test_pca_obs.py @@ -1,5 +1,3 @@ -"""Test the ieeg projection functions.""" - # Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. @@ -26,14 +24,17 @@ def short_raw_data(): def test_heart_artifact_removal(short_raw_data: Raw): """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" - # fake some random qrs events - ecg_event_samples = np.arange(0, len(short_raw_data.times), 1400) + 1430 + # fake some random qrs events in the window of the raw data + # remove first and last samples and cast to integer for indexing + ecg_event_indices = np.linspace(0, short_raw_data.n_times, 20, dtype=int)[1:-1] # copy the original raw. heart artifact is removed in-place orig_df: pd.DataFrame = short_raw_data.to_data_frame().copy(deep=True) # perform heart artifact removal - apply_pca_obs(raw=short_raw_data, picks=["eeg"], qrs=ecg_event_samples, n_jobs=1) + apply_pca_obs( + raw=short_raw_data, picks=["eeg"], qrs_indices=ecg_event_indices, n_jobs=1 + ) # compare processed df to original df removed_heart_artifact_df: pd.DataFrame = short_raw_data.to_data_frame() @@ -61,3 +62,30 @@ def test_heart_artifact_removal(short_raw_data: Raw): orig_df[unaltered_cols], removed_heart_artifact_df[unaltered_cols], ) + + +# test that various nonsensical inputs raise the proper errors +@pytest.mark.parametrize( + ("picks", "qrs", "error"), + [ + (["eeg"], np.array([[0, 1], [2, 3]]), "qrs_indices must be a 1d array"), + (["eeg"], [2, 3, 4], "qrs_indices must be an array"), + ( + ["eeg"], + np.array([None, "foo", 2]), + "qrs_indices must be an array of integers", + ), + ( + ["eeg"], + np.array([-1, 0, 3]), + "qrs_indices must be strictly positive integers", + ), + ([], np.array([1, 2, 3]), "picks must be a list of channel names"), + ], +) +def test_pca_obs_bad_input( + short_raw_data: Raw, picks: list[str], qrs: np.ndarray, error: str +): + """Test if bad input data raises the proper errors in the function sanity checks.""" + with pytest.raises(ValueError, match=error): + apply_pca_obs(raw=short_raw_data, picks=picks, qrs_indices=qrs) From d19049b88b35b4c774a0aa429c5ce456885cad2c Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Wed, 18 Dec 2024 14:40:36 +0100 Subject: [PATCH 38/71] style: move fit_ecg_template function to bottom of file to improve readability --- mne/preprocessing/pca_obs.py | 190 +++++++++++++++++------------------ 1 file changed, 95 insertions(+), 95 deletions(-) diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index 0c21dae8158..99e643cd708 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -16,101 +16,6 @@ from mne.utils import logger -def fit_ecg_template( - data: np.ndarray, - pca_template: np.ndarray, - a_peak_idx: int, - peak_range: int, - pre_range: int, - post_range: int, - mid_p: float, - fitted_art: np.ndarray, - post_idx_previous_peak: int | None, - n_samples_fit: int, -) -> tuple[np.ndarray, int]: - """ - Fits the heartbeat artefact found in the data. - - Returns the fitted artefact and the index of the next peak. - - Parameters - ---------- - data (ndarray): Data from the raw signal (n_channels, n_times) - pca_template (ndarray): Mean heartbeat and first N (default 4) - principal components of the heartbeat matrix - a_peak_idx (int): Sample index of current R-peak - peak_range (int): Half the median RR-interval - pre_range (int): Number of samples to fit before the R-peak - post_range (int): Number of samples to fit after the R-peak - mid_p (float): Sample index marking middle of the median RR interval - in the signal. Used to extract relevant part of PCA_template. - fitted_art (ndarray): The computed heartbeat artefact computed to - remove from the data - post_idx_previous_peak (optional int): Sample index of previous R-peak - n_samples_fit (int): Sample fit for interpolation in fitted artifact - windows. Helps reduce sharp edges at end of fitted heartbeat events - - Returns - ------- - tuple[np.ndarray, int]: the fitted artifact and the next peak index - """ - # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previous_peak - # Then nextpeak is returned at the end and the process repeats - # select window of template - template = pca_template[mid_p - peak_range - 1 : mid_p + peak_range + 1, :] - - # select window of data and detrend it - slice_ = data[a_peak_idx - peak_range : a_peak_idx + peak_range + 1] - - detrended_data = detrend(slice_, type="constant") - - # maps data on template and then maps it again back to the sensor space - least_square = np.linalg.lstsq(template, detrended_data, rcond=None) - pad_fit = np.dot(template, least_square[0]) - - # fit artifact - fitted_art[a_peak_idx - pre_range - 1 : a_peak_idx + post_range] = pad_fit[ - mid_p - pre_range - 1 : mid_p + post_range - ].T - - # if last peak, return - if post_idx_previous_peak is None: - return fitted_art, a_peak_idx + post_range - - # interpolate time between peaks - intpol_window = np.ceil([post_idx_previous_peak, a_peak_idx - pre_range]).astype( - int - ) # interpolation window - - if intpol_window[0] < intpol_window[1]: - # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data - - # You have x_fit which is two slices on either side of the interpolation window - # endpoints - # You have y_fit which is the y vals corresponding to x values above - # You have x_interpol which is the time points between the two slices in x_fit - # that you want to interpolate - # You have y_interpol which is values from pchip at the time points specified in - # x_interpol - # points to be interpolated in pt - the gap between the endpoints of the window - x_interpol = np.arange(intpol_window[0], intpol_window[1] + 1, 1) - # Entire range of x values in this step (taking some - # number of samples before and after the window) - x_fit = np.concatenate( - [ - np.arange(intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1), - np.arange(intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1), - ] - ) - y_fit = fitted_art[x_fit] - y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation - - # make fitted artefact in the desired range equal to the completed fit above - fitted_art[post_idx_previous_peak : a_peak_idx - pre_range + 1] = y_interpol - - return fitted_art, a_peak_idx + post_range - - def apply_pca_obs( raw: Raw, picks: list[str], @@ -313,3 +218,98 @@ def _pca_obs( # Actually subtract the artefact, return needs to be the same shape as input data data -= fitted_art return data + + +def fit_ecg_template( + data: np.ndarray, + pca_template: np.ndarray, + a_peak_idx: int, + peak_range: int, + pre_range: int, + post_range: int, + mid_p: float, + fitted_art: np.ndarray, + post_idx_previous_peak: int | None, + n_samples_fit: int, +) -> tuple[np.ndarray, int]: + """ + Fits the heartbeat artefact found in the data. + + Returns the fitted artefact and the index of the next peak. + + Parameters + ---------- + data (ndarray): Data from the raw signal (n_channels, n_times) + pca_template (ndarray): Mean heartbeat and first N (default 4) + principal components of the heartbeat matrix + a_peak_idx (int): Sample index of current R-peak + peak_range (int): Half the median RR-interval + pre_range (int): Number of samples to fit before the R-peak + post_range (int): Number of samples to fit after the R-peak + mid_p (float): Sample index marking middle of the median RR interval + in the signal. Used to extract relevant part of PCA_template. + fitted_art (ndarray): The computed heartbeat artefact computed to + remove from the data + post_idx_previous_peak (optional int): Sample index of previous R-peak + n_samples_fit (int): Sample fit for interpolation in fitted artifact + windows. Helps reduce sharp edges at end of fitted heartbeat events + + Returns + ------- + tuple[np.ndarray, int]: the fitted artifact and the next peak index + """ + # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previous_peak + # Then nextpeak is returned at the end and the process repeats + # select window of template + template = pca_template[mid_p - peak_range - 1 : mid_p + peak_range + 1, :] + + # select window of data and detrend it + slice_ = data[a_peak_idx - peak_range : a_peak_idx + peak_range + 1] + + detrended_data = detrend(slice_, type="constant") + + # maps data on template and then maps it again back to the sensor space + least_square = np.linalg.lstsq(template, detrended_data, rcond=None) + pad_fit = np.dot(template, least_square[0]) + + # fit artifact + fitted_art[a_peak_idx - pre_range - 1 : a_peak_idx + post_range] = pad_fit[ + mid_p - pre_range - 1 : mid_p + post_range + ].T + + # if last peak, return + if post_idx_previous_peak is None: + return fitted_art, a_peak_idx + post_range + + # interpolate time between peaks + intpol_window = np.ceil([post_idx_previous_peak, a_peak_idx - pre_range]).astype( + int + ) # interpolation window + + if intpol_window[0] < intpol_window[1]: + # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data + + # You have x_fit which is two slices on either side of the interpolation window + # endpoints + # You have y_fit which is the y vals corresponding to x values above + # You have x_interpol which is the time points between the two slices in x_fit + # that you want to interpolate + # You have y_interpol which is values from pchip at the time points specified in + # x_interpol + # points to be interpolated in pt - the gap between the endpoints of the window + x_interpol = np.arange(intpol_window[0], intpol_window[1] + 1, 1) + # Entire range of x values in this step (taking some + # number of samples before and after the window) + x_fit = np.concatenate( + [ + np.arange(intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1), + np.arange(intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1), + ] + ) + y_fit = fitted_art[x_fit] + y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation + + # make fitted artefact in the desired range equal to the completed fit above + fitted_art[post_idx_previous_peak : a_peak_idx - pre_range + 1] = y_interpol + + return fitted_art, a_peak_idx + post_range From 826afeceeb1900301d562324f294b5a366687ad1 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Thu, 19 Dec 2024 19:29:27 +0100 Subject: [PATCH 39/71] test: add pytest importorskip for pandas lib in pca obs tests --- mne/preprocessing/tests/test_pca_obs.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mne/preprocessing/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py index e2d07a8ce72..8519854736e 100644 --- a/mne/preprocessing/tests/test_pca_obs.py +++ b/mne/preprocessing/tests/test_pca_obs.py @@ -24,6 +24,8 @@ def short_raw_data(): def test_heart_artifact_removal(short_raw_data: Raw): """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" + pytest.importorskip("pandas") + # fake some random qrs events in the window of the raw data # remove first and last samples and cast to integer for indexing ecg_event_indices = np.linspace(0, short_raw_data.n_times, 20, dtype=int)[1:-1] @@ -87,5 +89,7 @@ def test_pca_obs_bad_input( short_raw_data: Raw, picks: list[str], qrs: np.ndarray, error: str ): """Test if bad input data raises the proper errors in the function sanity checks.""" + pytest.importorskip("pandas") + with pytest.raises(ValueError, match=error): apply_pca_obs(raw=short_raw_data, picks=picks, qrs_indices=qrs) From 16937c8e147360055c77091196462c77d2e4071b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 Dec 2024 18:32:34 +0000 Subject: [PATCH 40/71] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/preprocessing/tests/test_pca_obs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/preprocessing/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py index 8519854736e..1fb5cda0cf3 100644 --- a/mne/preprocessing/tests/test_pca_obs.py +++ b/mne/preprocessing/tests/test_pca_obs.py @@ -90,6 +90,6 @@ def test_pca_obs_bad_input( ): """Test if bad input data raises the proper errors in the function sanity checks.""" pytest.importorskip("pandas") - + with pytest.raises(ValueError, match=error): apply_pca_obs(raw=short_raw_data, picks=picks, qrs_indices=qrs) From 8f4dcdde0a55ec4ad850b211d3f9e34832341f9c Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Thu, 19 Dec 2024 19:34:59 +0100 Subject: [PATCH 41/71] docs: add apply_pca_obs algorithm to doc/api/preprocessing.rst list --- doc/api/preprocessing.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/api/preprocessing.rst b/doc/api/preprocessing.rst index 86ad3aca910..9fe3f995cc4 100644 --- a/doc/api/preprocessing.rst +++ b/doc/api/preprocessing.rst @@ -116,6 +116,7 @@ Projections: read_ica_eeglab read_fine_calibration write_fine_calibration + apply_pca_obs :py:mod:`mne.preprocessing.nirs`: From 4a6f4e84262f1201a735851d8ed9554b676c3f17 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 7 Jan 2025 13:59:03 -0500 Subject: [PATCH 42/71] MAINT: Satisfy CircleCI --- .circleci/config.yml | 7 ++ doc/api/datasets.rst | 1 + .../esg_rm_heart_artefact_pcaobs.py | 106 ++++++++---------- mne/datasets/__init__.pyi | 2 + mne/datasets/utils.py | 30 ++++- 5 files changed, 85 insertions(+), 61 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 644fd8b31b7..0fa373e392e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -218,6 +218,9 @@ jobs: - restore_cache: keys: - data-cache-phantom-kit + - restore_cache: + keys: + - data-cache-ds004388 - run: name: Get data # This limit could be increased, but this is helpful for finding slow ones @@ -393,6 +396,10 @@ jobs: key: data-cache-phantom-kit paths: - ~/mne_data/MNE-phantom-KIT-data # (1 G) + - save_cache: + key: data-cache-ds004388 + paths: + - ~/mne_data/ds004388 # (1.8 G) linkcheck: diff --git a/doc/api/datasets.rst b/doc/api/datasets.rst index 2b2c92c8654..87730fbd717 100644 --- a/doc/api/datasets.rst +++ b/doc/api/datasets.rst @@ -18,6 +18,7 @@ Datasets brainstorm.bst_auditory.data_path brainstorm.bst_resting.data_path brainstorm.bst_raw.data_path + default_path eegbci.load_data eegbci.standardize fetch_aparc_sub_parcellation diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index b75552f657a..0014c2b0771 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -33,29 +33,28 @@ # This will download simultaneous EEG and ESG data from a single participant after # median nerve stimulation of the left wrist # Set the target directory to your desired location -import openneuro as on +import openneuro from matplotlib import pyplot as plt import mne -from mne import Epochs, concatenate_raws, events_from_annotations +from mne import Epochs, events_from_annotations from mne.io import read_raw_eeglab from mne.preprocessing import find_ecg_events, fix_stim_artifact # add the path where you want the OpenNeuro data downloaded. Files total around 8 GB # target_dir = "/home/steinnhm/personal/mne-data" -target_dir = "/data/pt_02569/test_data" - -file_list = glob.glob(target_dir + "/sub-001/eeg/*median*.set") -if file_list: - print("Data is already downloaded") -else: - on.download( - dataset="ds004388", target_dir=target_dir, include="sub-001/*median*_eeg*" - ) +ds = "ds004388" +target_dir = mne.datasets.default_path() / ds +run_name = "sub-001/eeg/*median_run-03_eeg*.set" +if not glob.glob(str(target_dir / run_name)): + target_dir.mkdir(exist_ok=True) + openneuro.download(dataset=ds, target_dir=target_dir, include=run_name[:-4]) +block_files = glob.glob(str(target_dir / run_name)) +assert len(block_files) == 1 ############################################################################### # Define the esg channels (arranged in two patches over the neck and lower back) -# Also include the ECG channel for artefact correction + esg_chans = [ "S35", "S24", @@ -96,69 +95,54 @@ "S23", ] -# Sampling rate -fs = 1000 - # Interpolation window for ESG data to remove stimulation artefact -tstart_esg = -0.007 -tmax_esg = 0.007 +tstart_esg = -7e-3 +tmax_esg = 7e-3 # Define timing of heartbeat epochs -iv_baseline = [-400 / 1000, -300 / 1000] -iv_epoch = [-400 / 1000, 600 / 1000] +iv_baseline = [-400e-3, -300e-3] +iv_epoch = [-400e-3, 600e-3] ############################################################################### # Read in each of the four blocks and concatenate the raw structures after performing # some minimal preprocessing including removing the stimulation artefact, downsampling # and filtering -block_files = glob.glob(target_dir + "/sub-001/eeg/*median*.set") -block_files = sorted(block_files) - -for count, block_file in enumerate(block_files): - raw = read_raw_eeglab( - block_file, eog=(), preload=True, uint16_codec=None, verbose=None - ) - - # Isolate the ESG channels (including ECG for R-peak detection) - raw.pick(esg_chans + ["ECG"]) - - # Find trigger timings to remove the stimulation artefact - events, event_dict = events_from_annotations(raw) - trigger_name = "Median - Stimulation" - - fix_stim_artifact( - raw, - events=events, - event_id=event_dict[trigger_name], - tmin=tstart_esg, - tmax=tmax_esg, - mode="linear", - stim_channel=None, - ) - - # Downsample the data - raw.resample(fs) - - # Append blocks of the same condition - if count == 0: - raw_concat = raw - else: - concatenate_raws([raw_concat, raw]) + +raw = read_raw_eeglab(block_files[0], verbose="error") +raw.set_channel_types(dict(ECG="ecg")) +# Isolate the ESG channels (including ECG for R-peak detection) +raw.pick(esg_chans + ["ECG"]) +# Trim duration and downsample (from 10kHz) to improve example speed +raw.crop(0, 60).load_data().resample(2000) + +# Find trigger timings to remove the stimulation artefact +events, event_dict = events_from_annotations(raw) +trigger_name = "Median - Stimulation" + +fix_stim_artifact( + raw, + events=events, + event_id=event_dict[trigger_name], + tmin=tstart_esg, + tmax=tmax_esg, + mode="linear", + stim_channel=None, +) ############################################################################### # Find ECG events and add to the raw structure as event annotations -ecg_events, ch_ecg, average_pulse = find_ecg_events(raw_concat, ch_name="ECG") +ecg_events, ch_ecg, average_pulse = find_ecg_events(raw, ch_name="ECG") ecg_event_samples = np.asarray( [[ecg_event[0] for ecg_event in ecg_events]] ) # Samples only qrs_event_time = [ - x / fs for x in ecg_event_samples.reshape(-1) + x / raw.info["sfreq"] for x in ecg_event_samples.reshape(-1) ] # Divide by sampling rate to make times duration = np.repeat(0.0, len(ecg_event_samples)) description = ["qrs"] * len(ecg_event_samples) -raw_concat.annotations.append( +raw.annotations.append( qrs_event_time, duration, description, ch_names=[esg_chans] * len(qrs_event_time) ) @@ -166,10 +150,11 @@ # Create evoked response about the detected R-peaks before cardiac artefact correction # Apply PCA-OBS to remove the cardiac artefact # Create evoked response about the detected R-peaks after cardiac artefact correction -events, event_ids = events_from_annotations(raw_concat) + +events, event_ids = events_from_annotations(raw) event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} epochs = Epochs( - raw_concat, + raw, events, event_id=event_id_dict, tmin=iv_epoch[0], @@ -181,11 +166,11 @@ # Apply function - modifies the data in place. Optionally high-pass filter # the data before applying PCA-OBS to remove low frequency drifts mne.preprocessing.apply_pca_obs( - raw_concat, picks=esg_chans, n_jobs=5, qrs_indices=ecg_event_samples.reshape(-1) + raw, picks=esg_chans, n_jobs=5, qrs_indices=ecg_event_samples.reshape(-1) ) epochs = Epochs( - raw_concat, + raw, events, event_id=event_id_dict, tmin=iv_epoch[0], @@ -196,7 +181,8 @@ ############################################################################### # Comparison image -fig, axes = plt.subplots(1, 1) + +fig, axes = plt.subplots(1, 1, layout="constrained") axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") axes.plot(evoked_after.times, evoked_after.get_data().T, color="green") axes.set_ylim([-0.0005, 0.001]) diff --git a/mne/datasets/__init__.pyi b/mne/datasets/__init__.pyi index 44cee84fe7f..2f69a1027e5 100644 --- a/mne/datasets/__init__.pyi +++ b/mne/datasets/__init__.pyi @@ -6,6 +6,7 @@ __all__ = [ "epilepsy_ecog", "erp_core", "eyelink", + "default_path", "fetch_aparc_sub_parcellation", "fetch_dataset", "fetch_fsaverage", @@ -70,6 +71,7 @@ from ._infant import fetch_infant_template from ._phantom.base import fetch_phantom from .utils import ( _download_all_example_data, + default_path, fetch_aparc_sub_parcellation, fetch_hcp_mmp_parcellation, has_dataset, diff --git a/mne/datasets/utils.py b/mne/datasets/utils.py index 452e42cffc7..93aabc0841a 100644 --- a/mne/datasets/utils.py +++ b/mne/datasets/utils.py @@ -2,6 +2,7 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import glob import importlib import inspect import logging @@ -92,6 +93,22 @@ def _dataset_version(path, name): return version +@verbose +def default_path(*, verbose=None): + """Get the default MNE_DATA path. + + Parameters + ---------- + %(verbose)s + + Returns + ------- + data_path : instance of Path + Path to the default MNE_DATA directory. + """ + return _get_path(None, None, None) + + def _get_path(path, key, name): """Get a dataset path.""" # 1. Input @@ -113,7 +130,8 @@ def _get_path(path, key, name): return path # 4. ~/mne_data (but use a fake home during testing so we don't # unnecessarily create ~/mne_data) - logger.info(f"Using default location ~/mne_data for {name}...") + extra = f" for {name}" if name else "" + logger.info(f"Using default location ~/mne_data{extra}...") path = Path(os.getenv("_MNE_FAKE_HOME_DIR", "~")).expanduser() / "mne_data" if not path.is_dir(): logger.info(f"Creating {path}") @@ -319,6 +337,8 @@ def _download_all_example_data(verbose=True): # # verbose=True by default so we get nice status messages. # Consider adding datasets from here to CircleCI for PR-auto-build + import openneuro + paths = dict() for kind in ( "sample testing misc spm_face somato hf_sef multimodal " @@ -375,6 +395,14 @@ def _download_all_example_data(verbose=True): limo.load_data(subject=1, update_path=True) logger.info("[done limo]") + # for ESG + ds = "ds004388" + target_dir = default_path() / ds + run_name = "sub-001/eeg/*median_run-03_eeg*.set" + if not glob.glob(str(target_dir / run_name)): + target_dir.mkdir(exist_ok=True) + openneuro.download(dataset=ds, target_dir=target_dir, include=run_name[:-4]) + @verbose def fetch_aparc_sub_parcellation(subjects_dir=None, verbose=None): From 28e45ca83efffc33bda1f707d3b51e795bdce9a0 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 7 Jan 2025 14:05:37 -0500 Subject: [PATCH 43/71] FIX: Better --- .../esg_rm_heart_artefact_pcaobs.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index 0014c2b0771..afb19f3656a 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -183,14 +183,14 @@ # Comparison image fig, axes = plt.subplots(1, 1, layout="constrained") -axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") -axes.plot(evoked_after.times, evoked_after.get_data().T, color="green") -axes.set_ylim([-0.0005, 0.001]) -axes.set_ylabel("Amplitude (V)") -axes.set_xlabel("Time (s)") -axes.set_title("Before (black) vs. After (green)") -plt.tight_layout() -plt.show() +data_before = evoked_before.get_data(units=dict(eeg="uV")).T +data_after = evoked_after.get_data(units=dict(eeg="uV")).T +hs = list() +hs.append(axes.plot(epochs.times, data_before, color="k")[0]) +hs.append(axes.plot(epochs.times, data_after, color="green", label="after")[0]) +axes.set(ylim=[-500, 1000], ylabel="Amplitude (µV)", xlabel="Time (s)") +axes.set(title="ECG artefact removal using PCA-OBS") +axes.legend(hs, ["before", "after"]) # %% # References From 6415d9732b01cb17c246f56dccbe7f94d972332e Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 7 Jan 2025 14:06:27 -0500 Subject: [PATCH 44/71] FIX: Wording --- examples/preprocessing/esg_rm_heart_artefact_pcaobs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index afb19f3656a..9b7ce9b8ef4 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -131,6 +131,7 @@ ############################################################################### # Find ECG events and add to the raw structure as event annotations + ecg_events, ch_ecg, average_pulse = find_ecg_events(raw, ch_name="ECG") ecg_event_samples = np.asarray( [[ecg_event[0] for ecg_event in ecg_events]] @@ -147,9 +148,8 @@ ) ############################################################################### -# Create evoked response about the detected R-peaks before cardiac artefact correction -# Apply PCA-OBS to remove the cardiac artefact -# Create evoked response about the detected R-peaks after cardiac artefact correction +# Create evoked response about the detected R-peaks before and after cardiac artefact +# correction events, event_ids = events_from_annotations(raw) event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} From 7508056b6ba8176bdbf92d416d0a5973f1a1ec6d Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 7 Jan 2025 14:31:44 -0500 Subject: [PATCH 45/71] FIX: Docstring --- mne/preprocessing/pca_obs.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index 99e643cd708..3f7cd93024b 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -30,17 +30,16 @@ def apply_pca_obs( Parameters ---------- - raw: Raw - The raw data to process - picks: list[str] - Channels in the Raw object to remove the heart artefact from - qrs_indices: ndarray, shape (n_peaks, 1) + raw : instance of Raw + The raw data to process. + picks : list of str + Channels in the Raw object to remove the heart artefact from. + qrs_indices : ndarray, shape (n_peaks, 1) Array of indices in the Raw data of detected R-peaks in ECG channel. - n_components: int, default 4 - Number of PCA components to use to form the OBS - n_jobs: int, default None + n_components : int + Number of PCA components to use to form the OBS (default 4). + n_jobs : int | None Number of jobs to perform the PCA-OBS processing in parallel. - Passed on to Raw.apply_function """ # sanity checks if not isinstance(qrs_indices, np.ndarray): From 6556418b8d305bc732ea59eddff11a7078ac9b72 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 7 Jan 2025 14:44:59 -0500 Subject: [PATCH 46/71] FIX: Name --- doc/conf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/conf.py b/doc/conf.py index 7dd6ec90d4f..c51de7bd077 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -355,6 +355,7 @@ "n_frequencies", "n_tests", "n_samples", + "n_peaks", "n_permutations", "nchan", "n_points", From c8eb284b72e61451b022879fa1d9d6dfdc36ac8c Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Wed, 8 Jan 2025 09:53:54 +0100 Subject: [PATCH 47/71] example: update comments --- .../esg_rm_heart_artefact_pcaobs.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index 9b7ce9b8ef4..b51051d028f 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -30,9 +30,8 @@ ############################################################################### # Download sample subject data from OpenNeuro if you haven't already -# This will download simultaneous EEG and ESG data from a single participant after +# This will download simultaneous EEG and ESG data from a single run of a single participant after # median nerve stimulation of the left wrist -# Set the target directory to your desired location import openneuro from matplotlib import pyplot as plt @@ -41,8 +40,7 @@ from mne.io import read_raw_eeglab from mne.preprocessing import find_ecg_events, fix_stim_artifact -# add the path where you want the OpenNeuro data downloaded. Files total around 8 GB -# target_dir = "/home/steinnhm/personal/mne-data" +# add the path where you want the OpenNeuro data downloaded. Each run is ~2GB of data to download. ds = "ds004388" target_dir = mne.datasets.default_path() / ds run_name = "sub-001/eeg/*median_run-03_eeg*.set" @@ -95,22 +93,21 @@ "S23", ] -# Interpolation window for ESG data to remove stimulation artefact +# Interpolation window in seconds for ESG data to remove stimulation artefact tstart_esg = -7e-3 tmax_esg = 7e-3 -# Define timing of heartbeat epochs +# Define timing of heartbeat epochs in seconds relative to R-peaks iv_baseline = [-400e-3, -300e-3] iv_epoch = [-400e-3, 600e-3] ############################################################################### -# Read in each of the four blocks and concatenate the raw structures after performing -# some minimal preprocessing including removing the stimulation artefact, downsampling +# Perform minimal preprocessing including removing the stimulation artefact, downsampling # and filtering raw = read_raw_eeglab(block_files[0], verbose="error") raw.set_channel_types(dict(ECG="ecg")) -# Isolate the ESG channels (including ECG for R-peak detection) +# Isolate the ESG channels (include the ECG channel for R-peak detection) raw.pick(esg_chans + ["ECG"]) # Trim duration and downsample (from 10kHz) to improve example speed raw.crop(0, 60).load_data().resample(2000) @@ -191,6 +188,7 @@ axes.set(ylim=[-500, 1000], ylabel="Amplitude (µV)", xlabel="Time (s)") axes.set(title="ECG artefact removal using PCA-OBS") axes.legend(hs, ["before", "after"]) +plt.show() # %% # References From 89278a68e531dacd6ca78162ea136b8f1f116deb Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Wed, 8 Jan 2025 10:22:53 +0100 Subject: [PATCH 48/71] tutorial: add ref to pca_obs in SSP tutorial section on ECG artefacts --- tutorials/preprocessing/50_artifact_correction_ssp.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tutorials/preprocessing/50_artifact_correction_ssp.py b/tutorials/preprocessing/50_artifact_correction_ssp.py index 57be25803d5..74e4cae52d0 100644 --- a/tutorials/preprocessing/50_artifact_correction_ssp.py +++ b/tutorials/preprocessing/50_artifact_correction_ssp.py @@ -390,6 +390,14 @@ # # See the documentation of each function for further details. # +# %% +# .. note:: +# In situations only limited electrodes are available for analysis, removing the cardiac +# artefact using techniques which rely on the availability of spatial information +# (such as SSP) may not be possible. In these instances, it may be of use to consider +# algorithms which require information only regarding heartbeat instances in the time domain, +# such as `mne.preprocessing.pca_obs`. +# # # Repairing EOG artifacts with SSP # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -535,6 +543,7 @@ # reduced the amplitude of our signals in sensor space, but that it should not # bias the amplitudes in source space. # +# # References # ^^^^^^^^^^ # From 36e0ca0e7e60fe31a7406d65c310bda60d096bda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Steinn=20Hauser=20Magn=C3=BAsson?= <42544980+steinnhauser@users.noreply.github.com> Date: Fri, 10 Jan 2025 17:48:47 +0100 Subject: [PATCH 49/71] Update mne/preprocessing/pca_obs.py make fit_ecg_template a private function Co-authored-by: Eric Larson --- mne/preprocessing/pca_obs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index 3f7cd93024b..5627f80b5a1 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -219,7 +219,7 @@ def _pca_obs( return data -def fit_ecg_template( +def _fit_ecg_template( data: np.ndarray, pca_template: np.ndarray, a_peak_idx: int, From fc5bdc837801d0ac8c293262c79796e32b6c2904 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Steinn=20Hauser=20Magn=C3=BAsson?= <42544980+steinnhauser@users.noreply.github.com> Date: Fri, 10 Jan 2025 17:50:12 +0100 Subject: [PATCH 50/71] Update tutorials/preprocessing/50_artifact_correction_ssp.py fix CircleCI error Co-authored-by: Eric Larson --- tutorials/preprocessing/50_artifact_correction_ssp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tutorials/preprocessing/50_artifact_correction_ssp.py b/tutorials/preprocessing/50_artifact_correction_ssp.py index 74e4cae52d0..45e52ede0df 100644 --- a/tutorials/preprocessing/50_artifact_correction_ssp.py +++ b/tutorials/preprocessing/50_artifact_correction_ssp.py @@ -390,7 +390,6 @@ # # See the documentation of each function for further details. # -# %% # .. note:: # In situations only limited electrodes are available for analysis, removing the cardiac # artefact using techniques which rely on the availability of spatial information From d586d620765bc7b20db59c6309db69878766d9e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Steinn=20Hauser=20Magn=C3=BAsson?= <42544980+steinnhauser@users.noreply.github.com> Date: Fri, 10 Jan 2025 17:51:14 +0100 Subject: [PATCH 51/71] Update mne/preprocessing/pca_obs.py improve docstring for PCA-OBS function, adds citation Co-authored-by: Eric Larson --- mne/preprocessing/pca_obs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index 5627f80b5a1..46bb56b4a4b 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -26,7 +26,7 @@ def apply_pca_obs( """ Apply the PCA-OBS algorithm to picks of a Raw object. - Update the Raw object in-place. Make sanity checks for all inputs. + Uses the optimal basis set (OBS) algorithm from :footcite:`NiazyEtAl2005`. Parameters ---------- From d6ae456b77fb898192a0f5170508dcbe396436a9 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Fri, 10 Jan 2025 19:35:59 +0100 Subject: [PATCH 52/71] refactor: rework qrs_indices to qrs_times, adjust sanity checks, add _validate_type for ndarray check --- mne/preprocessing/pca_obs.py | 48 ++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index 46bb56b4a4b..024a0242a82 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -14,12 +14,13 @@ from mne.io.fiff.raw import Raw from mne.utils import logger +from mne.utils.check import _validate_type def apply_pca_obs( raw: Raw, picks: list[str], - qrs_indices: np.ndarray, + qrs_times: np.ndarray, n_components: int = 4, n_jobs: int | None = None, ) -> None: @@ -34,24 +35,23 @@ def apply_pca_obs( The raw data to process. picks : list of str Channels in the Raw object to remove the heart artefact from. - qrs_indices : ndarray, shape (n_peaks, 1) - Array of indices in the Raw data of detected R-peaks in ECG channel. + qrs_times : ndarray, shape (n_peaks,) + Array of times in the Raw data of detected R-peaks in ECG channel. n_components : int Number of PCA components to use to form the OBS (default 4). n_jobs : int | None Number of jobs to perform the PCA-OBS processing in parallel. """ # sanity checks - if not isinstance(qrs_indices, np.ndarray): - raise ValueError("qrs_indices must be an array") - if len(qrs_indices.shape) > 1: - raise ValueError("qrs_indices must be a 1d array") - if qrs_indices.dtype != int: - raise ValueError("qrs_indices must be an array of integers") - if np.any(qrs_indices < 0): - raise ValueError("qrs_indices must be strictly positive integers") - if np.any(qrs_indices >= raw.n_times): - logger.warning("out of bound qrs_indices will be ignored..") + _validate_type(qrs_times, np.ndarray, "qrs_times") + if len(qrs_times.shape) > 1: + raise ValueError("qrs_times must be a 1d array") + if qrs_times.dtype not in [int, float]: + raise ValueError("qrs_times must be an array of either integers or floats") + if np.any(qrs_times < 0): + raise ValueError("qrs_times must be strictly positive") + if np.any(qrs_times >= raw.times[-1]): + logger.warning("some out of bound qrs_times will be ignored..") if not picks: raise ValueError("picks must be a list of channel names") @@ -59,8 +59,8 @@ def apply_pca_obs( _pca_obs, picks=picks, n_jobs=n_jobs, - # args sent to PCA_OBS - qrs=qrs_indices, + # args sent to PCA_OBS, convert times to indices + qrs=raw.time_as_index(qrs_times), n_components=n_components, ) @@ -133,7 +133,7 @@ def _pca_obs( ################ window_start_idx = [] window_end_idx = [] - post_idx_nextPeak = None + post_idx_next_peak = None for p in range(peak_count): # if the current peak doesn't have enough data in the @@ -148,7 +148,7 @@ def _pca_obs( if post_range > peak_range: post_range = peak_range - fitted_art, post_idx_nextPeak = fit_ecg_template( + fitted_art, post_idx_next_peak = _fit_ecg_template( data=data, pca_template=pca_template, a_peak_idx=peak_idx[p], @@ -157,7 +157,7 @@ def _pca_obs( post_range=post_range, mid_p=mid_p, fitted_art=fitted_art, - post_idx_previous_peak=post_idx_nextPeak, + post_idx_previous_peak=post_idx_next_peak, n_samples_fit=n_samples_fit, ) # Appending to list instead of using counter @@ -170,7 +170,7 @@ def _pca_obs( post_range = peak_range if pre_range > peak_range: pre_range = peak_range - fitted_art, _ = fit_ecg_template( + fitted_art, _ = _fit_ecg_template( data=data, pca_template=pca_template, a_peak_idx=peak_idx[p], @@ -179,7 +179,7 @@ def _pca_obs( post_range=post_range, mid_p=mid_p, fitted_art=fitted_art, - post_idx_previous_peak=post_idx_nextPeak, + post_idx_previous_peak=post_idx_next_peak, n_samples_fit=n_samples_fit, ) window_start_idx.append(peak_idx[p] - peak_range) @@ -199,7 +199,7 @@ def _pca_obs( a_template = pca_template[ mid_p - peak_range - 1 : mid_p + peak_range + 1, : ] - fitted_art, post_idx_nextPeak = fit_ecg_template( + fitted_art, post_idx_next_peak = _fit_ecg_template( data=data, pca_template=a_template, a_peak_idx=peak_idx[p], @@ -208,7 +208,7 @@ def _pca_obs( post_range=post_range, mid_p=mid_p, fitted_art=fitted_art, - post_idx_previous_peak=post_idx_nextPeak, + post_idx_previous_peak=post_idx_next_peak, n_samples_fit=n_samples_fit, ) window_start_idx.append(peak_idx[p] - peak_range) @@ -257,8 +257,8 @@ def _fit_ecg_template( ------- tuple[np.ndarray, int]: the fitted artifact and the next peak index """ - # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previous_peak - # Then nextpeak is returned at the end and the process repeats + # post_idx_next_peak is passed in in PCA_OBS, used here as post_idx_previous_peak + # Then next_peak is returned at the end and the process repeats # select window of template template = pca_template[mid_p - peak_range - 1 : mid_p + peak_range + 1, :] From 420f4fcebf864043088233e79f301a99543c88be Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Fri, 10 Jan 2025 19:37:19 +0100 Subject: [PATCH 53/71] test,example: adjust example and tests to new qrs_times changes, add exception type to test parameterization --- .../esg_rm_heart_artefact_pcaobs.py | 2 +- mne/preprocessing/tests/test_pca_obs.py | 30 ++++++++++--------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index b51051d028f..d0490926eb9 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -163,7 +163,7 @@ # Apply function - modifies the data in place. Optionally high-pass filter # the data before applying PCA-OBS to remove low frequency drifts mne.preprocessing.apply_pca_obs( - raw, picks=esg_chans, n_jobs=5, qrs_indices=ecg_event_samples.reshape(-1) + raw, picks=esg_chans, n_jobs=5, qrs_times=raw.times[ecg_event_samples.reshape(-1)] ) epochs = Epochs( diff --git a/mne/preprocessing/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py index 1fb5cda0cf3..81a3dc75a7c 100644 --- a/mne/preprocessing/tests/test_pca_obs.py +++ b/mne/preprocessing/tests/test_pca_obs.py @@ -26,16 +26,16 @@ def test_heart_artifact_removal(short_raw_data: Raw): """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" pytest.importorskip("pandas") - # fake some random qrs events in the window of the raw data - # remove first and last samples and cast to integer for indexing - ecg_event_indices = np.linspace(0, short_raw_data.n_times, 20, dtype=int)[1:-1] - # copy the original raw. heart artifact is removed in-place orig_df: pd.DataFrame = short_raw_data.to_data_frame().copy(deep=True) + # fake some random qrs events in the window of the raw data + # remove first and last samples and cast to integer for indexing + ecg_event_times = np.linspace(0, orig_df["time"].iloc[-1], 20)[1:-1] + # perform heart artifact removal apply_pca_obs( - raw=short_raw_data, picks=["eeg"], qrs_indices=ecg_event_indices, n_jobs=1 + raw=short_raw_data, picks=["eeg"], qrs_times=ecg_event_times, n_jobs=1 ) # compare processed df to original df @@ -68,28 +68,30 @@ def test_heart_artifact_removal(short_raw_data: Raw): # test that various nonsensical inputs raise the proper errors @pytest.mark.parametrize( - ("picks", "qrs", "error"), + ("picks", "qrs_times", "error", "exception"), [ - (["eeg"], np.array([[0, 1], [2, 3]]), "qrs_indices must be a 1d array"), - (["eeg"], [2, 3, 4], "qrs_indices must be an array"), + (["eeg"], np.array([[0, 1], [2, 3]]), "qrs_times must be a 1d array", ValueError), + (["eeg"], [2, 3, 4], "qrs_times must be an instance of ndarray, got instead.", TypeError), ( ["eeg"], np.array([None, "foo", 2]), - "qrs_indices must be an array of integers", + "qrs_times must be an array of either integers or floats", + ValueError, ), ( ["eeg"], np.array([-1, 0, 3]), - "qrs_indices must be strictly positive integers", + "qrs_times must be strictly positive", + ValueError, ), - ([], np.array([1, 2, 3]), "picks must be a list of channel names"), + ([], np.array([1, 2, 3]), "picks must be a list of channel names", ValueError), ], ) def test_pca_obs_bad_input( - short_raw_data: Raw, picks: list[str], qrs: np.ndarray, error: str + short_raw_data: Raw, picks: list[str], qrs_times: np.ndarray, error: str, exception: type[Exception] ): """Test if bad input data raises the proper errors in the function sanity checks.""" pytest.importorskip("pandas") - with pytest.raises(ValueError, match=error): - apply_pca_obs(raw=short_raw_data, picks=picks, qrs_indices=qrs) + with pytest.raises(exception, match=error): + apply_pca_obs(raw=short_raw_data, picks=picks, qrs_times=qrs_times) From 54d67aad30d9a67aa83d80a627ee1e1cfb0457bb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Jan 2025 18:39:43 +0000 Subject: [PATCH 54/71] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/preprocessing/tests/test_pca_obs.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/mne/preprocessing/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py index 81a3dc75a7c..257de07fec4 100644 --- a/mne/preprocessing/tests/test_pca_obs.py +++ b/mne/preprocessing/tests/test_pca_obs.py @@ -70,8 +70,18 @@ def test_heart_artifact_removal(short_raw_data: Raw): @pytest.mark.parametrize( ("picks", "qrs_times", "error", "exception"), [ - (["eeg"], np.array([[0, 1], [2, 3]]), "qrs_times must be a 1d array", ValueError), - (["eeg"], [2, 3, 4], "qrs_times must be an instance of ndarray, got instead.", TypeError), + ( + ["eeg"], + np.array([[0, 1], [2, 3]]), + "qrs_times must be a 1d array", + ValueError, + ), + ( + ["eeg"], + [2, 3, 4], + "qrs_times must be an instance of ndarray, got instead.", + TypeError, + ), ( ["eeg"], np.array([None, "foo", 2]), @@ -88,7 +98,11 @@ def test_heart_artifact_removal(short_raw_data: Raw): ], ) def test_pca_obs_bad_input( - short_raw_data: Raw, picks: list[str], qrs_times: np.ndarray, error: str, exception: type[Exception] + short_raw_data: Raw, + picks: list[str], + qrs_times: np.ndarray, + error: str, + exception: type[Exception], ): """Test if bad input data raises the proper errors in the function sanity checks.""" pytest.importorskip("pandas") From f4082b4c45f349f116cefd7b6f443ade2decaac1 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Fri, 10 Jan 2025 20:15:47 +0100 Subject: [PATCH 55/71] feat: add copy kwarg, optionally return new instance, add verbose decorator, add verbose argument and docstring mention, convert docstring of n_jobs --- mne/preprocessing/pca_obs.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index 024a0242a82..4f9a6c163ff 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -14,16 +14,21 @@ from mne.io.fiff.raw import Raw from mne.utils import logger +from mne.utils._logging import verbose from mne.utils.check import _validate_type +@verbose def apply_pca_obs( raw: Raw, picks: list[str], + *, qrs_times: np.ndarray, n_components: int = 4, n_jobs: int | None = None, -) -> None: + copy: bool = True, + verbose : bool | str | int | None = None +) -> Raw | None: """ Apply the PCA-OBS algorithm to picks of a Raw object. @@ -39,8 +44,15 @@ def apply_pca_obs( Array of times in the Raw data of detected R-peaks in ECG channel. n_components : int Number of PCA components to use to form the OBS (default 4). - n_jobs : int | None - Number of jobs to perform the PCA-OBS processing in parallel. + copy : bool + If False, modify the Raw instance in-place. + If True, return a copied, modified Raw instance. Defaults to True. + %(n_jobs)s + %(verbose)s + + References + ---------- + .. footbibliography:: """ # sanity checks _validate_type(qrs_times, np.ndarray, "qrs_times") @@ -52,8 +64,9 @@ def apply_pca_obs( raise ValueError("qrs_times must be strictly positive") if np.any(qrs_times >= raw.times[-1]): logger.warning("some out of bound qrs_times will be ignored..") - if not picks: - raise ValueError("picks must be a list of channel names") + + if copy: + raw = raw.copy() raw.apply_function( _pca_obs, @@ -64,6 +77,8 @@ def apply_pca_obs( n_components=n_components, ) + if copy: + return raw def _pca_obs( data: np.ndarray, From fd9c41b63576333981e952e2135fa9025bfbb9b1 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Fri, 10 Jan 2025 20:18:40 +0100 Subject: [PATCH 56/71] test,example: update tests and example with new copy default behavior, remove redundant test --- examples/preprocessing/esg_rm_heart_artefact_pcaobs.py | 2 +- mne/preprocessing/tests/test_pca_obs.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index d0490926eb9..a338b8d6b6b 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -162,7 +162,7 @@ # Apply function - modifies the data in place. Optionally high-pass filter # the data before applying PCA-OBS to remove low frequency drifts -mne.preprocessing.apply_pca_obs( +raw = mne.preprocessing.apply_pca_obs( raw, picks=esg_chans, n_jobs=5, qrs_times=raw.times[ecg_event_samples.reshape(-1)] ) diff --git a/mne/preprocessing/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py index 257de07fec4..5b6707224aa 100644 --- a/mne/preprocessing/tests/test_pca_obs.py +++ b/mne/preprocessing/tests/test_pca_obs.py @@ -34,7 +34,7 @@ def test_heart_artifact_removal(short_raw_data: Raw): ecg_event_times = np.linspace(0, orig_df["time"].iloc[-1], 20)[1:-1] # perform heart artifact removal - apply_pca_obs( + short_raw_data = apply_pca_obs( raw=short_raw_data, picks=["eeg"], qrs_times=ecg_event_times, n_jobs=1 ) @@ -94,7 +94,6 @@ def test_heart_artifact_removal(short_raw_data: Raw): "qrs_times must be strictly positive", ValueError, ), - ([], np.array([1, 2, 3]), "picks must be a list of channel names", ValueError), ], ) def test_pca_obs_bad_input( From a578a7d10d8171b73e3ebce437c2980a89e11c31 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Jan 2025 19:21:30 +0000 Subject: [PATCH 57/71] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/preprocessing/pca_obs.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index 4f9a6c163ff..b3cfb6ec373 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -27,7 +27,7 @@ def apply_pca_obs( n_components: int = 4, n_jobs: int | None = None, copy: bool = True, - verbose : bool | str | int | None = None + verbose: bool | str | int | None = None, ) -> Raw | None: """ Apply the PCA-OBS algorithm to picks of a Raw object. @@ -45,7 +45,7 @@ def apply_pca_obs( n_components : int Number of PCA components to use to form the OBS (default 4). copy : bool - If False, modify the Raw instance in-place. + If False, modify the Raw instance in-place. If True, return a copied, modified Raw instance. Defaults to True. %(n_jobs)s %(verbose)s @@ -80,6 +80,7 @@ def apply_pca_obs( if copy: return raw + def _pca_obs( data: np.ndarray, qrs: np.ndarray, From 41559ffc566843cf3b9466479c29d1495a41b57b Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Fri, 10 Jan 2025 20:26:03 +0100 Subject: [PATCH 58/71] lint: resolve pre-commit errors, format --- .../esg_rm_heart_artefact_pcaobs.py | 22 +++++++++---------- mne/preprocessing/pca_obs.py | 5 +++-- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index a338b8d6b6b..56fd6b0051a 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -30,8 +30,8 @@ ############################################################################### # Download sample subject data from OpenNeuro if you haven't already -# This will download simultaneous EEG and ESG data from a single run of a single participant after -# median nerve stimulation of the left wrist +# This will download simultaneous EEG and ESG data from a single run of a +# single participant after median nerve stimulation of the left wrist import openneuro from matplotlib import pyplot as plt @@ -40,7 +40,7 @@ from mne.io import read_raw_eeglab from mne.preprocessing import find_ecg_events, fix_stim_artifact -# add the path where you want the OpenNeuro data downloaded. Each run is ~2GB of data to download. +# add the path where you want the OpenNeuro data downloaded. Each run is ~2GB of data ds = "ds004388" target_dir = mne.datasets.default_path() / ds run_name = "sub-001/eeg/*median_run-03_eeg*.set" @@ -101,9 +101,9 @@ iv_baseline = [-400e-3, -300e-3] iv_epoch = [-400e-3, 600e-3] -############################################################################### -# Perform minimal preprocessing including removing the stimulation artefact, downsampling -# and filtering +###################################################### +# Perform minimal preprocessing including removing the +# stimulation artefact, downsampling and filtering raw = read_raw_eeglab(block_files[0], verbose="error") raw.set_channel_types(dict(ECG="ecg")) @@ -126,7 +126,7 @@ stim_channel=None, ) -############################################################################### +################################################################### # Find ECG events and add to the raw structure as event annotations ecg_events, ch_ecg, average_pulse = find_ecg_events(raw, ch_name="ECG") @@ -144,9 +144,9 @@ qrs_event_time, duration, description, ch_names=[esg_chans] * len(qrs_event_time) ) -############################################################################### -# Create evoked response about the detected R-peaks before and after cardiac artefact -# correction +################################################### +# Create evoked response about the detected R-peaks +# before and after cardiac artefact correction events, event_ids = events_from_annotations(raw) event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} @@ -176,7 +176,7 @@ ) evoked_after = epochs.average() -############################################################################### +################## # Comparison image fig, axes = plt.subplots(1, 1, layout="constrained") diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index 4f9a6c163ff..b3cfb6ec373 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -27,7 +27,7 @@ def apply_pca_obs( n_components: int = 4, n_jobs: int | None = None, copy: bool = True, - verbose : bool | str | int | None = None + verbose: bool | str | int | None = None, ) -> Raw | None: """ Apply the PCA-OBS algorithm to picks of a Raw object. @@ -45,7 +45,7 @@ def apply_pca_obs( n_components : int Number of PCA components to use to form the OBS (default 4). copy : bool - If False, modify the Raw instance in-place. + If False, modify the Raw instance in-place. If True, return a copied, modified Raw instance. Defaults to True. %(n_jobs)s %(verbose)s @@ -80,6 +80,7 @@ def apply_pca_obs( if copy: return raw + def _pca_obs( data: np.ndarray, qrs: np.ndarray, From d170c39e762640ee62dabfa240c6a9c43355c90b Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Mon, 13 Jan 2025 10:04:29 +0100 Subject: [PATCH 59/71] fix: remove conditional raw return, always return raw object regardless of copy kwarg to allow method chaining --- mne/preprocessing/pca_obs.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index b3cfb6ec373..753f84d53e9 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -28,7 +28,7 @@ def apply_pca_obs( n_jobs: int | None = None, copy: bool = True, verbose: bool | str | int | None = None, -) -> Raw | None: +) -> Raw: """ Apply the PCA-OBS algorithm to picks of a Raw object. @@ -77,8 +77,7 @@ def apply_pca_obs( n_components=n_components, ) - if copy: - return raw + return raw def _pca_obs( From b70545bd9cc6a431fd72ab03e74176c9dcc482af Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Mon, 13 Jan 2025 10:08:59 +0100 Subject: [PATCH 60/71] lint: resolve linter errors --- tutorials/preprocessing/50_artifact_correction_ssp.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tutorials/preprocessing/50_artifact_correction_ssp.py b/tutorials/preprocessing/50_artifact_correction_ssp.py index 45e52ede0df..2ef07168f20 100644 --- a/tutorials/preprocessing/50_artifact_correction_ssp.py +++ b/tutorials/preprocessing/50_artifact_correction_ssp.py @@ -391,11 +391,11 @@ # See the documentation of each function for further details. # # .. note:: -# In situations only limited electrodes are available for analysis, removing the cardiac -# artefact using techniques which rely on the availability of spatial information -# (such as SSP) may not be possible. In these instances, it may be of use to consider -# algorithms which require information only regarding heartbeat instances in the time domain, -# such as `mne.preprocessing.pca_obs`. +# In situations only limited electrodes are available for analysis, removing the +# cardiac artefact using techniques which rely on the availability of spatial +# information (such as SSP) may not be possible. In these instances, it may be of +# use to consider algorithms which require information only regarding heartbeat +# instances in the time domain, such as `mne.preprocessing.pca_obs`. # # # Repairing EOG artifacts with SSP From 39c322fda080d4e62f97b762ace04b4e9c1bfd71 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 14 Jan 2025 12:19:35 -0500 Subject: [PATCH 61/71] FIX: Doc --- .circleci/config.yml | 2 +- mne/preprocessing/__init__.pyi | 2 +- mne/preprocessing/{pca_obs.py => _pca_obs.py} | 16 ++++++++++++---- .../preprocessing/50_artifact_correction_ssp.py | 2 +- 4 files changed, 15 insertions(+), 7 deletions(-) rename mne/preprocessing/{pca_obs.py => _pca_obs.py} (98%) diff --git a/.circleci/config.yml b/.circleci/config.yml index 337e01394fa..26b9f600e3c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -255,7 +255,7 @@ jobs: name: Check sphinx log for warnings (which are treated as errors) when: always command: | - ! grep "^.* (WARNING|ERROR): .*$" sphinx_log.txt + ! grep "^.*\(WARNING\|ERROR\): " sphinx_log.txt - run: name: Show profiling output when: always diff --git a/mne/preprocessing/__init__.pyi b/mne/preprocessing/__init__.pyi index d0a6a1dd742..c54685dba34 100644 --- a/mne/preprocessing/__init__.pyi +++ b/mne/preprocessing/__init__.pyi @@ -57,6 +57,7 @@ from ._fine_cal import ( write_fine_calibration, ) from ._lof import find_bad_channels_lof +from ._pca_obs import apply_pca_obs from ._peak_finder import peak_finder from ._regress import EOGRegression, read_eog_regression, regress_artifact from .artifact_detection import ( @@ -86,7 +87,6 @@ from .maxwell import ( maxwell_filter_prepare_emptyroom, ) from .otp import oversampled_temporal_projection -from .pca_obs import apply_pca_obs from .realign import realign_raw from .ssp import compute_proj_ecg, compute_proj_eog from .stim import fix_stim_artifact diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/_pca_obs.py similarity index 98% rename from mne/preprocessing/pca_obs.py rename to mne/preprocessing/_pca_obs.py index 753f84d53e9..830673a189c 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/_pca_obs.py @@ -38,18 +38,26 @@ def apply_pca_obs( ---------- raw : instance of Raw The raw data to process. - picks : list of str - Channels in the Raw object to remove the heart artefact from. + %(picks_all_data_noref)s qrs_times : ndarray, shape (n_peaks,) Array of times in the Raw data of detected R-peaks in ECG channel. n_components : int Number of PCA components to use to form the OBS (default 4). + %(n_jobs)s copy : bool If False, modify the Raw instance in-place. - If True, return a copied, modified Raw instance. Defaults to True. - %(n_jobs)s + If True (default), copy the raw instance before processing. %(verbose)s + Returns + ------- + raw : instance of Raw + The modified raw instance. + + Notes + ----- + .. versionadded:: 1.10 + References ---------- .. footbibliography:: diff --git a/tutorials/preprocessing/50_artifact_correction_ssp.py b/tutorials/preprocessing/50_artifact_correction_ssp.py index 2ef07168f20..0721e54f4ba 100644 --- a/tutorials/preprocessing/50_artifact_correction_ssp.py +++ b/tutorials/preprocessing/50_artifact_correction_ssp.py @@ -395,7 +395,7 @@ # cardiac artefact using techniques which rely on the availability of spatial # information (such as SSP) may not be possible. In these instances, it may be of # use to consider algorithms which require information only regarding heartbeat -# instances in the time domain, such as `mne.preprocessing.pca_obs`. +# instances in the time domain, such as :func:`mne.preprocessing.apply_pca_obs`. # # # Repairing EOG artifacts with SSP From f62f98b053925776c76cfeceefc98a789ddf9382 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Tue, 14 Jan 2025 17:27:07 +0000 Subject: [PATCH 62/71] [autofix.ci] apply automated fixes --- mne/preprocessing/_pca_obs.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mne/preprocessing/_pca_obs.py b/mne/preprocessing/_pca_obs.py index 830673a189c..4ed303219be 100755 --- a/mne/preprocessing/_pca_obs.py +++ b/mne/preprocessing/_pca_obs.py @@ -1,7 +1,6 @@ """Principle Component Analysis Optimal Basis Sets (PCA-OBS).""" -# Authors: Emma Bailey , -# Steinn Hauser Magnusson +# Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. From ebda34e300cdeeb28217d2b91e1275623a915513 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 14 Jan 2025 12:53:20 -0500 Subject: [PATCH 63/71] FIX: How --- mne/preprocessing/_pca_obs.py | 9 +++------ mne/utils/numerics.py | 3 +++ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mne/preprocessing/_pca_obs.py b/mne/preprocessing/_pca_obs.py index 4ed303219be..892170003e1 100755 --- a/mne/preprocessing/_pca_obs.py +++ b/mne/preprocessing/_pca_obs.py @@ -9,12 +9,9 @@ import numpy as np from scipy.interpolate import PchipInterpolator as pchip from scipy.signal import detrend -from sklearn.decomposition import PCA -from mne.io.fiff.raw import Raw -from mne.utils import logger -from mne.utils._logging import verbose -from mne.utils.check import _validate_type +from ..io.fiff.raw import Raw +from ..utils import _PCA, _validate_type, logger, verbose @verbose @@ -138,7 +135,7 @@ def _pca_obs( # Perform PCA with sklearn # ############################ # run PCA, perform singular value decomposition (SVD) - pca = PCA(svd_solver="full") + pca = _PCA() pca.fit(dpcamat) factor_loadings = pca.components_.T * np.sqrt(pca.explained_variance_) diff --git a/mne/utils/numerics.py b/mne/utils/numerics.py index c287fb42305..c64652450a0 100644 --- a/mne/utils/numerics.py +++ b/mne/utils/numerics.py @@ -859,6 +859,9 @@ def fit_transform(self, X, y=None): return U + def fit(self, X): + self._fit(X) + def _fit(self, X): if self.n_components is None: n_components = min(X.shape) From 92fe25c6d85614f72de6d809fdb2789642c34ec0 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 14 Jan 2025 12:54:14 -0500 Subject: [PATCH 64/71] TST: Pre From d38e37ffb85a6ba4ba19e9dd34aec28860062f1d Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 14 Jan 2025 12:58:42 -0500 Subject: [PATCH 65/71] DOC: whats new --- doc/changes/devel/13037.newfeature.rst | 1 + doc/changes/names.inc | 2 ++ 2 files changed, 3 insertions(+) create mode 100644 doc/changes/devel/13037.newfeature.rst diff --git a/doc/changes/devel/13037.newfeature.rst b/doc/changes/devel/13037.newfeature.rst new file mode 100644 index 00000000000..05d44b948d3 --- /dev/null +++ b/doc/changes/devel/13037.newfeature.rst @@ -0,0 +1 @@ +Add PCA-OBS preprocessing for the removal of heart-artefacts from EEG or ESG datasets via :func:`mne.preprocessing.apply_pca_obs`, by :newcontrib:`Emma Bailey` and :newcontrib:`Seinn Hauser Magnusson`. diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 3ac0b1cd9c9..eb444c5e594 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -73,6 +73,7 @@ .. _Eberhard Eich: https://github.com/ebeich .. _Eduard Ort: https://github.com/eort .. _Emily Stephen: https://github.com/emilyps14 +.. _Emma Bailey: https://www.cbs.mpg.de/employees/bailey .. _Enrico Varano: https://github.com/enricovara/ .. _Enzo Altamiranda: https://www.linkedin.com/in/enzoalt .. _Eric Larson: https://larsoner.com @@ -284,6 +285,7 @@ .. _Stanislas Chambon: https://github.com/Slasnista .. _Stefan Appelhoff: https://stefanappelhoff.com .. _Stefan Repplinger: https://github.com/stfnrpplngr +.. _Steinn Hauser Magnusson: https://github.com/steinnhauser .. _Steven Bethard: https://github.com/bethard .. _Steven Bierer: https://github.com/neurolaunch .. _Steven Gutstein: https://github.com/smgutstein From 2472c84406d6fc50fe014bcc4c6026c6cbe96b31 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 14 Jan 2025 13:16:10 -0500 Subject: [PATCH 66/71] FIX: Name --- doc/changes/devel/13037.newfeature.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/changes/devel/13037.newfeature.rst b/doc/changes/devel/13037.newfeature.rst index 05d44b948d3..3b28e2294ab 100644 --- a/doc/changes/devel/13037.newfeature.rst +++ b/doc/changes/devel/13037.newfeature.rst @@ -1 +1 @@ -Add PCA-OBS preprocessing for the removal of heart-artefacts from EEG or ESG datasets via :func:`mne.preprocessing.apply_pca_obs`, by :newcontrib:`Emma Bailey` and :newcontrib:`Seinn Hauser Magnusson`. +Add PCA-OBS preprocessing for the removal of heart-artefacts from EEG or ESG datasets via :func:`mne.preprocessing.apply_pca_obs`, by :newcontrib:`Emma Bailey` and :newcontrib:`Steinn Hauser Magnusson`. From dab5623ccde47a4ab05c4a88a3f38a250dbc9fd0 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Thu, 16 Jan 2025 12:16:41 -0500 Subject: [PATCH 67/71] Apply suggestions from code review --- mne/preprocessing/tests/test_pca_obs.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mne/preprocessing/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py index 5b6707224aa..ee2568a2080 100644 --- a/mne/preprocessing/tests/test_pca_obs.py +++ b/mne/preprocessing/tests/test_pca_obs.py @@ -5,7 +5,6 @@ from pathlib import Path import numpy as np -import pandas as pd import pytest from mne.io import read_raw_fif @@ -24,7 +23,7 @@ def short_raw_data(): def test_heart_artifact_removal(short_raw_data: Raw): """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" - pytest.importorskip("pandas") + pd = pytest.importorskip("pandas") # copy the original raw. heart artifact is removed in-place orig_df: pd.DataFrame = short_raw_data.to_data_frame().copy(deep=True) @@ -104,7 +103,5 @@ def test_pca_obs_bad_input( exception: type[Exception], ): """Test if bad input data raises the proper errors in the function sanity checks.""" - pytest.importorskip("pandas") - with pytest.raises(exception, match=error): apply_pca_obs(raw=short_raw_data, picks=picks, qrs_times=qrs_times) From a8e84ff986bcfe7686c653f3322b902ebbe358f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Steinn=20Hauser=20Magn=C3=BAsson?= <42544980+steinnhauser@users.noreply.github.com> Date: Fri, 17 Jan 2025 09:23:08 +0100 Subject: [PATCH 68/71] Update examples/preprocessing/esg_rm_heart_artefact_pcaobs.py Fix typo Co-authored-by: Daniel McCloy --- examples/preprocessing/esg_rm_heart_artefact_pcaobs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index 56fd6b0051a..24a956e963b 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -10,7 +10,7 @@ the ballistocardiographic artefact in simultaneous EEG-fMRI. Here, it has been adapted to remove the delay between the detected R-peak and the ballistocardiographic artefact such that the algorithm can be applied to -remove the cardiac artefact in EEG (electroencephalogrpahy) and ESG +remove the cardiac artefact in EEG (electroencephalography) and ESG (electrospinography) data. We will illustrate how it works by applying the algorithm to ESG data, where the effect of removal is most pronounced. From 12b38343f166e0adf563c88da8ffee514b9690d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Steinn=20Hauser=20Magn=C3=BAsson?= <42544980+steinnhauser@users.noreply.github.com> Date: Fri, 17 Jan 2025 09:23:56 +0100 Subject: [PATCH 69/71] Update examples/preprocessing/esg_rm_heart_artefact_pcaobs.py add punctuation Co-authored-by: Daniel McCloy --- examples/preprocessing/esg_rm_heart_artefact_pcaobs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index 24a956e963b..e9103d3b216 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -29,7 +29,7 @@ import numpy as np ############################################################################### -# Download sample subject data from OpenNeuro if you haven't already +# Download sample subject data from OpenNeuro if you haven't already. # This will download simultaneous EEG and ESG data from a single run of a # single participant after median nerve stimulation of the left wrist import openneuro From c1106c257d5c628729898195fd4374639a4bfa04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Steinn=20Hauser=20Magn=C3=BAsson?= <42544980+steinnhauser@users.noreply.github.com> Date: Fri, 17 Jan 2025 09:25:43 +0100 Subject: [PATCH 70/71] Apply suggestions from code review Co-authored-by: Daniel McCloy --- .../esg_rm_heart_artefact_pcaobs.py | 28 +++++++++---------- mne/preprocessing/_pca_obs.py | 2 +- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index e9103d3b216..a6c6bb3c2ba 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -28,10 +28,10 @@ import numpy as np -############################################################################### +# %% # Download sample subject data from OpenNeuro if you haven't already. # This will download simultaneous EEG and ESG data from a single run of a -# single participant after median nerve stimulation of the left wrist +# single participant after median nerve stimulation of the left wrist. import openneuro from matplotlib import pyplot as plt @@ -50,8 +50,8 @@ block_files = glob.glob(str(target_dir / run_name)) assert len(block_files) == 1 -############################################################################### -# Define the esg channels (arranged in two patches over the neck and lower back) +# %% +# Define the esg channels (arranged in two patches over the neck and lower back). esg_chans = [ "S35", @@ -101,9 +101,9 @@ iv_baseline = [-400e-3, -300e-3] iv_epoch = [-400e-3, 600e-3] -###################################################### -# Perform minimal preprocessing including removing the -# stimulation artefact, downsampling and filtering +# %% +# Next, we perform minimal preprocessing including removing the +# stimulation artefact, downsampling and filtering. raw = read_raw_eeglab(block_files[0], verbose="error") raw.set_channel_types(dict(ECG="ecg")) @@ -126,8 +126,8 @@ stim_channel=None, ) -################################################################### -# Find ECG events and add to the raw structure as event annotations +# %% +# Find ECG events and add to the raw structure as event annotations. ecg_events, ch_ecg, average_pulse = find_ecg_events(raw, ch_name="ECG") ecg_event_samples = np.asarray( @@ -144,9 +144,9 @@ qrs_event_time, duration, description, ch_names=[esg_chans] * len(qrs_event_time) ) -################################################### -# Create evoked response about the detected R-peaks -# before and after cardiac artefact correction +# %% +# Create evoked response around the detected R-peaks +# before and after cardiac artefact correction. events, event_ids = events_from_annotations(raw) event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} @@ -176,8 +176,8 @@ ) evoked_after = epochs.average() -################## -# Comparison image +# %% +# Compare evoked responses to assess completeness of artefact removal. fig, axes = plt.subplots(1, 1, layout="constrained") data_before = evoked_before.get_data(units=dict(eeg="uV")).T diff --git a/mne/preprocessing/_pca_obs.py b/mne/preprocessing/_pca_obs.py index 892170003e1..be226a73889 100755 --- a/mne/preprocessing/_pca_obs.py +++ b/mne/preprocessing/_pca_obs.py @@ -111,7 +111,7 @@ def _pca_obs( peak_range / 8 ) # sample fit for interpolation between fitted artifact windows - # make sure array is long enough for PArange (if not cut off last ECG peak) + # make sure array is long enough for PArange (if not cut off last ECG peak) # NOTE: Here we previously checked for the last part of the window to be big enough. while peak_idx[peak_count - 1] + peak_range > len(data): peak_count = peak_count - 1 # reduce number of QRS complexes detected From a4ee593d41ae185aaec73543b0e6a1f53646d41e Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 17 Jan 2025 15:35:09 -0500 Subject: [PATCH 71/71] FIX: Move --- pyproject.toml | 1 + tools/circleci_dependencies.sh | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bb56126bc07..f20c495a2bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ doc = [ "mne-gui-addons", "neo", "numpydoc", + "openneuro-py", "psutil", "pydata_sphinx_theme >= 0.15.2", "pygments >= 2.13", diff --git a/tools/circleci_dependencies.sh b/tools/circleci_dependencies.sh index 2ecc9718ab2..dd3216ebf06 100755 --- a/tools/circleci_dependencies.sh +++ b/tools/circleci_dependencies.sh @@ -11,6 +11,6 @@ python -m pip install --upgrade --progress-bar off \ alphaCSC autoreject bycycle conpy emd fooof meggie \ mne-ari mne-bids-pipeline mne-faster mne-features \ mne-icalabel mne-lsl mne-microstates mne-nirs mne-rsa \ - neurodsp neurokit2 niseq nitime openneuro-py pactools \ + neurodsp neurokit2 niseq nitime pactools \ plotly pycrostates pyprep pyriemann python-picard sesameeg \ sleepecg tensorpac yasa meegkit eeg_positions