Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions chaco/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@

from abstract_plot_data import AbstractPlotData
from array_plot_data import ArrayPlotData
from data_frame_plot_data import DataFramePlotData
from plot import Plot
from toolbar_plot import ToolbarPlot

Expand Down
181 changes: 181 additions & 0 deletions chaco/data_frame_plot_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
""" Defines DataFramePlotData.
"""

# Enthought library imports
from traits.api import Bool, Instance, Property

# Local, relative imports
from .abstract_plot_data import AbstractPlotData


class DataFramePlotData(AbstractPlotData):
""" A PlotData implementation class that handles a DataFrame.

By default, it doesn't allow its input data to be modified by downstream
Chaco components or interactors. The index is available as data unless
there is a column named 'index', in which case that column masks the
DataFrame index. (Rename that column if the DataFrame index must be
accessible.)

"""

#-------------------------------------------------------------------------
# Public traits
#-------------------------------------------------------------------------

# The DataFrame backing this object.
data_frame = Instance('pandas.core.frame.DataFrame')

# Consumers can write data to this object (overrides AbstractPlotData).
writable = True

#-------------------------------------------------------------------------
# Private traits
#-------------------------------------------------------------------------

_has_index_column = Property(Bool)

def _get__has_index_column(self):
return 'index' in self.data_frame.columns

#------------------------------------------------------------------------
# AbstractPlotData Interface
#------------------------------------------------------------------------

def list_data(self):
""" Returns a list of the names of the columns of the DataFrame. The
name 'index' is added to this unless there is a column named 'index'.
"""
names = self.data_frame.columns.tolist()
if not self._has_index_column:
names = ['index'] + names
return names

def get_data(self, name):
""" Returns the array associated with *name*.

Implements AbstractDataSource.
"""
if name == 'index' and not self._has_index_column:
return self.data_frame.index.values
series = self.data_frame.get(name)
return series if series is None else series.values

def del_data(self, name):
""" Deletes the column specified by *name*, or raises a KeyError if
the named column does not exist.
"""
if not self.writable:
return None

if name == 'index' and not self._has_index_column:
raise KeyError("Cannot delete the index.")

if name in self.data_frame.columns:
del self.data_frame[name]
if name == 'index':
# It is impossible to remove the 'index' in the PlotData.
# Removing a column named 'index' in the DataFrame means that
# the DataFrame index is now the 'index' in the PlotData. Thus,
# this results in a 'changed' event instead of a 'removed'
# event.
self.data_changed = {'changed': [name]}
else:
self.data_changed = {'removed': [name]}
else:
raise KeyError("Column '{}' does not exist.".format(name))

def set_data(self, name, new_data, generate_name=False):
""" Sets the specified index or column as the value for either the
specified
name or a generated name.

If the instance's `writable` attribute is True, then this method sets
the data associated with the given name to the new value, otherwise it
does nothing.

Parameters
----------
name : string
The name of the array whose value is to be set.
new_data : array
The array to set as the value of *name*.
generate_name : Boolean
If True, a unique name of the form 'seriesN' is created for the
array, and is used in place of *name*. The 'N' in 'seriesN' is
one greater the largest N already used.

Returns
-------
The name under which the array was set.

"""
if not self.writable:
return None

if generate_name:
names = self._generate_names(1)
name = names[0]

self.update_data({name: new_data})
return name

def update_data(self, *args, **kwargs):
""" Sets the specified column or index as the value for either the
specified name or a generated name.

Implements AbstractPlotData's update_data() method. This method has
the same signature as the dictionary update() method.

"""
if not self.writable:
return None

data = dict(*args, **kwargs)
event = {}
for name in data:
if name == 'index' or name in self.data_frame.columns:
event.setdefault('changed', []).append(name)
else:
event.setdefault('added', []).append(name)

self._update_data(data)
self.data_changed = event

def set_selection(self, name, selection):
""" Overrides AbstractPlotData to do nothing and not raise an error.
"""
pass

#------------------------------------------------------------------------
# Private methods
#------------------------------------------------------------------------

def _generate_names(self, n):
""" Generate n new names
"""
max_index = max(self._generate_indices())
names = [
"series{0:d}".format(i)
for i in range(max_index + 1, max_index + n + 1)
]
return names

def _generate_indices(self):
""" Generator that yields all integers that match "series%d" in keys
"""
yield 0 # default minimum
for name in self.list_data():
if name.startswith('series'):
try:
v = int(name[6:])
except ValueError:
continue
yield v

def _update_data(self, data):
for name, value in data.items():
if name == 'index' and not self._has_index_column:
self.data_frame.index = value
else:
self.data_frame[name] = value
112 changes: 112 additions & 0 deletions chaco/tests/data_frame_plot_data_test_case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import contextlib
from traits.testing.unittest_tools import unittest

import numpy as np
from numpy.testing import assert_array_equal
from pandas import DataFrame

from chaco.api import DataFramePlotData
from traits.api import HasTraits, Instance, List, on_trait_change


class DataFramePlotDataEventsCollector(HasTraits):
plot_data = Instance(DataFramePlotData)

