diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py index 07b86cd638c8..8ebda357b380 100644 --- a/colossalai/shardformer/policies/deepseek.py +++ b/colossalai/shardformer/policies/deepseek.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List, Union @@ -39,21 +40,20 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.shard_config.enable_tensor_parallelism: raise NotImplementedError("Tensor parallelism is not supported for Deepseek model now.") - if getattr(self.shard_config, "ep_group", None) is None: - raise ValueError("You must pass in ep_group via shard_config for expert parallel!") - - # expert parallel - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="mlp", - target_module=EPDeepseekMoE, - kwargs={"ep_group": self.shard_config.ep_group}, - ) - ], - policy=policy, - target_key="DeepseekDecoderLayer", - ) + + if getattr(self.shard_config, "ep_group", None) is not None: + # expert parallel + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="mlp", + target_module=EPDeepseekMoE, + kwargs={"ep_group": self.shard_config.ep_group}, + ) + ], + policy=policy, + target_key="DeepseekDecoderLayer", + ) # optimization configuration if self.shard_config.enable_fused_normalization: @@ -82,7 +82,10 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ) if self.shard_config.enable_flash_attention: - raise NotImplementedError("Flash attention has already been replaced in deepseek.") + warnings.warn( + "Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False." + ) + self.shard_config.enable_flash_attention = False return policy diff --git a/tests/test_moe/test_deepseek_layer.py b/tests/test_moe/test_deepseek_layer.py index 328ffb1de5f8..85cc986959fd 100644 --- a/tests/test_moe/test_deepseek_layer.py +++ b/tests/test_moe/test_deepseek_layer.py @@ -25,14 +25,17 @@ def check_deepseek_moe_layer(): ep_size=dist.get_world_size(), ) - config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True) - config.num_hidden_layers = 1 - config.n_routed_experts = n_experts - config.num_experts_per_tok = top_k - config.hidden_size = hidden_size - config.intermediate_size = hidden_size * 2 - config.first_k_dense_replace = 0 - config.num_attention_heads = 2 + config = AutoConfig.from_pretrained( + "deepseek-ai/deepseek-moe-16b-base", + num_hidden_layers=1, + n_routed_experts=n_experts, + num_experts_per_tok=top_k, + hidden_size=hidden_size, + intermediate_size=hidden_size * 2, + first_k_dense_replace=0, + num_attention_heads=2, + trust_remote_code=True, + ) torch.manual_seed(0) # get the moe layer in auto model orig_model = AutoModel.from_config(config, trust_remote_code=True).layers[0].mlp.cuda()