From fcc1c8309806689da611dd76e6c95f11699d9e62 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Wed, 15 Sep 2021 13:09:16 -0700 Subject: [PATCH 1/8] initial change --- include/tvm/te/tensor.h | 4 ++-- src/te/tensor.cc | 23 +++++++++++++++++------ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index 85677a726574..e06e649614aa 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -131,13 +131,13 @@ class Tensor : public DataProducer { * \param indices the indices. * \return the result expression representing tensor read. */ - TVM_DLL PrimExpr operator()(Array indices) const; + TVM_DLL PrimExpr operator()(Array indices, bool support_negative_indices) const; /*! * \brief Take elements from the tensor * \param indices the indices. * \return the result expression representing tensor read. */ - TVM_DLL PrimExpr operator()(Array indices) const; + TVM_DLL PrimExpr operator()(Array indices, bool support_negative_indices) const; /*! * \brief data structure to represent a slice that fixes first k coordinates. * This is used to enable syntax sugar of Tensor[x][y][z] to get the element. diff --git a/src/te/tensor.cc b/src/te/tensor.cc index b48f39a38627..8efb790ad745 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -39,15 +39,26 @@ IterVar reduce_axis(Range dom, std::string name) { return IterVar(dom, Var(name) Var var(std::string name_hint, DataType t) { return Var(name_hint, t); } // Tensor -PrimExpr Tensor::operator()(Array indices) const { +PrimExpr Tensor::operator()(Array indices, bool support_negative_indices = false) const { Array arr(indices.begin(), indices.end()); - return operator()(arr); + return operator()(arr, support_negative_indices); } -PrimExpr Tensor::operator()(Array indices) const { - if (ndim() != 0) { - ICHECK_EQ(ndim(), indices.size()) << "Tensor dimension mismatch in read " - << "ndim = " << ndim() << ", indices.size=" << indices.size(); +PrimExpr Tensor::operator()(Array indices, bool support_negative_indices = false) const { + Array shape = (*this)->shape; + + if (shape.size() != 0) { + ICHECK_EQ(shape.size(), indices.size()) + << "Tensor dimension mismatch in read " + << "ndim = " << ndim() << ", indices.size=" << indices.size(); + } + + if (support_negative_indices) { + for (size_t i = 0; i < shape.size(); i++) { + PrimExpr new_index = if_then_else(indices[i] < make_const(indices[i]->dtype, 0), + indices[i] + shape[i], indices[i]); + indices.Set(i, new_index); + } } return ProducerLoad((*this), indices); From 3d77b0c38a57b4a5553b5245f11b277d2eb14921 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 20 Sep 2021 14:10:35 -0700 Subject: [PATCH 2/8] more explicit api --- include/tvm/te/tensor.h | 31 +++++++++++++++++++++++++++++-- src/te/tensor.cc | 24 +++++++++++++++++------- 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index e06e649614aa..14e234e8318f 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -100,6 +100,9 @@ class TensorNode : public DataProducerNode { * or intermediate computation result. */ class Tensor : public DataProducer { + private: + inline PrimExpr index_tensor(Array indices, bool support_negative_indices) const; + public: TVM_DLL Tensor(Array shape, DataType dtype, Operation op, int value_index); /*! @@ -129,15 +132,39 @@ class Tensor : public DataProducer { /*! * \brief Take elements from the tensor * \param indices the indices. + * \param support_negative_indices * \return the result expression representing tensor read. */ - TVM_DLL PrimExpr operator()(Array indices, bool support_negative_indices) const; + TVM_DLL PrimExpr operator()(Array indices) const; /*! * \brief Take elements from the tensor * \param indices the indices. * \return the result expression representing tensor read. */ - TVM_DLL PrimExpr operator()(Array indices, bool support_negative_indices) const; + TVM_DLL PrimExpr operator()(Array indices) const; + /*! + * \brief Take elements from the tensor with support for negative indices. + * \param args The indices + * \return the result expression representing tensor read. + */ + template + TVM_DLL PrimExpr IndexWithNegativeIndices(Args&&... args) const { + Array indices{std::forward(args)...}; + return IndexWithNegativeIndices(indices); + } + /*! + * \brief Take elements from the tensor with support for negative indices. + * \param indices the indices. + * \return the result expression representing tensor read. + */ + TVM_DLL PrimExpr IndexWithNegativeIndices(Array indices) const; + /*! + * \brief Take elements from the tensor with support for negative indices. + * \param indices the indices. + * \return the result expression representing tensor read. + */ + TVM_DLL PrimExpr IndexWithNegativeIndices(Array indices) const; + /*! * \brief data structure to represent a slice that fixes first k coordinates. * This is used to enable syntax sugar of Tensor[x][y][z] to get the element. diff --git a/src/te/tensor.cc b/src/te/tensor.cc index 8efb790ad745..b7475a121393 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -39,12 +39,7 @@ IterVar reduce_axis(Range dom, std::string name) { return IterVar(dom, Var(name) Var var(std::string name_hint, DataType t) { return Var(name_hint, t); } // Tensor -PrimExpr Tensor::operator()(Array indices, bool support_negative_indices = false) const { - Array arr(indices.begin(), indices.end()); - return operator()(arr, support_negative_indices); -} - -PrimExpr Tensor::operator()(Array indices, bool support_negative_indices = false) const { +inline PrimExpr Tensor::index_tensor(Array indices, bool support_negative_indices) const { Array shape = (*this)->shape; if (shape.size() != 0) { @@ -60,10 +55,25 @@ PrimExpr Tensor::operator()(Array indices, bool support_negative_indic indices.Set(i, new_index); } } - return ProducerLoad((*this), indices); } +PrimExpr Tensor::operator()(Array indices) const { + Array arr(indices.begin(), indices.end()); + return operator()(arr); +} + +PrimExpr Tensor::operator()(Array indices) const { return index_tensor(indices, false); } + +PrimExpr Tensor::IndexWithNegativeIndices(Array indices) const { + Array arr(indices.begin(), indices.end()); + return IndexWithNegativeIndices(arr); +} + +PrimExpr Tensor::IndexWithNegativeIndices(Array indices) const { + return index_tensor(indices, true); +} + String TensorNode::GetNameHint() const { return op->num_outputs() == 1 ? op->name : (op->name + ".v" + std::to_string(value_index)); } From 788d1f04db6eaa9f95a201a4e9dfb0cfdee7841f Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 20 Sep 2021 14:12:30 -0700 Subject: [PATCH 3/8] switch to select --- src/te/tensor.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/te/tensor.cc b/src/te/tensor.cc index b7475a121393..488e5e73017d 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -50,8 +50,8 @@ inline PrimExpr Tensor::index_tensor(Array indices, bool support_negat if (support_negative_indices) { for (size_t i = 0; i < shape.size(); i++) { - PrimExpr new_index = if_then_else(indices[i] < make_const(indices[i]->dtype, 0), - indices[i] + shape[i], indices[i]); + PrimExpr new_index = + Select(indices[i] < make_const(indices[i]->dtype, 0), indices[i] + shape[i], indices[i]); indices.Set(i, new_index); } } From 2acbcf56ef7e833cacfc062693bd83e977250457 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 20 Sep 2021 14:14:27 -0700 Subject: [PATCH 4/8] add support for negative indices --- include/tvm/te/tensor.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index 14e234e8318f..0b6bc351b808 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -101,6 +101,12 @@ class TensorNode : public DataProducerNode { */ class Tensor : public DataProducer { private: + /*! + * \brief Helper for indexing operations into tensors + * \param args The indices + * \param support_negative_indices Whether to normalize indices in the case of negative indices. + * \return the result expression representing tensor read. + */ inline PrimExpr index_tensor(Array indices, bool support_negative_indices) const; public: From 3bef867a995fc3630e94af6d811da0a0fc8a900c Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 20 Sep 2021 14:14:58 -0700 Subject: [PATCH 5/8] reduce things further --- include/tvm/te/tensor.h | 1 - 1 file changed, 1 deletion(-) diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index 0b6bc351b808..1cbb2541d0a1 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -138,7 +138,6 @@ class Tensor : public DataProducer { /*! * \brief Take elements from the tensor * \param indices the indices. - * \param support_negative_indices * \return the result expression representing tensor read. */ TVM_DLL PrimExpr operator()(Array indices) const; From 5f75a02260db4fd58d1d31dc75512e9f95669d65 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 21 Sep 2021 10:08:00 -0700 Subject: [PATCH 6/8] lint --- include/tvm/te/tensor.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index 1cbb2541d0a1..3dfbd1805256 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -103,7 +103,7 @@ class Tensor : public DataProducer { private: /*! * \brief Helper for indexing operations into tensors - * \param args The indices + * \param indices The indices * \param support_negative_indices Whether to normalize indices in the case of negative indices. * \return the result expression representing tensor read. */ From 72ba9145214ffe26ee62a48e838d8691e7500fd6 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 21 Sep 2021 10:22:02 -0700 Subject: [PATCH 7/8] to CamelCase --- include/tvm/te/tensor.h | 2 +- src/te/tensor.cc | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index 3dfbd1805256..30480e150823 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -107,7 +107,7 @@ class Tensor : public DataProducer { * \param support_negative_indices Whether to normalize indices in the case of negative indices. * \return the result expression representing tensor read. */ - inline PrimExpr index_tensor(Array indices, bool support_negative_indices) const; + inline PrimExpr IndexTensor(Array indices, bool support_negative_indices) const; public: TVM_DLL Tensor(Array shape, DataType dtype, Operation op, int value_index); diff --git a/src/te/tensor.cc b/src/te/tensor.cc index 488e5e73017d..1d75761216f1 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -39,7 +39,7 @@ IterVar reduce_axis(Range dom, std::string name) { return IterVar(dom, Var(name) Var var(std::string name_hint, DataType t) { return Var(name_hint, t); } // Tensor -inline PrimExpr Tensor::index_tensor(Array indices, bool support_negative_indices) const { +inline PrimExpr Tensor::IndexTensor(Array indices, bool support_negative_indices) const { Array shape = (*this)->shape; if (shape.size() != 0) { @@ -63,7 +63,7 @@ PrimExpr Tensor::operator()(Array indices) const { return operator()(arr); } -PrimExpr Tensor::operator()(Array indices) const { return index_tensor(indices, false); } +PrimExpr Tensor::operator()(Array indices) const { return IndexTensor(indices, false); } PrimExpr Tensor::IndexWithNegativeIndices(Array indices) const { Array arr(indices.begin(), indices.end()); @@ -71,7 +71,7 @@ PrimExpr Tensor::IndexWithNegativeIndices(Array indices) const { } PrimExpr Tensor::IndexWithNegativeIndices(Array indices) const { - return index_tensor(indices, true); + return IndexTensor(indices, true); } String TensorNode::GetNameHint() const { From 9b7d01af1984e6295622b39f003dcdc4170d13ca Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 30 Sep 2021 23:02:03 -0700 Subject: [PATCH 8/8] unit test --- tests/cpp/tensor_test.cc | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/cpp/tensor_test.cc b/tests/cpp/tensor_test.cc index a50af838f735..e53f6d05a991 100644 --- a/tests/cpp/tensor_test.cc +++ b/tests/cpp/tensor_test.cc @@ -49,3 +49,14 @@ TEST(Tensor, Reduce) { {m, n}, [&](Var i, Var j) { return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv}); }, "C"); LOG(INFO) << C->op.as()->body; } + +TEST(Tensor, Indexing) { + using namespace tvm; + using namespace tvm::te; + + Var x("x"), y("y"); + te::Tensor A = te::placeholder({x, y}, DataType::Float(32), "A"); + LOG(INFO) << A(0, 0); + LOG(INFO) << A.IndexWithNegativeIndices(-1, -1); + LOG(INFO) << A.IndexWithNegativeIndices(0, -1); +}