Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,6 @@ policy:
weight_decay: 0.01
betas: [0.9, 0.999]
eps: 1e-8
# when using Dtensor, we need to set foreach
# and fused to False
foreach: False
fused: False

scheduler:
- name: "torch.optim.lr_scheduler.LinearLR"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@ policy:
logprob_chunk_size: 4096
offload_optimizer_for_logprob: true
optimizer:
name: "transformer_engine.pytorch.optimizers.fused_adam.FusedAdam"
kwargs:
lr: 1.0e-06
weight_decay: 0.1
master_weights: true
store_param_remainders: true
exp_avg_dtype: "torch.bfloat16"
exp_avg_sq_dtype: "torch.bfloat16"
scheduler:
- name: torch.optim.lr_scheduler.LinearLR
kwargs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ policy:
train_micro_batch_size: 1
logprob_batch_size: 1
max_total_sequence_length: 4096
optimizer:
name: transformer_engine.pytorch.optimizers.fused_adam.FusedAdam
kwargs:
master_weights: true
store_param_remainders: true
exp_avg_dtype: torch.bfloat16
exp_avg_sq_dtype: torch.bfloat16
dtensor_cfg:
expert_parallel_size: 16
activation_checkpointing: true
Expand All @@ -28,7 +35,6 @@ policy:
generation:
vllm_cfg:
tensor_parallel_size: 8
# set to eager mode to mitigate https://github.com/vllm-project/vllm/issues/36237
enforce_eager: true
logger:
wandb_enabled: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@ policy:
logprob_chunk_size: 4096
offload_optimizer_for_logprob: true
optimizer:
name: transformer_engine.pytorch.optimizers.fused_adam.FusedAdam
kwargs:
lr: 1.0e-06
weight_decay: 0.1
master_weights: true
store_param_remainders: true
exp_avg_dtype: torch.bfloat16
exp_avg_sq_dtype: torch.bfloat16
scheduler:
- name: torch.optim.lr_scheduler.LinearLR
kwargs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,14 @@ policy:
logprob_batch_size: 1
max_total_sequence_length: 3072
optimizer:
name: transformer_engine.pytorch.optimizers.fused_adam.FusedAdam
kwargs:
lr: 1.0e-06
weight_decay: 0.1
master_weights: true
store_param_remainders: true
exp_avg_dtype: torch.bfloat16
exp_avg_sq_dtype: torch.bfloat16
dtensor_cfg:
expert_parallel_size: 16
activation_checkpointing: true
Expand Down
15 changes: 14 additions & 1 deletion nemo_rl/models/automodel/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,20 @@ def setup_model_and_optimizer(
optimizer = None
if init_optimizer:
optimizer_cls = get_class(config["optimizer"]["name"])
optimizer = optimizer_cls(model.parameters(), **config["optimizer"]["kwargs"])
optimizer_kwargs = dict(config["optimizer"]["kwargs"])
# Resolve string-valued torch dtypes (e.g. "torch.bfloat16" -> torch.bfloat16)
for key, value in optimizer_kwargs.items():
if isinstance(value, str) and value.startswith("torch."):
optimizer_kwargs[key] = getattr(torch, value.removeprefix("torch."))
# Only pass trainable params to the optimizer. TE FusedAdam's step()
# allocates per-param state (exp_avg/exp_avg_sq/master_param) before the
# p.grad-is-None check, so passing frozen params (e.g. the visual
# encoder in text-only training) causes DCP to save unused state that
# later fails to reshard on resume.
optimizer = optimizer_cls(
(p for p in model.parameters() if p.requires_grad),
**optimizer_kwargs,
)

# Initialize scheduler
scheduler = None
Expand Down
Loading