diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index a932b96597b6..9b4988b345d9 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -9,8 +9,6 @@ from colossalai.moe.manager import MOE_MANAGER MOE_KERNEL = None -WORLD_HANDLE_ALLGATHER = None -WORLD_HANDLE_REDUCESCATTER = None def load_moe(): @@ -28,14 +26,20 @@ def forward( inputs: Tensor, group: Optional[ProcessGroup] = None, overlap: bool = False, - ) -> Tensor: + ) -> Tuple[Tensor, Any]: + """ + Returns: + outputs: Tensor + handle: Optional[Work], if overlap is True + """ + assert ctx is not None or not overlap + if ctx is not None: ctx.comm_grp = group - ctx.overlap = overlap comm_size = dist.get_world_size(group) if comm_size == 1: - return inputs.unsqueeze(0) + return inputs.unsqueeze(0), None buffer_shape = (comm_size,) + inputs.shape outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) @@ -45,19 +49,12 @@ def forward( return outputs, None else: handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True) - if ctx is None and overlap: - global WORLD_HANDLE_ALLGATHER - WORLD_HANDLE_ALLGATHER = handle return outputs, handle @staticmethod - def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]: - global WORLD_HANDLE_REDUCESCATTER - if WORLD_HANDLE_REDUCESCATTER is not None: - WORLD_HANDLE_REDUCESCATTER.wait() - WORLD_HANDLE_REDUCESCATTER = None + def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: return ( - ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, ctx.overlap)[0], + ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], None, None, ) @@ -71,14 +68,20 @@ def forward( inputs: Tensor, group: Optional[ProcessGroup] = None, overlap: bool = False, - ) -> Tensor: + ) -> Tuple[Tensor, Any]: + """ + Returns: + outputs: Tensor + handle: Optional[Work], if overlap is True + """ + assert ctx is not None or not overlap + if ctx is not None: ctx.comm_grp = group - ctx.overlap = overlap comm_size = dist.get_world_size(group) if comm_size == 1: - return inputs.squeeze(0) + return inputs.squeeze(0), None if not inputs.is_contiguous(): inputs = inputs.contiguous() @@ -91,19 +94,13 @@ def forward( return outputs, None else: handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True) - if ctx is None and overlap: - global WORLD_HANDLE_REDUCESCATTER - WORLD_HANDLE_REDUCESCATTER = handle return outputs, handle @staticmethod - def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]: - global WORLD_HANDLE_ALLGATHER - if WORLD_HANDLE_ALLGATHER is not None: - WORLD_HANDLE_ALLGATHER.wait() - WORLD_HANDLE_ALLGATHER = None + def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: + # TODO: support async backward return ( - AllGather.forward(None, grad_outputs[0], ctx.comm_grp, ctx.overlap)[0], + AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], None, None, ) @@ -115,20 +112,38 @@ class AllToAll(torch.autograd.Function): """ @staticmethod - def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: + def forward( + ctx: Any, + inputs: Tensor, + group: Optional[ProcessGroup] = None, + overlap: bool = False, + ) -> Tuple[Tensor, Any]: + """ + Returns: + outputs: Tensor + handle: Optional[Work], if overlap is True + """ if ctx is not None: ctx.comm_grp = group if not inputs.is_contiguous(): inputs = inputs.contiguous() if dist.get_world_size(group) == 1: - return inputs + return inputs, None output = torch.empty_like(inputs) - dist.all_to_all_single(output, inputs, group=group) - return output + if not overlap: + dist.all_to_all_single(output, inputs, group=group) + return output, None + else: + handle = dist.all_to_all_single(output, inputs, group=group, async_op=True) + return output, handle @staticmethod - def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: - return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None + def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: + return ( + AllToAll.forward(None, grad_outputs[0], ctx.comm_grp)[0], + None, + None, + ) class MoeDispatch(torch.autograd.Function): diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index 3f82a0fa23fd..9846cd432b53 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -1,3 +1,4 @@ +import dataclasses import math from typing import Any, Optional, Tuple @@ -188,7 +189,10 @@ def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: expert_out = self.experts(expert_in) return expert_out - def _ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: + def _ep_process(self, + dispatch_data: torch.Tensor, + overlap: bool = True + ) -> torch.Tensor: """ Expert Parallel @@ -198,27 +202,82 @@ def _ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: (num_experts, capacity, hidden_size) """ - expert_input = AllToAll.apply(dispatch_data, self.ep_group) - input_shape = expert_input.shape - expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) - expert_output = self.experts(expert_input) - expert_output = expert_output.reshape(input_shape) - expert_output = AllToAll.apply(expert_output, self.ep_group) - return expert_output - - def _tp_process(self, dispatch_data: torch.Tensor, use_overlap: bool = False) -> torch.Tensor: - """ - TP with overlap. + if not overlap or dist.get_world_size(self.ep_group) == 1: + expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0] + expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) + expert_output = self.experts(expert_input) + expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0] + return expert_output - origin: + else: + @dataclasses.dataclass + class Capsule(): + data: torch.Tensor + handle: Any = None + + NUM_CHUNK = 2 + NUM_STAGES = 4 + + assert dispatch_data.shape[1] % NUM_CHUNK == 0, \ + "arbitrary chunk num is not supported yet" + chunk_size = dispatch_data.shape[1] // NUM_CHUNK + input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size) + dispatch_data = dispatch_data.reshape(*input_shape) + chunk_data = torch.split(dispatch_data, chunk_size, dim=2) + output = torch.empty_like(dispatch_data) + + offset = 0 + _expert_in, expert_in, _expert_out, expert_out = None, None, None, None + + for i in range(NUM_CHUNK + NUM_STAGES - 1): + if expert_out is not None: + expert_out.handle.wait() + output[:, :, offset:offset + chunk_size, :] = expert_out.data + offset += chunk_size + expert_out = None + + # all2all last output + if _expert_out is not None: + expert_out = Capsule( + *AllToAll.apply(_expert_out.data, self.ep_group, True), + ) + _expert_out = None + + # all2all next input + if 0 <= i < NUM_CHUNK: + _expert_in = Capsule( + *AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True) + ) + + # compute + if expert_in is not None: + expert_in.handle.wait() + _expert_out = Capsule( + data=self.experts(expert_in.data), + handle=None + ) + expert_in = None + + if _expert_in is not None: + expert_in = _expert_in + _expert_in = None + + return output + + def _tp_process(self, + dispatch_data: torch.Tensor, + overlap: bool = True + ) -> torch.Tensor: + """ + without overlap: | C | | A | | R | - overlap: + with overlap: | C1 || C2 || C3 || C4 | | A1 || A2 | | R1 | A3 || R2 | A4 || R3 | | R4 | - C is computation, A is all gather, R is reduce scatter. + where C is computation, A is all gather, R is reduce scatter. Args: dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size) @@ -226,65 +285,67 @@ def _tp_process(self, dispatch_data: torch.Tensor, use_overlap: bool = False) -> Returns: torch.Tensor: (num_experts, capacity, hidden_size) """ - if use_overlap == False: - expert_in, _ = AllGather.apply(dispatch_data, self.ep_group) + if not overlap or dist.get_world_size(self.ep_group) == 1: + expert_in = AllGather.apply(dispatch_data, self.ep_group, False)[0] expert_out = self.experts(expert_in) - expert_out, _ = ReduceScatter.apply(expert_out, self.ep_group) + expert_out = ReduceScatter.apply(expert_out, self.ep_group, False)[0] return expert_out - - # TODO: there is accuracy problem in overlap - chunk_num = 4 - chunk_size = dispatch_data.shape[0] // chunk_num - out = torch.empty_like(dispatch_data) - in_data = None - in_handle = None - out_data = None - out_handle = None - - # backward compatibility for async op - torch.cuda.synchronize() - - def get_chunk_slice(idx: int, gap: int) -> Tuple[slice]: - return (slice(idx * gap, (idx + 1) * gap),) - - for i in range(chunk_num): - cur_chunk_slice = get_chunk_slice(i, chunk_size) - - # if first, all gather - if i == 0: - d = dispatch_data[cur_chunk_slice].contiguous() - expert_in, _ = AllGather.apply(d, self.ep_group) - else: - expert_in = in_data - - # async communication while compute - if i != 0: - # reduce scatter last out - out_data, out_handle = ReduceScatter.apply(out_data, self.ep_group, True) - if i != chunk_num - 1: - # all gather next in - next_d = dispatch_data[get_chunk_slice(i + 1, chunk_size)].contiguous() - in_data, in_handle = AllGather.apply(next_d, self.ep_group, True) - - # compute - expert_out = self.experts(expert_in, cur_chunk_slice) - - # sync handle - if i != 0: - out_handle.wait() - out[get_chunk_slice(i - 1, chunk_size)] = out_data - if i != chunk_num - 1: - in_handle.wait() - out_data = expert_out - - # store out for last loop - if i == chunk_num - 1: - out_data, _ = ReduceScatter.apply(out_data, self.ep_group) - out[cur_chunk_slice] = out_data - - # sync for async op - torch.cuda.synchronize() - return out + else: + @dataclasses.dataclass + class Capsule(): + data: torch.Tensor + handle: Any + indices: Tuple + + NUM_CHUNK = 2 + NUM_STAGES = 4 + + assert dispatch_data.shape[0] % NUM_CHUNK == 0, \ + "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" + chunk_size = dispatch_data.shape[0] // NUM_CHUNK + chunk_data = torch.split(dispatch_data, chunk_size, dim=0) + output = torch.empty_like(dispatch_data) + + def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]: + return (slice(idx * chunk_size, (idx + 1) * chunk_size), ) + + _expert_in, expert_in, _expert_out, expert_out = None, None, None, None + + for i in range(NUM_CHUNK + NUM_STAGES - 1): + if expert_out is not None: + expert_out.handle.wait() + output[expert_out.indices] = expert_out.data + expert_out = None + + # reduce scatter last output + if _expert_out is not None: + expert_out = Capsule( + *ReduceScatter.apply(_expert_out.data, self.ep_group, True), + indices=_expert_out.indices + ) + _expert_out = None + + # all gather next input + if 0 <= i < NUM_CHUNK: + _expert_in = Capsule( + *AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True), + indices=get_chunk_slice(i, chunk_size) + ) + + # compute + if expert_in is not None: + expert_in.handle.wait() + _expert_out = Capsule( + self.experts(expert_in.data, expert_in.indices), + handle=None, indices=expert_in.indices + ) + expert_in = None + + if _expert_in is not None: + expert_in = _expert_in + _expert_in = None + + return output def apply_load_balance(model: nn.Module, optim: Any) -> None: diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 2bbf739ebbd4..51fd135483b6 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -8,57 +8,89 @@ from colossalai.moe.utils import sync_moe_model_param from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -from tests.test_moe.moe_utils import MoeGradientHandler, sync_tp_from_ep +from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep, sync_tp_from_ep -BATCH_SIZE = 4 -DIM = 16 +def run_test(rank: int, + world_size: int, + port: int, + num_experts: int, + batch_size: int, + dim: int, + seed: int): + assert batch_size % world_size == 0 -def run_test(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(42, parallel="EP") # MOE initialization - ep_model = SparseMLP(num_experts=4, hidden_size=DIM, intermediate_size=DIM) + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed, parallel=None) + local_model = SparseMLP(num_experts=num_experts, + hidden_size=dim, + intermediate_size=dim * 2) MOE_MANAGER.__init__() - MOE_MANAGER.setup(42, parallel="TP") - tp_model = SparseMLP(num_experts=4, hidden_size=DIM, intermediate_size=DIM) + MOE_MANAGER.setup(seed, parallel="EP") + ep_model = SparseMLP(num_experts=num_experts, + hidden_size=dim, + intermediate_size=dim * 2) + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed, parallel="TP") + tp_model = SparseMLP(num_experts=num_experts, + hidden_size=dim, + intermediate_size=dim * 2) ep_model = ep_model.to(get_current_device()) tp_model = tp_model.to(get_current_device()) + local_model = local_model.to(get_current_device()) # sync ep param sync_moe_model_param(ep_model) dist_dict = MOE_MANAGER.parallel_info_dict - assert_equal_in_group(ep_model.experts.wi.data, dist_dict[2].dp_group) - assert_equal_in_group(ep_model.experts.wo.data, dist_dict[2].dp_group) + assert_equal_in_group(ep_model.experts.wi.data, dist_dict[world_size].dp_group) + assert_equal_in_group(ep_model.experts.wo.data, dist_dict[world_size].dp_group) grad_handler = MoeGradientHandler(ep_model) # sync tp param sync_tp_from_ep(tp_model, ep_model) + # sync local param + sync_local_from_ep(local_model, ep_model) rank = dist.get_rank() - torch.cuda.manual_seed(78) - tp_data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) - ep_data = tp_data.detach()[2 * rank:2 * (rank + 1)] + torch.cuda.manual_seed(seed) + tp_data = torch.randn(batch_size, dim, device=get_current_device()) + micro_batch_size = batch_size // world_size + ep_data = tp_data.detach()[micro_batch_size * rank:micro_batch_size * (rank + 1)] + out_local = local_model(tp_data) + MOE_MANAGER.reset_loss() out_tp = tp_model(tp_data) MOE_MANAGER.reset_loss() out_ep = ep_model(ep_data) MOE_MANAGER.reset_loss() - assert torch.allclose(out_ep, out_tp[2 * rank:2 * (rank + 1)]) + assert torch.allclose(out_ep, out_tp[micro_batch_size * rank:micro_batch_size * (rank + 1)]) + assert torch.allclose(out_ep, out_local[micro_batch_size * rank:micro_batch_size * (rank + 1)]) + out_local.mean().backward() out_tp.mean().backward() out_ep.mean().backward() grad_handler.handle_gradient() - assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[2].dp_group) - assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[2].dp_group) + assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[world_size].dp_group) + assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[world_size].dp_group) + sync_local_from_ep(local_model, ep_model, assert_grad_flag=True) sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True) @pytest.mark.dist +@pytest.mark.parametrize("num_experts", [4, 8]) +@pytest.mark.parametrize("batch_size", [4, 8]) +@pytest.mark.parametrize("dim", [16, 256]) +@pytest.mark.parametrize("seed", [42, 78]) @rerun_if_address_is_in_use() -def test_moe_ep_tp(): - spawn(run_test, 2) +def test_moe_ep_tp(num_experts: int, + batch_size: int, + dim: int, + seed: int): + spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed) if __name__ == '__main__': - test_moe_ep_tp() + test_moe_ep_tp(num_experts=8, batch_size=8, dim=256, seed=42) diff --git a/tests/test_moe/test_moe_local.py b/tests/test_moe/test_moe_local.py deleted file mode 100644 index 1211a0d2d7f1..000000000000 --- a/tests/test_moe/test_moe_local.py +++ /dev/null @@ -1,65 +0,0 @@ -import pytest -import torch -import torch.distributed as dist - -import colossalai -from colossalai.moe import SparseMLP -from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import sync_moe_model_param -from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device -from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep - -BATCH_SIZE = 4 -DIM = 16 - - -def run_test(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(42, parallel=None) - local_model = SparseMLP(num_experts=4, hidden_size=DIM, intermediate_size=DIM) - MOE_MANAGER.__init__() - MOE_MANAGER.setup(42, parallel="EP") # MOE initialization - ep_model = SparseMLP(num_experts=4, hidden_size=DIM, intermediate_size=DIM) - ep_model = ep_model.to(get_current_device()) - local_model = local_model.to(get_current_device()) - - # sync ep param - sync_moe_model_param(ep_model) - dist_dict = MOE_MANAGER.parallel_info_dict - assert_equal_in_group(ep_model.experts.wi.data, dist_dict[2].dp_group) - assert_equal_in_group(ep_model.experts.wo.data, dist_dict[2].dp_group) - grad_handler = MoeGradientHandler(ep_model) - # sync tp param - sync_local_from_ep(local_model, ep_model) - - rank = dist.get_rank() - torch.cuda.manual_seed(78) - tp_data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) - ep_data = tp_data.detach()[2 * rank:2 * (rank + 1)] - - out_tp = local_model(tp_data) - MOE_MANAGER.reset_loss() - out_ep = ep_model(ep_data) - MOE_MANAGER.reset_loss() - assert torch.allclose(out_ep, out_tp[2 * rank:2 * (rank + 1)]) - - out_tp.mean().backward() - out_ep.mean().backward() - grad_handler.handle_gradient() - - assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[2].dp_group) - assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[2].dp_group) - - sync_local_from_ep(local_model, ep_model, assert_grad_flag=True) - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [2]) -@rerun_if_address_is_in_use() -def test_moe_local(world_size): - spawn(run_test, world_size) - - -if __name__ == '__main__': - test_moe_local()