Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 104 additions & 38 deletions cuda_pathfinder/cuda/pathfinder/_dynamic_libs/load_dl_linux.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,57 +5,126 @@
import ctypes
import ctypes.util
import os
from typing import Optional
from typing import Optional, cast

from cuda.pathfinder._dynamic_libs.load_dl_common import LoadedDL
from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import SUPPORTED_LINUX_SONAMES

CDLL_MODE = os.RTLD_NOW | os.RTLD_GLOBAL

LIBDL_PATH = ctypes.util.find_library("dl") or "libdl.so.2"
LIBDL = ctypes.CDLL(LIBDL_PATH)
LIBDL.dladdr.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
LIBDL.dladdr.restype = ctypes.c_int

def _load_libdl() -> ctypes.CDLL:
# In normal glibc-based Linux environments, find_library("dl") should return
# something like "libdl.so.2". In minimal or stripped-down environments
# (no ldconfig/gcc, incomplete linker cache), this can return None even
# though libdl is present. In that case, we fall back to the stable SONAME.
name = ctypes.util.find_library("dl") or "libdl.so.2"
try:
return ctypes.CDLL(name)
except OSError as e:
raise RuntimeError(f"Could not load {name!r} (required for dlinfo/dlerror on Linux)") from e


LIBDL = _load_libdl()

# dlinfo
LIBDL.dlinfo.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p]
LIBDL.dlinfo.restype = ctypes.c_int

# dlerror (thread-local error string; cleared after read)
LIBDL.dlerror.argtypes = []
LIBDL.dlerror.restype = ctypes.c_char_p

# First appeared in 2004-era glibc. Universally correct on Linux for all practical purposes.
RTLD_DI_LINKMAP = 2
RTLD_DI_ORIGIN = 6

class DlInfo(ctypes.Structure):
"""Structure used by dladdr to return information about a loaded symbol."""

class _LinkMapLNameView(ctypes.Structure):
"""
Prefix-only view of glibc's `struct link_map` used **solely** to read `l_name`.

Background:
- `dlinfo(handle, RTLD_DI_LINKMAP, ...)` returns a `struct link_map*`.
- The first few members of `struct link_map` (including `l_name`) have been
stable on glibc for decades and are documented as debugger-visible.
- We only need the offset/layout of `l_name`, not the full struct.

Safety constraints:
- This is a **partial** definition (prefix). It must only be used via a pointer
returned by `dlinfo(...)`.
- Do **not** instantiate it or pass it **by value** to any C function.
- Do **not** access any members beyond those declared here.
- Do **not** rely on `ctypes.sizeof(LinkMapPrefix)` for allocation.

Rationale:
- Defining only the leading fields avoids depending on internal/unstable
tail members while keeping code more readable than raw pointer arithmetic.
"""

_fields_ = (
("dli_fname", ctypes.c_char_p), # path to .so
("dli_fbase", ctypes.c_void_p),
("dli_sname", ctypes.c_char_p),
("dli_saddr", ctypes.c_void_p),
("l_addr", ctypes.c_void_p), # ElfW(Addr)
("l_name", ctypes.c_char_p), # char*
)


def abs_path_for_dynamic_library(libname: str, handle: ctypes.CDLL) -> Optional[str]:
"""Get the absolute path of a loaded dynamic library on Linux.
# Defensive assertions, mainly to document the invariants we depend on
assert _LinkMapLNameView.l_addr.offset == 0
assert _LinkMapLNameView.l_name.offset == ctypes.sizeof(ctypes.c_void_p)

Args:
libname: The name of the library
handle: The library handle

Returns:
The absolute path to the library file, or None if no expected symbol is found
def _dl_last_error() -> Optional[str]:
msg_bytes = cast(Optional[bytes], LIBDL.dlerror())
if not msg_bytes:
return None # no pending error
# Never raises; undecodable bytes are mapped to U+DC80..U+DCFF
return msg_bytes.decode("utf-8", "surrogateescape")

Raises:
OSError: If dladdr fails to get information about the symbol
"""
from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import EXPECTED_LIB_SYMBOLS

for symbol_name in EXPECTED_LIB_SYMBOLS[libname]:
symbol = getattr(handle, symbol_name, None)
if symbol is not None:
break
else:
return None
def l_name_for_dynamic_library(libname: str, handle: ctypes.CDLL) -> str:
lm_view = ctypes.POINTER(_LinkMapLNameView)()
rc = LIBDL.dlinfo(ctypes.c_void_p(handle._handle), RTLD_DI_LINKMAP, ctypes.byref(lm_view))
if rc != 0:
err = _dl_last_error()
raise OSError(f"dlinfo failed for {libname=!r} (rc={rc})" + (f": {err}" if err else ""))
if not lm_view: # NULL link_map**
raise OSError(f"dlinfo returned NULL link_map pointer for {libname=!r}")

