From 6c2a3f5b07672020166e356b3c1da2bfec98fa73 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Thu, 23 Apr 2026 00:16:50 -0700 Subject: [PATCH 1/4] fix: enable TE FusedAdam for Qwen3.5 MoE automodel recipes Resolve string-valued torch dtypes (e.g. "torch.bfloat16") in the automodel optimizer kwargs so TE FusedAdam's exp_avg_dtype and exp_avg_sq_dtype can be specified from YAML. Migrate the three Qwen3.5-35B-A3B automodel GRPO recipes (llm 2n8g EP16, llm DAPO 4n8g, vlm geo3k 2n8g EP16) from torch.optim.AdamW to transformer_engine.pytorch.optimizers.fused_adam.FusedAdam, carrying over lr/weight_decay/betas/eps from the prior settings. Use _override_: true on the optimizer block so the base grpo_math_1B.yaml optimizer config (including foreach/fused=False, which FusedAdam does not accept) is replaced rather than merged. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Zhaopeng Qiu --- .../llm/grpo-qwen3.5-35ba3b-2n8g-automodel-ep16.yaml | 12 ++++++++++++ .../llm/grpo-qwen3.5-35ba3b-dapo-4n8g-automodel.yaml | 8 ++++++++ ...rpo-qwen3.5-35ba3b-geo3k-2n8g-automodel-ep16.yaml | 8 ++++++++ nemo_rl/models/automodel/setup.py | 7 ++++++- 4 files changed, 34 insertions(+), 1 deletion(-) diff --git a/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-2n8g-automodel-ep16.yaml b/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-2n8g-automodel-ep16.yaml index 22d8c6dbda..d898ba26d3 100644 --- a/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-2n8g-automodel-ep16.yaml +++ b/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-2n8g-automodel-ep16.yaml @@ -6,6 +6,18 @@ policy: train_micro_batch_size: 1 logprob_batch_size: 1 max_total_sequence_length: 4096 + optimizer: + _override_: true + name: "transformer_engine.pytorch.optimizers.fused_adam.FusedAdam" + kwargs: + lr: 5.0e-06 + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1.0e-08 + master_weights: true + store_param_remainders: true + exp_avg_dtype: "torch.bfloat16" + exp_avg_sq_dtype: "torch.bfloat16" dtensor_cfg: expert_parallel_size: 16 activation_checkpointing: true diff --git a/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-dapo-4n8g-automodel.yaml b/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-dapo-4n8g-automodel.yaml index 1db1c749fd..181d5e8020 100644 --- a/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-dapo-4n8g-automodel.yaml +++ b/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-dapo-4n8g-automodel.yaml @@ -24,9 +24,17 @@ policy: logprob_chunk_size: 4096 offload_optimizer_for_logprob: true optimizer: + _override_: true + name: "transformer_engine.pytorch.optimizers.fused_adam.FusedAdam" kwargs: lr: 1.0e-06 weight_decay: 0.1 + betas: [0.9, 0.999] + eps: 1.0e-08 + master_weights: true + store_param_remainders: true + exp_avg_dtype: "torch.bfloat16" + exp_avg_sq_dtype: "torch.bfloat16" scheduler: - name: torch.optim.lr_scheduler.LinearLR kwargs: diff --git a/examples/configs/recipes/vlm/vlm_grpo-qwen3.5-35ba3b-geo3k-2n8g-automodel-ep16.yaml b/examples/configs/recipes/vlm/vlm_grpo-qwen3.5-35ba3b-geo3k-2n8g-automodel-ep16.yaml index eb7ee542c7..3fa4a45424 100644 --- a/examples/configs/recipes/vlm/vlm_grpo-qwen3.5-35ba3b-geo3k-2n8g-automodel-ep16.yaml +++ b/examples/configs/recipes/vlm/vlm_grpo-qwen3.5-35ba3b-geo3k-2n8g-automodel-ep16.yaml @@ -13,9 +13,17 @@ policy: logprob_batch_size: 1 max_total_sequence_length: 3072 optimizer: + _override_: true + name: "transformer_engine.pytorch.optimizers.fused_adam.FusedAdam" kwargs: lr: 1.0e-06 weight_decay: 0.1 + betas: [0.9, 0.999] + eps: 1.0e-08 + master_weights: true + store_param_remainders: true + exp_avg_dtype: "torch.bfloat16" + exp_avg_sq_dtype: "torch.bfloat16" dtensor_cfg: expert_parallel_size: 16 activation_checkpointing: true diff --git a/nemo_rl/models/automodel/setup.py b/nemo_rl/models/automodel/setup.py index a539560b0d..f5010e932d 100644 --- a/nemo_rl/models/automodel/setup.py +++ b/nemo_rl/models/automodel/setup.py @@ -711,7 +711,12 @@ def setup_model_and_optimizer( optimizer = None if init_optimizer: optimizer_cls = get_class(config["optimizer"]["name"]) - optimizer = optimizer_cls(model.parameters(), **config["optimizer"]["kwargs"]) + optimizer_kwargs = dict(config["optimizer"]["kwargs"]) + # Resolve string-valued torch dtypes (e.g. "torch.bfloat16" -> torch.bfloat16) + for key, value in optimizer_kwargs.items(): + if isinstance(value, str) and value.startswith("torch."): + optimizer_kwargs[key] = getattr(torch, value.removeprefix("torch.")) + optimizer = optimizer_cls(model.parameters(), **optimizer_kwargs) # Initialize scheduler scheduler = None From 8b4f3b211e2336df21eb935d991f5b0914304847 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Thu, 23 Apr 2026 06:47:35 -0700 Subject: [PATCH 2/4] fix: drop foreach/fused=False from grpo_math_1B.yaml base optimizer The `foreach: False` / `fused: False` kwargs were added ~1 year ago in the initial FSDP2/DTensor support PR (#131, commit 085fa666) as a defensive measure for DTensor compatibility. PyTorch DTensor has since added native `_foreach_*` kernel coverage and the auto-selected defaults (`foreach=None`, `fused=None`) are correct for DTensor tensors on the currently pinned `torch==2.10.0`. Dropping these from the base unblocks using TE FusedAdam on recipes that inherit from grpo_math_1B.yaml without needing the previous `_override_: true` trick, because FusedAdam does not accept those AdamW-only kwargs. Re-minimizes the three Qwen3.5 MoE automodel recipes accordingly: the `_override_: true` markers are removed and kwargs that now match the (cleaner) base are elided. Scoped to grpo_math_1B.yaml only; the same cleanup for sft/dpo/rm base configs is deferred to a follow-up once this change is validated in nightly. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Zhaopeng Qiu --- examples/configs/grpo_math_1B.yaml | 4 ---- .../llm/grpo-qwen3.5-35ba3b-2n8g-automodel-ep16.yaml | 12 +++--------- .../llm/grpo-qwen3.5-35ba3b-dapo-4n8g-automodel.yaml | 9 +++------ ...rpo-qwen3.5-35ba3b-geo3k-2n8g-automodel-ep16.yaml | 9 +++------ 4 files changed, 9 insertions(+), 25 deletions(-) diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 45582eb591..20328e6c4a 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -255,10 +255,6 @@ policy: weight_decay: 0.01 betas: [0.9, 0.999] eps: 1e-8 - # when using Dtensor, we need to set foreach - # and fused to False - foreach: False - fused: False scheduler: - name: "torch.optim.lr_scheduler.LinearLR" diff --git a/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-2n8g-automodel-ep16.yaml b/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-2n8g-automodel-ep16.yaml index d898ba26d3..8a9ccefa50 100644 --- a/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-2n8g-automodel-ep16.yaml +++ b/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-2n8g-automodel-ep16.yaml @@ -7,17 +7,12 @@ policy: logprob_batch_size: 1 max_total_sequence_length: 4096 optimizer: - _override_: true - name: "transformer_engine.pytorch.optimizers.fused_adam.FusedAdam" + name: transformer_engine.pytorch.optimizers.fused_adam.FusedAdam kwargs: - lr: 5.0e-06 - weight_decay: 0.01 - betas: [0.9, 0.999] - eps: 1.0e-08 master_weights: true store_param_remainders: true - exp_avg_dtype: "torch.bfloat16" - exp_avg_sq_dtype: "torch.bfloat16" + exp_avg_dtype: torch.bfloat16 + exp_avg_sq_dtype: torch.bfloat16 dtensor_cfg: expert_parallel_size: 16 activation_checkpointing: true @@ -40,7 +35,6 @@ policy: generation: vllm_cfg: tensor_parallel_size: 8 - # set to eager mode to mitigate https://github.com/vllm-project/vllm/issues/36237 enforce_eager: true logger: wandb_enabled: true diff --git a/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-dapo-4n8g-automodel.yaml b/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-dapo-4n8g-automodel.yaml index 181d5e8020..b71ff1ce2b 100644 --- a/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-dapo-4n8g-automodel.yaml +++ b/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-dapo-4n8g-automodel.yaml @@ -24,17 +24,14 @@ policy: logprob_chunk_size: 4096 offload_optimizer_for_logprob: true optimizer: - _override_: true - name: "transformer_engine.pytorch.optimizers.fused_adam.FusedAdam" + name: transformer_engine.pytorch.optimizers.fused_adam.FusedAdam kwargs: lr: 1.0e-06 weight_decay: 0.1 - betas: [0.9, 0.999] - eps: 1.0e-08 master_weights: true store_param_remainders: true - exp_avg_dtype: "torch.bfloat16" - exp_avg_sq_dtype: "torch.bfloat16" + exp_avg_dtype: torch.bfloat16 + exp_avg_sq_dtype: torch.bfloat16 scheduler: - name: torch.optim.lr_scheduler.LinearLR kwargs: diff --git a/examples/configs/recipes/vlm/vlm_grpo-qwen3.5-35ba3b-geo3k-2n8g-automodel-ep16.yaml b/examples/configs/recipes/vlm/vlm_grpo-qwen3.5-35ba3b-geo3k-2n8g-automodel-ep16.yaml index 3fa4a45424..f9bd24b224 100644 --- a/examples/configs/recipes/vlm/vlm_grpo-qwen3.5-35ba3b-geo3k-2n8g-automodel-ep16.yaml +++ b/examples/configs/recipes/vlm/vlm_grpo-qwen3.5-35ba3b-geo3k-2n8g-automodel-ep16.yaml @@ -13,17 +13,14 @@ policy: logprob_batch_size: 1 max_total_sequence_length: 3072 optimizer: - _override_: true - name: "transformer_engine.pytorch.optimizers.fused_adam.FusedAdam" + name: transformer_engine.pytorch.optimizers.fused_adam.FusedAdam kwargs: lr: 1.0e-06 weight_decay: 0.1 - betas: [0.9, 0.999] - eps: 1.0e-08 master_weights: true store_param_remainders: true - exp_avg_dtype: "torch.bfloat16" - exp_avg_sq_dtype: "torch.bfloat16" + exp_avg_dtype: torch.bfloat16 + exp_avg_sq_dtype: torch.bfloat16 dtensor_cfg: expert_parallel_size: 16 activation_checkpointing: true From 58cc3d708fa7349072978fa88e5cc9f1ed69b335 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Thu, 23 Apr 2026 08:02:18 -0700 Subject: [PATCH 3/4] fix: enable TE FusedAdam for GLM-4.7-Flash automodel recipe Apply the same FusedAdam migration used for the Qwen3.5 MoE recipes: switch torch.optim.AdamW to transformer_engine.pytorch.optimizers.fused_adam.FusedAdam with master_weights=True so the optimizer keeps an internal FP32 master copy, bypassing the missing FP32 master weights on the Automodel custom MoE path. lr / weight_decay are unchanged; betas / eps are inherited from the base grpo_math_1B.yaml. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Zhaopeng Qiu --- .../configs/recipes/llm/grpo-glm47-flash-4n8g-automodel.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/configs/recipes/llm/grpo-glm47-flash-4n8g-automodel.yaml b/examples/configs/recipes/llm/grpo-glm47-flash-4n8g-automodel.yaml index 894f31c388..ef7dcfc514 100644 --- a/examples/configs/recipes/llm/grpo-glm47-flash-4n8g-automodel.yaml +++ b/examples/configs/recipes/llm/grpo-glm47-flash-4n8g-automodel.yaml @@ -18,9 +18,14 @@ policy: logprob_chunk_size: 4096 offload_optimizer_for_logprob: true optimizer: + name: "transformer_engine.pytorch.optimizers.fused_adam.FusedAdam" kwargs: lr: 1.0e-06 weight_decay: 0.1 + master_weights: true + store_param_remainders: true + exp_avg_dtype: "torch.bfloat16" + exp_avg_sq_dtype: "torch.bfloat16" scheduler: - name: torch.optim.lr_scheduler.LinearLR kwargs: From 2a23c0ecea22b43e2b500615d208a82b4d9eda89 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Fri, 24 Apr 2026 06:15:19 -0700 Subject: [PATCH 4/4] fix: skip frozen params when constructing automodel optimizer TE FusedAdam's step() allocates per-parameter state (exp_avg/exp_avg_sq/master_param) before the p.grad-is-None check, so frozen parameters (e.g. the visual encoder in text-only training) still get optimizer state entries. DCP then saves that state, and the next resume fails inside gather_object with a misleading "cannot pickle code objects" (DCP's _wrap_exception captures the real "Size mismatch" ValueError whose traceback contains a CodeType). Pass only requires_grad=True parameters to the optimizer so the frozen visual subtree never enters optimizer state in the first place. This also matches the standard PyTorch idiom and works regardless of which optimizer backend the recipe selects. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Zhaopeng Qiu --- nemo_rl/models/automodel/setup.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/nemo_rl/models/automodel/setup.py b/nemo_rl/models/automodel/setup.py index f5010e932d..5ba5b66b47 100644 --- a/nemo_rl/models/automodel/setup.py +++ b/nemo_rl/models/automodel/setup.py @@ -716,7 +716,15 @@ def setup_model_and_optimizer( for key, value in optimizer_kwargs.items(): if isinstance(value, str) and value.startswith("torch."): optimizer_kwargs[key] = getattr(torch, value.removeprefix("torch.")) - optimizer = optimizer_cls(model.parameters(), **optimizer_kwargs) + # Only pass trainable params to the optimizer. TE FusedAdam's step() + # allocates per-param state (exp_avg/exp_avg_sq/master_param) before the + # p.grad-is-None check, so passing frozen params (e.g. the visual + # encoder in text-only training) causes DCP to save unused state that + # later fails to reshard on resume. + optimizer = optimizer_cls( + (p for p in model.parameters() if p.requires_grad), + **optimizer_kwargs, + ) # Initialize scheduler scheduler = None