Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions applications/ChatGPT/chatgpt/models/deberta/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .deberta_critic import DebertaCritic
from .deberta_rm import DebertaRM

__all__ = ['DebertaCritic', 'DebertaRM']
36 changes: 36 additions & 0 deletions applications/ChatGPT/chatgpt/models/deberta/deberta_critic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Optional

import torch.nn as nn
from transformers import DebertaV2Config, DebertaV2Model

from ..base import Critic


class DebertaCritic(Critic):
"""
Deberta Critic model.

Args:
pretrained (str): Pretrained model name or path.
config (DebertaV2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LO-RA decomposition.
lora_train_bias (str): LoRA bias training mode.
"""

def __init__(self,
pretrained: Optional[str] = None,
config: Optional[DebertaV2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = DebertaV2Model.from_pretrained(pretrained)
elif config is not None:
model = DebertaV2Model(config)
else:
model = DebertaV2Model(DebertaV2Config())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias)
37 changes: 37 additions & 0 deletions applications/ChatGPT/chatgpt/models/deberta/deberta_rm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Optional

import torch.nn as nn
from transformers import DebertaV2Config, DebertaV2Model

from ..base import RewardModel


class DebertaRM(RewardModel):
"""
Deberta Reward model.

Args:
pretrained (str): Pretrained model name or path.
config (DebertaV2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LO-RA decomposition.
lora_train_bias (str): LoRA bias training mode.
"""

def __init__(self,
pretrained: str = None,
config: Optional[DebertaV2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = DebertaV2Model.from_pretrained(pretrained)
elif config is not None:
model = DebertaV2Model(config)
else:
model = DebertaV2Model(DebertaV2Config())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)
1 change: 1 addition & 0 deletions applications/ChatGPT/examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pandas>=1.4.1
sentencepiece
6 changes: 6 additions & 0 deletions applications/ChatGPT/examples/test_ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,10 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
--test True --lora_rank 4

torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'microsoft/deberta-v3-large' --model 'deberta' \
--strategy colossalai_zero2 --loss_fn 'log_sig'\
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
--test True --lora_rank 4

rm -rf ${BASE}/rm_ckpt.pt
9 changes: 7 additions & 2 deletions applications/ChatGPT/examples/train_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from chatgpt.models.bloom import BLOOMRM
from chatgpt.models.gpt import GPTRM
from chatgpt.models.opt import OPTRM
from chatgpt.models.deberta import DebertaRM
from chatgpt.trainer import RewardModelTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from datasets import load_dataset
from random import randint
from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer

from colossalai.nn.optimizer import HybridAdam
Expand All @@ -39,6 +40,8 @@ def train(args):
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'gpt2':
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'deberta':
model = DebertaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
else:
raise ValueError(f'Unsupported model "{args.model}"')

Expand All @@ -54,6 +57,8 @@ def train(args):
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
elif args.model == 'deberta':
tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-large')
else:
raise ValueError(f'Unsupported model "{args.model}"')
max_len = args.max_len
Expand Down Expand Up @@ -119,7 +124,7 @@ def train(args):
parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta'], default='bloom')
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--model_path', type=str, default=None)
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
Expand Down
4 changes: 2 additions & 2 deletions applications/ChatGPT/examples/train_rm.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
set_n_least_used_CUDA_VISIBLE_DEVICES 1

python train_reward_model.py --pretrain '/home/lczht/data2/bloom-560m' \
--model 'bloom' \
python train_reward_model.py --pretrain 'microsoft/deberta-v3-large' \
--model 'deberta' \
--strategy naive \
--loss_fn 'log_exp'\
--save_path 'rmstatic.pt' \
Expand Down