Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ Enhancements

- Update ``surfaces`` argument in :func:`mne.viz.plot_alignment` to allow dict for transparency values, and set default for sEEG data to have transparency (:gh:`8445` by `Keith Doelling`_)

- Add support for ``mri_fiducials='estimated'`` in :func:`mne.viz.plot_alignment` to allow estimating MRI fiducial locations using :func:`mne.coreg.get_mni_fiducials` (:gh:`8553` by `Eric Larson`_)

- Update default values in :ref:`mne coreg` and :func:`mne.viz.plot_alignment` for clearer representation of MRI and digitized fiducial points (:gh:`8553` by `Alex Gramfort`_ and `Eric Larson`_)

- Add ``n_pca_components`` argument to :func:`mne.viz.plot_ica_overlay` (:gh:`8351` by `Eric Larson`_)

- Add :func:`mne.stc_near_sensors` to facilitate plotting ECoG data (:gh:`8190` by `Eric Larson`_)
Expand Down
6 changes: 3 additions & 3 deletions mne/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@
markersize=4),
coreg=dict(
mri_fid_opacity=1.0,
dig_fid_opacity=0.3,
dig_fid_opacity=1.0,

mri_fid_scale=1e-2,
dig_fid_scale=3e-2,
mri_fid_scale=5e-3,
dig_fid_scale=8e-3,
extra_scale=4e-3,
eeg_scale=4e-3, eegp_scale=20e-3, eegp_height=0.1,
ecog_scale=5e-3,
Expand Down
6 changes: 3 additions & 3 deletions mne/gui/_coreg_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1900,18 +1900,18 @@ def _init_plot(self):
point_scale = defaults['mri_fid_scale']
self.mri_lpa_obj = PointObject(scene=self.scene, color=lpa_color,
has_norm=True, point_scale=point_scale,
name='LPA')
name='LPA', view='oct')
self.model.sync_trait('transformed_mri_lpa',
self.mri_lpa_obj, 'points', mutual=False)
self.mri_nasion_obj = PointObject(scene=self.scene, color=nasion_color,
has_norm=True,
point_scale=point_scale,
name='Nasion')
name='Nasion', view='oct')
self.model.sync_trait('transformed_mri_nasion',
self.mri_nasion_obj, 'points', mutual=False)
self.mri_rpa_obj = PointObject(scene=self.scene, color=rpa_color,
has_norm=True, point_scale=point_scale,
name='RPA')
name='RPA', view='oct')
self.model.sync_trait('transformed_mri_rpa',
self.mri_rpa_obj, 'points', mutual=False)

Expand Down
18 changes: 12 additions & 6 deletions mne/gui/_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

