From 68edfce4653e8336509dec7540fd3e1746c96ad2 Mon Sep 17 00:00:00 2001 From: Joris Vankerschaver Date: Sat, 14 Jul 2018 16:06:30 -0500 Subject: [PATCH 1/2] FIX: Ensure seed is int. --- chaco/jitterplot.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/chaco/jitterplot.py b/chaco/jitterplot.py index bff7d22a2..c4d4c55e0 100644 --- a/chaco/jitterplot.py +++ b/chaco/jitterplot.py @@ -1,9 +1,7 @@ - from __future__ import absolute_import from math import sqrt -import six import six.moves as sm import numpy as np @@ -63,7 +61,7 @@ def map_screen(self, data_array): return np.vstack((ys, xs)).T def _make_jitter_vals(self, data_array): - random_state = np.random.RandomState(data_array[:100]) + random_state = np.random.RandomState(data_array[:100].astype(int)) numpts = len(data_array) vals = random_state.uniform(0, self.jitter_width, numpts) vals += self._marker_position From f54719113f1a34c19ca2368dc76e88d92aaba020 Mon Sep 17 00:00:00 2001 From: Joris Vankerschaver Date: Sat, 14 Jul 2018 16:07:19 -0500 Subject: [PATCH 2/2] MAINT: Fix up create_plot_snapshots script. --- .../plot_types/create_plot_snapshots.py | 209 ++++++++++-------- 1 file changed, 113 insertions(+), 96 deletions(-) diff --git a/examples/user_guide/plot_types/create_plot_snapshots.py b/examples/user_guide/plot_types/create_plot_snapshots.py index 59a57ae42..4bada4068 100644 --- a/examples/user_guide/plot_types/create_plot_snapshots.py +++ b/examples/user_guide/plot_types/create_plot_snapshots.py @@ -2,7 +2,9 @@ Relies on sklearn for the datasets. """ +from __future__ import print_function +import argparse from functools import partial from chaco.array_data_source import ArrayDataSource @@ -28,7 +30,6 @@ import chaco.default_colormaps as dc from enable.colors import color_table -import scipy.stats import scipy.stats import numpy as np import sklearn.datasets as datasets @@ -114,9 +115,11 @@ def save_plot(plot, filename): # ---- factories for example plots def get_line_plot(render_style): - prices = datasets.fetch_mldata('regression-datasets stock') + prng = np.random.RandomState(seed=1234) + x_data = np.linspace(0, 10, 50) + y_data = x_data ** 2 + prng.randn(50) - x, y = get_data_sources(y=prices['data'][:70,0]) + x, y = get_data_sources(x=x_data, y=y_data) x_mapper, y_mapper = get_mappers(x, y) line_plot = LinePlot( @@ -126,10 +129,11 @@ def get_line_plot(render_style): **PLOT_DEFAULTS ) - add_axes(line_plot, x_label='Days', y_label='Stock price') + add_axes(line_plot, x_label='x', y_label='y') return line_plot + get_line_plot_connected = partial(get_line_plot, "connectedpoints") get_line_plot_hold = partial(get_line_plot, "hold") get_line_plot_connectedhold = partial(get_line_plot, "connectedhold") @@ -138,7 +142,7 @@ def get_line_plot(render_style): def get_scatter_plot(): boston = datasets.load_boston() prices = boston['target'] - lower_status = boston['data'][:,-1] + lower_status = boston['data'][:, -1] x, y = get_data_sources(x=lower_status, y=prices) x_mapper, y_mapper = get_mappers(x, y) @@ -160,8 +164,8 @@ def get_scatter_plot(): def get_cmap_scatter_plot(): boston = datasets.load_boston() prices = boston['target'] - lower_status = boston['data'][:,-1] - nox = boston['data'][:,4] + lower_status = boston['data'][:, -1] + nox = boston['data'][:, 4] x, y = get_data_sources(x=lower_status, y=prices) x_mapper, y_mapper = get_mappers(x, y) @@ -191,9 +195,9 @@ def get_cmap_scatter_plot(): def get_4d_scatter_plot(): boston = datasets.load_boston() prices = boston['target'] - lower_status = boston['data'][:,-1] - tax = boston['data'][:,9] - nox = boston['data'][:,4] + lower_status = boston['data'][:, -1] + tax = boston['data'][:, 9] + nox = boston['data'][:, 4] x, y = get_data_sources(x=lower_status, y=prices) x_mapper, y_mapper = get_mappers(x, y) @@ -211,7 +215,7 @@ def get_4d_scatter_plot(): index_mapper=x_mapper, value_mapper=y_mapper, color_data=color_source, color_mapper=color_mapper, - fill_alpha = 0.8, + fill_alpha=0.8, marker='circle', marker_size=marker_size, title='Size represents property-tax rate, ' @@ -229,8 +233,8 @@ def get_4d_scatter_plot(): def get_variable_size_scatter_plot(): boston = datasets.load_boston() prices = boston['target'] - lower_status = boston['data'][:,-1] - tax = boston['data'][:,9] + lower_status = boston['data'][:, -1] + tax = boston['data'][:, 9] x, y = get_data_sources(x=lower_status, y=prices) x_mapper, y_mapper = get_mappers(x, y) @@ -263,7 +267,8 @@ def get_jitter_plot(): jitter_plot = JitterPlot( index=y, - mapper=y_mapper, + index_mapper=y_mapper, + orientation='h', marker='circle', jitter_width=100, **PLOT_DEFAULTS @@ -272,7 +277,7 @@ def get_jitter_plot(): x_axis = PlotAxis(orientation='bottom', title='Median house prices', - mapper=jitter_plot.mapper, + mapper=jitter_plot.index_mapper, component=jitter_plot, **AXIS_DEFAULTS) @@ -283,7 +288,7 @@ def get_jitter_plot(): def get_candle_plot(): means = np.array([0.2, 0.8, 0.5]) - stds = np.array([1.0, 0.3, 0.5]) + stds = np.array([1.0, 0.3, 0.5]) data = scipy.stats.norm(loc=means, scale=stds).rvs((100, 3)) x = ArrayDataSource(np.arange(3)) @@ -319,7 +324,7 @@ def get_candle_plot(): def get_errorbar_plot(): x = np.linspace(1., 5., 10) y = 3.2 * x**2 + 4.0 - y_with_noise = (y[None,:] + y_with_noise = (y[None, :] + scipy.stats.norm(loc=0, scale=2.8).rvs((10, 1))) means = y_with_noise.mean(0) @@ -350,9 +355,11 @@ def get_errorbar_plot(): def get_filled_line_plot(): - prices = datasets.fetch_mldata('regression-datasets stock') + prng = np.random.RandomState(seed=1234) + x_data = np.linspace(0, 10, 50) + y_data = x_data ** 2 + prng.randn(50) - x, y = get_data_sources(y=prices['data'][:70,0]) + x, y = get_data_sources(x=x_data, y=y_data) x_mapper, y_mapper = get_mappers(x, y) line_plot = FilledLinePlot( @@ -363,21 +370,21 @@ def get_filled_line_plot(): **PLOT_DEFAULTS ) - add_axes(line_plot, x_label='Days', y_label='Stock price') + add_axes(line_plot, x_label='x', y_label='y') return line_plot def get_image_plot(): # Create some RGBA image data - image = np.zeros((200,400,4), dtype=np.uint8) - image[:,0:40,0] += 255 # Vertical red stripe - image[0:25,:,1] += 255 # Horizontal green stripe; also yellow square - image[-80:,-160:,2] += 255 # Blue square - image[:,:,3] = 255 + image = np.zeros((200, 400, 4), dtype=np.uint8) + image[:, 0:40, 0] += 255 # Vertical red stripe + image[0:25, :, 1] += 255 # Horizontal green stripe; also yellow square + image[-80:, -160:, 2] += 255 # Blue square + image[:, :, 3] = 255 index = GridDataSource(np.linspace(0, 4., 400), np.linspace(-1, 1., 200)) - index_mapper = GridMapper(range=DataRange2D(low=(0,-1), high=(4.,1.))) + index_mapper = GridMapper(range=DataRange2D(low=(0, -1), high=(4., 1.))) image_source = ImageData(data=image, value_depth=4) @@ -395,8 +402,8 @@ def get_image_plot(): def get_image_from_file(): import os.path - filename = os.path.join('..', '..', '..', - 'demo','basic','capitol.jpg') + filename = os.path.join('..', '..', + 'demo', 'basic', 'capitol.jpg') image_source = ImageData.fromfile(filename) w, h = image_source.get_width(), image_source.get_height() @@ -454,23 +461,18 @@ def get_contour_line_plot(): x, y = np.meshgrid(xs, ys) z = scipy.special.jn(2, x)*y*x - # FIXME: we have set the xbounds and ybounds manually to work around - # a bug in CountourLinePlot, see comment in contour_line_plot.py at - # line 112 (the workaround is the +1 at the end) - xs_bounds = np.linspace(xs[0], xs[-1], z.shape[1]+1) - ys_bounds = np.linspace(ys[0], ys[-1], z.shape[0]+1) - index = GridDataSource(xdata=xs_bounds, ydata=ys_bounds) + index = GridDataSource(xdata=xs, ydata=ys) index_mapper = GridMapper(range=DataRange2D(index)) value = ImageData(data=z, value_depth=1) color_mapper = dc.Blues(DataRange1D(value)) contour_plot = ContourLinePlot( - index = index, - index_mapper = index_mapper, - value = value, - colors = color_mapper, - widths = list(range(1, 11)), + index=index, + index_mapper=index_mapper, + value=value, + colors=color_mapper, + widths=list(range(1, 11)), **PLOT_DEFAULTS ) @@ -488,22 +490,17 @@ def get_contour_poly_plot(): x, y = np.meshgrid(xs, ys) z = scipy.special.jn(2, x)*y*x - # FIXME: we have set the xbounds and ybounds manually to work around - # a bug in CountourLinePlot, see comment in contour_line_plot.py at - # line 112 (the workaround is the +1 at the end) - xs_bounds = np.linspace(xs[0], xs[-1], z.shape[1]+1) - ys_bounds = np.linspace(ys[0], ys[-1], z.shape[0]+1) - index = GridDataSource(xdata=xs_bounds, ydata=ys_bounds) + index = GridDataSource(xdata=xs, ydata=ys) index_mapper = GridMapper(range=DataRange2D(index)) value = ImageData(data=z, value_depth=1) color_mapper = dc.Blues(DataRange1D(value)) contour_plot = ContourPolyPlot( - index = index, - index_mapper = index_mapper, - value = value, - colors = color_mapper, + index=index, + index_mapper=index_mapper, + value=value, + colors=color_mapper, **PLOT_DEFAULTS ) @@ -520,12 +517,12 @@ def get_polygon_plot(): x_mapper, y_mapper = get_mappers(x, y) polygon_plot = PolygonPlot( - index = x, - value = y, - index_mapper = x_mapper, - value_mapper = y_mapper, - edge_width = 4.0, - face_color = 'orange', + index=x, + value=y, + index_mapper=x_mapper, + value_mapper=y_mapper, + edge_width=4.0, + face_color='orange', **PLOT_DEFAULTS ) @@ -554,12 +551,12 @@ def get_bar_plot(): y_mapper.range.high += 0.02 bar_plot = BarPlot( - index = x, - value = y, - index_mapper = x_mapper, - value_mapper = y_mapper, - fill_color = 'blue', - bar_width = 3.0, + index=x, + value=y, + index_mapper=x_mapper, + value_mapper=y_mapper, + fill_color='blue', + bar_width=3.0, **PLOT_DEFAULTS ) @@ -586,15 +583,14 @@ def get_quiver_plot(): v_source = MultiArrayDataSource(v.T) quiver_plot = QuiverPlot( - index = x, - value = y, - vectors = v_source, - index_mapper = x_mapper, - value_mapper = y_mapper, - aspect_ratio = 1.0 + index=x, + value=y, + vectors=v_source, + index_mapper=x_mapper, + value_mapper=y_mapper, + aspect_ratio=1.0 ) - add_axes(quiver_plot, x_label='x', y_label='y') return quiver_plot @@ -620,9 +616,9 @@ def get_polar_plot(): polar_plot = PolarLineRenderer( index=x, value=y, - index_mapper = index_mapper, - value_mapper = value_mapper, - aspect_ratio = 1.0, + index_mapper=index_mapper, + value_mapper=value_mapper, + aspect_ratio=1.0, **PLOT_DEFAULTS ) polar_plot.border_visible = False @@ -631,35 +627,40 @@ def get_polar_plot(): def get_multiline_plot(): - prices = datasets.fetch_mldata('regression-datasets stock') - - T, N_LINES = 70, 5 + prng = np.random.RandomState(seed=1234) - prices_data = prices['data'][:T,:N_LINES] - prices_data -= prices_data[0,:] + x = np.linspace(0, 10, 50) + y_data = np.column_stack([ + x ** 2, + 50 * np.sin(x), + 50 * np.cos(x), + 0.5 * x ** 2 + 2 * prng.randn(50), + 0.7 * x ** 2 + prng.randn(50), + ]) # data sources for the two axes - xs = ArrayDataSource(np.arange(T)) - ys = ArrayDataSource(np.arange(N_LINES)) - y_range = DataRange1D(low=-0.5, high=N_LINES - 0.5) + xs = ArrayDataSource(np.arange(50)) + ys = ArrayDataSource(np.arange(y_data.shape[1])) + y_range = DataRange1D(low=-0.5, high=y_data.shape[1] - 0.5) y_mapper = LinearMapper(range=y_range) # data source for the multiple lines - lines_source = MultiArrayDataSource(data=prices_data.T) + lines_source = MultiArrayDataSource(data=y_data.T) colors = ['blue', 'green', 'yellow', 'orange', 'red'] + def color_generator(color_idx): return color_table[colors[color_idx]] multiline_plot = MultiLinePlot( - index = xs, - yindex = ys, - index_mapper = LinearMapper(range=DataRange1D(xs)), - value_mapper = y_mapper, - value = lines_source, - normalized_amplitude = 1.0, - use_global_bounds = False, - color_func = color_generator, + index=xs, + yindex=ys, + index_mapper=LinearMapper(range=DataRange1D(xs)), + value_mapper=y_mapper, + value=lines_source, + normalized_amplitude=1.0, + use_global_bounds=False, + color_func=color_generator, **PLOT_DEFAULTS ) @@ -697,16 +698,32 @@ def color_generator(color_idx): } -if __name__ == '__main__': - name = 'vsize_scatter' +def main(): + p = argparse.ArgumentParser() + p.add_argument("-n", "--name", help=( + "Name of the plot snapshot to create. If this is omitted, all " + "snapshots are created.") + ) + name = p.parse_args().name + if name is None: + names = all_examples.keys() + else: + names = [name] + + create_plot_types(names) - factory_func = all_examples[name] - plot = factory_func() - window = PlotWindow(plot=plot) - ui = window.edit_traits() +def create_plot_types(names): + for name in names: + fname = '{}_plot.png'.format(name) + print("*** creating figure {!r} for type {!r} ***".format(fname, name)) + plot = all_examples[name]() - filename = '{}_plot.png'.format(name) - save_plot(window.container, filename) + window = PlotWindow(plot=plot) + window.edit_traits() + save_plot(window.container, fname) + +if __name__ == '__main__': + main()