From bb03862c56db6d03cc23f47bda100338aa84c550 Mon Sep 17 00:00:00 2001 From: Hongyi Jin Date: Wed, 13 Dec 2023 21:51:31 -0500 Subject: [PATCH 1/2] convert ssa process entry func first --- src/tir/transforms/ir_utils.cc | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 25c10dd6828d..a4476d46b418 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -717,6 +717,9 @@ Pass ConvertSSA() { bool made_change = false; for (auto [gvar, base_func] : mod->functions) { if (auto* ptr = base_func.as()) { + if (!ptr->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + continue; + } auto updated = converter.VisitPrimFunc(GetRef(ptr)); if (!updated.same_as(base_func)) { made_change = true; @@ -725,6 +728,19 @@ Pass ConvertSSA() { } functions.Set(gvar, base_func); } + for (auto [gvar, base_func] : mod->functions) { + if (auto* ptr = base_func.as()) { + if (ptr->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + continue; + } + auto updated = converter.VisitPrimFunc(GetRef(ptr)); + if (!updated.same_as(base_func)) { + made_change = true; + base_func = updated; + } + functions.Set(gvar, base_func); + } + } if (made_change) { mod.CopyOnWrite()->functions = std::move(functions); } From febe50bdf489cf7806c394fafd18fe56713ea8fe Mon Sep 17 00:00:00 2001 From: Hongyi Jin Date: Wed, 13 Dec 2023 22:50:18 -0500 Subject: [PATCH 2/2] add doc --- src/tir/transforms/ir_utils.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index a4476d46b418..6b681c07e5d5 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -715,6 +715,8 @@ Pass ConvertSSA() { tir::IRConvertSSA converter; Map functions; bool made_change = false; + // FIXME: This is just a temporal workaround to ensure free vars + // in device function have the same pointer as the host function for (auto [gvar, base_func] : mod->functions) { if (auto* ptr = base_func.as()) { if (!ptr->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {