Skip to content

[Doc] MoE routing capture and replay recipe #44925

Merged
kashif merged 9 commits intohuggingface:mainfrom
kashif:moe-selected-expert-indices
Apr 14, 2026
Merged

[Doc] MoE routing capture and replay recipe #44925
kashif merged 9 commits intohuggingface:mainfrom
kashif:moe-selected-expert-indices

Conversation

@kashif
Copy link
Copy Markdown
Contributor

@kashif kashif commented Mar 22, 2026

What does this PR do?

This PR adds a first-class MoE routing capture/replay API for Qwen2Moe and introduces shared MoE routing helpers for reuse by other MoE model families.

It adds:

  • a structured MoERouting payload in modeling outputs
  • output_moe_routing=True to return exact selected experts from the real forward path
  • moe_routing=... to replay those expert choices on a later forward
  • shared helper utilities in src/transformers/integrations/moe_routing.py
  • focused tests covering capture, replay, and fail-closed validation

The replay semantics follow the minimal router-replay contract used in systems like Megatron: the model reuses the recorded expert indices while recomputing current routing scores for those forced experts.

This PR uses Qwen2Moe as the prototype implementation. It does not add transport/runtime-specific logic such as DeepEP integration yet.

Fixes #42638

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@kashif kashif force-pushed the moe-selected-expert-indices branch from b0cda06 to a862843 Compare March 22, 2026 14:09
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Rocketknight1
Copy link
Copy Markdown
Member

@Dlove123 blocked for extremely obvious reasons. Also sent you a $100 Paypal invoice for the time wasted in dealing with this.

@casinca
Copy link
Copy Markdown
Contributor

casinca commented Mar 24, 2026

I'm linking this PR with Quentin's issue here #42638 seems appropriate

@huggingface huggingface deleted a comment from Dlove123 Mar 24, 2026
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: qwen2_moe

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  from __future__ import annotations
  from contextlib import contextmanager

  import torch
  from transformers import Qwen2MoeConfig, Qwen2MoeForCausalLM
  from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeTopKRouter
  from transformers.monkey_patching import (
      apply_patches, clear_patch_mapping, register_patch_mapping,
  )
  from transformers.utils.generic import _CAN_RECORD_REGISTRY, OutputRecorder

  # ---  REPLAY: router subclass reading forced indices off self
  class ReplayableQwen2MoeTopKRouter(Qwen2MoeTopKRouter):
      _forced_indices: torch.Tensor | None = None

      def forward(self, hidden_states):
          hidden_states = hidden_states.reshape(-1, self.hidden_dim)
          logits = torch.nn.functional.linear(hidden_states, self.weight)
          logits = torch.nn.functional.softmax(logits, dtype=torch.float, dim=-1)

          forced = self._forced_indices
          if forced is not None:
              idx = forced
              val = logits.gather(-1, forced)           # Megatron-style: forced experts,
          else:                                         # fresh scores from current logits
              val, idx = torch.topk(logits, self.top_k, dim=-1)

          if self.norm_topk_prob:
              val = val / val.sum(dim=-1, keepdim=True)
          return logits, val.to(logits.dtype), idx


  @contextmanager
  def replay_moe_routing(model, selected_experts_per_layer):
      routers = [m for m in model.modules() if isinstance(m, ReplayableQwen2MoeTopKRouter)]
      assert len(routers) == len(selected_experts_per_layer)
      for r, t in zip(routers, selected_experts_per_layer):
          r._forced_indices = t
      try:
          yield
      finally:
          for r in routers:
              r._forced_indices = None

and a test:

      torch.manual_seed(0)
      cfg = Qwen2MoeConfig()
      register_patch_mapping({"Qwen2MoeTopKRouter": ReplayableQwen2MoeTopKRouter})
      with apply_patches():
          model = Qwen2MoeForCausalLM(cfg).eval()

      _CAN_RECORD_REGISTRY[model.model.__class__]["selected_experts"] = OutputRecorder(Qwen2MoeTopKRouter, index=2)

      input_ids = torch.randint(0, cfg.vocab_size, (1, 8))

      with torch.no_grad():
          captured = model.model(input_ids=input_ids, output_selected_experts=True)
      selected_experts = captured["selected_experts"]
      ref = captured.last_hidden_state.clone()

      with torch.no_grad(), replay_moe_routing(model.model, list(selected_experts)):
          replayed = model.model(input_ids=input_ids)

      diff = (ref - replayed.last_hidden_state).abs().max().item()
      print(f"captured {len(selected_experts)} layers, replay max |Δhidden| = {diff:.3e}")
      assert diff < 1e-5

Verified output:
captured 2 layers of selected_experts
layer 0: shape=(8, 2) dtype=torch.int64
layer 1: shape=(8, 2) dtype=torch.int64
eager replay max |Δhidden| = 0.000e+00

Replace the intrusive record/replay implementation across modeling files
with a documentation-only guide. All three pieces — the replayable router
subclass, the replay context manager, and the runtime OutputRecorder
registration — can be built on top of the existing monkey_patching and
output_capturing APIs without touching core MoE modeling code.

Also shows the one-line conversion from vLLM's CompletionOutput.routed_experts
numpy array to the per-layer tuple this pattern expects, enabling RLHF
workflows that generate with vLLM and train with transformers.
@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=44925&sha=200b25

The previous revert commit accidentally rolled back unrelated work on
these two files — version bump, TorchvisionBackend addition, and
module-alias refactor. Restore those while keeping the MoE-specific
additions (MoERouting export, output_moe_routing kwarg) removed.
@kashif kashif changed the title [MOE] MoE routing capture and replay support [Doc] MoE routing capture and replay support Apr 13, 2026
@kashif kashif changed the title [Doc] MoE routing capture and replay support [Doc] MoE routing capture and replay recipe Apr 13, 2026
@kashif kashif added this pull request to the merge queue Apr 14, 2026
Merged via the queue into huggingface:main with commit def5e68 Apr 14, 2026
16 checks passed
@kashif kashif deleted the moe-selected-expert-indices branch April 14, 2026 08:09
sirzechs66 pushed a commit to sirzechs66/transformers that referenced this pull request Apr 18, 2026
* inital protoype

* remove unneeded selected_experts

* Revert MoE expert replay; document pattern via monkey patching

Replace the intrusive record/replay implementation across modeling files
with a documentation-only guide. All three pieces — the replayable router
subclass, the replay context manager, and the runtime OutputRecorder
registration — can be built on top of the existing monkey_patching and
output_capturing APIs without touching core MoE modeling code.

Also shows the one-line conversion from vLLM's CompletionOutput.routed_experts
numpy array to the per-layer tuple this pattern expects, enabling RLHF
workflows that generate with vLLM and train with transformers.

* Preserve unrelated forward-progress in __init__.py and generic.py

The previous revert commit accidentally rolled back unrelated work on
these two files — version bump, TorchvisionBackend addition, and
module-alias refactor. Restore those while keeping the MoE-specific
additions (MoERouting export, output_moe_routing kwarg) removed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Routing Replay for MoEs

5 participants