Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 81 additions & 121 deletions cpp/src/arrow/sparse_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,6 @@ class SparseTensorConverter<TYPE, SparseCOOIndex>
using BaseClass::tensor_;
};

template <typename TYPE, typename SparseIndexType>
void MakeSparseTensorFromTensor(const Tensor& tensor,
std::shared_ptr<SparseIndex>* sparse_index,
std::shared_ptr<Buffer>* data) {
NumericTensor<TYPE> numeric_tensor(tensor.data(), tensor.shape(), tensor.strides());
SparseTensorConverter<TYPE, SparseIndexType> converter(numeric_tensor);
Status s = converter.Convert();
DCHECK_OK(s);
*sparse_index = converter.sparse_index;
*data = converter.data;
}

// ----------------------------------------------------------------------
// SparseTensorConverter for SparseCSRIndex

Expand Down Expand Up @@ -244,6 +232,87 @@ INSTANTIATE_SPARSE_TENSOR_CONVERTER(SparseCSRIndex);

} // namespace

namespace internal {

namespace {

template <typename TYPE, typename SparseIndexType>
void MakeSparseTensorFromTensor(const Tensor& tensor,
std::shared_ptr<SparseIndex>* sparse_index,
std::shared_ptr<Buffer>* data) {
NumericTensor<TYPE> numeric_tensor(tensor.data(), tensor.shape(), tensor.strides());
SparseTensorConverter<TYPE, SparseIndexType> converter(numeric_tensor);
ARROW_CHECK_OK(converter.Convert());
*sparse_index = converter.sparse_index;
*data = converter.data;
}

template <typename SparseIndexType>
inline void MakeSparseTensorFromTensor(const Tensor& tensor,
std::shared_ptr<SparseIndex>* sparse_index,
std::shared_ptr<Buffer>* data) {
switch (tensor.type()->id()) {
case Type::UINT8:
MakeSparseTensorFromTensor<UInt8Type, SparseIndexType>(tensor, sparse_index, data);
break;
case Type::INT8:
MakeSparseTensorFromTensor<Int8Type, SparseIndexType>(tensor, sparse_index, data);
break;
case Type::UINT16:
MakeSparseTensorFromTensor<UInt16Type, SparseIndexType>(tensor, sparse_index, data);
break;
case Type::INT16:
MakeSparseTensorFromTensor<Int16Type, SparseIndexType>(tensor, sparse_index, data);
break;
case Type::UINT32:
MakeSparseTensorFromTensor<UInt32Type, SparseIndexType>(tensor, sparse_index, data);
break;
case Type::INT32:
MakeSparseTensorFromTensor<Int32Type, SparseIndexType>(tensor, sparse_index, data);
break;
case Type::UINT64:
MakeSparseTensorFromTensor<UInt64Type, SparseIndexType>(tensor, sparse_index, data);
break;
case Type::INT64:
MakeSparseTensorFromTensor<Int64Type, SparseIndexType>(tensor, sparse_index, data);
break;
case Type::HALF_FLOAT:
MakeSparseTensorFromTensor<HalfFloatType, SparseIndexType>(tensor, sparse_index,
data);
break;
case Type::FLOAT:
MakeSparseTensorFromTensor<FloatType, SparseIndexType>(tensor, sparse_index, data);
break;
case Type::DOUBLE:
MakeSparseTensorFromTensor<DoubleType, SparseIndexType>(tensor, sparse_index, data);
break;
default:
ARROW_LOG(FATAL) << "Unsupported Tensor value type";
break;
}
}

} // namespace

void MakeSparseTensorFromTensor(const Tensor& tensor,
SparseTensorFormat::type sparse_format_id,
std::shared_ptr<SparseIndex>* sparse_index,
std::shared_ptr<Buffer>* data) {
switch (sparse_format_id) {
case SparseTensorFormat::COO:
MakeSparseTensorFromTensor<SparseCOOIndex>(tensor, sparse_index, data);
break;
case SparseTensorFormat::CSR:
MakeSparseTensorFromTensor<SparseCSRIndex>(tensor, sparse_index, data);
break;
default:
ARROW_LOG(FATAL) << "Invalid sparse tensor format ID";
break;
}
}

} // namespace internal

