diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index a6cfe9083a..0bf4e71477 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -901,7 +901,9 @@ def train( if not eval_mode: # take one LR step every rollout batch - self.scheduler.step(increment=1) + # we need to scale the step by gbs to counteract the fact that NeMo automatically + # scales lr_warmup_steps by gbs during init + self.scheduler.step(increment=gbs) # Aggregate metrics across all microbatches mb_metrics = defaultdict(list)