Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 50 additions & 9 deletions include/internal/comm_routines.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,40 @@ static inline MPI_Datatype getMpiDataType(cuda::std::complex<float>) { return MP
static inline MPI_Datatype getMpiDataType(cuda::std::complex<double>) { return MPI_C_DOUBLE_COMPLEX; }
template <typename T> static inline MPI_Datatype getMpiDataType() { return getMpiDataType(T(0)); }

static inline bool canUseMpiAlltoall(const std::vector<comm_count_t>& send_counts,
const std::vector<comm_count_t>& send_offsets,
const std::vector<comm_count_t>& recv_counts,
const std::vector<comm_count_t>& recv_offsets) {
auto scount = send_counts[0];
auto rcount = recv_counts[0];
// Check that send and recv counts are constants
for (int i = 1; i < send_counts.size(); ++i) {
if (send_counts[i] != scount) {
return false;
}
}
for (int i = 1; i < recv_counts.size(); ++i) {
if (recv_counts[i] != rcount) {
return false;
}
}

// Check that offsets are contiguous and equal to counts
for (int i = 0; i < send_offsets.size(); ++i) {
if (send_offsets[i] != i * scount) {
return false;
}
}
for (int i = 0; i < recv_offsets.size(); ++i) {
if (recv_offsets[i] != i * rcount) {
return false;
}
}

return true;
}


#ifdef ENABLE_NVSHMEM
#define CUDECOMP_NVSHMEM_CHUNK_SZ (static_cast<size_t>(1024 * 1024 * 1024))
template <typename T>
Expand Down Expand Up @@ -217,20 +251,27 @@ cudecompAlltoall(const cudecompHandle_t& handle, const cudecompGridDesc_t& grid_
}
case CUDECOMP_TRANSPOSE_COMM_MPI_A2A: {
CHECK_CUDA(cudaStreamSynchronize(stream));
auto send_counts_mod = send_counts;
auto recv_counts_mod = recv_counts;
auto comm =
(comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info.mpi_comm : grid_desc->col_comm_info.mpi_comm;
int self_rank = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info.rank : grid_desc->col_comm_info.rank;

// Self-copy with cudaMemcpy
CHECK_CUDA(cudaMemcpyAsync(recv_buff + recv_offsets[self_rank], send_buff + send_offsets[self_rank],
send_counts[self_rank] * sizeof(T), cudaMemcpyDeviceToDevice, stream));
bool use_alltoall = canUseMpiAlltoall(send_counts, send_offsets, recv_counts, recv_offsets);

send_counts_mod[self_rank] = 0;
recv_counts_mod[self_rank] = 0;
CHECK_MPI(MPI_Alltoallv(send_buff, send_counts_mod.data(), send_offsets.data(), getMpiDataType<T>(), recv_buff,
recv_counts_mod.data(), recv_offsets.data(), getMpiDataType<T>(), comm));
if (use_alltoall) {
CHECK_MPI(MPI_Alltoall(send_buff, send_counts[0], getMpiDataType<T>(), recv_buff, recv_counts[0],
getMpiDataType<T>(), comm));
} else {
// Self-copy with cudaMemcpy
CHECK_CUDA(cudaMemcpyAsync(recv_buff + recv_offsets[self_rank], send_buff + send_offsets[self_rank],
send_counts[self_rank] * sizeof(T), cudaMemcpyDeviceToDevice, stream));

auto send_counts_mod = send_counts;
auto recv_counts_mod = recv_counts;
send_counts_mod[self_rank] = 0;
recv_counts_mod[self_rank] = 0;
CHECK_MPI(MPI_Alltoallv(send_buff, send_counts_mod.data(), send_offsets.data(), getMpiDataType<T>(), recv_buff,
recv_counts_mod.data(), recv_offsets.data(), getMpiDataType<T>(), comm));
}
break;
}
default: {
Expand Down