diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 372a3c54e4c5..f40b9a7cf6d3 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -135,9 +135,9 @@ def _find_parser_def(self): def get_macro_def(self): ast_module = self.source.as_ast() for decl in ast_module.body: - if isinstance(decl, doc.FunctionDef) and decl.name == self.__name__: + if isinstance(decl, doc.FunctionDef) and decl.name == self.func.__name__: return decl - raise RuntimeError(f"cannot find macro definition for {self.__name__}") + raise RuntimeError(f"cannot find macro definition for {self.func.__name__}") def __call__(self, *args, **kwargs): param_binding = inspect.signature(self.func).bind(*args, **kwargs) diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index 73a5d7149a81..04a5f985643e 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -128,8 +128,11 @@ def macro(*args, hygienic: bool = True) -> _Callable: def _decorator(func: _Callable) -> ScriptMacro: source, closure_vars = scan_macro(func, utils.inspect_function_capture(func)) obj = RelaxMacro(source, closure_vars, func, hygienic) - obj.__name__ = func.__name__ - return obj + + def wrapper(*args, **kwargs): + return obj(*args, **kwargs) + + return wrapper if len(args) == 0: return _decorator diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index 79eb88dfc102..c7d5dc756b32 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -139,8 +139,11 @@ def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: def _decorator(func: Callable) -> TIRMacro: source, closure_vars = scan_macro(func, utils.inspect_function_capture(func)) obj = TIRMacro(source, closure_vars, func, hygienic) - obj.__name__ = func.__name__ - return obj + + def wrapper(*args, **kwargs): + return obj(*args, **kwargs) + + return wrapper if len(args) == 0: return _decorator diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py b/tests/python/tvmscript/test_tvmscript_parser_tir.py index 2dcbc89d47a6..16b206751402 100644 --- a/tests/python/tvmscript/test_tvmscript_parser_tir.py +++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py @@ -116,8 +116,6 @@ def evaluate0(): def func1(): T.evaluate(0) - assert func1.hygienic - @T.prim_func(private=True) def use1(): func1() @@ -129,8 +127,6 @@ def use1(): def func2(): T.evaluate(0) - assert func2.hygienic - @T.prim_func(private=True) def use2(): func2() @@ -212,6 +208,44 @@ def expected_non_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32" tvm.ir.assert_structural_equal(use_non_hygienic, expected_non_hygienic) +def test_tir_macro_in_class(): + class Object: + def __init__(self, x: T.Buffer): + self.local_x = T.alloc_buffer(x.shape, x.dtype) + + @T.macro + def load(self, x: T.Buffer): + N, M = T.meta_var(self.local_x.shape) + for i, j in T.grid(N, M): + with T.block("update"): + vi, vj = T.axis.remap("SS", [i, j]) + self.local_x[vi, vj] = x[vi, vj] + + @T.prim_func(private=True) + def func_w_macro(a: T.handle): + A = T.match_buffer(a, [128, 128]) + o1 = T.meta_var(Object(A)) + o1.load(A) + o2 = T.meta_var(Object(A)) + o2.load(o1.local_x) + + @T.prim_func(private=True) + def func_no_macro(a: T.handle): + A = T.match_buffer(a, [128, 128]) + local_a = T.alloc_buffer([128, 128]) + for i, j in T.grid(128, 128): + with T.block("update"): + vi, vj = T.axis.remap("SS", [i, j]) + local_a[vi, vj] = A[vi, vj] + local_b = T.alloc_buffer([128, 128]) + for i, j in T.grid(128, 128): + with T.block("update"): + vi, vj = T.axis.remap("SS", [i, j]) + local_b[vi, vj] = local_a[vi, vj] + + tvm.ir.assert_structural_equal(func_no_macro, func_w_macro) + + def test_tir_starred_expression(): dims = (128, 128)