diff --git a/include/tvm/node/script_printer.h b/include/tvm/node/script_printer.h index f4fec04035fc..c65394f7b7e1 100644 --- a/include/tvm/node/script_printer.h +++ b/include/tvm/node/script_printer.h @@ -43,6 +43,11 @@ class PrinterConfigNode : public Object { std::string ir_prefix = "I"; /*! \brief The prefix of TIR nodes */ std::string tir_prefix = "T"; + /*! + * \brief The alias of the current module at cross-function call + * \note Directly use module name if it's empty. + */ + std::string module_alias = "cls"; /*! \brief Default data type of TIR buffer */ DataType buffer_dtype = DataType::Float(32); /*! \brief Default data type of integer literals */ @@ -76,6 +81,8 @@ class PrinterConfigNode : public Object { v->Visit("binding_names", &binding_names); v->Visit("show_meta", &show_meta); v->Visit("ir_prefix", &ir_prefix); + v->Visit("tir_prefix", &tir_prefix); + v->Visit("module_alias", &module_alias); v->Visit("buffer_dtype", &buffer_dtype); v->Visit("int_dtype", &int_dtype); v->Visit("float_dtype", &float_dtype); diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 3c3fefb6d6c6..1c775b461e6e 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Common expressions data structures in the IR.""" +from numbers import Number + import tvm._ffi from ..runtime import Scriptable, const, convert @@ -86,6 +88,9 @@ def __call__(self, *args): from tvm import relay return relay.Call(self, args) + elif all(isinstance(x, (Number, PrimExpr)) for x in args): + return tvm.tir.call_tir(self, *args) + arg_types = [type(x) for x in args] raise RuntimeError( "Do not know how to handle GlobalVar.__call__ for types {}".format(arg_types) diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 72858a202853..c253f61c3181 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -23,6 +23,7 @@ from tvm._ffi.base import TVMError from tvm.error import DiagnosticError +from tvm.ir import GlobalVar from . import dispatch, doc from .diagnostics import Diagnostics, Source @@ -504,10 +505,10 @@ def visit_FunctionDef(self, node: doc.FunctionDef) -> None: # pylint: disable=i _dispatch_wrapper(func)(self, node) post_func(self, node) - def visit_tvm_declare_function(self, node: doc.FunctionDef) -> None: + def visit_tvm_declare_function(self, node: doc.FunctionDef) -> GlobalVar: token = self.get_dispatch_token(node) with self.with_dispatch_token(token): - _dispatch(self, "tvm_declare_function")(self, node) + return _dispatch(self, "tvm_declare_function")(self, node) def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name """The general class definition visiting method. diff --git a/python/tvm/script/parser/core/utils.py b/python/tvm/script/parser/core/utils.py index 6a693df12f89..3edae3f25a33 100644 --- a/python/tvm/script/parser/core/utils.py +++ b/python/tvm/script/parser/core/utils.py @@ -22,6 +22,30 @@ from .diagnostics import findsource +def get_func_nonlocals(func): + """A modified version of `inspect.getclosurevars`""" + + if inspect.ismethod(func): + func = func.__func__ + + if not inspect.isfunction(func): + raise TypeError("{!r} is not a Python function".format(func)) + + code = func.__code__ + # Nonlocal references are named in co_freevars and resolved + # by looking them up in __closure__ by positional index + nonlocal_vars = {} + if func.__closure__ is not None: + for var, cell in zip(code.co_freevars, func.__closure__): + try: + nonlocal_vars[var] = cell.cell_contents + except ValueError as err: + # cell_contents may raise ValueError if the cell is empty. + if "empty" not in str(err): + raise + return nonlocal_vars + + def inspect_function_capture(func: Callable) -> Dict[str, Any]: """Capture function non-locals and global variables. @@ -37,7 +61,7 @@ def inspect_function_capture(func: Callable) -> Dict[str, Any]: """ captured = { **func.__globals__, # type: ignore - **inspect.getclosurevars(func).nonlocals, + **get_func_nonlocals(func), } return captured diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py index 201c99074f20..075ca0870341 100644 --- a/python/tvm/script/parser/ir/parser.py +++ b/python/tvm/script/parser/ir/parser.py @@ -20,6 +20,15 @@ from .._core import Parser, dispatch, doc +class ModuleWithGlobalVars: + """A Module that can add global vars during parsing, to support `Module.function` syntax.""" + + def __getattr__(self, attr): + # Customize the error message. + # NOTE: `__getattr__` is only called when the attribute access fails with an AttributeError + raise AttributeError(f"Cannot find the function `{attr}` in the current IRModule") + + @dispatch.register(token="ir", type_name="ClassDef") def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: """The class definition visiting method for ir module. @@ -35,13 +44,25 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: with self.var_table.with_frame(): with I.ir_module(): + # Step 0. Add the class name to the var table + fake_module = ModuleWithGlobalVars() + self.var_table.add(node.name, fake_module) + + # Step 1. Visit non-function stmts, including but not limited to + # 1. `I.module_attrs` + # 2. `I.module_global_infos` with self.with_dispatch_token("ir"): for stmt in node.body: if not isinstance(stmt, doc.FunctionDef): self.visit(stmt) + + # Step 2. Visit function stmts to declare the global vars for stmt in node.body: if isinstance(stmt, doc.FunctionDef): - self.visit_tvm_declare_function(stmt) + global_var = self.visit_tvm_declare_function(stmt) + fake_module.__setattr__(stmt.name, global_var) + + # Step 3. Visit and parse the functions with self.with_dispatch_token("ir"): for stmt in node.body: if isinstance(stmt, doc.FunctionDef): diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 0a489a8f0401..dfecaacdf655 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -21,7 +21,7 @@ from typing import Any import tvm -from tvm.ir import PrimType +from tvm.ir import GlobalVar, PrimType from tvm.tir import Buffer, IterVar, PrimExpr, Var from ...ir_builder import ir as I @@ -473,7 +473,7 @@ def visit_return(self: Parser, node: doc.Return) -> None: @dispatch.register(token="tir", type_name="tvm_declare_function") -def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None: +def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar: """The function declaration step for tir Parameters @@ -493,5 +493,4 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None: # Only ret_type is needed for func_signature. func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type) - global_var = I.decl_function(node.name, func_signature) - self.var_table.add(node.name, global_var) + return I.decl_function(node.name, func_signature) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 10e75b915129..6583af6e79e6 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -44,7 +44,7 @@ from .function import PrimFunc, TensorIntrin, IndexMap -from .op import call_packed_lowered, call_cpacked_lowered +from .op import call_packed_lowered, call_cpacked_lowered, call_tir from .op import call_packed, call_cpacked, call_intrin, call_pure_extern, call_extern from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace from .op import tvm_check_return diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 419ab2275858..90e3db4cb96b 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -436,6 +436,18 @@ def undef(): return call_intrin("int32", "tir.undef") +def call_tir(global_var: tvm.ir.GlobalVar, *args): + """Performs a call into another PrimFunc in the same IRModule + + Returns + ------- + call : PrimExpr + The call expression. + """ + assert isinstance(global_var, tvm.ir.GlobalVar) + return Call(dtype="handle", op=global_var, args=args) + + def start_profile_intrinsic(id): """Start profile intrinsic. Parameters diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index 87e7bfbcd9d2..7b6da4230532 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -64,11 +64,23 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) std::sort(functions.begin(), functions.end()); With f(d); (*f)->AddDispatchToken(d, "ir"); + IdDoc module_doc = d->Define(mod, f(), GetBindingName(d).value_or("Module")); if (mod->attrs.defined() && !mod->attrs->dict.empty()) { (*f)->stmts.push_back( ExprStmtDoc(IR(d, "module_attrs") // ->Call({d->AsDoc(mod->attrs, p->Attr("attrs"))}))); } + + // Declare GlobalVars first + IdDoc module_alias = d->cfg->module_alias.empty() ? module_doc : IdDoc(d->cfg->module_alias); + for (const auto& entry : functions) { + const GlobalVar& gv = entry.gv; + d->Define(gv, f(), [=]() { + return d->AsDoc(mod, p->Attr("global_vars"))->Attr(gv->name_hint); + }); + } + // Print functions + for (const auto& entry : functions) { const GlobalVar& gv = entry.gv; const BaseFunc& func = entry.func; @@ -84,8 +96,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) (*f)->stmts.push_back(Downcast(doc)); } } - return HeaderWrapper(d, ClassDoc(IdDoc(GetBindingName(d).value_or("Module")), - {IR(d, "ir_module")}, (*f)->stmts)); + return HeaderWrapper(d, ClassDoc(module_doc, {IR(d, "ir_module")}, (*f)->stmts)); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index db99c24886bf..8de142f8613e 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -250,8 +250,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) dtype_print_location = static_cast(dtype_locations[op].IntValue()); } - } else if (const auto* gv = call->op.as()) { - prefix = LiteralDoc::Str(gv->name_hint, call_p->Attr("op")); + } else if (call->op.as()) { + prefix = d->AsDoc(call->op, call_p->Attr("op")); } else { LOG(FATAL) << "call: " << call; } @@ -261,6 +261,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (dtype_print_location == tir::ScriptDtypePrintLocation::kFirst) { args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype"))); } + for (int i = 0; i < n_args; ++i) { args.push_back(d->AsDoc(call->args[i], call_p->Attr("args")->ArrayIndex(i))); } diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index f40d7818d7e1..a8445f23df6c 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -176,6 +176,26 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_SCRIPT_REPR(tir::PrimFuncNode, ReprPrintTIR); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "tir", [](tvm::GlobalVar n, ObjectPath n_p, IRDocsifier d) -> Doc { // + if (Optional doc = d->GetVarDoc(n)) { + return doc.value(); + } else { + IdDoc ret(n->name_hint); + ret->source_paths.push_back(n_p); + return ret; + } + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "tir", [](tvm::IRModule mod, ObjectPath n_p, IRDocsifier d) -> Doc { // + Optional doc = d->GetVarDoc(mod); + ICHECK(doc) << "Unable to print IRModule before definition in TIR."; + return doc.value(); + }); + } // namespace printer } // namespace script } // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 757f74ab8396..7eee6013589d 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3792,6 +3792,22 @@ def nested_seqstmt(): return func +def subroutine_call(): + """A GlobalVar may reference other functions in the module""" + + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(16, "float32")): + mod.subroutine(A.data, T.int32(16)) + + @T.prim_func + def subroutine(A_data: T.handle("float32"), n: T.int32): + T.evaluate(0) + + return mod + + ir_generator = tvm.testing.parameter( launch_env_thread, opt_gemm_normalize, @@ -3861,6 +3877,7 @@ def nested_seqstmt(): tvm_struct_set_generated_in_cpp, ir_module_with_attrs, nested_seqstmt, + subroutine_call, )