Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,9 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::InjectPTXLDG32());
}

mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
mixed_pass_list.push_back(tir::transform::SplitHostDevice());

bool unpacked_api = mixed_mod->GetAttr<relay::Executor>(tvm::attr::kExecutor)
.value_or(relay::Executor::Create("graph", {}))
->GetAttr<Bool>("unpacked-api")
Expand All @@ -590,8 +593,6 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::FP8StorageLegalize());
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());

mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
mixed_pass_list.push_back(tir::transform::SplitHostDevice());
mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());

return transform::Sequential(mixed_pass_list);
Expand Down
79 changes: 39 additions & 40 deletions src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,33 +183,17 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
}

/* \brief Return the global_symbol of the function, if it should be updated
*
* \param func The function to be inspected
*
* \returns The global_symbol to be used for the function at call
* sites, or NullOpt if the function is to remain unchanged.
*/
Optional<String> RequiresPackedAPI(const PrimFunc& func) {
PrimFunc MakePackedAPI(PrimFunc func) {
// A function with an explicit calling convention has already been
// lowered, and should not be modified.
if (auto opt = func->GetAttr<Integer>(tvm::attr::kCallingConv)) {
if (CallingConv(opt.value()->value) != CallingConv::kDefault) {
return NullOpt;
return func;
}
}

// Internal function calls do not need the PackedFunc API
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
if (!global_symbol.defined()) {
return NullOpt;
}

return global_symbol;
}

PrimFunc MakePackedAPI(PrimFunc func) {
auto global_symbol = RequiresPackedAPI(func);
if (!global_symbol.defined()) {
return func;
}
Expand All @@ -218,11 +202,20 @@ PrimFunc MakePackedAPI(PrimFunc func) {
Target target = [&]() {
auto opt = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(opt) << "MakePackedAPI required the function to be annotated with tvm::attr::kTarget ("
<< tvm::attr::kTarget << "), but the function only has attributes " << func->attrs;
<< tvm::attr::kTarget << "), but the function " << name_hint
<< " only has attributes" << func->attrs;
return opt.value();
}();
int target_device_type = target->GetTargetDeviceType();

// A function without a host target has already been lowered.
Target target_host;
if (auto opt = target->GetHost()) {
target_host = opt.value();
} else {
return func;
}

auto* func_ptr = func.CopyOnWrite();
const Stmt nop = Evaluate(0);
int num_args = static_cast<int>(func_ptr->params.size());
Expand Down Expand Up @@ -325,7 +318,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
name_hint + "." + kv.first->name_hint);
}

func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc));
func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)},
{tvm::attr::kTarget, target_host}});

Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode);
body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
Expand Down Expand Up @@ -368,38 +362,43 @@ namespace transform {
Pass MakePackedAPI() {
auto pass_func = [](IRModule mod, PassContext ctx) {
Map<GlobalVar, String> packed_func_methods;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<PrimFunc>()) {
auto prim_func = opt.value();
if (auto global_symbol = RequiresPackedAPI(prim_func)) {
packed_func_methods.Set(gvar, global_symbol.value());
}
}
}

IRModuleNode* mptr = mod.CopyOnWrite();
IRModule updates;

for (const auto& [gvar, base_func] : mptr->functions) {
for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<PrimFunc>()) {
auto func = opt.value();
auto orig_func = func;

if (auto body = SubroutineCallRewriter::Apply(packed_func_methods, func->body)) {
func.CopyOnWrite()->body = body.value();
}

func = MakePackedAPI(std::move(func));
auto orig_func = opt.value();
auto func = MakePackedAPI(orig_func);

if (!func.same_as(orig_func)) {
updates->Add(gvar, func);
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol).value();
packed_func_methods.Set(gvar, global_symbol);
}
}
}

if (updates->functions.size()) {
mod.CopyOnWrite()->Update(updates);
}

if (packed_func_methods.size()) {
IRModule updates;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<PrimFunc>()) {
auto func = opt.value();
auto orig_func = func;

if (auto body = SubroutineCallRewriter::Apply(packed_func_methods, func->body)) {
func.CopyOnWrite()->body = body.value();
updates->Add(gvar, func);
}
}
}

if (updates->functions.size()) {
mod.CopyOnWrite()->Update(updates);
}
}

return mod;
};

