Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 34 additions & 19 deletions cpp/src/arrow/ipc/metadata_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1245,39 +1245,54 @@ Status GetSparseTensorMetadata(const Buffer& metadata, std::shared_ptr<DataType>
}
int ndim = static_cast<int>(sparse_tensor->shape()->size());

for (int i = 0; i < ndim; ++i) {
auto dim = sparse_tensor->shape()->Get(i);
if (shape || dim_names) {
for (int i = 0; i < ndim; ++i) {
auto dim = sparse_tensor->shape()->Get(i);

shape->push_back(dim->size());
auto fb_name = dim->name();
if (fb_name == 0) {
dim_names->push_back("");
} else {
dim_names->push_back(fb_name->str());
if (shape) {
shape->push_back(dim->size());
}

if (dim_names) {
auto fb_name = dim->name();
if (fb_name == 0) {
dim_names->push_back("");
} else {
dim_names->push_back(fb_name->str());
}
}
}
}

*non_zero_length = sparse_tensor->non_zero_length();
if (non_zero_length) {
*non_zero_length = sparse_tensor->non_zero_length();
}

switch (sparse_tensor->sparseIndex_type()) {
case flatbuf::SparseTensorIndex_SparseTensorIndexCOO:
*sparse_tensor_format_id = SparseTensorFormat::COO;
break;
if (sparse_tensor_format_id) {
switch (sparse_tensor->sparseIndex_type()) {
case flatbuf::SparseTensorIndex_SparseTensorIndexCOO:
*sparse_tensor_format_id = SparseTensorFormat::COO;
break;

case flatbuf::SparseTensorIndex_SparseMatrixIndexCSR:
*sparse_tensor_format_id = SparseTensorFormat::CSR;
break;
case flatbuf::SparseTensorIndex_SparseMatrixIndexCSR:
*sparse_tensor_format_id = SparseTensorFormat::CSR;
break;

default:
return Status::Invalid("Unrecognized sparse index type");
default:
return Status::Invalid("Unrecognized sparse index type");
}
}

auto type_data = sparse_tensor->type();
if (type_data == nullptr) {
return Status::IOError(
"Type-pointer in custom metadata of flatbuffer-encoded SparseTensor is null.");
}
return ConcreteTypeFromFlatbuffer(sparse_tensor->type_type(), type_data, {}, type);
if (type) {
return ConcreteTypeFromFlatbuffer(sparse_tensor->type_type(), type_data, {}, type);
} else {
return Status::OK();
}
}

} // namespace internal
Expand Down
20 changes: 12 additions & 8 deletions cpp/src/arrow/ipc/read_write_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1188,8 +1188,8 @@ void TestSparseTensorRoundTrip<IndexValueType>::CheckSparseTensorRoundTrip(

ASSERT_OK(mmap_->Seek(0));

ASSERT_OK(WriteSparseTensor(sparse_tensor, mmap_.get(), &metadata_length, &body_length,
default_memory_pool()));
ASSERT_OK(
WriteSparseTensor(sparse_tensor, mmap_.get(), &metadata_length, &body_length));

const auto& sparse_index =
checked_cast<const SparseCOOIndex&>(*sparse_tensor.sparse_index());
Expand Down Expand Up @@ -1224,8 +1224,8 @@ void TestSparseTensorRoundTrip<IndexValueType>::CheckSparseTensorRoundTrip(

ASSERT_OK(mmap_->Seek(0));

ASSERT_OK(WriteSparseTensor(sparse_tensor, mmap_.get(), &metadata_length, &body_length,
default_memory_pool()));
ASSERT_OK(
WriteSparseTensor(sparse_tensor, mmap_.get(), &metadata_length, &body_length));

const auto& sparse_index =
checked_cast<const SparseCSRIndex&>(*sparse_tensor.sparse_index());
Expand Down Expand Up @@ -1285,8 +1285,10 @@ TYPED_TEST_P(TestSparseTensorRoundTrip, WithSparseCOOIndexRowMajor) {
0, 2, 0, 0, 2, 2, 1, 0, 1, 1, 0, 3,
1, 1, 0, 1, 1, 2, 1, 2, 1, 1, 2, 3};
const int sizeof_index_value = sizeof(c_index_value_type);
auto si = this->MakeSparseCOOIndex(
{12, 3}, {sizeof_index_value * 3, sizeof_index_value}, coords_values);
std::shared_ptr<SparseCOOIndex> si;
ASSERT_OK(SparseCOOIndex::Make(TypeTraits<IndexValueType>::type_singleton(), {12, 3},
{sizeof_index_value * 3, sizeof_index_value},
Buffer::Wrap(coords_values), &si));

std::vector<int64_t> shape = {2, 3, 4};
std::vector<std::string> dim_names = {"foo", "bar", "baz"};
Expand Down Expand Up @@ -1328,8 +1330,10 @@ TYPED_TEST_P(TestSparseTensorRoundTrip, WithSparseCOOIndexColumnMajor) {
0, 0, 1, 1, 2, 2, 0, 0, 1, 1, 2, 2,
0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3};
const int sizeof_index_value = sizeof(c_index_value_type);
auto si = this->MakeSparseCOOIndex(
{12, 3}, {sizeof_index_value, sizeof_index_value * 12}, coords_values);
std::shared_ptr<SparseCOOIndex> si;
ASSERT_OK(SparseCOOIndex::Make(TypeTraits<IndexValueType>::type_singleton(), {12, 3},
{sizeof_index_value, sizeof_index_value * 12},
Buffer::Wrap(coords_values), &si));

std::vector<int64_t> shape = {2, 3, 4};
std::vector<std::string> dim_names = {"foo", "bar", "baz"};
Expand Down
142 changes: 130 additions & 12 deletions cpp/src/arrow/ipc/reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -933,32 +933,150 @@ Status MakeSparseTensorWithSparseCSRIndex(
return Status::OK();
}

} // namespace

Status ReadSparseTensor(const Buffer& metadata, io::RandomAccessFile* file,
std::shared_ptr<SparseTensor>* out) {
std::shared_ptr<DataType> type;
std::vector<int64_t> shape;
std::vector<std::string> dim_names;
int64_t non_zero_length;
SparseTensorFormat::type sparse_tensor_format_id;

Status ReadSparseTensorMetadata(const Buffer& metadata,
std::shared_ptr<DataType>* out_type,
std::vector<int64_t>* out_shape,
std::vector<std::string>* out_dim_names,
int64_t* out_non_zero_length,
SparseTensorFormat::type* out_format_id,
const flatbuf::SparseTensor** out_fb_sparse_tensor,
const flatbuf::Buffer** out_buffer) {
RETURN_NOT_OK(internal::GetSparseTensorMetadata(
metadata, &type, &shape, &dim_names, &non_zero_length, &sparse_tensor_format_id));
metadata, out_type, out_shape, out_dim_names, out_non_zero_length, out_format_id));

const flatbuf::Message* message;
RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message));

