diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index cfc7fa80c7a9..b75f173e0d00 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -578,6 +578,9 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::InjectPTXLDG32()); } + mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions()); + mixed_pass_list.push_back(tir::transform::SplitHostDevice()); + bool unpacked_api = mixed_mod->GetAttr(tvm::attr::kExecutor) .value_or(relay::Executor::Create("graph", {})) ->GetAttr("unpacked-api") @@ -590,8 +593,6 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::FP8StorageLegalize()); mixed_pass_list.push_back(tir::transform::BF16StorageLegalize()); - mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions()); - mixed_pass_list.push_back(tir::transform::SplitHostDevice()); mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch()); return transform::Sequential(mixed_pass_list); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 062f1c050961..804d01325b6b 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -183,33 +183,17 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); } -/* \brief Return the global_symbol of the function, if it should be updated - * - * \param func The function to be inspected - * - * \returns The global_symbol to be used for the function at call - * sites, or NullOpt if the function is to remain unchanged. - */ -Optional RequiresPackedAPI(const PrimFunc& func) { +PrimFunc MakePackedAPI(PrimFunc func) { // A function with an explicit calling convention has already been // lowered, and should not be modified. if (auto opt = func->GetAttr(tvm::attr::kCallingConv)) { if (CallingConv(opt.value()->value) != CallingConv::kDefault) { - return NullOpt; + return func; } } // Internal function calls do not need the PackedFunc API auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); - if (!global_symbol.defined()) { - return NullOpt; - } - - return global_symbol; -} - -PrimFunc MakePackedAPI(PrimFunc func) { - auto global_symbol = RequiresPackedAPI(func); if (!global_symbol.defined()) { return func; } @@ -218,11 +202,20 @@ PrimFunc MakePackedAPI(PrimFunc func) { Target target = [&]() { auto opt = func->GetAttr(tvm::attr::kTarget); ICHECK(opt) << "MakePackedAPI required the function to be annotated with tvm::attr::kTarget (" - << tvm::attr::kTarget << "), but the function only has attributes " << func->attrs; + << tvm::attr::kTarget << "), but the function " << name_hint + << " only has attributes" << func->attrs; return opt.value(); }(); int target_device_type = target->GetTargetDeviceType(); + // A function without a host target has already been lowered. + Target target_host; + if (auto opt = target->GetHost()) { + target_host = opt.value(); + } else { + return func; + } + auto* func_ptr = func.CopyOnWrite(); const Stmt nop = Evaluate(0); int num_args = static_cast(func_ptr->params.size()); @@ -325,7 +318,8 @@ PrimFunc MakePackedAPI(PrimFunc func) { name_hint + "." + kv.first->name_hint); } - func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)); + func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)}, + {tvm::attr::kTarget, target_host}}); Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode); body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope, @@ -368,38 +362,43 @@ namespace transform { Pass MakePackedAPI() { auto pass_func = [](IRModule mod, PassContext ctx) { Map packed_func_methods; - for (const auto& [gvar, base_func] : mod->functions) { - if (auto opt = base_func.as()) { - auto prim_func = opt.value(); - if (auto global_symbol = RequiresPackedAPI(prim_func)) { - packed_func_methods.Set(gvar, global_symbol.value()); - } - } - } - IRModuleNode* mptr = mod.CopyOnWrite(); IRModule updates; - - for (const auto& [gvar, base_func] : mptr->functions) { + for (const auto& [gvar, base_func] : mod->functions) { if (auto opt = base_func.as()) { - auto func = opt.value(); - auto orig_func = func; - - if (auto body = SubroutineCallRewriter::Apply(packed_func_methods, func->body)) { - func.CopyOnWrite()->body = body.value(); - } - - func = MakePackedAPI(std::move(func)); + auto orig_func = opt.value(); + auto func = MakePackedAPI(orig_func); if (!func.same_as(orig_func)) { updates->Add(gvar, func); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol).value(); + packed_func_methods.Set(gvar, global_symbol); } } } - if (updates->functions.size()) { mod.CopyOnWrite()->Update(updates); } + + if (packed_func_methods.size()) { + IRModule updates; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto opt = base_func.as()) { + auto func = opt.value(); + auto orig_func = func; + + if (auto body = SubroutineCallRewriter::Apply(packed_func_methods, func->body)) { + func.CopyOnWrite()->body = body.value(); + updates->Add(gvar, func); + } + } + } + + if (updates->functions.size()) { + mod.CopyOnWrite()->Update(updates); + } + } + return mod; }; diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index bdb3a953e99c..4b1b3bf517d0 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -111,6 +111,14 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) { }(); int target_device_type = target->GetTargetDeviceType(); + // A function without a host target has already been lowered. + Target target_host; + if (auto opt = target->GetHost()) { + target_host = opt.value(); + } else { + return func; + } + auto* func_ptr = func.CopyOnWrite(); // Setup device context @@ -145,7 +153,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) { func_ptr->buffer_map = Map(); // return the function. - return func; + return WithAttrs(std::move(func), {{tvm::attr::kTarget, target_host}}); } namespace transform { diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 2de831e8ad0c..29ecaa4e8e43 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -99,10 +99,6 @@ class HostDeviceSplitter : public StmtMutator { }; PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& gvar) { - auto opt_target = func->GetAttr(tvm::attr::kTarget); - ICHECK(opt_target) << "SplitHostDevice: Require the target attribute"; - Target target = opt_target.value(); - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); auto name_prefix = global_symbol.value_or(gvar->name_hint); @@ -112,10 +108,6 @@ PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& g func.CopyOnWrite()->body = body; } - if (auto target_host = target->GetHost()) { - func = WithAttr(std::move(func), tvm::attr::kTarget, target_host.value()); - } - return func; } diff --git a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py index d224a688d298..949bb0477bab 100644 --- a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py @@ -56,7 +56,7 @@ def check_packed_func(target="llvm"): # Construct a valid IRModule to be lowered: mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([a_buffer, b_buffer, c_buffer], stmt)) - target = tvm.target.Target(target) + target = tvm.target.Target(target, host="llvm") mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) mod = tvm.tir.transform.MakePackedAPI()(mod) diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index 8af7efb59604..34adcbb9aee4 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -37,7 +37,7 @@ def test_makeapi(): mod = tvm.tir.transform.Apply( lambda f: f.with_attr( { - "target": tvm.target.Target("llvm"), + "target": tvm.target.Target("llvm", host="llvm"), "global_symbol": "main", } ) @@ -90,7 +90,9 @@ def test_variable_passed_from_args(): stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer, not_device_context], stmt)) - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod) + mod = tvm.tir.transform.Apply( + lambda f: f.with_attr("target", tvm.target.Target("llvm", host="llvm")) + )(mod) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) func = tvm.tir.transform.MakePackedAPI()(mod)["main"] @@ -132,7 +134,9 @@ def test_device_api_context_implicit_resource_handle(): stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer, device_context], stmt)) - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod) + mod = tvm.tir.transform.Apply( + lambda f: f.with_attr("target", tvm.target.Target("llvm", host="llvm")) + )(mod) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) func = tvm.tir.transform.MakePackedAPI()(mod)["main"] @@ -161,7 +165,7 @@ def test_device_api_context_implicit_resource_handle(): @pytest.mark.parametrize("use_global_symbol", [True, False]) def test_no_op_when_global_symbol_is_absent(use_global_symbol): - func_attr = {"target": tvm.target.Target("llvm")} + func_attr = {"target": tvm.target.Target("llvm", host="llvm")} if use_global_symbol: func_attr["global_symbol"] = "main" @@ -177,6 +181,28 @@ def before(): tvm.ir.assert_structural_equal(before, after) +def test_target_host_removed(): + """After MakePackedAPI, host-side target should be the host + + MakePackedAPI is the last transform that requires both the device + and the host. After MakePackedAPI, the target attribute should + only contain the host-side target. + """ + + host = tvm.target.Target("llvm") + + @I.ir_module + class before: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "main", "target": T.target("cuda", host=host)}) + T.evaluate(0) + + after = tvm.tir.transform.MakePackedAPI()(before) + target_attr = after["main"].attrs["target"] + assert str(host) == str(target_attr) + + def test_internal_subroutine_call(): """Internal subroutines should not use the PackedFunc API @@ -190,7 +216,7 @@ def test_internal_subroutine_call(): class before: @T.prim_func def main(A: T.Buffer(1, "float32")): - T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")}) before.subroutine(A.data) @T.prim_func @@ -222,12 +248,12 @@ def test_subroutine_call_to_externally_visible_subroutine(): class before: @T.prim_func def main(A: T.Buffer(1, "float32")): - T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")}) before.subroutine(A.data) @T.prim_func def subroutine(A_data: T.handle("float32")): - T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")}) + T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm", host="llvm")}) T.evaluate(A_data) after = tvm.tir.transform.MakePackedAPI()(before) diff --git a/tests/python/unittest/test_tir_transform_make_unpacked_api.py b/tests/python/unittest/test_tir_transform_make_unpacked_api.py index bb9fe8ab8267..1931f7aef324 100644 --- a/tests/python/unittest/test_tir_transform_make_unpacked_api.py +++ b/tests/python/unittest/test_tir_transform_make_unpacked_api.py @@ -41,9 +41,8 @@ def mod(mod_without_attrs): def test_noop_if_not_global_symbol(mod_without_attrs): - before = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))( - mod_without_attrs - ) + target = tvm.target.Target("llvm", host="llvm") + before = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_without_attrs) after = tvm.tir.transform.MakeUnpackedAPI()(before) tvm.ir.assert_structural_equal(before, after) @@ -59,7 +58,8 @@ def test_fails_if_no_target(mod_without_attrs): @tvm.testing.parametrize_targets("c", "llvm", "cuda") def test_device_setup(mod, target, dev): - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target(target)))(mod) + target = tvm.target.Target(target, host="llvm") + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod) f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] assert len(f.params) == 1 assert f.params[0].name == "A" @@ -138,6 +138,49 @@ def test_body(): assert f.params[2].name == "A" +class TestTargetHostRemoved(tvm.testing.CompareBeforeAfter): + """After MakeUnpackedAPI, host-side target should be the host + + MakeUnpackedAPI is the last transform that requires both the device + and the host. After MakeUnpackedAPI, the target attribute should + only contain the host-side target. + """ + + transform = tvm.tir.transform.MakeUnpackedAPI() + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "main", "target": T.target("cuda", host="llvm")}) + mod.subroutine(A.data) + + @T.prim_func + def subroutine(A_data: T.handle("float32")): + T.func_attr({"target": T.target("cuda")}) + T.evaluate(A_data) + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def main(A_data: T.handle("float32")) -> T.int32: + T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 2) + mod.subroutine(A_data) + + @T.prim_func + def subroutine(A_data: T.handle("float32")): + T.func_attr({"target": T.target("cuda")}) + T.evaluate(A_data) + + return mod + + class TestInternalSubroutineCall(tvm.testing.CompareBeforeAfter): """Internal subroutines do not require modification @@ -153,7 +196,7 @@ def before(self): class mod: @T.prim_func def main(A: T.Buffer(1, "float32")): - T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")}) mod.subroutine(A.data) @T.prim_func @@ -195,12 +238,14 @@ def before(self): class mod: @T.prim_func def main(A: T.Buffer(1, "float32")): - T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")}) mod.subroutine(A.data) @T.prim_func def subroutine(A_data: T.handle("float32")): - T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")}) + T.func_attr( + {"global_symbol": "subroutine", "target": T.target("llvm", host="llvm")} + ) T.evaluate(A_data) return mod @@ -240,7 +285,7 @@ def before(self): class mod: @T.prim_func def main(A: T.Buffer(1, "float32")): - T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")}) mod.subroutine( T.tvm_stack_make_array( A.data, @@ -255,7 +300,9 @@ def main(A: T.Buffer(1, "float32")): @T.prim_func def subroutine(A: T.Buffer(1, "float32")): - T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")}) + T.func_attr( + {"global_symbol": "subroutine", "target": T.target("llvm", host="llvm")} + ) T.evaluate(A.data) return mod diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py b/tests/python/unittest/test_tir_transform_split_host_device.py index 1599b9a031a0..60bfb8a718d2 100644 --- a/tests/python/unittest/test_tir_transform_split_host_device.py +++ b/tests/python/unittest/test_tir_transform_split_host_device.py @@ -46,6 +46,7 @@ def test_split_host_device_func_attr(): [ tvm.tir.transform.AnnotateDeviceRegions(), tvm.tir.transform.SplitHostDevice(), + tvm.tir.transform.MakePackedAPI(), tvm.tir.transform.LowerDeviceKernelLaunch(), ] )(mod) @@ -111,7 +112,7 @@ def expected(self): class mod: @T.prim_func def main(n: T.int32): - T.func_attr({"target": T.target("llvm -opt-level=0")}) + T.func_attr({"target": T.target("cuda", host="llvm -opt-level=0")}) mod.main_kernel(n) @T.prim_func @@ -168,20 +169,19 @@ def main_kernel(n: T.int32): return mod -class TestSplitHostDevice(BaseCompare): +class TestSplitHostDeviceWithoutDeviceRegion(BaseCompare): """Like TestSplitHostDevice, but no device regions to extract - Even if there are no device regions, the host-side function should - still have its "target" attribute updated. + Because MakePackedAPI/MakeUnpackedAPI still require both the + device and host, SplitHostDevice does not modify the "target" + attribute. """ def before(): T.func_attr({"target": T.target("ext_dev", host="llvm")}) T.evaluate(0) - def expected(): - T.func_attr({"target": T.target("llvm")}) - T.evaluate(0) + expected = before if __name__ == "__main__":