Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/script/parser/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
# pylint: disable=unused-import
from .core import dispatch, doc, utils
from .core.dispatch import OpMethod, register_op
from .core.entry import parse
from .core.entry import gen_ast, parse
from .core.parser import Parser
7 changes: 7 additions & 0 deletions python/tvm/script/parser/core/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
from .parser import Parser


def gen_ast(program: Union[Any, str], extra_vars: Dict[str, Any] = None) -> Any:
# Simply generate the AST. It will be parsed at the time of inclusion.
source = Source(program)
node = source.as_ast()
return node


def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) -> Any:
"""Register a method for a operand type, AST operator node and operand index.

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/script/parser/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@
# so most tvmscript won't trigger pylint error here.
prim_func = staticmethod
else:
from .entry import prim_func
from .entry import prim_func, macro, include

__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func"]
__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func", "macro", "include"]
22 changes: 20 additions & 2 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
# under the License.
"""The entry point of TVM parser for tir."""
import inspect
from typing import Callable, Union
from typing import Any, Callable, Union

from tvm.ir.base import deprecated
from tvm.tir import Buffer, PrimFunc

from ...ir_builder.tir import buffer, ptr
from .._core import parse, utils
from .._core import doc, gen_ast, parse, utils


def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
Expand Down Expand Up @@ -50,6 +50,24 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
setattr(prim_func, "dispatch_token", "tir")


def macro(func: Callable) -> doc.AST:
f = gen_ast(func, utils.inspect_function_capture(func))
setattr(f, "__name__", func.__name__)
# We don't need to explicitly store the return value anywhere.
# This function is a decorator, so the return value will replace
# the function definition (to which the decorator it is applied)
# in that function's name space.
return f


# There is no dispatch_token for macro, because macro doesn't invoke parser.


def include(name: Union[str, doc.Name], *args, **kwargs) -> Any:
"""Placeholder function, so that T.include can be parsed without errors."""
pass


class BufferProxy:
"""Buffer proxy class for constructing tir buffer."""

Expand Down
86 changes: 85 additions & 1 deletion python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
# under the License.
"""The base parser for tir"""

import ast
import contextlib
from functools import partial
from typing import Any
from typing import Any, Union

import tvm
from tvm.ir import GlobalVar, PrimType
Expand Down Expand Up @@ -427,6 +428,19 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
node : doc.Expr
The doc AST Expr node.
"""
def is_include_macro(node: doc.Call) -> bool:
if not isinstance(node.func, doc.Attribute):
return False
attr = node.func
if not isinstance(attr.value, doc.Name):
return False
if attr.value.id != "T" or attr.attr != "include":
return False
return True

if isinstance(node.value, doc.Call) and is_include_macro(node.value):
return process_include_macro(self, node.value)

res = self.eval_expr(node.value)
if res is None:
pass
Expand Down Expand Up @@ -528,3 +542,73 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar
# Only ret_type is needed for func_signature.
func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type)
return I.decl_function(node.name, func_signature)


def process_include_macro(self: Parser, call: doc.Call) -> None:
def find_macro_def(name: str, decl_list: doc.AST) -> Union[doc.FunctionDef, Any]:
for decl in decl_list:
if isinstance(decl, doc.FunctionDef) and decl.name == name:
return decl
return None

macro_name = call.args[0]

if not isinstance(macro_name, str):
if not isinstance(macro_name, doc.Name):
self.report_error(node, "Invalid macro name in T.include")
macro_name = macro_name.id

macro = self.var_table.get().get(macro_name)
if macro is None:
self.report_error(node, f"Undefined macro '{macro_name}'")

if isinstance(macro, doc.Module):
macro_def = find_macro_def(macro_name, macro.body)
elif not isinstance(macro, doc.FunctionDef) or macro.name != macro_name:
macro_def = None

if macro_def is None:
self.report_error(macro, f"Undefined macro {macro_name}")

# `macro_def` is a FunctionDef of the macro.

# We have the AST for the macro definition, and the AST for the call. We need to
# substitute the actual arguments from the call for the parameters from the
# definition. To allow full flexibility of python, i.e. positional, unnamed, and
# keyword parameters, get the python interpreter to do the work: create and execute
# the following:
# ```
# def macro_name(...macro parameters...)
# return locals()
# tmp = macro_name(...arguments from the call...)
# ```
# Obtain the dictionary `tmp` resulting from the execution, and update the var_table
# with it.

# Construct the function with the macro's parameters, and returning locals().
macro_ast = doc.from_doc(macro_def)
macro_ast.body = [
ast.Return(value=ast.Call(func=ast.Name("locals", ctx=ast.Load()), args=[], keywords=[]))
]
macro_ast.decorator_list = []

# Construct the assignment with the call.
call_ast = doc.from_doc(call)
call_ast.func = ast.Name(macro_name, ctx=ast.Load())
call_ast.args = call_ast.args[1:]
tmp_name = "__tmp_param_eval_64e98b523301204b"
assign_ast = ast.Assign(targets=[ast.Name(tmp_name, ctx=ast.Store())], value=call_ast)

# Finalize and execute the module:
module_ast = ast.Module(body=[macro_ast, assign_ast], type_ignores=[])
module_ast = ast.fix_missing_locations(module_ast)
cmacro = compile(module_ast, filename="<tmp-string>", mode="exec")
local_vars = {}
exec(cmacro, self.var_table.get(), local_vars)
local_vars = local_vars[tmp_name]

with self.var_table.with_frame():
for k, v in local_vars.items():
self.var_table.add(k, v)

self.visit_body(macro_def.body)