diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py index ffeeac796441..c20d16181909 100644 --- a/colossalai/nn/layer/moe/__init__.py +++ b/colossalai/nn/layer/moe/__init__.py @@ -5,6 +5,6 @@ from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts __all__ = [ - 'EPMLPExperts', 'TPMLPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'MoeModule', 'NormalNoiseGenerator', + 'EPMLPExperts', 'TPMLPExperts', 'Top1Router', 'Top2Router', 'MoeModule', 'MoeLayer', 'NormalNoiseGenerator', 'UniformNoiseGenerator', 'build_ffn_experts', 'SparseMLP', 'MoeRouter', 'MoeCheckpintIO' ] diff --git a/colossalai/nn/layer/moe/_operation.py b/colossalai/nn/layer/moe/_operation.py index 37f31c16709b..a0753d8581b4 100644 --- a/colossalai/nn/layer/moe/_operation.py +++ b/colossalai/nn/layer/moe/_operation.py @@ -173,3 +173,44 @@ def moe_cumsum(inputs: Tensor): return moe.cumsum_sub_one(inputs) else: return torch.cumsum(inputs, dim=0) - 1 + + +class MoeInGradScaler(torch.autograd.Function): + """ + Scale the gradient back by the number of experts + because the batch size increases in the moe stage + """ + + @staticmethod + def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor: + if ctx is not None: + ctx.ep_size = ep_size + return inputs + + @staticmethod + def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: + assert len(grad_outputs) == 1 + grad = grad_outputs[0] + if ctx.ep_size != 1: + grad = grad * ctx.ep_size + return grad, None + + +class MoeOutGradScaler(torch.autograd.Function): + """ + Scale the gradient by the number of experts + because the batch size increases in the moe stage + """ + + @staticmethod + def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor: + ctx.ep_size = ep_size + return inputs + + @staticmethod + def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: + assert len(grad_outputs) == 1 + grad = grad_outputs[0] + if ctx.ep_size != 1: + grad = grad / ctx.ep_size + return grad, None diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index 0ed2f1fd2513..608eca05435e 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -7,6 +7,7 @@ from colossalai.context import ParallelMode, seed from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.nn.layer.moe._operation import MoeInGradScaler, MoeOutGradScaler from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size, set_moe_tensor_info @@ -20,27 +21,31 @@ def __init__( num_experts: int, hidden_size: int, intermediate_size: int, - expert_parallel: str, + expert_parallel: str = None, activation: str = None, drop_rate: float = 0, ): super().__init__() - assert expert_parallel in ["EP", "TP"] + assert expert_parallel in ["EP", "TP", None] self.expert_parallel = expert_parallel - - # get local and total experts self.num_total_experts = num_experts - self.num_local_experts, self.moe_info = MOE_CONTEXT.get_info(num_experts, - use_tp=True if expert_parallel == "TP" else False) - - # get settings for different parallel - if expert_parallel == "TP": - assert intermediate_size % MOE_CONTEXT.max_ep_size == 0, \ - "intermediate_size should be divide by maximum expert parallel size" - intermediate_size = intermediate_size // MOE_CONTEXT.max_ep_size - num_experts = self.num_total_experts + + # get expert parallel info + if expert_parallel is not None: + self.num_local_experts, self.moe_info = MOE_CONTEXT.get_info( + num_experts, use_tp=True if expert_parallel == "TP" else False) + # get settings for different parallel + if expert_parallel == "TP": + assert intermediate_size % MOE_CONTEXT.max_ep_size == 0, \ + "intermediate_size should be divide by maximum expert parallel size" + intermediate_size = intermediate_size // MOE_CONTEXT.max_ep_size + num_experts = self.num_total_experts + else: + num_experts = self.num_local_experts + self.ep_size = get_ep_size(self) else: - num_experts = self.num_local_experts + self.num_local_experts = self.num_total_experts + self.ep_size = 1 self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size)) @@ -52,10 +57,12 @@ def __init__( self.act = nn.GELU() if activation is None else activation self.drop = nn.Dropout(p=drop_rate) - for param in self.parameters(): - set_moe_tensor_info(param, self.moe_info) + if expert_parallel is not None: + for param in self.parameters(): + set_moe_tensor_info(param, self.moe_info) def forward(self, x: torch.Tensor) -> torch.Tensor: # inputs [g, e, c, h] + x = MoeInGradScaler.apply(x, self.ep_size) e = x.size(1) h = x.size(-1) @@ -72,6 +79,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # inputs [g, e, c, h] x = x.reshape(inshape) x = x.transpose(0, 1).contiguous() + x = MoeOutGradScaler.apply(x, self.ep_size) return x # outputs [g, e, c, h] @@ -135,5 +143,7 @@ def get_expert_class(name: str) -> BaseMLPExperts: return TPMLPExperts elif name == "EP": return EPMLPExperts + elif name is None: + return BaseMLPExperts else: raise ValueError(f"Unknown expert class name: {name}") diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index d870781d29c4..f39eab40d28b 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -20,90 +20,6 @@ from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size -class MoeLayer(nn.Module): - """A MoE layer, that puts its input tensor to its gate and uses the output logits - to router all tokens, is mainly used to exchange all tokens for every expert across - the moe tensor group by all to all communication. Then it will get the output of all - experts and exchange the output. At last returns the output of the moe system. - - Args: - dim_model (int): Dimension of model. - num_experts (int): The number of experts. - router (MoeRouter): Instance of router used in routing. - experts (MoeExperts): Instance of experts generated by Expert. - """ - - def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: BaseMLPExperts): - super().__init__() - self.d_model = dim_model - self.num_experts = num_experts - self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model)) - self.router: MoeRouter = router - self.experts: BaseMLPExperts = experts - self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False - self.ep_group = get_ep_group(experts) - self.ep_size = get_ep_size(experts) - self.num_local_experts = experts.num_local_experts - - nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model)) - - def ep_process(self, dispatch_data: torch.Tensor): - expert_input = AllToAll.apply(dispatch_data, self.ep_group) - input_shape = expert_input.shape - expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model) - expert_output = self.experts(expert_input) - expert_output = expert_output.reshape(input_shape) - expert_output = AllToAll.apply(expert_output, self.ep_group) - return expert_output - - def tp_process(self, dispatch_data: torch.Tensor): - expert_in = AllGather.apply(dispatch_data, self.ep_group) - expert_out = self.experts(expert_in) - expert_out = ReduceScatter.apply(expert_out, self.ep_group) - return expert_out - - def forward(self, inputs: torch.Tensor) -> Tuple: - # reshape the input tokens - tokens = inputs.reshape(-1, self.d_model) - - # the data type of the inputs in the gating should be fp32 - fp32_input = tokens.to(torch.float) - fp32_weight = self.gate_weight.to(torch.float) - gate_output = F.linear(fp32_input, fp32_weight) - - # the result from the router - route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) - - if self.use_kernel: - dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:]) - dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model) - else: - sec_mask_f = route_result_list[1].type_as(inputs) - dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) - - # dispatch_data [e, c, h] - if self.experts.expert_parallel == "EP": - expert_output = self.ep_process(dispatch_data) - elif self.experts.expert_parallel == "TP": - expert_output = self.tp_process(dispatch_data) - else: - raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " - "build function.") - # expert_output [e, c, h] - if self.use_kernel: - expert_output = expert_output.reshape(-1, self.d_model) - ans = MoeCombine.apply(expert_output, *route_result_list) - else: - combine_weights = route_result_list[0].type_as(inputs) - combine_weights = combine_weights.view(combine_weights.shape[0], -1) - expert_output = expert_output.view(-1, expert_output.shape[-1]) - ans = torch.matmul(combine_weights, expert_output) - - ans = ans.reshape(inputs.shape) - l_aux = self.router.pop_routing_loss() - return ans, l_aux - - class SparseMLP(nn.Module): """A class for users to create MoE modules in their models. @@ -149,7 +65,8 @@ def __init__(self, self.hidden_size = hidden_size self.num_experts = num_experts self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False - assert expert_parallel in ["EP", "TP"], f"Unsupported expert parallel type {expert_parallel}" + self.expert_parallel = expert_parallel + assert expert_parallel in ["EP", "TP", None], f"Unsupported expert parallel type {expert_parallel}" # moe router noisy_func = get_noise_generator(noisy_policy, num_experts) @@ -166,8 +83,11 @@ def __init__(self, hidden_size=hidden_size, intermediate_size=intermediate_size, activation=activation) - self.ep_group = get_ep_group(self.experts) - self.ep_size = get_ep_size(self.experts) + if expert_parallel is not None: + self.ep_group = get_ep_group(self.experts) + self.ep_size = get_ep_size(self.experts) + else: + self.ep_group = None self.num_local_experts = self.experts.num_local_experts self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) @@ -193,10 +113,12 @@ def forward(self, inputs: torch.Tensor) -> Tuple: dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) # dispatch_data [e, c, h] - if self.experts.expert_parallel == "EP": - expert_output = self.ep_process(dispatch_data) - elif self.experts.expert_parallel == "TP": - expert_output = self.tp_process(dispatch_data) + if self.expert_parallel == "EP": + expert_output = self._ep_process(dispatch_data) + elif self.expert_parallel == "TP": + expert_output = self._tp_process(dispatch_data) + elif self.expert_parallel is None: + expert_output = self._local_process(dispatch_data) else: raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " "build function.") @@ -214,7 +136,12 @@ def forward(self, inputs: torch.Tensor) -> Tuple: l_aux = self.router.pop_routing_loss() return ans, l_aux - def ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: + def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: + expert_in = expert_in.unsqueeze(0) + expert_out = self.experts(expert_in) + return expert_out + + def _ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: expert_input = AllToAll.apply(dispatch_data, self.ep_group) input_shape = expert_input.shape expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) @@ -223,14 +150,35 @@ def ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: expert_output = AllToAll.apply(expert_output, self.ep_group) return expert_output - def tp_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: + def _tp_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: expert_in = AllGather.apply(dispatch_data, self.ep_group) expert_out = self.experts(expert_in) expert_out = ReduceScatter.apply(expert_out, self.ep_group) return expert_out -class MoeModule(nn.Module): +class MoeModule(SparseMLP): + """ + For other dependency + """ + + def __init__(self, + num_experts: int, + top_k: int = 1, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_policy: Optional[str] = None, + drop_tks: bool = True, + expert_parallel: str = "EP", + hidden_size: int = 2048, + intermediate_size: int = 2048, + activation: str = None): + super().__init__(num_experts, top_k, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_policy, + drop_tks, expert_parallel, hidden_size, intermediate_size, activation) + + +class MoeLayer(SparseMLP): """ For other dependency """ diff --git a/colossalai/nn/layer/moe/routers.py b/colossalai/nn/layer/moe/routers.py index 53fd8fd43e91..962aec9cf1e7 100644 --- a/colossalai/nn/layer/moe/routers.py +++ b/colossalai/nn/layer/moe/routers.py @@ -111,7 +111,7 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti l_aux = num_experts * torch.sum(me * ce) self.set_routing_loss(l_aux) - if not self.training and not self.drop_tks: + if not self.training and not self.drop_tks and ep_group is not None: max_num = torch.max(torch.sum(mask, dim=0)) dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) capacity = max_num.item() @@ -190,7 +190,7 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 self.set_routing_loss(l_aux) - if not self.training and not self.drop_tks: + if not self.training and not self.drop_tks and ep_group is not None: max_num = torch.max(torch.sum(cmask, dim=0)) dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) capacity = max_num.item() diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index b57567e74be3..a3f6f86e6fe9 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -127,3 +127,37 @@ def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: assert torch.allclose(tp_param.grad, new_grad) else: tp_param.data.copy_(new_tp_param.data) + + +def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: + """Sync the parameters of tp model from ep model + + Args: + tp_model (MoeModule) + ep_model (MoeModule) + """ + for (local_name, local_param), (ep_name, ep_param) in zip(local_model.named_parameters(), + ep_model.named_parameters()): + assert local_name == ep_name + if "experts" not in local_name: + if assert_grad_flag: + assert torch.allclose(local_param, ep_param) + assert torch.allclose(local_param.grad, ep_param.grad) + else: + local_param.data.copy_(ep_param.data) + continue + + # gather param from ep model + param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) + all_param = torch.cat(param_list, dim=0) + if assert_grad_flag: + grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) + all_grad = torch.cat(grad_list, dim=0) + + if assert_grad_flag: + assert torch.allclose(local_param, all_param) + assert torch.allclose(local_param.grad, all_grad) + else: + local_param.data.copy_(all_param.data) diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index 6135a386e7c8..9544aa0daf01 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -5,7 +5,7 @@ import colossalai from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.nn.layer.moe import EPMLPExperts, MoeLayer, Top1Router, UniformNoiseGenerator +from colossalai.nn.layer.moe import SparseMLP from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device from colossalai.utils.moe import sync_moe_model_param @@ -17,16 +17,17 @@ def run_test(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - expert_factor = dict(hidden_size=DIM, intermediate_size=DIM * 2) MOE_CONTEXT.setup(42) # MOE initialization - noisy_func = UniformNoiseGenerator() - router = Top1Router(noisy_func=noisy_func) num_experts_list = [1, 2, 4] layer_list = [] for num_experts in num_experts_list: - exp = EPMLPExperts(num_experts, **expert_factor) - moe_layer = MoeLayer(DIM, num_experts, router, exp) + moe_layer = SparseMLP(hidden_size=DIM, + intermediate_size=DIM * 4, + num_experts=num_experts, + top_k=1, + expert_parallel="EP", + noisy_policy="Jitter") layer_list.append(moe_layer) model = nn.ModuleList(layer_list) diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 867437f00c82..6e40e53311a6 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -5,7 +5,7 @@ from colossalai.context import ParallelMode from colossalai.context.moe_context import MOE_CONTEXT from colossalai.core import global_context as gpc -from colossalai.nn.layer.moe import EPMLPExperts, MoeLayer, Top1Router, Top2Router +from colossalai.nn.layer.moe import SparseMLP from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device @@ -17,7 +17,7 @@ def check_equal(tensor_a, tensor_b, atol=1e-06): assert torch.allclose(tensor_a, tensor_b, rtol=0, atol=atol) is True -def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32, router=Top2Router): +def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32, topk=1): # Here we do not need TF32, since it brings absolute error on results torch.backends.cuda.matmul.allow_tf32 = False @@ -31,9 +31,12 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f # get randomized data tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True) - expert_factor = dict(hidden_size=hidden_size, intermediate_size=hidden_size * 2) - expert = EPMLPExperts(NUM_EXPERTS, **expert_factor) - layer = MoeLayer(hidden_size, NUM_EXPERTS, router(capacity_factor_train=1.0), expert) + layer = SparseMLP(hidden_size=hidden_size, + intermediate_size=hidden_size * 2, + num_experts=NUM_EXPERTS, + top_k=topk, + expert_parallel="EP", + capacity_factor_train=1.0) layer = layer.to(get_current_device()) if data_type == torch.float16: layer = layer.half() @@ -83,11 +86,11 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f @pytest.mark.parametrize("rs", [131]) @pytest.mark.parametrize("hidden_size", [32, 144]) @pytest.mark.parametrize("data_type", [torch.float32, torch.float16]) -@pytest.mark.parametrize("router", [Top1Router, Top2Router]) +@pytest.mark.parametrize("topk", [1, 2]) @rerun_if_address_is_in_use() -def test_moe_kernel(rs, hidden_size, data_type, router): - spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, router=router) +def test_moe_kernel(rs, hidden_size, data_type, topk): + spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, topk=topk) if __name__ == '__main__': - test_moe_kernel(2, 256, torch.float16, Top2Router) + test_moe_kernel(2, 256, torch.float16, 2) diff --git a/tests/test_moe/test_moe_local.py b/tests/test_moe/test_moe_local.py new file mode 100644 index 000000000000..d240ad46ce71 --- /dev/null +++ b/tests/test_moe/test_moe_local.py @@ -0,0 +1,63 @@ +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.nn.layer.moe import SparseMLP +from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from colossalai.utils.moe import sync_moe_model_param +from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep + +BATCH_SIZE = 4 +DIM = 4 + + +def run_test(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(42) # MOE initialization + + ep_model = SparseMLP(num_experts=4, expert_parallel="EP", hidden_size=DIM, intermediate_size=DIM) + local_model = SparseMLP(num_experts=4, expert_parallel=None, hidden_size=DIM, intermediate_size=DIM) + ep_model = ep_model.to(get_current_device()) + local_model = local_model.to(get_current_device()) + + # sync ep param + sync_moe_model_param(ep_model) + dist_dict = MOE_CONTEXT.parallel_info_dict + assert_equal_in_group(ep_model.experts.wi.data, dist_dict[2].dp_group) + assert_equal_in_group(ep_model.experts.wo.data, dist_dict[2].dp_group) + grad_handler = MoeGradientHandler(ep_model) + # sync tp param + sync_local_from_ep(local_model, ep_model) + + rank = dist.get_rank() + torch.cuda.manual_seed(78) + tp_data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) + ep_data = tp_data.detach()[2 * rank:2 * (rank + 1)] + + out_tp = local_model(tp_data)[0] + MOE_CONTEXT.reset_loss() + out_ep = ep_model(ep_data)[0] + MOE_CONTEXT.reset_loss() + assert torch.allclose(out_ep, out_tp[2 * rank:2 * (rank + 1)]) + + out_tp.mean().backward() + out_ep.mean().backward() + grad_handler.handle_gradient() + + assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[2].dp_group) + assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[2].dp_group) + + sync_local_from_ep(local_model, ep_model, assert_grad_flag=True) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_moe_ep_tp(): + spawn(run_test, 2) + + +if __name__ == '__main__': + test_moe_ep_tp()