From 0ad8af2feb08e322b8059d84a8cb5cea23fd5da6 Mon Sep 17 00:00:00 2001 From: sung Date: Sun, 19 Mar 2023 20:57:18 -0700 Subject: [PATCH] register the dispatch for runtime::Module --- src/script/printer/ir/ir.cc | 8 ++++ .../relax/test_tvmscript_printer_relax.py | 44 +++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index f23820927db6..ff6ed4a497a2 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -102,6 +102,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return HeaderWrapper(d, ClassDoc(module_doc, {IR(d, "ir_module")}, (*f)->stmts)); }); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", + [](runtime::Module rtmod, ObjectPath p, IRDocsifier d) -> Doc { + std::ostringstream oss; + oss << rtmod << ", " << rtmod.get(); + return LiteralDoc::Str(String(oss.str()), NullOpt); + }); + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](DictAttrs attrs, ObjectPath p, IRDocsifier d) -> Doc { return d->AsDoc(attrs->dict, p->Attr("dict")); diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index bffa741353a9..6114b0bfabc1 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -529,5 +529,49 @@ def foo(x: R.Tensor((128,), dtype="float32")) -> R.Tensor((128,), dtype="float32 ) +def test_runtime_module_in_irmodule_attrs(): + @I.ir_module + class TestModule: + @T.prim_func + def tir_func( + x: T.Buffer((T.int64(128),), "float32"), y: T.Buffer((T.int64(128),), "float32") + ): + T.evaluate(0) + + @R.function + def foo(x: R.Tensor((128,), "float32")) -> R.Tensor((128,), "float32"): + cls = TestModule + gv0 = R.call_tir(cls.tir_func, x, R.Tensor((128,), dtype="float32")) + return gv0 + + exec = relax.build(TestModule, "llvm") + NewTestModule = TestModule.with_attr("test", exec.mod) + # empty module alias + module_str = NewTestModule.script(module_alias="") + _assert_print( + module_str, + """ +# from tvm.script import ir as I +# from tvm.script import tir as T +# from tvm.script import relax as R + +@I.ir_module +class Module: + I.module_attrs({"test": "Module(type_key= relax.Executable), +""".rstrip() + + f" {exec.mod.handle.value:#x}".rstrip() + + """"}) + @T.prim_func + def tir_func(x: T.Buffer((T.int64(128),), "float32"), y: T.Buffer((T.int64(128),), "float32")): + T.evaluate(0) + + @R.function + def foo(x: R.Tensor((128,), dtype="float32")) -> R.Tensor((128,), dtype="float32"): + gv0 = R.call_tir(Module.tir_func, (x,), out_sinfo=R.Tensor((128,), dtype="float32")) + return gv0 + """, + ) + + if __name__ == "__main__": tvm.testing.main()