From 55259704ac45a6b372806182e68067800e40b630 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Thu, 4 Dec 2025 14:05:20 -0800 Subject: [PATCH 1/6] Better handling of non-default GPU/multi-GPU per process use cases. Signed-off-by: Josh Romero --- src/csrc/include/internal/rl/off_policy.h | 41 ++++++++----------- src/csrc/include/internal/rl/on_policy.h | 49 +++++++++-------------- src/csrc/include/internal/utils.h | 8 ++++ src/csrc/rl/off_policy/interface.cpp | 12 ++---- src/csrc/rl/on_policy/interface.cpp | 8 ++-- src/csrc/training.cpp | 16 +++----- src/csrc/utils.cpp | 26 ++++++++++++ 7 files changed, 81 insertions(+), 79 deletions(-) diff --git a/src/csrc/include/internal/rl/off_policy.h b/src/csrc/include/internal/rl/off_policy.h index ba102b3..d40fd3a 100644 --- a/src/csrc/include/internal/rl/off_policy.h +++ b/src/csrc/include/internal/rl/off_policy.h @@ -30,6 +30,7 @@ #include "internal/defines.h" #include "internal/logging.h" +#include "internal/utils.h" namespace torchfort { @@ -90,12 +91,10 @@ static void update_replay_buffer(const char* name, T* state_old, T* state_new, s torch::NoGradGuard no_grad; #ifdef ENABLE_GPU - c10::cuda::OptionalCUDAStreamGuard guard; auto rb_device = registry[name]->rbDevice(); - if (rb_device.is_cuda()) { - auto stream = c10::cuda::getStreamFromExternal(ext_stream, rb_device.index()); - guard.reset_stream(stream); - } + c10::cuda::OptionalCUDAStreamGuard stream_guard; + c10::cuda::OptionalCUDAGuard cuda_guard; + set_device_and_stream(stream_guard, cuda_guard, rb_device, ext_stream); #endif // get tensors and copy: @@ -121,12 +120,10 @@ static void update_replay_buffer(const char* name, T* state_old, T* state_new, s torch::NoGradGuard no_grad; #ifdef ENABLE_GPU - c10::cuda::OptionalCUDAStreamGuard guard; auto rb_device = registry[name]->rbDevice(); - if (rb_device.is_cuda()) { - auto stream = c10::cuda::getStreamFromExternal(ext_stream, rb_device.index()); - guard.reset_stream(stream); - } + c10::cuda::OptionalCUDAStreamGuard stream_guard; + c10::cuda::OptionalCUDAGuard cuda_guard; + set_device_and_stream(stream_guard, cuda_guard, rb_device, ext_stream); #endif // get tensors and copy: @@ -152,12 +149,10 @@ static void predict_explore(const char* name, T* state, size_t state_dim, int64_ #ifdef ENABLE_GPU // device and stream handling - c10::cuda::OptionalCUDAStreamGuard guard; auto model_device = registry[name]->modelDevice(); - if (model_device.is_cuda()) { - auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index()); - guard.reset_stream(stream); - } + c10::cuda::OptionalCUDAStreamGuard stream_guard; + c10::cuda::OptionalCUDAGuard cuda_guard; + set_device_and_stream(stream_guard, cuda_guard, model_device, ext_stream); #endif // create tensors @@ -190,12 +185,10 @@ static void predict(const char* name, T* state, size_t state_dim, int64_t* state #ifdef ENABLE_GPU // device and stream handling - c10::cuda::OptionalCUDAStreamGuard guard; auto model_device = registry[name]->modelDevice(); - if (model_device.is_cuda()) { - auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index()); - guard.reset_stream(stream); - } + c10::cuda::OptionalCUDAStreamGuard stream_guard; + c10::cuda::OptionalCUDAGuard cuda_guard; + set_device_and_stream(stream_guard, cuda_guard, model_device, ext_stream); #endif // create tensors @@ -228,12 +221,10 @@ static void policy_evaluate(const char* name, T* state, size_t state_dim, int64_ #ifdef ENABLE_GPU // device and stream handling - c10::cuda::OptionalCUDAStreamGuard guard; auto model_device = registry[name]->modelDevice(); - if (model_device.is_cuda()) { - auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index()); - guard.reset_stream(stream); - } + c10::cuda::OptionalCUDAStreamGuard stream_guard; + c10::cuda::OptionalCUDAGuard cuda_guard; + set_device_and_stream(stream_guard, cuda_guard, model_device, ext_stream); #endif // create tensors diff --git a/src/csrc/include/internal/rl/on_policy.h b/src/csrc/include/internal/rl/on_policy.h index 5302ea5..65e0f49 100644 --- a/src/csrc/include/internal/rl/on_policy.h +++ b/src/csrc/include/internal/rl/on_policy.h @@ -30,6 +30,7 @@ #include "internal/defines.h" #include "internal/logging.h" +#include "internal/utils.h" namespace torchfort { @@ -94,14 +95,10 @@ static void update_rollout_buffer(const char* name, T* state, size_t state_dim, auto model_device = registry[name]->modelDevice(); auto rb_device = registry[name]->rbDevice(); #ifdef ENABLE_GPU - c10::cuda::OptionalCUDAStreamGuard guard; - if (model_device.is_cuda()) { - auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index()); - guard.reset_stream(stream); - } else if (rb_device.is_cuda()) { - auto stream = c10::cuda::getStreamFromExternal(ext_stream, rb_device.index()); - guard.reset_stream(stream); - } + c10::cuda::OptionalCUDAStreamGuard stream_guard; + c10::cuda::OptionalCUDAGuard cuda_guard; + set_device_and_stream(stream_guard, cuda_guard, model_device, ext_stream); + set_device_and_stream(stream_guard, cuda_guard, rb_device, ext_stream); #endif // get tensors and copy: @@ -127,14 +124,10 @@ static void update_rollout_buffer(const char* name, T* state, size_t state_dim, auto model_device = registry[name]->modelDevice(); auto rb_device = registry[name]->rbDevice(); #ifdef ENABLE_GPU - c10::cuda::OptionalCUDAStreamGuard guard; - if (model_device.is_cuda()) { - auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index()); - guard.reset_stream(stream); - } else if (rb_device.is_cuda()) { - auto stream = c10::cuda::getStreamFromExternal(ext_stream, rb_device.index()); - guard.reset_stream(stream); - } + c10::cuda::OptionalCUDAStreamGuard stream_guard; + c10::cuda::OptionalCUDAGuard cuda_guard; + set_device_and_stream(stream_guard, cuda_guard, model_device, ext_stream); + set_device_and_stream(stream_guard, cuda_guard, rb_device, ext_stream); #endif // get tensors and copy: @@ -157,12 +150,10 @@ static void predict_explore(const char* name, T* state, size_t state_dim, int64_ #ifdef ENABLE_GPU // device and stream handling - c10::cuda::OptionalCUDAStreamGuard guard; auto model_device = registry[name]->modelDevice(); - if (model_device.is_cuda()) { - auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index()); - guard.reset_stream(stream); - } + c10::cuda::OptionalCUDAStreamGuard stream_guard; + c10::cuda::OptionalCUDAGuard cuda_guard; + set_device_and_stream(stream_guard, cuda_guard, model_device, ext_stream); #endif // create tensors @@ -194,12 +185,10 @@ static void predict(const char* name, T* state, size_t state_dim, int64_t* state #ifdef ENABLE_GPU // device and stream handling - c10::cuda::OptionalCUDAStreamGuard guard; + c10::cuda::OptionalCUDAStreamGuard stream_guard; + c10::cuda::OptionalCUDAGuard cuda_guard; auto model_device = registry[name]->modelDevice(); - if (model_device.is_cuda()) { - auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index()); - guard.reset_stream(stream); - } + set_device_and_stream(stream_guard, cuda_guard, model_device, ext_stream); #endif // create tensors @@ -232,12 +221,10 @@ static void policy_evaluate(const char* name, T* state, size_t state_dim, int64_ #ifdef ENABLE_GPU // device and stream handling - c10::cuda::OptionalCUDAStreamGuard guard; auto model_device = registry[name]->modelDevice(); - if (model_device.is_cuda()) { - auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index()); - guard.reset_stream(stream); - } + c10::cuda::OptionalCUDAStreamGuard stream_guard; + c10::cuda::OptionalCUDAGuard cuda_guard; + set_device_and_stream(stream_guard, cuda_guard, model_device, ext_stream); #endif // create tensors diff --git a/src/csrc/include/internal/utils.h b/src/csrc/include/internal/utils.h index f5de09b..4b7de11 100644 --- a/src/csrc/include/internal/utils.h +++ b/src/csrc/include/internal/utils.h @@ -22,6 +22,10 @@ #include #include +#ifdef ENABLE_GPU +#include +#include +#endif #include #ifdef ENABLE_GPU @@ -114,4 +118,8 @@ std::string print_tensor_shape(torch::Tensor tensor); // Helper function to get the lrs std::vector get_current_lrs(const char* name); +#ifdef ENABLE_GPU +// Helper function to set the device and stream with device checks +void set_device_and_stream(c10::cuda::OptionalCUDAStreamGuard& stream_guard, c10::cuda::OptionalCUDAGuard& cuda_guard, torch::Device device, cudaStream_t ext_stream); +#endif } // namespace torchfort diff --git a/src/csrc/rl/off_policy/interface.cpp b/src/csrc/rl/off_policy/interface.cpp index 4bbdbc7..d7b5e6f 100644 --- a/src/csrc/rl/off_policy/interface.cpp +++ b/src/csrc/rl/off_policy/interface.cpp @@ -180,14 +180,10 @@ torchfort_result_t torchfort_rl_off_policy_train_step(const char* name, float* p auto model_device = rl::off_policy::registry[name]->modelDevice(); auto rb_device = rl::off_policy::registry[name]->rbDevice(); #ifdef ENABLE_GPU - c10::cuda::OptionalCUDAStreamGuard guard; - if (model_device.is_cuda()) { - auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index()); - guard.reset_stream(stream); - } else if (rb_device.is_cuda()) { - auto stream = c10::cuda::getStreamFromExternal(ext_stream, rb_device.index()); - guard.reset_stream(stream); - } + c10::cuda::OptionalCUDAStreamGuard stream_guard; + c10::cuda::OptionalCUDAGuard cuda_guard; + set_device_and_stream(stream_guard, cuda_guard, model_device, ext_stream); + set_device_and_stream(stream_guard, cuda_guard, rb_device, ext_stream); #endif try { diff --git a/src/csrc/rl/on_policy/interface.cpp b/src/csrc/rl/on_policy/interface.cpp index e3cfea0..8e7daa1 100644 --- a/src/csrc/rl/on_policy/interface.cpp +++ b/src/csrc/rl/on_policy/interface.cpp @@ -167,12 +167,10 @@ torchfort_result_t torchfort_rl_on_policy_train_step(const char* name, float* p_ #ifdef ENABLE_GPU // TODO: we need to figure out what to do if RB and Model streams are different - c10::cuda::OptionalCUDAStreamGuard guard; auto model_device = rl::on_policy::registry[name]->modelDevice(); - if (model_device.is_cuda()) { - auto stream = c10::cuda::getStreamFromExternal(ext_stream, model_device.index()); - guard.reset_stream(stream); - } + c10::cuda::OptionalCUDAStreamGuard stream_guard; + c10::cuda::OptionalCUDAGuard cuda_guard; + set_device_and_stream(stream_guard, cuda_guard, model_device, ext_stream); #endif try { diff --git a/src/csrc/training.cpp b/src/csrc/training.cpp index f50d3dd..77c7095 100644 --- a/src/csrc/training.cpp +++ b/src/csrc/training.cpp @@ -56,11 +56,9 @@ void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, tor auto model = models[name].model; #if ENABLE_GPU - c10::cuda::OptionalCUDAStreamGuard guard; - if (model->device().is_cuda()) { - auto stream = c10::cuda::getStreamFromExternal(ext_stream, model->device().index()); - guard.reset_stream(stream); - } + c10::cuda::OptionalCUDAStreamGuard stream_guard; + c10::cuda::OptionalCUDAGuard cuda_guard; + set_device_and_stream(stream_guard, cuda_guard, model->device(), ext_stream); #endif inputs->to(model->device()); @@ -104,11 +102,9 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo auto model = models[name].model; #ifdef ENABLE_GPU - c10::cuda::OptionalCUDAStreamGuard guard; - if (model->device().is_cuda()) { - auto stream = c10::cuda::getStreamFromExternal(ext_stream, model->device().index()); - guard.reset_stream(stream); - } + c10::cuda::OptionalCUDAStreamGuard stream_guard; + c10::cuda::OptionalCUDAGuard cuda_guard; + set_device_and_stream(stream_guard, cuda_guard, model->device(), ext_stream); #endif inputs->to(model->device()); diff --git a/src/csrc/utils.cpp b/src/csrc/utils.cpp index 3e8e1a7..c500e56 100644 --- a/src/csrc/utils.cpp +++ b/src/csrc/utils.cpp @@ -20,8 +20,16 @@ #include #include +#ifdef ENABLE_GPU +#include +#include +#endif #include +#ifdef ENABLE_GPU +#include +#endif + #include "internal/defines.h" #include "internal/model_pack.h" #include "internal/utils.h" @@ -102,4 +110,22 @@ std::vector get_current_lrs(const char* name) { return learnings_rates; } +#ifdef ENABLE_GPU +void set_device_and_stream(c10::cuda::OptionalCUDAStreamGuard& stream_guard, c10::cuda::OptionalCUDAGuard& cuda_guard, torch::Device device, cudaStream_t ext_stream) { + if (device.is_cuda()) { + cuda_guard.set_device(device); + if (ext_stream) { + int ext_stream_device; + CHECK_CUDA(cudaStreamGetDevice(ext_stream, &ext_stream_device)); + if (ext_stream_device != device.index()) { + std::stringstream ss; + ss << "The provided external stream is on device " << get_device(ext_stream_device) << " but the device is on device " << device << "."; + THROW_INVALID_USAGE(ss.str()); + } + stream_guard.reset_stream(c10::cuda::getStreamFromExternal(ext_stream, device.index())); + } + } +} +#endif + } // namespace torchfort From 281a6d681311ba6dd8e0ee6d55868c373d2c6eb9 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Fri, 5 Dec 2025 12:28:27 -0800 Subject: [PATCH 2/6] Add device context switch checks to supervised learning tests. Signed-off-by: Josh Romero --- tests/supervised/test_training.cpp | 77 ++++++++++++++++++++++++++---- tests/test_utils.h | 10 ++++ 2 files changed, 78 insertions(+), 9 deletions(-) diff --git a/tests/supervised/test_training.cpp b/tests/supervised/test_training.cpp index 709b22e..4cfd3b7 100644 --- a/tests/supervised/test_training.cpp +++ b/tests/supervised/test_training.cpp @@ -36,6 +36,22 @@ void training_test(const std::string& model_config, int dev_model, int dev_input std::string model_name = generate_random_name(10); +#ifdef ENABLE_GPU + if (dev_model == 1 || dev_input == 1) { + int ngpu; + cudaGetDeviceCount(&ngpu); + if (ngpu < 2) { + GTEST_SKIP() << "This test requires at least 2 GPUs. Skipping."; + } + } +#endif + +#ifdef ENABLE_GPU + if (dev_input != TORCHFORT_DEVICE_CPU) { + CHECK_CUDA(cudaSetDevice(dev_input)); + } +#endif + try { CHECK_TORCHFORT(torchfort_create_model(model_name.c_str(), model_config.c_str(), dev_model)); if (should_fail_create) { @@ -49,11 +65,7 @@ void training_test(const std::string& model_config, int dev_model, int dev_input } } -#ifdef ENABLE_GPU - if (dev_input != TORCHFORT_DEVICE_CPU) { - CHECK_CUDA(cudaSetDevice(dev_input)); - } -#endif + if (!check_current_device(dev_input)) FAIL() << "GPU device switched by torchfort_create_model."; auto input = generate_random(shape); auto label = generate_random(shape); @@ -85,6 +97,8 @@ void training_test(const std::string& model_config, int dev_model, int dev_input } } + if (!check_current_device(dev_input)) FAIL() << "GPU device switched by torchfort_train."; + try { CHECK_TORCHFORT(torchfort_inference(model_name.c_str(), input_ptr, shape.size(), shape.data(), output_ptr, shape.size(), shape.data(), TORCHFORT_FLOAT, 0)); @@ -106,6 +120,8 @@ void training_test(const std::string& model_config, int dev_model, int dev_input } } + if (!check_current_device(dev_input)) FAIL() << "GPU device switched by torchfort_inference."; + #ifdef ENABLE_GPU if (dev_input != TORCHFORT_DEVICE_CPU) { copy_to_host_vector(output, output_ptr); @@ -124,10 +140,15 @@ void training_test(const std::string& model_config, int dev_model, int dev_input void training_test_multiarg(const std::string& model_config, int dev_model, int dev_input, bool use_extra_args, bool should_fail_create, bool should_fail_train, bool should_fail_inference, bool check_result) { - - std::string model_name = generate_random_name(10); - - CHECK_TORCHFORT(torchfort_create_model(model_name.c_str(), model_config.c_str(), dev_model)); +#ifdef ENABLE_GPU + if (dev_model == 1 || dev_input == 1) { + int ngpu; + cudaGetDeviceCount(&ngpu); + if (ngpu < 2) { + GTEST_SKIP() << "This test requires at least 2 GPUs. Skipping."; + } + } +#endif #ifdef ENABLE_GPU if (dev_input != TORCHFORT_DEVICE_CPU) { @@ -135,6 +156,12 @@ void training_test_multiarg(const std::string& model_config, int dev_model, int } #endif + std::string model_name = generate_random_name(10); + + CHECK_TORCHFORT(torchfort_create_model(model_name.c_str(), model_config.c_str(), dev_model)); + + if (!check_current_device(dev_input)) FAIL() << "GPU device switched by torchfort_create_model."; + std::vector shape = {10, 10}; std::vector> inputs(2), labels(2), outputs(2); for (int i = 0; i < 2; ++i) { @@ -156,6 +183,7 @@ void training_test_multiarg(const std::string& model_config, int dev_model, int CHECK_TORCHFORT(torchfort_tensor_list_create(&inputs_tl)); CHECK_TORCHFORT(torchfort_tensor_list_create(&labels_tl)); CHECK_TORCHFORT(torchfort_tensor_list_create(&outputs_tl)); + std::vector input_ptrs(2), label_ptrs(2), output_ptrs(2); for (int i = 0; i < 2; ++i) { @@ -170,6 +198,8 @@ void training_test_multiarg(const std::string& model_config, int dev_model, int torchfort_tensor_list_add_tensor(outputs_tl, output_ptrs[i], shape.size(), shape.data(), TORCHFORT_FLOAT)); } + if (!check_current_device(dev_input)) FAIL() << "GPU device switched by torchfort_tensor_list_add_tensor."; + torchfort_tensor_list_t extra_args_tl; std::vector extra_args_ptrs(2); if (use_extra_args) { @@ -181,6 +211,8 @@ void training_test_multiarg(const std::string& model_config, int dev_model, int } } + + try { CHECK_TORCHFORT(torchfort_train_multiarg(model_name.c_str(), inputs_tl, labels_tl, &loss_val, (use_extra_args) ? extra_args_tl : nullptr, 0)); @@ -195,6 +227,8 @@ void training_test_multiarg(const std::string& model_config, int dev_model, int } } + if (!check_current_device(dev_input)) FAIL() << "GPU device switched by torchfort_train_multiarg."; + try { CHECK_TORCHFORT(torchfort_inference_multiarg(model_name.c_str(), inputs_tl, outputs_tl, 0)); if (should_fail_inference) { @@ -208,6 +242,8 @@ void training_test_multiarg(const std::string& model_config, int dev_model, int } } + if (!check_current_device(dev_input)) FAIL() << "GPU device switched by torchfort_inference_multiarg."; + // Check inference output if (check_result) { for (int i = 0; i < 2; ++i) { @@ -395,6 +431,15 @@ TEST(TorchFort, TrainTestMLPCPUGPU) { training_test("configs/mlp2.yaml", TORCHFORT_DEVICE_CPU, 0, {10, 2, 5}, false, false, false, false); } TEST(TorchFort, TrainTestMLPGPUGPU) { training_test("configs/mlp2.yaml", 0, 0, {10, 10}, false, false, false, false); } + +TEST(TorchFort, TrainTestMLPGPU1CPU) { + training_test("configs/mlp2.yaml", 1, TORCHFORT_DEVICE_CPU, {10, 2, 5}, false, false, false, false); +} +TEST(TorchFort, TrainTestMLPCPUGPU1) { + training_test("configs/mlp2.yaml", TORCHFORT_DEVICE_CPU, 1, {10, 2, 5}, false, false, false, false); +} +TEST(TorchFort, TrainTestMLPGPU0GPU1) { training_test("configs/mlp2.yaml", 0, 1, {10, 10}, false, false, false, false); } +TEST(TorchFort, TrainTestMLPGPU1GPU0) { training_test("configs/mlp2.yaml", 1, 0, {10, 10}, false, false, false, false); } TEST(TorchFort, TrainTestTorchScriptCPUGPU) { training_test("configs/torchscript.yaml", TORCHFORT_DEVICE_CPU, 0, {10, 2, 10}, false, false, false, true); } @@ -415,6 +460,20 @@ TEST(TorchFort, TrainTestTorchScriptMultiArgGPUCPU) { TEST(TorchFort, TrainTestTorchScriptMultiArgGPUGPU) { training_test_multiarg("configs/torchscript_multiarg.yaml", 0, 0, false, false, false, false, true); } +TEST(TorchFort, TrainTestTorchScriptMultiArgCPUGPU1) { + training_test_multiarg("configs/torchscript_multiarg.yaml", TORCHFORT_DEVICE_CPU, 1, false, false, false, false, + true); +} +TEST(TorchFort, TrainTestTorchScriptMultiArgGPU1CPU) { + training_test_multiarg("configs/torchscript_multiarg.yaml", 1, TORCHFORT_DEVICE_CPU, false, false, false, false, + true); +} +TEST(TorchFort, TrainTestTorchScriptMultiArgGPU0GPU1) { + training_test_multiarg("configs/torchscript_multiarg.yaml", 0, 1, false, false, false, false, true); +} +TEST(TorchFort, TrainTestTorchScriptMultiArgGPU1GPU0) { + training_test_multiarg("configs/torchscript_multiarg.yaml", 1, 0, false, false, false, false, true); +} TEST(TorchFort, TrainTestTorchScriptMultiArgExtraCPUGPU) { training_test_multiarg("configs/torchscript_multiarg_extra.yaml", TORCHFORT_DEVICE_CPU, 0, true, false, false, false, true); diff --git a/tests/test_utils.h b/tests/test_utils.h index b2366c5..9222a97 100644 --- a/tests/test_utils.h +++ b/tests/test_utils.h @@ -105,3 +105,13 @@ template void copy_from_host_vector(T* data_ptr, std::vector& da CHECK_CUDA(cudaMemcpy(data_ptr, data.data(), data.size() * sizeof(T(0)), cudaMemcpyHostToDevice)); } #endif + +bool check_current_device(int expected_device) { +#ifdef ENABLE_GPU + int device; + CHECK_CUDA(cudaGetDevice(&device)); + + return expected_device == TORCHFORT_DEVICE_CPU || device == expected_device; +#endif + return true; +} From 12fce112c16574b4c5e9c884aec8ab95bdce2621 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Fri, 5 Dec 2025 22:42:18 -0800 Subject: [PATCH 3/6] Adding tests. Conditional use of cuStreamGetDevice based on CUDA driver version. Signed-off-by: Josh Romero --- CMakeLists.txt | 1 + src/csrc/cuda_wrap.cpp | 81 +++++++++++++++++++++++++++ src/csrc/include/internal/cuda_wrap.h | 49 ++++++++++++++++ src/csrc/include/internal/defines.h | 18 ++++++ src/csrc/utils.cpp | 23 +++++++- tests/supervised/test_training.cpp | 43 ++++++++++++-- 6 files changed, 208 insertions(+), 7 deletions(-) create mode 100644 src/csrc/cuda_wrap.cpp create mode 100644 src/csrc/include/internal/cuda_wrap.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 62cffb8..db2fb27 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -139,6 +139,7 @@ set_target_properties(${PROJECT_NAME} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAK target_sources(${PROJECT_NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/cuda_wrap.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/distributed.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/logging.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/model_state.cpp diff --git a/src/csrc/cuda_wrap.cpp b/src/csrc/cuda_wrap.cpp new file mode 100644 index 0000000..3e4c953 --- /dev/null +++ b/src/csrc/cuda_wrap.cpp @@ -0,0 +1,81 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_GPU +#include + +#include "internal/cuda_wrap.h" +#include "internal/defines.h" + +#if CUDART_VERSION >= 13000 +#define LOAD_SYM(symbol, version, optional) \ + do { \ + cudaDriverEntryPointQueryResult driverStatus = cudaDriverEntryPointSymbolNotFound; \ + cudaError_t err = cudaGetDriverEntryPointByVersion(#symbol, (void**)(&cuFnTable.pfn_##symbol), version, \ + cudaEnableDefault, &driverStatus)); \ + if ((driverStatus != cudaDriverEntryPointSuccess || err != cudaSuccess) && !optional) { \ + THROW_CUDA_ERROR("cudaGetDriverEntryPointByVersion failed."); \ + } \ + } while (false) +#elif CUDART_VERSION >= 12000 +#define LOAD_SYM(symbol, version, optional) \ + do { \ + cudaDriverEntryPointQueryResult driverStatus = cudaDriverEntryPointSymbolNotFound; \ + cudaError_t err = cudaGetDriverEntryPoint(#symbol, (void**)(&cuFnTable.pfn_##symbol), cudaEnableDefault, \ + &driverStatus); \ + if ((driverStatus != cudaDriverEntryPointSuccess || err != cudaSuccess) && !optional) { \ + THROW_CUDA_ERROR("cudaGetDriverEntryPoint failed."); \ + } \ + } while (false) +#else +#define LOAD_SYM(symbol, version, optional) \ + do { \ + cudaError_t err = cudaGetDriverEntryPoint(#symbol, (void**)(&cuFnTable.pfn_##symbol), cudaEnableDefault); \ + if (err != cudaSuccess && !optional) { \ + THROW_CUDA_ERROR("cudaGetDriverEntryPoint failed."); \ + } \ + } while (false) +#endif + +namespace torchfort { + +cuFunctionTable cuFnTable; // global table of required CUDA driver functions + +void initCuFunctionTable() { + std::lock_guard guard(cuFnTable.mutex); + + if (cuFnTable.initialized) { + return; + } + +#if CUDART_VERSION >= 11030 + LOAD_SYM(cuCtxGetCurrent, 4000, false); + LOAD_SYM(cuCtxGetDevice, 2000, false); + LOAD_SYM(cuCtxSetCurrent, 4000, false); + LOAD_SYM(cuGetErrorString, 6000, false); + LOAD_SYM(cuStreamGetCtx, 9020, false); +#if CUDART_VERSION >= 12080 + LOAD_SYM(cuStreamGetDevice, 12080, true); +#endif +#endif + cuFnTable.initialized = true; +} + +} // namespace torchfort + +#undef LOAD_SYM +#endif diff --git a/src/csrc/include/internal/cuda_wrap.h b/src/csrc/include/internal/cuda_wrap.h new file mode 100644 index 0000000..feb7a1b --- /dev/null +++ b/src/csrc/include/internal/cuda_wrap.h @@ -0,0 +1,49 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#if CUDART_VERSION >= 11030 +#include +#endif + +#define DECLARE_CUDA_PFN(symbol, version) PFN_##symbol##_v##version pfn_##symbol = nullptr + +namespace torchfort { + +struct cuFunctionTable { +#if CUDART_VERSION >= 11030 + DECLARE_CUDA_PFN(cuCtxGetCurrent, 4000); + DECLARE_CUDA_PFN(cuCtxGetDevice, 2000); + DECLARE_CUDA_PFN(cuCtxSetCurrent, 4000); + DECLARE_CUDA_PFN(cuGetErrorString, 6000); + DECLARE_CUDA_PFN(cuStreamGetCtx, 9020); +#if CUDART_VERSION >= 12080 + DECLARE_CUDA_PFN(cuStreamGetDevice, 12080); +#endif +#endif + bool initialized = false; + std::mutex mutex; +}; + +extern cuFunctionTable cuFnTable; + +void initCuFunctionTable(); +} // namespace torchfort + +#undef DECLARE_CUDA_PFN diff --git a/src/csrc/include/internal/defines.h b/src/csrc/include/internal/defines.h index 812ac70..0ff766e 100644 --- a/src/csrc/include/internal/defines.h +++ b/src/csrc/include/internal/defines.h @@ -23,6 +23,7 @@ #include "internal/base_loss.h" #include "internal/base_model.h" +#include "internal/cuda_wrap.h" #include "internal/exceptions.h" #include "internal/utils.h" @@ -44,6 +45,17 @@ } \ } while (false) +#define CHECK_CUDA_DRV(call) \ + do { \ + if (!cuFnTable.initialized) {initCuFunctionTable();} \ + CUresult err = cuFnTable.pfn_##call; \ + if (CUDA_SUCCESS != err) { \ + const char* error_str; \ + cuFnTable.pfn_cuGetErrorString(err, &error_str); \ + throw torchfort::CudaError(__FILE__, __LINE__, error_str); \ + } \ + } while (false) + #define CHECK_NCCL(call) \ do { \ ncclResult_t err = call; \ @@ -72,6 +84,12 @@ } \ } while (false) +#define IS_CUDA_DRV_FUNC_AVAILABLE(symbol) \ + ([&]() { if (!cuFnTable.initialized) {initCuFunctionTable();} \ + return cuFnTable.pfn_##symbol != nullptr; \ + })() + + #define BEGIN_MODEL_REGISTRY \ static std::unordered_map()>> model_registry { diff --git a/src/csrc/utils.cpp b/src/csrc/utils.cpp index c500e56..ae3fdc7 100644 --- a/src/csrc/utils.cpp +++ b/src/csrc/utils.cpp @@ -111,15 +111,34 @@ std::vector get_current_lrs(const char* name) { } #ifdef ENABLE_GPU +int getStreamDevice(cudaStream_t stream) { + CUdevice device; + +#if CUDART_VERSION >= 12080 + if (IS_CUDA_DRV_FUNC_AVAILABLE(cuStreamGetDevice)) { + CHECK_CUDA_DRV(cuStreamGetDevice((CUstream)stream, &device)); + return (int)device; + } +#endif + + CUcontext streamCtx, savedCtx; + CHECK_CUDA_DRV(cuCtxGetCurrent(&savedCtx)); + CHECK_CUDA_DRV(cuStreamGetCtx((CUstream)stream, &streamCtx)); + CHECK_CUDA_DRV(cuCtxSetCurrent(streamCtx)); + CHECK_CUDA_DRV(cuCtxGetDevice(&device)); + CHECK_CUDA_DRV(cuCtxSetCurrent(savedCtx)); + return (int)device; +} + void set_device_and_stream(c10::cuda::OptionalCUDAStreamGuard& stream_guard, c10::cuda::OptionalCUDAGuard& cuda_guard, torch::Device device, cudaStream_t ext_stream) { if (device.is_cuda()) { cuda_guard.set_device(device); if (ext_stream) { int ext_stream_device; - CHECK_CUDA(cudaStreamGetDevice(ext_stream, &ext_stream_device)); + ext_stream_device = getStreamDevice(ext_stream); if (ext_stream_device != device.index()) { std::stringstream ss; - ss << "The provided external stream is on device " << get_device(ext_stream_device) << " but the device is on device " << device << "."; + ss << "The provided external stream is on device " << get_device(ext_stream_device) << " but the model is on device " << device << "."; THROW_INVALID_USAGE(ss.str()); } stream_guard.reset_stream(c10::cuda::getStreamFromExternal(ext_stream, device.index())); diff --git a/tests/supervised/test_training.cpp b/tests/supervised/test_training.cpp index 4cfd3b7..b51b2f1 100644 --- a/tests/supervised/test_training.cpp +++ b/tests/supervised/test_training.cpp @@ -32,20 +32,30 @@ #include "test_utils.h" void training_test(const std::string& model_config, int dev_model, int dev_input, std::vector shape, - bool should_fail_create, bool should_fail_train, bool should_fail_inference, bool check_result) { + bool should_fail_create, bool should_fail_train, bool should_fail_inference, bool check_result, + int dev_stream=-1) { std::string model_name = generate_random_name(10); #ifdef ENABLE_GPU - if (dev_model == 1 || dev_input == 1) { + if (dev_model == 1 || dev_input == 1 || dev_stream == 1) { int ngpu; - cudaGetDeviceCount(&ngpu); + CHECK_CUDA(cudaGetDeviceCount(&ngpu)); if (ngpu < 2) { GTEST_SKIP() << "This test requires at least 2 GPUs. Skipping."; } } #endif + cudaStream_t stream = nullptr; +#ifdef ENABLE_GPU + if (dev_stream != -1) { + CHECK_CUDA(cudaSetDevice(dev_stream)); + CHECK_CUDA(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + } +#endif + + #ifdef ENABLE_GPU if (dev_input != TORCHFORT_DEVICE_CPU) { CHECK_CUDA(cudaSetDevice(dev_input)); @@ -76,9 +86,15 @@ void training_test(const std::string& model_config, int dev_model, int dev_input float* label_ptr = get_data_ptr(label, dev_input); float* output_ptr = get_data_ptr(output, dev_input); +#ifdef ENABLE_GPU + if (stream) { + CHECK_CUDA(cudaStreamSynchronize(stream)); + } +#endif + try { CHECK_TORCHFORT(torchfort_train(model_name.c_str(), input_ptr, shape.size(), shape.data(), label_ptr, shape.size(), - shape.data(), &loss_val, TORCHFORT_FLOAT, 0)); + shape.data(), &loss_val, TORCHFORT_FLOAT, stream)); if (should_fail_train) { FAIL() << "This test should fail train call, but did not."; } @@ -99,9 +115,15 @@ void training_test(const std::string& model_config, int dev_model, int dev_input if (!check_current_device(dev_input)) FAIL() << "GPU device switched by torchfort_train."; +#ifdef ENABLE_GPU + if (stream) { + CHECK_CUDA(cudaStreamSynchronize(stream)); + } +#endif + try { CHECK_TORCHFORT(torchfort_inference(model_name.c_str(), input_ptr, shape.size(), shape.data(), output_ptr, - shape.size(), shape.data(), TORCHFORT_FLOAT, 0)); + shape.size(), shape.data(), TORCHFORT_FLOAT, stream)); if (should_fail_inference) { FAIL() << "This test should fail inference call, but did not."; } @@ -122,6 +144,12 @@ void training_test(const std::string& model_config, int dev_model, int dev_input if (!check_current_device(dev_input)) FAIL() << "GPU device switched by torchfort_inference."; +#ifdef ENABLE_GPU + if (stream) { + CHECK_CUDA(cudaStreamSynchronize(stream)); + } +#endif + #ifdef ENABLE_GPU if (dev_input != TORCHFORT_DEVICE_CPU) { copy_to_host_vector(output, output_ptr); @@ -431,6 +459,7 @@ TEST(TorchFort, TrainTestMLPCPUGPU) { training_test("configs/mlp2.yaml", TORCHFORT_DEVICE_CPU, 0, {10, 2, 5}, false, false, false, false); } TEST(TorchFort, TrainTestMLPGPUGPU) { training_test("configs/mlp2.yaml", 0, 0, {10, 10}, false, false, false, false); } +TEST(TorchFort, TrainTestMLPGPUGPUStream) { training_test("configs/mlp2.yaml", 0, 0, {10, 10}, false, false, false, false, 0); } TEST(TorchFort, TrainTestMLPGPU1CPU) { training_test("configs/mlp2.yaml", 1, TORCHFORT_DEVICE_CPU, {10, 2, 5}, false, false, false, false); @@ -511,6 +540,10 @@ TEST(TorchFort, TrainTestMLPCPUCPU1DDimError) { training_test("configs/mlp2.yaml", TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU, {10}, false, true, true, false); } +#ifdef ENABLE_GPU +TEST(TorchFort, TrainTestMLPGPUGPUStreamWrongDeviceError) { training_test("configs/mlp2.yaml", 0, 0, {10, 10}, false, true, true, false, 1); } +#endif + int main(int argc, char* argv[]) { ::testing::InitGoogleTest(&argc, argv); From 83ebe32a049f56da7d8c639193f04bc5383e3368 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Mon, 8 Dec 2025 10:14:36 -0800 Subject: [PATCH 4/6] Update tests. Signed-off-by: Josh Romero --- tests/supervised/test_distributed_training.cpp | 16 +++++++++++----- tests/supervised/test_training.cpp | 7 ++++++- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/tests/supervised/test_distributed_training.cpp b/tests/supervised/test_distributed_training.cpp index 0c2f6f3..695931b 100644 --- a/tests/supervised/test_distributed_training.cpp +++ b/tests/supervised/test_distributed_training.cpp @@ -54,6 +54,12 @@ void training_test_distributed(const std::string& model_config, std::vector } #endif +#ifdef ENABLE_GPU + if (dev_input[rank] != TORCHFORT_DEVICE_CPU) { + CHECK_CUDA(cudaSetDevice(dev_input[rank])); + } +#endif + try { CHECK_TORCHFORT( torchfort_create_distributed_model(model_name.c_str(), model_config.c_str(), mpi_comm, dev_model[rank])); @@ -68,11 +74,7 @@ void training_test_distributed(const std::string& model_config, std::vector } } -#ifdef ENABLE_GPU - if (dev_input[rank] != TORCHFORT_DEVICE_CPU) { - CHECK_CUDA(cudaSetDevice(dev_input[rank])); - } -#endif + if (!check_current_device(dev_input)) FAIL() << "GPU device switched by torchfort_create_distributed_model."; auto input = generate_random(shape); auto label = generate_random(shape); @@ -104,6 +106,8 @@ void training_test_distributed(const std::string& model_config, std::vector } } + if (!check_current_device(dev_input)) FAIL() << "GPU device switched by torchfort_train."; + try { CHECK_TORCHFORT(torchfort_inference(model_name.c_str(), input_ptr, shape.size(), shape.data(), output_ptr, shape.size(), shape.data(), TORCHFORT_FLOAT, 0)); @@ -125,6 +129,8 @@ void training_test_distributed(const std::string& model_config, std::vector } } + if (!check_current_device(dev_input)) FAIL() << "GPU device switched by torchfort_inference."; + #ifdef ENABLE_GPU if (dev_input[rank] != TORCHFORT_DEVICE_CPU) { copy_to_host_vector(output, output_ptr); diff --git a/tests/supervised/test_training.cpp b/tests/supervised/test_training.cpp index b51b2f1..72f3069 100644 --- a/tests/supervised/test_training.cpp +++ b/tests/supervised/test_training.cpp @@ -55,7 +55,6 @@ void training_test(const std::string& model_config, int dev_model, int dev_input } #endif - #ifdef ENABLE_GPU if (dev_input != TORCHFORT_DEVICE_CPU) { CHECK_CUDA(cudaSetDevice(dev_input)); @@ -163,6 +162,12 @@ void training_test(const std::string& model_config, int dev_model, int dev_input free_data_ptr(input_ptr, dev_input); free_data_ptr(label_ptr, dev_input); free_data_ptr(output_ptr, dev_input); + +#ifdef ENABLE_GPU + if (stream) { + CHECK_CUDA(cudaStreamDestroy(stream)); + } +#endif } void training_test_multiarg(const std::string& model_config, int dev_model, int dev_input, bool use_extra_args, From 09631704d3301daa00997f3efe6ee88b04f6d281 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Mon, 8 Dec 2025 10:26:39 -0800 Subject: [PATCH 5/6] Update tests. Signed-off-by: Josh Romero --- tests/supervised/test_distributed_training.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/supervised/test_distributed_training.cpp b/tests/supervised/test_distributed_training.cpp index 695931b..7388bd9 100644 --- a/tests/supervised/test_distributed_training.cpp +++ b/tests/supervised/test_distributed_training.cpp @@ -74,7 +74,7 @@ void training_test_distributed(const std::string& model_config, std::vector } } - if (!check_current_device(dev_input)) FAIL() << "GPU device switched by torchfort_create_distributed_model."; + if (!check_current_device(dev_input[rank])) FAIL() << "GPU device switched by torchfort_create_distributed_model."; auto input = generate_random(shape); auto label = generate_random(shape); @@ -106,7 +106,7 @@ void training_test_distributed(const std::string& model_config, std::vector } } - if (!check_current_device(dev_input)) FAIL() << "GPU device switched by torchfort_train."; + if (!check_current_device(dev_input[rank])) FAIL() << "GPU device switched by torchfort_train."; try { CHECK_TORCHFORT(torchfort_inference(model_name.c_str(), input_ptr, shape.size(), shape.data(), output_ptr, @@ -129,7 +129,7 @@ void training_test_distributed(const std::string& model_config, std::vector } } - if (!check_current_device(dev_input)) FAIL() << "GPU device switched by torchfort_inference."; + if (!check_current_device(dev_input[rank])) FAIL() << "GPU device switched by torchfort_inference."; #ifdef ENABLE_GPU if (dev_input[rank] != TORCHFORT_DEVICE_CPU) { From 9e905072b638638b2a799cf44dece31528f1e17a Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Mon, 8 Dec 2025 10:38:49 -0800 Subject: [PATCH 6/6] Formatting fixes. Signed-off-by: Josh Romero --- src/csrc/cuda_wrap.cpp | 42 +++++++++---------- src/csrc/include/internal/defines.h | 10 +++-- src/csrc/include/internal/utils.h | 3 +- src/csrc/utils.cpp | 6 ++- .../supervised/test_distributed_training.cpp | 9 ++-- tests/supervised/test_training.cpp | 41 +++++++++++------- 6 files changed, 67 insertions(+), 44 deletions(-) diff --git a/src/csrc/cuda_wrap.cpp b/src/csrc/cuda_wrap.cpp index 3e4c953..c6b8b5e 100644 --- a/src/csrc/cuda_wrap.cpp +++ b/src/csrc/cuda_wrap.cpp @@ -22,32 +22,32 @@ #include "internal/defines.h" #if CUDART_VERSION >= 13000 -#define LOAD_SYM(symbol, version, optional) \ - do { \ - cudaDriverEntryPointQueryResult driverStatus = cudaDriverEntryPointSymbolNotFound; \ +#define LOAD_SYM(symbol, version, optional) \ + do { \ + cudaDriverEntryPointQueryResult driverStatus = cudaDriverEntryPointSymbolNotFound; \ cudaError_t err = cudaGetDriverEntryPointByVersion(#symbol, (void**)(&cuFnTable.pfn_##symbol), version, \ - cudaEnableDefault, &driverStatus)); \ - if ((driverStatus != cudaDriverEntryPointSuccess || err != cudaSuccess) && !optional) { \ - THROW_CUDA_ERROR("cudaGetDriverEntryPointByVersion failed."); \ - } \ + cudaEnableDefault, &driverStatus)); \ + if ((driverStatus != cudaDriverEntryPointSuccess || err != cudaSuccess) && !optional) { \ + THROW_CUDA_ERROR("cudaGetDriverEntryPointByVersion failed."); \ + } \ } while (false) #elif CUDART_VERSION >= 12000 -#define LOAD_SYM(symbol, version, optional) \ - do { \ - cudaDriverEntryPointQueryResult driverStatus = cudaDriverEntryPointSymbolNotFound; \ - cudaError_t err = cudaGetDriverEntryPoint(#symbol, (void**)(&cuFnTable.pfn_##symbol), cudaEnableDefault, \ - &driverStatus); \ - if ((driverStatus != cudaDriverEntryPointSuccess || err != cudaSuccess) && !optional) { \ - THROW_CUDA_ERROR("cudaGetDriverEntryPoint failed."); \ - } \ +#define LOAD_SYM(symbol, version, optional) \ + do { \ + cudaDriverEntryPointQueryResult driverStatus = cudaDriverEntryPointSymbolNotFound; \ + cudaError_t err = \ + cudaGetDriverEntryPoint(#symbol, (void**)(&cuFnTable.pfn_##symbol), cudaEnableDefault, &driverStatus); \ + if ((driverStatus != cudaDriverEntryPointSuccess || err != cudaSuccess) && !optional) { \ + THROW_CUDA_ERROR("cudaGetDriverEntryPoint failed."); \ + } \ } while (false) #else -#define LOAD_SYM(symbol, version, optional) \ - do { \ - cudaError_t err = cudaGetDriverEntryPoint(#symbol, (void**)(&cuFnTable.pfn_##symbol), cudaEnableDefault); \ - if (err != cudaSuccess && !optional) { \ - THROW_CUDA_ERROR("cudaGetDriverEntryPoint failed."); \ - } \ +#define LOAD_SYM(symbol, version, optional) \ + do { \ + cudaError_t err = cudaGetDriverEntryPoint(#symbol, (void**)(&cuFnTable.pfn_##symbol), cudaEnableDefault); \ + if (err != cudaSuccess && !optional) { \ + THROW_CUDA_ERROR("cudaGetDriverEntryPoint failed."); \ + } \ } while (false) #endif diff --git a/src/csrc/include/internal/defines.h b/src/csrc/include/internal/defines.h index 0ff766e..fe3e80a 100644 --- a/src/csrc/include/internal/defines.h +++ b/src/csrc/include/internal/defines.h @@ -47,7 +47,9 @@ #define CHECK_CUDA_DRV(call) \ do { \ - if (!cuFnTable.initialized) {initCuFunctionTable();} \ + if (!cuFnTable.initialized) { \ + initCuFunctionTable(); \ + } \ CUresult err = cuFnTable.pfn_##call; \ if (CUDA_SUCCESS != err) { \ const char* error_str; \ @@ -85,11 +87,13 @@ } while (false) #define IS_CUDA_DRV_FUNC_AVAILABLE(symbol) \ - ([&]() { if (!cuFnTable.initialized) {initCuFunctionTable();} \ + ([&]() { \ + if (!cuFnTable.initialized) { \ + initCuFunctionTable(); \ + } \ return cuFnTable.pfn_##symbol != nullptr; \ })() - #define BEGIN_MODEL_REGISTRY \ static std::unordered_map()>> model_registry { diff --git a/src/csrc/include/internal/utils.h b/src/csrc/include/internal/utils.h index 4b7de11..39df334 100644 --- a/src/csrc/include/internal/utils.h +++ b/src/csrc/include/internal/utils.h @@ -120,6 +120,7 @@ std::vector get_current_lrs(const char* name); #ifdef ENABLE_GPU // Helper function to set the device and stream with device checks -void set_device_and_stream(c10::cuda::OptionalCUDAStreamGuard& stream_guard, c10::cuda::OptionalCUDAGuard& cuda_guard, torch::Device device, cudaStream_t ext_stream); +void set_device_and_stream(c10::cuda::OptionalCUDAStreamGuard& stream_guard, c10::cuda::OptionalCUDAGuard& cuda_guard, + torch::Device device, cudaStream_t ext_stream); #endif } // namespace torchfort diff --git a/src/csrc/utils.cpp b/src/csrc/utils.cpp index ae3fdc7..e601d3e 100644 --- a/src/csrc/utils.cpp +++ b/src/csrc/utils.cpp @@ -130,7 +130,8 @@ int getStreamDevice(cudaStream_t stream) { return (int)device; } -void set_device_and_stream(c10::cuda::OptionalCUDAStreamGuard& stream_guard, c10::cuda::OptionalCUDAGuard& cuda_guard, torch::Device device, cudaStream_t ext_stream) { +void set_device_and_stream(c10::cuda::OptionalCUDAStreamGuard& stream_guard, c10::cuda::OptionalCUDAGuard& cuda_guard, + torch::Device device, cudaStream_t ext_stream) { if (device.is_cuda()) { cuda_guard.set_device(device); if (ext_stream) { @@ -138,7 +139,8 @@ void set_device_and_stream(c10::cuda::OptionalCUDAStreamGuard& stream_guard, c10 ext_stream_device = getStreamDevice(ext_stream); if (ext_stream_device != device.index()) { std::stringstream ss; - ss << "The provided external stream is on device " << get_device(ext_stream_device) << " but the model is on device " << device << "."; + ss << "The provided external stream is on device " << get_device(ext_stream_device) + << " but the model is on device " << device << "."; THROW_INVALID_USAGE(ss.str()); } stream_guard.reset_stream(c10::cuda::getStreamFromExternal(ext_stream, device.index())); diff --git a/tests/supervised/test_distributed_training.cpp b/tests/supervised/test_distributed_training.cpp index 7388bd9..6a34592 100644 --- a/tests/supervised/test_distributed_training.cpp +++ b/tests/supervised/test_distributed_training.cpp @@ -74,7 +74,8 @@ void training_test_distributed(const std::string& model_config, std::vector } } - if (!check_current_device(dev_input[rank])) FAIL() << "GPU device switched by torchfort_create_distributed_model."; + if (!check_current_device(dev_input[rank])) + FAIL() << "GPU device switched by torchfort_create_distributed_model."; auto input = generate_random(shape); auto label = generate_random(shape); @@ -106,7 +107,8 @@ void training_test_distributed(const std::string& model_config, std::vector } } - if (!check_current_device(dev_input[rank])) FAIL() << "GPU device switched by torchfort_train."; + if (!check_current_device(dev_input[rank])) + FAIL() << "GPU device switched by torchfort_train."; try { CHECK_TORCHFORT(torchfort_inference(model_name.c_str(), input_ptr, shape.size(), shape.data(), output_ptr, @@ -129,7 +131,8 @@ void training_test_distributed(const std::string& model_config, std::vector } } - if (!check_current_device(dev_input[rank])) FAIL() << "GPU device switched by torchfort_inference."; + if (!check_current_device(dev_input[rank])) + FAIL() << "GPU device switched by torchfort_inference."; #ifdef ENABLE_GPU if (dev_input[rank] != TORCHFORT_DEVICE_CPU) { diff --git a/tests/supervised/test_training.cpp b/tests/supervised/test_training.cpp index 72f3069..0986752 100644 --- a/tests/supervised/test_training.cpp +++ b/tests/supervised/test_training.cpp @@ -33,7 +33,7 @@ void training_test(const std::string& model_config, int dev_model, int dev_input, std::vector shape, bool should_fail_create, bool should_fail_train, bool should_fail_inference, bool check_result, - int dev_stream=-1) { + int dev_stream = -1) { std::string model_name = generate_random_name(10); @@ -74,7 +74,8 @@ void training_test(const std::string& model_config, int dev_model, int dev_input } } - if (!check_current_device(dev_input)) FAIL() << "GPU device switched by torchfort_create_model."; + if (!check_current_device(dev_input)) + FAIL() << "GPU device switched by torchfort_create_model."; auto input = generate_random(shape); auto label = generate_random(shape); @@ -112,7 +113,8 @@ void training_test(const std::string& model_config, int dev_model, int dev_input } } - if (!check_current_device(dev_input)) FAIL() << "GPU device switched by torchfort_train."; + if (!check_current_device(dev_input)) + FAIL() << "GPU device switched by torchfort_train."; #ifdef ENABLE_GPU if (stream) { @@ -141,7 +143,8 @@ void training_test(const std::string& model_config, int dev_model, int dev_input } } - if (!check_current_device(dev_input)) FAIL() << "GPU device switched by torchfort_inference."; + if (!check_current_device(dev_input)) + FAIL() << "GPU device switched by torchfort_inference."; #ifdef ENABLE_GPU if (stream) { @@ -193,7 +196,8 @@ void training_test_multiarg(const std::string& model_config, int dev_model, int CHECK_TORCHFORT(torchfort_create_model(model_name.c_str(), model_config.c_str(), dev_model)); - if (!check_current_device(dev_input)) FAIL() << "GPU device switched by torchfort_create_model."; + if (!check_current_device(dev_input)) + FAIL() << "GPU device switched by torchfort_create_model."; std::vector shape = {10, 10}; std::vector> inputs(2), labels(2), outputs(2); @@ -231,7 +235,8 @@ void training_test_multiarg(const std::string& model_config, int dev_model, int torchfort_tensor_list_add_tensor(outputs_tl, output_ptrs[i], shape.size(), shape.data(), TORCHFORT_FLOAT)); } - if (!check_current_device(dev_input)) FAIL() << "GPU device switched by torchfort_tensor_list_add_tensor."; + if (!check_current_device(dev_input)) + FAIL() << "GPU device switched by torchfort_tensor_list_add_tensor."; torchfort_tensor_list_t extra_args_tl; std::vector extra_args_ptrs(2); @@ -244,8 +249,6 @@ void training_test_multiarg(const std::string& model_config, int dev_model, int } } - - try { CHECK_TORCHFORT(torchfort_train_multiarg(model_name.c_str(), inputs_tl, labels_tl, &loss_val, (use_extra_args) ? extra_args_tl : nullptr, 0)); @@ -260,7 +263,8 @@ void training_test_multiarg(const std::string& model_config, int dev_model, int } } - if (!check_current_device(dev_input)) FAIL() << "GPU device switched by torchfort_train_multiarg."; + if (!check_current_device(dev_input)) + FAIL() << "GPU device switched by torchfort_train_multiarg."; try { CHECK_TORCHFORT(torchfort_inference_multiarg(model_name.c_str(), inputs_tl, outputs_tl, 0)); @@ -275,7 +279,8 @@ void training_test_multiarg(const std::string& model_config, int dev_model, int } } - if (!check_current_device(dev_input)) FAIL() << "GPU device switched by torchfort_inference_multiarg."; + if (!check_current_device(dev_input)) + FAIL() << "GPU device switched by torchfort_inference_multiarg."; // Check inference output if (check_result) { @@ -464,7 +469,9 @@ TEST(TorchFort, TrainTestMLPCPUGPU) { training_test("configs/mlp2.yaml", TORCHFORT_DEVICE_CPU, 0, {10, 2, 5}, false, false, false, false); } TEST(TorchFort, TrainTestMLPGPUGPU) { training_test("configs/mlp2.yaml", 0, 0, {10, 10}, false, false, false, false); } -TEST(TorchFort, TrainTestMLPGPUGPUStream) { training_test("configs/mlp2.yaml", 0, 0, {10, 10}, false, false, false, false, 0); } +TEST(TorchFort, TrainTestMLPGPUGPUStream) { + training_test("configs/mlp2.yaml", 0, 0, {10, 10}, false, false, false, false, 0); +} TEST(TorchFort, TrainTestMLPGPU1CPU) { training_test("configs/mlp2.yaml", 1, TORCHFORT_DEVICE_CPU, {10, 2, 5}, false, false, false, false); @@ -472,8 +479,12 @@ TEST(TorchFort, TrainTestMLPGPU1CPU) { TEST(TorchFort, TrainTestMLPCPUGPU1) { training_test("configs/mlp2.yaml", TORCHFORT_DEVICE_CPU, 1, {10, 2, 5}, false, false, false, false); } -TEST(TorchFort, TrainTestMLPGPU0GPU1) { training_test("configs/mlp2.yaml", 0, 1, {10, 10}, false, false, false, false); } -TEST(TorchFort, TrainTestMLPGPU1GPU0) { training_test("configs/mlp2.yaml", 1, 0, {10, 10}, false, false, false, false); } +TEST(TorchFort, TrainTestMLPGPU0GPU1) { + training_test("configs/mlp2.yaml", 0, 1, {10, 10}, false, false, false, false); +} +TEST(TorchFort, TrainTestMLPGPU1GPU0) { + training_test("configs/mlp2.yaml", 1, 0, {10, 10}, false, false, false, false); +} TEST(TorchFort, TrainTestTorchScriptCPUGPU) { training_test("configs/torchscript.yaml", TORCHFORT_DEVICE_CPU, 0, {10, 2, 10}, false, false, false, true); } @@ -546,7 +557,9 @@ TEST(TorchFort, TrainTestMLPCPUCPU1DDimError) { } #ifdef ENABLE_GPU -TEST(TorchFort, TrainTestMLPGPUGPUStreamWrongDeviceError) { training_test("configs/mlp2.yaml", 0, 0, {10, 10}, false, true, true, false, 1); } +TEST(TorchFort, TrainTestMLPGPUGPUStreamWrongDeviceError) { + training_test("configs/mlp2.yaml", 0, 0, {10, 10}, false, true, true, false, 1); +} #endif int main(int argc, char* argv[]) {