diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index bb561a106515..7070c0a1e59d 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -93,8 +93,7 @@ def torch_adam_update(self, bias_correction1, bias_correction2, use_adamw=False): - # FIXME(ver217): remove the below line when replace torch adam with fused adam - grad = grad.float() + grad = grad.to(data.dtype) if weight_decay != 0: if use_adamw: @@ -133,10 +132,12 @@ def step(self, closure=None, div_scale: float = -1): if len(state) == 0: state['step'] = 0 + # FIXME(ver217): CPU adam kernel only supports fp32 states now + assert p.dtype is torch.float, "CPUAdam only support fp32 parameters" # gradient momentums - state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + state['exp_avg'] = torch.zeros_like(p, device=target_device) # gradient variances - state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + state['exp_avg_sq'] = torch.zeros_like(p, device=target_device) self._post_state_init(p) state['step'] += 1 @@ -147,9 +148,17 @@ def step(self, closure=None, div_scale: float = -1): assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu" assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu" self._pre_update(p, 'exp_avg', 'exp_avg_sq') - self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], - group['bias_correction'], p.data, p.grad.data, state['exp_avg'], - state['exp_avg_sq'], div_scale) + if p.grad.dtype is torch.bfloat16: + # cpu adam kernel does not support bf16 now + bias_correction1 = 1 - beta1**state['step'] + bias_correction2 = 1 - beta2**state['step'] + self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'], + beta1, beta2, group['eps'], group['weight_decay'], bias_correction1, + bias_correction2, self.adamw_mode) + else: + self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], + group['weight_decay'], group['bias_correction'], p.data, p.grad.data, + state['exp_avg'], state['exp_avg_sq'], div_scale) self._post_update(p, 'exp_avg', 'exp_avg_sq') elif target_device.type == 'cuda': assert div_scale == -1, "div_scale should remain default" diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index be6311c6c29f..526071b06f95 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -1,16 +1,17 @@ from typing import Any, Optional import torch +from torch.optim import Adam -from colossalai.kernel.op_builder import CPUAdamBuilder, FusedOptimBuilder +from colossalai.kernel.op_builder import FusedOptimBuilder from colossalai.registry import OPTIMIZERS from colossalai.utils import multi_tensor_applier -from .nvme_optimizer import NVMeOptimizer +from .cpu_adam import CPUAdam @OPTIMIZERS.register_module -class HybridAdam(NVMeOptimizer): +class HybridAdam(CPUAdam): """Implements Adam algorithm. Supports parameters updating on both GPU and CPU, depanding on the device of parameters. @@ -74,15 +75,9 @@ def __init__(self, nvme_offload_dir: Optional[str] = None, **defaults: Any): - default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) - super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) - self.adamw_mode = adamw_mode - - # build during runtime if not found - cpu_optim = CPUAdamBuilder().load() + super().__init__(model_params, lr, bias_correction, betas, eps, weight_decay, adamw_mode, nvme_offload_fraction, + nvme_offload_dir) fused_optim = FusedOptimBuilder().load() - self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) - self.gpu_adam_op = fused_optim.multi_tensor_adam self._dummy_overflow_buf = torch.cuda.IntTensor([0]) @@ -108,10 +103,12 @@ def step(self, closure=None, div_scale: float = -1): if len(state) == 0: state['step'] = 0 + # FIXME(ver217): CPU adam kernel only supports fp32 states now + assert p.dtype is torch.float, "HybridAdam only support fp32 parameters" # gradient momentums - state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + state['exp_avg'] = torch.zeros_like(p, device=target_device) # gradient variances - state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + state['exp_avg_sq'] = torch.zeros_like(p, device=target_device) self._post_state_init(p) state['step'] += 1 @@ -122,9 +119,17 @@ def step(self, closure=None, div_scale: float = -1): assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu" assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu" self._pre_update(p, 'exp_avg', 'exp_avg_sq') - self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], - group['bias_correction'], p.data, p.grad.data, state['exp_avg'], - state['exp_avg_sq'], div_scale) + if p.grad.dtype is torch.bfloat16: + # cpu adam kernel does not support bf16 now + bias_correction1 = 1 - beta1**state['step'] + bias_correction2 = 1 - beta2**state['step'] + self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'], + beta1, beta2, group['eps'], group['weight_decay'], bias_correction1, + bias_correction2, self.adamw_mode) + else: + self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], + group['weight_decay'], group['bias_correction'], p.data, p.grad.data, + state['exp_avg'], state['exp_avg_sq'], div_scale) self._post_update(p, 'exp_avg', 'exp_avg_sq') elif target_device.type == 'cuda': diff --git a/tests/test_optimizer/test_adam_kernel.py b/tests/test_optimizer/test_adam_kernel.py new file mode 100644 index 000000000000..2186a421fe00 --- /dev/null +++ b/tests/test_optimizer/test_adam_kernel.py @@ -0,0 +1,131 @@ +# This test checks adam kernels +# Baseline is pure fp32 torch adam optimizer +import math +from abc import abstractmethod +from typing import Type + +import pytest +import torch +from torch import Tensor + +from colossalai.utils import get_current_device, multi_tensor_applier + +_FUSED_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float), + (torch.half, torch.half), (torch.bfloat16, torch.float), (torch.float, torch.bfloat16), + (torch.bfloat16, torch.bfloat16)] + +_CPU_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float), + (torch.half, torch.half)] + + +class AdamKernel: + + def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: + self.lr = lr + self.beta1 = beta1 + self.beta2 = beta2 + self.eps = eps + self.weight_decay = weight_decay + self.use_adamw = use_adamw + + @abstractmethod + def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): + pass + + +class TorchAdamKernel(AdamKernel): + + def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): + bias_correction1 = 1 - self.beta1**step + bias_correction2 = 1 - self.beta2**step + + if self.weight_decay != 0: + if self.use_adamw: + # Perform stepweight decay + param.mul_(1 - self.lr * self.weight_decay) + else: + grad = grad.add(param, alpha=self.weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(self.beta1).add_(grad, alpha=1 - self.beta1) + exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad, value=1 - self.beta2) + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps) + + step_size = self.lr / bias_correction1 + + param.addcdiv_(exp_avg, denom, value=-step_size) + + +class FusedAdamKernel(AdamKernel): + + def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: + super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) + from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() + self.fused_adam = fused_optim.multi_tensor_adam + self.dummy_overflow_buf = torch.cuda.IntTensor([0]) + + def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): + multi_tensor_applier(self.fused_adam, self.dummy_overflow_buf, [[grad], [param], [exp_avg], [exp_avg_sq]], + self.lr, self.beta1, self.beta2, self.eps, step, self.use_adamw, True, self.weight_decay, + -1) + + +class CPUAdamKernel(AdamKernel): + + def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: + super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) + from colossalai.kernel.op_builder import CPUAdamBuilder + cpu_optim = CPUAdamBuilder().load() + + self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw) + + def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): + self.cpu_adam_op.step(step, self.lr, self.beta1, self.beta2, self.eps, self.weight_decay, True, param.view(-1), + grad.view(-1), exp_avg.view(-1), exp_avg_sq.view(-1), -1) + + +def check_adam_kernel(kernel: Type[AdamKernel], adamw: bool, weight_decay: float, p_dtype: torch.dtype, + g_dtype: torch.dtype, device: torch.device, n_steps: int, rtol: float, atol: float): + lr = 1e-3 + beta1, beta2 = 0.9, 0.999 + eps = 1e-8 + torch_adam = TorchAdamKernel(lr, beta1, beta2, eps, weight_decay, adamw) + adam_kernel = kernel(lr, beta1, beta2, eps, weight_decay, adamw) + master_p = torch.rand(64, device=device) + master_g = torch.rand_like(master_p) + master_exp_avg = torch.zeros_like(master_p) + master_exp_avg_sq = torch.zeros_like(master_p) + p = master_p.clone().to(p_dtype) + g = master_g.clone().to(g_dtype) + exp_avg = master_exp_avg.clone() + exp_avg_sq = master_exp_avg_sq.clone() + + for step in range(1, 1 + n_steps): + torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq) + adam_kernel.update(step, p, g, exp_avg, exp_avg_sq) + # if overflow, the weight won't be updated. so there will be no nan in p + assert not torch.isnan(p).any() + assert torch.allclose(master_p, p.float(), rtol=rtol, atol=atol) + + +@pytest.mark.parametrize('adamw', [False, True]) +@pytest.mark.parametrize('weight_decay', [0.0, 0.1]) +@pytest.mark.parametrize('p_dtype, g_dtype', _FUSED_ALLOWED_P_G_TYPES) +def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype): + rtol, atol = 1e-5, 1e-8 + if p_dtype is torch.float16 or g_dtype is torch.float16: + rtol, atol = 1e-3, 1e-3 + if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16: + rtol, atol = 4e-3, 4e-3 + check_adam_kernel(FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_current_device(), 3, rtol, atol) + + +@pytest.mark.parametrize('adamw', [False, True]) +@pytest.mark.parametrize('weight_decay', [0.0, 0.1]) +@pytest.mark.parametrize('p_dtype, g_dtype', _CPU_ALLOWED_P_G_TYPES) +def test_cpu_adam_kernel(adamw, weight_decay, p_dtype, g_dtype): + rtol, atol = 1e-5, 1e-8 + if p_dtype is torch.float16 or g_dtype is torch.float16: + rtol, atol = 1e-3, 1e-3 + check_adam_kernel(CPUAdamKernel, adamw, weight_decay, p_dtype, g_dtype, torch.device('cpu'), 3, rtol, atol) diff --git a/tests/test_optimizer/test_adam_optim.py b/tests/test_optimizer/test_adam_optim.py new file mode 100644 index 000000000000..0f72bc134809 --- /dev/null +++ b/tests/test_optimizer/test_adam_optim.py @@ -0,0 +1,86 @@ +from copy import deepcopy +from typing import Type, Union + +import pytest +import torch +import torch.nn as nn +from torch.optim import Adam, AdamW + +from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam +from tests.kit.model_zoo import model_zoo + +_ALLOWED_OPTIM_DEVICES = [ + (FusedAdam, torch.device('cuda:0')), + (CPUAdam, torch.device('cpu')), + (CPUAdam, torch.device('cuda:0')), + (HybridAdam, torch.device('cpu')), + (HybridAdam, torch.device('cuda:0')), +] + +_ALLOWED_P_G_TYPES = [ + (torch.float, torch.float), # pure fp32 + (torch.float, torch.half), # fp16 amp + (torch.float, torch.bfloat16), # bfloat16 amp + # (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16 + # (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16 +] + +N_STEPS = 3 + + +def setup_param_groups(bert_model: nn.Module) -> list: + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": 0.1, + }, + { + "params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + + +def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> None: + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + torch_p.grad = torch.rand_like(torch_p) + # avoid inconsistent grad and param dtype error + orig_p = p.data + p.data = torch_p.grad.clone().to(g_dtype) + p.grad = p.data + p.data = orig_p + + +@pytest.mark.parametrize('optim_cls, device', _ALLOWED_OPTIM_DEVICES) +@pytest.mark.parametrize('adamw', [False, True]) +@pytest.mark.parametrize('p_dtype, g_dtype', _ALLOWED_P_G_TYPES) +def test_adam_optim_on_bert(optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]], device: torch.device, + adamw: bool, p_dtype: torch.dtype, g_dtype: torch.dtype) -> None: + model_fn, *_ = next(iter(model_zoo.get_sub_registry('transformers_bert_for_sequence_classification').values())) + torch_model = model_fn().to(device) + model = deepcopy(torch_model).to(p_dtype) + lr = 1e-3 + beta1, beta2 = 0.9, 0.999 + eps = 1e-8 + torch_optim_cls = AdamW if adamw else Adam + torch_optim = torch_optim_cls(setup_param_groups(torch_model), lr=lr, betas=(beta1, beta2), eps=eps) + optim = optim_cls(setup_param_groups(model), lr=lr, betas=(beta1, beta2), eps=eps, adamw_mode=adamw) + + rtol, atol = 1e-5, 1e-5 + if p_dtype is torch.float16 or g_dtype is torch.float16: + rtol, atol = 2e-3, 2e-3 + if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16: + rtol, atol = 4e-3, 4e-3 + + for _ in range(N_STEPS): + set_grad(model, torch_model, g_dtype) + torch_optim.step() + optim.step() + torch_optim.zero_grad() + optim.zero_grad() + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + # if overflow, the weight won't be updated. so there will be no nan in p + assert not torch.isnan(p).any() + assert torch.allclose(p.float(), torch_p, rtol=rtol, atol=atol) diff --git a/tests/test_optimizer/test_cpu_adam.py b/tests/test_optimizer/test_cpu_adam.py deleted file mode 100644 index 8b3ecf8517f7..000000000000 --- a/tests/test_optimizer/test_cpu_adam.py +++ /dev/null @@ -1,121 +0,0 @@ -import math - -import torch - -from colossalai.testing import clear_cache_before_run, parameterize - - -def torch_adam_update( - step, - lr, - beta1, - beta2, - eps, - weight_decay, - param, - grad, - exp_avg, - exp_avg_sq, - use_adamw, -): - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - - if weight_decay != 0: - if use_adamw: - # Perform stepweight decay - param.mul_(1 - lr * weight_decay) - else: - grad = grad.add(param, alpha=weight_decay) - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) - - step_size = lr / bias_correction1 - - param.addcdiv_(exp_avg, denom, value=-step_size) - - -def assertLess(data_diff, threshold, msg): - assert data_diff < threshold, msg - - -def assertTrue(condition, msg): - assert condition, msg - - -@clear_cache_before_run() -@parameterize('adamw', [True, False]) -@parameterize('step', [1, 2]) -@parameterize('p_dtype', [torch.float, torch.half]) -@parameterize('g_dtype', [torch.float, torch.half]) -def test_cpu_adam(adamw, step, p_dtype, g_dtype): - lr = 1e-3 - beta1, beta2 = 0.9, 0.999 - eps = 1e-8 - weight_decay = 0 - - for i in range(3): - p_data = torch.rand(64, dtype=p_dtype) - p_data_copy = p_data.clone().float() - p_grad = torch.rand(64, dtype=g_dtype) - p_grad_copy = p_grad.clone().float() - exp_avg = torch.rand(p_data.shape) - exp_avg_copy = exp_avg.clone() - exp_avg_sq = torch.rand(p_data.shape) - exp_avg_sq_copy = exp_avg_sq.clone() - - from colossalai.kernel.op_builder import CPUAdamBuilder - cpu_optim = CPUAdamBuilder().load() - - cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw) - - cpu_adam_op.step( - step, - lr, - beta1, - beta2, - eps, - weight_decay, - True, - p_data.view(-1), # fp32 data - p_grad.view(-1), # fp32 grad - exp_avg.view(-1), - exp_avg_sq.view(-1), - -1, - ) - - torch_adam_update( - step, - lr, - beta1, - beta2, - eps, - weight_decay, - p_data_copy, # fp32 data - p_grad_copy, # fp32 grad - exp_avg_copy, - exp_avg_sq_copy, - adamw, - ) - var = p_data_copy - p_data - data_diff = torch.max(torch.abs(var)) - threshold = 1e-3 - assertLess( - data_diff, - threshold, - f"p_data diff {data_diff}. failed check, step {step}, lr {lr}, eps " - f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}", - ) - max_grad_diff = torch.max(torch.abs(p_grad_copy - p_grad)) - assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}") - max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg)) - assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}") - max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq)) - assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}") - - -if __name__ == '__main__': - test_cpu_adam() diff --git a/tests/test_optimizer/test_fused_adam.py b/tests/test_optimizer/test_fused_adam.py deleted file mode 100644 index 511987ca3dbf..000000000000 --- a/tests/test_optimizer/test_fused_adam.py +++ /dev/null @@ -1,73 +0,0 @@ -from copy import deepcopy - -import pytest -import torch -import torch.nn as nn -from torch.optim import AdamW -from torch.optim.adam import Adam - -from colossalai.nn.optimizer.fused_adam import FusedAdam - - -class FC(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.fc = nn.Sequential(nn.Linear(64, 64)) - - def forward(self, x): - return self.fc(x) - - -@pytest.mark.parametrize('adamw', [False, True]) -@pytest.mark.parametrize('running_p_dtype', [torch.float, torch.half, torch.bfloat16]) -@pytest.mark.parametrize('fp32_master_weights', [False, True]) -def test_adam(adamw, running_p_dtype, fp32_master_weights): - # baseline is fure fp32 torch adam - # g_type is the same as running_p_dtype - if running_p_dtype is torch.float and not fp32_master_weights: - # pure fp32 must have fp32 weights - return - if not fp32_master_weights or running_p_dtype is torch.bfloat16: - # pure low precision or bf16, high tolerance - atol = 4e-3 - rtol = 4e-3 - else: - # fp32 master weights, low tolerance - atol = 2e-3 - rtol = 2e-3 - torch_model = FC().cuda() - model = deepcopy(torch_model).to(running_p_dtype) - - torch_optim_cls = AdamW if adamw else Adam - torch_optim = torch_optim_cls(torch_model.parameters(), lr=1e-3) - optim = FusedAdam(model.parameters(), lr=1e-3, adamw_mode=adamw) - - data = torch.rand(10, 64).cuda() - label = torch.rand(10, 64).cuda() - - for d, l in zip(data.to(running_p_dtype), label.to(running_p_dtype)): - y = model(d) - loss = ((l - y)**2).sum() - optim.zero_grad() - loss.backward() - if fp32_master_weights: - for p in model.parameters(): - p.data = p.data.float() - optim.step() - if fp32_master_weights: - for p in model.parameters(): - p.data = p.data.to(running_p_dtype) - - for d, l in zip(data, label): - y = torch_model(d) - loss = ((l - y)**2).sum() - torch_optim.zero_grad() - loss.backward() - torch_optim.step() - - for p, torch_p in zip(model.parameters(), torch_model.parameters()): - if torch.isnan(p).any() or torch.isnan(torch_p).any(): - continue - fp32_p = p.float() - assert torch.allclose(fp32_p, torch_p, atol=atol, rtol=rtol) diff --git a/tests/test_optimizer/test_fused_adam_kernel.py b/tests/test_optimizer/test_fused_adam_kernel.py deleted file mode 100644 index deabaebfec75..000000000000 --- a/tests/test_optimizer/test_fused_adam_kernel.py +++ /dev/null @@ -1,95 +0,0 @@ -import math - -import pytest -import torch - -from colossalai.utils import multi_tensor_applier - -_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float), - (torch.half, torch.half), (torch.bfloat16, torch.float), (torch.float, torch.bfloat16), - (torch.bfloat16, torch.bfloat16)] - - -def torch_adam_update( - step, - lr, - beta1, - beta2, - eps, - weight_decay, - param, - grad, - exp_avg, - exp_avg_sq, - use_adamw, -): - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - - if weight_decay != 0: - if use_adamw: - # Perform stepweight decay - param.mul_(1 - lr * weight_decay) - else: - grad = grad.add(param, alpha=weight_decay) - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) - - step_size = lr / bias_correction1 - - param.addcdiv_(exp_avg, denom, value=-step_size) - - -@pytest.mark.parametrize('adamw', [False, True]) -@pytest.mark.parametrize('step', [1, 2]) -@pytest.mark.parametrize('p_dtype, g_dtype', _ALLOWED_P_G_TYPES) -def test_adam(adamw, step, p_dtype, g_dtype): - from colossalai.kernel.op_builder import FusedOptimBuilder - fused_optim = FusedOptimBuilder().load() - fused_adam = fused_optim.multi_tensor_adam - - dummy_overflow_buf = torch.cuda.IntTensor([0]) - - count = 0 - - for i in range(3): - p = torch.rand(64, dtype=p_dtype).cuda() - p_copy = p.clone().float() - g = torch.rand(p.shape, dtype=g_dtype).cuda() - g_copy = g.clone().float() - m = torch.rand(p.shape).cuda() - m_copy = m.clone() - v = torch.rand(p.shape).cuda() - v_copy = v.clone() - - lr = 1e-3 - beta1, beta2 = 0.9, 0.999 - eps = 1e-8 - weight_decay = 0 - - multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]], lr, beta1, beta2, eps, step, adamw, - True, weight_decay, -1) - - torch_adam_update( - step, - lr, - beta1, - beta2, - eps, - weight_decay, - p_copy, # fp32 data - g_copy, # fp32 grad - m_copy, - v_copy, - adamw, - ) - - if torch.isnan(p).any() or torch.isnan(p_copy).any(): - count += 1 - continue - assert count < 200, "too many nans" - assert torch.allclose(p.to(torch.float), p_copy.to(torch.float), 1e-5, - 1e-5), f"failed check, adamw {adamw}, p_dtype {p_dtype}, g_dtype {g_dtype}" diff --git a/tests/test_optimizer/test_hybrid_adam.py b/tests/test_optimizer/test_hybrid_adam.py deleted file mode 100644 index d075149dfcb1..000000000000 --- a/tests/test_optimizer/test_hybrid_adam.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch -import torch.nn as nn -from torch.optim import AdamW -from torch.optim.adam import Adam - -from colossalai.nn.optimizer.hybrid_adam import HybridAdam -from colossalai.testing import clear_cache_before_run, parameterize - -RE = 3 - - -@clear_cache_before_run() -@parameterize('adamw', [False, True]) -@parameterize('device', ['cpu', 'cuda:0']) -@parameterize('p_dtype', [torch.float]) -@parameterize('g_dtype', [torch.float, torch.half]) -def test_adam(adamw, device, p_dtype, g_dtype): - rng_state = torch.get_rng_state() - p = nn.Parameter(torch.rand(64).to(device, p_dtype)) - torch.set_rng_state(rng_state) - p_copy = nn.Parameter(torch.rand(64).to(device).float()) - - if adamw: - optim = HybridAdam([p], lr=1e-3, adamw_mode=True) - torch_optim = AdamW([p_copy], lr=1e-3) - else: - optim = HybridAdam([p], lr=1e-3) - torch_optim = Adam([p_copy], lr=1e-3) - - print(f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}") - for i in range(RE): - p.grad = torch.rand(64).to(device, p_dtype) - p_copy.grad = p.grad.clone().float() - p.grad.data = p.grad.data.to(g_dtype) - - optim.step() - torch_optim.step() - - if torch.isnan(p.data).any() or torch.isnan(p_copy.data).any(): - continue - assert torch.allclose(p.data, p_copy.data, 1e-4, 1e-2), \ - f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}"