From afd478d446cb4137e7267fef7caf53d49235459a Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 17 Jul 2025 09:59:42 -0700 Subject: [PATCH] fix LR increment Signed-off-by: ashors1 --- nemo_rl/models/policy/megatron_policy_worker.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index a6cfe9083a..0bf4e71477 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -901,7 +901,9 @@ def train( if not eval_mode: # take one LR step every rollout batch - self.scheduler.step(increment=1) + # we need to scale the step by gbs to counteract the fact that NeMo automatically + # scales lr_warmup_steps by gbs during init + self.scheduler.step(increment=gbs) # Aggregate metrics across all microbatches mb_metrics = defaultdict(list)