Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 11 additions & 20 deletions colossalai/shardformer/modeling/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,28 @@

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed import ProcessGroup

# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
from torch.nn import CrossEntropyLoss
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from colossalai.shardformer.modeling.deepseek_moe_16b_base.modeling_deepseek import (
AddAuxiliaryLoss,
DeepseekModel,
DeepseekMoE,
DeepseekForCausalLM,
CausalLMOutputWithPast
)
from colossalai.shardformer.modeling.deepseek_moe_16b_base.configuration_deepseek import DeepseekConfig
from transformers.utils import logging
from transformers.models import Ca

from colossalai.lazy import LazyInitContext
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.modeling.deepseek_moe_16b_base.configuration_deepseek import DeepseekConfig
from colossalai.shardformer.modeling.deepseek_moe_16b_base.modeling_deepseek import (
AddAuxiliaryLoss,
CausalLMOutputWithPast,
DeepseekForCausalLM,
DeepseekModel,
DeepseekMoE,
)
from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none


class EPDeepseekMoE(DeepseekMoE):
def __init__(self, config: DeepseekConfig):
super().__init__(config)
Expand Down Expand Up @@ -95,9 +94,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(
output_states, output_split_list, input_split_list, self.ep_group
)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
recover_token_idx = torch.empty_like(flat_topk_token_idx)
recover_token_idx[flat_topk_token_idx] = torch.arange(
flat_topk_token_idx.size(0), device=flat_topk_token_idx.device
Expand All @@ -112,7 +109,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
output_hidden_states = output_hidden_states + self.shared_experts(identity)
return output_hidden_states



class DeepseekPipelineForwards:
"""
Expand Down Expand Up @@ -298,11 +294,7 @@ def custom_forward(*inputs):
next_cache = next_decoder_cache if use_cache else None

if stage_manager.is_last_stage():
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
# always return dict for imediate stage
return {
"hidden_states": hidden_states,
Expand Down Expand Up @@ -403,7 +395,6 @@ def deepseek_for_causal_lm_forward(
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)


if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
Expand Down
100 changes: 0 additions & 100 deletions colossalai/shardformer/modeling/deepseekmoe.py

This file was deleted.

8 changes: 6 additions & 2 deletions colossalai/shardformer/policies/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
from colossalai.shardformer.modeling.deepseek_moe_16b_base.modeling_deepseek import DeepseekDecoderLayer, DeepseekForCausalLM, DeepseekModel

from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.modeling.deepseek import EPDeepseekMoE, DeepseekPipelineForwards
from colossalai.shardformer.modeling.deepseek import DeepseekPipelineForwards, EPDeepseekMoE
from colossalai.shardformer.modeling.deepseek_moe_16b_base.modeling_deepseek import (
DeepseekDecoderLayer,
DeepseekForCausalLM,
DeepseekModel,
)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_moe/test_deepseek_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

import colossalai
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.shardformer.modeling.deepseek import EPDeepseekMoE
from colossalai.shardformer.modeling.deepseek_moe_16b_base.configuration_deepseek import DeepseekConfig
from colossalai.shardformer.modeling.deepseek_moe_16b_base.modeling_deepseek import DeepseekMoE
from colossalai.shardformer.modeling.deepseek import EPDeepseekMoE
from colossalai.testing.utils import spawn

tokens, n_experts = 7, 4
Expand Down