Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/target/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ using runtime::TVMRetValue;
* \param target The target to be built.
* \return The result runtime::Module.
*/
runtime::Module Build(IRModule mod, const Target& target);
runtime::Module Build(IRModule mod, Target target);

/*!
* \brief Pack imported device library to a C file.
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ class TargetNode : public Object {
Array<String> keys;
/*! \brief Collection of attributes */
Map<String, ObjectRef> attrs;

/*! \return the full device string to pass to codegen::Build */
TVM_DLL const std::string& str() const;
/*! \return Export target to JSON-like configuration */
TVM_DLL Map<String, ObjectRef> Export() const;

void VisitAttrs(AttrVisitor* v) {
v->Visit("kind", &kind);
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def __enter__(self):
def __exit__(self, ptype, value, trace):
_ffi_api.ExitTargetScope(self)

def export(self):
return _ffi_api.TargetExport(self)

@staticmethod
def current(allow_none=True):
"""Returns the current target.
Expand Down
17 changes: 17 additions & 0 deletions src/target/build_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,23 @@ inline std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co
}
return fmap;
}

inline void UpdateTargetConfigKeyValueEntry(const String& key, const String& value,
Map<String, ObjectRef>* target_config,
bool error_if_inconsistent) {
if (target_config->count(key)) {
const ObjectRef& obj = (*target_config)[key];
CHECK(obj->IsInstance<StringObj>()) << "TypeError: Expect target key \"" << key
<< "\" to be String, but gets type: " << obj->GetTypeKey();
if (error_if_inconsistent) {
String old_value = Downcast<String>(obj);
CHECK_EQ(old_value, value) << "ValueError: Target key \"" << key << "\" has been set to \""
<< old_value << "\", and cannot be reset to \"" << value << "\"";
}
}
target_config->Set(key, value);
}

} // namespace codegen
} // namespace tvm
#endif // TVM_TARGET_BUILD_COMMON_H_
6 changes: 3 additions & 3 deletions src/target/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
namespace tvm {
namespace codegen {

runtime::Module Build(IRModule mod, const Target& target) {
runtime::Module Build(IRModule mod, Target target) {
if (transform::PassContext::Current()
->GetConfig<Bool>("tir.disable_assert", Bool(false))
.value()) {
Expand All @@ -55,8 +55,8 @@ runtime::Module Build(IRModule mod, const Target& target) {
}
// the build function.
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
CHECK(bf != nullptr) << "target.build." << target << " is not enabled";
return (*bf)(mod, target->str());
CHECK(bf != nullptr) << build_f_name << " is not enabled";
return (*bf)(mod, target);
}

/*! \brief Helper class to serialize module */
Expand Down
50 changes: 33 additions & 17 deletions src/target/llvm/codegen_amdgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,17 @@ class CodeGenAMDGPU : public CodeGenLLVM {
}
};

inline int DetectROCMComputeVersion(const std::string& target) {
size_t pos = target.find("=gfx");
if (pos != std::string::npos) {
int value;
std::stringstream is(target.substr(pos + 4));
if (is >> value) return value;
inline int DetectROCMComputeVersion(const Target& target) {
if (const Optional<String> mcpu = target->GetAttr<String>("mcpu")) {
std::string gfx = mcpu.value();
if (gfx.length() >= 3 && gfx.substr(0, 3) == "gfx") {
int version;
std::stringstream is(gfx.substr(3));
if (is >> version) {
return version;
}
}
LOG(FATAL) << "ValueError: Unrecognized -mcpu value: " << mcpu;
}
TVMContext tvm_ctx;
tvm_ctx.device_type = kDLROCM;
Expand Down Expand Up @@ -228,23 +233,34 @@ inline int DetectROCMApiVersion() {
return 305;
}

runtime::Module BuildAMDGPU(IRModule mod, std::string target) {
Target UpdateTarget(const Target& original_target) {
Map<String, ObjectRef> target_config = original_target->Export();
UpdateTargetConfigKeyValueEntry("mtriple", "amdgcn-amd-amdhsa-hcc", &target_config, true);
UpdateTargetConfigKeyValueEntry("mcpu",
"gfx" + std::to_string(DetectROCMComputeVersion(original_target)),
&target_config, false);
if (DetectROCMApiVersion() < 305) {
// before ROCm 3.5 we needed code object v2, starting
// with 3.5 we need v3 (this argument disables v3)
Array<String> mattr;
if (target_config.count("mattr")) {
mattr = Downcast<Array<String>>(target_config["mattr"]);
}
mattr.push_back("-code-object-v3");
target_config.Set("mattr", mattr);
}
return Target::FromConfig(target_config);
}

runtime::Module BuildAMDGPU(IRModule mod, Target original_target) {
#if TVM_LLVM_VERSION < 90
LOG(FATAL) << "AMDGPU backend requires at least LLVM 9";
// Lower versions will crash when loading the bitcode, see
// issue #4087 for a discussion
#endif
InitializeLLVM();
CHECK(target.length() >= 4 && target.substr(0, 4) == "rocm");
std::ostringstream config;
config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" << DetectROCMComputeVersion(target);
if (DetectROCMApiVersion() < 305) {
// before ROCm 3.5 we needed code object v2, starting
// with 3.5 we need v3 (this argument disables v3)
config << " -mattr=-code-object-v3 ";
}
config << target.substr(4, target.length() - 4);
std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str());
Target target = UpdateTarget(original_target);
std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(target);
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
// careful: cg will hold a naked pointer reference to ctx, so it should
// have a shorter lifetime than the ctx.
Expand Down
7 changes: 4 additions & 3 deletions src/target/llvm/codegen_blob.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "codegen_blob.h"

#include <tvm/runtime/module.h>
#include <tvm/target/target.h>

#include <cstring>

Expand All @@ -33,8 +34,8 @@ namespace codegen {
std::pair<std::unique_ptr<llvm::Module>, std::shared_ptr<llvm::LLVMContext>> CodeGenBlob(
const std::string& data, bool system_lib, const std::string& target_triple) {
InitializeLLVM();
std::string full_target_triple = std::string("-mtriple ") + target_triple;
auto tm = GetLLVMTargetMachine(full_target_triple);
Target target = Target::Create("llvm -mtriple " + target_triple);
auto tm = GetLLVMTargetMachine(target);
auto triple = tm->getTargetTriple();
auto ctx = std::make_shared<llvm::LLVMContext>();
std::string module_name = "devc";
Expand All @@ -43,7 +44,7 @@ std::pair<std::unique_ptr<llvm::Module>, std::shared_ptr<llvm::LLVMContext>> Cod
// Store full target string in metadata, because flags such as -mfloat-abi must be preserved for
// ModulePackImportsToLLVM.
module->addModuleFlag(llvm::Module::ModFlagBehavior::Override, "tvm_target",
llvm::MDString::get(*ctx, full_target_triple));
llvm::MDString::get(*ctx, LLVMTargetToString(target)));
module->setDataLayout(tm->createDataLayout());
auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false);
auto* tvm_dev_mblob = new llvm::GlobalVariable(
Expand Down
31 changes: 8 additions & 23 deletions src/target/llvm/codegen_hexagon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -658,11 +658,7 @@ void ProcessLLVMOptions(const std::vector<std::string>& llvm_vec) {

} // namespace

runtime::Module BuildHexagon(IRModule mod, std::string target_str) {
if (target_str.empty()) {
LOG(FATAL) << "Unknown or invalid target.";
}

runtime::Module BuildHexagon(IRModule mod, Target target) {
// Make sure all targets are registered. InitializeLLVM can be called
// multiple times, after the first call all subsequent calls are no-ops.
InitializeLLVM();
Expand All @@ -675,21 +671,12 @@ runtime::Module BuildHexagon(IRModule mod, std::string target_str) {
}
return vec;
};
auto starts_with = [](const std::string& s, const std::string& p) {
return !s.compare(0, p.size(), p);
};

std::vector<std::string> flags = split(target_str);
std::string llvm_target_str, llvm_options_str = "llvm";

for (const auto& s : flags) {
if (starts_with(s, "-mattr=") || starts_with(s, "-mtriple=") || starts_with(s, "-mcpu=")) {
llvm_target_str += " " + s;
} else if (starts_with(s, "-llvm-options=")) {
llvm_options_str += "," + s.substr(14 /*length of -llvm-options=*/);
}
std::string llvm_options_str;
if (const Optional<String> llvm_options = target->GetAttr<String>("llvm-options")) {
llvm_options_str = "llvm," + llvm_options.value();
} else {
llvm_options_str = "llvm";
}

// Postprocess the LLVM options string: replace '@' with '=', and ',' with ' '.
for (int i = 0, e = llvm_options_str.size(); i != e; ++i) {
switch (llvm_options_str[i]) {
Expand All @@ -716,7 +703,7 @@ runtime::Module BuildHexagon(IRModule mod, std::string target_str) {
static bool CallOnce = (ProcessLLVMOptions(llvm_options_vec), true);
(void)CallOnce;

std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(target_str);
std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(target);
std::unique_ptr<CodeGenHexagon> cg(new CodeGenHexagon());
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
cg->Init("TVMHexagonModule", tm.get(), ctx.get(), false, false, false);
Expand Down Expand Up @@ -802,9 +789,7 @@ runtime::Module BuildHexagon(IRModule mod, std::string target_str) {
export_abi);
}

TVM_REGISTER_GLOBAL("target.build.hexagon").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildHexagon(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("target.build.hexagon").set_body_typed(BuildHexagon);

} // namespace codegen
} // namespace tvm
Expand Down
17 changes: 11 additions & 6 deletions src/target/llvm/codegen_nvptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,19 @@ inline int DetectCUDAComputeVersion() {
}
}

runtime::Module BuildNVPTX(IRModule mod, std::string target) {
Target UpdateTarget(const Target& original_target, int compute_ver) {
Map<String, ObjectRef> target_config = original_target->Export();
UpdateTargetConfigKeyValueEntry("mtriple", "nvptx64-nvidia-cuda", &target_config, true);
UpdateTargetConfigKeyValueEntry("mcpu", "sm_" + std::to_string(compute_ver), &target_config,
false);
return Target::FromConfig(target_config);
}

runtime::Module BuildNVPTX(IRModule mod, Target original_target) {
InitializeLLVM();
CHECK(target.length() >= 5 && target.substr(0, 5) == "nvptx");
int compute_ver = DetectCUDAComputeVersion();
std::ostringstream config;
config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_" << compute_ver
<< target.substr(5, target.length() - 5);
std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str());
Target target = UpdateTarget(original_target, compute_ver);
std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(target);
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
// careful: cg will hold a naked pointer reference to ctx, so it should
// have a shorter lifetime than the ctx.
Expand Down
Loading