diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 399adf27e4e5..884a885fb21d 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -2065,13 +2065,6 @@ def _is_method(func): def test_compare(self, before, expected, transform): """Unit test to compare the expected TIR PrimFunc to actual""" - def pprint(name, obj): - script = obj.script() - if isinstance(obj, tvm.IRModule): - return script.replace("class Module", f"class {name}") - else: - return script.replace("def func", f"def {name}") - if inspect.isclass(expected) and issubclass(expected, Exception): with pytest.raises(expected): after = transform(before) @@ -2079,8 +2072,8 @@ def pprint(name, obj): # This portion through pytest.fail isn't strictly # necessary, but gives a better error message that # includes the before/after. - before_str = pprint("before", before) - after_str = pprint("after", after) + before_str = before.script(name="before") + after_str = after.script(name="after") pytest.fail( msg=( @@ -2095,9 +2088,9 @@ def pprint(name, obj): try: tvm.ir.assert_structural_equal(after, expected) except ValueError as err: - before_str = pprint("before", before) - after_str = pprint("after", after) - expected_str = pprint("expected", expected) + before_str = before.script(name="before") + after_str = after.script(name="after") + expected_str = expected.script(name="expected") raise ValueError( f"TIR after transformation did not match expected:\n" f"{before_str}\n{after_str}\n{expected_str}"