diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index a60a37c5a318..d6554fc37103 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -344,13 +344,14 @@ ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, PrimExpr extent = arith::Analyzer().Simplify(stop - start); ObjectPtr n = make_object(); int bits = std::max(min.dtype().bits(), extent.dtype().bits()); - n->vars = {Var("v", DataType(min.dtype().code(), bits, 1))}; + DataType dtype = DataType(min.dtype().code(), bits, 1); + n->vars = {Var("v", dtype)}; n->doms = {Range::FromMinExtent(min, extent)}; - n->f_make_for_loop = [annotations, thread](Array vars, Array doms, Stmt body) -> For { + n->f_make_for_loop = [annotations, thread, dtype](Array vars, Array doms, + Stmt body) -> For { ICHECK_EQ(vars.size(), 1); ICHECK_EQ(doms.size(), 1); - IterVar iter_var(Range(nullptr), Var("iter", DataType::Int(32)), IterVarType::kThreadIndex, - thread); + IterVar iter_var(Range(nullptr), Var("iter", dtype), IterVarType::kThreadIndex, thread); return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var, annotations.value_or(Map())); }; diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index 413894264ea6..249555ad6ed2 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -836,7 +836,7 @@ class CrossThreadReductionTransformer : public StmtMutator { /*kind=*/ForKind::kThreadBinding, // /*body=*/body, // /*thread_binding=*/ - IterVar(NullValue(), Var(""), IterVarType::kThreadIndex, + IterVar(NullValue(), Var("", loop_vars[i]->dtype), IterVarType::kThreadIndex, "threadIdx." + dim_index)); } return body; diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py b/tests/python/unittest/test_tvmscript_parser_tir.py index a7b6e53e6fe4..b2b534064605 100644 --- a/tests/python/unittest/test_tvmscript_parser_tir.py +++ b/tests/python/unittest/test_tvmscript_parser_tir.py @@ -325,5 +325,20 @@ def evaluated(A: T.Buffer((2, 128, 128), "int32")): tvm.ir.assert_structural_equal(with_builtin, evaluated) +def test_thread_binding_dtype(): + @T.prim_func(private=True) + def func(A: T.Buffer((128, 128)), B: T.Buffer((128, 128))): + for i in T.thread_binding(T.int64(128), "threadIdx.x"): + for j in T.thread_binding(128, "threadIdx.y"): + B[i, j] = A[i, j] + + loop_i = func.body + loop_j = loop_i.body + assert loop_i.loop_var.dtype == "int64" + assert loop_i.thread_binding.var.dtype == "int64" + assert loop_j.loop_var.dtype == "int32" + assert loop_j.thread_binding.var.dtype == "int32" + + if __name__ == "__main__": tvm.testing.main()