Skip to content
11 changes: 11 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,17 @@ TVM_DLL Pass InstrumentBoundCheckers();
*/
TVM_DLL Pass MakePackedAPI(int num_unpacked_args);

/*!
* \brief Transform the high-level PrimFunc to a C signature that can be used
* to call the operator directly.
*
* The main task of this function is to create code that maps the values in the
* api_args to Var that is required by body
*
* \return The pass.
*/
TVM_DLL Pass MakeUnpackedAPI();

/*!
* \brief Remap the thread axis
*
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,17 @@ def MakePackedAPI(num_unpacked_params=0):
return _ffi_api.MakePackedAPI(num_unpacked_params)


def MakeUnpackedAPI():
"""Transform the PrimFuncs in the module to a C API compatible with internal calls.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.MakeUnpackedAPI()


def SplitHostDevice():
"""Split the function into a host function and device functions.

Expand Down
9 changes: 8 additions & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,15 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const Target
mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
mixed_pass_list.push_back(tir::transform::InferFragment());
mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
mixed_pass_list.push_back(tir::transform::MakePackedAPI(0));

if (target->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {
mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI());
} else {
mixed_pass_list.push_back(tir::transform::MakePackedAPI(0));
}

mixed_pass_list.push_back(tir::transform::SplitHostDevice());

auto opt_mixed = transform::Sequential(mixed_pass_list);
mod_mixed = opt_mixed(std::move(mod_mixed));

Expand Down
84 changes: 54 additions & 30 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,17 @@ class AOTExecutorCodegen : public ExprVisitor {
// Pack the sid inside the TVMValue
auto sid_array = te::Var(MakeString("sid_", sid, "_value"), DataType::Handle());
auto sid_value = sids_table_[sid];
tvm::PrimExpr set_tensor =
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
{sid_array, 0, tir::builtin::kArrData, sid_value});
stmts_.push_back(tir::LetStmt(sid_array, StackAlloca("array", 1), tir::Evaluate(set_tensor)));

if (!use_unpacked_api_) {
tvm::PrimExpr set_tensor =
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
{sid_array, 0, tir::builtin::kArrData, sid_value});
stmts_.push_back(
tir::LetStmt(sid_array, StackAlloca("array", 1), tir::Evaluate(set_tensor)));
} else {
stmts_.push_back(tir::LetStmt(sid_array, sid_value, tir::Evaluate(0)));
}

sid_vars.push_back(sid_array);
}
return sid_vars;
Expand All @@ -161,16 +168,16 @@ class AOTExecutorCodegen : public ExprVisitor {
auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
{tir::StringImm(params_by_expr_[expr])});

tvm::PrimExpr set_param_array =
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
{param_array, 0, tir::builtin::kArrData, param_handle});
lookup_call.push_back(tir::Evaluate(set_param_array));

tir::Stmt lookup_body = tir::SeqStmt(lookup_call);
if (!use_unpacked_api_) {
tvm::PrimExpr set_param_array =
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
{param_array, 0, tir::builtin::kArrData, param_handle});
stmts_.push_back(
tir::LetStmt(param_array, StackAlloca("arg_value", 1), tir::Evaluate(set_param_array)));
} else {
stmts_.push_back(tir::LetStmt(param_array, param_handle, tir::Evaluate(0)));
}

// Allocate the DLTensors on the stack
lookup_body = tir::LetStmt(param_array, StackAlloca("arg_value", 1), lookup_body);
stmts_.push_back(lookup_body);
return param_array;
}

Expand Down Expand Up @@ -206,15 +213,20 @@ class AOTExecutorCodegen : public ExprVisitor {
}

auto ret_expr = Downcast<Expr>(call);

// Pack the return(s) value. A call node can produce multiple outputs
for (const auto& var : PackSid(ret_expr)) {
args.push_back(var);
}

// Use tvm_call_packed to execute the function
create_func_call_stmts.push_back(tir::Evaluate(
tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_cpacked(), args)));
// Use tvm_call_packed to execute the function unless we're calling directly
auto calling_pattern = tvm::tir::builtin::tvm_call_cpacked();
if (use_unpacked_api_) {
calling_pattern = tvm::tir::builtin::call_extern();
}

create_func_call_stmts.push_back(
tir::Evaluate(tvm::tir::Call(DataType::Int(32), calling_pattern, args)));

tir::Stmt body = tir::SeqStmt(create_func_call_stmts);
stmts_.push_back(body);
}
Expand All @@ -226,16 +238,20 @@ class AOTExecutorCodegen : public ExprVisitor {
* copy-on-write fashion.
*/
void CopyToOutput(te::Var out, te::Var in, size_t size) {
auto retval_get = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(),
{in, 0, tir::builtin::kArrData});

// Define intermediate DLTensor to load/store the data
auto tmp0 = te::Var("tmp0", DataType::Handle());
auto tmp1 = te::Var("tmp1", DataType::Handle());
te::Var loop_idx("i", DataType::Int(32));
auto retval_i = tir::Load(DataType::UInt(8), tmp0, loop_idx, tir::const_true());
auto tostore = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(),
{out, 0, tir::builtin::kArrData});

PrimExpr retval_get = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(),
{in, 0, tir::builtin::kArrData});
PrimExpr tostore = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(),
{out, 0, tir::builtin::kArrData});
if (use_unpacked_api_) {
retval_get = in;
tostore = out;
}

