diff --git a/python/tvm/script/parser/_core.py b/python/tvm/script/parser/_core.py index 4f5411dc368f..16e9dca190c2 100644 --- a/python/tvm/script/parser/_core.py +++ b/python/tvm/script/parser/_core.py @@ -18,5 +18,5 @@ # pylint: disable=unused-import from .core import dispatch, doc, utils from .core.dispatch import OpMethod, register_op -from .core.entry import parse +from .core.entry import gen_ast, parse from .core.parser import Parser diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index 5315c0f6755e..9be6aec73c24 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -25,6 +25,13 @@ from .parser import Parser +def gen_ast(program: Union[Any, str], extra_vars: Dict[str, Any] = None) -> Any: + # Simply generate the AST. It will be parsed at the time of inclusion. + source = Source(program) + node = source.as_ast() + return node + + def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) -> Any: """Register a method for a operand type, AST operator node and operand index. diff --git a/python/tvm/script/parser/tir/__init__.py b/python/tvm/script/parser/tir/__init__.py index ad16821a89a3..6f34cdb11043 100644 --- a/python/tvm/script/parser/tir/__init__.py +++ b/python/tvm/script/parser/tir/__init__.py @@ -30,6 +30,6 @@ # so most tvmscript won't trigger pylint error here. prim_func = staticmethod else: - from .entry import prim_func + from .entry import prim_func, macro, include -__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func"] +__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func", "macro", "include"] diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index d5bff7a856d5..bbcdbf628183 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -16,13 +16,13 @@ # under the License. """The entry point of TVM parser for tir.""" import inspect -from typing import Callable, Union +from typing import Any, Callable, Union from tvm.ir.base import deprecated from tvm.tir import Buffer, PrimFunc from ...ir_builder.tir import buffer, ptr -from .._core import parse, utils +from .._core import doc, gen_ast, parse, utils def prim_func(func: Callable) -> Union[PrimFunc, Callable]: @@ -50,6 +50,24 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]: setattr(prim_func, "dispatch_token", "tir") +def macro(func: Callable) -> doc.AST: + f = gen_ast(func, utils.inspect_function_capture(func)) + setattr(f, "__name__", func.__name__) + # We don't need to explicitly store the return value anywhere. + # This function is a decorator, so the return value will replace + # the function definition (to which the decorator it is applied) + # in that function's name space. + return f + + +# There is no dispatch_token for macro, because macro doesn't invoke parser. + + +def include(name: Union[str, doc.Name], *args, **kwargs) -> Any: + """Placeholder function, so that T.include can be parsed without errors.""" + pass + + class BufferProxy: """Buffer proxy class for constructing tir buffer.""" diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index f81f9bd9ea78..388439cd0cc9 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -16,9 +16,10 @@ # under the License. """The base parser for tir""" +import ast import contextlib from functools import partial -from typing import Any +from typing import Any, Union import tvm from tvm.ir import GlobalVar, PrimType @@ -427,6 +428,19 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: node : doc.Expr The doc AST Expr node. """ + def is_include_macro(node: doc.Call) -> bool: + if not isinstance(node.func, doc.Attribute): + return False + attr = node.func + if not isinstance(attr.value, doc.Name): + return False + if attr.value.id != "T" or attr.attr != "include": + return False + return True + + if isinstance(node.value, doc.Call) and is_include_macro(node.value): + return process_include_macro(self, node.value) + res = self.eval_expr(node.value) if res is None: pass @@ -528,3 +542,73 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar # Only ret_type is needed for func_signature. func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type) return I.decl_function(node.name, func_signature) + + +def process_include_macro(self: Parser, call: doc.Call) -> None: + def find_macro_def(name: str, decl_list: doc.AST) -> Union[doc.FunctionDef, Any]: + for decl in decl_list: + if isinstance(decl, doc.FunctionDef) and decl.name == name: + return decl + return None + + macro_name = call.args[0] + + if not isinstance(macro_name, str): + if not isinstance(macro_name, doc.Name): + self.report_error(node, "Invalid macro name in T.include") + macro_name = macro_name.id + + macro = self.var_table.get().get(macro_name) + if macro is None: + self.report_error(node, f"Undefined macro '{macro_name}'") + + if isinstance(macro, doc.Module): + macro_def = find_macro_def(macro_name, macro.body) + elif not isinstance(macro, doc.FunctionDef) or macro.name != macro_name: + macro_def = None + + if macro_def is None: + self.report_error(macro, f"Undefined macro {macro_name}") + + # `macro_def` is a FunctionDef of the macro. + + # We have the AST for the macro definition, and the AST for the call. We need to + # substitute the actual arguments from the call for the parameters from the + # definition. To allow full flexibility of python, i.e. positional, unnamed, and + # keyword parameters, get the python interpreter to do the work: create and execute + # the following: + # ``` + # def macro_name(...macro parameters...) + # return locals() + # tmp = macro_name(...arguments from the call...) + # ``` + # Obtain the dictionary `tmp` resulting from the execution, and update the var_table + # with it. + + # Construct the function with the macro's parameters, and returning locals(). + macro_ast = doc.from_doc(macro_def) + macro_ast.body = [ + ast.Return(value=ast.Call(func=ast.Name("locals", ctx=ast.Load()), args=[], keywords=[])) + ] + macro_ast.decorator_list = [] + + # Construct the assignment with the call. + call_ast = doc.from_doc(call) + call_ast.func = ast.Name(macro_name, ctx=ast.Load()) + call_ast.args = call_ast.args[1:] + tmp_name = "__tmp_param_eval_64e98b523301204b" + assign_ast = ast.Assign(targets=[ast.Name(tmp_name, ctx=ast.Store())], value=call_ast) + + # Finalize and execute the module: + module_ast = ast.Module(body=[macro_ast, assign_ast], type_ignores=[]) + module_ast = ast.fix_missing_locations(module_ast) + cmacro = compile(module_ast, filename="", mode="exec") + local_vars = {} + exec(cmacro, self.var_table.get(), local_vars) + local_vars = local_vars[tmp_name] + + with self.var_table.with_frame(): + for k, v in local_vars.items(): + self.var_table.add(k, v) + + self.visit_body(macro_def.body)