Skip to content

Qwen3.5 MoE Automodel path does not use FP32 master weights by default, causing slower convergence #2322

@zpqiu

Description

@zpqiu

Summary

On the Automodel (FSDP2) path, Qwen3.5 MoE (and other custom MoE models) end up with the optimizer operating on BF16 parameters rather than FP32 master weights. With torch.optim.AdamW, updates of magnitude lr · m / (√v + ε) at typical RLHF learning rates are well below the BF16 quantization step and get rounded to zero on most parameters. Convergence lags the MCore (distributed optimizer) path, which keeps FP32 master weights externally.

The root-cause fix belongs in nemo-automodel and is tracked separately — see NVIDIA-NeMo/Automodel#1896.

This issue tracks the NeMo-RL side:

  1. Document the symptom so users picking the automodel path for MoE models are aware.
  2. Ship the FusedAdam workaround (fix: enable TE FusedAdam for Qwen3.5 MoE & GLM-4.7-Flash automodel recipes #2320) for the affected recipes, which bypasses the missing FP32 master weights by having the optimizer hold its own FP32 master copy internally.

Symptoms

Same workload (Qwen3.5-35B-A3B-Base, DAPO-style GRPO, BF16), three setups:

  • Yellow: MCore path
  • Red: Automodel + torch.optim.AdamW (default)
  • Grey: Automodel + TE FusedAdam (master_weights=True)
Image

Affected recipes (NeMo-RL side)

Automodel-path GRPO recipes for Qwen3.5 MoE:

  • examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-2n8g-automodel-ep16.yaml
  • examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-dapo-4n8g-automodel.yaml
  • examples/configs/recipes/vlm/vlm_grpo-qwen3.5-35ba3b-geo3k-2n8g-automodel-ep16.yaml

(-megatron recipes go through the MCore distributed optimizer and are not affected.)

Workaround (#2320)

Switch the affected recipes to transformer_engine.pytorch.optimizers.fused_adam.FusedAdam with:

policy:
  optimizer:
    name: "transformer_engine.pytorch.optimizers.fused_adam.FusedAdam"
    kwargs:
      lr: ...
      weight_decay: ...
      betas: [0.9, 0.999]
      eps: 1.0e-08
      master_weights: true
      store_param_remainders: true
      exp_avg_dtype: "torch.bfloat16"
      exp_avg_sq_dtype: "torch.bfloat16"

FusedAdam holds an internal FP32 master copy of the parameters, so the BF16 parameter storage on the Automodel MoE path no longer matters for optimizer precision.

Resolution criteria

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions