diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 21416f619fae..769f01063ff2 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -42,8 +42,9 @@ class CUDADeviceAPI final : public DeviceAPI { int value = 0; switch (kind) { case kExist: - value = (cudaDeviceGetAttribute(&value, cudaDevAttrMaxThreadsPerBlock, dev.device_id) == - cudaSuccess); + int count; + CUDA_CALL(cudaGetDeviceCount(&count)); + value = static_cast(dev.device_id < count); break; case kMaxThreadsPerBlock: { CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMaxThreadsPerBlock, dev.device_id));