diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 34a621063d..30fb77adeb 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -268,6 +268,12 @@ def train( for gb_start in range(0, dataset_size, local_gbs): self.optimizer.zero_grad() mb_losses = [] + + # Calculate number of microbatches to process + # make_microbatch_iterator assumes that the batch size is a multiple of the microbatch size + # so its safe to not check for the case where the last data slice is smaller than mbs + num_microbatches = min(local_gbs, dataset_size - gb_start) // mbs + for mb in data.slice( gb_start, gb_start + local_gbs ).make_microbatch_iterator(mbs): @@ -298,6 +304,9 @@ def train( loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] # Backward pass + + # Loss is accumulated across microbatches, so we need to scale by the number of microbatches + loss = loss / num_microbatches if not eval_mode: loss.backward() mb_losses.append(loss.item()) @@ -310,7 +319,7 @@ def train( # Update parameters self.optimizer.step() self.scheduler.step() - losses.append(torch.tensor(mb_losses).mean().item()) + losses.append(torch.tensor(mb_losses).sum().item()) # Compute global loss across all ranks with torch.no_grad():