From 65ac46a66cd3f350db4dae04d0e8c93431266ed4 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 25 May 2023 09:36:51 -0500 Subject: [PATCH 1/3] [Bugfix] Avoid symbol conflicts in MakePackedAPI/MakeUnpackedAPI PRs https://github.com/apache/tvm/pull/14913 and https://github.com/apache/tvm/pull/14914 made analogous changes to `MakePackedAPI` and `MakeUnpackedAPI` to handle subroutine calls. Both PRs introduced the same symbol, `tvm::tir::SubroutineCallRewriter`, a local utility to update internal calls to a modified function. While each PR passed CI individually, and was therefore able to merge, having both changes caused a duplicate symbol. This commit updates `MakePackedAPI` and `MakeUnpackedAPI` to place their local utilities into anonymous namespaces, avoiding the conflict. --- src/tir/transforms/make_packed_api.cc | 3 +++ src/tir/transforms/make_unpacked_api.cc | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index dd9d471c5066..825a8da45b27 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -42,6 +42,7 @@ namespace tir { static constexpr const char* kDeviceContextVar = "device_api_context"; +namespace { class ReturnRewriter : public StmtMutator { public: explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {} @@ -176,6 +177,8 @@ class SubroutineCallRewriter : public StmtExprMutator { bool made_change_{false}; }; +} // namespace + inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); } diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 82685411f592..bdb3a953e99c 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -40,6 +40,8 @@ namespace tvm { namespace tir { +namespace { + class SubroutineCallRewriter : public StmtExprMutator { public: static Optional Apply(const std::unordered_set& external_methods, @@ -84,6 +86,8 @@ class SubroutineCallRewriter : public StmtExprMutator { bool made_change_{false}; }; +} // namespace + PrimFunc MakeUnpackedAPI(PrimFunc func) { // A function with an explicit calling convention has already been // lowered, and should not be modified. From 6867e6787b93df218e4fe6718edc6da8baec2c42 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 29 Mar 2023 13:48:31 -0500 Subject: [PATCH 2/3] [Target] Added WithoutHost method --- include/tvm/target/target.h | 3 +++ src/target/target.cc | 10 ++++++++++ 2 files changed, 13 insertions(+) diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 891700b86a4c..5c88807682d7 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -218,6 +218,9 @@ class Target : public ObjectRef { */ static Target WithHost(const Target& target, const Target& host); + /*! \return The target with the host stripped out */ + Target WithoutHost() const; + /*! * \brief Returns true if \p this target represents an external codegen. If so, * \p this->kind->name can be used as the "Compiler" attribute on partitioned functions, diff --git a/src/target/target.cc b/src/target/target.cc index f05d4db2b888..e479f592c640 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -662,6 +662,16 @@ Map TargetNode::Export() const { Optional TargetNode::GetHost() const { return this->host.as(); } +Target Target::WithoutHost() const { + if ((*this)->GetHost()) { + auto output = make_object(*get()); + output->host = NullOpt; + return Target(output); + } else { + return *this; + } +} + int TargetNode::GetTargetDeviceType() const { if (Optional device_type = GetAttr("target_device_type")) { return Downcast(device_type)->value; From a328c1b81eb341c9b437ac721c6ce99fc490aee2 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 24 Mar 2023 14:49:05 -0500 Subject: [PATCH 3/3] [TIR] Preserve existing kTarget function attribute in BindTarget Previously, if a function already has a `tvm::attr::kTarget` attribute, it will be overwritten by the `tir.BindTarget` transform. This commit updates the behavior such that `tir.BindTarget` adds annotations to functions that are missing a target annotation, but preserves any existing target annotations. This is part of a series of commits to simplify the handling of multi-target builds. --- src/tir/transforms/primfunc_utils.cc | 30 ++++- .../unittest/test_tir_transform_helpers.py | 112 ++++++++++++++++++ 2 files changed, 137 insertions(+), 5 deletions(-) diff --git a/src/tir/transforms/primfunc_utils.cc b/src/tir/transforms/primfunc_utils.cc index 257e3eacda90..f844b51f5394 100644 --- a/src/tir/transforms/primfunc_utils.cc +++ b/src/tir/transforms/primfunc_utils.cc @@ -30,12 +30,32 @@ namespace tvm { namespace tir { namespace transform { transform::Pass BindTarget(Target target) { - auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { - if (f->GetAttr(tvm::tir::attr::kIsHostFunc) == 1) { - return WithAttr(std::move(WithoutAttr(std::move(f), tvm::tir::attr::kIsHostFunc)), - tvm::attr::kTarget, target->host.value_or(Target("llvm"))); + Target without_host = target.WithoutHost(); + Target target_host = Downcast(target->host.value_or(Target("llvm"))); + + auto fpass = [target, target_host, without_host](tir::PrimFunc func, IRModule m, + transform::PassContext ctx) { + bool is_externally_exposed = func->GetAttr(tvm::attr::kGlobalSymbol).defined(); + + if (auto func_target = func->GetAttr(tvm::attr::kTarget)) { + auto func_target_host = func_target.value()->GetHost(); + auto target_host = target->GetHost(); + + if (target_host && !func_target_host && is_externally_exposed) { + auto new_target = Target::WithHost(func_target.value(), target_host.value()); + func = WithAttr(std::move(func), tvm::attr::kTarget, new_target); + } + } else if (func->HasNonzeroAttr(tvm::tir::attr::kIsHostFunc)) { + func = WithAttr(std::move(func), tvm::attr::kTarget, target_host); + } else if (is_externally_exposed) { + func = WithAttr(std::move(func), tvm::attr::kTarget, target); + } else { + func = WithAttr(std::move(func), tvm::attr::kTarget, without_host); } - return WithAttr(std::move(f), tvm::attr::kTarget, target); + + func = WithoutAttr(std::move(func), tvm::tir::attr::kIsHostFunc); + + return func; }; return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.BindTarget", {}); } diff --git a/tests/python/unittest/test_tir_transform_helpers.py b/tests/python/unittest/test_tir_transform_helpers.py index 657bda591ae2..00fd12521268 100644 --- a/tests/python/unittest/test_tir_transform_helpers.py +++ b/tests/python/unittest/test_tir_transform_helpers.py @@ -85,6 +85,118 @@ def test_bind_target(): assert after["func2"].attrs["target"] == target +class TestBindTarget(tvm.testing.CompareBeforeAfter): + """BindTarget adds the "target" attribute""" + + transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda")) + + def before(): + T.evaluate(0) + + def expected(): + T.func_attr({"target": T.target("cuda")}) + T.evaluate(0) + + +class TestBindTargetWithHostToExposedFunction(tvm.testing.CompareBeforeAfter): + """BindTarget adds the host target to externally-exposed functions""" + + transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda", host="llvm")) + + def before(): + T.func_attr({"global_symbol": "main"}) + T.evaluate(0) + + def expected(): + T.func_attr({"global_symbol": "main", "target": T.target("cuda", host="llvm")}) + T.evaluate(0) + + +class TestBindTargetWithHostToInternalFunction(tvm.testing.CompareBeforeAfter): + """Internal functions have a target annotation, but without the host + + The host portion of the target annotation provides host + parameters, and is used to expose a function externally as part of + `MakePackedAPI` and `MakeUnpackedAPI`. For internal functions, no + external exposure is required, so the host attribute should not be + used. + """ + + transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda", host="llvm")) + + def before(): + T.evaluate(0) + + def expected(): + T.func_attr({"target": T.target("cuda")}) + T.evaluate(0) + + +class TestBindTargetIgnoresExisting(tvm.testing.CompareBeforeAfter): + """BindTarget should not replace existing annotations""" + + transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda")) + + def before(): + T.func_attr({"target": T.target("nvptx")}) + T.evaluate(0) + + expected = before + + +class TestBindTargetUpdatesHost(tvm.testing.CompareBeforeAfter): + """BindTarget should update host for existing annotations""" + + transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda", host="llvm -opt-level=0")) + + def before(): + T.func_attr({"global_symbol": "func", "target": T.target("nvptx")}) + T.evaluate(0) + + def expected(): + T.func_attr( + { + "global_symbol": "func", + "target": T.target("nvptx", host="llvm -opt-level=0"), + } + ) + T.evaluate(0) + + +class TestBindTargetMultipleFunctions(tvm.testing.CompareBeforeAfter): + """BindTarget may apply to multiple functions in a module""" + + transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda")) + + def before(self): + @tvm.script.ir_module + class mod: + @T.prim_func + def func1(): + T.evaluate(0) + + @T.prim_func + def func2(): + T.evaluate(0) + + return mod + + def expected(self): + @tvm.script.ir_module + class mod: + @T.prim_func + def func1(): + T.func_attr({"target": T.target("cuda")}) + T.evaluate(0) + + @T.prim_func + def func2(): + T.func_attr({"target": T.target("cuda")}) + T.evaluate(0) + + return mod + + def test_filter_primfunc(): mod = MockModule assert mod