Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions colossalai/shardformer/modeling/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,19 @@ def setup_process_groups(
for p in self.experts.parameters():
set_moe_tensor_ep_group(p, ep_group)

if self.config.n_shared_experts is not None:
self.shared_experts.gate_proj = Linear1D_Col.from_native_module(
self.shared_experts.gate_proj, self.tp_group, fp8_communication=self.fp8_communication
)

self.shared_experts.up_proj = Linear1D_Col.from_native_module(
self.shared_experts.up_proj, self.tp_group, fp8_communication=self.fp8_communication
)

self.shared_experts.down_proj = Linear1D_Row.from_native_module(
self.shared_experts.down_proj, self.tp_group, fp8_communication=self.fp8_communication
)

@staticmethod
def from_native_module(
module,
Expand Down
10 changes: 6 additions & 4 deletions tests/test_shardformer/test_model/test_shard_deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
NUM_BATCH = 8
NUM_TOK_PER_BATCH, NUM_EXPERTS = 64, 4
NUM_LAYERS = 4
HIDDEN_SIZE_PER_HEAD = 4
HIDDEN_SIZE_PER_HEAD = 8
NUM_HEADS = 8
TOP_K = 2


def run_deepseek_commom(config: Tuple[int, ...]):
def run_deepseek_commom(parallel_config: Tuple[int, ...]):
Randomizer.reset_index()
stage, ep_size, pp_size, tp_size, sp_size = config
print(f"rank {dist.get_rank()} testing {parallel_config}")
stage, ep_size, pp_size, tp_size, sp_size = parallel_config
world_size = dist.get_world_size()
rank = dist.get_rank()
dtype, precision = torch.bfloat16, "bf16"
Expand Down Expand Up @@ -65,6 +66,7 @@ def run_deepseek_commom(config: Tuple[int, ...]):
attn_implementation="flash_attention_2",
torch_dtype="float16",
n_routed_experts=NUM_EXPERTS,
n_shared_experts=2,
num_experts_per_tok=TOP_K,
trust_remote_code=True,
)
Expand Down Expand Up @@ -159,7 +161,7 @@ def run_deepseek_commom(config: Tuple[int, ...]):
if rank == world_size - 1:
shutil.rmtree(model_dir)

print(f"rank {dist.get_rank()} test passed")
print(f"rank {dist.get_rank()} passed {parallel_config}")


@parameterize(
Expand Down