Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions apptools/persistence/state_pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@
from . import version_registry
from .file_path import FilePath

PY_VER = sys.version_info[0]
NumpyArrayType = type(numpy.array([]))


Expand All @@ -130,7 +129,7 @@ def gunzip_string(data):
"""Given a gzipped string (`data`) this unzips the string and
returns it.
"""
if PY_VER== 2 or (bytes is not str and type(data) is bytes):
if type(data) is bytes:
s = BytesIO(data)
else:
s = StringIO(data)
Expand Down Expand Up @@ -442,10 +441,7 @@ def _do_dict(self, value):

def _do_numeric(self, value):
idx = self._register(value)
if PY_VER > 2:
data = base64.encodebytes(gzip_string(numpy.ndarray.dumps(value)))
else:
data = base64.encodestring(gzip_string(numpy.ndarray.dumps(value)))
data = base64.encodebytes(gzip_string(numpy.ndarray.dumps(value)))
return dict(type='numeric', id=idx, data=data)


Expand Down Expand Up @@ -650,15 +646,11 @@ def _do_dict(self, value, path):
return result

def _do_numeric(self, value, path):
if PY_VER > 2:
data = value['data']
if isinstance(data, str):
data = value['data'].encode('utf-8')
junk = gunzip_string(base64.decodebytes(data))
result = pickle.loads(junk, encoding='bytes')
else:
junk = gunzip_string(value['data'].decode('base64'))
result = pickle.loads(junk)
data = value['data']
if isinstance(data, str):
data = value['data'].encode('utf-8')
junk = gunzip_string(base64.decodebytes(data))
result = pickle.loads(junk, encoding='bytes')
self._numeric[value['id']] = (path, result)
self._obj_cache[value['id']] = result
return result
Expand Down
6 changes: 1 addition & 5 deletions apptools/persistence/tests/test_version_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# License: BSD Style.

# Standard library imports.
import sys
from imp import reload
import unittest

Expand Down Expand Up @@ -43,10 +42,7 @@ def upgrade1(self, state, version):
class TestVersionRegistry(unittest.TestCase):
def test_get_version(self):
"""Test the get_version function."""
if sys.version_info[0] > 2:
extra = [(('object', 'builtins'), -1)]
else:
extra = []
extra = [(('object', 'builtins'), -1)]
c = Classic()
v = version_registry.get_version(c)
res = extra + [(('Classic', __name__), 0)]
Expand Down
26 changes: 4 additions & 22 deletions apptools/persistence/versioned_unpickler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Standard library imports
from pickle import *
import sys
import logging
from types import GeneratorType, MethodType

Expand All @@ -11,19 +10,6 @@
logger = logging.getLogger(__name__)


def _unbound_method(method, klass):
"""
Python-version-agnostic unbound_method generator.

For Python 2, use MethodType. Python 3 doesn't have a separate
type for unbound methods; just return the method itself.
"""
if sys.version_info < (3,):
return MethodType(method, None, klass)
else:
return method


##############################################################################
# class 'NewUnpickler'
##############################################################################
Expand Down Expand Up @@ -159,8 +145,7 @@ def find_class(self, module, name):
# restore the original __setstate__ if necessary
fn = getattr(klass, '__setstate_original__', False)
if fn:
m = _unbound_method(fn, klass)
setattr(klass, '__setstate__', m)
setattr(klass, '__setstate__', fn)

return klass

Expand All @@ -177,12 +162,10 @@ class as the __setstate__ method.
self.backup_setstate(module, klass)

# add the updater into the class
m = _unbound_method(fn, klass)
setattr(klass, '__updater__', m)
setattr(klass, '__updater__', fn)

# hook up our __setstate__ which updates self.__dict__
m = _unbound_method(__replacement_setstate__, klass)
setattr(klass, '__setstate__', m)
setattr(klass, '__setstate__', __replacement_setstate__)

else:
pass
Expand All @@ -206,8 +189,7 @@ def backup_setstate(self, module, klass):

#logger.debug('renaming __setstate__ to %s' % name)
method = getattr(klass, '__setstate__')
m = _unbound_method(method, klass)
setattr(klass, name, m)
setattr(klass, name, method)

else:
# the class has no __setstate__ method so do nothing
Expand Down