Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
6c880b6
Check well-formedness in the parser
slyubomirsky Feb 14, 2024
dbf64d0
Correct packed funcs in NN frontend
slyubomirsky Feb 14, 2024
471148c
Support the check_well_formed optional argument to I.ir_module
slyubomirsky Feb 14, 2024
ca0ce7a
Also check well-formedness in TIR
slyubomirsky Feb 14, 2024
76d79a6
Enable normalization for individual Relax functions and PrimFuncs
slyubomirsky Feb 14, 2024
cefdb7e
Use the error raised by the TIR well-formed checker for the message
slyubomirsky Feb 15, 2024
3177880
Fix tvmscript test failures
slyubomirsky Feb 15, 2024
4f9f97a
Whitespace
slyubomirsky Feb 15, 2024
019a85b
Fix errors in verify_well_formed test
slyubomirsky Feb 15, 2024
ab2f2bd
Include a more helpful error message
slyubomirsky Feb 15, 2024
a6db5fa
Fix TIR test failures
slyubomirsky Feb 15, 2024
2048e60
Address well-formed failures in test_tir_specialize
slyubomirsky Feb 15, 2024
831f8f2
Correct well-formedness error in test_tir_analysis_oob
slyubomirsky Feb 15, 2024
e6d5f00
Correct further well-formedness failures
slyubomirsky Feb 15, 2024
3070b18
Remove __tvm_meta__ from test case to avoid parsing error
slyubomirsky Feb 23, 2024
5c7aec3
Avoid circular import in entryy.py
slyubomirsky Feb 26, 2024
87426d3
Formatting fixes
slyubomirsky Feb 26, 2024
7e299e9
lint fix
slyubomirsky Feb 27, 2024
b8616f2
Add pylint exceptions
slyubomirsky Feb 27, 2024
cb114e4
Fix whitespace
slyubomirsky Feb 27, 2024
ea357bb
Fix more failed test cases
slyubomirsky Feb 27, 2024
5d6de7d
Catch inappropriate use of decl_function instead of segfaulting
slyubomirsky Feb 27, 2024
574677a
Fix test_lower.py
slyubomirsky Feb 27, 2024
08819aa
Mark purity in test_relax_2d_buffer_allocation.py
slyubomirsky Feb 27, 2024
f93c406
Mark purity in test_dma_builtin.py
slyubomirsky Feb 27, 2024
299abbc
Remove __tvm_meta___ from test_tir_usmp_analysis_extract_bufferinfo.py
slyubomirsky Feb 27, 2024
fda20a3
Suppress well-formed check in test_tir_transform_convert_blocks_to_op…
slyubomirsky Feb 27, 2024
8b9506d
Remove __tvm_meta__ in test_tir_usmp_algo.py
slyubomirsky Feb 27, 2024
4944c3a
Remove __tvm_meta__ from more USMP tests
slyubomirsky Feb 27, 2024
7889711
Fix incorrect var in test_tir_transform_storage_flatten.py
slyubomirsky Feb 27, 2024
e11cf8c
Remove all remaining instances of __tvm_meta__
slyubomirsky Feb 27, 2024
e35499d
Fix purity error in test_dataflow_pattern.py
slyubomirsky Feb 27, 2024
261e3eb
Fix purity error in test_ast_printer
slyubomirsky Feb 27, 2024
6cc354c
Fix test_arith_domain_touched example
slyubomirsky Feb 28, 2024
355aedf
Okay to set check_well_formed to True in test_tir_analysis_identify_m…
slyubomirsky Feb 28, 2024
f1b6dcb
Define variable in test_tir_analysis_oob
slyubomirsky Feb 28, 2024
d9dbeb5
Typo fix
slyubomirsky Feb 28, 2024
0ca4698
Add explanatory comment to test case
slyubomirsky Feb 28, 2024
8af392d
Define the undefined vars in test_tir_transform_common_subexpr_elim
slyubomirsky Feb 28, 2024
07c8819
Exception no longer necessary in test_tir_transform_inject_rolling_bu…
slyubomirsky Feb 28, 2024
ae8001c
Remove unnecessary check exemption in test_tir_transform_convert_ssa
slyubomirsky Feb 28, 2024
b373a4f
Avoid checking exemption in test_inject_ptx_ldg32
slyubomirsky Feb 28, 2024
785394d
Note special case in test_distributed_transform_propagate_sharding
slyubomirsky Feb 29, 2024
571cf47
Exempt well-formed error in dlight/test_benchmark
slyubomirsky Mar 1, 2024
aff00de
Exempt well-formedness errors in test_ethosu/, mostly uninitialized vars
slyubomirsky Mar 1, 2024
a8e7415
Whitespace
slyubomirsky Mar 1, 2024
f772190
Include non-CUDA GPUs in IsScheduledOnGPU
slyubomirsky Mar 5, 2024
f82ca07
Fix thread binding bug by changing thread binding var dtype
slyubomirsky Mar 6, 2024
2b7c4ca
Include overrides in test_runtime_builtin_paged_attention_kv_cache.py
slyubomirsky Mar 6, 2024
d254177
add exemptions in test_ethosu/test_replace_conv2d
slyubomirsky Mar 6, 2024
3247cb5
Add more ethosu exemptions
slyubomirsky Mar 6, 2024
9cd556d
More exemptions for ethosu tests
slyubomirsky Mar 6, 2024
1800188
Remove unused reference
slyubomirsky Mar 14, 2024
c54267b
Indicate purity in test_transform_rewrite_cuda_graph
slyubomirsky Mar 14, 2024
3684fd3
Indicate purity in test_transform_normalize
slyubomirsky Mar 14, 2024
e77ef6e
Reorder MergeSharedMemoryAllocations in GPU codegen
slyubomirsky Mar 18, 2024
c8fb78c
Add target parameter for FP8StorageLegalize and FP8ComputeLegalize
slyubomirsky Mar 19, 2024
3395808
Don't re-import Target in tvm/tir/transform/transform.py
slyubomirsky Mar 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@
class FunctionScope(object):
"""Auxiliary scope for function"""

