From 986451685f0d5285000303fdc3c67f785f0395b3 Mon Sep 17 00:00:00 2001 From: haze188 Date: Fri, 28 Jun 2024 06:37:59 +0000 Subject: [PATCH 1/2] [misc] fix typo --- colossalai/shardformer/modeling/deepseek.py | 31 ++---- .../shardformer/modeling/deepseekmoe.py | 100 ------------------ 2 files changed, 11 insertions(+), 120 deletions(-) delete mode 100644 colossalai/shardformer/modeling/deepseekmoe.py diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py index 6edaca71e168..1e546c60d03f 100644 --- a/colossalai/shardformer/modeling/deepseek.py +++ b/colossalai/shardformer/modeling/deepseek.py @@ -2,29 +2,28 @@ import torch import torch.distributed as dist -import torch.nn.functional as F from torch.distributed import ProcessGroup # from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo from torch.nn import CrossEntropyLoss from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask -from colossalai.shardformer.modeling.deepseek_moe_16b_base.modeling_deepseek import ( - AddAuxiliaryLoss, - DeepseekModel, - DeepseekMoE, - DeepseekForCausalLM, - CausalLMOutputWithPast -) -from colossalai.shardformer.modeling.deepseek_moe_16b_base.configuration_deepseek import DeepseekConfig from transformers.utils import logging -from transformers.models import Ca from colossalai.lazy import LazyInitContext from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.modeling.deepseek_moe_16b_base.configuration_deepseek import DeepseekConfig +from colossalai.shardformer.modeling.deepseek_moe_16b_base.modeling_deepseek import ( + AddAuxiliaryLoss, + CausalLMOutputWithPast, + DeepseekForCausalLM, + DeepseekModel, + DeepseekMoE, +) from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard.utils import set_tensors_to_none + class EPDeepseekMoE(DeepseekMoE): def __init__(self, config: DeepseekConfig): super().__init__(config) @@ -95,9 +94,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: output_states_list.append(split_states) output_states = torch.cat(output_states_list) output_states = MoeOutGradScaler.apply(output_states, self.ep_size) - dispatch_states, _ = all_to_all_uneven( - output_states, output_split_list, input_split_list, self.ep_group - ) + dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) recover_token_idx = torch.empty_like(flat_topk_token_idx) recover_token_idx[flat_topk_token_idx] = torch.arange( flat_topk_token_idx.size(0), device=flat_topk_token_idx.device @@ -112,7 +109,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: output_hidden_states = output_hidden_states + self.shared_experts(identity) return output_hidden_states - class DeepseekPipelineForwards: """ @@ -298,11 +294,7 @@ def custom_forward(*inputs): next_cache = next_decoder_cache if use_cache else None if stage_manager.is_last_stage(): - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) # always return dict for imediate stage return { "hidden_states": hidden_states, @@ -403,7 +395,6 @@ def deepseek_for_causal_lm_forward( shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) - if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output diff --git a/colossalai/shardformer/modeling/deepseekmoe.py b/colossalai/shardformer/modeling/deepseekmoe.py deleted file mode 100644 index d135c5feb9f0..000000000000 --- a/colossalai/shardformer/modeling/deepseekmoe.py +++ /dev/null @@ -1,100 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup - -from colossalai.lazy import LazyInitContext -from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven -from colossalai.shardformer.shard.utils import set_tensors_to_none - -from .deepseek_moe_16b_base.configuration_deepseek import DeepseekConfig - -# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo -from .deepseek_moe_16b_base.modeling_deepseek import AddAuxiliaryLoss, DeepseekMoE - - -class EPDeepseekMoE(DeepseekMoE): - def __init__(self, config: DeepseekConfig): - super().__init__(config) - - def setup_ep(self, ep_group: ProcessGroup): - ep_group = ep_group - self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1 - self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0 - self.num_experts = self.config.n_routed_experts - assert self.num_experts % self.ep_size == 0 - self.ep_group = ep_group - self.num_experts_per_ep = self.num_experts // self.ep_size - self.expert_start_idx = self.ep_rank * self.num_experts_per_ep - held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep] - set_tensors_to_none(self.experts, exclude=set(held_experts)) - for p in self.experts.parameters(): - p.ep_group = ep_group - - @staticmethod - def from_native_module(module: DeepseekMoE, *args, **kwargs) -> "EPDeepseekMoE": - LazyInitContext.materialize(module) - module.__class__ = EPDeepseekMoE - assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!" - module.setup_ep(kwargs["ep_group"]) - return module - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - identity = hidden_states - orig_shape = hidden_states.shape - - topk_experts_idx, topk_experts_weight, aux_loss = self.gate(hidden_states) - - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) # [t0, t1, t2 ...] - hidden_states = hidden_states.repeat_interleave( - self.num_experts_per_tok, dim=0 - ) # after repeat_interleave: [t0 t0 t1 t1 t2 t2 ... ] - - flat_topk_experts_idx = topk_experts_idx.view(-1) # [e0 e1 e2 ...] - # The elements of flat_topk_token_idx are token ids, which are arranged in ascending order of expert ids. - flat_topk_token_idx = flat_topk_experts_idx.argsort() - - # Now we adjust the order of the hidden states, also in ascending order of expert id - dispatch_states = hidden_states[flat_topk_token_idx] - input_split_sizes = flat_topk_experts_idx.bincount(minlength=self.num_experts) # [n0, n1, n2, n3] - print(f"{input_split_sizes=}") - output_split_sizes = torch.zeros_like(input_split_sizes) - - # [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3] - dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) - - input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() - output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() - output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) - output_states = MoeInGradScaler.apply(output_states, self.ep_size) - - if output_states.size(0) > 0: - if self.num_experts_per_ep == 1: - expert = self.experts[self.expert_start_idx] - output_states = expert(output_states) - else: - output_states_splits = output_states.split(output_split_sizes.tolist()) - output_states_list = [] - for i, split_states in enumerate(output_states_splits): - if split_states.size(0) == 0: # no token routed to this experts - continue - expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep] - split_states = expert(split_states) - output_states_list.append(split_states) - output_states = torch.cat(output_states_list) # (4, h) (8, h) - output_states = MoeOutGradScaler.apply(output_states, self.ep_size) - dispatch_states, _ = all_to_all_uneven( - output_states, output_split_list, input_split_list, self.ep_group - ) # 专家处理完对应token的输出,要返还回去给别的rank - recover_token_idx = torch.empty_like(flat_topk_token_idx) # (6,) - recover_token_idx[flat_topk_token_idx] = torch.arange( - flat_topk_token_idx.size(0), device=flat_topk_token_idx.device - ) - - output_hidden_states = dispatch_states[recover_token_idx] # t0 t0 t1 t1 t2 t2 - output_hidden_states = output_hidden_states.view(-1, self.num_experts_per_tok, orig_shape[-1]) - output_hidden_states = (output_hidden_states * topk_experts_weight[:, :, None]).sum(dim=-2) # (BS, h) - output_hidden_states = output_hidden_states.view(*orig_shape) - output_hidden_states = AddAuxiliaryLoss.apply(output_hidden_states, aux_loss) - if self.config.n_shared_experts is not None: - output_hidden_states = output_hidden_states + self.shared_experts(identity) - return output_hidden_states From 76d2c64ec0f7db51eaf84777fc5aafbc38f5287e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 Jun 2024 06:46:08 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/policies/deepseek.py | 8 ++++++-- tests/test_moe/test_deepseek_layer.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py index 953c74beecea..88d0a91ca93c 100644 --- a/colossalai/shardformer/policies/deepseek.py +++ b/colossalai/shardformer/policies/deepseek.py @@ -4,10 +4,14 @@ import torch.nn as nn from torch import Tensor from torch.nn import Module -from colossalai.shardformer.modeling.deepseek_moe_16b_base.modeling_deepseek import DeepseekDecoderLayer, DeepseekForCausalLM, DeepseekModel from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col -from colossalai.shardformer.modeling.deepseek import EPDeepseekMoE, DeepseekPipelineForwards +from colossalai.shardformer.modeling.deepseek import DeepseekPipelineForwards, EPDeepseekMoE +from colossalai.shardformer.modeling.deepseek_moe_16b_base.modeling_deepseek import ( + DeepseekDecoderLayer, + DeepseekForCausalLM, + DeepseekModel, +) from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"] diff --git a/tests/test_moe/test_deepseek_layer.py b/tests/test_moe/test_deepseek_layer.py index 69952e16acc6..06dfbfe3b515 100644 --- a/tests/test_moe/test_deepseek_layer.py +++ b/tests/test_moe/test_deepseek_layer.py @@ -7,9 +7,9 @@ import colossalai from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.shardformer.modeling.deepseek import EPDeepseekMoE from colossalai.shardformer.modeling.deepseek_moe_16b_base.configuration_deepseek import DeepseekConfig from colossalai.shardformer.modeling.deepseek_moe_16b_base.modeling_deepseek import DeepseekMoE -from colossalai.shardformer.modeling.deepseek import EPDeepseekMoE from colossalai.testing.utils import spawn tokens, n_experts = 7, 4