diff --git a/include/internal/comm_routines.h b/include/internal/comm_routines.h index 6d625a2..1521419 100644 --- a/include/internal/comm_routines.h +++ b/include/internal/comm_routines.h @@ -53,6 +53,40 @@ static inline MPI_Datatype getMpiDataType(cuda::std::complex) { return MP static inline MPI_Datatype getMpiDataType(cuda::std::complex) { return MPI_C_DOUBLE_COMPLEX; } template static inline MPI_Datatype getMpiDataType() { return getMpiDataType(T(0)); } +static inline bool canUseMpiAlltoall(const std::vector& send_counts, + const std::vector& send_offsets, + const std::vector& recv_counts, + const std::vector& recv_offsets) { + auto scount = send_counts[0]; + auto rcount = recv_counts[0]; + // Check that send and recv counts are constants + for (int i = 1; i < send_counts.size(); ++i) { + if (send_counts[i] != scount) { + return false; + } + } + for (int i = 1; i < recv_counts.size(); ++i) { + if (recv_counts[i] != rcount) { + return false; + } + } + + // Check that offsets are contiguous and equal to counts + for (int i = 0; i < send_offsets.size(); ++i) { + if (send_offsets[i] != i * scount) { + return false; + } + } + for (int i = 0; i < recv_offsets.size(); ++i) { + if (recv_offsets[i] != i * rcount) { + return false; + } + } + + return true; +} + + #ifdef ENABLE_NVSHMEM #define CUDECOMP_NVSHMEM_CHUNK_SZ (static_cast(1024 * 1024 * 1024)) template @@ -217,20 +251,27 @@ cudecompAlltoall(const cudecompHandle_t& handle, const cudecompGridDesc_t& grid_ } case CUDECOMP_TRANSPOSE_COMM_MPI_A2A: { CHECK_CUDA(cudaStreamSynchronize(stream)); - auto send_counts_mod = send_counts; - auto recv_counts_mod = recv_counts; auto comm = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info.mpi_comm : grid_desc->col_comm_info.mpi_comm; int self_rank = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info.rank : grid_desc->col_comm_info.rank; - // Self-copy with cudaMemcpy - CHECK_CUDA(cudaMemcpyAsync(recv_buff + recv_offsets[self_rank], send_buff + send_offsets[self_rank], - send_counts[self_rank] * sizeof(T), cudaMemcpyDeviceToDevice, stream)); + bool use_alltoall = canUseMpiAlltoall(send_counts, send_offsets, recv_counts, recv_offsets); - send_counts_mod[self_rank] = 0; - recv_counts_mod[self_rank] = 0; - CHECK_MPI(MPI_Alltoallv(send_buff, send_counts_mod.data(), send_offsets.data(), getMpiDataType(), recv_buff, - recv_counts_mod.data(), recv_offsets.data(), getMpiDataType(), comm)); + if (use_alltoall) { + CHECK_MPI(MPI_Alltoall(send_buff, send_counts[0], getMpiDataType(), recv_buff, recv_counts[0], + getMpiDataType(), comm)); + } else { + // Self-copy with cudaMemcpy + CHECK_CUDA(cudaMemcpyAsync(recv_buff + recv_offsets[self_rank], send_buff + send_offsets[self_rank], + send_counts[self_rank] * sizeof(T), cudaMemcpyDeviceToDevice, stream)); + + auto send_counts_mod = send_counts; + auto recv_counts_mod = recv_counts; + send_counts_mod[self_rank] = 0; + recv_counts_mod[self_rank] = 0; + CHECK_MPI(MPI_Alltoallv(send_buff, send_counts_mod.data(), send_offsets.data(), getMpiDataType(), recv_buff, + recv_counts_mod.data(), recv_offsets.data(), getMpiDataType(), comm)); + } break; } default: {