From 4fdf1d157e30bfd36ae60d7be03cabba218cf3de Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 30 May 2023 09:14:35 -0500 Subject: [PATCH 1/4] [Bugfix][TIR][VTA] Update host-side target, even without device func This resolves an issue introduced by the combination of https://github.com/apache/tvm/pull/14918 and https://github.com/apache/tvm/pull/14945. The bug occurred for targets that do not require device-side codegen, but do require a `device_type` other than `kDLCPU`. It wasn't caught by CI, as the issue only occurred with the combination of both PRs. 1. #14918 updated `SplitHostDevice` to only modify the `"target"` attribute when a device-side function has been extracted. 2. For VTA, there is no device-side function, as everything is done through host-side API calls. 3. From (1) and (2), the VTA examples kept the target `T.target("ext_dev", host="llvm")` after the `SplitHostDevice` pass, instead of being updated to `T.target("llvm")`. 4. #14945 restricted CombineContextCall to only apply to host-side passes. 5. From (4) and (5), the `CombineContextCall` pass was no longer applied to the VTA context calls. This PR fixes `SplitHostDevice`, updating the target from `T.target("ext_dev", host="llvm")` to `T.target("llvm")`, even if no device sections have been extracted from the function. --- src/tir/transforms/split_host_device.cc | 10 +++++----- .../test_tir_transform_split_host_device.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 9270b356ba22..2de831e8ad0c 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -108,12 +108,12 @@ PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& g HostDeviceSplitter splitter(device_mod, name_prefix); - auto body = splitter(func->body); - - if (!body.same_as(func->body)) { + if (auto body = splitter(func->body); !body.same_as(func->body)) { func.CopyOnWrite()->body = body; - auto target_host = target->GetHost().value_or(Target("llvm")); - func = WithAttr(std::move(func), tvm::attr::kTarget, target_host); + } + + if (auto target_host = target->GetHost()) { + func = WithAttr(std::move(func), tvm::attr::kTarget, target_host.value()); } return func; diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py b/tests/python/unittest/test_tir_transform_split_host_device.py index cf866ae005c8..1599b9a031a0 100644 --- a/tests/python/unittest/test_tir_transform_split_host_device.py +++ b/tests/python/unittest/test_tir_transform_split_host_device.py @@ -168,5 +168,21 @@ def main_kernel(n: T.int32): return mod +class TestSplitHostDevice(BaseCompare): + """Like TestSplitHostDevice, but no device regions to extract + + Even if there are no device regions, the host-side function should + still have its "target" attribute updated. + """ + + def before(): + T.func_attr({"target": T.target("ext_dev", host="llvm")}) + T.evaluate(0) + + def expected(): + T.func_attr({"target": T.target("llvm")}) + T.evaluate(0) + + if __name__ == "__main__": tvm.testing.main() From 0130f8e10ec5b8e2464f9602d8b9f9c4ff482980 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 22 Mar 2023 12:54:32 -0500 Subject: [PATCH 2/4] [TIR] Move SplitHostDevice to before MakePackedAPI This simplifies the logic used in MakePackedAPI, that it the last user of the host parameter in a function's target. After MakePackedAPI, every PrimFunc has a "target" attribute without a "host". --- src/driver/driver_api.cc | 5 +- src/tir/transforms/make_packed_api.cc | 11 +++- src/tir/transforms/make_unpacked_api.cc | 10 ++- src/tir/transforms/split_host_device.cc | 8 --- .../test_tir_transform_make_packed_api.py | 40 ++++++++++-- .../test_tir_transform_make_unpacked_api.py | 65 ++++++++++++++++--- .../test_tir_transform_split_host_device.py | 14 ++-- 7 files changed, 118 insertions(+), 35 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index e5f71c38320d..46716ed30619 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -577,6 +577,9 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::InjectPTXLDG32()); } + mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions()); + mixed_pass_list.push_back(tir::transform::SplitHostDevice()); + bool unpacked_api = mixed_mod->GetAttr(tvm::attr::kExecutor) .value_or(relay::Executor::Create("graph", {})) ->GetAttr("unpacked-api") @@ -588,8 +591,6 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) } mixed_pass_list.push_back(tir::transform::BF16StorageLegalize()); - mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions()); - mixed_pass_list.push_back(tir::transform::SplitHostDevice()); mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch()); return transform::Sequential(mixed_pass_list); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 062f1c050961..a6673a19ad01 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -223,6 +223,14 @@ PrimFunc MakePackedAPI(PrimFunc func) { }(); int target_device_type = target->GetTargetDeviceType(); + // A function without a host target has already been lowered. + Target target_host; + if (auto opt = target->GetHost()) { + target_host = opt.value(); + } else { + return func; + } + auto* func_ptr = func.CopyOnWrite(); const Stmt nop = Evaluate(0); int num_args = static_cast(func_ptr->params.size()); @@ -325,7 +333,8 @@ PrimFunc MakePackedAPI(PrimFunc func) { name_hint + "." + kv.first->name_hint); } - func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)); + func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)}, + {tvm::attr::kTarget, target_host}}); Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode); body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope, diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index bdb3a953e99c..4b1b3bf517d0 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -111,6 +111,14 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) { }(); int target_device_type = target->GetTargetDeviceType(); + // A function without a host target has already been lowered. + Target target_host; + if (auto opt = target->GetHost()) { + target_host = opt.value(); + } else { + return func; + } + auto* func_ptr = func.CopyOnWrite(); // Setup device context @@ -145,7 +153,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) { func_ptr->buffer_map = Map(); // return the function. - return func; + return WithAttrs(std::move(func), {{tvm::attr::kTarget, target_host}}); } namespace transform { diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 2de831e8ad0c..29ecaa4e8e43 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -99,10 +99,6 @@ class HostDeviceSplitter : public StmtMutator { }; PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& gvar) { - auto opt_target = func->GetAttr(tvm::attr::kTarget); - ICHECK(opt_target) << "SplitHostDevice: Require the target attribute"; - Target target = opt_target.value(); - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); auto name_prefix = global_symbol.value_or(gvar->name_hint); @@ -112,10 +108,6 @@ PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& g func.CopyOnWrite()->body = body; } - if (auto target_host = target->GetHost()) { - func = WithAttr(std::move(func), tvm::attr::kTarget, target_host.value()); - } - return func; } diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index 8af7efb59604..34adcbb9aee4 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -37,7 +37,7 @@ def test_makeapi(): mod = tvm.tir.transform.Apply( lambda f: f.with_attr( { - "target": tvm.target.Target("llvm"), + "target": tvm.target.Target("llvm", host="llvm"), "global_symbol": "main", } ) @@ -90,7 +90,9 @@ def test_variable_passed_from_args(): stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer, not_device_context], stmt)) - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod) + mod = tvm.tir.transform.Apply( + lambda f: f.with_attr("target", tvm.target.Target("llvm", host="llvm")) + )(mod) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) func = tvm.tir.transform.MakePackedAPI()(mod)["main"] @@ -132,7 +134,9 @@ def test_device_api_context_implicit_resource_handle(): stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer, device_context], stmt)) - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod) + mod = tvm.tir.transform.Apply( + lambda f: f.with_attr("target", tvm.target.Target("llvm", host="llvm")) + )(mod) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) func = tvm.tir.transform.MakePackedAPI()(mod)["main"] @@ -161,7 +165,7 @@ def test_device_api_context_implicit_resource_handle(): @pytest.mark.parametrize("use_global_symbol", [True, False]) def test_no_op_when_global_symbol_is_absent(use_global_symbol): - func_attr = {"target": tvm.target.Target("llvm")} + func_attr = {"target": tvm.target.Target("llvm", host="llvm")} if use_global_symbol: func_attr["global_symbol"] = "main" @@ -177,6 +181,28 @@ def before(): tvm.ir.assert_structural_equal(before, after) +def test_target_host_removed(): + """After MakePackedAPI, host-side target should be the host + + MakePackedAPI is the last transform that requires both the device + and the host. After MakePackedAPI, the target attribute should + only contain the host-side target. + """ + + host = tvm.target.Target("llvm") + + @I.ir_module + class before: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "main", "target": T.target("cuda", host=host)}) + T.evaluate(0) + + after = tvm.tir.transform.MakePackedAPI()(before) + target_attr = after["main"].attrs["target"] + assert str(host) == str(target_attr) + + def test_internal_subroutine_call(): """Internal subroutines should not use the PackedFunc API @@ -190,7 +216,7 @@ def test_internal_subroutine_call(): class before: @T.prim_func def main(A: T.Buffer(1, "float32")): - T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")}) before.subroutine(A.data) @T.prim_func @@ -222,12 +248,12 @@ def test_subroutine_call_to_externally_visible_subroutine(): class before: @T.prim_func def main(A: T.Buffer(1, "float32")): - T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")}) before.subroutine(A.data) @T.prim_func def subroutine(A_data: T.handle("float32")): - T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")}) + T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm", host="llvm")}) T.evaluate(A_data) after = tvm.tir.transform.MakePackedAPI()(before) diff --git a/tests/python/unittest/test_tir_transform_make_unpacked_api.py b/tests/python/unittest/test_tir_transform_make_unpacked_api.py index bb9fe8ab8267..1931f7aef324 100644 --- a/tests/python/unittest/test_tir_transform_make_unpacked_api.py +++ b/tests/python/unittest/test_tir_transform_make_unpacked_api.py @@ -41,9 +41,8 @@ def mod(mod_without_attrs): def test_noop_if_not_global_symbol(mod_without_attrs): - before = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))( - mod_without_attrs - ) + target = tvm.target.Target("llvm", host="llvm") + before = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_without_attrs) after = tvm.tir.transform.MakeUnpackedAPI()(before) tvm.ir.assert_structural_equal(before, after) @@ -59,7 +58,8 @@ def test_fails_if_no_target(mod_without_attrs): @tvm.testing.parametrize_targets("c", "llvm", "cuda") def test_device_setup(mod, target, dev): - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target(target)))(mod) + target = tvm.target.Target(target, host="llvm") + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod) f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] assert len(f.params) == 1 assert f.params[0].name == "A" @@ -138,6 +138,49 @@ def test_body(): assert f.params[2].name == "A" +class TestTargetHostRemoved(tvm.testing.CompareBeforeAfter): + """After MakeUnpackedAPI, host-side target should be the host + + MakeUnpackedAPI is the last transform that requires both the device + and the host. After MakeUnpackedAPI, the target attribute should + only contain the host-side target. + """ + + transform = tvm.tir.transform.MakeUnpackedAPI() + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "main", "target": T.target("cuda", host="llvm")}) + mod.subroutine(A.data) + + @T.prim_func + def subroutine(A_data: T.handle("float32")): + T.func_attr({"target": T.target("cuda")}) + T.evaluate(A_data) + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def main(A_data: T.handle("float32")) -> T.int32: + T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 2) + mod.subroutine(A_data) + + @T.prim_func + def subroutine(A_data: T.handle("float32")): + T.func_attr({"target": T.target("cuda")}) + T.evaluate(A_data) + + return mod + + class TestInternalSubroutineCall(tvm.testing.CompareBeforeAfter): """Internal subroutines do not require modification @@ -153,7 +196,7 @@ def before(self): class mod: @T.prim_func def main(A: T.Buffer(1, "float32")): - T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")}) mod.subroutine(A.data) @T.prim_func @@ -195,12 +238,14 @@ def before(self): class mod: @T.prim_func def main(A: T.Buffer(1, "float32")): - T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")}) mod.subroutine(A.data) @T.prim_func def subroutine(A_data: T.handle("float32")): - T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")}) + T.func_attr( + {"global_symbol": "subroutine", "target": T.target("llvm", host="llvm")} + ) T.evaluate(A_data) return mod @@ -240,7 +285,7 @@ def before(self): class mod: @T.prim_func def main(A: T.Buffer(1, "float32")): - T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")}) mod.subroutine( T.tvm_stack_make_array( A.data, @@ -255,7 +300,9 @@ def main(A: T.Buffer(1, "float32")): @T.prim_func def subroutine(A: T.Buffer(1, "float32")): - T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")}) + T.func_attr( + {"global_symbol": "subroutine", "target": T.target("llvm", host="llvm")} + ) T.evaluate(A.data) return mod diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py b/tests/python/unittest/test_tir_transform_split_host_device.py index 1599b9a031a0..60bfb8a718d2 100644 --- a/tests/python/unittest/test_tir_transform_split_host_device.py +++ b/tests/python/unittest/test_tir_transform_split_host_device.py @@ -46,6 +46,7 @@ def test_split_host_device_func_attr(): [ tvm.tir.transform.AnnotateDeviceRegions(), tvm.tir.transform.SplitHostDevice(), + tvm.tir.transform.MakePackedAPI(), tvm.tir.transform.LowerDeviceKernelLaunch(), ] )(mod) @@ -111,7 +112,7 @@ def expected(self): class mod: @T.prim_func def main(n: T.int32): - T.func_attr({"target": T.target("llvm -opt-level=0")}) + T.func_attr({"target": T.target("cuda", host="llvm -opt-level=0")}) mod.main_kernel(n) @T.prim_func @@ -168,20 +169,19 @@ def main_kernel(n: T.int32): return mod -class TestSplitHostDevice(BaseCompare): +class TestSplitHostDeviceWithoutDeviceRegion(BaseCompare): """Like TestSplitHostDevice, but no device regions to extract - Even if there are no device regions, the host-side function should - still have its "target" attribute updated. + Because MakePackedAPI/MakeUnpackedAPI still require both the + device and host, SplitHostDevice does not modify the "target" + attribute. """ def before(): T.func_attr({"target": T.target("ext_dev", host="llvm")}) T.evaluate(0) - def expected(): - T.func_attr({"target": T.target("llvm")}) - T.evaluate(0) + expected = before if __name__ == "__main__": From e6b55db984326ec6610d588f072dbb11e59fe0b3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 24 May 2023 10:06:21 -0500 Subject: [PATCH 3/4] [TIR] Cleanup of MakePackedAPI Prior to this commit, the `RequiresPackedAPI` function checked whether a function needed the packed func API. This was used both to generate a list of call-sites to update, and as part of the updates to `PrimFunc` signatures. However, the function that updates the `PrimFunc` signature could still return the original function unmodified, breaking internal method calls. This occurred for functions with a `kTarget` attribute without a host. This commit updates `MakePackedAPI` to first update all `PrimFunc` signatures that require the packed func API, then use the result to determine which call-sites must be updated. This resolves the discrepancy for host-less target annotations, and removes the possibility of similar discrepancies in the future. --- src/tir/transforms/make_packed_api.cc | 68 ++++++++++++--------------- 1 file changed, 29 insertions(+), 39 deletions(-) diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 062f1c050961..77d563bf0da8 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -183,33 +183,17 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); } -/* \brief Return the global_symbol of the function, if it should be updated - * - * \param func The function to be inspected - * - * \returns The global_symbol to be used for the function at call - * sites, or NullOpt if the function is to remain unchanged. - */ -Optional RequiresPackedAPI(const PrimFunc& func) { +PrimFunc MakePackedAPI(PrimFunc func) { // A function with an explicit calling convention has already been // lowered, and should not be modified. if (auto opt = func->GetAttr(tvm::attr::kCallingConv)) { if (CallingConv(opt.value()->value) != CallingConv::kDefault) { - return NullOpt; + return func; } } // Internal function calls do not need the PackedFunc API auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); - if (!global_symbol.defined()) { - return NullOpt; - } - - return global_symbol; -} - -PrimFunc MakePackedAPI(PrimFunc func) { - auto global_symbol = RequiresPackedAPI(func); if (!global_symbol.defined()) { return func; } @@ -218,7 +202,8 @@ PrimFunc MakePackedAPI(PrimFunc func) { Target target = [&]() { auto opt = func->GetAttr(tvm::attr::kTarget); ICHECK(opt) << "MakePackedAPI required the function to be annotated with tvm::attr::kTarget (" - << tvm::attr::kTarget << "), but the function only has attributes " << func->attrs; + << tvm::attr::kTarget << "), but the function " << name_hint + << " only has attributes" << func->attrs; return opt.value(); }(); int target_device_type = target->GetTargetDeviceType(); @@ -368,38 +353,43 @@ namespace transform { Pass MakePackedAPI() { auto pass_func = [](IRModule mod, PassContext ctx) { Map packed_func_methods; - for (const auto& [gvar, base_func] : mod->functions) { - if (auto opt = base_func.as()) { - auto prim_func = opt.value(); - if (auto global_symbol = RequiresPackedAPI(prim_func)) { - packed_func_methods.Set(gvar, global_symbol.value()); - } - } - } - IRModuleNode* mptr = mod.CopyOnWrite(); IRModule updates; - - for (const auto& [gvar, base_func] : mptr->functions) { + for (const auto& [gvar, base_func] : mod->functions) { if (auto opt = base_func.as()) { - auto func = opt.value(); - auto orig_func = func; - - if (auto body = SubroutineCallRewriter::Apply(packed_func_methods, func->body)) { - func.CopyOnWrite()->body = body.value(); - } - - func = MakePackedAPI(std::move(func)); + auto orig_func = opt.value(); + auto func = MakePackedAPI(orig_func); if (!func.same_as(orig_func)) { updates->Add(gvar, func); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol).value(); + packed_func_methods.Set(gvar, global_symbol); } } } - if (updates->functions.size()) { mod.CopyOnWrite()->Update(updates); } + + if (packed_func_methods.size()) { + IRModule updates; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto opt = base_func.as()) { + auto func = opt.value(); + auto orig_func = func; + + if (auto body = SubroutineCallRewriter::Apply(packed_func_methods, func->body)) { + func.CopyOnWrite()->body = body.value(); + updates->Add(gvar, func); + } + } + } + + if (updates->functions.size()) { + mod.CopyOnWrite()->Update(updates); + } + } + return mod; }; From 8d89826c00c88beeb782bac54952f547c38e7ae3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 2 Jun 2023 13:41:22 -0500 Subject: [PATCH 4/4] Include "host" in LowerTVMBuiltin test case --- tests/python/unittest/test_tir_transform_lower_tvm_builtin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py index d224a688d298..949bb0477bab 100644 --- a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py @@ -56,7 +56,7 @@ def check_packed_func(target="llvm"): # Construct a valid IRModule to be lowered: mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([a_buffer, b_buffer, c_buffer], stmt)) - target = tvm.target.Target(target) + target = tvm.target.Target(target, host="llvm") mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) mod = tvm.tir.transform.MakePackedAPI()(mod)