diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index ee8032236252..0cc385d876a8 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -71,6 +71,8 @@ class PrimFuncFrameNode : public TIRFrameNode { Optional name; /*! \brief Function parameters. */ Array args; + /*! \brief Whether the PrimFunc is annotated as private. */ + bool is_private; /*! \brief The return type of the function. */ Optional ret_type; /*! \brief Maps some parameters to specific Buffer data structures. */ @@ -86,6 +88,7 @@ class PrimFuncFrameNode : public TIRFrameNode { TIRFrameNode::VisitAttrs(v); v->Visit("name", &name); v->Visit("args", &args); + v->Visit("is_private", &is_private); v->Visit("ret_type", &ret_type); v->Visit("buffer_map", &buffer_map); v->Visit("attrs", &attrs); diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 9ce478da43cd..735d5ba6c0a1 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -56,7 +56,7 @@ Buffer BufferDecl(Array shape, DataType dtype, String buffer_name, Opt * \brief The primitive function statement. * \return The PrimFuncFrame. */ -PrimFuncFrame PrimFunc(); +PrimFuncFrame PrimFunc(bool is_private); /*! * \brief The PrimFunc variable arguments adding function. diff --git a/python/tvm/ir/function.py b/python/tvm/ir/function.py index 60961027be40..862077ecd8f0 100644 --- a/python/tvm/ir/function.py +++ b/python/tvm/ir/function.py @@ -67,3 +67,19 @@ def with_attr(self, attr_key_or_dict, attr_value=None): return _ffi_api.BaseFuncWithAttr( res._move(), attr_key_or_dict, tvm.runtime.convert(attr_value) ) + + def without_attr(self, attr_key: str) -> "BaseFunc": + """Create a new copy of the function with an attribute without provided key. + + Parameters + ---------- + attr_key : str + The attribute key to delete from the attrubte pairs. + + + Returns + ------- + func : BaseFunc + A new copy of the function + """ + return _ffi_api.BaseFuncWithoutAttr(self, attr_key) diff --git a/python/tvm/meta_schedule/testing/space_generation.py b/python/tvm/meta_schedule/testing/space_generation.py index 6689e45245e8..00004cb99247 100644 --- a/python/tvm/meta_schedule/testing/space_generation.py +++ b/python/tvm/meta_schedule/testing/space_generation.py @@ -39,6 +39,21 @@ def get_rules( return [rule for rule in rules if isinstance(rule, types)] +def structural_equal_no_gs(mod1: IRModule, mod2: IRModule) -> bool: + """ + Checks structural equality but ignores global symbols + """ + # for every function in the modules, remove global symbols from the attrs and then compare + def remove_global_symbols(mod: IRModule) -> IRModule: + stripped_mod = IRModule() + for global_var in mod.get_global_vars(): + func = mod[global_var] + stripped_mod[global_var] = func.without_attr("global_symbol") + return stripped_mod + + return structural_equal(remove_global_symbols(mod1), remove_global_symbols(mod2)) + + def generate_design_space( kind: Literal["llvm", "cuda", "cuda-tensorcore", "hexagon"], mod: IRModule, @@ -87,7 +102,7 @@ def _find_match_sketch_id( insts=sketch.trace.insts, decisions=new_decisions, ).apply_to_schedule(sch, remove_postproc=True) - if structural_equal(sch.mod, expected_mod): + if structural_equal_no_gs(sch.mod, expected_mod): verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask, text_format="json") return sketch_id return None diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 22f815c3d812..afe9388b9a14 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -162,15 +162,22 @@ def buffer_decl(*args, **kwargs): return buffer(*args, **kwargs) -def prim_func() -> frame.PrimFuncFrame: +def prim_func(is_private: bool = False) -> frame.PrimFuncFrame: """The primitive function statement. + Parameters + ---------- + is_private : bool + Whether the PrimFunc is annotated as private + (if yes, it does not have a global symbol assigned; + otherwise, the global symbol is the PrimFunc's name) + Returns ------- res : frame.PrimFuncFrame The PrimFuncFrame. """ - return _ffi_api.PrimFunc() # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.PrimFunc(is_private) # type: ignore[attr-defined] # pylint: disable=no-member def arg(name: str, obj: Union[Var, Buffer]) -> Union[Var, Buffer]: diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index 64b71d699f3d..93bf8721c58e 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -16,7 +16,7 @@ # under the License. """The entry point of TVM parser for tir.""" import inspect -from typing import Any, Callable, Dict, Union +from typing import Any, Callable, Dict, Optional, Union from tvm.ir.base import deprecated from tvm.tir import Buffer, PrimFunc @@ -25,26 +25,52 @@ from .._core import doc, parse, parse_macro, utils -def prim_func(func: Callable) -> Union[PrimFunc, Callable]: +def prim_func(func: Optional[Callable] = None, private: bool = False) -> Union[PrimFunc, Callable]: """The parsing method for tir prim func, by using `@prim_func` as decorator. Parameters ---------- func : Callable The function to be parsed as prim func. + (Listed as optional to allow the decorator to be used + without arguments, like `@prim_func`, + or with an argument, `@prim_func(private=True)`) + + private : bool, optional + Whether the function should be treated as private. + A private function has no global symbol attribute; + if the function is not private, it will have a global symbol + matching the function name. Returns ------- res : Union[PrimFunc, Callable] The parsed tir prim func. """ - if not inspect.isfunction(func): - raise TypeError(f"Expect a function, but got: {func}") - if utils.is_defined_in_class(inspect.stack(), func): - return func - f = parse(func, utils.inspect_function_capture(func)) - setattr(f, "__name__", func.__name__) - return f + # pylint: disable=unused-argument + # (private will be used in the parser, but not immediately) + + # need to capture this var outside the wrapper because the wrapper + # adds to the stack + outer_stack = inspect.stack() + + def decorator_wrapper(func): + if not inspect.isfunction(func): + raise TypeError(f"Expect a function, but got: {func}") + if utils.is_defined_in_class(outer_stack, func): + return func + f = parse(func, utils.inspect_function_capture(func)) + setattr(f, "__name__", func.__name__) + return f + + if func is not None: + # no optional args given => use wrapper directly + return decorator_wrapper(func) + else: + # if there is an optional arg given, return a new decorator + # that will then be invoked + setattr(decorator_wrapper, "dispatch_token", "tir") + return decorator_wrapper setattr(prim_func, "dispatch_token", "tir") diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 67e14d0e9772..5398b471e49d 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -156,6 +156,21 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) - return var +def find_decorator_annotation(node: doc.FunctionDef, annotation: str, default: bool = True) -> bool: + """ + Check the value of given annotation (argument name) in the prim_func decorator. + Returns the value of the annotation if present, otherwise giving the default value. + """ + # look for the named argument in the prim_func decorator + for dec in node.decorator_list: + if not isinstance(dec, doc.Call) or dec.func.attr != "prim_func": + continue + for keyword in dec.keywords: + if keyword.arg == annotation: + return keyword.value.value + return default + + @dispatch.register(token="tir", type_name="For") def visit_for(self: Parser, node: doc.For) -> None: """The for visiting method for tir. @@ -365,10 +380,11 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: """ supplied_annotation = self.function_annotations func_annotation = supplied_annotation.get(node.name, {}) + privacy = find_decorator_annotation(node, "private", default=False) self.function_annotations = None with self.var_table.with_frame(): self.var_table.add("range", T.serial) - with T.prim_func(): + with T.prim_func(is_private=privacy): T.func_name(node.name) if node.returns is not None: ret_type = self.eval_expr(node.returns) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index f29b7c4394c2..70cd7a02dab0 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1983,11 +1983,11 @@ def inner(self): if name.startswith("_"): pass elif isinstance(method, tvm.ir.function.BaseFunc): - func_dict[name] = method + func_dict[name] = method.with_attr("global_symbol", name) else: source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(method)) prim_func = tvm.script.from_source(source_code) - func_dict[name] = prim_func + func_dict[name] = prim_func.with_attr("global_symbol", name) return tvm.IRModule(func_dict) else: @@ -2093,6 +2093,10 @@ def test_compare(self, before, expected, transform): after = transform(before) try: + # overwrite global symbol so it doesn't come up in the comparison + if isinstance(after, tvm.tir.PrimFunc): + after = after.with_attr("global_symbol", "main") + expected = expected.with_attr("global_symbol", "main") tvm.ir.assert_structural_equal(after, expected) except ValueError as err: before_str = before.script(name="before") diff --git a/python/tvm/tir/schedule/testing.py b/python/tvm/tir/schedule/testing.py index a293b54b46a1..7fc6cd7b7d48 100644 --- a/python/tvm/tir/schedule/testing.py +++ b/python/tvm/tir/schedule/testing.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=dangerous-default-value """Testing utilities for the TensorIR schedule API""" -from typing import Sequence, Union +from typing import Any, Sequence, Union import tvm from tvm.ir import IRModule, assert_structural_equal @@ -24,6 +24,25 @@ from tvm.tir.schedule import Schedule, Trace +def assert_structural_equal_ignore_global_symbol( + func1: PrimFunc, + func2: PrimFunc, + *args: Any, + **kwargs: Any, +) -> None: + """ + Asserts that PrimFuncs func1 and func2 are structurally equal, setting both + their global symbol attributes to main so that the global symbol + will not be a point of comparison. + """ + assert_structural_equal( + func1.with_attr("global_symbol", "main"), + func2.with_attr("global_symbol", "main"), + *args, + **kwargs, + ) + + def verify_trace_roundtrip( sch: Schedule, mod: Union[PrimFunc, IRModule], diff --git a/src/ir/function.cc b/src/ir/function.cc index ce294708b2a9..d4522733992c 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -22,6 +22,7 @@ * \brief The function data structure. */ #include +#include #include #include @@ -44,4 +45,16 @@ TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr") LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); }); +TVM_REGISTER_GLOBAL("ir.BaseFuncWithoutAttr") + .set_body_typed([](BaseFunc func, String key) -> BaseFunc { + if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } else if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } else { + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); + return func; + } + }); + } // namespace tvm diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index dd8d3c2ed3f3..c15a290bf03d 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -30,6 +30,22 @@ namespace tir { void PrimFuncFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); + // if the prim func is not private and there isn't already a global symbol, + // add a global symbol + if (!is_private && name.defined()) { + if (!attrs.defined()) { + attrs = {{tvm::attr::kGlobalSymbol, name.value()}}; + } else if (!attrs.value().count(tvm::attr::kGlobalSymbol)) { + // copy over attributes (can't mutate the dict inside the optional in-place) + Map new_attrs; + for (auto kv : attrs.value()) { + new_attrs.Set(kv.first, kv.second); + } + new_attrs.Set(tvm::attr::kGlobalSymbol, name.value()); + attrs = std::move(new_attrs); + } + } + tvm::tir::PrimFunc func( /*params=*/args, /*body=*/AsStmt(stmts), diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 154a1ab3b01b..24dfa425ddcb 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -54,9 +54,10 @@ Buffer BufferDecl(Array shape, DataType dtype, String buffer_name, Opt axis_separators.value_or(Array())); } -PrimFuncFrame PrimFunc() { +PrimFuncFrame PrimFunc(bool is_private) { ObjectPtr n = make_object(); n->name = NullOpt; + n->is_private = is_private; n->args.clear(); n->ret_type = NullOpt; n->buffer_map.clear(); @@ -96,6 +97,10 @@ void FuncAttrs(Map attrs) { if (frame->attrs.defined()) { LOG(FATAL) << "ValueError: Duplicate prim func annotations, previous one is " << frame->attrs; } + if (attrs.count(tvm::attr::kGlobalSymbol) && frame->is_private) { + LOG(FATAL) << "ValueError: Specifying the global symbol even though the PrimFunc is annotated " + "as private"; + } frame->attrs = attrs; } diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index 19cc67cb14f1..3342405ba28c 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -68,6 +68,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::PrimFunc func, ObjectPath p, IRDocsifier d) -> Doc { With f(d, func); (*f)->AddDispatchToken(d, "tir"); + auto func_name = IdDoc(FindFunctionName(d, func).value_or("main")); d->SetCommonPrefix(func, [](const ObjectRef& obj) { return obj->IsInstance() || obj->IsInstance(); }); @@ -104,9 +105,25 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } // Step 2. Handle `func->attrs` if (func->attrs.defined() && !func->attrs->dict.empty()) { - (*f)->stmts.push_back( - ExprStmtDoc(TIR(d, "func_attr") // - ->Call({d->AsDoc(func->attrs, p->Attr("attrs"))}))); + // for global symbol, don't display it if it matches the func name + if (func->attrs->dict.count(tvm::attr::kGlobalSymbol) && + Downcast(func->attrs->dict.at(tvm::attr::kGlobalSymbol)) == func_name->name) { + Map new_attrs; + for (auto kv : func->attrs->dict) { + if (kv.first != tvm::attr::kGlobalSymbol) { + new_attrs.Set(kv.first, kv.second); + } + } + if (!new_attrs.empty()) { + (*f)->stmts.push_back(ExprStmtDoc( + TIR(d, "func_attr") // + ->Call({d->AsDoc(DictAttrs(new_attrs), p->Attr("attrs"))}))); + } + } else { + (*f)->stmts.push_back( + ExprStmtDoc(TIR(d, "func_attr") // + ->Call({d->AsDoc(func->attrs, p->Attr("attrs"))}))); + } } // Step 3. Handle `func->buffer_map` for (int i = 0; i < n_args; ++i) { @@ -168,10 +185,19 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ret_type = d->AsDoc(func->ret_type, p->Attr("ret_type")); } } + // Step 5. Determine if we need to display the private annotation in the decorator + ExprDoc decorator = TIR(d, "prim_func"); + // mark private if there is no global symbol + if (!func->attrs.defined() || !func->attrs->dict.count(tvm::attr::kGlobalSymbol)) { + Array pos_args; + decorator = std::move(decorator->Call(pos_args, {"private"}, + {LiteralDoc::Boolean(true, Optional())})); + } + return HeaderWrapper(d, FunctionDoc( - /*name=*/IdDoc(FindFunctionName(d, func).value_or("main")), + /*name=*/func_name, /*args=*/args, - /*decorators=*/{TIR(d, "prim_func")}, + /*decorators=*/{decorator}, /*return_type=*/ret_type, /*body=*/(*f)->stmts)); }); diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index 45d060567c68..0426c0cb1e7d 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -241,7 +241,13 @@ PrimFunc PoolAllocationToOffsetConverter::CreatePrimFuncWithPoolParams( // We dont need attrs of PrimFunc that might include non printable attrs such as target // for unit tests where emit_tvmscript_printable_ is to be used. if (emit_tvmscript_printable_) { - original_attrs = DictAttrs(); + // keep global symbol if it's there because it determines if the private attribute is printed + if (original_attrs->dict.count(tvm::attr::kGlobalSymbol)) { + original_attrs = DictAttrs( + {{tvm::attr::kGlobalSymbol, original_attrs->dict.at(tvm::attr::kGlobalSymbol)}}); + } else { + original_attrs = DictAttrs(); + } } PrimFunc ret = PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, original_attrs); @@ -464,8 +470,12 @@ IRModule PoolAllocationToOffsetConverter::operator()() { PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, main_func->attrs); main_func = WithAttr(main_func, tvm::attr::kPoolArgs, si.allocated_pool_params); } else { - main_func = - PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, DictAttrs()); + auto new_attrs = DictAttrs(); + if (main_func->attrs->dict.count(tvm::attr::kGlobalSymbol)) { + new_attrs = DictAttrs( + {{tvm::attr::kGlobalSymbol, main_func->attrs->dict.at(tvm::attr::kGlobalSymbol)}}); + } + main_func = PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, new_attrs); } module_->Update(gv, main_func); if (!emit_tvmscript_printable_) { diff --git a/tests/python/relay/test_pass_plan_devices.py b/tests/python/relay/test_pass_plan_devices.py index 937ece1f8243..c7f42103cab3 100644 --- a/tests/python/relay/test_pass_plan_devices.py +++ b/tests/python/relay/test_pass_plan_devices.py @@ -1650,7 +1650,7 @@ def test_lowered(): of device_copies to mediate any scope changes. """ - @T.prim_func + @T.prim_func(private=True) def input_gem(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [128, 128], scope="scopeA") # will flow out B = T.match_buffer(b, [128, 128], scope="") # will flow in @@ -1664,7 +1664,7 @@ def input_gem(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: D[vi, vj] = C[vi, vj] D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk] - @T.prim_func + @T.prim_func(private=True) def expected_gem(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [128, 128], scope="scopeA") B = T.match_buffer(b, [128, 128], scope="scopeB") # flowed in diff --git a/tests/python/tir/test_debug_info.py b/tests/python/tir/test_debug_info.py index d333b43b28f5..16a589a2ddcc 100644 --- a/tests/python/tir/test_debug_info.py +++ b/tests/python/tir/test_debug_info.py @@ -48,7 +48,6 @@ def main(a: T.handle, b: T.handle): # We exchange data between function by handles, which are similar to pointer. T.func_attr( { - "global_symbol": "main", "tir.noalias": True, "target": T.target("llvm"), } @@ -165,7 +164,7 @@ def test_llvm_ir_debug_accuracy(): # Check that it matches the expected line number (in main.tir) debug_line_no = int(locations[directive_idx]) - assert debug_line_no == 42 + assert debug_line_no == 43 if __name__ == "__main__": diff --git a/tests/python/unittest/test_evaluator_with_preproc.py b/tests/python/unittest/test_evaluator_with_preproc.py index fc6eec25b8da..d9ef63054030 100644 --- a/tests/python/unittest/test_evaluator_with_preproc.py +++ b/tests/python/unittest/test_evaluator_with_preproc.py @@ -39,7 +39,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: @tvm.testing.requires_cuda @pytest.mark.parametrize("f_preproc", ["", "l2_cache_flush_cuda"]) def test_time_evalutor_with_preproc(f_preproc: str): - mod = tvm.IRModule.from_expr(matmul) + mod = tvm.IRModule.from_expr(matmul.with_attr("global_symbol", "main")) sch = tvm.tir.Schedule(mod) blk = sch.get_block("matmul") i, j, k = sch.get_loops(blk) diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py index 94d76a76922c..e2305de2afaf 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py @@ -20,6 +20,7 @@ from tvm import meta_schedule as ms from tvm.script import tir as T from tvm.target import Target +from tvm.tir.schedule.testing import assert_structural_equal_ignore_global_symbol def _target() -> Target: @@ -201,7 +202,7 @@ def test_layout_rewrite(): sch = tvm.tir.Schedule(tir_matmul, debug_mask="all") sch.enter_postproc() assert ctx.space_generator.postprocs[0].apply(sch) - tvm.ir.assert_structural_equal(sch.mod["main"], rewritten_tir_matmul) + assert_structural_equal_ignore_global_symbol(sch.mod["main"], rewritten_tir_matmul) # fmt: off diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py index 932a5d156c28..9170e1310226 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py @@ -20,6 +20,7 @@ from tvm.meta_schedule.postproc import RewriteParallelVectorizeUnroll from tvm.script import tir as T from tvm.tir.schedule import Schedule +from tvm.tir.schedule.testing import assert_structural_equal_ignore_global_symbol # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant # fmt: off @@ -195,14 +196,14 @@ def test_vectorize_inner_loop(): sch = Schedule(before_matmul_vectorize) rule = RewriteParallelVectorizeUnroll() assert rule.apply(sch) - tvm.ir.assert_structural_equal(sch.mod["main"], after_matmul_vectorize) + assert_structural_equal_ignore_global_symbol(sch.mod["main"], after_matmul_vectorize) def test_parallel_vectorize_add(): sch = Schedule(before_postproc_add) rule = RewriteParallelVectorizeUnroll() assert rule.apply(sch) - tvm.ir.assert_structural_equal(sch.mod["main"], after_postproc_add) + assert_structural_equal_ignore_global_symbol(sch.mod["main"], after_postproc_add) def test_no_unroll_for_spatial_block(): @@ -264,7 +265,7 @@ def expected(A: T.Buffer((1, 4, 4, 32), "float32"), B: T.Buffer((4, 4, 32), "flo sch = Schedule(layer_norm) assert postproc.apply(sch) mod = tvm.tir.transform.Simplify()(sch.mod) - tvm.ir.assert_structural_equal(mod["main"], expected) + assert_structural_equal_ignore_global_symbol(mod["main"], expected) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py index e71c975f3590..7dc164496501 100644 --- a/tests/python/unittest/test_tir_lower_match_buffer.py +++ b/tests/python/unittest/test_tir_lower_match_buffer.py @@ -23,10 +23,10 @@ def _check(original, transformed): - mod = tvm.IRModule.from_expr(original) + mod = tvm.IRModule.from_expr(original.with_attr("global_symbol", "main")) mod = tvm.tir.transform.LowerMatchBuffer()(mod) mod = tvm.tir.transform.Simplify()(mod) - tvm.ir.assert_structural_equal(mod["main"], transformed) + tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main")) def _check_fail(original): diff --git a/tests/python/unittest/test_tir_reorder_block_iter_var.py b/tests/python/unittest/test_tir_reorder_block_iter_var.py index 99e07aa525f9..fe1a832f49cb 100644 --- a/tests/python/unittest/test_tir_reorder_block_iter_var.py +++ b/tests/python/unittest/test_tir_reorder_block_iter_var.py @@ -57,7 +57,9 @@ def test_reorder_block_iter_var(): sch = tir.Schedule(matmul, debug_mask="all") C = sch.get_block("C") sch.reorder_block_iter_var(C, [2, 1, 0]) - tvm.ir.assert_structural_equal(matmul_after_reorder_block_iter_var, sch.mod["main"]) + tvm.ir.assert_structural_equal( + matmul_after_reorder_block_iter_var.with_attr("global_symbol", "matmul"), sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=matmul) diff --git a/tests/python/unittest/test_tir_schedule_blockize.py b/tests/python/unittest/test_tir_schedule_blockize.py index d151e4b43809..631df7a82dc3 100644 --- a/tests/python/unittest/test_tir_schedule_blockize.py +++ b/tests/python/unittest/test_tir_schedule_blockize.py @@ -54,7 +54,9 @@ def after_blockize_outer( s = tir.Schedule(func, debug_mask="all") x, _ = s.get_loops(s.get_block("B")) s.blockize(x) - tvm.ir.assert_structural_equal(s.mod["main"], after_blockize_outer) + tvm.ir.assert_structural_equal( + s.mod["main"], after_blockize_outer.with_attr("global_symbol", "single_elementwise") + ) verify_trace_roundtrip(sch=s, mod=func) @@ -77,7 +79,9 @@ def after_blockize_inner( s = tir.Schedule(func, debug_mask="all") _, y = s.get_loops(s.get_block("B")) s.blockize(y) - tvm.ir.assert_structural_equal(s.mod["main"], after_blockize_inner) + tvm.ir.assert_structural_equal( + s.mod["main"], after_blockize_inner.with_attr("global_symbol", "single_elementwise") + ) verify_trace_roundtrip(sch=s, mod=func) @@ -139,7 +143,9 @@ def after_blockize_rca( s = tir.Schedule(func, debug_mask="all") _, _, x, _ = s.get_loops(s.get_block("C")) s.blockize(x) - tvm.ir.assert_structural_equal(s.mod["main"], after_blockize_rca) + tvm.ir.assert_structural_equal( + s.mod["main"], after_blockize_rca.with_attr("global_symbol", "before_blockize_rca") + ) verify_trace_roundtrip(sch=s, mod=func) @@ -209,7 +215,10 @@ def after_blockize_compute_at( s = tir.Schedule(func, debug_mask="all") _, _, x, _ = s.get_loops(s.get_block("B")) s.blockize(x) - tvm.ir.assert_structural_equal(s.mod["main"], after_blockize_compute_at) + tvm.ir.assert_structural_equal( + s.mod["main"], + after_blockize_compute_at.with_attr("global_symbol", "before_blockize_compute_at"), + ) verify_trace_roundtrip(sch=s, mod=func) @@ -244,7 +253,9 @@ def after_rowsum_blockize( s = tir.Schedule(rowsum, debug_mask="all") k, _ = s.get_loops(s.get_block("B")) s.blockize(k) - tvm.ir.assert_structural_equal(s.mod["main"], after_rowsum_blockize) + tvm.ir.assert_structural_equal( + s.mod["main"], after_rowsum_blockize.with_attr("global_symbol", "rowsum") + ) verify_trace_roundtrip(sch=s, mod=rowsum) @@ -301,7 +312,9 @@ def after_single_elementwise_int64_blockize_preserve_unit_iters( if preserve_unit_iters else after_single_elementwise_int64_blockize ) - tvm.ir.assert_structural_equal(s.mod["main"], expected) + tvm.ir.assert_structural_equal( + s.mod["main"], expected.with_attr("global_symbol", "single_elementwise_int64") + ) verify_trace_roundtrip(sch=s, mod=single_elementwise_int64) @@ -350,7 +363,9 @@ def after_blocks_blockize( blocks = [s.get_block("B"), s.get_block("C")] s.blockize(blocks, preserve_unit_iters=False) expected = after_blocks_blockize - tvm.ir.assert_structural_equal(s.mod["main"], expected) + tvm.ir.assert_structural_equal( + s.mod["main"], expected.with_attr("global_symbol", "blocks_func") + ) verify_trace_roundtrip(sch=s, mod=blocks_func) diff --git a/tests/python/unittest/test_tir_schedule_cache_index.py b/tests/python/unittest/test_tir_schedule_cache_index.py index a509c02b37f3..5ef39958823b 100644 --- a/tests/python/unittest/test_tir_schedule_cache_index.py +++ b/tests/python/unittest/test_tir_schedule_cache_index.py @@ -456,7 +456,9 @@ def test_basic_cache_index(): sch = tvm.tir.Schedule(resize, debug_mask="all") block = sch.get_block("A") sch.cache_index(block, "global") - tvm.ir.assert_structural_equal(resize_cache_index, sch.mod["main"]) + tvm.ir.assert_structural_equal( + resize_cache_index, sch.mod["main"].with_attr("global_symbol", "resize_cache_index") + ) verify_trace_roundtrip(sch=sch, mod=resize) @@ -464,7 +466,9 @@ def test_resize_bilinear_cache_index(): sch = tvm.tir.Schedule(bilinear_resize, debug_mask="all") block = sch.get_block("resize") sch.cache_index(block, "global", 4) - tvm.ir.assert_structural_equal(sch.mod["main"], cached_bilinear_resize) + tvm.ir.assert_structural_equal( + sch.mod["main"], cached_bilinear_resize.with_attr("global_symbol", "bilinear_resize") + ) verify_trace_roundtrip(sch=sch, mod=bilinear_resize) diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index 2d460b359181..840a18ae6aea 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -22,7 +22,10 @@ import tvm.testing from tvm import tir from tvm.script import tir as T -from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule.testing import ( + verify_trace_roundtrip, + assert_structural_equal_ignore_global_symbol, +) # pylint: disable=no-member,invalid-name,unused-variable @@ -1284,7 +1287,7 @@ def test_cache_read_elementwise(use_block_name): assert sch.get(cached_b) == sch.get(sch.get_block("B_local")) assert sch.get(block_b) == sch.get(sch.get_block("B")) assert sch.get(block_c) == sch.get(sch.get_block("C")) - tvm.ir.assert_structural_equal(cache_read_elementwise, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(cache_read_elementwise, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) @@ -1294,7 +1297,7 @@ def test_cache_read_under_scope(use_block_name): block_c = "C" if use_block_name else sch.get_block("C") sch.cache_read(block_b, 0, "local") sch.cache_read(block_c, 0, "global") - tvm.ir.assert_structural_equal(cache_read_under_scope, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(cache_read_under_scope, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=access_under_scope) @@ -1302,7 +1305,7 @@ def test_cache_read_opaque_access(use_block_name): sch = tir.Schedule(opaque_access, debug_mask="all") block = "load_store" if use_block_name else sch.get_block("load_store") sch.cache_read(block, 0, "global") - tvm.ir.assert_structural_equal(cache_read_opaque_access, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(cache_read_opaque_access, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=opaque_access) @@ -1310,7 +1313,7 @@ def test_cache_read_location(use_block_name): sch = tir.Schedule(func_multi_consumer, debug_mask="all") block_b = "B" if use_block_name else sch.get_block("B") sch.cache_read(block_b, 0, "global") - tvm.ir.assert_structural_equal(cache_read_multi_consumer, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(cache_read_multi_consumer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) # Test that specific consumer block targeting works. @@ -1318,7 +1321,7 @@ def test_cache_read_location(use_block_name): block_b = "B" if use_block_name else sch.get_block("B") block_c = "C" if use_block_name else sch.get_block("C") sch.cache_read(block_b, 0, "global", consumer_blocks=[block_c]) - tvm.ir.assert_structural_equal(cache_read_multi_consumer_target, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(cache_read_multi_consumer_target, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) # Also test setting multiple consumers yields same result as unspecified. @@ -1326,7 +1329,7 @@ def test_cache_read_location(use_block_name): block_b = "B" if use_block_name else sch.get_block("B") block_c = "C" if use_block_name else sch.get_block("C") sch.cache_read(block_b, 0, "global", consumer_blocks=[block_b, block_c]) - tvm.ir.assert_structural_equal(cache_read_multi_consumer, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(cache_read_multi_consumer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) @@ -1335,7 +1338,7 @@ def test_continuous_cache_read(use_block_name): block_c = "C" if use_block_name else sch.get_block("C") sch.cache_read(block_c, 0, "shared") sch.cache_read(block_c, 0, "local") - tvm.ir.assert_structural_equal(continuous_cache_read, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(continuous_cache_read, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) @@ -1343,7 +1346,7 @@ def test_cache_read_with_block_predicate(use_block_name): sch = tir.Schedule(func_with_block_predicate, debug_mask="all") block = "consumer" if use_block_name else sch.get_block("consumer") sch.cache_read(block, 0, "shared") - tvm.ir.assert_structural_equal(block_predicate_cache_read, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(block_predicate_cache_read, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate) @@ -1351,7 +1354,7 @@ def test_cache_read_non_int32_shape(use_block_name): sch = tir.Schedule(elementwise_shape_int64, debug_mask="all") block_b = "B" if use_block_name else sch.get_block("B") sch.cache_read(block_b, 0, "global") - tvm.ir.assert_structural_equal(cache_read_shape_int64, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(cache_read_shape_int64, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_shape_int64) @@ -1380,7 +1383,7 @@ def test_inplace_cache_read(): sch = tvm.tir.Schedule(inplace_func, debug_mask="all") block = sch.get_block("copy_in") sch.cache_read(block, 0, "local", [block]) - tvm.ir.assert_structural_equal(cache_read_inplace, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(cache_read_inplace, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=inplace_func) @@ -1393,7 +1396,7 @@ def test_cache_inplace(): block = sch.cache_read(blocks[0], 0, "global", [blocks[0]]) block = sch.cache_write(blocks[1], 0, "global") - tvm.ir.assert_structural_equal(cache_inplace_buffer, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(cache_inplace_buffer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=inplace_call, debug_mask=debug_mask) @@ -1401,7 +1404,7 @@ def test_cache_read_nested_seq(use_block_name): sch = tir.Schedule(func_nested_seq, debug_mask="all") block_c = "C" if use_block_name else sch.get_block("C") sch.cache_read(block_c, 0, "global", consumer_blocks=[block_c]) - tvm.ir.assert_structural_equal(cache_read_nested_seq_target, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(cache_read_nested_seq_target, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_nested_seq) @@ -1418,7 +1421,7 @@ def test_cache_write_elementwise(use_block_name): assert sch.get(cached_c) == sch.get(sch.get_block("C_global")) assert sch.get(block_b) == sch.get(sch.get_block("B")) assert sch.get(block_c) == sch.get(sch.get_block("C")) - tvm.ir.assert_structural_equal(cache_write_elementwise, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(cache_write_elementwise, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) @@ -1430,7 +1433,7 @@ def test_cache_write_under_scope(use_block_name): sch.cache_write(block_a, 0, "local") sch.cache_write(block_b, 0, "global") sch.cache_write(block_scope, 0, "global") - tvm.ir.assert_structural_equal(cache_write_under_scope, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(cache_write_under_scope, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=access_under_scope) @@ -1442,7 +1445,7 @@ def test_cache_write_opaque_access(use_block_name): sch.cache_write(block_store, 0, "global") sch.cache_write(block_opaque, 0, "global") sch.cache_write(block_match_buffer, 0, "global") - tvm.ir.assert_structural_equal(cache_write_opaque_access, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(cache_write_opaque_access, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=opaque_access) @@ -1450,7 +1453,7 @@ def test_cache_write_location(use_block_name): sch = tir.Schedule(func_multi_consumer, debug_mask="all") block_a = "A" if use_block_name else sch.get_block("A") sch.cache_write(block_a, 0, "global") - tvm.ir.assert_structural_equal(cache_write_multi_consumer, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(cache_write_multi_consumer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) # Test that specific consumer block targeting works. @@ -1459,7 +1462,9 @@ def test_cache_write_location(use_block_name): block_a = "A" if use_block_name else sch.get_block("A") block_b = "B" if use_block_name else sch.get_block("B") sch.cache_write(block_a, 0, "global", consumer_blocks=[block_b]) - tvm.ir.assert_structural_equal(cache_write_multi_consumer_B_consume_cache, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + cache_write_multi_consumer_B_consume_cache, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) # Test that specific consumer block targeting works. @@ -1468,7 +1473,9 @@ def test_cache_write_location(use_block_name): block_a = "A" if use_block_name else sch.get_block("A") block_c = "C" if use_block_name else sch.get_block("C") sch.cache_write(block_a, 0, "global", consumer_blocks=[block_c]) - tvm.ir.assert_structural_equal(cache_write_multi_consumer_C_consume_cache, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + cache_write_multi_consumer_C_consume_cache, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) # Test that specific consumer block targeting works. @@ -1478,7 +1485,9 @@ def test_cache_write_location(use_block_name): block_b = "B" if use_block_name else sch.get_block("B") block_c = "C" if use_block_name else sch.get_block("C") sch.cache_write(block_a, 0, "global", consumer_blocks=[block_b, block_c]) - tvm.ir.assert_structural_equal(cache_write_multi_consumer_all_consume_cache, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + cache_write_multi_consumer_all_consume_cache, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) @@ -1487,7 +1496,7 @@ def test_continuous_cache_write(use_block_name): block_b = "B" if use_block_name else sch.get_block("B") sch.cache_write(block_b, 0, "shared") sch.cache_write(block_b, 0, "local") - tvm.ir.assert_structural_equal(continuous_cache_write, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(continuous_cache_write, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) @@ -1496,13 +1505,17 @@ def test_cache_write_with_block_predicate(use_block_name): sch = tir.Schedule(func_with_block_predicate, debug_mask="all") block = "producer" if use_block_name else sch.get_block("producer") sch.cache_write(block, 0, "shared") - tvm.ir.assert_structural_equal(block_predicate_cache_write_intermediate_buf, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + block_predicate_cache_write_intermediate_buf, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate) # cache write for external buffer sch = tir.Schedule(func_with_block_predicate, debug_mask="all") block = "consumer" if use_block_name else sch.get_block("consumer") sch.cache_write(block, 0, "shared") - tvm.ir.assert_structural_equal(block_predicate_cache_write_output_buf, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + block_predicate_cache_write_output_buf, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate) @@ -1601,21 +1614,21 @@ def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float1 after = sch.mod["main"] - tvm.ir.assert_structural_equal(expected, after) + assert_structural_equal_ignore_global_symbol(expected, after) verify_trace_roundtrip(sch=sch, mod=before) def test_reindex_cache_read(): sch = tir.Schedule(elementwise, debug_mask="all") sch.reindex_cache_read("C", 0, "shared", lambda i, j: (j, i // 2, i % 2)) - tvm.ir.assert_structural_equal(elementwise_reindex_cache_read, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_reindex_cache_read, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) def test_reindex_cache_read_multi_consumer(): sch = tir.Schedule(func_multi_consumer) sch.reindex_cache_read("B", 0, "shared", lambda i: (i // 32, i % 32)) - tvm.ir.assert_structural_equal(reindex_cache_read_multi_consumer, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(reindex_cache_read_multi_consumer, sch.mod["main"]) # NOTE(zihao): we do not verify trace roundtrip because of in set analysis issues. @@ -1639,16 +1652,16 @@ def test_reindex_cache_read_failed_not_single_point(): def test_reindex_cache_write(): sch = tir.Schedule(elementwise, debug_mask="all") sch.reindex_cache_write("B", 0, "shared", lambda i, j: (j, i)) - tvm.ir.assert_structural_equal(elementwise_reindex_cache_write, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_reindex_cache_write, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) def test_reindex_cache_write_reduce(): sch = tir.Schedule(reduce, debug_mask="all") sch.reindex_cache_write("B", 0, "shared", lambda i, j, k, l: (j, i, k)) - tvm.ir.assert_structural_equal(reduce_reindex_cache_write_0, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(reduce_reindex_cache_write_0, sch.mod["main"]) sch.reindex_cache_write("C", 0, "shared", lambda i, j, k: [j, i]) - tvm.ir.assert_structural_equal(reduce_reindex_cache_write_1, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(reduce_reindex_cache_write_1, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=reduce) @@ -1673,7 +1686,9 @@ def test_symbolic_matmul_blocked_cache_read(use_block_name): sch = tir.Schedule(symbolic_matmul_blocked, debug_mask="all") block = "matmul" if use_block_name else sch.get_block("matmul") sch.cache_read(block=block, read_buffer_index=0, storage_scope="shared") - tvm.ir.assert_structural_equal(sch.mod["main"], symbolic_matmul_blocked_cache_read) + assert_structural_equal_ignore_global_symbol( + sch.mod["main"], symbolic_matmul_blocked_cache_read + ) verify_trace_roundtrip(sch=sch, mod=symbolic_matmul_blocked) @@ -1681,7 +1696,9 @@ def test_symbolic_matmul_blocked_cache_write(use_block_name): sch = tir.Schedule(symbolic_matmul_blocked, debug_mask="all") block = "matmul" if use_block_name else sch.get_block("matmul") sch.cache_write(block=block, write_buffer_index=0, storage_scope="local") - tvm.ir.assert_structural_equal(sch.mod["main"], symbolic_matmul_blocked_cache_write) + assert_structural_equal_ignore_global_symbol( + sch.mod["main"], symbolic_matmul_blocked_cache_write + ) verify_trace_roundtrip(sch=sch, mod=symbolic_matmul_blocked) diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index 2e44776a0fdc..963d9586bcaa 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -20,7 +20,10 @@ import tvm.testing from tvm import te, tir from tvm.script import tir as T -from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule.testing import ( + verify_trace_roundtrip, + assert_structural_equal_ignore_global_symbol, +) # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks @@ -1102,7 +1105,7 @@ def test_compute_at_two_elementwise(use_block_name): block = "B" if use_block_name else sch.get_block("B") loop, _ = sch.get_loops("C" if use_block_name else sch.get_block("C")) sch.compute_at(block, loop, preserve_unit_loops=True) - tvm.ir.assert_structural_equal(two_elementwise_after_compute_at, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(two_elementwise_after_compute_at, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise) @@ -1111,7 +1114,7 @@ def test_compute_at_blockized_1(use_block_name): block = sch.get_block("B") _, loop = sch.get_loops(sch.get_block("C_outer")) sch.compute_at(block, loop, preserve_unit_loops=True) - tvm.ir.assert_structural_equal(blockized_after_compute_at, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(blockized_after_compute_at, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=blockized_1) @@ -1120,7 +1123,7 @@ def test_compute_at_blockized_2(use_block_name): block = sch.get_block("B_outer") _, loop, _, _ = sch.get_loops(sch.get_block("C")) sch.compute_at(block, loop, preserve_unit_loops=True) - tvm.ir.assert_structural_equal(blockized_2_after_compute_at, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(blockized_2_after_compute_at, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=blockized_2) @@ -1129,7 +1132,7 @@ def test_compute_at_cuda_matmul_0(use_block_name): block = sch.get_block("C") _, _, _, _, _, loop, _, _ = sch.get_loops(sch.get_block("C_local")) sch.compute_at(block, loop, preserve_unit_loops=True) - tvm.ir.assert_structural_equal(cuda_matmul_0_after_compute_at, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(cuda_matmul_0_after_compute_at, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=cuda_matmul_0) @@ -1138,7 +1141,7 @@ def test_compute_at_cuda_matmul_1(use_block_name): block = sch.get_block("A_shared_local") _, _, _, _, _, _, _, loop, _, _, _ = sch.get_loops(sch.get_block("C")) sch.compute_at(block, loop, preserve_unit_loops=True) - tvm.ir.assert_structural_equal(cuda_matmul_2, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(cuda_matmul_2, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=cuda_matmul_1) @@ -1147,7 +1150,7 @@ def test_compute_at_cuda_matmul_2(use_block_name): block = sch.get_block("B_shared_local") _, _, _, _, _, _, _, loop, _, _, _ = sch.get_loops(sch.get_block("C")) sch.compute_at(block, loop, preserve_unit_loops=True) - tvm.ir.assert_structural_equal(cuda_matmul_3, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(cuda_matmul_3, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=cuda_matmul_2) @@ -1156,7 +1159,7 @@ def test_compute_at_cuda_matmul_3(use_block_name): block = sch.get_block("A_shared") _, _, _, _, _, _, loop, _, _, _, _ = sch.get_loops(sch.get_block("C")) sch.compute_at(block, loop, preserve_unit_loops=True) - tvm.ir.assert_structural_equal(cuda_matmul_4, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(cuda_matmul_4, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=cuda_matmul_3) @@ -1165,7 +1168,7 @@ def test_compute_at_cuda_matmul_4(use_block_name): block = sch.get_block("B_shared") _, _, _, _, _, _, loop, _, _, _, _ = sch.get_loops(sch.get_block("C")) sch.compute_at(block, loop, preserve_unit_loops=True) - tvm.ir.assert_structural_equal(cuda_matmul_5, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(cuda_matmul_5, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=cuda_matmul_4) @@ -1174,7 +1177,7 @@ def test_compute_at_reduction_block(use_block_name): block = sch.get_block("B") (loop,) = sch.get_loops(sch.get_block("C")) sch.compute_at(block, loop, preserve_unit_loops=False) - tvm.ir.assert_structural_equal(multi_reduction_after_compute_at, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(multi_reduction_after_compute_at, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=multi_reduction) @@ -1184,7 +1187,9 @@ def test_compute_at_tiled_pooling_read_cache(use_block_name): _, w_o, _, _, _, _ = sch.get_loops(compute) cache = sch.get_block("cache") sch.compute_at(cache, w_o) - tvm.ir.assert_structural_equal(tiled_pooling_read_cache_after_compute_at, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + tiled_pooling_read_cache_after_compute_at, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=tiled_pooling_read_cache) @@ -1192,7 +1197,9 @@ def test_compute_at_non_uniform_tiled_conv(use_block_name): sch = tir.Schedule(non_uniform_tiled_conv, debug_mask="all") compute = sch.get_block("compute") sch.compute_at(sch.get_block("cache"), sch.get_loops(compute)[1]) - tvm.ir.assert_structural_equal(non_uniform_tiled_conv_after_compute_at, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + non_uniform_tiled_conv_after_compute_at, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=non_uniform_tiled_conv) @@ -1204,7 +1211,9 @@ def test_compute_at_concat(use_block_name): axis = sch.get_loops(concat)[0] sch.compute_at(add1, axis) sch.compute_at(add2, axis) - tvm.ir.assert_structural_equal(concat_two_elemwise_after_compute_at, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + concat_two_elemwise_after_compute_at, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=concat_two_elemwise) @@ -1212,7 +1221,7 @@ def test_compute_at_tiled_repeat_op(use_block_name): sch = tir.Schedule(tiled_repeat_op, debug_mask="all") outer_ax, _ = sch.get_loops(sch.get_block("T_repeat")) sch.compute_at(sch.get_block("T_add"), outer_ax) - tvm.ir.assert_structural_equal(tiled_repeat_op_after_compute_at, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(tiled_repeat_op_after_compute_at, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=tiled_repeat_op) @@ -1246,7 +1255,7 @@ def after(X: T.Buffer[(10, 10), "float32"], Z: T.Buffer[(10, 10), "float32"]): sch = tir.Schedule(before, debug_mask="all") axis = sch.get_loops(sch.get_block("b1"))[0] sch.compute_at(sch.get_block("b0"), axis) - tvm.ir.assert_structural_equal(after, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(after, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=before) @@ -1255,7 +1264,7 @@ def test_reverse_compute_at_tiled(use_block_name): block = sch.get_block("C") _, _, loop, _ = sch.get_loops(sch.get_block("B")) sch.reverse_compute_at(block, loop, preserve_unit_loops=False) - tvm.ir.assert_structural_equal(tiled_after_reverse_compute_at, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(tiled_after_reverse_compute_at, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=tiled) @@ -1264,7 +1273,9 @@ def test_reverse_compute_at_tiled_trivial_binding(use_block_name): block = sch.get_block("C") _, _, loop, _ = sch.get_loops(sch.get_block("B")) sch.reverse_compute_at(block, loop, preserve_unit_loops=False) - tvm.ir.assert_structural_equal(tiled_trivial_binding_after_reverse_compute_at, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + tiled_trivial_binding_after_reverse_compute_at, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=tiled_trivial_binding) @@ -1273,7 +1284,9 @@ def test_reverse_compute_at_blockized_2(use_block_name): block = sch.get_block("C") _, loop = sch.get_loops(sch.get_block("B_outer")) sch.reverse_compute_at(block, loop, preserve_unit_loops=True) - tvm.ir.assert_structural_equal(blockized_2_after_reverse_compute_at, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + blockized_2_after_reverse_compute_at, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=blockized_2) @@ -1282,7 +1295,9 @@ def test_reverse_compute_at_factorized(use_block_name): block = sch.get_block("B") _, loop, _, _ = sch.get_loops(sch.get_block("B_rf")) sch.reverse_compute_at(block, loop, preserve_unit_loops=False) - tvm.ir.assert_structural_equal(factorized_after_reverse_compute_at, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + factorized_after_reverse_compute_at, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=factorized) @@ -1291,7 +1306,7 @@ def test_reverse_compute_at_floordiv_and_floormod_indices(use_block_name): A = sch.get_block("A") B = sch.get_block("B") sch.reverse_compute_at(B, sch.get_loops(A)[0]) - tvm.ir.assert_structural_equal( + assert_structural_equal_ignore_global_symbol( floordiv_and_floormod_indices_after_reverse_compute_at, sch.mod["main"] ) verify_trace_roundtrip(sch=sch, mod=floordiv_and_floormod_indices) @@ -1301,7 +1316,7 @@ def test_reverse_compute_at_floordiv_and_floormod_recursive(use_block_name): sch = tir.Schedule(recursive_floordiv_floormod, debug_mask="all") write_block = sch.get_block("Out") sch.reverse_compute_at(write_block, sch.get_loops("In")[2]) - tvm.ir.assert_structural_equal( + assert_structural_equal_ignore_global_symbol( recursive_floordiv_floormod_after_reverse_compute_at, sch.mod["main"] ) verify_trace_roundtrip(sch=sch, mod=recursive_floordiv_floormod) @@ -1312,7 +1327,9 @@ def test_read_out_of_bound(use_block_name): block = sch.get_block("B") (loop,) = sch.get_loops(sch.get_block("C")) sch.compute_at(block, loop) - tvm.ir.assert_structural_equal(read_out_of_bound_after_compute_at, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + read_out_of_bound_after_compute_at, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=read_out_of_bound) @@ -1321,7 +1338,9 @@ def test_compact_dataflow(use_block_name): block = sch.get_block("B") _, loop = sch.get_loops(sch.get_block("C_1")) sch.compute_at(block, loop) - tvm.ir.assert_structural_equal(not_all_compact_data_flow_after_compute_at, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + not_all_compact_data_flow_after_compute_at, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=not_all_compact_data_flow) @@ -1330,7 +1349,7 @@ def test_compute_at_simplify_static_bound(use_block_name): block = sch.get_block("B") loop, _ = sch.get_loops(sch.get_block("C")) sch.compute_at(block, loop, preserve_unit_loops=True) - tvm.ir.assert_structural_equal(static_bound_after_compute_at, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(static_bound_after_compute_at, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=static_bound) @@ -1410,7 +1429,9 @@ def grouped_channel_bias_non_perfect_tiled( sch = tir.Schedule(grouped_channel_bias, debug_mask="all") loop = sch.get_loops(sch.get_block("compute"))[0] sch.compute_at(sch.get_block("init"), loop) - tvm.ir.assert_structural_equal(sch.mod["main"], grouped_channel_bias_non_perfect_tiled) + assert_structural_equal_ignore_global_symbol( + sch.mod["main"], grouped_channel_bias_non_perfect_tiled + ) def test_fail_subtree_complete_block(use_block_name): @@ -1566,7 +1587,7 @@ def multi_producers_after_compute_at( block_c = sch.get_block("pad") axis = sch.get_loops("conv")[0] sch.compute_at(block_c, axis, index=-2) - tvm.ir.assert_structural_equal(multi_producers_after_compute_at, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(multi_producers_after_compute_at, sch.mod["main"]) def test_reverse_compute_at_to_index(): @@ -1629,7 +1650,7 @@ def main_reverse_compute_at( block_c = sch.get_block("D") axis = sch.get_loops("B")[2] sch.reverse_compute_at(block_c, axis, index=1) - tvm.ir.assert_structural_equal(main_reverse_compute_at, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(main_reverse_compute_at, sch.mod["main"]) def test_reverse_compute_at_with_unit_loop(): @@ -1681,7 +1702,7 @@ def main_reverse_compute_at( block_d = sch.get_block("D") axis = sch.get_loops("B")[2] sch.reverse_compute_at(block_d, axis, preserve_unit_loops=True, index=1) - tvm.ir.assert_structural_equal(main_reverse_compute_at, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(main_reverse_compute_at, sch.mod["main"]) def test_reverse_compute_at_layout_trans(): @@ -1720,7 +1741,7 @@ def after(A: T.Buffer((1, 3, 5, 5, 16), "float32"), C: T.Buffer((1, 6, 5, 5, 8), trans = sch.get_block("T_layout_trans") axis = sch.get_loops("compute")[1] sch.reverse_compute_at(trans, axis) - tvm.ir.assert_structural_equal(after, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(after, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=before) @@ -1777,7 +1798,7 @@ def expected(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")) after = sch.mod["main"] - tvm.ir.assert_structural_equal(expected, after) + assert_structural_equal_ignore_global_symbol(expected, after) verify_trace_roundtrip(sch=sch, mod=before) @@ -1819,7 +1840,7 @@ def expected(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")) sch.compute_inline(block) after = sch.mod["main"] - tvm.ir.assert_structural_equal(expected, after) + assert_structural_equal_ignore_global_symbol(expected, after) verify_trace_roundtrip(sch=sch, mod=before) @@ -1885,11 +1906,13 @@ def expected(A: T.Buffer((32, 1, 128), "float32"), b: T.handle, c: T.handle): C[v0, 0, v1] = T.float32(0) C[v0, 0, v1] = C[v0, 0, v1] + C_rf[vax2_fused_1, v0, 0, v1] # fmt: on - sch = tir.Schedule(before, debug_mask="all") + sch = tir.Schedule(before.with_attr("global_symbol", "main"), debug_mask="all") block = sch.get_block("NT_matmul") loop, _, _ = sch.get_loops(sch.get_block("NT_matmul_rf")) sch.reverse_compute_at(block, loop, preserve_unit_loops=True) - tvm.ir.assert_structural_equal(sch.mod["main"], expected, True) + tvm.ir.assert_structural_equal( + sch.mod["main"], expected.with_attr("global_symbol", "main"), True + ) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index 8d90189507d7..b84c80b0a1df 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -21,7 +21,10 @@ import tvm.tir.tensor_intrin from tvm import tir from tvm.script import tir as T -from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule.testing import ( + verify_trace_roundtrip, + assert_structural_equal_ignore_global_symbol, +) # pylint: disable=no-member,invalid-name,unused-variable @@ -969,7 +972,7 @@ def test_compute_inline_elementwise(use_block_name): block_b = "B" if use_block_name else sch.get_block("B") block_c = sch.get_block("C") sch.compute_inline(block_b) - tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_inlined, sch.mod["main"]) assert sch.get(block_c).name_hint == "C" verify_trace_roundtrip(sch=sch, mod=elementwise) @@ -979,7 +982,7 @@ def test_compute_inline_under_loop(use_block_name): block_b = "B" if use_block_name else sch.get_block("B") block_c = sch.get_block("C") sch.compute_inline(block_b) - tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_inlined, sch.mod["main"]) assert sch.get(block_c).name_hint == "C" verify_trace_roundtrip(sch=sch, mod=elementwise_under_loop) @@ -989,7 +992,7 @@ def test_compute_inline_as_dce(use_block_name): block_b = "B" if use_block_name else sch.get_block("B") block_c = sch.get_block("C") sch.compute_inline(block_b) - tvm.ir.assert_structural_equal(elementwise_standalone_dce, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_standalone_dce, sch.mod["main"]) assert sch.get(block_c).name_hint == "C" verify_trace_roundtrip(sch=sch, mod=elementwise_standalone) @@ -1000,7 +1003,9 @@ def test_compute_inline_multi_consumer(use_block_name): block_c = sch.get_block("C") block_d = sch.get_block("D") sch.compute_inline(block_b) - tvm.ir.assert_structural_equal(elementwise_multi_consumer_inlined, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + elementwise_multi_consumer_inlined, sch.mod["main"] + ) assert sch.get(block_c).name_hint == "C" assert sch.get(block_d).name_hint == "D" verify_trace_roundtrip(sch=sch, mod=elementwise_multi_producer_consumer) @@ -1018,7 +1023,7 @@ def test_reverse_compute_inline_elementwise(use_block_name): block_b = sch.get_block("B") block_c = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(block_c) - tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_inlined, sch.mod["main"]) assert sch.get(block_b).name_hint == "B" verify_trace_roundtrip(sch=sch, mod=elementwise) @@ -1028,7 +1033,7 @@ def test_reverse_compute_inline_under_loop(use_block_name): block_b = sch.get_block("B") block_c = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(block_c) - tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_inlined, sch.mod["main"]) assert sch.get(block_b).name_hint == "B" verify_trace_roundtrip(sch=sch, mod=elementwise_under_loop) @@ -1058,7 +1063,9 @@ def test_reverse_compute_multi_reverse_loads(use_block_name): sch = tir.Schedule(elementwise_multi_reverse_loads, debug_mask="all") block_c = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(block_c) - tvm.ir.assert_structural_equal(elementwise_multi_reverse_loads_inlined, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + elementwise_multi_reverse_loads_inlined, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=elementwise_multi_reverse_loads) @@ -1066,7 +1073,9 @@ def test_reverse_compute_inline_affine_load(use_block_name): sch = tir.Schedule(elementwise_reverse_affine_load, debug_mask="all") block_c = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(block_c) - tvm.ir.assert_structural_equal(elementwise_reverse_affine_load_inlined, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + elementwise_reverse_affine_load_inlined, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_load) @@ -1074,7 +1083,9 @@ def test_reverse_compute_inline_multi_affine_load(use_block_name): sch = tir.Schedule(elementwise_multi_reverse_affine_load, debug_mask="all") block_c = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(block_c) - tvm.ir.assert_structural_equal(elementwise_multi_reverse_affine_load_inlined, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + elementwise_multi_reverse_affine_load_inlined, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=elementwise_multi_reverse_affine_load) @@ -1082,7 +1093,7 @@ def test_reverse_compute_inline_affine_load_unit_iter(use_block_name): sch = tir.Schedule(elementwise_reverse_affine_load_unit_iter, debug_mask="all") block_c = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(block_c) - tvm.ir.assert_structural_equal( + assert_structural_equal_ignore_global_symbol( elementwise_reverse_affine_load_unit_iter_inlined, sch.mod["main"] ) verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_load_unit_iter) @@ -1092,7 +1103,7 @@ def test_reverse_compute_inline_affine_load_unit_iter_simplified(use_block_name) sch = tir.Schedule(elementwise_reverse_affine_load_unit_iter_simplified, debug_mask="all") block_c = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(block_c) - tvm.ir.assert_structural_equal( + assert_structural_equal_ignore_global_symbol( elementwise_reverse_affine_load_unit_iter_simplified_inlined, sch.mod["main"] ) verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_load_unit_iter_simplified) @@ -1109,7 +1120,9 @@ def test_reverse_compute_inline_affine_chain(use_block_name, reverse_order): else: sch.reverse_compute_inline(block_c) sch.reverse_compute_inline(block_d) - tvm.ir.assert_structural_equal(elementwise_reverse_affine_chain_inlined, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + elementwise_reverse_affine_chain_inlined, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_chain) @@ -1168,7 +1181,7 @@ def test_compute_inline_predicate(use_block_name): sch = tir.Schedule(elementwise_predicate, debug_mask="all") block_b = "B" if use_block_name else sch.get_block("B") sch.compute_inline(block_b) - tvm.ir.assert_structural_equal(elementwise_predicate_inlined, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_predicate_inlined, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_predicate) @@ -1176,7 +1189,7 @@ def test_compute_inline_multi_loads(use_block_name): sch = tir.Schedule(elementwise_multi_loads, debug_mask="all") block_b = "B" if use_block_name else sch.get_block("B") sch.compute_inline(block_b) - tvm.ir.assert_structural_equal(elementwise_multi_loads_inlined, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_multi_loads_inlined, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_multi_loads) @@ -1185,7 +1198,9 @@ def test_compute_inline_with_opaque_access(use_block_name): sch = tir.Schedule(access_opaque_ptr_then_elemwise, debug_mask="all") BB = "BB" if use_block_name else sch.get_block("BB") sch.compute_inline(BB) - tvm.ir.assert_structural_equal(access_opaque_ptr_then_elemwise_inline, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + access_opaque_ptr_then_elemwise_inline, sch.mod["main"] + ) def test_inline_block_with_init(): @@ -1200,7 +1215,7 @@ def test_compute_inline_opaque_access_with_tvm_access_ptr(use_block_name): sch = tir.Schedule(exp_exp_opaque_access_with_tvm_access_ptr, debug_mask="all") compute = "compute" if use_block_name else sch.get_block("compute") sch.compute_inline(compute) - tvm.ir.assert_structural_equal( + assert_structural_equal_ignore_global_symbol( exp_exp_opaque_access_with_tvm_access_ptr_inlined, sch.mod["main"] ) @@ -1210,7 +1225,7 @@ def test_reverse_compute_inline_overcomputed_producer(use_block_name): sch = tir.Schedule(elementwise_overcomputed_producer, debug_mask="all") compute = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(compute) - tvm.ir.assert_structural_equal( + assert_structural_equal_ignore_global_symbol( elementwise_overcomputed_producer_reverse_inlined, sch.mod["main"] ) @@ -1220,7 +1235,7 @@ def test_reverse_compute_inline_overcomputed_producer_simplify_predicate(use_blo sch = tir.Schedule(elementwise_overcomputed_producer_simplify_predicate, debug_mask="all") compute = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(compute) - tvm.ir.assert_structural_equal( + assert_structural_equal_ignore_global_symbol( elementwise_overcomputed_producer_simplify_predicate_reverse_inlined, sch.mod["main"] ) @@ -1230,7 +1245,7 @@ def test_reverse_compute_inline_overcomputed_producer_injective_load(use_block_n sch = tir.Schedule(elementwise_overcomputed_producer_injective_load, debug_mask="all") compute = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(compute) - tvm.ir.assert_structural_equal( + assert_structural_equal_ignore_global_symbol( elementwise_overcomputed_producer_injective_load_reverse_inlined, sch.mod["main"] ) @@ -1252,7 +1267,9 @@ def test_reverse_compute_inline_producer_predicate_allowed(): sch = tir.Schedule(elementwise_predicate_producer, debug_mask="all") sch.reverse_compute_inline(sch.get_block("C")) - tvm.ir.assert_structural_equal(elementwise_predicate_producer_inlined, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + elementwise_predicate_producer_inlined, sch.mod["main"] + ) def test_reverse_compute_inline_producer_predicate_disallowed(): @@ -1262,7 +1279,7 @@ def test_reverse_compute_inline_producer_predicate_disallowed(): sch = tir.Schedule(Conv2dInt8_TensorCore_with_predicate_before, debug_mask="all") sch.reverse_compute_inline(sch.get_block("compute_4")) - tvm.ir.assert_structural_equal( + assert_structural_equal_ignore_global_symbol( Conv2dInt8_TensorCore_with_predicate_after["main"], sch.mod["main"] ) @@ -1358,7 +1375,7 @@ def after(p_lv44: T.handle, p_output0: T.handle): sch = tir.Schedule(before) sch.compute_inline(sch.get_block("T_softmax_exp")) - tvm.ir.assert_structural_equal(after, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(after, sch.mod["main"]) def test_reverse_compute_inline_layer_norm(): @@ -1442,7 +1459,7 @@ def after(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: sch = tir.Schedule(before) sch.reverse_compute_inline(sch.get_block("compute")) - tvm.ir.assert_structural_equal(after, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(after, sch.mod["main"]) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_schedule_decompose_padding.py b/tests/python/unittest/test_tir_schedule_decompose_padding.py index 15ed194328a5..e8ba0b4e21e0 100644 --- a/tests/python/unittest/test_tir_schedule_decompose_padding.py +++ b/tests/python/unittest/test_tir_schedule_decompose_padding.py @@ -19,13 +19,14 @@ import tvm import tvm.testing from tvm import tir +from tvm.tir.schedule.testing import assert_structural_equal_ignore_global_symbol from tvm.script import tir as T # pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg def check_decompose_padding(origin, scheduled, expected, check_run=False): - tvm.ir.assert_structural_equal(scheduled, expected) + assert_structural_equal_ignore_global_symbol(scheduled, expected) if check_run: in_buffer = origin.buffer_map[origin.params[0]] out_buffer = origin.buffer_map[origin.params[1]] diff --git a/tests/python/unittest/test_tir_schedule_for_kind.py b/tests/python/unittest/test_tir_schedule_for_kind.py index 8994f9de0ed4..fb0939f99086 100644 --- a/tests/python/unittest/test_tir_schedule_for_kind.py +++ b/tests/python/unittest/test_tir_schedule_for_kind.py @@ -22,7 +22,10 @@ import tvm.testing from tvm import tir from tvm.script import tir as T -from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule.testing import ( + verify_trace_roundtrip, + assert_structural_equal_ignore_global_symbol, +) # pylint: disable=no-member,invalid-name,unused-variable @@ -477,7 +480,7 @@ def test_parallel(): s = tir.Schedule(element_wise, debug_mask="all") i, _ = s.get_loops(s.get_block("B")) s.parallel(i) - tvm.ir.assert_structural_equal(s.mod["main"], element_wise_parallelized) + assert_structural_equal_ignore_global_symbol(s.mod["main"], element_wise_parallelized) verify_trace_roundtrip(s, mod=element_wise) @@ -485,7 +488,9 @@ def test_parallel_predicate(): s = tir.Schedule(element_wise_split_predicate, debug_mask="all") _, j, _ = s.get_loops(s.get_block("B")) s.parallel(j) - tvm.ir.assert_structural_equal(s.mod["main"], element_wise_split_predicate_parallelized) + assert_structural_equal_ignore_global_symbol( + s.mod["main"], element_wise_split_predicate_parallelized + ) verify_trace_roundtrip(s, mod=element_wise_split_predicate) @@ -514,7 +519,9 @@ def test_vectorize(): s = tir.Schedule(element_wise_compute_at_split, debug_mask="all") _, _, j1i = s.get_loops(s.get_block("C")) s.vectorize(j1i) - tvm.ir.assert_structural_equal(s.mod["main"], element_wise_compute_at_split_vectorized) + assert_structural_equal_ignore_global_symbol( + s.mod["main"], element_wise_compute_at_split_vectorized + ) verify_trace_roundtrip(s, mod=element_wise_compute_at_split) @@ -522,7 +529,9 @@ def test_vectorize_predicate(): s = tir.Schedule(element_wise_split_predicate, debug_mask="all") i, _, _ = s.get_loops(s.get_block("B")) s.vectorize(i) - tvm.ir.assert_structural_equal(s.mod["main"], element_wise_split_predicate_vectorized) + assert_structural_equal_ignore_global_symbol( + s.mod["main"], element_wise_split_predicate_vectorized + ) verify_trace_roundtrip(s, mod=element_wise_split_predicate) @@ -537,7 +546,7 @@ def test_unroll(): s = tir.Schedule(rowsum, debug_mask="all") i, _ = s.get_loops(s.get_block("B")) s.unroll(i) - tvm.ir.assert_structural_equal(s.mod["main"], rowsum_unrolled) + assert_structural_equal_ignore_global_symbol(s.mod["main"], rowsum_unrolled) verify_trace_roundtrip(s, mod=rowsum) @@ -546,7 +555,7 @@ def test_unroll_after_bind(): i, _ = s.get_loops(s.get_block("B")) s.bind(i, "blockIdx.x") s.unroll(i) - tvm.ir.assert_structural_equal(s.mod["main"], rowsum_unrolled) + assert_structural_equal_ignore_global_symbol(s.mod["main"], rowsum_unrolled) verify_trace_roundtrip(s, mod=rowsum) @@ -554,7 +563,7 @@ def test_bind1(): s = tir.Schedule(element_wise, debug_mask="all") i, _ = s.get_loops(s.get_block("B")) s.bind(i, "threadIdx.x") - tvm.ir.assert_structural_equal(s.mod["main"], element_wise_i_bound) + assert_structural_equal_ignore_global_symbol(s.mod["main"], element_wise_i_bound) verify_trace_roundtrip(s, mod=element_wise) @@ -564,7 +573,9 @@ def test_bind2(): _, j1o, _ = s.get_loops(s.get_block("C")) s.bind(j0, "threadIdx.x") s.bind(j1o, "threadIdx.x") - tvm.ir.assert_structural_equal(s.mod["main"], element_wise_compute_at_split_j0_j1o_bound) + assert_structural_equal_ignore_global_symbol( + s.mod["main"], element_wise_compute_at_split_j0_j1o_bound + ) verify_trace_roundtrip(s, mod=element_wise_compute_at_split) @@ -572,7 +583,7 @@ def test_bind_cross_thread_reduction(): s = tir.Schedule(rowsum, debug_mask="all") _, k = s.get_loops(s.get_block("B")) s.bind(k, "threadIdx.x") - tvm.ir.assert_structural_equal(s.mod["main"], rowsum_cross_thread_reduction) + assert_structural_equal_ignore_global_symbol(s.mod["main"], rowsum_cross_thread_reduction) verify_trace_roundtrip(s, mod=rowsum) @@ -588,7 +599,7 @@ def test_bind_after_bind(): i, _ = s.get_loops(s.get_block("B")) s.bind(i, "blockIdx.x") s.bind(i, "threadIdx.x") - tvm.ir.assert_structural_equal(s.mod["main"], element_wise_i_bound) + assert_structural_equal_ignore_global_symbol(s.mod["main"], element_wise_i_bound) verify_trace_roundtrip(s, mod=element_wise) @@ -596,7 +607,7 @@ def test_block_inside_init(): s = tir.Schedule(block_inside_init, debug_mask="all") (i,) = s.get_loops(s.get_block("outer")) s.bind(i, "threadIdx.x") - tvm.ir.assert_structural_equal(s.mod["main"], thread_bound_block_inside_init) + assert_structural_equal_ignore_global_symbol(s.mod["main"], thread_bound_block_inside_init) verify_trace_roundtrip(s, mod=block_inside_init) @@ -604,7 +615,7 @@ def test_vectorize_after_decompose(): s = tir.Schedule(decomposed_gemm, debug_mask="all") jj = s.get_loops(s.get_block("C"))[-1] s.vectorize(jj) - tvm.ir.assert_structural_equal(s.mod["main"], decomposed_gemm_after_vectorize) + assert_structural_equal_ignore_global_symbol(s.mod["main"], decomposed_gemm_after_vectorize) verify_trace_roundtrip(s, mod=decomposed_gemm) @@ -616,7 +627,7 @@ def test_nested_block_bind(): _, l = s.get_loops(block_inner) s.bind(l, "threadIdx.x") s.bind(j, "blockIdx.x") - tvm.ir.assert_structural_equal(s.mod["main"], thread_bound_nested_block) + assert_structural_equal_ignore_global_symbol(s.mod["main"], thread_bound_nested_block) verify_trace_roundtrip(s, mod=nested_block_bind) @@ -628,7 +639,9 @@ def test_nexted_block_bind_after_cache_read(): (j,) = s.get_loops(block_inner) s.bind(i, "blockIdx.x") s.bind(j, "threadIdx.x") - tvm.ir.assert_structural_equal(s.mod["main"], thread_bound_nested_block_after_cache_read) + assert_structural_equal_ignore_global_symbol( + s.mod["main"], thread_bound_nested_block_after_cache_read + ) verify_trace_roundtrip(s, mod=nested_block_bind_after_cache_read) @@ -639,7 +652,7 @@ def test_vectorize_init(): _, _, ii_0, jj_0 = s.get_loops(init_blk) _, _, k_1, ii_1, jj_1 = s.get_loops(upd_blk) s.vectorize(jj_0) - tvm.ir.assert_structural_equal(s.mod["main"], decomposed_gemm_parallelize_init) + assert_structural_equal_ignore_global_symbol(s.mod["main"], decomposed_gemm_parallelize_init) verify_trace_roundtrip(s, mod=decomposed_gemm) @@ -651,7 +664,7 @@ def test_scatter_parallelize(): (i_1,) = s.get_loops(last) s.parallel(i_0) s.parallel(i_1) - tvm.ir.assert_structural_equal(s.mod["main"], scatter_compute_parallelize) + assert_structural_equal_ignore_global_symbol(s.mod["main"], scatter_compute_parallelize) verify_trace_roundtrip(s, mod=scatter_compute) diff --git a/tests/python/unittest/test_tir_schedule_merge.py b/tests/python/unittest/test_tir_schedule_merge.py index 2c3df04d6f75..b3e72943bf6e 100644 --- a/tests/python/unittest/test_tir_schedule_merge.py +++ b/tests/python/unittest/test_tir_schedule_merge.py @@ -20,7 +20,10 @@ import tvm.testing from tvm import tir from tvm.script import tir as T -from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule.testing import ( + verify_trace_roundtrip, + assert_structural_equal_ignore_global_symbol, +) # pylint: disable=no-member,invalid-name,unused-variable @@ -118,7 +121,7 @@ def test_merge(): i = sch.get_loops(block_c)[0] j = sch.get_loops(block_d)[0] sch.merge(i, j) - tvm.ir.assert_structural_equal(elementwise_merged, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_merged, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) @@ -129,7 +132,7 @@ def test_merge2(): i = sch.get_loops(block_c)[1] j = sch.get_loops(block_d)[1] sch.merge(i, j) - tvm.ir.assert_structural_equal(elementwise_merged2, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_merged2, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) diff --git a/tests/python/unittest/test_tir_schedule_pad_einsum.py b/tests/python/unittest/test_tir_schedule_pad_einsum.py index 0b0288e3f79c..d27a8ca985fd 100644 --- a/tests/python/unittest/test_tir_schedule_pad_einsum.py +++ b/tests/python/unittest/test_tir_schedule_pad_einsum.py @@ -20,7 +20,85 @@ import tvm.testing from tvm import tir from tvm.script import tir as T -from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule.testing import ( + verify_trace_roundtrip, + assert_structural_equal_ignore_global_symbol, +) + +# pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg + + +@T.prim_func +def matmul_before( + A: T.Buffer((128, 127), "float32"), + B: T.Buffer((127, 127), "float32"), + C: T.Buffer((128, 127), "float32"), +) -> None: + A_shared = T.alloc_buffer((128, 127), "float32", scope="shared") + B_shared = T.alloc_buffer((127, 127), "float32", scope="shared") + C_shared = T.alloc_buffer((128, 127), "float32", scope="shared") + for i0, i1 in T.grid(128, 127): + with T.block("A"): + i, j = T.axis.remap("SS", [i0, i1]) + A_shared[i, j] = A[i, j] + for i0, i1 in T.grid(127, 127): + with T.block("B"): + i, j = T.axis.remap("SS", [i0, i1]) + B_shared[i, j] = B[i, j] + for i0, i1, i2 in T.grid(128, 127, 127): + with T.block("C_shared"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + with T.init(): + C_shared[i, j] = T.float32(0) + C_shared[i, j] = C_shared[i, j] + A_shared[i, k] * B_shared[k, j] + for i0, i1 in T.grid(128, 127): + with T.block("C"): + i, j = T.axis.remap("SS", [i0, i1]) + C[i, j] = C_shared[i, j] + + +@T.prim_func +def matmul_expected( + A: T.Buffer((128, 127), "float32"), + B: T.Buffer((127, 127), "float32"), + C: T.Buffer((128, 127), "float32"), +) -> None: + A_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared") + B_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared") + C_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared") + for i0, i1 in T.grid(128, 128): + with T.block("A"): + i, j = T.axis.remap("SS", [i0, i1]) + T.reads(A[i, j]) + T.writes(A_shared_padded[i, j]) + A_shared_padded[i, j] = T.if_then_else(j < 127, A[i, j], T.float32(0), dtype="float32") + for i0, i1 in T.grid(128, 128): + with T.block("B"): + i, j = T.axis.remap("SS", [i0, i1]) + T.reads(B[i, j]) + T.writes(B_shared_padded[i, j]) + B_shared_padded[i, j] = T.if_then_else( + i < 127 and j < 127, B[i, j], T.float32(0), dtype="float32" + ) + for i0, i1, i2 in T.grid(128, 128, 128): + with T.block("C_shared"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(A_shared_padded[i, k], B_shared_padded[k, j]) + T.writes(C_shared_padded[i, j]) + with T.init(): + C_shared_padded[i, j] = T.float32(0) + C_shared_padded[i, j] = ( + C_shared_padded[i, j] + A_shared_padded[i, k] * B_shared_padded[k, j] + ) + for i0, i1 in T.grid(128, 127): + with T.block("C"): + i, j = T.axis.remap("SS", [i0, i1]) + T.reads(C_shared_padded[i, j]) + T.writes(C[i, j]) + C[i, j] = C_shared_padded[i, j] + + +# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg def test_pad_matmul(): @@ -75,7 +153,7 @@ def matmul_after( sch = tir.Schedule(matmul_before, debug_mask="all") C = sch.get_block("C") sch.pad_einsum(C, [32, 32, 32]) - tvm.ir.assert_structural_equal(matmul_after, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(matmul_after, sch.mod["main"]) verify_trace_roundtrip(sch, mod=matmul_before) @@ -145,7 +223,7 @@ def after(a: T.handle, b: T.handle, m: T.handle, d: T.handle): sch = tir.Schedule(before, debug_mask="all") C = sch.get_block("C") sch.pad_einsum(C, [1, 32, 32, 32]) - tvm.ir.assert_structural_equal(after, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(after, sch.mod["main"]) verify_trace_roundtrip(sch, mod=before) @@ -217,7 +295,7 @@ def after(a: T.handle, w: T.handle, r: T.handle): sch = tir.Schedule(before, debug_mask="all") C = sch.get_block("S") sch.pad_einsum(C, [1, 32, 1]) - tvm.ir.assert_structural_equal(after, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(after, sch.mod["main"]) verify_trace_roundtrip(sch, mod=before) diff --git a/tests/python/unittest/test_tir_schedule_read_write_at.py b/tests/python/unittest/test_tir_schedule_read_write_at.py index dd61a4d62be1..39d6b4f82272 100644 --- a/tests/python/unittest/test_tir_schedule_read_write_at.py +++ b/tests/python/unittest/test_tir_schedule_read_write_at.py @@ -21,7 +21,10 @@ import tvm from tvm import tir from tvm.script import tir as T -from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule.testing import ( + verify_trace_roundtrip, + assert_structural_equal_ignore_global_symbol, +) # fmt: off @@ -191,7 +194,7 @@ def test_read_at_global_to_shared_a(): _by, _bx, _vy, _vx, _ty, _tx, k0, _k1, _, _i, _j = sch.get_loops(block) # pylint: enable=invalid-name sch.read_at(k0, block, 1, "shared") - tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_read_at_a) + assert_structural_equal_ignore_global_symbol(sch.mod["main"], cuda_matmul_read_at_a) verify_trace_roundtrip(sch, cuda_matmul) @@ -202,7 +205,7 @@ def test_read_at_global_to_shared_ab(): _by, _bx, _vy, _vx, _ty, _tx, k0, _k1, _, _i, _j = sch.get_loops(block) # pylint: enable=invalid-name sch.read_at(k0, block, 2, "shared") - tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_read_at_ab) + assert_structural_equal_ignore_global_symbol(sch.mod["main"], cuda_matmul_read_at_ab) verify_trace_roundtrip(sch, cuda_matmul_read_at_a) @@ -213,7 +216,7 @@ def test_read_at_local_to_shared_c(): _by, _bx, _vy, _vx, _ty, tx, _k0, _k1, _, _i, _j = sch.get_loops(block) # pylint: enable=invalid-name sch.write_at(tx, block, 0, "shared") - tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_write_at_c) + assert_structural_equal_ignore_global_symbol(sch.mod["main"], cuda_matmul_write_at_c) verify_trace_roundtrip(sch, cuda_matmul_read_at_ab) diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py index b24ecee3762a..a1e5ed74c228 100644 --- a/tests/python/unittest/test_tir_schedule_reduction.py +++ b/tests/python/unittest/test_tir_schedule_reduction.py @@ -22,7 +22,10 @@ import tvm.testing from tvm import tir from tvm.script import tir as T -from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule.testing import ( + verify_trace_roundtrip, + assert_structural_equal_ignore_global_symbol, +) # pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg @@ -223,7 +226,7 @@ def test_reduction_decompose0(use_block_name): C = "update" if use_block_name else s.get_block("update") i, j, k = s.get_loops(C) s.decompose_reduction(C, i) - tvm.ir.assert_structural_equal(matmul_decompose0, s.mod["main"]) + assert_structural_equal_ignore_global_symbol(matmul_decompose0, s.mod["main"]) verify_trace_roundtrip(s, mod=matmul) @@ -232,7 +235,7 @@ def test_reduction_decompose1(use_block_name): blockized_B = "blockized_B" if use_block_name else s.get_block("blockized_B") io, ko = s.get_loops(blockized_B) s.decompose_reduction(blockized_B, io) - tvm.ir.assert_structural_equal(matmul_decompose1, s.mod["main"]) + assert_structural_equal_ignore_global_symbol(matmul_decompose1, s.mod["main"]) verify_trace_roundtrip(s, mod=rowsum_blockized) @@ -241,7 +244,7 @@ def test_reduction_decompose2(): C = s.get_block("update") i, j, k = s.get_loops(C) s.decompose_reduction(C, k) - tvm.ir.assert_structural_equal(matmul_decompose2, s.mod["main"]) + assert_structural_equal_ignore_global_symbol(matmul_decompose2, s.mod["main"]) verify_trace_roundtrip(s, mod=matmul) @@ -260,7 +263,7 @@ def test_reduction_decompose4(): io, ii = s.split(i, factors=[16, 8]) ko, ki = s.split(k, factors=[19, 7]) s.decompose_reduction(C, ii) - tvm.ir.assert_structural_equal(matmul_decompose4, s.mod["main"]) + assert_structural_equal_ignore_global_symbol(matmul_decompose4, s.mod["main"]) verify_trace_roundtrip(s, mod=matmul) @@ -269,7 +272,7 @@ def test_reduction_decompose_with_annotation(): C = s.get_block("update") i, j, k = s.get_loops(C) s.decompose_reduction(C, i) - tvm.ir.assert_structural_equal(matmul_decompose_with_annotation, s.mod["main"]) + assert_structural_equal_ignore_global_symbol(matmul_decompose_with_annotation, s.mod["main"]) verify_trace_roundtrip(s, mod=matmul_with_annotation) @@ -278,14 +281,14 @@ def test_reduction_decompose_with_different_for_kind(): B = s.get_block("B") k, _ = s.get_loops(B) B_init = s.decompose_reduction(B, k) - tvm.ir.assert_structural_equal(s.mod["main"], colsum_decompose_with_vectorization) + assert_structural_equal_ignore_global_symbol(s.mod["main"], colsum_decompose_with_vectorization) assert s.get(B).same_as(s.get(s.get_block("B_update"))) assert s.get(B_init).same_as(s.get(s.get_block("B_init"))) verify_trace_roundtrip(s, mod=colsum_with_vectorization) def test_decompose_reduction_ref_hash_check(): - mod = tvm.IRModule.from_expr(matmul) + mod = tvm.IRModule.from_expr(matmul.with_attr("global_symbol", "main")) mod_bak = mod hash_before = tvm.ir.structural_hash(mod_bak) s = tir.Schedule(mod["main"], debug_mask="all") @@ -346,7 +349,7 @@ def decomposed_nested_block(A: T.Buffer((1, 64), "float32"), B: T.Buffer((1,), " i, ko = sch.get_loops(outer) sch.decompose_reduction(outer, ko) - tvm.ir.assert_structural_equal(decomposed_nested_block, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(decomposed_nested_block, sch.mod["main"]) verify_trace_roundtrip(sch, mod=nested_block) diff --git a/tests/python/unittest/test_tir_schedule_reindex.py b/tests/python/unittest/test_tir_schedule_reindex.py index 60e3f004f59c..a410c293bcb3 100644 --- a/tests/python/unittest/test_tir_schedule_reindex.py +++ b/tests/python/unittest/test_tir_schedule_reindex.py @@ -21,7 +21,10 @@ from tvm import tir from tvm.script import tir as T from tvm.tir.schedule.schedule import ScheduleError -from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule.testing import ( + assert_structural_equal_ignore_global_symbol, + verify_trace_roundtrip, +) @T.prim_func @@ -287,7 +290,9 @@ def test_reindex_read_basic(use_block_name, use_buffer_name): block = "B" if use_block_name else sch.get_block("B") buf = "A" if use_buffer_name else ("read", 0) sch.reindex(block, buf) - tvm.ir.assert_structural_equal(transpose_elementwise_reindex_read, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + transpose_elementwise_reindex_read, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=transpose_elementwise) @@ -296,7 +301,7 @@ def test_conv2d_reindex_weight(use_block_name, use_buffer_name): block = "conv2d_nhwc" if use_block_name else sch.get_block("conv2d_nhwc") buf = "Weight" if use_buffer_name else ("read", 1) sch.reindex(block, buf) - tvm.ir.assert_structural_equal(conv2d_nhwc_reindex_weight, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(conv2d_nhwc_reindex_weight, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=conv2d_nhwc) @@ -305,7 +310,7 @@ def test_conv2d_reindex_data(use_block_name, use_buffer_name): block = "conv2d_nhwc" if use_block_name else sch.get_block("conv2d_nhwc") buf = "PadInput" if use_buffer_name else ("read", 0) sch.reindex(block, buf) - tvm.ir.assert_structural_equal(conv2d_nhwc_reindex_data, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(conv2d_nhwc_reindex_data, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=conv2d_nhwc) @@ -314,7 +319,7 @@ def test_matmul_reindex_write(use_block_name, use_buffer_name): block = "matmul" if use_block_name else sch.get_block("matmul") buf = "C" if use_buffer_name else ("write", 0) sch.reindex(block, buf) - tvm.ir.assert_structural_equal(matmul_reindex_write, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(matmul_reindex_write, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=matmul) @@ -331,7 +336,7 @@ def test_reindex_mixed_dtype(use_block_name, use_buffer_name): block = "T_matmul_NT" if use_block_name else sch.get_block("T_matmul_NT") buf = "T_matmul_NT" if use_buffer_name else ("write", 0) sch.reindex(block, buf) - tvm.ir.assert_structural_equal(mixed_dtype_reindex_write, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(mixed_dtype_reindex_write, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=mixed_dtype) @@ -340,7 +345,7 @@ def test_matmul_unit_dim_reindex_write(use_block_name, use_buffer_name): block = "matmul" if use_block_name else sch.get_block("matmul") buf = "C" if use_buffer_name else ("write", 0) sch.reindex(block, buf) - tvm.ir.assert_structural_equal(matmul_unit_dim_reindex_write, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(matmul_unit_dim_reindex_write, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=matmul_unit_dim) diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py index 763ce8c36ef0..7ca9d35ea09d 100644 --- a/tests/python/unittest/test_tir_schedule_reorder.py +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -22,7 +22,10 @@ import tvm.testing from tvm import tir from tvm.script import tir as T -from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule.testing import ( + assert_structural_equal_ignore_global_symbol, + verify_trace_roundtrip, +) # pylint: disable=no-member,invalid-name,unused-variable @@ -189,7 +192,7 @@ def test_reorder(): block_b = sch.get_block("B") i, j, k, l = sch.get_loops(block_b) sch.reorder(l, i) - tvm.ir.assert_structural_equal(elementwise_reordered, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_reordered, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) @@ -198,7 +201,7 @@ def test_reorder2(): block_b = sch.get_block("B") i, j, k, l = sch.get_loops(block_b) sch.reorder(k, i, l) - tvm.ir.assert_structural_equal(elementwise_reordered2, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_reordered2, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) @@ -210,7 +213,7 @@ def test_reorder_with_opaque_access(): block_b = sch.get_block("B") i, j = sch.get_loops(block_b) sch.reorder(j, i) - tvm.ir.assert_structural_equal(opaque_access_reorder, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(opaque_access_reorder, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=opaque_access) @@ -236,7 +239,7 @@ def overlapped_access_reorder(A: T.Buffer((14, 4), "float32"), B: T.Buffer((14, sch = tir.Schedule(overlapped_access, debug_mask="all") v0, v1, v2 = sch.get_loops(sch.get_block("block")) sch.reorder(v0, v2, v1) - tvm.ir.assert_structural_equal(overlapped_access_reorder, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(overlapped_access_reorder, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=overlapped_access) @@ -263,7 +266,7 @@ def non_affine_func_reorder(A: T.Buffer((14, 4), "float32"), B: T.Buffer((14, 4) sch.reorder(v0, v2, v1) sch.reorder(v2, v1) - tvm.ir.assert_structural_equal(non_affine_func_reorder, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(non_affine_func_reorder, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=non_affine_func) @@ -323,7 +326,9 @@ def cascade_pool_ops_tile_reordered( sch.compute_at(pool_0, ho) _, _, _, h_i, w, _, _ = sch.get_loops(pool_0) sch.reorder(w, h_i) - tvm.ir.assert_structural_equal(cascade_pool_ops_tile_reordered, sch.mod["main"], True) + assert_structural_equal_ignore_global_symbol( + cascade_pool_ops_tile_reordered, sch.mod["main"], True + ) verify_trace_roundtrip(sch=sch, mod=cascade_pool_ops) diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py b/tests/python/unittest/test_tir_schedule_rfactor.py index 199e822e84b8..c1eb04b7c314 100644 --- a/tests/python/unittest/test_tir_schedule_rfactor.py +++ b/tests/python/unittest/test_tir_schedule_rfactor.py @@ -20,7 +20,10 @@ import tvm.testing from tvm import te, tir, topi from tvm.script import tir as T -from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule.testing import ( + assert_structural_equal_ignore_global_symbol, + verify_trace_roundtrip, +) # pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg @@ -1261,7 +1264,7 @@ def test_reduction_rfactor_matmul(): update = s.get_block("update") _, _, _, _, kii = s.get_loops(update) rf_block = s.rfactor(kii, 0) - tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor) + assert_structural_equal_ignore_global_symbol(s.mod["main"], matmul_rfactor) assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) assert s.get(update).same_as(s.get(s.get_block("update"))) verify_trace_roundtrip(s, mod=transformed_matmul) @@ -1272,7 +1275,7 @@ def test_reduction_rfactor_matmul_with_let(): update = s.get_block("update") _, _, _, _, kii = s.get_loops(update) rf_block = s.rfactor(kii, 0) - tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor) + assert_structural_equal_ignore_global_symbol(s.mod["main"], matmul_rfactor) assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) assert s.get(update).same_as(s.get(s.get_block("update"))) verify_trace_roundtrip(s, mod=transformed_matmul_with_let) @@ -1283,7 +1286,7 @@ def test_reduction_rfactor_square_sum(): C = s.get_block("C") _, _, j = s.get_loops(C) rf_block = s.rfactor(j, 1) - tvm.ir.assert_structural_equal(s.mod["main"], square_sum_rfactor) + assert_structural_equal_ignore_global_symbol(s.mod["main"], square_sum_rfactor) assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) assert s.get(C).same_as(s.get(s.get_block("C"))) verify_trace_roundtrip(s, mod=square_sum) @@ -1294,7 +1297,7 @@ def test_reduction_rfactor_square_sum_square_root(): C = s.get_block("C") _, _, f_i = s.get_loops(C) rf_block = s.rfactor(f_i, 0) - tvm.ir.assert_structural_equal(s.mod["main"], square_sum_square_root_rfactor) + assert_structural_equal_ignore_global_symbol(s.mod["main"], square_sum_square_root_rfactor) assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) assert s.get(C).same_as(s.get(s.get_block("C"))) verify_trace_roundtrip(s, mod=transformed_square_sum_square_root) @@ -1363,7 +1366,7 @@ def test_reduction_rfactor_factor_axis_range(): update = s.get_block("update") _, _, _, _, kii = s.get_loops(update) rf_block = s.rfactor(kii, -3) - tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor) + assert_structural_equal_ignore_global_symbol(s.mod["main"], matmul_rfactor) assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) assert s.get(update).same_as(s.get(s.get_block("update"))) verify_trace_roundtrip(s, mod=transformed_matmul) @@ -1409,7 +1412,7 @@ def test_reduction_rfactor_zero_dim(): B = s.get_block("B") (k,) = s.get_loops(B) rf_block = s.rfactor(k, 0) - tvm.ir.assert_structural_equal(s.mod["main"], rowsum_zero_dim_rfactor) + assert_structural_equal_ignore_global_symbol(s.mod["main"], rowsum_zero_dim_rfactor) assert s.get(rf_block).same_as(s.get(s.get_block("B_rf"))) assert s.get(B).same_as(s.get(s.get_block("B"))) verify_trace_roundtrip(s, mod=rowsum_zero_dim) @@ -1439,7 +1442,7 @@ def test_reduction_rfactor_outermost_loop_multiple_children(): # pylint: disabl C = s.get_block("C") _, _, k1o, _ = s.get_loops(C) rf_block = s.rfactor(k1o, 2) - tvm.ir.assert_structural_equal(s.mod["main"], multiple_reduction_blocks_rfactor) + assert_structural_equal_ignore_global_symbol(s.mod["main"], multiple_reduction_blocks_rfactor) assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) assert s.get(C).same_as(s.get(s.get_block("C"))) verify_trace_roundtrip(s, mod=multiple_reduction_blocks) @@ -1459,7 +1462,7 @@ def test_reduction_rfactor_with_annotation(): C = s.get_block("C") _, _, j = s.get_loops(C) rf_block = s.rfactor(j, 1) - tvm.ir.assert_structural_equal(s.mod["main"], square_sum_with_annotation_rfactor) + assert_structural_equal_ignore_global_symbol(s.mod["main"], square_sum_with_annotation_rfactor) assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) assert s.get(C).same_as(s.get(s.get_block("C"))) verify_trace_roundtrip(s, mod=square_sum_with_annotation) @@ -1470,7 +1473,7 @@ def test_reduction_rfactor_spatial_only(): block = s.get_block(name="acc", func_name="main") _, _, _, _, loop, _ = s.get_loops(block) rf_block = s.rfactor(loop=loop, factor_axis=4) - tvm.ir.assert_structural_equal(s.mod["main"], rfactor_spatial_only_after) + assert_structural_equal_ignore_global_symbol(s.mod["main"], rfactor_spatial_only_after) assert s.get(rf_block).same_as(s.get(s.get_block("acc_rf"))) assert s.get(block).same_as(s.get(s.get_block("acc"))) verify_trace_roundtrip(s, mod=rfactor_spatial_only) @@ -1481,7 +1484,7 @@ def test_reduction_rfactor_argmax(): argmax = s.get_block("argmax") _, _, ki = s.get_loops(argmax) rf_block = s.rfactor(ki, 1) - tvm.ir.assert_structural_equal(s.mod["main"], argmax_split_rfactor) + assert_structural_equal_ignore_global_symbol(s.mod["main"], argmax_split_rfactor) assert s.get(rf_block).same_as(s.get(s.get_block("argmax_rf"))) assert s.get(argmax).same_as(s.get(s.get_block("argmax"))) verify_trace_roundtrip(s, mod=argmax_split) @@ -1492,7 +1495,7 @@ def test_reduction_rfactor_argmin_init_update_reordeded(): argmin = s.get_block("argmin") _, _, ki = s.get_loops(argmin) rf_block = s.rfactor(ki, 1) - tvm.ir.assert_structural_equal(s.mod["main"], argmin_split_rfactor) + assert_structural_equal_ignore_global_symbol(s.mod["main"], argmin_split_rfactor) assert s.get(rf_block).same_as(s.get(s.get_block("argmin_rf"))) assert s.get(argmin).same_as(s.get(s.get_block("argmin"))) verify_trace_roundtrip(s, mod=argmin_split_init_update_reordered) @@ -1619,7 +1622,7 @@ def test_reduction_rfactor_topi_argmax(): _, k = s.get_loops(argmax) _, ki = s.split(k, [None, 8]) rf_block = s.rfactor(ki, 1) - tvm.ir.assert_structural_equal(s.mod["main"], argmax_topi_rfactor) + assert_structural_equal_ignore_global_symbol(s.mod["main"], argmax_topi_rfactor) assert s.get(rf_block).same_as(s.get(s.get_block("placeholder_red_temp_rf"))) assert s.get(argmax).same_as(s.get(s.get_block("placeholder_red_temp"))) verify_trace_roundtrip(s, mod=argmax_topi) @@ -1634,7 +1637,7 @@ def test_reduction_rfactor_topi_argmin(): _, k = s.get_loops(argmin) _, ki = s.split(k, [None, 8]) rf_block = s.rfactor(ki, 1) - tvm.ir.assert_structural_equal(s.mod["main"], argmin_topi_rfactor) + assert_structural_equal_ignore_global_symbol(s.mod["main"], argmin_topi_rfactor) assert s.get(rf_block).same_as(s.get(s.get_block("placeholder_red_temp_rf"))) assert s.get(argmin).same_as(s.get(s.get_block("placeholder_red_temp"))) verify_trace_roundtrip(s, mod=argmin_topi) diff --git a/tests/python/unittest/test_tir_schedule_rolling_buffer.py b/tests/python/unittest/test_tir_schedule_rolling_buffer.py index 9597a5db72fc..9d19dd877b75 100644 --- a/tests/python/unittest/test_tir_schedule_rolling_buffer.py +++ b/tests/python/unittest/test_tir_schedule_rolling_buffer.py @@ -20,7 +20,10 @@ import tvm.testing from tvm import tir from tvm.script import tir as T -from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule.testing import ( + assert_structural_equal_ignore_global_symbol, + verify_trace_roundtrip, +) import pytest @@ -28,7 +31,7 @@ def check_rolling_buffer( sch: tir.Schedule, origin: tir.PrimFunc, expected: tir.PrimFunc, check_run=False ): scheduled = sch.mod["main"] - tvm.ir.assert_structural_equal(scheduled, expected) + assert_structural_equal_ignore_global_symbol(scheduled, expected) verify_trace_roundtrip(sch, origin) if check_run: in_buffer = origin.buffer_map[origin.params[0]] diff --git a/tests/python/unittest/test_tir_schedule_set_axis_separator.py b/tests/python/unittest/test_tir_schedule_set_axis_separator.py index 75c650733ae0..76a6ade42f50 100644 --- a/tests/python/unittest/test_tir_schedule_set_axis_separator.py +++ b/tests/python/unittest/test_tir_schedule_set_axis_separator.py @@ -21,7 +21,10 @@ from tvm import tir from tvm.tir import IndexMap from tvm.script import tir as T -from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule.testing import ( + assert_structural_equal_ignore_global_symbol, + verify_trace_roundtrip, +) # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg @@ -122,7 +125,7 @@ def test_set_axis_separator(argument_style): else: raise ValueError(f'Unexpected argument_style: {argument_style}') - tvm.ir.assert_structural_equal(element_wise_set_axis_separator, s.mod["main"]) + assert_structural_equal_ignore_global_symbol(element_wise_set_axis_separator, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) @@ -150,7 +153,7 @@ def test_set_axis_separator_input_buffer(argument_style): raise ValueError(f'Unexpected argument_style: {argument_style}') - tvm.ir.assert_structural_equal(element_wise_set_axis_separator_input_buffer, s.mod["main"]) + assert_structural_equal_ignore_global_symbol(element_wise_set_axis_separator_input_buffer, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) @@ -168,7 +171,7 @@ def test_set_axis_separator_subregion(argument_style): else: raise ValueError(f'Unexpected argument_style: {argument_style}') - tvm.ir.assert_structural_equal(element_wise_subregion_match_set_axis_separator, s.mod["main"]) + assert_structural_equal_ignore_global_symbol(element_wise_subregion_match_set_axis_separator, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) class TestIndexedLookup(tvm.testing.CompareBeforeAfter): diff --git a/tests/python/unittest/test_tir_schedule_set_dtype.py b/tests/python/unittest/test_tir_schedule_set_dtype.py index 7f0900619b9b..96441b630b05 100644 --- a/tests/python/unittest/test_tir_schedule_set_dtype.py +++ b/tests/python/unittest/test_tir_schedule_set_dtype.py @@ -21,7 +21,10 @@ import tvm.testing from tvm import tir from tvm.script import tir as T -from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule.testing import ( + assert_structural_equal_ignore_global_symbol, + verify_trace_roundtrip, +) # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg @@ -96,7 +99,7 @@ def test_set_dtype(use_block_name): func = element_wise sch = tir.Schedule(func, debug_mask="all") sch.unsafe_set_dtype("B" if use_block_name else sch.get_block("B"), 0, "float16") - tvm.ir.assert_structural_equal(element_wise_set_dtype, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(element_wise_set_dtype, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func) def test_set_dtype_fail_on_output_buffer(use_block_name): @@ -117,7 +120,7 @@ def test_set_dtype_subregion(): func = element_wise_subregion_match sch = tir.Schedule(func, debug_mask='all') sch.unsafe_set_dtype(sch.get_block("B"), 0, "float16") - tvm.ir.assert_structural_equal(element_wise_subregion_match_set_dtype, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(element_wise_subregion_match_set_dtype, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func) diff --git a/tests/python/unittest/test_tir_schedule_set_scope.py b/tests/python/unittest/test_tir_schedule_set_scope.py index 40df049783f9..991a4ca9b77f 100644 --- a/tests/python/unittest/test_tir_schedule_set_scope.py +++ b/tests/python/unittest/test_tir_schedule_set_scope.py @@ -20,7 +20,10 @@ import tvm.testing from tvm import tir from tvm.script import tir as T -from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule.testing import ( + assert_structural_equal_ignore_global_symbol, + verify_trace_roundtrip, +) # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg @@ -94,7 +97,7 @@ def test_set_scope(use_block_name, use_buffer_name): func = element_wise s = tir.Schedule(func, debug_mask='all') s.set_scope('B' if use_block_name else s.get_block("B"), 'B' if use_buffer_name else 0, "shared") - tvm.ir.assert_structural_equal(element_wise_set_scope, s.mod["main"]) + assert_structural_equal_ignore_global_symbol(element_wise_set_scope, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) @@ -125,7 +128,7 @@ def test_set_scope_subregion(): func = element_wise_subregion_match s = tir.Schedule(func, debug_mask='all') s.set_scope(s.get_block("B"), 0, "shared") - tvm.ir.assert_structural_equal(element_wise_subregion_match_set_scope, s.mod["main"]) + assert_structural_equal_ignore_global_symbol(element_wise_subregion_match_set_scope, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index f6373fa727a1..679b147446ea 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -21,7 +21,10 @@ from tvm import te, tir from tvm.script import tir as T from tvm.tir.expr import IntImm -from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule.testing import ( + assert_structural_equal_ignore_global_symbol, + verify_trace_roundtrip, +) # pylint: disable=no-member,invalid-name,unused-variable @@ -359,7 +362,7 @@ def test_fuse(): block_b = sch.get_block("B") i, j, k = sch.get_loops(block_b) sch.fuse(i, j, k) - tvm.ir.assert_structural_equal(elementwise_fused, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_fused, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) @@ -370,7 +373,7 @@ def test_split(): sch.split(i, factors=[2, 1, 64]) sch.split(j, factors=[4, 32]) sch.split(k, factors=[16, 8]) - tvm.ir.assert_structural_equal(elementwise_split_case0, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_split_case0, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) @@ -381,7 +384,7 @@ def test_split_with_inferred_factor(): sch.split(i, factors=[None, 1, 64]) sch.split(j, factors=[2, None, 64]) sch.split(k, factors=[2, 1, None]) - tvm.ir.assert_structural_equal(elementwise_split_case1, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_split_case1, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) @@ -392,7 +395,7 @@ def test_split_with_predicate(): sch.split(i, factors=[1000, 2, 3]) sch.split(j, factors=[None, 129]) sch.split(k, factors=[3, None]) - tvm.ir.assert_structural_equal(elementwise_split_with_predicate, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_split_with_predicate, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) @@ -429,7 +432,9 @@ def test_fuse_with_opaque_block(): block_opaque = sch.get_block("opaque") i, j, k = sch.get_loops(block_opaque) sch.fuse(i, j, k) - tvm.ir.assert_structural_equal(elementwise_fuse_with_opaque_block, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + elementwise_fuse_with_opaque_block, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=elementwise_with_opaque_block) @@ -441,7 +446,7 @@ def test_fuse_with_opaque_access(): block_b = sch.get_block("B") i, j = sch.get_loops(block_b) sch.fuse(i, j) - tvm.ir.assert_structural_equal(opaque_access_fused, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(opaque_access_fused, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=opaque_access) @@ -450,7 +455,9 @@ def test_split_with_opaque_block(): block_opaque = sch.get_block("opaque") i, _, _ = sch.get_loops(block_opaque) sch.split(i, factors=[None, 16]) - tvm.ir.assert_structural_equal(elementwise_split_with_opaque_block, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + elementwise_split_with_opaque_block, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=elementwise_with_opaque_block) @@ -462,7 +469,7 @@ def test_split_with_opaque_access(): block_b = sch.get_block("B") _, j = sch.get_loops(block_b) sch.split(j, factors=[None, 4]) - tvm.ir.assert_structural_equal(opaque_access_split, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(opaque_access_split, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=opaque_access) @@ -493,7 +500,7 @@ def test_fuse_symbolic(): block_b = sch.get_block("B") i, j, k = sch.get_loops(block_b) sch.fuse(i, j, k) - tvm.ir.assert_structural_equal(elementwise_symbolic_fused, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_symbolic_fused, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_symbolic) @@ -502,7 +509,7 @@ def test_split_symbolic(): block_b = sch.get_block("B") _, _, k = sch.get_loops(block_b) sch.split(k, factors=[10, None]) - tvm.ir.assert_structural_equal(elementwise_symbolic_split, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_symbolic_split, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_symbolic) @@ -519,7 +526,7 @@ def test_fuse_not_affine(): block_b = sch.get_block("B") _, j, k = sch.get_loops(block_b) sch.fuse(j, k) - tvm.ir.assert_structural_equal(elementwise_not_affine_fused, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_not_affine_fused, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_not_affine) @@ -548,7 +555,7 @@ def zero_dim_added( sch = tir.Schedule(zero_dim, debug_mask="all") block = sch.get_block("C") sch.add_unit_loop(block) - tvm.ir.assert_structural_equal(zero_dim_added, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(zero_dim_added, sch.mod["main"]) def test_add_unit_loop_above_loop(): @@ -578,7 +585,7 @@ def zero_dim_added( block = sch.get_block("C") (loop,) = sch.get_loops(block) sch.add_unit_loop(loop) - tvm.ir.assert_structural_equal(zero_dim_added, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(zero_dim_added, sch.mod["main"]) @pytest.mark.skip("Pending fix in affine analysis") @@ -643,7 +650,7 @@ def test_split_int64_factors(): block_b = sch.get_block("B") _, _, k = sch.get_loops(block_b) sch.split(k, factors=[IntImm(dtype="int64", value=10), None]) - tvm.ir.assert_structural_equal(elementwise_symbolic_split, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_symbolic_split, sch.mod["main"]) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_schedule_storage_align.py b/tests/python/unittest/test_tir_schedule_storage_align.py index 23cb5d3b5339..3825234c20e0 100644 --- a/tests/python/unittest/test_tir_schedule_storage_align.py +++ b/tests/python/unittest/test_tir_schedule_storage_align.py @@ -19,7 +19,10 @@ import tvm from tvm import tir from tvm.script import tir as T -from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule.testing import ( + assert_structural_equal_ignore_global_symbol, + verify_trace_roundtrip, +) # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name @@ -105,7 +108,7 @@ def test_storage_align(use_block_name): s = tir.Schedule(func, debug_mask='all') B = 'B' if use_block_name else s.get_block("B") s.storage_align(B, 0, axis=0, factor=128, offset=127) - tvm.ir.assert_structural_equal(element_wise_storage_align, s.mod["main"]) + assert_structural_equal_ignore_global_symbol(element_wise_storage_align, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) @@ -115,7 +118,7 @@ def test_storage_align_update(): B = s.get_block("B") s.storage_align(B, 0, axis=0, factor=128, offset=0) s.storage_align(B, 0, axis=0, factor=128, offset=127) - tvm.ir.assert_structural_equal(element_wise_storage_align, s.mod["main"]) + assert_structural_equal_ignore_global_symbol(element_wise_storage_align, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py index 94dc650ac016..9646355f0a5d 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize.py +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -21,7 +21,10 @@ import tvm.testing from tvm import te, tir from tvm.script import tir as T -from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule.testing import ( + assert_structural_equal_ignore_global_symbol, + verify_trace_roundtrip, +) from tvm.tir.tensor_intrin.arm_cpu import ( DP4A_INTRIN, ARM_DOT_4x4_i8_NEON_INTRIN, @@ -507,7 +510,7 @@ def test_tensorize_matmul(): s.reorder(io, jo, ko, ii, ji, ki) s.decompose_reduction(update, ko) s.tensorize(ii, "test_mma_intrin") - tvm.ir.assert_structural_equal(tensorized_matmul, s.mod["main"]) + assert_structural_equal_ignore_global_symbol(tensorized_matmul, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) @@ -521,7 +524,7 @@ def test_tensorize_batch_matmul(): ko, ki = s.split(k, factors=[None, 16]) s.reorder(io, jo, ko, ii, ji, ki) s.tensorize(ii, "test_mma_intrin") - tvm.ir.assert_structural_equal(tensorized_batch_matmul_mma, s.mod["main"]) + assert_structural_equal_ignore_global_symbol(tensorized_batch_matmul_mma, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=batch_matmul) @@ -532,7 +535,7 @@ def test_tensorize_dot_product(): _, _, _, k = s.get_loops(C) _, ki = s.split(k, factors=[None, 4]) s.tensorize(ki, "test_dot_product_intrin") - tvm.ir.assert_structural_equal(tensorized_batch_matmul_dot_product, s.mod["main"]) + assert_structural_equal_ignore_global_symbol(tensorized_batch_matmul_dot_product, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) @@ -545,7 +548,7 @@ def test_tensorize_outer_product(): jo, ji = s.split(j, factors=[None, 16]) s.reorder(io, jo, k, ii, ji) s.tensorize(ii, "test_outer_product_intrin") - tvm.ir.assert_structural_equal(tensorized_batch_matmul_outer_product, s.mod["main"]) + assert_structural_equal_ignore_global_symbol(tensorized_batch_matmul_outer_product, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) @@ -560,7 +563,7 @@ def test_tensorize_with_annotation(): s.reorder(io, jo, ko, ii, ji, ki) s.decompose_reduction(update, ko) s.tensorize(ii, "test_annotated_mma_intrin") - tvm.ir.assert_structural_equal(annotated_tensorized_matmul, s.mod["main"]) + assert_structural_equal_ignore_global_symbol(annotated_tensorized_matmul, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) @@ -820,7 +823,7 @@ def tensorized_matmul_int64_shape( update = s.get_block("update") ii = s.get_loops(update)[-3] s.tensorize(ii, "test_mma_intrin") - tvm.ir.assert_structural_equal(s.mod["main"], tensorized_matmul_int64_shape) + assert_structural_equal_ignore_global_symbol(s.mod["main"], tensorized_matmul_int64_shape) verify_trace_roundtrip(sch=s, mod=matmul_int64_shape) diff --git a/tests/python/unittest/test_tir_schedule_trace.py b/tests/python/unittest/test_tir_schedule_trace.py index a87fd4ed5b56..a793699ca755 100644 --- a/tests/python/unittest/test_tir_schedule_trace.py +++ b/tests/python/unittest/test_tir_schedule_trace.py @@ -24,6 +24,7 @@ from tvm import tir from tvm.script import tir as T from tvm.tir.schedule import BlockRV, Instruction, InstructionKind, LoopRV, Trace +from tvm.tir.schedule.testing import assert_structural_equal_ignore_global_symbol # pylint: disable=no-member,invalid-name,unused-variable @@ -229,7 +230,7 @@ def test_trace_apply_to_schedule(): trace = _make_trace_2(BlockRV()) sch = tir.Schedule(elementwise, debug_mask="all") trace.apply_to_schedule(sch, remove_postproc=False, decision_provider=None) - tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_inlined, sch.mod["main"]) def test_trace_as_json_1(): @@ -313,7 +314,7 @@ def test_apply_json_to_schedule_1(): json_obj = trace.as_json() sch = tir.Schedule(elementwise, debug_mask="all") Trace.apply_json_to_schedule(json_obj, sch) - tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_inlined, sch.mod["main"]) def test_apply_json_to_schedule_sample_categorical(): @@ -373,7 +374,7 @@ def elementwise_expected(a: T.handle, c: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 - tvm.ir.assert_structural_equal(elementwise_expected, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_expected, sch.mod["main"]) def test_apply_annotation_from_json(): diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index 04bd00111ef3..52f2e2e3419f 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -22,7 +22,10 @@ import tvm.testing from tvm import tir from tvm.script import tir as T -from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule.testing import ( + assert_structural_equal_ignore_global_symbol, + verify_trace_roundtrip, +) # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks @@ -241,7 +244,9 @@ def test_two_elementwise_transform_intermediate_buffer(use_block_name): block = sch.get_block("B") sch.transform_layout(block, ("write", 0), packed_index_map_func) - tvm.ir.assert_structural_equal(two_elementwise_transformed_intermediate_buffer, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + two_elementwise_transformed_intermediate_buffer, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=two_elementwise) @@ -267,7 +272,9 @@ def test_two_elementwise_transform_input_buffer(use_block_name): block = sch.get_block("B") sch.transform_layout(block, ("read", 0), packed_index_map_func) - tvm.ir.assert_structural_equal(two_elementwise_transformed_input_buffer, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + two_elementwise_transformed_input_buffer, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=two_elementwise) @@ -284,7 +291,9 @@ def test_two_elementwise_transform_output_buffer(use_block_name): block = sch.get_block("C") sch.transform_layout(block, ("write", 0), packed_index_map_func) - tvm.ir.assert_structural_equal(two_elementwise_transformed_output_buffer, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + two_elementwise_transformed_output_buffer, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=two_elementwise) @@ -302,7 +311,7 @@ def test_two_elementwise_unit_dim(use_block_name): block = sch.get_block("B") sch.transform_layout(block, ("write", 0), index_map) - tvm.ir.assert_structural_equal(two_elementwise_unit_dim, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(two_elementwise_unit_dim, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise_unit_dim) @@ -343,6 +352,7 @@ def ref(B: T.Buffer((8, 8, 16, 16), "float32"), C: T.Buffer((128, 128), "float32 # T.reads(B[vi // 16 + vi_o, vj // 16 + vj_o, vi % 16, vj % 16]) # C[...] = B[vi // 16 + vi_o, vj // 16 + vj_o, vi % 16, vj % 16] + T.float32(1) + # not comparing PrimFuncs tvm.ir.assert_structural_equal(ref.body.block.body, sch.get(sch.get_loops(block_outer)[0])) @@ -371,14 +381,14 @@ def summation_3d_split( sch.transform_layout( index_map=lambda *indices, k: [*indices, k // 4, k % 4], block="compute", buffer="A" ) - tvm.ir.assert_structural_equal(summation_3d_split, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(summation_3d_split, sch.mod["main"]) def test_transform_block_layout_basic(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") block = "B" if use_block_name else sch.get_block("B") sch.transform_block_layout(block, lambda i, j: (i * 128 + j,)) - tvm.ir.assert_structural_equal(elementwise_transformed, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(elementwise_transformed, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) @@ -389,7 +399,7 @@ def test_transform_block_layout_conv2d_nhwc(use_block_name): block, lambda n, h, w, co, rh, rw, rc: (n * 112 * 112 + h * 112 + w, co, rh * 7 * 3 + rw * 3 + rc), ) - tvm.ir.assert_structural_equal(conv2d_nhwc_transformed, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(conv2d_nhwc_transformed, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=conv2d_nhwc) @@ -412,7 +422,9 @@ def two_elementwise_unit_dim_transformed( vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 - tvm.ir.assert_structural_equal(two_elementwise_unit_dim_transformed, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + two_elementwise_unit_dim_transformed, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=two_elementwise_unit_dim) @@ -459,7 +471,9 @@ def elementwise_int64_extent_transformed( sch = tir.Schedule(elementwise_int64_extent, debug_mask="all") block = "B" if use_block_name else sch.get_block("B") sch.transform_block_layout(block, lambda i, j: (i * 128 + j,)) - tvm.ir.assert_structural_equal(elementwise_int64_extent_transformed, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol( + elementwise_int64_extent_transformed, sch.mod["main"] + ) verify_trace_roundtrip(sch=sch, mod=elementwise_int64_extent) diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index a8be97488b25..f7b0e672b23c 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -23,7 +23,10 @@ from tvm import tir from tvm.ir import IRModule from tvm.script import tir as T -from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule.testing import ( + assert_structural_equal_ignore_global_symbol, + verify_trace_roundtrip, +) # pylint: disable=no-member,invalid-name,unused-variable @@ -333,7 +336,7 @@ def test_annotate_unannotate_loop(): sch.annotate(sch.get_loops(matmul)[1], "test2", 612) sch.annotate(sch.get_loops(matmul)[1], "test3", ["aa", 1]) sch.annotate(sch.get_loops(matmul)[0], "test4", {"arr": [0, 0], "key": 3}) - tvm.ir.assert_structural_equal(sch.mod["main"], matmul_relu_ann1) + assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_relu_ann1) verify_trace_roundtrip(sch=sch, mod=matmul_relu) sch.unannotate(sch.get_loops(matmul)[0], "test1") sch.unannotate(sch.get_loops(matmul)[1], "test2") @@ -350,7 +353,7 @@ def test_annotate_unannotate_block(): sch.annotate(relu, "test2", 0.22) sch.annotate(relu, "test3", ["aa", 1]) sch.annotate(matmul, "test4", {"arr": [0, 0], "key": 3}) - tvm.ir.assert_structural_equal(sch.mod["main"], matmul_relu_ann2) + assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_relu_ann2) verify_trace_roundtrip(sch=sch, mod=matmul_relu) sch.unannotate(matmul, "test1") sch.unannotate(relu, "test2") diff --git a/tests/python/unittest/test_tir_specialize.py b/tests/python/unittest/test_tir_specialize.py index ebae827ef5ad..508730aacfe2 100644 --- a/tests/python/unittest/test_tir_specialize.py +++ b/tests/python/unittest/test_tir_specialize.py @@ -18,6 +18,7 @@ import tvm from tvm.script import tir as T +from tvm.tir.schedule.testing import assert_structural_equal_ignore_global_symbol @T.prim_func @@ -199,13 +200,13 @@ def test_specialize_matmul(): a, _, _, n = matmul.params # fully specialized func = matmul.specialize({a: tvm.tir.decl_buffer((128, 128))}) - tvm.ir.assert_structural_equal(func, matmul_128) + assert_structural_equal_ignore_global_symbol(func, matmul_128) # partially specialized func = matmul.specialize({n: 128}) - tvm.ir.assert_structural_equal(func, matmul_m_128) + assert_structural_equal_ignore_global_symbol(func, matmul_m_128) # symbolic specialized func = matmul.specialize({n: tvm.tir.Var("x", "int32") * 8}) - tvm.ir.assert_structural_equal(func, matmul_m_8x) + assert_structural_equal_ignore_global_symbol(func, matmul_m_8x) def test_specialize_elemwise(): @@ -213,22 +214,22 @@ def test_specialize_elemwise(): C = element_wise.buffer_map[c] # fully specialized func = element_wise.specialize({a: tvm.tir.decl_buffer((128, 64))}) - tvm.ir.assert_structural_equal(func, element_wise_128_64) + assert_structural_equal_ignore_global_symbol(func, element_wise_128_64) # partially specialized func = element_wise.specialize({c: tvm.tir.decl_buffer((128, C.shape[1]))}) - tvm.ir.assert_structural_equal(func, element_wise_128_n) + assert_structural_equal_ignore_global_symbol(func, element_wise_128_n) def test_specialize_mem_copy(): a, _, m, n, p, q = mem_copy.params # fully specialized func = mem_copy.specialize({a: tvm.tir.decl_buffer((16, 16), strides=[8, 1], elem_offset=4)}) - tvm.ir.assert_structural_equal(func, mem_copy_16_16_8_4) + assert_structural_equal_ignore_global_symbol(func, mem_copy_16_16_8_4) func = mem_copy.specialize({n: 16, m: 16, p: 8, q: 4}) - tvm.ir.assert_structural_equal(func, mem_copy_16_16_8_4) + assert_structural_equal_ignore_global_symbol(func, mem_copy_16_16_8_4) # partially specialized func = mem_copy.specialize({q: n}) - tvm.ir.assert_structural_equal(func, mem_copy_m_n_p_n) + assert_structural_equal_ignore_global_symbol(func, mem_copy_m_n_p_n) def test_specialize_recursive_load(): @@ -239,7 +240,7 @@ def test_specialize_recursive_load(): def test_specialize_with_const_folding(): b = param_in_arith_exprs.params[1] func = param_in_arith_exprs.specialize({b: tvm.tir.decl_buffer([16])}) - tvm.ir.assert_structural_equal(func, param_in_arith_exprs_n_16) + assert_structural_equal_ignore_global_symbol(func, param_in_arith_exprs_n_16) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_texture_scope.py b/tests/python/unittest/test_tir_texture_scope.py index 2af4710751d7..fb98b6536e66 100644 --- a/tests/python/unittest/test_tir_texture_scope.py +++ b/tests/python/unittest/test_tir_texture_scope.py @@ -29,7 +29,7 @@ def test_texture_scope(): class PlusOneMultTwo: @T.prim_func def main(a: T.handle, b: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"tir.noalias": True}) A = T.match_buffer(a, (128, 128, 4), dtype="float32", scope="global.texture") B = T.alloc_buffer((128, 128, 4), dtype="float32", scope="global.texture") C = T.match_buffer(b, (128, 128, 4), dtype="float32", scope="global.texture") diff --git a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py index 5ba2824e74dd..7be1038ce5d4 100644 --- a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py @@ -384,10 +384,10 @@ def func_associativity_expected( def _check(original, transformed): func = original - mod = tvm.IRModule.from_expr(func) + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) body = tvm.tir.transform.CommonSubexprElimTIR(identify_equiv_terms=True)(mod) tvm.transform.PrintIR()(body) - tvm.ir.assert_structural_equal(body["main"], transformed) + tvm.ir.assert_structural_equal(body["main"], transformed.with_attr("global_symbol", "main")) def test_semantic_equiv_distributivity(): diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 1a5ef95a374a..d268403c1be4 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -34,8 +34,8 @@ def test_compact(self): is_lower_order_free = getattr(self, "is_lower_order_free", True) is_strict = getattr(self, "is_strict_mode", True) - before = tvm.IRModule.from_expr(self.before) - expected = tvm.IRModule.from_expr(self.expected) + before = tvm.IRModule.from_expr(self.before.with_attr("global_symbol", "main")) + expected = tvm.IRModule.from_expr(self.expected.with_attr("global_symbol", "main")) simplify = tvm.transform.Sequential([tir.transform.Simplify(), tir.transform.RemoveNoOp()]) after = simplify(tir.transform.CompactBufferAllocation(is_strict=is_strict)(before)) expected = simplify(expected) @@ -1056,7 +1056,7 @@ def before( C[i_0 * 26 + ax0, j_0 * 26 + ax1] = C_local[ax0, ax1] # Get partitioned workload to compact - before_mod = tvm.IRModule.from_expr(before) + before_mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): before_mod = tvm.tir.transform.LowerOpaqueBlock()(before_mod) before_mod = tvm.tir.transform.LoopPartition()(before_mod) diff --git a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py index 73b5203b56f0..8fbbaf59bb58 100644 --- a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py +++ b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py @@ -21,10 +21,10 @@ def _check(original, transformed): func = original - mod = tvm.IRModule.from_expr(func) + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod) mod = tvm.tir.transform.Simplify()(mod) - tvm.ir.assert_structural_equal(mod["main"], transformed) + tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main")) @T.prim_func diff --git a/tests/python/unittest/test_tir_transform_convert_ssa.py b/tests/python/unittest/test_tir_transform_convert_ssa.py index 918fe6b90738..38a93b199e44 100644 --- a/tests/python/unittest/test_tir_transform_convert_ssa.py +++ b/tests/python/unittest/test_tir_transform_convert_ssa.py @@ -103,7 +103,12 @@ def func(): with T.LetStmt(10) as var: T.evaluate(var) - return tvm.IRModule({"func_a": func, "func_b": func}) + return tvm.IRModule( + { + "func_a": func.with_attr("global_symbol", "func_a"), + "func_b": func.with_attr("global_symbol", "func_b"), + } + ) def expected(self): @I.ir_module @@ -133,7 +138,12 @@ def before(self): def func(n: T.int32): T.evaluate(n) - return tvm.IRModule({"func_a": func, "func_b": func}) + return tvm.IRModule( + { + "func_a": func.with_attr("global_symbol", "func_a"), + "func_b": func.with_attr("global_symbol", "func_b"), + } + ) def expected(self): @I.ir_module @@ -158,7 +168,12 @@ def func(a: T.handle("float32")): A = T.Buffer(shape=1, dtype="float32", data=a) T.evaluate(A[0]) - return tvm.IRModule({"func_a": func, "func_b": func}) + return tvm.IRModule( + { + "func_a": func.with_attr("global_symbol", "func_a"), + "func_b": func.with_attr("global_symbol", "func_b"), + } + ) def expected(self): @I.ir_module @@ -184,7 +199,12 @@ def before(self): def func(A: T.Buffer(1, "float32")): T.evaluate(A[0]) - return tvm.IRModule({"func_a": func, "func_b": func}) + return tvm.IRModule( + { + "func_a": func.with_attr("global_symbol", "func_a"), + "func_b": func.with_attr("global_symbol", "func_b"), + } + ) def expected(self): @I.ir_module diff --git a/tests/python/unittest/test_tir_transform_helpers.py b/tests/python/unittest/test_tir_transform_helpers.py index 00fd12521268..d4cd01ade248 100644 --- a/tests/python/unittest/test_tir_transform_helpers.py +++ b/tests/python/unittest/test_tir_transform_helpers.py @@ -24,7 +24,7 @@ def test_annotate_entry_func_single_primfunc(): @tvm.script.ir_module class MockModule: - @T.prim_func + @T.prim_func(private=True) def func1(A: T.Buffer((16,), "float32")): for i in T.serial(16): if i == 5: @@ -45,14 +45,14 @@ def func1(A: T.Buffer((16,), "float32")): # Test module @tvm.script.ir_module class MockModule: - @T.prim_func + @T.prim_func(private=True) def func1(A: T.Buffer((16,), "float32")): for i in T.serial(16): if i == 5: if i == 5: A[i] = 0.0 - @T.prim_func + @T.prim_func(private=True) def func2(A: T.Buffer((32,), "float32")): for i in T.serial(32): if i == 15: @@ -124,12 +124,24 @@ class TestBindTargetWithHostToInternalFunction(tvm.testing.CompareBeforeAfter): transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda", host="llvm")) - def before(): - T.evaluate(0) + def before(self): + @I.ir_module + class module: + @T.prim_func(private=True) + def main(): + T.evaluate(0) - def expected(): - T.func_attr({"target": T.target("cuda")}) - T.evaluate(0) + return module + + def expected(self): + @I.ir_module + class module: + @T.prim_func(private=True) + def main(): + T.func_attr({"target": T.target("cuda")}) + T.evaluate(0) + + return module class TestBindTargetIgnoresExisting(tvm.testing.CompareBeforeAfter): diff --git a/tests/python/unittest/test_tir_transform_hoist_expression.py b/tests/python/unittest/test_tir_transform_hoist_expression.py index a0b624a15c31..c9fdbd0c424e 100644 --- a/tests/python/unittest/test_tir_transform_hoist_expression.py +++ b/tests/python/unittest/test_tir_transform_hoist_expression.py @@ -27,7 +27,7 @@ class BaseBeforeAfter: def test_hoist(self, hoisted_conditionals, hoisted_let_bindings): before = self.before - before_mod = tvm.IRModule.from_expr(before) + before_mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) config = { "tir.HoistExpression": { @@ -40,7 +40,7 @@ def test_hoist(self, hoisted_conditionals, hoisted_let_bindings): after_mod = tvm.tir.transform.HoistExpression()(before_mod) after = after_mod["main"] - expected = self.expected + expected = self.expected.with_attr("global_symbol", "main") try: tvm.ir.assert_structural_equal(after, expected) diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index b9f35ed553e1..2a1ce2be28f5 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -37,10 +37,12 @@ def _check(original, transformed): func = original - mod = tvm.IRModule.from_expr(func) + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.tir.transform.InjectSoftwarePipeline()(mod) mod = tvm.tir.transform.Simplify()(mod) - tvm.ir.assert_structural_equal(mod["main"], transformed, True) + tvm.ir.assert_structural_equal( + mod["main"], transformed.with_attr("global_symbol", "main"), True + ) def _check_error(func): @@ -1108,7 +1110,7 @@ def test_error_missing_annotation(): def test_simple_compute_async(): - mod = tvm.IRModule.from_expr(gen_simple_compute(1)) + mod = tvm.IRModule.from_expr(gen_simple_compute(1).with_attr("global_symbol", "main")) sch = tvm.tir.Schedule(mod) _, loop = sch.get_loops(sch.get_block("compute")) @@ -1153,9 +1155,9 @@ def ref(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")): with T.attr(0, "async_wait_inflight_count", 0): C[tx, 15] = B[T.FloorMod(15, 2), tx, 0] + T.float32(1) - tvm.ir.assert_structural_equal(mod["main"], ref, True) + tvm.ir.assert_structural_equal(mod["main"], ref.with_attr("global_symbol", "main"), True) - mod = tvm.IRModule.from_expr(gen_simple_compute(3)) + mod = tvm.IRModule.from_expr(gen_simple_compute(3).with_attr("global_symbol", "main")) sch = tvm.tir.Schedule(mod) _, loop = sch.get_loops(sch.get_block("compute")) @@ -1210,7 +1212,7 @@ def ref(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")) -> N with T.attr(0, "async_wait_inflight_count", 2 - i): C[tx, i - 3 + 16] = B[(i - 3 + 16) % 4, tx, 0] + T.float32(1) - tvm.ir.assert_structural_equal(mod["main"], ref, True) + tvm.ir.assert_structural_equal(mod["main"], ref.with_attr("global_symbol", "main"), True) def test_async_producer_interleaving(): @@ -1240,7 +1242,7 @@ def simple_compute( T.writes(C[tx, i]) C[tx, i] = A_shared[tx, 0] + B_shared[tx, 0] - mod = tvm.IRModule.from_expr(simple_compute) + mod = tvm.IRModule.from_expr(simple_compute.with_attr("global_symbol", "main")) sch = tvm.tir.Schedule(mod) _, loop = sch.get_loops(sch.get_block("compute")) @@ -1317,11 +1319,11 @@ def ref( + B_shared[(i - 3 + 16) % 4, tx, 0] ) - tvm.ir.assert_structural_equal(mod["main"], ref, True) + tvm.ir.assert_structural_equal(mod["main"], ref.with_attr("global_symbol", "main"), True) def test_three_stage_compute_two_stage_async(): - mod = tvm.IRModule.from_expr(three_stage_compute) + mod = tvm.IRModule.from_expr(three_stage_compute.with_attr("global_symbol", "main")) sch = tvm.tir.Schedule(mod) _, loop = sch.get_loops(sch.get_block("compute")) @@ -1415,7 +1417,7 @@ def ref(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")) -> N ): D[tx, i - 2 + 16] = C[(i - 2 + 16) % 2, tx, 0] + T.float32(1) - tvm.ir.assert_structural_equal(mod["main"], ref, True) + tvm.ir.assert_structural_equal(mod["main"], ref.with_attr("global_symbol", "main"), True) N = K = M = 4096 diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py index beb20fd43ba6..593f9447d44c 100644 --- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -103,7 +103,9 @@ def get_vthread(name): C_expected_alloc = m * nthread * nthread stmt = tvm.tir.transform.InjectVirtualThread()( - tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread(vthread_name))) + tvm.IRModule.from_expr( + tvm.tir.PrimFunc([], get_vthread(vthread_name)).with_attr("global_symbol", "main") + ) )["main"] assert list(stmt.body.body.extents) == [A_expected_alloc] @@ -127,7 +129,7 @@ def test_vthread_if_then_else(): stmt = ib.get() stmt = tvm.tir.transform.InjectVirtualThread()( - tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt).with_attr("global_symbol", "main")) )["main"] assert stmt.body.body.body[0].else_case != None @@ -160,11 +162,11 @@ def expected_func(): B[T.Mul(2, 4) : T.Mul(2, 4) + 4] = T.broadcast(2, 4) B[T.Mul(3, 4) : T.Mul(3, 4) + 4] = T.broadcast(3, 4) - before_mod = tvm.IRModule.from_expr(before_func) + before_mod = tvm.IRModule.from_expr(before_func.with_attr("global_symbol", "main")) after_mod = tvm.tir.transform.InjectVirtualThread()(before_mod) after_func = after_mod["main"] - tvm.ir.assert_structural_equal(after_func, expected_func) + tvm.ir.assert_structural_equal(after_func, expected_func.with_attr("global_symbol", "main")) def test_vthread_vectorized(): @@ -187,12 +189,12 @@ def expected_func(): B[T.Div(T.Mul(2, 4), 4)] = T.broadcast(2, 4) B[T.Div(T.Mul(3, 4), 4)] = T.broadcast(3, 4) - before_mod = tvm.IRModule.from_expr(before_func) + before_mod = tvm.IRModule.from_expr(before_func.with_attr("global_symbol", "main")) intermediate_mod = tvm.tir.transform.InjectVirtualThread()(before_mod) after_mod = tvm.tir.transform.StorageRewrite()(intermediate_mod) after_func = after_mod["main"] - tvm.ir.assert_structural_equal(after_func, expected_func) + tvm.ir.assert_structural_equal(after_func, expected_func.with_attr("global_symbol", "main")) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_transform_lift_thread_binding.py b/tests/python/unittest/test_tir_transform_lift_thread_binding.py index defcc6d6c1dc..84868ae2ed16 100644 --- a/tests/python/unittest/test_tir_transform_lift_thread_binding.py +++ b/tests/python/unittest/test_tir_transform_lift_thread_binding.py @@ -130,9 +130,9 @@ def expected(A: T.Buffer((32, 1, 128), "float32"), b: T.handle, c: T.handle): T.writes(C[blockIdx_x // n, 0, blockIdx_x % n]) C[blockIdx_x // n, 0, blockIdx_x % n] = D_local[blockIdx_x // n, 0, blockIdx_x % n] * T.float32(0.088397790055248615) # fmt: on - mod = tvm.IRModule({"main": before}) + mod = tvm.IRModule({"main": before.with_attr("global_symbol", "main")}) after = tir.transform.LiftThreadBinding()(mod) - tvm.ir.assert_structural_equal(expected, after["main"]) + tvm.ir.assert_structural_equal(expected.with_attr("global_symbol", "main"), after["main"]) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index b88f8d1e3e72..aa11ae5a5f7b 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -40,7 +40,7 @@ def test_basic(): bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt).with_attr("global_symbol", "main")) mod = tvm.tir.transform.LoopPartition()(mod) stmt = tvm.tir.transform.Simplify()(mod)["main"] @@ -60,7 +60,7 @@ def test_const_loop(): bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt).with_attr("global_symbol", "main")) with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): mod = tvm.tir.transform.LoopPartition()(mod) stmt = tvm.tir.transform.Simplify()(mod)["main"].body @@ -80,7 +80,7 @@ def test_no_unroll_loop(): bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt).with_attr("global_symbol", "main")) with tvm.transform.PassContext( config={ "tir.LoopPartition": { @@ -109,7 +109,7 @@ def test_multi_loop(): ib.emit(tvm.tir.Evaluate(n)) stmt = ib.get() - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n, m], stmt)) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n, m], stmt).with_attr("global_symbol", "main")) mod = tvm.tir.transform.LoopPartition()(mod) stmt = tvm.tir.transform.Simplify()(mod)["main"].body @@ -133,7 +133,7 @@ def test_multi_if(): ib.emit(tvm.tir.Evaluate(n)) stmt = ib.get() - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt).with_attr("global_symbol", "main")) mod = tvm.tir.transform.LoopPartition()(mod) stmt = tvm.tir.transform.Simplify()(mod)["main"].body @@ -156,7 +156,7 @@ def test_thread_axis(): bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt).with_attr("global_symbol", "main")) mod = tvm.tir.transform.LoopPartition()(mod) stmt = tvm.tir.transform.Simplify()(mod)["main"] @@ -195,7 +195,7 @@ def test_condition(): ib.emit(tvm.tir.Evaluate(tvm.tir.Select(ib.likely(i * 4 + j < n), m, n))) stmt = ib.get() - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt)) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt).with_attr("global_symbol", "main")) mod = tvm.tir.transform.LoopPartition()(mod) stmt = tvm.tir.transform.Simplify()(mod)["main"].body @@ -210,7 +210,7 @@ def test_condition_EQ(): ib.emit(tvm.tir.Evaluate(tvm.tir.Select(ib.likely(tvm.tir.EQ(i, 5)), m, n))) stmt = ib.get() - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt)) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt).with_attr("global_symbol", "main")) with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): mod = tvm.tir.transform.LoopPartition()(mod) stmt = tvm.tir.transform.Simplify()(mod)["main"].body @@ -247,7 +247,7 @@ def test_everything_during_deduction(): ib.emit(tvm.tir.Evaluate(m)) stmt = ib.get() - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt)) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt).with_attr("global_symbol", "main")) mod = tvm.tir.transform.LoopPartition()(mod) stmt = tvm.tir.transform.Simplify()(mod)["main"].body @@ -267,7 +267,7 @@ def test_single_likely(): bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt).with_attr("global_symbol", "main")) with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): mod = tvm.tir.transform.LoopPartition()(mod) @@ -294,7 +294,7 @@ def test_multi_likely(): bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt).with_attr("global_symbol", "main")) with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): mod = tvm.tir.transform.LoopPartition()(mod) @@ -327,7 +327,9 @@ def test_oneD_pool(): stmt = ib.get() - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, data, out], stmt)) + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([m, data, out], stmt).with_attr("global_symbol", "main") + ) with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): mod = tvm.tir.transform.LoopPartition()(mod) @@ -356,7 +358,9 @@ def test_cce_loop_1(): ) stmt = ib.get() - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([Ab, Bb], stmt).with_attr("global_symbol", "main") + ) with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): mod = tvm.tir.transform.LoopPartition()(mod) stmt = tvm.tir.transform.Simplify()(mod)["main"].body @@ -380,7 +384,7 @@ def test_cce_loop_2(): stmt = ib.get() - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt).with_attr("global_symbol", "main")) with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): mod = tvm.tir.transform.LoopPartition()(mod) stmt = tvm.tir.transform.Simplify()(mod)["main"].body @@ -401,7 +405,7 @@ def test_cce_loop_3(): ib.emit(tvm.tir.call_extern("float16", "cce_intrisic", head1)) stmt = ib.get() - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt).with_attr("global_symbol", "main")) with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): mod = tvm.tir.transform.LoopPartition()(mod) @@ -436,7 +440,7 @@ def test_conv_tiling(): oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16) bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt).with_attr("global_symbol", "main")) with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): mod = tvm.tir.transform.LoopPartition()(mod) stmt = tvm.tir.transform.Simplify()(mod)["main"].body @@ -528,7 +532,7 @@ def test_simple_rfactor(): bounds = tvm.te.schedule.InferBound(s) stmt1 = tvm.te.schedule.ScheduleOps(s, bounds) - mod1 = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt1)) + mod1 = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt1).with_attr("global_symbol", "main")) stmt1 = tvm.tir.transform.Simplify()(mod1)["main"].body with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): @@ -567,7 +571,7 @@ def test_explicit_partition_hint(): def partition_from_scheduled_tir(prim_func, pass_cfg): with tvm.transform.PassContext(config=pass_cfg): - mod = IRModule.from_expr(prim_func) + mod = IRModule.from_expr(prim_func.with_attr("global_symbol", "main")) mod = tvm.tir.transform.LowerOpaqueBlock()(mod) mod = tvm.tir.transform.FlattenBuffer()(mod) mod = tvm.tir.transform.LoopPartition()(mod) @@ -624,7 +628,9 @@ def test_condition_mutually_exclusive(): mod = partition_from_scheduled_tir( concat_func_3, {"tir.LoopPartition": {"partition_const_loop": True}} ) - assert tvm.ir.structural_equal(mod["main"], partitioned_concat_3) + assert tvm.ir.structural_equal( + mod["main"], partitioned_concat_3.with_attr("global_symbol", "main") + ) def test_loop_partition_unroll_hint(): @@ -674,7 +680,7 @@ def partitioned_main( mod = tvm.tir.transform.UnrollLoop()(mod) mod = tvm.tir.transform.RemoveNoOp()(mod) mod = tvm.tir.transform.Simplify()(mod) - assert tvm.ir.structural_equal(mod["main"], partitioned_main) + assert tvm.ir.structural_equal(mod["main"], partitioned_main.with_attr("global_symbol", "main")) def test_loop_partition_recursive_unroll_hint(): @@ -743,7 +749,7 @@ def partitioned_main(): } }, ) - assert tvm.ir.structural_equal(mod["main"], partitioned_main) + assert tvm.ir.structural_equal(mod["main"], partitioned_main.with_attr("global_symbol", "main")) def test_loop_partition_keep_loop_annotations(): @@ -777,7 +783,7 @@ def after(A: T.Buffer(160, "int32"), B: T.Buffer(160, "int32")) -> None: } }, ) - assert tvm.ir.structural_equal(mod["main"], after) + assert tvm.ir.structural_equal(mod["main"], after.with_attr("global_symbol", "main")) def test_loop_partition_with_unit_loop_in_condition(): @@ -825,7 +831,7 @@ def after( } }, ) - assert tvm.ir.structural_equal(mod["main"], after) + assert tvm.ir.structural_equal(mod["main"], after.with_attr("global_symbol", "main")) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py index 2334fe535076..6162233b6583 100644 --- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py @@ -27,9 +27,11 @@ def _check(original, transformed): - mod = tvm.IRModule.from_expr(original) + mod = tvm.IRModule.from_expr(original.with_attr("global_symbol", "main")) mod = tvm.tir.transform.LowerCrossThreadReduction()(mod) - tvm.ir.assert_structural_equal(mod["main"], transformed, True) + tvm.ir.assert_structural_equal( + mod["main"], transformed.with_attr("global_symbol", "main"), True + ) def _check_fail(original): diff --git a/tests/python/unittest/test_tir_transform_lower_opaque_block.py b/tests/python/unittest/test_tir_transform_lower_opaque_block.py index a7502edd31ab..444e36bfbb7a 100644 --- a/tests/python/unittest/test_tir_transform_lower_opaque_block.py +++ b/tests/python/unittest/test_tir_transform_lower_opaque_block.py @@ -22,10 +22,12 @@ def _check(original, transformed): func = original - mod = tvm.IRModule.from_expr(func) + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.tir.transform.LowerOpaqueBlock()(mod) mod = tvm.tir.transform.Simplify()(mod) - tvm.ir.assert_structural_equal(mod["main"], transformed, True) + tvm.ir.assert_structural_equal( + mod["main"], transformed.with_attr("global_symbol", "main"), True + ) @T.prim_func @@ -309,7 +311,7 @@ def test_lower_te(): def test_annotated_loops(): - mod = tvm.IRModule.from_expr(annotated_loops) + mod = tvm.IRModule.from_expr(annotated_loops.with_attr("global_symbol", "main")) mod = tvm.tir.transform.LowerOpaqueBlock()(mod) attr1 = mod["main"].body attr2 = attr1.body @@ -328,7 +330,7 @@ def annotated_block() -> None: T.block_attr({"pragma_1": "str_value", "pragma_2": 1, "pragma_3": 0.0}) T.evaluate(0) - mod = tvm.IRModule.from_expr(annotated_block) + mod = tvm.IRModule.from_expr(annotated_block.with_attr("global_symbol", "main")) mod = tvm.tir.transform.LowerOpaqueBlock()(mod) attr1 = mod["main"].body attr2 = attr1.body @@ -353,9 +355,9 @@ def after(A: T.Buffer(8, "float32"), B: T.Buffer(8, "float32")): for i in T.serial(8, annotations={"k_0": 1, "k_1": [2, 3], "k_2": 3.14}): B[i] = A[i] + 1.0 - mod = tvm.IRModule.from_expr(before) + mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) mod = tvm.tir.transform.LowerOpaqueBlock()(mod) - tvm.ir.assert_structural_equal(mod["main"], after) + tvm.ir.assert_structural_equal(mod["main"], after.with_attr("global_symbol", "main")) def test_boolean_handling(): diff --git a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py index 21db36d1f9db..de1020ef2078 100644 --- a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py @@ -23,6 +23,8 @@ import numpy as np +from tvm.tir.schedule.testing import assert_structural_equal_ignore_global_symbol + @tvm.register_func("tvm.test_matmul") def my_matmul(a, b, c): @@ -234,7 +236,7 @@ def expected(): def test_compare(self, before, expected, transform): after = transform(before) - tvm.ir.assert_structural_equal(after, expected, map_free_vars=True) + assert_structural_equal_ignore_global_symbol(after, expected, map_free_vars=True) class TestLowerCPUAllocation(tvm.testing.CompareBeforeAfter): diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index 6f84b6f6d48c..2f871a246f53 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -44,7 +44,7 @@ def test_makeapi(): )(mod) before = mod - after = tvm.tir.transform.MakePackedAPI()(mod) + after = tvm.tir.transform.MakePackedAPI()(before) f = after["main"] assert len(f.params) == 6 @@ -175,14 +175,15 @@ def test_device_api_context_implicit_resource_handle(): @pytest.mark.parametrize("use_global_symbol", [True, False]) def test_no_op_when_global_symbol_is_absent(use_global_symbol): func_attr = {"target": tvm.target.Target("llvm", host="llvm")} - if use_global_symbol: - func_attr["global_symbol"] = "main" - @T.prim_func + @T.prim_func(private=True) def before(): T.func_attr(func_attr) T.evaluate(0) + if use_global_symbol: + before = before.with_attr("global_symbol", "main") + after = tvm.tir.transform.MakePackedAPI()(tvm.IRModule.from_expr(before))["main"] if use_global_symbol: assert len(after.params) == 6 @@ -225,10 +226,11 @@ def test_internal_subroutine_call(): class before: @T.prim_func def main(A: T.Buffer(1, "float32")): - T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")}) + T.func_attr({"target": T.target("llvm", host="llvm")}) before.subroutine(A.data) - @T.prim_func + # this test fails if it's made public + @T.prim_func(private=True) def subroutine(A_data: T.handle("float32")): T.func_attr({"target": T.target("llvm")}) T.evaluate(A_data) diff --git a/tests/python/unittest/test_tir_transform_make_unpacked_api.py b/tests/python/unittest/test_tir_transform_make_unpacked_api.py index 868d30db3618..46fd4104544a 100644 --- a/tests/python/unittest/test_tir_transform_make_unpacked_api.py +++ b/tests/python/unittest/test_tir_transform_make_unpacked_api.py @@ -156,7 +156,7 @@ def main(A: T.Buffer(1, "float32")): T.func_attr({"global_symbol": "main", "target": T.target("cuda", host="llvm")}) mod.subroutine(A.data) - @T.prim_func + @T.prim_func(private=True) def subroutine(A_data: T.handle("float32")): T.func_attr({"target": T.target("cuda")}) T.evaluate(A_data) @@ -174,7 +174,7 @@ def main(A_data: T.handle("float32")) -> T.int32: mod.subroutine(A_data) T.ret(T.int32(0)) - @T.prim_func + @T.prim_func(private=True) def subroutine(A_data: T.handle("float32")): T.func_attr({"target": T.target("cuda")}) T.evaluate(A_data) @@ -200,7 +200,7 @@ def main(A: T.Buffer(1, "float32")): T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")}) mod.subroutine(A.data) - @T.prim_func + @T.prim_func(private=True) def subroutine(A_data: T.handle("float32")): T.func_attr({"target": T.target("llvm")}) T.evaluate(A_data) @@ -218,7 +218,7 @@ def main(A_data: T.handle("float32")) -> T.int32: mod.subroutine(A_data) T.ret(T.int32(0)) - @T.prim_func + @T.prim_func(private=True) def subroutine(A_data: T.handle("float32")): T.func_attr({"target": T.target("llvm")}) T.evaluate(A_data) diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index b3b0c6f59b0f..c03dd7a5291d 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -321,8 +321,10 @@ def expected_after(A: T.Buffer(128, "float32"), B: T.Buffer(130, "float32")): i * 65 + j >= 0 and i * 65 + j < 128, A[i * 65 + j], T.float32(0), dtype="float32" ) - after = tvm.tir.transform.NarrowDataType(32)(tvm.IRModule.from_expr(before))["main"] - tvm.ir.assert_structural_equal(after, expected_after) + after = tvm.tir.transform.NarrowDataType(32)( + tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) + )["main"] + tvm.ir.assert_structural_equal(after, expected_after.with_attr("global_symbol", "main")) def test_block(): @@ -342,8 +344,10 @@ def expected_after(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32" vi = T.axis.spatial(T.int32(128), i * T.int32(8) + j) B[vi] = A[vi] + T.float32(1) - after = tvm.tir.transform.NarrowDataType(32)(tvm.IRModule.from_expr(before))["main"] - tvm.ir.assert_structural_equal(after, expected_after) + after = tvm.tir.transform.NarrowDataType(32)( + tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) + )["main"] + tvm.ir.assert_structural_equal(after, expected_after.with_attr("global_symbol", "main")) def test_avg_pool2d(): @@ -402,9 +406,11 @@ def expected_after(PSUM: T.Buffer((313600,), "int32"), PAVG: T.Buffer((313600,), ), ) - after = tvm.tir.transform.NarrowDataType(32)(tvm.IRModule.from_expr(before)) + after = tvm.tir.transform.NarrowDataType(32)( + tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) + ) after = tvm.tir.transform.Simplify()(after) - tvm.ir.assert_structural_equal(after["main"], expected_after) + tvm.ir.assert_structural_equal(after["main"], expected_after.with_attr("global_symbol", "main")) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py index 758a395da6d7..fe724ad0c981 100644 --- a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py @@ -27,9 +27,9 @@ def _check(original, transformed): func = original - mod = tvm.IRModule.from_expr(func) + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) - tvm.ir.assert_structural_equal(mod["main"], transformed) + tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main")) @T.prim_func diff --git a/tests/python/unittest/test_tir_transform_profiling_instr.py b/tests/python/unittest/test_tir_transform_profiling_instr.py index d14e2a4c8925..4084ad0feb27 100644 --- a/tests/python/unittest/test_tir_transform_profiling_instr.py +++ b/tests/python/unittest/test_tir_transform_profiling_instr.py @@ -277,9 +277,11 @@ def test6_expected_output(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> # By default, only loops with siblings are instrumented. def test1(): with tvm.transform.PassContext(config=default_lwp_test_config): - mod = tvm.IRModule.from_expr(input1) + mod = tvm.IRModule.from_expr(input1.with_attr("global_symbol", "main")) mod = tvm.tir.transform.InstrumentProfileIntrinsics()(mod) - tvm.ir.assert_structural_equal(mod["main"], test1_expected_output) + tvm.ir.assert_structural_equal( + mod["main"], test1_expected_output.with_attr("global_symbol", "main") + ) # By default, only loops with siblings are instrumented. Here, 'lwp_max_depth' @@ -288,9 +290,11 @@ def test2(): test2_config = default_lwp_test_config.copy() test2_config.update({"tir.lwp_max_depth": 3}) with tvm.transform.PassContext(config=test2_config): - mod = tvm.IRModule.from_expr(input1) + mod = tvm.IRModule.from_expr(input1.with_attr("global_symbol", "main")) mod = tvm.tir.transform.InstrumentProfileIntrinsics()(mod) - tvm.ir.assert_structural_equal(mod["main"], test1_expected_output) + tvm.ir.assert_structural_equal( + mod["main"], test1_expected_output.with_attr("global_symbol", "main") + ) # test3: Use 'lwp_max_depth' to instrument loops upto a certain depth. This flag @@ -301,18 +305,22 @@ def test3(): test3_config = default_lwp_test_config.copy() test3_config.update({"tir.lwp_max_depth": 3, "tir.instr_siblings": False}) with tvm.transform.PassContext(config=test3_config): - mod = tvm.IRModule.from_expr(input1) + mod = tvm.IRModule.from_expr(input1.with_attr("global_symbol", "main")) mod = tvm.tir.transform.InstrumentProfileIntrinsics()(mod) - tvm.ir.assert_structural_equal(mod["main"], test3_expected_output) + tvm.ir.assert_structural_equal( + mod["main"], test3_expected_output.with_attr("global_symbol", "main") + ) # test4: Use 'lwp_min_height' to exclude inner loops upto a certain height from # instrumentation. def test4(): with tvm.transform.PassContext(config=default_lwp_test_config): - mod = tvm.IRModule.from_expr(input2) + mod = tvm.IRModule.from_expr(input2.with_attr("global_symbol", "main")) mod = tvm.tir.transform.InstrumentProfileIntrinsics()(mod) - tvm.ir.assert_structural_equal(mod["main"], test4_expected_output) + tvm.ir.assert_structural_equal( + mod["main"], test4_expected_output.with_attr("global_symbol", "main") + ) # test5: Use both 'lwp_min_height' and 'lwp_max_depth'. @@ -323,17 +331,21 @@ def test5(): {"tir.lwp_max_depth": 3, "tir.instr_siblings": False, "tir.lwp_min_height": 2} ) with tvm.transform.PassContext(config=test5_config): - mod = tvm.IRModule.from_expr(input1) + mod = tvm.IRModule.from_expr(input1.with_attr("global_symbol", "main")) mod = tvm.tir.transform.InstrumentProfileIntrinsics()(mod) - tvm.ir.assert_structural_equal(mod["main"], test5_expected_output) + tvm.ir.assert_structural_equal( + mod["main"], test5_expected_output.with_attr("global_symbol", "main") + ) # test6: Tests instrumentation for the parallel loops def test6(): with tvm.transform.PassContext(config=default_lwp_test_config): - mod = tvm.IRModule.from_expr(input3) + mod = tvm.IRModule.from_expr(input3.with_attr("global_symbol", "main")) mod = tvm.tir.transform.InstrumentProfileIntrinsics()(mod) - tvm.ir.assert_structural_equal(mod["main"], test6_expected_output) + tvm.ir.assert_structural_equal( + mod["main"], test6_expected_output.with_attr("global_symbol", "main") + ) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_transform_remove_weight_layout_rewrite_block.py b/tests/python/unittest/test_tir_transform_remove_weight_layout_rewrite_block.py index 6d6e0da71cc5..3b099aa0838e 100644 --- a/tests/python/unittest/test_tir_transform_remove_weight_layout_rewrite_block.py +++ b/tests/python/unittest/test_tir_transform_remove_weight_layout_rewrite_block.py @@ -24,9 +24,9 @@ def _check(before, expect): if isinstance(before, PrimFunc): - before = IRModule({"main": before}) + before = IRModule({"main": before.with_attr("global_symbol", "main")}) if isinstance(expect, PrimFunc): - expect = IRModule({"main": expect}) + expect = IRModule({"main": expect.with_attr("global_symbol", "main")}) mod = tvm.tir.transform.RemoveWeightLayoutRewriteBlock()(before) tvm.ir.assert_structural_equal(mod, expect) diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py b/tests/python/unittest/test_tir_transform_split_host_device.py index a4dbb6b6b9a3..b61fcc66014e 100644 --- a/tests/python/unittest/test_tir_transform_split_host_device.py +++ b/tests/python/unittest/test_tir_transform_split_host_device.py @@ -115,7 +115,7 @@ def main(n: T.int32): T.func_attr({"target": T.target("cuda", host="llvm -opt-level=0")}) mod.main_kernel(n) - @T.prim_func + @T.prim_func(private=True) def main_kernel(n: T.int32): T.func_attr( { @@ -152,7 +152,7 @@ def main(n: T.int32): err = mod.main_kernel(n) assert err == 0, "Error executing compute kernel" - @T.prim_func + @T.prim_func(private=True) def main_kernel(n: T.int32) -> T.int32: T.func_attr( { @@ -193,7 +193,7 @@ def main(n: T.int32): T.func_attr({"target": T.target("llvm")}) mod.main_kernel(n) - @T.prim_func + @T.prim_func(private=True) def main_kernel(n: T.int32): T.func_attr( { @@ -254,7 +254,7 @@ def main(n: T.int32): T.func_attr({"target": T.target("cuda", host="llvm -opt-level=0")}) mod.main_kernel_1(n) - @T.prim_func + @T.prim_func(private=True) def main_kernel_1(n: T.int32): T.func_attr( { diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index 34de6fcabf3a..197e81818ee3 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -299,7 +299,7 @@ def verify(n): total_alloc[0] += n.extents[0].value total_alloc = [0] - mod = tvm.IRModule.from_expr(before) + mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) mod.show() tvm.tir.stmt_functor.post_order_visit(mod["main"].body, verify) assert total_alloc[0] == 24 @@ -722,8 +722,10 @@ def func_rewritten(A: T.Buffer((8,), "float32")) -> None: x: T.float32 = T.exp(B[0], dtype="float32") A[i] = (x + 1.0) / (x - 1.0) - mod = tvm.tir.transform.StorageRewrite()(tvm.IRModule.from_expr(func)) - tvm.ir.assert_structural_equal(mod["main"], func_rewritten) + mod = tvm.tir.transform.StorageRewrite()( + tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + ) + tvm.ir.assert_structural_equal(mod["main"], func_rewritten.with_attr("global_symbol", "main")) class BaseCompare(tvm.testing.CompareBeforeAfter): diff --git a/tests/python/unittest/test_tir_transform_unify_thread_binding.py b/tests/python/unittest/test_tir_transform_unify_thread_binding.py index 0b4f4bfb39f9..d42adfcee4bb 100644 --- a/tests/python/unittest/test_tir_transform_unify_thread_binding.py +++ b/tests/python/unittest/test_tir_transform_unify_thread_binding.py @@ -24,10 +24,12 @@ def _check(original, transformed): - mod = tvm.IRModule.from_expr(original) + mod = tvm.IRModule.from_expr(original.with_attr("global_symbol", "main")) mod = tvm.tir.transform.UnifyThreadBinding()(mod) mod = tvm.tir.transform.Simplify()(mod) - tvm.ir.assert_structural_equal(mod["main"], transformed, True) + tvm.ir.assert_structural_equal( + mod["main"], transformed.with_attr("global_symbol", "main"), True + ) def _check_fail(original): diff --git a/tests/python/unittest/test_tir_unsafe_hide_buffer_access.py b/tests/python/unittest/test_tir_unsafe_hide_buffer_access.py index 18fb2a5d5841..80944dc21da6 100644 --- a/tests/python/unittest/test_tir_unsafe_hide_buffer_access.py +++ b/tests/python/unittest/test_tir_unsafe_hide_buffer_access.py @@ -20,7 +20,10 @@ import tvm.testing from tvm import tir from tvm.script import tir as T -from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule.testing import ( + assert_structural_equal_ignore_global_symbol, + verify_trace_roundtrip, +) @T.prim_func @@ -72,7 +75,7 @@ def test_hide_buffer_access_read(): sch = tir.Schedule(indirect_mem_access, debug_mask="all") block_b = sch.get_block("B") sch.unsafe_hide_buffer_access(block_b, "read", [1]) - tvm.ir.assert_structural_equal(indirect_mem_access_hide_ia, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(indirect_mem_access_hide_ia, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=indirect_mem_access) @@ -80,7 +83,7 @@ def test_hide_buffer_access_write(): sch = tir.Schedule(indirect_mem_access, debug_mask="all") block_b = sch.get_block("B") sch.unsafe_hide_buffer_access(block_b, "write", [1]) - tvm.ir.assert_structural_equal(indirect_mem_access_hide_ib, sch.mod["main"]) + assert_structural_equal_ignore_global_symbol(indirect_mem_access_hide_ib, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=indirect_mem_access) diff --git a/tests/python/unittest/test_tvmscript_complete.py b/tests/python/unittest/test_tvmscript_complete.py index 6d435a906e37..2723566d8c2c 100644 --- a/tests/python/unittest/test_tvmscript_complete.py +++ b/tests/python/unittest/test_tvmscript_complete.py @@ -253,10 +253,18 @@ def expected_recursive_bufferslice_indices(data: T.handle, index: T.handle) -> N def test_complete_buffer_indices(): - new_func = tvm.script.from_source(func_with_bufferslice_indices.script()) - tvm.ir.assert_structural_equal(new_func, expected_bufferslice_indices) - new_func = tvm.script.from_source(func_with_recursive_bufferslice_indices.script()) - tvm.ir.assert_structural_equal(new_func, expected_recursive_bufferslice_indices) + new_func = tvm.script.from_source(func_with_bufferslice_indices.script()).with_attr( + "global_symbol", "main" + ) + tvm.ir.assert_structural_equal( + new_func, expected_bufferslice_indices.with_attr("global_symbol", "main") + ) + new_func = tvm.script.from_source(func_with_recursive_bufferslice_indices.script()).with_attr( + "global_symbol", "main" + ) + tvm.ir.assert_structural_equal( + new_func, expected_recursive_bufferslice_indices.with_attr("global_symbol", "main") + ) @T.prim_func @@ -292,7 +300,10 @@ def expected_match_buffer_func(a: T.handle) -> None: def test_complete_match_buffer(): - tvm.ir.assert_structural_equal(match_buffer_func, expected_match_buffer_func) + tvm.ir.assert_structural_equal( + match_buffer_func.with_attr("global_symbol", "main"), + expected_match_buffer_func.with_attr("global_symbol", "main"), + ) @T.prim_func @@ -319,8 +330,10 @@ def expect_alloc_buffer_func(a: T.handle, b: T.handle) -> None: def test_complete_alloc_buffer(): - rt_func = tvm.script.from_source(alloc_buffer_func.script()) - tvm.ir.assert_structural_equal(alloc_buffer_func, expect_alloc_buffer_func) + rt_func = tvm.script.from_source(alloc_buffer_func.script()).with_attr("global_symbol", "main") + tvm.ir.assert_structural_equal( + rt_func, expect_alloc_buffer_func.with_attr("global_symbol", "main") + ) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tvmscript_meta_programming.py b/tests/python/unittest/test_tvmscript_meta_programming.py index ed567b659444..83b71e1447c7 100644 --- a/tests/python/unittest/test_tvmscript_meta_programming.py +++ b/tests/python/unittest/test_tvmscript_meta_programming.py @@ -49,8 +49,8 @@ def matmul_128_128_128_fp16(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = T.float32(0) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] - f = matmul_generator(128, 128, 128, "float16") - tvm.ir.assert_structural_equal(f, matmul_128_128_128_fp16) + f = matmul_generator(128, 128, 128, "float16").with_attr("global_symbol", "main") + tvm.ir.assert_structural_equal(f, matmul_128_128_128_fp16.with_attr("global_symbol", "main")) def test_meta_programming_uncaptured_var(): @@ -75,8 +75,8 @@ def fp16(A: T.Buffer((1,), "float16"), C: T.Buffer((1,), "float16")): with T.block("C"): C[i] = T.erf(A[i]) - tvm.ir.assert_structural_equal(fp16, generate_erf("float16")) - tvm.ir.assert_structural_equal(fp32, generate_erf("float32")) + tvm.ir.assert_structural_equal(fp16.with_attr("global_symbol", "main"), generate_erf("float16")) + tvm.ir.assert_structural_equal(fp32.with_attr("global_symbol", "main"), generate_erf("float32")) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tvmscript_ops.py b/tests/python/unittest/test_tvmscript_ops.py index 671fe3cc199d..dc539b9327c8 100644 --- a/tests/python/unittest/test_tvmscript_ops.py +++ b/tests/python/unittest/test_tvmscript_ops.py @@ -156,8 +156,13 @@ def test_alloc_zero_dim_buffer_round_trip(): rt_func_with_block = tvm.script.from_source(func_with_block.script()) rt_mod = tvm.build(rt_func, "llvm") rt_mod_with_block = tvm.build(rt_func_with_block, "llvm") - tvm.ir.assert_structural_equal(func, func_with_block) - tvm.ir.assert_structural_equal(rt_func, rt_func_with_block) + tvm.ir.assert_structural_equal( + func.with_attr("global_symbol", "main"), func_with_block.with_attr("global_symbol", "main") + ) + tvm.ir.assert_structural_equal( + rt_func.with_attr("global_symbol", "main"), + rt_func_with_block.with_attr("global_symbol", "main"), + ) _check_alloc_zero_dim_buffer(rt_mod) _check_alloc_zero_dim_buffer(rt_mod_with_block) @@ -242,7 +247,10 @@ def slice_op_test_ref( def test_slice_op(): - tvm.ir.assert_structural_equal(slice_op_test, slice_op_test_ref) + tvm.ir.assert_structural_equal( + slice_op_test.with_attr("global_symbol", "main"), + slice_op_test_ref.with_attr("global_symbol", "main"), + ) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py b/tests/python/unittest/test_tvmscript_parser_tir.py index 38d3e1474656..36df55610868 100644 --- a/tests/python/unittest/test_tvmscript_parser_tir.py +++ b/tests/python/unittest/test_tvmscript_parser_tir.py @@ -70,10 +70,44 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] assert matmul.__name__ == "matmul" + assert matmul.attrs["global_symbol"] == "matmul" + + +def test_tir_func_private_attrs(): + @T.prim_func(private=True) + def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"attr": "value"}) + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + assert "global_symbol" not in matmul.attrs + + +def test_tir_func_private_manual_global_symbol_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @T.prim_func(private=True) + def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "matmul"}) + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + # should not execute + assert matmul.__name__ == "matmul" def test_tir_macro_decorator_signature(): - @T.prim_func + @T.prim_func(private=True) def evaluate0(): T.evaluate(0) @@ -84,7 +118,7 @@ def func1(): assert func1.hygienic - @T.prim_func + @T.prim_func(private=True) def use1(): func1() @@ -97,7 +131,7 @@ def func2(): assert func2.hygienic - @T.prim_func + @T.prim_func(private=True) def use2(): func2() @@ -116,7 +150,7 @@ def assign(i, *args, t1, **kwargs): vi, vj, vk = T.axis.remap("SSR", [i, args[0], args[1]]) kwargs["t3"][vi, vj] = kwargs["t3"][vi, vj] + t1[vi, vk] * kwargs["t2"][vj, vk] - @T.prim_func + @T.prim_func(private=True) def matmul_w_macro(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) @@ -125,7 +159,7 @@ def matmul_w_macro(a: T.handle, b: T.handle, c: T.handle) -> None: with T.block("update"): assign(i, j, k, t1=A, t2=B, t3=C) - @T.prim_func + @T.prim_func(private=True) def matmul_no_macro(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) @@ -145,12 +179,12 @@ def test_tir_macro_hygienic(): def static_capture(A, B): B[()] = A[x_value] - @T.prim_func + @T.prim_func(private=True) def use_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: for x_value in T.serial(10): static_capture(A, B) - @T.prim_func + @T.prim_func(private=True) def expected_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: for x_value in range(10): B[()] = A[128] @@ -165,12 +199,12 @@ def test_tir_macro_non_hygienic(): def dynamic_capture(A, B): B[()] = A[x_value] - @T.prim_func + @T.prim_func(private=True) def use_non_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: for x_value in T.serial(10): dynamic_capture(A, B) - @T.prim_func + @T.prim_func(private=True) def expected_non_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: for x_value in range(10): B[()] = A[x_value] diff --git a/tests/python/unittest/test_tvmscript_printer_annotation.py b/tests/python/unittest/test_tvmscript_printer_annotation.py index 72d2238b2b63..98e6d7c0596c 100644 --- a/tests/python/unittest/test_tvmscript_printer_annotation.py +++ b/tests/python/unittest/test_tvmscript_printer_annotation.py @@ -35,7 +35,7 @@ def _func(): def test_annotation_multi_object_paths(): - result = _func.script( + result = _func.with_attr("global_symbol", "main").script( path_to_annotate={ ObjectPath.root().attr("body").attr("seq").array_index(1): "annotation 1", ObjectPath.root().attr("body").attr("seq").array_index(3): "annotation 3", @@ -61,7 +61,7 @@ def main(): def test_annotate_from_multi_obj(): - result = _func.script( + result = _func.with_attr("global_symbol", "main").script( obj_to_annotate={ _func.body.seq[1]: "annotation 1", _func.body.seq[3]: "annotation 3", diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index e6334553d64f..b16a3b05282c 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -37,7 +37,7 @@ def test_prim_func(): b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"), }, body=tir.Evaluate(0), - ) + ).with_attr("global_symbol", "main") _assert_print( func, expected=""" @@ -60,7 +60,7 @@ def test_prim_func_no_sugar_inlined_buffer(): b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"), }, body=tir.Evaluate(a), - ) + ).with_attr("global_symbol", "main") _assert_print( func, expected=""" @@ -86,7 +86,7 @@ def test_prim_func_no_sugar_shared_buffer_data(): b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B", data=buffer_data), }, body=tir.Evaluate(0), - ) + ).with_attr("global_symbol", "main") _assert_print( func, expected=""" @@ -759,8 +759,8 @@ def main(): T.reads() T.writes() T.evaluate(0)""" - _assert_print(block_with_remap_explicitly, expected_output) - _assert_print(block_with_remap_implicitly, expected_output) + _assert_print(block_with_remap_explicitly.with_attr("global_symbol", "main"), expected_output) + _assert_print(block_with_remap_implicitly.with_attr("global_symbol", "main"), expected_output) def test_root_block(): @@ -794,8 +794,51 @@ def main(): T.writes() T.evaluate(0) """ - _assert_print(root_block_implicitly, expected_output) - _assert_print(root_block_explicitly, expected_output) + _assert_print(root_block_implicitly.with_attr("global_symbol", "main"), expected_output) + _assert_print(root_block_explicitly.with_attr("global_symbol", "main"), expected_output) + + +def test_private_primfunc(): + from tvm.script import tir as T + + a = tir.Var("a", "handle") + b = tir.Var("b", "handle") + func = tir.PrimFunc( + params=[a, b], + ret_type=None, + buffer_map={ + a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A"), + b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"), + }, + body=tir.Evaluate(0), + ) + _assert_print( + func, + expected=""" +# from tvm.script import tir as T + +@T.prim_func(private=True) +def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + T.evaluate(0)""", + ) + + +def test_prim_func_different_symbol(): + from tvm.script import tir as T + + @T.prim_func + def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + T.func_attr({"global_symbol": "func"}) + T.evaluate(0) + + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func +def func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + T.evaluate(0) + """ + _assert_print(main, expected_output) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tvmscript_printer_underlining.py b/tests/python/unittest/test_tvmscript_printer_underlining.py index 569f03d0f828..a0fc139a2d29 100644 --- a/tests/python/unittest/test_tvmscript_printer_underlining.py +++ b/tests/python/unittest/test_tvmscript_printer_underlining.py @@ -410,7 +410,7 @@ def func(a: T.int32, b: T.int32): T.evaluate(a) T.evaluate(b) - result = func.script(obj_to_underline=[func.params[0]]) + result = func.with_attr("global_symbol", "main").script(obj_to_underline=[func.params[0]]) assert result == format_script( """ # from tvm.script import tir as T @@ -442,7 +442,7 @@ def func(): T.evaluate(6) T.evaluate(7) - result = func.script( + result = func.with_attr("global_symbol", "main").script( obj_to_underline=[ func.body.seq[1], func.body.seq[3], @@ -477,7 +477,7 @@ def test_underline_func(): def func(): T.evaluate(0) - result = func.script( + result = func.with_attr("global_symbol", "main").script( path_to_underline=[ ObjectPath.root(), ] diff --git a/tests/python/unittest/test_tvmscript_regression.py b/tests/python/unittest/test_tvmscript_regression.py index c4ca23b3f037..d531acc2e993 100644 --- a/tests/python/unittest/test_tvmscript_regression.py +++ b/tests/python/unittest/test_tvmscript_regression.py @@ -54,7 +54,9 @@ def func_ref(): a = T.alloc_buffer([10, 10], dtype="int8") T.evaluate(0) - tvm.ir.assert_structural_equal(test_case, func_ref) + tvm.ir.assert_structural_equal( + test_case.with_attr("global_symbol", "main"), func_ref.with_attr("global_symbol", "main") + ) def test_var_capturing_order(): @@ -69,7 +71,9 @@ def func_ref(): k: T.int32 = 2 T.evaluate(0) - tvm.ir.assert_structural_equal(test_case, func_ref) + tvm.ir.assert_structural_equal( + test_case.with_attr("global_symbol", "main"), func_ref.with_attr("global_symbol", "main") + ) def test_tir_buffer_region_extent_correct_dtype(): diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 90d2599b58bd..105ea62fd572 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -32,7 +32,7 @@ class Module: @T.prim_func def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: # function attr dict - T.func_attr({"global_symbol": "mmult", "tir.noalias": True}) + T.func_attr({"tir.noalias": True}) # buffer definition C_global = T.Buffer([1024, 1024], elem_offset=0, align=64, offset_factor=1) packedB = T.Buffer([32, 1024, 32], elem_offset=0, align=64, offset_factor=1) @@ -89,7 +89,7 @@ class Module: @T.prim_func def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: # function attr dict - T.func_attr({"global_symbol": "mmult", "tir.noalias": True}) + T.func_attr({"tir.noalias": True}) A_1 = T.match_buffer(A, [16384], elem_offset=0, align=64, offset_factor=1) B_1 = T.match_buffer(B, [1024, 1024], elem_offset=0, align=64, offset_factor=1) C_1 = T.match_buffer(C, [16384], elem_offset=0, align=64, offset_factor=1) @@ -196,7 +196,6 @@ def mmult( T.func_attr( { "tir.noalias": True, - "global_symbol": "mmult", "tir.is_entry_func": True, "calling_conv": 1, } @@ -3567,7 +3566,9 @@ def func(A: T.Buffer(128, "float32"), C: T.Buffer(128, "float32")): for i in T.thread_binding(128, thread="threadIdx.x"): C[i] = B[i] + 2.0 - mod = tvm.tir.transform.LowerOpaqueBlock()(tvm.IRModule.from_expr(func)) + mod = tvm.tir.transform.LowerOpaqueBlock()( + tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + ) return mod["main"] @@ -3906,6 +3907,23 @@ def func() -> T.int32: return func +def return_zero_private(): + @T.prim_func(private=True) + def func() -> T.int32: + T.ret(0) + + return func + + +def return_zero_private_with_attr(): + @T.prim_func(private=True) + def func() -> T.int32: + T.func_attr({"greeting": "hello"}) + T.ret(0) + + return func + + def op_of_literal(): op_list = [ (T.exp, 0), @@ -4032,6 +4050,8 @@ def func(): undefined_elem_offset_in_decl_buffer, subroutine_call_without_arguments, return_zero, + return_zero_private, + return_zero_private_with_attr, *op_of_literal(), ) diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index 41262a6669a3..ecde549b4afa 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -16,12 +16,13 @@ # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement import sys +from typing import Any import pytest import tvm.testing -from tvm.ir import assert_structural_equal from tvm.script import from_source from tvm.script import tir as T +from tvm.tir.schedule.testing import assert_structural_equal_ignore_global_symbol @T.prim_func @@ -61,7 +62,9 @@ def transformed_matmul_syntax_sugar(a: T.handle, b: T.handle, c: T.handle) -> No def test_reads_writes_syntax_sugar(): - assert_structural_equal(transformed_matmul_no_syntax_sugar, transformed_matmul_syntax_sugar) + assert_structural_equal_ignore_global_symbol( + transformed_matmul_no_syntax_sugar, transformed_matmul_syntax_sugar + ) @T.prim_func @@ -89,7 +92,7 @@ def loop_syntax_sugar(a: T.handle) -> None: def test_loop_syntax_sugar(): - assert_structural_equal(loop_no_syntax_sugar, loop_syntax_sugar) + assert_structural_equal_ignore_global_symbol(loop_no_syntax_sugar, loop_syntax_sugar) # match buffer - use kwargs @@ -132,9 +135,9 @@ def elementwise_buffer_no_kwargs( def test_match_buffer_syntax_sugar(): # with kwargs - assert_structural_equal(elementwise_handle, elementwise_buffer_kwargs) + assert_structural_equal_ignore_global_symbol(elementwise_handle, elementwise_buffer_kwargs) # without kwargs - assert_structural_equal(elementwise_handle, elementwise_buffer_no_kwargs) + assert_structural_equal_ignore_global_symbol(elementwise_handle, elementwise_buffer_no_kwargs) def test_match_buffer_1d(): @@ -149,7 +152,7 @@ def func_with_sugar(A: T.Buffer(16, "float32")): for i in T.serial(16): A[i] = 0.0 - assert_structural_equal(func_no_sugar, func_with_sugar) + assert_structural_equal_ignore_global_symbol(func_no_sugar, func_with_sugar) # dynamic shape gemm @@ -171,7 +174,7 @@ def gemm_dyn_shape(a: T.handle, b: T.handle, c: T.handle): def test_dynamic_shape_gemm(): gemm_dyn_shape_roundtrip = from_source(gemm_dyn_shape.script()) - assert_structural_equal(gemm_dyn_shape, gemm_dyn_shape_roundtrip) + assert_structural_equal_ignore_global_symbol(gemm_dyn_shape, gemm_dyn_shape_roundtrip) @T.prim_func @@ -208,7 +211,7 @@ def match_buffer_int64_after_roundtrip( def test_match_buffer_int64(): original = match_buffer_int64 after_roundtrip = match_buffer_int64_after_roundtrip - assert_structural_equal(original, after_roundtrip, True) + assert_structural_equal_ignore_global_symbol(original, after_roundtrip, True) def test_match_buffer_region_has_implicit_shape_dtype(): @@ -224,7 +227,7 @@ def implicit_shape_dtype(A: T.Buffer((16, 64), "int32")): B = T.match_buffer(A[8:16, 32:64]) T.evaluate(0) - assert_structural_equal(explicit_shape_dtype, implicit_shape_dtype) + assert_structural_equal_ignore_global_symbol(explicit_shape_dtype, implicit_shape_dtype) def test_match_buffer_input_requires_shape_arg(): @@ -263,7 +266,7 @@ def constant_binds_wrapped(): y = T.meta_var(T.float32(42.0)) T.evaluate(T.cast(x, "float32") + y) - assert_structural_equal(constant_binds, constant_binds_wrapped) + assert_structural_equal_ignore_global_symbol(constant_binds, constant_binds_wrapped) def test_func_call(): @@ -322,7 +325,9 @@ def mma_sync_m16n16k16_desc_manual(a: T.handle, b: T.handle, c: T.handle) -> Non * B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2] ) - assert_structural_equal(mma_sync_m16n16k16_desc, mma_sync_m16n16k16_desc_manual) + assert_structural_equal_ignore_global_symbol( + mma_sync_m16n16k16_desc, mma_sync_m16n16k16_desc_manual + ) # The following is an example of an error message from calling an invalid function @@ -370,7 +375,7 @@ def int64_grid_expanded( vj = T.axis.spatial(T.int64(128), j) B[vi, vj] = A[vi, vj] + 1.0 - assert_structural_equal(int64_grid, int64_grid_expanded) + assert_structural_equal_ignore_global_symbol(int64_grid, int64_grid_expanded) def test_implicit_evaluate_assume(): @@ -384,7 +389,7 @@ def implicit(A: T.Buffer(1, "int32")): T.assume(A[0] == 5) A[0] = 10 - assert_structural_equal(implicit, explicit) + assert_structural_equal_ignore_global_symbol(implicit, explicit) def test_implicit_evaluate_call_extern(): @@ -396,7 +401,7 @@ def explicit(A: T.Buffer(1, "int32")): def implicit(A: T.Buffer(1, "int32")): T.call_extern("extern_func", A.data, dtype="int32") - assert_structural_equal(implicit, explicit) + assert_structural_equal_ignore_global_symbol(implicit, explicit) def test_preserve_trivial_let_binding(): @@ -411,7 +416,7 @@ def implicit(i: T.int32): j = i T.evaluate(j) - assert_structural_equal(implicit, explicit) + assert_structural_equal_ignore_global_symbol(implicit, explicit) def test_preserve_trivial_let_binding_of_value(): @@ -426,7 +431,7 @@ def implicit(i: T.int32): j = 42 T.evaluate(j) - assert_structural_equal(implicit, explicit) + assert_structural_equal_ignore_global_symbol(implicit, explicit) def test_preserve_parameter_name(): @@ -463,7 +468,7 @@ def explicit(): def implicit(): T.evaluate(True) - assert_structural_equal(implicit, explicit) + assert_structural_equal_ignore_global_symbol(implicit, explicit) def test_foldable_boolean_in_assert(): @@ -484,7 +489,7 @@ def implicit(): assert 0 == 1, "Message" T.evaluate(0) - assert_structural_equal(implicit, explicit) + assert_structural_equal_ignore_global_symbol(implicit, explicit) if __name__ == "__main__":