Skip to content

[BUG]: Fetching wrong data from pretrained dataloader when ptx_coef is not zero in Staging 3 training #3432

@yynil

Description

@yynil

🐛 Describe the bug

When the pretrained data loss is used to train the prompts model in stage-3, the pretrained data loader fetch 3 times data from the dataloader to get input_ids, labels and attention_mask.
The following is the current implementation in ppo.py.

        # ptx loss
        if self.ptx_coef != 0:
            ptx = next(iter(self.pretrain_dataloader))['input_ids'].to(torch.cuda.current_device())
            label = next(iter(self.pretrain_dataloader))['labels'].to(torch.cuda.current_device())[:, 1:]
            attention_mask = next(iter(self.pretrain_dataloader))['attention_mask'].to(torch.cuda.current_device())
            ptx_log_probs = self.actor.get_base_model()(ptx, attention_mask=attention_mask)['logits'][..., :-1, :]
            ptx_loss = self.ptx_loss_fn(ptx_log_probs.view(-1, ptx_log_probs.size(-1)), label.view(-1))
            actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)

It's an obvious bug here but if you guys are using exact same length of data, that bug will not terminate your training procedure. But the loss is totally wrong since ptx,label and attention_mask are from different data batch.

Environment

Pytorch 1.12+cuda 11.6+CentOS 7

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions