Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions colossalai/shardformer/policies/deepseek.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Union

Expand Down Expand Up @@ -39,21 +40,20 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:

if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for Deepseek model now.")
if getattr(self.shard_config, "ep_group", None) is None:
raise ValueError("You must pass in ep_group via shard_config for expert parallel!")

# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="mlp",
target_module=EPDeepseekMoE,
kwargs={"ep_group": self.shard_config.ep_group},
)
],
policy=policy,
target_key="DeepseekDecoderLayer",
)

if getattr(self.shard_config, "ep_group", None) is not None:
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="mlp",
target_module=EPDeepseekMoE,
kwargs={"ep_group": self.shard_config.ep_group},
)
],
policy=policy,
target_key="DeepseekDecoderLayer",
)

# optimization configuration
if self.shard_config.enable_fused_normalization:
Expand Down Expand Up @@ -82,7 +82,10 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
)

if self.shard_config.enable_flash_attention:
raise NotImplementedError("Flash attention has already been replaced in deepseek.")
warnings.warn(
"Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False."
)
self.shard_config.enable_flash_attention = False

return policy

Expand Down
19 changes: 11 additions & 8 deletions tests/test_moe/test_deepseek_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,17 @@ def check_deepseek_moe_layer():
ep_size=dist.get_world_size(),
)

config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
config.num_hidden_layers = 1
config.n_routed_experts = n_experts
config.num_experts_per_tok = top_k
config.hidden_size = hidden_size
config.intermediate_size = hidden_size * 2
config.first_k_dense_replace = 0
config.num_attention_heads = 2
config = AutoConfig.from_pretrained(
"deepseek-ai/deepseek-moe-16b-base",
num_hidden_layers=1,
n_routed_experts=n_experts,
num_experts_per_tok=top_k,
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
first_k_dense_replace=0,
num_attention_heads=2,
trust_remote_code=True,
)
torch.manual_seed(0)
# get the moe layer in auto model
orig_model = AutoModel.from_config(config, trust_remote_code=True).layers[0].mlp.cuda()
Expand Down