From 61f58b279f11adf25043c671ab491d56aa10d032 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 23 Oct 2024 16:32:20 +0000 Subject: [PATCH 01/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/preprocessing/pca_obs/PCA_OBS.py | 112 +++++++++++------ mne/preprocessing/pca_obs/fit_ecgTemplate.py | 53 ++++++-- .../pca_obs/pchip_interpolation.py | 45 ++++--- .../rm_heart_artefact_cortical_mnedata.py | 57 +++++---- ...rm_heart_artefact_spinal_impreciserpeak.py | 116 +++++++++++++----- .../rm_heart_artefact_spinal_preciserpeak.py | 113 ++++++++++++----- 6 files changed, 340 insertions(+), 156 deletions(-) diff --git a/mne/preprocessing/pca_obs/PCA_OBS.py b/mne/preprocessing/pca_obs/PCA_OBS.py index aee06165d11..ec09d4f9f6e 100755 --- a/mne/preprocessing/pca_obs/PCA_OBS.py +++ b/mne/preprocessing/pca_obs/PCA_OBS.py @@ -1,17 +1,16 @@ +import math + import numpy as np +from fit_ecgTemplate import fit_ecgTemplate + # import mne -from scipy.signal import filtfilt, detrend -import matplotlib.pyplot as plt +from scipy.signal import detrend, filtfilt from sklearn.decomposition import PCA -from fit_ecgTemplate import fit_ecgTemplate -import math -import h5py def PCA_OBS(data, **kwargs): - # Declare class to hold pca information - class PCAInfo(): + class PCAInfo: def __init__(self): pass @@ -20,12 +19,14 @@ def __init__(self): # Check all necessary arguments sent in required_kws = ["qrs", "filter_coords", "sr"] - assert all([kw in kwargs.keys() for kw in required_kws]), "Error. Some KWs not passed into PCA_OBS." + assert all( + [kw in kwargs.keys() for kw in required_kws] + ), "Error. Some KWs not passed into PCA_OBS." # Extract all kwargs - qrs = kwargs['qrs'] - filter_coords = kwargs['filter_coords'] - sr = kwargs['sr'] + qrs = kwargs["qrs"] + filter_coords = kwargs["filter_coords"] + sr = kwargs["sr"] fs = sr @@ -50,19 +51,21 @@ def __init__(self): ################################################################ # Preparatory work - reserving memory, configure sizes, de-trend ################################################################ - print('Pulse artifact subtraction in progress...Please wait!') + print("Pulse artifact subtraction in progress...Please wait!") # define peak range based on RR RR = np.diff(peak_idx[:, 0]) mRR = np.median(RR) - peak_range = round(mRR/2) # Rounds to an integer + peak_range = round(mRR / 2) # Rounds to an integer midP = peak_range + 1 - baseline_range = [0, round(peak_range/8)] - n_samples_fit = round(peak_range/8) # sample fit for interpolation between fitted artifact windows + baseline_range = [0, round(peak_range / 8)] + n_samples_fit = round( + peak_range / 8 + ) # sample fit for interpolation between fitted artifact windows # make sure array is long enough for PArange (if not cut off last ECG peak) pa = peak_count # Number of QRS complexes detected - while peak_idx[pa-1, 0] + peak_range > len(data[0]): + while peak_idx[pa - 1, 0] + peak_range > len(data[0]): pa = pa - 1 steps = 1 * pa peak_count = pa @@ -71,16 +74,22 @@ def __init__(self): eegchan = filtfilt(filter_coords, 1, data) # build PCA matrix(heart-beat-epochs x window-length) - pcamat = np.zeros((peak_count - 1, 2*peak_range+1)) # [epoch x time] + pcamat = np.zeros((peak_count - 1, 2 * peak_range + 1)) # [epoch x time] # picking out heartbeat epochs for p in range(1, peak_count): - pcamat[p-1, :] = eegchan[0, peak_idx[p, 0] - peak_range: peak_idx[p, 0] + peak_range+1] + pcamat[p - 1, :] = eegchan[ + 0, peak_idx[p, 0] - peak_range : peak_idx[p, 0] + peak_range + 1 + ] # detrending matrix(twice) - pcamat = detrend(pcamat, type='constant', axis=1) # [epoch x time] - detrended along the epoch - mean_effect = np.mean(pcamat, axis=0) # [1 x time], contains the mean over all epochs + pcamat = detrend( + pcamat, type="constant", axis=1 + ) # [epoch x time] - detrended along the epoch + mean_effect = np.mean( + pcamat, axis=0 + ) # [1 x time], contains the mean over all epochs std_effect = np.std(pcamat, axis=0) # want mean and std of each column - dpcamat = detrend(pcamat, type='constant', axis=1) # [time x epoch] + dpcamat = detrend(pcamat, type="constant", axis=1) # [time x epoch] ################################################################### # Perform PCA with sklearn @@ -90,7 +99,7 @@ def __init__(self): pca.fit(dpcamat) eigen_vectors = pca.components_ eigen_values = pca.explained_variance_ - factor_loadings = pca.components_.T*np.sqrt(pca.explained_variance_) + factor_loadings = pca.components_.T * np.sqrt(pca.explained_variance_) pca_info.eigen_vectors = eigen_vectors pca_info.factor_loadings = factor_loadings pca_info.eigen_values = eigen_values @@ -116,34 +125,55 @@ def __init__(self): # Deals with start portion of data if p == 0: pre_range = peak_range - post_range = math.floor((peak_idx[p + 1] - peak_idx[p])/2) + post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) if post_range > peak_range: post_range = peak_range try: post_idx_nextPeak = [] - fitted_art, post_idx_nextPeak = fit_ecgTemplate(data, pca_template, peak_idx[p], peak_range, - pre_range, post_range, baseline_range, midP, - fitted_art, post_idx_nextPeak, n_samples_fit) + fitted_art, post_idx_nextPeak = fit_ecgTemplate( + data, + pca_template, + peak_idx[p], + peak_range, + pre_range, + post_range, + baseline_range, + midP, + fitted_art, + post_idx_nextPeak, + n_samples_fit, + ) # Appending to list instead of using counter window_start_idx.append(peak_idx[p] - peak_range) window_end_idx.append(peak_idx[p] + peak_range) except Exception as e: - print(f'Cannot fit first ECG epoch. Reason: {e}') + print(f"Cannot fit first ECG epoch. Reason: {e}") # Deals with last edge of data elif p == peak_count: - print('On last section - almost there!') + print("On last section - almost there!") try: pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) post_range = peak_range if pre_range > peak_range: pre_range = peak_range - fitted_art, _ = fit_ecgTemplate(data, pca_template, peak_idx(p), peak_range, pre_range, post_range, - baseline_range, midP, fitted_art, post_idx_nextPeak, n_samples_fit) + fitted_art, _ = fit_ecgTemplate( + data, + pca_template, + peak_idx(p), + peak_range, + pre_range, + post_range, + baseline_range, + midP, + fitted_art, + post_idx_nextPeak, + n_samples_fit, + ) window_start_idx.append(peak_idx[p] - peak_range) window_end_idx.append(peak_idx[p] + peak_range) except Exception as e: - print(f'Cannot fit last ECG epoch. Reason: {e}') + print(f"Cannot fit last ECG epoch. Reason: {e}") # Deals with middle portion of data else: @@ -157,10 +187,22 @@ def __init__(self): if post_range > peak_range: post_range = peak_range - aTemplate = pca_template[midP - peak_range-1:midP + peak_range+1, :] - fitted_art, post_idx_nextPeak = fit_ecgTemplate(data, aTemplate, peak_idx[p], peak_range, pre_range, - post_range, baseline_range, midP, fitted_art, - post_idx_nextPeak, n_samples_fit) + aTemplate = pca_template[ + midP - peak_range - 1 : midP + peak_range + 1, : + ] + fitted_art, post_idx_nextPeak = fit_ecgTemplate( + data, + aTemplate, + peak_idx[p], + peak_range, + pre_range, + post_range, + baseline_range, + midP, + fitted_art, + post_idx_nextPeak, + n_samples_fit, + ) window_start_idx.append(peak_idx[p] - peak_range) window_end_idx.append(peak_idx[p] + peak_range) except Exception as e: diff --git a/mne/preprocessing/pca_obs/fit_ecgTemplate.py b/mne/preprocessing/pca_obs/fit_ecgTemplate.py index 11043b0c56d..3e5b7f55cb2 100755 --- a/mne/preprocessing/pca_obs/fit_ecgTemplate.py +++ b/mne/preprocessing/pca_obs/fit_ecgTemplate.py @@ -1,12 +1,23 @@ import numpy as np -from scipy.signal import detrend from scipy.interpolate import PchipInterpolator as pchip -import h5py +from scipy.signal import detrend -def fit_ecgTemplate(data, pca_template, aPeak_idx, peak_range, pre_range, post_range, baseline_range, midP, fitted_art, post_idx_previousPeak, n_samples_fit): +def fit_ecgTemplate( + data, + pca_template, + aPeak_idx, + peak_range, + pre_range, + post_range, + baseline_range, + midP, + fitted_art, + post_idx_previousPeak, + n_samples_fit, +): # Declare class to hold ecg fit information - class fitECG(): + class fitECG: def __init__(self): pass @@ -16,18 +27,20 @@ def __init__(self): # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previouspeak # Then nextpeak is returned at the end and the process repeats # select window of template - template = pca_template[midP - peak_range-1: midP + peak_range+1, :] + template = pca_template[midP - peak_range - 1 : midP + peak_range + 1, :] # select window of data and detrend it - slice = data[0, aPeak_idx[0] - peak_range:aPeak_idx[0] + peak_range+1] - detrended_data = detrend(slice.reshape(-1), type='constant') + slice = data[0, aPeak_idx[0] - peak_range : aPeak_idx[0] + peak_range + 1] + detrended_data = detrend(slice.reshape(-1), type="constant") # maps data on template and then maps it again back to the sensor space least_square = np.linalg.lstsq(template, detrended_data, rcond=None) pad_fit = np.dot(template, least_square[0]) # fit artifact, I already loop through externally channel to channel - fitted_art[0, aPeak_idx[0] - pre_range-1: aPeak_idx[0] + post_range] = pad_fit[midP - pre_range-1: midP + post_range].T + fitted_art[0, aPeak_idx[0] - pre_range - 1 : aPeak_idx[0] + post_range] = pad_fit[ + midP - pre_range - 1 : midP + post_range + ].T fitecg.fitted_art = fitted_art fitecg.template = template @@ -43,7 +56,9 @@ def __init__(self): # Check it's not empty if len(post_idx_previousPeak) != 0: # interpolate time between peaks - intpol_window = np.ceil([post_idx_previousPeak[0], aPeak_idx[0] - pre_range]).astype('int') # interpolation window + intpol_window = np.ceil( + [post_idx_previousPeak[0], aPeak_idx[0] - pre_range] + ).astype("int") # interpolation window fitecg.intpol_window = intpol_window if intpol_window[0] < intpol_window[1]: @@ -53,14 +68,26 @@ def __init__(self): # You have y_fit which is the y vals corresponding to x values above # You have x_interpol which is the time points between the two slices in x_fit that you want to interpolate # You have y_interpol which is values from pchip at the time points specified in x_interpol - x_interpol = np.arange(intpol_window[0], intpol_window[1] + 1, 1) # points to be interpolated in pt - the gap between the endpoints of the window - x_fit = np.concatenate([np.arange(intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1), - np.arange(intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1)]) # Entire range of x values in this step (taking some number of samples before and after the window) + x_interpol = np.arange( + intpol_window[0], intpol_window[1] + 1, 1 + ) # points to be interpolated in pt - the gap between the endpoints of the window + x_fit = np.concatenate( + [ + np.arange( + intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1 + ), + np.arange( + intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1 + ), + ] + ) # Entire range of x values in this step (taking some number of samples before and after the window) y_fit = fitted_art[0, x_fit] y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation # Then make fitted artefact in the desired range equal to the completed fit above - fitted_art[0, post_idx_previousPeak[0]: aPeak_idx[0] - pre_range+1] = y_interpol + fitted_art[0, post_idx_previousPeak[0] : aPeak_idx[0] - pre_range + 1] = ( + y_interpol + ) fitecg.x_fit = x_fit fitecg.y_fit = y_fit diff --git a/mne/preprocessing/pca_obs/pchip_interpolation.py b/mne/preprocessing/pca_obs/pchip_interpolation.py index e336f672363..56a34eabb43 100755 --- a/mne/preprocessing/pca_obs/pchip_interpolation.py +++ b/mne/preprocessing/pca_obs/pchip_interpolation.py @@ -3,33 +3,46 @@ # import mne import numpy as np from scipy.interpolate import PchipInterpolator as pchip -import matplotlib.pyplot as plt def PCHIP_interpolation(data, **kwargs): # Check all necessary arguments sent in required_kws = ["trigger_indices", "interpol_window_sec", "fs"] - assert all([kw in kwargs.keys() for kw in required_kws]), "Error. Some KWs not passed into PCA_OBS." + assert all( + [kw in kwargs.keys() for kw in required_kws] + ), "Error. Some KWs not passed into PCA_OBS." # Extract all kwargs - more elegant ways to do this - fs = kwargs['fs'] - interpol_window_sec = kwargs['interpol_window_sec'] - trigger_indices = kwargs['trigger_indices'] + fs = kwargs["fs"] + interpol_window_sec = kwargs["interpol_window_sec"] + trigger_indices = kwargs["trigger_indices"] # Convert intpol window to msec then convert to samples - pre_window = round((interpol_window_sec[0]*1000) * fs / 1000) # in samples - post_window = round((interpol_window_sec[1]*1000) * fs / 1000) # in samples - intpol_window = np.ceil([pre_window, post_window]).astype(int) # interpolation window - - n_samples_fit = 5 # number of samples before and after cut used for interpolation fit - - x_fit_raw = np.concatenate([np.arange(intpol_window[0]-n_samples_fit-1, intpol_window[0], 1), - np.arange(intpol_window[1]+1, intpol_window[1]+n_samples_fit+2, 1)]) - x_interpol_raw = np.arange(intpol_window[0], intpol_window[1]+1, 1) # points to be interpolated; in pt + pre_window = round((interpol_window_sec[0] * 1000) * fs / 1000) # in samples + post_window = round((interpol_window_sec[1] * 1000) * fs / 1000) # in samples + intpol_window = np.ceil([pre_window, post_window]).astype( + int + ) # interpolation window + + n_samples_fit = ( + 5 # number of samples before and after cut used for interpolation fit + ) + + x_fit_raw = np.concatenate( + [ + np.arange(intpol_window[0] - n_samples_fit - 1, intpol_window[0], 1), + np.arange(intpol_window[1] + 1, intpol_window[1] + n_samples_fit + 2, 1), + ] + ) + x_interpol_raw = np.arange( + intpol_window[0], intpol_window[1] + 1, 1 + ) # points to be interpolated; in pt for ii in np.arange(0, len(trigger_indices)): # loop through all stimulation events x_fit = trigger_indices[ii] + x_fit_raw # fit point latencies for this event - x_interpol = trigger_indices[ii] + x_interpol_raw # latencies for to-be-interpolated data points + x_interpol = ( + trigger_indices[ii] + x_interpol_raw + ) # latencies for to-be-interpolated data points # Data is just a string of values y_fit = data[x_fit] # y values to be fitted @@ -37,6 +50,6 @@ def PCHIP_interpolation(data, **kwargs): data[x_interpol] = y_interpol # replace in data if np.mod(ii, 100) == 0: # talk to the operator every 100th trial - print(f'stimulation event {ii} \n') + print(f"stimulation event {ii} \n") return data diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py index 42e58e99463..e84f9de3b90 100644 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py @@ -1,21 +1,22 @@ # Checking algorithm implementation with EEG data form MNE sample datasets # Following this tutorial data: https://mne.tools/stable/auto_tutorials/preprocessing/50_artifact_correction_ssp.html#what-is-ssp -from mne.preprocessing import ( - find_ecg_events, - create_ecg_epochs, -) -from mne.io import read_raw_fif -from mne.datasets.sample import data_path -from PCA_OBS import * -from mne import Epochs import os + +import matplotlib.pyplot as plt import numpy as np +from PCA_OBS import * from scipy.signal import firls -import matplotlib.pyplot as plt -if __name__ == '__main__': - sample_data_folder = data_path(path='/data/pt_02569/mne_test_data/') +from mne import Epochs +from mne.datasets.sample import data_path +from mne.io import read_raw_fif +from mne.preprocessing import ( + find_ecg_events, +) + +if __name__ == "__main__": + sample_data_folder = data_path(path="/data/pt_02569/mne_test_data/") sample_data_raw_file = os.path.join( sample_data_folder, "MEG", "sample", "sample_audvis_raw.fif" ) @@ -23,13 +24,17 @@ raw = read_raw_fif(sample_data_raw_file, preload=True) # Find ECG events - no ECG channel in data, uses synthetic - ecg_events, ch_ecg, average_pulse, = find_ecg_events(raw) + ( + ecg_events, + ch_ecg, + average_pulse, + ) = find_ecg_events(raw) # Extract just sample timings of ecg events ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) # print(ecg_events) # Create filter coefficients - fs = raw.info['sfreq'] + fs = raw.info["sfreq"] a = [0, 0, 1, 1] f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter @@ -43,35 +48,35 @@ # run PCA_OBS # Algorithm is extremely sensitive to accurate R-peak timings, won't work as well with the artificial ECG # channel estimation as we have here - PCA_OBS_kwargs = dict( - qrs=ecg_event_samples, filter_coords=fwts, sr=fs - ) + PCA_OBS_kwargs = dict(qrs=ecg_event_samples, filter_coords=fwts, sr=fs) - epochs = Epochs(raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], - baseline=tuple(iv_baseline)) + epochs = Epochs( + raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], baseline=tuple(iv_baseline) + ) evoked_before = epochs.average() # Apply function should modifies the data in raw in place - raw.apply_function(PCA_OBS, picks='eeg', **PCA_OBS_kwargs, n_jobs=10) - epochs = Epochs(raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], - baseline=tuple(iv_baseline)) + raw.apply_function(PCA_OBS, picks="eeg", **PCA_OBS_kwargs, n_jobs=10) + epochs = Epochs( + raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], baseline=tuple(iv_baseline) + ) evoked_after = epochs.average() # Comparison image fig, axes = plt.subplots(2, 1) axes[0].plot(evoked_before.times, evoked_before.get_data().T) axes[0].set_ylim([-1e-5, 3e-5]) - axes[0].set_title('Before PCA-OBS') + axes[0].set_title("Before PCA-OBS") axes[1].plot(evoked_after.times, evoked_after.get_data().T) axes[1].set_ylim([-1e-5, 3e-5]) - axes[1].set_title('After PCA-OBS') + axes[1].set_title("After PCA-OBS") plt.tight_layout() # Comparison image fig, axes = plt.subplots(1, 1) - axes.plot(evoked_before.times, evoked_before.get_data().T, color='black') + axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") axes.set_ylim([-1e-5, 3e-5]) - axes.plot(evoked_after.times, evoked_after.get_data().T, color='green') - axes.set_title('Before (black) versus after (green)') + axes.plot(evoked_after.times, evoked_after.get_data().T, color="green") + axes.set_title("Before (black) versus after (green)") plt.tight_layout() plt.show() diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py index 7634096c909..73682b67867 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py @@ -1,80 +1,130 @@ # Calls PCA_OBS which in turn calls fit_ecgTemplate to remove the heart artefact via PCA_OBS (Principal Component # Analysis, Optimal Basis Sets) -import os -from scipy.io import loadmat -from scipy.signal import firls from PCA_OBS import * +from scipy.signal import firls + +from mne import Epochs, events_from_annotations from mne.io import read_raw_fif -from mne import events_from_annotations, Epochs from mne.preprocessing import find_ecg_events - -if __name__ == '__main__': +if __name__ == "__main__": # Incredibly slow without parallelization # Set variables - subject_id = f'sub-001' - cond_name = 'median' + subject_id = "sub-001" + cond_name = "median" sampling_rate = 1000 - esg_chans = ['S35', 'S24', 'S36', 'Iz', 'S17', 'S15', 'S32', 'S22', - 'S19', 'S26', 'S28', 'S9', 'S13', 'S11', 'S7', 'SC1', 'S4', 'S18', - 'S8', 'S31', 'SC6', 'S12', 'S16', 'S5', 'S30', 'S20', 'S34', 'AC', - 'S21', 'S25', 'L1', 'S29', 'S14', 'S33', 'S3', 'AL', 'L4', 'S6', - 'S23'] + esg_chans = [ + "S35", + "S24", + "S36", + "Iz", + "S17", + "S15", + "S32", + "S22", + "S19", + "S26", + "S28", + "S9", + "S13", + "S11", + "S7", + "SC1", + "S4", + "S18", + "S8", + "S31", + "SC6", + "S12", + "S16", + "S5", + "S30", + "S20", + "S34", + "AC", + "S21", + "S25", + "L1", + "S29", + "S14", + "S33", + "S3", + "AL", + "L4", + "S6", + "S23", + ] # For heartbeat epochs iv_baseline = [-300 / 1000, -200 / 1000] iv_epoch = [-400 / 1000, 600 / 1000] # Setting paths - input_path = "/data/pt_02569/tmp_data/prepared_py/"+subject_id+"/esg/prepro/" + input_path = "/data/pt_02569/tmp_data/prepared_py/" + subject_id + "/esg/prepro/" fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs_pchip" - raw = read_raw_fif(input_path + fname + '.fif', preload=True) + raw = read_raw_fif(input_path + fname + ".fif", preload=True) # Find ECG events - no ECG channel in data, uses synthetic - ecg_events, ch_ecg, average_pulse, = find_ecg_events(raw, ch_name='ECG') + ( + ecg_events, + ch_ecg, + average_pulse, + ) = find_ecg_events(raw, ch_name="ECG") # Extract just sample timings of ecg events ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) # Create filter coefficients fs = sampling_rate a = [0, 0, 1, 1] - f = [0, 0.4/(fs/2), 0.9/(fs / 2), 1] # 0.9 Hz highpass filter + f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter - ord = round(3*fs/0.5) - fwts = firls(ord+1, f, a) + ord = round(3 * fs / 0.5) + fwts = firls(ord + 1, f, a) # run PCA_OBS - PCA_OBS_kwargs = dict( - qrs=ecg_event_samples, filter_coords=fwts, sr=sampling_rate - ) + PCA_OBS_kwargs = dict(qrs=ecg_event_samples, filter_coords=fwts, sr=sampling_rate) events, event_ids = events_from_annotations(raw) - event_id_dict = {key: value for key, value in event_ids.items() if key == 'qrs'} - epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], - baseline=tuple(iv_baseline)) + event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} + epochs = Epochs( + raw, + events, + event_id=event_id_dict, + tmin=iv_epoch[0], + tmax=iv_epoch[1], + baseline=tuple(iv_baseline), + ) evoked_before = epochs.average() # Apply function should modifies the data in raw in place - raw.apply_function(PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans)) - epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], - baseline=tuple(iv_baseline)) + raw.apply_function( + PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans) + ) + epochs = Epochs( + raw, + events, + event_id=event_id_dict, + tmin=iv_epoch[0], + tmax=iv_epoch[1], + baseline=tuple(iv_baseline), + ) evoked_after = epochs.average() # Comparison image fig, axes = plt.subplots(2, 1) axes[0].plot(evoked_before.times, evoked_before.get_data().T) axes[0].set_ylim([-0.0005, 0.001]) - axes[0].set_title('Before PCA-OBS') + axes[0].set_title("Before PCA-OBS") axes[1].plot(evoked_after.times, evoked_after.get_data().T) axes[1].set_ylim([-0.0005, 0.001]) - axes[1].set_title('After PCA-OBS') + axes[1].set_title("After PCA-OBS") plt.tight_layout() # Comparison image fig, axes = plt.subplots(1, 1) - axes.plot(evoked_before.times, evoked_before.get_data().T, color='black') + axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") axes.set_ylim([-0.0005, 0.001]) - axes.plot(evoked_after.times, evoked_after.get_data().T, color='green') - axes.set_title('Before (black) versus after (green)') + axes.plot(evoked_after.times, evoked_after.get_data().T, color="green") + axes.set_title("Before (black) versus after (green)") plt.tight_layout() plt.show() diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py index 1b559ef9f22..4c40d8f7234 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py @@ -1,80 +1,127 @@ # Calls PCA_OBS which in turn calls fit_ecgTemplate to remove the heart artefact via PCA_OBS (Principal Component # Analysis, Optimal Basis Sets) -import os +from PCA_OBS import * from scipy.io import loadmat from scipy.signal import firls -from PCA_OBS import * -from mne.io import read_raw_fif -from mne import events_from_annotations, Epochs +from mne import Epochs, events_from_annotations +from mne.io import read_raw_fif -if __name__ == '__main__': +if __name__ == "__main__": # Incredibly slow without parallelization # Set variables - subject_id = f'sub-001' - cond_name = 'median' + subject_id = "sub-001" + cond_name = "median" sampling_rate = 1000 - esg_chans = ['S35', 'S24', 'S36', 'Iz', 'S17', 'S15', 'S32', 'S22', - 'S19', 'S26', 'S28', 'S9', 'S13', 'S11', 'S7', 'SC1', 'S4', 'S18', - 'S8', 'S31', 'SC6', 'S12', 'S16', 'S5', 'S30', 'S20', 'S34', 'AC', - 'S21', 'S25', 'L1', 'S29', 'S14', 'S33', 'S3', 'AL', 'L4', 'S6', - 'S23'] + esg_chans = [ + "S35", + "S24", + "S36", + "Iz", + "S17", + "S15", + "S32", + "S22", + "S19", + "S26", + "S28", + "S9", + "S13", + "S11", + "S7", + "SC1", + "S4", + "S18", + "S8", + "S31", + "SC6", + "S12", + "S16", + "S5", + "S30", + "S20", + "S34", + "AC", + "S21", + "S25", + "L1", + "S29", + "S14", + "S33", + "S3", + "AL", + "L4", + "S6", + "S23", + ] # For heartbeat epochs iv_baseline = [-300 / 1000, -200 / 1000] iv_epoch = [-400 / 1000, 600 / 1000] # Setting paths - input_path = "/data/pt_02569/tmp_data/prepared_py/"+subject_id+"/esg/prepro/" + input_path = "/data/pt_02569/tmp_data/prepared_py/" + subject_id + "/esg/prepro/" input_path_m = "/data/pt_02569/tmp_data/prepared/" + subject_id + "/esg/prepro/" fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs_pchip" - raw = read_raw_fif(input_path + fname + '.fif', preload=True) + raw = read_raw_fif(input_path + fname + ".fif", preload=True) # Read .mat file with QRS events fname_m = f"raw_{sampling_rate}_spinal_{cond_name}" - matdata = loadmat(input_path_m+fname_m+'.mat') - QRSevents_m = matdata['QRSevents'] + matdata = loadmat(input_path_m + fname_m + ".mat") + QRSevents_m = matdata["QRSevents"] # Create filter coefficients fs = sampling_rate a = [0, 0, 1, 1] - f = [0, 0.4/(fs/2), 0.9/(fs / 2), 1] # 0.9 Hz highpass filter + f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter - ord = round(3*fs/0.5) - fwts = firls(ord+1, f, a) + ord = round(3 * fs / 0.5) + fwts = firls(ord + 1, f, a) # run PCA_OBS - PCA_OBS_kwargs = dict( - qrs=QRSevents_m, filter_coords=fwts, sr=sampling_rate - ) + PCA_OBS_kwargs = dict(qrs=QRSevents_m, filter_coords=fwts, sr=sampling_rate) events, event_ids = events_from_annotations(raw) - event_id_dict = {key: value for key, value in event_ids.items() if key == 'qrs'} - epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], - baseline=tuple(iv_baseline)) + event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} + epochs = Epochs( + raw, + events, + event_id=event_id_dict, + tmin=iv_epoch[0], + tmax=iv_epoch[1], + baseline=tuple(iv_baseline), + ) evoked_before = epochs.average() # Apply function should modifies the data in raw in place - raw.apply_function(PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans)) - epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], - baseline=tuple(iv_baseline)) + raw.apply_function( + PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans) + ) + epochs = Epochs( + raw, + events, + event_id=event_id_dict, + tmin=iv_epoch[0], + tmax=iv_epoch[1], + baseline=tuple(iv_baseline), + ) evoked_after = epochs.average() # Comparison image fig, axes = plt.subplots(2, 1) axes[0].plot(evoked_before.times, evoked_before.get_data().T) axes[0].set_ylim([-0.0005, 0.001]) - axes[0].set_title('Before PCA-OBS') + axes[0].set_title("Before PCA-OBS") axes[1].plot(evoked_after.times, evoked_after.get_data().T) axes[1].set_ylim([-0.0005, 0.001]) - axes[1].set_title('After PCA-OBS') + axes[1].set_title("After PCA-OBS") plt.tight_layout() # Comparison image fig, axes = plt.subplots(1, 1) - axes.plot(evoked_before.times, evoked_before.get_data().T, color='black') + axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") axes.set_ylim([-0.0005, 0.001]) - axes.plot(evoked_after.times, evoked_after.get_data().T, color='green') - axes.set_title('Before (black) versus after (green)') + axes.plot(evoked_after.times, evoked_after.get_data().T, color="green") + axes.set_title("Before (black) versus after (green)") plt.tight_layout() plt.show() From 070037dcb42fe04b945891a9d3916a2cb93c09f8 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Wed, 23 Oct 2024 19:12:10 +0200 Subject: [PATCH 02/23] fix: adjust import paths, add init file to module --- mne/preprocessing/pca_obs/PCA_OBS.py | 1 - mne/preprocessing/pca_obs/__init__.py | 0 mne/preprocessing/pca_obs/pchip_interpolation.py | 6 +++--- .../pca_obs/rm_heart_artefact_cortical_mnedata.py | 2 +- .../pca_obs/rm_heart_artefact_spinal_impreciserpeak.py | 5 +++-- .../pca_obs/rm_heart_artefact_spinal_preciserpeak.py | 3 ++- 6 files changed, 9 insertions(+), 8 deletions(-) create mode 100644 mne/preprocessing/pca_obs/__init__.py diff --git a/mne/preprocessing/pca_obs/PCA_OBS.py b/mne/preprocessing/pca_obs/PCA_OBS.py index ec09d4f9f6e..691cd259cdf 100755 --- a/mne/preprocessing/pca_obs/PCA_OBS.py +++ b/mne/preprocessing/pca_obs/PCA_OBS.py @@ -3,7 +3,6 @@ import numpy as np from fit_ecgTemplate import fit_ecgTemplate -# import mne from scipy.signal import detrend, filtfilt from sklearn.decomposition import PCA diff --git a/mne/preprocessing/pca_obs/__init__.py b/mne/preprocessing/pca_obs/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/mne/preprocessing/pca_obs/pchip_interpolation.py b/mne/preprocessing/pca_obs/pchip_interpolation.py index 56a34eabb43..941aa6c0e48 100755 --- a/mne/preprocessing/pca_obs/pchip_interpolation.py +++ b/mne/preprocessing/pca_obs/pchip_interpolation.py @@ -1,6 +1,5 @@ # Function to interpolate based on PCHIP rather than MNE inbuilt linear option -# import mne import numpy as np from scipy.interpolate import PchipInterpolator as pchip @@ -8,9 +7,10 @@ def PCHIP_interpolation(data, **kwargs): # Check all necessary arguments sent in required_kws = ["trigger_indices", "interpol_window_sec", "fs"] - assert all( + if not all( [kw in kwargs.keys() for kw in required_kws] - ), "Error. Some KWs not passed into PCA_OBS." + ): + raise RuntimeError("Some KWs not passed into PCA_OBS.") # Extract all kwargs - more elegant ways to do this fs = kwargs["fs"] diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py index e84f9de3b90..442d0371e4e 100644 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt import numpy as np -from PCA_OBS import * +from mne.preprocessing.pca_obs.PCA_OBS import PCA_OBS from scipy.signal import firls from mne import Epochs diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py index 73682b67867..6c6cd615a7c 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py @@ -1,9 +1,10 @@ # Calls PCA_OBS which in turn calls fit_ecgTemplate to remove the heart artefact via PCA_OBS (Principal Component # Analysis, Optimal Basis Sets) -from PCA_OBS import * +from matplotlib import pyplot as plt +from mne.preprocessing.pca_obs.PCA_OBS import PCA_OBS from scipy.signal import firls - +import numpy as np from mne import Epochs, events_from_annotations from mne.io import read_raw_fif from mne.preprocessing import find_ecg_events diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py index 4c40d8f7234..509c7161f0d 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py @@ -1,7 +1,8 @@ # Calls PCA_OBS which in turn calls fit_ecgTemplate to remove the heart artefact via PCA_OBS (Principal Component # Analysis, Optimal Basis Sets) -from PCA_OBS import * +from matplotlib import pyplot as plt +from mne.preprocessing.pca_obs.PCA_OBS import PCA_OBS from scipy.io import loadmat from scipy.signal import firls From e1b6732750d83d4c04f8df8541fd2d4d38dc141f Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Wed, 23 Oct 2024 19:31:34 +0200 Subject: [PATCH 03/23] refactor: rearrange PCA_OBS arg structure, remove kwarg 'sr' --- mne/preprocessing/pca_obs/PCA_OBS.py | 20 +++++++------------ .../pca_obs/pchip_interpolation.py | 1 + .../rm_heart_artefact_cortical_mnedata.py | 17 ++++++++++------ ...rm_heart_artefact_spinal_impreciserpeak.py | 15 ++++++++------ .../rm_heart_artefact_spinal_preciserpeak.py | 20 ++++++++++--------- 5 files changed, 39 insertions(+), 34 deletions(-) diff --git a/mne/preprocessing/pca_obs/PCA_OBS.py b/mne/preprocessing/pca_obs/PCA_OBS.py index 691cd259cdf..e30006df104 100755 --- a/mne/preprocessing/pca_obs/PCA_OBS.py +++ b/mne/preprocessing/pca_obs/PCA_OBS.py @@ -1,13 +1,18 @@ import math import numpy as np -from fit_ecgTemplate import fit_ecgTemplate +from mne.preprocessing.pca_obs.fit_ecgTemplate import fit_ecgTemplate from scipy.signal import detrend, filtfilt from sklearn.decomposition import PCA -def PCA_OBS(data, **kwargs): +def PCA_OBS( + data, + qrs, + filter_coords, + **kwargs +): # Declare class to hold pca information class PCAInfo: def __init__(self): @@ -16,18 +21,9 @@ def __init__(self): # Instantiate class pca_info = PCAInfo() - # Check all necessary arguments sent in - required_kws = ["qrs", "filter_coords", "sr"] - assert all( - [kw in kwargs.keys() for kw in required_kws] - ), "Error. Some KWs not passed into PCA_OBS." - # Extract all kwargs qrs = kwargs["qrs"] filter_coords = kwargs["filter_coords"] - sr = kwargs["sr"] - - fs = sr # set to baseline data = data.reshape(-1, 1) @@ -66,7 +62,6 @@ def __init__(self): pa = peak_count # Number of QRS complexes detected while peak_idx[pa - 1, 0] + peak_range > len(data[0]): pa = pa - 1 - steps = 1 * pa peak_count = pa # Filter channel @@ -87,7 +82,6 @@ def __init__(self): mean_effect = np.mean( pcamat, axis=0 ) # [1 x time], contains the mean over all epochs - std_effect = np.std(pcamat, axis=0) # want mean and std of each column dpcamat = detrend(pcamat, type="constant", axis=1) # [time x epoch] ################################################################### diff --git a/mne/preprocessing/pca_obs/pchip_interpolation.py b/mne/preprocessing/pca_obs/pchip_interpolation.py index 941aa6c0e48..466946ecdce 100755 --- a/mne/preprocessing/pca_obs/pchip_interpolation.py +++ b/mne/preprocessing/pca_obs/pchip_interpolation.py @@ -3,6 +3,7 @@ import numpy as np from scipy.interpolate import PchipInterpolator as pchip +# TODO: only place defined. is this used? def PCHIP_interpolation(data, **kwargs): # Check all necessary arguments sent in diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py index 442d0371e4e..4df2262cc05 100644 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py @@ -31,7 +31,6 @@ ) = find_ecg_events(raw) # Extract just sample timings of ecg events ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) - # print(ecg_events) # Create filter coefficients fs = raw.info["sfreq"] @@ -42,21 +41,27 @@ fwts = firls(ord + 1, f, a) # For heartbeat epochs - iv_baseline = [-300 / 1000, -200 / 1000] - iv_epoch = [-400 / 1000, 600 / 1000] + iv_baseline = [-300 / 1000, -200 / 1000] # 300 ms before to 200 ms after + iv_epoch = [-400 / 1000, 600 / 1000] # 300 ms before to 200 ms after # run PCA_OBS # Algorithm is extremely sensitive to accurate R-peak timings, won't work as well with the artificial ECG # channel estimation as we have here - PCA_OBS_kwargs = dict(qrs=ecg_event_samples, filter_coords=fwts, sr=fs) - epochs = Epochs( raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], baseline=tuple(iv_baseline) ) evoked_before = epochs.average() # Apply function should modifies the data in raw in place - raw.apply_function(PCA_OBS, picks="eeg", **PCA_OBS_kwargs, n_jobs=10) + raw.apply_function( + PCA_OBS, + picks="eeg", + n_jobs=10 + **{ # args sent to PCA_OBS + "qrs": ecg_event_samples, + "filter_coords": fwts, + }, + ) epochs = Epochs( raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], baseline=tuple(iv_baseline) ) diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py index 6c6cd615a7c..62c126ae6e0 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py @@ -14,7 +14,7 @@ # Set variables subject_id = "sub-001" cond_name = "median" - sampling_rate = 1000 + fs = 1000 # sampling rate esg_chans = [ "S35", "S24", @@ -62,7 +62,7 @@ # Setting paths input_path = "/data/pt_02569/tmp_data/prepared_py/" + subject_id + "/esg/prepro/" - fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs_pchip" + fname = f"noStimart_sr{fs}_{cond_name}_withqrs_pchip" raw = read_raw_fif(input_path + fname + ".fif", preload=True) # Find ECG events - no ECG channel in data, uses synthetic @@ -75,7 +75,6 @@ ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) # Create filter coefficients - fs = sampling_rate a = [0, 0, 1, 1] f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter @@ -83,8 +82,6 @@ fwts = firls(ord + 1, f, a) # run PCA_OBS - PCA_OBS_kwargs = dict(qrs=ecg_event_samples, filter_coords=fwts, sr=sampling_rate) - events, event_ids = events_from_annotations(raw) event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} epochs = Epochs( @@ -99,7 +96,13 @@ # Apply function should modifies the data in raw in place raw.apply_function( - PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans) + PCA_OBS, + picks=esg_chans, + n_jobs=len(esg_chans), + **{ # args sent to PCA_OBS + "qrs": ecg_event_samples, + "filter_coords": fwts, + }, ) epochs = Epochs( raw, diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py index 509c7161f0d..d72bb058c2a 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py @@ -14,7 +14,7 @@ # Set variables subject_id = "sub-001" cond_name = "median" - sampling_rate = 1000 + fs = 1000 # sampling rate esg_chans = [ "S35", "S24", @@ -62,17 +62,15 @@ # Setting paths input_path = "/data/pt_02569/tmp_data/prepared_py/" + subject_id + "/esg/prepro/" - input_path_m = "/data/pt_02569/tmp_data/prepared/" + subject_id + "/esg/prepro/" - fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs_pchip" + fname = f"noStimart_sr{fs}_{cond_name}_withqrs_pchip" raw = read_raw_fif(input_path + fname + ".fif", preload=True) # Read .mat file with QRS events - fname_m = f"raw_{sampling_rate}_spinal_{cond_name}" + input_path_m = "/data/pt_02569/tmp_data/prepared/" + subject_id + "/esg/prepro/" + fname_m = f"raw_{fs}_spinal_{cond_name}" matdata = loadmat(input_path_m + fname_m + ".mat") - QRSevents_m = matdata["QRSevents"] # Create filter coefficients - fs = sampling_rate a = [0, 0, 1, 1] f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter @@ -80,8 +78,6 @@ fwts = firls(ord + 1, f, a) # run PCA_OBS - PCA_OBS_kwargs = dict(qrs=QRSevents_m, filter_coords=fwts, sr=sampling_rate) - events, event_ids = events_from_annotations(raw) event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} epochs = Epochs( @@ -96,7 +92,13 @@ # Apply function should modifies the data in raw in place raw.apply_function( - PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans) + PCA_OBS, + picks=esg_chans, + n_jobs=len(esg_chans) + **{ # args sent to PCA_OBS + "qrs": matdata["QRSevents"], + "filter_coords": fwts, + }, ) epochs = Epochs( raw, From ccd42980d325b6b466ae2450a6899a1c7ae8af0c Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Wed, 23 Oct 2024 20:09:05 +0200 Subject: [PATCH 04/23] refactor: move common variables to module init, further cleanup --- mne/preprocessing/pca_obs/PCA_OBS.py | 14 ++-- mne/preprocessing/pca_obs/__init__.py | 58 ++++++++++++++ .../pca_obs/pchip_interpolation.py | 21 ++--- ...rm_heart_artefact_spinal_impreciserpeak.py | 72 +++-------------- .../rm_heart_artefact_spinal_preciserpeak.py | 77 ++++--------------- 5 files changed, 99 insertions(+), 143 deletions(-) diff --git a/mne/preprocessing/pca_obs/PCA_OBS.py b/mne/preprocessing/pca_obs/PCA_OBS.py index e30006df104..cc226113de3 100755 --- a/mne/preprocessing/pca_obs/PCA_OBS.py +++ b/mne/preprocessing/pca_obs/PCA_OBS.py @@ -1,17 +1,19 @@ import math +from typing import Any import numpy as np from mne.preprocessing.pca_obs.fit_ecgTemplate import fit_ecgTemplate from scipy.signal import detrend, filtfilt from sklearn.decomposition import PCA +from numpy.typing import NDArray def PCA_OBS( - data, - qrs, - filter_coords, - **kwargs + data: NDArray[Any], + qrs: NDArray[Any], + filter_coords: NDArray[Any], + **_ # TODO: are there any other kwargs passed in? ): # Declare class to hold pca information class PCAInfo: @@ -21,10 +23,6 @@ def __init__(self): # Instantiate class pca_info = PCAInfo() - # Extract all kwargs - qrs = kwargs["qrs"] - filter_coords = kwargs["filter_coords"] - # set to baseline data = data.reshape(-1, 1) data = data.T diff --git a/mne/preprocessing/pca_obs/__init__.py b/mne/preprocessing/pca_obs/__init__.py index e69de29bb2d..054c4a3d2aa 100644 --- a/mne/preprocessing/pca_obs/__init__.py +++ b/mne/preprocessing/pca_obs/__init__.py @@ -0,0 +1,58 @@ +from dataclasses import dataclass +from mne.io import Raw, read_raw_fif + + +# TODO: description of what this is +ESG_CHANS = [ + "S35", + "S24", + "S36", + "Iz", + "S17", + "S15", + "S32", + "S22", + "S19", + "S26", + "S28", + "S9", + "S13", + "S11", + "S7", + "SC1", + "S4", + "S18", + "S8", + "S31", + "SC6", + "S12", + "S16", + "S5", + "S30", + "S20", + "S34", + "AC", + "S21", + "S25", + "L1", + "S29", + "S14", + "S33", + "S3", + "AL", + "L4", + "S6", + "S23", +] + +# Set variables +fs = 1000 # sampling rate + +# For heartbeat epochs +iv_baseline = [-300 / 1000, -200 / 1000] +iv_epoch = [-400 / 1000, 600 / 1000] + +# Setting paths +input_path = "/data/pt_02569/tmp_data/prepared_py/sub-001/esg/prepro/" +fname = f"noStimart_sr{fs}_median_withqrs_pchip" +raw = read_raw_fif(input_path + fname + ".fif", preload=True) diff --git a/mne/preprocessing/pca_obs/pchip_interpolation.py b/mne/preprocessing/pca_obs/pchip_interpolation.py index 466946ecdce..e4a448eeecf 100755 --- a/mne/preprocessing/pca_obs/pchip_interpolation.py +++ b/mne/preprocessing/pca_obs/pchip_interpolation.py @@ -1,23 +1,18 @@ # Function to interpolate based on PCHIP rather than MNE inbuilt linear option +from typing import Any import numpy as np from scipy.interpolate import PchipInterpolator as pchip +from numpy.typing import NDArray # TODO: only place defined. is this used? -def PCHIP_interpolation(data, **kwargs): - # Check all necessary arguments sent in - required_kws = ["trigger_indices", "interpol_window_sec", "fs"] - if not all( - [kw in kwargs.keys() for kw in required_kws] - ): - raise RuntimeError("Some KWs not passed into PCA_OBS.") - - # Extract all kwargs - more elegant ways to do this - fs = kwargs["fs"] - interpol_window_sec = kwargs["interpol_window_sec"] - trigger_indices = kwargs["trigger_indices"] - +def PCHIP_interpolation( + data: NDArray[Any], + trigger_indices, + interpol_window_sec, + fs, +): # Convert intpol window to msec then convert to samples pre_window = round((interpol_window_sec[0] * 1000) * fs / 1000) # in samples post_window = round((interpol_window_sec[1] * 1000) * fs / 1000) # in samples diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py index 62c126ae6e0..f57221e0ed0 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py @@ -3,6 +3,13 @@ from matplotlib import pyplot as plt from mne.preprocessing.pca_obs.PCA_OBS import PCA_OBS +from mne.preprocessing.pca_obs import ESG_CHANS +from mne.preprocessing.pca_obs import ( + fs, + iv_baseline, + iv_epoch, + raw, +) from scipy.signal import firls import numpy as np from mne import Epochs, events_from_annotations @@ -10,60 +17,6 @@ from mne.preprocessing import find_ecg_events if __name__ == "__main__": - # Incredibly slow without parallelization - # Set variables - subject_id = "sub-001" - cond_name = "median" - fs = 1000 # sampling rate - esg_chans = [ - "S35", - "S24", - "S36", - "Iz", - "S17", - "S15", - "S32", - "S22", - "S19", - "S26", - "S28", - "S9", - "S13", - "S11", - "S7", - "SC1", - "S4", - "S18", - "S8", - "S31", - "SC6", - "S12", - "S16", - "S5", - "S30", - "S20", - "S34", - "AC", - "S21", - "S25", - "L1", - "S29", - "S14", - "S33", - "S3", - "AL", - "L4", - "S6", - "S23", - ] - # For heartbeat epochs - iv_baseline = [-300 / 1000, -200 / 1000] - iv_epoch = [-400 / 1000, 600 / 1000] - - # Setting paths - input_path = "/data/pt_02569/tmp_data/prepared_py/" + subject_id + "/esg/prepro/" - fname = f"noStimart_sr{fs}_{cond_name}_withqrs_pchip" - raw = read_raw_fif(input_path + fname + ".fif", preload=True) # Find ECG events - no ECG channel in data, uses synthetic ( @@ -97,12 +50,11 @@ # Apply function should modifies the data in raw in place raw.apply_function( PCA_OBS, - picks=esg_chans, - n_jobs=len(esg_chans), - **{ # args sent to PCA_OBS - "qrs": ecg_event_samples, - "filter_coords": fwts, - }, + picks=ESG_CHANS, + n_jobs=len(ESG_CHANS), + # args sent to PCA_OBS + qrs=ecg_event_samples, + filter_coords=fwts, ) epochs = Epochs( raw, diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py index d72bb058c2a..f1807829642 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py @@ -5,69 +5,23 @@ from mne.preprocessing.pca_obs.PCA_OBS import PCA_OBS from scipy.io import loadmat from scipy.signal import firls - +from mne.preprocessing.pca_obs import ESG_CHANS +from mne.preprocessing.pca_obs import ( + fs, + iv_baseline, + iv_epoch, + raw, +) from mne import Epochs, events_from_annotations from mne.io import read_raw_fif + if __name__ == "__main__": - # Incredibly slow without parallelization - # Set variables - subject_id = "sub-001" - cond_name = "median" - fs = 1000 # sampling rate - esg_chans = [ - "S35", - "S24", - "S36", - "Iz", - "S17", - "S15", - "S32", - "S22", - "S19", - "S26", - "S28", - "S9", - "S13", - "S11", - "S7", - "SC1", - "S4", - "S18", - "S8", - "S31", - "SC6", - "S12", - "S16", - "S5", - "S30", - "S20", - "S34", - "AC", - "S21", - "S25", - "L1", - "S29", - "S14", - "S33", - "S3", - "AL", - "L4", - "S6", - "S23", - ] - # For heartbeat epochs - iv_baseline = [-300 / 1000, -200 / 1000] - iv_epoch = [-400 / 1000, 600 / 1000] - # Setting paths - input_path = "/data/pt_02569/tmp_data/prepared_py/" + subject_id + "/esg/prepro/" - fname = f"noStimart_sr{fs}_{cond_name}_withqrs_pchip" - raw = read_raw_fif(input_path + fname + ".fif", preload=True) # Read .mat file with QRS events - input_path_m = "/data/pt_02569/tmp_data/prepared/" + subject_id + "/esg/prepro/" - fname_m = f"raw_{fs}_spinal_{cond_name}" + input_path_m = "/data/pt_02569/tmp_data/prepared/sub-001/esg/prepro/" + fname_m = f"raw_1000_spinal_median" matdata = loadmat(input_path_m + fname_m + ".mat") # Create filter coefficients @@ -93,12 +47,11 @@ # Apply function should modifies the data in raw in place raw.apply_function( PCA_OBS, - picks=esg_chans, - n_jobs=len(esg_chans) - **{ # args sent to PCA_OBS - "qrs": matdata["QRSevents"], - "filter_coords": fwts, - }, + picks=ESG_CHANS, + n_jobs=len(ESG_CHANS), + # args sent to PCA_OBS + qrs=matdata["QRSevents"], + filter_coords=fwts, ) epochs = Epochs( raw, From 61a92eaebb332c19204f0181ea95f118b9b975a8 Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Fri, 25 Oct 2024 13:09:15 +0200 Subject: [PATCH 05/23] feat/initial-cleanup: Remove custom pchip as not in use --- .../pca_obs/pchip_interpolation.py | 51 ------------------- 1 file changed, 51 deletions(-) delete mode 100755 mne/preprocessing/pca_obs/pchip_interpolation.py diff --git a/mne/preprocessing/pca_obs/pchip_interpolation.py b/mne/preprocessing/pca_obs/pchip_interpolation.py deleted file mode 100755 index e4a448eeecf..00000000000 --- a/mne/preprocessing/pca_obs/pchip_interpolation.py +++ /dev/null @@ -1,51 +0,0 @@ -# Function to interpolate based on PCHIP rather than MNE inbuilt linear option - -from typing import Any -import numpy as np -from scipy.interpolate import PchipInterpolator as pchip -from numpy.typing import NDArray - -# TODO: only place defined. is this used? - -def PCHIP_interpolation( - data: NDArray[Any], - trigger_indices, - interpol_window_sec, - fs, -): - # Convert intpol window to msec then convert to samples - pre_window = round((interpol_window_sec[0] * 1000) * fs / 1000) # in samples - post_window = round((interpol_window_sec[1] * 1000) * fs / 1000) # in samples - intpol_window = np.ceil([pre_window, post_window]).astype( - int - ) # interpolation window - - n_samples_fit = ( - 5 # number of samples before and after cut used for interpolation fit - ) - - x_fit_raw = np.concatenate( - [ - np.arange(intpol_window[0] - n_samples_fit - 1, intpol_window[0], 1), - np.arange(intpol_window[1] + 1, intpol_window[1] + n_samples_fit + 2, 1), - ] - ) - x_interpol_raw = np.arange( - intpol_window[0], intpol_window[1] + 1, 1 - ) # points to be interpolated; in pt - - for ii in np.arange(0, len(trigger_indices)): # loop through all stimulation events - x_fit = trigger_indices[ii] + x_fit_raw # fit point latencies for this event - x_interpol = ( - trigger_indices[ii] + x_interpol_raw - ) # latencies for to-be-interpolated data points - - # Data is just a string of values - y_fit = data[x_fit] # y values to be fitted - y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation - data[x_interpol] = y_interpol # replace in data - - if np.mod(ii, 100) == 0: # talk to the operator every 100th trial - print(f"stimulation event {ii} \n") - - return data From a1eb6f6ee4c153914a140fd436cc8b86ac746fdb Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Fri, 25 Oct 2024 13:15:39 +0200 Subject: [PATCH 06/23] refactor/initial-cleanup: answer question --- mne/preprocessing/pca_obs/PCA_OBS.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/preprocessing/pca_obs/PCA_OBS.py b/mne/preprocessing/pca_obs/PCA_OBS.py index cc226113de3..541302853d6 100755 --- a/mne/preprocessing/pca_obs/PCA_OBS.py +++ b/mne/preprocessing/pca_obs/PCA_OBS.py @@ -13,7 +13,7 @@ def PCA_OBS( data: NDArray[Any], qrs: NDArray[Any], filter_coords: NDArray[Any], - **_ # TODO: are there any other kwargs passed in? + **_ # TODO: are there any other kwargs passed in? No - this is all I believe ): # Declare class to hold pca information class PCAInfo: From d75e92715950441257e1d51884bbd452c51330f2 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Mon, 4 Nov 2024 17:42:11 +0100 Subject: [PATCH 07/23] refactor: rename files to match other modules in preprocessing, reduce size of indentation tree --- mne/preprocessing/pca_obs/__init__.py | 13 ++- mne/preprocessing/pca_obs/fit_ecgTemplate.py | 87 ++++++++++--------- .../pca_obs/{PCA_OBS.py => pca_obs.py} | 3 +- .../rm_heart_artefact_cortical_mnedata.py | 4 +- ...rm_heart_artefact_spinal_impreciserpeak.py | 4 +- .../rm_heart_artefact_spinal_preciserpeak.py | 4 +- 6 files changed, 63 insertions(+), 52 deletions(-) rename mne/preprocessing/pca_obs/{PCA_OBS.py => pca_obs.py} (98%) diff --git a/mne/preprocessing/pca_obs/__init__.py b/mne/preprocessing/pca_obs/__init__.py index 054c4a3d2aa..dccbfab698a 100644 --- a/mne/preprocessing/pca_obs/__init__.py +++ b/mne/preprocessing/pca_obs/__init__.py @@ -1,6 +1,15 @@ -from dataclasses import dataclass -from mne.io import Raw, read_raw_fif +"""Principle Component Analysis of OBS (PCA-OBS).""" # TODO: What's OBS? +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +from mne.io import read_raw_fif +from .pca_obs import pca_obs + +__all__ = [ + "pca_obs" +] # TODO: description of what this is ESG_CHANS = [ diff --git a/mne/preprocessing/pca_obs/fit_ecgTemplate.py b/mne/preprocessing/pca_obs/fit_ecgTemplate.py index 3e5b7f55cb2..0aab0e15157 100755 --- a/mne/preprocessing/pca_obs/fit_ecgTemplate.py +++ b/mne/preprocessing/pca_obs/fit_ecgTemplate.py @@ -13,7 +13,7 @@ def fit_ecgTemplate( baseline_range, midP, fitted_art, - post_idx_previousPeak, + post_idx_previousPeak: list, n_samples_fit, ): # Declare class to hold ecg fit information @@ -22,6 +22,7 @@ def __init__(self): pass # Instantiate class + # TODO: Why are we storing this to a class? Can't we just use the variables and write to them? fitecg = fitECG() # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previouspeak @@ -53,46 +54,48 @@ def __init__(self): post_idx_nextPeak = [aPeak_idx[0] + post_range] - # Check it's not empty - if len(post_idx_previousPeak) != 0: - # interpolate time between peaks - intpol_window = np.ceil( - [post_idx_previousPeak[0], aPeak_idx[0] - pre_range] - ).astype("int") # interpolation window - fitecg.intpol_window = intpol_window - - if intpol_window[0] < intpol_window[1]: - # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data - - # You have x_fit which is two slices on either side of the interpolation window endpoints - # You have y_fit which is the y vals corresponding to x values above - # You have x_interpol which is the time points between the two slices in x_fit that you want to interpolate - # You have y_interpol which is values from pchip at the time points specified in x_interpol - x_interpol = np.arange( - intpol_window[0], intpol_window[1] + 1, 1 - ) # points to be interpolated in pt - the gap between the endpoints of the window - x_fit = np.concatenate( - [ - np.arange( - intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1 - ), - np.arange( - intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1 - ), - ] - ) # Entire range of x values in this step (taking some number of samples before and after the window) - y_fit = fitted_art[0, x_fit] - y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation - - # Then make fitted artefact in the desired range equal to the completed fit above - fitted_art[0, post_idx_previousPeak[0] : aPeak_idx[0] - pre_range + 1] = ( - y_interpol - ) - - fitecg.x_fit = x_fit - fitecg.y_fit = y_fit - fitecg.x_interpol = x_interpol - fitecg.y_interpol = y_interpol - fitecg.fitted_art = fitted_art # Reassign if we've gone into this loop + # if last peak, return + if not post_idx_previousPeak: + return fitted_art, post_idx_nextPeak + + # interpolate time between peaks + intpol_window = np.ceil( + [post_idx_previousPeak[0], aPeak_idx[0] - pre_range] + ).astype("int") # interpolation window + fitecg.intpol_window = intpol_window + + if intpol_window[0] < intpol_window[1]: + # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data + + # You have x_fit which is two slices on either side of the interpolation window endpoints + # You have y_fit which is the y vals corresponding to x values above + # You have x_interpol which is the time points between the two slices in x_fit that you want to interpolate + # You have y_interpol which is values from pchip at the time points specified in x_interpol + x_interpol = np.arange( + intpol_window[0], intpol_window[1] + 1, 1 + ) # points to be interpolated in pt - the gap between the endpoints of the window + x_fit = np.concatenate( + [ + np.arange( + intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1 + ), + np.arange( + intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1 + ), + ] + ) # Entire range of x values in this step (taking some number of samples before and after the window) + y_fit = fitted_art[0, x_fit] + y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation + + # Then make fitted artefact in the desired range equal to the completed fit above + fitted_art[0, post_idx_previousPeak[0] : aPeak_idx[0] - pre_range + 1] = ( + y_interpol + ) + + fitecg.x_fit = x_fit + fitecg.y_fit = y_fit + fitecg.x_interpol = x_interpol + fitecg.y_interpol = y_interpol + fitecg.fitted_art = fitted_art # Reassign if we've gone into this loop return fitted_art, post_idx_nextPeak diff --git a/mne/preprocessing/pca_obs/PCA_OBS.py b/mne/preprocessing/pca_obs/pca_obs.py similarity index 98% rename from mne/preprocessing/pca_obs/PCA_OBS.py rename to mne/preprocessing/pca_obs/pca_obs.py index 541302853d6..6fcf5882cbb 100755 --- a/mne/preprocessing/pca_obs/PCA_OBS.py +++ b/mne/preprocessing/pca_obs/pca_obs.py @@ -9,11 +9,10 @@ from numpy.typing import NDArray -def PCA_OBS( +def pca_obs( data: NDArray[Any], qrs: NDArray[Any], filter_coords: NDArray[Any], - **_ # TODO: are there any other kwargs passed in? No - this is all I believe ): # Declare class to hold pca information class PCAInfo: diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py index 4df2262cc05..f0e01d6d8d4 100644 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt import numpy as np -from mne.preprocessing.pca_obs.PCA_OBS import PCA_OBS +from mne.preprocessing.pca_obs import pca_obs from scipy.signal import firls from mne import Epochs @@ -54,7 +54,7 @@ # Apply function should modifies the data in raw in place raw.apply_function( - PCA_OBS, + pca_obs, picks="eeg", n_jobs=10 **{ # args sent to PCA_OBS diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py index f57221e0ed0..5183054dc21 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py @@ -2,7 +2,7 @@ # Analysis, Optimal Basis Sets) from matplotlib import pyplot as plt -from mne.preprocessing.pca_obs.PCA_OBS import PCA_OBS +from mne.preprocessing.pca_obs import pca_obs from mne.preprocessing.pca_obs import ESG_CHANS from mne.preprocessing.pca_obs import ( fs, @@ -49,7 +49,7 @@ # Apply function should modifies the data in raw in place raw.apply_function( - PCA_OBS, + pca_obs, picks=ESG_CHANS, n_jobs=len(ESG_CHANS), # args sent to PCA_OBS diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py index f1807829642..f561fab69a0 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py @@ -2,7 +2,7 @@ # Analysis, Optimal Basis Sets) from matplotlib import pyplot as plt -from mne.preprocessing.pca_obs.PCA_OBS import PCA_OBS +from mne.preprocessing.pca_obs import pca_obs from scipy.io import loadmat from scipy.signal import firls from mne.preprocessing.pca_obs import ESG_CHANS @@ -46,7 +46,7 @@ # Apply function should modifies the data in raw in place raw.apply_function( - PCA_OBS, + pca_obs, picks=ESG_CHANS, n_jobs=len(ESG_CHANS), # args sent to PCA_OBS From 035a9f60a3484483bb23428a2e860660f5f01f43 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Mon, 4 Nov 2024 17:48:18 +0100 Subject: [PATCH 08/23] refactor: remove unused fit_ecgTemplate variable 'baselinerange' --- mne/preprocessing/pca_obs/fit_ecgTemplate.py | 1 - mne/preprocessing/pca_obs/pca_obs.py | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/mne/preprocessing/pca_obs/fit_ecgTemplate.py b/mne/preprocessing/pca_obs/fit_ecgTemplate.py index 0aab0e15157..51665ed4d03 100755 --- a/mne/preprocessing/pca_obs/fit_ecgTemplate.py +++ b/mne/preprocessing/pca_obs/fit_ecgTemplate.py @@ -10,7 +10,6 @@ def fit_ecgTemplate( peak_range, pre_range, post_range, - baseline_range, midP, fitted_art, post_idx_previousPeak: list, diff --git a/mne/preprocessing/pca_obs/pca_obs.py b/mne/preprocessing/pca_obs/pca_obs.py index 6fcf5882cbb..846c4a0eab3 100755 --- a/mne/preprocessing/pca_obs/pca_obs.py +++ b/mne/preprocessing/pca_obs/pca_obs.py @@ -9,6 +9,7 @@ from numpy.typing import NDArray +# TODO: Are we able to split this into smaller segmented functions? def pca_obs( data: NDArray[Any], qrs: NDArray[Any], @@ -19,6 +20,8 @@ class PCAInfo: def __init__(self): pass + # NOTE: Here aswell, is there a reason we are storing this + # to a class? Shouldn't variables suffice? # Instantiate class pca_info = PCAInfo() @@ -96,7 +99,7 @@ def __init__(self): pca_info.expl_var = pca.explained_variance_ratio_ # define selected number of components using profile likelihood - pca_info.nComponents = 4 + pca_info.nComponents = 4 # TODO: Is this a variable? Or constant? Seems like a variable pca_info.meanEffect = mean_effect.T nComponents = pca_info.nComponents @@ -127,7 +130,6 @@ def __init__(self): peak_range, pre_range, post_range, - baseline_range, midP, fitted_art, post_idx_nextPeak, @@ -154,7 +156,6 @@ def __init__(self): peak_range, pre_range, post_range, - baseline_range, midP, fitted_art, post_idx_nextPeak, @@ -187,7 +188,6 @@ def __init__(self): peak_range, pre_range, post_range, - baseline_range, midP, fitted_art, post_idx_nextPeak, From 2debb614afd4201644e6ec7953a288218d104dbf Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Mon, 4 Nov 2024 17:55:06 +0100 Subject: [PATCH 09/23] refactor: make main methods private, import from module __init__, remove more unused variables and imports, add some types --- mne/preprocessing/pca_obs/__init__.py | 6 ++-- ...it_ecgTemplate.py => _fit_ecg_template.py} | 29 +++++++++++++++++-- .../pca_obs/{pca_obs.py => _pca_obs.py} | 16 +++++----- ...rm_heart_artefact_spinal_impreciserpeak.py | 3 +- .../rm_heart_artefact_spinal_preciserpeak.py | 3 +- 5 files changed, 40 insertions(+), 17 deletions(-) rename mne/preprocessing/pca_obs/{fit_ecgTemplate.py => _fit_ecg_template.py} (80%) rename mne/preprocessing/pca_obs/{pca_obs.py => _pca_obs.py} (94%) diff --git a/mne/preprocessing/pca_obs/__init__.py b/mne/preprocessing/pca_obs/__init__.py index dccbfab698a..88f6276264b 100644 --- a/mne/preprocessing/pca_obs/__init__.py +++ b/mne/preprocessing/pca_obs/__init__.py @@ -5,10 +5,12 @@ # Copyright the MNE-Python contributors. from mne.io import read_raw_fif -from .pca_obs import pca_obs +from ._pca_obs import pca_obs +from ._fit_ecg_template import fit_ecg_template __all__ = [ - "pca_obs" + "pca_obs", + "fit_ecg_template" ] # TODO: description of what this is diff --git a/mne/preprocessing/pca_obs/fit_ecgTemplate.py b/mne/preprocessing/pca_obs/_fit_ecg_template.py similarity index 80% rename from mne/preprocessing/pca_obs/fit_ecgTemplate.py rename to mne/preprocessing/pca_obs/_fit_ecg_template.py index 51665ed4d03..631bd0598f2 100755 --- a/mne/preprocessing/pca_obs/fit_ecgTemplate.py +++ b/mne/preprocessing/pca_obs/_fit_ecg_template.py @@ -3,7 +3,7 @@ from scipy.signal import detrend -def fit_ecgTemplate( +def fit_ecg_template( data, pca_template, aPeak_idx, @@ -14,7 +14,32 @@ def fit_ecgTemplate( fitted_art, post_idx_previousPeak: list, n_samples_fit, -): +) -> tuple[np.ndarray, list]: + """TODO: Write docstring about what we do here. + Fits the ECG to a template signal (?) + and returns the fitted artefact and the index of the next peak. (?) + + .. note:: This should only be used on data which is ... (TODO: are there any conditions that must be met to use our algos?) + + + # TODO: Fill out input/output and raises + Parameters + ---------- + data (_type_): _description_ + pca_template (_type_): _description_ + aPeak_idx (_type_): _description_ + peak_range (_type_): _description_ + pre_range (_type_): _description_ + post_range (_type_): _description_ + midP (_type_): _description_ + fitted_art (_type_): _description_ + post_idx_previousPeak (list): _description_ + n_samples_fit (_type_): _description_ + + Returns + ------- + _type_: _description_ + """ # Declare class to hold ecg fit information class fitECG: def __init__(self): diff --git a/mne/preprocessing/pca_obs/pca_obs.py b/mne/preprocessing/pca_obs/_pca_obs.py similarity index 94% rename from mne/preprocessing/pca_obs/pca_obs.py rename to mne/preprocessing/pca_obs/_pca_obs.py index 846c4a0eab3..3c11e293bd7 100755 --- a/mne/preprocessing/pca_obs/pca_obs.py +++ b/mne/preprocessing/pca_obs/_pca_obs.py @@ -2,18 +2,17 @@ from typing import Any import numpy as np -from mne.preprocessing.pca_obs.fit_ecgTemplate import fit_ecgTemplate +from mne.preprocessing.pca_obs import fit_ecg_template from scipy.signal import detrend, filtfilt from sklearn.decomposition import PCA -from numpy.typing import NDArray # TODO: Are we able to split this into smaller segmented functions? def pca_obs( - data: NDArray[Any], - qrs: NDArray[Any], - filter_coords: NDArray[Any], + data: np.ndarray, + qrs: np.ndarray, + filter_coords: np.ndarray, ): # Declare class to hold pca information class PCAInfo: @@ -53,7 +52,6 @@ def __init__(self): mRR = np.median(RR) peak_range = round(mRR / 2) # Rounds to an integer midP = peak_range + 1 - baseline_range = [0, round(peak_range / 8)] n_samples_fit = round( peak_range / 8 ) # sample fit for interpolation between fitted artifact windows @@ -123,7 +121,7 @@ def __init__(self): post_range = peak_range try: post_idx_nextPeak = [] - fitted_art, post_idx_nextPeak = fit_ecgTemplate( + fitted_art, post_idx_nextPeak = fit_ecg_template( data, pca_template, peak_idx[p], @@ -149,7 +147,7 @@ def __init__(self): post_range = peak_range if pre_range > peak_range: pre_range = peak_range - fitted_art, _ = fit_ecgTemplate( + fitted_art, _ = fit_ecg_template( data, pca_template, peak_idx(p), @@ -181,7 +179,7 @@ def __init__(self): aTemplate = pca_template[ midP - peak_range - 1 : midP + peak_range + 1, : ] - fitted_art, post_idx_nextPeak = fit_ecgTemplate( + fitted_art, post_idx_nextPeak = fit_ecg_template( data, aTemplate, peak_idx[p], diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py index 5183054dc21..10dce937e21 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py @@ -1,4 +1,4 @@ -# Calls PCA_OBS which in turn calls fit_ecgTemplate to remove the heart artefact via PCA_OBS (Principal Component +# Calls PCA_OBS which in turn calls fit_ecg_template to remove the heart artefact via PCA_OBS (Principal Component # Analysis, Optimal Basis Sets) from matplotlib import pyplot as plt @@ -13,7 +13,6 @@ from scipy.signal import firls import numpy as np from mne import Epochs, events_from_annotations -from mne.io import read_raw_fif from mne.preprocessing import find_ecg_events if __name__ == "__main__": diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py index f561fab69a0..21723fc6c93 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py @@ -1,4 +1,4 @@ -# Calls PCA_OBS which in turn calls fit_ecgTemplate to remove the heart artefact via PCA_OBS (Principal Component +# Calls PCA_OBS which in turn calls fit_ecg_template to remove the heart artefact via PCA_OBS (Principal Component # Analysis, Optimal Basis Sets) from matplotlib import pyplot as plt @@ -13,7 +13,6 @@ raw, ) from mne import Epochs, events_from_annotations -from mne.io import read_raw_fif if __name__ == "__main__": From 97f9a15f572933f8917d560e7bc4f5bbeba9e541 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Mon, 4 Nov 2024 18:09:04 +0100 Subject: [PATCH 10/23] refactor: add docstring templates, gather imports --- mne/preprocessing/pca_obs/_fit_ecg_template.py | 6 +++--- mne/preprocessing/pca_obs/_pca_obs.py | 18 +++++++++++++++++- .../rm_heart_artefact_cortical_mnedata.py | 9 ++++----- .../rm_heart_artefact_spinal_impreciserpeak.py | 4 ++-- .../rm_heart_artefact_spinal_preciserpeak.py | 4 ++-- 5 files changed, 28 insertions(+), 13 deletions(-) diff --git a/mne/preprocessing/pca_obs/_fit_ecg_template.py b/mne/preprocessing/pca_obs/_fit_ecg_template.py index 631bd0598f2..23f5a512539 100755 --- a/mne/preprocessing/pca_obs/_fit_ecg_template.py +++ b/mne/preprocessing/pca_obs/_fit_ecg_template.py @@ -19,8 +19,8 @@ def fit_ecg_template( Fits the ECG to a template signal (?) and returns the fitted artefact and the index of the next peak. (?) - .. note:: This should only be used on data which is ... (TODO: are there any conditions that must be met to use our algos?) - + (TODO: are there any conditions that must be met to use our algos?) + .. note:: This should only be used on data which is ... # TODO: Fill out input/output and raises Parameters @@ -38,7 +38,7 @@ def fit_ecg_template( Returns ------- - _type_: _description_ + tuple[np.ndarray, list]: the fitted artifact and the next peak index (if available) """ # Declare class to hold ecg fit information class fitECG: diff --git a/mne/preprocessing/pca_obs/_pca_obs.py b/mne/preprocessing/pca_obs/_pca_obs.py index 3c11e293bd7..c6d9f752d69 100755 --- a/mne/preprocessing/pca_obs/_pca_obs.py +++ b/mne/preprocessing/pca_obs/_pca_obs.py @@ -13,7 +13,23 @@ def pca_obs( data: np.ndarray, qrs: np.ndarray, filter_coords: np.ndarray, -): +) -> np.ndarray: + """ + Algorithm to perform the PCA OBS (Principal Component Analysis, Optimal Basis Sets) + algorithm to remove the heart artefact from EEG data. + + .. note:: This should only be used on data which is ... (TODO: are there any conditions that must be met to use our algos?) + + Parameters + ---------- + data (np.ndarray): The data which we want to remove the heart artefact from. + qrs (np.ndarray): _description_ + filter_coords (np.ndarray): _description_ + + Returns + ------- + np.ndarray: The data with the heart artefact removed. + """ # Declare class to hold pca information class PCAInfo: def __init__(self): diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py index f0e01d6d8d4..10cb7772bbf 100644 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py @@ -56,11 +56,10 @@ raw.apply_function( pca_obs, picks="eeg", - n_jobs=10 - **{ # args sent to PCA_OBS - "qrs": ecg_event_samples, - "filter_coords": fwts, - }, + n_jobs=10, + # args sent to PCA_OBS + qrs=ecg_event_samples, + filter_coords=fwts, ) epochs = Epochs( raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], baseline=tuple(iv_baseline) diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py index 10dce937e21..0f0df7d6c75 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py @@ -2,13 +2,13 @@ # Analysis, Optimal Basis Sets) from matplotlib import pyplot as plt -from mne.preprocessing.pca_obs import pca_obs -from mne.preprocessing.pca_obs import ESG_CHANS from mne.preprocessing.pca_obs import ( fs, iv_baseline, iv_epoch, raw, + pca_obs, + ESG_CHANS, ) from scipy.signal import firls import numpy as np diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py index 21723fc6c93..2a195e85a40 100755 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py @@ -2,15 +2,15 @@ # Analysis, Optimal Basis Sets) from matplotlib import pyplot as plt -from mne.preprocessing.pca_obs import pca_obs from scipy.io import loadmat from scipy.signal import firls -from mne.preprocessing.pca_obs import ESG_CHANS from mne.preprocessing.pca_obs import ( fs, iv_baseline, iv_epoch, raw, + pca_obs, + ESG_CHANS, ) from mne import Epochs, events_from_annotations From 34a2e060bcbd5c64d1bbfce2a39b1b00881ca614 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Mon, 4 Nov 2024 18:41:21 +0100 Subject: [PATCH 11/23] test: add placeholder tests, add multiple todos to address where we get data from, how we call functions, how we assert outputs --- mne/preprocessing/pca_obs/tests/__init__.py | 0 .../pca_obs/tests/test_fit_ecg.py | 39 ++++++++++++++++ .../pca_obs/tests/test_pca_obs.py | 44 +++++++++++++++++++ 3 files changed, 83 insertions(+) create mode 100644 mne/preprocessing/pca_obs/tests/__init__.py create mode 100644 mne/preprocessing/pca_obs/tests/test_fit_ecg.py create mode 100644 mne/preprocessing/pca_obs/tests/test_pca_obs.py diff --git a/mne/preprocessing/pca_obs/tests/__init__.py b/mne/preprocessing/pca_obs/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/mne/preprocessing/pca_obs/tests/test_fit_ecg.py b/mne/preprocessing/pca_obs/tests/test_fit_ecg.py new file mode 100644 index 00000000000..eb5c80bc8d8 --- /dev/null +++ b/mne/preprocessing/pca_obs/tests/test_fit_ecg.py @@ -0,0 +1,39 @@ +"""Test the fot_ecg_template function.""" + +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +from mne.io import read_raw_fif +from mne.preprocessing.pca_obs import fit_ecg_template +from mne.datasets.testing import data_path, requires_testing_data + +# TODO: Where are the test files we want to use located? +fname = data_path(download=False) / "eyetrack" / "test_eyelink.asc" + + +@requires_testing_data +def test_fit_ecg_template(): + """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" + raw = read_raw_fif(fname) + + # Somehow have to "fake" all these inputs to the function + result = fit_ecg_template( + data=None, + pca_template=None, + aPeak_idx=None, + peak_range=None, + pre_range=None, + post_range=None, + midP=None, + fitted_art=None, + post_idx_previousPeak=None, + n_samples_fit=None, + ) + + # assert results + assert result is not None + assert result.shape == (100, 100) + assert result.shape == raw.shape # is this a condition we can test? + assert result[0, 0] == 1.0 + ... \ No newline at end of file diff --git a/mne/preprocessing/pca_obs/tests/test_pca_obs.py b/mne/preprocessing/pca_obs/tests/test_pca_obs.py new file mode 100644 index 00000000000..06d45062d6d --- /dev/null +++ b/mne/preprocessing/pca_obs/tests/test_pca_obs.py @@ -0,0 +1,44 @@ +"""Test the ieeg projection functions.""" + +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +# TODO: migrate this structure to test out function + +import pytest + +from mne.io import read_raw_fif +from mne.preprocessing.pca_obs import pca_obs +from mne.datasets.testing import data_path, requires_testing_data + +# TODO: Where are the test files we want to use located? +fname = data_path(download=False) / "eyetrack" / "test_eyelink.asc" + +@requires_testing_data +@pytest.mark.parametrize( + # TODO: Are there any parameters we can cycle through to + # test multiple? Different fs, windows, highpass freqs, etc.? + # TODO: how do we determine qrs and filter_coords? What are these? + "fs, highpass_freq, qrs, filter_coords", + [ + (0.2, 1.0, 100, 200), + (0.1, 2.0, 100, 200), + ], +) +def test_heart_artifact_removal(fs, highpass_freq, qrs, filter_coords): + """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" + raw = read_raw_fif(fname) + + # Do something with fs and highpass as processing of the data? + ... + + # call pca_obs algorithm + result = pca_obs(raw, qrs=qrs, filter_coords=filter_coords) + + # assert results + assert result is not None + assert result.shape == (100, 100) + assert result.shape == raw.shape # is this a condition we can test? + assert result[0, 0] == 1.0 + ... \ No newline at end of file From 3760c02efda91f709dbab4aaf1277ecbafce050c Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Mon, 11 Nov 2024 12:29:09 +0100 Subject: [PATCH 12/23] Removing extra examples --- .../rm_heart_artefact_cortical_mnedata.py | 86 ------------------- .../rm_heart_artefact_spinal_preciserpeak.py | 82 ------------------ 2 files changed, 168 deletions(-) delete mode 100644 mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py delete mode 100755 mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py deleted file mode 100644 index 10cb7772bbf..00000000000 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py +++ /dev/null @@ -1,86 +0,0 @@ -# Checking algorithm implementation with EEG data form MNE sample datasets -# Following this tutorial data: https://mne.tools/stable/auto_tutorials/preprocessing/50_artifact_correction_ssp.html#what-is-ssp - -import os - -import matplotlib.pyplot as plt -import numpy as np -from mne.preprocessing.pca_obs import pca_obs -from scipy.signal import firls - -from mne import Epochs -from mne.datasets.sample import data_path -from mne.io import read_raw_fif -from mne.preprocessing import ( - find_ecg_events, -) - -if __name__ == "__main__": - sample_data_folder = data_path(path="/data/pt_02569/mne_test_data/") - sample_data_raw_file = os.path.join( - sample_data_folder, "MEG", "sample", "sample_audvis_raw.fif" - ) - # here we crop and resample just for speed - raw = read_raw_fif(sample_data_raw_file, preload=True) - - # Find ECG events - no ECG channel in data, uses synthetic - ( - ecg_events, - ch_ecg, - average_pulse, - ) = find_ecg_events(raw) - # Extract just sample timings of ecg events - ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) - - # Create filter coefficients - fs = raw.info["sfreq"] - a = [0, 0, 1, 1] - f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter - # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter - ord = round(3 * fs / 0.5) - fwts = firls(ord + 1, f, a) - - # For heartbeat epochs - iv_baseline = [-300 / 1000, -200 / 1000] # 300 ms before to 200 ms after - iv_epoch = [-400 / 1000, 600 / 1000] # 300 ms before to 200 ms after - - # run PCA_OBS - # Algorithm is extremely sensitive to accurate R-peak timings, won't work as well with the artificial ECG - # channel estimation as we have here - epochs = Epochs( - raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], baseline=tuple(iv_baseline) - ) - evoked_before = epochs.average() - - # Apply function should modifies the data in raw in place - raw.apply_function( - pca_obs, - picks="eeg", - n_jobs=10, - # args sent to PCA_OBS - qrs=ecg_event_samples, - filter_coords=fwts, - ) - epochs = Epochs( - raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], baseline=tuple(iv_baseline) - ) - evoked_after = epochs.average() - - # Comparison image - fig, axes = plt.subplots(2, 1) - axes[0].plot(evoked_before.times, evoked_before.get_data().T) - axes[0].set_ylim([-1e-5, 3e-5]) - axes[0].set_title("Before PCA-OBS") - axes[1].plot(evoked_after.times, evoked_after.get_data().T) - axes[1].set_ylim([-1e-5, 3e-5]) - axes[1].set_title("After PCA-OBS") - plt.tight_layout() - - # Comparison image - fig, axes = plt.subplots(1, 1) - axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") - axes.set_ylim([-1e-5, 3e-5]) - axes.plot(evoked_after.times, evoked_after.get_data().T, color="green") - axes.set_title("Before (black) versus after (green)") - plt.tight_layout() - plt.show() diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py deleted file mode 100755 index 2a195e85a40..00000000000 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py +++ /dev/null @@ -1,82 +0,0 @@ -# Calls PCA_OBS which in turn calls fit_ecg_template to remove the heart artefact via PCA_OBS (Principal Component -# Analysis, Optimal Basis Sets) - -from matplotlib import pyplot as plt -from scipy.io import loadmat -from scipy.signal import firls -from mne.preprocessing.pca_obs import ( - fs, - iv_baseline, - iv_epoch, - raw, - pca_obs, - ESG_CHANS, -) -from mne import Epochs, events_from_annotations - - -if __name__ == "__main__": - - - # Read .mat file with QRS events - input_path_m = "/data/pt_02569/tmp_data/prepared/sub-001/esg/prepro/" - fname_m = f"raw_1000_spinal_median" - matdata = loadmat(input_path_m + fname_m + ".mat") - - # Create filter coefficients - a = [0, 0, 1, 1] - f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter - # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter - ord = round(3 * fs / 0.5) - fwts = firls(ord + 1, f, a) - - # run PCA_OBS - events, event_ids = events_from_annotations(raw) - event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} - epochs = Epochs( - raw, - events, - event_id=event_id_dict, - tmin=iv_epoch[0], - tmax=iv_epoch[1], - baseline=tuple(iv_baseline), - ) - evoked_before = epochs.average() - - # Apply function should modifies the data in raw in place - raw.apply_function( - pca_obs, - picks=ESG_CHANS, - n_jobs=len(ESG_CHANS), - # args sent to PCA_OBS - qrs=matdata["QRSevents"], - filter_coords=fwts, - ) - epochs = Epochs( - raw, - events, - event_id=event_id_dict, - tmin=iv_epoch[0], - tmax=iv_epoch[1], - baseline=tuple(iv_baseline), - ) - evoked_after = epochs.average() - - # Comparison image - fig, axes = plt.subplots(2, 1) - axes[0].plot(evoked_before.times, evoked_before.get_data().T) - axes[0].set_ylim([-0.0005, 0.001]) - axes[0].set_title("Before PCA-OBS") - axes[1].plot(evoked_after.times, evoked_after.get_data().T) - axes[1].set_ylim([-0.0005, 0.001]) - axes[1].set_title("After PCA-OBS") - plt.tight_layout() - - # Comparison image - fig, axes = plt.subplots(1, 1) - axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") - axes.set_ylim([-0.0005, 0.001]) - axes.plot(evoked_after.times, evoked_after.get_data().T, color="green") - axes.set_title("Before (black) versus after (green)") - plt.tight_layout() - plt.show() From 0a55712678b21c63a54c94d7a32710046e929579 Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Mon, 11 Nov 2024 12:36:02 +0100 Subject: [PATCH 13/23] Moving example --- .../preprocessing/esg_rm_heart_artefact_pcaobs.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py => examples/preprocessing/esg_rm_heart_artefact_pcaobs.py (100%) diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py similarity index 100% rename from mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py rename to examples/preprocessing/esg_rm_heart_artefact_pcaobs.py From e887ace5c40f89e0fe52e1d6ff4a68a5932ea98e Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Mon, 11 Nov 2024 14:20:52 +0100 Subject: [PATCH 14/23] Adding Niazy reference --- doc/references.bib | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/doc/references.bib b/doc/references.bib index a129d2f46a2..e2578ed18f2 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -1335,6 +1335,16 @@ @inproceedings{NdiayeEtAl2016 year = {2016} } +@article{NiazyEtAl2005, + author = {Niazy, R. K. and Beckmann, C.F. and Iannetti, G.D. and Brady, J. M. and Smith, S. M.}, + title = {Removal of FMRI environment artifacts from EEG data using optimal basis sets}, + journal = {NeuroImage}, + year = {2005}, + volume = {28}, + pages = {720-737}, + doi = {10.1016/j.neuroimage.2005.06.067.} +} + @article{NicholsHolmes2002, author = {Nichols, Thomas E. and Holmes, Andrew P.}, doi = {10.1002/hbm.1058}, From ee3b73eb59e7f6396edf7ef127f970d81826af58 Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Mon, 11 Nov 2024 16:13:12 +0100 Subject: [PATCH 15/23] Update example, put pca-obs in a single file --- .../esg_rm_heart_artefact_pcaobs.py | 252 ++++++++++++------ mne/preprocessing/pca_obs/__init__.py | 62 +---- .../pca_obs/_fit_ecg_template.py | 125 --------- mne/preprocessing/pca_obs/_pca_obs.py | 133 ++++++++- 4 files changed, 296 insertions(+), 276 deletions(-) delete mode 100755 mne/preprocessing/pca_obs/_fit_ecg_template.py diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index 0f0df7d6c75..e3179293e9e 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -1,85 +1,175 @@ -# Calls PCA_OBS which in turn calls fit_ecg_template to remove the heart artefact via PCA_OBS (Principal Component -# Analysis, Optimal Basis Sets) +""" +.. _ex-pcaobs: + +============================================================================================== +Principal Component Analysis - Optimal Basis Sets (PCA-OBS) for removal of cardiac artefact +============================================================================================== + +This script shows an example of how to use an adaptation of PCA-OBS +:footcite:`NiazyEtAl2005`. PCA-OBS was originally designed to remove +the ballistocardiographic artefact in simultaneous EEG-fMRI. Here, it +has been adapted to remove the delay between the detected R-peak and the +ballistocardiographic artefact such that the algorithm can be applied to +remove the cardiac artefact in EEG (electroencephalogrpahy) and ESG +(electrospinography) data. We will illustrate how it works by applying the +algorithm to ESG data, where the effect of removal is most pronounced. + +See: https://www.biorxiv.org/content/10.1101/2024.09.05.611423v1 +for more details on the dataset and application for ESG data. + +""" + +# Authors: Emma Bailey , Steinn Hauser Magnusson +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. from matplotlib import pyplot as plt -from mne.preprocessing.pca_obs import ( - fs, - iv_baseline, - iv_epoch, - raw, - pca_obs, - ESG_CHANS, -) +from mne.preprocessing.pca_obs import pca_obs +from mne.preprocessing import find_ecg_events, fix_stim_artifact +from mne.io import read_raw_eeglab from scipy.signal import firls import numpy as np -from mne import Epochs, events_from_annotations -from mne.preprocessing import find_ecg_events - -if __name__ == "__main__": - - # Find ECG events - no ECG channel in data, uses synthetic - ( - ecg_events, - ch_ecg, - average_pulse, - ) = find_ecg_events(raw, ch_name="ECG") - # Extract just sample timings of ecg events - ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) - - # Create filter coefficients - a = [0, 0, 1, 1] - f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter - # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter - ord = round(3 * fs / 0.5) - fwts = firls(ord + 1, f, a) - - # run PCA_OBS - events, event_ids = events_from_annotations(raw) - event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} - epochs = Epochs( - raw, - events, - event_id=event_id_dict, - tmin=iv_epoch[0], - tmax=iv_epoch[1], - baseline=tuple(iv_baseline), - ) - evoked_before = epochs.average() - - # Apply function should modifies the data in raw in place - raw.apply_function( - pca_obs, - picks=ESG_CHANS, - n_jobs=len(ESG_CHANS), - # args sent to PCA_OBS - qrs=ecg_event_samples, - filter_coords=fwts, - ) - epochs = Epochs( - raw, - events, - event_id=event_id_dict, - tmin=iv_epoch[0], - tmax=iv_epoch[1], - baseline=tuple(iv_baseline), - ) - evoked_after = epochs.average() - - # Comparison image - fig, axes = plt.subplots(2, 1) - axes[0].plot(evoked_before.times, evoked_before.get_data().T) - axes[0].set_ylim([-0.0005, 0.001]) - axes[0].set_title("Before PCA-OBS") - axes[1].plot(evoked_after.times, evoked_after.get_data().T) - axes[1].set_ylim([-0.0005, 0.001]) - axes[1].set_title("After PCA-OBS") - plt.tight_layout() - - # Comparison image - fig, axes = plt.subplots(1, 1) - axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") - axes.set_ylim([-0.0005, 0.001]) - axes.plot(evoked_after.times, evoked_after.get_data().T, color="green") - axes.set_title("Before (black) versus after (green)") - plt.tight_layout() - plt.show() +from mne import Epochs, events_from_annotations, concatenate_raws + +############################################################################### +# Download sample subject data from OpenNeuro if you haven't already +# This will download simultaneous EEG and ESG data from a single participant after +# median nerve stimulation of the left wrist +# Set the target directory to your desired location +import openneuro as on +import glob +target_dir = '/data/pt_02569/test_data' +file_list = glob.glob(target_dir + '/sub-001/eeg/*median*.set') +if file_list: + print('Data is already downloaded') +else: + on.download(dataset='ds004388', target_dir=target_dir, include='sub-001/*median*_eeg*') + +############################################################################### +# Define the esg channels (arranged in two patches over the neck and lower back) +# Also include the ECG channel for artefact correction +esg_chans = ["S35", "S24", "S36", "Iz", "S17", "S15", "S32", "S22", "S19", "S26", "S28", + "S9", "S13", "S11", "S7", "SC1", "S4", "S18", "S8", "S31", "SC6", "S12", + "S16", "S5", "S30", "S20", "S34", "S21", "S25", "L1", "S29", "S14", "S33", + "S3", "L4", "S6", "S23", 'ECG'] + +# Sampling rate +fs = 1000 + +# Interpolation window for ESG data to remove stimulation artefact +tstart_esg = -0.007 +tmax_esg = 0.007 + +# Define timing of heartbeat epochs +iv_baseline = [-300 / 1000, -200 / 1000] +iv_epoch = [-400 / 1000, 600 / 1000] + +############################################################################### +# Read in each of the four blocks and concatenate the raw structures after performing +# some minimal preprocessing including removing the stimulation artefact, downsampling +# and filtering +block_files = glob.glob(target_dir + '/sub-001/eeg/*median*.set') +block_files = sorted(block_files) + +for count, block_file in enumerate(block_files): + raw = read_raw_eeglab(block_file, eog=(), preload=True, uint16_codec=None, verbose=None) + + # Isolate the ESG channels only + raw.pick(esg_chans) + + # Find trigger timings to remove the stimulation artefact + events, event_dict = events_from_annotations(raw) + trigger_name = 'Median - Stimulation' + + fix_stim_artifact(raw, events=events, event_id=trigger_name, tmin=tstart_esg, tmax=tmax_esg, mode='linear', + stim_channel=None) + + # Downsample the data + raw.resample(fs) + + # Append blocks of the same condition + if count == 0: + raw_concat = raw + else: + concatenate_raws([raw_concat, raw]) + +############################################################################### +# Find ECG events and add to the raw structure as event annotations +ecg_events, ch_ecg, average_pulse = find_ecg_events(raw_concat, ch_name="ECG") +ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) # Samples only + +qrs_event_time = [x / fs for x in ecg_event_samples.reshape(-1)] # Divide by sampling rate to make times +duration = np.repeat(0.0, len(ecg_event_samples)) +description = ['qrs'] * len(ecg_event_samples) + +print(ecg_event_samples) +print(qrs_event_time) +raw_concat.annotations.append(qrs_event_time, duration, description, ch_names=[esg_chans]*len(qrs_event_time)) + +############################################################################### +# Create filter coefficients +a = [0, 0, 1, 1] +f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter +ord = round(3 * fs / 0.5) +fwts = firls(ord + 1, f, a) + +############################################################################### +# Create evoked response about the detected R-peaks before cardiac artefact correction +# Apply PCA-OBS to remove the cardiac artefact +# Create evoked response about the detected R-peaks after cardiac artefact correction +events, event_ids = events_from_annotations(raw_concat) +event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} +epochs = Epochs( + raw_concat, + events, + event_id=event_id_dict, + tmin=iv_epoch[0], + tmax=iv_epoch[1], + baseline=tuple(iv_baseline), +) +evoked_before = epochs.average() + +# Apply function - modifies the data in place +raw_concat.apply_function( + pca_obs, + picks=esg_chans, + n_jobs=len(esg_chans), + # args sent to PCA_OBS + qrs=ecg_event_samples, + filter_coords=fwts, +) + +epochs = Epochs( + raw_concat, + events, + event_id=event_id_dict, + tmin=iv_epoch[0], + tmax=iv_epoch[1], + baseline=tuple(iv_baseline), +) +evoked_after = epochs.average() + +############################################################################### +# Comparison image +fig, axes = plt.subplots(2, 1) +axes[0].plot(evoked_before.times, evoked_before.get_data().T) +axes[0].set_ylim([-0.0005, 0.001]) +axes[0].set_title("Before PCA-OBS") +axes[1].plot(evoked_after.times, evoked_after.get_data().T) +axes[1].set_ylim([-0.0005, 0.001]) +axes[1].set_title("After PCA-OBS") +plt.tight_layout() + +# Comparison image +fig, axes = plt.subplots(1, 1) +axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") +axes.set_ylim([-0.0005, 0.001]) +axes.plot(evoked_after.times, evoked_after.get_data().T, color="green") +axes.set_title("Before (black) versus after (green)") +plt.tight_layout() +plt.show() + +# %% +# References +# ---------- +# .. footbibliography:: diff --git a/mne/preprocessing/pca_obs/__init__.py b/mne/preprocessing/pca_obs/__init__.py index 88f6276264b..cfa48b95fec 100644 --- a/mne/preprocessing/pca_obs/__init__.py +++ b/mne/preprocessing/pca_obs/__init__.py @@ -1,69 +1,11 @@ -"""Principle Component Analysis of OBS (PCA-OBS).""" # TODO: What's OBS? +"""Principle Component Analysis Optimal Basis Sets (PCA-OBS).""" # Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. -from mne.io import read_raw_fif from ._pca_obs import pca_obs -from ._fit_ecg_template import fit_ecg_template __all__ = [ - "pca_obs", - "fit_ecg_template" + "pca_obs" ] - -# TODO: description of what this is -ESG_CHANS = [ - "S35", - "S24", - "S36", - "Iz", - "S17", - "S15", - "S32", - "S22", - "S19", - "S26", - "S28", - "S9", - "S13", - "S11", - "S7", - "SC1", - "S4", - "S18", - "S8", - "S31", - "SC6", - "S12", - "S16", - "S5", - "S30", - "S20", - "S34", - "AC", - "S21", - "S25", - "L1", - "S29", - "S14", - "S33", - "S3", - "AL", - "L4", - "S6", - "S23", -] - -# Set variables -fs = 1000 # sampling rate - -# For heartbeat epochs -iv_baseline = [-300 / 1000, -200 / 1000] -iv_epoch = [-400 / 1000, 600 / 1000] - -# Setting paths -input_path = "/data/pt_02569/tmp_data/prepared_py/sub-001/esg/prepro/" -fname = f"noStimart_sr{fs}_median_withqrs_pchip" -raw = read_raw_fif(input_path + fname + ".fif", preload=True) diff --git a/mne/preprocessing/pca_obs/_fit_ecg_template.py b/mne/preprocessing/pca_obs/_fit_ecg_template.py deleted file mode 100755 index 23f5a512539..00000000000 --- a/mne/preprocessing/pca_obs/_fit_ecg_template.py +++ /dev/null @@ -1,125 +0,0 @@ -import numpy as np -from scipy.interpolate import PchipInterpolator as pchip -from scipy.signal import detrend - - -def fit_ecg_template( - data, - pca_template, - aPeak_idx, - peak_range, - pre_range, - post_range, - midP, - fitted_art, - post_idx_previousPeak: list, - n_samples_fit, -) -> tuple[np.ndarray, list]: - """TODO: Write docstring about what we do here. - Fits the ECG to a template signal (?) - and returns the fitted artefact and the index of the next peak. (?) - - (TODO: are there any conditions that must be met to use our algos?) - .. note:: This should only be used on data which is ... - - # TODO: Fill out input/output and raises - Parameters - ---------- - data (_type_): _description_ - pca_template (_type_): _description_ - aPeak_idx (_type_): _description_ - peak_range (_type_): _description_ - pre_range (_type_): _description_ - post_range (_type_): _description_ - midP (_type_): _description_ - fitted_art (_type_): _description_ - post_idx_previousPeak (list): _description_ - n_samples_fit (_type_): _description_ - - Returns - ------- - tuple[np.ndarray, list]: the fitted artifact and the next peak index (if available) - """ - # Declare class to hold ecg fit information - class fitECG: - def __init__(self): - pass - - # Instantiate class - # TODO: Why are we storing this to a class? Can't we just use the variables and write to them? - fitecg = fitECG() - - # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previouspeak - # Then nextpeak is returned at the end and the process repeats - # select window of template - template = pca_template[midP - peak_range - 1 : midP + peak_range + 1, :] - - # select window of data and detrend it - slice = data[0, aPeak_idx[0] - peak_range : aPeak_idx[0] + peak_range + 1] - detrended_data = detrend(slice.reshape(-1), type="constant") - - # maps data on template and then maps it again back to the sensor space - least_square = np.linalg.lstsq(template, detrended_data, rcond=None) - pad_fit = np.dot(template, least_square[0]) - - # fit artifact, I already loop through externally channel to channel - fitted_art[0, aPeak_idx[0] - pre_range - 1 : aPeak_idx[0] + post_range] = pad_fit[ - midP - pre_range - 1 : midP + post_range - ].T - - fitecg.fitted_art = fitted_art - fitecg.template = template - fitecg.detrended_data = detrended_data - fitecg.pad_fit = pad_fit - fitecg.aPeak_idx = aPeak_idx - fitecg.midP = midP - fitecg.peak_range = peak_range - fitecg.data = data - - post_idx_nextPeak = [aPeak_idx[0] + post_range] - - # if last peak, return - if not post_idx_previousPeak: - return fitted_art, post_idx_nextPeak - - # interpolate time between peaks - intpol_window = np.ceil( - [post_idx_previousPeak[0], aPeak_idx[0] - pre_range] - ).astype("int") # interpolation window - fitecg.intpol_window = intpol_window - - if intpol_window[0] < intpol_window[1]: - # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data - - # You have x_fit which is two slices on either side of the interpolation window endpoints - # You have y_fit which is the y vals corresponding to x values above - # You have x_interpol which is the time points between the two slices in x_fit that you want to interpolate - # You have y_interpol which is values from pchip at the time points specified in x_interpol - x_interpol = np.arange( - intpol_window[0], intpol_window[1] + 1, 1 - ) # points to be interpolated in pt - the gap between the endpoints of the window - x_fit = np.concatenate( - [ - np.arange( - intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1 - ), - np.arange( - intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1 - ), - ] - ) # Entire range of x values in this step (taking some number of samples before and after the window) - y_fit = fitted_art[0, x_fit] - y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation - - # Then make fitted artefact in the desired range equal to the completed fit above - fitted_art[0, post_idx_previousPeak[0] : aPeak_idx[0] - pre_range + 1] = ( - y_interpol - ) - - fitecg.x_fit = x_fit - fitecg.y_fit = y_fit - fitecg.x_interpol = x_interpol - fitecg.y_interpol = y_interpol - fitecg.fitted_art = fitted_art # Reassign if we've gone into this loop - - return fitted_art, post_idx_nextPeak diff --git a/mne/preprocessing/pca_obs/_pca_obs.py b/mne/preprocessing/pca_obs/_pca_obs.py index c6d9f752d69..c3f2178f172 100755 --- a/mne/preprocessing/pca_obs/_pca_obs.py +++ b/mne/preprocessing/pca_obs/_pca_obs.py @@ -2,10 +2,131 @@ from typing import Any import numpy as np -from mne.preprocessing.pca_obs import fit_ecg_template - from scipy.signal import detrend, filtfilt from sklearn.decomposition import PCA +from scipy.interpolate import PchipInterpolator as pchip +from scipy.signal import detrend + +def fit_ecg_template( + data, + pca_template, + aPeak_idx, + peak_range, + pre_range, + post_range, + midP, + fitted_art, + post_idx_previousPeak: list, + n_samples_fit, +) -> tuple[np.ndarray, list]: + """TODO: Write docstring about what we do here. + Fits the ECG to a template signal (?) + and returns the fitted artefact and the index of the next peak. (?) + + (TODO: are there any conditions that must be met to use our algos?) + .. note:: This should only be used on data which is ... + + # TODO: Fill out input/output and raises + Parameters + ---------- + data (_type_): _description_ + pca_template (_type_): _description_ + aPeak_idx (_type_): _description_ + peak_range (_type_): _description_ + pre_range (_type_): _description_ + post_range (_type_): _description_ + midP (_type_): _description_ + fitted_art (_type_): _description_ + post_idx_previousPeak (list): _description_ + n_samples_fit (_type_): _description_ + + Returns + ------- + tuple[np.ndarray, list]: the fitted artifact and the next peak index (if available) + """ + # Declare class to hold ecg fit information + class fitECG: + def __init__(self): + pass + + # Instantiate class + # TODO: Why are we storing this to a class? Can't we just use the variables and write to them? + fitecg = fitECG() + + # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previouspeak + # Then nextpeak is returned at the end and the process repeats + # select window of template + template = pca_template[midP - peak_range - 1 : midP + peak_range + 1, :] + + # select window of data and detrend it + slice = data[0, aPeak_idx[0] - peak_range : aPeak_idx[0] + peak_range + 1] + detrended_data = detrend(slice.reshape(-1), type="constant") + + # maps data on template and then maps it again back to the sensor space + least_square = np.linalg.lstsq(template, detrended_data, rcond=None) + pad_fit = np.dot(template, least_square[0]) + + # fit artifact, I already loop through externally channel to channel + fitted_art[0, aPeak_idx[0] - pre_range - 1 : aPeak_idx[0] + post_range] = pad_fit[ + midP - pre_range - 1 : midP + post_range + ].T + + fitecg.fitted_art = fitted_art + fitecg.template = template + fitecg.detrended_data = detrended_data + fitecg.pad_fit = pad_fit + fitecg.aPeak_idx = aPeak_idx + fitecg.midP = midP + fitecg.peak_range = peak_range + fitecg.data = data + + post_idx_nextPeak = [aPeak_idx[0] + post_range] + + # if last peak, return + if not post_idx_previousPeak: + return fitted_art, post_idx_nextPeak + + # interpolate time between peaks + intpol_window = np.ceil( + [post_idx_previousPeak[0], aPeak_idx[0] - pre_range] + ).astype("int") # interpolation window + fitecg.intpol_window = intpol_window + + if intpol_window[0] < intpol_window[1]: + # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data + + # You have x_fit which is two slices on either side of the interpolation window endpoints + # You have y_fit which is the y vals corresponding to x values above + # You have x_interpol which is the time points between the two slices in x_fit that you want to interpolate + # You have y_interpol which is values from pchip at the time points specified in x_interpol + x_interpol = np.arange( + intpol_window[0], intpol_window[1] + 1, 1 + ) # points to be interpolated in pt - the gap between the endpoints of the window + x_fit = np.concatenate( + [ + np.arange( + intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1 + ), + np.arange( + intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1 + ), + ] + ) # Entire range of x values in this step (taking some number of samples before and after the window) + y_fit = fitted_art[0, x_fit] + y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation + + # Then make fitted artefact in the desired range equal to the completed fit above + fitted_art[0, post_idx_previousPeak[0] : aPeak_idx[0] - pre_range + 1] = ( + y_interpol + ) + + fitecg.x_fit = x_fit + fitecg.y_fit = y_fit + fitecg.x_interpol = x_interpol + fitecg.y_interpol = y_interpol + fitecg.fitted_art = fitted_art # Reassign if we've gone into this loop + + return fitted_art, post_idx_nextPeak # TODO: Are we able to split this into smaller segmented functions? @@ -213,17 +334,9 @@ def __init__(self): print(f"Cannot fit middle section of data. Reason: {e}") # Actually subtract the artefact, return needs to be the same shape as input data - # One sample shift purely due to the fact the r-peaks are currently detected in MATLAB data = data.reshape(-1) fitted_art = fitted_art.reshape(-1) - # One sample shift for my actual data (introduced using matlab r timings) - # data_ = np.zeros(len(data)) - # data_[0] = data[0] - # data_[1:] = data[1:] - fitted_art[:-1] - # data = data_ - - # Original code is this: data -= fitted_art data = data.T.reshape(-1) From adae35261cd5efd0eb54290d4cd093dd0f83b176 Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Mon, 11 Nov 2024 16:16:34 +0100 Subject: [PATCH 16/23] TODO --- mne/preprocessing/pca_obs/_pca_obs.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mne/preprocessing/pca_obs/_pca_obs.py b/mne/preprocessing/pca_obs/_pca_obs.py index c3f2178f172..e987abbe208 100755 --- a/mne/preprocessing/pca_obs/_pca_obs.py +++ b/mne/preprocessing/pca_obs/_pca_obs.py @@ -7,6 +7,9 @@ from scipy.interpolate import PchipInterpolator as pchip from scipy.signal import detrend +# TODO: This needs to be pulled out of the subfolder we've created and moved into the more 'normal' MNE setup +# with the _pca_obs in preprocessing as a single file only, _init integrated in their __init__.py and .pyi + def fit_ecg_template( data, pca_template, From b3c1b4e60f0e4a19642aeac683b6d9a4ebe24aed Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Mon, 11 Nov 2024 16:56:56 +0100 Subject: [PATCH 17/23] Update example --- .../esg_rm_heart_artefact_pcaobs.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index e3179293e9e..e24d5f42212 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -61,7 +61,7 @@ tmax_esg = 0.007 # Define timing of heartbeat epochs -iv_baseline = [-300 / 1000, -200 / 1000] +iv_baseline = [-400 / 1000, -300 / 1000] iv_epoch = [-400 / 1000, 600 / 1000] ############################################################################### @@ -81,7 +81,7 @@ events, event_dict = events_from_annotations(raw) trigger_name = 'Median - Stimulation' - fix_stim_artifact(raw, events=events, event_id=trigger_name, tmin=tstart_esg, tmax=tmax_esg, mode='linear', + fix_stim_artifact(raw, events=events, event_id=event_dict[trigger_name], tmin=tstart_esg, tmax=tmax_esg, mode='linear', stim_channel=None) # Downsample the data @@ -102,8 +102,6 @@ duration = np.repeat(0.0, len(ecg_event_samples)) description = ['qrs'] * len(ecg_event_samples) -print(ecg_event_samples) -print(qrs_event_time) raw_concat.annotations.append(qrs_event_time, duration, description, ch_names=[esg_chans]*len(qrs_event_time)) ############################################################################### @@ -150,16 +148,6 @@ evoked_after = epochs.average() ############################################################################### -# Comparison image -fig, axes = plt.subplots(2, 1) -axes[0].plot(evoked_before.times, evoked_before.get_data().T) -axes[0].set_ylim([-0.0005, 0.001]) -axes[0].set_title("Before PCA-OBS") -axes[1].plot(evoked_after.times, evoked_after.get_data().T) -axes[1].set_ylim([-0.0005, 0.001]) -axes[1].set_title("After PCA-OBS") -plt.tight_layout() - # Comparison image fig, axes = plt.subplots(1, 1) axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") From e6aa17d2c154bc85966f1fb2a29db03c41c29f69 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Fri, 15 Nov 2024 13:48:57 +0100 Subject: [PATCH 18/23] refactor: change way of calling pca obs method, add comments --- .../esg_rm_heart_artefact_pcaobs.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index e24d5f42212..da8583c9901 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -24,7 +24,7 @@ # Copyright the MNE-Python contributors. from matplotlib import pyplot as plt -from mne.preprocessing.pca_obs import pca_obs +import mne from mne.preprocessing import find_ecg_events, fix_stim_artifact from mne.io import read_raw_eeglab from scipy.signal import firls @@ -38,7 +38,11 @@ # Set the target directory to your desired location import openneuro as on import glob + +# add the path where you want the OpenNeuro data downloaded. Files total around 8 GB +# target_dir = "/home/steinnhm/personal/mne-data" target_dir = '/data/pt_02569/test_data' + file_list = glob.glob(target_dir + '/sub-001/eeg/*median*.set') if file_list: print('Data is already downloaded') @@ -128,13 +132,12 @@ evoked_before = epochs.average() # Apply function - modifies the data in place -raw_concat.apply_function( - pca_obs, - picks=esg_chans, - n_jobs=len(esg_chans), - # args sent to PCA_OBS - qrs=ecg_event_samples, - filter_coords=fwts, +mne.preprocessing.apply_pca_obs( + raw_concat, + picks=esg_chans, + n_jobs=4, + qrs=ecg_event_samples, + filter_coords=fwts ) epochs = Epochs( From b150b6c4419cbe02492bf45cc7cd05f382c67bbc Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Fri, 15 Nov 2024 13:50:19 +0100 Subject: [PATCH 19/23] refactor: move pca obs method out of separate python module, change logging to use mne logger instead of prints, add wrapper method in front of private _pca_obs method to handle parallel processing --- mne/preprocessing/__init__.pyi | 2 + .../{pca_obs/_pca_obs.py => pca_obs.py} | 44 ++++++++++++++----- mne/preprocessing/pca_obs/__init__.py | 11 ----- mne/preprocessing/pca_obs/tests/__init__.py | 0 .../pca_obs/tests/test_fit_ecg.py | 39 ---------------- .../{pca_obs => }/tests/test_pca_obs.py | 0 6 files changed, 35 insertions(+), 61 deletions(-) rename mne/preprocessing/{pca_obs/_pca_obs.py => pca_obs.py} (91%) delete mode 100644 mne/preprocessing/pca_obs/__init__.py delete mode 100644 mne/preprocessing/pca_obs/tests/__init__.py delete mode 100644 mne/preprocessing/pca_obs/tests/test_fit_ecg.py rename mne/preprocessing/{pca_obs => }/tests/test_pca_obs.py (100%) diff --git a/mne/preprocessing/__init__.pyi b/mne/preprocessing/__init__.pyi index 54f1c825c13..d58c6e77d24 100644 --- a/mne/preprocessing/__init__.pyi +++ b/mne/preprocessing/__init__.pyi @@ -44,6 +44,7 @@ __all__ = [ "realign_raw", "regress_artifact", "write_fine_calibration", + "apply_pca_obs", ] from . import eyetracking, ieeg, nirs from ._annotate_amplitude import annotate_amplitude @@ -89,3 +90,4 @@ from .realign import realign_raw from .ssp import compute_proj_ecg, compute_proj_eog from .stim import fix_stim_artifact from .xdawn import Xdawn +from .pca_obs import apply_pca_obs \ No newline at end of file diff --git a/mne/preprocessing/pca_obs/_pca_obs.py b/mne/preprocessing/pca_obs.py similarity index 91% rename from mne/preprocessing/pca_obs/_pca_obs.py rename to mne/preprocessing/pca_obs.py index e987abbe208..0b00607262f 100755 --- a/mne/preprocessing/pca_obs/_pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -1,5 +1,10 @@ +"""Principle Component Analysis Optimal Basis Sets (PCA-OBS).""" + +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + import math -from typing import Any import numpy as np from scipy.signal import detrend, filtfilt @@ -7,6 +12,9 @@ from scipy.interpolate import PchipInterpolator as pchip from scipy.signal import detrend +from mne.io.fiff.raw import Raw +from mne.utils import logger, warn + # TODO: This needs to be pulled out of the subfolder we've created and moved into the more 'normal' MNE setup # with the _pca_obs in preprocessing as a single file only, _init integrated in their __init__.py and .pyi @@ -132,9 +140,19 @@ def __init__(self): return fitted_art, post_idx_nextPeak +def apply_pca_obs(raw: Raw, picks: list[str], n_jobs: int, qrs: np.ndarray, filter_coords: np.ndarray) -> None: + raw.apply_function( + _pca_obs, + picks=picks, + n_jobs=n_jobs, + # args sent to PCA_OBS + qrs=qrs, + filter_coords=filter_coords, + ) + # TODO: Are we able to split this into smaller segmented functions? -def pca_obs( - data: np.ndarray, +def _pca_obs( + data: np.ndarray, qrs: np.ndarray, filter_coords: np.ndarray, ) -> np.ndarray: @@ -146,9 +164,13 @@ def pca_obs( Parameters ---------- - data (np.ndarray): The data which we want to remove the heart artefact from. - qrs (np.ndarray): _description_ - filter_coords (np.ndarray): _description_ + data: ndarray, shape (n_channels, n_times) + The data which we want to remove the heart artefact from. + + qrs: ndarray, shape (n_peaks, 1) + Array of times in (s), of detected R-peaks in ECG channel. + + filter_coords: ndarray Returns ------- @@ -185,7 +207,7 @@ def __init__(self): ################################################################ # Preparatory work - reserving memory, configure sizes, de-trend ################################################################ - print("Pulse artifact subtraction in progress...Please wait!") + logger.info("Pulse artifact subtraction in progress...Please wait!") # define peak range based on RR RR = np.diff(peak_idx[:, 0]) @@ -277,11 +299,11 @@ def __init__(self): window_start_idx.append(peak_idx[p] - peak_range) window_end_idx.append(peak_idx[p] + peak_range) except Exception as e: - print(f"Cannot fit first ECG epoch. Reason: {e}") + warn(f"Cannot fit first ECG epoch. Reason: {e}") # Deals with last edge of data elif p == peak_count: - print("On last section - almost there!") + logger.info("On last section - almost there!") try: pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) post_range = peak_range @@ -302,7 +324,7 @@ def __init__(self): window_start_idx.append(peak_idx[p] - peak_range) window_end_idx.append(peak_idx[p] + peak_range) except Exception as e: - print(f"Cannot fit last ECG epoch. Reason: {e}") + warn(f"Cannot fit last ECG epoch. Reason: {e}") # Deals with middle portion of data else: @@ -334,7 +356,7 @@ def __init__(self): window_start_idx.append(peak_idx[p] - peak_range) window_end_idx.append(peak_idx[p] + peak_range) except Exception as e: - print(f"Cannot fit middle section of data. Reason: {e}") + warn(f"Cannot fit middle section of data. Reason: {e}") # Actually subtract the artefact, return needs to be the same shape as input data data = data.reshape(-1) diff --git a/mne/preprocessing/pca_obs/__init__.py b/mne/preprocessing/pca_obs/__init__.py deleted file mode 100644 index cfa48b95fec..00000000000 --- a/mne/preprocessing/pca_obs/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Principle Component Analysis Optimal Basis Sets (PCA-OBS).""" - -# Authors: The MNE-Python contributors. -# License: BSD-3-Clause -# Copyright the MNE-Python contributors. - -from ._pca_obs import pca_obs - -__all__ = [ - "pca_obs" -] diff --git a/mne/preprocessing/pca_obs/tests/__init__.py b/mne/preprocessing/pca_obs/tests/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/mne/preprocessing/pca_obs/tests/test_fit_ecg.py b/mne/preprocessing/pca_obs/tests/test_fit_ecg.py deleted file mode 100644 index eb5c80bc8d8..00000000000 --- a/mne/preprocessing/pca_obs/tests/test_fit_ecg.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Test the fot_ecg_template function.""" - -# Authors: The MNE-Python contributors. -# License: BSD-3-Clause -# Copyright the MNE-Python contributors. - -from mne.io import read_raw_fif -from mne.preprocessing.pca_obs import fit_ecg_template -from mne.datasets.testing import data_path, requires_testing_data - -# TODO: Where are the test files we want to use located? -fname = data_path(download=False) / "eyetrack" / "test_eyelink.asc" - - -@requires_testing_data -def test_fit_ecg_template(): - """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" - raw = read_raw_fif(fname) - - # Somehow have to "fake" all these inputs to the function - result = fit_ecg_template( - data=None, - pca_template=None, - aPeak_idx=None, - peak_range=None, - pre_range=None, - post_range=None, - midP=None, - fitted_art=None, - post_idx_previousPeak=None, - n_samples_fit=None, - ) - - # assert results - assert result is not None - assert result.shape == (100, 100) - assert result.shape == raw.shape # is this a condition we can test? - assert result[0, 0] == 1.0 - ... \ No newline at end of file diff --git a/mne/preprocessing/pca_obs/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py similarity index 100% rename from mne/preprocessing/pca_obs/tests/test_pca_obs.py rename to mne/preprocessing/tests/test_pca_obs.py From 951e87ca95ed57bb2cad63e04d6e772319b5666f Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Fri, 15 Nov 2024 15:16:00 +0100 Subject: [PATCH 20/23] Refactor: Docstrings, removed classes --- .../esg_rm_heart_artefact_pcaobs.py | 11 +- mne/preprocessing/pca_obs.py | 117 ++++++------------ 2 files changed, 43 insertions(+), 85 deletions(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index da8583c9901..f5e3ad3757f 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -19,7 +19,8 @@ """ -# Authors: Emma Bailey , Steinn Hauser Magnusson +# Authors: Emma Bailey , +# Steinn Hauser Magnusson # License: BSD-3-Clause # Copyright the MNE-Python contributors. @@ -135,7 +136,7 @@ mne.preprocessing.apply_pca_obs( raw_concat, picks=esg_chans, - n_jobs=4, + n_jobs=5, qrs=ecg_event_samples, filter_coords=fwts ) @@ -154,9 +155,11 @@ # Comparison image fig, axes = plt.subplots(1, 1) axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") -axes.set_ylim([-0.0005, 0.001]) axes.plot(evoked_after.times, evoked_after.get_data().T, color="green") -axes.set_title("Before (black) versus after (green)") +axes.set_ylim([-0.0005, 0.001]) +axes.set_ylabel('Amplitude (V)') +axes.set_xlabel('Time (s)') +axes.set_title("Before (black) vs. After (green)") plt.tight_layout() plt.show() diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index 0b00607262f..2ef3851a58e 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -1,6 +1,7 @@ """Principle Component Analysis Optimal Basis Sets (PCA-OBS).""" -# Authors: The MNE-Python contributors. +# Authors: Emma Bailey , +# Steinn Hauser Magnusson # License: BSD-3-Clause # Copyright the MNE-Python contributors. @@ -15,8 +16,8 @@ from mne.io.fiff.raw import Raw from mne.utils import logger, warn -# TODO: This needs to be pulled out of the subfolder we've created and moved into the more 'normal' MNE setup -# with the _pca_obs in preprocessing as a single file only, _init integrated in their __init__.py and .pyi + +# TODO: check arguments passed in, raise errors, tests def fit_ecg_template( data, @@ -27,47 +28,37 @@ def fit_ecg_template( post_range, midP, fitted_art, - post_idx_previousPeak: list, + post_idx_previousPeak, n_samples_fit, ) -> tuple[np.ndarray, list]: - """TODO: Write docstring about what we do here. - Fits the ECG to a template signal (?) - and returns the fitted artefact and the index of the next peak. (?) - - (TODO: are there any conditions that must be met to use our algos?) - .. note:: This should only be used on data which is ... + """ + Fits the heartbeat artefact found in the data + Returns the fitted artefact and the index of the next peak. - # TODO: Fill out input/output and raises Parameters ---------- - data (_type_): _description_ - pca_template (_type_): _description_ - aPeak_idx (_type_): _description_ - peak_range (_type_): _description_ - pre_range (_type_): _description_ - post_range (_type_): _description_ - midP (_type_): _description_ - fitted_art (_type_): _description_ - post_idx_previousPeak (list): _description_ - n_samples_fit (_type_): _description_ + data (ndarray): Data from the raw signal (n_channels, n_times) + pca_template (ndarray): Mean heartbeat and first N (4) principal components of the heartbeat matrix + aPeak_idx (int): Sample index of current R-peak + peak_range (int): Half the median RR-interval + pre_range (int): Number of samples to fit before the R-peak + post_range (int): Number of samples to fit after the R-peak + midP (float): Sample index marking middle of the median RR interval in the signal. + Used to extract relevant part of PCA_template. + fitted_art (ndarray): The computed heartbeat artefact computed to remove from the data + post_idx_previousPeak (optional int): Sample index of previous R-peak + n_samples_fit (int): Sample fit for interpolation between fitted artifact windows. + Helps reduce sharp edges at the end of fitted heartbeat events. Returns ------- tuple[np.ndarray, list]: the fitted artifact and the next peak index (if available) """ - # Declare class to hold ecg fit information - class fitECG: - def __init__(self): - pass - - # Instantiate class - # TODO: Why are we storing this to a class? Can't we just use the variables and write to them? - fitecg = fitECG() # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previouspeak # Then nextpeak is returned at the end and the process repeats # select window of template - template = pca_template[midP - peak_range - 1 : midP + peak_range + 1, :] + template = pca_template[midP - peak_range - 1: midP + peak_range + 1, :] # select window of data and detrend it slice = data[0, aPeak_idx[0] - peak_range : aPeak_idx[0] + peak_range + 1] @@ -77,31 +68,19 @@ def __init__(self): least_square = np.linalg.lstsq(template, detrended_data, rcond=None) pad_fit = np.dot(template, least_square[0]) - # fit artifact, I already loop through externally channel to channel - fitted_art[0, aPeak_idx[0] - pre_range - 1 : aPeak_idx[0] + post_range] = pad_fit[ - midP - pre_range - 1 : midP + post_range + # fit artifact + fitted_art[0, aPeak_idx[0] - pre_range - 1: aPeak_idx[0] + post_range] = pad_fit[ + midP - pre_range - 1: midP + post_range ].T - fitecg.fitted_art = fitted_art - fitecg.template = template - fitecg.detrended_data = detrended_data - fitecg.pad_fit = pad_fit - fitecg.aPeak_idx = aPeak_idx - fitecg.midP = midP - fitecg.peak_range = peak_range - fitecg.data = data - - post_idx_nextPeak = [aPeak_idx[0] + post_range] - # if last peak, return - if not post_idx_previousPeak: - return fitted_art, post_idx_nextPeak + if post_idx_previousPeak is None: + return fitted_art, aPeak_idx[0] + post_range # interpolate time between peaks intpol_window = np.ceil( - [post_idx_previousPeak[0], aPeak_idx[0] - pre_range] + [post_idx_previousPeak, aPeak_idx[0] - pre_range] ).astype("int") # interpolation window - fitecg.intpol_window = intpol_window if intpol_window[0] < intpol_window[1]: # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data @@ -127,17 +106,11 @@ def __init__(self): y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation # Then make fitted artefact in the desired range equal to the completed fit above - fitted_art[0, post_idx_previousPeak[0] : aPeak_idx[0] - pre_range + 1] = ( + fitted_art[0, post_idx_previousPeak: aPeak_idx[0] - pre_range + 1] = ( y_interpol ) - fitecg.x_fit = x_fit - fitecg.y_fit = y_fit - fitecg.x_interpol = x_interpol - fitecg.y_interpol = y_interpol - fitecg.fitted_art = fitted_art # Reassign if we've gone into this loop - - return fitted_art, post_idx_nextPeak + return fitted_art, aPeak_idx[0] + post_range def apply_pca_obs(raw: Raw, picks: list[str], n_jobs: int, qrs: np.ndarray, filter_coords: np.ndarray) -> None: @@ -160,8 +133,6 @@ def _pca_obs( Algorithm to perform the PCA OBS (Principal Component Analysis, Optimal Basis Sets) algorithm to remove the heart artefact from EEG data. - .. note:: This should only be used on data which is ... (TODO: are there any conditions that must be met to use our algos?) - Parameters ---------- data: ndarray, shape (n_channels, n_times) @@ -170,21 +141,13 @@ def _pca_obs( qrs: ndarray, shape (n_peaks, 1) Array of times in (s), of detected R-peaks in ECG channel. - filter_coords: ndarray + filter_coords: ndarray (N, ) + The numerator coefficient vector of the filter passed to scipy.signal.filtfilt Returns ------- - np.ndarray: The data with the heart artefact removed. + np.ndarray: The data with the heart artefact suppressed. """ - # Declare class to hold pca information - class PCAInfo: - def __init__(self): - pass - - # NOTE: Here aswell, is there a reason we are storing this - # to a class? Shouldn't variables suffice? - # Instantiate class - pca_info = PCAInfo() # set to baseline data = data.reshape(-1, 1) @@ -250,18 +213,10 @@ def __init__(self): # run PCA(performs SVD(singular value decomposition)) pca = PCA(svd_solver="full") pca.fit(dpcamat) - eigen_vectors = pca.components_ - eigen_values = pca.explained_variance_ factor_loadings = pca.components_.T * np.sqrt(pca.explained_variance_) - pca_info.eigen_vectors = eigen_vectors - pca_info.factor_loadings = factor_loadings - pca_info.eigen_values = eigen_values - pca_info.expl_var = pca.explained_variance_ratio_ # define selected number of components using profile likelihood - pca_info.nComponents = 4 # TODO: Is this a variable? Or constant? Seems like a variable - pca_info.meanEffect = mean_effect.T - nComponents = pca_info.nComponents + nComponents = 4 ####################################################################### # Make template of the ECG artefact @@ -282,7 +237,7 @@ def __init__(self): if post_range > peak_range: post_range = peak_range try: - post_idx_nextPeak = [] + post_idx_nextPeak = None fitted_art, post_idx_nextPeak = fit_ecg_template( data, pca_template, @@ -302,7 +257,7 @@ def __init__(self): warn(f"Cannot fit first ECG epoch. Reason: {e}") # Deals with last edge of data - elif p == peak_count: + elif p == peak_count-1: logger.info("On last section - almost there!") try: pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) @@ -312,7 +267,7 @@ def __init__(self): fitted_art, _ = fit_ecg_template( data, pca_template, - peak_idx(p), + peak_idx[p], peak_range, pre_range, post_range, From c70db0a130e1601a171610d062191ac35d30cf11 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Sun, 24 Nov 2024 17:19:23 +0100 Subject: [PATCH 21/23] refactor: remove unrequired index references, adjust variable namings to have consistent patterns --- mne/preprocessing/pca_obs.py | 165 ++++++++++++++++++----------------- 1 file changed, 83 insertions(+), 82 deletions(-) diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index 2ef3851a58e..b9e0a01e64a 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -6,6 +6,7 @@ # Copyright the MNE-Python contributors. import math +from typing import Optional import numpy as np from scipy.signal import detrend, filtfilt @@ -20,17 +21,17 @@ # TODO: check arguments passed in, raise errors, tests def fit_ecg_template( - data, - pca_template, - aPeak_idx, - peak_range, - pre_range, - post_range, - midP, - fitted_art, - post_idx_previousPeak, - n_samples_fit, -) -> tuple[np.ndarray, list]: + data: np.ndarray, + pca_template: np.ndarray, + a_peak_idx: int, + peak_range: int, + pre_range: int, + post_range: int, + mid_p: float, + fitted_art: np.ndarray, + post_idx_previous_peak: Optional[int], + n_samples_fit: int, +) -> tuple[np.ndarray, int]: """ Fits the heartbeat artefact found in the data Returns the fitted artefact and the index of the next peak. @@ -39,29 +40,29 @@ def fit_ecg_template( ---------- data (ndarray): Data from the raw signal (n_channels, n_times) pca_template (ndarray): Mean heartbeat and first N (4) principal components of the heartbeat matrix - aPeak_idx (int): Sample index of current R-peak + a_peak_idx (int): Sample index of current R-peak peak_range (int): Half the median RR-interval pre_range (int): Number of samples to fit before the R-peak post_range (int): Number of samples to fit after the R-peak - midP (float): Sample index marking middle of the median RR interval in the signal. + mid_p (float): Sample index marking middle of the median RR interval in the signal. Used to extract relevant part of PCA_template. fitted_art (ndarray): The computed heartbeat artefact computed to remove from the data - post_idx_previousPeak (optional int): Sample index of previous R-peak + post_idx_previous_peak (optional int): Sample index of previous R-peak n_samples_fit (int): Sample fit for interpolation between fitted artifact windows. Helps reduce sharp edges at the end of fitted heartbeat events. Returns ------- - tuple[np.ndarray, list]: the fitted artifact and the next peak index (if available) + tuple[np.ndarray, int]: the fitted artifact and the next peak index """ - # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previouspeak + # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previous_peak # Then nextpeak is returned at the end and the process repeats # select window of template - template = pca_template[midP - peak_range - 1: midP + peak_range + 1, :] + template = pca_template[mid_p - peak_range - 1: mid_p + peak_range + 1, :] # select window of data and detrend it - slice = data[0, aPeak_idx[0] - peak_range : aPeak_idx[0] + peak_range + 1] + slice = data[0, a_peak_idx[0] - peak_range : a_peak_idx[0] + peak_range + 1] detrended_data = detrend(slice.reshape(-1), type="constant") # maps data on template and then maps it again back to the sensor space @@ -69,18 +70,18 @@ def fit_ecg_template( pad_fit = np.dot(template, least_square[0]) # fit artifact - fitted_art[0, aPeak_idx[0] - pre_range - 1: aPeak_idx[0] + post_range] = pad_fit[ - midP - pre_range - 1: midP + post_range + fitted_art[0, a_peak_idx[0] - pre_range - 1: a_peak_idx[0] + post_range] = pad_fit[ + mid_p - pre_range - 1: mid_p + post_range ].T # if last peak, return - if post_idx_previousPeak is None: - return fitted_art, aPeak_idx[0] + post_range + if post_idx_previous_peak is None: + return fitted_art, a_peak_idx[0] + post_range # interpolate time between peaks intpol_window = np.ceil( - [post_idx_previousPeak, aPeak_idx[0] - pre_range] - ).astype("int") # interpolation window + [post_idx_previous_peak, a_peak_idx[0] - pre_range] + ).astype(int) # interpolation window if intpol_window[0] < intpol_window[1]: # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data @@ -106,14 +107,18 @@ def fit_ecg_template( y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation # Then make fitted artefact in the desired range equal to the completed fit above - fitted_art[0, post_idx_previousPeak: aPeak_idx[0] - pre_range + 1] = ( + fitted_art[0, post_idx_previous_peak: a_peak_idx[0] - pre_range + 1] = ( y_interpol ) - return fitted_art, aPeak_idx[0] + post_range + return fitted_art, a_peak_idx[0] + post_range def apply_pca_obs(raw: Raw, picks: list[str], n_jobs: int, qrs: np.ndarray, filter_coords: np.ndarray) -> None: + """ + Main convenience function for applying the PCA-OBS algorithm + to certain picks of a Raw object. + """ raw.apply_function( _pca_obs, picks=picks, @@ -123,11 +128,11 @@ def apply_pca_obs(raw: Raw, picks: list[str], n_jobs: int, qrs: np.ndarray, filt filter_coords=filter_coords, ) -# TODO: Are we able to split this into smaller segmented functions? def _pca_obs( data: np.ndarray, qrs: np.ndarray, filter_coords: np.ndarray, + n_components: int = 4 # number of components to pick from the PCA ) -> np.ndarray: """ Algorithm to perform the PCA OBS (Principal Component Analysis, Optimal Basis Sets) @@ -167,25 +172,23 @@ def _pca_obs( peak_idx = peak_idx.reshape(-1, 1) peak_count = len(peak_idx) - ################################################################ - # Preparatory work - reserving memory, configure sizes, de-trend - ################################################################ + ################################################################## + # Preparatory work - reserving memory, configure sizes, de-trend # + ################################################################## logger.info("Pulse artifact subtraction in progress...Please wait!") # define peak range based on RR RR = np.diff(peak_idx[:, 0]) mRR = np.median(RR) peak_range = round(mRR / 2) # Rounds to an integer - midP = peak_range + 1 + mid_p = peak_range + 1 n_samples_fit = round( peak_range / 8 ) # sample fit for interpolation between fitted artifact windows # make sure array is long enough for PArange (if not cut off last ECG peak) - pa = peak_count # Number of QRS complexes detected - while peak_idx[pa - 1, 0] + peak_range > len(data[0]): - pa = pa - 1 - peak_count = pa + while peak_idx[peak_count - 1, 0] + peak_range > len(data[0]): + peak_count = peak_count - 1 # reduce number of QRS complexes detected # Filter channel eegchan = filtfilt(filter_coords, 1, data) @@ -202,34 +205,33 @@ def _pca_obs( pcamat = detrend( pcamat, type="constant", axis=1 ) # [epoch x time] - detrended along the epoch - mean_effect = np.mean( + mean_effect: np.ndarray = np.mean( pcamat, axis=0 ) # [1 x time], contains the mean over all epochs dpcamat = detrend(pcamat, type="constant", axis=1) # [time x epoch] - ################################################################### - # Perform PCA with sklearn - ################################################################### - # run PCA(performs SVD(singular value decomposition)) + ############################ + # Perform PCA with sklearn # + ############################ + # run PCA, perform singular value decomposition (SVD) pca = PCA(svd_solver="full") pca.fit(dpcamat) factor_loadings = pca.components_.T * np.sqrt(pca.explained_variance_) # define selected number of components using profile likelihood - nComponents = 4 - ####################################################################### - # Make template of the ECG artefact - ####################################################################### + ##################################### + # Make template of the ECG artefact # + ##################################### mean_effect = mean_effect.reshape(-1, 1) - pca_template = np.c_[mean_effect, factor_loadings[:, 0:nComponents]] + pca_template = np.c_[mean_effect, factor_loadings[:, :n_components]] - ################################################################################### - # Data Fitting - ################################################################################### + ################ + # Data Fitting # + ################ window_start_idx = [] window_end_idx = [] - for p in range(0, peak_count): + for p in range(peak_count): # Deals with start portion of data if p == 0: pre_range = peak_range @@ -239,16 +241,16 @@ def _pca_obs( try: post_idx_nextPeak = None fitted_art, post_idx_nextPeak = fit_ecg_template( - data, - pca_template, - peak_idx[p], - peak_range, - pre_range, - post_range, - midP, - fitted_art, - post_idx_nextPeak, - n_samples_fit, + data=data, + pca_template=pca_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_nextPeak, + n_samples_fit=n_samples_fit, ) # Appending to list instead of using counter window_start_idx.append(peak_idx[p] - peak_range) @@ -265,16 +267,16 @@ def _pca_obs( if pre_range > peak_range: pre_range = peak_range fitted_art, _ = fit_ecg_template( - data, - pca_template, - peak_idx[p], - peak_range, - pre_range, - post_range, - midP, - fitted_art, - post_idx_nextPeak, - n_samples_fit, + data=data, + pca_template=pca_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_nextPeak, + n_samples_fit=n_samples_fit, ) window_start_idx.append(peak_idx[p] - peak_range) window_end_idx.append(peak_idx[p] + peak_range) @@ -293,20 +295,20 @@ def _pca_obs( if post_range > peak_range: post_range = peak_range - aTemplate = pca_template[ - midP - peak_range - 1 : midP + peak_range + 1, : + a_template = pca_template[ + mid_p - peak_range - 1 : mid_p + peak_range + 1, : ] fitted_art, post_idx_nextPeak = fit_ecg_template( - data, - aTemplate, - peak_idx[p], - peak_range, - pre_range, - post_range, - midP, - fitted_art, - post_idx_nextPeak, - n_samples_fit, + data=data, + pca_template=a_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_nextPeak, + n_samples_fit=n_samples_fit, ) window_start_idx.append(peak_idx[p] - peak_range) window_end_idx.append(peak_idx[p] + peak_range) @@ -320,5 +322,4 @@ def _pca_obs( data -= fitted_art data = data.T.reshape(-1) - # Can only return data return data From fb8c7ff39a4b08143daf52b7abfa250a34bdfba3 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Sun, 24 Nov 2024 17:35:52 +0100 Subject: [PATCH 22/23] docs: minor edits in docstrings --- mne/preprocessing/pca_obs.py | 58 +++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index b9e0a01e64a..45973f82ded 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -39,17 +39,19 @@ def fit_ecg_template( Parameters ---------- data (ndarray): Data from the raw signal (n_channels, n_times) - pca_template (ndarray): Mean heartbeat and first N (4) principal components of the heartbeat matrix + pca_template (ndarray): Mean heartbeat and first N (default 4) + principal components of the heartbeat matrix a_peak_idx (int): Sample index of current R-peak peak_range (int): Half the median RR-interval pre_range (int): Number of samples to fit before the R-peak post_range (int): Number of samples to fit after the R-peak - mid_p (float): Sample index marking middle of the median RR interval in the signal. - Used to extract relevant part of PCA_template. - fitted_art (ndarray): The computed heartbeat artefact computed to remove from the data + mid_p (float): Sample index marking middle of the median RR interval + in the signal. Used to extract relevant part of PCA_template. + fitted_art (ndarray): The computed heartbeat artefact computed to + remove from the data post_idx_previous_peak (optional int): Sample index of previous R-peak n_samples_fit (int): Sample fit for interpolation between fitted artifact windows. - Helps reduce sharp edges at the end of fitted heartbeat events. + Helps reduce sharp edges at the end of fitted heartbeat events. Returns ------- @@ -114,10 +116,32 @@ def fit_ecg_template( return fitted_art, a_peak_idx[0] + post_range -def apply_pca_obs(raw: Raw, picks: list[str], n_jobs: int, qrs: np.ndarray, filter_coords: np.ndarray) -> None: +def apply_pca_obs( + raw: Raw, + picks: list[str], + qrs: np.ndarray, + filter_coords: np.ndarray, + n_components: int = 4, + n_jobs: Optional[int] = None, +) -> None: """ Main convenience function for applying the PCA-OBS algorithm - to certain picks of a Raw object. + to certain picks of a Raw object. Updates the Raw object in-place. + + Parameters + ---------- + raw: Raw + The raw data to process + picks: list[str] + Channels in the Raw object to remove the heart artefact from + qrs: ndarray, shape (n_peaks, 1) + Array of times in (s), of detected R-peaks in ECG channel. + filter_coords: ndarray (N, ) + The numerator coefficient vector of the filter passed to scipy.signal.filtfilt + n_components: int, default 4 + Number of PCA components to use to form the OBS + n_jobs: int, default None + Number of jobs to perform the PCA-OBS processing in parallel """ raw.apply_function( _pca_obs, @@ -126,32 +150,18 @@ def apply_pca_obs(raw: Raw, picks: list[str], n_jobs: int, qrs: np.ndarray, filt # args sent to PCA_OBS qrs=qrs, filter_coords=filter_coords, + n_components=n_components, ) def _pca_obs( data: np.ndarray, qrs: np.ndarray, filter_coords: np.ndarray, - n_components: int = 4 # number of components to pick from the PCA + n_components: int, ) -> np.ndarray: """ Algorithm to perform the PCA OBS (Principal Component Analysis, Optimal Basis Sets) - algorithm to remove the heart artefact from EEG data. - - Parameters - ---------- - data: ndarray, shape (n_channels, n_times) - The data which we want to remove the heart artefact from. - - qrs: ndarray, shape (n_peaks, 1) - Array of times in (s), of detected R-peaks in ECG channel. - - filter_coords: ndarray (N, ) - The numerator coefficient vector of the filter passed to scipy.signal.filtfilt - - Returns - ------- - np.ndarray: The data with the heart artefact suppressed. + algorithm to remove the heart artefact from EEG data (shape [n_channels, n_times]) """ # set to baseline From 4c73a962648c02570a97ea6b85fe0d941246ad7d Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Mon, 25 Nov 2024 09:10:12 +0100 Subject: [PATCH 23/23] fix: add minor sanity checks for the function inputs --- mne/preprocessing/pca_obs.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index 45973f82ded..fe877f7beb9 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -127,6 +127,7 @@ def apply_pca_obs( """ Main convenience function for applying the PCA-OBS algorithm to certain picks of a Raw object. Updates the Raw object in-place. + Makes sanity checks for all inputs. Parameters ---------- @@ -143,6 +144,13 @@ def apply_pca_obs( n_jobs: int, default None Number of jobs to perform the PCA-OBS processing in parallel """ + + if not qrs: + raise ValueError("qrs must not be empty") + + if not filter_coords: + raise ValueError("filter_coords must not be empty") + raw.apply_function( _pca_obs, picks=picks,