Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions cpp/src/arrow/compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,13 @@ inline bool SparseTensorEqualsImplDispatch(const SparseTensorImpl<SparseIndexTyp
right_csr);
}

case SparseTensorFormat::CSC: {
const auto& right_csc =
checked_cast<const SparseTensorImpl<SparseCSCIndex>&>(right);
return SparseTensorEqualsImpl<SparseIndexType, SparseCSCIndex>::Compare(left,
right_csc);
}

default:
return false;
}
Expand Down Expand Up @@ -1220,6 +1227,11 @@ bool SparseTensorEquals(const SparseTensor& left, const SparseTensor& right) {
return SparseTensorEqualsImplDispatch(left_csr, right);
}

case SparseTensorFormat::CSC: {
const auto& left_csc = checked_cast<const SparseTensorImpl<SparseCSCIndex>&>(left);
return SparseTensorEqualsImplDispatch(left_csc, right);
}

default:
return false;
}
Expand Down
166 changes: 100 additions & 66 deletions cpp/src/arrow/ipc/metadata_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -906,62 +906,6 @@ static Status MakeRecordBatch(FBB& fbb, int64_t length, int64_t body_length,
return Status::OK();
}

} // namespace

Status WriteSchemaMessage(const Schema& schema, DictionaryMemo* dictionary_memo,
std::shared_ptr<Buffer>* out) {
FBB fbb;
flatbuffers::Offset<flatbuf::Schema> fb_schema;
RETURN_NOT_OK(SchemaToFlatbuffer(fbb, schema, dictionary_memo, &fb_schema));
return WriteFBMessage(fbb, flatbuf::MessageHeader_Schema, fb_schema.Union(), 0, out);
}

Status WriteRecordBatchMessage(int64_t length, int64_t body_length,
const std::vector<FieldMetadata>& nodes,
const std::vector<BufferMetadata>& buffers,
std::shared_ptr<Buffer>* out) {
FBB fbb;
RecordBatchOffset record_batch;
RETURN_NOT_OK(MakeRecordBatch(fbb, length, body_length, nodes, buffers, &record_batch));
return WriteFBMessage(fbb, flatbuf::MessageHeader_RecordBatch, record_batch.Union(),
body_length, out);
}

Status WriteTensorMessage(const Tensor& tensor, int64_t buffer_start_offset,
std::shared_ptr<Buffer>* out) {
using TensorDimOffset = flatbuffers::Offset<flatbuf::TensorDim>;
using TensorOffset = flatbuffers::Offset<flatbuf::Tensor>;

FBB fbb;

const auto& type = checked_cast<const FixedWidthType&>(*tensor.type());
const int elem_size = type.bit_width() / 8;

flatbuf::Type fb_type_type;
Offset fb_type;
RETURN_NOT_OK(TensorTypeToFlatbuffer(fbb, *tensor.type(), &fb_type_type, &fb_type));

std::vector<TensorDimOffset> dims;
for (int i = 0; i < tensor.ndim(); ++i) {
FBString name = fbb.CreateString(tensor.dim_name(i));
dims.push_back(flatbuf::CreateTensorDim(fbb, tensor.shape()[i], name));
}

auto fb_shape = fbb.CreateVector(util::MakeNonNull(dims.data()), dims.size());

flatbuffers::Offset<flatbuffers::Vector<int64_t>> fb_strides;
fb_strides = fbb.CreateVector(util::MakeNonNull(tensor.strides().data()),
tensor.strides().size());
int64_t body_length = tensor.size() * elem_size;
flatbuf::Buffer buffer(buffer_start_offset, body_length);

TensorOffset fb_tensor =
flatbuf::CreateTensor(fbb, fb_type_type, fb_type, fb_shape, fb_strides, &buffer);

return WriteFBMessage(fbb, flatbuf::MessageHeader_Tensor, fb_tensor.Union(),
body_length, out);
}

