diff --git a/dmlc-core b/dmlc-core index 46886a6b47f6..a527100d7d50 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit 46886a6b47f660cda581e497378204ccc029a01e +Subproject commit a527100d7d5001efc4954848a2fc6027e48c05f4 diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index ecb6200807fb..f3f908f2ef53 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -29,3 +29,4 @@ from .schedule import create_schedule from .build_module import build, lower, build_config from .tag import tag_scope +from .contrib import rocm as _rocm diff --git a/python/tvm/_ffi/ndarray.py b/python/tvm/_ffi/ndarray.py index eb440bf06635..b7bac46157d2 100644 --- a/python/tvm/_ffi/ndarray.py +++ b/python/tvm/_ffi/ndarray.py @@ -59,6 +59,8 @@ def context(dev_type, dev_id=0): if dev_type not in TVMContext.STR2MASK: if dev_type.find("nvptx") != -1: dev_type = "cuda" + if dev_type.find("rocm") != -1: + dev_type = "rocm" if dev_type not in TVMContext.STR2MASK: raise ValueError("Unknown device type %s" % dev_type) dev_type = TVMContext.STR2MASK[dev_type] diff --git a/python/tvm/contrib/rocm.py b/python/tvm/contrib/rocm.py new file mode 100644 index 000000000000..c367aef24e21 --- /dev/null +++ b/python/tvm/contrib/rocm.py @@ -0,0 +1,50 @@ +"""Utility for ROCm backend""" +import subprocess +from . import util +from ..api import register_func + +def rocm_link(in_file, out_file): + """Link relocatable ELF object to shared ELF object using lld + + Parameters + ---------- + in_file : str + Input file name (relocatable ELF object file) + + out_file : str + Output file name (shared ELF object file) + """ + args = ["ld.lld", "-shared", in_file, "-o", out_file] + proc = subprocess.Popen( + args, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + (out, _) = proc.communicate() + + if proc.returncode != 0: + msg = "Linking error using ld.lld:\n" + msg += str(out) + raise RuntimeError(msg) + +@register_func("tvm_callback_rocm_link") +def callback_rocm_link(obj_bin): + """Links object file generated from LLVM to HSA Code Object + + Parameters + ---------- + obj_bin : bytearray + The object file + + Return + ------ + cobj_bin : bytearray + The HSA Code Object + """ + tmp_dir = util.tempdir() + tmp_obj = tmp_dir.relpath("rocm_kernel.o") + tmp_cobj = tmp_dir.relpath("rocm_kernel.co") + with open(tmp_obj, "wb") as out_file: + out_file.write(bytes(obj_bin)) + rocm_link(tmp_obj, tmp_cobj) + cobj_bin = bytearray(open(tmp_cobj, "rb").read()) + return cobj_bin diff --git a/src/codegen/llvm/codegen_amdgpu.cc b/src/codegen/llvm/codegen_amdgpu.cc new file mode 100644 index 000000000000..4769efdb0405 --- /dev/null +++ b/src/codegen/llvm/codegen_amdgpu.cc @@ -0,0 +1,188 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file codegen_amdgpu.cc + * \brief AMDGPU code generator. + */ +#ifdef TVM_LLVM_VERSION +#if TVM_ROCM_RUNTIME + +#include +#include +#include +#include "./codegen_llvm.h" +#include "../build_common.h" +#include "../../pass/ir_util.h" +#include "../../runtime/rocm/rocm_module.h" + +namespace tvm { +namespace codegen { + +// AMDGPU code generator. +class CodeGenAMDGPU : public CodeGenLLVM { + public: + void AddFunction(const LoweredFunc& f) final { + // add function as void return value + CodeGenLLVM::AddFunctionInternal(f, true); + function_->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL); + } + + void VisitStmt_(const Allocate* op) final { + CHECK(!is_zero(op->condition)); + llvm::Value* buf = nullptr; + if (op->new_expr.defined()) { + CHECK_EQ(op->free_function, "nop"); + buf = MakeValue(op->new_expr); + } else { + int32_t constant_size = op->constant_allocation_size(); + CHECK_GT(constant_size, 0) + << "Can only handle constant size stack allocation in GPU"; + StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; + if (constant_size % 4 == 0 && info.alignment == 0) { + info.alignment = GetTempAllocaAlignment(op->type, constant_size); + } + // maximum necessary alignment in the AMD devices + if (info.alignment > 16) { + info.alignment = 16; + } + if (info.scope.rank == 2) { + // const int local_address_space = 5; + // TODO(tqchen): for higher version of LLVM, local address space can be set. + llvm::AllocaInst* alloca = builder_->CreateAlloca( + LLVMType(op->type), ConstInt32(constant_size)); + if (alloca->getAlignment() < static_cast(info.alignment)) { + alloca->setAlignment(info.alignment); + } + buf = alloca; + } else { + CHECK_EQ(info.scope.rank, 1) + << "Can only allocate shared or local memory inside kernel"; + // Shared memory: address space == 3 + const unsigned shared_address_space = 3; + llvm::Type* type = llvm::ArrayType::get(LLVMType(op->type), constant_size); + // Allocate shared memory in global, address_space = 3 + llvm::GlobalVariable *global = new llvm::GlobalVariable( + *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", + nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); + global->setAlignment(info.alignment); + buf = global; + } + } + buf = builder_->CreatePointerCast( + buf, LLVMType(op->type)->getPointerTo( + buf->getType()->getPointerAddressSpace())); + CHECK(!var_map_.count(op->buffer_var.get())); + var_map_[op->buffer_var.get()] = buf; + this->VisitStmt(op->body); + } + + // Return the thread index via intrinsics. + llvm::Value* GetThreadIndex(const IterVar& iv) final { + runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag); + llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x; + if (ts.rank == 1) { + switch (ts.dim_index) { + case 0: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x; break; + case 1: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_y; break; + case 2: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_z; break; + default: LOG(FATAL) << "unknown workitem idx"; + } + } else { + CHECK_EQ(ts.rank, 0); + switch (ts.dim_index) { + case 0: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_x; break; + case 1: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_y; break; + case 2: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_z; break; + default: LOG(FATAL) << "unknown workgroup idx"; + } + } + llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id); + return builder_->CreateCall(f, {}); + } + + llvm::Value* CreateStorageSync(const Call* op) final { + const std::string& sync = op->args[0].as()->value; + if (sync == "warp") { + // TODO(tqchen) warp sync in CUDA9 + return nullptr; + } else if (sync == "shared") { + llvm::Function* f = llvm::Intrinsic::getDeclaration( + module_.get(), + ::llvm::Intrinsic::amdgcn_s_barrier); + return builder_->CreateCall(f, {}); + } else { + LOG(FATAL) << "Do not support sync " << sync; + return nullptr; + } + } + + void InitPassManagerBuilder(llvm::PassManagerBuilder* builder) final { + // Additional optimization hook to tweak the builder. + } + + unsigned GetGlobalAddressSpace() { + return 1; + } + + protected: + void InitTarget(llvm::TargetMachine* tm) final { + // Maximum vector lane = float4 + native_vector_bits_ = 4 * 32; + CodeGenLLVM::InitTarget(tm); + } +}; + +runtime::Module BuildAMDGPU(Array funcs, std::string target) { + CHECK(target.length( +) >= 4 && + target.substr(0, 4) == "rocm"); + llvm::TargetMachine* tm = \ + GetLLVMTargetMachine("-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx900" + \ + target.substr(4, target.length() - 4)); + + std::unique_ptr cg(new CodeGenAMDGPU()); + std::unique_ptr ctx(new llvm::LLVMContext()); + cg->Init(funcs[0]->name, tm, ctx.get(), false, false); + for (LoweredFunc f : funcs) { + cg->AddFunction(f); + } + + std::unique_ptr module = cg->Finish(); + + llvm::SmallString<8> dataObj, data_ll, dataAsm; + llvm::raw_svector_ostream destObj(dataObj), dest_ll(data_ll), destAsm(dataAsm); + destObj.SetUnbuffered(); + dest_ll.SetUnbuffered(); + destAsm.SetUnbuffered(); + module->print(dest_ll, nullptr); + std::unique_ptr mAsm = llvm::CloneModule(module.get()); + std::unique_ptr mObj = llvm::CloneModule(module.get()); + llvm::legacy::PassManager pass; + + CHECK(tm->addPassesToEmitFile( + pass, destObj, llvm::TargetMachine::CGFT_ObjectFile) == 0) + << "Cannot emit target CGFT_ObjectFile"; + pass.run(*mObj); + std::string obj(dataObj.begin(), dataObj.end()); + + const auto* f = tvm::runtime::Registry::Get("tvm_callback_rocm_link"); + CHECK(f != nullptr) << "Require tvm_callback_rocm_link to exist, do import tvm.contrib.rocm"; + + TVMByteArray arr; + arr.data = &obj[0]; + arr.size = obj.length(); + + std::string hsaco = (*f)(arr); + std::string ll(data_ll.begin(), data_ll.end()); + + return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(funcs), ll); +} + +TVM_REGISTER_API("codegen.build_rocm") +.set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = BuildAMDGPU(args[0], args[1]); + }); + +} // namespace codegen +} // namespace tvm +#endif // TVM_ROCM_RUNTIME +#endif // TVM_LLVM_VERSION diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 30437bb911d6..743575cb65aa 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -100,7 +100,7 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) { Type t = arg.type(); if (t.is_handle() && f->handle_data_type.count(arg)) { arg_type.push_back( - LLVMType(f->handle_data_type[arg].type())->getPointerTo()); + LLVMType(f->handle_data_type[arg].type())->getPointerTo(GetGlobalAddressSpace())); if (!is_restricted_) { alias_var_set_.insert(arg.get()); } @@ -555,6 +555,10 @@ int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) co return native_vector_bits_; } +unsigned CodeGenLLVM::GetGlobalAddressSpace() { + return 0; +} + void CodeGenLLVM::GetAlignment( Type t, const Variable* buf_var, const Expr& index, int* p_alignment, int* p_native_bits) { diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index 7fa10c05ad4a..d055d7d5a1c3 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -23,6 +23,7 @@ namespace codegen { using namespace ir; + /*! * \brief A base class to generate a LLVM. */ @@ -148,6 +149,9 @@ class CodeGenLLVM : virtual void Optimize(); // Get the maximim storage align bits of buffer pointer given storage scope. virtual int NativeVectorBits(const runtime::StorageScope& storage_scope) const; + // Get correct address space depending on the backend + virtual unsigned GetGlobalAddressSpace(); + void AddFunctionInternal(const LoweredFunc& f, bool ret_void); // Create extern call llvm::CallInst* CreateCallExtern(llvm::Type* ret, diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 2f2b0a214bed..43ad6e523494 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -125,6 +125,8 @@ bool RuntimeEnabled(const std::string& target) { f_name = "device_api.vpi"; } else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") { f_name = "codegen.build_nvptx"; + } else if (target.length() >= 4 && target.substr(0, 4) == "rocm") { + f_name = "codegen.build_rocm"; } else if (target.length() >= 4 && target.substr(0, 4) == "llvm") { const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled"); if (pf == nullptr) return false; diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index f7ce1f284ee2..2839e10945f8 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -59,10 +59,17 @@ class ROCMModuleNode : public runtime::ModuleNode { stream->Write(data_); } + std::string GetSource(const std::string& format) final { + if (format == fmt_) { return data_; } + if (fmt_ == "hsaco") { return data_; } + return ""; + } + // get a CUfunction from primary context in device_id hipFunction_t GetFunc(int device_id, const std::string& func_name) { std::lock_guard lock(mutex_); // must recheck under the lock scope + if (module_[device_id] == nullptr) { ROCM_DRIVER_CALL(hipModuleLoadData(&(module_[device_id]), data_.c_str())); } @@ -140,7 +147,9 @@ class ROCMWrappedFunc { if (fcache_[device_id] == nullptr) { fcache_[device_id] = m_->GetFunc(device_id, func_name_); } + hipStream_t strm = static_cast(ROCMThreadEntry::ThreadLocal()->stream); + ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); void* config[] = { HIP_LAUNCH_PARAM_BUFFER_POINTER, &packed_args, @@ -181,7 +190,6 @@ PackedFunc ROCMModuleNode::GetFunction( CHECK_EQ(sptr_to_self.get(), this); CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; - auto it = fmap_.find(name); if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; diff --git a/tests/python/integration/test_gemm.py b/tests/python/integration/test_gemm.py index 0798ecf61e4f..5733e11a84ec 100644 --- a/tests/python/integration/test_gemm.py +++ b/tests/python/integration/test_gemm.py @@ -85,6 +85,8 @@ def check_device(device): np.testing.assert_allclose( c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5) + check_device("nvptx -mcpu=sm_20") + check_device("rocm") check_device("metal") check_device("opencl") check_device("cuda") diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py index ee3284cd640c..c4fbb4eccac1 100644 --- a/tests/python/unittest/test_codegen_device.py +++ b/tests/python/unittest/test_codegen_device.py @@ -82,6 +82,7 @@ def check_module_save(device, host="stackvm"): check_target("cuda", host="llvm") check_module_save("cuda", host="stackvm") check_target("nvptx", host="llvm") + check_target("rocm", host="llvm") if __name__ == "__main__": test_add_pipeline()