data_changed_events = List

@on_trait_change('plot_data:data_changed')
def _got_data_changed_event(self, event):
self.data_changed_events.append(event)


@contextlib.contextmanager
def monitor_events(plot_data):
"""
Context manager to collect data_changed events.

"""
collector = DataFramePlotDataEventsCollector(plot_data=plot_data)
yield collector.data_changed_events


class DataFramePlotDataTestCase(unittest.TestCase):

def test_data_changed_events(self):
# Test data.
arr = np.zeros(16)
arr2 = np.ones(16)

df = DataFrame(index=np.arange(16))
plot_data = DataFramePlotData(data_frame=df)

assert_array_equal(plot_data.get_data('index'), df.index.values)

with monitor_events(plot_data) as events:
plot_data.set_data('arr', arr)
self.assertEqual(events, [{'added': ['arr']}])

assert_array_equal(df['arr'].values, arr)

# While we're here, check that get_data works as advertised.
out = plot_data.get_data('arr')
assert_array_equal(arr, out)

with monitor_events(plot_data) as events:
plot_data.set_data('arr', arr2)
self.assertEqual(events, [{'changed': ['arr']}])
assert_array_equal(df['arr'].values, arr2)

with monitor_events(plot_data) as events:
plot_data.del_data('arr')
self.assertEqual(events, [{'removed': ['arr']}])

def test_no_index_column(self):
# Test data.
idx = np.arange(16)
arr = np.zeros(16)
df = DataFrame(index=idx)
plot_data = DataFramePlotData(data_frame=df)

assert_array_equal(plot_data.get_data('index'), df.index.values)

# Can set 'index'
with monitor_events(plot_data) as events:
plot_data.set_data('index', arr)
self.assertEqual(events, [{'changed': ['index']}])
self.assertNotIn('index', df.columns)
assert_array_equal(df.index.values, arr)

# Cannot remove 'index' column
with self.assertRaises(KeyError):
plot_data.del_data('index')

def test_index_column(self):
# Test data.
idx = np.arange(16)
arr = np.zeros(16)
arr2 = np.ones(16)
data = {'index': arr}
df = DataFrame(data, index=idx)
plot_data = DataFramePlotData(data_frame=df)

assert_array_equal(plot_data.get_data('index'), df['index'].values)

# Can set 'index' column
with monitor_events(plot_data) as events:
plot_data.set_data('index', arr2)
self.assertEqual(events, [{'changed': ['index']}])
assert_array_equal(df['index'].values, arr2)

# Can remove 'index' column
with monitor_events(plot_data) as events:
plot_data.del_data('index')
self.assertNotIn('index', df.columns)
# Since there is always an index, this will register a 'changed'
# event instead of a 'removed' event.
self.assertEqual(events, [{'changed': ['index']}])
assert_array_equal(plot_data.get_data('index'), df.index.values)


if __name__ == '__main__':
import nose
nose.run()
1 change: 1 addition & 0 deletions ci/edmtool.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
"nose",
"mock",
"numpy",
"pandas",
"pygments",
"pyparsing",
"cython"
Expand Down
1 change: 0 additions & 1 deletion ci/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
nose-exclude
coverage

66 changes: 66 additions & 0 deletions examples/demo/basic/pandas_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#!/usr/bin/env python
""" Plot backed by a pandas DataFrame. """

# Major library imports
from numpy import linspace
from pandas import DataFrame
from scipy.special import jn

# Enthought library imports
from enable.api import Component, ComponentEditor
from traits.api import HasTraits, Instance
from traitsui.api import Item, Group, View

# Chaco imports
from chaco.api import DataFramePlotData, Plot


#==============================================================================
# # Demo class that is used by the demo.py application.
#==============================================================================

class Demo(HasTraits):

plot_data = Instance(DataFramePlotData)

plot = Instance(Component)

traits_view = View(
Group(
Item(
'plot',
editor=ComponentEditor(size=(900, 500)),
show_label=False
),
orientation="vertical",
),
resizable=True,
title="pandas data example"
)

def _plot_data_default(self):
# Create a DataFrame with plottable data
index = linspace(-2.0, 10.0, 100)
df = DataFrame(index=index)
for i in range(5):
name = "y" + str(i)
df[name] = jn(i, index)

plot_data = DataFramePlotData(data_frame=df)
return plot_data

def _plot_default(self):
plot = Plot(self.plot_data, padding=50)
plot.plot(("index", "y0", "y1", "y2"), name="j_n, n<3", color="red")
plot.plot(("index", "y3"), name="j_3", color="blue")
plot.x_axis.title = "index"
plot.y_axis.title = "j_n"
return plot


demo = Demo()

if __name__ == "__main__":
demo.configure_traits()

#--EOF---
6 changes: 6 additions & 0 deletions examples/demo/demo.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,12 @@ sourcedir = basic
[[[Editing plots]]]
files = ../edit_line.py

[[Data Sources]]
sourcedir = ''

[[[Pandas DataFrame]]]
files = pandas_data.py

[Image and Contour Plots]
sourcedir = basic

Expand Down