diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index 81a7b21544e4..d2ca1cbc8276 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -1,5 +1,4 @@ import math -from contextlib import nullcontext from typing import Callable, Optional, Tuple import torch diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index 036bd32ae7c0..3f82a0fa23fd 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -1,16 +1,18 @@ import math -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter from colossalai.moe.experts import BaseMLPExperts, get_expert_class +from colossalai.moe.load_balance import LoadBalancer from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.routers import MoeRouter, get_router_cls from colossalai.moe.utils import get_noise_generator -from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size +from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size class SparseMLP(nn.Module): @@ -72,6 +74,7 @@ def __init__( # moe router noisy_func = get_noise_generator(noisy_policy, num_experts) router_cls = get_router_cls(top_k) + self.topk = top_k self.router: MoeRouter = router_cls( capacity_factor_train=capacity_factor_train, capacity_factor_eval=capacity_factor_eval, @@ -91,13 +94,30 @@ def __init__( if self.expert_parallel is not None: self.ep_group = get_ep_group(self.experts) self.ep_size = get_ep_size(self.experts) + self.dp_group = get_dp_group(self.experts) else: self.ep_group = None + self.dp_group = None self.num_local_experts = self.experts.num_local_experts # gate self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) + # load balance + self.enable_load_balance = MOE_MANAGER.load_balance + if self.enable_load_balance == True: + self.load_balancer = LoadBalancer( + experts=self.experts, + gate=self.gate_weight, + local_expert_num=self.num_local_experts, + expert_num=self.num_experts, + ep_group=self.ep_group, + dp_group=self.dp_group, + tolerance=MOE_MANAGER.tolerance, + beam_width=MOE_MANAGER.beam_width, + group_swap_factor=MOE_MANAGER.group_swap_factor, + ) + # init param self.reset_parameters() @@ -121,6 +141,14 @@ def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: fp32_weight = self.gate_weight.to(torch.float) gate_output = F.linear(fp32_input, fp32_weight) + # update expert load + if self.enable_load_balance == True: + with torch.no_grad(): + # TODO: optimize computation + expert_load = torch.topk(gate_output, k=self.topk, dim=-1)[1] + expert_load = torch.bincount(expert_load.view(-1)) + self.load_balancer.update_load(expert_load) + # the result from the router route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) @@ -257,3 +285,18 @@ def get_chunk_slice(idx: int, gap: int) -> Tuple[slice]: # sync for async op torch.cuda.synchronize() return out + + +def apply_load_balance(model: nn.Module, optim: Any) -> None: + """ + apply load balance to every experts in the model + """ + + def _apply_recursive(module: nn.Module): + for _, sub_module in module.named_children(): + if isinstance(sub_module, SparseMLP): + if sub_module.enable_load_balance == True: + sub_module.load_balancer.balance_load(optim) + _apply_recursive(sub_module) + + _apply_recursive(model) diff --git a/colossalai/moe/load_balance.py b/colossalai/moe/load_balance.py new file mode 100644 index 000000000000..b2fb672329c2 --- /dev/null +++ b/colossalai/moe/load_balance.py @@ -0,0 +1,429 @@ +from copy import deepcopy +from typing import List, Optional, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor, nn +from torch.distributed import ProcessGroup + +from colossalai.cluster import ProcessGroupMesh +from colossalai.moe.experts import BaseMLPExperts +from colossalai.moe.manager import MOE_MANAGER +from colossalai.zero.low_level import LowLevelZeroOptimizer + + +class LoadBalancer: + + def __init__( + self, + experts: BaseMLPExperts, + gate: nn.Parameter, + local_expert_num: int, + expert_num: int, + ep_group: ProcessGroup, + dp_group: ProcessGroup, + tolerance: Optional[float] = 0.1, + beam_width: Optional[int] = 8, + group_swap_factor: Optional[float] = 0.4, + ) -> None: + self.experts: BaseMLPExperts = experts + self.gate: nn.Parameter = gate + self.moe_ep_group: ProcessGroup = ep_group + self.moe_ep_ranks = MOE_MANAGER.parallel_info_dict[dist.get_world_size(self.moe_ep_group)].ep_group_ranks + self.moe_dp_group: ProcessGroup = dp_group + self.tolerance = tolerance + self.beam_width = beam_width + self.group_swap_factor = group_swap_factor + self.local_expert_num = local_expert_num + self.expert_num = expert_num + self.local_load = None + # TODO: use a global process group mesh + pp_size = 1 if MOE_MANAGER.pp_size is None else MOE_MANAGER.pp_size + global_dp_group = ProcessGroupMesh(pp_size, dist.get_world_size() // pp_size) + self.global_dp_group = global_dp_group.get_group_along_axis(1) + + def _clear_load(self) -> None: + self.local_load = None + + def _sync_load(self) -> Tensor: + new_load = self.local_load.clone().detach() + # all reduce load between ep group + dist.all_reduce(new_load, group=self.moe_ep_group) + # all reduce load between dp group + dist.all_reduce(new_load, group=self.moe_dp_group) + return new_load + + @staticmethod + def _get_diff_from_avg(data: List, group: int, avg: float) -> float: + return abs(sum(data[group]) / len(data[group]) - avg) + + @staticmethod + def _swap_data(data: List, group_i: int, index_i: int, group_j: int, index_j: int) -> None: + data[group_i][index_i], data[group_j][index_j] = ( + data[group_j][index_j], + data[group_i][index_i], + ) + + @staticmethod + def _normalize_data(data: List) -> List: + max_value = max(max(sublist) for sublist in data) + data = [[i / max_value for i in sublist] for sublist in data] + return data + + @staticmethod + def _get_swap_loss( + group_swap_factor: float, + swap_list: List, + group_i: int, + index_i: int, + group_j: int, + index_j: int, + ) -> float: + """ + Get swap loss. The swap loss is used to avoid the situation that + the same index is swapped twice and the same group is swapped for multiple times. + """ + swap_loss = 0 + for swap in swap_list: + for group_id, index_id in zip([group_i, group_j], [index_i, index_j]): + # the group has been swapped + if group_id in [swap[0], swap[2]]: + # the index has been swapped + # we want to avoid the situation that the same index is swapped twice + if index_id in [swap[1], swap[3]]: + swap_loss += 1e5 + # the index has not been swapped + # this is acceptable but as less as possible + else: + swap_loss += group_swap_factor + return swap_loss + + @staticmethod + def _check_convergence(data: List, avg: float, tolerance: float): + """ + Check whether the data is converged after swap. + """ + for sublist in data: + if abs(sum(sublist) / len(sublist) - avg) > tolerance * avg: + return False + return True + + def _beam_search( + self, + inputs: Tuple[List, float, List], + beam_width: int, + avg: float, + group_swap_factor: float, + ) -> List: + """ + Beam search for the best swap combination. + Specifically, we swap two elements from two groups and calculate the score. + The score is the difference between the origin group sum and the new group sum. + The larger the score, the better the swap combination. + + Args: + inputs (Tuple): (data, origin_score, swap_list) + beam_width (int): beam width for beam search + avg (float): average value of the data + group_swap_factor (float): group loss for group swap loss + + Returns: + List: results list + """ + data, origin_score, swap_list = inputs + results = [] + group_num = len(data) + group_size = len(data[0]) + origin_diff_list = [self._get_diff_from_avg(data, i, avg) for i in range(group_num)] + + for group_num_i in range(group_num): + for group_size_i in range(group_size): + for group_num_j in range(group_num_i + 1, group_num): + for group_size_j in range(group_size): + new_data = deepcopy(data) + # calculate origin group sum + origin_diff = (origin_diff_list[group_num_i] + origin_diff_list[group_num_j]) + # swap data + self._swap_data( + new_data, + group_num_i, + group_size_i, + group_num_j, + group_size_j, + ) + # calculate new group sum + new_diff = self._get_diff_from_avg(new_data, group_num_i, avg) + self._get_diff_from_avg( + new_data, group_num_j, avg) + # caculate score + new_score = origin_diff - new_diff + if new_score > 0: + new_score = origin_score + new_score + # get swap loss + swap_loss = self._get_swap_loss( + group_swap_factor, + swap_list, + group_num_i, + group_size_i, + group_num_j, + group_size_j, + ) + new_score = new_score - swap_loss + # update swap list + new_swap_list = swap_list + [(group_num_i, group_size_i, group_num_j, group_size_j)] + results.append((new_data, new_score, new_swap_list)) + # sort results + results.sort(key=lambda x: x[1], reverse=True) + # select top k results + results = results[:beam_width] + return results + + def _load_to_list(self, load: Tensor) -> List: + load_len = len(load) + assert load_len % self.local_expert_num == 0 + load_list = [] + tmp_list = [] + for i in range(len(load)): + tmp_list.append(float(load[i])) + if (i + 1) % self.local_expert_num == 0: + load_list.append(tmp_list) + tmp_list = [] + return load_list + + def _search_balance( + self, + data: List, + tolerance: Optional[float] = 0.1, + beam_width: Optional[int] = 8, + group_swap_factor: Optional[float] = 0.4, + return_swapped_data: Optional[bool] = False, + ) -> Tuple[List, List]: + """ + Search for the best swap combination to balance the data within the specified tolerance. + And return the balanced data and the swap list. The swap list is used to record the swap. + The swap list is a list of tuples. Each tuple is a swap operation. + + Args: + data (List): expert load list. + E.g. [[9.2, 8.3], [2.3, 10.0], [6.1, 7.2], [5.3, 3.2]] + This means there are 4 devices and each devices has 2 experts. + The value is the load of the expert. + tolerance (float): tolerance for balance. + beam_width (int): beam width for beam search. + group_swap_factor (float): group swap factor for group swap loss. + The bigger it is, the less times a group will be swapped. + return_swapped_data (bool): whether to return the swapped data. + + Returns: + Tuple: (balanced data, swap list). + The swap list is a list of tuples. Each tuple is a swap operation. + E.g. [(0, 0, 1, 0), (...), (...)]. The first tuple means + the first expert of the first device is swapped with the first expert + of the second device. + """ + norm_data = self._normalize_data(data) + avg = sum(sum(sublist) / len(sublist) for sublist in norm_data) / len(norm_data) + results = [(norm_data, 0, [])] + stop_flag = False + + while stop_flag == False: + new_results = [] + best_score = results[0][1] + for i in range(len(results)): + new_results.extend(self._beam_search(results[i], beam_width, avg, group_swap_factor)) + if len(new_results) == 0: + stop_flag = True + break + new_results.sort(key=lambda x: x[1], reverse=True) + new_best_score = new_results[0][1] + if new_best_score == best_score: + stop_flag = True + break + new_results = new_results[:beam_width] + results = new_results + for i in results: + if self._check_convergence(results[0][0], avg, tolerance): + stop_flag = True + break + + swap_list = results[0][2] + if return_swapped_data: + out = deepcopy(data) + for swap in swap_list: + self._swap_data(out, *swap) + return out, swap_list + else: + return swap_list + + @staticmethod + def _swap_expert_single_tensor( + weight: nn.Parameter, + expert_idx: int, + comm_group: ProcessGroup, + send_first: bool, + comm_rank: int, + ): + # exchange weight + local_weight = weight.data[expert_idx] + new_weight = torch.empty_like(local_weight) + if send_first: + dist.send(local_weight, dst=comm_rank, group=comm_group) + dist.recv(new_weight, src=comm_rank, group=comm_group) + else: + dist.recv(new_weight, src=comm_rank, group=comm_group) + dist.send(local_weight, dst=comm_rank, group=comm_group) + weight.data[expert_idx] = new_weight + + def _swap_expert_param_and_optim( + self, + weight: nn.Parameter, + expert_idx: int, + comm_group: ProcessGroup, + send_first: bool, + comm_rank: int, + optim: LowLevelZeroOptimizer, + ): + # need to update master and working param if master param exists + # else just update working param + if weight in optim.optim.state: + master_weight_ptr = None + working_weight_ptr = weight + exp_avg_ptr = optim.optim.state[working_weight_ptr]["exp_avg"] + exp_avg_sq_ptr = optim.optim.state[working_weight_ptr]["exp_avg_sq"] + else: + master_weight_ptr = optim._param_store.working_to_master_param[id(weight)] + working_weight_ptr = weight + exp_avg_ptr = optim.optim.state[master_weight_ptr]["exp_avg"] + exp_avg_sq_ptr = optim.optim.state[master_weight_ptr]["exp_avg_sq"] + + # exchange weight + self._swap_expert_single_tensor( + working_weight_ptr, + expert_idx, + comm_group, + send_first, + comm_rank, + ) + if master_weight_ptr is not None: + # TODO: exchange master weight, skip for now + # master weight is shared by dp group + tmp = working_weight_ptr.view(-1).split( + working_weight_ptr.numel() // dist.get_world_size(self.moe_dp_group))[dist.get_rank(self.moe_dp_group)] + master_weight_ptr.data.copy_(tmp.clone().detach().to(master_weight_ptr.device).to(master_weight_ptr.dtype)) + # exchange optim + self._swap_expert_single_tensor(exp_avg_ptr, expert_idx, comm_group, send_first, comm_rank) + self._swap_expert_single_tensor(exp_avg_sq_ptr, expert_idx, comm_group, send_first, comm_rank) + + def _gather_global_dp_group(self, data: Tensor) -> Tensor: + data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size(self.global_dp_group))] + dist.all_gather(data_list, data, group=self.global_dp_group) + data_list = torch.cat(data_list, dim=0) + return data_list + + def _swap_moe_param(self, swap_list: List, optim: LowLevelZeroOptimizer) -> None: + """ + Swap moe param and optim. + We use different strategies to swap expert and gate. + For expert, we exchange the param and optim of the expert by p2p. + For gate, we all gather the gate choose the part we want. + + Args: + swap_list (List) + optim (LowLevelZeroOptimizer) + """ + # get all experts weights + local_rank = dist.get_rank(self.moe_ep_group) + if self.experts.gated: + weight_list = [self.experts.wi_up, self.experts.wi_gate] + else: + weight_list = [self.experts.wi] + weight_list.append(self.experts.wo) + + # gate optim should be obtained first + gate_shape = self.gate.shape + # get master weight and optim + master_gate_weight = optim._param_store.working_to_master_param[id(self.gate)] + gate_exp_avg = optim.optim.state[master_gate_weight]["exp_avg"] + gate_exp_avg_sq = optim.optim.state[master_gate_weight]["exp_avg_sq"] + # gather + global_master_gate_weight = self._gather_global_dp_group(master_gate_weight).view(gate_shape) + global_gate_exp_avg = self._gather_global_dp_group(gate_exp_avg).view(gate_shape) + global_gate_exp_avg_sq = self._gather_global_dp_group(gate_exp_avg_sq).view(gate_shape) + assert (self.gate.shape == global_master_gate_weight.shape == global_gate_exp_avg.shape == + global_gate_exp_avg_sq.shape) + + for swap in swap_list: + source_group, source_idx, target_group, target_idx = swap + source_rank = self.moe_ep_ranks[source_group] + target_rank = self.moe_ep_ranks[target_group] + # exchange expert + if local_rank in [source_group, target_group]: + for weight in weight_list: + if local_rank == source_group: + self._swap_expert_param_and_optim( + weight, + source_idx, + self.moe_ep_group, + True, + target_rank, + optim, + ) + elif local_rank == target_group: + self._swap_expert_param_and_optim( + weight, + target_idx, + self.moe_ep_group, + False, + source_rank, + optim, + ) + # exchange gate + source_expert_pos = source_group * self.local_expert_num + source_idx + target_expert_pos = target_group * self.local_expert_num + target_idx + for gate in [ + self.gate, + global_master_gate_weight, + global_gate_exp_avg, + global_gate_exp_avg_sq, + ]: + origin_source = gate.data[source_expert_pos].clone().detach() + origin_target = gate.data[target_expert_pos].clone().detach() + gate.data[source_expert_pos], gate.data[target_expert_pos] = ( + origin_target, + origin_source, + ) + + # update gate + dp_group_rank = dist.get_rank(self.global_dp_group) + dp_group_size = dist.get_world_size(self.global_dp_group) + global_master_gate_weight = global_master_gate_weight.view(-1).split(global_master_gate_weight.numel() // + dp_group_size)[dp_group_rank] + master_gate_weight.data.copy_(global_master_gate_weight) + global_gate_exp_avg = global_gate_exp_avg.view(-1).split(global_gate_exp_avg.numel() // + dp_group_size)[dp_group_rank] + gate_exp_avg.data.copy_(global_gate_exp_avg) + global_gate_exp_avg_sq = global_gate_exp_avg_sq.view(-1).split(global_gate_exp_avg_sq.numel() // + dp_group_size)[dp_group_rank] + gate_exp_avg_sq.data.copy_(global_gate_exp_avg_sq) + + @torch.no_grad() + def update_load(self, load: Tensor) -> None: + if len(load) != self.expert_num: + padding_size = self.expert_num - len(load) + padding = torch.zeros(padding_size, dtype=load.dtype, device=load.device) + load = torch.cat((load, padding), dim=0) + if self.local_load is None: + self.local_load = load + else: + self.local_load += load + + @torch.no_grad() + def balance_load(self, optim: LowLevelZeroOptimizer) -> None: + # prepare load + load = self._sync_load() + load = self._load_to_list(load) + # search balance + swap_list = self._search_balance(load) + # swap expert and gate + self._swap_moe_param(swap_list, optim) + # clear load + self._clear_load() diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py index 1e949bb9a6dd..e3659ef43fbd 100644 --- a/colossalai/moe/manager.py +++ b/colossalai/moe/manager.py @@ -27,6 +27,13 @@ def __init__(self): self.mode = None self.use_kernel_optim = False self.use_ep_inside = None + self.pp_size = None + + # load balance param + self.load_balance = None + self.tolerance = None + self.beam_width = None + self.group_swap_factor = None self.has_setup = False self._parallel_info_dict = dict() @@ -39,16 +46,22 @@ def parallel_info_dict(self): def is_initialized(self): return self.has_setup - def setup(self, - seed: int, - use_kernel_optim: bool = False, - parallel: str = None, - mode: str = "dynamic", - max_ep_size: int = 8, - fixed_dp_size: int = 0, - fixed_ep_size: int = 0, - fixed_pp_size: int = 0, - use_ep_inside: bool = True) -> None: + def setup( + self, + seed: int, + use_kernel_optim: bool = False, + parallel: str = None, + mode: str = "dynamic", + max_ep_size: int = 8, + fixed_dp_size: int = 0, + fixed_ep_size: int = 0, + fixed_pp_size: int = 0, + use_ep_inside: bool = True, + enable_load_balance: bool = False, + tolerance: float = 0.1, + beam_width: int = 8, + group_swap_factor: float = 0.4, + ) -> None: """ Setup MoE distributed context. @@ -91,6 +104,12 @@ def setup(self, # Users can close kernel optimization manually self.use_kernel_optim = use_kernel_optim + # update load balance + self.load_balance = enable_load_balance + self.tolerance = tolerance + self.beam_width = beam_width + self.group_swap_factor = group_swap_factor + self.has_setup = True def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]: diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py new file mode 100644 index 000000000000..b4eea04bc85a --- /dev/null +++ b/tests/test_moe/test_moe_load_balance.py @@ -0,0 +1,193 @@ +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel +from colossalai.moe.layers import apply_load_balance +from colossalai.moe.manager import MOE_MANAGER +from colossalai.tensor.moe_tensor.api import is_moe_tensor +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel + + +def split_ddp_grad(grad, world_size): + with torch.no_grad(): + grad = grad.clone().detach().flatten() + padding_size = (world_size - grad.numel() % world_size) % world_size + if padding_size > 0: + grad = torch.nn.functional.pad(grad, [0, padding_size]) + splited_grad = grad.split(grad.numel() // world_size) + return splited_grad + + +def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): + model.train() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() + + if isinstance(model, LowLevelZeroModel): + optimizer.backward(loss) + else: + loss.backward() + return y + + +def run_zero_optim_test(local_rank, world_size, stage=1): + criterion = torch.nn.CrossEntropyLoss() + + MOE_MANAGER.__init__() + MOE_MANAGER.setup( + seed=42, + parallel="EP", + enable_load_balance=True, + tolerance=0.1, + beam_width=8, + group_swap_factor=0.4, + ) + zero_model = MoeModel(checkpoint=True) + zero_optimizer = torch.optim.Adam(zero_model.parameters()) + plugin = LowLevelZeroPlugin(stage=stage, precision="bf16", verbose=True) + booster = Booster(plugin=plugin) + zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed=42, parallel="EP") + torch_model = MoeModel(checkpoint=True) + for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): + torch_param.data.copy_(zero_param.data) + torch_optimizer = torch.optim.Adam(torch_model.parameters()) + torch_model = torch_model.cuda().bfloat16() + grad_handler = MoeGradientHandler(torch_model) + + # run to update expert load + data = torch.randn(16, 4).cuda().bfloat16() / 1000 / (local_rank + 1) + label = torch.randint(0, 4, (16,)).cuda() + + # run torch model twice + run_fwd_bwd(torch_model, data, label, criterion, None) + grad_handler.handle_gradient() + torch_optimizer.step() + torch_optimizer.zero_grad() + run_fwd_bwd(torch_model, data, label, criterion, None) + grad_handler.handle_gradient() + + # get optim and load status in zero model + run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + zero_optimizer.step() + zero_optimizer.zero_grad() + with torch.no_grad(): + origin_out = zero_model(data) + + # load balance + apply_load_balance(zero_model, zero_optimizer) + + # run again to test + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + torch.allclose(origin_out, zero_out) + + # assert optim + torch_optimizer.step() + torch_out = run_fwd_bwd(torch_model, data, label, criterion, None) + zero_optimizer.step() + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + assert torch.allclose(zero_out, torch_out), f"zero_out:{zero_out}\ntorch_out{torch_out}" + + +def run_hybrid_zero_optim_test(local_rank, world_size, stage=1): + criterion = torch.nn.CrossEntropyLoss() + data = torch.randn(16, 4).cuda() + label = torch.randint(0, 4, (16,)).cuda() + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed=42, parallel=None) + torch_model = MoeModel(checkpoint=True) + torch_optimizer = torch.optim.Adam(torch_model.parameters()) + torch_model = torch_model.cuda() + + MOE_MANAGER.__init__() + MOE_MANAGER.setup( + seed=42, + max_ep_size=2, + use_ep_inside=False, + parallel="EP", + enable_load_balance=True, + tolerance=0.1, + beam_width=8, + group_swap_factor=0.4, + ) + zero_model = MoeModel(checkpoint=True) + extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group + ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group) + ep_size = MOE_MANAGER.parallel_info_dict[2].ep_size + for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): + if is_moe_tensor(zero_param): + num_expert = torch_param.data.shape[0] + zero_param.data.copy_(torch_param.data[ep_rank * (num_expert // ep_size):(ep_rank + 1) * + (num_expert // ep_size)].detach().clone()) + else: + zero_param.data.copy_(torch_param.data.detach().clone()) + zero_optimizer = torch.optim.Adam(zero_model.parameters()) + plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") + plugin.zero_optim_kwargs["extra_dp_process_group"] = extra_dp_group + booster = Booster(plugin=plugin) + zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) + + # run torch for twice + run_fwd_bwd(torch_model, data, label, criterion, None) + torch_optimizer.step() + torch_optimizer.zero_grad() + run_fwd_bwd(torch_model, data, label, criterion, None) + torch_optimizer.step() + + # run zero + run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + zero_optimizer.step() + zero_optimizer.zero_grad() + with torch.no_grad(): + origin_out = zero_model(data) + + # load balance + apply_load_balance(zero_model, zero_optimizer) + + # assert out + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + torch.allclose(origin_out, zero_out) + + # assert optim + zero_optimizer.step() + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + torch_out = run_fwd_bwd(torch_model, data, label, criterion, None) + assert torch.allclose(zero_out, torch_out, atol=8e-4), f"zero_out:{zero_out}\ntorch_out{torch_out}" + + +def run_dist(rank, world_size, port): + colossalai.launch( + config=dict(), + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) + run_zero_optim_test(rank, world_size, stage=1) + run_zero_optim_test(rank, world_size, stage=2) + run_hybrid_zero_optim_test(rank, world_size, stage=1) + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_moe_load_balance(world_size): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_moe_load_balance(world_size=4)