// ----------------------------------------------------------------------
// SparseCOOIndex

Expand Down Expand Up @@ -303,113 +372,4 @@ bool SparseTensor::Equals(const SparseTensor& other) const {
return SparseTensorEquals(*this, other);
}

// ----------------------------------------------------------------------
// SparseTensorImpl

// Constructor with a dense tensor
template <typename SparseIndexType>
SparseTensorImpl<SparseIndexType>::SparseTensorImpl(
const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape,
const std::vector<std::string>& dim_names)
: SparseTensorImpl(nullptr, type, nullptr, shape, dim_names) {}

// Constructor with a dense tensor
template <typename SparseIndexType>
template <typename TYPE>
SparseTensorImpl<SparseIndexType>::SparseTensorImpl(const NumericTensor<TYPE>& tensor)
: SparseTensorImpl(nullptr, tensor.type(), nullptr, tensor.shape(),
tensor.dim_names_) {
SparseTensorConverter<TYPE, SparseIndexType> converter(tensor);
Status s = converter.Convert();
DCHECK_OK(s);
sparse_index_ = converter.sparse_index;
data_ = converter.data;
}

// Constructor with a dense tensor
template <typename SparseIndexType>
SparseTensorImpl<SparseIndexType>::SparseTensorImpl(const Tensor& tensor)
: SparseTensorImpl(nullptr, tensor.type(), nullptr, tensor.shape(),
tensor.dim_names_) {
switch (tensor.type()->id()) {
case Type::UINT8:
MakeSparseTensorFromTensor<UInt8Type, SparseIndexType>(tensor, &sparse_index_,
&data_);
return;
case Type::INT8:
MakeSparseTensorFromTensor<Int8Type, SparseIndexType>(tensor, &sparse_index_,
&data_);
return;
case Type::UINT16:
MakeSparseTensorFromTensor<UInt16Type, SparseIndexType>(tensor, &sparse_index_,
&data_);
return;
case Type::INT16:
MakeSparseTensorFromTensor<Int16Type, SparseIndexType>(tensor, &sparse_index_,
&data_);
return;
case Type::UINT32:
MakeSparseTensorFromTensor<UInt32Type, SparseIndexType>(tensor, &sparse_index_,
&data_);
return;
case Type::INT32:
MakeSparseTensorFromTensor<Int32Type, SparseIndexType>(tensor, &sparse_index_,
&data_);
return;
case Type::UINT64:
MakeSparseTensorFromTensor<UInt64Type, SparseIndexType>(tensor, &sparse_index_,
&data_);
return;
case Type::INT64:
MakeSparseTensorFromTensor<Int64Type, SparseIndexType>(tensor, &sparse_index_,
&data_);
return;
case Type::HALF_FLOAT:
MakeSparseTensorFromTensor<HalfFloatType, SparseIndexType>(tensor, &sparse_index_,
&data_);
return;
case Type::FLOAT:
MakeSparseTensorFromTensor<FloatType, SparseIndexType>(tensor, &sparse_index_,
&data_);
return;
case Type::DOUBLE:
MakeSparseTensorFromTensor<DoubleType, SparseIndexType>(tensor, &sparse_index_,
&data_);
return;
default:
break;
}
}

// ----------------------------------------------------------------------
// Instantiate templates

