Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
e5ce311
style: rename replay buffer
cwher Jul 4, 2023
9bba20f
fix: fix wrong zero2 default arg
cwher Jul 4, 2023
f23babe
test: update experience tests
cwher Jul 4, 2023
9b7d798
style: rename zero_pad fn
cwher Jul 4, 2023
1bb3747
fix: defer init in CycledDataLoader
cwher Jul 11, 2023
961f0a9
test: add benchmark test
cwher Jul 11, 2023
d3da644
style: rename internal fn of generation
cwher Jul 11, 2023
974dea6
style: rename internal fn of lora
cwher Jul 11, 2023
5108c64
fix: remove unused loss fn
cwher Jul 11, 2023
7222858
fix: remove unused utils fn
cwher Jul 11, 2023
aa47fc1
refactor: remove generate_with_actor fn
cwher Jul 11, 2023
0b65963
fix: fix type annotation
cwher Jul 11, 2023
5167991
test: add models tests
cwher Jul 11, 2023
aa1ffd0
fix: skip llama due to long execution time
cwher Jul 11, 2023
467781b
style: modify dataset
cwher Jul 12, 2023
6961ca9
style: apply formatter
cwher Jul 13, 2023
e17aa6c
perf: update reward dataset
cwher Jul 13, 2023
7ce6a6e
fix: fix wrong IGNORE_INDEX in sft dataset
cwher Jul 13, 2023
554ddf8
fix: remove DataCollatorForSupervisedDataset
cwher Jul 13, 2023
7b7af17
test: add dataset tests
cwher Jul 13, 2023
cd5408a
style: apply formatter
cwher Jul 13, 2023
cf1add3
style: rename test_ci to test_train
cwher Jul 13, 2023
35cce90
feat: add llama in inference
cwher Jul 13, 2023
85365d1
test: add inference tests
cwher Jul 13, 2023
47f4446
test: change test scripts directory
cwher Jul 14, 2023
456e6b3
fix: update ci
cwher Jul 14, 2023
412fd3e
fix: fix typo
cwher Jul 14, 2023
b6c3b0b
fix: skip llama due to oom
cwher Jul 14, 2023
d736e80
fix: fix file mod
cwher Jul 14, 2023
57b4750
style: apply formatter
cwher Jul 14, 2023
d06bf53
refactor: remove duplicated llama_gptq
cwher Jul 14, 2023
f12cf08
style: apply formatter
cwher Jul 14, 2023
8d778ea
to: update rm test
cwher Jul 18, 2023
dfee23c
feat: add tokenizer arg
cwher Jul 18, 2023
fe407ff
feat: add download model script
cwher Jul 18, 2023
7d650ab
test: update train tests
cwher Jul 18, 2023
e4c8846
fix: modify gemini load and save pretrained
cwher Jul 18, 2023
efe392d
test: update checkpoint io test
cwher Jul 18, 2023
e906bc6
to: modify nproc_per_node
cwher Jul 18, 2023
4cb6b6a
fix: do not remove existing dir
cwher Jul 18, 2023
dc46c6d
fix: modify save path
cwher Jul 18, 2023
fde8b86
test: add random choice
cwher Jul 18, 2023
5bac8dd
fix: fix sft path
cwher Jul 18, 2023
f564a06
fix: enlarge nproc_per_node to avoid oom
cwher Jul 19, 2023
360605d
fix: add num_retry
cwher Jul 19, 2023
ca4a817
fix: make lora config of rm and critic consistent
cwher Jul 19, 2023
bcfe320
fix: add warning about lora weights
cwher Jul 19, 2023
6ad3755
fix: skip some gpt2 tests
cwher Jul 19, 2023
fb2051d
fix: remove grad ckpt in rm and critic due to errors
cwher Jul 22, 2023
d854f69
refactor: directly use Actor in train_sft
cwher Jul 22, 2023
1614a7d
test: add more arguments
cwher Jul 22, 2023
44b2955
fix: disable grad ckpt when using lora
cwher Jul 22, 2023
d1e18bf
fix: fix save_pretrained and related tests
cwher Jul 24, 2023
9591545
test: enable zero2 tests
cwher Aug 1, 2023
d0d3a24
revert: remove useless fn
cwher Aug 1, 2023
73e2347
style: polish code
cwher Aug 1, 2023
bb20f43
test: modify test args
cwher Aug 1, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/run_chatgpt_examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ jobs:
run: |
cd applications/Chat
rm -rf ~/.cache/colossalai
./examples/test_ci.sh
./tests/test_inference.sh
./tests/test_benchmarks.sh
./tests/test_train.sh
env:
NCCL_SHM_DISABLE: 1
MAX_JOBS: 8
Expand Down
7 changes: 4 additions & 3 deletions applications/Chat/coati/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .prompt_dataset import PromptDataset
from .reward_dataset import HhRlhfDataset, RmStaticDataset
from .sft_dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
from .sft_dataset import SFTDataset, SupervisedDataset
from .utils import is_rank_0

