Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/tvm/node/script_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ class PrinterConfigNode : public Object {
std::string ir_prefix = "I";
/*! \brief The prefix of TIR nodes */
std::string tir_prefix = "T";
/*!
* \brief The alias of the current module at cross-function call
* \note Directly use module name if it's empty.
*/
std::string module_alias = "cls";
/*! \brief Default data type of TIR buffer */
DataType buffer_dtype = DataType::Float(32);
/*! \brief Default data type of integer literals */
Expand Down Expand Up @@ -76,6 +81,8 @@ class PrinterConfigNode : public Object {
v->Visit("binding_names", &binding_names);
v->Visit("show_meta", &show_meta);
v->Visit("ir_prefix", &ir_prefix);
v->Visit("tir_prefix", &tir_prefix);
v->Visit("module_alias", &module_alias);
v->Visit("buffer_dtype", &buffer_dtype);
v->Visit("int_dtype", &int_dtype);
v->Visit("float_dtype", &float_dtype);
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""Common expressions data structures in the IR."""
from numbers import Number

import tvm._ffi

from ..runtime import Scriptable, const, convert
Expand Down Expand Up @@ -86,6 +88,9 @@ def __call__(self, *args):
from tvm import relay

return relay.Call(self, args)
elif all(isinstance(x, (Number, PrimExpr)) for x in args):
return tvm.tir.call_tir(self, *args)

arg_types = [type(x) for x in args]
raise RuntimeError(
"Do not know how to handle GlobalVar.__call__ for types {}".format(arg_types)
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/script/parser/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tvm._ffi.base import TVMError

from tvm.error import DiagnosticError
from tvm.ir import GlobalVar

from . import dispatch, doc
from .diagnostics import Diagnostics, Source
Expand Down Expand Up @@ -504,10 +505,10 @@ def visit_FunctionDef(self, node: doc.FunctionDef) -> None: # pylint: disable=i
_dispatch_wrapper(func)(self, node)
post_func(self, node)

def visit_tvm_declare_function(self, node: doc.FunctionDef) -> None:
def visit_tvm_declare_function(self, node: doc.FunctionDef) -> GlobalVar:
token = self.get_dispatch_token(node)
with self.with_dispatch_token(token):
_dispatch(self, "tvm_declare_function")(self, node)
return _dispatch(self, "tvm_declare_function")(self, node)

def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name
"""The general class definition visiting method.
Expand Down
26 changes: 25 additions & 1 deletion python/tvm/script/parser/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,30 @@
from .diagnostics import findsource


def get_func_nonlocals(func):
"""A modified version of `inspect.getclosurevars`"""

if inspect.ismethod(func):
func = func.__func__

if not inspect.isfunction(func):
raise TypeError("{!r} is not a Python function".format(func))

code = func.__code__
# Nonlocal references are named in co_freevars and resolved
# by looking them up in __closure__ by positional index
nonlocal_vars = {}
if func.__closure__ is not None:
for var, cell in zip(code.co_freevars, func.__closure__):
try:
nonlocal_vars[var] = cell.cell_contents
except ValueError as err:
# cell_contents may raise ValueError if the cell is empty.
if "empty" not in str(err):
raise
return nonlocal_vars


def inspect_function_capture(func: Callable) -> Dict[str, Any]:
"""Capture function non-locals and global variables.

Expand All @@ -37,7 +61,7 @@ def inspect_function_capture(func: Callable) -> Dict[str, Any]:
"""
captured = {
**func.__globals__, # type: ignore
**inspect.getclosurevars(func).nonlocals,
**get_func_nonlocals(func),
}
return captured

Expand Down
23 changes: 22 additions & 1 deletion python/tvm/script/parser/ir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@
from .._core import Parser, dispatch, doc


class ModuleWithGlobalVars:
"""A Module that can add global vars during parsing, to support `Module.function` syntax."""

def __getattr__(self, attr):
# Customize the error message.
# NOTE: `__getattr__` is only called when the attribute access fails with an AttributeError
raise AttributeError(f"Cannot find the function `{attr}` in the current IRModule")


@dispatch.register(token="ir", type_name="ClassDef")
def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
"""The class definition visiting method for ir module.
Expand All @@ -35,13 +44,25 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:

with self.var_table.with_frame():
with I.ir_module():
# Step 0. Add the class name to the var table
fake_module = ModuleWithGlobalVars()
self.var_table.add(node.name, fake_module)

# Step 1. Visit non-function stmts, including but not limited to
# 1. `I.module_attrs`
# 2. `I.module_global_infos`
with self.with_dispatch_token("ir"):
for stmt in node.body:
if not isinstance(stmt, doc.FunctionDef):
self.visit(stmt)

# Step 2. Visit function stmts to declare the global vars
for stmt in node.body:
if isinstance(stmt, doc.FunctionDef):
self.visit_tvm_declare_function(stmt)
global_var = self.visit_tvm_declare_function(stmt)
fake_module.__setattr__(stmt.name, global_var)

# Step 3. Visit and parse the functions
with self.with_dispatch_token("ir"):
for stmt in node.body:
if isinstance(stmt, doc.FunctionDef):
Expand Down
7 changes: 3 additions & 4 deletions python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Any

import tvm
from tvm.ir import PrimType
from tvm.ir import GlobalVar, PrimType
from tvm.tir import Buffer, IterVar, PrimExpr, Var

from ...ir_builder import ir as I
Expand Down Expand Up @@ -473,7 +473,7 @@ def visit_return(self: Parser, node: doc.Return) -> None:


@dispatch.register(token="tir", type_name="tvm_declare_function")
def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None:
def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar:
"""The function declaration step for tir

Parameters
Expand All @@ -493,5 +493,4 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None:

# Only ret_type is needed for func_signature.
func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type)
global_var = I.decl_function(node.name, func_signature)
self.var_table.add(node.name, global_var)
return I.decl_function(node.name, func_signature)
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

from .function import PrimFunc, TensorIntrin, IndexMap

from .op import call_packed_lowered, call_cpacked_lowered
from .op import call_packed_lowered, call_cpacked_lowered, call_tir
from .op import call_packed, call_cpacked, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
from .op import tvm_check_return
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,18 @@ def undef():
return call_intrin("int32", "tir.undef")


def call_tir(global_var: tvm.ir.GlobalVar, *args):
"""Performs a call into another PrimFunc in the same IRModule

Returns
-------
call : PrimExpr
The call expression.
"""
assert isinstance(global_var, tvm.ir.GlobalVar)
return Call(dtype="handle", op=global_var, args=args)


def start_profile_intrinsic(id):
"""Start profile intrinsic.
Parameters
Expand Down
15 changes: 13 additions & 2 deletions src/script/printer/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,23 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
std::sort(functions.begin(), functions.end());
With<IRFrame> f(d);
(*f)->AddDispatchToken(d, "ir");
IdDoc module_doc = d->Define(mod, f(), GetBindingName(d).value_or("Module"));
if (mod->attrs.defined() && !mod->attrs->dict.empty()) {
(*f)->stmts.push_back(
ExprStmtDoc(IR(d, "module_attrs") //
->Call({d->AsDoc<ExprDoc>(mod->attrs, p->Attr("attrs"))})));
}

// Declare GlobalVars first
IdDoc module_alias = d->cfg->module_alias.empty() ? module_doc : IdDoc(d->cfg->module_alias);
for (const auto& entry : functions) {
const GlobalVar& gv = entry.gv;
d->Define(gv, f(), [=]() {
return d->AsDoc<ExprDoc>(mod, p->Attr("global_vars"))->Attr(gv->name_hint);
});
}
// Print functions

for (const auto& entry : functions) {
const GlobalVar& gv = entry.gv;
const BaseFunc& func = entry.func;
Expand All @@ -84,8 +96,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
(*f)->stmts.push_back(Downcast<FunctionDoc>(doc));
}
}
return HeaderWrapper(d, ClassDoc(IdDoc(GetBindingName(d).value_or("Module")),
{IR(d, "ir_module")}, (*f)->stmts));
return HeaderWrapper(d, ClassDoc(module_doc, {IR(d, "ir_module")}, (*f)->stmts));
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Expand Down
5 changes: 3 additions & 2 deletions src/script/printer/tir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
dtype_print_location =
static_cast<tir::ScriptDtypePrintLocation>(dtype_locations[op].IntValue());
}
} else if (const auto* gv = call->op.as<GlobalVarNode>()) {
prefix = LiteralDoc::Str(gv->name_hint, call_p->Attr("op"));
} else if (call->op.as<GlobalVarNode>()) {
prefix = d->AsDoc<ExprDoc>(call->op, call_p->Attr("op"));
} else {
LOG(FATAL) << "call: " << call;
}
Expand All @@ -261,6 +261,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
if (dtype_print_location == tir::ScriptDtypePrintLocation::kFirst) {
args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype")));
}

for (int i = 0; i < n_args; ++i) {
args.push_back(d->AsDoc<ExprDoc>(call->args[i], call_p->Attr("args")->ArrayIndex(i)));
}
Expand Down
20 changes: 20 additions & 0 deletions src/script/printer/tir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,26 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)

TVM_SCRIPT_REPR(tir::PrimFuncNode, ReprPrintTIR);

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tvm::GlobalVar>( //
"tir", [](tvm::GlobalVar n, ObjectPath n_p, IRDocsifier d) -> Doc { //
if (Optional<ExprDoc> doc = d->GetVarDoc(n)) {
return doc.value();
} else {
IdDoc ret(n->name_hint);
ret->source_paths.push_back(n_p);
return ret;
}
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tvm::IRModule>( //
"tir", [](tvm::IRModule mod, ObjectPath n_p, IRDocsifier d) -> Doc { //
Optional<ExprDoc> doc = d->GetVarDoc(mod);
ICHECK(doc) << "Unable to print IRModule before definition in TIR.";
return doc.value();
});

} // namespace printer
} // namespace script
} // namespace tvm
17 changes: 17 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3792,6 +3792,22 @@ def nested_seqstmt():
return func


def subroutine_call():
"""A GlobalVar may reference other functions in the module"""

@I.ir_module
class mod:
@T.prim_func
def main(A: T.Buffer(16, "float32")):
mod.subroutine(A.data, T.int32(16))

@T.prim_func
def subroutine(A_data: T.handle("float32"), n: T.int32):
T.evaluate(0)

return mod


ir_generator = tvm.testing.parameter(
launch_env_thread,
opt_gemm_normalize,
Expand Down Expand Up @@ -3861,6 +3877,7 @@ def nested_seqstmt():
tvm_struct_set_generated_in_cpp,
ir_module_with_attrs,
nested_seqstmt,
subroutine_call,
)


Expand Down