Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@
from .schedule import create_schedule
from .build_module import build, lower, build_config
from .tag import tag_scope
from .contrib import rocm as _rocm

# Contrib initializers
from .contrib import rocm as _rocm, nvcc as _nvcc
61 changes: 61 additions & 0 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
from __future__ import absolute_import as _abs

import subprocess
import os
import warnings
from . import util
from .. import ndarray as nd
from ..api import register_func
from .._ffi.base import py_str

def compile_cuda(code,
target="ptx",
Expand Down Expand Up @@ -72,3 +76,60 @@ def compile_cuda(code,
raise RuntimeError(msg)

return bytearray(open(file_target, "rb").read())


def find_cuda_path():
"""Utility function to find cuda path

Returns
-------
path : str
Path to cuda root.
"""
if "CUDA_PATH" in os.environ:
return os.environ["CUDA_PATH"]
cmd = ["which", "nvcc"]
proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
out = py_str(out)
if proc.returncode == 0:
return os.path.abspath(os.path.join(str(out).strip(), "../.."))
cuda_path = "/usr/local/cuda"
if os.path.exists(os.path.join(cuda_path, "bin/nvcc")):
return cuda_path
raise RuntimeError("Cannot find cuda path")


@register_func("tvm_callback_libdevice_path")
def find_libdevice_path(arch):
"""Utility function to find libdevice

Parameters
----------
arch : int
The compute architecture in int
"""
cuda_path = find_cuda_path()
lib_path = os.path.join(cuda_path, "nvvm/libdevice")
selected_ver = 0
selected_path = None

for fn in os.listdir(lib_path):
if not fn.startswith("libdevice"):
continue
ver = int(fn.split(".")[-3].split("_")[-1])
if ver > selected_ver and ver <= arch:
selected_ver = ver
selected_path = fn
if selected_path is None:
raise RuntimeError("Cannot find libdevice for arch {}".format(arch))
return os.path.join(lib_path, selected_path)


def callback_libdevice_path(arch):
try:
return find_libdevice_path(arch)
except RuntimeError:
warnings.warn("Cannot find libdevice path")
return ""
1 change: 1 addition & 0 deletions python/tvm/contrib/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def rocm_link(in_file, out_file):
msg += str(out)
raise RuntimeError(msg)


@register_func("tvm_callback_rocm_link")
def callback_rocm_link(obj_bin):
"""Links object file generated from LLVM to HSA Code Object
Expand Down
11 changes: 11 additions & 0 deletions src/codegen/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,21 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {

std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
this->AddStartupFunction();
// link modules
for (size_t i = 0; i < link_modules_.size(); ++i) {
CHECK(!llvm::Linker::linkModules(*module_, std::move(link_modules_[i])))
<< "Failed to link modules";
}
link_modules_.clear();
// optimize
this->Optimize();
return std::move(module_);
}

void CodeGenLLVM::AddLinkModule(std::unique_ptr<llvm::Module>&& mod) {
link_modules_.emplace_back(std::move(mod));
}

void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
LOG(FATAL) << "not implemented";
}
Expand Down
8 changes: 7 additions & 1 deletion src/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ class CodeGenLLVM :
* \return the created module.
*/
virtual std::unique_ptr<llvm::Module> Finish();
/*!
* \brief Add mod to be linked with the generated module
* \param mod The module to be linked.
*/
void AddLinkModule(std::unique_ptr<llvm::Module>&& mod);
/*!
* \brief Create Value for expression e
* \param e The expression to be created value for.
Expand Down Expand Up @@ -227,7 +232,8 @@ class CodeGenLLVM :
llvm::MDNode* md_very_likely_branch_{nullptr};
llvm::MDNode* md_tbaa_root_{nullptr};
llvm::MDNode* md_tbaa_alias_set_{nullptr};

// modules to be linked.
std::vector<std::unique_ptr<llvm::Module> > link_modules_;
/*! \brief native vector bits of current targetx*/
int native_vector_bits_{0};
/*! \brief the storage scope of allocation */
Expand Down
22 changes: 21 additions & 1 deletion src/codegen/llvm/codegen_nvptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,10 @@ inline int DetectCUDAComputeVersion() {
runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
CHECK(target.length() >= 5 &&
target.substr(0, 5) == "nvptx");
int compute_ver = DetectCUDAComputeVersion();
std::ostringstream config;
config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_"
<< DetectCUDAComputeVersion()
<< compute_ver
<< target.substr(5, target.length() - 5);
llvm::TargetMachine* tm = GetLLVMTargetMachine(config.str());
std::unique_ptr<CodeGenNVPTX> cg(new CodeGenNVPTX());
Expand All @@ -164,6 +165,25 @@ runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
for (LoweredFunc f : funcs) {
cg->AddFunction(f);
}

const auto* flibdevice_path =
tvm::runtime::Registry::Get("tvm_callback_libdevice_path");
if (flibdevice_path != nullptr) {
std::string path = (*flibdevice_path)(compute_ver);
if (path.length() != 0) {
llvm::SMDiagnostic err;
std::unique_ptr<llvm::Module> mlib = llvm::parseIRFile(path, err, *ctx);
if (mlib.get() == nullptr) {
std::string msg = err.getMessage();
LOG(FATAL) << "Fail to load bitcode file " << path << "\n"
<< "line " << err.getLineNo() << ":" << msg;
}
mlib->setTargetTriple(tm->getTargetTriple().str());
mlib->setDataLayout(tm->createDataLayout());
// TODO(tqchen) libdevice linking not yet working.
// cg->AddLinkModule(std::move(mlib));
}
}
std::unique_ptr<llvm::Module> module = cg->Finish();
llvm::SmallString<8> data_ptx, data_ll;
llvm::raw_svector_ostream dest_ptx(data_ptx), dest_ll(data_ll);
Expand Down
42 changes: 1 addition & 41 deletions src/codegen/llvm/intrin_rule_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,52 +4,12 @@
*/
#ifdef TVM_LLVM_VERSION

#include <tvm/ir.h>
#include <tvm/api_registry.h>
#include <tvm/codegen.h>
#include <string>
#include "./llvm_common.h"
#include "./intrin_rule_llvm.h"

namespace tvm {
namespace codegen {
namespace llvm {

using namespace ir;

// num_signature means number of arguments used to query signature
template<unsigned id, int num_signature>
inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
const Call* call = e.as<Call>();
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
cargs.push_back(UIntImm::make(UInt(32), id));
cargs.push_back(UIntImm::make(UInt(32), num_signature));

for (Expr arg : call->args) {
cargs.push_back(arg);
}
*rv = Call::make(
call->type, "llvm_intrin", cargs, Call::PureIntrinsic);
}

template<unsigned id, int num_signature>
inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
const Call* call = e.as<Call>();
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
cargs.push_back(UIntImm::make(UInt(32), id));
cargs.push_back(UIntImm::make(UInt(32), num_signature));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
*rv = Call::make(
call->type, "llvm_intrin", cargs, Call::Intrinsic);
}

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch")
.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 0>);

Expand Down
56 changes: 56 additions & 0 deletions src/codegen/llvm/intrin_rule_llvm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*!
* Copyright (c) 2017 by Contributors
* \file intrin_rule_llvm.h
* \brief Common utilities for llvm intrinsics.
*/
#ifndef TVM_CODEGEN_LLVM_INTRIN_RULE_LLVM_H_
#define TVM_CODEGEN_LLVM_INTRIN_RULE_LLVM_H_
#ifdef TVM_LLVM_VERSION

#include <tvm/ir.h>
#include <tvm/api_registry.h>
#include <tvm/codegen.h>
#include <string>
#include "./llvm_common.h"

namespace tvm {
namespace codegen {
// num_signature means number of arguments used to query signature
template<unsigned id, int num_signature>
inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
const ir::Call* call = e.as<ir::Call>();
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
cargs.push_back(ir::UIntImm::make(UInt(32), id));
cargs.push_back(ir::UIntImm::make(UInt(32), num_signature));

for (Expr arg : call->args) {
cargs.push_back(arg);
}
*rv = ir::Call::make(
call->type, "llvm_intrin", cargs, ir::Call::PureIntrinsic);
}

template<unsigned id, int num_signature>
inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
const ir::Call* call = e.as<ir::Call>();
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
cargs.push_back(ir::UIntImm::make(UInt(32), id));
cargs.push_back(ir::UIntImm::make(UInt(32), num_signature));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
*rv = ir::Call::make(
call->type, "llvm_intrin", cargs, ir::Call::Intrinsic);
}

} // namespace codegen
} // namespace tvm

#endif // LLVM_VERSION
#endif // TVM_CODEGEN_LLVM_INTRIN_RULE_LLVM_H_
2 changes: 2 additions & 0 deletions src/codegen/llvm/llvm_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
#include <llvm/IRReader/IRReader.h>
#include <llvm/CodeGen/TargetLoweringObjectFileImpl.h>

#include <llvm/Linker/Linker.h>

#include <utility>
#include <string>

Expand Down