Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 193 additions & 25 deletions csrc/include/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ typedef __hip_bfloat16 nv_bfloat16;
namespace aiter
{

constexpr int kMaxBlocks = 64;
constexpr int kMaxBlocks = 80;
// note: we don't want to use atomics for signals because peer atomics are no
// supported on PCIe links
struct Signal
Expand Down Expand Up @@ -315,7 +315,7 @@ namespace aiter

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
cross_device_reduce_1stage_naive(RankData *_dp, RankSignals sg,
#ifndef USE_ROCM
volatile
#endif
Expand Down Expand Up @@ -350,7 +350,7 @@ namespace aiter

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
cross_device_reduce_2stage_naive(RankData *_dp, RankSignals sg,
#ifndef USE_ROCM
volatile
#endif
Expand Down Expand Up @@ -403,6 +403,141 @@ namespace aiter
}
}

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
#ifndef USE_ROCM
volatile
#endif
Signal *self_sg,
T *__restrict__ result, int rank, int size)
{
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
constexpr int pack_size = packed_t<T>::P::size;
constexpr int tnum_gpu = 512 / ngpus;
__shared__ T tmp_smem[tnum_gpu * ngpus * pack_size];
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
auto dp = *_dp;

// load one gpu data each wave
int warp_id = threadIdx.x / tnum_gpu;
int lane_id = threadIdx.x % tnum_gpu;
start_sync<ngpus>(sg, self_sg, rank);
// do the actual reduction
for (int idx = blockIdx.x * tnum_gpu + lane_id; idx < size;
idx += gridDim.x * tnum_gpu)
{
*(reinterpret_cast<P*>(&tmp_smem[0]) + threadIdx.x) = ((const P**)&dp.ptrs[0])[warp_id][idx];
__syncthreads();
if (warp_id == 0)
{
A add_reg;
#pragma unroll
for (int i = 0; i < pack_size; ++i)
{
add_reg.data[i] = ck_tile::type_convert<float>(tmp_smem[threadIdx.x * pack_size + i]);
}
#pragma unroll
for (int i = 1; i < ngpus; ++i)
{
#pragma unroll
for (int j = 0; j < pack_size; ++j)
{
add_reg.data[j] += ck_tile::type_convert<float>(tmp_smem[512 * i + threadIdx.x * pack_size + j]);
}
}
P write_reg;
#pragma unroll
for (int i = 0; i < pack_size; ++i)
{
write_reg.data[i] = ck_tile::type_convert<T>(add_reg.data[i]);
}
((P *)result)[idx] = write_reg;
}
}
}

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
#ifndef USE_ROCM
volatile
#endif
Signal *self_sg,
T *__restrict__ result, int rank, int size)
{
constexpr int pack_size = packed_t<T>::P::size;
constexpr int tnum_gpu = 512 / ngpus;
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
__shared__ T tmp_smem[tnum_gpu * ngpus * pack_size];
int warp_id = threadIdx.x / tnum_gpu;
int lane_id = threadIdx.x % tnum_gpu;
int tid = blockIdx.x * tnum_gpu + lane_id;
int stride = gridDim.x * tnum_gpu;
int part = size / ngpus;
int start = rank * part;
int end = rank == ngpus - 1 ? size : start + part;
int largest_part = part + size % ngpus;
const P *ptrs[ngpus];
P *tmps[ngpus];
#pragma unroll
for (int i = 0; i < ngpus; i++)
{
int target = (rank + i) % ngpus;
ptrs[i] = (const P *)_dp->ptrs[target];
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
}
auto tmp_out = tmps[0];
start_sync<ngpus>(sg, self_sg, rank);
// stage 1: reduce scatter
for (int idx = start + tid; idx < end; idx += stride)
{
*(reinterpret_cast<P*>(&tmp_smem[0]) + threadIdx.x) = ptrs[warp_id][idx];
__syncthreads();
// cal add in first 64 threads
if (warp_id == 0)
{
A add_reg;
#pragma unroll
for (int i = 0; i < pack_size; ++i)
{
add_reg.data[i] = ck_tile::type_convert<float>(tmp_smem[pack_size * threadIdx.x + i]);
}
#pragma unroll
for (int i = 1; i < ngpus; ++i)
{
#pragma unroll
for (int j = 0; j < pack_size; ++j)
{
add_reg.data[j] += ck_tile::type_convert<float>(tmp_smem[i * 512 + pack_size * threadIdx.x + j]);
}
}
P write_reg;
#pragma unroll
for (int i = 0; i < pack_size; ++i)
{
write_reg.data[i] = ck_tile::type_convert<T>(add_reg.data[i]);
}
tmp_out[idx - start] = write_reg;
}
}
end_sync<ngpus>(sg, self_sg, rank);

// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
// between threads that have the same tid. If thread i computes the sum of
// start + i in the first stage, then thread i also gathers start + i from all
// ranks.
for (int idx = tid; idx < largest_part; idx += stride)
{
int dst_idx = (warp_id + rank) % ngpus * part + idx;
((P *)result)[dst_idx] = tmps[warp_id][idx];
}
}

// fp8 quant all-reduce code start
template <typename T>
struct Fp16Filter
Expand Down Expand Up @@ -595,6 +730,7 @@ namespace aiter
return ret_val;
}


template <typename T, int pack_size, int ngpus>
DINLINE array_t<T, pack_size> multiGPUPackReduce(const array_t<T, pack_size> *ptrs[ngpus], int index)
{
Expand Down Expand Up @@ -977,32 +1113,64 @@ namespace aiter

RankData *ptrs = get_buffer_RD(stream, input);

auto bytes = size * sizeof(T);
size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = std::min(block_limit, (size + threads - 1) / threads);
int blocks = 16;
bool call_1stage = false;
bool call_2stage = false;
if (world_size_ == 2)
{
call_1stage = true;
}
else if (full_nvlink_)
{
if ((world_size_ <= 4 && bytes < 160 * 1024) || (world_size_ <= 8 && bytes < 80 * 1024))
{
call_1stage = true;
}
else
{
call_2stage = true;
}
}
if (call_1stage)
{
blocks = std::min(kMaxBlocks, (size + (threads / world_size_) - 1) / (threads / world_size_));
}
else if (call_2stage)
{
blocks = std::min(kMaxBlocks, (size / world_size_ + (threads / world_size_) - 1) / (threads / world_size_));
}

#define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size);
#define REDUCE_CASE(ngpus) \
case ngpus: \
{ \
if (world_size_ == 2) \
{ \
KL(ngpus, cross_device_reduce_1stage); \
} \
else if (full_nvlink_) \
{ \
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
(world_size_ <= 8 && bytes < 256 * 1024)) \
{ \
KL(ngpus, cross_device_reduce_1stage); \
} \
else \
{ \
KL(ngpus, cross_device_reduce_2stage); \
} \
} \
break; \

#define dispatch(ngpus, name) \
do \
{ \
if (bytes % 128 == 0) \
{ \
KL(ngpus, name) \
} \
else \
{ \
KL(ngpus, name##_naive) \
} \
} while(0)

#define REDUCE_CASE(ngpus) \
case ngpus: \
{ \
if (call_1stage) \
{ \
dispatch(ngpus, cross_device_reduce_1stage); \
} \
else if (call_2stage) \
{ \
dispatch(ngpus, cross_device_reduce_2stage); \
} \
break; \
}

switch (world_size_)
Expand Down