diff --git a/utils/get_test_info.py b/utils/get_test_info.py index 83dd4e297053..59201ef17949 100644 --- a/utils/get_test_info.py +++ b/utils/get_test_info.py @@ -15,6 +15,7 @@ import importlib import os import sys +import unittest # This is required to make the module import works (when the python process is running from the root of the repo) @@ -87,11 +88,19 @@ def get_test_classes(test_file): test_module = get_test_module(test_file) for attr in dir(test_module): attr_value = getattr(test_module, attr) - # ModelTesterMixin is also an attribute in specific model test module. Let's exclude them by checking - # `all_model_classes` is not empty (which also excludes other special classes). - model_classes = getattr(attr_value, "all_model_classes", []) - if len(model_classes) > 0: - test_classes.append(attr_value) + + # Look for the test classes (subclass of `unittest.TestCase`) with `all_model_classes` attribute. + # This also excludes `ModelTesterMixin` and `CausalLMModelTest`. + if isinstance(attr_value, type) and issubclass(attr_value, unittest.TestCase): + model_classes = getattr(attr_value, "all_model_classes", []) + # `CausalLMModelTest` (subclass of `ModelTesterMixin`) has `all_model_classes` as a class attribute with + # the value being `None`. For a real test class of `CausalLMModelTest`, the value is only set during `setUp`. + if model_classes is None: + test_instance = attr_value() + test_instance.setUp() + model_classes = getattr(test_instance, "all_model_classes", []) + if len(model_classes) > 0: + test_classes.append(attr_value) # sort with class names return sorted(test_classes, key=lambda x: x.__name__) @@ -102,7 +111,12 @@ def get_model_classes(test_file): test_classes = get_test_classes(test_file) model_classes = set() for test_class in test_classes: - model_classes.update(test_class.all_model_classes) + all_model_classes = test_class.all_model_classes + if all_model_classes is None: + test_instance = test_class() + test_instance.setUp() + all_model_classes = test_instance.all_model_classes + model_classes.update(all_model_classes) # sort with class names return sorted(model_classes, key=lambda x: x.__name__) @@ -128,8 +142,15 @@ def get_test_classes_for_model(test_file, model_class): test_classes = get_test_classes(test_file) target_test_classes = [] + for test_class in test_classes: - if model_class in test_class.all_model_classes: + all_model_classes = test_class.all_model_classes + if all_model_classes is None: + test_instance = test_class() + test_instance.setUp() + all_model_classes = test_instance.all_model_classes + + if model_class in all_model_classes: target_test_classes.append(test_class) # sort with class names