diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 2e5c67e07b6..6b1c25dc604 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -984,6 +984,8 @@ if(ARROW_JSON) arrow_add_object_library(ARROW_JSON extension/fixed_shape_tensor.cc extension/opaque.cc + extension/tensor_internal.cc + extension/variable_shape_tensor.cc json/options.cc json/chunked_builder.cc json/chunker.cc diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 15955b5ef88..546a3e9ffe2 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -22,7 +22,8 @@ arrow_install_all_headers("arrow/compute/kernels") # Define arrow_compute_kernels_testing object library for common test files if(ARROW_TESTING) - add_library(arrow_compute_kernels_testing OBJECT test_util_internal.cc) + add_library(arrow_compute_kernels_testing OBJECT + test_util_internal.cc ../../extension/tensor_extension_array_test.cc) # Even though this is still just an object library we still need to "link" our # dependencies so that include paths are configured correctly target_link_libraries(arrow_compute_kernels_testing PUBLIC ${ARROW_GTEST_GMOCK}) diff --git a/cpp/src/arrow/extension/CMakeLists.txt b/cpp/src/arrow/extension/CMakeLists.txt index 4ab6a35b52e..ae52bc32a99 100644 --- a/cpp/src/arrow/extension/CMakeLists.txt +++ b/cpp/src/arrow/extension/CMakeLists.txt @@ -18,7 +18,7 @@ set(CANONICAL_EXTENSION_TESTS bool8_test.cc json_test.cc uuid_test.cc) if(ARROW_JSON) - list(APPEND CANONICAL_EXTENSION_TESTS fixed_shape_tensor_test.cc opaque_test.cc) + list(APPEND CANONICAL_EXTENSION_TESTS tensor_extension_array_test.cc opaque_test.cc) endif() add_arrow_test(test diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.cc b/cpp/src/arrow/extension/fixed_shape_tensor.cc index bb7082e6976..e7df91f5892 100644 --- a/cpp/src/arrow/extension/fixed_shape_tensor.cc +++ b/cpp/src/arrow/extension/fixed_shape_tensor.cc @@ -37,52 +37,7 @@ namespace rj = arrow::rapidjson; -namespace arrow { - -namespace extension { - -namespace { - -Status ComputeStrides(const FixedWidthType& type, const std::vector& shape, - const std::vector& permutation, - std::vector* strides) { - if (permutation.empty()) { - return internal::ComputeRowMajorStrides(type, shape, strides); - } - - const int byte_width = type.byte_width(); - - int64_t remaining = 0; - if (!shape.empty() && shape.front() > 0) { - remaining = byte_width; - for (auto i : permutation) { - if (i > 0) { - if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) { - return Status::Invalid( - "Strides computed from shape would not fit in 64-bit integer"); - } - } - } - } - - if (remaining == 0) { - strides->assign(shape.size(), byte_width); - return Status::OK(); - } - - strides->push_back(remaining); - for (auto i : permutation) { - if (i > 0) { - remaining /= shape[i]; - strides->push_back(remaining); - } - } - internal::Permute(permutation, strides); - - return Status::OK(); -} - -} // namespace +namespace arrow::extension { bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const { if (extension_name() != other.extension_name()) { @@ -237,7 +192,8 @@ Result> FixedShapeTensorType::MakeTensor( } std::vector strides; - RETURN_NOT_OK(ComputeStrides(value_type, shape, permutation, &strides)); + RETURN_NOT_OK( + internal::ComputeStrides(ext_type.value_type(), shape, permutation, &strides)); const auto start_position = array->offset() * byte_width; const auto size = std::accumulate(shape.begin(), shape.end(), static_cast(1), std::multiplies<>()); @@ -376,9 +332,8 @@ const Result> FixedShapeTensorArray::ToTensor() const { internal::Permute(permutation, &shape); std::vector tensor_strides; - const auto* fw_value_type = internal::checked_cast(value_type.get()); ARROW_RETURN_NOT_OK( - ComputeStrides(*fw_value_type, shape, permutation, &tensor_strides)); + internal::ComputeStrides(value_type, shape, permutation, &tensor_strides)); const auto& raw_buffer = this->storage()->data()->child_data[0]->buffers[1]; ARROW_ASSIGN_OR_RAISE( @@ -412,10 +367,9 @@ Result> FixedShapeTensorType::Make( const std::vector& FixedShapeTensorType::strides() { if (strides_.empty()) { - auto value_type = internal::checked_cast(this->value_type_.get()); std::vector tensor_strides; - ARROW_CHECK_OK( - ComputeStrides(*value_type, this->shape(), this->permutation(), &tensor_strides)); + ARROW_CHECK_OK(internal::ComputeStrides(this->value_type_, this->shape(), + this->permutation(), &tensor_strides)); strides_ = tensor_strides; } return strides_; @@ -430,5 +384,4 @@ std::shared_ptr fixed_shape_tensor(const std::shared_ptr& va return maybe_type.MoveValueUnsafe(); } -} // namespace extension -} // namespace arrow +} // namespace arrow::extension diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.h b/cpp/src/arrow/extension/fixed_shape_tensor.h index 80a602021c6..5098da0405f 100644 --- a/cpp/src/arrow/extension/fixed_shape_tensor.h +++ b/cpp/src/arrow/extension/fixed_shape_tensor.h @@ -19,8 +19,7 @@ #include "arrow/extension_type.h" -namespace arrow { -namespace extension { +namespace arrow::extension { class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray { public: @@ -126,5 +125,4 @@ ARROW_EXPORT std::shared_ptr fixed_shape_tensor( const std::vector& permutation = {}, const std::vector& dim_names = {}); -} // namespace extension -} // namespace arrow +} // namespace arrow::extension diff --git a/cpp/src/arrow/extension/fixed_shape_tensor_test.cc b/cpp/src/arrow/extension/tensor_extension_array_test.cc similarity index 66% rename from cpp/src/arrow/extension/fixed_shape_tensor_test.cc rename to cpp/src/arrow/extension/tensor_extension_array_test.cc index 6d4d2de3265..2305d2a9e9d 100644 --- a/cpp/src/arrow/extension/fixed_shape_tensor_test.cc +++ b/cpp/src/arrow/extension/tensor_extension_array_test.cc @@ -16,6 +16,7 @@ // under the License. #include "arrow/extension/fixed_shape_tensor.h" +#include "arrow/extension/variable_shape_tensor.h" #include "arrow/testing/matchers.h" @@ -37,7 +38,11 @@ using arrow::ipc::test::RoundtripBatch; using extension::fixed_shape_tensor; using extension::FixedShapeTensorArray; -class TestExtensionType : public ::testing::Test { +using VariableShapeTensorType = extension::VariableShapeTensorType; +using extension::variable_shape_tensor; +using extension::VariableShapeTensorArray; + +class TestFixedShapeTensorType : public ::testing::Test { public: void SetUp() override { shape_ = {3, 3, 4}; @@ -72,13 +77,13 @@ class TestExtensionType : public ::testing::Test { std::string serialized_; }; -TEST_F(TestExtensionType, CheckDummyRegistration) { +TEST_F(TestFixedShapeTensorType, CheckDummyRegistration) { // We need a registered dummy type at runtime to allow for IPC deserialization auto registered_type = GetExtensionType("arrow.fixed_shape_tensor"); - ASSERT_TRUE(registered_type->type_id == Type::EXTENSION); + ASSERT_EQ(registered_type->id(), Type::EXTENSION); } -TEST_F(TestExtensionType, CreateExtensionType) { +TEST_F(TestFixedShapeTensorType, CreateExtensionType) { auto exact_ext_type = internal::checked_pointer_cast(ext_type_); // Test ExtensionType methods @@ -118,7 +123,7 @@ TEST_F(TestExtensionType, CreateExtensionType) { FixedShapeTensorType::Make(value_type_, {1, 2, 3}, {0, 1, 1})); } -TEST_F(TestExtensionType, EqualsCases) { +TEST_F(TestFixedShapeTensorType, EqualsCases) { auto ext_type_permutation_1 = fixed_shape_tensor(int64(), {3, 4}, {0, 1}, {"x", "y"}); auto ext_type_permutation_2 = fixed_shape_tensor(int64(), {3, 4}, {1, 0}, {"x", "y"}); auto ext_type_no_permutation = fixed_shape_tensor(int64(), {3, 4}, {}, {"x", "y"}); @@ -140,7 +145,7 @@ TEST_F(TestExtensionType, EqualsCases) { ASSERT_FALSE(ext_type_permutation_2->Equals(ext_type_permutation_1)); } -TEST_F(TestExtensionType, CreateFromArray) { +TEST_F(TestFixedShapeTensorType, CreateFromArray) { auto exact_ext_type = internal::checked_pointer_cast(ext_type_); std::vector> buffers = {nullptr, Buffer::Wrap(values_)}; @@ -152,7 +157,7 @@ TEST_F(TestExtensionType, CreateFromArray) { ASSERT_EQ(ext_arr->null_count(), 0); } -TEST_F(TestExtensionType, MakeArrayCanGetCorrectScalarType) { +TEST_F(TestFixedShapeTensorType, MakeArrayCanGetCorrectScalarType) { ASSERT_OK_AND_ASSIGN(auto tensor, Tensor::Make(value_type_, Buffer::Wrap(values_), shape_)); @@ -175,23 +180,23 @@ TEST_F(TestExtensionType, MakeArrayCanGetCorrectScalarType) { } void CheckSerializationRoundtrip(const std::shared_ptr& ext_type) { - auto fst_type = internal::checked_pointer_cast(ext_type); - auto serialized = fst_type->Serialize(); + auto type = internal::checked_pointer_cast(ext_type); + auto serialized = type->Serialize(); ASSERT_OK_AND_ASSIGN(auto deserialized, - fst_type->Deserialize(fst_type->storage_type(), serialized)); - ASSERT_TRUE(fst_type->Equals(*deserialized)); + type->Deserialize(type->storage_type(), serialized)); + ASSERT_TRUE(type->Equals(*deserialized)); } -void CheckDeserializationRaises(const std::shared_ptr& storage_type, +void CheckDeserializationRaises(const std::shared_ptr& extension_type, + const std::shared_ptr& storage_type, const std::string& serialized, const std::string& expected_message) { - auto fst_type = internal::checked_pointer_cast( - fixed_shape_tensor(int64(), {3, 4})); + auto ext_type = internal::checked_pointer_cast(extension_type); EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr(expected_message), - fst_type->Deserialize(storage_type, serialized)); + ext_type->Deserialize(storage_type, serialized)); } -TEST_F(TestExtensionType, MetadataSerializationRoundtrip) { +TEST_F(TestFixedShapeTensorType, MetadataSerializationRoundtrip) { CheckSerializationRoundtrip(ext_type_); CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {}, {}, {})); CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {0}, {}, {})); @@ -202,19 +207,21 @@ TEST_F(TestExtensionType, MetadataSerializationRoundtrip) { fixed_shape_tensor(value_type_, {256, 256, 3}, {2, 0, 1}, {"C", "H", "W"})); auto storage_type = fixed_size_list(int64(), 12); - CheckDeserializationRaises(boolean(), R"({"shape":[3,4]})", + CheckDeserializationRaises(ext_type_, boolean(), R"({"shape":[3,4]})", "Expected FixedSizeList storage type, got bool"); - CheckDeserializationRaises(storage_type, R"({"dim_names":["x","y"]})", + CheckDeserializationRaises(ext_type_, storage_type, R"({"dim_names":["x","y"]})", "Invalid serialized JSON data"); - CheckDeserializationRaises(storage_type, R"({"shape":(3,4)})", + CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":(3,4)})", "Invalid serialized JSON data"); - CheckDeserializationRaises(storage_type, R"({"shape":[3,4],"permutation":[1,0,2]})", + CheckDeserializationRaises(ext_type_, storage_type, + R"({"shape":[3,4],"permutation":[1,0,2]})", "Invalid permutation"); - CheckDeserializationRaises(storage_type, R"({"shape":[3],"dim_names":["x","y"]})", + CheckDeserializationRaises(ext_type_, storage_type, + R"({"shape":[3],"dim_names":["x","y"]})", "Invalid dim_names"); } -TEST_F(TestExtensionType, RoundtripBatch) { +TEST_F(TestFixedShapeTensorType, RoundtripBatch) { auto exact_ext_type = internal::checked_pointer_cast(ext_type_); std::vector> buffers = {nullptr, Buffer::Wrap(values_)}; @@ -242,7 +249,7 @@ TEST_F(TestExtensionType, RoundtripBatch) { CompareBatch(*batch, *read_batch2, /*compare_metadata=*/true); } -TEST_F(TestExtensionType, CreateFromTensor) { +TEST_F(TestFixedShapeTensorType, CreateFromTensor) { std::vector column_major_strides = {8, 24, 72}; std::vector neither_major_strides = {96, 8, 32}; @@ -320,7 +327,7 @@ void CheckFromTensorType(const std::shared_ptr& tensor, ASSERT_TRUE(generated_ext_type->Equals(ext_type)); } -TEST_F(TestExtensionType, TestFromTensorType) { +TEST_F(TestFixedShapeTensorType, TestFromTensorType) { auto values = Buffer::Wrap(values_); auto shapes = std::vector>{{3, 3, 4}, {3, 3, 4}, {3, 4, 3}, {3, 4, 3}}; @@ -379,7 +386,7 @@ void CheckToTensor(const std::vector& values, const std::shared_ptr ASSERT_TRUE(actual_tensor->Equals(*expected_tensor)); } -TEST_F(TestExtensionType, ToTensor) { +TEST_F(TestFixedShapeTensorType, ToTensor) { std::vector float_values = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35}; @@ -430,7 +437,7 @@ void CheckTensorRoundtrip(const std::shared_ptr& tensor) { ASSERT_TRUE(tensor->Equals(*tensor_from_array)); } -TEST_F(TestExtensionType, RoundtripTensor) { +TEST_F(TestFixedShapeTensorType, RoundtripTensor) { auto values = Buffer::Wrap(values_); auto shapes = std::vector>{ @@ -451,7 +458,7 @@ TEST_F(TestExtensionType, RoundtripTensor) { } } -TEST_F(TestExtensionType, SliceTensor) { +TEST_F(TestFixedShapeTensorType, SliceTensor) { ASSERT_OK_AND_ASSIGN(auto tensor, Tensor::Make(value_type_, Buffer::Wrap(values_), shape_)); ASSERT_OK_AND_ASSIGN( @@ -478,7 +485,7 @@ TEST_F(TestExtensionType, SliceTensor) { ASSERT_EQ(sliced->length(), partial->length()); } -TEST_F(TestExtensionType, RoundtripBatchFromTensor) { +TEST_F(TestFixedShapeTensorType, RoundtripBatchFromTensor) { auto exact_ext_type = internal::checked_pointer_cast(ext_type_); ASSERT_OK_AND_ASSIGN(auto tensor, Tensor::Make(value_type_, Buffer::Wrap(values_), shape_, {}, {"n", "x", "y"})); @@ -495,7 +502,7 @@ TEST_F(TestExtensionType, RoundtripBatchFromTensor) { CompareBatch(*batch, *read_batch, /*compare_metadata=*/true); } -TEST_F(TestExtensionType, ComputeStrides) { +TEST_F(TestFixedShapeTensorType, ComputeStrides) { auto exact_ext_type = internal::checked_pointer_cast(ext_type_); auto ext_type_1 = internal::checked_pointer_cast( @@ -529,7 +536,7 @@ TEST_F(TestExtensionType, ComputeStrides) { ASSERT_EQ(ext_type_7->Serialize(), R"({"shape":[3,4,7],"permutation":[2,0,1]})"); } -TEST_F(TestExtensionType, ToString) { +TEST_F(TestFixedShapeTensorType, FixedShapeTensorToString) { auto exact_ext_type = internal::checked_pointer_cast(ext_type_); auto ext_type_1 = internal::checked_pointer_cast( @@ -557,7 +564,7 @@ TEST_F(TestExtensionType, ToString) { ASSERT_EQ(expected_3, result_3); } -TEST_F(TestExtensionType, GetTensor) { +TEST_F(TestFixedShapeTensorType, GetTensor) { auto arr = ArrayFromJSON(element_type_, "[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]," "[12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]]"); @@ -649,4 +656,284 @@ TEST_F(TestExtensionType, GetTensor) { exact_ext_type->MakeTensor(ext_scalar)); } +class TestVariableShapeTensorType : public ::testing::Test { + public: + void SetUp() override { + ndim_ = 3; + value_type_ = int64(); + data_type_ = list(value_type_); + shape_type_ = fixed_size_list(int32(), ndim_); + permutation_ = {0, 1, 2}; + dim_names_ = {"x", "y", "z"}; + uniform_shape_ = {std::nullopt, std::optional(1), std::nullopt}; + ext_type_ = internal::checked_pointer_cast(variable_shape_tensor( + value_type_, ndim_, permutation_, dim_names_, uniform_shape_)); + values_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35}; + shapes_ = ArrayFromJSON(fixed_size_list(int32(), ndim_), "[[2,1,3],[2,1,2],[3,1,3]]"); + data_ = ArrayFromJSON(list(value_type_), + "[[0,1,2,3,4,5],[6,7,8,9],[10,11,12,13,14,15,16,17,18]]"); + serialized_ = + R"({"permutation":[0,1,2],"dim_names":["x","y","z"],"uniform_shape":[null,1,null]})"; + storage_arr_ = ArrayFromJSON( + ext_type_->storage_type(), + R"([[[0,1,2,3,4,5],[2,3,1]],[[6,7,8,9],[1,2,2]],[[10,11,12,13,14,15,16,17,18],[3,1,3]]])"); + ext_arr_ = internal::checked_pointer_cast( + ExtensionType::WrapArray(ext_type_, storage_arr_)); + } + + protected: + int32_t ndim_; + std::shared_ptr value_type_; + std::shared_ptr data_type_; + std::shared_ptr shape_type_; + std::vector permutation_; + std::vector> uniform_shape_; + std::vector dim_names_; + std::shared_ptr ext_type_; + std::vector values_; + std::shared_ptr shapes_; + std::shared_ptr data_; + std::string serialized_; + std::shared_ptr storage_arr_; + std::shared_ptr ext_arr_; +}; + +TEST_F(TestVariableShapeTensorType, CheckDummyRegistration) { + // We need a registered dummy type at runtime to allow for IPC deserialization + auto registered_type = GetExtensionType("arrow.variable_shape_tensor"); + ASSERT_EQ(registered_type->id(), Type::EXTENSION); +} + +TEST_F(TestVariableShapeTensorType, CreateExtensionType) { + auto exact_ext_type = + internal::checked_pointer_cast(ext_type_); + + // Test ExtensionType methods + ASSERT_EQ(ext_type_->extension_name(), "arrow.variable_shape_tensor"); + ASSERT_TRUE(ext_type_->Equals(*exact_ext_type)); + auto expected_type = + struct_({::arrow::field("data", list(value_type_)), + ::arrow::field("shape", fixed_size_list(int32(), ndim_))}); + + ASSERT_TRUE(ext_type_->storage_type()->Equals(*expected_type)); + ASSERT_EQ(ext_type_->Serialize(), serialized_); + ASSERT_OK_AND_ASSIGN(auto ds, + ext_type_->Deserialize(ext_type_->storage_type(), serialized_)); + auto deserialized = internal::checked_pointer_cast(ds); + ASSERT_TRUE(deserialized->Equals(*exact_ext_type)); + ASSERT_TRUE(deserialized->Equals(*ext_type_)); + + // Test FixedShapeTensorType methods + ASSERT_EQ(exact_ext_type->id(), Type::EXTENSION); + ASSERT_EQ(exact_ext_type->ndim(), ndim_); + ASSERT_EQ(exact_ext_type->value_type(), value_type_); + ASSERT_EQ(exact_ext_type->permutation(), permutation_); + ASSERT_EQ(exact_ext_type->dim_names(), dim_names_); + + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + testing::HasSubstr("Invalid: permutation size must match ndim. Expected: 3 Got: 1"), + VariableShapeTensorType::Make(value_type_, ndim_, {0})); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("Invalid: dim_names size must match ndim."), + VariableShapeTensorType::Make(value_type_, ndim_, {}, {"x"})); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + testing::HasSubstr("Invalid: Permutation indices for 3 dimensional tensors must be " + "unique and within [0, 2] range. Got: [2,0,0]"), + VariableShapeTensorType::Make(value_type_, 3, {2, 0, 0}, {"C", "H", "W"})); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + testing::HasSubstr("Invalid: Permutation indices for 3 dimensional tensors must be " + "unique and within [0, 2] range. Got: [1,2,3]"), + VariableShapeTensorType::Make(value_type_, 3, {1, 2, 3}, {"C", "H", "W"})); +} + +TEST_F(TestVariableShapeTensorType, EqualsCases) { + auto ext_type_permutation_1 = variable_shape_tensor(int64(), 2, {0, 1}, {"x", "y"}); + auto ext_type_permutation_2 = variable_shape_tensor(int64(), 2, {1, 0}, {"x", "y"}); + auto ext_type_no_permutation = variable_shape_tensor(int64(), 2, {}, {"x", "y"}); + + ASSERT_TRUE(ext_type_permutation_1->Equals(ext_type_permutation_1)); + + ASSERT_FALSE( + variable_shape_tensor(int32(), 2, {}, {"x", "y"})->Equals(ext_type_no_permutation)); + ASSERT_FALSE(variable_shape_tensor(int64(), 2, {}, {}) + ->Equals(variable_shape_tensor(int64(), 3, {}, {}))); + ASSERT_FALSE( + variable_shape_tensor(int64(), 2, {}, {"H", "W"})->Equals(ext_type_no_permutation)); + + ASSERT_TRUE(ext_type_no_permutation->Equals(ext_type_permutation_1)); + ASSERT_TRUE(ext_type_permutation_1->Equals(ext_type_no_permutation)); + ASSERT_FALSE(ext_type_no_permutation->Equals(ext_type_permutation_2)); + ASSERT_FALSE(ext_type_permutation_2->Equals(ext_type_no_permutation)); + ASSERT_FALSE(ext_type_permutation_1->Equals(ext_type_permutation_2)); + ASSERT_FALSE(ext_type_permutation_2->Equals(ext_type_permutation_1)); +} + +TEST_F(TestVariableShapeTensorType, MetadataSerializationRoundtrip) { + CheckSerializationRoundtrip(ext_type_); + CheckSerializationRoundtrip( + variable_shape_tensor(value_type_, 3, {1, 2, 0}, {"x", "y", "z"})); + CheckSerializationRoundtrip(variable_shape_tensor(value_type_, 0, {}, {})); + CheckSerializationRoundtrip(variable_shape_tensor(value_type_, 1, {0}, {"x"})); + CheckSerializationRoundtrip( + variable_shape_tensor(value_type_, 3, {0, 1, 2}, {"H", "W", "C"})); + CheckSerializationRoundtrip( + variable_shape_tensor(value_type_, 3, {2, 0, 1}, {"C", "H", "W"})); + CheckSerializationRoundtrip( + variable_shape_tensor(value_type_, 3, {2, 0, 1}, {"C", "H", "W"}, {0, 1, 2})); + + auto storage_type = ext_type_->storage_type(); + CheckDeserializationRaises(ext_type_, boolean(), R"({"shape":[3,4]})", + "Expected Struct storage type, got bool"); + CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":(3,4)})", + "Invalid serialized JSON data"); + CheckDeserializationRaises(ext_type_, storage_type, R"({"permutation":[1,0]})", + "Invalid: permutation"); + CheckDeserializationRaises(ext_type_, storage_type, R"({"dim_names":["x","y"]})", + "Invalid: dim_names"); +} + +TEST_F(TestVariableShapeTensorType, RoudtripBatch) { + auto exact_ext_type = + internal::checked_pointer_cast(ext_type_); + + // Pass extension array, expect getting back extension array + std::shared_ptr read_batch; + auto ext_field = field(/*name=*/"f0", /*type=*/ext_type_); + auto batch = RecordBatch::Make(schema({ext_field}), ext_arr_->length(), {ext_arr_}); + ASSERT_OK(RoundtripBatch(batch, &read_batch)); + CompareBatch(*batch, *read_batch, /*compare_metadata=*/true); + + // Pass extension metadata and storage array, expect getting back extension array + std::shared_ptr read_batch2; + auto ext_metadata = + key_value_metadata({{"ARROW:extension:name", exact_ext_type->extension_name()}, + {"ARROW:extension:metadata", serialized_}}); + ext_field = field(/*name=*/"f0", /*type=*/ext_type_->storage_type(), /*nullable=*/true, + /*metadata=*/ext_metadata); + auto batch2 = RecordBatch::Make(schema({ext_field}), ext_arr_->length(), {ext_arr_}); + ASSERT_OK(RoundtripBatch(batch2, &read_batch2)); + CompareBatch(*batch, *read_batch2, /*compare_metadata=*/true); +} + +TEST_F(TestVariableShapeTensorType, ComputeStrides) { + auto shapes = ArrayFromJSON(shape_type_, "[[2,3,1],[2,1,2],[3,1,3],null]"); + auto data = ArrayFromJSON( + data_type_, "[[1,1,2,3,4,5],[2,7,8,9],[10,11,12,13,14,15,16,17,18],null]"); + std::vector> fields = {field("data", data_type_), + field("shapes", shape_type_)}; + ASSERT_OK_AND_ASSIGN(auto storage_arr, StructArray::Make({data, shapes}, fields)); + auto ext_arr = ExtensionType::WrapArray(ext_type_, storage_arr); + auto exact_ext_type = + internal::checked_pointer_cast(ext_type_); + auto ext_array = std::static_pointer_cast(ext_arr); + + std::shared_ptr t, tensor; + + ASSERT_OK_AND_ASSIGN(auto scalar, ext_array->GetScalar(0)); + auto ext_scalar = internal::checked_pointer_cast(scalar); + ASSERT_OK_AND_ASSIGN(t, exact_ext_type->MakeTensor(ext_scalar)); + ASSERT_EQ(t->shape(), (std::vector{2, 3, 1})); + ASSERT_EQ(t->strides(), (std::vector{24, 8, 8})); + + std::vector strides = {sizeof(int64_t) * 3, sizeof(int64_t) * 1, + sizeof(int64_t) * 1}; + tensor = TensorFromJSON(int64(), R"([1,1,2,3,4,5])", {2, 3, 1}, strides, dim_names_); + + ASSERT_TRUE(tensor->Equals(*t)); + + ASSERT_OK_AND_ASSIGN(scalar, ext_array->GetScalar(1)); + ext_scalar = internal::checked_pointer_cast(scalar); + ASSERT_OK_AND_ASSIGN(t, exact_ext_type->MakeTensor(ext_scalar)); + ASSERT_EQ(t->shape(), (std::vector{2, 1, 2})); + ASSERT_EQ(t->strides(), (std::vector{16, 16, 8})); + + ASSERT_OK_AND_ASSIGN(scalar, ext_array->GetScalar(2)); + ext_scalar = internal::checked_pointer_cast(scalar); + ASSERT_OK_AND_ASSIGN(t, exact_ext_type->MakeTensor(ext_scalar)); + ASSERT_EQ(t->shape(), (std::vector{3, 1, 3})); + ASSERT_EQ(t->strides(), (std::vector{24, 24, 8})); + + strides = {sizeof(int64_t) * 3, sizeof(int64_t) * 3, sizeof(int64_t) * 1}; + tensor = TensorFromJSON(int64(), R"([10,11,12,13,14,15,16,17,18])", {3, 1, 3}, strides, + dim_names_); + + ASSERT_EQ(tensor->strides(), t->strides()); + ASSERT_EQ(tensor->shape(), t->shape()); + ASSERT_EQ(tensor->dim_names(), t->dim_names()); + ASSERT_EQ(tensor->type(), t->type()); + ASSERT_EQ(tensor->is_contiguous(), t->is_contiguous()); + ASSERT_EQ(tensor->is_column_major(), t->is_column_major()); + ASSERT_TRUE(tensor->Equals(*t)); + + ASSERT_OK_AND_ASSIGN(auto sc, ext_arr->GetScalar(2)); + auto s = internal::checked_pointer_cast(sc); + ASSERT_OK_AND_ASSIGN(t, exact_ext_type->MakeTensor(s)); + ASSERT_EQ(tensor->strides(), t->strides()); + ASSERT_EQ(tensor->shape(), t->shape()); + ASSERT_EQ(tensor->dim_names(), t->dim_names()); + ASSERT_EQ(tensor->type(), t->type()); + ASSERT_EQ(tensor->is_contiguous(), t->is_contiguous()); + ASSERT_EQ(tensor->is_column_major(), t->is_column_major()); + ASSERT_TRUE(tensor->Equals(*t)); + + // Null value in VariableShapeTensorArray produces a tensor with shape {0, 0, 0} + strides = {sizeof(int64_t), sizeof(int64_t), sizeof(int64_t)}; + tensor = TensorFromJSON(int64(), R"([10,11,12,13,14,15,16,17,18])", {0, 0, 0}, strides, + dim_names_); + + ASSERT_OK_AND_ASSIGN(sc, ext_arr->GetScalar(3)); + ASSERT_OK_AND_ASSIGN( + t, exact_ext_type->MakeTensor(internal::checked_pointer_cast(sc))); + ASSERT_EQ(tensor->strides(), t->strides()); + ASSERT_EQ(tensor->shape(), t->shape()); + ASSERT_EQ(tensor->dim_names(), t->dim_names()); + ASSERT_EQ(tensor->type(), t->type()); + ASSERT_EQ(tensor->is_contiguous(), t->is_contiguous()); + ASSERT_EQ(tensor->is_column_major(), t->is_column_major()); + ASSERT_TRUE(tensor->Equals(*t)); +} + +TEST_F(TestVariableShapeTensorType, ToString) { + auto exact_ext_type = + internal::checked_pointer_cast(ext_type_); + + auto uniform_shape = std::vector>{ + std::nullopt, std::optional(1), std::nullopt}; + auto ext_type_1 = internal::checked_pointer_cast( + variable_shape_tensor(int16(), 3)); + auto ext_type_2 = internal::checked_pointer_cast( + variable_shape_tensor(int32(), 3, {1, 0, 2})); + auto ext_type_3 = internal::checked_pointer_cast( + variable_shape_tensor(int64(), 3, {}, {"C", "H", "W"})); + auto ext_type_4 = internal::checked_pointer_cast( + variable_shape_tensor(int64(), 3, {}, {}, uniform_shape)); + + std::string result_1 = ext_type_1->ToString(); + std::string expected_1 = + "extension"; + ASSERT_EQ(expected_1, result_1); + + std::string result_2 = ext_type_2->ToString(); + std::string expected_2 = + "extension"; + ASSERT_EQ(expected_2, result_2); + + std::string result_3 = ext_type_3->ToString(); + std::string expected_3 = + "extension"; + ASSERT_EQ(expected_3, result_3); + + std::string result_4 = ext_type_4->ToString(); + std::string expected_4 = + "extension"; + ASSERT_EQ(expected_4, result_4); +} + } // namespace arrow diff --git a/cpp/src/arrow/extension/tensor_internal.cc b/cpp/src/arrow/extension/tensor_internal.cc new file mode 100644 index 00000000000..a875adc55fc --- /dev/null +++ b/cpp/src/arrow/extension/tensor_internal.cc @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/tensor_internal.h" + +#include "arrow/tensor.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/int_util_overflow.h" +#include "arrow/util/sort_internal.h" + +#include "arrow/status.h" +#include "arrow/util/print_internal.h" + +namespace arrow::internal { + +Status IsPermutationValid(const std::vector& permutation) { + const auto size = static_cast(permutation.size()); + std::vector dim_seen(size, 0); + + for (const auto p : permutation) { + if (p < 0 || p >= size || dim_seen[p] != 0) { + return Status::Invalid( + "Permutation indices for ", size, + " dimensional tensors must be unique and within [0, ", size - 1, + "] range. Got: ", ::arrow::internal::PrintVector{permutation, ","}); + } + dim_seen[p] = 1; + } + return Status::OK(); +} + +Status ComputeStrides(const std::shared_ptr& value_type, + const std::vector& shape, + const std::vector& permutation, + std::vector* strides) { + auto fixed_width_type = internal::checked_pointer_cast(value_type); + if (permutation.empty()) { + return internal::ComputeRowMajorStrides(*fixed_width_type.get(), shape, strides); + } + const int byte_width = value_type->byte_width(); + + int64_t remaining = 0; + if (!shape.empty() && shape.front() > 0) { + remaining = byte_width; + for (auto i : permutation) { + if (i > 0) { + if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) { + return Status::Invalid( + "Strides computed from shape would not fit in 64-bit integer"); + } + } + } + } + + if (remaining == 0) { + strides->assign(shape.size(), byte_width); + return Status::OK(); + } + + strides->push_back(remaining); + for (auto i : permutation) { + if (i > 0) { + remaining /= shape[i]; + strides->push_back(remaining); + } + } + internal::Permute(permutation, strides); + + return Status::OK(); +} + +} // namespace arrow::internal diff --git a/cpp/src/arrow/extension/tensor_internal.h b/cpp/src/arrow/extension/tensor_internal.h index 62b1dba6144..1a0bd0b29c2 100644 --- a/cpp/src/arrow/extension/tensor_internal.h +++ b/cpp/src/arrow/extension/tensor_internal.h @@ -20,25 +20,17 @@ #include #include -#include "arrow/status.h" -#include "arrow/util/print_internal.h" +#include "arrow/array/array_nested.h" namespace arrow::internal { -inline Status IsPermutationValid(const std::vector& permutation) { - const auto size = static_cast(permutation.size()); - std::vector dim_seen(size, 0); +ARROW_EXPORT +Status IsPermutationValid(const std::vector& permutation); - for (const auto p : permutation) { - if (p < 0 || p >= size || dim_seen[p] != 0) { - return Status::Invalid( - "Permutation indices for ", size, - " dimensional tensors must be unique and within [0, ", size - 1, - "] range. Got: ", ::arrow::internal::PrintVector{permutation, ","}); - } - dim_seen[p] = 1; - } - return Status::OK(); -} +ARROW_EXPORT +Status ComputeStrides(const std::shared_ptr& value_type, + const std::vector& shape, + const std::vector& permutation, + std::vector* strides); } // namespace arrow::internal diff --git a/cpp/src/arrow/extension/variable_shape_tensor.cc b/cpp/src/arrow/extension/variable_shape_tensor.cc new file mode 100644 index 00000000000..ee2dd165c56 --- /dev/null +++ b/cpp/src/arrow/extension/variable_shape_tensor.cc @@ -0,0 +1,314 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/extension/tensor_internal.h" +#include "arrow/extension/variable_shape_tensor.h" + +#include "arrow/array/array_primitive.h" +#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep +#include "arrow/scalar.h" +#include "arrow/tensor.h" +#include "arrow/util/int_util_overflow.h" +#include "arrow/util/logging_internal.h" +#include "arrow/util/print_internal.h" +#include "arrow/util/sort_internal.h" +#include "arrow/util/string.h" + +#include +#include + +namespace rj = arrow::rapidjson; + +namespace arrow::extension { + +bool VariableShapeTensorType::ExtensionEquals(const ExtensionType& other) const { + if (extension_name() != other.extension_name()) { + return false; + } + const auto& other_ext = internal::checked_cast(other); + if (this->ndim() != other_ext.ndim()) { + return false; + } + + auto is_permutation_trivial = [](const std::vector& permutation) { + for (size_t i = 1; i < permutation.size(); ++i) { + if (permutation[i - 1] + 1 != permutation[i]) { + return false; + } + } + return true; + }; + const bool permutation_equivalent = + ((permutation_ == other_ext.permutation()) || + (permutation_.empty() && is_permutation_trivial(other_ext.permutation())) || + (is_permutation_trivial(permutation_) && other_ext.permutation().empty())); + + return (storage_type()->Equals(other_ext.storage_type())) && + (dim_names_ == other_ext.dim_names()) && + (uniform_shape_ == other_ext.uniform_shape()) && permutation_equivalent; +} + +std::string VariableShapeTensorType::ToString(bool show_metadata) const { + std::stringstream ss; + ss << "extension<" << this->extension_name() + << "[value_type=" << value_type_->ToString() << ", ndim=" << ndim_; + + if (!permutation_.empty()) { + ss << ", permutation=" << ::arrow::internal::PrintVector{permutation_, ","}; + } + if (!dim_names_.empty()) { + ss << ", dim_names=[" << internal::JoinStrings(dim_names_, ",") << "]"; + } + if (!uniform_shape_.empty()) { + std::vector uniform_shape; + for (const auto& v : uniform_shape_) { + if (v.has_value()) { + uniform_shape.emplace_back(std::to_string(v.value())); + } else { + uniform_shape.emplace_back("null"); + } + } + ss << ", uniform_shape=[" << internal::JoinStrings(uniform_shape, ",") << "]"; + } + ss << "]>"; + return ss.str(); +} + +std::string VariableShapeTensorType::Serialize() const { + rj::Document document; + document.SetObject(); + rj::Document::AllocatorType& allocator = document.GetAllocator(); + + if (!permutation_.empty()) { + rj::Value permutation(rj::kArrayType); + for (auto v : permutation_) { + permutation.PushBack(v, allocator); + } + document.AddMember(rj::Value("permutation", allocator), permutation, allocator); + } + + if (!dim_names_.empty()) { + rj::Value dim_names(rj::kArrayType); + for (std::string v : dim_names_) { + dim_names.PushBack(rj::Value{}.SetString(v.c_str(), allocator), allocator); + } + document.AddMember(rj::Value("dim_names", allocator), dim_names, allocator); + } + + if (!uniform_shape_.empty()) { + rj::Value uniform_shape(rj::kArrayType); + for (auto v : uniform_shape_) { + if (v.has_value()) { + uniform_shape.PushBack(v.value(), allocator); + } else { + uniform_shape.PushBack(rj::Value{}.SetNull(), allocator); + } + } + document.AddMember(rj::Value("uniform_shape", allocator), uniform_shape, allocator); + } + + rj::StringBuffer buffer; + rj::Writer writer(buffer); + document.Accept(writer); + return buffer.GetString(); +} + +Result> VariableShapeTensorType::Deserialize( + std::shared_ptr storage_type, const std::string& serialized_data) const { + if (storage_type->id() != Type::STRUCT) { + return Status::Invalid("Expected Struct storage type, got ", + storage_type->ToString()); + } + if (storage_type->num_fields() != 2) { + return Status::Invalid("Expected Struct storage type with 2 fields, got ", + storage_type->num_fields()); + } + if (storage_type->field(0)->type()->id() != Type::LIST) { + return Status::Invalid("Expected List storage type, got ", + storage_type->field(0)->type()->ToString()); + } + if (storage_type->field(1)->type()->id() != Type::FIXED_SIZE_LIST) { + return Status::Invalid("Expected FixedSizeList storage type, got ", + storage_type->field(1)->type()->ToString()); + } + if (internal::checked_cast(*storage_type->field(1)->type()) + .value_type() != int32()) { + return Status::Invalid("Expected FixedSizeList value type int32, got ", + storage_type->field(1)->type()->ToString()); + } + + const auto value_type = storage_type->field(0)->type()->field(0)->type(); + const uint32_t ndim = + internal::checked_cast(*storage_type->field(1)->type()) + .list_size(); + + rj::Document document; + if (document.Parse(serialized_data.data(), serialized_data.length()).HasParseError()) { + return Status::Invalid("Invalid serialized JSON data: ", serialized_data); + } + + std::vector permutation; + if (document.HasMember("permutation")) { + permutation.reserve(ndim); + for (const auto& x : document["permutation"].GetArray()) { + permutation.emplace_back(x.GetInt64()); + } + } + std::vector dim_names; + if (document.HasMember("dim_names")) { + dim_names.reserve(ndim); + for (const auto& x : document["dim_names"].GetArray()) { + dim_names.emplace_back(x.GetString()); + } + } + + std::vector> uniform_shape; + if (document.HasMember("uniform_shape")) { + uniform_shape.reserve(ndim); + for (const auto& x : document["uniform_shape"].GetArray()) { + if (x.IsNull()) { + uniform_shape.emplace_back(std::nullopt); + } else { + uniform_shape.emplace_back(x.GetInt64()); + } + } + } + + return VariableShapeTensorType::Make(value_type, ndim, permutation, dim_names, + uniform_shape); +} + +std::shared_ptr VariableShapeTensorType::MakeArray( + std::shared_ptr data) const { + DCHECK_EQ(data->type->id(), Type::EXTENSION); + DCHECK_EQ("arrow.variable_shape_tensor", + internal::checked_cast(*data->type).extension_name()); + return std::make_shared(data); +} + +Result> VariableShapeTensorType::MakeTensor( + const std::shared_ptr& scalar) { + const auto& tensor_scalar = internal::checked_cast(*scalar->value); + const auto& ext_type = + internal::checked_cast(*scalar->type); + + ARROW_ASSIGN_OR_RAISE(const auto data_scalar, tensor_scalar.field(0)); + ARROW_ASSIGN_OR_RAISE(const auto shape_scalar, tensor_scalar.field(1)); + ARROW_CHECK(tensor_scalar.is_valid); + const auto data_array = + internal::checked_pointer_cast(data_scalar)->value; + const auto shape_array = internal::checked_pointer_cast( + internal::checked_pointer_cast(shape_scalar)->value); + + const auto& value_type = + internal::checked_cast(*ext_type.value_type()); + + if (data_array->null_count() > 0) { + return Status::Invalid("Cannot convert data with nulls to Tensor."); + } + + auto permutation = ext_type.permutation(); + if (permutation.empty()) { + permutation.resize(ext_type.ndim()); + std::iota(permutation.begin(), permutation.end(), 0); + } + + ARROW_CHECK_EQ(shape_array->length(), ext_type.ndim()); + std::vector shape; + shape.reserve(ext_type.ndim()); + for (int64_t j = 0; j < static_cast(ext_type.ndim()); ++j) { + const auto size_value = shape_array->Value(j); + if (size_value < 0) { + return Status::Invalid("shape must have non-negative values"); + } + shape.push_back(std::move(size_value)); + } + + std::vector dim_names = ext_type.dim_names(); + if (!dim_names.empty()) { + internal::Permute(permutation, &dim_names); + } + + std::vector strides; + ARROW_RETURN_NOT_OK( + internal::ComputeStrides(ext_type.value_type(), shape, permutation, &strides)); + internal::Permute(permutation, &shape); + + const auto byte_width = value_type.byte_width(); + const auto start_position = data_array->offset() * byte_width; + const auto size = std::accumulate(shape.begin(), shape.end(), static_cast(1), + std::multiplies<>()); + ARROW_CHECK_EQ(size * byte_width, data_array->length() * byte_width); + ARROW_ASSIGN_OR_RAISE( + const auto buffer, + SliceBufferSafe(data_array->data()->buffers[1], start_position, size * byte_width)); + + return Tensor::Make(ext_type.value_type(), std::move(buffer), std::move(shape), + std::move(strides), ext_type.dim_names()); +} + +Result> VariableShapeTensorType::Make( + const std::shared_ptr& value_type, const int32_t ndim, + const std::vector& permutation, const std::vector& dim_names, + const std::vector>& uniform_shape) { + if (!is_fixed_width(*value_type)) { + return Status::Invalid("Cannot convert non-fixed-width values to Tensor."); + } + + if (!dim_names.empty() && dim_names.size() != static_cast(ndim)) { + return Status::Invalid("dim_names size must match ndim. Expected: ", ndim, + " Got: ", dim_names.size()); + } + if (!uniform_shape.empty() && uniform_shape.size() != static_cast(ndim)) { + return Status::Invalid("uniform_shape size must match ndim. Expected: ", ndim, + " Got: ", uniform_shape.size()); + } + if (!uniform_shape.empty()) { + for (const auto& v : uniform_shape) { + if (v.has_value() && v.value() < 0) { + return Status::Invalid("uniform_shape must have non-negative values"); + } + } + } + if (!permutation.empty()) { + if (permutation.size() != static_cast(ndim)) { + return Status::Invalid("permutation size must match ndim. Expected: ", ndim, + " Got: ", permutation.size()); + } + RETURN_NOT_OK(internal::IsPermutationValid(permutation)); + } + + return std::make_shared( + value_type, std::move(ndim), std::move(permutation), std::move(dim_names), + std::move(uniform_shape)); +} + +std::shared_ptr variable_shape_tensor( + const std::shared_ptr& value_type, const int32_t ndim, + const std::vector permutation, const std::vector dim_names, + const std::vector> uniform_shape) { + auto maybe_type = + VariableShapeTensorType::Make(value_type, std::move(ndim), std::move(permutation), + std::move(dim_names), std::move(uniform_shape)); + ARROW_CHECK_OK(maybe_type.status()); + return maybe_type.MoveValueUnsafe(); +} + +} // namespace arrow::extension diff --git a/cpp/src/arrow/extension/variable_shape_tensor.h b/cpp/src/arrow/extension/variable_shape_tensor.h new file mode 100644 index 00000000000..7b3e14fbc7e --- /dev/null +++ b/cpp/src/arrow/extension/variable_shape_tensor.h @@ -0,0 +1,116 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "arrow/extension_type.h" + +namespace arrow { +namespace extension { + +class ARROW_EXPORT VariableShapeTensorArray : public ExtensionArray { + public: + using ExtensionArray::ExtensionArray; +}; + +/// \brief Concrete type class for variable-shape Tensor data. +/// This is a canonical arrow extension type. +/// See: https://arrow.apache.org/docs/format/CanonicalExtensions.html +class ARROW_EXPORT VariableShapeTensorType : public ExtensionType { + public: + VariableShapeTensorType(const std::shared_ptr& value_type, const int32_t ndim, + const std::vector permutation = {}, + const std::vector dim_names = {}, + const std::vector> uniform_shape = {}) + : ExtensionType(struct_({::arrow::field("data", list(value_type)), + ::arrow::field("shape", fixed_size_list(int32(), ndim))})), + value_type_(value_type), + ndim_(std::move(ndim)), + permutation_(std::move(permutation)), + dim_names_(std::move(dim_names)), + uniform_shape_(std::move(uniform_shape)) {} + + std::string extension_name() const override { return "arrow.variable_shape_tensor"; } + std::string ToString(bool show_metadata = false) const override; + + /// Number of dimensions of tensor elements + int32_t ndim() const { return ndim_; } + + /// Value type of tensor elements + const std::shared_ptr& value_type() const { return value_type_; } + + /// Permutation mapping from logical to physical memory layout of tensor elements + const std::vector& permutation() const { return permutation_; } + + /// Dimension names of tensor elements. Dimensions are ordered physically. + const std::vector& dim_names() const { return dim_names_; } + + /// Shape of uniform dimensions. + const std::vector>& uniform_shape() const { + return uniform_shape_; + } + + bool ExtensionEquals(const ExtensionType& other) const override; + + std::string Serialize() const override; + + Result> Deserialize( + std::shared_ptr storage_type, + const std::string& serialized_data) const override; + + /// Create a VariableShapeTensorArray from ArrayData + std::shared_ptr MakeArray(std::shared_ptr data) const override; + + /// \brief Convert an ExtensionScalar to a Tensor + /// + /// This method will return a Tensor from ExtensionScalar with strides derived + /// from shape and permutation stored. Shape and dim_names will be permuted + /// according to permutation stored in the VariableShapeTensorType. + static Result> MakeTensor( + const std::shared_ptr&); + + /// \brief Create a VariableShapeTensorType instance + static Result> Make( + const std::shared_ptr& value_type, const int32_t ndim, + const std::vector& permutation = {}, + const std::vector& dim_names = {}, + const std::vector>& uniform_shape = {}); + + private: + std::shared_ptr storage_type_; + std::shared_ptr value_type_; + int32_t ndim_; + std::vector permutation_; + std::vector dim_names_; + std::vector> uniform_shape_; +}; + +/// \brief Return a VariableShapeTensorType instance. +ARROW_EXPORT std::shared_ptr variable_shape_tensor( + const std::shared_ptr& value_type, const int32_t ndim, + const std::vector permutation = {}, + const std::vector dim_names = {}, + const std::vector> uniform_shape = {}); + +} // namespace extension +} // namespace arrow diff --git a/cpp/src/arrow/extension_type.cc b/cpp/src/arrow/extension_type.cc index 555ffe0156a..ce88c951741 100644 --- a/cpp/src/arrow/extension_type.cc +++ b/cpp/src/arrow/extension_type.cc @@ -31,6 +31,7 @@ #ifdef ARROW_JSON # include "arrow/extension/fixed_shape_tensor.h" # include "arrow/extension/opaque.h" +# include "arrow/extension/variable_shape_tensor.h" #endif #include "arrow/extension/json.h" #include "arrow/extension/uuid.h" @@ -155,6 +156,7 @@ static void CreateGlobalRegistry() { #ifdef ARROW_JSON ext_types.push_back(extension::fixed_shape_tensor(int64(), {})); ext_types.push_back(extension::opaque(null(), "", "")); + ext_types.push_back(extension::variable_shape_tensor(int64(), 0)); #endif for (const auto& ext_type : ext_types) { diff --git a/cpp/src/arrow/extension_type_test.cc b/cpp/src/arrow/extension_type_test.cc index 23c1ff731da..0b256f1b45b 100644 --- a/cpp/src/arrow/extension_type_test.cc +++ b/cpp/src/arrow/extension_type_test.cc @@ -40,10 +40,10 @@ #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging_internal.h" -namespace arrow { - using arrow::ipc::test::RoundtripBatch; +namespace arrow { + class Parametric1Array : public ExtensionArray { public: using ExtensionArray::ExtensionArray; diff --git a/docs/source/format/CanonicalExtensions.rst b/docs/source/format/CanonicalExtensions.rst index 8608a6388e0..5a4131d8eb8 100644 --- a/docs/source/format/CanonicalExtensions.rst +++ b/docs/source/format/CanonicalExtensions.rst @@ -248,8 +248,8 @@ Variable shape tensor This means the logical tensor has names [z, x, y] and shape [30, 10, 20]. .. note:: - Values inside each **data** tensor element are stored in row-major/C-contiguous - order according to the corresponding **shape**. + Elements in a variable shape tensor extension array are stored + in row-major/C-contiguous order. .. _json_extension: diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index ebac37e862b..4344aef3fbb 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -1482,6 +1482,17 @@ def test_tensor_class_methods(np_type_str): assert result.to_tensor().shape == (1, 3, 2, 2) assert result.to_tensor().strides == (12 * bw, 1 * bw, 6 * bw, 2 * bw) + tensor_type = pa.fixed_shape_tensor(arrow_type, [2, 2, 3], permutation=[2, 1, 0]) + result = pa.ExtensionArray.from_storage(tensor_type, storage) + expected = as_strided(flat_arr, shape=(1, 3, 2, 2), + strides=(bw * 12, bw, bw * 3, bw * 6)) + np.testing.assert_array_equal(result.to_numpy_ndarray(), expected) + + assert result.type.permutation == [2, 1, 0] + assert result.type.shape == [2, 2, 3] + assert result.to_tensor().shape == (1, 3, 2, 2) + assert result.to_tensor().strides == (12 * bw, 1 * bw, 3 * bw, 6 * bw) + @pytest.mark.numpy @pytest.mark.parametrize("np_type_str", ("int8", "int64", "float32")) @@ -1712,7 +1723,7 @@ def test_tensor_type_is_picklable(pickle_module): 'fixed_shape_tensor[value_type=int64, shape=[2,2,3], dim_names=[C,H,W]]' ) ]) -def test_tensor_type_str(tensor_type, text): +def test_tensor_type_str(tensor_type, text, pickle_module): tensor_type_str = tensor_type.__str__() assert text in tensor_type_str