Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 24 additions & 26 deletions src/transformers/models/glm4v_moe/modeling_glm4v_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,37 +289,25 @@ def forward(self, hidden_states):
return topk_indices, topk_weights


# FIXME: refactor moe
class Glm4vMoeTextMoE(nn.Module):
"""
A mixed expert module containing shared experts.
"""

def __init__(self, config: Glm4vMoeTextConfig):
class Glm4vMoeTextNaiveMoe(nn.ModuleList):
def __init__(self, config):
super().__init__()
self.config = config
self.experts = nn.ModuleList(
[
Glm4vMoeTextMLP(config, intermediate_size=config.moe_intermediate_size)
for _ in range(config.n_routed_experts)
]
)
self.gate = Glm4vMoeTextTopkRouter(config)
self.shared_experts = Glm4vMoeTextMLP(
config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
)
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_local_experts
for _ in range(self.num_experts):
self += [Glm4vMoeTextMLP(config, intermediate_size=config.moe_intermediate_size)]

def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
def forward(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
r"""
CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
to not have to do a loop here (deepseek has 256 experts soooo yeah).
"""
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self))
expert_mask = expert_mask.permute(2, 0, 1)

for expert_idx in range(len(self.experts)):
expert = self.experts[expert_idx]
for expert_idx in range(len(self)):
expert = self[expert_idx]
mask = expert_mask[expert_idx]
token_indices, weight_indices = torch.where(mask)

Expand All @@ -335,23 +323,33 @@ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weig
# and all expert are "local" meaning we shard but we don't gather
return final_hidden_states.type(hidden_states.dtype)


class Glm4vMoeTextMoE(nn.Module):
def __init__(self, config: Glm4vMoeTextConfig):
super().__init__(config)
self.config = config
self.experts = Glm4vMoeTextNaiveMoe(config)
self.gate = Glm4vMoeTextTopkRouter(config)
self.shared_experts = Glm4vMoeTextMLP(
config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
)

def forward(self, hidden_states):
residuals = hidden_states
orig_shape = hidden_states.shape
topk_indices, topk_weights = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)
hidden_states = hidden_states + self.shared_experts(residuals)
return hidden_states


class Glm4vMoeTextMLP(nn.Module):
def __init__(self, config, hidden_size=None, intermediate_size=None):
def __init__(self, config, intermediate_size=None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size

self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
Expand Down
25 changes: 17 additions & 8 deletions src/transformers/models/glm4v_moe/modular_glm4v_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
import torch.nn as nn

from ..deepseek_v3.modeling_deepseek_v3 import DeepseekV3NaiveMoe
from ...cache_utils import Cache
from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
Expand Down Expand Up @@ -408,22 +409,30 @@ def __init__(self, config: Glm4vMoeTextConfig):
super().__init__(config)


# FIXME: update the expert class
class Glm4vMoeTextMoE(Glm4MoeMoE):
class Glm4vMoeTextNaiveMoe(DeepseekV3NaiveMoe):
pass


class Glm4vMoeTextMoE(nn.Module):
def __init__(self, config: Glm4vMoeTextConfig):
super().__init__(config)
self.config = config
self.experts = nn.ModuleList(
[
Glm4vMoeTextMLP(config, intermediate_size=config.moe_intermediate_size)
for _ in range(config.n_routed_experts)
]
)
self.experts = Glm4vMoeTextNaiveMoe(config)
self.gate = Glm4vMoeTextTopkRouter(config)
self.shared_experts = Glm4vMoeTextMLP(
config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
)

def forward(self, hidden_states):
residuals = hidden_states
orig_shape = hidden_states.shape
topk_indices, topk_weights = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)
hidden_states = hidden_states + self.shared_experts(residuals)
return hidden_states



class Glm4vMoeTextMLP(Glm4MoeMLP):
pass
Expand Down