diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 2872933f865..8e1c842ac49 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -67,6 +67,8 @@ Changelog - Add option for ``first_samp`` in :func:`mne.make_fixed_length_events` by `Jon Houck`_ + - Add ability to auto-scale channel types for `mne.viz.plot_raw` and `mne.viz.plot_epochs` and corresponding object plotting methods by `Chris Holdgraf`_ + BUG ~~~ diff --git a/examples/io/plot_objects_from_arrays.py b/examples/io/plot_objects_from_arrays.py index 64be6607b30..1e1f90619b0 100644 --- a/examples/io/plot_objects_from_arrays.py +++ b/examples/io/plot_objects_from_arrays.py @@ -51,6 +51,11 @@ raw.plot(n_channels=4, scalings=scalings, title='Data from arrays', show=True, block=True) +# It is also possible to auto-compute scalings +scalings = 'auto' # Could also pass a dictionary with some value == 'auto' +raw.plot(n_channels=4, scalings=scalings, title='Auto-scaled Data from arrays', + show=True, block=True) + ############################################################################### # EpochsArray @@ -66,7 +71,6 @@ epochs_data = np.array([[sin[:700], cos[:700]], [sin[1000:1700], cos[1000:1700]], [sin[1800:2500], cos[1800:2500]]]) -epochs_data *= 1e-12 # Scale to match usual magnetometer amplitudes. ch_names = ['sin', 'cos'] ch_types = ['mag', 'mag'] @@ -77,7 +81,7 @@ picks = mne.pick_types(info, meg=True, eeg=False, misc=False) -epochs.plot(picks=picks, show=True, block=True) +epochs.plot(picks=picks, scalings='auto', show=True, block=True) ############################################################################### diff --git a/mne/epochs.py b/mne/epochs.py index 944b28207ce..fb3869fca4b 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -814,9 +814,17 @@ def plot(self, picks=None, scalings=None, show=True, Channels to be included. If None only good data channels are used. Defaults to None scalings : dict | None - Scale factors for the traces. If None, defaults to - ``dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4, - emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1, resp=1, chpi=1e-4)``. + Scaling factors for the traces. If any fields in scalings are + 'auto', the scaling factor is set to match the 99.5th percentile of + a subset of the corresponding data. If scalings == 'auto', all + scalings fields are set to 'auto'. If any fields are 'auto' and + data is not preloaded, a subset of epochs up to 100mb will be + loaded. If None, defaults to:: + + dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4, + emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1, resp=1, + chpi=1e-4) + show : bool Whether to show the figure or not. block : bool diff --git a/mne/io/base.py b/mne/io/base.py index 609d506e48b..bea66d12289 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -1404,7 +1404,12 @@ def plot(self, events=None, duration=10.0, start=0.0, n_channels=20, event_color : color object Color to use for events. scalings : dict | None - Scale factors for the traces. If None, defaults to:: + Scaling factors for the traces. If any fields in scalings are + 'auto', the scaling factor is set to match the 99.5th percentile of + a subset of the corresponding data. If scalings == 'auto', all + scalings fields are set to 'auto'. If any fields are 'auto' and + data is not preloaded, a subset of times up to 100mb will be + loaded. If None, defaults to:: dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4, emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1, diff --git a/mne/viz/epochs.py b/mne/viz/epochs.py index 3f9b0319f62..25854583f8e 100644 --- a/mne/viz/epochs.py +++ b/mne/viz/epochs.py @@ -21,7 +21,8 @@ from ..time_frequency import psd_multitaper from .utils import (tight_layout, figure_nobar, _toggle_proj, _toggle_options, _layout_figure, _setup_vmin_vmax, _channels_changed, - _plot_raw_onscroll, _onclick_help, plt_show) + _plot_raw_onscroll, _onclick_help, plt_show, + _compute_scalings) from ..defaults import _handle_default @@ -364,8 +365,12 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, picks : array-like of int | None Channels to be included. If None only good data channels are used. Defaults to None - scalings : dict | None - Scale factors for the traces. If None, defaults to:: + scalings : dict | 'auto' | None + Scaling factors for the traces. If any fields in scalings are 'auto', + the scaling factor is set to match the 99.5th percentile of a subset of + the corresponding data. If scalings == 'auto', all scalings fields are + set to 'auto'. If any fields are 'auto' and data is not preloaded, + a subset of epochs up to 100mb will be loaded. If None, defaults to:: dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4, emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1, resp=1, chpi=1e-4) @@ -400,6 +405,7 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, with ``b`` key. Right mouse click adds a vertical line to the plot. """ epochs.drop_bad() + scalings = _compute_scalings(scalings, epochs) scalings = _handle_default('scalings_plot_raw', scalings) projs = epochs.info['projs'] diff --git a/mne/viz/raw.py b/mne/viz/raw.py index a6a0a6d1253..09724e68724 100644 --- a/mne/viz/raw.py +++ b/mne/viz/raw.py @@ -23,7 +23,7 @@ _layout_figure, _plot_raw_onkey, figure_nobar, _plot_raw_onscroll, _mouse_click, plt_show, _helper_raw_resize, _select_bads, _onclick_help, - _setup_browser_offsets) + _setup_browser_offsets, _compute_scalings) from ..defaults import _handle_default from ..annotations import _onset_to_seconds @@ -117,7 +117,11 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=20, ``{event_number: color}`` pairings. Use ``event_number==-1`` for any event numbers in the events list that are not in the dictionary. scalings : dict | None - Scale factors for the traces. If None, defaults to:: + Scaling factors for the traces. If any fields in scalings are 'auto', + the scaling factor is set to match the 99.5th percentile of a subset of + the corresponding data. If scalings == 'auto', all scalings fields are + set to 'auto'. If any fields are 'auto' and data is not preloaded, a + subset of times up to 100mb will be loaded. If None, defaults to:: dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4, emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1, @@ -177,6 +181,7 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=20, import matplotlib as mpl from scipy.signal import butter color = _handle_default('color', color) + scalings = _compute_scalings(scalings, raw) scalings = _handle_default('scalings_plot_raw', scalings) if clipping is not None and clipping not in ('clamp', 'transparent'): diff --git a/mne/viz/tests/test_utils.py b/mne/viz/tests/test_utils.py index 3a8b69dbdf6..8cb99d5224f 100644 --- a/mne/viz/tests/test_utils.py +++ b/mne/viz/tests/test_utils.py @@ -8,9 +8,12 @@ from nose.tools import assert_true, assert_raises from numpy.testing import assert_allclose -from mne.viz.utils import compare_fiff, _fake_click +from mne.viz.utils import compare_fiff, _fake_click, _compute_scalings from mne.viz import ClickableImage, add_background_image, mne_analyze_colormap from mne.utils import run_tests_if_main +from mne.io import read_raw_fif +from mne.event import read_events +from mne.epochs import Epochs # Set our plotters to test mode import matplotlib @@ -21,6 +24,7 @@ base_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data') raw_fname = op.join(base_dir, 'test_raw.fif') cov_fname = op.join(base_dir, 'test-cov.fif') +ev_fname = op.join(base_dir, 'test_raw-eve.fif') def test_mne_analyze_colormap(): @@ -85,4 +89,31 @@ def test_add_background_image(): assert_true(ax.get_aspect() == 'auto') +def test_auto_scale(): + """Test auto-scaling of channels for quick plotting.""" + raw = read_raw_fif(raw_fname, preload=False) + ev = read_events(ev_fname) + epochs = Epochs(raw, ev) + rand_data = np.random.randn(10, 100) + + for inst in [raw, epochs]: + scale_grad = 1e10 + scalings_def = dict([('eeg', 'auto'), ('grad', scale_grad)]) + + # Test for wrong inputs + assert_raises(ValueError, inst.plot, scalings='foo') + assert_raises(ValueError, _compute_scalings, 'foo', inst) + + # Make sure compute_scalings doesn't change anything not auto + scalings_new = _compute_scalings(scalings_def, inst) + assert_true(scale_grad == scalings_new['grad']) + assert_true(scalings_new['eeg'] != 'auto') + + assert_raises(ValueError, _compute_scalings, scalings_def, rand_data) + epochs = epochs[0].load_data() + epochs.pick_types(eeg=True, meg=False, copy=False) + assert_raises(ValueError, _compute_scalings, + dict(grad='auto'), epochs) + + run_tests_if_main() diff --git a/mne/viz/utils.py b/mne/viz/utils.py index cb1c6b3f1e1..e17b436ba0f 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -16,6 +16,7 @@ import webbrowser import tempfile import numpy as np +from copy import deepcopy from ..channels.layout import _auto_topomap_coords from ..channels.channels import _contains_ch_type @@ -797,7 +798,7 @@ def to_layout(self, **kwargs): **kwargs : dict Arguments are passed to generate_2d_layout """ - from mne.channels.layout import generate_2d_layout + from ..channels.layout import generate_2d_layout coords = np.array(self.coords) lt = generate_2d_layout(coords, bg_image=self.imdata, **kwargs) return lt @@ -1032,3 +1033,77 @@ def _plot_sensors(pos, colors, ch_names, title, show_names, show): fig.suptitle(title) plt_show(show) return fig + + +def _compute_scalings(scalings, inst): + """Compute scalings for each channel type automatically. + + Parameters + ---------- + scalings : dict + The scalings for each channel type. If any values are + 'auto', this will automatically compute a reasonable + scaling for that channel type. Any values that aren't + 'auto' will not be changed. + inst : instance of Raw or Epochs + The data for which you want to compute scalings. If data + is not preloaded, this will read a subset of times / epochs + up to 100mb in size in order to compute scalings. + + Returns + ------- + scalings : dict + A scalings dictionary with updated values + """ + from ..io.base import _BaseRaw + from ..io.pick import _picks_by_type + from ..epochs import _BaseEpochs + if not isinstance(inst, (_BaseRaw, _BaseEpochs)): + raise ValueError('Must supply either Raw or Epochs') + if scalings is None: + # If scalings is None just return it and do nothing + return scalings + + ch_types = _picks_by_type(inst.info) + unique_ch_types = [i_type[0] for i_type in ch_types] + if scalings == 'auto': + # If we want to auto-compute everything + scalings = dict((i_type, 'auto') for i_type in unique_ch_types) + if not isinstance(scalings, dict): + raise ValueError('scalings must be a dictionary of ch_type: val pairs,' + ' not type %s ' % type(scalings)) + scalings = deepcopy(scalings) + + if inst.preload is False: + if isinstance(inst, _BaseRaw): + # Load a window of data from the center up to 100mb in size + n_times = 1e8 // (len(inst.ch_names) * 8) + n_times = np.clip(n_times, 1, inst.n_times) + n_secs = n_times / float(inst.info['sfreq']) + time_middle = np.mean(inst.times) + tmin = np.clip(time_middle - n_secs / 2., inst.times.min(), None) + tmax = np.clip(time_middle + n_secs / 2., None, inst.times.max()) + data = inst._read_segment(tmin, tmax) + elif isinstance(inst, _BaseEpochs): + # Load a random subset of epochs up to 100mb in size + n_epochs = 1e8 // (len(inst.ch_names) * len(inst.times) * 8) + n_epochs = int(np.clip(n_epochs, 1, len(inst))) + ixs_epochs = np.random.choice(range(len(inst)), n_epochs, False) + inst = inst.copy()[ixs_epochs].load_data() + else: + data = inst._data + if isinstance(inst, _BaseEpochs): + data = inst._data.reshape([len(inst.ch_names), -1]) + + # Iterate through ch types and update scaling if ' auto' + for key, value in scalings.items(): + if value != 'auto': + continue + if key not in unique_ch_types: + raise ValueError("Sensor {0} doesn't exist in data".format(key)) + this_ixs = [i_ixs for key_, i_ixs in ch_types if key_ == key] + this_data = data[this_ixs] + scale_factor = np.percentile(this_data.ravel(), [0.5, 99.5]) + scale_factor = np.max(np.abs(scale_factor)) + scalings[key] = scale_factor + return scalings diff --git a/tutorials/plot_visualize_raw.py b/tutorials/plot_visualize_raw.py index 06934f01839..4d1ce89919d 100644 --- a/tutorials/plot_visualize_raw.py +++ b/tutorials/plot_visualize_raw.py @@ -35,10 +35,12 @@ # color coded gray. By clicking the lines or channel names on the left, you can # mark or unmark a bad channel interactively. You can use +/- keys to adjust # the scale (also = works for magnifying the data). Note that the initial -# scaling factors can be set with parameter ``scalings``. With -# ``pageup/pagedown`` and ``home/end`` keys you can adjust the amount of data -# viewed at once. To see all the interactive features, hit ``?`` or click -# ``help`` in the lower left corner of the browser window. +# scaling factors can be set with parameter ``scalings``. If you don't know the +# scaling factor for channels, you can automatically set them by passing +# scalings='auto'. With ``pageup/pagedown`` and ``home/end`` keys you can +# adjust the amount of data viewed at once. To see all the interactive +# features, hit ``?`` or click ``help`` in the lower left corner of the +# browser window. # # We read the events from a file and passed it as a parameter when calling the # method. The events are plotted as vertical lines so you can see how they