[WIP]feat: add router replay for megatron engine#1207
[WIP]feat: add router replay for megatron engine#1207TaoZex wants to merge 89 commits intoinclusionAI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements Router Replay (R3) to align Mixture-of-Experts (MoE) routing decisions between rollout inference and training, preventing performance degradation caused by weight staleness in RL. The changes include monkey-patches for Megatron-Core components, engine-level wrappers for micro-batch scheduling, and workflow integrations to propagate routing indices from SGLang. Feedback focuses on critical architectural issues regarding global state and thread safety, specifically the risks of patching class-level iterators and using global lists for router instances. Additionally, there are recommendations to fix potential data loss in uneven batch splitting and to optimize performance by removing GPU-CPU synchronization points in the data processing pipeline.
Metric Comparison: Router Replay (r3)Using the moonlight_16b_a3b_gsm8k_grpo_megatron_h20.yaml configuration, compare the metric results with router replay (r3) enabled versus disabled:
|




Description
This PR implements Rollout Routing Replay (R3) for MoE models, addressing training instability caused by inference-training routing discrepancy in asynchronous RL training. R3 records expert routing indices from the inference engine and replays them during training, ensuring consistent expert selection regardless of weight staleness.
Key Changes
Core MoE Patch (router_replay_patch.py):
RouterReplayclass (one per MoE layer) withRECORD/REPLAY_FORWARD/REPLAY_BACKWARDactionspatched_routing: replacesTopKRouter.routing— usesscores.gather(1, target_topk_idx)in replay mode instead oftorch.topk, preserving gradient flowTransformerConfig.__init__,TopKRouter.__init__,TopKRouter.routing,MoEAlltoAllTokenDispatcher.preprocessData Distribution (router_replay_utils.py):
set_router_replay_data: 4-step pipeline — right-pad→left-align → TP/SP scatter → PP layer slice → Dense/MoE mappingRouterReplayHelper: locates RouterReplay instances by(pp_rank, vp_stage)get_num_layers_to_build,get_moe_num_layers_to_build(PP/VP aware)MegatronEngine Integration (megatron_engine_r3_patch.py):
forward_backward_batch: retrievesrouted_expertsvia side-channel, splits per micro-batch, injects replay setup via per-instance class swap, toggles forward/backward replay mode, cleans up infinallyActor & Workflow Integration (actor_r3_patch.py, rlvr_r3_patch.py):
routed_expertsper mini-batch, delivers via engine side-channel (bypassespack_tensor_dict4D incompatibility)resolve_r3_moe_configauto-resolvesnum_moe_layers/topkfrom HF config;extract_routed_expertsconverts SGLang numpy output to left-padded torch tensorSGLang Integration (sglang_r3_patch.py, sglang_remote.py):
routed_expertsas base64 inTokenizerManager._handle_batch_output(fixesjsonable_encodersilently flatteningtorch.Tensorto{}whenskip_tokenizer_init=True)num_sgl_tokendivisibilityOrchestrator & Config (rl_trainer.py, cli_args.py):
return_routed_experts=True→ auto-setsenable_router_replay, resolves MoE config, forcesskip_tokenizer_init=True, validates SGLang-only supportSupported Parallelism
scatter_to_sequence_parallel_region+seq_align_tobytp_sizeget_current_rank_layer_infoslices per PP rank's MoE layersvp_stageinRouterReplayHelperseq_align_to = tp_size * cp_size * 2whencp_size > 1New Metrics
rollout_train_logprobs_abs_diff_mean:Mean absolute difference between rollout and training log-probs over response tokens. Reflects routing-inconsistency-induced policy deviation; R3 should reduce this to only weight-update drift.
rollout_train_logprobs_abs_diff_std:Standard deviation of the above differences. High values indicate extreme outliers from completely inconsistent routing — the primary cause of PPO training collapse via disproportionate gradients.
Both computed in
torch.no_grad()on detached tensors; negligible overhead.Related Paper
Stabilizing MoE Reinforcement Learning by Aligning Training and Inference Routers (Ma et al., arXiv:2510.11370, 2025) — proposes R3 to reduce training-inference policy KL divergence and prevent MoE RL training collapse.
Related Issue
Fixes #(issue)
Type of Change
Checklist
pre-commit run --all-files)./docs/build_all.sh)main/review-prcommand/create-prBreaking Change Details (if applicable):
N/A
Additional Context
return_routed_experts=False(default) → all R3 code inactive, zero overheadreturn_routed_experts; config validation raises explicit errorrouted_expertsdelivered viaengine._r3_pending_routed_expertsto bypasspack_tensor_dict4D incompatibilitysglang_r3_patchmust be installed on inference server to fixtorch.Tensorserialization whenskip_tokenizer_init=True