Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d69685a
Add bindings and tests for FixedShapeTensorType and Array
AlenkaF Apr 4, 2023
a6292f8
Fix linter error
AlenkaF Apr 5, 2023
1bdba1d
Add pa.fixedshapetensor factory function and update docstring examples
AlenkaF Apr 5, 2023
7c395b0
Apply suggestions from code review - Joris
AlenkaF Apr 5, 2023
d27d48f
Use pa.FixedSizeListArray.from_arrays(..) in from_numpy_ndarray()
AlenkaF Apr 5, 2023
8e790b4
Change fixedshapetensor to fixed_shape_tensor
AlenkaF Apr 5, 2023
64e0cd0
Add tests for all the custom attributes
AlenkaF Apr 5, 2023
48cbeb3
Add test for numpy F-contiguous
AlenkaF Apr 5, 2023
d9ca165
Correct dim_names() to return list of strings, not bytes
AlenkaF Apr 5, 2023
d3530af
Correct dim_names and permutation methods to return None and not empt…
AlenkaF Apr 5, 2023
e2ce8ba
Replace FixedShapeTensorType with fixed_shape_tensor in FixedShapeTen…
AlenkaF Apr 5, 2023
ee5d25c
Update from_numpy_ndarray docstrings
AlenkaF Apr 5, 2023
f5a5c0c
Update public-api.pxi
AlenkaF Apr 5, 2023
52f9e7e
Update python/pyarrow/types.pxi
AlenkaF Apr 5, 2023
f9dee9e
Merge branch 'main' into python-binding-tensor-extension-type
AlenkaF Apr 5, 2023
b171d00
Use ravel insted of flatten and raise ValueError if numpy array is no…
AlenkaF Apr 5, 2023
c0ec94c
Remove CFixedShapeTensorType binding in libarrow
AlenkaF Apr 5, 2023
f2d9fe7
Fix doctest failure
AlenkaF Apr 6, 2023
8b5dc93
Add explanation of permutation from the spec to the docstring of fixe…
AlenkaF Apr 6, 2023
570f086
from_numpy_ndarray should be a static method
AlenkaF Apr 6, 2023
3dbbe20
Apply suggestions from code review
AlenkaF Apr 6, 2023
223968a
Apply suggestions from code review
AlenkaF Apr 6, 2023
dd8fd31
Update to_numpy_ndarraydocstring
AlenkaF Apr 7, 2023
1ebb829
Add a check for non-trivial permutation in to_numpy_ndarray
AlenkaF Apr 11, 2023
b2d0453
Update python/pyarrow/array.pxi
AlenkaF Apr 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/format/CanonicalExtensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ same rules as laid out above, and provide backwards compatibility guarantees.
Official List
=============

.. _fixed_shape_tensor_extension:

Fixed shape tensor
==================

Expand Down
5 changes: 3 additions & 2 deletions python/pyarrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def print_entry(label, value):
union, sparse_union, dense_union,
dictionary,
run_end_encoded,
fixed_shape_tensor,
field,
type_for_alias,
DataType, DictionaryType, StructType,
Expand All @@ -178,7 +179,7 @@ def print_entry(label, value):
TimestampType, Time32Type, Time64Type, DurationType,
FixedSizeBinaryType, Decimal128Type, Decimal256Type,
BaseExtensionType, ExtensionType,
RunEndEncodedType,
RunEndEncodedType, FixedShapeTensorType,
PyExtensionType, UnknownExtensionType,
register_extension_type, unregister_extension_type,
DictionaryMemo,
Expand Down Expand Up @@ -209,7 +210,7 @@ def print_entry(label, value):
Time32Array, Time64Array, DurationArray,
MonthDayNanoIntervalArray,
Decimal128Array, Decimal256Array, StructArray, ExtensionArray,
RunEndEncodedArray,
RunEndEncodedArray, FixedShapeTensorArray,
scalar, NA, _NULL as NULL, Scalar,
NullScalar, BooleanScalar,
Int8Scalar, Int16Scalar, Int32Scalar, Int64Scalar,
Expand Down
109 changes: 109 additions & 0 deletions python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -3076,6 +3076,115 @@ cdef class ExtensionArray(Array):
return Array._to_pandas(self.storage, options, **kwargs)


