diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index bfa106324dc4..6068cb7f891d 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -289,37 +289,25 @@ def forward(self, hidden_states): return topk_indices, topk_weights -# FIXME: refactor moe -class Glm4vMoeTextMoE(nn.Module): - """ - A mixed expert module containing shared experts. - """ - - def __init__(self, config: Glm4vMoeTextConfig): +class Glm4vMoeTextNaiveMoe(nn.ModuleList): + def __init__(self, config): super().__init__() - self.config = config - self.experts = nn.ModuleList( - [ - Glm4vMoeTextMLP(config, intermediate_size=config.moe_intermediate_size) - for _ in range(config.n_routed_experts) - ] - ) - self.gate = Glm4vMoeTextTopkRouter(config) - self.shared_experts = Glm4vMoeTextMLP( - config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts - ) + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + for _ in range(self.num_experts): + self += [Glm4vMoeTextMLP(config, intermediate_size=config.moe_intermediate_size)] - def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): + def forward(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): r""" CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused to not have to do a loop here (deepseek has 256 experts soooo yeah). """ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) - expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts)) + expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self)) expert_mask = expert_mask.permute(2, 0, 1) - for expert_idx in range(len(self.experts)): - expert = self.experts[expert_idx] + for expert_idx in range(len(self)): + expert = self[expert_idx] mask = expert_mask[expert_idx] token_indices, weight_indices = torch.where(mask) @@ -335,23 +323,33 @@ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weig # and all expert are "local" meaning we shard but we don't gather return final_hidden_states.type(hidden_states.dtype) + +class Glm4vMoeTextMoE(nn.Module): + def __init__(self, config: Glm4vMoeTextConfig): + super().__init__(config) + self.config = config + self.experts = Glm4vMoeTextNaiveMoe(config) + self.gate = Glm4vMoeTextTopkRouter(config) + self.shared_experts = Glm4vMoeTextMLP( + config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts + ) + def forward(self, hidden_states): residuals = hidden_states orig_shape = hidden_states.shape topk_indices, topk_weights = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) + hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states + self.shared_experts(residuals) return hidden_states class Glm4vMoeTextMLP(nn.Module): - def __init__(self, config, hidden_size=None, intermediate_size=None): + def __init__(self, config, intermediate_size=None): super().__init__() self.config = config - self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) diff --git a/src/transformers/models/glm4v_moe/modular_glm4v_moe.py b/src/transformers/models/glm4v_moe/modular_glm4v_moe.py index 4f333c2ef8b2..572756680985 100644 --- a/src/transformers/models/glm4v_moe/modular_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modular_glm4v_moe.py @@ -17,6 +17,7 @@ import torch import torch.nn as nn +from ..deepseek_v3.modeling_deepseek_v3 import DeepseekV3NaiveMoe from ...cache_utils import Cache from ...configuration_utils import PretrainedConfig from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -408,22 +409,30 @@ def __init__(self, config: Glm4vMoeTextConfig): super().__init__(config) -# FIXME: update the expert class -class Glm4vMoeTextMoE(Glm4MoeMoE): +class Glm4vMoeTextNaiveMoe(DeepseekV3NaiveMoe): + pass + + +class Glm4vMoeTextMoE(nn.Module): def __init__(self, config: Glm4vMoeTextConfig): super().__init__(config) self.config = config - self.experts = nn.ModuleList( - [ - Glm4vMoeTextMLP(config, intermediate_size=config.moe_intermediate_size) - for _ in range(config.n_routed_experts) - ] - ) + self.experts = Glm4vMoeTextNaiveMoe(config) self.gate = Glm4vMoeTextTopkRouter(config) self.shared_experts = Glm4vMoeTextMLP( config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts ) + def forward(self, hidden_states): + residuals = hidden_states + orig_shape = hidden_states.shape + topk_indices, topk_weights = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states + + class Glm4vMoeTextMLP(Glm4MoeMLP): pass