From f29100b4f0c82c7a2f8b497821ab103997a28d10 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 11 Apr 2022 09:38:51 -0500 Subject: [PATCH 1/2] [TVMScript] Allow parameters in @T.prim_func definition This extends the existing `@T.prim_func` decorator to accept a keyword-only argument `@T.prim_func(params=param_dict)`, defining a string to object mapping. This mapping allows variables inside a PrimFunc to refer to values set outside the `PrimFunc` definition. --- python/tvm/script/parser.py | 27 +++++++++++++++++++++++---- python/tvm/script/tir/prim_func.py | 28 ++++++++++++++++++++++++---- 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 32919128e063..c2042c7a2c9d 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -158,8 +158,11 @@ class TVMScriptParser(Transformer): # pylint gets confused here with synr.Transformer which doesn't have a # custom init, so just disable it - def __init__(self, base_lineno, tir_namespace): # pylint: disable=super-init-not-called + def __init__( + self, base_lineno, tir_namespace, params: Optional[Dict[str, Any]] = None + ): # pylint: disable=super-init-not-called self.context = None + self.params = params self.base_lineno = base_lineno self.current_lineno = 0 @@ -435,6 +438,10 @@ def check_decorator(decorators: List[ast.Expr]) -> bool: if len(decorators) != 1: return False d: ast.Expr = decorators[0] + # Allow the decorator to be either T.prim_func or + # T.prim_func(params=...) + if isinstance(d, ast.Call): + d = d.func_name return ( isinstance(d, ast.Attr) and isinstance(d.object, ast.Var) @@ -1136,6 +1143,12 @@ def transform_Var(self, node): symbol = self.context.lookup_symbol(name) if symbol is not None: return symbol + + if self.params and name in self.params: + obj = self.params[name] + span = tvm_span_from_synr(node.span) + return tvm.runtime.convert(obj, span=span) + self.report_error(f"Unknown identifier {name}.", node.span) def transform_TypeVar(self, node): @@ -1147,6 +1160,10 @@ def transform_TypeVar(self, node): symbol = Registry.lookup(name) or self.context.lookup_symbol(name) if symbol is not None: return symbol + + if self.params and name in self.params: + return self.params[name] + self.report_error(f"Unknown identifier {name}.", node.span) def transform_Constant(self, node): @@ -1230,7 +1247,9 @@ def get_tir_namespace(script: Union[Callable, type]) -> List[str]: def from_source( - input_func: Union[str, Callable], tir_prefix: Optional[List[str]] = None + input_func: Union[str, Callable], + tir_prefix: Optional[List[str]] = None, + params: Optional[Dict[str, Any]] = None, ) -> Union[PrimFunc, IRModule]: """Parse function or string into PrimFunc or IRModule. @@ -1252,12 +1271,12 @@ def from_source( """ if isinstance(input_func, str): tir_prefix = ["T", "tir"] if tir_prefix is None else tir_prefix - return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix)) + return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix, params)) elif inspect.isfunction(input_func): _, start_line = inspect.getsourcelines(input_func) env: Dict[str, Any] = input_func.__globals__ namespace = [key for key in env.keys() if env[key] is tir] - parser = TVMScriptParser(start_line, namespace) + parser = TVMScriptParser(start_line, namespace, params) result = to_ast(input_func, TVMDiagnosticCtx(), parser) return result else: diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/tir/prim_func.py index 923eb97d2758..7117a9f08188 100644 --- a/python/tvm/script/tir/prim_func.py +++ b/python/tvm/script/tir/prim_func.py @@ -17,27 +17,47 @@ """TVM Script Interface for PrimFunc""" import inspect -from typing import Callable +from typing import Callable, Dict, Any, Optional from tvm.tir.function import PrimFunc from ..parser import from_source -def prim_func(input_func: Callable) -> PrimFunc: +def prim_func( + input_func: Optional[Callable] = None, *, params: Optional[Dict[str, Any]] = None +) -> PrimFunc: """Decorate a python function as tvm script. Parameters ---------- - func : input_func + input_func : Optional[Callable] + The function to be parsed. + params: Optional[Dict[str,Any]] + + A variable look-up table. Variables defined within the body + of the function take precedence over the variables in the + params dictionary. Variable types should be supported by + tvm.runtime.convert. + + This is a keyword-only argument, to require explicit opt-in to + using parameters. + Returns ------- output : PrimFunc The result functions. """ + if input_func is None: + + def wrapper(input_func: Callable): + return prim_func(input_func=input_func, params=params) + + return wrapper + if inspect.isfunction(input_func): - result = from_source(input_func) + result = from_source(input_func, params=params) result.__name__ = input_func.__name__ result.__qualname__ = input_func.__qualname__ return result From 5683fe5c2724add201967285f00b01a94ae1a397 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 19 Apr 2022 10:42:19 -0500 Subject: [PATCH 2/2] Added unit tests showing desired behavior --- .../unittest/test_tvmscript_syntax_sugar.py | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index 4a2482c11d22..4b4f801c531a 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -18,6 +18,8 @@ import sys import pytest + +import tvm from tvm.ir import assert_structural_equal from tvm.script import tir as T from tvm.script.parser import from_source @@ -235,5 +237,75 @@ def test_match_buffer_int64(): assert_structural_equal(original, after_roundtrip, True) +def test_primfunc_with_numeric_parameter(): + # A parameter accessed in the body of the PrimFunc must be passed + # in + @T.prim_func(params={"N": 16}) + def func_with_parameter(A: T.Buffer[(1,), "int32"]): + for i in T.serial(N): + A[0] = A[0] + i + + @T.prim_func + def func_without_parameter(A: T.Buffer[(1,), "int32"]): + for i in T.serial(16): + A[0] = A[0] + i + + assert_structural_equal(func_with_parameter, func_without_parameter) + + +def test_primfunc_with_numeric_shape_parameter(): + # A parameter used in the argument list must also be within the + # parent scope. This is because names in the argument list is + # resolved twice, once by the python interpreter and once by + # tvmscript. Only the value passed to tvmscript as a parameter is + # used in the generated PrimFunc. + A_size = 16 + + @T.prim_func(params={"A_size": 256}) + def func_with_parameter(A: T.Buffer[(A_size,), "int32"]): + T.evaluate(A[0]) + + @T.prim_func + def func_without_parameter(A: T.Buffer[(256,), "int32"]): + T.evaluate(A[0]) + + assert_structural_equal(func_with_parameter, func_without_parameter) + + +def test_primfunc_with_tuple_shape_parameter(): + # Parameters can be any type supported by tvm.runtime.convert. + A_shape = (16, 16) + + @T.prim_func(params={"A_shape": A_shape}) + def func_with_parameter(A: T.Buffer[A_shape, "int32"]): + T.evaluate(A[0, 0]) + + @T.prim_func + def func_without_parameter(A: T.Buffer[(16, 16), "int32"]): + T.evaluate(A[0, 0]) + + assert_structural_equal(func_with_parameter, func_without_parameter) + + +def test_irmodule_parameter(): + # Parameters can occur within T.prim_func definitions within an + # `ir_module` annotation. + A_shape = (16, 16) + + @tvm.script.ir_module + class module_with_parameter: + @T.prim_func(params={"A_shape": A_shape}) + def func(A: T.Buffer[A_shape, "int32"]): + T.evaluate(A[0, 0]) + + @tvm.script.ir_module + class module_without_parameter: + @T.prim_func + def func(A: T.Buffer[(16, 16), "int32"]): + T.evaluate(A[0, 0]) + + assert_structural_equal(module_with_parameter, module_without_parameter) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))