diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index ea99de1538..3bcad52849 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -191,6 +191,7 @@ def setup( batch_size=grpo_config["num_prompts_per_step"], shuffle=False, collate_fn=rl_collate_fn, + drop_last=True, ) if last_checkpoint_path is not None: dataloader_state_dict = torch.load(