From 5f7707ac9b5226bcbcc1b351e8a479cd0dc3d419 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 26 Oct 2024 13:57:02 -0400 Subject: [PATCH 1/3] fix(pt): set device for PT C++ Fix #4171. Signed-off-by: Jinzhe Zeng --- source/api_cc/src/DeepPotPT.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 4c7aac19b8..eb99abc41a 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -80,6 +80,7 @@ void DeepPotPT::init(const std::string& model, device = torch::Device(torch::kCPU); std::cout << "load model from: " << model << " to cpu " << std::endl; } else { + c10::cuda::set_device(gpu_id); std::cout << "load model from: " << model << " to gpu " << gpu_id << std::endl; } From ba015a1a7ed0f20e0dd4599a7c2e0957cb2d5e45 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 26 Oct 2024 14:21:37 -0400 Subject: [PATCH 2/3] try CUDAGuard Signed-off-by: Jinzhe Zeng --- source/api_cc/src/DeepPotPT.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index eb99abc41a..5d43515e2d 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -2,6 +2,7 @@ #ifdef BUILD_PYTORCH #include "DeepPotPT.h" +#include #include #include @@ -80,7 +81,7 @@ void DeepPotPT::init(const std::string& model, device = torch::Device(torch::kCPU); std::cout << "load model from: " << model << " to cpu " << std::endl; } else { - c10::cuda::set_device(gpu_id); + c10::cuda::CUDAGuard guard_(device); std::cout << "load model from: " << model << " to gpu " << gpu_id << std::endl; } From 9bee6f4cffc7b9878f48f2c7b8cc69ae2034b5df Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 26 Oct 2024 14:44:32 -0400 Subject: [PATCH 3/3] use dp api Signed-off-by: Jinzhe Zeng --- source/api_cc/src/DeepPotPT.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 5d43515e2d..780a8007f3 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -2,7 +2,6 @@ #ifdef BUILD_PYTORCH #include "DeepPotPT.h" -#include #include #include @@ -81,7 +80,9 @@ void DeepPotPT::init(const std::string& model, device = torch::Device(torch::kCPU); std::cout << "load model from: " << model << " to cpu " << std::endl; } else { - c10::cuda::CUDAGuard guard_(device); +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + DPErrcheck(DPSetDevice(gpu_id)); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM std::cout << "load model from: " << model << " to gpu " << gpu_id << std::endl; }