Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/nn/layer/moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts

__all__ = [
'EPMLPExperts', 'TPMLPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'MoeModule', 'NormalNoiseGenerator',
'EPMLPExperts', 'TPMLPExperts', 'Top1Router', 'Top2Router', 'MoeModule', 'MoeLayer', 'NormalNoiseGenerator',
'UniformNoiseGenerator', 'build_ffn_experts', 'SparseMLP', 'MoeRouter', 'MoeCheckpintIO'
]
41 changes: 41 additions & 0 deletions colossalai/nn/layer/moe/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,44 @@ def moe_cumsum(inputs: Tensor):
return moe.cumsum_sub_one(inputs)
else:
return torch.cumsum(inputs, dim=0) - 1


class MoeInGradScaler(torch.autograd.Function):
"""
Scale the gradient back by the number of experts
because the batch size increases in the moe stage
"""

@staticmethod
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
if ctx is not None:
ctx.ep_size = ep_size
return inputs

@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
assert len(grad_outputs) == 1
grad = grad_outputs[0]
if ctx.ep_size != 1:
grad = grad * ctx.ep_size
return grad, None


class MoeOutGradScaler(torch.autograd.Function):
"""
Scale the gradient by the number of experts
because the batch size increases in the moe stage
"""

@staticmethod
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
ctx.ep_size = ep_size
return inputs

@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
assert len(grad_outputs) == 1
grad = grad_outputs[0]
if ctx.ep_size != 1:
grad = grad / ctx.ep_size
return grad, None
42 changes: 26 additions & 16 deletions colossalai/nn/layer/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from colossalai.context import ParallelMode, seed
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.nn.layer.moe._operation import MoeInGradScaler, MoeOutGradScaler
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size, set_moe_tensor_info


Expand All @@ -20,27 +21,31 @@ def __init__(
num_experts: int,
hidden_size: int,
intermediate_size: int,
expert_parallel: str,
expert_parallel: str = None,
activation: str = None,
drop_rate: float = 0,
):
super().__init__()
assert expert_parallel in ["EP", "TP"]
assert expert_parallel in ["EP", "TP", None]
self.expert_parallel = expert_parallel

# get local and total experts
self.num_total_experts = num_experts
self.num_local_experts, self.moe_info = MOE_CONTEXT.get_info(num_experts,
use_tp=True if expert_parallel == "TP" else False)

# get settings for different parallel
if expert_parallel == "TP":
assert intermediate_size % MOE_CONTEXT.max_ep_size == 0, \
"intermediate_size should be divide by maximum expert parallel size"
intermediate_size = intermediate_size // MOE_CONTEXT.max_ep_size
num_experts = self.num_total_experts

# get expert parallel info
if expert_parallel is not None:
self.num_local_experts, self.moe_info = MOE_CONTEXT.get_info(
num_experts, use_tp=True if expert_parallel == "TP" else False)
# get settings for different parallel
if expert_parallel == "TP":
assert intermediate_size % MOE_CONTEXT.max_ep_size == 0, \
"intermediate_size should be divide by maximum expert parallel size"
intermediate_size = intermediate_size // MOE_CONTEXT.max_ep_size
num_experts = self.num_total_experts
else:
num_experts = self.num_local_experts
self.ep_size = get_ep_size(self)
else:
num_experts = self.num_local_experts
self.num_local_experts = self.num_total_experts
self.ep_size = 1

