diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index ad1e003d6e3f..213cc40270ae 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1859,10 +1859,7 @@ def __init_subclass__(cls): cls.transform = cls._normalize_transform(cls.transform) @classmethod - def _normalize_before(cls, func): - if hasattr(func, "_pytestfixturefunction"): - return func - + def _normalize_ir_module(cls, func): if isinstance(func, tvm.tir.PrimFunc): def inner(self): @@ -1875,6 +1872,22 @@ def inner(self): # pylint: disable=unused-argument return func(self) + elif inspect.isclass(func): + + def inner(self): + # pylint: disable=unused-argument + func_dict = {} + for name, method in func.__dict__.items(): + if name.startswith("_"): + pass + elif isinstance(method, tvm.ir.function.BaseFunc): + func_dict[name] = method + else: + source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(method)) + prim_func = tvm.script.from_source(source_code) + func_dict[name] = prim_func + return tvm.IRModule(func_dict) + else: def inner(self): @@ -1884,50 +1897,64 @@ def inner(self): return pytest.fixture(inner) + @classmethod + def _normalize_before(cls, func): + if hasattr(func, "_pytestfixturefunction"): + return func + else: + return cls._normalize_ir_module(func) + @classmethod def _normalize_expected(cls, func): if hasattr(func, "_pytestfixturefunction"): return func - if isinstance(func, tvm.tir.PrimFunc) or ( - inspect.isclass(func) and issubclass(func, Exception) - ): + elif inspect.isclass(func) and issubclass(func, Exception): def inner(self): # pylint: disable=unused-argument return func - elif cls._is_method(func): - - def inner(self): - # pylint: disable=unused-argument - return func(self) + return pytest.fixture(inner) else: - - def inner(self): - # pylint: disable=unused-argument - source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(func)) - return tvm.script.from_source(source_code) - - return pytest.fixture(inner) + return cls._normalize_ir_module(func) @classmethod def _normalize_transform(cls, transform): + def apply(module_transform): + def inner(obj): + if isinstance(obj, tvm.IRModule): + return module_transform(obj) + elif isinstance(obj, tvm.tir.PrimFunc): + mod = tvm.IRModule({"main": obj}) + mod = module_transform(mod) + return mod["main"] + else: + raise TypeError(f"Expected IRModule or PrimFunc, but received {type(obj)}") + + return inner + if hasattr(transform, "_pytestfixturefunction"): - return transform - if isinstance(transform, tvm.ir.transform.Pass): + if not hasattr(cls, "_transform_orig"): + cls._transform_orig = transform + + def inner(self, _transform_orig): + # pylint: disable=unused-argument + return apply(_transform_orig) + + elif isinstance(transform, tvm.ir.transform.Pass): def inner(self): # pylint: disable=unused-argument - return transform + return apply(transform) elif cls._is_method(transform): def inner(self): # pylint: disable=unused-argument - return transform(self) + return apply(transform(self)) else: @@ -1945,36 +1972,42 @@ def _is_method(func): def test_compare(self, before, expected, transform): """Unit test to compare the expected TIR PrimFunc to actual""" - before_mod = tvm.IRModule.from_expr(before) + def pprint(name, obj): + script = obj.script() + if isinstance(obj, tvm.IRModule): + return script.replace("class Module", f"class {name}") + else: + return script.replace("def func", f"def {name}") if inspect.isclass(expected) and issubclass(expected, Exception): with pytest.raises(expected): - after_mod = transform(before_mod) + after = transform(before) # This portion through pytest.fail isn't strictly # necessary, but gives a better error message that # includes the before/after. - after = after_mod["main"] - script = tvm.IRModule({"after": after, "before": before}).script() + before_str = pprint("before", before) + after_str = pprint("after", after) + pytest.fail( msg=( f"Expected {expected.__name__} to be raised from transformation, " - f"instead received TIR\n:{script}" + f"instead received TIR\n:{before_str}\n{after_str}" ) ) - elif isinstance(expected, tvm.tir.PrimFunc): - after_mod = transform(before_mod) - after = after_mod["main"] + elif isinstance(expected, (tvm.tir.PrimFunc, tvm.ir.IRModule)): + after = transform(before) try: tvm.ir.assert_structural_equal(after, expected) except ValueError as err: - script = tvm.IRModule( - {"expected": expected, "after": after, "before": before} - ).script() + before_str = pprint("before", before) + after_str = pprint("after", after) + expected_str = pprint("expected", expected) raise ValueError( - f"TIR after transformation did not match expected:\n{script}" + f"TIR after transformation did not match expected:\n" + f"{before_str}\n{after_str}\n{expected_str}" ) from err else: @@ -1982,5 +2015,5 @@ def test_compare(self, before, expected, transform): f"tvm.testing.CompareBeforeAfter requires the `expected` fixture " f"to return either `Exception`, an `Exception` subclass, " f"or an instance of `tvm.tir.PrimFunc`. " - f"Instead, received {type(exception)}." + f"Instead, received {type(expected)}." ) diff --git a/tests/python/unittest/test_tvm_testing_before_after.py b/tests/python/unittest/test_tvm_testing_before_after.py index 613d66ccdb2b..946493922ed5 100644 --- a/tests/python/unittest/test_tvm_testing_before_after.py +++ b/tests/python/unittest/test_tvm_testing_before_after.py @@ -18,7 +18,7 @@ import tvm import tvm.testing -from tvm.script import tir as T +from tvm.script import tir as T, ir_module class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): @@ -79,5 +79,52 @@ def func(A: T.Buffer[n, "float32"]): expected = before +class TestBeforeAfterIRModule(BaseBeforeAfter): + """The preferred form for writing TIR unit tests + + All evaluation is done at test-time, with the minimal amount of + additional lines. The `@tvm.testing.fixture`, `@ir_module`, and + `@T.prim_func` annotations are handled by + `tvm.testing.CompareBeforeAfter`. + """ + + class before: + def func_A(A: T.Buffer[16, "float32"]): + for i in T.serial(16): + A[i] = 0.0 + + def func_B(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = 42 + + expected = before + + +class TestBeforeAfterIRModuleExplicitFixture(BaseBeforeAfter): + """Like TestBeforeAfterIRModule, but with an explicit fixture + + If the IRModule depends on additional fixtures, this form can be + used. + """ + + @tvm.testing.fixture + def before(self): + @ir_module + class mod: + @T.prim_func + def func_A(A: T.Buffer[16, "float32"]): + for i in T.serial(16): + A[i] = 0.0 + + @T.prim_func + def func_B(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = 42 + + return mod + + expected = before + + if __name__ == "__main__": tvm.testing.main()