Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ Changelog

- Add option for ``first_samp`` in :func:`mne.make_fixed_length_events` by `Jon Houck`_

- Add ability to auto-scale channel types for `mne.viz.plot_raw` and `mne.viz.plot_epochs` and corresponding object plotting methods by `Chris Holdgraf`_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:func: in front (see other lines above for examples)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 then +1 for merge


BUG
~~~

Expand Down
8 changes: 6 additions & 2 deletions examples/io/plot_objects_from_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@
raw.plot(n_channels=4, scalings=scalings, title='Data from arrays',
show=True, block=True)

# It is also possible to auto-compute scalings
scalings = 'auto' # Could also pass a dictionary with some value == 'auto'
raw.plot(n_channels=4, scalings=scalings, title='Auto-scaled Data from arrays',
show=True, block=True)


###############################################################################
# EpochsArray
Expand All @@ -66,7 +71,6 @@
epochs_data = np.array([[sin[:700], cos[:700]],
[sin[1000:1700], cos[1000:1700]],
[sin[1800:2500], cos[1800:2500]]])
epochs_data *= 1e-12 # Scale to match usual magnetometer amplitudes.

ch_names = ['sin', 'cos']
ch_types = ['mag', 'mag']
Expand All @@ -77,7 +81,7 @@

picks = mne.pick_types(info, meg=True, eeg=False, misc=False)

epochs.plot(picks=picks, show=True, block=True)
epochs.plot(picks=picks, scalings='auto', show=True, block=True)


###############################################################################
Expand Down
14 changes: 11 additions & 3 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,9 +814,17 @@ def plot(self, picks=None, scalings=None, show=True,
Channels to be included. If None only good data channels are used.
Defaults to None
scalings : dict | None
Scale factors for the traces. If None, defaults to
``dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4,
emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1, resp=1, chpi=1e-4)``.
Scaling factors for the traces. If any fields in scalings are
'auto', the scaling factor is set to match the 99.5th percentile of
a subset of the corresponding data. If scalings == 'auto', all
scalings fields are set to 'auto'. If any fields are 'auto' and
data is not preloaded, a subset of epochs up to 100mb will be
loaded. If None, defaults to::

dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4,
emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1, resp=1,
chpi=1e-4)

show : bool
Whether to show the figure or not.
block : bool
Expand Down
7 changes: 6 additions & 1 deletion mne/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,7 +1404,12 @@ def plot(self, events=None, duration=10.0, start=0.0, n_channels=20,
event_color : color object
Color to use for events.
scalings : dict | None
Scale factors for the traces. If None, defaults to::
Scaling factors for the traces. If any fields in scalings are
'auto', the scaling factor is set to match the 99.5th percentile of
a subset of the corresponding data. If scalings == 'auto', all
scalings fields are set to 'auto'. If any fields are 'auto' and
data is not preloaded, a subset of times up to 100mb will be
loaded. If None, defaults to::

dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4,
emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1,
Expand Down
12 changes: 9 additions & 3 deletions mne/viz/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from ..time_frequency import psd_multitaper
from .utils import (tight_layout, figure_nobar, _toggle_proj, _toggle_options,
_layout_figure, _setup_vmin_vmax, _channels_changed,
_plot_raw_onscroll, _onclick_help, plt_show)
_plot_raw_onscroll, _onclick_help, plt_show,
_compute_scalings)
from ..defaults import _handle_default


Expand Down Expand Up @@ -364,8 +365,12 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20,
picks : array-like of int | None
Channels to be included. If None only good data channels are used.
Defaults to None
scalings : dict | None
Scale factors for the traces. If None, defaults to::
scalings : dict | 'auto' | None
Scaling factors for the traces. If any fields in scalings are 'auto',
the scaling factor is set to match the 99.5th percentile of a subset of
the corresponding data. If scalings == 'auto', all scalings fields are
set to 'auto'. If any fields are 'auto' and data is not preloaded,
a subset of epochs up to 100mb will be loaded. If None, defaults to::

dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4,
emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1, resp=1, chpi=1e-4)
Expand Down Expand Up @@ -400,6 +405,7 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20,
with ``b`` key. Right mouse click adds a vertical line to the plot.
"""
epochs.drop_bad()
scalings = _compute_scalings(scalings, epochs)
scalings = _handle_default('scalings_plot_raw', scalings)

projs = epochs.info['projs']
Expand Down
9 changes: 7 additions & 2 deletions mne/viz/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
_layout_figure, _plot_raw_onkey, figure_nobar,
_plot_raw_onscroll, _mouse_click, plt_show,
_helper_raw_resize, _select_bads, _onclick_help,
_setup_browser_offsets)
_setup_browser_offsets, _compute_scalings)
from ..defaults import _handle_default
from ..annotations import _onset_to_seconds

Expand Down Expand Up @@ -117,7 +117,11 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=20,
``{event_number: color}`` pairings. Use ``event_number==-1`` for
any event numbers in the events list that are not in the dictionary.
scalings : dict | None
Scale factors for the traces. If None, defaults to::
Scaling factors for the traces. If any fields in scalings are 'auto',
the scaling factor is set to match the 99.5th percentile of a subset of
the corresponding data. If scalings == 'auto', all scalings fields are
set to 'auto'. If any fields are 'auto' and data is not preloaded, a
subset of times up to 100mb will be loaded. If None, defaults to::

dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4,
emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1,
Expand Down Expand Up @@ -177,6 +181,7 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=20,
import matplotlib as mpl
from scipy.signal import butter
color = _handle_default('color', color)
scalings = _compute_scalings(scalings, raw)
scalings = _handle_default('scalings_plot_raw', scalings)

if clipping is not None and clipping not in ('clamp', 'transparent'):
Expand Down
33 changes: 32 additions & 1 deletion mne/viz/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
from nose.tools import assert_true, assert_raises
from numpy.testing import assert_allclose

from mne.viz.utils import compare_fiff, _fake_click
from mne.viz.utils import compare_fiff, _fake_click, _compute_scalings
from mne.viz import ClickableImage, add_background_image, mne_analyze_colormap
from mne.utils import run_tests_if_main
from mne.io import read_raw_fif
from mne.event import read_events
from mne.epochs import Epochs

# Set our plotters to test mode
import matplotlib
Expand All @@ -21,6 +24,7 @@
base_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data')
raw_fname = op.join(base_dir, 'test_raw.fif')
cov_fname = op.join(base_dir, 'test-cov.fif')
ev_fname = op.join(base_dir, 'test_raw-eve.fif')


def test_mne_analyze_colormap():
Expand Down Expand Up @@ -85,4 +89,31 @@ def test_add_background_image():
assert_true(ax.get_aspect() == 'auto')


def test_auto_scale():
"""Test auto-scaling of channels for quick plotting."""
raw = read_raw_fif(raw_fname, preload=False)
ev = read_events(ev_fname)
epochs = Epochs(raw, ev)
rand_data = np.random.randn(10, 100)

for inst in [raw, epochs]:
scale_grad = 1e10
scalings_def = dict([('eeg', 'auto'), ('grad', scale_grad)])

# Test for wrong inputs
assert_raises(ValueError, inst.plot, scalings='foo')
assert_raises(ValueError, _compute_scalings, 'foo', inst)

# Make sure compute_scalings doesn't change anything not auto
scalings_new = _compute_scalings(scalings_def, inst)
assert_true(scale_grad == scalings_new['grad'])
assert_true(scalings_new['eeg'] != 'auto')

assert_raises(ValueError, _compute_scalings, scalings_def, rand_data)
epochs = epochs[0].load_data()
epochs.pick_types(eeg=True, meg=False, copy=False)
assert_raises(ValueError, _compute_scalings,
dict(grad='auto'), epochs)


run_tests_if_main()
77 changes: 76 additions & 1 deletion mne/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import webbrowser
import tempfile
import numpy as np
from copy import deepcopy

from ..channels.layout import _auto_topomap_coords
from ..channels.channels import _contains_ch_type
Expand Down Expand Up @@ -797,7 +798,7 @@ def to_layout(self, **kwargs):
**kwargs : dict
Arguments are passed to generate_2d_layout
"""
from mne.channels.layout import generate_2d_layout
from ..channels.layout import generate_2d_layout
coords = np.array(self.coords)
lt = generate_2d_layout(coords, bg_image=self.imdata, **kwargs)
return lt
Expand Down Expand Up @@ -1032,3 +1033,77 @@ def _plot_sensors(pos, colors, ch_names, title, show_names, show):
fig.suptitle(title)
plt_show(show)
return fig


