diff --git a/apptools/persistence/state_pickler.py b/apptools/persistence/state_pickler.py index aed0bdcb2..e4f327dc1 100644 --- a/apptools/persistence/state_pickler.py +++ b/apptools/persistence/state_pickler.py @@ -100,10 +100,10 @@ # Standard library imports. import base64 import sys -import types import pickle import gzip from io import BytesIO, StringIO +import warnings import numpy @@ -139,6 +139,23 @@ def gunzip_string(data): writer.close() return data + +def base64_encode(data): + if PY_VER > 2: + return base64.encodebytes(data) + else: + return base64.encodestring(data) + + +def base64_decode(data): + if PY_VER > 2: + if isinstance(data, str): + data = data.encode('utf-8') + return base64.decodebytes(data) + else: + return data.decode('base64') + + class StatePicklerError(Exception): pass @@ -330,6 +347,8 @@ def _do(self, obj): return self._do_reference(obj) elif obj_type in self.type_map: return self.type_map[obj_type](obj) + elif isinstance(obj, numpy.generic): + return self._do_numpy_generic_type(obj) elif isinstance(obj, tuple): # Takes care of StateTuples. return self._do_tuple(obj) @@ -341,8 +360,19 @@ def _do(self, obj): return self._do_dict(obj) elif hasattr(obj, '__dict__'): return self._do_instance(obj) + else: + warnings.warn("Cannot pickle unrecognized type {}. Returning None" + " for backward compatibility.".format(obj_type)) + return None def _get_id(self, value): + # We consider a special case for numpy scalar values, because + # they hash as native types, but they are special because we + # want to recover their true type, and the only way of doing so + # is to consider them as objects. + if isinstance(value, numpy.generic): + return id(value) + try: key = hash(value) except TypeError: @@ -445,12 +475,14 @@ def _do_dict(self, value): def _do_numeric(self, value): idx = self._register(value) - if PY_VER > 2: - data = base64.encodebytes(gzip_string(numpy.ndarray.dumps(value))) - else: - data = base64.encodestring(gzip_string(numpy.ndarray.dumps(value))) + data = base64_encode(gzip_string(numpy.ndarray.dumps(value))) return dict(type='numeric', id=idx, data=data) + def _do_numpy_generic_type(self, value): + idx = self._register(value) + self._misc_cache.append(value) + data = base64_encode(value.dumps()) + return dict(type='numpy', id=idx, data=data) ###################################################################### @@ -502,6 +534,7 @@ def __init__(self): 'list': self._do_list, 'dict': self._do_dict, 'numeric': self._do_numeric, + 'numpy': self._do_numpy_generic_type, } def load_state(self, file): @@ -666,6 +699,10 @@ def _do_numeric(self, value, path): self._obj_cache[value['id']] = result return result + def _do_numpy_generic_type(self, value, path): + result = numpy.loads(base64_decode(value["data"])) + self._obj_cache[value['id']] = result + return result ###################################################################### # `StateSetter` class diff --git a/apptools/persistence/tests/test_state_pickler.py b/apptools/persistence/tests/test_state_pickler.py index 70bf42573..c5393149f 100644 --- a/apptools/persistence/tests/test_state_pickler.py +++ b/apptools/persistence/tests/test_state_pickler.py @@ -44,6 +44,7 @@ def __init__(self): self.i = 7 self.l = 1234567890123456789 self.f = math.pi + self.fnumpy = numpy.float64(3.0) self.c = complex(1.01234, 2.3) self.n = None self.s = 'String' @@ -65,6 +66,7 @@ class TestTraits(HasTraits): i = Int(7) l = Long(12345678901234567890) f = Float(math.pi) + fnumpy = Float(numpy.float64(3.0)) c = Complex(complex(1.01234, 2.3)) n = Any s = Str('String') @@ -168,6 +170,9 @@ def verify_unpickled(self, obj, state): self.assertEqual(state.i, obj.i) self.assertEqual(state.l, obj.l) self.assertEqual(state.f, obj.f) + self.assertEqual(state.fnumpy, obj.fnumpy) + self.assertIsInstance(obj.fnumpy, numpy.generic) + self.assertEqual(type(state.fnumpy), type(obj.fnumpy)) self.assertEqual(state.c, obj.c) self.assertEqual(state.n, obj.n) self.assertEqual(state.s, obj.s) @@ -464,5 +469,33 @@ def test_dump_to_file_str(self): os.remove(filepath) +class TestStatePickler(unittest.TestCase): + + def setUp(self): + self.pickler = state_pickler.StatePickler() + + def tearDown(self): + self.pickler = None + + def test_on_base_types(self): + state = self.pickler.dump_state(1) + self.assertEqual(state, 1) + + def test_on_lists(self): + l = [1,2.0, None, [1,2,3]] + state = self.pickler.dump_state(l) + self.assertEqual( + state, + {'data': [1, 2.0, None, {'data': [1, 2, 3], 'type': 'list', 'id': 1}], + 'id': 0, + 'type': 'list'}) + + def test_on_numpy_scalars(self): + state = self.pickler.dumps(numpy.int32(78)) + loaded_state = state_pickler.StateUnpickler().loads_state(state) + self.assertEqual(loaded_state, 78) + self.assertEqual(loaded_state.dtype, numpy.int32) + + if __name__ == "__main__": unittest.main() diff --git a/apptools/preferences/tests/example.ini b/apptools/preferences/tests/example.ini index 100dc6e16..9b2f1311d 100644 --- a/apptools/preferences/tests/example.ini +++ b/apptools/preferences/tests/example.ini @@ -1,10 +1,10 @@ [acme.ui] -bgcolor = blue -width = 50 ratio = 1.0 -visible = True -description = 'acme ui' +description = acme ui +width = 50 offsets = "[1, 2, 3, 4]" +bgcolor = blue +visible = True names = "['joe', 'fred', 'jane']" [acme.ui.splash_screen]