Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions cpp/src/arrow/extension/fixed_shape_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "arrow/array/array_nested.h"
#include "arrow/array/array_primitive.h"
#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep
#include "arrow/tensor.h"
#include "arrow/util/int_util_overflow.h"
#include "arrow/util/logging.h"
#include "arrow/util/sort.h"
Expand All @@ -33,8 +34,52 @@
namespace rj = arrow::rapidjson;

namespace arrow {

namespace extension {

namespace {

Status ComputeStrides(const FixedWidthType& type, const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation,
std::vector<int64_t>* strides) {
if (permutation.empty()) {
return internal::ComputeRowMajorStrides(type, shape, strides);
}

const int byte_width = type.byte_width();

int64_t remaining = 0;
if (!shape.empty() && shape.front() > 0) {
remaining = byte_width;
for (auto i : permutation) {
if (i > 0) {
if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) {
return Status::Invalid(
"Strides computed from shape would not fit in 64-bit integer");
}
}
}
}

if (remaining == 0) {
strides->assign(shape.size(), byte_width);
return Status::OK();
}

strides->push_back(remaining);
for (auto i : permutation) {
if (i > 0) {
remaining /= shape[i];
strides->push_back(remaining);
}
}
internal::Permute(permutation, strides);

return Status::OK();
}

} // namespace

bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const {
if (extension_name() != other.extension_name()) {
return false;
Expand Down Expand Up @@ -140,6 +185,132 @@ std::shared_ptr<Array> FixedShapeTensorType::MakeArray(
return std::make_shared<ExtensionArray>(data);
}

Result<std::shared_ptr<FixedShapeTensorArray>> FixedShapeTensorArray::FromTensor(
const std::shared_ptr<Tensor>& tensor) {
auto permutation = internal::ArgSort(tensor->strides(), std::greater<>());
if (permutation[0] != 0) {
return Status::Invalid(
"Only first-major tensors can be zero-copy converted to arrays");
}
permutation.erase(permutation.begin());

std::vector<int64_t> cell_shape;
for (auto i : permutation) {
cell_shape.emplace_back(tensor->shape()[i]);
}

std::vector<std::string> dim_names;
if (!tensor->dim_names().empty()) {
for (auto i : permutation) {
dim_names.emplace_back(tensor->dim_names()[i]);
}
}

for (int64_t& i : permutation) {
--i;
}

auto ext_type = internal::checked_pointer_cast<ExtensionType>(
fixed_shape_tensor(tensor->type(), cell_shape, permutation, dim_names));

std::shared_ptr<Array> value_array;
switch (tensor->type_id()) {
case Type::UINT8: {
value_array = std::make_shared<UInt8Array>(tensor->size(), tensor->data());
break;
}
case Type::INT8: {
value_array = std::make_shared<Int8Array>(tensor->size(), tensor->data());
break;
}
case Type::UINT16: {
value_array = std::make_shared<UInt16Array>(tensor->size(), tensor->data());
break;
}
case Type::INT16: {
value_array = std::make_shared<Int16Array>(tensor->size(), tensor->data());
break;
}
case Type::UINT32: {
value_array = std::make_shared<UInt32Array>(tensor->size(), tensor->data());
break;
}
case Type::INT32: {
value_array = std::make_shared<Int32Array>(tensor->size(), tensor->data());
break;
}
case Type::UINT64: {
value_array = std::make_shared<Int64Array>(tensor->size(), tensor->data());
break;
}
case Type::INT64: {
value_array = std::make_shared<Int64Array>(tensor->size(), tensor->data());
break;
}
case Type::HALF_FLOAT: {
value_array = std::make_shared<HalfFloatArray>(tensor->size(), tensor->data());
break;
}
case Type::FLOAT: {
value_array = std::make_shared<FloatArray>(tensor->size(), tensor->data());
break;
}
case Type::DOUBLE: {
value_array = std::make_shared<DoubleArray>(tensor->size(), tensor->data());
break;
}
default: {
return Status::NotImplemented("Unsupported tensor type: ",
tensor->type()->ToString());
}
}
auto cell_size = static_cast<int32_t>(tensor->size() / tensor->shape()[0]);
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> arr,
FixedSizeListArray::FromArrays(value_array, cell_size));
std::shared_ptr<Array> ext_arr = ExtensionType::WrapArray(ext_type, arr);
return std::reinterpret_pointer_cast<FixedShapeTensorArray>(ext_arr);
}

const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
// To convert an array of n dimensional tensors to a n+1 dimensional tensor we
// interpret the array's length as the first dimension the new tensor.

auto ext_arr = internal::checked_pointer_cast<FixedSizeListArray>(this->storage());
auto ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(this->type());
ARROW_RETURN_IF(!is_fixed_width(*ext_arr->value_type()),
Status::Invalid(ext_arr->value_type()->ToString(),
" is not valid data type for a tensor"));
auto permutation = ext_type->permutation();

std::vector<std::string> dim_names;
if (!ext_type->dim_names().empty()) {
for (auto i : permutation) {
dim_names.emplace_back(ext_type->dim_names()[i]);
}
dim_names.insert(dim_names.begin(), 1, "");
} else {
dim_names = {};
}

std::vector<int64_t> shape;
for (int64_t& i : permutation) {
shape.emplace_back(ext_type->shape()[i]);
++i;
}
shape.insert(shape.begin(), 1, this->length());
permutation.insert(permutation.begin(), 1, 0);

std::vector<int64_t> tensor_strides;
auto value_type = internal::checked_pointer_cast<FixedWidthType>(ext_arr->value_type());
ARROW_RETURN_NOT_OK(
ComputeStrides(*value_type.get(), shape, permutation, &tensor_strides));
ARROW_ASSIGN_OR_RAISE(auto buffers, ext_arr->Flatten());
ARROW_ASSIGN_OR_RAISE(
auto tensor, Tensor::Make(ext_arr->value_type(), buffers->data()->buffers[1], shape,
tensor_strides, dim_names));
return tensor;
}

Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation, const std::vector<std::string>& dim_names) {
Expand All @@ -157,6 +328,17 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
shape, permutation, dim_names);
}

