Skip to content
3 changes: 3 additions & 0 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ def __init__(
for group_id, param_group in enumerate(self.optim.param_groups):
group_params = list()
for param in param_group['params']:
# skip moe param
if hasattr(param, "moe_info"):
continue
if param.requires_grad:
group_params.append(param)

Expand Down
41 changes: 41 additions & 0 deletions tests/test_moe/moe_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch.nn as nn

from colossalai.context import MOE_CONTEXT
from colossalai.nn import CheckpointModule
from colossalai.nn.layer import MoeModule


class MoeModel(nn.Module):

def __init__(self, checkpoint: bool = False):

class TestSubModule(CheckpointModule):

def __init__(self):
super().__init__(checkpoint)
expert_cls = nn.Linear
expert_args_dict = dict(in_features=16, out_features=16)
self.moe = MoeModule(dim_model=16,
num_experts=8,
use_residual=True,
expert_cls=expert_cls,
**expert_args_dict)
self.proj = nn.Linear(16, 4)

def _forward(self, x):
x, y = self.moe(x)
x = self.proj(x)
return x, y

super().__init__()
self.test_embed = nn.Linear(4, 16)
self.test_transform = TestSubModule()

def forward(self, x):
MOE_CONTEXT.reset_loss()

x = self.test_embed(x)
x, y = self.test_transform(x)

MOE_CONTEXT.add_loss(y)
return x
3 changes: 1 addition & 2 deletions tests/test_moe/test_grad_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@

BATCH_SIZE = 4
DIM = 16
CONFIG = dict()


def run_test(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
expert_module = nn.Linear
expert_factor = dict(in_features=DIM, out_features=DIM, device=get_current_device())

Expand Down
3 changes: 1 addition & 2 deletions tests/test_moe/test_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

BATCH_SIZE = 16
NUM_EXPERTS = 4
CONFIG = dict()


def check_equal(tensor_a, tensor_b, atol=1e-06):
Expand All @@ -23,7 +22,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
# Here we do not need TF32, since it brings absolute error on results
torch.backends.cuda.matmul.allow_tf32 = False

colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)

MOE_CONTEXT.setup(42) # MOE environment initialization
Expand Down
5 changes: 2 additions & 3 deletions tests/test_moe/test_moe_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from tests.test_moe.test_moe_zero_init import MoeModel
from tests.test_zero.test_legacy.common import CONFIG
from tests.test_moe.moe_utils import MoeModel


def exam_moe_checkpoint():
Expand All @@ -34,7 +33,7 @@ def exam_moe_checkpoint():


def _run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MOE_CONTEXT.setup(seed=42)
exam_moe_checkpoint()

Expand Down
56 changes: 0 additions & 56 deletions tests/test_moe/test_moe_colo_init.py

This file was deleted.

3 changes: 1 addition & 2 deletions tests/test_moe/test_moe_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@

D_MODEL = 4
D_FF = 8
CONFIG = dict()


def run_test(rank, world_size, port):
world_size = 4
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
expert_module = nn.Linear
expert_factor = dict(in_features=D_MODEL, out_features=D_FF, device=get_current_device())

Expand Down
106 changes: 106 additions & 0 deletions tests/test_moe/test_moe_zero_fwd_bwd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import pytest
import torch

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.context import MOE_CONTEXT
from colossalai.engine.gradient_handler import MoeGradientHandler
from colossalai.nn import MoeLoss
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import MoeModel


def split_ddp_grad(grad, world_size):
with torch.no_grad():
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad


def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()

if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss)
else:
loss.backward()
return y


def run_zero_test(local_rank, world_size, stage=1):
criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss)

zero_model = MoeModel(checkpoint=True)
optimizer = torch.optim.Adam(zero_model.parameters())
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
booster = Booster(plugin=plugin)
zero_model, optimizer, _, _, _ = booster.boost(zero_model, optimizer)

torch_model = MoeModel(checkpoint=True)
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
torch_param.data.copy_(zero_param.data)
torch_model = torch_model.cuda()
grad_handler = MoeGradientHandler(torch_model)

# assert zero model
assert len(zero_model.module.test_transform.moe.moe_layer.experts.experts) == 8 // MOE_CONTEXT.world_size
for (torch_name, torch_param), (zero_name, zero_param) in zip(torch_model.named_parameters(),
zero_model.module.named_parameters()):
assert zero_name == torch_name
assert torch.allclose(zero_param.data, torch_param.data)

data = torch.randn(16, 4).cuda()
label = torch.randint(0, 4, (16,)).cuda()

torch_out = run_fwd_bwd(torch_model, data, label, criterion, None)
zero_out = run_fwd_bwd(zero_model, data, label, criterion, optimizer)
assert torch.allclose(torch_out, zero_out)
grad_handler.handle_gradient()

for (zero_name, zero_param), (torch_name, torch_param) in zip(zero_model.module.named_parameters(),
torch_model.named_parameters()):
assert zero_name == torch_name
zero_grad_list = optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param))
if hasattr(zero_param, "moe_info"):
assert len(zero_grad_list) == 0
assert torch.allclose(zero_param.grad, torch_param.grad)
else:
assert len(zero_grad_list) > 0
torch_grad_list = split_ddp_grad(torch_param.grad, world_size)
if stage == 2:
torch_grad_list = torch_grad_list[local_rank:local_rank + 1]
assert len(zero_grad_list) == len(torch_grad_list)
for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list):
assert torch.allclose(zero_grad, torch_grad)


def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MOE_CONTEXT.setup(seed=42)
seed_all(42 + rank)
run_zero_test(rank, world_size, stage=1)
run_zero_test(rank, world_size, stage=2)


@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2])
@rerun_if_address_is_in_use()
def test_moe_zero_model(world_size):
spawn(run_dist, world_size)


if __name__ == '__main__':
test_moe_zero_model(world_size=2)
108 changes: 0 additions & 108 deletions tests/test_moe/test_moe_zero_init.py

This file was deleted.

Loading