From e5c8be41b3d6637fd5b591ba1f22ca36c03151e0 Mon Sep 17 00:00:00 2001 From: Lucien0 <16538059+Lucien0@users.noreply.github.com> Date: Mon, 14 Aug 2023 13:27:42 +0800 Subject: [PATCH] adapt tir for loop var dtype --- src/script/ir_builder/tir/ir.cc | 4 ++-- .../unittest/test_tvmscript_ir_builder_tir.py | 20 +++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 24dfa425ddcb..a60a37c5a318 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -319,7 +319,7 @@ Array Remap(String kinds, Array bindings, DataType dtype) { PrimExpr extent = arith::Analyzer().Simplify(stop - start); \ ObjectPtr n = make_object(); \ int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \ - n->vars = {Var("v", DataType::Int(bits))}; \ + n->vars = {Var("v", DataType(min.dtype().code(), bits, 1))}; \ n->doms = {Range::FromMinExtent(min, extent)}; \ n->f_make_for_loop = [annotations](Array vars, Array doms, tvm::tir::Stmt body) { \ ICHECK_EQ(vars.size(), 1); \ @@ -344,7 +344,7 @@ ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, PrimExpr extent = arith::Analyzer().Simplify(stop - start); ObjectPtr n = make_object(); int bits = std::max(min.dtype().bits(), extent.dtype().bits()); - n->vars = {Var("v", DataType::Int(bits))}; + n->vars = {Var("v", DataType(min.dtype().code(), bits, 1))}; n->doms = {Range::FromMinExtent(min, extent)}; n->f_make_for_loop = [annotations, thread](Array vars, Array doms, Stmt body) -> For { ICHECK_EQ(vars.size(), 1); diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py index 70180ecbb05f..e13b609d86bb 100644 --- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py +++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py @@ -267,6 +267,26 @@ def test_ir_builder_tir_for(): assert_structural_equal(for_actual, for_expected, map_free_vars=True) +def test_ir_builder_tir_for_uint(): + with IRBuilder() as ib: + with T.serial(tir.const(128, "uint32")) as a: + T.evaluate(0) + + # the for generated by IRBuilder + for_actual = ib.get() + + for_expected = tir.For( + loop_var=tir.Var("", "uint32"), + min_val=tir.const(0, "uint32"), + extent=tir.const(128, "uint32"), + kind=tir.ForKind.SERIAL, + body=tir.Evaluate(0), + ) + + # Check if the generated ir is expected + assert_structural_equal(for_actual, for_expected, map_free_vars=True) + + def test_ir_builder_tir_assert(): with IRBuilder() as ib: with T.Assert(T.int32() == 0, message="a is 0"):