auto sparse_tensor = message->header_as_SparseTensor();
if (sparse_tensor == nullptr) {
return Status::IOError(
"Header-type of flatbuffer-encoded Message is not SparseTensor.");
}
const flatbuf::Buffer* buffer = sparse_tensor->data();
*out_fb_sparse_tensor = sparse_tensor;

auto buffer = sparse_tensor->data();
if (!BitUtil::IsMultipleOf8(buffer->offset())) {
return Status::Invalid(
"Buffer of sparse index data did not start on 8-byte aligned offset: ",
buffer->offset());
}
*out_buffer = buffer;

return Status::OK();
}

} // namespace

namespace internal {

namespace {

Status GetSparseTensorBodyBufferCount(SparseTensorFormat::type format_id,
size_t* buffer_count) {
switch (format_id) {
case SparseTensorFormat::COO:
*buffer_count = 2;
break;

case SparseTensorFormat::CSR:
*buffer_count = 3;
break;

default:
return Status::Invalid("Unrecognized sparse tensor format");
}

return Status::OK();
}

Status CheckSparseTensorBodyBufferCount(
const IpcPayload& payload, SparseTensorFormat::type sparse_tensor_format_id) {
size_t expected_body_buffer_count;

RETURN_NOT_OK(GetSparseTensorBodyBufferCount(sparse_tensor_format_id,
&expected_body_buffer_count));
if (payload.body_buffers.size() != expected_body_buffer_count) {
return Status::Invalid("Invalid body buffer count for a sparse tensor");
}

return Status::OK();
}

} // namespace

Status ReadSparseTensorBodyBufferCount(const Buffer& metadata, size_t* buffer_count) {
SparseTensorFormat::type format_id;

RETURN_NOT_OK(internal::GetSparseTensorMetadata(metadata, nullptr, nullptr, nullptr,
nullptr, &format_id));
return GetSparseTensorBodyBufferCount(format_id, buffer_count);
}

Status ReadSparseTensorPayload(const IpcPayload& payload,
std::shared_ptr<SparseTensor>* out) {
std::shared_ptr<DataType> type;
std::vector<int64_t> shape;
std::vector<std::string> dim_names;
int64_t non_zero_length;
SparseTensorFormat::type sparse_tensor_format_id;
const flatbuf::SparseTensor* sparse_tensor;
const flatbuf::Buffer* buffer;

RETURN_NOT_OK(ReadSparseTensorMetadata(*payload.metadata, &type, &shape, &dim_names,
&non_zero_length, &sparse_tensor_format_id,
&sparse_tensor, &buffer));

RETURN_NOT_OK(CheckSparseTensorBodyBufferCount(payload, sparse_tensor_format_id));

switch (sparse_tensor_format_id) {
case SparseTensorFormat::COO: {
std::shared_ptr<SparseCOOIndex> sparse_index;
std::shared_ptr<DataType> indices_type;
RETURN_NOT_OK(internal::GetSparseCOOIndexMetadata(
sparse_tensor->sparseIndex_as_SparseTensorIndexCOO(), &indices_type));
RETURN_NOT_OK(SparseCOOIndex::Make(indices_type, shape, non_zero_length,
payload.body_buffers[0], &sparse_index));
return MakeSparseTensorWithSparseCOOIndex(type, shape, dim_names, sparse_index,
non_zero_length, payload.body_buffers[1],
out);
}

case SparseTensorFormat::CSR: {
std::shared_ptr<SparseCSRIndex> sparse_index;
std::shared_ptr<DataType> indptr_type;
std::shared_ptr<DataType> indices_type;
RETURN_NOT_OK(internal::GetSparseCSRIndexMetadata(
sparse_tensor->sparseIndex_as_SparseMatrixIndexCSR(), &indptr_type,
&indices_type));
ARROW_CHECK_EQ(indptr_type, indices_type);
RETURN_NOT_OK(SparseCSRIndex::Make(indices_type, shape, non_zero_length,
payload.body_buffers[0], payload.body_buffers[1],
&sparse_index));
return MakeSparseTensorWithSparseCSRIndex(type, shape, dim_names, sparse_index,
non_zero_length, payload.body_buffers[2],
out);
}

default:
return Status::Invalid("Unsupported sparse index format");
}
}

} // namespace internal

Status ReadSparseTensor(const Buffer& metadata, io::RandomAccessFile* file,
std::shared_ptr<SparseTensor>* out) {
std::shared_ptr<DataType> type;
std::vector<int64_t> shape;
std::vector<std::string> dim_names;
int64_t non_zero_length;
SparseTensorFormat::type sparse_tensor_format_id;
const flatbuf::SparseTensor* sparse_tensor;
const flatbuf::Buffer* buffer;

RETURN_NOT_OK(ReadSparseTensorMetadata(metadata, &type, &shape, &dim_names,
&non_zero_length, &sparse_tensor_format_id,
&sparse_tensor, &buffer));

std::shared_ptr<Buffer> data;
RETURN_NOT_OK(file->ReadAt(buffer->offset(), buffer->length(), &data));
Expand Down
23 changes: 23 additions & 0 deletions cpp/src/arrow/ipc/reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
#include "arrow/ipc/dictionary.h"
#include "arrow/ipc/message.h"
#include "arrow/ipc/options.h"
#include "arrow/ipc/writer.h"
#include "arrow/record_batch.h"
#include "arrow/sparse_tensor.h"
#include "arrow/util/visibility.h"

namespace arrow {
Expand Down Expand Up @@ -286,6 +288,27 @@ Status ReadSparseTensor(io::InputStream* file, std::shared_ptr<SparseTensor>* ou
ARROW_EXPORT
Status ReadSparseTensor(const Message& message, std::shared_ptr<SparseTensor>* out);

namespace internal {

// These internal APIs may change without warning or deprecation

/// \brief EXPERIMENTAL: Read arrow::SparseTensorFormat::type from a metadata
/// \param[in] metadata a Buffer containing the sparse tensor metadata
/// \param[out] buffer_count the returned count of the body buffers
/// \return Status
ARROW_EXPORT
Status ReadSparseTensorBodyBufferCount(const Buffer& metadata, size_t* buffer_count);

/// \brief EXPERIMENTAL: Read arrow::SparseTensor from an IpcPayload
/// \param[in] payload a IpcPayload contains a serialized SparseTensor
/// \param[out] out the returned SparseTensor
/// \return Status
ARROW_EXPORT
Status ReadSparseTensorPayload(const IpcPayload& payload,
std::shared_ptr<SparseTensor>* out);

} // namespace internal

} // namespace ipc
} // namespace arrow

Expand Down
17 changes: 14 additions & 3 deletions cpp/src/arrow/ipc/writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -814,8 +814,7 @@ Status GetSparseTensorPayload(const SparseTensor& sparse_tensor, MemoryPool* poo
} // namespace internal

Status WriteSparseTensor(const SparseTensor& sparse_tensor, io::OutputStream* dst,
int32_t* metadata_length, int64_t* body_length,
MemoryPool* pool) {
int32_t* metadata_length, int64_t* body_length) {
internal::IpcPayload payload;
internal::SparseTensorSerializer writer(0, &payload);
RETURN_NOT_OK(writer.Assemble(sparse_tensor));
Expand All @@ -824,6 +823,18 @@ Status WriteSparseTensor(const SparseTensor& sparse_tensor, io::OutputStream* ds
return internal::WriteIpcPayload(payload, IpcOptions::Defaults(), dst, metadata_length);
}

Status GetSparseTensorMessage(const SparseTensor& sparse_tensor, MemoryPool* pool,
std::unique_ptr<Message>* out) {
internal::IpcPayload payload;
RETURN_NOT_OK(internal::GetSparseTensorPayload(sparse_tensor, pool, &payload));

const std::shared_ptr<Buffer> metadata = payload.metadata;
const std::shared_ptr<Buffer> buffer = *payload.body_buffers.data();

out->reset(new Message(metadata, buffer));
return Status::OK();
}

Status GetRecordBatchSize(const RecordBatch& batch, int64_t* size) {
// emulates the behavior of Write without actually writing
auto options = IpcOptions::Defaults();
Expand Down Expand Up @@ -1029,7 +1040,7 @@ class StreamBookKeeper {
int64_t position_;
};

/// A IpcPayloadWriter implementation that writes to a IPC stream
/// A IpcPayloadWriter implementation that writes to an IPC stream
/// (with an end-of-stream marker)
class PayloadStreamWriter : public internal::IpcPayloadWriter,
protected StreamBookKeeper {
Expand Down
Loading