Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions colossalai/moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from .checkpoint import MoeCheckpintIO
from .experts import EPMLPExperts, TPMLPExperts, build_ffn_experts
from .layers import SparseMLP
from .routers import MoeRouter, Top1Router, Top2Router
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
from .utils import NormalNoiseGenerator, UniformNoiseGenerator

__all__ = [
'EPMLPExperts', 'TPMLPExperts', 'Top1Router', 'Top2Router', 'NormalNoiseGenerator', 'UniformNoiseGenerator',
'SparseMLP', 'MoeRouter', 'MoeCheckpintIO', 'build_ffn_experts'
'EPMLPExperts', 'TPMLPExperts', 'build_ffn_experts',
'MoeRouter', 'Top1Router', 'Top2Router', 'TopKRouter',
'NormalNoiseGenerator', 'UniformNoiseGenerator',
'SparseMLP', 'MoeCheckpintIO'
]
26 changes: 21 additions & 5 deletions colossalai/moe/experts.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import math
from contextlib import nullcontext
from typing import Callable, Optional

import torch
import torch.nn as nn

from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_activation
Expand All @@ -14,15 +14,24 @@
class BaseMLPExperts(nn.Module):
"""
SparseMLP is a multi-layer perceptron with sparse expert parallel layers.

Args:
num_experts (int): The number of experts
forward: hidden_size --> intermediate_size --> hidden_size
hidden_size (int): The hidden size of MLP
intermediate_size (int): The intermediate size of MLP
expert_parallel (str, optional): The parallelism of experts. Now we have 'EP' and 'TP'.
activation (optional): The activation function of MLP
drop_rate (float, optional): The drop rate of MLP
"""

def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
expert_parallel: str = None,
activation: str = None,
expert_parallel: Optional[str] = None,
activation: Optional[Callable] = None,
drop_rate: float = 0,
gated: bool = False,
):
Expand Down Expand Up @@ -76,7 +85,14 @@ def __init__(
for param in self.parameters():
set_moe_tensor_info(param, self.moe_info)

def forward(self, x: torch.Tensor) -> torch.Tensor: # inputs [g, e, c, h]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size)

Returns:
torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
"""
x = MoeInGradScaler.apply(x, self.ep_size)

e = x.size(1)
Expand All @@ -97,7 +113,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # inputs [g, e, c, h]
x = x.reshape(inshape)
x = x.transpose(0, 1).contiguous()
x = MoeOutGradScaler.apply(x, self.ep_size)
return x # outputs [g, e, c, h]
return x


class EPMLPExperts(BaseMLPExperts):
Expand Down
20 changes: 15 additions & 5 deletions colossalai/moe/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,16 @@ def __init__(self,
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size))
nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size))

def forward(self, inputs: torch.Tensor) -> Tuple:
def forward(self,
inputs: torch.Tensor) \
-> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size, seq_len, hidden_size)

Returns:
torch.Tensor: The output tensor of shape (batch_size, seq_len, hidden_size)
"""
# reshape the input tokens
tokens = inputs.reshape(-1, self.hidden_size)

Expand All @@ -100,24 +109,25 @@ def forward(self, inputs: torch.Tensor) -> Tuple:
# the result from the router
route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)

# dispatch_data: (num_experts, capacity, hidden_size)
if self.use_kernel:
dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:])
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.hidden_size)
else:
sec_mask_f = route_result_list[1].type_as(inputs)
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)

# dispatch_data [e, c, h]
# expert_output: (num_groups, num_experts, capacity, hidden_size)
if self.expert_parallel == "EP":
expert_output = self._ep_process(dispatch_data)
elif self.expert_parallel == "TP":
expert_output = self._tp_process(dispatch_data)
elif self.expert_parallel is None:
expert_output = self._local_process(dispatch_data)
else:
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
"build function.")
# expert_output [e, c, h]
raise NotImplementedError("This kind of communication has not been implemented yet.\n"
"Please use Experts build function.")

if self.use_kernel:
expert_output = expert_output.reshape(-1, self.hidden_size)
ans = MoeCombine.apply(expert_output, *route_result_list)
Expand Down
185 changes: 161 additions & 24 deletions colossalai/moe/routers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from abc import ABC
from typing import Callable, Optional
from typing import Callable, Optional, Tuple

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -29,7 +29,7 @@ def __init__(self,
capacity_factor_train: float,
capacity_factor_eval: float,
min_capacity: int,
noisy_func: Callable = None,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True):
super().__init__()
self.k_value = k_value
Expand Down Expand Up @@ -72,9 +72,10 @@ def pop_router_loss(self) -> torch.Tensor:


class Top1Router(MoeRouter):
"""Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More detailed function can be found in the paper about Switch Transformer
of Google.
"""Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
function can be found in the paper about Switch Transformer of Google.

Args:
capacity_factor_train (float, optional): Capacity factor in routing of training.
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
Expand All @@ -89,7 +90,7 @@ def __init__(self,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
select_policy: str = "first",
noisy_func: Callable = None,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True):
super().__init__(k_value=1,
capacity_factor_train=capacity_factor_train,
Expand All @@ -100,12 +101,27 @@ def __init__(self,
self.select_policy = select_policy
assert select_policy in {"first", "random"}
if select_policy == "random":
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()),
high=torch.tensor(1.0,
device=get_current_device())).rsample