self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size))
Expand All @@ -52,10 +57,12 @@ def __init__(
self.act = nn.GELU() if activation is None else activation
self.drop = nn.Dropout(p=drop_rate)

for param in self.parameters():
set_moe_tensor_info(param, self.moe_info)
if expert_parallel is not None:
for param in self.parameters():
set_moe_tensor_info(param, self.moe_info)

def forward(self, x: torch.Tensor) -> torch.Tensor: # inputs [g, e, c, h]
x = MoeInGradScaler.apply(x, self.ep_size)

e = x.size(1)
h = x.size(-1)
Expand All @@ -72,6 +79,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # inputs [g, e, c, h]

x = x.reshape(inshape)
x = x.transpose(0, 1).contiguous()
x = MoeOutGradScaler.apply(x, self.ep_size)
return x # outputs [g, e, c, h]


Expand Down Expand Up @@ -135,5 +143,7 @@ def get_expert_class(name: str) -> BaseMLPExperts:
return TPMLPExperts
elif name == "EP":
return EPMLPExperts
elif name is None:
return BaseMLPExperts
else:
raise ValueError(f"Unknown expert class name: {name}")
136 changes: 42 additions & 94 deletions colossalai/nn/layer/moe/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,90 +20,6 @@
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size


class MoeLayer(nn.Module):
"""A MoE layer, that puts its input tensor to its gate and uses the output logits
to router all tokens, is mainly used to exchange all tokens for every expert across
the moe tensor group by all to all communication. Then it will get the output of all
experts and exchange the output. At last returns the output of the moe system.

Args:
dim_model (int): Dimension of model.
num_experts (int): The number of experts.
router (MoeRouter): Instance of router used in routing.
experts (MoeExperts): Instance of experts generated by Expert.
"""

def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: BaseMLPExperts):
super().__init__()
self.d_model = dim_model
self.num_experts = num_experts
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model))
self.router: MoeRouter = router
self.experts: BaseMLPExperts = experts
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
self.ep_group = get_ep_group(experts)
self.ep_size = get_ep_size(experts)
self.num_local_experts = experts.num_local_experts

nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model))

def ep_process(self, dispatch_data: torch.Tensor):
expert_input = AllToAll.apply(dispatch_data, self.ep_group)
input_shape = expert_input.shape
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model)
expert_output = self.experts(expert_input)
expert_output = expert_output.reshape(input_shape)
expert_output = AllToAll.apply(expert_output, self.ep_group)
return expert_output

def tp_process(self, dispatch_data: torch.Tensor):
expert_in = AllGather.apply(dispatch_data, self.ep_group)
expert_out = self.experts(expert_in)
expert_out = ReduceScatter.apply(expert_out, self.ep_group)
return expert_out

def forward(self, inputs: torch.Tensor) -> Tuple:
# reshape the input tokens
tokens = inputs.reshape(-1, self.d_model)

# the data type of the inputs in the gating should be fp32
fp32_input = tokens.to(torch.float)
fp32_weight = self.gate_weight.to(torch.float)
gate_output = F.linear(fp32_input, fp32_weight)

# the result from the router
route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)

if self.use_kernel:
dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:])
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
else:
sec_mask_f = route_result_list[1].type_as(inputs)
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)

# dispatch_data [e, c, h]
if self.experts.expert_parallel == "EP":
expert_output = self.ep_process(dispatch_data)
elif self.experts.expert_parallel == "TP":
expert_output = self.tp_process(dispatch_data)
else:
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
"build function.")
# expert_output [e, c, h]
if self.use_kernel:
expert_output = expert_output.reshape(-1, self.d_model)
ans = MoeCombine.apply(expert_output, *route_result_list)
else:
combine_weights = route_result_list[0].type_as(inputs)
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1])
ans = torch.matmul(combine_weights, expert_output)

ans = ans.reshape(inputs.shape)
l_aux = self.router.pop_routing_loss()
return ans, l_aux


