From 33607c05c099a53e8c14e2de0abcf9dcc7766f57 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Mon, 23 May 2022 06:37:55 -0700 Subject: [PATCH 1/6] [LLVM] Create LLVM scope object for use with LLVM libraries This implements RFC 80. See https://github.com/apache/tvm-rfcs/pull/83. Summary of changes: - Created an `LLVMScope` class. Uses of LLVM functions and data struc- tures should be contained within the lifetime of an object of this class. LLVMScope object contains LLVMContext, and implements member functions to deserialize an llvm::Module. - Created an `LLVMTarget` class. Once an LLVMScope object has been created, an object of LLVMTarget class can be created from TVM target string, or Target object for "llvm" target. Once LLVM command line flags are added to the "llvm" target, one of the goals of this object will be to save/restore relevant LLVM global state. Another objective for the LLVMTarget object is to be a single location for all LLVM-related compilation structures and options (such as TargetMachine, FastMathFlags, etc.) --- src/target/llvm/codegen_amdgpu.cc | 30 +- src/target/llvm/codegen_arm.cc | 5 +- src/target/llvm/codegen_blob.cc | 24 +- src/target/llvm/codegen_blob.h | 15 +- src/target/llvm/codegen_cpu.cc | 94 ++-- src/target/llvm/codegen_cpu.h | 10 +- src/target/llvm/codegen_hexagon.cc | 51 ++- src/target/llvm/codegen_llvm.cc | 140 +++--- src/target/llvm/codegen_llvm.h | 27 +- src/target/llvm/codegen_nvptx.cc | 33 +- src/target/llvm/codegen_x86_64.cc | 8 +- src/target/llvm/llvm_common.cc | 211 --------- src/target/llvm/llvm_common.h | 89 ---- src/target/llvm/llvm_module.cc | 660 +++++++++++++---------------- src/target/llvm/llvm_module.h | 1 - src/target/llvm/llvm_scope.cc | 368 ++++++++++++++++ src/target/llvm/llvm_scope.h | 117 +++++ 17 files changed, 1003 insertions(+), 880 deletions(-) delete mode 100644 src/target/llvm/llvm_common.cc delete mode 100644 src/target/llvm/llvm_common.h create mode 100644 src/target/llvm/llvm_scope.cc create mode 100644 src/target/llvm/llvm_scope.h diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 2e5a4bc23bd5..daaf9b3da9ae 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -51,7 +51,7 @@ #include "../../runtime/rocm/rocm_module.h" #include "../build_common.h" #include "codegen_llvm.h" -#include "llvm_common.h" +#include "llvm_scope.h" namespace tvm { namespace codegen { @@ -238,27 +238,25 @@ class CodeGenAMDGPU : public CodeGenLLVM { } protected: - void InitTarget(llvm::TargetMachine* tm) final { + void InitTarget() final { // Maximum vector lane = float4 native_vector_bits_ = 4 * 32; - CodeGenLLVM::InitTarget(tm); + CodeGenLLVM::InitTarget(); } }; runtime::Module BuildAMDGPU(IRModule mod, Target target) { + LLVMScope llvm_scope; + + With llvm_target(llvm_scope, target); #if TVM_LLVM_VERSION < 90 LOG(FATAL) << "AMDGPU backend requires at least LLVM 9"; // Lower versions will crash when loading the bitcode, see // issue #4087 for a discussion #endif - InitializeLLVM(); - std::unique_ptr tm = GetLLVMTargetMachine(target); - std::unique_ptr ctx(new llvm::LLVMContext()); - // careful: cg will hold a naked pointer reference to ctx, so it should - // have a shorter lifetime than the ctx. std::unique_ptr cg(new CodeGenAMDGPU()); - cg->Init("TVMAMDGPUModule", tm.get(), ctx.get(), false, false, false); + cg->Init("TVMAMDGPUModule", llvm_target.get(), false, false, false); cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) { ICHECK(kv.second->template IsInstance()) @@ -266,20 +264,15 @@ runtime::Module BuildAMDGPU(IRModule mod, Target target) { return Downcast(kv.second); }); + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); const auto* find_rocm_bitcodes = tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path"); Array bitcode_files = (*find_rocm_bitcodes)(); for (auto& bitcode_path : bitcode_files) { - std::string path = bitcode_path; - llvm::SMDiagnostic err; - std::unique_ptr mlib = llvm::parseIRFile(path, err, *ctx); - if (mlib.get() == nullptr) { - std::string msg(err.getMessage()); - LOG(FATAL) << "Fail to load bitcode file " << path << "\n" - << "line " << err.getLineNo() << ":" << msg; - } - mlib->setTargetTriple(tm->getTargetTriple().str()); + std::unique_ptr mlib = llvm_scope.LoadIR(bitcode_path); + mlib->setTargetTriple(llvm_target->GetTargetTriple()); mlib->setDataLayout(tm->createDataLayout()); + for (llvm::Function& f : mlib->functions()) { f.addFnAttr(llvm::Attribute::AlwaysInline); } @@ -351,4 +344,5 @@ TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_rocm") } // namespace codegen } // namespace tvm + #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc index f5ce0d550b1f..15d1699b3b59 100644 --- a/src/target/llvm/codegen_arm.cc +++ b/src/target/llvm/codegen_arm.cc @@ -42,10 +42,10 @@ class CodeGenARM final : public CodeGenCPU { CodeGenARM() = default; virtual ~CodeGenARM() = default; - void InitTarget(llvm::TargetMachine* tm) final { + void InitTarget() final { // set native vector bits. native_vector_bits_ = 16 * 8; - CodeGenCPU::InitTarget(tm); + CodeGenCPU::InitTarget(); } llvm::Value* CreateIntrinsic(const CallNode* op) override; @@ -139,4 +139,5 @@ TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm") } // namespace codegen } // namespace tvm + #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/codegen_blob.cc b/src/target/llvm/codegen_blob.cc index 8e6041b4c970..b15d3caed653 100644 --- a/src/target/llvm/codegen_blob.cc +++ b/src/target/llvm/codegen_blob.cc @@ -52,25 +52,20 @@ #include #include -#include "llvm_common.h" +#include "llvm_scope.h" namespace tvm { namespace codegen { -std::pair, std::shared_ptr> CodeGenBlob( - const std::string& data, bool system_lib, const std::string& llvm_target_string) { - InitializeLLVM(); - Target target(llvm_target_string); - auto tm = GetLLVMTargetMachine(target); - auto triple = tm->getTargetTriple(); - auto ctx = std::make_shared(); +std::unique_ptr CodeGenBlob(const std::string& data, bool system_lib, + LLVMTarget* llvm_target) { + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); + const llvm::Triple& triple = tm->getTargetTriple(); + llvm::LLVMContext* ctx = llvm_target->GetContext(); std::string module_name = "devc"; - std::unique_ptr module(new llvm::Module(module_name, *ctx)); + auto module = std::make_unique(module_name, *ctx); module->setTargetTriple(triple.str()); - // Store full target string in metadata, because flags such as -mfloat-abi must be preserved for - // ModulePackImportsToLLVM. - module->addModuleFlag(llvm::Module::ModFlagBehavior::Override, "tvm_target", - llvm::MDString::get(*ctx, LLVMTargetToString(target))); + llvm_target->SetTargetMetadata(module.get()); module->setDataLayout(tm->createDataLayout()); auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false); auto* tvm_dev_mblob = new llvm::GlobalVariable( @@ -188,9 +183,10 @@ std::pair, std::shared_ptr> Cod ir_builder.CreateRetVoid(); } - return std::make_pair(std::move(module), ctx); + return module; } } // namespace codegen } // namespace tvm + #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/codegen_blob.h b/src/target/llvm/codegen_blob.h index 46c037a30af2..a06c043c07b1 100644 --- a/src/target/llvm/codegen_blob.h +++ b/src/target/llvm/codegen_blob.h @@ -26,15 +26,18 @@ #ifdef TVM_LLVM_VERSION -#include -#include - #include #include -#include + +namespace llvm { +class Module; +} namespace tvm { namespace codegen { + +class LLVMTarget; + /** * \brief Code Generation of blob data * @@ -44,8 +47,8 @@ namespace codegen { * * \return LLVM module and LLVM context */ -std::pair, std::shared_ptr> CodeGenBlob( - const std::string& data, bool system_lib, const std::string& llvm_target_string); +std::unique_ptr CodeGenBlob(const std::string& data, bool system_lib, + LLVMTarget* llvm_target); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index f2ce6fb848b4..4db0df87f916 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -60,6 +60,7 @@ #include "../func_registry_generator.h" #include "../metadata_utils.h" +#include "llvm_scope.h" namespace tvm { namespace codegen { @@ -69,10 +70,9 @@ namespace codegen { CodeGenCPU::CodeGenCPU() = default; CodeGenCPU::~CodeGenCPU() = default; -void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup, - bool target_c_runtime) { - CodeGenLLVM::Init(module_name, tm, ctx, system_lib, dynamic_lookup, target_c_runtime); +void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib, + bool dynamic_lookup, bool target_c_runtime) { + CodeGenLLVM::Init(module_name, llvm_target, system_lib, dynamic_lookup, target_c_runtime); dbg_info_ = CreateDebugInfo(module_.get()); static_assert(sizeof(TVMValue) == sizeof(double), "invariant"); func_handle_map_.clear(); @@ -80,7 +80,8 @@ void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, // Runtime types. - t_tvm_shape_index_ = llvm::Type::getIntNTy(*ctx, DataType::ShapeIndex().bits()); + t_tvm_shape_index_ = + llvm::Type::getIntNTy(*llvm_target_->GetContext(), DataType::ShapeIndex().bits()); // Defined in 3rdparty/dlpack/include/dlpack/dlpack.h: // typedef struct { DLDeviceType device_type; int device_id; } DLDevice; t_tvm_device_ = llvm::StructType::create({t_int_, t_int_}); @@ -177,7 +178,7 @@ void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::Function::Create(ftype_tvm_parallel_barrier_, llvm::Function::ExternalLinkage, "TVMBackendParallelBarrier", module_.get()); } - this->InitGlobalContext(dynamic_lookup); + InitGlobalContext(dynamic_lookup); target_c_runtime_ = target_c_runtime; is_system_lib_ = system_lib; } @@ -240,6 +241,7 @@ void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) { } llvm::DebugLoc DL; builder.SetCurrentDebugLocation(DL); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); for (size_t i = 0; i < f_llvm->arg_size(); ++i) { auto* paramAlloca = builder.CreateAlloca(f_llvm->getFunctionType()->getParamType(i)); std::string paramName = "arg" + std::to_string(i + 1); @@ -248,7 +250,7 @@ void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) { GetDebugType(GetType(f_tir->params[i]), f_llvm->getFunctionType()->getParamType(i)), /*alwaysPreserve=*/true); auto* store = builder.CreateStore(f_llvm->arg_begin() + i, paramAlloca); - auto* di_loc = llvm::DILocation::get(*ctx_, 0, 0, DIFunction); + auto* di_loc = llvm::DILocation::get(*ctx, 0, 0, DIFunction); dbg_info_->di_builder_->insertDeclare(paramAlloca, param, dbg_info_->di_builder_->createExpression(), llvm::DebugLoc(di_loc), store); @@ -263,7 +265,7 @@ void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) { if (I.getDebugLoc()) { continue; } - auto* di_loc = llvm::DILocation::get(*ctx_, 0, 0, scope); + auto* di_loc = llvm::DILocation::get(*ctx, 0, 0, scope); I.setDebugLoc(llvm::DebugLoc(di_loc)); } } @@ -273,7 +275,7 @@ void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) { llvm::DIType* CodeGenCPU::GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm) { if (ty_llvm == t_void_) { return nullptr; - } else if (ty_llvm == llvm::Type::getFloatTy(*ctx_)) { + } else if (ty_llvm == llvm::Type::getFloatTy(*llvm_target_->GetContext())) { return dbg_info_->di_builder_->createBasicType("float", 32, llvm::dwarf::DW_ATE_float); } else if (ty_llvm == t_int8_) { return dbg_info_->di_builder_->createBasicType("int8", 8, llvm::dwarf::DW_ATE_signed); @@ -311,13 +313,14 @@ void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) { #endif // comdat is needed for windows select any linking to work // set comdat to Any(weak linking) - if (target_machine_->getTargetTriple().isOSWindows()) { + if (llvm_target_->GetOrCreateTargetMachine()->getTargetTriple().isOSWindows()) { llvm::Comdat* comdat = module_->getOrInsertComdat(runtime::symbol::tvm_module_main); comdat->setSelectionKind(llvm::Comdat::Any); global->setComdat(comdat); } - global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, entry_func_name)); + global->setInitializer( + llvm::ConstantDataArray::getString(*llvm_target_->GetContext(), entry_func_name)); global->setDLLStorageClass(llvm::GlobalVariable::DLLExportStorageClass); } @@ -475,7 +478,7 @@ llvm::GlobalVariable* CodeGenCPU::InitContextPtr(llvm::Type* p_type, std::string gv->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); // comdat is needed for windows select any linking to work // set comdat to Any(weak linking) - if (target_machine_->getTargetTriple().isOSWindows()) { + if (llvm_target_->GetOrCreateTargetMachine()->getTargetTriple().isOSWindows()) { llvm::Comdat* comdat = module_->getOrInsertComdat(name); comdat->setSelectionKind(llvm::Comdat::Any); gv->setComdat(comdat); @@ -525,8 +528,9 @@ void CodeGenCPU::InitGlobalContext(bool dynamic_lookup) { llvm::BasicBlock* CodeGenCPU::CheckCallSuccess(llvm::Value* retcode) { // create emit codes that checks and load the function. - auto* fail_block = llvm::BasicBlock::Create(*ctx_, "call_fail", function_); - auto* end_block = llvm::BasicBlock::Create(*ctx_, "call_end", function_); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + auto* fail_block = llvm::BasicBlock::Create(*ctx, "call_fail", function_); + auto* end_block = llvm::BasicBlock::Create(*ctx, "call_end", function_); auto* succ = builder_->CreateICmpEQ(retcode, llvm::ConstantInt::get(t_int_, 0)); builder_->CreateCondBr(succ, end_block, fail_block, md_very_likely_branch_); builder_->SetInsertPoint(fail_block); @@ -584,6 +588,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { SetTargetAttributes(fcompute); llvm::BasicBlock* compute_call_end = CheckCallSuccess(builder_->CreateCall(fcompute, arg_values)); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); // enter compute scope and setup compute function. With scope_states_guard(this); size_t idx = 0; @@ -607,7 +612,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { if (f != alloc_storage_info_.end()) { unsigned align = f->second.alignment; if (align > 1) { - auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align); + auto attr = llvm::Attribute::get(*ctx, llvm::Attribute::Alignment, align); fcompute->addParamAttr(idx, attr); } } @@ -615,7 +620,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { } function_ = fcompute; - auto* compute_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); + auto* compute_entry = llvm::BasicBlock::Create(*ctx, "entry", function_); builder_->SetInsertPoint(compute_entry); this->VisitStmt(op->body); builder_->CreateRet(ConstInt32(0)); @@ -679,7 +684,8 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task, std::strin launch_callee, {f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(num_task)})); // Setup the closure function. - auto* lambda_entry = llvm::BasicBlock::Create(*ctx_, "parallel_closure_entry", f); + auto* lambda_entry = + llvm::BasicBlock::Create(*llvm_target_->GetContext(), "parallel_closure_entry", f); builder_->SetInsertPoint(lambda_entry); auto it = f->arg_begin(); llvm::Value* task_id = &(*it++); @@ -747,7 +753,7 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod llvm::BasicBlock* init_end = CheckCallSuccess(builder_->CreateCall( finit, {gv, f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(nbytes)})); // Setup the closure function. - auto* lambda_entry = llvm::BasicBlock::Create(*ctx_, "entry", f); + auto* lambda_entry = llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry", f); builder_->SetInsertPoint(lambda_entry); auto it = f->arg_begin(); cdata.addr = builder_->CreatePointerCast(&(*it++), cdata.addr->getType()); @@ -793,9 +799,10 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { hptr = it->second; } // create emit codes that checks and load the function. + llvm::LLVMContext* ctx = llvm_target_->GetContext(); llvm::BasicBlock* pre_block = builder_->GetInsertBlock(); - auto* init_block = llvm::BasicBlock::Create(*ctx_, "handle_init", function_); - auto* end_block = llvm::BasicBlock::Create(*ctx_, "handle_init_end", function_); + auto* init_block = llvm::BasicBlock::Create(*ctx, "handle_init", function_); + auto* end_block = llvm::BasicBlock::Create(*ctx, "handle_init_end", function_); #if TVM_LLVM_VERSION >= 110 llvm::Value* handle = builder_->CreateAlignedLoad(hptr->getValueType(), hptr, llvm::Align(align)); #elif TVM_LLVM_VERSION >= 80 @@ -811,22 +818,22 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { llvm::Value* out = WithFunctionEntry([&]() { return builder_->CreateAlloca(t_tvm_func_handle_); }); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, - llvm::Align(gv_mod_ctx_->getAlignment())); + llvm::LoadInst* ctx_load = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, + llvm::Align(gv_mod_ctx_->getAlignment())); #elif TVM_LLVM_VERSION >= 80 - llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, - gv_mod_ctx_->getAlignment()); + llvm::LoadInst* ctx_load = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, + gv_mod_ctx_->getAlignment()); #else - llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_, gv_mod_ctx_->getAlignment()); + llvm::LoadInst* ctx_load = builder_->CreateAlignedLoad(gv_mod_ctx_, gv_mod_ctx_->getAlignment()); #endif - ctx->setMetadata("tbaa", - md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0)); + ctx_load->setMetadata( + "tbaa", md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0)); #if TVM_LLVM_VERSION >= 90 auto env_callee = llvm::FunctionCallee(ftype_tvm_get_func_from_env_, RuntimeTVMGetFuncFromEnv()); #else auto env_callee = RuntimeTVMGetFuncFromEnv(); #endif - llvm::Value* retcode = builder_->CreateCall(env_callee, {ctx, GetConstString(fname), out}); + llvm::Value* retcode = builder_->CreateCall(env_callee, {ctx_load, GetConstString(fname), out}); init_block = CheckCallSuccess(retcode); #if TVM_LLVM_VERSION >= 110 llvm::Value* loaded_handle = @@ -946,13 +953,14 @@ llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) { ICHECK_EQ(op->args.size(), 6U); PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, op->args[4].as()->value, true); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); // Get traced value. llvm::Value* traced_value = MakeValue(op->args[5]); // The update_block handles case when we need to update the return value. - llvm::BasicBlock* update_block = llvm::BasicBlock::Create(*ctx_, "update_block", function_); + llvm::BasicBlock* update_block = llvm::BasicBlock::Create(*ctx, "update_block", function_); // The continue_block handles case when we need to return original // traced value. - llvm::BasicBlock* continue_block = llvm::BasicBlock::Create(*ctx_, "continue_block", function_); + llvm::BasicBlock* continue_block = llvm::BasicBlock::Create(*ctx, "continue_block", function_); // Check the ret_type_code and create cmp instruction. llvm::Value* cmp = @@ -1254,14 +1262,15 @@ class MetadataSerializerLLVM : public AttrVisitor { }; void CodeGenCPU::DefineMetadata(runtime::metadata::Metadata metadata) { + llvm::LLVMContext* ctx = llvm_target_->GetContext(); MetadataLlvmTypes llvm_types{ t_float64_ /* t_float64 */, - llvm::Type::getInt8Ty(*ctx_) /* t_uint8 */, + llvm::Type::getInt8Ty(*ctx) /* t_uint8 */, t_int64_ /* t_int64 */, - llvm::Type::getInt8Ty(*ctx_) /* t_bool */, + llvm::Type::getInt8Ty(*ctx) /* t_bool */, t_char_->getPointerTo() /* t_cstring */, t_void_p_ /* t_void_p */, - llvm::StructType::create(*ctx_, {t_int8_, t_int8_, t_int8_}, "DLDataType") /* t_data_type */, + llvm::StructType::create(*ctx, {t_int8_, t_int8_, t_int8_}, "DLDataType") /* t_data_type */, }; // create sample ConstantInfoMetadata instance for MetadataTypeDefiner @@ -1278,7 +1287,7 @@ void CodeGenCPU::DefineMetadata(runtime::metadata::Metadata metadata) { metadata::DiscoverComplexTypesVisitor discover_complex{&queue}; discover_complex.Discover(metadata); - MetadataTypeDefiner definer{ctx_, &llvm_types}; + MetadataTypeDefiner definer{ctx, &llvm_types}; for (auto md : queue) { if (md.defined()) { definer.DefineType(md); @@ -1295,7 +1304,7 @@ void CodeGenCPU::DefineMetadata(runtime::metadata::Metadata metadata) { function_->setCallingConv(llvm::CallingConv::C); function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); - llvm::BasicBlock* entry_point_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); + llvm::BasicBlock* entry_point_entry = llvm::BasicBlock::Create(*ctx, "entry", function_); builder_->SetInsertPoint(entry_point_entry); auto ret_values_p = builder_->CreateBitCast(GetArg(function_, 3), t_void_p_->getPointerTo()); @@ -1350,7 +1359,8 @@ void CodeGenCPU::DefineFunctionRegistry(Array func_names) { function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, "TVMSystemLibEntryPoint", module_.get()); SetTargetAttributes(function_); - llvm::BasicBlock* entry_point_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); + llvm::BasicBlock* entry_point_entry = + llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry", function_); builder_->SetInsertPoint(entry_point_entry); builder_->CreateRet(builder_->CreateBitCast(module, t_void_p_)); } @@ -1361,7 +1371,8 @@ void CodeGenCPU::AddStartupFunction() { function_ = llvm::Function::Create(ftype, llvm::Function::InternalLinkage, "__tvm_module_startup", module_.get()); SetTargetAttributes(function_); - llvm::BasicBlock* startup_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); + llvm::BasicBlock* startup_entry = + llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry", function_); builder_->SetInsertPoint(startup_entry); for (const auto& kv : export_system_symbols_) { llvm::Value* name = GetConstString(kv.first); @@ -1385,7 +1396,8 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { } else if (op->op.same_as(builtin::tvm_throw_last_error())) { builder_->CreateRet(ConstInt32(-1)); auto next_block = std::next(builder_->GetInsertBlock()->getIterator()); - llvm::BasicBlock* new_bb = llvm::BasicBlock::Create(*ctx_, "cont", function_, &*next_block); + llvm::BasicBlock* new_bb = + llvm::BasicBlock::Create(*llvm_target_->GetContext(), "cont", function_, &*next_block); builder_->SetInsertPoint(new_bb); return ConstInt32(-1); } else if (op->op.same_as(builtin::tvm_struct_get())) { @@ -1443,8 +1455,9 @@ void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) { os << ", " << op->message.as()->value; } llvm::Value* msg = GetConstString(os.str()); - auto* fail_block = llvm::BasicBlock::Create(*ctx_, "assert_fail", function_); - auto* end_block = llvm::BasicBlock::Create(*ctx_, "assert_end", function_); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + auto* fail_block = llvm::BasicBlock::Create(*ctx, "assert_fail", function_); + auto* end_block = llvm::BasicBlock::Create(*ctx, "assert_end", function_); builder_->CreateCondBr(cond, end_block, fail_block, md_very_likely_branch_); // fail condition. builder_->SetInsertPoint(fail_block); @@ -1549,4 +1562,5 @@ TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_cpu") } // namespace codegen } // namespace tvm + #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index eec38b122a0b..e0716ac8be2d 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -24,6 +24,8 @@ #ifndef TVM_TARGET_LLVM_CODEGEN_CPU_H_ #define TVM_TARGET_LLVM_CODEGEN_CPU_H_ +#ifdef TVM_LLVM_VERSION + #include #include #include @@ -54,14 +56,16 @@ class Module; namespace tvm { namespace codegen { +class LLVMTarget; + // CPU host code generation class CodeGenCPU : public CodeGenLLVM { public: CodeGenCPU(); virtual ~CodeGenCPU(); - void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx, - bool system_lib, bool dynamic_lookup, bool target_c_runtime) override; + void Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib, + bool dynamic_lookup, bool target_c_runtime) override; void AddFunction(const PrimFunc& f) override; void AddMainFunction(const std::string& entry_func_name) override; std::unique_ptr Finish() override; @@ -197,4 +201,6 @@ class CodeGenCPU : public CodeGenLLVM { } // namespace codegen } // namespace tvm + +#endif // TVM_LLVM_VERSION #endif // TVM_TARGET_LLVM_CODEGEN_CPU_H_ diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index 7b0081869a27..200e108dad2b 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -62,7 +62,7 @@ #include "../../runtime/hexagon/hexagon_module.h" #include "../build_common.h" #include "codegen_cpu.h" -#include "llvm_common.h" +#include "llvm_scope.h" namespace tvm { namespace codegen { @@ -70,9 +70,9 @@ namespace codegen { // Hexagon code generation class CodeGenHexagon final : public CodeGenCPU { public: - void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx, - bool system_lib, bool dynamic_lookup, bool target_c_runtime) override; - void InitTarget(llvm::TargetMachine* tm) final; + void Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib, + bool dynamic_lookup, bool target_c_runtime) override; + void InitTarget() final; using CodeGenCPU::VisitStmt_; llvm::Value* VisitExpr_(const BufferLoadNode* op) override; @@ -102,29 +102,30 @@ class CodeGenHexagon final : public CodeGenCPU { llvm::Value* Intrinsic(llvm::Intrinsic::ID, llvm::ArrayRef args); }; -void CodeGenHexagon::Init(const std::string& module_name, llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup, - bool target_c_runtime) { - CodeGenCPU::Init(module_name, tm, ctx, system_lib, dynamic_lookup, target_c_runtime); +void CodeGenHexagon::Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib, + bool dynamic_lookup, bool target_c_runtime) { + CodeGenCPU::Init(module_name, llvm_target, system_lib, dynamic_lookup, target_c_runtime); } -void CodeGenHexagon::InitTarget(llvm::TargetMachine* tm) { - native_vector_bits_ = 64; // Assume "scalar" vectors at first. - llvm::StringRef fs = tm->getTargetFeatureString(); - size_t npos = llvm::StringRef::npos; +void CodeGenHexagon::InitTarget() { + native_vector_bits_ = 64; // Assume "scalar" vectors at first. const auto hvx_length_feature = "+hvx-length"; // +hvx-length{64|128}b - size_t len_begin = fs.find(hvx_length_feature); - size_t len_end = len_begin != npos ? fs.find('b', len_begin) : npos; - if (len_end != npos) { + for (const std::string& f : llvm_target_->GetTargetFeatures()) { + llvm::StringRef fs(f); + if (!fs.startswith(hvx_length_feature)) continue; + + ICHECK(fs.endswith("b")) << "malformed target feature: " << f; int hvx_bytes = 0; - len_begin += std::strlen(hvx_length_feature); - ICHECK(!fs.substr(len_begin, len_end - len_begin).getAsInteger(10, hvx_bytes)) - << "invalid HVX length in feature string: " << fs.str(); + size_t len_begin = std::strlen(hvx_length_feature); + ICHECK(!fs.substr(len_begin, fs.size() - len_begin - 1).getAsInteger(10, hvx_bytes)) + << "invalid HVX length in feature string: " << f; ICHECK(hvx_bytes == 64 || hvx_bytes == 128) << "invalid HVX vector length: " << hvx_bytes << ", should be 64 or 128"; native_vector_bits_ = hvx_bytes * 8; + // There should only be one hvx-length... + break; } - CodeGenLLVM::InitTarget(tm); + CodeGenLLVM::InitTarget(); } llvm::GlobalVariable* CodeGenHexagon::InitContextPtr(llvm::Type* p_type, std::string name) { @@ -451,9 +452,8 @@ void ProcessLLVMOptions(const std::vector& llvm_vec) { } // namespace runtime::Module BuildHexagon(IRModule mod, Target target) { - // Make sure all targets are registered. InitializeLLVM can be called - // multiple times, after the first call all subsequent calls are no-ops. - InitializeLLVM(); + LLVMScope llvm_scope; + With llvm_target(llvm_scope, target); auto split = [](const std::string& str, char delim = ' ') { std::vector vec; @@ -493,8 +493,6 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { static bool CallOnce = (ProcessLLVMOptions(llvm_options_vec), true); (void)CallOnce; - std::unique_ptr tm = GetLLVMTargetMachine(target); - std::unique_ptr ctx(new llvm::LLVMContext()); std::unique_ptr cg(new CodeGenHexagon()); std::vector funcs; @@ -515,7 +513,7 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { funcs.emplace_back(f); } - cg->Init("TVMHexagonModule", tm.get(), ctx.get(), false, false, false); + cg->Init("TVMHexagonModule", llvm_target.get(), false, false, false); cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); if (entry_func.length() != 0) { cg->AddMainFunction(entry_func); @@ -527,7 +525,7 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { enum CodeGenFileType { Asm, Obj, IR, BC }; - auto EmitToString = [&tm](const llvm::Module& m, CodeGenFileType cgft) { + auto EmitToString = [&llvm_target](const llvm::Module& m, CodeGenFileType cgft) { std::string out; if (cgft == IR || cgft == BC) { @@ -548,6 +546,7 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { llvm::raw_svector_ostream os(ss); std::unique_ptr cm = llvm::CloneModule(m); llvm::legacy::PassManager pass; + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); ICHECK(tm->addPassesToEmitFile(pass, os, nullptr, ft) == 0) << "Cannot emit target code"; pass.run(*cm.get()); out.assign(ss.c_str(), ss.size()); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index f1d891e2c3bd..695eaa30493a 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -89,7 +89,7 @@ #include "../build_common.h" #include "../func_registry_generator.h" #include "codegen_params.h" -#include "llvm_common.h" +#include "llvm_scope.h" namespace tvm { namespace codegen { @@ -102,8 +102,8 @@ CodeGenLLVM::CodeGenLLVM() = default; CodeGenLLVM::~CodeGenLLVM() = default; CodeGenLLVM::DebugInfo::~DebugInfo() = default; -std::unique_ptr CodeGenLLVM::Create(llvm::TargetMachine* tm) { - std::string target = tm->getTarget().getName(); +std::unique_ptr CodeGenLLVM::Create(LLVMTarget* llvm_target) { + std::string target = llvm_target->GetOrCreateTargetMachine()->getTarget().getName(); std::string factory_template = "tvm.codegen.llvm.target_"; void* handle = nullptr; if (const PackedFunc* f = runtime::Registry::Get(factory_template + target)) { @@ -121,38 +121,37 @@ std::unique_ptr CodeGenLLVM::Create(llvm::TargetMachine* tm) { } } -void CodeGenLLVM::Init(const std::string& module_name, llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup, - bool target_c_runtime) { - InitializeLLVM(); - ctx_ = ctx; - builder_.reset(new IRBuilder(*ctx_)); - module_.reset(new llvm::Module(module_name, *ctx_)); - md_builder_.reset(new llvm::MDBuilder(*ctx_)); +void CodeGenLLVM::Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib, + bool dynamic_lookup, bool target_c_runtime) { + llvm_target_ = llvm_target; + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + builder_.reset(new IRBuilder(*ctx)); + module_.reset(new llvm::Module(module_name, *ctx)); + md_builder_.reset(new llvm::MDBuilder(*ctx)); // types - t_void_ = llvm::Type::getVoidTy(*ctx_); - t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo(GetGlobalAddressSpace()); - t_int_ = llvm::Type::getInt32Ty(*ctx_); - t_char_ = llvm::Type::getInt8Ty(*ctx_); - t_int8_ = llvm::Type::getInt8Ty(*ctx_); - t_int16_ = llvm::Type::getInt16Ty(*ctx_); - t_int32_ = llvm::Type::getInt32Ty(*ctx_); - t_int64_ = llvm::Type::getInt64Ty(*ctx_); - t_float64_ = llvm::Type::getDoubleTy(*ctx_); + t_void_ = llvm::Type::getVoidTy(*ctx); + t_void_p_ = llvm::Type::getInt8Ty(*ctx)->getPointerTo(GetGlobalAddressSpace()); + t_int_ = llvm::Type::getInt32Ty(*ctx); + t_char_ = llvm::Type::getInt8Ty(*ctx); + t_int8_ = llvm::Type::getInt8Ty(*ctx); + t_int16_ = llvm::Type::getInt16Ty(*ctx); + t_int32_ = llvm::Type::getInt32Ty(*ctx); + t_int64_ = llvm::Type::getInt64Ty(*ctx); + t_float64_ = llvm::Type::getDoubleTy(*ctx); // meta data md_very_likely_branch_ = md_builder_->createBranchWeights(1 << 20, 1); md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa"); md_tbaa_alias_set_ = md_builder_->createTBAANode("tvm-alias", md_tbaa_root_); - this->InitTarget(tm); + InitTarget(); } void CodeGenLLVM::SetFastMathFlag(llvm::FastMathFlags fmf) { builder_->setFastMathFlags(fmf); } -void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) { +void CodeGenLLVM::InitTarget() { + llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine(); module_->setTargetTriple(tm->getTargetTriple().str()); module_->setDataLayout(tm->createDataLayout()); data_layout_.reset(new llvm::DataLayout(module_.get())); - target_machine_ = tm; if (native_vector_bits_ == 0) { const auto& arch = tm->getTargetTriple().getArch(); if (arch == llvm::Triple::x86_64) { @@ -230,7 +229,8 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { } } } - llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx, "entry", function_); builder_->SetInsertPoint(entry); this->VisitStmt(f->body); @@ -242,7 +242,7 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { if (f != alloc_storage_info_.end()) { unsigned align = f->second.alignment; if (align > 1) { - auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align); + auto attr = llvm::Attribute::get(*ctx, llvm::Attribute::Alignment, align); function_->addParamAttr(i, attr); } } @@ -269,28 +269,16 @@ std::unique_ptr CodeGenLLVM::Finish() { } void CodeGenLLVM::HandleImport(const std::string& code) { + llvm::StringRef code_str(code); std::unique_ptr mlib; - llvm::SMDiagnostic err; - if (code.length() >= 3 && - (code.substr(code.length() - 3) == ".ll" || code.substr(code.length() - 3) == ".bc")) { - mlib = llvm::parseIRFile(code, err, *ctx_); - if (mlib.get() == nullptr) { - std::string msg = std::string(err.getMessage()); - LOG(FATAL) << "Fail to load bitcode file " << code << "\n" - << "line " << err.getLineNo() << ":" << msg; - } + if (code_str.endswith(".ll") || code_str.endswith(".bc")) { + mlib = llvm_target_->GetScope().LoadIR(code); } else { - std::unique_ptr buf = llvm::MemoryBuffer::getMemBuffer(code); - mlib = llvm::parseIR(*buf, err, *ctx_); - if (mlib.get() == nullptr) { - std::string msg = std::string(err.getMessage()); - LOG(FATAL) << "Fail to load llvm ir " - << "line " << err.getLineNo() << ":" << msg << "\ncontent:\n" - << code; - } + mlib = llvm_target_->GetScope().ParseIR(code); } - mlib->setTargetTriple(target_machine_->getTargetTriple().str()); - mlib->setDataLayout(target_machine_->createDataLayout()); + + mlib->setTargetTriple(llvm_target_->GetTargetTriple()); + mlib->setDataLayout(llvm_target_->GetOrCreateTargetMachine()->createDataLayout()); // mark all the functions as force inline for (llvm::Function& f : mlib->functions()) { f.removeFnAttr(llvm::Attribute::NoInline); @@ -338,16 +326,15 @@ void CodeGenLLVM::Optimize() { // pass manager FPassManager fpass(module_.get()); MPassManager mpass; - mpass.add(llvm::createTargetTransformInfoWrapperPass( - target_machine_ ? target_machine_->getTargetIRAnalysis() : llvm::TargetIRAnalysis())); - fpass.add(llvm::createTargetTransformInfoWrapperPass( - target_machine_ ? target_machine_->getTargetIRAnalysis() : llvm::TargetIRAnalysis())); + llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine(); + mpass.add(llvm::createTargetTransformInfoWrapperPass(tm->getTargetIRAnalysis())); + fpass.add(llvm::createTargetTransformInfoWrapperPass(tm->getTargetIRAnalysis())); // place optimization pass llvm::PassManagerBuilder builder; // Use the same opt-level as specified in TargetMachine for running passes - llvm::CodeGenOpt::Level opt_level = target_machine_->getOptLevel(); + llvm::CodeGenOpt::Level opt_level = llvm_target_->GetOptLevel(); switch (opt_level) { case llvm::CodeGenOpt::Level::None: @@ -376,7 +363,7 @@ void CodeGenLLVM::Optimize() { this->InitPassManagerBuilder(&builder); #if TVM_LLVM_VERSION >= 50 - target_machine_->adjustPassManager(builder); + tm->adjustPassManager(builder); #endif builder.populateFunctionPassManager(fpass); @@ -405,18 +392,19 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { return t_void_; } llvm::Type* etype = nullptr; + llvm::LLVMContext* ctx = llvm_target_->GetContext(); if (dtype.is_int() || dtype.is_uint()) { - etype = llvm::Type::getIntNTy(*ctx_, dtype.bits()); + etype = llvm::Type::getIntNTy(*ctx, dtype.bits()); } else if (dtype.is_float()) { switch (dtype.bits()) { case 16: - etype = llvm::Type::getHalfTy(*ctx_); + etype = llvm::Type::getHalfTy(*ctx); break; case 32: - etype = llvm::Type::getFloatTy(*ctx_); + etype = llvm::Type::getFloatTy(*ctx); break; case 64: - etype = llvm::Type::getDoubleTy(*ctx_); + etype = llvm::Type::getDoubleTy(*ctx); break; default: LOG(FATAL) << "do not support " << dtype; @@ -702,9 +690,10 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va const Var& loop_var, const Stmt& body) { llvm::BasicBlock* pre_block = builder_->GetInsertBlock(); std::string loop_var_name = loop_var->name_hint; - auto* for_begin = llvm::BasicBlock::Create(*ctx_, "for_begin_" + loop_var_name, function_); - auto* for_body = llvm::BasicBlock::Create(*ctx_, "for_body_" + loop_var_name, function_); - auto* for_end = llvm::BasicBlock::Create(*ctx_, "for_end_" + loop_var_name, function_); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + auto* for_begin = llvm::BasicBlock::Create(*ctx, "for_begin_" + loop_var_name, function_); + auto* for_body = llvm::BasicBlock::Create(*ctx, "for_body_" + loop_var_name, function_); + auto* for_end = llvm::BasicBlock::Create(*ctx, "for_end_" + loop_var_name, function_); builder_->CreateBr(for_begin); builder_->SetInsertPoint(for_begin); llvm::PHINode* loop_value = builder_->CreatePHI(begin->getType(), 2); @@ -777,7 +766,7 @@ llvm::Constant* CodeGenLLVM::GetGlobalConstant(llvm::Constant* const_data, const llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) { auto it = str_map_.find(str); if (it != str_map_.end()) return it->second; - auto llvm_str = llvm::ConstantDataArray::getString(*ctx_, str); + auto llvm_str = llvm::ConstantDataArray::getString(*llvm_target_->GetContext(), str); auto ptr = GetGlobalConstant(llvm_str, ".str", llvm::GlobalValue::PrivateLinkage); str_map_[str] = ptr; return ptr; @@ -950,11 +939,11 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type } void CodeGenLLVM::SetTargetAttributes(llvm::Function* func) { - llvm::StringRef cpu = target_machine_->getTargetCPU(); + const std::string& cpu = llvm_target_->GetCPU(); if (!cpu.empty()) { func->addFnAttr("target-cpu", cpu); } - llvm::StringRef features = target_machine_->getTargetFeatureString(); + const std::string& features = llvm_target_->GetTargetFeatureString(); if (!features.empty()) { func->addFnAttr("target-features", features); } @@ -980,8 +969,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { // mismatch will have to be treated specially here. // TODO(kparzysz-quic): fix this once TVM prefetch uses the same // type as LLVM. - llvm::Type* return_type = (id != llvm::Intrinsic::prefetch) ? GetLLVMType(GetRef(op)) - : llvm::Type::getVoidTy(*ctx_); + llvm::Type* return_type = + (id != llvm::Intrinsic::prefetch) ? GetLLVMType(GetRef(op)) : t_void_; llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type); ICHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: " #if TVM_LLVM_VERSION >= 130 @@ -1039,9 +1028,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return llvm::ConstantInt::get(DTypeToLLVMType(op->dtype), val); } else if (op->op.same_as(builtin::if_then_else())) { ICHECK_EQ(op->args[0].dtype().lanes(), 1) << "if_then_else can only take scalar condition"; - auto* then_block = llvm::BasicBlock::Create(*ctx_, "if_then", function_); - auto* else_block = llvm::BasicBlock::Create(*ctx_, "if_else", function_); - auto* end_block = llvm::BasicBlock::Create(*ctx_, "if_end", function_); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + auto* then_block = llvm::BasicBlock::Create(*ctx, "if_then", function_); + auto* else_block = llvm::BasicBlock::Create(*ctx, "if_else", function_); + auto* end_block = llvm::BasicBlock::Create(*ctx, "if_end", function_); builder_->CreateCondBr(MakeValue(op->args[0]), then_block, else_block); builder_->SetInsertPoint(then_block); llvm::Value* then_value = MakeValue(op->args[1]); @@ -1065,7 +1055,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { builder_->CreateRet(ConstInt32(0)); // LLVM allows exactly one terminator in a single basic block // append a new dummy basic block to avoid error. - llvm::BasicBlock* ret_dummy = llvm::BasicBlock::Create(*ctx_, "ret_dummy", function_); + llvm::BasicBlock* ret_dummy = + llvm::BasicBlock::Create(*llvm_target_->GetContext(), "ret_dummy", function_); builder_->SetInsertPoint(ret_dummy); return ret_dummy; } else if (op->op.same_as(builtin::reinterpret())) { @@ -1519,9 +1510,10 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) { } void CodeGenLLVM::VisitStmt_(const WhileNode* op) { - auto* while_cond = llvm::BasicBlock::Create(*ctx_, "while_cond", function_); - auto* while_body = llvm::BasicBlock::Create(*ctx_, "while_body", function_); - auto* while_merge = llvm::BasicBlock::Create(*ctx_, "while_merge", function_); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + auto* while_cond = llvm::BasicBlock::Create(*ctx, "while_cond", function_); + auto* while_body = llvm::BasicBlock::Create(*ctx, "while_body", function_); + auto* while_merge = llvm::BasicBlock::Create(*ctx, "while_merge", function_); builder_->CreateBr(while_cond); builder_->SetInsertPoint(while_cond); builder_->CreateCondBr(MakeValue(op->condition), while_body, while_merge); @@ -1533,10 +1525,11 @@ void CodeGenLLVM::VisitStmt_(const WhileNode* op) { void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { llvm::Value* cond = MakeValue(op->condition); - auto* then_block = llvm::BasicBlock::Create(*ctx_, "if_then", function_); - auto* end_block = llvm::BasicBlock::Create(*ctx_, "if_end", function_); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + auto* then_block = llvm::BasicBlock::Create(*ctx, "if_then", function_); + auto* end_block = llvm::BasicBlock::Create(*ctx, "if_end", function_); if (op->else_case.defined()) { - auto* else_block = llvm::BasicBlock::Create(*ctx_, "if_else", function_); + auto* else_block = llvm::BasicBlock::Create(*ctx, "if_else", function_); builder_->CreateCondBr(cond, then_block, else_block); builder_->SetInsertPoint(then_block); this->VisitStmt(op->then_case); @@ -1555,7 +1548,7 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { void CodeGenLLVM::VisitStmt_(const AllocateConstNode* op) { auto data = op->data.value(); - auto array = codegen::NDArrayToLLVMArray(ctx_, data); + auto array = NDArrayToLLVMArray(llvm_target_->GetContext(), data); std::string symbol_name = op->buffer_var->name_hint; llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable( *module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name); @@ -1673,4 +1666,5 @@ void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); } } // namespace codegen } // namespace tvm + #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index c6129c238c7f..d91f9c5fab33 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -23,6 +23,7 @@ */ #ifndef TVM_TARGET_LLVM_CODEGEN_LLVM_H_ #define TVM_TARGET_LLVM_CODEGEN_LLVM_H_ + #ifdef TVM_LLVM_VERSION #include @@ -40,7 +41,6 @@ #include #include #include -#include #include #if TVM_LLVM_VERSION >= 140 #include @@ -78,7 +78,6 @@ class Function; class GlobalVariable; class Instruction; class PassManagerBuilder; -class TargetMachine; class DIFile; class DICompileUnit; class MDNode; @@ -93,6 +92,8 @@ class MDBuilder; namespace tvm { namespace codegen { +class LLVMTarget; + using namespace tir; /*! @@ -109,7 +110,7 @@ class CodeGenLLVM : public ExprFunctor, * \param tm The target machine * \return The created llvm generator. */ - static std::unique_ptr Create(llvm::TargetMachine* tm); + static std::unique_ptr Create(LLVMTarget* llvm_target); /*! * \brief Initialize the code generator with given context * \param module_name The name of the module. @@ -121,8 +122,8 @@ class CodeGenLLVM : public ExprFunctor, * \param target_c_runtime If true, generate a module to be executed by the C runtime. In practice * this option influences whether global ctors are used. */ - virtual void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx, - bool system_lib, bool dynamic_lookup, bool target_c_runtime); + virtual void Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib, + bool dynamic_lookup, bool target_c_runtime); /*! * \brief Turn on fast math flags for floating point operations. @@ -229,9 +230,6 @@ class CodeGenLLVM : public ExprFunctor, llvm::Constant* GetGlobalConstant( llvm::Constant* const_data, const std::string& name = "", llvm::GlobalValue::LinkageTypes linkage_type = llvm::GlobalValue::InternalLinkage); - inline llvm::ConstantArray* NDArrayToLLVMArray(::tvm::runtime::NDArray arr) { - return codegen::NDArrayToLLVMArray(ctx_, arr); - } protected: /*! @@ -340,7 +338,7 @@ class CodeGenLLVM : public ExprFunctor, bool is_volatile)> make_instruction); // Initialize target - virtual void InitTarget(llvm::TargetMachine* tm); + virtual void InitTarget(); // Add module startup function if needed. virtual void AddStartupFunction() {} // apply optimization on the module. @@ -476,10 +474,8 @@ class CodeGenLLVM : public ExprFunctor, std::unique_ptr data_layout_; // Internal metabuilder std::unique_ptr md_builder_; - // llvm target machine - llvm::TargetMachine* target_machine_{nullptr}; - // llvm context - llvm::LLVMContext* ctx_{nullptr}; + // llvm target info + LLVMTarget* llvm_target_{nullptr}; // helpful data types llvm::Type* t_void_{nullptr}; llvm::PointerType* t_void_p_{nullptr}; @@ -495,7 +491,7 @@ class CodeGenLLVM : public ExprFunctor, llvm::MDNode* md_tbaa_root_{nullptr}; llvm::MDNode* md_tbaa_alias_set_{nullptr}; // modules to be linked. - std::vector > link_modules_; + std::vector> link_modules_; /*! \brief native vector bits of current targetx*/ int native_vector_bits_{0}; /*! \brief the storage scope of allocation */ @@ -567,5 +563,6 @@ void CodeGenLLVM::AddFunctionsOrdered(IterType begin, IterType end, ConvType pfu } // namespace codegen } // namespace tvm -#endif // LLVM_VERSION + +#endif // TVM_LLVM_VERSION #endif // TVM_TARGET_LLVM_CODEGEN_LLVM_H_ diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index a74274009cf4..67184c495c80 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -56,7 +56,7 @@ #include "../../runtime/cuda/cuda_module.h" #include "../build_common.h" #include "codegen_llvm.h" -#include "llvm_common.h" +#include "llvm_scope.h" namespace tvm { namespace codegen { @@ -68,10 +68,11 @@ class CodeGenNVPTX : public CodeGenLLVM { // add function as void return value CodeGenLLVM::AddFunctionInternal(f, true); // annotate as kernel function + llvm::LLVMContext* ctx = llvm_target_->GetContext(); module_->getOrInsertNamedMetadata("nvvm.annotations") ->addOperand(llvm::MDNode::get( - *ctx_, {llvm::ValueAsMetadata::get(function_), llvm::MDString::get(*ctx_, "kernel"), - llvm::ValueAsMetadata::get(ConstInt32(1))})); + *ctx, {llvm::ValueAsMetadata::get(function_), llvm::MDString::get(*ctx, "kernel"), + llvm::ValueAsMetadata::get(ConstInt32(1))})); } void VisitStmt_(const AllocateNode* op) final { @@ -203,10 +204,10 @@ class CodeGenNVPTX : public CodeGenLLVM { llvm::Value* CreateIntrinsic(const CallNode* op) override; protected: - void InitTarget(llvm::TargetMachine* tm) final { + void InitTarget() final { // Maximum vector lane = float4 native_vector_bits_ = 4 * 32; - CodeGenLLVM::InitTarget(tm); + CodeGenLLVM::InitTarget(); } }; @@ -298,15 +299,13 @@ int GetCUDAComputeVersion(const Target& target) { } runtime::Module BuildNVPTX(IRModule mod, Target target) { - InitializeLLVM(); + LLVMScope llvm_scope; + With llvm_target(llvm_scope, target); + int compute_ver = GetCUDAComputeVersion(target); - std::unique_ptr tm = GetLLVMTargetMachine(target); - std::unique_ptr ctx(new llvm::LLVMContext()); - // careful: cg will hold a naked pointer reference to ctx, so it should - // have a shorter lifetime than the ctx. std::unique_ptr cg(new CodeGenNVPTX()); - cg->Init("TVMPTXModule", tm.get(), ctx.get(), false, false, false); + cg->Init("TVMPTXModule", llvm_target.get(), false, false, false); cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) { ICHECK(kv.second->template IsInstance()) @@ -314,18 +313,13 @@ runtime::Module BuildNVPTX(IRModule mod, Target target) { return Downcast(kv.second); }); + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); const auto* flibdevice_path = tvm::runtime::Registry::Get("tvm_callback_libdevice_path"); if (flibdevice_path != nullptr) { std::string path = (*flibdevice_path)(compute_ver); if (path.length() != 0) { - llvm::SMDiagnostic err; - std::unique_ptr mlib = llvm::parseIRFile(path, err, *ctx); - if (mlib.get() == nullptr) { - std::string msg(err.getMessage()); - LOG(FATAL) << "Fail to load bitcode file " << path << "\n" - << "line " << err.getLineNo() << ":" << msg; - } - mlib->setTargetTriple(tm->getTargetTriple().str()); + std::unique_ptr mlib = llvm_scope.LoadIR(path); + mlib->setTargetTriple(llvm_target->GetTargetTriple()); mlib->setDataLayout(tm->createDataLayout()); cg->AddLinkModule(std::move(mlib)); } @@ -365,4 +359,5 @@ TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_nvptx") } // namespace codegen } // namespace tvm + #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index 2d36e0b022e1..7793317fe792 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -38,6 +38,7 @@ #include #include "codegen_cpu.h" +#include "llvm_scope.h" namespace tvm { namespace codegen { @@ -91,9 +92,9 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { const auto to = op->dtype; if (from.is_float() && to.is_float() && from.bits() == 16 && to.bits() == 32) { ICHECK_EQ(from.lanes(), to.lanes()); - CHECK_NOTNULL(target_machine_); + llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine(); - const auto has_avx512 = TargetHasFeature(*target_machine_, "avx512f"); + const auto has_avx512 = TargetHasFeature(*tm, "avx512f"); if (from.lanes() >= 16 && has_avx512) { return CallVectorIntrin( @@ -110,7 +111,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { #if TVM_LLVM_VERSION <= 100 // The intrinsic x86_vcvtph2ps_256 was removed in LLVM 11. - const auto has_f16c = TargetHasFeature(*target_machine_, "f16c"); + const auto has_f16c = TargetHasFeature(*tm, "f16c"); if (from.lanes() >= 8 && has_f16c) { return CallVectorIntrin(llvm::Intrinsic::x86_vcvtph2ps_256, 8, @@ -168,4 +169,5 @@ TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64") } // namespace codegen } // namespace tvm + #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/llvm_common.cc b/src/target/llvm/llvm_common.cc deleted file mode 100644 index 83de839a926e..000000000000 --- a/src/target/llvm/llvm_common.cc +++ /dev/null @@ -1,211 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file llvm_common.cc - */ -#ifdef TVM_LLVM_VERSION - -#include "llvm_common.h" - -#if TVM_LLVM_VERSION >= 140 -#include -#else -#include -#endif -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace tvm { -namespace codegen { - -struct LLVMEnv { - std::mutex mu; - std::atomic all_initialized{false}; - - static LLVMEnv* Global() { - static LLVMEnv inst; - return &inst; - } -}; - -void InitializeLLVM() { - LLVMEnv* e = LLVMEnv::Global(); - if (!e->all_initialized.load(std::memory_order::memory_order_acquire)) { - std::lock_guard lock(e->mu); - if (!e->all_initialized.load(std::memory_order::memory_order_acquire)) { - llvm::InitializeAllTargetInfos(); - llvm::InitializeAllTargets(); - llvm::InitializeAllTargetMCs(); - llvm::InitializeAllAsmParsers(); - llvm::InitializeAllAsmPrinters(); - e->all_initialized.store(true, std::memory_order::memory_order_release); - } - } -} - -void ParseLLVMTargetOptions(const Target& target, std::string* triple, std::string* mcpu, - std::string* mattr, llvm::TargetOptions* options) { - // simple parser - triple->resize(0); - mcpu->resize(0); - mattr->resize(0); - bool soft_float_abi = false; - if (const Optional& v = target->GetAttr("mtriple")) { - *triple = v.value(); - } - if (const Optional& v = target->GetAttr("mcpu")) { - *mcpu = v.value(); - } - if (const Optional>& v = target->GetAttr>("mattr")) { - std::ostringstream os; - bool is_first = true; - for (const String& s : v.value()) { - if (!is_first) { - os << ','; - } - is_first = false; - os << s; - } - *mattr = os.str(); - } - if (const Optional& v = target->GetAttr("mfloat-abi")) { - String value = v.value(); - if (value == "hard") { -#if TVM_LLVM_VERSION < 60 - LOG(FATAL) << "-mfloat-abi hard is only supported for LLVM > 6.0"; -#endif - soft_float_abi = false; - } else if (value == "soft") { - soft_float_abi = true; - } else { - LOG(FATAL) << "invalid -mfloat-abi option " << value; - } - } - if (triple->length() == 0 || *triple == "default") { - *triple = llvm::sys::getDefaultTargetTriple(); - } - // set target option - llvm::TargetOptions& opt = *options; - opt = llvm::TargetOptions(); -#if TVM_LLVM_VERSION < 50 - opt.LessPreciseFPMADOption = true; -#endif - // In clang, these are fed from LangOpts which describe language specific features - // TODO(AndrewZhaoLuo): figure out how these relate to fast math flags - opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; - opt.UnsafeFPMath = false; - opt.NoInfsFPMath = false; - opt.NoNaNsFPMath = true; - if (soft_float_abi) { - opt.FloatABIType = llvm::FloatABI::Soft; - } else { - opt.FloatABIType = llvm::FloatABI::Hard; - } - if (const Optional& v = target->GetAttr("mabi")) { - opt.MCOptions.ABIName = v.value(); - } -} - -std::unique_ptr GetLLVMTargetMachine(const Target& target, bool allow_null) { - std::string target_triple, mcpu, mattr; - llvm::TargetOptions opt; - - ParseLLVMTargetOptions(target, &target_triple, &mcpu, &mattr, &opt); - - if (target_triple.length() == 0 || target_triple == "default") { - target_triple = llvm::sys::getDefaultTargetTriple(); - } - if (mcpu.length() == 0) { - mcpu = "generic"; - } - - std::string err; - const llvm::Target* llvm_target = llvm::TargetRegistry::lookupTarget(target_triple, err); - if (llvm_target == nullptr) { - ICHECK(allow_null) << err << " target_triple=" << target_triple; - return nullptr; - } - - int llvm_opt_level = target->GetAttr("opt-level").value_or(Integer(3)).IntValue(); - llvm::CodeGenOpt::Level llvm_opt; - if (llvm_opt_level <= 0) { - llvm_opt = llvm::CodeGenOpt::None; - } else if (llvm_opt_level == 1) { - llvm_opt = llvm::CodeGenOpt::Less; - } else if (llvm_opt_level == 2) { - llvm_opt = llvm::CodeGenOpt::Default; - } else { - // llvm_opt_level >= 3 - llvm_opt = llvm::CodeGenOpt::Aggressive; - } - - llvm::TargetMachine* tm = llvm_target->createTargetMachine( - target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_, llvm::CodeModel::Small, llvm_opt); - return std::unique_ptr(tm); -} - -std::string LLVMTargetToString(const Target& target) { - std::ostringstream os; - os << "llvm"; - if (Optional mtriple = target->GetAttr("mtriple")) { - os << " -mtriple=" << mtriple.value(); - } - if (Optional mcpu = target->GetAttr("mcpu")) { - os << " -mcpu=" << mcpu.value(); - } - if (Optional> mattr = target->GetAttr>("mattr")) { - bool is_first = true; - os << " -mattr="; - for (const String& attr : mattr.value()) { - if (!is_first) { - os << ","; - } - is_first = false; - os << attr; - } - } - if (Optional mfloat_abo = target->GetAttr("mfloat-abi")) { - os << " -mfloat-abi=" << mfloat_abo.value(); - } - if (Optional mabi = target->GetAttr("mabi")) { - os << " -mabi=" << mabi.value(); - } - return os.str(); -} - -} // namespace codegen -} // namespace tvm -#endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/llvm_common.h b/src/target/llvm/llvm_common.h deleted file mode 100644 index c127b77c03ac..000000000000 --- a/src/target/llvm/llvm_common.h +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file llvm_common.h - * \brief Common utilities for llvm initialization. - */ -#ifndef TVM_TARGET_LLVM_LLVM_COMMON_H_ -#define TVM_TARGET_LLVM_LLVM_COMMON_H_ - -#ifdef _MSC_VER -#pragma warning(disable : 4141 4291 4146 4624) -#endif -#ifdef TVM_LLVM_VERSION - -#include - -#include -#include -#include - -namespace llvm { -class Module; -class Target; -class TargetMachine; -class TargetOptions; -} // namespace llvm - -namespace tvm { - -// The TVM target -class Target; - -namespace codegen { - -/*! - * \brief Initialize LLVM on this process, - * can be called multiple times. - */ -void InitializeLLVM(); - -/*! - * \brief Parse target options - * \param target The TVM target - * \param triple Target triple - * \param mcpu cpu info - * \param options the options - * \param mattr The attributes - */ -void ParseLLVMTargetOptions(const Target& target, std::string* triple, std::string* mcpu, - std::string* mattr, llvm::TargetOptions* options); - -/*! - * \brief Get target machine from TVM target. - * \param target The TVM target - * \param allow_null Whether allow null to be returned. - * \return target machine - */ -std::unique_ptr GetLLVMTargetMachine(const Target& target, - bool allow_null = false); - -/*! - * \brief Convert the TVM's LLVM target to string by extracting only relevant fields - * \param target The TVM target to be extracted - * \return The raw string format for the TVM LLVM target - */ -std::string LLVMTargetToString(const Target& target); - -} // namespace codegen -} // namespace tvm - -#endif // TVM_LLVM_VERSION -#endif // TVM_TARGET_LLVM_LLVM_COMMON_H_ diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 80731895a4f6..919fca879582 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -51,11 +51,13 @@ #include #include #include +#include #include #include #include #include #include +#include #include #include @@ -74,7 +76,7 @@ #include "codegen_blob.h" #include "codegen_cpu.h" #include "codegen_llvm.h" -#include "llvm_common.h" +#include "llvm_scope.h" namespace tvm { namespace codegen { @@ -85,392 +87,332 @@ using runtime::TVMRetValue; class LLVMModuleNode final : public runtime::ModuleNode { public: - ~LLVMModuleNode() { - module_owning_ptr_.reset(); - if (ee_ != nullptr) { - ee_->runStaticConstructorsDestructors(true); - delete ee_; - } - } + ~LLVMModuleNode(); const char* type_key() const final { return "llvm"; } - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { - if (name == "__tvm_is_system_module") { - bool flag = (module_->getFunction("__tvm_module_startup") != nullptr); - return PackedFunc([flag](TVMArgs args, TVMRetValue* rv) { *rv = flag; }); - } else if (name == "get_func_names") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->function_names_; }); - } else if (name == "get_symbol") { - return PackedFunc(nullptr); - } else if (name == "get_const_vars") { - return PackedFunc(nullptr); - } else if (name == "_get_target_string") { - std::string target_string = LLVMTargetToString(target_); - return PackedFunc([target_string](TVMArgs args, TVMRetValue* rv) { *rv = target_string; }); - } - if (ee_ == nullptr) LazyInitJIT(); - - std::lock_guard lock(mutex_); - - TVMBackendPackedCFunc faddr; - if (name == runtime::symbol::tvm_module_main) { - const char* entry_name = - reinterpret_cast(GetGlobalAddr(runtime::symbol::tvm_module_main)); - ICHECK(entry_name != nullptr) - << "Symbol " << runtime::symbol::tvm_module_main << " is not presented"; - faddr = reinterpret_cast(GetFunctionAddr(entry_name)); - } else { - faddr = reinterpret_cast(GetFunctionAddr(name)); - } - if (faddr == nullptr) return PackedFunc(); - return WrapPackedFunc(faddr, sptr_to_self); + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; + + void SaveToFile(const std::string& file_name, const std::string& format) final; + void SaveToBinary(dmlc::Stream* stream) final; + std::string GetSource(const std::string& format) final; + + void Init(const IRModule& mod, const Target& target); + void Init(std::unique_ptr module, std::unique_ptr llvm_scope); + void LoadIR(const std::string& file_name); + bool IsDSOExportable() const final { return true; } + + bool ImplementsFunction(const String& name, bool query_imports) final; + + private: + void LazyInitJIT(); + bool IsCompatibleWithHost(const llvm::TargetMachine* tm) const; + void* GetGlobalAddr(const std::string& name, const LLVMTarget& llvm_target) const; + void* GetFunctionAddr(const std::string& name, const LLVMTarget& llvm_target) const; + + // The LLVM scope object. + std::unique_ptr llvm_scope_; + // JIT lock + std::mutex mutex_; + // execution engine + llvm::ExecutionEngine* ee_{nullptr}; + // The raw pointer to the module. + llvm::Module* module_{nullptr}; + // The unique_ptr owning the module. This becomes empty once JIT has been initialized + // (EngineBuilder takes ownership of the module). + std::unique_ptr module_owning_ptr_; + /* \brief names of the functions declared in this module */ + Array function_names_; +}; + +LLVMModuleNode::~LLVMModuleNode() { + if (ee_ != nullptr) { + ee_->runStaticConstructorsDestructors(true); + delete ee_; } + module_owning_ptr_.reset(); +} - void SaveToFile(const std::string& file_name, const std::string& format) final { - std::string fmt = runtime::GetFileFormat(file_name, format); - std::error_code ecode; +PackedFunc LLVMModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + if (name == "__tvm_is_system_module") { + bool flag = (module_->getFunction("__tvm_module_startup") != nullptr); + return PackedFunc([flag](TVMArgs args, TVMRetValue* rv) { *rv = flag; }); + } else if (name == "get_func_names") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->function_names_; }); + } else if (name == "get_symbol") { + return PackedFunc(nullptr); + } else if (name == "get_const_vars") { + return PackedFunc(nullptr); + } else if (name == "_get_target_string") { + std::string target_string = LLVMTarget::GetTargetMetadata(*module_); + return PackedFunc([target_string](TVMArgs args, TVMRetValue* rv) { *rv = target_string; }); + } + if (ee_ == nullptr) LazyInitJIT(); + + std::lock_guard lock(mutex_); + + TVMBackendPackedCFunc faddr; + With llvm_target(*llvm_scope_, LLVMTarget::GetTargetMetadata(*module_)); + if (name == runtime::symbol::tvm_module_main) { + const char* entry_name = reinterpret_cast( + GetGlobalAddr(runtime::symbol::tvm_module_main, *llvm_target)); + ICHECK(entry_name != nullptr) << "Symbol " << runtime::symbol::tvm_module_main + << " is not presented"; + faddr = reinterpret_cast(GetFunctionAddr(entry_name, *llvm_target)); + } else { + faddr = reinterpret_cast(GetFunctionAddr(name, *llvm_target)); + } + if (faddr == nullptr) return PackedFunc(); + return WrapPackedFunc(faddr, sptr_to_self); +} + +void LLVMModuleNode::SaveToFile(const std::string& file_name, const std::string& format) { + std::string fmt = runtime::GetFileFormat(file_name, format); + std::error_code ecode; #if TVM_LLVM_VERSION <= 70 - llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::F_None); + llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::F_None); #else - llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::OF_None); + llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::OF_None); #endif - ICHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name << " " << ecode.message(); - if (fmt == "o" || fmt == "obj") { + ICHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name << " " << ecode.message(); + if (fmt == "o" || fmt == "obj") { + With llvm_target(*llvm_scope_, LLVMTarget::GetTargetMetadata(*module_)); #if TVM_LLVM_VERSION <= 60 - std::unique_ptr m = llvm::CloneModule(module_); + std::unique_ptr m = llvm::CloneModule(module_); #else - std::unique_ptr m = llvm::CloneModule(*module_); + std::unique_ptr m = llvm::CloneModule(*module_); #endif - llvm::legacy::PassManager pass; - ICHECK(tm_); + llvm::legacy::PassManager pass; + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); #if TVM_LLVM_VERSION <= 60 - ICHECK(tm_->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_ObjectFile) == 0) - << "Cannot emit target CGFT_ObjectFile"; + ICHECK(tm->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_ObjectFile) == 0) + << "Cannot emit target CGFT_ObjectFile"; #elif TVM_LLVM_VERSION <= 90 - ICHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == - 0) - << "Cannot emit target CGFT_ObjectFile"; + ICHECK(tm->addPassesToEmitFile(pass, dest, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0) + << "Cannot emit target CGFT_ObjectFile"; #else - ICHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_ObjectFile) == 0) - << "Cannot emit target CGFT_ObjectFile"; + ICHECK(tm->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_ObjectFile) == 0) + << "Cannot emit target CGFT_ObjectFile"; #endif - pass.run(*m); - } else if (fmt == "s" || fmt == "asm") { + pass.run(*m); + } else if (fmt == "s" || fmt == "asm") { + With llvm_target(*llvm_scope_, LLVMTarget::GetTargetMetadata(*module_)); #if TVM_LLVM_VERSION <= 60 - std::unique_ptr m = llvm::CloneModule(module_); + std::unique_ptr m = llvm::CloneModule(module_); #else - std::unique_ptr m = llvm::CloneModule(*module_); + std::unique_ptr m = llvm::CloneModule(*module_); #endif - llvm::legacy::PassManager pass; - ICHECK(tm_); + llvm::legacy::PassManager pass; + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); #if TVM_LLVM_VERSION <= 60 - ICHECK(tm_->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; + ICHECK(tm->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; #elif TVM_LLVM_VERSION <= 90 - ICHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, - llvm::TargetMachine::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; + ICHECK(tm->addPassesToEmitFile(pass, dest, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == + 0) + << "Cannot emit target CGFT_AssemblyFile"; #else - ICHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; + ICHECK(tm->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; #endif - pass.run(*m); - } else if (fmt == "ll") { - module_->print(dest, nullptr); - } else if (fmt == "bc") { + pass.run(*m); + } else if (fmt == "ll") { + module_->print(dest, nullptr); + } else if (fmt == "bc") { #if TVM_LLVM_VERSION <= 60 - llvm::WriteBitcodeToFile(module_, dest); + llvm::WriteBitcodeToFile(module_, dest); #else - llvm::WriteBitcodeToFile(*module_, dest); + llvm::WriteBitcodeToFile(*module_, dest); #endif - } else { - LOG(FATAL) << "Do not know how to save file " << file_name << " with format=\'" << format - << "\'"; - } - dest.close(); + } else { + LOG(FATAL) << "Do not know how to save file " << file_name << " with format=\'" << format + << "\'"; } + dest.close(); +} - void SaveToBinary(dmlc::Stream* stream) final { - LOG(FATAL) << "LLVMModule: SaveToBinary not supported"; - } +void LLVMModuleNode::SaveToBinary(dmlc::Stream* stream) { + LOG(FATAL) << "LLVMModule: SaveToBinary not supported"; +} - std::string GetSource(const std::string& format) final { - std::string fmt = runtime::GetFileFormat("", format); - std::string type_str; - llvm::SmallString<256> str; - llvm::raw_svector_ostream rso(str); +std::string LLVMModuleNode::GetSource(const std::string& format) { + std::string fmt = runtime::GetFileFormat("", format); + std::string type_str; + llvm::SmallString<256> str; + llvm::raw_svector_ostream rso(str); - if (fmt == "s" || fmt == "asm") { + if (fmt == "s" || fmt == "asm") { + With llvm_target(*llvm_scope_, LLVMTarget::GetTargetMetadata(*module_)); #if TVM_LLVM_VERSION <= 60 - std::unique_ptr m = llvm::CloneModule(module_); + std::unique_ptr m = llvm::CloneModule(module_); #else - std::unique_ptr m = llvm::CloneModule(*module_); + std::unique_ptr m = llvm::CloneModule(*module_); #endif - llvm::legacy::PassManager pass; - ICHECK(tm_); + llvm::legacy::PassManager pass; + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); #if TVM_LLVM_VERSION <= 60 - ICHECK(tm_->addPassesToEmitFile(pass, rso, llvm::TargetMachine::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; + ICHECK(tm->addPassesToEmitFile(pass, rso, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; #elif TVM_LLVM_VERSION <= 90 - ICHECK(tm_->addPassesToEmitFile(pass, rso, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == - 0) - << "Cannot emit target CGFT_AssemblyFile"; + ICHECK(tm->addPassesToEmitFile(pass, rso, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; #else - ICHECK(tm_->addPassesToEmitFile(pass, rso, nullptr, llvm::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; + ICHECK(tm->addPassesToEmitFile(pass, rso, nullptr, llvm::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; #endif - pass.run(*m); - return rso.str().str(); - } else if (fmt == "" || fmt == "ll") { - std::string type_str; - llvm::raw_string_ostream rso(type_str); - ICHECK(module_ != nullptr); - module_->print(rso, nullptr); - return rso.str(); - } else { - LOG(FATAL) << "Do not know how to get source code with format: " << format << "\'"; - } - return ""; + pass.run(*m); + return rso.str().str(); + } else if (fmt == "" || fmt == "ll") { + std::string type_str; + llvm::raw_string_ostream rso(type_str); + ICHECK(module_ != nullptr); + module_->print(rso, nullptr); + return rso.str(); + } else { + LOG(FATAL) << "Do not know how to get source code with format: " << format << "\'"; } + return ""; +} - void Init(const IRModule& mod, const Target& target) { - InitializeLLVM(); - tm_ = GetLLVMTargetMachine(target); - ctx_ = std::make_shared(); - std::unique_ptr cg = CodeGenLLVM::Create(tm_.get()); - - std::vector funcs; - std::string entry_func; - relay::Runtime runtime = - mod->GetAttr(tvm::attr::kRuntime).value_or(relay::Runtime::Create("cpp")); - bool system_lib = runtime->GetAttr("system-lib").value_or(Bool(false)); - bool target_c_runtime = runtime->name == "crt"; - - for (auto kv : mod->functions) { - if (!kv.second->IsInstance()) { - // (@jroesch): we relax constraints here, Relay functions will just be ignored. - DLOG(INFO) << "Can only lower IR Module with PrimFuncs, but got " - << kv.second->GetTypeKey(); - continue; - } - auto f = Downcast(kv.second); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()); - function_names_.push_back(global_symbol.value()); - if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { - entry_func = global_symbol.value(); - } - funcs.push_back(f); - } - // TODO(@jroesch): follow up on this condition. - // ICHECK(funcs.size() > 0); - // TODO(tqchen): remove the entry function behavior as it does not - // makes sense when we start to use multiple modules. - cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib, target_c_runtime); - - // See https://llvm.org/docs/LangRef.html#fast-math-flags for details - Bool fast_math_all = target->GetAttr("fast-math").value_or(Bool(false)); - Bool fast_math_nnan = target->GetAttr("fast-math-nnan").value_or(Bool(false)); - Bool fast_math_ninf = target->GetAttr("fast-math-ninf").value_or(Bool(false)); - Bool fast_math_nsz = target->GetAttr("fast-math-nsz").value_or(Bool(false)); - Bool fast_math_arcp = target->GetAttr("fast-math-arcp").value_or(Bool(false)); - - llvm::FastMathFlags fmf; - if (fast_math_all) { -#if TVM_LLVM_VERSION >= 60 - fmf.setFast(); -#else - fmf.setUnsafeAlgebra(); -#endif - } +void LLVMModuleNode::Init(const IRModule& mod, const Target& target) { + llvm_scope_ = std::make_unique(); + With llvm_target(*llvm_scope_, target); + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); + std::unique_ptr cg = CodeGenLLVM::Create(llvm_target.get()); - if (fast_math_nnan) { - fmf.setNoNaNs(); - } - if (fast_math_ninf) { - fmf.setNoInfs(); - } - if (fast_math_nsz) { - fmf.setNoSignedZeros(); - } - if (fast_math_arcp) { - fmf.setAllowReciprocal(); - } + std::vector funcs; + std::string entry_func; + relay::Runtime runtime = + mod->GetAttr(tvm::attr::kRuntime).value_or(relay::Runtime::Create("cpp")); + bool system_lib = runtime->GetAttr("system-lib").value_or(Bool(false)); + bool target_c_runtime = runtime->name == "crt"; -#if TVM_LLVM_VERSION >= 60 - Bool fast_math_contract = target->GetAttr("fast-math-contract").value_or(Bool(false)); - Bool fast_math_afn = target->GetAttr("fast-math-afn").value_or(Bool(false)); - Bool fast_math_reassoc = target->GetAttr("fast-math-reassoc").value_or(Bool(false)); - if (fast_math_contract) { - fmf.setAllowContract(true); - } - if (fast_math_afn) { - fmf.setApproxFunc(); + for (auto kv : mod->functions) { + if (!kv.second->IsInstance()) { + // (@jroesch): we relax constraints here, Relay functions will just be ignored. + DLOG(INFO) << "Can only lower IR Module with PrimFuncs, but got " << kv.second->GetTypeKey(); + continue; } - if (fast_math_reassoc) { - fmf.setAllowReassoc(); + auto f = Downcast(kv.second); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.defined()); + function_names_.push_back(global_symbol.value()); + if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + entry_func = global_symbol.value(); } -#endif + funcs.push_back(f); + } + // TODO(@jroesch): follow up on this condition. + // ICHECK(funcs.size() > 0); + // TODO(tqchen): remove the entry function behavior as it does not + // makes sense when we start to use multiple modules. + cg->Init("TVMMod", llvm_target.get(), system_lib, system_lib, target_c_runtime); + cg->SetFastMathFlag(llvm_target->GetFastMathFlags()); + + cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); + if (entry_func.length() != 0) { + cg->AddMainFunction(entry_func); + } - cg->SetFastMathFlag(fmf); + module_owning_ptr_ = cg->Finish(); + module_ = module_owning_ptr_.get(); + llvm_target->SetTargetMetadata(module_); + module_->addModuleFlag(llvm::Module::Override, "Debug Info Version", + llvm::DEBUG_METADATA_VERSION); - cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); - if (entry_func.length() != 0) { - cg->AddMainFunction(entry_func); - } + if (tm->getTargetTriple().isOSDarwin()) { + module_->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2); + } - module_owning_ptr_ = cg->Finish(); - module_ = module_owning_ptr_.get(); + std::string verify_errors_storage; + llvm::raw_string_ostream verify_errors(verify_errors_storage); + LOG_IF(FATAL, llvm::verifyModule(*module_, &verify_errors)) + << "LLVM module verification failed with the following errors: \n" + << verify_errors.str(); +} - module_->addModuleFlag(llvm::Module::Warning, "tvm_target", - llvm::MDString::get(*ctx_, LLVMTargetToString(target))); - module_->addModuleFlag(llvm::Module::Override, "Debug Info Version", - llvm::DEBUG_METADATA_VERSION); +void LLVMModuleNode::Init(std::unique_ptr module, + std::unique_ptr llvm_scope) { + module_owning_ptr_ = std::move(module); + module_ = module_owning_ptr_.get(); + llvm_scope_ = std::move(llvm_scope); +} - if (tm_->getTargetTriple().isOSDarwin()) { - module_->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2); - } +void LLVMModuleNode::LoadIR(const std::string& file_name) { + auto llvm_scope = std::make_unique(); + std::unique_ptr module = llvm_scope->LoadIR(file_name); + Init(std::move(module), std::move(llvm_scope)); +} - std::string verify_errors_storage; - llvm::raw_string_ostream verify_errors(verify_errors_storage); - LOG_IF(FATAL, llvm::verifyModule(*module_, &verify_errors)) - << "LLVM module verification failed with the following errors: \n" - << verify_errors.str(); - target_ = target; - } +bool LLVMModuleNode::ImplementsFunction(const String& name, bool query_imports) { + return std::find(function_names_.begin(), function_names_.end(), name) != function_names_.end(); +} - void Init(std::unique_ptr module, std::shared_ptr ctx) { - InitializeLLVM(); - ctx_ = ctx; - llvm::SMDiagnostic err; - module_owning_ptr_ = std::move(module); - module_ = module_owning_ptr_.get(); - if (module_ == nullptr) { - std::string msg = std::string(err.getMessage()); - LOG(FATAL) << "Fail to load module: " << msg; - } - std::string target_metadata; - llvm::Metadata* tvm_target = module_->getModuleFlag("tvm_target"); - if (tvm_target != nullptr) { - llvm::MDString* pstr = llvm::dyn_cast(tvm_target); - ICHECK(pstr != nullptr); - target_metadata = pstr->getString().str(); - if (!(target_metadata.length() >= 4 && target_metadata.substr(0, 4) == "llvm")) { - target_metadata = "llvm " + target_metadata; - } - } else { - std::ostringstream os; - os << "llvm -mtriple " << module_->getTargetTriple(); - target_metadata = os.str(); - } - target_ = Target(target_metadata); - tm_ = GetLLVMTargetMachine(target_); +void LLVMModuleNode::LazyInitJIT() { + std::lock_guard lock(mutex_); + if (ee_) { + return; } - - void LoadIR(const std::string& file_name) { - auto ctx = std::make_shared(); - llvm::SMDiagnostic err; - auto module = llvm::parseIRFile(file_name, err, *ctx); - if (module == nullptr) { - std::string msg = std::string(err.getMessage()); - LOG(FATAL) << "Fail to load ir file " << file_name << "\n" - << "line " << err.getLineNo() << ":" << msg; - } - Init(std::move(module), ctx); + With llvm_target(*llvm_scope_, LLVMTarget::GetTargetMetadata(*module_)); + llvm::EngineBuilder builder(std::move(module_owning_ptr_)); + builder.setEngineKind(llvm::EngineKind::JIT); + builder.setOptLevel(llvm::CodeGenOpt::Aggressive); + builder.setMCPU(llvm_target->GetCPU()); + builder.setMAttrs(llvm_target->GetTargetFeatures()); + builder.setTargetOptions(llvm_target->GetTargetOptions()); + auto tm = std::unique_ptr(builder.selectTarget()); + if (!IsCompatibleWithHost(tm.get())) { + LOG(FATAL) << "Cannot run module, architecture mismatch"; } - - bool IsDSOExportable() const final { return true; } - - bool ImplementsFunction(const String& name, bool query_imports) final { - return std::find(function_names_.begin(), function_names_.end(), name) != function_names_.end(); + llvm::DataLayout layout(tm->createDataLayout()); + ICHECK(layout == module_->getDataLayout()) + << "Data layout mismatch between module(" + << module_->getDataLayout().getStringRepresentation() << ")" + << " and ExecutionEngine (" << layout.getStringRepresentation() << ")"; + ee_ = builder.create(tm.release()); + ICHECK(ee_ != nullptr) << "Failed to initialize jit engine for " << module_->getTargetTriple(); + ee_->runStaticConstructorsDestructors(false); + + if (void** ctx_addr = + reinterpret_cast(GetGlobalAddr(runtime::symbol::tvm_module_ctx, *llvm_target))) { + *ctx_addr = this; } + runtime::InitContextFunctions( + [this, &llvm_target](const char* name) { return GetGlobalAddr(name, *llvm_target); }); +} - private: - void LazyInitJIT() { - std::lock_guard lock(mutex_); - if (ee_) { - return; - } - if (!target_.defined()) { - target_ = Target("llvm"); - } - llvm::EngineBuilder builder(std::move(module_owning_ptr_)); - std::string triple, mcpu, mattr; - llvm::TargetOptions opt; - ParseLLVMTargetOptions(target_, &triple, &mcpu, &mattr, &opt); - builder.setEngineKind(llvm::EngineKind::JIT); - builder.setOptLevel(llvm::CodeGenOpt::Aggressive); - if (mcpu.length() != 0) { - builder.setMCPU(mcpu); - } - if (mattr.length() != 0) { - std::vector mattrs{mattr}; - builder.setMAttrs(mattrs); - } - builder.setTargetOptions(opt); - auto tm = std::unique_ptr(builder.selectTarget()); - std::unique_ptr tm_sys = GetLLVMTargetMachine(Target("llvm")); - if (tm_sys->getTargetTriple().getArch() != tm->getTargetTriple().getArch()) { - LOG(FATAL) << "Cannot run module, architecture mismatch " - << " module=" << tm->getTargetTriple().str() - << " system=" << tm_sys->getTargetTriple().str(); - } - llvm::DataLayout layout(tm->createDataLayout()); - ICHECK(layout == module_->getDataLayout()) - << "Data layout mismatch between module(" - << module_->getDataLayout().getStringRepresentation() << ")" - << " and ExecutionEngine (" << layout.getStringRepresentation() << ")"; - ee_ = builder.create(tm.release()); - ICHECK(ee_ != nullptr) << "Failed to initialize jit engine for " << module_->getTargetTriple(); - ee_->runStaticConstructorsDestructors(false); - - if (void** ctx_addr = - reinterpret_cast(GetGlobalAddr(runtime::symbol::tvm_module_ctx))) { - *ctx_addr = this; - } - runtime::InitContextFunctions( - [this](const char* name) { return reinterpret_cast(GetGlobalAddr(name)); }); +bool LLVMModuleNode::IsCompatibleWithHost(const llvm::TargetMachine* tm) const { + With host_target(*llvm_scope_, "llvm"); // FIXME(kparzysz-quic): nesting + auto tm_host = host_target->GetOrCreateTargetMachine(); + if (tm_host->getTargetTriple().getArch() != tm->getTargetTriple().getArch()) { + LOG(INFO) << "Architecture mismatch: module=" << tm->getTargetTriple().str() + << " host=" << tm_host->getTargetTriple().str(); + return false; } + return true; +} - // Get global address from execution engine. - uint64_t GetGlobalAddr(const std::string& name) const { - // first verifies if GV exists. - if (module_->getGlobalVariable(name) != nullptr) { - return ee_->getGlobalValueAddress(name); - } else { - return 0; - } +// Get global address from execution engine. +void* LLVMModuleNode::GetGlobalAddr(const std::string& name, const LLVMTarget& llvm_target) const { + // first verifies if GV exists. + if (module_->getGlobalVariable(name) != nullptr) { + return reinterpret_cast(ee_->getGlobalValueAddress(name)); + } else { + return nullptr; } +} - uint64_t GetFunctionAddr(const std::string& name) const { - // first verifies if GV exists. - if (module_->getFunction(name) != nullptr) { - return ee_->getFunctionAddress(name); - } else { - return 0; - } +void* LLVMModuleNode::GetFunctionAddr(const std::string& name, + const LLVMTarget& llvm_target) const { + // first verifies if GV exists. + if (module_->getFunction(name) != nullptr) { + return reinterpret_cast(ee_->getFunctionAddress(name)); + } else { + return nullptr; } - - // The target configuration string - Target target_; - // JIT lock - std::mutex mutex_; - // execution engine - llvm::ExecutionEngine* ee_{nullptr}; - // The target machine - std::unique_ptr tm_{nullptr}; - // The raw pointer to the module. - llvm::Module* module_{nullptr}; - // The unique_ptr owning the module. This becomes empty once JIT has been initialized - // (EngineBuilder takes ownership of the module). - std::unique_ptr module_owning_ptr_; - // the context. - std::shared_ptr ctx_; - /* \brief names of the functions declared in this module */ - Array function_names_; -}; +} TVM_REGISTER_GLOBAL("target.build.llvm") .set_body_typed([](IRModule mod, Target target) -> runtime::Module { @@ -481,18 +423,15 @@ TVM_REGISTER_GLOBAL("target.build.llvm") TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") .set_body_typed([](std::string target_str, std::string module_name) -> runtime::Module { - Target target = Target(target_str); + auto llvm_scope = std::make_unique(); + With llvm_target(*llvm_scope, target_str); auto n = make_object(); // Generate a LLVM module from an input target string - InitializeLLVM(); - auto tm = GetLLVMTargetMachine(target); - auto ctx = std::make_shared(); - std::unique_ptr module(new llvm::Module(module_name, *ctx)); - // Use a default data layout and target triple - auto triple = tm->getTargetTriple(); - module->setTargetTriple(triple.str()); - module->setDataLayout(tm->createDataLayout()); - n->Init(std::move(module), ctx); + auto module = std::make_unique(module_name, *llvm_target->GetContext()); + llvm_target->SetTargetMetadata(module.get()); + module->setTargetTriple(llvm_target->GetTargetTriple()); + module->setDataLayout(llvm_target->GetOrCreateTargetMachine()->createDataLayout()); + n->Init(std::move(module), std::move(llvm_scope)); return runtime::Module(n); }); @@ -529,38 +468,39 @@ TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll") TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled") .set_body_typed([](std::string target_str) -> bool { - InitializeLLVM(); - Target target = Target(target_str); - return (GetLLVMTargetMachine(target, true) != nullptr); + LLVMScope llvm_scope; + auto* tm = With(llvm_scope, target_str) + ->GetOrCreateTargetMachine(/*allow_missing=*/true); + return tm != nullptr; }); TVM_REGISTER_GLOBAL("codegen.codegen_blob") .set_body_typed([](std::string data, bool system_lib, std::string llvm_target_string) -> runtime::Module { auto n = make_object(); - auto p = CodeGenBlob(data, system_lib, llvm_target_string); - n->Init(std::move(p.first), p.second); + auto llvm_scope = std::make_unique(); + With llvm_target(*llvm_scope, llvm_target_string); + std::unique_ptr blob = CodeGenBlob(data, system_lib, llvm_target.get()); + n->Init(std::move(blob), std::move(llvm_scope)); return runtime::Module(n); }); runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata, Target target, tvm::relay::Runtime runtime) { - InitializeLLVM(); - auto tm = GetLLVMTargetMachine(target); + auto llvm_scope = std::make_unique(); + With llvm_target(*llvm_scope, target); bool system_lib = runtime->GetAttr("system-lib").value_or(Bool(false)); - auto ctx = std::make_shared(); std::unique_ptr cg{new CodeGenCPU()}; - cg->Init("TVMMetadataMod", tm.get(), ctx.get(), system_lib, system_lib, - false /* target_c_runtime */); + cg->Init("TVMMetadataMod", llvm_target.get(), system_lib, system_lib, + /*target_c_runtime=*/false); cg->DefineMetadata(metadata); auto mod = cg->Finish(); - mod->addModuleFlag(llvm::Module::Warning, "tvm_target", - llvm::MDString::get(*ctx, LLVMTargetToString(target))); + llvm_target->SetTargetMetadata(mod.get()); mod->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION); - if (tm->getTargetTriple().isOSDarwin()) { + if (llvm_target->GetOrCreateTargetMachine()->getTargetTriple().isOSDarwin()) { mod->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2); } @@ -571,7 +511,7 @@ runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata << verify_errors.str(); auto n = make_object(); - n->Init(std::move(mod), ctx); + n->Init(std::move(mod), std::move(llvm_scope)); auto meta_mod = MetadataModuleCreate(metadata); meta_mod->Import(runtime::Module(n)); @@ -591,24 +531,22 @@ runtime::Module CreateLLVMCrtMetadataModule(const Array& module } } - InitializeLLVM(); - auto tm = GetLLVMTargetMachine(target); + auto llvm_scope = std::make_unique(); + With llvm_target(*llvm_scope, target); bool system_lib = runtime->GetAttr("system-lib").value_or(Bool(false)); bool target_c_runtime = runtime->name == "crt"; ICHECK(system_lib && target_c_runtime) << "For LLVM C-runtime metadata module, must include --system-lib and --runtime=c; " << "got target: " << target->str(); - auto ctx = std::make_shared(); std::unique_ptr cg{new CodeGenCPU()}; - cg->Init("TVMMetadataMod", tm.get(), ctx.get(), system_lib, system_lib, target_c_runtime); + cg->Init("TVMMetadataMod", llvm_target.operator->(), system_lib, system_lib, target_c_runtime); cg->DefineFunctionRegistry(func_names); auto mod = cg->Finish(); - mod->addModuleFlag(llvm::Module::Warning, "tvm_target", - llvm::MDString::get(*ctx, LLVMTargetToString(target))); + llvm_target->SetTargetMetadata(mod.get()); mod->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION); - if (tm->getTargetTriple().isOSDarwin()) { + if (llvm_target->GetOrCreateTargetMachine()->getTargetTriple().isOSDarwin()) { mod->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2); } @@ -619,7 +557,7 @@ runtime::Module CreateLLVMCrtMetadataModule(const Array& module << verify_errors.str(); auto n = make_object(); - n->Init(std::move(mod), ctx); + n->Init(std::move(mod), std::move(llvm_scope)); for (auto m : modules) { n->Import(m); } diff --git a/src/target/llvm/llvm_module.h b/src/target/llvm/llvm_module.h index 3a50c2c4244f..66492f8152e5 100644 --- a/src/target/llvm/llvm_module.h +++ b/src/target/llvm/llvm_module.h @@ -46,5 +46,4 @@ runtime::Module CreateLLVMCrtMetadataModule(const Array& module } // namespace tvm #endif // TVM_LLVM_VERSION - #endif // TVM_TARGET_LLVM_LLVM_MODULE_H_ diff --git a/src/target/llvm/llvm_scope.cc b/src/target/llvm/llvm_scope.cc new file mode 100644 index 000000000000..c39b55f9b69b --- /dev/null +++ b/src/target/llvm/llvm_scope.cc @@ -0,0 +1,368 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifdef TVM_LLVM_VERSION + +#include "llvm_scope.h" + +#include +#include +#include +#if TVM_LLVM_VERSION >= 150 +#include +#else +#include +#endif +#include +#include +#include +#include +#if TVM_LLVM_VERSION >= 140 +#include +#else +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace codegen { + +namespace { +namespace defaults { +static const std::string cpu = "generic"; // NOLINT(runtime/string) +static const llvm::CodeGenOpt::Level opt_level = llvm::CodeGenOpt::Aggressive; +} // namespace defaults +} // namespace + +namespace { +bool InitializeLLVM() { + static std::atomic_flag initialized = ATOMIC_FLAG_INIT; + if (!initialized.test_and_set()) { + llvm::InitializeAllTargetInfos(); + llvm::InitializeAllTargets(); + llvm::InitializeAllTargetMCs(); + llvm::InitializeAllAsmParsers(); + llvm::InitializeAllAsmPrinters(); + } + return true; +} + +std::string Join(std::string sep, llvm::ArrayRef strings) { + std::string result; + bool is_first = true; + for (const std::string& s : strings) { + if (!is_first) { + result += sep; + } + result += s; + is_first = false; + } + return result; +} + +} // namespace + +// LLVMScope + +LLVMScope::LLVMScope() { + // Call InitializeLLVM before anything else. + static const bool DMLC_ATTRIBUTE_UNUSED init_llvm = InitializeLLVM(); + ctx_ = std::make_shared(); +} + +LLVMScope::~LLVMScope() = default; + +std::unique_ptr LLVMScope::ParseIR(const std::string& llvm_ir) const { + auto buffer = llvm::MemoryBuffer::getMemBuffer(llvm_ir, /*BufferName=*/"", + /*RequiresNullTerminator=*/false); + return ParseBuffer(*buffer); +} + +std::unique_ptr LLVMScope::LoadIR(const std::string& file_name) const { + llvm::ErrorOr> maybe_buffer = + llvm::MemoryBuffer::getFileAsStream(file_name); + if (std::error_code ec = maybe_buffer.getError()) { + LOG(FATAL) << ec.message(); + } + return ParseBuffer(**maybe_buffer); +} + +std::unique_ptr LLVMScope::ParseBuffer(const llvm::MemoryBuffer& buffer) const { + llvm::SMDiagnostic error; + std::unique_ptr module = llvm::parseIR(buffer.getMemBufferRef(), error, *ctx_); + if (module == nullptr) { + std::string message; + llvm::raw_string_ostream ostream(message); + error.print(/*ProgName=*/nullptr, ostream, /*ShowColors=*/false, /*ShowKindLabel=*/true); + LOG(FATAL) << ostream.str(); + } + + return module; +} + +// LLVMTarget + +LLVMTarget::LLVMTarget(LLVMScope& scope, const Target& target) + : scope_(scope), ctx_(scope.GetContext()) { + triple_ = target->GetAttr("mtriple").value_or("default"); + + if (triple_.empty() || triple_ == "default") { + triple_ = llvm::sys::getDefaultTargetTriple(); + } + cpu_ = target->GetAttr("mcpu").value_or(defaults::cpu); + + if (const Optional>& v = target->GetAttr>("mattr")) { + for (const String& s : v.value()) { + attrs_.push_back(s); + } + } + + llvm::FloatABI::ABIType float_abi = llvm::FloatABI::Default; + if (const Optional& v = target->GetAttr("mfloat-abi")) { + String value = v.value(); + if (value == "hard") { + float_abi = llvm::FloatABI::Hard; + } else if (value == "soft") { + float_abi = llvm::FloatABI::Soft; + } else { + LOG(FATAL) << "invalid -mfloat-abi option " << value; + } + } + + // Target options + +#if TVM_LLVM_VERSION < 50 + target_options_.LessPreciseFPMADOption = true; +#endif + // In clang, these are fed from LangOpts which describe language specific features + // TODO(AndrewZhaoLuo): figure out how these relate to fast math flags + target_options_.AllowFPOpFusion = llvm::FPOpFusion::Fast; + target_options_.UnsafeFPMath = false; + target_options_.NoInfsFPMath = false; + target_options_.NoNaNsFPMath = true; + target_options_.FloatABIType = float_abi; + if (const Optional& v = target->GetAttr("mabi")) { + target_options_.MCOptions.ABIName = v.value(); + } + + auto maybe_level = target->GetAttr("opt-level"); + + if (maybe_level.defined()) { + int level = maybe_level.value()->value; + if (level <= 0) { + opt_level_ = llvm::CodeGenOpt::None; + } else if (level == 1) { + opt_level_ = llvm::CodeGenOpt::Less; + } else if (level == 2) { + opt_level_ = llvm::CodeGenOpt::Default; + } else { + // level >= 3 + opt_level_ = llvm::CodeGenOpt::Aggressive; + } + } else { + opt_level_ = defaults::opt_level; + } + + // Fast math options + + auto GetBoolFlag = [&target](llvm::StringRef flag) -> bool { + return target->GetAttr(flag.str()).value_or(Bool(false)); + }; + if (GetBoolFlag("fast-math")) { +#if TVM_LLVM_VERSION >= 60 + fast_math_flags_.setFast(); +#else + fast_math_flags_.setUnsafeAlgebra(); +#endif + } else { +#if TVM_LLVM_VERSION >= 50 + // This option was added in 5.x, and has a boolean argument, + // unlike the rest of options at the time. + fast_math_flags_.setAllowContract(GetBoolFlag("fast-math-contract")); +#endif +#if TVM_LLVM_VERSION >= 70 + fast_math_flags_.setNoNaNs(GetBoolFlag("fast-math-nnan")); + fast_math_flags_.setNoInfs(GetBoolFlag("fast-math-ninf")); + fast_math_flags_.setNoSignedZeros(GetBoolFlag("fast-math-nsz")); + fast_math_flags_.setAllowReciprocal(GetBoolFlag("fast-math-arcp")); + fast_math_flags_.setAllowContract(GetBoolFlag("fast-math-contract")); + fast_math_flags_.setAllowReassoc(GetBoolFlag("fast-math-reassoc")); + fast_math_flags_.setApproxFunc(GetBoolFlag("fast-math-afn")); +#else + // LLVM 4.x, 5.x, and 6.x + if (GetBoolFlag("fast-math-nnan")) fast_math_flags_.setNoNaNs(); + if (GetBoolFlag("fast-math-ninf")) fast_math_flags_.setNoInfs(); + if (GetBoolFlag("fast-math-nsz")) fast_math_flags_.setNoSignedZeros(); + if (GetBoolFlag("fast-math-arcp")) fast_math_flags_.setAllowReciprocal(); +#if TVM_LLVM_VERSION >= 60 + if (GetBoolFlag("fast-math-reassoc")) fast_math_flags_.setAllowReassoc(); + if (GetBoolFlag("fast-math-afn")) fast_math_flags_.setApproxFunc(); +#endif +#endif + } +} + +LLVMTarget::LLVMTarget(LLVMScope& scope, const std::string& target_str) + : LLVMTarget(scope, Target(target_str)) {} + +LLVMTarget::~LLVMTarget() = default; + +llvm::LLVMContext* LLVMTarget::GetContext() const { + ICHECK(!ctx_.expired()) << "LLVM scope has been deleted"; + return ctx_.lock().get(); +} + +llvm::TargetMachine* LLVMTarget::GetOrCreateTargetMachine(bool allow_missing) { + if (target_machine_) return target_machine_.get(); + + std::string error; + if (const llvm::Target* llvm_scope = llvm::TargetRegistry::lookupTarget(triple_, error)) { + llvm::TargetMachine* tm = + llvm_scope->createTargetMachine(triple_, cpu_, GetTargetFeatureString(), target_options_, + reloc_model_, code_model_, opt_level_); + target_machine_ = std::unique_ptr(tm); + if (!allow_missing) { + ICHECK(target_machine_ != nullptr) << error; + } + } + return target_machine_.get(); +} + +std::string LLVMTarget::GetTargetFeatureString() const { // + return Join(",", attrs_); +} + +std::string LLVMTarget::str() const { + std::ostringstream os; + os << "llvm"; + if (!triple_.empty()) { + os << " -mtriple=" << triple_; + } + if (!cpu_.empty() && cpu_ != defaults::cpu) { + os << " -mcpu=" << cpu_; + } + if (!attrs_.empty()) { + os << " -mattr=" << GetTargetFeatureString(); + } + + switch (target_options_.FloatABIType) { + case llvm::FloatABI::Soft: + os << " -mfloat-abi=soft"; + break; + case llvm::FloatABI::Hard: + os << " -mfloat-abi=hard"; + break; + case llvm::FloatABI::Default: + break; + } + if (!target_options_.MCOptions.ABIName.empty()) { + os << " -mabi=" << target_options_.MCOptions.ABIName; + } + + bool do_individual = true; +#if TVM_LLVM_VERSION >= 60 + if (fast_math_flags_.isFast()) { + os << " -fast-math"; + do_individual = false; + } +#else + if (fast_math_flags_.unsafeAlgebra()) { + os << " -fast-math"; + do_individual = false; + } +#endif + + if (do_individual) { + if (fast_math_flags_.noNaNs()) os << " -fast-math-nnan"; + if (fast_math_flags_.noInfs()) os << " -fast-math-ninf"; + if (fast_math_flags_.noSignedZeros()) os << " -fast-math-nsz"; + if (fast_math_flags_.allowReciprocal()) os << " -fast-math-arcp"; +#if TVM_LLVM_VERSION >= 50 + if (fast_math_flags_.allowContract()) os << " -fast-math-contract"; +#endif +#if TVM_LLVM_VERSION >= 60 + if (fast_math_flags_.allowReassoc()) os << " -fast-math-reassoc"; + if (fast_math_flags_.approxFunc()) os << " -fast-math-afn"; +#endif + } + + if (opt_level_ != defaults::opt_level) { + os << " -opt-level="; + switch (opt_level_) { + case llvm::CodeGenOpt::None: + os << "0"; + break; + case llvm::CodeGenOpt::Less: + os << "1"; + break; + case llvm::CodeGenOpt::Default: + os << "2"; + break; + case llvm::CodeGenOpt::Aggressive: + os << "3"; + break; + } + } + + return os.str(); +} + +std::string LLVMTarget::GetTargetMetadata(const llvm::Module& module) { + if (llvm::Metadata* tvm_target = module.getModuleFlag("tvm_target")) { + auto* mdstr = llvm::cast(tvm_target); + llvm::StringRef meta = mdstr->getString(); + if (meta.startswith("llvm")) { + return meta.str(); + } + } + return "llvm -mtriple " + module.getTargetTriple(); +} + +void LLVMTarget::SetTargetMetadata(llvm::Module* module) const { + module->addModuleFlag(llvm::Module::Warning, "tvm_target", + llvm::MDString::get(*GetContext(), str())); +} + +void LLVMTarget::EnterWithScope() {} +void LLVMTarget::ExitWithScope() {} + +} // namespace codegen +} // namespace tvm + +#endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/llvm_scope.h b/src/target/llvm/llvm_scope.h new file mode 100644 index 000000000000..e9a0aea66fcd --- /dev/null +++ b/src/target/llvm/llvm_scope.h @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_TARGET_LLVM_LLVM_SCOPE_H_ +#define TVM_TARGET_LLVM_LLVM_SCOPE_H_ + +#ifdef TVM_LLVM_VERSION + +#include +#if TVM_LLVM_VERSION >= 150 +#include +#else +#include +#endif +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace llvm { +class LLVMContext; +class MemoryBuffer; +class Module; +class TargetMachine; +} // namespace llvm + +namespace tvm { +namespace codegen { + +class LLVMTarget; + +class LLVMScope { + public: + LLVMScope(); + ~LLVMScope(); // Must not be "= default" here in the header file. + + std::shared_ptr GetContext() const { return ctx_; } + + std::unique_ptr ParseIR(const std::string& llvm_ir) const; + std::unique_ptr LoadIR(const std::string& file_name) const; + + private: + std::unique_ptr ParseBuffer(const llvm::MemoryBuffer& buffer) const; + + std::shared_ptr ctx_; +}; + +class LLVMTarget { + public: + LLVMTarget(LLVMScope& scope, const Target& target); // NOLINT(runtime/references) + LLVMTarget(LLVMScope& scope, const std::string& target_str); // NOLINT(runtime/references) + ~LLVMTarget(); + + std::string str() const; + + const LLVMScope& GetScope() const { return scope_; } + llvm::LLVMContext* GetContext() const; + llvm::TargetMachine* GetOrCreateTargetMachine(bool allow_missing = false); + + const std::string& GetTargetTriple() const { return triple_; } + const std::string& GetCPU() const { return cpu_; } + llvm::ArrayRef GetTargetFeatures() const { return attrs_; } + std::string GetTargetFeatureString() const; + const llvm::TargetOptions& GetTargetOptions() const { return target_options_; } + llvm::FastMathFlags GetFastMathFlags() const { return fast_math_flags_; } + llvm::CodeGenOpt::Level GetOptLevel() const { return opt_level_; } + + static std::string GetTargetMetadata(const llvm::Module& module); + void SetTargetMetadata(llvm::Module* module) const; + + void EnterWithScope(); + void ExitWithScope(); + + private: + const LLVMScope& scope_; + std::weak_ptr ctx_; + + std::string triple_; + std::string cpu_; + std::vector attrs_; + llvm::TargetOptions target_options_; + llvm::FastMathFlags fast_math_flags_; + llvm::CodeGenOpt::Level opt_level_; + llvm::Reloc::Model reloc_model_ = llvm::Reloc::PIC_; + llvm::CodeModel::Model code_model_ = llvm::CodeModel::Small; + std::shared_ptr target_machine_; +}; + +} // namespace codegen +} // namespace tvm + +#endif // TVM_LLVM_VERSION +#endif // TVM_TARGET_LLVM_LLVM_SCOPE_H_ From 6bd9e465b9a4d8baa981c0555196493671c95fa3 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Mon, 25 Jul 2022 06:49:39 -0700 Subject: [PATCH 2/6] Rename CodeGenLLVM::SetFastMathFlag to SetFastMathFlags, NFC --- src/target/llvm/codegen_llvm.cc | 2 +- src/target/llvm/codegen_llvm.h | 2 +- src/target/llvm/llvm_module.cc | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 695eaa30493a..a2f057c7e457 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -145,7 +145,7 @@ void CodeGenLLVM::Init(const std::string& module_name, LLVMTarget* llvm_target, InitTarget(); } -void CodeGenLLVM::SetFastMathFlag(llvm::FastMathFlags fmf) { builder_->setFastMathFlags(fmf); } +void CodeGenLLVM::SetFastMathFlags(llvm::FastMathFlags fmf) { builder_->setFastMathFlags(fmf); } void CodeGenLLVM::InitTarget() { llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine(); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index d91f9c5fab33..e6321be647aa 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -129,7 +129,7 @@ class CodeGenLLVM : public ExprFunctor, * \brief Turn on fast math flags for floating point operations. * \param fmf FastMathFlags to use for code generation. */ - void SetFastMathFlag(llvm::FastMathFlags fmf); + void SetFastMathFlags(llvm::FastMathFlags fmf); /*! * \brief Compile and add function f to the current module. diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 919fca879582..82355a556ba2 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -309,7 +309,7 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) { // TODO(tqchen): remove the entry function behavior as it does not // makes sense when we start to use multiple modules. cg->Init("TVMMod", llvm_target.get(), system_lib, system_lib, target_c_runtime); - cg->SetFastMathFlag(llvm_target->GetFastMathFlags()); + cg->SetFastMathFlags(llvm_target->GetFastMathFlags()); cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); if (entry_func.length() != 0) { From 257a188bfe75643eb89cfce5904ed7276e766690 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Mon, 25 Jul 2022 09:10:22 -0700 Subject: [PATCH 3/6] Add doxygen documentation Move empty implementations of EnterWithScope/ExitWithScope to header since it helps see that these functions are only stubs. --- src/target/llvm/llvm_scope.cc | 3 - src/target/llvm/llvm_scope.h | 155 +++++++++++++++++++++++++++++++++- 2 files changed, 152 insertions(+), 6 deletions(-) diff --git a/src/target/llvm/llvm_scope.cc b/src/target/llvm/llvm_scope.cc index c39b55f9b69b..72042a3c1354 100644 --- a/src/target/llvm/llvm_scope.cc +++ b/src/target/llvm/llvm_scope.cc @@ -359,9 +359,6 @@ void LLVMTarget::SetTargetMetadata(llvm::Module* module) const { llvm::MDString::get(*GetContext(), str())); } -void LLVMTarget::EnterWithScope() {} -void LLVMTarget::ExitWithScope() {} - } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/llvm_scope.h b/src/target/llvm/llvm_scope.h index e9a0aea66fcd..c0620004e1bc 100644 --- a/src/target/llvm/llvm_scope.h +++ b/src/target/llvm/llvm_scope.h @@ -17,6 +17,8 @@ * under the License. */ +/*! \file llvm_scope.h + */ #ifndef TVM_TARGET_LLVM_LLVM_SCOPE_H_ #define TVM_TARGET_LLVM_LLVM_SCOPE_H_ @@ -53,14 +55,70 @@ namespace codegen { class LLVMTarget; +/*! + * \class LLVMScope + * \brief LLVMScope is a class that (conceptually) starts and stops LLVM. All + * uses of LLVM should take place within a lifetime of an object of this class. + * + * E.g. + * ```{.cpp} + * { + * LLVMScope llvm_scope; + * ... + * someFunctionFromLLVM(...); + * ... + * } + * // no more calls to LLVM here + * ``` + * In addition to that, LLVMScope provides an LLVM context (llvm::LLVMContext). + * The context is a structure in LLVM where common IR constructs are maintained, + * (such as types, constants, etc.) so that they can be identified by their + * address (i.e. pointer comparison). Because of that, it's important to use + * the same context throughout compilation. + * + * At the moment the "starting" of LLVM performs initialization of LLVM, but + * "stopping" doesn't do anything. In the future, if such a need arises, this + * functionality may be extended to perform dlopen/dlclose of the LLVM-based + * code in TVM. + * + * This class provides means to deserialize an LLVM module, either from text + * (in a string), or from a file. In either case, the serialized module can + * be LLVM IR assembly, or binary bitcode enconding. + */ class LLVMScope { public: + /*! + * \brief Constructs LLVMScope + */ LLVMScope(); + /*! + * \brief Destroys LLVMScope object + */ ~LLVMScope(); // Must not be "= default" here in the header file. + /*! + * \brief Get the LLVM context for this scope. + */ std::shared_ptr GetContext() const { return ctx_; } + /*! + * \brief Create `llvm::Module` from a string. + * + * Parse the string in \param llvm_ir, and return the `llvm::Module`. + * At the moment this function will abort if the parsing fails. + * \param llvm_ir string with the LLVM IR assembly or bitcode + * \return created `llvm::Module` + */ std::unique_ptr ParseIR(const std::string& llvm_ir) const; + /*! + * \brief Load `llvm::Module` from a given file + * + * Read the file \param file_name, and return the `llvm::Module`. + * At the moment this function will abort if reading of the file or creation + * of the module fails. + * \param file_name file with the LLVM IR assembly or bitcode + * \return created `llvm::Module` + */ std::unique_ptr LoadIR(const std::string& file_name) const; private: @@ -69,31 +127,122 @@ class LLVMScope { std::shared_ptr ctx_; }; +/*! + * \class LLVMTarget + * \brief Information used by LLVM for code generation for particular target + * + * This class contains all information that LLVM needs for code generation for + * a particular target. Since Target in TVM will soon contain command line + * flags for LLVM, objects of this class will handle saving and restoring + * global LLVM state that may be affected by these flags. This way, code + * generation for each LLVM-based target in TVM will start with the same LLVM + * global state. + * + * Note that objects of this class must be created within the lifetime of an + * LLVMScope object. + */ class LLVMTarget { public: - LLVMTarget(LLVMScope& scope, const Target& target); // NOLINT(runtime/references) + /*! + * \brief Constructs LLVMTarget from `Target` + * \param scope LLVMScope object + * \param target TVM Target object for target "llvm" + */ + LLVMTarget(LLVMScope& scope, const Target& target); // NOLINT(runtime/references) + /*! + * \brief Constructs LLVMTarget from target string + * \param scope LLVMScope object + * \param target TVM target string for target "llvm" + */ LLVMTarget(LLVMScope& scope, const std::string& target_str); // NOLINT(runtime/references) + /*! + * \brief Destroys LLVMTarget object + */ ~LLVMTarget(); + /*! + * \brief Returns string representation (as TVM target) of the LLVMTarget + * \return Target string + * + * Note: If the LLVMTarget object was created from a string `s`, the string + * returned here may not be exactly equal to `s`. For example, if the CPU + * was "default", the returned string will have CPU set to the detected host + * CPU. + */ std::string str() const; + /*! + * \brief Get the LLVMScope object from which the LLVMTarget object was + * created + * \return The enclosing LLVMScope object + */ const LLVMScope& GetScope() const { return scope_; } + /*! + * \brief Get the current LLVM context + * \return the current LLVM context + */ llvm::LLVMContext* GetContext() const; + /*! + * \brief Return LLVM's `TargetMachine`, or nullptr + * \param allow_missing do not abort if the target machine cannot be created, + * return nullptr instead + * \return Pointer to the `TargetMachine` object (or nullptr if it cannot be + * created, \see allow_missing) + */ llvm::TargetMachine* GetOrCreateTargetMachine(bool allow_missing = false); + /*! + * \brief Get the target triple + * \return the target triple + */ const std::string& GetTargetTriple() const { return triple_; } + /*! + * \brief Get the CPU name + * \return the CPU name: the detected host CPU if the original TVM target + * specified it as "default" + */ const std::string& GetCPU() const { return cpu_; } + /*! + * \brief Get the list of LLVM target features + * \return array of individual feature strings + */ llvm::ArrayRef GetTargetFeatures() const { return attrs_; } + /*! + * \brief Get the LLVM target feature string + * \return comma-separated list of LLVM target features + */ std::string GetTargetFeatureString() const; + /*! + * \brief Get the LLVM target options + * \return `llvm::TargetOptions` object for this target + */ const llvm::TargetOptions& GetTargetOptions() const { return target_options_; } + /*! + * \brief Get fast math flags + * \return `llvm::FastMathFlags` for this target + */ llvm::FastMathFlags GetFastMathFlags() const { return fast_math_flags_; } + /*! + * \brief Get the LLVM optimization level + * \return optimization level for this target + */ llvm::CodeGenOpt::Level GetOptLevel() const { return opt_level_; } + /*! + * \brief Extract the target string from given `llvm::Module` + * \param module LLVM module with the TVM target string embedded as metadata + * \return the target string from module's metadata + */ static std::string GetTargetMetadata(const llvm::Module& module); + /*! + * \brief Embed target string as metadata in given `llvm::Module` + * \param module the module to insert the target string into + */ void SetTargetMetadata(llvm::Module* module) const; - void EnterWithScope(); - void ExitWithScope(); + // Stubs to enable use with `With`. + void EnterWithScope() {} + void ExitWithScope() {} private: const LLVMScope& scope_; From 6b1f1e28784adc4f1336985e6a077a071d4ae9a0 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Wed, 27 Jul 2022 19:56:51 -0500 Subject: [PATCH 4/6] Change global std::string to const char* --- src/target/llvm/llvm_scope.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/target/llvm/llvm_scope.cc b/src/target/llvm/llvm_scope.cc index 72042a3c1354..975bb335e952 100644 --- a/src/target/llvm/llvm_scope.cc +++ b/src/target/llvm/llvm_scope.cc @@ -65,7 +65,7 @@ namespace codegen { namespace { namespace defaults { -static const std::string cpu = "generic"; // NOLINT(runtime/string) +static const char* cpu = "generic"; static const llvm::CodeGenOpt::Level opt_level = llvm::CodeGenOpt::Aggressive; } // namespace defaults } // namespace From 5502144392cdf83af5c828922487fa3327536665 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Wed, 3 Aug 2022 12:59:41 -0700 Subject: [PATCH 5/6] LLVMScope -> LLVMInstance --- src/target/llvm/codegen_amdgpu.cc | 8 +-- src/target/llvm/codegen_blob.cc | 2 +- src/target/llvm/codegen_cpu.cc | 2 +- src/target/llvm/codegen_hexagon.cc | 6 +- src/target/llvm/codegen_llvm.cc | 6 +- src/target/llvm/codegen_nvptx.cc | 8 +-- src/target/llvm/codegen_x86_64.cc | 2 +- .../llvm/{llvm_scope.cc => llvm_instance.cc} | 26 ++++---- .../llvm/{llvm_scope.h => llvm_instance.h} | 38 ++++++------ src/target/llvm/llvm_module.cc | 60 +++++++++---------- 10 files changed, 79 insertions(+), 79 deletions(-) rename src/target/llvm/{llvm_scope.cc => llvm_instance.cc} (92%) rename src/target/llvm/{llvm_scope.h => llvm_instance.h} (89%) diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index daaf9b3da9ae..c08081405648 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -51,7 +51,7 @@ #include "../../runtime/rocm/rocm_module.h" #include "../build_common.h" #include "codegen_llvm.h" -#include "llvm_scope.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { @@ -246,9 +246,9 @@ class CodeGenAMDGPU : public CodeGenLLVM { }; runtime::Module BuildAMDGPU(IRModule mod, Target target) { - LLVMScope llvm_scope; + LLVMInstance llvm_instance; - With llvm_target(llvm_scope, target); + With llvm_target(llvm_instance, target); #if TVM_LLVM_VERSION < 90 LOG(FATAL) << "AMDGPU backend requires at least LLVM 9"; // Lower versions will crash when loading the bitcode, see @@ -269,7 +269,7 @@ runtime::Module BuildAMDGPU(IRModule mod, Target target) { Array bitcode_files = (*find_rocm_bitcodes)(); for (auto& bitcode_path : bitcode_files) { - std::unique_ptr mlib = llvm_scope.LoadIR(bitcode_path); + std::unique_ptr mlib = llvm_instance.LoadIR(bitcode_path); mlib->setTargetTriple(llvm_target->GetTargetTriple()); mlib->setDataLayout(tm->createDataLayout()); diff --git a/src/target/llvm/codegen_blob.cc b/src/target/llvm/codegen_blob.cc index b15d3caed653..b67aac480654 100644 --- a/src/target/llvm/codegen_blob.cc +++ b/src/target/llvm/codegen_blob.cc @@ -52,7 +52,7 @@ #include #include -#include "llvm_scope.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 4db0df87f916..c4aed1a237dd 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -60,7 +60,7 @@ #include "../func_registry_generator.h" #include "../metadata_utils.h" -#include "llvm_scope.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index eb768409662a..1b9233d2ad2f 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -62,7 +62,7 @@ #include "../../runtime/hexagon/hexagon_module.h" #include "../build_common.h" #include "codegen_cpu.h" -#include "llvm_scope.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { @@ -511,8 +511,8 @@ void ProcessLLVMOptions(const std::vector& llvm_vec) { } // namespace runtime::Module BuildHexagon(IRModule mod, Target target) { - LLVMScope llvm_scope; - With llvm_target(llvm_scope, target); + LLVMInstance llvm_instance; + With llvm_target(llvm_instance, target); auto split = [](const std::string& str, char delim = ' ') { std::vector vec; diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index a2f057c7e457..305358d079d0 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -89,7 +89,7 @@ #include "../build_common.h" #include "../func_registry_generator.h" #include "codegen_params.h" -#include "llvm_scope.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { @@ -272,9 +272,9 @@ void CodeGenLLVM::HandleImport(const std::string& code) { llvm::StringRef code_str(code); std::unique_ptr mlib; if (code_str.endswith(".ll") || code_str.endswith(".bc")) { - mlib = llvm_target_->GetScope().LoadIR(code); + mlib = llvm_target_->GetInstance().LoadIR(code); } else { - mlib = llvm_target_->GetScope().ParseIR(code); + mlib = llvm_target_->GetInstance().ParseIR(code); } mlib->setTargetTriple(llvm_target_->GetTargetTriple()); diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 67184c495c80..c758ca383621 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -56,7 +56,7 @@ #include "../../runtime/cuda/cuda_module.h" #include "../build_common.h" #include "codegen_llvm.h" -#include "llvm_scope.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { @@ -299,8 +299,8 @@ int GetCUDAComputeVersion(const Target& target) { } runtime::Module BuildNVPTX(IRModule mod, Target target) { - LLVMScope llvm_scope; - With llvm_target(llvm_scope, target); + LLVMInstance llvm_instance; + With llvm_target(llvm_instance, target); int compute_ver = GetCUDAComputeVersion(target); std::unique_ptr cg(new CodeGenNVPTX()); @@ -318,7 +318,7 @@ runtime::Module BuildNVPTX(IRModule mod, Target target) { if (flibdevice_path != nullptr) { std::string path = (*flibdevice_path)(compute_ver); if (path.length() != 0) { - std::unique_ptr mlib = llvm_scope.LoadIR(path); + std::unique_ptr mlib = llvm_instance.LoadIR(path); mlib->setTargetTriple(llvm_target->GetTargetTriple()); mlib->setDataLayout(tm->createDataLayout()); cg->AddLinkModule(std::move(mlib)); diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index 7793317fe792..efe15c5c4aac 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -38,7 +38,7 @@ #include #include "codegen_cpu.h" -#include "llvm_scope.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { diff --git a/src/target/llvm/llvm_scope.cc b/src/target/llvm/llvm_instance.cc similarity index 92% rename from src/target/llvm/llvm_scope.cc rename to src/target/llvm/llvm_instance.cc index 975bb335e952..772e71b28724 100644 --- a/src/target/llvm/llvm_scope.cc +++ b/src/target/llvm/llvm_instance.cc @@ -19,7 +19,7 @@ #ifdef TVM_LLVM_VERSION -#include "llvm_scope.h" +#include "llvm_instance.h" #include #include @@ -98,23 +98,23 @@ std::string Join(std::string sep, llvm::ArrayRef strings) { } // namespace -// LLVMScope +// LLVMInstance -LLVMScope::LLVMScope() { +LLVMInstance::LLVMInstance() { // Call InitializeLLVM before anything else. static const bool DMLC_ATTRIBUTE_UNUSED init_llvm = InitializeLLVM(); ctx_ = std::make_shared(); } -LLVMScope::~LLVMScope() = default; +LLVMInstance::~LLVMInstance() = default; -std::unique_ptr LLVMScope::ParseIR(const std::string& llvm_ir) const { +std::unique_ptr LLVMInstance::ParseIR(const std::string& llvm_ir) const { auto buffer = llvm::MemoryBuffer::getMemBuffer(llvm_ir, /*BufferName=*/"", /*RequiresNullTerminator=*/false); return ParseBuffer(*buffer); } -std::unique_ptr LLVMScope::LoadIR(const std::string& file_name) const { +std::unique_ptr LLVMInstance::LoadIR(const std::string& file_name) const { llvm::ErrorOr> maybe_buffer = llvm::MemoryBuffer::getFileAsStream(file_name); if (std::error_code ec = maybe_buffer.getError()) { @@ -123,7 +123,7 @@ std::unique_ptr LLVMScope::LoadIR(const std::string& file_name) co return ParseBuffer(**maybe_buffer); } -std::unique_ptr LLVMScope::ParseBuffer(const llvm::MemoryBuffer& buffer) const { +std::unique_ptr LLVMInstance::ParseBuffer(const llvm::MemoryBuffer& buffer) const { llvm::SMDiagnostic error; std::unique_ptr module = llvm::parseIR(buffer.getMemBufferRef(), error, *ctx_); if (module == nullptr) { @@ -138,8 +138,8 @@ std::unique_ptr LLVMScope::ParseBuffer(const llvm::MemoryBuffer& b // LLVMTarget -LLVMTarget::LLVMTarget(LLVMScope& scope, const Target& target) - : scope_(scope), ctx_(scope.GetContext()) { +LLVMTarget::LLVMTarget(LLVMInstance& instance, const Target& target) + : instance_(instance), ctx_(instance.GetContext()) { triple_ = target->GetAttr("mtriple").value_or("default"); if (triple_.empty() || triple_ == "default") { @@ -238,7 +238,7 @@ LLVMTarget::LLVMTarget(LLVMScope& scope, const Target& target) } } -LLVMTarget::LLVMTarget(LLVMScope& scope, const std::string& target_str) +LLVMTarget::LLVMTarget(LLVMInstance& scope, const std::string& target_str) : LLVMTarget(scope, Target(target_str)) {} LLVMTarget::~LLVMTarget() = default; @@ -252,10 +252,10 @@ llvm::TargetMachine* LLVMTarget::GetOrCreateTargetMachine(bool allow_missing) { if (target_machine_) return target_machine_.get(); std::string error; - if (const llvm::Target* llvm_scope = llvm::TargetRegistry::lookupTarget(triple_, error)) { + if (const llvm::Target* llvm_instance = llvm::TargetRegistry::lookupTarget(triple_, error)) { llvm::TargetMachine* tm = - llvm_scope->createTargetMachine(triple_, cpu_, GetTargetFeatureString(), target_options_, - reloc_model_, code_model_, opt_level_); + llvm_instance->createTargetMachine(triple_, cpu_, GetTargetFeatureString(), target_options_, + reloc_model_, code_model_, opt_level_); target_machine_ = std::unique_ptr(tm); if (!allow_missing) { ICHECK(target_machine_ != nullptr) << error; diff --git a/src/target/llvm/llvm_scope.h b/src/target/llvm/llvm_instance.h similarity index 89% rename from src/target/llvm/llvm_scope.h rename to src/target/llvm/llvm_instance.h index c0620004e1bc..37eacd2a4aec 100644 --- a/src/target/llvm/llvm_scope.h +++ b/src/target/llvm/llvm_instance.h @@ -17,7 +17,7 @@ * under the License. */ -/*! \file llvm_scope.h +/*! \file llvm_instance.h */ #ifndef TVM_TARGET_LLVM_LLVM_SCOPE_H_ #define TVM_TARGET_LLVM_LLVM_SCOPE_H_ @@ -56,21 +56,21 @@ namespace codegen { class LLVMTarget; /*! - * \class LLVMScope - * \brief LLVMScope is a class that (conceptually) starts and stops LLVM. All + * \class LLVMInstance + * \brief LLVMInstance is a class that (conceptually) starts and stops LLVM. All * uses of LLVM should take place within a lifetime of an object of this class. * * E.g. * ```{.cpp} * { - * LLVMScope llvm_scope; + * LLVMInstance llvm_instance; * ... * someFunctionFromLLVM(...); * ... * } * // no more calls to LLVM here * ``` - * In addition to that, LLVMScope provides an LLVM context (llvm::LLVMContext). + * In addition to that, LLVMInstance provides an LLVM context (llvm::LLVMContext). * The context is a structure in LLVM where common IR constructs are maintained, * (such as types, constants, etc.) so that they can be identified by their * address (i.e. pointer comparison). Because of that, it's important to use @@ -85,16 +85,16 @@ class LLVMTarget; * (in a string), or from a file. In either case, the serialized module can * be LLVM IR assembly, or binary bitcode enconding. */ -class LLVMScope { +class LLVMInstance { public: /*! - * \brief Constructs LLVMScope + * \brief Constructs LLVMInstance */ - LLVMScope(); + LLVMInstance(); /*! - * \brief Destroys LLVMScope object + * \brief Destroys LLVMInstance object */ - ~LLVMScope(); // Must not be "= default" here in the header file. + ~LLVMInstance(); // Must not be "= default" here in the header file. /*! * \brief Get the LLVM context for this scope. @@ -139,22 +139,22 @@ class LLVMScope { * global state. * * Note that objects of this class must be created within the lifetime of an - * LLVMScope object. + * LLVMInstance object. */ class LLVMTarget { public: /*! * \brief Constructs LLVMTarget from `Target` - * \param scope LLVMScope object + * \param scope LLVMInstance object * \param target TVM Target object for target "llvm" */ - LLVMTarget(LLVMScope& scope, const Target& target); // NOLINT(runtime/references) + LLVMTarget(LLVMInstance& scope, const Target& target); // NOLINT(runtime/references) /*! * \brief Constructs LLVMTarget from target string - * \param scope LLVMScope object + * \param scope LLVMInstance object * \param target TVM target string for target "llvm" */ - LLVMTarget(LLVMScope& scope, const std::string& target_str); // NOLINT(runtime/references) + LLVMTarget(LLVMInstance& scope, const std::string& target_str); // NOLINT(runtime/references) /*! * \brief Destroys LLVMTarget object */ @@ -172,11 +172,11 @@ class LLVMTarget { std::string str() const; /*! - * \brief Get the LLVMScope object from which the LLVMTarget object was + * \brief Get the LLVMInstance object from which the LLVMTarget object was * created - * \return The enclosing LLVMScope object + * \return The enclosing LLVMInstance object */ - const LLVMScope& GetScope() const { return scope_; } + const LLVMInstance& GetInstance() const { return instance_; } /*! * \brief Get the current LLVM context * \return the current LLVM context @@ -245,7 +245,7 @@ class LLVMTarget { void ExitWithScope() {} private: - const LLVMScope& scope_; + const LLVMInstance& instance_; std::weak_ptr ctx_; std::string triple_; diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 954d31b594c5..9aed66fffc5c 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -76,7 +76,7 @@ #include "codegen_blob.h" #include "codegen_cpu.h" #include "codegen_llvm.h" -#include "llvm_scope.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { @@ -98,7 +98,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { std::string GetSource(const std::string& format) final; void Init(const IRModule& mod, const Target& target); - void Init(std::unique_ptr module, std::unique_ptr llvm_scope); + void Init(std::unique_ptr module, std::unique_ptr llvm_instance); void LoadIR(const std::string& file_name); bool IsDSOExportable() const final { return true; } @@ -111,7 +111,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { void* GetFunctionAddr(const std::string& name, const LLVMTarget& llvm_target) const; // The LLVM scope object. - std::unique_ptr llvm_scope_; + std::unique_ptr llvm_instance_; // JIT lock std::mutex mutex_; // execution engine @@ -154,7 +154,7 @@ PackedFunc LLVMModuleNode::GetFunction(const std::string& name, std::lock_guard lock(mutex_); TVMBackendPackedCFunc faddr; - With llvm_target(*llvm_scope_, LLVMTarget::GetTargetMetadata(*module_)); + With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); if (name == runtime::symbol::tvm_module_main) { const char* entry_name = reinterpret_cast( GetGlobalAddr(runtime::symbol::tvm_module_main, *llvm_target)); @@ -178,7 +178,7 @@ void LLVMModuleNode::SaveToFile(const std::string& file_name, const std::string& #endif ICHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name << " " << ecode.message(); if (fmt == "o" || fmt == "obj") { - With llvm_target(*llvm_scope_, LLVMTarget::GetTargetMetadata(*module_)); + With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); #if TVM_LLVM_VERSION <= 60 std::unique_ptr m = llvm::CloneModule(module_); #else @@ -198,7 +198,7 @@ void LLVMModuleNode::SaveToFile(const std::string& file_name, const std::string& #endif pass.run(*m); } else if (fmt == "s" || fmt == "asm") { - With llvm_target(*llvm_scope_, LLVMTarget::GetTargetMetadata(*module_)); + With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); #if TVM_LLVM_VERSION <= 60 std::unique_ptr m = llvm::CloneModule(module_); #else @@ -244,7 +244,7 @@ std::string LLVMModuleNode::GetSource(const std::string& format) { llvm::raw_svector_ostream rso(str); if (fmt == "s" || fmt == "asm") { - With llvm_target(*llvm_scope_, LLVMTarget::GetTargetMetadata(*module_)); + With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); #if TVM_LLVM_VERSION <= 60 std::unique_ptr m = llvm::CloneModule(module_); #else @@ -277,8 +277,8 @@ std::string LLVMModuleNode::GetSource(const std::string& format) { } void LLVMModuleNode::Init(const IRModule& mod, const Target& target) { - llvm_scope_ = std::make_unique(); - With llvm_target(*llvm_scope_, target); + llvm_instance_ = std::make_unique(); + With llvm_target(*llvm_instance_, target); llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); std::unique_ptr cg = CodeGenLLVM::Create(llvm_target.get()); @@ -334,16 +334,16 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) { } void LLVMModuleNode::Init(std::unique_ptr module, - std::unique_ptr llvm_scope) { + std::unique_ptr llvm_instance) { module_owning_ptr_ = std::move(module); module_ = module_owning_ptr_.get(); - llvm_scope_ = std::move(llvm_scope); + llvm_instance_ = std::move(llvm_instance); } void LLVMModuleNode::LoadIR(const std::string& file_name) { - auto llvm_scope = std::make_unique(); - std::unique_ptr module = llvm_scope->LoadIR(file_name); - Init(std::move(module), std::move(llvm_scope)); + auto llvm_instance = std::make_unique(); + std::unique_ptr module = llvm_instance->LoadIR(file_name); + Init(std::move(module), std::move(llvm_instance)); } bool LLVMModuleNode::ImplementsFunction(const String& name, bool query_imports) { @@ -355,7 +355,7 @@ void LLVMModuleNode::LazyInitJIT() { if (ee_) { return; } - With llvm_target(*llvm_scope_, LLVMTarget::GetTargetMetadata(*module_)); + With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); llvm::EngineBuilder builder(std::move(module_owning_ptr_)); builder.setEngineKind(llvm::EngineKind::JIT); builder.setOptLevel(llvm::CodeGenOpt::Aggressive); @@ -390,7 +390,7 @@ void LLVMModuleNode::LazyInitJIT() { } bool LLVMModuleNode::IsCompatibleWithHost(const llvm::TargetMachine* tm) const { - With host_target(*llvm_scope_, "llvm"); // FIXME(kparzysz-quic): nesting + With host_target(*llvm_instance_, "llvm"); // FIXME(kparzysz-quic): nesting auto tm_host = host_target->GetOrCreateTargetMachine(); if (tm_host->getTargetTriple().getArch() != tm->getTargetTriple().getArch()) { LOG(INFO) << "Architecture mismatch: module=" << tm->getTargetTriple().str() @@ -429,15 +429,15 @@ TVM_REGISTER_GLOBAL("target.build.llvm") TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") .set_body_typed([](std::string target_str, std::string module_name) -> runtime::Module { - auto llvm_scope = std::make_unique(); - With llvm_target(*llvm_scope, target_str); + auto llvm_instance = std::make_unique(); + With llvm_target(*llvm_instance, target_str); auto n = make_object(); // Generate a LLVM module from an input target string auto module = std::make_unique(module_name, *llvm_target->GetContext()); llvm_target->SetTargetMetadata(module.get()); module->setTargetTriple(llvm_target->GetTargetTriple()); module->setDataLayout(llvm_target->GetOrCreateTargetMachine()->createDataLayout()); - n->Init(std::move(module), std::move(llvm_scope)); + n->Init(std::move(module), std::move(llvm_instance)); return runtime::Module(n); }); @@ -474,8 +474,8 @@ TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll") TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled") .set_body_typed([](std::string target_str) -> bool { - LLVMScope llvm_scope; - auto* tm = With(llvm_scope, target_str) + LLVMInstance llvm_instance; + auto* tm = With(llvm_instance, target_str) ->GetOrCreateTargetMachine(/*allow_missing=*/true); return tm != nullptr; }); @@ -484,17 +484,17 @@ TVM_REGISTER_GLOBAL("codegen.codegen_blob") .set_body_typed([](std::string data, bool system_lib, std::string llvm_target_string) -> runtime::Module { auto n = make_object(); - auto llvm_scope = std::make_unique(); - With llvm_target(*llvm_scope, llvm_target_string); + auto llvm_instance = std::make_unique(); + With llvm_target(*llvm_instance, llvm_target_string); std::unique_ptr blob = CodeGenBlob(data, system_lib, llvm_target.get()); - n->Init(std::move(blob), std::move(llvm_scope)); + n->Init(std::move(blob), std::move(llvm_instance)); return runtime::Module(n); }); runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata, Target target, tvm::relay::Runtime runtime) { - auto llvm_scope = std::make_unique(); - With llvm_target(*llvm_scope, target); + auto llvm_instance = std::make_unique(); + With llvm_target(*llvm_instance, target); bool system_lib = runtime->GetAttr("system-lib").value_or(Bool(false)); std::unique_ptr cg{new CodeGenCPU()}; @@ -517,7 +517,7 @@ runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata << verify_errors.str(); auto n = make_object(); - n->Init(std::move(mod), std::move(llvm_scope)); + n->Init(std::move(mod), std::move(llvm_instance)); auto meta_mod = MetadataModuleCreate(metadata); meta_mod->Import(runtime::Module(n)); @@ -537,8 +537,8 @@ runtime::Module CreateLLVMCrtMetadataModule(const Array& module } } - auto llvm_scope = std::make_unique(); - With llvm_target(*llvm_scope, target); + auto llvm_instance = std::make_unique(); + With llvm_target(*llvm_instance, target); bool system_lib = runtime->GetAttr("system-lib").value_or(Bool(false)); bool target_c_runtime = runtime->name == "crt"; ICHECK(system_lib && target_c_runtime) @@ -563,7 +563,7 @@ runtime::Module CreateLLVMCrtMetadataModule(const Array& module << verify_errors.str(); auto n = make_object(); - n->Init(std::move(mod), std::move(llvm_scope)); + n->Init(std::move(mod), std::move(llvm_instance)); for (auto m : modules) { n->Import(m); } From 27e5be7748b5d4735994a20cf287dc9d5e58c66b Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Wed, 3 Aug 2022 13:13:19 -0700 Subject: [PATCH 6/6] Fix the header guard name --- src/target/llvm/llvm_instance.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/target/llvm/llvm_instance.h b/src/target/llvm/llvm_instance.h index 37eacd2a4aec..afb6e58deb1f 100644 --- a/src/target/llvm/llvm_instance.h +++ b/src/target/llvm/llvm_instance.h @@ -19,8 +19,8 @@ /*! \file llvm_instance.h */ -#ifndef TVM_TARGET_LLVM_LLVM_SCOPE_H_ -#define TVM_TARGET_LLVM_LLVM_SCOPE_H_ +#ifndef TVM_TARGET_LLVM_LLVM_INSTANCE_H_ +#define TVM_TARGET_LLVM_LLVM_INSTANCE_H_ #ifdef TVM_LLVM_VERSION @@ -263,4 +263,4 @@ class LLVMTarget { } // namespace tvm #endif // TVM_LLVM_VERSION -#endif // TVM_TARGET_LLVM_LLVM_SCOPE_H_ +#endif // TVM_TARGET_LLVM_LLVM_INSTANCE_H_