From baf15141911060f23157fb3e55825a0bc355e694 Mon Sep 17 00:00:00 2001 From: Robert Kern Date: Tue, 23 Sep 2014 14:36:14 +0100 Subject: [PATCH 1/4] BUG: Remove equality comparisons with None. --- mayavi/tools/decorations.py | 2 +- mayavi/tools/filters.py | 4 ++-- mayavi/tools/helper_functions.py | 4 ++-- mayavi/tools/modules.py | 2 +- mayavi/tools/pipe_base.py | 4 ++-- mlab_reference.py | 2 +- tvtk/pyface/tvtk_scene.py | 2 +- tvtk/util/gradient_editor.py | 4 ++-- tvtk/util/tk_gradient_editor.py | 4 ++-- 9 files changed, 14 insertions(+), 14 deletions(-) diff --git a/mayavi/tools/decorations.py b/mayavi/tools/decorations.py index 5fa8ee69b..5a0093d22 100644 --- a/mayavi/tools/decorations.py +++ b/mayavi/tools/decorations.py @@ -242,7 +242,7 @@ def __init__(self, *args, **kwargs): raise ValueError("Wrong number of arguments") # Try to find an existing module, if not add one to the pipeline - if parent == None: + if parent is None: target = self._scene else: target = parent diff --git a/mayavi/tools/filters.py b/mayavi/tools/filters.py index 11c43cc72..3baa0577a 100644 --- a/mayavi/tools/filters.py +++ b/mayavi/tools/filters.py @@ -69,7 +69,7 @@ class ThresholdFactory(PipeFactory): low = Trait(None, None, CFloat, help="The lower threshold") def _low_changed(self): - if self.low == None: + if self.low is None: pass else: self._target.lower_threshold = self.low @@ -77,7 +77,7 @@ def _low_changed(self): up = Trait(None, None, CFloat, help="The upper threshold") def _up_changed(self): - if self.up == None: + if self.up is None: pass else: self._target.upper_threshold = self.up diff --git a/mayavi/tools/helper_functions.py b/mayavi/tools/helper_functions.py index 6776cb41a..2979247ab 100644 --- a/mayavi/tools/helper_functions.py +++ b/mayavi/tools/helper_functions.py @@ -525,7 +525,7 @@ def __call_internal__(self, *args, **kwargs): self.store_kwargs(kwargs) # Copy the pipeline so as not to modify it for the next call self.pipeline = self._pipeline[:] - if self.kwargs['tube_radius'] == None: + if self.kwargs['tube_radius'] is None: self.pipeline.remove(TubeFactory) self.pipeline.remove(StripperFactory) return self.build_pipeline() @@ -834,7 +834,7 @@ def __call_internal__(self, *args, **kwargs): self.pipeline.remove(GlyphFactory) self.pipeline = [PolyDataNormalsFactory, ] + self.pipeline else: - if self.kwargs['tube_radius'] == None: + if self.kwargs['tube_radius'] is None: self.pipeline.remove(TubeFactory) if not self.kwargs['representation'] == 'fancymesh': self.pipeline.remove(GlyphFactory) diff --git a/mayavi/tools/modules.py b/mayavi/tools/modules.py index 1a0b1373b..d762add05 100644 --- a/mayavi/tools/modules.py +++ b/mayavi/tools/modules.py @@ -134,7 +134,7 @@ def _colormap_changed(self): If None, the max of the data will be used""") def _vmin_changed(self): - if self.vmin == None and self.vmax == None: + if self.vmin is None and self.vmax is None: self._target.module_manager.scalar_lut_manager.use_default_range\ = True return diff --git a/mayavi/tools/pipe_base.py b/mayavi/tools/pipe_base.py index a75424a8e..a4551b190 100644 --- a/mayavi/tools/pipe_base.py +++ b/mayavi/tools/pipe_base.py @@ -179,7 +179,7 @@ def set(self, trait_change_notify=True, **traits): callback() self._anytrait_changed(trait, value) except TraitError: - if value == None: + if value is None: # This means "default" pass else: @@ -192,7 +192,7 @@ def _anytrait_changed(self, name, value): # Private attribute return # hasattr(traits, "adapts") always returns True :-<. - if not trait.adapts == None: + if not trait.adapts is None: components = trait.adapts.split('.') obj = get_obj(self._target, components[:-1]) setattr(obj, components[-1], value) diff --git a/mlab_reference.py b/mlab_reference.py index ad29b535c..53ce5a93b 100644 --- a/mlab_reference.py +++ b/mlab_reference.py @@ -95,7 +95,7 @@ def document_function(func, func_name=None, example_code=None, """ Creates a rst documentation string for the function, with an image and a example code, if given. """ - if func_name==None: + if func_name is None: func_name = func.__name__ func_doc = func.__doc__ diff --git a/tvtk/pyface/tvtk_scene.py b/tvtk/pyface/tvtk_scene.py index 735476bf7..a542d235d 100644 --- a/tvtk/pyface/tvtk_scene.py +++ b/tvtk/pyface/tvtk_scene.py @@ -519,7 +519,7 @@ def save_rib(self, file_name, bg=0, resolution=None, resfactor=1.0): resfactor -- The resolution factor which scales the resolution. """ - if resolution == None: + if resolution is None: # get present window size Nx, Ny = self.render_window.size else: diff --git a/tvtk/util/gradient_editor.py b/tvtk/util/gradient_editor.py index 8a75b4419..18d9001a0 100644 --- a/tvtk/util/gradient_editor.py +++ b/tvtk/util/gradient_editor.py @@ -1092,7 +1092,7 @@ def __init__(self, master, gradient_table, color_space, width, height): self.cur_drag = None #<- [channel,control_point] while something is dragged. def find_control_point(self, x, y): - """Check if a control point lies near (x,y) or near x if y == None. + """Check if a control point lies near (x,y) or near x if y is None. returns [channel, control point] if found, None otherwise""" for channel in self.channels: for control_point in self.table.control_points: @@ -1103,7 +1103,7 @@ def find_control_point(self, x, y): point_x = channel.get_pos_index( control_point.pos ) point_y = channel.get_value_index( control_point.color ) y_ = y - if ( None == y_ ): + if ( y_ is None ): y_ = point_y if ( (point_x-x)**2 + (point_y-y_)**2 <= self.control_pt_click_tolerance**2 ): return [channel, control_point] diff --git a/tvtk/util/tk_gradient_editor.py b/tvtk/util/tk_gradient_editor.py index c265e2df6..59d9fc0ed 100644 --- a/tvtk/util/tk_gradient_editor.py +++ b/tvtk/util/tk_gradient_editor.py @@ -245,7 +245,7 @@ def update(self): channel.paint(self.canvas) def find_control_point(self, x, y): - """Check if a control point lies near (x,y) or near x if y == None. + """Check if a control point lies near (x,y) or near x if y is None. returns [channel, control point] if found, None otherwise""" for channel in self.channels: for control_point in self.table.control_points: @@ -256,7 +256,7 @@ def find_control_point(self, x, y): point_x = channel.get_pos_index( control_point.pos ) point_y = channel.get_value_index( control_point.color ) y_ = y - if ( None == y_ ): + if ( y_ is None ): y_ = point_y if ( (point_x-x)**2 + (point_y-y_)**2 <= self.control_pt_click_tolerance**2 ): return [channel, control_point] From 13aa42e87c0d41afc10dd02009c81566422ad175 Mon Sep 17 00:00:00 2001 From: Robert Kern Date: Tue, 23 Sep 2014 15:50:09 +0100 Subject: [PATCH 2/4] ENH: Add new ArrayOrNone and ArrayNumberOrNone traits. These avoid doing implicit comparisons of arrays with `None`. --- mayavi/core/trait_defs.py | 42 +++++++++++++++++++++++++- mayavi/tests/test_mayavi_traits.py | 48 ++++++++++++++++++++++++++++-- 2 files changed, 87 insertions(+), 3 deletions(-) diff --git a/mayavi/core/trait_defs.py b/mayavi/core/trait_defs.py index 5cef12da5..b02b0a684 100644 --- a/mayavi/core/trait_defs.py +++ b/mayavi/core/trait_defs.py @@ -16,7 +16,10 @@ # #--------------------------------------------------------------------------- -from traits.api import Property, TraitFactory, TraitError, TraitType, Int +import operator + +from traits.api import (CArray, Int, NO_COMPARE, Property, TraitError, + TraitFactory, TraitType) from traitsui.api import EnumEditor from traits.traits import trait_cast @@ -252,3 +255,40 @@ def callback(value): object.remove_trait(attr) return status + +class ArrayOrNone(CArray): + """ Either an array-like object or None. + """ + + def __init__(self, *args, **metadata): + metadata['comparison_mode'] = NO_COMPARE + super(ArrayOrNone, self).__init__(*args, **metadata) + + def validate(self, object, name, value): + if value is None: + return value + return super(ArrayOrNone, self).validate(object, name, value) + + def get_default_value(self): + return (0, None) + + +class ArrayNumberOrNone(CArray): + """ Either an array-like, number converted to a 1D array, or None. + """ + + def __init__(self, *args, **metadata): + metadata['comparison_mode'] = NO_COMPARE + super(ArrayNumberOrNone, self).__init__(*args, **metadata) + + def validate(self, object, name, value): + if value is None: + return value + elif operator.isNumberType(value): + # Local import to avoid explicit dependency. + import numpy + value = numpy.atleast_1d(value) + return super(ArrayNumberOrNone, self).validate(object, name, value) + + def get_default_value(self): + return (0, None) diff --git a/mayavi/tests/test_mayavi_traits.py b/mayavi/tests/test_mayavi_traits.py index d0a610fd0..72e1db8ca 100644 --- a/mayavi/tests/test_mayavi_traits.py +++ b/mayavi/tests/test_mayavi_traits.py @@ -9,10 +9,10 @@ import numpy from traits.api import (HasTraits, Either, Array, Any, TraitError, Float, Int) -from mayavi.core.trait_defs import ShadowProperty +from mayavi.core.trait_defs import (ArrayNumberOrNone, ArrayOrNone, + ShadowProperty) -ArrayOrNone = Either(None, Array) class DataNotSmart(HasTraits): x = ShadowProperty(ArrayOrNone, smart_notify=False) # Test attribute. @@ -34,6 +34,20 @@ class Simple(HasTraits): def _x_changed(self, value): self._test += 1 +class HasArrays(HasTraits): + x = ArrayOrNone + y = ArrayNumberOrNone + + # Test attribute. + _test_x = Int(0) + _test_y = Int(0) + + def _x_changed(self, value): + self._test_x += 1 + + def _y_changed(self, value): + self._test_y += 1 + class TestShadowProperty(unittest.TestCase): def test_simple_case(self): @@ -98,6 +112,36 @@ def test_set_trait_change_notify(self): self.assertEqual(s.trait_names(), trait_names) self.assertEqual(s._notifiers(False), None) + +class TestArrayOrNone(unittest.TestCase): + + def test_default(self): + a = HasArrays() + self.assertIsNone(a.x) + self.assertIsNone(a.y) + + def test_no_compare(self): + a = HasArrays() + a.x = numpy.arange(10) + self.assertEqual(a._test_x, 1) + a.x = numpy.arange(10) + self.assertEqual(a._test_x, 2) + a.x = a.x + self.assertEqual(a._test_x, 3) + a.x = None + self.assertEqual(a._test_x, 4) + + a.y = numpy.arange(10) + self.assertEqual(a._test_y, 1) + a.y = 1.0 + self.assertEqual(a.y.shape, (1,)) + self.assertEqual(a._test_y, 2) + a.y = a.y + self.assertEqual(a._test_y, 3) + a.y = None + self.assertEqual(a._test_y, 4) + + if __name__ == '__main__': unittest.main() From b79f7b0c172fb10e81cc3155bccf7de21ca3f5f1 Mon Sep 17 00:00:00 2001 From: Robert Kern Date: Tue, 23 Sep 2014 15:50:56 +0100 Subject: [PATCH 3/4] BUG: Avoid a comparison of arrays with None. --- mayavi/tests/test_mlab_source_integration.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mayavi/tests/test_mlab_source_integration.py b/mayavi/tests/test_mlab_source_integration.py index 3b1b5dbe5..37025d292 100644 --- a/mayavi/tests/test_mlab_source_integration.py +++ b/mayavi/tests/test_mlab_source_integration.py @@ -29,8 +29,9 @@ def tearDown(self): def all_close(self, a, b): """ Similar to numpy's allclose, but works also for a=None. """ - if None in (a, b): - self.assert_(a==b) + if a is None or b is None: + self.assertIsNone(a) + self.assertIsNone(b) else: self.assert_(np.allclose(a, a)) From f960c28edd6caa3205794eab7df5538acbd7d2d2 Mon Sep 17 00:00:00 2001 From: Robert Kern Date: Tue, 23 Sep 2014 15:51:14 +0100 Subject: [PATCH 4/4] ENH: Use the new ArrayOrNone and ArrayNumberOrNone traits. --- .../tools/data_wizards/data_source_factory.py | 5 ++--- mayavi/tools/sources.py | 21 ++----------------- 2 files changed, 4 insertions(+), 22 deletions(-) diff --git a/mayavi/tools/data_wizards/data_source_factory.py b/mayavi/tools/data_wizards/data_source_factory.py index 01933df15..429e9da78 100644 --- a/mayavi/tools/data_wizards/data_source_factory.py +++ b/mayavi/tools/data_wizards/data_source_factory.py @@ -2,16 +2,15 @@ from numpy import c_, zeros, arange from traits.api import HasStrictTraits, \ - true, false, CArray, Trait, Instance + true, false, Instance from mayavi.sources.vtk_data_source import VTKDataSource from mayavi.sources.array_source import ArraySource from mayavi.core.source import Source +from mayavi.core.trait_defs import ArrayOrNone from tvtk.api import tvtk -ArrayOrNone = Trait(None, (None, CArray)) - ############################################################################ # The DataSourceFactory class diff --git a/mayavi/tools/sources.py b/mayavi/tools/sources.py index 5958738d1..7390a07da 100644 --- a/mayavi/tools/sources.py +++ b/mayavi/tools/sources.py @@ -7,17 +7,15 @@ # Copyright (c) 2007-2010, Enthought, Inc. # License: BSD Style. -import operator - import numpy as np -from traits.api import (HasTraits, Instance, CArray, Either, - Bool, on_trait_change, NO_COMPARE) +from traits.api import Bool, HasTraits, Instance, on_trait_change from tvtk.api import tvtk from tvtk.common import camel2enthought from mayavi.sources.array_source import ArraySource from mayavi.core.registry import registry +from mayavi.core.trait_defs import ArrayNumberOrNone, ArrayOrNone import tools from engine_manager import get_null_engine, engine_manager @@ -28,17 +26,6 @@ ] -############################################################################### -# A subclass of CArray that will accept floats and do a np.atleast_1d -############################################################################### -class CArrayOrNumber(CArray): - - def validate(self, object, name, value): - if operator.isNumberType(value): - value = np.atleast_1d(value) - return CArray.validate(self, object, name, value) - - ############################################################################### # `MlabSource` class. ############################################################################### @@ -127,10 +114,6 @@ def _m_data_changed(self, ds): ds.mlab_source = self -ArrayOrNone = Either(None, CArray, comparison_mode=NO_COMPARE) -ArrayNumberOrNone = Either(None, CArrayOrNumber, comparison_mode=NO_COMPARE) - - ############################################################################### # `MGlyphSource` class. ###############################################################################