Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion mayavi/core/trait_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
#
#---------------------------------------------------------------------------

from traits.api import Property, TraitFactory, TraitError, TraitType, Int
import operator

from traits.api import (CArray, Int, NO_COMPARE, Property, TraitError,
TraitFactory, TraitType)
from traitsui.api import EnumEditor
from traits.traits import trait_cast

Expand Down Expand Up @@ -252,3 +255,40 @@ def callback(value):
object.remove_trait(attr)
return status


class ArrayOrNone(CArray):
""" Either an array-like object or None.
"""

def __init__(self, *args, **metadata):
metadata['comparison_mode'] = NO_COMPARE
super(ArrayOrNone, self).__init__(*args, **metadata)

def validate(self, object, name, value):
if value is None:
return value
return super(ArrayOrNone, self).validate(object, name, value)

def get_default_value(self):
return (0, None)


class ArrayNumberOrNone(CArray):
""" Either an array-like, number converted to a 1D array, or None.
"""

def __init__(self, *args, **metadata):
metadata['comparison_mode'] = NO_COMPARE
super(ArrayNumberOrNone, self).__init__(*args, **metadata)

def validate(self, object, name, value):
if value is None:
return value
elif operator.isNumberType(value):
# Local import to avoid explicit dependency.
import numpy
value = numpy.atleast_1d(value)
return super(ArrayNumberOrNone, self).validate(object, name, value)

def get_default_value(self):
return (0, None)
48 changes: 46 additions & 2 deletions mayavi/tests/test_mayavi_traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import numpy
from traits.api import (HasTraits, Either, Array, Any,
TraitError, Float, Int)
from mayavi.core.trait_defs import ShadowProperty
from mayavi.core.trait_defs import (ArrayNumberOrNone, ArrayOrNone,
ShadowProperty)


ArrayOrNone = Either(None, Array)
class DataNotSmart(HasTraits):
x = ShadowProperty(ArrayOrNone, smart_notify=False)
# Test attribute.
Expand All @@ -34,6 +34,20 @@ class Simple(HasTraits):
def _x_changed(self, value):
self._test += 1

class HasArrays(HasTraits):
x = ArrayOrNone
y = ArrayNumberOrNone

# Test attribute.
_test_x = Int(0)
_test_y = Int(0)

def _x_changed(self, value):
self._test_x += 1

def _y_changed(self, value):
self._test_y += 1


class TestShadowProperty(unittest.TestCase):
def test_simple_case(self):
Expand Down Expand Up @@ -98,6 +112,36 @@ def test_set_trait_change_notify(self):
self.assertEqual(s.trait_names(), trait_names)
self.assertEqual(s._notifiers(False), None)


class TestArrayOrNone(unittest.TestCase):

def test_default(self):
a = HasArrays()
self.assertIsNone(a.x)
self.assertIsNone(a.y)

def test_no_compare(self):
a = HasArrays()
a.x = numpy.arange(10)
self.assertEqual(a._test_x, 1)
a.x = numpy.arange(10)
self.assertEqual(a._test_x, 2)
a.x = a.x
self.assertEqual(a._test_x, 3)
a.x = None
self.assertEqual(a._test_x, 4)

a.y = numpy.arange(10)
self.assertEqual(a._test_y, 1)
a.y = 1.0
self.assertEqual(a.y.shape, (1,))
self.assertEqual(a._test_y, 2)
a.y = a.y
self.assertEqual(a._test_y, 3)
a.y = None
self.assertEqual(a._test_y, 4)


if __name__ == '__main__':
unittest.main()

5 changes: 3 additions & 2 deletions mayavi/tests/test_mlab_source_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ def tearDown(self):
def all_close(self, a, b):
""" Similar to numpy's allclose, but works also for a=None.
"""
if None in (a, b):
self.assert_(a==b)
if a is None or b is None:
self.assertIsNone(a)
self.assertIsNone(b)
else:
self.assert_(np.allclose(a, a))

Expand Down
5 changes: 2 additions & 3 deletions mayavi/tools/data_wizards/data_source_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@
from numpy import c_, zeros, arange

from traits.api import HasStrictTraits, \
true, false, CArray, Trait, Instance
true, false, Instance

from mayavi.sources.vtk_data_source import VTKDataSource
from mayavi.sources.array_source import ArraySource
from mayavi.core.source import Source
from mayavi.core.trait_defs import ArrayOrNone

from tvtk.api import tvtk

ArrayOrNone = Trait(None, (None, CArray))


############################################################################
# The DataSourceFactory class
Expand Down
2 changes: 1 addition & 1 deletion mayavi/tools/decorations.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def __init__(self, *args, **kwargs):
raise ValueError("Wrong number of arguments")

# Try to find an existing module, if not add one to the pipeline
if parent == None:
if parent is None:
target = self._scene
else:
target = parent
Expand Down
4 changes: 2 additions & 2 deletions mayavi/tools/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ class ThresholdFactory(PipeFactory):
low = Trait(None, None, CFloat, help="The lower threshold")

def _low_changed(self):
if self.low == None:
if self.low is None:
pass
else:
self._target.lower_threshold = self.low

up = Trait(None, None, CFloat, help="The upper threshold")

