Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions applications/ChatGPT/chatgpt/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .reward_dataset import RmStaticDataset, HhRlhfDataset
from .utils import is_rank_0
from .sft_dataset import SFTDataset
from .sft_dataset import SFTDataset, AlpacaDataset, AlpacaDataCollator

__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0', 'SFTDataset']
__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0', 'SFTDataset', 'AlpacaDataset', 'AlpacaDataCollator']
122 changes: 120 additions & 2 deletions applications/ChatGPT/chatgpt/dataset/sft_dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,46 @@
from typing import Callable
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from dataclasses import dataclass, field
from typing import Callable, Dict, Sequence
import random
from torch.utils.data import Dataset
import torch.distributed as dist
from tqdm import tqdm
import torch

from .utils import is_rank_0
from .utils import is_rank_0, jload

import transformers
from colossalai.logging import get_dist_logger

logger = get_dist_logger()

IGNORE_INDEX = -100
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}

class SFTDataset(Dataset):
"""
Expand Down Expand Up @@ -38,3 +72,87 @@ def __len__(self):

def __getitem__(self, idx):
return self.prompts[idx]


def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)

def preprocess(
sources: Sequence[str],
targets: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
label[:source_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=labels)

class AlpacaDataset(Dataset):
"""Dataset for supervised fine-tuning."""

def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
super(AlpacaDataset, self).__init__()
logger.info("Loading data...")
list_data_dict = jload(data_path)

logger.info("Formatting inputs...")
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
sources = [
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
for example in list_data_dict
]
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]

logger.info("Tokenizing inputs... This may take some time...")
data_dict = preprocess(sources, targets, tokenizer)

self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]

def __len__(self):
return len(self.input_ids)

def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(input_ids=self.input_ids[i], labels=self.labels[i])

@dataclass
class AlpacaDataCollator(object):
"""Collate examples for supervised fine-tuning."""

tokenizer: transformers.PreTrainedTokenizer

def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
15 changes: 15 additions & 0 deletions applications/ChatGPT/chatgpt/dataset/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
import io
import json

import torch.distributed as dist


def is_rank_0() -> bool:
return not dist.is_initialized() or dist.get_rank() == 0

def _make_r_io_base(f, mode: str):
if not isinstance(f, io.IOBase):
f = open(f, mode=mode)
return f

def jload(f, mode="r"):
"""Load a .json file into a dictionary."""
f = _make_r_io_base(f, mode)
jdict = json.load(f)
f.close()
return jdict
3 changes: 2 additions & 1 deletion applications/ChatGPT/chatgpt/models/llama/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .llama_actor import LlamaActor
from .llama_critic import LlamaCritic
from .llama_rm import LlamaRM
from .llama_lm import LlamaLM

__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM']
__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM', 'LlamaLM']
38 changes: 38 additions & 0 deletions applications/ChatGPT/chatgpt/models/llama/llama_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Optional

from transformers import LlamaConfig, LlamaForCausalLM

from ..base import LM


class LlamaLM(LM):
"""
Llama language model.

Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""

def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:

if pretrained is not None:
model = LlamaForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = LlamaForCausalLM(config)
else:
model = LlamaForCausalLM(LlamaConfig())

if checkpoint:
model.gradient_checkpointing_enable()

super().__init__(model, lora_rank, lora_train_bias)

50 changes: 24 additions & 26 deletions applications/ChatGPT/chatgpt/trainer/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Optional
import loralib as lora
import torch
from chatgpt.dataset import SFTDataset
from chatgpt.models.loss import GPTLMLoss
from torch.optim import Adam, Optimizer
from torch.utils.data import DataLoader
Expand All @@ -22,8 +21,8 @@ class SFTTrainer(ABC):
model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training
train_dataset (SFTDataset or SFTDistributedDataset): the dataset to use for training
eval_dataset (SFTDataset or SFTDistributedDataset): the dataset to use for evaluation
train_dataloader: the dataloader to use for training
eval_dataloader: the dataloader to use for evaluation
batch_size (int, defaults to 1): the batch size while training
max_epochs (int, defaults to 2): the number of epochs to train
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
Expand All @@ -34,22 +33,19 @@ def __init__(
model,
strategy: Strategy,
optim: Optimizer,
train_dataset: SFTDataset,
eval_dataset: SFTDataset,
train_dataloader: DataLoader,
eval_dataloader: DataLoader = None,
sampler: Optional[DistributedSampler] = None,
batch_size: int = 1,
max_epochs: int = 2,
) -> None:
super().__init__()
self.strategy = strategy
self.epochs = max_epochs
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.sampler = sampler

self.train_dataloader = DataLoader(self.train_dataset, shuffle=(sampler is None),
sampler=sampler, batch_size=batch_size)
self.eval_dataloader = DataLoader(self.eval_dataset, batch_size=batch_size)
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader

self.model = strategy.setup_model(model)
if "DDP" in str(self.strategy):
Expand Down Expand Up @@ -79,23 +75,25 @@ def fit(self, logger, use_lora, log_interval=10):
logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}')

# eval
self.model.eval()
with torch.no_grad():
loss_sum = 0
num_seen = 0
for batch in self.eval_dataloader:
prompt_ids = batch["input_ids"]
p_mask = batch["attention_mask"]
prompt_ids = prompt_ids.squeeze(1).cuda()
p_mask = p_mask.squeeze(1).cuda()
if self.eval_dataloader is not None:
self.model.eval()
with torch.no_grad():
loss_sum = 0
num_seen = 0
for batch in self.eval_dataloader:
prompt_ids = batch["input_ids"]
p_mask = batch["attention_mask"]
prompt_ids = prompt_ids.squeeze(1).cuda()
p_mask = p_mask.squeeze(1).cuda()

prompt_logits = self.model(prompt_ids, attention_mask=p_mask)
loss = self.loss_fn(prompt_logits, prompt_ids)
loss_sum += loss.item()
num_seen += prompt_ids.size(0)
prompt_logits = self.model(prompt_ids, attention_mask=p_mask)
loss = self.loss_fn(prompt_logits, prompt_ids)
loss_sum += loss.item()
num_seen += prompt_ids.size(0)

loss_mean = loss_sum / num_seen
if dist.get_rank() == 0:
logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}')
loss_mean = loss_sum / num_seen
if dist.get_rank() == 0:
logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}')

epoch_bar.update()

3 changes: 3 additions & 0 deletions applications/ChatGPT/chatgpt/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .tokenizer_utils import smart_tokenizer_and_embedding_resize, prepare_llama_tokenizer_and_embedding

__all__ = ['smart_tokenizer_and_embedding_resize', 'prepare_llama_tokenizer_and_embedding']
74 changes: 74 additions & 0 deletions applications/ChatGPT/chatgpt/utils/tokenizer_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict

import transformers

DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "</s>"

def prepare_llama_tokenizer_and_embedding(
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN),
):
"""prepare llama tokenizer and embedding.

"""

if tokenizer.pad_token is None:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
tokenizer=tokenizer,
model=model,
)

tokenizer.add_special_tokens(
{
"eos_token": DEFAULT_EOS_TOKEN,
"bos_token": DEFAULT_BOS_TOKEN,
"unk_token": DEFAULT_UNK_TOKEN,
}
)

return tokenizer


def smart_tokenizer_and_embedding_resize(
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN),
):
"""Resize tokenizer and embedding.

Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""

if tokenizer.pad_token is None:
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))

if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data

input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)

input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg

Loading