From 65ac46a66cd3f350db4dae04d0e8c93431266ed4 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 25 May 2023 09:36:51 -0500 Subject: [PATCH 1/4] [Bugfix] Avoid symbol conflicts in MakePackedAPI/MakeUnpackedAPI PRs https://github.com/apache/tvm/pull/14913 and https://github.com/apache/tvm/pull/14914 made analogous changes to `MakePackedAPI` and `MakeUnpackedAPI` to handle subroutine calls. Both PRs introduced the same symbol, `tvm::tir::SubroutineCallRewriter`, a local utility to update internal calls to a modified function. While each PR passed CI individually, and was therefore able to merge, having both changes caused a duplicate symbol. This commit updates `MakePackedAPI` and `MakeUnpackedAPI` to place their local utilities into anonymous namespaces, avoiding the conflict. --- src/tir/transforms/make_packed_api.cc | 3 +++ src/tir/transforms/make_unpacked_api.cc | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index dd9d471c5066..825a8da45b27 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -42,6 +42,7 @@ namespace tir { static constexpr const char* kDeviceContextVar = "device_api_context"; +namespace { class ReturnRewriter : public StmtMutator { public: explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {} @@ -176,6 +177,8 @@ class SubroutineCallRewriter : public StmtExprMutator { bool made_change_{false}; }; +} // namespace + inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); } diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 82685411f592..bdb3a953e99c 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -40,6 +40,8 @@ namespace tvm { namespace tir { +namespace { + class SubroutineCallRewriter : public StmtExprMutator { public: static Optional Apply(const std::unordered_set& external_methods, @@ -84,6 +86,8 @@ class SubroutineCallRewriter : public StmtExprMutator { bool made_change_{false}; }; +} // namespace + PrimFunc MakeUnpackedAPI(PrimFunc func) { // A function with an explicit calling convention has already been // lowered, and should not be modified. From 8c27815a0f5b2a7c00ef41e5fbc7728091c22a19 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 24 May 2023 09:49:26 -0500 Subject: [PATCH 2/4] [Target] Added utility method TargetNode::HasKey() This utility method makes it easier to determine if a target contains a specific key. --- include/tvm/target/target.h | 10 ++++++++++ src/target/target.cc | 5 +++++ 2 files changed, 15 insertions(+) diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 891700b86a4c..1c46828c3049 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -71,6 +71,16 @@ class TargetNode : public Object { /*! \return The device type for this target */ TVM_DLL int GetTargetDeviceType() const; + /*! + * \brief Check if the target contains a key + * + * \param query_key The string name of the key to be checked + * + * \return True if the target's `TargetNode::keys` contains the + * specified key, False otherwise. + */ + TVM_DLL bool HasKey(const std::string& query_key) const; + /*! * \brief Returns a human readable representation of \p Target which includes all fields, * especially the host. Useful for diagnostic messages and debugging. diff --git a/src/target/target.cc b/src/target/target.cc index f05d4db2b888..3d51e0ad2766 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -669,6 +669,11 @@ int TargetNode::GetTargetDeviceType() const { return kind->default_device_type; } +bool TargetNode::HasKey(const std::string& query_key) const { + return std::any_of(keys.begin(), keys.end(), + [&query_key](const auto& key) { return key == query_key; }); +} + String TargetNode::ToDebugString() const { std::ostringstream os; os << "Target("; From 0918064ff1949b3ce2874712bbe23da50981eba2 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 24 May 2023 09:43:31 -0500 Subject: [PATCH 3/4] [TIR] Added utility method tvm::tir::IsHostFunc(const PrimFunc&) For modules that contain both host and device functions, this utility function checks whether a given PrimFunc is a host function, based on the target annotation. --- src/tir/transforms/ir_utils.cc | 10 ++++++++++ src/tir/transforms/ir_utils.h | 12 ++++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 9b47d84e6aa2..604dbed325ec 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -692,6 +692,16 @@ std::pair GetWmmaFragmentDimSize(const std::string& shape_str, return std::pair(0, 0); } +std::optional IsHostFunc(const PrimFunc& func) { + if (func->HasNonzeroAttr(tvm::tir::attr::kIsHostFunc)) { + return true; + } else if (auto target = func->GetAttr(tvm::attr::kTarget)) { + return target.value()->HasKey("cpu"); + } else { + return std::nullopt; + } +} + namespace transform { Pass ConvertSSA() { auto pass_func = [](IRModule mod, PassContext ctx) { diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 59dc95dcd6a0..b48502871372 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -34,6 +34,7 @@ #include #include +#include #include #include #include @@ -351,6 +352,17 @@ CollectStorageAlignAnnotation(const Stmt& body); std::pair GetWmmaFragmentDimSize(const std::string& shape_str, const std::string& scope); +/*! \brief Check if a PrimFunc is a host function + * + * \param func The function to be inspected + * + * \return True if the function is known to run on the host, false if + * the function is known to run on the device. If it cannot be + * determined (e.g. a function without a tvm::attr::kTarget + * attribute), returns std::nullopt. + */ +std::optional IsHostFunc(const PrimFunc& func); + } // namespace tir } // namespace tvm #endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_ From 7b47f29f49be3ad61db876238eda26787be51b3b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 24 Mar 2023 14:42:33 -0500 Subject: [PATCH 4/4] [TIR] Restrict InstallDebugSpans to host functions Previously, the `tir.InstallDebugSpans` pass required the module to contain only a single PrimFunc. This commit relaxes the requirement, to require a single host-side PrimFunc, and to ignore any other device-side functions. --- src/tir/transforms/install_debug_spans.cc | 36 ++++++++++------ tests/python/tir/test_debug_info.py | 50 ++++++++++++++++++++++- 2 files changed, 73 insertions(+), 13 deletions(-) diff --git a/src/tir/transforms/install_debug_spans.cc b/src/tir/transforms/install_debug_spans.cc index c97070e1bf89..ea61378ccccc 100644 --- a/src/tir/transforms/install_debug_spans.cc +++ b/src/tir/transforms/install_debug_spans.cc @@ -31,6 +31,7 @@ #include #include "../../relay/printer/tir_text_printer_debug.h" +#include "ir_utils.h" namespace tvm { namespace tir { @@ -128,19 +129,30 @@ TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_STMTS namespace transform { Pass InstallDebugSpans() { - auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - ICHECK(m->functions.size() == 1) - << "Debug info can only be added to IRModules with a single function"; - // There is known to be only 1 function in the module at this point - auto entry = m->functions.begin(); - auto name = std::get<0>(*entry)->name_hint; - auto* n = f.CopyOnWrite(); - - n->body = DebugInfoInstaller::InstallInfo(std::move(name), std::move(f->body)); - - return f; + auto pass_func = [](IRModule mod, PassContext ctx) { + Map external_host_functions; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto opt = base_func.as()) { + auto prim_func = opt.value(); + if (IsHostFunc(prim_func).value_or(false) && + prim_func->GetAttr(tvm::attr::kGlobalSymbol)) { + external_host_functions.Set(gvar, prim_func); + } + } + } + + ICHECK_EQ(external_host_functions.size(), 1) + << "Debug info can only be added to IRModules with a single host function"; + + for (auto [gvar, prim_func] : external_host_functions) { + auto name = prim_func->GetAttr(tvm::attr::kGlobalSymbol).value(); + prim_func.CopyOnWrite()->body = DebugInfoInstaller::InstallInfo(name, prim_func->body); + mod.CopyOnWrite()->Update(gvar, prim_func); + } + + return mod; }; - return CreatePrimFuncPass(pass_func, 0, "tir.InstallDebugSpans", {}); + return tvm::transform::CreateModulePass(pass_func, 0, "tir.InstallDebugSpans", {}); } TVM_REGISTER_GLOBAL("tir.transform.InstallDebugSpans").set_body_typed(InstallDebugSpans); diff --git a/tests/python/tir/test_debug_info.py b/tests/python/tir/test_debug_info.py index 8ecabbd51a97..d333b43b28f5 100644 --- a/tests/python/tir/test_debug_info.py +++ b/tests/python/tir/test_debug_info.py @@ -46,7 +46,13 @@ class MyModule: @T.prim_func def main(a: T.handle, b: T.handle): # We exchange data between function by handles, which are similar to pointer. - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr( + { + "global_symbol": "main", + "tir.noalias": True, + "target": T.target("llvm"), + } + ) # Create buffer from handles. A = T.match_buffer(a, (8,), dtype="float32") B = T.match_buffer(b, (8,), dtype="float32") @@ -83,6 +89,48 @@ def find_span(m): assert span_after.line == 4 +def test_tir_debug_info_with_subroutine(): + """Like test_tir_debug_info, but with a TIR subroutine + + The current InstallDebugSpans applies to a single PrimFunc. This + test verifies that the existence of device-side subroutines + + """ + + def find_span(m): + func = next(m.functions.values()) + return func.body.block.body.span + + @tvm.script.ir_module + class module_before: + @T.prim_func + def main(a: T.handle, b: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": True, "target": T.target("llvm")}) + A = T.match_buffer(a, (8,), dtype="float32") + B = T.match_buffer(b, (8,), dtype="float32") + for i in range(8): + with T.block("B"): + vi = T.axis.spatial(8, i) + module_before.subroutine(T.address_of(A[vi]), T.address_of(B[vi])) + + @T.prim_func + def subroutine(a_ptr: T.handle("float32"), b_ptr: T.handle("float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.decl_buffer(1, "float32", data=a_ptr) + B = T.decl_buffer(1, "float32", data=b_ptr) + B[0] = A[1] + 1.0 + + span_before = find_span(module_before) + assert span_before is None + + module_after = tir.transform.InstallDebugSpans()(module_before) + span_after = find_span(module_after) + + # Check that the module name has been added and a line number is present + assert span_after.source_name.name == "main.tir" + assert span_after.line == 4 + + def test_llvm_ir_debug_info(): """ Check that the right amount of debug locations are present