From 5bd0f0dc598f23f496714f0ee151e63e461443d6 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Wed, 10 Jul 2024 17:16:28 -0700 Subject: [PATCH 1/2] Use MPI_Alltoall instead of MPI_Alltoallv if able. --- include/internal/comm_routines.h | 55 ++++++++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 7 deletions(-) diff --git a/include/internal/comm_routines.h b/include/internal/comm_routines.h index 6d625a2..0f984e6 100644 --- a/include/internal/comm_routines.h +++ b/include/internal/comm_routines.h @@ -53,6 +53,40 @@ static inline MPI_Datatype getMpiDataType(cuda::std::complex) { return MP static inline MPI_Datatype getMpiDataType(cuda::std::complex) { return MPI_C_DOUBLE_COMPLEX; } template static inline MPI_Datatype getMpiDataType() { return getMpiDataType(T(0)); } +static inline bool canUseMpiAlltoall(const std::vector& send_counts, + const std::vector& send_offsets, + const std::vector& recv_counts, + const std::vector& recv_offsets) { + auto scount = send_counts[0]; + auto rcount = recv_counts[0]; + // Check that send and recv counts are constants + for (int i = 1; i < send_counts.size(); ++i) { + if (send_counts[i] != scount) { + return false; + } + } + for (int i = 1; i < recv_counts.size(); ++i) { + if (recv_counts[i] != rcount) { + return false; + } + } + + // Check that offsets are contiguous and equal to counts + for (int i = 0; i < send_offsets.size(); ++i) { + if (send_offsets[i] != i * scount) { + return false; + } + } + for (int i = 0; i < recv_offsets.size(); ++i) { + if (recv_offsets[i] != i * rcount) { + return false; + } + } + + return true; +} + + #ifdef ENABLE_NVSHMEM #define CUDECOMP_NVSHMEM_CHUNK_SZ (static_cast(1024 * 1024 * 1024)) template @@ -223,14 +257,21 @@ cudecompAlltoall(const cudecompHandle_t& handle, const cudecompGridDesc_t& grid_ (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info.mpi_comm : grid_desc->col_comm_info.mpi_comm; int self_rank = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info.rank : grid_desc->col_comm_info.rank; - // Self-copy with cudaMemcpy - CHECK_CUDA(cudaMemcpyAsync(recv_buff + recv_offsets[self_rank], send_buff + send_offsets[self_rank], - send_counts[self_rank] * sizeof(T), cudaMemcpyDeviceToDevice, stream)); + bool use_alltoall = canUseMpiAlltoall(send_counts, send_offsets, recv_counts, recv_offsets); - send_counts_mod[self_rank] = 0; - recv_counts_mod[self_rank] = 0; - CHECK_MPI(MPI_Alltoallv(send_buff, send_counts_mod.data(), send_offsets.data(), getMpiDataType(), recv_buff, - recv_counts_mod.data(), recv_offsets.data(), getMpiDataType(), comm)); + if (use_alltoall) { + CHECK_MPI(MPI_Alltoall(send_buff, send_counts[0], getMpiDataType(), recv_buff, recv_counts[0], + getMpiDataType(), comm)); + } else { + // Self-copy with cudaMemcpy + CHECK_CUDA(cudaMemcpyAsync(recv_buff + recv_offsets[self_rank], send_buff + send_offsets[self_rank], + send_counts[self_rank] * sizeof(T), cudaMemcpyDeviceToDevice, stream)); + + send_counts_mod[self_rank] = 0; + recv_counts_mod[self_rank] = 0; + CHECK_MPI(MPI_Alltoallv(send_buff, send_counts_mod.data(), send_offsets.data(), getMpiDataType(), recv_buff, + recv_counts_mod.data(), recv_offsets.data(), getMpiDataType(), comm)); + } break; } default: { From c40cab80851285374f4fbccfdb61b3ad6444fbfd Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Fri, 11 Jul 2025 11:04:03 -0700 Subject: [PATCH 2/2] Minor update. --- include/internal/comm_routines.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/internal/comm_routines.h b/include/internal/comm_routines.h index 0f984e6..1521419 100644 --- a/include/internal/comm_routines.h +++ b/include/internal/comm_routines.h @@ -251,8 +251,6 @@ cudecompAlltoall(const cudecompHandle_t& handle, const cudecompGridDesc_t& grid_ } case CUDECOMP_TRANSPOSE_COMM_MPI_A2A: { CHECK_CUDA(cudaStreamSynchronize(stream)); - auto send_counts_mod = send_counts; - auto recv_counts_mod = recv_counts; auto comm = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info.mpi_comm : grid_desc->col_comm_info.mpi_comm; int self_rank = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info.rank : grid_desc->col_comm_info.rank; @@ -267,6 +265,8 @@ cudecompAlltoall(const cudecompHandle_t& handle, const cudecompGridDesc_t& grid_ CHECK_CUDA(cudaMemcpyAsync(recv_buff + recv_offsets[self_rank], send_buff + send_offsets[self_rank], send_counts[self_rank] * sizeof(T), cudaMemcpyDeviceToDevice, stream)); + auto send_counts_mod = send_counts; + auto recv_counts_mod = recv_counts; send_counts_mod[self_rank] = 0; recv_counts_mod[self_rank] = 0; CHECK_MPI(MPI_Alltoallv(send_buff, send_counts_mod.data(), send_offsets.data(), getMpiDataType(), recv_buff,