Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def test_something():
import platform
import shutil
import sys
import textwrap
import time

from pathlib import Path
Expand Down Expand Up @@ -1707,3 +1708,208 @@ def fetch_model_from_url(
def main():
test_file = inspect.getsourcefile(sys._getframe(1))
sys.exit(pytest.main([test_file] + sys.argv[1:]))


class CompareBeforeAfter:
"""Utility for comparing before/after of TIR transforms

A standard framework for writing tests that take a TIR PrimFunc as
input, apply a transformation, then either compare against an
expected output or assert that the transformation raised an error.
A test should subclass CompareBeforeAfter, defining class members
`before`, `transform`, and `expected`. CompareBeforeAfter will
then use these members to define a test method and test fixture.

`transform` may be one of the following.

- An instance of `tvm.ir.transform.Pass`

- A method that takes no arguments and returns a `tvm.ir.transform.Pass`

- A pytest fixture that returns a `tvm.ir.transform.Pass`

`before` may be any one of the following.

- An instance of `tvm.tir.PrimFunc`. This is allowed, but is not
the preferred method, as any errors in constructing the
`PrimFunc` occur while collecting the test, preventing any other
tests in the same file from being run.

- An TVMScript function, without the ``@T.prim_func`` decoration.
The ``@T.prim_func`` decoration will be applied when running the
test, rather than at module import.

- A method that takes no arguments and returns a `tvm.tir.PrimFunc`

- A pytest fixture that returns a `tvm.tir.PrimFunc`

`expected` may be any one of the following. The type of
`expected` defines the test being performed. If `expected`
provides a `tvm.tir.PrimFunc`, the result of the transformation
must match `expected`. If `expected` is an exception, then the
transformation must raise that exception type.

- Any option supported for `before`.

- The `Exception` class object, or a class object that inherits
from `Exception`.

- A method that takes no arguments and returns `Exception` or a
class object that inherits from `Exception`.

- A pytest fixture that returns `Exception` or an class object
that inherits from `Exception`.

Examples
--------

.. python::

class TestRemoveIf(tvm.testing.CompareBeforeAfter):
transform = tvm.tir.transform.Simplify()

def before(A: T.Buffer[1, "int32"]):
if True:
A[0] = 42
else:
A[0] = 5

def expected(A: T.Buffer[1, "int32"]):
A[0] = 42

"""

def __init_subclass__(cls):
if hasattr(cls, "before"):
cls.before = cls._normalize_before(cls.before)
if hasattr(cls, "expected"):
cls.expected = cls._normalize_expected(cls.expected)
if hasattr(cls, "transform"):
cls.transform = cls._normalize_transform(cls.transform)

@classmethod
def _normalize_before(cls, func):
if hasattr(func, "_pytestfixturefunction"):
return func

if isinstance(func, tvm.tir.PrimFunc):

def inner(self):
# pylint: disable=unused-argument
return func

elif cls._is_method(func):

def inner(self):
# pylint: disable=unused-argument
return func(self)

else:

def inner(self):
# pylint: disable=unused-argument
source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(func))
return tvm.script.from_source(source_code)

return pytest.fixture(inner)

@classmethod
def _normalize_expected(cls, func):
if hasattr(func, "_pytestfixturefunction"):
return func

if isinstance(func, tvm.tir.PrimFunc) or (
inspect.isclass(func) and issubclass(func, Exception)
):

def inner(self):
# pylint: disable=unused-argument
return func

elif cls._is_method(func):

def inner(self):
# pylint: disable=unused-argument
return func(self)

else:

def inner(self):
# pylint: disable=unused-argument
source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(func))
return tvm.script.from_source(source_code)

return pytest.fixture(inner)

@classmethod
def _normalize_transform(cls, transform):
if hasattr(transform, "_pytestfixturefunction"):
return transform

if isinstance(transform, tvm.ir.transform.Pass):

def inner(self):
# pylint: disable=unused-argument
return transform

elif cls._is_method(transform):

def inner(self):
# pylint: disable=unused-argument
return transform(self)

else:

raise TypeError(
"Expected transform to be a tvm.ir.transform.Pass, or a method returning a Pass"
)

return pytest.fixture(inner)

@staticmethod
def _is_method(func):
sig = inspect.signature(func)
return "self" in sig.parameters

def test_compare(self, before, expected, transform):
"""Unit test to compare the expected TIR PrimFunc to actual"""

before_mod = tvm.IRModule.from_expr(before)

if inspect.isclass(expected) and issubclass(expected, Exception):
with pytest.raises(expected):
after_mod = transform(before_mod)

# This portion through pytest.fail isn't strictly
# necessary, but gives a better error message that
# includes the before/after.
after = after_mod["main"]
script = tvm.IRModule({"after": after, "before": before}).script()
pytest.fail(
msg=(
f"Expected {expected.__name__} to be raised from transformation, "
f"instead received TIR\n:{script}"
)
)

elif isinstance(expected, tvm.tir.PrimFunc):
after_mod = transform(before_mod)
after = after_mod["main"]

try:
tvm.ir.assert_structural_equal(after, expected)
except ValueError as err:
script = tvm.IRModule(
{"expected": expected, "after": after, "before": before}
).script()
raise ValueError(
f"TIR after transformation did not match expected:\n{script}"
) from err

else:
raise TypeError(
f"tvm.testing.CompareBeforeAfter requires the `expected` fixture "
f"to return either `Exception`, an `Exception` subclass, "
f"or an instance of `tvm.tir.PrimFunc`. "
f"Instead, received {type(exception)}."
)
38 changes: 2 additions & 36 deletions tests/python/unittest/test_tir_transform_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,31 +136,16 @@ def sls(n, d):
assert "if" not in str(stmt)


