Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cuda_bindings/cuda/bindings/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Callable

from ._ptx_utils import get_minimal_required_cuda_ver_from_ptx_ver, get_ptx_ver
from ._version_check import warn_if_cuda_major_version_mismatch

_handle_getters: dict[type, Callable[[Any], int]] = {}

Expand Down
56 changes: 56 additions & 0 deletions cuda_bindings/cuda/bindings/utils/_version_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

import os
import warnings

# Track whether we've already checked major version compatibility
_major_version_compatibility_checked = False


def warn_if_cuda_major_version_mismatch():
"""Warn if the CUDA driver major version is older than cuda-bindings compile-time version.

This function compares the CUDA major version that cuda-bindings was compiled
against with the CUDA major version supported by the installed driver. If the
compile-time major version is greater than the driver's major version, a warning
is issued.

The check runs only once per process. Subsequent calls are no-ops.

The warning can be suppressed by setting the environment variable
``CUDA_PYTHON_DISABLE_MAJOR_VERSION_WARNING=1``.
"""
global _major_version_compatibility_checked
if _major_version_compatibility_checked:
return
_major_version_compatibility_checked = True

# Allow users to suppress the warning
if os.environ.get("CUDA_PYTHON_DISABLE_MAJOR_VERSION_WARNING"):
return

# Import here to avoid circular imports and allow lazy loading
from cuda.bindings import driver

# Get compile-time CUDA version from cuda-bindings
compile_version = driver.CUDA_VERSION # e.g., 13010
compile_major = compile_version // 1000

# Get runtime driver version
err, runtime_version = driver.cuDriverGetVersion()
if err != driver.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to query CUDA driver version: {err}")

runtime_major = runtime_version // 1000

if compile_major > runtime_major:
warnings.warn(
f"cuda-bindings was built for CUDA major version {compile_major}, but the "
f"NVIDIA driver only supports up to CUDA {runtime_major}. Some cuda-bindings "
f"features may not work correctly. Consider updating your NVIDIA driver, "
f"or using a cuda-bindings version built for CUDA {runtime_major}. "
f"(Set CUDA_PYTHON_DISABLE_MAJOR_VERSION_WARNING=1 to suppress this warning.)",
UserWarning,
stacklevel=3,
)
2 changes: 2 additions & 0 deletions cuda_bindings/docs/source/environment_variables.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ Runtime Environment Variables

- ``CUDA_PYTHON_CUDA_PER_THREAD_DEFAULT_STREAM`` : When set to 1, the default stream is the per-thread default stream. When set to 0, the default stream is the legacy default stream. This defaults to 0, for the legacy default stream. See `Stream Synchronization Behavior <https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html>`_ for an explanation of the legacy and per-thread default streams.

- ``CUDA_PYTHON_DISABLE_MAJOR_VERSION_WARNING`` : When set to 1, suppresses warnings about CUDA major version mismatches between ``cuda-bindings`` and the installed driver.


Build-Time Environment Variables
--------------------------------
Expand Down
98 changes: 98 additions & 0 deletions cuda_bindings/tests/test_version_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

import os
import warnings
from unittest import mock

import pytest
from cuda.bindings import driver
from cuda.bindings.utils import _version_check, warn_if_cuda_major_version_mismatch


class TestVersionCompatibilityCheck:
"""Tests for CUDA major version mismatch warning function."""

def setup_method(self):
"""Reset the version compatibility check flag before each test."""
_version_check._major_version_compatibility_checked = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is effectively going to reset if the version was checked already and after the unit tests executes the check could potentially execute again. Shouldn't we cache the current state and before setting it to false for the unit tests and then restore to its previous value on unit test teardown?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Someone with a dodgy build might get multiple warnings during testing. Not sure it's worth fixing.


def teardown_method(self):
"""Reset the version compatibility check flag after each test."""
_version_check._major_version_compatibility_checked = False

def test_no_warning_when_driver_newer(self):
"""No warning should be issued when driver version >= compile version."""
# Mock compile version 12.9 and driver version 13.0
with (
mock.patch.object(driver, "CUDA_VERSION", 12090),
mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 13000)),
warnings.catch_warnings(record=True) as w,
):
warnings.simplefilter("always")
warn_if_cuda_major_version_mismatch()
assert len(w) == 0

def test_no_warning_when_same_major_version(self):
"""No warning should be issued when major versions match."""
# Mock compile version 12.9 and driver version 12.8
with (
mock.patch.object(driver, "CUDA_VERSION", 12090),
mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 12080)),
warnings.catch_warnings(record=True) as w,
):
warnings.simplefilter("always")
warn_if_cuda_major_version_mismatch()
assert len(w) == 0

def test_warning_when_compile_major_newer(self):
"""Warning should be issued when compile major version > driver major version."""
# Mock compile version 13.0 and driver version 12.8
with (
mock.patch.object(driver, "CUDA_VERSION", 13000),
mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 12080)),
warnings.catch_warnings(record=True) as w,
):
warnings.simplefilter("always")
warn_if_cuda_major_version_mismatch()
assert len(w) == 1
assert issubclass(w[0].category, UserWarning)
assert "cuda-bindings was built for CUDA major version 13" in str(w[0].message)
assert "only supports up to CUDA 12" in str(w[0].message)

