Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,16 +192,13 @@ TVM_DLL Pass InstrumentBoundCheckers();
* - Map the values in the api_args to Var that is required by body.
* - Insert assertions to check type/value of the passed arguments.
*
* \param num_unpacked_args Number of arguments that
* are processed in plain form instead of packed form.
*
* \note
* The function signature have two cases
*
* let num_packed_args = len(api_args) - num_unpacked_args;
* let num_packed_args = len(api_args);
*
* if num_packed_args is zero:
* f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
* f()
*
* if num_packed_args is not zero:
* f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
Expand All @@ -212,7 +209,7 @@ TVM_DLL Pass InstrumentBoundCheckers();
*
* \return The pass.
*/
TVM_DLL Pass MakePackedAPI(int num_unpacked_args);
TVM_DLL Pass MakePackedAPI();

/*!
* \brief Transform the high-level PrimFunc to a C signature that can be used
Expand Down
11 changes: 2 additions & 9 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,22 +387,15 @@ def LowerCustomDatatypes():
return _ffi_api.LowerCustomDatatypes() # type: ignore


def MakePackedAPI(num_unpacked_params: int = -1):
def MakePackedAPI():
"""Transform the PrimFuncs in the module to a packed func API.

Parameters
----------
num_unpacked_params : int
Number of parameters that we hope to directly pass via normal arguments
following the PackedFunc input signature. If it is specified as -1 or it
is less than the number of arguments, the pass will packed arguments still.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.MakePackedAPI(num_unpacked_params) # type: ignore
return _ffi_api.MakePackedAPI() # type: ignore


def MakeUnpackedAPI():
Expand Down
2 changes: 1 addition & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
if (unpacked_api) {
mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI());
} else {
mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1));
mixed_pass_list.push_back(tir::transform::MakePackedAPI());
}
mixed_pass_list.push_back(tir::transform::SplitHostDevice());

Expand Down
109 changes: 38 additions & 71 deletions src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
}

PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
PrimFunc MakePackedAPI(PrimFunc&& func) {
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";

Expand All @@ -152,14 +152,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
auto* func_ptr = func.CopyOnWrite();
const Stmt nop = Evaluate(0);
int num_args = static_cast<int>(func_ptr->params.size());
ICHECK_LE(num_unpacked_args, num_args);
bool pack_args = (num_unpacked_args == -1) || (num_args > num_unpacked_args);
if (num_unpacked_args == -1) {
// reset to zero
num_unpacked_args = 0;
}
ICHECK_GE(num_unpacked_args, 0);
int num_packed_args = num_args - num_unpacked_args;

// Data field definitions
// The packed fields
Var v_packed_args("args", DataType::Handle());
Expand All @@ -170,7 +163,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
Var v_out_ret_tcode("out_ret_tcode", PointerType(PrimType(DataType::Int(32))));
Var v_resource_handle("resource_handle", DataType::Handle());
// The arguments of the function.
Array<Var> args;

// The device context
Var device_id("dev_id");
Integer device_type(target_device_type);
Expand All @@ -194,14 +187,6 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
}
return res;
};
// ---------------------------
// start of logics
// add signature for packed arguments.
if (pack_args) {
args.push_back(v_packed_args);
args.push_back(buf_packed_arg_type_ids->data);
args.push_back(v_num_packed_args);
}

// Need to re-declare vars, in case some arguments also appears in the buffer.
std::vector<std::pair<Var, Var>> var_def;
Expand All @@ -219,7 +204,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {

// Pluck the device API context out based on name
if (param->name_hint == kDeviceContextVar) {
num_packed_args--;
num_args--;
v_resource_handle = param;
continue;
}
Expand All @@ -232,44 +217,34 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
var_def.emplace_back(v_arg, param);
}

if (i < num_packed_args) {
// Value loads
seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop));
// type code checks
Var tcode(v_arg->name_hint + ".code", DataType::Int(32));
seq_init.emplace_back(
LetStmt(tcode, BufferLoad(buf_packed_arg_type_ids, {IntImm(DataType::Int(32), i)}), nop));
DataType t = v_arg.dtype();
if (t.is_handle()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be pointer";
seq_check.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle ||
tcode == kTVMDLTensorHandle || tcode == kTVMNullptr,
tvm::tir::StringImm(msg.str()), nop));
} else if (t.is_int() || t.is_uint()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be int";
seq_check.emplace_back(AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop));
} else {
ICHECK(t.is_float());
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be float";
seq_check.emplace_back(AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop));
}
// Value loads
seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop));
// type code checks
Var tcode(v_arg->name_hint + ".code", DataType::Int(32));
seq_init.emplace_back(
LetStmt(tcode, BufferLoad(buf_packed_arg_type_ids, {IntImm(DataType::Int(32), i)}), nop));
DataType t = v_arg.dtype();
if (t.is_handle()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be pointer";
seq_check.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle ||
tcode == kTVMDLTensorHandle || tcode == kTVMNullptr,
tvm::tir::StringImm(msg.str()), nop));
} else if (t.is_int() || t.is_uint()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be int";
seq_check.emplace_back(AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop));
} else {
args.push_back(v_arg);
ICHECK(t.is_float());
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be float";
seq_check.emplace_back(AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop));
}
}

// allow return value if the function is packed.
if (pack_args) {
args.push_back(v_out_ret_value);
args.push_back(v_out_ret_tcode);
args.push_back(v_resource_handle);
}

size_t expected_nargs = num_unpacked_args + (pack_args ? 6 : 0);
ICHECK_EQ(args.size(), expected_nargs);
Array<Var> args{v_packed_args, buf_packed_arg_type_ids->data,
v_num_packed_args, v_out_ret_value,
v_out_ret_tcode, v_resource_handle};

// Arg definitions are defined before buffer binding to avoid the use before
// def errors.
Expand All @@ -286,9 +261,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
binder.BindDLTensor(kv.second, device_type, device_id, kv.first, kv.first->name_hint);
}

if (num_unpacked_args == 0) {
func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc));
}
func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc));

Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode);
body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
Expand All @@ -307,16 +280,11 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
}
}

if (pack_args) {
std::ostringstream num_args_error;
num_args_error << name_hint << ": num_args should be " << num_packed_args;
std::vector<Stmt> arg_assert = {
MakeAssertEQ(v_num_packed_args, num_packed_args, num_args_error.str())};
func_ptr->body =
MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
} else {
func_ptr->body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
}
std::ostringstream num_args_error;
num_args_error << name_hint << ": num_args should be " << num_args;
std::vector<Stmt> arg_assert = {MakeAssertEQ(v_num_packed_args, num_args, num_args_error.str())};
func_ptr->body =
MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
func_ptr->params = args;

Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
Expand All @@ -339,9 +307,8 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {

namespace transform {

Pass MakePackedAPI(int num_unpacked_args) {
// packed arguments anyway while `num_unpacked_args` is -1
auto pass_func = [num_unpacked_args](IRModule m, PassContext ctx) {
Pass MakePackedAPI() {
auto pass_func = [](IRModule m, PassContext ctx) {
IRModuleNode* mptr = m.CopyOnWrite();
std::vector<std::pair<GlobalVar, PrimFunc>> updates;

Expand All @@ -350,7 +317,7 @@ Pass MakePackedAPI(int num_unpacked_args) {
PrimFunc func = GetRef<PrimFunc>(n);
if (func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
CallingConv::kDefault) {
auto updated_func = MakePackedAPI(std::move(func), num_unpacked_args);
auto updated_func = MakePackedAPI(std::move(func));
updates.push_back({kv.first, updated_func});
}
}
Expand All @@ -365,7 +332,7 @@ Pass MakePackedAPI(int num_unpacked_args) {
return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakePackedAPI", {});
}

TVM_REGISTER_GLOBAL("tir.transform.MakePackedAPI").set_body_typed(MakePackedAPI);
TVM_REGISTER_GLOBAL("tir.transform.MakePackedAPI").set_body_typed([]() { return MakePackedAPI(); });
} // namespace transform
} // namespace tir
} // namespace tvm
5 changes: 2 additions & 3 deletions tests/python/unittest/test_tir_transform_make_packed_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@ def test_makeapi():
)
)(mod)

num_unpacked_args = 2
f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"]
assert len(f.params) == 8
f = tvm.tir.transform.MakePackedAPI()(mod)["main"]
assert len(f.params) == 6


def _find_assignment(stmt, var_name):
Expand Down