Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 104 additions & 15 deletions mne/viz/backends/_pyvista.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# License: Simplified BSD

from contextlib import contextmanager
import os
import warnings

import numpy as np
Expand Down Expand Up @@ -101,6 +102,21 @@ def visible(self, state):
self.pts.SetVisibility(state)


def _enable_aa(figure, plotter):
"""Enable it everywhere except Azure."""
# XXX for some reason doing this on Azure causes access violations:
# ##[error]Cmd.exe exited with code '-1073741819'
# So for now don't use it there. Maybe has to do with setting these
# before the window has actually been made "active"...?
# For Mayavi we have an "on activated" event or so, we should look into
# using this for Azure at some point, too.
if os.getenv('AZURE_CI_WINDOWS', 'false').lower() == 'true':
return
if figure.is_active():
plotter.enable_anti_aliasing()
plotter.ren_win.LineSmoothingOn()


@copy_base_doc_to_subclass_doc
class _Renderer(_BaseRenderer):
"""Class managing rendering scene.
Expand Down Expand Up @@ -143,11 +159,13 @@ def __init__(self, fig=None, size=(600, 600), bgcolor='black',
with _disabled_depth_peeling():
self.plotter = self.figure.build()
self.plotter.hide_axes()
_enable_aa(self.figure, self.plotter)

def subplot(self, x, y):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
self.plotter.subplot(x, y)
_enable_aa(self.figure, self.plotter)

def scene(self):
return self.figure
Expand Down Expand Up @@ -255,7 +273,7 @@ def sphere(self, center, color, scale, opacity=1.0,
sphere.SetRadius(radius)
sphere.Update()
geom = sphere.GetOutput()
mesh = PolyData(center)
mesh = PolyData(np.array(center))
glyph = mesh.glyph(orient=False, scale=False,
factor=factor, geom=geom)
actor = self.plotter.add_mesh(
Expand Down Expand Up @@ -294,7 +312,7 @@ def tube(self, origin, destination, radius=0.001, color='white',
def quiver3d(self, x, y, z, u, v, w, color, scale, mode, resolution=8,
glyph_height=None, glyph_center=None, glyph_resolution=None,
opacity=1.0, scale_mode='none', scalars=None,
backface_culling=False):
backface_culling=False, line_width=2., name=None):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
factor = scale
Expand All @@ -306,21 +324,21 @@ def quiver3d(self, x, y, z, u, v, w, color, scale, mode, resolution=8,
cells = np.c_[np.full(n_points, 1), range(n_points)]
grid = UnstructuredGrid(offset, cells, cell_type, points)
grid.point_arrays['vec'] = vectors
if scale_mode == "scalar":
if scale_mode == 'scalar':
grid.point_arrays['mag'] = np.array(scalars)
scale = 'mag'
else:
scale = False
if mode == "arrow":
if mode == '2darrow':
return _arrow_glyph(grid, factor)
elif mode == 'arrow' or mode == '3darrow':
self.plotter.add_mesh(grid.glyph(orient='vec',
scale=scale,
factor=factor),
color=color,
opacity=opacity,
backface_culling=backface_culling,
smooth_shading=self.figure.
smooth_shading)
elif mode == "cone":
backface_culling=backface_culling)
elif mode == 'cone':
cone = vtk.vtkConeSource()
if glyph_height is not None:
cone.SetHeight(glyph_height)
Expand All @@ -337,11 +355,9 @@ def quiver3d(self, x, y, z, u, v, w, color, scale, mode, resolution=8,
geom=geom),
color=color,
opacity=opacity,
backface_culling=backface_culling,
smooth_shading=self.figure.
smooth_shading)
backface_culling=backface_culling)

elif mode == "cylinder":
elif mode == 'cylinder':
cylinder = vtk.vtkCylinderSource()
cylinder.SetHeight(glyph_height)
cylinder.SetRadius(0.15)
Expand All @@ -364,9 +380,7 @@ def quiver3d(self, x, y, z, u, v, w, color, scale, mode, resolution=8,
geom=geom),
color=color,
opacity=opacity,
backface_culling=backface_culling,
smooth_shading=self.figure.
smooth_shading)
backface_culling=backface_culling)

def text2d(self, x_window, y_window, text, size=14, color='white',
justification=None):
Expand Down Expand Up @@ -441,6 +455,11 @@ def project(self, xyz, ch_names):

return _Projection(xy=xy, pts=pts)

def enable_depth_peeling(self):
if not self.figure.store['off_screen']:
for renderer in self.plotter.renderers:
renderer.enable_depth_peeling()


def _deg2rad(deg):
return deg * np.pi / 180.
Expand Down Expand Up @@ -644,6 +663,76 @@ def _update_picking_callback(plotter,
plotter.picker = picker


def _add_polydata_actor(plotter, polydata, name=None,
hide=False):
mapper = vtk.vtkPolyDataMapper()
mapper.SetInputData(polydata)

actor = vtk.vtkActor()
actor.SetMapper(mapper)
if hide:
actor.VisibilityOff()

plotter.add_actor(actor, name=name)
return actor


def _arrow_glyph(grid, factor):
glyph = vtk.vtkGlyphSource2D()
glyph.SetGlyphTypeToArrow()
glyph.FilledOff()
glyph.Update()
geom = glyph.GetOutput()

# fix position
tr = vtk.vtkTransform()
tr.Translate(0.5, 0., 0.)
trp = vtk.vtkTransformPolyDataFilter()
trp.SetInputData(geom)
trp.SetTransform(tr)
trp.Update()
geom = trp.GetOutput()

polydata = _glyph(
grid,
scale_mode='vector',
scalars=False,
orient='vec',
factor=factor,
geom=geom,
)
return pyvista.wrap(polydata)


def _glyph(dataset, scale_mode='scalar', orient=True, scalars=True, factor=1.0,
geom=None, tolerance=0.0, absolute=False, clamping=False, rng=None):
if geom is None:
arrow = vtk.vtkArrowSource()
arrow.Update()
geom = arrow.GetOutput()
alg = vtk.vtkGlyph3D()
alg.SetSourceData(geom)
if isinstance(scalars, str):
dataset.active_scalars_name = scalars
if isinstance(orient, str):
dataset.active_vectors_name = orient
orient = True
if scale_mode == 'scalar':
alg.SetScaleModeToScaleByScalar()
elif scale_mode == 'vector':
alg.SetScaleModeToScaleByVector()
else:
alg.SetScaleModeToDataScalingOff()
if rng is not None:
alg.SetRange(rng)
alg.SetOrient(orient)
alg.SetInputData(dataset)
alg.SetScaleFactor(factor)
alg.SetClamping(clamping)
alg.Update()
return alg.GetOutput()


def _require_minimum_version(version_required):
from distutils.version import LooseVersion
version = LooseVersion(pyvista.__version__)
Expand Down