Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion colossalai/moe/experts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import math
from contextlib import nullcontext
from typing import Callable, Optional, Tuple

import torch
Expand Down
47 changes: 45 additions & 2 deletions colossalai/moe/layers.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import math
from typing import Optional, Tuple
from typing import Any, Optional, Tuple

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter
from colossalai.moe.experts import BaseMLPExperts, get_expert_class
from colossalai.moe.load_balance import LoadBalancer
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.routers import MoeRouter, get_router_cls
from colossalai.moe.utils import get_noise_generator
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size


class SparseMLP(nn.Module):
Expand Down Expand Up @@ -72,6 +74,7 @@ def __init__(
# moe router
noisy_func = get_noise_generator(noisy_policy, num_experts)
router_cls = get_router_cls(top_k)
self.topk = top_k
self.router: MoeRouter = router_cls(
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
Expand All @@ -91,13 +94,30 @@ def __init__(
if self.expert_parallel is not None:
self.ep_group = get_ep_group(self.experts)
self.ep_size = get_ep_size(self.experts)
self.dp_group = get_dp_group(self.experts)
else:
self.ep_group = None
self.dp_group = None
self.num_local_experts = self.experts.num_local_experts

# gate
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size))

# load balance
self.enable_load_balance = MOE_MANAGER.load_balance
if self.enable_load_balance == True:
self.load_balancer = LoadBalancer(
experts=self.experts,
gate=self.gate_weight,
local_expert_num=self.num_local_experts,
expert_num=self.num_experts,
ep_group=self.ep_group,
dp_group=self.dp_group,
tolerance=MOE_MANAGER.tolerance,
beam_width=MOE_MANAGER.beam_width,
group_swap_factor=MOE_MANAGER.group_swap_factor,
)

# init param
self.reset_parameters()

Expand All @@ -121,6 +141,14 @@ def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
fp32_weight = self.gate_weight.to(torch.float)
gate_output = F.linear(fp32_input, fp32_weight)

# update expert load
if self.enable_load_balance == True:
with torch.no_grad():
# TODO: optimize computation
expert_load = torch.topk(gate_output, k=self.topk, dim=-1)[1]
expert_load = torch.bincount(expert_load.view(-1))
self.load_balancer.update_load(expert_load)

# the result from the router
route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)

Expand Down Expand Up @@ -257,3 +285,18 @@ def get_chunk_slice(idx: int, gap: int) -> Tuple[slice]:
# sync for async op
torch.cuda.synchronize()
return out


def apply_load_balance(model: nn.Module, optim: Any) -> None:
"""
apply load balance to every experts in the model
"""

def _apply_recursive(module: nn.Module):
for _, sub_module in module.named_children():
if isinstance(sub_module, SparseMLP):
if sub_module.enable_load_balance == True:
sub_module.load_balancer.balance_load(optim)
_apply_recursive(sub_module)

_apply_recursive(model)
Loading