__all__ = [
'RmStaticDataset', 'HhRlhfDataset', 'is_rank_0', 'SFTDataset', 'SupervisedDataset',
'DataCollatorForSupervisedDataset', 'PromptDataset'
'RmStaticDataset', 'HhRlhfDataset',
'SFTDataset', 'SupervisedDataset',
'PromptDataset', 'is_rank_0',
]
18 changes: 6 additions & 12 deletions applications/Chat/coati/dataset/prompt_dataset.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
import copy
import random
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Callable, Dict, Sequence
from typing import Dict

import torch
import torch.distributed as dist
import transformers
from torch.utils.data import Dataset
from tqdm import tqdm

from colossalai.logging import get_dist_logger

from .utils import is_rank_0, jload

logger = get_dist_logger()
from .utils import jload


class PromptDataset(Dataset):
Expand All @@ -27,12 +20,13 @@ def __init__(self,
max_length: int = 96):
super(PromptDataset, self).__init__()
self.keyed_prompt = defaultdict(list)
logger.info("Loading data...")
self.logger = get_dist_logger()
self.logger.info("Loading data...")
list_data_dict = jload(data_path)
logger.info(f"Loaded {len(list_data_dict)} examples.")
self.logger.info(f"Loaded {len(list_data_dict)} examples.")

if max_datasets_size is not None:
logger.info(f"Limiting dataset to {max_datasets_size} examples.")
self.logger.info(f"Limiting dataset to {max_datasets_size} examples.")
list_data_dict = list_data_dict[:max_datasets_size]

instructions = [data_dict["instruction"] for data_dict in list_data_dict]
Expand Down
130 changes: 66 additions & 64 deletions applications/Chat/coati/dataset/reward_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,44 +20,44 @@ class RmStaticDataset(Dataset):

def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
self.chosen = []
self.reject = []
if special_token is None:
self.end_token = tokenizer.eos_token
else:
self.end_token = special_token
for data in tqdm(dataset, disable=not is_rank_0()):
prompt = data['prompt']

chosen = prompt + data['chosen'] + self.end_token
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.chosen.append({
"input_ids": chosen_token['input_ids'],
"attention_mask": chosen_token['attention_mask']
})

reject = prompt + data['rejected'] + self.end_token
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject.append({
"input_ids": reject_token['input_ids'],
"attention_mask": reject_token['attention_mask']
})
self.end_token = tokenizer.eos_token \
if special_token is None else special_token

chosen = [
data["prompt"] + data["chosen"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
]
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.chosen = {
"input_ids": chosen_token["input_ids"],
"attention_mask": chosen_token["attention_mask"]
}

reject = [
data["prompt"] + data["rejected"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
]
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject = {
"input_ids": reject_token["input_ids"],
"attention_mask": reject_token["attention_mask"]
}

def __len__(self):
length = len(self.chosen)
length = self.chosen["input_ids"].shape[0]
return length

def __getitem__(self, idx):
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
"input_ids"], self.reject[idx]["attention_mask"]
return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
self.reject["input_ids"][idx], self.reject["attention_mask"][idx]


# Anthropic/hh-rlhf
Expand All @@ -74,39 +74,41 @@ class HhRlhfDataset(Dataset):

def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
self.chosen = []
self.reject = []
if special_token is None:
self.end_token = tokenizer.eos_token
else:
self.end_token = special_token
for data in tqdm(dataset, disable=not is_rank_0()):
chosen = data['chosen'] + self.end_token
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.chosen.append({
"input_ids": chosen_token['input_ids'],
"attention_mask": chosen_token['attention_mask']
})

reject = data['rejected'] + self.end_token
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject.append({
"input_ids": reject_token['input_ids'],
"attention_mask": reject_token['attention_mask']
})
self.end_token = tokenizer.eos_token \
if special_token is None else special_token

chosen = [
data["chosen"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
]
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.chosen = {
"input_ids": chosen_token["input_ids"],
"attention_mask": chosen_token["attention_mask"]
}

reject = [
data["rejected"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
]
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject = {
"input_ids": reject_token["input_ids"],
"attention_mask": reject_token["attention_mask"]
}

def __len__(self):
length = len(self.chosen)
length = self.chosen["input_ids"].shape[0]
return length

def __getitem__(self, idx):
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
"input_ids"], self.reject[idx]["attention_mask"]
return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
Loading