Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,9 @@ def train(

if not eval_mode:
# take one LR step every rollout batch
self.scheduler.step(increment=1)
# we need to scale the step by gbs to counteract the fact that NeMo automatically
# scales lr_warmup_steps by gbs during init
self.scheduler.step(increment=gbs)

# Aggregate metrics across all microbatches
mb_metrics = defaultdict(list)
Expand Down
Loading