Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,8 @@ if(ARROW_JSON)
arrow_add_object_library(ARROW_JSON
extension/fixed_shape_tensor.cc
extension/opaque.cc
extension/tensor_internal.cc
extension/variable_shape_tensor.cc
json/options.cc
json/chunked_builder.cc
json/chunker.cc
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/extension/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
set(CANONICAL_EXTENSION_TESTS bool8_test.cc uuid_test.cc)

if(ARROW_JSON)
list(APPEND CANONICAL_EXTENSION_TESTS fixed_shape_tensor_test.cc opaque_test.cc)
list(APPEND CANONICAL_EXTENSION_TESTS tensor_extension_array_test.cc opaque_test.cc)
endif()

add_arrow_test(test
Expand Down
61 changes: 7 additions & 54 deletions cpp/src/arrow/extension/fixed_shape_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,52 +37,7 @@

namespace rj = arrow::rapidjson;

namespace arrow {

namespace extension {

namespace {

Status ComputeStrides(const FixedWidthType& type, const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation,
std::vector<int64_t>* strides) {
if (permutation.empty()) {
return internal::ComputeRowMajorStrides(type, shape, strides);
}

const int byte_width = type.byte_width();

int64_t remaining = 0;
if (!shape.empty() && shape.front() > 0) {
remaining = byte_width;
for (auto i : permutation) {
if (i > 0) {
if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) {
return Status::Invalid(
"Strides computed from shape would not fit in 64-bit integer");
}
}
}
}

if (remaining == 0) {
strides->assign(shape.size(), byte_width);
return Status::OK();
}

strides->push_back(remaining);
for (auto i : permutation) {
if (i > 0) {
remaining /= shape[i];
strides->push_back(remaining);
}
}
internal::Permute(permutation, strides);

return Status::OK();
}

} // namespace
namespace arrow::extension {

bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const {
if (extension_name() != other.extension_name()) {
Expand Down Expand Up @@ -237,7 +192,8 @@ Result<std::shared_ptr<Tensor>> FixedShapeTensorType::MakeTensor(
}

std::vector<int64_t> strides;
RETURN_NOT_OK(ComputeStrides(value_type, shape, permutation, &strides));
RETURN_NOT_OK(
internal::ComputeStrides(ext_type.value_type(), shape, permutation, &strides));
const auto start_position = array->offset() * byte_width;
const auto size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1),
std::multiplies<>());
Expand Down Expand Up @@ -376,9 +332,8 @@ const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
internal::Permute<int64_t>(permutation, &shape);

std::vector<int64_t> tensor_strides;
const auto* fw_value_type = internal::checked_cast<FixedWidthType*>(value_type.get());
ARROW_RETURN_NOT_OK(
ComputeStrides(*fw_value_type, shape, permutation, &tensor_strides));
internal::ComputeStrides(value_type, shape, permutation, &tensor_strides));

const auto& raw_buffer = this->storage()->data()->child_data[0]->buffers[1];
ARROW_ASSIGN_OR_RAISE(
Expand Down Expand Up @@ -412,10 +367,9 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(

const std::vector<int64_t>& FixedShapeTensorType::strides() {
if (strides_.empty()) {
auto value_type = internal::checked_cast<FixedWidthType*>(this->value_type_.get());
std::vector<int64_t> tensor_strides;
ARROW_CHECK_OK(
ComputeStrides(*value_type, this->shape(), this->permutation(), &tensor_strides));
ARROW_CHECK_OK(internal::ComputeStrides(this->value_type_, this->shape(),
this->permutation(), &tensor_strides));
strides_ = tensor_strides;
}
return strides_;
Expand All @@ -430,5 +384,4 @@ std::shared_ptr<DataType> fixed_shape_tensor(const std::shared_ptr<DataType>& va
return maybe_type.MoveValueUnsafe();
}

} // namespace extension
} // namespace arrow
} // namespace arrow::extension
6 changes: 2 additions & 4 deletions cpp/src/arrow/extension/fixed_shape_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

#include "arrow/extension_type.h"

namespace arrow {
namespace extension {
namespace arrow::extension {

class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray {
public:
Expand Down Expand Up @@ -126,5 +125,4 @@ ARROW_EXPORT std::shared_ptr<DataType> fixed_shape_tensor(
const std::vector<int64_t>& permutation = {},
const std::vector<std::string>& dim_names = {});

} // namespace extension
} // namespace arrow
} // namespace arrow::extension
Loading