Skip to content

[WIP]feat: add router replay for megatron engine#1207

Draft
TaoZex wants to merge 89 commits intoinclusionAI:mainfrom
TaoZex:final_moe
Draft

[WIP]feat: add router replay for megatron engine#1207
TaoZex wants to merge 89 commits intoinclusionAI:mainfrom
TaoZex:final_moe

Conversation

@TaoZex
Copy link
Copy Markdown
Collaborator

@TaoZex TaoZex commented Apr 18, 2026

Description

This PR implements Rollout Routing Replay (R3) for MoE models, addressing training instability caused by inference-training routing discrepancy in asynchronous RL training. R3 records expert routing indices from the inference engine and replays them during training, ensuring consistent expert selection regardless of weight staleness.

Key Changes

Core MoE Patch (router_replay_patch.py):

  • RouterReplay class (one per MoE layer) with RECORD/REPLAY_FORWARD/REPLAY_BACKWARD actions
  • patched_routing: replaces TopKRouter.routing — uses scores.gather(1, target_topk_idx) in replay mode instead of torch.topk, preserving gradient flow
  • Four monkey-patches: TransformerConfig.__init__, TopKRouter.__init__, TopKRouter.routing, MoEAlltoAllTokenDispatcher.preprocess

Data Distribution (router_replay_utils.py):

  • set_router_replay_data: 4-step pipeline — right-pad→left-align → TP/SP scatter → PP layer slice → Dense/MoE mapping
  • RouterReplayHelper: locates RouterReplay instances by (pp_rank, vp_stage)
  • Layer allocation helpers: get_num_layers_to_build, get_moe_num_layers_to_build (PP/VP aware)

MegatronEngine Integration (megatron_engine_r3_patch.py):

  • Wraps forward_backward_batch: retrieves routed_experts via side-channel, splits per micro-batch, injects replay setup via per-instance class swap, toggles forward/backward replay mode, cleans up in finally

Actor & Workflow Integration (actor_r3_patch.py, rlvr_r3_patch.py):

  • Actor: splits routed_experts per mini-batch, delivers via engine side-channel (bypasses pack_tensor_dict 4D incompatibility)
  • Workflow: resolve_r3_moe_config auto-resolves num_moe_layers/topk from HF config; extract_routed_experts converts SGLang numpy output to left-padded torch tensor

SGLang Integration (sglang_r3_patch.py, sglang_remote.py):

  • Server patch: pre-encodes routed_experts as base64 in TokenizerManager._handle_batch_output (fixes jsonable_encoder silently flattening torch.Tensor to {} when skip_tokenizer_init=True)
  • Client: decodes base64, validates num_sgl_token divisibility

Orchestrator & Config (rl_trainer.py, cli_args.py):

  • return_routed_experts=True → auto-sets enable_router_replay, resolves MoE config, forces skip_tokenizer_init=True, validates SGLang-only support

Supported Parallelism

Dimension Supported Mechanism
TP scatter_to_sequence_parallel_region + seq_align_to by tp_size
PP get_current_rank_layer_info slices per PP rank's MoE layers
VP Cumulative offset by vp_stage in RouterReplayHelper
CP seq_align_to = tp_size * cp_size * 2 when cp_size > 1
DP Data flows with mini-batches; no conflict

New Metrics

rollout_train_logprobs_abs_diff_mean:

mean = (1/|M|) * Σ_{i∈M} |log π_rollout(a_i|s_i) - log π_train(a_i|s_i)|,  M = {i: loss_mask[i]=1}

Mean absolute difference between rollout and training log-probs over response tokens. Reflects routing-inconsistency-induced policy deviation; R3 should reduce this to only weight-update drift.

rollout_train_logprobs_abs_diff_std:

std = sqrt((1/(|M|-1)) * Σ_{i∈M} (|log π_rollout(a_i|s_i) - log π_train(a_i|s_i)| - mean)²)

Standard deviation of the above differences. High values indicate extreme outliers from completely inconsistent routing — the primary cause of PPO training collapse via disproportionate gradients.

Both computed in torch.no_grad() on detached tensors; negligible overhead.

Related Paper

Stabilizing MoE Reinforcement Learning by Aligning Training and Inference Routers (Ma et al., arXiv:2510.11370, 2025) — proposes R3 to reduce training-inference policy KL divergence and prevent MoE RL training collapse.

Related Issue

Fixes #(issue)

Type of Change

  • 🐛 Bug fix
  • ✨ New feature
  • 💥 Breaking change
  • 📝 Documentation update
  • ♻️ Refactoring
  • ⚡ Performance improvement
  • ✅ Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • Pre-commit hooks pass (pre-commit run --all-files)
  • Relevant tests pass; new tests added for new functionality
  • Documentation updated (if applicable; built with ./docs/build_all.sh)
  • Branch is up to date with main
  • Self-reviewed via /review-pr command
  • This PR was created by a coding agent via /create-pr
  • This PR is a breaking change

Breaking Change Details (if applicable):

N/A

Additional Context

  • Backward Compatible: return_routed_experts=False (default) → all R3 code inactive, zero overhead
  • SGLang Only: vLLM backend does not support return_routed_experts; config validation raises explicit error
  • Side-Channel Delivery: routed_experts delivered via engine._r3_pending_routed_experts to bypass pack_tensor_dict 4D incompatibility
  • Server Patch Required: sglang_r3_patch must be installed on inference server to fix torch.Tensor serialization when skip_tokenizer_init=True

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements Router Replay (R3) to align Mixture-of-Experts (MoE) routing decisions between rollout inference and training, preventing performance degradation caused by weight staleness in RL. The changes include monkey-patches for Megatron-Core components, engine-level wrappers for micro-batch scheduling, and workflow integrations to propagate routing indices from SGLang. Feedback focuses on critical architectural issues regarding global state and thread safety, specifically the risks of patching class-level iterators and using global lists for router instances. Additionally, there are recommendations to fix potential data loss in uneven batch splitting and to optimize performance by removing GPU-CPU synchronization points in the data processing pipeline.

Comment thread areal/engine/megatron_engine_r3_patch.py Outdated
Comment thread areal/engine/router_replay_patch.py
Comment thread areal/engine/megatron_engine_r3_patch.py Outdated
Comment thread areal/trainer/ppo/actor_r3_patch.py Outdated
Comment thread areal/engine/router_replay_utils.py Outdated
Comment thread areal/engine/router_replay_patch.py
@TaoZex TaoZex marked this pull request as ready for review April 29, 2026 16:03
@TaoZex
Copy link
Copy Markdown
Collaborator Author

TaoZex commented Apr 29, 2026

  1. Unit Test Results (including test_r3_mask_alignment.py and test_router_replay.py)
image
  1. End-to-End Test Results (including test_router_replay_e2e.py)
image

@TaoZex
Copy link
Copy Markdown
Collaborator Author

TaoZex commented Apr 30, 2026

Metric Comparison: Router Replay (r3)

Using the moonlight_16b_a3b_gsm8k_grpo_megatron_h20.yaml configuration, compare the metric results with router replay (r3) enabled versus disabled:

  1. rollout_train_logprobs_abs_diff_std metric comparison
    (Standard deviation of absolute differences in training log-probabilities during rollout)
image
  1. rollout_train_logprobs_abs_diff_mean metric comparison
    (Mean of absolute differences in training log-probabilities during rollout)
image

@TaoZex TaoZex marked this pull request as draft April 30, 2026 07:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant