You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
On the Automodel (FSDP2) path, Qwen3.5 MoE (and other custom MoE models) end up with the optimizer operating on BF16 parameters rather than FP32 master weights. With torch.optim.AdamW, updates of magnitude lr · m / (√v + ε) at typical RLHF learning rates are well below the BF16 quantization step and get rounded to zero on most parameters. Convergence lags the MCore (distributed optimizer) path, which keeps FP32 master weights externally.
The root-cause fix belongs in nemo-automodel and is tracked separately — see NVIDIA-NeMo/Automodel#1896.
This issue tracks the NeMo-RL side:
Document the symptom so users picking the automodel path for MoE models are aware.
FusedAdam holds an internal FP32 master copy of the parameters, so the BF16 parameter storage on the Automodel MoE path no longer matters for optimizer precision.
Long-term: root-cause fix in nemo-automodel lands and the automodel MoE path produces FP32 master weights with torch.optim.AdamWfix: fp32 master weights for custom MoE models under FSDP2 Automodel#1896. Once that ships, the FusedAdam override in these recipes can be revisited (switch back to AdamW, or keep FusedAdam on perf grounds).
Summary
On the Automodel (FSDP2) path, Qwen3.5 MoE (and other custom MoE models) end up with the optimizer operating on BF16 parameters rather than FP32 master weights. With
torch.optim.AdamW, updates of magnitudelr · m / (√v + ε)at typical RLHF learning rates are well below the BF16 quantization step and get rounded to zero on most parameters. Convergence lags the MCore (distributed optimizer) path, which keeps FP32 master weights externally.The root-cause fix belongs in
nemo-automodeland is tracked separately — see NVIDIA-NeMo/Automodel#1896.This issue tracks the NeMo-RL side:
Symptoms
Same workload (Qwen3.5-35B-A3B-Base, DAPO-style GRPO, BF16), three setups:
torch.optim.AdamW(default)master_weights=True)Affected recipes (NeMo-RL side)
Automodel-path GRPO recipes for Qwen3.5 MoE:
examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-2n8g-automodel-ep16.yamlexamples/configs/recipes/llm/grpo-qwen3.5-35ba3b-dapo-4n8g-automodel.yamlexamples/configs/recipes/vlm/vlm_grpo-qwen3.5-35ba3b-geo3k-2n8g-automodel-ep16.yaml(
-megatronrecipes go through the MCore distributed optimizer and are not affected.)Workaround (#2320)
Switch the affected recipes to
transformer_engine.pytorch.optimizers.fused_adam.FusedAdamwith:FusedAdam holds an internal FP32 master copy of the parameters, so the BF16 parameter storage on the Automodel MoE path no longer matters for optimizer precision.
Resolution criteria
nemo-automodellands and the automodel MoE path produces FP32 master weights withtorch.optim.AdamWfix: fp32 master weights for custom MoE models under FSDP2 Automodel#1896. Once that ships, the FusedAdam override in these recipes can be revisited (switch back to AdamW, or keep FusedAdam on perf grounds).