diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 45582eb591..20328e6c4a 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -255,10 +255,6 @@ policy: weight_decay: 0.01 betas: [0.9, 0.999] eps: 1e-8 - # when using Dtensor, we need to set foreach - # and fused to False - foreach: False - fused: False scheduler: - name: "torch.optim.lr_scheduler.LinearLR" diff --git a/examples/configs/recipes/llm/grpo-glm47-flash-4n8g-automodel.yaml b/examples/configs/recipes/llm/grpo-glm47-flash-4n8g-automodel.yaml index 894f31c388..ef7dcfc514 100644 --- a/examples/configs/recipes/llm/grpo-glm47-flash-4n8g-automodel.yaml +++ b/examples/configs/recipes/llm/grpo-glm47-flash-4n8g-automodel.yaml @@ -18,9 +18,14 @@ policy: logprob_chunk_size: 4096 offload_optimizer_for_logprob: true optimizer: + name: "transformer_engine.pytorch.optimizers.fused_adam.FusedAdam" kwargs: lr: 1.0e-06 weight_decay: 0.1 + master_weights: true + store_param_remainders: true + exp_avg_dtype: "torch.bfloat16" + exp_avg_sq_dtype: "torch.bfloat16" scheduler: - name: torch.optim.lr_scheduler.LinearLR kwargs: diff --git a/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-2n8g-automodel-ep16.yaml b/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-2n8g-automodel-ep16.yaml index 22d8c6dbda..8a9ccefa50 100644 --- a/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-2n8g-automodel-ep16.yaml +++ b/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-2n8g-automodel-ep16.yaml @@ -6,6 +6,13 @@ policy: train_micro_batch_size: 1 logprob_batch_size: 1 max_total_sequence_length: 4096 + optimizer: + name: transformer_engine.pytorch.optimizers.fused_adam.FusedAdam + kwargs: + master_weights: true + store_param_remainders: true + exp_avg_dtype: torch.bfloat16 + exp_avg_sq_dtype: torch.bfloat16 dtensor_cfg: expert_parallel_size: 16 activation_checkpointing: true @@ -28,7 +35,6 @@ policy: generation: vllm_cfg: tensor_parallel_size: 8 - # set to eager mode to mitigate https://github.com/vllm-project/vllm/issues/36237 enforce_eager: true logger: wandb_enabled: true diff --git a/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-dapo-4n8g-automodel.yaml b/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-dapo-4n8g-automodel.yaml index 1db1c749fd..b71ff1ce2b 100644 --- a/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-dapo-4n8g-automodel.yaml +++ b/examples/configs/recipes/llm/grpo-qwen3.5-35ba3b-dapo-4n8g-automodel.yaml @@ -24,9 +24,14 @@ policy: logprob_chunk_size: 4096 offload_optimizer_for_logprob: true optimizer: + name: transformer_engine.pytorch.optimizers.fused_adam.FusedAdam kwargs: lr: 1.0e-06 weight_decay: 0.1 + master_weights: true + store_param_remainders: true + exp_avg_dtype: torch.bfloat16 + exp_avg_sq_dtype: torch.bfloat16 scheduler: - name: torch.optim.lr_scheduler.LinearLR kwargs: diff --git a/examples/configs/recipes/vlm/vlm_grpo-qwen3.5-35ba3b-geo3k-2n8g-automodel-ep16.yaml b/examples/configs/recipes/vlm/vlm_grpo-qwen3.5-35ba3b-geo3k-2n8g-automodel-ep16.yaml index eb7ee542c7..f9bd24b224 100644 --- a/examples/configs/recipes/vlm/vlm_grpo-qwen3.5-35ba3b-geo3k-2n8g-automodel-ep16.yaml +++ b/examples/configs/recipes/vlm/vlm_grpo-qwen3.5-35ba3b-geo3k-2n8g-automodel-ep16.yaml @@ -13,9 +13,14 @@ policy: logprob_batch_size: 1 max_total_sequence_length: 3072 optimizer: + name: transformer_engine.pytorch.optimizers.fused_adam.FusedAdam kwargs: lr: 1.0e-06 weight_decay: 0.1 + master_weights: true + store_param_remainders: true + exp_avg_dtype: torch.bfloat16 + exp_avg_sq_dtype: torch.bfloat16 dtensor_cfg: expert_parallel_size: 16 activation_checkpointing: true diff --git a/nemo_rl/models/automodel/setup.py b/nemo_rl/models/automodel/setup.py index a539560b0d..5ba5b66b47 100644 --- a/nemo_rl/models/automodel/setup.py +++ b/nemo_rl/models/automodel/setup.py @@ -711,7 +711,20 @@ def setup_model_and_optimizer( optimizer = None if init_optimizer: optimizer_cls = get_class(config["optimizer"]["name"]) - optimizer = optimizer_cls(model.parameters(), **config["optimizer"]["kwargs"]) + optimizer_kwargs = dict(config["optimizer"]["kwargs"]) + # Resolve string-valued torch dtypes (e.g. "torch.bfloat16" -> torch.bfloat16) + for key, value in optimizer_kwargs.items(): + if isinstance(value, str) and value.startswith("torch."): + optimizer_kwargs[key] = getattr(torch, value.removeprefix("torch.")) + # Only pass trainable params to the optimizer. TE FusedAdam's step() + # allocates per-param state (exp_avg/exp_avg_sq/master_param) before the + # p.grad-is-None check, so passing frozen params (e.g. the visual + # encoder in text-only training) causes DCP to save unused state that + # later fails to reshard on resume. + optimizer = optimizer_cls( + (p for p in model.parameters() if p.requires_grad), + **optimizer_kwargs, + ) # Initialize scheduler scheduler = None