def _compute_scalings(scalings, inst):
"""Compute scalings for each channel type automatically.

Parameters
----------
scalings : dict
The scalings for each channel type. If any values are
'auto', this will automatically compute a reasonable
scaling for that channel type. Any values that aren't
'auto' will not be changed.
inst : instance of Raw or Epochs
The data for which you want to compute scalings. If data
is not preloaded, this will read a subset of times / epochs
up to 100mb in size in order to compute scalings.

Returns
-------
scalings : dict
A scalings dictionary with updated values
"""
from ..io.base import _BaseRaw
from ..io.pick import _picks_by_type
from ..epochs import _BaseEpochs
if not isinstance(inst, (_BaseRaw, _BaseEpochs)):
raise ValueError('Must supply either Raw or Epochs')
if scalings is None:
# If scalings is None just return it and do nothing
return scalings

ch_types = _picks_by_type(inst.info)
unique_ch_types = [i_type[0] for i_type in ch_types]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this simplifies into:

unique_ch_types = [ch for ch in ['eeg', 'mag', 'grad', 'ecog'] if ch in inst]

Thanks the contains mixin!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to leave in the fields for things like stim channels. E.g., I often use this for loading an audio track to visualize next to the ecog channels. The default scaling is 1 (I think) for any non-brain channel types, and I'd like this kind of functionality to deal with this case as well...

if scalings == 'auto':
# If we want to auto-compute everything
scalings = dict((i_type, 'auto') for i_type in unique_ch_types)
if not isinstance(scalings, dict):
raise ValueError('scalings must be a dictionary of ch_type: val pairs,'
' not type %s ' % type(scalings))
scalings = deepcopy(scalings)

if inst.preload is False:
if isinstance(inst, _BaseRaw):
# Load a window of data from the center up to 100mb in size
n_times = 1e8 // (len(inst.ch_names) * 8)
n_times = np.clip(n_times, 1, inst.n_times)
n_secs = n_times / float(inst.info['sfreq'])
time_middle = np.mean(inst.times)
tmin = np.clip(time_middle - n_secs / 2., inst.times.min(), None)
tmax = np.clip(time_middle + n_secs / 2., None, inst.times.max())
data = inst._read_segment(tmin, tmax)
elif isinstance(inst, _BaseEpochs):
# Load a random subset of epochs up to 100mb in size
n_epochs = 1e8 // (len(inst.ch_names) * len(inst.times) * 8)
n_epochs = int(np.clip(n_epochs, 1, len(inst)))
ixs_epochs = np.random.choice(range(len(inst)), n_epochs, False)
inst = inst.copy()[ixs_epochs].load_data()
else:
data = inst._data
if isinstance(inst, _BaseEpochs):
data = inst._data.reshape([len(inst.ch_names), -1])

# Iterate through ch types and update scaling if ' auto'
for key, value in scalings.items():
if value != 'auto':
continue
if key not in unique_ch_types:
raise ValueError("Sensor {0} doesn't exist in data".format(key))
this_ixs = [i_ixs for key_, i_ixs in ch_types if key_ == key]
this_data = data[this_ixs]
scale_factor = np.percentile(this_data.ravel(), [0.5, 99.5])
scale_factor = np.max(np.abs(scale_factor))
scalings[key] = scale_factor
return scalings
10 changes: 6 additions & 4 deletions tutorials/plot_visualize_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
# color coded gray. By clicking the lines or channel names on the left, you can
# mark or unmark a bad channel interactively. You can use +/- keys to adjust
# the scale (also = works for magnifying the data). Note that the initial
# scaling factors can be set with parameter ``scalings``. With
# ``pageup/pagedown`` and ``home/end`` keys you can adjust the amount of data
# viewed at once. To see all the interactive features, hit ``?`` or click
# ``help`` in the lower left corner of the browser window.
# scaling factors can be set with parameter ``scalings``. If you don't know the
# scaling factor for channels, you can automatically set them by passing
# scalings='auto'. With ``pageup/pagedown`` and ``home/end`` keys you can
# adjust the amount of data viewed at once. To see all the interactive
# features, hit ``?`` or click ``help`` in the lower left corner of the
# browser window.
#
# We read the events from a file and passed it as a parameter when calling the
# method. The events are plotted as vertical lines so you can see how they
Expand Down