Expand Down
10 changes: 9 additions & 1 deletion src/tir/transforms/make_unpacked_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) {
}();
int target_device_type = target->GetTargetDeviceType();

// A function without a host target has already been lowered.
Target target_host;
if (auto opt = target->GetHost()) {
target_host = opt.value();
} else {
return func;
}

auto* func_ptr = func.CopyOnWrite();

// Setup device context
Expand Down Expand Up @@ -145,7 +153,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) {
func_ptr->buffer_map = Map<Var, Buffer>();

// return the function.
return func;
return WithAttrs(std::move(func), {{tvm::attr::kTarget, target_host}});
}

namespace transform {
Expand Down
8 changes: 0 additions & 8 deletions src/tir/transforms/split_host_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,6 @@ class HostDeviceSplitter : public StmtMutator {
};

PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& gvar) {
auto opt_target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(opt_target) << "SplitHostDevice: Require the target attribute";
Target target = opt_target.value();

auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
auto name_prefix = global_symbol.value_or(gvar->name_hint);

Expand All @@ -112,10 +108,6 @@ PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& g
func.CopyOnWrite()->body = body;
}

if (auto target_host = target->GetHost()) {
func = WithAttr(std::move(func), tvm::attr::kTarget, target_host.value());
}

return func;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def check_packed_func(target="llvm"):
# Construct a valid IRModule to be lowered:
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([a_buffer, b_buffer, c_buffer], stmt))

target = tvm.target.Target(target)
target = tvm.target.Target(target, host="llvm")
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod)
mod = tvm.tir.transform.MakePackedAPI()(mod)
Expand Down
40 changes: 33 additions & 7 deletions tests/python/unittest/test_tir_transform_make_packed_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_makeapi():
mod = tvm.tir.transform.Apply(
lambda f: f.with_attr(
{
"target": tvm.target.Target("llvm"),
"target": tvm.target.Target("llvm", host="llvm"),
"global_symbol": "main",
}
)
Expand Down Expand Up @@ -90,7 +90,9 @@ def test_variable_passed_from_args():
stmt = ib.get()

mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer, not_device_context], stmt))
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod)
mod = tvm.tir.transform.Apply(
lambda f: f.with_attr("target", tvm.target.Target("llvm", host="llvm"))
)(mod)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod)
func = tvm.tir.transform.MakePackedAPI()(mod)["main"]

Expand Down Expand Up @@ -132,7 +134,9 @@ def test_device_api_context_implicit_resource_handle():
stmt = ib.get()

mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer, device_context], stmt))
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod)
mod = tvm.tir.transform.Apply(
lambda f: f.with_attr("target", tvm.target.Target("llvm", host="llvm"))
)(mod)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod)
func = tvm.tir.transform.MakePackedAPI()(mod)["main"]

Expand Down Expand Up @@ -161,7 +165,7 @@ def test_device_api_context_implicit_resource_handle():

@pytest.mark.parametrize("use_global_symbol", [True, False])
def test_no_op_when_global_symbol_is_absent(use_global_symbol):
func_attr = {"target": tvm.target.Target("llvm")}
func_attr = {"target": tvm.target.Target("llvm", host="llvm")}
if use_global_symbol:
func_attr["global_symbol"] = "main"

Expand All @@ -177,6 +181,28 @@ def before():
tvm.ir.assert_structural_equal(before, after)


def test_target_host_removed():
"""After MakePackedAPI, host-side target should be the host

MakePackedAPI is the last transform that requires both the device
and the host. After MakePackedAPI, the target attribute should
only contain the host-side target.
"""

host = tvm.target.Target("llvm")

@I.ir_module
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("cuda", host=host)})
T.evaluate(0)

after = tvm.tir.transform.MakePackedAPI()(before)
target_attr = after["main"].attrs["target"]
assert str(host) == str(target_attr)


def test_internal_subroutine_call():
"""Internal subroutines should not use the PackedFunc API

Expand All @@ -190,7 +216,7 @@ def test_internal_subroutine_call():
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")})
before.subroutine(A.data)

@T.prim_func
Expand Down Expand Up @@ -222,12 +248,12 @@ def test_subroutine_call_to_externally_visible_subroutine():
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")})
before.subroutine(A.data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm", host="llvm")})
T.evaluate(A_data)

after = tvm.tir.transform.MakePackedAPI()(before)
Expand Down
Loading