diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py index b41f4072a405..5622bd271735 100644 --- a/colossalai/context/moe_context.py +++ b/colossalai/context/moe_context.py @@ -3,29 +3,9 @@ import torch import torch.distributed as dist -from colossalai.context.parallel_mode import ParallelMode from colossalai.context.singleton_meta import SingletonMeta -from colossalai.tensor import ProcessGroup - - -def _check_sanity(): - from colossalai.core import global_context as gpc - if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1: - raise NotImplementedError("Moe is not compatible with tensor or " - "pipeline parallel at present.") - - -class MoeParallelInfo: - """Moe parallelism information, storing parallel sizes and groups. - """ - - def __init__(self, ep_size: int, dp_size: int): - _check_sanity() - self.ep_size = ep_size - self.dp_size = dp_size - self.pg = ProcessGroup(tp_degree=ep_size, dp_degree=dp_size) - self.ep_group = self.pg.tp_process_group() - self.dp_group = self.pg.dp_process_group() +from colossalai.tensor.moe_tensor.api import get_moe_info +from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo class MoeContext(metaclass=SingletonMeta): @@ -34,12 +14,12 @@ class MoeContext(metaclass=SingletonMeta): """ def __init__(self): - self.world_size = 1 + self.world_size = None # Users may want to set maximum expert parallel size smaller than the world size # since very low bandwidth across nodes may constrain the performance of MoE # When we have a maximum expert parallel size, we have a minimum data parallel size naturally - self.max_ep_size = 1 - self.min_dp_size = 1 + self.max_ep_size = None + self.min_dp_size = None self.aux_loss = None self.use_kernel_optim = True @@ -54,17 +34,12 @@ def parallel_info_dict(self): def is_initialized(self): return self.has_setup - def setup(self, seed: int, use_kernel_optim: bool = True): + def setup(self, seed: int, use_kernel_optim: bool = True, max_ep_size: int = 8): assert not self.is_initialized, "MoE distributed context shouldn't be set up again" - _check_sanity() assert torch.cuda.is_available(), "MoE requires to enable CUDA first" self.world_size = dist.get_world_size() - - from colossalai.core import global_context as gpc - self.max_ep_size = gpc.config.get('max_ep_size', self.world_size) - assert self.world_size % self.max_ep_size == 0, \ - "Maximum expert parallel size must be a factor of the number of GPUs" + self.max_ep_size = min(max_ep_size, dist.get_world_size()) self.min_dp_size = self.world_size // self.max_ep_size # Enabling kernel optimization may raise error in some cases @@ -75,7 +50,7 @@ def setup(self, seed: int, use_kernel_optim: bool = True): moe_set_seed(seed) self.has_setup = True - def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]: + def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]: """Calculate the Data Parallel Group and Expert Parallel Group. Parameters @@ -104,12 +79,15 @@ def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]: ep_size = self.max_ep_size // dp_size # Calculate the number of experts for each GPU - num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size + if use_tp: + num_local_experts = num_experts + else: + num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size # Don't forget to multiply minimum data parallel size dp_size *= self.min_dp_size if not (ep_size in self.parallel_info_dict): - self.parallel_info_dict[ep_size] = MoeParallelInfo(ep_size, dp_size) + self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size) return num_local_experts, self.parallel_info_dict[ep_size] diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py index 05333fe965f1..ffeeac796441 100644 --- a/colossalai/nn/layer/moe/__init__.py +++ b/colossalai/nn/layer/moe/__init__.py @@ -1,10 +1,10 @@ -from .checkpoint import load_moe_model, save_moe_model -from .experts import Experts, FFNExperts, TPExperts -from .layers import MoeLayer, MoeModule +from .checkpoint import MoeCheckpintIO +from .experts import EPMLPExperts, TPMLPExperts +from .layers import MoeLayer, MoeModule, SparseMLP from .routers import MoeRouter, Top1Router, Top2Router from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts __all__ = [ - 'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator', - 'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter', 'save_moe_model', 'load_moe_model' + 'EPMLPExperts', 'TPMLPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'MoeModule', 'NormalNoiseGenerator', + 'UniformNoiseGenerator', 'build_ffn_experts', 'SparseMLP', 'MoeRouter', 'MoeCheckpintIO' ] diff --git a/colossalai/nn/layer/moe/checkpoint.py b/colossalai/nn/layer/moe/checkpoint.py index efda1f22252d..34af87bd9d47 100644 --- a/colossalai/nn/layer/moe/checkpoint.py +++ b/colossalai/nn/layer/moe/checkpoint.py @@ -1,40 +1,61 @@ +from pathlib import Path +from typing import Optional + import torch import torch.distributed as dist import torch.nn as nn +from torch.optim import Optimizer + +from colossalai.checkpoint_io import CheckpointIO +from colossalai.tensor.moe_tensor.api import get_ep_group + + +class MoeCheckpintIO(CheckpointIO): + + def __init__(self) -> None: + super().__init__() + + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): + state_dict = torch.load(checkpoint) + for name, param in model.named_parameters(): + if '.experts.' in name: + ep_rank = dist.get_rank(get_ep_group(param)) + ep_size = dist.get_world_size(get_ep_group(param)) + for rank in range(ep_size): + new_name = name.replace('.experts.', f'.experts.{rank}.') + if rank == ep_rank: + state_dict[name] = state_dict.pop(new_name) + else: + state_dict.pop(new_name) -from .experts import MoeExperts + model.load_state_dict(state_dict, strict=strict) + def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + state_dict = model.state_dict() + if dist.get_rank() == 0: + torch.save(state_dict, checkpoint) + dist.barrier() -def save_moe_model(model: nn.Module, save_path: str): - state_dict = model.state_dict() - if dist.get_rank() == 0: - torch.save(state_dict, save_path) - dist.barrier() + def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool): + raise NotImplementedError() + def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str], + size_per_shard: int, use_safetensors: bool): + raise NotImplementedError() -def load_moe_model(model: nn.Module, load_path: str): - state_dict = torch.load(load_path) + # ======================================================== + # Abstract methods for optimizer loading/saving implementation + # ======================================================== - for prefix, module in model.named_modules(): - if prefix.endswith('.moe_layer.experts'): - # this module should be an Experts instance - assert isinstance(module, MoeExperts) + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): + raise NotImplementedError() - ep_rank = dist.get_rank(module.dist_info.ep_group) - num_local = module.num_local_experts - for i in range(num_local): - expert_id = ep_rank * num_local + i - for name, _ in module.experts[i].named_parameters(): - cur_key = f'{prefix}.experts.{i}.{name}' - param_key = f'{prefix}.experts.{expert_id}.{name}' - load_param = state_dict[param_key] - state_dict[cur_key] = load_param + def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): + raise NotImplementedError() - for name, _ in module.experts[0].named_parameters(): - pop_pre = f'{prefix}.experts.' - pop_suf = f'.{name}' - for i in range(num_local, module.num_total_experts): - pop_key = f'{pop_pre}{i}{pop_suf}' - state_dict.pop(pop_key) + def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, + size_per_shard: int): + raise NotImplementedError() - model.load_state_dict(state_dict) + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool): + raise NotImplementedError() diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index 25de4364cb39..0ed2f1fd2513 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -1,6 +1,5 @@ import math from copy import deepcopy -from typing import Type import torch import torch.distributed as dist @@ -8,197 +7,133 @@ from colossalai.context import ParallelMode, seed from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.tensor.moe_tensor.api import set_moe_param_info -from colossalai.utils import get_current_device -from colossalai.zero.legacy.init_ctx import no_shard_zero_decrator +from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size, set_moe_tensor_info -class MoeExperts(nn.Module): - """Basic class for experts in MoE. It stores what kind of communication experts use - to exchange tokens, how many experts in a single GPU and parallel information such as - expert parallel size, data parallel size and their distributed communication groups. +class BaseMLPExperts(nn.Module): """ - - def __init__(self, comm_name: str, num_experts: int): - super().__init__() - assert comm_name in {"all_to_all", "all_gather"}, \ - "This kind of communication has not been implemented yet.\n Please use Experts build function." - self.comm_name = comm_name - self.num_total_experts = num_experts - # Get the configuration of experts' deployment and parallel information from moe context - self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts) - - -@no_shard_zero_decrator(is_replicated=False) -class Experts(MoeExperts): - """A wrapper class to create experts. It will create E experts across the - moe model parallel group, where E is the number of experts. Every expert - is a instance of the class, 'expert' in initialization parameters. - - Args: - expert_cls (:class:`torch.nn.Module`): The class of all experts - num_experts (int): The number of experts - expert_args: Args used to initialize experts, the args could be found in corresponding expert class + SparseMLP is a multi-layer perceptron with sparse expert parallel layers. """ - def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args): - super().__init__("all_to_all", num_experts) - - # Use seed to make every expert different from others - with seed(ParallelMode.TENSOR): - self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)]) - - # Attach parallel information for all parameters in Experts - for exp in self.experts: - for param in exp.parameters(): - set_moe_param_info(param, self.dist_info) - - def forward(self, inputs: torch.Tensor): - # Split inputs for each expert - expert_input = torch.chunk(inputs, self.num_local_experts, dim=1) - expert_output = [] - - # Get outputs from each expert - for i in range(self.num_local_experts): - expert_output.append(self.experts[i](expert_input[i])) - - # Concatenate all outputs together - output = torch.cat(expert_output, dim=1).contiguous() - return output - - def state_dict(self, destination=None, prefix='', keep_vars=False): - assert keep_vars == False, "Only support keep_vars=False now" - dp_rank = dist.get_rank(self.dist_info.dp_group) - ep_rank = dist.get_rank(self.dist_info.ep_group) - submodule_dict = dict() - example_submodule = None - for name, subm in self.experts.named_modules(): - if subm is self.experts: - continue - module_number = self.num_local_experts * ep_rank + int(name) - submodule_dict[module_number] = subm - example_submodule = subm - - if dp_rank == 0: - local_prefix = prefix + 'experts.' - buffer_module = deepcopy(example_submodule) - for i in range(self.num_total_experts): - source_rank = i // self.num_local_experts - current_prefix = local_prefix + str(i) + '.' - comm_module = submodule_dict.get(i, buffer_module) - for name, param in comm_module.named_parameters(): - dist.broadcast(param.data, src=source_rank, group=self.dist_info.ep_group) - if ep_rank == 0: - destination[current_prefix + name] = param.data.cpu() - - dist.barrier() - - -class FFNExperts(MoeExperts): - """Use torch.bmm to speed up for multiple experts. - """ - - def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - super().__init__("all_to_all", num_experts) + def __init__( + self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + expert_parallel: str, + activation: str = None, + drop_rate: float = 0, + ): + super().__init__() + assert expert_parallel in ["EP", "TP"] + self.expert_parallel = expert_parallel - self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device())) - self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device())) + # get local and total experts + self.num_total_experts = num_experts + self.num_local_experts, self.moe_info = MOE_CONTEXT.get_info(num_experts, + use_tp=True if expert_parallel == "TP" else False) - self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device())) - self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device())) + # get settings for different parallel + if expert_parallel == "TP": + assert intermediate_size % MOE_CONTEXT.max_ep_size == 0, \ + "intermediate_size should be divide by maximum expert parallel size" + intermediate_size = intermediate_size // MOE_CONTEXT.max_ep_size + num_experts = self.num_total_experts + else: + num_experts = self.num_local_experts - s1 = math.sqrt(0.1 / d_model) - s2 = math.sqrt(0.1 / d_ff) + self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) + self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size)) with seed(ParallelMode.TENSOR): - nn.init.trunc_normal_(self.w1, std=s1) - nn.init.trunc_normal_(self.b1, std=s1) - nn.init.trunc_normal_(self.w2, std=s2) - nn.init.trunc_normal_(self.b2, std=s2) + nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) + nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size)) self.act = nn.GELU() if activation is None else activation self.drop = nn.Dropout(p=drop_rate) for param in self.parameters(): - param.__setattr__('moe_info', self.dist_info) - - def forward(self, inputs): # inputs [g, el, c, h] + set_moe_tensor_info(param, self.moe_info) - el = inputs.size(1) - h = inputs.size(-1) + def forward(self, x: torch.Tensor) -> torch.Tensor: # inputs [g, e, c, h] - inputs = inputs.transpose(0, 1) - inshape = inputs.shape - inputs = inputs.reshape(el, -1, h) + e = x.size(1) + h = x.size(-1) - out_ff = torch.baddbmm(self.b1, inputs, self.w1) - out_act = self.act(out_ff) - with seed(ParallelMode.TENSOR): - out_inter = self.drop(out_act) + x = x.transpose(0, 1) + inshape = x.shape + x = x.reshape(e, -1, h) - out_model = torch.baddbmm(self.b2, out_inter, self.w2) + x = torch.bmm(x, self.wi) + x = self.act(x) with seed(ParallelMode.TENSOR): - outputs = self.drop(out_model) # outputs [el, gc, h] + x = self.drop(x) + x = torch.bmm(x, self.wo) - outputs = outputs.reshape(inshape) - outputs = outputs.transpose(0, 1).contiguous() - return outputs + x = x.reshape(inshape) + x = x.transpose(0, 1).contiguous() + return x # outputs [g, e, c, h] -class TPExperts(MoeExperts): - """Use tensor parallelism to split each expert evenly, which can deploy experts in - case that the number of experts can't be divide by maximum expert parallel size or - maximum expert parallel size can't be divide by the number of experts. +class EPMLPExperts(BaseMLPExperts): + """ + Use expert parallelism to split each expert evenly, which can deploy experts in """ - def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - super().__init__("all_gather", MOE_CONTEXT.max_ep_size) - - assert d_ff % MOE_CONTEXT.max_ep_size == 0, \ - "d_ff should be divide by maximum expert parallel size" - - p_ff = d_ff // MOE_CONTEXT.max_ep_size - - self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device())) - self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device())) - - self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device())) - self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device())) - - s1 = math.sqrt(0.1 / d_model) - s2 = math.sqrt(0.1 / d_ff) - - with seed(ParallelMode.TENSOR): - nn.init.trunc_normal_(self.w1, std=s1) - nn.init.trunc_normal_(self.b1, std=s1) - nn.init.trunc_normal_(self.w2, std=s2) - - nn.init.trunc_normal_(self.b2, std=s2) - - self.act = nn.GELU() if activation is None else activation - self.drop = nn.Dropout(p=drop_rate) - - self.w1.__setattr__('moe_info', self.dist_info) - self.w2.__setattr__('moe_info', self.dist_info) - self.b1.__setattr__('moe_info', self.dist_info) - - def forward(self, inputs): # inputs [g, e, c, h] + def __init__(self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + activation=None, + drop_rate: float = 0): + super().__init__(num_experts, hidden_size, intermediate_size, "EP", activation, drop_rate) - e = inputs.size(1) - h = inputs.size(-1) + def state_dict(self, destination=None, prefix='', keep_vars=False): + dp_rank = dist.get_rank(get_dp_group(self)) + ep_rank = dist.get_rank(get_ep_group(self)) + ep_size = get_ep_size(self) + # dp rank 0 will save the state dict + if dp_rank == 0: + for name, param in self.named_parameters(): + if param is self: + continue + # create buffer + buffer_module = deepcopy(param) + # gather param from every ep rank + for source_rank in range(ep_size): + current_prefix = f"{prefix}{source_rank}." + if ep_rank == source_rank: + dist.broadcast(param.data, src=source_rank, group=self.moe_info.ep_group) + else: + dist.broadcast(buffer_module.data, src=source_rank, group=self.moe_info.ep_group) + if ep_rank == 0: + if keep_vars: + destination[current_prefix + name] = buffer_module.cpu() + else: + destination[current_prefix + name] = buffer_module.data.cpu() - inputs = inputs.transpose(0, 1) - inshape = inputs.shape - inputs = inputs.reshape(e, -1, h) + dist.barrier() - out_ff = torch.baddbmm(self.b1, inputs, self.w1) - out_act = self.act(out_ff) - with seed(ParallelMode.TENSOR): - out_inter = self.drop(out_act) - out_model = torch.baddbmm(self.b2, out_inter, self.w2) - outputs = self.drop(out_model) # outputs [e, gc, h] +class TPMLPExperts(BaseMLPExperts): + """Use tensor parallelism to split each expert evenly, which can deploy experts in + case that the number of experts can't be divide by maximum expert parallel size or + maximum expert parallel size can't be divide by the number of experts. + """ - outputs = outputs.reshape(inshape) - outputs = outputs.transpose(0, 1).contiguous() - return outputs # outputs [g, e, c, h] + def __init__(self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + activation: str = None, + drop_rate: float = 0): + super().__init__(num_experts, hidden_size, intermediate_size, "TP", activation, drop_rate) + + +def get_expert_class(name: str) -> BaseMLPExperts: + if name == "TP": + return TPMLPExperts + elif name == "EP": + return EPMLPExperts + else: + raise ValueError(f"Unknown expert class name: {name}") diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index 03f55d91f3a8..d870781d29c4 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -1,5 +1,5 @@ import math -from typing import Optional, Tuple, Type +from typing import Optional, Tuple import torch import torch.nn as nn @@ -14,14 +14,12 @@ MoeDispatch, ReduceScatter, ) -from colossalai.nn.layer.moe.experts import Experts, MoeExperts -from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router -from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator -from colossalai.utils import get_current_device -from colossalai.zero.legacy.init_ctx import no_shard_zero_context, no_shard_zero_decrator +from colossalai.nn.layer.moe.experts import BaseMLPExperts, get_expert_class +from colossalai.nn.layer.moe.routers import MoeRouter, get_router_cls +from colossalai.nn.layer.moe.utils import get_noise_generator +from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size -@no_shard_zero_decrator(is_replicated=True) class MoeLayer(nn.Module): """A MoE layer, that puts its input tensor to its gate and uses the output logits to router all tokens, is mainly used to exchange all tokens for every expert across @@ -35,21 +33,21 @@ class MoeLayer(nn.Module): experts (MoeExperts): Instance of experts generated by Expert. """ - def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: MoeExperts): + def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: BaseMLPExperts): super().__init__() self.d_model = dim_model self.num_experts = num_experts self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model)) self.router: MoeRouter = router - self.experts: MoeExperts = experts + self.experts: BaseMLPExperts = experts self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False - self.ep_group = experts.dist_info.ep_group - self.ep_size = experts.dist_info.ep_size + self.ep_group = get_ep_group(experts) + self.ep_size = get_ep_size(experts) self.num_local_experts = experts.num_local_experts nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model)) - def a2a_process(self, dispatch_data: torch.Tensor): + def ep_process(self, dispatch_data: torch.Tensor): expert_input = AllToAll.apply(dispatch_data, self.ep_group) input_shape = expert_input.shape expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model) @@ -84,9 +82,9 @@ def forward(self, inputs: torch.Tensor) -> Tuple: dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) # dispatch_data [e, c, h] - if self.experts.comm_name == "all_to_all": - expert_output = self.a2a_process(dispatch_data) - elif self.experts.comm_name == "all_gather": + if self.experts.expert_parallel == "EP": + expert_output = self.ep_process(dispatch_data) + elif self.experts.expert_parallel == "TP": expert_output = self.tp_process(dispatch_data) else: raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " @@ -106,7 +104,7 @@ def forward(self, inputs: torch.Tensor) -> Tuple: return ans, l_aux -class MoeModule(nn.Module): +class SparseMLP(nn.Module): """A class for users to create MoE modules in their models. Args: @@ -136,7 +134,6 @@ class MoeModule(nn.Module): """ def __init__(self, - dim_model: int, num_experts: int, top_k: int = 1, capacity_factor_train: float = 1.25, @@ -144,67 +141,111 @@ def __init__(self, min_capacity: int = 4, noisy_policy: Optional[str] = None, drop_tks: bool = True, - use_residual: bool = False, - residual_instance: Optional[nn.Module] = None, - expert_instance: Optional[MoeExperts] = None, - expert_cls: Optional[Type[nn.Module]] = None, - **expert_args): + expert_parallel: str = "EP", + hidden_size: int = 2048, + intermediate_size: int = 2048, + activation: str = None): super().__init__() + self.hidden_size = hidden_size + self.num_experts = num_experts + self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False + assert expert_parallel in ["EP", "TP"], f"Unsupported expert parallel type {expert_parallel}" + + # moe router + noisy_func = get_noise_generator(noisy_policy, num_experts) + router_cls = get_router_cls(top_k) + self.router: MoeRouter = router_cls(capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks) + + # moe experts + expert_cls = get_expert_class(expert_parallel) + self.experts: BaseMLPExperts = expert_cls(num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + activation=activation) + self.ep_group = get_ep_group(self.experts) + self.ep_size = get_ep_size(self.experts) + self.num_local_experts = self.experts.num_local_experts + + self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) + nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size)) + + def forward(self, inputs: torch.Tensor) -> Tuple: + # reshape the input tokens + tokens = inputs.reshape(-1, self.hidden_size) - noisy_func = None - if noisy_policy is not None: - if noisy_policy == 'Jitter': - noisy_func = UniformNoiseGenerator() - elif noisy_policy == 'Gaussian': - noisy_func = NormalNoiseGenerator(num_experts) - else: - raise NotImplementedError("Unsupported input noisy policy") - - if top_k == 1: - moe_router_cls = Top1Router - elif top_k == 2: - moe_router_cls = Top2Router + # the data type of the inputs in the gating should be fp32 + fp32_input = tokens.to(torch.float) + fp32_weight = self.gate_weight.to(torch.float) + gate_output = F.linear(fp32_input, fp32_weight) + + # the result from the router + route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) + + if self.use_kernel: + dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:]) + dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.hidden_size) else: - raise NotImplementedError("top_k > 2 is not supported yet") - - self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) - self.use_residual = use_residual - if use_residual: - if residual_instance is not None: - self.residual_module = residual_instance - else: - assert expert_cls is not None, \ - "Expert class can't be None when residual instance is not given" - self.residual_module = expert_cls(**expert_args) - - with no_shard_zero_context(): - self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device()) - - if expert_instance is not None: - my_experts = expert_instance + sec_mask_f = route_result_list[1].type_as(inputs) + dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) + + # dispatch_data [e, c, h] + if self.experts.expert_parallel == "EP": + expert_output = self.ep_process(dispatch_data) + elif self.experts.expert_parallel == "TP": + expert_output = self.tp_process(dispatch_data) else: - assert expert_cls is not None, \ - "Expert class can't be None when experts instance is not given" - my_experts = Experts(expert_cls, num_experts, **expert_args) - - self.moe_layer = MoeLayer(dim_model=dim_model, - num_experts=num_experts, - router=self.moe_router, - experts=my_experts) - - def forward(self, inputs: torch.Tensor): - moe_output, l_aux = self.moe_layer(inputs) - - if self.use_residual: - residual_output = self.residual_module(inputs) - combine_coef = self.residual_combine(inputs) - combine_coef = F.softmax(combine_coef, dim=-1) - output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:] + raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " + "build function.") + # expert_output [e, c, h] + if self.use_kernel: + expert_output = expert_output.reshape(-1, self.hidden_size) + ans = MoeCombine.apply(expert_output, *route_result_list) else: - output = moe_output + combine_weights = route_result_list[0].type_as(inputs) + combine_weights = combine_weights.view(combine_weights.shape[0], -1) + expert_output = expert_output.view(-1, expert_output.shape[-1]) + ans = torch.matmul(combine_weights, expert_output) - return output, l_aux + ans = ans.reshape(inputs.shape) + l_aux = self.router.pop_routing_loss() + return ans, l_aux + + def ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: + expert_input = AllToAll.apply(dispatch_data, self.ep_group) + input_shape = expert_input.shape + expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) + expert_output = self.experts(expert_input) + expert_output = expert_output.reshape(input_shape) + expert_output = AllToAll.apply(expert_output, self.ep_group) + return expert_output + + def tp_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: + expert_in = AllGather.apply(dispatch_data, self.ep_group) + expert_out = self.experts(expert_in) + expert_out = ReduceScatter.apply(expert_out, self.ep_group) + return expert_out + + +class MoeModule(nn.Module): + """ + For other dependency + """ + + def __init__(self, + num_experts: int, + top_k: int = 1, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_policy: Optional[str] = None, + drop_tks: bool = True, + expert_parallel: str = "EP", + hidden_size: int = 2048, + intermediate_size: int = 2048, + activation: str = None): + super().__init__(num_experts, top_k, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_policy, + drop_tks, expert_parallel, hidden_size, intermediate_size, activation) diff --git a/colossalai/nn/layer/moe/routers.py b/colossalai/nn/layer/moe/routers.py index c5b8390bf047..53fd8fd43e91 100644 --- a/colossalai/nn/layer/moe/routers.py +++ b/colossalai/nn/layer/moe/routers.py @@ -1,226 +1,237 @@ -import math -from abc import ABC - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.distributed as dist -from colossalai.utils import get_current_device -from colossalai.context import MOE_CONTEXT -from colossalai.nn.layer.moe._operation import moe_cumsum -from typing import Callable, Optional -from torch.distributed import ProcessGroup - - -class MoeRouter(nn.Module, ABC): - """Base class for all MoE routers. - Args: - k_value (int): The value of top_k. - capacity_factor_train (float): Capacity factor in routing of training. - capacity_factor_eval (float): Capacity factor in routing of evaluation. - min_capacity (int): The minimum number of the capacity of each expert. - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation - """ - - def __init__(self, - k_value: int, - capacity_factor_train: float, - capacity_factor_eval: float, - min_capacity: int, - noisy_func: Callable = None, - drop_tks: bool = True): - super().__init__() - self.k_value = k_value - self.capacity_factor_train = capacity_factor_train - self.capacity_factor_eval = capacity_factor_eval - self.min_capacity = min_capacity - self.noisy_func = noisy_func - self.drop_tks = drop_tks - self._routing_loss = None - - def get_capacity(self, logits_shape): - capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval - capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1]) - capacity += capacity % 2 - capacity = max(capacity, self.min_capacity) - assert capacity > 0 - return capacity - - def set_routing_loss(self, aux_loss: torch.Tensor) -> None: - assert self._routing_loss is None - self._routing_loss = aux_loss - - def pop_routing_loss(self) -> torch.Tensor: - assert self._routing_loss is not None - reservation = self._routing_loss - self._routing_loss = None - return reservation - - -class Top1Router(MoeRouter): - """Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] - for routing usage. More detailed function can be found in the paper about Switch Transformer - of Google. - Args: - capacity_factor_train (float, optional): Capacity factor in routing of training. - capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. - min_capacity (int, optional): The minimum number of the capacity of each expert. - select_policy (str, optional): The policy about tokens selection. - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation - """ - - def __init__(self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - select_policy: str = "first", - noisy_func: Callable = None, - drop_tks: bool = True): - super().__init__(k_value=1, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) - self.select_policy = select_policy - assert select_policy in {"first", "random"} - if select_policy == "random": - self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()), - high=torch.tensor(1.0, - device=get_current_device())).rsample - - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): - - if self.noisy_func is not None and self.training: - inputs = self.noisy_func(inputs) - - assert inputs.dtype == torch.float - logits = F.softmax(inputs, dim=-1) - num_experts = logits.size(-1) - capacity = self.get_capacity(logits.shape) - - top1_idx = torch.argmax(inputs, dim=-1) - mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - - # caculate the auxiliary loss - me = torch.mean(logits, dim=0) - ce = torch.mean(mask.float(), dim=0) - l_aux = num_experts * torch.sum(me * ce) - self.set_routing_loss(l_aux) - - if not self.training and not self.drop_tks: - max_num = torch.max(torch.sum(mask, dim=0)) - dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) - capacity = max_num.item() - - if self.select_policy == "random": - rand_mask = mask * self.uniform(mask.shape) - _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) - mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) - ranks = moe_cumsum(mask) - elif self.select_policy == "first": - ranks = moe_cumsum(mask) - mask = mask * torch.lt(ranks, capacity) - else: - raise NotImplementedError("Not support such select policy yet.") - - ranks = torch.sum(mask * ranks, dim=-1) - - if use_kernel: - mask = torch.sum(mask, dim=-1) - mask = torch.stack([mask], dim=0).to(torch.int32) - dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) - return logits, mask, dest_idx, num_experts * capacity - else: - ranks = F.one_hot(ranks, num_classes=capacity) - weight = mask * logits.type_as(inputs) - combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) - sec_mask = combine_weights.bool() - return combine_weights, sec_mask - - -class Top2Router(MoeRouter): - """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] - for routing usage. More detailed function can be found in the paper about ViT-MoE. - Args: - capacity_factor_train (float, optional): Capacity factor in routing of training. - capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. - min_capacity (int, optional): The minimum number of the capacity of each expert - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation. - """ - - def __init__(self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_func: Callable = None, - drop_tks: bool = True): - super().__init__(k_value=2, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) - - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): - # inputs: [s, h] - if self.noisy_func is not None and self.training: - inputs = self.noisy_func(inputs) - - assert inputs.dtype == torch.float - logits = F.softmax(inputs, dim=-1) # logits: [s, e] - num_experts = logits.size(-1) - capacity = self.get_capacity(logits.shape) - - top1_idx = torch.argmax(logits, dim=-1) - mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - logits_except1 = logits.masked_fill(mask1.bool(), float("-inf")) - top2_idx = torch.argmax(logits_except1, dim=-1) - mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) - - cmask = (mask1 + mask2) # loss: [s, e] - - # caculate the auxiliary loss - me = torch.mean(logits, dim=0) - ce = torch.mean(cmask.float(), dim=0) - l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 - self.set_routing_loss(l_aux) - - if not self.training and not self.drop_tks: - max_num = torch.max(torch.sum(cmask, dim=0)) - dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) - capacity = max_num.item() - - rank1 = moe_cumsum(mask1) # rank1: [s, e] - rank2 = moe_cumsum(mask2) - rank2 += torch.sum(mask1, dim=-2, keepdim=True) - - mask1 *= torch.lt(rank1, capacity) - mask2 *= torch.lt(rank2, capacity) - - rank1 = torch.sum(mask1 * rank1, dim=-1) - rank2 = torch.sum(mask2 * rank2, dim=-1) - - if use_kernel: - mask1 = torch.sum(mask1, dim=-1) - mask2 = torch.sum(mask2, dim=-1) - - mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) - dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) - - return logits, mask, dest_idx, num_experts * capacity - else: - weight1 = mask1 * logits.type_as(inputs) - weight2 = mask2 * logits.type_as(inputs) - rank1_sc = F.one_hot(rank1, num_classes=capacity) - rank2_sc = F.one_hot(rank2, num_classes=capacity) - - cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) - cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) - cb_weight = cb_weight1 + cb_weight2 - sec_mask = cb_weight.bool() - - return cb_weight, sec_mask +import math +from abc import ABC +from typing import Callable, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed import ProcessGroup + +from colossalai.context import MOE_CONTEXT +from colossalai.nn.layer.moe._operation import moe_cumsum +from colossalai.utils import get_current_device + + +class MoeRouter(nn.Module, ABC): + """Base class for all MoE routers. + Args: + k_value (int): The value of top_k. + capacity_factor_train (float): Capacity factor in routing of training. + capacity_factor_eval (float): Capacity factor in routing of evaluation. + min_capacity (int): The minimum number of the capacity of each expert. + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation + """ + + def __init__(self, + k_value: int, + capacity_factor_train: float, + capacity_factor_eval: float, + min_capacity: int, + noisy_func: Callable = None, + drop_tks: bool = True): + super().__init__() + self.k_value = k_value + self.capacity_factor_train = capacity_factor_train + self.capacity_factor_eval = capacity_factor_eval + self.min_capacity = min_capacity + self.noisy_func = noisy_func + self.drop_tks = drop_tks + self._routing_loss = None + + def get_capacity(self, logits_shape): + capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval + capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1]) + capacity += capacity % 2 + capacity = max(capacity, self.min_capacity) + assert capacity > 0 + return capacity + + def set_routing_loss(self, aux_loss: torch.Tensor) -> None: + assert self._routing_loss is None + self._routing_loss = aux_loss + + def pop_routing_loss(self) -> torch.Tensor: + assert self._routing_loss is not None + reservation = self._routing_loss + self._routing_loss = None + return reservation + + +class Top1Router(MoeRouter): + """Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] + for routing usage. More detailed function can be found in the paper about Switch Transformer + of Google. + Args: + capacity_factor_train (float, optional): Capacity factor in routing of training. + capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. + min_capacity (int, optional): The minimum number of the capacity of each expert. + select_policy (str, optional): The policy about tokens selection. + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation + """ + + def __init__(self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + select_policy: str = "first", + noisy_func: Callable = None, + drop_tks: bool = True): + super().__init__(k_value=1, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks) + self.select_policy = select_policy + assert select_policy in {"first", "random"} + if select_policy == "random": + self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()), + high=torch.tensor(1.0, + device=get_current_device())).rsample + + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): + + if self.noisy_func is not None and self.training: + inputs = self.noisy_func(inputs) + + assert inputs.dtype == torch.float + logits = F.softmax(inputs, dim=-1) + num_experts = logits.size(-1) + capacity = self.get_capacity(logits.shape) + + top1_idx = torch.argmax(inputs, dim=-1) + mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) + + # caculate the auxiliary loss + me = torch.mean(logits, dim=0) + ce = torch.mean(mask.float(), dim=0) + l_aux = num_experts * torch.sum(me * ce) + self.set_routing_loss(l_aux) + + if not self.training and not self.drop_tks: + max_num = torch.max(torch.sum(mask, dim=0)) + dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) + capacity = max_num.item() + + if self.select_policy == "random": + rand_mask = mask * self.uniform(mask.shape) + _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) + mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) + ranks = moe_cumsum(mask) + elif self.select_policy == "first": + ranks = moe_cumsum(mask) + mask = mask * torch.lt(ranks, capacity) + else: + raise NotImplementedError("Not support such select policy yet.") + + ranks = torch.sum(mask * ranks, dim=-1) + + if use_kernel: + mask = torch.sum(mask, dim=-1) + mask = torch.stack([mask], dim=0).to(torch.int32) + dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) + return logits, mask, dest_idx, num_experts * capacity + else: + ranks = F.one_hot(ranks, num_classes=capacity) + weight = mask * logits.type_as(inputs) + combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) + sec_mask = combine_weights.bool() + return combine_weights, sec_mask + + +class Top2Router(MoeRouter): + """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] + for routing usage. More detailed function can be found in the paper about ViT-MoE. + Args: + capacity_factor_train (float, optional): Capacity factor in routing of training. + capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. + min_capacity (int, optional): The minimum number of the capacity of each expert + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation. + """ + + def __init__(self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Callable = None, + drop_tks: bool = True): + super().__init__(k_value=2, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks) + + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): + # inputs: [s, h] + if self.noisy_func is not None and self.training: + inputs = self.noisy_func(inputs) + + assert inputs.dtype == torch.float + logits = F.softmax(inputs, dim=-1) # logits: [s, e] + num_experts = logits.size(-1) + capacity = self.get_capacity(logits.shape) + + top1_idx = torch.argmax(logits, dim=-1) + mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) + logits_except1 = logits.masked_fill(mask1.bool(), float("-inf")) + top2_idx = torch.argmax(logits_except1, dim=-1) + mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) + + cmask = (mask1 + mask2) # loss: [s, e] + + # caculate the auxiliary loss + me = torch.mean(logits, dim=0) + ce = torch.mean(cmask.float(), dim=0) + l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 + self.set_routing_loss(l_aux) + + if not self.training and not self.drop_tks: + max_num = torch.max(torch.sum(cmask, dim=0)) + dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) + capacity = max_num.item() + + rank1 = moe_cumsum(mask1) # rank1: [s, e] + rank2 = moe_cumsum(mask2) + rank2 += torch.sum(mask1, dim=-2, keepdim=True) + + mask1 *= torch.lt(rank1, capacity) + mask2 *= torch.lt(rank2, capacity) + + rank1 = torch.sum(mask1 * rank1, dim=-1) + rank2 = torch.sum(mask2 * rank2, dim=-1) + + if use_kernel: + mask1 = torch.sum(mask1, dim=-1) + mask2 = torch.sum(mask2, dim=-1) + + mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) + dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) + + return logits, mask, dest_idx, num_experts * capacity + else: + weight1 = mask1 * logits.type_as(inputs) + weight2 = mask2 * logits.type_as(inputs) + rank1_sc = F.one_hot(rank1, num_classes=capacity) + rank2_sc = F.one_hot(rank2, num_classes=capacity) + + cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) + cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) + cb_weight = cb_weight1 + cb_weight2 + sec_mask = cb_weight.bool() + + return cb_weight, sec_mask + + +def get_router_cls(top_k: int) -> MoeRouter: + if top_k == 1: + router_cls = Top1Router + elif top_k == 2: + router_cls = Top2Router + else: + raise NotImplementedError("top_k > 2 is not supported yet") + return router_cls diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py index 4ca8bd703386..eb3bef70998d 100644 --- a/colossalai/nn/layer/moe/utils.py +++ b/colossalai/nn/layer/moe/utils.py @@ -1,68 +1,82 @@ -import torch -import torch.nn.functional as F -from colossalai.utils import get_current_device -from colossalai.context.moe_context import MOE_CONTEXT -from .experts import FFNExperts, TPExperts - - -class ForceFP32Parameter(torch.nn.Parameter): - - def half(self, memory_format=None): - return self.data.clone() - - -class NormalNoiseGenerator: - """Generates a random noisy mask for logits tensor. - - All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where - `E = the number of experts`. - - Args: - num_experts (int): The number of experts. - """ - - def __init__(self, num_experts: int): - self.normal = torch.distributions.normal.Normal(loc=torch.tensor(0.0, device=get_current_device()), - scale=torch.tensor(1.0 / num_experts**2, - device=get_current_device())).rsample - - def __call__(self, inputs: torch.Tensor): - noisy = self.normal(inputs.shape) - return inputs + noisy - - -class UniformNoiseGenerator: - """Generates a random noisy mask for logits tensor. - copied from mesh tensorflow: - Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`. - Makes models more resilient to rounding errors introduced by bfloat16. - This seems particularly important for logits. - - Args: - eps (float, optional): Epsilon in generator, defaults 1e-2. - """ - - def __init__(self, eps: float = 1e-2): - self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()), - high=torch.tensor(1.0 + eps, - device=get_current_device())).rsample - - def __call__(self, inputs: torch.Tensor): - noisy = self.uniform(inputs.shape) - return inputs * noisy - - -def autocast_softmax(logit: torch.Tensor, dim: int): - if logit.dtype != torch.float32: - logit = logit.float() - return F.softmax(logit, dim=dim) - - -def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - mep_size = MOE_CONTEXT.max_ep_size - if num_experts % mep_size == 0 or mep_size % num_experts == 0: - return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate) - elif d_ff % mep_size == 0: - return TPExperts(num_experts, d_model, d_ff, activation, drop_rate) - else: - raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.") +from typing import Callable + +import torch +import torch.nn.functional as F + +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.utils import get_current_device + +from .experts import EPMLPExperts, TPMLPExperts + + +class ForceFP32Parameter(torch.nn.Parameter): + + def half(self, memory_format=None): + return self.data.clone() + + +class NormalNoiseGenerator: + """Generates a random noisy mask for logits tensor. + + All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where + `E = the number of experts`. + + Args: + num_experts (int): The number of experts. + """ + + def __init__(self, num_experts: int): + self.normal = torch.distributions.normal.Normal(loc=torch.tensor(0.0, device=get_current_device()), + scale=torch.tensor(1.0 / num_experts**2, + device=get_current_device())).rsample + + def __call__(self, inputs: torch.Tensor): + noisy = self.normal(inputs.shape) + return inputs + noisy + + +class UniformNoiseGenerator: + """Generates a random noisy mask for logits tensor. + copied from mesh tensorflow: + Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`. + Makes models more resilient to rounding errors introduced by bfloat16. + This seems particularly important for logits. + + Args: + eps (float, optional): Epsilon in generator, defaults 1e-2. + """ + + def __init__(self, eps: float = 1e-2): + self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()), + high=torch.tensor(1.0 + eps, + device=get_current_device())).rsample + + def __call__(self, inputs: torch.Tensor): + noisy = self.uniform(inputs.shape) + return inputs * noisy + + +def autocast_softmax(logit: torch.Tensor, dim: int): + return F.softmax(logit, dim=dim, detype=torch.float32) + + +def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): + mep_size = MOE_CONTEXT.max_ep_size + if num_experts % mep_size == 0 or mep_size % num_experts == 0: + return EPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate) + elif d_ff % mep_size == 0: + return TPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate) + else: + raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.") + + +def get_noise_generator(noise_type: str, num_experts: int) -> Callable: + if noise_type is None: + return None + elif noise_type == 'Jitter': + noisy_func = UniformNoiseGenerator() + elif noise_type == 'Gaussian': + noisy_func = NormalNoiseGenerator(num_experts) + else: + raise NotImplementedError("Unsupported input noisy policy") + return noisy_func diff --git a/colossalai/tensor/moe_tensor/api.py b/colossalai/tensor/moe_tensor/api.py index 11d07ef8c804..b9b6d338438e 100644 --- a/colossalai/tensor/moe_tensor/api.py +++ b/colossalai/tensor/moe_tensor/api.py @@ -1,20 +1,25 @@ import torch +import torch.distributed as dist +from colossalai.tensor import ProcessGroup -def is_moe_param(tensor: torch.Tensor) -> bool: +from .moe_info import MoeParallelInfo + + +def is_moe_tensor(tensor: torch.Tensor) -> bool: """ - Check whether the given tensor is a moe param. + Check whether the given tensor is a moe tensor. Args: tensor (torch.Tensor): The tensor to be checked. Returns: - bool: Whether the given tensor is a moe param. + bool: Whether the given tensor is a moe tensor. """ return hasattr(tensor, "moe_info") -def set_moe_param_info(tensor: torch.Tensor, moe_info: dict) -> None: +def set_moe_tensor_info(tensor: torch.Tensor, moe_info: MoeParallelInfo) -> None: """ Set moe info for the given tensor. @@ -24,3 +29,81 @@ def set_moe_param_info(tensor: torch.Tensor, moe_info: dict) -> None: """ tensor.__setattr__('moe_info', moe_info) + + +def get_moe_info(ep_size: int, dp_size: int) -> MoeParallelInfo: + """ + Get moe info for the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + dict: The moe info of the given tensor. + """ + return MoeParallelInfo(ep_size, dp_size) + + +def get_ep_group(tensor: torch.Tensor) -> ProcessGroup: + """ + Get the expert parallel group of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + torch.distributed.ProcessGroup: The expert parallel group of the given tensor. + """ + return tensor.moe_info.ep_group + + +def get_ep_size(tensor: torch.Tensor) -> int: + """ + Get the expert parallel size of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + int: The expert parallel size of the given tensor. + """ + return tensor.moe_info.ep_size + + +def get_dp_group(tensor: torch.Tensor) -> ProcessGroup: + """ + Get the data parallel group of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + torch.distributed.ProcessGroup: The data parallel group of the given tensor. + """ + return tensor.moe_info.dp_group + + +def get_ep_rank(tensor: torch.Tensor) -> int: + """ + Get the expert parallel rank of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + int: The expert parallel rank of the given tensor. + """ + return dist.get_rank(get_ep_group(tensor)) + + +def get_dp_rank(tensor: torch.Tensor) -> int: + """ + Get the data parallel rank of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + int: The data parallel rank of the given tensor. + """ + return dist.get_rank(get_dp_group(tensor)) diff --git a/colossalai/tensor/moe_tensor/moe_info.py b/colossalai/tensor/moe_tensor/moe_info.py new file mode 100644 index 000000000000..89f79f162b5b --- /dev/null +++ b/colossalai/tensor/moe_tensor/moe_info.py @@ -0,0 +1,15 @@ +from colossalai.cluster import ProcessGroupMesh + + +class MoeParallelInfo: + """Moe parallelism information, storing parallel sizes and groups. + """ + + def __init__(self, ep_size: int, dp_size: int): + self.dp_axis = 0 + self.dp_size = dp_size + self.ep_axis = 1 + self.ep_size = ep_size + self.pg = ProcessGroupMesh(self.dp_size, self.ep_size) + self.ep_group = self.pg.get_group_along_axis(self.ep_axis) + self.dp_group = self.pg.get_group_along_axis(self.dp_axis) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 3516e4df4bba..17624856e7ce 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -16,7 +16,7 @@ ) from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger -from colossalai.tensor.moe_tensor.api import is_moe_param +from colossalai.tensor.moe_tensor.api import is_moe_tensor # from colossalai.tensor import ColoParameter, ProcessGroup from colossalai.utils.cuda import get_current_device @@ -146,7 +146,7 @@ def __init__( for param in param_group['params']: if param.requires_grad: # skip moe param - if is_moe_param(param): + if is_moe_tensor(param): moe_params.append(param) continue group_params.append(param) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 4f901671a4ba..58a3567d1e5a 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -17,4 +17,4 @@ requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggi SentencePiece datasets ninja -flash-attn>=2.0 +flash-attn==2.0.5 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 65eecce2c34f..f6be6a624c70 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -10,5 +10,4 @@ contexttimer ninja torch>=1.11 safetensors -flash_attn>=2.0 einops diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index d86d78886e23..b57567e74be3 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -1,3 +1,5 @@ +import torch +import torch.distributed as dist import torch.nn as nn from colossalai.context import MOE_CONTEXT @@ -7,26 +9,24 @@ from colossalai.engine.gradient_handler._base_gradient_handler import BaseGradientHandler from colossalai.engine.gradient_handler.utils import bucket_allreduce from colossalai.nn import CheckpointModule -from colossalai.nn.layer import MoeModule +from colossalai.nn.layer import SparseMLP from colossalai.registry import GRADIENT_HANDLER +from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor from colossalai.utils.moe import get_moe_epsize_param_dict class MoeModel(nn.Module): - def __init__(self, checkpoint: bool = False): + def __init__(self, checkpoint: bool = False, expert_parallel: str = "EP"): class TestSubModule(CheckpointModule): def __init__(self): super().__init__(checkpoint) - expert_cls = nn.Linear - expert_args_dict = dict(in_features=16, out_features=16) - self.moe = MoeModule(dim_model=16, - num_experts=8, - use_residual=True, - expert_cls=expert_cls, - **expert_args_dict) + self.moe = SparseMLP(num_experts=8, + expert_parallel=expert_parallel, + hidden_size=16, + intermediate_size=32) self.proj = nn.Linear(16, 4) def _forward(self, x): @@ -84,3 +84,46 @@ def handle_gradient(self): if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: bucket_allreduce(param_list=epsize_param_dict[ep_size], group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group) + + +def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: + """Sync the parameters of tp model from ep model + + Args: + tp_model (MoeModule) + ep_model (MoeModule) + """ + for (tp_name, tp_param), (ep_name, ep_param) in zip(tp_model.named_parameters(), ep_model.named_parameters()): + assert tp_name == ep_name + if not is_moe_tensor(tp_param): + if assert_grad_flag: + assert torch.allclose(tp_param, ep_param) + assert torch.allclose(tp_param.grad, ep_param.grad) + else: + tp_param.data.copy_(ep_param.data) + continue + + # gather param from ep model + param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) + all_param = torch.cat(param_list, dim=0) + if assert_grad_flag: + grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) + all_grad = torch.cat(grad_list, dim=0) + + # get tp param + tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape[1:], all_param.shape[1:])) if d1 != d2] + tp_rank = get_ep_rank(tp_param) + tp_dim = tp_dim[0] + 1 + tp_slice = [slice(None)] * tp_dim + [ + slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1)) + ] + new_tp_param = all_param[tuple(tp_slice)] + if assert_grad_flag: + new_grad = all_grad[tuple(tp_slice)] + if assert_grad_flag: + assert torch.allclose(tp_param, new_tp_param) + assert torch.allclose(tp_param.grad, new_grad) + else: + tp_param.data.copy_(new_tp_param.data) diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index 87f0f4b2abe4..6135a386e7c8 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -5,7 +5,7 @@ import colossalai from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, UniformNoiseGenerator +from colossalai.nn.layer.moe import EPMLPExperts, MoeLayer, Top1Router, UniformNoiseGenerator from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device from colossalai.utils.moe import sync_moe_model_param @@ -17,8 +17,7 @@ def run_test(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - expert_module = nn.Linear - expert_factor = dict(in_features=DIM, out_features=DIM, device=get_current_device()) + expert_factor = dict(hidden_size=DIM, intermediate_size=DIM * 2) MOE_CONTEXT.setup(42) # MOE initialization noisy_func = UniformNoiseGenerator() @@ -26,7 +25,7 @@ def run_test(rank, world_size, port): num_experts_list = [1, 2, 4] layer_list = [] for num_experts in num_experts_list: - exp = Experts(expert_module, num_experts, **expert_factor) + exp = EPMLPExperts(num_experts, **expert_factor) moe_layer = MoeLayer(DIM, num_experts, router, exp) layer_list.append(moe_layer) @@ -35,8 +34,10 @@ def run_test(rank, world_size, port): sync_moe_model_param(model) dist_dict = MOE_CONTEXT.parallel_info_dict - assert_equal_in_group(layer_list[0].experts.experts[0].weight.data, dist_dict[1].dp_group) - assert_equal_in_group(layer_list[1].experts.experts[0].weight.data, dist_dict[2].dp_group) + assert_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group) + assert_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group) + assert_equal_in_group(layer_list[1].experts.wi.data, dist_dict[2].dp_group) + assert_equal_in_group(layer_list[1].experts.wo.data, dist_dict[2].dp_group) # MoE model synchronization passed grad_handler = MoeGradientHandler(model, 0) @@ -52,11 +53,10 @@ def run_test(rank, world_size, port): data.backward(grad) grad_handler.handle_gradient() - assert_equal_in_group(layer_list[0].experts.experts[0].weight.grad, dist_dict[1].dp_group) - assert_equal_in_group(layer_list[0].experts.experts[0].bias.grad, dist_dict[1].dp_group) - - assert_equal_in_group(layer_list[1].experts.experts[0].weight.grad, dist_dict[2].dp_group) - assert_equal_in_group(layer_list[1].experts.experts[0].bias.grad, dist_dict[2].dp_group) + assert_equal_in_group(layer_list[0].experts.wi.grad, dist_dict[1].dp_group) + assert_equal_in_group(layer_list[0].experts.wo.grad, dist_dict[1].dp_group) + assert_equal_in_group(layer_list[1].experts.wi.grad, dist_dict[2].dp_group) + assert_equal_in_group(layer_list[1].experts.wo.grad, dist_dict[2].dp_group) # MoE grad handler test passed diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 9d11fd9bcd6d..867437f00c82 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -1,12 +1,11 @@ import pytest import torch -import torch.nn as nn import colossalai from colossalai.context import ParallelMode from colossalai.context.moe_context import MOE_CONTEXT from colossalai.core import global_context as gpc -from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, Top2Router +from colossalai.nn.layer.moe import EPMLPExperts, MoeLayer, Top1Router, Top2Router from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device @@ -32,9 +31,8 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f # get randomized data tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True) - expert_module = nn.Linear - expert_factor = dict(in_features=hidden_size, out_features=hidden_size, device=get_current_device()) - expert = Experts(expert_module, NUM_EXPERTS, **expert_factor) + expert_factor = dict(hidden_size=hidden_size, intermediate_size=hidden_size * 2) + expert = EPMLPExperts(NUM_EXPERTS, **expert_factor) layer = MoeLayer(hidden_size, NUM_EXPERTS, router(capacity_factor_train=1.0), expert) layer = layer.to(get_current_device()) if data_type == torch.float16: diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index df0fa164c068..fc51c2217233 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -6,21 +6,19 @@ import colossalai from colossalai.context import MOE_CONTEXT -from colossalai.nn.layer.moe import load_moe_model, save_moe_model +from colossalai.nn.layer.moe import MoeCheckpintIO from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext from tests.test_moe.moe_utils import MoeModel def exam_moe_checkpoint(): - with ColoInitContext(device=get_current_device()): - model = MoeModel(checkpoint=True) - save_moe_model(model, 'temp_path.pth') + ckpt = MoeCheckpintIO() + model = MoeModel(checkpoint=True).to(get_current_device()) + ckpt.save_model(model, 'temp_path.pth') - with ColoInitContext(device=get_current_device()): - other_model = MoeModel(checkpoint=True) - load_moe_model(other_model, 'temp_path.pth') + other_model = MoeModel(checkpoint=True).to(get_current_device()) + ckpt.load_model(other_model, 'temp_path.pth') state_0 = model.state_dict() state_1 = other_model.state_dict() @@ -42,7 +40,7 @@ def _run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [2, 4]) @rerun_if_address_is_in_use() def test_moe_checkpoint(world_size): - spawn(_run_dist) + spawn(_run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py new file mode 100644 index 000000000000..13c66cf73e4d --- /dev/null +++ b/tests/test_moe/test_moe_ep_tp.py @@ -0,0 +1,63 @@ +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.nn.layer.moe import SparseMLP +from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from colossalai.utils.moe import sync_moe_model_param +from tests.test_moe.moe_utils import MoeGradientHandler, sync_tp_from_ep + +BATCH_SIZE = 4 +DIM = 4 + + +def run_test(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(42) # MOE initialization + + ep_model = SparseMLP(num_experts=4, expert_parallel="EP", hidden_size=DIM, intermediate_size=DIM) + tp_model = SparseMLP(num_experts=4, expert_parallel="TP", hidden_size=DIM, intermediate_size=DIM) + ep_model = ep_model.to(get_current_device()) + tp_model = tp_model.to(get_current_device()) + + # sync ep param + sync_moe_model_param(ep_model) + dist_dict = MOE_CONTEXT.parallel_info_dict + assert_equal_in_group(ep_model.experts.wi.data, dist_dict[2].dp_group) + assert_equal_in_group(ep_model.experts.wo.data, dist_dict[2].dp_group) + grad_handler = MoeGradientHandler(ep_model) + # sync tp param + sync_tp_from_ep(tp_model, ep_model) + + rank = dist.get_rank() + torch.cuda.manual_seed(78) + tp_data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) + ep_data = tp_data.detach()[2 * rank:2 * (rank + 1)] + + out_tp = tp_model(tp_data)[0] + MOE_CONTEXT.reset_loss() + out_ep = ep_model(ep_data)[0] + MOE_CONTEXT.reset_loss() + assert torch.allclose(out_ep, out_tp[2 * rank:2 * (rank + 1)]) + + out_tp.mean().backward() + out_ep.mean().backward() + grad_handler.handle_gradient() + + assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[2].dp_group) + assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[2].dp_group) + + sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_moe_ep_tp(): + spawn(run_test, 2) + + +if __name__ == '__main__': + test_moe_ep_tp() diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index d073e0e5c08f..fd87a9a3135d 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -4,36 +4,37 @@ import colossalai from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.nn.layer.moe import Experts +from colossalai.nn.layer.moe import EPMLPExperts, TPMLPExperts from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device from colossalai.utils.moe import sync_moe_model_param -D_MODEL = 4 -D_FF = 8 +HIDDEN_SIZE = 4 +INTERMEDIATE_SIZE = 8 -def run_test(rank, world_size, port): - world_size = 4 - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - expert_module = nn.Linear - expert_factor = dict(in_features=D_MODEL, out_features=D_FF, device=get_current_device()) - - MOE_CONTEXT.setup(42) # MOE environment initialization - exp0 = Experts(expert_module, 1, **expert_factor) - exp1 = Experts(expert_module, 2, **expert_factor) - exp2 = Experts(expert_module, 4, **expert_factor) - exp3 = Experts(expert_module, 8, **expert_factor) +def run_moe_init(expert_cls): + expert_args = dict(hidden_size=HIDDEN_SIZE, intermediate_size=INTERMEDIATE_SIZE) + exp0 = expert_cls(1, **expert_args) + exp1 = expert_cls(2, **expert_args) + exp2 = expert_cls(4, **expert_args) + exp3 = expert_cls(8, **expert_args) - assert exp0.num_local_experts == 1 - assert exp1.num_local_experts == 1 - assert exp2.num_local_experts == 1 - assert exp3.num_local_experts == 2 - # experts deployment passed + if expert_cls is EPMLPExperts: + assert exp0.num_local_experts == 1 + assert exp1.num_local_experts == 1 + assert exp2.num_local_experts == 1 + assert exp3.num_local_experts == 2 + else: + assert exp0.num_local_experts == 1 + assert exp1.num_local_experts == 2 + assert exp2.num_local_experts == 4 + assert exp3.num_local_experts == 8 parallel_info_dict = MOE_CONTEXT.parallel_info_dict rank = dist.get_rank() + # group creation assert assert len(parallel_info_dict) == 3 assert dist.get_rank(parallel_info_dict[4].ep_group) == rank assert dist.get_rank(parallel_info_dict[2].ep_group) == rank % 2 @@ -42,26 +43,33 @@ def run_test(rank, world_size, port): assert dist.get_rank(parallel_info_dict[4].dp_group) == 0 assert dist.get_rank(parallel_info_dict[2].dp_group) == rank // 2 assert dist.get_rank(parallel_info_dict[1].dp_group) == rank - # group creation passed model = nn.ModuleList([exp0, exp1, exp2, exp3]) model = model.to(get_current_device()) sync_moe_model_param(model) - assert_equal_in_group(exp0.experts[0].weight.data, parallel_info_dict[1].dp_group) - assert_equal_in_group(exp0.experts[0].bias.data, parallel_info_dict[1].dp_group) # MOE experts layout success when ep_size = 1 + assert_equal_in_group(exp0.wi.data, parallel_info_dict[1].dp_group) + assert_equal_in_group(exp0.wo.data, parallel_info_dict[1].dp_group) - assert_equal_in_group(exp1.experts[0].weight.data, parallel_info_dict[2].dp_group) - assert_equal_in_group(exp1.experts[0].bias.data, parallel_info_dict[2].dp_group) # MOE experts layout success when ep_size = 2 + assert_equal_in_group(exp1.wi.data, parallel_info_dict[2].dp_group) + assert_equal_in_group(exp1.wo.data, parallel_info_dict[2].dp_group) + + +def _run_test(rank, world_size, port, expert_cls): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + run_moe_init(expert_cls) @pytest.mark.dist +@pytest.mark.parametrize("expert_cls", [EPMLPExperts, TPMLPExperts]) @rerun_if_address_is_in_use() -def test_moe_initialization(): - spawn(run_test, 4) +def test_moe_initialization(expert_cls): + spawn(_run_test, 4, expert_cls=expert_cls) if __name__ == '__main__': - test_moe_initialization() + test_moe_initialization(EPMLPExperts) + test_moe_initialization(TPMLPExperts) diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py index e2acb0702f1c..9d19ee830f77 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd.py +++ b/tests/test_moe/test_moe_zero_fwd_bwd.py @@ -55,7 +55,6 @@ def run_zero_test(local_rank, world_size, stage=1): grad_handler = MoeGradientHandler(torch_model) # assert zero model - assert len(zero_model.module.test_transform.moe.moe_layer.experts.experts) == 8 // MOE_CONTEXT.world_size for (torch_name, torch_param), (zero_name, zero_param) in zip(torch_model.named_parameters(), zero_model.module.named_parameters()): assert zero_name == torch_name