From 4eba952c6d3bf0d26ed58fd67633da94ada837c0 Mon Sep 17 00:00:00 2001 From: CWHer Date: Wed, 11 Oct 2023 12:32:45 +0800 Subject: [PATCH 01/10] test: add more ep/tp test case --- tests/test_moe/test_moe_ep_tp.py | 70 +++++++++++++++++++++++--------- tests/test_moe/test_moe_local.py | 65 ----------------------------- 2 files changed, 50 insertions(+), 85 deletions(-) delete mode 100644 tests/test_moe/test_moe_local.py diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 2bbf739ebbd4..10ccd8ebd2c2 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -8,57 +8,87 @@ from colossalai.moe.utils import sync_moe_model_param from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -from tests.test_moe.moe_utils import MoeGradientHandler, sync_tp_from_ep +from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep, sync_tp_from_ep -BATCH_SIZE = 4 -DIM = 16 +def run_test(rank: int, + world_size: int, + port: int, + num_experts: int, + batch_size: int, + dim: int, + seed: int): + assert batch_size % world_size == 0 -def run_test(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(42, parallel="EP") # MOE initialization - ep_model = SparseMLP(num_experts=4, hidden_size=DIM, intermediate_size=DIM) - MOE_MANAGER.__init__() - MOE_MANAGER.setup(42, parallel="TP") - tp_model = SparseMLP(num_experts=4, hidden_size=DIM, intermediate_size=DIM) + MOE_MANAGER.setup(seed) # MOE initialization + + ep_model = SparseMLP(num_experts=num_experts, + expert_parallel="EP", + hidden_size=dim, + intermediate_size=dim * 2) + tp_model = SparseMLP(num_experts=num_experts, + expert_parallel="TP", + hidden_size=dim, + intermediate_size=dim * 2) + local_model = SparseMLP(num_experts=num_experts, + expert_parallel=None, + hidden_size=dim, + intermediate_size=dim * 2) ep_model = ep_model.to(get_current_device()) tp_model = tp_model.to(get_current_device()) + local_model = local_model.to(get_current_device()) # sync ep param sync_moe_model_param(ep_model) dist_dict = MOE_MANAGER.parallel_info_dict - assert_equal_in_group(ep_model.experts.wi.data, dist_dict[2].dp_group) - assert_equal_in_group(ep_model.experts.wo.data, dist_dict[2].dp_group) + assert_equal_in_group(ep_model.experts.wi.data, dist_dict[world_size].dp_group) + assert_equal_in_group(ep_model.experts.wo.data, dist_dict[world_size].dp_group) grad_handler = MoeGradientHandler(ep_model) # sync tp param sync_tp_from_ep(tp_model, ep_model) + # sync local param + sync_local_from_ep(local_model, ep_model) rank = dist.get_rank() - torch.cuda.manual_seed(78) - tp_data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) - ep_data = tp_data.detach()[2 * rank:2 * (rank + 1)] + torch.cuda.manual_seed(seed) + tp_data = torch.randn(batch_size, dim, device=get_current_device()) + micro_batch_size = batch_size // world_size + ep_data = tp_data.detach()[micro_batch_size * rank:micro_batch_size * (rank + 1)] + out_local = local_model(tp_data) + MOE_MANAGER.reset_loss() out_tp = tp_model(tp_data) MOE_MANAGER.reset_loss() out_ep = ep_model(ep_data) MOE_MANAGER.reset_loss() - assert torch.allclose(out_ep, out_tp[2 * rank:2 * (rank + 1)]) + assert torch.allclose(out_ep, out_tp[micro_batch_size * rank:micro_batch_size * (rank + 1)]) + assert torch.allclose(out_ep, out_local[micro_batch_size * rank:micro_batch_size * (rank + 1)]) + out_local.mean().backward() out_tp.mean().backward() out_ep.mean().backward() grad_handler.handle_gradient() - assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[2].dp_group) - assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[2].dp_group) + assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[world_size].dp_group) + assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[world_size].dp_group) + sync_local_from_ep(local_model, ep_model, assert_grad_flag=True) sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True) @pytest.mark.dist +@pytest.mark.parametrize("num_experts", [4, 8]) +@pytest.mark.parametrize("batch_size", [4, 8]) +@pytest.mark.parametrize("dim", [16, 256]) +@pytest.mark.parametrize("seed", [42, 78]) @rerun_if_address_is_in_use() -def test_moe_ep_tp(): - spawn(run_test, 2) +def test_moe_ep_tp(num_experts: int, + batch_size: int, + dim: int, + seed: int): + spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed) if __name__ == '__main__': - test_moe_ep_tp() + test_moe_ep_tp(num_experts=8, batch_size=8, dim=256, seed=42) diff --git a/tests/test_moe/test_moe_local.py b/tests/test_moe/test_moe_local.py deleted file mode 100644 index 1211a0d2d7f1..000000000000 --- a/tests/test_moe/test_moe_local.py +++ /dev/null @@ -1,65 +0,0 @@ -import pytest -import torch -import torch.distributed as dist - -import colossalai -from colossalai.moe import SparseMLP -from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import sync_moe_model_param -from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device -from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep - -BATCH_SIZE = 4 -DIM = 16 - - -def run_test(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(42, parallel=None) - local_model = SparseMLP(num_experts=4, hidden_size=DIM, intermediate_size=DIM) - MOE_MANAGER.__init__() - MOE_MANAGER.setup(42, parallel="EP") # MOE initialization - ep_model = SparseMLP(num_experts=4, hidden_size=DIM, intermediate_size=DIM) - ep_model = ep_model.to(get_current_device()) - local_model = local_model.to(get_current_device()) - - # sync ep param - sync_moe_model_param(ep_model) - dist_dict = MOE_MANAGER.parallel_info_dict - assert_equal_in_group(ep_model.experts.wi.data, dist_dict[2].dp_group) - assert_equal_in_group(ep_model.experts.wo.data, dist_dict[2].dp_group) - grad_handler = MoeGradientHandler(ep_model) - # sync tp param - sync_local_from_ep(local_model, ep_model) - - rank = dist.get_rank() - torch.cuda.manual_seed(78) - tp_data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) - ep_data = tp_data.detach()[2 * rank:2 * (rank + 1)] - - out_tp = local_model(tp_data) - MOE_MANAGER.reset_loss() - out_ep = ep_model(ep_data) - MOE_MANAGER.reset_loss() - assert torch.allclose(out_ep, out_tp[2 * rank:2 * (rank + 1)]) - - out_tp.mean().backward() - out_ep.mean().backward() - grad_handler.handle_gradient() - - assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[2].dp_group) - assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[2].dp_group) - - sync_local_from_ep(local_model, ep_model, assert_grad_flag=True) - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [2]) -@rerun_if_address_is_in_use() -def test_moe_local(world_size): - spawn(run_test, world_size) - - -if __name__ == '__main__': - test_moe_local() From aa57ff9a56668bf77b8bb651a237372d28e3e55a Mon Sep 17 00:00:00 2001 From: CWHer Date: Wed, 11 Oct 2023 17:32:25 +0800 Subject: [PATCH 02/10] to: add TPOverlap fn --- colossalai/moe/_operation.py | 95 +++++++++++++++++++++++++++++------- colossalai/moe/layers.py | 79 ++++++------------------------ 2 files changed, 92 insertions(+), 82 deletions(-) diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index a932b96597b6..3ce203ce27e6 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -1,7 +1,8 @@ -from typing import Any, Optional, Tuple +from typing import Any, Callable, Optional, Tuple import torch import torch.distributed as dist +import torch.nn as nn from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup @@ -9,8 +10,6 @@ from colossalai.moe.manager import MOE_MANAGER MOE_KERNEL = None -WORLD_HANDLE_ALLGATHER = None -WORLD_HANDLE_REDUCESCATTER = None def load_moe(): @@ -20,6 +19,64 @@ def load_moe(): MOE_KERNEL = MOEBuilder().load() +class TPOverlap(torch.autograd.Function): + + @staticmethod + def forward( + ctx: Any, + experts: nn.Module, + dispatch_data: Tensor, + group: ProcessGroup, + ) -> Tensor: + + NUM_CHUNK = 1 + NUM_STAGES = 4 + ctx.save_for_backward(experts, dispatch_data) + + assert dispatch_data.shape[0] % NUM_CHUNK == 0, \ + "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" + chunk_size = dispatch_data.shape[0] // NUM_CHUNK + chunk_data = torch.split(dispatch_data, chunk_size, dim=0) + output = torch.empty_like(dispatch_data) + + def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]: + return (slice(idx * chunk_size, (idx + 1) * chunk_size), ) + + expert_in, in_handle, input_indices = None, None, None + partial_expert_out, data_indices = None, None + expert_out, out_handle, output_indices = None, None, None + + for i in range(NUM_CHUNK + NUM_STAGES - 1): + if out_handle is not None: + out_handle.wait() + output[output_indices] = expert_out + expert_out, out_handle, output_indices = None, None, None + + # reduce scatter last output + if partial_expert_out is not None: + output_indices = data_indices + expert_out, out_handle = ReduceScatter.apply(partial_expert_out, group, True) + partial_expert_out = None + + # compute + if in_handle is not None: + in_handle.wait() + data_indices = input_indices + partial_expert_out = experts(expert_in, input_indices) + expert_in, in_handle, input_indices = None, None, None + + # all gather next input + if 0 <= i < NUM_CHUNK: + input_indices = get_chunk_slice(i, chunk_size) + expert_in, in_handle = AllGather.apply(chunk_data[i].contiguous(), group, True) + + return output + + @staticmethod + def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None, None]: + raise NotImplementedError() + + class AllGather(torch.autograd.Function): @staticmethod @@ -28,14 +85,20 @@ def forward( inputs: Tensor, group: Optional[ProcessGroup] = None, overlap: bool = False, - ) -> Tensor: + ) -> Tuple[Tensor, Optional[Callable]]: + """ + Returns: + outputs: Tensor + handle: Optional[Callable], if overlap is True + """ + if ctx is not None: ctx.comm_grp = group ctx.overlap = overlap comm_size = dist.get_world_size(group) if comm_size == 1: - return inputs.unsqueeze(0) + return inputs.unsqueeze(0), None buffer_shape = (comm_size,) + inputs.shape outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) @@ -51,11 +114,7 @@ def forward( return outputs, handle @staticmethod - def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]: - global WORLD_HANDLE_REDUCESCATTER - if WORLD_HANDLE_REDUCESCATTER is not None: - WORLD_HANDLE_REDUCESCATTER.wait() - WORLD_HANDLE_REDUCESCATTER = None + def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: return ( ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, ctx.overlap)[0], None, @@ -71,14 +130,20 @@ def forward( inputs: Tensor, group: Optional[ProcessGroup] = None, overlap: bool = False, - ) -> Tensor: + ) -> Tuple[Tensor, Optional[Callable]]: + """ + Returns: + outputs: Tensor + handle: Optional[Callable], if overlap is True + """ + if ctx is not None: ctx.comm_grp = group ctx.overlap = overlap comm_size = dist.get_world_size(group) if comm_size == 1: - return inputs.squeeze(0) + return inputs.squeeze(0), None if not inputs.is_contiguous(): inputs = inputs.contiguous() @@ -97,11 +162,7 @@ def forward( return outputs, handle @staticmethod - def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]: - global WORLD_HANDLE_ALLGATHER - if WORLD_HANDLE_ALLGATHER is not None: - WORLD_HANDLE_ALLGATHER.wait() - WORLD_HANDLE_ALLGATHER = None + def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: return ( AllGather.forward(None, grad_outputs[0], ctx.comm_grp, ctx.overlap)[0], None, diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index 3f82a0fa23fd..ff06021e3826 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -6,7 +6,7 @@ import torch.nn as nn import torch.nn.functional as F -from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter +from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter, TPOverlap from colossalai.moe.experts import BaseMLPExperts, get_expert_class from colossalai.moe.load_balance import LoadBalancer from colossalai.moe.manager import MOE_MANAGER @@ -206,19 +206,20 @@ def _ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: expert_output = AllToAll.apply(expert_output, self.ep_group) return expert_output - def _tp_process(self, dispatch_data: torch.Tensor, use_overlap: bool = False) -> torch.Tensor: + def _tp_process(self, + dispatch_data: torch.Tensor, + overlap: bool = True + ) -> torch.Tensor: """ - TP with overlap. - - origin: + without overlap: | C | | A | | R | - overlap: + with overlap: | C1 || C2 || C3 || C4 | | A1 || A2 | | R1 | A3 || R2 | A4 || R3 | | R4 | - C is computation, A is all gather, R is reduce scatter. + where C is computation, A is all gather, R is reduce scatter. Args: dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size) @@ -226,65 +227,13 @@ def _tp_process(self, dispatch_data: torch.Tensor, use_overlap: bool = False) -> Returns: torch.Tensor: (num_experts, capacity, hidden_size) """ - if use_overlap == False: - expert_in, _ = AllGather.apply(dispatch_data, self.ep_group) - expert_out = self.experts(expert_in) - expert_out, _ = ReduceScatter.apply(expert_out, self.ep_group) + if not overlap: + expert_in = AllGather.apply(dispatch_data, self.ep_group, False)[0] + partial_expert_out = self.experts(expert_in) + expert_out = ReduceScatter.apply(partial_expert_out, self.ep_group, False)[0] return expert_out - - # TODO: there is accuracy problem in overlap - chunk_num = 4 - chunk_size = dispatch_data.shape[0] // chunk_num - out = torch.empty_like(dispatch_data) - in_data = None - in_handle = None - out_data = None - out_handle = None - - # backward compatibility for async op - torch.cuda.synchronize() - - def get_chunk_slice(idx: int, gap: int) -> Tuple[slice]: - return (slice(idx * gap, (idx + 1) * gap),) - - for i in range(chunk_num): - cur_chunk_slice = get_chunk_slice(i, chunk_size) - - # if first, all gather - if i == 0: - d = dispatch_data[cur_chunk_slice].contiguous() - expert_in, _ = AllGather.apply(d, self.ep_group) - else: - expert_in = in_data - - # async communication while compute - if i != 0: - # reduce scatter last out - out_data, out_handle = ReduceScatter.apply(out_data, self.ep_group, True) - if i != chunk_num - 1: - # all gather next in - next_d = dispatch_data[get_chunk_slice(i + 1, chunk_size)].contiguous() - in_data, in_handle = AllGather.apply(next_d, self.ep_group, True) - - # compute - expert_out = self.experts(expert_in, cur_chunk_slice) - - # sync handle - if i != 0: - out_handle.wait() - out[get_chunk_slice(i - 1, chunk_size)] = out_data - if i != chunk_num - 1: - in_handle.wait() - out_data = expert_out - - # store out for last loop - if i == chunk_num - 1: - out_data, _ = ReduceScatter.apply(out_data, self.ep_group) - out[cur_chunk_slice] = out_data - - # sync for async op - torch.cuda.synchronize() - return out + else: + return TPOverlap.apply(self.experts, dispatch_data, self.ep_group) def apply_load_balance(model: nn.Module, optim: Any) -> None: From 4861ddc1e75484d8eaa38fe7aab3a8cb11bc185b Mon Sep 17 00:00:00 2001 From: CWHer Date: Wed, 11 Oct 2023 18:08:52 +0800 Subject: [PATCH 03/10] fix: fix tp overlap --- colossalai/moe/_operation.py | 65 ++---------------------------------- colossalai/moe/layers.py | 44 ++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 64 deletions(-) diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index 3ce203ce27e6..e3db66637756 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -2,7 +2,6 @@ import torch import torch.distributed as dist -import torch.nn as nn from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup @@ -19,64 +18,6 @@ def load_moe(): MOE_KERNEL = MOEBuilder().load() -class TPOverlap(torch.autograd.Function): - - @staticmethod - def forward( - ctx: Any, - experts: nn.Module, - dispatch_data: Tensor, - group: ProcessGroup, - ) -> Tensor: - - NUM_CHUNK = 1 - NUM_STAGES = 4 - ctx.save_for_backward(experts, dispatch_data) - - assert dispatch_data.shape[0] % NUM_CHUNK == 0, \ - "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" - chunk_size = dispatch_data.shape[0] // NUM_CHUNK - chunk_data = torch.split(dispatch_data, chunk_size, dim=0) - output = torch.empty_like(dispatch_data) - - def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]: - return (slice(idx * chunk_size, (idx + 1) * chunk_size), ) - - expert_in, in_handle, input_indices = None, None, None - partial_expert_out, data_indices = None, None - expert_out, out_handle, output_indices = None, None, None - - for i in range(NUM_CHUNK + NUM_STAGES - 1): - if out_handle is not None: - out_handle.wait() - output[output_indices] = expert_out - expert_out, out_handle, output_indices = None, None, None - - # reduce scatter last output - if partial_expert_out is not None: - output_indices = data_indices - expert_out, out_handle = ReduceScatter.apply(partial_expert_out, group, True) - partial_expert_out = None - - # compute - if in_handle is not None: - in_handle.wait() - data_indices = input_indices - partial_expert_out = experts(expert_in, input_indices) - expert_in, in_handle, input_indices = None, None, None - - # all gather next input - if 0 <= i < NUM_CHUNK: - input_indices = get_chunk_slice(i, chunk_size) - expert_in, in_handle = AllGather.apply(chunk_data[i].contiguous(), group, True) - - return output - - @staticmethod - def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None, None]: - raise NotImplementedError() - - class AllGather(torch.autograd.Function): @staticmethod @@ -91,7 +32,7 @@ def forward( outputs: Tensor handle: Optional[Callable], if overlap is True """ - + # print(in) if ctx is not None: ctx.comm_grp = group ctx.overlap = overlap @@ -116,7 +57,7 @@ def forward( @staticmethod def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: return ( - ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, ctx.overlap)[0], + ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], None, None, ) @@ -164,7 +105,7 @@ def forward( @staticmethod def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: return ( - AllGather.forward(None, grad_outputs[0], ctx.comm_grp, ctx.overlap)[0], + AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], None, None, ) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index ff06021e3826..19ed6d626ac7 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -6,7 +6,7 @@ import torch.nn as nn import torch.nn.functional as F -from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter, TPOverlap +from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter from colossalai.moe.experts import BaseMLPExperts, get_expert_class from colossalai.moe.load_balance import LoadBalancer from colossalai.moe.manager import MOE_MANAGER @@ -233,7 +233,47 @@ def _tp_process(self, expert_out = ReduceScatter.apply(partial_expert_out, self.ep_group, False)[0] return expert_out else: - return TPOverlap.apply(self.experts, dispatch_data, self.ep_group) + NUM_CHUNK = 4 + NUM_STAGES = 4 + + assert dispatch_data.shape[0] % NUM_CHUNK == 0, \ + "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" + chunk_size = dispatch_data.shape[0] // NUM_CHUNK + chunk_data = torch.split(dispatch_data, chunk_size, dim=0) + output = torch.empty_like(dispatch_data) + + def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]: + return (slice(idx * chunk_size, (idx + 1) * chunk_size), ) + + expert_in, in_handle, input_indices = None, None, None + partial_expert_out, data_indices = None, None + expert_out, out_handle, output_indices = None, None, None + + for i in range(NUM_CHUNK + NUM_STAGES - 1): + if out_handle is not None: + out_handle.wait() + output[output_indices] = expert_out + expert_out, out_handle, output_indices = None, None, None + + # reduce scatter last output + if partial_expert_out is not None: + output_indices = data_indices + expert_out, out_handle = ReduceScatter.apply(partial_expert_out, self.ep_group, True) + partial_expert_out = None + + # compute + if in_handle is not None: + in_handle.wait() + data_indices = input_indices + partial_expert_out = self.experts(expert_in, input_indices) + expert_in, in_handle, input_indices = None, None, None + + # all gather next input + if 0 <= i < NUM_CHUNK: + input_indices = get_chunk_slice(i, chunk_size) + expert_in, in_handle = AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True) + + return output def apply_load_balance(model: nn.Module, optim: Any) -> None: From 3e9c6ee7a56b02f96638d465a914544507ae2451 Mon Sep 17 00:00:00 2001 From: CWHer Date: Wed, 11 Oct 2023 18:30:31 +0800 Subject: [PATCH 04/10] fix: remove useless variables --- colossalai/moe/_operation.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index e3db66637756..e16ebbe620e7 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -32,10 +32,10 @@ def forward( outputs: Tensor handle: Optional[Callable], if overlap is True """ - # print(in) + assert ctx is not None or not overlap + if ctx is not None: ctx.comm_grp = group - ctx.overlap = overlap comm_size = dist.get_world_size(group) if comm_size == 1: @@ -49,9 +49,6 @@ def forward( return outputs, None else: handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True) - if ctx is None and overlap: - global WORLD_HANDLE_ALLGATHER - WORLD_HANDLE_ALLGATHER = handle return outputs, handle @staticmethod @@ -77,10 +74,10 @@ def forward( outputs: Tensor handle: Optional[Callable], if overlap is True """ + assert ctx is not None or not overlap if ctx is not None: ctx.comm_grp = group - ctx.overlap = overlap comm_size = dist.get_world_size(group) if comm_size == 1: @@ -97,9 +94,6 @@ def forward( return outputs, None else: handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True) - if ctx is None and overlap: - global WORLD_HANDLE_REDUCESCATTER - WORLD_HANDLE_REDUCESCATTER = handle return outputs, handle @staticmethod From cb16779fcac1a0d1a8a45a1d1fd738606e175560 Mon Sep 17 00:00:00 2001 From: CWHer Date: Fri, 13 Oct 2023 15:37:05 +0800 Subject: [PATCH 05/10] feat: add async all to all --- colossalai/moe/_operation.py | 40 ++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index e16ebbe620e7..b180c6b0cbe5 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Optional, Tuple +from typing import Any, Optional, Tuple import torch import torch.distributed as dist @@ -26,11 +26,11 @@ def forward( inputs: Tensor, group: Optional[ProcessGroup] = None, overlap: bool = False, - ) -> Tuple[Tensor, Optional[Callable]]: + ) -> Tuple[Tensor, Optional[Work]]: """ Returns: outputs: Tensor - handle: Optional[Callable], if overlap is True + handle: Optional[Work], if overlap is True """ assert ctx is not None or not overlap @@ -68,11 +68,11 @@ def forward( inputs: Tensor, group: Optional[ProcessGroup] = None, overlap: bool = False, - ) -> Tuple[Tensor, Optional[Callable]]: + ) -> Tuple[Tensor, Optional[Work]]: """ Returns: outputs: Tensor - handle: Optional[Callable], if overlap is True + handle: Optional[Work], if overlap is True """ assert ctx is not None or not overlap @@ -111,20 +111,38 @@ class AllToAll(torch.autograd.Function): """ @staticmethod - def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: + def forward( + ctx: Any, + inputs: Tensor, + group: Optional[ProcessGroup] = None, + overlap: bool = False, + ) -> Tuple[Tensor, Optional[Work]]: + """ + Returns: + outputs: Tensor + handle: Optional[Work], if overlap is True + """ if ctx is not None: ctx.comm_grp = group if not inputs.is_contiguous(): inputs = inputs.contiguous() if dist.get_world_size(group) == 1: - return inputs + return inputs, None output = torch.empty_like(inputs) - dist.all_to_all_single(output, inputs, group=group) - return output + if not overlap: + dist.all_to_all_single(output, inputs, group=group) + return output, None + else: + handle = dist.all_to_all_single(output, inputs, group=group, async_op=True) + return output, handle @staticmethod - def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: - return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None + def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: + return ( + AllToAll.forward(None, grad_outputs[0], ctx.comm_grp)[0], + None, + None, + ) class MoeDispatch(torch.autograd.Function): From 0abb7f2ae3a4a2d7e7bc723bf1499be290301e11 Mon Sep 17 00:00:00 2001 From: CWHer Date: Mon, 16 Oct 2023 23:36:30 +0800 Subject: [PATCH 06/10] feat: add overlap ep --- colossalai/moe/layers.py | 61 ++++++++++++++++++++++++++++++++++------ 1 file changed, 52 insertions(+), 9 deletions(-) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index 19ed6d626ac7..eda7ee59c270 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -188,7 +188,10 @@ def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: expert_out = self.experts(expert_in) return expert_out - def _ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: + def _ep_process(self, + dispatch_data: torch.Tensor, + overlap: bool = True + ) -> torch.Tensor: """ Expert Parallel @@ -198,13 +201,53 @@ def _ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: (num_experts, capacity, hidden_size) """ - expert_input = AllToAll.apply(dispatch_data, self.ep_group) - input_shape = expert_input.shape - expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) - expert_output = self.experts(expert_input) - expert_output = expert_output.reshape(input_shape) - expert_output = AllToAll.apply(expert_output, self.ep_group) - return expert_output + if not overlap: + expert_input = AllToAll.apply(dispatch_data, self.ep_group) + expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) + expert_output = self.experts(expert_input) + expert_output = AllToAll.apply(expert_output, self.ep_group) + return expert_output + + else: + NUM_CHUNK = 4 + NUM_STAGES = 4 + + assert dispatch_data.shape[1] % NUM_CHUNK == 0, \ + "arbitrary chunk num is not supported yet" + chunk_size = dispatch_data.shape[1] // NUM_CHUNK + + input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size) + dispatch_data = dispatch_data.reshape(*input_shape) + chunk_data = torch.split(dispatch_data, chunk_size, dim=2) + output = torch.empty_like(dispatch_data) + + expert_in, in_handle = None, None + partial_expert_out = None + expert_out, out_handle, offset = None, None, 0 + + for i in range(NUM_CHUNK + NUM_STAGES - 1): + if out_handle is not None: + out_handle.wait() + output[:, :, offset:offset + chunk_size, :] = expert_out + offset += chunk_size + expert_out, out_handle = None, None + + # reduce scatter last output + if partial_expert_out is not None: + expert_out, out_handle = AllToAll.apply(partial_expert_out, self.ep_group, True) + partial_expert_out = None + + # compute + if in_handle is not None: + in_handle.wait() + partial_expert_out = self.experts(expert_in) + expert_in, in_handle = None, None + + # all gather next input + if 0 <= i < NUM_CHUNK: + expert_in, in_handle = AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True) + + return output def _tp_process(self, dispatch_data: torch.Tensor, @@ -233,7 +276,7 @@ def _tp_process(self, expert_out = ReduceScatter.apply(partial_expert_out, self.ep_group, False)[0] return expert_out else: - NUM_CHUNK = 4 + NUM_CHUNK = 2 NUM_STAGES = 4 assert dispatch_data.shape[0] % NUM_CHUNK == 0, \ From 9cb5226a51f3add6f897835127fe392df7f9c518 Mon Sep 17 00:00:00 2001 From: CWHer Date: Tue, 17 Oct 2023 10:42:13 +0800 Subject: [PATCH 07/10] fix: fix import error --- colossalai/moe/_operation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index b180c6b0cbe5..9b6630a32d46 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -26,7 +26,7 @@ def forward( inputs: Tensor, group: Optional[ProcessGroup] = None, overlap: bool = False, - ) -> Tuple[Tensor, Optional[Work]]: + ) -> Tuple[Tensor, Any]: """ Returns: outputs: Tensor @@ -68,7 +68,7 @@ def forward( inputs: Tensor, group: Optional[ProcessGroup] = None, overlap: bool = False, - ) -> Tuple[Tensor, Optional[Work]]: + ) -> Tuple[Tensor, Any]: """ Returns: outputs: Tensor @@ -116,7 +116,7 @@ def forward( inputs: Tensor, group: Optional[ProcessGroup] = None, overlap: bool = False, - ) -> Tuple[Tensor, Optional[Work]]: + ) -> Tuple[Tensor, Any]: """ Returns: outputs: Tensor From 8c83c1b87e3deb9d88f8b168ebfa43e7152d53a6 Mon Sep 17 00:00:00 2001 From: CWHer Date: Tue, 17 Oct 2023 11:06:27 +0800 Subject: [PATCH 08/10] fix: fix ep/tp tests --- colossalai/moe/layers.py | 6 +++--- tests/test_moe/test_moe_ep_tp.py | 16 +++++++++------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index eda7ee59c270..8f118b252213 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -202,14 +202,14 @@ def _ep_process(self, torch.Tensor: (num_experts, capacity, hidden_size) """ if not overlap: - expert_input = AllToAll.apply(dispatch_data, self.ep_group) + expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0] expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) expert_output = self.experts(expert_input) - expert_output = AllToAll.apply(expert_output, self.ep_group) + expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0] return expert_output else: - NUM_CHUNK = 4 + NUM_CHUNK = 2 NUM_STAGES = 4 assert dispatch_data.shape[1] % NUM_CHUNK == 0, \ diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 10ccd8ebd2c2..51fd135483b6 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -21,20 +21,22 @@ def run_test(rank: int, assert batch_size % world_size == 0 colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(seed) # MOE initialization + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed, parallel=None) + local_model = SparseMLP(num_experts=num_experts, + hidden_size=dim, + intermediate_size=dim * 2) + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed, parallel="EP") ep_model = SparseMLP(num_experts=num_experts, - expert_parallel="EP", hidden_size=dim, intermediate_size=dim * 2) + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed, parallel="TP") tp_model = SparseMLP(num_experts=num_experts, - expert_parallel="TP", hidden_size=dim, intermediate_size=dim * 2) - local_model = SparseMLP(num_experts=num_experts, - expert_parallel=None, - hidden_size=dim, - intermediate_size=dim * 2) ep_model = ep_model.to(get_current_device()) tp_model = tp_model.to(get_current_device()) local_model = local_model.to(get_current_device()) From faf7003011017f5126f666047b552c94ce4309b7 Mon Sep 17 00:00:00 2001 From: CWHer Date: Tue, 17 Oct 2023 15:53:56 +0800 Subject: [PATCH 09/10] perf: optimize overlap --- colossalai/moe/_operation.py | 1 + colossalai/moe/layers.py | 113 ++++++++++++++++++++++------------- 2 files changed, 72 insertions(+), 42 deletions(-) diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index 9b6630a32d46..9b4988b345d9 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -98,6 +98,7 @@ def forward( @staticmethod def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: + # TODO: support async backward return ( AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], None, diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index 8f118b252213..cd91998ae528 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -1,3 +1,4 @@ +import dataclasses import math from typing import Any, Optional, Tuple @@ -209,43 +210,57 @@ def _ep_process(self, return expert_output else: + @dataclasses.dataclass + class Capsule(): + data: torch.Tensor + handle: Any = None + NUM_CHUNK = 2 NUM_STAGES = 4 assert dispatch_data.shape[1] % NUM_CHUNK == 0, \ "arbitrary chunk num is not supported yet" chunk_size = dispatch_data.shape[1] // NUM_CHUNK - input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size) dispatch_data = dispatch_data.reshape(*input_shape) chunk_data = torch.split(dispatch_data, chunk_size, dim=2) output = torch.empty_like(dispatch_data) - expert_in, in_handle = None, None - partial_expert_out = None - expert_out, out_handle, offset = None, None, 0 + offset = 0 + _expert_in, expert_in, _expert_out, expert_out = None, None, None, None for i in range(NUM_CHUNK + NUM_STAGES - 1): - if out_handle is not None: - out_handle.wait() - output[:, :, offset:offset + chunk_size, :] = expert_out + if expert_out is not None: + expert_out.handle.wait() + output[:, :, offset:offset + chunk_size, :] = expert_out.data offset += chunk_size - expert_out, out_handle = None, None + expert_out = None - # reduce scatter last output - if partial_expert_out is not None: - expert_out, out_handle = AllToAll.apply(partial_expert_out, self.ep_group, True) - partial_expert_out = None + # all2all last output + if _expert_out is not None: + expert_out = Capsule( + *AllToAll.apply(_expert_out.data, self.ep_group, True), + ) + _expert_out = None - # compute - if in_handle is not None: - in_handle.wait() - partial_expert_out = self.experts(expert_in) - expert_in, in_handle = None, None - - # all gather next input + # all2all next input if 0 <= i < NUM_CHUNK: - expert_in, in_handle = AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True) + _expert_in = Capsule( + *AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True) + ) + + # compute + if expert_in is not None: + expert_in.handle.wait() + _expert_out = Capsule( + data=self.experts(expert_in.data), + handle=None + ) + expert_in = None + + if _expert_in is not None: + expert_in = _expert_in + _expert_in = None return output @@ -272,10 +287,16 @@ def _tp_process(self, """ if not overlap: expert_in = AllGather.apply(dispatch_data, self.ep_group, False)[0] - partial_expert_out = self.experts(expert_in) - expert_out = ReduceScatter.apply(partial_expert_out, self.ep_group, False)[0] + expert_out = self.experts(expert_in) + expert_out = ReduceScatter.apply(expert_out, self.ep_group, False)[0] return expert_out else: + @dataclasses.dataclass + class Capsule(): + data: torch.Tensor + handle: Any + indices: Tuple + NUM_CHUNK = 2 NUM_STAGES = 4 @@ -288,33 +309,41 @@ def _tp_process(self, def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]: return (slice(idx * chunk_size, (idx + 1) * chunk_size), ) - expert_in, in_handle, input_indices = None, None, None - partial_expert_out, data_indices = None, None - expert_out, out_handle, output_indices = None, None, None + _expert_in, expert_in, _expert_out, expert_out = None, None, None, None for i in range(NUM_CHUNK + NUM_STAGES - 1): - if out_handle is not None: - out_handle.wait() - output[output_indices] = expert_out - expert_out, out_handle, output_indices = None, None, None + if expert_out is not None: + expert_out.handle.wait() + output[expert_out.indices] = expert_out.data + expert_out = None # reduce scatter last output - if partial_expert_out is not None: - output_indices = data_indices - expert_out, out_handle = ReduceScatter.apply(partial_expert_out, self.ep_group, True) - partial_expert_out = None - - # compute - if in_handle is not None: - in_handle.wait() - data_indices = input_indices - partial_expert_out = self.experts(expert_in, input_indices) - expert_in, in_handle, input_indices = None, None, None + if _expert_out is not None: + expert_out = Capsule( + *ReduceScatter.apply(_expert_out.data, self.ep_group, True), + indices=_expert_out.indices + ) + _expert_out = None # all gather next input if 0 <= i < NUM_CHUNK: - input_indices = get_chunk_slice(i, chunk_size) - expert_in, in_handle = AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True) + _expert_in = Capsule( + *AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True), + indices=get_chunk_slice(i, chunk_size) + ) + + # compute + if expert_in is not None: + expert_in.handle.wait() + _expert_out = Capsule( + self.experts(expert_in.data, expert_in.indices), + handle=None, indices=expert_in.indices + ) + expert_in = None + + if _expert_in is not None: + expert_in = _expert_in + _expert_in = None return output From 117bc9e4cea59cb9c38d5326416792abea10abb1 Mon Sep 17 00:00:00 2001 From: CWHer Date: Tue, 17 Oct 2023 16:31:11 +0800 Subject: [PATCH 10/10] fix: add world_size check --- colossalai/moe/layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index cd91998ae528..9846cd432b53 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -202,7 +202,7 @@ def _ep_process(self, Returns: torch.Tensor: (num_experts, capacity, hidden_size) """ - if not overlap: + if not overlap or dist.get_world_size(self.ep_group) == 1: expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0] expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) expert_output = self.experts(expert_input) @@ -285,7 +285,7 @@ def _tp_process(self, Returns: torch.Tensor: (num_experts, capacity, hidden_size) """ - if not overlap: + if not overlap or dist.get_world_size(self.ep_group) == 1: expert_in = AllGather.apply(dispatch_data, self.ep_group, False)[0] expert_out = self.experts(expert_in) expert_out = ReduceScatter.apply(expert_out, self.ep_group, False)[0]