class SparseMLP(nn.Module):
"""A class for users to create MoE modules in their models.

Expand Down Expand Up @@ -149,7 +65,8 @@ def __init__(self,
self.hidden_size = hidden_size
self.num_experts = num_experts
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
assert expert_parallel in ["EP", "TP"], f"Unsupported expert parallel type {expert_parallel}"
self.expert_parallel = expert_parallel
assert expert_parallel in ["EP", "TP", None], f"Unsupported expert parallel type {expert_parallel}"

# moe router
noisy_func = get_noise_generator(noisy_policy, num_experts)
Expand All @@ -166,8 +83,11 @@ def __init__(self,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
activation=activation)
self.ep_group = get_ep_group(self.experts)
self.ep_size = get_ep_size(self.experts)
if expert_parallel is not None:
self.ep_group = get_ep_group(self.experts)
self.ep_size = get_ep_size(self.experts)
else:
self.ep_group = None
self.num_local_experts = self.experts.num_local_experts

self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size))
Expand All @@ -193,10 +113,12 @@ def forward(self, inputs: torch.Tensor) -> Tuple:
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)

# dispatch_data [e, c, h]
if self.experts.expert_parallel == "EP":
expert_output = self.ep_process(dispatch_data)
elif self.experts.expert_parallel == "TP":
expert_output = self.tp_process(dispatch_data)
if self.expert_parallel == "EP":
expert_output = self._ep_process(dispatch_data)
elif self.expert_parallel == "TP":
expert_output = self._tp_process(dispatch_data)
elif self.expert_parallel is None:
expert_output = self._local_process(dispatch_data)
else:
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
"build function.")
Expand All @@ -214,7 +136,12 @@ def forward(self, inputs: torch.Tensor) -> Tuple:
l_aux = self.router.pop_routing_loss()
return ans, l_aux

def ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor:
def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
expert_in = expert_in.unsqueeze(0)
expert_out = self.experts(expert_in)
return expert_out

def _ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor:
expert_input = AllToAll.apply(dispatch_data, self.ep_group)
input_shape = expert_input.shape
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
Expand All @@ -223,14 +150,35 @@ def ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor:
expert_output = AllToAll.apply(expert_output, self.ep_group)
return expert_output

def tp_process(self, dispatch_data: torch.Tensor) -> torch.Tensor:
def _tp_process(self, dispatch_data: torch.Tensor) -> torch.Tensor:
expert_in = AllGather.apply(dispatch_data, self.ep_group)
expert_out = self.experts(expert_in)
expert_out = ReduceScatter.apply(expert_out, self.ep_group)
return expert_out


class MoeModule(nn.Module):
class MoeModule(SparseMLP):
"""
For other dependency
"""

def __init__(self,
num_experts: int,
top_k: int = 1,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_policy: Optional[str] = None,
drop_tks: bool = True,
expert_parallel: str = "EP",
hidden_size: int = 2048,
intermediate_size: int = 2048,
activation: str = None):
super().__init__(num_experts, top_k, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_policy,
drop_tks, expert_parallel, hidden_size, intermediate_size, activation)


class MoeLayer(SparseMLP):
"""
For other dependency
"""
Expand Down
4 changes: 2 additions & 2 deletions colossalai/nn/layer/moe/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti
l_aux = num_experts * torch.sum(me * ce)
self.set_routing_loss(l_aux)

if not self.training and not self.drop_tks:
if not self.training and not self.drop_tks and ep_group is not None:
max_num = torch.max(torch.sum(mask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item()
Expand Down Expand Up @@ -190,7 +190,7 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti
l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1
self.set_routing_loss(l_aux)

if not self.training and not self.drop_tks:
if not self.training and not self.drop_tks and ep_group is not None:
max_num = torch.max(torch.sum(cmask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item()
Expand Down
34 changes: 34 additions & 0 deletions tests/test_moe/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,37 @@ def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag:
assert torch.allclose(tp_param.grad, new_grad)
else:
tp_param.data.copy_(new_tp_param.data)


def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model

Args:
tp_model (MoeModule)
ep_model (MoeModule)
"""
for (local_name, local_param), (ep_name, ep_param) in zip(local_model.named_parameters(),
ep_model.named_parameters()):
assert local_name == ep_name
if "experts" not in local_name:
if assert_grad_flag:
assert torch.allclose(local_param, ep_param)
assert torch.allclose(local_param.grad, ep_param.grad)
else:
local_param.data.copy_(ep_param.data)
continue

# gather param from ep model
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
all_param = torch.cat(param_list, dim=0)
if assert_grad_flag:
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
all_grad = torch.cat(grad_list, dim=0)

if assert_grad_flag:
assert torch.allclose(local_param, all_param)
assert torch.allclose(local_param.grad, all_grad)
else:
local_param.data.copy_(all_param.data)
Loading