diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py index 492cdaf13d1d..1614987538c1 100644 --- a/colossalai/moe/__init__.py +++ b/colossalai/moe/__init__.py @@ -1,10 +1,12 @@ from .checkpoint import MoeCheckpintIO from .experts import EPMLPExperts, TPMLPExperts, build_ffn_experts from .layers import SparseMLP -from .routers import MoeRouter, Top1Router, Top2Router +from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter from .utils import NormalNoiseGenerator, UniformNoiseGenerator __all__ = [ - 'EPMLPExperts', 'TPMLPExperts', 'Top1Router', 'Top2Router', 'NormalNoiseGenerator', 'UniformNoiseGenerator', - 'SparseMLP', 'MoeRouter', 'MoeCheckpintIO', 'build_ffn_experts' + 'EPMLPExperts', 'TPMLPExperts', 'build_ffn_experts', + 'MoeRouter', 'Top1Router', 'Top2Router', 'TopKRouter', + 'NormalNoiseGenerator', 'UniformNoiseGenerator', + 'SparseMLP', 'MoeCheckpintIO' ] diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index da4fe58977e8..9715f4dc37b3 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -1,9 +1,9 @@ import math from contextlib import nullcontext +from typing import Callable, Optional import torch import torch.nn as nn - from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import get_activation @@ -14,6 +14,15 @@ class BaseMLPExperts(nn.Module): """ SparseMLP is a multi-layer perceptron with sparse expert parallel layers. + + Args: + num_experts (int): The number of experts + forward: hidden_size --> intermediate_size --> hidden_size + hidden_size (int): The hidden size of MLP + intermediate_size (int): The intermediate size of MLP + expert_parallel (str, optional): The parallelism of experts. Now we have 'EP' and 'TP'. + activation (optional): The activation function of MLP + drop_rate (float, optional): The drop rate of MLP """ def __init__( @@ -21,8 +30,8 @@ def __init__( num_experts: int, hidden_size: int, intermediate_size: int, - expert_parallel: str = None, - activation: str = None, + expert_parallel: Optional[str] = None, + activation: Optional[Callable] = None, drop_rate: float = 0, gated: bool = False, ): @@ -76,7 +85,14 @@ def __init__( for param in self.parameters(): set_moe_tensor_info(param, self.moe_info) - def forward(self, x: torch.Tensor) -> torch.Tensor: # inputs [g, e, c, h] + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size) + + Returns: + torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size) + """ x = MoeInGradScaler.apply(x, self.ep_size) e = x.size(1) @@ -97,7 +113,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # inputs [g, e, c, h] x = x.reshape(inshape) x = x.transpose(0, 1).contiguous() x = MoeOutGradScaler.apply(x, self.ep_size) - return x # outputs [g, e, c, h] + return x class EPMLPExperts(BaseMLPExperts): diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index ace81b543273..a3f68cf7a6f1 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -88,7 +88,16 @@ def __init__(self, self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size)) - def forward(self, inputs: torch.Tensor) -> Tuple: + def forward(self, + inputs: torch.Tensor) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + inputs (torch.Tensor): The input tensor of shape (batch_size, seq_len, hidden_size) + + Returns: + torch.Tensor: The output tensor of shape (batch_size, seq_len, hidden_size) + """ # reshape the input tokens tokens = inputs.reshape(-1, self.hidden_size) @@ -100,6 +109,7 @@ def forward(self, inputs: torch.Tensor) -> Tuple: # the result from the router route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) + # dispatch_data: (num_experts, capacity, hidden_size) if self.use_kernel: dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:]) dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.hidden_size) @@ -107,7 +117,7 @@ def forward(self, inputs: torch.Tensor) -> Tuple: sec_mask_f = route_result_list[1].type_as(inputs) dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) - # dispatch_data [e, c, h] + # expert_output: (num_groups, num_experts, capacity, hidden_size) if self.expert_parallel == "EP": expert_output = self._ep_process(dispatch_data) elif self.expert_parallel == "TP": @@ -115,9 +125,9 @@ def forward(self, inputs: torch.Tensor) -> Tuple: elif self.expert_parallel is None: expert_output = self._local_process(dispatch_data) else: - raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " - "build function.") - # expert_output [e, c, h] + raise NotImplementedError("This kind of communication has not been implemented yet.\n" + "Please use Experts build function.") + if self.use_kernel: expert_output = expert_output.reshape(-1, self.hidden_size) ans = MoeCombine.apply(expert_output, *route_result_list) diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index dd9243421667..688471530758 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -1,6 +1,6 @@ import math from abc import ABC -from typing import Callable, Optional +from typing import Callable, Optional, Tuple import torch import torch.distributed as dist @@ -29,7 +29,7 @@ def __init__(self, capacity_factor_train: float, capacity_factor_eval: float, min_capacity: int, - noisy_func: Callable = None, + noisy_func: Optional[Callable] = None, drop_tks: bool = True): super().__init__() self.k_value = k_value @@ -72,9 +72,10 @@ def pop_router_loss(self) -> torch.Tensor: class Top1Router(MoeRouter): - """Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] - for routing usage. More detailed function can be found in the paper about Switch Transformer - of Google. + """Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) + and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed + function can be found in the paper about Switch Transformer of Google. + Args: capacity_factor_train (float, optional): Capacity factor in routing of training. capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. @@ -89,7 +90,7 @@ def __init__(self, capacity_factor_eval: float = 2.0, min_capacity: int = 4, select_policy: str = "first", - noisy_func: Callable = None, + noisy_func: Optional[Callable] = None, drop_tks: bool = True): super().__init__(k_value=1, capacity_factor_train=capacity_factor_train, @@ -100,12 +101,27 @@ def __init__(self, self.select_policy = select_policy assert select_policy in {"first", "random"} if select_policy == "random": - self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()), - high=torch.tensor(1.0, - device=get_current_device())).rsample - - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): - + self.uniform = torch.distributions.uniform.Uniform( + low=torch.tensor(0.0, device=get_current_device()), + high=torch.tensor(1.0, device=get_current_device()) + ).rsample + + def forward(self, + inputs: torch.Tensor, + use_kernel: bool = False, + ep_group: Optional[ProcessGroup] = None + ) -> Tuple: + """ + Args: + inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). + + Returns: + 1. use_kernel is False: + The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity). + The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity). + 2. use_kernel is True: + ... + """ if self.noisy_func is not None and self.training: inputs = self.noisy_func(inputs) @@ -154,8 +170,10 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti class Top2Router(MoeRouter): - """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] - for routing usage. More detailed function can be found in the paper about ViT-MoE. + """Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) + and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed + function can be found in the paper about ViT-MoE. + Args: capacity_factor_train (float, optional): Capacity factor in routing of training. capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. @@ -168,7 +186,7 @@ def __init__(self, capacity_factor_train: float = 1.25, capacity_factor_eval: float = 2.0, min_capacity: int = 4, - noisy_func: Callable = None, + noisy_func: Optional[Callable] = None, drop_tks: bool = True): super().__init__(k_value=2, capacity_factor_train=capacity_factor_train, @@ -177,8 +195,22 @@ def __init__(self, noisy_func=noisy_func, drop_tks=drop_tks) - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): - # inputs: [s, h] + def forward(self, + inputs: torch.Tensor, + use_kernel: bool = False, + ep_group: Optional[ProcessGroup] = None + ) -> Tuple: + """ + Args: + inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). + + Returns: + 1. use_kernel is False: + The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity). + The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity). + 2. use_kernel is True: + ... + """ if self.noisy_func is not None and self.training: inputs = self.noisy_func(inputs) @@ -238,11 +270,116 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti return cb_weight, sec_mask -def get_router_cls(top_k: int) -> MoeRouter: - if top_k == 1: - router_cls = Top1Router - elif top_k == 2: - router_cls = Top2Router +class TopKRouter(MoeRouter): + """Masked matmul router using tokens choose top-k experts assignment. + + NOTE: this is modified from flaxformer. + This router uses the same mechanism as in Switch Transformer + (https://arxiv.org/abs/2101.03961) and V-MoE + (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are + sorted by router_probs and then routed to their choice of expert until the + expert's expert_capacity is reached. There is no guarantee that each token is + processed by an expert, or that each expert receives at least one token. + + Attributes: + num_selected_experts: Maximum number of experts to which each token is + routed. Tokens may be routed to fewer experts if particular experts are + oversubscribed / reach capacity. + """ + + def __init__(self, + num_selected_experts: int, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Optional[Callable] = None, + drop_tks: bool = True): + super().__init__(num_selected_experts, + capacity_factor_train, + capacity_factor_eval, + min_capacity, + noisy_func, + drop_tks) + + def forward(self, + router_probs: torch.Tensor, + expert_capacity: int, + ) -> Tuple: + """Computes masks for the top-k experts per token. + + Args: + router_probs: [num_groups, tokens_per_group, num_experts] + probabilities used to determine the routing of tokens to the experts. + + Returns: + Dispatch and combine arrays for routing with masked matmuls. + """ + num_groups, _, num_experts = router_probs.shape + + # Top-k router probability and corresponding expert indices for each token. + # Shape: [num_groups, tokens_per_group, num_selected_experts]. + expert_gate, expert_index = torch.topk(router_probs, self.k_value) + + # TODO + # auxiliary_loss = _load_balancing_loss(router_probs, expert_index) + + # Make num_selected_experts the leading axis to ensure that top-1 choices + # have priority over top-2 choices, which have priority over top-3 choices, + # etc. + expert_index = torch.transpose(expert_index, 1, 2) + # Shape: [num_groups, num_selected_experts * tokens_per_group] + expert_index = expert_index.reshape(num_groups, -1) + + # Create mask out of indices. + # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. + expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32) + + # Experts have a fixed capacity that we cannot exceed. A token's priority + # within the expert's buffer is given by the masked, cumulative capacity of + # its target expert. + # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. + token_priority = torch.cumsum(expert_mask, dim=1) * expert_mask - 1 + # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts]. + token_priority = token_priority.reshape((num_groups, self.k_value, -1, num_experts)) + # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. + token_priority = torch.transpose(token_priority, 1, 2) + # For each token, across all selected experts, select the only non-negative + # (unmasked) priority. Now, for group G routing to expert E, token T has + # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E + # is its targeted expert. + # Shape: [num_groups, tokens_per_group, num_experts]. + token_priority = torch.max(token_priority, dim=2)[0] + + # Token T can only be routed to expert E if its priority is positive and + # less than the expert capacity. One-hot matrix will ignore indices outside + # the range [0, expert_capacity). + # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity]. + valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity) + token_priority = torch.masked_fill(token_priority, ~valid_mask, 0) + dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool) + valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, -1, expert_capacity) + dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0) + + # The combine array will be used for combining expert outputs, scaled by the + # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, + # expert_capacity]. + combine_array = torch.einsum( + '...te,...tec->...tec', + router_probs, + dispatch_mask) + + return combine_array, dispatch_mask + + +def get_router_cls(top_k: int, + grouped: bool = False + ) -> MoeRouter: + if not grouped: + if top_k == 1: + return Top1Router + elif top_k == 2: + return Top2Router + else: + raise NotImplementedError("top_k > 2 is not supported yet") else: - raise NotImplementedError("top_k > 2 is not supported yet") - return router_cls + return TopKRouter diff --git a/tests/test_moe/test_moe_router.py b/tests/test_moe/test_moe_router.py new file mode 100644 index 000000000000..94c263baa5a3 --- /dev/null +++ b/tests/test_moe/test_moe_router.py @@ -0,0 +1,48 @@ +import pytest +import torch +from colossalai.moe.routers import (MoeRouter, Top1Router, Top2Router, + TopKRouter, get_router_cls) + + +@pytest.mark.parametrize(["router", "num_groups"], [ + (Top1Router(), 1), + (Top2Router(), 1), + (TopKRouter(num_selected_experts=3), 4), +]) +@pytest.mark.parametrize( + ["batch_size", "seq_len", "num_experts"], + [ + (4, 5, 8), + (3, 4, 4), + ] +) +def test_router_forward(router: MoeRouter, + batch_size: int, + seq_len: int, + num_experts: int, + num_groups: int): + x = torch.randn((batch_size * seq_len, num_experts)) + if num_groups > 1: + x = x.expand(num_groups, -1, -1) + + router.train() + if isinstance(router, TopKRouter): + combine_array, dispatch_mask = router(x, expert_capacity=2) + else: + combine_array, dispatch_mask = router(x) + assert combine_array.shape[:-1] == x.shape + assert dispatch_mask.shape[:-1] == x.shape + assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value) + + router.eval() + if isinstance(router, TopKRouter): + combine_array, dispatch_mask = router(x, expert_capacity=2) + else: + combine_array, dispatch_mask = router(x) + assert combine_array.shape[:-1] == x.shape + assert dispatch_mask.shape[:-1] == x.shape + assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value) + + +if __name__ == "__main__": + test_router_forward(Top1Router(), 4, 4, 4, 1)