Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 69 additions & 36 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1859,10 +1859,7 @@ def __init_subclass__(cls):
cls.transform = cls._normalize_transform(cls.transform)

@classmethod
def _normalize_before(cls, func):
if hasattr(func, "_pytestfixturefunction"):
return func

def _normalize_ir_module(cls, func):
if isinstance(func, tvm.tir.PrimFunc):

def inner(self):
Expand All @@ -1875,6 +1872,22 @@ def inner(self):
# pylint: disable=unused-argument
return func(self)

elif inspect.isclass(func):

def inner(self):
# pylint: disable=unused-argument
func_dict = {}
for name, method in func.__dict__.items():
if name.startswith("_"):
pass
elif isinstance(method, tvm.ir.function.BaseFunc):
func_dict[name] = method
else:
source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(method))
prim_func = tvm.script.from_source(source_code)
func_dict[name] = prim_func
return tvm.IRModule(func_dict)

else:

def inner(self):
Expand All @@ -1884,50 +1897,64 @@ def inner(self):

return pytest.fixture(inner)

@classmethod
def _normalize_before(cls, func):
if hasattr(func, "_pytestfixturefunction"):
return func
else:
return cls._normalize_ir_module(func)

@classmethod
def _normalize_expected(cls, func):
if hasattr(func, "_pytestfixturefunction"):
return func

if isinstance(func, tvm.tir.PrimFunc) or (
inspect.isclass(func) and issubclass(func, Exception)
):
elif inspect.isclass(func) and issubclass(func, Exception):

def inner(self):
# pylint: disable=unused-argument
return func

elif cls._is_method(func):

def inner(self):
# pylint: disable=unused-argument
return func(self)
return pytest.fixture(inner)

else:

def inner(self):
# pylint: disable=unused-argument
source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(func))
return tvm.script.from_source(source_code)

return pytest.fixture(inner)
return cls._normalize_ir_module(func)

@classmethod
def _normalize_transform(cls, transform):
def apply(module_transform):
def inner(obj):
if isinstance(obj, tvm.IRModule):
return module_transform(obj)
elif isinstance(obj, tvm.tir.PrimFunc):
mod = tvm.IRModule({"main": obj})
mod = module_transform(mod)
return mod["main"]
else:
raise TypeError(f"Expected IRModule or PrimFunc, but received {type(obj)}")

return inner

if hasattr(transform, "_pytestfixturefunction"):
return transform

if isinstance(transform, tvm.ir.transform.Pass):
if not hasattr(cls, "_transform_orig"):
cls._transform_orig = transform

def inner(self, _transform_orig):
# pylint: disable=unused-argument
return apply(_transform_orig)

elif isinstance(transform, tvm.ir.transform.Pass):

def inner(self):
# pylint: disable=unused-argument
return transform
return apply(transform)

elif cls._is_method(transform):

def inner(self):
# pylint: disable=unused-argument
return transform(self)
return apply(transform(self))

else:

Expand All @@ -1945,42 +1972,48 @@ def _is_method(func):
def test_compare(self, before, expected, transform):
"""Unit test to compare the expected TIR PrimFunc to actual"""

before_mod = tvm.IRModule.from_expr(before)
def pprint(name, obj):
script = obj.script()
if isinstance(obj, tvm.IRModule):
return script.replace("class Module", f"class {name}")
else:
return script.replace("def func", f"def {name}")

if inspect.isclass(expected) and issubclass(expected, Exception):
with pytest.raises(expected):
after_mod = transform(before_mod)
after = transform(before)

# This portion through pytest.fail isn't strictly
# necessary, but gives a better error message that
# includes the before/after.
after = after_mod["main"]
script = tvm.IRModule({"after": after, "before": before}).script()
before_str = pprint("before", before)
after_str = pprint("after", after)

pytest.fail(
msg=(
f"Expected {expected.__name__} to be raised from transformation, "
f"instead received TIR\n:{script}"
f"instead received TIR\n:{before_str}\n{after_str}"
)
)

elif isinstance(expected, tvm.tir.PrimFunc):
after_mod = transform(before_mod)
after = after_mod["main"]
elif isinstance(expected, (tvm.tir.PrimFunc, tvm.ir.IRModule)):
after = transform(before)

try:
tvm.ir.assert_structural_equal(after, expected)
except ValueError as err:
script = tvm.IRModule(
{"expected": expected, "after": after, "before": before}
).script()
before_str = pprint("before", before)
after_str = pprint("after", after)
expected_str = pprint("expected", expected)
raise ValueError(
f"TIR after transformation did not match expected:\n{script}"
f"TIR after transformation did not match expected:\n"
f"{before_str}\n{after_str}\n{expected_str}"
) from err

else:
raise TypeError(
f"tvm.testing.CompareBeforeAfter requires the `expected` fixture "
f"to return either `Exception`, an `Exception` subclass, "
f"or an instance of `tvm.tir.PrimFunc`. "
f"Instead, received {type(exception)}."
f"Instead, received {type(expected)}."
)
49 changes: 48 additions & 1 deletion tests/python/unittest/test_tvm_testing_before_after.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import tvm
import tvm.testing
from tvm.script import tir as T
from tvm.script import tir as T, ir_module


class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
Expand Down Expand Up @@ -79,5 +79,52 @@ def func(A: T.Buffer[n, "float32"]):
expected = before


class TestBeforeAfterIRModule(BaseBeforeAfter):
"""The preferred form for writing TIR unit tests

All evaluation is done at test-time, with the minimal amount of
additional lines. The `@tvm.testing.fixture`, `@ir_module`, and
`@T.prim_func` annotations are handled by
`tvm.testing.CompareBeforeAfter`.
"""

class before:
def func_A(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
A[i] = 0.0

def func_B(A: T.Buffer[16, "int32"]):
for i in T.serial(16):
A[i] = 42

expected = before


class TestBeforeAfterIRModuleExplicitFixture(BaseBeforeAfter):
"""Like TestBeforeAfterIRModule, but with an explicit fixture

If the IRModule depends on additional fixtures, this form can be
used.
"""

@tvm.testing.fixture
def before(self):
@ir_module
class mod:
@T.prim_func
def func_A(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
A[i] = 0.0

@T.prim_func
def func_B(A: T.Buffer[16, "int32"]):
for i in T.serial(16):
A[i] = 42

return mod

expected = before


if __name__ == "__main__":
tvm.testing.main()