From 368cb37ee0ed4ccbba58950bb5a8568440c99742 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 9 Feb 2024 12:01:29 -0800 Subject: [PATCH 1/5] int->DeviceIndex --- test/utils.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/utils.h b/test/utils.h index 657d774fce0..00f30ea6cb6 100644 --- a/test/utils.h +++ b/test/utils.h @@ -395,7 +395,8 @@ inline bool maybeClearAllocator(int64_t max_bytes = ((int64_t)1 << 32)) { #if TORCH_VERSION_GREATER(2, 0, 1) // GetDevice was introduced in https://github.com/pytorch/pytorch/pull/94864 // in order to properly handle new CUDA 112 behavior - c10::cuda::GetDevice(&device); + // c10::cuda uses DeviceIndex instead of int https://github.com/pytorch/pytorch/pull/119142 + c10::cuda::GetDevice(reinterpret_cast(&device)); #else cudaGetDevice(&device); #endif From 9f41c93d96583a340628c86e1d09c703dc85f794 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 9 Feb 2024 12:15:39 -0800 Subject: [PATCH 2/5] format --- test/utils.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/utils.h b/test/utils.h index 00f30ea6cb6..50b42acf250 100644 --- a/test/utils.h +++ b/test/utils.h @@ -395,7 +395,8 @@ inline bool maybeClearAllocator(int64_t max_bytes = ((int64_t)1 << 32)) { #if TORCH_VERSION_GREATER(2, 0, 1) // GetDevice was introduced in https://github.com/pytorch/pytorch/pull/94864 // in order to properly handle new CUDA 112 behavior - // c10::cuda uses DeviceIndex instead of int https://github.com/pytorch/pytorch/pull/119142 + // c10::cuda uses DeviceIndex instead of int + // https://github.com/pytorch/pytorch/pull/119142 c10::cuda::GetDevice(reinterpret_cast(&device)); #else cudaGetDevice(&device); From 125da1b3762f12012c018fd01364b33277e55d68 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 9 Feb 2024 12:17:20 -0800 Subject: [PATCH 3/5] format --- test/utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/utils.h b/test/utils.h index 50b42acf250..7ee064747fc 100644 --- a/test/utils.h +++ b/test/utils.h @@ -395,7 +395,7 @@ inline bool maybeClearAllocator(int64_t max_bytes = ((int64_t)1 << 32)) { #if TORCH_VERSION_GREATER(2, 0, 1) // GetDevice was introduced in https://github.com/pytorch/pytorch/pull/94864 // in order to properly handle new CUDA 112 behavior - // c10::cuda uses DeviceIndex instead of int + // c10::cuda uses DeviceIndex instead of int // https://github.com/pytorch/pytorch/pull/119142 c10::cuda::GetDevice(reinterpret_cast(&device)); #else From 9cefabd79cbe7ca51db5e24578f111007fe87af3 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Sat, 10 Feb 2024 00:42:05 +0000 Subject: [PATCH 4/5] use device variable --- test/utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/utils.h b/test/utils.h index 7ee064747fc..6cccc5c221c 100644 --- a/test/utils.h +++ b/test/utils.h @@ -402,7 +402,7 @@ inline bool maybeClearAllocator(int64_t max_bytes = ((int64_t)1 << 32)) { cudaGetDevice(&device); #endif - auto device_stats = allocator->getDeviceStats(0); + auto device_stats = allocator->getDeviceStats(device); // allocated_bytes[] holds multiple statistics but the first is sum across // both small and large blocks if (uint64_t(device_stats.reserved_bytes[0].current) > From d0fd2e4635856c0dea3a03cfdf8e6678a236120d Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Sat, 10 Feb 2024 01:07:32 +0000 Subject: [PATCH 5/5] avoid reinterpret cast --- test/utils.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/utils.h b/test/utils.h index 6cccc5c221c..8b250b8bd06 100644 --- a/test/utils.h +++ b/test/utils.h @@ -397,7 +397,9 @@ inline bool maybeClearAllocator(int64_t max_bytes = ((int64_t)1 << 32)) { // in order to properly handle new CUDA 112 behavior // c10::cuda uses DeviceIndex instead of int // https://github.com/pytorch/pytorch/pull/119142 - c10::cuda::GetDevice(reinterpret_cast(&device)); + c10::DeviceIndex device_index; + c10::cuda::GetDevice(&device_index); + device = static_cast(device_index); #else cudaGetDevice(&device); #endif