diff --git a/include/tvm/node/script_printer.h b/include/tvm/node/script_printer.h index 703844d99ae9..2b812eef32db 100644 --- a/include/tvm/node/script_printer.h +++ b/include/tvm/node/script_printer.h @@ -72,6 +72,40 @@ class PrinterConfigNode : public Object { bool syntax_sugar = true; /*! \brief Whether variable names should include the object's address */ bool show_object_address = false; + + /*! \brief In Relax, whether to show all StructInfo annotations + * + * If true (default), all variable bindings will be annotated with + * the struct info of the variable being bound. + * + * If false, the annotations will only be shown when they are + * required for correct parsing of the Relax function. For example, + * function parameters must always have struct info annotations, but + * the struct info for expressions within a function body may be inferred from their + * arguments, and are therefore + * + * Example: + * + * # func.show(show_all_struct_info=True) + * @R.function + * def func( + * A: R.Tensor((10, 20), dtype="float32"), + * B: R.Tensor((10,20), dtype="float32"), + * ) -> R.Tensor((10, 20), dtype="float32"): + * C: R.Tensor((10,20), dtype="float32") = R.add(A, B2) + * return C + * + * # func.show(show_all_struct_info=False) + * @R.function + * def func( + * A: R.Tensor((10, 20), dtype="float32"), + * B: R.Tensor((10,20), dtype="float32"), + * ) -> R.Tensor((10, 20), dtype="float32"): + * C = R.add(A, B2) + * return C + */ + bool show_all_struct_info = true; + /* \brief Object path to be underlined */ Array path_to_underline = Array(); /*! \brief Object path to be annotated. */ @@ -97,6 +131,7 @@ class PrinterConfigNode : public Object { v->Visit("num_context_lines", &num_context_lines); v->Visit("syntax_sugar", &syntax_sugar); v->Visit("show_object_address", &show_object_address); + v->Visit("show_all_struct_info", &show_all_struct_info); v->Visit("path_to_underline", &path_to_underline); v->Visit("path_to_annotate", &path_to_annotate); v->Visit("obj_to_underline", &obj_to_underline); diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index 260d0ead9d8e..ad3f612c4e29 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -44,6 +44,7 @@ class PrinterConfig(Object): num_context_lines: int syntax_sugar: bool show_object_address: bool + show_all_struct_info: bool path_to_underline: Optional[List[ObjectPath]] path_to_annotate: Optional[Dict[ObjectPath, str]] obj_to_underline: Optional[List[Object]] @@ -67,6 +68,7 @@ def __init__( num_context_lines: Optional[int] = None, syntax_sugar: bool = True, show_object_address: bool = False, + show_all_struct_info: bool = True, path_to_underline: Optional[List[ObjectPath]] = None, path_to_annotate: Optional[Dict[ObjectPath, str]] = None, obj_to_underline: Optional[List[Object]] = None, @@ -89,6 +91,7 @@ def __init__( "num_context_lines": num_context_lines, "syntax_sugar": syntax_sugar, "show_object_address": show_object_address, + "show_all_struct_info": show_all_struct_info, "path_to_underline": path_to_underline, "path_to_annotate": path_to_annotate, "obj_to_underline": obj_to_underline, @@ -132,6 +135,7 @@ def script( num_context_lines: int = -1, syntax_sugar: bool = True, show_object_address: bool = False, + show_all_struct_info: bool = True, path_to_underline: Optional[List[ObjectPath]] = None, path_to_annotate: Optional[Dict[ObjectPath, str]] = None, obj_to_underline: Optional[List[Object]] = None, @@ -169,9 +173,13 @@ def script( num_context_lines : int = -1 The number of lines of context to print before and after the line to underline. syntax_sugar: bool = True - Whether to output with syntax sugar, set false for complete printing. + Whether to output with syntax sugar, set false for complete printing. show_object_address: bool = False - Whether to include the object's address as part of the TVMScript name + Whether to include the object's address as part of the TVMScript name + show_all_struct_info: bool = True + If True (default), annotate all variable bindings with the struct + info of that variable. If False, only add annotations where + required for unambiguous round-trip of Relax -> TVMScript -> Relax. path_to_underline : Optional[List[ObjectPath]] = None Object path to be underlined path_to_annotate : Optional[Dict[ObjectPath, str]] = None @@ -185,6 +193,7 @@ def script( ------- script : str The TVM Script of the given TVM IR + """ return _script( self, @@ -204,6 +213,7 @@ def script( num_context_lines=num_context_lines, syntax_sugar=syntax_sugar, show_object_address=show_object_address, + show_all_struct_info=show_all_struct_info, path_to_underline=path_to_underline, path_to_annotate=path_to_annotate, obj_to_underline=obj_to_underline, @@ -279,6 +289,7 @@ def show( num_context_lines: int = -1, syntax_sugar: bool = True, show_object_address: bool = False, + show_all_struct_info: bool = True, path_to_underline: Optional[List[ObjectPath]] = None, path_to_annotate: Optional[Dict[ObjectPath, str]] = None, obj_to_underline: Optional[List[Object]] = None, @@ -339,9 +350,13 @@ def show( num_context_lines : int = -1 The number of lines of context to print before and after the line to underline. syntax_sugar: bool = True - Whether to output with syntax sugar, set false for complete printing. + Whether to output with syntax sugar, set false for complete printing. show_object_address: bool = False - Whether to include the object's address as part of the TVMScript name + Whether to include the object's address as part of the TVMScript name + show_all_struct_info: bool = True + If True (default), annotate all variable bindings with the struct + info of that variable. If False, only add annotations where + required for unambiguous round-trip of Relax -> TVMScript -> Relax. path_to_underline : Optional[List[ObjectPath]] = None Object path to be underlined path_to_annotate : Optional[Dict[ObjectPath, str]] = None @@ -377,6 +392,7 @@ def show( num_context_lines=num_context_lines, syntax_sugar=syntax_sugar, show_object_address=show_object_address, + show_all_struct_info=show_all_struct_info, path_to_underline=path_to_underline, path_to_annotate=path_to_annotate, obj_to_underline=obj_to_underline, diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index 38334de357d8..6e7d82ee4a59 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -112,6 +112,9 @@ PrinterConfig::PrinterConfig(Map config_dict) { if (auto v = config_dict.Get("show_object_address")) { n->show_object_address = Downcast(v)->value; } + if (auto v = config_dict.Get("show_all_struct_info")) { + n->show_all_struct_info = Downcast(v)->value; + } // Checking prefixes if they are valid Python identifiers. CHECK(IsIdentifier(n->ir_prefix)) << "Invalid `ir_prefix`: " << n->ir_prefix; diff --git a/src/script/printer/relax/binding.cc b/src/script/printer/relax/binding.cc index 395b4251fb5e..acf0072c0f45 100644 --- a/src/script/printer/relax/binding.cc +++ b/src/script/printer/relax/binding.cc @@ -44,7 +44,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) "", [](relax::MatchCast n, ObjectPath n_p, IRDocsifier d) -> Doc { using relax::StructInfo; using relax::MatchStructInfo; - Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); + Optional ann = NullOpt; + if (d->cfg->show_all_struct_info) { + ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); + } ExprDoc rhs = Relax(d, "match_cast") ->Call({d->AsDoc(n->value, n_p->Attr("value")), d->AsDoc(n->struct_info, n_p->Attr("struct_info_"))}); diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h index 58b8bf443173..989e9a63b1d9 100644 --- a/src/script/printer/relax/utils.h +++ b/src/script/printer/relax/utils.h @@ -19,6 +19,8 @@ #ifndef TVM_SCRIPT_PRINTER_RELAX_UTILS_H_ #define TVM_SCRIPT_PRINTER_RELAX_UTILS_H_ +#include +#include #include #include #include @@ -82,10 +84,47 @@ inline Optional StructInfoAsAnn(const relax::Var& v, const ObjectPath& if (!v->struct_info_.defined()) { return NullOpt; } + bool attempt_to_hide_struct_info = !d->cfg->show_all_struct_info; + if (const auto* call = rhs.as()) { static const Op& call_tir_op = Op::Get("relax.call_tir"); static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); if (call->op.same_as(call_tir_op) || call->op.same_as(call_dps_packed_op)) { + attempt_to_hide_struct_info = true; + } + } + if (attempt_to_hide_struct_info) { + Optional inferred_sinfo = NullOpt; + if (auto opt = rhs.as()) { + auto call = opt.value(); + if (auto opt = call->op.as()) { + auto op = opt.value(); + + static auto op_map_infer_struct_info = + Op::GetAttrMap("FInferStructInfo"); + + auto temp_builder = relax::BlockBuilder::Create(NullOpt); + inferred_sinfo = op_map_infer_struct_info[op](call, temp_builder); + } else if (auto opt = call->op.as()) { + auto temp_builder = relax::BlockBuilder::Create(NullOpt); + inferred_sinfo = + DeriveCallRetStructInfo(opt.value(), call, temp_builder, temp_builder->GetAnalyzer()); + } + + } else if (const auto* tuple = rhs.as()) { + inferred_sinfo = relax::TupleStructInfo(tuple->fields.Map(relax::GetStructInfo)); + + } else if (const auto* get_item = rhs.as()) { + if (auto ptr = get_item->tuple->struct_info_.as(); + ptr && get_item->index < static_cast(ptr->fields.size())) { + inferred_sinfo = ptr->fields[get_item->index]; + } + + } else if (const auto* trivial_binding = rhs.as()) { + inferred_sinfo = trivial_binding->struct_info_.as(); + } + + if (inferred_sinfo && StructuralEqual()(inferred_sinfo, v->struct_info_)) { return NullOpt; } } diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index 530e45e61074..a75977ff9910 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -829,5 +829,54 @@ def func(x: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype ) +def test_hide_inferable_struct_info(): + """Redundant type annotations can be omitted + + When `show_all_struct_info=False`, TVMScript type annotations that + provide redundant struct info can be omitted. + """ + + @R.function + def func(A: R.Tensor([10, 20], "float32"), B: R.Tensor(ndim=2, dtype="float32")): + # R.match_cast has the struct info as an argument, so it can + # be omitted from the variable annotation. + B2 = R.match_cast(B, R.Tensor([10, 20], "float32")) + + # Call nodes may have inferable shapes from their arguments. + C = R.add(A, B2) + + # Trivial bindings can be inferred to have the same struct + # info as the RHS. + D = C + + # Here, the struct info cannot be omitted. `R.add(D,B)` has + # struct info `R.Tensor(ndim=2)`, but the variable has a shape + # `R.Tensor([10,20])`. This is compatible, so it is not an + # error to have this annotation, but it is not inferrable from + # the RHS. Therefore, it must still be printed. + E: R.Tensor([10, 20], "float32") = R.add(D, B) + + # The return type can be inferred from function body, but is + # still always printed in the TVMScript. When parsing an + # IRModule with functions calling each other, the return type + # of each callee must be available for use in the caller's + # shape inference. + return E + + _assert_print( + func.script(show_all_struct_info=False), + """ +# from tvm.script import relax as R + +@R.function +def func(A: R.Tensor((10, 20), dtype="float32"), B: R.Tensor(dtype="float32", ndim=2)) -> R.Tensor((10, 20), dtype="float32"): + B2 = R.match_cast(B, R.Tensor((10, 20), dtype="float32")) + C = R.add(A, B2) + D = C + E: R.Tensor((10, 20), dtype="float32") = R.add(D, B) + return E""", + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 5b3e68e22fa9..66eef5ad81c8 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -21,7 +21,7 @@ import tvm import tvm.testing from tvm import tir -from tvm.script import tir as T, ir as I +from tvm.script import tir as T, ir as I, relax as R import numpy as np @@ -3996,6 +3996,24 @@ def func(): yield make_ir_generator(op, arg) +def relax_extern_func(): + @R.function + def func(A: R.Tensor([10, 20], "float32")): + func = R.ExternFunc("dummy_func") + + B: R.Tensor([10, 20], "float32") = R.call_dps_packed( + func, [A], out_sinfo=R.Tensor([10, 20], "float32") + ) + + C: R.Tensor(ndim=2, dtype="float32") = R.call_dps_packed( + func, [B], out_sinfo=R.Tensor([10, 20], "float32") + ) + + return C + + return func + + ir_generator = tvm.testing.parameter( launch_env_thread, opt_gemm_normalize, @@ -4081,6 +4099,17 @@ def func(): *op_of_literal(), ) +relax_ir_generator = tvm.testing.parameter( + relax_extern_func, +) + +show_all_relax_struct_info = tvm.testing.parameter( + by_dict={ + "show_all_struct_info": True, + "hide_inferable_struct_info": False, + } +) + def test_roundtrip(ir_generator): original = ir_generator() @@ -4088,6 +4117,17 @@ def test_roundtrip(ir_generator): tvm.ir.assert_structural_equal(original, after_roundtrip, True) +def test_relax_roundtrip(relax_ir_generator, show_all_relax_struct_info): + original = relax_ir_generator() + after_roundtrip = tvm.script.from_source( + original.script( + show_meta=True, + show_all_struct_info=show_all_relax_struct_info, + ) + ) + tvm.ir.assert_structural_equal(original, after_roundtrip, True) + + def test_return_none_no_trailing_type(): func = return_none() script = func.script()