Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 92 additions & 5 deletions applications/Chat/examples/community/peft/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Add support for ChatGLM-6B and graphic memory friendly PPO trainer

# Add Peft support for SFT and Prompts model training

The orginal implementation just adopts the loralib and merges the layers into the final model. The huggingface peft is a better lora model implementation and can be easily training and distributed.
Expand All @@ -13,12 +15,97 @@ pip install .
```

# Usage
For SFT training, just call train_peft_sft.py

Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have a eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py.
## Data preparation
There are 3 stages to train a RLHF LLM. So we need to prepare data for 3 stages.

## Stage 1, fine tuning LLM with domain knowledge

In common cases, we want LLM to learn some domain knowledge that's not possible for public available LLMs. For an example, if courts want to let a LLM be familar with law cases, the courts will feed law cases texts to LLM to finetune the public available LLMs.

Even you can crawl instructions from other LLMs to finetune you LLM to get a close performance with the LLM you crawl from.

Anyway, let's just use the toys_sft.txt as an example

Since I'm using ChatGLM-6B as my finetuned LLM, I face a huge problem that the tokenizer is a sentencepiece based one with poor performance compared to dictionary based tokenizer. So I need to preprocess the large text corpus and save them to parquet format to save my training/evaluation time.

The following I'll stick to use the model THUDM/chatglm-6b as example.

```
python easy_dataset.py \
--input_file toys_sft.txt \
--output_file toys_sft \
--need_trust_code \
--tokenizer_name THUDM/chatglm-6b \
--dataset_type easy_sft
```
The above command is used to generate binary dataset used by sft training. After finishing the command, we will get a directory 'toys_sft' which contains the binary data in SFT tuning.

Run the training sft script to get the stage-1 sft model.
```
sh train_peft_sft.sh
```

After several seconds, you will get a new directory toys_sft_lora. Only LoRa parameters are stored inside. Your stage-1 is finished!

## Stage-2 Reward model training

In theory, it's unnecessary that the reward model and SFT model share the same pretrained model. However the reality is they have to be the same.

The reason is that in Stage-3 PPO training, the actions are actually the generated ids. These ids has to be evaluated by Critic/Reward model(They are same when the PPO training starts). If reward model and SFT model(aka Actor)'s tokenizers are not identical, the ids generated by SFT mean nothing to Reward model.

Sometimes the same tokenizer might be shared by different models, but in my case, I can't find smaller model sharing the tokenizer with ChatGLM-6B. Even the same model with different scales, it's not garaunteed that share the same tokenizer.

So I use the same pretrained model to train reward model.

We first generate reward binary dataset:
```
python easy_dataset.py \
--input_file toys_rewards_train.jsonl \
--output_file toys_rewards_train \
--need_trust_code \
--tokenizer_name THUDM/chatglm-6b \
--dataset_type reward
```
```
python easy_dataset.py \
--input_file toys_rewards_test.jsonl \
--output_file toys_rewards_test \
--need_trust_code \
--tokenizer_name THUDM/chatglm-6b \
--dataset_type reward
```

Now we get two new directories: toys_rewards_train and toys_rewards_test which contains our data to be used by stage-2 training.

Now we can run the following script to train reward model:
```
sh train_reward_model.sh
```

After seconds of training, we'll get a new directory chatglmrm which also contains only LoRa parameter. Only several megabytes only!

We get the SFT and Reward model training models, then let's move to the final stage-3 PPO training.

For stage-3 rlhf training, call train_peft_prompts.py.
Its arguments are almost idential to train_prompts.py. The only difference is that I use text files to indicate the prompt and pretrained data file. The models are included in easy_models.py. Currently only bloom models are tested, but technically gpt2/opt/llama should be supported.
## Stage 3 PPO Training

PPO training is a difficult process because it has to maintain four models(Actor, Initial model, Critic, Reward) in memory at the same time. I modified the initial trainer that support 6B model training in one 3090ti with 24G VRam.

First we'll prepare the training data. We need two text files which are used as prompts and LLM's pretrain file.

In the directory, toys_prompts.txt and toys_pretrain.txt are these two files. Actually they have the same type of contents, prompt file provid prompts for Actor to generate samples, and pretrain file just keep the Actor prameters moving so farwary from original model.

Run the following scripts:
```
sh train_prompts.sh
```

After several minitues, we will get our first RLHF LLM which is stored in a new directory lora_ppo! Again we only store the LoRa parameters, it's easy and the same usage like standard peft!


---

# Dataformat
Please refer the formats in test_sft.txt, test_prompts.txt, test_pretrained.txt.
Please refer the formats in toys_pretrain.txt,toys_sft.txt, toys_prompts.txt,toys_rewards_test.jsonl, toys_reward_train.jsonl.

It's easy to use your own data to train all of 3 stages model.
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import json
import os
import sys
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--input_file', type=str, default=None, help='input file')
parser.add_argument('--output_file', type=str, default=None, help='output file')
args = parser.parse_args()
with open(args.input_file, 'r',encoding='UTF-8') as f:
with open(args.output_file, 'w',encoding='UTF-8') as f2:
for line in f:
line = line.strip()
data = json.loads(line)
prompt = data['prompt']
response = data['chosen']
instructions = "请根据法律知识回答。提问:"+prompt+"\t回答\t"+response+'\n'
f2.write(instructions)
Loading