def _up_changed(self):
if self.up == None:
if self.up is None:
pass
else:
self._target.upper_threshold = self.up
Expand Down
4 changes: 2 additions & 2 deletions mayavi/tools/helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def __call_internal__(self, *args, **kwargs):
self.store_kwargs(kwargs)
# Copy the pipeline so as not to modify it for the next call
self.pipeline = self._pipeline[:]
if self.kwargs['tube_radius'] == None:
if self.kwargs['tube_radius'] is None:
self.pipeline.remove(TubeFactory)
self.pipeline.remove(StripperFactory)
return self.build_pipeline()
Expand Down Expand Up @@ -834,7 +834,7 @@ def __call_internal__(self, *args, **kwargs):
self.pipeline.remove(GlyphFactory)
self.pipeline = [PolyDataNormalsFactory, ] + self.pipeline
else:
if self.kwargs['tube_radius'] == None:
if self.kwargs['tube_radius'] is None:
self.pipeline.remove(TubeFactory)
if not self.kwargs['representation'] == 'fancymesh':
self.pipeline.remove(GlyphFactory)
Expand Down
2 changes: 1 addition & 1 deletion mayavi/tools/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _colormap_changed(self):
If None, the max of the data will be used""")

def _vmin_changed(self):
if self.vmin == None and self.vmax == None:
if self.vmin is None and self.vmax is None:
self._target.module_manager.scalar_lut_manager.use_default_range\
= True
return
Expand Down
4 changes: 2 additions & 2 deletions mayavi/tools/pipe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def set(self, trait_change_notify=True, **traits):
callback()
self._anytrait_changed(trait, value)
except TraitError:
if value == None:
if value is None:
# This means "default"
pass
else:
Expand All @@ -192,7 +192,7 @@ def _anytrait_changed(self, name, value):
# Private attribute
return
# hasattr(traits, "adapts") always returns True :-<.
if not trait.adapts == None:
if not trait.adapts is None:
components = trait.adapts.split('.')
obj = get_obj(self._target, components[:-1])
setattr(obj, components[-1], value)
21 changes: 2 additions & 19 deletions mayavi/tools/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,15 @@
# Copyright (c) 2007-2010, Enthought, Inc.
# License: BSD Style.

import operator

import numpy as np

from traits.api import (HasTraits, Instance, CArray, Either,
Bool, on_trait_change, NO_COMPARE)
from traits.api import Bool, HasTraits, Instance, on_trait_change
from tvtk.api import tvtk
from tvtk.common import camel2enthought

from mayavi.sources.array_source import ArraySource
from mayavi.core.registry import registry
from mayavi.core.trait_defs import ArrayNumberOrNone, ArrayOrNone

import tools
from engine_manager import get_null_engine, engine_manager
Expand All @@ -28,17 +26,6 @@
]


###############################################################################
# A subclass of CArray that will accept floats and do a np.atleast_1d
###############################################################################
class CArrayOrNumber(CArray):

def validate(self, object, name, value):
if operator.isNumberType(value):
value = np.atleast_1d(value)
return CArray.validate(self, object, name, value)


###############################################################################
# `MlabSource` class.
###############################################################################
Expand Down Expand Up @@ -127,10 +114,6 @@ def _m_data_changed(self, ds):
ds.mlab_source = self


ArrayOrNone = Either(None, CArray, comparison_mode=NO_COMPARE)
ArrayNumberOrNone = Either(None, CArrayOrNumber, comparison_mode=NO_COMPARE)


###############################################################################
# `MGlyphSource` class.
###############################################################################
Expand Down
2 changes: 1 addition & 1 deletion mlab_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def document_function(func, func_name=None, example_code=None,
""" Creates a rst documentation string for the function, with an
image and a example code, if given.
"""
if func_name==None:
if func_name is None:
func_name = func.__name__

func_doc = func.__doc__
Expand Down
2 changes: 1 addition & 1 deletion tvtk/pyface/tvtk_scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def save_rib(self, file_name, bg=0, resolution=None, resfactor=1.0):

resfactor -- The resolution factor which scales the resolution.
"""
if resolution == None:
if resolution is None:
# get present window size
Nx, Ny = self.render_window.size
else:
Expand Down
4 changes: 2 additions & 2 deletions tvtk/util/gradient_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,7 @@ def __init__(self, master, gradient_table, color_space, width, height):
self.cur_drag = None #<- [channel,control_point] while something is dragged.

def find_control_point(self, x, y):
"""Check if a control point lies near (x,y) or near x if y == None.
"""Check if a control point lies near (x,y) or near x if y is None.
returns [channel, control point] if found, None otherwise"""
for channel in self.channels:
for control_point in self.table.control_points:
Expand All @@ -1103,7 +1103,7 @@ def find_control_point(self, x, y):
point_x = channel.get_pos_index( control_point.pos )
point_y = channel.get_value_index( control_point.color )
y_ = y
if ( None == y_ ):
if ( y_ is None ):
y_ = point_y
if ( (point_x-x)**2 + (point_y-y_)**2 <= self.control_pt_click_tolerance**2 ):
return [channel, control_point]
Expand Down
4 changes: 2 additions & 2 deletions tvtk/util/tk_gradient_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def update(self):
channel.paint(self.canvas)

def find_control_point(self, x, y):
"""Check if a control point lies near (x,y) or near x if y == None.
"""Check if a control point lies near (x,y) or near x if y is None.
returns [channel, control point] if found, None otherwise"""
for channel in self.channels:
for control_point in self.table.control_points:
Expand All @@ -256,7 +256,7 @@ def find_control_point(self, x, y):
point_x = channel.get_pos_index( control_point.pos )
point_y = channel.get_value_index( control_point.color )
y_ = y
if ( None == y_ ):
if ( y_ is None ):
y_ = point_y
if ( (point_x-x)**2 + (point_y-y_)**2 <= self.control_pt_click_tolerance**2 ):
return [channel, control_point]
Expand Down