diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 1370758f5a5b..0e74114ba29c 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -20,6 +20,7 @@ from functools import partial from typing import Any +import tvm from tvm.ir import PrimType from tvm.tir import Buffer, IterVar, PrimExpr, Var @@ -411,6 +412,10 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: if isinstance(res, Frame): res.add_callback(partial(res.__exit__, None, None, None)) res.__enter__() + elif isinstance(res, PrimExpr): + T.evaluate(res) + elif isinstance(res, (int, bool)): + T.evaluate(tvm.tir.const(res)) @dispatch.register(token="tir", type_name="If") diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index d7a3a406e352..f1d68ee43845 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1275,16 +1275,17 @@ Doc TVMScriptPrinter::VisitStmt_(const SeqStmtNode* op) { } Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) { - if (auto* call = op->value.as()) { - if (call->op.same_as(builtin::assume())) { - Doc doc; - doc << tir_prefix_ << ".assume(" << Print(call->args[0]) << ")"; - return doc; - } - } - + // When parsing TVMScript, a PrimExpr that occurs as a statement is + // automatically wrapped in `tir::Evaluate`. Therefore, when + // printing, it's only necessary to print the value. For + // readability, though, we still print T.evaluate() when the + // expression is something other than a call node. Doc doc; - doc << tir_prefix_ << ".evaluate(" << Print(op->value) << ")"; + if (op->value.as()) { + doc << Print(op->value); + } else { + doc << tir_prefix_ << ".evaluate(" << Print(op->value) << ")"; + } return doc; } diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index f22e61e1838d..b8c8379c8a16 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3458,6 +3458,15 @@ def func() -> None: return func +def implicit_evaluate(): + @T.prim_func + def func(A: T.Buffer[1, "int32"]): + T.evaluate(T.assume(A[0] == 5)) + A[0] = 10 + + return func + + ir_generator = tvm.testing.parameter( opt_gemm_normalize, opt_gemm_lower, @@ -3509,6 +3518,7 @@ def func() -> None: bool_primitive, bool_cast, return_none, + implicit_evaluate, ) diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index 16f1cb04945a..a39354b9552a 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -402,5 +402,31 @@ def int64_grid_expanded( assert_structural_equal(int64_grid, int64_grid_expanded) +def test_implicit_evaluate_assume(): + @T.prim_func + def explicit(A: T.Buffer[1, "int32"]): + T.evaluate(T.assume(A[0] == 5)) + A[0] = 10 + + @T.prim_func + def implicit(A: T.Buffer[1, "int32"]): + T.assume(A[0] == 5) + A[0] = 10 + + assert_structural_equal(implicit, explicit) + + +def test_implicit_evaluate_call_extern(): + @T.prim_func + def explicit(A: T.Buffer[1, "int32"]): + T.evaluate(T.call_extern("extern_func", A.data, dtype="int32")) + + @T.prim_func + def implicit(A: T.Buffer[1, "int32"]): + T.call_extern("extern_func", A.data, dtype="int32") + + assert_structural_equal(implicit, explicit) + + if __name__ == "__main__": tvm.testing.main()