From 331ae70423bd430195e0f80826fbf0552c66d008 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Tue, 24 Sep 2024 04:47:12 +0000 Subject: [PATCH] bufferload's index dtype narrowing should not inherit value bits constraint --- src/tir/transforms/narrow_datatype.cc | 14 +++++++++++++- .../test_tir_transform_narrow_datatype.py | 17 +++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 7b6187af64b8..696eae201f3c 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -97,6 +97,13 @@ class DataTypeVisitor final : public StmtExprVisitor { } } + void VisitExpr_(const BufferLoadNode* op) { + int tmp = bits_; + bits_ = target_bits_; + StmtExprVisitor::VisitExpr_(op); + bits_ = tmp; + } + void VisitStmt_(const ForNode* op) { analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); vextent_[op->loop_var.as()] = op->extent.dtype(); @@ -245,7 +252,12 @@ class NarrowDataTypeRewriter : public IndexDataTypeRewriter { const CastNode* new_op = e.as(); ICHECK(new_op != nullptr) << "Expected type to be CastNode" << ", but get " << e->GetTypeKey(); - return Cast(visitor_.vmap[op], new_op->value); + PrimExpr new_value = new_op->value; + DataType cast_type = visitor_.vmap[op]; + if (new_value.dtype() != cast_type) { + new_value = Cast(cast_type, new_value); + } + return new_value; } return Parent::VisitExpr_(op); } diff --git a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py index c03dd7a5291d..cf85f2e3714c 100644 --- a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py +++ b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py @@ -413,5 +413,22 @@ def expected_after(PSUM: T.Buffer((313600,), "int32"), PAVG: T.Buffer((313600,), tvm.ir.assert_structural_equal(after["main"], expected_after.with_attr("global_symbol", "main")) +def test_narrow_i64_valued_bufferload_index_to_i32(): + @T.prim_func + def before(A: T.Buffer((16,), "int64")): + for i in range(T.int64(15)): + A[i + T.int64(1)] = A[i] + T.int64(1) + + @T.prim_func + def expect(A: T.Buffer((16,), "int64")): + for i in range(15): + A[i + 1] = A[i] + T.int64(1) + + after = tvm.tir.transform.NarrowDataType(32)( + tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) + )["main"] + tvm.ir.assert_structural_equal(after, expect.with_attr("global_symbol", "main")) + + if __name__ == "__main__": tvm.testing.main()