diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 4fe2a343a52..48b411ef05e 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -29,6 +29,10 @@ Enhancements - Update ``surfaces`` argument in :func:`mne.viz.plot_alignment` to allow dict for transparency values, and set default for sEEG data to have transparency (:gh:`8445` by `Keith Doelling`_) +- Add support for ``mri_fiducials='estimated'`` in :func:`mne.viz.plot_alignment` to allow estimating MRI fiducial locations using :func:`mne.coreg.get_mni_fiducials` (:gh:`8553` by `Eric Larson`_) + +- Update default values in :ref:`mne coreg` and :func:`mne.viz.plot_alignment` for clearer representation of MRI and digitized fiducial points (:gh:`8553` by `Alex Gramfort`_ and `Eric Larson`_) + - Add ``n_pca_components`` argument to :func:`mne.viz.plot_ica_overlay` (:gh:`8351` by `Eric Larson`_) - Add :func:`mne.stc_near_sensors` to facilitate plotting ECoG data (:gh:`8190` by `Eric Larson`_) diff --git a/mne/defaults.py b/mne/defaults.py index 82b9d4a58bd..5fb0a7dc6eb 100644 --- a/mne/defaults.py +++ b/mne/defaults.py @@ -49,10 +49,10 @@ markersize=4), coreg=dict( mri_fid_opacity=1.0, - dig_fid_opacity=0.3, + dig_fid_opacity=1.0, - mri_fid_scale=1e-2, - dig_fid_scale=3e-2, + mri_fid_scale=5e-3, + dig_fid_scale=8e-3, extra_scale=4e-3, eeg_scale=4e-3, eegp_scale=20e-3, eegp_height=0.1, ecog_scale=5e-3, diff --git a/mne/gui/_coreg_gui.py b/mne/gui/_coreg_gui.py index dd36d3013f9..3f1cdeee8ff 100644 --- a/mne/gui/_coreg_gui.py +++ b/mne/gui/_coreg_gui.py @@ -1900,18 +1900,18 @@ def _init_plot(self): point_scale = defaults['mri_fid_scale'] self.mri_lpa_obj = PointObject(scene=self.scene, color=lpa_color, has_norm=True, point_scale=point_scale, - name='LPA') + name='LPA', view='oct') self.model.sync_trait('transformed_mri_lpa', self.mri_lpa_obj, 'points', mutual=False) self.mri_nasion_obj = PointObject(scene=self.scene, color=nasion_color, has_norm=True, point_scale=point_scale, - name='Nasion') + name='Nasion', view='oct') self.model.sync_trait('transformed_mri_nasion', self.mri_nasion_obj, 'points', mutual=False) self.mri_rpa_obj = PointObject(scene=self.scene, color=rpa_color, has_norm=True, point_scale=point_scale, - name='RPA') + name='RPA', view='oct') self.model.sync_trait('transformed_mri_rpa', self.mri_rpa_obj, 'points', mutual=False) diff --git a/mne/gui/_viewer.py b/mne/gui/_viewer.py index e33f44a0680..4fe6e9b1002 100644 --- a/mne/gui/_viewer.py +++ b/mne/gui/_viewer.py @@ -21,9 +21,9 @@ from ..defaults import DEFAULTS from ..surface import _CheckInside, _DistanceQuery -from ..transforms import apply_trans +from ..transforms import apply_trans, rotation from ..utils import SilenceStdout -from ..viz.backends._pysurfer_mayavi import (_create_mesh_surf, +from ..viz.backends._pysurfer_mayavi import (_create_mesh_surf, _oct_glyph, _toggle_mlab_render) try: @@ -235,14 +235,14 @@ def __init__(self, view='points', has_norm=False, *args, **kwargs): Parameters ---------- - view : 'points' | 'cloud' + view : 'points' | 'cloud' | 'arrow' | 'oct' Whether the view options should be tailored to individual points or a point cloud. has_norm : bool Whether a norm can be defined; adds view options based on point norms (default False). """ - assert view in ('points', 'cloud', 'arrow') + assert view in ('points', 'cloud', 'arrow', 'oct') self._view = view self._has_norm = bool(has_norm) super(PointObject, self).__init__(*args, **kwargs) @@ -264,7 +264,7 @@ def default_traits_view(self): # noqa: D102 if self._view == 'arrow': visible = Item('visible', label='Show', show_label=False) return View(HGroup(visible, scale, 'opacity', 'label', Spring())) - elif self._view == 'points': + elif self._view in ('points', 'oct'): visible = Item('visible', label='Show', show_label=True) views = (visible, color, scale, 'label') else: @@ -327,11 +327,15 @@ def _plot_points(self): # this can occur sometimes during testing w/ui.dispose() return # fig.scene.engine.current_object is scatter - mode = 'arrow' if self._view == 'arrow' else 'sphere' + mode = {'cloud': 'sphere', 'points': 'sphere', 'oct': 'sphere'}.get( + self._view, self._view) + assert mode in ('sphere', 'arrow') glyph = pipeline.glyph(scatter, color=self.color, figure=fig, scale_factor=self.point_scale, opacity=1., resolution=self.resolution, mode=mode) + if self._view == 'oct': + _oct_glyph(glyph.glyph.glyph_source, rotation(0, 0, np.pi / 4)) glyph.actor.property.backface_culling = True glyph.glyph.glyph.vector_mode = 'use_normal' glyph.glyph.glyph.clamping = False @@ -430,6 +434,8 @@ def _update_marker_type(self): gs = self.glyph.glyph.glyph_source res = getattr(gs.glyph_source, 'theta_resolution', getattr(gs.glyph_source, 'resolution', None)) + if res is None: + return if self.project_to_surface or self.orient_to_surface: gs.glyph_source = tvtk.CylinderSource() gs.glyph_source.height = defaults['eegp_height'] diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index 0b836dd1071..e034f4479ba 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -34,7 +34,7 @@ _reorder_ccw, _complete_sphere_surf) from ..transforms import (_find_trans, apply_trans, rot_to_quat, combine_transforms, _get_trans, _ensure_trans, - invert_transform, Transform, + invert_transform, Transform, rotation, read_ras_mni_t, _print_coord_trans) from ..utils import (get_subjects_dir, logger, _check_subject, verbose, warn, has_nibabel, check_version, fill_doc, _pl, get_config, @@ -489,8 +489,12 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None, If not None, also plot the source space points. mri_fiducials : bool | str Plot MRI fiducials (default False). If ``True``, look for a file with - the canonical name (``bem/{subject}-fiducials.fif``). If ``str`` it - should provide the full path to the fiducials file. + the canonical name (``bem/{subject}-fiducials.fif``). If ``str``, + it can be ``'estimated'`` to use :func:`mne.coreg.get_mni_fiducials`, + otherwise it should provide the full path to the fiducials file. + + .. versionadded:: 0.22 + Support for ``'estimated'``. bem : list of dict | instance of ConductorModel | None Can be either the BEM surfaces (list of dict), a BEM solution or a sphere model. If None, we first try loading @@ -550,6 +554,7 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None, .. versionadded:: 0.15 """ from ..forward import _create_meg_coils, Forward + from ..coreg import get_mni_fiducials # Update the backend from .backends.renderer import _get_renderer @@ -811,9 +816,12 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None, mri_fiducials = op.join(subjects_dir, subject, 'bem', subject + '-fiducials.fif') if isinstance(mri_fiducials, str): - mri_fiducials, cf = read_fiducials(mri_fiducials) - if cf != FIFF.FIFFV_COORD_MRI: - raise ValueError("Fiducials are not in MRI space") + if mri_fiducials == 'estimated': + mri_fiducials = get_mni_fiducials(subject, subjects_dir) + else: + mri_fiducials, cf = read_fiducials(mri_fiducials) + if cf != FIFF.FIFFV_COORD_MRI: + raise ValueError("Fiducials are not in MRI space") fid_loc = _fiducial_coords(mri_fiducials, FIFF.FIFFV_COORD_MRI) fid_loc = apply_trans(mri_trans, fid_loc) else: @@ -1014,7 +1022,8 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None, for k, v in user_alpha.items(): if v is not None: alphas[k] = v - colors = dict(head=(0.6,) * 3, helmet=(0.0, 0.0, 0.6), lh=(0.5,) * 3, + colors = dict(head=DEFAULTS['coreg']['head_color'], + helmet=(0.0, 0.0, 0.6), lh=(0.5,) * 3, rh=(0.5,) * 3) colors.update(skull_colors) for key, surf in surfs.items(): @@ -1060,19 +1069,34 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None, defaults['extra_scale'] ] + [defaults[key + '_scale'] for key in other_keys] assert len(datas) == len(colors) == len(alphas) == len(scales) + fid_colors = tuple( + defaults[f'{key}_color'] for key in ('lpa', 'nasion', 'rpa')) + glyphs = ['sphere'] * len(datas) for kind, loc in (('dig', car_loc), ('mri', fid_loc)): if len(loc) > 0: datas.extend(loc[:, np.newaxis]) - colors.extend((defaults['lpa_color'], - defaults['nasion_color'], - defaults['rpa_color'])) - alphas.extend(3 * (defaults[kind + '_fid_opacity'],)) - scales.extend(3 * (defaults[kind + '_fid_scale'],)) - - for data, color, alpha, scale in zip(datas, colors, alphas, scales): + colors.extend(fid_colors) + alphas.extend(3 * (defaults[f'{kind}_fid_opacity'],)) + scales.extend(3 * (defaults[f'{kind}_fid_scale'],)) + glyphs.extend(3 * (('oct' if kind == 'mri' else 'sphere'),)) + for data, color, alpha, scale, glyph in zip( + datas, colors, alphas, scales, glyphs): if len(data) > 0: - renderer.sphere(center=data, color=color, scale=scale, - opacity=alpha, backface_culling=True) + if glyph == 'oct': + transform = np.eye(4) + transform[:3, :3] = mri_trans['trans'][:3, :3] * scale + # rotate around Z axis 45 deg first + transform = transform @ rotation(0, 0, np.pi / 4) + renderer.quiver3d( + x=data[:, 0], y=data[:, 1], z=data[:, 2], + u=1., v=0., w=0., color=color, mode='oct', + scale=1., opacity=alpha, backface_culling=True, + solid_transform=transform) + else: + assert glyph == 'sphere' + assert data.ndim == 2 and data.shape[1] == 3, data.shape + renderer.sphere(center=data, color=color, scale=scale, + opacity=alpha, backface_culling=True) if len(eegp_loc) > 0: renderer.quiver3d( x=eegp_loc[:, 0], y=eegp_loc[:, 1], z=eegp_loc[:, 2], diff --git a/mne/viz/backends/_pysurfer_mayavi.py b/mne/viz/backends/_pysurfer_mayavi.py index 5077a81dee5..1af9bf0a805 100644 --- a/mne/viz/backends/_pysurfer_mayavi.py +++ b/mne/viz/backends/_pysurfer_mayavi.py @@ -234,7 +234,7 @@ def quiver3d(self, x, y, z, u, v, w, color, scale, mode, resolution=8, glyph_height=None, glyph_center=None, glyph_resolution=None, opacity=1.0, scale_mode='none', scalars=None, backface_culling=False, colormap=None, vmin=None, vmax=None, - line_width=2., name=None): + line_width=2., name=None, solid_transform=None): _check_option('mode', mode, ALLOWED_QUIVER_MODES) color = _check_color(color) with warnings.catch_warnings(record=True): # traits @@ -244,12 +244,15 @@ def quiver3d(self, x, y, z, u, v, w, color, scale, mode, resolution=8, scale_mode=scale_mode, resolution=resolution, scalars=scalars, opacity=opacity, figure=self.fig) - elif mode in ('cone', 'sphere'): + elif mode in ('cone', 'sphere', 'oct'): + use_mode = 'sphere' if mode == 'oct' else mode quiv = self.mlab.quiver3d(x, y, z, u, v, w, color=color, - mode=mode, scale_factor=scale, + mode=use_mode, scale_factor=scale, opacity=opacity, figure=self.fig) if mode == 'sphere': quiv.glyph.glyph_source.glyph_source.center = 0., 0., 0. + elif mode == 'oct': + _oct_glyph(quiv.glyph.glyph_source, solid_transform) else: assert mode == 'cylinder', mode # should be guaranteed above quiv = self.mlab.quiver3d(x, y, z, u, v, w, mode=mode, @@ -523,3 +526,32 @@ def _testing_context(interactive): yield finally: mlab.options.backend = orig_backend + + +def _oct_glyph(glyph_source, transform): + from tvtk.api import tvtk + from tvtk.common import configure_input + from traits.api import Array + gs = tvtk.PlatonicSolidSource() + + # Workaround for: + # File "mayavi/components/glyph_source.py", line 231, in _glyph_position_changed # noqa: E501 + # g.center = 0.0, 0.0, 0.0 + # traits.trait_errors.TraitError: Cannot set the undefined 'center' attribute of a 'TransformPolyDataFilter' object. # noqa: E501 + class SafeTransformPolyDataFilter(tvtk.TransformPolyDataFilter): + center = Array(shape=(3,), value=np.zeros(3)) + + gs.solid_type = 'octahedron' + if transform is not None: + # glyph: mayavi.modules.vectors.Vectors + # glyph.glyph: vtkGlyph3D + # glyph.glyph.glyph: mayavi.components.glyph.Glyph + assert transform.shape == (4, 4) + tr = tvtk.Transform() + tr.set_matrix(transform.ravel()) + trp = SafeTransformPolyDataFilter() + configure_input(trp, gs) + trp.transform = tr + trp.update() + gs = trp + glyph_source.glyph_source = gs diff --git a/mne/viz/backends/_pyvista.py b/mne/viz/backends/_pyvista.py index 4b8fa1ae041..60dd7632d0a 100644 --- a/mne/viz/backends/_pyvista.py +++ b/mne/viz/backends/_pyvista.py @@ -483,7 +483,9 @@ def tube(self, origin, destination, radius=0.001, color='white', def quiver3d(self, x, y, z, u, v, w, color, scale, mode, resolution=8, glyph_height=None, glyph_center=None, glyph_resolution=None, opacity=1.0, scale_mode='none', scalars=None, - backface_culling=False, line_width=2., name=None): + backface_culling=False, line_width=2., name=None, + glyph_width=None, glyph_depth=None, + solid_transform=None): _check_option('mode', mode, ALLOWED_QUIVER_MODES) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) @@ -514,6 +516,7 @@ def quiver3d(self, x, y, z, u, v, w, color, scale, mode, resolution=8, ) mesh = pyvista.wrap(alg.GetOutput()) else: + tr = None if mode == 'cone': glyph = vtk.vtkConeSource() glyph.SetCenter(0.5, 0, 0) @@ -521,6 +524,9 @@ def quiver3d(self, x, y, z, u, v, w, color, scale, mode, resolution=8, elif mode == 'cylinder': glyph = vtk.vtkCylinderSource() glyph.SetRadius(0.15) + elif mode == 'oct': + glyph = vtk.vtkPlatonicSolidSource() + glyph.SetSolidTypeToOctahedron() else: assert mode == 'sphere', mode # guaranteed above glyph = vtk.vtkSphereSource() @@ -531,10 +537,17 @@ def quiver3d(self, x, y, z, u, v, w, color, scale, mode, resolution=8, glyph.SetCenter(glyph_center) if glyph_resolution is not None: glyph.SetResolution(glyph_resolution) - # fix orientation - glyph.Update() tr = vtk.vtkTransform() tr.RotateWXYZ(90, 0, 0, 1) + elif mode == 'oct': + if solid_transform is not None: + assert solid_transform.shape == (4, 4) + tr = vtk.vtkTransform() + tr.SetMatrix( + solid_transform.astype(np.float64).ravel()) + if tr is not None: + # fix orientation + glyph.Update() trp = vtk.vtkTransformPolyDataFilter() trp.SetInputData(glyph.GetOutput()) trp.SetTransform(tr) diff --git a/mne/viz/backends/_utils.py b/mne/viz/backends/_utils.py index 4c271ead23f..cc60cc39ac4 100644 --- a/mne/viz/backends/_utils.py +++ b/mne/viz/backends/_utils.py @@ -15,7 +15,8 @@ 'mayavi', 'notebook', ) -ALLOWED_QUIVER_MODES = ('2darrow', 'arrow', 'cone', 'cylinder', 'sphere') +ALLOWED_QUIVER_MODES = ('2darrow', 'arrow', 'cone', 'cylinder', 'sphere', + 'oct') def _get_colormap_from_array(colormap=None, normalized_colormap=False, diff --git a/tutorials/source-modeling/plot_source_alignment.py b/tutorials/source-modeling/plot_source_alignment.py index 45c38343fd7..9739eb83d6e 100644 --- a/tutorials/source-modeling/plot_source_alignment.py +++ b/tutorials/source-modeling/plot_source_alignment.py @@ -104,7 +104,7 @@ fig = mne.viz.plot_alignment(raw.info, trans=trans, subject='sample', subjects_dir=subjects_dir, surfaces='head-dense', show_axes=True, dig=True, eeg=[], meg='sensors', - coord_frame='meg') + coord_frame='meg', mri_fiducials='estimated') mne.viz.set_3d_view(fig, 45, 90, distance=0.6, focalpoint=(0., 0., 0.)) print('Distance from head origin to MEG origin: %0.1f mm' % (1000 * np.linalg.norm(raw.info['dev_head_t']['trans'][:3, 3])))