diff --git a/.travis.yml b/.travis.yml index a2505ea202c..df9f261d4e5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -142,7 +142,7 @@ script: pip install -e .; python mne/tests/test_evoked.py; fi; - - echo 'pytest -m "${CONDITION}" --cov=mne -vv ${USE_DIRS}' + - echo 'pytest -m "${CONDITION}" --tb=short --cov=mne -vv ${USE_DIRS}' - pytest -m "${CONDITION}" --tb=short --cov=mne -vv ${USE_DIRS} # run the minimal one with the testing data - if [ "${DEPS}" == "minimal" ]; then diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index 099d0899764..2e074b12abd 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -2568,7 +2568,7 @@ def plot_sparse_source_estimates(src, stcs, colors=None, linewidth=2, z=points[:, 2], triangles=use_faces, color=brain_color, opacity=opacity, backface_culling=True, shading=True, - **kwargs) + normals=normals, **kwargs) # Show time courses fig = plt.figure(fig_number) diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index a6e3ccc77a2..e7ac2e840b8 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -270,7 +270,7 @@ def __init__(self, subject_id, hemi, surf, title=None, self._hemi_meshes[h] = mesh self._hemi_actors[h] = actor else: - self._renderer._mesh( + self._renderer.polydata( self._hemi_meshes[h], **kwargs, ) @@ -528,7 +528,7 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None, self._data[hemi]['actor'] = actor self._data[hemi]['mesh'] = mesh else: - self._renderer._mesh( + self._renderer.polydata( self._data[hemi]['mesh'], **kwargs, ) @@ -1153,8 +1153,7 @@ def set_time_point(self, time_idx): self._update() def update_glyphs(self, hemi, vectors): - from ..backends._pyvista import (_set_colormap_range, - _add_polydata_actor) + from ..backends._pyvista import _set_colormap_range hemi_data = self._data.get(hemi) if hemi_data is not None: vertices = hemi_data['vertices'] @@ -1177,14 +1176,11 @@ def update_glyphs(self, hemi, vectors): ) if polydata is not None: if hemi_data['glyph_mesh'] is None: - hemi_data['glyph_mesh'] = polydata - glyph_actor = _add_polydata_actor( - plotter=self._renderer.plotter, - polydata=polydata, - hide=True - ) - hemi_data['glyph_actor'] = glyph_actor + glyph_actor, _ = self._renderer.polydata(polydata) + glyph_actor.VisibilityOff() glyph_actor.GetProperty().SetLineWidth(2.) + hemi_data['glyph_actor'] = glyph_actor + hemi_data['glyph_mesh'] = polydata else: glyph_actor = hemi_data['glyph_actor'] glyph_mesh = hemi_data['glyph_mesh'] diff --git a/mne/viz/backends/_pysurfer_mayavi.py b/mne/viz/backends/_pysurfer_mayavi.py index 0277f4c3117..5ad63ee45dc 100644 --- a/mne/viz/backends/_pysurfer_mayavi.py +++ b/mne/viz/backends/_pysurfer_mayavi.py @@ -95,12 +95,6 @@ def set_interactive(self): self.fig.scene.interactor.interactor_style = \ tvtk.InteractorStyleTerrain() - def _mesh(self, mesh, color, opacity=1.0, - backface_culling=False, scalars=None, colormap=None, - vmin=None, vmax=None, interpolate_before_map=True, - representation='surface', line_width=1., **kwargs): - raise NotImplementedError("This feature is not available with mayavi.") - def mesh(self, x, y, z, triangles, color, opacity=1.0, shading=False, backface_culling=False, scalars=None, colormap=None, vmin=None, vmax=None, interpolate_before_map=True, diff --git a/mne/viz/backends/_pyvista.py b/mne/viz/backends/_pyvista.py index 16dc6310038..4405d1b5d08 100644 --- a/mne/viz/backends/_pyvista.py +++ b/mne/viz/backends/_pyvista.py @@ -49,7 +49,7 @@ def __init__(self, plotter=None, size=(600, 600), shape=(1, 1), background_color='black', - smooth_shading=False, + smooth_shading=True, off_screen=False, notebook=False): self.plotter = plotter @@ -144,7 +144,7 @@ class _Renderer(_BaseRenderer): def __init__(self, fig=None, size=(600, 600), bgcolor='black', name="PyVista Scene", show=False, shape=(1, 1), - notebook=None, smooth_shading=False): + notebook=None, smooth_shading=True): from .renderer import MNE_3D_BACKEND_TESTING from .._3d import _get_3d_option figure = _Figure(show=show, title=name, size=size, shape=shape, @@ -217,17 +217,16 @@ def scene(self): def set_interactive(self): self.plotter.enable_terrain_style() - def _mesh(self, mesh, color, opacity=1.0, - backface_culling=False, scalars=None, colormap=None, - vmin=None, vmax=None, interpolate_before_map=True, - representation='surface', line_width=1., **kwargs): + def polydata(self, mesh, color=None, opacity=1.0, normals=None, + backface_culling=False, scalars=None, colormap=None, + vmin=None, vmax=None, interpolate_before_map=True, + representation='surface', line_width=1., **kwargs): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) - n_vertices = mesh.n_points rgba = False - if color is not None and len(color) == n_vertices: + if color is not None and len(color) == mesh.n_points: if color.shape[1] == 3: - scalars = np.c_[color, np.ones(n_vertices)] + scalars = np.c_[color, np.ones(mesh.n_points)] else: scalars = color scalars = (scalars * 255).astype('ubyte') @@ -238,7 +237,11 @@ def _mesh(self, mesh, color, opacity=1.0, colormap = colormap.astype(np.float64) / 255. from matplotlib.colors import ListedColormap colormap = ListedColormap(colormap) - + if normals is not None: + mesh.point_arrays["Normals"] = normals + mesh.GetPointData().SetActiveNormals("Normals") + else: + _compute_normals(mesh) actor = _add_mesh( plotter=self.plotter, mesh=mesh, color=color, scalars=scalars, @@ -247,16 +250,9 @@ def _mesh(self, mesh, color, opacity=1.0, rng=[vmin, vmax], show_scalar_bar=False, smooth_shading=self.figure.smooth_shading, interpolate_before_map=interpolate_before_map, - representation=representation, line_width=line_width, **kwargs, + style=representation, line_width=line_width, **kwargs, ) - try: - mesh.point_arrays["Normals"] - except KeyError: - pass - else: - prop = actor.GetProperty() - prop.SetInterpolationToPhong() return actor, mesh def mesh(self, x, y, z, triangles, color, opacity=1.0, shading=False, @@ -268,21 +264,19 @@ def mesh(self, x, y, z, triangles, color, opacity=1.0, shading=False, vertices = np.c_[x, y, z] triangles = np.c_[np.full(len(triangles), 3), triangles] mesh = PolyData(vertices, triangles) - if normals is not None: - mesh.point_arrays["Normals"] = normals - mesh.GetPointData().SetActiveNormals("Normals") - return self._mesh( - mesh, - color, - opacity, - backface_culling, - scalars, - colormap, - vmin, - vmax, - interpolate_before_map, - representation, - line_width, + return self.polydata( + mesh=mesh, + color=color, + opacity=opacity, + normals=normals, + backface_culling=backface_culling, + scalars=scalars, + colormap=colormap, + vmin=vmin, + vmax=vmax, + interpolate_before_map=interpolate_before_map, + representation=representation, + line_width=line_width, **kwargs, ) @@ -323,24 +317,25 @@ def surface(self, surface, color=None, opacity=1.0, backface_culling=False): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) - cmap = _get_colormap_from_array(colormap, normalized_colormap) + normals = surface.get('nn', None) vertices = np.array(surface['rr']) triangles = np.array(surface['tris']) - n_triangles = len(triangles) - triangles = np.c_[np.full(n_triangles, 3), triangles] + triangles = np.c_[np.full(len(triangles), 3), triangles] mesh = PolyData(vertices, triangles) - if scalars is not None: - mesh.point_arrays['scalars'] = scalars - _add_mesh( - plotter=self.plotter, - mesh=mesh, color=color, - rng=[vmin, vmax], - show_scalar_bar=False, - opacity=opacity, - cmap=cmap, - backface_culling=backface_culling, - smooth_shading=self.figure.smooth_shading - ) + colormap = _get_colormap_from_array(colormap, normalized_colormap) + if scalars is not None: + mesh.point_arrays['scalars'] = scalars + return self.polydata( + mesh=mesh, + color=color, + opacity=opacity, + normals=normals, + backface_culling=backface_culling, + scalars=scalars, + colormap=colormap, + vmin=vmin, + vmax=vmax, + ) def sphere(self, center, color, scale, opacity=1.0, resolution=8, backface_culling=False, @@ -360,7 +355,7 @@ def sphere(self, center, color, scale, opacity=1.0, factor=factor, geom=geom) actor = _add_mesh( self.plotter, - glyph, color=color, opacity=opacity, + mesh=glyph, color=color, opacity=opacity, backface_culling=backface_culling, smooth_shading=self.figure.smooth_shading ) @@ -421,9 +416,9 @@ def quiver3d(self, x, y, z, u, v, w, color, scale, mode, resolution=8, elif mode == 'arrow' or mode == '3darrow': _add_mesh( self.plotter, - grid.glyph(orient='vec', - scale=scale, - factor=factor), + mesh=grid.glyph(orient='vec', + scale=scale, + factor=factor), color=color, opacity=opacity, backface_culling=backface_culling @@ -441,10 +436,10 @@ def quiver3d(self, x, y, z, u, v, w, color, scale, mode, resolution=8, geom = cone.GetOutput() _add_mesh( self.plotter, - grid.glyph(orient='vec', - scale=scale, - factor=factor, - geom=geom), + mesh=grid.glyph(orient='vec', + scale=scale, + factor=factor, + geom=geom), color=color, opacity=opacity, backface_culling=backface_culling @@ -468,10 +463,10 @@ def quiver3d(self, x, y, z, u, v, w, color, scale, mode, resolution=8, geom = trp.GetOutput() _add_mesh( self.plotter, - grid.glyph(orient='vec', - scale=scale, - factor=factor, - geom=geom), + mesh=grid.glyph(orient='vec', + scale=scale, + factor=factor, + geom=geom), color=color, opacity=opacity, backface_culling=backface_culling @@ -567,10 +562,30 @@ def remove_mesh(self, mesh_data): self.plotter.renderer.remove_actor(actor) +def _compute_normals(mesh): + """Patch PyVista compute_normals.""" + if 'Normals' not in mesh.point_arrays: + mesh.compute_normals( + cell_normals=False, + consistent_normals=False, + non_manifold_traversal=False, + inplace=True, + ) + + def _add_mesh(plotter, *args, **kwargs): + """Patch PyVista add_mesh.""" _process_events(plotter) - kwargs['style'] = kwargs.pop('representation', 'wireframe') - return plotter.add_mesh(*args, **kwargs) + mesh = kwargs.get('mesh') + if 'smooth_shading' in kwargs: + smooth_shading = kwargs.pop('smooth_shading') + else: + smooth_shading = True + actor = plotter.add_mesh(*args, **kwargs) + if smooth_shading and 'Normals' in mesh.point_arrays: + prop = actor.GetProperty() + prop.SetInterpolationToPhong() + return actor def _deg2rad(deg): @@ -800,20 +815,6 @@ def _update_picking_callback(plotter, plotter.picker = picker -def _add_polydata_actor(plotter, polydata, name=None, - hide=False): - mapper = vtk.vtkPolyDataMapper() - mapper.SetInputData(polydata) - - actor = vtk.vtkActor() - actor.SetMapper(mapper) - if hide: - actor.VisibilityOff() - - plotter.add_actor(actor, name=name) - return actor - - def _arrow_glyph(grid, factor): glyph = vtk.vtkGlyphSource2D() glyph.SetGlyphTypeToArrow() @@ -878,7 +879,11 @@ def _sphere(plotter, center, color, radius): sphere.SetCenter(center) sphere.Update() mesh = pyvista.wrap(sphere.GetOutput()) - actor = _add_mesh(plotter, mesh, color=color) + actor = _add_mesh( + plotter, + mesh=mesh, + color=color + ) return actor, mesh diff --git a/mne/viz/backends/tests/test_renderer.py b/mne/viz/backends/tests/test_renderer.py index ae168c1eb13..87e4075002c 100644 --- a/mne/viz/backends/tests/test_renderer.py +++ b/mne/viz/backends/tests/test_renderer.py @@ -102,7 +102,11 @@ def test_3d_backend(renderer): cam_distance = 5 * tet_size # init scene - rend = renderer.backend._Renderer(size=win_size, bgcolor=win_color) + rend = renderer.backend._Renderer( + size=win_size, + bgcolor=win_color, + smooth_shading=True, + ) rend.set_interactive() # use mesh