diff --git a/cpp/src/arrow/sparse_tensor.cc b/cpp/src/arrow/sparse_tensor.cc index 6b4f99e6368..fc2d5386cb6 100644 --- a/cpp/src/arrow/sparse_tensor.cc +++ b/cpp/src/arrow/sparse_tensor.cc @@ -135,18 +135,6 @@ class SparseTensorConverter using BaseClass::tensor_; }; -template -void MakeSparseTensorFromTensor(const Tensor& tensor, - std::shared_ptr* sparse_index, - std::shared_ptr* data) { - NumericTensor numeric_tensor(tensor.data(), tensor.shape(), tensor.strides()); - SparseTensorConverter converter(numeric_tensor); - Status s = converter.Convert(); - DCHECK_OK(s); - *sparse_index = converter.sparse_index; - *data = converter.data; -} - // ---------------------------------------------------------------------- // SparseTensorConverter for SparseCSRIndex @@ -244,6 +232,87 @@ INSTANTIATE_SPARSE_TENSOR_CONVERTER(SparseCSRIndex); } // namespace +namespace internal { + +namespace { + +template +void MakeSparseTensorFromTensor(const Tensor& tensor, + std::shared_ptr* sparse_index, + std::shared_ptr* data) { + NumericTensor numeric_tensor(tensor.data(), tensor.shape(), tensor.strides()); + SparseTensorConverter converter(numeric_tensor); + ARROW_CHECK_OK(converter.Convert()); + *sparse_index = converter.sparse_index; + *data = converter.data; +} + +template +inline void MakeSparseTensorFromTensor(const Tensor& tensor, + std::shared_ptr* sparse_index, + std::shared_ptr* data) { + switch (tensor.type()->id()) { + case Type::UINT8: + MakeSparseTensorFromTensor(tensor, sparse_index, data); + break; + case Type::INT8: + MakeSparseTensorFromTensor(tensor, sparse_index, data); + break; + case Type::UINT16: + MakeSparseTensorFromTensor(tensor, sparse_index, data); + break; + case Type::INT16: + MakeSparseTensorFromTensor(tensor, sparse_index, data); + break; + case Type::UINT32: + MakeSparseTensorFromTensor(tensor, sparse_index, data); + break; + case Type::INT32: + MakeSparseTensorFromTensor(tensor, sparse_index, data); + break; + case Type::UINT64: + MakeSparseTensorFromTensor(tensor, sparse_index, data); + break; + case Type::INT64: + MakeSparseTensorFromTensor(tensor, sparse_index, data); + break; + case Type::HALF_FLOAT: + MakeSparseTensorFromTensor(tensor, sparse_index, + data); + break; + case Type::FLOAT: + MakeSparseTensorFromTensor(tensor, sparse_index, data); + break; + case Type::DOUBLE: + MakeSparseTensorFromTensor(tensor, sparse_index, data); + break; + default: + ARROW_LOG(FATAL) << "Unsupported Tensor value type"; + break; + } +} + +} // namespace + +void MakeSparseTensorFromTensor(const Tensor& tensor, + SparseTensorFormat::type sparse_format_id, + std::shared_ptr* sparse_index, + std::shared_ptr* data) { + switch (sparse_format_id) { + case SparseTensorFormat::COO: + MakeSparseTensorFromTensor(tensor, sparse_index, data); + break; + case SparseTensorFormat::CSR: + MakeSparseTensorFromTensor(tensor, sparse_index, data); + break; + default: + ARROW_LOG(FATAL) << "Invalid sparse tensor format ID"; + break; + } +} + +} // namespace internal + // ---------------------------------------------------------------------- // SparseCOOIndex @@ -303,113 +372,4 @@ bool SparseTensor::Equals(const SparseTensor& other) const { return SparseTensorEquals(*this, other); } -// ---------------------------------------------------------------------- -// SparseTensorImpl - -// Constructor with a dense tensor -template -SparseTensorImpl::SparseTensorImpl( - const std::shared_ptr& type, const std::vector& shape, - const std::vector& dim_names) - : SparseTensorImpl(nullptr, type, nullptr, shape, dim_names) {} - -// Constructor with a dense tensor -template -template -SparseTensorImpl::SparseTensorImpl(const NumericTensor& tensor) - : SparseTensorImpl(nullptr, tensor.type(), nullptr, tensor.shape(), - tensor.dim_names_) { - SparseTensorConverter converter(tensor); - Status s = converter.Convert(); - DCHECK_OK(s); - sparse_index_ = converter.sparse_index; - data_ = converter.data; -} - -// Constructor with a dense tensor -template -SparseTensorImpl::SparseTensorImpl(const Tensor& tensor) - : SparseTensorImpl(nullptr, tensor.type(), nullptr, tensor.shape(), - tensor.dim_names_) { - switch (tensor.type()->id()) { - case Type::UINT8: - MakeSparseTensorFromTensor(tensor, &sparse_index_, - &data_); - return; - case Type::INT8: - MakeSparseTensorFromTensor(tensor, &sparse_index_, - &data_); - return; - case Type::UINT16: - MakeSparseTensorFromTensor(tensor, &sparse_index_, - &data_); - return; - case Type::INT16: - MakeSparseTensorFromTensor(tensor, &sparse_index_, - &data_); - return; - case Type::UINT32: - MakeSparseTensorFromTensor(tensor, &sparse_index_, - &data_); - return; - case Type::INT32: - MakeSparseTensorFromTensor(tensor, &sparse_index_, - &data_); - return; - case Type::UINT64: - MakeSparseTensorFromTensor(tensor, &sparse_index_, - &data_); - return; - case Type::INT64: - MakeSparseTensorFromTensor(tensor, &sparse_index_, - &data_); - return; - case Type::HALF_FLOAT: - MakeSparseTensorFromTensor(tensor, &sparse_index_, - &data_); - return; - case Type::FLOAT: - MakeSparseTensorFromTensor(tensor, &sparse_index_, - &data_); - return; - case Type::DOUBLE: - MakeSparseTensorFromTensor(tensor, &sparse_index_, - &data_); - return; - default: - break; - } -} - -// ---------------------------------------------------------------------- -// Instantiate templates - -#define INSTANTIATE_SPARSE_TENSOR(IndexType) \ - template class ARROW_TEMPLATE_EXPORT SparseTensorImpl; \ - template ARROW_EXPORT SparseTensorImpl::SparseTensorImpl( \ - const NumericTensor&); \ - template ARROW_EXPORT SparseTensorImpl::SparseTensorImpl( \ - const NumericTensor&); \ - template ARROW_EXPORT SparseTensorImpl::SparseTensorImpl( \ - const NumericTensor&); \ - template ARROW_EXPORT SparseTensorImpl::SparseTensorImpl( \ - const NumericTensor&); \ - template ARROW_EXPORT SparseTensorImpl::SparseTensorImpl( \ - const NumericTensor&); \ - template ARROW_EXPORT SparseTensorImpl::SparseTensorImpl( \ - const NumericTensor&); \ - template ARROW_EXPORT SparseTensorImpl::SparseTensorImpl( \ - const NumericTensor&); \ - template ARROW_EXPORT SparseTensorImpl::SparseTensorImpl( \ - const NumericTensor&); \ - template ARROW_EXPORT SparseTensorImpl::SparseTensorImpl( \ - const NumericTensor&); \ - template ARROW_EXPORT SparseTensorImpl::SparseTensorImpl( \ - const NumericTensor&); \ - template ARROW_EXPORT SparseTensorImpl::SparseTensorImpl( \ - const NumericTensor&) - -INSTANTIATE_SPARSE_TENSOR(SparseCOOIndex); -INSTANTIATE_SPARSE_TENSOR(SparseCSRIndex); - } // namespace arrow diff --git a/cpp/src/arrow/sparse_tensor.h b/cpp/src/arrow/sparse_tensor.h index e622245d633..b6fe4b20597 100644 --- a/cpp/src/arrow/sparse_tensor.h +++ b/cpp/src/arrow/sparse_tensor.h @@ -217,10 +217,20 @@ class ARROW_EXPORT SparseTensor { // ---------------------------------------------------------------------- // SparseTensorImpl class +namespace internal { + +ARROW_EXPORT +void MakeSparseTensorFromTensor(const Tensor& tensor, + SparseTensorFormat::type sparse_format_id, + std::shared_ptr* sparse_index, + std::shared_ptr* data); + +} // namespace internal + /// \brief EXPERIMENTAL: Concrete sparse tensor implementation classes with sparse index /// type template -class ARROW_EXPORT SparseTensorImpl : public SparseTensor { +class SparseTensorImpl : public SparseTensor { public: virtual ~SparseTensorImpl() = default; @@ -234,14 +244,16 @@ class ARROW_EXPORT SparseTensorImpl : public SparseTensor { // Constructor for empty sparse tensor SparseTensorImpl(const std::shared_ptr& type, const std::vector& shape, - const std::vector& dim_names = {}); - - // Constructor with a dense numeric tensor - template - explicit SparseTensorImpl(const NumericTensor& tensor); + const std::vector& dim_names = {}) + : SparseTensorImpl(NULLPTR, type, NULLPTR, shape, dim_names) {} // Constructor with a dense tensor - explicit SparseTensorImpl(const Tensor& tensor); + explicit SparseTensorImpl(const Tensor& tensor) + : SparseTensorImpl(NULLPTR, tensor.type(), NULLPTR, tensor.shape(), + tensor.dim_names_) { + internal::MakeSparseTensorFromTensor(tensor, SparseIndexType::format_id, + &sparse_index_, &data_); + } private: ARROW_DISALLOW_COPY_AND_ASSIGN(SparseTensorImpl); diff --git a/cpp/src/arrow/util/visibility.h b/cpp/src/arrow/util/visibility.h index b224717a62d..95cd9cf5ba2 100644 --- a/cpp/src/arrow/util/visibility.h +++ b/cpp/src/arrow/util/visibility.h @@ -43,14 +43,4 @@ #endif #endif // Non-Windows -// This is a complicated topic, some reading on it: -// http://www.codesynthesis.com/~boris/blog/2010/01/18/dll-export-cxx-templates/ -#if defined(_MSC_VER) || defined(__clang__) -#define ARROW_TEMPLATE_CLASS_EXPORT -#define ARROW_TEMPLATE_EXPORT ARROW_EXPORT -#else -#define ARROW_TEMPLATE_CLASS_EXPORT ARROW_EXPORT -#define ARROW_TEMPLATE_EXPORT -#endif - #endif // ARROW_UTIL_VISIBILITY_H