diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 70cd7a02dab0..7817ddcb0189 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1890,7 +1890,7 @@ class CompareBeforeAfter: input, apply a transformation, then either compare against an expected output or assert that the transformation raised an error. A test should subclass CompareBeforeAfter, defining class members - `before`, `transform`, and `expected`. CompareBeforeAfter will + `before` / `Before`, `transform`, and `expected` / `Expected`. CompareBeforeAfter will then use these members to define a test method and test fixture. `transform` may be one of the following. @@ -1901,7 +1901,7 @@ class CompareBeforeAfter: - A pytest fixture that returns a `tvm.ir.transform.Pass` - `before` may be any one of the following. + `before` / `Before` may be any one of the following. - An instance of `tvm.tir.PrimFunc`. This is allowed, but is not the preferred method, as any errors in constructing the @@ -1916,13 +1916,13 @@ class CompareBeforeAfter: - A pytest fixture that returns a `tvm.tir.PrimFunc` - `expected` may be any one of the following. The type of - `expected` defines the test being performed. If `expected` + `expected` / `Expected` may be any one of the following. The type of + `expected` / `Expected` defines the test being performed. If `expected` provides a `tvm.tir.PrimFunc`, the result of the transformation must match `expected`. If `expected` is an exception, then the transformation must raise that exception type. - - Any option supported for `before`. + - Any option supported for `before` / `Before`. - The `Exception` class object, or a class object that inherits from `Exception`. @@ -1953,10 +1953,19 @@ def expected(A: T.Buffer(1, "int32")): """ def __init_subclass__(cls): - if hasattr(cls, "before"): - cls.before = cls._normalize_before(cls.before) - if hasattr(cls, "expected"): - cls.expected = cls._normalize_expected(cls.expected) + assert len([getattr(cls, name) for name in ["before", "Before"] if hasattr(cls, name)]) <= 1 + assert ( + len([getattr(cls, name) for name in ["expected", "Expected"] if hasattr(cls, name)]) + <= 1 + ) + for name in ["before", "Before"]: + if hasattr(cls, name): + cls.before = cls._normalize_before(getattr(cls, name)) + break + for name in ["expected", "Expected"]: + if hasattr(cls, name): + cls.expected = cls._normalize_expected(getattr(cls, name)) + break if hasattr(cls, "transform"): cls.transform = cls._normalize_transform(cls.transform)