Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 24 additions & 38 deletions zarr/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
import uuid
import time

from numcodecs.compat import ensure_bytes, ensure_contiguous_ndarray
from numcodecs.compat import (
ensure_bytes,
ensure_text,
ensure_contiguous_ndarray
)
from numcodecs.registry import codec_registry

from zarr.errors import (
Expand Down Expand Up @@ -1573,18 +1577,6 @@ def migrate_1to2(store):
del store['attrs']


def _dbm_encode_key(key):
if hasattr(key, 'encode'):
key = key.encode('ascii')
return key


def _dbm_decode_key(key):
if hasattr(key, 'decode'):
key = key.decode('ascii')
return key


# noinspection PyShadowingBuiltins
class DBMStore(MutableMapping):
"""Storage class using a DBM-style database.
Expand Down Expand Up @@ -1730,17 +1722,20 @@ def __exit__(self, *args):
self.close()

def __getitem__(self, key):
key = _dbm_encode_key(key)
if isinstance(key, str):
key = key.encode("ascii")
return self.db[key]

def __setitem__(self, key, value):
key = _dbm_encode_key(key)
if isinstance(key, str):
key = key.encode("ascii")
value = ensure_bytes(value)
with self.write_mutex:
self.db[key] = value

def __delitem__(self, key):
key = _dbm_encode_key(key)
if isinstance(key, str):
key = key.encode("ascii")
with self.write_mutex:
del self.db[key]

Expand All @@ -1754,7 +1749,7 @@ def __eq__(self, other):
)

def keys(self):
return (_dbm_decode_key(k) for k in iter(self.db.keys()))
return (ensure_text(k, "ascii") for k in iter(self.db.keys()))

def __iter__(self):
return self.keys()
Expand All @@ -1763,20 +1758,11 @@ def __len__(self):
return sum(1 for _ in self.keys())

def __contains__(self, key):
key = _dbm_encode_key(key)
if isinstance(key, str):
key = key.encode("ascii")
return key in self.db


def _lmdb_decode_key_buffer(key):
# assume buffers=True
return key.tobytes().decode('ascii')


def _lmdb_decode_key_bytes(key):
# assume buffers=False
return key.decode('ascii')


class LMDBStore(MutableMapping):
"""Storage class using LMDB. Requires the `lmdb <http://lmdb.readthedocs.io/>`_
package to be installed.
Expand Down Expand Up @@ -1866,10 +1852,6 @@ def __init__(self, path, buffers=True, **kwargs):
self.db = lmdb.open(path, **kwargs)

# store properties
if buffers:
self.decode_key = _lmdb_decode_key_buffer
else:
self.decode_key = _lmdb_decode_key_bytes
self.buffers = buffers
self.path = path
self.kwargs = kwargs
Expand Down Expand Up @@ -1901,7 +1883,8 @@ def __exit__(self, *args):
self.close()

def __getitem__(self, key):
key = _dbm_encode_key(key)
if isinstance(key, str):
key = key.encode("ascii")
# use the buffers option, should avoid a memory copy
with self.db.begin(buffers=self.buffers) as txn:
value = txn.get(key)
Expand All @@ -1910,18 +1893,21 @@ def __getitem__(self, key):
return value

def __setitem__(self, key, value):
key = _dbm_encode_key(key)
if isinstance(key, str):
key = key.encode("ascii")
with self.db.begin(write=True, buffers=self.buffers) as txn:
txn.put(key, value)

def __delitem__(self, key):
key = _dbm_encode_key(key)
if isinstance(key, str):
key = key.encode("ascii")
with self.db.begin(write=True) as txn:
if not txn.delete(key):
raise KeyError(key)

def __contains__(self, key):
key = _dbm_encode_key(key)
if isinstance(key, str):
key = key.encode("ascii")
with self.db.begin(buffers=self.buffers) as txn:
with txn.cursor() as cursor:
return cursor.set_key(key)
Expand All @@ -1930,13 +1916,13 @@ def items(self):
with self.db.begin(buffers=self.buffers) as txn:
with txn.cursor() as cursor:
for k, v in cursor.iternext(keys=True, values=True):
yield self.decode_key(k), v
yield ensure_text(k, "ascii"), v

def keys(self):
with self.db.begin(buffers=self.buffers) as txn:
with txn.cursor() as cursor:
for k in cursor.iternext(keys=True, values=False):
yield self.decode_key(k)
yield ensure_text(k, "ascii")

def values(self):
with self.db.begin(buffers=self.buffers) as txn:
Expand Down