class BaseBeforeAfter:
def test_simplify(self):
before = self.before
before_mod = tvm.IRModule.from_expr(before)
after_mod = tvm.tir.transform.Simplify()(before_mod)
after = after_mod["main"]
expected = self.expected

try:
tvm.ir.assert_structural_equal(after, expected)
except ValueError as err:
script = tvm.IRModule({"expected": expected, "after": after, "before": before}).script()
raise ValueError(
f"Function after simplification did not match expected:\n{script}"
) from err
class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
transform = tvm.tir.transform.Simplify()


class TestLoadStoreNoop(BaseBeforeAfter):
"""Store of a value that was just read from the same location is a no-op."""

@T.prim_func
def before(A: T.Buffer[(1,), "float32"]):
A[0] = A[0]

@T.prim_func
def expected(A: T.Buffer[(1,), "float32"]):
T.evaluate(0)

Expand All @@ -174,11 +159,9 @@ class TestLoadStoreNoopAfterSimplify(BaseBeforeAfter):
regression.
"""

@T.prim_func
def before(A: T.Buffer[(1,), "float32"]):
A[0] = A[0] + (5.0 - 5.0)

@T.prim_func
def expected(A: T.Buffer[(1,), "float32"]):
T.evaluate(0)

Expand All @@ -191,14 +174,12 @@ class TestNestedCondition(BaseBeforeAfter):
constraint.
"""

@T.prim_func
def before(A: T.Buffer[(16,), "float32"]):
for i in T.serial(16):
if i == 5:
if i == 5:
A[i] = 0.0

@T.prim_func
def expected(A: T.Buffer[(16,), "float32"]):
for i in T.serial(16):
if i == 5:
Expand All @@ -212,14 +193,12 @@ class TestNestedProvableCondition(BaseBeforeAfter):
conditional.
"""

@T.prim_func
def before(A: T.Buffer[(16,), "float32"]):
for i in T.serial(16):
if i == 5:
if i < 7:
A[i] = 0.0

@T.prim_func
def expected(A: T.Buffer[(16,), "float32"]):
for i in T.serial(16):
if i == 5:
Expand All @@ -233,14 +212,12 @@ class TestNestedVarCondition(BaseBeforeAfter):
constraint.
"""

@T.prim_func
def before(A: T.Buffer[(16,), "float32"], n: T.int32):
for i in T.serial(16):
if i == n:
if i == n:
A[i] = 0.0

@T.prim_func
def expected(A: T.Buffer[(16,), "float32"], n: T.int32):
for i in T.serial(16):
if i == n:
Expand All @@ -256,7 +233,6 @@ class TestAlteredBufferContents(BaseBeforeAfter):
may not.
"""

@T.prim_func
def before(A: T.Buffer[(1,), "int32"], n: T.int32):
if A[0] == n:
A[0] = A[0] + 1
Expand All @@ -273,7 +249,6 @@ class TestNegationOfCondition(BaseBeforeAfter):
condition is known to be false.
"""

@T.prim_func
def before(A: T.Buffer[(16,), "int32"]):
for i in T.serial(16):
if i == 5:
Expand All @@ -282,7 +257,6 @@ def before(A: T.Buffer[(16,), "int32"]):
else:
A[i] = 1

@T.prim_func
def expected(A: T.Buffer[(16,), "int32"]):
for i in T.serial(16):
if i == 5:
Expand All @@ -298,7 +272,6 @@ class TestNegationOfNotEqual(BaseBeforeAfter):
``i==5`` as the negation of a literal constraint.
"""

@T.prim_func
def before(A: T.Buffer[(16,), "int32"]):
for i in T.serial(16):
if i != 5:
Expand All @@ -307,7 +280,6 @@ def before(A: T.Buffer[(16,), "int32"]):
else:
A[i] = 1

@T.prim_func
def expected(A: T.Buffer[(16,), "int32"]):
for i in T.serial(16):
if i != 5:
Expand All @@ -321,7 +293,6 @@ class TestNegationOfVarCondition(BaseBeforeAfter):
must rely on RewriteSimplifier recognizing the repeated literal.
"""

@T.prim_func
def before(A: T.Buffer[(16,), "int32"], n: T.int32):
for i in T.serial(16):
if i == n:
Expand All @@ -330,7 +301,6 @@ def before(A: T.Buffer[(16,), "int32"], n: T.int32):
else:
A[i] = 1

@T.prim_func
def expected(A: T.Buffer[(16,), "int32"], n: T.int32):
for i in T.serial(16):
if i == n:
Expand All @@ -346,14 +316,12 @@ class TestLiteralConstraintSplitBooleanAnd(BaseBeforeAfter):
the condition is to ensure we exercise RewriteSimplifier.
"""

@T.prim_func
def before(A: T.Buffer[(16, 16), "int32"], n: T.int32):
for i, j in T.grid(16, 16):
if i == n and j == n:
if i == n:
A[i, j] = 0

@T.prim_func
def expected(A: T.Buffer[(16, 16), "int32"], n: T.int32):
for i, j in T.grid(16, 16):
if i == n and j == n:
Expand All @@ -371,7 +339,6 @@ class TestLiteralConstraintSplitBooleanOr(BaseBeforeAfter):
RewriteSimplifier.
"""

@T.prim_func
def before(A: T.Buffer[(16, 16), "int32"], n: T.int32):
for i, j in T.grid(16, 16):
if i == n or j == n:
Expand All @@ -382,7 +349,6 @@ def before(A: T.Buffer[(16, 16), "int32"], n: T.int32):
else:
A[i, j] = 2

@T.prim_func
def expected(A: T.Buffer[(16, 16), "int32"], n: T.int32):
for i, j in T.grid(16, 16):
if i == n or j == n:
Expand Down
Loading