Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/scripts/run_ci_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ cd /opt/torchfort/bin/tests/supervised
python scripts/setup_tests.py
./test_checkpoint
./test_training
mpirun -np 2 --allow-run-as-root ./test_distributed_training

cd /opt/torchfort/bin/tests/rl
./test_distributions
Expand Down
42 changes: 36 additions & 6 deletions src/csrc/distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#ifdef ENABLE_GPU
#include <nccl.h>

#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#endif
#include <torch/torch.h>
Expand Down Expand Up @@ -73,12 +74,15 @@ static ncclComm_t ncclCommFromMPIComm(MPI_Comm mpi_comm) {
}
#endif

void Comm::initialize(bool initialize_nccl) {
void Comm::initialize() {
CHECK_MPI(MPI_Comm_rank(mpi_comm, &rank));
CHECK_MPI(MPI_Comm_size(mpi_comm, &size));

#ifdef ENABLE_GPU
if (initialize_nccl) {
// Initialize NCCL if device if GPU
if (device.is_cuda()) {
c10::cuda::CUDAGuard guard(device);

nccl_comm = ncclCommFromMPIComm(mpi_comm);

int greatest_priority;
Expand All @@ -94,13 +98,17 @@ void Comm::initialize(bool initialize_nccl) {

void Comm::finalize() {
#ifdef ENABLE_GPU
if (nccl_comm)
if (nccl_comm) {
c10::cuda::CUDAGuard guard(device);
CHECK_NCCL(ncclCommDestroy(nccl_comm));
if (stream)
CHECK_CUDA(cudaStreamDestroy(stream));
if (event)
CHECK_CUDA(cudaEventDestroy(event));
nccl_comm = nullptr;
stream = nullptr;
event = nullptr;
}
#endif
initialized = false;
}

void Comm::allreduce(torch::Tensor& tensor, bool average) const {
Expand All @@ -116,6 +124,15 @@ void Comm::allreduce(torch::Tensor& tensor, bool average) const {

#ifdef ENABLE_GPU
if (tensor.device().type() == torch::kCUDA) {
if (tensor.device() != device) {
std::stringstream ss;
ss << "allreduce called with tensor on device " << tensor.device() << " but the comm was initialized on device "
<< device << ".";
THROW_INVALID_USAGE(ss.str());
}

c10::cuda::CUDAGuard guard(device);

auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
CHECK_CUDA(cudaEventRecord(event, torch_stream));
CHECK_CUDA(cudaStreamWaitEvent(stream, event));
Expand Down Expand Up @@ -157,6 +174,8 @@ void Comm::allreduce(std::vector<torch::Tensor>& tensors, bool average) const {

#ifdef ENABLE_GPU
if (tensors[0].device().type() == torch::kCUDA) {
c10::cuda::CUDAGuard guard(device);

auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
CHECK_CUDA(cudaEventRecord(event, torch_stream));
CHECK_CUDA(cudaStreamWaitEvent(stream, event));
Expand All @@ -173,6 +192,8 @@ void Comm::allreduce(std::vector<torch::Tensor>& tensors, bool average) const {

#ifdef ENABLE_GPU
if (tensors[0].device().type() == torch::kCUDA) {
c10::cuda::CUDAGuard guard(device);

CHECK_NCCL(ncclGroupEnd());
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
CHECK_CUDA(cudaEventRecord(event, stream));
Expand Down Expand Up @@ -201,6 +222,15 @@ void Comm::broadcast(torch::Tensor& tensor, int root) const {
auto count = torch::numel(tensor);
#ifdef ENABLE_GPU
if (tensor.device().type() == torch::kCUDA) {
if (tensor.device() != device) {
std::stringstream ss;
ss << "broadcast called with tensor on device " << tensor.device() << " but the comm was initialized on device "
<< device << ".";
THROW_INVALID_USAGE(ss.str());
}

c10::cuda::CUDAGuard guard(device);

// Use NCCL for GPU tensors
ncclDataType_t nccl_dtype;
if (torch::is_complex(tensor)) {
Expand All @@ -210,7 +240,7 @@ void Comm::broadcast(torch::Tensor& tensor, int root) const {
nccl_dtype = get_nccl_dtype(tensor);
}

auto torch_stream = c10::cuda::getCurrentCUDAStream(tensor.device().index()).stream();
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
CHECK_CUDA(cudaEventRecord(event, torch_stream));
CHECK_CUDA(cudaStreamWaitEvent(stream, event));

Expand Down
5 changes: 3 additions & 2 deletions src/csrc/include/internal/distributed.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
namespace torchfort {

struct Comm {
void initialize(bool initialize_nccl = false);
void initialize();
void finalize();
void allreduce(torch::Tensor& tensor, bool average = false) const;
void allreduce(std::vector<torch::Tensor>& tensors, bool average = false) const;
Expand All @@ -38,6 +38,7 @@ struct Comm {

int rank;
int size;
torch::Device device;
MPI_Comm mpi_comm;
#ifdef ENABLE_GPU
ncclComm_t nccl_comm = nullptr;
Expand All @@ -46,7 +47,7 @@ struct Comm {
#endif
bool initialized = false;

Comm(MPI_Comm mpi_comm) : mpi_comm(mpi_comm){};
Comm(MPI_Comm mpi_comm, torch::Device device) : mpi_comm(mpi_comm), device(device){};
};

} // namespace torchfort
20 changes: 10 additions & 10 deletions src/csrc/rl/off_policy/ddpg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,18 +256,18 @@ int DDPGSystem::getRank() const {
void DDPGSystem::initSystemComm(MPI_Comm mpi_comm) {
// Set up distributed communicators for all models
// system
system_comm_ = std::make_shared<Comm>(mpi_comm);
system_comm_->initialize(model_device_.is_cuda());
system_comm_ = std::make_shared<Comm>(mpi_comm, model_device_);
system_comm_->initialize();
// policy
p_model_.comm = std::make_shared<Comm>(mpi_comm);
p_model_.comm->initialize(model_device_.is_cuda());
p_model_target_.comm = std::make_shared<Comm>(mpi_comm);
p_model_target_.comm->initialize(model_device_.is_cuda());
p_model_.comm = std::make_shared<Comm>(mpi_comm, model_device_);
p_model_.comm->initialize();
p_model_target_.comm = std::make_shared<Comm>(mpi_comm, model_device_);
p_model_target_.comm->initialize();
// critic
q_model_.comm = std::make_shared<Comm>(mpi_comm);
q_model_.comm->initialize(model_device_.is_cuda());
q_model_target_.comm = std::make_shared<Comm>(mpi_comm);
q_model_target_.comm->initialize(model_device_.is_cuda());
q_model_.comm = std::make_shared<Comm>(mpi_comm, model_device_);
q_model_.comm->initialize();
q_model_target_.comm = std::make_shared<Comm>(mpi_comm, model_device_);
q_model_target_.comm->initialize();

// move to device before broadcasting
// policy
Expand Down
16 changes: 8 additions & 8 deletions src/csrc/rl/off_policy/sac.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,19 +291,19 @@ int SACSystem::getRank() const {
void SACSystem::initSystemComm(MPI_Comm mpi_comm) {
// Set up distributed communicators for all models
// system
system_comm_ = std::make_shared<Comm>(mpi_comm);
system_comm_->initialize(model_device_.is_cuda());
system_comm_ = std::make_shared<Comm>(mpi_comm, model_device_);
system_comm_->initialize();
// policy
p_model_.comm = std::make_shared<Comm>(mpi_comm);
p_model_.comm->initialize(model_device_.is_cuda());
p_model_.comm = std::make_shared<Comm>(mpi_comm, model_device_);
p_model_.comm->initialize();
// critic
for (auto& q_model : q_models_) {
q_model.comm = std::make_shared<Comm>(mpi_comm);
q_model.comm->initialize(model_device_.is_cuda());
q_model.comm = std::make_shared<Comm>(mpi_comm, model_device_);
q_model.comm->initialize();
}
for (auto& q_model_target : q_models_target_) {
q_model_target.comm = std::make_shared<Comm>(mpi_comm);
q_model_target.comm->initialize(model_device_.is_cuda());
q_model_target.comm = std::make_shared<Comm>(mpi_comm, model_device_);
q_model_target.comm->initialize();
}
// we do not need an alpha comm objects since p-model comm is used for that

Expand Down
20 changes: 10 additions & 10 deletions src/csrc/rl/off_policy/td3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,21 +274,21 @@ int TD3System::getRank() const {
void TD3System::initSystemComm(MPI_Comm mpi_comm) {
// Set up distributed communicators for all models
// system
system_comm_ = std::make_shared<Comm>(mpi_comm);
system_comm_->initialize(model_device_.is_cuda());
system_comm_ = std::make_shared<Comm>(mpi_comm, model_device_);
system_comm_->initialize();
// policy
p_model_.comm = std::make_shared<Comm>(mpi_comm);
p_model_.comm->initialize(model_device_.is_cuda());
p_model_target_.comm = std::make_shared<Comm>(mpi_comm);
p_model_target_.comm->initialize(model_device_.is_cuda());
p_model_.comm = std::make_shared<Comm>(mpi_comm, model_device_);
p_model_.comm->initialize();
p_model_target_.comm = std::make_shared<Comm>(mpi_comm, model_device_);
p_model_target_.comm->initialize();
// critic
for (auto& q_model : q_models_) {
q_model.comm = std::make_shared<Comm>(mpi_comm);
q_model.comm->initialize(model_device_.is_cuda());
q_model.comm = std::make_shared<Comm>(mpi_comm, model_device_);
q_model.comm->initialize();
}
for (auto& q_model_target : q_models_target_) {
q_model_target.comm = std::make_shared<Comm>(mpi_comm);
q_model_target.comm->initialize(model_device_.is_cuda());
q_model_target.comm = std::make_shared<Comm>(mpi_comm, model_device_);
q_model_target.comm->initialize();
}

// move to device before broadcasting
Expand Down
8 changes: 4 additions & 4 deletions src/csrc/rl/on_policy/ppo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,11 @@ torch::Device PPOSystem::rbDevice() const { return rb_device_; }
void PPOSystem::initSystemComm(MPI_Comm mpi_comm) {
// Set up distributed communicators for all models
// system
system_comm_ = std::make_shared<Comm>(mpi_comm);
system_comm_->initialize(model_device_.is_cuda());
system_comm_ = std::make_shared<Comm>(mpi_comm, model_device_);
system_comm_->initialize();
// policy
pq_model_.comm = std::make_shared<Comm>(mpi_comm);
pq_model_.comm->initialize(model_device_.is_cuda());
pq_model_.comm = std::make_shared<Comm>(mpi_comm, model_device_);
pq_model_.comm->initialize();

// move to device before broadcasting
pq_model_.model->to(model_device_);
Expand Down
4 changes: 2 additions & 2 deletions src/csrc/torchfort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ torchfort_result_t torchfort_create_distributed_model(const char* name, const ch
torchfort_create_model(name, config_fname, device);

// Set up distributed communicator
models[name].comm = std::shared_ptr<Comm>(new Comm(mpi_comm));
models[name].comm->initialize(models[name].model->device().is_cuda());
models[name].comm = std::make_shared<Comm>(mpi_comm, get_device(device));
models[name].comm->initialize();

// Broadcast initial model parameters from rank 0
for (auto& p : models[name].model->parameters()) {
Expand Down
7 changes: 7 additions & 0 deletions tests/supervised/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ cmake_minimum_required(VERSION 3.14)
set(test_targets
test_checkpoint
test_training
test_distributed_training
)

add_executable(test_checkpoint)
Expand All @@ -17,6 +18,12 @@ target_sources(test_training
test_training.cpp
)

add_executable(test_distributed_training)
target_sources(test_distributed_training
PRIVATE
test_distributed_training.cpp
)

find_package(Python 3.6 COMPONENTS Interpreter Development REQUIRED)

foreach(tgt ${test_targets})
Expand Down
Loading