From 41a890c3e0f997f6d7f7d9984aa514dca9ffd56a Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Thu, 4 Dec 2025 14:01:01 -0800 Subject: [PATCH 1/4] Add missing GPU device handling for NCCL initialization/usage for distibuted training. Signed-off-by: Josh Romero --- src/csrc/distributed.cpp | 40 +++++++++++++++++++++---- src/csrc/include/internal/distributed.h | 5 ++-- src/csrc/rl/off_policy/ddpg.cpp | 20 ++++++------- src/csrc/rl/off_policy/sac.cpp | 16 +++++----- src/csrc/rl/off_policy/td3.cpp | 20 ++++++------- src/csrc/rl/on_policy/ppo.cpp | 8 ++--- src/csrc/torchfort.cpp | 4 +-- 7 files changed, 71 insertions(+), 42 deletions(-) diff --git a/src/csrc/distributed.cpp b/src/csrc/distributed.cpp index ca5da0c..198f8d8 100644 --- a/src/csrc/distributed.cpp +++ b/src/csrc/distributed.cpp @@ -20,6 +20,7 @@ #include #include +#include #endif #include @@ -73,12 +74,15 @@ static ncclComm_t ncclCommFromMPIComm(MPI_Comm mpi_comm) { } #endif -void Comm::initialize(bool initialize_nccl) { +void Comm::initialize() { CHECK_MPI(MPI_Comm_rank(mpi_comm, &rank)); CHECK_MPI(MPI_Comm_size(mpi_comm, &size)); #ifdef ENABLE_GPU - if (initialize_nccl) { + // Initialize NCCL if device if GPU + if (device.is_cuda()) { + c10::cuda::CUDAGuard guard(device); + nccl_comm = ncclCommFromMPIComm(mpi_comm); int greatest_priority; @@ -94,13 +98,17 @@ void Comm::initialize(bool initialize_nccl) { void Comm::finalize() { #ifdef ENABLE_GPU - if (nccl_comm) + if (nccl_comm) { + c10::cuda::CUDAGuard guard(device); CHECK_NCCL(ncclCommDestroy(nccl_comm)); - if (stream) CHECK_CUDA(cudaStreamDestroy(stream)); - if (event) CHECK_CUDA(cudaEventDestroy(event)); + nccl_comm = nullptr; + stream = nullptr; + event = nullptr; + } #endif + initialized = false; } void Comm::allreduce(torch::Tensor& tensor, bool average) const { @@ -116,6 +124,14 @@ void Comm::allreduce(torch::Tensor& tensor, bool average) const { #ifdef ENABLE_GPU if (tensor.device().type() == torch::kCUDA) { + if (tensor.device() != device) { + std::stringstream ss; + ss << "allreduce called with tensor on device " << tensor.device() << " but the comm was initialized on device " << device << "."; + THROW_INVALID_USAGE(ss.str()); + } + + c10::cuda::CUDAGuard guard(device); + auto torch_stream = c10::cuda::getCurrentCUDAStream().stream(); CHECK_CUDA(cudaEventRecord(event, torch_stream)); CHECK_CUDA(cudaStreamWaitEvent(stream, event)); @@ -157,6 +173,8 @@ void Comm::allreduce(std::vector& tensors, bool average) const { #ifdef ENABLE_GPU if (tensors[0].device().type() == torch::kCUDA) { + c10::cuda::CUDAGuard guard(device); + auto torch_stream = c10::cuda::getCurrentCUDAStream().stream(); CHECK_CUDA(cudaEventRecord(event, torch_stream)); CHECK_CUDA(cudaStreamWaitEvent(stream, event)); @@ -173,6 +191,8 @@ void Comm::allreduce(std::vector& tensors, bool average) const { #ifdef ENABLE_GPU if (tensors[0].device().type() == torch::kCUDA) { + c10::cuda::CUDAGuard guard(device); + CHECK_NCCL(ncclGroupEnd()); auto torch_stream = c10::cuda::getCurrentCUDAStream().stream(); CHECK_CUDA(cudaEventRecord(event, stream)); @@ -201,6 +221,14 @@ void Comm::broadcast(torch::Tensor& tensor, int root) const { auto count = torch::numel(tensor); #ifdef ENABLE_GPU if (tensor.device().type() == torch::kCUDA) { + if (tensor.device() != device) { + std::stringstream ss; + ss << "broadcast called with tensor on device " << tensor.device() << " but the comm was initialized on device " << device << "."; + THROW_INVALID_USAGE(ss.str()); + } + + c10::cuda::CUDAGuard guard(device); + // Use NCCL for GPU tensors ncclDataType_t nccl_dtype; if (torch::is_complex(tensor)) { @@ -210,7 +238,7 @@ void Comm::broadcast(torch::Tensor& tensor, int root) const { nccl_dtype = get_nccl_dtype(tensor); } - auto torch_stream = c10::cuda::getCurrentCUDAStream(tensor.device().index()).stream(); + auto torch_stream = c10::cuda::getCurrentCUDAStream().stream(); CHECK_CUDA(cudaEventRecord(event, torch_stream)); CHECK_CUDA(cudaStreamWaitEvent(stream, event)); diff --git a/src/csrc/include/internal/distributed.h b/src/csrc/include/internal/distributed.h index cfadd13..6b54f18 100644 --- a/src/csrc/include/internal/distributed.h +++ b/src/csrc/include/internal/distributed.h @@ -28,7 +28,7 @@ namespace torchfort { struct Comm { - void initialize(bool initialize_nccl = false); + void initialize(); void finalize(); void allreduce(torch::Tensor& tensor, bool average = false) const; void allreduce(std::vector& tensors, bool average = false) const; @@ -38,6 +38,7 @@ struct Comm { int rank; int size; + torch::Device device; MPI_Comm mpi_comm; #ifdef ENABLE_GPU ncclComm_t nccl_comm = nullptr; @@ -46,7 +47,7 @@ struct Comm { #endif bool initialized = false; - Comm(MPI_Comm mpi_comm) : mpi_comm(mpi_comm){}; + Comm(MPI_Comm mpi_comm, torch::Device device) : mpi_comm(mpi_comm), device(device){}; }; } // namespace torchfort diff --git a/src/csrc/rl/off_policy/ddpg.cpp b/src/csrc/rl/off_policy/ddpg.cpp index d90efe0..10d9fb6 100644 --- a/src/csrc/rl/off_policy/ddpg.cpp +++ b/src/csrc/rl/off_policy/ddpg.cpp @@ -256,18 +256,18 @@ int DDPGSystem::getRank() const { void DDPGSystem::initSystemComm(MPI_Comm mpi_comm) { // Set up distributed communicators for all models // system - system_comm_ = std::make_shared(mpi_comm); - system_comm_->initialize(model_device_.is_cuda()); + system_comm_ = std::make_shared(mpi_comm, model_device_); + system_comm_->initialize(); // policy - p_model_.comm = std::make_shared(mpi_comm); - p_model_.comm->initialize(model_device_.is_cuda()); - p_model_target_.comm = std::make_shared(mpi_comm); - p_model_target_.comm->initialize(model_device_.is_cuda()); + p_model_.comm = std::make_shared(mpi_comm, model_device_); + p_model_.comm->initialize(); + p_model_target_.comm = std::make_shared(mpi_comm, model_device_); + p_model_target_.comm->initialize(); // critic - q_model_.comm = std::make_shared(mpi_comm); - q_model_.comm->initialize(model_device_.is_cuda()); - q_model_target_.comm = std::make_shared(mpi_comm); - q_model_target_.comm->initialize(model_device_.is_cuda()); + q_model_.comm = std::make_shared(mpi_comm, model_device_); + q_model_.comm->initialize(); + q_model_target_.comm = std::make_shared(mpi_comm, model_device_); + q_model_target_.comm->initialize(); // move to device before broadcasting // policy diff --git a/src/csrc/rl/off_policy/sac.cpp b/src/csrc/rl/off_policy/sac.cpp index c1905ee..917820f 100644 --- a/src/csrc/rl/off_policy/sac.cpp +++ b/src/csrc/rl/off_policy/sac.cpp @@ -291,19 +291,19 @@ int SACSystem::getRank() const { void SACSystem::initSystemComm(MPI_Comm mpi_comm) { // Set up distributed communicators for all models // system - system_comm_ = std::make_shared(mpi_comm); - system_comm_->initialize(model_device_.is_cuda()); + system_comm_ = std::make_shared(mpi_comm, model_device_); + system_comm_->initialize(); // policy - p_model_.comm = std::make_shared(mpi_comm); - p_model_.comm->initialize(model_device_.is_cuda()); + p_model_.comm = std::make_shared(mpi_comm, model_device_); + p_model_.comm->initialize(); // critic for (auto& q_model : q_models_) { - q_model.comm = std::make_shared(mpi_comm); - q_model.comm->initialize(model_device_.is_cuda()); + q_model.comm = std::make_shared(mpi_comm, model_device_); + q_model.comm->initialize(); } for (auto& q_model_target : q_models_target_) { - q_model_target.comm = std::make_shared(mpi_comm); - q_model_target.comm->initialize(model_device_.is_cuda()); + q_model_target.comm = std::make_shared(mpi_comm, model_device_); + q_model_target.comm->initialize(); } // we do not need an alpha comm objects since p-model comm is used for that diff --git a/src/csrc/rl/off_policy/td3.cpp b/src/csrc/rl/off_policy/td3.cpp index 7233d75..2f2c900 100644 --- a/src/csrc/rl/off_policy/td3.cpp +++ b/src/csrc/rl/off_policy/td3.cpp @@ -274,21 +274,21 @@ int TD3System::getRank() const { void TD3System::initSystemComm(MPI_Comm mpi_comm) { // Set up distributed communicators for all models // system - system_comm_ = std::make_shared(mpi_comm); - system_comm_->initialize(model_device_.is_cuda()); + system_comm_ = std::make_shared(mpi_comm, model_device_); + system_comm_->initialize(); // policy - p_model_.comm = std::make_shared(mpi_comm); - p_model_.comm->initialize(model_device_.is_cuda()); - p_model_target_.comm = std::make_shared(mpi_comm); - p_model_target_.comm->initialize(model_device_.is_cuda()); + p_model_.comm = std::make_shared(mpi_comm, model_device_); + p_model_.comm->initialize(); + p_model_target_.comm = std::make_shared(mpi_comm, model_device_); + p_model_target_.comm->initialize(); // critic for (auto& q_model : q_models_) { - q_model.comm = std::make_shared(mpi_comm); - q_model.comm->initialize(model_device_.is_cuda()); + q_model.comm = std::make_shared(mpi_comm, model_device_); + q_model.comm->initialize(); } for (auto& q_model_target : q_models_target_) { - q_model_target.comm = std::make_shared(mpi_comm); - q_model_target.comm->initialize(model_device_.is_cuda()); + q_model_target.comm = std::make_shared(mpi_comm, model_device_); + q_model_target.comm->initialize(); } // move to device before broadcasting diff --git a/src/csrc/rl/on_policy/ppo.cpp b/src/csrc/rl/on_policy/ppo.cpp index 70999e5..728c5fc 100644 --- a/src/csrc/rl/on_policy/ppo.cpp +++ b/src/csrc/rl/on_policy/ppo.cpp @@ -188,11 +188,11 @@ torch::Device PPOSystem::rbDevice() const { return rb_device_; } void PPOSystem::initSystemComm(MPI_Comm mpi_comm) { // Set up distributed communicators for all models // system - system_comm_ = std::make_shared(mpi_comm); - system_comm_->initialize(model_device_.is_cuda()); + system_comm_ = std::make_shared(mpi_comm, model_device_); + system_comm_->initialize(); // policy - pq_model_.comm = std::make_shared(mpi_comm); - pq_model_.comm->initialize(model_device_.is_cuda()); + pq_model_.comm = std::make_shared(mpi_comm, model_device_); + pq_model_.comm->initialize(); // move to device before broadcasting pq_model_.model->to(model_device_); diff --git a/src/csrc/torchfort.cpp b/src/csrc/torchfort.cpp index e52f830..b973ff9 100644 --- a/src/csrc/torchfort.cpp +++ b/src/csrc/torchfort.cpp @@ -140,8 +140,8 @@ torchfort_result_t torchfort_create_distributed_model(const char* name, const ch torchfort_create_model(name, config_fname, device); // Set up distributed communicator - models[name].comm = std::shared_ptr(new Comm(mpi_comm)); - models[name].comm->initialize(models[name].model->device().is_cuda()); + models[name].comm = std::make_shared(mpi_comm, get_device(device)); + models[name].comm->initialize(); // Broadcast initial model parameters from rank 0 for (auto& p : models[name].model->parameters()) { From e13fa8292ed7f62f305eb9b2abdcc9082a01c431 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Thu, 4 Dec 2025 15:05:07 -0800 Subject: [PATCH 2/4] Add tests. Signed-off-by: Josh Romero --- tests/supervised/CMakeLists.txt | 7 + .../supervised/test_distributed_training.cpp | 175 ++++++++++++++++++ 2 files changed, 182 insertions(+) create mode 100644 tests/supervised/test_distributed_training.cpp diff --git a/tests/supervised/CMakeLists.txt b/tests/supervised/CMakeLists.txt index 7880e2b..ceb7902 100644 --- a/tests/supervised/CMakeLists.txt +++ b/tests/supervised/CMakeLists.txt @@ -3,6 +3,7 @@ cmake_minimum_required(VERSION 3.14) set(test_targets test_checkpoint test_training + test_distributed_training ) add_executable(test_checkpoint) @@ -17,6 +18,12 @@ target_sources(test_training test_training.cpp ) +add_executable(test_distributed_training) +target_sources(test_distributed_training + PRIVATE + test_distributed_training.cpp + ) + find_package(Python 3.6 COMPONENTS Interpreter Development REQUIRED) foreach(tgt ${test_targets}) diff --git a/tests/supervised/test_distributed_training.cpp b/tests/supervised/test_distributed_training.cpp new file mode 100644 index 0000000..a81ad46 --- /dev/null +++ b/tests/supervised/test_distributed_training.cpp @@ -0,0 +1,175 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#ifdef ENABLE_GPU +#include +#endif + +#include +#include + +#include "internal/defines.h" +#include "internal/exceptions.h" +#include "internal/utils.h" +#include "torchfort.h" + +#include "test_utils.h" + +void training_test_distributed(const std::string& model_config, std::vector dev_model, std::vector dev_input, std::vector shape, + bool should_fail_create, bool should_fail_train, bool should_fail_inference, bool check_result) { + + std::string model_name = generate_random_name(10); + MPI_Comm mpi_comm = MPI_COMM_WORLD; + int rank, size; + CHECK_MPI(MPI_Comm_rank(mpi_comm, &rank)); + CHECK_MPI(MPI_Comm_size(mpi_comm, &size)); + + // Skip tests if not running with 2 ranks + if (size != 2) { + GTEST_SKIP() << "This test requires 2 ranks to run. Skipping."; + } + +#ifdef ENABLE_GPU + int ngpu; + cudaGetDeviceCount(&ngpu); + if (ngpu < 2) { + GTEST_SKIP() << "This test requires at least 2 GPUs. Skipping."; + } +#endif + + try { + CHECK_TORCHFORT(torchfort_create_distributed_model(model_name.c_str(), model_config.c_str(), mpi_comm, dev_model[rank])); + if (should_fail_create) { + FAIL() << "This test should fail create call, but did not."; + } + } catch (const torchfort::BaseException& e) { + if (should_fail_create) { + // pass + } else { + FAIL(); + } + } + +#ifdef ENABLE_GPU + if (dev_input[rank] != TORCHFORT_DEVICE_CPU) { + CHECK_CUDA(cudaSetDevice(dev_input[rank])); + } +#endif + + auto input = generate_random(shape); + auto label = generate_random(shape); + auto output = generate_random(shape); + float loss_val; + + float* input_ptr = get_data_ptr(input, dev_input[rank]); + float* label_ptr = get_data_ptr(label, dev_input[rank]); + float* output_ptr = get_data_ptr(output, dev_input[rank]); + + try { + CHECK_TORCHFORT(torchfort_train(model_name.c_str(), input_ptr, shape.size(), shape.data(), label_ptr, shape.size(), + shape.data(), &loss_val, TORCHFORT_FLOAT, 0)); + if (should_fail_train) { + FAIL() << "This test should fail train call, but did not."; + } + } catch (const torchfort::BaseException& e) { + if (should_fail_train) { + // pass + } else { + FAIL(); + } + } catch (const c10::Error& e) { + std::cout << e.what() << std::endl; + if (should_fail_train) { + // pass + } else { + FAIL(); + } + } + + try { + CHECK_TORCHFORT(torchfort_inference(model_name.c_str(), input_ptr, shape.size(), shape.data(), output_ptr, + shape.size(), shape.data(), TORCHFORT_FLOAT, 0)); + if (should_fail_inference) { + FAIL() << "This test should fail inference call, but did not."; + } + } catch (const torchfort::BaseException& e) { + if (should_fail_inference) { + // pass + } else { + FAIL(); + } + } catch (const c10::Error& e) { + std::cout << e.what() << std::endl; + if (should_fail_train) { + // pass + } else { + FAIL(); + } + } + +#ifdef ENABLE_GPU + if (dev_input[rank] != TORCHFORT_DEVICE_CPU) { + copy_to_host_vector(output, output_ptr); + } +#endif + + if (check_result) { + EXPECT_EQ(input, output); + } + + free_data_ptr(input_ptr, dev_input[rank]); + free_data_ptr(label_ptr, dev_input[rank]); + free_data_ptr(output_ptr, dev_input[rank]); +} + + +TEST(TorchFort, TrainTestDistributedMLPCPUCPU) { + training_test_distributed("configs/mlp2.yaml", {TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU}, {TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU}, {10, 2, 5}, false, false, false, + false); +} + +#ifdef ENABLE_GPU +TEST(TorchFort, TrainTestDistributedMLPGPUCPU) { + training_test_distributed("configs/mlp2.yaml", {0, 1}, {TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU}, {10, 2, 5}, false, false, false, false); +} +TEST(TorchFort, TrainTestDistributedMLPGPUReverseCPU) { + training_test_distributed("configs/mlp2.yaml", {1, 0}, {TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU}, {10, 2, 5}, false, false, false, false); +} +TEST(TorchFort, TrainTestDistributedMLPCPUGPU) { + training_test_distributed("configs/mlp2.yaml", {TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU}, {0, 1}, {10, 2, 5}, false, false, false, false); +} +TEST(TorchFort, TrainTestDistributedMLPCPUGPUReverse) { + training_test_distributed("configs/mlp2.yaml", {TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU}, {1, 0}, {10, 2, 5}, false, false, false, false); +} +TEST(TorchFort, TrainTestDistributedMLPGPUGPU) { training_test_distributed("configs/mlp2.yaml", {0, 1}, {0, 1}, {10, 10}, false, false, false, false); } +TEST(TorchFort, TrainTestDistributedMLPGPUGPUReverse) { training_test_distributed("configs/mlp2.yaml", {0, 1}, {1, 0}, {10, 10}, false, false, false, false); } +#endif + +// Testing expected error cases + +int main(int argc, char* argv[]) { + ::testing::InitGoogleTest(&argc, argv); + MPI_Init(&argc, &argv); + + + int result = RUN_ALL_TESTS(); + MPI_Finalize(); + + return result; +} From 6c6554c30ae4c20d090e9f82ce251bd0dfc64e14 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Fri, 5 Dec 2025 10:29:03 -0800 Subject: [PATCH 3/4] Formatting fixes. Signed-off-by: Josh Romero --- src/csrc/distributed.cpp | 8 +++-- .../supervised/test_distributed_training.cpp | 36 +++++++++++-------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/src/csrc/distributed.cpp b/src/csrc/distributed.cpp index 198f8d8..352a8cc 100644 --- a/src/csrc/distributed.cpp +++ b/src/csrc/distributed.cpp @@ -19,8 +19,8 @@ #ifdef ENABLE_GPU #include -#include #include +#include #endif #include @@ -126,7 +126,8 @@ void Comm::allreduce(torch::Tensor& tensor, bool average) const { if (tensor.device().type() == torch::kCUDA) { if (tensor.device() != device) { std::stringstream ss; - ss << "allreduce called with tensor on device " << tensor.device() << " but the comm was initialized on device " << device << "."; + ss << "allreduce called with tensor on device " << tensor.device() << " but the comm was initialized on device " + << device << "."; THROW_INVALID_USAGE(ss.str()); } @@ -223,7 +224,8 @@ void Comm::broadcast(torch::Tensor& tensor, int root) const { if (tensor.device().type() == torch::kCUDA) { if (tensor.device() != device) { std::stringstream ss; - ss << "broadcast called with tensor on device " << tensor.device() << " but the comm was initialized on device " << device << "."; + ss << "broadcast called with tensor on device " << tensor.device() << " but the comm was initialized on device " + << device << "."; THROW_INVALID_USAGE(ss.str()); } diff --git a/tests/supervised/test_distributed_training.cpp b/tests/supervised/test_distributed_training.cpp index a81ad46..0c2f6f3 100644 --- a/tests/supervised/test_distributed_training.cpp +++ b/tests/supervised/test_distributed_training.cpp @@ -31,8 +31,9 @@ #include "test_utils.h" -void training_test_distributed(const std::string& model_config, std::vector dev_model, std::vector dev_input, std::vector shape, - bool should_fail_create, bool should_fail_train, bool should_fail_inference, bool check_result) { +void training_test_distributed(const std::string& model_config, std::vector dev_model, std::vector dev_input, + std::vector shape, bool should_fail_create, bool should_fail_train, + bool should_fail_inference, bool check_result) { std::string model_name = generate_random_name(10); MPI_Comm mpi_comm = MPI_COMM_WORLD; @@ -42,7 +43,7 @@ void training_test_distributed(const std::string& model_config, std::vector // Skip tests if not running with 2 ranks if (size != 2) { - GTEST_SKIP() << "This test requires 2 ranks to run. Skipping."; + GTEST_SKIP() << "This test requires 2 ranks to run. Skipping."; } #ifdef ENABLE_GPU @@ -54,7 +55,8 @@ void training_test_distributed(const std::string& model_config, std::vector #endif try { - CHECK_TORCHFORT(torchfort_create_distributed_model(model_name.c_str(), model_config.c_str(), mpi_comm, dev_model[rank])); + CHECK_TORCHFORT( + torchfort_create_distributed_model(model_name.c_str(), model_config.c_str(), mpi_comm, dev_model[rank])); if (should_fail_create) { FAIL() << "This test should fail create call, but did not."; } @@ -138,27 +140,34 @@ void training_test_distributed(const std::string& model_config, std::vector free_data_ptr(output_ptr, dev_input[rank]); } - TEST(TorchFort, TrainTestDistributedMLPCPUCPU) { - training_test_distributed("configs/mlp2.yaml", {TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU}, {TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU}, {10, 2, 5}, false, false, false, - false); + training_test_distributed("configs/mlp2.yaml", {TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU}, + {TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU}, {10, 2, 5}, false, false, false, false); } #ifdef ENABLE_GPU TEST(TorchFort, TrainTestDistributedMLPGPUCPU) { - training_test_distributed("configs/mlp2.yaml", {0, 1}, {TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU}, {10, 2, 5}, false, false, false, false); + training_test_distributed("configs/mlp2.yaml", {0, 1}, {TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU}, {10, 2, 5}, + false, false, false, false); } TEST(TorchFort, TrainTestDistributedMLPGPUReverseCPU) { - training_test_distributed("configs/mlp2.yaml", {1, 0}, {TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU}, {10, 2, 5}, false, false, false, false); + training_test_distributed("configs/mlp2.yaml", {1, 0}, {TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU}, {10, 2, 5}, + false, false, false, false); } TEST(TorchFort, TrainTestDistributedMLPCPUGPU) { - training_test_distributed("configs/mlp2.yaml", {TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU}, {0, 1}, {10, 2, 5}, false, false, false, false); + training_test_distributed("configs/mlp2.yaml", {TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU}, {0, 1}, {10, 2, 5}, + false, false, false, false); } TEST(TorchFort, TrainTestDistributedMLPCPUGPUReverse) { - training_test_distributed("configs/mlp2.yaml", {TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU}, {1, 0}, {10, 2, 5}, false, false, false, false); + training_test_distributed("configs/mlp2.yaml", {TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU}, {1, 0}, {10, 2, 5}, + false, false, false, false); +} +TEST(TorchFort, TrainTestDistributedMLPGPUGPU) { + training_test_distributed("configs/mlp2.yaml", {0, 1}, {0, 1}, {10, 10}, false, false, false, false); +} +TEST(TorchFort, TrainTestDistributedMLPGPUGPUReverse) { + training_test_distributed("configs/mlp2.yaml", {0, 1}, {1, 0}, {10, 10}, false, false, false, false); } -TEST(TorchFort, TrainTestDistributedMLPGPUGPU) { training_test_distributed("configs/mlp2.yaml", {0, 1}, {0, 1}, {10, 10}, false, false, false, false); } -TEST(TorchFort, TrainTestDistributedMLPGPUGPUReverse) { training_test_distributed("configs/mlp2.yaml", {0, 1}, {1, 0}, {10, 10}, false, false, false, false); } #endif // Testing expected error cases @@ -167,7 +176,6 @@ int main(int argc, char* argv[]) { ::testing::InitGoogleTest(&argc, argv); MPI_Init(&argc, &argv); - int result = RUN_ALL_TESTS(); MPI_Finalize(); From c2b21f8594aee8761859e267a1f02f6ecacb43ef Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Fri, 5 Dec 2025 10:58:59 -0800 Subject: [PATCH 4/4] Add distributed test to CI. Signed-off-by: Josh Romero --- .github/scripts/run_ci_tests.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/scripts/run_ci_tests.sh b/.github/scripts/run_ci_tests.sh index c35e7bf..7baa46d 100755 --- a/.github/scripts/run_ci_tests.sh +++ b/.github/scripts/run_ci_tests.sh @@ -9,6 +9,7 @@ cd /opt/torchfort/bin/tests/supervised python scripts/setup_tests.py ./test_checkpoint ./test_training +mpirun -np 2 --allow-run-as-root ./test_distributed_training cd /opt/torchfort/bin/tests/rl ./test_distributions