From 7ec8017f24ea60d85f228a8ac99d8fdf57e6e010 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Wed, 12 Jul 2023 17:51:38 -0700 Subject: [PATCH] [TIR] Generalize implementation of T.macro to work with other dialects As a background info---the script parser works by visiting a "statement" (or top-level expression) at a time. The expression parts of the state- ment are evaluated, and then the IR corresponding to the statement is constructed if necessary. In TIR, macro calls can only occur at the statement level, and they don't produce any values. This means that the statement visitor (visit_expr_stmt) can see these calls directly in its node parameter. At this point it could simply visit the body of the macro instead, which is the basis of the existing implementation. In other dialects there may be a need for macros to produce values. This means that macro calls can occur in the middle of complex expressions. As a result, these calls will not be present at the statement level, and the TIR approach by intercepting them in visit_expr_stmt will no longer work. Instead, these macros delay the visiting of the macro body to the evaluation time. A macro is represented by an ScriptMacro (TIRMacro in the current implementation) object (created via macro decorator). When the evaluator evaluates an expression with a macro call, it will call the macro object (since macro calls use function call syntax). It is in the macro object's __call__ function where the macro parsing picks up. The remaining issue was to pass the Parser object to the __call__ function. This is done by injecting it into the global dictionary under a reserved name. It turns out that the same approach also works for TIR, and the macro processing can be generalized, leaving only language-specific details to the language-specific language macro objects. --- python/tvm/script/parser/_core.py | 2 +- python/tvm/script/parser/core/entry.py | 6 +- python/tvm/script/parser/core/parser.py | 105 ++++++++++++++++++++++++ python/tvm/script/parser/tir/entry.py | 41 +++------ python/tvm/script/parser/tir/parser.py | 59 +------------ 5 files changed, 119 insertions(+), 94 deletions(-) diff --git a/python/tvm/script/parser/_core.py b/python/tvm/script/parser/_core.py index b7ba5ee4713f..8c29df7e6211 100644 --- a/python/tvm/script/parser/_core.py +++ b/python/tvm/script/parser/_core.py @@ -18,5 +18,5 @@ # pylint: disable=unused-import from .core import dispatch, doc, utils from .core.dispatch import OpMethod, register_op -from .core.entry import parse, parse_macro +from .core.entry import parse, scan_macro from .core.parser import Parser diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index 08a593d5d31b..7604a54b45c2 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -34,14 +34,12 @@ def _default_globals() -> Dict[str, Any]: return extra_vars -def parse_macro(program: Union[Any, str], extra_vars: Dict[str, Any] = None) -> Any: +def scan_macro(program: Union[Any, str], extra_vars: Dict[str, Any] = None) -> Any: """Generate the AST, and the source code for __repr__.""" # The AST will be converted into TIR at the time of expansion. source = Source(program) - source_txt = source.source - source_ast = source.as_ast() closure_vars = extra_vars or _default_globals() - return source_ast, source_txt, closure_vars + return source, closure_vars def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) -> Any: diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index c253f61c3181..7032d194be2f 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -16,6 +16,8 @@ # under the License. """The core parser""" +import abc +import inspect from collections import defaultdict from contextlib import contextmanager from typing import Any, Callable, Dict, List, Optional, Set, Union @@ -65,6 +67,108 @@ def _do_nothing(*args, **kwargs): # pylint: disable=unused-argument pass +class ScriptMacro(abc.ABC): + """Representation of a script macro. + + This is a callable object, intended to be called from the expression evaluator. + The evaluator is expected to insert the current parser into the environment + undef the name given by "parser_object_name". + + Once called, the ScriptMacro object will locate the current parser, and use it + to parse the macro's body and produce the result. + + There were two major considerations for this design: + 1. Implementing hygienic and non-hygienic macros. + 2. Implementing macros that return values. + + Macro uses in TIR are only allowed at a statement-level, and they don't produce + any values. Parsing of such macros could easily be done by intercepting doc.Call + nodes in the TIR parser. If a macro is a value-producing expression, then there + may not be a direct way to intercept calls to it if it's embedded in a complex + expression. Because macros use function-call syntax, the evaluator will try to + call the macro object, which this design relies on to parse and evaluate the macro. + """ + + parser_object_name = "__current_script_parser__" + + def __init__( + self, + source: Source, + closure_vars: Dict[str, Any], + func: Callable, + hygienic: bool, + ) -> None: + self.source = source + self.closure_vars = closure_vars + self.func = func + self.hygienic = hygienic + + def __repr__(self): + return self.source.source + + @abc.abstractmethod + def parse_macro(self, parser: "Parser") -> Any: + """The main macro parsing function. Different scripts may have different + ways to parse a macro, and to return a value to the evaluator. + + Parameters + ---------- + parser : Parser + The parser with the appropriate frame already created and populated depending + macro's hygiene settings, + + Returns + ------- + The return value depends on the specifics of the particular script. It can be + "None" or any other value or any type. + """ + + def _find_parser_def(self): + outer_frame_infos = inspect.getouterframes(inspect.currentframe()) + for finfo in outer_frame_infos: + parser = finfo.frame.f_globals.get(ScriptMacro.parser_object_name) + if parser is not None: + return parser + raise RuntimeError(f"{ScriptMacro.parser_object_name} not available") + + def get_macro_def(self): + ast_module = self.source.as_ast() + for decl in ast_module.body: + if isinstance(decl, doc.FunctionDef) and decl.name == self.__name__: + return decl + raise RuntimeError(f"cannot find macro definition for {self.__name__}") + + def __call__(self, *args, **kwargs): + param_binding = inspect.signature(self.func).bind(*args, **kwargs) + param_binding.apply_defaults() + local_vars = param_binding.arguments + parser = self._find_parser_def() + + if self.hygienic: + saved_var_table = parser.var_table + parser.var_table = VarTable() + + with parser.var_table.with_frame(): + for k, v in self.closure_vars.items(): + parser.var_table.add(k, v) + for k, v in local_vars.items(): + parser.var_table.add(k, v) + + parse_result = self.parse_macro(parser) + + parser.var_table = saved_var_table + + else: + with parser.var_table.with_frame(): + for k, v in local_vars.items(): + parser.var_table.add(k, v) + + print(parser.var_table.get()) + parse_result = self.parse_macro(parser) + + return parse_result + + class VarTableFrame: """The variable table frame. A frame of variable table stores the variables created in one block or scope. @@ -326,6 +430,7 @@ def eval_expr( if extra_vars is not None: for k, v in extra_vars.items(): var_values[k] = v + var_values[ScriptMacro.parser_object_name] = self return eval_expr(self, node, var_values) def _duplicate_lhs_check(self, target: doc.expr) -> Union[bool, Set[str]]: diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index 93bf8721c58e..d2fb070aaab1 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -16,13 +16,14 @@ # under the License. """The entry point of TVM parser for tir.""" import inspect -from typing import Any, Callable, Dict, Optional, Union +from typing import Callable, Optional, Union from tvm.ir.base import deprecated from tvm.tir import Buffer, PrimFunc from ...ir_builder.tir import buffer, ptr -from .._core import doc, parse, parse_macro, utils +from .._core import parse, scan_macro, utils +from ..core.parser import Parser, ScriptMacro def prim_func(func: Optional[Callable] = None, private: bool = False) -> Union[PrimFunc, Callable]: @@ -86,25 +87,12 @@ def decorator_wrapper(func): # inserted at the point where the call to the macro is located. -class TIRMacro: - """Representation of T.macro.""" +class TIRMacro(ScriptMacro): + """Specialization of the ScriptMacro class for TIR.""" - def __init__( - self, - source_ast: doc.AST, - source_txt: str, - closure_vars: Dict[str, Any], - func: Callable, - hygienic: bool, - ) -> None: - self.source_ast = source_ast - self.source_txt = source_txt - self.closure_vars = closure_vars - self.func = func - self.hygienic = hygienic - - def __repr__(self): - return self.source_txt + def parse_macro(self, parser: Parser) -> None: + macro_def = self.get_macro_def() + parser.visit_body(macro_def.body) def macro(*args, hygienic: bool = True) -> Callable: @@ -147,15 +135,9 @@ def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: """ def _decorator(func: Callable) -> TIRMacro: - source_ast, source_txt, closure_vars = parse_macro( - func, utils.inspect_function_capture(func) - ) - obj = TIRMacro(source_ast, source_txt, closure_vars, func, hygienic) + source, closure_vars = scan_macro(func, utils.inspect_function_capture(func)) + obj = TIRMacro(source, closure_vars, func, hygienic) obj.__name__ = func.__name__ - # We don't need to explicitly store the return value anywhere. - # This function is a decorator, so the return value will replace - # the function definition (to which the decorator it is applied) - # in that function's name space. return obj if len(args) == 0: @@ -168,9 +150,6 @@ def _decorator(func: Callable) -> TIRMacro: ) -# There is no dispatch_token for macro, because macro doesn't invoke parser. - - class BufferProxy: """Buffer proxy class for constructing tir buffer.""" diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 5398b471e49d..60bdb7f7924b 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -17,9 +17,8 @@ """The base parser for tir""" import contextlib -import inspect from functools import partial -from typing import Any, Union +from typing import Any import tvm from tvm.ir import GlobalVar, PrimType @@ -30,8 +29,6 @@ from ...ir_builder.base import IRBuilder from ...ir_builder.base import IRBuilderFrame as Frame from .._core import Parser, dispatch, doc -from ..core.parser import VarTable -from .entry import TIRMacro def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any: @@ -447,11 +444,6 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: The doc AST Expr node. """ - if isinstance(node.value, doc.Call): - callee = self.eval_expr(node.value.func) - if isinstance(callee, TIRMacro): - return expand_macro(self, callee, node.value) - res = self.eval_expr(node.value) if res is None: pass @@ -472,7 +464,6 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: pass else: self.report_error(node, f"Parsing resulted in unexpected type {type(res)}") - return None # For pylint @dispatch.register(token="tir", type_name="If") @@ -554,51 +545,3 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar # Only ret_type is needed for func_signature. func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type) return I.decl_function(node.name, func_signature) - - -def expand_macro(self: Parser, callee: TIRMacro, call: doc.Call) -> None: - """Bind arguments to the macro invocation to the parameters in the macro definition, - and pass the macro body for further parsing. - """ - - assert isinstance(callee, TIRMacro), f"Unexpected macro type {type(callee)}" - - def find_macro_def(name: str, decl_list: doc.AST) -> Union[doc.FunctionDef, Any]: - for decl in decl_list: - if isinstance(decl, doc.FunctionDef) and decl.name == name: - return decl - return None - - macro_def = find_macro_def(callee.__name__, callee.source_ast.body) - assert macro_def is not None, f"Invalid macro AST for {callee.__name__}" - # `macro_def` is the FunctionDef of the macro. - - args = [self.eval_expr(arg) for arg in call.args] - kwargs = {kw.arg: self.eval_expr(kw.value) for kw in call.keywords} - param_binding = inspect.signature(callee.func).bind(*args, **kwargs) - param_binding.apply_defaults() - local_vars = param_binding.arguments - - if callee.hygienic: - # If the macro was hygienic, construct new var_table with a single frame that - # contains the captured environment, and process the macro's body with that - # frame. - saved_var_table = self.var_table - self.var_table = VarTable() - with self.var_table.with_frame(): - for k, v in callee.closure_vars.items(): - self.var_table.add(k, v) - for k, v in local_vars.items(): - self.var_table.add(k, v) - - self.visit_body(macro_def.body) - - self.var_table = saved_var_table - - else: - # Otherwise, dynamically resolve symbols in the macro's body. - with self.var_table.with_frame(): - for k, v in local_vars.items(): - self.var_table.add(k, v) - - self.visit_body(macro_def.body)