From 447706b5e7613d2a80a1feacaec8c3ffcbc63922 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 13 Dec 2023 16:43:00 +0000 Subject: [PATCH 1/3] [Unity][TVMScript] Optionally hide StructInfo that can be inferred By default, TVMScript prints the struct info of every variable being bound, which can become quite verbose. This commit adds the configuration option `show_inferable_type_annotations`, which determines whether struct info annotations are shown in cases where they can be inferred. The `show_inferable_type_annotations` option defaults to `True`, preserving the current default behavior. --- include/tvm/node/script_printer.h | 3 ++ python/tvm/runtime/script_printer.py | 21 +++++++-- src/node/script_printer.cc | 3 ++ src/script/printer/relax/binding.cc | 5 ++- src/script/printer/relax/utils.h | 37 ++++++++++++++++ .../relax/test_tvmscript_printer_relax.py | 43 +++++++++++++++++++ 6 files changed, 107 insertions(+), 5 deletions(-) diff --git a/include/tvm/node/script_printer.h b/include/tvm/node/script_printer.h index 703844d99ae9..515a5f8d1215 100644 --- a/include/tvm/node/script_printer.h +++ b/include/tvm/node/script_printer.h @@ -72,6 +72,8 @@ class PrinterConfigNode : public Object { bool syntax_sugar = true; /*! \brief Whether variable names should include the object's address */ bool show_object_address = false; + /*! \brief Whether to show StructInfo that can be inferred from arguments */ + bool show_inferable_type_annotations = true; /* \brief Object path to be underlined */ Array path_to_underline = Array(); /*! \brief Object path to be annotated. */ @@ -97,6 +99,7 @@ class PrinterConfigNode : public Object { v->Visit("num_context_lines", &num_context_lines); v->Visit("syntax_sugar", &syntax_sugar); v->Visit("show_object_address", &show_object_address); + v->Visit("show_inferable_type_annotations", &show_inferable_type_annotations); v->Visit("path_to_underline", &path_to_underline); v->Visit("path_to_annotate", &path_to_annotate); v->Visit("obj_to_underline", &obj_to_underline); diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index 260d0ead9d8e..c5407fd59483 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -44,6 +44,7 @@ class PrinterConfig(Object): num_context_lines: int syntax_sugar: bool show_object_address: bool + show_inferable_type_annotations: bool path_to_underline: Optional[List[ObjectPath]] path_to_annotate: Optional[Dict[ObjectPath, str]] obj_to_underline: Optional[List[Object]] @@ -67,6 +68,7 @@ def __init__( num_context_lines: Optional[int] = None, syntax_sugar: bool = True, show_object_address: bool = False, + show_inferable_type_annotations: bool = True, path_to_underline: Optional[List[ObjectPath]] = None, path_to_annotate: Optional[Dict[ObjectPath, str]] = None, obj_to_underline: Optional[List[Object]] = None, @@ -89,6 +91,7 @@ def __init__( "num_context_lines": num_context_lines, "syntax_sugar": syntax_sugar, "show_object_address": show_object_address, + "show_inferable_type_annotations": show_inferable_type_annotations, "path_to_underline": path_to_underline, "path_to_annotate": path_to_annotate, "obj_to_underline": obj_to_underline, @@ -132,6 +135,7 @@ def script( num_context_lines: int = -1, syntax_sugar: bool = True, show_object_address: bool = False, + show_inferable_type_annotations: bool = True, path_to_underline: Optional[List[ObjectPath]] = None, path_to_annotate: Optional[Dict[ObjectPath, str]] = None, obj_to_underline: Optional[List[Object]] = None, @@ -169,9 +173,12 @@ def script( num_context_lines : int = -1 The number of lines of context to print before and after the line to underline. syntax_sugar: bool = True - Whether to output with syntax sugar, set false for complete printing. + Whether to output with syntax sugar, set false for complete printing. show_object_address: bool = False - Whether to include the object's address as part of the TVMScript name + Whether to include the object's address as part of the TVMScript name + show_inferable_type_annotations: bool = True + Whether to show type annotations that can be inferred from previous + annotations. path_to_underline : Optional[List[ObjectPath]] = None Object path to be underlined path_to_annotate : Optional[Dict[ObjectPath, str]] = None @@ -204,6 +211,7 @@ def script( num_context_lines=num_context_lines, syntax_sugar=syntax_sugar, show_object_address=show_object_address, + show_inferable_type_annotations=show_inferable_type_annotations, path_to_underline=path_to_underline, path_to_annotate=path_to_annotate, obj_to_underline=obj_to_underline, @@ -279,6 +287,7 @@ def show( num_context_lines: int = -1, syntax_sugar: bool = True, show_object_address: bool = False, + show_inferable_type_annotations: bool = True, path_to_underline: Optional[List[ObjectPath]] = None, path_to_annotate: Optional[Dict[ObjectPath, str]] = None, obj_to_underline: Optional[List[Object]] = None, @@ -339,9 +348,12 @@ def show( num_context_lines : int = -1 The number of lines of context to print before and after the line to underline. syntax_sugar: bool = True - Whether to output with syntax sugar, set false for complete printing. + Whether to output with syntax sugar, set false for complete printing. show_object_address: bool = False - Whether to include the object's address as part of the TVMScript name + Whether to include the object's address as part of the TVMScript name + show_inferable_type_annotations: bool = True + Whether to show type annotations that can be inferred from previous + annotations. path_to_underline : Optional[List[ObjectPath]] = None Object path to be underlined path_to_annotate : Optional[Dict[ObjectPath, str]] = None @@ -377,6 +389,7 @@ def show( num_context_lines=num_context_lines, syntax_sugar=syntax_sugar, show_object_address=show_object_address, + show_inferable_type_annotations=show_inferable_type_annotations, path_to_underline=path_to_underline, path_to_annotate=path_to_annotate, obj_to_underline=obj_to_underline, diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index 38334de357d8..74c3fb730b45 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -112,6 +112,9 @@ PrinterConfig::PrinterConfig(Map config_dict) { if (auto v = config_dict.Get("show_object_address")) { n->show_object_address = Downcast(v)->value; } + if (auto v = config_dict.Get("show_inferable_type_annotations")) { + n->show_inferable_type_annotations = Downcast(v)->value; + } // Checking prefixes if they are valid Python identifiers. CHECK(IsIdentifier(n->ir_prefix)) << "Invalid `ir_prefix`: " << n->ir_prefix; diff --git a/src/script/printer/relax/binding.cc b/src/script/printer/relax/binding.cc index 395b4251fb5e..efbdc9d8a5b7 100644 --- a/src/script/printer/relax/binding.cc +++ b/src/script/printer/relax/binding.cc @@ -44,7 +44,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) "", [](relax::MatchCast n, ObjectPath n_p, IRDocsifier d) -> Doc { using relax::StructInfo; using relax::MatchStructInfo; - Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); + Optional ann = NullOpt; + if (d->cfg->show_inferable_type_annotations) { + ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); + } ExprDoc rhs = Relax(d, "match_cast") ->Call({d->AsDoc(n->value, n_p->Attr("value")), d->AsDoc(n->struct_info, n_p->Attr("struct_info_"))}); diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h index 58b8bf443173..1526277a0fce 100644 --- a/src/script/printer/relax/utils.h +++ b/src/script/printer/relax/utils.h @@ -19,6 +19,8 @@ #ifndef TVM_SCRIPT_PRINTER_RELAX_UTILS_H_ #define TVM_SCRIPT_PRINTER_RELAX_UTILS_H_ +#include +#include #include #include #include @@ -89,6 +91,41 @@ inline Optional StructInfoAsAnn(const relax::Var& v, const ObjectPath& return NullOpt; } } + if (!d->cfg->show_inferable_type_annotations) { + Optional inferred_sinfo = NullOpt; + if (auto opt = rhs.as()) { + auto call = opt.value(); + if (auto opt = call->op.as()) { + auto op = opt.value(); + + static auto op_map_infer_struct_info = + Op::GetAttrMap("FInferStructInfo"); + + auto temp_builder = relax::BlockBuilder::Create(NullOpt); + inferred_sinfo = op_map_infer_struct_info[op](call, temp_builder); + } else if (auto opt = call->op.as()) { + auto temp_builder = relax::BlockBuilder::Create(NullOpt); + inferred_sinfo = + DeriveCallRetStructInfo(opt.value(), call, temp_builder, temp_builder->GetAnalyzer()); + } + + } else if (const auto* tuple = rhs.as()) { + inferred_sinfo = relax::TupleStructInfo(tuple->fields.Map(relax::GetStructInfo)); + + } else if (const auto* get_item = rhs.as()) { + if (auto ptr = get_item->tuple->struct_info_.as(); + ptr && get_item->index < static_cast(ptr->fields.size())) { + inferred_sinfo = ptr->fields[get_item->index]; + } + + } else if (const auto* trivial_binding = rhs.as()) { + inferred_sinfo = trivial_binding->struct_info_.as(); + } + + if (inferred_sinfo && StructuralEqual()(inferred_sinfo, v->struct_info_)) { + return NullOpt; + } + } return d->AsDoc(v->struct_info_, v_p->Attr("struct_info_")); } diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index 530e45e61074..50b5e41598db 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -829,5 +829,48 @@ def func(x: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype ) +def test_hide_inferable_struct_info(): + @R.function + def func(A: R.Tensor([10, 20], "float32"), B: R.Tensor(ndim=2, dtype="float32")): + # R.match_cast has the struct info as an argument, so it can + # be omitted from the variable annotation. + B2 = R.match_cast(B, R.Tensor([10, 20], "float32")) + + # Call nodes may have inferable shapes from their arguments. + C = R.add(A, B2) + + # Trivial bindings can be inferred to have the same struct + # info as the RHS. + D = C + + # Here, the struct info cannot be omitted. `R.add(D,B)` has + # struct info `R.Tensor(ndim=2)`, but the variable has a shape + # `R.Tensor([10,20])`. This is compatible, so it is not an + # error to have this annotation, but it is not inferrable from + # the RHS. Therefore, it must still be printed. + E: R.Tensor([10, 20], "float32") = R.add(D, B) + + # The return type can be inferred from function body, but is + # still always printed in the TVMScript. When parsing an + # IRModule with functions calling each other, the return type + # of each callee must be available for use in the caller's + # shape inference. + return E + + _assert_print( + func.script(show_inferable_type_annotations=False), + """ +# from tvm.script import relax as R + +@R.function +def func(A: R.Tensor((10, 20), dtype="float32"), B: R.Tensor(dtype="float32", ndim=2)) -> R.Tensor((10, 20), dtype="float32"): + B2 = R.match_cast(B, R.Tensor((10, 20), dtype="float32")) + C = R.add(A, B2) + D = C + E: R.Tensor((10, 20), dtype="float32") = R.add(D, B) + return E""", + ) + + if __name__ == "__main__": tvm.testing.main() From 75282b67a9a5e5b7de33e64e04e5d5a3c8bab59d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 8 Jan 2024 14:20:03 +0000 Subject: [PATCH 2/3] Rename show_inferable_type_annotations to show_all_struct_info --- include/tvm/node/script_printer.h | 38 +++++++++++++++++-- python/tvm/runtime/script_printer.py | 29 +++++++------- src/node/script_printer.cc | 4 +- src/script/printer/relax/binding.cc | 2 +- src/script/printer/relax/utils.h | 2 +- .../relax/test_tvmscript_printer_relax.py | 8 +++- 6 files changed, 62 insertions(+), 21 deletions(-) diff --git a/include/tvm/node/script_printer.h b/include/tvm/node/script_printer.h index 515a5f8d1215..2b812eef32db 100644 --- a/include/tvm/node/script_printer.h +++ b/include/tvm/node/script_printer.h @@ -72,8 +72,40 @@ class PrinterConfigNode : public Object { bool syntax_sugar = true; /*! \brief Whether variable names should include the object's address */ bool show_object_address = false; - /*! \brief Whether to show StructInfo that can be inferred from arguments */ - bool show_inferable_type_annotations = true; + + /*! \brief In Relax, whether to show all StructInfo annotations + * + * If true (default), all variable bindings will be annotated with + * the struct info of the variable being bound. + * + * If false, the annotations will only be shown when they are + * required for correct parsing of the Relax function. For example, + * function parameters must always have struct info annotations, but + * the struct info for expressions within a function body may be inferred from their + * arguments, and are therefore + * + * Example: + * + * # func.show(show_all_struct_info=True) + * @R.function + * def func( + * A: R.Tensor((10, 20), dtype="float32"), + * B: R.Tensor((10,20), dtype="float32"), + * ) -> R.Tensor((10, 20), dtype="float32"): + * C: R.Tensor((10,20), dtype="float32") = R.add(A, B2) + * return C + * + * # func.show(show_all_struct_info=False) + * @R.function + * def func( + * A: R.Tensor((10, 20), dtype="float32"), + * B: R.Tensor((10,20), dtype="float32"), + * ) -> R.Tensor((10, 20), dtype="float32"): + * C = R.add(A, B2) + * return C + */ + bool show_all_struct_info = true; + /* \brief Object path to be underlined */ Array path_to_underline = Array(); /*! \brief Object path to be annotated. */ @@ -99,7 +131,7 @@ class PrinterConfigNode : public Object { v->Visit("num_context_lines", &num_context_lines); v->Visit("syntax_sugar", &syntax_sugar); v->Visit("show_object_address", &show_object_address); - v->Visit("show_inferable_type_annotations", &show_inferable_type_annotations); + v->Visit("show_all_struct_info", &show_all_struct_info); v->Visit("path_to_underline", &path_to_underline); v->Visit("path_to_annotate", &path_to_annotate); v->Visit("obj_to_underline", &obj_to_underline); diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index c5407fd59483..ad3f612c4e29 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -44,7 +44,7 @@ class PrinterConfig(Object): num_context_lines: int syntax_sugar: bool show_object_address: bool - show_inferable_type_annotations: bool + show_all_struct_info: bool path_to_underline: Optional[List[ObjectPath]] path_to_annotate: Optional[Dict[ObjectPath, str]] obj_to_underline: Optional[List[Object]] @@ -68,7 +68,7 @@ def __init__( num_context_lines: Optional[int] = None, syntax_sugar: bool = True, show_object_address: bool = False, - show_inferable_type_annotations: bool = True, + show_all_struct_info: bool = True, path_to_underline: Optional[List[ObjectPath]] = None, path_to_annotate: Optional[Dict[ObjectPath, str]] = None, obj_to_underline: Optional[List[Object]] = None, @@ -91,7 +91,7 @@ def __init__( "num_context_lines": num_context_lines, "syntax_sugar": syntax_sugar, "show_object_address": show_object_address, - "show_inferable_type_annotations": show_inferable_type_annotations, + "show_all_struct_info": show_all_struct_info, "path_to_underline": path_to_underline, "path_to_annotate": path_to_annotate, "obj_to_underline": obj_to_underline, @@ -135,7 +135,7 @@ def script( num_context_lines: int = -1, syntax_sugar: bool = True, show_object_address: bool = False, - show_inferable_type_annotations: bool = True, + show_all_struct_info: bool = True, path_to_underline: Optional[List[ObjectPath]] = None, path_to_annotate: Optional[Dict[ObjectPath, str]] = None, obj_to_underline: Optional[List[Object]] = None, @@ -176,9 +176,10 @@ def script( Whether to output with syntax sugar, set false for complete printing. show_object_address: bool = False Whether to include the object's address as part of the TVMScript name - show_inferable_type_annotations: bool = True - Whether to show type annotations that can be inferred from previous - annotations. + show_all_struct_info: bool = True + If True (default), annotate all variable bindings with the struct + info of that variable. If False, only add annotations where + required for unambiguous round-trip of Relax -> TVMScript -> Relax. path_to_underline : Optional[List[ObjectPath]] = None Object path to be underlined path_to_annotate : Optional[Dict[ObjectPath, str]] = None @@ -192,6 +193,7 @@ def script( ------- script : str The TVM Script of the given TVM IR + """ return _script( self, @@ -211,7 +213,7 @@ def script( num_context_lines=num_context_lines, syntax_sugar=syntax_sugar, show_object_address=show_object_address, - show_inferable_type_annotations=show_inferable_type_annotations, + show_all_struct_info=show_all_struct_info, path_to_underline=path_to_underline, path_to_annotate=path_to_annotate, obj_to_underline=obj_to_underline, @@ -287,7 +289,7 @@ def show( num_context_lines: int = -1, syntax_sugar: bool = True, show_object_address: bool = False, - show_inferable_type_annotations: bool = True, + show_all_struct_info: bool = True, path_to_underline: Optional[List[ObjectPath]] = None, path_to_annotate: Optional[Dict[ObjectPath, str]] = None, obj_to_underline: Optional[List[Object]] = None, @@ -351,9 +353,10 @@ def show( Whether to output with syntax sugar, set false for complete printing. show_object_address: bool = False Whether to include the object's address as part of the TVMScript name - show_inferable_type_annotations: bool = True - Whether to show type annotations that can be inferred from previous - annotations. + show_all_struct_info: bool = True + If True (default), annotate all variable bindings with the struct + info of that variable. If False, only add annotations where + required for unambiguous round-trip of Relax -> TVMScript -> Relax. path_to_underline : Optional[List[ObjectPath]] = None Object path to be underlined path_to_annotate : Optional[Dict[ObjectPath, str]] = None @@ -389,7 +392,7 @@ def show( num_context_lines=num_context_lines, syntax_sugar=syntax_sugar, show_object_address=show_object_address, - show_inferable_type_annotations=show_inferable_type_annotations, + show_all_struct_info=show_all_struct_info, path_to_underline=path_to_underline, path_to_annotate=path_to_annotate, obj_to_underline=obj_to_underline, diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index 74c3fb730b45..6e7d82ee4a59 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -112,8 +112,8 @@ PrinterConfig::PrinterConfig(Map config_dict) { if (auto v = config_dict.Get("show_object_address")) { n->show_object_address = Downcast(v)->value; } - if (auto v = config_dict.Get("show_inferable_type_annotations")) { - n->show_inferable_type_annotations = Downcast(v)->value; + if (auto v = config_dict.Get("show_all_struct_info")) { + n->show_all_struct_info = Downcast(v)->value; } // Checking prefixes if they are valid Python identifiers. diff --git a/src/script/printer/relax/binding.cc b/src/script/printer/relax/binding.cc index efbdc9d8a5b7..acf0072c0f45 100644 --- a/src/script/printer/relax/binding.cc +++ b/src/script/printer/relax/binding.cc @@ -45,7 +45,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) using relax::StructInfo; using relax::MatchStructInfo; Optional ann = NullOpt; - if (d->cfg->show_inferable_type_annotations) { + if (d->cfg->show_all_struct_info) { ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); } ExprDoc rhs = Relax(d, "match_cast") diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h index 1526277a0fce..25c56087a742 100644 --- a/src/script/printer/relax/utils.h +++ b/src/script/printer/relax/utils.h @@ -91,7 +91,7 @@ inline Optional StructInfoAsAnn(const relax::Var& v, const ObjectPath& return NullOpt; } } - if (!d->cfg->show_inferable_type_annotations) { + if (!d->cfg->show_all_struct_info) { Optional inferred_sinfo = NullOpt; if (auto opt = rhs.as()) { auto call = opt.value(); diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index 50b5e41598db..a75977ff9910 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -830,6 +830,12 @@ def func(x: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype def test_hide_inferable_struct_info(): + """Redundant type annotations can be omitted + + When `show_all_struct_info=False`, TVMScript type annotations that + provide redundant struct info can be omitted. + """ + @R.function def func(A: R.Tensor([10, 20], "float32"), B: R.Tensor(ndim=2, dtype="float32")): # R.match_cast has the struct info as an argument, so it can @@ -858,7 +864,7 @@ def func(A: R.Tensor([10, 20], "float32"), B: R.Tensor(ndim=2, dtype="float32")) return E _assert_print( - func.script(show_inferable_type_annotations=False), + func.script(show_all_struct_info=False), """ # from tvm.script import relax as R From f11725f75fff40f8dd46a2d2f42f6bac61e426a6 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 22 Jan 2024 19:10:11 +0000 Subject: [PATCH 3/3] Add unit test for round-trip of opaque function --- src/script/printer/relax/utils.h | 6 ++- .../tvmscript/test_tvmscript_roundtrip.py | 42 ++++++++++++++++++- 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h index 25c56087a742..989e9a63b1d9 100644 --- a/src/script/printer/relax/utils.h +++ b/src/script/printer/relax/utils.h @@ -84,14 +84,16 @@ inline Optional StructInfoAsAnn(const relax::Var& v, const ObjectPath& if (!v->struct_info_.defined()) { return NullOpt; } + bool attempt_to_hide_struct_info = !d->cfg->show_all_struct_info; + if (const auto* call = rhs.as()) { static const Op& call_tir_op = Op::Get("relax.call_tir"); static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); if (call->op.same_as(call_tir_op) || call->op.same_as(call_dps_packed_op)) { - return NullOpt; + attempt_to_hide_struct_info = true; } } - if (!d->cfg->show_all_struct_info) { + if (attempt_to_hide_struct_info) { Optional inferred_sinfo = NullOpt; if (auto opt = rhs.as()) { auto call = opt.value(); diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 5b3e68e22fa9..66eef5ad81c8 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -21,7 +21,7 @@ import tvm import tvm.testing from tvm import tir -from tvm.script import tir as T, ir as I +from tvm.script import tir as T, ir as I, relax as R import numpy as np @@ -3996,6 +3996,24 @@ def func(): yield make_ir_generator(op, arg) +def relax_extern_func(): + @R.function + def func(A: R.Tensor([10, 20], "float32")): + func = R.ExternFunc("dummy_func") + + B: R.Tensor([10, 20], "float32") = R.call_dps_packed( + func, [A], out_sinfo=R.Tensor([10, 20], "float32") + ) + + C: R.Tensor(ndim=2, dtype="float32") = R.call_dps_packed( + func, [B], out_sinfo=R.Tensor([10, 20], "float32") + ) + + return C + + return func + + ir_generator = tvm.testing.parameter( launch_env_thread, opt_gemm_normalize, @@ -4081,6 +4099,17 @@ def func(): *op_of_literal(), ) +relax_ir_generator = tvm.testing.parameter( + relax_extern_func, +) + +show_all_relax_struct_info = tvm.testing.parameter( + by_dict={ + "show_all_struct_info": True, + "hide_inferable_struct_info": False, + } +) + def test_roundtrip(ir_generator): original = ir_generator() @@ -4088,6 +4117,17 @@ def test_roundtrip(ir_generator): tvm.ir.assert_structural_equal(original, after_roundtrip, True) +def test_relax_roundtrip(relax_ir_generator, show_all_relax_struct_info): + original = relax_ir_generator() + after_roundtrip = tvm.script.from_source( + original.script( + show_meta=True, + show_all_struct_info=show_all_relax_struct_info, + ) + ) + tvm.ir.assert_structural_equal(original, after_roundtrip, True) + + def test_return_none_no_trailing_type(): func = return_none() script = func.script()