Status MakeSparseTensorIndexCOO(FBB& fbb, const SparseCOOIndex& sparse_index,
const std::vector<BufferMetadata>& buffers,
flatbuf::SparseTensorIndex* fb_sparse_index_type,
Expand All @@ -988,11 +932,25 @@ Status MakeSparseTensorIndexCOO(FBB& fbb, const SparseCOOIndex& sparse_index,
return Status::OK();
}

Status MakeSparseMatrixIndexCSR(FBB& fbb, const SparseCSRIndex& sparse_index,
template <typename SparseIndexType>
struct SparseMatrixCompressedAxis {};

template <>
struct SparseMatrixCompressedAxis<SparseCSRIndex> {
constexpr static const auto value = flatbuf::SparseMatrixCompressedAxis_Row;
};

template <>
struct SparseMatrixCompressedAxis<SparseCSCIndex> {
constexpr static const auto value = flatbuf::SparseMatrixCompressedAxis_Column;
};

template <typename SparseIndexType>
Status MakeSparseMatrixIndexCSX(FBB& fbb, const SparseIndexType& sparse_index,
const std::vector<BufferMetadata>& buffers,
flatbuf::SparseTensorIndex* fb_sparse_index_type,
Offset* fb_sparse_index, size_t* num_buffers) {
*fb_sparse_index_type = flatbuf::SparseTensorIndex_SparseMatrixIndexCSR;
*fb_sparse_index_type = flatbuf::SparseTensorIndex_SparseMatrixIndexCSX;

// We assume that the value type of indptr tensor is an integer.
const auto& indptr_value_type =
Expand All @@ -1012,9 +970,11 @@ Status MakeSparseMatrixIndexCSR(FBB& fbb, const SparseCSRIndex& sparse_index,
const BufferMetadata& indices_metadata = buffers[1];
flatbuf::Buffer indices(indices_metadata.offset, indices_metadata.length);

*fb_sparse_index = flatbuf::CreateSparseMatrixIndexCSR(fbb, indptr_type_offset, &indptr,
indices_type_offset, &indices)
.Union();
auto compressedAxis = SparseMatrixCompressedAxis<SparseIndexType>::value;
*fb_sparse_index =
flatbuf::CreateSparseMatrixIndexCSX(fbb, compressedAxis, indptr_type_offset,
&indptr, indices_type_offset, &indices)
.Union();
*num_buffers = 2;
return Status::OK();
}
Expand All @@ -1031,11 +991,17 @@ Status MakeSparseTensorIndex(FBB& fbb, const SparseIndex& sparse_index,
break;

case SparseTensorFormat::CSR:
RETURN_NOT_OK(MakeSparseMatrixIndexCSR(
RETURN_NOT_OK(MakeSparseMatrixIndexCSX(
fbb, checked_cast<const SparseCSRIndex&>(sparse_index), buffers,
fb_sparse_index_type, fb_sparse_index, num_buffers));
break;

case SparseTensorFormat::CSC:
RETURN_NOT_OK(MakeSparseMatrixIndexCSX(
fbb, checked_cast<const SparseCSCIndex&>(sparse_index), buffers,
fb_sparse_index_type, fb_sparse_index, num_buffers));
break;

default:
std::stringstream ss;
ss << "Unsupporoted sparse tensor format:: " << sparse_index.ToString()
Expand Down Expand Up @@ -1082,6 +1048,62 @@ Status MakeSparseTensor(FBB& fbb, const SparseTensor& sparse_tensor, int64_t bod
return Status::OK();
}

} // namespace

Status WriteSchemaMessage(const Schema& schema, DictionaryMemo* dictionary_memo,
std::shared_ptr<Buffer>* out) {
FBB fbb;
flatbuffers::Offset<flatbuf::Schema> fb_schema;
RETURN_NOT_OK(SchemaToFlatbuffer(fbb, schema, dictionary_memo, &fb_schema));
return WriteFBMessage(fbb, flatbuf::MessageHeader_Schema, fb_schema.Union(), 0, out);
}

Status WriteRecordBatchMessage(int64_t length, int64_t body_length,
const std::vector<FieldMetadata>& nodes,
const std::vector<BufferMetadata>& buffers,
std::shared_ptr<Buffer>* out) {
FBB fbb;
RecordBatchOffset record_batch;
RETURN_NOT_OK(MakeRecordBatch(fbb, length, body_length, nodes, buffers, &record_batch));
return WriteFBMessage(fbb, flatbuf::MessageHeader_RecordBatch, record_batch.Union(),
body_length, out);
}

Status WriteTensorMessage(const Tensor& tensor, int64_t buffer_start_offset,
std::shared_ptr<Buffer>* out) {
using TensorDimOffset = flatbuffers::Offset<flatbuf::TensorDim>;
using TensorOffset = flatbuffers::Offset<flatbuf::Tensor>;

FBB fbb;

const auto& type = checked_cast<const FixedWidthType&>(*tensor.type());
const int elem_size = type.bit_width() / 8;

flatbuf::Type fb_type_type;
Offset fb_type;
RETURN_NOT_OK(TensorTypeToFlatbuffer(fbb, *tensor.type(), &fb_type_type, &fb_type));

std::vector<TensorDimOffset> dims;
for (int i = 0; i < tensor.ndim(); ++i) {
FBString name = fbb.CreateString(tensor.dim_name(i));
dims.push_back(flatbuf::CreateTensorDim(fbb, tensor.shape()[i], name));
}

auto fb_shape = fbb.CreateVector(util::MakeNonNull(dims.data()), dims.size());

flatbuffers::Offset<flatbuffers::Vector<int64_t>> fb_strides;
fb_strides = fbb.CreateVector(util::MakeNonNull(tensor.strides().data()),
tensor.strides().size());
int64_t body_length = tensor.size() * elem_size;
flatbuf::Buffer buffer(buffer_start_offset, body_length);

TensorOffset fb_tensor =
flatbuf::CreateTensor(fbb, fb_type_type, fb_type, fb_shape, fb_strides, &buffer);

return WriteFBMessage(fbb, flatbuf::MessageHeader_Tensor, fb_tensor.Union(),
body_length, out);
}

Status WriteSparseTensorMessage(const SparseTensor& sparse_tensor, int64_t body_length,
const std::vector<BufferMetadata>& buffers,
std::shared_ptr<Buffer>* out) {
Expand Down Expand Up @@ -1225,7 +1247,7 @@ Status GetSparseCOOIndexMetadata(const flatbuf::SparseTensorIndexCOO* sparse_ind
return IntFromFlatbuffer(sparse_index->indicesType(), indices_type);
}

Status GetSparseCSRIndexMetadata(const flatbuf::SparseMatrixIndexCSR* sparse_index,
Status GetSparseCSXIndexMetadata(const flatbuf::SparseMatrixIndexCSX* sparse_index,
std::shared_ptr<DataType>* indptr_type,
std::shared_ptr<DataType>* indices_type) {
RETURN_NOT_OK(IntFromFlatbuffer(sparse_index->indptrType(), indptr_type));
Expand Down Expand Up @@ -1276,9 +1298,21 @@ Status GetSparseTensorMetadata(const Buffer& metadata, std::shared_ptr<DataType>
*sparse_tensor_format_id = SparseTensorFormat::COO;
break;

case flatbuf::SparseTensorIndex_SparseMatrixIndexCSR:
*sparse_tensor_format_id = SparseTensorFormat::CSR;
break;
case flatbuf::SparseTensorIndex_SparseMatrixIndexCSX: {
auto cs = sparse_tensor->sparseIndex_as_SparseMatrixIndexCSX();
switch (cs->compressedAxis()) {
case flatbuf::SparseMatrixCompressedAxis_Row:
*sparse_tensor_format_id = SparseTensorFormat::CSR;
break;

case flatbuf::SparseMatrixCompressedAxis_Column:
*sparse_tensor_format_id = SparseTensorFormat::CSC;
break;

default:
return Status::Invalid("Invalid value of SparseMatrixCompressedAxis");
}
} break;

default:
return Status::Invalid("Unrecognized sparse index type");
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/ipc/metadata_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ Status GetTensorMetadata(const Buffer& metadata, std::shared_ptr<DataType>* type
Status GetSparseCOOIndexMetadata(const flatbuf::SparseTensorIndexCOO* sparse_index,
std::shared_ptr<DataType>* indices_type);

// EXPERIMENTAL: Extracting metadata of a SparseCSRIndex from the message
Status GetSparseCSRIndexMetadata(const flatbuf::SparseMatrixIndexCSR* sparse_index,
// EXPERIMENTAL: Extracting metadata of a SparseCSXIndex from the message
Status GetSparseCSXIndexMetadata(const flatbuf::SparseMatrixIndexCSX* sparse_index,
std::shared_ptr<DataType>* indptr_type,
std::shared_ptr<DataType>* indices_type);

Expand Down
Loading