Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion applications/Chat/coati/dataset/prompt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self,
self.keyed_prompt[k].extend(tensor.to(torch.cuda.current_device()).unbind())

def __len__(self):
return len(self.keyed_prompt)
return len(self.keyed_prompt["input_ids"])

def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return {k: v[i] for k, v in self.keyed_prompt.items()}
2 changes: 1 addition & 1 deletion applications/Chat/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ torchrun --standalone --nproc_per_node=4 train_prompts.py \
--rm_path /your/rm/model/path
```

Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use the [script](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/example_data_reformat.py) to reformat [seed_prompts_ch.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_ch.jsonl) or [seed_prompts_en.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_en.jsonl) in InstructionWild.
Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use the [script](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/generate_prompt_dataset.py) which samples `instinwild_en.json` or `instinwild_ch.json` in [InstructionWild](https://github.com/XueFuzhao/InstructionWild/tree/main/data#instructwild-data) to generate the prompt dataset.
Pretrain dataset: the pretrain dataset including the instruction and corresponding response, e.g. you can use the [InstructWild Data](https://github.com/XueFuzhao/InstructionWild/tree/main/data) in stage 1 supervised instructs tuning.

### Arg List
Expand Down
12 changes: 0 additions & 12 deletions applications/Chat/examples/example_data_reformat.py

This file was deleted.

30 changes: 30 additions & 0 deletions applications/Chat/examples/generate_prompt_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import argparse

import random
import json

random.seed(42)


def sample(args):
with open(args.dataset_path, mode='r') as f:
dataset_list = json.load(f)

sampled_dataset = [{"instruction": sample["instruction"], "id":idx}
for idx, sample in enumerate(random.sample(dataset_list, args.sample_size))]

with open(args.save_path, mode='w') as f:
json.dump(sampled_dataset, f, indent=4,
default=str, ensure_ascii=False)


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', type=str, default=None,
required=True, help="path to the pretrain dataset")
parser.add_argument('--save_path', type=str, default='prompt.json',
help="path to save the prompt dataset")
parser.add_argument('--sample_size', type=int,
default=16384, help="size of the prompt dataset")
args = parser.parse_args()
sample(args)