diff --git a/test/utils.h b/test/utils.h index 657d774fce0..8b250b8bd06 100644 --- a/test/utils.h +++ b/test/utils.h @@ -395,12 +395,16 @@ inline bool maybeClearAllocator(int64_t max_bytes = ((int64_t)1 << 32)) { #if TORCH_VERSION_GREATER(2, 0, 1) // GetDevice was introduced in https://github.com/pytorch/pytorch/pull/94864 // in order to properly handle new CUDA 112 behavior - c10::cuda::GetDevice(&device); + // c10::cuda uses DeviceIndex instead of int + // https://github.com/pytorch/pytorch/pull/119142 + c10::DeviceIndex device_index; + c10::cuda::GetDevice(&device_index); + device = static_cast(device_index); #else cudaGetDevice(&device); #endif - auto device_stats = allocator->getDeviceStats(0); + auto device_stats = allocator->getDeviceStats(device); // allocated_bytes[] holds multiple statistics but the first is sum across // both small and large blocks if (uint64_t(device_stats.reserved_bytes[0].current) >