From bd4c0f36997e3d96255dbdaacd7c20af9848d582 Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 16 Oct 2023 14:24:24 +0800 Subject: [PATCH 1/2] [kernel] fix cpu adam --- .../kernel/cuda_native/csrc/cpu_adam.cpp | 238 +++++------------- colossalai/kernel/cuda_native/csrc/cpu_adam.h | 30 ++- 2 files changed, 86 insertions(+), 182 deletions(-) diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp index 027d18a9dd58..be9300c545c2 100644 --- a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp @@ -37,30 +37,17 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, bool param_half_precision, bool grad_half_precision, bool momentum_half_precision, bool variance_half_precision, float loss_scale) { - size_t rounded_size = 0; + size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH); float betta1_minus1 = 1 - _betta1; float betta2_minus1 = 1 - _betta2; float step_size = -1 * _alpha / _bias_correction1; float w_decay = -1 * _alpha * _weight_decay; - __half *params_cast_h = NULL; - __half *grads_cast_h = NULL; - __half *momentum_cast_h = NULL; - __half *variance_cast_h = NULL; - - if (param_half_precision) { - params_cast_h = reinterpret_cast<__half *>(_params); - } - if (grad_half_precision) { - grads_cast_h = reinterpret_cast<__half *>(grads); - } - if (momentum_half_precision) { - momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); - } - if (variance_half_precision) { - variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); - } + __half *params_cast_h = reinterpret_cast<__half *>(_params); + __half *grads_cast_h = reinterpret_cast<__half *>(grads); + __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); + __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) AVX_Data betta1_4; @@ -86,7 +73,6 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, if (_weight_decay > 0) weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); - rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH); for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; @@ -96,36 +82,23 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, #pragma omp parallel for for (size_t i = t; i < offset; i += SIMD_WIDTH) { AVX_Data grad_4; - if (grad_half_precision) { - grad_4.data = SIMD_LOAD_HALF(grads_cast_h + i); - } else { - grad_4.data = SIMD_LOAD(grads + i); - } + this->simd_load(grad_half_precision, grads + i, grads_cast_h + i, grad_4); if (loss_scale > 0) { AVX_Data loss_scale_vec; loss_scale_vec.data = SIMD_SET(loss_scale); grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data); } AVX_Data momentum_4; - if (momentum_half_precision) { - momentum_4.data = SIMD_LOAD_HALF(momentum_cast_h + i); - } else { - momentum_4.data = SIMD_LOAD(_exp_avg + i); - } + this->simd_load(momentum_half_precision, _exp_avg + i, + momentum_cast_h + i, momentum_4); AVX_Data variance_4; - if (variance_half_precision) { - variance_4.data = SIMD_LOAD_HALF(variance_cast_h + i); - } else { - variance_4.data = SIMD_LOAD(_exp_avg_sq + i); - } + this->simd_load(variance_half_precision, _exp_avg_sq + i, + variance_cast_h + i, variance_4); AVX_Data param_4; - if (param_half_precision) { - param_4.data = SIMD_LOAD_HALF(params_cast_h + i); - } else { - param_4.data = SIMD_LOAD(_params + i); - } + this->simd_load(param_half_precision, _params + i, params_cast_h + i, + param_4); if (_weight_decay > 0 && !_adamw_mode) { grad_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, grad_4.data); @@ -147,21 +120,12 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, } param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data); - if (param_half_precision) { - SIMD_STORE_HALF((float *)(params_cast_h + i), param_4.data); - } else { - SIMD_STORE(_params + i, param_4.data); - } - if (momentum_half_precision) { - SIMD_STORE_HALF((float *)(momentum_cast_h + i), momentum_4.data); - } else { - SIMD_STORE(_exp_avg + i, momentum_4.data); - } - if (variance_half_precision) { - SIMD_STORE_HALF((float *)(variance_cast_h + i), variance_4.data); - } else { - SIMD_STORE(_exp_avg_sq + i, variance_4.data); - } + this->simd_store(param_half_precision, _params + i, params_cast_h + i, + param_4); + this->simd_store(momentum_half_precision, _exp_avg + i, + momentum_cast_h + i, momentum_4); + this->simd_store(variance_half_precision, _exp_avg_sq + i, + variance_cast_h + i, variance_4); } } #endif @@ -223,24 +187,12 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, bool param_half_precision, bool grad_half_precision, bool momentum_half_precision, bool variance_half_precision, float loss_scale) { - size_t rounded_size = 0; - - __half *params_cast_h = NULL; - __half *grads_cast_h = NULL; - __half *momentum_cast_h = NULL; - __half *variance_cast_h = NULL; - if (param_half_precision) { - params_cast_h = reinterpret_cast<__half *>(_params); - } - if (grad_half_precision) { - grads_cast_h = reinterpret_cast<__half *>(grads); - } - if (momentum_half_precision) { - momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); - } - if (variance_half_precision) { - variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); - } + size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4); + + __half *params_cast_h = reinterpret_cast<__half *>(_params); + __half *grads_cast_h = reinterpret_cast<__half *>(grads); + __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); + __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) AVX_Data betta1_4; @@ -270,7 +222,6 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, if (_weight_decay > 0) weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); - rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4); for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; @@ -285,36 +236,21 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, AVX_Data param_4[4]; #pragma unroll 4 for (int j = 0; j < 4; j++) { - if (grad_half_precision) { - grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j); - } else { - grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j); - } + this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j, + grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]); if (loss_scale > 0) { AVX_Data loss_scale_vec; loss_scale_vec.data = SIMD_SET(loss_scale); grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data); } - - if (momentum_half_precision) { - momentum_4[j].data = - SIMD_LOAD_HALF(momentum_cast_h + i + SIMD_WIDTH * j); - } else { - momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j); - } - if (variance_half_precision) { - variance_4[j].data = - SIMD_LOAD_HALF(variance_cast_h + i + SIMD_WIDTH * j); - } else { - variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j); - } - - if (param_half_precision) { - param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j); - } else { - param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j); - } + this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_load(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); + this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); if (_weight_decay > 0 && !_adamw_mode) { grad_4[j].data = @@ -337,24 +273,13 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, } param_4[j].data = SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data); - if (param_half_precision) { - SIMD_STORE_HALF((float *)(params_cast_h + i + SIMD_WIDTH * j), - param_4[j].data); - } else { - SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data); - } - if (momentum_half_precision) { - SIMD_STORE_HALF((float *)(momentum_cast_h + i + SIMD_WIDTH * j), - momentum_4[j].data); - } else { - SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data); - } - if (variance_half_precision) { - SIMD_STORE_HALF((float *)(variance_cast_h + i + SIMD_WIDTH * j), - variance_4[j].data); - } else { - SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data); - } + this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); + this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_store(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); } } } @@ -378,23 +303,12 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, bool param_half_precision, bool grad_half_precision, bool momentum_half_precision, bool variance_half_precision, float loss_scale) { - size_t rounded_size = 0; - __half *params_cast_h = NULL; - __half *grads_cast_h = NULL; - __half *momentum_cast_h = NULL; - __half *variance_cast_h = NULL; - if (param_half_precision) { - params_cast_h = reinterpret_cast<__half *>(_params); - } - if (grad_half_precision) { - grads_cast_h = reinterpret_cast<__half *>(grads); - } - if (momentum_half_precision) { - momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); - } - if (variance_half_precision) { - variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); - } + size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8); + __half *params_cast_h = reinterpret_cast<__half *>(_params); + __half *grads_cast_h = reinterpret_cast<__half *>(grads); + __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); + __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); + #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) AVX_Data betta1_4; betta1_4.data = SIMD_SET(_betta1); @@ -423,7 +337,6 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, if (_weight_decay > 0) weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); - rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8); for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; @@ -438,36 +351,21 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, AVX_Data param_4[8]; #pragma unroll 8 for (int j = 0; j < 8; j++) { - if (grad_half_precision) { - grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j); - } else { - grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j); - } + this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j, + grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]); if (loss_scale > 0) { AVX_Data loss_scale_vec; loss_scale_vec.data = SIMD_SET(loss_scale); grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data); } - - if (momentum_half_precision) { - momentum_4[j].data = - SIMD_LOAD_HALF(momentum_cast_h + i + SIMD_WIDTH * j); - } else { - momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j); - } - if (variance_half_precision) { - variance_4[j].data = - SIMD_LOAD_HALF(variance_cast_h + i + SIMD_WIDTH * j); - } else { - variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j); - } - - if (param_half_precision) { - param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j); - } else { - param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j); - } + this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_load(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); + this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); if (_weight_decay > 0 && !_adamw_mode) { grad_4[j].data = @@ -490,25 +388,13 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, param_4[j].data = SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data); - if (param_half_precision) { - SIMD_STORE_HALF((float *)(params_cast_h + i + SIMD_WIDTH * j), - param_4[j].data); - } else { - SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data); - } - - if (momentum_half_precision) { - SIMD_STORE_HALF((float *)(momentum_cast_h + i + SIMD_WIDTH * j), - momentum_4[j].data); - } else { - SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data); - } - if (variance_half_precision) { - SIMD_STORE_HALF((float *)(variance_cast_h + i + SIMD_WIDTH * j), - variance_4[j].data); - } else { - SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data); - } + this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); + this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_store(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); } } } diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.h b/colossalai/kernel/cuda_native/csrc/cpu_adam.h index 67f3bffaf46a..bf9b85997c78 100644 --- a/colossalai/kernel/cuda_native/csrc/cpu_adam.h +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.h @@ -50,9 +50,9 @@ SOFTWARE #define SIMD_DIV(x, y) _mm512_div_ps(x, y) #define SIMD_LOAD_HALF(x) \ _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) -#define SIMD_STORE_HALF(x, d) \ - _mm256_store_ps( \ - x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) +#define SIMD_STORE_HALF(x, d) \ + _mm256_storeu_ps((float *)(x), _mm256_castsi256_ps(_mm512_cvtps_ph( \ + d, _MM_FROUND_TO_NEAREST_INT))) #elif defined(__AVX256__) or defined(__AVX2__) #define SIMD_WIDTH 8 @@ -66,9 +66,9 @@ SOFTWARE #define SIMD_SQRT(x) _mm256_sqrt_ps(x) #define SIMD_DIV(x, y) _mm256_div_ps(x, y) #define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) -#define SIMD_STORE_HALF(x, d) \ - _mm_store_ps( \ - x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) +#define SIMD_STORE_HALF(x, d) \ + _mm_storeu_ps((float *)(x), _mm_castsi128_ps(_mm256_cvtps_ph( \ + d, _MM_FROUND_TO_NEAREST_INT))) #endif @@ -142,6 +142,24 @@ class Adam_Optimizer { } } + inline void simd_load(bool is_half, float *ptr, __half *h_ptr, + AVX_Data &data) { + if (is_half) { + data.data = SIMD_LOAD_HALF(h_ptr); + } else { + data.data = SIMD_LOAD(ptr); + } + } + + inline void simd_store(bool is_half, float *ptr, __half *h_ptr, + AVX_Data &data) { + if (is_half) { + SIMD_STORE_HALF(h_ptr, data.data); + } else { + SIMD_STORE(ptr, data.data); + } + } + void step(size_t step, float lr, float beta1, float beta2, float epsilon, float weight_decay, bool bias_correction, torch::Tensor ¶ms, torch::Tensor &grads, torch::Tensor &exp_avg, From 8e22669f11262f4c19a59a3723b2ba4a4ae91adf Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 16 Oct 2023 14:24:59 +0800 Subject: [PATCH 2/2] [test] update gemini optim test --- tests/test_zero/test_gemini/test_grad_clip.py | 12 +++++++++--- tests/test_zero/test_gemini/test_optim.py | 15 +++++++++++---- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index a3af81646a18..4c84e9e5a89a 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -52,7 +52,8 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("model_name", ["gpt2"]) -def exam_grad_clipping(placement_config, model_name: str): +@parameterize("master_weights", [True, False]) +def exam_grad_clipping(placement_config, model_name: str, master_weights: bool): set_seed(1912) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -82,6 +83,7 @@ def exam_grad_clipping(placement_config, model_name: str): chunk_config_dict=config_dict, chunk_init_device=init_device, pin_memory=True, + master_weights=master_weights, **placement_config, ) @@ -103,7 +105,10 @@ def exam_grad_clipping(placement_config, model_name: str): torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim) loss = run_fwd_bwd(model, data, label, criterion, zero_optim) - assert_close(torch_loss, loss) + + # as no master weights leads to error accumulation, we don't check the loss + if master_weights: + assert_close(torch_loss, loss) import apex.amp as apex_amp @@ -111,7 +116,8 @@ def exam_grad_clipping(placement_config, model_name: str): torch_optim.step() zero_optim.step() - check_param(model, torch_model) + if master_weights: + check_param(model, torch_model) def run_dist(rank, world_size, port): diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 8e8e508ff483..9b84d68f3c7a 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -70,12 +70,14 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("model_name", TEST_MODELS) @parameterize("mixed_precision", [torch.half, torch.bfloat16]) -def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype): +@parameterize("master_weights", [True, False]) +def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() torch_model = model_builder().cuda() + # apex no master weights leads to nan, so we don't use it amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) @@ -90,7 +92,9 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = False - model = GeminiDDP(model, config_dict, **placement_config, mixed_precision=mixed_precision) + model = GeminiDDP( + model, config_dict, **placement_config, mixed_precision=mixed_precision, master_weights=master_weights + ) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128) @@ -109,12 +113,15 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) - assert_close(torch_loss, loss, rtol=rtol, atol=atol) + # as no master weights leads to error accumulation, we don't check the loss + if master_weights: + assert_close(torch_loss, loss, rtol=rtol, atol=atol) zero_optim.step() torch_optim.step() - check_param(model, torch_model, mixed_precision) + if master_weights: + check_param(model, torch_model, mixed_precision) @parameterize("placement_config", PLACEMENT_CONFIGS)