diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 1ce19c2a9ed2..f8f727d144c7 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -437,6 +437,19 @@ def emit_var_binding(value: VarBinding) -> Var: return _ffi_api.EmitVarBinding(value) # type: ignore +############################### SeqExpr ############################### + + +def SeqExpr() -> frame.SeqExprFrame: # pylint: disable=invalid-name + """Create a SeqExpr frame. + Returns + ------- + res : frame.SeqExprFrame + The result SeqExprFrame + """ + return _ffi_api.SeqExpr() # type: ignore[attr-defined] # pylint: disable=no-member + + ############################# If Then Else ############################# @@ -562,6 +575,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: __all__ = [ "Else", "If", + "SeqExpr", "Then", "TupleGetItem", "abs", diff --git a/python/tvm/script/parser/_core.py b/python/tvm/script/parser/_core.py index b7ba5ee4713f..8c29df7e6211 100644 --- a/python/tvm/script/parser/_core.py +++ b/python/tvm/script/parser/_core.py @@ -18,5 +18,5 @@ # pylint: disable=unused-import from .core import dispatch, doc, utils from .core.dispatch import OpMethod, register_op -from .core.entry import parse, parse_macro +from .core.entry import parse, scan_macro from .core.parser import Parser diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index de1afb6245be..9a7430643cd8 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -35,14 +35,12 @@ def _default_globals() -> Dict[str, Any]: return extra_vars -def parse_macro(program: Union[Any, str], extra_vars: Dict[str, Any] = None) -> Any: +def scan_macro(program: Union[Any, str], extra_vars: Dict[str, Any] = None) -> Any: """Generate the AST, and the source code for __repr__.""" # The AST will be converted into TIR at the time of expansion. source = Source(program) - source_txt = source.source - source_ast = source.as_ast() closure_vars = extra_vars or _default_globals() - return source_ast, source_txt, closure_vars + return source, closure_vars def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) -> Any: diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 69e262b1d327..7bdcc8e69a25 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -16,6 +16,8 @@ # under the License. """The core parser""" +import abc +import inspect from collections import defaultdict from contextlib import contextmanager from typing import Any, Callable, Dict, List, Optional, Set, Union @@ -66,6 +68,108 @@ def _do_nothing(*args, **kwargs): # pylint: disable=unused-argument pass +class ScriptMacro(abc.ABC): + """Representation of a script macro. + + This is a callable object, intended to be called from the expression evaluator. + The evaluator is expected to insert the current parser into the environment + undef the name given by "parser_object_name". + + Once called, the ScriptMacro object will locate the current parser, and use it + to parse the macro's body and produce the result. + + There were two major considerations for this design: + 1. Implementing hygienic and non-hygienic macros. + 2. Implementing macros that return values. + + Macro uses in TIR are only allowed at a statement-level, and they don't produce + any values. Parsing of such macros could easily be done by intercepting doc.Call + nodes in the TIR parser. If a macro is a value-producing expression, then there + may not be a direct way to intercept calls to it if it's embedded in a complex + expression. Because macros use function-call syntax, the evaluator will try to + call the macro object, which this design relies on to parse and evaluate the macro. + """ + + parser_object_name = "__current_script_parser__" + + def __init__( + self, + source: Source, + closure_vars: Dict[str, Any], + func: Callable, + hygienic: bool, + ) -> None: + self.source = source + self.closure_vars = closure_vars + self.func = func + self.hygienic = hygienic + + def __repr__(self): + return self.source.source + + @abc.abstractmethod + def parse_macro(self, parser: "Parser") -> Any: + """The main macro parsing function. Different scripts may have different + ways to parse a macro, and to return a value to the evaluator. + + Parameters + ---------- + parser : Parser + The parser with the appropriate frame already created and populated depending + macro's hygiene settings, + + Returns + ------- + The return value depends on the specifics of the particular script. It can be + "None" or any other value or any type. + """ + + def _find_parser_def(self): + outer_frame_infos = inspect.getouterframes(inspect.currentframe()) + for finfo in outer_frame_infos: + parser = finfo.frame.f_globals.get(ScriptMacro.parser_object_name) + if parser is not None: + return parser + raise RuntimeError(f"{ScriptMacro.parser_object_name} not available") + + def get_macro_def(self): + ast_module = self.source.as_ast() + for decl in ast_module.body: + if isinstance(decl, doc.FunctionDef) and decl.name == self.__name__: + return decl + raise RuntimeError(f"cannot find macro definition for {self.__name__}") + + def __call__(self, *args, **kwargs): + param_binding = inspect.signature(self.func).bind(*args, **kwargs) + param_binding.apply_defaults() + local_vars = param_binding.arguments + parser = self._find_parser_def() + + if self.hygienic: + saved_var_table = parser.var_table + parser.var_table = VarTable() + + with parser.var_table.with_frame(): + for k, v in self.closure_vars.items(): + parser.var_table.add(k, v) + for k, v in local_vars.items(): + parser.var_table.add(k, v) + + parse_result = self.parse_macro(parser) + + parser.var_table = saved_var_table + + else: + with parser.var_table.with_frame(): + for k, v in local_vars.items(): + parser.var_table.add(k, v) + + print(parser.var_table.get()) + parse_result = self.parse_macro(parser) + + return parse_result + + class VarTableFrame: """The variable table frame. A frame of variable table stores the variables created in one block or scope. @@ -336,6 +440,7 @@ def eval_expr( if extra_vars is not None: for k, v in extra_vars.items(): var_values[k] = v + var_values[ScriptMacro.parser_object_name] = self return eval_expr(self, node, var_values) def _duplicate_lhs_check(self, target: doc.expr) -> Union[bool, Set[str]]: diff --git a/python/tvm/script/parser/relax/__init__.py b/python/tvm/script/parser/relax/__init__.py index 7c93bdebba8c..0b1ed168de1d 100644 --- a/python/tvm/script/parser/relax/__init__.py +++ b/python/tvm/script/parser/relax/__init__.py @@ -32,7 +32,7 @@ # so most tvmscript won't trigger pylint error here. function = staticmethod else: - from .entry import function + from .entry import function, macro __all__ = ( _relax.__all__ @@ -45,6 +45,7 @@ "Tensor", "Tuple", "function", + "macro", "match_cast", ] ) diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index ff237a5600e7..057a2d70c05e 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -22,6 +22,7 @@ from tvm.relax import ( Expr, + SeqExpr, ShapeExpr, FuncStructInfo, Function, @@ -36,7 +37,10 @@ from tvm.runtime import ObjectGeneric from tvm.tir import PrimExpr -from .._core import parse, utils +from .._core import doc, parse, utils +from ..core.entry import scan_macro +from ..core.parser import Parser, ScriptMacro +from ...ir_builder import relax as R FType = TypeVar("FType", bound=_Callable) @@ -75,6 +79,65 @@ def decorator_wrapper(f): setattr(function, "dispatch_token", "relax") +############################## R.macro ############################## + + +class RelaxMacro(ScriptMacro): + """Specialization of the ScriptMacro class for Relax.""" + + def parse_macro(self, parser: Parser) -> Expr: + macro_def = self.get_macro_def() + ret_value = None + + with R.SeqExpr() as seq: + for idx, stmt in enumerate(macro_def.body): + # Normally, a "return" statement is only allowed in a R.function. We don't + # want to parse the macro's body as if it was a body of a function, because + # the latter imposes some constraints that we want to avoid. + # At the same time, we want to use "return" to indicate the value of the + # macro (since in Relax everything is an expression), so add special handling + # of "return". + if isinstance(stmt, doc.Return): + ret_value = parser.eval_expr(stmt.value) + if idx + 1 != len(macro_def.body): + parser.report_error(macro_def, "'return' should be the last statement") + break + parser.visit(stmt) + + if ret_value is None: + parser.report_error(macro_def, "Macros must end with a return statement") + + return SeqExpr(seq.binding_blocks, ret_value) + + +def macro(*args, hygienic: bool = True) -> _Callable: + """Decorator for macro definitions. + + Parameters + ---------- + hygienic: bool + Specifies whether the macro is hygienic or not. + A macro is hygienic if all symbols used in the macro's body are resolved + to values from the location of the macro definition. A non-hygienic macro + will have its symbols resolved to values at the time of the macro's use. + """ + + def _decorator(func: _Callable) -> ScriptMacro: + source, closure_vars = scan_macro(func, utils.inspect_function_capture(func)) + obj = RelaxMacro(source, closure_vars, func, hygienic) + obj.__name__ = func.__name__ + return obj + + if len(args) == 0: + return _decorator + if len(args) == 1 and inspect.isfunction(args[0]): + return _decorator(args[0]) + + raise ValueError( + "Invalid use of R.macro. Usage: @R.macro, @R.macro(), @R.macro(hygienic=[True|False])" + ) + + ############################# Struct Info ############################## diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index 64b71d699f3d..63db31b0d933 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -16,13 +16,14 @@ # under the License. """The entry point of TVM parser for tir.""" import inspect -from typing import Any, Callable, Dict, Union +from typing import Callable, Union from tvm.ir.base import deprecated from tvm.tir import Buffer, PrimFunc from ...ir_builder.tir import buffer, ptr -from .._core import doc, parse, parse_macro, utils +from .._core import parse, scan_macro, utils +from ..core.parser import Parser, ScriptMacro def prim_func(func: Callable) -> Union[PrimFunc, Callable]: @@ -60,25 +61,12 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]: # inserted at the point where the call to the macro is located. -class TIRMacro: - """Representation of T.macro.""" +class TIRMacro(ScriptMacro): + """Specialization of the ScriptMacro class for TIR.""" - def __init__( - self, - source_ast: doc.AST, - source_txt: str, - closure_vars: Dict[str, Any], - func: Callable, - hygienic: bool, - ) -> None: - self.source_ast = source_ast - self.source_txt = source_txt - self.closure_vars = closure_vars - self.func = func - self.hygienic = hygienic - - def __repr__(self): - return self.source_txt + def parse_macro(self, parser: Parser) -> None: + macro_def = self.get_macro_def() + parser.visit_body(macro_def.body) def macro(*args, hygienic: bool = True) -> Callable: @@ -121,15 +109,9 @@ def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: """ def _decorator(func: Callable) -> TIRMacro: - source_ast, source_txt, closure_vars = parse_macro( - func, utils.inspect_function_capture(func) - ) - obj = TIRMacro(source_ast, source_txt, closure_vars, func, hygienic) + source, closure_vars = scan_macro(func, utils.inspect_function_capture(func)) + obj = TIRMacro(source, closure_vars, func, hygienic) obj.__name__ = func.__name__ - # We don't need to explicitly store the return value anywhere. - # This function is a decorator, so the return value will replace - # the function definition (to which the decorator it is applied) - # in that function's name space. return obj if len(args) == 0: @@ -142,9 +124,6 @@ def _decorator(func: Callable) -> TIRMacro: ) -# There is no dispatch_token for macro, because macro doesn't invoke parser. - - class BufferProxy: """Buffer proxy class for constructing tir buffer.""" diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 67e14d0e9772..9ea5d955c617 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -17,9 +17,8 @@ """The base parser for tir""" import contextlib -import inspect from functools import partial -from typing import Any, Union +from typing import Any import tvm from tvm.ir import GlobalVar, PrimType @@ -30,8 +29,6 @@ from ...ir_builder.base import IRBuilder from ...ir_builder.base import IRBuilderFrame as Frame from .._core import Parser, dispatch, doc -from ..core.parser import VarTable -from .entry import TIRMacro def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any: @@ -431,11 +428,6 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: The doc AST Expr node. """ - if isinstance(node.value, doc.Call): - callee = self.eval_expr(node.value.func) - if isinstance(callee, TIRMacro): - return expand_macro(self, callee, node.value) - res = self.eval_expr(node.value) if res is None: pass @@ -456,7 +448,6 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: pass else: self.report_error(node, f"Parsing resulted in unexpected type {type(res)}") - return None # For pylint @dispatch.register(token="tir", type_name="If") @@ -538,51 +529,3 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar # Only ret_type is needed for func_signature. func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type) return I.decl_function(node.name, func_signature) - - -def expand_macro(self: Parser, callee: TIRMacro, call: doc.Call) -> None: - """Bind arguments to the macro invocation to the parameters in the macro definition, - and pass the macro body for further parsing. - """ - - assert isinstance(callee, TIRMacro), f"Unexpected macro type {type(callee)}" - - def find_macro_def(name: str, decl_list: doc.AST) -> Union[doc.FunctionDef, Any]: - for decl in decl_list: - if isinstance(decl, doc.FunctionDef) and decl.name == name: - return decl - return None - - macro_def = find_macro_def(callee.__name__, callee.source_ast.body) - assert macro_def is not None, f"Invalid macro AST for {callee.__name__}" - # `macro_def` is the FunctionDef of the macro. - - args = [self.eval_expr(arg) for arg in call.args] - kwargs = {kw.arg: self.eval_expr(kw.value) for kw in call.keywords} - param_binding = inspect.signature(callee.func).bind(*args, **kwargs) - param_binding.apply_defaults() - local_vars = param_binding.arguments - - if callee.hygienic: - # If the macro was hygienic, construct new var_table with a single frame that - # contains the captured environment, and process the macro's body with that - # frame. - saved_var_table = self.var_table - self.var_table = VarTable() - with self.var_table.with_frame(): - for k, v in callee.closure_vars.items(): - self.var_table.add(k, v) - for k, v in local_vars.items(): - self.var_table.add(k, v) - - self.visit_body(macro_def.body) - - self.var_table = saved_var_table - - else: - # Otherwise, dynamically resolve symbols in the macro's body. - with self.var_table.with_frame(): - for k, v in local_vars.items(): - self.var_table.add(k, v) - - self.visit_body(macro_def.body) diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 116cd02eb573..9af52fa80bd4 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -221,6 +221,15 @@ TVM_REGISTER_GLOBAL("script.ir_builder.relax.Emit").set_body_typed(Emit); TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitMatchCast").set_body_typed(EmitMatchCast); TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitVarBinding").set_body_typed(EmitVarBinding); +/////////////////////////////// SeqExpr /////////////////////////////// + +SeqExprFrame SeqExpr() { + ObjectPtr n = make_object(); + return SeqExprFrame(n); +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.SeqExpr").set_body_typed(SeqExpr); + ///////////////////////////// If Then Else ///////////////////////////// IfFrame If(tvm::relax::Expr condition) {