diff --git a/CMakeLists.txt b/CMakeLists.txt index b901f41c29f2..c5104f139b88 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -586,11 +586,11 @@ if(USE_CUDA) string(REPLACE ";" " " CUDA_ARCH_FLAGS_SPACES "${CUDA_ARCH_FLAGS}") - find_package(CUDAToolkit REQUIRED cublas cufft cusolver curand nvrtc cuda_driver + find_package(CUDAToolkit REQUIRED cublas cufft cusolver curand nvrtc OPTIONAL_COMPONENTS nvToolsExt) list(APPEND mxnet_LINKER_LIBS CUDA::cudart CUDA::cublas CUDA::cufft CUDA::cusolver CUDA::curand - CUDA::nvrtc CUDA::cuda_driver) + CUDA::nvrtc) list(APPEND SOURCE ${CUDA}) add_definitions(-DMXNET_USE_CUDA=1) diff --git a/src/common/cuda/rtc.cc b/src/common/cuda/rtc.cc index 5b27e0bbd225..dda3b7421bed 100644 --- a/src/common/cuda/rtc.cc +++ b/src/common/cuda/rtc.cc @@ -32,6 +32,7 @@ #include #include "rtc.h" +#include "../../initialize.h" #include "rtc/half-inl.h" #include "rtc/util-inl.h" #include "rtc/forward_functions-inl.h" @@ -41,12 +42,30 @@ #include "rtc/reducer-inl.h" #include "utils.h" +typedef CUresult (*cuDeviceGetPtr) (CUdevice* device, int ordinal); +typedef CUresult (*cuDevicePrimaryCtxRetainPtr) (CUcontext* pctx, CUdevice dev); +typedef CUresult (*cuModuleLoadDataExPtr) (CUmodule* module, const void* image, + unsigned int numOptions, CUjit_option* options, void** optionValues); +typedef CUresult (*cuModuleGetFunctionPtr) (CUfunction* hfunc, CUmodule hmod, + const char* name); +typedef CUresult (*cuLaunchKernelPtr) (CUfunction f, unsigned int gridDimX, + unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, + void** extra); +typedef CUresult (*cuGetErrorStringPtr) (CUresult error, const char** pStr); namespace mxnet { namespace common { namespace cuda { namespace rtc { +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + const char cuda_lib_name[] = "nvcuda.dll"; +#else + const char cuda_lib_name[] = "libcuda.so"; +#endif + std::mutex lock; namespace util { @@ -149,6 +168,8 @@ CUfunction get_function(const std::string ¶meters, std::string ptx; std::vector functions; }; + void* cuda_lib_handle = LibraryInitializer::Get()->lib_load(cuda_lib_name); + // Maps from the kernel name and parameters to the ptx and jit-compiled CUfunctions. using KernelCache = std::unordered_map; // Per-gpu-architecture compiled kernel cache with jit-compiled function for each device context @@ -233,8 +254,12 @@ CUfunction get_function(const std::string ¶meters, // Make sure driver context is set to the proper device CUdevice cu_device; CUcontext context; - CUDA_DRIVER_CALL(cuDeviceGet(&cu_device, dev_id)); - CUDA_DRIVER_CALL(cuDevicePrimaryCtxRetain(&context, cu_device)); + cuDeviceGetPtr device_get_ptr = get_func(cuda_lib_handle, "cuDeviceGet"); + CUDA_DRIVER_CALL((*device_get_ptr)(&cu_device, dev_id)); + cuDevicePrimaryCtxRetainPtr device_primary_ctx_retain_ptr = + get_func(cuda_lib_handle, "cuDevicePrimaryCtxRetain"); + CUDA_DRIVER_CALL((*device_primary_ctx_retain_ptr)(&context, cu_device)); + // Jit-compile ptx for the driver's current context CUmodule module; @@ -250,10 +275,15 @@ CUfunction get_function(const std::string ¶meters, void* jit_opt_values[] = {reinterpret_cast(debug_info), reinterpret_cast(line_info)}; - CUDA_DRIVER_CALL(cuModuleLoadDataEx(&module, kinfo.ptx.c_str(), 2, jit_opts, jit_opt_values)); - CUDA_DRIVER_CALL(cuModuleGetFunction(&kinfo.functions[dev_id], - module, - kinfo.mangled_name.c_str())); + cuModuleLoadDataExPtr module_load_data_ex_ptr = + get_func(cuda_lib_handle, "cuModuleLoadDataEx"); + CUDA_DRIVER_CALL((*module_load_data_ex_ptr)(&module, kinfo.ptx.c_str(), 2, + jit_opts, jit_opt_values)); + cuModuleGetFunctionPtr module_get_function_ptr = + get_func(cuda_lib_handle, "cuModuleGetFunction"); + CUDA_DRIVER_CALL((*module_get_function_ptr)(&kinfo.functions[dev_id], + module, + kinfo.mangled_name.c_str())); } return kinfo.functions[dev_id]; } @@ -266,8 +296,10 @@ void launch(CUfunction function, std::vector *args) { CHECK(args->size() != 0) << "Empty argument list passed to a kernel."; - // CUDA_DRIVER_CALL( - CUresult err = cuLaunchKernel(function, // function to launch + void* cuda_lib_handle = LibraryInitializer::Get()->lib_load(cuda_lib_name); + cuLaunchKernelPtr launch_kernel_ptr = + get_func(cuda_lib_handle, "cuLaunchKernel"); + CUresult err = (*launch_kernel_ptr)(function, // function to launch grid_dim.x, grid_dim.y, grid_dim.z, // grid dim block_dim.x, block_dim.y, block_dim.z, // block dim shared_mem_bytes, // shared memory @@ -276,7 +308,9 @@ void launch(CUfunction function, nullptr); // ); if (err != CUDA_SUCCESS) { const char* error_string; - cuGetErrorString(err, &error_string); + cuGetErrorStringPtr get_error_string_ptr = + get_func(cuda_lib_handle, "cuGetErrorString"); + (*get_error_string_ptr)(err, &error_string); LOG(FATAL) << "cuLaunchKernel failed: " << err << " " << error_string << ": " << reinterpret_cast(function) << " " diff --git a/src/initialize.cc b/src/initialize.cc index 9ef51219609f..6be13e61ae9e 100644 --- a/src/initialize.cc +++ b/src/initialize.cc @@ -173,7 +173,7 @@ void LibraryInitializer::lib_close(void* handle) { * \param func function pointer that gets output address * \param name function name to be fetched */ -void LibraryInitializer::get_sym(void* handle, void** func, char* name) { +void LibraryInitializer::get_sym(void* handle, void** func, const char* name) { #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) *func = GetProcAddress((HMODULE)handle, name); if (!(*func)) { diff --git a/src/initialize.h b/src/initialize.h index 8a6dc3aa5f7f..d792613aefb2 100644 --- a/src/initialize.h +++ b/src/initialize.h @@ -69,7 +69,7 @@ class LibraryInitializer { bool lib_is_loaded(const std::string& path) const; void* lib_load(const char* path); void lib_close(void* handle); - static void get_sym(void* handle, void** func, char* name); + static void get_sym(void* handle, void** func, const char* name); /** * Original pid of the process which first loaded and initialized the library @@ -114,7 +114,7 @@ class LibraryInitializer { * \return func a function pointer */ template -T get_func(void *lib, char *func_name) { +T get_func(void *lib, const char *func_name) { T func; LibraryInitializer::Get()->get_sym(lib, reinterpret_cast(&func), func_name); if (!func)