Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,11 @@
class FunctionScope(object):
"""Auxiliary scope for function"""

def __init__(self, block_builder, name, params, attrs, is_pure):
def __init__(self, block_builder, name, params, attrs):
self._bb = block_builder
self._name = name
self._params = params
self._attrs = attrs
self._is_pure = is_pure

# Blocks that have been collected within the function
self._blocks = []
Expand Down Expand Up @@ -209,7 +208,6 @@ def function(
name: str,
params: Optional[Union[Var, Tuple, List[Var]]] = None,
attrs: Optional[Dict[str, Object]] = None,
pure: bool = True,
private: bool = False,
) -> FunctionScope:
"""Annotate a Relax function.
Expand All @@ -227,9 +225,6 @@ def function(
attrs : Dict[str, Object], optional
The function attrs

pure : bool, optional
Whether the function is annotated as pure.

private : bool, optional
Whether the function is annotated as private.
If the function is private, it will not have a global symbol attribute.
Expand Down Expand Up @@ -259,7 +254,7 @@ def function(
if not private:
attrs["global_symbol"] = name

return FunctionScope(self, name, params, attrs, is_pure=pure)
return FunctionScope(self, name, params, attrs)

def testing_scope(self, def_vars: List[tir.Var]) -> TestingScope:
"""Start a scope for unit-testing purposes.
Expand Down Expand Up @@ -645,7 +640,7 @@ def emit_func_output(

# do not specify ret_struct_info and let constructor deduce
# from seqe.struct_info
func = rx.Function(self._func._params, seqe, is_pure=self._func._is_pure)
func = rx.Function(self._func._params, seqe)
for key, value in self._func._attrs.items():
func = func.with_attr(key, value)
self.end_scope()
Expand Down
44 changes: 26 additions & 18 deletions python/tvm/relax/frontend/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import List, Optional, Sequence, Union

from tvm import relax as rx
from tvm import tir
from tvm import tir, ir

from . import op
from .core import Effect, Module, ModuleList, Parameter, Tensor, get_default_dtype
Expand Down Expand Up @@ -599,12 +599,15 @@ def emit_init(self, name_hint: str, bb: rx.BlockBuilder): # pylint: disable=arg
init_shape = rx.ShapeExpr([self.init_seq_len] + self.unit_shape)
return [
bb.emit(
rx.op.call_pure_packed(
"vm.builtin.attention_kv_cache_create",
rx.op.zeros(init_shape, self.dtype),
init_shape,
rx.PrimValue(0),
sinfo_args=rx.ObjectStructInfo(),
rx.Call(
ir.Op.get("relax.call_pure_packed"),
args=[
rx.extern("vm.builtin.attention_kv_cache_create"),
rx.op.zeros(init_shape, self.dtype),
init_shape,
rx.PrimValue(0),
],
sinfo_args=[rx.ObjectStructInfo()],
),
name_hint=name_hint,
)
Expand Down Expand Up @@ -672,11 +675,14 @@ def view(self, seq_len: tir.Var) -> Tensor:
shape = rx.ShapeExpr([seq_len] + self.unit_shape)
return Tensor(
_expr=rx.BlockBuilder.current().emit(
rx.op.call_pure_packed(
"vm.builtin.attention_kv_cache_view",
self.cache,
shape,
sinfo_args=rx.TensorStructInfo(shape, self.dtype),
rx.Call(
ir.Op.get("relax.call_pure_packed"),
args=[
rx.extern("vm.builtin.attention_kv_cache_view"),
self.cache,
shape,
],
sinfo_args=[rx.TensorStructInfo(shape, self.dtype)],
)
)
)
Expand All @@ -696,12 +702,14 @@ def append(self, new_element: Tensor) -> None:
f'but got "{new_element.dtype}"'
)
self.cache = rx.BlockBuilder.current().emit(
rx.op.call_inplace_packed(
"vm.builtin.attention_kv_cache_append",
self.cache,
new_element._expr,
inplace_indices=[0],
sinfo_args=rx.ObjectStructInfo(),
rx.Call(
ir.Op.get("relax.call_pure_packed"),
args=[
rx.extern("vm.builtin.attention_kv_cache_append"),
self.cache,
new_element._expr,
],
sinfo_args=[rx.ObjectStructInfo()],
)
)

Expand Down
8 changes: 2 additions & 6 deletions python/tvm/script/ir_builder/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def decl_function(func_name: str, func_signature: BaseFunc) -> GlobalVar:
func_name : str
The function unique name.

func_signature: BaseFunc
func_signature: Optional[BaseFunc]
A Function w/o body, which used to specify the function signature
(i.e. func params and func return type/shape).

Expand All @@ -55,11 +55,7 @@ def decl_function(func_name: str, func_signature: BaseFunc) -> GlobalVar:
gv : GlobalVar
The corresponding GlobalVar.
"""
if not isinstance(func_signature, BaseFunc):
raise ValueError(
"decl_function expects an instance of BaseFunc, "
f"but {func_signature} is of type {type(func_signature)}"
)

return _ffi_api.DeclFunction( # type: ignore[attr-defined] # pylint: disable=no-member
func_name, func_signature
)
Expand Down
40 changes: 2 additions & 38 deletions python/tvm/script/parser/core/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,12 @@
import inspect
from typing import Any, Dict, Union

from ....ir.module import IRModule
from ...ir_builder import IRBuilder
from . import doc
from .diagnostics import Source
from .error import ParserError
from .parser import Parser

WELL_FORMED_ERROR_MESSAGE = (
"Program is not well-formed. If this is deliberate, consider "
"setting check_well_formed in the top-level decorator to False "
"(e.g., @I.ir_module(check_well_formed=False) or "
"@R.function(check_well_formed=False))."
)


def _default_globals() -> Dict[str, Any]:
import tvm # pylint: disable=import-outside-toplevel
Expand All @@ -51,11 +43,7 @@ def scan_macro(program: Union[Any, str], extra_vars: Dict[str, Any] = None) -> A
return source, closure_vars


def parse(
program: Union[doc.AST, Any, str],
extra_vars: Dict[str, Any] = None,
check_well_formed: bool = True,
) -> Any:
def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) -> Any:
"""Register a method for a operand type, AST operator node and operand index.

Parameters
Expand All @@ -66,9 +54,6 @@ def parse(
extra_vars : Dict[str, Any]
The extra variable table for parsing.

check_well_formed : bool
Whether to check well-formedness after parsing.

Returns
-------
func : Any
Expand All @@ -92,25 +77,4 @@ def parse(
parser.parse(extra_vars=extra_vars)
except ParserError as err:
parser.report_error(err.node, err.args[0])
ret = builder.get()
# check well-formedness in both Relax and TIR
if check_well_formed:
# (C0415 = import-outside-toplevel. It is necessary here to avoid a circular dependency,
# since importing Relax imports a dependency on the parser)
from ....relax.analysis import well_formed as relax_well_formed # pylint: disable=C0415
from ....tir.analysis import verify_well_formed as tir_well_formed # pylint: disable=C0415

check_ret = ret
if not isinstance(check_ret, IRModule):
check_ret = IRModule.from_expr(ret)
source_ast = source.as_ast()
if not relax_well_formed(check_ret):
parser.report_error(source_ast, err=WELL_FORMED_ERROR_MESSAGE)
try:
tir_well_formed(check_ret)
except Exception as err: # pylint: disable=broad-exception-caught
parser.report_error(
source_ast,
err=f"{WELL_FORMED_ERROR_MESSAGE}\n\nTraceback: {str(err)}",
)
return ret
return builder.get()
30 changes: 7 additions & 23 deletions python/tvm/script/parser/ir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,32 @@
"""The entry point of TVM parser for ir module."""

import inspect
from typing import Optional, Type
from typing import Type

from tvm.ir import IRModule

from .._core import parse, utils


# this formulation allows us to support having @I.ir_module
# appear as a decorator by itself or to have optional arguments
# like @I.ir_module(check_well_formed=False)
def ir_module(mod: Optional[Type] = None, check_well_formed: bool = True) -> IRModule:
def ir_module(mod: Type) -> IRModule:
"""The parsing method for ir module, by using `@ir_module` as decorator.

Parameters
----------
mod : Type
The class to be parsed as ir module.

check_well_formed : bool
Whether to check well-formedness during parsing.

Returns
-------
ir_module : IRModule
The parsed ir module.
"""
if not inspect.isclass(mod):
raise TypeError(f"Expect a class, but got: {mod}")

def decorator_wrapper(mod):
if not inspect.isclass(mod):
raise TypeError(f"Expect a class, but got: {mod}")
m = parse(mod, utils.inspect_class_capture(mod), check_well_formed=check_well_formed)
setattr(m, "__name__", mod.__name__)
return m

if mod is not None:
# if there are no optional args given, this will directly invoke the wrapper
return decorator_wrapper(mod)
else:
# if there is a optional arg given, it returns the wrapper function
# as a new decorator and applies it
setattr(decorator_wrapper, "dispatch_token", "ir")
return decorator_wrapper
m = parse(mod, utils.inspect_class_capture(mod))
setattr(m, "__name__", mod.__name__)
return m


setattr(ir_module, "dispatch_token", "ir")
4 changes: 2 additions & 2 deletions python/tvm/script/parser/relax/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
# appear as a decorator by itself or to have optional arguments
# like @R.function(pure=False)
def function(
f: Optional[FType] = None, pure: bool = True, private: bool = False, check_well_formed=True
f: Optional[FType] = None, pure: bool = True, private: bool = False
) -> Union[Function, FType]:
# pylint: disable=unused-argument
# (pure and private aren't used here, but are used later in parsing)
Expand All @@ -66,7 +66,7 @@ def decorator_wrapper(f):
raise TypeError(f"Expect a function, but got: {f}")
if utils.is_defined_in_class(orig_stack, f):
return f
return parse(f, utils.inspect_function_capture(f), check_well_formed=check_well_formed)
return parse(f, utils.inspect_function_capture(f))

if f is not None:
# if there are no optional args given, this will directly invoke the wrapper
Expand Down
6 changes: 2 additions & 4 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
from ..core.parser import Parser, ScriptMacro


def prim_func(
func: Optional[Callable] = None, private: bool = False, check_well_formed=True
) -> Union[PrimFunc, Callable]:
def prim_func(func: Optional[Callable] = None, private: bool = False) -> Union[PrimFunc, Callable]:
"""The parsing method for tir prim func, by using `@prim_func` as decorator.

Parameters
Expand Down Expand Up @@ -62,7 +60,7 @@ def decorator_wrapper(func):
raise TypeError(f"Expect a function, but got: {func}")
if utils.is_defined_in_class(outer_stack, func):
return func
f = parse(func, utils.inspect_function_capture(func), check_well_formed=check_well_formed)
f = parse(func, utils.inspect_function_capture(func))
setattr(f, "__name__", func.__name__)
return f

Expand Down
9 changes: 3 additions & 6 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ def enabled_targets():


class Feature:

"""A feature that may be required to run a test.

Parameters
Expand Down Expand Up @@ -1951,8 +1952,6 @@ def expected(A: T.Buffer(1, "int32")):

"""

check_well_formed: bool = True

def __init_subclass__(cls):
assert len([getattr(cls, name) for name in ["before", "Before"] if hasattr(cls, name)]) <= 1
assert (
Expand Down Expand Up @@ -1996,9 +1995,7 @@ def inner(self):
func_dict[name] = method.with_attr("global_symbol", name)
else:
source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(method))
prim_func = tvm.script.from_source(
source_code, check_well_formed=self.check_well_formed
)
prim_func = tvm.script.from_source(source_code)
func_dict[name] = prim_func.with_attr("global_symbol", name)
return tvm.IRModule(func_dict)

Expand All @@ -2007,7 +2004,7 @@ def inner(self):
def inner(self):
# pylint: disable=unused-argument
source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(func))
return tvm.script.from_source(source_code, check_well_formed=self.check_well_formed)
return tvm.script.from_source(source_code)

return pytest.fixture(inner)

Expand Down
Loading