From 55289a01bb71aee637d771ec7408c5ee50215e38 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Tue, 30 Jun 2015 09:13:55 +0300 Subject: [PATCH 01/36] Added functions for plotting ica components. --- mne/preprocessing/ica.py | 18 +++++++++++++++++- mne/viz/__init__.py | 1 + mne/viz/ica.py | 14 ++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 4bb3cc7fcfe..a0ab9f25712 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -32,7 +32,7 @@ from ..io.constants import Bunch, FIFF from ..io.base import _BaseRaw from ..epochs import _BaseEpochs -from ..viz import (plot_ica_components, plot_ica_scores, +from ..viz import (plot_ica_components, plot_ica_scores, _plot_raw_components, plot_ica_sources, plot_ica_overlay) from ..channels.channels import _contains_ch_type, ContainsMixin from ..io.write import start_file, end_file, write_id @@ -1573,6 +1573,22 @@ def _check_n_pca_components(self, _n_pca_comp, verbose=None): return _n_pca_comp + def plot_raw_components(self, raw): + """Plot ICA components + + Note. This is still experimental and will most likely change. Over + + Parameters + ---------- + raw : instance of Raw + Raw object to draw sources from. + start_find : int | float | None + First sample to include for artifact search. If float, data will be + interpreted as time in seconds. If None, data will be used from the + """ + _plot_raw_components(self, raw) + + def _check_start_stop(raw, start, stop): """Aux function""" return [c if (isinstance(c, int) or c is None) else diff --git a/mne/viz/__init__.py b/mne/viz/__init__.py index 2e3e0d6f657..968f235e303 100644 --- a/mne/viz/__init__.py +++ b/mne/viz/__init__.py @@ -19,5 +19,6 @@ plot_epochs_trellis, _drop_log_stats, plot_epochs_psd) from .raw import plot_raw, plot_raw_psd from .ica import plot_ica_scores, plot_ica_sources, plot_ica_overlay +from .ica import _plot_raw_components from .montage import plot_montage from .decoding import plot_gat_matrix, plot_gat_times diff --git a/mne/viz/ica.py b/mne/viz/ica.py index 0309e670835..2798052ed27 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -510,3 +510,17 @@ def _plot_ica_overlay_evoked(evoked, evoked_cln, title, show): plt.show() return fig + + +def _plot_raw_components(ica, raw): + """Helper function for plotting the ICA components as raw array.""" + data = ica._transform_raw(raw, 0, len(raw.times)) + c_names = ['ICA ' + str(x + 1) for x in range(len(data))] + #info = create_info(c_names, raw.info['sfreq']) + #raw_ica = RawArray(data, info) + scalings = {'misc': 2} + plot_traces(data) + return raw_ica.plot(scalings=scalings) + +def plot_traces(data): + pass \ No newline at end of file From 94a81b26ce07d326a04e95fc5521d7c26f88aa01 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Tue, 30 Jun 2015 12:33:45 +0300 Subject: [PATCH 02/36] Work towards ica plot. Refactoring. --- mne/viz/epochs.py | 27 ++++++------ mne/viz/ica.py | 75 +++++++++++++++++++++++++++++----- mne/viz/raw.py | 102 ++++++++-------------------------------------- mne/viz/utils.py | 83 ++++++++++++++++++++++++++++++++++++- 4 files changed, 176 insertions(+), 111 deletions(-) diff --git a/mne/viz/epochs.py b/mne/viz/epochs.py index b3f87fef0d9..099ce890761 100644 --- a/mne/viz/epochs.py +++ b/mne/viz/epochs.py @@ -23,6 +23,7 @@ from ..time_frequency import compute_epochs_psd from .utils import tight_layout, _prepare_trellis, figure_nobar from .utils import _toggle_options, _toggle_proj, _layout_figure +from .utils import _channels_changed from ..defaults import _handle_default @@ -1015,17 +1016,6 @@ def _plot_window(value, params): _plot_traces(params) -def _channels_changed(params): - """Deal with vertical shift of the viewport.""" - if params['butterfly']: - return - if params['ch_start'] + params['n_channels'] > len(params['ch_names']): - params['ch_start'] = len(params['ch_names']) - params['n_channels'] - elif params['ch_start'] < 0: - params['ch_start'] = 0 - _plot_traces(params) - - def _plot_vert_lines(params): """ Helper function for plotting vertical lines.""" ax = params['ax'] @@ -1082,6 +1072,8 @@ def _plot_onscroll(event, params): event.key = '+' _plot_onkey(event, params) return + if params['butterfly']: + return orig_start = params['ch_start'] if event.step < 0: params['ch_start'] = min(params['ch_start'] + params['n_channels'], @@ -1090,7 +1082,8 @@ def _plot_onscroll(event, params): else: params['ch_start'] = max(params['ch_start'] - params['n_channels'], 0) if orig_start != params['ch_start']: - _channels_changed(params) + _channels_changed(params, len(params['ch_names'])) + _plot_traces(params) def _mouse_click(event, params): @@ -1167,11 +1160,17 @@ def _plot_onkey(event, params): """Function to handle key presses.""" import matplotlib.pyplot as plt if event.key == 'down': + if params['butterfly']: + return params['ch_start'] += params['n_channels'] - _channels_changed(params) + _channels_changed(params, len(params['ch_names'])) + _plot_traces(params) elif event.key == 'up': + if params['butterfly']: + return params['ch_start'] -= params['n_channels'] - _channels_changed(params) + _channels_changed(params, len(params['ch_names'])) + _plot_traces(params) elif event.key == 'left': sample = params['t_start'] - params['duration'] sample = np.max([0, sample]) diff --git a/mne/viz/ica.py b/mne/viz/ica.py index 2798052ed27..30840a33280 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -12,8 +12,10 @@ import numpy as np -from .utils import tight_layout, _prepare_trellis +from .utils import tight_layout, _prepare_trellis, _prepare_mne_browse_raw +from .utils import _layout_figure from .evoked import _butterfly_on_button_press, _butterfly_onpick +from ..io.meas_info import create_info def _ica_plot_sources_onpick_(event, sources=None, ylims=None): @@ -512,15 +514,68 @@ def _plot_ica_overlay_evoked(evoked, evoked_cln, title, show): return fig -def _plot_raw_components(ica, raw): +def _plot_raw_components(ica, raw, title=None, duration=10.0, start=0.0, + n_channels=20, bgcolor='w', color=None, + bad_color=(0.8, 0.8, 0.8), event_color='cyan'): """Helper function for plotting the ICA components as raw array.""" data = ica._transform_raw(raw, 0, len(raw.times)) - c_names = ['ICA ' + str(x + 1) for x in range(len(data))] - #info = create_info(c_names, raw.info['sfreq']) - #raw_ica = RawArray(data, info) + inds = range(len(data)) + c_names = ['ICA ' + str(x + 1) for x in inds] + if title is None: + title = 'ICA components' + info = create_info(c_names, raw.info['sfreq']) scalings = {'misc': 2} - plot_traces(data) - return raw_ica.plot(scalings=scalings) - -def plot_traces(data): - pass \ No newline at end of file + params = dict(raw=raw, data=data, ch_start=0, t_start=start, info=info, + duration=duration, n_channels=n_channels, scalings=scalings, + n_times=raw.n_times, bad_color=(0.8, 0.8, 0.8)) + _prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds, + n_channels) + params['scale_factor'] = 1.0 + _layout_figure(params) + plot_traces(params) + return params['fig'] + + +def plot_traces(params): + lines = params['lines'] + info = params['info'] + n_channels = params['n_channels'] + bad_color = params['bad_color'] + color = 'black' + # do the plotting + tick_list = [] + for ii in range(n_channels): + ch_ind = ii + params['ch_start'] + # let's be generous here and allow users to pass + # n_channels per view >= the number of traces available + if ii >= len(lines): + break + elif ch_ind < len(info['ch_names']): + # scale to fit + ch_name = info['ch_names'][ch_ind] + tick_list += [ch_name] + offset = params['offsets'][ii] + + # do NOT operate in-place lest this get screwed up + this_data = params['data'][ch_ind] * params['scale_factor'] + this_color = bad_color if ch_name in info['bads'] else color + this_z = -1 if ch_name in info['bads'] else 0 + + # subtraction here gets correct orientation for flipped ylim + lines[ii].set_ydata(offset - this_data) + lines[ii].set_xdata(params['raw'].times) + lines[ii].set_color(this_color) + lines[ii].set_zorder(this_z) + vars(lines[ii])['ch_name'] = ch_name + vars(lines[ii])['def_color'] = this_color + else: + # "remove" lines + lines[ii].set_xdata([]) + lines[ii].set_ydata([]) + + # finalize plot + params['ax'].set_xlim(params['raw'].times[0], + params['raw'].times[0] + params['duration'], False) + params['ax'].set_yticklabels(tick_list) + params['vsel_patch'].set_y(params['ch_start']) + params['fig'].canvas.draw() diff --git a/mne/viz/raw.py b/mne/viz/raw.py index 753c910a126..9d290e6c120 100644 --- a/mne/viz/raw.py +++ b/mne/viz/raw.py @@ -15,10 +15,10 @@ from ..externals.six import string_types from ..io.pick import pick_types from ..io.proj import setup_proj -from ..utils import set_config, get_config, verbose +from ..utils import set_config, verbose from ..time_frequency import compute_raw_psd -from .utils import figure_nobar, _toggle_options, _toggle_proj, tight_layout -from .utils import _layout_figure +from .utils import _toggle_options, _toggle_proj, tight_layout +from .utils import _layout_figure, _prepare_mne_browse_raw, _channels_changed from ..defaults import _handle_default @@ -221,16 +221,8 @@ def _plot_raw_onkey(event, params): return # deal with plotting changes if ch_changed: - _channels_changed(params) - - -def _channels_changed(params): - len_channels = len(params['info']['ch_names']) - if params['ch_start'] + params['n_channels'] >= len_channels: - params['ch_start'] = len_channels - params['n_channels'] - if params['ch_start'] < 0: - params['ch_start'] = 0 - params['plot_fun']() + _channels_changed(params, len(params['info']['ch_names'])) + params['plot_fun']() def _plot_raw_onscroll(event, params): @@ -243,7 +235,8 @@ def _plot_raw_onscroll(event, params): else: # event.key == 'up': params['ch_start'] = max(params['ch_start'] - params['n_channels'], 0) if orig_start != params['ch_start']: - _channels_changed(params) + _channels_changed(params, len(params['info']['ch_names'])) + params['plot_fun']() def _plot_traces(params, inds, color, bad_color, event_lines, event_color): @@ -529,80 +522,17 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=None, n_times=n_times, event_times=event_times, event_nums=event_nums, clipping=clipping, fig_proj=None) - # set up plotting - size = get_config('MNE_BROWSE_RAW_SIZE') - if size is not None: - size = size.split(',') - size = tuple([float(s) for s in size]) - # have to try/catch when there's no toolbar - fig = figure_nobar(facecolor=bgcolor, figsize=size, dpi=80) - fig.canvas.set_window_title('mne_browse_raw') - ax = plt.subplot2grid((10, 10), (0, 0), colspan=9, rowspan=9) - ax.set_title(title, fontsize=12) - ax_hscroll = plt.subplot2grid((10, 10), (9, 0), colspan=9) - ax_hscroll.get_yaxis().set_visible(False) - ax_hscroll.set_xlabel('Time (s)') - ax_vscroll = plt.subplot2grid((10, 10), (0, 9), rowspan=9) - ax_vscroll.set_axis_off() - # store these so they can be fixed on resize - params['fig'] = fig - params['ax'] = ax - params['ax_hscroll'] = ax_hscroll - params['ax_vscroll'] = ax_vscroll - - # populate vertical and horizontal scrollbars - for ci in range(len(info['ch_names'])): - this_color = (bad_color if info['ch_names'][inds[ci]] in info['bads'] - else color) - if isinstance(this_color, dict): - this_color = this_color[types[inds[ci]]] - ax_vscroll.add_patch(mpl.patches.Rectangle((0, ci), 1, 1, - facecolor=this_color, - edgecolor=this_color)) - vsel_patch = mpl.patches.Rectangle((0, 0), 1, n_channels, alpha=0.5, - facecolor='w', edgecolor='w') - ax_vscroll.add_patch(vsel_patch) - params['vsel_patch'] = vsel_patch - hsel_patch = mpl.patches.Rectangle((start, 0), duration, 1, edgecolor='k', - facecolor=(0.75, 0.75, 0.75), - alpha=0.25, linewidth=1, clip_on=False) - ax_hscroll.add_patch(hsel_patch) - params['hsel_patch'] = hsel_patch - ax_hscroll.set_xlim(0, n_times / float(info['sfreq'])) - n_ch = len(info['ch_names']) - ax_vscroll.set_ylim(n_ch, 0) - ax_vscroll.set_title('Ch.') - - # make shells for plotting traces - ylim = [n_channels * 2 + 1, 0] - offset = ylim[0] / n_channels - offsets = np.arange(n_channels) * offset + (offset / 2.) - ax.set_yticks(offsets) - ax.set_ylim(ylim) + _prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds, + n_channels) + # plot event_line first so it's in the back - event_lines = [ax.plot([np.nan], color=event_color[ev_num])[0] + event_lines = [params['ax'].plot([np.nan], color=event_color[ev_num])[0] for ev_num in sorted(event_color.keys())] - params['offsets'] = offsets - params['lines'] = [ax.plot([np.nan], antialiased=False, linewidth=0.5)[0] - for _ in range(n_ch)] - ax.set_yticklabels(['X' * max([len(ch) for ch in info['ch_names']])]) - vertline_color = (0., 0.75, 0.) - params['ax_vertline'] = ax.plot([0, 0], ylim, color=vertline_color, - zorder=-1)[0] - params['ax_vertline'].ch_name = '' - params['vertline_t'] = ax_hscroll.text(0, 0.5, '', color=vertline_color, - verticalalignment='center', - horizontalalignment='right') - params['ax_hscroll_vertline'] = ax_hscroll.plot([0, 0], [0, 1], - color=vertline_color, - zorder=1)[0] - params['plot_fun'] = partial(_plot_traces, params=params, inds=inds, color=color, bad_color=bad_color, event_lines=event_lines, event_color=event_color) params['scale_factor'] = 1.0 - # set up callbacks opt_button = None if len(raw.info['projs']) > 0 and not raw.proj: @@ -612,13 +542,13 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=None, callback_option = partial(_toggle_options, params=params) opt_button.on_clicked(callback_option) callback_key = partial(_plot_raw_onkey, params=params) - fig.canvas.mpl_connect('key_press_event', callback_key) + params['fig'].canvas.mpl_connect('key_press_event', callback_key) callback_scroll = partial(_plot_raw_onscroll, params=params) - fig.canvas.mpl_connect('scroll_event', callback_scroll) + params['fig'].canvas.mpl_connect('scroll_event', callback_scroll) callback_pick = partial(_mouse_click, params=params) - fig.canvas.mpl_connect('button_press_event', callback_pick) + params['fig'].canvas.mpl_connect('button_press_event', callback_pick) callback_resize = partial(_helper_resize, params=params) - fig.canvas.mpl_connect('resize_event', callback_resize) + params['fig'].canvas.mpl_connect('resize_event', callback_resize) # As here code is shared with plot_evoked, some extra steps: # first the actual plot update function @@ -645,7 +575,7 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=None, except TypeError: # not all versions have this plt.show() - return fig + return params['fig'] def _set_psd_plot_params(info, proj, picks, ax, area_mode): diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 668e20ac3d3..1e07e4ebdb6 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -19,7 +19,7 @@ import numpy as np from ..io import show_fiff -from ..utils import verbose +from ..utils import verbose, get_config COLORS = ['b', 'g', 'r', 'c', 'm', 'y', 'k', '#473C8B', '#458B74', @@ -390,6 +390,87 @@ def figure_nobar(*args, **kwargs): return fig +def _prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds, + n_channels): + """Helper for setting up the mne_browse_raw window.""" + import matplotlib.pyplot as plt + import matplotlib as mpl + size = get_config('MNE_BROWSE_RAW_SIZE') + if size is not None: + size = size.split(',') + size = tuple([float(s) for s in size]) + + fig = figure_nobar(facecolor=bgcolor, figsize=size) + fig.canvas.set_window_title('mne_browse_raw') + ax = plt.subplot2grid((10, 10), (0, 0), colspan=9, rowspan=9) + ax.set_title(title, fontsize=12) + ax_hscroll = plt.subplot2grid((10, 10), (9, 0), colspan=9) + ax_hscroll.get_yaxis().set_visible(False) + ax_hscroll.set_xlabel('Time (s)') + ax_vscroll = plt.subplot2grid((10, 10), (0, 9), rowspan=9) + ax_vscroll.set_axis_off() + # store these so they can be fixed on resize + params['fig'] = fig + params['ax'] = ax + params['ax_hscroll'] = ax_hscroll + params['ax_vscroll'] = ax_vscroll + + # populate vertical and horizontal scrollbars + info = params['info'] + for ci in range(len(info['ch_names'])): + this_color = (bad_color if info['ch_names'][inds[ci]] in info['bads'] + else color) + if isinstance(this_color, dict): + this_color = this_color[params['types'][inds[ci]]] + ax_vscroll.add_patch(mpl.patches.Rectangle((0, ci), 1, 1, + facecolor=this_color, + edgecolor=this_color)) + vsel_patch = mpl.patches.Rectangle((0, 0), 1, n_channels, alpha=0.5, + facecolor='w', edgecolor='w') + ax_vscroll.add_patch(vsel_patch) + params['vsel_patch'] = vsel_patch + hsel_patch = mpl.patches.Rectangle((params['t_start'], 0), + params['duration'], 1, edgecolor='k', + facecolor=(0.75, 0.75, 0.75), + alpha=0.25, linewidth=1, clip_on=False) + ax_hscroll.add_patch(hsel_patch) + params['hsel_patch'] = hsel_patch + ax_hscroll.set_xlim(0, params['n_times'] / float(info['sfreq'])) + n_ch = len(info['ch_names']) + ax_vscroll.set_ylim(n_ch, 0) + ax_vscroll.set_title('Ch.') + + # make shells for plotting traces + ylim = [n_channels * 2 + 1, 0] + offset = ylim[0] / n_channels + offsets = np.arange(n_channels) * offset + (offset / 2.) + ax.set_yticks(offsets) + ax.set_ylim(ylim) + + params['offsets'] = offsets + params['lines'] = [ax.plot([np.nan], antialiased=False, linewidth=0.5)[0] + for _ in range(n_ch)] + ax.set_yticklabels(['X' * max([len(ch) for ch in info['ch_names']])]) + vertline_color = (0., 0.75, 0.) + params['ax_vertline'] = ax.plot([0, 0], ylim, color=vertline_color, + zorder=-1)[0] + params['ax_vertline'].ch_name = '' + params['vertline_t'] = ax_hscroll.text(0, 0.5, '', color=vertline_color, + verticalalignment='center', + horizontalalignment='right') + params['ax_hscroll_vertline'] = ax_hscroll.plot([0, 0], [0, 1], + color=vertline_color, + zorder=1)[0] + + +def _channels_changed(params, len_channels): + """Helper function for dealing with the vertical shift of the viewport.""" + if params['ch_start'] + params['n_channels'] > len_channels: + params['ch_start'] = len_channels - params['n_channels'] + elif params['ch_start'] < 0: + params['ch_start'] = 0 + + class ClickableImage(object): """ From fd0c26367ee2739d490f5c4ea9ac0ade7682b2dc Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Tue, 30 Jun 2015 15:23:45 +0300 Subject: [PATCH 03/36] Almost working plotter for ica components. Refactoring. --- mne/preprocessing/ica.py | 35 ++++++++---- mne/viz/epochs.py | 46 +++++++--------- mne/viz/ica.py | 114 ++++++++++++++++++++++++++++++++++----- mne/viz/raw.py | 40 ++++---------- mne/viz/utils.py | 30 ++++++++++- 5 files changed, 184 insertions(+), 81 deletions(-) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index a0ab9f25712..64ce57ec91c 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -1572,21 +1572,38 @@ def _check_n_pca_components(self, _n_pca_comp, verbose=None): return _n_pca_comp - - def plot_raw_components(self, raw): + def plot_raw_components(self, raw, bads=[], title=None, duration=10.0, + start=0.0, n_channels=20, bgcolor='w', color=None, + bad_color=(1., 0., 0.)): """Plot ICA components - Note. This is still experimental and will most likely change. Over - Parameters ---------- raw : instance of Raw Raw object to draw sources from. - start_find : int | float | None - First sample to include for artifact search. If float, data will be - interpreted as time in seconds. If None, data will be used from the - """ - _plot_raw_components(self, raw) + bads : list + List of components to be marked in the plot. Defaults to empty + list. + title : str + Title for the plot. If None, ``ICA components`` is displayed. + Defaults to None + duration : float + Time window (sec) to plot in a given time. Defaults to 10.0. + start : float + Starting point for the plot. Defaults to 0.0. + n_channel : int + The number of channels per view. Defaults to 20. + bgcolor : color object + Color of the background. + color : color object | None + Color for the data traces. If None ``black`` is used. Defaults to + None. + bad_color : color object + Color to use for components marked as bad. + Defaults to (1., 0., 0.) (red). + """ + _plot_raw_components(self, raw, bads, title, duration, start, + n_channels, bgcolor, color, bad_color) def _check_start_stop(raw, start, stop): diff --git a/mne/viz/epochs.py b/mne/viz/epochs.py index 099ce890761..6bd155e1587 100644 --- a/mne/viz/epochs.py +++ b/mne/viz/epochs.py @@ -23,7 +23,7 @@ from ..time_frequency import compute_epochs_psd from .utils import tight_layout, _prepare_trellis, figure_nobar from .utils import _toggle_options, _toggle_proj, _layout_figure -from .utils import _channels_changed +from .utils import _channels_changed, _plot_raw_onscroll from ..defaults import _handle_default @@ -696,6 +696,8 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, 'fig_options': None, 'settings': [True, True, True, True]} # for options dialog + params['plot_fun'] = partial(_plot_traces, params=params) + if len(projs) > 0 and not epochs.proj: ax_button = plt.subplot2grid((10, 15), (9, 14)) opt_button = mpl.widgets.Button(ax_button, 'Proj') @@ -989,7 +991,7 @@ def _plot_update_epochs_proj(params, bools): types = params['types'] for pick, ind in enumerate(params['inds']): params['data'][pick] = data[ind] / params['scalings'][types[pick]] - _plot_traces(params) + params['plot_fun']() def _handle_picks(epochs): @@ -1013,7 +1015,7 @@ def _plot_window(value, params): if params['t_start'] != value: params['t_start'] = value params['hsel_patch'].set_x(value) - _plot_traces(params) + params['plot_fun']() def _plot_vert_lines(params): @@ -1051,7 +1053,7 @@ def _pick_bad_epochs(event, params): params['colors'][ch_idx][epoch_idx] = params['def_colors'][ch_idx] params['ax_hscroll'].patches[epoch_idx].set_color('w') params['ax_hscroll'].patches[epoch_idx].set_zorder(1) - _plot_traces(params) + params['plot_fun']() return # add bad epoch params['bads'] = np.append(params['bads'], epoch_idx) @@ -1060,7 +1062,7 @@ def _pick_bad_epochs(event, params): params['ax_hscroll'].patches[epoch_idx].set_edgecolor('w') for ch_idx in range(len(params['ch_names'])): params['colors'][ch_idx][epoch_idx] = (1., 0., 0., 1.) - _plot_traces(params) + params['plot_fun']() def _plot_onscroll(event, params): @@ -1074,16 +1076,7 @@ def _plot_onscroll(event, params): return if params['butterfly']: return - orig_start = params['ch_start'] - if event.step < 0: - params['ch_start'] = min(params['ch_start'] + params['n_channels'], - len(params['ch_names']) - - params['n_channels']) - else: - params['ch_start'] = max(params['ch_start'] - params['n_channels'], 0) - if orig_start != params['ch_start']: - _channels_changed(params, len(params['ch_names'])) - _plot_traces(params) + _plot_raw_onscroll(event, params, len(params['ch_names'])) def _mouse_click(event, params): @@ -1109,7 +1102,7 @@ def _mouse_click(event, params): params['epochs'].info['bads'].append(text) color = params['bad_color'] params['ax_vscroll'].patches[ch_idx + 1].set_color(color) - _plot_traces(params) + params['plot_fun']() elif event.button == 1: # left click # vertical scroll bar changed if event.inaxes == params['ax_vscroll']: @@ -1118,7 +1111,7 @@ def _mouse_click(event, params): ch_start = max(int(event.ydata) - params['n_channels'] // 2, 0) if params['ch_start'] != ch_start: params['ch_start'] = ch_start - _plot_traces(params) + params['plot_fun']() # horizontal scroll bar changed elif event.inaxes == params['ax_hscroll']: # find the closest epoch time @@ -1145,7 +1138,7 @@ def _mouse_click(event, params): params['vert_lines'].pop(0) if prev_xdata == xdata: params['vertline_t'].set_text('') - _plot_traces(params) + params['plot_fun']() return ylim = params['ax'].get_ylim() for epoch_idx in range(params['n_epochs']): @@ -1153,7 +1146,7 @@ def _mouse_click(event, params): params['vert_lines'].append(params['ax'].plot(pos, ylim, 'y', zorder=4)) params['vertline_t'].set_text('%0.3f' % params['epochs'].times[xdata]) - _plot_traces(params) + params['plot_fun']() def _plot_onkey(event, params): @@ -1164,13 +1157,11 @@ def _plot_onkey(event, params): return params['ch_start'] += params['n_channels'] _channels_changed(params, len(params['ch_names'])) - _plot_traces(params) elif event.key == 'up': if params['butterfly']: return params['ch_start'] -= params['n_channels'] _channels_changed(params, len(params['ch_names'])) - _plot_traces(params) elif event.key == 'left': sample = params['t_start'] - params['duration'] sample = np.max([0, sample]) @@ -1186,13 +1177,13 @@ def _plot_onkey(event, params): params['butterfly_scale'] /= 1.1 else: params['scale_factor'] /= 1.1 - _plot_traces(params) + params['plot_fun']() elif event.key in ['+', '=']: if params['butterfly']: params['butterfly_scale'] *= 1.1 else: params['scale_factor'] *= 1.1 - _plot_traces(params) + params['plot_fun']() elif event.key == 'f11': mng = plt.get_current_fig_manager() mng.full_screen_toggle() @@ -1208,7 +1199,7 @@ def _plot_onkey(event, params): params['ax'].set_yticks(params['offsets']) params['lines'].pop() params['vsel_patch'].set_height(n_channels) - _plot_traces(params) + params['plot_fun']() elif event.key == 'pageup': if params['butterfly']: return @@ -1224,7 +1215,7 @@ def _plot_onkey(event, params): params['ax'].set_yticks(params['offsets']) params['lines'].append(lc) params['vsel_patch'].set_height(n_channels) - _plot_traces(params) + params['plot_fun']() elif event.key == 'home': n_epochs = params['n_epochs'] - 1 if n_epochs <= 0: @@ -1235,7 +1226,7 @@ def _plot_onkey(event, params): params['n_epochs'] = n_epochs params['duration'] -= n_times params['hsel_patch'].set_width(params['duration']) - _plot_traces(params) + params['plot_fun']() elif event.key == 'end': n_epochs = params['n_epochs'] + 1 n_times = len(params['epochs'].times) @@ -1257,7 +1248,7 @@ def _plot_onkey(event, params): params['t_start'] -= n_times params['hsel_patch'].set_x(params['t_start']) params['hsel_patch'].set_width(params['duration']) - _plot_traces(params) + params['plot_fun']() elif event.key == 'b': if params['fig_options'] is not None: plt.close(params['fig_options']) @@ -1269,6 +1260,7 @@ def _plot_onkey(event, params): _open_options(params) elif event.key == '?': _onclick_help(event) + params['plot_fun']() elif event.key == 'escape': plt.close(params['fig']) diff --git a/mne/viz/ica.py b/mne/viz/ica.py index 30840a33280..d044e10b9eb 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -13,7 +13,8 @@ import numpy as np from .utils import tight_layout, _prepare_trellis, _prepare_mne_browse_raw -from .utils import _layout_figure +from .utils import _layout_figure, _plot_raw_onscroll, _plot_raw_time +from .utils import _channels_changed from .evoked import _butterfly_on_button_press, _butterfly_onpick from ..io.meas_info import create_info @@ -514,29 +515,37 @@ def _plot_ica_overlay_evoked(evoked, evoked_cln, title, show): return fig -def _plot_raw_components(ica, raw, title=None, duration=10.0, start=0.0, - n_channels=20, bgcolor='w', color=None, - bad_color=(0.8, 0.8, 0.8), event_color='cyan'): +def _plot_raw_components(ica, raw, bads=[], title=None, duration=10.0, + start=0.0, n_channels=20, bgcolor='w', color=None, + bad_color=(1., 0., 0.)): """Helper function for plotting the ICA components as raw array.""" - data = ica._transform_raw(raw, 0, len(raw.times)) + data = ica._transform_raw(raw, 0, len(raw.times)) * 0.2 # custom scaling inds = range(len(data)) c_names = ['ICA ' + str(x + 1) for x in inds] if title is None: title = 'ICA components' info = create_info(c_names, raw.info['sfreq']) - scalings = {'misc': 2} + params = dict(raw=raw, data=data, ch_start=0, t_start=start, info=info, - duration=duration, n_channels=n_channels, scalings=scalings, - n_times=raw.n_times, bad_color=(0.8, 0.8, 0.8)) + duration=duration, n_channels=n_channels, + n_times=raw.n_times, bad_color=bad_color, bads=bads) _prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds, n_channels) params['scale_factor'] = 1.0 + params['plot_fun'] = partial(_plot_traces, params=params) _layout_figure(params) - plot_traces(params) + # callbacks + callback_key = partial(_plot_onkey, params=params) + params['fig'].canvas.mpl_connect('key_press_event', callback_key) + callback_scroll = partial(_plot_raw_onscroll, params=params) + params['fig'].canvas.mpl_connect('scroll_event', callback_scroll) + callback_pick = partial(_mouse_click, params=params) + params['fig'].canvas.mpl_connect('button_press_event', callback_pick) + _plot_traces(params) return params['fig'] -def plot_traces(params): +def _plot_traces(params): lines = params['lines'] info = params['info'] n_channels = params['n_channels'] @@ -558,8 +567,8 @@ def plot_traces(params): # do NOT operate in-place lest this get screwed up this_data = params['data'][ch_ind] * params['scale_factor'] - this_color = bad_color if ch_name in info['bads'] else color - this_z = -1 if ch_name in info['bads'] else 0 + this_color = bad_color if ch_ind in params['bads'] else color + this_z = -1 if ch_ind in params['bads'] else 0 # subtraction here gets correct orientation for flipped ylim lines[ii].set_ydata(offset - this_data) @@ -567,7 +576,7 @@ def plot_traces(params): lines[ii].set_color(this_color) lines[ii].set_zorder(this_z) vars(lines[ii])['ch_name'] = ch_name - vars(lines[ii])['def_color'] = this_color + lines[ii].set_color(this_color) else: # "remove" lines lines[ii].set_xdata([]) @@ -579,3 +588,82 @@ def plot_traces(params): params['ax'].set_yticklabels(tick_list) params['vsel_patch'].set_y(params['ch_start']) params['fig'].canvas.draw() + + +def _plot_onkey(event, params): + """Interpret key presses""" + import matplotlib.pyplot as plt + if event.key == 'escape': + plt.close(params['fig']) + elif event.key == 'down': + params['ch_start'] += params['n_channels'] + _channels_changed(params, len(params['info']['ch_names'])) + elif event.key == 'up': + params['ch_start'] -= params['n_channels'] + _channels_changed(params, len(params['info']['ch_names'])) + elif event.key == 'right': + _plot_raw_time(params['t_start'] + params['duration'], params) + params['plot_fun']() + elif event.key == 'left': + _plot_raw_time(params['t_start'] - params['duration'], params) + params['plot_fun']() + elif event.key in ['+', '=']: + params['scale_factor'] *= 1.1 + params['plot_fun']() + elif event.key == '-': + params['scale_factor'] /= 1.1 + params['plot_fun']() + elif event.key == 'pageup': + n_channels = params['n_channels'] + 1 + offset = params['ax'].get_ylim()[0] / n_channels + params['offsets'] = np.arange(n_channels) * offset + (offset / 2.) + params['n_channels'] = n_channels + params['ax'].set_yticks(params['offsets']) + params['vsel_patch'].set_height(n_channels) + _channels_changed(params, len(params['info']['ch_names'])) + elif event.key == 'pagedown': + n_channels = params['n_channels'] - 1 + if n_channels == 0: + return + offset = params['ax'].get_ylim()[0] / n_channels + params['offsets'] = np.arange(n_channels) * offset + (offset / 2.) + params['n_channels'] = n_channels + params['ax'].set_yticks(params['offsets']) + params['vsel_patch'].set_height(n_channels) + if len(params['lines']) > n_channels: # remove line from view + params['lines'][n_channels].set_xdata([]) + params['lines'][n_channels].set_ydata([]) + _channels_changed(params, len(params['info']['ch_names'])) + elif event.key == 'home': + duration = params['duration'] - 1.0 + if duration <= 0: + return + params['duration'] = duration + params['hsel_patch'].set_width(params['duration']) + params['plot_fun']() + elif event.key == 'end': + duration = params['duration'] + 1.0 + if duration > params['raw'].times[-1]: + duration = params['raw'].times[-1] + params['duration'] = duration + params['hsel_patch'].set_width(params['duration']) + params['plot_fun']() + elif event.key == 'f11': + mng = plt.get_current_fig_manager() + mng.full_screen_toggle() + + +def _mouse_click(event, params): + """Vertical select callback""" + if event.inaxes is None or event.button != 1: + return + # vertical scrollbar changed + if event.inaxes == params['ax_vscroll']: + ch_start = max(int(event.ydata) - params['n_channels'] // 2, 0) + if params['ch_start'] != ch_start: + params['ch_start'] = ch_start + params['plot_fun']() + # horizontal scrollbar changed + elif event.inaxes == params['ax_hscroll']: + _plot_raw_time(event.xdata - params['duration'] / 2, params) + params['plot_fun']() diff --git a/mne/viz/raw.py b/mne/viz/raw.py index 9d290e6c120..d4d7468892d 100644 --- a/mne/viz/raw.py +++ b/mne/viz/raw.py @@ -19,6 +19,7 @@ from ..time_frequency import compute_raw_psd from .utils import _toggle_options, _toggle_proj, tight_layout from .utils import _layout_figure, _prepare_mne_browse_raw, _channels_changed +from .utils import _plot_raw_onscroll, _plot_raw_time from ..defaults import _handle_default @@ -120,26 +121,13 @@ def _mouse_click(event, params): # horizontal scrollbar changed elif event.inaxes == params['ax_hscroll']: _plot_raw_time(event.xdata - params['duration'] / 2, params) + _update_raw_data(params) + params['plot_fun']() elif event.inaxes == params['ax']: _pick_bad_channels(event, params) -def _plot_raw_time(value, params): - """Deal with changed time value""" - info = params['info'] - max_times = params['n_times'] / float(info['sfreq']) - params['duration'] - if value > max_times: - value = params['n_times'] / info['sfreq'] - params['duration'] - if value < 0: - value = 0 - if params['t_start'] != value: - params['t_start'] = value - params['hsel_patch'].set_x(value) - _update_raw_data(params) - params['plot_fun']() - - def _plot_raw_onkey(event, params): """Interpret key presses""" import matplotlib.pyplot as plt @@ -163,9 +151,13 @@ def _plot_raw_onkey(event, params): ch_changed = True elif event.key == 'right': _plot_raw_time(params['t_start'] + params['duration'], params) + _update_raw_data(params) + params['plot_fun']() return elif event.key == 'left': _plot_raw_time(params['t_start'] - params['duration'], params) + _update_raw_data(params) + params['plot_fun']() return elif event.key in ['o', 'p']: _toggle_options(None, params) @@ -221,22 +213,8 @@ def _plot_raw_onkey(event, params): return # deal with plotting changes if ch_changed: - _channels_changed(params, len(params['info']['ch_names'])) - params['plot_fun']() - - -def _plot_raw_onscroll(event, params): - """Interpret scroll events""" - orig_start = params['ch_start'] - if event.step < 0: - params['ch_start'] = min(params['ch_start'] + params['n_channels'], - len(params['info']['ch_names']) - - params['n_channels']) - else: # event.key == 'up': - params['ch_start'] = max(params['ch_start'] - params['n_channels'], 0) - if orig_start != params['ch_start']: - _channels_changed(params, len(params['info']['ch_names'])) - params['plot_fun']() + len_channels = len(params['info']['ch_names']) + _channels_changed(params, len_channels) def _plot_traces(params, inds, color, bad_color, event_lines, event_color): diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 1e07e4ebdb6..6f575115fdc 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -463,12 +463,40 @@ def _prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds, zorder=1)[0] +def _plot_raw_onscroll(event, params, len_channels=None): + """Interpret scroll events""" + if len_channels is None: + len_channels = len(params['info']['ch_names']) + orig_start = params['ch_start'] + if event.step < 0: + params['ch_start'] = min(params['ch_start'] + params['n_channels'], + len_channels - params['n_channels']) + else: # event.key == 'up': + params['ch_start'] = max(params['ch_start'] - params['n_channels'], 0) + if orig_start != params['ch_start']: + _channels_changed(params, len_channels) + + def _channels_changed(params, len_channels): """Helper function for dealing with the vertical shift of the viewport.""" if params['ch_start'] + params['n_channels'] > len_channels: params['ch_start'] = len_channels - params['n_channels'] - elif params['ch_start'] < 0: + if params['ch_start'] < 0: params['ch_start'] = 0 + params['plot_fun']() + + +def _plot_raw_time(value, params): + """Deal with changed time value""" + info = params['info'] + max_times = params['n_times'] / float(info['sfreq']) - params['duration'] + if value > max_times: + value = params['n_times'] / info['sfreq'] - params['duration'] + if value < 0: + value = 0 + if params['t_start'] != value: + params['t_start'] = value + params['hsel_patch'].set_x(value) class ClickableImage(object): From 9ef7398469ca38c60a2eceac09aae10c57cd2941 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Tue, 30 Jun 2015 16:34:08 +0300 Subject: [PATCH 04/36] Fixes. Horizontal navigation. --- mne/preprocessing/ica.py | 17 +++++++++++------ mne/viz/ica.py | 23 +++++++++++++++-------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 64ce57ec91c..0ac5a61240d 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -1572,7 +1572,7 @@ def _check_n_pca_components(self, _n_pca_comp, verbose=None): return _n_pca_comp - def plot_raw_components(self, raw, bads=[], title=None, duration=10.0, + def plot_raw_components(self, raw, exclude=None, title=None, duration=10.0, start=0.0, n_channels=20, bgcolor='w', color=None, bad_color=(1., 0., 0.)): """Plot ICA components @@ -1581,9 +1581,9 @@ def plot_raw_components(self, raw, bads=[], title=None, duration=10.0, ---------- raw : instance of Raw Raw object to draw sources from. - bads : list - List of components to be marked in the plot. Defaults to empty - list. + exclude : array_like of int | None + The components marked for exclusion. If None (default), ICA.exclude + will be used. title : str Title for the plot. If None, ``ICA components`` is displayed. Defaults to None @@ -1591,7 +1591,7 @@ def plot_raw_components(self, raw, bads=[], title=None, duration=10.0, Time window (sec) to plot in a given time. Defaults to 10.0. start : float Starting point for the plot. Defaults to 0.0. - n_channel : int + n_channels : int The number of channels per view. Defaults to 20. bgcolor : color object Color of the background. @@ -1601,8 +1601,13 @@ def plot_raw_components(self, raw, bads=[], title=None, duration=10.0, bad_color : color object Color to use for components marked as bad. Defaults to (1., 0., 0.) (red). + + Returns + ------- + fig : Instance of matplotlib.figure.Figure + The figure. """ - _plot_raw_components(self, raw, bads, title, duration, start, + _plot_raw_components(self, raw, exclude, title, duration, start, n_channels, bgcolor, color, bad_color) diff --git a/mne/viz/ica.py b/mne/viz/ica.py index d044e10b9eb..b5cf64a5a61 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -515,7 +515,7 @@ def _plot_ica_overlay_evoked(evoked, evoked_cln, title, show): return fig -def _plot_raw_components(ica, raw, bads=[], title=None, duration=10.0, +def _plot_raw_components(ica, raw, exclude=None, title=None, duration=10.0, start=0.0, n_channels=20, bgcolor='w', color=None, bad_color=(1., 0., 0.)): """Helper function for plotting the ICA components as raw array.""" @@ -526,9 +526,11 @@ def _plot_raw_components(ica, raw, bads=[], title=None, duration=10.0, title = 'ICA components' info = create_info(c_names, raw.info['sfreq']) + if exclude is None: + exclude = list() # TODO -> ica.exclude params = dict(raw=raw, data=data, ch_start=0, t_start=start, info=info, duration=duration, n_channels=n_channels, - n_times=raw.n_times, bad_color=bad_color, bads=bads) + n_times=raw.n_times, bad_color=bad_color, bads=exclude) _prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds, n_channels) params['scale_factor'] = 1.0 @@ -551,6 +553,9 @@ def _plot_traces(params): n_channels = params['n_channels'] bad_color = params['bad_color'] color = 'black' + scale_factor = params['scale_factor'] + t_start = int(params['t_start'] * params['info']['sfreq']) + t_end = int((t_start + params['duration']) * params['info']['sfreq']) # do the plotting tick_list = [] for ii in range(n_channels): @@ -566,13 +571,13 @@ def _plot_traces(params): offset = params['offsets'][ii] # do NOT operate in-place lest this get screwed up - this_data = params['data'][ch_ind] * params['scale_factor'] + this_data = params['data'][ch_ind][t_start:t_end] * scale_factor this_color = bad_color if ch_ind in params['bads'] else color this_z = -1 if ch_ind in params['bads'] else 0 # subtraction here gets correct orientation for flipped ylim lines[ii].set_ydata(offset - this_data) - lines[ii].set_xdata(params['raw'].times) + lines[ii].set_xdata(params['raw'].times[t_start:t_end]) lines[ii].set_color(this_color) lines[ii].set_zorder(this_z) vars(lines[ii])['ch_name'] = ch_name @@ -583,8 +588,8 @@ def _plot_traces(params): lines[ii].set_ydata([]) # finalize plot - params['ax'].set_xlim(params['raw'].times[0], - params['raw'].times[0] + params['duration'], False) + params['ax'].set_xlim(params['t_start'], + params['t_start'] + params['duration'], False) params['ax'].set_yticklabels(tick_list) params['vsel_patch'].set_y(params['ch_start']) params['fig'].canvas.draw() @@ -602,10 +607,12 @@ def _plot_onkey(event, params): params['ch_start'] -= params['n_channels'] _channels_changed(params, len(params['info']['ch_names'])) elif event.key == 'right': - _plot_raw_time(params['t_start'] + params['duration'], params) + value = params['t_start'] + params['duration'] + _plot_raw_time(value, params) params['plot_fun']() elif event.key == 'left': - _plot_raw_time(params['t_start'] - params['duration'], params) + value = params['t_start'] - params['duration'] + _plot_raw_time(value, params) params['plot_fun']() elif event.key in ['+', '=']: params['scale_factor'] *= 1.1 From d0b58a7f69ef86ad8724b8bf44084dbcb43e5138 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Wed, 1 Jul 2015 12:15:17 +0300 Subject: [PATCH 05/36] More refactoring. --- mne/preprocessing/ica.py | 4 +- mne/viz/ica.py | 96 ++++++++++++++++------------------------ mne/viz/raw.py | 84 +---------------------------------- mne/viz/utils.py | 81 +++++++++++++++++++++++++++++++++ 4 files changed, 122 insertions(+), 143 deletions(-) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 0ac5a61240d..348d3338be7 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -1573,8 +1573,8 @@ def _check_n_pca_components(self, _n_pca_comp, verbose=None): return _n_pca_comp def plot_raw_components(self, raw, exclude=None, title=None, duration=10.0, - start=0.0, n_channels=20, bgcolor='w', color=None, - bad_color=(1., 0., 0.)): + start=0.0, n_channels=20, bgcolor='w', + color=(0., 0., 0.), bad_color=(1., 0., 0.)): """Plot ICA components Parameters diff --git a/mne/viz/ica.py b/mne/viz/ica.py index b5cf64a5a61..684ed5412ea 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -14,8 +14,9 @@ from .utils import tight_layout, _prepare_trellis, _prepare_mne_browse_raw from .utils import _layout_figure, _plot_raw_onscroll, _plot_raw_time -from .utils import _channels_changed +from .utils import _channels_changed, _plot_raw_traces from .evoked import _butterfly_on_button_press, _butterfly_onpick +from ..defaults import _handle_default from ..io.meas_info import create_info @@ -516,11 +517,15 @@ def _plot_ica_overlay_evoked(evoked, evoked_cln, title, show): def _plot_raw_components(ica, raw, exclude=None, title=None, duration=10.0, - start=0.0, n_channels=20, bgcolor='w', color=None, - bad_color=(1., 0., 0.)): + start=0.0, n_channels=20, bgcolor='w', + color=(0., 0., 0.), bad_color=(1., 0., 0.)): """Helper function for plotting the ICA components as raw array.""" - data = ica._transform_raw(raw, 0, len(raw.times)) * 0.2 # custom scaling - inds = range(len(data)) + color = _handle_default('color', color) + scalings = {'misc': 0.2} + orig_data = ica._transform_raw(raw, 0, len(raw.times)) * scalings['misc'] + inds = range(len(orig_data)) + types = np.repeat('misc', len(inds)) + c_names = ['ICA ' + str(x + 1) for x in inds] if title is None: title = 'ICA components' @@ -528,13 +533,18 @@ def _plot_raw_components(ica, raw, exclude=None, title=None, duration=10.0, if exclude is None: exclude = list() # TODO -> ica.exclude - params = dict(raw=raw, data=data, ch_start=0, t_start=start, info=info, - duration=duration, n_channels=n_channels, - n_times=raw.n_times, bad_color=bad_color, bads=exclude) + info['bads'] = [c_names[x] for x in exclude] + t_end = int(duration * raw.info['sfreq']) + times = raw.times[0:t_end] + params = dict(raw=raw, orig_data=orig_data, data=orig_data[:, 0:t_end], + ch_start=0, t_start=start, info=info, duration=duration, + n_channels=n_channels, times=times, types=types, + n_times=raw.n_times, bad_color=bad_color) _prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds, n_channels) params['scale_factor'] = 1.0 - params['plot_fun'] = partial(_plot_traces, params=params) + params['plot_fun'] = partial(_plot_raw_traces, params=params, inds=inds, + color=color, bad_color=bad_color) _layout_figure(params) # callbacks callback_key = partial(_plot_onkey, params=params) @@ -543,58 +553,12 @@ def _plot_raw_components(ica, raw, exclude=None, title=None, duration=10.0, params['fig'].canvas.mpl_connect('scroll_event', callback_scroll) callback_pick = partial(_mouse_click, params=params) params['fig'].canvas.mpl_connect('button_press_event', callback_pick) - _plot_traces(params) + params['fig_proj'] = None + params['event_times'] = None + params['plot_fun']() return params['fig'] -def _plot_traces(params): - lines = params['lines'] - info = params['info'] - n_channels = params['n_channels'] - bad_color = params['bad_color'] - color = 'black' - scale_factor = params['scale_factor'] - t_start = int(params['t_start'] * params['info']['sfreq']) - t_end = int((t_start + params['duration']) * params['info']['sfreq']) - # do the plotting - tick_list = [] - for ii in range(n_channels): - ch_ind = ii + params['ch_start'] - # let's be generous here and allow users to pass - # n_channels per view >= the number of traces available - if ii >= len(lines): - break - elif ch_ind < len(info['ch_names']): - # scale to fit - ch_name = info['ch_names'][ch_ind] - tick_list += [ch_name] - offset = params['offsets'][ii] - - # do NOT operate in-place lest this get screwed up - this_data = params['data'][ch_ind][t_start:t_end] * scale_factor - this_color = bad_color if ch_ind in params['bads'] else color - this_z = -1 if ch_ind in params['bads'] else 0 - - # subtraction here gets correct orientation for flipped ylim - lines[ii].set_ydata(offset - this_data) - lines[ii].set_xdata(params['raw'].times[t_start:t_end]) - lines[ii].set_color(this_color) - lines[ii].set_zorder(this_z) - vars(lines[ii])['ch_name'] = ch_name - lines[ii].set_color(this_color) - else: - # "remove" lines - lines[ii].set_xdata([]) - lines[ii].set_ydata([]) - - # finalize plot - params['ax'].set_xlim(params['t_start'], - params['t_start'] + params['duration'], False) - params['ax'].set_yticklabels(tick_list) - params['vsel_patch'].set_y(params['ch_start']) - params['fig'].canvas.draw() - - def _plot_onkey(event, params): """Interpret key presses""" import matplotlib.pyplot as plt @@ -609,10 +573,12 @@ def _plot_onkey(event, params): elif event.key == 'right': value = params['t_start'] + params['duration'] _plot_raw_time(value, params) + _update_data(params) params['plot_fun']() elif event.key == 'left': value = params['t_start'] - params['duration'] _plot_raw_time(value, params) + _update_data(params) params['plot_fun']() elif event.key in ['+', '=']: params['scale_factor'] *= 1.1 @@ -647,6 +613,7 @@ def _plot_onkey(event, params): return params['duration'] = duration params['hsel_patch'].set_width(params['duration']) + _update_data(params) params['plot_fun']() elif event.key == 'end': duration = params['duration'] + 1.0 @@ -654,6 +621,7 @@ def _plot_onkey(event, params): duration = params['raw'].times[-1] params['duration'] = duration params['hsel_patch'].set_width(params['duration']) + _update_data(params) params['plot_fun']() elif event.key == 'f11': mng = plt.get_current_fig_manager() @@ -661,7 +629,7 @@ def _plot_onkey(event, params): def _mouse_click(event, params): - """Vertical select callback""" + """Function for handling mouse clicks.""" if event.inaxes is None or event.button != 1: return # vertical scrollbar changed @@ -673,4 +641,14 @@ def _mouse_click(event, params): # horizontal scrollbar changed elif event.inaxes == params['ax_hscroll']: _plot_raw_time(event.xdata - params['duration'] / 2, params) + _update_data(params) params['plot_fun']() + + +def _update_data(params): + """Function for preparing the data on horizontal shift of the viewport.""" + sfreq = params['info']['sfreq'] + start = int(params['t_start'] * sfreq) + end = int((params['t_start'] + params['duration']) * sfreq) + params['data'] = params['orig_data'][:, start:end] + params['times'] = params['raw'].times[start:end] diff --git a/mne/viz/raw.py b/mne/viz/raw.py index d4d7468892d..bd149997b8a 100644 --- a/mne/viz/raw.py +++ b/mne/viz/raw.py @@ -19,7 +19,7 @@ from ..time_frequency import compute_raw_psd from .utils import _toggle_options, _toggle_proj, tight_layout from .utils import _layout_figure, _prepare_mne_browse_raw, _channels_changed -from .utils import _plot_raw_onscroll, _plot_raw_time +from .utils import _plot_raw_onscroll, _plot_raw_time, _plot_raw_traces from ..defaults import _handle_default @@ -217,86 +217,6 @@ def _plot_raw_onkey(event, params): _channels_changed(params, len_channels) -def _plot_traces(params, inds, color, bad_color, event_lines, event_color): - """Helper for plotting raw""" - lines = params['lines'] - info = params['info'] - n_channels = params['n_channels'] - params['bad_color'] = bad_color - # do the plotting - tick_list = [] - for ii in range(n_channels): - ch_ind = ii + params['ch_start'] - # let's be generous here and allow users to pass - # n_channels per view >= the number of traces available - if ii >= len(lines): - break - elif ch_ind < len(info['ch_names']): - # scale to fit - ch_name = info['ch_names'][inds[ch_ind]] - tick_list += [ch_name] - offset = params['offsets'][ii] - - # do NOT operate in-place lest this get screwed up - this_data = params['data'][inds[ch_ind]] * params['scale_factor'] - this_color = bad_color if ch_name in info['bads'] else color - this_z = -1 if ch_name in info['bads'] else 0 - if isinstance(this_color, dict): - this_color = this_color[params['types'][inds[ch_ind]]] - - # subtraction here gets corect orientation for flipped ylim - lines[ii].set_ydata(offset - this_data) - lines[ii].set_xdata(params['times']) - lines[ii].set_color(this_color) - lines[ii].set_zorder(this_z) - vars(lines[ii])['ch_name'] = ch_name - vars(lines[ii])['def_color'] = color[params['types'][inds[ch_ind]]] - else: - # "remove" lines - lines[ii].set_xdata([]) - lines[ii].set_ydata([]) - # deal with event lines - if params['event_times'] is not None: - # find events in the time window - event_times = params['event_times'] - mask = np.logical_and(event_times >= params['times'][0], - event_times <= params['times'][-1]) - event_times = event_times[mask] - event_nums = params['event_nums'][mask] - # plot them with appropriate colors - # go through the list backward so we end with -1, the catchall - used = np.zeros(len(event_times), bool) - ylim = params['ax'].get_ylim() - for ev_num, line in zip(sorted(event_color.keys())[::-1], - event_lines[::-1]): - mask = (event_nums == ev_num) if ev_num >= 0 else ~used - assert not np.any(used[mask]) - used[mask] = True - t = event_times[mask] - if len(t) > 0: - xs = list() - ys = list() - for tt in t: - xs += [tt, tt, np.nan] - ys += [0, ylim[0], np.nan] - line.set_xdata(xs) - line.set_ydata(ys) - else: - line.set_xdata([]) - line.set_ydata([]) - # finalize plot - params['ax'].set_xlim(params['times'][0], - params['times'][0] + params['duration'], False) - params['ax'].set_yticklabels(tick_list) - params['vsel_patch'].set_y(params['ch_start']) - params['fig'].canvas.draw() - # XXX This is a hack to make sure this figure gets drawn last - # so that when matplotlib goes to calculate bounds we don't get a - # CGContextRef error on the MacOSX backend :( - if params['fig_proj'] is not None: - params['fig_proj'].canvas.draw() - - def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=None, bgcolor='w', color=None, bad_color=(0.8, 0.8, 0.8), event_color='cyan', scalings=None, remove_dc=True, order='type', @@ -506,7 +426,7 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=None, # plot event_line first so it's in the back event_lines = [params['ax'].plot([np.nan], color=event_color[ev_num])[0] for ev_num in sorted(event_color.keys())] - params['plot_fun'] = partial(_plot_traces, params=params, inds=inds, + params['plot_fun'] = partial(_plot_raw_traces, params=params, inds=inds, color=color, bad_color=bad_color, event_lines=event_lines, event_color=event_color) diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 6f575115fdc..86ba9e1e348 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -499,6 +499,87 @@ def _plot_raw_time(value, params): params['hsel_patch'].set_x(value) +def _plot_raw_traces(params, inds, color, bad_color, event_lines=None, + event_color=None): + """Helper for plotting raw""" + lines = params['lines'] + info = params['info'] + n_channels = params['n_channels'] + params['bad_color'] = bad_color + # do the plotting + tick_list = [] + for ii in range(n_channels): + ch_ind = ii + params['ch_start'] + # let's be generous here and allow users to pass + # n_channels per view >= the number of traces available + if ii >= len(lines): + break + elif ch_ind < len(info['ch_names']): + # scale to fit + ch_name = info['ch_names'][inds[ch_ind]] + tick_list += [ch_name] + offset = params['offsets'][ii] + + # do NOT operate in-place lest this get screwed up + this_data = params['data'][inds[ch_ind]] * params['scale_factor'] + this_color = bad_color if ch_name in info['bads'] else color + this_z = -1 if ch_name in info['bads'] else 0 + if isinstance(this_color, dict): + this_color = this_color[params['types'][inds[ch_ind]]] + + # subtraction here gets corect orientation for flipped ylim + lines[ii].set_ydata(offset - this_data) + lines[ii].set_xdata(params['times']) + lines[ii].set_color(this_color) + lines[ii].set_zorder(this_z) + vars(lines[ii])['ch_name'] = ch_name + vars(lines[ii])['def_color'] = color[params['types'][inds[ch_ind]]] + else: + # "remove" lines + lines[ii].set_xdata([]) + lines[ii].set_ydata([]) + # deal with event lines + if params['event_times'] is not None: + # find events in the time window + event_times = params['event_times'] + mask = np.logical_and(event_times >= params['times'][0], + event_times <= params['times'][-1]) + event_times = event_times[mask] + event_nums = params['event_nums'][mask] + # plot them with appropriate colors + # go through the list backward so we end with -1, the catchall + used = np.zeros(len(event_times), bool) + ylim = params['ax'].get_ylim() + for ev_num, line in zip(sorted(event_color.keys())[::-1], + event_lines[::-1]): + mask = (event_nums == ev_num) if ev_num >= 0 else ~used + assert not np.any(used[mask]) + used[mask] = True + t = event_times[mask] + if len(t) > 0: + xs = list() + ys = list() + for tt in t: + xs += [tt, tt, np.nan] + ys += [0, ylim[0], np.nan] + line.set_xdata(xs) + line.set_ydata(ys) + else: + line.set_xdata([]) + line.set_ydata([]) + # finalize plot + params['ax'].set_xlim(params['times'][0], + params['times'][0] + params['duration'], False) + params['ax'].set_yticklabels(tick_list) + params['vsel_patch'].set_y(params['ch_start']) + params['fig'].canvas.draw() + # XXX This is a hack to make sure this figure gets drawn last + # so that when matplotlib goes to calculate bounds we don't get a + # CGContextRef error on the MacOSX backend :( + if params['fig_proj'] is not None: + params['fig_proj'].canvas.draw() + + class ClickableImage(object): """ From 089bb2d65f5e76b6a478646a8723e65cbb43a699 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Wed, 1 Jul 2015 12:59:16 +0300 Subject: [PATCH 06/36] Updated example. Removed parameter. --- examples/preprocessing/plot_ica_from_raw.py | 1 + mne/preprocessing/ica.py | 13 +++++-------- mne/viz/ica.py | 6 +++--- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/examples/preprocessing/plot_ica_from_raw.py b/examples/preprocessing/plot_ica_from_raw.py index 1a8c96644dc..90a460cbd74 100644 --- a/examples/preprocessing/plot_ica_from_raw.py +++ b/examples/preprocessing/plot_ica_from_raw.py @@ -62,6 +62,7 @@ show_picks = np.abs(scores).argsort()[::-1][:5] +ica.plot_raw_components(raw, exclude=ecg_inds, title=title % 'ecg') ica.plot_sources(raw, show_picks, exclude=ecg_inds, title=title % 'ecg') ica.plot_components(ecg_inds, title=title % 'ecg', colorbar=True) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 348d3338be7..9c3d116f5e7 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -1573,8 +1573,8 @@ def _check_n_pca_components(self, _n_pca_comp, verbose=None): return _n_pca_comp def plot_raw_components(self, raw, exclude=None, title=None, duration=10.0, - start=0.0, n_channels=20, bgcolor='w', - color=(0., 0., 0.), bad_color=(1., 0., 0.)): + n_channels=20, bgcolor='w', color=(0., 0., 0.), + bad_color=(1., 0., 0.)): """Plot ICA components Parameters @@ -1589,15 +1589,12 @@ def plot_raw_components(self, raw, exclude=None, title=None, duration=10.0, Defaults to None duration : float Time window (sec) to plot in a given time. Defaults to 10.0. - start : float - Starting point for the plot. Defaults to 0.0. n_channels : int The number of channels per view. Defaults to 20. bgcolor : color object Color of the background. - color : color object | None - Color for the data traces. If None ``black`` is used. Defaults to - None. + color : color object + Color for the data traces. Defaults to (0., 0., 0.) (black). bad_color : color object Color to use for components marked as bad. Defaults to (1., 0., 0.) (red). @@ -1607,7 +1604,7 @@ def plot_raw_components(self, raw, exclude=None, title=None, duration=10.0, fig : Instance of matplotlib.figure.Figure The figure. """ - _plot_raw_components(self, raw, exclude, title, duration, start, + _plot_raw_components(self, raw, exclude, title, duration, n_channels, bgcolor, color, bad_color) diff --git a/mne/viz/ica.py b/mne/viz/ica.py index 684ed5412ea..e012ff5e485 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -517,8 +517,8 @@ def _plot_ica_overlay_evoked(evoked, evoked_cln, title, show): def _plot_raw_components(ica, raw, exclude=None, title=None, duration=10.0, - start=0.0, n_channels=20, bgcolor='w', - color=(0., 0., 0.), bad_color=(1., 0., 0.)): + n_channels=20, bgcolor='w', color=(0., 0., 0.), + bad_color=(1., 0., 0.)): """Helper function for plotting the ICA components as raw array.""" color = _handle_default('color', color) scalings = {'misc': 0.2} @@ -537,7 +537,7 @@ def _plot_raw_components(ica, raw, exclude=None, title=None, duration=10.0, t_end = int(duration * raw.info['sfreq']) times = raw.times[0:t_end] params = dict(raw=raw, orig_data=orig_data, data=orig_data[:, 0:t_end], - ch_start=0, t_start=start, info=info, duration=duration, + ch_start=0, t_start=0, info=info, duration=duration, n_channels=n_channels, times=times, types=types, n_times=raw.n_times, bad_color=bad_color) _prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds, From 6688a1a59d6df79529b6ef8ba9e8aa3b67922fdd Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Thu, 2 Jul 2015 10:57:38 +0300 Subject: [PATCH 07/36] Refactoring. Interactive selection of components. --- mne/viz/ica.py | 131 +++++++++++++++-------------------------------- mne/viz/raw.py | 129 +++------------------------------------------- mne/viz/utils.py | 98 ++++++++++++++++++++++++++++++++++- 3 files changed, 147 insertions(+), 211 deletions(-) diff --git a/mne/viz/ica.py b/mne/viz/ica.py index e012ff5e485..bdf4ab0b66e 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -13,8 +13,8 @@ import numpy as np from .utils import tight_layout, _prepare_trellis, _prepare_mne_browse_raw -from .utils import _layout_figure, _plot_raw_onscroll, _plot_raw_time -from .utils import _channels_changed, _plot_raw_traces +from .utils import _layout_figure, _plot_raw_onscroll, _mouse_click +from .utils import _plot_raw_traces, _helper_raw_resize, _plot_raw_onkey from .evoked import _butterfly_on_button_press, _butterfly_onpick from ..defaults import _handle_default from ..io.meas_info import create_info @@ -545,106 +545,24 @@ def _plot_raw_components(ica, raw, exclude=None, title=None, duration=10.0, params['scale_factor'] = 1.0 params['plot_fun'] = partial(_plot_raw_traces, params=params, inds=inds, color=color, bad_color=bad_color) + params['update_fun'] = partial(_update_data, params) + params['pick_bads_fun'] = partial(_pick_bads, params=params) _layout_figure(params) # callbacks - callback_key = partial(_plot_onkey, params=params) + callback_key = partial(_plot_raw_onkey, params=params) params['fig'].canvas.mpl_connect('key_press_event', callback_key) callback_scroll = partial(_plot_raw_onscroll, params=params) params['fig'].canvas.mpl_connect('scroll_event', callback_scroll) callback_pick = partial(_mouse_click, params=params) params['fig'].canvas.mpl_connect('button_press_event', callback_pick) + callback_resize = partial(_helper_raw_resize, params=params) + params['fig'].canvas.mpl_connect('resize_event', callback_resize) params['fig_proj'] = None params['event_times'] = None params['plot_fun']() return params['fig'] -def _plot_onkey(event, params): - """Interpret key presses""" - import matplotlib.pyplot as plt - if event.key == 'escape': - plt.close(params['fig']) - elif event.key == 'down': - params['ch_start'] += params['n_channels'] - _channels_changed(params, len(params['info']['ch_names'])) - elif event.key == 'up': - params['ch_start'] -= params['n_channels'] - _channels_changed(params, len(params['info']['ch_names'])) - elif event.key == 'right': - value = params['t_start'] + params['duration'] - _plot_raw_time(value, params) - _update_data(params) - params['plot_fun']() - elif event.key == 'left': - value = params['t_start'] - params['duration'] - _plot_raw_time(value, params) - _update_data(params) - params['plot_fun']() - elif event.key in ['+', '=']: - params['scale_factor'] *= 1.1 - params['plot_fun']() - elif event.key == '-': - params['scale_factor'] /= 1.1 - params['plot_fun']() - elif event.key == 'pageup': - n_channels = params['n_channels'] + 1 - offset = params['ax'].get_ylim()[0] / n_channels - params['offsets'] = np.arange(n_channels) * offset + (offset / 2.) - params['n_channels'] = n_channels - params['ax'].set_yticks(params['offsets']) - params['vsel_patch'].set_height(n_channels) - _channels_changed(params, len(params['info']['ch_names'])) - elif event.key == 'pagedown': - n_channels = params['n_channels'] - 1 - if n_channels == 0: - return - offset = params['ax'].get_ylim()[0] / n_channels - params['offsets'] = np.arange(n_channels) * offset + (offset / 2.) - params['n_channels'] = n_channels - params['ax'].set_yticks(params['offsets']) - params['vsel_patch'].set_height(n_channels) - if len(params['lines']) > n_channels: # remove line from view - params['lines'][n_channels].set_xdata([]) - params['lines'][n_channels].set_ydata([]) - _channels_changed(params, len(params['info']['ch_names'])) - elif event.key == 'home': - duration = params['duration'] - 1.0 - if duration <= 0: - return - params['duration'] = duration - params['hsel_patch'].set_width(params['duration']) - _update_data(params) - params['plot_fun']() - elif event.key == 'end': - duration = params['duration'] + 1.0 - if duration > params['raw'].times[-1]: - duration = params['raw'].times[-1] - params['duration'] = duration - params['hsel_patch'].set_width(params['duration']) - _update_data(params) - params['plot_fun']() - elif event.key == 'f11': - mng = plt.get_current_fig_manager() - mng.full_screen_toggle() - - -def _mouse_click(event, params): - """Function for handling mouse clicks.""" - if event.inaxes is None or event.button != 1: - return - # vertical scrollbar changed - if event.inaxes == params['ax_vscroll']: - ch_start = max(int(event.ydata) - params['n_channels'] // 2, 0) - if params['ch_start'] != ch_start: - params['ch_start'] = ch_start - params['plot_fun']() - # horizontal scrollbar changed - elif event.inaxes == params['ax_hscroll']: - _plot_raw_time(event.xdata - params['duration'] / 2, params) - _update_data(params) - params['plot_fun']() - - def _update_data(params): """Function for preparing the data on horizontal shift of the viewport.""" sfreq = params['info']['sfreq'] @@ -652,3 +570,38 @@ def _update_data(params): end = int((params['t_start'] + params['duration']) * sfreq) params['data'] = params['orig_data'][:, start:end] params['times'] = params['raw'].times[start:end] + + +def _pick_bads(event, params): + """Method for selecting components on click.""" + bads = params['info']['bads'] + + # trade-off, avoid selecting more than one channel when drifts are present + # however for clean data don't click on peaks but on flat segments + def f(x, y): + return y(np.mean(x), x.std() * 2) + for l in event.inaxes.lines: + ydata = l.get_ydata() + if not isinstance(ydata, list) and not np.isnan(ydata).any(): + ymin, ymax = f(ydata, np.subtract), f(ydata, np.add) + if ymin <= event.ydata <= ymax: + this_chan = vars(l)['ch_name'] + if this_chan in params['info']['ch_names']: + if this_chan not in bads: + bads.append(this_chan) + l.set_color(params['bad_color']) + l.set_zorder(-1) + else: + bads.pop(bads.index(this_chan)) + l.set_color(vars(l)['def_color']) + l.set_zorder(0) + break + else: + x = np.array([event.xdata] * 2) + params['ax_vertline'].set_data(x, np.array(params['ax'].get_ylim())) + params['ax_hscroll_vertline'].set_data(x, np.array([0., 1.])) + params['vertline_t'].set_text('%0.3f' % x[0]) + # update deep-copied info to persistently draw bads + params['info']['bads'] = bads + params['update_fun']() + params['plot_fun']() diff --git a/mne/viz/raw.py b/mne/viz/raw.py index bd149997b8a..9d9137d4e57 100644 --- a/mne/viz/raw.py +++ b/mne/viz/raw.py @@ -15,11 +15,12 @@ from ..externals.six import string_types from ..io.pick import pick_types from ..io.proj import setup_proj -from ..utils import set_config, verbose +from ..utils import verbose from ..time_frequency import compute_raw_psd from .utils import _toggle_options, _toggle_proj, tight_layout -from .utils import _layout_figure, _prepare_mne_browse_raw, _channels_changed -from .utils import _plot_raw_onscroll, _plot_raw_time, _plot_raw_traces +from .utils import _layout_figure, _prepare_mne_browse_raw, _plot_raw_onkey +from .utils import _plot_raw_onscroll, _plot_raw_traces, _mouse_click +from .utils import _helper_raw_resize from ..defaults import _handle_default @@ -32,7 +33,7 @@ def _plot_update_raw_proj(params, bools): params['proj_bools'] = bools params['projector'], _ = setup_proj(params['info'], add_eeg_ref=False, verbose=False) - _update_raw_data(params) + params['update_fun']() params['plot_fun']() @@ -67,13 +68,6 @@ def _update_raw_data(params): params['times'] = times -def _helper_resize(event, params): - """Helper for resizing""" - size = ','.join([str(s) for s in params['fig'].get_size_inches()]) - set_config('MNE_BROWSE_RAW_SIZE', size) - _layout_figure(params) - - def _pick_bad_channels(event, params): """Helper for selecting / dropping bad channels onpick""" bads = params['raw'].info['bads'] @@ -108,115 +102,6 @@ def f(x, y): _plot_update_raw_proj(params, None) -def _mouse_click(event, params): - """Vertical select callback""" - if event.inaxes is None or event.button != 1: - return - # vertical scrollbar changed - if event.inaxes == params['ax_vscroll']: - ch_start = max(int(event.ydata) - params['n_channels'] // 2, 0) - if params['ch_start'] != ch_start: - params['ch_start'] = ch_start - params['plot_fun']() - # horizontal scrollbar changed - elif event.inaxes == params['ax_hscroll']: - _plot_raw_time(event.xdata - params['duration'] / 2, params) - _update_raw_data(params) - params['plot_fun']() - - elif event.inaxes == params['ax']: - _pick_bad_channels(event, params) - - -def _plot_raw_onkey(event, params): - """Interpret key presses""" - import matplotlib.pyplot as plt - # check for initial plot - if event is None: - params['plot_fun']() - return - - # quit event - if event.key == 'escape': - plt.close(params['fig']) - return - - # change plotting params - ch_changed = False - if event.key == 'down': - params['ch_start'] += params['n_channels'] - ch_changed = True - elif event.key == 'up': - params['ch_start'] -= params['n_channels'] - ch_changed = True - elif event.key == 'right': - _plot_raw_time(params['t_start'] + params['duration'], params) - _update_raw_data(params) - params['plot_fun']() - return - elif event.key == 'left': - _plot_raw_time(params['t_start'] - params['duration'], params) - _update_raw_data(params) - params['plot_fun']() - return - elif event.key in ['o', 'p']: - _toggle_options(None, params) - return - elif event.key in ['+', '=']: - params['scale_factor'] *= 1.1 - params['plot_fun']() - return - elif event.key == '-': - params['scale_factor'] /= 1.1 - params['plot_fun']() - return - elif event.key == 'pageup': - n_channels = params['n_channels'] + 1 - offset = params['ax'].get_ylim()[0] / n_channels - params['offsets'] = np.arange(n_channels) * offset + (offset / 2.) - params['n_channels'] = n_channels - params['ax'].set_yticks(params['offsets']) - params['vsel_patch'].set_height(n_channels) - ch_changed = True - elif event.key == 'pagedown': - n_channels = params['n_channels'] - 1 - if n_channels == 0: - return - offset = params['ax'].get_ylim()[0] / n_channels - params['offsets'] = np.arange(n_channels) * offset + (offset / 2.) - params['n_channels'] = n_channels - params['ax'].set_yticks(params['offsets']) - params['vsel_patch'].set_height(n_channels) - if len(params['lines']) > n_channels: # remove line from view - params['lines'][n_channels].set_xdata([]) - params['lines'][n_channels].set_ydata([]) - ch_changed = True - elif event.key == 'home': - duration = params['duration'] - 1.0 - if duration <= 0: - return - params['duration'] = duration - params['hsel_patch'].set_width(params['duration']) - _update_raw_data(params) - params['plot_fun']() - elif event.key == 'end': - duration = params['duration'] + 1.0 - if duration > params['raw'].times[-1]: - duration = params['raw'].times[-1] - params['duration'] = duration - params['hsel_patch'].set_width(params['duration']) - _update_raw_data(params) - params['plot_fun']() - elif event.key == 'f11': - mng = plt.get_current_fig_manager() - mng.full_screen_toggle() - return - # deal with plotting changes - if ch_changed: - len_channels = len(params['info']['ch_names']) - _channels_changed(params, len_channels) - - def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=None, bgcolor='w', color=None, bad_color=(0.8, 0.8, 0.8), event_color='cyan', scalings=None, remove_dc=True, order='type', @@ -430,6 +315,8 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=None, color=color, bad_color=bad_color, event_lines=event_lines, event_color=event_color) + params['update_fun'] = partial(_update_raw_data, params=params) + params['pick_bads_fun'] = partial(_pick_bad_channels, params=params) params['scale_factor'] = 1.0 # set up callbacks opt_button = None @@ -445,7 +332,7 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=None, params['fig'].canvas.mpl_connect('scroll_event', callback_scroll) callback_pick = partial(_mouse_click, params=params) params['fig'].canvas.mpl_connect('button_press_event', callback_pick) - callback_resize = partial(_helper_resize, params=params) + callback_resize = partial(_helper_raw_resize, params=params) params['fig'].canvas.mpl_connect('resize_event', callback_resize) # As here code is shared with plot_evoked, some extra steps: diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 86ba9e1e348..75af8233d20 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -19,7 +19,7 @@ import numpy as np from ..io import show_fiff -from ..utils import verbose, get_config +from ..utils import verbose, get_config, set_config COLORS = ['b', 'g', 'r', 'c', 'm', 'y', 'k', '#473C8B', '#458B74', @@ -463,6 +463,13 @@ def _prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds, zorder=1)[0] +def _helper_raw_resize(event, params): + """Helper for resizing""" + size = ','.join([str(s) for s in params['fig'].get_size_inches()]) + set_config('MNE_BROWSE_RAW_SIZE', size) + _layout_figure(params) + + def _plot_raw_onscroll(event, params, len_channels=None): """Interpret scroll events""" if len_channels is None: @@ -499,6 +506,95 @@ def _plot_raw_time(value, params): params['hsel_patch'].set_x(value) +def _plot_raw_onkey(event, params): + """Interpret key presses""" + import matplotlib.pyplot as plt + if event.key == 'escape': + plt.close(params['fig']) + elif event.key == 'down': + params['ch_start'] += params['n_channels'] + _channels_changed(params, len(params['info']['ch_names'])) + elif event.key == 'up': + params['ch_start'] -= params['n_channels'] + _channels_changed(params, len(params['info']['ch_names'])) + elif event.key == 'right': + value = params['t_start'] + params['duration'] + _plot_raw_time(value, params) + params['update_fun']() + params['plot_fun']() + elif event.key == 'left': + value = params['t_start'] - params['duration'] + _plot_raw_time(value, params) + params['update_fun']() + params['plot_fun']() + elif event.key in ['+', '=']: + params['scale_factor'] *= 1.1 + params['plot_fun']() + elif event.key == '-': + params['scale_factor'] /= 1.1 + params['plot_fun']() + elif event.key == 'pageup': + n_channels = params['n_channels'] + 1 + offset = params['ax'].get_ylim()[0] / n_channels + params['offsets'] = np.arange(n_channels) * offset + (offset / 2.) + params['n_channels'] = n_channels + params['ax'].set_yticks(params['offsets']) + params['vsel_patch'].set_height(n_channels) + _channels_changed(params, len(params['info']['ch_names'])) + elif event.key == 'pagedown': + n_channels = params['n_channels'] - 1 + if n_channels == 0: + return + offset = params['ax'].get_ylim()[0] / n_channels + params['offsets'] = np.arange(n_channels) * offset + (offset / 2.) + params['n_channels'] = n_channels + params['ax'].set_yticks(params['offsets']) + params['vsel_patch'].set_height(n_channels) + if len(params['lines']) > n_channels: # remove line from view + params['lines'][n_channels].set_xdata([]) + params['lines'][n_channels].set_ydata([]) + _channels_changed(params, len(params['info']['ch_names'])) + elif event.key == 'home': + duration = params['duration'] - 1.0 + if duration <= 0: + return + params['duration'] = duration + params['hsel_patch'].set_width(params['duration']) + params['update_fun']() + params['plot_fun']() + elif event.key == 'end': + duration = params['duration'] + 1.0 + if duration > params['raw'].times[-1]: + duration = params['raw'].times[-1] + params['duration'] = duration + params['hsel_patch'].set_width(params['duration']) + params['update_fun']() + params['plot_fun']() + elif event.key == 'f11': + mng = plt.get_current_fig_manager() + mng.full_screen_toggle() + + +def _mouse_click(event, params): + """Vertical select callback""" + if event.inaxes is None or event.button != 1: + return + # vertical scrollbar changed + if event.inaxes == params['ax_vscroll']: + ch_start = max(int(event.ydata) - params['n_channels'] // 2, 0) + if params['ch_start'] != ch_start: + params['ch_start'] = ch_start + params['plot_fun']() + # horizontal scrollbar changed + elif event.inaxes == params['ax_hscroll']: + _plot_raw_time(event.xdata - params['duration'] / 2, params) + params['update_fun']() + params['plot_fun']() + + elif event.inaxes == params['ax']: + params['pick_bads_fun'](event) + + def _plot_raw_traces(params, inds, color, bad_color, event_lines=None, event_color=None): """Helper for plotting raw""" From 7d2ca2b79479b21cc7011cff4543486db555cb9f Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Thu, 2 Jul 2015 12:16:55 +0300 Subject: [PATCH 08/36] More refactoring. Bad channels shown on vertical scroll bar. --- mne/viz/ica.py | 30 ++---------------------------- mne/viz/raw.py | 31 ++----------------------------- mne/viz/utils.py | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 57 deletions(-) diff --git a/mne/viz/ica.py b/mne/viz/ica.py index bdf4ab0b66e..5ffdaed8c2a 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -15,6 +15,7 @@ from .utils import tight_layout, _prepare_trellis, _prepare_mne_browse_raw from .utils import _layout_figure, _plot_raw_onscroll, _mouse_click from .utils import _plot_raw_traces, _helper_raw_resize, _plot_raw_onkey +from .utils import _select_bads from .evoked import _butterfly_on_button_press, _butterfly_onpick from ..defaults import _handle_default from ..io.meas_info import create_info @@ -575,33 +576,6 @@ def _update_data(params): def _pick_bads(event, params): """Method for selecting components on click.""" bads = params['info']['bads'] - - # trade-off, avoid selecting more than one channel when drifts are present - # however for clean data don't click on peaks but on flat segments - def f(x, y): - return y(np.mean(x), x.std() * 2) - for l in event.inaxes.lines: - ydata = l.get_ydata() - if not isinstance(ydata, list) and not np.isnan(ydata).any(): - ymin, ymax = f(ydata, np.subtract), f(ydata, np.add) - if ymin <= event.ydata <= ymax: - this_chan = vars(l)['ch_name'] - if this_chan in params['info']['ch_names']: - if this_chan not in bads: - bads.append(this_chan) - l.set_color(params['bad_color']) - l.set_zorder(-1) - else: - bads.pop(bads.index(this_chan)) - l.set_color(vars(l)['def_color']) - l.set_zorder(0) - break - else: - x = np.array([event.xdata] * 2) - params['ax_vertline'].set_data(x, np.array(params['ax'].get_ylim())) - params['ax_hscroll_vertline'].set_data(x, np.array([0., 1.])) - params['vertline_t'].set_text('%0.3f' % x[0]) - # update deep-copied info to persistently draw bads - params['info']['bads'] = bads + params['info']['bads'] = _select_bads(event, params, bads) params['update_fun']() params['plot_fun']() diff --git a/mne/viz/raw.py b/mne/viz/raw.py index 9d9137d4e57..5d34c5244e9 100644 --- a/mne/viz/raw.py +++ b/mne/viz/raw.py @@ -20,7 +20,7 @@ from .utils import _toggle_options, _toggle_proj, tight_layout from .utils import _layout_figure, _prepare_mne_browse_raw, _plot_raw_onkey from .utils import _plot_raw_onscroll, _plot_raw_traces, _mouse_click -from .utils import _helper_raw_resize +from .utils import _helper_raw_resize, _select_bads from ..defaults import _handle_default @@ -71,34 +71,7 @@ def _update_raw_data(params): def _pick_bad_channels(event, params): """Helper for selecting / dropping bad channels onpick""" bads = params['raw'].info['bads'] - - # trade-off, avoid selecting more than one channel when drifts are present - # however for clean data don't click on peaks but on flat segments - def f(x, y): - return y(np.mean(x), x.std() * 2) - for l in event.inaxes.lines: - ydata = l.get_ydata() - if not isinstance(ydata, list) and not np.isnan(ydata).any(): - ymin, ymax = f(ydata, np.subtract), f(ydata, np.add) - if ymin <= event.ydata <= ymax: - this_chan = vars(l)['ch_name'] - if this_chan in params['raw'].ch_names: - if this_chan not in bads: - bads.append(this_chan) - l.set_color(params['bad_color']) - l.set_zorder(-1) - else: - bads.pop(bads.index(this_chan)) - l.set_color(vars(l)['def_color']) - l.set_zorder(0) - break - else: - x = np.array([event.xdata] * 2) - params['ax_vertline'].set_data(x, np.array(params['ax'].get_ylim())) - params['ax_hscroll_vertline'].set_data(x, np.array([0., 1.])) - params['vertline_t'].set_text('%0.3f' % x[0]) - # update deep-copied info to persistently draw bads - params['info']['bads'] = bads + params['info']['bads'] = _select_bads(event, params, bads) _plot_update_raw_proj(params, None) diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 75af8233d20..9a337f5485a 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -595,6 +595,40 @@ def _mouse_click(event, params): params['pick_bads_fun'](event) +def _select_bads(event, params, bads): + """Helper for selecting bad channels onpick. Returns updated bads list.""" + # trade-off, avoid selecting more than one channel when drifts are present + # however for clean data don't click on peaks but on flat segments + def f(x, y): + return y(np.mean(x), x.std() * 2) + lines = event.inaxes.lines + for line in lines: + ydata = line.get_ydata() + if not isinstance(ydata, list) and not np.isnan(ydata).any(): + ymin, ymax = f(ydata, np.subtract), f(ydata, np.add) + if ymin <= event.ydata <= ymax: + this_chan = vars(line)['ch_name'] + if this_chan in params['info']['ch_names']: + ch_idx = params['ch_start'] + lines.index(line) + if this_chan not in bads: + bads.append(this_chan) + color = params['bad_color'] + line.set_zorder(-1) + else: + bads.pop(bads.index(this_chan)) + color = vars(line)['def_color'] + line.set_zorder(0) + line.set_color(color) + params['ax_vscroll'].patches[ch_idx].set_color(color) + break + else: + x = np.array([event.xdata] * 2) + params['ax_vertline'].set_data(x, np.array(params['ax'].get_ylim())) + params['ax_hscroll_vertline'].set_data(x, np.array([0., 1.])) + params['vertline_t'].set_text('%0.3f' % x[0]) + return bads + + def _plot_raw_traces(params, inds, color, bad_color, event_lines=None, event_color=None): """Helper for plotting raw""" From be345c3bd7cdfa2b933c7d8aedbaeb3136a1a987 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Thu, 2 Jul 2015 12:59:47 +0300 Subject: [PATCH 09/36] Interactive exclusion. --- mne/preprocessing/ica.py | 18 +++++++++++++++--- mne/viz/ica.py | 27 +++++++++++++++++++++------ 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 9c3d116f5e7..006cffa80c4 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -1574,7 +1574,7 @@ def _check_n_pca_components(self, _n_pca_comp, verbose=None): def plot_raw_components(self, raw, exclude=None, title=None, duration=10.0, n_channels=20, bgcolor='w', color=(0., 0., 0.), - bad_color=(1., 0., 0.)): + bad_color=(1., 0., 0.), show=True, block=False): """Plot ICA components Parameters @@ -1598,14 +1598,26 @@ def plot_raw_components(self, raw, exclude=None, title=None, duration=10.0, bad_color : color object Color to use for components marked as bad. Defaults to (1., 0., 0.) (red). + show : bool + Show figures if True. Defaults to True. + block : bool + Whether to halt program execution until the figure is closed. + Useful for selecting components for exclusion on the fly + (click on line). May not work on all systems / platforms. + Defaults to False. Returns ------- fig : Instance of matplotlib.figure.Figure The figure. + exclude : list + Updated list of components marked for exclusion. """ - _plot_raw_components(self, raw, exclude, title, duration, - n_channels, bgcolor, color, bad_color) + return _plot_raw_components(self, raw, exclude=exclude, title=title, + duration=duration, n_channels=n_channels, + bgcolor=bgcolor, color=color, + bad_color=bad_color, show=show, + block=block) def _check_start_stop(raw, start, stop): diff --git a/mne/viz/ica.py b/mne/viz/ica.py index 5ffdaed8c2a..7cde5a043c0 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -519,8 +519,9 @@ def _plot_ica_overlay_evoked(evoked, evoked_cln, title, show): def _plot_raw_components(ica, raw, exclude=None, title=None, duration=10.0, n_channels=20, bgcolor='w', color=(0., 0., 0.), - bad_color=(1., 0., 0.)): - """Helper function for plotting the ICA components as raw array.""" + bad_color=(1., 0., 0.), show=True, block=False): + """Function for plotting the ICA components as raw array.""" + import matplotlib.pyplot as plt color = _handle_default('color', color) scalings = {'misc': 0.2} orig_data = ica._transform_raw(raw, 0, len(raw.times)) * scalings['misc'] @@ -533,14 +534,14 @@ def _plot_raw_components(ica, raw, exclude=None, title=None, duration=10.0, info = create_info(c_names, raw.info['sfreq']) if exclude is None: - exclude = list() # TODO -> ica.exclude + exclude = list() info['bads'] = [c_names[x] for x in exclude] t_end = int(duration * raw.info['sfreq']) times = raw.times[0:t_end] params = dict(raw=raw, orig_data=orig_data, data=orig_data[:, 0:t_end], ch_start=0, t_start=0, info=info, duration=duration, n_channels=n_channels, times=times, types=types, - n_times=raw.n_times, bad_color=bad_color) + n_times=raw.n_times, bad_color=bad_color, exclude=exclude) _prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds, n_channels) params['scale_factor'] = 1.0 @@ -558,10 +559,18 @@ def _plot_raw_components(ica, raw, exclude=None, title=None, duration=10.0, params['fig'].canvas.mpl_connect('button_press_event', callback_pick) callback_resize = partial(_helper_raw_resize, params=params) params['fig'].canvas.mpl_connect('resize_event', callback_resize) + callback_close = partial(_close_event, params=params) + params['fig'].canvas.mpl_connect('close_event', callback_close) params['fig_proj'] = None params['event_times'] = None params['plot_fun']() - return params['fig'] + if show: + try: + plt.show(block=block) + except TypeError: # not all versions have this + plt.show() + + return params['fig'], params['exclude'] def _update_data(params): @@ -574,8 +583,14 @@ def _update_data(params): def _pick_bads(event, params): - """Method for selecting components on click.""" + """Function for selecting components on click.""" bads = params['info']['bads'] params['info']['bads'] = _select_bads(event, params, bads) params['update_fun']() params['plot_fun']() + + +def _close_event(events, params): + """Function for updating the list of excluded components.""" + info = params['info'] + params['exclude'] = [info['ch_names'].index(x) for x in info['bads']] From e8710b89366632a1a58d6bebeefd02932c139d3e Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Thu, 2 Jul 2015 13:39:45 +0300 Subject: [PATCH 10/36] Added test for plot_raw_components. --- mne/viz/tests/test_ica.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/mne/viz/tests/test_ica.py b/mne/viz/tests/test_ica.py index cf612ec507c..ea99bc108b7 100644 --- a/mne/viz/tests/test_ica.py +++ b/mne/viz/tests/test_ica.py @@ -143,4 +143,31 @@ def test_plot_ica_scores(): plt.close('all') +@requires_sklearn +def test_plot_raw_components(): + """Test plotting of raw components.""" + import matplotlib.pyplot as plt + raw = _get_raw() + picks = _get_picks(raw) + ica = ICA(noise_cov=read_cov(cov_fname), n_components=2, + max_pca_components=3, n_pca_components=3) + ica.fit(raw, picks=picks) + fig, _ = ica.plot_raw_components(raw, exclude=[0], title='Components') + fig.canvas.key_press_event('down') + fig.canvas.key_press_event('up') + fig.canvas.key_press_event('right') + fig.canvas.key_press_event('left') + fig.canvas.key_press_event('o') + fig.canvas.key_press_event('-') + fig.canvas.key_press_event('+') + fig.canvas.key_press_event('=') + fig.canvas.key_press_event('pageup') + fig.canvas.key_press_event('pagedown') + fig.canvas.key_press_event('home') + fig.canvas.key_press_event('end') + fig.canvas.key_press_event('f11') + fig.canvas.key_press_event('escape') + plt.close('all') + + run_tests_if_main() From ceb4a587d1ffba63883f434cc986c868b172e586 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Fri, 3 Jul 2015 11:14:09 +0300 Subject: [PATCH 11/36] Components added to ice.exclude. Docs. --- mne/preprocessing/ica.py | 10 +++++++--- mne/viz/ica.py | 14 +++++++------- mne/viz/tests/test_ica.py | 2 +- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 006cffa80c4..0bccab6f65f 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -1575,7 +1575,7 @@ def _check_n_pca_components(self, _n_pca_comp, verbose=None): def plot_raw_components(self, raw, exclude=None, title=None, duration=10.0, n_channels=20, bgcolor='w', color=(0., 0., 0.), bad_color=(1., 0., 0.), show=True, block=False): - """Plot ICA components + """Plot ICA components. Parameters ---------- @@ -1610,8 +1610,12 @@ def plot_raw_components(self, raw, exclude=None, title=None, duration=10.0, ------- fig : Instance of matplotlib.figure.Figure The figure. - exclude : list - Updated list of components marked for exclusion. + + Notes + ----- + To mark or un-mark a component for exclusion, click on the rather flat + segments of a channel's time series. The changes will be reflected + immediately in the ica object's ``ica.exclude`` entry. """ return _plot_raw_components(self, raw, exclude=exclude, title=title, duration=duration, n_channels=n_channels, diff --git a/mne/viz/ica.py b/mne/viz/ica.py index 7cde5a043c0..f7cdb1feb7e 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -523,8 +523,7 @@ def _plot_raw_components(ica, raw, exclude=None, title=None, duration=10.0, """Function for plotting the ICA components as raw array.""" import matplotlib.pyplot as plt color = _handle_default('color', color) - scalings = {'misc': 0.2} - orig_data = ica._transform_raw(raw, 0, len(raw.times)) * scalings['misc'] + orig_data = ica._transform_raw(raw, 0, len(raw.times)) * 0.2 inds = range(len(orig_data)) types = np.repeat('misc', len(inds)) @@ -539,9 +538,9 @@ def _plot_raw_components(ica, raw, exclude=None, title=None, duration=10.0, t_end = int(duration * raw.info['sfreq']) times = raw.times[0:t_end] params = dict(raw=raw, orig_data=orig_data, data=orig_data[:, 0:t_end], - ch_start=0, t_start=0, info=info, duration=duration, + ch_start=0, t_start=0, info=info, duration=duration, ica=ica, n_channels=n_channels, times=times, types=types, - n_times=raw.n_times, bad_color=bad_color, exclude=exclude) + n_times=raw.n_times, bad_color=bad_color) _prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds, n_channels) params['scale_factor'] = 1.0 @@ -570,7 +569,7 @@ def _plot_raw_components(ica, raw, exclude=None, title=None, duration=10.0, except TypeError: # not all versions have this plt.show() - return params['fig'], params['exclude'] + return params['fig'] def _update_data(params): @@ -591,6 +590,7 @@ def _pick_bads(event, params): def _close_event(events, params): - """Function for updating the list of excluded components.""" + """Function for excluding the selected components on close.""" info = params['info'] - params['exclude'] = [info['ch_names'].index(x) for x in info['bads']] + exclude = [info['ch_names'].index(x) for x in info['bads']] + params['ica'].exclude = exclude diff --git a/mne/viz/tests/test_ica.py b/mne/viz/tests/test_ica.py index ea99bc108b7..2014cf1e2fe 100644 --- a/mne/viz/tests/test_ica.py +++ b/mne/viz/tests/test_ica.py @@ -152,7 +152,7 @@ def test_plot_raw_components(): ica = ICA(noise_cov=read_cov(cov_fname), n_components=2, max_pca_components=3, n_pca_components=3) ica.fit(raw, picks=picks) - fig, _ = ica.plot_raw_components(raw, exclude=[0], title='Components') + fig = ica.plot_raw_components(raw, exclude=[0], title='Components') fig.canvas.key_press_event('down') fig.canvas.key_press_event('up') fig.canvas.key_press_event('right') From 679f8cca09aeba06b3d07c6408411cccd7ad490f Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Mon, 6 Jul 2015 13:13:49 +0300 Subject: [PATCH 12/36] ICA plotter for epochs. --- mne/preprocessing/ica.py | 54 +++++++- mne/viz/__init__.py | 2 +- mne/viz/epochs.py | 285 +++++++++++++++++++++------------------ mne/viz/ica.py | 43 ++++++ 4 files changed, 249 insertions(+), 135 deletions(-) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 0bccab6f65f..407da00afd2 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -33,7 +33,7 @@ from ..io.base import _BaseRaw from ..epochs import _BaseEpochs from ..viz import (plot_ica_components, plot_ica_scores, _plot_raw_components, - plot_ica_sources, plot_ica_overlay) + plot_ica_sources, plot_ica_overlay, _plot_epoch_components) from ..channels.channels import _contains_ch_type, ContainsMixin from ..io.write import start_file, end_file, write_id from ..utils import (check_sklearn_version, logger, check_fname, verbose, @@ -1623,6 +1623,58 @@ def plot_raw_components(self, raw, exclude=None, title=None, duration=10.0, bad_color=bad_color, show=show, block=block) + def plot_epoch_components(self, epochs, exclude=None, title=None, + n_epochs=10, n_channels=20, bgcolor='w', + color=(0., 0., 0.), bad_color=(1., 0., 0.), + show=True, block=False): + """Plot ICA components. + + Parameters + ---------- + epochs : instance of Epochs + Epochs object to draw sources from. + exclude : array_like of int | None + The components marked for exclusion. If None (default), ICA.exclude + will be used. + title : str + Title for the plot. If None, ``ICA components`` is displayed. + Defaults to None + n_epochs : int + Number of epoch per view. Defaults to 10. + n_channels : int + The number of channels per view. Defaults to 20. + bgcolor : color object + Color of the background. + color : color object + Color for the data traces. Defaults to (0., 0., 0.) (black). + bad_color : color object + Color to use for epochs marked as bad. + Defaults to (1., 0., 0.) (red). + show : bool + Show figures if True. Defaults to True. + block : bool + Whether to halt program execution until the figure is closed. + Useful for selecting components for exclusion on the fly + (click on line). May not work on all systems / platforms. + Defaults to False. + + Returns + ------- + fig : Instance of matplotlib.figure.Figure + The figure. + + Notes + ----- + To mark or un-mark a component for exclusion, click on the component + name left of the main axes. The changes will be reflected + immediately in the ica object's ``ica.exclude`` entry. + """ + return _plot_epoch_components(self, epochs, exclude=exclude, + title=title, n_epochs=n_epochs, + n_channels=n_channels, bgcolor=bgcolor, + color=color, bad_color=bad_color, + show=show, block=block) + def _check_start_stop(raw, start, stop): """Aux function""" diff --git a/mne/viz/__init__.py b/mne/viz/__init__.py index 968f235e303..d6475014dce 100644 --- a/mne/viz/__init__.py +++ b/mne/viz/__init__.py @@ -19,6 +19,6 @@ plot_epochs_trellis, _drop_log_stats, plot_epochs_psd) from .raw import plot_raw, plot_raw_psd from .ica import plot_ica_scores, plot_ica_sources, plot_ica_overlay -from .ica import _plot_raw_components +from .ica import _plot_raw_components, _plot_epoch_components from .montage import plot_montage from .decoding import plot_gat_matrix, plot_gat_times diff --git a/mne/viz/epochs.py b/mne/viz/epochs.py index 6bd155e1587..1a16e558ad2 100644 --- a/mne/viz/epochs.py +++ b/mne/viz/epochs.py @@ -499,29 +499,143 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, vertical line to the plot. """ import matplotlib.pyplot as plt + scalings = _handle_default('scalings_plot_raw', scalings) + + projs = epochs.info['projs'] + + params = {'epochs': epochs, + 'orig_data': np.concatenate(epochs.get_data(), axis=1), + 'info': copy.deepcopy(epochs.info)} + _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings, + title, picks) + + callback_close = partial(_close_event, params=params) + params['fig'].canvas.mpl_connect('close_event', callback_close) + if show: + try: + plt.show(block=block) + except TypeError: # not all versions have this + plt.show() + + return params['fig'] + + +@verbose +def plot_epochs_psd(epochs, fmin=0, fmax=np.inf, proj=False, n_fft=256, + picks=None, ax=None, color='black', area_mode='std', + area_alpha=0.33, n_overlap=0, + dB=True, n_jobs=1, show=True, verbose=None): + """Plot the power spectral density across epochs + + Parameters + ---------- + epochs : instance of Epochs + The epochs object + fmin : float + Start frequency to consider. + fmax : float + End frequency to consider. + proj : bool + Apply projection. + n_fft : int + Number of points to use in Welch FFT calculations. + picks : array-like of int | None + List of channels to use. + ax : instance of matplotlib Axes | None + Axes to plot into. If None, axes will be created. + color : str | tuple + A matplotlib-compatible color to use. + area_mode : str | None + Mode for plotting area. If 'std', the mean +/- 1 STD (across channels) + will be plotted. If 'range', the min and max (across channels) will be + plotted. Bad channels will be excluded from these calculations. + If None, no area will be plotted. + area_alpha : float + Alpha for the area. + n_overlap : int + The number of points of overlap between blocks. + dB : bool + If True, transform data to decibels. + n_jobs : int + Number of jobs to run in parallel. + show : bool + Show figure if True. + verbose : bool, str, int, or None + If not None, override default verbose level (see mne.verbose). + + Returns + ------- + fig : instance of matplotlib figure + Figure distributing one image per channel across sensor topography. + """ + import matplotlib.pyplot as plt + from .raw import _set_psd_plot_params + fig, picks_list, titles_list, ax_list, make_label = _set_psd_plot_params( + epochs.info, proj, picks, ax, area_mode) + + for ii, (picks, title, ax) in enumerate(zip(picks_list, titles_list, + ax_list)): + psds, freqs = compute_epochs_psd(epochs, picks=picks, fmin=fmin, + fmax=fmax, n_fft=n_fft, + n_overlap=n_overlap, proj=proj, + n_jobs=n_jobs) + + # Convert PSDs to dB + if dB: + psds = 10 * np.log10(psds) + unit = 'dB' + else: + unit = 'power' + # mean across epochs and channels + psd_mean = np.mean(psds, axis=0).mean(axis=0) + if area_mode == 'std': + # std across channels + psd_std = np.std(np.mean(psds, axis=0), axis=0) + hyp_limits = (psd_mean - psd_std, psd_mean + psd_std) + elif area_mode == 'range': + hyp_limits = (np.min(np.mean(psds, axis=0), axis=0), + np.max(np.mean(psds, axis=0), axis=0)) + else: # area_mode is None + hyp_limits = None + + ax.plot(freqs, psd_mean, color=color) + if hyp_limits is not None: + ax.fill_between(freqs, hyp_limits[0], y2=hyp_limits[1], + color=color, alpha=area_alpha) + if make_label: + if ii == len(picks_list) - 1: + ax.set_xlabel('Freq (Hz)') + if ii == len(picks_list) // 2: + ax.set_ylabel('Power Spectral Density (%s/Hz)' % unit) + ax.set_title(title) + ax.set_xlim(freqs[0], freqs[-1]) + if make_label: + tight_layout(pad=0.1, h_pad=0.1, w_pad=0.1, fig=fig) + if show: + plt.show() + return fig + + +def _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings, + title, picks): + """Helper for setting up the mne_browse_epochs window.""" + import matplotlib.pyplot as plt import matplotlib as mpl from matplotlib.collections import LineCollection from matplotlib.colors import colorConverter - scalings = _handle_default('scalings_plot_raw', scalings) - color = _handle_default('color', None) - bad_color = (0.8, 0.8, 0.8) + epochs = params['epochs'] + if picks is None: picks = _handle_picks(epochs) if len(picks) < 1: raise RuntimeError('No appropriate channels found. Please' ' check your picks') - epoch_data = epochs.get_data() - - n_epochs = min(n_epochs, len(epochs.events)) - duration = len(epochs.times) * n_epochs - n_channels = min(n_channels, len(picks)) - projs = epochs.info['projs'] # Reorganize channels inds = list() types = list() for t in ['grad', 'mag']: - idxs = pick_types(epochs.info, meg=t, ref_meg=False, exclude=[]) + idxs = pick_types(params['info'], meg=t, ref_meg=False, exclude=[]) if len(idxs) < 1: continue mask = _in1d(idxs, picks, assume_unique=True) @@ -531,7 +645,7 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, for ch_type in ['eeg', 'eog', 'ecg', 'emg', 'ref_meg', 'stim', 'resp', 'misc', 'chpi', 'syst', 'ias', 'exci']: pick_kwargs[ch_type] = True - idxs = pick_types(epochs.info, **pick_kwargs) + idxs = pick_types(params['info'], **pick_kwargs) if len(idxs) < 1: continue mask = _in1d(idxs, picks, assume_unique=True) @@ -542,9 +656,14 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, if not len(inds) == len(picks): raise RuntimeError('Some channels not classified. Please' ' check your picks') + ch_names = [params['info']['ch_names'][x] for x in inds] # set up plotting + size = get_config('MNE_BROWSE_RAW_SIZE') + n_epochs = min(n_epochs, len(epochs.events)) + duration = len(epochs.times) * n_epochs + n_channels = min(n_channels, len(picks)) if size is not None: size = size.split(',') size = tuple(float(s) for s in size) @@ -559,7 +678,8 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, ax.annotate(title, xy=(0.5, 1), xytext=(0, ax.get_ylim()[1] + 15), ha='center', va='bottom', size=12, xycoords='axes fraction', textcoords='offset points') - + color = _handle_default('color', None) + bad_color = (0.8, 0.8, 0.8) ax.axis([0, duration, 0, 200]) ax2 = ax.twiny() ax2.set_zorder(-1) @@ -578,9 +698,13 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, # populate vertical and horizontal scrollbars for ci in range(len(picks)): + if ch_names[ci] in params['info']['bads']: + this_color = bad_color + else: + this_color = color[types[ci]] ax_vscroll.add_patch(mpl.patches.Rectangle((0, ci), 1, 1, - facecolor=color[types[ci]], - edgecolor=color[types[ci]], + facecolor=this_color, + edgecolor=this_color, zorder=3)) vsel_patch = mpl.patches.Rectangle((0, 0), 1, n_channels, alpha=0.5, @@ -591,7 +715,6 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, ax_vscroll.set_title('Ch.') # populate colors list - ch_names = [epochs.info['ch_names'][x] for x in inds] type_colors = [colorConverter.to_rgba(color[c]) for c in types] colors = list() for color_idx in range(len(type_colors)): @@ -607,9 +730,8 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, ax.add_collection(lc) lines.append(lc) - epoch_data = np.concatenate(epoch_data, axis=1) times = epochs.times - data = np.zeros((epochs.info['nchan'], len(times) * len(epochs.events))) + data = np.zeros((params['info']['nchan'], len(times) * len(epochs.events))) ylim = (25., 0.) # make shells for plotting traces @@ -703,6 +825,7 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, opt_button = mpl.widgets.Button(ax_button, 'Proj') callback_option = partial(_toggle_options, params=params) opt_button.on_clicked(callback_option) + params['opt_button'] = opt_button params['ax_button'] = ax_button # callbacks @@ -712,8 +835,6 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, fig.canvas.mpl_connect('button_press_event', callback_click) callback_key = partial(_plot_onkey, params=params) fig.canvas.mpl_connect('key_press_event', callback_key) - callback_close = partial(_close_event, params=params) - fig.canvas.mpl_connect('close_event', callback_close) callback_resize = partial(_resize_event, params=params) fig.canvas.mpl_connect('resize_event', callback_resize) fig.canvas.mpl_connect('pick_event', partial(_onpick, params=params)) @@ -733,110 +854,6 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, callback_proj('none') _layout_figure(params) - if show: - try: - plt.show(block=block) - except TypeError: # not all versions have this - plt.show() - - return fig - - -@verbose -def plot_epochs_psd(epochs, fmin=0, fmax=np.inf, proj=False, n_fft=256, - picks=None, ax=None, color='black', area_mode='std', - area_alpha=0.33, n_overlap=0, - dB=True, n_jobs=1, show=True, verbose=None): - """Plot the power spectral density across epochs - - Parameters - ---------- - epochs : instance of Epochs - The epochs object - fmin : float - Start frequency to consider. - fmax : float - End frequency to consider. - proj : bool - Apply projection. - n_fft : int - Number of points to use in Welch FFT calculations. - picks : array-like of int | None - List of channels to use. - ax : instance of matplotlib Axes | None - Axes to plot into. If None, axes will be created. - color : str | tuple - A matplotlib-compatible color to use. - area_mode : str | None - Mode for plotting area. If 'std', the mean +/- 1 STD (across channels) - will be plotted. If 'range', the min and max (across channels) will be - plotted. Bad channels will be excluded from these calculations. - If None, no area will be plotted. - area_alpha : float - Alpha for the area. - n_overlap : int - The number of points of overlap between blocks. - dB : bool - If True, transform data to decibels. - n_jobs : int - Number of jobs to run in parallel. - show : bool - Show figure if True. - verbose : bool, str, int, or None - If not None, override default verbose level (see mne.verbose). - - Returns - ------- - fig : instance of matplotlib figure - Figure distributing one image per channel across sensor topography. - """ - import matplotlib.pyplot as plt - from .raw import _set_psd_plot_params - fig, picks_list, titles_list, ax_list, make_label = _set_psd_plot_params( - epochs.info, proj, picks, ax, area_mode) - - for ii, (picks, title, ax) in enumerate(zip(picks_list, titles_list, - ax_list)): - psds, freqs = compute_epochs_psd(epochs, picks=picks, fmin=fmin, - fmax=fmax, n_fft=n_fft, - n_overlap=n_overlap, proj=proj, - n_jobs=n_jobs) - - # Convert PSDs to dB - if dB: - psds = 10 * np.log10(psds) - unit = 'dB' - else: - unit = 'power' - # mean across epochs and channels - psd_mean = np.mean(psds, axis=0).mean(axis=0) - if area_mode == 'std': - # std across channels - psd_std = np.std(np.mean(psds, axis=0), axis=0) - hyp_limits = (psd_mean - psd_std, psd_mean + psd_std) - elif area_mode == 'range': - hyp_limits = (np.min(np.mean(psds, axis=0), axis=0), - np.max(np.mean(psds, axis=0), axis=0)) - else: # area_mode is None - hyp_limits = None - - ax.plot(freqs, psd_mean, color=color) - if hyp_limits is not None: - ax.fill_between(freqs, hyp_limits[0], y2=hyp_limits[1], - color=color, alpha=area_alpha) - if make_label: - if ii == len(picks_list) - 1: - ax.set_xlabel('Freq (Hz)') - if ii == len(picks_list) // 2: - ax.set_ylabel('Power Spectral Density (%s/Hz)' % unit) - ax.set_title(title) - ax.set_xlim(freqs[0], freqs[-1]) - if make_label: - tight_layout(pad=0.1, h_pad=0.1, w_pad=0.1, fig=fig) - if show: - plt.show() - return fig - def _plot_traces(params): """ Helper for plotting concatenated epochs """ @@ -872,16 +889,16 @@ def _plot_traces(params): break elif ch_idx < len(params['ch_names']): if butterfly: - type = params['types'][ch_idx] - if type == 'grad': + ch_type = params['types'][ch_idx] + if ch_type == 'grad': offset = offsets[0] - elif type == 'mag': + elif ch_type == 'mag': offset = offsets[1] - elif type == 'eeg': + elif ch_type == 'eeg': offset = offsets[2] - elif type == 'eog': + elif ch_type == 'eog': offset = offsets[3] - elif type == 'ecg': + elif ch_type == 'ecg': offset = offsets[4] else: lines[line_idx].set_segments(list()) @@ -898,7 +915,7 @@ def _plot_traces(params): segments = np.split(np.array((xdata, ydata)).T, num_epochs) ch_name = params['ch_names'][ch_idx] - if ch_name in params['epochs'].info['bads']: + if ch_name in params['info']['bads']: if not butterfly: this_color = params['bad_color'] ylabels[line_idx].set_color(this_color) @@ -1084,6 +1101,7 @@ def _mouse_click(event, params): if event.inaxes is None: if params['butterfly'] or not params['settings'][0]: return + # Bad channel selected ax = params['ax'] ylim = ax.get_ylim() pos = ax.transData.inverted().transform((event.x, event.y)) @@ -1094,12 +1112,12 @@ def _mouse_click(event, params): line_idx = np.searchsorted(offsets, pos[1]) text = labels[line_idx].get_text() ch_idx = params['ch_start'] + line_idx - if text in params['epochs'].info['bads']: - params['epochs'].info['bads'].remove(text) + if text in params['info']['bads']: + params['info']['bads'].remove(text) color = params['def_colors'][ch_idx] params['ax_vscroll'].patches[ch_idx + 1].set_color(color) else: - params['epochs'].info['bads'].append(text) + params['info']['bads'].append(text) color = params['bad_color'] params['ax_vscroll'].patches[ch_idx + 1].set_color(color) params['plot_fun']() @@ -1370,6 +1388,7 @@ def _close_event(event, params): """Function to drop selected bad epochs. Called on closing of the plot.""" params['epochs'].drop_epochs(params['bads']) logger.info('Channels marked as bad: %s' % params['epochs'].info['bads']) + params['epochs'].info['bads'] = params['info']['bads'] def _resize_event(event, params): diff --git a/mne/viz/ica.py b/mne/viz/ica.py index f7cdb1feb7e..aa2bca363a6 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -16,6 +16,7 @@ from .utils import _layout_figure, _plot_raw_onscroll, _mouse_click from .utils import _plot_raw_traces, _helper_raw_resize, _plot_raw_onkey from .utils import _select_bads +from .epochs import _prepare_mne_browse_epochs from .evoked import _butterfly_on_button_press, _butterfly_onpick from ..defaults import _handle_default from ..io.meas_info import create_info @@ -594,3 +595,45 @@ def _close_event(events, params): info = params['info'] exclude = [info['ch_names'].index(x) for x in info['bads']] params['ica'].exclude = exclude + + +def _plot_epoch_components(ica, epochs, exclude=None, title=None, n_epochs=20, + n_channels=20, bgcolor='w', color=(0., 0., 0.), + bad_color=(1., 0., 0.), show=True, block=False): + """Function for plotting the components as epochs.""" + import matplotlib.pyplot as plt + data = ica._transform_epochs(epochs, concatenate=True) + inds = range(ica.n_components_) + c_names = ['ICA ' + str(x + 1) for x in range(ica.n_components_)] + scalings = {'misc': 2.0} + info = create_info(ch_names=c_names, sfreq=epochs.info['sfreq']) + info['projs'] = list() + if exclude is None: + exclude = ica.exclude + else: + exclude += ica.exclude + info['bads'] = [c_names[x] for x in exclude] + params = {'ica': ica, + 'epochs': epochs, + 'info': info, + 'orig_data': data, + 'bads': list()} + _prepare_mne_browse_epochs(params, projs=list(), n_channels=n_channels, + n_epochs=n_epochs, scalings=scalings, + title=title, picks=inds) + callback_close = partial(_close_epochs_event, params=params) + params['fig'].canvas.mpl_connect('close_event', callback_close) + if show: + try: + plt.show(block=block) + except TypeError: # not all versions have this + plt.show() + + return params['fig'] + + +def _close_epochs_event(events, params): + """Function for excluding the selected components on close.""" + info = params['info'] + exclude = [info['ch_names'].index(x) for x in info['bads']] + params['ica'].exclude = exclude From c73b5409d7b21e4b255b7927cf499c2a2ebfa246 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Mon, 6 Jul 2015 13:27:35 +0300 Subject: [PATCH 13/36] Updated test. --- mne/viz/tests/test_epochs.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mne/viz/tests/test_epochs.py b/mne/viz/tests/test_epochs.py index 484b9faac7e..3f341c2d50e 100644 --- a/mne/viz/tests/test_epochs.py +++ b/mne/viz/tests/test_epochs.py @@ -128,6 +128,8 @@ def test_plot_epochs(): fig.canvas.resize_event() fig.canvas.close_event() # closing and epoch dropping plt.close('all') + assert_raises(RuntimeError, epochs.plot, picks=[], trellis=False) + plt.close('all') with warnings.catch_warnings(record=True): fig = epochs.plot(trellis=False) # test mouse clicks @@ -144,9 +146,6 @@ def test_plot_epochs(): assert(n_epochs - 1 == len(epochs)) plt.close('all') - assert_raises(RuntimeError, epochs.plot, picks=[], trellis=False) - plt.close('all') - def test_plot_image_epochs(): """Test plotting of epochs image From 318fe9e42947f22f7f4850c36ab41d6cdd6dfad3 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Mon, 6 Jul 2015 14:14:04 +0300 Subject: [PATCH 14/36] Disabled bad epoch selection for ica. Color for excluded components. --- mne/viz/epochs.py | 45 +++++++++++++++++++++++++++------------------ mne/viz/ica.py | 3 ++- 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/mne/viz/epochs.py b/mne/viz/epochs.py index 1a16e558ad2..3bcd6ea2236 100644 --- a/mne/viz/epochs.py +++ b/mne/viz/epochs.py @@ -505,7 +505,8 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, params = {'epochs': epochs, 'orig_data': np.concatenate(epochs.get_data(), axis=1), - 'info': copy.deepcopy(epochs.info)} + 'info': copy.deepcopy(epochs.info), + 'bad_color': (0.8, 0.8, 0.8)} _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings, title, picks) @@ -679,7 +680,7 @@ def _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings, ha='center', va='bottom', size=12, xycoords='axes fraction', textcoords='offset points') color = _handle_default('color', None) - bad_color = (0.8, 0.8, 0.8) + ax.axis([0, duration, 0, 200]) ax2 = ax.twiny() ax2.set_zorder(-1) @@ -699,7 +700,7 @@ def _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings, # populate vertical and horizontal scrollbars for ci in range(len(picks)): if ch_names[ci] in params['info']['bads']: - this_color = bad_color + this_color = params['bad_color'] else: this_color = color[types[ci]] ax_vscroll.add_patch(mpl.patches.Rectangle((0, ci), 1, 1, @@ -1055,6 +1056,10 @@ def _plot_vert_lines(params): def _pick_bad_epochs(event, params): """Helper for selecting / dropping bad epochs""" + if 'ica' in params: + pos = (event.xdata, event.ydata) + _pick_bad_channels(pos, params) + return n_times = len(params['epochs'].times) start_idx = int(params['t_start'] / n_times) xdata = event.xdata @@ -1082,6 +1087,24 @@ def _pick_bad_epochs(event, params): params['plot_fun']() +def _pick_bad_channels(pos, params): + """Helper function for selecting bad channels.""" + labels = params['ax'].yaxis.get_ticklabels() + offsets = np.array(params['offsets']) + params['offsets'][0] + line_idx = np.searchsorted(offsets, pos[1]) + text = labels[line_idx].get_text() + ch_idx = params['ch_start'] + line_idx + if text in params['info']['bads']: + params['info']['bads'].remove(text) + color = params['def_colors'][ch_idx] + params['ax_vscroll'].patches[ch_idx + 1].set_color(color) + else: + params['info']['bads'].append(text) + color = params['bad_color'] + params['ax_vscroll'].patches[ch_idx + 1].set_color(color) + params['plot_fun']() + + def _plot_onscroll(event, params): """Function to handle scroll events.""" if event.key == 'control': @@ -1101,26 +1124,12 @@ def _mouse_click(event, params): if event.inaxes is None: if params['butterfly'] or not params['settings'][0]: return - # Bad channel selected ax = params['ax'] ylim = ax.get_ylim() pos = ax.transData.inverted().transform((event.x, event.y)) if pos[0] > 0 or pos[1] < 0 or pos[1] > ylim[0]: return - labels = ax.yaxis.get_ticklabels() - offsets = np.array(params['offsets']) + params['offsets'][0] - line_idx = np.searchsorted(offsets, pos[1]) - text = labels[line_idx].get_text() - ch_idx = params['ch_start'] + line_idx - if text in params['info']['bads']: - params['info']['bads'].remove(text) - color = params['def_colors'][ch_idx] - params['ax_vscroll'].patches[ch_idx + 1].set_color(color) - else: - params['info']['bads'].append(text) - color = params['bad_color'] - params['ax_vscroll'].patches[ch_idx + 1].set_color(color) - params['plot_fun']() + _pick_bad_channels(pos, params) elif event.button == 1: # left click # vertical scroll bar changed if event.inaxes == params['ax_vscroll']: diff --git a/mne/viz/ica.py b/mne/viz/ica.py index aa2bca363a6..3ea9747bfbb 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -617,7 +617,8 @@ def _plot_epoch_components(ica, epochs, exclude=None, title=None, n_epochs=20, 'epochs': epochs, 'info': info, 'orig_data': data, - 'bads': list()} + 'bads': list(), + 'bad_color': bad_color} _prepare_mne_browse_epochs(params, projs=list(), n_channels=n_channels, n_epochs=n_epochs, scalings=scalings, title=title, picks=inds) From b10c69f8c343376fe19490bd1d6c5f7dead0961d Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Mon, 6 Jul 2015 15:27:03 +0300 Subject: [PATCH 15/36] Changed plotters to be called by ica.plot_sources. --- mne/preprocessing/ica.py | 119 ++++---------------------------------- mne/viz/epochs.py | 81 +++++++++++++------------- mne/viz/ica.py | 95 ++++++++++++++++-------------- mne/viz/tests/test_ica.py | 2 +- 4 files changed, 105 insertions(+), 192 deletions(-) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 407da00afd2..770c040cfe0 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -32,8 +32,8 @@ from ..io.constants import Bunch, FIFF from ..io.base import _BaseRaw from ..epochs import _BaseEpochs -from ..viz import (plot_ica_components, plot_ica_scores, _plot_raw_components, - plot_ica_sources, plot_ica_overlay, _plot_epoch_components) +from ..viz import (plot_ica_components, plot_ica_scores, + plot_ica_sources, plot_ica_overlay) from ..channels.channels import _contains_ch_type, ContainsMixin from ..io.write import start_file, end_file, write_id from ..utils import (check_sklearn_version, logger, check_fname, verbose, @@ -1335,7 +1335,7 @@ def plot_components(self, picks=None, ch_type=None, res=64, layout=None, head_pos=head_pos) def plot_sources(self, inst, picks=None, exclude=None, start=None, - stop=None, title=None, show=True): + stop=None, title=None, show=True, block=False): """Plot estimated latent sources given the unmixing matrix. Typical usecases: @@ -1363,15 +1363,23 @@ def plot_sources(self, inst, picks=None, exclude=None, start=None, The figure title. If None a default is provided. show : bool If True, all open plots will be shown. + block : bool + Whether to halt program execution until the figure is closed. + Useful for interactive selection of components in raw and epoch + plotter. For evoked, this parameter has no effect. Defaults to + False. Returns ------- fig : instance of pyplot.Figure The figure. + + .. versionadded:: 0.10.0 """ return plot_ica_sources(self, inst=inst, picks=picks, exclude=exclude, - title=title, start=start, stop=stop, show=show) + title=title, start=start, stop=stop, show=show, + block=block) def plot_scores(self, scores, exclude=None, axhline=None, title='ICA component scores', figsize=(12, 6), @@ -1572,109 +1580,6 @@ def _check_n_pca_components(self, _n_pca_comp, verbose=None): return _n_pca_comp - def plot_raw_components(self, raw, exclude=None, title=None, duration=10.0, - n_channels=20, bgcolor='w', color=(0., 0., 0.), - bad_color=(1., 0., 0.), show=True, block=False): - """Plot ICA components. - - Parameters - ---------- - raw : instance of Raw - Raw object to draw sources from. - exclude : array_like of int | None - The components marked for exclusion. If None (default), ICA.exclude - will be used. - title : str - Title for the plot. If None, ``ICA components`` is displayed. - Defaults to None - duration : float - Time window (sec) to plot in a given time. Defaults to 10.0. - n_channels : int - The number of channels per view. Defaults to 20. - bgcolor : color object - Color of the background. - color : color object - Color for the data traces. Defaults to (0., 0., 0.) (black). - bad_color : color object - Color to use for components marked as bad. - Defaults to (1., 0., 0.) (red). - show : bool - Show figures if True. Defaults to True. - block : bool - Whether to halt program execution until the figure is closed. - Useful for selecting components for exclusion on the fly - (click on line). May not work on all systems / platforms. - Defaults to False. - - Returns - ------- - fig : Instance of matplotlib.figure.Figure - The figure. - - Notes - ----- - To mark or un-mark a component for exclusion, click on the rather flat - segments of a channel's time series. The changes will be reflected - immediately in the ica object's ``ica.exclude`` entry. - """ - return _plot_raw_components(self, raw, exclude=exclude, title=title, - duration=duration, n_channels=n_channels, - bgcolor=bgcolor, color=color, - bad_color=bad_color, show=show, - block=block) - - def plot_epoch_components(self, epochs, exclude=None, title=None, - n_epochs=10, n_channels=20, bgcolor='w', - color=(0., 0., 0.), bad_color=(1., 0., 0.), - show=True, block=False): - """Plot ICA components. - - Parameters - ---------- - epochs : instance of Epochs - Epochs object to draw sources from. - exclude : array_like of int | None - The components marked for exclusion. If None (default), ICA.exclude - will be used. - title : str - Title for the plot. If None, ``ICA components`` is displayed. - Defaults to None - n_epochs : int - Number of epoch per view. Defaults to 10. - n_channels : int - The number of channels per view. Defaults to 20. - bgcolor : color object - Color of the background. - color : color object - Color for the data traces. Defaults to (0., 0., 0.) (black). - bad_color : color object - Color to use for epochs marked as bad. - Defaults to (1., 0., 0.) (red). - show : bool - Show figures if True. Defaults to True. - block : bool - Whether to halt program execution until the figure is closed. - Useful for selecting components for exclusion on the fly - (click on line). May not work on all systems / platforms. - Defaults to False. - - Returns - ------- - fig : Instance of matplotlib.figure.Figure - The figure. - - Notes - ----- - To mark or un-mark a component for exclusion, click on the component - name left of the main axes. The changes will be reflected - immediately in the ica object's ``ica.exclude`` entry. - """ - return _plot_epoch_components(self, epochs, exclude=exclude, - title=title, n_epochs=n_epochs, - n_channels=n_channels, bgcolor=bgcolor, - color=color, bad_color=bad_color, - show=show, block=block) - def _check_start_stop(raw, start, stop): """Aux function""" diff --git a/mne/viz/epochs.py b/mne/viz/epochs.py index 3bcd6ea2236..624460c4ae1 100644 --- a/mne/viz/epochs.py +++ b/mne/viz/epochs.py @@ -506,7 +506,8 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, params = {'epochs': epochs, 'orig_data': np.concatenate(epochs.get_data(), axis=1), 'info': copy.deepcopy(epochs.info), - 'bad_color': (0.8, 0.8, 0.8)} + 'bad_color': (0.8, 0.8, 0.8), + 't_start': 0} _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings, title, picks) @@ -777,47 +778,43 @@ def _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings, ha='left', fontweight='bold') text.set_visible(False) - params = {'fig': fig, - 'ax': ax, - 'ax2': ax2, - 'ax_hscroll': ax_hscroll, - 'ax_vscroll': ax_vscroll, - 'vsel_patch': vsel_patch, - 'hsel_patch': hsel_patch, - 'epochs': epochs, - 'info': copy.deepcopy(epochs.info), # needed for projs - 'lines': lines, - 'n_channels': n_channels, - 'n_epochs': n_epochs, - 'ch_start': 0, - 't_start': 0, - 'duration': duration, - 'colors': colors, - 'def_colors': type_colors, # don't change at runtime - 'picks': picks, - 'bad_color': bad_color, - 'bads': np.array(list(), dtype=int), - 'ch_names': ch_names, - 'data': data, - 'orig_data': epoch_data, - 'times': times, - 'epoch_times': epoch_times, - 'offsets': offsets, - 'labels': labels, - 'projs': projs, - 'scale_factor': 1.0, - 'butterfly_scale': 1.0, - 'fig_proj': None, - 'inds': inds, - 'scalings': scalings, - 'types': np.array(types), - 'vert_lines': list(), - 'vertline_t': vertline_t, - 'butterfly': False, - 'text': text, - 'ax_help_button': ax_help_button, - 'fig_options': None, - 'settings': [True, True, True, True]} # for options dialog + params.update({'fig': fig, + 'ax': ax, + 'ax2': ax2, + 'ax_hscroll': ax_hscroll, + 'ax_vscroll': ax_vscroll, + 'vsel_patch': vsel_patch, + 'hsel_patch': hsel_patch, + 'lines': lines, + 'projs': projs, + 'ch_names': ch_names, + 'n_channels': n_channels, + 'n_epochs': n_epochs, + 'scalings': scalings, + 'types': types, + 'duration': duration, + 'ch_start': 0, + 'colors': colors, + 'def_colors': type_colors, # don't change at runtime + 'picks': picks, + 'bads': np.array(list(), dtype=int), + 'data': data, + 'times': times, + 'epoch_times': epoch_times, + 'offsets': offsets, + 'labels': labels, + 'scale_factor': 1.0, + 'butterfly_scale': 1.0, + 'fig_proj': None, + 'types': np.array(types), + 'inds': inds, + 'vert_lines': list(), + 'vertline_t': vertline_t, + 'butterfly': False, + 'text': text, + 'ax_help_button': ax_help_button, + 'fig_options': None, + 'settings': [True, True, True, True]}) params['plot_fun'] = partial(_plot_traces, params=params) diff --git a/mne/viz/ica.py b/mne/viz/ica.py index 3ea9747bfbb..4239c80d584 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -47,7 +47,7 @@ def _ica_plot_sources_onpick_(event, sources=None, ylims=None): def plot_ica_sources(ica, inst, picks=None, exclude=None, start=None, - stop=None, show=True, title=None): + stop=None, show=True, title=None, block=False): """Plot estimated latent sources given the unmixing matrix. Typical usecases: @@ -70,18 +70,25 @@ def plot_ica_sources(ica, inst, picks=None, exclude=None, start=None, The components marked for exclusion. If None (default), ICA.exclude will be used. start : int - X-axis start index. If None from the beginning. + X-axis start index. If None, from the beginning. stop : int - X-axis stop index. If None to the end. + X-axis stop index. If None, next 20 are shown, in case of evoked to the + end. show : bool Show figure if True. title : str | None The figure title. If None a default is provided. + block : bool + Whether to halt program execution until the figure is closed. + Useful for interactive selection of components in raw and epoch + plotter. For evoked, this parameter has no effect. Defaults to False. Returns ------- fig : instance of pyplot.Figure The figure. + + .. versionadded:: 0.10.0 """ from ..io.base import _BaseRaw @@ -91,23 +98,14 @@ def plot_ica_sources(ica, inst, picks=None, exclude=None, start=None, if exclude is None: exclude = ica.exclude - if isinstance(inst, (_BaseRaw, _BaseEpochs)): - if isinstance(inst, _BaseRaw): - sources = ica._transform_raw(inst, start, stop) - else: - if start is not None or stop is not None: - inst = inst.crop(start, stop, copy=True) - sources = ica._transform_epochs(inst, concatenate=True) - if picks is not None: - if np.isscalar(picks): - picks = [picks] - sources = np.atleast_2d(sources[picks]) - - fig = _plot_ica_grid(sources, start=start, stop=stop, - ncol=len(sources) // 10 or 1, - exclude=exclude, - source_idx=picks, - title=title, show=show) + if isinstance(inst, _BaseRaw): + fig = _plot_raw_components(ica, inst, picks, exclude, start=start, + stop=stop, show=show, title=title, + block=block) + elif isinstance(inst, _BaseEpochs): + fig = _plot_epoch_components(ica, inst, picks, exclude, start=start, + stop=stop, show=show, title=title, + block=block) elif isinstance(inst, Evoked): sources = ica.get_sources(inst) if start is not None or stop is not None: @@ -518,12 +516,11 @@ def _plot_ica_overlay_evoked(evoked, evoked_cln, title, show): return fig -def _plot_raw_components(ica, raw, exclude=None, title=None, duration=10.0, - n_channels=20, bgcolor='w', color=(0., 0., 0.), - bad_color=(1., 0., 0.), show=True, block=False): +def _plot_raw_components(ica, raw, picks, exclude, start, stop, show, title, + block): """Function for plotting the ICA components as raw array.""" import matplotlib.pyplot as plt - color = _handle_default('color', color) + color = _handle_default('color', (0., 0., 0.)) orig_data = ica._transform_raw(raw, 0, len(raw.times)) * 0.2 inds = range(len(orig_data)) types = np.repeat('misc', len(inds)) @@ -533,17 +530,23 @@ def _plot_raw_components(ica, raw, exclude=None, title=None, duration=10.0, title = 'ICA components' info = create_info(c_names, raw.info['sfreq']) - if exclude is None: - exclude = list() info['bads'] = [c_names[x] for x in exclude] + if start is None: + start = 0 + if stop is None: + stop = start + 20 + stop = min(stop, raw.times[-1]) + duration = stop - start + if duration <= 0: + raise RuntimeError('Stop must be larger than start.') t_end = int(duration * raw.info['sfreq']) times = raw.times[0:t_end] + bad_color = (1., 0., 0.) params = dict(raw=raw, orig_data=orig_data, data=orig_data[:, 0:t_end], - ch_start=0, t_start=0, info=info, duration=duration, ica=ica, - n_channels=n_channels, times=times, types=types, + ch_start=0, t_start=start, info=info, duration=duration, + ica=ica, n_channels=20, times=times, types=types, n_times=raw.n_times, bad_color=bad_color) - _prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds, - n_channels) + _prepare_mne_browse_raw(params, title, 'w', color, bad_color, inds, 20) params['scale_factor'] = 1.0 params['plot_fun'] = partial(_plot_raw_traces, params=params, inds=inds, color=color, bad_color=bad_color) @@ -597,31 +600,39 @@ def _close_event(events, params): params['ica'].exclude = exclude -def _plot_epoch_components(ica, epochs, exclude=None, title=None, n_epochs=20, - n_channels=20, bgcolor='w', color=(0., 0., 0.), - bad_color=(1., 0., 0.), show=True, block=False): +def _plot_epoch_components(ica, epochs, picks, exclude, start, stop, show, + title, block): """Function for plotting the components as epochs.""" import matplotlib.pyplot as plt data = ica._transform_epochs(epochs, concatenate=True) - inds = range(ica.n_components_) c_names = ['ICA ' + str(x + 1) for x in range(ica.n_components_)] - scalings = {'misc': 2.0} + scalings = {'misc': 5.0} info = create_info(ch_names=c_names, sfreq=epochs.info['sfreq']) info['projs'] = list() - if exclude is None: - exclude = ica.exclude - else: - exclude += ica.exclude info['bads'] = [c_names[x] for x in exclude] + if title is None: + title = 'ICA components' + if picks is None: + picks = range(len(c_names)) + if start is None: + start = 0 + if stop is None: + stop = start + 20 + stop = min(stop, len(epochs.events)) + n_epochs = stop - start + if n_epochs <= 0: + raise RuntimeError('Stop must be larger than start.') params = {'ica': ica, 'epochs': epochs, 'info': info, 'orig_data': data, 'bads': list(), - 'bad_color': bad_color} - _prepare_mne_browse_epochs(params, projs=list(), n_channels=n_channels, + 'bad_color': (1., 0., 0.), + 't_start': start} + + _prepare_mne_browse_epochs(params, projs=list(), n_channels=20, n_epochs=n_epochs, scalings=scalings, - title=title, picks=inds) + title=title, picks=picks) callback_close = partial(_close_epochs_event, params=params) params['fig'].canvas.mpl_connect('close_event', callback_close) if show: diff --git a/mne/viz/tests/test_ica.py b/mne/viz/tests/test_ica.py index 2014cf1e2fe..8f4a676bac1 100644 --- a/mne/viz/tests/test_ica.py +++ b/mne/viz/tests/test_ica.py @@ -152,7 +152,7 @@ def test_plot_raw_components(): ica = ICA(noise_cov=read_cov(cov_fname), n_components=2, max_pca_components=3, n_pca_components=3) ica.fit(raw, picks=picks) - fig = ica.plot_raw_components(raw, exclude=[0], title='Components') + fig = ica.plot_sources(raw, exclude=[0], title='Components') fig.canvas.key_press_event('down') fig.canvas.key_press_event('up') fig.canvas.key_press_event('right') From 51b5f009a330e9d221cbd39eecc9ab8a0cf480ae Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Mon, 6 Jul 2015 15:45:15 +0300 Subject: [PATCH 16/36] Updated examples and whats_new. Docstrings. --- doc/source/whats_new.rst | 3 +++ examples/preprocessing/plot_ica_from_raw.py | 1 - mne/preprocessing/ica.py | 6 ++++++ mne/viz/ica.py | 6 ++++++ 4 files changed, 15 insertions(+), 1 deletion(-) diff --git a/doc/source/whats_new.rst b/doc/source/whats_new.rst index 6eea3335981..43ed18f9f12 100644 --- a/doc/source/whats_new.rst +++ b/doc/source/whats_new.rst @@ -27,6 +27,9 @@ Changelog - Add support for BEM solution computation :func:`mne.make_bem_solution` by `Eric Larson`_ + - Add ICA plotters for raw and epoch components by `Jaakko Leppakangas`_ + + BUG ~~~ diff --git a/examples/preprocessing/plot_ica_from_raw.py b/examples/preprocessing/plot_ica_from_raw.py index 90a460cbd74..1a8c96644dc 100644 --- a/examples/preprocessing/plot_ica_from_raw.py +++ b/examples/preprocessing/plot_ica_from_raw.py @@ -62,7 +62,6 @@ show_picks = np.abs(scores).argsort()[::-1][:5] -ica.plot_raw_components(raw, exclude=ecg_inds, title=title % 'ecg') ica.plot_sources(raw, show_picks, exclude=ecg_inds, title=title % 'ecg') ica.plot_components(ecg_inds, title=title % 'ecg', colorbar=True) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 770c040cfe0..48abbb9bca6 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -1374,6 +1374,12 @@ def plot_sources(self, inst, picks=None, exclude=None, start=None, fig : instance of pyplot.Figure The figure. + Notes + ----- + For raw and epoch instances, it is possible to select components for + exclusion by clicking on the line. The selected components are added to + ``ica.exclude`` on close. + .. versionadded:: 0.10.0 """ diff --git a/mne/viz/ica.py b/mne/viz/ica.py index 4239c80d584..f320f4ad315 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -88,6 +88,12 @@ def plot_ica_sources(ica, inst, picks=None, exclude=None, start=None, fig : instance of pyplot.Figure The figure. + Notes + ----- + For raw and epoch instances, it is possible to select components for + exclusion by clicking on the line. The selected components are added to + ``ica.exclude`` on close. + .. versionadded:: 0.10.0 """ From a0752338ce09fb44d7b8099f0fc5fb3f14ee96fb Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Tue, 7 Jul 2015 11:03:29 +0300 Subject: [PATCH 17/36] Fix to whats_new. Tests. --- mne/viz/tests/test_ica.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/mne/viz/tests/test_ica.py b/mne/viz/tests/test_ica.py index 8f4a676bac1..fde26575464 100644 --- a/mne/viz/tests/test_ica.py +++ b/mne/viz/tests/test_ica.py @@ -144,8 +144,8 @@ def test_plot_ica_scores(): @requires_sklearn -def test_plot_raw_components(): - """Test plotting of raw components.""" +def test_plot_instance_components(): + """Test plotting of components as instances of raw and epochs.""" import matplotlib.pyplot as plt raw = _get_raw() picks = _get_picks(raw) @@ -166,6 +166,30 @@ def test_plot_raw_components(): fig.canvas.key_press_event('home') fig.canvas.key_press_event('end') fig.canvas.key_press_event('f11') + ax = fig.get_axes()[0] + line = ax.lines[0] + _fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], 'data') + fig.canvas.key_press_event('escape') + plt.close('all') + epochs = _get_epochs() + fig = ica.plot_sources(epochs, exclude=[0], title='Components') + fig.canvas.key_press_event('down') + fig.canvas.key_press_event('up') + fig.canvas.key_press_event('right') + fig.canvas.key_press_event('left') + fig.canvas.key_press_event('o') + fig.canvas.key_press_event('-') + fig.canvas.key_press_event('+') + fig.canvas.key_press_event('=') + fig.canvas.key_press_event('pageup') + fig.canvas.key_press_event('pagedown') + fig.canvas.key_press_event('home') + fig.canvas.key_press_event('end') + fig.canvas.key_press_event('f11') + # Test a click + ax = fig.get_axes()[0] + line = ax.lines[0] + _fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], 'data') fig.canvas.key_press_event('escape') plt.close('all') From 1a5752cc860e99b2dfef903ead51d627bcd3da0e Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Tue, 7 Jul 2015 13:20:48 +0300 Subject: [PATCH 18/36] Fixes. Components plotted on click to label. --- mne/viz/epochs.py | 3 ++- mne/viz/ica.py | 11 ++++++++++- mne/viz/raw.py | 6 ++++++ mne/viz/utils.py | 13 ++++++++++++- 4 files changed, 30 insertions(+), 3 deletions(-) diff --git a/mne/viz/epochs.py b/mne/viz/epochs.py index 624460c4ae1..a86aeb4c6bf 100644 --- a/mne/viz/epochs.py +++ b/mne/viz/epochs.py @@ -508,6 +508,7 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, 'info': copy.deepcopy(epochs.info), 'bad_color': (0.8, 0.8, 0.8), 't_start': 0} + params['label_click_fun'] = partial(_pick_bad_channels, params=params) _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings, title, picks) @@ -1126,7 +1127,7 @@ def _mouse_click(event, params): pos = ax.transData.inverted().transform((event.x, event.y)) if pos[0] > 0 or pos[1] < 0 or pos[1] > ylim[0]: return - _pick_bad_channels(pos, params) + params['label_click_fun'](pos) elif event.button == 1: # left click # vertical scroll bar changed if event.inaxes == params['ax_vscroll']: diff --git a/mne/viz/ica.py b/mne/viz/ica.py index f320f4ad315..82ba1cdd784 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -558,6 +558,7 @@ def _plot_raw_components(ica, raw, picks, exclude, start, stop, show, title, color=color, bad_color=bad_color) params['update_fun'] = partial(_update_data, params) params['pick_bads_fun'] = partial(_pick_bads, params=params) + params['label_click_fun'] = partial(_label_clicked, params=params) _layout_figure(params) # callbacks callback_key = partial(_plot_raw_onkey, params=params) @@ -572,6 +573,7 @@ def _plot_raw_components(ica, raw, picks, exclude, start, stop, show, title, params['fig'].canvas.mpl_connect('close_event', callback_close) params['fig_proj'] = None params['event_times'] = None + params['update_fun']() params['plot_fun']() if show: try: @@ -635,7 +637,7 @@ def _plot_epoch_components(ica, epochs, picks, exclude, start, stop, show, 'bads': list(), 'bad_color': (1., 0., 0.), 't_start': start} - + params['label_click_fun'] = partial(_label_clicked, params=params) _prepare_mne_browse_epochs(params, projs=list(), n_channels=20, n_epochs=n_epochs, scalings=scalings, title=title, picks=picks) @@ -655,3 +657,10 @@ def _close_epochs_event(events, params): info = params['info'] exclude = [info['ch_names'].index(x) for x in info['bads']] params['ica'].exclude = exclude + + +def _label_clicked(pos, params): + """""" + offsets = np.array(params['offsets']) + params['offsets'][0] + line_idx = [np.searchsorted(offsets, pos[1]) + params['ch_start']] + params['ica'].plot_components(picks=line_idx) diff --git a/mne/viz/raw.py b/mne/viz/raw.py index 5d34c5244e9..3b625e6e70d 100644 --- a/mne/viz/raw.py +++ b/mne/viz/raw.py @@ -290,6 +290,7 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=None, event_color=event_color) params['update_fun'] = partial(_update_raw_data, params=params) params['pick_bads_fun'] = partial(_pick_bad_channels, params=params) + params['label_click_fun'] = partial(_label_clicked, params=params) params['scale_factor'] = 1.0 # set up callbacks opt_button = None @@ -336,6 +337,11 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=None, return params['fig'] +def _label_clicked(pos, params): + """Empty placeholder for clicks on channel names.""" + pass + + def _set_psd_plot_params(info, proj, picks, ax, area_mode): """Aux function""" import matplotlib.pyplot as plt diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 9a337f5485a..a4f92fadb57 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -446,6 +446,8 @@ def _prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds, offsets = np.arange(n_channels) * offset + (offset / 2.) ax.set_yticks(offsets) ax.set_ylim(ylim) + ax.set_xlim(params['t_start'], params['t_start'] + params['duration'], + False) params['offsets'] = offsets params['lines'] = [ax.plot([np.nan], antialiased=False, linewidth=0.5)[0] @@ -577,8 +579,17 @@ def _plot_raw_onkey(event, params): def _mouse_click(event, params): """Vertical select callback""" - if event.inaxes is None or event.button != 1: + if event.button != 1: return + if event.inaxes is None: + if params['n_channels'] > 100: + return + ax = params['ax'] + ylim = ax.get_ylim() + pos = ax.transData.inverted().transform((event.x, event.y)) + if pos[0] > params['t_start'] or pos[1] < 0 or pos[1] > ylim[0]: + return + params['label_click_fun'](pos) # vertical scrollbar changed if event.inaxes == params['ax_vscroll']: ch_start = max(int(event.ydata) - params['n_channels'] // 2, 0) From 53e8239420667d1bc519ea8b0c3a2c12fe5a0e71 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Tue, 7 Jul 2015 14:02:56 +0300 Subject: [PATCH 19/36] Fixes. --- mne/viz/ica.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mne/viz/ica.py b/mne/viz/ica.py index 82ba1cdd784..ce318efc415 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -531,7 +531,7 @@ def _plot_raw_components(ica, raw, picks, exclude, start, stop, show, title, inds = range(len(orig_data)) types = np.repeat('misc', len(inds)) - c_names = ['ICA ' + str(x + 1) for x in inds] + c_names = ['ICA ' + str(x) for x in inds] if title is None: title = 'ICA components' info = create_info(c_names, raw.info['sfreq']) @@ -612,8 +612,9 @@ def _plot_epoch_components(ica, epochs, picks, exclude, start, stop, show, title, block): """Function for plotting the components as epochs.""" import matplotlib.pyplot as plt + plt.ion() # Turn interactive mode on to avoid warnings. data = ica._transform_epochs(epochs, concatenate=True) - c_names = ['ICA ' + str(x + 1) for x in range(ica.n_components_)] + c_names = ['ICA ' + str(x) for x in range(ica.n_components_)] scalings = {'misc': 5.0} info = create_info(ch_names=c_names, sfreq=epochs.info['sfreq']) info['projs'] = list() @@ -636,11 +637,12 @@ def _plot_epoch_components(ica, epochs, picks, exclude, start, stop, show, 'orig_data': data, 'bads': list(), 'bad_color': (1., 0., 0.), - 't_start': start} + 't_start': start * len(epochs.times)} params['label_click_fun'] = partial(_label_clicked, params=params) _prepare_mne_browse_epochs(params, projs=list(), n_channels=20, n_epochs=n_epochs, scalings=scalings, title=title, picks=picks) + params['hsel_patch'].set_x(params['t_start']) callback_close = partial(_close_epochs_event, params=params) params['fig'].canvas.mpl_connect('close_event', callback_close) if show: @@ -660,7 +662,7 @@ def _close_epochs_event(events, params): def _label_clicked(pos, params): - """""" + """Function for plotting independent components on click to label.""" offsets = np.array(params['offsets']) + params['offsets'][0] line_idx = [np.searchsorted(offsets, pos[1]) + params['ch_start']] params['ica'].plot_components(picks=line_idx) From c83336a30e156b7fbacece23b0b04d1124f84536 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Tue, 7 Jul 2015 15:19:21 +0300 Subject: [PATCH 20/36] Bad channel selection by clicking label in mne_browse_raw. --- mne/viz/raw.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/mne/viz/raw.py b/mne/viz/raw.py index 3b625e6e70d..6f713bc795e 100644 --- a/mne/viz/raw.py +++ b/mne/viz/raw.py @@ -70,6 +70,7 @@ def _update_raw_data(params): def _pick_bad_channels(event, params): """Helper for selecting / dropping bad channels onpick""" + # Both bad lists are updated. params['info'] used for colors. bads = params['raw'].info['bads'] params['info']['bads'] = _select_bads(event, params, bads) _plot_update_raw_proj(params, None) @@ -338,8 +339,23 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=None, def _label_clicked(pos, params): - """Empty placeholder for clicks on channel names.""" - pass + """Helper function for selecting bad channels.""" + labels = params['ax'].yaxis.get_ticklabels() + offsets = np.array(params['offsets']) + params['offsets'][0] + line_idx = np.searchsorted(offsets, pos[1]) + text = labels[line_idx].get_text() + ch_idx = params['ch_start'] + line_idx + bads = params['info']['bads'] + if text in bads: + bads.remove(text) + color = vars(params['lines'][ch_idx])['def_color'] + params['ax_vscroll'].patches[ch_idx + 1].set_color(color) + else: + bads.append(text) + color = params['bad_color'] + params['ax_vscroll'].patches[ch_idx + 1].set_color(color) + params['raw'].info['bads'] = bads + params['plot_fun']() def _set_psd_plot_params(info, proj, picks, ax, area_mode): From e3a5b65ac20fe49cd27e0ee7e43e3046e38d4ced Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Wed, 8 Jul 2015 12:41:46 +0300 Subject: [PATCH 21/36] Separate topomaps for eeg, mag and grad. Fixes. --- mne/viz/epochs.py | 2 +- mne/viz/ica.py | 62 +++++++++++++++++++++++++++++++++++++++++------ mne/viz/raw.py | 2 +- 3 files changed, 56 insertions(+), 10 deletions(-) diff --git a/mne/viz/epochs.py b/mne/viz/epochs.py index a86aeb4c6bf..47489c161ec 100644 --- a/mne/viz/epochs.py +++ b/mne/viz/epochs.py @@ -633,7 +633,7 @@ def _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings, if len(picks) < 1: raise RuntimeError('No appropriate channels found. Please' ' check your picks') - + picks = sorted(picks) # Reorganize channels inds = list() types = list() diff --git a/mne/viz/ica.py b/mne/viz/ica.py index ce318efc415..139876e1ab9 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -18,8 +18,10 @@ from .utils import _select_bads from .epochs import _prepare_mne_browse_epochs from .evoked import _butterfly_on_button_press, _butterfly_onpick +from .topomap import _prepare_topo_plot, plot_topomap from ..defaults import _handle_default from ..io.meas_info import create_info +from mne.io.pick import pick_types def _ica_plot_sources_onpick_(event, sources=None, ylims=None): @@ -528,13 +530,15 @@ def _plot_raw_components(ica, raw, picks, exclude, start, stop, show, title, import matplotlib.pyplot as plt color = _handle_default('color', (0., 0., 0.)) orig_data = ica._transform_raw(raw, 0, len(raw.times)) * 0.2 - inds = range(len(orig_data)) - types = np.repeat('misc', len(inds)) + if picks is None: + picks = range(len(orig_data)) + types = np.repeat('misc', len(picks)) + picks = sorted(picks) - c_names = ['ICA ' + str(x) for x in inds] + c_names = ['ICA ' + str(x) for x in range(len(orig_data))] if title is None: title = 'ICA components' - info = create_info(c_names, raw.info['sfreq']) + info = create_info([c_names[x] for x in picks], raw.info['sfreq']) info['bads'] = [c_names[x] for x in exclude] if start is None: @@ -548,10 +552,11 @@ def _plot_raw_components(ica, raw, picks, exclude, start, stop, show, title, t_end = int(duration * raw.info['sfreq']) times = raw.times[0:t_end] bad_color = (1., 0., 0.) + inds = range(len(picks)) params = dict(raw=raw, orig_data=orig_data, data=orig_data[:, 0:t_end], ch_start=0, t_start=start, info=info, duration=duration, ica=ica, n_channels=20, times=times, types=types, - n_times=raw.n_times, bad_color=bad_color) + n_times=raw.n_times, bad_color=bad_color, picks=picks) _prepare_mne_browse_raw(params, title, 'w', color, bad_color, inds, 20) params['scale_factor'] = 1.0 params['plot_fun'] = partial(_plot_raw_traces, params=params, inds=inds, @@ -604,7 +609,8 @@ def _pick_bads(event, params): def _close_event(events, params): """Function for excluding the selected components on close.""" info = params['info'] - exclude = [info['ch_names'].index(x) for x in info['bads']] + picks = params['picks'] + exclude = [picks[info['ch_names'].index(x)] for x in info['bads']] params['ica'].exclude = exclude @@ -663,6 +669,46 @@ def _close_epochs_event(events, params): def _label_clicked(pos, params): """Function for plotting independent components on click to label.""" + import matplotlib.pyplot as plt offsets = np.array(params['offsets']) + params['offsets'][0] - line_idx = [np.searchsorted(offsets, pos[1]) + params['ch_start']] - params['ica'].plot_components(picks=line_idx) + line_idx = np.searchsorted(offsets, pos[1]) + params['ch_start'] + ic_idx = [params['picks'][line_idx]] + types = list() + info = params['ica'].info + if len(pick_types(info, meg=False, eeg=True, ref_meg=False)) > 0: + types.append('eeg') + if len(pick_types(info, meg='mag', ref_meg=False)) > 0: + types.append('mag') + if len(pick_types(info, meg='grad', ref_meg=False)) > 0: + types.append('grad') + + ica = params['ica'] + data = np.dot(ica.mixing_matrix_[:, ic_idx].T, + ica.pca_components_[:ica.n_components_]) + data = np.atleast_2d(data) + fig, axes = _prepare_trellis(len(types), max_col=3) + for ch_idx, ch_type in enumerate(types): + data_picks, pos, merge_grads, _, _ = _prepare_topo_plot(ica, ch_type, + None) + this_data = data[:, data_picks] + ax = axes[ch_idx] + if merge_grads: + from ..channels.layout import _merge_grad_data + for ii, data_ in zip(ic_idx, this_data): + ax.set_title('IC #%03d ' % ii + ch_type, fontsize=12) + data_ = _merge_grad_data(data_) if merge_grads else data_ + plot_topomap(data_.flatten(), pos, axis=ax, show=False)[0] + ax.set_yticks([]) + ax.set_xticks([]) + ax.set_frame_on(False) + tight_layout(fig=fig) + fig.subplots_adjust(top=0.95) + fig.canvas.draw() + + plt.show() + """ + try: + params['ica'].plot_components(picks=line_idx) + except: + pass + """ diff --git a/mne/viz/raw.py b/mne/viz/raw.py index 6f713bc795e..9a4f2534adc 100644 --- a/mne/viz/raw.py +++ b/mne/viz/raw.py @@ -348,7 +348,7 @@ def _label_clicked(pos, params): bads = params['info']['bads'] if text in bads: bads.remove(text) - color = vars(params['lines'][ch_idx])['def_color'] + color = vars(params['lines'][line_idx])['def_color'] params['ax_vscroll'].patches[ch_idx + 1].set_color(color) else: bads.append(text) From 063a221d657eb2acd25954beb043bc600f729404 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Wed, 8 Jul 2015 12:58:32 +0300 Subject: [PATCH 22/36] Cleaning. Docs. --- mne/preprocessing/ica.py | 4 +++- mne/viz/ica.py | 6 ------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 48abbb9bca6..35c068788c2 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -1378,7 +1378,9 @@ def plot_sources(self, inst, picks=None, exclude=None, start=None, ----- For raw and epoch instances, it is possible to select components for exclusion by clicking on the line. The selected components are added to - ``ica.exclude`` on close. + ``ica.exclude`` on close. The independent components can be viewed as + topographies by clicking on the component name on the left of of the + main axes. .. versionadded:: 0.10.0 """ diff --git a/mne/viz/ica.py b/mne/viz/ica.py index 139876e1ab9..17547d9e3f4 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -706,9 +706,3 @@ def _label_clicked(pos, params): fig.canvas.draw() plt.show() - """ - try: - params['ica'].plot_components(picks=line_idx) - except: - pass - """ From fff6c7f5708aa6620c2200e3a2a715daf41a5b34 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Wed, 8 Jul 2015 15:16:08 +0300 Subject: [PATCH 23/36] Docs. Tests. Checks. --- mne/preprocessing/ica.py | 3 ++- mne/viz/epochs.py | 2 ++ mne/viz/ica.py | 15 ++++++++++++--- mne/viz/raw.py | 2 ++ mne/viz/tests/test_epochs.py | 1 + mne/viz/tests/test_raw.py | 1 + 6 files changed, 20 insertions(+), 4 deletions(-) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 35c068788c2..65286e78144 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -1380,7 +1380,8 @@ def plot_sources(self, inst, picks=None, exclude=None, start=None, exclusion by clicking on the line. The selected components are added to ``ica.exclude`` on close. The independent components can be viewed as topographies by clicking on the component name on the left of of the - main axes. + main axes. The topography view tries to infer the correct electrode + layout from the data. This should work at least for Neuromag data. .. versionadded:: 0.10.0 """ diff --git a/mne/viz/epochs.py b/mne/viz/epochs.py index 47489c161ec..c0feca41f8c 100644 --- a/mne/viz/epochs.py +++ b/mne/viz/epochs.py @@ -1091,6 +1091,8 @@ def _pick_bad_channels(pos, params): offsets = np.array(params['offsets']) + params['offsets'][0] line_idx = np.searchsorted(offsets, pos[1]) text = labels[line_idx].get_text() + if len(text) == 0: + return ch_idx = params['ch_start'] + line_idx if text in params['info']['bads']: params['info']['bads'].remove(text) diff --git a/mne/viz/ica.py b/mne/viz/ica.py index 17547d9e3f4..9334f6b50d3 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -19,9 +19,10 @@ from .epochs import _prepare_mne_browse_epochs from .evoked import _butterfly_on_button_press, _butterfly_onpick from .topomap import _prepare_topo_plot, plot_topomap +from ..utils import logger from ..defaults import _handle_default from ..io.meas_info import create_info -from mne.io.pick import pick_types +from ..io.pick import pick_types def _ica_plot_sources_onpick_(event, sources=None, ylims=None): @@ -672,6 +673,8 @@ def _label_clicked(pos, params): import matplotlib.pyplot as plt offsets = np.array(params['offsets']) + params['offsets'][0] line_idx = np.searchsorted(offsets, pos[1]) + params['ch_start'] + if line_idx >= len(params['picks']): + return ic_idx = [params['picks'][line_idx]] types = list() info = params['ica'].info @@ -688,8 +691,14 @@ def _label_clicked(pos, params): data = np.atleast_2d(data) fig, axes = _prepare_trellis(len(types), max_col=3) for ch_idx, ch_type in enumerate(types): - data_picks, pos, merge_grads, _, _ = _prepare_topo_plot(ica, ch_type, - None) + try: + data_picks, pos, merge_grads, _, _ = _prepare_topo_plot(ica, + ch_type, + None) + except Exception as exc: + logger.warning(exc) + plt.close(fig) + return this_data = data[:, data_picks] ax = axes[ch_idx] if merge_grads: diff --git a/mne/viz/raw.py b/mne/viz/raw.py index 9a4f2534adc..57d327c850f 100644 --- a/mne/viz/raw.py +++ b/mne/viz/raw.py @@ -344,6 +344,8 @@ def _label_clicked(pos, params): offsets = np.array(params['offsets']) + params['offsets'][0] line_idx = np.searchsorted(offsets, pos[1]) text = labels[line_idx].get_text() + if len(text) == 0: + return ch_idx = params['ch_start'] + line_idx bads = params['info']['bads'] if text in bads: diff --git a/mne/viz/tests/test_epochs.py b/mne/viz/tests/test_epochs.py index 3f341c2d50e..7e826c647e9 100644 --- a/mne/viz/tests/test_epochs.py +++ b/mne/viz/tests/test_epochs.py @@ -140,6 +140,7 @@ def test_plot_epochs(): _fake_click(fig, data_ax, [x, y], xform='data') # mark a bad epoch _fake_click(fig, data_ax, [x, y], xform='data') # unmark a bad epoch _fake_click(fig, data_ax, [0.5, 0.999]) # click elsewhere in 1st axes + _fake_click(fig, data_ax, [-0.1, 0.9]) # click on y-label _fake_click(fig, fig.get_axes()[2], [0.5, 0.5]) # change epochs _fake_click(fig, fig.get_axes()[3], [0.5, 0.5]) # change channels fig.canvas.close_event() # closing and epoch dropping diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py index cabcfca7728..f8b97babbfe 100644 --- a/mne/viz/tests/test_raw.py +++ b/mne/viz/tests/test_raw.py @@ -48,6 +48,7 @@ def test_plot_raw(): _fake_click(fig, data_ax, [x, y], xform='data') # mark a bad channel _fake_click(fig, data_ax, [x, y], xform='data') # unmark a bad channel _fake_click(fig, data_ax, [0.5, 0.999]) # click elsewhere in 1st axes + _fake_click(fig, data_ax, [-0.1, 0.9]) # click on y-label _fake_click(fig, fig.get_axes()[1], [0.5, 0.5]) # change time _fake_click(fig, fig.get_axes()[2], [0.5, 0.5]) # change channels _fake_click(fig, fig.get_axes()[3], [0.5, 0.5]) # open SSP window From d6d33faec11cf43e4f85d91818c43907ab4dea32 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Thu, 9 Jul 2015 13:17:01 +0300 Subject: [PATCH 24/36] Customized help dialog. Fixes. --- mne/viz/epochs.py | 72 +++++++++++++---------------------------------- mne/viz/raw.py | 2 +- mne/viz/utils.py | 59 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 79 insertions(+), 54 deletions(-) diff --git a/mne/viz/epochs.py b/mne/viz/epochs.py index c0feca41f8c..1e13092c414 100644 --- a/mne/viz/epochs.py +++ b/mne/viz/epochs.py @@ -23,7 +23,7 @@ from ..time_frequency import compute_epochs_psd from .utils import tight_layout, _prepare_trellis, figure_nobar from .utils import _toggle_options, _toggle_proj, _layout_figure -from .utils import _channels_changed, _plot_raw_onscroll +from .utils import _channels_changed, _plot_raw_onscroll, _get_help_text from ..defaults import _handle_default @@ -662,7 +662,6 @@ def _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings, ch_names = [params['info']['ch_names'][x] for x in inds] # set up plotting - size = get_config('MNE_BROWSE_RAW_SIZE') n_epochs = min(n_epochs, len(epochs.events)) duration = len(epochs.times) * n_epochs @@ -697,7 +696,7 @@ def _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings, ax_help_button = plt.subplot2grid((10, 15), (9, 0), colspan=1) help_button = mpl.widgets.Button(ax_help_button, 'Help') - help_button.on_clicked(_onclick_help) + help_button.on_clicked(partial(_onclick_help, params=params)) # populate vertical and horizontal scrollbars for ci in range(len(picks)): @@ -736,7 +735,7 @@ def _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings, times = epochs.times data = np.zeros((params['info']['nchan'], len(times) * len(epochs.events))) - ylim = (25., 0.) + ylim = (25., 0.) # Hardcoded 25 because butterfly has max 5 rows (5*5=25). # make shells for plotting traces offset = ylim[0] / n_channels offsets = np.arange(n_channels) * offset + (offset / 2.) @@ -813,7 +812,8 @@ def _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings, 'vertline_t': vertline_t, 'butterfly': False, 'text': text, - 'ax_help_button': ax_help_button, + 'ax_help_button': ax_help_button, # needed for positioning + 'help_button': help_button, # reference needed for clicks 'fig_options': None, 'settings': [True, True, True, True]}) @@ -1039,6 +1039,9 @@ def _plot_vert_lines(params): ax = params['ax'] while len(ax.lines) > 0: ax.lines.pop() + params['vert_lines'] = list() + params['vertline_t'].set_text('') + epochs = params['epochs'] if params['settings'][3]: # if zeroline visible t_zero = np.where(epochs.times == 0.)[0] @@ -1102,7 +1105,10 @@ def _pick_bad_channels(pos, params): params['info']['bads'].append(text) color = params['bad_color'] params['ax_vscroll'].patches[ch_idx + 1].set_color(color) - params['plot_fun']() + if 'ica' in params: + params['plot_fun']() + else: + params['plot_update_proj_callback'](params, None) def _plot_onscroll(event, params): @@ -1163,12 +1169,12 @@ def _mouse_click(event, params): while len(params['vert_lines']) > 0: params['ax'].lines.remove(params['vert_lines'][0][0]) params['vert_lines'].pop(0) - if prev_xdata == xdata: + if prev_xdata == xdata: # lines removed params['vertline_t'].set_text('') params['plot_fun']() return ylim = params['ax'].get_ylim() - for epoch_idx in range(params['n_epochs']): + for epoch_idx in range(params['n_epochs']): # plot lines pos = [epoch_idx * n_times + xdata, epoch_idx * n_times + xdata] params['vert_lines'].append(params['ax'].plot(pos, ylim, 'y', zorder=4)) @@ -1286,7 +1292,7 @@ def _plot_onkey(event, params): if not params['butterfly']: _open_options(params) elif event.key == '?': - _onclick_help(event) + _onclick_help(event, params) params['plot_fun']() elif event.key == 'escape': plt.close(params['fig']) @@ -1407,53 +1413,13 @@ def _resize_event(event, params): _layout_figure(params) -def _onclick_help(event): +def _onclick_help(event, params): """Function for drawing help window""" import matplotlib.pyplot as plt + text, text2 = _get_help_text(params) - text = u'\u2190 : \n'\ - u'\u2192 : \n'\ - u'\u2193 : \n'\ - u'\u2191 : \n'\ - u'- : \n'\ - u'+ or = : \n'\ - u'Home : \n'\ - u'End : \n'\ - u'Page down : \n'\ - u'Page up : \n'\ - u'b : \n'\ - u'o : \n'\ - u'F11 : \n'\ - u'? : \n'\ - u'Esc : \n\n'\ - u'Mouse controls\n'\ - u'click epoch :\n'\ - u'click channel name :\n'\ - u'right click :\n'\ - u'middle click :\n' - - text2 = 'Navigate left\n'\ - 'Navigate right\n'\ - 'Navigate channels down\n'\ - 'Navigate channels up\n'\ - 'Scale down\n'\ - 'Scale up\n'\ - 'Reduce the number of epochs per view\n'\ - 'Increase the number of epochs per view\n'\ - 'Reduce the number of channels per view\n'\ - 'Increase the number of channels per view\n'\ - 'Toggle butterfly plot on/off\n'\ - 'View settings (orig. view only)\n'\ - 'Toggle full screen mode\n'\ - 'Open help box\n'\ - 'Quit\n\n\n'\ - 'Mark bad epoch\n'\ - 'Mark bad channel\n'\ - 'Verticlal line at a time instant\n'\ - 'Show channel name (butterfly plot)\n' - - width = 5.5 - height = 0.25 * 19 # 19 rows of text + width = 6 + height = 5 fig_help = figure_nobar(figsize=(width, height), dpi=80) fig_help.canvas.set_window_title('Help') diff --git a/mne/viz/raw.py b/mne/viz/raw.py index 57d327c850f..27a1636cabc 100644 --- a/mne/viz/raw.py +++ b/mne/viz/raw.py @@ -357,7 +357,7 @@ def _label_clicked(pos, params): color = params['bad_color'] params['ax_vscroll'].patches[ch_idx + 1].set_color(color) params['raw'].info['bads'] = bads - params['plot_fun']() + _plot_update_raw_proj(params, None) def _set_psd_plot_params(info, proj, picks, ax, area_mode): diff --git a/mne/viz/utils.py b/mne/viz/utils.py index a4f92fadb57..848309afc7e 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -213,6 +213,65 @@ def _toggle_proj(event, params): params['plot_update_proj_callback'](params, bools) +def _get_help_text(params): + """Aux function for customizing help dialogs text.""" + text, text2 = list(), list() + + text.append(u'\u2190 : \n') + text.append(u'\u2192 : \n') + text.append(u'\u2193 : \n') + text.append(u'\u2191 : \n') + text.append(u'- : \n') + text.append(u'+ or = : \n') + text2.append('Navigate left\n') + text2.append('Navigate right\n') + text2.append('Scale down\n') + text2.append('Scale up\n') + if 'epochs' in params: + text.append(u'Home : \n') + text.append(u'End : \n') + text.append(u'Page down : \n') + text.append(u'Page up : \n') + text.append(u'o : \n') + text.append(u'F11 : \n') + text.append(u'? : \n') + text.append(u'Esc : \n\n') + text.append(u'Mouse controls\n') + text.append(u'click on main axes :\n') + text.append(u'right click :\n') + + text2.append('Reduce the number of epochs per view\n') + text2.append('Increase the number of epochs per view\n') + text2.append('View settings (orig. view only)\n') + text2.append('Toggle full screen mode\n') + text2.append('Open help box\n') + text2.append('Quit\n\n\n') + if 'ica' in params: + text.append(u'click component name :\n') + text2.insert(2, 'Navigate components down\n') + text2.insert(3, 'Navigate components up\n') + text2.insert(8, 'Reduce the number of components per view\n') + text2.insert(9, 'Increase the number of components per view\n') + text2.append('Mark component for exclusion\n') + text2.append('Vertical line at a time instant\n') + text2.append('Show topography for the component\n') + else: + text.append(u'click channel name :\n') + text2.insert(2, 'Navigate channels down\n') + text2.insert(3, 'Navigate channels up\n') + text2.insert(8, 'Reduce the number of channels per view\n') + text2.insert(9, 'Increase the number of channels per view\n') + text.insert(10, u'b : \n') + text2.insert(10, 'Toggle butterfly plot on/off\n') + text.append(u'middle click :\n') + text2.append('Mark bad epoch\n') + text2.append('Vertical line at a time instant\n') + text2.append('Mark bad channel\n') + text2.append('Show channel name (butterfly plot)\n') + + return ''.join(text), ''.join(text2) + + def _prepare_trellis(n_cells, max_col): """Aux function """ From 265992a5742da9f2c6a7b78a3353e142f0968ae8 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Thu, 9 Jul 2015 14:43:26 +0300 Subject: [PATCH 25/36] Help box for browse_raw. --- mne/viz/epochs.py | 36 +----------- mne/viz/tests/test_raw.py | 1 + mne/viz/utils.py | 113 ++++++++++++++++++++++++++++++-------- 3 files changed, 91 insertions(+), 59 deletions(-) diff --git a/mne/viz/epochs.py b/mne/viz/epochs.py index 1e13092c414..db0e3eaa89c 100644 --- a/mne/viz/epochs.py +++ b/mne/viz/epochs.py @@ -23,7 +23,7 @@ from ..time_frequency import compute_epochs_psd from .utils import tight_layout, _prepare_trellis, figure_nobar from .utils import _toggle_options, _toggle_proj, _layout_figure -from .utils import _channels_changed, _plot_raw_onscroll, _get_help_text +from .utils import _channels_changed, _plot_raw_onscroll, _onclick_help from ..defaults import _handle_default @@ -1293,7 +1293,6 @@ def _plot_onkey(event, params): _open_options(params) elif event.key == '?': _onclick_help(event, params) - params['plot_fun']() elif event.key == 'escape': plt.close(params['fig']) @@ -1413,39 +1412,6 @@ def _resize_event(event, params): _layout_figure(params) -def _onclick_help(event, params): - """Function for drawing help window""" - import matplotlib.pyplot as plt - text, text2 = _get_help_text(params) - - width = 6 - height = 5 - - fig_help = figure_nobar(figsize=(width, height), dpi=80) - fig_help.canvas.set_window_title('Help') - ax = plt.subplot2grid((8, 5), (0, 0), colspan=5) - ax.set_title('Keyboard shortcuts') - plt.axis('off') - ax1 = plt.subplot2grid((8, 5), (1, 0), rowspan=7, colspan=2) - ax1.set_yticklabels(list()) - plt.text(0.99, 1, text, fontname='STIXGeneral', va='top', weight='bold', - ha='right') - plt.axis('off') - - ax2 = plt.subplot2grid((8, 5), (1, 2), rowspan=7, colspan=3) - ax2.set_yticklabels(list()) - plt.text(0, 1, text2, fontname='STIXGeneral', va='top') - plt.axis('off') - - tight_layout(fig=fig_help) - # this should work for non-test cases - try: - fig_help.canvas.draw() - fig_help.show() - except Exception: - pass - - def _update_channels_epochs(event, params): """Function for changing the amount of channels and epochs per view.""" from matplotlib.collections import LineCollection diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py index f8b97babbfe..62f2e9b06d2 100644 --- a/mne/viz/tests/test_raw.py +++ b/mne/viz/tests/test_raw.py @@ -77,6 +77,7 @@ def test_plot_raw(): fig.canvas.key_press_event('pagedown') fig.canvas.key_press_event('home') fig.canvas.key_press_event('end') + fig.canvas.key_press_event('?') fig.canvas.key_press_event('f11') fig.canvas.key_press_event('escape') # Color setting diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 848309afc7e..e1b60b59ee6 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -223,29 +223,53 @@ def _get_help_text(params): text.append(u'\u2191 : \n') text.append(u'- : \n') text.append(u'+ or = : \n') + text.append(u'Home : \n') + text.append(u'End : \n') + text.append(u'Page down : \n') + text.append(u'Page up : \n') + + text.append(u'F11 : \n') + text.append(u'? : \n') + text.append(u'Esc : \n\n') + text.append(u'Mouse controls\n') + text.append(u'click on data :\n') + text2.append('Navigate left\n') text2.append('Navigate right\n') + text2.append('Scale down\n') text2.append('Scale up\n') - if 'epochs' in params: - text.append(u'Home : \n') - text.append(u'End : \n') - text.append(u'Page down : \n') - text.append(u'Page up : \n') - text.append(u'o : \n') - text.append(u'F11 : \n') - text.append(u'? : \n') - text.append(u'Esc : \n\n') - text.append(u'Mouse controls\n') - text.append(u'click on main axes :\n') - text.append(u'right click :\n') - text2.append('Reduce the number of epochs per view\n') - text2.append('Increase the number of epochs per view\n') - text2.append('View settings (orig. view only)\n') - text2.append('Toggle full screen mode\n') - text2.append('Open help box\n') - text2.append('Quit\n\n\n') + text2.append('Toggle full screen mode\n') + text2.append('Open help box\n') + text2.append('Quit\n\n\n') + if 'raw' in params: + text2.insert(4, 'Reduce the time shown per view\n') + text2.insert(5, 'Increase the time shown per view\n') + text.append(u'click elsewhere in the plot :\n') + if 'ica' in params: + text.append(u'click component name :\n') + text2.insert(2, 'Navigate components down\n') + text2.insert(3, 'Navigate components up\n') + text2.insert(8, 'Reduce the number of components per view\n') + text2.insert(9, 'Increase the number of components per view\n') + text2.append('Mark bad channel\n') + text2.append('Vertical line at a time instant\n') + text2.append('Show topography for the component\n') + else: + text.append(u'click channel name :\n') + text2.insert(2, 'Navigate channels down\n') + text2.insert(3, 'Navigate channels up\n') + text2.insert(8, 'Reduce the number of channels per view\n') + text2.insert(9, 'Increase the number of channels per view\n') + text2.append('Mark bad channel\n') + text2.append('Vertical line at a time instant\n') + text2.append('Mark bad channel\n') + + elif 'epochs' in params: + text.append(u'right click :\n') + text2.insert(4, 'Reduce the number of epochs per view\n') + text2.insert(5, 'Increase the number of epochs per view\n') if 'ica' in params: text.append(u'click component name :\n') text2.insert(2, 'Navigate components down\n') @@ -263,11 +287,13 @@ def _get_help_text(params): text2.insert(9, 'Increase the number of channels per view\n') text.insert(10, u'b : \n') text2.insert(10, 'Toggle butterfly plot on/off\n') - text.append(u'middle click :\n') text2.append('Mark bad epoch\n') text2.append('Vertical line at a time instant\n') text2.append('Mark bad channel\n') + text.append(u'middle click :\n') text2.append('Show channel name (butterfly plot)\n') + text.insert(11, u'o : \n') + text2.insert(11, 'View settings (orig. view only)\n') return ''.join(text), ''.join(text2) @@ -461,18 +487,23 @@ def _prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds, fig = figure_nobar(facecolor=bgcolor, figsize=size) fig.canvas.set_window_title('mne_browse_raw') - ax = plt.subplot2grid((10, 10), (0, 0), colspan=9, rowspan=9) + ax = plt.subplot2grid((10, 10), (0, 1), colspan=8, rowspan=9) ax.set_title(title, fontsize=12) - ax_hscroll = plt.subplot2grid((10, 10), (9, 0), colspan=9) + ax_hscroll = plt.subplot2grid((10, 10), (9, 1), colspan=8) ax_hscroll.get_yaxis().set_visible(False) ax_hscroll.set_xlabel('Time (s)') ax_vscroll = plt.subplot2grid((10, 10), (0, 9), rowspan=9) ax_vscroll.set_axis_off() + ax_help_button = plt.subplot2grid((10, 10), (0, 0), colspan=1) + help_button = mpl.widgets.Button(ax_help_button, 'Help') + help_button.on_clicked(partial(_onclick_help, params=params)) # store these so they can be fixed on resize params['fig'] = fig params['ax'] = ax params['ax_hscroll'] = ax_hscroll params['ax_vscroll'] = ax_vscroll + params['ax_help_button'] = ax_help_button + params['help_button'] = help_button # populate vertical and horizontal scrollbars info = params['info'] @@ -516,9 +547,8 @@ def _prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds, params['ax_vertline'] = ax.plot([0, 0], ylim, color=vertline_color, zorder=-1)[0] params['ax_vertline'].ch_name = '' - params['vertline_t'] = ax_hscroll.text(0, 0.5, '', color=vertline_color, - verticalalignment='center', - horizontalalignment='right') + params['vertline_t'] = ax_hscroll.text(0, 1, '', color=vertline_color, + va='bottom', ha='right') params['ax_hscroll_vertline'] = ax_hscroll.plot([0, 0], [0, 1], color=vertline_color, zorder=1)[0] @@ -631,6 +661,8 @@ def _plot_raw_onkey(event, params): params['hsel_patch'].set_width(params['duration']) params['update_fun']() params['plot_fun']() + elif event.key == '?': + _onclick_help(event, params) elif event.key == 'f11': mng = plt.get_current_fig_manager() mng.full_screen_toggle() @@ -780,6 +812,39 @@ def _plot_raw_traces(params, inds, color, bad_color, event_lines=None, params['fig_proj'].canvas.draw() +def _onclick_help(event, params): + """Function for drawing help window""" + import matplotlib.pyplot as plt + text, text2 = _get_help_text(params) + + width = 6 + height = 5 + + fig_help = figure_nobar(figsize=(width, height), dpi=80) + fig_help.canvas.set_window_title('Help') + ax = plt.subplot2grid((8, 5), (0, 0), colspan=5) + ax.set_title('Keyboard shortcuts') + plt.axis('off') + ax1 = plt.subplot2grid((8, 5), (1, 0), rowspan=7, colspan=2) + ax1.set_yticklabels(list()) + plt.text(0.99, 1, text, fontname='STIXGeneral', va='top', weight='bold', + ha='right') + plt.axis('off') + + ax2 = plt.subplot2grid((8, 5), (1, 2), rowspan=7, colspan=3) + ax2.set_yticklabels(list()) + plt.text(0, 1, text2, fontname='STIXGeneral', va='top') + plt.axis('off') + + tight_layout(fig=fig_help) + # this should work for non-test cases + try: + fig_help.canvas.draw() + fig_help.show() + except Exception: + pass + + class ClickableImage(object): """ From a3781a1c6c562327dfb6725f0206efeb324941e1 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Fri, 10 Jul 2015 10:04:54 +0300 Subject: [PATCH 26/36] Added interactive mode for viewing ica topos. --- mne/viz/ica.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mne/viz/ica.py b/mne/viz/ica.py index 9334f6b50d3..684ef45866b 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -529,6 +529,7 @@ def _plot_raw_components(ica, raw, picks, exclude, start, stop, show, title, block): """Function for plotting the ICA components as raw array.""" import matplotlib.pyplot as plt + plt.ion() color = _handle_default('color', (0., 0., 0.)) orig_data = ica._transform_raw(raw, 0, len(raw.times)) * 0.2 if picks is None: From 3682087f345c6412830058e067bf1732e625c854 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Fri, 10 Jul 2015 11:17:43 +0300 Subject: [PATCH 27/36] Tests. --- mne/viz/tests/test_epochs.py | 2 ++ mne/viz/tests/test_ica.py | 2 ++ mne/viz/tests/test_raw.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/mne/viz/tests/test_epochs.py b/mne/viz/tests/test_epochs.py index 7e826c647e9..e2692dd40c8 100644 --- a/mne/viz/tests/test_epochs.py +++ b/mne/viz/tests/test_epochs.py @@ -112,6 +112,8 @@ def test_plot_epochs(): fig = epochs.plot(trellis=False) fig.canvas.key_press_event('left') fig.canvas.key_press_event('right') + fig.canvas.scroll_event(0.5, 0.5, -0.5) # scroll down + fig.canvas.scroll_event(0.5, 0.5, 0.5) # scroll up fig.canvas.key_press_event('up') fig.canvas.key_press_event('down') fig.canvas.key_press_event('pageup') diff --git a/mne/viz/tests/test_ica.py b/mne/viz/tests/test_ica.py index fde26575464..deb42f16ff0 100644 --- a/mne/viz/tests/test_ica.py +++ b/mne/viz/tests/test_ica.py @@ -169,6 +169,7 @@ def test_plot_instance_components(): ax = fig.get_axes()[0] line = ax.lines[0] _fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], 'data') + _fake_click(fig, ax, [-0.1, 0.9]) # click on y-label fig.canvas.key_press_event('escape') plt.close('all') epochs = _get_epochs() @@ -190,6 +191,7 @@ def test_plot_instance_components(): ax = fig.get_axes()[0] line = ax.lines[0] _fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], 'data') + _fake_click(fig, ax, [-0.1, 0.9]) # click on y-label fig.canvas.key_press_event('escape') plt.close('all') diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py index 62f2e9b06d2..bdb8e46d3d8 100644 --- a/mne/viz/tests/test_raw.py +++ b/mne/viz/tests/test_raw.py @@ -53,6 +53,8 @@ def test_plot_raw(): _fake_click(fig, fig.get_axes()[2], [0.5, 0.5]) # change channels _fake_click(fig, fig.get_axes()[3], [0.5, 0.5]) # open SSP window fig.canvas.button_press_event(1, 1, 1) # outside any axes + fig.canvas.scroll_event(0.5, 0.5, -0.5) # scroll down + fig.canvas.scroll_event(0.5, 0.5, 0.5) # scroll up # sadly these fail when no renderer is used (i.e., when using Agg): # ssp_fig = set(plt.get_fignums()) - set([fig.number]) # assert_equal(len(ssp_fig), 1) From 9fc33c69ec6ebfa4f14cc98bb2d947f5da6a2a85 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Fri, 10 Jul 2015 12:40:16 +0300 Subject: [PATCH 28/36] Reorganizing. Fix. --- mne/viz/ica.py | 6 +- mne/viz/raw.py | 172 +++++++++++++++++++++++++++++++++++++++++++++-- mne/viz/utils.py | 162 +------------------------------------------- 3 files changed, 170 insertions(+), 170 deletions(-) diff --git a/mne/viz/ica.py b/mne/viz/ica.py index 684ef45866b..a6900cd0322 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -12,10 +12,10 @@ import numpy as np -from .utils import tight_layout, _prepare_trellis, _prepare_mne_browse_raw +from .utils import tight_layout, _prepare_trellis, _select_bads from .utils import _layout_figure, _plot_raw_onscroll, _mouse_click -from .utils import _plot_raw_traces, _helper_raw_resize, _plot_raw_onkey -from .utils import _select_bads +from .utils import _helper_raw_resize, _plot_raw_onkey +from .raw import _prepare_mne_browse_raw, _plot_raw_traces from .epochs import _prepare_mne_browse_epochs from .evoked import _butterfly_on_button_press, _butterfly_onpick from .topomap import _prepare_topo_plot, plot_topomap diff --git a/mne/viz/raw.py b/mne/viz/raw.py index 27a1636cabc..42667fc2db0 100644 --- a/mne/viz/raw.py +++ b/mne/viz/raw.py @@ -15,12 +15,12 @@ from ..externals.six import string_types from ..io.pick import pick_types from ..io.proj import setup_proj -from ..utils import verbose +from ..utils import verbose, get_config from ..time_frequency import compute_raw_psd from .utils import _toggle_options, _toggle_proj, tight_layout -from .utils import _layout_figure, _prepare_mne_browse_raw, _plot_raw_onkey -from .utils import _plot_raw_onscroll, _plot_raw_traces, _mouse_click -from .utils import _helper_raw_resize, _select_bads +from .utils import _layout_figure, _plot_raw_onkey, figure_nobar +from .utils import _plot_raw_onscroll, _mouse_click +from .utils import _helper_raw_resize, _select_bads, _onclick_help from ..defaults import _handle_default @@ -351,11 +351,11 @@ def _label_clicked(pos, params): if text in bads: bads.remove(text) color = vars(params['lines'][line_idx])['def_color'] - params['ax_vscroll'].patches[ch_idx + 1].set_color(color) + params['ax_vscroll'].patches[ch_idx].set_color(color) else: bads.append(text) color = params['bad_color'] - params['ax_vscroll'].patches[ch_idx + 1].set_color(color) + params['ax_vscroll'].patches[ch_idx].set_color(color) params['raw'].info['bads'] = bads _plot_update_raw_proj(params, None) @@ -502,3 +502,163 @@ def plot_raw_psd(raw, tmin=0., tmax=np.inf, fmin=0, fmax=np.inf, proj=False, if show is True: plt.show() return fig + + +def _prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds, + n_channels): + """Helper for setting up the mne_browse_raw window.""" + import matplotlib.pyplot as plt + import matplotlib as mpl + size = get_config('MNE_BROWSE_RAW_SIZE') + if size is not None: + size = size.split(',') + size = tuple([float(s) for s in size]) + + fig = figure_nobar(facecolor=bgcolor, figsize=size) + fig.canvas.set_window_title('mne_browse_raw') + ax = plt.subplot2grid((10, 10), (0, 1), colspan=8, rowspan=9) + ax.set_title(title, fontsize=12) + ax_hscroll = plt.subplot2grid((10, 10), (9, 1), colspan=8) + ax_hscroll.get_yaxis().set_visible(False) + ax_hscroll.set_xlabel('Time (s)') + ax_vscroll = plt.subplot2grid((10, 10), (0, 9), rowspan=9) + ax_vscroll.set_axis_off() + ax_help_button = plt.subplot2grid((10, 10), (0, 0), colspan=1) + help_button = mpl.widgets.Button(ax_help_button, 'Help') + help_button.on_clicked(partial(_onclick_help, params=params)) + # store these so they can be fixed on resize + params['fig'] = fig + params['ax'] = ax + params['ax_hscroll'] = ax_hscroll + params['ax_vscroll'] = ax_vscroll + params['ax_help_button'] = ax_help_button + params['help_button'] = help_button + + # populate vertical and horizontal scrollbars + info = params['info'] + for ci in range(len(info['ch_names'])): + this_color = (bad_color if info['ch_names'][inds[ci]] in info['bads'] + else color) + if isinstance(this_color, dict): + this_color = this_color[params['types'][inds[ci]]] + ax_vscroll.add_patch(mpl.patches.Rectangle((0, ci), 1, 1, + facecolor=this_color, + edgecolor=this_color)) + vsel_patch = mpl.patches.Rectangle((0, 0), 1, n_channels, alpha=0.5, + facecolor='w', edgecolor='w') + ax_vscroll.add_patch(vsel_patch) + params['vsel_patch'] = vsel_patch + hsel_patch = mpl.patches.Rectangle((params['t_start'], 0), + params['duration'], 1, edgecolor='k', + facecolor=(0.75, 0.75, 0.75), + alpha=0.25, linewidth=1, clip_on=False) + ax_hscroll.add_patch(hsel_patch) + params['hsel_patch'] = hsel_patch + ax_hscroll.set_xlim(0, params['n_times'] / float(info['sfreq'])) + n_ch = len(info['ch_names']) + ax_vscroll.set_ylim(n_ch, 0) + ax_vscroll.set_title('Ch.') + + # make shells for plotting traces + ylim = [n_channels * 2 + 1, 0] + offset = ylim[0] / n_channels + offsets = np.arange(n_channels) * offset + (offset / 2.) + ax.set_yticks(offsets) + ax.set_ylim(ylim) + ax.set_xlim(params['t_start'], params['t_start'] + params['duration'], + False) + + params['offsets'] = offsets + params['lines'] = [ax.plot([np.nan], antialiased=False, linewidth=0.5)[0] + for _ in range(n_ch)] + ax.set_yticklabels(['X' * max([len(ch) for ch in info['ch_names']])]) + vertline_color = (0., 0.75, 0.) + params['ax_vertline'] = ax.plot([0, 0], ylim, color=vertline_color, + zorder=-1)[0] + params['ax_vertline'].ch_name = '' + params['vertline_t'] = ax_hscroll.text(0, 1, '', color=vertline_color, + va='bottom', ha='right') + params['ax_hscroll_vertline'] = ax_hscroll.plot([0, 0], [0, 1], + color=vertline_color, + zorder=1)[0] + + +def _plot_raw_traces(params, inds, color, bad_color, event_lines=None, + event_color=None): + """Helper for plotting raw""" + lines = params['lines'] + info = params['info'] + n_channels = params['n_channels'] + params['bad_color'] = bad_color + # do the plotting + tick_list = [] + for ii in range(n_channels): + ch_ind = ii + params['ch_start'] + # let's be generous here and allow users to pass + # n_channels per view >= the number of traces available + if ii >= len(lines): + break + elif ch_ind < len(info['ch_names']): + # scale to fit + ch_name = info['ch_names'][inds[ch_ind]] + tick_list += [ch_name] + offset = params['offsets'][ii] + + # do NOT operate in-place lest this get screwed up + this_data = params['data'][inds[ch_ind]] * params['scale_factor'] + this_color = bad_color if ch_name in info['bads'] else color + this_z = -1 if ch_name in info['bads'] else 0 + if isinstance(this_color, dict): + this_color = this_color[params['types'][inds[ch_ind]]] + + # subtraction here gets corect orientation for flipped ylim + lines[ii].set_ydata(offset - this_data) + lines[ii].set_xdata(params['times']) + lines[ii].set_color(this_color) + lines[ii].set_zorder(this_z) + vars(lines[ii])['ch_name'] = ch_name + vars(lines[ii])['def_color'] = color[params['types'][inds[ch_ind]]] + else: + # "remove" lines + lines[ii].set_xdata([]) + lines[ii].set_ydata([]) + # deal with event lines + if params['event_times'] is not None: + # find events in the time window + event_times = params['event_times'] + mask = np.logical_and(event_times >= params['times'][0], + event_times <= params['times'][-1]) + event_times = event_times[mask] + event_nums = params['event_nums'][mask] + # plot them with appropriate colors + # go through the list backward so we end with -1, the catchall + used = np.zeros(len(event_times), bool) + ylim = params['ax'].get_ylim() + for ev_num, line in zip(sorted(event_color.keys())[::-1], + event_lines[::-1]): + mask = (event_nums == ev_num) if ev_num >= 0 else ~used + assert not np.any(used[mask]) + used[mask] = True + t = event_times[mask] + if len(t) > 0: + xs = list() + ys = list() + for tt in t: + xs += [tt, tt, np.nan] + ys += [0, ylim[0], np.nan] + line.set_xdata(xs) + line.set_ydata(ys) + else: + line.set_xdata([]) + line.set_ydata([]) + # finalize plot + params['ax'].set_xlim(params['times'][0], + params['times'][0] + params['duration'], False) + params['ax'].set_yticklabels(tick_list) + params['vsel_patch'].set_y(params['ch_start']) + params['fig'].canvas.draw() + # XXX This is a hack to make sure this figure gets drawn last + # so that when matplotlib goes to calculate bounds we don't get a + # CGContextRef error on the MacOSX backend :( + if params['fig_proj'] is not None: + params['fig_proj'].canvas.draw() diff --git a/mne/viz/utils.py b/mne/viz/utils.py index e1b60b59ee6..517f6645998 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -19,7 +19,7 @@ import numpy as np from ..io import show_fiff -from ..utils import verbose, get_config, set_config +from ..utils import verbose, set_config COLORS = ['b', 'g', 'r', 'c', 'm', 'y', 'k', '#473C8B', '#458B74', @@ -475,85 +475,6 @@ def figure_nobar(*args, **kwargs): return fig -def _prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds, - n_channels): - """Helper for setting up the mne_browse_raw window.""" - import matplotlib.pyplot as plt - import matplotlib as mpl - size = get_config('MNE_BROWSE_RAW_SIZE') - if size is not None: - size = size.split(',') - size = tuple([float(s) for s in size]) - - fig = figure_nobar(facecolor=bgcolor, figsize=size) - fig.canvas.set_window_title('mne_browse_raw') - ax = plt.subplot2grid((10, 10), (0, 1), colspan=8, rowspan=9) - ax.set_title(title, fontsize=12) - ax_hscroll = plt.subplot2grid((10, 10), (9, 1), colspan=8) - ax_hscroll.get_yaxis().set_visible(False) - ax_hscroll.set_xlabel('Time (s)') - ax_vscroll = plt.subplot2grid((10, 10), (0, 9), rowspan=9) - ax_vscroll.set_axis_off() - ax_help_button = plt.subplot2grid((10, 10), (0, 0), colspan=1) - help_button = mpl.widgets.Button(ax_help_button, 'Help') - help_button.on_clicked(partial(_onclick_help, params=params)) - # store these so they can be fixed on resize - params['fig'] = fig - params['ax'] = ax - params['ax_hscroll'] = ax_hscroll - params['ax_vscroll'] = ax_vscroll - params['ax_help_button'] = ax_help_button - params['help_button'] = help_button - - # populate vertical and horizontal scrollbars - info = params['info'] - for ci in range(len(info['ch_names'])): - this_color = (bad_color if info['ch_names'][inds[ci]] in info['bads'] - else color) - if isinstance(this_color, dict): - this_color = this_color[params['types'][inds[ci]]] - ax_vscroll.add_patch(mpl.patches.Rectangle((0, ci), 1, 1, - facecolor=this_color, - edgecolor=this_color)) - vsel_patch = mpl.patches.Rectangle((0, 0), 1, n_channels, alpha=0.5, - facecolor='w', edgecolor='w') - ax_vscroll.add_patch(vsel_patch) - params['vsel_patch'] = vsel_patch - hsel_patch = mpl.patches.Rectangle((params['t_start'], 0), - params['duration'], 1, edgecolor='k', - facecolor=(0.75, 0.75, 0.75), - alpha=0.25, linewidth=1, clip_on=False) - ax_hscroll.add_patch(hsel_patch) - params['hsel_patch'] = hsel_patch - ax_hscroll.set_xlim(0, params['n_times'] / float(info['sfreq'])) - n_ch = len(info['ch_names']) - ax_vscroll.set_ylim(n_ch, 0) - ax_vscroll.set_title('Ch.') - - # make shells for plotting traces - ylim = [n_channels * 2 + 1, 0] - offset = ylim[0] / n_channels - offsets = np.arange(n_channels) * offset + (offset / 2.) - ax.set_yticks(offsets) - ax.set_ylim(ylim) - ax.set_xlim(params['t_start'], params['t_start'] + params['duration'], - False) - - params['offsets'] = offsets - params['lines'] = [ax.plot([np.nan], antialiased=False, linewidth=0.5)[0] - for _ in range(n_ch)] - ax.set_yticklabels(['X' * max([len(ch) for ch in info['ch_names']])]) - vertline_color = (0., 0.75, 0.) - params['ax_vertline'] = ax.plot([0, 0], ylim, color=vertline_color, - zorder=-1)[0] - params['ax_vertline'].ch_name = '' - params['vertline_t'] = ax_hscroll.text(0, 1, '', color=vertline_color, - va='bottom', ha='right') - params['ax_hscroll_vertline'] = ax_hscroll.plot([0, 0], [0, 1], - color=vertline_color, - zorder=1)[0] - - def _helper_raw_resize(event, params): """Helper for resizing""" size = ','.join([str(s) for s in params['fig'].get_size_inches()]) @@ -731,87 +652,6 @@ def f(x, y): return bads -def _plot_raw_traces(params, inds, color, bad_color, event_lines=None, - event_color=None): - """Helper for plotting raw""" - lines = params['lines'] - info = params['info'] - n_channels = params['n_channels'] - params['bad_color'] = bad_color - # do the plotting - tick_list = [] - for ii in range(n_channels): - ch_ind = ii + params['ch_start'] - # let's be generous here and allow users to pass - # n_channels per view >= the number of traces available - if ii >= len(lines): - break - elif ch_ind < len(info['ch_names']): - # scale to fit - ch_name = info['ch_names'][inds[ch_ind]] - tick_list += [ch_name] - offset = params['offsets'][ii] - - # do NOT operate in-place lest this get screwed up - this_data = params['data'][inds[ch_ind]] * params['scale_factor'] - this_color = bad_color if ch_name in info['bads'] else color - this_z = -1 if ch_name in info['bads'] else 0 - if isinstance(this_color, dict): - this_color = this_color[params['types'][inds[ch_ind]]] - - # subtraction here gets corect orientation for flipped ylim - lines[ii].set_ydata(offset - this_data) - lines[ii].set_xdata(params['times']) - lines[ii].set_color(this_color) - lines[ii].set_zorder(this_z) - vars(lines[ii])['ch_name'] = ch_name - vars(lines[ii])['def_color'] = color[params['types'][inds[ch_ind]]] - else: - # "remove" lines - lines[ii].set_xdata([]) - lines[ii].set_ydata([]) - # deal with event lines - if params['event_times'] is not None: - # find events in the time window - event_times = params['event_times'] - mask = np.logical_and(event_times >= params['times'][0], - event_times <= params['times'][-1]) - event_times = event_times[mask] - event_nums = params['event_nums'][mask] - # plot them with appropriate colors - # go through the list backward so we end with -1, the catchall - used = np.zeros(len(event_times), bool) - ylim = params['ax'].get_ylim() - for ev_num, line in zip(sorted(event_color.keys())[::-1], - event_lines[::-1]): - mask = (event_nums == ev_num) if ev_num >= 0 else ~used - assert not np.any(used[mask]) - used[mask] = True - t = event_times[mask] - if len(t) > 0: - xs = list() - ys = list() - for tt in t: - xs += [tt, tt, np.nan] - ys += [0, ylim[0], np.nan] - line.set_xdata(xs) - line.set_ydata(ys) - else: - line.set_xdata([]) - line.set_ydata([]) - # finalize plot - params['ax'].set_xlim(params['times'][0], - params['times'][0] + params['duration'], False) - params['ax'].set_yticklabels(tick_list) - params['vsel_patch'].set_y(params['ch_start']) - params['fig'].canvas.draw() - # XXX This is a hack to make sure this figure gets drawn last - # so that when matplotlib goes to calculate bounds we don't get a - # CGContextRef error on the MacOSX backend :( - if params['fig_proj'] is not None: - params['fig_proj'].canvas.draw() - - def _onclick_help(event, params): """Function for drawing help window""" import matplotlib.pyplot as plt From 5e817d9c69d30f37f2f65f307c9626c893343139 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Fri, 10 Jul 2015 13:25:08 +0300 Subject: [PATCH 29/36] Unselecting removes duplicate bads. Removed adding of bad channel manually from example. --- examples/plot_from_raw_to_epochs_to_evoked.py | 1 - mne/viz/epochs.py | 3 ++- mne/viz/raw.py | 3 ++- mne/viz/utils.py | 3 ++- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/plot_from_raw_to_epochs_to_evoked.py b/examples/plot_from_raw_to_epochs_to_evoked.py index de5c10afce4..4b697be1153 100644 --- a/examples/plot_from_raw_to_epochs_to_evoked.py +++ b/examples/plot_from_raw_to_epochs_to_evoked.py @@ -39,7 +39,6 @@ # Set up pick list: EEG + STI 014 - bad channels (modify to your needs) include = [] # or stim channels ['STI 014'] -raw.info['bads'] += ['EEG 053'] # bads + 1 more # pick EEG and MEG channels picks = mne.pick_types(raw.info, meg=True, eeg=True, stim=False, eog=True, diff --git a/mne/viz/epochs.py b/mne/viz/epochs.py index db0e3eaa89c..cc5fd74f353 100644 --- a/mne/viz/epochs.py +++ b/mne/viz/epochs.py @@ -1098,7 +1098,8 @@ def _pick_bad_channels(pos, params): return ch_idx = params['ch_start'] + line_idx if text in params['info']['bads']: - params['info']['bads'].remove(text) + while text in params['info']['bads']: + params['info']['bads'].remove(text) color = params['def_colors'][ch_idx] params['ax_vscroll'].patches[ch_idx + 1].set_color(color) else: diff --git a/mne/viz/raw.py b/mne/viz/raw.py index 42667fc2db0..3f6e65b62db 100644 --- a/mne/viz/raw.py +++ b/mne/viz/raw.py @@ -349,7 +349,8 @@ def _label_clicked(pos, params): ch_idx = params['ch_start'] + line_idx bads = params['info']['bads'] if text in bads: - bads.remove(text) + while text in bads: # to make sure duplicates are removed + bads.remove(text) color = vars(params['lines'][line_idx])['def_color'] params['ax_vscroll'].patches[ch_idx].set_color(color) else: diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 517f6645998..6d4929190a6 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -638,7 +638,8 @@ def f(x, y): color = params['bad_color'] line.set_zorder(-1) else: - bads.pop(bads.index(this_chan)) + while this_chan in bads: + bads.remove(this_chan) color = vars(line)['def_color'] line.set_zorder(0) line.set_color(color) From 696f3694d4faa3a286a086bc1089e0d161407edf Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Mon, 13 Jul 2015 12:11:40 +0300 Subject: [PATCH 30/36] Component name format. Manual bad channel back to example. --- examples/plot_from_raw_to_epochs_to_evoked.py | 1 + mne/viz/ica.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/plot_from_raw_to_epochs_to_evoked.py b/examples/plot_from_raw_to_epochs_to_evoked.py index 4b697be1153..de5c10afce4 100644 --- a/examples/plot_from_raw_to_epochs_to_evoked.py +++ b/examples/plot_from_raw_to_epochs_to_evoked.py @@ -39,6 +39,7 @@ # Set up pick list: EEG + STI 014 - bad channels (modify to your needs) include = [] # or stim channels ['STI 014'] +raw.info['bads'] += ['EEG 053'] # bads + 1 more # pick EEG and MEG channels picks = mne.pick_types(raw.info, meg=True, eeg=True, stim=False, eog=True, diff --git a/mne/viz/ica.py b/mne/viz/ica.py index a6900cd0322..9c9d6dc8c58 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -537,7 +537,7 @@ def _plot_raw_components(ica, raw, picks, exclude, start, stop, show, title, types = np.repeat('misc', len(picks)) picks = sorted(picks) - c_names = ['ICA ' + str(x) for x in range(len(orig_data))] + c_names = ['ICA %03d' % x for x in range(len(orig_data))] if title is None: title = 'ICA components' info = create_info([c_names[x] for x in picks], raw.info['sfreq']) @@ -622,7 +622,7 @@ def _plot_epoch_components(ica, epochs, picks, exclude, start, stop, show, import matplotlib.pyplot as plt plt.ion() # Turn interactive mode on to avoid warnings. data = ica._transform_epochs(epochs, concatenate=True) - c_names = ['ICA ' + str(x) for x in range(ica.n_components_)] + c_names = ['ICA %03d' % x for x in range(ica.n_components_)] scalings = {'misc': 5.0} info = create_info(ch_names=c_names, sfreq=epochs.info['sfreq']) info['projs'] = list() From 71daaeb6f3da83a1bb96c04b93a620dfc1725aa8 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Mon, 13 Jul 2015 13:07:44 +0300 Subject: [PATCH 31/36] Fix to exclusion. --- mne/viz/ica.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mne/viz/ica.py b/mne/viz/ica.py index 9c9d6dc8c58..7e0bea05693 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -106,7 +106,8 @@ def plot_ica_sources(ica, inst, picks=None, exclude=None, start=None, if exclude is None: exclude = ica.exclude - + else: + exclude = np.union1d(ica.exclude, exclude) if isinstance(inst, _BaseRaw): fig = _plot_raw_components(ica, inst, picks, exclude, start=start, stop=stop, show=show, title=title, @@ -611,8 +612,8 @@ def _pick_bads(event, params): def _close_event(events, params): """Function for excluding the selected components on close.""" info = params['info'] - picks = params['picks'] - exclude = [picks[info['ch_names'].index(x)] for x in info['bads']] + c_names = ['ICA %03d' % x for x in range(len(params['orig_data']))] + exclude = [c_names.index(x) for x in info['bads']] params['ica'].exclude = exclude From ec50bd4ea41ab2770883d5cfb427d1c9ee61be2c Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Mon, 13 Jul 2015 15:48:06 +0300 Subject: [PATCH 32/36] Fix. --- mne/viz/ica.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/viz/ica.py b/mne/viz/ica.py index 7e0bea05693..ca63a5fa745 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -106,7 +106,7 @@ def plot_ica_sources(ica, inst, picks=None, exclude=None, start=None, if exclude is None: exclude = ica.exclude - else: + elif len(ica.exclude) > 0: exclude = np.union1d(ica.exclude, exclude) if isinstance(inst, _BaseRaw): fig = _plot_raw_components(ica, inst, picks, exclude, start=start, From f8d234f5a9dc1467768d0faeea7f764c3db3b276 Mon Sep 17 00:00:00 2001 From: dengemann Date: Mon, 13 Jul 2015 18:31:45 +0200 Subject: [PATCH 33/36] fixes fixes yeah --- examples/preprocessing/plot_ica_from_raw.py | 2 +- mne/viz/__init__.py | 2 +- mne/viz/ica.py | 26 ++++++++++----------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/examples/preprocessing/plot_ica_from_raw.py b/examples/preprocessing/plot_ica_from_raw.py index 1a8c96644dc..22d134b1c30 100644 --- a/examples/preprocessing/plot_ica_from_raw.py +++ b/examples/preprocessing/plot_ica_from_raw.py @@ -75,7 +75,7 @@ show_picks = np.abs(scores).argsort()[::-1][:5] -ica.plot_sources(raw, show_picks, exclude=ecg_inds, title=title % 'eog') +ica.plot_sources(raw, show_picks, exclude=eog_inds, title=title % 'eog') ica.plot_components(eog_inds, title=title % 'eog', colorbar=True) eog_inds = eog_inds[:n_max_eog] diff --git a/mne/viz/__init__.py b/mne/viz/__init__.py index d6475014dce..c7f3df567bf 100644 --- a/mne/viz/__init__.py +++ b/mne/viz/__init__.py @@ -19,6 +19,6 @@ plot_epochs_trellis, _drop_log_stats, plot_epochs_psd) from .raw import plot_raw, plot_raw_psd from .ica import plot_ica_scores, plot_ica_sources, plot_ica_overlay -from .ica import _plot_raw_components, _plot_epoch_components +from .ica import _plot_sources_raw, _plot_sources_epochs from .montage import plot_montage from .decoding import plot_gat_matrix, plot_gat_times diff --git a/mne/viz/ica.py b/mne/viz/ica.py index ca63a5fa745..3418623434d 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -109,13 +109,13 @@ def plot_ica_sources(ica, inst, picks=None, exclude=None, start=None, elif len(ica.exclude) > 0: exclude = np.union1d(ica.exclude, exclude) if isinstance(inst, _BaseRaw): - fig = _plot_raw_components(ica, inst, picks, exclude, start=start, + fig = _plot_sources_raw(ica, inst, picks, exclude, start=start, + stop=stop, show=show, title=title, + block=block) + elif isinstance(inst, _BaseEpochs): + fig = _plot_sources_epochs(ica, inst, picks, exclude, start=start, stop=stop, show=show, title=title, block=block) - elif isinstance(inst, _BaseEpochs): - fig = _plot_epoch_components(ica, inst, picks, exclude, start=start, - stop=stop, show=show, title=title, - block=block) elif isinstance(inst, Evoked): sources = ica.get_sources(inst) if start is not None or stop is not None: @@ -248,7 +248,7 @@ def _plot_ica_sources_evoked(evoked, picks, exclude, title, show): ax.set_xlim(times[[0, -1]]) ax.set_xlabel('Time (ms)') ax.set_ylabel('(NA)') - if exclude: + if len(exclude) > 0: plt.legend(loc='best') tight_layout(fig=fig) @@ -329,7 +329,7 @@ def plot_ica_scores(ica, scores, exclude=None, axhline=None, plt.suptitle(title) for this_scores, ax in zip(scores, axes): if len(my_range) != len(this_scores): - raise ValueError('The length ofr `scores` must equal the ' + raise ValueError('The length of `scores` must equal the ' 'number of ICA components.') ax.bar(my_range, this_scores, color='w') for excl in exclude: @@ -526,8 +526,8 @@ def _plot_ica_overlay_evoked(evoked, evoked_cln, title, show): return fig -def _plot_raw_components(ica, raw, picks, exclude, start, stop, show, title, - block): +def _plot_sources_raw(ica, raw, picks, exclude, start, stop, show, title, + block): """Function for plotting the ICA components as raw array.""" import matplotlib.pyplot as plt plt.ion() @@ -535,8 +535,8 @@ def _plot_raw_components(ica, raw, picks, exclude, start, stop, show, title, orig_data = ica._transform_raw(raw, 0, len(raw.times)) * 0.2 if picks is None: picks = range(len(orig_data)) - types = np.repeat('misc', len(picks)) - picks = sorted(picks) + types = ['misc' for _ in picks] + picks = list(sorted(picks)) c_names = ['ICA %03d' % x for x in range(len(orig_data))] if title is None: @@ -617,8 +617,8 @@ def _close_event(events, params): params['ica'].exclude = exclude -def _plot_epoch_components(ica, epochs, picks, exclude, start, stop, show, - title, block): +def _plot_sources_epochs(ica, epochs, picks, exclude, start, stop, show, + title, block): """Function for plotting the components as epochs.""" import matplotlib.pyplot as plt plt.ion() # Turn interactive mode on to avoid warnings. From 95e0ac1e2d7f73ae57e037387cd94e3af48451f0 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Tue, 14 Jul 2015 11:53:07 +0300 Subject: [PATCH 34/36] Fix to raw sources. --- mne/viz/ica.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mne/viz/ica.py b/mne/viz/ica.py index 3418623434d..df67f06ae6f 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -537,7 +537,7 @@ def _plot_sources_raw(ica, raw, picks, exclude, start, stop, show, title, picks = range(len(orig_data)) types = ['misc' for _ in picks] picks = list(sorted(picks)) - + data = [orig_data[pick] for pick in picks] c_names = ['ICA %03d' % x for x in range(len(orig_data))] if title is None: title = 'ICA components' @@ -556,7 +556,8 @@ def _plot_sources_raw(ica, raw, picks, exclude, start, stop, show, title, times = raw.times[0:t_end] bad_color = (1., 0., 0.) inds = range(len(picks)) - params = dict(raw=raw, orig_data=orig_data, data=orig_data[:, 0:t_end], + data = np.array(data) + params = dict(raw=raw, orig_data=data, data=data[:, 0:t_end], ch_start=0, t_start=start, info=info, duration=duration, ica=ica, n_channels=20, times=times, types=types, n_times=raw.n_times, bad_color=bad_color, picks=picks) @@ -612,7 +613,7 @@ def _pick_bads(event, params): def _close_event(events, params): """Function for excluding the selected components on close.""" info = params['info'] - c_names = ['ICA %03d' % x for x in range(len(params['orig_data']))] + c_names = ['ICA %03d' % x for x in range(params['ica'].n_components_)] exclude = [c_names.index(x) for x in info['bads']] params['ica'].exclude = exclude From 00f65f56e1f9fc73d194f1ba21eb60a5020aac62 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Tue, 14 Jul 2015 12:19:34 +0300 Subject: [PATCH 35/36] Fix to channels per view. --- mne/viz/ica.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mne/viz/ica.py b/mne/viz/ica.py index df67f06ae6f..a7805ea8967 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -557,11 +557,13 @@ def _plot_sources_raw(ica, raw, picks, exclude, start, stop, show, title, bad_color = (1., 0., 0.) inds = range(len(picks)) data = np.array(data) + n_channels = min([20, len(picks)]) params = dict(raw=raw, orig_data=data, data=data[:, 0:t_end], ch_start=0, t_start=start, info=info, duration=duration, - ica=ica, n_channels=20, times=times, types=types, + ica=ica, n_channels=n_channels, times=times, types=types, n_times=raw.n_times, bad_color=bad_color, picks=picks) - _prepare_mne_browse_raw(params, title, 'w', color, bad_color, inds, 20) + _prepare_mne_browse_raw(params, title, 'w', color, bad_color, inds, + n_channels) params['scale_factor'] = 1.0 params['plot_fun'] = partial(_plot_raw_traces, params=params, inds=inds, color=color, bad_color=bad_color) From 57cd76bf3d7e22fa84c772920a8bc3b4104d79a9 Mon Sep 17 00:00:00 2001 From: Jaakko Leppakangas Date: Tue, 14 Jul 2015 15:07:17 +0300 Subject: [PATCH 36/36] Small fixes. --- mne/viz/ica.py | 4 ++-- mne/viz/raw.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mne/viz/ica.py b/mne/viz/ica.py index a7805ea8967..7b748f7919c 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -555,7 +555,7 @@ def _plot_sources_raw(ica, raw, picks, exclude, start, stop, show, title, t_end = int(duration * raw.info['sfreq']) times = raw.times[0:t_end] bad_color = (1., 0., 0.) - inds = range(len(picks)) + inds = list(range(len(picks))) data = np.array(data) n_channels = min([20, len(picks)]) params = dict(raw=raw, orig_data=data, data=data[:, 0:t_end], @@ -634,7 +634,7 @@ def _plot_sources_epochs(ica, epochs, picks, exclude, start, stop, show, if title is None: title = 'ICA components' if picks is None: - picks = range(len(c_names)) + picks = list(range(len(c_names))) if start is None: start = 0 if stop is None: diff --git a/mne/viz/raw.py b/mne/viz/raw.py index 3f6e65b62db..e4bdf839656 100644 --- a/mne/viz/raw.py +++ b/mne/viz/raw.py @@ -592,7 +592,7 @@ def _plot_raw_traces(params, inds, color, bad_color, event_lines=None, n_channels = params['n_channels'] params['bad_color'] = bad_color # do the plotting - tick_list = [] + tick_list = list() for ii in range(n_channels): ch_ind = ii + params['ch_start'] # let's be generous here and allow users to pass