diff --git a/cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in b/cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in index edd61ecaf6..044dabba04 100644 --- a/cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in +++ b/cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in @@ -55,54 +55,60 @@ cdef int cuPythonInit() except -1 nogil: except: handle = None - # Else try default search - if not handle: - LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000 - try: - handle = win32api.LoadLibraryEx("nvrtc64_112_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS) - except: - try: - handle = win32api.LoadLibraryEx("nvrtc64_111_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS) - except: - try: - handle = win32api.LoadLibraryEx("nvrtc64_110_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS) - except: - pass - - # Final check if DLLs can be found within pip installations + # Next check if DLLs can be found within pip installations if not handle: + LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000 + LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100 site_packages = [site.getusersitepackages()] + site.getsitepackages() for sp in site_packages: mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin") if not os.path.isdir(mod_path): continue os.add_dll_directory(mod_path) - LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000 - LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100 - try: - handle = win32api.LoadLibraryEx( - # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path... - os.path.join(mod_path, "nvrtc64_112_0.dll"), - 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) - - # Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is - # located in the same mod_path. - # Update PATH environ so that the two dlls can find each other - os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path)) - except: try: handle = win32api.LoadLibraryEx( - os.path.join(mod_path, "nvrtc64_111_0.dll"), + # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path... + os.path.join(mod_path, "nvrtc64_112_0.dll"), 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) + + # Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is + # located in the same mod_path. + # Update PATH environ so that the two dlls can find each other os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path)) except: try: handle = win32api.LoadLibraryEx( - os.path.join(mod_path, "nvrtc64_110_0.dll"), + os.path.join(mod_path, "nvrtc64_111_0.dll"), 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path)) except: - pass + try: + handle = win32api.LoadLibraryEx( + os.path.join(mod_path, "nvrtc64_110_0.dll"), + 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) + os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path)) + except: + pass + else: + break + else: + break + else: + break + else: + # Else try default search + # Only reached if DLL wasn't found in any site-package path + LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000 + try: + handle = win32api.LoadLibraryEx("nvrtc64_112_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS) + except: + try: + handle = win32api.LoadLibraryEx("nvrtc64_111_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS) + except: + try: + handle = win32api.LoadLibraryEx("nvrtc64_110_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS) + except: + pass if not handle: raise RuntimeError('Failed to LoadLibraryEx nvrtc64_112_0.dll, or nvrtc64_111_0.dll, or nvrtc64_110_0.dll') diff --git a/cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx b/cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx index b8e6795478..7a809d6dde 100644 --- a/cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx +++ b/cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx @@ -41,7 +41,7 @@ cdef void* __nvvmGetProgramLog = NULL cdef inline list get_site_packages(): - return [site.getusersitepackages()] + site.getsitepackages() + return [site.getusersitepackages()] + site.getsitepackages() + ["conda"] cdef load_library(const int driver_ver): @@ -50,44 +50,42 @@ cdef load_library(const int driver_ver): for suffix in get_nvvm_dso_version_suffix(driver_ver): if len(suffix) == 0: continue - dll_name = "nvvm64_40_0" + dll_name = "nvvm64_40_0.dll" # First check if the DLL has been loaded by 3rd parties try: - handle = win32api.GetModuleHandle(dll_name) + return win32api.GetModuleHandle(dll_name) except: pass - else: - break - # Next, check if DLLs are installed via pip + # Next, check if DLLs are installed via pip or conda for sp in get_site_packages(): - mod_path = os.path.join(sp, "nvidia", "cuda_nvcc", "nvvm", "bin") - if not os.path.isdir(mod_path): - continue - os.add_dll_directory(mod_path) - try: - handle = win32api.LoadLibraryEx( - # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path... - os.path.join(mod_path, dll_name), - 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) - except: - pass - else: - break + if sp == "conda": + # nvvm is not under $CONDA_PREFIX/lib, so it's not in the default search path + conda_prefix = os.environ.get("CONDA_PREFIX") + if conda_prefix is None: + continue + mod_path = os.path.join(conda_prefix, "Library", "nvvm", "bin") + else: + mod_path = os.path.join(sp, "nvidia", "cuda_nvcc", "nvvm", "bin") + if os.path.isdir(mod_path): + os.add_dll_directory(mod_path) + try: + return win32api.LoadLibraryEx( + # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path... + os.path.join(mod_path, dll_name), + 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) + except: + pass # Finally, try default search + # Only reached if DLL wasn't found in any site-package path try: - handle = win32api.LoadLibrary(dll_name) + return win32api.LoadLibrary(dll_name) except: pass - else: - break - else: - raise RuntimeError('Failed to load nvvm') - assert handle != 0 - return handle + raise RuntimeError('Failed to load nvvm') cdef int _check_or_init_nvvm() except -1 nogil: diff --git a/cuda_bindings/setup.py b/cuda_bindings/setup.py index 4968d54eb2..1067d2922a 100644 --- a/cuda_bindings/setup.py +++ b/cuda_bindings/setup.py @@ -307,7 +307,7 @@ def build_extension(self, ext): # to /site-packages/nvidia/cuda_nvcc/nvvm/lib64/ rel1 = "$ORIGIN/../../../nvidia/cuda_nvcc/nvvm/lib64" # from /lib/python3.*/site-packages/cuda/bindings/_internal/ - # to /lib/nvvm/lib64/ + # to /nvvm/lib64/ rel2 = "$ORIGIN/../../../../../../nvvm/lib64" ldflag = f"-Wl,--disable-new-dtags,-rpath,{rel1},-rpath,{rel2}" else: