Skip to content
15 changes: 10 additions & 5 deletions colossalai/moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from .checkpoint import MoeCheckpintIO
from .experts import EPMLPExperts, TPMLPExperts, build_ffn_experts
from .experts import MLPExperts
from .layers import SparseMLP
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
from .utils import NormalNoiseGenerator, UniformNoiseGenerator

__all__ = [
'EPMLPExperts', 'TPMLPExperts', 'build_ffn_experts',
'MoeRouter', 'Top1Router', 'Top2Router', 'TopKRouter',
'NormalNoiseGenerator', 'UniformNoiseGenerator',
'SparseMLP', 'MoeCheckpintIO'
"MLPExperts",
"MoeRouter",
"Top1Router",
"Top2Router",
"TopKRouter",
"NormalNoiseGenerator",
"UniformNoiseGenerator",
"SparseMLP",
"MoeCheckpintIO",
]
4 changes: 2 additions & 2 deletions colossalai/moe/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,10 @@ def backward(ctx, tokens_grad):
return d_expert, d_logits, None, None, None


def moe_cumsum(inputs: Tensor):
def moe_cumsum(inputs: Tensor, use_kernel: bool = False):
dim0 = inputs.size(0)
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
if flag and MOE_MANAGER.use_kernel_optim:
if flag and use_kernel:
if MOE_KERNEL is None:
load_moe()
return MOE_KERNEL.cumsum_sub_one(inputs)
Expand Down
118 changes: 31 additions & 87 deletions colossalai/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine


class BaseMLPExperts(nn.Module):
class MLPExperts(nn.Module):
"""
SparseMLP is a multi-layer perceptron with sparse expert parallel layers.

Args:
num_experts (int): The number of experts
forward: hidden_size --> intermediate_size --> hidden_size
hidden_size (int): The hidden size of MLP
intermediate_size (int): The intermediate size of MLP
expert_parallel (str, optional): The parallelism of experts. Now we have 'EP' and 'TP'.
hidden_size (int): The hidden size of MLP
intermediate_size (int): The intermediate size of MLP
expert_parallel (str, optional): The parallelism of experts. Now we have None, EP and TP.
activation (optional): The activation function of MLP
drop_rate (float, optional): The drop rate of MLP
gated (bool, optional): Whether to use gated MLP
use_kernel (bool, optional): Whether to use kernel optimization
"""

def __init__(
Expand All @@ -36,9 +37,9 @@ def __init__(
intermediate_size: int,
expert_parallel: Optional[str] = None,
activation: Optional[Callable] = None,
drop_rate: float = 0,
gated: bool = False,
use_kernel: bool = False,
drop_rate: Optional[float] = 0,
gated: Optional[bool] = False,
use_kernel: Optional[bool] = False,
):
super().__init__()
assert expert_parallel in ["EP", "TP", None]
Expand Down Expand Up @@ -97,8 +98,15 @@ def reset_parameters(self):
torch.nn.init.normal_(self.wi, std=math.sqrt(0.1 / self.hidden_size))
torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size))

def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) -> torch.Tensor:
def forward(
self,
x: torch.Tensor,
param_slice: Tuple[slice] = (slice(None),),
use_sparse: bool = True,
) -> torch.Tensor:
"""
forward: hidden_size --> intermediate_size --> hidden_size

Args:
x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size)

Expand All @@ -114,6 +122,16 @@ def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) -
inshape = x.shape
x = x.reshape(e, -1, h)

if self.use_kernel and use_sparse:
seq_len = x.shape[1]
with torch.no_grad():
mask = x[:, :, 0] != 0.0
mask = torch.sum(mask, dim=-1)
x_list = []
for i in range(e):
x_list.append(x[i, :mask[i]])
x = x_list

if self.gated:
x_gate = [torch.mm(x[i], self.wi_gate[param_slice][i]) for i in range(e)]
x_up = [torch.mm(x[i], self.wi_up[param_slice][i]) for i in range(e)]
Expand All @@ -127,86 +145,12 @@ def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) -
x = [self.drop(x[i]) for i in range(e)]
x = [torch.mm(x[i], self.wo[param_slice][i]) for i in range(e)]

if self.use_kernel and use_sparse:
for i in range(e):
x[i] = torch.nn.functional.pad(x[i], (0, 0, 0, seq_len - x[i].shape[0]), mode="constant", value=0)

x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
x = x.reshape(inshape)
x = x.transpose(0, 1).contiguous()
x = MoeOutGradScaler.apply(x, self.ep_size)
return x


class EPMLPExperts(BaseMLPExperts):
"""
Use expert parallelism to split each expert evenly, which can deploy experts in
"""

def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
activation=None,
drop_rate: float = 0,
gated: bool = False,
use_kernel: bool = False,
):
# TODO: This class can be aborted
super().__init__(
num_experts,
hidden_size,
intermediate_size,
"EP",
activation,
drop_rate,
gated,
use_kernel,
)


class TPMLPExperts(BaseMLPExperts):
"""Use tensor parallelism to split each expert evenly, which can deploy experts in
case that the number of experts can't be divide by maximum expert parallel size or
maximum expert parallel size can't be divide by the number of experts.
"""

def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
activation: str = None,
drop_rate: float = 0,
gated: bool = False,
use_kernel: bool = False,
):
# TODO: This class can be aborted
super().__init__(
num_experts,
hidden_size,
intermediate_size,
"TP",
activation,
drop_rate,
gated,
use_kernel,
)


def get_expert_class(name: str) -> BaseMLPExperts:
if name == "TP":
return TPMLPExperts
elif name == "EP":
return EPMLPExperts
elif name is None:
return BaseMLPExperts
else:
raise ValueError(f"Unknown expert class name: {name}")


def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
mep_size = MOE_MANAGER.max_ep_size
if num_experts % mep_size == 0 or mep_size % num_experts == 0:
return EPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate)
elif d_ff % mep_size == 0:
return TPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate)
else:
raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.")
Loading