Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions applications/Chat/coati/trainer/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@ def training_step(self, experience: Experience) -> Dict[str, float]:

# ptx loss
if self.ptx_coef != 0:
ptx = next(iter(self.pretrain_dataloader))['input_ids'].to(torch.cuda.current_device())
label = next(iter(self.pretrain_dataloader))['labels'].to(torch.cuda.current_device())[:, 1:]
attention_mask = next(iter(self.pretrain_dataloader))['attention_mask'].to(torch.cuda.current_device())
batch = next(iter(self.pretrain_dataloader))
ptx = batch['input_ids'].to(torch.cuda.current_device())
label = batch['labels'].to(torch.cuda.current_device())[:, 1:]
attention_mask = batch['attention_mask'].to(torch.cuda.current_device())
ptx_log_probs = self.actor.get_base_model()(ptx, attention_mask=attention_mask)['logits'][..., :-1, :]
ptx_loss = self.ptx_loss_fn(ptx_log_probs.view(-1, ptx_log_probs.size(-1)), label.view(-1))
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
Expand Down
24 changes: 24 additions & 0 deletions applications/Chat/examples/community/EasyPeftModel.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Add Peft support for SFT and Prompts model training

The orginal implementation just adopts the loralib and merges the layers into the final model. The huggingface peft is a better lora model implementation and can be easily training and distributed.

Since reward model is relative small, I just keep it as original one. I suggest train full model to get the proper reward/critic model.

# Prelimenary installation
Since the current pypi peft package(0.2) has some bugs, please install the peft package using source.
```
git clone https://github.com/huggingface/peft
cd peft
pip install .
```

# Usage
For SFT training, just call train_peft_sft.py

Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have a eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py.

For stage-3 rlhf training, call train_peft_prompts.py.
Its arguments are almost idential to train_prompts.py. The only difference is that I use text files to indicate the prompt and pretrained data file. The models are included in easy_models.py. Currently only bloom models are tested, but technically gpt2/opt/llama should be supported.

# Dataformat
Please refer the formats in test_sft.txt, test_prompts.txt, test_pretrained.txt.
242 changes: 242 additions & 0 deletions applications/Chat/examples/community/easy_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
import copy
from typing import Dict, Sequence
from datasets import load_dataset
from torch.utils.data import Dataset
from transformers import AutoTokenizer
import torch
from tqdm import tqdm
import json

from tqdm import tqdm
import json

IGNORE_INDEX = -100


def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer,max_length :int = 512) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=max_length,
truncation=True,
) for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)


def preprocess(
sources: Sequence[str],
targets: Sequence[str],
tokenizer: AutoTokenizer,
max_length :int = 512
) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer,max_length) for strings in (examples, sources)]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
label[:source_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=labels)


class EasySupervisedDataset(Dataset):
def __init__(self, data_file :str, tokenizer :AutoTokenizer,max_length :int = 512) -> None:
super(EasySupervisedDataset,self).__init__()
with open(data_file,"r",encoding="UTF-8") as f:
all_lines = f.readlines()
#split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
sources,targets = [],[]
for line in all_lines:
if "回答:" in line:
sep_index = line.index("回答:")
sources.append(line[:sep_index+3])
targets.append(line[sep_index+3:]+tokenizer.eos_token)
else:
sources.append(line)
targets.append(""+tokenizer.eos_token)
data_dict = preprocess(sources, targets, tokenizer,max_length)

self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
self.data_file = data_file

def __len__(self):
return len(self.input_ids)

def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(input_ids=self.input_ids[i], labels=self.labels[i])

def __repr__(self):
return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})"

def __str__(self):
return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})"

class EasyPromptsDataset(Dataset):
def __init__(self,data_file :str, tokenizer :AutoTokenizer, max_length :int = 96) -> None:
super(EasyPromptsDataset,self).__init__()
with open(data_file,"r",encoding="UTF-8") as f:
all_lines = f.readlines()
all_lines = [line if "回答:" not in line else line[:line.index("回答:")+3] for line in all_lines]
self.prompts = [
tokenizer(line,
return_tensors='pt',
max_length=max_length,
padding='max_length',
truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0)
for line in tqdm(all_lines)
]
self.data_file = data_file
def __len__(self):
return len(self.prompts)

def __getitem__(self, idx):
return self.prompts[idx]

def __repr__(self):
return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})"

def __str__(self):
return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})"


class EasyRewardDataset(Dataset):
def __init__(self,train_file :str,tokenizer :AutoTokenizer, special_token = None,max_length = 512) -> None:
super(EasyRewardDataset,self).__init__()
self.chosen = []
self.reject = []
if special_token is None:
self.end_token = tokenizer.eos_token
else:
self.end_token = special_token
print(self.end_token)
#read all lines in the train_file to a list
with open(train_file,"r",encoding="UTF-8") as f:
all_lines = f.readlines()
for line in tqdm(all_lines):
data = json.loads(line)
prompt = "提问:"+data['prompt']+" 回答:"

