Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions cuda_bindings/tests/nvml/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,18 @@ def test_devices_are_the_same_architecture(all_devices):
# they won't be tested properly. This tests for the (hopefully rare) case
# where a system has devices of different architectures and produces a warning.

all_arches = {nvml.DeviceArch(nvml.device_get_architecture(device)) for device in all_devices}
def get_architecture_name(arch):
try:
arch = nvml.DeviceArch(arch)
return arch.name
except ValueError:
return f"UNKNOWN({arch})"

all_arches = {nvml.device_get_architecture(device) for device in all_devices}

if len(all_arches) > 1:
warnings.warn(
f"System has devices of multiple architectures ({', '.join(x.name for x in all_arches)}). "
f"System has devices of multiple architectures ({', '.join(get_architecture_name(x) for x in all_arches)}). "
f" Some tests may be skipped unexpectedly",
UserWarning,
)
Expand Down
2 changes: 1 addition & 1 deletion cuda_bindings/tests/nvml/test_nvlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ def test_nvlink_get_link_count(all_devices):
# The feature_nvlink_supported detection is not robust, so we
# can't be more specific about how many links we should find.
if value.nvml_return == nvml.Return.SUCCESS:
assert value.value.ui_val <= nvml.NVLINK_MAX_LINKS, f"Unexpected link count {value.value.ui_val}"
assert value.value.ui_val[0] <= nvml.NVLINK_MAX_LINKS, f"Unexpected link count {value.value.ui_val[0]}"
6 changes: 5 additions & 1 deletion cuda_core/cuda/core/system/_device.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,11 @@ cdef class Device:
"VOLTA"``, and RTX A6000 will report ``DeviceArchitecture.name ==
"AMPERE"``.
"""
return DeviceArch(nvml.device_get_architecture(self._handle))
arch = nvml.device_get_architecture(self._handle)
try:
return DeviceArch(arch)
except ValueError:
return nvml.DeviceArch.UNKNOWN

@property
def name(self) -> str:
Expand Down