From 0f240268c2ff85089123480c7b8628b2b335491e Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Sun, 15 Oct 2023 16:16:19 +0800 Subject: [PATCH 01/10] add load balance --- colossalai/moe/experts.py | 1 - colossalai/moe/load_balance.py | 317 +++++++++++++++++++++++++++++++++ 2 files changed, 317 insertions(+), 1 deletion(-) create mode 100644 colossalai/moe/load_balance.py diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index 81a7b21544e4..d2ca1cbc8276 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -1,5 +1,4 @@ import math -from contextlib import nullcontext from typing import Callable, Optional, Tuple import torch diff --git a/colossalai/moe/load_balance.py b/colossalai/moe/load_balance.py new file mode 100644 index 000000000000..5fbb1bdb839f --- /dev/null +++ b/colossalai/moe/load_balance.py @@ -0,0 +1,317 @@ +from copy import deepcopy +from typing import List, Optional, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor, nn +from torch.distributed import ProcessGroup + +from colossalai.moe.experts import BaseMLPExperts + + +class LoadBalance: + + def __init__( + self, + experts: BaseMLPExperts, + gate: nn.Parameter, + local_expert_num: float, + ep_group: ProcessGroup, + dp_group: ProcessGroup, + tolerance: Optional[float] = 0.1, + beam_width: Optional[int] = 8, + group_swap_factor: Optional[float] = 0.4, + ) -> None: + self.experts: BaseMLPExperts = experts + self.gate: nn.Parameter = gate + self.ep_group: ProcessGroup = ep_group + self.dp_group: ProcessGroup = dp_group + self.tolerance = tolerance + self.beam_width = beam_width + self.group_swap_factor = group_swap_factor + self.local_expert_num = local_expert_num + self.local_load = None + + def _clear_load(self) -> None: + self.local_load = None + + def _sync_load(self) -> Tensor: + # all gather load between ep group + new_load = [torch.zeros_like(self.local_load) for _ in range(dist.get_world_size(self.ep_group))] + dist.all_gather(new_load, self.local_load, group=self.ep_group) + new_load = torch.cat(new_load, dim=0) + # all reduce load between dp group + dist.all_reduce(new_load, group=self.dp_group) + return new_load + + @staticmethod + def _get_diff_from_avg(data: List, group: int, avg: float) -> float: + return abs(sum(data[group]) / len(data[group]) - avg) + + @staticmethod + def _swap_data(data: List, group_i: int, index_i: int, group_j: int, index_j: int) -> None: + data[group_i][index_i], data[group_j][index_j] = ( + data[group_j][index_j], + data[group_i][index_i], + ) + + @staticmethod + def _normalize_data(data: List) -> List: + max_value = max(max(sublist) for sublist in data) + data = [[i / max_value for i in sublist] for sublist in data] + return data + + @staticmethod + def _get_swap_loss( + group_swap_factor: float, + swap_list: List, + group_i: int, + index_i: int, + group_j: int, + index_j: int, + ) -> float: + """ + Get swap loss. The swap loss is used to avoid the situation that + the same index is swapped twice and the same group is swapped for multiple times. + """ + swap_loss = 0 + for swap in swap_list: + for group_id, index_id in zip([group_i, group_j], [index_i, index_j]): + # the group has been swapped + if group_id in [swap[0], swap[2]]: + # the index has been swapped + # we want to avoid the situation that the same index is swapped twice + if index_id in [swap[1], swap[3]]: + swap_loss += 1e5 + # the index has not been swapped + # this is acceptable but as less as possible + else: + swap_loss += group_swap_factor + return swap_loss + + @staticmethod + def _check_convergence(data: List, avg: float, tolerance: float): + """ + Check whether the data is converged after swap. + """ + for sublist in data: + if abs(sum(sublist) / len(sublist) - avg) > tolerance * avg: + return False + return True + + def _beam_search( + self, + inputs: Tuple[List, float, List], + beam_width: int, + avg: float, + group_swap_factor: float, + ) -> List: + """ + Beam search for the best swap combination. + Specifically, we swap two elements from two groups and calculate the score. + The score is the difference between the origin group sum and the new group sum. + The larger the score, the better the swap combination. + + Args: + inputs (Tuple): (data, origin_score, swap_list) + beam_width (int): beam width for beam search + avg (float): average value of the data + group_swap_factor (float): group loss for group swap loss + + Returns: + List: results list + """ + data, origin_score, swap_list = inputs + results = [] + group_num = len(data) + group_size = len(data[0]) + origin_diff_list = [self._get_diff_from_avg(data, i, avg) for i in range(group_num)] + + for group_num_i in range(group_num): + for group_size_i in range(group_size): + for group_num_j in range(group_num_i + 1, group_num): + for group_size_j in range(group_size): + new_data = deepcopy(data) + # calculate origin group sum + origin_diff = (origin_diff_list[group_num_i] + origin_diff_list[group_num_j]) + # swap data + self._swap_data( + new_data, + group_num_i, + group_size_i, + group_num_j, + group_size_j, + ) + # calculate new group sum + new_diff = self._get_diff_from_avg(new_data, group_num_i, avg) + self._get_diff_from_avg( + new_data, group_num_j, avg) + # caculate score + new_score = origin_diff - new_diff + if new_score > 0: + new_score = origin_score + new_score + # get swap loss + swap_loss = self._get_swap_loss( + group_swap_factor, + swap_list, + group_num_i, + group_size_i, + group_num_j, + group_size_j, + ) + new_score = new_score - swap_loss + # update swap list + new_swap_list = swap_list + [(group_num_i, group_size_i, group_num_j, group_size_j)] + results.append((new_data, new_score, new_swap_list)) + # sort results + results.sort(key=lambda x: x[1], reverse=True) + # select top k results + results = results[:beam_width] + return results + + def _load_to_list(self, load: Tensor) -> List: + load_len = len(load) + load_list = [] + for _ in range(load_len): + tmp_list = [] + for j in range(self.local_expert_num): + tmp_list.append(float(load[j])) + load_list.append(tmp_list) + return load_list + + def _search_balance( + self, + data: List, + tolerance: Optional[float] = 0.1, + beam_width: Optional[int] = 8, + group_swap_factor: Optional[float] = 0.4, + return_swapped_data: Optional[bool] = False, + ) -> Tuple[List, List]: + """ + Search for the best swap combination to balance the data within the specified tolerance. + And return the balanced data and the swap list. The swap list is used to record the swap. + The swap list is a list of tuples. Each tuple is a swap operation. + + Args: + data (List): expert load list. + E.g. [[9.2, 8.3], [2.3, 10.0], [6.1, 7.2], [5.3, 3.2]] + This means there are 4 devices and each devices has 2 experts. + The value is the load of the expert. + tolerance (float): tolerance for balance. + beam_width (int): beam width for beam search. + group_swap_factor (float): group swap factor for group swap loss. + The bigger it is, the less times a group will be swapped. + return_swapped_data (bool): whether to return the swapped data. + + Returns: + Tuple: (balanced data, swap list). + The swap list is a list of tuples. Each tuple is a swap operation. + E.g. [(0, 0, 1, 0), (...), (...)]. The first tuple means + the first expert of the first device is swapped with the first expert + of the second device. + """ + norm_data = self._normalize_data(data) + avg = sum(sum(sublist) / len(sublist) for sublist in norm_data) / len(norm_data) + results = [(norm_data, 0, [])] + stop_flag = False + + while stop_flag == False: + new_results = [] + best_score = results[0][1] + for i in range(len(results)): + new_results.extend(self._beam_search(results[i], beam_width, avg, group_swap_factor)) + if len(new_results) == 0: + stop_flag = True + break + new_results.sort(key=lambda x: x[1], reverse=True) + new_best_score = new_results[0][1] + if new_best_score == best_score: + stop_flag = True + break + new_results = new_results[:beam_width] + results = new_results + for i in results: + if self._check_convergence(results[0][0], avg, tolerance): + stop_flag = True + break + + swap_list = results[0][2] + if return_swapped_data: + out = deepcopy(data) + for swap in swap_list: + self._swap_data(out, *swap) + return out, swap_list + else: + return swap_list + + def _swap_expert(self, swap_list: List) -> None: + local_rank = dist.get_rank(self.ep_group) + for swap in swap_list: + source_group, source_idx, target_group, target_idx = swap + # exchange expert + if local_rank == source_group: + local_expert = self.experts[source_idx] + new_expert = torch.empty_like(local_expert) + dist.send(local_expert, dst=target_group, group=self.ep_group) + dist.recv(new_expert, src=target_group, group=self.ep_group) + self.experts[source_idx] = new_expert + elif local_rank == target_group: + local_expert = self.experts[target_idx] + new_expert = torch.empty_like(local_expert) + dist.recv(new_expert, src=source_group, group=self.ep_group) + dist.send(local_expert, dst=source_group, group=self.ep_group) + self.experts[target_idx] = new_expert + # exchange gate + source_expert_pos = source_group * self.local_expert_num + source_idx + target_expert_pos = target_group * self.local_expert_num + target_idx + self.gate.data[source_expert_pos], self.gate.data[target_expert_pos] = self.gate.data[ + target_expert_pos], self.gate.data[source_expert_pos] + + @torch.no_grad() + def update_load(self, load: Tensor) -> None: + if self.local_load is None: + self.local_load = load + else: + self.local_load += load + + @torch.no_grad() + def balance_load(self) -> None: + # prepare load + load = self._sync_load() + load = self._load_to_list(load) + # search balance + swap_list = self._search_balance(load) + # swap expert and gate + self._swap_expert(swap_list) + # clear load + self._clear_load() + + +data = [[9.2, 8.3], [2.3, 10.0], [6.1, 7.2], [5.3, 3.2]] +data = [ + [10, 397, 479, 661], + [447, 654, 552, 312], + [769, 339, 780, 705], + [491, 562, 837, 434], + [509, 291, 626, 851], + [539, 484, 538, 406], + [541, 660, 160, 498], +] +result, swap_list = LoadBalance()._search_balance( + data=data, + tolerance=0.1, + beam_width=8, + group_swap_factor=0.4, + return_swapped_data=True, +) +if result: + print("Balanced Lists:") + for sublist in result: + print(f"{sublist} sum: {sum(sublist)}") +else: + print("Unable to balance lists within the specified constraints.") +print(f"Swap List:\n{swap_list}") +swap_dict = {i: 0 for i in range(len(data))} +for swap in swap_list: + swap_dict[swap[0]] += 1 + swap_dict[swap[2]] += 1 +print(f"Swap Dict:\n{swap_dict}") From 3bb9a38fa1a1d4fb5b78673b52acfad145c4b747 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Sun, 15 Oct 2023 18:30:37 +0800 Subject: [PATCH 02/10] update test --- tests/test_moe/test_moe_load_balance.py | 108 ++++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 tests/test_moe/test_moe_load_balance.py diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py new file mode 100644 index 000000000000..72dca2e96fca --- /dev/null +++ b/tests/test_moe/test_moe_load_balance.py @@ -0,0 +1,108 @@ +import pytest +import torch + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel +from colossalai.moe.layers import apply_load_balance +from colossalai.moe.manager import MOE_MANAGER +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel + + +def split_ddp_grad(grad, world_size): + with torch.no_grad(): + grad = grad.clone().detach().flatten() + padding_size = (world_size - grad.numel() % world_size) % world_size + if padding_size > 0: + grad = torch.nn.functional.pad(grad, [0, padding_size]) + splited_grad = grad.split(grad.numel() // world_size) + return splited_grad + + +def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): + model.train() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() + + if isinstance(model, LowLevelZeroModel): + optimizer.backward(loss) + else: + loss.backward() + return y + + +def run_zero_optim_test(local_rank, world_size, stage=1): + criterion = torch.nn.CrossEntropyLoss() + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed=42, parallel="EP") + zero_model = MoeModel(checkpoint=True) + zero_optimizer = torch.optim.Adam(zero_model.parameters()) + plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") + booster = Booster(plugin=plugin) + zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed=42, parallel="EP") + torch_model = MoeModel(checkpoint=True) + for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): + torch_param.data.copy_(zero_param.data) + torch_optimizer = torch.optim.Adam(torch_model.parameters()) + torch_model = torch_model.cuda() + grad_handler = MoeGradientHandler(torch_model) + + # run to update expert load + data = torch.randn(16, 4).cuda() / (local_rank + 1) + label = torch.randint(0, 4, (16,)).cuda() + run_fwd_bwd(torch_model, data, label, criterion, None) + grad_handler.handle_gradient() + with torch.no_grad(): + zero_model(data) + + # load balance + apply_load_balance(zero_model) + + # run again to test + run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + + for (zero_name, zero_param), (torch_name, torch_param) in zip(zero_model.module.named_parameters(), + torch_model.named_parameters()): + assert zero_name == torch_name + zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param)) + if hasattr(zero_param, "moe_info"): + assert len(zero_grad_list) == 0 + assert torch.allclose(zero_param.grad, torch_param.grad) + else: + assert len(zero_grad_list) > 0 + torch_grad_list = split_ddp_grad(torch_param.grad, world_size) + if stage == 2: + torch_grad_list = torch_grad_list[local_rank:local_rank + 1] + assert len(zero_grad_list) == len(torch_grad_list) + for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list): + assert torch.allclose(zero_grad, torch_grad), f"{zero_name} {zero_grad} {torch_grad}" + + torch_optimizer.zero_grad() + zero_optimizer.zero_grad() + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_zero_optim_test(rank, world_size, stage=1) + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2]) +@rerun_if_address_is_in_use() +def test_moe_zero_optim(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_moe_zero_optim(world_size=2) From 8be2fce280681bdc4d481237e9be6bb455a493e3 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Sun, 15 Oct 2023 18:37:37 +0800 Subject: [PATCH 03/10] update param exchange --- colossalai/moe/layers.py | 45 ++++++++++++++++++++++++- colossalai/moe/load_balance.py | 33 +----------------- colossalai/moe/manager.py | 38 +++++++++++++++------ tests/test_moe/test_moe_load_balance.py | 7 +++- 4 files changed, 79 insertions(+), 44 deletions(-) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index 036bd32ae7c0..d3480287a64f 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -2,15 +2,17 @@ from typing import Optional, Tuple import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter from colossalai.moe.experts import BaseMLPExperts, get_expert_class +from colossalai.moe.load_balance import LoadBalancer from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.routers import MoeRouter, get_router_cls from colossalai.moe.utils import get_noise_generator -from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size +from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size class SparseMLP(nn.Module): @@ -72,6 +74,7 @@ def __init__( # moe router noisy_func = get_noise_generator(noisy_policy, num_experts) router_cls = get_router_cls(top_k) + self.topk = top_k self.router: MoeRouter = router_cls( capacity_factor_train=capacity_factor_train, capacity_factor_eval=capacity_factor_eval, @@ -91,13 +94,29 @@ def __init__( if self.expert_parallel is not None: self.ep_group = get_ep_group(self.experts) self.ep_size = get_ep_size(self.experts) + self.dp_group = get_dp_group(self.experts) else: self.ep_group = None + self.dp_group = None self.num_local_experts = self.experts.num_local_experts # gate self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) + # load balance + self.enable_load_balance = MOE_MANAGER.load_balance + if self.enable_load_balance == True: + self.load_balancer = LoadBalancer( + experts=self.experts, + gate=self.gate_weight, + local_expert_num=self.num_local_experts, + ep_group=self.ep_group, + dp_group=self.dp_group, + tolerance=MOE_MANAGER.tolerance, + beam_width=MOE_MANAGER.beam_width, + group_swap_factor=MOE_MANAGER.group_swap_factor, + ) + # init param self.reset_parameters() @@ -121,6 +140,14 @@ def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: fp32_weight = self.gate_weight.to(torch.float) gate_output = F.linear(fp32_input, fp32_weight) + # update expert load + if self.enable_load_balance == True: + with torch.no_grad(): + # TODO: optimize computation + expert_load = torch.topk(gate_output, k=self.topk, dim=-1)[1] + expert_load = torch.bincount(expert_load.view(-1)) + self.load_balancer.update_load(expert_load) + # the result from the router route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) @@ -257,3 +284,19 @@ def get_chunk_slice(idx: int, gap: int) -> Tuple[slice]: # sync for async op torch.cuda.synchronize() return out + + +def apply_load_balance(model: nn.Module) -> None: + """ + apply load balance to every experts in the model + """ + + def _apply_recursive(module: nn.Module): + for _, sub_module in module.named_children(): + if isinstance(sub_module, SparseMLP): + if sub_module.enable_load_balance == True: + sub_module.load_balancer.balance_load() + _apply_recursive(sub_module) + + _apply_recursive(model) + dist.barrier() diff --git a/colossalai/moe/load_balance.py b/colossalai/moe/load_balance.py index 5fbb1bdb839f..855d99b4ae32 100644 --- a/colossalai/moe/load_balance.py +++ b/colossalai/moe/load_balance.py @@ -9,7 +9,7 @@ from colossalai.moe.experts import BaseMLPExperts -class LoadBalance: +class LoadBalancer: def __init__( self, @@ -284,34 +284,3 @@ def balance_load(self) -> None: self._swap_expert(swap_list) # clear load self._clear_load() - - -data = [[9.2, 8.3], [2.3, 10.0], [6.1, 7.2], [5.3, 3.2]] -data = [ - [10, 397, 479, 661], - [447, 654, 552, 312], - [769, 339, 780, 705], - [491, 562, 837, 434], - [509, 291, 626, 851], - [539, 484, 538, 406], - [541, 660, 160, 498], -] -result, swap_list = LoadBalance()._search_balance( - data=data, - tolerance=0.1, - beam_width=8, - group_swap_factor=0.4, - return_swapped_data=True, -) -if result: - print("Balanced Lists:") - for sublist in result: - print(f"{sublist} sum: {sum(sublist)}") -else: - print("Unable to balance lists within the specified constraints.") -print(f"Swap List:\n{swap_list}") -swap_dict = {i: 0 for i in range(len(data))} -for swap in swap_list: - swap_dict[swap[0]] += 1 - swap_dict[swap[2]] += 1 -print(f"Swap Dict:\n{swap_dict}") diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py index 1e949bb9a6dd..5fa93a2899ec 100644 --- a/colossalai/moe/manager.py +++ b/colossalai/moe/manager.py @@ -28,6 +28,12 @@ def __init__(self): self.use_kernel_optim = False self.use_ep_inside = None + # load balance param + self.load_balance = None + self.tolerance = None + self.beam_width = None + self.group_swap_factor = None + self.has_setup = False self._parallel_info_dict = dict() @@ -39,16 +45,22 @@ def parallel_info_dict(self): def is_initialized(self): return self.has_setup - def setup(self, - seed: int, - use_kernel_optim: bool = False, - parallel: str = None, - mode: str = "dynamic", - max_ep_size: int = 8, - fixed_dp_size: int = 0, - fixed_ep_size: int = 0, - fixed_pp_size: int = 0, - use_ep_inside: bool = True) -> None: + def setup( + self, + seed: int, + use_kernel_optim: bool = False, + parallel: str = None, + mode: str = "dynamic", + max_ep_size: int = 8, + fixed_dp_size: int = 0, + fixed_ep_size: int = 0, + fixed_pp_size: int = 0, + use_ep_inside: bool = True, + enable_load_balance: bool = False, + tolerance: float = 0.1, + beam_width: int = 8, + group_swap_factor: float = 0.4, + ) -> None: """ Setup MoE distributed context. @@ -91,6 +103,12 @@ def setup(self, # Users can close kernel optimization manually self.use_kernel_optim = use_kernel_optim + # update load balance + self.load_balance = enable_load_balance + self.tolerance = tolerance + self.beam_width = beam_width + self.group_swap_factor = group_swap_factor + self.has_setup = True def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]: diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py index 72dca2e96fca..1bc86ee7f0d1 100644 --- a/tests/test_moe/test_moe_load_balance.py +++ b/tests/test_moe/test_moe_load_balance.py @@ -42,7 +42,12 @@ def run_zero_optim_test(local_rank, world_size, stage=1): criterion = torch.nn.CrossEntropyLoss() MOE_MANAGER.__init__() - MOE_MANAGER.setup(seed=42, parallel="EP") + MOE_MANAGER.setup(seed=42, + parallel="EP", + enable_load_balance=True, + tolerance=0.1, + beam_width=8, + group_swap_factor=0.4) zero_model = MoeModel(checkpoint=True) zero_optimizer = torch.optim.Adam(zero_model.parameters()) plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") From 2ac2b82288b754b59ea7fa807a4c23fe736569b6 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Mon, 16 Oct 2023 00:55:44 +0800 Subject: [PATCH 04/10] pass test --- colossalai/moe/layers.py | 7 +- colossalai/moe/load_balance.py | 156 +++++++++++++++++++----- colossalai/moe/manager.py | 1 + tests/test_moe/test_moe_load_balance.py | 48 ++++---- 4 files changed, 153 insertions(+), 59 deletions(-) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index d3480287a64f..bd878047da78 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -1,5 +1,5 @@ import math -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import torch import torch.distributed as dist @@ -110,6 +110,7 @@ def __init__( experts=self.experts, gate=self.gate_weight, local_expert_num=self.num_local_experts, + expert_num=self.num_experts, ep_group=self.ep_group, dp_group=self.dp_group, tolerance=MOE_MANAGER.tolerance, @@ -286,7 +287,7 @@ def get_chunk_slice(idx: int, gap: int) -> Tuple[slice]: return out -def apply_load_balance(model: nn.Module) -> None: +def apply_load_balance(model: nn.Module, optim: Any) -> None: """ apply load balance to every experts in the model """ @@ -295,7 +296,7 @@ def _apply_recursive(module: nn.Module): for _, sub_module in module.named_children(): if isinstance(sub_module, SparseMLP): if sub_module.enable_load_balance == True: - sub_module.load_balancer.balance_load() + sub_module.load_balancer.balance_load(optim) _apply_recursive(sub_module) _apply_recursive(model) diff --git a/colossalai/moe/load_balance.py b/colossalai/moe/load_balance.py index 855d99b4ae32..302cd1c1526a 100644 --- a/colossalai/moe/load_balance.py +++ b/colossalai/moe/load_balance.py @@ -1,12 +1,15 @@ from copy import deepcopy -from typing import List, Optional, Tuple +from typing import Any, List, Optional, Tuple import torch import torch.distributed as dist from torch import Tensor, nn from torch.distributed import ProcessGroup +from colossalai.cluster import ProcessGroupMesh from colossalai.moe.experts import BaseMLPExperts +from colossalai.moe.manager import MOE_MANAGER +from colossalai.zero.low_level import LowLevelZeroOptimizer class LoadBalancer: @@ -15,7 +18,8 @@ def __init__( self, experts: BaseMLPExperts, gate: nn.Parameter, - local_expert_num: float, + local_expert_num: int, + expert_num: int, ep_group: ProcessGroup, dp_group: ProcessGroup, tolerance: Optional[float] = 0.1, @@ -24,24 +28,28 @@ def __init__( ) -> None: self.experts: BaseMLPExperts = experts self.gate: nn.Parameter = gate - self.ep_group: ProcessGroup = ep_group - self.dp_group: ProcessGroup = dp_group + self.moe_ep_group: ProcessGroup = ep_group + self.moe_dp_group: ProcessGroup = dp_group self.tolerance = tolerance self.beam_width = beam_width self.group_swap_factor = group_swap_factor self.local_expert_num = local_expert_num + self.expert_num = expert_num self.local_load = None + # TODO: use a global process group mesh + pp_size = 1 if MOE_MANAGER.pp_size is None else MOE_MANAGER.pp_size + global_dp_group = ProcessGroupMesh(pp_size, dist.get_world_size() // pp_size) + self.global_dp_group = global_dp_group.get_group_along_axis(1) def _clear_load(self) -> None: self.local_load = None def _sync_load(self) -> Tensor: - # all gather load between ep group - new_load = [torch.zeros_like(self.local_load) for _ in range(dist.get_world_size(self.ep_group))] - dist.all_gather(new_load, self.local_load, group=self.ep_group) - new_load = torch.cat(new_load, dim=0) + new_load = self.local_load.clone().detach() + # all reduce load between ep group + dist.all_reduce(new_load, group=self.moe_ep_group) # all reduce load between dp group - dist.all_reduce(new_load, group=self.dp_group) + dist.all_reduce(new_load, group=self.moe_dp_group) return new_load @staticmethod @@ -170,12 +178,14 @@ def _beam_search( def _load_to_list(self, load: Tensor) -> List: load_len = len(load) + assert load_len % self.local_expert_num == 0 load_list = [] - for _ in range(load_len): - tmp_list = [] - for j in range(self.local_expert_num): - tmp_list.append(float(load[j])) - load_list.append(tmp_list) + tmp_list = [] + for i in range(len(load)): + tmp_list.append(float(load[i])) + if (i + 1) % self.local_expert_num == 0: + load_list.append(tmp_list) + tmp_list = [] return load_list def _search_balance( @@ -243,44 +253,126 @@ def _search_balance( else: return swap_list - def _swap_expert(self, swap_list: List) -> None: - local_rank = dist.get_rank(self.ep_group) + @staticmethod + def _swap_expert_single_tensor( + weight: nn.Parameter, + expert_idx: int, + comm_group: ProcessGroup, + send_first: bool, + comm_rank: int, + working_weight: nn.Parameter = None, + ): + # exchange weight + local_weight = weight.data[expert_idx] + new_weight = torch.empty_like(local_weight) + if send_first: + dist.send(local_weight, dst=comm_rank, group=comm_group) + dist.recv(new_weight, src=comm_rank, group=comm_group) + else: + dist.recv(new_weight, src=comm_rank, group=comm_group) + dist.send(local_weight, dst=comm_rank, group=comm_group) + weight.data[expert_idx] = new_weight + if working_weight is not None: + working_weight.data[expert_idx] = new_weight.to(working_weight.device).to(working_weight.dtype) + + def _swap_expert_param_and_optim( + self, + weight: nn.Parameter, + expert_idx: int, + comm_group: ProcessGroup, + send_first: bool, + comm_rank: int, + optim: LowLevelZeroOptimizer, + ): + # need to update master and working param if master param exists + # else just update working param + if weight in optim.optim.state: + master_weight_ptr = weight + working_weight_ptr = None + else: + master_weight_ptr = optim._param_store.working_to_master_param[id(weight)] + working_weight_ptr = master_weight_ptr + exp_avg_ptr = optim.optim.state[master_weight_ptr]['exp_avg'] + exp_avg_sq_ptr = optim.optim.state[master_weight_ptr]['exp_avg_sq'] + + # exchange weight + self._swap_expert_single_tensor(master_weight_ptr, expert_idx, comm_group, send_first, comm_rank, + working_weight_ptr) + # exchange optim + self._swap_expert_single_tensor(exp_avg_ptr, expert_idx, comm_group, send_first, comm_rank) + self._swap_expert_single_tensor(exp_avg_sq_ptr, expert_idx, comm_group, send_first, comm_rank) + + def gather_global_dp_group(self, data: Tensor) -> Tensor: + data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size(self.global_dp_group))] + dist.all_gather(data_list, data, group=self.global_dp_group) + data_list = torch.cat(data_list, dim=0) + return data_list + + def _swap_moe_param(self, swap_list: List, optim: LowLevelZeroOptimizer) -> None: + # get all experts weights + local_rank = dist.get_rank(self.moe_ep_group) + if self.experts.gated: + weight_list = [self.experts.wi_up, self.experts.wi_gate] + else: + weight_list = [self.experts.wi] + weight_list.append(self.experts.wo) + + # gate optim should be obtained first + local_range = slice(local_rank * self.local_expert_num, (local_rank + 1) * self.local_expert_num) + local_gate_shape = self.gate[local_range].shape + # get master weight and optim + master_gate_weight = optim._param_store.working_to_master_param[id(self.gate)] + gate_exp_avg = optim.optim.state[master_gate_weight]['exp_avg'] + gate_exp_avg_sq = optim.optim.state[master_gate_weight]['exp_avg_sq'] + # gather + global_master_gate_weight = self.gather_global_dp_group(master_gate_weight.view(local_gate_shape)) + global_gate_exp_avg = self.gather_global_dp_group(gate_exp_avg.view(local_gate_shape)) + global_gate_exp_avg_sq = self.gather_global_dp_group(gate_exp_avg_sq.view(local_gate_shape)) + assert self.gate.shape == global_master_gate_weight.shape == global_gate_exp_avg.shape == global_gate_exp_avg_sq.shape + for swap in swap_list: source_group, source_idx, target_group, target_idx = swap # exchange expert - if local_rank == source_group: - local_expert = self.experts[source_idx] - new_expert = torch.empty_like(local_expert) - dist.send(local_expert, dst=target_group, group=self.ep_group) - dist.recv(new_expert, src=target_group, group=self.ep_group) - self.experts[source_idx] = new_expert - elif local_rank == target_group: - local_expert = self.experts[target_idx] - new_expert = torch.empty_like(local_expert) - dist.recv(new_expert, src=source_group, group=self.ep_group) - dist.send(local_expert, dst=source_group, group=self.ep_group) - self.experts[target_idx] = new_expert + if local_rank in [source_group, target_group]: + for weight in weight_list: + if local_rank == source_group: + self._swap_expert_param_and_optim(weight, source_idx, self.moe_ep_group, True, target_group, + optim) + elif local_rank == target_group: + self._swap_expert_param_and_optim(weight, target_idx, self.moe_ep_group, False, source_group, + optim) # exchange gate source_expert_pos = source_group * self.local_expert_num + source_idx target_expert_pos = target_group * self.local_expert_num + target_idx - self.gate.data[source_expert_pos], self.gate.data[target_expert_pos] = self.gate.data[ - target_expert_pos], self.gate.data[source_expert_pos] + for gate in [self.gate, global_master_gate_weight, global_gate_exp_avg, global_gate_exp_avg_sq]: + origin_source = gate.data[source_expert_pos].clone().detach() + origin_target = gate.data[target_expert_pos].clone().detach() + gate.data[source_expert_pos], gate.data[target_expert_pos] = origin_target, origin_source + + # update gate + master_gate_weight.data.copy_(global_master_gate_weight[local_range].data.view(-1)) + gate_exp_avg.data.copy_(global_gate_exp_avg[local_range].data.view(-1)) + gate_exp_avg_sq.data.copy_(global_gate_exp_avg_sq[local_range].data.view(-1)) @torch.no_grad() def update_load(self, load: Tensor) -> None: + if len(load) != self.expert_num: + padding_size = self.expert_num - len(load) + padding = torch.zeros(padding_size, dtype=load.dtype, device=load.device) + load = torch.cat((load, padding), dim=0) if self.local_load is None: self.local_load = load else: self.local_load += load @torch.no_grad() - def balance_load(self) -> None: + def balance_load(self, optim: LowLevelZeroOptimizer) -> None: # prepare load load = self._sync_load() load = self._load_to_list(load) # search balance swap_list = self._search_balance(load) # swap expert and gate - self._swap_expert(swap_list) + self._swap_moe_param(swap_list, optim) # clear load self._clear_load() diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py index 5fa93a2899ec..e3659ef43fbd 100644 --- a/colossalai/moe/manager.py +++ b/colossalai/moe/manager.py @@ -27,6 +27,7 @@ def __init__(self): self.mode = None self.use_kernel_optim = False self.use_ep_inside = None + self.pp_size = None # load balance param self.load_balance = None diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py index 1bc86ee7f0d1..404ca22f547b 100644 --- a/tests/test_moe/test_moe_load_balance.py +++ b/tests/test_moe/test_moe_load_balance.py @@ -49,7 +49,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1): beam_width=8, group_swap_factor=0.4) zero_model = MoeModel(checkpoint=True) - zero_optimizer = torch.optim.Adam(zero_model.parameters()) + zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1) plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") booster = Booster(plugin=plugin) zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) @@ -59,42 +59,42 @@ def run_zero_optim_test(local_rank, world_size, stage=1): torch_model = MoeModel(checkpoint=True) for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): torch_param.data.copy_(zero_param.data) - torch_optimizer = torch.optim.Adam(torch_model.parameters()) + torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) torch_model = torch_model.cuda() grad_handler = MoeGradientHandler(torch_model) # run to update expert load data = torch.randn(16, 4).cuda() / (local_rank + 1) label = torch.randint(0, 4, (16,)).cuda() + + # run torch model twice + run_fwd_bwd(torch_model, data, label, criterion, None) + grad_handler.handle_gradient() + torch_optimizer.step() + torch_optimizer.zero_grad() run_fwd_bwd(torch_model, data, label, criterion, None) grad_handler.handle_gradient() + + # get optim and load status in zero model + run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + zero_optimizer.step() + zero_optimizer.zero_grad() with torch.no_grad(): - zero_model(data) + origin_out = zero_model(data) # load balance - apply_load_balance(zero_model) + apply_load_balance(zero_model, zero_optimizer) # run again to test - run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) - - for (zero_name, zero_param), (torch_name, torch_param) in zip(zero_model.module.named_parameters(), - torch_model.named_parameters()): - assert zero_name == torch_name - zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param)) - if hasattr(zero_param, "moe_info"): - assert len(zero_grad_list) == 0 - assert torch.allclose(zero_param.grad, torch_param.grad) - else: - assert len(zero_grad_list) > 0 - torch_grad_list = split_ddp_grad(torch_param.grad, world_size) - if stage == 2: - torch_grad_list = torch_grad_list[local_rank:local_rank + 1] - assert len(zero_grad_list) == len(torch_grad_list) - for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list): - assert torch.allclose(zero_grad, torch_grad), f"{zero_name} {zero_grad} {torch_grad}" + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + torch.allclose(origin_out, zero_out) - torch_optimizer.zero_grad() - zero_optimizer.zero_grad() + # assert optim + torch_optimizer.step() + torch_out = run_fwd_bwd(torch_model, data, label, criterion, None) + zero_optimizer.step() + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + assert torch.allclose(zero_out, torch_out), f"zero_out:{zero_out}\ntorch_out{torch_out}" def run_dist(rank, world_size, port): @@ -103,7 +103,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist -@pytest.mark.parametrize("world_size", [2]) +@pytest.mark.parametrize("world_size", [4]) @rerun_if_address_is_in_use() def test_moe_zero_optim(world_size): spawn(run_dist, world_size) From e863dc2d6e8e45bce4406385b6a1250a793fda52 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Mon, 16 Oct 2023 01:06:58 +0800 Subject: [PATCH 05/10] update test --- colossalai/moe/load_balance.py | 71 +++++++++++++++++++------ tests/test_moe/test_moe_load_balance.py | 25 ++++++--- 2 files changed, 71 insertions(+), 25 deletions(-) diff --git a/colossalai/moe/load_balance.py b/colossalai/moe/load_balance.py index 302cd1c1526a..b1f2ffedc650 100644 --- a/colossalai/moe/load_balance.py +++ b/colossalai/moe/load_balance.py @@ -292,23 +292,39 @@ def _swap_expert_param_and_optim( else: master_weight_ptr = optim._param_store.working_to_master_param[id(weight)] working_weight_ptr = master_weight_ptr - exp_avg_ptr = optim.optim.state[master_weight_ptr]['exp_avg'] - exp_avg_sq_ptr = optim.optim.state[master_weight_ptr]['exp_avg_sq'] + exp_avg_ptr = optim.optim.state[master_weight_ptr]["exp_avg"] + exp_avg_sq_ptr = optim.optim.state[master_weight_ptr]["exp_avg_sq"] # exchange weight - self._swap_expert_single_tensor(master_weight_ptr, expert_idx, comm_group, send_first, comm_rank, - working_weight_ptr) + self._swap_expert_single_tensor( + master_weight_ptr, + expert_idx, + comm_group, + send_first, + comm_rank, + working_weight_ptr, + ) # exchange optim self._swap_expert_single_tensor(exp_avg_ptr, expert_idx, comm_group, send_first, comm_rank) self._swap_expert_single_tensor(exp_avg_sq_ptr, expert_idx, comm_group, send_first, comm_rank) - def gather_global_dp_group(self, data: Tensor) -> Tensor: + def _gather_global_dp_group(self, data: Tensor) -> Tensor: data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size(self.global_dp_group))] dist.all_gather(data_list, data, group=self.global_dp_group) data_list = torch.cat(data_list, dim=0) return data_list def _swap_moe_param(self, swap_list: List, optim: LowLevelZeroOptimizer) -> None: + """ + Swap moe param and optim. + We use different strategies to swap expert and gate. + For expert, we exchange the param and optim of the expert by p2p. + For gate, we all gather the gate choose the part we want. + + Args: + swap_list (List) + optim (LowLevelZeroOptimizer) + """ # get all experts weights local_rank = dist.get_rank(self.moe_ep_group) if self.experts.gated: @@ -322,13 +338,14 @@ def _swap_moe_param(self, swap_list: List, optim: LowLevelZeroOptimizer) -> None local_gate_shape = self.gate[local_range].shape # get master weight and optim master_gate_weight = optim._param_store.working_to_master_param[id(self.gate)] - gate_exp_avg = optim.optim.state[master_gate_weight]['exp_avg'] - gate_exp_avg_sq = optim.optim.state[master_gate_weight]['exp_avg_sq'] + gate_exp_avg = optim.optim.state[master_gate_weight]["exp_avg"] + gate_exp_avg_sq = optim.optim.state[master_gate_weight]["exp_avg_sq"] # gather - global_master_gate_weight = self.gather_global_dp_group(master_gate_weight.view(local_gate_shape)) - global_gate_exp_avg = self.gather_global_dp_group(gate_exp_avg.view(local_gate_shape)) - global_gate_exp_avg_sq = self.gather_global_dp_group(gate_exp_avg_sq.view(local_gate_shape)) - assert self.gate.shape == global_master_gate_weight.shape == global_gate_exp_avg.shape == global_gate_exp_avg_sq.shape + global_master_gate_weight = self._gather_global_dp_group(master_gate_weight.view(local_gate_shape)) + global_gate_exp_avg = self._gather_global_dp_group(gate_exp_avg.view(local_gate_shape)) + global_gate_exp_avg_sq = self._gather_global_dp_group(gate_exp_avg_sq.view(local_gate_shape)) + assert (self.gate.shape == global_master_gate_weight.shape == global_gate_exp_avg.shape == + global_gate_exp_avg_sq.shape) for swap in swap_list: source_group, source_idx, target_group, target_idx = swap @@ -336,18 +353,38 @@ def _swap_moe_param(self, swap_list: List, optim: LowLevelZeroOptimizer) -> None if local_rank in [source_group, target_group]: for weight in weight_list: if local_rank == source_group: - self._swap_expert_param_and_optim(weight, source_idx, self.moe_ep_group, True, target_group, - optim) + self._swap_expert_param_and_optim( + weight, + source_idx, + self.moe_ep_group, + True, + target_group, + optim, + ) elif local_rank == target_group: - self._swap_expert_param_and_optim(weight, target_idx, self.moe_ep_group, False, source_group, - optim) + self._swap_expert_param_and_optim( + weight, + target_idx, + self.moe_ep_group, + False, + source_group, + optim, + ) # exchange gate source_expert_pos = source_group * self.local_expert_num + source_idx target_expert_pos = target_group * self.local_expert_num + target_idx - for gate in [self.gate, global_master_gate_weight, global_gate_exp_avg, global_gate_exp_avg_sq]: + for gate in [ + self.gate, + global_master_gate_weight, + global_gate_exp_avg, + global_gate_exp_avg_sq, + ]: origin_source = gate.data[source_expert_pos].clone().detach() origin_target = gate.data[target_expert_pos].clone().detach() - gate.data[source_expert_pos], gate.data[target_expert_pos] = origin_target, origin_source + gate.data[source_expert_pos], gate.data[target_expert_pos] = ( + origin_target, + origin_source, + ) # update gate master_gate_weight.data.copy_(global_master_gate_weight[local_range].data.view(-1)) diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py index 404ca22f547b..a9f7dee3a6e0 100644 --- a/tests/test_moe/test_moe_load_balance.py +++ b/tests/test_moe/test_moe_load_balance.py @@ -42,12 +42,14 @@ def run_zero_optim_test(local_rank, world_size, stage=1): criterion = torch.nn.CrossEntropyLoss() MOE_MANAGER.__init__() - MOE_MANAGER.setup(seed=42, - parallel="EP", - enable_load_balance=True, - tolerance=0.1, - beam_width=8, - group_swap_factor=0.4) + MOE_MANAGER.setup( + seed=42, + parallel="EP", + enable_load_balance=True, + tolerance=0.1, + beam_width=8, + group_swap_factor=0.4, + ) zero_model = MoeModel(checkpoint=True) zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1) plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") @@ -98,7 +100,14 @@ def run_zero_optim_test(local_rank, world_size, stage=1): def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch( + config=dict(), + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_zero_optim_test(rank, world_size, stage=1) @@ -109,5 +118,5 @@ def test_moe_zero_optim(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_zero_optim(world_size=2) From 5589539ccf4bac7dab6eba20b616fd6b4ef4dd6d Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Mon, 16 Oct 2023 01:21:23 +0800 Subject: [PATCH 06/10] update test --- tests/test_moe/test_moe_load_balance.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py index a9f7dee3a6e0..70e6a468b349 100644 --- a/tests/test_moe/test_moe_load_balance.py +++ b/tests/test_moe/test_moe_load_balance.py @@ -51,7 +51,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1): group_swap_factor=0.4, ) zero_model = MoeModel(checkpoint=True) - zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1) + zero_optimizer = torch.optim.Adam(zero_model.parameters()) plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") booster = Booster(plugin=plugin) zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) @@ -61,7 +61,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1): torch_model = MoeModel(checkpoint=True) for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): torch_param.data.copy_(zero_param.data) - torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) + torch_optimizer = torch.optim.Adam(torch_model.parameters()) torch_model = torch_model.cuda() grad_handler = MoeGradientHandler(torch_model) @@ -119,4 +119,4 @@ def test_moe_zero_optim(world_size): if __name__ == "__main__": - test_moe_zero_optim(world_size=2) + test_moe_zero_optim(world_size=4) From 5429cac4a5e71a11fd5d04183c6d90990a7f56be Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Mon, 16 Oct 2023 01:38:30 +0800 Subject: [PATCH 07/10] update test --- tests/test_moe/test_moe_load_balance.py | 62 +++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py index 70e6a468b349..64e8b40bd16e 100644 --- a/tests/test_moe/test_moe_load_balance.py +++ b/tests/test_moe/test_moe_load_balance.py @@ -1,5 +1,6 @@ import pytest import torch +import torch.distributed as dist import colossalai from colossalai.booster import Booster @@ -7,6 +8,7 @@ from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.moe.layers import apply_load_balance from colossalai.moe.manager import MOE_MANAGER +from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.testing import rerun_if_address_is_in_use, spawn from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel @@ -99,6 +101,64 @@ def run_zero_optim_test(local_rank, world_size, stage=1): assert torch.allclose(zero_out, torch_out), f"zero_out:{zero_out}\ntorch_out{torch_out}" +def run_hybrid_zero_optim_test(local_rank, world_size, stage=1): + criterion = torch.nn.CrossEntropyLoss() + data = torch.randn(16, 4).cuda() + label = torch.randint(0, 4, (16,)).cuda() + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed=42, parallel=None) + torch_model = MoeModel(checkpoint=True) + torch_optimizer = torch.optim.Adam(torch_model.parameters()) + torch_model = torch_model.cuda() + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed=42, max_ep_size=2, use_ep_inside=False, parallel="EP") + zero_model = MoeModel(checkpoint=True) + extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group + ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group) + ep_size = MOE_MANAGER.parallel_info_dict[2].ep_size + for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): + if is_moe_tensor(zero_param): + num_expert = torch_param.data.shape[0] + zero_param.data.copy_(torch_param.data[ep_rank * (num_expert // ep_size):(ep_rank + 1) * + (num_expert // ep_size)].detach().clone()) + else: + zero_param.data.copy_(torch_param.data.detach().clone()) + zero_optimizer = torch.optim.Adam(zero_model.parameters()) + plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") + plugin.zero_optim_kwargs["extra_dp_process_group"] = extra_dp_group + booster = Booster(plugin=plugin) + zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) + + # run torch for twice + run_fwd_bwd(torch_model, data, label, criterion, None) + torch_optimizer.step() + torch_optimizer.zero_grad() + run_fwd_bwd(torch_model, data, label, criterion, None) + torch_optimizer.step() + + # run zero + run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + zero_optimizer.step() + zero_optimizer.zero_grad() + with torch.no_grad(): + origin_out = zero_model(data) + + # load balance + apply_load_balance(zero_model, zero_optimizer) + + # assert out + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + torch.allclose(origin_out, zero_out) + + # assert optim + zero_optimizer.step() + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + torch_out = run_fwd_bwd(torch_model, data, label, criterion, None) + assert torch.allclose(zero_out, torch_out, atol=3e-7), f"zero_out:{zero_out}\ntorch_out{torch_out}" + + def run_dist(rank, world_size, port): colossalai.launch( config=dict(), @@ -109,6 +169,8 @@ def run_dist(rank, world_size, port): backend="nccl", ) run_zero_optim_test(rank, world_size, stage=1) + run_zero_optim_test(rank, world_size, stage=2) + run_hybrid_zero_optim_test(rank, world_size, stage=1) @pytest.mark.dist From 063c2798638e47b1aa15e19edb4bcb524dbc9d4f Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Mon, 16 Oct 2023 03:15:26 +0800 Subject: [PATCH 08/10] update test --- colossalai/moe/load_balance.py | 56 +++++++++++++++---------- tests/test_moe/test_moe_load_balance.py | 17 ++++++-- 2 files changed, 48 insertions(+), 25 deletions(-) diff --git a/colossalai/moe/load_balance.py b/colossalai/moe/load_balance.py index b1f2ffedc650..2690aeb4fe63 100644 --- a/colossalai/moe/load_balance.py +++ b/colossalai/moe/load_balance.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Any, List, Optional, Tuple +from typing import List, Optional, Tuple import torch import torch.distributed as dist @@ -29,6 +29,7 @@ def __init__( self.experts: BaseMLPExperts = experts self.gate: nn.Parameter = gate self.moe_ep_group: ProcessGroup = ep_group + self.moe_ep_ranks = dist.get_process_group_ranks(self.moe_ep_group) self.moe_dp_group: ProcessGroup = dp_group self.tolerance = tolerance self.beam_width = beam_width @@ -260,7 +261,6 @@ def _swap_expert_single_tensor( comm_group: ProcessGroup, send_first: bool, comm_rank: int, - working_weight: nn.Parameter = None, ): # exchange weight local_weight = weight.data[expert_idx] @@ -272,8 +272,6 @@ def _swap_expert_single_tensor( dist.recv(new_weight, src=comm_rank, group=comm_group) dist.send(local_weight, dst=comm_rank, group=comm_group) weight.data[expert_idx] = new_weight - if working_weight is not None: - working_weight.data[expert_idx] = new_weight.to(working_weight.device).to(working_weight.dtype) def _swap_expert_param_and_optim( self, @@ -287,23 +285,30 @@ def _swap_expert_param_and_optim( # need to update master and working param if master param exists # else just update working param if weight in optim.optim.state: - master_weight_ptr = weight - working_weight_ptr = None + master_weight_ptr = None + working_weight_ptr = weight + exp_avg_ptr = optim.optim.state[working_weight_ptr]["exp_avg"] + exp_avg_sq_ptr = optim.optim.state[working_weight_ptr]["exp_avg_sq"] else: master_weight_ptr = optim._param_store.working_to_master_param[id(weight)] - working_weight_ptr = master_weight_ptr - exp_avg_ptr = optim.optim.state[master_weight_ptr]["exp_avg"] - exp_avg_sq_ptr = optim.optim.state[master_weight_ptr]["exp_avg_sq"] + working_weight_ptr = weight + exp_avg_ptr = optim.optim.state[master_weight_ptr]["exp_avg"] + exp_avg_sq_ptr = optim.optim.state[master_weight_ptr]["exp_avg_sq"] # exchange weight self._swap_expert_single_tensor( - master_weight_ptr, + working_weight_ptr, expert_idx, comm_group, send_first, comm_rank, - working_weight_ptr, ) + if master_weight_ptr is not None: + # TODO: exchange master weight, skip for now + # master weight is shared by dp group + tmp = working_weight_ptr.view(-1).split( + working_weight_ptr.numel() // dist.get_world_size(self.moe_dp_group))[dist.get_rank(self.moe_dp_group)] + master_weight_ptr.data.copy_(tmp.clone().detach().to(master_weight_ptr.device).to(master_weight_ptr.dtype)) # exchange optim self._swap_expert_single_tensor(exp_avg_ptr, expert_idx, comm_group, send_first, comm_rank) self._swap_expert_single_tensor(exp_avg_sq_ptr, expert_idx, comm_group, send_first, comm_rank) @@ -334,21 +339,22 @@ def _swap_moe_param(self, swap_list: List, optim: LowLevelZeroOptimizer) -> None weight_list.append(self.experts.wo) # gate optim should be obtained first - local_range = slice(local_rank * self.local_expert_num, (local_rank + 1) * self.local_expert_num) - local_gate_shape = self.gate[local_range].shape + gate_shape = self.gate.shape # get master weight and optim master_gate_weight = optim._param_store.working_to_master_param[id(self.gate)] gate_exp_avg = optim.optim.state[master_gate_weight]["exp_avg"] gate_exp_avg_sq = optim.optim.state[master_gate_weight]["exp_avg_sq"] # gather - global_master_gate_weight = self._gather_global_dp_group(master_gate_weight.view(local_gate_shape)) - global_gate_exp_avg = self._gather_global_dp_group(gate_exp_avg.view(local_gate_shape)) - global_gate_exp_avg_sq = self._gather_global_dp_group(gate_exp_avg_sq.view(local_gate_shape)) + global_master_gate_weight = self._gather_global_dp_group(master_gate_weight).view(gate_shape) + global_gate_exp_avg = self._gather_global_dp_group(gate_exp_avg).view(gate_shape) + global_gate_exp_avg_sq = self._gather_global_dp_group(gate_exp_avg_sq).view(gate_shape) assert (self.gate.shape == global_master_gate_weight.shape == global_gate_exp_avg.shape == global_gate_exp_avg_sq.shape) for swap in swap_list: source_group, source_idx, target_group, target_idx = swap + source_rank = self.moe_ep_ranks[source_group] + target_rank = self.moe_ep_ranks[target_group] # exchange expert if local_rank in [source_group, target_group]: for weight in weight_list: @@ -358,7 +364,7 @@ def _swap_moe_param(self, swap_list: List, optim: LowLevelZeroOptimizer) -> None source_idx, self.moe_ep_group, True, - target_group, + target_rank, optim, ) elif local_rank == target_group: @@ -367,7 +373,7 @@ def _swap_moe_param(self, swap_list: List, optim: LowLevelZeroOptimizer) -> None target_idx, self.moe_ep_group, False, - source_group, + source_rank, optim, ) # exchange gate @@ -387,9 +393,17 @@ def _swap_moe_param(self, swap_list: List, optim: LowLevelZeroOptimizer) -> None ) # update gate - master_gate_weight.data.copy_(global_master_gate_weight[local_range].data.view(-1)) - gate_exp_avg.data.copy_(global_gate_exp_avg[local_range].data.view(-1)) - gate_exp_avg_sq.data.copy_(global_gate_exp_avg_sq[local_range].data.view(-1)) + dp_group_rank = dist.get_rank(self.global_dp_group) + dp_group_size = dist.get_world_size(self.global_dp_group) + global_master_gate_weight = global_master_gate_weight.view(-1).split(global_master_gate_weight.numel() // + dp_group_size)[dp_group_rank] + master_gate_weight.data.copy_(global_master_gate_weight) + global_gate_exp_avg = global_gate_exp_avg.view(-1).split(global_gate_exp_avg.numel() // + dp_group_size)[dp_group_rank] + gate_exp_avg.data.copy_(global_gate_exp_avg) + global_gate_exp_avg_sq = global_gate_exp_avg_sq.view(-1).split(global_gate_exp_avg_sq.numel() // + dp_group_size)[dp_group_rank] + gate_exp_avg_sq.data.copy_(global_gate_exp_avg_sq) @torch.no_grad() def update_load(self, load: Tensor) -> None: diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py index 64e8b40bd16e..9e1d649519ee 100644 --- a/tests/test_moe/test_moe_load_balance.py +++ b/tests/test_moe/test_moe_load_balance.py @@ -113,7 +113,16 @@ def run_hybrid_zero_optim_test(local_rank, world_size, stage=1): torch_model = torch_model.cuda() MOE_MANAGER.__init__() - MOE_MANAGER.setup(seed=42, max_ep_size=2, use_ep_inside=False, parallel="EP") + MOE_MANAGER.setup( + seed=42, + max_ep_size=2, + use_ep_inside=False, + parallel="EP", + enable_load_balance=True, + tolerance=0.1, + beam_width=8, + group_swap_factor=0.4, + ) zero_model = MoeModel(checkpoint=True) extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group) @@ -156,7 +165,7 @@ def run_hybrid_zero_optim_test(local_rank, world_size, stage=1): zero_optimizer.step() zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) torch_out = run_fwd_bwd(torch_model, data, label, criterion, None) - assert torch.allclose(zero_out, torch_out, atol=3e-7), f"zero_out:{zero_out}\ntorch_out{torch_out}" + assert torch.allclose(zero_out, torch_out, atol=8e-4), f"zero_out:{zero_out}\ntorch_out{torch_out}" def run_dist(rank, world_size, port): @@ -176,9 +185,9 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @pytest.mark.parametrize("world_size", [4]) @rerun_if_address_is_in_use() -def test_moe_zero_optim(world_size): +def test_moe_load_balance(world_size): spawn(run_dist, world_size) if __name__ == "__main__": - test_moe_zero_optim(world_size=4) + test_moe_load_balance(world_size=4) From c2ebb1f15987f306903b50db31732cba61e48b11 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Mon, 16 Oct 2023 10:05:39 +0800 Subject: [PATCH 09/10] fix ranks --- colossalai/moe/load_balance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/moe/load_balance.py b/colossalai/moe/load_balance.py index 2690aeb4fe63..b2fb672329c2 100644 --- a/colossalai/moe/load_balance.py +++ b/colossalai/moe/load_balance.py @@ -29,7 +29,7 @@ def __init__( self.experts: BaseMLPExperts = experts self.gate: nn.Parameter = gate self.moe_ep_group: ProcessGroup = ep_group - self.moe_ep_ranks = dist.get_process_group_ranks(self.moe_ep_group) + self.moe_ep_ranks = MOE_MANAGER.parallel_info_dict[dist.get_world_size(self.moe_ep_group)].ep_group_ranks self.moe_dp_group: ProcessGroup = dp_group self.tolerance = tolerance self.beam_width = beam_width From d153572b78c23b831e091a65aa407967b188f5ea Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Mon, 16 Oct 2023 17:04:11 +0800 Subject: [PATCH 10/10] update --- colossalai/moe/layers.py | 1 - tests/test_moe/test_moe_load_balance.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index bd878047da78..3f82a0fa23fd 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -300,4 +300,3 @@ def _apply_recursive(module: nn.Module): _apply_recursive(sub_module) _apply_recursive(model) - dist.barrier() diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py index 9e1d649519ee..b4eea04bc85a 100644 --- a/tests/test_moe/test_moe_load_balance.py +++ b/tests/test_moe/test_moe_load_balance.py @@ -54,7 +54,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1): ) zero_model = MoeModel(checkpoint=True) zero_optimizer = torch.optim.Adam(zero_model.parameters()) - plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") + plugin = LowLevelZeroPlugin(stage=stage, precision="bf16", verbose=True) booster = Booster(plugin=plugin) zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) @@ -64,11 +64,11 @@ def run_zero_optim_test(local_rank, world_size, stage=1): for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): torch_param.data.copy_(zero_param.data) torch_optimizer = torch.optim.Adam(torch_model.parameters()) - torch_model = torch_model.cuda() + torch_model = torch_model.cuda().bfloat16() grad_handler = MoeGradientHandler(torch_model) # run to update expert load - data = torch.randn(16, 4).cuda() / (local_rank + 1) + data = torch.randn(16, 4).cuda().bfloat16() / 1000 / (local_rank + 1) label = torch.randint(0, 4, (16,)).cuda() # run torch model twice