Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/mxnet/rtc.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class CudaModule {
/*! \brief nvrtc program handle. */
nvrtcProgram prog_;
/*! \brief compiled cuda PTX */
char* ptx_;
std::vector<char> ptx_;
/*! \brief lazily loaded cuda module */
std::unordered_map<int, CUmodule> mod_;
/*! \brief exported names */
Expand Down
48 changes: 41 additions & 7 deletions src/common/cuda/rtc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <fstream>
#include <unordered_map>
#include <vector>
#include <tuple>
#include <algorithm>

#include "rtc.h"
#include "rtc/half-inl.h"
Expand Down Expand Up @@ -78,15 +80,47 @@ std::string GetCompileLog(nvrtcProgram program) {
}

// Obtain compilation result (ptx assembly) from the program.
std::string GetPtx(nvrtcProgram program) {
std::string GetCompiledCode(nvrtcProgram program, bool use_cubin) {
#if CUDA_VERSION >= 11010
const auto getSize = use_cubin ? nvrtcGetCUBINSize : nvrtcGetPTXSize;
const auto getFunc = use_cubin ? nvrtcGetCUBIN : nvrtcGetPTX;
Comment on lines +85 to +86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, while nvrtcGetCUBINSize() and nvrtcGetCUBIN() are not yet in the nvrtc docs, their use is described in https://docs.nvidia.com/deploy/cuda-compatibility/

#else
const auto getSize = nvrtcGetPTXSize;
const auto getFunc = nvrtcGetPTX;
#endif
size_t ptx_size_including_null;
NVRTC_CALL(nvrtcGetPTXSize(program, &ptx_size_including_null));
NVRTC_CALL(getSize(program, &ptx_size_including_null));
std::string ptx(ptx_size_including_null - 1, '\0');
// Room for terminating null character ensured since C++11
NVRTC_CALL(nvrtcGetPTX(program, &ptx[0]));
NVRTC_CALL(getFunc(program, &ptx[0]));
return ptx;
}

std::tuple<bool, std::string> GetArchString(const int sm_arch) {
#if CUDA_VERSION < 10000
constexpr int max_supported_sm_arch = 72;
#elif CUDA_VERSION < 11000
constexpr int max_supported_sm_arch = 75;
#elif CUDA_VERSION < 11010
constexpr int max_supported_sm_arch = 80;
#else
constexpr int max_supported_sm_arch = 86;
#endif

#if CUDA_VERSION <= 11000
// Always use PTX for CUDA <= 11.0
const bool known_arch = false;
#else
const bool known_arch = sm_arch <= max_supported_sm_arch;
#endif
const int actual_sm_arch = std::min(sm_arch, max_supported_sm_arch);
if (known_arch) {
return {known_arch, "sm_" + std::to_string(actual_sm_arch)};
} else {
return {known_arch, "compute_" + std::to_string(actual_sm_arch)};
}
}

} // namespace

CUfunction get_function(const std::string &parameters,
Expand Down Expand Up @@ -141,14 +175,14 @@ CUfunction get_function(const std::string &parameters,
0, // num headers
nullptr, // headers
nullptr)); // include names

std::string gpu_arch_arg = "--gpu-architecture=compute_" + std::to_string(sm_arch);
const auto [use_cubin, gpu_arch] = GetArchString(sm_arch); // NOLINT(*)
std::string gpu_arch_arg = "--gpu-architecture=" + gpu_arch;
const char *opts[] = {gpu_arch_arg.c_str(),
#if NDEBUG == 0
"-G",
#endif
"--std=c++14"};
const std::string kernel_name_demangled = kernel_name;
const std::string& kernel_name_demangled = kernel_name;
NVRTC_CALL(nvrtcAddNameExpression(program, (kernel_name_demangled).c_str()));

nvrtcResult compileResult = nvrtcCompileProgram(program, // prog
Expand All @@ -165,7 +199,7 @@ CUfunction get_function(const std::string &parameters,
<< "The generated code was stored in " << dump_file << "\n"
<< GetCompileLog(program);

kinfo.ptx = GetPtx(program);
kinfo.ptx = GetCompiledCode(program, use_cubin);
const char *mangled_name;
NVRTC_CALL(nvrtcGetLoweredName(program,
kernel_name_demangled.c_str(),
Expand Down
37 changes: 28 additions & 9 deletions src/common/rtc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ CudaModule::Chunk::Chunk(
<< "For lower version of CUDA, please prepend your kernel defintiions "
<< "with extern \"C\" instead.";
#endif
std::vector<const char*> c_options;
for (const auto& i : options) c_options.push_back(i.c_str());
std::vector<const char*> c_options(options.size());
for (const auto& i : options) c_options.emplace_back(i.c_str());
nvrtcResult compile_res = nvrtcCompileProgram(prog_, c_options.size(), c_options.data());
if (compile_res != NVRTC_SUCCESS) {
size_t err_size;
Expand All @@ -55,10 +55,30 @@ CudaModule::Chunk::Chunk(
LOG(FATAL) << err.data();
}

size_t ptx_size;
NVRTC_CALL(nvrtcGetPTXSize(prog_, &ptx_size));
ptx_ = new char[ptx_size];
NVRTC_CALL(nvrtcGetPTX(prog_, ptx_));
bool use_ptx = true;
for (const auto& opt : options) {
if (opt.find("sm_") != std::string::npos) {
use_ptx = false;
break;
}
}

if (use_ptx) {
size_t ptx_size;
NVRTC_CALL(nvrtcGetPTXSize(prog_, &ptx_size));
ptx_.resize(ptx_size);
NVRTC_CALL(nvrtcGetPTX(prog_, ptx_.data()));
} else {
#if CUDA_VERSION >= 11010
size_t cubin_size;
NVRTC_CALL(nvrtcGetCUBINSize(prog_, &cubin_size));
ptx_.resize(cubin_size);
NVRTC_CALL(nvrtcGetCUBIN(prog_, ptx_.data()));
#else
LOG(FATAL) << "Your CUDA version does not support compiling for sm_XX target. "
<< "Use compute_XX target instead or upgrade to CUDA 11.1 or later.";
#endif
}
}


Expand All @@ -67,7 +87,6 @@ CudaModule::Chunk::~Chunk() {
CUDA_DRIVER_CALL(cuModuleUnload(kv.second));
}
NVRTC_CALL(nvrtcDestroyProgram(&prog_));
delete ptx_;
}


Expand All @@ -83,7 +102,7 @@ CUfunction CudaModule::Chunk::GetFunction(
module = iter->second;
} else {
device_store.SetDevice(ctx.dev_id);
CUDA_DRIVER_CALL(cuModuleLoadDataEx(&module, ptx_, 0, 0, 0));
CUDA_DRIVER_CALL(cuModuleLoadDataEx(&module, ptx_.data(), 0, nullptr, nullptr));
mod_[ctx.dev_id] = module;
}
CUfunction function;
Expand Down Expand Up @@ -176,7 +195,7 @@ void CudaModule::Kernel::Launch(
function, grid_dim_x, grid_dim_y, grid_dim_z,
block_dim_x, block_dim_y, block_dim_z,
shared_mem, s->stream_,
p_args.data(), 0));
p_args.data(), nullptr));
CUDA_CALL(cudaStreamSynchronize(s->stream_));
}, ctx, read_vars, write_vars, FnProperty::kNormal, 0,
mangled_name_.c_str());
Expand Down