Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ Release notes
* Update Cython to 0.29.6.
By :user:`John Kirkham <jakirkham>`, :issue:`168`, :issue:`177`.

* Add a ``subok`` option to the ``ensure_*`` functions to handle their behavior
with subclasses of the expected type.
By :user:`John Kirkham <jakirkham>`, :issue:`173`.


.. _release_0.6.2:

Expand Down
20 changes: 14 additions & 6 deletions numcodecs/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,16 @@ def ensure_text(l, encoding='utf-8'):
return text_type(l, encoding=encoding)


def ensure_ndarray(buf):
def ensure_ndarray(buf, subok=True):
"""Convenience function to coerce `buf` to a numpy array, if it is not already a
numpy array.

Parameters
----------
buf : array-like or bytes-like
A numpy array or any object exporting a buffer interface.
subok : bool
Whether to allow `ndarray` subclasses or not

Returns
-------
Expand All @@ -54,7 +56,11 @@ def ensure_ndarray(buf):

"""

if isinstance(buf, np.ndarray):
if type(buf) is np.ndarray:
# already a numpy array
arr = buf

elif subok and isinstance(buf, np.ndarray):
# already a numpy array
arr = buf

Expand Down Expand Up @@ -90,7 +96,7 @@ def ensure_ndarray(buf):
return arr


def ensure_contiguous_ndarray(buf, max_buffer_size=None):
def ensure_contiguous_ndarray(buf, max_buffer_size=None, subok=True):
"""Convenience function to coerce `buf` to a numpy array, if it is not already a
numpy array. Also ensures that the returned value exports fully contiguous memory,
and supports the new-style buffer interface. If the optional max_buffer_size is
Expand All @@ -104,6 +110,8 @@ def ensure_contiguous_ndarray(buf, max_buffer_size=None):
max_buffer_size : int
If specified, the largest allowable value of arr.nbytes, where arr
is the retured array.
subok : bool
Whether to allow `ndarray` subclasses or not

Returns
-------
Expand All @@ -118,7 +126,7 @@ def ensure_contiguous_ndarray(buf, max_buffer_size=None):
"""

# ensure input is a numpy array
arr = ensure_ndarray(buf)
arr = ensure_ndarray(buf, subok=subok)

# check for object arrays, these are just memory pointers, actual memory holding
# item data is scattered elsewhere
Expand All @@ -144,10 +152,10 @@ def ensure_contiguous_ndarray(buf, max_buffer_size=None):
return arr


def ensure_bytes(buf):
def ensure_bytes(buf, subok=True):
"""Obtain a bytes object from memory exposed by `buf`."""

if not isinstance(buf, binary_type):
if not (type(buf) is binary_type or (subok and isinstance(buf, binary_type))):

# go via numpy, for convenience
arr = ensure_ndarray(buf)
Expand Down
34 changes: 34 additions & 0 deletions numcodecs/tests/test_compat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function, division
import array
import itertools
import mmap


Expand All @@ -23,6 +24,24 @@ def test_ensure_bytes():
assert isinstance(b, bytes)


def test_ensure_bytes_subok():
class MyBytes(bytes):
pass

bufs = [
b'adsdasdas',
MyBytes(b'adsdasdas'),
]
suboks = [True, False]
for buf, subok in itertools.product(bufs, suboks):
b = ensure_bytes(buf, subok=subok)
assert isinstance(b, bytes)
if subok:
assert type(b) is type(buf)
else:
assert type(b) is bytes


def test_ensure_contiguous_ndarray_shares_memory():
typed_bufs = [
('u', 1, b'adsdasdas'),
Expand Down Expand Up @@ -51,6 +70,21 @@ def test_ensure_contiguous_ndarray_shares_memory():
assert np.shares_memory(a, memoryview(buf))


def test_ensure_contiguous_ndarray_subok():
bufs = [
np.arange(100, dtype=np.int64),
np.ma.arange(100, dtype=np.int64),
]
suboks = [True, False]
for buf, subok in itertools.product(bufs, suboks):
a = ensure_contiguous_ndarray(buf, subok=subok)
assert isinstance(a, np.ndarray)
if subok:
assert type(a) is type(buf)
else:
assert type(a) is np.ndarray


def test_ensure_bytes_invalid_inputs():

# object array not allowed
Expand Down