Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 32 additions & 37 deletions colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
#define BLOCK_SIZE 512
#define ILP 4

template <typename T> __device__ __forceinline__ bool is_aligned(T *p) {
template <typename T>
__device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
}

Expand All @@ -28,11 +29,12 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
}

template <typename x_t> struct L2NormFunctor {
__device__ __forceinline__ void
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
float *output, float *output_per_tensor, bool per_tensor,
int max_chunks_per_tensor) {
template <typename x_t>
struct L2NormFunctor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
float *output, float *output_per_tensor, bool per_tensor,
int max_chunks_per_tensor) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
Expand All @@ -48,8 +50,8 @@ template <typename x_t> struct L2NormFunctor {

__shared__ float s_vals[512];

float
vals[ILP]; // = {0}; // this probably works too but I want to be sure...
float vals[ILP]; // = {0}; // this probably works too but I want to be
// sure...
x_t r_x[ILP];
for (int i = 0; i < ILP; i++) {
vals[i] = 0.f;
Expand Down Expand Up @@ -84,15 +86,14 @@ template <typename x_t> struct L2NormFunctor {
}

float val = 0.f;
for (int i = 0; i < ILP; i++)
val += vals[i];
for (int i = 0; i < ILP; i++) val += vals[i];

float final = reduce_block_into_lanes(s_vals, val);

if (threadIdx.x == 0) {
if (!isfinite(final))
*noop_gmem =
1; // Blindly fire off a write. These will race but that's ok.
1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] += final;
if (per_tensor)
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) *
Expand All @@ -104,11 +105,12 @@ template <typename x_t> struct L2NormFunctor {

// Probably better to template, but since we are not likely to support other
// norm
template <typename x_t> struct MaxNormFunctor {
__device__ __forceinline__ void
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
float *output, float *output_per_tensor, bool per_tensor,
int max_chunks_per_tensor) {
template <typename x_t>
struct MaxNormFunctor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
float *output, float *output_per_tensor, bool per_tensor,
int max_chunks_per_tensor) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
Expand All @@ -124,8 +126,8 @@ template <typename x_t> struct MaxNormFunctor {

__shared__ float s_vals[512];

float
vals[ILP]; // = {0}; // this probably works too but I want to be sure...
float vals[ILP]; // = {0}; // this probably works too but I want to be
// sure...
x_t r_x[ILP];
for (int i = 0; i < ILP; i++) {
vals[i] = 0.f;
Expand Down Expand Up @@ -160,15 +162,14 @@ template <typename x_t> struct MaxNormFunctor {
}

float val = 0.f;
for (int i = 0; i < ILP; i++)
val = fmaxf(fabsf(val), fabsf(vals[i]));
for (int i = 0; i < ILP; i++) val = fmaxf(fabsf(val), fabsf(vals[i]));

float final = reduce_block_into_lanes_max_op(s_vals, val);

if (threadIdx.x == 0) {
if (!isfinite(final))
*noop_gmem =
1; // Blindly fire off a write. These will race but that's ok.
1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final));
if (per_tensor)
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) *
Expand All @@ -185,13 +186,11 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret,

if (blockIdx.x == 0) {
float val = 0;
if (threadIdx.x < 320)
val = output[threadIdx.x];
if (threadIdx.x < 320) val = output[threadIdx.x];

float final = reduce_block_into_lanes(vals, val);

if (threadIdx.x == 0)
*ret = sqrt(final);
if (threadIdx.x == 0) *ret = sqrt(final);
}

if (per_tensor) {
Expand All @@ -204,8 +203,7 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret,

float final = reduce_block_into_lanes(vals, val);

if (threadIdx.x == 0)
ret_per_tensor[blockIdx.x] = sqrt(final);
if (threadIdx.x == 0) ret_per_tensor[blockIdx.x] = sqrt(final);
}
}

Expand All @@ -217,17 +215,14 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,

if (blockIdx.x == 0) {
float val = 0;
if (threadIdx.x < 320)
val = output[threadIdx.x];
if (threadIdx.x < 320) val = output[threadIdx.x];

if (norm_type == 0) {
float final = reduce_block_into_lanes_max_op(vals, val);
if (threadIdx.x == 0)
*ret = alpha * (*ret) + beta * final;
if (threadIdx.x == 0) *ret = alpha * (*ret) + beta * final;
} else {
float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0)
*ret = sqrt(alpha * (*ret) * (*ret) + beta * final);
if (threadIdx.x == 0) *ret = sqrt(alpha * (*ret) * (*ret) + beta * final);
}
}

Expand Down Expand Up @@ -260,10 +255,10 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
}
}

std::tuple<at::Tensor, at::Tensor>
multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python) {
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python) {
bool per_tensor =
per_tensor_python.has_value() ? per_tensor_python.value() : false;

Expand Down