diff --git a/cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in b/cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in index a0f8a27a0f..caf36d40e8 100644 --- a/cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in +++ b/cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in @@ -60,36 +60,37 @@ cdef int cuPythonInit() except -1 nogil: except: handle = None - # Else try default search - if not handle: - LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000 - try: - handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS) - except: - pass - - # Final check if DLLs can be found within pip installations + # Check if DLLs can be found within pip installations if not handle: + LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000 + LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100 site_packages = [site.getusersitepackages()] + site.getsitepackages() for sp in site_packages: mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin") - if not os.path.isdir(mod_path): - continue - os.add_dll_directory(mod_path) - LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000 - LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100 - try: - handle = win32api.LoadLibraryEx( - # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path... - os.path.join(mod_path, "nvrtc64_120_0.dll"), - 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) - - # Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is - # located in the same mod_path. - # Update PATH environ so that the two dlls can find each other - os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path)) - except: - pass + if os.path.isdir(mod_path): + os.add_dll_directory(mod_path) + try: + handle = win32api.LoadLibraryEx( + # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path... + os.path.join(mod_path, "nvrtc64_120_0.dll"), + 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) + + # Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is + # located in the same mod_path. + # Update PATH environ so that the two dlls can find each other + os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path)) + except: + pass + else: + break + else: + # Else try default search + # Only reached if DLL wasn't found in any site-package path + LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000 + try: + handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS) + except: + pass if not handle: raise RuntimeError('Failed to LoadLibraryEx nvrtc64_120_0.dll') diff --git a/cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx b/cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx index c8c7e6b298..9798204424 100644 --- a/cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx +++ b/cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx @@ -56,40 +56,30 @@ cdef load_library(const int driver_ver): # First check if the DLL has been loaded by 3rd parties try: - handle = win32api.GetModuleHandle(dll_name) + return win32api.GetModuleHandle(dll_name) except: pass - else: - break # Next, check if DLLs are installed via pip for sp in get_site_packages(): mod_path = os.path.join(sp, "nvidia", "nvJitLink", "bin") - if not os.path.isdir(mod_path): - continue - os.add_dll_directory(mod_path) - try: - handle = win32api.LoadLibraryEx( - # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path... - os.path.join(mod_path, dll_name), - 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) - except: - pass - else: - break - + if os.path.isdir(mod_path): + os.add_dll_directory(mod_path) + try: + return win32api.LoadLibraryEx( + # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path... + os.path.join(mod_path, dll_name), + 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) + except: + pass # Finally, try default search + # Only reached if DLL wasn't found in any site-package path try: - handle = win32api.LoadLibrary(dll_name) + return win32api.LoadLibrary(dll_name) except: pass - else: - break - else: - raise RuntimeError('Failed to load nvJitLink') - assert handle != 0 - return handle + raise RuntimeError('Failed to load nvJitLink') cdef int _check_or_init_nvjitlink() except -1 nogil: diff --git a/cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx b/cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx index 76ce232542..9f507e8e1b 100644 --- a/cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx +++ b/cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx @@ -41,7 +41,7 @@ cdef void* __nvvmGetProgramLog = NULL cdef inline list get_site_packages(): - return [site.getusersitepackages()] + site.getsitepackages() + return [site.getusersitepackages()] + site.getsitepackages() + ["conda"] cdef load_library(const int driver_ver): @@ -50,44 +50,42 @@ cdef load_library(const int driver_ver): for suffix in get_nvvm_dso_version_suffix(driver_ver): if len(suffix) == 0: continue - dll_name = "nvvm64_40_0" + dll_name = "nvvm64_40_0.dll" # First check if the DLL has been loaded by 3rd parties try: - handle = win32api.GetModuleHandle(dll_name) + return win32api.GetModuleHandle(dll_name) except: pass - else: - break - # Next, check if DLLs are installed via pip + # Next, check if DLLs are installed via pip or conda for sp in get_site_packages(): - mod_path = os.path.join(sp, "nvidia", "cuda_nvcc", "nvvm", "bin") - if not os.path.isdir(mod_path): - continue - os.add_dll_directory(mod_path) - try: - handle = win32api.LoadLibraryEx( - # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path... - os.path.join(mod_path, dll_name), - 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) - except: - pass - else: - break + if sp == "conda": + # nvvm is not under $CONDA_PREFIX/lib, so it's not in the default search path + conda_prefix = os.environ.get("CONDA_PREFIX") + if conda_prefix is None: + continue + mod_path = os.path.join(conda_prefix, "Library", "nvvm", "bin") + else: + mod_path = os.path.join(sp, "nvidia", "cuda_nvcc", "nvvm", "bin") + if os.path.isdir(mod_path): + os.add_dll_directory(mod_path) + try: + return win32api.LoadLibraryEx( + # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path... + os.path.join(mod_path, dll_name), + 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) + except: + pass # Finally, try default search + # Only reached if DLL wasn't found in any site-package path try: - handle = win32api.LoadLibrary(dll_name) + return win32api.LoadLibrary(dll_name) except: pass - else: - break - else: - raise RuntimeError('Failed to load nvvm') - assert handle != 0 - return handle + raise RuntimeError('Failed to load nvvm') cdef int _check_or_init_nvvm() except -1 nogil: diff --git a/cuda_bindings/setup.py b/cuda_bindings/setup.py index d9fa93d5ce..ed84ef4f1b 100644 --- a/cuda_bindings/setup.py +++ b/cuda_bindings/setup.py @@ -390,7 +390,7 @@ def build_extension(self, ext): # to /site-packages/nvidia/cuda_nvcc/nvvm/lib64/ rel1 = "$ORIGIN/../../../nvidia/cuda_nvcc/nvvm/lib64" # from /lib/python3.*/site-packages/cuda/bindings/_internal/ - # to /lib/nvvm/lib64/ + # to /nvvm/lib64/ rel2 = "$ORIGIN/../../../../../../nvvm/lib64" ldflag = f"-Wl,--disable-new-dtags,-rpath,{rel1},-rpath,{rel2}" else: