diff --git a/chaco/tools/legend_highlighter.py b/chaco/tools/legend_highlighter.py index c0df6cf37..5ad9c3579 100644 --- a/chaco/tools/legend_highlighter.py +++ b/chaco/tools/legend_highlighter.py @@ -1,11 +1,19 @@ -import operator - -import six.moves as sm +from itertools import chain # ETS imports from chaco.tools.api import LegendTool from traits.api import List, Float +concat = chain.from_iterable + + +def _ensure_list(obj): + """ NOTE: The Legend stores plots in a dictionary with either single + renderers as values, or lists of renderers. + This function helps us assume we're always working with lists + """ + return obj if isinstance(obj, list) else [obj] + def get_hit_plots(legend, event): if legend is None or not legend.is_in(event.x, event.y): @@ -26,7 +34,7 @@ def get_hit_plots(legend, event): ndx = legend._cached_labels.index(label) label_name = legend._cached_label_names[ndx] renderers = legend.plots[label_name] - return renderers + return _ensure_list(renderers) except (ValueError, KeyError): return [] @@ -54,35 +62,34 @@ def normal_left_down(self, event): return plots = get_hit_plots(self.component, event) - - if len(plots) > 0: - plot = plots[0] - - if event.shift_down: - # User in multi-select mode by using [shift] key. + if event.shift_down: + # User in multi-select mode by using [shift] key. + for plot in plots: if plot in self._selected_renderers: self._selected_renderers.remove(plot) else: self._selected_renderers.append(plot) - - else: - # User in single-select mode. - add_plot = plot not in self._selected_renderers - self._selected_renderers = [] - if add_plot: - self._selected_renderers.append(plot) - - if self._selected_renderers: - self._set_states(self.component.plots) - else: - self._reset_selects(self.component.plots) - plot.request_redraw() + elif plots: + # User in single-select mode. + add_plot = any(plot not in self._selected_renderers + for plot in plots) + self._selected_renderers = [] + if add_plot: + self._selected_renderers.extend(plots) + + if self._selected_renderers: + self._set_states(self.component.plots) + else: + self._reset_selects(self.component.plots) + + if plots: + plots[0].request_redraw() event.handled = True def _reset_selects(self, plots): """ Set all renderers to their default values. """ - for plot in sm.reduce(operator.add, plots.values()): + for plot in concat(_ensure_list(p) for p in plots.values()): if not hasattr(plot, '_orig_alpha'): plot._orig_alpha = plot.alpha plot._orig_line_width = plot.line_width @@ -92,7 +99,7 @@ def _reset_selects(self, plots): def _set_states(self, plots): """ Decorates a plot to indicate it is selected """ - for plot in sm.reduce(operator.add, plots.values()): + for plot in concat(_ensure_list(p) for p in plots.values()): if not hasattr(plot, '_orig_alpha'): # FIXME: These attributes should be put into the class def. plot._orig_alpha = plot.alpha diff --git a/examples/demo/multiaxis.py b/examples/demo/multiaxis.py index 451a95fbc..f6a4a3c26 100644 --- a/examples/demo/multiaxis.py +++ b/examples/demo/multiaxis.py @@ -25,8 +25,8 @@ from chaco.api import create_line_plot, add_default_axes, \ add_default_grids, OverlayPlotContainer, \ PlotLabel, Legend, PlotAxis -from chaco.tools.api import PanTool, LegendTool, TraitsTool, \ - BroadcasterTool +from chaco.tools.api import (PanTool, LegendTool, LegendHighlighter, + TraitsTool, BroadcasterTool) #=============================================================================== # # Create the Chaco plot. @@ -76,6 +76,7 @@ def _create_plot_component(): legend = Legend(component=container, padding=10, align="ur") legend.tools.append(LegendTool(legend, drag_button="right")) + legend.tools.append(LegendHighlighter(legend)) container.overlays.append(legend) # Set the list of plots on the legend