def __init__(self, block_builder, name, params, attrs):
def __init__(self, block_builder, name, params, attrs, is_pure):
self._bb = block_builder
self._name = name
self._params = params
self._attrs = attrs
self._is_pure = is_pure

# Blocks that have been collected within the function
self._blocks = []
Expand Down Expand Up @@ -208,6 +209,7 @@ def function(
name: str,
params: Optional[Union[Var, Tuple, List[Var]]] = None,
attrs: Optional[Dict[str, Object]] = None,
pure: bool = True,
private: bool = False,
) -> FunctionScope:
"""Annotate a Relax function.
Expand All @@ -225,6 +227,9 @@ def function(
attrs : Dict[str, Object], optional
The function attrs

pure : bool, optional
Whether the function is annotated as pure.

private : bool, optional
Whether the function is annotated as private.
If the function is private, it will not have a global symbol attribute.
Expand Down Expand Up @@ -254,7 +259,7 @@ def function(
if not private:
attrs["global_symbol"] = name

return FunctionScope(self, name, params, attrs)
return FunctionScope(self, name, params, attrs, is_pure=pure)

def testing_scope(self, def_vars: List[tir.Var]) -> TestingScope:
"""Start a scope for unit-testing purposes.
Expand Down Expand Up @@ -640,7 +645,7 @@ def emit_func_output(

# do not specify ret_struct_info and let constructor deduce
# from seqe.struct_info
func = rx.Function(self._func._params, seqe)
func = rx.Function(self._func._params, seqe, is_pure=self._func._is_pure)
for key, value in self._func._attrs.items():
func = func.with_attr(key, value)
self.end_scope()
Expand Down
44 changes: 18 additions & 26 deletions python/tvm/relax/frontend/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import List, Optional, Sequence, Union

from tvm import relax as rx
from tvm import tir, ir
from tvm import tir

from . import op
from .core import Effect, Module, ModuleList, Parameter, Tensor, get_default_dtype
Expand Down Expand Up @@ -599,15 +599,12 @@ def emit_init(self, name_hint: str, bb: rx.BlockBuilder): # pylint: disable=arg
init_shape = rx.ShapeExpr([self.init_seq_len] + self.unit_shape)
return [
bb.emit(
rx.Call(
ir.Op.get("relax.call_pure_packed"),
args=[
rx.extern("vm.builtin.attention_kv_cache_create"),
rx.op.zeros(init_shape, self.dtype),
init_shape,
rx.PrimValue(0),
],
sinfo_args=[rx.ObjectStructInfo()],
rx.op.call_pure_packed(
"vm.builtin.attention_kv_cache_create",
rx.op.zeros(init_shape, self.dtype),
init_shape,
rx.PrimValue(0),
sinfo_args=rx.ObjectStructInfo(),
),
name_hint=name_hint,
)
Expand Down Expand Up @@ -675,14 +672,11 @@ def view(self, seq_len: tir.Var) -> Tensor:
shape = rx.ShapeExpr([seq_len] + self.unit_shape)
return Tensor(
_expr=rx.BlockBuilder.current().emit(
rx.Call(
ir.Op.get("relax.call_pure_packed"),
args=[
rx.extern("vm.builtin.attention_kv_cache_view"),
self.cache,
shape,
],
sinfo_args=[rx.TensorStructInfo(shape, self.dtype)],
rx.op.call_pure_packed(
"vm.builtin.attention_kv_cache_view",
self.cache,
shape,
sinfo_args=rx.TensorStructInfo(shape, self.dtype),
)
)
)
Expand All @@ -702,14 +696,12 @@ def append(self, new_element: Tensor) -> None:
f'but got "{new_element.dtype}"'
)
self.cache = rx.BlockBuilder.current().emit(
rx.Call(
ir.Op.get("relax.call_pure_packed"),
args=[
rx.extern("vm.builtin.attention_kv_cache_append"),
self.cache,
new_element._expr,
],
sinfo_args=[rx.ObjectStructInfo()],
rx.op.call_inplace_packed(
"vm.builtin.attention_kv_cache_append",
self.cache,
new_element._expr,
inplace_indices=[0],
sinfo_args=rx.ObjectStructInfo(),
)
)

Expand Down
8 changes: 6 additions & 2 deletions python/tvm/script/ir_builder/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def decl_function(func_name: str, func_signature: BaseFunc) -> GlobalVar:
func_name : str
The function unique name.

func_signature: Optional[BaseFunc]
func_signature: BaseFunc
A Function w/o body, which used to specify the function signature
(i.e. func params and func return type/shape).

Expand All @@ -55,7 +55,11 @@ def decl_function(func_name: str, func_signature: BaseFunc) -> GlobalVar:
gv : GlobalVar
The corresponding GlobalVar.
"""

if not isinstance(func_signature, BaseFunc):
raise ValueError(
"decl_function expects an instance of BaseFunc, "
f"but {func_signature} is of type {type(func_signature)}"
)
return _ffi_api.DeclFunction( # type: ignore[attr-defined] # pylint: disable=no-member
func_name, func_signature
)
Expand Down
40 changes: 38 additions & 2 deletions python/tvm/script/parser/core/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,20 @@
import inspect
from typing import Any, Dict, Union

from ....ir.module import IRModule
from ...ir_builder import IRBuilder
from . import doc
from .diagnostics import Source
from .error import ParserError
from .parser import Parser

WELL_FORMED_ERROR_MESSAGE = (
"Program is not well-formed. If this is deliberate, consider "
"setting check_well_formed in the top-level decorator to False "
"(e.g., @I.ir_module(check_well_formed=False) or "
"@R.function(check_well_formed=False))."
)


def _default_globals() -> Dict[str, Any]:
import tvm # pylint: disable=import-outside-toplevel
Expand All @@ -43,7 +51,11 @@ def scan_macro(program: Union[Any, str], extra_vars: Dict[str, Any] = None) -> A
return source, closure_vars


def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) -> Any:
def parse(
program: Union[doc.AST, Any, str],
extra_vars: Dict[str, Any] = None,
check_well_formed: bool = True,
) -> Any:
"""Register a method for a operand type, AST operator node and operand index.

Parameters
Expand All @@ -54,6 +66,9 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None)
extra_vars : Dict[str, Any]
The extra variable table for parsing.

check_well_formed : bool
Whether to check well-formedness after parsing.

Returns
-------
func : Any
Expand All @@ -77,4 +92,25 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None)
parser.parse(extra_vars=extra_vars)
except ParserError as err:
parser.report_error(err.node, err.args[0])
return builder.get()
ret = builder.get()
# check well-formedness in both Relax and TIR
if check_well_formed:
# (C0415 = import-outside-toplevel. It is necessary here to avoid a circular dependency,
# since importing Relax imports a dependency on the parser)
from ....relax.analysis import well_formed as relax_well_formed # pylint: disable=C0415
from ....tir.analysis import verify_well_formed as tir_well_formed # pylint: disable=C0415

check_ret = ret
if not isinstance(check_ret, IRModule):
check_ret = IRModule.from_expr(ret)
source_ast = source.as_ast()
if not relax_well_formed(check_ret):
parser.report_error(source_ast, err=WELL_FORMED_ERROR_MESSAGE)
try:
tir_well_formed(check_ret)
except Exception as err: # pylint: disable=broad-exception-caught
parser.report_error(
source_ast,
err=f"{WELL_FORMED_ERROR_MESSAGE}\n\nTraceback: {str(err)}",
)
return ret
30 changes: 23 additions & 7 deletions python/tvm/script/parser/ir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,48 @@
"""The entry point of TVM parser for ir module."""

import inspect
from typing import Type
from typing import Optional, Type

from tvm.ir import IRModule

from .._core import parse, utils


def ir_module(mod: Type) -> IRModule:
# this formulation allows us to support having @I.ir_module
# appear as a decorator by itself or to have optional arguments
# like @I.ir_module(check_well_formed=False)
def ir_module(mod: Optional[Type] = None, check_well_formed: bool = True) -> IRModule:
"""The parsing method for ir module, by using `@ir_module` as decorator.

Parameters
----------
mod : Type
The class to be parsed as ir module.

check_well_formed : bool
Whether to check well-formedness during parsing.

Returns
-------
ir_module : IRModule
The parsed ir module.
"""
if not inspect.isclass(mod):
raise TypeError(f"Expect a class, but got: {mod}")

m = parse(mod, utils.inspect_class_capture(mod))
setattr(m, "__name__", mod.__name__)
return m
def decorator_wrapper(mod):
if not inspect.isclass(mod):
raise TypeError(f"Expect a class, but got: {mod}")
m = parse(mod, utils.inspect_class_capture(mod), check_well_formed=check_well_formed)
setattr(m, "__name__", mod.__name__)
return m

if mod is not None:
# if there are no optional args given, this will directly invoke the wrapper
return decorator_wrapper(mod)
else:
# if there is a optional arg given, it returns the wrapper function
# as a new decorator and applies it
setattr(decorator_wrapper, "dispatch_token", "ir")
return decorator_wrapper


setattr(ir_module, "dispatch_token", "ir")
4 changes: 2 additions & 2 deletions python/tvm/script/parser/relax/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
# appear as a decorator by itself or to have optional arguments
# like @R.function(pure=False)
def function(
f: Optional[FType] = None, pure: bool = True, private: bool = False
f: Optional[FType] = None, pure: bool = True, private: bool = False, check_well_formed=True
) -> Union[Function, FType]:
# pylint: disable=unused-argument
# (pure and private aren't used here, but are used later in parsing)
Expand All @@ -66,7 +66,7 @@ def decorator_wrapper(f):
raise TypeError(f"Expect a function, but got: {f}")
if utils.is_defined_in_class(orig_stack, f):
return f
return parse(f, utils.inspect_function_capture(f))
return parse(f, utils.inspect_function_capture(f), check_well_formed=check_well_formed)

if f is not None:
# if there are no optional args given, this will directly invoke the wrapper
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from ..core.parser import Parser, ScriptMacro


def prim_func(func: Optional[Callable] = None, private: bool = False) -> Union[PrimFunc, Callable]:
def prim_func(
func: Optional[Callable] = None, private: bool = False, check_well_formed=True
) -> Union[PrimFunc, Callable]:
"""The parsing method for tir prim func, by using `@prim_func` as decorator.

Parameters
Expand Down Expand Up @@ -60,7 +62,7 @@ def decorator_wrapper(func):
raise TypeError(f"Expect a function, but got: {func}")
if utils.is_defined_in_class(outer_stack, func):
return func
f = parse(func, utils.inspect_function_capture(func))
f = parse(func, utils.inspect_function_capture(func), check_well_formed=check_well_formed)
setattr(f, "__name__", func.__name__)
return f

Expand Down
9 changes: 6 additions & 3 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,6 @@ def enabled_targets():


class Feature:

"""A feature that may be required to run a test.

Parameters
Expand Down Expand Up @@ -1952,6 +1951,8 @@ def expected(A: T.Buffer(1, "int32")):

"""

check_well_formed: bool = True

def __init_subclass__(cls):
assert len([getattr(cls, name) for name in ["before", "Before"] if hasattr(cls, name)]) <= 1
assert (
Expand Down Expand Up @@ -1995,7 +1996,9 @@ def inner(self):
func_dict[name] = method.with_attr("global_symbol", name)
else:
source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(method))
prim_func = tvm.script.from_source(source_code)
prim_func = tvm.script.from_source(
source_code, check_well_formed=self.check_well_formed
)
func_dict[name] = prim_func.with_attr("global_symbol", name)
return tvm.IRModule(func_dict)

Expand All @@ -2004,7 +2007,7 @@ def inner(self):
def inner(self):
# pylint: disable=unused-argument
source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(func))
return tvm.script.from_source(source_code)
return tvm.script.from_source(source_code, check_well_formed=self.check_well_formed)

return pytest.fixture(inner)

Expand Down
Loading