From 0d23a9690a9e211a9530230b7e31646bf498a116 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 9 Jan 2024 17:05:16 -0800 Subject: [PATCH] [Runtime] Use cudaGetDeviceCount to check if device exists Using `cudaDeviceGetAttribute` will set the global error code when the device doesn't exist and will impact subsequent CUDA API calls. --- src/runtime/cuda/cuda_device_api.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 21416f619fae..769f01063ff2 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -42,8 +42,9 @@ class CUDADeviceAPI final : public DeviceAPI { int value = 0; switch (kind) { case kExist: - value = (cudaDeviceGetAttribute(&value, cudaDevAttrMaxThreadsPerBlock, dev.device_id) == - cudaSuccess); + int count; + CUDA_CALL(cudaGetDeviceCount(&count)); + value = static_cast(dev.device_id < count); break; case kMaxThreadsPerBlock: { CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMaxThreadsPerBlock, dev.device_id));