diff --git a/doc/source/whats_new.rst b/doc/source/whats_new.rst index 6eea3335981..43ed18f9f12 100644 --- a/doc/source/whats_new.rst +++ b/doc/source/whats_new.rst @@ -27,6 +27,9 @@ Changelog - Add support for BEM solution computation :func:`mne.make_bem_solution` by `Eric Larson`_ + - Add ICA plotters for raw and epoch components by `Jaakko Leppakangas`_ + + BUG ~~~ diff --git a/examples/preprocessing/plot_ica_from_raw.py b/examples/preprocessing/plot_ica_from_raw.py index 1a8c96644dc..22d134b1c30 100644 --- a/examples/preprocessing/plot_ica_from_raw.py +++ b/examples/preprocessing/plot_ica_from_raw.py @@ -75,7 +75,7 @@ show_picks = np.abs(scores).argsort()[::-1][:5] -ica.plot_sources(raw, show_picks, exclude=ecg_inds, title=title % 'eog') +ica.plot_sources(raw, show_picks, exclude=eog_inds, title=title % 'eog') ica.plot_components(eog_inds, title=title % 'eog', colorbar=True) eog_inds = eog_inds[:n_max_eog] diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 4bb3cc7fcfe..65286e78144 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -1335,7 +1335,7 @@ def plot_components(self, picks=None, ch_type=None, res=64, layout=None, head_pos=head_pos) def plot_sources(self, inst, picks=None, exclude=None, start=None, - stop=None, title=None, show=True): + stop=None, title=None, show=True, block=False): """Plot estimated latent sources given the unmixing matrix. Typical usecases: @@ -1363,15 +1363,32 @@ def plot_sources(self, inst, picks=None, exclude=None, start=None, The figure title. If None a default is provided. show : bool If True, all open plots will be shown. + block : bool + Whether to halt program execution until the figure is closed. + Useful for interactive selection of components in raw and epoch + plotter. For evoked, this parameter has no effect. Defaults to + False. Returns ------- fig : instance of pyplot.Figure The figure. + + Notes + ----- + For raw and epoch instances, it is possible to select components for + exclusion by clicking on the line. The selected components are added to + ``ica.exclude`` on close. The independent components can be viewed as + topographies by clicking on the component name on the left of of the + main axes. The topography view tries to infer the correct electrode + layout from the data. This should work at least for Neuromag data. + + .. versionadded:: 0.10.0 """ return plot_ica_sources(self, inst=inst, picks=picks, exclude=exclude, - title=title, start=start, stop=stop, show=show) + title=title, start=start, stop=stop, show=show, + block=block) def plot_scores(self, scores, exclude=None, axhline=None, title='ICA component scores', figsize=(12, 6), diff --git a/mne/viz/__init__.py b/mne/viz/__init__.py index 2e3e0d6f657..c7f3df567bf 100644 --- a/mne/viz/__init__.py +++ b/mne/viz/__init__.py @@ -19,5 +19,6 @@ plot_epochs_trellis, _drop_log_stats, plot_epochs_psd) from .raw import plot_raw, plot_raw_psd from .ica import plot_ica_scores, plot_ica_sources, plot_ica_overlay +from .ica import _plot_sources_raw, _plot_sources_epochs from .montage import plot_montage from .decoding import plot_gat_matrix, plot_gat_times diff --git a/mne/viz/epochs.py b/mne/viz/epochs.py index b3f87fef0d9..cc5fd74f353 100644 --- a/mne/viz/epochs.py +++ b/mne/viz/epochs.py @@ -23,6 +23,7 @@ from ..time_frequency import compute_epochs_psd from .utils import tight_layout, _prepare_trellis, figure_nobar from .utils import _toggle_options, _toggle_proj, _layout_figure +from .utils import _channels_changed, _plot_raw_onscroll, _onclick_help from ..defaults import _handle_default @@ -498,29 +499,146 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, vertical line to the plot. """ import matplotlib.pyplot as plt + scalings = _handle_default('scalings_plot_raw', scalings) + + projs = epochs.info['projs'] + + params = {'epochs': epochs, + 'orig_data': np.concatenate(epochs.get_data(), axis=1), + 'info': copy.deepcopy(epochs.info), + 'bad_color': (0.8, 0.8, 0.8), + 't_start': 0} + params['label_click_fun'] = partial(_pick_bad_channels, params=params) + _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings, + title, picks) + + callback_close = partial(_close_event, params=params) + params['fig'].canvas.mpl_connect('close_event', callback_close) + if show: + try: + plt.show(block=block) + except TypeError: # not all versions have this + plt.show() + + return params['fig'] + + +@verbose +def plot_epochs_psd(epochs, fmin=0, fmax=np.inf, proj=False, n_fft=256, + picks=None, ax=None, color='black', area_mode='std', + area_alpha=0.33, n_overlap=0, + dB=True, n_jobs=1, show=True, verbose=None): + """Plot the power spectral density across epochs + + Parameters + ---------- + epochs : instance of Epochs + The epochs object + fmin : float + Start frequency to consider. + fmax : float + End frequency to consider. + proj : bool + Apply projection. + n_fft : int + Number of points to use in Welch FFT calculations. + picks : array-like of int | None + List of channels to use. + ax : instance of matplotlib Axes | None + Axes to plot into. If None, axes will be created. + color : str | tuple + A matplotlib-compatible color to use. + area_mode : str | None + Mode for plotting area. If 'std', the mean +/- 1 STD (across channels) + will be plotted. If 'range', the min and max (across channels) will be + plotted. Bad channels will be excluded from these calculations. + If None, no area will be plotted. + area_alpha : float + Alpha for the area. + n_overlap : int + The number of points of overlap between blocks. + dB : bool + If True, transform data to decibels. + n_jobs : int + Number of jobs to run in parallel. + show : bool + Show figure if True. + verbose : bool, str, int, or None + If not None, override default verbose level (see mne.verbose). + + Returns + ------- + fig : instance of matplotlib figure + Figure distributing one image per channel across sensor topography. + """ + import matplotlib.pyplot as plt + from .raw import _set_psd_plot_params + fig, picks_list, titles_list, ax_list, make_label = _set_psd_plot_params( + epochs.info, proj, picks, ax, area_mode) + + for ii, (picks, title, ax) in enumerate(zip(picks_list, titles_list, + ax_list)): + psds, freqs = compute_epochs_psd(epochs, picks=picks, fmin=fmin, + fmax=fmax, n_fft=n_fft, + n_overlap=n_overlap, proj=proj, + n_jobs=n_jobs) + + # Convert PSDs to dB + if dB: + psds = 10 * np.log10(psds) + unit = 'dB' + else: + unit = 'power' + # mean across epochs and channels + psd_mean = np.mean(psds, axis=0).mean(axis=0) + if area_mode == 'std': + # std across channels + psd_std = np.std(np.mean(psds, axis=0), axis=0) + hyp_limits = (psd_mean - psd_std, psd_mean + psd_std) + elif area_mode == 'range': + hyp_limits = (np.min(np.mean(psds, axis=0), axis=0), + np.max(np.mean(psds, axis=0), axis=0)) + else: # area_mode is None + hyp_limits = None + + ax.plot(freqs, psd_mean, color=color) + if hyp_limits is not None: + ax.fill_between(freqs, hyp_limits[0], y2=hyp_limits[1], + color=color, alpha=area_alpha) + if make_label: + if ii == len(picks_list) - 1: + ax.set_xlabel('Freq (Hz)') + if ii == len(picks_list) // 2: + ax.set_ylabel('Power Spectral Density (%s/Hz)' % unit) + ax.set_title(title) + ax.set_xlim(freqs[0], freqs[-1]) + if make_label: + tight_layout(pad=0.1, h_pad=0.1, w_pad=0.1, fig=fig) + if show: + plt.show() + return fig + + +def _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings, + title, picks): + """Helper for setting up the mne_browse_epochs window.""" + import matplotlib.pyplot as plt import matplotlib as mpl from matplotlib.collections import LineCollection from matplotlib.colors import colorConverter - scalings = _handle_default('scalings_plot_raw', scalings) - color = _handle_default('color', None) - bad_color = (0.8, 0.8, 0.8) + epochs = params['epochs'] + if picks is None: picks = _handle_picks(epochs) if len(picks) < 1: raise RuntimeError('No appropriate channels found. Please' ' check your picks') - epoch_data = epochs.get_data() - - n_epochs = min(n_epochs, len(epochs.events)) - duration = len(epochs.times) * n_epochs - n_channels = min(n_channels, len(picks)) - projs = epochs.info['projs'] - + picks = sorted(picks) # Reorganize channels inds = list() types = list() for t in ['grad', 'mag']: - idxs = pick_types(epochs.info, meg=t, ref_meg=False, exclude=[]) + idxs = pick_types(params['info'], meg=t, ref_meg=False, exclude=[]) if len(idxs) < 1: continue mask = _in1d(idxs, picks, assume_unique=True) @@ -530,7 +648,7 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, for ch_type in ['eeg', 'eog', 'ecg', 'emg', 'ref_meg', 'stim', 'resp', 'misc', 'chpi', 'syst', 'ias', 'exci']: pick_kwargs[ch_type] = True - idxs = pick_types(epochs.info, **pick_kwargs) + idxs = pick_types(params['info'], **pick_kwargs) if len(idxs) < 1: continue mask = _in1d(idxs, picks, assume_unique=True) @@ -541,9 +659,13 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, if not len(inds) == len(picks): raise RuntimeError('Some channels not classified. Please' ' check your picks') + ch_names = [params['info']['ch_names'][x] for x in inds] # set up plotting size = get_config('MNE_BROWSE_RAW_SIZE') + n_epochs = min(n_epochs, len(epochs.events)) + duration = len(epochs.times) * n_epochs + n_channels = min(n_channels, len(picks)) if size is not None: size = size.split(',') size = tuple(float(s) for s in size) @@ -558,6 +680,7 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, ax.annotate(title, xy=(0.5, 1), xytext=(0, ax.get_ylim()[1] + 15), ha='center', va='bottom', size=12, xycoords='axes fraction', textcoords='offset points') + color = _handle_default('color', None) ax.axis([0, duration, 0, 200]) ax2 = ax.twiny() @@ -573,13 +696,17 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, ax_help_button = plt.subplot2grid((10, 15), (9, 0), colspan=1) help_button = mpl.widgets.Button(ax_help_button, 'Help') - help_button.on_clicked(_onclick_help) + help_button.on_clicked(partial(_onclick_help, params=params)) # populate vertical and horizontal scrollbars for ci in range(len(picks)): + if ch_names[ci] in params['info']['bads']: + this_color = params['bad_color'] + else: + this_color = color[types[ci]] ax_vscroll.add_patch(mpl.patches.Rectangle((0, ci), 1, 1, - facecolor=color[types[ci]], - edgecolor=color[types[ci]], + facecolor=this_color, + edgecolor=this_color, zorder=3)) vsel_patch = mpl.patches.Rectangle((0, 0), 1, n_channels, alpha=0.5, @@ -590,7 +717,6 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, ax_vscroll.set_title('Ch.') # populate colors list - ch_names = [epochs.info['ch_names'][x] for x in inds] type_colors = [colorConverter.to_rgba(color[c]) for c in types] colors = list() for color_idx in range(len(type_colors)): @@ -606,11 +732,10 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, ax.add_collection(lc) lines.append(lc) - epoch_data = np.concatenate(epoch_data, axis=1) times = epochs.times - data = np.zeros((epochs.info['nchan'], len(times) * len(epochs.events))) + data = np.zeros((params['info']['nchan'], len(times) * len(epochs.events))) - ylim = (25., 0.) + ylim = (25., 0.) # Hardcoded 25 because butterfly has max 5 rows (5*5=25). # make shells for plotting traces offset = ylim[0] / n_channels offsets = np.arange(n_channels) * offset + (offset / 2.) @@ -653,53 +778,53 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, ha='left', fontweight='bold') text.set_visible(False) - params = {'fig': fig, - 'ax': ax, - 'ax2': ax2, - 'ax_hscroll': ax_hscroll, - 'ax_vscroll': ax_vscroll, - 'vsel_patch': vsel_patch, - 'hsel_patch': hsel_patch, - 'epochs': epochs, - 'info': copy.deepcopy(epochs.info), # needed for projs - 'lines': lines, - 'n_channels': n_channels, - 'n_epochs': n_epochs, - 'ch_start': 0, - 't_start': 0, - 'duration': duration, - 'colors': colors, - 'def_colors': type_colors, # don't change at runtime - 'picks': picks, - 'bad_color': bad_color, - 'bads': np.array(list(), dtype=int), - 'ch_names': ch_names, - 'data': data, - 'orig_data': epoch_data, - 'times': times, - 'epoch_times': epoch_times, - 'offsets': offsets, - 'labels': labels, - 'projs': projs, - 'scale_factor': 1.0, - 'butterfly_scale': 1.0, - 'fig_proj': None, - 'inds': inds, - 'scalings': scalings, - 'types': np.array(types), - 'vert_lines': list(), - 'vertline_t': vertline_t, - 'butterfly': False, - 'text': text, - 'ax_help_button': ax_help_button, - 'fig_options': None, - 'settings': [True, True, True, True]} # for options dialog + params.update({'fig': fig, + 'ax': ax, + 'ax2': ax2, + 'ax_hscroll': ax_hscroll, + 'ax_vscroll': ax_vscroll, + 'vsel_patch': vsel_patch, + 'hsel_patch': hsel_patch, + 'lines': lines, + 'projs': projs, + 'ch_names': ch_names, + 'n_channels': n_channels, + 'n_epochs': n_epochs, + 'scalings': scalings, + 'types': types, + 'duration': duration, + 'ch_start': 0, + 'colors': colors, + 'def_colors': type_colors, # don't change at runtime + 'picks': picks, + 'bads': np.array(list(), dtype=int), + 'data': data, + 'times': times, + 'epoch_times': epoch_times, + 'offsets': offsets, + 'labels': labels, + 'scale_factor': 1.0, + 'butterfly_scale': 1.0, + 'fig_proj': None, + 'types': np.array(types), + 'inds': inds, + 'vert_lines': list(), + 'vertline_t': vertline_t, + 'butterfly': False, + 'text': text, + 'ax_help_button': ax_help_button, # needed for positioning + 'help_button': help_button, # reference needed for clicks + 'fig_options': None, + 'settings': [True, True, True, True]}) + + params['plot_fun'] = partial(_plot_traces, params=params) if len(projs) > 0 and not epochs.proj: ax_button = plt.subplot2grid((10, 15), (9, 14)) opt_button = mpl.widgets.Button(ax_button, 'Proj') callback_option = partial(_toggle_options, params=params) opt_button.on_clicked(callback_option) + params['opt_button'] = opt_button params['ax_button'] = ax_button # callbacks @@ -709,8 +834,6 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, fig.canvas.mpl_connect('button_press_event', callback_click) callback_key = partial(_plot_onkey, params=params) fig.canvas.mpl_connect('key_press_event', callback_key) - callback_close = partial(_close_event, params=params) - fig.canvas.mpl_connect('close_event', callback_close) callback_resize = partial(_resize_event, params=params) fig.canvas.mpl_connect('resize_event', callback_resize) fig.canvas.mpl_connect('pick_event', partial(_onpick, params=params)) @@ -730,110 +853,6 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, callback_proj('none') _layout_figure(params) - if show: - try: - plt.show(block=block) - except TypeError: # not all versions have this - plt.show() - - return fig - - -@verbose -def plot_epochs_psd(epochs, fmin=0, fmax=np.inf, proj=False, n_fft=256, - picks=None, ax=None, color='black', area_mode='std', - area_alpha=0.33, n_overlap=0, - dB=True, n_jobs=1, show=True, verbose=None): - """Plot the power spectral density across epochs - - Parameters - ---------- - epochs : instance of Epochs - The epochs object - fmin : float - Start frequency to consider. - fmax : float - End frequency to consider. - proj : bool - Apply projection. - n_fft : int - Number of points to use in Welch FFT calculations. - picks : array-like of int | None - List of channels to use. - ax : instance of matplotlib Axes | None - Axes to plot into. If None, axes will be created. - color : str | tuple - A matplotlib-compatible color to use. - area_mode : str | None - Mode for plotting area. If 'std', the mean +/- 1 STD (across channels) - will be plotted. If 'range', the min and max (across channels) will be - plotted. Bad channels will be excluded from these calculations. - If None, no area will be plotted. - area_alpha : float - Alpha for the area. - n_overlap : int - The number of points of overlap between blocks. - dB : bool - If True, transform data to decibels. - n_jobs : int - Number of jobs to run in parallel. - show : bool - Show figure if True. - verbose : bool, str, int, or None - If not None, override default verbose level (see mne.verbose). - - Returns - ------- - fig : instance of matplotlib figure - Figure distributing one image per channel across sensor topography. - """ - import matplotlib.pyplot as plt - from .raw import _set_psd_plot_params - fig, picks_list, titles_list, ax_list, make_label = _set_psd_plot_params( - epochs.info, proj, picks, ax, area_mode) - - for ii, (picks, title, ax) in enumerate(zip(picks_list, titles_list, - ax_list)): - psds, freqs = compute_epochs_psd(epochs, picks=picks, fmin=fmin, - fmax=fmax, n_fft=n_fft, - n_overlap=n_overlap, proj=proj, - n_jobs=n_jobs) - - # Convert PSDs to dB - if dB: - psds = 10 * np.log10(psds) - unit = 'dB' - else: - unit = 'power' - # mean across epochs and channels - psd_mean = np.mean(psds, axis=0).mean(axis=0) - if area_mode == 'std': - # std across channels - psd_std = np.std(np.mean(psds, axis=0), axis=0) - hyp_limits = (psd_mean - psd_std, psd_mean + psd_std) - elif area_mode == 'range': - hyp_limits = (np.min(np.mean(psds, axis=0), axis=0), - np.max(np.mean(psds, axis=0), axis=0)) - else: # area_mode is None - hyp_limits = None - - ax.plot(freqs, psd_mean, color=color) - if hyp_limits is not None: - ax.fill_between(freqs, hyp_limits[0], y2=hyp_limits[1], - color=color, alpha=area_alpha) - if make_label: - if ii == len(picks_list) - 1: - ax.set_xlabel('Freq (Hz)') - if ii == len(picks_list) // 2: - ax.set_ylabel('Power Spectral Density (%s/Hz)' % unit) - ax.set_title(title) - ax.set_xlim(freqs[0], freqs[-1]) - if make_label: - tight_layout(pad=0.1, h_pad=0.1, w_pad=0.1, fig=fig) - if show: - plt.show() - return fig - def _plot_traces(params): """ Helper for plotting concatenated epochs """ @@ -869,16 +888,16 @@ def _plot_traces(params): break elif ch_idx < len(params['ch_names']): if butterfly: - type = params['types'][ch_idx] - if type == 'grad': + ch_type = params['types'][ch_idx] + if ch_type == 'grad': offset = offsets[0] - elif type == 'mag': + elif ch_type == 'mag': offset = offsets[1] - elif type == 'eeg': + elif ch_type == 'eeg': offset = offsets[2] - elif type == 'eog': + elif ch_type == 'eog': offset = offsets[3] - elif type == 'ecg': + elif ch_type == 'ecg': offset = offsets[4] else: lines[line_idx].set_segments(list()) @@ -895,7 +914,7 @@ def _plot_traces(params): segments = np.split(np.array((xdata, ydata)).T, num_epochs) ch_name = params['ch_names'][ch_idx] - if ch_name in params['epochs'].info['bads']: + if ch_name in params['info']['bads']: if not butterfly: this_color = params['bad_color'] ylabels[line_idx].set_color(this_color) @@ -988,7 +1007,7 @@ def _plot_update_epochs_proj(params, bools): types = params['types'] for pick, ind in enumerate(params['inds']): params['data'][pick] = data[ind] / params['scalings'][types[pick]] - _plot_traces(params) + params['plot_fun']() def _handle_picks(epochs): @@ -1012,18 +1031,7 @@ def _plot_window(value, params): if params['t_start'] != value: params['t_start'] = value params['hsel_patch'].set_x(value) - _plot_traces(params) - - -def _channels_changed(params): - """Deal with vertical shift of the viewport.""" - if params['butterfly']: - return - if params['ch_start'] + params['n_channels'] > len(params['ch_names']): - params['ch_start'] = len(params['ch_names']) - params['n_channels'] - elif params['ch_start'] < 0: - params['ch_start'] = 0 - _plot_traces(params) + params['plot_fun']() def _plot_vert_lines(params): @@ -1031,6 +1039,9 @@ def _plot_vert_lines(params): ax = params['ax'] while len(ax.lines) > 0: ax.lines.pop() + params['vert_lines'] = list() + params['vertline_t'].set_text('') + epochs = params['epochs'] if params['settings'][3]: # if zeroline visible t_zero = np.where(epochs.times == 0.)[0] @@ -1046,6 +1057,10 @@ def _plot_vert_lines(params): def _pick_bad_epochs(event, params): """Helper for selecting / dropping bad epochs""" + if 'ica' in params: + pos = (event.xdata, event.ydata) + _pick_bad_channels(pos, params) + return n_times = len(params['epochs'].times) start_idx = int(params['t_start'] / n_times) xdata = event.xdata @@ -1061,7 +1076,7 @@ def _pick_bad_epochs(event, params): params['colors'][ch_idx][epoch_idx] = params['def_colors'][ch_idx] params['ax_hscroll'].patches[epoch_idx].set_color('w') params['ax_hscroll'].patches[epoch_idx].set_zorder(1) - _plot_traces(params) + params['plot_fun']() return # add bad epoch params['bads'] = np.append(params['bads'], epoch_idx) @@ -1070,7 +1085,31 @@ def _pick_bad_epochs(event, params): params['ax_hscroll'].patches[epoch_idx].set_edgecolor('w') for ch_idx in range(len(params['ch_names'])): params['colors'][ch_idx][epoch_idx] = (1., 0., 0., 1.) - _plot_traces(params) + params['plot_fun']() + + +def _pick_bad_channels(pos, params): + """Helper function for selecting bad channels.""" + labels = params['ax'].yaxis.get_ticklabels() + offsets = np.array(params['offsets']) + params['offsets'][0] + line_idx = np.searchsorted(offsets, pos[1]) + text = labels[line_idx].get_text() + if len(text) == 0: + return + ch_idx = params['ch_start'] + line_idx + if text in params['info']['bads']: + while text in params['info']['bads']: + params['info']['bads'].remove(text) + color = params['def_colors'][ch_idx] + params['ax_vscroll'].patches[ch_idx + 1].set_color(color) + else: + params['info']['bads'].append(text) + color = params['bad_color'] + params['ax_vscroll'].patches[ch_idx + 1].set_color(color) + if 'ica' in params: + params['plot_fun']() + else: + params['plot_update_proj_callback'](params, None) def _plot_onscroll(event, params): @@ -1082,15 +1121,9 @@ def _plot_onscroll(event, params): event.key = '+' _plot_onkey(event, params) return - orig_start = params['ch_start'] - if event.step < 0: - params['ch_start'] = min(params['ch_start'] + params['n_channels'], - len(params['ch_names']) - - params['n_channels']) - else: - params['ch_start'] = max(params['ch_start'] - params['n_channels'], 0) - if orig_start != params['ch_start']: - _channels_changed(params) + if params['butterfly']: + return + _plot_raw_onscroll(event, params, len(params['ch_names'])) def _mouse_click(event, params): @@ -1103,20 +1136,7 @@ def _mouse_click(event, params): pos = ax.transData.inverted().transform((event.x, event.y)) if pos[0] > 0 or pos[1] < 0 or pos[1] > ylim[0]: return - labels = ax.yaxis.get_ticklabels() - offsets = np.array(params['offsets']) + params['offsets'][0] - line_idx = np.searchsorted(offsets, pos[1]) - text = labels[line_idx].get_text() - ch_idx = params['ch_start'] + line_idx - if text in params['epochs'].info['bads']: - params['epochs'].info['bads'].remove(text) - color = params['def_colors'][ch_idx] - params['ax_vscroll'].patches[ch_idx + 1].set_color(color) - else: - params['epochs'].info['bads'].append(text) - color = params['bad_color'] - params['ax_vscroll'].patches[ch_idx + 1].set_color(color) - _plot_traces(params) + params['label_click_fun'](pos) elif event.button == 1: # left click # vertical scroll bar changed if event.inaxes == params['ax_vscroll']: @@ -1125,7 +1145,7 @@ def _mouse_click(event, params): ch_start = max(int(event.ydata) - params['n_channels'] // 2, 0) if params['ch_start'] != ch_start: params['ch_start'] = ch_start - _plot_traces(params) + params['plot_fun']() # horizontal scroll bar changed elif event.inaxes == params['ax_hscroll']: # find the closest epoch time @@ -1150,28 +1170,32 @@ def _mouse_click(event, params): while len(params['vert_lines']) > 0: params['ax'].lines.remove(params['vert_lines'][0][0]) params['vert_lines'].pop(0) - if prev_xdata == xdata: + if prev_xdata == xdata: # lines removed params['vertline_t'].set_text('') - _plot_traces(params) + params['plot_fun']() return ylim = params['ax'].get_ylim() - for epoch_idx in range(params['n_epochs']): + for epoch_idx in range(params['n_epochs']): # plot lines pos = [epoch_idx * n_times + xdata, epoch_idx * n_times + xdata] params['vert_lines'].append(params['ax'].plot(pos, ylim, 'y', zorder=4)) params['vertline_t'].set_text('%0.3f' % params['epochs'].times[xdata]) - _plot_traces(params) + params['plot_fun']() def _plot_onkey(event, params): """Function to handle key presses.""" import matplotlib.pyplot as plt if event.key == 'down': + if params['butterfly']: + return params['ch_start'] += params['n_channels'] - _channels_changed(params) + _channels_changed(params, len(params['ch_names'])) elif event.key == 'up': + if params['butterfly']: + return params['ch_start'] -= params['n_channels'] - _channels_changed(params) + _channels_changed(params, len(params['ch_names'])) elif event.key == 'left': sample = params['t_start'] - params['duration'] sample = np.max([0, sample]) @@ -1187,13 +1211,13 @@ def _plot_onkey(event, params): params['butterfly_scale'] /= 1.1 else: params['scale_factor'] /= 1.1 - _plot_traces(params) + params['plot_fun']() elif event.key in ['+', '=']: if params['butterfly']: params['butterfly_scale'] *= 1.1 else: params['scale_factor'] *= 1.1 - _plot_traces(params) + params['plot_fun']() elif event.key == 'f11': mng = plt.get_current_fig_manager() mng.full_screen_toggle() @@ -1209,7 +1233,7 @@ def _plot_onkey(event, params): params['ax'].set_yticks(params['offsets']) params['lines'].pop() params['vsel_patch'].set_height(n_channels) - _plot_traces(params) + params['plot_fun']() elif event.key == 'pageup': if params['butterfly']: return @@ -1225,7 +1249,7 @@ def _plot_onkey(event, params): params['ax'].set_yticks(params['offsets']) params['lines'].append(lc) params['vsel_patch'].set_height(n_channels) - _plot_traces(params) + params['plot_fun']() elif event.key == 'home': n_epochs = params['n_epochs'] - 1 if n_epochs <= 0: @@ -1236,7 +1260,7 @@ def _plot_onkey(event, params): params['n_epochs'] = n_epochs params['duration'] -= n_times params['hsel_patch'].set_width(params['duration']) - _plot_traces(params) + params['plot_fun']() elif event.key == 'end': n_epochs = params['n_epochs'] + 1 n_times = len(params['epochs'].times) @@ -1258,7 +1282,7 @@ def _plot_onkey(event, params): params['t_start'] -= n_times params['hsel_patch'].set_x(params['t_start']) params['hsel_patch'].set_width(params['duration']) - _plot_traces(params) + params['plot_fun']() elif event.key == 'b': if params['fig_options'] is not None: plt.close(params['fig_options']) @@ -1269,7 +1293,7 @@ def _plot_onkey(event, params): if not params['butterfly']: _open_options(params) elif event.key == '?': - _onclick_help(event) + _onclick_help(event, params) elif event.key == 'escape': plt.close(params['fig']) @@ -1379,6 +1403,7 @@ def _close_event(event, params): """Function to drop selected bad epochs. Called on closing of the plot.""" params['epochs'].drop_epochs(params['bads']) logger.info('Channels marked as bad: %s' % params['epochs'].info['bads']) + params['epochs'].info['bads'] = params['info']['bads'] def _resize_event(event, params): @@ -1388,79 +1413,6 @@ def _resize_event(event, params): _layout_figure(params) -def _onclick_help(event): - """Function for drawing help window""" - import matplotlib.pyplot as plt - - text = u'\u2190 : \n'\ - u'\u2192 : \n'\ - u'\u2193 : \n'\ - u'\u2191 : \n'\ - u'- : \n'\ - u'+ or = : \n'\ - u'Home : \n'\ - u'End : \n'\ - u'Page down : \n'\ - u'Page up : \n'\ - u'b : \n'\ - u'o : \n'\ - u'F11 : \n'\ - u'? : \n'\ - u'Esc : \n\n'\ - u'Mouse controls\n'\ - u'click epoch :\n'\ - u'click channel name :\n'\ - u'right click :\n'\ - u'middle click :\n' - - text2 = 'Navigate left\n'\ - 'Navigate right\n'\ - 'Navigate channels down\n'\ - 'Navigate channels up\n'\ - 'Scale down\n'\ - 'Scale up\n'\ - 'Reduce the number of epochs per view\n'\ - 'Increase the number of epochs per view\n'\ - 'Reduce the number of channels per view\n'\ - 'Increase the number of channels per view\n'\ - 'Toggle butterfly plot on/off\n'\ - 'View settings (orig. view only)\n'\ - 'Toggle full screen mode\n'\ - 'Open help box\n'\ - 'Quit\n\n\n'\ - 'Mark bad epoch\n'\ - 'Mark bad channel\n'\ - 'Verticlal line at a time instant\n'\ - 'Show channel name (butterfly plot)\n' - - width = 5.5 - height = 0.25 * 19 # 19 rows of text - - fig_help = figure_nobar(figsize=(width, height), dpi=80) - fig_help.canvas.set_window_title('Help') - ax = plt.subplot2grid((8, 5), (0, 0), colspan=5) - ax.set_title('Keyboard shortcuts') - plt.axis('off') - ax1 = plt.subplot2grid((8, 5), (1, 0), rowspan=7, colspan=2) - ax1.set_yticklabels(list()) - plt.text(0.99, 1, text, fontname='STIXGeneral', va='top', weight='bold', - ha='right') - plt.axis('off') - - ax2 = plt.subplot2grid((8, 5), (1, 2), rowspan=7, colspan=3) - ax2.set_yticklabels(list()) - plt.text(0, 1, text2, fontname='STIXGeneral', va='top') - plt.axis('off') - - tight_layout(fig=fig_help) - # this should work for non-test cases - try: - fig_help.canvas.draw() - fig_help.show() - except Exception: - pass - - def _update_channels_epochs(event, params): """Function for changing the amount of channels and epochs per view.""" from matplotlib.collections import LineCollection diff --git a/mne/viz/ica.py b/mne/viz/ica.py index 0309e670835..7b748f7919c 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -12,8 +12,17 @@ import numpy as np -from .utils import tight_layout, _prepare_trellis +from .utils import tight_layout, _prepare_trellis, _select_bads +from .utils import _layout_figure, _plot_raw_onscroll, _mouse_click +from .utils import _helper_raw_resize, _plot_raw_onkey +from .raw import _prepare_mne_browse_raw, _plot_raw_traces +from .epochs import _prepare_mne_browse_epochs from .evoked import _butterfly_on_button_press, _butterfly_onpick +from .topomap import _prepare_topo_plot, plot_topomap +from ..utils import logger +from ..defaults import _handle_default +from ..io.meas_info import create_info +from ..io.pick import pick_types def _ica_plot_sources_onpick_(event, sources=None, ylims=None): @@ -41,7 +50,7 @@ def _ica_plot_sources_onpick_(event, sources=None, ylims=None): def plot_ica_sources(ica, inst, picks=None, exclude=None, start=None, - stop=None, show=True, title=None): + stop=None, show=True, title=None, block=False): """Plot estimated latent sources given the unmixing matrix. Typical usecases: @@ -64,18 +73,31 @@ def plot_ica_sources(ica, inst, picks=None, exclude=None, start=None, The components marked for exclusion. If None (default), ICA.exclude will be used. start : int - X-axis start index. If None from the beginning. + X-axis start index. If None, from the beginning. stop : int - X-axis stop index. If None to the end. + X-axis stop index. If None, next 20 are shown, in case of evoked to the + end. show : bool Show figure if True. title : str | None The figure title. If None a default is provided. + block : bool + Whether to halt program execution until the figure is closed. + Useful for interactive selection of components in raw and epoch + plotter. For evoked, this parameter has no effect. Defaults to False. Returns ------- fig : instance of pyplot.Figure The figure. + + Notes + ----- + For raw and epoch instances, it is possible to select components for + exclusion by clicking on the line. The selected components are added to + ``ica.exclude`` on close. + + .. versionadded:: 0.10.0 """ from ..io.base import _BaseRaw @@ -84,24 +106,16 @@ def plot_ica_sources(ica, inst, picks=None, exclude=None, start=None, if exclude is None: exclude = ica.exclude - - if isinstance(inst, (_BaseRaw, _BaseEpochs)): - if isinstance(inst, _BaseRaw): - sources = ica._transform_raw(inst, start, stop) - else: - if start is not None or stop is not None: - inst = inst.crop(start, stop, copy=True) - sources = ica._transform_epochs(inst, concatenate=True) - if picks is not None: - if np.isscalar(picks): - picks = [picks] - sources = np.atleast_2d(sources[picks]) - - fig = _plot_ica_grid(sources, start=start, stop=stop, - ncol=len(sources) // 10 or 1, - exclude=exclude, - source_idx=picks, - title=title, show=show) + elif len(ica.exclude) > 0: + exclude = np.union1d(ica.exclude, exclude) + if isinstance(inst, _BaseRaw): + fig = _plot_sources_raw(ica, inst, picks, exclude, start=start, + stop=stop, show=show, title=title, + block=block) + elif isinstance(inst, _BaseEpochs): + fig = _plot_sources_epochs(ica, inst, picks, exclude, start=start, + stop=stop, show=show, title=title, + block=block) elif isinstance(inst, Evoked): sources = ica.get_sources(inst) if start is not None or stop is not None: @@ -234,7 +248,7 @@ def _plot_ica_sources_evoked(evoked, picks, exclude, title, show): ax.set_xlim(times[[0, -1]]) ax.set_xlabel('Time (ms)') ax.set_ylabel('(NA)') - if exclude: + if len(exclude) > 0: plt.legend(loc='best') tight_layout(fig=fig) @@ -315,7 +329,7 @@ def plot_ica_scores(ica, scores, exclude=None, axhline=None, plt.suptitle(title) for this_scores, ax in zip(scores, axes): if len(my_range) != len(this_scores): - raise ValueError('The length ofr `scores` must equal the ' + raise ValueError('The length of `scores` must equal the ' 'number of ICA components.') ax.bar(my_range, this_scores, color='w') for excl in exclude: @@ -510,3 +524,199 @@ def _plot_ica_overlay_evoked(evoked, evoked_cln, title, show): plt.show() return fig + + +def _plot_sources_raw(ica, raw, picks, exclude, start, stop, show, title, + block): + """Function for plotting the ICA components as raw array.""" + import matplotlib.pyplot as plt + plt.ion() + color = _handle_default('color', (0., 0., 0.)) + orig_data = ica._transform_raw(raw, 0, len(raw.times)) * 0.2 + if picks is None: + picks = range(len(orig_data)) + types = ['misc' for _ in picks] + picks = list(sorted(picks)) + data = [orig_data[pick] for pick in picks] + c_names = ['ICA %03d' % x for x in range(len(orig_data))] + if title is None: + title = 'ICA components' + info = create_info([c_names[x] for x in picks], raw.info['sfreq']) + + info['bads'] = [c_names[x] for x in exclude] + if start is None: + start = 0 + if stop is None: + stop = start + 20 + stop = min(stop, raw.times[-1]) + duration = stop - start + if duration <= 0: + raise RuntimeError('Stop must be larger than start.') + t_end = int(duration * raw.info['sfreq']) + times = raw.times[0:t_end] + bad_color = (1., 0., 0.) + inds = list(range(len(picks))) + data = np.array(data) + n_channels = min([20, len(picks)]) + params = dict(raw=raw, orig_data=data, data=data[:, 0:t_end], + ch_start=0, t_start=start, info=info, duration=duration, + ica=ica, n_channels=n_channels, times=times, types=types, + n_times=raw.n_times, bad_color=bad_color, picks=picks) + _prepare_mne_browse_raw(params, title, 'w', color, bad_color, inds, + n_channels) + params['scale_factor'] = 1.0 + params['plot_fun'] = partial(_plot_raw_traces, params=params, inds=inds, + color=color, bad_color=bad_color) + params['update_fun'] = partial(_update_data, params) + params['pick_bads_fun'] = partial(_pick_bads, params=params) + params['label_click_fun'] = partial(_label_clicked, params=params) + _layout_figure(params) + # callbacks + callback_key = partial(_plot_raw_onkey, params=params) + params['fig'].canvas.mpl_connect('key_press_event', callback_key) + callback_scroll = partial(_plot_raw_onscroll, params=params) + params['fig'].canvas.mpl_connect('scroll_event', callback_scroll) + callback_pick = partial(_mouse_click, params=params) + params['fig'].canvas.mpl_connect('button_press_event', callback_pick) + callback_resize = partial(_helper_raw_resize, params=params) + params['fig'].canvas.mpl_connect('resize_event', callback_resize) + callback_close = partial(_close_event, params=params) + params['fig'].canvas.mpl_connect('close_event', callback_close) + params['fig_proj'] = None + params['event_times'] = None + params['update_fun']() + params['plot_fun']() + if show: + try: + plt.show(block=block) + except TypeError: # not all versions have this + plt.show() + + return params['fig'] + + +def _update_data(params): + """Function for preparing the data on horizontal shift of the viewport.""" + sfreq = params['info']['sfreq'] + start = int(params['t_start'] * sfreq) + end = int((params['t_start'] + params['duration']) * sfreq) + params['data'] = params['orig_data'][:, start:end] + params['times'] = params['raw'].times[start:end] + + +def _pick_bads(event, params): + """Function for selecting components on click.""" + bads = params['info']['bads'] + params['info']['bads'] = _select_bads(event, params, bads) + params['update_fun']() + params['plot_fun']() + + +def _close_event(events, params): + """Function for excluding the selected components on close.""" + info = params['info'] + c_names = ['ICA %03d' % x for x in range(params['ica'].n_components_)] + exclude = [c_names.index(x) for x in info['bads']] + params['ica'].exclude = exclude + + +def _plot_sources_epochs(ica, epochs, picks, exclude, start, stop, show, + title, block): + """Function for plotting the components as epochs.""" + import matplotlib.pyplot as plt + plt.ion() # Turn interactive mode on to avoid warnings. + data = ica._transform_epochs(epochs, concatenate=True) + c_names = ['ICA %03d' % x for x in range(ica.n_components_)] + scalings = {'misc': 5.0} + info = create_info(ch_names=c_names, sfreq=epochs.info['sfreq']) + info['projs'] = list() + info['bads'] = [c_names[x] for x in exclude] + if title is None: + title = 'ICA components' + if picks is None: + picks = list(range(len(c_names))) + if start is None: + start = 0 + if stop is None: + stop = start + 20 + stop = min(stop, len(epochs.events)) + n_epochs = stop - start + if n_epochs <= 0: + raise RuntimeError('Stop must be larger than start.') + params = {'ica': ica, + 'epochs': epochs, + 'info': info, + 'orig_data': data, + 'bads': list(), + 'bad_color': (1., 0., 0.), + 't_start': start * len(epochs.times)} + params['label_click_fun'] = partial(_label_clicked, params=params) + _prepare_mne_browse_epochs(params, projs=list(), n_channels=20, + n_epochs=n_epochs, scalings=scalings, + title=title, picks=picks) + params['hsel_patch'].set_x(params['t_start']) + callback_close = partial(_close_epochs_event, params=params) + params['fig'].canvas.mpl_connect('close_event', callback_close) + if show: + try: + plt.show(block=block) + except TypeError: # not all versions have this + plt.show() + + return params['fig'] + + +def _close_epochs_event(events, params): + """Function for excluding the selected components on close.""" + info = params['info'] + exclude = [info['ch_names'].index(x) for x in info['bads']] + params['ica'].exclude = exclude + + +def _label_clicked(pos, params): + """Function for plotting independent components on click to label.""" + import matplotlib.pyplot as plt + offsets = np.array(params['offsets']) + params['offsets'][0] + line_idx = np.searchsorted(offsets, pos[1]) + params['ch_start'] + if line_idx >= len(params['picks']): + return + ic_idx = [params['picks'][line_idx]] + types = list() + info = params['ica'].info + if len(pick_types(info, meg=False, eeg=True, ref_meg=False)) > 0: + types.append('eeg') + if len(pick_types(info, meg='mag', ref_meg=False)) > 0: + types.append('mag') + if len(pick_types(info, meg='grad', ref_meg=False)) > 0: + types.append('grad') + + ica = params['ica'] + data = np.dot(ica.mixing_matrix_[:, ic_idx].T, + ica.pca_components_[:ica.n_components_]) + data = np.atleast_2d(data) + fig, axes = _prepare_trellis(len(types), max_col=3) + for ch_idx, ch_type in enumerate(types): + try: + data_picks, pos, merge_grads, _, _ = _prepare_topo_plot(ica, + ch_type, + None) + except Exception as exc: + logger.warning(exc) + plt.close(fig) + return + this_data = data[:, data_picks] + ax = axes[ch_idx] + if merge_grads: + from ..channels.layout import _merge_grad_data + for ii, data_ in zip(ic_idx, this_data): + ax.set_title('IC #%03d ' % ii + ch_type, fontsize=12) + data_ = _merge_grad_data(data_) if merge_grads else data_ + plot_topomap(data_.flatten(), pos, axis=ax, show=False)[0] + ax.set_yticks([]) + ax.set_xticks([]) + ax.set_frame_on(False) + tight_layout(fig=fig) + fig.subplots_adjust(top=0.95) + fig.canvas.draw() + + plt.show() diff --git a/mne/viz/raw.py b/mne/viz/raw.py index 753c910a126..e4bdf839656 100644 --- a/mne/viz/raw.py +++ b/mne/viz/raw.py @@ -15,10 +15,12 @@ from ..externals.six import string_types from ..io.pick import pick_types from ..io.proj import setup_proj -from ..utils import set_config, get_config, verbose +from ..utils import verbose, get_config from ..time_frequency import compute_raw_psd -from .utils import figure_nobar, _toggle_options, _toggle_proj, tight_layout -from .utils import _layout_figure +from .utils import _toggle_options, _toggle_proj, tight_layout +from .utils import _layout_figure, _plot_raw_onkey, figure_nobar +from .utils import _plot_raw_onscroll, _mouse_click +from .utils import _helper_raw_resize, _select_bads, _onclick_help from ..defaults import _handle_default @@ -31,7 +33,7 @@ def _plot_update_raw_proj(params, bools): params['proj_bools'] = bools params['projector'], _ = setup_proj(params['info'], add_eeg_ref=False, verbose=False) - _update_raw_data(params) + params['update_fun']() params['plot_fun']() @@ -66,266 +68,14 @@ def _update_raw_data(params): params['times'] = times -def _helper_resize(event, params): - """Helper for resizing""" - size = ','.join([str(s) for s in params['fig'].get_size_inches()]) - set_config('MNE_BROWSE_RAW_SIZE', size) - _layout_figure(params) - - def _pick_bad_channels(event, params): """Helper for selecting / dropping bad channels onpick""" + # Both bad lists are updated. params['info'] used for colors. bads = params['raw'].info['bads'] - - # trade-off, avoid selecting more than one channel when drifts are present - # however for clean data don't click on peaks but on flat segments - def f(x, y): - return y(np.mean(x), x.std() * 2) - for l in event.inaxes.lines: - ydata = l.get_ydata() - if not isinstance(ydata, list) and not np.isnan(ydata).any(): - ymin, ymax = f(ydata, np.subtract), f(ydata, np.add) - if ymin <= event.ydata <= ymax: - this_chan = vars(l)['ch_name'] - if this_chan in params['raw'].ch_names: - if this_chan not in bads: - bads.append(this_chan) - l.set_color(params['bad_color']) - l.set_zorder(-1) - else: - bads.pop(bads.index(this_chan)) - l.set_color(vars(l)['def_color']) - l.set_zorder(0) - break - else: - x = np.array([event.xdata] * 2) - params['ax_vertline'].set_data(x, np.array(params['ax'].get_ylim())) - params['ax_hscroll_vertline'].set_data(x, np.array([0., 1.])) - params['vertline_t'].set_text('%0.3f' % x[0]) - # update deep-copied info to persistently draw bads - params['info']['bads'] = bads + params['info']['bads'] = _select_bads(event, params, bads) _plot_update_raw_proj(params, None) -def _mouse_click(event, params): - """Vertical select callback""" - if event.inaxes is None or event.button != 1: - return - # vertical scrollbar changed - if event.inaxes == params['ax_vscroll']: - ch_start = max(int(event.ydata) - params['n_channels'] // 2, 0) - if params['ch_start'] != ch_start: - params['ch_start'] = ch_start - params['plot_fun']() - # horizontal scrollbar changed - elif event.inaxes == params['ax_hscroll']: - _plot_raw_time(event.xdata - params['duration'] / 2, params) - - elif event.inaxes == params['ax']: - _pick_bad_channels(event, params) - - -def _plot_raw_time(value, params): - """Deal with changed time value""" - info = params['info'] - max_times = params['n_times'] / float(info['sfreq']) - params['duration'] - if value > max_times: - value = params['n_times'] / info['sfreq'] - params['duration'] - if value < 0: - value = 0 - if params['t_start'] != value: - params['t_start'] = value - params['hsel_patch'].set_x(value) - _update_raw_data(params) - params['plot_fun']() - - -def _plot_raw_onkey(event, params): - """Interpret key presses""" - import matplotlib.pyplot as plt - # check for initial plot - if event is None: - params['plot_fun']() - return - - # quit event - if event.key == 'escape': - plt.close(params['fig']) - return - - # change plotting params - ch_changed = False - if event.key == 'down': - params['ch_start'] += params['n_channels'] - ch_changed = True - elif event.key == 'up': - params['ch_start'] -= params['n_channels'] - ch_changed = True - elif event.key == 'right': - _plot_raw_time(params['t_start'] + params['duration'], params) - return - elif event.key == 'left': - _plot_raw_time(params['t_start'] - params['duration'], params) - return - elif event.key in ['o', 'p']: - _toggle_options(None, params) - return - elif event.key in ['+', '=']: - params['scale_factor'] *= 1.1 - params['plot_fun']() - return - elif event.key == '-': - params['scale_factor'] /= 1.1 - params['plot_fun']() - return - elif event.key == 'pageup': - n_channels = params['n_channels'] + 1 - offset = params['ax'].get_ylim()[0] / n_channels - params['offsets'] = np.arange(n_channels) * offset + (offset / 2.) - params['n_channels'] = n_channels - params['ax'].set_yticks(params['offsets']) - params['vsel_patch'].set_height(n_channels) - ch_changed = True - elif event.key == 'pagedown': - n_channels = params['n_channels'] - 1 - if n_channels == 0: - return - offset = params['ax'].get_ylim()[0] / n_channels - params['offsets'] = np.arange(n_channels) * offset + (offset / 2.) - params['n_channels'] = n_channels - params['ax'].set_yticks(params['offsets']) - params['vsel_patch'].set_height(n_channels) - if len(params['lines']) > n_channels: # remove line from view - params['lines'][n_channels].set_xdata([]) - params['lines'][n_channels].set_ydata([]) - ch_changed = True - elif event.key == 'home': - duration = params['duration'] - 1.0 - if duration <= 0: - return - params['duration'] = duration - params['hsel_patch'].set_width(params['duration']) - _update_raw_data(params) - params['plot_fun']() - elif event.key == 'end': - duration = params['duration'] + 1.0 - if duration > params['raw'].times[-1]: - duration = params['raw'].times[-1] - params['duration'] = duration - params['hsel_patch'].set_width(params['duration']) - _update_raw_data(params) - params['plot_fun']() - elif event.key == 'f11': - mng = plt.get_current_fig_manager() - mng.full_screen_toggle() - return - # deal with plotting changes - if ch_changed: - _channels_changed(params) - - -def _channels_changed(params): - len_channels = len(params['info']['ch_names']) - if params['ch_start'] + params['n_channels'] >= len_channels: - params['ch_start'] = len_channels - params['n_channels'] - if params['ch_start'] < 0: - params['ch_start'] = 0 - params['plot_fun']() - - -def _plot_raw_onscroll(event, params): - """Interpret scroll events""" - orig_start = params['ch_start'] - if event.step < 0: - params['ch_start'] = min(params['ch_start'] + params['n_channels'], - len(params['info']['ch_names']) - - params['n_channels']) - else: # event.key == 'up': - params['ch_start'] = max(params['ch_start'] - params['n_channels'], 0) - if orig_start != params['ch_start']: - _channels_changed(params) - - -def _plot_traces(params, inds, color, bad_color, event_lines, event_color): - """Helper for plotting raw""" - lines = params['lines'] - info = params['info'] - n_channels = params['n_channels'] - params['bad_color'] = bad_color - # do the plotting - tick_list = [] - for ii in range(n_channels): - ch_ind = ii + params['ch_start'] - # let's be generous here and allow users to pass - # n_channels per view >= the number of traces available - if ii >= len(lines): - break - elif ch_ind < len(info['ch_names']): - # scale to fit - ch_name = info['ch_names'][inds[ch_ind]] - tick_list += [ch_name] - offset = params['offsets'][ii] - - # do NOT operate in-place lest this get screwed up - this_data = params['data'][inds[ch_ind]] * params['scale_factor'] - this_color = bad_color if ch_name in info['bads'] else color - this_z = -1 if ch_name in info['bads'] else 0 - if isinstance(this_color, dict): - this_color = this_color[params['types'][inds[ch_ind]]] - - # subtraction here gets corect orientation for flipped ylim - lines[ii].set_ydata(offset - this_data) - lines[ii].set_xdata(params['times']) - lines[ii].set_color(this_color) - lines[ii].set_zorder(this_z) - vars(lines[ii])['ch_name'] = ch_name - vars(lines[ii])['def_color'] = color[params['types'][inds[ch_ind]]] - else: - # "remove" lines - lines[ii].set_xdata([]) - lines[ii].set_ydata([]) - # deal with event lines - if params['event_times'] is not None: - # find events in the time window - event_times = params['event_times'] - mask = np.logical_and(event_times >= params['times'][0], - event_times <= params['times'][-1]) - event_times = event_times[mask] - event_nums = params['event_nums'][mask] - # plot them with appropriate colors - # go through the list backward so we end with -1, the catchall - used = np.zeros(len(event_times), bool) - ylim = params['ax'].get_ylim() - for ev_num, line in zip(sorted(event_color.keys())[::-1], - event_lines[::-1]): - mask = (event_nums == ev_num) if ev_num >= 0 else ~used - assert not np.any(used[mask]) - used[mask] = True - t = event_times[mask] - if len(t) > 0: - xs = list() - ys = list() - for tt in t: - xs += [tt, tt, np.nan] - ys += [0, ylim[0], np.nan] - line.set_xdata(xs) - line.set_ydata(ys) - else: - line.set_xdata([]) - line.set_ydata([]) - # finalize plot - params['ax'].set_xlim(params['times'][0], - params['times'][0] + params['duration'], False) - params['ax'].set_yticklabels(tick_list) - params['vsel_patch'].set_y(params['ch_start']) - params['fig'].canvas.draw() - # XXX This is a hack to make sure this figure gets drawn last - # so that when matplotlib goes to calculate bounds we don't get a - # CGContextRef error on the MacOSX backend :( - if params['fig_proj'] is not None: - params['fig_proj'].canvas.draw() - - def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=None, bgcolor='w', color=None, bad_color=(0.8, 0.8, 0.8), event_color='cyan', scalings=None, remove_dc=True, order='type', @@ -529,80 +279,20 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=None, n_times=n_times, event_times=event_times, event_nums=event_nums, clipping=clipping, fig_proj=None) - # set up plotting - size = get_config('MNE_BROWSE_RAW_SIZE') - if size is not None: - size = size.split(',') - size = tuple([float(s) for s in size]) - # have to try/catch when there's no toolbar - fig = figure_nobar(facecolor=bgcolor, figsize=size, dpi=80) - fig.canvas.set_window_title('mne_browse_raw') - ax = plt.subplot2grid((10, 10), (0, 0), colspan=9, rowspan=9) - ax.set_title(title, fontsize=12) - ax_hscroll = plt.subplot2grid((10, 10), (9, 0), colspan=9) - ax_hscroll.get_yaxis().set_visible(False) - ax_hscroll.set_xlabel('Time (s)') - ax_vscroll = plt.subplot2grid((10, 10), (0, 9), rowspan=9) - ax_vscroll.set_axis_off() - # store these so they can be fixed on resize - params['fig'] = fig - params['ax'] = ax - params['ax_hscroll'] = ax_hscroll - params['ax_vscroll'] = ax_vscroll + _prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds, + n_channels) - # populate vertical and horizontal scrollbars - for ci in range(len(info['ch_names'])): - this_color = (bad_color if info['ch_names'][inds[ci]] in info['bads'] - else color) - if isinstance(this_color, dict): - this_color = this_color[types[inds[ci]]] - ax_vscroll.add_patch(mpl.patches.Rectangle((0, ci), 1, 1, - facecolor=this_color, - edgecolor=this_color)) - vsel_patch = mpl.patches.Rectangle((0, 0), 1, n_channels, alpha=0.5, - facecolor='w', edgecolor='w') - ax_vscroll.add_patch(vsel_patch) - params['vsel_patch'] = vsel_patch - hsel_patch = mpl.patches.Rectangle((start, 0), duration, 1, edgecolor='k', - facecolor=(0.75, 0.75, 0.75), - alpha=0.25, linewidth=1, clip_on=False) - ax_hscroll.add_patch(hsel_patch) - params['hsel_patch'] = hsel_patch - ax_hscroll.set_xlim(0, n_times / float(info['sfreq'])) - n_ch = len(info['ch_names']) - ax_vscroll.set_ylim(n_ch, 0) - ax_vscroll.set_title('Ch.') - - # make shells for plotting traces - ylim = [n_channels * 2 + 1, 0] - offset = ylim[0] / n_channels - offsets = np.arange(n_channels) * offset + (offset / 2.) - ax.set_yticks(offsets) - ax.set_ylim(ylim) # plot event_line first so it's in the back - event_lines = [ax.plot([np.nan], color=event_color[ev_num])[0] + event_lines = [params['ax'].plot([np.nan], color=event_color[ev_num])[0] for ev_num in sorted(event_color.keys())] - params['offsets'] = offsets - params['lines'] = [ax.plot([np.nan], antialiased=False, linewidth=0.5)[0] - for _ in range(n_ch)] - ax.set_yticklabels(['X' * max([len(ch) for ch in info['ch_names']])]) - vertline_color = (0., 0.75, 0.) - params['ax_vertline'] = ax.plot([0, 0], ylim, color=vertline_color, - zorder=-1)[0] - params['ax_vertline'].ch_name = '' - params['vertline_t'] = ax_hscroll.text(0, 0.5, '', color=vertline_color, - verticalalignment='center', - horizontalalignment='right') - params['ax_hscroll_vertline'] = ax_hscroll.plot([0, 0], [0, 1], - color=vertline_color, - zorder=1)[0] - - params['plot_fun'] = partial(_plot_traces, params=params, inds=inds, + params['plot_fun'] = partial(_plot_raw_traces, params=params, inds=inds, color=color, bad_color=bad_color, event_lines=event_lines, event_color=event_color) + params['update_fun'] = partial(_update_raw_data, params=params) + params['pick_bads_fun'] = partial(_pick_bad_channels, params=params) + params['label_click_fun'] = partial(_label_clicked, params=params) params['scale_factor'] = 1.0 - # set up callbacks opt_button = None if len(raw.info['projs']) > 0 and not raw.proj: @@ -612,13 +302,13 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=None, callback_option = partial(_toggle_options, params=params) opt_button.on_clicked(callback_option) callback_key = partial(_plot_raw_onkey, params=params) - fig.canvas.mpl_connect('key_press_event', callback_key) + params['fig'].canvas.mpl_connect('key_press_event', callback_key) callback_scroll = partial(_plot_raw_onscroll, params=params) - fig.canvas.mpl_connect('scroll_event', callback_scroll) + params['fig'].canvas.mpl_connect('scroll_event', callback_scroll) callback_pick = partial(_mouse_click, params=params) - fig.canvas.mpl_connect('button_press_event', callback_pick) - callback_resize = partial(_helper_resize, params=params) - fig.canvas.mpl_connect('resize_event', callback_resize) + params['fig'].canvas.mpl_connect('button_press_event', callback_pick) + callback_resize = partial(_helper_raw_resize, params=params) + params['fig'].canvas.mpl_connect('resize_event', callback_resize) # As here code is shared with plot_evoked, some extra steps: # first the actual plot update function @@ -645,7 +335,30 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=None, except TypeError: # not all versions have this plt.show() - return fig + return params['fig'] + + +def _label_clicked(pos, params): + """Helper function for selecting bad channels.""" + labels = params['ax'].yaxis.get_ticklabels() + offsets = np.array(params['offsets']) + params['offsets'][0] + line_idx = np.searchsorted(offsets, pos[1]) + text = labels[line_idx].get_text() + if len(text) == 0: + return + ch_idx = params['ch_start'] + line_idx + bads = params['info']['bads'] + if text in bads: + while text in bads: # to make sure duplicates are removed + bads.remove(text) + color = vars(params['lines'][line_idx])['def_color'] + params['ax_vscroll'].patches[ch_idx].set_color(color) + else: + bads.append(text) + color = params['bad_color'] + params['ax_vscroll'].patches[ch_idx].set_color(color) + params['raw'].info['bads'] = bads + _plot_update_raw_proj(params, None) def _set_psd_plot_params(info, proj, picks, ax, area_mode): @@ -790,3 +503,163 @@ def plot_raw_psd(raw, tmin=0., tmax=np.inf, fmin=0, fmax=np.inf, proj=False, if show is True: plt.show() return fig + + +def _prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds, + n_channels): + """Helper for setting up the mne_browse_raw window.""" + import matplotlib.pyplot as plt + import matplotlib as mpl + size = get_config('MNE_BROWSE_RAW_SIZE') + if size is not None: + size = size.split(',') + size = tuple([float(s) for s in size]) + + fig = figure_nobar(facecolor=bgcolor, figsize=size) + fig.canvas.set_window_title('mne_browse_raw') + ax = plt.subplot2grid((10, 10), (0, 1), colspan=8, rowspan=9) + ax.set_title(title, fontsize=12) + ax_hscroll = plt.subplot2grid((10, 10), (9, 1), colspan=8) + ax_hscroll.get_yaxis().set_visible(False) + ax_hscroll.set_xlabel('Time (s)') + ax_vscroll = plt.subplot2grid((10, 10), (0, 9), rowspan=9) + ax_vscroll.set_axis_off() + ax_help_button = plt.subplot2grid((10, 10), (0, 0), colspan=1) + help_button = mpl.widgets.Button(ax_help_button, 'Help') + help_button.on_clicked(partial(_onclick_help, params=params)) + # store these so they can be fixed on resize + params['fig'] = fig + params['ax'] = ax + params['ax_hscroll'] = ax_hscroll + params['ax_vscroll'] = ax_vscroll + params['ax_help_button'] = ax_help_button + params['help_button'] = help_button + + # populate vertical and horizontal scrollbars + info = params['info'] + for ci in range(len(info['ch_names'])): + this_color = (bad_color if info['ch_names'][inds[ci]] in info['bads'] + else color) + if isinstance(this_color, dict): + this_color = this_color[params['types'][inds[ci]]] + ax_vscroll.add_patch(mpl.patches.Rectangle((0, ci), 1, 1, + facecolor=this_color, + edgecolor=this_color)) + vsel_patch = mpl.patches.Rectangle((0, 0), 1, n_channels, alpha=0.5, + facecolor='w', edgecolor='w') + ax_vscroll.add_patch(vsel_patch) + params['vsel_patch'] = vsel_patch + hsel_patch = mpl.patches.Rectangle((params['t_start'], 0), + params['duration'], 1, edgecolor='k', + facecolor=(0.75, 0.75, 0.75), + alpha=0.25, linewidth=1, clip_on=False) + ax_hscroll.add_patch(hsel_patch) + params['hsel_patch'] = hsel_patch + ax_hscroll.set_xlim(0, params['n_times'] / float(info['sfreq'])) + n_ch = len(info['ch_names']) + ax_vscroll.set_ylim(n_ch, 0) + ax_vscroll.set_title('Ch.') + + # make shells for plotting traces + ylim = [n_channels * 2 + 1, 0] + offset = ylim[0] / n_channels + offsets = np.arange(n_channels) * offset + (offset / 2.) + ax.set_yticks(offsets) + ax.set_ylim(ylim) + ax.set_xlim(params['t_start'], params['t_start'] + params['duration'], + False) + + params['offsets'] = offsets + params['lines'] = [ax.plot([np.nan], antialiased=False, linewidth=0.5)[0] + for _ in range(n_ch)] + ax.set_yticklabels(['X' * max([len(ch) for ch in info['ch_names']])]) + vertline_color = (0., 0.75, 0.) + params['ax_vertline'] = ax.plot([0, 0], ylim, color=vertline_color, + zorder=-1)[0] + params['ax_vertline'].ch_name = '' + params['vertline_t'] = ax_hscroll.text(0, 1, '', color=vertline_color, + va='bottom', ha='right') + params['ax_hscroll_vertline'] = ax_hscroll.plot([0, 0], [0, 1], + color=vertline_color, + zorder=1)[0] + + +def _plot_raw_traces(params, inds, color, bad_color, event_lines=None, + event_color=None): + """Helper for plotting raw""" + lines = params['lines'] + info = params['info'] + n_channels = params['n_channels'] + params['bad_color'] = bad_color + # do the plotting + tick_list = list() + for ii in range(n_channels): + ch_ind = ii + params['ch_start'] + # let's be generous here and allow users to pass + # n_channels per view >= the number of traces available + if ii >= len(lines): + break + elif ch_ind < len(info['ch_names']): + # scale to fit + ch_name = info['ch_names'][inds[ch_ind]] + tick_list += [ch_name] + offset = params['offsets'][ii] + + # do NOT operate in-place lest this get screwed up + this_data = params['data'][inds[ch_ind]] * params['scale_factor'] + this_color = bad_color if ch_name in info['bads'] else color + this_z = -1 if ch_name in info['bads'] else 0 + if isinstance(this_color, dict): + this_color = this_color[params['types'][inds[ch_ind]]] + + # subtraction here gets corect orientation for flipped ylim + lines[ii].set_ydata(offset - this_data) + lines[ii].set_xdata(params['times']) + lines[ii].set_color(this_color) + lines[ii].set_zorder(this_z) + vars(lines[ii])['ch_name'] = ch_name + vars(lines[ii])['def_color'] = color[params['types'][inds[ch_ind]]] + else: + # "remove" lines + lines[ii].set_xdata([]) + lines[ii].set_ydata([]) + # deal with event lines + if params['event_times'] is not None: + # find events in the time window + event_times = params['event_times'] + mask = np.logical_and(event_times >= params['times'][0], + event_times <= params['times'][-1]) + event_times = event_times[mask] + event_nums = params['event_nums'][mask] + # plot them with appropriate colors + # go through the list backward so we end with -1, the catchall + used = np.zeros(len(event_times), bool) + ylim = params['ax'].get_ylim() + for ev_num, line in zip(sorted(event_color.keys())[::-1], + event_lines[::-1]): + mask = (event_nums == ev_num) if ev_num >= 0 else ~used + assert not np.any(used[mask]) + used[mask] = True + t = event_times[mask] + if len(t) > 0: + xs = list() + ys = list() + for tt in t: + xs += [tt, tt, np.nan] + ys += [0, ylim[0], np.nan] + line.set_xdata(xs) + line.set_ydata(ys) + else: + line.set_xdata([]) + line.set_ydata([]) + # finalize plot + params['ax'].set_xlim(params['times'][0], + params['times'][0] + params['duration'], False) + params['ax'].set_yticklabels(tick_list) + params['vsel_patch'].set_y(params['ch_start']) + params['fig'].canvas.draw() + # XXX This is a hack to make sure this figure gets drawn last + # so that when matplotlib goes to calculate bounds we don't get a + # CGContextRef error on the MacOSX backend :( + if params['fig_proj'] is not None: + params['fig_proj'].canvas.draw() diff --git a/mne/viz/tests/test_epochs.py b/mne/viz/tests/test_epochs.py index 484b9faac7e..e2692dd40c8 100644 --- a/mne/viz/tests/test_epochs.py +++ b/mne/viz/tests/test_epochs.py @@ -112,6 +112,8 @@ def test_plot_epochs(): fig = epochs.plot(trellis=False) fig.canvas.key_press_event('left') fig.canvas.key_press_event('right') + fig.canvas.scroll_event(0.5, 0.5, -0.5) # scroll down + fig.canvas.scroll_event(0.5, 0.5, 0.5) # scroll up fig.canvas.key_press_event('up') fig.canvas.key_press_event('down') fig.canvas.key_press_event('pageup') @@ -128,6 +130,8 @@ def test_plot_epochs(): fig.canvas.resize_event() fig.canvas.close_event() # closing and epoch dropping plt.close('all') + assert_raises(RuntimeError, epochs.plot, picks=[], trellis=False) + plt.close('all') with warnings.catch_warnings(record=True): fig = epochs.plot(trellis=False) # test mouse clicks @@ -138,15 +142,13 @@ def test_plot_epochs(): _fake_click(fig, data_ax, [x, y], xform='data') # mark a bad epoch _fake_click(fig, data_ax, [x, y], xform='data') # unmark a bad epoch _fake_click(fig, data_ax, [0.5, 0.999]) # click elsewhere in 1st axes + _fake_click(fig, data_ax, [-0.1, 0.9]) # click on y-label _fake_click(fig, fig.get_axes()[2], [0.5, 0.5]) # change epochs _fake_click(fig, fig.get_axes()[3], [0.5, 0.5]) # change channels fig.canvas.close_event() # closing and epoch dropping assert(n_epochs - 1 == len(epochs)) plt.close('all') - assert_raises(RuntimeError, epochs.plot, picks=[], trellis=False) - plt.close('all') - def test_plot_image_epochs(): """Test plotting of epochs image diff --git a/mne/viz/tests/test_ica.py b/mne/viz/tests/test_ica.py index cf612ec507c..deb42f16ff0 100644 --- a/mne/viz/tests/test_ica.py +++ b/mne/viz/tests/test_ica.py @@ -143,4 +143,57 @@ def test_plot_ica_scores(): plt.close('all') +@requires_sklearn +def test_plot_instance_components(): + """Test plotting of components as instances of raw and epochs.""" + import matplotlib.pyplot as plt + raw = _get_raw() + picks = _get_picks(raw) + ica = ICA(noise_cov=read_cov(cov_fname), n_components=2, + max_pca_components=3, n_pca_components=3) + ica.fit(raw, picks=picks) + fig = ica.plot_sources(raw, exclude=[0], title='Components') + fig.canvas.key_press_event('down') + fig.canvas.key_press_event('up') + fig.canvas.key_press_event('right') + fig.canvas.key_press_event('left') + fig.canvas.key_press_event('o') + fig.canvas.key_press_event('-') + fig.canvas.key_press_event('+') + fig.canvas.key_press_event('=') + fig.canvas.key_press_event('pageup') + fig.canvas.key_press_event('pagedown') + fig.canvas.key_press_event('home') + fig.canvas.key_press_event('end') + fig.canvas.key_press_event('f11') + ax = fig.get_axes()[0] + line = ax.lines[0] + _fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], 'data') + _fake_click(fig, ax, [-0.1, 0.9]) # click on y-label + fig.canvas.key_press_event('escape') + plt.close('all') + epochs = _get_epochs() + fig = ica.plot_sources(epochs, exclude=[0], title='Components') + fig.canvas.key_press_event('down') + fig.canvas.key_press_event('up') + fig.canvas.key_press_event('right') + fig.canvas.key_press_event('left') + fig.canvas.key_press_event('o') + fig.canvas.key_press_event('-') + fig.canvas.key_press_event('+') + fig.canvas.key_press_event('=') + fig.canvas.key_press_event('pageup') + fig.canvas.key_press_event('pagedown') + fig.canvas.key_press_event('home') + fig.canvas.key_press_event('end') + fig.canvas.key_press_event('f11') + # Test a click + ax = fig.get_axes()[0] + line = ax.lines[0] + _fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], 'data') + _fake_click(fig, ax, [-0.1, 0.9]) # click on y-label + fig.canvas.key_press_event('escape') + plt.close('all') + + run_tests_if_main() diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py index cabcfca7728..bdb8e46d3d8 100644 --- a/mne/viz/tests/test_raw.py +++ b/mne/viz/tests/test_raw.py @@ -48,10 +48,13 @@ def test_plot_raw(): _fake_click(fig, data_ax, [x, y], xform='data') # mark a bad channel _fake_click(fig, data_ax, [x, y], xform='data') # unmark a bad channel _fake_click(fig, data_ax, [0.5, 0.999]) # click elsewhere in 1st axes + _fake_click(fig, data_ax, [-0.1, 0.9]) # click on y-label _fake_click(fig, fig.get_axes()[1], [0.5, 0.5]) # change time _fake_click(fig, fig.get_axes()[2], [0.5, 0.5]) # change channels _fake_click(fig, fig.get_axes()[3], [0.5, 0.5]) # open SSP window fig.canvas.button_press_event(1, 1, 1) # outside any axes + fig.canvas.scroll_event(0.5, 0.5, -0.5) # scroll down + fig.canvas.scroll_event(0.5, 0.5, 0.5) # scroll up # sadly these fail when no renderer is used (i.e., when using Agg): # ssp_fig = set(plt.get_fignums()) - set([fig.number]) # assert_equal(len(ssp_fig), 1) @@ -76,6 +79,7 @@ def test_plot_raw(): fig.canvas.key_press_event('pagedown') fig.canvas.key_press_event('home') fig.canvas.key_press_event('end') + fig.canvas.key_press_event('?') fig.canvas.key_press_event('f11') fig.canvas.key_press_event('escape') # Color setting diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 668e20ac3d3..6d4929190a6 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -19,7 +19,7 @@ import numpy as np from ..io import show_fiff -from ..utils import verbose +from ..utils import verbose, set_config COLORS = ['b', 'g', 'r', 'c', 'm', 'y', 'k', '#473C8B', '#458B74', @@ -213,6 +213,91 @@ def _toggle_proj(event, params): params['plot_update_proj_callback'](params, bools) +def _get_help_text(params): + """Aux function for customizing help dialogs text.""" + text, text2 = list(), list() + + text.append(u'\u2190 : \n') + text.append(u'\u2192 : \n') + text.append(u'\u2193 : \n') + text.append(u'\u2191 : \n') + text.append(u'- : \n') + text.append(u'+ or = : \n') + text.append(u'Home : \n') + text.append(u'End : \n') + text.append(u'Page down : \n') + text.append(u'Page up : \n') + + text.append(u'F11 : \n') + text.append(u'? : \n') + text.append(u'Esc : \n\n') + text.append(u'Mouse controls\n') + text.append(u'click on data :\n') + + text2.append('Navigate left\n') + text2.append('Navigate right\n') + + text2.append('Scale down\n') + text2.append('Scale up\n') + + text2.append('Toggle full screen mode\n') + text2.append('Open help box\n') + text2.append('Quit\n\n\n') + if 'raw' in params: + text2.insert(4, 'Reduce the time shown per view\n') + text2.insert(5, 'Increase the time shown per view\n') + text.append(u'click elsewhere in the plot :\n') + if 'ica' in params: + text.append(u'click component name :\n') + text2.insert(2, 'Navigate components down\n') + text2.insert(3, 'Navigate components up\n') + text2.insert(8, 'Reduce the number of components per view\n') + text2.insert(9, 'Increase the number of components per view\n') + text2.append('Mark bad channel\n') + text2.append('Vertical line at a time instant\n') + text2.append('Show topography for the component\n') + else: + text.append(u'click channel name :\n') + text2.insert(2, 'Navigate channels down\n') + text2.insert(3, 'Navigate channels up\n') + text2.insert(8, 'Reduce the number of channels per view\n') + text2.insert(9, 'Increase the number of channels per view\n') + text2.append('Mark bad channel\n') + text2.append('Vertical line at a time instant\n') + text2.append('Mark bad channel\n') + + elif 'epochs' in params: + text.append(u'right click :\n') + text2.insert(4, 'Reduce the number of epochs per view\n') + text2.insert(5, 'Increase the number of epochs per view\n') + if 'ica' in params: + text.append(u'click component name :\n') + text2.insert(2, 'Navigate components down\n') + text2.insert(3, 'Navigate components up\n') + text2.insert(8, 'Reduce the number of components per view\n') + text2.insert(9, 'Increase the number of components per view\n') + text2.append('Mark component for exclusion\n') + text2.append('Vertical line at a time instant\n') + text2.append('Show topography for the component\n') + else: + text.append(u'click channel name :\n') + text2.insert(2, 'Navigate channels down\n') + text2.insert(3, 'Navigate channels up\n') + text2.insert(8, 'Reduce the number of channels per view\n') + text2.insert(9, 'Increase the number of channels per view\n') + text.insert(10, u'b : \n') + text2.insert(10, 'Toggle butterfly plot on/off\n') + text2.append('Mark bad epoch\n') + text2.append('Vertical line at a time instant\n') + text2.append('Mark bad channel\n') + text.append(u'middle click :\n') + text2.append('Show channel name (butterfly plot)\n') + text.insert(11, u'o : \n') + text2.insert(11, 'View settings (orig. view only)\n') + + return ''.join(text), ''.join(text2) + + def _prepare_trellis(n_cells, max_col): """Aux function """ @@ -390,6 +475,217 @@ def figure_nobar(*args, **kwargs): return fig +def _helper_raw_resize(event, params): + """Helper for resizing""" + size = ','.join([str(s) for s in params['fig'].get_size_inches()]) + set_config('MNE_BROWSE_RAW_SIZE', size) + _layout_figure(params) + + +def _plot_raw_onscroll(event, params, len_channels=None): + """Interpret scroll events""" + if len_channels is None: + len_channels = len(params['info']['ch_names']) + orig_start = params['ch_start'] + if event.step < 0: + params['ch_start'] = min(params['ch_start'] + params['n_channels'], + len_channels - params['n_channels']) + else: # event.key == 'up': + params['ch_start'] = max(params['ch_start'] - params['n_channels'], 0) + if orig_start != params['ch_start']: + _channels_changed(params, len_channels) + + +def _channels_changed(params, len_channels): + """Helper function for dealing with the vertical shift of the viewport.""" + if params['ch_start'] + params['n_channels'] > len_channels: + params['ch_start'] = len_channels - params['n_channels'] + if params['ch_start'] < 0: + params['ch_start'] = 0 + params['plot_fun']() + + +def _plot_raw_time(value, params): + """Deal with changed time value""" + info = params['info'] + max_times = params['n_times'] / float(info['sfreq']) - params['duration'] + if value > max_times: + value = params['n_times'] / info['sfreq'] - params['duration'] + if value < 0: + value = 0 + if params['t_start'] != value: + params['t_start'] = value + params['hsel_patch'].set_x(value) + + +def _plot_raw_onkey(event, params): + """Interpret key presses""" + import matplotlib.pyplot as plt + if event.key == 'escape': + plt.close(params['fig']) + elif event.key == 'down': + params['ch_start'] += params['n_channels'] + _channels_changed(params, len(params['info']['ch_names'])) + elif event.key == 'up': + params['ch_start'] -= params['n_channels'] + _channels_changed(params, len(params['info']['ch_names'])) + elif event.key == 'right': + value = params['t_start'] + params['duration'] + _plot_raw_time(value, params) + params['update_fun']() + params['plot_fun']() + elif event.key == 'left': + value = params['t_start'] - params['duration'] + _plot_raw_time(value, params) + params['update_fun']() + params['plot_fun']() + elif event.key in ['+', '=']: + params['scale_factor'] *= 1.1 + params['plot_fun']() + elif event.key == '-': + params['scale_factor'] /= 1.1 + params['plot_fun']() + elif event.key == 'pageup': + n_channels = params['n_channels'] + 1 + offset = params['ax'].get_ylim()[0] / n_channels + params['offsets'] = np.arange(n_channels) * offset + (offset / 2.) + params['n_channels'] = n_channels + params['ax'].set_yticks(params['offsets']) + params['vsel_patch'].set_height(n_channels) + _channels_changed(params, len(params['info']['ch_names'])) + elif event.key == 'pagedown': + n_channels = params['n_channels'] - 1 + if n_channels == 0: + return + offset = params['ax'].get_ylim()[0] / n_channels + params['offsets'] = np.arange(n_channels) * offset + (offset / 2.) + params['n_channels'] = n_channels + params['ax'].set_yticks(params['offsets']) + params['vsel_patch'].set_height(n_channels) + if len(params['lines']) > n_channels: # remove line from view + params['lines'][n_channels].set_xdata([]) + params['lines'][n_channels].set_ydata([]) + _channels_changed(params, len(params['info']['ch_names'])) + elif event.key == 'home': + duration = params['duration'] - 1.0 + if duration <= 0: + return + params['duration'] = duration + params['hsel_patch'].set_width(params['duration']) + params['update_fun']() + params['plot_fun']() + elif event.key == 'end': + duration = params['duration'] + 1.0 + if duration > params['raw'].times[-1]: + duration = params['raw'].times[-1] + params['duration'] = duration + params['hsel_patch'].set_width(params['duration']) + params['update_fun']() + params['plot_fun']() + elif event.key == '?': + _onclick_help(event, params) + elif event.key == 'f11': + mng = plt.get_current_fig_manager() + mng.full_screen_toggle() + + +def _mouse_click(event, params): + """Vertical select callback""" + if event.button != 1: + return + if event.inaxes is None: + if params['n_channels'] > 100: + return + ax = params['ax'] + ylim = ax.get_ylim() + pos = ax.transData.inverted().transform((event.x, event.y)) + if pos[0] > params['t_start'] or pos[1] < 0 or pos[1] > ylim[0]: + return + params['label_click_fun'](pos) + # vertical scrollbar changed + if event.inaxes == params['ax_vscroll']: + ch_start = max(int(event.ydata) - params['n_channels'] // 2, 0) + if params['ch_start'] != ch_start: + params['ch_start'] = ch_start + params['plot_fun']() + # horizontal scrollbar changed + elif event.inaxes == params['ax_hscroll']: + _plot_raw_time(event.xdata - params['duration'] / 2, params) + params['update_fun']() + params['plot_fun']() + + elif event.inaxes == params['ax']: + params['pick_bads_fun'](event) + + +def _select_bads(event, params, bads): + """Helper for selecting bad channels onpick. Returns updated bads list.""" + # trade-off, avoid selecting more than one channel when drifts are present + # however for clean data don't click on peaks but on flat segments + def f(x, y): + return y(np.mean(x), x.std() * 2) + lines = event.inaxes.lines + for line in lines: + ydata = line.get_ydata() + if not isinstance(ydata, list) and not np.isnan(ydata).any(): + ymin, ymax = f(ydata, np.subtract), f(ydata, np.add) + if ymin <= event.ydata <= ymax: + this_chan = vars(line)['ch_name'] + if this_chan in params['info']['ch_names']: + ch_idx = params['ch_start'] + lines.index(line) + if this_chan not in bads: + bads.append(this_chan) + color = params['bad_color'] + line.set_zorder(-1) + else: + while this_chan in bads: + bads.remove(this_chan) + color = vars(line)['def_color'] + line.set_zorder(0) + line.set_color(color) + params['ax_vscroll'].patches[ch_idx].set_color(color) + break + else: + x = np.array([event.xdata] * 2) + params['ax_vertline'].set_data(x, np.array(params['ax'].get_ylim())) + params['ax_hscroll_vertline'].set_data(x, np.array([0., 1.])) + params['vertline_t'].set_text('%0.3f' % x[0]) + return bads + + +def _onclick_help(event, params): + """Function for drawing help window""" + import matplotlib.pyplot as plt + text, text2 = _get_help_text(params) + + width = 6 + height = 5 + + fig_help = figure_nobar(figsize=(width, height), dpi=80) + fig_help.canvas.set_window_title('Help') + ax = plt.subplot2grid((8, 5), (0, 0), colspan=5) + ax.set_title('Keyboard shortcuts') + plt.axis('off') + ax1 = plt.subplot2grid((8, 5), (1, 0), rowspan=7, colspan=2) + ax1.set_yticklabels(list()) + plt.text(0.99, 1, text, fontname='STIXGeneral', va='top', weight='bold', + ha='right') + plt.axis('off') + + ax2 = plt.subplot2grid((8, 5), (1, 2), rowspan=7, colspan=3) + ax2.set_yticklabels(list()) + plt.text(0, 1, text2, fontname='STIXGeneral', va='top') + plt.axis('off') + + tight_layout(fig=fig_help) + # this should work for non-test cases + try: + fig_help.canvas.draw() + fig_help.show() + except Exception: + pass + + class ClickableImage(object): """