From f68eb3117ebdbdc29121d6442c4c406755511b6d Mon Sep 17 00:00:00 2001 From: Guillaume Fraux Date: Fri, 1 May 2026 17:27:12 +0200 Subject: [PATCH 1/2] Try to re-use already loaded libcuda/nvrtc/cudart This way we don't load a different version from the rest of the code by mistake. This should also enable re-using libraries from torch installation when creating custom torch ops --- gpulite/gpulite.hpp | 288 ++++++++++++++++++++++++++++++++------------ 1 file changed, 211 insertions(+), 77 deletions(-) diff --git a/gpulite/gpulite.hpp b/gpulite/gpulite.hpp index 462f3f3..b0a3526 100644 --- a/gpulite/gpulite.hpp +++ b/gpulite/gpulite.hpp @@ -4,6 +4,7 @@ #define GPULITE_HPP #include +#include #include #include @@ -14,34 +15,35 @@ #include #include #include +#include #include +#if defined(__linux__) + #include +#endif + #if defined(__linux__) || defined(__APPLE__) -#include -#include // for getcwd + #include + #include // for getcwd #elif defined(_WIN32) -#include -#include - -#include // for _getcwd -#define getcwd _getcwd - -#include + #include + #include + #include + #include // for _getcwd + #define getcwd _getcwd #else -#error "Platform not supported" + #error "Platform not supported" #endif #if defined(_MSC_VER) - // MSVC historically reports __cplusplus wrong unless /Zc:__cplusplus is enabled, - // so prefer _MSVC_LANG there. - #if !defined(_MSVC_LANG) || _MSVC_LANG < 201703L - #error "This project requires C++17 or newer (/std:c++17)." - #endif -#else - #if __cplusplus < 201703L + // MSVC historically reports __cplusplus wrong unless /Zc:__cplusplus is enabled, + // so prefer _MSVC_LANG there. + #if !defined(_MSVC_LANG) || _MSVC_LANG < 201703L + #error "This project requires C++17 or newer (/std:c++17)." + #endif +#elif __cplusplus < 201703L #error "This project requires C++17 or newer (-std=c++17)." - #endif #endif #if defined(__GNUC__) || defined(__clang__) @@ -459,6 +461,88 @@ inline std::optional FindBestCudaDll(const std::wstring& #endif +inline std::string basename(const std::string& path) { + auto fs_path = std::filesystem::path(path); + return fs_path.filename().string(); +} + +#if defined(_WIN32) +/// Try to find an already loaded library whose name starts with the given +/// prefix, and return its handle if found. +inline HMODULE findLoadedLibrary(const std::wstring& prefix) { + auto process = GetCurrentProcess(); + DWORD needed = 0; + auto status = EnumProcessModulesEx(process, nullptr, 0, &needed, LIST_MODULES_ALL); + if (!status || needed == 0) { + return nullptr; + } + + auto modules = std::vector(needed / sizeof(HMODULE)); + status = EnumProcessModulesEx( + process, + modules.data(), + static_cast(modules.size() * sizeof(HMODULE)), + &needed, + LIST_MODULES_ALL + ); + if (!status) { + return nullptr; + } + + std::wstring name(MAX_PATH, L'\0'); + for (auto hmodule : modules) { + auto n = GetModuleBaseNameW(process, hmodule, name.data(), MAX_PATH); + if (n == 0) continue; + name[n] = '\0'; + if (name.find(prefix) == 0) { + return hmodule; + } + } + + return nullptr; +} + +#elif defined(__linux__) + +inline void* findLoadedLibrary(const char* prefix) { + struct SearchData { + const char* prefix; + void* handle; + }; + + auto data = SearchData{prefix, nullptr}; + + auto callback = [](dl_phdr_info* info, std::size_t, void* user_data) -> int { + auto* data = static_cast(user_data); + if (info == nullptr || info->dlpi_name == nullptr || info->dlpi_name[0] == '\0') { + return 0; + } + + auto base = basename(info->dlpi_name); + if (base.find(data->prefix) != 0) { + return 0; + } + + void* h = dlopen(info->dlpi_name, RTLD_NOW | RTLD_NOLOAD); + if (h != nullptr) { + data->handle = h; + return 1; // stop iteration + } + + return 0; + }; + + dl_iterate_phdr(callback, &data); + return data.handle; +} + +#elif defined(__APPLE__) +inline void* findLoadedLibrary(const char* prefix) { + return nullptr; +} +#endif + + // Helper function to demangle the type name if necessary inline std::string demangleTypeName(const std::string& name) { #if defined(__GNUC__) || defined(__clang__) @@ -566,35 +650,55 @@ class CUDART { CUDART() { #if defined(__linux__) || defined(__APPLE__) - static const char* CANDIDATES[] = { - "libcudart.so", - "libcudart.so.11", - "libcudart.so.12", - "libcudart.so.13", - "libcudart.so.14", - "libcudart.so.15", - }; - for (auto* candidate: CANDIDATES) { - cudartHandle = dlopen(candidate, RTLD_NOW); - if (cudartHandle) { - break; + ownedHandle = true; + // First try to find an already loaded libcudart to avoid loading + // multiple versions in the same process + cudartHandle = details::findLoadedLibrary("libcudart"); + + // otherwise, try multiple candidate names to maximize compatibility + // with different CUDA versions + if (!cudartHandle) { + static const char* CANDIDATES[] = { + "libcudart.so", + "libcudart.so.11", + "libcudart.so.12", + "libcudart.so.13", + "libcudart.so.14", + "libcudart.so.15", + }; + for (auto* candidate: CANDIDATES) { + cudartHandle = dlopen(candidate, RTLD_NOW); + if (cudartHandle) { + break; + } } } #elif defined(_WIN32) - auto dllPathOpt = details::FindBestCudaDll(L"cudart64"); - if (dllPathOpt) { - auto dllPath = *dllPathOpt; - auto dir = dllPath.parent_path(); - // add the directory containing the DLL to the search path - SetDllDirectoryW(dir.c_str()); - - cudartHandle = LoadLibraryExW( - dllPath.c_str(), - nullptr, - LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR | - LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | - LOAD_LIBRARY_SEARCH_USER_DIRS - ); + // First try to find an already loaded cudart.dll to avoid loading + // multiple versions in the same process + cudartHandle = details::findLoadedLibrary(L"cudart64"); + + // then look into known path and pick the most recent version if + // multiple are found (e.g. cudart64_90.dll, cudart64_120.dll, etc.) + if (cudartHandle != nullptr) { + ownedHandle = false; + } else { + auto dllPathOpt = details::FindBestCudaDll(L"cudart64"); + if (dllPathOpt) { + auto dllPath = *dllPathOpt; + auto dir = dllPath.parent_path(); + // add the directory containing the DLL to the search path + SetDllDirectoryW(dir.c_str()); + + cudartHandle = LoadLibraryExW( + dllPath.c_str(), + nullptr, + LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR | + LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | + LOAD_LIBRARY_SEARCH_USER_DIRS + ); + ownedHandle = true; + } } #else #error "Platform not supported" @@ -629,7 +733,7 @@ class CUDART { dlclose(cudartHandle); } #elif defined(_WIN32) - if (cudartHandle) { + if (cudartHandle && ownedHandle) { FreeLibrary(static_cast(cudartHandle)); } #else @@ -642,6 +746,7 @@ class CUDART { CUDART& operator=(const CUDART&) = delete; void* cudartHandle = nullptr; + bool ownedHandle = false; }; /* @@ -721,9 +826,22 @@ class CUDADriver { CUDADriver() { #if defined(__linux__) || defined(__APPLE__) - cudaHandle = dlopen("libcuda.so", RTLD_NOW); + ownedHandle = true; + // check if libcuda is already loaded to avoid loading multiple versions + // in the same process + cudaHandle = details::findLoadedLibrary("libcuda"); + + if (cudaHandle == nullptr) { + cudaHandle = dlopen("libcuda.so", RTLD_NOW); + } #elif defined(_WIN32) - cudaHandle = LoadLibraryA("nvcuda.dll"); + cudaHandle = details::findLoadedLibrary(L"nvcuda.dll"); + if (cudaHandle) { + ownedHandle = false; + } else { + cudaHandle = LoadLibraryA("nvcuda.dll"); + ownedHandle = true; + } #else #error "Platform not supported" #endif @@ -761,7 +879,7 @@ class CUDADriver { dlclose(cudaHandle); } #elif defined(_WIN32) - if (cudaHandle) { + if (cudaHandle && ownedHandle) { FreeLibrary(static_cast(cudaHandle)); } #else @@ -774,6 +892,7 @@ class CUDADriver { CUDADriver& operator=(const CUDADriver&) = delete; void* cudaHandle = nullptr; + bool ownedHandle = false; }; /* @@ -790,8 +909,7 @@ class NVRTC { static bool loaded() { return instance().nvrtcHandle != nullptr; } - using nvrtcCreateProgram_t = - nvrtcResult (*)(nvrtcProgram*, const char*, const char*, int, const char*[], const char*[]); + using nvrtcCreateProgram_t = nvrtcResult (*)(nvrtcProgram*, const char*, const char*, int, const char*[], const char*[]); using nvrtcCompileProgram_t = nvrtcResult (*)(nvrtcProgram, int, const char*[]); using nvrtcGetPTX_t = nvrtcResult (*)(nvrtcProgram, char*); using nvrtcGetPTXSize_t = nvrtcResult (*)(nvrtcProgram, size_t*); @@ -819,36 +937,51 @@ class NVRTC { NVRTC() { #if defined(__linux__) || defined(__APPLE__) - static const char* CANDIDATES[] = { - "libnvrtc.so", - "libnvrtc.so.11", - "libnvrtc.so.12", - "libnvrtc.so.13", - "libnvrtc.so.14", - "libnvrtc.so.15", - }; - for (auto* candidate: CANDIDATES) { - nvrtcHandle = dlopen(candidate, RTLD_NOW); - if (nvrtcHandle != nullptr) { - break; + ownedHandle = true; + // check if libnvrtc is already loaded to avoid loading multiple versions + // in the same process + nvrtcHandle = details::findLoadedLibrary("libnvrtc"); + if (!nvrtcHandle) { + static const char* CANDIDATES[] = { + "libnvrtc.so", + "libnvrtc.so.11", + "libnvrtc.so.12", + "libnvrtc.so.13", + "libnvrtc.so.14", + "libnvrtc.so.15", + }; + for (auto* candidate: CANDIDATES) { + nvrtcHandle = dlopen(candidate, RTLD_NOW); + if (nvrtcHandle != nullptr) { + break; + } } } #elif defined(_WIN32) - auto dllPathOpt = details::FindBestCudaDll(L"nvrtc64"); - if (dllPathOpt) { - auto dllPath = *dllPathOpt; - // add the directory containing the DLL to the search path - auto dir = dllPath.parent_path(); - SetDllDirectoryW(dir.c_str()); - - nvrtcHandle = LoadLibraryExW( - dllPath.c_str(), - nullptr, - LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR | - LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | - LOAD_LIBRARY_SEARCH_USER_DIRS - ); + // check if nvrtc.dll is already loaded to avoid loading multiple versions + nvrtcHandle = details::findLoadedLibrary(L"nvrtc64"); + if (nvrtcHandle) { + ownedHandle = false; + } else { + // otherwise, look into known path and pick the most recent version + // if multiple are found + auto dllPathOpt = details::FindBestCudaDll(L"nvrtc64"); + if (dllPathOpt) { + auto dllPath = *dllPathOpt; + // add the directory containing the DLL to the search path + auto dir = dllPath.parent_path(); + SetDllDirectoryW(dir.c_str()); + + nvrtcHandle = LoadLibraryExW( + dllPath.c_str(), + nullptr, + LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR | + LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | + LOAD_LIBRARY_SEARCH_USER_DIRS + ); + ownedHandle = true; + } } #else #error "Platform not supported" @@ -877,7 +1010,7 @@ class NVRTC { dlclose(nvrtcHandle); } #elif defined(_WIN32) - if (nvrtcHandle) { + if (nvrtcHandle && ownedHandle) { FreeLibrary(static_cast(nvrtcHandle)); } #else @@ -890,6 +1023,7 @@ class NVRTC { NVRTC& operator=(const NVRTC&) = delete; void* nvrtcHandle = nullptr; + bool ownedHandle = false; }; namespace details { From 90d85e3b80a5bd42fea5282c08fda02fae720f1b Mon Sep 17 00:00:00 2001 From: Guillaume Fraux Date: Mon, 4 May 2026 10:51:22 +0200 Subject: [PATCH 2/2] Add more context when we fail to load a symbol --- gpulite/gpulite.hpp | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/gpulite/gpulite.hpp b/gpulite/gpulite.hpp index b0a3526..35b4d8c 100644 --- a/gpulite/gpulite.hpp +++ b/gpulite/gpulite.hpp @@ -276,13 +276,40 @@ namespace details { // Define a template to dynamically load symbols template FuncType loadSymbol(void* handle, const char* functionName) { #if defined(__linux__) || defined(__APPLE__) + dlerror(); // Clear any existing error auto func = reinterpret_cast(dlsym(handle, functionName)); + + if (!func) { + auto* error = dlerror(); + throw std::runtime_error( + std::string("Failed to load function: ") + functionName + ": " + + (error ? error : "unknown error") + ); + } #elif defined(_WIN32) auto func = reinterpret_cast(GetProcAddress(static_cast(handle), functionName)); -#endif + if (!func) { - throw std::runtime_error(std::string("Failed to load function: ") + functionName); + auto errorCode = GetLastError(); + LPSTR errorMsg = nullptr; + FormatMessageA( + FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + nullptr, + errorCode, + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + reinterpret_cast(&errorMsg), + 0, + nullptr + ); + auto errorStr = ( + "Failed to load function: " + std::string(functionName) + ": " + + (errorMsg ? errorMsg : "unknown error") + ); + LocalFree(errorMsg); + throw std::runtime_error(std::move(errorStr)); } +#endif + return func; }