From 3aa306d7f8373ef5cae2f8e949f4604c5bdce839 Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Tue, 14 Mar 2023 14:18:15 +0800 Subject: [PATCH 01/18] add normalize function to value_head in bloom rm --- applications/ChatGPT/chatgpt/models/bloom/bloom_rm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/applications/ChatGPT/chatgpt/models/bloom/bloom_rm.py b/applications/ChatGPT/chatgpt/models/bloom/bloom_rm.py index 4dc2646e36ae..2dba227ff7d0 100644 --- a/applications/ChatGPT/chatgpt/models/bloom/bloom_rm.py +++ b/applications/ChatGPT/chatgpt/models/bloom/bloom_rm.py @@ -33,4 +33,5 @@ def __init__(self, if checkpoint: model.gradient_checkpointing_enable() value_head = nn.Linear(model.config.hidden_size, 1) + value_head.weight.data.normal_(mean=0.0, std=1/(model.config.hidden_size + 1)) super().__init__(model, value_head, lora_rank, lora_train_bias) From 9f19beadb49ed564382f1caed05c7cbf91e5b808 Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Tue, 14 Mar 2023 14:19:41 +0800 Subject: [PATCH 02/18] add normalization to value_function in gpt_rm --- applications/ChatGPT/chatgpt/models/gpt/gpt_rm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/applications/ChatGPT/chatgpt/models/gpt/gpt_rm.py b/applications/ChatGPT/chatgpt/models/gpt/gpt_rm.py index 0132dbf27ffc..19d673de6825 100644 --- a/applications/ChatGPT/chatgpt/models/gpt/gpt_rm.py +++ b/applications/ChatGPT/chatgpt/models/gpt/gpt_rm.py @@ -35,4 +35,5 @@ def __init__(self, model.gradient_checkpointing_enable() value_head = nn.Linear(model.config.n_embd, 1) + value_head.weight.data.normal_(mean=0.0, std=1/(model.config.n_embd + 1)) super().__init__(model, value_head, lora_rank, lora_train_bias) From 9cd2a67ea360a0e0fd28b0d6732ad21048e4ce31 Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Tue, 14 Mar 2023 14:20:59 +0800 Subject: [PATCH 03/18] add normalization to value_head of opt_rm --- applications/ChatGPT/chatgpt/models/opt/opt_rm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/applications/ChatGPT/chatgpt/models/opt/opt_rm.py b/applications/ChatGPT/chatgpt/models/opt/opt_rm.py index 7ad7b3887e53..ef7f0fb16fd1 100644 --- a/applications/ChatGPT/chatgpt/models/opt/opt_rm.py +++ b/applications/ChatGPT/chatgpt/models/opt/opt_rm.py @@ -34,4 +34,5 @@ def __init__(self, model.gradient_checkpointing_enable() value_head = nn.Linear(model.config.word_embed_proj_dim, 1) + value_head.weight.data.normal_(mean=0.0, std=1/(model.config.word_embed_proj_dim + 1)) super().__init__(model, value_head, lora_rank, lora_train_bias) From 087eff422dd93ada6953471ca16271bf7877d65b Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Tue, 14 Mar 2023 14:24:13 +0800 Subject: [PATCH 04/18] add Anthropic/hh-rlhf dataset --- .../ChatGPT/chatgpt/dataset/reward_dataset.py | 65 +++++++++++++++++-- 1 file changed, 60 insertions(+), 5 deletions(-) diff --git a/applications/ChatGPT/chatgpt/dataset/reward_dataset.py b/applications/ChatGPT/chatgpt/dataset/reward_dataset.py index 8bc850f2d52d..a86697cf91e8 100644 --- a/applications/ChatGPT/chatgpt/dataset/reward_dataset.py +++ b/applications/ChatGPT/chatgpt/dataset/reward_dataset.py @@ -5,8 +5,8 @@ from .utils import is_rank_0 - -class RewardDataset(Dataset): +# Dahaos/rm-static +class RmStaticDataset(Dataset): """ Dataset for reward model @@ -14,16 +14,71 @@ class RewardDataset(Dataset): dataset: dataset for reward model tokenizer: tokenizer for reward model max_length: max length of input + special_token: special token at the end of sentence """ - def __init__(self, dataset, tokenizer: Callable, max_length: int) -> None: + def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: super().__init__() self.chosen = [] self.reject = [] + if special_token is None: + self.end_token = tokenizer.eos_token + else: + self.end_token = special_token for data in tqdm(dataset, disable=not is_rank_0()): prompt = data['prompt'] - chosen = prompt + data['chosen'] + "<|endoftext|>" + chosen = prompt + data['chosen'] + self.end_token + chosen_token = tokenizer(chosen, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + self.chosen.append({ + "input_ids": chosen_token['input_ids'], + "attention_mask": chosen_token['attention_mask'] + }) + + reject = prompt + data['rejected'] + self.end_token + reject_token = tokenizer(reject, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + self.reject.append({ + "input_ids": reject_token['input_ids'], + "attention_mask": reject_token['attention_mask'] + }) + + def __len__(self): + length = len(self.chosen) + return length + + def __getitem__(self, idx): + return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ + "input_ids"], self.reject[idx]["attention_mask"] + +# Anthropic/hh-rlhf +class HhRlhfDataset(Dataset): + """ + Dataset for reward model + + Args: + dataset: dataset for reward model + tokenizer: tokenizer for reward model + max_length: max length of input + special_token: special token at the end of sentence + """ + def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: + super().__init__() + self.chosen = [] + self.reject = [] + if special_token is None: + self.end_token = tokenizer.eos_token + else: + self.end_token = special_token + for data in tqdm(dataset, disable=not is_rank_0()): + chosen = data['chosen'] + self.end_token chosen_token = tokenizer(chosen, max_length=max_length, padding="max_length", @@ -34,7 +89,7 @@ def __init__(self, dataset, tokenizer: Callable, max_length: int) -> None: "attention_mask": chosen_token['attention_mask'] }) - reject = prompt + data['rejected'] + "<|endoftext|>" + reject = data['rejected'] + self.end_token reject_token = tokenizer(reject, max_length=max_length, padding="max_length", From bb64cc56f86bcb067cc3af7fd46eec1897929deb Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Tue, 14 Mar 2023 14:24:59 +0800 Subject: [PATCH 05/18] Update __init__.py --- applications/ChatGPT/chatgpt/dataset/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/applications/ChatGPT/chatgpt/dataset/__init__.py b/applications/ChatGPT/chatgpt/dataset/__init__.py index b4599c82ba75..83393098775f 100644 --- a/applications/ChatGPT/chatgpt/dataset/__init__.py +++ b/applications/ChatGPT/chatgpt/dataset/__init__.py @@ -1,4 +1,4 @@ -from .reward_dataset import RewardDataset +from .reward_dataset import RmStaticDataset, HhRlhfDataset from .utils import is_rank_0 -__all__ = ['RewardDataset', 'is_rank_0'] +__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0'] From cf9b1dfc4a0d1bd773627598697a30190690f948 Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Tue, 14 Mar 2023 14:26:27 +0800 Subject: [PATCH 06/18] Add LogExpLoss in RM training --- applications/ChatGPT/chatgpt/models/loss.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/applications/ChatGPT/chatgpt/models/loss.py b/applications/ChatGPT/chatgpt/models/loss.py index 0ebcfea061b0..c5b1ccc93228 100644 --- a/applications/ChatGPT/chatgpt/models/loss.py +++ b/applications/ChatGPT/chatgpt/models/loss.py @@ -93,13 +93,23 @@ def forward(self, return policy_loss + self.pretrain_coef * lm_loss -class PairWiseLoss(nn.Module): +class LogSigLoss(nn.Module): """ Pairwise Loss for Reward Model + Details: https://arxiv.org/abs/2203.02155 """ - def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor: probs = torch.sigmoid(chosen_reward - reject_reward) log_probs = torch.log(probs) loss = -log_probs.mean() return loss + + +class LogExpLoss(nn.Module): + """ + Pairwise Loss for Reward Model + Details: https://arxiv.org/abs/2204.05862 + """ + def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor: + loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean() + return loss From 32cf672f7db787e4a171a9ca8c143835efa0266f Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Tue, 14 Mar 2023 14:27:53 +0800 Subject: [PATCH 07/18] Update __init__.py --- applications/ChatGPT/chatgpt/models/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/applications/ChatGPT/chatgpt/models/__init__.py b/applications/ChatGPT/chatgpt/models/__init__.py index 376fed8de792..b274188a21df 100644 --- a/applications/ChatGPT/chatgpt/models/__init__.py +++ b/applications/ChatGPT/chatgpt/models/__init__.py @@ -1,4 +1,4 @@ from .base import Actor, Critic, RewardModel -from .loss import PairWiseLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss +from .loss import PolicyLoss, PPOPtxActorLoss, ValueLoss, LogSigLoss, LogExpLoss -__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'PairWiseLoss'] +__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss'] From 450f708d62f839dbb9b659e73aa538358d42eaf2 Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Tue, 14 Mar 2023 14:32:16 +0800 Subject: [PATCH 08/18] update rm trainer to use acc as target --- applications/ChatGPT/chatgpt/trainer/rm.py | 111 +++++++++++++-------- 1 file changed, 69 insertions(+), 42 deletions(-) diff --git a/applications/ChatGPT/chatgpt/trainer/rm.py b/applications/ChatGPT/chatgpt/trainer/rm.py index c07d65f84ca5..7fa87a64968b 100644 --- a/applications/ChatGPT/chatgpt/trainer/rm.py +++ b/applications/ChatGPT/chatgpt/trainer/rm.py @@ -1,13 +1,12 @@ from abc import ABC - +import pandas as pd import loralib as lora import torch -from chatgpt.dataset import RewardDataset -from chatgpt.models.loss import PairWiseLoss -from torch.optim import Adam, Optimizer -from torch.utils.data import DataLoader +from datetime import datetime +from torch.optim import Optimizer, lr_scheduler +from torch.utils.data import DataLoader, Dataset from tqdm import tqdm - + from .strategies import Strategy from .utils import is_rank_0 @@ -20,11 +19,12 @@ class RewardModelTrainer(ABC): model (torch.nn.Module): the model to train strategy (Strategy): the strategy to use for training optim(Optimizer): the optimizer to use for training - train_dataset (RewardDataset): the dataset to use for training - eval_dataset (RewardDataset): the dataset to use for evaluation + loss_fn (callable): the loss function to use for training + train_dataset (Dataset): the dataset to use for training + valid_dataset (Dataset): the dataset to use for validation + eval_dataset (Dataset): the dataset to use for evaluation batch_size (int, defaults to 1): the batch size while training max_epochs (int, defaults to 2): the number of epochs to train - optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer """ def __init__( @@ -32,24 +32,52 @@ def __init__( model, strategy: Strategy, optim: Optimizer, - train_dataset: RewardDataset, - eval_dataset: RewardDataset, + loss_fn, + train_dataset: Dataset, + valid_dataset: Dataset, + eval_dataset: Dataset, batch_size: int = 1, - max_epochs: int = 2, + max_epochs: int = 1, ) -> None: super().__init__() self.strategy = strategy self.epochs = max_epochs - self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size) - self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size) - + self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + self.valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True) + self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True) + self.model = strategy.setup_model(model) - if "DDP" in str(self.strategy): - self.model = self.model.module - self.loss_fn = PairWiseLoss() + self.loss_fn = loss_fn self.optimizer = strategy.setup_optimizer(optim, self.model) + self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__()//100) + - def fit(self, use_lora): + def eval_acc(self, dataloader): + dist = 0 + on = 0 + cnt = 0 + self.model.eval() + with torch.no_grad(): + for chosen_ids, c_mask, reject_ids, r_mask in dataloader: + chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) + c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) + reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device()) + r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) + chosen_reward = self.model(chosen_ids, attention_mask=c_mask) + reject_reward = self.model(reject_ids, attention_mask=r_mask) + for i in range(len(chosen_reward)): + cnt += 1 + if chosen_reward[i] > reject_reward[i]: + on += 1 + dist += (chosen_reward - reject_reward).mean().item() + dist_mean = dist / len(dataloader) + acc = on / cnt + self.model.train() + return dist_mean, acc + + + def fit(self): + time = datetime.now() epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0()) for epoch in range(self.epochs): step_bar = tqdm(range(self.train_dataloader.__len__()), @@ -57,37 +85,36 @@ def fit(self, use_lora): disable=not is_rank_0()) # train self.model.train() + cnt = 0 + acc = 0 + dist = 0 for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader: - chosen_ids = chosen_ids.squeeze(1).cuda() - c_mask = c_mask.squeeze(1).cuda() - reject_ids = reject_ids.squeeze(1).cuda() - r_mask = r_mask.squeeze(1).cuda() + chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) + c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) + reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device()) + r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) chosen_reward = self.model(chosen_ids, attention_mask=c_mask) reject_reward = self.model(reject_ids, attention_mask=r_mask) loss = self.loss_fn(chosen_reward, reject_reward) self.strategy.backward(loss, self.model, self.optimizer) self.strategy.optimizer_step(self.optimizer) self.optimizer.zero_grad() + cnt += 1 + if cnt == 100: + self.scheduler.step() + dist, acc = self.eval_acc(self.valid_dataloader) + cnt = 0 + if is_rank_0(): + log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc']) + log.to_csv('log_%s.csv' % time, mode='a', header=False, index=False) step_bar.update() - step_bar.set_postfix({'loss': loss.item()}) - + step_bar.set_postfix({'dist': dist, 'acc': acc}) + # eval - self.model.eval() - with torch.no_grad(): - dist = 0 - loss_sum = 0 - for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader: - chosen_ids = chosen_ids.squeeze(1).cuda() - c_mask = c_mask.squeeze(1).cuda() - reject_ids = reject_ids.squeeze(1).cuda() - r_mask = r_mask.squeeze(1).cuda() - chosen_reward = self.model(chosen_ids, attention_mask=c_mask) - reject_reward = self.model(reject_ids, attention_mask=r_mask) - dist += (chosen_reward - reject_reward).mean().item() - loss = self.loss_fn(chosen_reward, reject_reward) - loss_sum += loss.item() - dist_mean = dist / self.eval_dataloader.__len__() - loss_mean = loss_sum / self.eval_dataloader.__len__() + dist, acc = self.eval_acc(self.eval_dataloader) + if is_rank_0(): + log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc']) + log.to_csv('log.csv', mode='a', header=False, index=False) epoch_bar.update() - step_bar.set_postfix({'loss': loss_mean, 'dist_mean': dist_mean}) + step_bar.set_postfix({'dist': dist, 'acc': acc}) step_bar.close() From e7cb711216e8f480e19f513c708f247d8e145aaa Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Tue, 14 Mar 2023 14:38:46 +0800 Subject: [PATCH 09/18] update example/train_rm --- .../ChatGPT/examples/train_reward_model.py | 96 +++++++++++++------ 1 file changed, 68 insertions(+), 28 deletions(-) diff --git a/applications/ChatGPT/examples/train_reward_model.py b/applications/ChatGPT/examples/train_reward_model.py index 19b20b0847cc..7c8f8b09eaf2 100644 --- a/applications/ChatGPT/examples/train_reward_model.py +++ b/applications/ChatGPT/examples/train_reward_model.py @@ -2,7 +2,8 @@ import loralib as lora import torch -from chatgpt.dataset import RewardDataset +from chatgpt.dataset import HhRlhfDataset, RmStaticDataset +from chatgpt.model import LogSigLoss, LogExpLoss from chatgpt.models.base import RewardModel from chatgpt.models.bloom import BLOOMRM from chatgpt.models.gpt import GPTRM @@ -10,6 +11,7 @@ from chatgpt.trainer import RewardModelTrainer from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from datasets import load_dataset +from random import randint from torch.optim import Adam from transformers import AutoTokenizer, BloomTokenizerFast from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer @@ -33,69 +35,107 @@ def train(args): # configure model with strategy.model_init_context(): if args.model == 'bloom': - model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() + model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) elif args.model == 'opt': - model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() + model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) elif args.model == 'gpt2': - model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() + model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) else: raise ValueError(f'Unsupported model "{args.model}"') - + + if args.model_path is not None: + state_dict = torch.load(args.model_path) + model.load_state_dict(state_dict) + # configure tokenizer if args.model == 'gpt2': tokenizer = GPT2Tokenizer.from_pretrained('gpt2') - tokenizer.pad_token = tokenizer.eos_token elif args.model == 'bloom': tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) - tokenizer.pad_token = tokenizer.eos_token elif args.model == 'opt': tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + elif args.model == 'deberta': + tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-large") else: raise ValueError(f'Unsupported model "{args.model}"') - tokenizer.pad_token = tokenizer.eos_token - - max_len = 512 + + max_len = args.max_len # configure optimizer if args.strategy.startswith('colossalai'): - optim = HybridAdam(model.parameters(), lr=5e-5) + optim = HybridAdam(model.parameters(), lr=1.5e-5) else: - optim = Adam(model.parameters(), lr=5e-5) - + optim = Adam(model.parameters(), lr=1.5e-5) + + # configure loss function + if args.loss_fn == 'log_sig': + loss_fn = LogSigLoss() + elif args.loss_fn == 'log_exp': + loss_fn = LogExpLoss() + else: + raise ValueError(f'Unsupported loss function "{args.loss_fn}"') + # prepare for data and dataset - data = load_dataset(args.dataset) - train_data = data["train"] - eval_data = data['test'] - train_dataset = RewardDataset(train_data, tokenizer, max_len) - eval_dataset = RewardDataset(eval_data, tokenizer, max_len) - + if args.subset is not None: + data = load_dataset(args.dataset, data_dir=args.subset) + else: + data = load_dataset(args.dataset) + + if args.test: + train_data = data['train'].select(range(100)) + eval_data = data['test'].select(range(10)) + else: + train_data = data['train'] + eval_data = data['test'] + valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data)//10))) + + if args.dataset == 'Dahoas/rm-static': + train_dataset = RmStaticDataset(train_data, tokenizer, max_len) + valid_dataset = RmStaticDataset(valid_data, tokenizer, max_len) + eval_dataset = RmStaticDataset(eval_data, tokenizer, max_len) + elif args.dataset == 'Anthropic/hh-rlhf': + train_dataset = HhRlhfDataset(train_data, tokenizer, max_len) + valid_dataset = HhRlhfDataset(valid_data, tokenizer, max_len) + eval_dataset = HhRlhfDataset(eval_data, tokenizer, max_len) + else: + raise ValueError(f'Unsupported dataset "{args.dataset}"') + trainer = RewardModelTrainer(model=model, strategy=strategy, optim=optim, + loss_fn = loss_fn, train_dataset=train_dataset, + valid_dataset=valid_dataset, eval_dataset=eval_dataset, batch_size=args.batch_size, max_epochs=args.max_epochs) - trainer.fit(use_lora=args.lora_rank) - + trainer.fit() # save model checkpoint after fitting on only rank0 - strategy.save_model(model, 'rm_checkpoint.pt', only_rank0=True) + strategy.save_model(trainer.model, args.save_path, only_rank0=True) # save optimizer checkpoint on all ranks - strategy.save_optimizer(optim, 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False) - + if args.need_optim_ckpt: + strategy.save_optimizer(trainer.optimizer, 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--strategy', choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='naive') - parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom') + parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta'], default='bloom') parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--dataset', type=str, default='Dahoas/rm-static') - parser.add_argument('--save_path', type=str, default='rm_ckpt.pth') + parser.add_argument('--model_path', type=str, default=None) + parser.add_argument('--need_optim_ckpt', type=bool, default=False) + parser.add_argument('--dataset', type=str, + choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'], + default='Dahoas/rm-static') + parser.add_argument('--subset', type=str, default=None) + parser.add_argument('--save_path', type=str, default='rm_ckpt.pt') parser.add_argument('--max_epochs', type=int, default=1) - parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--batch_size', type=int, default=8) + parser.add_argument('--max_len', type=int, default=512) parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument('--loss_fn', type=str, default='log_sig', choices=['log_sig', 'log_exp']) + parser.add_argument('--test', type=bool, default=False) args = parser.parse_args() train(args) From cbf16c135958b27326119480fd75a06cb7cbcf1d Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Tue, 14 Mar 2023 14:40:05 +0800 Subject: [PATCH 10/18] Update train_rm.sh --- applications/ChatGPT/examples/train_rm.sh | 26 ++++++----------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/applications/ChatGPT/examples/train_rm.sh b/applications/ChatGPT/examples/train_rm.sh index 6e11a148bfbe..981b7a15fcd4 100755 --- a/applications/ChatGPT/examples/train_rm.sh +++ b/applications/ChatGPT/examples/train_rm.sh @@ -1,20 +1,8 @@ -set_n_least_used_CUDA_VISIBLE_DEVICES() { - local n=${1:-"9999"} - echo "GPU Memory Usage:" - local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ - | tail -n +2 \ - | nl -v 0 \ - | tee /dev/tty \ - | sort -g -k 2 \ - | awk '{print $1}' \ - | head -n $n) - export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') - echo "Now CUDA_VISIBLE_DEVICES is set to:" - echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" -} +set_n_least_used_CUDA_VISIBLE_DEVICES 1 -set_n_least_used_CUDA_VISIBLE_DEVICES 2 - -# torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain 'bigscience/bloomz-560m' --model 'bloom' --strategy colossalai_zero2 -torchrun --standalone --nproc_per_node=2 train_reward_model.py --model 'gpt2' --strategy colossalai_zero2 -# torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2 +python train_reward_model.py --pretrain '/home/lczht/data2/bloom-560m' \ + --model 'bloom' \ + --strategy naive \ + --loss_fn 'log_exp'\ + --save_path 'rmstatic.pt' \ + --test True From 1a16bd2efa66aacfc0947e974b3e0605497eca35 Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Tue, 14 Mar 2023 14:47:20 +0800 Subject: [PATCH 11/18] code style --- applications/ChatGPT/chatgpt/dataset/reward_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/applications/ChatGPT/chatgpt/dataset/reward_dataset.py b/applications/ChatGPT/chatgpt/dataset/reward_dataset.py index a86697cf91e8..9ee13490b893 100644 --- a/applications/ChatGPT/chatgpt/dataset/reward_dataset.py +++ b/applications/ChatGPT/chatgpt/dataset/reward_dataset.py @@ -22,7 +22,7 @@ def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token= self.chosen = [] self.reject = [] if special_token is None: - self.end_token = tokenizer.eos_token + self.end_token = tokenizer.eos_token else: self.end_token = special_token for data in tqdm(dataset, disable=not is_rank_0()): @@ -74,7 +74,7 @@ def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token= self.chosen = [] self.reject = [] if special_token is None: - self.end_token = tokenizer.eos_token + self.end_token = tokenizer.eos_token else: self.end_token = special_token for data in tqdm(dataset, disable=not is_rank_0()): From 4c945fba5e6d0aef0077c976695fb4f85c37e8f7 Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Wed, 15 Mar 2023 17:10:06 +0800 Subject: [PATCH 12/18] Update README.md --- applications/ChatGPT/examples/README.md | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/applications/ChatGPT/examples/README.md b/applications/ChatGPT/examples/README.md index 3876d20f02d7..7bea2f230b44 100644 --- a/applications/ChatGPT/examples/README.md +++ b/applications/ChatGPT/examples/README.md @@ -9,16 +9,24 @@ pip install -r requirements.txt ## Train the reward model (Stage 2) We use [rm-static](https://huggingface.co/datasets/Dahoas/rm-static) as dataset to train our reward model. It is a dataset of chosen & rejected response of the same prompt. -You can download the dataset from huggingface automatically. +You can download the datasets from huggingface automatically. Use these code to train your reward model. - ```shell -# Naive reward model training -python train_reward_model.py --pretrain --model --strategy naive +# Take naive reward model training with opt-350m as example +python train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy naive # use colossalai_zero2 -torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain --model --strategy colossalai_zero2 +torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy naive ``` +### About the update in 03/15 + +- [x] We support hh-rlhf dataset from [Anthropic](https://huggingface.co/datasets/Anthropic/hh-rlhf). +- [x] We support 2 kinds of loss_function named 'log_sig'(used by OpenAI) and 'log_exp'(used by Anthropic). +- [x] We change the loss to valid_acc and pair_dist to monitor progress during training. +- [x] We add special token to the end of the sequence to get better result. +- [x] We train a Bloom-560m reward model and find the test acc of the model achieve the performance mentions in [Anthropics paper](https://arxiv.org/abs/2112.00861). +### Experiment result +TODO ## Train with dummy prompt data (Stage 3) From a2a269e3453693c6097eaff16152d9f0e337f721 Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Wed, 15 Mar 2023 19:13:32 +0800 Subject: [PATCH 13/18] Update README.md --- applications/ChatGPT/examples/README.md | 41 ++++++++++++++++--------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/applications/ChatGPT/examples/README.md b/applications/ChatGPT/examples/README.md index 7bea2f230b44..ce73a5407944 100644 --- a/applications/ChatGPT/examples/README.md +++ b/applications/ChatGPT/examples/README.md @@ -7,34 +7,42 @@ pip install -r requirements.txt ``` ## Train the reward model (Stage 2) -We use [rm-static](https://huggingface.co/datasets/Dahoas/rm-static) as dataset to train our reward model. It is a dataset of chosen & rejected response of the same prompt. - -You can download the datasets from huggingface automatically. - Use these code to train your reward model. ```shell # Take naive reward model training with opt-350m as example python train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy naive # use colossalai_zero2 -torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy naive +torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2 ``` -### About the update in 03/15 -- [x] We support hh-rlhf dataset from [Anthropic](https://huggingface.co/datasets/Anthropic/hh-rlhf). -- [x] We support 2 kinds of loss_function named 'log_sig'(used by OpenAI) and 'log_exp'(used by Anthropic). -- [x] We change the loss to valid_acc and pair_dist to monitor progress during training. -- [x] We add special token to the end of the sequence to get better result. -- [x] We train a Bloom-560m reward model and find the test acc of the model achieve the performance mentions in [Anthropics paper](https://arxiv.org/abs/2112.00861). +### Features and tricks in RM training +- We support [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)and[rm-static](https://huggingface.co/datasets/Dahoas/rm-static) datasets. +- We support 2 kinds of loss_function named 'log_sig'(used by OpenAI) and 'log_exp'(used by Anthropic). +- We change the loss to valid_acc and pair_dist to monitor progress during training. +- We add special token to the end of the sequence to get better result. +- We use cosine-reducing lr-scheduler for RM training. +- We set value_head as 1 liner layer and initialize the weight of value_head using N(0,1/(d_model + 1)) distribution. +- We train a Bloom-560m reward model for 1 epoch and find the test acc of the model achieve the performance mentions in [Anthropics paper](https://arxiv.org/abs/2112.00861). + ### Experiment result -TODO +Model performance in [Anthropics paper](https://arxiv.org/abs/2112.00861): + +
image + +
Our training & test result of bloom-560m for 1 epoch: + +
image + +
## Train with dummy prompt data (Stage 3) -This script supports 3 strategies: +This script supports 4 kinds of strategies: - naive - ddp -- colossalai +- colossalai_zero2 +- colossalai_gemini It uses random generated prompt data. @@ -61,7 +69,7 @@ We use [awesome-chatgpt-prompts](https://huggingface.co/datasets/fka/awesome-cha You should download `prompts.csv` first. -This script also supports 3 strategies. +This script also supports 4 strategies. ```shell # display cli help @@ -83,6 +91,9 @@ python inference.py --model_path --model Date: Thu, 16 Mar 2023 09:13:02 +0800 Subject: [PATCH 14/18] add rm test to ci --- applications/ChatGPT/examples/test_ci.sh | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/applications/ChatGPT/examples/test_ci.sh b/applications/ChatGPT/examples/test_ci.sh index 0aa4a36fe514..734ed93c3dc7 100755 --- a/applications/ChatGPT/examples/test_ci.sh +++ b/applications/ChatGPT/examples/test_ci.sh @@ -69,3 +69,17 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH \ python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_prompts.pt --pretrain 'gpt2' --model gpt2 rm -rf ${BASE}/actor_checkpoint_prompts.pt + +# train rm +torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ + --pretrain 'facebook/opt-350m' --model 'opt' \ + --strategy colossalai_zero2 --loss_fn 'log_sig'\ + --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\ + --test True --lora_rank 4 + +torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ + --pretrain 'gpt2' --model 'gpt2' \ + --strategy colossalai_gemini --loss_fn 'log_exp'\ + --dataset 'Dahoas/rm-static' --test True --lora_rank 4 + +rm -rf ${BASE}/rm_ckpt.pt From 13568b809c97f8bc92b1a9a3848be0be5f67d5df Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Thu, 16 Mar 2023 09:15:01 +0800 Subject: [PATCH 15/18] fix tokenier --- applications/ChatGPT/examples/train_reward_model.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/applications/ChatGPT/examples/train_reward_model.py b/applications/ChatGPT/examples/train_reward_model.py index 7c8f8b09eaf2..6baff8662314 100644 --- a/applications/ChatGPT/examples/train_reward_model.py +++ b/applications/ChatGPT/examples/train_reward_model.py @@ -18,7 +18,6 @@ from colossalai.nn.optimizer import HybridAdam - def train(args): # configure strategy if args.strategy == 'naive': @@ -50,15 +49,13 @@ def train(args): # configure tokenizer if args.model == 'gpt2': tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.eos_token elif args.model == 'bloom': - tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) + tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') elif args.model == 'opt': tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - elif args.model == 'deberta': - tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-large") else: raise ValueError(f'Unsupported model "{args.model}"') - max_len = args.max_len # configure optimizer @@ -122,7 +119,7 @@ def train(args): parser.add_argument('--strategy', choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='naive') - parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta'], default='bloom') + parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom') parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--model_path', type=str, default=None) parser.add_argument('--need_optim_ckpt', type=bool, default=False) From eddd6e380c6ad19bc547185dd2f29ed3a910d9ba Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Thu, 16 Mar 2023 10:37:15 +0800 Subject: [PATCH 16/18] fix typo --- applications/ChatGPT/examples/train_reward_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ChatGPT/examples/train_reward_model.py b/applications/ChatGPT/examples/train_reward_model.py index 6baff8662314..3fc698f33c5e 100644 --- a/applications/ChatGPT/examples/train_reward_model.py +++ b/applications/ChatGPT/examples/train_reward_model.py @@ -3,7 +3,7 @@ import loralib as lora import torch from chatgpt.dataset import HhRlhfDataset, RmStaticDataset -from chatgpt.model import LogSigLoss, LogExpLoss +from chatgpt.models import LogSigLoss, LogExpLoss from chatgpt.models.base import RewardModel from chatgpt.models.bloom import BLOOMRM from chatgpt.models.gpt import GPTRM From 915ae5ef1d2a6abd53a895ad16ec7477db7a4ca7 Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Mon, 20 Mar 2023 09:43:50 +0800 Subject: [PATCH 17/18] change batchsize to avoid oom in ci --- applications/ChatGPT/examples/train_reward_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ChatGPT/examples/train_reward_model.py b/applications/ChatGPT/examples/train_reward_model.py index 3fc698f33c5e..47dd988b8117 100644 --- a/applications/ChatGPT/examples/train_reward_model.py +++ b/applications/ChatGPT/examples/train_reward_model.py @@ -129,7 +129,7 @@ def train(args): parser.add_argument('--subset', type=str, default=None) parser.add_argument('--save_path', type=str, default='rm_ckpt.pt') parser.add_argument('--max_epochs', type=int, default=1) - parser.add_argument('--batch_size', type=int, default=8) + parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('--max_len', type=int, default=512) parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument('--loss_fn', type=str, default='log_sig', choices=['log_sig', 'log_exp']) From 36a2b2e47139f1e2552c8c0977be7f360620a465 Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Mon, 20 Mar 2023 09:45:01 +0800 Subject: [PATCH 18/18] Update test_ci.sh --- applications/ChatGPT/examples/test_ci.sh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/applications/ChatGPT/examples/test_ci.sh b/applications/ChatGPT/examples/test_ci.sh index 734ed93c3dc7..abc43ab1ee9e 100755 --- a/applications/ChatGPT/examples/test_ci.sh +++ b/applications/ChatGPT/examples/test_ci.sh @@ -81,5 +81,11 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ --pretrain 'gpt2' --model 'gpt2' \ --strategy colossalai_gemini --loss_fn 'log_exp'\ --dataset 'Dahoas/rm-static' --test True --lora_rank 4 + +torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ + --pretrain 'bigscience/bloom-560m' --model 'bloom' \ + --strategy colossalai_zero2 --loss_fn 'log_sig'\ + --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\ + --test True --lora_rank 4 rm -rf ${BASE}/rm_ckpt.pt