diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 02dd941080db..e844b4abded3 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -27,9 +27,16 @@ from . import _ffi_api from .expr import Tuple as rx_Tuple from .expr import Expr, ShapeExpr, Function, PrimValue, StringImm, te_tensor +from .expr import _update_struct_info from ..te import Tensor as te_Tensor, create_relax_prim_func from ..ir import Array, Attrs, Type, Map -from .struct_info import PrimStructInfo, ShapeStructInfo, TensorStructInfo +from .struct_info import ( + PrimStructInfo, + ShapeStructInfo, + TensorStructInfo, + FuncStructInfo, + TupleStructInfo, +) def metadata_partitioner(rx_txt: str) -> List[str]: @@ -455,10 +462,19 @@ def _shape_with_old_tir_var( # with old set of variables. tir_var_inverse_map = {v: k for k, v in tir_var_map.items()} - output_sinfo = [ - TensorStructInfo(_shape_with_old_tir_var(out.shape, tir_var_inverse_map), out.dtype) - for out in outs - ] + def te_to_sinfo(arg): + if isinstance(arg, tir.Var): + return PrimStructInfo(arg.dtype) + else: + return TensorStructInfo( + _shape_with_old_tir_var(arg.shape, tir_var_inverse_map), arg.dtype + ) + + input_sinfo = [te_to_sinfo(arg) for arg in [*te_args, *unbound_tir_vars]] + output_sinfo = [te_to_sinfo(out) for out in outs] + + primfunc_sinfo = FuncStructInfo([*input_sinfo, *output_sinfo], PrimStructInfo("void")) + _update_struct_info(tir_func, primfunc_sinfo) tir_vars = None if len(unbound_tir_vars) > 0: