Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,19 @@ def emit_var_binding(value: VarBinding) -> Var:
return _ffi_api.EmitVarBinding(value) # type: ignore


############################### SeqExpr ###############################


def SeqExpr() -> frame.SeqExprFrame: # pylint: disable=invalid-name
"""Create a SeqExpr frame.
Returns
-------
res : frame.SeqExprFrame
The result SeqExprFrame
"""
return _ffi_api.SeqExpr() # type: ignore[attr-defined] # pylint: disable=no-member


############################# If Then Else #############################


Expand Down Expand Up @@ -562,6 +575,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
__all__ = [
"Else",
"If",
"SeqExpr",
"Then",
"TupleGetItem",
"abs",
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/parser/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
# pylint: disable=unused-import
from .core import dispatch, doc, utils
from .core.dispatch import OpMethod, register_op
from .core.entry import parse, parse_macro
from .core.entry import parse, scan_macro
from .core.parser import Parser
6 changes: 2 additions & 4 deletions python/tvm/script/parser/core/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,12 @@ def _default_globals() -> Dict[str, Any]:
return extra_vars


def parse_macro(program: Union[Any, str], extra_vars: Dict[str, Any] = None) -> Any:
def scan_macro(program: Union[Any, str], extra_vars: Dict[str, Any] = None) -> Any:
"""Generate the AST, and the source code for __repr__."""
# The AST will be converted into TIR at the time of expansion.
source = Source(program)
source_txt = source.source
source_ast = source.as_ast()
closure_vars = extra_vars or _default_globals()
return source_ast, source_txt, closure_vars
return source, closure_vars


def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) -> Any:
Expand Down
105 changes: 105 additions & 0 deletions python/tvm/script/parser/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
"""The core parser"""

import abc
import inspect
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Set, Union
Expand Down Expand Up @@ -66,6 +68,108 @@ def _do_nothing(*args, **kwargs): # pylint: disable=unused-argument
pass


class ScriptMacro(abc.ABC):
"""Representation of a script macro.

This is a callable object, intended to be called from the expression evaluator.
The evaluator is expected to insert the current parser into the environment
undef the name given by "parser_object_name".

Once called, the ScriptMacro object will locate the current parser, and use it
to parse the macro's body and produce the result.

There were two major considerations for this design:
1. Implementing hygienic and non-hygienic macros.
2. Implementing macros that return values.

Macro uses in TIR are only allowed at a statement-level, and they don't produce
any values. Parsing of such macros could easily be done by intercepting doc.Call
nodes in the TIR parser. If a macro is a value-producing expression, then there
may not be a direct way to intercept calls to it if it's embedded in a complex
expression. Because macros use function-call syntax, the evaluator will try to
call the macro object, which this design relies on to parse and evaluate the macro.
"""

parser_object_name = "__current_script_parser__"

def __init__(
self,
source: Source,
closure_vars: Dict[str, Any],
func: Callable,
hygienic: bool,
) -> None:
self.source = source
self.closure_vars = closure_vars
self.func = func
self.hygienic = hygienic

def __repr__(self):
return self.source.source

@abc.abstractmethod
def parse_macro(self, parser: "Parser") -> Any:
"""The main macro parsing function. Different scripts may have different
ways to parse a macro, and to return a value to the evaluator.

Parameters
----------
parser : Parser
The parser with the appropriate frame already created and populated depending
macro's hygiene settings,

Returns
-------
The return value depends on the specifics of the particular script. It can be
"None" or any other value or any type.
"""

def _find_parser_def(self):
outer_frame_infos = inspect.getouterframes(inspect.currentframe())
for finfo in outer_frame_infos:
parser = finfo.frame.f_globals.get(ScriptMacro.parser_object_name)
if parser is not None:
return parser
raise RuntimeError(f"{ScriptMacro.parser_object_name} not available")

def get_macro_def(self):
ast_module = self.source.as_ast()
for decl in ast_module.body:
if isinstance(decl, doc.FunctionDef) and decl.name == self.__name__:
return decl
raise RuntimeError(f"cannot find macro definition for {self.__name__}")

def __call__(self, *args, **kwargs):
param_binding = inspect.signature(self.func).bind(*args, **kwargs)
param_binding.apply_defaults()
local_vars = param_binding.arguments
parser = self._find_parser_def()

if self.hygienic:
saved_var_table = parser.var_table
parser.var_table = VarTable()

with parser.var_table.with_frame():
for k, v in self.closure_vars.items():
parser.var_table.add(k, v)
for k, v in local_vars.items():
parser.var_table.add(k, v)

parse_result = self.parse_macro(parser)

parser.var_table = saved_var_table

else:
with parser.var_table.with_frame():
for k, v in local_vars.items():
parser.var_table.add(k, v)

print(parser.var_table.get())
parse_result = self.parse_macro(parser)

return parse_result


class VarTableFrame:
"""The variable table frame.
A frame of variable table stores the variables created in one block or scope.
Expand Down Expand Up @@ -336,6 +440,7 @@ def eval_expr(
if extra_vars is not None:
for k, v in extra_vars.items():
var_values[k] = v
var_values[ScriptMacro.parser_object_name] = self
return eval_expr(self, node, var_values)

def _duplicate_lhs_check(self, target: doc.expr) -> Union[bool, Set[str]]:
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/script/parser/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# so most tvmscript won't trigger pylint error here.
function = staticmethod
else:
from .entry import function
from .entry import function, macro

__all__ = (
_relax.__all__
Expand All @@ -45,6 +45,7 @@
"Tensor",
"Tuple",
"function",
"macro",
"match_cast",
]
)
65 changes: 64 additions & 1 deletion python/tvm/script/parser/relax/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from tvm.relax import (
Expr,
SeqExpr,
ShapeExpr,
FuncStructInfo,
Function,
Expand All @@ -36,7 +37,10 @@
from tvm.runtime import ObjectGeneric
from tvm.tir import PrimExpr

from .._core import parse, utils
from .._core import doc, parse, utils
from ..core.entry import scan_macro
from ..core.parser import Parser, ScriptMacro
from ...ir_builder import relax as R

FType = TypeVar("FType", bound=_Callable)

Expand Down Expand Up @@ -75,6 +79,65 @@ def decorator_wrapper(f):
setattr(function, "dispatch_token", "relax")


############################## R.macro ##############################


class RelaxMacro(ScriptMacro):
"""Specialization of the ScriptMacro class for Relax."""

def parse_macro(self, parser: Parser) -> Expr:
macro_def = self.get_macro_def()
ret_value = None

with R.SeqExpr() as seq:
for idx, stmt in enumerate(macro_def.body):
# Normally, a "return" statement is only allowed in a R.function. We don't
# want to parse the macro's body as if it was a body of a function, because
# the latter imposes some constraints that we want to avoid.
# At the same time, we want to use "return" to indicate the value of the
# macro (since in Relax everything is an expression), so add special handling
# of "return".
if isinstance(stmt, doc.Return):
ret_value = parser.eval_expr(stmt.value)
if idx + 1 != len(macro_def.body):
parser.report_error(macro_def, "'return' should be the last statement")
break
parser.visit(stmt)

if ret_value is None:
parser.report_error(macro_def, "Macros must end with a return statement")

return SeqExpr(seq.binding_blocks, ret_value)


def macro(*args, hygienic: bool = True) -> _Callable:
"""Decorator for macro definitions.

Parameters
----------
hygienic: bool
Specifies whether the macro is hygienic or not.
A macro is hygienic if all symbols used in the macro's body are resolved
to values from the location of the macro definition. A non-hygienic macro
will have its symbols resolved to values at the time of the macro's use.
"""

def _decorator(func: _Callable) -> ScriptMacro:
source, closure_vars = scan_macro(func, utils.inspect_function_capture(func))
obj = RelaxMacro(source, closure_vars, func, hygienic)
obj.__name__ = func.__name__
return obj

if len(args) == 0:
return _decorator
if len(args) == 1 and inspect.isfunction(args[0]):
return _decorator(args[0])

raise ValueError(
"Invalid use of R.macro. Usage: @R.macro, @R.macro(), @R.macro(hygienic=[True|False])"
)


############################# Struct Info ##############################


Expand Down
41 changes: 10 additions & 31 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
# under the License.
"""The entry point of TVM parser for tir."""
import inspect
from typing import Any, Callable, Dict, Union
from typing import Callable, Union

from tvm.ir.base import deprecated
from tvm.tir import Buffer, PrimFunc

from ...ir_builder.tir import buffer, ptr
from .._core import doc, parse, parse_macro, utils
from .._core import parse, scan_macro, utils
from ..core.parser import Parser, ScriptMacro


def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
Expand Down Expand Up @@ -60,25 +61,12 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
# inserted at the point where the call to the macro is located.


class TIRMacro:
"""Representation of T.macro."""
class TIRMacro(ScriptMacro):
"""Specialization of the ScriptMacro class for TIR."""

def __init__(
self,
source_ast: doc.AST,
source_txt: str,
closure_vars: Dict[str, Any],
func: Callable,
hygienic: bool,
) -> None:
self.source_ast = source_ast
self.source_txt = source_txt
self.closure_vars = closure_vars
self.func = func
self.hygienic = hygienic

def __repr__(self):
return self.source_txt
def parse_macro(self, parser: Parser) -> None:
macro_def = self.get_macro_def()
parser.visit_body(macro_def.body)


def macro(*args, hygienic: bool = True) -> Callable:
Expand Down Expand Up @@ -121,15 +109,9 @@ def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
"""

def _decorator(func: Callable) -> TIRMacro:
source_ast, source_txt, closure_vars = parse_macro(
func, utils.inspect_function_capture(func)
)
obj = TIRMacro(source_ast, source_txt, closure_vars, func, hygienic)
source, closure_vars = scan_macro(func, utils.inspect_function_capture(func))
obj = TIRMacro(source, closure_vars, func, hygienic)
obj.__name__ = func.__name__
# We don't need to explicitly store the return value anywhere.
# This function is a decorator, so the return value will replace
# the function definition (to which the decorator it is applied)
# in that function's name space.
return obj

if len(args) == 0:
Expand All @@ -142,9 +124,6 @@ def _decorator(func: Callable) -> TIRMacro:
)


# There is no dispatch_token for macro, because macro doesn't invoke parser.


class BufferProxy:
"""Buffer proxy class for constructing tir buffer."""

Expand Down
Loading