[Doc] MoE routing capture and replay recipe #44925
Conversation
b0cda06 to
a862843
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@Dlove123 blocked for extremely obvious reasons. Also sent you a $100 Paypal invoice for the time wasted in dealing with this. |
|
I'm linking this PR with Quentin's issue here #42638 seems appropriate |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: qwen2_moe |
ArthurZucker
left a comment
There was a problem hiding this comment.
from __future__ import annotations
from contextlib import contextmanager
import torch
from transformers import Qwen2MoeConfig, Qwen2MoeForCausalLM
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeTopKRouter
from transformers.monkey_patching import (
apply_patches, clear_patch_mapping, register_patch_mapping,
)
from transformers.utils.generic import _CAN_RECORD_REGISTRY, OutputRecorder
# --- REPLAY: router subclass reading forced indices off self
class ReplayableQwen2MoeTopKRouter(Qwen2MoeTopKRouter):
_forced_indices: torch.Tensor | None = None
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
logits = torch.nn.functional.linear(hidden_states, self.weight)
logits = torch.nn.functional.softmax(logits, dtype=torch.float, dim=-1)
forced = self._forced_indices
if forced is not None:
idx = forced
val = logits.gather(-1, forced) # Megatron-style: forced experts,
else: # fresh scores from current logits
val, idx = torch.topk(logits, self.top_k, dim=-1)
if self.norm_topk_prob:
val = val / val.sum(dim=-1, keepdim=True)
return logits, val.to(logits.dtype), idx
@contextmanager
def replay_moe_routing(model, selected_experts_per_layer):
routers = [m for m in model.modules() if isinstance(m, ReplayableQwen2MoeTopKRouter)]
assert len(routers) == len(selected_experts_per_layer)
for r, t in zip(routers, selected_experts_per_layer):
r._forced_indices = t
try:
yield
finally:
for r in routers:
r._forced_indices = Noneand a test:
torch.manual_seed(0)
cfg = Qwen2MoeConfig()
register_patch_mapping({"Qwen2MoeTopKRouter": ReplayableQwen2MoeTopKRouter})
with apply_patches():
model = Qwen2MoeForCausalLM(cfg).eval()
_CAN_RECORD_REGISTRY[model.model.__class__]["selected_experts"] = OutputRecorder(Qwen2MoeTopKRouter, index=2)
input_ids = torch.randint(0, cfg.vocab_size, (1, 8))
with torch.no_grad():
captured = model.model(input_ids=input_ids, output_selected_experts=True)
selected_experts = captured["selected_experts"]
ref = captured.last_hidden_state.clone()
with torch.no_grad(), replay_moe_routing(model.model, list(selected_experts)):
replayed = model.model(input_ids=input_ids)
diff = (ref - replayed.last_hidden_state).abs().max().item()
print(f"captured {len(selected_experts)} layers, replay max |Δhidden| = {diff:.3e}")
assert diff < 1e-5Verified output:
captured 2 layers of selected_experts
layer 0: shape=(8, 2) dtype=torch.int64
layer 1: shape=(8, 2) dtype=torch.int64
eager replay max |Δhidden| = 0.000e+00
Replace the intrusive record/replay implementation across modeling files with a documentation-only guide. All three pieces — the replayable router subclass, the replay context manager, and the runtime OutputRecorder registration — can be built on top of the existing monkey_patching and output_capturing APIs without touching core MoE modeling code. Also shows the one-line conversion from vLLM's CompletionOutput.routed_experts numpy array to the per-layer tuple this pattern expects, enabling RLHF workflows that generate with vLLM and train with transformers.
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=44925&sha=200b25 |
The previous revert commit accidentally rolled back unrelated work on these two files — version bump, TorchvisionBackend addition, and module-alias refactor. Restore those while keeping the MoE-specific additions (MoERouting export, output_moe_routing kwarg) removed.
* inital protoype * remove unneeded selected_experts * Revert MoE expert replay; document pattern via monkey patching Replace the intrusive record/replay implementation across modeling files with a documentation-only guide. All three pieces — the replayable router subclass, the replay context manager, and the runtime OutputRecorder registration — can be built on top of the existing monkey_patching and output_capturing APIs without touching core MoE modeling code. Also shows the one-line conversion from vLLM's CompletionOutput.routed_experts numpy array to the per-layer tuple this pattern expects, enabling RLHF workflows that generate with vLLM and train with transformers. * Preserve unrelated forward-progress in __init__.py and generic.py The previous revert commit accidentally rolled back unrelated work on these two files — version bump, TorchvisionBackend addition, and module-alias refactor. Restore those while keeping the MoE-specific additions (MoERouting export, output_moe_routing kwarg) removed.
What does this PR do?
This PR adds a first-class MoE routing capture/replay API for Qwen2Moe and introduces shared MoE routing helpers for reuse by other MoE model families.
It adds:
MoERoutingpayload in modeling outputsoutput_moe_routing=Trueto return exact selected experts from the real forward pathmoe_routing=...to replay those expert choices on a later forwardsrc/transformers/integrations/moe_routing.pyThe replay semantics follow the minimal router-replay contract used in systems like Megatron: the model reuses the recorded expert indices while recomputing current routing scores for those forced experts.
This PR uses Qwen2Moe as the prototype implementation. It does not add transport/runtime-specific logic such as DeepEP integration yet.
Fixes #42638
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.