Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 47 additions & 9 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,14 +341,24 @@ llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) {
gv_func_map_[op->name] = InitContextPtr(ftype->getPointerTo(), "__" + op->name);
it = gv_func_map_.find(op->name);
}
return builder_->CreateCall(GetContextPtr(it->second), arg_values);
#if TVM_LLVM_VERSION >= 90
auto ext_callee = llvm::FunctionCallee(ftype, GetContextPtr(it->second));
#else
auto ext_callee = GetContextPtr(it->second);
#endif
return builder_->CreateCall(ext_callee, arg_values);
} else {
llvm::Function* f = module_->getFunction(op->name);
if (f == nullptr) {
f = llvm::Function::Create(
ftype, llvm::Function::ExternalLinkage, op->name, module_.get());
}
return builder_->CreateCall(f, arg_values);
#if TVM_LLVM_VERSION >= 90
auto ext_callee = llvm::FunctionCallee(f);
#else
auto ext_callee = f;
#endif
return builder_->CreateCall(ext_callee, arg_values);
}
}

Expand Down Expand Up @@ -524,9 +534,15 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) {
Array<Var> vfields = tir::UndefinedVars(body, {});
uint64_t nbytes;
llvm::Value* cdata = PackClosureData(vfields, &nbytes);
#if TVM_LLVM_VERSION >= 90
auto launch_callee = llvm::FunctionCallee(
ftype_tvm_parallel_launch_, RuntimeTVMParallelLaunch());
#else
auto launch_callee = RuntimeTVMParallelLaunch();
#endif
BasicBlock* par_launch_end = CheckCallSuccess(
builder_->CreateCall(
RuntimeTVMParallelLaunch(),
launch_callee,
{f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(num_task)}));
// Setup the closure function.
BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
Expand Down Expand Up @@ -670,8 +686,14 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) {
ctx->setMetadata(
"tbaa",
md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0));
#if TVM_LLVM_VERSION >= 90
auto env_callee = llvm::FunctionCallee(
ftype_tvm_get_func_from_env_, RuntimeTVMGetFuncFromEnv());
#else
auto env_callee = RuntimeTVMGetFuncFromEnv();
#endif
llvm::Value* retcode = builder_->CreateCall(
RuntimeTVMGetFuncFromEnv(), {ctx, GetConstString(fname), out});
env_callee, {ctx, GetConstString(fname), out});
init_block = CheckCallSuccess(retcode);
#if TVM_LLVM_VERSION >= 110
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, llvm::Align(align));
Expand Down Expand Up @@ -710,9 +732,14 @@ CodeGenCPU::MakeCallPacked(const Array<PrimExpr> &args, llvm::Value **rvalue,
builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
ConstInt32(end));
*ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end));
#if TVM_LLVM_VERSION >= 90
auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall());
#else
auto call_callee = RuntimeTVMFuncCall();
#endif
BasicBlock *end_block = CheckCallSuccess(builder_->CreateCall(
RuntimeTVMFuncCall(), {handle, arg_value, arg_tcode, ConstInt32(nargs),
ret_value, *ret_tcode}));
call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs),
ret_value, *ret_tcode}));
DataType r_api_type = tir::APIType(r_type);
llvm::Value* load_ptr = builder_->CreatePointerCast(
ret_value, DTypeToLLVMType(r_api_type)->getPointerTo());
Expand Down Expand Up @@ -890,7 +917,13 @@ void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) {
builder_->CreateCondBr(cond, end_block, fail_block, md_very_likely_branch_);
// fail condition.
builder_->SetInsertPoint(fail_block);
builder_->CreateCall(RuntimeTVMAPISetLastError(), {msg});
#if TVM_LLVM_VERSION >= 90
auto err_callee = llvm::FunctionCallee(
ftype_tvm_api_set_last_error_, RuntimeTVMAPISetLastError());
#else
auto err_callee = RuntimeTVMAPISetLastError();
#endif
builder_->CreateCall(err_callee, {msg});
builder_->CreateRet(ConstInt32(-1));
// otherwise set it to be new end.
builder_->SetInsertPoint(end_block);
Expand All @@ -917,9 +950,14 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) {
<< "Cannot not place within parallel loop as the workload may differ, "
<< " place it between parallel and parallel_launch_point";
this->VisitStmt(op->body);
#if TVM_LLVM_VERSION >= 90
auto bar_callee = llvm::FunctionCallee(
ftype_tvm_parallel_barrier_, RuntimeTVMParallelBarrier());
#else
auto bar_callee = RuntimeTVMParallelBarrier();
#endif
builder_->CreateCall(
RuntimeTVMParallelBarrier(),
{MakeValue(parallel_env_.task_id), parallel_env_.penv});
bar_callee, {MakeValue(parallel_env_.task_id), parallel_env_.penv});
} else if (op->attr_key == tir::attr::pragma_import_llvm) {
const StringImmNode* value = op->value.as<StringImmNode>();
CHECK(value != nullptr);
Expand Down