from ..defaults import DEFAULTS
from ..surface import _CheckInside, _DistanceQuery
from ..transforms import apply_trans
from ..transforms import apply_trans, rotation
from ..utils import SilenceStdout
from ..viz.backends._pysurfer_mayavi import (_create_mesh_surf,
from ..viz.backends._pysurfer_mayavi import (_create_mesh_surf, _oct_glyph,
_toggle_mlab_render)

try:
Expand Down Expand Up @@ -235,14 +235,14 @@ def __init__(self, view='points', has_norm=False, *args, **kwargs):

Parameters
----------
view : 'points' | 'cloud'
view : 'points' | 'cloud' | 'arrow' | 'oct'
Whether the view options should be tailored to individual points
or a point cloud.
has_norm : bool
Whether a norm can be defined; adds view options based on point
norms (default False).
"""
assert view in ('points', 'cloud', 'arrow')
assert view in ('points', 'cloud', 'arrow', 'oct')
self._view = view
self._has_norm = bool(has_norm)
super(PointObject, self).__init__(*args, **kwargs)
Expand All @@ -264,7 +264,7 @@ def default_traits_view(self): # noqa: D102
if self._view == 'arrow':
visible = Item('visible', label='Show', show_label=False)
return View(HGroup(visible, scale, 'opacity', 'label', Spring()))
elif self._view == 'points':
elif self._view in ('points', 'oct'):
visible = Item('visible', label='Show', show_label=True)
views = (visible, color, scale, 'label')
else:
Expand Down Expand Up @@ -327,11 +327,15 @@ def _plot_points(self):
# this can occur sometimes during testing w/ui.dispose()
return
# fig.scene.engine.current_object is scatter
mode = 'arrow' if self._view == 'arrow' else 'sphere'
mode = {'cloud': 'sphere', 'points': 'sphere', 'oct': 'sphere'}.get(
self._view, self._view)
assert mode in ('sphere', 'arrow')
glyph = pipeline.glyph(scatter, color=self.color,
figure=fig, scale_factor=self.point_scale,
opacity=1., resolution=self.resolution,
mode=mode)
if self._view == 'oct':
_oct_glyph(glyph.glyph.glyph_source, rotation(0, 0, np.pi / 4))
glyph.actor.property.backface_culling = True
glyph.glyph.glyph.vector_mode = 'use_normal'
glyph.glyph.glyph.clamping = False
Expand Down Expand Up @@ -430,6 +434,8 @@ def _update_marker_type(self):
gs = self.glyph.glyph.glyph_source
res = getattr(gs.glyph_source, 'theta_resolution',
getattr(gs.glyph_source, 'resolution', None))
if res is None:
return
if self.project_to_surface or self.orient_to_surface:
gs.glyph_source = tvtk.CylinderSource()
gs.glyph_source.height = defaults['eegp_height']
Expand Down
56 changes: 40 additions & 16 deletions mne/viz/_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
_reorder_ccw, _complete_sphere_surf)
from ..transforms import (_find_trans, apply_trans, rot_to_quat,
combine_transforms, _get_trans, _ensure_trans,
invert_transform, Transform,
invert_transform, Transform, rotation,
read_ras_mni_t, _print_coord_trans)
from ..utils import (get_subjects_dir, logger, _check_subject, verbose, warn,
has_nibabel, check_version, fill_doc, _pl, get_config,
Expand Down Expand Up @@ -489,8 +489,12 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None,
If not None, also plot the source space points.
mri_fiducials : bool | str
Plot MRI fiducials (default False). If ``True``, look for a file with
the canonical name (``bem/{subject}-fiducials.fif``). If ``str`` it
should provide the full path to the fiducials file.
the canonical name (``bem/{subject}-fiducials.fif``). If ``str``,
it can be ``'estimated'`` to use :func:`mne.coreg.get_mni_fiducials`,
otherwise it should provide the full path to the fiducials file.

.. versionadded:: 0.22
Support for ``'estimated'``.
bem : list of dict | instance of ConductorModel | None
Can be either the BEM surfaces (list of dict), a BEM solution or a
sphere model. If None, we first try loading
Expand Down Expand Up @@ -550,6 +554,7 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None,
.. versionadded:: 0.15
"""
from ..forward import _create_meg_coils, Forward
from ..coreg import get_mni_fiducials
# Update the backend
from .backends.renderer import _get_renderer

Expand Down Expand Up @@ -811,9 +816,12 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None,
mri_fiducials = op.join(subjects_dir, subject, 'bem',
subject + '-fiducials.fif')
if isinstance(mri_fiducials, str):
mri_fiducials, cf = read_fiducials(mri_fiducials)
if cf != FIFF.FIFFV_COORD_MRI:
raise ValueError("Fiducials are not in MRI space")
if mri_fiducials == 'estimated':
mri_fiducials = get_mni_fiducials(subject, subjects_dir)
else:
mri_fiducials, cf = read_fiducials(mri_fiducials)
if cf != FIFF.FIFFV_COORD_MRI:
raise ValueError("Fiducials are not in MRI space")
fid_loc = _fiducial_coords(mri_fiducials, FIFF.FIFFV_COORD_MRI)
fid_loc = apply_trans(mri_trans, fid_loc)
else:
Expand Down Expand Up @@ -1014,7 +1022,8 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None,
for k, v in user_alpha.items():
if v is not None:
alphas[k] = v
colors = dict(head=(0.6,) * 3, helmet=(0.0, 0.0, 0.6), lh=(0.5,) * 3,
colors = dict(head=DEFAULTS['coreg']['head_color'],
helmet=(0.0, 0.0, 0.6), lh=(0.5,) * 3,
rh=(0.5,) * 3)
colors.update(skull_colors)
for key, surf in surfs.items():
Expand Down Expand Up @@ -1060,19 +1069,34 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None,
defaults['extra_scale']
] + [defaults[key + '_scale'] for key in other_keys]
assert len(datas) == len(colors) == len(alphas) == len(scales)
fid_colors = tuple(
defaults[f'{key}_color'] for key in ('lpa', 'nasion', 'rpa'))
glyphs = ['sphere'] * len(datas)
for kind, loc in (('dig', car_loc), ('mri', fid_loc)):
if len(loc) > 0:
datas.extend(loc[:, np.newaxis])
colors.extend((defaults['lpa_color'],
defaults['nasion_color'],
defaults['rpa_color']))
alphas.extend(3 * (defaults[kind + '_fid_opacity'],))
scales.extend(3 * (defaults[kind + '_fid_scale'],))

for data, color, alpha, scale in zip(datas, colors, alphas, scales):
colors.extend(fid_colors)
alphas.extend(3 * (defaults[f'{kind}_fid_opacity'],))
scales.extend(3 * (defaults[f'{kind}_fid_scale'],))
glyphs.extend(3 * (('oct' if kind == 'mri' else 'sphere'),))
for data, color, alpha, scale, glyph in zip(
datas, colors, alphas, scales, glyphs):
if len(data) > 0:
renderer.sphere(center=data, color=color, scale=scale,
opacity=alpha, backface_culling=True)
if glyph == 'oct':
transform = np.eye(4)
transform[:3, :3] = mri_trans['trans'][:3, :3] * scale
# rotate around Z axis 45 deg first
transform = transform @ rotation(0, 0, np.pi / 4)
renderer.quiver3d(
x=data[:, 0], y=data[:, 1], z=data[:, 2],
u=1., v=0., w=0., color=color, mode='oct',
scale=1., opacity=alpha, backface_culling=True,
solid_transform=transform)
else:
assert glyph == 'sphere'
assert data.ndim == 2 and data.shape[1] == 3, data.shape
renderer.sphere(center=data, color=color, scale=scale,
opacity=alpha, backface_culling=True)
if len(eegp_loc) > 0:
renderer.quiver3d(
x=eegp_loc[:, 0], y=eegp_loc[:, 1], z=eegp_loc[:, 2],
Expand Down
38 changes: 35 additions & 3 deletions mne/viz/backends/_pysurfer_mayavi.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def quiver3d(self, x, y, z, u, v, w, color, scale, mode, resolution=8,
glyph_height=None, glyph_center=None, glyph_resolution=None,
opacity=1.0, scale_mode='none', scalars=None,
backface_culling=False, colormap=None, vmin=None, vmax=None,
line_width=2., name=None):
line_width=2., name=None, solid_transform=None):
_check_option('mode', mode, ALLOWED_QUIVER_MODES)
color = _check_color(color)
with warnings.catch_warnings(record=True): # traits
Expand All @@ -244,12 +244,15 @@ def quiver3d(self, x, y, z, u, v, w, color, scale, mode, resolution=8,
scale_mode=scale_mode,
resolution=resolution, scalars=scalars,
opacity=opacity, figure=self.fig)
elif mode in ('cone', 'sphere'):
elif mode in ('cone', 'sphere', 'oct'):
use_mode = 'sphere' if mode == 'oct' else mode
quiv = self.mlab.quiver3d(x, y, z, u, v, w, color=color,
mode=mode, scale_factor=scale,
mode=use_mode, scale_factor=scale,
opacity=opacity, figure=self.fig)
if mode == 'sphere':
quiv.glyph.glyph_source.glyph_source.center = 0., 0., 0.
elif mode == 'oct':
_oct_glyph(quiv.glyph.glyph_source, solid_transform)
else:
assert mode == 'cylinder', mode # should be guaranteed above
quiv = self.mlab.quiver3d(x, y, z, u, v, w, mode=mode,
Expand Down Expand Up @@ -523,3 +526,32 @@ def _testing_context(interactive):
yield
finally:
mlab.options.backend = orig_backend


def _oct_glyph(glyph_source, transform):
from tvtk.api import tvtk
from tvtk.common import configure_input
from traits.api import Array
gs = tvtk.PlatonicSolidSource()

# Workaround for:
# File "mayavi/components/glyph_source.py", line 231, in _glyph_position_changed # noqa: E501
# g.center = 0.0, 0.0, 0.0
# traits.trait_errors.TraitError: Cannot set the undefined 'center' attribute of a 'TransformPolyDataFilter' object. # noqa: E501
class SafeTransformPolyDataFilter(tvtk.TransformPolyDataFilter):
center = Array(shape=(3,), value=np.zeros(3))

gs.solid_type = 'octahedron'
if transform is not None:
# glyph: mayavi.modules.vectors.Vectors
# glyph.glyph: vtkGlyph3D
# glyph.glyph.glyph: mayavi.components.glyph.Glyph
assert transform.shape == (4, 4)
tr = tvtk.Transform()
tr.set_matrix(transform.ravel())
trp = SafeTransformPolyDataFilter()
configure_input(trp, gs)
trp.transform = tr
trp.update()
gs = trp
glyph_source.glyph_source = gs
19 changes: 16 additions & 3 deletions mne/viz/backends/_pyvista.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,9 @@ def tube(self, origin, destination, radius=0.001, color='white',
def quiver3d(self, x, y, z, u, v, w, color, scale, mode, resolution=8,
glyph_height=None, glyph_center=None, glyph_resolution=None,
opacity=1.0, scale_mode='none', scalars=None,
backface_culling=False, line_width=2., name=None):
backface_culling=False, line_width=2., name=None,
glyph_width=None, glyph_depth=None,
solid_transform=None):
_check_option('mode', mode, ALLOWED_QUIVER_MODES)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
Expand Down Expand Up @@ -514,13 +516,17 @@ def quiver3d(self, x, y, z, u, v, w, color, scale, mode, resolution=8,
)
mesh = pyvista.wrap(alg.GetOutput())
else:
tr = None
if mode == 'cone':
glyph = vtk.vtkConeSource()
glyph.SetCenter(0.5, 0, 0)
glyph.SetRadius(0.15)
elif mode == 'cylinder':
glyph = vtk.vtkCylinderSource()
glyph.SetRadius(0.15)
elif mode == 'oct':
glyph = vtk.vtkPlatonicSolidSource()
glyph.SetSolidTypeToOctahedron()
else:
assert mode == 'sphere', mode # guaranteed above
glyph = vtk.vtkSphereSource()
Expand All @@ -531,10 +537,17 @@ def quiver3d(self, x, y, z, u, v, w, color, scale, mode, resolution=8,
glyph.SetCenter(glyph_center)
if glyph_resolution is not None:
glyph.SetResolution(glyph_resolution)
# fix orientation
glyph.Update()
tr = vtk.vtkTransform()
tr.RotateWXYZ(90, 0, 0, 1)
elif mode == 'oct':
if solid_transform is not None:
assert solid_transform.shape == (4, 4)
tr = vtk.vtkTransform()
tr.SetMatrix(
solid_transform.astype(np.float64).ravel())
if tr is not None:
# fix orientation
glyph.Update()
trp = vtk.vtkTransformPolyDataFilter()
trp.SetInputData(glyph.GetOutput())
trp.SetTransform(tr)
Expand Down
3 changes: 2 additions & 1 deletion mne/viz/backends/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
'mayavi',
'notebook',
)
ALLOWED_QUIVER_MODES = ('2darrow', 'arrow', 'cone', 'cylinder', 'sphere')
ALLOWED_QUIVER_MODES = ('2darrow', 'arrow', 'cone', 'cylinder', 'sphere',
'oct')


def _get_colormap_from_array(colormap=None, normalized_colormap=False,
Expand Down
2 changes: 1 addition & 1 deletion tutorials/source-modeling/plot_source_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
fig = mne.viz.plot_alignment(raw.info, trans=trans, subject='sample',
subjects_dir=subjects_dir, surfaces='head-dense',
show_axes=True, dig=True, eeg=[], meg='sensors',
coord_frame='meg')
coord_frame='meg', mri_fiducials='estimated')
mne.viz.set_3d_view(fig, 45, 90, distance=0.6, focalpoint=(0., 0., 0.))
print('Distance from head origin to MEG origin: %0.1f mm'
% (1000 * np.linalg.norm(raw.info['dev_head_t']['trans'][:3, 3])))
Expand Down