Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 25 additions & 25 deletions colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
#define BLOCK_SIZE 512
#define ILP 4

template <typename T> __device__ __forceinline__ bool is_aligned(T *p) {
template <typename T>
__device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
}

Expand All @@ -28,24 +29,25 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
}

typedef enum {
MOMENT_MODE_0 = 0, // L2 regularization mode
MOMENT_MODE_1 = 1 // Decoupled weight decay mode
MOMENT_MODE_0 = 0, // L2 regularization mode
MOMENT_MODE_1 = 1 // Decoupled weight decay mode
} adamMode_t;

std::tuple<at::Tensor, at::Tensor>
multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);

using MATH_T = float;

template <typename T> struct LAMBStage1Functor {
__device__ __forceinline__ void
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
const float beta1, const float beta2, const float beta3,
const float beta1_correction, const float beta2_correction,
const float epsilon, adamMode_t mode, const float decay,
const float *global_grad_norm, const float max_global_grad_norm) {
template <typename T>
struct LAMBStage1Functor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
const float beta1, const float beta2, const float beta3,
const float beta1_correction, const float beta2_correction,
const float epsilon, adamMode_t mode, const float decay,
const float *global_grad_norm, const float max_global_grad_norm) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
Expand Down Expand Up @@ -89,8 +91,7 @@ template <typename T> struct LAMBStage1Functor {
i_start += blockDim.x) {
// load
load_store(l_g, g, 0, i_start);
if (decay != 0)
load_store(l_p, p, 0, i_start);
if (decay != 0) load_store(l_p, p, 0, i_start);
load_store(l_m, m, 0, i_start);
load_store(l_v, v, 0, i_start);
// unpack
Expand Down Expand Up @@ -204,12 +205,12 @@ template <typename T> struct LAMBStage1Functor {

// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template <typename T> struct LAMBStage2Functor {
__device__ __forceinline__ void
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl,
const float *per_tensor_param_norm,
const float *per_tensor_update_norm, const float learning_rate,
const float decay, bool use_nvlamb) {
template <typename T>
struct LAMBStage2Functor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl,
const float *per_tensor_param_norm, const float *per_tensor_update_norm,
const float learning_rate, const float decay, bool use_nvlamb) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
Expand Down Expand Up @@ -310,8 +311,7 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,

// Handle grad averaging mode
float beta3 = 1.0f;
if (grad_averaging == 1)
beta3 = 1 - beta1;
if (grad_averaging == 1) beta3 = 1 - beta1;

std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(),
tensor_lists.begin() + 1);
Expand All @@ -330,7 +330,7 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
LAMBStage1Functor<scalar_t_0>(), beta1, beta2,
beta3, // 1-beta1 or 1 depends on averaging mode
beta3, // 1-beta1 or 1 depends on averaging mode
bias_correction1, bias_correction2, epsilon,
(adamMode_t)mode, weight_decay,
global_grad_norm.DATA_PTR<float>(), max_grad_norm);)
Expand Down