Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion applications/ChatGPT/chatgpt/dataset/sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,15 @@ def preprocess(
class AlpacaDataset(Dataset):
"""Dataset for supervised fine-tuning."""

def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_length: int=None):
super(AlpacaDataset, self).__init__()
logger.info("Loading data...")
list_data_dict = jload(data_path)
logger.info(f"Loaded {len(list_data_dict)} examples.")

if max_length is not None:
logger.info(f"Truncating data to max length {max_length}...")
list_data_dict = [example for example in list_data_dict if len(example["input"]) <= max_length]

logger.info("Formatting inputs...")
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
Expand Down
3 changes: 3 additions & 0 deletions applications/ChatGPT/chatgpt/models/base/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,6 @@ def forward(self,
logits = output['logits']
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]

def get_base_model(self):
return self.model
2 changes: 2 additions & 0 deletions applications/ChatGPT/chatgpt/models/llama/llama_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,5 @@ def __init__(self,

super().__init__(model, lora_rank, lora_train_bias)

def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
29 changes: 17 additions & 12 deletions applications/ChatGPT/chatgpt/trainer/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,15 @@ def fit(self, logger, use_lora, log_interval=10):
# train
self.model.train()
for batch_id, batch in enumerate(self.train_dataloader):
prompt_ids = batch["input_ids"]
p_mask = batch["attention_mask"]
labels = batch["labels"]
prompt_ids = prompt_ids.squeeze(1).cuda()
p_mask = p_mask.squeeze(1).cuda()
prompt_ids = batch["input_ids"].to(torch.cuda.current_device())
p_mask = batch["attention_mask"].to(torch.cuda.current_device())
labels = batch["labels"].to(torch.cuda.current_device())
# prompt_ids = prompt_ids.squeeze(1).cuda()
# p_mask = p_mask.squeeze(1).cuda()
# prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
loss, prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
outputs = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
loss = outputs.loss
prompt_logits = outputs.logits

# loss = self.loss_fn(prompt_logits, labels)
self.strategy.backward(loss, self.model, self.optimizer)
Expand All @@ -83,13 +85,16 @@ def fit(self, logger, use_lora, log_interval=10):
loss_sum = 0
num_seen = 0
for batch in self.eval_dataloader:
prompt_ids = batch["input_ids"]
p_mask = batch["attention_mask"]
prompt_ids = prompt_ids.squeeze(1).cuda()
p_mask = p_mask.squeeze(1).cuda()
prompt_ids = batch["input_ids"].to(torch.cuda.current_device())
p_mask = batch["attention_mask"].to(torch.cuda.current_device())
labels = batch["labels"].to(torch.cuda.current_device())
# prompt_ids = prompt_ids.squeeze(1).cuda()
# p_mask = p_mask.squeeze(1).cuda()

outputs = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
loss = outputs.loss
# prompt_logits = outputs.logits

prompt_logits = self.model(prompt_ids, attention_mask=p_mask)
loss = self.loss_fn(prompt_logits, prompt_ids)
loss_sum += loss.item()
num_seen += prompt_ids.size(0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from chatgpt.models.lora import LoraLinear
from torch.optim import Optimizer


from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

Expand Down
6 changes: 6 additions & 0 deletions applications/ChatGPT/chatgpt/utils/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import transformers

from ..models.llama.llama_lm import LlamaLM

DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
Expand Down Expand Up @@ -60,6 +62,10 @@ def smart_tokenizer_and_embedding_resize(

if tokenizer.pad_token is None:
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)

if isinstance(model, LlamaLM):
model = model.get_base_model()

model.resize_token_embeddings(len(tokenizer))

if num_new_tokens > 0:
Expand Down
18 changes: 10 additions & 8 deletions applications/ChatGPT/examples/train_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,25 +93,27 @@ def train(args):
elif 'alpaca' in args.dataset:
train_dataset = AlpacaDataset(tokenizer=tokenizer, data_path=args.dataset)
eval_dataset = None
eval_dataset
data_collator = AlpacaDataCollator(tokenizer=tokenizer)

if dist.is_initialized() and dist.get_world_size() > 1:
sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True)
logger.info("Using Distributed Sampler")
train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True)
if eval_dataset is not None:
eval_sampler = DistributedSampler(eval_dataset, shuffle=False, seed=42, drop_last=False)
else:
sampler = None
train_sampler = None
eval_sampler = None

train_dataloader = DataLoader(train_dataset, shuffle=(sampler is None), sampler=sampler, batch_size=args.batch_size)
train_dataloader = DataLoader(train_dataset, shuffle=(train_sampler is None), sampler=train_sampler, batch_size=args.batch_size, collate_fn=data_collator)
if eval_dataset is not None:
eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size)
eval_dataloader = DataLoader(eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, collate_fn=data_collator)
else:
eval_dataloader = None

trainer = SFTTrainer(model=model,
strategy=strategy,
optim=optim,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
sampler=sampler,
batch_size=args.batch_size,
max_epochs=args.max_epochs)

Expand All @@ -128,7 +130,7 @@ def train(args):
parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default='yizhongw/self_instruct')
parser.add_argument('--save_path', type=str, default='sft_ckpt.pth')
Expand Down
8 changes: 7 additions & 1 deletion applications/ChatGPT/examples/train_sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,10 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 8

#torchrun --standalone --nproc_per_node=2 train_sft.py --pretrain 'bigscience/bloomz-560m' --model 'bloom' --strategy colossalai_zero2 --log_interval 10
#torchrun --standalone --nproc_per_node=8 train_sft.py --model 'gpt2' --strategy colossalai_zero2 --batch_size 1 --log_interval 10
torchrun --standalone --nproc_per_node=2 train_sft.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2 --log_interval 10
torchrun --standalone --nproc_per_node=8 train_sft.py \
--pretrain "/data/personal/nus-mql/LLAMA-7B" \
--model 'llama' \
--strategy colossalai_zero2 \
--log_interval 10 \
--save_path /data/personal/nus-mql/Coati-7B \
--dataset /data/personal/nus-mql/stanford_alpaca/alpaca_data.json
2 changes: 1 addition & 1 deletion applications/ChatGPT/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.0
1.0.0