chosen = prompt + data['chosen'] + self.end_token
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.chosen.append({
"input_ids": chosen_token['input_ids'],
"attention_mask": chosen_token['attention_mask']
})

reject = prompt + data['rejected'] + self.end_token
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject.append({
"input_ids": reject_token['input_ids'],
"attention_mask": reject_token['attention_mask']
})

def __len__(self):
length = len(self.chosen)
return length

def __getitem__(self, idx):
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
"input_ids"], self.reject[idx]["attention_mask"]

#python representation of the object and the string representation of the object
def __repr__(self):
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"

def __str__(self):
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"

'''
Easy SFT just accept a text file which can be read line by line. However the datasest will group texts together to max_length so LLM will learn the texts meaning better.
If individual lines are not related, just set is_group_texts to False.
'''
class EasySFTDataset(Dataset):

def __init__(self,data_file :str,tokenizer :AutoTokenizer,max_length = 512,is_group_texts = True) -> None:
super().__init__()
#read the data_file line by line
with open(data_file,"r",encoding="UTF-8") as f:
#encode the text data line by line and put raw python list input_ids only to raw_input_ids list
raw_input_ids = []
for line in f:
encoded_ids = tokenizer.encode(line)
#if the encoded_ids is longer than max_length, then split it into several parts
if len(encoded_ids) > max_length:
for i in range(0,len(encoded_ids),max_length):
raw_input_ids.append(encoded_ids[i:i+max_length])
else:
raw_input_ids.append(encoded_ids)

grouped_inpup_ids = []
current_input_ids = []
attention_mask = []
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
if is_group_texts:
for input_ids in raw_input_ids:
if len(current_input_ids) + len(input_ids) > max_length:
#pad the current_input_ids to max_length with tokenizer.pad_token_id
padded_length = max_length - len(current_input_ids)
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
grouped_inpup_ids.append(torch.tensor(current_input_ids,dtype=torch.long))
attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long))
current_input_ids = []
else:
current_input_ids.extend(input_ids)
if len(current_input_ids) > 0:
padded_length = max_length - len(current_input_ids)
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
grouped_inpup_ids.append(torch.tensor(current_input_ids,dtype=torch.long))
attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long))
else:
#just append the raw_input_ids to max_length
for input_ids in raw_input_ids:
padded_length = max_length - len(input_ids)
input_ids.extend([tokenizer.pad_token_id] * padded_length)
attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long))
grouped_inpup_ids.append(torch.tensor(input_ids,dtype=torch.long))
self.input_ids = grouped_inpup_ids
self.labels = copy.deepcopy(self.input_ids)
self.file_name = data_file
self.attention_mask = attention_mask

def __len__(self):
return len(self.input_ids)

#get item from dataset
def __getitem__(self,idx):
return dict(input_ids=self.input_ids[idx],labels=self.labels[idx],attention_mask=self.attention_mask[idx])

#generate the dataset description to be printed by print in python
def __repr__(self):
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"

#generate the dataset description to be printed by print in python
def __str__(self):
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"





97 changes: 97 additions & 0 deletions applications/Chat/examples/community/easy_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules import Module

from coati.models.generation import generate
from coati.models.utils import log_probs_from_logits,masked_mean
from transformers import BloomConfig,BloomForCausalLM
from peft import PeftModel

class Actor(Module):
"""
Actor model base class.

Args:
model (nn.Module): Actor Model.
"""

def __init__(self, model: nn.Module) -> None:
super().__init__()
self.model = model

@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
return_action_mask: bool = True,
**kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
sequences = generate(self.model, input_ids, **kwargs)
attention_mask = None
pad_token_id = kwargs.get('pad_token_id', None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
if not return_action_mask:
return sequences, attention_mask, None
input_len = input_ids.size(1)
eos_token_id = kwargs.get('eos_token_id', None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]

def forward(self,
sequences: torch.LongTensor,
num_actions: int,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Returns action log probs
"""
output = self.model(sequences, attention_mask=attention_mask)
logits = output['logits']
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]

def get_base_model(self):
return self.model


class BLOOMActor(Actor):
"""
BLOOM Actor model.

Args:
pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""

def __init__(self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
checkpoint: bool = False,
lora_path: str = None) -> None:
if pretrained is not None:
model = BloomForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = BloomForCausalLM(config)
else:
model = BloomForCausalLM(BloomConfig())
if lora_path is not None:
model = PeftModel.from_pretrained(model,lora_path)
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model)

def print_trainable_parameters(self):
self.get_base_model().print_trainable_parameters()

Loading