From 6b5cb2f5fb2f964c47584efd3e15345a75571867 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Thu, 3 Oct 2024 06:52:33 +0000 Subject: [PATCH 1/2] [TVMScript] Enable T.macro decorateing class method This PR refactors the implementation of `T.macro`, so that the `self` argument can be passed through the TVMScript parser. Then we can decroate the class methods with `T.macro`. --- python/tvm/script/parser/core/parser.py | 4 +-- python/tvm/script/parser/relax/entry.py | 7 ++-- python/tvm/script/parser/tir/entry.py | 7 ++-- .../tvmscript/test_tvmscript_parser_tir.py | 35 ++++++++++++++++--- 4 files changed, 43 insertions(+), 10 deletions(-) diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 372a3c54e4c5..f40b9a7cf6d3 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -135,9 +135,9 @@ def _find_parser_def(self): def get_macro_def(self): ast_module = self.source.as_ast() for decl in ast_module.body: - if isinstance(decl, doc.FunctionDef) and decl.name == self.__name__: + if isinstance(decl, doc.FunctionDef) and decl.name == self.func.__name__: return decl - raise RuntimeError(f"cannot find macro definition for {self.__name__}") + raise RuntimeError(f"cannot find macro definition for {self.func.__name__}") def __call__(self, *args, **kwargs): param_binding = inspect.signature(self.func).bind(*args, **kwargs) diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index 73a5d7149a81..04a5f985643e 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -128,8 +128,11 @@ def macro(*args, hygienic: bool = True) -> _Callable: def _decorator(func: _Callable) -> ScriptMacro: source, closure_vars = scan_macro(func, utils.inspect_function_capture(func)) obj = RelaxMacro(source, closure_vars, func, hygienic) - obj.__name__ = func.__name__ - return obj + + def wrapper(*args, **kwargs): + return obj(*args, **kwargs) + + return wrapper if len(args) == 0: return _decorator diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index 79eb88dfc102..c7d5dc756b32 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -139,8 +139,11 @@ def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: def _decorator(func: Callable) -> TIRMacro: source, closure_vars = scan_macro(func, utils.inspect_function_capture(func)) obj = TIRMacro(source, closure_vars, func, hygienic) - obj.__name__ = func.__name__ - return obj + + def wrapper(*args, **kwargs): + return obj(*args, **kwargs) + + return wrapper if len(args) == 0: return _decorator diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py b/tests/python/tvmscript/test_tvmscript_parser_tir.py index 2dcbc89d47a6..d4207ef0fa76 100644 --- a/tests/python/tvmscript/test_tvmscript_parser_tir.py +++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py @@ -116,8 +116,6 @@ def evaluate0(): def func1(): T.evaluate(0) - assert func1.hygienic - @T.prim_func(private=True) def use1(): func1() @@ -129,8 +127,6 @@ def use1(): def func2(): T.evaluate(0) - assert func2.hygienic - @T.prim_func(private=True) def use2(): func2() @@ -212,6 +208,37 @@ def expected_non_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32" tvm.ir.assert_structural_equal(use_non_hygienic, expected_non_hygienic) +def test_tir_macro_in_class(): + class Object: + def __init__(self, x: T.Buffer): + self.local_x = T.alloc_buffer(x.shape, x.dtype) + + @T.macro + def load(self, x: T.Buffer): + N, M = T.meta_var(self.local_x.shape) + for i, j in T.grid(N, M): + with T.block("update"): + vi, vj = T.axis.remap("SS", [i, j]) + self.local_x[vi, vj] = x[vi, vj] + + @T.prim_func(private=True) + def func_w_macro(a: T.handle): + A = T.match_buffer(a, [128, 128]) + o = T.meta_var(Object(A)) + o.load(A) + + @T.prim_func(private=True) + def func_no_macro(a: T.handle): + A = T.match_buffer(a, [128, 128]) + local_a = T.alloc_buffer([128, 128]) + for i, j in T.grid(128, 128): + with T.block("update"): + vi, vj = T.axis.remap("SS", [i, j]) + local_a[vi, vj] = A[vi, vj] + + tvm.ir.assert_structural_equal(func_no_macro, func_w_macro) + + def test_tir_starred_expression(): dims = (128, 128) From 04dcf1115d4298a47c2107f5811a829df8ce0a8a Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Thu, 3 Oct 2024 06:58:17 +0000 Subject: [PATCH 2/2] update test --- tests/python/tvmscript/test_tvmscript_parser_tir.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py b/tests/python/tvmscript/test_tvmscript_parser_tir.py index d4207ef0fa76..16b206751402 100644 --- a/tests/python/tvmscript/test_tvmscript_parser_tir.py +++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py @@ -224,8 +224,10 @@ def load(self, x: T.Buffer): @T.prim_func(private=True) def func_w_macro(a: T.handle): A = T.match_buffer(a, [128, 128]) - o = T.meta_var(Object(A)) - o.load(A) + o1 = T.meta_var(Object(A)) + o1.load(A) + o2 = T.meta_var(Object(A)) + o2.load(o1.local_x) @T.prim_func(private=True) def func_no_macro(a: T.handle): @@ -235,6 +237,11 @@ def func_no_macro(a: T.handle): with T.block("update"): vi, vj = T.axis.remap("SS", [i, j]) local_a[vi, vj] = A[vi, vj] + local_b = T.alloc_buffer([128, 128]) + for i, j in T.grid(128, 128): + with T.block("update"): + vi, vj = T.axis.remap("SS", [i, j]) + local_b[vi, vj] = local_a[vi, vj] tvm.ir.assert_structural_equal(func_no_macro, func_w_macro)