diff --git a/mne/preprocessing/pca_obs/PCA_OBS.py b/mne/preprocessing/pca_obs/PCA_OBS.py new file mode 100755 index 00000000000..aee06165d11 --- /dev/null +++ b/mne/preprocessing/pca_obs/PCA_OBS.py @@ -0,0 +1,185 @@ +import numpy as np +# import mne +from scipy.signal import filtfilt, detrend +import matplotlib.pyplot as plt +from sklearn.decomposition import PCA +from fit_ecgTemplate import fit_ecgTemplate +import math +import h5py + + +def PCA_OBS(data, **kwargs): + + # Declare class to hold pca information + class PCAInfo(): + def __init__(self): + pass + + # Instantiate class + pca_info = PCAInfo() + + # Check all necessary arguments sent in + required_kws = ["qrs", "filter_coords", "sr"] + assert all([kw in kwargs.keys() for kw in required_kws]), "Error. Some KWs not passed into PCA_OBS." + + # Extract all kwargs + qrs = kwargs['qrs'] + filter_coords = kwargs['filter_coords'] + sr = kwargs['sr'] + + fs = sr + + # set to baseline + data = data.reshape(-1, 1) + data = data.T + data = data - np.mean(data, axis=1) + + # Allocate memory + fitted_art = np.zeros(data.shape) + peakplot = np.zeros(data.shape) + + # Extract QRS events + for idx in qrs[0]: + if idx < len(peakplot[0, :]): + peakplot[0, idx] = 1 # logical indexed locations of qrs events + + peak_idx = np.nonzero(peakplot)[1] # Selecting indices along columns + peak_idx = peak_idx.reshape(-1, 1) + peak_count = len(peak_idx) + + ################################################################ + # Preparatory work - reserving memory, configure sizes, de-trend + ################################################################ + print('Pulse artifact subtraction in progress...Please wait!') + + # define peak range based on RR + RR = np.diff(peak_idx[:, 0]) + mRR = np.median(RR) + peak_range = round(mRR/2) # Rounds to an integer + midP = peak_range + 1 + baseline_range = [0, round(peak_range/8)] + n_samples_fit = round(peak_range/8) # sample fit for interpolation between fitted artifact windows + + # make sure array is long enough for PArange (if not cut off last ECG peak) + pa = peak_count # Number of QRS complexes detected + while peak_idx[pa-1, 0] + peak_range > len(data[0]): + pa = pa - 1 + steps = 1 * pa + peak_count = pa + + # Filter channel + eegchan = filtfilt(filter_coords, 1, data) + + # build PCA matrix(heart-beat-epochs x window-length) + pcamat = np.zeros((peak_count - 1, 2*peak_range+1)) # [epoch x time] + # picking out heartbeat epochs + for p in range(1, peak_count): + pcamat[p-1, :] = eegchan[0, peak_idx[p, 0] - peak_range: peak_idx[p, 0] + peak_range+1] + + # detrending matrix(twice) + pcamat = detrend(pcamat, type='constant', axis=1) # [epoch x time] - detrended along the epoch + mean_effect = np.mean(pcamat, axis=0) # [1 x time], contains the mean over all epochs + std_effect = np.std(pcamat, axis=0) # want mean and std of each column + dpcamat = detrend(pcamat, type='constant', axis=1) # [time x epoch] + + ################################################################### + # Perform PCA with sklearn + ################################################################### + # run PCA(performs SVD(singular value decomposition)) + pca = PCA(svd_solver="full") + pca.fit(dpcamat) + eigen_vectors = pca.components_ + eigen_values = pca.explained_variance_ + factor_loadings = pca.components_.T*np.sqrt(pca.explained_variance_) + pca_info.eigen_vectors = eigen_vectors + pca_info.factor_loadings = factor_loadings + pca_info.eigen_values = eigen_values + pca_info.expl_var = pca.explained_variance_ratio_ + + # define selected number of components using profile likelihood + pca_info.nComponents = 4 + pca_info.meanEffect = mean_effect.T + nComponents = pca_info.nComponents + + ####################################################################### + # Make template of the ECG artefact + ####################################################################### + mean_effect = mean_effect.reshape(-1, 1) + pca_template = np.c_[mean_effect, factor_loadings[:, 0:nComponents]] + + ################################################################################### + # Data Fitting + ################################################################################### + window_start_idx = [] + window_end_idx = [] + for p in range(0, peak_count): + # Deals with start portion of data + if p == 0: + pre_range = peak_range + post_range = math.floor((peak_idx[p + 1] - peak_idx[p])/2) + if post_range > peak_range: + post_range = peak_range + try: + post_idx_nextPeak = [] + fitted_art, post_idx_nextPeak = fit_ecgTemplate(data, pca_template, peak_idx[p], peak_range, + pre_range, post_range, baseline_range, midP, + fitted_art, post_idx_nextPeak, n_samples_fit) + # Appending to list instead of using counter + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + except Exception as e: + print(f'Cannot fit first ECG epoch. Reason: {e}') + + # Deals with last edge of data + elif p == peak_count: + print('On last section - almost there!') + try: + pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) + post_range = peak_range + if pre_range > peak_range: + pre_range = peak_range + fitted_art, _ = fit_ecgTemplate(data, pca_template, peak_idx(p), peak_range, pre_range, post_range, + baseline_range, midP, fitted_art, post_idx_nextPeak, n_samples_fit) + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + except Exception as e: + print(f'Cannot fit last ECG epoch. Reason: {e}') + + # Deals with middle portion of data + else: + try: + # ---------------- Processing of central data - -------------------- + # cycle through peak artifacts identified by peakplot + pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) + post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) + if pre_range >= peak_range: + pre_range = peak_range + if post_range > peak_range: + post_range = peak_range + + aTemplate = pca_template[midP - peak_range-1:midP + peak_range+1, :] + fitted_art, post_idx_nextPeak = fit_ecgTemplate(data, aTemplate, peak_idx[p], peak_range, pre_range, + post_range, baseline_range, midP, fitted_art, + post_idx_nextPeak, n_samples_fit) + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + except Exception as e: + print(f"Cannot fit middle section of data. Reason: {e}") + + # Actually subtract the artefact, return needs to be the same shape as input data + # One sample shift purely due to the fact the r-peaks are currently detected in MATLAB + data = data.reshape(-1) + fitted_art = fitted_art.reshape(-1) + + # One sample shift for my actual data (introduced using matlab r timings) + # data_ = np.zeros(len(data)) + # data_[0] = data[0] + # data_[1:] = data[1:] - fitted_art[:-1] + # data = data_ + + # Original code is this: + data -= fitted_art + data = data.T.reshape(-1) + + # Can only return data + return data diff --git a/mne/preprocessing/pca_obs/fit_ecgTemplate.py b/mne/preprocessing/pca_obs/fit_ecgTemplate.py new file mode 100755 index 00000000000..11043b0c56d --- /dev/null +++ b/mne/preprocessing/pca_obs/fit_ecgTemplate.py @@ -0,0 +1,71 @@ +import numpy as np +from scipy.signal import detrend +from scipy.interpolate import PchipInterpolator as pchip +import h5py + + +def fit_ecgTemplate(data, pca_template, aPeak_idx, peak_range, pre_range, post_range, baseline_range, midP, fitted_art, post_idx_previousPeak, n_samples_fit): + # Declare class to hold ecg fit information + class fitECG(): + def __init__(self): + pass + + # Instantiate class + fitecg = fitECG() + + # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previouspeak + # Then nextpeak is returned at the end and the process repeats + # select window of template + template = pca_template[midP - peak_range-1: midP + peak_range+1, :] + + # select window of data and detrend it + slice = data[0, aPeak_idx[0] - peak_range:aPeak_idx[0] + peak_range+1] + detrended_data = detrend(slice.reshape(-1), type='constant') + + # maps data on template and then maps it again back to the sensor space + least_square = np.linalg.lstsq(template, detrended_data, rcond=None) + pad_fit = np.dot(template, least_square[0]) + + # fit artifact, I already loop through externally channel to channel + fitted_art[0, aPeak_idx[0] - pre_range-1: aPeak_idx[0] + post_range] = pad_fit[midP - pre_range-1: midP + post_range].T + + fitecg.fitted_art = fitted_art + fitecg.template = template + fitecg.detrended_data = detrended_data + fitecg.pad_fit = pad_fit + fitecg.aPeak_idx = aPeak_idx + fitecg.midP = midP + fitecg.peak_range = peak_range + fitecg.data = data + + post_idx_nextPeak = [aPeak_idx[0] + post_range] + + # Check it's not empty + if len(post_idx_previousPeak) != 0: + # interpolate time between peaks + intpol_window = np.ceil([post_idx_previousPeak[0], aPeak_idx[0] - pre_range]).astype('int') # interpolation window + fitecg.intpol_window = intpol_window + + if intpol_window[0] < intpol_window[1]: + # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data + + # You have x_fit which is two slices on either side of the interpolation window endpoints + # You have y_fit which is the y vals corresponding to x values above + # You have x_interpol which is the time points between the two slices in x_fit that you want to interpolate + # You have y_interpol which is values from pchip at the time points specified in x_interpol + x_interpol = np.arange(intpol_window[0], intpol_window[1] + 1, 1) # points to be interpolated in pt - the gap between the endpoints of the window + x_fit = np.concatenate([np.arange(intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1), + np.arange(intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1)]) # Entire range of x values in this step (taking some number of samples before and after the window) + y_fit = fitted_art[0, x_fit] + y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation + + # Then make fitted artefact in the desired range equal to the completed fit above + fitted_art[0, post_idx_previousPeak[0]: aPeak_idx[0] - pre_range+1] = y_interpol + + fitecg.x_fit = x_fit + fitecg.y_fit = y_fit + fitecg.x_interpol = x_interpol + fitecg.y_interpol = y_interpol + fitecg.fitted_art = fitted_art # Reassign if we've gone into this loop + + return fitted_art, post_idx_nextPeak diff --git a/mne/preprocessing/pca_obs/pchip_interpolation.py b/mne/preprocessing/pca_obs/pchip_interpolation.py new file mode 100755 index 00000000000..e336f672363 --- /dev/null +++ b/mne/preprocessing/pca_obs/pchip_interpolation.py @@ -0,0 +1,42 @@ +# Function to interpolate based on PCHIP rather than MNE inbuilt linear option + +# import mne +import numpy as np +from scipy.interpolate import PchipInterpolator as pchip +import matplotlib.pyplot as plt + + +def PCHIP_interpolation(data, **kwargs): + # Check all necessary arguments sent in + required_kws = ["trigger_indices", "interpol_window_sec", "fs"] + assert all([kw in kwargs.keys() for kw in required_kws]), "Error. Some KWs not passed into PCA_OBS." + + # Extract all kwargs - more elegant ways to do this + fs = kwargs['fs'] + interpol_window_sec = kwargs['interpol_window_sec'] + trigger_indices = kwargs['trigger_indices'] + + # Convert intpol window to msec then convert to samples + pre_window = round((interpol_window_sec[0]*1000) * fs / 1000) # in samples + post_window = round((interpol_window_sec[1]*1000) * fs / 1000) # in samples + intpol_window = np.ceil([pre_window, post_window]).astype(int) # interpolation window + + n_samples_fit = 5 # number of samples before and after cut used for interpolation fit + + x_fit_raw = np.concatenate([np.arange(intpol_window[0]-n_samples_fit-1, intpol_window[0], 1), + np.arange(intpol_window[1]+1, intpol_window[1]+n_samples_fit+2, 1)]) + x_interpol_raw = np.arange(intpol_window[0], intpol_window[1]+1, 1) # points to be interpolated; in pt + + for ii in np.arange(0, len(trigger_indices)): # loop through all stimulation events + x_fit = trigger_indices[ii] + x_fit_raw # fit point latencies for this event + x_interpol = trigger_indices[ii] + x_interpol_raw # latencies for to-be-interpolated data points + + # Data is just a string of values + y_fit = data[x_fit] # y values to be fitted + y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation + data[x_interpol] = y_interpol # replace in data + + if np.mod(ii, 100) == 0: # talk to the operator every 100th trial + print(f'stimulation event {ii} \n') + + return data diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py new file mode 100644 index 00000000000..42e58e99463 --- /dev/null +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py @@ -0,0 +1,77 @@ +# Checking algorithm implementation with EEG data form MNE sample datasets +# Following this tutorial data: https://mne.tools/stable/auto_tutorials/preprocessing/50_artifact_correction_ssp.html#what-is-ssp + +from mne.preprocessing import ( + find_ecg_events, + create_ecg_epochs, +) +from mne.io import read_raw_fif +from mne.datasets.sample import data_path +from PCA_OBS import * +from mne import Epochs +import os +import numpy as np +from scipy.signal import firls +import matplotlib.pyplot as plt + +if __name__ == '__main__': + sample_data_folder = data_path(path='/data/pt_02569/mne_test_data/') + sample_data_raw_file = os.path.join( + sample_data_folder, "MEG", "sample", "sample_audvis_raw.fif" + ) + # here we crop and resample just for speed + raw = read_raw_fif(sample_data_raw_file, preload=True) + + # Find ECG events - no ECG channel in data, uses synthetic + ecg_events, ch_ecg, average_pulse, = find_ecg_events(raw) + # Extract just sample timings of ecg events + ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) + # print(ecg_events) + + # Create filter coefficients + fs = raw.info['sfreq'] + a = [0, 0, 1, 1] + f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter + # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter + ord = round(3 * fs / 0.5) + fwts = firls(ord + 1, f, a) + + # For heartbeat epochs + iv_baseline = [-300 / 1000, -200 / 1000] + iv_epoch = [-400 / 1000, 600 / 1000] + + # run PCA_OBS + # Algorithm is extremely sensitive to accurate R-peak timings, won't work as well with the artificial ECG + # channel estimation as we have here + PCA_OBS_kwargs = dict( + qrs=ecg_event_samples, filter_coords=fwts, sr=fs + ) + + epochs = Epochs(raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], + baseline=tuple(iv_baseline)) + evoked_before = epochs.average() + + # Apply function should modifies the data in raw in place + raw.apply_function(PCA_OBS, picks='eeg', **PCA_OBS_kwargs, n_jobs=10) + epochs = Epochs(raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], + baseline=tuple(iv_baseline)) + evoked_after = epochs.average() + + # Comparison image + fig, axes = plt.subplots(2, 1) + axes[0].plot(evoked_before.times, evoked_before.get_data().T) + axes[0].set_ylim([-1e-5, 3e-5]) + axes[0].set_title('Before PCA-OBS') + axes[1].plot(evoked_after.times, evoked_after.get_data().T) + axes[1].set_ylim([-1e-5, 3e-5]) + axes[1].set_title('After PCA-OBS') + plt.tight_layout() + + # Comparison image + fig, axes = plt.subplots(1, 1) + axes.plot(evoked_before.times, evoked_before.get_data().T, color='black') + axes.set_ylim([-1e-5, 3e-5]) + axes.plot(evoked_after.times, evoked_after.get_data().T, color='green') + axes.set_title('Before (black) versus after (green)') + plt.tight_layout() + plt.show() diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py new file mode 100755 index 00000000000..7634096c909 --- /dev/null +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py @@ -0,0 +1,80 @@ +# Calls PCA_OBS which in turn calls fit_ecgTemplate to remove the heart artefact via PCA_OBS (Principal Component +# Analysis, Optimal Basis Sets) + +import os +from scipy.io import loadmat +from scipy.signal import firls +from PCA_OBS import * +from mne.io import read_raw_fif +from mne import events_from_annotations, Epochs +from mne.preprocessing import find_ecg_events + + +if __name__ == '__main__': + # Incredibly slow without parallelization + # Set variables + subject_id = f'sub-001' + cond_name = 'median' + sampling_rate = 1000 + esg_chans = ['S35', 'S24', 'S36', 'Iz', 'S17', 'S15', 'S32', 'S22', + 'S19', 'S26', 'S28', 'S9', 'S13', 'S11', 'S7', 'SC1', 'S4', 'S18', + 'S8', 'S31', 'SC6', 'S12', 'S16', 'S5', 'S30', 'S20', 'S34', 'AC', + 'S21', 'S25', 'L1', 'S29', 'S14', 'S33', 'S3', 'AL', 'L4', 'S6', + 'S23'] + # For heartbeat epochs + iv_baseline = [-300 / 1000, -200 / 1000] + iv_epoch = [-400 / 1000, 600 / 1000] + + # Setting paths + input_path = "/data/pt_02569/tmp_data/prepared_py/"+subject_id+"/esg/prepro/" + fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs_pchip" + raw = read_raw_fif(input_path + fname + '.fif', preload=True) + + # Find ECG events - no ECG channel in data, uses synthetic + ecg_events, ch_ecg, average_pulse, = find_ecg_events(raw, ch_name='ECG') + # Extract just sample timings of ecg events + ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) + + # Create filter coefficients + fs = sampling_rate + a = [0, 0, 1, 1] + f = [0, 0.4/(fs/2), 0.9/(fs / 2), 1] # 0.9 Hz highpass filter + # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter + ord = round(3*fs/0.5) + fwts = firls(ord+1, f, a) + + # run PCA_OBS + PCA_OBS_kwargs = dict( + qrs=ecg_event_samples, filter_coords=fwts, sr=sampling_rate + ) + + events, event_ids = events_from_annotations(raw) + event_id_dict = {key: value for key, value in event_ids.items() if key == 'qrs'} + epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], + baseline=tuple(iv_baseline)) + evoked_before = epochs.average() + + # Apply function should modifies the data in raw in place + raw.apply_function(PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans)) + epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], + baseline=tuple(iv_baseline)) + evoked_after = epochs.average() + + # Comparison image + fig, axes = plt.subplots(2, 1) + axes[0].plot(evoked_before.times, evoked_before.get_data().T) + axes[0].set_ylim([-0.0005, 0.001]) + axes[0].set_title('Before PCA-OBS') + axes[1].plot(evoked_after.times, evoked_after.get_data().T) + axes[1].set_ylim([-0.0005, 0.001]) + axes[1].set_title('After PCA-OBS') + plt.tight_layout() + + # Comparison image + fig, axes = plt.subplots(1, 1) + axes.plot(evoked_before.times, evoked_before.get_data().T, color='black') + axes.set_ylim([-0.0005, 0.001]) + axes.plot(evoked_after.times, evoked_after.get_data().T, color='green') + axes.set_title('Before (black) versus after (green)') + plt.tight_layout() + plt.show() diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py new file mode 100755 index 00000000000..1b559ef9f22 --- /dev/null +++ b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py @@ -0,0 +1,80 @@ +# Calls PCA_OBS which in turn calls fit_ecgTemplate to remove the heart artefact via PCA_OBS (Principal Component +# Analysis, Optimal Basis Sets) + +import os +from scipy.io import loadmat +from scipy.signal import firls +from PCA_OBS import * +from mne.io import read_raw_fif +from mne import events_from_annotations, Epochs + + +if __name__ == '__main__': + # Incredibly slow without parallelization + # Set variables + subject_id = f'sub-001' + cond_name = 'median' + sampling_rate = 1000 + esg_chans = ['S35', 'S24', 'S36', 'Iz', 'S17', 'S15', 'S32', 'S22', + 'S19', 'S26', 'S28', 'S9', 'S13', 'S11', 'S7', 'SC1', 'S4', 'S18', + 'S8', 'S31', 'SC6', 'S12', 'S16', 'S5', 'S30', 'S20', 'S34', 'AC', + 'S21', 'S25', 'L1', 'S29', 'S14', 'S33', 'S3', 'AL', 'L4', 'S6', + 'S23'] + # For heartbeat epochs + iv_baseline = [-300 / 1000, -200 / 1000] + iv_epoch = [-400 / 1000, 600 / 1000] + + # Setting paths + input_path = "/data/pt_02569/tmp_data/prepared_py/"+subject_id+"/esg/prepro/" + input_path_m = "/data/pt_02569/tmp_data/prepared/" + subject_id + "/esg/prepro/" + fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs_pchip" + raw = read_raw_fif(input_path + fname + '.fif', preload=True) + + # Read .mat file with QRS events + fname_m = f"raw_{sampling_rate}_spinal_{cond_name}" + matdata = loadmat(input_path_m+fname_m+'.mat') + QRSevents_m = matdata['QRSevents'] + + # Create filter coefficients + fs = sampling_rate + a = [0, 0, 1, 1] + f = [0, 0.4/(fs/2), 0.9/(fs / 2), 1] # 0.9 Hz highpass filter + # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter + ord = round(3*fs/0.5) + fwts = firls(ord+1, f, a) + + # run PCA_OBS + PCA_OBS_kwargs = dict( + qrs=QRSevents_m, filter_coords=fwts, sr=sampling_rate + ) + + events, event_ids = events_from_annotations(raw) + event_id_dict = {key: value for key, value in event_ids.items() if key == 'qrs'} + epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], + baseline=tuple(iv_baseline)) + evoked_before = epochs.average() + + # Apply function should modifies the data in raw in place + raw.apply_function(PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans)) + epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], + baseline=tuple(iv_baseline)) + evoked_after = epochs.average() + + # Comparison image + fig, axes = plt.subplots(2, 1) + axes[0].plot(evoked_before.times, evoked_before.get_data().T) + axes[0].set_ylim([-0.0005, 0.001]) + axes[0].set_title('Before PCA-OBS') + axes[1].plot(evoked_after.times, evoked_after.get_data().T) + axes[1].set_ylim([-0.0005, 0.001]) + axes[1].set_title('After PCA-OBS') + plt.tight_layout() + + # Comparison image + fig, axes = plt.subplots(1, 1) + axes.plot(evoked_before.times, evoked_before.get_data().T, color='black') + axes.set_ylim([-0.0005, 0.001]) + axes.plot(evoked_after.times, evoked_after.get_data().T, color='green') + axes.set_title('Before (black) versus after (green)') + plt.tight_layout() + plt.show()