Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions colossalai/nn/optimizer/cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ def torch_adam_update(self,
bias_correction1,
bias_correction2,
use_adamw=False):
# FIXME(ver217): remove the below line when replace torch adam with fused adam
grad = grad.float()
grad = grad.to(data.dtype)

if weight_decay != 0:
if use_adamw:
Expand Down Expand Up @@ -133,10 +132,12 @@ def step(self, closure=None, div_scale: float = -1):
if len(state) == 0:
state['step'] = 0

# FIXME(ver217): CPU adam kernel only supports fp32 states now
assert p.dtype is torch.float, "CPUAdam only support fp32 parameters"
# gradient momentums
state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device)
state['exp_avg'] = torch.zeros_like(p, device=target_device)
# gradient variances
state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device)
state['exp_avg_sq'] = torch.zeros_like(p, device=target_device)
self._post_state_init(p)

state['step'] += 1
Expand All @@ -147,9 +148,17 @@ def step(self, closure=None, div_scale: float = -1):
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
state['exp_avg_sq'], div_scale)
if p.grad.dtype is torch.bfloat16:
# cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1**state['step']
bias_correction2 = 1 - beta2**state['step']
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
beta1, beta2, group['eps'], group['weight_decay'], bias_correction1,
bias_correction2, self.adamw_mode)
else:
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'],
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
state['exp_avg'], state['exp_avg_sq'], div_scale)
self._post_update(p, 'exp_avg', 'exp_avg_sq')
elif target_device.type == 'cuda':
assert div_scale == -1, "div_scale should remain default"
Expand Down
37 changes: 21 additions & 16 deletions colossalai/nn/optimizer/hybrid_adam.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from typing import Any, Optional

import torch
from torch.optim import Adam

from colossalai.kernel.op_builder import CPUAdamBuilder, FusedOptimBuilder
from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.registry import OPTIMIZERS
from colossalai.utils import multi_tensor_applier

from .nvme_optimizer import NVMeOptimizer
from .cpu_adam import CPUAdam