class FixedShapeTensorArray(ExtensionArray):
"""
Concrete class for fixed shape tensor extension arrays.

Examples
--------
Define the extension type for tensor array

>>> import pyarrow as pa
>>> tensor_type = pa.fixed_shape_tensor(pa.int32(), [2, 2])

Create an extension array

>>> arr = [[1, 2, 3, 4], [10, 20, 30, 40], [100, 200, 300, 400]]
>>> storage = pa.array(arr, pa.list_(pa.int32(), 4))
>>> pa.ExtensionArray.from_storage(tensor_type, storage)
<pyarrow.lib.FixedShapeTensorArray object at ...>
[
[
1,
2,
3,
4
],
[
10,
20,
30,
40
],
[
100,
200,
300,
400
]
]
"""

def to_numpy_ndarray(self):
"""
Convert fixed shape tensor extension array to a numpy array (with dim+1).

Note: ``permutation`` should be trivial (``None`` or ``[0, 1, ..., len(shape)-1]``).
"""
if self.type.permutation is None or self.type.permutation == list(range(len(self.type.shape))):
np_flat = np.asarray(self.storage.values)
numpy_tensor = np_flat.reshape((len(self),) + tuple(self.type.shape))
return numpy_tensor
else:
raise ValueError(
'Only non-permuted tensors can be converted to numpy tensors.')

@staticmethod
def from_numpy_ndarray(obj):
"""
Convert numpy tensors (ndarrays) to a fixed shape tensor extension array.
The first dimension of ndarray will become the length of the fixed
shape tensor array.

Numpy array needs to be C-contiguous in memory
(``obj.flags["C_CONTIGUOUS"]==True``).

Parameters
----------
obj : numpy.ndarray

Examples
--------
>>> import pyarrow as pa
>>> import numpy as np
>>> arr = np.array(
... [[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]],
... dtype=np.float32)
>>> pa.FixedShapeTensorArray.from_numpy_ndarray(arr)
<pyarrow.lib.FixedShapeTensorArray object at ...>
[
[
1,
2,
3,
4,
5,
6
],
[
1,
2,
3,
4,
5,
6
]
]
"""
if not obj.flags["C_CONTIGUOUS"]:
raise ValueError('The data in the numpy array need to be in a single, '
'C-style contiguous segment.')

arrow_type = from_numpy_dtype(obj.dtype)
shape = obj.shape[1:]
size = obj.size / obj.shape[0]

return ExtensionArray.from_storage(
fixed_shape_tensor(arrow_type, shape),
FixedSizeListArray.from_arrays(np.ravel(obj, order='C'), size)
)


