Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions colossalai/kernel/cuda_native/csrc/type_shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,21 @@
using g_scalar_t_##LEVEL = at::Half; \
using p_scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::Float && \
PTYPE == at::ScalarType::BFloat16) { \
using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::BFloat16 && \
PTYPE == at::ScalarType::Float) { \
using g_scalar_t_##LEVEL = at::BFloat16; \
using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::BFloat16 && \
PTYPE == at::ScalarType::BFloat16) { \
using g_scalar_t_##LEVEL = at::BFloat16; \
using p_scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
} else { \
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
"'"); \
Expand Down
4 changes: 2 additions & 2 deletions colossalai/nn/optimizer/fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p)

if p.dtype not in [torch.float16, torch.float32]:
raise RuntimeError('FusedAdam only support fp16 and fp32.')
if p.dtype not in [torch.float16, torch.float32, torch.bfloat16]:
raise RuntimeError('FusedAdam only support fp16, fp32 and bf16.')

g_l.append(p.grad.data)
p_l.append(p.data)
Expand Down
71 changes: 40 additions & 31 deletions tests/test_optimizer/test_fused_adam.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from copy import deepcopy

import pytest
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.adam import Adam

from colossalai.nn.optimizer.fused_adam import FusedAdam
from colossalai.testing import clear_cache_before_run, parameterize


class FC(nn.Module):
Expand All @@ -17,48 +19,55 @@ def forward(self, x):
return self.fc(x)


@clear_cache_before_run()
@parameterize('adamw', [False, True])
@parameterize('p_dtype', [torch.float, torch.half])
@parameterize('g_dtype', [torch.float, torch.half])
def test_adam(adamw, p_dtype, g_dtype):
model = FC().cuda().to(p_dtype)
state = model.state_dict()
model_copy = FC().cuda().to(p_dtype)
model_copy.load_state_dict(state.copy())

if adamw:
optim = FusedAdam(model.parameters(), lr=1e-3, adamw_mode=True)
torch_optim = AdamW(model_copy.parameters(), lr=1e-3)
@pytest.mark.parametrize('adamw', [False, True])
@pytest.mark.parametrize('running_p_dtype', [torch.float, torch.half, torch.bfloat16])
@pytest.mark.parametrize('fp32_master_weights', [False, True])
def test_adam(adamw, running_p_dtype, fp32_master_weights):
# baseline is fure fp32 torch adam
# g_type is the same as running_p_dtype
if running_p_dtype is torch.float and not fp32_master_weights:
# pure fp32 must have fp32 weights
return
if not fp32_master_weights or running_p_dtype is torch.bfloat16:
# pure low precision or bf16, high tolerance
atol = 4e-3
rtol = 4e-3
else:
optim = FusedAdam(model.parameters(), lr=1e-3)
torch_optim = Adam(model_copy.parameters(), lr=1e-3)
# fp32 master weights, low tolerance
atol = 2e-3
rtol = 2e-3
torch_model = FC().cuda()
model = deepcopy(torch_model).to(running_p_dtype)

data = torch.rand(1024, 64).cuda().to(p_dtype)
data_copy = data.clone()
label = torch.rand(1024, 64).cuda().to(p_dtype)
torch_optim_cls = AdamW if adamw else Adam
torch_optim = torch_optim_cls(torch_model.parameters(), lr=1e-3)
optim = FusedAdam(model.parameters(), lr=1e-3, adamw_mode=adamw)

for d, l in zip(data, label):
data = torch.rand(10, 64).cuda()
label = torch.rand(10, 64).cuda()

for d, l in zip(data.to(running_p_dtype), label.to(running_p_dtype)):
y = model(d)
loss = ((l - y)**2).sum()
optim.zero_grad()
loss.backward()
if p_dtype != g_dtype:
for i in range(len(optim.param_groups[0]['params'])):
optim.param_groups[0]['params'][i].grad.data = optim.param_groups[0]['params'][i].grad.data.to(g_dtype)
if fp32_master_weights:
for p in model.parameters():
p.data = p.data.float()
optim.step()
if fp32_master_weights:
for p in model.parameters():
p.data = p.data.to(running_p_dtype)

for d, l in zip(data_copy, label):
y = model_copy(d)
for d, l in zip(data, label):
y = torch_model(d)
loss = ((l - y)**2).sum()
torch_optim.zero_grad()
loss.backward()
torch_optim.step()

assert len(optim.param_groups[0]['params']) == len(torch_optim.param_groups[0]['params'])

for i in range(len(optim.param_groups[0]['params'])):
if torch.isnan(optim.param_groups[0]['params'][i]).any() \
or torch.isnan(torch_optim.param_groups[0]['params'][i]).any():
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
if torch.isnan(p).any() or torch.isnan(torch_p).any():
continue
assert torch.allclose(optim.param_groups[0]['params'][i], torch_optim.param_groups[0]['params'][i], 2e-3, 2e-3)
fp32_p = p.float()
assert torch.allclose(fp32_p, torch_p, atol=atol, rtol=rtol)
16 changes: 8 additions & 8 deletions tests/test_optimizer/test_fused_adam_kernel.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import math

import pytest
import torch
import torch.nn as nn
from numpy import dtype

from colossalai.testing import clear_cache_before_run, parameterize
from colossalai.utils import multi_tensor_applier

_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float),
(torch.half, torch.half), (torch.bfloat16, torch.float), (torch.float, torch.bfloat16),
(torch.bfloat16, torch.bfloat16)]


def torch_adam_update(
step,
Expand Down Expand Up @@ -41,11 +43,9 @@ def torch_adam_update(
param.addcdiv_(exp_avg, denom, value=-step_size)


@clear_cache_before_run()
Comment thread
FrankLeeeee marked this conversation as resolved.
@parameterize('adamw', [False, True])
@parameterize('step', [1, 2])
@parameterize('p_dtype', [torch.float, torch.half])
@parameterize('g_dtype', [torch.float, torch.half])
@pytest.mark.parametrize('adamw', [False, True])
@pytest.mark.parametrize('step', [1, 2])
@pytest.mark.parametrize('p_dtype, g_dtype', _ALLOWED_P_G_TYPES)
def test_adam(adamw, step, p_dtype, g_dtype):
from colossalai.kernel.op_builder import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()
Expand Down