From a3294f75e2976e77bd26091e450876b63b2ffd97 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 15 Aug 2023 15:46:05 -0700 Subject: [PATCH 1/3] [Testing] Allow Capitalized name in CompareBeforeAfter --- python/tvm/testing/utils.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 70cd7a02dab0..7e856a0f670a 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1890,7 +1890,7 @@ class CompareBeforeAfter: input, apply a transformation, then either compare against an expected output or assert that the transformation raised an error. A test should subclass CompareBeforeAfter, defining class members - `before`, `transform`, and `expected`. CompareBeforeAfter will + `before` / `Before`, `transform`, and `expected` / `Expected`. CompareBeforeAfter will then use these members to define a test method and test fixture. `transform` may be one of the following. @@ -1901,7 +1901,7 @@ class CompareBeforeAfter: - A pytest fixture that returns a `tvm.ir.transform.Pass` - `before` may be any one of the following. + `before` / `Before` may be any one of the following. - An instance of `tvm.tir.PrimFunc`. This is allowed, but is not the preferred method, as any errors in constructing the @@ -1916,13 +1916,13 @@ class CompareBeforeAfter: - A pytest fixture that returns a `tvm.tir.PrimFunc` - `expected` may be any one of the following. The type of - `expected` defines the test being performed. If `expected` + `expected` / `Expected` may be any one of the following. The type of + `expected` / `Expected` defines the test being performed. If `expected` provides a `tvm.tir.PrimFunc`, the result of the transformation must match `expected`. If `expected` is an exception, then the transformation must raise that exception type. - - Any option supported for `before`. + - Any option supported for `before` / `Before`. - The `Exception` class object, or a class object that inherits from `Exception`. @@ -1953,10 +1953,13 @@ def expected(A: T.Buffer(1, "int32")): """ def __init_subclass__(cls): - if hasattr(cls, "before"): - cls.before = cls._normalize_before(cls.before) - if hasattr(cls, "expected"): - cls.expected = cls._normalize_expected(cls.expected) + for name in ["before", "Before"]: + if hasattr(cls, name): + cls.before = cls._normalize_before(getattr(cls, name)) + break + for name in ["expected", "Expected"]: + if hasattr(cls, name): + cls.expected = cls._normalize_expected(getattr(cls, name)) if hasattr(cls, "transform"): cls.transform = cls._normalize_transform(cls.transform) From 16a476e597a2de5cc65b60612868e4374e0a89da Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 16 Aug 2023 12:00:42 -0700 Subject: [PATCH 2/3] Address comments --- python/tvm/testing/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 7e856a0f670a..bd35e3fb4c86 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1953,6 +1953,10 @@ def expected(A: T.Buffer(1, "int32")): """ def __init_subclass__(cls): + assert (hasattr(cls, "before") and hasattr(cls, "Expected")) or ( + hasattr(cls, "Before") and hasattr(cls, "Expected") + ), "The subclass of CompareBeforeAfter should have either 'before' and 'expected', or " + "'Before' and 'Expected' defined." for name in ["before", "Before"]: if hasattr(cls, name): cls.before = cls._normalize_before(getattr(cls, name)) From b8fa0d33779222cd9073eb7608b92dc0b84f87b8 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 16 Aug 2023 15:19:00 -0700 Subject: [PATCH 3/3] Address comments --- python/tvm/testing/utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index bd35e3fb4c86..7817ddcb0189 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1953,10 +1953,11 @@ def expected(A: T.Buffer(1, "int32")): """ def __init_subclass__(cls): - assert (hasattr(cls, "before") and hasattr(cls, "Expected")) or ( - hasattr(cls, "Before") and hasattr(cls, "Expected") - ), "The subclass of CompareBeforeAfter should have either 'before' and 'expected', or " - "'Before' and 'Expected' defined." + assert len([getattr(cls, name) for name in ["before", "Before"] if hasattr(cls, name)]) <= 1 + assert ( + len([getattr(cls, name) for name in ["expected", "Expected"] if hasattr(cls, name)]) + <= 1 + ) for name in ["before", "Before"]: if hasattr(cls, name): cls.before = cls._normalize_before(getattr(cls, name)) @@ -1964,6 +1965,7 @@ def __init_subclass__(cls): for name in ["expected", "Expected"]: if hasattr(cls, name): cls.expected = cls._normalize_expected(getattr(cls, name)) + break if hasattr(cls, "transform"): cls.transform = cls._normalize_transform(cls.transform)