From c85389a474b2fea1c971090ead266dc8c2128a4c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 5 Aug 2024 11:46:44 -0500 Subject: [PATCH 1/2] [Relax] Avoid wrapping TupleStructInfo into a Tuple for R.call_tir Prior to this commit, the different `R.call_tir*` variations would wrap the arguments into an in-line `relax.Tuple`, if it is not already a `relax.Tuple`. While this allows a tensor to be passed into these functions as a single argument (`R.call_tir(func, arg, ...)` instead of `R.call_tir(func, [arg], ...)`), the wrapped Relax variable may already refer to a tuple. This use of a variable to refer to an argument tuple rather than an in-line argument tuple is not allowed by Relax. (See discussion on https://github.com/apache/tvm/pull/15916 for details.) However, by wrapping a variable `args: R.Tuple(R.Tensor, R.Tensor, ...)` into a tuple-of-tuples, the error occurs after the expression has already been generated, and refers to an expression `R.Tuple(R.Tuple(R.Tensor, R.Tensor, ...))` that doesn't appear anywhere in the user's input. This can make debugging difficult (see https://github.com/apache/tvm/issues/17239 for an example). This commit updates the argument-handling in `R.call_tir` to only generate an in-line `relax.Tuple` if the arguments do not already have `relax.TupleStructInfo`. If the argument was provided as a Relax variable bound to a tuple of arguments, it will still produce an error. However, that error will occur much earlier, and will explicitly state that the argument must be a `relax.Tuple` instead of a `relax.Var`. --- python/tvm/relax/op/base.py | 37 ++++++++++++++----- tests/python/relax/test_tvmscript_parser.py | 40 +++++++++++++++++++++ 2 files changed, 68 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 756d250c1687..03e86a4633a6 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # pylint: disable=redefined-builtin """The base Relax operators.""" + from typing import Dict, Union, List, Tuple, Optional, Callable @@ -25,7 +26,6 @@ from . import _ffi_api from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar, Var -from ..expr import Tuple as RxTuple from ..struct_info import StructInfo, TensorStructInfo from ...ir import PrimExpr from ..utils import args_converter @@ -67,6 +67,29 @@ def null_value() -> Call: return _ffi_api.null_value() # type: ignore +def _wrap_inline_arg_tuple(args) -> Expr: + """Helper function to wrap argument tuple + + Normalize the arguments provided the functions that accept a tuple + of arguments, and require the tuple of arguments to be written + in-line. If the arguments provided are a single relax expression, + and are not a reference to a relax tuple, then wrap them into an + in-line relax Tuple. + + """ + if ( + isinstance(args, Expr) + and not isinstance(args, tvm.relax.Tuple) + and ( + args.struct_info_ is None + or not isinstance(args.struct_info_, tvm.relax.TupleStructInfo) + ) + ): + return tvm.relax.Tuple([args]) + else: + return args + + @args_converter.auto def call_tir( gvar: GlobalVar, @@ -98,8 +121,7 @@ def call_tir( ret: Call A call node for the call_tir operator. """ - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore - args = RxTuple((args,)) + args = _wrap_inline_arg_tuple(args) if not isinstance(out_sinfo, list): out_sinfo = [out_sinfo] @@ -153,8 +175,7 @@ def call_tir_with_grad( ret: Call A call node for the call_tir_with_grad operator. """ - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore - args = RxTuple((args,)) + args = _wrap_inline_arg_tuple(args) if not isinstance(out_sinfo, list): out_sinfo = [out_sinfo] @@ -221,8 +242,7 @@ def call_tir_inplace( ret: Call A call node for the call_tir operator. """ - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore - args = RxTuple((args,)) + args = _wrap_inline_arg_tuple(args) if not isinstance(inplace_indices, list): inplace_indices = [inplace_indices] @@ -276,8 +296,7 @@ def call_dps_packed( if isinstance(func, str): func = ExternFunc(func) - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore - args = RxTuple((args,)) + args = _wrap_inline_arg_tuple(args) if not isinstance(out_sinfo, list): out_sinfo = [out_sinfo] diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 4f41b662caf2..f66c7329b441 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -1043,6 +1043,46 @@ def main( _check(Module) +def test_call_tir_inplace_with_tuple_var_raises_error(): + + + with pytest.raises(tvm.error.DiagnosticError): + + @tvm.script.ir_module + class Module: + @R.function + def main( + x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32") + ) : + cls = Module + args = (x, y) + res = R.call_tir_inplace( + cls.copy, + # The `args` tuple must be an in-line tuple, not a + # reference to a tuple. This error should be + # caught and raised during parsing. + args, + inplace_indices = [0, -1], + out_sinfo=[R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")], + ) + return res + + @T.prim_func + def copy( + A: T.Buffer((2, 3), "int32"), + B: T.Buffer((2, 3), "int32"), + out1: T.Buffer((2, 3), "int32"), + ): + # copies the contents of B into A and out1 + T.func_attr({"tir.noalias": True}) + for iters in T.grid(T.int64(2), T.int64(3)): + with T.block("T_zeros"): + i, j = T.axis.remap("SS", iters) + A[i, j] = B[i, j] + out1[i, j] = B[i, j] + + + def test_local_function(): @R.function From eaa5d21cae760801be9e6ffc504e5e77202b8cb0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 5 Aug 2024 14:06:44 -0500 Subject: [PATCH 2/2] lint fixes --- tests/python/relax/test_tvmscript_parser.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index f66c7329b441..ea99d49270a1 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -1043,17 +1043,15 @@ def main( _check(Module) -def test_call_tir_inplace_with_tuple_var_raises_error(): +def test_call_tir_inplace_with_tuple_var_raises_error(): with pytest.raises(tvm.error.DiagnosticError): @tvm.script.ir_module class Module: @R.function - def main( - x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32") - ) : + def main(x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32")): cls = Module args = (x, y) res = R.call_tir_inplace( @@ -1062,7 +1060,7 @@ def main( # reference to a tuple. This error should be # caught and raised during parsing. args, - inplace_indices = [0, -1], + inplace_indices=[0, -1], out_sinfo=[R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")], ) return res @@ -1082,8 +1080,6 @@ def copy( out1[i, j] = B[i, j] - - def test_local_function(): @R.function def main(