Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 16 additions & 30 deletions colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#include <cooperative_groups.h>

#include <chrono>
#include <ctime>

#include "kernels.h"

#include <cooperative_groups.h>

namespace cg = cooperative_groups;

curandStatePhilox4_32_10_t *curandstate;
Expand Down Expand Up @@ -165,8 +165,7 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 4 >= total_count)
return;
if (i * 4 >= total_count) return;

curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
Expand Down Expand Up @@ -202,8 +201,7 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,

int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 8 >= total_count)
return;
if (i * 8 >= total_count) return;

curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
Expand Down Expand Up @@ -261,8 +259,7 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 4 >= total_count)
return;
if (i * 4 >= total_count) return;

uint8_t m[4];

Expand All @@ -289,8 +286,7 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,

int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 8 >= total_count)
return;
if (i * 8 >= total_count) return;

float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
Expand Down Expand Up @@ -380,8 +376,7 @@ __global__ void ls_dropout_res_bias_kernel(
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 4 >= total_count)
return;
if (i * 4 >= total_count) return;

curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
Expand Down Expand Up @@ -424,8 +419,7 @@ __global__ void ls_dropout_res_bias_kernel(

int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 8 >= total_count)
return;
if (i * 8 >= total_count) return;

curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
Expand Down Expand Up @@ -565,11 +559,9 @@ __global__ void ls_dropout_bias_bwd_kernel(
}
__syncthreads();

for (int i = 1; i < 32; i <<= 1)
sum += g.shfl_down(sum, i);
for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i);

if (y == 0)
tile[0][x] = sum;
if (y == 0) tile[0][x] = sum;
__syncthreads();

if (threadIdx.x < 8) {
Expand Down Expand Up @@ -621,11 +613,9 @@ __global__ void ls_dropout_bias_bwd_kernel(
}
__syncthreads();

for (int i = 1; i < WARP_SIZE; i <<= 1)
sum += g.shfl_down(sum, i);
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);

if (y == 0)
tile[0][x] = sum;
if (y == 0) tile[0][x] = sum;
__syncthreads();

if (threadIdx.x < 8) {
Expand Down Expand Up @@ -689,8 +679,7 @@ __global__ void ls_dropout_act_bias_kernel(
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 4 >= total_count)
return;
if (i * 4 >= total_count) return;

curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
Expand Down Expand Up @@ -735,8 +724,7 @@ __global__ void ls_dropout_act_bias_kernel(

int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 8 >= total_count)
return;
if (i * 8 >= total_count) return;

curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
Expand Down Expand Up @@ -897,11 +885,9 @@ __global__ void ls_dropout_act_bias_bwd_kernel(
float sum = tile[threadIdx.y][threadIdx.x];
__syncthreads();

for (int i = 1; i < WARP_SIZE; i <<= 1)
sum += g.shfl_down(sum, i);
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);

if (threadIdx.x == 0)
tile[0][threadIdx.y] = sum;
if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
__syncthreads();

if (threadIdx.y == 0) {
Expand Down