Skip to content
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
9ad6bcd
_decide_nvjitlink_or_driver(): catch RuntimeError (bug fix), use impo…
rwgk Sep 30, 2025
9390c85
Fix misunderstanding: RuntimeError is raised only from inner_nvjitlin…
rwgk Oct 1, 2025
2ccdc21
Merge branch 'main' into linker_decide_nvjitlink_or_driver_fix
rwgk Oct 1, 2025
6bf421e
Better way of formatting warning messages.
rwgk Oct 1, 2025
9220d58
Change from importlib.import_module() to plain import (the latter doe…
rwgk Oct 1, 2025
c9a95c0
Merge branch 'main' into linker_decide_nvjitlink_or_driver_fix
rwgk Oct 2, 2025
7e9fce4
Enhance to warning messages, to make them actionable.
rwgk Oct 2, 2025
4724769
Factor out _nvjitlink_has_version_symbol() for clarity and testability
rwgk Oct 2, 2025
6630d8e
Add test_linker_warnings.py
rwgk Oct 2, 2025
f1c29f6
Fix "the the" oversight
rwgk Oct 2, 2025
0948942
Replace "culink APIs" → "driver APIs" in warning message.
rwgk Oct 2, 2025
00e6421
Merge branch 'main' into linker_decide_nvjitlink_or_driver_fix
rwgk Oct 3, 2025
b237669
Fix oversight: test_linker_warnings.py needs to be updated after comm…
rwgk Oct 3, 2025
6787ddf
Merge branch 'main' into linker_decide_nvjitlink_or_driver_fix
rwgk Oct 3, 2025
4dc08e7
fix skipping the check for nvidia-smi (#1084)
leofang Oct 4, 2025
9721041
Merge branch 'main' into linker_decide_nvjitlink_or_driver_fix
leofang Oct 6, 2025
c1a767f
Merge branch 'main' into linker_decide_nvjitlink_or_driver_fix
rwgk Oct 7, 2025
5ba11be
rm cuda_core/tests/test_linker_warnings.py: see https://github.com/NV…
rwgk Oct 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 33 additions & 19 deletions cuda_core/cuda/core/experimental/_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import ctypes
import sys
import weakref
from contextlib import contextmanager
from dataclasses import dataclass
Expand All @@ -28,6 +29,11 @@
_nvjitlink_input_types = None # populated if nvJitLink cannot be used


def _nvjitlink_has_version_symbol(inner_nvjitlink) -> bool:
# This condition is equivalent to testing for version >= 12.3
return bool(inner_nvjitlink._inspect_function_pointer("__nvJitLinkVersion"))


# Note: this function is reused in the tests
def _decide_nvjitlink_or_driver() -> bool:
"""Returns True if falling back to the cuLink* driver APIs."""
Expand All @@ -37,28 +43,36 @@ def _decide_nvjitlink_or_driver() -> bool:

_driver_ver = handle_return(driver.cuDriverGetVersion())
_driver_ver = (_driver_ver // 1000, (_driver_ver % 1000) // 10)

warn_txt_common = (
"the driver APIs will be used instead, which do not support"
" minor version compatibility or linking LTO IRs."
" For best results, consider upgrading to a recent version of"
)

try:
from cuda.bindings import nvjitlink as _nvjitlink
from cuda.bindings._internal import nvjitlink as inner_nvjitlink
except ImportError:
# binding is not available
_nvjitlink = None
import cuda.bindings.nvjitlink as _nvjitlink
except ModuleNotFoundError:
warn_txt = f"cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings."
else:
if inner_nvjitlink._inspect_function_pointer("__nvJitLinkVersion") == 0:
# binding is available, but nvJitLink is not installed
_nvjitlink = None

if _nvjitlink is None:
warn(
"nvJitLink is not installed or too old (<12.3). Therefore it is not usable "
"and the culink APIs will be used instead.",
stacklevel=3,
category=RuntimeWarning,
from cuda.bindings._internal import nvjitlink as inner_nvjitlink

try:
if _nvjitlink_has_version_symbol(inner_nvjitlink):
return False # Use nvjitlink
except RuntimeError:
warn_detail = "not available"
else:
warn_detail = "too old (<12.3)"
warn_txt = (
f"{'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is {warn_detail}."
f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink."
)
_driver = driver
return True
else:
return False
_nvjitlink = None

warn(warn_txt, stacklevel=2, category=RuntimeWarning)
_driver = driver
return True


def _lazy_init():
Expand Down