🐛 Describe the bug
When the pretrained data loss is used to train the prompts model in stage-3, the pretrained data loader fetch 3 times data from the dataloader to get input_ids, labels and attention_mask.
The following is the current implementation in ppo.py.
# ptx loss
if self.ptx_coef != 0:
ptx = next(iter(self.pretrain_dataloader))['input_ids'].to(torch.cuda.current_device())
label = next(iter(self.pretrain_dataloader))['labels'].to(torch.cuda.current_device())[:, 1:]
attention_mask = next(iter(self.pretrain_dataloader))['attention_mask'].to(torch.cuda.current_device())
ptx_log_probs = self.actor.get_base_model()(ptx, attention_mask=attention_mask)['logits'][..., :-1, :]
ptx_loss = self.ptx_loss_fn(ptx_log_probs.view(-1, ptx_log_probs.size(-1)), label.view(-1))
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
It's an obvious bug here but if you guys are using exact same length of data, that bug will not terminate your training procedure. But the loss is totally wrong since ptx,label and attention_mask are from different data batch.
Environment
Pytorch 1.12+cuda 11.6+CentOS 7
🐛 Describe the bug
When the pretrained data loss is used to train the prompts model in stage-3, the pretrained data loader fetch 3 times data from the dataloader to get input_ids, labels and attention_mask.
The following is the current implementation in ppo.py.
It's an obvious bug here but if you guys are using exact same length of data, that bug will not terminate your training procedure. But the loss is totally wrong since ptx,label and attention_mask are from different data batch.
Environment
Pytorch 1.12+cuda 11.6+CentOS 7