From a8c5be5d474040901e1435ecaf08e93b794dc285 Mon Sep 17 00:00:00 2001 From: BoxiangW <45734921+BoxiangW@users.noreply.github.com> Date: Fri, 13 May 2022 15:07:19 +0800 Subject: [PATCH 1/3] [NFC] Polish colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu code style. (#937) --- .../cuda_native/csrc/multi_tensor_lamb.cu | 50 +++++++++---------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu index 15ac209149eb..54c4220190d8 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu @@ -15,7 +15,8 @@ #define BLOCK_SIZE 512 #define ILP 4 -template __device__ __forceinline__ bool is_aligned(T *p) { +template +__device__ __forceinline__ bool is_aligned(T *p) { return ((uint64_t)p) % (ILP * sizeof(T)) == 0; } @@ -28,24 +29,25 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, } typedef enum { - MOMENT_MODE_0 = 0, // L2 regularization mode - MOMENT_MODE_1 = 1 // Decoupled weight decay mode + MOMENT_MODE_0 = 0, // L2 regularization mode + MOMENT_MODE_1 = 1 // Decoupled weight decay mode } adamMode_t; -std::tuple -multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, - at::optional per_tensor_python); +std::tuple multi_tensor_l2norm_cuda( + int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::optional per_tensor_python); using MATH_T = float; -template struct LAMBStage1Functor { - __device__ __forceinline__ void - operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, - const float beta1, const float beta2, const float beta3, - const float beta1_correction, const float beta2_correction, - const float epsilon, adamMode_t mode, const float decay, - const float *global_grad_norm, const float max_global_grad_norm) { +template +struct LAMBStage1Functor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, + const float beta1, const float beta2, const float beta3, + const float beta1_correction, const float beta2_correction, + const float epsilon, adamMode_t mode, const float decay, + const float *global_grad_norm, const float max_global_grad_norm) { // I'd like this kernel to propagate infs/nans. // if(*noop_gmem == 1) // return; @@ -89,8 +91,7 @@ template struct LAMBStage1Functor { i_start += blockDim.x) { // load load_store(l_g, g, 0, i_start); - if (decay != 0) - load_store(l_p, p, 0, i_start); + if (decay != 0) load_store(l_p, p, 0, i_start); load_store(l_m, m, 0, i_start); load_store(l_v, v, 0, i_start); // unpack @@ -204,12 +205,12 @@ template struct LAMBStage1Functor { // Step 2 reads in 'update' value and per-tensor param_norm and update_norm. // It computes new parameter value. -template struct LAMBStage2Functor { - __device__ __forceinline__ void - operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl, - const float *per_tensor_param_norm, - const float *per_tensor_update_norm, const float learning_rate, - const float decay, bool use_nvlamb) { +template +struct LAMBStage2Functor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl, + const float *per_tensor_param_norm, const float *per_tensor_update_norm, + const float learning_rate, const float decay, bool use_nvlamb) { // I'd like this kernel to propagate infs/nans. // if(*noop_gmem == 1) // return; @@ -310,8 +311,7 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, // Handle grad averaging mode float beta3 = 1.0f; - if (grad_averaging == 1) - beta3 = 1 - beta1; + if (grad_averaging == 1) beta3 = 1 - beta1; std::vector> grad_list(tensor_lists.begin(), tensor_lists.begin() + 1); @@ -330,7 +330,7 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, LAMBStage1Functor(), beta1, beta2, - beta3, // 1-beta1 or 1 depends on averaging mode + beta3, // 1-beta1 or 1 depends on averaging mode bias_correction1, bias_correction2, epsilon, (adamMode_t)mode, weight_decay, global_grad_norm.DATA_PTR(), max_grad_norm);) From 565994b90c0095292dde59b12cbb7f8e8742b714 Mon Sep 17 00:00:00 2001 From: yuxuan-lou Date: Fri, 13 May 2022 15:22:30 +0800 Subject: [PATCH 2/3] [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h code style --- .../kernel/cuda_native/csrc/kernels/include/cuda_util.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h index bc22587628d8..1595257be0f5 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h @@ -20,7 +20,8 @@ void check_gpu_error(T result, char const *const func, const char *const file, template void print_vec(const T *outv, std::string outn, int num_output_ele); -template T *cuda_malloc(size_t ele_num); +template +T *cuda_malloc(size_t ele_num); void cuda_free(void *pdata); @@ -28,6 +29,6 @@ template void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf, std::string file, int line, cudaStream_t stream); -#define CHECK_NAN_INF(ptr, size, stream) \ - check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \ +#define CHECK_NAN_INF(ptr, size, stream) \ + check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \ check_nan_inf((ptr), (size), false, __FILE__, __LINE__, (stream)) From e077573a139edcc2c4635ee2eb36a7e33f720867 Mon Sep 17 00:00:00 2001 From: yuxuan-lou Date: Tue, 12 Jul 2022 19:20:59 +0800 Subject: [PATCH 3/3] [NFC] polish colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp code style --- .../csrc/scaled_masked_softmax.cpp | 84 ++++++++----------- 1 file changed, 35 insertions(+), 49 deletions(-) diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp index 4ae3c853ca5e..8c2982b0cff9 100644 --- a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp +++ b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp @@ -3,82 +3,68 @@ #include #include + #include namespace multihead_attn { namespace fused_softmax { namespace scaled_masked_softmax { -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor); - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor); - -int get_batch_per_block_cuda( - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads); - -torch::Tensor fwd( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor) { +torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, + float scale_factor); + +torch::Tensor bwd_cuda(torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, + int attn_heads); + +torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask, + float scale_factor) { AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); return fwd_cuda(input, mask, scale_factor); } -torch::Tensor bwd( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor) { - +torch::Tensor bwd(torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, float scale_factor) { AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); return bwd_cuda(output_grads, softmax_results, scale_factor); } -int get_batch_per_block( - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads) { - return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, + int attn_heads) { + return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, + attn_heads); } -} // end namespace scaled_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn +} // end namespace scaled_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", - &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward."); + m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); - m.def("backward", - &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward."); + m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); m.def("get_batch_per_block", - &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, - "Return Batch per block size." - ); + &multihead_attn::fused_softmax::scaled_masked_softmax:: + get_batch_per_block, + "Return Batch per block size."); }