From 8d214539709172945e9e80f112b49e4ece891493 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Fri, 27 Oct 2023 13:22:36 +0800 Subject: [PATCH 01/15] fix: add warning for EP different behavior --- tests/test_moe/moe_utils.py | 85 +------------------- tests/test_moe/test_moe_ep_tp.py | 130 ++++++++++++++++++++++++++----- 2 files changed, 114 insertions(+), 101 deletions(-) diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index 40adeab717de..721a4796abfd 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -8,7 +8,6 @@ from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import get_moe_epsize_param_dict -from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor class MoeModel(nn.Module): @@ -76,84 +75,6 @@ def handle_gradient(self): ) -def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: - """Sync the parameters of tp model from ep model - - Args: - tp_model (MoeModule) - ep_model (MoeModule) - """ - for (tp_name, tp_param), (ep_name, ep_param) in zip(tp_model.named_parameters(), ep_model.named_parameters()): - assert tp_name == ep_name - if not is_moe_tensor(tp_param): - if assert_grad_flag: - assert torch.allclose(tp_param, ep_param) - assert torch.allclose(tp_param.grad, ep_param.grad) - else: - tp_param.data.copy_(ep_param.data) - continue - - # gather param from ep model - param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] - dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) - all_param = torch.cat(param_list, dim=0) - if assert_grad_flag: - grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] - dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) - all_grad = torch.cat(grad_list, dim=0) - - # get tp param - tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape[1:], all_param.shape[1:])) if d1 != d2] - tp_rank = get_ep_rank(tp_param) - tp_dim = tp_dim[0] + 1 - tp_slice = [slice(None)] * tp_dim + [ - slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1)) - ] - new_tp_param = all_param[tuple(tp_slice)] - if assert_grad_flag: - new_grad = all_grad[tuple(tp_slice)] - if assert_grad_flag: - assert torch.allclose(tp_param, new_tp_param) - assert torch.allclose(tp_param.grad, new_grad) - else: - tp_param.data.copy_(new_tp_param.data) - - -def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: - """Sync the parameters of tp model from ep model - - Args: - local_model (MoeModule) - ep_model (MoeModule) - """ - for (local_name, local_param), (ep_name, ep_param) in zip( - local_model.named_parameters(), ep_model.named_parameters() - ): - assert local_name == ep_name - if "experts" not in local_name: - if assert_grad_flag: - assert torch.allclose(local_param, ep_param) - assert torch.allclose(local_param.grad, ep_param.grad) - else: - local_param.data.copy_(ep_param.data) - continue - - # gather param from ep model - param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] - dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) - all_param = torch.cat(param_list, dim=0) - if assert_grad_flag: - grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] - dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) - all_grad = torch.cat(grad_list, dim=0) - - if assert_grad_flag: - assert torch.allclose(local_param, all_param) - assert torch.allclose(local_param.grad, all_grad) - else: - local_param.data.copy_(all_param.data) - - def assert_not_equal_in_group(tensor, process_group=None): # all gather tensors from different ranks world_size = dist.get_world_size(process_group) @@ -164,6 +85,6 @@ def assert_not_equal_in_group(tensor, process_group=None): for i in range(world_size - 1): a = tensor_list[i] b = tensor_list[i + 1] - assert not torch.allclose( - a, b - ), f"expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}" + assert not torch.allclose(a, b), \ + (f"expected tensors on rank {i} and {i + 1} not to be equal " + f"but they are, {a} vs {b}") diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 2c9bbd446e22..7079df2d3655 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -1,3 +1,5 @@ +import warnings + import pytest import torch import torch.distributed as dist @@ -6,9 +8,75 @@ from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import sync_moe_model_param +from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep, sync_tp_from_ep +from tests.test_moe.moe_utils import MoeGradientHandler + + +def sync_tp_from_local(tp_model: SparseMLP, local_model: SparseMLP, assert_grad_flag: bool = False) -> None: + """Sync the parameters of tp model from local model + + Args: + tp_model (MoeModule) + local_model (MoeModule) + """ + for (tp_name, tp_param), (local_name, local_param) in \ + zip(tp_model.named_parameters(), local_model.named_parameters()): + assert tp_name == local_name + if not is_moe_tensor(tp_param): + if assert_grad_flag: + assert torch.allclose(tp_param, local_param) + assert torch.allclose(tp_param.grad, local_param.grad) + else: + tp_param.data.copy_(local_param.data) + continue + + tp_rank = get_ep_rank(tp_param) + tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape, local_param.shape)) if d1 != d2][0] + tp_slice = [slice(None)] * tp_dim + [ + slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1)) + ] + + if assert_grad_flag: + assert torch.allclose(tp_param, local_param[tuple(tp_slice)]) + assert torch.allclose(tp_param.grad, local_param.grad[tuple(tp_slice)]) + else: + tp_param.data.copy_(local_param[tuple(tp_slice)].data) + + +def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: + """Sync the parameters of tp model from ep model + + Args: + local_model (MoeModule) + ep_model (MoeModule) + """ + for (local_name, local_param), (ep_name, ep_param) in \ + zip(local_model.named_parameters(), ep_model.named_parameters()): + assert local_name == ep_name + if "experts" not in local_name: + if assert_grad_flag: + assert torch.allclose(local_param, ep_param) + assert torch.allclose(local_param.grad, ep_param.grad) + else: + local_param.data.copy_(ep_param.data) + continue + + # gather param from ep model + param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) + all_param = torch.cat(param_list, dim=0) + if assert_grad_flag: + grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) + all_grad = torch.cat(grad_list, dim=0) + + if assert_grad_flag: + assert torch.allclose(local_param, all_param) + assert torch.allclose(local_param.grad, all_grad) + else: + local_param.data.copy_(all_param.data) def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, seed: int): @@ -35,25 +103,44 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size assert_equal_in_group(ep_model.experts.wi.data, dist_dict[world_size].dp_group) assert_equal_in_group(ep_model.experts.wo.data, dist_dict[world_size].dp_group) grad_handler = MoeGradientHandler(ep_model) - # sync tp param - sync_tp_from_ep(tp_model, ep_model) # sync local param sync_local_from_ep(local_model, ep_model) + # sync tp param + sync_tp_from_local(tp_model, local_model) rank = dist.get_rank() torch.cuda.manual_seed(seed) - tp_data = torch.randn(batch_size, dim, device=get_current_device()) + input_data = torch.randn(batch_size, dim, device=get_current_device()) micro_batch_size = batch_size // world_size - ep_data = tp_data.detach()[micro_batch_size * rank : micro_batch_size * (rank + 1)] + index = rank * micro_batch_size + shard_data = input_data.detach()[index:index + micro_batch_size] - out_local = local_model(tp_data) + out_local = local_model(input_data) MOE_MANAGER.reset_loss() - out_tp = tp_model(tp_data) + out_tp = tp_model(input_data) MOE_MANAGER.reset_loss() - out_ep = ep_model(ep_data) + out_ep = ep_model(shard_data) MOE_MANAGER.reset_loss() - assert torch.allclose(out_ep, out_tp[micro_batch_size * rank : micro_batch_size * (rank + 1)]) - assert torch.allclose(out_ep, out_local[micro_batch_size * rank : micro_batch_size * (rank + 1)]) + + assert torch.allclose(out_tp, out_local, atol=1e-6), \ + f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_local))}" + try: + assert torch.allclose(out_ep, out_local[index:index + micro_batch_size], atol=1e-6), \ + f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local[index:index + micro_batch_size]))}" + except AssertionError as e: + """ + e.g., in local model, tokens = 4, capacity = 2, experts = 2, topk = 1 + router yields [01] --> [0], [23] --> [1], this is valid as capacity is 2 + However, in ep mode, there are 2 separate routers dealing with sharded data. + Assume router 0 handles token [01] and router 1 handles token [23]. + Note that for each router the capacity is only 1 !!! + Thus, router 0 may yields [0] --> [0] or [1] --> [0], but not both. + The same thing happens on router 1. And finally some tokens are dropped due to the sharded nature. + """ + warnings.warn( + "EP may result in different behavior from local model. " + "Please check the comments for details." + ) out_local.mean().backward() out_tp.mean().backward() @@ -62,20 +149,25 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[world_size].dp_group) assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[world_size].dp_group) - - sync_local_from_ep(local_model, ep_model, assert_grad_flag=True) - sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True) + sync_tp_from_local(tp_model, local_model, assert_grad_flag=True) + try: + sync_local_from_ep(local_model, ep_model, assert_grad_flag=True) + except AssertionError as e: + warnings.warn( + "EP may result in different behavior from local model. " + "Please check the comments for details." + ) @pytest.mark.dist -@pytest.mark.parametrize("num_experts", [4, 8]) -@pytest.mark.parametrize("batch_size", [4]) -@pytest.mark.parametrize("dim", [32]) -@pytest.mark.parametrize("seed", [42]) +@pytest.mark.parametrize("num_experts", [4, 64]) +@pytest.mark.parametrize("batch_size", [16]) +@pytest.mark.parametrize("dim", [256]) +@pytest.mark.parametrize("seed", [42, 127]) @rerun_if_address_is_in_use() def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int): spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed) -if __name__ == "__main__": - test_moe_ep_tp(num_experts=8, batch_size=8, dim=256, seed=42) +if __name__ == '__main__': + test_moe_ep_tp(num_experts=8, batch_size=32, dim=32, seed=42) From 317dd58c42a6ada9a54a8d35da294ea0edf8d7c7 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Wed, 1 Nov 2023 14:00:49 +0800 Subject: [PATCH 02/15] fix: use shard_data in ep & tp model --- tests/test_moe/test_moe_ep_tp.py | 71 ++++++++++++++++++++++++++------ 1 file changed, 59 insertions(+), 12 deletions(-) diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 7079df2d3655..52dc3eeb060b 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -45,6 +45,49 @@ def sync_tp_from_local(tp_model: SparseMLP, local_model: SparseMLP, assert_grad_ tp_param.data.copy_(local_param[tuple(tp_slice)].data) +def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: + """Sync the parameters of tp model from ep model + + Args: + tp_model (MoeModule) + ep_model (MoeModule) + """ + for (tp_name, tp_param), (ep_name, ep_param) in \ + zip(tp_model.named_parameters(), ep_model.named_parameters()): + assert tp_name == ep_name + if not is_moe_tensor(tp_param): + if assert_grad_flag: + assert torch.allclose(tp_param, ep_param) + assert torch.allclose(tp_param.grad, ep_param.grad) + else: + tp_param.data.copy_(ep_param.data) + continue + + # gather param from ep model + param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) + all_param = torch.cat(param_list, dim=0) + if assert_grad_flag: + grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) + all_grad = torch.cat(grad_list, dim=0) + + # get tp param + tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape[1:], all_param.shape[1:])) if d1 != d2][0] + 1 + tp_rank = get_ep_rank(tp_param) + tp_slice = [slice(None)] * tp_dim + [ + slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1)) + ] + new_tp_param = all_param[tuple(tp_slice)] + if assert_grad_flag: + new_grad = all_grad[tuple(tp_slice)] + if assert_grad_flag: + assert torch.allclose(tp_param, new_tp_param) + assert torch.allclose(tp_param.grad, new_grad) + else: + tp_param.data.copy_(new_tp_param.data) + + def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: """Sync the parameters of tp model from ep model @@ -102,31 +145,34 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size dist_dict = MOE_MANAGER.parallel_info_dict assert_equal_in_group(ep_model.experts.wi.data, dist_dict[world_size].dp_group) assert_equal_in_group(ep_model.experts.wo.data, dist_dict[world_size].dp_group) - grad_handler = MoeGradientHandler(ep_model) + ep_grad_handler = MoeGradientHandler(ep_model) # sync local param sync_local_from_ep(local_model, ep_model) # sync tp param - sync_tp_from_local(tp_model, local_model) + sync_tp_from_ep(tp_model, ep_model) + tp_grad_handler = MoeGradientHandler(tp_model) rank = dist.get_rank() torch.cuda.manual_seed(seed) input_data = torch.randn(batch_size, dim, device=get_current_device()) micro_batch_size = batch_size // world_size index = rank * micro_batch_size + # NOTE: ep & tp takes in sharded data for each process shard_data = input_data.detach()[index:index + micro_batch_size] out_local = local_model(input_data) MOE_MANAGER.reset_loss() - out_tp = tp_model(input_data) + out_tp = tp_model(shard_data) MOE_MANAGER.reset_loss() out_ep = ep_model(shard_data) MOE_MANAGER.reset_loss() - assert torch.allclose(out_tp, out_local, atol=1e-6), \ - f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_local))}" + assert torch.allclose(out_tp, out_ep, atol=1e-6), \ + f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}" try: - assert torch.allclose(out_ep, out_local[index:index + micro_batch_size], atol=1e-6), \ - f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local[index:index + micro_batch_size]))}" + out_local_slice = out_local[index:index + micro_batch_size] + assert torch.allclose(out_ep, out_local_slice, atol=1e-6), \ + f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}" except AssertionError as e: """ e.g., in local model, tokens = 4, capacity = 2, experts = 2, topk = 1 @@ -138,23 +184,24 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size The same thing happens on router 1. And finally some tokens are dropped due to the sharded nature. """ warnings.warn( - "EP may result in different behavior from local model. " + "EP & TP may result in different behavior from local model. " "Please check the comments for details." ) out_local.mean().backward() out_tp.mean().backward() + tp_grad_handler.handle_gradient() out_ep.mean().backward() - grad_handler.handle_gradient() + ep_grad_handler.handle_gradient() assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[world_size].dp_group) assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[world_size].dp_group) - sync_tp_from_local(tp_model, local_model, assert_grad_flag=True) + sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True) try: sync_local_from_ep(local_model, ep_model, assert_grad_flag=True) except AssertionError as e: warnings.warn( - "EP may result in different behavior from local model. " + "EP & TP may result in different behavior from local model. " "Please check the comments for details." ) @@ -162,7 +209,7 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size @pytest.mark.dist @pytest.mark.parametrize("num_experts", [4, 64]) @pytest.mark.parametrize("batch_size", [16]) -@pytest.mark.parametrize("dim", [256]) +@pytest.mark.parametrize("dim", [64]) @pytest.mark.parametrize("seed", [42, 127]) @rerun_if_address_is_in_use() def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int): From 5388b7bb7b3fc3f72e436bda250041b4f46b16c0 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Wed, 1 Nov 2023 16:12:01 +0800 Subject: [PATCH 03/15] to: add used_capacity --- colossalai/moe/_operation.py | 6 ++-- colossalai/moe/layers.py | 57 ++++++++++++++++++++++++------------ colossalai/moe/routers.py | 47 +++++++++++++++++------------ 3 files changed, 69 insertions(+), 41 deletions(-) diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index 542c6372790f..ec1ad6cddeba 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -6,8 +6,6 @@ from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup -from colossalai.moe.manager import MOE_MANAGER - MOE_KERNEL = None @@ -121,6 +119,8 @@ def forward( outputs: Tensor handle: Optional[Work], if overlap is True """ + assert ctx is not None or not overlap + if ctx is not None: ctx.comm_grp = group if not inputs.is_contiguous(): @@ -138,7 +138,7 @@ def forward( @staticmethod def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: return ( - AllToAll.forward(None, grad_outputs[0], ctx.comm_grp)[0], + AllToAll.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], None, None, ) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index bd2cefbe9ab8..8b399e6d3f44 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -51,19 +51,19 @@ def __init__( hidden_size: int, intermediate_size: int, router_top_k: int = 1, - router_capacity_factor_train: Optional[float] = 1.25, - router_capacity_factor_eval: Optional[float] = 2.0, - router_min_capacity: Optional[int] = 4, + router_capacity_factor_train: float = 1.25, + router_capacity_factor_eval: float = 2.0, + router_min_capacity: int = 4, router_noisy_policy: Optional[str] = None, - router_drop_tks: Optional[bool] = True, + router_drop_tks: bool = True, mlp_activation: Optional[str] = None, - mlp_gated: Optional[bool] = False, - enable_load_balance: Optional[bool] = False, - load_balance_tolerance: Optional[float] = 0.1, - load_balance_beam_width: Optional[int] = 8, - load_balance_group_swap_factor: Optional[float] = 0.4, - enable_kernel: Optional[bool] = False, - enable_comm_overlap: Optional[bool] = False, + mlp_gated: bool = False, + enable_load_balance: bool = False, + load_balance_tolerance: float = 0.1, + load_balance_beam_width: int = 8, + load_balance_group_swap_factor: float = 0.4, + enable_kernel: bool = False, + enable_comm_overlap: bool = False, ): super().__init__() self.hidden_size = hidden_size @@ -132,7 +132,7 @@ def __init__( def reset_parameters(self): torch.nn.init.normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size)) - def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, inputs: torch.Tensor) -> torch.Tensor: """ Args: inputs (torch.Tensor): The input tensor of shape (batch_size, seq_len, hidden_size) @@ -158,7 +158,8 @@ def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: self.load_balancer.update_load(expert_load) # the result from the router - route_result_list = self.router(inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group) + used_capacity, *route_result_list = self.router( + inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group) # dispatch_data: (num_experts, capacity, hidden_size) if self.enable_kernel: @@ -170,9 +171,17 @@ def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # expert_output: (num_groups, num_experts, capacity, hidden_size) if self.expert_parallel == "EP": - expert_output = self._ep_process(dispatch_data, overlap=self.enable_comm_overlap) + expert_output = self._ep_process( + dispatch_data, + used_capacity, + overlap=self.enable_comm_overlap + ) elif self.expert_parallel == "TP": - expert_output = self._tp_process(dispatch_data, overlap=self.enable_comm_overlap) + expert_output = self._tp_process( + dispatch_data, + used_capacity, + overlap=self.enable_comm_overlap + ) elif self.expert_parallel is None: expert_output = self._local_process(dispatch_data) else: @@ -196,7 +205,12 @@ def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: expert_out = self.experts(expert_in) return expert_out - def _ep_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> torch.Tensor: + def _ep_process( + self, + dispatch_data: torch.Tensor, + used_capacity: torch.Tensor, + overlap: bool = False + ) -> torch.Tensor: """ Expert Parallel @@ -261,7 +275,12 @@ class Capsule: return output - def _tp_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> torch.Tensor: + def _tp_process( + self, + dispatch_data: torch.Tensor, + used_capacity: torch.Tensor, + overlap: bool = False + ) -> torch.Tensor: """ without overlap: | C | @@ -295,8 +314,8 @@ class Capsule: NUM_CHUNK = 4 NUM_STAGES = 4 - assert (dispatch_data.shape[0] % NUM_CHUNK == 0 - ), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" + assert dispatch_data.shape[0] % NUM_CHUNK == 0, \ + "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" chunk_size = dispatch_data.shape[0] // NUM_CHUNK chunk_data = torch.split(dispatch_data, chunk_size, dim=0) output = torch.empty_like(dispatch_data) diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index 7960a74d4539..46c316588f8a 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -138,9 +138,10 @@ def __init__(self, self.select_policy = select_policy assert select_policy in {"first", "random"} if select_policy == "random": - self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()), - high=torch.tensor(1.0, - device=get_current_device())).rsample + self.uniform = torch.distributions.uniform.Uniform( + low=torch.tensor(0.0, device=get_current_device()), + high=torch.tensor(1.0, device=get_current_device()) + ).rsample def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: """ @@ -165,7 +166,7 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti top1_idx = torch.argmax(inputs, dim=-1) mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - # caculate router loss + # calculate router loss self.set_aux_loss(probs, top1_idx.unsqueeze(-1), num_experts) self.set_z_loss(inputs) self.pop_router_loss() @@ -187,18 +188,19 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti raise NotImplementedError("Not support such select policy yet.") ranks = torch.sum(mask * ranks, dim=-1) + used_capacity = mask.sum(dim=0) if use_kernel: mask = torch.sum(mask, dim=-1) mask = torch.stack([mask], dim=0).to(torch.int32) dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) - return probs, mask, dest_idx, num_experts * capacity + return used_capacity, probs, mask, dest_idx, num_experts * capacity else: ranks = F.one_hot(ranks, num_classes=capacity) weight = mask * probs.type_as(inputs) combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) sec_mask = combine_weights.bool() - return combine_weights, sec_mask + return used_capacity, combine_weights, sec_mask class Top2Router(MoeRouter): @@ -256,7 +258,7 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti cmask = (mask1 + mask2) # loss: [s, e] cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 - # caculate loss + # calculate loss expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) self.set_aux_loss(probs, expert_indices, num_experts) self.set_z_loss(inputs) @@ -273,6 +275,7 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti mask1 *= torch.lt(rank1, capacity) mask2 *= torch.lt(rank2, capacity) + used_capacity = mask1.sum(dim=0) + mask2.sum(dim=0) rank1 = torch.sum(mask1 * rank1, dim=-1) rank2 = torch.sum(mask2 * rank2, dim=-1) @@ -284,18 +287,23 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) - return probs, mask, dest_idx, num_experts * capacity + return used_capacity, probs, mask, dest_idx, num_experts * capacity else: - # >>> original code - # weight1 = mask1 * probs.type_as(inputs) - # weight2 = mask2 * probs.type_as(inputs) - # rank1_sc = F.one_hot(rank1, num_classes=capacity) - # rank2_sc = F.one_hot(rank2, num_classes=capacity) - - # cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) - # cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) - # cb_weight = cb_weight1 + cb_weight2 - # sec_mask = cb_weight.bool() + """ + The following code is equivalent to: + + ``` + weight1 = mask1 * probs.type_as(inputs) + weight2 = mask2 * probs.type_as(inputs) + rank1_sc = F.one_hot(rank1, num_classes=capacity) + rank2_sc = F.one_hot(rank2, num_classes=capacity) + + cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) + cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) + cb_weight = cb_weight1 + cb_weight2 + sec_mask = cb_weight.bool() + ``` + """ weight1 = mask1 * probs.type_as(inputs) weight2 = mask2 * probs.type_as(inputs) @@ -308,7 +316,7 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti sec_mask[indices, top1_idx[indices], rank1[indices]] |= mask1.bool()[indices, top1_idx[indices]] sec_mask[indices, top2_idx[indices], rank2[indices]] |= mask2.bool()[indices, top2_idx[indices]] - return cb_weight, sec_mask + return used_capacity, cb_weight, sec_mask class TopKRouter(MoeRouter): @@ -353,6 +361,7 @@ def forward( Dispatch and combine arrays for routing with masked matmuls. """ # TODO: add parallel group + raise RuntimeError("Not tested yet.") num_groups, _, num_experts = router_probs.shape # Top-k router probability and corresponding expert indices for each token. From ea1ddc905ef53d3e8bc6b5818335dc7d6c87ad58 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Wed, 1 Nov 2023 16:20:10 +0800 Subject: [PATCH 04/15] fix: fix router test --- colossalai/moe/routers.py | 3 +-- tests/test_moe/test_moe_router.py | 10 +++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index 46c316588f8a..c5bb508621b2 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -360,8 +360,7 @@ def forward( Returns: Dispatch and combine arrays for routing with masked matmuls. """ - # TODO: add parallel group - raise RuntimeError("Not tested yet.") + # TODO: FIXME: add parallel group num_groups, _, num_experts = router_probs.shape # Top-k router probability and corresponding expert indices for each token. diff --git a/tests/test_moe/test_moe_router.py b/tests/test_moe/test_moe_router.py index fce0d1064950..c2b581eb1556 100644 --- a/tests/test_moe/test_moe_router.py +++ b/tests/test_moe/test_moe_router.py @@ -20,22 +20,22 @@ def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_ex router.train() if isinstance(router, TopKRouter): - combine_array, dispatch_mask = router(x, expert_capacity=2) + _, combine_array, dispatch_mask = router(x, expert_capacity=2) else: - combine_array, dispatch_mask = router(x) + _, combine_array, dispatch_mask = router(x) assert combine_array.shape[:-1] == x.shape assert dispatch_mask.shape[:-1] == x.shape assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value) router.eval() if isinstance(router, TopKRouter): - combine_array, dispatch_mask = router(x, expert_capacity=2) + _, combine_array, dispatch_mask = router(x, expert_capacity=2) else: - combine_array, dispatch_mask = router(x) + _, combine_array, dispatch_mask = router(x) assert combine_array.shape[:-1] == x.shape assert dispatch_mask.shape[:-1] == x.shape assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value) if __name__ == "__main__": - test_router_forward(Top1Router(), 4, 4, 4, 1) + test_router_forward(Top2Router(), 4, 4, 4, 1) From e7616809565d6ea485b99cddc5dfe16fe37abaeb Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Fri, 3 Nov 2023 17:38:02 +0800 Subject: [PATCH 05/15] feat: add create_ep_node_group --- colossalai/moe/utils.py | 37 +++++++++++++++++++++++- colossalai/tensor/moe_tensor/moe_info.py | 8 ++--- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py index 0938e4206fda..d5c90e7eb872 100644 --- a/colossalai/moe/utils.py +++ b/colossalai/moe/utils.py @@ -1,5 +1,6 @@ import contextlib -from typing import Any, Callable, Dict, List +import os +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torch.distributed as dist @@ -175,3 +176,37 @@ def sync_moe_model_param(model: nn.Module): def set_moe_args(config: Any, args: dict): for k, v in args.items(): setattr(config, k, v) + + +def create_ep_node_group( + ep_group: dist.ProcessGroup +) -> Tuple[dist.ProcessGroup, + Optional[dist.ProcessGroup]]: + """ + e.g., If ep_group = [1, 2, 5, 6], and nproc_per_node = 4 + Then, ep_intra_group = [1, 2] & [5, 6], ep_inter_group = [1, 5] & None + """ + assert dist.is_initialized(), "Please initialize torch.distributed first." + + rank = dist.get_rank() + node_rank = int(os.environ["GROUP_RANK"]) + nproc_per_node = int(os.environ["LOCAL_WORLD_SIZE"]) + num_node = dist.get_world_size() // nproc_per_node + ep_ranks = dist.get_process_group_ranks(ep_group) + + ep_intra_ranks = [ + node_rank * nproc_per_node + i + for i in range(nproc_per_node) + if i in ep_ranks + ] + ep_intra_node_group = dist.new_group(ep_intra_ranks) + + ep_inter_ranks = [ + min(ep_ranks) + i * nproc_per_node + for i in range(num_node) + ] + ep_inter_node_group = None + if rank in ep_inter_ranks and len(ep_inter_ranks) > 1: + ep_inter_node_group = dist.new_group(ep_inter_ranks) + + return ep_intra_node_group, ep_inter_node_group diff --git a/colossalai/tensor/moe_tensor/moe_info.py b/colossalai/tensor/moe_tensor/moe_info.py index 5097ac1044e7..8044b3a86dce 100644 --- a/colossalai/tensor/moe_tensor/moe_info.py +++ b/colossalai/tensor/moe_tensor/moe_info.py @@ -15,12 +15,8 @@ def __init__(self, ep_inside: bool, ep_size: int, dp_size: int, pp_size: int = 1 ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True. """ self.pp_size, self.dp_size, self.ep_size = pp_size, dp_size, ep_size - if ep_inside: - self.pp_axis, self.dp_axis, self.ep_axis = 0, 1, 2 - self.pg = ProcessGroupMesh(self.pp_size, self.dp_size, self.ep_size) - else: - self.pp_axis, self.ep_axis, self.dp_axis = 0, 1, 2 - self.pg = ProcessGroupMesh(self.pp_size, self.ep_size, self.dp_size) + self.pp_axis, self.dp_axis, self.ep_axis = (0, 1, 2) if ep_inside else (0, 2, 1) + self.pg = ProcessGroupMesh(self.pp_size, self.dp_size, self.ep_size) self.ep_group = self.pg.get_group_along_axis(self.ep_axis) self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group) From 887c09e88f6e96c027eedc980ab99df55767046a Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 6 Nov 2023 16:28:07 +0800 Subject: [PATCH 06/15] feat: add create_ep_hierarchical_group fn --- colossalai/moe/utils.py | 40 +++++++++++++++++++++++----------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py index d5c90e7eb872..4bed71076c39 100644 --- a/colossalai/moe/utils.py +++ b/colossalai/moe/utils.py @@ -6,6 +6,7 @@ import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F +from colossalai.cluster.process_group_mesh import ProcessGroupMesh from colossalai.moe.manager import MOE_MANAGER from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor @@ -178,35 +179,40 @@ def set_moe_args(config: Any, args: dict): setattr(config, k, v) -def create_ep_node_group( - ep_group: dist.ProcessGroup -) -> Tuple[dist.ProcessGroup, +def create_ep_hierarchical_group( + ep_group: dist.ProcessGroup, + num_node: int, + nproc_per_node: int, +) -> Tuple[Optional[dist.ProcessGroup], Optional[dist.ProcessGroup]]: """ e.g., If ep_group = [1, 2, 5, 6], and nproc_per_node = 4 Then, ep_intra_group = [1, 2] & [5, 6], ep_inter_group = [1, 5] & None """ assert dist.is_initialized(), "Please initialize torch.distributed first." + assert dist.get_world_size() == num_node * nproc_per_node - rank = dist.get_rank() - node_rank = int(os.environ["GROUP_RANK"]) - nproc_per_node = int(os.environ["LOCAL_WORLD_SIZE"]) - num_node = dist.get_world_size() // nproc_per_node + group_mesh = ProcessGroupMesh(num_node, nproc_per_node) ep_ranks = dist.get_process_group_ranks(ep_group) - ep_intra_ranks = [ - node_rank * nproc_per_node + i - for i in range(nproc_per_node) - if i in ep_ranks - ] - ep_intra_node_group = dist.new_group(ep_intra_ranks) + ep_intra_node_group = None + for i in range(num_node): + ep_intra_ranks = [ + i * nproc_per_node + j + for j in range(nproc_per_node) + if j in ep_ranks + ] + group = group_mesh.get_group(ep_intra_ranks) + if group is not None: + assert ep_intra_node_group is None + ep_intra_node_group = group + ep_inter_node_group = None ep_inter_ranks = [ - min(ep_ranks) + i * nproc_per_node + ep_ranks[0] + i * nproc_per_node for i in range(num_node) ] - ep_inter_node_group = None - if rank in ep_inter_ranks and len(ep_inter_ranks) > 1: - ep_inter_node_group = dist.new_group(ep_inter_ranks) + if len(ep_inter_ranks) > 1: + ep_inter_node_group = group_mesh.get_group(ep_inter_ranks) return ep_intra_node_group, ep_inter_node_group From d2e86b078c80379e0bf185d07be68a1e44344083 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 6 Nov 2023 16:30:59 +0800 Subject: [PATCH 07/15] feat: add HierarchicalAllToAll --- colossalai/moe/_operation.py | 62 ++++++++++++++++++++++++++++++++++-- colossalai/moe/layers.py | 24 +++++++++----- 2 files changed, 76 insertions(+), 10 deletions(-) diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index ec1ad6cddeba..4e5798788481 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -62,7 +62,7 @@ class ReduceScatter(torch.autograd.Function): def forward( ctx: Any, inputs: Tensor, - group: Optional[ProcessGroup] = None, + group: ProcessGroup, overlap: bool = False, ) -> Tuple[Tensor, Any]: """ @@ -111,7 +111,7 @@ class AllToAll(torch.autograd.Function): def forward( ctx: Any, inputs: Tensor, - group: Optional[ProcessGroup] = None, + group: ProcessGroup, overlap: bool = False, ) -> Tuple[Tensor, Any]: """ @@ -144,6 +144,64 @@ def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: ) +class HierarchicalAllToAll(torch.autograd.Function): + + @staticmethod + def forward( + ctx: Any, + inputs: Tensor, + groups: Tuple[ProcessGroup], + ) -> Tensor: + """ + Returns: + outputs: Tensor + """ + if ctx is not None: + ctx.comm_grps = groups + intra_node_group, inter_node_group = groups + + local_world_size = dist.get_world_size(intra_node_group) + num_group = dist.get_world_size(inter_node_group) if inter_node_group is not None else 1 + world_size = local_world_size * num_group + src_rank = dist.get_process_group_ranks(intra_node_group)[0] + outputs = torch.empty_like(inputs) + + if dist.get_rank() == src_rank: + # intra-node gather + intra_output = [torch.empty_like(inputs) for _ in range(local_world_size)] + dist.gather(inputs, intra_output, dst=src_rank, group=intra_node_group) + + intra_output = [v.chunk(world_size, dim=0) for v in intra_output] + intra_output = torch.cat(sum(zip(*intra_output), ())) + + # inter-node all-to-all + if inter_node_group is not None: + inter_output = torch.empty_like(intra_output) + dist.all_to_all_single(inter_output, intra_output, group=inter_node_group) + + # layout transform + inter_output = inter_output.chunk(num_group, dim=0) + inter_output = [v.chunk(local_world_size, dim=0) for v in inter_output] + intra_output = torch.cat(sum(zip(*inter_output), ())) + + # intra-node scatter + intra_output = list(intra_output.chunk(local_world_size, dim=0)) + dist.scatter(outputs, intra_output, src=src_rank, group=intra_node_group) + + else: + dist.gather(inputs, dst=src_rank, group=intra_node_group) + dist.scatter(outputs, src=src_rank, group=intra_node_group) + + return outputs + + @staticmethod + def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]: + return ( + HierarchicalAllToAll.forward(None, grad_outputs[0], ctx.comm_grps), + None, + ) + + class MoeDispatch(torch.autograd.Function): @staticmethod @custom_fwd diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index 8b399e6d3f44..816e2a6ce275 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -1,13 +1,13 @@ import dataclasses import math -from typing import Any, Optional, Tuple +from typing import Any, Callable, Optional, Tuple import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F -from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter +from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter from colossalai.moe.experts import MLPExperts from colossalai.moe.load_balance import LoadBalancer from colossalai.moe.manager import MOE_MANAGER @@ -64,6 +64,7 @@ def __init__( load_balance_group_swap_factor: float = 0.4, enable_kernel: bool = False, enable_comm_overlap: bool = False, + create_hierarchical_group: Optional[Callable] = lambda *args, **kwargs: None, ): super().__init__() self.hidden_size = hidden_size @@ -104,6 +105,7 @@ def __init__( if self.expert_parallel is not None: self.ep_group = get_ep_group(self.experts) self.ep_size = get_ep_size(self.experts) + self.ep_hierarchical_group = create_hierarchical_group(self.ep_group) self.dp_group = get_dp_group(self.experts) else: self.ep_group = None @@ -221,12 +223,18 @@ def _ep_process( torch.Tensor: (num_experts, capacity, hidden_size) """ if not overlap or dist.get_world_size(self.ep_group) == 1: - expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0] - expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) - expert_output = self.experts(expert_input) - expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0] - return expert_output - + if self.ep_hierarchical_group is not None: + expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group) + expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) + expert_output = self.experts(expert_input) + expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group) + return expert_output + else: + expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0] + expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) + expert_output = self.experts(expert_input) + expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0] + return expert_output else: @dataclasses.dataclass From a1bad942adb9b152062a2a81deba73ab8ec61632 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 6 Nov 2023 16:37:55 +0800 Subject: [PATCH 08/15] test: add hierarchical all2all test --- tests/test_moe/test_moe_ep_tp.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 52dc3eeb060b..23c7c556bd1e 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -7,7 +7,7 @@ import colossalai from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import sync_moe_model_param +from colossalai.moe.utils import create_ep_hierarchical_group, sync_moe_model_param from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device @@ -132,7 +132,12 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) MOE_MANAGER.__init__() MOE_MANAGER.setup(parallel="EP") - ep_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) + ep_model = SparseMLP( + num_experts=num_experts, + hidden_size=dim, + intermediate_size=dim * 2, + create_hierarchical_group=lambda group: create_ep_hierarchical_group(group, num_node=1, nproc_per_node=2) + ) MOE_MANAGER.__init__() MOE_MANAGER.setup(parallel="TP") tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) From 374aba31f985a6ebe60b330c70f0aa9557d250d6 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 6 Nov 2023 18:19:44 +0800 Subject: [PATCH 09/15] fix: fix test errors --- colossalai/moe/utils.py | 12 ++++++------ tests/test_moe/test_moe_router.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py index 4bed71076c39..b068aa48a345 100644 --- a/colossalai/moe/utils.py +++ b/colossalai/moe/utils.py @@ -1,12 +1,10 @@ import contextlib -import os from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F -from colossalai.cluster.process_group_mesh import ProcessGroupMesh from colossalai.moe.manager import MOE_MANAGER from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor @@ -192,7 +190,7 @@ def create_ep_hierarchical_group( assert dist.is_initialized(), "Please initialize torch.distributed first." assert dist.get_world_size() == num_node * nproc_per_node - group_mesh = ProcessGroupMesh(num_node, nproc_per_node) + rank = dist.get_rank() ep_ranks = dist.get_process_group_ranks(ep_group) ep_intra_node_group = None @@ -202,8 +200,8 @@ def create_ep_hierarchical_group( for j in range(nproc_per_node) if j in ep_ranks ] - group = group_mesh.get_group(ep_intra_ranks) - if group is not None: + group = dist.new_group(ep_intra_ranks) + if rank in ep_intra_ranks: assert ep_intra_node_group is None ep_intra_node_group = group @@ -213,6 +211,8 @@ def create_ep_hierarchical_group( for i in range(num_node) ] if len(ep_inter_ranks) > 1: - ep_inter_node_group = group_mesh.get_group(ep_inter_ranks) + group = dist.new_group(ep_inter_ranks) + if rank in ep_inter_ranks: + ep_inter_node_group = group return ep_intra_node_group, ep_inter_node_group diff --git a/tests/test_moe/test_moe_router.py b/tests/test_moe/test_moe_router.py index c2b581eb1556..7ba7fa6f6b7d 100644 --- a/tests/test_moe/test_moe_router.py +++ b/tests/test_moe/test_moe_router.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize(["router", "num_groups"], [ (Top1Router(), 1), (Top2Router(), 1), - (TopKRouter(num_selected_experts=3), 4), + # (TopKRouter(num_selected_experts=3), 4), ]) @pytest.mark.parametrize(["batch_size", "seq_len", "num_experts"], [ (4, 5, 8), From 9c98c73182fa52c201056f45348a456cbddea189 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Tue, 7 Nov 2023 10:20:37 +0800 Subject: [PATCH 10/15] fix: simplify create_ep_hierarchical_group --- colossalai/moe/utils.py | 12 +++++++++--- tests/test_moe/test_moe_ep_tp.py | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py index b068aa48a345..a2cf99762ebd 100644 --- a/colossalai/moe/utils.py +++ b/colossalai/moe/utils.py @@ -1,4 +1,5 @@ import contextlib +import os from typing import Any, Callable, Dict, List, Optional, Tuple import torch @@ -179,8 +180,7 @@ def set_moe_args(config: Any, args: dict): def create_ep_hierarchical_group( ep_group: dist.ProcessGroup, - num_node: int, - nproc_per_node: int, + nproc_per_node: Optional[int] = None, ) -> Tuple[Optional[dist.ProcessGroup], Optional[dist.ProcessGroup]]: """ @@ -188,7 +188,13 @@ def create_ep_hierarchical_group( Then, ep_intra_group = [1, 2] & [5, 6], ep_inter_group = [1, 5] & None """ assert dist.is_initialized(), "Please initialize torch.distributed first." - assert dist.get_world_size() == num_node * nproc_per_node + if nproc_per_node is None: + nproc_per_node = os.environ.get("LOCAL_WORLD_SIZE") + assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually." + else: + assert dist.get_world_size() % nproc_per_node == 0, \ + "nproc_per_node should be a divisor of world_size." + num_node = dist.get_world_size() // nproc_per_node rank = dist.get_rank() ep_ranks = dist.get_process_group_ranks(ep_group) diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 23c7c556bd1e..b2331f9680eb 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -136,7 +136,7 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2, - create_hierarchical_group=lambda group: create_ep_hierarchical_group(group, num_node=1, nproc_per_node=2) + create_hierarchical_group=lambda group: create_ep_hierarchical_group(group, nproc_per_node=world_size) ) MOE_MANAGER.__init__() MOE_MANAGER.setup(parallel="TP") From 3d45d28562166329057980de4dc48ccd4225ebf7 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Tue, 7 Nov 2023 11:04:58 +0800 Subject: [PATCH 11/15] fix: add hierarchical_alltoall arg --- colossalai/moe/layers.py | 9 +++++---- examples/language/openmoe/benchmark/benchmark_cai.py | 9 ++++++--- examples/language/openmoe/model/modeling_openmoe.py | 3 +++ examples/language/openmoe/train.py | 7 +++++++ tests/test_moe/test_moe_ep_tp.py | 4 +++- 5 files changed, 24 insertions(+), 8 deletions(-) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index 816e2a6ce275..2714d6316151 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -1,6 +1,6 @@ import dataclasses import math -from typing import Any, Callable, Optional, Tuple +from typing import Any, Optional, Tuple import torch import torch.distributed as dist @@ -12,7 +12,7 @@ from colossalai.moe.load_balance import LoadBalancer from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.routers import MoeRouter, get_router_cls -from colossalai.moe.utils import get_noise_generator +from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size @@ -64,7 +64,7 @@ def __init__( load_balance_group_swap_factor: float = 0.4, enable_kernel: bool = False, enable_comm_overlap: bool = False, - create_hierarchical_group: Optional[Callable] = lambda *args, **kwargs: None, + enable_hierarchical_comm: bool = False, ): super().__init__() self.hidden_size = hidden_size @@ -105,7 +105,8 @@ def __init__( if self.expert_parallel is not None: self.ep_group = get_ep_group(self.experts) self.ep_size = get_ep_size(self.experts) - self.ep_hierarchical_group = create_hierarchical_group(self.ep_group) + self.ep_hierarchical_group = create_ep_hierarchical_group( + self.ep_group) if enable_hierarchical_comm else None self.dp_group = get_dp_group(self.experts) else: self.ep_group = None diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index f48ba9ef89a5..65562b386cf9 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -132,8 +132,10 @@ def parse_args(): # load balance parser.add_argument("--load_balance", action="store_true") - # overlap - parser.add_argument("--overlap_alltoall", action="store_true") + # overlap communication + parser.add_argument("--overlap_comm", action="store_true") + # hierarchical all-to-all + parser.add_argument("--hierarchical_alltoall", action="store_true") args = parser.parse_args() return args @@ -211,7 +213,8 @@ def main(): moe_layer_interval=config.moe_layer_interval, enable_load_balance=args.load_balance, enable_kernel=args.use_kernel, - enable_comm_overlap=args.overlap_alltoall, + enable_comm_overlap=args.overlap_comm, + enable_hierarchical_alltoall=args.hierarchical_alltoall, ) with skip_init(): model = OpenMoeForCausalLM(config) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 7e3e6b3ed364..ec7644317903 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -70,6 +70,7 @@ def set_openmoe_args( load_balance_group_swap_factor: float = 0.4, enable_kernel: bool = False, enable_comm_overlap: bool = False, + enable_hierarchical_alltoall: bool = False, ) -> None: """ MoE related arguments. @@ -96,6 +97,7 @@ def set_openmoe_args( load_balance_group_swap_factor (float, optional): Expert load balance group swap factor. Longer value encourages less swap. Defaults to 0.4. enable_kernel (bool, optional): Use kernel optimization. Defaults to False. enable_comm_overlap (bool, optional): Use communication overlap for MoE. Recommended to enable for muiti-node training. Defaults to False. + enable_hierarchical_alltoall (bool, optional): Use hierarchical alltoall for MoE. Defaults to False. """ moe_args = dict( num_experts=num_experts, @@ -117,6 +119,7 @@ def set_openmoe_args( load_balance_group_swap_factor=load_balance_group_swap_factor, enable_kernel=enable_kernel, enable_comm_overlap=enable_comm_overlap, + enable_hierarchical_alltoall=enable_hierarchical_alltoall, ) set_moe_args(config, moe_args) diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index b4c45416c199..b084361661ac 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -190,6 +190,12 @@ def parse_args(): action="store_true", help="Use communication overlap for MoE. Recommended to enable for muiti-node training.", ) + # hierarchical all-to-all + parser.add_argument( + "--hierarchical_alltoall", + action="store_true", + help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.", + ) args = parser.parse_args() return args @@ -277,6 +283,7 @@ def main(): z_loss_factor=args.z_loss_factor, enable_load_balance=args.load_balance, enable_comm_overlap=args.comm_overlap, + enable_hierarchical_alltoall=args.hierarchical_alltoall, enable_kernel=args.use_kernel, ) with skip_init(): diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index b2331f9680eb..1bfdd7744a16 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -1,3 +1,4 @@ +import os import warnings import pytest @@ -132,11 +133,12 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) MOE_MANAGER.__init__() MOE_MANAGER.setup(parallel="EP") + os.environ["LOCAL_WORLD_SIZE"] = str(world_size) ep_model = SparseMLP( num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2, - create_hierarchical_group=lambda group: create_ep_hierarchical_group(group, nproc_per_node=world_size) + enable_hierarchical_comm=True, ) MOE_MANAGER.__init__() MOE_MANAGER.setup(parallel="TP") From 610de3428531247507060fb1878ade5a3798016e Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Tue, 7 Nov 2023 11:13:31 +0800 Subject: [PATCH 12/15] fix: fix environ typo --- colossalai/moe/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py index a2cf99762ebd..5180f6ea6274 100644 --- a/colossalai/moe/utils.py +++ b/colossalai/moe/utils.py @@ -191,6 +191,7 @@ def create_ep_hierarchical_group( if nproc_per_node is None: nproc_per_node = os.environ.get("LOCAL_WORLD_SIZE") assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually." + nproc_per_node = int(nproc_per_node) else: assert dist.get_world_size() % nproc_per_node == 0, \ "nproc_per_node should be a divisor of world_size." From ee71effce518e8390a35fb4c6b03be31ea1991eb Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Tue, 7 Nov 2023 11:33:06 +0800 Subject: [PATCH 13/15] revert: revert process mesh order --- colossalai/tensor/moe_tensor/moe_info.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/colossalai/tensor/moe_tensor/moe_info.py b/colossalai/tensor/moe_tensor/moe_info.py index 8044b3a86dce..5097ac1044e7 100644 --- a/colossalai/tensor/moe_tensor/moe_info.py +++ b/colossalai/tensor/moe_tensor/moe_info.py @@ -15,8 +15,12 @@ def __init__(self, ep_inside: bool, ep_size: int, dp_size: int, pp_size: int = 1 ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True. """ self.pp_size, self.dp_size, self.ep_size = pp_size, dp_size, ep_size - self.pp_axis, self.dp_axis, self.ep_axis = (0, 1, 2) if ep_inside else (0, 2, 1) - self.pg = ProcessGroupMesh(self.pp_size, self.dp_size, self.ep_size) + if ep_inside: + self.pp_axis, self.dp_axis, self.ep_axis = 0, 1, 2 + self.pg = ProcessGroupMesh(self.pp_size, self.dp_size, self.ep_size) + else: + self.pp_axis, self.ep_axis, self.dp_axis = 0, 1, 2 + self.pg = ProcessGroupMesh(self.pp_size, self.ep_size, self.dp_size) self.ep_group = self.pg.get_group_along_axis(self.ep_axis) self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group) From f70a73bee4b1e9b1d805a5c3749b4aa12ba5d856 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Tue, 7 Nov 2023 11:37:58 +0800 Subject: [PATCH 14/15] to: add todo mark --- colossalai/moe/_operation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index 4e5798788481..abc221fea2ad 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -156,6 +156,7 @@ def forward( Returns: outputs: Tensor """ + # TODO: we can reduce comm volume by removing empty capacity if ctx is not None: ctx.comm_grps = groups intra_node_group, inter_node_group = groups From 867fc33f51121a45ab597e18f5361ffedde5615c Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Thu, 9 Nov 2023 10:42:52 +0800 Subject: [PATCH 15/15] fix: skip hierarchical_comm if torch < 1.13.1 --- tests/test_moe/test_moe_ep_tp.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 1bfdd7744a16..d5557a41f139 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -8,7 +8,7 @@ import colossalai from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import create_ep_hierarchical_group, sync_moe_model_param +from colossalai.moe.utils import sync_moe_model_param from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device @@ -134,11 +134,12 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size MOE_MANAGER.__init__() MOE_MANAGER.setup(parallel="EP") os.environ["LOCAL_WORLD_SIZE"] = str(world_size) + enable_hierarchical_comm = torch.__version__ >= "1.13.1" ep_model = SparseMLP( num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2, - enable_hierarchical_comm=True, + enable_hierarchical_comm=enable_hierarchical_comm ) MOE_MANAGER.__init__() MOE_MANAGER.setup(parallel="TP")