-
-
Notifications
You must be signed in to change notification settings - Fork 1.5k
adding basic auto-scaling functionality for raw and epochs classes #3198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bcd7d1b
799d2e5
eaaaeca
8624f09
9241afb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
| import webbrowser | ||
| import tempfile | ||
| import numpy as np | ||
| from copy import deepcopy | ||
|
|
||
| from ..channels.layout import _auto_topomap_coords | ||
| from ..channels.channels import _contains_ch_type | ||
|
|
@@ -797,7 +798,7 @@ def to_layout(self, **kwargs): | |
| **kwargs : dict | ||
| Arguments are passed to generate_2d_layout | ||
| """ | ||
| from mne.channels.layout import generate_2d_layout | ||
| from ..channels.layout import generate_2d_layout | ||
| coords = np.array(self.coords) | ||
| lt = generate_2d_layout(coords, bg_image=self.imdata, **kwargs) | ||
| return lt | ||
|
|
@@ -1032,3 +1033,77 @@ def _plot_sensors(pos, colors, ch_names, title, show_names, show): | |
| fig.suptitle(title) | ||
| plt_show(show) | ||
| return fig | ||
|
|
||
|
|
||
| def _compute_scalings(scalings, inst): | ||
| """Compute scalings for each channel type automatically. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| scalings : dict | ||
| The scalings for each channel type. If any values are | ||
| 'auto', this will automatically compute a reasonable | ||
| scaling for that channel type. Any values that aren't | ||
| 'auto' will not be changed. | ||
| inst : instance of Raw or Epochs | ||
| The data for which you want to compute scalings. If data | ||
| is not preloaded, this will read a subset of times / epochs | ||
| up to 100mb in size in order to compute scalings. | ||
|
|
||
| Returns | ||
| ------- | ||
| scalings : dict | ||
| A scalings dictionary with updated values | ||
| """ | ||
| from ..io.base import _BaseRaw | ||
| from ..io.pick import _picks_by_type | ||
| from ..epochs import _BaseEpochs | ||
| if not isinstance(inst, (_BaseRaw, _BaseEpochs)): | ||
| raise ValueError('Must supply either Raw or Epochs') | ||
| if scalings is None: | ||
| # If scalings is None just return it and do nothing | ||
| return scalings | ||
|
|
||
| ch_types = _picks_by_type(inst.info) | ||
| unique_ch_types = [i_type[0] for i_type in ch_types] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this simplifies into:
Thanks the contains mixin!
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd prefer to leave in the fields for things like stim channels. E.g., I often use this for loading an audio track to visualize next to the ecog channels. The default scaling is 1 (I think) for any non-brain channel types, and I'd like this kind of functionality to deal with this case as well... |
||
| if scalings == 'auto': | ||
| # If we want to auto-compute everything | ||
| scalings = dict((i_type, 'auto') for i_type in unique_ch_types) | ||
| if not isinstance(scalings, dict): | ||
| raise ValueError('scalings must be a dictionary of ch_type: val pairs,' | ||
| ' not type %s ' % type(scalings)) | ||
| scalings = deepcopy(scalings) | ||
|
|
||
| if inst.preload is False: | ||
| if isinstance(inst, _BaseRaw): | ||
| # Load a window of data from the center up to 100mb in size | ||
| n_times = 1e8 // (len(inst.ch_names) * 8) | ||
| n_times = np.clip(n_times, 1, inst.n_times) | ||
| n_secs = n_times / float(inst.info['sfreq']) | ||
| time_middle = np.mean(inst.times) | ||
| tmin = np.clip(time_middle - n_secs / 2., inst.times.min(), None) | ||
| tmax = np.clip(time_middle + n_secs / 2., None, inst.times.max()) | ||
| data = inst._read_segment(tmin, tmax) | ||
| elif isinstance(inst, _BaseEpochs): | ||
| # Load a random subset of epochs up to 100mb in size | ||
| n_epochs = 1e8 // (len(inst.ch_names) * len(inst.times) * 8) | ||
| n_epochs = int(np.clip(n_epochs, 1, len(inst))) | ||
| ixs_epochs = np.random.choice(range(len(inst)), n_epochs, False) | ||
| inst = inst.copy()[ixs_epochs].load_data() | ||
| else: | ||
| data = inst._data | ||
| if isinstance(inst, _BaseEpochs): | ||
| data = inst._data.reshape([len(inst.ch_names), -1]) | ||
|
|
||
| # Iterate through ch types and update scaling if ' auto' | ||
| for key, value in scalings.items(): | ||
| if value != 'auto': | ||
| continue | ||
| if key not in unique_ch_types: | ||
| raise ValueError("Sensor {0} doesn't exist in data".format(key)) | ||
| this_ixs = [i_ixs for key_, i_ixs in ch_types if key_ == key] | ||
| this_data = data[this_ixs] | ||
| scale_factor = np.percentile(this_data.ravel(), [0.5, 99.5]) | ||
| scale_factor = np.max(np.abs(scale_factor)) | ||
| scalings[key] = scale_factor | ||
| return scalings | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:func: in front (see other lines above for examples)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 then +1 for merge