@OPTIMIZERS.register_module
class HybridAdam(NVMeOptimizer):
class HybridAdam(CPUAdam):
"""Implements Adam algorithm.

Supports parameters updating on both GPU and CPU, depanding on the device of parameters.
Expand Down Expand Up @@ -74,15 +75,9 @@ def __init__(self,
nvme_offload_dir: Optional[str] = None,
**defaults: Any):

default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode

# build during runtime if not found
cpu_optim = CPUAdamBuilder().load()
super().__init__(model_params, lr, bias_correction, betas, eps, weight_decay, adamw_mode, nvme_offload_fraction,
nvme_offload_dir)
fused_optim = FusedOptimBuilder().load()
self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)

self.gpu_adam_op = fused_optim.multi_tensor_adam
self._dummy_overflow_buf = torch.cuda.IntTensor([0])

Expand All @@ -108,10 +103,12 @@ def step(self, closure=None, div_scale: float = -1):
if len(state) == 0:
state['step'] = 0

# FIXME(ver217): CPU adam kernel only supports fp32 states now
assert p.dtype is torch.float, "HybridAdam only support fp32 parameters"
# gradient momentums
state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device)
state['exp_avg'] = torch.zeros_like(p, device=target_device)
# gradient variances
state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device)
state['exp_avg_sq'] = torch.zeros_like(p, device=target_device)
self._post_state_init(p)

state['step'] += 1
Expand All @@ -122,9 +119,17 @@ def step(self, closure=None, div_scale: float = -1):
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
state['exp_avg_sq'], div_scale)
if p.grad.dtype is torch.bfloat16:
# cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1**state['step']
bias_correction2 = 1 - beta2**state['step']
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
beta1, beta2, group['eps'], group['weight_decay'], bias_correction1,
bias_correction2, self.adamw_mode)
else:
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'],
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
state['exp_avg'], state['exp_avg_sq'], div_scale)
self._post_update(p, 'exp_avg', 'exp_avg_sq')

elif target_device.type == 'cuda':
Expand Down
131 changes: 131 additions & 0 deletions tests/test_optimizer/test_adam_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# This test checks adam kernels
# Baseline is pure fp32 torch adam optimizer
import math
from abc import abstractmethod
from typing import Type

import pytest
import torch
from torch import Tensor

from colossalai.utils import get_current_device, multi_tensor_applier

_FUSED_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float),
(torch.half, torch.half), (torch.bfloat16, torch.float), (torch.float, torch.bfloat16),
(torch.bfloat16, torch.bfloat16)]

_CPU_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float),
(torch.half, torch.half)]


class AdamKernel:

def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
self.lr = lr
self.beta1 = beta1
self.beta2 = beta2
self.eps = eps
self.weight_decay = weight_decay
self.use_adamw = use_adamw

@abstractmethod
def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
pass


class TorchAdamKernel(AdamKernel):

def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
bias_correction1 = 1 - self.beta1**step
bias_correction2 = 1 - self.beta2**step

if self.weight_decay != 0:
if self.use_adamw:
# Perform stepweight decay
param.mul_(1 - self.lr * self.weight_decay)
else:
grad = grad.add(param, alpha=self.weight_decay)

# Decay the first and second moment running average coefficient
exp_avg.mul_(self.beta1).add_(grad, alpha=1 - self.beta1)
exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad, value=1 - self.beta2)
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps)

step_size = self.lr / bias_correction1

param.addcdiv_(exp_avg, denom, value=-step_size)


class FusedAdamKernel(AdamKernel):

def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
from colossalai.kernel.op_builder import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()
self.fused_adam = fused_optim.multi_tensor_adam
self.dummy_overflow_buf = torch.cuda.IntTensor([0])

def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
multi_tensor_applier(self.fused_adam, self.dummy_overflow_buf, [[grad], [param], [exp_avg], [exp_avg_sq]],
self.lr, self.beta1, self.beta2, self.eps, step, self.use_adamw, True, self.weight_decay,
-1)


class CPUAdamKernel(AdamKernel):

def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
from colossalai.kernel.op_builder import CPUAdamBuilder
cpu_optim = CPUAdamBuilder().load()

self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw)

def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
self.cpu_adam_op.step(step, self.lr, self.beta1, self.beta2, self.eps, self.weight_decay, True, param.view(-1),
grad.view(-1), exp_avg.view(-1), exp_avg_sq.view(-1), -1)


def check_adam_kernel(kernel: Type[AdamKernel], adamw: bool, weight_decay: float, p_dtype: torch.dtype,
g_dtype: torch.dtype, device: torch.device, n_steps: int, rtol: float, atol: float):
lr = 1e-3
beta1, beta2 = 0.9, 0.999
eps = 1e-8
torch_adam = TorchAdamKernel(lr, beta1, beta2, eps, weight_decay, adamw)
adam_kernel = kernel(lr, beta1, beta2, eps, weight_decay, adamw)
master_p = torch.rand(64, device=device)
master_g = torch.rand_like(master_p)
master_exp_avg = torch.zeros_like(master_p)
master_exp_avg_sq = torch.zeros_like(master_p)
p = master_p.clone().to(p_dtype)
g = master_g.clone().to(g_dtype)
exp_avg = master_exp_avg.clone()
exp_avg_sq = master_exp_avg_sq.clone()

for step in range(1, 1 + n_steps):
torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq)
adam_kernel.update(step, p, g, exp_avg, exp_avg_sq)
# if overflow, the weight won't be updated. so there will be no nan in p
assert not torch.isnan(p).any()
assert torch.allclose(master_p, p.float(), rtol=rtol, atol=atol)


@pytest.mark.parametrize('adamw', [False, True])
@pytest.mark.parametrize('weight_decay', [0.0, 0.1])
@pytest.mark.parametrize('p_dtype, g_dtype', _FUSED_ALLOWED_P_G_TYPES)
def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype):
rtol, atol = 1e-5, 1e-8
if p_dtype is torch.float16 or g_dtype is torch.float16:
rtol, atol = 1e-3, 1e-3
if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
rtol, atol = 4e-3, 4e-3
check_adam_kernel(FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_current_device(), 3, rtol, atol)


@pytest.mark.parametrize('adamw', [False, True])
@pytest.mark.parametrize('weight_decay', [0.0, 0.1])
@pytest.mark.parametrize('p_dtype, g_dtype', _CPU_ALLOWED_P_G_TYPES)
def test_cpu_adam_kernel(adamw, weight_decay, p_dtype, g_dtype):
rtol, atol = 1e-5, 1e-8
if p_dtype is torch.float16 or g_dtype is torch.float16:
rtol, atol = 1e-3, 1e-3
check_adam_kernel(CPUAdamKernel, adamw, weight_decay, p_dtype, g_dtype, torch.device('cpu'), 3, rtol, atol)
86 changes: 86 additions & 0 deletions tests/test_optimizer/test_adam_optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from copy import deepcopy
from typing import Type, Union

import pytest
import torch
import torch.nn as nn
from torch.optim import Adam, AdamW

from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
from tests.kit.model_zoo import model_zoo

_ALLOWED_OPTIM_DEVICES = [
(FusedAdam, torch.device('cuda:0')),
(CPUAdam, torch.device('cpu')),
(CPUAdam, torch.device('cuda:0')),
(HybridAdam, torch.device('cpu')),
(HybridAdam, torch.device('cuda:0')),
]

_ALLOWED_P_G_TYPES = [
(torch.float, torch.float), # pure fp32
(torch.float, torch.half), # fp16 amp
(torch.float, torch.bfloat16), # bfloat16 amp
# (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16
# (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16
]

N_STEPS = 3


def setup_param_groups(bert_model: nn.Module) -> list:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": 0.1,
},
{
"params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
return optimizer_grouped_parameters


def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> None:
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
torch_p.grad = torch.rand_like(torch_p)
# avoid inconsistent grad and param dtype error
orig_p = p.data
p.data = torch_p.grad.clone().to(g_dtype)
p.grad = p.data
p.data = orig_p


@pytest.mark.parametrize('optim_cls, device', _ALLOWED_OPTIM_DEVICES)
@pytest.mark.parametrize('adamw', [False, True])
@pytest.mark.parametrize('p_dtype, g_dtype', _ALLOWED_P_G_TYPES)
def test_adam_optim_on_bert(optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]], device: torch.device,
adamw: bool, p_dtype: torch.dtype, g_dtype: torch.dtype) -> None:
model_fn, *_ = next(iter(model_zoo.get_sub_registry('transformers_bert_for_sequence_classification').values()))
torch_model = model_fn().to(device)
model = deepcopy(torch_model).to(p_dtype)
lr = 1e-3
beta1, beta2 = 0.9, 0.999
eps = 1e-8
torch_optim_cls = AdamW if adamw else Adam
torch_optim = torch_optim_cls(setup_param_groups(torch_model), lr=lr, betas=(beta1, beta2), eps=eps)
optim = optim_cls(setup_param_groups(model), lr=lr, betas=(beta1, beta2), eps=eps, adamw_mode=adamw)

rtol, atol = 1e-5, 1e-5
if p_dtype is torch.float16 or g_dtype is torch.float16:
rtol, atol = 2e-3, 2e-3
if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
rtol, atol = 4e-3, 4e-3

for _ in range(N_STEPS):
set_grad(model, torch_model, g_dtype)
torch_optim.step()
optim.step()
torch_optim.zero_grad()
optim.zero_grad()
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
# if overflow, the weight won't be updated. so there will be no nan in p
assert not torch.isnan(p).any()
assert torch.allclose(p.float(), torch_p, rtol=rtol, atol=atol)
Loading