Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,11 @@ class TVMScriptParser(Transformer):

# pylint gets confused here with synr.Transformer which doesn't have a
# custom init, so just disable it
def __init__(self, base_lineno, tir_namespace): # pylint: disable=super-init-not-called
def __init__(
self, base_lineno, tir_namespace, params: Optional[Dict[str, Any]] = None
): # pylint: disable=super-init-not-called
self.context = None
self.params = params

self.base_lineno = base_lineno
self.current_lineno = 0
Expand Down Expand Up @@ -435,6 +438,10 @@ def check_decorator(decorators: List[ast.Expr]) -> bool:
if len(decorators) != 1:
return False
d: ast.Expr = decorators[0]
# Allow the decorator to be either T.prim_func or
# T.prim_func(params=...)
if isinstance(d, ast.Call):
d = d.func_name
return (
isinstance(d, ast.Attr)
and isinstance(d.object, ast.Var)
Expand Down Expand Up @@ -1136,6 +1143,12 @@ def transform_Var(self, node):
symbol = self.context.lookup_symbol(name)
if symbol is not None:
return symbol

if self.params and name in self.params:
obj = self.params[name]
span = tvm_span_from_synr(node.span)
return tvm.runtime.convert(obj, span=span)

self.report_error(f"Unknown identifier {name}.", node.span)

def transform_TypeVar(self, node):
Expand All @@ -1147,6 +1160,10 @@ def transform_TypeVar(self, node):
symbol = Registry.lookup(name) or self.context.lookup_symbol(name)
if symbol is not None:
return symbol

if self.params and name in self.params:
return self.params[name]

self.report_error(f"Unknown identifier {name}.", node.span)

def transform_Constant(self, node):
Expand Down Expand Up @@ -1230,7 +1247,9 @@ def get_tir_namespace(script: Union[Callable, type]) -> List[str]:


def from_source(
input_func: Union[str, Callable], tir_prefix: Optional[List[str]] = None
input_func: Union[str, Callable],
tir_prefix: Optional[List[str]] = None,
params: Optional[Dict[str, Any]] = None,
) -> Union[PrimFunc, IRModule]:
"""Parse function or string into PrimFunc or IRModule.

Expand All @@ -1252,12 +1271,12 @@ def from_source(
"""
if isinstance(input_func, str):
tir_prefix = ["T", "tir"] if tir_prefix is None else tir_prefix
return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix))
return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix, params))
elif inspect.isfunction(input_func):
_, start_line = inspect.getsourcelines(input_func)
env: Dict[str, Any] = input_func.__globals__
namespace = [key for key in env.keys() if env[key] is tir]
parser = TVMScriptParser(start_line, namespace)
parser = TVMScriptParser(start_line, namespace, params)
result = to_ast(input_func, TVMDiagnosticCtx(), parser)
return result
else:
Expand Down
28 changes: 24 additions & 4 deletions python/tvm/script/tir/prim_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,47 @@
"""TVM Script Interface for PrimFunc"""

import inspect
from typing import Callable
from typing import Callable, Dict, Any, Optional

from tvm.tir.function import PrimFunc
from ..parser import from_source


def prim_func(input_func: Callable) -> PrimFunc:
def prim_func(
input_func: Optional[Callable] = None, *, params: Optional[Dict[str, Any]] = None
) -> PrimFunc:
"""Decorate a python function as tvm script.

Parameters
----------
func : input_func
input_func : Optional[Callable]

The function to be parsed.

params: Optional[Dict[str,Any]]

A variable look-up table. Variables defined within the body
of the function take precedence over the variables in the
params dictionary. Variable types should be supported by
tvm.runtime.convert.

This is a keyword-only argument, to require explicit opt-in to
using parameters.

Returns
-------
output : PrimFunc
The result functions.
"""
if input_func is None:

def wrapper(input_func: Callable):
return prim_func(input_func=input_func, params=params)

return wrapper

if inspect.isfunction(input_func):
result = from_source(input_func)
result = from_source(input_func, params=params)
result.__name__ = input_func.__name__
result.__qualname__ = input_func.__qualname__
return result
Expand Down
72 changes: 72 additions & 0 deletions tests/python/unittest/test_tvmscript_syntax_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import sys

import pytest

import tvm
from tvm.ir import assert_structural_equal
from tvm.script import tir as T
from tvm.script.parser import from_source
Expand Down Expand Up @@ -235,5 +237,75 @@ def test_match_buffer_int64():
assert_structural_equal(original, after_roundtrip, True)


def test_primfunc_with_numeric_parameter():
# A parameter accessed in the body of the PrimFunc must be passed
# in
@T.prim_func(params={"N": 16})
def func_with_parameter(A: T.Buffer[(1,), "int32"]):
for i in T.serial(N):
A[0] = A[0] + i

@T.prim_func
def func_without_parameter(A: T.Buffer[(1,), "int32"]):
for i in T.serial(16):
A[0] = A[0] + i

assert_structural_equal(func_with_parameter, func_without_parameter)


def test_primfunc_with_numeric_shape_parameter():
# A parameter used in the argument list must also be within the
# parent scope. This is because names in the argument list is
# resolved twice, once by the python interpreter and once by
# tvmscript. Only the value passed to tvmscript as a parameter is
# used in the generated PrimFunc.
A_size = 16

@T.prim_func(params={"A_size": 256})
def func_with_parameter(A: T.Buffer[(A_size,), "int32"]):
T.evaluate(A[0])

@T.prim_func
def func_without_parameter(A: T.Buffer[(256,), "int32"]):
T.evaluate(A[0])

assert_structural_equal(func_with_parameter, func_without_parameter)


def test_primfunc_with_tuple_shape_parameter():
# Parameters can be any type supported by tvm.runtime.convert.
A_shape = (16, 16)

@T.prim_func(params={"A_shape": A_shape})
def func_with_parameter(A: T.Buffer[A_shape, "int32"]):
T.evaluate(A[0, 0])

@T.prim_func
def func_without_parameter(A: T.Buffer[(16, 16), "int32"]):
T.evaluate(A[0, 0])

assert_structural_equal(func_with_parameter, func_without_parameter)


def test_irmodule_parameter():
# Parameters can occur within T.prim_func definitions within an
# `ir_module` annotation.
A_shape = (16, 16)

@tvm.script.ir_module
class module_with_parameter:
@T.prim_func(params={"A_shape": A_shape})
def func(A: T.Buffer[A_shape, "int32"]):
T.evaluate(A[0, 0])

@tvm.script.ir_module
class module_without_parameter:
@T.prim_func
def func(A: T.Buffer[(16, 16), "int32"]):
T.evaluate(A[0, 0])

assert_structural_equal(module_with_parameter, module_without_parameter)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))