Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 25 additions & 52 deletions colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#include "block_reduce.h"
#include <cub/cub.cuh>
#include <cuda.h>
#include <cuda_fp16.h>
#include <torch/extension.h>

#include <cub/cub.cuh>

#include "block_reduce.h"

template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {

assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;

Expand All @@ -28,7 +29,6 @@ __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {

template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {

assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;

Expand All @@ -51,7 +51,6 @@ __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
const int cols) {

assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;

Expand All @@ -75,7 +74,6 @@ __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
const int cols) {

assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;

Expand Down Expand Up @@ -105,7 +103,6 @@ __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
template <typename T, int block_size, int pack_size>
__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
const int cols) {

assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;

Expand Down Expand Up @@ -134,7 +131,6 @@ __device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
template <typename T, int block_size, int pack_size>
__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
T *weight_grad, const T weight, const int cols) {

assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;

Expand Down Expand Up @@ -164,15 +160,13 @@ __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,

blockReduce<ReduceType::kSum, 1>(&thread_sum);

if (threadIdx.x == 0)
*weight_grad = static_cast<T>(thread_sum);
if (threadIdx.x == 0) *weight_grad = static_cast<T>(thread_sum);
}

template <typename T, int block_size, int pack_size>
__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row,
const T weight1, const T weight2,
const int cols) {

assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;

Expand Down Expand Up @@ -204,7 +198,6 @@ __device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row,
T *tks_row1, T *tks_row2, T *weight_grad1,
T *weight_grad2, const T weight1,
const T weight2, const int cols) {

assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;

Expand Down Expand Up @@ -251,7 +244,6 @@ template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2,
const int cols, const int indicator1,
const int indicator2) {

if (indicator1 != 0 && indicator2 != 0)
moe_dpch_two_fwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
cols);
Expand All @@ -267,7 +259,6 @@ template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2,
const int cols, const int indicator1,
const int indicator2) {

if (indicator1 != 0 && indicator2 != 0)
moe_dpch_two_bwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
cols);
Expand All @@ -283,7 +274,6 @@ template <typename T, int block_size, int pack_size>
__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input,
int *mask1, int *mask2, int *dest1,
int *dest2, const int h) {

int row = blockIdx.x;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
moe_dpch_fwd_selector<T, block_size, pack_size>(
Expand All @@ -295,7 +285,6 @@ template <typename T, int block_size, int pack_size>
__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1,
int *mask2, int *dest1, int *dest2,
const int h) {

int row = blockIdx.x;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
moe_dpch_bwd_selector<T, block_size, pack_size>(
Expand All @@ -310,7 +299,6 @@ __device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row,
const int cols, const T weight1,
const T weight2, const int indicator1,
const int indicator2) {

if (indicator1 != 0 && indicator2 != 0)
moe_cb_two_fwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
weight1, weight2, cols);
Expand All @@ -328,7 +316,6 @@ __device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row,
T *wt_grad1, T *wt_grad2, const T weight1,
const T weight2, const int indicator1,
const int indicator2) {

if (indicator1 != 0 && indicator2 != 0)
moe_cb_two_bwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
tks_row1, tks_row2, wt_grad1,
Expand All @@ -348,7 +335,6 @@ __global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens,
T *logits, int *mask1, int *mask2, int *dest1,
int *dest2, const int e, const int c,
const int h) {

int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
T *row_log = logits + (row * e);
Expand All @@ -363,7 +349,6 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
T *logits, T *logits_grad, int *mask1,
int *mask2, int *dest1, int *dest2,
const int e, const int c, const int h) {

int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e);
Expand All @@ -379,7 +364,6 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
template <int block_size, int pack_size>
__global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
const int e) {

assert(s % pack_size == 0);
constexpr int bpack_size = block_size * pack_size;
int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1;
Expand Down Expand Up @@ -426,8 +410,7 @@ __global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
}
__syncthreads();

if (tid == 0)
temp[0] = temp[block_size];
if (tid == 0) temp[0] = temp[block_size];
__syncthreads();

if (idx + tps < s) {
Expand All @@ -453,7 +436,6 @@ template <typename T>
void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
int *mask2, int *dest1, int *dest2, const int s,
const int h) {

if (h < 256)
moe_dpch_fwd_kernel<T, 32, 4>
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
Expand All @@ -474,7 +456,6 @@ void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
template <typename T>
void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2,
int *dest1, int *dest2, const int s, const int h) {

if (h < 256)
moe_dpch_bwd_kernel<T, 32, 4>
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
Expand All @@ -496,7 +477,6 @@ template <typename T>
void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits,
int *mask1, int *mask2, int *dest1, int *dest2,
const int s, const int e, const int c, const int h) {

if (h < 256)
moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>>(expert_tokens, combine_tokens,
logits, mask1, mask2, dest1, dest2,
Expand Down Expand Up @@ -524,12 +504,11 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
T *logits_grad, int *mask1, int *mask2, int *dest1,
int *dest2, const int s, const int e, const int c,
const int h) {

if (h < 256)
moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, tks,
logits, logits_grad, mask1, mask2,
dest1, dest2, e, c, h);
else // if (h < 512)
else // if (h < 512)
moe_cb_bwd_kernel<T, 64, 4><<<s, 64>>>(tokens_grad, expert_grad, tks,
logits, logits_grad, mask1, mask2,
dest1, dest2, e, c, h);
Expand All @@ -544,7 +523,6 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
}

void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {

if (s <= 256)
cumsum_kernel<256, 1><<<e, 256>>>(inputs, outputs, s, e);
else if (s <= 512)
Expand All @@ -559,27 +537,26 @@ void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {

// API FUNCTIONS --------------------------------

#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented yet for specific data type."); \
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented yet for specific data type."); \
}

torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h,
torch::Tensor batch_tokens,
torch::Tensor mask,
torch::Tensor dest_idx) {

assert(h % 16 == 0);
auto res = torch::zeros(
{ec, h},
Expand All @@ -601,7 +578,6 @@ torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
torch::Tensor expert_grad,
torch::Tensor mask,
torch::Tensor dest_idx) {

assert(h % 16 == 0);
auto res = torch::zeros(
{s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device()));
Expand All @@ -622,7 +598,6 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
torch::Tensor expert_tokens,
torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx) {

assert(h % 16 == 0);
assert(expert_tokens.dtype() == logits.dtype());

Expand All @@ -643,11 +618,10 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
return res;
}

std::vector<torch::Tensor>
moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor logits,
torch::Tensor mask, torch::Tensor dest_idx) {

std::vector<torch::Tensor> moe_combine_cuda_backward(
int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx) {
assert(h % 16 == 0);
assert(tokens_grad.dtype() == expert_tokens.dtype());
assert(expert_tokens.dtype() == logits.dtype());
Expand All @@ -673,7 +647,6 @@ moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
}

torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) {

assert(mask.dim() == 2);
assert(mask.dtype() == torch::kInt32);

Expand Down