From 831695b37641be6bff4258c61192515746efd42e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 22 Feb 2023 09:15:58 -0600 Subject: [PATCH 1/2] [Utils] Allow classmethod and staticmethod in TVMDerivedObject Instance methods that exist in the user-defined class but not in the TVM base are forward using `__getattr__`. However, this is only applied for attribute look of instances, and doesn't apply for attribute lookup on the class object itself, such as when calling a classmethod or staticmethod. This commit exposes class methods and static methods in the wrapper class, if they are defined in the user-defined subclass. --- python/tvm/meta_schedule/utils.py | 3 +++ .../test_meta_schedule_post_order_apply.py | 21 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 401fdab08a26..959e85fe17eb 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -128,6 +128,9 @@ def __setattr__(self, name, value): TVMDerivedObject.__name__ = cls.__name__ TVMDerivedObject.__doc__ = cls.__doc__ TVMDerivedObject.__module__ = cls.__module__ + for key,value in cls.__dict__.items(): + if isinstance(value, (classmethod,staticmethod)): + setattr(TVMDerivedObject, key, value) return TVMDerivedObject diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py index c1d2dc3d0788..716f829653f3 100644 --- a/tests/python/unittest/test_meta_schedule_post_order_apply.py +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -404,5 +404,26 @@ def _get_sch(filter_fn): assert len(schs) == 8 +def test_meta_schedule_derived_object(): + @derived_object + class RemoveBlock(PyScheduleRule): + @classmethod + def class_construct(cls): + return cls() + + @staticmethod + def static_construct(): + return RemoveBlock() + + inst_by_init = RemoveBlock() + assert isinstance(inst_by_init, RemoveBlock) + + inst_by_classmethod = RemoveBlock.class_construct() + assert isinstance(inst_by_classmethod, RemoveBlock) + + inst_by_staticmethod = RemoveBlock.static_construct() + assert isinstance(inst_by_staticmethod, RemoveBlock) + + if __name__ == "__main__": tvm.testing.main() From 5ffdd2ae0ffde1d051dcc4e7c682692f32178848 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 10 Mar 2023 08:30:38 -0600 Subject: [PATCH 2/2] fix linting errors --- python/tvm/meta_schedule/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 959e85fe17eb..fb1ddd6585f2 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -128,8 +128,8 @@ def __setattr__(self, name, value): TVMDerivedObject.__name__ = cls.__name__ TVMDerivedObject.__doc__ = cls.__doc__ TVMDerivedObject.__module__ = cls.__module__ - for key,value in cls.__dict__.items(): - if isinstance(value, (classmethod,staticmethod)): + for key, value in cls.__dict__.items(): + if isinstance(value, (classmethod, staticmethod)): setattr(TVMDerivedObject, key, value) return TVMDerivedObject