addr = ctypes.cast(symbol, ctypes.c_void_p)
info = DlInfo()
if LIBDL.dladdr(addr, ctypes.byref(info)) == 0:
raise OSError(f"dladdr failed for {libname=!r}")
return info.dli_fname.decode() # type: ignore[no-any-return]
l_name_bytes = lm_view.contents.l_name
if not l_name_bytes:
raise OSError(f"dlinfo returned empty link_map->l_name for {libname=!r}")

path = os.fsdecode(l_name_bytes)
if not path:
raise OSError(f"dlinfo returned empty l_name string for {libname=!r}")

return path


def l_origin_for_dynamic_library(libname: str, handle: ctypes.CDLL) -> str:
l_origin_buf = ctypes.create_string_buffer(4096)
rc = LIBDL.dlinfo(ctypes.c_void_p(handle._handle), RTLD_DI_ORIGIN, l_origin_buf)
if rc != 0:
err = _dl_last_error()
raise OSError(f"dlinfo failed for {libname=!r} (rc={rc})" + (f": {err}" if err else ""))

path = os.fsdecode(l_origin_buf.value)
if not path:
raise OSError(f"dlinfo returned empty l_origin string for {libname=!r}")

return path


def abs_path_for_dynamic_library(libname: str, handle: ctypes.CDLL) -> str:
l_name = l_name_for_dynamic_library(libname, handle)
l_origin = l_origin_for_dynamic_library(libname, handle)
return os.path.join(l_origin, os.path.basename(l_name))


def get_candidate_sonames(libname: str) -> list[str]:
candidate_sonames = list(SUPPORTED_LINUX_SONAMES.get(libname, ()))
candidate_sonames.append(f"lib{libname}.so")
return candidate_sonames


def check_if_already_loaded_from_elsewhere(libname: str) -> Optional[LoadedDL]:
Expand All @@ -72,9 +141,8 @@ def check_if_already_loaded_from_elsewhere(libname: str) -> Optional[LoadedDL]:
>>> if loaded is not None:
... print(f"Library already loaded from {loaded.abs_path}")
"""
from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import SUPPORTED_LINUX_SONAMES

for soname in SUPPORTED_LINUX_SONAMES.get(libname, ()):
for soname in get_candidate_sonames(libname):
try:
handle = ctypes.CDLL(soname, mode=os.RTLD_NOLOAD)
except OSError:
Expand All @@ -96,9 +164,7 @@ def load_with_system_search(libname: str) -> Optional[LoadedDL]:
Raises:
RuntimeError: If the library is loaded but no expected symbol is found
"""
candidate_sonames = list(SUPPORTED_LINUX_SONAMES.get(libname, ()))
candidate_sonames.append(f"lib{libname}.so")
for soname in candidate_sonames:
for soname in get_candidate_sonames(libname):
try:
handle = ctypes.CDLL(soname, CDLL_MODE)
abs_path = abs_path_for_dynamic_library(libname, handle)
Expand Down
11 changes: 4 additions & 7 deletions cuda_pathfinder/cuda/pathfinder/_dynamic_libs/load_dl_windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from typing import Optional

from cuda.pathfinder._dynamic_libs.load_dl_common import LoadedDL
from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import (
LIBNAMES_REQUIRING_OS_ADD_DLL_DIRECTORY,
SUPPORTED_WINDOWS_DLLS,
)

# Mirrors WinBase.h (unfortunately not defined already elsewhere)
WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
Expand Down Expand Up @@ -110,7 +114,6 @@ def check_if_already_loaded_from_elsewhere(libname: str) -> Optional[LoadedDL]:
>>> if loaded is not None:
... print(f"Library already loaded from {loaded.abs_path}")
"""
from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import SUPPORTED_WINDOWS_DLLS

for dll_name in SUPPORTED_WINDOWS_DLLS.get(libname, ()):
handle = kernel32.GetModuleHandleW(dll_name)
Expand All @@ -129,8 +132,6 @@ def load_with_system_search(libname: str) -> Optional[LoadedDL]:
Returns:
A LoadedDL object if successful, None if the library cannot be loaded
"""
from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import SUPPORTED_WINDOWS_DLLS

