From 29b3fadb0f69ed65145e0af830f276832db57bf8 Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Wed, 10 May 2023 16:11:57 +0100 Subject: [PATCH] Handle nans in barplots; better handling of cropping. --- chaco/plots/barplot.py | 24 ++++---- chaco/plots/tests/test_barplot.py | 98 +++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 10 deletions(-) create mode 100644 chaco/plots/tests/test_barplot.py diff --git a/chaco/plots/barplot.py b/chaco/plots/barplot.py index 4e8915654..cdfa55a5b 100644 --- a/chaco/plots/barplot.py +++ b/chaco/plots/barplot.py @@ -18,7 +18,7 @@ column_stack, empty, invert, - isnan, + isfinite, transpose, zeros, ) @@ -271,22 +271,26 @@ def _gather_points(self): self._cache_valid = True return - # TODO: Until we code up a better handling of value-based culling that - # takes into account starting_value and dataspace bar widths, just use - # the index culling for now. - # value_range_mask = self.value_mapper.range.mask_data(value) - # nan_mask = invert(isnan(index_mask)) & invert(isnan(value_mask)) - # point_mask = index_mask & value_mask & nan_mask & \ - # index_range_mask & value_range_mask + # TODO: better accounting for intersection of boxes with visible region + # current strategy simply masks the index values and then does a 1D + # dilation of the mask. This will work in many situations but will + # fail on extreme zoom in. + # Ideally we would work out all the boxes and compute intersections. index_range_mask = self.index_mapper.range.mask_data(index) - nan_mask = invert(isnan(index_mask)) - point_mask = index_mask & nan_mask & index_range_mask + # include points on either side of clipped range (1D dilation) + # - not perfect, but better than simple clipping + index_range_mask[:-1] |= index_range_mask[1:] + index_range_mask[1:] |= index_range_mask[:-1] + + nan_mask = isfinite(index) & isfinite(value) + point_mask = index_mask & value_mask & nan_mask & index_range_mask if self.starting_value is None: starting_values = zeros(len(index)) else: starting_values = self.starting_value.get_data() + point_mask &= isfinite(starting_values) if self.bar_width_type == "data": half_width = self.bar_width / 2.0 diff --git a/chaco/plots/tests/test_barplot.py b/chaco/plots/tests/test_barplot.py new file mode 100644 index 000000000..d92554bd8 --- /dev/null +++ b/chaco/plots/tests/test_barplot.py @@ -0,0 +1,98 @@ +# (C) Copyright 2005-2021 Enthought, Inc., Austin, TX +# All rights reserved. +# +# This software is provided without warranty under the terms of the BSD +# license included in LICENSE.txt and may be redistributed only under +# the conditions described in the aforementioned license. The license +# is also available online at http://www.enthought.com/licenses/BSD.txt +# +# Thanks for using Enthought open source! + +import unittest + +from numpy import alltrue, arange, nan + +from traits.testing.api import UnittestTools + +# Chaco imports +from chaco.api import ( + ArrayDataSource, + DataRange1D, + LinearMapper, + PlotGraphicsContext, +) +from chaco.plots.api import BarPlot + + +class BarPlotTest(UnittestTools, unittest.TestCase): + def setUp(self): + self.size = (250, 250) + values = arange(10.0, 0.0, -1.0) + values[2] = nan + value_data_source = ArrayDataSource(values) + value_range = DataRange1D() + value_range.add(value_data_source) + value_mapper = LinearMapper(range=value_range) + starting_value_data_source = ArrayDataSource(-values) + value_range.add(starting_value_data_source) + indices = arange(10.0) + indices[4] = nan + index_data_source = ArrayDataSource(indices) + index_range = DataRange1D() + index_range.add(index_data_source) + index_mapper = LinearMapper(range=index_range) + self.barplot = BarPlot( + index=index_data_source, + value=value_data_source, + starting_value=starting_value_data_source, + index_mapper=index_mapper, + value_mapper=value_mapper, + border_visible=False, + ) + self.barplot.outer_bounds = list(self.size) + + def test_barplot(self): + self.assertEqual(self.barplot.origin, "bottom left") + self.assertIs( + self.barplot.x_mapper, self.barplot.index_mapper + ) + self.assertIs( + self.barplot.y_mapper, self.barplot.value_mapper + ) + self.assertIs( + self.barplot.index_range, + self.barplot.index_mapper.range, + ) + self.assertIs( + self.barplot.value_range, + self.barplot.value_mapper.range, + ) + + gc = PlotGraphicsContext(self.size) + gc.render_component(self.barplot) + actual = gc.bmp_array[:, :, :] + self.assertFalse(alltrue(actual == 255)) + + def test_barplot_horizontal(self): + self.barplot.orientation = 'v' + + self.assertEqual(self.barplot.origin, "bottom left") + self.assertIs( + self.barplot.x_mapper, self.barplot.value_mapper + ) + self.assertIs( + self.barplot.y_mapper, self.barplot.index_mapper + ) + self.assertIs( + self.barplot.index_range, + self.barplot.index_mapper.range, + ) + self.assertIs( + self.barplot.value_range, + self.barplot.value_mapper.range, + ) + + gc = PlotGraphicsContext(self.size) + gc.render_component(self.barplot) + actual = gc.bmp_array[:, :, :] + self.assertFalse(alltrue(actual == 255))