diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 401fdab08a26..fb1ddd6585f2 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -128,6 +128,9 @@ def __setattr__(self, name, value): TVMDerivedObject.__name__ = cls.__name__ TVMDerivedObject.__doc__ = cls.__doc__ TVMDerivedObject.__module__ = cls.__module__ + for key, value in cls.__dict__.items(): + if isinstance(value, (classmethod, staticmethod)): + setattr(TVMDerivedObject, key, value) return TVMDerivedObject diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py index c1d2dc3d0788..716f829653f3 100644 --- a/tests/python/unittest/test_meta_schedule_post_order_apply.py +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -404,5 +404,26 @@ def _get_sch(filter_fn): assert len(schs) == 8 +def test_meta_schedule_derived_object(): + @derived_object + class RemoveBlock(PyScheduleRule): + @classmethod + def class_construct(cls): + return cls() + + @staticmethod + def static_construct(): + return RemoveBlock() + + inst_by_init = RemoveBlock() + assert isinstance(inst_by_init, RemoveBlock) + + inst_by_classmethod = RemoveBlock.class_construct() + assert isinstance(inst_by_classmethod, RemoveBlock) + + inst_by_staticmethod = RemoveBlock.static_construct() + assert isinstance(inst_by_staticmethod, RemoveBlock) + + if __name__ == "__main__": tvm.testing.main()