diff --git a/python/tvm/utils/roofline/__init__.py b/python/tvm/utils/roofline/__init__.py index 67d80eb05284..45cc880c5b85 100644 --- a/python/tvm/utils/roofline/__init__.py +++ b/python/tvm/utils/roofline/__init__.py @@ -51,10 +51,16 @@ def _create_args(mod: IRModule, dev: Device, func_name: str = "main", remote=Non @pass_instrument class SaveLoweredTIR: - """Save TIR functions from right before final lowering. Right now this - means right before tir.MakePackedAPI.""" + """Save TIR functions for analysis. - def __init__(self, before_pass: str = "tir.MakePackedAPI"): + We need the TIR function in a form that can be handled by + `auto_scheduler.feature.named_features_from_primfunc`, but which + is the closest to the final lowered form as possible. Right now this + means right before tir.SplitHostDevice. + + """ + + def __init__(self, before_pass: str = "tir.SplitHostDevice"): """ Parameters ---------- diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index cfc7fa80c7a9..b75f173e0d00 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -578,6 +578,9 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::InjectPTXLDG32()); } + mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions()); + mixed_pass_list.push_back(tir::transform::SplitHostDevice()); + bool unpacked_api = mixed_mod->GetAttr(tvm::attr::kExecutor) .value_or(relay::Executor::Create("graph", {})) ->GetAttr("unpacked-api") @@ -590,8 +593,6 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::FP8StorageLegalize()); mixed_pass_list.push_back(tir::transform::BF16StorageLegalize()); - mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions()); - mixed_pass_list.push_back(tir::transform::SplitHostDevice()); mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch()); return transform::Sequential(mixed_pass_list); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 062f1c050961..a6673a19ad01 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -223,6 +223,14 @@ PrimFunc MakePackedAPI(PrimFunc func) { }(); int target_device_type = target->GetTargetDeviceType(); + // A function without a host target has already been lowered. + Target target_host; + if (auto opt = target->GetHost()) { + target_host = opt.value(); + } else { + return func; + } + auto* func_ptr = func.CopyOnWrite(); const Stmt nop = Evaluate(0); int num_args = static_cast(func_ptr->params.size()); @@ -325,7 +333,8 @@ PrimFunc MakePackedAPI(PrimFunc func) { name_hint + "." + kv.first->name_hint); } - func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)); + func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)}, + {tvm::attr::kTarget, target_host}}); Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode); body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope, diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index bdb3a953e99c..4b1b3bf517d0 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -111,6 +111,14 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) { }(); int target_device_type = target->GetTargetDeviceType(); + // A function without a host target has already been lowered. + Target target_host; + if (auto opt = target->GetHost()) { + target_host = opt.value(); + } else { + return func; + } + auto* func_ptr = func.CopyOnWrite(); // Setup device context @@ -145,7 +153,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) { func_ptr->buffer_map = Map(); // return the function. - return func; + return WithAttrs(std::move(func), {{tvm::attr::kTarget, target_host}}); } namespace transform { diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 2de831e8ad0c..29ecaa4e8e43 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -99,10 +99,6 @@ class HostDeviceSplitter : public StmtMutator { }; PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& gvar) { - auto opt_target = func->GetAttr(tvm::attr::kTarget); - ICHECK(opt_target) << "SplitHostDevice: Require the target attribute"; - Target target = opt_target.value(); - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); auto name_prefix = global_symbol.value_or(gvar->name_hint); @@ -112,10 +108,6 @@ PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& g func.CopyOnWrite()->body = body; } - if (auto target_host = target->GetHost()) { - func = WithAttr(std::move(func), tvm::attr::kTarget, target_host.value()); - } - return func; } diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 030976845298..6a8ff28e442e 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -525,7 +525,9 @@ def get_graph(): # nothing else was overrwritten. # With Target Hooks the TIR module needs a target attached # and lowered via make unpacked API. - tir_mod["main"] = tir_mod["main"].with_attr("target", tvm.target.Target("ethos-u")) + tir_mod["main"] = tir_mod["main"].with_attr( + "target", tvm.target.Target("ethos-u", host="ethos-u") + ) tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod) tir_to_cs_translator.translate(tir_mod, params) diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index 22f886a5917a..a293e2691923 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -255,7 +255,9 @@ def test_buffer_info_extraction(): # With Target Hooks the TIR module needs a target attached # and lowered via make unpacked API. tir_mod = test_case["tir_module"] - tir_mod["main"] = tir_mod["main"].with_attr("target", tvm.target.Target("ethos-u")) + tir_mod["main"] = tir_mod["main"].with_attr( + "target", tvm.target.Target("ethos-u", host="ethos-u") + ) tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod) buffer_info = tir_to_cs_translator.extract_buffer_info(tir_mod, test_case["param_dict"]) for buffer_var, info in buffer_info.items(): @@ -959,7 +961,9 @@ def check_buffer(address, region, length, buffer_var): for test_case in test_cases: tir_mod = test_case["tir_module"] - tir_mod["main"] = tir_mod["main"].with_attr("target", tvm.target.Target("ethos-u")) + tir_mod["main"] = tir_mod["main"].with_attr( + "target", tvm.target.Target("ethos-u", host="ethos-u") + ) tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod) candidate_regions_for_scratch = [5, 2, 1] ( diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index 8af7efb59604..34adcbb9aee4 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -37,7 +37,7 @@ def test_makeapi(): mod = tvm.tir.transform.Apply( lambda f: f.with_attr( { - "target": tvm.target.Target("llvm"), + "target": tvm.target.Target("llvm", host="llvm"), "global_symbol": "main", } ) @@ -90,7 +90,9 @@ def test_variable_passed_from_args(): stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer, not_device_context], stmt)) - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod) + mod = tvm.tir.transform.Apply( + lambda f: f.with_attr("target", tvm.target.Target("llvm", host="llvm")) + )(mod) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) func = tvm.tir.transform.MakePackedAPI()(mod)["main"] @@ -132,7 +134,9 @@ def test_device_api_context_implicit_resource_handle(): stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer, device_context], stmt)) - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod) + mod = tvm.tir.transform.Apply( + lambda f: f.with_attr("target", tvm.target.Target("llvm", host="llvm")) + )(mod) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) func = tvm.tir.transform.MakePackedAPI()(mod)["main"] @@ -161,7 +165,7 @@ def test_device_api_context_implicit_resource_handle(): @pytest.mark.parametrize("use_global_symbol", [True, False]) def test_no_op_when_global_symbol_is_absent(use_global_symbol): - func_attr = {"target": tvm.target.Target("llvm")} + func_attr = {"target": tvm.target.Target("llvm", host="llvm")} if use_global_symbol: func_attr["global_symbol"] = "main" @@ -177,6 +181,28 @@ def before(): tvm.ir.assert_structural_equal(before, after) +def test_target_host_removed(): + """After MakePackedAPI, host-side target should be the host + + MakePackedAPI is the last transform that requires both the device + and the host. After MakePackedAPI, the target attribute should + only contain the host-side target. + """ + + host = tvm.target.Target("llvm") + + @I.ir_module + class before: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "main", "target": T.target("cuda", host=host)}) + T.evaluate(0) + + after = tvm.tir.transform.MakePackedAPI()(before) + target_attr = after["main"].attrs["target"] + assert str(host) == str(target_attr) + + def test_internal_subroutine_call(): """Internal subroutines should not use the PackedFunc API @@ -190,7 +216,7 @@ def test_internal_subroutine_call(): class before: @T.prim_func def main(A: T.Buffer(1, "float32")): - T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")}) before.subroutine(A.data) @T.prim_func @@ -222,12 +248,12 @@ def test_subroutine_call_to_externally_visible_subroutine(): class before: @T.prim_func def main(A: T.Buffer(1, "float32")): - T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")}) before.subroutine(A.data) @T.prim_func def subroutine(A_data: T.handle("float32")): - T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")}) + T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm", host="llvm")}) T.evaluate(A_data) after = tvm.tir.transform.MakePackedAPI()(before) diff --git a/tests/python/unittest/test_tir_transform_make_unpacked_api.py b/tests/python/unittest/test_tir_transform_make_unpacked_api.py index bb9fe8ab8267..1931f7aef324 100644 --- a/tests/python/unittest/test_tir_transform_make_unpacked_api.py +++ b/tests/python/unittest/test_tir_transform_make_unpacked_api.py @@ -41,9 +41,8 @@ def mod(mod_without_attrs): def test_noop_if_not_global_symbol(mod_without_attrs): - before = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))( - mod_without_attrs - ) + target = tvm.target.Target("llvm", host="llvm") + before = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_without_attrs) after = tvm.tir.transform.MakeUnpackedAPI()(before) tvm.ir.assert_structural_equal(before, after) @@ -59,7 +58,8 @@ def test_fails_if_no_target(mod_without_attrs): @tvm.testing.parametrize_targets("c", "llvm", "cuda") def test_device_setup(mod, target, dev): - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target(target)))(mod) + target = tvm.target.Target(target, host="llvm") + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod) f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] assert len(f.params) == 1 assert f.params[0].name == "A" @@ -138,6 +138,49 @@ def test_body(): assert f.params[2].name == "A" +class TestTargetHostRemoved(tvm.testing.CompareBeforeAfter): + """After MakeUnpackedAPI, host-side target should be the host + + MakeUnpackedAPI is the last transform that requires both the device + and the host. After MakeUnpackedAPI, the target attribute should + only contain the host-side target. + """ + + transform = tvm.tir.transform.MakeUnpackedAPI() + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "main", "target": T.target("cuda", host="llvm")}) + mod.subroutine(A.data) + + @T.prim_func + def subroutine(A_data: T.handle("float32")): + T.func_attr({"target": T.target("cuda")}) + T.evaluate(A_data) + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def main(A_data: T.handle("float32")) -> T.int32: + T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 2) + mod.subroutine(A_data) + + @T.prim_func + def subroutine(A_data: T.handle("float32")): + T.func_attr({"target": T.target("cuda")}) + T.evaluate(A_data) + + return mod + + class TestInternalSubroutineCall(tvm.testing.CompareBeforeAfter): """Internal subroutines do not require modification @@ -153,7 +196,7 @@ def before(self): class mod: @T.prim_func def main(A: T.Buffer(1, "float32")): - T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")}) mod.subroutine(A.data) @T.prim_func @@ -195,12 +238,14 @@ def before(self): class mod: @T.prim_func def main(A: T.Buffer(1, "float32")): - T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")}) mod.subroutine(A.data) @T.prim_func def subroutine(A_data: T.handle("float32")): - T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")}) + T.func_attr( + {"global_symbol": "subroutine", "target": T.target("llvm", host="llvm")} + ) T.evaluate(A_data) return mod @@ -240,7 +285,7 @@ def before(self): class mod: @T.prim_func def main(A: T.Buffer(1, "float32")): - T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")}) mod.subroutine( T.tvm_stack_make_array( A.data, @@ -255,7 +300,9 @@ def main(A: T.Buffer(1, "float32")): @T.prim_func def subroutine(A: T.Buffer(1, "float32")): - T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")}) + T.func_attr( + {"global_symbol": "subroutine", "target": T.target("llvm", host="llvm")} + ) T.evaluate(A.data) return mod diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py b/tests/python/unittest/test_tir_transform_split_host_device.py index 1599b9a031a0..60bfb8a718d2 100644 --- a/tests/python/unittest/test_tir_transform_split_host_device.py +++ b/tests/python/unittest/test_tir_transform_split_host_device.py @@ -46,6 +46,7 @@ def test_split_host_device_func_attr(): [ tvm.tir.transform.AnnotateDeviceRegions(), tvm.tir.transform.SplitHostDevice(), + tvm.tir.transform.MakePackedAPI(), tvm.tir.transform.LowerDeviceKernelLaunch(), ] )(mod) @@ -111,7 +112,7 @@ def expected(self): class mod: @T.prim_func def main(n: T.int32): - T.func_attr({"target": T.target("llvm -opt-level=0")}) + T.func_attr({"target": T.target("cuda", host="llvm -opt-level=0")}) mod.main_kernel(n) @T.prim_func @@ -168,20 +169,19 @@ def main_kernel(n: T.int32): return mod -class TestSplitHostDevice(BaseCompare): +class TestSplitHostDeviceWithoutDeviceRegion(BaseCompare): """Like TestSplitHostDevice, but no device regions to extract - Even if there are no device regions, the host-side function should - still have its "target" attribute updated. + Because MakePackedAPI/MakeUnpackedAPI still require both the + device and host, SplitHostDevice does not modify the "target" + attribute. """ def before(): T.func_attr({"target": T.target("ext_dev", host="llvm")}) T.evaluate(0) - def expected(): - T.func_attr({"target": T.target("llvm")}) - T.evaluate(0) + expected = before if __name__ == "__main__":