From 7b8d20e8014e76115eb7ed18b8075b14362f2b52 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 16 Apr 2022 17:32:13 -0700 Subject: [PATCH] [BugFix][TIR] Fix narrower dtype of loop vars in CreatePrimFunc --- src/te/operation/create_primfunc.cc | 4 +++- .../python/unittest/test_te_create_primfunc.py | 18 +++++++++++++++--- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 5cf6e5c7dc1b..6254d5997aca 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -244,7 +244,9 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in axes.insert(axes.end(), compute_op->reduce_axis.begin(), compute_op->reduce_axis.end()); Array bindings; for (size_t i = 0; i < axes.size(); ++i) { - bindings.push_back(Var("i" + std::to_string(i))); + const IterVar& axis = axes[i]; + int bits = std::max(axis->dom->min.dtype().bits(), axis->dom->extent.dtype().bits()); + bindings.push_back(Var("i" + std::to_string(i), runtime::DataType::Int(bits))); } // Step 2. Generate block bodies. Array seq_stmt; diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index 23d264d4ef38..eba71cf5e484 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -15,11 +15,11 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring -import tvm -from tvm.script import tir as T -from tvm import te, tir, topi import numpy as np +import tvm import tvm.testing +from tvm import te, tir, topi +from tvm.script import tir as T def test_unique_name_complete_block(): @@ -473,6 +473,17 @@ def test_argmax_val_idx(): _check_workload(te_argmax_val_idx, tir_argmax_val_idx) +def test_int64_indices(): + n = te.var("n", "int64") + A = te.placeholder((n,), name="A") + B = te.compute(A.shape, lambda *i: A(*i) + 1, name="B") + prim_func = te.create_prim_func([A, B]) + loop = prim_func.body.block.body + assert loop.loop_var.dtype == "int64" + assert loop.min.dtype == "int64" + assert loop.extent.dtype == "int64" + + if __name__ == "__main__": test_unique_name_complete_block() test_unique_name_reduction_block() @@ -488,3 +499,4 @@ def test_argmax_val_idx(): test_tensor_attr() test_argmax_idx_val() test_argmax_val_idx() + test_int64_indices()