Skip to content

Qwen2_moe: Avoid zero tokens fowarding for some experts #32283

@Coco58323

Description

@Coco58323

System Info

transformers=4.43.3
python=3.8
Linux

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Under Auto-GPTQ with Triton kernel, it would use math.log2() function in Line 96
image While for the Implementation of MoE,

for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

some experts directly forward with zero tokens and therefore the input shape is like [0, seq_length, hidden_states], and fails on log2()

The issue could be solved by checking the number of tokens before Line 675
if current_state.shape[0] == 0: continue

Expected behavior

No more forwarding with zero tokens.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions