From 55a77b322824bb6d3a64820ca302a26a70563c6e Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 25 May 2023 16:59:08 +0800 Subject: [PATCH 1/6] [bf16] add bf16 support for fused adam (#3844) * [bf16] fused adam kernel support bf16 * [test] update fused adam kernel test * [test] update fused adam test --- .../kernel/cuda_native/csrc/type_shim.h | 15 ++++ colossalai/nn/optimizer/fused_adam.py | 4 +- tests/test_optimizer/test_fused_adam.py | 71 +++++++++++-------- .../test_optimizer/test_fused_adam_kernel.py | 16 ++--- 4 files changed, 65 insertions(+), 41 deletions(-) diff --git a/colossalai/kernel/cuda_native/csrc/type_shim.h b/colossalai/kernel/cuda_native/csrc/type_shim.h index 2f180a7783ec..03ccc02635fa 100644 --- a/colossalai/kernel/cuda_native/csrc/type_shim.h +++ b/colossalai/kernel/cuda_native/csrc/type_shim.h @@ -171,6 +171,21 @@ using g_scalar_t_##LEVEL = at::Half; \ using p_scalar_t_##LEVEL = at::Half; \ __VA_ARGS__; \ + } else if (GTYPE == at::ScalarType::Float && \ + PTYPE == at::ScalarType::BFloat16) { \ + using g_scalar_t_##LEVEL = float; \ + using p_scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ + } else if (GTYPE == at::ScalarType::BFloat16 && \ + PTYPE == at::ScalarType::Float) { \ + using g_scalar_t_##LEVEL = at::BFloat16; \ + using p_scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + } else if (GTYPE == at::ScalarType::BFloat16 && \ + PTYPE == at::ScalarType::BFloat16) { \ + using g_scalar_t_##LEVEL = at::BFloat16; \ + using p_scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ } else { \ AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \ "'"); \ diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py index 987af8a968b7..82a6250f1fd1 100644 --- a/colossalai/nn/optimizer/fused_adam.py +++ b/colossalai/nn/optimizer/fused_adam.py @@ -134,8 +134,8 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like(p) - if p.dtype not in [torch.float16, torch.float32]: - raise RuntimeError('FusedAdam only support fp16 and fp32.') + if p.dtype not in [torch.float16, torch.float32, torch.bfloat16]: + raise RuntimeError('FusedAdam only support fp16, fp32 and bf16.') g_l.append(p.grad.data) p_l.append(p.data) diff --git a/tests/test_optimizer/test_fused_adam.py b/tests/test_optimizer/test_fused_adam.py index 114d5293dad9..511987ca3dbf 100644 --- a/tests/test_optimizer/test_fused_adam.py +++ b/tests/test_optimizer/test_fused_adam.py @@ -1,10 +1,12 @@ +from copy import deepcopy + +import pytest import torch import torch.nn as nn from torch.optim import AdamW from torch.optim.adam import Adam from colossalai.nn.optimizer.fused_adam import FusedAdam -from colossalai.testing import clear_cache_before_run, parameterize class FC(nn.Module): @@ -17,48 +19,55 @@ def forward(self, x): return self.fc(x) -@clear_cache_before_run() -@parameterize('adamw', [False, True]) -@parameterize('p_dtype', [torch.float, torch.half]) -@parameterize('g_dtype', [torch.float, torch.half]) -def test_adam(adamw, p_dtype, g_dtype): - model = FC().cuda().to(p_dtype) - state = model.state_dict() - model_copy = FC().cuda().to(p_dtype) - model_copy.load_state_dict(state.copy()) - - if adamw: - optim = FusedAdam(model.parameters(), lr=1e-3, adamw_mode=True) - torch_optim = AdamW(model_copy.parameters(), lr=1e-3) +@pytest.mark.parametrize('adamw', [False, True]) +@pytest.mark.parametrize('running_p_dtype', [torch.float, torch.half, torch.bfloat16]) +@pytest.mark.parametrize('fp32_master_weights', [False, True]) +def test_adam(adamw, running_p_dtype, fp32_master_weights): + # baseline is fure fp32 torch adam + # g_type is the same as running_p_dtype + if running_p_dtype is torch.float and not fp32_master_weights: + # pure fp32 must have fp32 weights + return + if not fp32_master_weights or running_p_dtype is torch.bfloat16: + # pure low precision or bf16, high tolerance + atol = 4e-3 + rtol = 4e-3 else: - optim = FusedAdam(model.parameters(), lr=1e-3) - torch_optim = Adam(model_copy.parameters(), lr=1e-3) + # fp32 master weights, low tolerance + atol = 2e-3 + rtol = 2e-3 + torch_model = FC().cuda() + model = deepcopy(torch_model).to(running_p_dtype) - data = torch.rand(1024, 64).cuda().to(p_dtype) - data_copy = data.clone() - label = torch.rand(1024, 64).cuda().to(p_dtype) + torch_optim_cls = AdamW if adamw else Adam + torch_optim = torch_optim_cls(torch_model.parameters(), lr=1e-3) + optim = FusedAdam(model.parameters(), lr=1e-3, adamw_mode=adamw) - for d, l in zip(data, label): + data = torch.rand(10, 64).cuda() + label = torch.rand(10, 64).cuda() + + for d, l in zip(data.to(running_p_dtype), label.to(running_p_dtype)): y = model(d) loss = ((l - y)**2).sum() optim.zero_grad() loss.backward() - if p_dtype != g_dtype: - for i in range(len(optim.param_groups[0]['params'])): - optim.param_groups[0]['params'][i].grad.data = optim.param_groups[0]['params'][i].grad.data.to(g_dtype) + if fp32_master_weights: + for p in model.parameters(): + p.data = p.data.float() optim.step() + if fp32_master_weights: + for p in model.parameters(): + p.data = p.data.to(running_p_dtype) - for d, l in zip(data_copy, label): - y = model_copy(d) + for d, l in zip(data, label): + y = torch_model(d) loss = ((l - y)**2).sum() torch_optim.zero_grad() loss.backward() torch_optim.step() - assert len(optim.param_groups[0]['params']) == len(torch_optim.param_groups[0]['params']) - - for i in range(len(optim.param_groups[0]['params'])): - if torch.isnan(optim.param_groups[0]['params'][i]).any() \ - or torch.isnan(torch_optim.param_groups[0]['params'][i]).any(): + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + if torch.isnan(p).any() or torch.isnan(torch_p).any(): continue - assert torch.allclose(optim.param_groups[0]['params'][i], torch_optim.param_groups[0]['params'][i], 2e-3, 2e-3) + fp32_p = p.float() + assert torch.allclose(fp32_p, torch_p, atol=atol, rtol=rtol) diff --git a/tests/test_optimizer/test_fused_adam_kernel.py b/tests/test_optimizer/test_fused_adam_kernel.py index 4afa13349c1b..deabaebfec75 100644 --- a/tests/test_optimizer/test_fused_adam_kernel.py +++ b/tests/test_optimizer/test_fused_adam_kernel.py @@ -1,12 +1,14 @@ import math +import pytest import torch -import torch.nn as nn -from numpy import dtype -from colossalai.testing import clear_cache_before_run, parameterize from colossalai.utils import multi_tensor_applier +_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float), + (torch.half, torch.half), (torch.bfloat16, torch.float), (torch.float, torch.bfloat16), + (torch.bfloat16, torch.bfloat16)] + def torch_adam_update( step, @@ -41,11 +43,9 @@ def torch_adam_update( param.addcdiv_(exp_avg, denom, value=-step_size) -@clear_cache_before_run() -@parameterize('adamw', [False, True]) -@parameterize('step', [1, 2]) -@parameterize('p_dtype', [torch.float, torch.half]) -@parameterize('g_dtype', [torch.float, torch.half]) +@pytest.mark.parametrize('adamw', [False, True]) +@pytest.mark.parametrize('step', [1, 2]) +@pytest.mark.parametrize('p_dtype, g_dtype', _ALLOWED_P_G_TYPES) def test_adam(adamw, step, p_dtype, g_dtype): from colossalai.kernel.op_builder import FusedOptimBuilder fused_optim = FusedOptimBuilder().load() From 2e799c78aef754dbd0ea94f442c17caf9b1b5a49 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 30 May 2023 09:59:29 +0800 Subject: [PATCH 2/6] [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860) --- colossalai/nn/optimizer/cpu_adam.py | 23 ++- colossalai/nn/optimizer/hybrid_adam.py | 37 ++--- tests/test_optimizer/test_adam_kernel.py | 131 ++++++++++++++++++ tests/test_optimizer/test_adam_optim.py | 86 ++++++++++++ tests/test_optimizer/test_cpu_adam.py | 121 ---------------- tests/test_optimizer/test_fused_adam.py | 73 ---------- .../test_optimizer/test_fused_adam_kernel.py | 95 ------------- tests/test_optimizer/test_hybrid_adam.py | 42 ------ 8 files changed, 254 insertions(+), 354 deletions(-) create mode 100644 tests/test_optimizer/test_adam_kernel.py create mode 100644 tests/test_optimizer/test_adam_optim.py delete mode 100644 tests/test_optimizer/test_cpu_adam.py delete mode 100644 tests/test_optimizer/test_fused_adam.py delete mode 100644 tests/test_optimizer/test_fused_adam_kernel.py delete mode 100644 tests/test_optimizer/test_hybrid_adam.py diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index bb561a106515..7070c0a1e59d 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -93,8 +93,7 @@ def torch_adam_update(self, bias_correction1, bias_correction2, use_adamw=False): - # FIXME(ver217): remove the below line when replace torch adam with fused adam - grad = grad.float() + grad = grad.to(data.dtype) if weight_decay != 0: if use_adamw: @@ -133,10 +132,12 @@ def step(self, closure=None, div_scale: float = -1): if len(state) == 0: state['step'] = 0 + # FIXME(ver217): CPU adam kernel only supports fp32 states now + assert p.dtype is torch.float, "CPUAdam only support fp32 parameters" # gradient momentums - state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + state['exp_avg'] = torch.zeros_like(p, device=target_device) # gradient variances - state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + state['exp_avg_sq'] = torch.zeros_like(p, device=target_device) self._post_state_init(p) state['step'] += 1 @@ -147,9 +148,17 @@ def step(self, closure=None, div_scale: float = -1): assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu" assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu" self._pre_update(p, 'exp_avg', 'exp_avg_sq') - self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], - group['bias_correction'], p.data, p.grad.data, state['exp_avg'], - state['exp_avg_sq'], div_scale) + if p.grad.dtype is torch.bfloat16: + # cpu adam kernel does not support bf16 now + bias_correction1 = 1 - beta1**state['step'] + bias_correction2 = 1 - beta2**state['step'] + self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'], + beta1, beta2, group['eps'], group['weight_decay'], bias_correction1, + bias_correction2, self.adamw_mode) + else: + self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], + group['weight_decay'], group['bias_correction'], p.data, p.grad.data, + state['exp_avg'], state['exp_avg_sq'], div_scale) self._post_update(p, 'exp_avg', 'exp_avg_sq') elif target_device.type == 'cuda': assert div_scale == -1, "div_scale should remain default" diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index be6311c6c29f..526071b06f95 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -1,16 +1,17 @@ from typing import Any, Optional import torch +from torch.optim import Adam -from colossalai.kernel.op_builder import CPUAdamBuilder, FusedOptimBuilder +from colossalai.kernel.op_builder import FusedOptimBuilder from colossalai.registry import OPTIMIZERS from colossalai.utils import multi_tensor_applier -from .nvme_optimizer import NVMeOptimizer +from .cpu_adam import CPUAdam @OPTIMIZERS.register_module -class HybridAdam(NVMeOptimizer): +class HybridAdam(CPUAdam): """Implements Adam algorithm. Supports parameters updating on both GPU and CPU, depanding on the device of parameters. @@ -74,15 +75,9 @@ def __init__(self, nvme_offload_dir: Optional[str] = None, **defaults: Any): - default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) - super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) - self.adamw_mode = adamw_mode - - # build during runtime if not found - cpu_optim = CPUAdamBuilder().load() + super().__init__(model_params, lr, bias_correction, betas, eps, weight_decay, adamw_mode, nvme_offload_fraction, + nvme_offload_dir) fused_optim = FusedOptimBuilder().load() - self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) - self.gpu_adam_op = fused_optim.multi_tensor_adam self._dummy_overflow_buf = torch.cuda.IntTensor([0]) @@ -108,10 +103,12 @@ def step(self, closure=None, div_scale: float = -1): if len(state) == 0: state['step'] = 0 + # FIXME(ver217): CPU adam kernel only supports fp32 states now + assert p.dtype is torch.float, "HybridAdam only support fp32 parameters" # gradient momentums - state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + state['exp_avg'] = torch.zeros_like(p, device=target_device) # gradient variances - state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + state['exp_avg_sq'] = torch.zeros_like(p, device=target_device) self._post_state_init(p) state['step'] += 1 @@ -122,9 +119,17 @@ def step(self, closure=None, div_scale: float = -1): assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu" assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu" self._pre_update(p, 'exp_avg', 'exp_avg_sq') - self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], - group['bias_correction'], p.data, p.grad.data, state['exp_avg'], - state['exp_avg_sq'], div_scale) + if p.grad.dtype is torch.bfloat16: + # cpu adam kernel does not support bf16 now + bias_correction1 = 1 - beta1**state['step'] + bias_correction2 = 1 - beta2**state['step'] + self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'], + beta1, beta2, group['eps'], group['weight_decay'], bias_correction1, + bias_correction2, self.adamw_mode) + else: + self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], + group['weight_decay'], group['bias_correction'], p.data, p.grad.data, + state['exp_avg'], state['exp_avg_sq'], div_scale) self._post_update(p, 'exp_avg', 'exp_avg_sq') elif target_device.type == 'cuda': diff --git a/tests/test_optimizer/test_adam_kernel.py b/tests/test_optimizer/test_adam_kernel.py new file mode 100644 index 000000000000..2186a421fe00 --- /dev/null +++ b/tests/test_optimizer/test_adam_kernel.py @@ -0,0 +1,131 @@ +# This test checks adam kernels +# Baseline is pure fp32 torch adam optimizer +import math +from abc import abstractmethod +from typing import Type + +import pytest +import torch +from torch import Tensor + +from colossalai.utils import get_current_device, multi_tensor_applier + +_FUSED_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float), + (torch.half, torch.half), (torch.bfloat16, torch.float), (torch.float, torch.bfloat16), + (torch.bfloat16, torch.bfloat16)] + +_CPU_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float), + (torch.half, torch.half)] + + +class AdamKernel: + + def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: + self.lr = lr + self.beta1 = beta1 + self.beta2 = beta2 + self.eps = eps + self.weight_decay = weight_decay + self.use_adamw = use_adamw + + @abstractmethod + def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): + pass + + +class TorchAdamKernel(AdamKernel): + + def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): + bias_correction1 = 1 - self.beta1**step + bias_correction2 = 1 - self.beta2**step + + if self.weight_decay != 0: + if self.use_adamw: + # Perform stepweight decay + param.mul_(1 - self.lr * self.weight_decay) + else: + grad = grad.add(param, alpha=self.weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(self.beta1).add_(grad, alpha=1 - self.beta1) + exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad, value=1 - self.beta2) + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps) + + step_size = self.lr / bias_correction1 + + param.addcdiv_(exp_avg, denom, value=-step_size) + + +class FusedAdamKernel(AdamKernel): + + def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: + super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) + from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() + self.fused_adam = fused_optim.multi_tensor_adam + self.dummy_overflow_buf = torch.cuda.IntTensor([0]) + + def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): + multi_tensor_applier(self.fused_adam, self.dummy_overflow_buf, [[grad], [param], [exp_avg], [exp_avg_sq]], + self.lr, self.beta1, self.beta2, self.eps, step, self.use_adamw, True, self.weight_decay, + -1) + + +class CPUAdamKernel(AdamKernel): + + def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: + super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) + from colossalai.kernel.op_builder import CPUAdamBuilder + cpu_optim = CPUAdamBuilder().load() + + self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw) + + def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): + self.cpu_adam_op.step(step, self.lr, self.beta1, self.beta2, self.eps, self.weight_decay, True, param.view(-1), + grad.view(-1), exp_avg.view(-1), exp_avg_sq.view(-1), -1) + + +def check_adam_kernel(kernel: Type[AdamKernel], adamw: bool, weight_decay: float, p_dtype: torch.dtype, + g_dtype: torch.dtype, device: torch.device, n_steps: int, rtol: float, atol: float): + lr = 1e-3 + beta1, beta2 = 0.9, 0.999 + eps = 1e-8 + torch_adam = TorchAdamKernel(lr, beta1, beta2, eps, weight_decay, adamw) + adam_kernel = kernel(lr, beta1, beta2, eps, weight_decay, adamw) + master_p = torch.rand(64, device=device) + master_g = torch.rand_like(master_p) + master_exp_avg = torch.zeros_like(master_p) + master_exp_avg_sq = torch.zeros_like(master_p) + p = master_p.clone().to(p_dtype) + g = master_g.clone().to(g_dtype) + exp_avg = master_exp_avg.clone() + exp_avg_sq = master_exp_avg_sq.clone() + + for step in range(1, 1 + n_steps): + torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq) + adam_kernel.update(step, p, g, exp_avg, exp_avg_sq) + # if overflow, the weight won't be updated. so there will be no nan in p + assert not torch.isnan(p).any() + assert torch.allclose(master_p, p.float(), rtol=rtol, atol=atol) + + +@pytest.mark.parametrize('adamw', [False, True]) +@pytest.mark.parametrize('weight_decay', [0.0, 0.1]) +@pytest.mark.parametrize('p_dtype, g_dtype', _FUSED_ALLOWED_P_G_TYPES) +def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype): + rtol, atol = 1e-5, 1e-8 + if p_dtype is torch.float16 or g_dtype is torch.float16: + rtol, atol = 1e-3, 1e-3 + if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16: + rtol, atol = 4e-3, 4e-3 + check_adam_kernel(FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_current_device(), 3, rtol, atol) + + +@pytest.mark.parametrize('adamw', [False, True]) +@pytest.mark.parametrize('weight_decay', [0.0, 0.1]) +@pytest.mark.parametrize('p_dtype, g_dtype', _CPU_ALLOWED_P_G_TYPES) +def test_cpu_adam_kernel(adamw, weight_decay, p_dtype, g_dtype): + rtol, atol = 1e-5, 1e-8 + if p_dtype is torch.float16 or g_dtype is torch.float16: + rtol, atol = 1e-3, 1e-3 + check_adam_kernel(CPUAdamKernel, adamw, weight_decay, p_dtype, g_dtype, torch.device('cpu'), 3, rtol, atol) diff --git a/tests/test_optimizer/test_adam_optim.py b/tests/test_optimizer/test_adam_optim.py new file mode 100644 index 000000000000..0f72bc134809 --- /dev/null +++ b/tests/test_optimizer/test_adam_optim.py @@ -0,0 +1,86 @@ +from copy import deepcopy +from typing import Type, Union + +import pytest +import torch +import torch.nn as nn +from torch.optim import Adam, AdamW + +from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam +from tests.kit.model_zoo import model_zoo + +_ALLOWED_OPTIM_DEVICES = [ + (FusedAdam, torch.device('cuda:0')), + (CPUAdam, torch.device('cpu')), + (CPUAdam, torch.device('cuda:0')), + (HybridAdam, torch.device('cpu')), + (HybridAdam, torch.device('cuda:0')), +] + +_ALLOWED_P_G_TYPES = [ + (torch.float, torch.float), # pure fp32 + (torch.float, torch.half), # fp16 amp + (torch.float, torch.bfloat16), # bfloat16 amp + # (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16 + # (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16 +] + +N_STEPS = 3 + + +def setup_param_groups(bert_model: nn.Module) -> list: + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": 0.1, + }, + { + "params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + + +def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> None: + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + torch_p.grad = torch.rand_like(torch_p) + # avoid inconsistent grad and param dtype error + orig_p = p.data + p.data = torch_p.grad.clone().to(g_dtype) + p.grad = p.data + p.data = orig_p + + +@pytest.mark.parametrize('optim_cls, device', _ALLOWED_OPTIM_DEVICES) +@pytest.mark.parametrize('adamw', [False, True]) +@pytest.mark.parametrize('p_dtype, g_dtype', _ALLOWED_P_G_TYPES) +def test_adam_optim_on_bert(optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]], device: torch.device, + adamw: bool, p_dtype: torch.dtype, g_dtype: torch.dtype) -> None: + model_fn, *_ = next(iter(model_zoo.get_sub_registry('transformers_bert_for_sequence_classification').values())) + torch_model = model_fn().to(device) + model = deepcopy(torch_model).to(p_dtype) + lr = 1e-3 + beta1, beta2 = 0.9, 0.999 + eps = 1e-8 + torch_optim_cls = AdamW if adamw else Adam + torch_optim = torch_optim_cls(setup_param_groups(torch_model), lr=lr, betas=(beta1, beta2), eps=eps) + optim = optim_cls(setup_param_groups(model), lr=lr, betas=(beta1, beta2), eps=eps, adamw_mode=adamw) + + rtol, atol = 1e-5, 1e-5 + if p_dtype is torch.float16 or g_dtype is torch.float16: + rtol, atol = 2e-3, 2e-3 + if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16: + rtol, atol = 4e-3, 4e-3 + + for _ in range(N_STEPS): + set_grad(model, torch_model, g_dtype) + torch_optim.step() + optim.step() + torch_optim.zero_grad() + optim.zero_grad() + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + # if overflow, the weight won't be updated. so there will be no nan in p + assert not torch.isnan(p).any() + assert torch.allclose(p.float(), torch_p, rtol=rtol, atol=atol) diff --git a/tests/test_optimizer/test_cpu_adam.py b/tests/test_optimizer/test_cpu_adam.py deleted file mode 100644 index 8b3ecf8517f7..000000000000 --- a/tests/test_optimizer/test_cpu_adam.py +++ /dev/null @@ -1,121 +0,0 @@ -import math - -import torch - -from colossalai.testing import clear_cache_before_run, parameterize - - -def torch_adam_update( - step, - lr, - beta1, - beta2, - eps, - weight_decay, - param, - grad, - exp_avg, - exp_avg_sq, - use_adamw, -): - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - - if weight_decay != 0: - if use_adamw: - # Perform stepweight decay - param.mul_(1 - lr * weight_decay) - else: - grad = grad.add(param, alpha=weight_decay) - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) - - step_size = lr / bias_correction1 - - param.addcdiv_(exp_avg, denom, value=-step_size) - - -def assertLess(data_diff, threshold, msg): - assert data_diff < threshold, msg - - -def assertTrue(condition, msg): - assert condition, msg - - -@clear_cache_before_run() -@parameterize('adamw', [True, False]) -@parameterize('step', [1, 2]) -@parameterize('p_dtype', [torch.float, torch.half]) -@parameterize('g_dtype', [torch.float, torch.half]) -def test_cpu_adam(adamw, step, p_dtype, g_dtype): - lr = 1e-3 - beta1, beta2 = 0.9, 0.999 - eps = 1e-8 - weight_decay = 0 - - for i in range(3): - p_data = torch.rand(64, dtype=p_dtype) - p_data_copy = p_data.clone().float() - p_grad = torch.rand(64, dtype=g_dtype) - p_grad_copy = p_grad.clone().float() - exp_avg = torch.rand(p_data.shape) - exp_avg_copy = exp_avg.clone() - exp_avg_sq = torch.rand(p_data.shape) - exp_avg_sq_copy = exp_avg_sq.clone() - - from colossalai.kernel.op_builder import CPUAdamBuilder - cpu_optim = CPUAdamBuilder().load() - - cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw) - - cpu_adam_op.step( - step, - lr, - beta1, - beta2, - eps, - weight_decay, - True, - p_data.view(-1), # fp32 data - p_grad.view(-1), # fp32 grad - exp_avg.view(-1), - exp_avg_sq.view(-1), - -1, - ) - - torch_adam_update( - step, - lr, - beta1, - beta2, - eps, - weight_decay, - p_data_copy, # fp32 data - p_grad_copy, # fp32 grad - exp_avg_copy, - exp_avg_sq_copy, - adamw, - ) - var = p_data_copy - p_data - data_diff = torch.max(torch.abs(var)) - threshold = 1e-3 - assertLess( - data_diff, - threshold, - f"p_data diff {data_diff}. failed check, step {step}, lr {lr}, eps " - f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}", - ) - max_grad_diff = torch.max(torch.abs(p_grad_copy - p_grad)) - assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}") - max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg)) - assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}") - max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq)) - assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}") - - -if __name__ == '__main__': - test_cpu_adam() diff --git a/tests/test_optimizer/test_fused_adam.py b/tests/test_optimizer/test_fused_adam.py deleted file mode 100644 index 511987ca3dbf..000000000000 --- a/tests/test_optimizer/test_fused_adam.py +++ /dev/null @@ -1,73 +0,0 @@ -from copy import deepcopy - -import pytest -import torch -import torch.nn as nn -from torch.optim import AdamW -from torch.optim.adam import Adam - -from colossalai.nn.optimizer.fused_adam import FusedAdam - - -class FC(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.fc = nn.Sequential(nn.Linear(64, 64)) - - def forward(self, x): - return self.fc(x) - - -@pytest.mark.parametrize('adamw', [False, True]) -@pytest.mark.parametrize('running_p_dtype', [torch.float, torch.half, torch.bfloat16]) -@pytest.mark.parametrize('fp32_master_weights', [False, True]) -def test_adam(adamw, running_p_dtype, fp32_master_weights): - # baseline is fure fp32 torch adam - # g_type is the same as running_p_dtype - if running_p_dtype is torch.float and not fp32_master_weights: - # pure fp32 must have fp32 weights - return - if not fp32_master_weights or running_p_dtype is torch.bfloat16: - # pure low precision or bf16, high tolerance - atol = 4e-3 - rtol = 4e-3 - else: - # fp32 master weights, low tolerance - atol = 2e-3 - rtol = 2e-3 - torch_model = FC().cuda() - model = deepcopy(torch_model).to(running_p_dtype) - - torch_optim_cls = AdamW if adamw else Adam - torch_optim = torch_optim_cls(torch_model.parameters(), lr=1e-3) - optim = FusedAdam(model.parameters(), lr=1e-3, adamw_mode=adamw) - - data = torch.rand(10, 64).cuda() - label = torch.rand(10, 64).cuda() - - for d, l in zip(data.to(running_p_dtype), label.to(running_p_dtype)): - y = model(d) - loss = ((l - y)**2).sum() - optim.zero_grad() - loss.backward() - if fp32_master_weights: - for p in model.parameters(): - p.data = p.data.float() - optim.step() - if fp32_master_weights: - for p in model.parameters(): - p.data = p.data.to(running_p_dtype) - - for d, l in zip(data, label): - y = torch_model(d) - loss = ((l - y)**2).sum() - torch_optim.zero_grad() - loss.backward() - torch_optim.step() - - for p, torch_p in zip(model.parameters(), torch_model.parameters()): - if torch.isnan(p).any() or torch.isnan(torch_p).any(): - continue - fp32_p = p.float() - assert torch.allclose(fp32_p, torch_p, atol=atol, rtol=rtol) diff --git a/tests/test_optimizer/test_fused_adam_kernel.py b/tests/test_optimizer/test_fused_adam_kernel.py deleted file mode 100644 index deabaebfec75..000000000000 --- a/tests/test_optimizer/test_fused_adam_kernel.py +++ /dev/null @@ -1,95 +0,0 @@ -import math - -import pytest -import torch - -from colossalai.utils import multi_tensor_applier - -_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float), - (torch.half, torch.half), (torch.bfloat16, torch.float), (torch.float, torch.bfloat16), - (torch.bfloat16, torch.bfloat16)] - - -def torch_adam_update( - step, - lr, - beta1, - beta2, - eps, - weight_decay, - param, - grad, - exp_avg, - exp_avg_sq, - use_adamw, -): - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - - if weight_decay != 0: - if use_adamw: - # Perform stepweight decay - param.mul_(1 - lr * weight_decay) - else: - grad = grad.add(param, alpha=weight_decay) - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) - - step_size = lr / bias_correction1 - - param.addcdiv_(exp_avg, denom, value=-step_size) - - -@pytest.mark.parametrize('adamw', [False, True]) -@pytest.mark.parametrize('step', [1, 2]) -@pytest.mark.parametrize('p_dtype, g_dtype', _ALLOWED_P_G_TYPES) -def test_adam(adamw, step, p_dtype, g_dtype): - from colossalai.kernel.op_builder import FusedOptimBuilder - fused_optim = FusedOptimBuilder().load() - fused_adam = fused_optim.multi_tensor_adam - - dummy_overflow_buf = torch.cuda.IntTensor([0]) - - count = 0 - - for i in range(3): - p = torch.rand(64, dtype=p_dtype).cuda() - p_copy = p.clone().float() - g = torch.rand(p.shape, dtype=g_dtype).cuda() - g_copy = g.clone().float() - m = torch.rand(p.shape).cuda() - m_copy = m.clone() - v = torch.rand(p.shape).cuda() - v_copy = v.clone() - - lr = 1e-3 - beta1, beta2 = 0.9, 0.999 - eps = 1e-8 - weight_decay = 0 - - multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]], lr, beta1, beta2, eps, step, adamw, - True, weight_decay, -1) - - torch_adam_update( - step, - lr, - beta1, - beta2, - eps, - weight_decay, - p_copy, # fp32 data - g_copy, # fp32 grad - m_copy, - v_copy, - adamw, - ) - - if torch.isnan(p).any() or torch.isnan(p_copy).any(): - count += 1 - continue - assert count < 200, "too many nans" - assert torch.allclose(p.to(torch.float), p_copy.to(torch.float), 1e-5, - 1e-5), f"failed check, adamw {adamw}, p_dtype {p_dtype}, g_dtype {g_dtype}" diff --git a/tests/test_optimizer/test_hybrid_adam.py b/tests/test_optimizer/test_hybrid_adam.py deleted file mode 100644 index d075149dfcb1..000000000000 --- a/tests/test_optimizer/test_hybrid_adam.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch -import torch.nn as nn -from torch.optim import AdamW -from torch.optim.adam import Adam - -from colossalai.nn.optimizer.hybrid_adam import HybridAdam -from colossalai.testing import clear_cache_before_run, parameterize - -RE = 3 - - -@clear_cache_before_run() -@parameterize('adamw', [False, True]) -@parameterize('device', ['cpu', 'cuda:0']) -@parameterize('p_dtype', [torch.float]) -@parameterize('g_dtype', [torch.float, torch.half]) -def test_adam(adamw, device, p_dtype, g_dtype): - rng_state = torch.get_rng_state() - p = nn.Parameter(torch.rand(64).to(device, p_dtype)) - torch.set_rng_state(rng_state) - p_copy = nn.Parameter(torch.rand(64).to(device).float()) - - if adamw: - optim = HybridAdam([p], lr=1e-3, adamw_mode=True) - torch_optim = AdamW([p_copy], lr=1e-3) - else: - optim = HybridAdam([p], lr=1e-3) - torch_optim = Adam([p_copy], lr=1e-3) - - print(f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}") - for i in range(RE): - p.grad = torch.rand(64).to(device, p_dtype) - p_copy.grad = p.grad.clone().float() - p.grad.data = p.grad.data.to(g_dtype) - - optim.step() - torch_optim.step() - - if torch.isnan(p.data).any() or torch.isnan(p_copy.data).any(): - continue - assert torch.allclose(p.data, p_copy.data, 1e-4, 1e-2), \ - f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}" From d175d92a4f3bacf15508573f5d2e99957afdf861 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 31 May 2023 13:35:23 +0800 Subject: [PATCH 3/6] [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869) * [bf16] add mixed precision mixin * [bf16] low level zero optim support bf16 * [text] update low level zero test * [text] fix low level zero grad acc test --- .../mixed_precision_mixin/__init__.py | 9 ++ .../naive_amp/mixed_precision_mixin/base.py | 91 +++++++++++++++ .../naive_amp/mixed_precision_mixin/bf16.py | 23 ++++ .../naive_amp/mixed_precision_mixin/fp16.py | 84 ++++++++++++++ colossalai/zero/low_level/low_level_optim.py | 106 +++++++++--------- .../test_zero/test_low_level/test_grad_acc.py | 5 +- .../test_zero/test_low_level/test_zero1_2.py | 35 +++--- 7 files changed, 281 insertions(+), 72 deletions(-) create mode 100644 colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py create mode 100644 colossalai/amp/naive_amp/mixed_precision_mixin/base.py create mode 100644 colossalai/amp/naive_amp/mixed_precision_mixin/bf16.py create mode 100644 colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py b/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py new file mode 100644 index 000000000000..b0348e1477bb --- /dev/null +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py @@ -0,0 +1,9 @@ +from .base import MixedPrecisionMixin +from .bf16 import BF16MixedPrecisionMixin +from .fp16 import FP16MixedPrecisionMixin + +__all__ = [ + 'MixedPrecisionMixin', + 'FP16MixedPrecisionMixin', + 'BF16MixedPrecisionMixin', +] diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py new file mode 100644 index 000000000000..a52a9747ad1e --- /dev/null +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py @@ -0,0 +1,91 @@ +from abc import ABC, abstractmethod + +import torch +from torch import Tensor + + +class MixedPrecisionMixin(ABC): + """A helper class for mixed precision training. This mixin is used in mixed precision optimizers. + + Attributes: + dtype (torc.dtype): The expected dtype of the gradients. + + Examples: + ```python + class MyMixedPrecisionOptimizer(OptimizerWrapper): + def __init__(self, optim: Optimizer): + super().__init__(optim) + self.mixed_precision = MixedPrecisionMixin() + + def backward(self, loss): + loss = self.mixed_precision.pre_backward(loss) + loss.backward() + + def backward_by_grad(self, tensor, grad): + grad = self.mixed_precision.pre_backward_by_grad(tensor, grad) + tensor.backward(grad) + + def step(self): + if self.mixed_precision.should_skip_step(): + self.zero_grad() + return + div_scale = self.mixed_precision.get_grad_div_scale() + # maybe clip grad here + # maybe scale grad here + self.optim.step() + + def zero_grad(self): + self.mixed_precision.pre_zero_grad() + return self.optim.zero_grad() + ``` + """ + dtype: torch.dtype + + @abstractmethod + def pre_backward(self, loss: Tensor) -> Tensor: + """Called before backward. + + Args: + loss (Tensor): Loss value. + + Returns: + Tensor: Loss value (possibly scaled). + """ + pass + + @abstractmethod + def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor: + """Called before backward by grad. This is helpful for pipeline parallelism. + + Args: + tensor (Tensor): Tensor to backward. + grad (Tensor): Gradient of the tensor. + + Returns: + Tensor: Gradient of the tensor (possibly scaled). + """ + pass + + @abstractmethod + def should_skip_step(self) -> bool: + """Called before step. + + Returns: + bool: Whether to skip the step. + """ + pass + + @abstractmethod + def pre_zero_grad(self) -> None: + """Called before zero_grad. + """ + pass + + @abstractmethod + def get_grad_div_scale(self) -> float: + """Called before step or clip_grad. To keep computation efficiency, this method does not (maybe) unscale grads. + + Returns: + float: A divisor for gradient clipping or step. + """ + pass diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/bf16.py b/colossalai/amp/naive_amp/mixed_precision_mixin/bf16.py new file mode 100644 index 000000000000..9454f6eb8413 --- /dev/null +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/bf16.py @@ -0,0 +1,23 @@ +import torch +from torch import Tensor + +from .base import MixedPrecisionMixin + + +class BF16MixedPrecisionMixin(MixedPrecisionMixin): + dtype = torch.bfloat16 + + def pre_backward(self, loss: Tensor) -> Tensor: + return loss + + def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor: + return grad + + def should_skip_step(self) -> bool: + return False + + def pre_zero_grad(self) -> None: + pass + + def get_grad_div_scale(self) -> float: + return 1.0 diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py new file mode 100644 index 000000000000..1ce8e42eb3ed --- /dev/null +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py @@ -0,0 +1,84 @@ +from abc import abstractmethod +from enum import Enum + +import torch +import torch.distributed as dist +from torch import Tensor + +from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.utils import get_current_device + +from .base import MixedPrecisionMixin + + +class OptimState(Enum): + SCALED = 0 + UNSCALED = 1 + + +class FP16MixedPrecisionMixin(MixedPrecisionMixin): + dtype = torch.float16 + + def __init__(self, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32) -> None: + super().__init__() + self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale) + self.optim_state = OptimState.UNSCALED + self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device()) + + @property + def loss_scale(self) -> float: + return self.grad_scaler.scale.item() + + @abstractmethod + def check_local_overflow(self) -> bool: + """Check whether there is overflow in the local process. This method should be implemented by subclasses. + + Returns: + bool: Whether there is overflow in the local process. + """ + pass + + def check_overflow(self) -> bool: + # clear previous overflow record + self.found_overflow.fill_(0.0) + if self.check_local_overflow(): + self.found_overflow.fill_(1.0) + dist.all_reduce(self.found_overflow, op=dist.ReduceOp.MAX) + return self.found_overflow.item() > 0 + + def pre_backward(self, loss: Tensor) -> Tensor: + loss = self.loss_scale * loss + self.optim_state = OptimState.SCALED + return loss + + def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor: + self.optim_state = OptimState.SCALED + return grad + + def should_skip_step(self) -> bool: + found_inf = self.check_overflow() + self.grad_scaler.update(found_inf) + if found_inf: + self.optim_state = OptimState.UNSCALED + return found_inf + + def pre_zero_grad(self) -> None: + pass + + def get_grad_div_scale(self) -> float: + assert self.optim_state == OptimState.SCALED, 'grads should be scaled before clipping' + self.optim_state = OptimState.UNSCALED + return self.loss_scale diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 3e7661ecab76..d4d03e5b5fcd 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -6,7 +6,11 @@ import torch.distributed as dist from torch.optim import Optimizer -from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.amp.naive_amp.mixed_precision_mixin import ( + BF16MixedPrecisionMixin, + FP16MixedPrecisionMixin, + MixedPrecisionMixin, +) from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger @@ -27,6 +31,31 @@ from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket +class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): + + def __init__(self, + num_working_param_groups: int, + grad_store: GradientStore, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32) -> None: + super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, + max_scale) + self.num_working_param_groups = num_working_param_groups + self.grad_store = grad_store + + def check_local_overflow(self) -> bool: + for group_id in range(self.num_working_param_groups): + for avg_grad in self.grad_store.get_averaged_gradients_by_group(group_id): + if avg_grad is not None and has_inf_or_nan(avg_grad): + return True + return False + + class LowLevelZeroOptimizer(ColossalaiOptimizer): """Optimizer used for ZeRO-1 and ZeRO-2. """ @@ -100,17 +129,6 @@ def __init__( self._reduce_bucket_size = reduce_bucket_size self._communication_dtype = communication_dtype - # gradient scaler - self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale, - verbose=verbose) - self._found_overflow = torch.FloatTensor([0]).to(get_current_device()) - # gradient clipping self._clip_grad_norm = clip_grad_norm @@ -200,14 +218,25 @@ def __init__( if self._overlap_communication or self._partition_grads: self._attach_reduction_hook() + # initialize mixed precision mixin + self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None + if self._dtype is torch.float16: + self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(self.num_param_groups, + self._grad_store, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale) + elif self._dtype is torch.bfloat16: + self.mixed_precision_mixin = BF16MixedPrecisionMixin() + @property def dtype(self): return self._dtype - @property - def loss_scale(self): - return self.grad_scaler.scale - @property def num_param_groups(self): return len(self._working_param_groups) @@ -392,7 +421,8 @@ def _add_to_reduction_bucket(self, param, reduce_rank=None): ################################ def backward(self, loss, retain_graph=False, sync_grad=True): - loss = self.loss_scale * loss + if self.mixed_precision_mixin is not None: + loss = self.mixed_precision_mixin.pre_backward(loss) loss.backward(retain_graph=retain_graph) # finish gradient reduction @@ -419,6 +449,8 @@ def zero_grad(self, set_to_none=True): :param set_to_none: Whether set the gradient to None. Default value is True. :type set_to_none: bool """ + if self.mixed_precision_mixin is not None: + self.mixed_precision_mixin.pre_zero_grad() for _, param_group in self._working_param_groups.items(): for param in param_group: if set_to_none: @@ -435,12 +467,7 @@ def zero_grad(self, set_to_none=True): def step(self, closure=None): assert closure is None, 'closure is not supported by step()' - # check for overflow - found_inf = self._check_overflow() - self.grad_scaler.update(found_inf) - - # update loss scale if overflow occurs - if found_inf: + if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step(): self._grad_store.reset_all_average_gradients() if self._verbose: self._logger.info(f'Found overflow. Skip step') @@ -507,41 +534,20 @@ def step(self, closure=None): # Mixed Precision Utilities # ############################# - def _check_overflow(self): - # clear previous overflow record - self._found_overflow.fill_(0.0) - - # check for overflow - for group_id in range(len(self._working_param_groups)): - for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id): - if avg_grad is not None and has_inf_or_nan(avg_grad): - self._found_overflow.fill_(1.0) - break - - # all-reduce across dp group - dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_torch_group) - - # all-reduce over model parallel group - if self._mp_torch_group: - dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_torch_group) - - if self._found_overflow.item() > 0: - return True - else: - return False - def _unscale_and_clip_grads(self, grad_groups_flat, total_norm): # compute combined scale factor for this group - combined_scale = self.loss_scale + div_scale = 1.0 + if self.mixed_precision_mixin is not None: + div_scale = self.mixed_precision_mixin.get_grad_div_scale() if self._clip_grad_norm > 0.: # norm is in fact norm*scale - clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm + clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm if clip > 1: - combined_scale = clip * self.loss_scale + div_scale = clip * div_scale for grad in grad_groups_flat: - grad.data.mul_(1. / combined_scale) + grad.data.mul_(1. / div_scale) ############################ # Gradient Synchronization # diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index 2ae1f3a99d79..c264a8077d2a 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -82,7 +82,6 @@ def fwd_bwd_func(number, cur_data): def exam_zero_1_grad_acc(): local_rank = torch.distributed.get_rank() - grad_scale = 32 seed_all(2008) # create models @@ -101,7 +100,6 @@ def exam_zero_1_grad_acc(): # level 1 and 2 will produce exactly the same results zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, overlap_communication=False, - initial_scale=grad_scale, reduce_bucket_size=262144, clip_grad_norm=1.0) @@ -128,9 +126,8 @@ def fwd_bwd_func(number, cur_data, check_flag): if check_flag: # check grad for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - unscale_grad = z1p.grad / grad_scale # print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad))) - assert torch.equal(p.grad, unscale_grad) + assert torch.equal(p.grad, z1p.grad) zero_optimizer._sync_grad() diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index 4086af9d896e..8e2206fe6c8d 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -7,7 +7,7 @@ from torch.testing import assert_close import colossalai -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from colossalai.zero import LowLevelZeroOptimizer @@ -25,15 +25,18 @@ def forward(self, x): return x -def half_close(a, b, loose=False): +def loose_close(a, b, dtype: torch.dtype = torch.float32): rtol = None atol = None - if loose: + if dtype is torch.float16: rtol = 5e-2 atol = 5e-4 + elif dtype is torch.bfloat16: + rtol = 4e-3 + atol = 4e-3 - a = a.detach().half() - b = b.detach().half() + a = a.detach().to(dtype) + b = b.detach().to(dtype) assert_close(a, b, rtol=rtol, atol=atol) @@ -96,7 +99,8 @@ def exam_zero_1_2(): assert torch.equal(z1p.data, z2p.data) -def exam_zero_1_torch_ddp(): +@parameterize('dtype', [torch.float16, torch.bfloat16]) +def exam_zero_1_torch_ddp(dtype: torch.dtype): """ In this test, two pairs of model and optimizers are created. 1. zero: use sharded optimizer and fp16 parameters @@ -109,15 +113,10 @@ def exam_zero_1_torch_ddp(): seed_all(1453) # create models - zero_model = MlpModel() - torch_model = copy.deepcopy(zero_model) + torch_model = MlpModel().cuda() + zero_model = copy.deepcopy(torch_model).to(dtype) - zero_model = zero_model.cuda().half() - torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0) - torch_model = torch_model.cuda() - - # for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - # half_close(p.data, z1p.data) + torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0).cuda() # create optimizer zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) @@ -137,11 +136,11 @@ def exam_zero_1_torch_ddp(): input_data = torch.rand(32, 128).cuda() # zero-dp forward - zero_output = zero_model(input_data.half()) + zero_output = zero_model(input_data.to(dtype)) # torch-ddp forward torch_output = torch_model(input_data) - half_close(zero_output, torch_output, loose=True) + loose_close(zero_output, torch_output, dtype=dtype) # zero-dp backward zero_optimizer.backward(zero_output.mean().float(), sync_grad=False) @@ -151,7 +150,7 @@ def exam_zero_1_torch_ddp(): # check grad for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - half_close(p.grad, z1p.grad, loose=True) + loose_close(p.grad, z1p.grad, dtype=dtype) # zero-dp step zero_optimizer._sync_grad() @@ -163,7 +162,7 @@ def exam_zero_1_torch_ddp(): # check updated param for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): # print(n, torch.max(torch.abs(p.data - z1p.data))) - half_close(p.data, z1p.data, loose=True) + loose_close(p.data, z1p.data, dtype=dtype) def run_dist(rank, world_size, port): From c9c3c823eb8b0b5dec972ba616819de0ab00642d Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 1 Jun 2023 12:58:25 +0800 Subject: [PATCH 4/6] [bf16] add bf16 support for gemini (#3872) * [bf16] gemini support bf16 * [test] update gemini bf16 test * [doc] update gemini docstring --- colossalai/zero/gemini/gemini_ddp.py | 82 +++++++++++-------- colossalai/zero/gemini/gemini_optimizer.py | 92 ++++++++++------------ tests/test_zero/test_gemini/test_optim.py | 46 ++++++++--- 3 files changed, 125 insertions(+), 95 deletions(-) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 878c25be7094..7e230896aac8 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -2,7 +2,7 @@ from collections import OrderedDict from contextlib import nullcontext from functools import partial -from typing import Dict, Iterator, List, Optional, Union, Tuple, Set +from typing import Dict, Iterator, List, Optional, Set, Tuple, Union import torch import torch.distributed as dist @@ -51,6 +51,7 @@ class ZeroDDP(ColoDDP): strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated. Defaults to False. Users can set it to True, when they clearly know that they only need DDP. scatter_after_inference (bool): If set to True, the model will be scattered after inference. This will save memory but slow down the consecutive inference. + mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16. """ def __init__(self, @@ -59,7 +60,9 @@ def __init__(self, pin_memory: bool = False, force_outputs_fp32: bool = False, strict_ddp_mode: bool = False, - scatter_after_inference: bool = True) -> None: + scatter_after_inference: bool = True, + mixed_precision: torch.dtype = torch.float16) -> None: + assert mixed_precision in (torch.float16, torch.bfloat16) self.gemini_manager = gemini_manager self.chunk_manager: ChunkManager = gemini_manager.chunk_manager self.force_outputs_fp32 = force_outputs_fp32 @@ -71,6 +74,7 @@ def __init__(self, self.param2name: Dict[nn.Parameter, str] = dict() self.name2param: Dict[str, nn.Parameter] = dict() self.scatter_after_inference = scatter_after_inference + self.mixed_precision = mixed_precision self._logger = get_dist_logger() @@ -96,34 +100,38 @@ def __init__(self, param_name = m_name + '.' + p_name if m_name else p_name self.name2param[param_name] = p_var super().__init__(module, process_group=ColoProcessGroup()) - self._non_persistent_buffers_set=self._get_non_persistent_buffers_set(module) + self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module) self._cast_buffers() - def _get_non_persistent_buffers_set(self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True): - - r""" - Args: - memo: a memo to store the set of modules already added to the result - prefix: a prefix that will be added to the name of the module - remove_duplicate: whether to remove the duplicated module instances in the result - or not - """ - - if memo is None: - memo = set() - self_non_persistent_set = set() - if module not in memo: - if remove_duplicate: - memo.add(module) - self_non_persistent_set = set(map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set)) - for name, sub_module in module._modules.items(): - if sub_module is None: - continue - submodule_prefix = prefix + ('.' if prefix else '') + name - child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, remove_duplicate) - self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set) - return self_non_persistent_set - + def _get_non_persistent_buffers_set(self, + module, + memo: Optional[Set[nn.Module]] = None, + prefix: str = '', + remove_duplicate: bool = True): + r""" + Args: + memo: a memo to store the set of modules already added to the result + prefix: a prefix that will be added to the name of the module + remove_duplicate: whether to remove the duplicated module instances in the result + or not + """ + + if memo is None: + memo = set() + self_non_persistent_set = set() + if module not in memo: + if remove_duplicate: + memo.add(module) + self_non_persistent_set = set( + map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set)) + for name, sub_module in module._modules.items(): + if sub_module is None: + continue + submodule_prefix = prefix + ('.' if prefix else '') + name + child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, + remove_duplicate) + self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set) + return self_non_persistent_set def _post_forward(self): """This function is only triggered for inference. @@ -147,7 +155,7 @@ def forward(self, *args, **kwargs): assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup( ), "You should run a completed iteration as your warmup iter" - args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) + args, kwargs = _cast_float(args, self.mixed_precision), _cast_float(kwargs, self.mixed_precision) self.module.zero_grad(set_to_none=True) if not grad_flag: outputs = self._inference_forward(*args, **kwargs) @@ -566,14 +574,14 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi # move ignored parameters to CUDA if is_ddp_ignored(p): - p.data = p.data.to(device=get_current_device(), dtype=torch.float16) + p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision) continue # create a fp32 parameter fp32_data = p.data.float() fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) # create a fp16 parameter - p.data = p.data.half() + p.data = p.data.to(self.mixed_precision) # register the fp16 parameter and fp32 parameter in the chunk manager dp_world_size = p.process_group.dp_world_size() @@ -609,7 +617,7 @@ def _cast_buffers(self): buffer.materialize() buffer.data = buffer.cuda() if torch.is_floating_point(buffer): - buffer.data = buffer.half() + buffer.data = buffer.to(self.mixed_precision) def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None: """Convert parameter to ColoParameter in-place. @@ -732,6 +740,7 @@ def __init__(self, hidden_dim: Optional[int] = None, min_chunk_size_mb: float = 32, memstats: Optional[MemStats] = None, + mixed_precision: torch.dtype = torch.float16, verbose: bool = False) -> None: """ A torch.Module wrapper using ZeRO-DP and Gemini. @@ -772,5 +781,10 @@ def __init__(self, strict_ddp_flag=strict_ddp_mode, verbose=verbose) gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) - super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode, - scatter_after_inference) + super().__init__(module, + gemini_manager, + pin_memory, + force_outputs_fp32, + strict_ddp_mode, + scatter_after_inference, + mixed_precision=mixed_precision) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 71c4f65cb8d2..267deb1e8699 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -1,7 +1,6 @@ # this code is inspired by the DeepSpeed library and implemented with our own design from scratch import math import warnings -from enum import Enum from typing import Any, Dict, Set, Tuple import torch @@ -9,7 +8,7 @@ from torch.nn import Parameter from torch.optim import Optimizer -from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam from colossalai.utils import disposable, get_current_device, is_ddp_ignored @@ -22,9 +21,26 @@ _AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam} -class OptimState(Enum): - SCALED = 0 - UNSCALED = 1 +class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): + + def __init__(self, + module: ZeroDDP, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32) -> None: + super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, + max_scale) + self.module = module + + def check_local_overflow(self) -> bool: + return self.module.overflow_counter > 0 + + def pre_zero_grad(self) -> None: + self.module.overflow_counter = 0 class ZeroOptimizer(ColossalaiOptimizer): @@ -79,7 +95,6 @@ def __init__(self, self.module = module self.gemini_manager = module.gemini_manager self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager - self.optim_state = OptimState.UNSCALED self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict() self.param_to_chunk32: Dict[Parameter, Chunk] = dict() self.chunk16_set: Set[Chunk] = set() @@ -107,15 +122,20 @@ def __init__(self, self.__init__optimizer() - # Grad scaler - self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale) - self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + if module.mixed_precision is torch.float16: + self.mix_precision_mixin = GeminiFP16MixedPrecisionMixin(module, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale) + elif module.mixed_precision is torch.bfloat16: + self.mix_precision_mixin = BF16MixedPrecisionMixin() + else: + raise RuntimeError(f"Unsupported mixed precision type: {module.mixed_precision}") + self._logger = get_dist_logger() self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio) @@ -151,15 +171,6 @@ def _update_fp16_params(self): for chunk16 in self.chunk16_set: chunk16.optim_update() - def _check_overflow(self): - # clear previous overflow record - self._found_overflow.fill_(self.module.overflow_counter) - - # all-reduce across global group - dist.all_reduce(self._found_overflow) - - return self._found_overflow.item() > 0 - def _clear_global_norm(self) -> None: for c16 in self.chunk16_set: c16.l2_norm = None @@ -190,40 +201,25 @@ def _calc_global_norm(self) -> float: return global_norm def _get_combined_scale(self): - loss_scale = 1 - - if self.optim_state == OptimState.SCALED: - loss_scale = self.loss_scale - self.optim_state = OptimState.UNSCALED + div_scale = self.mix_precision_mixin.get_grad_div_scale() - combined_scale = loss_scale if self.clipping_flag: total_norm = self._calc_global_norm() - clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm + clip = ((total_norm / div_scale) + 1e-6) / self.max_norm if clip > 1: - combined_scale = clip * loss_scale + div_scale = clip * div_scale - if combined_scale == 1: - return -1 - else: - return combined_scale - - @property - def loss_scale(self): - return self.grad_scaler.scale.item() + return -1 if div_scale == 1.0 else div_scale def zero_grad(self, *args, **kwargs): - self.module.overflow_counter = 0 + self.mix_precision_mixin.pre_zero_grad() return self.optim.zero_grad(set_to_none=True) def step(self, *args, **kwargs): self._maybe_move_fp32_params() self._set_grad_ptr() - found_inf = self._check_overflow() - if found_inf: - self.optim_state = OptimState.UNSCALED # no need to unscale grad - self.grad_scaler.update(found_inf) # update gradient scaler + if self.mix_precision_mixin.should_skip_step(): if self.verbose: self._logger.info(f'Found overflow. Skip step') self._clear_global_norm() # clear recorded norm @@ -234,7 +230,6 @@ def step(self, *args, **kwargs): # get combined scale. combined scale = loss scale * clipping norm # so that gradient = gradient / combined scale combined_scale = self._get_combined_scale() - self.grad_scaler.update(found_inf) ret = self.optim.step(div_scale=combined_scale, *args, **kwargs) self._register_states() @@ -246,8 +241,7 @@ def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: flo raise NotImplementedError def backward(self, loss: torch.Tensor): - loss = self.loss_scale * loss - self.optim_state = OptimState.SCALED + loss = self.mix_precision_mixin.pre_backward(loss) self.module.backward(loss) def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor): @@ -255,7 +249,7 @@ def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor): # It receives the scaled grad from the previous rank # No need to scale the grad again # Need to unscale when optimizing - self.optim_state = OptimState.SCALED + grad = self.mix_precision_mixin.pre_backward_by_grad(grad) self.module.backward_by_grad(tensor, grad) def _maybe_move_fp32_params(self): diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 8ce20c16e8f9..66611bcd2419 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -21,23 +21,40 @@ # these models are too small, all parameters in these models are compacted into one chunk EXAMPLE_MODELS = ['albert', 'beit', 'bert', 'hanging_param_model', 'nested_model', 'repeated_computed_layers'] +# bfloat16 cannot represent them exactly +BF16_IGNORED_KEYS = [ + 'albert.embeddings.word_embeddings.weight', + 'albert.embeddings.position_embeddings.weight', + 'masked_bias', +] -def check_param(model: ZeroDDP, torch_model: torch.nn.Module): - zero_dict = model.state_dict(only_rank_0=False) + +def check_param(model: ZeroDDP, torch_model: torch.nn.Module, dtype: torch.dtype): + zero_dict = model.state_dict(only_rank_0=False, dtype=dtype) torch_dict = torch_model.state_dict() for key, value in torch_dict.items(): # key is 'module.model.PARAMETER', so we truncate it key = key[7:] assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) - temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + temp_zero_value = zero_dict[key].to(device=value.device) + if dtype is torch.bfloat16 and any(k in key for k in BF16_IGNORED_KEYS): + continue + rtol, atol = 1e-3, 4e-3 + if dtype is torch.bfloat16: + rtol, atol = 4e-3, 8e-3 # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) - assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) + assert_close(value.float(), + temp_zero_value.float(), + rtol=rtol, + atol=atol, + msg=lambda s: s + f'\n{key}\n{temp_zero_value.dtype}') @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) @parameterize('model_name', TEST_MODELS) -def exam_model_step(placement_policy, model_name: str): +@parameterize('mixed_precision', [torch.half, torch.bfloat16]) +def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dtype): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -65,7 +82,7 @@ def exam_model_step(placement_policy, model_name: str): init_device = None chunk_manager = ChunkManager(config_dict, init_device=init_device) gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) + model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128) @@ -74,6 +91,7 @@ def exam_model_step(placement_policy, model_name: str): torch_model.eval() set_seed(dist.get_rank() * 3 + 128) + rtol, atol = 1e-4, 1e-5 for i, (input_ids, label) in enumerate(train_dataloader): if i > 2: break @@ -83,17 +101,18 @@ def exam_model_step(placement_policy, model_name: str): torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) - assert_close(torch_loss, loss) + assert_close(torch_loss, loss, rtol=rtol, atol=atol) zero_optim.step() torch_optim.step() - check_param(model, torch_model) + check_param(model, torch_model, mixed_precision) @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) @parameterize('model_name', EXAMPLE_MODELS) -def exam_tiny_example(placement_policy, model_name: str): +@parameterize('mixed_precision', [torch.half, torch.bfloat16]) +def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.dtype): set_seed(2008) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -113,7 +132,7 @@ def exam_tiny_example(placement_policy, model_name: str): chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1) gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) + model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2) @@ -121,6 +140,9 @@ def exam_tiny_example(placement_policy, model_name: str): torch_model.eval() set_seed(dist.get_rank() * 3 + 128) + rtol, atol = 1.5e-6, 2e-5 + if mixed_precision is torch.bfloat16: + rtol, atol = 2e-3, 2e-3 for i, (input_ids, label) in enumerate(train_dataloader): if i > 2: break @@ -133,12 +155,12 @@ def exam_tiny_example(placement_policy, model_name: str): torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) - assert_close(torch_loss, loss, rtol=1.5e-6, atol=2e-5) # atol should be 2e-5 for torch lower than 1.12 + assert_close(torch_loss, loss, rtol=rtol, atol=atol) # atol should be 2e-5 for torch lower than 1.12 zero_optim.step() torch_optim.step() - check_param(model, torch_model) + check_param(model, torch_model, mixed_precision) def run_dist(rank, world_size, port): From 05ec0f6de22668e9d7d2d0e12e5de6a8e5aead1b Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 1 Jun 2023 15:27:58 +0800 Subject: [PATCH 5/6] [bf16] add bf16 support for plugins (#3877) --- colossalai/booster/plugin/gemini_plugin.py | 9 ++++- .../booster/plugin/low_level_zero_plugin.py | 33 ++++++++++++------- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index adbf4803eefe..46714fe1c679 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -23,6 +23,9 @@ __all__ = ['GeminiPlugin'] +SUPPORTED_PRECISION = ['fp16', 'bf16'] +PRECISION_STR_TO_DTYPE = {'fp16': torch.half, 'bf16': torch.bfloat16} + class GeminiCheckpointIO(GeneralCheckpointIO): @@ -171,6 +174,7 @@ class GeminiPlugin(DPPluginBase): Args: device (torch.device): device to place the model. placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu". + precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'. pin_memory (bool, optional): use pin memory on CPU. Defaults to False. force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False. @@ -203,6 +207,7 @@ def __init__( self, device: Optional[torch.device] = None, placement_policy: str = "cpu", + precision: str = "fp16", pin_memory: bool = False, force_outputs_fp32: bool = False, strict_ddp_mode: bool = False, @@ -223,6 +228,7 @@ def __init__( verbose: bool = False, ) -> None: super().__init__() + assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported' self.gemini_config = dict( device=(device or get_current_device()), placement_policy=placement_policy, @@ -233,6 +239,7 @@ def __init__( hidden_dim=hidden_dim, min_chunk_size_mb=min_chunk_size_mb, memstats=memstats, + mixed_precision=PRECISION_STR_TO_DTYPE[precision], ) self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,) self.optim_kwargs = dict(initial_scale=initial_scale, @@ -253,7 +260,7 @@ def control_precision(self) -> bool: return True def supported_precisions(self) -> List[str]: - return ['fp16'] + return SUPPORTED_PRECISION def control_device(self) -> bool: return True diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 5d93cf0e33be..2b312d0f9947 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -1,4 +1,5 @@ import warnings +from functools import partial from typing import Callable, Iterator, List, Optional, Tuple, Union import torch @@ -20,12 +21,15 @@ __all__ = ['LowLevelZeroPlugin'] -def _convert_to_fp16(x): +def _convert_floating_point(x, dtype: torch.dtype = torch.float16): if isinstance(x, torch.Tensor) and torch.is_floating_point(x): - return x.half() + return x.to(dtype) return x +SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32'] + + class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): @@ -49,17 +53,24 @@ class LowLevelZeroModel(ModelWrapper): def __init__(self, module: nn.Module, stage: int, precision: str) -> None: super().__init__(module) - self.convert_inputs = (precision == 'fp16') - module = zero_model_wrapper(module, zero_stage=stage) + self.dtype = None if precision == 'fp16': - module = module.half() + self.dtype = torch.float16 + elif precision == 'bf16': + self.dtype = torch.bfloat16 + module = zero_model_wrapper(module, zero_stage=stage) + if self.dtype is not None: + module = module.to(self.dtype) module = module.to(get_current_device()) self.module = module + self.convert_fn = None + if self.dtype is not None: + self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) def forward(self, *args, **kwargs): - if self.convert_inputs: - args = tree_map(_convert_to_fp16, args) - kwargs = tree_map(_convert_to_fp16, kwargs) + if self.convert_fn is not None: + args = tree_map(self.convert_fn, args) + kwargs = tree_map(self.convert_fn, kwargs) return super().forward(*args, **kwargs) @@ -110,7 +121,7 @@ class LowLevelZeroPlugin(DPPluginBase): Args: strage (int, optional): ZeRO stage. Defaults to 1. - precision (str, optional): precision. Support 'fp16' and 'fp32'. Defaults to 'fp16'. + precision (str, optional): precision. Support 'fp16', 'bf16' and 'fp32'. Defaults to 'fp16'. initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. @@ -149,7 +160,7 @@ def __init__( ) -> None: super().__init__() assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training' - assert precision in ('fp16', 'fp32'), f'LowLevelZeroPlugin only supports fp16/fp32 training' + assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training' self.stage = stage self.precision = precision @@ -175,7 +186,7 @@ def control_precision(self) -> bool: return True def supported_precisions(self) -> List[str]: - return ['fp16', 'fp32'] + return SUPPORTED_PRECISION def control_device(self) -> bool: return True From 7f18c7e1670a0ae9a2ec46dd93a3eac3a70ed863 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 1 Jun 2023 16:49:40 +0800 Subject: [PATCH 6/6] [bf16] add bf16 support for legacy zero (#3879) * [zero] init context support bf16 * [zero] legacy zero support bf16 * [test] add zero bf16 test * [doc] add bf16 related docstring for legacy zero --- .../zero/legacy/init_ctx/init_context.py | 11 ++++-- .../zero/legacy/sharded_model/_utils.py | 10 ++++- .../legacy/sharded_model/sharded_model_v2.py | 7 +++- .../legacy/sharded_optim/sharded_optim_v2.py | 39 ++++++++++++------- .../test_zero/test_legacy/test_zero_engine.py | 21 ++++++---- 5 files changed, 62 insertions(+), 26 deletions(-) diff --git a/colossalai/zero/legacy/init_ctx/init_context.py b/colossalai/zero/legacy/init_ctx/init_context.py index a921ca0aa83a..a3fa46b38b5a 100644 --- a/colossalai/zero/legacy/init_ctx/init_context.py +++ b/colossalai/zero/legacy/init_ctx/init_context.py @@ -14,7 +14,7 @@ from colossalai.logging import get_dist_logger from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses from colossalai.zero.legacy.shard_utils import BaseShardStrategy -from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16 +from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_bf16, cast_tensor_to_fp16 from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2 from colossalai.zero.legacy.sharded_param import ShardedParamV2 @@ -55,6 +55,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): seed (int, optional): Random seed for weight initialization shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16. + bf16 (bool, optional): If it's True, parameters will be initialized as ``torch.bfloat16``. Otherwise, parameters will be initialized as ``torch.float16``. Defaults to False. model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int). """ @@ -64,6 +65,7 @@ def __init__(self, seed: int = 2**10 - 1, shard_param: bool = False, default_dtype: Optional[torch.dtype] = None, + bf16: bool = False, model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)): super().__init__(default_dtype=default_dtype) @@ -71,6 +73,7 @@ def __init__(self, self.param_list = [] self.model_numel_tensor = model_numel_tensor self.seed = seed + self.bf16 = bf16 self.dp_process_group = gpc.get_group(ParallelMode.DATA) self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param) @@ -183,9 +186,10 @@ def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): NOTE() The module may be passed to this function multiple times. """ self.top_module = module + half_dtype = torch.float16 if not self.bf16 else torch.bfloat16 def half_fn(t: torch.Tensor): - return t.half() if t.is_floating_point() else t + return t.to(half_dtype) if t.is_floating_point() else t for param in module.parameters(recurse=False): # avoid adapting a param to ShardedParam twice @@ -226,9 +230,10 @@ def half_fn(t: torch.Tensor): # We must cast buffers # If we use BN, buffers may be on CPU and Float # We must cast them + cast_fn = cast_tensor_to_fp16 if not self.bf16 else cast_tensor_to_bf16 for buffer in module.buffers(recurse=False): buffer.data = buffer.data.to(device=torch.cuda.current_device()) - buffer.data = cast_tensor_to_fp16(buffer.data) + buffer.data = cast_fn(buffer.data) class ZeroContextMgr(metaclass=SingletonMeta): diff --git a/colossalai/zero/legacy/sharded_model/_utils.py b/colossalai/zero/legacy/sharded_model/_utils.py index 2bd01531a78f..f1d642cf3f13 100644 --- a/colossalai/zero/legacy/sharded_model/_utils.py +++ b/colossalai/zero/legacy/sharded_model/_utils.py @@ -43,11 +43,19 @@ def cast_tensor_to_fp32(tensor: Union[torch.Tensor, StatefulTensor]) -> torch.Te if isinstance(tensor, StatefulTensor): tensor = tensor.payload - if torch.is_floating_point(tensor) and tensor.dtype is torch.float16: + if torch.is_floating_point(tensor) and tensor.dtype in (torch.float16, torch.bfloat16): return tensor.float() return tensor +def cast_tensor_to_bf16(tensor: torch.Tensor) -> torch.Tensor: + if isinstance(tensor, StatefulTensor): + tensor = tensor.payload + if torch.is_floating_point(tensor) and tensor.dtype is torch.float32: + return tensor.bfloat16() + return tensor + + def apply_to_tensors(x: Any, fn: Callable): if torch.is_tensor(x): return fn(x) diff --git a/colossalai/zero/legacy/sharded_model/sharded_model_v2.py b/colossalai/zero/legacy/sharded_model/sharded_model_v2.py index b3a83b741825..be3842beb208 100644 --- a/colossalai/zero/legacy/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/legacy/sharded_model/sharded_model_v2.py @@ -28,6 +28,7 @@ from ._utils import ( cast_float_arguments, + cast_tensor_to_bf16, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, @@ -74,6 +75,7 @@ class ShardedModelV2(nn.Module): In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad). We find that PyTorch's optimizers don't support mixed precision, so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False. + bf16 (bool, optional): Whether to use bfloat16 for param and grad. Defaults to False. """ def __init__(self, @@ -86,11 +88,13 @@ def __init__(self, tensor_placement_policy: str = 'cuda', gradient_predivide_factor: Optional[float] = 1.0, reuse_fp16_shard: bool = False, + bf16: bool = False, *args, **kwargs): assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.' super().__init__() self.logger = get_dist_logger() + self.bf16 = bf16 # We force users to use ZeroInitContext for submodule in module.modules(): @@ -232,7 +236,8 @@ def _post_forward_operations(self): def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: self._pre_forward_operations(*args) - args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs) + cast_fn = cast_tensor_to_bf16 if self.bf16 else cast_tensor_to_fp16 + args, kwargs = cast_float_arguments(cast_fn, *args, **kwargs) outputs = self.module(*args, **kwargs) self._post_forward_operations() return outputs diff --git a/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py b/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py index be60209af434..41dd174cb65a 100644 --- a/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py @@ -94,6 +94,7 @@ def __init__(self, super().__init__(optimizer) self.shard_strategy = sharded_model.shard_strategy self.model: ShardedModelV2 = sharded_model + self.bf16 = sharded_model.bf16 self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio) assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0' @@ -117,6 +118,7 @@ def __init__(self, self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device()) self._logger = get_dist_logger("ShardedOptimizerV2") self._verbose = verbose + self._grad_prepared: bool = False # this should be set to true when _prepare_grads() and reset to false when backward # Store fp32 param shards self._register_master_weight() @@ -166,8 +168,10 @@ def zero_grad(self, *args, **kwargs): self._zero_grad() def backward(self, loss: Tensor) -> None: - loss = self.loss_scale * loss - self.optim_state = OptimState.SCALED + if not self.bf16: + loss = self.loss_scale * loss + self.optim_state = OptimState.SCALED + self._grad_prepared = False self.model.backward(loss) def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None: @@ -175,30 +179,33 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None: # It receives the scaled grad from the previous rank # No need to scale the grad again # Need to unscale when optimizing - self.optim_state = OptimState.SCALED + if not self.bf16: + self.optim_state = OptimState.SCALED + self._grad_prepared = False self.model.backward_by_grad(tensor, grad) def clip_grad_norm(self, model: nn.Module, max_norm: float): - if self.optim_state == OptimState.SCALED: - self._prepare_grads() + self._prepare_grads() + if not self.bf16 and self.optim_state == OptimState.SCALED: self._unscale_grads() return super().clip_grad_norm(model, max_norm) def step(self, *args, **kwargs): + self._prepare_grads() # unscale grads if scaled - if self.optim_state == OptimState.SCALED: - self._prepare_grads() + if not self.bf16 and self.optim_state == OptimState.SCALED: self._unscale_grads() self._maybe_move_fp32_shards() - found_inf = self._check_overflow() - self.grad_scaler.update(found_inf) + if not self.bf16: + found_inf = self._check_overflow() + self.grad_scaler.update(found_inf) - if found_inf: - self._logger.warning('found inf during ShardedOptimV2 step') - self._zero_grad(recover_data=True) - return + if found_inf: + self._logger.warning('found inf during ShardedOptimV2 step') + self._zero_grad(recover_data=True) + return self._point_param_fp16_to_master_param() @@ -304,6 +311,8 @@ def _maybe_move_fp32_shards(self): state[k] = v.cuda() def _prepare_grads(self): + if self._grad_prepared: + return for group in self.optim.param_groups: for p in group['params']: if p.colo_attr.saved_grad.is_null(): @@ -320,6 +329,7 @@ def _prepare_grads(self): p.grad = p.colo_attr.grad_payload # Set p.data to empty tensor, in case of memory leaking p.colo_attr.set_data_none() + self._grad_prepared = True def _point_param_fp16_to_master_param(self): # assign master param pointers to p.data. @@ -357,7 +367,8 @@ def _copy_master_param_to_param_fp16(self, p): torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device)) # TODO() optimize this line CPU (fp32) -> GPU (fp16) - p.colo_attr.sharded_data_tensor.payload_copy(p.half().detach()) + half_dtype = torch.bfloat16 if self.bf16 else torch.float16 + p.colo_attr.sharded_data_tensor.payload_copy(p.to(half_dtype).detach()) p.colo_attr.set_data_none() if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated: diff --git a/tests/test_zero/test_legacy/test_zero_engine.py b/tests/test_zero/test_legacy/test_zero_engine.py index dc8847ce56ab..826a543db861 100644 --- a/tests/test_zero/test_legacy/test_zero_engine.py +++ b/tests/test_zero/test_legacy/test_zero_engine.py @@ -16,7 +16,11 @@ from tests.components_to_test.registry import non_distributed_component_funcs -def run_dist(rank, world_size, port, parallel_config): +def run_dist(rank, world_size, port, parallel_config, bf16): + is_mp_config = parallel_config == MP_PARALLEL_CONFIG + is_zero_config = parallel_config == ZERO_PARALLEL_CONFIG + if bf16: + parallel_config['zero']['model_config']['bf16'] = True colossalai.launch(config=parallel_config, rank=rank, world_size=world_size, @@ -30,7 +34,8 @@ def run_dist(rank, world_size, port, parallel_config): model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=gpc.config.zero.model_config.shard_strategy, - shard_param=True): + shard_param=True, + bf16=bf16): colo_model = model_builder(checkpoint=True) colo_optimizer = optimizer_class(colo_model.parameters(), lr=1e-3) @@ -38,7 +43,8 @@ def run_dist(rank, world_size, port, parallel_config): optimizer=colo_optimizer, criterion=criterion, train_dataloader=train_dataloader) - torch_model = model_builder(checkpoint=True).half() + dtype = torch.bfloat16 if bf16 else torch.float16 + torch_model = model_builder(checkpoint=True).to(dtype) col_model_deepcopy(engine.model, torch_model) torch_model = torch_model.cuda().float() @@ -80,9 +86,9 @@ def run_dist(rank, world_size, port, parallel_config): torch_optimizer.step() i += 1 - if parallel_config == MP_PARALLEL_CONFIG: + if is_mp_config: check_params(torch_model, colo_model, loose=True) - elif parallel_config == ZERO_PARALLEL_CONFIG: + elif is_zero_config: check_sharded_model_params(torch_model, colo_model, loose=True) @@ -97,9 +103,10 @@ def test_mp_engine(world_size): @pytest.mark.dist @pytest.mark.parametrize("world_size", [1, 2]) +@pytest.mark.parametrize("bf16", [True, False]) @rerun_if_address_is_in_use() -def test_zero_engine(world_size): - spawn(run_dist, world_size, parallel_config=ZERO_PARALLEL_CONFIG) +def test_zero_engine(world_size, bf16): + spawn(run_dist, world_size, parallel_config=ZERO_PARALLEL_CONFIG, bf16=bf16) if __name__ == '__main__':