const std::vector<int64_t>& FixedShapeTensorType::strides() {
if (strides_.empty()) {
auto value_type = internal::checked_pointer_cast<FixedWidthType>(this->value_type_);
std::vector<int64_t> tensor_strides;
ARROW_CHECK_OK(ComputeStrides(*value_type.get(), this->shape(), this->permutation(),
&tensor_strides));
strides_ = tensor_strides;
}
return strides_;
}

std::shared_ptr<DataType> fixed_shape_tensor(const std::shared_ptr<DataType>& value_type,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation,
Expand Down
26 changes: 26 additions & 0 deletions cpp/src/arrow/extension/fixed_shape_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,26 @@ namespace extension {
class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray {
public:
using ExtensionArray::ExtensionArray;

/// \brief Create a FixedShapeTensorArray from a Tensor
///
/// This method will create a FixedShapeTensorArray from a Tensor, taking its first
/// dimension as the number of elements in the resulting array and the remaining
/// dimensions as the shape of the individual tensors. If Tensor provides strides,
/// they will be used to determine dimension permutation. Otherwise, row-major layout
/// (i.e. no permutation) will be assumed.
///
/// \param[in] tensor The Tensor to convert to a FixedShapeTensorArray
static Result<std::shared_ptr<FixedShapeTensorArray>> FromTensor(
const std::shared_ptr<Tensor>& tensor);

/// \brief Create a Tensor from FixedShapeTensorArray
///
/// This method will create a Tensor from a FixedShapeTensorArray, setting its first
/// dimension as length equal to the FixedShapeTensorArray's length and the remaining
/// dimensions as the FixedShapeTensorType's shape. Shape and dim_names will be
/// permuted according to permutation stored in the FixedShapeTensorType metadata.
const Result<std::shared_ptr<Tensor>> ToTensor() const;
};

/// \brief Concrete type class for constant-size Tensor data.
Expand Down Expand Up @@ -51,6 +71,11 @@ class ARROW_EXPORT FixedShapeTensorType : public ExtensionType {
/// Value type of tensor elements
const std::shared_ptr<DataType> value_type() const { return value_type_; }

/// Strides of tensor elements. Strides state offset in bytes between adjacent
/// elements along each dimension. In case permutation is non-empty strides are
/// computed from permuted tensor element's shape.
const std::vector<int64_t>& strides();

/// Permutation mapping from logical to physical memory layout of tensor elements
const std::vector<int64_t>& permutation() const { return permutation_; }

Expand Down Expand Up @@ -78,6 +103,7 @@ class ARROW_EXPORT FixedShapeTensorType : public ExtensionType {
std::shared_ptr<DataType> storage_type_;
std::shared_ptr<DataType> value_type_;
std::vector<int64_t> shape_;
std::vector<int64_t> strides_;
std::vector<int64_t> permutation_;
std::vector<std::string> dim_names_;
};
Expand Down
Loading