for dll_name in SUPPORTED_WINDOWS_DLLS.get(libname, ()):
handle = kernel32.LoadLibraryExW(dll_name, None, 0)
if handle:
Expand All @@ -153,10 +154,6 @@ def load_with_abs_path(libname: str, found_path: str) -> LoadedDL:
Raises:
RuntimeError: If the DLL cannot be loaded
"""
from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import (
LIBNAMES_REQUIRING_OS_ADD_DLL_DIRECTORY,
)

if libname in LIBNAMES_REQUIRING_OS_ADD_DLL_DIRECTORY:
add_dll_directory(found_path)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# SUPPORTED_LIBNAMES
# SUPPORTED_WINDOWS_DLLS
# SUPPORTED_LINUX_SONAMES
# EXPECTED_LIB_SYMBOLS

import sys

Expand Down Expand Up @@ -401,39 +400,3 @@ def is_suppressed_dll_file(path_basename: str) -> bool:
# nvrtc64_120_0.dll
return path_basename.endswith(".alt.dll") or "-builtins" in path_basename
return path_basename.startswith(("cudart32_", "nvvm32"))


# Based on `nm -D --defined-only` output for Linux x86_64 distributions.
EXPECTED_LIB_SYMBOLS = {
"nvJitLink": (
"__nvJitLinkCreate_12_0", # 12.0 through 12.9
"nvJitLinkVersion", # 12.3 and up
),
"nvrtc": ("nvrtcVersion",),
"nvvm": ("nvvmVersion",),
"cudart": ("cudaRuntimeGetVersion",),
"nvfatbin": ("nvFatbinVersion",),
"cublas": ("cublasGetVersion",),
"cublasLt": ("cublasLtGetVersion",),
"cufft": ("cufftGetVersion",),
"cufftw": ("fftwf_malloc",),
"curand": ("curandGetVersion",),
"cusolver": ("cusolverGetVersion",),
"cusolverMg": ("cusolverMgCreate",),
"cusparse": ("cusparseGetVersion",),
"nppc": ("nppGetLibVersion",),
"nppial": ("nppiAdd_32f_C1R_Ctx",),
"nppicc": ("nppiColorToGray_8u_C3C1R_Ctx",),
"nppidei": ("nppiCopy_8u_C1R_Ctx",),
"nppif": ("nppiFilterSobelHorizBorder_8u_C1R_Ctx",),
"nppig": ("nppiResize_8u_C1R_Ctx",),
"nppim": ("nppiErode_8u_C1R_Ctx",),
"nppist": ("nppiMean_8u_C1R_Ctx",),
"nppisu": ("nppiFree",),
"nppitc": ("nppiThreshold_8u_C1R_Ctx",),
"npps": ("nppsAdd_32f_Ctx",),
"nvblas": ("dgemm",),
"cufile": ("cuFileGetVersion",),
# "cufile_rdma": ("rdma_buffer_reg",),
"nvjpeg": ("nvjpegCreate",),
}
2 changes: 1 addition & 1 deletion cuda_pathfinder/cuda/pathfinder/_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

__version__ = "1.1.1a0"
__version__ = "1.1.1a1"
14 changes: 8 additions & 6 deletions cuda_pathfinder/tests/test_load_nvidia_dynamic_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,6 @@ def test_supported_libnames_windows_libnames_requiring_os_add_dll_directory_cons
)


def test_supported_libnames_all_expected_lib_symbols_consistency():
assert tuple(sorted(supported_nvidia_libs.SUPPORTED_LIBNAMES_ALL)) == tuple(
sorted(supported_nvidia_libs.EXPECTED_LIB_SYMBOLS.keys())
)


def test_runtime_error_on_non_64bit_python():
with (
patch("struct.calcsize", return_value=3), # fake 24-bit pointer
Expand All @@ -68,6 +62,12 @@ def build_child_process_failed_for_libname_message(libname, result):
)


def validate_abs_path(abs_path):
assert abs_path, f"empty path: {abs_path=!r}"
assert os.path.isabs(abs_path), f"not absolute: {abs_path=!r}"
assert os.path.isfile(abs_path), f"not a file: {abs_path=!r}"


def child_process_func(libname):
import os

Expand All @@ -76,6 +76,7 @@ def child_process_func(libname):
loaded_dl_fresh = load_nvidia_dynamic_lib(libname)
if loaded_dl_fresh.was_already_loaded_from_elsewhere:
raise RuntimeError("loaded_dl_fresh.was_already_loaded_from_elsewhere")
validate_abs_path(loaded_dl_fresh.abs_path)

loaded_dl_from_cache = load_nvidia_dynamic_lib(libname)
if loaded_dl_from_cache is not loaded_dl_fresh:
Expand All @@ -86,6 +87,7 @@ def child_process_func(libname):
raise RuntimeError("loaded_dl_no_cache.was_already_loaded_from_elsewhere")
if not os.path.samefile(loaded_dl_no_cache.abs_path, loaded_dl_fresh.abs_path):
raise RuntimeError(f"not os.path.samefile({loaded_dl_no_cache.abs_path=!r}, {loaded_dl_fresh.abs_path=!r})")
validate_abs_path(loaded_dl_no_cache.abs_path)

sys.stdout.write(f"{loaded_dl_fresh.abs_path!r}\n")

Expand Down
Loading