From a74f587dca6b34ef81f2580a4a3e58bee92266c6 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 25 May 2023 14:03:18 +0800 Subject: [PATCH 1/3] [bf16] fused adam kernel support bf16 --- colossalai/kernel/cuda_native/csrc/type_shim.h | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/colossalai/kernel/cuda_native/csrc/type_shim.h b/colossalai/kernel/cuda_native/csrc/type_shim.h index 2f180a7783ec..03ccc02635fa 100644 --- a/colossalai/kernel/cuda_native/csrc/type_shim.h +++ b/colossalai/kernel/cuda_native/csrc/type_shim.h @@ -171,6 +171,21 @@ using g_scalar_t_##LEVEL = at::Half; \ using p_scalar_t_##LEVEL = at::Half; \ __VA_ARGS__; \ + } else if (GTYPE == at::ScalarType::Float && \ + PTYPE == at::ScalarType::BFloat16) { \ + using g_scalar_t_##LEVEL = float; \ + using p_scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ + } else if (GTYPE == at::ScalarType::BFloat16 && \ + PTYPE == at::ScalarType::Float) { \ + using g_scalar_t_##LEVEL = at::BFloat16; \ + using p_scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + } else if (GTYPE == at::ScalarType::BFloat16 && \ + PTYPE == at::ScalarType::BFloat16) { \ + using g_scalar_t_##LEVEL = at::BFloat16; \ + using p_scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ } else { \ AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \ "'"); \ From 3a8c3a40f3075187e7995bdad50e9dfda67e768c Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 25 May 2023 15:10:40 +0800 Subject: [PATCH 2/3] [test] update fused adam kernel test --- colossalai/nn/optimizer/fused_adam.py | 4 ++-- tests/test_optimizer/test_fused_adam_kernel.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py index 987af8a968b7..82a6250f1fd1 100644 --- a/colossalai/nn/optimizer/fused_adam.py +++ b/colossalai/nn/optimizer/fused_adam.py @@ -134,8 +134,8 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like(p) - if p.dtype not in [torch.float16, torch.float32]: - raise RuntimeError('FusedAdam only support fp16 and fp32.') + if p.dtype not in [torch.float16, torch.float32, torch.bfloat16]: + raise RuntimeError('FusedAdam only support fp16, fp32 and bf16.') g_l.append(p.grad.data) p_l.append(p.data) diff --git a/tests/test_optimizer/test_fused_adam_kernel.py b/tests/test_optimizer/test_fused_adam_kernel.py index 4afa13349c1b..deabaebfec75 100644 --- a/tests/test_optimizer/test_fused_adam_kernel.py +++ b/tests/test_optimizer/test_fused_adam_kernel.py @@ -1,12 +1,14 @@ import math +import pytest import torch -import torch.nn as nn -from numpy import dtype -from colossalai.testing import clear_cache_before_run, parameterize from colossalai.utils import multi_tensor_applier +_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float), + (torch.half, torch.half), (torch.bfloat16, torch.float), (torch.float, torch.bfloat16), + (torch.bfloat16, torch.bfloat16)] + def torch_adam_update( step, @@ -41,11 +43,9 @@ def torch_adam_update( param.addcdiv_(exp_avg, denom, value=-step_size) -@clear_cache_before_run() -@parameterize('adamw', [False, True]) -@parameterize('step', [1, 2]) -@parameterize('p_dtype', [torch.float, torch.half]) -@parameterize('g_dtype', [torch.float, torch.half]) +@pytest.mark.parametrize('adamw', [False, True]) +@pytest.mark.parametrize('step', [1, 2]) +@pytest.mark.parametrize('p_dtype, g_dtype', _ALLOWED_P_G_TYPES) def test_adam(adamw, step, p_dtype, g_dtype): from colossalai.kernel.op_builder import FusedOptimBuilder fused_optim = FusedOptimBuilder().load() From 332a2bbeec2c3f18a4ebe362a2c53a96616bdaa2 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 25 May 2023 15:53:10 +0800 Subject: [PATCH 3/3] [test] update fused adam test --- tests/test_optimizer/test_fused_adam.py | 71 ++++++++++++++----------- 1 file changed, 40 insertions(+), 31 deletions(-) diff --git a/tests/test_optimizer/test_fused_adam.py b/tests/test_optimizer/test_fused_adam.py index 114d5293dad9..511987ca3dbf 100644 --- a/tests/test_optimizer/test_fused_adam.py +++ b/tests/test_optimizer/test_fused_adam.py @@ -1,10 +1,12 @@ +from copy import deepcopy + +import pytest import torch import torch.nn as nn from torch.optim import AdamW from torch.optim.adam import Adam from colossalai.nn.optimizer.fused_adam import FusedAdam -from colossalai.testing import clear_cache_before_run, parameterize class FC(nn.Module): @@ -17,48 +19,55 @@ def forward(self, x): return self.fc(x) -@clear_cache_before_run() -@parameterize('adamw', [False, True]) -@parameterize('p_dtype', [torch.float, torch.half]) -@parameterize('g_dtype', [torch.float, torch.half]) -def test_adam(adamw, p_dtype, g_dtype): - model = FC().cuda().to(p_dtype) - state = model.state_dict() - model_copy = FC().cuda().to(p_dtype) - model_copy.load_state_dict(state.copy()) - - if adamw: - optim = FusedAdam(model.parameters(), lr=1e-3, adamw_mode=True) - torch_optim = AdamW(model_copy.parameters(), lr=1e-3) +@pytest.mark.parametrize('adamw', [False, True]) +@pytest.mark.parametrize('running_p_dtype', [torch.float, torch.half, torch.bfloat16]) +@pytest.mark.parametrize('fp32_master_weights', [False, True]) +def test_adam(adamw, running_p_dtype, fp32_master_weights): + # baseline is fure fp32 torch adam + # g_type is the same as running_p_dtype + if running_p_dtype is torch.float and not fp32_master_weights: + # pure fp32 must have fp32 weights + return + if not fp32_master_weights or running_p_dtype is torch.bfloat16: + # pure low precision or bf16, high tolerance + atol = 4e-3 + rtol = 4e-3 else: - optim = FusedAdam(model.parameters(), lr=1e-3) - torch_optim = Adam(model_copy.parameters(), lr=1e-3) + # fp32 master weights, low tolerance + atol = 2e-3 + rtol = 2e-3 + torch_model = FC().cuda() + model = deepcopy(torch_model).to(running_p_dtype) - data = torch.rand(1024, 64).cuda().to(p_dtype) - data_copy = data.clone() - label = torch.rand(1024, 64).cuda().to(p_dtype) + torch_optim_cls = AdamW if adamw else Adam + torch_optim = torch_optim_cls(torch_model.parameters(), lr=1e-3) + optim = FusedAdam(model.parameters(), lr=1e-3, adamw_mode=adamw) - for d, l in zip(data, label): + data = torch.rand(10, 64).cuda() + label = torch.rand(10, 64).cuda() + + for d, l in zip(data.to(running_p_dtype), label.to(running_p_dtype)): y = model(d) loss = ((l - y)**2).sum() optim.zero_grad() loss.backward() - if p_dtype != g_dtype: - for i in range(len(optim.param_groups[0]['params'])): - optim.param_groups[0]['params'][i].grad.data = optim.param_groups[0]['params'][i].grad.data.to(g_dtype) + if fp32_master_weights: + for p in model.parameters(): + p.data = p.data.float() optim.step() + if fp32_master_weights: + for p in model.parameters(): + p.data = p.data.to(running_p_dtype) - for d, l in zip(data_copy, label): - y = model_copy(d) + for d, l in zip(data, label): + y = torch_model(d) loss = ((l - y)**2).sum() torch_optim.zero_grad() loss.backward() torch_optim.step() - assert len(optim.param_groups[0]['params']) == len(torch_optim.param_groups[0]['params']) - - for i in range(len(optim.param_groups[0]['params'])): - if torch.isnan(optim.param_groups[0]['params'][i]).any() \ - or torch.isnan(torch_optim.param_groups[0]['params'][i]).any(): + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + if torch.isnan(p).any() or torch.isnan(torch_p).any(): continue - assert torch.allclose(optim.param_groups[0]['params'][i], torch_optim.param_groups[0]['params'][i], 2e-3, 2e-3) + fp32_p = p.float() + assert torch.allclose(fp32_p, torch_p, atol=atol, rtol=rtol)