diff --git a/include/mxnet/rtc.h b/include/mxnet/rtc.h index 747c0b5c94ab..56717f4a34c7 100644 --- a/include/mxnet/rtc.h +++ b/include/mxnet/rtc.h @@ -60,7 +60,7 @@ class CudaModule { /*! \brief nvrtc program handle. */ nvrtcProgram prog_; /*! \brief compiled cuda PTX */ - char* ptx_; + std::vector ptx_; /*! \brief lazily loaded cuda module */ std::unordered_map mod_; /*! \brief exported names */ diff --git a/src/common/cuda/rtc.cc b/src/common/cuda/rtc.cc index 8f3b3391f5e4..2284beec11cd 100644 --- a/src/common/cuda/rtc.cc +++ b/src/common/cuda/rtc.cc @@ -28,6 +28,8 @@ #include #include #include +#include +#include #include "rtc.h" #include "rtc/half-inl.h" @@ -78,15 +80,47 @@ std::string GetCompileLog(nvrtcProgram program) { } // Obtain compilation result (ptx assembly) from the program. -std::string GetPtx(nvrtcProgram program) { +std::string GetCompiledCode(nvrtcProgram program, bool use_cubin) { +#if CUDA_VERSION >= 11010 + const auto getSize = use_cubin ? nvrtcGetCUBINSize : nvrtcGetPTXSize; + const auto getFunc = use_cubin ? nvrtcGetCUBIN : nvrtcGetPTX; +#else + const auto getSize = nvrtcGetPTXSize; + const auto getFunc = nvrtcGetPTX; +#endif size_t ptx_size_including_null; - NVRTC_CALL(nvrtcGetPTXSize(program, &ptx_size_including_null)); + NVRTC_CALL(getSize(program, &ptx_size_including_null)); std::string ptx(ptx_size_including_null - 1, '\0'); // Room for terminating null character ensured since C++11 - NVRTC_CALL(nvrtcGetPTX(program, &ptx[0])); + NVRTC_CALL(getFunc(program, &ptx[0])); return ptx; } +std::tuple GetArchString(const int sm_arch) { +#if CUDA_VERSION < 10000 + constexpr int max_supported_sm_arch = 72; +#elif CUDA_VERSION < 11000 + constexpr int max_supported_sm_arch = 75; +#elif CUDA_VERSION < 11010 + constexpr int max_supported_sm_arch = 80; +#else + constexpr int max_supported_sm_arch = 86; +#endif + +#if CUDA_VERSION <= 11000 + // Always use PTX for CUDA <= 11.0 + const bool known_arch = false; +#else + const bool known_arch = sm_arch <= max_supported_sm_arch; +#endif + const int actual_sm_arch = std::min(sm_arch, max_supported_sm_arch); + if (known_arch) { + return {known_arch, "sm_" + std::to_string(actual_sm_arch)}; + } else { + return {known_arch, "compute_" + std::to_string(actual_sm_arch)}; + } +} + } // namespace CUfunction get_function(const std::string ¶meters, @@ -141,14 +175,14 @@ CUfunction get_function(const std::string ¶meters, 0, // num headers nullptr, // headers nullptr)); // include names - - std::string gpu_arch_arg = "--gpu-architecture=compute_" + std::to_string(sm_arch); + const auto [use_cubin, gpu_arch] = GetArchString(sm_arch); // NOLINT(*) + std::string gpu_arch_arg = "--gpu-architecture=" + gpu_arch; const char *opts[] = {gpu_arch_arg.c_str(), #if NDEBUG == 0 "-G", #endif "--std=c++14"}; - const std::string kernel_name_demangled = kernel_name; + const std::string& kernel_name_demangled = kernel_name; NVRTC_CALL(nvrtcAddNameExpression(program, (kernel_name_demangled).c_str())); nvrtcResult compileResult = nvrtcCompileProgram(program, // prog @@ -165,7 +199,7 @@ CUfunction get_function(const std::string ¶meters, << "The generated code was stored in " << dump_file << "\n" << GetCompileLog(program); - kinfo.ptx = GetPtx(program); + kinfo.ptx = GetCompiledCode(program, use_cubin); const char *mangled_name; NVRTC_CALL(nvrtcGetLoweredName(program, kernel_name_demangled.c_str(), diff --git a/src/common/rtc.cc b/src/common/rtc.cc index 21d3061e5209..a2662ce9f59c 100644 --- a/src/common/rtc.cc +++ b/src/common/rtc.cc @@ -44,8 +44,8 @@ CudaModule::Chunk::Chunk( << "For lower version of CUDA, please prepend your kernel defintiions " << "with extern \"C\" instead."; #endif - std::vector c_options; - for (const auto& i : options) c_options.push_back(i.c_str()); + std::vector c_options(options.size()); + for (const auto& i : options) c_options.emplace_back(i.c_str()); nvrtcResult compile_res = nvrtcCompileProgram(prog_, c_options.size(), c_options.data()); if (compile_res != NVRTC_SUCCESS) { size_t err_size; @@ -55,10 +55,30 @@ CudaModule::Chunk::Chunk( LOG(FATAL) << err.data(); } - size_t ptx_size; - NVRTC_CALL(nvrtcGetPTXSize(prog_, &ptx_size)); - ptx_ = new char[ptx_size]; - NVRTC_CALL(nvrtcGetPTX(prog_, ptx_)); + bool use_ptx = true; + for (const auto& opt : options) { + if (opt.find("sm_") != std::string::npos) { + use_ptx = false; + break; + } + } + + if (use_ptx) { + size_t ptx_size; + NVRTC_CALL(nvrtcGetPTXSize(prog_, &ptx_size)); + ptx_.resize(ptx_size); + NVRTC_CALL(nvrtcGetPTX(prog_, ptx_.data())); + } else { +#if CUDA_VERSION >= 11010 + size_t cubin_size; + NVRTC_CALL(nvrtcGetCUBINSize(prog_, &cubin_size)); + ptx_.resize(cubin_size); + NVRTC_CALL(nvrtcGetCUBIN(prog_, ptx_.data())); +#else + LOG(FATAL) << "Your CUDA version does not support compiling for sm_XX target. " + << "Use compute_XX target instead or upgrade to CUDA 11.1 or later."; +#endif + } } @@ -67,7 +87,6 @@ CudaModule::Chunk::~Chunk() { CUDA_DRIVER_CALL(cuModuleUnload(kv.second)); } NVRTC_CALL(nvrtcDestroyProgram(&prog_)); - delete ptx_; } @@ -83,7 +102,7 @@ CUfunction CudaModule::Chunk::GetFunction( module = iter->second; } else { device_store.SetDevice(ctx.dev_id); - CUDA_DRIVER_CALL(cuModuleLoadDataEx(&module, ptx_, 0, 0, 0)); + CUDA_DRIVER_CALL(cuModuleLoadDataEx(&module, ptx_.data(), 0, nullptr, nullptr)); mod_[ctx.dev_id] = module; } CUfunction function; @@ -176,7 +195,7 @@ void CudaModule::Kernel::Launch( function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, block_dim_y, block_dim_z, shared_mem, s->stream_, - p_args.data(), 0)); + p_args.data(), nullptr)); CUDA_CALL(cudaStreamSynchronize(s->stream_)); }, ctx, read_vars, write_vars, FnProperty::kNormal, 0, mangled_name_.c_str());