-
Notifications
You must be signed in to change notification settings - Fork 21
Handle numpy pure scalar types #62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
158cac1
c12f57d
70639d8
755f6cd
b299f9f
44fee2f
daf1b79
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -100,10 +100,10 @@ | |
| # Standard library imports. | ||
| import base64 | ||
| import sys | ||
| import types | ||
| import pickle | ||
| import gzip | ||
| from io import BytesIO, StringIO | ||
| import warnings | ||
|
|
||
| import numpy | ||
|
|
||
|
|
@@ -139,6 +139,23 @@ def gunzip_string(data): | |
| writer.close() | ||
| return data | ||
|
|
||
|
|
||
| def base64_encode(data): | ||
| if PY_VER > 2: | ||
| return base64.encodebytes(data) | ||
| else: | ||
| return base64.encodestring(data) | ||
|
|
||
|
|
||
| def base64_decode(data): | ||
| if PY_VER > 2: | ||
| if isinstance(data, str): | ||
| data = data.encode('utf-8') | ||
| return base64.decodebytes(data) | ||
| else: | ||
| return data.decode('base64') | ||
|
|
||
|
|
||
| class StatePicklerError(Exception): | ||
| pass | ||
|
|
||
|
|
@@ -330,6 +347,8 @@ def _do(self, obj): | |
| return self._do_reference(obj) | ||
| elif obj_type in self.type_map: | ||
| return self.type_map[obj_type](obj) | ||
| elif isinstance(obj, numpy.generic): | ||
| return self._do_numpy_generic_type(obj) | ||
| elif isinstance(obj, tuple): | ||
| # Takes care of StateTuples. | ||
| return self._do_tuple(obj) | ||
|
|
@@ -341,8 +360,19 @@ def _do(self, obj): | |
| return self._do_dict(obj) | ||
| elif hasattr(obj, '__dict__'): | ||
| return self._do_instance(obj) | ||
| else: | ||
| warnings.warn("Cannot pickle unrecognized type {}. Returning None" | ||
| " for backward compatibility.".format(obj_type)) | ||
| return None | ||
|
|
||
| def _get_id(self, value): | ||
| # We consider a special case for numpy scalar values, because | ||
| # they hash as native types, but they are special because we | ||
| # want to recover their true type, and the only way of doing so | ||
| # is to consider them as objects. | ||
| if isinstance(value, numpy.generic): | ||
| return id(value) | ||
|
|
||
| try: | ||
| key = hash(value) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The use of
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| except TypeError: | ||
|
|
@@ -445,12 +475,14 @@ def _do_dict(self, value): | |
|
|
||
| def _do_numeric(self, value): | ||
| idx = self._register(value) | ||
| if PY_VER > 2: | ||
| data = base64.encodebytes(gzip_string(numpy.ndarray.dumps(value))) | ||
| else: | ||
| data = base64.encodestring(gzip_string(numpy.ndarray.dumps(value))) | ||
| data = base64_encode(gzip_string(numpy.ndarray.dumps(value))) | ||
| return dict(type='numeric', id=idx, data=data) | ||
|
|
||
| def _do_numpy_generic_type(self, value): | ||
| idx = self._register(value) | ||
| self._misc_cache.append(value) | ||
| data = base64_encode(value.dumps()) | ||
| return dict(type='numpy', id=idx, data=data) | ||
|
|
||
|
|
||
| ###################################################################### | ||
|
|
@@ -502,6 +534,7 @@ def __init__(self): | |
| 'list': self._do_list, | ||
| 'dict': self._do_dict, | ||
| 'numeric': self._do_numeric, | ||
| 'numpy': self._do_numpy_generic_type, | ||
| } | ||
|
|
||
| def load_state(self, file): | ||
|
|
@@ -666,6 +699,10 @@ def _do_numeric(self, value, path): | |
| self._obj_cache[value['id']] = result | ||
| return result | ||
|
|
||
| def _do_numpy_generic_type(self, value, path): | ||
| result = numpy.loads(base64_decode(value["data"])) | ||
| self._obj_cache[value['id']] = result | ||
| return result | ||
|
|
||
| ###################################################################### | ||
| # `StateSetter` class | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a small problem with using
idhere. The sameidcan be re-used when the original object is garbage collected. So if during the pickling process the numpy arrays are garbage collected then the registry might point to the wrong object.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated the pickling
_do_...method to keep the object alive in the chase (for the duration of the pickling process). This should be enough to avoid theidreuse case.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a simple way to avoid the caching altogether? The current caching code is broken (#66), and I don't think it makes a lot of sense to be caching simple immutable scalars in the first place: the size of the "reference" dictionary generated is going to be comparable to the size of the object dictionary.