cdef dict _array_classes = {
_Type_NA: NullArray,
_Type_BOOL: BooleanArray,
Expand Down
21 changes: 21 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2619,6 +2619,27 @@ cdef extern from "arrow/extension_type.h" namespace "arrow":
shared_ptr[CArray] storage()


cdef extern from "arrow/extension/fixed_shape_tensor.h" namespace "arrow::extension":
cdef cppclass CFixedShapeTensorType \
" arrow::extension::FixedShapeTensorType"(CExtensionType):

@staticmethod
CResult[shared_ptr[CDataType]] Make(const shared_ptr[CDataType]& value_type,
const vector[int64_t]& shape,
const vector[int64_t]& permutation,
const vector[c_string]& dim_names)

CResult[shared_ptr[CDataType]] Deserialize(const shared_ptr[CDataType] storage_type,
const c_string& serialized_data) const

c_string Serialize() const

const shared_ptr[CDataType] value_type()
const vector[int64_t] shape()
const vector[int64_t] permutation()
const vector[c_string] dim_names()


cdef extern from "arrow/util/compression.h" namespace "arrow" nogil:
cdef enum CCompressionType" arrow::Compression::type":
CCompressionType_UNCOMPRESSED" arrow::Compression::UNCOMPRESSED"
Expand Down
5 changes: 5 additions & 0 deletions python/pyarrow/lib.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ cdef class ExtensionType(BaseExtensionType):
const CPyExtensionType* cpy_ext_type


cdef class FixedShapeTensorType(BaseExtensionType):
cdef:
const CFixedShapeTensorType* tensor_ext_type


cdef class PyExtensionType(ExtensionType):
pass

Expand Down
2 changes: 2 additions & 0 deletions python/pyarrow/public-api.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ cdef api object pyarrow_wrap_data_type(
cpy_ext_type = dynamic_cast[_CPyExtensionTypePtr](ext_type)
if cpy_ext_type != nullptr:
return cpy_ext_type.GetInstance()
elif ext_type.extension_name() == b"arrow.fixed_shape_tensor":
out = FixedShapeTensorType.__new__(FixedShapeTensorType)
else:
out = BaseExtensionType.__new__(BaseExtensionType)
else:
Expand Down
96 changes: 96 additions & 0 deletions python/pyarrow/tests/test_extension_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,102 @@ def test_cpp_extension_in_python(tmpdir):
assert reconstructed_array == array


def test_tensor_type():
tensor_type = pa.fixed_shape_tensor(pa.int8(), [2, 3])
assert tensor_type.extension_name == "arrow.fixed_shape_tensor"
assert tensor_type.storage_type == pa.list_(pa.int8(), 6)
assert tensor_type.shape == [2, 3]
assert tensor_type.dim_names is None
assert tensor_type.permutation is None

tensor_type = pa.fixed_shape_tensor(pa.float64(), [2, 2, 3],
permutation=[0, 2, 1])
assert tensor_type.extension_name == "arrow.fixed_shape_tensor"
assert tensor_type.storage_type == pa.list_(pa.float64(), 12)
assert tensor_type.shape == [2, 2, 3]
assert tensor_type.dim_names is None
assert tensor_type.permutation == [0, 2, 1]

tensor_type = pa.fixed_shape_tensor(pa.bool_(), [2, 2, 3],
dim_names=['C', 'H', 'W'])
assert tensor_type.extension_name == "arrow.fixed_shape_tensor"
assert tensor_type.storage_type == pa.list_(pa.bool_(), 12)
assert tensor_type.shape == [2, 2, 3]
assert tensor_type.dim_names == ['C', 'H', 'W']
assert tensor_type.permutation is None


def test_tensor_class_methods():
tensor_type = pa.fixed_shape_tensor(pa.float32(), [2, 3])
storage = pa.array([[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6]],
pa.list_(pa.float32(), 6))
arr = pa.ExtensionArray.from_storage(tensor_type, storage)
expected = np.array(
[[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], dtype=np.float32)
result = arr.to_numpy_ndarray()
np.testing.assert_array_equal(result, expected)

arr = np.array(
[[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]],
dtype=np.float32, order="C")
tensor_array_from_numpy = pa.FixedShapeTensorArray.from_numpy_ndarray(arr)
assert isinstance(tensor_array_from_numpy.type, pa.FixedShapeTensorType)
assert tensor_array_from_numpy.type.value_type == pa.float32()
assert tensor_array_from_numpy.type.shape == [2, 3]

arr = np.array(
[[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]],
dtype=np.float32, order="F")
with pytest.raises(ValueError, match="C-style contiguous segment"):
pa.FixedShapeTensorArray.from_numpy_ndarray(arr)

tensor_type = pa.fixed_shape_tensor(pa.int8(), [2, 2, 3], permutation=[0, 2, 1])
storage = pa.array([[1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6]], pa.list_(pa.int8(), 12))
arr = pa.ExtensionArray.from_storage(tensor_type, storage)
with pytest.raises(ValueError, match="non-permuted tensors"):
arr.to_numpy_ndarray()


@pytest.mark.parametrize("tensor_type", (
pa.fixed_shape_tensor(pa.int8(), [2, 2, 3]),
pa.fixed_shape_tensor(pa.int8(), [2, 2, 3], permutation=[0, 2, 1]),
pa.fixed_shape_tensor(pa.int8(), [2, 2, 3], dim_names=['C', 'H', 'W'])
))
def test_tensor_type_ipc(tensor_type):
storage = pa.array([[1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6]], pa.list_(pa.int8(), 12))
arr = pa.ExtensionArray.from_storage(tensor_type, storage)
batch = pa.RecordBatch.from_arrays([arr], ["ext"])

# check the built array has exactly the expected clss
tensor_class = tensor_type.__arrow_ext_class__()
assert type(arr) == tensor_class

buf = ipc_write_batch(batch)
del batch
batch = ipc_read_batch(buf)

result = batch.column(0)
# check the deserialized array class is the expected one
assert type(result) == tensor_class
assert result.type.extension_name == "arrow.fixed_shape_tensor"
assert arr.storage.to_pylist() == [[1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6]]

# we get back an actual TensorType
assert isinstance(result.type, pa.FixedShapeTensorType)
assert result.type.value_type == pa.int8()
assert result.type.shape == [2, 2, 3]


def test_tensor_type_equality():
tensor_type = pa.fixed_shape_tensor(pa.int8(), [2, 2, 3])
assert tensor_type.extension_name == "arrow.fixed_shape_tensor"

tensor_type2 = pa.fixed_shape_tensor(pa.int8(), [2, 2, 3])
tensor_type3 = pa.fixed_shape_tensor(pa.uint8(), [2, 2, 3])
assert tensor_type == tensor_type2
assert not tensor_type == tensor_type3


@pytest.mark.pandas
def test_extension_to_pandas_storage_type(registered_period_type):
period_type, _ = registered_period_type
Expand Down
Loading