def test_warning_only_issued_once(self):
"""Warning should only be issued once per process."""
with (
mock.patch.object(driver, "CUDA_VERSION", 13000),
mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 12080)),
warnings.catch_warnings(record=True) as w,
):
warnings.simplefilter("always")
warn_if_cuda_major_version_mismatch()
warn_if_cuda_major_version_mismatch()
warn_if_cuda_major_version_mismatch()
# Only one warning despite multiple calls
assert len(w) == 1

def test_warning_suppressed_by_env_var(self):
"""Warning should be suppressed when CUDA_PYTHON_DISABLE_MAJOR_VERSION_WARNING is set."""
with (
mock.patch.object(driver, "CUDA_VERSION", 13000),
mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 12080)),
mock.patch.dict(os.environ, {"CUDA_PYTHON_DISABLE_MAJOR_VERSION_WARNING": "1"}),
warnings.catch_warnings(record=True) as w,
):
warnings.simplefilter("always")
warn_if_cuda_major_version_mismatch()
assert len(w) == 0

def test_error_when_driver_version_fails(self):
"""Should raise RuntimeError if cuDriverGetVersion fails."""
with (
mock.patch.object(driver, "CUDA_VERSION", 13000),
mock.patch.object(
driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_ERROR_NOT_INITIALIZED, 0)
),
pytest.raises(RuntimeError, match="Failed to query CUDA driver version"),
):
warn_if_cuda_major_version_mismatch()
105 changes: 61 additions & 44 deletions cuda_core/cuda/core/_device.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -960,53 +960,12 @@ class Device:
__slots__ = ("_device_id", "_memory_resource", "_has_inited", "_properties", "_uuid", "_context")

def __new__(cls, device_id: Device | int | None = None):
# Handle device_id argument.
if isinstance(device_id, Device):
return device_id
else:
device_id = getattr(device_id, 'device_id', device_id)

# Initialize CUDA.
global _is_cuInit
if _is_cuInit is False:
with _lock, nogil:
HANDLE_RETURN(cydriver.cuInit(0))
_is_cuInit = True

# important: creating a Device instance does not initialize the GPU!
cdef cydriver.CUdevice dev
cdef cydriver.CUcontext ctx
if device_id is None:
with nogil:
err = cydriver.cuCtxGetDevice(&dev)
if err == cydriver.CUresult.CUDA_SUCCESS:
device_id = int(dev)
elif err == cydriver.CUresult.CUDA_ERROR_INVALID_CONTEXT:
# No context is current - verify and default to device 0 (cudart behavior)
assert cydriver.cuCtxGetCurrent(&ctx) == cydriver.CUresult.CUDA_SUCCESS and ctx == NULL
device_id = 0
else:
HANDLE_RETURN(err)
elif device_id < 0:
raise ValueError(f"device_id must be >= 0, got {device_id}")

# ensure Device is singleton
cdef int total
try:
devices = _tls.devices
except AttributeError:
with nogil:
HANDLE_RETURN(cydriver.cuDeviceGetCount(&total))
devices = _tls.devices = []
for i in range(total):
device = super().__new__(cls)
device._device_id = i
device._memory_resource = None
device._has_inited = False
device._properties = None
device._uuid = None
device._context = None
devices.append(device)
Device_ensure_cuda_initialized()
device_id = Device_resolve_device_id(device_id)
devices = Device_ensure_tls_devices(cls)

try:
return devices[device_id]
Expand Down Expand Up @@ -1393,3 +1352,61 @@ class Device:
"""
self._check_context_initialized()
return GraphBuilder._init(stream=self.create_stream(), is_stream_owner=True)


cdef inline void Device_ensure_cuda_initialized() except *:
"""Initialize CUDA driver and check version compatibility (once per process)."""
global _is_cuInit
if _is_cuInit is False:
with _lock, nogil:
HANDLE_RETURN(cydriver.cuInit(0))
_is_cuInit = True
try:
from cuda.bindings.utils import warn_if_cuda_major_version_mismatch
except ImportError:
pass
else:
warn_if_cuda_major_version_mismatch()


cdef inline int Device_resolve_device_id(device_id) except? -1:
"""Resolve device_id, defaulting to current device or 0."""
cdef cydriver.CUdevice dev
cdef cydriver.CUcontext ctx
cdef cydriver.CUresult err
if device_id is None:
with nogil:
err = cydriver.cuCtxGetDevice(&dev)
if err == cydriver.CUresult.CUDA_SUCCESS:
return int(dev)
elif err == cydriver.CUresult.CUDA_ERROR_INVALID_CONTEXT:
with nogil:
HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx))
assert <void*>(ctx) == NULL
return 0 # cudart behavior
else:
HANDLE_RETURN(err)
elif device_id < 0:
raise ValueError(f"device_id must be >= 0, got {device_id}")
return device_id


cdef inline list Device_ensure_tls_devices(cls):
"""Ensure thread-local Device singletons exist, creating if needed."""
cdef int total
try:
return _tls.devices
except AttributeError:
with nogil:
HANDLE_RETURN(cydriver.cuDeviceGetCount(&total))
devices = _tls.devices = []
for dev_id in range(total):
device = super(Device, cls).__new__(cls)
device._device_id = dev_id
device._memory_resource = None
device._has_inited = False
device._properties = None
device._uuid = None
device._context = None
devices.append(device)
return devices
Loading