diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 756d250c1687..03e86a4633a6 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # pylint: disable=redefined-builtin """The base Relax operators.""" + from typing import Dict, Union, List, Tuple, Optional, Callable @@ -25,7 +26,6 @@ from . import _ffi_api from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar, Var -from ..expr import Tuple as RxTuple from ..struct_info import StructInfo, TensorStructInfo from ...ir import PrimExpr from ..utils import args_converter @@ -67,6 +67,29 @@ def null_value() -> Call: return _ffi_api.null_value() # type: ignore +def _wrap_inline_arg_tuple(args) -> Expr: + """Helper function to wrap argument tuple + + Normalize the arguments provided the functions that accept a tuple + of arguments, and require the tuple of arguments to be written + in-line. If the arguments provided are a single relax expression, + and are not a reference to a relax tuple, then wrap them into an + in-line relax Tuple. + + """ + if ( + isinstance(args, Expr) + and not isinstance(args, tvm.relax.Tuple) + and ( + args.struct_info_ is None + or not isinstance(args.struct_info_, tvm.relax.TupleStructInfo) + ) + ): + return tvm.relax.Tuple([args]) + else: + return args + + @args_converter.auto def call_tir( gvar: GlobalVar, @@ -98,8 +121,7 @@ def call_tir( ret: Call A call node for the call_tir operator. """ - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore - args = RxTuple((args,)) + args = _wrap_inline_arg_tuple(args) if not isinstance(out_sinfo, list): out_sinfo = [out_sinfo] @@ -153,8 +175,7 @@ def call_tir_with_grad( ret: Call A call node for the call_tir_with_grad operator. """ - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore - args = RxTuple((args,)) + args = _wrap_inline_arg_tuple(args) if not isinstance(out_sinfo, list): out_sinfo = [out_sinfo] @@ -221,8 +242,7 @@ def call_tir_inplace( ret: Call A call node for the call_tir operator. """ - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore - args = RxTuple((args,)) + args = _wrap_inline_arg_tuple(args) if not isinstance(inplace_indices, list): inplace_indices = [inplace_indices] @@ -276,8 +296,7 @@ def call_dps_packed( if isinstance(func, str): func = ExternFunc(func) - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore - args = RxTuple((args,)) + args = _wrap_inline_arg_tuple(args) if not isinstance(out_sinfo, list): out_sinfo = [out_sinfo] diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 4f41b662caf2..ea99d49270a1 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -1044,6 +1044,42 @@ def main( _check(Module) +def test_call_tir_inplace_with_tuple_var_raises_error(): + + with pytest.raises(tvm.error.DiagnosticError): + + @tvm.script.ir_module + class Module: + @R.function + def main(x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32")): + cls = Module + args = (x, y) + res = R.call_tir_inplace( + cls.copy, + # The `args` tuple must be an in-line tuple, not a + # reference to a tuple. This error should be + # caught and raised during parsing. + args, + inplace_indices=[0, -1], + out_sinfo=[R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")], + ) + return res + + @T.prim_func + def copy( + A: T.Buffer((2, 3), "int32"), + B: T.Buffer((2, 3), "int32"), + out1: T.Buffer((2, 3), "int32"), + ): + # copies the contents of B into A and out1 + T.func_attr({"tir.noalias": True}) + for iters in T.grid(T.int64(2), T.int64(3)): + with T.block("T_zeros"): + i, j = T.axis.remap("SS", iters) + A[i, j] = B[i, j] + out1[i, j] = B[i, j] + + def test_local_function(): @R.function def main(