diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py index 1614987538c1..f32e89dfad3f 100644 --- a/colossalai/moe/__init__.py +++ b/colossalai/moe/__init__.py @@ -1,12 +1,17 @@ from .checkpoint import MoeCheckpintIO -from .experts import EPMLPExperts, TPMLPExperts, build_ffn_experts +from .experts import MLPExperts from .layers import SparseMLP from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter from .utils import NormalNoiseGenerator, UniformNoiseGenerator __all__ = [ - 'EPMLPExperts', 'TPMLPExperts', 'build_ffn_experts', - 'MoeRouter', 'Top1Router', 'Top2Router', 'TopKRouter', - 'NormalNoiseGenerator', 'UniformNoiseGenerator', - 'SparseMLP', 'MoeCheckpintIO' + "MLPExperts", + "MoeRouter", + "Top1Router", + "Top2Router", + "TopKRouter", + "NormalNoiseGenerator", + "UniformNoiseGenerator", + "SparseMLP", + "MoeCheckpintIO", ] diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index 9b4988b345d9..14e0935b72e4 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -227,10 +227,10 @@ def backward(ctx, tokens_grad): return d_expert, d_logits, None, None, None -def moe_cumsum(inputs: Tensor): +def moe_cumsum(inputs: Tensor, use_kernel: bool = False): dim0 = inputs.size(0) flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) - if flag and MOE_MANAGER.use_kernel_optim: + if flag and use_kernel: if MOE_KERNEL is None: load_moe() return MOE_KERNEL.cumsum_sub_one(inputs) diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index 076f160adb79..3471b2876e9b 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -15,18 +15,19 @@ from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine -class BaseMLPExperts(nn.Module): +class MLPExperts(nn.Module): """ SparseMLP is a multi-layer perceptron with sparse expert parallel layers. Args: num_experts (int): The number of experts - forward: hidden_size --> intermediate_size --> hidden_size - hidden_size (int): The hidden size of MLP - intermediate_size (int): The intermediate size of MLP - expert_parallel (str, optional): The parallelism of experts. Now we have 'EP' and 'TP'. + hidden_size (int): The hidden size of MLP + intermediate_size (int): The intermediate size of MLP + expert_parallel (str, optional): The parallelism of experts. Now we have None, EP and TP. activation (optional): The activation function of MLP drop_rate (float, optional): The drop rate of MLP + gated (bool, optional): Whether to use gated MLP + use_kernel (bool, optional): Whether to use kernel optimization """ def __init__( @@ -36,9 +37,9 @@ def __init__( intermediate_size: int, expert_parallel: Optional[str] = None, activation: Optional[Callable] = None, - drop_rate: float = 0, - gated: bool = False, - use_kernel: bool = False, + drop_rate: Optional[float] = 0, + gated: Optional[bool] = False, + use_kernel: Optional[bool] = False, ): super().__init__() assert expert_parallel in ["EP", "TP", None] @@ -97,8 +98,15 @@ def reset_parameters(self): torch.nn.init.normal_(self.wi, std=math.sqrt(0.1 / self.hidden_size)) torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size)) - def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + param_slice: Tuple[slice] = (slice(None),), + use_sparse: bool = True, + ) -> torch.Tensor: """ + forward: hidden_size --> intermediate_size --> hidden_size + Args: x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size) @@ -114,6 +122,16 @@ def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) - inshape = x.shape x = x.reshape(e, -1, h) + if self.use_kernel and use_sparse: + seq_len = x.shape[1] + with torch.no_grad(): + mask = x[:, :, 0] != 0.0 + mask = torch.sum(mask, dim=-1) + x_list = [] + for i in range(e): + x_list.append(x[i, :mask[i]]) + x = x_list + if self.gated: x_gate = [torch.mm(x[i], self.wi_gate[param_slice][i]) for i in range(e)] x_up = [torch.mm(x[i], self.wi_up[param_slice][i]) for i in range(e)] @@ -127,86 +145,12 @@ def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) - x = [self.drop(x[i]) for i in range(e)] x = [torch.mm(x[i], self.wo[param_slice][i]) for i in range(e)] + if self.use_kernel and use_sparse: + for i in range(e): + x[i] = torch.nn.functional.pad(x[i], (0, 0, 0, seq_len - x[i].shape[0]), mode="constant", value=0) + x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0) x = x.reshape(inshape) x = x.transpose(0, 1).contiguous() x = MoeOutGradScaler.apply(x, self.ep_size) return x - - -class EPMLPExperts(BaseMLPExperts): - """ - Use expert parallelism to split each expert evenly, which can deploy experts in - """ - - def __init__( - self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - activation=None, - drop_rate: float = 0, - gated: bool = False, - use_kernel: bool = False, - ): - # TODO: This class can be aborted - super().__init__( - num_experts, - hidden_size, - intermediate_size, - "EP", - activation, - drop_rate, - gated, - use_kernel, - ) - - -class TPMLPExperts(BaseMLPExperts): - """Use tensor parallelism to split each expert evenly, which can deploy experts in - case that the number of experts can't be divide by maximum expert parallel size or - maximum expert parallel size can't be divide by the number of experts. - """ - - def __init__( - self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - activation: str = None, - drop_rate: float = 0, - gated: bool = False, - use_kernel: bool = False, - ): - # TODO: This class can be aborted - super().__init__( - num_experts, - hidden_size, - intermediate_size, - "TP", - activation, - drop_rate, - gated, - use_kernel, - ) - - -def get_expert_class(name: str) -> BaseMLPExperts: - if name == "TP": - return TPMLPExperts - elif name == "EP": - return EPMLPExperts - elif name is None: - return BaseMLPExperts - else: - raise ValueError(f"Unknown expert class name: {name}") - - -def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - mep_size = MOE_MANAGER.max_ep_size - if num_experts % mep_size == 0 or mep_size % num_experts == 0: - return EPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate) - elif d_ff % mep_size == 0: - return TPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate) - else: - raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.") diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index 9846cd432b53..bd2cefbe9ab8 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter -from colossalai.moe.experts import BaseMLPExperts, get_expert_class +from colossalai.moe.experts import MLPExperts from colossalai.moe.load_balance import LoadBalancer from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.routers import MoeRouter, get_router_cls @@ -48,50 +48,59 @@ class SparseMLP(nn.Module): def __init__( self, num_experts: int, - top_k: int = 1, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_policy: Optional[str] = None, - drop_tks: bool = True, - hidden_size: int = 2048, - intermediate_size: int = 2048, - activation: str = None, - gated: bool = False, + hidden_size: int, + intermediate_size: int, + router_top_k: int = 1, + router_capacity_factor_train: Optional[float] = 1.25, + router_capacity_factor_eval: Optional[float] = 2.0, + router_min_capacity: Optional[int] = 4, + router_noisy_policy: Optional[str] = None, + router_drop_tks: Optional[bool] = True, + mlp_activation: Optional[str] = None, + mlp_gated: Optional[bool] = False, + enable_load_balance: Optional[bool] = False, + load_balance_tolerance: Optional[float] = 0.1, + load_balance_beam_width: Optional[int] = 8, + load_balance_group_swap_factor: Optional[float] = 0.4, + enable_kernel: Optional[bool] = False, + enable_comm_overlap: Optional[bool] = False, ): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_experts = num_experts - self.use_kernel = MOE_MANAGER.use_kernel_optim + self.gated = mlp_gated + self.enable_kernel = enable_kernel + self.enable_comm_overlap = enable_comm_overlap self.expert_parallel = MOE_MANAGER.get_parallel() - self.gated = gated - assert self.expert_parallel in [ - "EP", - "TP", - None, - ], f"Unsupported expert parallel type {self.expert_parallel}" # moe router - noisy_func = get_noise_generator(noisy_policy, num_experts) - router_cls = get_router_cls(top_k) - self.topk = top_k + noisy_func = get_noise_generator(router_noisy_policy, num_experts) + router_cls = get_router_cls(router_top_k) + self.topk = router_top_k self.router: MoeRouter = router_cls( - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, + capacity_factor_train=router_capacity_factor_train, + capacity_factor_eval=router_capacity_factor_eval, + min_capacity=router_min_capacity, noisy_func=noisy_func, - drop_tks=drop_tks, + drop_tks=router_drop_tks, ) + # gate + self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) + # moe experts - expert_cls = get_expert_class(self.expert_parallel) - self.experts: BaseMLPExperts = expert_cls(num_experts=num_experts, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - activation=activation, - gated=gated, - use_kernel=self.use_kernel) + self.experts = MLPExperts( + num_experts=self.num_experts, + expert_parallel=self.expert_parallel, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + activation=mlp_activation, + gated=mlp_gated, + use_kernel=self.enable_kernel, + ) + + # get parallel settings if self.expert_parallel is not None: self.ep_group = get_ep_group(self.experts) self.ep_size = get_ep_size(self.experts) @@ -101,11 +110,8 @@ def __init__( self.dp_group = None self.num_local_experts = self.experts.num_local_experts - # gate - self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) - # load balance - self.enable_load_balance = MOE_MANAGER.load_balance + self.enable_load_balance = enable_load_balance if self.enable_load_balance == True: self.load_balancer = LoadBalancer( experts=self.experts, @@ -114,9 +120,9 @@ def __init__( expert_num=self.num_experts, ep_group=self.ep_group, dp_group=self.dp_group, - tolerance=MOE_MANAGER.tolerance, - beam_width=MOE_MANAGER.beam_width, - group_swap_factor=MOE_MANAGER.group_swap_factor, + tolerance=load_balance_tolerance, + beam_width=load_balance_beam_width, + group_swap_factor=load_balance_group_swap_factor, ) # init param @@ -147,14 +153,15 @@ def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: with torch.no_grad(): # TODO: optimize computation expert_load = torch.topk(gate_output, k=self.topk, dim=-1)[1] + # TODO: bincount introduces synchronize, fix it expert_load = torch.bincount(expert_load.view(-1)) self.load_balancer.update_load(expert_load) # the result from the router - route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) + route_result_list = self.router(inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group) # dispatch_data: (num_experts, capacity, hidden_size) - if self.use_kernel: + if self.enable_kernel: dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:]) dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.hidden_size) else: @@ -163,16 +170,16 @@ def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # expert_output: (num_groups, num_experts, capacity, hidden_size) if self.expert_parallel == "EP": - expert_output = self._ep_process(dispatch_data) + expert_output = self._ep_process(dispatch_data, overlap=self.enable_comm_overlap) elif self.expert_parallel == "TP": - expert_output = self._tp_process(dispatch_data) + expert_output = self._tp_process(dispatch_data, overlap=self.enable_comm_overlap) elif self.expert_parallel is None: expert_output = self._local_process(dispatch_data) else: raise NotImplementedError("This kind of communication has not been implemented yet.\n" "Please use Experts build function.") - if self.use_kernel: + if self.enable_kernel: expert_output = expert_output.reshape(-1, self.hidden_size) ans = MoeCombine.apply(expert_output, *route_result_list) else: @@ -189,10 +196,7 @@ def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: expert_out = self.experts(expert_in) return expert_out - def _ep_process(self, - dispatch_data: torch.Tensor, - overlap: bool = True - ) -> torch.Tensor: + def _ep_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> torch.Tensor: """ Expert Parallel @@ -210,16 +214,16 @@ def _ep_process(self, return expert_output else: + @dataclasses.dataclass - class Capsule(): + class Capsule: data: torch.Tensor handle: Any = None - NUM_CHUNK = 2 + NUM_CHUNK = 4 NUM_STAGES = 4 - assert dispatch_data.shape[1] % NUM_CHUNK == 0, \ - "arbitrary chunk num is not supported yet" + assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet" chunk_size = dispatch_data.shape[1] // NUM_CHUNK input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size) dispatch_data = dispatch_data.reshape(*input_shape) @@ -238,24 +242,17 @@ class Capsule(): # all2all last output if _expert_out is not None: - expert_out = Capsule( - *AllToAll.apply(_expert_out.data, self.ep_group, True), - ) + expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),) _expert_out = None # all2all next input if 0 <= i < NUM_CHUNK: - _expert_in = Capsule( - *AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True) - ) + _expert_in = Capsule(*AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True)) # compute if expert_in is not None: expert_in.handle.wait() - _expert_out = Capsule( - data=self.experts(expert_in.data), - handle=None - ) + _expert_out = Capsule(data=self.experts(expert_in.data), handle=None) expert_in = None if _expert_in is not None: @@ -264,10 +261,7 @@ class Capsule(): return output - def _tp_process(self, - dispatch_data: torch.Tensor, - overlap: bool = True - ) -> torch.Tensor: + def _tp_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> torch.Tensor: """ without overlap: | C | @@ -291,23 +285,24 @@ def _tp_process(self, expert_out = ReduceScatter.apply(expert_out, self.ep_group, False)[0] return expert_out else: + @dataclasses.dataclass - class Capsule(): + class Capsule: data: torch.Tensor handle: Any indices: Tuple - NUM_CHUNK = 2 + NUM_CHUNK = 4 NUM_STAGES = 4 - assert dispatch_data.shape[0] % NUM_CHUNK == 0, \ - "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" + assert (dispatch_data.shape[0] % NUM_CHUNK == 0 + ), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" chunk_size = dispatch_data.shape[0] // NUM_CHUNK chunk_data = torch.split(dispatch_data, chunk_size, dim=0) output = torch.empty_like(dispatch_data) def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]: - return (slice(idx * chunk_size, (idx + 1) * chunk_size), ) + return (slice(idx * chunk_size, (idx + 1) * chunk_size),) _expert_in, expert_in, _expert_out, expert_out = None, None, None, None @@ -321,7 +316,7 @@ def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]: if _expert_out is not None: expert_out = Capsule( *ReduceScatter.apply(_expert_out.data, self.ep_group, True), - indices=_expert_out.indices + indices=_expert_out.indices, ) _expert_out = None @@ -329,7 +324,7 @@ def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]: if 0 <= i < NUM_CHUNK: _expert_in = Capsule( *AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True), - indices=get_chunk_slice(i, chunk_size) + indices=get_chunk_slice(i, chunk_size), ) # compute @@ -337,7 +332,8 @@ def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]: expert_in.handle.wait() _expert_out = Capsule( self.experts(expert_in.data, expert_in.indices), - handle=None, indices=expert_in.indices + handle=None, + indices=expert_in.indices, ) expert_in = None @@ -360,4 +356,6 @@ def _apply_recursive(module: nn.Module): sub_module.load_balancer.balance_load(optim) _apply_recursive(sub_module) + torch.cuda.empty_cache() _apply_recursive(model) + torch.cuda.empty_cache() diff --git a/colossalai/moe/load_balance.py b/colossalai/moe/load_balance.py index b2fb672329c2..4a3d0fe4d096 100644 --- a/colossalai/moe/load_balance.py +++ b/colossalai/moe/load_balance.py @@ -7,7 +7,7 @@ from torch.distributed import ProcessGroup from colossalai.cluster import ProcessGroupMesh -from colossalai.moe.experts import BaseMLPExperts +from colossalai.moe.experts import MLPExperts from colossalai.moe.manager import MOE_MANAGER from colossalai.zero.low_level import LowLevelZeroOptimizer @@ -16,7 +16,7 @@ class LoadBalancer: def __init__( self, - experts: BaseMLPExperts, + experts: MLPExperts, gate: nn.Parameter, local_expert_num: int, expert_num: int, @@ -26,7 +26,7 @@ def __init__( beam_width: Optional[int] = 8, group_swap_factor: Optional[float] = 0.4, ) -> None: - self.experts: BaseMLPExperts = experts + self.experts: MLPExperts = experts self.gate: nn.Parameter = gate self.moe_ep_group: ProcessGroup = ep_group self.moe_ep_ranks = MOE_MANAGER.parallel_info_dict[dist.get_world_size(self.moe_ep_group)].ep_group_ranks @@ -423,6 +423,11 @@ def balance_load(self, optim: LowLevelZeroOptimizer) -> None: load = self._load_to_list(load) # search balance swap_list = self._search_balance(load) + if dist.get_rank() == 0: + if len(swap_list) > 0: + print(f"[Load Balance] Applying expert swap...") + else: + print(f"[Load Balance] Invalid swap, skip...") # swap expert and gate self._swap_moe_param(swap_list, optim) # clear load diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py index e3659ef43fbd..f237ea134638 100644 --- a/colossalai/moe/manager.py +++ b/colossalai/moe/manager.py @@ -14,29 +14,29 @@ class MoeManager(metaclass=SingletonMeta): """ def __init__(self): - self.world_size = None - # Users may want to set maximum expert parallel size smaller than the world size - # since very low bandwidth across nodes may constrain the performance of MoE - # When we have a maximum expert parallel size, we have a minimum data parallel size naturally - self.max_ep_size = None - self.min_dp_size = None - self.router_aux_loss = [] - self.router_z_loss = [] self.parallel = None self.seed = None self.mode = None - self.use_kernel_optim = False self.use_ep_inside = None + self.world_size = None + self._parallel_info_dict = dict() + + # router + self.router_aux_loss = [] + self.router_z_loss = [] + + # fixed mode self.pp_size = None + self.dp_size = None + self.ep_size = None - # load balance param - self.load_balance = None - self.tolerance = None - self.beam_width = None - self.group_swap_factor = None + # dynamic mode + # Users may want to set maximum expert parallel size smaller than the world size + # since very low bandwidth across nodes may constrain the performance of MoE + # When we have a maximum expert parallel size, we have a minimum data parallel size naturally + self.max_ep_size = None self.has_setup = False - self._parallel_info_dict = dict() @property def parallel_info_dict(self): @@ -49,7 +49,6 @@ def is_initialized(self): def setup( self, seed: int, - use_kernel_optim: bool = False, parallel: str = None, mode: str = "dynamic", max_ep_size: int = 8, @@ -57,10 +56,6 @@ def setup( fixed_ep_size: int = 0, fixed_pp_size: int = 0, use_ep_inside: bool = True, - enable_load_balance: bool = False, - tolerance: float = 0.1, - beam_width: int = 8, - group_swap_factor: float = 0.4, ) -> None: """ Setup MoE distributed context. @@ -78,38 +73,28 @@ def setup( fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0. use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True. """ - assert not self.is_initialized, "MoE distributed context shouldn't be set up again" + assert (not self.is_initialized), "MoE distributed context shouldn't be set up again" assert torch.cuda.is_available(), "MoE requires to enable CUDA first" - self.world_size = dist.get_world_size() self.seed = seed + dist.get_rank() self.parallel = parallel self.use_ep_inside = use_ep_inside + self.world_size = dist.get_world_size() # init by mode self.mode = mode assert self.mode in ["fixed", "dynamic"], "mode should be fixed or dynamic" if self.mode == "dynamic": - self.max_ep_size = min(max_ep_size, dist.get_world_size()) - self.min_dp_size = self.world_size // self.max_ep_size + self.max_ep_size = min(max_ep_size, self.world_size) else: - assert fixed_dp_size > 0 and fixed_ep_size > 0 and fixed_pp_size > 0, "dp_size, ep_size and pp_size should be greater than 0" - assert isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) and isinstance( - fixed_pp_size, int), "dp_size, ep_size and pp_size should be int" + assert (fixed_dp_size > 0 and fixed_ep_size > 0 + and fixed_pp_size > 0), "dp_size, ep_size and pp_size should be greater than 0" + assert (isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) + and isinstance(fixed_pp_size, int)), "dp_size, ep_size and pp_size should be int" self.ep_size = fixed_ep_size self.dp_size = fixed_dp_size self.pp_size = fixed_pp_size - # Enabling kernel optimization may raise error in some cases - # Users can close kernel optimization manually - self.use_kernel_optim = use_kernel_optim - - # update load balance - self.load_balance = enable_load_balance - self.tolerance = tolerance - self.beam_width = beam_width - self.group_swap_factor = group_swap_factor - self.has_setup = True def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]: @@ -127,21 +112,13 @@ def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoePara """ if self.mode == "dynamic": - gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater - lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less - - assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \ - " is not a multiple of ep size or vice versa." - - # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size, - # there are multiple experts in each GPU and each GPU has different experts - # So it's data parallel size is 1 - # Otherwise, there is only one expert in each GPU - # The data parallel size should be calculated - dp_size = 1 if gt_flag else self.max_ep_size // num_experts - ep_size = self.max_ep_size // dp_size - # Don't forget to multiply minimum data parallel size - dp_size *= self.min_dp_size + gt_flag = (num_experts % self.max_ep_size == 0) # check whether num_experts is greater + lt_flag = (self.max_ep_size % num_experts == 0) # check whether num_experts is less + assert gt_flag or lt_flag, ("Automatic experts placement dose not not support expert number" + " is not a multiple of ep size or vice versa.") + dp_size = 1 if gt_flag else self.world_size // num_experts + ep_size = min(self.world_size // dp_size, self.max_ep_size) + dp_size = self.world_size // ep_size pp_size = 1 else: dp_size = self.dp_size @@ -167,13 +144,10 @@ def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoePara return num_local_experts, self.parallel_info_dict[ep_size] - def set_kernel_not_use(self): - self.use_kernel_optim = False - def reset_loss(self): self.router_aux_loss, self.router_z_loss = [], [] - def add_loss(self, aux_loss: float = 0., z_loss: float = 0.): + def add_loss(self, aux_loss: float = 0.0, z_loss: float = 0.0): self.router_aux_loss.append(aux_loss) self.router_z_loss.append(z_loss) diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index 1ac66f7bb78f..7960a74d4539 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -30,7 +30,8 @@ def __init__(self, capacity_factor_eval: float, min_capacity: int, noisy_func: Optional[Callable] = None, - drop_tks: bool = True): + drop_tks: bool = True, + use_kernel: bool = False): super().__init__() self.k_value = k_value self.capacity_factor_train = capacity_factor_train @@ -40,6 +41,7 @@ def __init__(self, self.drop_tks = drop_tks self._aux_loss = None self._z_loss = None + self.use_kernel = use_kernel def get_capacity(self, logits_shape): capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval @@ -49,11 +51,7 @@ def get_capacity(self, logits_shape): assert capacity > 0 return int(capacity) - def set_aux_loss(self, - router_probs: torch.Tensor, - expert_indices: torch.Tensor, - num_experts: int - ) -> None: + def set_aux_loss(self, router_probs: torch.Tensor, expert_indices: torch.Tensor, num_experts: int) -> None: """Computes auxiliary load balancing loss as in Switch Transformer. See Switch Transformer (https://arxiv.org/abs/2101.03961). This function @@ -81,8 +79,7 @@ def set_aux_loss(self, tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=-2) router_prob_per_group_and_expert = torch.mean(router_probs.float(), dim=-2) - aux_loss = num_experts**2 * torch.mean( - tokens_per_group_and_expert * router_prob_per_group_and_expert) + aux_loss = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) self._aux_loss = aux_loss def set_z_loss(self, router_logits: torch.Tensor): @@ -101,8 +98,7 @@ def set_z_loss(self, router_logits: torch.Tensor): assert router_logits.dim() == 3, "router_logits must be 3D tensor" num_groups, tokens_per_group, _ = router_logits.shape log_z = torch.logsumexp(router_logits, dim=-1) - z_loss = torch.sum(log_z**2, dtype=torch.float32 - ) / (num_groups * tokens_per_group) + z_loss = torch.sum(log_z**2, dtype=torch.float32) / (num_groups * tokens_per_group) self._z_loss = z_loss def pop_router_loss(self) -> torch.Tensor: @@ -113,8 +109,8 @@ def pop_router_loss(self) -> torch.Tensor: class Top1Router(MoeRouter): - """Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) - and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed + """Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) + and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed function can be found in the paper about Switch Transformer of Google. Args: @@ -142,22 +138,17 @@ def __init__(self, self.select_policy = select_policy assert select_policy in {"first", "random"} if select_policy == "random": - self.uniform = torch.distributions.uniform.Uniform( - low=torch.tensor(0.0, device=get_current_device()), - high=torch.tensor(1.0, device=get_current_device()) - ).rsample - - def forward(self, - inputs: torch.Tensor, - use_kernel: bool = False, - ep_group: Optional[ProcessGroup] = None - ) -> Tuple: + self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()), + high=torch.tensor(1.0, + device=get_current_device())).rsample + + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: """ Args: inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). Returns: - 1. use_kernel is False: + 1. use_kernel is False: The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity). The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity). 2. use_kernel is True: @@ -188,9 +179,9 @@ def forward(self, rand_mask = mask * self.uniform(mask.shape) _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) - ranks = moe_cumsum(mask) + ranks = moe_cumsum(mask, use_kernel=self.use_kernel) elif self.select_policy == "first": - ranks = moe_cumsum(mask) + ranks = moe_cumsum(mask, use_kernel=self.use_kernel) mask = mask * torch.lt(ranks, capacity) else: raise NotImplementedError("Not support such select policy yet.") @@ -211,8 +202,8 @@ def forward(self, class Top2Router(MoeRouter): - """Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) - and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed + """Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) + and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed function can be found in the paper about ViT-MoE. Args: @@ -236,17 +227,13 @@ def __init__(self, noisy_func=noisy_func, drop_tks=drop_tks) - def forward(self, - inputs: torch.Tensor, - use_kernel: bool = False, - ep_group: Optional[ProcessGroup] = None - ) -> Tuple: + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: """ Args: inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). Returns: - 1. use_kernel is False: + 1. use_kernel is False: The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity). The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity). 2. use_kernel is True: @@ -280,8 +267,8 @@ def forward(self, dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) capacity = max_num.item() - rank1 = moe_cumsum(mask1) # rank1: [s, e] - rank2 = moe_cumsum(mask2) + rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e] + rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel) rank2 += torch.sum(mask1, dim=-2, keepdim=True) mask1 *= torch.lt(rank1, capacity) @@ -313,7 +300,7 @@ def forward(self, weight1 = mask1 * probs.type_as(inputs) weight2 = mask2 * probs.type_as(inputs) - cb_weight = torch.zeros(inputs.shape + (capacity, ), device=inputs.device) + cb_weight = torch.zeros(inputs.shape + (capacity,), device=inputs.device) sec_mask = torch.zeros_like(cb_weight, dtype=torch.bool) indices = torch.arange(0, inputs.shape[0], device=inputs.device) cb_weight[indices, top1_idx[indices], rank1[indices]] += weight1[indices, top1_idx[indices]] @@ -348,17 +335,14 @@ def __init__(self, min_capacity: int = 4, noisy_func: Optional[Callable] = None, drop_tks: bool = True): - super().__init__(num_selected_experts, - capacity_factor_train, - capacity_factor_eval, - min_capacity, - noisy_func, + super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks) - def forward(self, - router_probs: torch.Tensor, - expert_capacity: int, - ) -> Tuple: + def forward( + self, + router_probs: torch.Tensor, + expert_capacity: int, + ) -> Tuple: """Computes masks for the top-k experts per token. Args: @@ -418,17 +402,12 @@ def forward(self, # The combine array will be used for combining expert outputs, scaled by the # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, # expert_capacity]. - combine_array = torch.einsum( - '...te,...tec->...tec', - router_probs, - dispatch_mask) + combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask) return combine_array, dispatch_mask -def get_router_cls(top_k: int, - grouped: bool = False - ) -> MoeRouter: +def get_router_cls(top_k: int, grouped: bool = False) -> MoeRouter: if not grouped: if top_k == 1: return Top1Router diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py index e3bc6d3cac9a..0938e4206fda 100644 --- a/colossalai/moe/utils.py +++ b/colossalai/moe/utils.py @@ -1,5 +1,5 @@ import contextlib -from typing import Callable, Dict, List +from typing import Any, Callable, Dict, List import torch import torch.distributed as dist @@ -170,3 +170,8 @@ def sync_moe_model_param(model: nn.Module): for param in param_dict[ep_size]: src_rank = get_dp_group_ranks(param)[0] dist.broadcast(param, src=src_rank, group=get_dp_group(param)) + + +def set_moe_args(config: Any, args: dict): + for k, v in args.items(): + setattr(config, k, v) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index f08ebea58589..59c302103729 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -7,7 +7,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from torch._utils import _flatten_dense_tensors +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch.distributed import ProcessGroup from torch.optim import Optimizer @@ -281,6 +281,40 @@ def _run_reduction(self): if self.extra_dp_pg is None: flat_grads = self._bucket_store.get_flatten_grad() flat_grads /= self._world_size + else: + # record moe and non moe param + moe_list = [] + for param in self._bucket_store._param_list: + moe_list.append(is_moe_tensor(param)) + + # divide them into different groups + moe_grad_list = [] + non_moe_grad_list = [] + for grad_list in self._bucket_store._grad_in_bucket.values(): + non_moe_cur_grad = [] + moe_cur_grad = [] + for i in range(len(grad_list)): + if moe_list[i] == True: + moe_cur_grad.append(grad_list[i]) + else: + non_moe_cur_grad.append(grad_list[i]) + if len(moe_cur_grad) > 0: + moe_grad_list.append(moe_cur_grad) + if len(non_moe_cur_grad) > 0: + non_moe_grad_list.append(non_moe_cur_grad) + + if len(non_moe_grad_list) > 0: + non_moe_flat_grads = [] + for grad_list in non_moe_grad_list: + non_moe_flat_grads.append(_flatten_dense_tensors(grad_list)) + non_moe_flat_grads = _flatten_dense_tensors(non_moe_flat_grads) + non_moe_flat_grads /= self._world_size + + if len(moe_grad_list) > 0: + moe_flat_grads = [] + for grad_list in moe_grad_list: + moe_flat_grads.append(_flatten_dense_tensors(grad_list)) + moe_flat_grads = _flatten_dense_tensors(moe_flat_grads) # ready to add other tensors to bucket self._bucket_store.reset_num_elements_in_bucket() @@ -290,6 +324,11 @@ def _run_reduction(self): # in case of the memory being reused in the default stream if self.extra_dp_pg is None: flat_grads.record_stream(stream) + else: + if len(non_moe_grad_list) > 0: + non_moe_flat_grads.record_stream(stream) + if len(moe_grad_list) > 0: + moe_flat_grads.record_stream(stream) # waiting for ops in the default stream finishing stream.wait_stream(torch.cuda.current_stream()) else: @@ -324,64 +363,70 @@ def _run_reduction(self): # sync extra zero group else: - # record moe and non moe param - moe_list = [] - for param in self._bucket_store._param_list: - moe_list.append(is_moe_tensor(param)) - - # divide them into different groups - moe_grad_list = [] - non_moe_grad_list = [] - for grad_list in self._bucket_store._grad_in_bucket.values(): - non_moe_cur_grad = [] - moe_cur_grad = [] - for i in range(len(grad_list)): - if moe_list[i] == True: - moe_cur_grad.append(grad_list[i]) - else: - non_moe_cur_grad.append(grad_list[i]) - if len(moe_cur_grad) > 0: - moe_grad_list.append(moe_cur_grad) - if len(non_moe_cur_grad) > 0: - non_moe_grad_list.append(non_moe_cur_grad) - # sync non moe param in global dp group if len(non_moe_grad_list) > 0: - flat_grads = [] - for grad_list in non_moe_grad_list: - flat_grads.append(_flatten_dense_tensors(grad_list)) - flat_grads = _flatten_dense_tensors(flat_grads) - flat_grads /= self._world_size - dist.all_reduce(flat_grads, group=self.dp_pg) - flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) + dist.all_reduce(non_moe_flat_grads, group=self.dp_pg) + flat_grads_per_rank = non_moe_flat_grads.split(non_moe_flat_grads.numel() // + self._world_size) self._sync_unpartitioned_grad(non_moe_grad_list, flat_grads_per_rank, group_id) # sync moe param only in zero group if len(moe_grad_list) > 0: - flat_grads = [] - for grad_list in moe_grad_list: - flat_grads.append(_flatten_dense_tensors(grad_list)) - flat_grads = _flatten_dense_tensors(flat_grads) - dist.all_reduce(flat_grads, group=self.extra_dp_pg) - flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) + dist.all_reduce(moe_flat_grads, group=self.extra_dp_pg) + flat_grads_per_rank = moe_flat_grads.split(moe_flat_grads.numel() // self._world_size) self._sync_unpartitioned_grad(moe_grad_list, flat_grads_per_rank, group_id) else: - flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) - recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) - - if recieved_grad.dtype != grad_dtype: - recieved_grad = recieved_grad.to(grad_dtype) - - grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] - sync_tensor(recieved_grad, grad_in_bucket_current_rank) - for grad in grad_in_bucket_current_rank: - param_id = self._bucket_store.get_param_id_of_grad(grad) - if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1: - self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) - else: - self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) + if self.extra_dp_pg is None: + flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) + recieved_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) + + if recieved_grad.dtype != grad_dtype: + recieved_grad = recieved_grad.to(grad_dtype) + + grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] + sync_tensor(recieved_grad, grad_in_bucket_current_rank) + for grad in grad_in_bucket_current_rank: + param_id = self._bucket_store.get_param_id_of_grad(grad) + if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1: + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + else: + self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) + else: + if len(non_moe_grad_list) > 0: + flat_grads_list = list(non_moe_flat_grads.split( + len(non_moe_flat_grads) // self._world_size)) + recieved_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) + + grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] + sync_tensor(recieved_grad, grad_in_bucket_current_rank) + for grad in grad_in_bucket_current_rank: + param_id = self._bucket_store.get_param_id_of_grad(grad) + if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1: + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + else: + self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) + + if len(moe_grad_list) > 0: + flat_grads_list = list(moe_flat_grads.split(len(moe_flat_grads) // self.extra_dp_pg_size)) + recieved_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.extra_dp_pg) + + param_slice = self._world_size // self.extra_dp_pg_size + recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice)) + grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] + for split_recieved_grad in recieved_grad: + split_recieved_grad = _unflatten_dense_tensors(split_recieved_grad, + grad_in_bucket_current_rank) + for grad in grad_in_bucket_current_rank: + param_id = self._bucket_store.get_param_id_of_grad(grad) + if len(self._grad_store.get_partitioned_gradients_by_param_id( + group_id, param_id)) < param_slice: + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + else: + self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) self._bucket_store.reset() @@ -512,18 +557,19 @@ def step(self, closure=None): # moe hybrid zero if self.extra_dp_pg is not None and is_moe_tensor(working_param): real_working_params[group_id].append(working_param) - param_slice = self._world_size // self.extra_dp_pg_size - grad = grads[self.extra_dp_pg_rank * param_slice:(self.extra_dp_pg_rank + 1) * param_slice] - grad = flatten(grad).to(splited_param.dtype).to(splited_param.device) - splited_param.grad = grad - grad_partition_groups.append(grad) - real_master_params[group_id].append(splited_param) + if self._partition_grads: + grad = grads + else: + param_slice = self._world_size // self.extra_dp_pg_size + grad = grads[self.extra_dp_pg_rank * param_slice:(self.extra_dp_pg_rank + 1) * param_slice] + grad = flatten(grad) else: real_working_params[group_id].append(working_param) - grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device) - splited_param.grad = grad - grad_partition_groups.append(grad) - real_master_params[group_id].append(splited_param) + grad = grads[grad_index] + grad = grad.to(splited_param.dtype).to(splited_param.device) + splited_param.grad = grad + grad_partition_groups.append(grad) + real_master_params[group_id].append(splited_param) # compute norm working_grads = self._grad_store.get_working_grads_by_group_id(group_id) @@ -539,13 +585,21 @@ def step(self, closure=None): global_norm = calculate_global_norm_from_list(norm_list=norm_groups) self._unscale_and_clip_grads(grad_partition_groups, global_norm) + # TODO: we should store master param for ep + if len(self.param_groups) > len(self._working_param_groups): + for param in self.param_groups[-1]['params']: + param.data = param.data.to(torch.float32) + param.grad = param.grad.to(torch.float32) + # update the parameters self.optim.step() - # release the moe grad + # TODO: release the moe grad. we should store master param if len(self.param_groups) > len(self._working_param_groups): + dtype = real_working_params[0][0].dtype for param in self.param_groups[-1]['params']: param.grad = None + param.data = param.data.to(dtype) # release the grad grad_partition_groups = [] diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index e1acba5c88b0..2f6bfa0f89a2 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -15,12 +15,12 @@ import colossalai from colossalai import get_default_parser from colossalai.booster import Booster -from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator from colossalai.moe.layers import apply_load_balance from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import skip_init +from colossalai.moe.utils import set_moe_args, skip_init +from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device @@ -118,7 +118,7 @@ def parse_args(): parser.add_argument("--pp_size", type=int, default=2, help="pp size") parser.add_argument("--dp_size", type=int, default=1, help="dp size") parser.add_argument("--ep_size", type=int, default=2, help="ep size") - parser.add_argument("--zero_stage", type=int, default=1, help="zero stage in hybrid plugin") + parser.add_argument("--zero_stage", type=int, default=2, help="zero stage in hybrid plugin") parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size") parser.add_argument("--extra_dp_size", type=int, default=1) # kernel @@ -132,6 +132,9 @@ def parse_args(): parser.add_argument("--active", type=int, default=20) # load balance parser.add_argument("--load_balance", action="store_true") + + # overlap + parser.add_argument("--overlap_alltoall", action="store_true") args = parser.parse_args() return args @@ -150,49 +153,38 @@ def main(): "custom_policy": OpenMoeForCausalLMPolicy(), "enable_fused_normalization": args.use_kernel, "enable_jit_fused": args.use_kernel, - "precision": "bf16" + "precision": "bf16", + "zero_stage": args.zero_stage, + } + mgr_dict = { + "seed": 42, } - mgr_dict = {"seed": 42, "use_kernel_optim": args.use_kernel, "enable_load_balance": args.load_balance} if args.plugin == "zero": - dp_size = dist.get_world_size() - plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) - MOE_MANAGER.setup( - parallel=None, - **mgr_dict, - ) - elif args.plugin == "ep": dp_size = dist.get_world_size() plugin = MoeHybridParallelPlugin( pp_size=1, - zero_stage=2, **hybrid_dict, ) MOE_MANAGER.setup( - parallel="EP", + parallel=None, **mgr_dict, ) - elif args.plugin == "ep_zero": + elif args.plugin == "ep": dp_size = dist.get_world_size() - use_ep_inside = False plugin = MoeHybridParallelPlugin( pp_size=1, - zero_stage=1, - extra_dp_size=args.extra_dp_size, - use_ep_inside=use_ep_inside, **hybrid_dict, ) MOE_MANAGER.setup( parallel="EP", - max_ep_size=dp_size // args.extra_dp_size, - use_ep_inside=use_ep_inside, + max_ep_size=dp_size, **mgr_dict, ) - elif args.plugin == "zero_ep": + elif args.plugin == "ep_zero": dp_size = dist.get_world_size() - use_ep_inside = True + use_ep_inside = False plugin = MoeHybridParallelPlugin( pp_size=1, - zero_stage=1, extra_dp_size=args.extra_dp_size, use_ep_inside=use_ep_inside, **hybrid_dict, @@ -226,10 +218,28 @@ def main(): # Build OpenMoe model repo_name = "hpcaitech/openmoe-" + args.model_name config = LlamaConfig.from_pretrained(repo_name) - setattr(config, "router_aux_loss_factor", 0.1) - setattr(config, "router_z_loss_factor", 0.1) - setattr(config, "label_smoothing", 0.1) - setattr(config, "z_loss_factor", 0.1) + moe_args = { + "num_experts": config.num_experts, + "moe_layer_interval": config.moe_layer_interval, + "router_topk": 2, + "router_capacity_factor_train": 1.25, + "router_capacity_factor_eval": 2.0, + "router_min_capacity": 4, + "router_noisy_policy": None, + "router_drop_tks": True, + "router_aux_loss_factor": 0.01, + "router_z_loss_factor": 0.01, + "mlp_gated": True, + "label_smoothing": 0.001, + "z_loss_factor": 0.01, + "enable_load_balance": args.load_balance, + "load_balance_tolerance": 0.1, + "load_balance_beam_width": 8, + "load_balance_group_swap_factor": 0.4, + "enable_kernel": args.use_kernel, + "enable_comm_overlap": args.overlap_alltoall, + } + set_moe_args(config, moe_args) with skip_init(): model = OpenMoeForCausalLM(config) coordinator.print_on_master(f"Finish init model with config:\n{config}") @@ -247,7 +257,7 @@ def main(): dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size) # Set optimizer - optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01, lr=1e-5) + optimizer = HybridAdam(model.parameters(), weight_decay=0.01, lr=1e-5) model_numel = get_model_numel(model) performance_evaluator = PerformanceEvaluator( @@ -259,8 +269,8 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) - model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) load_ckpt(repo_name, model, booster) + model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) use_pipeline = (isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1) is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() coordinator.print_on_master(f"Finish init booster") @@ -302,8 +312,8 @@ def main(): optimizer.zero_grad() performance_evaluator.on_step_end(exmaple_data["input_ids"]) if (step == args.warmup // 2) and args.load_balance: - apply_load_balance(model, optimizer) coordinator.print_on_master(f"Apply load balance") + apply_load_balance(model, optimizer) performance_evaluator.on_fit_end() diff --git a/examples/language/openmoe/benchmark/benchmark_cai.sh b/examples/language/openmoe/benchmark/benchmark_cai.sh index ec4490faa55d..f269e260d8db 100755 --- a/examples/language/openmoe/benchmark/benchmark_cai.sh +++ b/examples/language/openmoe/benchmark/benchmark_cai.sh @@ -2,10 +2,10 @@ set -xue -NUM_GPU=4 +NUM_GPU=8 MODEL="8b" SEQ_LENGTH=2048 -WARMUP=8 +WARMUP=20 ACTIVE=4 # HACK: make model importable @@ -16,51 +16,50 @@ else export PYTHONPATH=$example_dir:$PYTHONPATH fi -# zero -torchrun --standalone --nproc_per_node $NUM_GPU \ - $example_dir/benchmark/benchmark_cai.py \ - --model_name $MODEL \ - --batch_size 4 \ - --seq_length $SEQ_LENGTH \ - --warmup $WARMUP \ - --active $ACTIVE \ - --plugin zero \ - --use_kernel # ep +echo -e "\n\n Naive EP \n\n" torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ - --batch_size 12 \ + --batch_size 8 \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ --plugin ep \ - --use_kernel + --zero_stage 2 + # ep_zero +echo -e "\n\n EP-ZERO \n\n" torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ - --batch_size 12 \ + --batch_size 16 \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ --plugin ep_zero \ --use_kernel \ - --extra_dp_size 2 + --extra_dp_size 2 \ + --zero_stage 1 \ + --load_balance -# zero_ep +echo -e "\n\n EP-ZERO + Overlap \n\n" torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ - --batch_size 12 \ + --batch_size 16 \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ - --plugin zero_ep \ + --plugin ep_zero \ --use_kernel \ - --extra_dp_size 2 + --extra_dp_size 2 \ + --zero_stage 1 \ + --load_balance \ + --overlap_alltoall + # hybrid torchrun --standalone --nproc_per_node $NUM_GPU \ diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.py b/examples/language/openmoe/benchmark/benchmark_fsdp.py index 0edf102d640c..531e18313798 100644 --- a/examples/language/openmoe/benchmark/benchmark_fsdp.py +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.py @@ -12,7 +12,6 @@ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.utils.data import Dataset from torch.utils.data.distributed import DistributedSampler -from transformers import Adafactor from transformers.models.llama import LlamaConfig from utils import PerformanceEvaluator, get_model_numel @@ -80,7 +79,7 @@ def fsdp_main(rank, world_size, args): auto_wrap_policy=auto_wrap_policy, device_id=torch.cuda.current_device(), ) - optimizer = Adafactor(model.parameters()) + optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01, lr=1e-5) model.train() model_numel = get_model_numel(model) diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.sh b/examples/language/openmoe/benchmark/benchmark_fsdp.sh index e1eb2a9c6053..0380ee1ade20 100755 --- a/examples/language/openmoe/benchmark/benchmark_fsdp.sh +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.sh @@ -6,8 +6,8 @@ NUM_GPU=8 MODEL="8b" BATCH_SIZE=1 SEQ_LENGTH=2048 -WARMUP=6 -ACTIVE=3 +WARMUP=8 +ACTIVE=4 # HACK: make model importable example_dir=$(dirname $(realpath $(dirname $0))) diff --git a/examples/language/openmoe/benchmark/hostfile.txt b/examples/language/openmoe/benchmark/hostfile.txt new file mode 100644 index 000000000000..994b3e2cfc4f --- /dev/null +++ b/examples/language/openmoe/benchmark/hostfile.txt @@ -0,0 +1,2 @@ +host1 +host2 diff --git a/examples/language/openmoe/infer.py b/examples/language/openmoe/infer.py index f59772189827..1ad1456b9c56 100644 --- a/examples/language/openmoe/infer.py +++ b/examples/language/openmoe/infer.py @@ -5,6 +5,8 @@ from transformers import T5Tokenizer from transformers.models.llama import LlamaConfig +from colossalai.moe.utils import set_moe_args + def parse_args(): parser = ArgumentParser() @@ -17,9 +19,54 @@ def inference(args): tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") if args.model == "test": config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base") + moe_args = { + "num_experts": config.num_experts, + "moe_layer_interval": config.moe_layer_interval, + "router_topk": 2, + "router_capacity_factor_train": 1.25, + "router_capacity_factor_eval": 2.0, + "router_min_capacity": 4, + "router_noisy_policy": None, + "router_drop_tks": True, + "router_aux_loss_factor": 0.01, + "router_z_loss_factor": 0.01, + "mlp_gated": True, + "label_smoothing": 0.001, + "z_loss_factor": 0.01, + "enable_load_balance": False, + "load_balance_tolerance": 0.1, + "load_balance_beam_width": 8, + "load_balance_group_swap_factor": 0.4, + "enable_kernel": False, + "enable_comm_overlap": False, + } + set_moe_args(config, moe_args) model = OpenMoeForCausalLM(config) else: - model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}") + config = LlamaConfig.from_pretrained(f"hpcaitech/openmoe-{args.model}") + moe_args = { + "num_experts": config.num_experts, + "moe_layer_interval": config.moe_layer_interval, + "router_topk": 2, + "router_capacity_factor_train": 1.25, + "router_capacity_factor_eval": 2.0, + "router_min_capacity": 4, + "router_noisy_policy": None, + "router_drop_tks": True, + "router_aux_loss_factor": 0.01, + "router_z_loss_factor": 0.01, + "mlp_gated": True, + "label_smoothing": 0.001, + "z_loss_factor": 0.01, + "enable_load_balance": False, + "load_balance_tolerance": 0.1, + "load_balance_beam_width": 8, + "load_balance_group_swap_factor": 0.4, + "enable_kernel": False, + "enable_comm_overlap": False, + } + set_moe_args(config, moe_args) + model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}", config=config) model = model.eval().half() model = model.to(torch.cuda.current_device()) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 357c0f22a783..6f9b668e4597 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -39,6 +39,7 @@ from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.moe.layers import SparseMLP from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import get_activation if HAS_TRITON: from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine @@ -166,7 +167,7 @@ def SwiGLU(x): class OpenMoeMLP(nn.Module): - def __init__(self, config): + def __init__(self, config: LlamaConfig): super().__init__() self.pretraining_tp = config.pretraining_tp self.hidden_size = config.hidden_size @@ -174,8 +175,9 @@ def __init__(self, config): self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = SwiGLU - self.use_kernel = True if MOE_MANAGER.use_kernel_optim else False + self.hidden_act = config.hidden_act + self.act_fn = get_activation(self.hidden_act) + self.use_kernel = config.enable_kernel def forward(self, x): if self.pretraining_tp > 1: @@ -191,7 +193,7 @@ def forward(self, x): down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] down_proj = sum(down_proj) else: - if HAS_TRITON and self.use_kernel: + if HAS_TRITON and self.use_kernel and self.hidden_act == "swiglu": down_proj = self.down_proj(LlamaActCombine.apply(self.gate_proj(x), self.up_proj(x))) else: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) @@ -361,16 +363,22 @@ def __init__(self, config: LlamaConfig, moe: bool): self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) if self.moe: self.mlp = SparseMLP(num_experts=config.num_experts, - top_k=config.topk, - capacity_factor_train=config.capacity_factor_train, - capacity_factor_eval=config.capacity_factor_eval, - min_capacity=config.min_capacity, - noisy_policy=config.noisy_policy, - drop_tks=config.drop_tks, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, - activation=config.hidden_act, - gated=config.gated) + router_top_k=config.router_topk, + router_capacity_factor_train=config.router_capacity_factor_train, + router_capacity_factor_eval=config.router_capacity_factor_eval, + router_min_capacity=config.router_min_capacity, + router_noisy_policy=config.router_noisy_policy, + router_drop_tks=config.router_drop_tks, + mlp_activation=config.hidden_act, + mlp_gated=config.mlp_gated, + enable_load_balance=config.enable_load_balance, + load_balance_tolerance=config.load_balance_tolerance, + load_balance_beam_width=config.load_balance_beam_width, + load_balance_group_swap_factor=config.load_balance_group_swap_factor, + enable_kernel=config.enable_kernel, + enable_comm_overlap=config.enable_comm_overlap) self.pre_extra_mlp_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.extra_mlp = OpenMoeMLP(config) else: diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index 6f239104328c..ec9ec21b55dc 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -19,7 +19,7 @@ from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.moe import MoeCheckpintIO from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import skip_init +from colossalai.moe.utils import set_moe_args, skip_init from colossalai.utils import get_current_device @@ -157,7 +157,6 @@ def main(): MOE_MANAGER.setup( seed=42, parallel="EP", - use_kernel_optim=args.use_kernel if not test_mode else False, ) elif args.plugin == "zero2_ep": plugin = MoeHybridParallelPlugin( @@ -171,7 +170,6 @@ def main(): MOE_MANAGER.setup( seed=42, parallel="EP", - use_kernel_optim=args.use_kernel if not test_mode else False, ) elif args.plugin == "hybrid": plugin = MoeHybridParallelPlugin( @@ -190,7 +188,6 @@ def main(): fixed_dp_size=args.dp_size, fixed_ep_size=args.ep_size, fixed_pp_size=args.pp_size, - use_kernel_optim=args.use_kernel, ) else: raise ValueError(f"Invalid plugin {args.plugin}") @@ -205,10 +202,28 @@ def main(): else: repo_name = "hpcaitech/openmoe-" + args.model_name config = LlamaConfig.from_pretrained(repo_name) - setattr(config, "router_aux_loss_factor", args.router_aux_loss_factor) - setattr(config, "router_z_loss_factor", args.router_z_loss_factor) - setattr(config, "label_smoothing", args.label_smoothing) - setattr(config, "z_loss_factor", args.z_loss_factor) + moe_args = { + "num_experts": config.num_experts, + "moe_layer_interval": config.moe_layer_interval, + "router_topk": 2, + "router_capacity_factor_train": 1.25, + "router_capacity_factor_eval": 2.0, + "router_min_capacity": 4, + "router_noisy_policy": None, + "router_drop_tks": True, + "router_aux_loss_factor": 0.01, + "router_z_loss_factor": 0.01, + "mlp_gated": True, + "label_smoothing": 0.001, + "z_loss_factor": 0.01, + "enable_load_balance": False, + "load_balance_tolerance": 0.1, + "load_balance_beam_width": 8, + "load_balance_group_swap_factor": 0.4, + "enable_kernel": False, + "enable_comm_overlap": False, + } + set_moe_args(config, moe_args) with skip_init(): model = OpenMoeForCausalLM(config) logger.info(f"Finish init model with config:\n{config}", ranks=[0]) diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index 934061ae4417..2e116de2db7d 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -14,13 +14,16 @@ class MoeModel(nn.Module): - def __init__(self, checkpoint: bool = False): + def __init__(self, checkpoint: bool = False, enable_load_balance: bool = False): class TestSubModule(CheckpointModule): def __init__(self): super().__init__(checkpoint) - self.moe = SparseMLP(num_experts=8, hidden_size=16, intermediate_size=32) + self.moe = SparseMLP(num_experts=8, + hidden_size=16, + intermediate_size=32, + enable_load_balance=enable_load_balance) self.proj = nn.Linear(16, 4) def _forward(self, x): diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index 13e142aadd7a..28ee618e1ba7 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -16,17 +16,26 @@ def run_test(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch( + config=dict(), + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) MOE_MANAGER.setup(42, parallel="EP") # MOE initialization num_experts_list = [1, 2, 4] layer_list = [] for num_experts in num_experts_list: - moe_layer = SparseMLP(hidden_size=DIM, - intermediate_size=DIM * 4, - num_experts=num_experts, - top_k=1, - noisy_policy="Jitter") + moe_layer = SparseMLP( + hidden_size=DIM, + intermediate_size=DIM * 4, + num_experts=num_experts, + router_top_k=1, + router_noisy_policy="Jitter", + ) layer_list.append(moe_layer) model = nn.ModuleList(layer_list) @@ -77,5 +86,5 @@ def test_grad_handler(): spawn(run_test, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_grad_handler() diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index db40110d8d9e..4b9885ab0691 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -33,14 +33,14 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f layer = SparseMLP(hidden_size=hidden_size, intermediate_size=hidden_size * 2, num_experts=NUM_EXPERTS, - top_k=topk, - capacity_factor_train=1.0) + router_top_k=topk, + router_capacity_factor_train=1.0) layer = layer.to(get_current_device()) if data_type == torch.float16: layer = layer.half() # use matrix multiplication instead of COL_MOE_KERNEL in MOE dispatch and combine - layer.use_kernel = False + layer.enable_kernel = False old_out = layer(tokens) ech = old_out.shape grad = torch.randn(ech, device=get_current_device()) @@ -54,7 +54,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f tokens.grad.zero_() layer.gate_weight.grad.zero_() - layer.use_kernel = True + layer.enable_kernel = True new_out = layer(tokens) # get outputs through colossal kernel if data_type == torch.float32: diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 09af499185db..40aae12f016a 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -16,7 +16,6 @@ sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "examples/language/openmoe")) -# TODO: better way to import them OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy @@ -37,22 +36,27 @@ def get_config(): "head_dim": 4, "num_attention_heads": 4, "dropout_rate": 0.0, - "layer_norm_epsilon": 1e-06, "hidden_act": "swiglu", "num_experts": 16, - "topk": 2, "capacity_factor_train": 1.25, "capacity_factor_eval": 2.0, "min_capacity": 4, "noisy_policy": None, "drop_tks": True, - "expert_parallel": None, - "gated": True, "moe_layer_interval": 4, "router_aux_loss_factor": 0.1, "router_z_loss_factor": 0.1, "label_smoothing": 0.1, "z_loss_factor": 0.1, + "mlp_gated": True, + "label_smoothing": 0.001, + "z_loss_factor": 0.01, + "enable_load_balance": False, + "load_balance_tolerance": 0.1, + "load_balance_beam_width": 8, + "load_balance_group_swap_factor": 0.4, + "enable_kernel": False, + "enable_comm_overlap": False, } for key, value in settings.items(): setattr(config, key, value) diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 51fd135483b6..11d0664fd580 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -11,32 +11,20 @@ from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep, sync_tp_from_ep -def run_test(rank: int, - world_size: int, - port: int, - num_experts: int, - batch_size: int, - dim: int, - seed: int): +def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, seed: int): assert batch_size % world_size == 0 colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') MOE_MANAGER.__init__() MOE_MANAGER.setup(seed, parallel=None) - local_model = SparseMLP(num_experts=num_experts, - hidden_size=dim, - intermediate_size=dim * 2) + local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) MOE_MANAGER.__init__() MOE_MANAGER.setup(seed, parallel="EP") - ep_model = SparseMLP(num_experts=num_experts, - hidden_size=dim, - intermediate_size=dim * 2) + ep_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) MOE_MANAGER.__init__() MOE_MANAGER.setup(seed, parallel="TP") - tp_model = SparseMLP(num_experts=num_experts, - hidden_size=dim, - intermediate_size=dim * 2) + tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) ep_model = ep_model.to(get_current_device()) tp_model = tp_model.to(get_current_device()) local_model = local_model.to(get_current_device()) @@ -81,14 +69,11 @@ def run_test(rank: int, @pytest.mark.dist @pytest.mark.parametrize("num_experts", [4, 8]) -@pytest.mark.parametrize("batch_size", [4, 8]) -@pytest.mark.parametrize("dim", [16, 256]) -@pytest.mark.parametrize("seed", [42, 78]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("dim", [32]) +@pytest.mark.parametrize("seed", [42]) @rerun_if_address_is_in_use() -def test_moe_ep_tp(num_experts: int, - batch_size: int, - dim: int, - seed: int): +def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int): spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed) diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index e111ea6bb18d..3cd5acc0d953 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -3,7 +3,7 @@ import torch.nn as nn import colossalai -from colossalai.moe import EPMLPExperts, TPMLPExperts +from colossalai.moe.experts import MLPExperts from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import sync_moe_model_param from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn @@ -13,38 +13,39 @@ INTERMEDIATE_SIZE = 8 -def run_moe_init(expert_cls): - expert_args = dict(hidden_size=HIDDEN_SIZE, intermediate_size=INTERMEDIATE_SIZE) - exp0 = expert_cls(1, **expert_args) - exp1 = expert_cls(2, **expert_args) - exp2 = expert_cls(4, **expert_args) - exp3 = expert_cls(8, **expert_args) +def run_moe_init(expert_parallel): + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed=42, parallel=expert_parallel) + expert_args = dict( + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + expert_parallel=expert_parallel, + ) + exp0 = MLPExperts(1, **expert_args) + exp1 = MLPExperts(2, **expert_args) + exp2 = MLPExperts(4, **expert_args) - if expert_cls is EPMLPExperts: + if expert_parallel == "EP": assert exp0.num_local_experts == 1 assert exp1.num_local_experts == 1 - assert exp2.num_local_experts == 1 - assert exp3.num_local_experts == 2 + assert exp2.num_local_experts == 2 else: assert exp0.num_local_experts == 1 assert exp1.num_local_experts == 2 assert exp2.num_local_experts == 4 - assert exp3.num_local_experts == 8 parallel_info_dict = MOE_MANAGER.parallel_info_dict rank = dist.get_rank() # group creation assert - assert len(parallel_info_dict) == 3 - assert dist.get_rank(parallel_info_dict[4].ep_group) == rank + assert len(parallel_info_dict) == 2 assert dist.get_rank(parallel_info_dict[2].ep_group) == rank % 2 assert dist.get_rank(parallel_info_dict[1].ep_group) == 0 - assert dist.get_rank(parallel_info_dict[4].dp_group) == 0 assert dist.get_rank(parallel_info_dict[2].dp_group) == rank // 2 assert dist.get_rank(parallel_info_dict[1].dp_group) == rank - model = nn.ModuleList([exp0, exp1, exp2, exp3]) + model = nn.ModuleList([exp0, exp1, exp2]) model = model.to(get_current_device()) sync_moe_model_param(model) @@ -57,19 +58,25 @@ def run_moe_init(expert_cls): assert_equal_in_group(exp1.wo.data, parallel_info_dict[2].dp_group) -def _run_test(rank, world_size, port, expert_cls): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(seed=42, parallel="EP") - run_moe_init(expert_cls) +def _run_test(rank, world_size, port, expert_parallel): + colossalai.launch( + config=dict(), + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) + run_moe_init(expert_parallel) @pytest.mark.dist -@pytest.mark.parametrize("expert_cls", [EPMLPExperts, TPMLPExperts]) +@pytest.mark.parametrize("expert_parallel", ["EP", "TP"]) @rerun_if_address_is_in_use() -def test_moe_initialization(expert_cls): - spawn(_run_test, 4, expert_cls=expert_cls) +def test_moe_initialization(expert_parallel): + spawn(_run_test, 2, expert_parallel=expert_parallel) -if __name__ == '__main__': - test_moe_initialization(EPMLPExperts) - test_moe_initialization(TPMLPExperts) +if __name__ == "__main__": + test_moe_initialization("EP") + test_moe_initialization("TP") diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py index b4eea04bc85a..5126c61ae92f 100644 --- a/tests/test_moe/test_moe_load_balance.py +++ b/tests/test_moe/test_moe_load_balance.py @@ -47,12 +47,8 @@ def run_zero_optim_test(local_rank, world_size, stage=1): MOE_MANAGER.setup( seed=42, parallel="EP", - enable_load_balance=True, - tolerance=0.1, - beam_width=8, - group_swap_factor=0.4, ) - zero_model = MoeModel(checkpoint=True) + zero_model = MoeModel(checkpoint=True, enable_load_balance=True) zero_optimizer = torch.optim.Adam(zero_model.parameters()) plugin = LowLevelZeroPlugin(stage=stage, precision="bf16", verbose=True) booster = Booster(plugin=plugin) @@ -118,12 +114,8 @@ def run_hybrid_zero_optim_test(local_rank, world_size, stage=1): max_ep_size=2, use_ep_inside=False, parallel="EP", - enable_load_balance=True, - tolerance=0.1, - beam_width=8, - group_swap_factor=0.4, ) - zero_model = MoeModel(checkpoint=True) + zero_model = MoeModel(checkpoint=True, enable_load_balance=True) extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group) ep_size = MOE_MANAGER.parallel_info_dict[2].ep_size