#define INSTANTIATE_SPARSE_TENSOR(IndexType) \
template class ARROW_TEMPLATE_EXPORT SparseTensorImpl<IndexType>; \
template ARROW_EXPORT SparseTensorImpl<IndexType>::SparseTensorImpl( \
const NumericTensor<UInt8Type>&); \
template ARROW_EXPORT SparseTensorImpl<IndexType>::SparseTensorImpl( \
const NumericTensor<UInt16Type>&); \
template ARROW_EXPORT SparseTensorImpl<IndexType>::SparseTensorImpl( \
const NumericTensor<UInt32Type>&); \
template ARROW_EXPORT SparseTensorImpl<IndexType>::SparseTensorImpl( \
const NumericTensor<UInt64Type>&); \
template ARROW_EXPORT SparseTensorImpl<IndexType>::SparseTensorImpl( \
const NumericTensor<Int8Type>&); \
template ARROW_EXPORT SparseTensorImpl<IndexType>::SparseTensorImpl( \
const NumericTensor<Int16Type>&); \
template ARROW_EXPORT SparseTensorImpl<IndexType>::SparseTensorImpl( \
const NumericTensor<Int32Type>&); \
template ARROW_EXPORT SparseTensorImpl<IndexType>::SparseTensorImpl( \
const NumericTensor<Int64Type>&); \
template ARROW_EXPORT SparseTensorImpl<IndexType>::SparseTensorImpl( \
const NumericTensor<HalfFloatType>&); \
template ARROW_EXPORT SparseTensorImpl<IndexType>::SparseTensorImpl( \
const NumericTensor<FloatType>&); \
template ARROW_EXPORT SparseTensorImpl<IndexType>::SparseTensorImpl( \
const NumericTensor<DoubleType>&)

INSTANTIATE_SPARSE_TENSOR(SparseCOOIndex);
INSTANTIATE_SPARSE_TENSOR(SparseCSRIndex);

} // namespace arrow
26 changes: 19 additions & 7 deletions cpp/src/arrow/sparse_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,20 @@ class ARROW_EXPORT SparseTensor {
// ----------------------------------------------------------------------
// SparseTensorImpl class

namespace internal {

ARROW_EXPORT
void MakeSparseTensorFromTensor(const Tensor& tensor,
SparseTensorFormat::type sparse_format_id,
std::shared_ptr<SparseIndex>* sparse_index,
std::shared_ptr<Buffer>* data);

} // namespace internal

/// \brief EXPERIMENTAL: Concrete sparse tensor implementation classes with sparse index
/// type
template <typename SparseIndexType>
class ARROW_EXPORT SparseTensorImpl : public SparseTensor {
class SparseTensorImpl : public SparseTensor {
public:
virtual ~SparseTensorImpl() = default;

Expand All @@ -234,14 +244,16 @@ class ARROW_EXPORT SparseTensorImpl : public SparseTensor {
// Constructor for empty sparse tensor
SparseTensorImpl(const std::shared_ptr<DataType>& type,
const std::vector<int64_t>& shape,
const std::vector<std::string>& dim_names = {});

// Constructor with a dense numeric tensor
template <typename TYPE>
explicit SparseTensorImpl(const NumericTensor<TYPE>& tensor);
const std::vector<std::string>& dim_names = {})
: SparseTensorImpl(NULLPTR, type, NULLPTR, shape, dim_names) {}

// Constructor with a dense tensor
explicit SparseTensorImpl(const Tensor& tensor);
explicit SparseTensorImpl(const Tensor& tensor)
: SparseTensorImpl(NULLPTR, tensor.type(), NULLPTR, tensor.shape(),
tensor.dim_names_) {
internal::MakeSparseTensorFromTensor(tensor, SparseIndexType::format_id,
&sparse_index_, &data_);
}

private:
ARROW_DISALLOW_COPY_AND_ASSIGN(SparseTensorImpl);
Expand Down
10 changes: 0 additions & 10 deletions cpp/src/arrow/util/visibility.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,4 @@
#endif
#endif // Non-Windows

// This is a complicated topic, some reading on it:
// http://www.codesynthesis.com/~boris/blog/2010/01/18/dll-export-cxx-templates/
#if defined(_MSC_VER) || defined(__clang__)
#define ARROW_TEMPLATE_CLASS_EXPORT
#define ARROW_TEMPLATE_EXPORT ARROW_EXPORT
#else
#define ARROW_TEMPLATE_CLASS_EXPORT ARROW_EXPORT
#define ARROW_TEMPLATE_EXPORT
#endif

#endif // ARROW_UTIL_VISIBILITY_H