diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 2e5a4bc23bd5..c08081405648 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -51,7 +51,7 @@ #include "../../runtime/rocm/rocm_module.h" #include "../build_common.h" #include "codegen_llvm.h" -#include "llvm_common.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { @@ -238,27 +238,25 @@ class CodeGenAMDGPU : public CodeGenLLVM { } protected: - void InitTarget(llvm::TargetMachine* tm) final { + void InitTarget() final { // Maximum vector lane = float4 native_vector_bits_ = 4 * 32; - CodeGenLLVM::InitTarget(tm); + CodeGenLLVM::InitTarget(); } }; runtime::Module BuildAMDGPU(IRModule mod, Target target) { + LLVMInstance llvm_instance; + + With llvm_target(llvm_instance, target); #if TVM_LLVM_VERSION < 90 LOG(FATAL) << "AMDGPU backend requires at least LLVM 9"; // Lower versions will crash when loading the bitcode, see // issue #4087 for a discussion #endif - InitializeLLVM(); - std::unique_ptr tm = GetLLVMTargetMachine(target); - std::unique_ptr ctx(new llvm::LLVMContext()); - // careful: cg will hold a naked pointer reference to ctx, so it should - // have a shorter lifetime than the ctx. std::unique_ptr cg(new CodeGenAMDGPU()); - cg->Init("TVMAMDGPUModule", tm.get(), ctx.get(), false, false, false); + cg->Init("TVMAMDGPUModule", llvm_target.get(), false, false, false); cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) { ICHECK(kv.second->template IsInstance()) @@ -266,20 +264,15 @@ runtime::Module BuildAMDGPU(IRModule mod, Target target) { return Downcast(kv.second); }); + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); const auto* find_rocm_bitcodes = tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path"); Array bitcode_files = (*find_rocm_bitcodes)(); for (auto& bitcode_path : bitcode_files) { - std::string path = bitcode_path; - llvm::SMDiagnostic err; - std::unique_ptr mlib = llvm::parseIRFile(path, err, *ctx); - if (mlib.get() == nullptr) { - std::string msg(err.getMessage()); - LOG(FATAL) << "Fail to load bitcode file " << path << "\n" - << "line " << err.getLineNo() << ":" << msg; - } - mlib->setTargetTriple(tm->getTargetTriple().str()); + std::unique_ptr mlib = llvm_instance.LoadIR(bitcode_path); + mlib->setTargetTriple(llvm_target->GetTargetTriple()); mlib->setDataLayout(tm->createDataLayout()); + for (llvm::Function& f : mlib->functions()) { f.addFnAttr(llvm::Attribute::AlwaysInline); } @@ -351,4 +344,5 @@ TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_rocm") } // namespace codegen } // namespace tvm + #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc index f5ce0d550b1f..15d1699b3b59 100644 --- a/src/target/llvm/codegen_arm.cc +++ b/src/target/llvm/codegen_arm.cc @@ -42,10 +42,10 @@ class CodeGenARM final : public CodeGenCPU { CodeGenARM() = default; virtual ~CodeGenARM() = default; - void InitTarget(llvm::TargetMachine* tm) final { + void InitTarget() final { // set native vector bits. native_vector_bits_ = 16 * 8; - CodeGenCPU::InitTarget(tm); + CodeGenCPU::InitTarget(); } llvm::Value* CreateIntrinsic(const CallNode* op) override; @@ -139,4 +139,5 @@ TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm") } // namespace codegen } // namespace tvm + #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/codegen_blob.cc b/src/target/llvm/codegen_blob.cc index 8e6041b4c970..b67aac480654 100644 --- a/src/target/llvm/codegen_blob.cc +++ b/src/target/llvm/codegen_blob.cc @@ -52,25 +52,20 @@ #include #include -#include "llvm_common.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { -std::pair, std::shared_ptr> CodeGenBlob( - const std::string& data, bool system_lib, const std::string& llvm_target_string) { - InitializeLLVM(); - Target target(llvm_target_string); - auto tm = GetLLVMTargetMachine(target); - auto triple = tm->getTargetTriple(); - auto ctx = std::make_shared(); +std::unique_ptr CodeGenBlob(const std::string& data, bool system_lib, + LLVMTarget* llvm_target) { + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); + const llvm::Triple& triple = tm->getTargetTriple(); + llvm::LLVMContext* ctx = llvm_target->GetContext(); std::string module_name = "devc"; - std::unique_ptr module(new llvm::Module(module_name, *ctx)); + auto module = std::make_unique(module_name, *ctx); module->setTargetTriple(triple.str()); - // Store full target string in metadata, because flags such as -mfloat-abi must be preserved for - // ModulePackImportsToLLVM. - module->addModuleFlag(llvm::Module::ModFlagBehavior::Override, "tvm_target", - llvm::MDString::get(*ctx, LLVMTargetToString(target))); + llvm_target->SetTargetMetadata(module.get()); module->setDataLayout(tm->createDataLayout()); auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false); auto* tvm_dev_mblob = new llvm::GlobalVariable( @@ -188,9 +183,10 @@ std::pair, std::shared_ptr> Cod ir_builder.CreateRetVoid(); } - return std::make_pair(std::move(module), ctx); + return module; } } // namespace codegen } // namespace tvm + #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/codegen_blob.h b/src/target/llvm/codegen_blob.h index 46c037a30af2..a06c043c07b1 100644 --- a/src/target/llvm/codegen_blob.h +++ b/src/target/llvm/codegen_blob.h @@ -26,15 +26,18 @@ #ifdef TVM_LLVM_VERSION -#include -#include - #include #include -#include + +namespace llvm { +class Module; +} namespace tvm { namespace codegen { + +class LLVMTarget; + /** * \brief Code Generation of blob data * @@ -44,8 +47,8 @@ namespace codegen { * * \return LLVM module and LLVM context */ -std::pair, std::shared_ptr> CodeGenBlob( - const std::string& data, bool system_lib, const std::string& llvm_target_string); +std::unique_ptr CodeGenBlob(const std::string& data, bool system_lib, + LLVMTarget* llvm_target); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index f2ce6fb848b4..c4aed1a237dd 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -60,6 +60,7 @@ #include "../func_registry_generator.h" #include "../metadata_utils.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { @@ -69,10 +70,9 @@ namespace codegen { CodeGenCPU::CodeGenCPU() = default; CodeGenCPU::~CodeGenCPU() = default; -void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup, - bool target_c_runtime) { - CodeGenLLVM::Init(module_name, tm, ctx, system_lib, dynamic_lookup, target_c_runtime); +void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib, + bool dynamic_lookup, bool target_c_runtime) { + CodeGenLLVM::Init(module_name, llvm_target, system_lib, dynamic_lookup, target_c_runtime); dbg_info_ = CreateDebugInfo(module_.get()); static_assert(sizeof(TVMValue) == sizeof(double), "invariant"); func_handle_map_.clear(); @@ -80,7 +80,8 @@ void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, // Runtime types. - t_tvm_shape_index_ = llvm::Type::getIntNTy(*ctx, DataType::ShapeIndex().bits()); + t_tvm_shape_index_ = + llvm::Type::getIntNTy(*llvm_target_->GetContext(), DataType::ShapeIndex().bits()); // Defined in 3rdparty/dlpack/include/dlpack/dlpack.h: // typedef struct { DLDeviceType device_type; int device_id; } DLDevice; t_tvm_device_ = llvm::StructType::create({t_int_, t_int_}); @@ -177,7 +178,7 @@ void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::Function::Create(ftype_tvm_parallel_barrier_, llvm::Function::ExternalLinkage, "TVMBackendParallelBarrier", module_.get()); } - this->InitGlobalContext(dynamic_lookup); + InitGlobalContext(dynamic_lookup); target_c_runtime_ = target_c_runtime; is_system_lib_ = system_lib; } @@ -240,6 +241,7 @@ void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) { } llvm::DebugLoc DL; builder.SetCurrentDebugLocation(DL); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); for (size_t i = 0; i < f_llvm->arg_size(); ++i) { auto* paramAlloca = builder.CreateAlloca(f_llvm->getFunctionType()->getParamType(i)); std::string paramName = "arg" + std::to_string(i + 1); @@ -248,7 +250,7 @@ void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) { GetDebugType(GetType(f_tir->params[i]), f_llvm->getFunctionType()->getParamType(i)), /*alwaysPreserve=*/true); auto* store = builder.CreateStore(f_llvm->arg_begin() + i, paramAlloca); - auto* di_loc = llvm::DILocation::get(*ctx_, 0, 0, DIFunction); + auto* di_loc = llvm::DILocation::get(*ctx, 0, 0, DIFunction); dbg_info_->di_builder_->insertDeclare(paramAlloca, param, dbg_info_->di_builder_->createExpression(), llvm::DebugLoc(di_loc), store); @@ -263,7 +265,7 @@ void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) { if (I.getDebugLoc()) { continue; } - auto* di_loc = llvm::DILocation::get(*ctx_, 0, 0, scope); + auto* di_loc = llvm::DILocation::get(*ctx, 0, 0, scope); I.setDebugLoc(llvm::DebugLoc(di_loc)); } } @@ -273,7 +275,7 @@ void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) { llvm::DIType* CodeGenCPU::GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm) { if (ty_llvm == t_void_) { return nullptr; - } else if (ty_llvm == llvm::Type::getFloatTy(*ctx_)) { + } else if (ty_llvm == llvm::Type::getFloatTy(*llvm_target_->GetContext())) { return dbg_info_->di_builder_->createBasicType("float", 32, llvm::dwarf::DW_ATE_float); } else if (ty_llvm == t_int8_) { return dbg_info_->di_builder_->createBasicType("int8", 8, llvm::dwarf::DW_ATE_signed); @@ -311,13 +313,14 @@ void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) { #endif // comdat is needed for windows select any linking to work // set comdat to Any(weak linking) - if (target_machine_->getTargetTriple().isOSWindows()) { + if (llvm_target_->GetOrCreateTargetMachine()->getTargetTriple().isOSWindows()) { llvm::Comdat* comdat = module_->getOrInsertComdat(runtime::symbol::tvm_module_main); comdat->setSelectionKind(llvm::Comdat::Any); global->setComdat(comdat); } - global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, entry_func_name)); + global->setInitializer( + llvm::ConstantDataArray::getString(*llvm_target_->GetContext(), entry_func_name)); global->setDLLStorageClass(llvm::GlobalVariable::DLLExportStorageClass); } @@ -475,7 +478,7 @@ llvm::GlobalVariable* CodeGenCPU::InitContextPtr(llvm::Type* p_type, std::string gv->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); // comdat is needed for windows select any linking to work // set comdat to Any(weak linking) - if (target_machine_->getTargetTriple().isOSWindows()) { + if (llvm_target_->GetOrCreateTargetMachine()->getTargetTriple().isOSWindows()) { llvm::Comdat* comdat = module_->getOrInsertComdat(name); comdat->setSelectionKind(llvm::Comdat::Any); gv->setComdat(comdat); @@ -525,8 +528,9 @@ void CodeGenCPU::InitGlobalContext(bool dynamic_lookup) { llvm::BasicBlock* CodeGenCPU::CheckCallSuccess(llvm::Value* retcode) { // create emit codes that checks and load the function. - auto* fail_block = llvm::BasicBlock::Create(*ctx_, "call_fail", function_); - auto* end_block = llvm::BasicBlock::Create(*ctx_, "call_end", function_); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + auto* fail_block = llvm::BasicBlock::Create(*ctx, "call_fail", function_); + auto* end_block = llvm::BasicBlock::Create(*ctx, "call_end", function_); auto* succ = builder_->CreateICmpEQ(retcode, llvm::ConstantInt::get(t_int_, 0)); builder_->CreateCondBr(succ, end_block, fail_block, md_very_likely_branch_); builder_->SetInsertPoint(fail_block); @@ -584,6 +588,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { SetTargetAttributes(fcompute); llvm::BasicBlock* compute_call_end = CheckCallSuccess(builder_->CreateCall(fcompute, arg_values)); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); // enter compute scope and setup compute function. With scope_states_guard(this); size_t idx = 0; @@ -607,7 +612,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { if (f != alloc_storage_info_.end()) { unsigned align = f->second.alignment; if (align > 1) { - auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align); + auto attr = llvm::Attribute::get(*ctx, llvm::Attribute::Alignment, align); fcompute->addParamAttr(idx, attr); } } @@ -615,7 +620,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { } function_ = fcompute; - auto* compute_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); + auto* compute_entry = llvm::BasicBlock::Create(*ctx, "entry", function_); builder_->SetInsertPoint(compute_entry); this->VisitStmt(op->body); builder_->CreateRet(ConstInt32(0)); @@ -679,7 +684,8 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task, std::strin launch_callee, {f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(num_task)})); // Setup the closure function. - auto* lambda_entry = llvm::BasicBlock::Create(*ctx_, "parallel_closure_entry", f); + auto* lambda_entry = + llvm::BasicBlock::Create(*llvm_target_->GetContext(), "parallel_closure_entry", f); builder_->SetInsertPoint(lambda_entry); auto it = f->arg_begin(); llvm::Value* task_id = &(*it++); @@ -747,7 +753,7 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod llvm::BasicBlock* init_end = CheckCallSuccess(builder_->CreateCall( finit, {gv, f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(nbytes)})); // Setup the closure function. - auto* lambda_entry = llvm::BasicBlock::Create(*ctx_, "entry", f); + auto* lambda_entry = llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry", f); builder_->SetInsertPoint(lambda_entry); auto it = f->arg_begin(); cdata.addr = builder_->CreatePointerCast(&(*it++), cdata.addr->getType()); @@ -793,9 +799,10 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { hptr = it->second; } // create emit codes that checks and load the function. + llvm::LLVMContext* ctx = llvm_target_->GetContext(); llvm::BasicBlock* pre_block = builder_->GetInsertBlock(); - auto* init_block = llvm::BasicBlock::Create(*ctx_, "handle_init", function_); - auto* end_block = llvm::BasicBlock::Create(*ctx_, "handle_init_end", function_); + auto* init_block = llvm::BasicBlock::Create(*ctx, "handle_init", function_); + auto* end_block = llvm::BasicBlock::Create(*ctx, "handle_init_end", function_); #if TVM_LLVM_VERSION >= 110 llvm::Value* handle = builder_->CreateAlignedLoad(hptr->getValueType(), hptr, llvm::Align(align)); #elif TVM_LLVM_VERSION >= 80 @@ -811,22 +818,22 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { llvm::Value* out = WithFunctionEntry([&]() { return builder_->CreateAlloca(t_tvm_func_handle_); }); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, - llvm::Align(gv_mod_ctx_->getAlignment())); + llvm::LoadInst* ctx_load = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, + llvm::Align(gv_mod_ctx_->getAlignment())); #elif TVM_LLVM_VERSION >= 80 - llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, - gv_mod_ctx_->getAlignment()); + llvm::LoadInst* ctx_load = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, + gv_mod_ctx_->getAlignment()); #else - llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_, gv_mod_ctx_->getAlignment()); + llvm::LoadInst* ctx_load = builder_->CreateAlignedLoad(gv_mod_ctx_, gv_mod_ctx_->getAlignment()); #endif - ctx->setMetadata("tbaa", - md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0)); + ctx_load->setMetadata( + "tbaa", md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0)); #if TVM_LLVM_VERSION >= 90 auto env_callee = llvm::FunctionCallee(ftype_tvm_get_func_from_env_, RuntimeTVMGetFuncFromEnv()); #else auto env_callee = RuntimeTVMGetFuncFromEnv(); #endif - llvm::Value* retcode = builder_->CreateCall(env_callee, {ctx, GetConstString(fname), out}); + llvm::Value* retcode = builder_->CreateCall(env_callee, {ctx_load, GetConstString(fname), out}); init_block = CheckCallSuccess(retcode); #if TVM_LLVM_VERSION >= 110 llvm::Value* loaded_handle = @@ -946,13 +953,14 @@ llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) { ICHECK_EQ(op->args.size(), 6U); PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, op->args[4].as()->value, true); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); // Get traced value. llvm::Value* traced_value = MakeValue(op->args[5]); // The update_block handles case when we need to update the return value. - llvm::BasicBlock* update_block = llvm::BasicBlock::Create(*ctx_, "update_block", function_); + llvm::BasicBlock* update_block = llvm::BasicBlock::Create(*ctx, "update_block", function_); // The continue_block handles case when we need to return original // traced value. - llvm::BasicBlock* continue_block = llvm::BasicBlock::Create(*ctx_, "continue_block", function_); + llvm::BasicBlock* continue_block = llvm::BasicBlock::Create(*ctx, "continue_block", function_); // Check the ret_type_code and create cmp instruction. llvm::Value* cmp = @@ -1254,14 +1262,15 @@ class MetadataSerializerLLVM : public AttrVisitor { }; void CodeGenCPU::DefineMetadata(runtime::metadata::Metadata metadata) { + llvm::LLVMContext* ctx = llvm_target_->GetContext(); MetadataLlvmTypes llvm_types{ t_float64_ /* t_float64 */, - llvm::Type::getInt8Ty(*ctx_) /* t_uint8 */, + llvm::Type::getInt8Ty(*ctx) /* t_uint8 */, t_int64_ /* t_int64 */, - llvm::Type::getInt8Ty(*ctx_) /* t_bool */, + llvm::Type::getInt8Ty(*ctx) /* t_bool */, t_char_->getPointerTo() /* t_cstring */, t_void_p_ /* t_void_p */, - llvm::StructType::create(*ctx_, {t_int8_, t_int8_, t_int8_}, "DLDataType") /* t_data_type */, + llvm::StructType::create(*ctx, {t_int8_, t_int8_, t_int8_}, "DLDataType") /* t_data_type */, }; // create sample ConstantInfoMetadata instance for MetadataTypeDefiner @@ -1278,7 +1287,7 @@ void CodeGenCPU::DefineMetadata(runtime::metadata::Metadata metadata) { metadata::DiscoverComplexTypesVisitor discover_complex{&queue}; discover_complex.Discover(metadata); - MetadataTypeDefiner definer{ctx_, &llvm_types}; + MetadataTypeDefiner definer{ctx, &llvm_types}; for (auto md : queue) { if (md.defined()) { definer.DefineType(md); @@ -1295,7 +1304,7 @@ void CodeGenCPU::DefineMetadata(runtime::metadata::Metadata metadata) { function_->setCallingConv(llvm::CallingConv::C); function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); - llvm::BasicBlock* entry_point_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); + llvm::BasicBlock* entry_point_entry = llvm::BasicBlock::Create(*ctx, "entry", function_); builder_->SetInsertPoint(entry_point_entry); auto ret_values_p = builder_->CreateBitCast(GetArg(function_, 3), t_void_p_->getPointerTo()); @@ -1350,7 +1359,8 @@ void CodeGenCPU::DefineFunctionRegistry(Array func_names) { function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, "TVMSystemLibEntryPoint", module_.get()); SetTargetAttributes(function_); - llvm::BasicBlock* entry_point_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); + llvm::BasicBlock* entry_point_entry = + llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry", function_); builder_->SetInsertPoint(entry_point_entry); builder_->CreateRet(builder_->CreateBitCast(module, t_void_p_)); } @@ -1361,7 +1371,8 @@ void CodeGenCPU::AddStartupFunction() { function_ = llvm::Function::Create(ftype, llvm::Function::InternalLinkage, "__tvm_module_startup", module_.get()); SetTargetAttributes(function_); - llvm::BasicBlock* startup_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); + llvm::BasicBlock* startup_entry = + llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry", function_); builder_->SetInsertPoint(startup_entry); for (const auto& kv : export_system_symbols_) { llvm::Value* name = GetConstString(kv.first); @@ -1385,7 +1396,8 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { } else if (op->op.same_as(builtin::tvm_throw_last_error())) { builder_->CreateRet(ConstInt32(-1)); auto next_block = std::next(builder_->GetInsertBlock()->getIterator()); - llvm::BasicBlock* new_bb = llvm::BasicBlock::Create(*ctx_, "cont", function_, &*next_block); + llvm::BasicBlock* new_bb = + llvm::BasicBlock::Create(*llvm_target_->GetContext(), "cont", function_, &*next_block); builder_->SetInsertPoint(new_bb); return ConstInt32(-1); } else if (op->op.same_as(builtin::tvm_struct_get())) { @@ -1443,8 +1455,9 @@ void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) { os << ", " << op->message.as()->value; } llvm::Value* msg = GetConstString(os.str()); - auto* fail_block = llvm::BasicBlock::Create(*ctx_, "assert_fail", function_); - auto* end_block = llvm::BasicBlock::Create(*ctx_, "assert_end", function_); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + auto* fail_block = llvm::BasicBlock::Create(*ctx, "assert_fail", function_); + auto* end_block = llvm::BasicBlock::Create(*ctx, "assert_end", function_); builder_->CreateCondBr(cond, end_block, fail_block, md_very_likely_branch_); // fail condition. builder_->SetInsertPoint(fail_block); @@ -1549,4 +1562,5 @@ TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_cpu") } // namespace codegen } // namespace tvm + #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index eec38b122a0b..e0716ac8be2d 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -24,6 +24,8 @@ #ifndef TVM_TARGET_LLVM_CODEGEN_CPU_H_ #define TVM_TARGET_LLVM_CODEGEN_CPU_H_ +#ifdef TVM_LLVM_VERSION + #include #include #include @@ -54,14 +56,16 @@ class Module; namespace tvm { namespace codegen { +class LLVMTarget; + // CPU host code generation class CodeGenCPU : public CodeGenLLVM { public: CodeGenCPU(); virtual ~CodeGenCPU(); - void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx, - bool system_lib, bool dynamic_lookup, bool target_c_runtime) override; + void Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib, + bool dynamic_lookup, bool target_c_runtime) override; void AddFunction(const PrimFunc& f) override; void AddMainFunction(const std::string& entry_func_name) override; std::unique_ptr Finish() override; @@ -197,4 +201,6 @@ class CodeGenCPU : public CodeGenLLVM { } // namespace codegen } // namespace tvm + +#endif // TVM_LLVM_VERSION #endif // TVM_TARGET_LLVM_CODEGEN_CPU_H_ diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index cab77697164d..1b9233d2ad2f 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -62,7 +62,7 @@ #include "../../runtime/hexagon/hexagon_module.h" #include "../build_common.h" #include "codegen_cpu.h" -#include "llvm_common.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { @@ -70,9 +70,9 @@ namespace codegen { // Hexagon code generation class CodeGenHexagon final : public CodeGenCPU { public: - void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx, - bool system_lib, bool dynamic_lookup, bool target_c_runtime) override; - void InitTarget(llvm::TargetMachine* tm) final; + void Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib, + bool dynamic_lookup, bool target_c_runtime) override; + void InitTarget() final; using CodeGenCPU::VisitStmt_; llvm::Value* VisitExpr_(const BufferLoadNode* op) override; @@ -117,29 +117,30 @@ class CodeGenHexagon final : public CodeGenCPU { llvm::Value* Intrinsic(llvm::Intrinsic::ID, llvm::ArrayRef args); }; -void CodeGenHexagon::Init(const std::string& module_name, llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup, - bool target_c_runtime) { - CodeGenCPU::Init(module_name, tm, ctx, system_lib, dynamic_lookup, target_c_runtime); +void CodeGenHexagon::Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib, + bool dynamic_lookup, bool target_c_runtime) { + CodeGenCPU::Init(module_name, llvm_target, system_lib, dynamic_lookup, target_c_runtime); } -void CodeGenHexagon::InitTarget(llvm::TargetMachine* tm) { - native_vector_bits_ = 64; // Assume "scalar" vectors at first. - llvm::StringRef fs = tm->getTargetFeatureString(); - size_t npos = llvm::StringRef::npos; +void CodeGenHexagon::InitTarget() { + native_vector_bits_ = 64; // Assume "scalar" vectors at first. const auto hvx_length_feature = "+hvx-length"; // +hvx-length{64|128}b - size_t len_begin = fs.find(hvx_length_feature); - size_t len_end = len_begin != npos ? fs.find('b', len_begin) : npos; - if (len_end != npos) { + for (const std::string& f : llvm_target_->GetTargetFeatures()) { + llvm::StringRef fs(f); + if (!fs.startswith(hvx_length_feature)) continue; + + ICHECK(fs.endswith("b")) << "malformed target feature: " << f; int hvx_bytes = 0; - len_begin += std::strlen(hvx_length_feature); - ICHECK(!fs.substr(len_begin, len_end - len_begin).getAsInteger(10, hvx_bytes)) - << "invalid HVX length in feature string: " << fs.str(); + size_t len_begin = std::strlen(hvx_length_feature); + ICHECK(!fs.substr(len_begin, fs.size() - len_begin - 1).getAsInteger(10, hvx_bytes)) + << "invalid HVX length in feature string: " << f; ICHECK(hvx_bytes == 64 || hvx_bytes == 128) << "invalid HVX vector length: " << hvx_bytes << ", should be 64 or 128"; native_vector_bits_ = hvx_bytes * 8; + // There should only be one hvx-length... + break; } - CodeGenLLVM::InitTarget(tm); + CodeGenLLVM::InitTarget(); } llvm::Value* CodeGenHexagon::CreateCallExternQHL(Type ret_type, String global_symbol, @@ -510,9 +511,8 @@ void ProcessLLVMOptions(const std::vector& llvm_vec) { } // namespace runtime::Module BuildHexagon(IRModule mod, Target target) { - // Make sure all targets are registered. InitializeLLVM can be called - // multiple times, after the first call all subsequent calls are no-ops. - InitializeLLVM(); + LLVMInstance llvm_instance; + With llvm_target(llvm_instance, target); auto split = [](const std::string& str, char delim = ' ') { std::vector vec; @@ -552,8 +552,6 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { static bool CallOnce = (ProcessLLVMOptions(llvm_options_vec), true); (void)CallOnce; - std::unique_ptr tm = GetLLVMTargetMachine(target); - std::unique_ptr ctx(new llvm::LLVMContext()); std::unique_ptr cg(new CodeGenHexagon()); std::vector funcs; @@ -574,7 +572,7 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { funcs.emplace_back(f); } - cg->Init("TVMHexagonModule", tm.get(), ctx.get(), false, false, false); + cg->Init("TVMHexagonModule", llvm_target.get(), false, false, false); cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); if (entry_func.length() != 0) { cg->AddMainFunction(entry_func); @@ -586,7 +584,7 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { enum CodeGenFileType { Asm, Obj, IR, BC }; - auto EmitToString = [&tm](const llvm::Module& m, CodeGenFileType cgft) { + auto EmitToString = [&llvm_target](const llvm::Module& m, CodeGenFileType cgft) { std::string out; if (cgft == IR || cgft == BC) { @@ -607,6 +605,7 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { llvm::raw_svector_ostream os(ss); std::unique_ptr cm = llvm::CloneModule(m); llvm::legacy::PassManager pass; + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); ICHECK(tm->addPassesToEmitFile(pass, os, nullptr, ft) == 0) << "Cannot emit target code"; pass.run(*cm.get()); out.assign(ss.c_str(), ss.size()); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index f1d891e2c3bd..305358d079d0 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -89,7 +89,7 @@ #include "../build_common.h" #include "../func_registry_generator.h" #include "codegen_params.h" -#include "llvm_common.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { @@ -102,8 +102,8 @@ CodeGenLLVM::CodeGenLLVM() = default; CodeGenLLVM::~CodeGenLLVM() = default; CodeGenLLVM::DebugInfo::~DebugInfo() = default; -std::unique_ptr CodeGenLLVM::Create(llvm::TargetMachine* tm) { - std::string target = tm->getTarget().getName(); +std::unique_ptr CodeGenLLVM::Create(LLVMTarget* llvm_target) { + std::string target = llvm_target->GetOrCreateTargetMachine()->getTarget().getName(); std::string factory_template = "tvm.codegen.llvm.target_"; void* handle = nullptr; if (const PackedFunc* f = runtime::Registry::Get(factory_template + target)) { @@ -121,38 +121,37 @@ std::unique_ptr CodeGenLLVM::Create(llvm::TargetMachine* tm) { } } -void CodeGenLLVM::Init(const std::string& module_name, llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup, - bool target_c_runtime) { - InitializeLLVM(); - ctx_ = ctx; - builder_.reset(new IRBuilder(*ctx_)); - module_.reset(new llvm::Module(module_name, *ctx_)); - md_builder_.reset(new llvm::MDBuilder(*ctx_)); +void CodeGenLLVM::Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib, + bool dynamic_lookup, bool target_c_runtime) { + llvm_target_ = llvm_target; + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + builder_.reset(new IRBuilder(*ctx)); + module_.reset(new llvm::Module(module_name, *ctx)); + md_builder_.reset(new llvm::MDBuilder(*ctx)); // types - t_void_ = llvm::Type::getVoidTy(*ctx_); - t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo(GetGlobalAddressSpace()); - t_int_ = llvm::Type::getInt32Ty(*ctx_); - t_char_ = llvm::Type::getInt8Ty(*ctx_); - t_int8_ = llvm::Type::getInt8Ty(*ctx_); - t_int16_ = llvm::Type::getInt16Ty(*ctx_); - t_int32_ = llvm::Type::getInt32Ty(*ctx_); - t_int64_ = llvm::Type::getInt64Ty(*ctx_); - t_float64_ = llvm::Type::getDoubleTy(*ctx_); + t_void_ = llvm::Type::getVoidTy(*ctx); + t_void_p_ = llvm::Type::getInt8Ty(*ctx)->getPointerTo(GetGlobalAddressSpace()); + t_int_ = llvm::Type::getInt32Ty(*ctx); + t_char_ = llvm::Type::getInt8Ty(*ctx); + t_int8_ = llvm::Type::getInt8Ty(*ctx); + t_int16_ = llvm::Type::getInt16Ty(*ctx); + t_int32_ = llvm::Type::getInt32Ty(*ctx); + t_int64_ = llvm::Type::getInt64Ty(*ctx); + t_float64_ = llvm::Type::getDoubleTy(*ctx); // meta data md_very_likely_branch_ = md_builder_->createBranchWeights(1 << 20, 1); md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa"); md_tbaa_alias_set_ = md_builder_->createTBAANode("tvm-alias", md_tbaa_root_); - this->InitTarget(tm); + InitTarget(); } -void CodeGenLLVM::SetFastMathFlag(llvm::FastMathFlags fmf) { builder_->setFastMathFlags(fmf); } +void CodeGenLLVM::SetFastMathFlags(llvm::FastMathFlags fmf) { builder_->setFastMathFlags(fmf); } -void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) { +void CodeGenLLVM::InitTarget() { + llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine(); module_->setTargetTriple(tm->getTargetTriple().str()); module_->setDataLayout(tm->createDataLayout()); data_layout_.reset(new llvm::DataLayout(module_.get())); - target_machine_ = tm; if (native_vector_bits_ == 0) { const auto& arch = tm->getTargetTriple().getArch(); if (arch == llvm::Triple::x86_64) { @@ -230,7 +229,8 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { } } } - llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx, "entry", function_); builder_->SetInsertPoint(entry); this->VisitStmt(f->body); @@ -242,7 +242,7 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { if (f != alloc_storage_info_.end()) { unsigned align = f->second.alignment; if (align > 1) { - auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align); + auto attr = llvm::Attribute::get(*ctx, llvm::Attribute::Alignment, align); function_->addParamAttr(i, attr); } } @@ -269,28 +269,16 @@ std::unique_ptr CodeGenLLVM::Finish() { } void CodeGenLLVM::HandleImport(const std::string& code) { + llvm::StringRef code_str(code); std::unique_ptr mlib; - llvm::SMDiagnostic err; - if (code.length() >= 3 && - (code.substr(code.length() - 3) == ".ll" || code.substr(code.length() - 3) == ".bc")) { - mlib = llvm::parseIRFile(code, err, *ctx_); - if (mlib.get() == nullptr) { - std::string msg = std::string(err.getMessage()); - LOG(FATAL) << "Fail to load bitcode file " << code << "\n" - << "line " << err.getLineNo() << ":" << msg; - } + if (code_str.endswith(".ll") || code_str.endswith(".bc")) { + mlib = llvm_target_->GetInstance().LoadIR(code); } else { - std::unique_ptr buf = llvm::MemoryBuffer::getMemBuffer(code); - mlib = llvm::parseIR(*buf, err, *ctx_); - if (mlib.get() == nullptr) { - std::string msg = std::string(err.getMessage()); - LOG(FATAL) << "Fail to load llvm ir " - << "line " << err.getLineNo() << ":" << msg << "\ncontent:\n" - << code; - } + mlib = llvm_target_->GetInstance().ParseIR(code); } - mlib->setTargetTriple(target_machine_->getTargetTriple().str()); - mlib->setDataLayout(target_machine_->createDataLayout()); + + mlib->setTargetTriple(llvm_target_->GetTargetTriple()); + mlib->setDataLayout(llvm_target_->GetOrCreateTargetMachine()->createDataLayout()); // mark all the functions as force inline for (llvm::Function& f : mlib->functions()) { f.removeFnAttr(llvm::Attribute::NoInline); @@ -338,16 +326,15 @@ void CodeGenLLVM::Optimize() { // pass manager FPassManager fpass(module_.get()); MPassManager mpass; - mpass.add(llvm::createTargetTransformInfoWrapperPass( - target_machine_ ? target_machine_->getTargetIRAnalysis() : llvm::TargetIRAnalysis())); - fpass.add(llvm::createTargetTransformInfoWrapperPass( - target_machine_ ? target_machine_->getTargetIRAnalysis() : llvm::TargetIRAnalysis())); + llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine(); + mpass.add(llvm::createTargetTransformInfoWrapperPass(tm->getTargetIRAnalysis())); + fpass.add(llvm::createTargetTransformInfoWrapperPass(tm->getTargetIRAnalysis())); // place optimization pass llvm::PassManagerBuilder builder; // Use the same opt-level as specified in TargetMachine for running passes - llvm::CodeGenOpt::Level opt_level = target_machine_->getOptLevel(); + llvm::CodeGenOpt::Level opt_level = llvm_target_->GetOptLevel(); switch (opt_level) { case llvm::CodeGenOpt::Level::None: @@ -376,7 +363,7 @@ void CodeGenLLVM::Optimize() { this->InitPassManagerBuilder(&builder); #if TVM_LLVM_VERSION >= 50 - target_machine_->adjustPassManager(builder); + tm->adjustPassManager(builder); #endif builder.populateFunctionPassManager(fpass); @@ -405,18 +392,19 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { return t_void_; } llvm::Type* etype = nullptr; + llvm::LLVMContext* ctx = llvm_target_->GetContext(); if (dtype.is_int() || dtype.is_uint()) { - etype = llvm::Type::getIntNTy(*ctx_, dtype.bits()); + etype = llvm::Type::getIntNTy(*ctx, dtype.bits()); } else if (dtype.is_float()) { switch (dtype.bits()) { case 16: - etype = llvm::Type::getHalfTy(*ctx_); + etype = llvm::Type::getHalfTy(*ctx); break; case 32: - etype = llvm::Type::getFloatTy(*ctx_); + etype = llvm::Type::getFloatTy(*ctx); break; case 64: - etype = llvm::Type::getDoubleTy(*ctx_); + etype = llvm::Type::getDoubleTy(*ctx); break; default: LOG(FATAL) << "do not support " << dtype; @@ -702,9 +690,10 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va const Var& loop_var, const Stmt& body) { llvm::BasicBlock* pre_block = builder_->GetInsertBlock(); std::string loop_var_name = loop_var->name_hint; - auto* for_begin = llvm::BasicBlock::Create(*ctx_, "for_begin_" + loop_var_name, function_); - auto* for_body = llvm::BasicBlock::Create(*ctx_, "for_body_" + loop_var_name, function_); - auto* for_end = llvm::BasicBlock::Create(*ctx_, "for_end_" + loop_var_name, function_); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + auto* for_begin = llvm::BasicBlock::Create(*ctx, "for_begin_" + loop_var_name, function_); + auto* for_body = llvm::BasicBlock::Create(*ctx, "for_body_" + loop_var_name, function_); + auto* for_end = llvm::BasicBlock::Create(*ctx, "for_end_" + loop_var_name, function_); builder_->CreateBr(for_begin); builder_->SetInsertPoint(for_begin); llvm::PHINode* loop_value = builder_->CreatePHI(begin->getType(), 2); @@ -777,7 +766,7 @@ llvm::Constant* CodeGenLLVM::GetGlobalConstant(llvm::Constant* const_data, const llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) { auto it = str_map_.find(str); if (it != str_map_.end()) return it->second; - auto llvm_str = llvm::ConstantDataArray::getString(*ctx_, str); + auto llvm_str = llvm::ConstantDataArray::getString(*llvm_target_->GetContext(), str); auto ptr = GetGlobalConstant(llvm_str, ".str", llvm::GlobalValue::PrivateLinkage); str_map_[str] = ptr; return ptr; @@ -950,11 +939,11 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type } void CodeGenLLVM::SetTargetAttributes(llvm::Function* func) { - llvm::StringRef cpu = target_machine_->getTargetCPU(); + const std::string& cpu = llvm_target_->GetCPU(); if (!cpu.empty()) { func->addFnAttr("target-cpu", cpu); } - llvm::StringRef features = target_machine_->getTargetFeatureString(); + const std::string& features = llvm_target_->GetTargetFeatureString(); if (!features.empty()) { func->addFnAttr("target-features", features); } @@ -980,8 +969,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { // mismatch will have to be treated specially here. // TODO(kparzysz-quic): fix this once TVM prefetch uses the same // type as LLVM. - llvm::Type* return_type = (id != llvm::Intrinsic::prefetch) ? GetLLVMType(GetRef(op)) - : llvm::Type::getVoidTy(*ctx_); + llvm::Type* return_type = + (id != llvm::Intrinsic::prefetch) ? GetLLVMType(GetRef(op)) : t_void_; llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type); ICHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: " #if TVM_LLVM_VERSION >= 130 @@ -1039,9 +1028,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return llvm::ConstantInt::get(DTypeToLLVMType(op->dtype), val); } else if (op->op.same_as(builtin::if_then_else())) { ICHECK_EQ(op->args[0].dtype().lanes(), 1) << "if_then_else can only take scalar condition"; - auto* then_block = llvm::BasicBlock::Create(*ctx_, "if_then", function_); - auto* else_block = llvm::BasicBlock::Create(*ctx_, "if_else", function_); - auto* end_block = llvm::BasicBlock::Create(*ctx_, "if_end", function_); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + auto* then_block = llvm::BasicBlock::Create(*ctx, "if_then", function_); + auto* else_block = llvm::BasicBlock::Create(*ctx, "if_else", function_); + auto* end_block = llvm::BasicBlock::Create(*ctx, "if_end", function_); builder_->CreateCondBr(MakeValue(op->args[0]), then_block, else_block); builder_->SetInsertPoint(then_block); llvm::Value* then_value = MakeValue(op->args[1]); @@ -1065,7 +1055,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { builder_->CreateRet(ConstInt32(0)); // LLVM allows exactly one terminator in a single basic block // append a new dummy basic block to avoid error. - llvm::BasicBlock* ret_dummy = llvm::BasicBlock::Create(*ctx_, "ret_dummy", function_); + llvm::BasicBlock* ret_dummy = + llvm::BasicBlock::Create(*llvm_target_->GetContext(), "ret_dummy", function_); builder_->SetInsertPoint(ret_dummy); return ret_dummy; } else if (op->op.same_as(builtin::reinterpret())) { @@ -1519,9 +1510,10 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) { } void CodeGenLLVM::VisitStmt_(const WhileNode* op) { - auto* while_cond = llvm::BasicBlock::Create(*ctx_, "while_cond", function_); - auto* while_body = llvm::BasicBlock::Create(*ctx_, "while_body", function_); - auto* while_merge = llvm::BasicBlock::Create(*ctx_, "while_merge", function_); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + auto* while_cond = llvm::BasicBlock::Create(*ctx, "while_cond", function_); + auto* while_body = llvm::BasicBlock::Create(*ctx, "while_body", function_); + auto* while_merge = llvm::BasicBlock::Create(*ctx, "while_merge", function_); builder_->CreateBr(while_cond); builder_->SetInsertPoint(while_cond); builder_->CreateCondBr(MakeValue(op->condition), while_body, while_merge); @@ -1533,10 +1525,11 @@ void CodeGenLLVM::VisitStmt_(const WhileNode* op) { void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { llvm::Value* cond = MakeValue(op->condition); - auto* then_block = llvm::BasicBlock::Create(*ctx_, "if_then", function_); - auto* end_block = llvm::BasicBlock::Create(*ctx_, "if_end", function_); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + auto* then_block = llvm::BasicBlock::Create(*ctx, "if_then", function_); + auto* end_block = llvm::BasicBlock::Create(*ctx, "if_end", function_); if (op->else_case.defined()) { - auto* else_block = llvm::BasicBlock::Create(*ctx_, "if_else", function_); + auto* else_block = llvm::BasicBlock::Create(*ctx, "if_else", function_); builder_->CreateCondBr(cond, then_block, else_block); builder_->SetInsertPoint(then_block); this->VisitStmt(op->then_case); @@ -1555,7 +1548,7 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { void CodeGenLLVM::VisitStmt_(const AllocateConstNode* op) { auto data = op->data.value(); - auto array = codegen::NDArrayToLLVMArray(ctx_, data); + auto array = NDArrayToLLVMArray(llvm_target_->GetContext(), data); std::string symbol_name = op->buffer_var->name_hint; llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable( *module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name); @@ -1673,4 +1666,5 @@ void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); } } // namespace codegen } // namespace tvm + #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index c6129c238c7f..e6321be647aa 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -23,6 +23,7 @@ */ #ifndef TVM_TARGET_LLVM_CODEGEN_LLVM_H_ #define TVM_TARGET_LLVM_CODEGEN_LLVM_H_ + #ifdef TVM_LLVM_VERSION #include @@ -40,7 +41,6 @@ #include #include #include -#include #include #if TVM_LLVM_VERSION >= 140 #include @@ -78,7 +78,6 @@ class Function; class GlobalVariable; class Instruction; class PassManagerBuilder; -class TargetMachine; class DIFile; class DICompileUnit; class MDNode; @@ -93,6 +92,8 @@ class MDBuilder; namespace tvm { namespace codegen { +class LLVMTarget; + using namespace tir; /*! @@ -109,7 +110,7 @@ class CodeGenLLVM : public ExprFunctor, * \param tm The target machine * \return The created llvm generator. */ - static std::unique_ptr Create(llvm::TargetMachine* tm); + static std::unique_ptr Create(LLVMTarget* llvm_target); /*! * \brief Initialize the code generator with given context * \param module_name The name of the module. @@ -121,14 +122,14 @@ class CodeGenLLVM : public ExprFunctor, * \param target_c_runtime If true, generate a module to be executed by the C runtime. In practice * this option influences whether global ctors are used. */ - virtual void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx, - bool system_lib, bool dynamic_lookup, bool target_c_runtime); + virtual void Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib, + bool dynamic_lookup, bool target_c_runtime); /*! * \brief Turn on fast math flags for floating point operations. * \param fmf FastMathFlags to use for code generation. */ - void SetFastMathFlag(llvm::FastMathFlags fmf); + void SetFastMathFlags(llvm::FastMathFlags fmf); /*! * \brief Compile and add function f to the current module. @@ -229,9 +230,6 @@ class CodeGenLLVM : public ExprFunctor, llvm::Constant* GetGlobalConstant( llvm::Constant* const_data, const std::string& name = "", llvm::GlobalValue::LinkageTypes linkage_type = llvm::GlobalValue::InternalLinkage); - inline llvm::ConstantArray* NDArrayToLLVMArray(::tvm::runtime::NDArray arr) { - return codegen::NDArrayToLLVMArray(ctx_, arr); - } protected: /*! @@ -340,7 +338,7 @@ class CodeGenLLVM : public ExprFunctor, bool is_volatile)> make_instruction); // Initialize target - virtual void InitTarget(llvm::TargetMachine* tm); + virtual void InitTarget(); // Add module startup function if needed. virtual void AddStartupFunction() {} // apply optimization on the module. @@ -476,10 +474,8 @@ class CodeGenLLVM : public ExprFunctor, std::unique_ptr data_layout_; // Internal metabuilder std::unique_ptr md_builder_; - // llvm target machine - llvm::TargetMachine* target_machine_{nullptr}; - // llvm context - llvm::LLVMContext* ctx_{nullptr}; + // llvm target info + LLVMTarget* llvm_target_{nullptr}; // helpful data types llvm::Type* t_void_{nullptr}; llvm::PointerType* t_void_p_{nullptr}; @@ -495,7 +491,7 @@ class CodeGenLLVM : public ExprFunctor, llvm::MDNode* md_tbaa_root_{nullptr}; llvm::MDNode* md_tbaa_alias_set_{nullptr}; // modules to be linked. - std::vector > link_modules_; + std::vector> link_modules_; /*! \brief native vector bits of current targetx*/ int native_vector_bits_{0}; /*! \brief the storage scope of allocation */ @@ -567,5 +563,6 @@ void CodeGenLLVM::AddFunctionsOrdered(IterType begin, IterType end, ConvType pfu } // namespace codegen } // namespace tvm -#endif // LLVM_VERSION + +#endif // TVM_LLVM_VERSION #endif // TVM_TARGET_LLVM_CODEGEN_LLVM_H_ diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index a74274009cf4..c758ca383621 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -56,7 +56,7 @@ #include "../../runtime/cuda/cuda_module.h" #include "../build_common.h" #include "codegen_llvm.h" -#include "llvm_common.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { @@ -68,10 +68,11 @@ class CodeGenNVPTX : public CodeGenLLVM { // add function as void return value CodeGenLLVM::AddFunctionInternal(f, true); // annotate as kernel function + llvm::LLVMContext* ctx = llvm_target_->GetContext(); module_->getOrInsertNamedMetadata("nvvm.annotations") ->addOperand(llvm::MDNode::get( - *ctx_, {llvm::ValueAsMetadata::get(function_), llvm::MDString::get(*ctx_, "kernel"), - llvm::ValueAsMetadata::get(ConstInt32(1))})); + *ctx, {llvm::ValueAsMetadata::get(function_), llvm::MDString::get(*ctx, "kernel"), + llvm::ValueAsMetadata::get(ConstInt32(1))})); } void VisitStmt_(const AllocateNode* op) final { @@ -203,10 +204,10 @@ class CodeGenNVPTX : public CodeGenLLVM { llvm::Value* CreateIntrinsic(const CallNode* op) override; protected: - void InitTarget(llvm::TargetMachine* tm) final { + void InitTarget() final { // Maximum vector lane = float4 native_vector_bits_ = 4 * 32; - CodeGenLLVM::InitTarget(tm); + CodeGenLLVM::InitTarget(); } }; @@ -298,15 +299,13 @@ int GetCUDAComputeVersion(const Target& target) { } runtime::Module BuildNVPTX(IRModule mod, Target target) { - InitializeLLVM(); + LLVMInstance llvm_instance; + With llvm_target(llvm_instance, target); + int compute_ver = GetCUDAComputeVersion(target); - std::unique_ptr tm = GetLLVMTargetMachine(target); - std::unique_ptr ctx(new llvm::LLVMContext()); - // careful: cg will hold a naked pointer reference to ctx, so it should - // have a shorter lifetime than the ctx. std::unique_ptr cg(new CodeGenNVPTX()); - cg->Init("TVMPTXModule", tm.get(), ctx.get(), false, false, false); + cg->Init("TVMPTXModule", llvm_target.get(), false, false, false); cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) { ICHECK(kv.second->template IsInstance()) @@ -314,18 +313,13 @@ runtime::Module BuildNVPTX(IRModule mod, Target target) { return Downcast(kv.second); }); + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); const auto* flibdevice_path = tvm::runtime::Registry::Get("tvm_callback_libdevice_path"); if (flibdevice_path != nullptr) { std::string path = (*flibdevice_path)(compute_ver); if (path.length() != 0) { - llvm::SMDiagnostic err; - std::unique_ptr mlib = llvm::parseIRFile(path, err, *ctx); - if (mlib.get() == nullptr) { - std::string msg(err.getMessage()); - LOG(FATAL) << "Fail to load bitcode file " << path << "\n" - << "line " << err.getLineNo() << ":" << msg; - } - mlib->setTargetTriple(tm->getTargetTriple().str()); + std::unique_ptr mlib = llvm_instance.LoadIR(path); + mlib->setTargetTriple(llvm_target->GetTargetTriple()); mlib->setDataLayout(tm->createDataLayout()); cg->AddLinkModule(std::move(mlib)); } @@ -365,4 +359,5 @@ TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_nvptx") } // namespace codegen } // namespace tvm + #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index 2d36e0b022e1..efe15c5c4aac 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -38,6 +38,7 @@ #include #include "codegen_cpu.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { @@ -91,9 +92,9 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { const auto to = op->dtype; if (from.is_float() && to.is_float() && from.bits() == 16 && to.bits() == 32) { ICHECK_EQ(from.lanes(), to.lanes()); - CHECK_NOTNULL(target_machine_); + llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine(); - const auto has_avx512 = TargetHasFeature(*target_machine_, "avx512f"); + const auto has_avx512 = TargetHasFeature(*tm, "avx512f"); if (from.lanes() >= 16 && has_avx512) { return CallVectorIntrin( @@ -110,7 +111,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { #if TVM_LLVM_VERSION <= 100 // The intrinsic x86_vcvtph2ps_256 was removed in LLVM 11. - const auto has_f16c = TargetHasFeature(*target_machine_, "f16c"); + const auto has_f16c = TargetHasFeature(*tm, "f16c"); if (from.lanes() >= 8 && has_f16c) { return CallVectorIntrin(llvm::Intrinsic::x86_vcvtph2ps_256, 8, @@ -168,4 +169,5 @@ TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64") } // namespace codegen } // namespace tvm + #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/llvm_common.cc b/src/target/llvm/llvm_common.cc deleted file mode 100644 index 83de839a926e..000000000000 --- a/src/target/llvm/llvm_common.cc +++ /dev/null @@ -1,211 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file llvm_common.cc - */ -#ifdef TVM_LLVM_VERSION - -#include "llvm_common.h" - -#if TVM_LLVM_VERSION >= 140 -#include -#else -#include -#endif -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace tvm { -namespace codegen { - -struct LLVMEnv { - std::mutex mu; - std::atomic all_initialized{false}; - - static LLVMEnv* Global() { - static LLVMEnv inst; - return &inst; - } -}; - -void InitializeLLVM() { - LLVMEnv* e = LLVMEnv::Global(); - if (!e->all_initialized.load(std::memory_order::memory_order_acquire)) { - std::lock_guard lock(e->mu); - if (!e->all_initialized.load(std::memory_order::memory_order_acquire)) { - llvm::InitializeAllTargetInfos(); - llvm::InitializeAllTargets(); - llvm::InitializeAllTargetMCs(); - llvm::InitializeAllAsmParsers(); - llvm::InitializeAllAsmPrinters(); - e->all_initialized.store(true, std::memory_order::memory_order_release); - } - } -} - -void ParseLLVMTargetOptions(const Target& target, std::string* triple, std::string* mcpu, - std::string* mattr, llvm::TargetOptions* options) { - // simple parser - triple->resize(0); - mcpu->resize(0); - mattr->resize(0); - bool soft_float_abi = false; - if (const Optional& v = target->GetAttr("mtriple")) { - *triple = v.value(); - } - if (const Optional& v = target->GetAttr("mcpu")) { - *mcpu = v.value(); - } - if (const Optional>& v = target->GetAttr>("mattr")) { - std::ostringstream os; - bool is_first = true; - for (const String& s : v.value()) { - if (!is_first) { - os << ','; - } - is_first = false; - os << s; - } - *mattr = os.str(); - } - if (const Optional& v = target->GetAttr("mfloat-abi")) { - String value = v.value(); - if (value == "hard") { -#if TVM_LLVM_VERSION < 60 - LOG(FATAL) << "-mfloat-abi hard is only supported for LLVM > 6.0"; -#endif - soft_float_abi = false; - } else if (value == "soft") { - soft_float_abi = true; - } else { - LOG(FATAL) << "invalid -mfloat-abi option " << value; - } - } - if (triple->length() == 0 || *triple == "default") { - *triple = llvm::sys::getDefaultTargetTriple(); - } - // set target option - llvm::TargetOptions& opt = *options; - opt = llvm::TargetOptions(); -#if TVM_LLVM_VERSION < 50 - opt.LessPreciseFPMADOption = true; -#endif - // In clang, these are fed from LangOpts which describe language specific features - // TODO(AndrewZhaoLuo): figure out how these relate to fast math flags - opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; - opt.UnsafeFPMath = false; - opt.NoInfsFPMath = false; - opt.NoNaNsFPMath = true; - if (soft_float_abi) { - opt.FloatABIType = llvm::FloatABI::Soft; - } else { - opt.FloatABIType = llvm::FloatABI::Hard; - } - if (const Optional& v = target->GetAttr("mabi")) { - opt.MCOptions.ABIName = v.value(); - } -} - -std::unique_ptr GetLLVMTargetMachine(const Target& target, bool allow_null) { - std::string target_triple, mcpu, mattr; - llvm::TargetOptions opt; - - ParseLLVMTargetOptions(target, &target_triple, &mcpu, &mattr, &opt); - - if (target_triple.length() == 0 || target_triple == "default") { - target_triple = llvm::sys::getDefaultTargetTriple(); - } - if (mcpu.length() == 0) { - mcpu = "generic"; - } - - std::string err; - const llvm::Target* llvm_target = llvm::TargetRegistry::lookupTarget(target_triple, err); - if (llvm_target == nullptr) { - ICHECK(allow_null) << err << " target_triple=" << target_triple; - return nullptr; - } - - int llvm_opt_level = target->GetAttr("opt-level").value_or(Integer(3)).IntValue(); - llvm::CodeGenOpt::Level llvm_opt; - if (llvm_opt_level <= 0) { - llvm_opt = llvm::CodeGenOpt::None; - } else if (llvm_opt_level == 1) { - llvm_opt = llvm::CodeGenOpt::Less; - } else if (llvm_opt_level == 2) { - llvm_opt = llvm::CodeGenOpt::Default; - } else { - // llvm_opt_level >= 3 - llvm_opt = llvm::CodeGenOpt::Aggressive; - } - - llvm::TargetMachine* tm = llvm_target->createTargetMachine( - target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_, llvm::CodeModel::Small, llvm_opt); - return std::unique_ptr(tm); -} - -std::string LLVMTargetToString(const Target& target) { - std::ostringstream os; - os << "llvm"; - if (Optional mtriple = target->GetAttr("mtriple")) { - os << " -mtriple=" << mtriple.value(); - } - if (Optional mcpu = target->GetAttr("mcpu")) { - os << " -mcpu=" << mcpu.value(); - } - if (Optional> mattr = target->GetAttr>("mattr")) { - bool is_first = true; - os << " -mattr="; - for (const String& attr : mattr.value()) { - if (!is_first) { - os << ","; - } - is_first = false; - os << attr; - } - } - if (Optional mfloat_abo = target->GetAttr("mfloat-abi")) { - os << " -mfloat-abi=" << mfloat_abo.value(); - } - if (Optional mabi = target->GetAttr("mabi")) { - os << " -mabi=" << mabi.value(); - } - return os.str(); -} - -} // namespace codegen -} // namespace tvm -#endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/llvm_common.h b/src/target/llvm/llvm_common.h deleted file mode 100644 index c127b77c03ac..000000000000 --- a/src/target/llvm/llvm_common.h +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file llvm_common.h - * \brief Common utilities for llvm initialization. - */ -#ifndef TVM_TARGET_LLVM_LLVM_COMMON_H_ -#define TVM_TARGET_LLVM_LLVM_COMMON_H_ - -#ifdef _MSC_VER -#pragma warning(disable : 4141 4291 4146 4624) -#endif -#ifdef TVM_LLVM_VERSION - -#include - -#include -#include -#include - -namespace llvm { -class Module; -class Target; -class TargetMachine; -class TargetOptions; -} // namespace llvm - -namespace tvm { - -// The TVM target -class Target; - -namespace codegen { - -/*! - * \brief Initialize LLVM on this process, - * can be called multiple times. - */ -void InitializeLLVM(); - -/*! - * \brief Parse target options - * \param target The TVM target - * \param triple Target triple - * \param mcpu cpu info - * \param options the options - * \param mattr The attributes - */ -void ParseLLVMTargetOptions(const Target& target, std::string* triple, std::string* mcpu, - std::string* mattr, llvm::TargetOptions* options); - -/*! - * \brief Get target machine from TVM target. - * \param target The TVM target - * \param allow_null Whether allow null to be returned. - * \return target machine - */ -std::unique_ptr GetLLVMTargetMachine(const Target& target, - bool allow_null = false); - -/*! - * \brief Convert the TVM's LLVM target to string by extracting only relevant fields - * \param target The TVM target to be extracted - * \return The raw string format for the TVM LLVM target - */ -std::string LLVMTargetToString(const Target& target); - -} // namespace codegen -} // namespace tvm - -#endif // TVM_LLVM_VERSION -#endif // TVM_TARGET_LLVM_LLVM_COMMON_H_ diff --git a/src/target/llvm/llvm_instance.cc b/src/target/llvm/llvm_instance.cc new file mode 100644 index 000000000000..772e71b28724 --- /dev/null +++ b/src/target/llvm/llvm_instance.cc @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifdef TVM_LLVM_VERSION + +#include "llvm_instance.h" + +#include +#include +#include +#if TVM_LLVM_VERSION >= 150 +#include +#else +#include +#endif +#include +#include +#include +#include +#if TVM_LLVM_VERSION >= 140 +#include +#else +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace codegen { + +namespace { +namespace defaults { +static const char* cpu = "generic"; +static const llvm::CodeGenOpt::Level opt_level = llvm::CodeGenOpt::Aggressive; +} // namespace defaults +} // namespace + +namespace { +bool InitializeLLVM() { + static std::atomic_flag initialized = ATOMIC_FLAG_INIT; + if (!initialized.test_and_set()) { + llvm::InitializeAllTargetInfos(); + llvm::InitializeAllTargets(); + llvm::InitializeAllTargetMCs(); + llvm::InitializeAllAsmParsers(); + llvm::InitializeAllAsmPrinters(); + } + return true; +} + +std::string Join(std::string sep, llvm::ArrayRef strings) { + std::string result; + bool is_first = true; + for (const std::string& s : strings) { + if (!is_first) { + result += sep; + } + result += s; + is_first = false; + } + return result; +} + +} // namespace + +// LLVMInstance + +LLVMInstance::LLVMInstance() { + // Call InitializeLLVM before anything else. + static const bool DMLC_ATTRIBUTE_UNUSED init_llvm = InitializeLLVM(); + ctx_ = std::make_shared(); +} + +LLVMInstance::~LLVMInstance() = default; + +std::unique_ptr LLVMInstance::ParseIR(const std::string& llvm_ir) const { + auto buffer = llvm::MemoryBuffer::getMemBuffer(llvm_ir, /*BufferName=*/"", + /*RequiresNullTerminator=*/false); + return ParseBuffer(*buffer); +} + +std::unique_ptr LLVMInstance::LoadIR(const std::string& file_name) const { + llvm::ErrorOr> maybe_buffer = + llvm::MemoryBuffer::getFileAsStream(file_name); + if (std::error_code ec = maybe_buffer.getError()) { + LOG(FATAL) << ec.message(); + } + return ParseBuffer(**maybe_buffer); +} + +std::unique_ptr LLVMInstance::ParseBuffer(const llvm::MemoryBuffer& buffer) const { + llvm::SMDiagnostic error; + std::unique_ptr module = llvm::parseIR(buffer.getMemBufferRef(), error, *ctx_); + if (module == nullptr) { + std::string message; + llvm::raw_string_ostream ostream(message); + error.print(/*ProgName=*/nullptr, ostream, /*ShowColors=*/false, /*ShowKindLabel=*/true); + LOG(FATAL) << ostream.str(); + } + + return module; +} + +// LLVMTarget + +LLVMTarget::LLVMTarget(LLVMInstance& instance, const Target& target) + : instance_(instance), ctx_(instance.GetContext()) { + triple_ = target->GetAttr("mtriple").value_or("default"); + + if (triple_.empty() || triple_ == "default") { + triple_ = llvm::sys::getDefaultTargetTriple(); + } + cpu_ = target->GetAttr("mcpu").value_or(defaults::cpu); + + if (const Optional>& v = target->GetAttr>("mattr")) { + for (const String& s : v.value()) { + attrs_.push_back(s); + } + } + + llvm::FloatABI::ABIType float_abi = llvm::FloatABI::Default; + if (const Optional& v = target->GetAttr("mfloat-abi")) { + String value = v.value(); + if (value == "hard") { + float_abi = llvm::FloatABI::Hard; + } else if (value == "soft") { + float_abi = llvm::FloatABI::Soft; + } else { + LOG(FATAL) << "invalid -mfloat-abi option " << value; + } + } + + // Target options + +#if TVM_LLVM_VERSION < 50 + target_options_.LessPreciseFPMADOption = true; +#endif + // In clang, these are fed from LangOpts which describe language specific features + // TODO(AndrewZhaoLuo): figure out how these relate to fast math flags + target_options_.AllowFPOpFusion = llvm::FPOpFusion::Fast; + target_options_.UnsafeFPMath = false; + target_options_.NoInfsFPMath = false; + target_options_.NoNaNsFPMath = true; + target_options_.FloatABIType = float_abi; + if (const Optional& v = target->GetAttr("mabi")) { + target_options_.MCOptions.ABIName = v.value(); + } + + auto maybe_level = target->GetAttr("opt-level"); + + if (maybe_level.defined()) { + int level = maybe_level.value()->value; + if (level <= 0) { + opt_level_ = llvm::CodeGenOpt::None; + } else if (level == 1) { + opt_level_ = llvm::CodeGenOpt::Less; + } else if (level == 2) { + opt_level_ = llvm::CodeGenOpt::Default; + } else { + // level >= 3 + opt_level_ = llvm::CodeGenOpt::Aggressive; + } + } else { + opt_level_ = defaults::opt_level; + } + + // Fast math options + + auto GetBoolFlag = [&target](llvm::StringRef flag) -> bool { + return target->GetAttr(flag.str()).value_or(Bool(false)); + }; + if (GetBoolFlag("fast-math")) { +#if TVM_LLVM_VERSION >= 60 + fast_math_flags_.setFast(); +#else + fast_math_flags_.setUnsafeAlgebra(); +#endif + } else { +#if TVM_LLVM_VERSION >= 50 + // This option was added in 5.x, and has a boolean argument, + // unlike the rest of options at the time. + fast_math_flags_.setAllowContract(GetBoolFlag("fast-math-contract")); +#endif +#if TVM_LLVM_VERSION >= 70 + fast_math_flags_.setNoNaNs(GetBoolFlag("fast-math-nnan")); + fast_math_flags_.setNoInfs(GetBoolFlag("fast-math-ninf")); + fast_math_flags_.setNoSignedZeros(GetBoolFlag("fast-math-nsz")); + fast_math_flags_.setAllowReciprocal(GetBoolFlag("fast-math-arcp")); + fast_math_flags_.setAllowContract(GetBoolFlag("fast-math-contract")); + fast_math_flags_.setAllowReassoc(GetBoolFlag("fast-math-reassoc")); + fast_math_flags_.setApproxFunc(GetBoolFlag("fast-math-afn")); +#else + // LLVM 4.x, 5.x, and 6.x + if (GetBoolFlag("fast-math-nnan")) fast_math_flags_.setNoNaNs(); + if (GetBoolFlag("fast-math-ninf")) fast_math_flags_.setNoInfs(); + if (GetBoolFlag("fast-math-nsz")) fast_math_flags_.setNoSignedZeros(); + if (GetBoolFlag("fast-math-arcp")) fast_math_flags_.setAllowReciprocal(); +#if TVM_LLVM_VERSION >= 60 + if (GetBoolFlag("fast-math-reassoc")) fast_math_flags_.setAllowReassoc(); + if (GetBoolFlag("fast-math-afn")) fast_math_flags_.setApproxFunc(); +#endif +#endif + } +} + +LLVMTarget::LLVMTarget(LLVMInstance& scope, const std::string& target_str) + : LLVMTarget(scope, Target(target_str)) {} + +LLVMTarget::~LLVMTarget() = default; + +llvm::LLVMContext* LLVMTarget::GetContext() const { + ICHECK(!ctx_.expired()) << "LLVM scope has been deleted"; + return ctx_.lock().get(); +} + +llvm::TargetMachine* LLVMTarget::GetOrCreateTargetMachine(bool allow_missing) { + if (target_machine_) return target_machine_.get(); + + std::string error; + if (const llvm::Target* llvm_instance = llvm::TargetRegistry::lookupTarget(triple_, error)) { + llvm::TargetMachine* tm = + llvm_instance->createTargetMachine(triple_, cpu_, GetTargetFeatureString(), target_options_, + reloc_model_, code_model_, opt_level_); + target_machine_ = std::unique_ptr(tm); + if (!allow_missing) { + ICHECK(target_machine_ != nullptr) << error; + } + } + return target_machine_.get(); +} + +std::string LLVMTarget::GetTargetFeatureString() const { // + return Join(",", attrs_); +} + +std::string LLVMTarget::str() const { + std::ostringstream os; + os << "llvm"; + if (!triple_.empty()) { + os << " -mtriple=" << triple_; + } + if (!cpu_.empty() && cpu_ != defaults::cpu) { + os << " -mcpu=" << cpu_; + } + if (!attrs_.empty()) { + os << " -mattr=" << GetTargetFeatureString(); + } + + switch (target_options_.FloatABIType) { + case llvm::FloatABI::Soft: + os << " -mfloat-abi=soft"; + break; + case llvm::FloatABI::Hard: + os << " -mfloat-abi=hard"; + break; + case llvm::FloatABI::Default: + break; + } + if (!target_options_.MCOptions.ABIName.empty()) { + os << " -mabi=" << target_options_.MCOptions.ABIName; + } + + bool do_individual = true; +#if TVM_LLVM_VERSION >= 60 + if (fast_math_flags_.isFast()) { + os << " -fast-math"; + do_individual = false; + } +#else + if (fast_math_flags_.unsafeAlgebra()) { + os << " -fast-math"; + do_individual = false; + } +#endif + + if (do_individual) { + if (fast_math_flags_.noNaNs()) os << " -fast-math-nnan"; + if (fast_math_flags_.noInfs()) os << " -fast-math-ninf"; + if (fast_math_flags_.noSignedZeros()) os << " -fast-math-nsz"; + if (fast_math_flags_.allowReciprocal()) os << " -fast-math-arcp"; +#if TVM_LLVM_VERSION >= 50 + if (fast_math_flags_.allowContract()) os << " -fast-math-contract"; +#endif +#if TVM_LLVM_VERSION >= 60 + if (fast_math_flags_.allowReassoc()) os << " -fast-math-reassoc"; + if (fast_math_flags_.approxFunc()) os << " -fast-math-afn"; +#endif + } + + if (opt_level_ != defaults::opt_level) { + os << " -opt-level="; + switch (opt_level_) { + case llvm::CodeGenOpt::None: + os << "0"; + break; + case llvm::CodeGenOpt::Less: + os << "1"; + break; + case llvm::CodeGenOpt::Default: + os << "2"; + break; + case llvm::CodeGenOpt::Aggressive: + os << "3"; + break; + } + } + + return os.str(); +} + +std::string LLVMTarget::GetTargetMetadata(const llvm::Module& module) { + if (llvm::Metadata* tvm_target = module.getModuleFlag("tvm_target")) { + auto* mdstr = llvm::cast(tvm_target); + llvm::StringRef meta = mdstr->getString(); + if (meta.startswith("llvm")) { + return meta.str(); + } + } + return "llvm -mtriple " + module.getTargetTriple(); +} + +void LLVMTarget::SetTargetMetadata(llvm::Module* module) const { + module->addModuleFlag(llvm::Module::Warning, "tvm_target", + llvm::MDString::get(*GetContext(), str())); +} + +} // namespace codegen +} // namespace tvm + +#endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/llvm_instance.h b/src/target/llvm/llvm_instance.h new file mode 100644 index 000000000000..afb6e58deb1f --- /dev/null +++ b/src/target/llvm/llvm_instance.h @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! \file llvm_instance.h + */ +#ifndef TVM_TARGET_LLVM_LLVM_INSTANCE_H_ +#define TVM_TARGET_LLVM_LLVM_INSTANCE_H_ + +#ifdef TVM_LLVM_VERSION + +#include +#if TVM_LLVM_VERSION >= 150 +#include +#else +#include +#endif +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace llvm { +class LLVMContext; +class MemoryBuffer; +class Module; +class TargetMachine; +} // namespace llvm + +namespace tvm { +namespace codegen { + +class LLVMTarget; + +/*! + * \class LLVMInstance + * \brief LLVMInstance is a class that (conceptually) starts and stops LLVM. All + * uses of LLVM should take place within a lifetime of an object of this class. + * + * E.g. + * ```{.cpp} + * { + * LLVMInstance llvm_instance; + * ... + * someFunctionFromLLVM(...); + * ... + * } + * // no more calls to LLVM here + * ``` + * In addition to that, LLVMInstance provides an LLVM context (llvm::LLVMContext). + * The context is a structure in LLVM where common IR constructs are maintained, + * (such as types, constants, etc.) so that they can be identified by their + * address (i.e. pointer comparison). Because of that, it's important to use + * the same context throughout compilation. + * + * At the moment the "starting" of LLVM performs initialization of LLVM, but + * "stopping" doesn't do anything. In the future, if such a need arises, this + * functionality may be extended to perform dlopen/dlclose of the LLVM-based + * code in TVM. + * + * This class provides means to deserialize an LLVM module, either from text + * (in a string), or from a file. In either case, the serialized module can + * be LLVM IR assembly, or binary bitcode enconding. + */ +class LLVMInstance { + public: + /*! + * \brief Constructs LLVMInstance + */ + LLVMInstance(); + /*! + * \brief Destroys LLVMInstance object + */ + ~LLVMInstance(); // Must not be "= default" here in the header file. + + /*! + * \brief Get the LLVM context for this scope. + */ + std::shared_ptr GetContext() const { return ctx_; } + + /*! + * \brief Create `llvm::Module` from a string. + * + * Parse the string in \param llvm_ir, and return the `llvm::Module`. + * At the moment this function will abort if the parsing fails. + * \param llvm_ir string with the LLVM IR assembly or bitcode + * \return created `llvm::Module` + */ + std::unique_ptr ParseIR(const std::string& llvm_ir) const; + /*! + * \brief Load `llvm::Module` from a given file + * + * Read the file \param file_name, and return the `llvm::Module`. + * At the moment this function will abort if reading of the file or creation + * of the module fails. + * \param file_name file with the LLVM IR assembly or bitcode + * \return created `llvm::Module` + */ + std::unique_ptr LoadIR(const std::string& file_name) const; + + private: + std::unique_ptr ParseBuffer(const llvm::MemoryBuffer& buffer) const; + + std::shared_ptr ctx_; +}; + +/*! + * \class LLVMTarget + * \brief Information used by LLVM for code generation for particular target + * + * This class contains all information that LLVM needs for code generation for + * a particular target. Since Target in TVM will soon contain command line + * flags for LLVM, objects of this class will handle saving and restoring + * global LLVM state that may be affected by these flags. This way, code + * generation for each LLVM-based target in TVM will start with the same LLVM + * global state. + * + * Note that objects of this class must be created within the lifetime of an + * LLVMInstance object. + */ +class LLVMTarget { + public: + /*! + * \brief Constructs LLVMTarget from `Target` + * \param scope LLVMInstance object + * \param target TVM Target object for target "llvm" + */ + LLVMTarget(LLVMInstance& scope, const Target& target); // NOLINT(runtime/references) + /*! + * \brief Constructs LLVMTarget from target string + * \param scope LLVMInstance object + * \param target TVM target string for target "llvm" + */ + LLVMTarget(LLVMInstance& scope, const std::string& target_str); // NOLINT(runtime/references) + /*! + * \brief Destroys LLVMTarget object + */ + ~LLVMTarget(); + + /*! + * \brief Returns string representation (as TVM target) of the LLVMTarget + * \return Target string + * + * Note: If the LLVMTarget object was created from a string `s`, the string + * returned here may not be exactly equal to `s`. For example, if the CPU + * was "default", the returned string will have CPU set to the detected host + * CPU. + */ + std::string str() const; + + /*! + * \brief Get the LLVMInstance object from which the LLVMTarget object was + * created + * \return The enclosing LLVMInstance object + */ + const LLVMInstance& GetInstance() const { return instance_; } + /*! + * \brief Get the current LLVM context + * \return the current LLVM context + */ + llvm::LLVMContext* GetContext() const; + /*! + * \brief Return LLVM's `TargetMachine`, or nullptr + * \param allow_missing do not abort if the target machine cannot be created, + * return nullptr instead + * \return Pointer to the `TargetMachine` object (or nullptr if it cannot be + * created, \see allow_missing) + */ + llvm::TargetMachine* GetOrCreateTargetMachine(bool allow_missing = false); + + /*! + * \brief Get the target triple + * \return the target triple + */ + const std::string& GetTargetTriple() const { return triple_; } + /*! + * \brief Get the CPU name + * \return the CPU name: the detected host CPU if the original TVM target + * specified it as "default" + */ + const std::string& GetCPU() const { return cpu_; } + /*! + * \brief Get the list of LLVM target features + * \return array of individual feature strings + */ + llvm::ArrayRef GetTargetFeatures() const { return attrs_; } + /*! + * \brief Get the LLVM target feature string + * \return comma-separated list of LLVM target features + */ + std::string GetTargetFeatureString() const; + /*! + * \brief Get the LLVM target options + * \return `llvm::TargetOptions` object for this target + */ + const llvm::TargetOptions& GetTargetOptions() const { return target_options_; } + /*! + * \brief Get fast math flags + * \return `llvm::FastMathFlags` for this target + */ + llvm::FastMathFlags GetFastMathFlags() const { return fast_math_flags_; } + /*! + * \brief Get the LLVM optimization level + * \return optimization level for this target + */ + llvm::CodeGenOpt::Level GetOptLevel() const { return opt_level_; } + + /*! + * \brief Extract the target string from given `llvm::Module` + * \param module LLVM module with the TVM target string embedded as metadata + * \return the target string from module's metadata + */ + static std::string GetTargetMetadata(const llvm::Module& module); + /*! + * \brief Embed target string as metadata in given `llvm::Module` + * \param module the module to insert the target string into + */ + void SetTargetMetadata(llvm::Module* module) const; + + // Stubs to enable use with `With`. + void EnterWithScope() {} + void ExitWithScope() {} + + private: + const LLVMInstance& instance_; + std::weak_ptr ctx_; + + std::string triple_; + std::string cpu_; + std::vector attrs_; + llvm::TargetOptions target_options_; + llvm::FastMathFlags fast_math_flags_; + llvm::CodeGenOpt::Level opt_level_; + llvm::Reloc::Model reloc_model_ = llvm::Reloc::PIC_; + llvm::CodeModel::Model code_model_ = llvm::CodeModel::Small; + std::shared_ptr target_machine_; +}; + +} // namespace codegen +} // namespace tvm + +#endif // TVM_LLVM_VERSION +#endif // TVM_TARGET_LLVM_LLVM_INSTANCE_H_ diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 69c7632d65ea..9aed66fffc5c 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -51,11 +51,13 @@ #include #include #include +#include #include #include #include #include #include +#include #include #include @@ -74,7 +76,7 @@ #include "codegen_blob.h" #include "codegen_cpu.h" #include "codegen_llvm.h" -#include "llvm_common.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { @@ -85,398 +87,338 @@ using runtime::TVMRetValue; class LLVMModuleNode final : public runtime::ModuleNode { public: - ~LLVMModuleNode() { - module_owning_ptr_.reset(); - if (ee_ != nullptr) { - ee_->runStaticConstructorsDestructors(true); - delete ee_; - } - } + ~LLVMModuleNode(); const char* type_key() const final { return "llvm"; } - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { - if (name == "__tvm_is_system_module") { - bool flag = (module_->getFunction("__tvm_module_startup") != nullptr); - return PackedFunc([flag](TVMArgs args, TVMRetValue* rv) { *rv = flag; }); - } else if (name == "get_func_names") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->function_names_; }); - } else if (name == "get_symbol") { - return PackedFunc(nullptr); - } else if (name == "get_const_vars") { - return PackedFunc(nullptr); - } else if (name == "_get_target_string") { - std::string target_string = LLVMTargetToString(target_); - return PackedFunc([target_string](TVMArgs args, TVMRetValue* rv) { *rv = target_string; }); - } - if (ee_ == nullptr) LazyInitJIT(); - - std::lock_guard lock(mutex_); - - TVMBackendPackedCFunc faddr; - if (name == runtime::symbol::tvm_module_main) { - const char* entry_name = - reinterpret_cast(GetGlobalAddr(runtime::symbol::tvm_module_main)); - ICHECK(entry_name != nullptr) - << "Symbol " << runtime::symbol::tvm_module_main << " is not presented"; - faddr = reinterpret_cast(GetFunctionAddr(entry_name)); - } else { - faddr = reinterpret_cast(GetFunctionAddr(name)); - } - if (faddr == nullptr) return PackedFunc(); - return WrapPackedFunc(faddr, sptr_to_self); + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; + + void SaveToFile(const std::string& file_name, const std::string& format) final; + void SaveToBinary(dmlc::Stream* stream) final; + std::string GetSource(const std::string& format) final; + + void Init(const IRModule& mod, const Target& target); + void Init(std::unique_ptr module, std::unique_ptr llvm_instance); + void LoadIR(const std::string& file_name); + bool IsDSOExportable() const final { return true; } + + bool ImplementsFunction(const String& name, bool query_imports) final; + + private: + void LazyInitJIT(); + bool IsCompatibleWithHost(const llvm::TargetMachine* tm) const; + void* GetGlobalAddr(const std::string& name, const LLVMTarget& llvm_target) const; + void* GetFunctionAddr(const std::string& name, const LLVMTarget& llvm_target) const; + + // The LLVM scope object. + std::unique_ptr llvm_instance_; + // JIT lock + std::mutex mutex_; + // execution engine + llvm::ExecutionEngine* ee_{nullptr}; + // The raw pointer to the module. + llvm::Module* module_{nullptr}; + // The unique_ptr owning the module. This becomes empty once JIT has been initialized + // (EngineBuilder takes ownership of the module). + std::unique_ptr module_owning_ptr_; + /* \brief names of the functions declared in this module */ + Array function_names_; +}; + +LLVMModuleNode::~LLVMModuleNode() { + if (ee_ != nullptr) { + ee_->runStaticConstructorsDestructors(true); + delete ee_; } + module_owning_ptr_.reset(); +} - void SaveToFile(const std::string& file_name, const std::string& format) final { - std::string fmt = runtime::GetFileFormat(file_name, format); - std::error_code ecode; +PackedFunc LLVMModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + if (name == "__tvm_is_system_module") { + bool flag = (module_->getFunction("__tvm_module_startup") != nullptr); + return PackedFunc([flag](TVMArgs args, TVMRetValue* rv) { *rv = flag; }); + } else if (name == "get_func_names") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->function_names_; }); + } else if (name == "get_symbol") { + return PackedFunc(nullptr); + } else if (name == "get_const_vars") { + return PackedFunc(nullptr); + } else if (name == "_get_target_string") { + std::string target_string = LLVMTarget::GetTargetMetadata(*module_); + return PackedFunc([target_string](TVMArgs args, TVMRetValue* rv) { *rv = target_string; }); + } + if (ee_ == nullptr) LazyInitJIT(); + + std::lock_guard lock(mutex_); + + TVMBackendPackedCFunc faddr; + With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); + if (name == runtime::symbol::tvm_module_main) { + const char* entry_name = reinterpret_cast( + GetGlobalAddr(runtime::symbol::tvm_module_main, *llvm_target)); + ICHECK(entry_name != nullptr) << "Symbol " << runtime::symbol::tvm_module_main + << " is not presented"; + faddr = reinterpret_cast(GetFunctionAddr(entry_name, *llvm_target)); + } else { + faddr = reinterpret_cast(GetFunctionAddr(name, *llvm_target)); + } + if (faddr == nullptr) return PackedFunc(); + return WrapPackedFunc(faddr, sptr_to_self); +} + +void LLVMModuleNode::SaveToFile(const std::string& file_name, const std::string& format) { + std::string fmt = runtime::GetFileFormat(file_name, format); + std::error_code ecode; #if TVM_LLVM_VERSION <= 70 - llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::F_None); + llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::F_None); #else - llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::OF_None); + llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::OF_None); #endif - ICHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name << " " << ecode.message(); - if (fmt == "o" || fmt == "obj") { + ICHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name << " " << ecode.message(); + if (fmt == "o" || fmt == "obj") { + With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); #if TVM_LLVM_VERSION <= 60 - std::unique_ptr m = llvm::CloneModule(module_); + std::unique_ptr m = llvm::CloneModule(module_); #else - std::unique_ptr m = llvm::CloneModule(*module_); + std::unique_ptr m = llvm::CloneModule(*module_); #endif - llvm::legacy::PassManager pass; - ICHECK(tm_); + llvm::legacy::PassManager pass; + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); #if TVM_LLVM_VERSION <= 60 - ICHECK(tm_->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_ObjectFile) == 0) - << "Cannot emit target CGFT_ObjectFile"; + ICHECK(tm->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_ObjectFile) == 0) + << "Cannot emit target CGFT_ObjectFile"; #elif TVM_LLVM_VERSION <= 90 - ICHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == - 0) - << "Cannot emit target CGFT_ObjectFile"; + ICHECK(tm->addPassesToEmitFile(pass, dest, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0) + << "Cannot emit target CGFT_ObjectFile"; #else - ICHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_ObjectFile) == 0) - << "Cannot emit target CGFT_ObjectFile"; + ICHECK(tm->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_ObjectFile) == 0) + << "Cannot emit target CGFT_ObjectFile"; #endif - pass.run(*m); - } else if (fmt == "s" || fmt == "asm") { + pass.run(*m); + } else if (fmt == "s" || fmt == "asm") { + With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); #if TVM_LLVM_VERSION <= 60 - std::unique_ptr m = llvm::CloneModule(module_); + std::unique_ptr m = llvm::CloneModule(module_); #else - std::unique_ptr m = llvm::CloneModule(*module_); + std::unique_ptr m = llvm::CloneModule(*module_); #endif - llvm::legacy::PassManager pass; - ICHECK(tm_); + llvm::legacy::PassManager pass; + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); #if TVM_LLVM_VERSION <= 60 - ICHECK(tm_->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; + ICHECK(tm->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; #elif TVM_LLVM_VERSION <= 90 - ICHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, - llvm::TargetMachine::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; + ICHECK(tm->addPassesToEmitFile(pass, dest, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == + 0) + << "Cannot emit target CGFT_AssemblyFile"; #else - ICHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; + ICHECK(tm->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; #endif - pass.run(*m); - } else if (fmt == "ll") { - module_->print(dest, nullptr); - } else if (fmt == "bc") { + pass.run(*m); + } else if (fmt == "ll") { + module_->print(dest, nullptr); + } else if (fmt == "bc") { #if TVM_LLVM_VERSION <= 60 - llvm::WriteBitcodeToFile(module_, dest); + llvm::WriteBitcodeToFile(module_, dest); #else - llvm::WriteBitcodeToFile(*module_, dest); + llvm::WriteBitcodeToFile(*module_, dest); #endif - } else { - LOG(FATAL) << "Do not know how to save file " << file_name << " with format=\'" << format - << "\'"; - } - dest.close(); + } else { + LOG(FATAL) << "Do not know how to save file " << file_name << " with format=\'" << format + << "\'"; } + dest.close(); +} - void SaveToBinary(dmlc::Stream* stream) final { - LOG(FATAL) << "LLVMModule: SaveToBinary not supported"; - } +void LLVMModuleNode::SaveToBinary(dmlc::Stream* stream) { + LOG(FATAL) << "LLVMModule: SaveToBinary not supported"; +} - std::string GetSource(const std::string& format) final { - std::string fmt = runtime::GetFileFormat("", format); - std::string type_str; - llvm::SmallString<256> str; - llvm::raw_svector_ostream rso(str); +std::string LLVMModuleNode::GetSource(const std::string& format) { + std::string fmt = runtime::GetFileFormat("", format); + std::string type_str; + llvm::SmallString<256> str; + llvm::raw_svector_ostream rso(str); - if (fmt == "s" || fmt == "asm") { + if (fmt == "s" || fmt == "asm") { + With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); #if TVM_LLVM_VERSION <= 60 - std::unique_ptr m = llvm::CloneModule(module_); + std::unique_ptr m = llvm::CloneModule(module_); #else - std::unique_ptr m = llvm::CloneModule(*module_); + std::unique_ptr m = llvm::CloneModule(*module_); #endif - llvm::legacy::PassManager pass; - ICHECK(tm_); + llvm::legacy::PassManager pass; + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); #if TVM_LLVM_VERSION <= 60 - ICHECK(tm_->addPassesToEmitFile(pass, rso, llvm::TargetMachine::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; + ICHECK(tm->addPassesToEmitFile(pass, rso, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; #elif TVM_LLVM_VERSION <= 90 - ICHECK(tm_->addPassesToEmitFile(pass, rso, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == - 0) - << "Cannot emit target CGFT_AssemblyFile"; + ICHECK(tm->addPassesToEmitFile(pass, rso, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; #else - ICHECK(tm_->addPassesToEmitFile(pass, rso, nullptr, llvm::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; + ICHECK(tm->addPassesToEmitFile(pass, rso, nullptr, llvm::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; #endif - pass.run(*m); - return rso.str().str(); - } else if (fmt == "" || fmt == "ll") { - std::string type_str; - llvm::raw_string_ostream rso(type_str); - ICHECK(module_ != nullptr); - module_->print(rso, nullptr); - return rso.str(); - } else { - LOG(FATAL) << "Do not know how to get source code with format: " << format << "\'"; - } - return ""; + pass.run(*m); + return rso.str().str(); + } else if (fmt == "" || fmt == "ll") { + std::string type_str; + llvm::raw_string_ostream rso(type_str); + ICHECK(module_ != nullptr); + module_->print(rso, nullptr); + return rso.str(); + } else { + LOG(FATAL) << "Do not know how to get source code with format: " << format << "\'"; } + return ""; +} - void Init(const IRModule& mod, const Target& target) { - InitializeLLVM(); - tm_ = GetLLVMTargetMachine(target); - ctx_ = std::make_shared(); - std::unique_ptr cg = CodeGenLLVM::Create(tm_.get()); - - std::vector funcs; - std::string entry_func; - relay::Runtime runtime = - mod->GetAttr(tvm::attr::kRuntime).value_or(relay::Runtime::Create("cpp")); - bool system_lib = runtime->GetAttr("system-lib").value_or(Bool(false)); - bool target_c_runtime = runtime->name == "crt"; - - for (auto kv : mod->functions) { - if (!kv.second->IsInstance()) { - // (@jroesch): we relax constraints here, Relay functions will just be ignored. - DLOG(INFO) << "Can only lower IR Module with PrimFuncs, but got " - << kv.second->GetTypeKey(); - continue; - } - auto f = Downcast(kv.second); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()); - function_names_.push_back(global_symbol.value()); - if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { - entry_func = global_symbol.value(); - } - funcs.push_back(f); - } - // TODO(@jroesch): follow up on this condition. - // ICHECK(funcs.size() > 0); - // TODO(tqchen): remove the entry function behavior as it does not - // makes sense when we start to use multiple modules. - cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib, target_c_runtime); - - // See https://llvm.org/docs/LangRef.html#fast-math-flags for details - Bool fast_math_all = target->GetAttr("fast-math").value_or(Bool(false)); - Bool fast_math_nnan = target->GetAttr("fast-math-nnan").value_or(Bool(false)); - Bool fast_math_ninf = target->GetAttr("fast-math-ninf").value_or(Bool(false)); - Bool fast_math_nsz = target->GetAttr("fast-math-nsz").value_or(Bool(false)); - Bool fast_math_arcp = target->GetAttr("fast-math-arcp").value_or(Bool(false)); - - llvm::FastMathFlags fmf; - if (fast_math_all) { -#if TVM_LLVM_VERSION >= 60 - fmf.setFast(); -#else - fmf.setUnsafeAlgebra(); -#endif - } +void LLVMModuleNode::Init(const IRModule& mod, const Target& target) { + llvm_instance_ = std::make_unique(); + With llvm_target(*llvm_instance_, target); + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); + std::unique_ptr cg = CodeGenLLVM::Create(llvm_target.get()); - if (fast_math_nnan) { - fmf.setNoNaNs(); - } - if (fast_math_ninf) { - fmf.setNoInfs(); - } - if (fast_math_nsz) { - fmf.setNoSignedZeros(); - } - if (fast_math_arcp) { - fmf.setAllowReciprocal(); - } + std::vector funcs; + std::string entry_func; + relay::Runtime runtime = + mod->GetAttr(tvm::attr::kRuntime).value_or(relay::Runtime::Create("cpp")); + bool system_lib = runtime->GetAttr("system-lib").value_or(Bool(false)); + bool target_c_runtime = runtime->name == "crt"; -#if TVM_LLVM_VERSION >= 60 - Bool fast_math_contract = target->GetAttr("fast-math-contract").value_or(Bool(false)); - Bool fast_math_afn = target->GetAttr("fast-math-afn").value_or(Bool(false)); - Bool fast_math_reassoc = target->GetAttr("fast-math-reassoc").value_or(Bool(false)); - if (fast_math_contract) { - fmf.setAllowContract(true); - } - if (fast_math_afn) { - fmf.setApproxFunc(); + for (auto kv : mod->functions) { + if (!kv.second->IsInstance()) { + // (@jroesch): we relax constraints here, Relay functions will just be ignored. + DLOG(INFO) << "Can only lower IR Module with PrimFuncs, but got " << kv.second->GetTypeKey(); + continue; } - if (fast_math_reassoc) { - fmf.setAllowReassoc(); + auto f = Downcast(kv.second); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.defined()); + function_names_.push_back(global_symbol.value()); + if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + entry_func = global_symbol.value(); } -#endif + funcs.push_back(f); + } + // TODO(@jroesch): follow up on this condition. + // ICHECK(funcs.size() > 0); + // TODO(tqchen): remove the entry function behavior as it does not + // makes sense when we start to use multiple modules. + cg->Init("TVMMod", llvm_target.get(), system_lib, system_lib, target_c_runtime); + cg->SetFastMathFlags(llvm_target->GetFastMathFlags()); + + cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); + if (entry_func.length() != 0) { + cg->AddMainFunction(entry_func); + } - cg->SetFastMathFlag(fmf); + module_owning_ptr_ = cg->Finish(); + module_ = module_owning_ptr_.get(); + llvm_target->SetTargetMetadata(module_); + module_->addModuleFlag(llvm::Module::Override, "Debug Info Version", + llvm::DEBUG_METADATA_VERSION); - cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); - if (entry_func.length() != 0) { - cg->AddMainFunction(entry_func); - } + if (tm->getTargetTriple().isOSDarwin()) { + module_->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2); + } - module_owning_ptr_ = cg->Finish(); - module_ = module_owning_ptr_.get(); + std::string verify_errors_storage; + llvm::raw_string_ostream verify_errors(verify_errors_storage); + LOG_IF(FATAL, llvm::verifyModule(*module_, &verify_errors)) + << "LLVM module verification failed with the following errors: \n" + << verify_errors.str(); +} - module_->addModuleFlag(llvm::Module::Warning, "tvm_target", - llvm::MDString::get(*ctx_, LLVMTargetToString(target))); - module_->addModuleFlag(llvm::Module::Override, "Debug Info Version", - llvm::DEBUG_METADATA_VERSION); +void LLVMModuleNode::Init(std::unique_ptr module, + std::unique_ptr llvm_instance) { + module_owning_ptr_ = std::move(module); + module_ = module_owning_ptr_.get(); + llvm_instance_ = std::move(llvm_instance); +} - if (tm_->getTargetTriple().isOSDarwin()) { - module_->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2); - } +void LLVMModuleNode::LoadIR(const std::string& file_name) { + auto llvm_instance = std::make_unique(); + std::unique_ptr module = llvm_instance->LoadIR(file_name); + Init(std::move(module), std::move(llvm_instance)); +} - std::string verify_errors_storage; - llvm::raw_string_ostream verify_errors(verify_errors_storage); - LOG_IF(FATAL, llvm::verifyModule(*module_, &verify_errors)) - << "LLVM module verification failed with the following errors: \n" - << verify_errors.str(); - target_ = target; - } +bool LLVMModuleNode::ImplementsFunction(const String& name, bool query_imports) { + return std::find(function_names_.begin(), function_names_.end(), name) != function_names_.end(); +} - void Init(std::unique_ptr module, std::shared_ptr ctx) { - InitializeLLVM(); - ctx_ = ctx; - llvm::SMDiagnostic err; - module_owning_ptr_ = std::move(module); - module_ = module_owning_ptr_.get(); - if (module_ == nullptr) { - std::string msg = std::string(err.getMessage()); - LOG(FATAL) << "Fail to load module: " << msg; - } - std::string target_metadata; - llvm::Metadata* tvm_target = module_->getModuleFlag("tvm_target"); - if (tvm_target != nullptr) { - llvm::MDString* pstr = llvm::dyn_cast(tvm_target); - ICHECK(pstr != nullptr); - target_metadata = pstr->getString().str(); - if (!(target_metadata.length() >= 4 && target_metadata.substr(0, 4) == "llvm")) { - target_metadata = "llvm " + target_metadata; - } - } else { - std::ostringstream os; - os << "llvm -mtriple " << module_->getTargetTriple(); - target_metadata = os.str(); - } - target_ = Target(target_metadata); - tm_ = GetLLVMTargetMachine(target_); +void LLVMModuleNode::LazyInitJIT() { + std::lock_guard lock(mutex_); + if (ee_) { + return; } - - void LoadIR(const std::string& file_name) { - auto ctx = std::make_shared(); - llvm::SMDiagnostic err; - auto module = llvm::parseIRFile(file_name, err, *ctx); - if (module == nullptr) { - std::string msg = std::string(err.getMessage()); - LOG(FATAL) << "Fail to load ir file " << file_name << "\n" - << "line " << err.getLineNo() << ":" << msg; - } - Init(std::move(module), ctx); + With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); + llvm::EngineBuilder builder(std::move(module_owning_ptr_)); + builder.setEngineKind(llvm::EngineKind::JIT); + builder.setOptLevel(llvm::CodeGenOpt::Aggressive); + builder.setMCPU(llvm_target->GetCPU()); + builder.setMAttrs(llvm_target->GetTargetFeatures()); + builder.setTargetOptions(llvm_target->GetTargetOptions()); + auto tm = std::unique_ptr(builder.selectTarget()); + if (!IsCompatibleWithHost(tm.get())) { + LOG(FATAL) << "Cannot run module, architecture mismatch"; } - - bool IsDSOExportable() const final { return true; } - - bool ImplementsFunction(const String& name, bool query_imports) final { - return std::find(function_names_.begin(), function_names_.end(), name) != function_names_.end(); + llvm::DataLayout layout(tm->createDataLayout()); + ICHECK(layout == module_->getDataLayout()) + << "Data layout mismatch between module(" + << module_->getDataLayout().getStringRepresentation() << ")" + << " and ExecutionEngine (" << layout.getStringRepresentation() << ")"; + ee_ = builder.create(tm.release()); + ICHECK(ee_ != nullptr) << "Failed to initialize jit engine for " << module_->getTargetTriple(); + ee_->runStaticConstructorsDestructors(false); + + if (void** ctx_addr = + reinterpret_cast(GetGlobalAddr(runtime::symbol::tvm_module_ctx, *llvm_target))) { + *ctx_addr = this; } + runtime::InitContextFunctions( + [this, &llvm_target](const char* name) { return GetGlobalAddr(name, *llvm_target); }); + // There is a problem when a JITed function contains a call to a runtime function. + // The runtime function (e.g. __truncsfhf2) may not be resolved, and calling it will + // lead to a runtime crash. + // Do name lookup on a symbol that doesn't exist. This will force MCJIT to finalize + // all loaded objects, which will resolve symbols in JITed code. + ee_->getFunctionAddress("__some_name_that_hopefully_doesnt_exist__b49f8aaade5877eaba7583b91"); +} - private: - void LazyInitJIT() { - std::lock_guard lock(mutex_); - if (ee_) { - return; - } - if (!target_.defined()) { - target_ = Target("llvm"); - } - llvm::EngineBuilder builder(std::move(module_owning_ptr_)); - std::string triple, mcpu, mattr; - llvm::TargetOptions opt; - ParseLLVMTargetOptions(target_, &triple, &mcpu, &mattr, &opt); - builder.setEngineKind(llvm::EngineKind::JIT); - builder.setOptLevel(llvm::CodeGenOpt::Aggressive); - if (mcpu.length() != 0) { - builder.setMCPU(mcpu); - } - if (mattr.length() != 0) { - std::vector mattrs{mattr}; - builder.setMAttrs(mattrs); - } - builder.setTargetOptions(opt); - auto tm = std::unique_ptr(builder.selectTarget()); - std::unique_ptr tm_sys = GetLLVMTargetMachine(Target("llvm")); - if (tm_sys->getTargetTriple().getArch() != tm->getTargetTriple().getArch()) { - LOG(FATAL) << "Cannot run module, architecture mismatch " - << " module=" << tm->getTargetTriple().str() - << " system=" << tm_sys->getTargetTriple().str(); - } - llvm::DataLayout layout(tm->createDataLayout()); - ICHECK(layout == module_->getDataLayout()) - << "Data layout mismatch between module(" - << module_->getDataLayout().getStringRepresentation() << ")" - << " and ExecutionEngine (" << layout.getStringRepresentation() << ")"; - ee_ = builder.create(tm.release()); - ICHECK(ee_ != nullptr) << "Failed to initialize jit engine for " << module_->getTargetTriple(); - ee_->runStaticConstructorsDestructors(false); - - if (void** ctx_addr = - reinterpret_cast(GetGlobalAddr(runtime::symbol::tvm_module_ctx))) { - *ctx_addr = this; - } - runtime::InitContextFunctions( - [this](const char* name) { return reinterpret_cast(GetGlobalAddr(name)); }); - // There is a problem when a JITed function contains a call to a runtime function. - // The runtime function (e.g. __truncsfhf2) may not be resolved, and calling it will - // lead to a runtime crash. - // Do name lookup on a symbol that doesn't exist. This will force MCJIT to finalize - // all loaded objects, which will resolve symbols in JITed code. - ee_->getFunctionAddress("__some_name_that_hopefully_doesnt_exist__b49f8aaade5877eaba7583b91"); +bool LLVMModuleNode::IsCompatibleWithHost(const llvm::TargetMachine* tm) const { + With host_target(*llvm_instance_, "llvm"); // FIXME(kparzysz-quic): nesting + auto tm_host = host_target->GetOrCreateTargetMachine(); + if (tm_host->getTargetTriple().getArch() != tm->getTargetTriple().getArch()) { + LOG(INFO) << "Architecture mismatch: module=" << tm->getTargetTriple().str() + << " host=" << tm_host->getTargetTriple().str(); + return false; } + return true; +} - // Get global address from execution engine. - uint64_t GetGlobalAddr(const std::string& name) const { - // first verifies if GV exists. - if (module_->getGlobalVariable(name) != nullptr) { - return ee_->getGlobalValueAddress(name); - } else { - return 0; - } +// Get global address from execution engine. +void* LLVMModuleNode::GetGlobalAddr(const std::string& name, const LLVMTarget& llvm_target) const { + // first verifies if GV exists. + if (module_->getGlobalVariable(name) != nullptr) { + return reinterpret_cast(ee_->getGlobalValueAddress(name)); + } else { + return nullptr; } +} - uint64_t GetFunctionAddr(const std::string& name) const { - // first verifies if GV exists. - if (module_->getFunction(name) != nullptr) { - return ee_->getFunctionAddress(name); - } else { - return 0; - } +void* LLVMModuleNode::GetFunctionAddr(const std::string& name, + const LLVMTarget& llvm_target) const { + // first verifies if GV exists. + if (module_->getFunction(name) != nullptr) { + return reinterpret_cast(ee_->getFunctionAddress(name)); + } else { + return nullptr; } - - // The target configuration string - Target target_; - // JIT lock - std::mutex mutex_; - // execution engine - llvm::ExecutionEngine* ee_{nullptr}; - // The target machine - std::unique_ptr tm_{nullptr}; - // The raw pointer to the module. - llvm::Module* module_{nullptr}; - // The unique_ptr owning the module. This becomes empty once JIT has been initialized - // (EngineBuilder takes ownership of the module). - std::unique_ptr module_owning_ptr_; - // the context. - std::shared_ptr ctx_; - /* \brief names of the functions declared in this module */ - Array function_names_; -}; +} TVM_REGISTER_GLOBAL("target.build.llvm") .set_body_typed([](IRModule mod, Target target) -> runtime::Module { @@ -487,18 +429,15 @@ TVM_REGISTER_GLOBAL("target.build.llvm") TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") .set_body_typed([](std::string target_str, std::string module_name) -> runtime::Module { - Target target = Target(target_str); + auto llvm_instance = std::make_unique(); + With llvm_target(*llvm_instance, target_str); auto n = make_object(); // Generate a LLVM module from an input target string - InitializeLLVM(); - auto tm = GetLLVMTargetMachine(target); - auto ctx = std::make_shared(); - std::unique_ptr module(new llvm::Module(module_name, *ctx)); - // Use a default data layout and target triple - auto triple = tm->getTargetTriple(); - module->setTargetTriple(triple.str()); - module->setDataLayout(tm->createDataLayout()); - n->Init(std::move(module), ctx); + auto module = std::make_unique(module_name, *llvm_target->GetContext()); + llvm_target->SetTargetMetadata(module.get()); + module->setTargetTriple(llvm_target->GetTargetTriple()); + module->setDataLayout(llvm_target->GetOrCreateTargetMachine()->createDataLayout()); + n->Init(std::move(module), std::move(llvm_instance)); return runtime::Module(n); }); @@ -535,38 +474,39 @@ TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll") TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled") .set_body_typed([](std::string target_str) -> bool { - InitializeLLVM(); - Target target = Target(target_str); - return (GetLLVMTargetMachine(target, true) != nullptr); + LLVMInstance llvm_instance; + auto* tm = With(llvm_instance, target_str) + ->GetOrCreateTargetMachine(/*allow_missing=*/true); + return tm != nullptr; }); TVM_REGISTER_GLOBAL("codegen.codegen_blob") .set_body_typed([](std::string data, bool system_lib, std::string llvm_target_string) -> runtime::Module { auto n = make_object(); - auto p = CodeGenBlob(data, system_lib, llvm_target_string); - n->Init(std::move(p.first), p.second); + auto llvm_instance = std::make_unique(); + With llvm_target(*llvm_instance, llvm_target_string); + std::unique_ptr blob = CodeGenBlob(data, system_lib, llvm_target.get()); + n->Init(std::move(blob), std::move(llvm_instance)); return runtime::Module(n); }); runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata, Target target, tvm::relay::Runtime runtime) { - InitializeLLVM(); - auto tm = GetLLVMTargetMachine(target); + auto llvm_instance = std::make_unique(); + With llvm_target(*llvm_instance, target); bool system_lib = runtime->GetAttr("system-lib").value_or(Bool(false)); - auto ctx = std::make_shared(); std::unique_ptr cg{new CodeGenCPU()}; - cg->Init("TVMMetadataMod", tm.get(), ctx.get(), system_lib, system_lib, - false /* target_c_runtime */); + cg->Init("TVMMetadataMod", llvm_target.get(), system_lib, system_lib, + /*target_c_runtime=*/false); cg->DefineMetadata(metadata); auto mod = cg->Finish(); - mod->addModuleFlag(llvm::Module::Warning, "tvm_target", - llvm::MDString::get(*ctx, LLVMTargetToString(target))); + llvm_target->SetTargetMetadata(mod.get()); mod->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION); - if (tm->getTargetTriple().isOSDarwin()) { + if (llvm_target->GetOrCreateTargetMachine()->getTargetTriple().isOSDarwin()) { mod->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2); } @@ -577,7 +517,7 @@ runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata << verify_errors.str(); auto n = make_object(); - n->Init(std::move(mod), ctx); + n->Init(std::move(mod), std::move(llvm_instance)); auto meta_mod = MetadataModuleCreate(metadata); meta_mod->Import(runtime::Module(n)); @@ -597,24 +537,22 @@ runtime::Module CreateLLVMCrtMetadataModule(const Array& module } } - InitializeLLVM(); - auto tm = GetLLVMTargetMachine(target); + auto llvm_instance = std::make_unique(); + With llvm_target(*llvm_instance, target); bool system_lib = runtime->GetAttr("system-lib").value_or(Bool(false)); bool target_c_runtime = runtime->name == "crt"; ICHECK(system_lib && target_c_runtime) << "For LLVM C-runtime metadata module, must include --system-lib and --runtime=c; " << "got target: " << target->str(); - auto ctx = std::make_shared(); std::unique_ptr cg{new CodeGenCPU()}; - cg->Init("TVMMetadataMod", tm.get(), ctx.get(), system_lib, system_lib, target_c_runtime); + cg->Init("TVMMetadataMod", llvm_target.operator->(), system_lib, system_lib, target_c_runtime); cg->DefineFunctionRegistry(func_names); auto mod = cg->Finish(); - mod->addModuleFlag(llvm::Module::Warning, "tvm_target", - llvm::MDString::get(*ctx, LLVMTargetToString(target))); + llvm_target->SetTargetMetadata(mod.get()); mod->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION); - if (tm->getTargetTriple().isOSDarwin()) { + if (llvm_target->GetOrCreateTargetMachine()->getTargetTriple().isOSDarwin()) { mod->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2); } @@ -625,7 +563,7 @@ runtime::Module CreateLLVMCrtMetadataModule(const Array& module << verify_errors.str(); auto n = make_object(); - n->Init(std::move(mod), ctx); + n->Init(std::move(mod), std::move(llvm_instance)); for (auto m : modules) { n->Import(m); } diff --git a/src/target/llvm/llvm_module.h b/src/target/llvm/llvm_module.h index 3a50c2c4244f..66492f8152e5 100644 --- a/src/target/llvm/llvm_module.h +++ b/src/target/llvm/llvm_module.h @@ -46,5 +46,4 @@ runtime::Module CreateLLVMCrtMetadataModule(const Array& module } // namespace tvm #endif // TVM_LLVM_VERSION - #endif // TVM_TARGET_LLVM_LLVM_MODULE_H_