From 725a2cc9770f2222daa770a6c6117c7c05f70c2a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 22 May 2023 16:08:48 -0500 Subject: [PATCH 1/3] [TVMScript][TIR] Parse subroutine calls with no arguments In most cases, the IR dialect in `GlobalVar.__call__` can be inferred from the argument types. If there are no arguments, then the returned value is ambiguous. This commit updates the TIR parser to identify and fix this case of erroneously producing a `relay.Call` instead of `tir.Call`. In addition, to prevent this from re-occuring, an unrecognized type resulting from `def visit_expr_stmt` now results in an error, rather than being silently ignored. --- python/tvm/script/parser/tir/parser.py | 11 ++++++++++- .../unittest/test_tvmscript_roundtrip.py | 19 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index dfecaacdf655..594dbe75d311 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -406,13 +406,22 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: The doc AST Expr node. """ res = self.eval_expr(node.value) - if isinstance(res, Frame): + if res is None: + return res + elif isinstance(res, Frame): res.add_callback(partial(res.__exit__, None, None, None)) res.__enter__() elif isinstance(res, PrimExpr): T.evaluate(res) elif isinstance(res, (int, bool)): T.evaluate(tvm.tir.const(res)) + elif isinstance(res, tvm.relay.Call) and not res.args: + # Using GlobalVar.__call__ with no arguments is ambiguous, as + # each IR has a different function Call representation. If + # this occurs, convert to the TIR representation. + return T.evaluate(tvm.tir.call_tir(res.op)) + else: + self.report_error(node, f"Parsing resulted in unexpected type {type(res)}") @dispatch.register(token="tir", type_name="If") diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 2ea7d3ec6579..3e09887e0d0d 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3855,6 +3855,24 @@ def func(): return func +def subroutine_call_without_arguments(): + @I.ir_module + class mod: + @T.prim_func + def main(): + # Should be equivalent to the bare "mod.subroutine()", but + # that relies on `GlobalVar.__call__` returning the + # correct IR type. Previously, this instead returned a + # `relay.Call` object. + tir.call_tir(mod.subroutine) + + @T.prim_func + def subroutine(): + T.evaluate(0) + + return mod + + ir_generator = tvm.testing.parameter( launch_env_thread, opt_gemm_normalize, @@ -3929,6 +3947,7 @@ def func(): undefined_shape_in_decl_buffer, undefined_stride_in_decl_buffer, undefined_elem_offset_in_decl_buffer, + subroutine_call_without_arguments, ) From 26248a8a39ebc08bce2d62b57fa03768afacbfda Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 23 May 2023 08:10:30 -0500 Subject: [PATCH 2/3] Ignore str for unknown parser result These may are used as docstrings in the TVMScript, even though they are not represented in the TIR. --- python/tvm/script/parser/tir/parser.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 594dbe75d311..67b8cad575ed 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -420,6 +420,9 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: # each IR has a different function Call representation. If # this occurs, convert to the TIR representation. return T.evaluate(tvm.tir.call_tir(res.op)) + elif isinstance(res, str): + # Ignore docstrings + pass else: self.report_error(node, f"Parsing resulted in unexpected type {type(res)}") From b76532e1ac5005d88cf00febea6d621ea76c7c73 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 23 May 2023 09:34:40 -0500 Subject: [PATCH 3/3] Lint fixes --- python/tvm/script/parser/tir/parser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 67b8cad575ed..7d81fecedbca 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -407,7 +407,7 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: """ res = self.eval_expr(node.value) if res is None: - return res + pass elif isinstance(res, Frame): res.add_callback(partial(res.__exit__, None, None, None)) res.__enter__() @@ -419,7 +419,7 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: # Using GlobalVar.__call__ with no arguments is ambiguous, as # each IR has a different function Call representation. If # this occurs, convert to the TIR representation. - return T.evaluate(tvm.tir.call_tir(res.op)) + T.evaluate(tvm.tir.call_tir(res.op)) elif isinstance(res, str): # Ignore docstrings pass