From ccf665c855f347b6c365f54a7734d05a97fb08be Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 5 Jun 2023 12:12:32 -0500 Subject: [PATCH 1/5] [Unity] Provide FuncStructInfo from `bb.emit_te` Prior to this commit, the PrimFunc generated by `bb.call_te` had no struct info associated with it. This commit updates `gen_call_tir_inputs`, which converts from a TE expression into a TIR PrimFunc, to annotate the PrimFunc with `FuncStructInfo` representing the input and output shapes. Providing this functionality for PrimFuncs produced from TE is a simpler case than a general PrimFunc, as TE has well-defined input and output tensors. --- python/tvm/relax/utils.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 02dd941080db..39b640272fb5 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -27,9 +27,10 @@ from . import _ffi_api from .expr import Tuple as rx_Tuple from .expr import Expr, ShapeExpr, Function, PrimValue, StringImm, te_tensor +from .expr import _update_struct_info from ..te import Tensor as te_Tensor, create_relax_prim_func from ..ir import Array, Attrs, Type, Map -from .struct_info import PrimStructInfo, ShapeStructInfo, TensorStructInfo +from .struct_info import PrimStructInfo, ShapeStructInfo, TensorStructInfo, FuncStructInfo def metadata_partitioner(rx_txt: str) -> List[str]: @@ -455,10 +456,19 @@ def _shape_with_old_tir_var( # with old set of variables. tir_var_inverse_map = {v: k for k, v in tir_var_map.items()} - output_sinfo = [ - TensorStructInfo(_shape_with_old_tir_var(out.shape, tir_var_inverse_map), out.dtype) - for out in outs - ] + def te_to_sinfo(te_tensor): + return TensorStructInfo( + _shape_with_old_tir_var(te_tensor.shape, tir_var_inverse_map), te_tensor.dtype + ) + + input_sinfo = [te_to_sinfo(arg) for arg in te_args] + if len(outs) == 1: + output_sinfo = te_to_sinfo(outs[0]) + else: + output_sinfo = TupleStructInfo(output_sinfo=[te_to_sinfo(out) for out in outs]) + + primfunc_sinfo = FuncStructInfo(input_sinfo, output_sinfo) + _update_struct_info(tir_func, primfunc_sinfo) tir_vars = None if len(unbound_tir_vars) > 0: From 0cc5f49f1447070ca0680cd4f8b1412c2d20eb67 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 5 Jun 2023 13:10:37 -0500 Subject: [PATCH 2/5] Lint fix --- python/tvm/relax/utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 39b640272fb5..0cf703713d57 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -456,10 +456,8 @@ def _shape_with_old_tir_var( # with old set of variables. tir_var_inverse_map = {v: k for k, v in tir_var_map.items()} - def te_to_sinfo(te_tensor): - return TensorStructInfo( - _shape_with_old_tir_var(te_tensor.shape, tir_var_inverse_map), te_tensor.dtype - ) + def te_to_sinfo(arg): + return TensorStructInfo(_shape_with_old_tir_var(arg.shape, tir_var_inverse_map), arg.dtype) input_sinfo = [te_to_sinfo(arg) for arg in te_args] if len(outs) == 1: From 58e7a3c3459079d3ff64e40557024a879887545a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 6 Jun 2023 10:05:26 -0500 Subject: [PATCH 3/5] Import TupleStructInfo, maintain return type for output_sinfo --- python/tvm/relax/utils.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 0cf703713d57..3f4869ceadbe 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -30,7 +30,13 @@ from .expr import _update_struct_info from ..te import Tensor as te_Tensor, create_relax_prim_func from ..ir import Array, Attrs, Type, Map -from .struct_info import PrimStructInfo, ShapeStructInfo, TensorStructInfo, FuncStructInfo +from .struct_info import ( + PrimStructInfo, + ShapeStructInfo, + TensorStructInfo, + FuncStructInfo, + TupleStructInfo, +) def metadata_partitioner(rx_txt: str) -> List[str]: @@ -460,12 +466,11 @@ def te_to_sinfo(arg): return TensorStructInfo(_shape_with_old_tir_var(arg.shape, tir_var_inverse_map), arg.dtype) input_sinfo = [te_to_sinfo(arg) for arg in te_args] - if len(outs) == 1: - output_sinfo = te_to_sinfo(outs[0]) - else: - output_sinfo = TupleStructInfo(output_sinfo=[te_to_sinfo(out) for out in outs]) + output_sinfo = [te_to_sinfo(out) for out in outs] - primfunc_sinfo = FuncStructInfo(input_sinfo, output_sinfo) + primfunc_sinfo = FuncStructInfo( + input_sinfo, output_sinfo[0] if len(output_sinfo) == 1 else TupleStructInfo(output_sinfo) + ) _update_struct_info(tir_func, primfunc_sinfo) tir_vars = None From 6e0777f22b262250218897c9c785f9cf694fa58c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 6 Jun 2023 10:22:50 -0500 Subject: [PATCH 4/5] Use FuncStructInfo for output-passing style The PrimFunc's GlobalVar is later used as the CallNode::op, and must have correct shape inference at that point. --- python/tvm/relax/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 3f4869ceadbe..c2e0f9638f7c 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -468,9 +468,7 @@ def te_to_sinfo(arg): input_sinfo = [te_to_sinfo(arg) for arg in te_args] output_sinfo = [te_to_sinfo(out) for out in outs] - primfunc_sinfo = FuncStructInfo( - input_sinfo, output_sinfo[0] if len(output_sinfo) == 1 else TupleStructInfo(output_sinfo) - ) + primfunc_sinfo = FuncStructInfo([*input_sinfo, *output_sinfo], PrimStructInfo("void")) _update_struct_info(tir_func, primfunc_sinfo) tir_vars = None From 74e207bf337f499a82c83e8d3eca245e55331a76 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 3 Jul 2023 08:18:52 -0500 Subject: [PATCH 5/5] Annotate non-tensor arguments with PrimStructInfo --- python/tvm/relax/utils.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index c2e0f9638f7c..e844b4abded3 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -463,9 +463,14 @@ def _shape_with_old_tir_var( tir_var_inverse_map = {v: k for k, v in tir_var_map.items()} def te_to_sinfo(arg): - return TensorStructInfo(_shape_with_old_tir_var(arg.shape, tir_var_inverse_map), arg.dtype) - - input_sinfo = [te_to_sinfo(arg) for arg in te_args] + if isinstance(arg, tir.Var): + return PrimStructInfo(arg.dtype) + else: + return TensorStructInfo( + _shape_with_old_tir_var(arg.shape, tir_var_inverse_map), arg.dtype + ) + + input_sinfo = [te_to_sinfo(arg) for arg in [*te_args, *unbound_tir_vars]] output_sinfo = [te_to_sinfo(out) for out in outs] primfunc_sinfo = FuncStructInfo([*input_sinfo, *output_sinfo], PrimStructInfo("void"))