def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):

self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(0.0, device=get_current_device()),
high=torch.tensor(1.0, device=get_current_device())
).rsample

def forward(self,
inputs: torch.Tensor,
use_kernel: bool = False,
ep_group: Optional[ProcessGroup] = None
) -> Tuple:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).

Returns:
1. use_kernel is False:
The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
2. use_kernel is True:
...
"""
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)

Expand Down Expand Up @@ -154,8 +170,10 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti


class Top2Router(MoeRouter):
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More detailed function can be found in the paper about ViT-MoE.
"""Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
function can be found in the paper about ViT-MoE.

Args:
capacity_factor_train (float, optional): Capacity factor in routing of training.
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
Expand All @@ -168,7 +186,7 @@ def __init__(self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Callable = None,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True):
super().__init__(k_value=2,
capacity_factor_train=capacity_factor_train,
Expand All @@ -177,8 +195,22 @@ def __init__(self,
noisy_func=noisy_func,
drop_tks=drop_tks)

def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
# inputs: [s, h]
def forward(self,
inputs: torch.Tensor,
use_kernel: bool = False,
ep_group: Optional[ProcessGroup] = None
) -> Tuple:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).

Returns:
1. use_kernel is False:
The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
2. use_kernel is True:
...
"""
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)

Expand Down Expand Up @@ -238,11 +270,116 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti
return cb_weight, sec_mask


def get_router_cls(top_k: int) -> MoeRouter:
if top_k == 1:
router_cls = Top1Router
elif top_k == 2:
router_cls = Top2Router
class TopKRouter(MoeRouter):
"""Masked matmul router using tokens choose top-k experts assignment.

NOTE: this is modified from flaxformer.
This router uses the same mechanism as in Switch Transformer
(https://arxiv.org/abs/2101.03961) and V-MoE
(https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are
sorted by router_probs and then routed to their choice of expert until the
expert's expert_capacity is reached. There is no guarantee that each token is
processed by an expert, or that each expert receives at least one token.

Attributes:
num_selected_experts: Maximum number of experts to which each token is
routed. Tokens may be routed to fewer experts if particular experts are
oversubscribed / reach capacity.
"""

def __init__(self,
num_selected_experts: int,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True):
super().__init__(num_selected_experts,
capacity_factor_train,
capacity_factor_eval,
min_capacity,
noisy_func,
drop_tks)

def forward(self,
router_probs: torch.Tensor,
expert_capacity: int,
) -> Tuple:
"""Computes masks for the top-k experts per token.

Args:
router_probs: <float32>[num_groups, tokens_per_group, num_experts]
probabilities used to determine the routing of tokens to the experts.

Returns:
Dispatch and combine arrays for routing with masked matmuls.
"""
num_groups, _, num_experts = router_probs.shape

# Top-k router probability and corresponding expert indices for each token.
# Shape: [num_groups, tokens_per_group, num_selected_experts].
expert_gate, expert_index = torch.topk(router_probs, self.k_value)

# TODO
# auxiliary_loss = _load_balancing_loss(router_probs, expert_index)

# Make num_selected_experts the leading axis to ensure that top-1 choices
# have priority over top-2 choices, which have priority over top-3 choices,
# etc.
expert_index = torch.transpose(expert_index, 1, 2)
# Shape: [num_groups, num_selected_experts * tokens_per_group]
expert_index = expert_index.reshape(num_groups, -1)

# Create mask out of indices.
# Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32)

# Experts have a fixed capacity that we cannot exceed. A token's priority
# within the expert's buffer is given by the masked, cumulative capacity of
# its target expert.
# Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
token_priority = torch.cumsum(expert_mask, dim=1) * expert_mask - 1
# Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts].
token_priority = token_priority.reshape((num_groups, self.k_value, -1, num_experts))
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
token_priority = torch.transpose(token_priority, 1, 2)
# For each token, across all selected experts, select the only non-negative
# (unmasked) priority. Now, for group G routing to expert E, token T has
# non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E
# is its targeted expert.
# Shape: [num_groups, tokens_per_group, num_experts].
token_priority = torch.max(token_priority, dim=2)[0]

# Token T can only be routed to expert E if its priority is positive and
# less than the expert capacity. One-hot matrix will ignore indices outside
# the range [0, expert_capacity).
# Shape: [num_groups, tokens_per_group, num_experts, expert_capacity].
valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity)
token_priority = torch.masked_fill(token_priority, ~valid_mask, 0)
dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool)
valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, -1, expert_capacity)
dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0)

# The combine array will be used for combining expert outputs, scaled by the
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
# expert_capacity].
combine_array = torch.einsum(
'...te,...tec->...tec',
router_probs,
dispatch_mask)

return combine_array, dispatch_mask


def get_router_cls(top_k: int,
grouped: bool = False
) -> MoeRouter:
if not grouped:
if top_k == 1:
return Top1Router
elif top_k == 2:
return Top2Router
else:
raise NotImplementedError("top_k > 2 is not supported yet")
else:
raise NotImplementedError("top_k > 2 is not supported yet")
return router_cls
return TopKRouter
Loading