// Copy the variable from the input to the output
tir::Stmt copy = tir::For(
Expand Down Expand Up @@ -535,6 +551,15 @@ class AOTExecutorCodegen : public ExprVisitor {
TargetsMap targets_;
/*! \brief target host */
Target target_host_;
/*!
* \brief unpacked api toggle
* When set to true the code generated will use unpacked calls to functions:
* func(void* arg0, void* arg1)
* Rather than packed calls:
* func(void* args)
* Defaults to using the packed calling convention
*/
Bool use_unpacked_api_;

/*!
* \brief parameters (i.e. ConstantNodes found in the graph).
Expand Down Expand Up @@ -564,21 +589,20 @@ class AOTExecutorCodegen : public ExprVisitor {

public:
AOTExecutorCodegen(runtime::Module* mod, const TargetsMap& targets, Target target_host)
: mod_(mod), return_sid_() {
compile_engine_ = CompileEngine::Global();
targets_ = targets;
target_host_ = target_host;
}
: mod_(mod),
targets_(targets),
target_host_(target_host),
use_unpacked_api_(target_host->GetAttr<Bool>("unpacked-api").value_or(Bool(false))),
compile_engine_(CompileEngine::Global()) {}

LoweredOutput Codegen(relay::Function func) {
// Get the module, storage map and token sizes
auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
storage_device_map_ = (*pf)(func);

int input_index = 0;
for (auto input : func->params) {
input_vars_.push_back(input);
main_signature_.push_back(tir::Var(MakeString("input_", input_index), DataType::Handle()));
main_signature_.push_back(tir::Var("input", DataType::Handle()));
}

// Define the storage allocator ids
Expand All @@ -592,7 +616,7 @@ class AOTExecutorCodegen : public ExprVisitor {
// Find the return sid
return_sid_ = AotReturnSidVisitor(storage_device_map_).FindReturnSid(func);
for (unsigned int output_index = 0; output_index < return_sid_.size(); output_index++) {
main_signature_.push_back(tir::Var(MakeString("output_", output_index), DataType::Handle()));
main_signature_.push_back(tir::Var("output", DataType::Handle()));
}

VisitExpr(func->body);
Expand Down
50 changes: 46 additions & 4 deletions src/target/source/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,17 +192,59 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
<< "}\n";
}

void GenerateEntrypointForUnpackedAPI() {
code_ << "TVM_DLL int32_t " << ::tvm::runtime::symbol::tvm_run_func_prefix << "(";
int total_args = (metadata_->num_inputs + metadata_->num_outputs);
for (int i = 0; i < total_args; ++i) {
code_ << "arg" << i;
if (i + 1 != total_args) {
code_ << ",";
}
}
code_ << ");\n";
code_ << "static int32_t " << ::tvm::runtime::symbol::tvm_module_main;
code_ << "(void* args, void* type_code, int num_args, void* out_value, void* "
"out_type_code, void* resource_handle) {\n";
code_ << "return " << ::tvm::runtime::symbol::tvm_run_func_prefix << "(";
for (int i = 0; i < metadata_->num_inputs; ++i) {
code_ << "((DLTensor*)(((TVMValue*)args)[" << i << "].v_handle))[0].data,";
}
for (int i = 0; i < metadata_->num_outputs; ++i) {
int j = metadata_->num_inputs + i;
code_ << "((DLTensor*)(((TVMValue*)args)[" << j << "].v_handle))[0].data";
if (i + 1 != metadata_->num_outputs) {
code_ << ",";
}
}
code_ << ");\n";
code_ << "}\n";
}

void GenerateEntrypointForPackedAPI() {
code_ << "TVM_DLL int32_t " << ::tvm::runtime::symbol::tvm_run_func_prefix;
code_ << "(void* args, void* type_code, int num_args, void* out_value, void* "
"out_type_code, void* resource_handle);\n";
code_ << "static int32_t " << ::tvm::runtime::symbol::tvm_module_main;
code_ << "(void* args, void* type_code, int num_args, void* out_value, void* "
"out_type_code, void* resource_handle) {\n";
code_ << "return " << ::tvm::runtime::symbol::tvm_run_func_prefix;
code_ << "(args, type_code, num_args, out_value, out_type_code, resource_handle);\n";
code_ << "}\n";
}

void GenerateAOTDescriptor() {
code_ << "#include \"tvm/runtime/crt/internal/aot_executor/aot_executor.h\"\n";
code_ << "#include \"tvm/runtime/c_runtime_api.h\"\n";
code_ << "#ifdef __cplusplus\n";
code_ << "extern \"C\"\n";
code_ << "#endif\n";
code_ << "TVM_DLL int32_t " << ::tvm::runtime::symbol::tvm_run_func_prefix;
code_ << "(void* args, void* type_code, int num_args, void* out_value, void* "
"out_type_code, void* resource_handle);\n";
if (target_->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {
GenerateEntrypointForUnpackedAPI();
} else {
GenerateEntrypointForPackedAPI();
}
code_ << "const tvm_model_t network = {\n"
<< " .run_func = &" << ::tvm::runtime::symbol::tvm_run_func_prefix << ",\n"
<< " .run_func = &" << ::tvm::runtime::symbol::tvm_module_main << ",\n"
<< " .num_input_tensors = " << metadata_->num_inputs << ",\n"
<< " .num_output_tensors = " << metadata_->num_outputs << ", \n"
<< "};\n";
Expand Down
2 changes: 2 additions & 0 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
.add_attr_option<Bool>("system-lib")
.add_attr_option<String>("runtime")
.add_attr_option<Bool>("link-params", Bool(false))
.add_attr_option<Bool>("unpacked-api")
.set_default_keys({"cpu"});

TVM_REGISTER_TARGET_KIND("c", kDLCPU)
Expand All @@ -268,6 +269,7 @@ TVM_REGISTER_TARGET_KIND("c", kDLCPU)
.add_attr_option<String>("march")
.add_attr_option<String>("executor")
.add_attr_option<Integer>("workspace-byte-alignment")
.add_attr_option<Bool>("unpacked-api")
.set_default_keys({"cpu"});

TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA)
Expand Down
Loading