-
Notifications
You must be signed in to change notification settings - Fork 237
Add CUDA version compatibility check #1412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Andy-Jost
wants to merge
8
commits into
NVIDIA:main
Choose a base branch
from
Andy-Jost:runtime-version-check
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+218
−44
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
d999f40
Add CUDA major version compatibility check
Andy-Jost 62dfcca
Merge remote-tracking branch 'origin/main' into runtime-version-check
Andy-Jost 73611ed
Move version check import to local scope
Andy-Jost b2083ed
Merge remote-tracking branch 'origin/main' into runtime-version-check
Andy-Jost fdb3a7e
Refactor Device.__new__ into helper functions
Andy-Jost 3a5c210
Merge branch 'main' into runtime-version-check
Andy-Jost 424a113
Merge branch 'main' into runtime-version-check
Andy-Jost 6071609
Merge branch 'main' into runtime-version-check
Andy-Jost File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE | ||
|
|
||
| import os | ||
| import warnings | ||
|
|
||
| # Track whether we've already checked major version compatibility | ||
| _major_version_compatibility_checked = False | ||
|
|
||
|
|
||
| def warn_if_cuda_major_version_mismatch(): | ||
| """Warn if the CUDA driver major version is older than cuda-bindings compile-time version. | ||
|
|
||
| This function compares the CUDA major version that cuda-bindings was compiled | ||
| against with the CUDA major version supported by the installed driver. If the | ||
| compile-time major version is greater than the driver's major version, a warning | ||
| is issued. | ||
|
|
||
| The check runs only once per process. Subsequent calls are no-ops. | ||
|
|
||
| The warning can be suppressed by setting the environment variable | ||
| ``CUDA_PYTHON_DISABLE_MAJOR_VERSION_WARNING=1``. | ||
| """ | ||
| global _major_version_compatibility_checked | ||
| if _major_version_compatibility_checked: | ||
| return | ||
| _major_version_compatibility_checked = True | ||
|
|
||
| # Allow users to suppress the warning | ||
| if os.environ.get("CUDA_PYTHON_DISABLE_MAJOR_VERSION_WARNING"): | ||
| return | ||
|
|
||
| # Import here to avoid circular imports and allow lazy loading | ||
| from cuda.bindings import driver | ||
|
|
||
| # Get compile-time CUDA version from cuda-bindings | ||
| compile_version = driver.CUDA_VERSION # e.g., 13010 | ||
| compile_major = compile_version // 1000 | ||
|
|
||
| # Get runtime driver version | ||
| err, runtime_version = driver.cuDriverGetVersion() | ||
| if err != driver.CUresult.CUDA_SUCCESS: | ||
| raise RuntimeError(f"Failed to query CUDA driver version: {err}") | ||
|
|
||
| runtime_major = runtime_version // 1000 | ||
|
|
||
| if compile_major > runtime_major: | ||
| warnings.warn( | ||
| f"cuda-bindings was built for CUDA major version {compile_major}, but the " | ||
| f"NVIDIA driver only supports up to CUDA {runtime_major}. Some cuda-bindings " | ||
| f"features may not work correctly. Consider updating your NVIDIA driver, " | ||
| f"or using a cuda-bindings version built for CUDA {runtime_major}. " | ||
| f"(Set CUDA_PYTHON_DISABLE_MAJOR_VERSION_WARNING=1 to suppress this warning.)", | ||
| UserWarning, | ||
| stacklevel=3, | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE | ||
|
|
||
| import os | ||
| import warnings | ||
| from unittest import mock | ||
|
|
||
| import pytest | ||
| from cuda.bindings import driver | ||
| from cuda.bindings.utils import _version_check, warn_if_cuda_major_version_mismatch | ||
|
|
||
|
|
||
| class TestVersionCompatibilityCheck: | ||
| """Tests for CUDA major version mismatch warning function.""" | ||
|
|
||
| def setup_method(self): | ||
| """Reset the version compatibility check flag before each test.""" | ||
| _version_check._major_version_compatibility_checked = False | ||
|
|
||
| def teardown_method(self): | ||
| """Reset the version compatibility check flag after each test.""" | ||
| _version_check._major_version_compatibility_checked = False | ||
|
|
||
| def test_no_warning_when_driver_newer(self): | ||
| """No warning should be issued when driver version >= compile version.""" | ||
| # Mock compile version 12.9 and driver version 13.0 | ||
| with ( | ||
| mock.patch.object(driver, "CUDA_VERSION", 12090), | ||
| mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 13000)), | ||
| warnings.catch_warnings(record=True) as w, | ||
| ): | ||
| warnings.simplefilter("always") | ||
| warn_if_cuda_major_version_mismatch() | ||
| assert len(w) == 0 | ||
|
|
||
| def test_no_warning_when_same_major_version(self): | ||
| """No warning should be issued when major versions match.""" | ||
| # Mock compile version 12.9 and driver version 12.8 | ||
| with ( | ||
| mock.patch.object(driver, "CUDA_VERSION", 12090), | ||
| mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 12080)), | ||
| warnings.catch_warnings(record=True) as w, | ||
| ): | ||
| warnings.simplefilter("always") | ||
| warn_if_cuda_major_version_mismatch() | ||
| assert len(w) == 0 | ||
|
|
||
| def test_warning_when_compile_major_newer(self): | ||
| """Warning should be issued when compile major version > driver major version.""" | ||
| # Mock compile version 13.0 and driver version 12.8 | ||
| with ( | ||
| mock.patch.object(driver, "CUDA_VERSION", 13000), | ||
| mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 12080)), | ||
| warnings.catch_warnings(record=True) as w, | ||
| ): | ||
| warnings.simplefilter("always") | ||
| warn_if_cuda_major_version_mismatch() | ||
| assert len(w) == 1 | ||
| assert issubclass(w[0].category, UserWarning) | ||
| assert "cuda-bindings was built for CUDA major version 13" in str(w[0].message) | ||
| assert "only supports up to CUDA 12" in str(w[0].message) | ||
|
|
||
| def test_warning_only_issued_once(self): | ||
| """Warning should only be issued once per process.""" | ||
| with ( | ||
| mock.patch.object(driver, "CUDA_VERSION", 13000), | ||
| mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 12080)), | ||
| warnings.catch_warnings(record=True) as w, | ||
| ): | ||
| warnings.simplefilter("always") | ||
| warn_if_cuda_major_version_mismatch() | ||
| warn_if_cuda_major_version_mismatch() | ||
| warn_if_cuda_major_version_mismatch() | ||
| # Only one warning despite multiple calls | ||
| assert len(w) == 1 | ||
|
|
||
| def test_warning_suppressed_by_env_var(self): | ||
| """Warning should be suppressed when CUDA_PYTHON_DISABLE_MAJOR_VERSION_WARNING is set.""" | ||
| with ( | ||
| mock.patch.object(driver, "CUDA_VERSION", 13000), | ||
| mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 12080)), | ||
| mock.patch.dict(os.environ, {"CUDA_PYTHON_DISABLE_MAJOR_VERSION_WARNING": "1"}), | ||
| warnings.catch_warnings(record=True) as w, | ||
| ): | ||
| warnings.simplefilter("always") | ||
| warn_if_cuda_major_version_mismatch() | ||
| assert len(w) == 0 | ||
|
|
||
| def test_error_when_driver_version_fails(self): | ||
| """Should raise RuntimeError if cuDriverGetVersion fails.""" | ||
| with ( | ||
| mock.patch.object(driver, "CUDA_VERSION", 13000), | ||
| mock.patch.object( | ||
| driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_ERROR_NOT_INITIALIZED, 0) | ||
| ), | ||
| pytest.raises(RuntimeError, match="Failed to query CUDA driver version"), | ||
| ): | ||
| warn_if_cuda_major_version_mismatch() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is effectively going to reset if the version was checked already and after the unit tests executes the check could potentially execute again. Shouldn't we cache the current state and before setting it to false for the unit tests and then restore to its previous value on unit test teardown?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Someone with a dodgy build might get multiple warnings during testing. Not sure it's worth fixing.