Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from functools import partial
from typing import Any

import tvm
from tvm.ir import PrimType
from tvm.tir import Buffer, IterVar, PrimExpr, Var

Expand Down Expand Up @@ -411,6 +412,10 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
if isinstance(res, Frame):
res.add_callback(partial(res.__exit__, None, None, None))
res.__enter__()
elif isinstance(res, PrimExpr):
T.evaluate(res)
elif isinstance(res, (int, bool)):
T.evaluate(tvm.tir.const(res))
Comment on lines +417 to +418
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey i was wondering if we want to explicitly print T.evaluate(0) vs 0? The former one might look a bit clearer from my perspective

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going back and forth on it, and I like the idea of printing T.evaluate() except in the case of CallNode. The more aggressive sugaring would be to render tir::Evaluate(0) as pass, which would be even clearer to read from a casual Python reader.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good idea!



@dispatch.register(token="tir", type_name="If")
Expand Down
19 changes: 10 additions & 9 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1275,16 +1275,17 @@ Doc TVMScriptPrinter::VisitStmt_(const SeqStmtNode* op) {
}

Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) {
if (auto* call = op->value.as<CallNode>()) {
if (call->op.same_as(builtin::assume())) {
Doc doc;
doc << tir_prefix_ << ".assume(" << Print(call->args[0]) << ")";
return doc;
}
}

// When parsing TVMScript, a PrimExpr that occurs as a statement is
// automatically wrapped in `tir::Evaluate`. Therefore, when
// printing, it's only necessary to print the value. For
// readability, though, we still print T.evaluate() when the
// expression is something other than a call node.
Doc doc;
doc << tir_prefix_ << ".evaluate(" << Print(op->value) << ")";
if (op->value.as<CallNode>()) {
doc << Print(op->value);
} else {
doc << tir_prefix_ << ".evaluate(" << Print(op->value) << ")";
}
return doc;
}

Expand Down
10 changes: 10 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3458,6 +3458,15 @@ def func() -> None:
return func


def implicit_evaluate():
@T.prim_func
def func(A: T.Buffer[1, "int32"]):
T.evaluate(T.assume(A[0] == 5))
A[0] = 10
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious what's this line for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly because T.assume doesn't have any effect, and I was pulling this in from a failing case of T.assume in a separate PR. This was from a unit test that validated a pass that removed T.assume calls from a primfunc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. Thanks for the clarification!


return func


ir_generator = tvm.testing.parameter(
opt_gemm_normalize,
opt_gemm_lower,
Expand Down Expand Up @@ -3509,6 +3518,7 @@ def func() -> None:
bool_primitive,
bool_cast,
return_none,
implicit_evaluate,
)


Expand Down
26 changes: 26 additions & 0 deletions tests/python/unittest/test_tvmscript_syntax_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,5 +402,31 @@ def int64_grid_expanded(
assert_structural_equal(int64_grid, int64_grid_expanded)


def test_implicit_evaluate_assume():
@T.prim_func
def explicit(A: T.Buffer[1, "int32"]):
T.evaluate(T.assume(A[0] == 5))
A[0] = 10

@T.prim_func
def implicit(A: T.Buffer[1, "int32"]):
T.assume(A[0] == 5)
A[0] = 10

assert_structural_equal(implicit, explicit)


def test_implicit_evaluate_call_extern():
@T.prim_func
def explicit(A: T.Buffer[1, "int32"]):
T.evaluate(T.call_extern("extern_func", A.data, dtype="int32"))

@T.prim_func
def implicit(A: T.Buffer[1, "int32"]):
T.call_extern("extern_func", A.data, dtype="int32")

assert_structural_equal(implicit, explicit)


if __name__ == "__main__":
tvm.testing.main()