diff --git a/include/tvm/tir/data_type_rewriter.h b/include/tvm/tir/data_type_rewriter.h index 846cda74c67d..913e2ab189ff 100644 --- a/include/tvm/tir/data_type_rewriter.h +++ b/include/tvm/tir/data_type_rewriter.h @@ -110,6 +110,7 @@ class IndexDataTypeRewriter : public DataTypeLegalizer { Stmt VisitStmt_(const IfThenElseNode* op) override; Stmt VisitStmt_(const DeclBufferNode* op) override; Stmt VisitStmt_(const AllocateNode* op) override; + Stmt VisitStmt_(const LetStmtNode* op) override; PrimExpr VisitExpr_(const EQNode* op) override; PrimExpr VisitExpr_(const NENode* op) override; PrimExpr VisitExpr_(const LTNode* op) override; diff --git a/python/tvm/relax/backend/dispatch_sort_scan.py b/python/tvm/relax/backend/dispatch_sort_scan.py index e25c28e5711a..53948b8449b0 100644 --- a/python/tvm/relax/backend/dispatch_sort_scan.py +++ b/python/tvm/relax/backend/dispatch_sort_scan.py @@ -155,9 +155,13 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: tgt = self._get_target(call.struct_info) axis = int(call.attrs.axis) if call.attrs.axis is not None else call.attrs.axis shape = call.struct_info.shape + # TODO(tvm-team): Support fully dynamic case with `shape=None` + if shape is None: + raise ValueError("non-symbolic shape is not supported for now") kwargs = {} if ( - (axis == -1 or axis == len(shape) - 1) + shape is not None + and (axis == -1 or axis == len(shape) - 1) and is_gpu_target(tgt) and not can_use_thrust(tgt, "tvm.contrib.thrust.sum_scan") and call.op.name == "relax.cumsum" diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index a613b8d4bb0c..7450c4762c7c 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -27,6 +27,10 @@ #include #include "./functor_common.h" +#include "tvm/ir/expr.h" +#include "tvm/tir/expr.h" +#include "tvm/tir/stmt.h" +#include "tvm/tir/var.h" namespace tvm { namespace tir { @@ -556,6 +560,21 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) { } } +Stmt IndexDataTypeRewriter::VisitStmt_(const LetStmtNode* op) { + LetStmt let_stmt = Downcast(DataTypeLegalizer::VisitStmt_(op)); + if (var_remap_.find(let_stmt->var.get()) == var_remap_.end()) { + return let_stmt; + } + bool is_enabled = is_enabled_; + is_enabled_ = true; + PrimExpr value = VisitExpr(op->value); + Var var = var_remap_[let_stmt->var.get()]; + is_enabled_ = is_enabled; + ICHECK(value.dtype() == var.dtype()); + // No need to re-visit body + return LetStmt(var, value, let_stmt->body, let_stmt->span); +} + #define TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ PrimExpr IndexDataTypeRewriter::VisitExpr_(const OP* op) { \ bool is_enabled = is_enabled_; \ diff --git a/tests/python/relax/test_backend_dispatch_sort_scan.py b/tests/python/relax/test_backend_dispatch_sort_scan.py index a53962106044..2ab5afaabf24 100644 --- a/tests/python/relax/test_backend_dispatch_sort_scan.py +++ b/tests/python/relax/test_backend_dispatch_sort_scan.py @@ -273,7 +273,7 @@ def foo2(y: R.Tensor((2, 3), "float32")): if can_use_thrust(target, "tvm.contrib.thrust.sort"): workspace = bb.emit( relax.op.builtin.alloc_tensor( - R.shape([4194568]), R.dtype("uint8"), R.prim_value(0), R.str("global") + R.shape([8388872]), R.dtype("uint8"), R.prim_value(0), R.str("global") ) ) out = bb.emit_te( @@ -400,8 +400,8 @@ def foo(x: R.Tensor((2, 3), "float32", "vulkan")): assert_structural_equal(mod, expected_mod) -@tvm.testing.requires_cuda -def test_dispatch_cumsum_gpu(): +@tvm.testing.parametrize_targets("cuda", "vulkan -supports_int64=1") +def test_dispatch_cumsum_gpu(target, dev): """Test cumsum kernel dispatch and numerical correctness""" @I.ir_module @@ -416,15 +416,13 @@ def main(x: R.Tensor(("m", "n"), "int32")): size = (8, 2000) np_data = np.random.randint(0, 10, size).astype("int32") np_cumsum = np.cumsum(np_data, axis=-1) - for target in ["cuda", "vulkan -supports_int64=1"]: - with tvm.target.Target(target): - mod = DispatchSortScan()(Module) - ex = tvm.relax.build(mod, target) - device = tvm.device(target, 0) - vm = tvm.relax.VirtualMachine(ex, device) - tvm_data = tvm.nd.array(np_data, device) - cumsum = vm["main"](tvm_data) - tvm.testing.assert_allclose(cumsum.numpy(), np_cumsum) + with tvm.target.Target(target): + mod = DispatchSortScan()(Module) + ex = tvm.relax.build(mod, target) + vm = tvm.relax.VirtualMachine(ex, dev) + tvm_data = tvm.nd.array(np_data, dev) + cumsum = vm["main"](tvm_data) + tvm.testing.assert_allclose(cumsum.numpy(), np_cumsum) if __name__ == "__main__": diff --git a/tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py b/tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py index 0be0e5fbb573..c85929e4f6bf 100644 --- a/tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py +++ b/tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py @@ -278,5 +278,30 @@ def main(B: T.Buffer((4,), "int32")): tvm.ir.assert_structural_equal(Expected, after) +def test_let_binding(): + @tvm.script.ir_module + class Before: + @T.prim_func + def main(buf: T.handle): + n = T.int64() + Buf = T.match_buffer(buf, [n], "int32") + ceil_log2 = T.Cast("int64", T.ceil(T.log2(T.Cast("float32", n)))) + for i in T.serial(ceil_log2): + T.evaluate(0) + + @tvm.script.ir_module + class Expected: + @T.prim_func + def main(buf: T.handle): + n = T.int32() + Buf = T.match_buffer(buf, [n], "int32") + ceil_log2 = T.Cast("int32", T.ceil(T.log2(T.Cast("float32", n)))) + for i in range(ceil_log2): + T.evaluate(0) + + after = tvm.tir.transform.ForceNarrowIndexToInt32()(Before) + tvm.ir.assert_structural_equal(Expected, after) + + if __name__ == "__main__": tvm.testing.main()