diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.cc b/cpp/src/arrow/extension/fixed_shape_tensor.cc index 8b0ed43df5c..1debac0e704 100644 --- a/cpp/src/arrow/extension/fixed_shape_tensor.cc +++ b/cpp/src/arrow/extension/fixed_shape_tensor.cc @@ -23,6 +23,7 @@ #include "arrow/array/array_nested.h" #include "arrow/array/array_primitive.h" #include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep +#include "arrow/tensor.h" #include "arrow/util/int_util_overflow.h" #include "arrow/util/logging.h" #include "arrow/util/sort.h" @@ -33,8 +34,52 @@ namespace rj = arrow::rapidjson; namespace arrow { + namespace extension { +namespace { + +Status ComputeStrides(const FixedWidthType& type, const std::vector& shape, + const std::vector& permutation, + std::vector* strides) { + if (permutation.empty()) { + return internal::ComputeRowMajorStrides(type, shape, strides); + } + + const int byte_width = type.byte_width(); + + int64_t remaining = 0; + if (!shape.empty() && shape.front() > 0) { + remaining = byte_width; + for (auto i : permutation) { + if (i > 0) { + if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) { + return Status::Invalid( + "Strides computed from shape would not fit in 64-bit integer"); + } + } + } + } + + if (remaining == 0) { + strides->assign(shape.size(), byte_width); + return Status::OK(); + } + + strides->push_back(remaining); + for (auto i : permutation) { + if (i > 0) { + remaining /= shape[i]; + strides->push_back(remaining); + } + } + internal::Permute(permutation, strides); + + return Status::OK(); +} + +} // namespace + bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const { if (extension_name() != other.extension_name()) { return false; @@ -140,6 +185,132 @@ std::shared_ptr FixedShapeTensorType::MakeArray( return std::make_shared(data); } +Result> FixedShapeTensorArray::FromTensor( + const std::shared_ptr& tensor) { + auto permutation = internal::ArgSort(tensor->strides(), std::greater<>()); + if (permutation[0] != 0) { + return Status::Invalid( + "Only first-major tensors can be zero-copy converted to arrays"); + } + permutation.erase(permutation.begin()); + + std::vector cell_shape; + for (auto i : permutation) { + cell_shape.emplace_back(tensor->shape()[i]); + } + + std::vector dim_names; + if (!tensor->dim_names().empty()) { + for (auto i : permutation) { + dim_names.emplace_back(tensor->dim_names()[i]); + } + } + + for (int64_t& i : permutation) { + --i; + } + + auto ext_type = internal::checked_pointer_cast( + fixed_shape_tensor(tensor->type(), cell_shape, permutation, dim_names)); + + std::shared_ptr value_array; + switch (tensor->type_id()) { + case Type::UINT8: { + value_array = std::make_shared(tensor->size(), tensor->data()); + break; + } + case Type::INT8: { + value_array = std::make_shared(tensor->size(), tensor->data()); + break; + } + case Type::UINT16: { + value_array = std::make_shared(tensor->size(), tensor->data()); + break; + } + case Type::INT16: { + value_array = std::make_shared(tensor->size(), tensor->data()); + break; + } + case Type::UINT32: { + value_array = std::make_shared(tensor->size(), tensor->data()); + break; + } + case Type::INT32: { + value_array = std::make_shared(tensor->size(), tensor->data()); + break; + } + case Type::UINT64: { + value_array = std::make_shared(tensor->size(), tensor->data()); + break; + } + case Type::INT64: { + value_array = std::make_shared(tensor->size(), tensor->data()); + break; + } + case Type::HALF_FLOAT: { + value_array = std::make_shared(tensor->size(), tensor->data()); + break; + } + case Type::FLOAT: { + value_array = std::make_shared(tensor->size(), tensor->data()); + break; + } + case Type::DOUBLE: { + value_array = std::make_shared(tensor->size(), tensor->data()); + break; + } + default: { + return Status::NotImplemented("Unsupported tensor type: ", + tensor->type()->ToString()); + } + } + auto cell_size = static_cast(tensor->size() / tensor->shape()[0]); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr arr, + FixedSizeListArray::FromArrays(value_array, cell_size)); + std::shared_ptr ext_arr = ExtensionType::WrapArray(ext_type, arr); + return std::reinterpret_pointer_cast(ext_arr); +} + +const Result> FixedShapeTensorArray::ToTensor() const { + // To convert an array of n dimensional tensors to a n+1 dimensional tensor we + // interpret the array's length as the first dimension the new tensor. + + auto ext_arr = internal::checked_pointer_cast(this->storage()); + auto ext_type = internal::checked_pointer_cast(this->type()); + ARROW_RETURN_IF(!is_fixed_width(*ext_arr->value_type()), + Status::Invalid(ext_arr->value_type()->ToString(), + " is not valid data type for a tensor")); + auto permutation = ext_type->permutation(); + + std::vector dim_names; + if (!ext_type->dim_names().empty()) { + for (auto i : permutation) { + dim_names.emplace_back(ext_type->dim_names()[i]); + } + dim_names.insert(dim_names.begin(), 1, ""); + } else { + dim_names = {}; + } + + std::vector shape; + for (int64_t& i : permutation) { + shape.emplace_back(ext_type->shape()[i]); + ++i; + } + shape.insert(shape.begin(), 1, this->length()); + permutation.insert(permutation.begin(), 1, 0); + + std::vector tensor_strides; + auto value_type = internal::checked_pointer_cast(ext_arr->value_type()); + ARROW_RETURN_NOT_OK( + ComputeStrides(*value_type.get(), shape, permutation, &tensor_strides)); + ARROW_ASSIGN_OR_RAISE(auto buffers, ext_arr->Flatten()); + ARROW_ASSIGN_OR_RAISE( + auto tensor, Tensor::Make(ext_arr->value_type(), buffers->data()->buffers[1], shape, + tensor_strides, dim_names)); + return tensor; +} + Result> FixedShapeTensorType::Make( const std::shared_ptr& value_type, const std::vector& shape, const std::vector& permutation, const std::vector& dim_names) { @@ -157,6 +328,17 @@ Result> FixedShapeTensorType::Make( shape, permutation, dim_names); } +const std::vector& FixedShapeTensorType::strides() { + if (strides_.empty()) { + auto value_type = internal::checked_pointer_cast(this->value_type_); + std::vector tensor_strides; + ARROW_CHECK_OK(ComputeStrides(*value_type.get(), this->shape(), this->permutation(), + &tensor_strides)); + strides_ = tensor_strides; + } + return strides_; +} + std::shared_ptr fixed_shape_tensor(const std::shared_ptr& value_type, const std::vector& shape, const std::vector& permutation, diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.h b/cpp/src/arrow/extension/fixed_shape_tensor.h index 4ee2b894ee8..93837f13002 100644 --- a/cpp/src/arrow/extension/fixed_shape_tensor.h +++ b/cpp/src/arrow/extension/fixed_shape_tensor.h @@ -23,6 +23,26 @@ namespace extension { class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray { public: using ExtensionArray::ExtensionArray; + + /// \brief Create a FixedShapeTensorArray from a Tensor + /// + /// This method will create a FixedShapeTensorArray from a Tensor, taking its first + /// dimension as the number of elements in the resulting array and the remaining + /// dimensions as the shape of the individual tensors. If Tensor provides strides, + /// they will be used to determine dimension permutation. Otherwise, row-major layout + /// (i.e. no permutation) will be assumed. + /// + /// \param[in] tensor The Tensor to convert to a FixedShapeTensorArray + static Result> FromTensor( + const std::shared_ptr& tensor); + + /// \brief Create a Tensor from FixedShapeTensorArray + /// + /// This method will create a Tensor from a FixedShapeTensorArray, setting its first + /// dimension as length equal to the FixedShapeTensorArray's length and the remaining + /// dimensions as the FixedShapeTensorType's shape. Shape and dim_names will be + /// permuted according to permutation stored in the FixedShapeTensorType metadata. + const Result> ToTensor() const; }; /// \brief Concrete type class for constant-size Tensor data. @@ -51,6 +71,11 @@ class ARROW_EXPORT FixedShapeTensorType : public ExtensionType { /// Value type of tensor elements const std::shared_ptr value_type() const { return value_type_; } + /// Strides of tensor elements. Strides state offset in bytes between adjacent + /// elements along each dimension. In case permutation is non-empty strides are + /// computed from permuted tensor element's shape. + const std::vector& strides(); + /// Permutation mapping from logical to physical memory layout of tensor elements const std::vector& permutation() const { return permutation_; } @@ -78,6 +103,7 @@ class ARROW_EXPORT FixedShapeTensorType : public ExtensionType { std::shared_ptr storage_type_; std::shared_ptr value_type_; std::vector shape_; + std::vector strides_; std::vector permutation_; std::vector dim_names_; }; diff --git a/cpp/src/arrow/extension/fixed_shape_tensor_test.cc b/cpp/src/arrow/extension/fixed_shape_tensor_test.cc index 16ba9d2014e..50132e25fb1 100644 --- a/cpp/src/arrow/extension/fixed_shape_tensor_test.cc +++ b/cpp/src/arrow/extension/fixed_shape_tensor_test.cc @@ -47,17 +47,26 @@ class TestExtensionType : public ::testing::Test { fixed_shape_tensor(value_type_, cell_shape_, {}, dim_names_)); values_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35}; + values_partial_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}; + shape_partial_ = {2, 3, 4}; + tensor_strides_ = {96, 32, 8}; + cell_strides_ = {32, 8}; serialized_ = R"({"shape":[3,4],"dim_names":["x","y"]})"; } protected: std::vector shape_; + std::vector shape_partial_; std::vector cell_shape_; std::shared_ptr value_type_; std::shared_ptr cell_type_; std::vector dim_names_; std::shared_ptr ext_type_; std::vector values_; + std::vector values_partial_; + std::vector tensor_strides_; + std::vector cell_strides_; std::string serialized_; }; @@ -100,6 +109,7 @@ TEST_F(TestExtensionType, CreateExtensionType) { ASSERT_EQ(exact_ext_type->ndim(), cell_shape_.size()); ASSERT_EQ(exact_ext_type->shape(), cell_shape_); ASSERT_EQ(exact_ext_type->value_type(), value_type_); + ASSERT_EQ(exact_ext_type->strides(), cell_strides_); ASSERT_EQ(exact_ext_type->dim_names(), dim_names_); EXPECT_RAISES_WITH_MESSAGE_THAT( @@ -212,4 +222,216 @@ TEST_F(TestExtensionType, RoudtripBatch) { CompareBatch(*batch, *read_batch2, /*compare_metadata=*/true); } +TEST_F(TestExtensionType, CreateFromTensor) { + std::vector column_major_strides = {8, 24, 72}; + std::vector neither_major_strides = {96, 8, 32}; + + ASSERT_OK_AND_ASSIGN(auto tensor, + Tensor::Make(value_type_, Buffer::Wrap(values_), shape_)); + + auto exact_ext_type = internal::checked_pointer_cast(ext_type_); + ASSERT_OK_AND_ASSIGN(auto ext_arr, FixedShapeTensorArray::FromTensor(tensor)); + + ASSERT_OK(ext_arr->ValidateFull()); + ASSERT_TRUE(tensor->is_row_major()); + ASSERT_EQ(tensor->strides(), tensor_strides_); + ASSERT_EQ(ext_arr->length(), shape_[0]); + + auto ext_type_2 = internal::checked_pointer_cast( + fixed_shape_tensor(int64(), {3, 4}, {0, 1})); + ASSERT_OK_AND_ASSIGN(auto ext_arr_2, FixedShapeTensorArray::FromTensor(tensor)); + + ASSERT_OK_AND_ASSIGN( + auto column_major_tensor, + Tensor::Make(value_type_, Buffer::Wrap(values_), shape_, column_major_strides)); + auto ext_type_3 = internal::checked_pointer_cast( + fixed_shape_tensor(int64(), {3, 4}, {0, 1})); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + testing::HasSubstr( + "Invalid: Only first-major tensors can be zero-copy converted to arrays"), + FixedShapeTensorArray::FromTensor(column_major_tensor)); + ASSERT_THAT(FixedShapeTensorArray::FromTensor(column_major_tensor), + Raises(StatusCode::Invalid)); + + auto neither_major_tensor = std::make_shared(value_type_, Buffer::Wrap(values_), + shape_, neither_major_strides); + auto ext_type_4 = internal::checked_pointer_cast( + fixed_shape_tensor(int64(), {3, 4}, {1, 0})); + ASSERT_OK_AND_ASSIGN(auto ext_arr_4, + FixedShapeTensorArray::FromTensor(neither_major_tensor)); + + auto ext_type_5 = internal::checked_pointer_cast( + fixed_shape_tensor(binary(), {1, 3})); + auto arr = ArrayFromJSON(binary(), R"(["abc", "def"])"); + + ASSERT_OK_AND_ASSIGN(auto fsla_arr, + FixedSizeListArray::FromArrays(arr, fixed_size_list(binary(), 2))); + auto ext_arr_5 = std::reinterpret_pointer_cast( + ExtensionType::WrapArray(ext_type_5, fsla_arr)); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("binary is not valid data type for a tensor"), + ext_arr_5->ToTensor()); + + auto ext_type_6 = internal::checked_pointer_cast( + fixed_shape_tensor(int64(), {1, 2})); + auto arr_with_null = ArrayFromJSON(int64(), "[1, 0, null, null, 1, 2]"); + ASSERT_OK_AND_ASSIGN(auto fsla_arr_6, FixedSizeListArray::FromArrays( + arr_with_null, fixed_size_list(int64(), 2))); +} + +void CheckFromTensorType(const std::shared_ptr& tensor, + std::shared_ptr expected_ext_type) { + auto ext_type = internal::checked_pointer_cast(expected_ext_type); + ASSERT_OK_AND_ASSIGN(auto ext_arr, FixedShapeTensorArray::FromTensor(tensor)); + auto generated_ext_type = + internal::checked_cast(ext_arr->extension_type()); + + // Check that generated type is equal to the expected type + ASSERT_EQ(generated_ext_type->type_name(), ext_type->type_name()); + ASSERT_EQ(generated_ext_type->shape(), ext_type->shape()); + ASSERT_EQ(generated_ext_type->dim_names(), ext_type->dim_names()); + ASSERT_EQ(generated_ext_type->permutation(), ext_type->permutation()); + ASSERT_TRUE(generated_ext_type->storage_type()->Equals(*ext_type->storage_type())); + ASSERT_TRUE(generated_ext_type->Equals(ext_type)); +} + +TEST_F(TestExtensionType, TestFromTensorType) { + auto values = Buffer::Wrap(values_); + auto shapes = + std::vector>{{3, 3, 4}, {3, 3, 4}, {3, 4, 3}, {3, 4, 3}}; + auto strides = std::vector>{ + {96, 32, 8}, {96, 8, 24}, {96, 24, 8}, {96, 8, 32}}; + auto tensor_dim_names = std::vector>{ + {"x", "y", "z"}, {"x", "y", "z"}, {"x", "y", "z"}, {"x", "y", "z"}, + {"x", "y", "z"}, {"x", "y", "z"}, {"x", "y", "z"}, {"x", "y", "z"}}; + auto dim_names = std::vector>{ + {"y", "z"}, {"z", "y"}, {"y", "z"}, {"z", "y"}, + {"y", "z"}, {"y", "z"}, {"y", "z"}, {"y", "z"}}; + auto cell_shapes = std::vector>{{3, 4}, {4, 3}, {4, 3}, {3, 4}}; + auto permutations = std::vector>{{0, 1}, {1, 0}, {0, 1}, {1, 0}}; + + for (size_t i = 0; i < shapes.size(); i++) { + ASSERT_OK_AND_ASSIGN(auto tensor, Tensor::Make(value_type_, values, shapes[i], + strides[i], tensor_dim_names[i])); + ASSERT_OK_AND_ASSIGN(auto ext_arr, FixedShapeTensorArray::FromTensor(tensor)); + auto ext_type = + fixed_shape_tensor(value_type_, cell_shapes[i], permutations[i], dim_names[i]); + CheckFromTensorType(tensor, ext_type); + } +} + +void CheckTensorRoundtrip(const std::shared_ptr& tensor) { + ASSERT_OK_AND_ASSIGN(auto ext_arr, FixedShapeTensorArray::FromTensor(tensor)); + ASSERT_OK_AND_ASSIGN(auto tensor_from_array, ext_arr->ToTensor()); + + ASSERT_EQ(tensor->type(), tensor_from_array->type()); + ASSERT_EQ(tensor->shape(), tensor_from_array->shape()); + for (size_t i = 1; i < tensor->dim_names().size(); i++) { + ASSERT_EQ(tensor->dim_names()[i], tensor_from_array->dim_names()[i]); + } + ASSERT_EQ(tensor->strides(), tensor_from_array->strides()); + ASSERT_TRUE(tensor->data()->Equals(*tensor_from_array->data())); + ASSERT_TRUE(tensor->Equals(*tensor_from_array)); +} + +TEST_F(TestExtensionType, RoundtripTensor) { + auto values = Buffer::Wrap(values_); + + auto shapes = std::vector>{ + {3, 3, 4}, {3, 4, 3}, {3, 4, 3}, {3, 3, 4}, {6, 2, 3}, + {6, 3, 2}, {2, 3, 6}, {2, 6, 3}, {2, 3, 2, 3}, {2, 3, 2, 3}}; + auto strides = std::vector>{ + {96, 32, 8}, {96, 8, 32}, {96, 24, 8}, {96, 8, 24}, {48, 24, 8}, + {48, 8, 24}, {144, 48, 8}, {144, 8, 48}, {144, 48, 24, 8}, {144, 8, 24, 48}}; + auto tensor_dim_names = std::vector>{ + {"x", "y", "z"}, {"x", "y", "z"}, {"x", "y", "z"}, {"x", "y", "z"}, + {"x", "y", "z"}, {"x", "y", "z"}, {"x", "y", "z"}, {"x", "y", "z"}, + {"N", "H", "W", "C"}, {"N", "H", "W", "C"}}; + + for (size_t i = 0; i < shapes.size(); i++) { + ASSERT_OK_AND_ASSIGN(auto tensor, Tensor::Make(value_type_, values, shapes[i], + strides[i], tensor_dim_names[i])); + CheckTensorRoundtrip(tensor); + } +} + +TEST_F(TestExtensionType, SliceTensor) { + ASSERT_OK_AND_ASSIGN(auto tensor, + Tensor::Make(value_type_, Buffer::Wrap(values_), shape_)); + ASSERT_OK_AND_ASSIGN( + auto tensor_partial, + Tensor::Make(value_type_, Buffer::Wrap(values_partial_), shape_partial_)); + ASSERT_EQ(tensor->strides(), tensor_strides_); + ASSERT_EQ(tensor_partial->strides(), tensor_strides_); + auto ext_type = fixed_shape_tensor(value_type_, cell_shape_, {}, dim_names_); + auto exact_ext_type = internal::checked_pointer_cast(ext_type_); + + ASSERT_OK_AND_ASSIGN(auto ext_arr, FixedShapeTensorArray::FromTensor(tensor)); + ASSERT_OK_AND_ASSIGN(auto ext_arr_partial, + FixedShapeTensorArray::FromTensor(tensor_partial)); + ASSERT_OK(ext_arr->ValidateFull()); + ASSERT_OK(ext_arr_partial->ValidateFull()); + + auto sliced = internal::checked_pointer_cast(ext_arr->Slice(0, 2)); + auto partial = internal::checked_pointer_cast(ext_arr_partial); + + ASSERT_TRUE(sliced->Equals(*partial)); + ASSERT_OK(sliced->ValidateFull()); + ASSERT_OK(partial->ValidateFull()); + ASSERT_TRUE(sliced->storage()->Equals(*partial->storage())); + ASSERT_EQ(sliced->length(), partial->length()); +} + +TEST_F(TestExtensionType, RoudtripBatchFromTensor) { + auto exact_ext_type = internal::checked_pointer_cast(ext_type_); + ASSERT_OK_AND_ASSIGN(auto tensor, Tensor::Make(value_type_, Buffer::Wrap(values_), + shape_, {}, {"n", "x", "y"})); + ASSERT_OK_AND_ASSIGN(auto ext_arr, FixedShapeTensorArray::FromTensor(tensor)); + ext_arr->data()->type = exact_ext_type; + + auto ext_metadata = + key_value_metadata({{"ARROW:extension:name", ext_type_->extension_name()}, + {"ARROW:extension:metadata", serialized_}}); + auto ext_field = field("f0", ext_type_, true, ext_metadata); + auto batch = RecordBatch::Make(schema({ext_field}), ext_arr->length(), {ext_arr}); + std::shared_ptr read_batch; + RoundtripBatch(batch, &read_batch); + CompareBatch(*batch, *read_batch, /*compare_metadata=*/true); +} + +TEST_F(TestExtensionType, ComputeStrides) { + auto exact_ext_type = internal::checked_pointer_cast(ext_type_); + + auto ext_type_1 = internal::checked_pointer_cast( + fixed_shape_tensor(int64(), cell_shape_, {}, dim_names_)); + auto ext_type_2 = internal::checked_pointer_cast( + fixed_shape_tensor(int64(), cell_shape_, {}, dim_names_)); + auto ext_type_3 = internal::checked_pointer_cast( + fixed_shape_tensor(int32(), cell_shape_, {}, dim_names_)); + ASSERT_TRUE(ext_type_1->Equals(*ext_type_2)); + ASSERT_FALSE(ext_type_1->Equals(*ext_type_3)); + + auto ext_type_4 = internal::checked_pointer_cast( + fixed_shape_tensor(int64(), {3, 4, 7}, {}, {"x", "y", "z"})); + ASSERT_EQ(ext_type_4->strides(), (std::vector{224, 56, 8})); + ext_type_4 = internal::checked_pointer_cast( + fixed_shape_tensor(int64(), {3, 4, 7}, {0, 1, 2}, {"x", "y", "z"})); + ASSERT_EQ(ext_type_4->strides(), (std::vector{224, 56, 8})); + + auto ext_type_5 = internal::checked_pointer_cast( + fixed_shape_tensor(int64(), {3, 4, 7}, {1, 0, 2})); + ASSERT_EQ(ext_type_5->strides(), (std::vector{56, 224, 8})); + ASSERT_EQ(ext_type_5->Serialize(), R"({"shape":[3,4,7],"permutation":[1,0,2]})"); + + auto ext_type_6 = internal::checked_pointer_cast( + fixed_shape_tensor(int64(), {3, 4, 7}, {1, 2, 0}, {})); + ASSERT_EQ(ext_type_6->strides(), (std::vector{56, 8, 224})); + ASSERT_EQ(ext_type_6->Serialize(), R"({"shape":[3,4,7],"permutation":[1,2,0]})"); + auto ext_type_7 = internal::checked_pointer_cast( + fixed_shape_tensor(int32(), {3, 4, 7}, {2, 0, 1}, {})); + ASSERT_EQ(ext_type_7->strides(), (std::vector{4, 112, 16})); + ASSERT_EQ(ext_type_7->Serialize(), R"({"shape":[3,4,7],"permutation":[2,0,1]})"); +} + } // namespace arrow