We also train the reward model based on LLaMA-7B, which reaches the ACC of 72.06% after 1 epoch, performing almost the same as Anthropic's best RM.
-
-### Arg List
-
-- `--strategy`: the strategy using for training, choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2'
-- `--model`: model type, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom'
-- `--pretrain`: pretrain model, type=str, default=None
-- `--model_path`: the path of rm model(if continue to train), type=str, default=None
-- `--save_path`: path to save the model, type=str, default='output'
-- `--need_optim_ckpt`: whether to save optim ckpt, type=bool, default=False
-- `--max_epochs`: max epochs for training, type=int, default=3
-- `--dataset`: dataset name, type=str, choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static']
-- `--subset`: subset of the dataset, type=str, default=None
-- `--batch_size`: batch size while training, type=int, default=4
-- `--lora_rank`: low-rank adaptation matrices rank, type=int, default=0
-- `--loss_func`: which kind of loss function, choices=['log_sig', 'log_exp']
-- `--max_len`: max sentence length for generation, type=int, default=512
-
-## Stage3 - Training model using prompts with RL
-
-Stage3 uses reinforcement learning algorithm, which is the most complex part of the training process, as shown below:
-
-
-
-
-
-You can run the `examples/train_prompts.sh` to start PPO training.
-
-You can also use the cmd following to start PPO training.
-[[Stage3 tutorial video]](https://www.youtube.com/watch?v=Z8wwSHxPL9g)
-
-```bash
-torchrun --standalone --nproc_per_node=4 train_prompts.py \
- --pretrain "/path/to/LLaMa-7B/" \
- --model 'llama' \
- --strategy colossalai_zero2 \
- --prompt_dataset /path/to/your/prompt_dataset \
- --pretrain_dataset /path/to/your/pretrain_dataset \
- --rm_pretrain /your/pretrain/rm/definition \
- --rm_path /your/rm/model/path
-```
-
-Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use the [script](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/generate_prompt_dataset.py) which samples `instinwild_en.json` or `instinwild_ch.json` in [InstructionWild](https://github.com/XueFuzhao/InstructionWild/tree/main/data#instructwild-data) to generate the prompt dataset.
-Pretrain dataset: the pretrain dataset including the instruction and corresponding response, e.g. you can use the [InstructWild Data](https://github.com/XueFuzhao/InstructionWild/tree/main/data) in stage 1 supervised instructs tuning.
-
-**Note**: the required datasets follow the following format,
-
-- `pretrain dataset`
-
- ```json
- [
- {
- "instruction": "Provide a list of the top 10 most popular mobile games in Asia",
- "input": "",
- "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
- "id": 0
- },
- ...
- ]
- ```
-
-- `prompt dataset`
-
- ```json
- [
- {
- "instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"",
- "id": 0
- },
- {
- "instruction": "Write a descriptive paragraph about a memorable vacation you went on",
- "id": 1
- },
- ...
- ]
- ```
-
-### Arg List
-
-- `--strategy`: the strategy using for training, choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2'
-- `--model`: model type of actor, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom'
-- `--pretrain`: pretrain model, type=str, default=None
-- `--rm_model`: reward model type, type=str, choices=['gpt2', 'bloom', 'opt', 'llama'], default=None
-- `--rm_pretrain`: pretrain model for reward model, type=str, default=None
-- `--rm_path`: the path of rm model, type=str, default=None
-- `--save_path`: path to save the model, type=str, default='output'
-- `--prompt_dataset`: path of the prompt dataset, type=str, default=None
-- `--pretrain_dataset`: path of the ptx dataset, type=str, default=None
-- `--need_optim_ckpt`: whether to save optim ckpt, type=bool, default=False
-- `--num_episodes`: num of episodes for training, type=int, default=10
-- `--num_update_steps`: number of steps to update policy per episode, type=int
-- `--num_collect_steps`: number of steps to collect experience per episode, type=int
-- `--train_batch_size`: batch size while training, type=int, default=8
-- `--ptx_batch_size`: batch size to compute ptx loss, type=int, default=1
-- `--experience_batch_size`: batch size to make experience, type=int, default=8
-- `--lora_rank`: low-rank adaptation matrices rank, type=int, default=0
-- `--kl_coef`: kl_coef using for computing reward, type=float, default=0.1
-- `--ptx_coef`: ptx_coef using for computing policy loss, type=float, default=0.9
-
-## Inference example - After Stage3
-
-We support different inference options, including int8 and int4 quantization.
-For details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
-
-## Attention
-
-The examples are demos for the whole training process.You need to change the hyper-parameters to reach great performance.
-
-#### data
-
-- [x] [rm-static](https://huggingface.co/datasets/Dahoas/rm-static)
-- [x] [hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)
-- [ ] [openai/summarize_from_feedback](https://huggingface.co/datasets/openai/summarize_from_feedback)
-- [ ] [openai/webgpt_comparisons](https://huggingface.co/datasets/openai/webgpt_comparisons)
-- [ ] [Dahoas/instruct-synthetic-prompt-responses](https://huggingface.co/datasets/Dahoas/instruct-synthetic-prompt-responses)
-
-## Support Model
-
-### GPT
-
-- [x] GPT2-S (s)
-- [x] GPT2-M (m)
-- [x] GPT2-L (l)
-- [x] GPT2-XL (xl)
-- [x] GPT2-4B (4b)
-- [ ] GPT2-6B (6b)
-
-### BLOOM
-
-- [x] [BLOOM-560m](https://huggingface.co/bigscience/bloom-560m)
-- [x] [BLOOM-1b1](https://huggingface.co/bigscience/bloom-1b1)
-- [x] [BLOOM-3b](https://huggingface.co/bigscience/bloom-3b)
-- [x] [BLOOM-7b](https://huggingface.co/bigscience/bloom-7b1)
-- [ ] [BLOOM-175b](https://huggingface.co/bigscience/bloom)
-
-### OPT
-
-- [x] [OPT-125M](https://huggingface.co/facebook/opt-125m)
-- [x] [OPT-350M](https://huggingface.co/facebook/opt-350m)
-- [x] [OPT-1.3B](https://huggingface.co/facebook/opt-1.3b)
-- [x] [OPT-2.7B](https://huggingface.co/facebook/opt-2.7b)
-- [x] [OPT-6.7B](https://huggingface.co/facebook/opt-6.7b)
-- [ ] [OPT-13B](https://huggingface.co/facebook/opt-13b)
-- [ ] [OPT-30B](https://huggingface.co/facebook/opt-30b)
-
-### [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)
-
-- [x] LLaMA-7B
-- [x] LLaMA-13B
-- [ ] LLaMA-33B
-- [ ] LLaMA-65B
-
-## Add your own models
-
-If you want to support your own model in Coati, please refer the pull request for RoBERTa support as an example --[[chatgpt] add pre-trained model RoBERTa for RLHF stage 2 & 3](https://github.com/hpcaitech/ColossalAI/pull/3223), and submit a PR to us.
-
-You should complete the implementation of four model classes, including Reward model, Critic model, LM model, Actor model
-
-here are some example code for a NewModel named `Coati`.
-if it is supported in huggingface [transformers](https://github.com/huggingface/transformers), you can load it by `from_pretrained`, o
-r you can build your own model by yourself.
-
-### Actor model
-
-```python
-from ..base import Actor
-from transformers.models.coati import CoatiModel
-
-class CoatiActor(Actor):
- def __init__(self,
- pretrained: Optional[str] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
- if pretrained is not None:
- model = CoatiModel.from_pretrained(pretrained)
- else:
- model = build_model() # load your own model if it is not support in transformers
-
- super().__init__(model, lora_rank, lora_train_bias)
-```
-
-### Reward model
-
-```python
-from ..base import RewardModel
-from transformers.models.coati import CoatiModel
-
-class CoatiRM(RewardModel):
-
- def __init__(self,
- pretrained: Optional[str] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
- if pretrained is not None:
- model = CoatiModel.from_pretrained(pretrained)
- else:
- model = build_model() # load your own model if it is not support in transformers
-
- value_head = nn.Linear(model.config.n_embd, 1)
- value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1))
- super().__init__(model, value_head, lora_rank, lora_train_bias)
-```
-
-### Critic model
-
-```python
-from ..base import Critic
-from transformers.models.coati import CoatiModel
-
-class CoatiCritic(Critic):
- def __init__(self,
- pretrained: Optional[str] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
- if pretrained is not None:
- model = CoatiModel.from_pretrained(pretrained)
- else:
- model = build_model() # load your own model if it is not support in transformers
-
- value_head = nn.Linear(model.config.n_embd, 1)
- value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1))
- super().__init__(model, value_head, lora_rank, lora_train_bias)
-```
diff --git a/applications/Chat/examples/download_model.py b/applications/Chat/examples/download_model.py
deleted file mode 100644
index ec3482b5f789..000000000000
--- a/applications/Chat/examples/download_model.py
+++ /dev/null
@@ -1,79 +0,0 @@
-import argparse
-import dataclasses
-import os
-import parser
-from typing import List
-
-import tqdm
-from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
-from coati.models.gpt import GPTRM, GPTActor, GPTCritic
-from coati.models.opt import OPTRM, OPTActor, OPTCritic
-from huggingface_hub import hf_hub_download, snapshot_download
-from transformers import AutoConfig, AutoTokenizer, BloomConfig, BloomTokenizerFast, GPT2Config, GPT2Tokenizer
-
-
-@dataclasses.dataclass
-class HFRepoFiles:
- repo_id: str
- files: List[str]
-
- def download(self, dir_path: str):
- for file in self.files:
- file_path = hf_hub_download(self.repo_id, file, local_dir=dir_path)
-
- def download_all(self):
- snapshot_download(self.repo_id)
-
-
-def test_init(model: str, dir_path: str):
- if model == "gpt2":
- config = GPT2Config.from_pretrained(dir_path)
- actor = GPTActor(config=config)
- critic = GPTCritic(config=config)
- reward_model = GPTRM(config=config)
- GPT2Tokenizer.from_pretrained(dir_path)
- elif model == "bloom":
- config = BloomConfig.from_pretrained(dir_path)
- actor = BLOOMActor(config=config)
- critic = BLOOMCritic(config=config)
- reward_model = BLOOMRM(config=config)
- BloomTokenizerFast.from_pretrained(dir_path)
- elif model == "opt":
- config = AutoConfig.from_pretrained(dir_path)
- actor = OPTActor(config=config)
- critic = OPTCritic(config=config)
- reward_model = OPTRM(config=config)
- AutoTokenizer.from_pretrained(dir_path)
- else:
- raise NotImplementedError(f"Model {model} not implemented")
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--model-dir", type=str, default="test_models")
- parser.add_argument("--config-only", default=False, action="store_true")
- args = parser.parse_args()
-
- if os.path.exists(args.model_dir):
- print(f"[INFO]: {args.model_dir} already exists")
- exit(0)
-
- repo_list = {
- "gpt2": HFRepoFiles(repo_id="gpt2", files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"]),
- "bloom": HFRepoFiles(
- repo_id="bigscience/bloom-560m", files=["config.json", "tokenizer.json", "tokenizer_config.json"]
- ),
- "opt": HFRepoFiles(
- repo_id="facebook/opt-350m", files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"]
- ),
- }
-
- os.mkdir(args.model_dir)
- for model_name in tqdm.tqdm(repo_list):
- dir_path = os.path.join(args.model_dir, model_name)
- if args.config_only:
- os.mkdir(dir_path)
- repo_list[model_name].download(dir_path)
- else:
- repo_list[model_name].download_all()
- test_init(model_name, dir_path)
diff --git a/applications/Chat/examples/generate_conversation_dataset.py b/applications/Chat/examples/generate_conversation_dataset.py
deleted file mode 100644
index 7e03b2d54260..000000000000
--- a/applications/Chat/examples/generate_conversation_dataset.py
+++ /dev/null
@@ -1,82 +0,0 @@
-import argparse
-import json
-
-from datasets import load_dataset
-
-
-def generate_alpaca():
- # We can convert dataset with the same format("instruction", "input", "output") as Alpaca into a one-round conversation.
- conversation_dataset = []
- dataset = load_dataset("tatsu-lab/alpaca", split="train")
-
- instructions = dataset["instruction"]
- inputs = dataset["input"]
- outputs = dataset["output"]
-
- assert len(instructions) == len(inputs) == len(outputs)
-
- for idx in range(len(instructions)):
- human_utterance = instructions[idx] + "\n\n" + inputs[idx] if inputs[idx] else instructions[idx]
- human = {"from": "human", "value": human_utterance}
-
- gpt_utterance = outputs[idx]
- gpt = {"from": "gpt", "value": gpt_utterance}
-
- conversation = dict(type="instruction", language="English", dataset="Alpaca", conversations=[human, gpt])
- conversation_dataset.append(conversation)
-
- return conversation_dataset
-
-
-def generate_sharegpt():
- # ShareGPT data requires less processing.
- conversation_dataset = []
- dataset = load_dataset(
- "anon8231489123/ShareGPT_Vicuna_unfiltered",
- data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json",
- split="train",
- )
-
- conversations = dataset["conversations"]
-
- for idx in range(len(conversations)):
- for conv in conversations[idx]:
- # We don't need markdown and text value.
- del conv["markdown"]
- del conv["text"]
-
- conversation = dict(
- type="conversation", language="Multilingual", dataset="ShareGPT", conversations=conversations[idx]
- )
- conversation_dataset.append(conversation)
-
- return conversation_dataset
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--dataset",
- type=str,
- default="All",
- choices=["Alpaca", "ShareGPT", "All"],
- help="which dataset to convert, All will combine Alpaca and ShareGPT",
- )
- parser.add_argument("--save_path", type=str, default="dataset.json", help="path to save the converted dataset")
- args = parser.parse_args()
-
- conversation_dataset = []
-
- if args.dataset == "Alpaca":
- conversation_dataset.extend(generate_alpaca())
- elif args.dataset == "ShareGPT":
- conversation_dataset.extend(generate_sharegpt())
- else:
- conversation_dataset.extend(generate_alpaca())
- conversation_dataset.extend(generate_sharegpt())
-
- for idx, sample in enumerate(conversation_dataset):
- sample["id"] = idx + 1
-
- with open(args.save_path, mode="w") as f:
- json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False)
diff --git a/applications/Chat/examples/generate_prompt_dataset.py b/applications/Chat/examples/generate_prompt_dataset.py
deleted file mode 100644
index 4eec6feae505..000000000000
--- a/applications/Chat/examples/generate_prompt_dataset.py
+++ /dev/null
@@ -1,27 +0,0 @@
-import argparse
-import json
-import random
-
-random.seed(42)
-
-
-def sample(args):
- with open(args.dataset_path, mode="r") as f:
- dataset_list = json.load(f)
-
- sampled_dataset = [
- {"instruction": sample["instruction"], "id": idx}
- for idx, sample in enumerate(random.sample(dataset_list, args.sample_size))
- ]
-
- with open(args.save_path, mode="w") as f:
- json.dump(sampled_dataset, f, indent=4, default=str, ensure_ascii=False)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--dataset_path", type=str, default=None, required=True, help="path to the pretrain dataset")
- parser.add_argument("--save_path", type=str, default="prompt.json", help="path to save the prompt dataset")
- parser.add_argument("--sample_size", type=int, default=16384, help="size of the prompt dataset")
- args = parser.parse_args()
- sample(args)
diff --git a/applications/Chat/examples/inference.py b/applications/Chat/examples/inference.py
deleted file mode 100644
index 9df8649d9c61..000000000000
--- a/applications/Chat/examples/inference.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import argparse
-
-import torch
-from coati.models.bloom import BLOOMActor
-from coati.models.generation import generate
-from coati.models.gpt import GPTActor
-from coati.models.llama import LlamaActor
-from coati.models.opt import OPTActor
-from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
-
-
-def eval(args):
- # configure model
- if args.model == "gpt2":
- actor = GPTActor(pretrained=args.pretrain)
- elif args.model == "bloom":
- actor = BLOOMActor(pretrained=args.pretrain)
- elif args.model == "opt":
- actor = OPTActor(pretrained=args.pretrain)
- elif args.model == "llama":
- actor = LlamaActor(pretrained=args.pretrain)
- else:
- raise ValueError(f'Unsupported model "{args.model}"')
-
- actor.to(torch.cuda.current_device())
- if args.model_path is not None:
- state_dict = torch.load(args.model_path)
- actor.load_state_dict(state_dict)
-
- # configure tokenizer
- if args.model == "gpt2":
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == "bloom":
- tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == "opt":
- tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == "llama":
- tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
- tokenizer.eos_token = ""
- tokenizer.pad_token = tokenizer.unk_token
- else:
- raise ValueError(f'Unsupported model "{args.model}"')
-
- actor.eval()
- tokenizer.padding_side = "left"
- input_ids = tokenizer.encode(args.input, return_tensors="pt").to(torch.cuda.current_device())
- outputs = generate(
- actor,
- input_ids,
- tokenizer=tokenizer,
- max_length=args.max_length,
- do_sample=True,
- top_k=50,
- top_p=0.95,
- num_return_sequences=1,
- )
- output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
- print(f"[Output]: {''.join(output)}")
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
- # We suggest to use the pretrained model from HuggingFace, use pretrain to configure model
- parser.add_argument("--pretrain", type=str, default=None)
- parser.add_argument("--model_path", type=str, default=None)
- parser.add_argument("--input", type=str, default="Question: How are you ? Answer:")
- parser.add_argument("--max_length", type=int, default=100)
- args = parser.parse_args()
- eval(args)
diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py
deleted file mode 100644
index 40e06043ab57..000000000000
--- a/applications/Chat/examples/train_prompts.py
+++ /dev/null
@@ -1,249 +0,0 @@
-import argparse
-import warnings
-
-import torch
-import torch.distributed as dist
-from coati.dataset import PromptDataset, SupervisedDataset
-from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
-from coati.models.gpt import GPTRM, GPTActor, GPTCritic
-from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
-from coati.models.opt import OPTRM, OPTActor, OPTCritic
-from coati.trainer import PPOTrainer
-from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
-from torch.optim import Adam
-from torch.utils.data import DataLoader
-from torch.utils.data.distributed import DistributedSampler
-from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
-
-from colossalai.nn.optimizer import HybridAdam
-
-
-def main(args):
- # configure strategy
- if args.strategy == "ddp":
- strategy = DDPStrategy()
- elif args.strategy == "colossalai_gemini":
- strategy = GeminiStrategy(placement_policy="static", initial_scale=2**5)
- elif args.strategy == "colossalai_zero2":
- strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
- else:
- raise ValueError(f'Unsupported strategy "{args.strategy}"')
-
- if args.rm_path is not None:
- warnings.warn("LoRA weights should be merged with the model weights")
- state_dict = torch.load(args.rm_path, map_location="cpu")
-
- if args.lora_rank > 0:
- warnings.warn("Lora is not supported yet.")
- args.lora_rank = 0
-
- with strategy.model_init_context():
- # configure model
- if args.model == "gpt2":
- initial_model = GPTActor(pretrained=args.pretrain)
- elif args.model == "bloom":
- initial_model = BLOOMActor(pretrained=args.pretrain)
- elif args.model == "opt":
- initial_model = OPTActor(pretrained=args.pretrain)
- elif args.model == "llama":
- initial_model = LlamaActor(pretrained=args.pretrain)
- else:
- raise ValueError(f'Unsupported actor model "{args.model}"')
-
- if args.rm_model is None:
- rm_model_name = args.model
- else:
- rm_model_name = args.rm_model
-
- if rm_model_name == "gpt2":
- reward_model = GPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
- elif rm_model_name == "bloom":
- reward_model = BLOOMRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
- elif rm_model_name == "opt":
- reward_model = OPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
- elif rm_model_name == "llama":
- reward_model = LlamaRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
- else:
- raise ValueError(f'Unsupported reward model "{rm_model_name}"')
-
- if args.rm_path is not None:
- reward_model.load_state_dict(state_dict, strict=False)
-
- initial_model.to(torch.bfloat16).to(torch.cuda.current_device())
- reward_model.to(torch.bfloat16).to(torch.cuda.current_device())
-
- if args.model == "gpt2":
- actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
- elif args.model == "bloom":
- actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
- elif args.model == "opt":
- actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
- elif args.model == "llama":
- actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
- else:
- raise ValueError(f'Unsupported actor model "{args.model}"')
-
- if rm_model_name == "gpt2":
- critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
- elif rm_model_name == "bloom":
- critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
- elif rm_model_name == "opt":
- critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
- elif rm_model_name == "llama":
- critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
- else:
- raise ValueError(f'Unsupported reward model "{rm_model_name}"')
-
- if args.rm_path is not None:
- critic.load_state_dict(state_dict, strict=False)
- del state_dict
-
- actor.to(torch.bfloat16).to(torch.cuda.current_device())
- critic.to(torch.bfloat16).to(torch.cuda.current_device())
-
- # configure optimizer
- if args.strategy.startswith("colossalai"):
- actor_optim = HybridAdam(actor.parameters(), lr=args.lr)
- critic_optim = HybridAdam(critic.parameters(), lr=args.lr)
- else:
- actor_optim = Adam(actor.parameters(), lr=args.lr)
- critic_optim = Adam(critic.parameters(), lr=args.lr)
-
- # configure tokenizer
- if args.model == "gpt2":
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == "bloom":
- tokenizer = BloomTokenizerFast.from_pretrained(
- "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
- )
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == "opt":
- tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == "llama":
- tokenizer = LlamaTokenizer.from_pretrained(
- "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
- )
- tokenizer.eos_token = ""
- tokenizer.pad_token = tokenizer.unk_token
- else:
- raise ValueError(f'Unsupported model "{args.model}"')
- # NOTE: generate() requires padding_side to be "left"
- tokenizer.padding_side = "left"
-
- prompt_dataset = PromptDataset(
- tokenizer=tokenizer,
- data_path=args.prompt_dataset,
- max_datasets_size=args.max_datasets_size,
- max_length=args.max_input_len,
- )
- if dist.is_initialized() and dist.get_world_size() > 1:
- prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
- else:
- prompt_sampler = None
- prompt_dataloader = DataLoader(
- prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.experience_batch_size
- )
-
- pretrain_dataset = SupervisedDataset(
- tokenizer=tokenizer,
- data_path=args.pretrain_dataset,
- max_datasets_size=args.max_datasets_size,
- max_length=args.max_input_len,
- )
- if dist.is_initialized() and dist.get_world_size() > 1:
- pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
- else:
- pretrain_sampler = None
- pretrain_dataloader = DataLoader(
- pretrain_dataset, shuffle=(pretrain_sampler is None), sampler=pretrain_sampler, batch_size=args.ptx_batch_size
- )
-
- # NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized.
- (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
- (actor, actor_optim), (critic, critic_optim), reward_model, initial_model
- )
-
- # configure trainer
- trainer = PPOTrainer(
- strategy,
- actor,
- critic,
- reward_model,
- initial_model,
- actor_optim,
- critic_optim,
- tokenizer=tokenizer,
- kl_coef=args.kl_coef,
- ptx_coef=args.ptx_coef,
- train_batch_size=args.train_batch_size,
- max_length=args.max_seq_len,
- use_cache=True,
- do_sample=True,
- temperature=1.0,
- top_k=50,
- offload_inference_models=args.strategy != "colossalai_gemini",
- )
-
- trainer.fit(
- num_episodes=args.num_episodes,
- num_collect_steps=args.num_collect_steps,
- num_update_steps=args.num_update_steps,
- prompt_dataloader=prompt_dataloader,
- pretrain_dataloader=pretrain_dataloader,
- log_dir=args.log_dir,
- use_wandb=args.use_wandb,
- )
-
- if args.lora_rank > 0 and args.merge_lora_weights:
- from coati.models.lora import LORA_MANAGER
-
- # NOTE: set model to eval to merge LoRA weights
- LORA_MANAGER.merge_weights = True
- actor.eval()
- # save model checkpoint after fitting
- strategy.save_pretrained(actor, path=args.save_path)
- # save optimizer checkpoint on all ranks
- if args.need_optim_ckpt:
- strategy.save_optimizer(
- actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False
- )
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--prompt_dataset", type=str, default=None, help="path to the prompt dataset")
- parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
- parser.add_argument("--max_datasets_size", type=int, default=50000)
- parser.add_argument(
- "--strategy",
- choices=["ddp", "colossalai_gemini", "colossalai_zero2"],
- default="colossalai_zero2",
- help="strategy to use",
- )
- parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
- parser.add_argument("--tokenizer", type=str, default=None)
- parser.add_argument("--pretrain", type=str, default=None)
- parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"])
- parser.add_argument("--rm_path", type=str, default=None)
- parser.add_argument("--rm_pretrain", type=str, default=None)
- parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
- parser.add_argument("--need_optim_ckpt", type=bool, default=False)
- parser.add_argument("--num_episodes", type=int, default=10)
- parser.add_argument("--num_collect_steps", type=int, default=10)
- parser.add_argument("--num_update_steps", type=int, default=5)
- parser.add_argument("--train_batch_size", type=int, default=8)
- parser.add_argument("--ptx_batch_size", type=int, default=1)
- parser.add_argument("--experience_batch_size", type=int, default=8)
- parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument("--merge_lora_weights", type=bool, default=True)
- parser.add_argument("--lr", type=float, default=1e-7)
- parser.add_argument("--kl_coef", type=float, default=0.1)
- parser.add_argument("--ptx_coef", type=float, default=0.9)
- parser.add_argument("--max_input_len", type=int, default=96)
- parser.add_argument("--max_seq_len", type=int, default=128)
- parser.add_argument("--log_dir", default="logs", type=str)
- parser.add_argument("--use_wandb", default=False, action="store_true")
- args = parser.parse_args()
- main(args)
diff --git a/applications/Chat/examples/train_prompts.sh b/applications/Chat/examples/train_prompts.sh
deleted file mode 100755
index d04c416015b1..000000000000
--- a/applications/Chat/examples/train_prompts.sh
+++ /dev/null
@@ -1,25 +0,0 @@
-set_n_least_used_CUDA_VISIBLE_DEVICES() {
- local n=${1:-"9999"}
- echo "GPU Memory Usage:"
- local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
- tail -n +2 |
- nl -v 0 |
- tee /dev/tty |
- sort -g -k 2 |
- awk '{print $1}' |
- head -n $n)
- export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
- echo "Now CUDA_VISIBLE_DEVICES is set to:"
- echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
-}
-
-set_n_least_used_CUDA_VISIBLE_DEVICES 2
-
-# torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2
-
-torchrun --standalone --nproc_per_node=2 train_prompts.py \
- --pretrain_dataset /path/to/data.json \
- --prompt_dataset /path/to/data.json \
- --strategy colossalai_zero2 \
- --num_episodes 1 --num_collect_steps 2 --num_update_steps 1 \
- --train_batch_size 2
diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py
deleted file mode 100644
index fcdd29b2954b..000000000000
--- a/applications/Chat/examples/train_reward_model.py
+++ /dev/null
@@ -1,208 +0,0 @@
-import argparse
-import warnings
-
-import torch
-import torch.distributed as dist
-from coati.dataset import HhRlhfDataset, RmStaticDataset
-from coati.models import LogExpLoss, LogSigLoss
-from coati.models.bloom import BLOOMRM
-from coati.models.gpt import GPTRM
-from coati.models.llama import LlamaRM
-from coati.models.opt import OPTRM
-from coati.trainer import RewardModelTrainer
-from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
-from datasets import load_dataset
-from torch.optim import Adam
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.utils.data import DataLoader
-from torch.utils.data.distributed import DistributedSampler
-from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
-from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
-
-from colossalai.nn.optimizer import HybridAdam
-
-
-def train(args):
- # configure strategy
- if args.strategy == "ddp":
- strategy = DDPStrategy()
- elif args.strategy == "colossalai_gemini":
- strategy = GeminiStrategy(placement_policy="auto")
- elif args.strategy == "colossalai_zero2":
- strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
- else:
- raise ValueError(f'Unsupported strategy "{args.strategy}"')
-
- # configure model
- if args.lora_rank > 0:
- warnings.warn("Lora is not supported yet.")
- args.lora_rank = 0
-
- with strategy.model_init_context():
- if args.model == "bloom":
- model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
- elif args.model == "opt":
- model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
- elif args.model == "gpt2":
- model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
- elif args.model == "llama":
- model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
- else:
- raise ValueError(f'Unsupported model "{args.model}"')
-
- model.to(torch.bfloat16).to(torch.cuda.current_device())
-
- if args.model_path is not None:
- state_dict = torch.load(args.model_path)
- model.load_state_dict(state_dict)
-
- # configure tokenizer
- if args.model == "gpt2":
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == "bloom":
- tokenizer = BloomTokenizerFast.from_pretrained(
- "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
- )
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == "opt":
- tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == "llama":
- tokenizer = LlamaTokenizer.from_pretrained(
- "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
- )
- tokenizer.eos_token = ""
- tokenizer.pad_token = tokenizer.unk_token
- else:
- raise ValueError(f'Unsupported model "{args.model}"')
-
- # configure optimizer
- if args.strategy.startswith("colossalai"):
- optim = HybridAdam(model.parameters(), lr=args.lr)
- else:
- optim = Adam(model.parameters(), lr=args.lr)
-
- # configure loss function
- if args.loss_fn == "log_sig":
- loss_fn = LogSigLoss()
- elif args.loss_fn == "log_exp":
- loss_fn = LogExpLoss()
- else:
- raise ValueError(f'Unsupported loss function "{args.loss_fn}"')
-
- # prepare for data and dataset
- if args.subset is not None:
- data = load_dataset(args.dataset, data_dir=args.subset)
- else:
- data = load_dataset(args.dataset)
-
- train_data = data["train"].select(range(min(args.max_datasets_size, len(data["train"]))))
- eval_data = data["test"].select(range(min(args.max_datasets_size, len(data["test"]))))
-
- if args.dataset == "Dahoas/rm-static":
- train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len)
- eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len)
- elif args.dataset == "Anthropic/hh-rlhf":
- train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len)
- eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len)
- else:
- raise ValueError(f'Unsupported dataset "{args.dataset}"')
-
- if dist.is_initialized() and dist.get_world_size() > 1:
- train_sampler = DistributedSampler(
- train_dataset,
- shuffle=True,
- seed=42,
- drop_last=True,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size(),
- )
- eval_sampler = DistributedSampler(
- eval_dataset,
- shuffle=True,
- seed=42,
- drop_last=True,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size(),
- )
- else:
- train_sampler = None
- eval_sampler = None
-
- train_dataloader = DataLoader(
- train_dataset,
- shuffle=(train_sampler is None),
- sampler=train_sampler,
- batch_size=args.batch_size,
- pin_memory=True,
- )
-
- eval_dataloader = DataLoader(
- eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True
- )
-
- lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100)
- strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
- model = strategy_dict["model"]
- optim = strategy_dict["optimizer"]
- lr_scheduler = strategy_dict["lr_scheduler"]
- trainer = RewardModelTrainer(
- model=model,
- strategy=strategy,
- optim=optim,
- lr_scheduler=lr_scheduler,
- loss_fn=loss_fn,
- max_epochs=args.max_epochs,
- )
-
- trainer.fit(
- train_dataloader=train_dataloader,
- eval_dataloader=eval_dataloader,
- log_dir=args.log_dir,
- use_wandb=args.use_wandb,
- )
-
- if args.lora_rank > 0 and args.merge_lora_weights:
- from coati.models.lora import LORA_MANAGER
-
- # NOTE: set model to eval to merge LoRA weights
- LORA_MANAGER.merge_weights = True
- model.eval()
- # save model checkpoint after fitting on only rank0
- state_dict = model.state_dict()
- torch.save(state_dict, args.save_path)
- # save optimizer checkpoint on all ranks
- if args.need_optim_ckpt:
- strategy.save_optimizer(
- trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
- )
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="colossalai_zero2"
- )
- parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom")
- parser.add_argument("--tokenizer", type=str, default=None)
- parser.add_argument("--pretrain", type=str, default=None)
- parser.add_argument("--model_path", type=str, default=None)
- parser.add_argument("--need_optim_ckpt", type=bool, default=False)
- parser.add_argument(
- "--dataset", type=str, choices=["Anthropic/hh-rlhf", "Dahoas/rm-static"], default="Dahoas/rm-static"
- )
- parser.add_argument("--subset", type=lambda x: None if x == "None" else x, default=None)
- parser.add_argument("--max_datasets_size", type=int, default=1000000)
- parser.add_argument("--save_path", type=str, default="rm_ckpt")
- parser.add_argument("--max_epochs", type=int, default=1)
- parser.add_argument("--batch_size", type=int, default=1)
- parser.add_argument("--max_len", type=int, default=512)
- parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument("--merge_lora_weights", type=bool, default=True)
- parser.add_argument("--lr", type=float, default=9e-6)
- parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"])
- parser.add_argument("--log_dir", default="logs", type=str)
- parser.add_argument("--use_wandb", default=False, action="store_true")
- args = parser.parse_args()
- train(args)
diff --git a/applications/Chat/examples/train_rm.sh b/applications/Chat/examples/train_rm.sh
deleted file mode 100755
index c5ebaf708ddc..000000000000
--- a/applications/Chat/examples/train_rm.sh
+++ /dev/null
@@ -1,25 +0,0 @@
-set_n_least_used_CUDA_VISIBLE_DEVICES() {
- local n=${1:-"9999"}
- echo "GPU Memory Usage:"
- local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
- tail -n +2 |
- nl -v 0 |
- tee /dev/tty |
- sort -g -k 2 |
- awk '{print $1}' |
- head -n $n)
- export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
- echo "Now CUDA_VISIBLE_DEVICES is set to:"
- echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
-}
-
-set_n_least_used_CUDA_VISIBLE_DEVICES 2
-
-torchrun --standalone --nproc_per_node=2 train_reward_model.py \
- --pretrain 'gpt2' \
- --model 'gpt2' \
- --strategy colossalai_zero2 \
- --loss_fn 'log_exp' \
- --dataset 'Anthropic/hh-rlhf' \
- --batch_size 16 \
- --max_epochs 10
diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py
deleted file mode 100644
index d00c04809a2d..000000000000
--- a/applications/Chat/examples/train_sft.py
+++ /dev/null
@@ -1,221 +0,0 @@
-import argparse
-import math
-import warnings
-
-import torch
-import torch.distributed as dist
-from coati.dataset import SFTDataset, SupervisedDataset
-from coati.models.bloom import BLOOMActor
-from coati.models.chatglm import ChatGLMActor
-from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
-from coati.models.gpt import GPTActor
-from coati.models.llama import LlamaActor
-from coati.models.opt import OPTActor
-from coati.trainer import SFTTrainer
-from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
-from datasets import load_dataset
-from torch.optim import Adam
-from torch.utils.data import DataLoader
-from torch.utils.data.distributed import DistributedSampler
-from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
-from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
-from transformers.trainer import get_scheduler
-
-from colossalai.logging import get_dist_logger
-from colossalai.nn.optimizer import HybridAdam
-
-
-def train(args):
- # configure strategy
- if args.strategy == "ddp":
- strategy = DDPStrategy()
- elif args.strategy == "colossalai_gemini":
- strategy = GeminiStrategy(placement_policy="auto")
- elif args.strategy == "colossalai_zero2":
- strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
- elif args.strategy == "colossalai_zero2_cpu":
- strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
- else:
- raise ValueError(f'Unsupported strategy "{args.strategy}"')
-
- # configure model
- if args.lora_rank > 0:
- warnings.warn("Lora is not supported yet.")
- args.lora_rank = 0
-
- with strategy.model_init_context():
- if args.model == "bloom":
- model = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
- elif args.model == "opt":
- model = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
- elif args.model == "gpt2":
- model = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
- elif args.model == "llama":
- model = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
- elif args.model == "chatglm":
- model = ChatGLMActor(pretrained=args.pretrain)
- else:
- raise ValueError(f'Unsupported model "{args.model}"')
-
- model.to(torch.bfloat16).to(torch.cuda.current_device())
-
- # configure tokenizer
- if args.model == "gpt2":
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == "bloom":
- tokenizer = BloomTokenizerFast.from_pretrained(
- "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
- )
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == "opt":
- tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == "llama":
- tokenizer = LlamaTokenizer.from_pretrained(
- "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
- )
- tokenizer.eos_token = ""
- tokenizer.pad_token = tokenizer.unk_token
- elif args.model == "chatglm":
- tokenizer = ChatGLMTokenizer.from_pretrained(
- "THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True
- )
- else:
- raise ValueError(f'Unsupported model "{args.model}"')
-
- # configure optimizer
- if args.strategy.startswith("colossalai"):
- optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
- else:
- optim = Adam(model.parameters(), lr=args.lr)
-
- # configure dataset
- if args.dataset == "yizhongw/self_instruct":
- train_data = load_dataset(args.dataset, "super_natural_instructions", split="train")
- eval_data = load_dataset(args.dataset, "super_natural_instructions", split="test")
-
- if args.max_datasets_size is not None:
- train_data = train_data.select(range(min(args.max_datasets_size, len(train_data))))
- eval_data = eval_data.select(range(min(args.max_datasets_size, len(eval_data))))
-
- train_dataset = SFTDataset(train_data, tokenizer, args.max_len)
- eval_dataset = SFTDataset(eval_data, tokenizer, args.max_len)
-
- else:
- train_dataset = SupervisedDataset(
- tokenizer=tokenizer,
- data_path=args.dataset,
- max_datasets_size=args.max_datasets_size,
- max_length=args.max_len,
- )
- eval_dataset = None
-
- if dist.is_initialized() and dist.get_world_size() > 1:
- train_sampler = DistributedSampler(
- train_dataset,
- shuffle=True,
- seed=42,
- drop_last=True,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size(),
- )
- if eval_dataset is not None:
- eval_sampler = DistributedSampler(
- eval_dataset,
- shuffle=False,
- seed=42,
- drop_last=False,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size(),
- )
- else:
- train_sampler = None
- eval_sampler = None
-
- train_dataloader = DataLoader(
- train_dataset,
- shuffle=(train_sampler is None),
- sampler=train_sampler,
- batch_size=args.batch_size,
- pin_memory=True,
- )
- if eval_dataset is not None:
- eval_dataloader = DataLoader(
- eval_dataset,
- shuffle=(eval_sampler is None),
- sampler=eval_sampler,
- batch_size=args.batch_size,
- pin_memory=True,
- )
- else:
- eval_dataloader = None
-
- num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
- max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch)
- lr_scheduler = get_scheduler(
- "cosine", optim, num_warmup_steps=math.ceil(max_steps * 0.03), num_training_steps=max_steps
- )
- strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
- model = strategy_dict["model"]
- optim = strategy_dict["optimizer"]
- lr_scheduler = strategy_dict["lr_scheduler"]
- trainer = SFTTrainer(
- model=model,
- strategy=strategy,
- optim=optim,
- lr_scheduler=lr_scheduler,
- max_epochs=args.max_epochs,
- accumulation_steps=args.accumulation_steps,
- )
-
- logger = get_dist_logger()
- trainer.fit(
- train_dataloader=train_dataloader,
- eval_dataloader=eval_dataloader,
- logger=logger,
- log_dir=args.log_dir,
- use_wandb=args.use_wandb,
- )
-
- if args.lora_rank > 0 and args.merge_lora_weights:
- from coati.models.lora import LORA_MANAGER
-
- # NOTE: set model to eval to merge LoRA weights
- LORA_MANAGER.merge_weights = True
- model.eval()
- # save model checkpoint after fitting on only rank0
- strategy.save_pretrained(model, path=args.save_path, tokenizer=tokenizer)
- # save optimizer checkpoint on all ranks
- if args.need_optim_ckpt:
- strategy.save_optimizer(
- trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
- )
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--strategy",
- choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_zero2_cpu"],
- default="colossalai_zero2",
- )
- parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama", "chatglm"], default="bloom")
- parser.add_argument("--tokenizer", type=str, default=None)
- parser.add_argument("--pretrain", type=str, default=None)
- parser.add_argument("--dataset", type=str, default=None)
- parser.add_argument("--max_datasets_size", type=int, default=None)
- parser.add_argument("--save_path", type=str, default="output")
- parser.add_argument("--need_optim_ckpt", type=bool, default=False)
- parser.add_argument("--max_epochs", type=int, default=3)
- parser.add_argument("--batch_size", type=int, default=4)
- parser.add_argument("--max_len", type=int, default=512)
- parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument("--merge_lora_weights", type=bool, default=True)
- parser.add_argument("--lr", type=float, default=5e-6)
- parser.add_argument("--accumulation_steps", type=int, default=8)
- parser.add_argument("--log_dir", default="logs", type=str)
- parser.add_argument("--use_wandb", default=False, action="store_true")
- parser.add_argument("--grad_checkpoint", default=False, action="store_true")
- args = parser.parse_args()
- train(args)
diff --git a/applications/Chat/examples/train_sft.sh b/applications/Chat/examples/train_sft.sh
deleted file mode 100755
index 0fb4da3d3ce8..000000000000
--- a/applications/Chat/examples/train_sft.sh
+++ /dev/null
@@ -1,28 +0,0 @@
-set_n_least_used_CUDA_VISIBLE_DEVICES() {
- local n=${1:-"9999"}
- echo "GPU Memory Usage:"
- local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
- tail -n +2 |
- nl -v 0 |
- tee /dev/tty |
- sort -g -k 2 |
- awk '{print $1}' |
- head -n $n)
- export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
- echo "Now CUDA_VISIBLE_DEVICES is set to:"
- echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
-}
-
-set_n_least_used_CUDA_VISIBLE_DEVICES 4
-
-torchrun --standalone --nproc_per_node=4 train_sft.py \
- --pretrain "/path/to/LLaMa-7B/" \
- --model 'llama' \
- --strategy colossalai_zero2 \
- --save_path /path/to/Coati-7B \
- --dataset /path/to/data.json \
- --batch_size 4 \
- --accumulation_steps 8 \
- --lr 2e-5 \
- --max_datasets_size 512 \
- --max_epochs 1
diff --git a/applications/Chat/inference/benchmark.py b/applications/Chat/inference/benchmark.py
deleted file mode 100644
index dbb5490a63dc..000000000000
--- a/applications/Chat/inference/benchmark.py
+++ /dev/null
@@ -1,141 +0,0 @@
-# Adapted from https://github.com/tloen/alpaca-lora/blob/main/generate.py
-
-import argparse
-from time import time
-
-import torch
-from coati.quant import llama_load_quant, low_resource_init
-from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
-
-
-def generate_prompt(instruction, input=None):
- if input:
- return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
-
-### Instruction:
-{instruction}
-
-### Input:
-{input}
-
-### Response:"""
- else:
- return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
-
-### Instruction:
-{instruction}
-
-### Response:"""
-
-
-@torch.no_grad()
-def evaluate(
- model,
- tokenizer,
- instruction,
- input=None,
- temperature=0.1,
- top_p=0.75,
- top_k=40,
- num_beams=4,
- max_new_tokens=128,
- **kwargs,
-):
- prompt = generate_prompt(instruction, input)
- inputs = tokenizer(prompt, return_tensors="pt")
- input_ids = inputs["input_ids"].cuda()
- generation_config = GenerationConfig(
- temperature=temperature,
- top_p=top_p,
- top_k=top_k,
- num_beams=num_beams,
- **kwargs,
- )
- generation_output = model.generate(
- input_ids=input_ids,
- generation_config=generation_config,
- return_dict_in_generate=True,
- output_scores=True,
- max_new_tokens=max_new_tokens,
- do_sample=True,
- )
- s = generation_output.sequences[0]
- output = tokenizer.decode(s)
- n_new_tokens = s.size(0) - input_ids.size(1)
- return output.split("### Response:")[1].strip(), n_new_tokens
-
-
-instructions = [
- "Tell me about alpacas.",
- "Tell me about the president of Mexico in 2019.",
- "Tell me about the king of France in 2019.",
- "List all Canadian provinces in alphabetical order.",
- "Write a Python program that prints the first 10 Fibonacci numbers.",
- "Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.",
- "Tell me five words that rhyme with 'shock'.",
- "Translate the sentence 'I have no mouth but I must scream' into Spanish.",
- "Count up from 1 to 500.",
- # ===
- "How to play support in legends of league",
- "Write a Python program that calculate Fibonacci numbers.",
-]
-inst = [instructions[0]] * 4
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "pretrained",
- help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.",
- )
- parser.add_argument(
- "--quant",
- choices=["8bit", "4bit"],
- default=None,
- help="Quantization mode. Default: None (no quantization, fp16).",
- )
- parser.add_argument(
- "--gptq_checkpoint",
- default=None,
- help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.",
- )
- parser.add_argument(
- "--gptq_group_size",
- type=int,
- default=128,
- help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.",
- )
- args = parser.parse_args()
-
- if args.quant == "4bit":
- assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint."
-
- tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
-
- if args.quant == "4bit":
- with low_resource_init():
- config = LlamaConfig.from_pretrained(args.pretrained)
- model = LlamaForCausalLM(config)
- model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size)
- model.cuda()
- else:
- model = LlamaForCausalLM.from_pretrained(
- args.pretrained,
- load_in_8bit=(args.quant == "8bit"),
- torch_dtype=torch.float16,
- device_map="auto",
- )
- if args.quant != "8bit":
- model.half() # seems to fix bugs for some users.
- model.eval()
-
- total_tokens = 0
- start = time()
- for instruction in instructions:
- print(f"Instruction: {instruction}")
- resp, tokens = evaluate(model, tokenizer, instruction, temperature=0.2, num_beams=1)
- total_tokens += tokens
- print(f"Response: {resp}")
- print("\n----------------------------\n")
- duration = time() - start
- print(f"Total time: {duration:.3f} s, {total_tokens/duration:.3f} tokens/s")
- print(f"Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB")
diff --git a/applications/Chat/inference/tests/test_chat_prompt.py b/applications/Chat/inference/tests/test_chat_prompt.py
deleted file mode 100644
index 9835e71894c6..000000000000
--- a/applications/Chat/inference/tests/test_chat_prompt.py
+++ /dev/null
@@ -1,61 +0,0 @@
-import os
-
-from transformers import AutoTokenizer
-from utils import ChatPromptProcessor, Dialogue
-
-CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions."
-tokenizer = AutoTokenizer.from_pretrained(os.environ["PRETRAINED_PATH"])
-
-samples = [
- (
- [
- Dialogue(
- instruction="Who is the best player in the history of NBA?",
- response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
- ),
- Dialogue(instruction="continue this talk", response=""),
- ],
- 128,
- "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n",
- ),
- (
- [
- Dialogue(
- instruction="Who is the best player in the history of NBA?",
- response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
- ),
- Dialogue(instruction="continue this talk", response=""),
- ],
- 200,
- "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n",
- ),
- (
- [
- Dialogue(
- instruction="Who is the best player in the history of NBA?",
- response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
- ),
- Dialogue(instruction="continue this talk", response=""),
- ],
- 211,
- "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n",
- ),
- (
- [
- Dialogue(instruction="Who is the best player in the history of NBA?", response=""),
- ],
- 128,
- "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n",
- ),
-]
-
-
-def test_chat_prompt_processor():
- processor = ChatPromptProcessor(tokenizer, CONTEXT, 256)
- for history, max_new_tokens, result in samples:
- prompt = processor.preprocess_prompt(history, max_new_tokens)
- assert prompt == result
-
-
-if __name__ == "__main__":
- test_chat_prompt_processor()
diff --git a/applications/Chat/inference/utils.py b/applications/Chat/inference/utils.py
deleted file mode 100644
index af018adf6e9d..000000000000
--- a/applications/Chat/inference/utils.py
+++ /dev/null
@@ -1,209 +0,0 @@
-import json
-import re
-from threading import Lock
-from typing import Any, Callable, Generator, List, Optional
-
-import jieba
-import torch
-import torch.distributed as dist
-import torch.nn as nn
-from pydantic import BaseModel, Field
-
-try:
- from transformers.generation_logits_process import (
- LogitsProcessorList,
- TemperatureLogitsWarper,
- TopKLogitsWarper,
- TopPLogitsWarper,
- )
-except ImportError:
- from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
-
-
-def prepare_logits_processor(
- top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
-) -> LogitsProcessorList:
- processor_list = LogitsProcessorList()
- if temperature is not None and temperature != 1.0:
- processor_list.append(TemperatureLogitsWarper(temperature))
- if top_k is not None and top_k != 0:
- processor_list.append(TopKLogitsWarper(top_k))
- if top_p is not None and top_p < 1.0:
- processor_list.append(TopPLogitsWarper(top_p))
- return processor_list
-
-
-def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
- if dist.is_initialized() and dist.get_world_size() > 1:
- # consider DP
- unfinished_sequences = unfinished_sequences.clone()
- dist.all_reduce(unfinished_sequences)
- return unfinished_sequences.max() == 0
-
-
-def sample_streamingly(
- model: nn.Module,
- input_ids: torch.Tensor,
- max_generate_tokens: int,
- early_stopping: bool = False,
- eos_token_id: Optional[int] = None,
- pad_token_id: Optional[int] = None,
- top_k: Optional[int] = None,
- top_p: Optional[float] = None,
- temperature: Optional[float] = None,
- prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
- update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
- **model_kwargs,
-) -> Generator:
- logits_processor = prepare_logits_processor(top_k, top_p, temperature)
- unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
-
- for _ in range(max_generate_tokens):
- model_inputs = (
- prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids}
- )
- outputs = model(**model_inputs)
-
- next_token_logits = outputs["logits"][:, -1, :]
- # pre-process distribution
- next_token_logits = logits_processor(input_ids, next_token_logits)
- # sample
- probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
-
- # finished sentences should have their next token be a padding token
- if eos_token_id is not None:
- if pad_token_id is None:
- raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
- next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
-
- yield next_tokens
-
- # update generated ids, model inputs for next step
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
- if update_model_kwargs_fn is not None:
- model_kwargs = update_model_kwargs_fn(outputs, **model_kwargs)
-
- # if eos_token was found in one sentence, set sentence to finished
- if eos_token_id is not None:
- unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
-
- # stop when each sentence is finished if early_stopping=True
- if early_stopping and _is_sequence_finished(unfinished_sequences):
- break
-
-
-def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
- if "past_key_values" in outputs:
- model_kwargs["past"] = outputs["past_key_values"]
- else:
- model_kwargs["past"] = None
-
- # update token_type_ids with last value
- if "token_type_ids" in model_kwargs:
- token_type_ids = model_kwargs["token_type_ids"]
- model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
-
- # update attention mask
- if "attention_mask" in model_kwargs:
- attention_mask = model_kwargs["attention_mask"]
- model_kwargs["attention_mask"] = torch.cat(
- [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
- )
-
- return model_kwargs
-
-
-class Dialogue(BaseModel):
- instruction: str = Field(min_length=1, example="Count up from 1 to 500.")
- response: str = Field(example="")
-
-
-def _format_dialogue(instruction: str, response: str = ""):
- return f"\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}"
-
-
-STOP_PAT = re.compile(r"(###|instruction:).*", flags=(re.I | re.S))
-
-
-class ChatPromptProcessor:
- SAFE_RESPONSE = "The input/response contains inappropriate content, please rephrase your prompt."
-
- def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str] = []):
- self.tokenizer = tokenizer
- self.context = context
- self.max_len = max_len
- self.censored_words = set([word.lower() for word in censored_words])
- # These will be initialized after the first call of preprocess_prompt()
- self.context_len: Optional[int] = None
- self.dialogue_placeholder_len: Optional[int] = None
-
- def preprocess_prompt(self, history: List[Dialogue], max_new_tokens: int) -> str:
- if self.context_len is None:
- self.context_len = len(self.tokenizer(self.context)["input_ids"])
- if self.dialogue_placeholder_len is None:
- self.dialogue_placeholder_len = len(
- self.tokenizer(_format_dialogue(""), add_special_tokens=False)["input_ids"]
- )
- prompt = self.context
- # the last dialogue must be in the prompt
- last_dialogue = history.pop()
- # the response of the last dialogue is empty
- assert last_dialogue.response == ""
- if (
- len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False)["input_ids"])
- + max_new_tokens
- + self.context_len
- >= self.max_len
- ):
- # to avoid truncate placeholder, apply truncate to the original instruction
- instruction_truncated = self.tokenizer(
- last_dialogue.instruction,
- add_special_tokens=False,
- truncation=True,
- max_length=(self.max_len - max_new_tokens - self.context_len - self.dialogue_placeholder_len),
- )["input_ids"]
- instruction_truncated = self.tokenizer.decode(instruction_truncated).lstrip()
- prompt += _format_dialogue(instruction_truncated)
- return prompt
-
- res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)["input_ids"])
-
- rows = []
- for dialogue in history[::-1]:
- text = _format_dialogue(dialogue.instruction, dialogue.response)
- cur_len = len(self.tokenizer(text, add_special_tokens=False)["input_ids"])
- if res_len - cur_len < 0:
- break
- res_len -= cur_len
- rows.insert(0, text)
- prompt += "".join(rows) + _format_dialogue(last_dialogue.instruction)
- return prompt
-
- def postprocess_output(self, output: str) -> str:
- output = STOP_PAT.sub("", output)
- return output.strip()
-
- def has_censored_words(self, text: str) -> bool:
- if len(self.censored_words) == 0:
- return False
- intersection = set(jieba.cut(text.lower())) & self.censored_words
- return len(intersection) > 0
-
-
-class LockedIterator:
- def __init__(self, it, lock: Lock) -> None:
- self.lock = lock
- self.it = iter(it)
-
- def __iter__(self):
- return self
-
- def __next__(self):
- with self.lock:
- return next(self.it)
-
-
-def load_json(path: str):
- with open(path) as f:
- return json.load(f)
diff --git a/applications/Chat/requirements-test.txt b/applications/Chat/requirements-test.txt
deleted file mode 100644
index 93d48bcb6f79..000000000000
--- a/applications/Chat/requirements-test.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-pytest
-colossalai==0.3.3
diff --git a/applications/Chat/requirements.txt b/applications/Chat/requirements.txt
deleted file mode 100644
index e56aaca0e7cb..000000000000
--- a/applications/Chat/requirements.txt
+++ /dev/null
@@ -1,14 +0,0 @@
-transformers>=4.20.1
-tqdm
-datasets
-loralib
-colossalai==0.3.3
-torch<2.0.0, >=1.12.1
-langchain
-tokenizers
-fastapi
-sse_starlette
-wandb
-sentencepiece
-gpustat
-tensorboard
diff --git a/applications/Chat/tests/test_benchmarks.sh b/applications/Chat/tests/test_benchmarks.sh
deleted file mode 100755
index 3fdb25181342..000000000000
--- a/applications/Chat/tests/test_benchmarks.sh
+++ /dev/null
@@ -1,33 +0,0 @@
-#!/bin/bash
-
-set -xue
-
-echo "Hint: You can run this script with 'verbose' as the first argument to run all strategies."
-
-if [[ $# -ne 0 && "$1" == "verbose" ]]; then
- STRATEGIES=(
- 'ddp'
- 'colossalai_gemini'
- 'colossalai_gemini_cpu'
- 'colossalai_zero2'
- 'colossalai_zero2_cpu'
- 'colossalai_zero1'
- 'colossalai_zero1_cpu'
- )
-else
- STRATEGIES=(
- 'colossalai_zero2'
- )
-fi
-
-BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
-BENCHMARKS_DIR=$BASE_DIR/benchmarks
-
-echo "[Test]: testing benchmarks ..."
-
-for strategy in ${STRATEGIES[@]}; do
- torchrun --standalone --nproc_per_node 1 $BENCHMARKS_DIR/benchmark_opt_lora_dummy.py \
- --model 125m --critic_model 125m --strategy ${strategy} --lora_rank 4 \
- --num_episodes 2 --num_collect_steps 4 --num_update_steps 2 \
- --train_batch_size 2 --experience_batch_size 4
-done
diff --git a/applications/Chat/tests/test_checkpoint.py b/applications/Chat/tests/test_checkpoint.py
deleted file mode 100644
index 9c08aa36c9b4..000000000000
--- a/applications/Chat/tests/test_checkpoint.py
+++ /dev/null
@@ -1,91 +0,0 @@
-import os
-import tempfile
-from contextlib import nullcontext
-
-import pytest
-import torch
-import torch.distributed as dist
-from coati.models.gpt import GPTActor
-from coati.models.utils import calc_action_log_probs
-from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy
-from transformers.models.gpt2.configuration_gpt2 import GPT2Config
-
-from colossalai.nn.optimizer import HybridAdam
-from colossalai.testing import rerun_if_address_is_in_use, spawn
-
-GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
-
-
-def get_data(batch_size: int, seq_len: int = 10) -> dict:
- input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda")
- attention_mask = torch.ones_like(input_ids)
- return dict(input_ids=input_ids, attention_mask=attention_mask)
-
-
-def train_step(strategy: Strategy, actor: GPTActor, actor_optim: HybridAdam, batch_size: int = 8):
- data = get_data(batch_size)
- action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool)
- actor_logits = actor(data["input_ids"], data["attention_mask"])["logits"]
- action_log_probs = calc_action_log_probs(actor_logits, data["input_ids"], action_mask.size(1))
- loss = action_log_probs.sum()
- strategy.backward(loss, actor, actor_optim)
- strategy.optimizer_step(actor_optim)
-
-
-def run_test_checkpoint(strategy_name: str, shard: bool):
- if strategy_name == "ddp":
- strategy = DDPStrategy()
- elif strategy_name == "colossalai_gemini":
- strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5)
- elif strategy_name == "colossalai_zero2":
- strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
- else:
- raise ValueError(f"Unsupported strategy '{strategy_name}'")
-
- with strategy.model_init_context():
- actor = GPTActor(config=GPT_CONFIG).cuda()
- actor_optim = HybridAdam(actor.parameters())
- actor, actor_optim = strategy.prepare((actor, actor_optim))
-
- train_step(strategy, actor, actor_optim)
-
- ctx = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext()
-
- with ctx as dirname:
- rank0_dirname = [dirname]
- dist.broadcast_object_list(rank0_dirname)
- rank0_dirname = rank0_dirname[0]
-
- model_path = os.path.join(rank0_dirname, "model" if shard else f"model.pt")
- strategy.save_model(actor, model_path)
- optim_path = os.path.join(rank0_dirname, "optim" if shard else "optim.pt")
- strategy.save_optimizer(actor_optim, optim_path)
- dist.barrier()
-
- strategy.load_model(actor, model_path, strict=False)
- strategy.load_optimizer(actor_optim, optim_path)
- dist.barrier()
-
- train_step(strategy, actor, actor_optim)
-
-
-def run_dist(rank: int, world_size: int, port: int, strategy_name: str, shard: bool):
- os.environ["RANK"] = str(rank)
- os.environ["LOCAL_RANK"] = str(rank)
- os.environ["WORLD_SIZE"] = str(world_size)
- os.environ["MASTER_ADDR"] = "localhost"
- os.environ["MASTER_PORT"] = str(port)
- run_test_checkpoint(strategy_name, shard)
-
-
-@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [4])
-@pytest.mark.parametrize("strategy_name", ["ddp", "colossalai_gemini", "colossalai_zero2"])
-@pytest.mark.parametrize("shard", [False, True])
-@rerun_if_address_is_in_use()
-def test_checkpoint(world_size: int, strategy_name: str, shard: bool):
- spawn(run_dist, world_size, strategy_name=strategy_name, shard=shard)
-
-
-if __name__ == "__main__":
- test_checkpoint(2, "colossalai_gemini", shard=False)
diff --git a/applications/Chat/tests/test_dataset.py b/applications/Chat/tests/test_dataset.py
deleted file mode 100644
index ec61bbb13fd7..000000000000
--- a/applications/Chat/tests/test_dataset.py
+++ /dev/null
@@ -1,241 +0,0 @@
-import json
-import os
-import tempfile
-from typing import Optional
-
-import pytest
-import torch
-from coati.dataset.prompt_dataset import PromptDataset
-from coati.dataset.reward_dataset import HhRlhfDataset, RmStaticDataset
-from coati.dataset.sft_dataset import IGNORE_INDEX, SFTDataset, SupervisedDataset
-from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
-from datasets import load_dataset
-from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, PreTrainedTokenizer
-from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
-
-SFT_DATASET = [
- {
- "instruction": "Provide a list of the top 10 most popular mobile games in Asia",
- "input": "",
- "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
- "id": 0,
- },
- {
- "instruction": "Please provide an action plan for reducing carbon footprint on a corporate level",
- "input": "",
- "output": "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.",
- "id": 1,
- },
- {
- "instruction": "Write a persuasive email to your boss explaining why you should have a pay raise",
- "input": "",
- "output": "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]",
- "id": 2,
- },
-]
-
-PROMPT_DATASET = [
- {
- "instruction": 'Edit this paragraph to make it more concise: "Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends."',
- "id": 0,
- },
- {"instruction": "Write a descriptive paragraph about a memorable vacation you went on", "id": 1},
- {"instruction": "Write a persuasive essay arguing why homework should be banned in schools", "id": 2},
- {"instruction": "Create a chart comparing the statistics on student debt in the United States.", "id": 3},
-]
-
-
-def make_tokenizer(model: str):
- if model == "gpt2":
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
- tokenizer.pad_token = tokenizer.eos_token
- elif model == "bloom":
- tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
- tokenizer.pad_token = tokenizer.eos_token
- elif model == "opt":
- tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
- tokenizer.pad_token = tokenizer.eos_token
- elif model == "llama":
- tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
- tokenizer.pad_token = tokenizer.unk_token
- elif model == "chatglm":
- tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
- else:
- raise ValueError(f"Unsupported model '{model}'")
- return tokenizer
-
-
-def check_content(input_ids_stripped: torch.Tensor, tokenizer: PreTrainedTokenizer, model: str):
- if model == "opt":
- # NOTE: Contrary to GPT2, OPT adds the EOS token to the beginning of every prompt.
- assert input_ids_stripped[0] == tokenizer.eos_token_id
- input_ids_stripped = input_ids_stripped[1:]
- elif model == "llama":
- assert input_ids_stripped[0] == tokenizer.bos_token_id
- input_ids_stripped = input_ids_stripped[1:]
- elif model == "chatglm":
- assert input_ids_stripped[0] == tokenizer.bos_token_id
- assert input_ids_stripped[-1] == tokenizer.eos_token_id
- input_ids_stripped = input_ids_stripped[1:-1]
- assert torch.all(input_ids_stripped != tokenizer.pad_token_id)
- assert torch.all(input_ids_stripped != tokenizer.bos_token_id)
- assert torch.all(input_ids_stripped != tokenizer.eos_token_id)
- assert input_ids_stripped != tokenizer.sep_token_id
- assert input_ids_stripped != tokenizer.cls_token_id
- if model == "chatglm":
- assert torch.all(input_ids_stripped != tokenizer.mask_token_id)
- else:
- assert input_ids_stripped != tokenizer.mask_token_id
-
-
-@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
-@pytest.mark.parametrize("max_length", [32, 1024])
-@pytest.mark.parametrize("max_datasets_size", [2])
-def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int):
- with tempfile.TemporaryDirectory() as tmp_dir:
- dataset_name = "prompt_dataset.json"
- with open(os.path.join(tmp_dir, dataset_name), "w") as f:
- json.dump(PROMPT_DATASET, f)
- tokenizer = make_tokenizer(model)
- assert tokenizer.padding_side in ("left", "right")
- prompt_dataset = PromptDataset(
- data_path=os.path.join(tmp_dir, dataset_name),
- tokenizer=tokenizer,
- max_datasets_size=max_datasets_size,
- max_length=max_length,
- )
- assert len(prompt_dataset) == min(max_datasets_size, len(PROMPT_DATASET))
- for i in range(len(prompt_dataset)):
- assert isinstance(prompt_dataset[i], dict)
- assert list(prompt_dataset[i].keys()) == ["input_ids", "attention_mask"]
- input_ids = prompt_dataset[i]["input_ids"]
- attention_mask = prompt_dataset[i]["attention_mask"]
- attention_mask = attention_mask.bool()
- assert input_ids.shape == attention_mask.shape == torch.Size([max_length])
- assert torch.all(input_ids[torch.logical_not(attention_mask)] == tokenizer.pad_token_id)
- check_content(input_ids.masked_select(attention_mask), tokenizer, model)
-
-
-@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
-@pytest.mark.parametrize(
- ["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"), ("Dahoas/rm-static", None)]
-)
-@pytest.mark.parametrize("max_datasets_size", [32])
-@pytest.mark.parametrize("max_length", [32, 1024])
-def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], max_datasets_size: int, max_length: int):
- data = load_dataset(dataset_path, data_dir=subset)
- assert max_datasets_size <= len(data["train"]) and max_datasets_size <= len(data["test"])
- train_data = data["train"].select(range(max_datasets_size))
- test_data = data["test"].select(range(max_datasets_size))
- tokenizer = make_tokenizer(model)
- assert tokenizer.padding_side in ("left", "right")
-
- if dataset_path == "Anthropic/hh-rlhf":
- train_dataset = HhRlhfDataset(train_data, tokenizer, max_length)
- test_dataset = HhRlhfDataset(test_data, tokenizer, max_length)
- elif dataset_path == "Dahoas/rm-static":
- train_dataset = RmStaticDataset(train_data, tokenizer, max_length)
- test_dataset = RmStaticDataset(test_data, tokenizer, max_length)
- else:
- raise ValueError(f'Unsupported dataset "{dataset_path}"')
-
- assert len(train_dataset) == len(test_dataset) == max_datasets_size
- for i in range(max_datasets_size):
- chosen_ids, c_mask, reject_ids, r_mask = train_dataset[i]
- assert chosen_ids.shape == c_mask.shape == reject_ids.shape == r_mask.shape == torch.Size([max_length])
- c_mask = c_mask.to(torch.bool)
- r_mask = r_mask.to(torch.bool)
- if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id:
- check_content(chosen_ids.masked_select(c_mask)[:-1], tokenizer, model)
- assert torch.all(chosen_ids.masked_select(torch.logical_not(c_mask)) == tokenizer.pad_token_id)
- else:
- check_content(chosen_ids.masked_select(c_mask), tokenizer, model)
- assert torch.all(c_mask)
- if reject_ids.masked_select(r_mask)[-1] == tokenizer.eos_token_id:
- check_content(reject_ids.masked_select(r_mask)[:-1], tokenizer, model)
- assert torch.all(reject_ids.masked_select(torch.logical_not(r_mask)) == tokenizer.pad_token_id)
- else:
- check_content(reject_ids.masked_select(r_mask), tokenizer, model)
- assert torch.all(r_mask)
-
- chosen_ids, c_mask, reject_ids, r_mask = test_dataset[i]
- assert chosen_ids.shape == c_mask.shape == reject_ids.shape == r_mask.shape == torch.Size([max_length])
- c_mask = c_mask.to(torch.bool)
- r_mask = r_mask.to(torch.bool)
- if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id:
- check_content(chosen_ids.masked_select(c_mask)[:-1], tokenizer, model)
- assert torch.all(chosen_ids.masked_select(torch.logical_not(c_mask)) == tokenizer.pad_token_id)
- else:
- check_content(chosen_ids.masked_select(c_mask), tokenizer, model)
- assert torch.all(c_mask)
- if reject_ids.masked_select(r_mask)[-1] == tokenizer.eos_token_id:
- check_content(reject_ids.masked_select(r_mask)[:-1], tokenizer, model)
- assert torch.all(reject_ids.masked_select(torch.logical_not(r_mask)) == tokenizer.pad_token_id)
- else:
- check_content(reject_ids.masked_select(r_mask), tokenizer, model)
- assert torch.all(r_mask)
-
-
-@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama", "chatglm"])
-@pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None])
-@pytest.mark.parametrize("max_dataset_size", [2])
-@pytest.mark.parametrize("max_length", [32, 1024])
-def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: int, max_length: int):
- tokenizer = make_tokenizer(model)
- if dataset_path == "yizhongw/self_instruct":
- data = load_dataset(dataset_path, "super_natural_instructions")
- train_data = data["train"].select(range(max_dataset_size))
- sft_dataset = SFTDataset(train_data, tokenizer, max_length)
- else:
- with tempfile.TemporaryDirectory() as tmp_dir:
- dataset_name = "sft_dataset.json"
- with open(os.path.join(tmp_dir, dataset_name), "w") as f:
- json.dump(SFT_DATASET, f)
- sft_dataset = SupervisedDataset(
- tokenizer=tokenizer,
- data_path=os.path.join(tmp_dir, dataset_name),
- max_datasets_size=max_dataset_size,
- max_length=max_length,
- )
- assert len(sft_dataset) == min(max_dataset_size, len(SFT_DATASET))
-
- if isinstance(tokenizer, ChatGLMTokenizer):
- for i in range(max_dataset_size):
- assert isinstance(sft_dataset[i], dict)
- assert list(sft_dataset[i].keys()) == ["input_ids", "labels"]
- input_ids = sft_dataset[i]["input_ids"]
- labels = sft_dataset[i]["labels"]
- assert input_ids.shape == labels.shape == torch.Size([max_length])
-
- ignore_mask = labels == IGNORE_INDEX
- assert input_ids.masked_select(torch.logical_not(ignore_mask))[0] == tokenizer.bos_token_id
- check_content(input_ids.masked_select(torch.logical_not(ignore_mask)), tokenizer, model)
- return
-
- for i in range(max_dataset_size):
- assert isinstance(sft_dataset[i], dict)
- assert list(sft_dataset[i].keys()) == ["input_ids", "labels", "attention_mask"]
- input_ids = sft_dataset[i]["input_ids"]
- labels = sft_dataset[i]["labels"]
- attention_mask = sft_dataset[i]["attention_mask"].to(torch.bool)
- assert input_ids.shape == labels.shape == attention_mask.shape == torch.Size([max_length])
- if input_ids.masked_select(attention_mask)[-1] == tokenizer.eos_token_id:
- check_content(input_ids.masked_select(attention_mask)[:-1], tokenizer, model)
- assert torch.all(input_ids.masked_select(torch.logical_not(attention_mask)) == tokenizer.pad_token_id)
- else:
- check_content(input_ids.masked_select(attention_mask), tokenizer, model)
- assert torch.all(attention_mask)
- ignore_mask = labels == IGNORE_INDEX
- prompt_mask = torch.logical_and(ignore_mask, attention_mask)
- check_content(input_ids.masked_select(prompt_mask), tokenizer, model)
- assert torch.all(input_ids.masked_select(ignore_mask ^ prompt_mask) == tokenizer.pad_token_id)
-
-
-if __name__ == "__main__":
- test_sft_dataset(model="bloom", dataset_path="yizhongw/self_instruct", max_dataset_size=2, max_length=256)
-
- test_reward_dataset(
- model="gpt2", dataset_path="Anthropic/hh-rlhf", subset="harmless-base", max_datasets_size=8, max_length=256
- )
-
- test_prompt_dataset(model="opt", max_datasets_size=2, max_length=128)
diff --git a/applications/Chat/tests/test_experience.py b/applications/Chat/tests/test_experience.py
deleted file mode 100644
index a9591259800d..000000000000
--- a/applications/Chat/tests/test_experience.py
+++ /dev/null
@@ -1,130 +0,0 @@
-import copy
-import os
-
-import pytest
-import torch
-import torch.distributed as dist
-from coati.experience_buffer import NaiveExperienceBuffer
-from coati.experience_maker import NaiveExperienceMaker
-from coati.models.base import RewardModel
-from coati.models.gpt import GPTActor, GPTCritic
-from coati.trainer.ppo import _set_default_generate_kwargs
-from coati.trainer.strategies import DDPStrategy, GeminiStrategy
-from coati.trainer.strategies.colossalai import LowLevelZeroStrategy
-from transformers.models.gpt2.configuration_gpt2 import GPT2Config
-
-from colossalai.testing import rerun_if_address_is_in_use, spawn
-
-GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
-
-
-def get_data(batch_size: int, seq_len: int = 10) -> dict:
- input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda")
- attention_mask = torch.ones_like(input_ids)
- return dict(input_ids=input_ids, attention_mask=attention_mask)
-
-
-def gather_and_equal(tensor: torch.Tensor) -> bool:
- world_size = dist.get_world_size()
- outputs = [torch.empty_like(tensor) for _ in range(world_size)]
- dist.all_gather(outputs, tensor.contiguous())
- for t in outputs[1:]:
- if not torch.equal(outputs[0], t):
- return False
- return True
-
-
-def make_and_consume_experience(strategy):
- EXPERIENCE_BATCH_SIZE = 4
- SAMPLE_BATCH_SIZE = 2
-
- if strategy == "ddp":
- strategy = DDPStrategy()
- elif strategy == "colossalai-zero2":
- strategy = LowLevelZeroStrategy()
- elif strategy == "colossalai-gemini":
- strategy = GeminiStrategy(placement_policy="static")
- else:
- raise ValueError(f'Unsupported strategy "{strategy}"')
-
- with strategy.model_init_context():
- actor = GPTActor(config=GPT_CONFIG).cuda()
- critic = GPTCritic(config=GPT_CONFIG).cuda()
-
- initial_model = GPTActor(config=GPT_CONFIG).cuda()
- reward_model = RewardModel(model=copy.deepcopy(critic.model)).cuda()
-
- actor, critic, initial_model, reward_model = strategy.prepare(actor, critic, initial_model, reward_model)
-
- class MockTokenizer:
- def __init__(self):
- self.padding_side = "left"
- self.eos_token_id = 0
- self.pad_token_id = 0
-
- tokenizer = MockTokenizer()
- experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer)
- data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
-
- generate_kwargs = dict(do_sample=True, max_length=16)
- generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
-
- # experience of all ranks should be the same
- for _ in range(2):
- data = get_data(EXPERIENCE_BATCH_SIZE)
- assert gather_and_equal(data["input_ids"])
- assert gather_and_equal(data["attention_mask"])
- experience = experience_maker.make_experience(**data, do_sample=True, max_length=16)
- assert gather_and_equal(experience.sequences)
- assert gather_and_equal(experience.action_log_probs)
- assert gather_and_equal(experience.values)
- assert gather_and_equal(experience.reward)
- assert gather_and_equal(experience.advantages)
- assert gather_and_equal(experience.action_mask)
- assert gather_and_equal(experience.attention_mask)
- data_buffer.append(experience)
-
- # data buffer's data should be the same
- buffer_size = torch.tensor([len(data_buffer)], device="cuda")
- assert gather_and_equal(buffer_size)
- for item in data_buffer.items:
- assert gather_and_equal(item.sequences)
- assert gather_and_equal(item.action_log_probs)
- assert gather_and_equal(item.values)
- assert gather_and_equal(item.reward)
- assert gather_and_equal(item.advantages)
- assert gather_and_equal(item.action_mask)
- assert gather_and_equal(item.attention_mask)
-
- # dataloader of each rank should have the same size and different batch
- dataloader = strategy.setup_dataloader(data_buffer)
- dataloader_size = torch.tensor([len(dataloader)], device="cuda")
- assert gather_and_equal(dataloader_size)
- for experience in dataloader:
- assert not gather_and_equal(experience.sequences)
- assert not gather_and_equal(experience.action_log_probs)
- assert not gather_and_equal(experience.values)
- assert not gather_and_equal(experience.reward)
- assert not gather_and_equal(experience.advantages)
- # action mask and attention mask may be same
-
-
-def run_dist(rank, world_size, port, strategy):
- os.environ["RANK"] = str(rank)
- os.environ["LOCAL_RANK"] = str(rank)
- os.environ["WORLD_SIZE"] = str(world_size)
- os.environ["MASTER_ADDR"] = "localhost"
- os.environ["MASTER_PORT"] = str(port)
- make_and_consume_experience(strategy)
-
-
-@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [2])
-@pytest.mark.parametrize("strategy", ["ddp", "colossalai-zero2", "colossalai-gemini"])
-@rerun_if_address_is_in_use()
-def test_experience(world_size, strategy):
- spawn(run_dist, world_size, strategy=strategy)
-
-
-if __name__ == "__main__":
- test_experience(2, "colossalai-zero2")
diff --git a/applications/Chat/tests/test_inference.sh b/applications/Chat/tests/test_inference.sh
deleted file mode 100755
index 849db06e58ab..000000000000
--- a/applications/Chat/tests/test_inference.sh
+++ /dev/null
@@ -1,11 +0,0 @@
-set -xue
-
-BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
-EXAMPLES_DIR=$BASE_DIR/examples
-
-echo "[Test]: testing inference ..."
-
-# HACK: skip llama due to oom
-for model in 'gpt2' 'bloom' 'opt'; do
- python $EXAMPLES_DIR/inference.py --model $model
-done
diff --git a/applications/Chat/tests/test_models.py b/applications/Chat/tests/test_models.py
deleted file mode 100644
index b2c22ac6a3b9..000000000000
--- a/applications/Chat/tests/test_models.py
+++ /dev/null
@@ -1,245 +0,0 @@
-import copy
-from typing import Any, Callable, Dict, Tuple
-
-import pytest
-import torch
-import torch.nn as nn
-from coati.models.base import Actor, Critic, RewardModel, get_base_model
-from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
-from coati.models.chatglm import ChatGLMActor
-from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
-from coati.models.generation import generate
-from coati.models.gpt import GPTRM, GPTActor, GPTCritic
-from coati.models.llama import LlamaActor
-from coati.models.lora import LoraLinear, convert_to_lora_module
-from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
-from coati.models.opt import OPTRM, OPTActor, OPTCritic
-from coati.models.utils import calc_action_log_probs, masked_mean
-
-
-@pytest.mark.parametrize("batch_size", [4])
-@pytest.mark.parametrize("seq_len", [32])
-@pytest.mark.parametrize(
- "actor_maker",
- [
- lambda: BLOOMActor(),
- lambda: GPTActor(),
- # HACK: skip llama due to long execution time
- # lambda: LlamaActor(),
- lambda: OPTActor(),
- ],
-)
-@pytest.mark.parametrize(
- "generate_kwargs",
- [
- {
- "max_length": 64,
- "use_cache": True,
- "do_sample": True,
- "temperature": 1.0,
- "top_k": 50,
- }
- ],
-)
-def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]):
- class MockTokenizer:
- def __init__(self):
- self.padding_side = "left"
- self.eos_token_id = 0
- self.pad_token_id = 0
-
- actor = actor_maker()
- input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda()
- tokenizer = MockTokenizer()
- sequences = generate(actor.cuda(), input_ids, tokenizer, **generate_kwargs)
- assert sequences.shape == (batch_size, generate_kwargs["max_length"])
-
-
-def test_utils():
- fn_input = {"tensor": torch.ones((10,)), "mask": torch.randint(0, 2, (10,))}
- fn_output = masked_mean(dim=0, **fn_input)
- assert fn_output.dim() == 0
- assert torch.allclose(fn_output, torch.tensor(1.0))
-
- batch_size = 4
- seq_len = 32
- num_labels = 10
- num_actions = 2
- fn_input = {
- "logits": torch.randn((batch_size, seq_len, num_labels)),
- "sequences": torch.randint(0, num_labels, (batch_size, seq_len)),
- "num_actions": num_actions,
- }
- fn_output = calc_action_log_probs(**fn_input)
- assert fn_output.shape == (batch_size, num_actions)
-
-
-@pytest.mark.parametrize("lora_rank", [4])
-@pytest.mark.parametrize("num_dim", [32])
-@pytest.mark.parametrize("num_layers", [4])
-def test_lora(lora_rank: int, num_dim: int, num_layers: int):
- model = nn.ModuleList([nn.Linear(num_dim, num_dim) for _ in range(num_layers)])
- lora_model = convert_to_lora_module(model, lora_rank)
- assert isinstance(lora_model, nn.ModuleList)
- for i in range(num_layers):
- assert isinstance(lora_model[i], LoraLinear)
- assert lora_model[i].lora_A.shape == (lora_rank, num_dim)
- assert lora_model[i].lora_B.shape == (num_dim, lora_rank)
-
- old_model = copy.deepcopy(lora_model)
- for i in range(num_layers):
- assert isinstance(lora_model[i], LoraLinear)
- assert torch.allclose(old_model[i].weight, lora_model[i].weight)
- assert torch.allclose(old_model[i].bias, lora_model[i].bias)
- assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A)
- optimizer = torch.optim.Adam(lora_model.parameters())
- x = torch.randn(8, num_dim)
- for i in range(num_layers):
- x = lora_model[i](x)
- loss = x.sum()
- loss.backward()
- optimizer.step()
- for i in range(num_layers):
- assert isinstance(lora_model[i], LoraLinear)
- assert torch.allclose(old_model[i].weight, lora_model[i].weight)
- assert torch.allclose(old_model[i].bias, lora_model[i].bias)
- assert not torch.allclose(
- old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A
- )
-
-
-@pytest.mark.parametrize("batch_size", [8])
-@pytest.mark.parametrize("seq_len", [128])
-@pytest.mark.parametrize(
- "models_maker",
- [
- lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()),
- lambda: (GPTActor(), GPTCritic(), GPTRM()),
- # HACK: skip llama due to long execution time
- # lambda: (LlamaActor(), LlamaCritic(), LlamaRM()),
- lambda: (OPTActor(), OPTCritic(), OPTRM()),
- lambda: (ChatGLMActor(), None, None),
- ],
-)
-@torch.no_grad()
-def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], batch_size: int, seq_len: int):
- actor_input = {
- "input_ids": torch.randint(0, 100, (batch_size, seq_len)),
- "attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
- }
- critic_input = {
- "sequences": torch.randint(0, 100, (batch_size, seq_len)),
- "attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
- }
- rm_input = {
- "sequences": torch.randint(0, 100, (batch_size, seq_len)),
- "attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
- }
-
- actor, critic, rm = models_maker()
- if isinstance(actor, ChatGLMActor):
- actor = actor.float()
- tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
- chatglm_special_token = torch.tensor([tokenizer.gmask_token_id, tokenizer.bos_token_id]).repeat(batch_size, 1)
- actor_input = {
- "input_ids": torch.cat(
- (
- torch.randint(0, 100, (batch_size, seq_len // 2)),
- chatglm_special_token,
- torch.randint(0, 100, (batch_size, seq_len // 2 - 2)),
- ),
- dim=1,
- ),
- "attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len)),
- }
- assert isinstance(actor, Actor)
- get_base_model(actor)
- actor_output = actor(**actor_input)
- assert actor_output.logits.shape[:2] == (batch_size, seq_len)
-
- if critic:
- assert isinstance(critic, Critic)
- get_base_model(critic)
- critic_output = critic(**critic_input)
- assert critic_output.shape == (batch_size,)
-
- if rm:
- assert isinstance(rm, RewardModel)
- get_base_model(rm)
- rm_output = rm(**rm_input)
- assert rm_output.shape == (batch_size,)
-
-
-@pytest.mark.parametrize("batch_size", [16])
-@pytest.mark.parametrize("seq_len", [128])
-@pytest.mark.parametrize("num_labels", [100])
-def test_loss(batch_size: int, seq_len: int, num_labels: int):
- loss = GPTLMLoss()
- loss_input = {
- "logits": torch.randn(batch_size, seq_len, num_labels),
- "labels": torch.randint(0, num_labels, (batch_size, seq_len)),
- }
- loss(**loss_input)
-
- loss = PolicyLoss()
- loss_input = {
- "log_probs": torch.randn(
- batch_size,
- ),
- "old_log_probs": torch.randn(
- batch_size,
- ),
- "advantages": torch.randn(
- batch_size,
- ),
- }
- loss(**loss_input)
-
- loss = ValueLoss()
- loss_input = {
- "values": torch.randn(
- batch_size,
- ),
- "old_values": torch.randn(
- batch_size,
- ),
- "reward": torch.randn(
- batch_size,
- ),
- }
- loss(**loss_input)
-
- loss = LogSigLoss()
- loss_input = {
- "chosen_reward": torch.randn(
- batch_size,
- ),
- "reject_reward": torch.randn(
- batch_size,
- ),
- }
- loss(**loss_input)
-
- loss = LogExpLoss()
- loss_input = {
- "chosen_reward": torch.randn(
- batch_size,
- ),
- "reject_reward": torch.randn(
- batch_size,
- ),
- }
- loss(**loss_input)
-
-
-if __name__ == "__main__":
- generate_kwargs = dict(max_length=40, use_cache=True, do_sample=True, temperature=1.0, top_k=50)
- test_generation(lambda: LlamaActor(), batch_size=4, seq_len=32, generate_kwargs=generate_kwargs)
-
- test_utils()
-
- test_lora(lora_rank=2, num_dim=8, num_layers=2)
-
- test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128)
-
- test_loss(batch_size=8, seq_len=128, num_labels=100)
diff --git a/applications/Chat/tests/test_train.sh b/applications/Chat/tests/test_train.sh
deleted file mode 100755
index 68fca7fbf8c0..000000000000
--- a/applications/Chat/tests/test_train.sh
+++ /dev/null
@@ -1,233 +0,0 @@
-#!/usr/bin/env bash
-
-set_n_least_used_CUDA_VISIBLE_DEVICES() {
- local n=${1:-"9999"}
- echo "GPU Memory Usage:"
- local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
- tail -n +2 |
- nl -v 0 |
- tee /dev/tty |
- sort -g -k 2 |
- awk '{print $1}' |
- head -n $n)
- export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
- echo "Now CUDA_VISIBLE_DEVICES is set to:"
- echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
-}
-
-set_n_least_used_CUDA_VISIBLE_DEVICES 4
-
-set -xu
-
-if [ -z "$SFT_DATASET" ]; then
- echo "Please set \$SFT_DATASET to the path to sft dataset."
- exit 1
-fi
-
-if [ -z "$PROMPT_DATASET" ]; then
- echo "Please set \$PROMPT_DATASET to the path to prompts csv."
- exit 1
-fi
-
-if [ -z "$PRETRAIN_DATASET" ]; then
- echo "Please set \$PRETRAIN_DATASET to the path to alpaca data."
- exit 1
-fi
-
-NUM_RETRY=3
-BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
-EXAMPLES_DIR=$BASE_DIR/examples
-MODELS_DIR=$BASE_DIR/examples/models_config
-MODELS=('gpt2' 'bloom' 'opt' 'llama')
-STRATEGIES=('ddp' 'colossalai_gemini' 'colossalai_zero2')
-
-
-export OMP_NUM_THREADS=8
-
-# install requirements
-pip install -r $EXAMPLES_DIR/requirements.txt
-
-python $EXAMPLES_DIR/download_model.py --model-dir $MODELS_DIR --config-only
-
-get_pretrain() {
- local model=$1
- if [[ $model == "gpt2" ]]; then
- echo "gpt2"
- elif [[ $model == "bloom" ]]; then
- echo "bigscience/bloom-560m"
- elif [[ $model == "opt" ]]; then
- echo "facebook/opt-350m"
- else
- echo "Unknown model $model"
- exit 1
- fi
-}
-
-random_choice() {
- local arr=("$@")
- local len=${#arr[@]}
- local idx=$((RANDOM % len))
- echo ${arr[$idx]}
-}
-
-echo "[Test]: testing sft ..."
-
-# FIXME: This is a hack to skip tests that are not working
-# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
-# - llama-*: These tests can be passed locally, skipped for long execution time
-# - *-gemini: Gemini plugin does not support `from_pretrained` yet
-SKIPPED_TESTS=(
- "gpt2-ddp"
- "llama-ddp"
- "llama-colossalai_gemini"
- "llama-colossalai_zero2"
-)
-
-GRAD_CKPTS=('' '--grad_checkpoint')
-for lora_rank in '0'; do
- for model in ${MODELS[@]}; do
- strategies=($(shuf -e "${STRATEGIES[@]}"))
- for strategy in ${strategies[@]}; do
- if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then
- echo "[Test]: Skipped $model-$strategy-$lora_rank"
- continue
- elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then
- echo "[Test]: Skipped $model-$strategy"
- continue
- fi
- pretrain=$(get_pretrain $model)
- pretrain_model=""
- if [[ $lora_rank -gt 0 ]]; then
- pretrain_model="--pretrain $pretrain"
- fi
- grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
- for i in $(seq $NUM_RETRY); do
- echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
- torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_sft.py \
- $pretrain_model --tokenizer $MODELS_DIR/$model \
- --model $model --strategy $strategy --lora_rank $lora_rank $grad_ckpt \
- --dataset $SFT_DATASET --max_datasets_size 8 \
- --max_epochs 1 --batch_size 1 --accumulation_steps 1 --lr 1e-8 \
- --save_path $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank}
- passed=$?
- if [ $passed -eq 0 ]; then
- break
- fi
- done
- if [ $passed -ne 0 ]; then
- echo "[Test]: Failed $model-$strategy-$lora_rank"
- exit 1
- fi
- done
- done
-done
-
-echo "[Test]: testing reward model ..."
-
-# FIXME: This is a hack to skip tests that are not working
-# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
-# - llama-*: These tests can be passed locally, skipped for long execution time
-# - *-gemini: Gemini plugin does not support `from_pretrained` yet
-SKIPPED_TESTS=(
- "gpt2-ddp"
- "llama-ddp"
- "llama-colossalai_gemini"
- "llama-colossalai_zero2"
-)
-
-LOSS_FNS=('log_sig' 'log_exp')
-DATASETS=('Anthropic/hh-rlhf' 'Dahoas/rm-static')
-for lora_rank in '0'; do
- for model in ${MODELS[@]}; do
- strategies=($(shuf -e "${STRATEGIES[@]}"))
- for strategy in ${strategies[@]}; do
- if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then
- echo "[Test]: Skipped $model-$strategy-$lora_rank"
- continue
- elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then
- echo "[Test]: Skipped $model-$strategy"
- continue
- fi
- pretrain=$(get_pretrain $model)
- pretrain_model=""
- if [[ $lora_rank -gt 0 ]]; then
- pretrain_model="--pretrain $pretrain"
- fi
- loss_fn=$(random_choice "${LOSS_FNS[@]}")
- dataset=$(random_choice "${DATASETS[@]}")
- subset=$(if [[ $dataset == "Dahoas/rm-static" ]]; then echo "None"; else echo "harmless-base"; fi)
- for i in $(seq $NUM_RETRY); do
- echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
- torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_reward_model.py \
- $pretrain_model --tokenizer $MODELS_DIR/$model \
- --dataset $dataset --subset $subset --max_datasets_size 8 \
- --model $model --strategy $strategy --lora_rank $lora_rank \
- --loss_fn $loss_fn --batch_size 1 --lr 1e-8 \
- --save_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt
- passed=$?
- if [ $passed -eq 0 ]; then
- break
- fi
- done
- if [ $passed -ne 0 ]; then
- echo "[Test]: Failed to train reward model $model-$strategy-$lora_rank"
- exit 1
- fi
- done
- done
-done
-
-echo "[Test]: testing RLHF ..."
-
-# FIXME: This is a hack to skip tests that are not working
-# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
-# - llama-*: These tests can be passed locally, skipped for long execution time
-# - *-gemini: Gemini plugin does not support `from_pretrained` yet
-SKIPPED_TESTS=(
- "gpt2-ddp"
- "llama-ddp"
- "llama-colossalai_gemini"
- "llama-colossalai_zero2"
-)
-
-for model in ${MODELS[@]}; do
- for lora_rank in '0'; do
- strategies=($(shuf -e "${STRATEGIES[@]}"))
- for strategy in ${strategies[@]}; do
- if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then
- echo "[Test]: Skipped $model-$strategy-$lora_rank"
- continue
- elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then
- echo "[Test]: Skipped $model-$strategy"
- continue
- fi
- rm_pretrain=$(get_pretrain $model)
- rm_pretrain_model=""
- if [[ $lora_rank -gt 0 ]]; then
- rm_pretrain_model="--rm_pretrain $rm_pretrain"
- fi
- for i in $(seq $NUM_RETRY); do
- echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
- torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_prompts.py \
- --prompt_dataset $PROMPT_DATASET --pretrain_dataset $PRETRAIN_DATASET --max_datasets_size 32 \
- --strategy $strategy --model $model --tokenizer $MODELS_DIR/$model \
- --num_episodes 1 --num_collect_steps 1 --num_update_steps 1 --lr 1e-8 \
- --experience_batch_size 2 --train_batch_size 1 --lora_rank $lora_rank \
- --pretrain $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} \
- $rm_pretrain_model --rm_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt \
- --save_path $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts
- passed=$?
- if [ $passed -eq 0 ]; then
- break
- fi
- done
- if [ $passed -ne 0 ]; then
- echo "[Test]: Failed to train RLHF $model-$strategy-$lora_rank"
- exit 1
- fi
- done
- rm -rf $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank}
- rm $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt
- done
-done
-rm -rf $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts
diff --git a/applications/Chat/.gitignore b/applications/ColossalChat/.gitignore
old mode 100644
new mode 100755
similarity index 87%
rename from applications/Chat/.gitignore
rename to applications/ColossalChat/.gitignore
index 5fa068105e26..33950adc0bb5
--- a/applications/Chat/.gitignore
+++ b/applications/ColossalChat/.gitignore
@@ -143,6 +143,17 @@ docs/.build
*.pt
# wandb log
-example/wandb/
+examples/wandb/
+examples/logs/
+examples/output/
examples/awesome-chatgpt-prompts/
+temp/
+
+# ColossalChat
+applications/ColossalChat/logs
+applications/ColossalChat/models
+applications/ColossalChat/sft_data
+applications/ColossalChat/prompt_data
+applications/ColossalChat/preference_data
+applications/ColossalChat/temp
diff --git a/applications/Chat/LICENSE b/applications/ColossalChat/LICENSE
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/LICENSE
rename to applications/ColossalChat/LICENSE
diff --git a/applications/Chat/README.md b/applications/ColossalChat/README.md
old mode 100644
new mode 100755
similarity index 59%
rename from applications/Chat/README.md
rename to applications/ColossalChat/README.md
index 349c26aad746..769f0b3d072c
--- a/applications/Chat/README.md
+++ b/applications/ColossalChat/README.md
@@ -13,10 +13,10 @@
- [Install the environment](#install-the-environment)
- [Install the Transformers](#install-the-transformers)
- [How to use?](#how-to-use)
- - [Supervised datasets collection](#supervised-datasets-collection)
- - [RLHF Training Stage1 - Supervised instructs tuning](#RLHF-training-stage1---supervised-instructs-tuning)
- - [RLHF Training Stage2 - Training reward model](#RLHF-training-stage2---training-reward-model)
- - [RLHF Training Stage3 - Training model with reinforcement learning by human feedback](#RLHF-training-stage3---training-model-with-reinforcement-learning-by-human-feedback)
+ - [Supervised datasets collection](#step-1-data-collection)
+ - [RLHF Training Stage1 - Supervised instructs tuning](#rlhf-training-stage1---supervised-instructs-tuning)
+ - [RLHF Training Stage2 - Training reward model](#rlhf-training-stage2---training-reward-model)
+ - [RLHF Training Stage3 - Training model with reinforcement learning by human feedback](#rlhf-training-stage3---proximal-policy-optimization)
- [Inference Quantization and Serving - After Training](#inference-quantization-and-serving---after-training)
- [Coati7B examples](#coati7b-examples)
- [Generation](#generation)
@@ -36,7 +36,7 @@
---
-## What is ColossalChat and Coati ?
+## What Is ColossalChat And Coati ?
[ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat) is the project to implement LLM with RLHF, powered by the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) project.
@@ -91,107 +91,191 @@ More details can be found in the latest news.
## Install
-### Install the environment
+### Install the Environment
```bash
-conda create -n coati
-conda activate coati
+# Create new environment
+conda create -n colossal-chat python=3.10.9 (>=3.8.7)
+conda activate colossal-chat
+
+# Install flash-attention
+git clone -b v2.0.5 https://github.com/Dao-AILab/flash-attention.git
+cd $FLASH_ATTENTION_ROOT/
+pip install .
+cd $FLASH_ATTENTION_ROOT/csrc/xentropy
+pip install .
+cd $FLASH_ATTENTION_ROOT/csrc/layer_norm
+pip install .
+cd $FLASH_ATTENTION_ROOT/csrc/rotary
+pip install .
+
+# Clone Colossalai
git clone https://github.com/hpcaitech/ColossalAI.git
-cd ColossalAI/applications/Chat
+
+# Install ColossalAI
+cd $COLOSSAL_AI_ROOT
+BUILD_EXT=1 pip install .
+
+# Install ColossalChat
+cd $COLOSSAL_AI_ROOT/applications/Chat
pip install .
```
-### Install the Transformers
+## How To Use?
-```bash
-pip install transformers==4.30.2
-```
+### RLHF Training Stage1 - Supervised Instructs Tuning
-## How to use?
+Stage1 is supervised instructs fine-tuning (SFT). This step is a crucial part of the RLHF training process, as it involves training a machine learning model using human-provided instructions to learn the initial behavior for the task at hand. Here's a detailed guide on how to SFT your LLM with ColossalChat. More details can be found in [example guideline](./examples/README.md).
-### Supervised datasets collection
+#### Step 1: Data Collection
+The first step in Stage 1 is to collect a dataset of human demonstrations of the following format.
-We collected 104K bilingual datasets of Chinese and English, and you can find the datasets in this repo
-[InstructionWild](https://github.com/XueFuzhao/InstructionWild) and in this [file](https://github.com/XueFuzhao/InstructionWild/blob/main/data/README.md).
+```json
+[
+ {"messages":
+ [
+ {
+ "from": "human",
+ "content": "what are some pranks with a pen i can do?"
+ },
+ {
+ "from": "assistant",
+ "content": "Are you looking for practical joke ideas?"
+ },
+ ...
+ ]
+ },
+ ...
+]
+```
-Here is how we collected the data
+#### Step 2: Preprocessing
+Once you have collected your SFT dataset, you will need to preprocess it. This involves four steps: data cleaning, data deduplication, formatting and tokenization. In this section, we will focus on formatting and tokenization.
-
-
-
+In this code, we provide a flexible way for users to set the conversation template for formatting chat data using Huggingface's newest feature--- chat template. Please follow the [example guideline](./examples/README.md) on how to format and tokenize data.
-### RLHF Training Stage1 - Supervised instructs tuning
+#### Step 3: Training
+Choose a suitable model architecture for your task. Note that your model should be compatible with the tokenizer that you used to tokenize the SFT dataset. You can run [train_sft.sh](./examples/training_scripts/train_sft.sh) to start a supervised instructs fine-tuning. More details can be found in [example guideline](./examples/README.md).
-Stage1 is supervised instructs fine-tuning, which uses the datasets mentioned earlier to fine-tune the model.
+### RLHF Training Stage2 - Training Reward Model
-You can run the `examples/train_sft.sh` to start a supervised instructs fine-tuning.
-[[Stage1 tutorial video]](https://www.youtube.com/watch?v=-qFBZFmOJfg)
+Stage2 trains a reward model, which obtains corresponding scores by manually ranking different outputs for the same prompt and supervises the training of the reward model.
-**Note**: the supervised dataset follows the following format,
+#### Step 1: Data Collection
+Below shows the preference dataset format used in training the reward model.
```json
[
- {
- "instruction": "Provide a list of the top 10 most popular mobile games in Asia",
- "input": "",
- "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
- "id": 0
+ {"context": [
+ {
+ "from": "human",
+ "content": "Introduce butterflies species in Oregon."
+ }
+ ]
+ "chosen": [
+ {
+ "from": "assistant",
+ "content": "About 150 species of butterflies live in Oregon, with about 100 species are moths..."
+ },
+ ...
+ ],
+ "rejected": [
+ {
+ "from": "assistant",
+ "content": "Are you interested in just the common butterflies? There are a few common ones which will be easy to find..."
+ },
+ ...
+ ]
},
...
]
```
-### RLHF Training Stage2 - Training reward model
-
-Stage2 trains a reward model, which obtains corresponding scores by manually ranking different outputs for the same prompt and supervises the training of the reward model
+#### Step 2: Preprocessing
+Similar to the second step in the previous stage, we format the reward data into the same structured format as used in step 2 of the SFT stage. You can run [prepare_preference_dataset.sh](./examples/data_preparation_scripts/prepare_preference_dataset.sh) to prepare the preference data for reward model training.
-You can run the `examples/train_rm.sh` to start a reward model training.
-[[Stage2 tutorial video]](https://www.youtube.com/watch?v=gMx2CApKhuo)
+#### Step 3: Training
+You can run [train_rm.sh](./examples/training_scripts/train_rm.sh) to start the reward model training. More details can be found in [example guideline](./examples/README.md).
-### RLHF Training Stage3 - Training model with reinforcement learning by human feedback
+### RLHF Training Stage3 - Proximal Policy Optimization
-Stage3 uses reinforcement learning algorithm, which is the most complex part of the training process:
+In stage3 we will use reinforcement learning algorithm--- Proximal Policy Optimization (PPO), which is the most complex part of the training process:
-You can run the `examples/train_prompts.sh` to start training PPO with human feedback.
-[[Stage3 tutorial video]](https://www.youtube.com/watch?v=Z8wwSHxPL9g)
-
-**Note**: the required datasets follow the following format,
-
-- `pretrain dataset`
-
- ```json
- [
- {
- "instruction": "Provide a list of the top 10 most popular mobile games in Asia",
- "input": "",
- "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
- "id": 0
- },
- ...
- ]
- ```
-
-- `prompt dataset`
-
- ```json
- [
- {
- "instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"",
- "id": 0
- },
- {
- "instruction": "Write a descriptive paragraph about a memorable vacation you went on",
- "id": 1
- },
- ...
- ]
- ```
-
-For more details, see [`examples/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples).
+#### Step 1: Data Collection
+PPO uses two kind of training data--- the prompt data and the sft data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "human" and thus the "assistant" needs to generate a response to answer to the "human". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
+
+```json
+[
+ {"messages":
+ [
+ {
+ "from": "human",
+ "content": "what are some pranks with a pen i can do?"
+ }
+ ...
+ ]
+ },
+]
+```
+
+#### Step 2: Data Preprocessing
+To prepare the prompt dataset for PPO training, simply run [prepare_prompt_dataset.sh](./examples/data_preparation_scripts/prepare_prompt_dataset.sh)
+
+#### Step 3: Training
+You can run the [train_ppo.sh](./examples/training_scripts/train_ppo.sh) to start PPO training. Here are some unique arguments for PPO, please refer to the training configuration section for other training configuration. More detais can be found in [example guideline](./examples/README.md).
+
+```bash
+--pretrain $PRETRAINED_MODEL_PATH \
+--rm_pretrain $PRETRAINED_MODEL_PATH \ # reward model architectual
+--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
+--rm_checkpoint_path $REWARD_MODEL_PATH \ # reward model checkpoint path
+--prompt_dataset ${prompt_dataset[@]} \ # List of string, the prompt dataset
+--ptx_dataset ${ptx_dataset[@]} \ # List of string, the SFT data used in the SFT stage
+--ptx_batch_size 1 \ # batch size for calculate ptx loss
+--ptx_coef 0.0 \ # none-zero if ptx loss is enable
+--num_episodes 2000 \ # number of episodes to train
+--num_collect_steps 1 \
+--num_update_steps 1 \
+--experience_batch_size 8 \
+--train_batch_size 4 \
+--accumulation_steps 2
+```
+
+Each episode has two phases, the collect phase and the update phase. During the collect phase, we will collect experiences (answers generated by actor), store those in ExperienceBuffer. Then data in ExperienceBuffer is used during the update phase to update parameter of actor and critic.
+
+- Without tensor parallelism,
+```
+experience buffer size
+= num_process * num_collect_steps * experience_batch_size
+= train_batch_size * accumulation_steps * num_process
+```
+
+- With tensor parallelism,
+```
+num_tp_group = num_process / tp
+experience buffer size
+= num_tp_group * num_collect_steps * experience_batch_size
+= train_batch_size * accumulation_steps * num_tp_group
+```
+
+## Alternative Option For RLHF: Direct Preference Optimization
+
+For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in the paper (available at [https://arxiv.org/abs/2305.18290](https://arxiv.org/abs/2305.18290)), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO.
+
+### DPO Training Stage1 - Supervised Instructs Tuning
+
+Please refer the [sft section](#dpo-training-stage1---supervised-instructs-tuning) in the PPO part.
+
+### DPO Training Stage2 - DPO Training
+#### Step 1: Data Collection & Preparation
+For DPO training, you only need the preference dataset. Please follow the instruction in the [preference dataset preparation section](#rlhf-training-stage2---training-reward-model) to prepare the preference data for DPO training.
+
+#### Step 2: Training
+You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. More detais can be found in [example guideline](./examples/README.md).
### Inference Quantization and Serving - After Training
@@ -301,91 +385,60 @@ You can find more examples in this [repo](https://github.com/XueFuzhao/Instructi
We have integrated the Transformers save and load pipeline, allowing users to freely call Hugging Face's language models and save them in the HF format.
+- Option 1: Save the model weights, model config and generation config (Note: tokenizer will not be saved) which can be loaded using HF's from_pretrained method.
```python
-from coati.models.llama import LlamaLM
-from coati.trainer import SFTTrainer
-
-model = LlamaLM(pretrained=args.pretrain)
-tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
-
-(model, optim) = strategy.prepare((model, optim))
-trainer = SFTTrainer(model=model,
- strategy=strategy,
- optim=optim,
- train_dataloader=train_dataloader,
- eval_dataloader=eval_dataloader,
- batch_size=args.batch_size,
- max_epochs=args.max_epochs,
- accumulation_steps=args.accumulation_steps
- )
-
-trainer.fit()
-# this saves in pytorch format
-strategy.save_model(model, args.save_path, only_rank0=True)
-
-# this saves in HF format
-strategy.save_pretrained(model, args.save_path, only_rank0=True, tokenizer=tokenizer)
+# if use lora, you can choose to merge lora weights before saving
+if args.lora_rank > 0 and args.merge_lora_weights:
+ from coati.models.lora import LORA_MANAGER
+
+ # NOTE: set model to eval to merge LoRA weights
+ LORA_MANAGER.merge_weights = True
+ model.eval()
+# save model checkpoint after fitting on only rank0
+booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
+
```
+- Option 2: Save the model weights, model config, generation config, as well as the optimizer, learning rate scheduler, running states (Note: tokenizer will not be saved) which are needed for resuming training.
+```python
+from coati.utils import save_checkpoint
+# save model checkpoint after fitting on only rank0
+save_checkpoint(
+ save_dir=actor_save_dir,
+ booster=actor_booster,
+ model=model,
+ optimizer=optim,
+ lr_scheduler=lr_scheduler,
+ epoch=0,
+ step=step,
+ batch_size=train_batch_size,
+ coordinator=coordinator,
+ )
+```
+To load the saved checkpoint
+```python
+from coati.utils import load_checkpoint
+start_epoch, start_step, sampler_start_idx = load_checkpoint(
+ load_dir=checkpoint_path,
+ booster=booster,
+ model=model,
+ optimizer=optim,
+ lr_scheduler=lr_scheduler,
+ )
+```
How to train with limited resources
-Here are some examples that can allow you to train a 7B model on a single or multiple consumer-grade GPUs.
-
-If you only have a single 24G GPU, you can use the following script. `batch_size`, `lora_rank` and `grad_checkpoint` are the most important parameters to successfully train the model.
-
-```bash
-// [INFO]: MAX GPU MEMORY ALLOCATED: 19148.9345703125 MB
-torchrun --standalone --nproc_per_node=1 train_sft.py \
- --pretrain "/path/to/LLaMa-7B/" \
- --model 'llama' \
- --strategy ddp \
- --save_path /path/to/Coati-7B \
- --dataset /path/to/data.json \
- --batch_size 1 \
- --accumulation_steps 8 \
- --lr 2e-5 \
- --max_datasets_size 512 \
- --max_epochs 1 \
- --lora_rank 16 \
- --grad_checkpoint
-```
+Here are some suggestions that can allow you to train a 7B model on a single or multiple consumer-grade GPUs.
-`colossalai_gemini` strategy can enable a single 24G GPU to train the whole model without using LoRA if you have sufficient CPU memory. You can use the following script.
+`batch_size`, `lora_rank` and `grad_checkpoint` are the most important parameters to successfully train the model. To maintain a descent batch size for gradient calculation, consider increase the accumulation_step and reduce the batch_size on each rank.
-```bash
-torchrun --standalone --nproc_per_node=1 train_sft.py \
- --pretrain "/path/to/LLaMa-7B/" \
- --model 'llama' \
- --strategy colossalai_gemini \
- --save_path /path/to/Coati-7B \
- --dataset /path/to/data.json \
- --batch_size 1 \
- --accumulation_steps 8 \
- --lr 2e-5 \
- --max_datasets_size 512 \
- --max_epochs 1 \
- --grad_checkpoint
-```
+If you only have a single 24G GPU. Generally, using lora and "zero2-cpu" will be sufficient.
-If you have 4x32 GB GPUs, you can even train the whole 7B model using our `colossalai_zero2_cpu` strategy! The script is given as follows.
-
-```bash
-torchrun --standalone --nproc_per_node=4 train_sft.py \
- --pretrain "/path/to/LLaMa-7B/" \
- --model 'llama' \
- --strategy colossalai_zero2_cpu \
- --save_path /path/to/Coati-7B \
- --dataset /path/to/data.json \
- --batch_size 1 \
- --accumulation_steps 8 \
- --lr 2e-5 \
- --max_datasets_size 512 \
- --max_epochs 1 \
- --grad_checkpoint
-```
+`gemini` and `gemini-auto` can enable a single 24G GPU to train the whole model without using LoRA if you have sufficient CPU memory. But that strategy doesn't support gradient accumulation.
+If you have multiple GPUs each has very limited VRAM, say 8GB. You can try the `3d` for the plugin option, which supports tensor parellelism, set `--tp` to the number of GPUs that you have.
## The Plan
@@ -396,6 +449,8 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
- [x] support inference
- [x] support llama from [facebook](https://github.com/facebookresearch/llama)
- [x] implement PPO-ptx fine-tuning
+- [x] support flash-attention
+- [x] implement DPO fine-tuning
- [ ] integrate with Ray
- [ ] support more RL paradigms, like Implicit Language Q-Learning (ILQL),
- [ ] support chain-of-thought by [langchain](https://github.com/hwchase17/langchain)
@@ -467,6 +522,7 @@ Coati is developed by ColossalAI Team:
- [Fazzie](https://fazzie-key.cool/about/index.html) Contributing to the algorithm and development for SFT.
- [ofey404](https://github.com/ofey404) Contributing to both front-end and back-end development.
- [Wenhao Chen](https://github.com/CWHer) Contributing to subsequent code enhancements and performance improvements.
+- [Anbang Ye](https://github.com/YeAnbang) Contributing to the refactored version with updated acceleration framework, LoRA, DPO and PPO.
The PhD student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project.
- [Zangwei Zheng](https://github.com/zhengzangw)
diff --git a/applications/ColossalChat/benchmarks/Opt.json b/applications/ColossalChat/benchmarks/Opt.json
new file mode 100644
index 000000000000..6d47666bb056
--- /dev/null
+++ b/applications/ColossalChat/benchmarks/Opt.json
@@ -0,0 +1,17 @@
+{
+ "chat_template": "{% for message in messages %}{% if message['role'] == 'user' %}{{'Human: ' + bos_token + message['content'].strip() + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'].strip() + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + bos_token + message['content'].strip() + eos_token }}{% endif %}{% endfor %}",
+ "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ "human_line_start": [
+ 2
+ ],
+ "human_line_end": [
+ 2
+ ],
+ "assistant_line_start": [
+ 2
+ ],
+ "assistant_line_end": [
+ 2
+ ],
+ "end_of_system_line_position": 0
+}
diff --git a/applications/ColossalChat/benchmarks/README.md b/applications/ColossalChat/benchmarks/README.md
new file mode 100755
index 000000000000..08c5e0e6c685
--- /dev/null
+++ b/applications/ColossalChat/benchmarks/README.md
@@ -0,0 +1,37 @@
+# Benchmarks
+
+## Benchmark OPT with LoRA on dummy prompt data
+
+We provide various OPT models (string in parentheses is the corresponding model name used in this script):
+
+- OPT-125M (125m)
+- OPT-350M (350m)
+- OPT-700M (700m)
+- OPT-1.3B (1.3b)
+- OPT-2.7B (2.7b)
+- OPT-3.5B (3.5b)
+- OPT-5.5B (5.5b)
+- OPT-6.7B (6.7b)
+- OPT-10B (10b)
+- OPT-13B (13b)
+
+We also provide various training strategies:
+
+- gemini: ColossalAI GeminiPlugin with `placement_policy="cuda"`, like zero3
+- gemini_auto: ColossalAI GeminiPlugin with `placement_policy="cpu"`, like zero3-offload
+- zero2: ColossalAI zero2
+- zero2_cpu: ColossalAI zero2-offload
+- 3d: ColossalAI HybridParallelPlugin with TP, DP support
+
+## How to Run
+```bash
+cd ../tests
+# Prepare data for benchmark
+SFT_DATASET=/path/to/sft/data/ \
+PROMPT_DATASET=/path/to/prompt/data/ \
+PRETRAIN_DATASET=/path/to/ptx/data/ \
+PREFERENCE_DATASET=/path/to/preference/data \
+./test_data_preparation.sh
+# Start benchmark
+./benchmark_ppo.sh
+```
diff --git a/applications/ColossalChat/benchmarks/benchmark_memory_consumption.txt b/applications/ColossalChat/benchmarks/benchmark_memory_consumption.txt
new file mode 100644
index 000000000000..049285552d4f
--- /dev/null
+++ b/applications/ColossalChat/benchmarks/benchmark_memory_consumption.txt
@@ -0,0 +1,4 @@
+Model=Opt-125m; lora_rank=0; plugin=zero2
+Max CUDA memory usage: 26123.16 MB
+Model=Opt-125m; lora_rank=0; plugin=zero2
+Max CUDA memory usage: 26123.91 MB
diff --git a/applications/ColossalChat/benchmarks/benchmark_performance_summarization.txt b/applications/ColossalChat/benchmarks/benchmark_performance_summarization.txt
new file mode 100644
index 000000000000..b2a1ff1d77f2
--- /dev/null
+++ b/applications/ColossalChat/benchmarks/benchmark_performance_summarization.txt
@@ -0,0 +1,16 @@
+facebook/opt-125m; 0; zero2
+Performance summary:
+Generate 768 samples, throughput: 188.48 samples/s, TFLOPS per GPU: 361.23
+Train 768 samples, throughput: 448.38 samples/s, TFLOPS per GPU: 82.84
+Overall throughput: 118.42 samples/s
+Overall time per sample: 0.01 s
+Make experience time per sample: 0.01 s, 62.83%
+Learn time per sample: 0.00 s, 26.41%
+facebook/opt-125m; 0; zero2
+Performance summary:
+Generate 768 samples, throughput: 26.32 samples/s, TFLOPS per GPU: 50.45
+Train 768 samples, throughput: 71.15 samples/s, TFLOPS per GPU: 13.14
+Overall throughput: 18.86 samples/s
+Overall time per sample: 0.05 s
+Make experience time per sample: 0.04 s, 71.66%
+Learn time per sample: 0.01 s, 26.51%
diff --git a/applications/ColossalChat/benchmarks/benchmark_ppo.py b/applications/ColossalChat/benchmarks/benchmark_ppo.py
new file mode 100644
index 000000000000..e1b7a313f981
--- /dev/null
+++ b/applications/ColossalChat/benchmarks/benchmark_ppo.py
@@ -0,0 +1,523 @@
+"""
+For becnhmarking ppo. Mudified from examples/training_scripts/train_ppo.py
+"""
+
+import argparse
+import json
+import os
+import resource
+from contextlib import nullcontext
+
+import torch
+import torch.distributed as dist
+from coati.dataset import (
+ DataCollatorForPromptDataset,
+ DataCollatorForSupervisedDataset,
+ StatefulDistributedSampler,
+ load_tokenized_dataset,
+ setup_conversation_template,
+ setup_distributed_dataloader,
+)
+from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dropout
+from coati.trainer import PPOTrainer
+from coati.trainer.callbacks import PerformanceEvaluator
+from coati.trainer.utils import is_rank_0
+from coati.utils import load_checkpoint, replace_with_flash_attention
+from transformers import AutoTokenizer, OPTForCausalLM
+from transformers.models.opt.configuration_opt import OPTConfig
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.lazy import LazyInitContext
+from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.utils import get_current_device
+
+
+def get_model_numel(model: torch.nn.Module, plugin: str, tp: int) -> int:
+ numel = sum(p.numel() for p in model.parameters())
+ if plugin == "3d" and tp > 1:
+ numel *= dist.get_world_size()
+ return numel
+
+
+def get_gpt_config(model_name: str) -> OPTConfig:
+ model_map = {
+ "125m": OPTConfig.from_pretrained("facebook/opt-125m"),
+ "350m": OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16),
+ "700m": OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20),
+ "1.3b": OPTConfig.from_pretrained("facebook/opt-1.3b"),
+ "2.7b": OPTConfig.from_pretrained("facebook/opt-2.7b"),
+ "3.5b": OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32),
+ "5.5b": OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32),
+ "6.7b": OPTConfig.from_pretrained("facebook/opt-6.7b"),
+ "10b": OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32),
+ "13b": OPTConfig.from_pretrained("facebook/opt-13b"),
+ }
+ try:
+ return model_map[model_name]
+ except KeyError:
+ raise ValueError(f'Unknown model "{model_name}"')
+
+
+def benchmark_train(args):
+ # ==============================
+ # Initialize Distributed Training
+ # ==============================
+ colossalai.launch_from_torch({})
+ coordinator = DistCoordinator()
+
+ # ======================================================
+ # Initialize Model, Objective, Optimizer and LR Scheduler
+ # ======================================================
+ init_ctx = LazyInitContext(default_device=get_current_device()) if "gemini" in args.plugin else nullcontext()
+
+ booster_policy = None
+ with init_ctx:
+ actor = OPTForCausalLM(config=get_gpt_config(args.pretrain))
+ # Disable dropout
+ disable_dropout(actor)
+ ref_model = OPTForCausalLM(config=get_gpt_config(args.pretrain))
+ reward_model = RewardModel(config=get_gpt_config("350m"))
+ critic = Critic(config=get_gpt_config("350m"))
+ disable_dropout(critic)
+
+ actor_numel = get_model_numel(actor, args.plugin, args.tp)
+ critic_numel = get_model_numel(critic, args.plugin, args.tp)
+ initial_model_numel = get_model_numel(ref_model, args.plugin, args.tp)
+ reward_model_numel = get_model_numel(reward_model, args.plugin, args.tp)
+
+ performance_evaluator = PerformanceEvaluator(
+ actor_numel,
+ critic_numel,
+ initial_model_numel,
+ reward_model_numel,
+ enable_grad_checkpoint=False,
+ ignore_episodes=2,
+ train_config={"model": "facebook/opt-" + args.pretrain, "lora_rank": args.lora_rank, "plugin": args.plugin},
+ save_path="./benchmark_performance_summarization.txt",
+ )
+
+ if args.tp > 1:
+ if reward_model.model.config.architectures[0] != critic.model.config.architectures[0]:
+ raise ValueError("Reward model and critic model must have the same architecture")
+ if reward_model.model.config.architectures[0] == "BloomForCausalLM":
+ from colossalai.shardformer.policies.bloom import BloomPolicy
+
+ booster_policy = BloomPolicy()
+ elif reward_model.model.config.architectures[0] == "LlamaForCausalLM":
+ from colossalai.shardformer.policies.llama import LlamaPolicy
+
+ booster_policy = LlamaPolicy()
+ elif reward_model.model.config.architectures[0] == "GPT2LMHeadModel":
+ from colossalai.shardformer.policies.gpt2 import GPT2Policy
+
+ booster_policy = GPT2Policy()
+ elif reward_model.model.config.architectures[0] == "ChatGLMModel":
+ from colossalai.shardformer.policies.chatglm2 import ChatGLMPolicy
+
+ booster_policy = ChatGLMPolicy()
+ elif reward_model.model.config.architectures[0] == "OPTForCausalLM":
+ from colossalai.shardformer.policies.opt import OPTPolicy
+
+ booster_policy = OPTPolicy()
+ else:
+ raise ValueError("Unknown model architecture for policy")
+
+ if args.lora_rank > 0:
+ actor = convert_to_lora_module(actor, args.lora_rank, lora_train_bias=args.lora_train_bias)
+ critic = convert_to_lora_module(critic, args.lora_rank, lora_train_bias=args.lora_train_bias)
+
+ if args.grad_checkpoint and args.lora_rank == 0:
+ actor.gradient_checkpointing_enable()
+ critic.model.gradient_checkpointing_enable()
+ coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
+ elif args.lora_rank > 0:
+ coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
+
+ if args.use_flash_attn:
+ replace_with_flash_attention(model=actor)
+ replace_with_flash_attention(model=critic)
+ coordinator.print_on_master(msg="Flash-attention enabled successfully")
+
+ # configure tokenizer
+ tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
+ if os.path.exists(args.conversation_template_config):
+ conversation_template_config = json.load(open(args.conversation_template_config, "r", encoding="utf8"))
+ conversation_template = setup_conversation_template(
+ tokenizer, chat_template_config=conversation_template_config, save_path=args.conversation_template_config
+ )
+ stop_token_ids = (
+ conversation_template.assistant_line_end if len(conversation_template.assistant_line_end) > 0 else None
+ )
+ else:
+ raise ValueError("Conversation template config is not provided or incorrect")
+ if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
+ try:
+ # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
+ tokenizer.pad_token = tokenizer.eos_token
+ except AttributeError as e:
+ logger.warning(f"Unable to set pad token to eos token, {str(e)}")
+ if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
+ logger.warning(
+ "The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
+ )
+ tokenizer.add_bos_token = False
+ tokenizer.add_eos_token = False
+ tokenizer.padding_side = "left" # left padding for generation (online learning)
+
+ # configure generation config
+ actor.generation_config.update(
+ pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id
+ )
+
+ # configure optimizer
+ coordinator.print_on_master(f"setting up optimizer for actor: lr={args.lr}, weight_decay={args.weight_decay}")
+ actor_optim = HybridAdam(
+ model_params=actor.parameters(),
+ lr=args.lr,
+ betas=(0.9, 0.95),
+ weight_decay=args.weight_decay,
+ adamw_mode=True,
+ )
+
+ coordinator.print_on_master(f"setting up optimizer for critic: lr={args.lr}, weight_decay={args.weight_decay}")
+ critic_optim = HybridAdam(
+ model_params=critic.parameters(),
+ lr=args.critic_lr,
+ betas=(0.9, 0.95),
+ weight_decay=args.weight_decay,
+ adamw_mode=True,
+ )
+
+ # configure dataset
+ coordinator.print_on_master(f"Load dataset: {args.prompt_dataset}")
+ mode_map = {"train": "train", "valid": "validation", "test": "test"}
+ train_prompt_dataset = load_tokenized_dataset(dataset_paths=args.prompt_dataset, mode="train", mode_map=mode_map)
+ coordinator.print_on_master(f"prompt dataset size: {len(train_prompt_dataset)}")
+ data_collator = DataCollatorForPromptDataset(tokenizer=tokenizer, max_length=args.max_length - args.max_seq_len)
+ train_prompt_dataloader = setup_distributed_dataloader(
+ dataset=train_prompt_dataset,
+ batch_size=args.experience_batch_size,
+ shuffle=True,
+ drop_last=True,
+ collate_fn=data_collator,
+ use_tp=args.tp > 1,
+ )
+
+ if len(args.pretrain_dataset) > 0:
+ train_pretrain_dataset = load_tokenized_dataset(
+ dataset_paths=args.pretrain_dataset, mode="train", mode_map=mode_map
+ )
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
+ train_pretrain_dataloader = setup_distributed_dataloader(
+ dataset=train_pretrain_dataset,
+ batch_size=args.ptx_batch_size,
+ shuffle=True,
+ drop_last=True,
+ collate_fn=data_collator,
+ use_tp=args.tp > 1,
+ )
+ else:
+ train_pretrain_dataloader = None
+
+ if args.warmup_steps is None:
+ args.warmup_steps = int(0.025 * args.num_episodes)
+ coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
+
+ actor_lr_scheduler = CosineAnnealingWarmupLR(
+ optimizer=actor_optim,
+ total_steps=args.num_episodes,
+ warmup_steps=args.warmup_steps,
+ eta_min=0.1 * args.lr,
+ )
+
+ critic_lr_scheduler = CosineAnnealingWarmupLR(
+ optimizer=critic_optim,
+ total_steps=args.num_episodes,
+ warmup_steps=args.warmup_steps,
+ eta_min=0.1 * args.lr,
+ )
+
+ # ==============================
+ # Initialize Booster
+ # ==============================
+ if args.plugin == "gemini":
+ plugin = GeminiPlugin(
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "gemini_auto":
+ plugin = GeminiPlugin(
+ precision=args.mixed_precision,
+ placement_policy="auto",
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "zero2":
+ plugin = LowLevelZeroPlugin(
+ stage=2,
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "zero2_cpu":
+ plugin = LowLevelZeroPlugin(
+ stage=2,
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ cpu_offload=True,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "3d":
+ plugin = HybridParallelPlugin(
+ tp_size=args.tp,
+ pp_size=1,
+ zero_stage=0,
+ precision=args.mixed_precision,
+ )
+ custom_plugin = HybridParallelPlugin(
+ tp_size=args.tp,
+ pp_size=1,
+ zero_stage=0,
+ precision=args.mixed_precision,
+ custom_policy=booster_policy,
+ )
+ else:
+ raise ValueError(f"Unknown plugin {args.plugin}")
+
+ if args.plugin != "3d":
+ custom_plugin = plugin
+
+ actor_booster = Booster(plugin=plugin)
+ ref_booster = Booster(plugin=plugin)
+ rm_booster = Booster(plugin=custom_plugin)
+ critic_booster = Booster(plugin=custom_plugin)
+
+ default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
+ torch.set_default_dtype(default_dtype)
+ actor, actor_optim, _, train_prompt_dataloader, actor_lr_scheduler = actor_booster.boost(
+ model=actor,
+ optimizer=actor_optim,
+ lr_scheduler=actor_lr_scheduler,
+ dataloader=train_prompt_dataloader,
+ )
+
+ critic, critic_optim, _, _, critic_lr_scheduler = critic_booster.boost(
+ model=critic,
+ optimizer=critic_optim,
+ lr_scheduler=critic_lr_scheduler,
+ dataloader=train_prompt_dataloader,
+ )
+ reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader)
+ ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_prompt_dataloader)
+
+ torch.set_default_dtype(torch.float)
+
+ coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
+ coordinator.print_on_master(
+ f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
+ )
+
+ sampler_start_idx = 0
+ start_step = 0
+
+ if args.rm_checkpoint_path is not None:
+ if "modeling" in args.rm_checkpoint_path:
+ rm_booster.load_model(reward_model, args.rm_checkpoint_path)
+ else:
+ _, _, _ = load_checkpoint(
+ load_dir=args.rm_checkpoint_path,
+ booster=rm_booster,
+ model=reward_model,
+ optimizer=None,
+ lr_scheduler=None,
+ )
+ coordinator.print_on_master(f"Loaded reward model checkpoint {args.rm_checkpoint_path}")
+
+ if args.checkpoint_path is not None:
+ if "modeling" in args.checkpoint_path:
+ actor_booster.load_model(actor, args.checkpoint_path)
+ ref_booster.load_model(ref_model, args.checkpoint_path)
+ coordinator.print_on_master(f"Loaded actor and reference model {args.checkpoint_path}")
+ else:
+ _, start_step, sampler_start_idx = load_checkpoint(
+ load_dir=args.checkpoint_path,
+ booster=actor_booster,
+ model=actor,
+ optimizer=actor_optim,
+ lr_scheduler=actor_lr_scheduler,
+ )
+ _, _, _ = load_checkpoint(
+ load_dir=args.checkpoint_path,
+ booster=ref_booster,
+ model=ref_model,
+ optimizer=critic_optim,
+ lr_scheduler=critic_lr_scheduler,
+ )
+ assert isinstance(train_prompt_dataloader.sampler, StatefulDistributedSampler)
+ train_prompt_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
+
+ coordinator.print_on_master(
+ f"Loaded actor and reference model checkpoint {args.checkpoint_path} at spisode {start_step}"
+ )
+ coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
+
+ coordinator.print_on_master(
+ f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+ coordinator.print_on_master(
+ f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+ coordinator.print_on_master(
+ f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
+ )
+
+ if args.critic_checkpoint_path is not None:
+ if "modeling" in args.critic_checkpoint_path:
+ critic_booster.load_model(critic, args.critic_checkpoint_path)
+ else:
+ _, _, _ = load_checkpoint(
+ load_dir=args.critic_checkpoint_path,
+ booster=critic_booster,
+ model=critic,
+ optimizer=critic_optim,
+ lr_scheduler=critic_lr_scheduler,
+ )
+ coordinator.print_on_master(f"Loaded critic checkpoint {args.critic_checkpoint_path}")
+ coordinator.print_on_master(
+ f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+ coordinator.print_on_master(
+ f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+ coordinator.print_on_master(
+ f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
+ )
+
+ # configure trainer
+ trainer = PPOTrainer(
+ actor_booster,
+ critic_booster,
+ actor,
+ critic,
+ reward_model,
+ ref_model,
+ actor_optim,
+ critic_optim,
+ actor_lr_scheduler,
+ critic_lr_scheduler,
+ tokenizer=tokenizer,
+ stop_token_ids=stop_token_ids,
+ kl_coef=args.kl_coef,
+ ptx_coef=args.ptx_coef,
+ train_batch_size=args.train_batch_size,
+ buffer_limit=args.num_collect_steps * args.experience_batch_size,
+ max_length=args.max_length,
+ max_new_tokens=args.max_seq_len,
+ use_cache=True,
+ do_sample=True,
+ temperature=0.7,
+ accumulation_steps=args.accumulation_steps,
+ save_dir=args.save_path,
+ save_interval=args.save_interval,
+ top_k=50,
+ use_tp=args.tp > 1,
+ offload_inference_models="gemini" not in args.plugin,
+ callbacks=[performance_evaluator],
+ coordinator=coordinator,
+ )
+
+ trainer.fit(
+ num_episodes=args.num_episodes,
+ num_collect_steps=args.num_collect_steps,
+ num_update_steps=args.num_update_steps,
+ prompt_dataloader=train_prompt_dataloader,
+ pretrain_dataloader=train_pretrain_dataloader,
+ log_dir=args.log_dir,
+ use_wandb=args.use_wandb,
+ )
+
+ if args.lora_rank > 0 and args.merge_lora_weights:
+ from coati.models.lora import LORA_MANAGER
+
+ # NOTE: set model to eval to merge LoRA weights
+ LORA_MANAGER.merge_weights = True
+ actor.eval()
+ critic.eval()
+ # save model checkpoint after fitting on only rank0
+ coordinator.print_on_master("Start saving final actor model checkpoint")
+ actor_booster.save_model(actor, os.path.join(trainer.actor_save_dir, "modeling"), shard=True)
+ coordinator.print_on_master(
+ f"Saved final actor model checkpoint at episodes {args.num_episodes} at folder {args.save_path}"
+ )
+ coordinator.print_on_master("Start saving final critic model checkpoint")
+ critic_booster.save_model(critic, os.path.join(trainer.critic_save_dir, "modeling"), shard=True)
+ coordinator.print_on_master(
+ f"Saved final critic model checkpoint at episodes {args.num_episodes} at folder {args.save_path}"
+ )
+ memory_consumption = torch.cuda.max_memory_allocated() / 1024**2
+ if is_rank_0():
+ with open("./benchmark_memory_consumption.txt", "a+") as f:
+ f.write(
+ f"Model=Opt-{args.pretrain}; lora_rank={args.lora_rank}; plugin={args.plugin}\nMax CUDA memory usage: {memory_consumption:.2f} MB\n"
+ )
+ coordinator.print_on_master(f"Max CUDA memory usage: {memory_consumption:.2f} MB")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--prompt_dataset", nargs="+", default=[])
+ parser.add_argument("--pretrain_dataset", nargs="+", default=[])
+ parser.add_argument(
+ "--plugin",
+ type=str,
+ default="gemini",
+ choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
+ help="Choose which plugin to use",
+ )
+ parser.add_argument(
+ "--conversation_template_config",
+ type=str,
+ default=None,
+ help="Path \
+ to save conversation template config files.",
+ )
+ parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
+ parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
+ parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
+ parser.add_argument("--tokenizer_dir", type=str, default=None)
+ parser.add_argument("--tp", type=int, default=1)
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--checkpoint_path", type=str, default=None)
+ parser.add_argument("--critic_checkpoint_path", type=str, default=None)
+ parser.add_argument("--rm_checkpoint_path", type=str, help="Reward model checkpoint path")
+ parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
+ parser.add_argument("--num_episodes", type=int, default=1)
+ parser.add_argument("--num_collect_steps", type=int, default=2)
+ parser.add_argument("--num_update_steps", type=int, default=5)
+ parser.add_argument("--save_interval", type=int, default=1000)
+ parser.add_argument("--train_batch_size", type=int, default=16)
+ parser.add_argument("--experience_batch_size", type=int, default=16)
+ parser.add_argument("--ptx_batch_size", type=int, default=1)
+ parser.add_argument("--lora_train_bias", type=str, default="none")
+ parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
+ parser.add_argument("--accumulation_steps", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument("--merge_lora_weights", type=bool, default=True)
+ parser.add_argument("--lr", type=float, default=9e-6)
+ parser.add_argument("--critic_lr", type=float, default=9e-6)
+ parser.add_argument("--kl_coef", type=float, default=0.1)
+ parser.add_argument("--ptx_coef", type=float, default=0.0)
+ parser.add_argument("--max_length", type=int, default=512)
+ parser.add_argument("--max_seq_len", type=int, default=256)
+ parser.add_argument("--log_dir", default="logs", type=str)
+ parser.add_argument("--use_wandb", default=False, action="store_true")
+ parser.add_argument("--grad_checkpoint", default=False, action="store_true")
+ parser.add_argument("--use_flash_attn", default=False, action="store_true")
+ args = parser.parse_args()
+ benchmark_train(args)
diff --git a/applications/ColossalChat/benchmarks/benchmark_ppo.sh b/applications/ColossalChat/benchmarks/benchmark_ppo.sh
new file mode 100755
index 000000000000..e88757659685
--- /dev/null
+++ b/applications/ColossalChat/benchmarks/benchmark_ppo.sh
@@ -0,0 +1,119 @@
+#!/usr/bin/env bash
+
+set_n_least_used_CUDA_VISIBLE_DEVICES() {
+ local n=${1:-"9999"}
+ echo "GPU Memory Usage:"
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
+ tail -n +2 |
+ nl -v 0 |
+ tee /dev/tty |
+ sort -g -k 2 |
+ awk '{print $1}' |
+ head -n $n)
+ export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
+ echo "Now CUDA_VISIBLE_DEVICES is set to:"
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+}
+
+set_n_least_used_CUDA_VISIBLE_DEVICES 8
+
+set -xu
+
+NUM_RETRY=3
+BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
+EXAMPLES_DIR=$BASE_DIR/examples
+TEMP_DIR=$BASE_DIR/temp
+MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models
+MODELS_DIR=$TEMP_DIR/models_config
+# To benchmark different models, change the following line
+# MODELS=('125m' '350m' '700m' '1.3b' '2.7b' '3.5b' '5.5b' '6.7b' '10b' '13b')
+MODELS=('125m')
+# To benchmark different strategies, change the following line
+# PLUGINS=('zero2', 'zero2_cpu', '3d')
+PLUGINS=('zero2')
+LORA_RANK=('0')
+
+export OMP_NUM_THREADS=8
+
+rm ./benchmark_memory_consumption.txt
+rm ./benchmark_performance_summarization.txt
+
+# install requirements
+pip install -r $EXAMPLES_DIR/requirements.txt
+
+random_choice() {
+ local arr=("$@")
+ local len=${#arr[@]}
+ local idx=$((RANDOM % len))
+ echo ${arr[$idx]}
+}
+
+echo "[Test]: testing ppo ..."
+
+SKIPPED_TESTS=(
+)
+
+GRAD_CKPTS=('' '--grad_checkpoint')
+GRAD_CKPTS=('')
+for lora_rank in ${LORA_RANK[@]}; do
+ for model in ${MODELS[@]}; do
+ plugins=($(shuf -e "${PLUGINS[@]}"))
+ for plugin in ${plugins[@]}; do
+ if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
+ echo "[Test]: Skipped $model-$plugin-$lora_rank"
+ continue
+ elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
+ echo "[Test]: Skipped $model-$plugin"
+ continue
+ fi
+ pretrain=$model
+ tokenizer_dir="facebook/opt-125m"
+ grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
+ tp='1'
+ if [[ $plugin == "3d" ]]; then
+ tp='4'
+ fi
+ for i in $(seq $NUM_RETRY); do
+ echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
+ declare -a prompt_dataset=()
+ for split in $(seq -f "%05g" 0 9); do
+ prompt_dataset+=("$TEMP_DIR/benchmark/arrow/part-$split")
+ done
+ colossalai run --nproc_per_node 8 --master_port 28547 $BASE_DIR/benchmarks/benchmark_ppo.py \
+ --pretrain $pretrain \
+ --tokenizer_dir $tokenizer_dir \
+ --prompt_dataset ${prompt_dataset[@]} \
+ --ptx_coef 0 \
+ --save_path $MODEL_SAVE_PATH \
+ --conversation_template_config ./Opt.json \
+ --lora_rank $lora_rank \
+ --plugin $plugin \
+ --num_episodes 5 \
+ --num_collect_steps 1 \
+ --num_update_steps 1 \
+ --max_seq_len 128 \
+ --max_length 512 \
+ --experience_batch_size 32 \
+ --train_batch_size 32 \
+ --accumulation_steps 1 \
+ --lr 9e-6 \
+ --mixed_precision "bf16" \
+ --grad_clip 1.0 \
+ --use_flash_attn \
+ --tp $tp \
+ --lr 2e-5 \
+ $grad_ckpt
+ passed=$?
+ if [ $passed -eq 0 ]; then
+ rm -rf $MODEL_SAVE_PATH/*
+ rm -rf $MODELS_DIR/*
+ break
+ fi
+ done
+ if [ $passed -ne 0 ]; then
+ echo "[Test]: Failed $model-$plugin-$lora_rank"
+ exit 1
+ fi
+ done
+ done
+done
diff --git a/applications/ColossalChat/benchmarks/data_preparation.sh b/applications/ColossalChat/benchmarks/data_preparation.sh
new file mode 100755
index 000000000000..ca2986be43d5
--- /dev/null
+++ b/applications/ColossalChat/benchmarks/data_preparation.sh
@@ -0,0 +1,16 @@
+SAVE_DIR=""
+
+
+BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
+EXAMPLES_DIR=$BASE_DIR/examples
+SAVE_DIR=$BASE_DIR/temp/benchmark
+
+rm -rf $SAVE_DIR
+
+python $EXAMPLES_DIR/data_preparation_scripts/prepare_prompt_dataset.py --data_input_dirs "/home/yeanbang/data/dataset/sft_data/alpaca/data_preprocessed/train" \
+ --conversation_template_config ./Opt.json \
+ --tokenizer_dir "facebook/opt-125m" \
+ --data_cache_dir $SAVE_DIR/cache \
+ --data_jsonl_output_dir $SAVE_DIR/jsonl \
+ --data_arrow_output_dir $SAVE_DIR/arrow \
+ --num_samples_per_datafile 30
diff --git a/applications/Chat/benchmarks/ray/1mmt_dummy.py b/applications/ColossalChat/benchmarks/ray/1mmt_dummy.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/benchmarks/ray/1mmt_dummy.py
rename to applications/ColossalChat/benchmarks/ray/1mmt_dummy.py
diff --git a/applications/Chat/benchmarks/ray/mmmt_dummy.py b/applications/ColossalChat/benchmarks/ray/mmmt_dummy.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/benchmarks/ray/mmmt_dummy.py
rename to applications/ColossalChat/benchmarks/ray/mmmt_dummy.py
diff --git a/applications/Chat/coati/__init__.py b/applications/ColossalChat/coati/__init__.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/coati/__init__.py
rename to applications/ColossalChat/coati/__init__.py
diff --git a/applications/ColossalChat/coati/dataset/__init__.py b/applications/ColossalChat/coati/dataset/__init__.py
new file mode 100755
index 000000000000..e216c37e1c62
--- /dev/null
+++ b/applications/ColossalChat/coati/dataset/__init__.py
@@ -0,0 +1,26 @@
+from .conversation import Conversation, setup_conversation_template
+from .loader import (
+ DataCollatorForPreferenceDataset,
+ DataCollatorForPromptDataset,
+ DataCollatorForSupervisedDataset,
+ StatefulDistributedSampler,
+ load_tokenized_dataset,
+ setup_distributed_dataloader,
+)
+from .tokenization_utils import supervised_tokenize_sft, tokenize_prompt_dataset, tokenize_rlhf
+
+__all__ = [
+ "tokenize_prompt_dataset",
+ "DataCollatorForPromptDataset",
+ "is_rank_0",
+ "DataCollatorForPreferenceDataset",
+ "DataCollatorForSupervisedDataset",
+ "StatefulDistributedSampler",
+ "load_tokenized_dataset",
+ "setup_distributed_dataloader",
+ "supervised_tokenize_pretrain",
+ "supervised_tokenize_sft",
+ "tokenize_rlhf",
+ "setup_conversation_template",
+ "Conversation",
+]
diff --git a/applications/ColossalChat/coati/dataset/conversation.py b/applications/ColossalChat/coati/dataset/conversation.py
new file mode 100755
index 000000000000..15a33be93966
--- /dev/null
+++ b/applications/ColossalChat/coati/dataset/conversation.py
@@ -0,0 +1,143 @@
+import dataclasses
+import json
+import os
+from typing import Any, Dict, List
+
+import torch.distributed as dist
+from transformers import AutoTokenizer, PreTrainedTokenizer
+
+from colossalai.logging import get_dist_logger
+
+logger = get_dist_logger()
+
+
+@dataclasses.dataclass
+class Conversation:
+ tokenizer: PreTrainedTokenizer
+ system_message: str
+ chat_template: str
+ stop_ids: List[int]
+
+ @classmethod
+ def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict):
+ """
+ Setup the conversation template from config
+ """
+ tokenizer.chat_template = config["chat_template"]
+ conv = cls(tokenizer, config["system_message"], config["chat_template"], config["stop_ids"])
+ conv.clear()
+ return conv
+
+ def clear(self):
+ self.messages = []
+
+ @classmethod
+ def get_conversation_template_keys(cls):
+ return ["system_message", "chat_template"]
+
+ def __str__(self):
+ return json.dumps(
+ {k: self.__dict__[k] for k in self.__dict__ if k not in ["tokenizer", "messages"]},
+ ensure_ascii=False,
+ indent=4,
+ )
+
+ def get_prompt(self, length: int = None, add_generation_prompt=False) -> Any:
+ """
+ Retrieves the prompt for the conversation.
+
+ Args:
+ length (int, optional): The number of messages to include in the prompt. Defaults to None.
+ get_seps_info (bool, optional): Whether to include separator information in the output. Defaults to False.
+ add_generation_prompt (bool, optional): Whether to add the assistant line start token in generation (for generation only). Defaults to False.
+
+ Returns:
+ str or tuple: The prompt string if get_seps_info is False, otherwise a tuple containing the prompt string and separator information.
+ """
+
+ if length is None:
+ length = len(self.messages)
+
+ assert length <= len(self.messages)
+ if self.system_message is not None:
+ messages = [{"role": "system", "content": self.system_message}] + self.messages[:length]
+ else:
+ messages = self.messages[:length]
+ prompt = self.tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=add_generation_prompt
+ )
+ return prompt
+
+ def save_prompt(self):
+ return self.get_prompt()
+
+ def append_message(self, role: str, message: str):
+ """
+ Append a message to the conversation.
+
+ Args:
+ role (str): The role of the message sender. Must be either 'user' or 'assistant'.
+ message (str): The content of the message.
+
+ Raises:
+ AssertionError: If the role is not 'user' or 'assistant'.
+ """
+ assert role in ["user", "assistant"]
+ self.messages.append({"role": role, "content": message})
+
+ def copy(self):
+ return Conversation(tokenizer=self.tokenizer, chat_template=self.chat_template)
+
+
+def setup_conversation_template(
+ tokenizer: PreTrainedTokenizer, chat_template_config: Dict = None, save_path: str = None
+) -> Conversation:
+ """
+ Setup the conversation template, if chat_template is given, will replace the default chat_template of the tokenizer
+ with it. Otherwise, the default chat_template will be used. If the tokenizer doesn't have a default chat_template,
+ raise error to remind the user to set it manually.
+
+ Args:
+ tokenizer: The tokenizer to use
+ chat_template_config:
+ {
+ "system_message": str The system message to use
+ "chat_template": str The chat_template to use, if can be a chat_template, a huggingface model path or a local model.
+ if a huggeface model path or a local model, the chat_template will be loaded from the model's tokenizer's default chat template.
+ "stop_ids": List[int], the token ids used to terminate generation. You need to provide this for ppo training and generation.
+ }
+ """
+ if any([s not in chat_template_config.keys() for s in Conversation.get_conversation_template_keys()]):
+ # Try to automatically set up conversation template, if fail, it throws an error that you need to do it manually
+ if "system_message" not in chat_template_config:
+ logger.warning("No system message is provided, will not use system message.")
+ if "chat_template" not in chat_template_config:
+ logger.warning("No chat_template is provided, will try to load it from the tokenizer.")
+ if tokenizer.chat_template != None:
+ chat_template_config["chat_template"] = tokenizer.chat_template
+ else:
+ raise ValueError(
+ f"Load a tokenizer from {chat_template_config['chat_template']}, which doesn't have a default chat template, please set it manually."
+ )
+ else:
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(chat_template_config["chat_template"])
+ if tokenizer.chat_template != None:
+ chat_template_config["chat_template"] = tokenizer.chat_template
+ else:
+ raise ValueError(
+ f"Load a tokenizer from {chat_template_config['chat_template']}, which doesn't have a default chat template, please set it manually."
+ )
+ logger.warning(
+ f"chat_template is provided as a local model path or huggingface model path, loaded chat_template from \"{chat_template_config['chat_template']}\"."
+ )
+ except OSError:
+ pass
+ except ValueError as e:
+ raise ValueError(e)
+ if not dist.is_initialized() or dist.get_rank() == 0:
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ with open(save_path, "w", encoding="utf8") as f:
+ logger.info(f"Successfully generated a conversation tempalte config, save to {save_path}.")
+ json.dump(chat_template_config, f, indent=4, ensure_ascii=False)
+ return Conversation.from_config(tokenizer, chat_template_config)
diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py
new file mode 100755
index 000000000000..93cc1dab8d21
--- /dev/null
+++ b/applications/ColossalChat/coati/dataset/loader.py
@@ -0,0 +1,383 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Dataloader for sft, dpo, ppo
+"""
+
+import math
+import os
+import random
+from dataclasses import dataclass
+from typing import Callable, Dict, Iterator, List, Optional, Sequence, Union
+
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from coati.dataset.utils import chuncate_sequence, pad_to_max_len
+from datasets import Dataset as HFDataset
+from datasets import dataset_dict, load_from_disk
+from torch.distributed import ProcessGroup
+from torch.distributed.distributed_c10d import _get_default_group
+from torch.utils.data import ConcatDataset, DataLoader, Dataset, DistributedSampler
+from transformers.tokenization_utils import PreTrainedTokenizer
+
+DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
+PathType = Union[str, os.PathLike]
+
+
+def load_tokenized_dataset(
+ dataset_paths: Union[PathType, List[PathType]], mode: str = "train", **kwargs
+) -> Optional[DatasetType]:
+ """
+ Load pre-tokenized dataset.
+ Each instance of dataset is a dictionary with
+ `{'input_ids': List[int], 'labels': List[int], sequence: str}` format.
+ """
+ mode_map = kwargs.get("mode_map", {"train": "train", "dev": "validation", "test": "test"})
+ assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}"
+
+ if isinstance(dataset_paths, (str, os.PathLike)):
+ dataset_paths = [dataset_paths]
+
+ datasets = [] # `List[datasets.dataset_dict.Dataset]`
+ for ds_path in dataset_paths:
+ ds_path = os.path.abspath(ds_path)
+ assert os.path.exists(ds_path), f"Not existed file path {ds_path}"
+ ds_dict = load_from_disk(dataset_path=ds_path, keep_in_memory=False)
+ if isinstance(ds_dict, HFDataset):
+ datasets.append(ds_dict)
+ else:
+ if mode_map[mode] in ds_dict:
+ datasets.append(ds_dict[mode_map[mode]])
+ if len(datasets) == 0:
+ return None
+ if len(datasets) == 1:
+ return datasets.pop()
+ return ConcatDataset(datasets=datasets)
+
+
+@dataclass
+class DataCollatorForSupervisedDataset(object):
+ """
+ Collate instances for supervised dataset.
+ Each instance is a tokenized dictionary with fields
+ `input_ids`(List[int]), `labels`(List[int]) and `sequence`(str).
+ """
+
+ tokenizer: PreTrainedTokenizer
+ max_length: int = 4096
+ ignore_index: int = -100
+
+ def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
+ """
+
+ Args:
+ instances (`Sequence[Dict[str, List[int]]]`):
+ Mini-batch samples, each sample is stored in an individual dictionary.
+
+ Returns:
+ (`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
+ `input_ids`: `torch.Tensor` of shape (bsz, max_len);
+ `attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
+ `labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
+ """
+ assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
+ f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
+ f"but now `{self.tokenizer.pad_token_id}`"
+ )
+
+ # `List[torch.Tensor]`
+ batch_input_ids = [
+ torch.LongTensor(instance["input_ids"][: self.max_length])
+ if len(instance["input_ids"]) > self.max_length
+ else torch.LongTensor(instance["input_ids"])
+ for instance in instances
+ ]
+ batch_labels = [
+ torch.LongTensor(instance["labels"][: self.max_length])
+ if len(instance["labels"]) > self.max_length
+ else torch.LongTensor(instance["labels"])
+ for instance in instances
+ ]
+ if self.tokenizer.padding_side == "right":
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ sequences=batch_input_ids,
+ batch_first=True,
+ padding_value=self.tokenizer.pad_token_id,
+ ) # (bsz, max_len)
+ labels = torch.nn.utils.rnn.pad_sequence(
+ sequences=batch_labels,
+ batch_first=True,
+ padding_value=self.ignore_index,
+ ) # (bsz, max_len)
+ # pad to max
+ to_pad = self.max_length - input_ids.size(1)
+ input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
+ labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
+ elif self.tokenizer.padding_side == "left":
+ reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids]
+ reversed_input_ids = torch.nn.utils.rnn.pad_sequence(
+ sequences=reversed_input_ids,
+ batch_first=True,
+ padding_value=self.tokenizer.pad_token_id,
+ ) # (bsz, max_len)
+ input_ids = torch.flip(reversed_input_ids, dims=(1,)) # (bsz, max_len)
+ reversed_labels = [seq.flip(dims=(0,)) for seq in batch_labels]
+ reversed_labels = torch.nn.utils.rnn.pad_sequence(
+ sequences=reversed_labels,
+ batch_first=True,
+ padding_value=self.ignore_index,
+ ) # (bsz, max_len)
+ labels = torch.flip(reversed_labels, dims=(1,)) # (bsz, max_len)
+ else:
+ raise RuntimeError(
+ f"`{self.tokenizer.__class__.__name__}.padding_side` can only be `left` or `right`, "
+ f"but now `{self.tokenizer.padding_side}`"
+ )
+
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
+
+ return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
+
+
+@dataclass
+class DataCollatorForPromptDataset(DataCollatorForSupervisedDataset):
+ def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
+ """
+
+ Args:
+ instances (`Sequence[Dict[str, List[int]]]`):
+ Mini-batch samples, each sample is stored in an individual dictionary.
+
+ Returns:
+ (`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
+ `input_ids`: `torch.Tensor` of shape (bsz, max_len);
+ `attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
+ """
+ instances = [{"input_ids": ins["input_ids"], "labels": ins["input_ids"]} for ins in instances]
+ ret = super().__call__(instances=instances)
+ input_ids = F.pad(
+ ret["input_ids"], (self.max_length - ret["input_ids"].size(1), 0), value=self.tokenizer.pad_token_id
+ )
+ attention_mask = F.pad(ret["attention_mask"], (self.max_length - ret["attention_mask"].size(1), 0), value=False)
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
+
+
+@dataclass
+class DataCollatorForPreferenceDataset(object):
+ """
+ Collate instances for supervised dataset.
+ Each instance is a tokenized dictionary with fields
+ `input_ids`(List[int]), `labels`(List[int]) and `sequence`(str).
+ """
+
+ tokenizer: PreTrainedTokenizer
+ max_length: int = 4096
+
+ def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
+ """
+
+ Args:
+ instances (`Sequence[Dict[str, List[int]]]`):
+ Mini-batch samples, each sample is stored in an individual dictionary.
+
+ Returns:
+ (`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
+ `input_ids`: `torch.Tensor` of shape (bsz, max_len);
+ `attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
+ `labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
+ """
+ assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
+ f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
+ f"but now `{self.tokenizer.pad_token_id}`"
+ )
+
+ (
+ chosen_input_ids,
+ chosen_loss_mask, # [batch_size * seq_len]
+ reject_input_ids,
+ reject_loss_mask,
+ ) = (
+ chuncate_sequence([ins["chosen_input_ids"] for ins in instances], self.max_length, torch.int64),
+ chuncate_sequence([ins["chosen_loss_mask"] for ins in instances], self.max_length, torch.bool),
+ chuncate_sequence([ins["rejected_input_ids"] for ins in instances], self.max_length, torch.int64),
+ chuncate_sequence([ins["rejected_loss_mask"] for ins in instances], self.max_length, torch.bool),
+ )
+
+ padding_side = self.tokenizer.padding_side
+ chosen_attention_mask = [torch.ones_like(seq).bool() for seq in chosen_input_ids]
+ reject_attention_mask = [torch.ones_like(seq).bool() for seq in reject_input_ids]
+
+ (
+ chosen_input_ids,
+ chosen_attention_mask,
+ chosen_loss_mask,
+ reject_input_ids,
+ reject_attention_mask,
+ reject_loss_mask,
+ ) = (
+ pad_to_max_len(chosen_input_ids, self.max_length, self.tokenizer.pad_token_id, padding_side=padding_side),
+ pad_to_max_len(chosen_attention_mask, self.max_length, False, padding_side=padding_side),
+ pad_to_max_len(chosen_loss_mask, self.max_length, False, padding_side=padding_side),
+ pad_to_max_len(reject_input_ids, self.max_length, self.tokenizer.pad_token_id, padding_side=padding_side),
+ pad_to_max_len(reject_attention_mask, self.max_length, False, padding_side=padding_side),
+ pad_to_max_len(reject_loss_mask, self.max_length, False, padding_side=padding_side),
+ )
+
+ return dict(
+ chosen_input_ids=chosen_input_ids,
+ chosen_attention_mask=chosen_attention_mask,
+ chosen_loss_mask=chosen_loss_mask,
+ reject_input_ids=reject_input_ids,
+ reject_attention_mask=reject_attention_mask,
+ reject_loss_mask=reject_loss_mask,
+ )
+
+
+class StatefulDistributedSampler(DistributedSampler):
+ """
+ Stateful distributed sampler for multi-stage training.
+ """
+
+ def __init__(
+ self,
+ dataset: DatasetType,
+ num_replicas: Optional[int] = None,
+ rank: Optional[int] = None,
+ shuffle: bool = True,
+ seed: int = 0,
+ drop_last: bool = False,
+ use_tp: Optional[bool] = False,
+ ) -> None:
+ if not use_tp:
+ super().__init__(
+ dataset=dataset,
+ num_replicas=num_replicas,
+ rank=rank,
+ shuffle=shuffle,
+ seed=seed,
+ drop_last=drop_last,
+ )
+ else:
+ # adapted from https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/torch/utils/data/distributed.py#L62
+ # TODO: support tp_group>1. will fix it later
+ num_replicas = 1
+ if rank is None:
+ rank = dist.get_rank()
+ if rank < 0:
+ raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, 0]")
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.drop_last = drop_last
+ # If the dataset length is evenly divisible by # of replicas, then there
+ # is no need to drop any data, since the dataset will be split equally.
+ if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
+ # Split to nearest available length that is evenly divisible.
+ # This is to ensure each rank receives the same amount of data when
+ # using this Sampler.
+ self.num_samples = math.ceil(
+ (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
+ )
+ else:
+ self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
+ self.total_size = self.num_samples * self.num_replicas
+ self.shuffle = shuffle
+ self.seed = seed
+ self.start_index = 0
+ self.use_tp = use_tp
+
+ def __iter__(self) -> Iterator:
+ if self.use_tp:
+ # TODO Add support for tp_group not equal to 1
+ pass
+ # adpated from https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/torch/utils/data/distributed.py#L96
+ if self.shuffle:
+ # deterministically shuffle based on epoch and seed
+ g = torch.Generator()
+ g.manual_seed(self.seed + self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
+ else:
+ indices = list(range(len(self.dataset))) # type: ignore[arg-type]
+
+ if not self.drop_last:
+ # add extra samples to make it evenly divisible
+ padding_size = self.total_size - len(indices)
+ if padding_size <= len(indices):
+ indices += indices[:padding_size]
+ else:
+ indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
+ else:
+ # remove tail of data to make it evenly divisible.
+ indices = indices[: self.total_size]
+ assert len(indices) == self.total_size
+
+ # subsample
+ indices = indices[
+ : self.total_size : self.num_replicas
+ ] # num_replicas=tp_group=1, we only support tp_group==1 for now
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ else:
+ iterator = super().__iter__()
+ indices = list(iterator)
+ indices = indices[self.start_index :]
+ return iter(indices)
+
+ def __len__(self) -> int:
+ return self.num_samples - self.start_index
+
+ def set_start_index(self, start_index: int) -> None:
+ self.start_index = start_index
+
+
+def setup_distributed_dataloader(
+ dataset: DatasetType,
+ batch_size: int = 1,
+ shuffle: bool = False,
+ seed: int = 1024,
+ drop_last: bool = False,
+ pin_memory: bool = False,
+ num_workers: int = 0,
+ collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None,
+ process_group: Optional[ProcessGroup] = None,
+ use_tp: Optional[bool] = False,
+ **kwargs,
+) -> DataLoader:
+ """
+ Setup dataloader for distributed training.
+ """
+ _kwargs = kwargs.copy()
+ process_group = process_group or _get_default_group()
+ sampler = StatefulDistributedSampler(
+ dataset=dataset,
+ num_replicas=process_group.size() if not use_tp else 1,
+ rank=process_group.rank(),
+ shuffle=shuffle,
+ seed=seed,
+ drop_last=drop_last,
+ use_tp=use_tp,
+ )
+
+ # Deterministic dataloader
+ def seed_worker(worker_id: int) -> None:
+ worker_seed = seed
+ np.random.seed(worker_seed)
+ torch.manual_seed(worker_seed)
+ random.seed(worker_seed)
+
+ return DataLoader(
+ dataset=dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ num_workers=num_workers,
+ collate_fn=collate_fn,
+ pin_memory=pin_memory,
+ drop_last=drop_last,
+ worker_init_fn=seed_worker,
+ **_kwargs,
+ )
diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py
new file mode 100755
index 000000000000..7606bc2a97ba
--- /dev/null
+++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py
@@ -0,0 +1,383 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+tokenization utils for constructing dataset for ppo, dpo, sft, rm
+"""
+
+import warnings
+from copy import deepcopy
+from typing import Any, Dict, List, Union
+
+from coati.dataset.conversation import Conversation
+from coati.dataset.utils import split_templated_prompt_into_chunks, tokenize_and_concatenate
+from datasets import dataset_dict
+from torch.utils.data import ConcatDataset, Dataset
+from transformers import PreTrainedTokenizer
+
+from colossalai.logging import get_dist_logger
+
+logger = get_dist_logger()
+
+IGNORE_INDEX = -100
+
+DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
+
+
+def supervised_tokenize_sft(
+ data_point: Dict[str, str],
+ tokenizer: PreTrainedTokenizer,
+ conversation_template: Conversation = None,
+ ignore_index: int = None,
+ max_length: int = 4096,
+) -> Dict[str, Union[int, str, List[int]]]:
+ """
+ A tokenization function to tokenize an original pretraining data point as following
+ and calculate corresponding labels for sft training:
+ "Something here can be system message[user_line_start]User line[User line end][Assistant line start]Assistant line[Assistant line end]...[Assistant line end]Something here"
+ ^
+ end_of_system_line_position
+
+ Args:
+ data_point: the data point of the following format
+ {"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
+ tokenizer: the tokenizer whose
+ conversation_template: the conversation template to apply
+ ignore_index: the ignore index when calculate loss during training
+ max_length: the maximum context length
+ """
+
+ if ignore_index is None:
+ ignore_index = IGNORE_INDEX
+
+ messages = data_point["messages"]
+ template = deepcopy(conversation_template)
+ template.messages = []
+
+ for mess in messages:
+ from_str = mess["from"]
+ if from_str.lower() == "human":
+ from_str = "user"
+ elif from_str.lower() == "assistant":
+ from_str = "assistant"
+ else:
+ raise ValueError(f"Unsupported role {from_str.lower()}")
+
+ template.append_message(from_str, mess["content"])
+
+ if len(template.messages) % 2 != 0:
+ template.messages = template.messages[0:-1]
+
+ # `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
+ turns = [i for i in range(1, len(messages) // 2 + 1)]
+
+ lo, hi = 0, len(turns)
+ while lo < hi:
+ mid = (lo + hi) // 2
+ if max_length - 1 < len(
+ tokenizer([template.get_prompt(2 * turns[mid] - 1)], add_special_tokens=False)["input_ids"][0]
+ ):
+ hi = mid
+ else:
+ lo = mid + 1
+ target_turn_index = lo
+
+ # The tokenized length for first turn already exceeds `max_length - 1`.
+ if target_turn_index - 1 < 0:
+ warnings.warn("The tokenized length for first turn already exceeds `max_length - 1`.")
+ return dict(
+ input_ids=None,
+ labels=None,
+ inputs_decode=None,
+ labels_decode=None,
+ seq_length=None,
+ seq_category=None,
+ )
+
+ target_turn = turns[target_turn_index - 1]
+ prompt = template.get_prompt(2 * target_turn)
+ chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt)
+ tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
+
+ labels = [ignore_index] * len(tokenized)
+ label_decode = []
+ for start, end in zip(starts, ends):
+ if end == len(tokenized):
+ tokenized = tokenized + [tokenizer.eos_token_id]
+ labels = labels + [ignore_index]
+ labels[start : end + 1] = tokenized[start : end + 1]
+ label_decode.append(tokenizer.decode(tokenized[start : end + 1], skip_special_tokens=False))
+
+ if tokenizer.bos_token_id is not None:
+ if tokenized[0] != tokenizer.bos_token_id:
+ tokenized = [tokenizer.bos_token_id] + tokenized
+ labels = [ignore_index] + labels
+
+ if tokenizer.eos_token_id is not None:
+ # Force to add eos token at the end of the tokenized sequence
+ if tokenized[-1] != tokenizer.eos_token_id:
+ tokenized = tokenized + [tokenizer.eos_token_id]
+ labels = labels + [tokenizer.eos_token_id]
+ else:
+ labels[-1] = tokenizer.eos_token_id
+
+ # For some model without bos/eos may raise the following errors
+ try:
+ inputs_decode = tokenizer.decode(tokenized)
+ except TypeError as e:
+ raise TypeError(str(e) + f"\nUnable to decode input_ids: {tokenized}")
+
+ # Check if all labels are ignored, this may happen when the tokenized length is too long
+ if labels.count(ignore_index) == len(labels):
+ return dict(
+ input_ids=None,
+ labels=None,
+ inputs_decode=None,
+ labels_decode=None,
+ seq_length=None,
+ seq_category=None,
+ )
+
+ return dict(
+ input_ids=tokenized,
+ labels=labels,
+ inputs_decode=inputs_decode,
+ labels_decode=label_decode,
+ seq_length=len(tokenized),
+ seq_category=data_point["category"] if "category" in data_point else "None",
+ )
+
+
+def tokenize_prompt_dataset(
+ data_point: Dict[str, str],
+ tokenizer: PreTrainedTokenizer,
+ conversation_template: Conversation = None,
+ ignore_index: int = None,
+ max_length: int = 4096,
+) -> Dict[str, Union[int, str, List[int]]]:
+ """
+ A tokenization function to tokenize an original pretraining data point as following for ppo training:
+ "Something here can be system message[user_line_start]User line[User line end][Assistant line start]Assistant line[Assistant line end]...[Assistant line start]"
+ Args:
+ data_point: the data point of the following format
+ {"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
+ tokenizer: the tokenizer whose
+ conversation_template: the conversation template to apply
+ ignore_index: the ignore index when calculate loss during training
+ max_length: the maximum context length
+ """
+ if ignore_index is None:
+ ignore_index = IGNORE_INDEX
+
+ messages = data_point["messages"]
+ template = deepcopy(conversation_template)
+ template.messages = []
+
+ for mess in messages:
+ from_str = mess["from"]
+ if from_str.lower() == "human":
+ from_str = "user"
+ elif from_str.lower() == "assistant":
+ from_str = "assistant"
+ else:
+ raise ValueError(f"Unsupported role {from_str.lower()}")
+
+ template.append_message(from_str, mess["content"])
+
+ # `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
+ target_turn = len(template.messages)
+ if target_turn % 2 != 1:
+ # exclude the answer if provided. keep only the prompt
+ target_turn = target_turn - 1
+
+ # Prepare data
+ prompt = template.get_prompt(target_turn, add_generation_prompt=True)
+ tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
+ if tokenizer.bos_token_id is not None:
+ if tokenized[0] != tokenizer.bos_token_id:
+ tokenized = [tokenizer.bos_token_id] + tokenized
+
+ # Skip overlength data
+ if max_length - 1 < len(tokenized):
+ return dict(
+ input_ids=None,
+ inputs_decode=None,
+ seq_length=None,
+ seq_category=None,
+ )
+
+ # `inputs_decode` can be used to check whether the tokenization method is true.
+ return dict(
+ input_ids=tokenized,
+ inputs_decode=tokenizer.decode(tokenized),
+ seq_length=len(tokenized),
+ seq_category=data_point["category"] if "category" in data_point else "None",
+ )
+
+
+def apply_rlhf_data_format(
+ template: Conversation, tokenizer: Any, context_len: int, mask_out_target_assistant_line_end=False
+):
+ target_turn = int(len(template.messages) / 2)
+ prompt = template.get_prompt(target_turn * 2)
+ chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt)
+ tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
+ loss_mask = [0] * len(tokenized)
+ mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id
+ if mask_token is None:
+ mask_token = 1 # If the tokenizer doesn't have eos_token or pad_token: Qwen
+
+ label_decode = []
+ for start, end in zip(starts[-1:], ends[-1:]):
+ # only the last round (chosen/rejected) counts
+ if end == len(tokenized):
+ tokenized = tokenized + [tokenizer.eos_token_id]
+ loss_mask = loss_mask + [1]
+ loss_mask[start : end + 1] = [1] * len(loss_mask[start : end + 1])
+ label_decode.append(tokenizer.decode(tokenized[start : end + 1], skip_special_tokens=False))
+ if tokenizer.bos_token_id is not None:
+ if tokenized[0] != tokenizer.bos_token_id:
+ tokenized = [tokenizer.bos_token_id] + tokenized
+ loss_mask = [0] + loss_mask
+
+ if tokenizer.eos_token_id is not None:
+ # Force to add eos token at the end of the tokenized sequence
+ if tokenized[-1] != tokenizer.eos_token_id:
+ tokenized = tokenized + [tokenizer.eos_token_id]
+ loss_mask = loss_mask + [1]
+ else:
+ loss_mask[-1] = 1
+
+ return {"input_ids": tokenized, "loss_mask": loss_mask, "label_decode": label_decode}
+
+
+def tokenize_rlhf(
+ data_point: Dict[str, str],
+ tokenizer: PreTrainedTokenizer,
+ conversation_template: Conversation = None,
+ ignore_index: int = None,
+ max_length: int = 4096,
+) -> Dict[str, Union[int, str, List[int]]]:
+ """
+ A tokenization function to tokenize an original pretraining data point as following:
+ {"context": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}],
+ "chosen": {"from": "assistant", "content": "xxx"}, "rejected": {"from": "assistant", "content": "xxx"}}
+ """
+ if ignore_index is None:
+ ignore_index = IGNORE_INDEX
+
+ context = data_point["context"]
+ template = deepcopy(conversation_template)
+ template.clear()
+
+ for mess in context:
+ from_str = mess["from"]
+ if from_str.lower() == "human":
+ from_str = "user"
+ elif from_str.lower() == "assistant":
+ from_str = "assistant"
+ else:
+ raise ValueError(f"Unsupported role {from_str.lower()}")
+
+ if len(template.messages) > 0 and from_str == template.messages[-1]["role"]:
+ # Concate adjacent message from the same role
+ template.messages[-1]["content"] = str(template.messages[-1]["content"] + " " + mess["content"])
+ else:
+ template.append_message(from_str, mess["content"])
+
+ if len(template.messages) % 2 != 1:
+ warnings.warn(
+ "Please make sure leading context starts and ends with a line from human\nLeading context: "
+ + str(template.messages)
+ )
+ return dict(
+ chosen_input_ids=None,
+ chosen_loss_mask=None,
+ chosen_label_decode=None,
+ rejected_input_ids=None,
+ rejected_loss_mask=None,
+ rejected_label_decode=None,
+ )
+ round_of_context = int((len(template.messages) - 1) / 2)
+
+ assert context[-1]["from"].lower() == "human", "The last message in context should be from human."
+ chosen = deepcopy(template)
+ rejected = deepcopy(template)
+
+ for round in range(len(data_point["chosen"])):
+ from_str = data_point["chosen"][round]["from"]
+ if from_str.lower() == "human":
+ from_str = "user"
+ elif from_str.lower() == "assistant":
+ from_str = "assistant"
+ else:
+ raise ValueError(f"Unsupported role {from_str.lower()}")
+ chosen.append_message(from_str, data_point["chosen"][round]["content"])
+
+ for round in range(len(data_point["rejected"])):
+ from_str = data_point["rejected"][round]["from"]
+ if from_str.lower() == "human":
+ from_str = "user"
+ elif from_str.lower() == "assistant":
+ from_str = "assistant"
+ else:
+ raise ValueError(f"Unsupported role {from_str.lower()}")
+ rejected.append_message(from_str, data_point["rejected"][round]["content"])
+
+ (
+ chosen_input_ids,
+ chosen_loss_mask,
+ chosen_label_decode,
+ rejected_input_ids,
+ rejected_loss_mask,
+ rejected_label_decode,
+ ) = (None, None, None, None, None, None)
+ if (
+ len(tokenizer([chosen.get_prompt(len(chosen.messages))], add_special_tokens=False)["input_ids"][0])
+ <= max_length - 1
+ and len(tokenizer([rejected.get_prompt(len(rejected.messages))], add_special_tokens=False)["input_ids"][0])
+ <= max_length - 1
+ ):
+ chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer, round_of_context)
+ (chosen_input_ids, chosen_loss_mask, chosen_label_decode) = (
+ chosen_data_packed["input_ids"],
+ chosen_data_packed["loss_mask"],
+ chosen_data_packed["label_decode"],
+ )
+
+ rejected_data_packed = apply_rlhf_data_format(
+ rejected, tokenizer, round_of_context, mask_out_target_assistant_line_end=True
+ )
+ (rejected_input_ids, rejected_loss_mask, rejected_label_decode) = (
+ rejected_data_packed["input_ids"],
+ rejected_data_packed["loss_mask"],
+ rejected_data_packed["label_decode"],
+ )
+
+ # Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long
+ if chosen_loss_mask.count(0) == len(chosen_loss_mask) or rejected_loss_mask.count(0) == len(rejected_loss_mask):
+ return dict(
+ chosen_input_ids=None,
+ chosen_loss_mask=None,
+ chosen_label_decode=None,
+ rejected_input_ids=None,
+ rejected_loss_mask=None,
+ rejected_label_decode=None,
+ )
+
+ return {
+ "chosen_input_ids": chosen_input_ids,
+ "chosen_loss_mask": chosen_loss_mask,
+ "chosen_label_decode": chosen_label_decode,
+ "rejected_input_ids": rejected_input_ids,
+ "rejected_loss_mask": rejected_loss_mask,
+ "rejected_label_decode": rejected_label_decode,
+ }
+ else:
+ return dict(
+ chosen_input_ids=None,
+ chosen_loss_mask=None,
+ chosen_label_decode=None,
+ rejected_input_ids=None,
+ rejected_loss_mask=None,
+ rejected_label_decode=None,
+ )
diff --git a/applications/ColossalChat/coati/dataset/utils.py b/applications/ColossalChat/coati/dataset/utils.py
new file mode 100755
index 000000000000..ada2afef0154
--- /dev/null
+++ b/applications/ColossalChat/coati/dataset/utils.py
@@ -0,0 +1,138 @@
+import io
+import json
+from typing import Any, Dict, List
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from transformers import PreTrainedTokenizer
+
+
+def is_rank_0() -> bool:
+ return not dist.is_initialized() or dist.get_rank() == 0
+
+
+def _make_r_io_base(f, mode: str):
+ if not isinstance(f, io.IOBase):
+ f = open(f, mode=mode)
+ return f
+
+
+def jload(f, mode="r"):
+ """Load a .json file into a dictionary."""
+ f = _make_r_io_base(f, mode)
+ jdict = json.load(f)
+ f.close()
+ return jdict
+
+
+def read_string_by_schema(data: Dict[str, Any], schema: str) -> str:
+ """
+ Read a feild of the dataset be schema
+ Args:
+ data: Dict[str, Any]
+ schema: cascaded feild names seperated by '.'. e.g. person.name.first will access data['person']['name']['first']
+ """
+ keys = schema.split(".")
+ result = data
+ for key in keys:
+ result = result.get(key, None)
+ if result is None:
+ return ""
+ assert isinstance(result, str), f"dataset element is not a string: {result}"
+ return result
+
+
+def pad_to_max_len(
+ sequence: List[torch.Tensor], max_length: int, padding_value: int, batch_first: bool = True, padding_side="left"
+):
+ """
+ Args:
+ sequence: a batch of tensor of shape [batch_size, seq_len] if batch_first==True
+ """
+ if padding_side == "left":
+ reversed_sequence = [seq.flip(dims=(0,)) for seq in sequence]
+ padded = torch.nn.utils.rnn.pad_sequence(
+ sequences=reversed_sequence, batch_first=batch_first, padding_value=padding_value
+ )
+ to_pad = max_length - padded.size(1)
+ padded = F.pad(padded, (0, to_pad), value=padding_value)
+ return torch.flip(padded, dims=(1,))
+ elif padding_side == "right":
+ padded = torch.nn.utils.rnn.pad_sequence(
+ sequences=sequence, batch_first=batch_first, padding_value=padding_value
+ )
+ to_pad = max_length - padded.size(1)
+ return F.pad(padded, (0, to_pad), value=padding_value)
+ else:
+ raise RuntimeError(f"`padding_side` can only be `left` or `right`, " f"but now `{padding_side}`")
+
+
+def chuncate_sequence(sequence: List[torch.Tensor], max_length: int, dtype: Any):
+ """
+ Args:
+ sequence: a batch of tensor of shape [batch_size, seq_len] if batch_first==True
+ """
+ return [
+ torch.Tensor(seq[:max_length]).to(dtype) if len(seq) > max_length else torch.Tensor(seq).to(dtype)
+ for seq in sequence
+ ]
+
+
+def find_first_occurrence_subsequence(seq: torch.Tensor, subseq: torch.Tensor, start_index: int = 0) -> int:
+ if subseq is None:
+ return 0
+ for i in range(start_index, len(seq) - len(subseq) + 1):
+ if torch.all(seq[i : i + len(subseq)] == subseq):
+ return i
+ return -1
+
+
+def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], require_loss: List[bool]):
+ """
+ Tokenizes a list of texts using the provided tokenizer and concatenates the tokenized outputs.
+
+ Args:
+ tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenization.
+ text (List[str]): The list of texts to tokenize.
+ require_loss (List[bool]): A list of boolean values indicating whether each text requires loss calculation.
+
+ Returns:
+ Tuple[List[int], List[int], List[int]]: A tuple containing the concatenated tokenized input ids,
+ the start positions of loss spans, and the end positions of loss spans.
+ """
+ input_ids = []
+ loss_starts = []
+ loss_ends = []
+ for s, r in zip(text, require_loss):
+ tokenized = tokenizer(s, add_special_tokens=False)["input_ids"]
+ if r:
+ loss_starts.append(len(input_ids))
+ loss_ends.append(len(input_ids) + len(tokenized))
+ input_ids.extend(tokenized)
+ return input_ids, loss_starts, loss_ends
+
+
+def split_templated_prompt_into_chunks(messages: List[Dict[str, str]], prompt: str):
+ # Seperate templated prompt into chunks by human/assistant's lines, prepare data for tokenize_and_concatenate
+ start_idx = 0
+ chunks = []
+ require_loss = []
+ for line in messages:
+ first_occur = prompt.find(line["content"], start_idx)
+ if prompt[first_occur - 1] != " ":
+ chunks.append(prompt[start_idx:first_occur])
+ chunks.append(prompt[first_occur : first_occur + len(line["content"])])
+ else:
+ chunks.append(prompt[start_idx : first_occur - 1])
+ chunks.append(prompt[first_occur - 1 : first_occur + len(line["content"])])
+ start_idx = first_occur + len(line["content"])
+ if line["role"].lower() == "assistant":
+ require_loss.append(False)
+ require_loss.append(True)
+ else:
+ require_loss.append(False)
+ require_loss.append(False)
+ chunks.append(prompt[start_idx:])
+ require_loss.append(False)
+ return chunks, require_loss
diff --git a/applications/Chat/coati/experience_buffer/__init__.py b/applications/ColossalChat/coati/experience_buffer/__init__.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/coati/experience_buffer/__init__.py
rename to applications/ColossalChat/coati/experience_buffer/__init__.py
diff --git a/applications/Chat/coati/experience_buffer/base.py b/applications/ColossalChat/coati/experience_buffer/base.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/coati/experience_buffer/base.py
rename to applications/ColossalChat/coati/experience_buffer/base.py
diff --git a/applications/Chat/coati/experience_buffer/naive.py b/applications/ColossalChat/coati/experience_buffer/naive.py
old mode 100644
new mode 100755
similarity index 86%
rename from applications/Chat/coati/experience_buffer/naive.py
rename to applications/ColossalChat/coati/experience_buffer/naive.py
index d47b67dbe713..b912df26818d
--- a/applications/Chat/coati/experience_buffer/naive.py
+++ b/applications/ColossalChat/coati/experience_buffer/naive.py
@@ -1,13 +1,16 @@
import random
-import warnings
from typing import List
import torch
from coati.experience_maker.base import Experience
+from colossalai.logging import get_dist_logger
+
from .base import ExperienceBuffer
from .utils import BufferItem, make_experience_batch, split_experience_batch
+logger = get_dist_logger()
+
class NaiveExperienceBuffer(ExperienceBuffer):
"""Naive experience buffer class. It stores experience.
@@ -35,7 +38,7 @@ def append(self, experience: Experience) -> None:
if self.limit > 0:
samples_to_remove = len(self.items) - self.limit
if samples_to_remove > 0:
- warnings.warn(f"Experience buffer is full. Removing {samples_to_remove} samples.")
+ logger.warning(f"Experience buffer is full. Removing {samples_to_remove} samples.")
self.items = self.items[samples_to_remove:]
def clear(self) -> None:
@@ -43,6 +46,12 @@ def clear(self) -> None:
@torch.no_grad()
def sample(self) -> Experience:
+ """
+ Randomly samples experiences from the buffer.
+
+ Returns:
+ A batch of sampled experiences.
+ """
items = random.sample(self.items, self.sample_batch_size)
experience = make_experience_batch(items)
if self.cpu_offload:
diff --git a/applications/Chat/coati/experience_buffer/utils.py b/applications/ColossalChat/coati/experience_buffer/utils.py
old mode 100644
new mode 100755
similarity index 94%
rename from applications/Chat/coati/experience_buffer/utils.py
rename to applications/ColossalChat/coati/experience_buffer/utils.py
index baedbebd184f..c4807d179d90
--- a/applications/Chat/coati/experience_buffer/utils.py
+++ b/applications/ColossalChat/coati/experience_buffer/utils.py
@@ -26,6 +26,7 @@ class BufferItem:
action_log_probs: torch.Tensor
values: torch.Tensor
reward: torch.Tensor
+ kl: torch.Tensor
advantages: torch.Tensor
attention_mask: Optional[torch.LongTensor]
action_mask: Optional[torch.BoolTensor]
@@ -34,7 +35,7 @@ class BufferItem:
def split_experience_batch(experience: Experience) -> List[BufferItem]:
batch_size = experience.sequences.size(0)
batch_kwargs = [{} for _ in range(batch_size)]
- keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask")
+ keys = ("sequences", "action_log_probs", "values", "reward", "kl", "advantages", "attention_mask", "action_mask")
for key in keys:
value = getattr(experience, key)
if isinstance(value, torch.Tensor):
@@ -63,7 +64,7 @@ def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left") -> to
def make_experience_batch(items: List[BufferItem]) -> Experience:
kwargs = {}
to_pad_keys = set(("action_log_probs", "action_mask"))
- keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask")
+ keys = ("sequences", "action_log_probs", "values", "reward", "kl", "advantages", "attention_mask", "action_mask")
for key in keys:
vals = [getattr(item, key) for item in items]
if key in to_pad_keys:
diff --git a/applications/Chat/coati/experience_maker/__init__.py b/applications/ColossalChat/coati/experience_maker/__init__.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/coati/experience_maker/__init__.py
rename to applications/ColossalChat/coati/experience_maker/__init__.py
diff --git a/applications/Chat/coati/experience_maker/base.py b/applications/ColossalChat/coati/experience_maker/base.py
old mode 100644
new mode 100755
similarity index 74%
rename from applications/Chat/coati/experience_maker/base.py
rename to applications/ColossalChat/coati/experience_maker/base.py
index 0731f6e0f97f..55054c3a0611
--- a/applications/Chat/coati/experience_maker/base.py
+++ b/applications/ColossalChat/coati/experience_maker/base.py
@@ -3,7 +3,8 @@
from typing import Optional
import torch
-from coati.models.base import Actor, Critic, RewardModel
+from coati.models import Critic, RewardModel
+from transformers import PreTrainedModel
@dataclass
@@ -28,6 +29,7 @@ class Experience:
action_log_probs: torch.Tensor
values: torch.Tensor
reward: torch.Tensor
+ kl: torch.Tensor
advantages: torch.Tensor
attention_mask: Optional[torch.LongTensor]
action_mask: Optional[torch.BoolTensor]
@@ -39,6 +41,7 @@ def to_device(self, device: torch.device) -> None:
self.values = self.values.to(device)
self.reward = self.reward.to(device)
self.advantages = self.advantages.to(device)
+ self.kl = self.kl.to(device)
if self.attention_mask is not None:
self.attention_mask = self.attention_mask.to(device)
if self.action_mask is not None:
@@ -50,6 +53,7 @@ def pin_memory(self):
self.values = self.values.pin_memory()
self.reward = self.reward.pin_memory()
self.advantages = self.advantages.pin_memory()
+ self.kl = self.kl.pin_memory()
if self.attention_mask is not None:
self.attention_mask = self.attention_mask.pin_memory()
if self.action_mask is not None:
@@ -58,7 +62,13 @@ def pin_memory(self):
class ExperienceMaker(ABC):
- def __init__(self, actor: Actor, critic: Critic, reward_model: RewardModel, initial_model: Actor) -> None:
+ """
+ Base class for experience makers.
+ """
+
+ def __init__(
+ self, actor: PreTrainedModel, critic: Critic, reward_model: RewardModel, initial_model: PreTrainedModel
+ ) -> None:
super().__init__()
self.actor = actor
self.critic = critic
@@ -67,4 +77,14 @@ def __init__(self, actor: Actor, critic: Critic, reward_model: RewardModel, init
@abstractmethod
def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience:
- pass
+ """
+ Abstract method to generate an experience.
+
+ Args:
+ input_ids (torch.Tensor): The input tensor.
+ attention_mask (torch.Tensor): The attention mask tensor.
+ **generate_kwargs: Additional keyword arguments for generating the experience.
+
+ Returns:
+ Experience: The generated experience.
+ """
diff --git a/applications/ColossalChat/coati/experience_maker/naive.py b/applications/ColossalChat/coati/experience_maker/naive.py
new file mode 100755
index 000000000000..945bb95577c7
--- /dev/null
+++ b/applications/ColossalChat/coati/experience_maker/naive.py
@@ -0,0 +1,180 @@
+"""
+experience maker.
+"""
+
+import torch
+import torch.nn.functional as F
+from coati.dataset.utils import find_first_occurrence_subsequence
+from coati.models import Critic, RewardModel
+from coati.models.generation import generate
+from coati.models.utils import calc_action_log_probs, compute_reward
+from transformers import PreTrainedModel, PreTrainedTokenizer
+
+from colossalai.logging import get_dist_logger
+
+from .base import Experience, ExperienceMaker
+
+logger = get_dist_logger()
+
+import torch.distributed as dist
+
+
+def is_rank_0() -> bool:
+ return not dist.is_initialized() or dist.get_rank() == 0
+
+
+class NaiveExperienceMaker(ExperienceMaker):
+ """
+ Naive experience maker.
+ """
+
+ def __init__(
+ self,
+ actor: PreTrainedModel,
+ critic: Critic,
+ reward_model: RewardModel,
+ initial_model: PreTrainedModel,
+ tokenizer: PreTrainedTokenizer,
+ kl_coef: float = 0.01,
+ gamma: float = 1.0,
+ lam: float = 0.95,
+ ) -> None:
+ super().__init__(actor, critic, reward_model, initial_model)
+ self.tokenizer = tokenizer
+ self.kl_coef = kl_coef
+ self.gamma = gamma
+ self.lam = lam
+
+ @torch.no_grad()
+ def calculate_advantage(self, value: torch.Tensor, reward: torch.Tensor, num_actions: int) -> torch.Tensor:
+ """
+ Calculates the advantage values for each action based on the value and reward tensors.
+
+ Args:
+ value (torch.Tensor): Tensor containing the predicted values from critic.
+ reward (torch.Tensor): reward of the shape [B, len].
+ num_actions (int): Number of actions.
+
+ Returns:
+ torch.Tensor: Tensor containing the calculated advantages for each action.
+ """
+ lastgaelam = 0
+ advantages_reversed = []
+ for t in reversed(range(num_actions)):
+ nextvalues = value[:, t + 1] if t < num_actions - 1 else 0.0
+ delta = reward[:, t] + self.gamma * nextvalues - value[:, t]
+ lastgaelam = delta + self.gamma * self.lam * lastgaelam
+ advantages_reversed.append(lastgaelam)
+ advantages = torch.stack(advantages_reversed[::-1], dim=1)
+ return advantages
+
+ @torch.no_grad()
+ def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience:
+ """
+ Generates an experience using the given input_ids and attention_mask.
+
+ Args:
+ input_ids (torch.Tensor): The input tensor containing the tokenized input sequence.
+ attention_mask (torch.Tensor): The attention mask tensor indicating which tokens to attend to.
+ **generate_kwargs: Additional keyword arguments for the generation process.
+
+ Returns:
+ Experience: The generated experience object.
+
+ """
+ self.actor.eval()
+ self.critic.eval()
+ self.initial_model.eval()
+ self.reward_model.eval()
+ pad_token_id = self.tokenizer.pad_token_id
+
+ stop_token_ids = generate_kwargs.get("stop_token_ids", None)
+ torch.manual_seed(41) # for tp, gurantee the same input for reward model
+
+ sequences = generate(self.actor, input_ids, self.tokenizer, **generate_kwargs)
+
+ # Pad to max length
+ sequences = F.pad(sequences, (0, generate_kwargs["max_length"] - sequences.size(1)), value=pad_token_id)
+ sequence_length = sequences.size(1)
+
+ # Calculate auxiliary tensors
+ attention_mask = None
+ if pad_token_id is not None:
+ attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
+
+ input_len = input_ids.size(1)
+ if stop_token_ids is None:
+ # End the sequence with eos token
+ eos_token_id = self.tokenizer.eos_token_id
+ if eos_token_id is None:
+ action_mask = torch.ones_like(sequences, dtype=torch.bool)
+ else:
+ # Left padding may be applied, only mask action
+ action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
+ action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
+ else:
+ # stop_token_ids are given, generation ends with stop_token_ids
+ action_mask = torch.ones_like(sequences, dtype=torch.bool)
+ for i in range(sequences.size(0)):
+ stop_index = find_first_occurrence_subsequence(
+ sequences[i][input_len:], torch.tensor(stop_token_ids).to(sequences.device)
+ )
+ if stop_index == -1:
+ # Sequence does not contain stop_token_ids, this should never happen BTW
+ logger.warning(
+ "Generated sequence does not contain stop_token_ids. Please check your chat template config"
+ )
+ else:
+ # Keep stop tokens
+ stop_index = input_len + stop_index
+ action_mask[i, stop_index + len(stop_token_ids) :] = False
+
+ generation_end_index = (action_mask == True).sum(dim=-1) - 1
+ action_mask[:, :input_len] = False
+ action_mask = action_mask[:, 1:]
+ action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
+ num_actions = action_mask.size(1)
+
+ actor_output = self.actor(input_ids=sequences, attention_mask=attention_mask)["logits"]
+ action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
+
+ base_model_output = self.initial_model(input_ids=sequences, attention_mask=attention_mask)["logits"]
+
+ base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
+
+ # Convert to right padding for the reward model and the critic model
+ input_ids_rm = torch.zeros_like(sequences, device=sequences.device)
+ attention_mask_rm = torch.zeros_like(sequences, device=sequences.device)
+ for i in range(sequences.size(0)):
+ sequence = sequences[i]
+ bos_index = (sequence != pad_token_id).nonzero().reshape([-1])[0]
+ eos_index = generation_end_index[i]
+ sequence_to_pad = sequence[bos_index:eos_index]
+ sequence_padded = F.pad(
+ sequence_to_pad, (0, sequence_length - sequence_to_pad.size(0)), value=self.tokenizer.pad_token_id
+ )
+ input_ids_rm[i] = sequence_padded
+ if sequence_length - sequence_to_pad.size(0) > 0:
+ attention_mask_rm[i, : sequence_to_pad.size(0) + 1] = 1
+ else:
+ attention_mask_rm[i, :] = 1
+ attention_mask_rm = attention_mask_rm.to(dtype=torch.bool)
+
+ r = self.reward_model(
+ input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device),
+ attention_mask=attention_mask_rm.to(device=sequences.device),
+ )
+
+ value = self.critic(
+ input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device),
+ attention_mask=attention_mask_rm.to(device=sequences.device),
+ )
+ reward, kl = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
+ value = value[:, -num_actions:] * action_mask
+ advantages = self.calculate_advantage(value, reward, num_actions)
+
+ advantages = advantages.detach()
+ value = value.detach()
+ r = r.detach()
+
+ return Experience(sequences, action_log_probs, value, r, kl, advantages, attention_mask, action_mask)
diff --git a/applications/ColossalChat/coati/models/__init__.py b/applications/ColossalChat/coati/models/__init__.py
new file mode 100755
index 000000000000..14073207f150
--- /dev/null
+++ b/applications/ColossalChat/coati/models/__init__.py
@@ -0,0 +1,24 @@
+from .base import BaseModel
+from .critic import Critic
+from .generation import generate, generate_streaming, prepare_inputs_fn, update_model_kwargs_fn
+from .lora import convert_to_lora_module
+from .loss import DpoLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
+from .reward_model import RewardModel
+from .utils import disable_dropout
+
+__all__ = [
+ "BaseModel",
+ "Critic",
+ "RewardModel",
+ "PolicyLoss",
+ "ValueLoss",
+ "LogSigLoss",
+ "LogExpLoss",
+ "convert_to_lora_module",
+ "DpoLoss",
+ "generate",
+ "generate_streaming",
+ "disable_dropout",
+ "update_model_kwargs_fn",
+ "prepare_inputs_fn",
+]
diff --git a/applications/ColossalChat/coati/models/base.py b/applications/ColossalChat/coati/models/base.py
new file mode 100755
index 000000000000..fcea9414b430
--- /dev/null
+++ b/applications/ColossalChat/coati/models/base.py
@@ -0,0 +1,58 @@
+"""
+Base class for critic and reward model
+"""
+
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from transformers import AutoModel, PretrainedConfig
+
+
+class BaseModel(nn.Module):
+ """
+ Actor model base class.
+
+ Args:
+ pretrained (str): path to pretrained model.
+ config (PretrainedConfig): PretrainedConfig used to initiate the base model.
+ **kwargs: all other kwargs as in AutoModel.from_pretrained
+ """
+
+ def __init__(self, pretrained: str = None, config: Optional[PretrainedConfig] = None, **kwargs) -> None:
+ super().__init__()
+ if pretrained is not None:
+ if config is not None:
+ # initialize with config and load weights from pretrained
+ self.model = AutoModel.from_pretrained(pretrained, config=config, **kwargs)
+ else:
+ # initialize with pretrained
+ self.model = AutoModel.from_pretrained(pretrained, **kwargs)
+ elif config is not None:
+ # initialize with config
+ self.model = AutoModel.from_config(config, **kwargs)
+ else:
+ raise ValueError("Either pretrained or config must be provided.")
+
+ self.config = self.model.config
+ # create dummy input to get the size of the last hidden state
+ if "use_flash_attention_2" in kwargs:
+ self.model = self.model.cuda()
+ dummy_input = torch.zeros((1, 1), dtype=torch.long).to(self.model.device)
+ out = self.model(dummy_input)
+ self.last_hidden_state_size = out.last_hidden_state.shape[-1]
+ self.model = self.model.cpu()
+ # print("self.last_hidden_state_size: ",self.last_hidden_state_size)
+
+ def resize_token_embeddings(self, *args, **kwargs):
+ """
+ Resize the token embeddings of the model.
+
+ Args:
+ *args: Variable length argument list.
+ **kwargs: Arbitrary keyword arguments.
+
+ Returns:
+ The resized token embeddings.
+ """
+ return self.model.resize_token_embeddings(*args, **kwargs)
diff --git a/applications/ColossalChat/coati/models/critic.py b/applications/ColossalChat/coati/models/critic.py
new file mode 100755
index 000000000000..80340d9bd43d
--- /dev/null
+++ b/applications/ColossalChat/coati/models/critic.py
@@ -0,0 +1,34 @@
+"""
+Critic model
+"""
+
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from coati.models import BaseModel
+from transformers import PretrainedConfig
+
+
+class Critic(BaseModel):
+ """
+ Critic model class.
+
+ Args:
+ pretrained (str): path to pretrained model.
+ config (PretrainedConfig): PretrainedConfig used to initiate the base model.
+ """
+
+ def __init__(self, pretrained: str = None, config: Optional[PretrainedConfig] = None, **kwargs) -> None:
+ super().__init__(pretrained=pretrained, config=config, **kwargs)
+ # et last hidden state size with dummy input
+ self.value_head = nn.Linear(self.last_hidden_state_size, 1)
+
+ def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ outputs = self.model(input_ids, attention_mask=attention_mask)
+ last_hidden_states = outputs["last_hidden_state"]
+ sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), :].type(
+ self.value_head.weight.dtype
+ )
+ values = self.value_head(sequence_hidden_states).squeeze(-1) # ensure shape is (B, sequence length)
+ return values
diff --git a/applications/ColossalChat/coati/models/generation.py b/applications/ColossalChat/coati/models/generation.py
new file mode 100755
index 000000000000..b671ef124063
--- /dev/null
+++ b/applications/ColossalChat/coati/models/generation.py
@@ -0,0 +1,428 @@
+from typing import Any, Callable, List, Optional
+
+import torch
+import torch.distributed as dist
+from transformers import PreTrainedTokenizer
+
+try:
+ from transformers.generation_logits_process import (
+ LogitsProcessorList,
+ TemperatureLogitsWarper,
+ TopKLogitsWarper,
+ TopPLogitsWarper,
+ )
+except ImportError:
+ from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
+
+
+def _prepare_logits_processor(
+ top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
+) -> LogitsProcessorList:
+ """
+ Prepare the logits processor list based on the given parameters.
+
+ Args:
+ top_k (Optional[int]): The number of highest probability logits to keep for each token.
+ top_p (Optional[float]): The cumulative probability threshold for selecting tokens.
+ temperature (Optional[float]): The temperature value to apply to the logits.
+
+ Returns:
+ LogitsProcessorList: The list of logits processors.
+
+ """
+ processor_list = LogitsProcessorList()
+ if temperature is not None and temperature != 1.0:
+ processor_list.append(TemperatureLogitsWarper(temperature))
+ if top_k is not None and top_k != 0:
+ processor_list.append(TopKLogitsWarper(top_k))
+ if top_p is not None and top_p < 1.0:
+ processor_list.append(TopPLogitsWarper(top_p))
+ return processor_list
+
+
+def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
+ """
+ Check if the sequence generation is finished.
+
+ Args:
+ unfinished_sequences (torch.Tensor): Tensor indicating the unfinished sequences.
+
+ Returns:
+ bool: True if all sequences are finished, False otherwise.
+ """
+ if dist.is_initialized() and dist.get_world_size() > 1:
+ # consider DP
+ unfinished_sequences = unfinished_sequences.clone()
+ dist.all_reduce(unfinished_sequences)
+ return unfinished_sequences.max() == 0
+
+
+def update_model_kwargs_fn(outputs: dict, new_mask, **model_kwargs) -> dict:
+ """
+ Update the model keyword arguments based on the outputs and new mask.
+
+ Args:
+ outputs (dict): The outputs from the model.
+ new_mask: The new attention mask.
+ **model_kwargs: Additional model keyword arguments.
+
+ Returns:
+ dict: The updated model keyword arguments.
+ """
+
+ if "past_key_values" in outputs:
+ model_kwargs["past_key_values"] = outputs["past_key_values"]
+ else:
+ model_kwargs["past_key_values"] = None
+
+ # update token_type_ids with last value
+ if "token_type_ids" in model_kwargs:
+ token_type_ids = model_kwargs["token_type_ids"]
+ model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
+
+ # update attention mask
+ if "attention_mask" in model_kwargs:
+ attention_mask = model_kwargs["attention_mask"]
+ model_kwargs["attention_mask"] = torch.cat([attention_mask, new_mask], dim=-1)
+
+ return model_kwargs
+
+
+def prepare_inputs_fn(input_ids: torch.Tensor, pad_token_id: int, **model_kwargs) -> dict:
+ model_kwargs["input_ids"] = input_ids
+ return model_kwargs
+
+
+def _sample(
+ model: Any,
+ input_ids: torch.Tensor,
+ max_length: int,
+ early_stopping: bool = True,
+ eos_token_id: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ stop_token_ids: Optional[List[int]] = None,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ temperature: Optional[float] = None,
+ max_new_tokens: int = None,
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
+ update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
+ stream_interval: int = 2,
+ **model_kwargs,
+) -> torch.Tensor:
+ """
+ Generates new tokens using the given model and input_ids.
+
+ Args:
+ model (Any): The model used for token generation.
+ input_ids (torch.Tensor): The input tensor containing the initial tokens.
+ max_length (int): The maximum length of the generated tokens.
+ early_stopping (bool, optional): Whether to stop generating tokens early if all sequences are finished. Defaults to True.
+ eos_token_id (int, optional): The ID of the end-of-sequence token. Defaults to None.
+ pad_token_id (int, optional): The ID of the padding token. Defaults to None.
+ stop_token_ids (List[int], optional): A list of token IDs that, if encountered, will stop the generation process. Defaults to None.
+ top_k (int, optional): The number of top-k tokens to consider during sampling. Defaults to None.
+ top_p (float, optional): The cumulative probability threshold for top-p sampling. Defaults to None.
+ temperature (float, optional): The temperature value for token sampling. Defaults to None.
+ max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to None.
+ prepare_inputs_fn (Callable[[torch.Tensor, Any], dict], optional): A function to prepare the model inputs. Defaults to None.
+ update_model_kwargs_fn (Callable[[dict, Any], dict], optional): A function to update the model kwargs. Defaults to None.
+ stream_interval (int, optional): The interval for streaming generation. Defaults to 2.
+ **model_kwargs: Additional keyword arguments for the model.
+
+ Returns:
+ torch.Tensor: The tensor containing the generated tokens.
+ """
+ context_length = input_ids.size(1)
+ if max_new_tokens is None:
+ max_new_tokens = max_length - context_length
+ if context_length + max_new_tokens > max_length or max_new_tokens == 0:
+ return input_ids
+
+ logits_processor = _prepare_logits_processor(top_k, top_p, temperature)
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
+ past = None
+ for i in range(context_length, context_length + max_new_tokens):
+ # Calculate attention mask
+ if "attention_mask" not in model_kwargs:
+ model_kwargs["attention_mask"] = input_ids.ne(pad_token_id)
+ model_inputs = (
+ prepare_inputs_fn(input_ids, past=past, **model_kwargs)
+ if prepare_inputs_fn is not None
+ else {"input_ids": input_ids, "attention_mask": input_ids.ne(pad_token_id)}
+ )
+ outputs = model(**model_inputs)
+
+ if "past_key_values" in outputs:
+ past = outputs.past_key_values
+ elif "mems" in outputs:
+ past = outputs.mems
+
+ # NOTE: this is correct only in left padding mode
+ next_token_logits = outputs["logits"][:, -1, :]
+ next_token_logits = logits_processor(input_ids, next_token_logits)
+
+ # Sample
+ probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
+
+ # Finished sentences should have their next token be a padding token
+ if eos_token_id is not None:
+ assert pad_token_id is not None, "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
+
+ # Update generated ids, model inputs for next step
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
+
+ if update_model_kwargs_fn is not None:
+ model_kwargs = update_model_kwargs_fn(outputs, model_kwargs)
+
+ # If eos_token was found in one sentence, set sentence to finished
+ if eos_token_id is not None:
+ unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
+
+ if stop_token_ids is not None:
+ # If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
+ tokens_to_check = input_ids[:, -len(stop_token_ids) :]
+ unfinished_sequences = unfinished_sequences.mul(
+ torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
+ )
+
+ # Stop when each sentence is finished if early_stopping=True
+ if (early_stopping and _is_sequence_finished(unfinished_sequences)) or i == context_length + max_new_tokens - 1:
+ if i == context_length + max_new_tokens - 1:
+ # Force to end with stop token ids
+ input_ids[input_ids[:, -1] != pad_token_id, -len(stop_token_ids) :] = (
+ torch.LongTensor(stop_token_ids).to(input_ids.device).long()
+ )
+ return input_ids
+
+
+@torch.inference_mode()
+def generate(
+ model: Any,
+ input_ids: torch.Tensor,
+ tokenizer: PreTrainedTokenizer,
+ max_length: int,
+ num_beams: int = 1,
+ do_sample: bool = True,
+ early_stopping: bool = True,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ temperature: Optional[float] = None,
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
+ update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
+ **model_kwargs,
+) -> torch.Tensor:
+ """Generate token sequence. The returned sequence is input_ids + generated_tokens.
+
+ Args:
+ model (nn.Module): model
+ input_ids (torch.Tensor): input sequence
+ max_length (int): max length of the returned sequence
+ num_beams (int, optional): number of beams. Defaults to 1.
+ do_sample (bool, optional): whether to do sample. Defaults to True.
+ early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
+ top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
+ top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
+ temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
+ prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
+ update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
+ """
+ assert tokenizer.padding_side == "left", "Current generation only supports left padding."
+ is_greedy_gen_mode = (num_beams == 1) and do_sample is False
+ is_sample_gen_mode = (num_beams == 1) and do_sample is True
+ is_beam_gen_mode = (num_beams > 1) and do_sample is False
+ if is_greedy_gen_mode:
+ raise NotImplementedError
+ elif is_sample_gen_mode:
+ # Run sample
+ res = _sample(
+ model,
+ input_ids,
+ max_length,
+ early_stopping=early_stopping,
+ eos_token_id=tokenizer.eos_token_id,
+ pad_token_id=tokenizer.pad_token_id,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ prepare_inputs_fn=prepare_inputs_fn,
+ update_model_kwargs_fn=update_model_kwargs_fn,
+ **model_kwargs,
+ )
+ return res
+ elif is_beam_gen_mode:
+ raise NotImplementedError
+ else:
+ raise ValueError("Unsupported generation mode")
+
+
+def _sample_streaming(
+ model: Any,
+ input_ids: torch.Tensor,
+ max_length: int,
+ early_stopping: bool = False,
+ eos_token_id: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ stop_token_ids: Optional[List[int]] = None,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ temperature: Optional[float] = None,
+ max_new_tokens: int = None,
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
+ update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
+ stream_interval: int = 2,
+ **model_kwargs,
+) -> torch.Tensor:
+ """
+ Generates new tokens using a streaming approach.
+
+ Args:
+ model (Any): The model used for token generation.
+ input_ids (torch.Tensor): The input tensor containing the initial tokens.
+ max_length (int): The maximum length of the generated sequence.
+ early_stopping (bool, optional): Whether to stop generating tokens for a sequence if it is finished. Defaults to False.
+ eos_token_id (int, optional): The ID of the end-of-sequence token. Defaults to None.
+ pad_token_id (int, optional): The ID of the padding token. Defaults to None.
+ stop_token_ids (List[int], optional): A list of token IDs that, if encountered, will mark the sequence as finished. Defaults to None.
+ top_k (int, optional): The number of top-k tokens to consider during sampling. Defaults to None.
+ top_p (float, optional): The cumulative probability threshold for top-p sampling. Defaults to None.
+ temperature (float, optional): The temperature value for sampling. Defaults to None.
+ max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to None.
+ prepare_inputs_fn (Callable[[torch.Tensor, Any], dict], optional): A function to prepare the model inputs. Defaults to None.
+ update_model_kwargs_fn (Callable[[dict, Any], dict], optional): A function to update the model keyword arguments. Defaults to None.
+ stream_interval (int, optional): The interval at which to yield the generated tokens. Defaults to 2.
+ **model_kwargs: Additional keyword arguments to be passed to the model.
+
+ Yields:
+ torch.Tensor: The generated tokens at each step.
+
+ Returns:
+ torch.Tensor: The final generated tokens.
+ """
+
+ context_length = input_ids.size(1)
+ if max_new_tokens is None:
+ max_new_tokens = max_length - context_length
+ if context_length + max_new_tokens > max_length or max_new_tokens == 0:
+ return input_ids
+
+ logits_processor = _prepare_logits_processor(top_k, top_p, temperature)
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
+
+ past = None
+ for i in range(context_length, context_length + max_new_tokens):
+ # calculate attention mask
+ if "attention_mask" not in model_kwargs:
+ model_kwargs["attention_mask"] = input_ids.ne(pad_token_id)
+ model_inputs = (
+ prepare_inputs_fn(input_ids, past=past, **model_kwargs)
+ if prepare_inputs_fn is not None
+ else {"input_ids": input_ids, "attention_mask": input_ids.ne(pad_token_id)}
+ )
+ outputs = model(**model_inputs)
+ if "past_key_values" in outputs:
+ past = outputs.past_key_values
+ elif "mems" in outputs:
+ past = outputs.mems
+
+ # NOTE: this is correct only in left padding mode
+ next_token_logits = outputs["logits"][:, -1, :]
+ next_token_logits = logits_processor(input_ids, next_token_logits)
+ # sample
+ probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
+
+ # finished sentences should have their next token be a padding token
+ if eos_token_id is not None:
+ assert pad_token_id is not None, "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
+
+ # update generated ids, model inputs for next step
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
+
+ if update_model_kwargs_fn is not None:
+ model_kwargs = update_model_kwargs_fn(outputs, model_kwargs)
+
+ # if eos_token was found in one sentence, set sentence to finished
+ if eos_token_id is not None:
+ unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
+
+ if stop_token_ids is not None:
+ # If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
+ tokens_to_check = input_ids[:, -len(stop_token_ids) :]
+ unfinished_sequences = unfinished_sequences.mul(
+ torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
+ )
+
+ # Stop when each sentence is finished if early_stopping=True
+ if (
+ (early_stopping and _is_sequence_finished(unfinished_sequences))
+ or (i - context_length) % stream_interval == 0
+ or i == context_length + max_new_tokens - 1
+ ):
+ yield input_ids
+ if early_stopping and _is_sequence_finished(unfinished_sequences):
+ break
+
+
+@torch.inference_mode()
+def generate_streaming(
+ model: Any,
+ input_ids: torch.Tensor,
+ tokenizer: PreTrainedTokenizer,
+ max_length: int,
+ num_beams: int = 1,
+ do_sample: bool = True,
+ early_stopping: bool = False,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ temperature: Optional[float] = None,
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
+ update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
+ **model_kwargs,
+):
+ """Generate token sequence. The returned sequence is input_ids + generated_tokens.
+
+ Args:
+ model (nn.Module): model
+ input_ids (torch.Tensor): input sequence
+ max_length (int): max length of the returned sequence
+ num_beams (int, optional): number of beams. Defaults to 1.
+ do_sample (bool, optional): whether to do sample. Defaults to True.
+ early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
+ top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
+ top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
+ temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
+ prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
+ update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
+ """
+ assert tokenizer.padding_side == "left", "Current generation only supports left padding."
+ is_greedy_gen_mode = (num_beams == 1) and do_sample is False
+ is_sample_gen_mode = (num_beams == 1) and do_sample is True
+ is_beam_gen_mode = (num_beams > 1) and do_sample is False
+ if is_greedy_gen_mode:
+ # run greedy search
+ raise NotImplementedError
+ elif is_sample_gen_mode:
+ # run sample
+ for res in _sample_streaming(
+ model,
+ input_ids,
+ max_length,
+ early_stopping=early_stopping,
+ eos_token_id=tokenizer.eos_token_id,
+ pad_token_id=tokenizer.pad_token_id,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ prepare_inputs_fn=prepare_inputs_fn,
+ update_model_kwargs_fn=update_model_kwargs_fn,
+ **model_kwargs,
+ ):
+ yield res
+ elif is_beam_gen_mode:
+ raise NotImplementedError
+ else:
+ raise ValueError("Unsupported generation mode")
diff --git a/applications/Chat/coati/models/lora.py b/applications/ColossalChat/coati/models/lora.py
old mode 100644
new mode 100755
similarity index 85%
rename from applications/Chat/coati/models/lora.py
rename to applications/ColossalChat/coati/models/lora.py
index e9bd7b2ed8f0..9553b00ff2a8
--- a/applications/Chat/coati/models/lora.py
+++ b/applications/ColossalChat/coati/models/lora.py
@@ -1,3 +1,7 @@
+"""
+LORA utils
+"""
+
import dataclasses
import math
import warnings
@@ -8,6 +12,10 @@
import torch.nn as nn
import torch.nn.functional as F
+from colossalai.logging import get_dist_logger
+
+logger = get_dist_logger()
+
@dataclasses.dataclass
class LoRAManager:
@@ -58,6 +66,10 @@ def reset_parameters(self):
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
+ """
+ This function runs when model.train() is invoked. It is used to prepare the linear layer for training
+ """
+
def T(w):
return w.T if self.fan_in_fan_out else w
@@ -101,6 +113,16 @@ def T(w):
def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
+ """
+ Wraps a linear layer with LoRA functionality.
+
+ Args:
+ linear (nn.Linear): The linear layer to be wrapped.
+ lora_rank (int): The rank of the LoRA decomposition.
+
+ Returns:
+ LoraLinear: The wrapped linear layer with LoRA functionality.
+ """
assert (
lora_rank <= linear.in_features
), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
@@ -109,6 +131,16 @@ def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
+ """
+ Recursively converts the given module and its children to LoRA (Low-Rank Approximation) form.
+
+ Args:
+ module (nn.Module): The module to convert to LoRA form.
+ lora_rank (int): The rank of the LoRA approximation.
+
+ Returns:
+ None
+ """
for name, child in module.named_children():
if isinstance(child, nn.Linear):
setattr(module, name, _lora_linear_wrapper(child, lora_rank))
@@ -131,23 +163,3 @@ def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: s
_convert_to_lora_recursively(module, lora_rank)
lora.mark_only_lora_as_trainable(module, lora_train_bias)
return module
-
-
-class LoRAModule(nn.Module):
- """A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`.
- This class will convert all torch.nn.Linear layer to LoraLinear layer.
-
- Args:
- lora_rank (int, optional): LoRA rank. 0 means LoRA is not applied. Defaults to 0.
- lora_train_bias (str, optional): Whether LoRA train biases.
- 'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers.
- Defaults to 'none'.
- """
-
- def __init__(self, lora_rank: int = 0, lora_train_bias: str = "none") -> None:
- super().__init__()
- self.lora_rank = lora_rank
- self.lora_train_bias = lora_train_bias
-
- def convert_to_lora(self) -> None:
- convert_to_lora_module(self, self.lora_rank, self.lora_train_bias)
diff --git a/applications/ColossalChat/coati/models/loss.py b/applications/ColossalChat/coati/models/loss.py
new file mode 100755
index 000000000000..aaef447a4383
--- /dev/null
+++ b/applications/ColossalChat/coati/models/loss.py
@@ -0,0 +1,169 @@
+"""
+loss functions
+"""
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+
+from .utils import masked_mean
+
+
+class GPTLMLoss(nn.Module):
+ """
+ GPT Language Model Loss
+ """
+
+ def __init__(self):
+ super().__init__()
+ # NOTE: default ignore_index is -100, which is equal to IGNORE_INDEX in sft_dataset.py
+ self.loss = nn.CrossEntropyLoss()
+
+ def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+
+class PolicyLoss(nn.Module):
+ """
+ Policy Loss for PPO
+ """
+
+ def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0) -> None:
+ super().__init__()
+ self.clip_eps = clip_eps
+ self.skip_threshold = skip_threshold
+
+ def forward(
+ self,
+ log_probs: torch.Tensor,
+ old_log_probs: torch.Tensor,
+ advantages: torch.Tensor,
+ action_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ skip = False
+ ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
+
+ # note that if dropout is disabled (recommanded), ratio will always be 1.
+ if ratio_.mean() > self.skip_threshold:
+ skip = True
+
+ ratio = ratio_.clamp(0.0, 10.0)
+ surr1 = ratio * advantages
+ surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
+ loss = -torch.min(surr1, surr2)
+ loss = masked_mean(loss, action_mask)
+ loss = loss.mean()
+ return loss, skip, ratio_.max()
+
+
+class ValueLoss(nn.Module):
+ """
+ Value Loss for PPO
+ """
+
+ def __init__(self, clip_eps: float = 0.2) -> None:
+ super().__init__()
+ self.clip_eps = clip_eps
+
+ def forward(
+ self,
+ values: torch.Tensor,
+ old_values: torch.Tensor,
+ advantage: torch.Tensor,
+ action_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ returns = advantage + old_values
+ values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
+ surr1 = (values_clipped - returns) ** 2
+ surr2 = (values - returns) ** 2
+ loss = torch.max(surr1, surr2) / torch.sum(action_mask)
+ loss = torch.sum(loss * action_mask)
+ return 0.5 * loss
+
+
+class DpoLoss(nn.Module):
+ """
+ Dpo loss
+ Details: https://arxiv.org/pdf/2305.18290.pdf
+ """
+
+ def __init__(self, beta: float = 0.1):
+ super().__init__()
+ self.beta = beta
+
+ def forward(
+ self,
+ logprob_actor_chosen: torch.Tensor,
+ logprob_actor_reject: torch.Tensor,
+ logprob_ref_chosen: torch.Tensor,
+ logprob_ref_reject: torch.Tensor,
+ chosen_mask: torch.Tensor,
+ reject_mask: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Compute the DPO loss for a batch of policy and reference model log probabilities.
+
+ # adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L328
+
+ Args:
+ logprob_actor_chosen: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
+ logprob_actor_reject: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
+ logprob_ref_chosen: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
+ logprob_ref_reject: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
+
+ Returns:
+ A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
+ The losses tensor contains the DPO loss for each example in the batch.
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
+ """
+ logprob_actor_chosen = logprob_actor_chosen * chosen_mask
+ logprob_actor_reject = logprob_actor_reject * reject_mask
+ if logprob_ref_chosen is not None and logprob_ref_reject is not None:
+ logprob_ref_chosen = logprob_ref_chosen * chosen_mask
+ logprob_ref_reject = logprob_ref_reject * reject_mask
+ if len(logprob_ref_chosen.shape) == 2:
+ ref_logratios = logprob_ref_chosen.sum(-1) - logprob_ref_reject.sum(-1)
+ else:
+ ref_logratios = logprob_ref_chosen.squeeze() - logprob_ref_reject.squeeze()
+ else:
+ # If no reference model is provided
+ ref_logratios = 0.0
+
+ pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1)
+ logits = pi_logratios - ref_logratios
+ losses = -torch.nn.functional.logsigmoid(self.beta * logits)
+
+ # Calculate rewards for logging
+ if logprob_ref_chosen is not None:
+ chosen_rewards = self.beta * (logprob_actor_chosen.sum(-1) - logprob_ref_chosen.sum(-1)).detach()
+ else:
+ chosen_rewards = self.beta * logprob_actor_chosen.sum(-1).detach()
+ if logprob_ref_reject is not None:
+ rejected_rewards = self.beta * (logprob_actor_reject.sum(-1) - logprob_ref_reject.sum(-1)).detach()
+ else:
+ rejected_rewards = self.beta * logprob_actor_reject.sum(-1).detach()
+
+ return losses, chosen_rewards, rejected_rewards
+
+
+class LogSigLoss(nn.Module):
+ """
+ Pairwise Loss for Reward Model
+ Details: https://arxiv.org/abs/2203.02155
+ """
+
+ def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
+ return -torch.nn.functional.logsigmoid(chosen_reward - reject_reward).mean()
+
+
+class LogExpLoss(nn.Module):
+ """
+ Pairwise Loss for Reward Model
+ Details: https://arxiv.org/abs/2204.05862
+ """
+
+ def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
+ loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
+ return loss
diff --git a/applications/ColossalChat/coati/models/reward_model.py b/applications/ColossalChat/coati/models/reward_model.py
new file mode 100755
index 000000000000..18c5eca41a71
--- /dev/null
+++ b/applications/ColossalChat/coati/models/reward_model.py
@@ -0,0 +1,38 @@
+"""
+reward model
+"""
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from coati.models import BaseModel
+from transformers import PretrainedConfig
+
+
+class RewardModel(BaseModel):
+ """
+ Reward model class.
+
+ Args:
+ pretrained str: huggingface or local model path
+ config: PretrainedConfig object
+ **kwargs: all other kwargs as in AutoModel.from_pretrained
+ """
+
+ def __init__(self, pretrained: str = None, config: Optional[PretrainedConfig] = None, **kwargs) -> None:
+ super().__init__(pretrained=pretrained, config=config, **kwargs)
+ self.value_head = nn.Linear(self.last_hidden_state_size, 1)
+ self.value_head.weight.data.normal_(mean=0.0, std=1 / (self.last_hidden_state_size + 1))
+
+ def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ outputs = self.model(input_ids, attention_mask=attention_mask)
+
+ last_hidden_states = outputs["last_hidden_state"]
+ sequence_lengths = torch.max(attention_mask * torch.arange(input_ids.size(1), device=input_ids.device), dim=1)[
+ 0
+ ]
+ sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths].type(
+ self.value_head.weight.dtype
+ )
+ values = self.value_head(sequence_hidden_states).squeeze(-1) # Ensure shape is (B,)
+ return values
diff --git a/applications/ColossalChat/coati/models/utils.py b/applications/ColossalChat/coati/models/utils.py
new file mode 100755
index 000000000000..ce672534c28e
--- /dev/null
+++ b/applications/ColossalChat/coati/models/utils.py
@@ -0,0 +1,137 @@
+import json
+import os
+from typing import Any, Dict, Optional, Union
+
+import torch
+import torch.nn.functional as F
+
+
+def get_model_numel(model: torch.nn.Module) -> int:
+ return sum(p.numel() for p in model.parameters())
+
+
+def compute_reward(
+ r: Union[torch.Tensor, float],
+ kl_coef: float,
+ log_probs: torch.Tensor,
+ log_probs_base: torch.Tensor,
+ action_mask: Optional[torch.Tensor] = None,
+ reward_eps=5,
+) -> torch.Tensor:
+ """
+ Args:
+ log_probs: [batch_size, response_length]
+ log_probs_base: [batch_size, response_length]
+ action_mask: [batch_size, response_length]
+ r: float
+ Returns:
+ reward: [batch_size, response_length]
+ """
+ log_ratio = log_probs - log_probs_base # address numerical instability issue
+ kl = -kl_coef * log_ratio * action_mask
+ reward = kl
+ r_clip = torch.clamp(r, -reward_eps, reward_eps)
+ for i in range(action_mask.size(0)):
+ assert action_mask[i].sum() > 0
+ reward[i, : action_mask[i].sum()] += r_clip[i]
+ reward[i, action_mask[i].sum() :] *= 0
+ return reward, ((log_ratio * (log_ratio < 10)).exp() - 1 - log_ratio) * action_mask
+
+
+def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the log probabilities from logits for the given labels.
+
+ Args:
+ logits (torch.Tensor): The input logits.
+ labels (torch.Tensor): The target labels.
+
+ Returns:
+ torch.Tensor: The log probabilities corresponding to the labels.
+ """
+ log_probs = F.log_softmax(logits, dim=-1)
+ log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
+ return log_probs_labels.squeeze(-1)
+
+
+def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
+ """Calculate action log probs.
+
+ Args:
+ output (torch.Tensor): Output tensor of Actor.forward.logits.
+ sequences (torch.LongTensor): Input sequences.
+ num_actions (int): Number of actions.
+
+ Returns:
+ torch.Tensor: Action log probs.
+ """
+ log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
+ return log_probs[:, -num_actions:]
+
+
+def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
+ """
+ Compute the masked mean of a tensor along a specified dimension.
+
+ Args:
+ tensor (torch.Tensor): The input tensor.
+ mask (torch.Tensor): The mask tensor with the same shape as the input tensor.
+ dim (int, optional): The dimension along which to compute the mean. Default is 1.
+
+ Returns:
+ torch.Tensor: The masked mean tensor.
+
+ """
+ tensor = tensor * mask
+ tensor = tensor.sum(dim=dim)
+ mask_sum = mask.sum(dim=dim)
+ mean = tensor / (mask_sum + 1e-8)
+ return mean
+
+
+def calc_masked_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, mask: torch.Tensor) -> torch.Tensor:
+ """
+ Calculate the masked log probabilities for a given sequence of logits.
+
+ Args:
+ logits (torch.Tensor): The input logits tensor of shape (batch_size, sequence_length, vocab_size).
+ sequences (torch.LongTensor): The input sequence tensor of shape (batch_size, sequence_length).
+ mask (torch.Tensor): The mask tensor of shape (batch_size, sequence_length).
+
+ Returns:
+ torch.Tensor: The masked log probabilities tensor of shape (batch_size, sequence_length - 1).
+ """
+ # logits are probabilities of the next token, so we shift them to the left by one
+ log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
+ return log_probs * mask
+
+
+def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
+ """
+ Load file in JSON format
+ """
+ with open(file=file_path, mode="r", encoding="utf-8") as fp:
+ return json.load(fp)
+
+
+def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
+ """
+ Save as JSON format
+ """
+ with open(file=file_path, mode="w", encoding="utf-8") as fp:
+ json.dump(data, fp=fp, ensure_ascii=False, indent=4)
+
+
+def disable_dropout(model: torch.nn.Module):
+ """
+ Disables dropout in a PyTorch model. This is used in PPO Training
+
+ Args:
+ model (torch.nn.Module): The PyTorch model.
+
+ Returns:
+ None
+ """
+ for module in model.modules():
+ if isinstance(module, torch.nn.Dropout):
+ module.p = 0.0
diff --git a/applications/Chat/coati/quant/__init__.py b/applications/ColossalChat/coati/quant/__init__.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/coati/quant/__init__.py
rename to applications/ColossalChat/coati/quant/__init__.py
diff --git a/applications/Chat/coati/quant/llama_gptq/__init__.py b/applications/ColossalChat/coati/quant/llama_gptq/__init__.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/coati/quant/llama_gptq/__init__.py
rename to applications/ColossalChat/coati/quant/llama_gptq/__init__.py
diff --git a/applications/Chat/coati/quant/llama_gptq/loader.py b/applications/ColossalChat/coati/quant/llama_gptq/loader.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/coati/quant/llama_gptq/loader.py
rename to applications/ColossalChat/coati/quant/llama_gptq/loader.py
diff --git a/applications/Chat/coati/quant/llama_gptq/model_utils.py b/applications/ColossalChat/coati/quant/llama_gptq/model_utils.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/coati/quant/llama_gptq/model_utils.py
rename to applications/ColossalChat/coati/quant/llama_gptq/model_utils.py
diff --git a/applications/Chat/coati/quant/llama_gptq/quant.py b/applications/ColossalChat/coati/quant/llama_gptq/quant.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/coati/quant/llama_gptq/quant.py
rename to applications/ColossalChat/coati/quant/llama_gptq/quant.py
diff --git a/applications/Chat/coati/quant/utils.py b/applications/ColossalChat/coati/quant/utils.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/coati/quant/utils.py
rename to applications/ColossalChat/coati/quant/utils.py
diff --git a/applications/Chat/coati/ray/README.md b/applications/ColossalChat/coati/ray/README.md
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/coati/ray/README.md
rename to applications/ColossalChat/coati/ray/README.md
diff --git a/applications/Chat/coati/ray/__init__.py b/applications/ColossalChat/coati/ray/__init__.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/coati/ray/__init__.py
rename to applications/ColossalChat/coati/ray/__init__.py
diff --git a/applications/Chat/coati/ray/callbacks/__init__.py b/applications/ColossalChat/coati/ray/callbacks/__init__.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/coati/ray/callbacks/__init__.py
rename to applications/ColossalChat/coati/ray/callbacks/__init__.py
diff --git a/applications/Chat/coati/ray/callbacks/base.py b/applications/ColossalChat/coati/ray/callbacks/base.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/coati/ray/callbacks/base.py
rename to applications/ColossalChat/coati/ray/callbacks/base.py
diff --git a/applications/Chat/coati/ray/callbacks/performance_evaluator.py b/applications/ColossalChat/coati/ray/callbacks/performance_evaluator.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/coati/ray/callbacks/performance_evaluator.py
rename to applications/ColossalChat/coati/ray/callbacks/performance_evaluator.py
diff --git a/applications/Chat/coati/ray/detached_replay_buffer.py b/applications/ColossalChat/coati/ray/detached_replay_buffer.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/coati/ray/detached_replay_buffer.py
rename to applications/ColossalChat/coati/ray/detached_replay_buffer.py
diff --git a/applications/Chat/coati/ray/detached_trainer_base.py b/applications/ColossalChat/coati/ray/detached_trainer_base.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/coati/ray/detached_trainer_base.py
rename to applications/ColossalChat/coati/ray/detached_trainer_base.py
diff --git a/applications/Chat/coati/ray/detached_trainer_ppo.py b/applications/ColossalChat/coati/ray/detached_trainer_ppo.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/coati/ray/detached_trainer_ppo.py
rename to applications/ColossalChat/coati/ray/detached_trainer_ppo.py
diff --git a/applications/Chat/coati/ray/experience_maker_holder.py b/applications/ColossalChat/coati/ray/experience_maker_holder.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/coati/ray/experience_maker_holder.py
rename to applications/ColossalChat/coati/ray/experience_maker_holder.py
diff --git a/applications/Chat/coati/ray/lora_constructor.py b/applications/ColossalChat/coati/ray/lora_constructor.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/coati/ray/lora_constructor.py
rename to applications/ColossalChat/coati/ray/lora_constructor.py
diff --git a/applications/Chat/coati/ray/utils.py b/applications/ColossalChat/coati/ray/utils.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/coati/ray/utils.py
rename to applications/ColossalChat/coati/ray/utils.py
diff --git a/applications/ColossalChat/coati/trainer/__init__.py b/applications/ColossalChat/coati/trainer/__init__.py
new file mode 100755
index 000000000000..2eff8ca7676a
--- /dev/null
+++ b/applications/ColossalChat/coati/trainer/__init__.py
@@ -0,0 +1,7 @@
+from .base import OLTrainer, SLTrainer
+from .dpo import DPOTrainer
+from .ppo import PPOTrainer
+from .rm import RewardModelTrainer
+from .sft import SFTTrainer
+
+__all__ = ["SLTrainer", "OLTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer", "DPOTrainer"]
diff --git a/applications/Chat/coati/trainer/base.py b/applications/ColossalChat/coati/trainer/base.py
old mode 100644
new mode 100755
similarity index 80%
rename from applications/Chat/coati/trainer/base.py
rename to applications/ColossalChat/coati/trainer/base.py
index 0a41d450d41e..63c903a51940
--- a/applications/Chat/coati/trainer/base.py
+++ b/applications/ColossalChat/coati/trainer/base.py
@@ -1,6 +1,14 @@
+"""
+Base trainers for online and offline training
+ SLTrainer: supervised learning trainer
+ pretrain, sft, dpo, reward model training
+ OLTrainer: online learning trainer
+ rlhf-ppo
+"""
+
from abc import ABC, abstractmethod
from contextlib import contextmanager
-from typing import List
+from typing import Callable, List
import torch.nn as nn
import tqdm
@@ -8,8 +16,8 @@
from coati.experience_maker import Experience
from torch.optim import Optimizer
-from .callbacks import Callback
-from .strategies import Strategy
+from colossalai.booster import Booster
+
from .utils import is_rank_0
@@ -26,16 +34,18 @@ class SLTrainer(ABC):
def __init__(
self,
- strategy: Strategy,
+ booster: Booster,
max_epochs: int,
model: nn.Module,
optimizer: Optimizer,
+ start_epoch: int = 0,
) -> None:
super().__init__()
- self.strategy = strategy
+ self.booster = booster
self.max_epochs = max_epochs
self.model = model
self.optimizer = optimizer
+ self.start_epoch = start_epoch
@abstractmethod
def _train(self, epoch):
@@ -45,19 +55,20 @@ def _train(self, epoch):
def _eval(self, epoch):
raise NotImplementedError()
+ @abstractmethod
def _before_fit(self):
raise NotImplementedError()
def fit(self, *args, **kwargs):
self._before_fit(*args, **kwargs)
- for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0()):
+ for epoch in tqdm.trange(self.start_epoch, self.max_epochs, desc="Epochs", disable=not is_rank_0()):
self._train(epoch)
self._eval(epoch)
-class OnPolicyTrainer(ABC):
+class OLTrainer(ABC):
"""
- Base class for on-policy rl trainers, e.g. PPO.
+ Base class for online learning trainers, e.g. PPO.
Args:
strategy (Strategy):the strategy to use for training
@@ -69,14 +80,16 @@ class OnPolicyTrainer(ABC):
def __init__(
self,
- strategy: Strategy,
+ actor_booster: Booster,
+ critic_booster: Booster,
data_buffer: NaiveExperienceBuffer,
sample_buffer: bool,
dataloader_pin_memory: bool,
- callbacks: List[Callback] = [],
+ callbacks: List[Callable] = [],
) -> None:
super().__init__()
- self.strategy = strategy
+ self.actor_booster = actor_booster
+ self.critic_booster = critic_booster
self.data_buffer = data_buffer
self.sample_buffer = sample_buffer
self.dataloader_pin_memory = dataloader_pin_memory
@@ -141,6 +154,20 @@ def _learn(self, update_step: int):
"""
raise NotImplementedError()
+ @abstractmethod
+ def _setup_update_phrase_dataload(self):
+ """
+ Implement this method to setup dataloader for update phase.
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def _save_checkpoint(self, episode: int = 0):
+ """
+ Implement this method to save checkpoint.
+ """
+ raise NotImplementedError()
+
def _collect_phase(self, collect_step: int):
self._on_make_experience_start()
experience = self._make_experience(collect_step)
@@ -178,11 +205,10 @@ def fit(
for collect_step in tqdm.trange(num_collect_steps, desc="Collect steps", disable=not is_rank_0()):
self._collect_phase(collect_step)
if not self.sample_buffer:
- # HACK(cwher): according to the design of boost API, dataloader should also be boosted,
- # but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
- # I only call strategy.setup_dataloader() to setup dataloader.
- self.dataloader = self.strategy.setup_dataloader(self.data_buffer, self.dataloader_pin_memory)
+ self._setup_update_phrase_dataload()
for update_step in tqdm.trange(num_update_steps, desc="Update steps", disable=not is_rank_0()):
self._update_phase(update_step)
# NOTE: this is for on-policy algorithms
self.data_buffer.clear()
+ if self.save_interval > 0 and (episode + 1) % (self.save_interval) == 0:
+ self._save_checkpoint(episode + 1)
diff --git a/applications/ColossalChat/coati/trainer/callbacks/__init__.py b/applications/ColossalChat/coati/trainer/callbacks/__init__.py
new file mode 100644
index 000000000000..a765485072c1
--- /dev/null
+++ b/applications/ColossalChat/coati/trainer/callbacks/__init__.py
@@ -0,0 +1,4 @@
+from .base import Callback
+from .performance_evaluator import PerformanceEvaluator
+
+__all__ = ["Callback", "PerformanceEvaluator"]
diff --git a/applications/Chat/coati/trainer/callbacks/base.py b/applications/ColossalChat/coati/trainer/callbacks/base.py
similarity index 100%
rename from applications/Chat/coati/trainer/callbacks/base.py
rename to applications/ColossalChat/coati/trainer/callbacks/base.py
diff --git a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py b/applications/ColossalChat/coati/trainer/callbacks/performance_evaluator.py
similarity index 92%
rename from applications/Chat/coati/trainer/callbacks/performance_evaluator.py
rename to applications/ColossalChat/coati/trainer/callbacks/performance_evaluator.py
index b286c766c263..86384e5e39fb 100644
--- a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py
+++ b/applications/ColossalChat/coati/trainer/callbacks/performance_evaluator.py
@@ -14,9 +14,11 @@ def get_world_size() -> int:
return 1
-def print_rank_0(*args, **kwargs) -> None:
+def save_eval_result_rank_0(s: str, save_path: str, **kwargs) -> None:
if not dist.is_initialized() or dist.get_rank() == 0:
- print(*args, **kwargs)
+ with open(save_path, "a+") as f:
+ train_config = "; ".join([str(kwargs[key]) for key in kwargs])
+ f.write(train_config + "\n" + s + "\n")
def divide(x: float, y: float) -> float:
@@ -74,6 +76,8 @@ def __init__(
reward_model_num_params: int,
enable_grad_checkpoint: bool = False,
ignore_episodes: int = 0,
+ train_config: Optional[dict] = None,
+ save_path: Optional[str] = None,
) -> None:
super().__init__()
self.world_size = get_world_size()
@@ -92,6 +96,8 @@ def __init__(
self.make_experience_flop: int = 0
self.learn_num_samples: int = 0
self.learn_flop: int = 0
+ self.train_config = train_config
+ self.save_path = save_path
def on_episode_start(self, episode: int) -> None:
self.disable = self.ignore_episodes > 0 and episode < self.ignore_episodes
@@ -172,12 +178,14 @@ def on_fit_end(self) -> None:
make_experience_time_per_sample = divide(avg_make_experience_duration, num_effective_samples)
learn_time_per_sample = divide(avg_learn_duration, num_effective_samples)
- print_rank_0(
+ save_eval_result_rank_0(
f"Performance summary:\n"
+ f"Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n"
+ f"Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n"
+ f"Overall throughput: {avg_overall_throughput:.2f} samples/s\n"
+ f"Overall time per sample: {overall_time_per_sample:.2f} s\n"
+ f"Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n"
- + f"Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%"
+ + f"Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%",
+ self.save_path,
+ **self.train_config,
)
diff --git a/applications/ColossalChat/coati/trainer/dpo.py b/applications/ColossalChat/coati/trainer/dpo.py
new file mode 100755
index 000000000000..cbe7d7ca811a
--- /dev/null
+++ b/applications/ColossalChat/coati/trainer/dpo.py
@@ -0,0 +1,336 @@
+"""
+Dpo trainer
+"""
+
+from typing import Any, Optional
+
+import torch
+from coati.models.loss import DpoLoss
+from coati.models.utils import calc_masked_log_probs
+from coati.trainer.utils import all_reduce_mean
+from coati.utils import AccumulativeMeanMeter, save_checkpoint
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+from torch.utils.data import DataLoader
+from tqdm import trange
+from transformers import PreTrainedTokenizerBase
+
+from colossalai.booster import Booster
+from colossalai.cluster import DistCoordinator
+from colossalai.utils import get_current_device
+
+from .base import SLTrainer
+from .utils import is_rank_0, to_device
+
+
+class DPOTrainer(SLTrainer):
+ """
+ Trainer for PPO algorithm.
+
+ Args:
+ actor (Actor): the actor model in ppo algorithm
+ ref_model (Critic): the reference model in ppo algorithm
+ booster (Strategy): the strategy to use for training
+ actor_optim (Optimizer): the optimizer to use for actor model
+ actor_lr_scheduler (_LRScheduler): the lr scheduler to use for actor model
+ tokenizer (PreTrainedTokenizerBase): the tokenizer to use for encoding
+ max_epochs (int, defaults to 1): the max number of epochs to train
+ beta (float, defaults to 0.1): the beta parameter in dpo loss
+ accumulation_steps (int): the number of steps to accumulate gradients
+ start_epoch (int, defaults to 0): the start epoch, non-zero if resumed from a checkpoint
+ save_interval (int): the interval to save model checkpoints, default to 0, which means no checkpoint will be saved during trainning
+ save_dir (str): the directory to save checkpoints
+ coordinator (DistCoordinator): the coordinator to use for distributed logging
+ """
+
+ def __init__(
+ self,
+ actor: Any,
+ ref_model: Any,
+ booster: Booster,
+ actor_optim: Optimizer,
+ actor_lr_scheduler: _LRScheduler,
+ tokenizer: PreTrainedTokenizerBase,
+ max_epochs: int = 1,
+ beta: float = 0.1,
+ accumulation_steps: int = 1,
+ start_epoch: int = 0,
+ save_interval: int = 0,
+ save_dir: str = None,
+ coordinator: DistCoordinator = None,
+ ) -> None:
+ super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
+ self.ref_model = ref_model
+ self.actor_scheduler = actor_lr_scheduler
+ self.tokenizer = tokenizer
+ self.actor_loss_fn = DpoLoss(beta)
+ self.save_interval = save_interval
+ self.coordinator = coordinator
+ self.save_dir = save_dir
+ self.num_train_step = 0
+ self.accumulation_steps = accumulation_steps
+ self.device = get_current_device()
+ self.accumulative_meter = AccumulativeMeanMeter()
+
+ def _before_fit(
+ self,
+ train_preference_dataloader: DataLoader = None,
+ eval_preference_dataloader: DataLoader = None,
+ log_dir: Optional[str] = None,
+ use_wandb: bool = False,
+ ):
+ """
+ Args:
+ prompt_dataloader (DataLoader): the dataloader to use for prompt data
+ pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
+ """
+ self.train_dataloader = train_preference_dataloader
+ self.eval_dataloader = eval_preference_dataloader
+ self.writer = None
+ if use_wandb and is_rank_0():
+ assert log_dir is not None, "log_dir must be provided when use_wandb is True"
+ import wandb
+
+ self.wandb_run = wandb.init(project="Coati-dpo", sync_tensorboard=True)
+ if log_dir is not None and is_rank_0():
+ import os
+ import time
+
+ from torch.utils.tensorboard import SummaryWriter
+
+ log_dir = os.path.join(log_dir, "dpo")
+ log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
+ self.writer = SummaryWriter(log_dir=log_dir)
+
+ def _train(self, epoch: int):
+ """
+ Args:
+ epoch int: the number of current epoch
+ """
+ self.model.train()
+ self.accumulative_meter.reset()
+ step_bar = trange(
+ len(self.train_dataloader) // self.accumulation_steps,
+ desc=f"Epoch {epoch + 1}/{self.max_epochs}",
+ disable=not is_rank_0(),
+ )
+ for i, batch in enumerate(self.train_dataloader):
+ batch = to_device(batch, self.device)
+ (
+ chosen_input_ids,
+ chosen_attention_mask,
+ chosen_loss_mask,
+ reject_input_ids,
+ reject_attention_mask,
+ reject_loss_mask,
+ ) = (
+ batch["chosen_input_ids"],
+ batch["chosen_attention_mask"],
+ batch["chosen_loss_mask"],
+ batch["reject_input_ids"],
+ batch["reject_attention_mask"],
+ batch["reject_loss_mask"],
+ )
+ reject_loss_mask[:, -1] = False
+ batch_size = chosen_input_ids.size()[0]
+
+ actor_all_logits = self.model(
+ input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
+ attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
+ )["logits"].to(torch.float32)
+ actor_chosen_logits = actor_all_logits[:batch_size]
+ actor_reject_logits = actor_all_logits[batch_size:]
+ logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
+
+ logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
+
+ if self.ref_model is not None:
+ self.ref_model.eval()
+ with torch.no_grad():
+ ref_all_logits = self.ref_model(
+ input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
+ attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
+ )["logits"].to(torch.float32)
+ ref_chosen_logits = ref_all_logits[:batch_size]
+ ref_reject_logits = ref_all_logits[batch_size:]
+ logprob_ref_chosen = calc_masked_log_probs(
+ ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]
+ )
+ logprob_ref_reject = calc_masked_log_probs(
+ ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
+ )
+ else:
+ logprob_ref_chosen = None
+ logprob_ref_reject = None
+
+ losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
+ logprob_actor_chosen,
+ logprob_actor_reject,
+ logprob_ref_chosen if logprob_ref_chosen is not None else None,
+ logprob_ref_reject if logprob_ref_reject is not None else None,
+ chosen_loss_mask[:, 1:],
+ reject_loss_mask[:, 1:],
+ )
+ reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
+
+ # DPO Loss
+ loss = losses.mean()
+
+ self.booster.backward(loss=loss, optimizer=self.optimizer)
+ if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+ self.actor_scheduler.step()
+
+ # sync
+ loss_mean = all_reduce_mean(tensor=loss)
+ chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
+ rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
+ reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
+ self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
+ self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
+ self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
+ self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
+
+ if i % self.accumulation_steps == self.accumulation_steps - 1:
+ self.num_train_step += 1
+ step_bar.update()
+ # logging
+ if self.writer and is_rank_0():
+ self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
+ self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
+ self.writer.add_scalar(
+ "train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
+ )
+ self.writer.add_scalar(
+ "train/rejected_rewards",
+ self.accumulative_meter.get("rejected_rewards"),
+ self.num_train_step,
+ )
+ self.writer.add_scalar(
+ "train/margin",
+ self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
+ self.num_train_step,
+ )
+ self.writer.add_scalar(
+ "train/accuracy",
+ self.accumulative_meter.get("accuracy"),
+ self.num_train_step,
+ )
+ self.accumulative_meter.reset()
+
+ if (self.num_train_step + 1) % self.save_interval == 0:
+ # save checkpoint
+ self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
+ save_checkpoint(
+ save_dir=self.save_dir,
+ booster=self.booster,
+ model=self.model,
+ optimizer=self.optimizer,
+ lr_scheduler=self.actor_scheduler,
+ epoch=epoch,
+ step=i + 1,
+ batch_size=batch_size,
+ coordinator=self.coordinator,
+ )
+ self.coordinator.print_on_master(
+ f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
+ )
+
+ step_bar.close()
+
+ def _eval(self, epoch: int):
+ """
+ Args:
+ epoch int: the number of current epoch
+ """
+ if self.eval_dataloader is None:
+ self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
+ return
+ self.model.eval()
+ self.ref_model.eval()
+ self.coordinator.print_on_master("\nStart evaluation...")
+
+ step_bar = trange(
+ len(self.eval_dataloader),
+ desc=f"Epoch {epoch + 1}/{self.max_epochs}",
+ disable=not is_rank_0(),
+ )
+
+ self.accumulative_meter.reset()
+
+ with torch.no_grad():
+ for i, batch in enumerate(self.eval_dataloader):
+ batch = to_device(batch, self.device)
+ (
+ chosen_input_ids,
+ chosen_attention_mask,
+ chosen_loss_mask,
+ reject_input_ids,
+ reject_attention_mask,
+ reject_loss_mask,
+ ) = (
+ batch["chosen_input_ids"],
+ batch["chosen_attention_mask"],
+ batch["chosen_loss_mask"],
+ batch["reject_input_ids"],
+ batch["reject_attention_mask"],
+ batch["reject_loss_mask"],
+ )
+
+ batch_size = chosen_input_ids.size()[0]
+
+ actor_all_logits = self.model(
+ torch.cat([chosen_input_ids, reject_input_ids]),
+ torch.cat([chosen_attention_mask, reject_attention_mask]),
+ )["logits"].to(torch.float32)
+ actor_chosen_logits = actor_all_logits[:batch_size]
+ actor_reject_logits = actor_all_logits[batch_size:]
+
+ logprob_actor_chosen = calc_masked_log_probs(
+ actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]
+ )
+
+ logprob_actor_reject = calc_masked_log_probs(
+ actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
+ )
+
+ self.ref_model.eval()
+
+ ref_all_logits = self.ref_model(
+ torch.cat([chosen_input_ids, reject_input_ids]),
+ torch.cat([chosen_attention_mask, reject_attention_mask]),
+ )["logits"].to(torch.float32)
+ ref_chosen_logits = ref_all_logits[:batch_size]
+ ref_reject_logits = ref_all_logits[batch_size:]
+ logprob_ref_chosen = calc_masked_log_probs(ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
+ logprob_ref_reject = calc_masked_log_probs(ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
+
+ losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
+ logprob_actor_chosen,
+ logprob_actor_reject,
+ logprob_ref_chosen if logprob_ref_chosen is not None else None,
+ logprob_ref_reject if logprob_ref_reject is not None else None,
+ chosen_loss_mask[:, 1:],
+ reject_loss_mask[:, 1:],
+ )
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
+ loss = losses.mean()
+ loss_mean = all_reduce_mean(tensor=loss)
+ chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
+ rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
+ reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
+ self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
+ self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
+ self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
+ self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
+ self.accumulative_meter.add(
+ "margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item()
+ )
+ step_bar.update()
+
+ msg = "Evaluation Result:\n"
+ for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]:
+ msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
+ self.coordinator.print_on_master(msg)
+ step_bar.close()
diff --git a/applications/ColossalChat/coati/trainer/ppo.py b/applications/ColossalChat/coati/trainer/ppo.py
new file mode 100755
index 000000000000..287767669516
--- /dev/null
+++ b/applications/ColossalChat/coati/trainer/ppo.py
@@ -0,0 +1,403 @@
+"""
+PPO trainer
+"""
+
+import os
+from typing import Dict, List, Optional
+
+import torch
+import wandb
+from coati.experience_buffer import NaiveExperienceBuffer
+from coati.experience_maker import Experience, NaiveExperienceMaker
+from coati.models import Critic, RewardModel
+from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
+from coati.models.utils import calc_action_log_probs
+from coati.trainer.callbacks import Callback
+from coati.trainer.utils import all_reduce_mean
+from coati.utils import AccumulativeMeanMeter, save_checkpoint
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+from torch.utils.data import DataLoader, DistributedSampler
+from tqdm import tqdm
+from transformers import PreTrainedModel, PreTrainedTokenizerBase
+
+from colossalai.booster import Booster
+from colossalai.booster.plugin import GeminiPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.utils import get_current_device
+
+from .base import OLTrainer
+from .utils import CycledDataLoader, is_rank_0, to_device
+
+
+def _set_default_generate_kwargs(actor: PreTrainedModel) -> Dict:
+ """
+ Set default keyword arguments for generation based on the actor model.
+
+ Args:
+ actor (PreTrainedModel): The actor model.
+
+ Returns:
+ Dict: A dictionary containing the default keyword arguments for generation.
+ """
+ unwrapped_model = actor.unwrap()
+ new_kwargs = {}
+ # use huggingface models method directly
+ if hasattr(unwrapped_model, "prepare_inputs_for_generation"):
+ new_kwargs["prepare_inputs_fn"] = unwrapped_model.prepare_inputs_for_generation
+
+ if hasattr(unwrapped_model, "_update_model_kwargs_for_generation"):
+ new_kwargs["update_model_kwargs_fn"] = unwrapped_model._update_model_kwargs_for_generation
+ return new_kwargs
+
+
+class PPOTrainer(OLTrainer):
+ """
+ Trainer for PPO algorithm.
+
+ Args:
+ strategy (Booster): the strategy to use for training
+ actor (Actor): the actor model in ppo algorithm
+ critic (Critic): the critic model in ppo algorithm
+ reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences
+ initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor
+ actor_optim (Optimizer): the optimizer to use for actor model
+ critic_optim (Optimizer): the optimizer to use for critic model
+ kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
+ train_batch_size (int, defaults to 8): the batch size to use for training
+ buffer_limit (int, defaults to 0): the max_size limitation of buffer
+ buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu
+ eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
+ vf_coef (float, defaults to 1.0): the coefficient of value loss
+ ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
+ value_clip (float, defaults to 0.4): the clip coefficient of value loss
+ sample_buffer (bool, defaults to False): whether to sample from buffer
+ dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
+ offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process
+ callbacks (List[Callback], defaults to []): the callbacks to call during training process
+ generate_kwargs (dict, optional): the kwargs to use while model generating
+ """
+
+ def __init__(
+ self,
+ actor_booster: Booster,
+ critic_booster: Booster,
+ actor: PreTrainedModel,
+ critic: Critic,
+ reward_model: RewardModel,
+ initial_model: PreTrainedModel,
+ actor_optim: Optimizer,
+ critic_optim: Optimizer,
+ actor_lr_scheduler: _LRScheduler,
+ critic_lr_scheduler: _LRScheduler,
+ tokenizer: PreTrainedTokenizerBase,
+ kl_coef: float = 0.1,
+ ptx_coef: float = 0.9,
+ train_batch_size: int = 8,
+ buffer_limit: int = 0,
+ buffer_cpu_offload: bool = True,
+ eps_clip: float = 0.2,
+ vf_coef: float = 1.0,
+ value_clip: float = 0.2,
+ sample_buffer: bool = False,
+ dataloader_pin_memory: bool = True,
+ offload_inference_models: bool = True,
+ accumulation_steps: int = 1,
+ save_interval: int = 0,
+ save_dir: str = None,
+ use_tp: bool = False,
+ coordinator: DistCoordinator = None,
+ callbacks: List[Callback] = [],
+ **generate_kwargs,
+ ) -> None:
+ if isinstance(actor_booster, GeminiPlugin):
+ assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')"
+
+ data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
+ super().__init__(
+ actor_booster, critic_booster, data_buffer, sample_buffer, dataloader_pin_memory, callbacks=callbacks
+ )
+ self.generate_kwargs = _set_default_generate_kwargs(actor)
+ self.generate_kwargs.update(generate_kwargs)
+
+ self.actor = actor
+ self.critic = critic
+ self.actor_booster = actor_booster
+ self.critic_booster = critic_booster
+ self.actor_scheduler = actor_lr_scheduler
+ self.critic_scheduler = critic_lr_scheduler
+ self.tokenizer = tokenizer
+ self.experience_maker = NaiveExperienceMaker(
+ self.actor, self.critic, reward_model, initial_model, self.tokenizer, kl_coef
+ )
+ self.train_batch_size = train_batch_size
+
+ self.actor_loss_fn = PolicyLoss(eps_clip)
+ self.critic_loss_fn = ValueLoss(value_clip)
+ self.vf_coef = vf_coef
+ self.ptx_loss_fn = GPTLMLoss()
+ self.ptx_coef = ptx_coef
+ self.actor_optim = actor_optim
+ self.critic_optim = critic_optim
+ self.save_interval = save_interval
+ self.coordinator = coordinator
+ self.actor_save_dir = os.path.join(save_dir, "actor")
+ self.critic_save_dir = os.path.join(save_dir, "critic")
+ self.num_train_step = 0
+ self.accumulation_steps = accumulation_steps
+ self.use_tp = use_tp
+ self.accumulative_meter = AccumulativeMeanMeter()
+ self.offload_inference_models = offload_inference_models
+ self.device = get_current_device()
+
+ def _before_fit(
+ self,
+ prompt_dataloader: DataLoader,
+ pretrain_dataloader: Optional[DataLoader] = None,
+ log_dir: Optional[str] = None,
+ use_wandb: bool = False,
+ ):
+ """
+ Args:
+ prompt_dataloader (DataLoader): the dataloader to use for prompt data
+ pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
+ """
+ self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
+ self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader) if pretrain_dataloader is not None else None
+
+ self.writer = None
+ if use_wandb and is_rank_0():
+ assert log_dir is not None, "log_dir must be provided when use_wandb is True"
+ import wandb
+
+ self.wandb_run = wandb.init(project="Coati-ppo", sync_tensorboard=True)
+ if log_dir is not None and is_rank_0():
+ import os
+ import time
+
+ from torch.utils.tensorboard import SummaryWriter
+
+ log_dir = os.path.join(log_dir, "ppo")
+ log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
+ self.writer = SummaryWriter(log_dir=log_dir)
+
+ def _setup_update_phrase_dataload(self):
+ """
+ why not use distributed_dataloader?
+ if tp is used, input on each rank is the same and we use the same dataloader to feed same experience to all ranks
+ if tp is not used, input on each rank is different and we expect different experiences to be fed to each rank
+ """
+ self.dataloader = DataLoader(
+ self.data_buffer,
+ batch_size=self.train_batch_size,
+ shuffle=True,
+ drop_last=True,
+ pin_memory=self.dataloader_pin_memory,
+ collate_fn=self.data_buffer.collate_fn,
+ )
+
+ def _make_experience(self, collect_step: int) -> Experience:
+ """
+ Make experience
+ """
+ prompts = self.prompt_dataloader.next()
+ if self.offload_inference_models:
+ # TODO(ver217): this may be controlled by strategy if they are prepared by strategy
+ self.experience_maker.initial_model.to(self.device)
+ self.experience_maker.reward_model.to(self.device)
+ return self.experience_maker.make_experience(
+ input_ids=prompts["input_ids"].to(get_current_device()),
+ attention_mask=prompts["attention_mask"].to(get_current_device()),
+ **self.generate_kwargs,
+ )
+
+ def _training_step(self, experience: Experience):
+ """
+ Args:
+ experience:
+ sequences: [batch_size, prompt_length + response_length] ---
............
+ """
+ self.num_train_step += 1
+ self.actor.train()
+ self.critic.train()
+ num_actions = experience.action_log_probs.size(1)
+ # policy loss
+
+ actor_logits = self.actor(input_ids=experience.sequences, attention_mask=experience.attention_mask)[
+ "logits"
+ ] # [batch size, prompt_length + response_length]
+ action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
+
+ actor_loss, to_skip, max_ratio = self.actor_loss_fn(
+ action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
+ )
+ actor_loss = (1 - self.ptx_coef) * actor_loss
+ if not to_skip:
+ self.actor_booster.backward(loss=actor_loss, optimizer=self.actor_optim)
+
+ # ptx loss
+ if self.ptx_coef != 0:
+ batch = self.pretrain_dataloader.next()
+ batch = to_device(batch, self.device)
+ outputs = self.actor(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
+ ptx_loss = outputs.loss
+ ptx_loss = self.ptx_coef * ptx_loss
+ self.actor_booster.backward(loss=ptx_loss, optimizer=self.actor_optim)
+
+ # value loss
+ values = self.critic(
+ input_ids=experience.sequences, attention_mask=experience.attention_mask
+ ) # [batch size, prompt_length + response_length]
+ critic_loss = self.critic_loss_fn(
+ values[:, -num_actions:], experience.values, experience.advantages, action_mask=experience.action_mask
+ )
+ critic_loss = critic_loss * self.vf_coef
+ self.critic_booster.backward(loss=critic_loss, optimizer=self.critic_optim)
+
+ # sync
+ actor_loss_mean = all_reduce_mean(tensor=actor_loss)
+ critic_loss_mean = all_reduce_mean(tensor=critic_loss)
+ max_ratio_mean = all_reduce_mean(tensor=max_ratio)
+ reward_mean = all_reduce_mean(tensor=experience.reward.mean())
+ value_mean = all_reduce_mean(tensor=experience.values.mean())
+ advantages_mean = all_reduce_mean(tensor=experience.advantages.mean())
+ kl_mean = all_reduce_mean(tensor=experience.kl.mean())
+ if self.ptx_coef != 0:
+ ptx_loss_mean = all_reduce_mean(tensor=ptx_loss)
+
+ self.accumulative_meter.add("actor_loss", actor_loss_mean.to(torch.float16).mean().item())
+ self.accumulative_meter.add("critic_loss", critic_loss_mean.to(torch.float16).mean().item())
+ self.accumulative_meter.add("max_ratio", max_ratio_mean.to(torch.float16).item())
+ self.accumulative_meter.add("reward", reward_mean.to(torch.float16).mean().item())
+ self.accumulative_meter.add("value", value_mean.to(torch.float16).mean().item())
+ self.accumulative_meter.add("advantages", advantages_mean.to(torch.float16).item())
+ self.accumulative_meter.add("skip_ratio", 1.0 if to_skip else 0.0)
+ self.accumulative_meter.add("kl", kl_mean.to(torch.float16).item())
+ if self.ptx_coef != 0:
+ self.accumulative_meter.add("ptx_loss", ptx_loss_mean.to(torch.float16).mean().item())
+
+ if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
+ self.actor_optim.step()
+ self.critic_optim.step()
+ self.actor_optim.zero_grad()
+ self.critic_optim.zero_grad()
+ self.actor_scheduler.step()
+ self.critic_scheduler.step()
+
+ # preparing logging model output and corresponding rewards.
+ if self.num_train_step % 10 == 1:
+ response_text = self.experience_maker.tokenizer.batch_decode(
+ experience.sequences, skip_special_tokens=True
+ )
+ for i in range(len(response_text)):
+ response_text[i] = response_text[i] + f"\n\nReward: {experience.reward[i]}"
+
+ if self.writer and is_rank_0() and "wandb_run" in self.__dict__:
+ # log output to wandb
+ my_table = wandb.Table(
+ columns=[f"sample response {i}" for i in range(len(response_text))], data=[response_text]
+ )
+ try:
+ self.wandb_run.log({"sample_response": my_table})
+ except OSError as e:
+ self.coordinator.print_on_master(e)
+ elif self.writer and is_rank_0():
+ for line in response_text:
+ self.coordinator.print_on_master(line)
+
+ if self.writer and is_rank_0():
+ self.writer.add_scalar("train/max_ratio", self.accumulative_meter.get("max_ratio"), self.num_train_step)
+ self.writer.add_scalar(
+ "train/skip_ratio", self.accumulative_meter.get("skip_ratio"), self.num_train_step
+ )
+ self.writer.add_scalar(
+ "train/actor_loss", self.accumulative_meter.get("actor_loss"), self.num_train_step
+ )
+ self.writer.add_scalar("train/lr_actor", self.actor_optim.param_groups[0]["lr"], self.num_train_step)
+ self.writer.add_scalar("train/lr_critic", self.critic_optim.param_groups[0]["lr"], self.num_train_step)
+ self.writer.add_scalar(
+ "train/critic_loss", self.accumulative_meter.get("critic_loss"), self.num_train_step
+ )
+ if self.ptx_coef != 0:
+ self.writer.add_scalar(
+ "train/ptx_loss", self.accumulative_meter.get("ptx_loss"), self.num_train_step
+ )
+ self.writer.add_scalar("reward", self.accumulative_meter.get("reward"), self.num_train_step)
+ self.writer.add_scalar("approx_kl", self.accumulative_meter.get("kl"), self.num_train_step)
+ self.writer.add_scalar("value", self.accumulative_meter.get("value"), self.num_train_step)
+ self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), self.num_train_step)
+ self.accumulative_meter.reset()
+
+ def _learn(self, update_step: int):
+ """
+ Perform the learning step of the PPO algorithm.
+
+ Args:
+ update_step (int): The current update step.
+
+ Returns:
+ None
+ """
+ if self.offload_inference_models:
+ self.experience_maker.initial_model.to("cpu")
+ self.experience_maker.reward_model.to("cpu")
+
+ # buffer may be empty at first, we should rebuild at each training
+ if self.sample_buffer:
+ experience = self.data_buffer.sample()
+ self._on_learn_batch_start()
+ experience.to_device(self.device)
+ self._training_step(experience)
+ self._on_learn_batch_end(experience)
+ else:
+ if isinstance(self.dataloader.sampler, DistributedSampler):
+ self.dataloader.sampler.set_epoch(update_step)
+ pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0())
+ for experience in pbar:
+ self._on_learn_batch_start()
+ experience.to_device(self.device)
+ self._training_step(experience)
+ self._on_learn_batch_end(experience)
+
+ def _save_checkpoint(self, episode: int = 0):
+ """
+ Save the actor and critic checkpoints with running states.
+
+ Args:
+ episode (int): The current episode number.
+
+ Returns:
+ None
+ """
+
+ self.coordinator.print_on_master("\nStart saving actor checkpoint with running states")
+ save_checkpoint(
+ save_dir=self.actor_save_dir,
+ booster=self.actor_booster,
+ model=self.actor,
+ optimizer=self.actor_optim,
+ lr_scheduler=self.actor_scheduler,
+ epoch=0,
+ step=episode + 1,
+ batch_size=self.train_batch_size,
+ coordinator=self.coordinator,
+ )
+ self.coordinator.print_on_master(
+ f"Saved actor checkpoint at episode {(episode + 1)} at folder {self.actor_save_dir}"
+ )
+
+ self.coordinator.print_on_master("\nStart saving critic checkpoint with running states")
+ save_checkpoint(
+ save_dir=self.critic_save_dir,
+ booster=self.critic_booster,
+ model=self.critic,
+ optimizer=self.critic_optim,
+ lr_scheduler=self.critic_scheduler,
+ epoch=0,
+ step=episode + 1,
+ batch_size=self.train_batch_size,
+ coordinator=self.coordinator,
+ )
+ self.coordinator.print_on_master(
+ f"Saved critic checkpoint at episode {(episode + 1)} at folder {self.critic_save_dir}"
+ )
diff --git a/applications/ColossalChat/coati/trainer/rm.py b/applications/ColossalChat/coati/trainer/rm.py
new file mode 100755
index 000000000000..0fb714a62bce
--- /dev/null
+++ b/applications/ColossalChat/coati/trainer/rm.py
@@ -0,0 +1,242 @@
+"""
+Reward model trianer
+"""
+
+import os
+from typing import Any, Callable, Optional
+
+import torch
+import tqdm
+from coati.models import LogSigLoss
+from coati.trainer.utils import all_reduce_mean
+from coati.utils import AccumulativeMeanMeter, save_checkpoint
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+from torch.utils.data import DataLoader
+from transformers import PreTrainedTokenizerBase
+
+from colossalai.booster import Booster
+from colossalai.cluster import DistCoordinator
+from colossalai.utils import get_current_device
+
+from .base import SLTrainer
+from .utils import is_rank_0, to_device
+
+
+class RewardModelTrainer(SLTrainer):
+ """
+ Trainer for PPO algorithm.
+
+ Args:
+ actor (Actor): the actor model in ppo algorithm
+ ref_model (Critic): the reference model in ppo algorithm
+ booster (Strategy): the strategy to use for training
+ actor_optim (Optimizer): the optimizer to use for actor model
+ actor_lr_scheduler (_LRScheduler): the lr scheduler to use for actor model
+ tokenizer (PreTrainedTokenizerBase): the tokenizer to use for encoding
+ max_epochs (int, defaults to 1): the max number of epochs to train
+ beta (float, defaults to 0.1): the beta parameter in dpo loss
+ accumulation_steps (int): the number of steps to accumulate gradients
+ start_epoch (int, defaults to 0): the start epoch, non-zero if resumed from a checkpoint
+ save_interval (int): the interval to save model checkpoints, default to 0, which means no checkpoint will be saved during trainning
+ save_dir (str): the directory to save checkpoints
+ coordinator (DistCoordinator): the coordinator to use for distributed logging
+ """
+
+ def __init__(
+ self,
+ model: Any,
+ booster: Booster,
+ optimizer: Optimizer,
+ lr_scheduler: _LRScheduler,
+ tokenizer: PreTrainedTokenizerBase,
+ loss_fn: Optional[Callable] = None,
+ max_epochs: int = 1,
+ beta: float = 0.1,
+ accumulation_steps: int = 1,
+ start_epoch: int = 0,
+ save_interval: int = 0,
+ save_dir: str = None,
+ coordinator: DistCoordinator = None,
+ ) -> None:
+ super().__init__(booster, max_epochs=max_epochs, model=model, optimizer=optimizer, start_epoch=start_epoch)
+ self.actor_scheduler = lr_scheduler
+ self.tokenizer = tokenizer
+ self.loss_fn = loss_fn if loss_fn is not None else LogSigLoss(beta=beta)
+ self.save_interval = save_interval
+ self.coordinator = coordinator
+ self.save_dir = save_dir
+ self.num_train_step = 0
+ self.accumulation_steps = accumulation_steps
+ self.device = get_current_device()
+ self.accumulative_meter = AccumulativeMeanMeter()
+
+ def _before_fit(
+ self,
+ train_preference_dataloader: DataLoader = None,
+ eval_preference_dataloader: DataLoader = None,
+ log_dir: Optional[str] = None,
+ use_wandb: bool = False,
+ ):
+ """
+ Args:
+ prompt_dataloader (DataLoader): the dataloader to use for prompt data
+ pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
+ """
+ self.train_dataloader = train_preference_dataloader
+ self.eval_dataloader = eval_preference_dataloader
+ self.writer = None
+ if use_wandb and is_rank_0():
+ assert log_dir is not None, "log_dir must be provided when use_wandb is True"
+ import wandb
+
+ self.wandb_run = wandb.init(project="Coati-rm", sync_tensorboard=True)
+ if log_dir is not None and is_rank_0():
+ import os
+ import time
+
+ from torch.utils.tensorboard import SummaryWriter
+
+ log_dir = os.path.join(log_dir, "rm")
+ log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
+ self.writer = SummaryWriter(log_dir=log_dir)
+
+ def _train(self, epoch):
+ self.model.train()
+ step_bar = tqdm.trange(
+ len(self.train_dataloader) // self.accumulation_steps,
+ desc=f"Epoch {epoch + 1}/{self.max_epochs}",
+ disable=not is_rank_0(),
+ )
+ for i, batch in enumerate(self.train_dataloader):
+ batch = to_device(batch, self.device)
+
+ (
+ chosen_input_ids,
+ chosen_attention_mask,
+ reject_input_ids,
+ reject_attention_mask,
+ ) = (
+ batch["chosen_input_ids"],
+ batch["chosen_attention_mask"],
+ batch["reject_input_ids"],
+ batch["reject_attention_mask"],
+ )
+ batch_size = chosen_input_ids.size()[0]
+
+ # Concatenate for better parrallelism
+ reward = self.model(
+ torch.cat([chosen_input_ids, reject_input_ids], dim=0),
+ attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask], dim=0),
+ )
+ chosen_reward = reward[:batch_size]
+ reject_reward = reward[batch_size:]
+ loss = self.loss_fn(chosen_reward, reject_reward).mean()
+
+ self.booster.backward(loss=loss, optimizer=self.optimizer)
+
+ accuracy = (chosen_reward > reject_reward).float()
+
+ # Sync
+ loss_mean = all_reduce_mean(tensor=loss)
+ chosen_rewards_mean = all_reduce_mean(tensor=chosen_reward)
+ rejected_rewards_mean = all_reduce_mean(tensor=reject_reward)
+ accuracy_mean = all_reduce_mean(tensor=accuracy)
+ self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
+ self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
+ self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
+ self.accumulative_meter.add("accuracy", accuracy_mean.mean().to(torch.float16).item())
+
+ if (i + 1) % self.accumulation_steps == 0:
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+ self.actor_scheduler.step()
+ step_bar.update()
+ self.num_train_step += 1
+
+ # Logging
+ if self.writer and is_rank_0():
+ self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
+ self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
+ self.writer.add_scalar(
+ "train/dist",
+ self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
+ self.num_train_step,
+ )
+ self.writer.add_scalar(
+ "train/reward_chosen", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
+ )
+ self.writer.add_scalar(
+ "train/reward_reject", self.accumulative_meter.get("rejected_rewards"), self.num_train_step
+ )
+ self.writer.add_scalar("train/acc", self.accumulative_meter.get("accuracy"), self.num_train_step)
+
+ self.accumulative_meter.reset()
+
+ # Save checkpoint
+ if self.save_interval > 0 and (self.num_train_step + 1) % self.save_interval == 0:
+ self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
+ save_checkpoint(
+ save_dir=self.save_dir,
+ booster=self.booster,
+ model=self.model,
+ optimizer=self.optimizer,
+ lr_scheduler=self.actor_scheduler,
+ epoch=epoch,
+ step=i + 1,
+ batch_size=batch_size,
+ coordinator=self.coordinator,
+ )
+ self.coordinator.print_on_master(
+ f"Saved checkpoint at epoch {epoch} step {(i + 1)/self.accumulation_steps} at folder {self.save_dir}"
+ )
+ step_bar.close()
+
+ def _eval(self, epoch):
+ if self.eval_dataloader is None:
+ self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
+ return
+ self.model.eval()
+ step_bar = tqdm.trange(
+ len(self.eval_dataloader), desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not is_rank_0()
+ )
+ with torch.no_grad():
+ for i, batch in enumerate(self.eval_dataloader):
+ batch = to_device(batch, self.device)
+ (
+ chosen_input_ids,
+ chosen_attention_mask,
+ reject_input_ids,
+ reject_attention_mask,
+ ) = (
+ batch["chosen_input_ids"],
+ batch["chosen_attention_mask"],
+ batch["reject_input_ids"],
+ batch["reject_attention_mask"],
+ )
+
+ chosen_reward = self.model(chosen_input_ids, attention_mask=chosen_attention_mask)
+ reject_reward = self.model(reject_input_ids, attention_mask=reject_attention_mask)
+ loss = self.loss_fn(chosen_reward, reject_reward).mean()
+
+ # Sync
+ loss_mean = all_reduce_mean(tensor=loss)
+ chosen_rewards_mean = all_reduce_mean(tensor=chosen_reward)
+ rejected_rewards_mean = all_reduce_mean(tensor=reject_reward)
+ self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
+ self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
+ self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
+
+ step_bar.update()
+
+ msg = "Evaluation Result:\n"
+ for tag in ["loss", "chosen_rewards", "rejected_rewards"]:
+ msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
+ msg = (
+ msg
+ + f"distance: {self.accumulative_meter.get('chosen_rewards')-self.accumulative_meter.get('rejected_rewards')}\n"
+ )
+ self.coordinator.print_on_master(msg)
+ with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
+ f.write(msg)
+ step_bar.close()
diff --git a/applications/ColossalChat/coati/trainer/sft.py b/applications/ColossalChat/coati/trainer/sft.py
new file mode 100755
index 000000000000..c95f5b65a822
--- /dev/null
+++ b/applications/ColossalChat/coati/trainer/sft.py
@@ -0,0 +1,170 @@
+"""
+SFT trainer
+"""
+
+import os
+from typing import Optional
+
+import torch
+from coati.trainer.utils import all_reduce_mean
+from coati.utils import AccumulativeMeanMeter, save_checkpoint
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+from torch.utils.data import DataLoader
+from tqdm import trange
+
+from colossalai.booster import Booster
+from colossalai.cluster import DistCoordinator
+
+from .base import SLTrainer
+from .utils import is_rank_0, to_device
+
+
+class SFTTrainer(SLTrainer):
+ """
+ Trainer to use while training reward model.
+
+ Args:
+ model (torch.nn.Module): the model to train
+ strategy (Strategy): the strategy to use for training
+ optim(Optimizer): the optimizer to use for training
+ lr_scheduler(_LRScheduler): the lr scheduler to use for training
+ max_epochs (int, defaults to 2): the number of epochs to train
+ accumulation_steps (int, defaults to 8): the number of steps to accumulate gradients
+ """
+
+ def __init__(
+ self,
+ model,
+ booster: Booster,
+ optim: Optimizer,
+ lr_scheduler: _LRScheduler,
+ max_epochs: int = 2,
+ accumulation_steps: int = 8,
+ start_epoch=0,
+ save_interval: int = None,
+ save_dir: str = None,
+ coordinator: Optional[DistCoordinator] = None,
+ ) -> None:
+ super().__init__(booster, max_epochs, model, optim, start_epoch=start_epoch)
+
+ self.accumulation_steps = accumulation_steps
+ self.scheduler = lr_scheduler
+ self.save_interval = save_interval
+ self.save_dir = save_dir
+ self.coordinator = coordinator
+ self.num_train_step = 0
+ self.num_eval_step = 0
+ self.accumulative_meter = AccumulativeMeanMeter()
+
+ def _before_fit(
+ self,
+ train_dataloader: DataLoader,
+ eval_dataloader: Optional[DataLoader] = None,
+ log_dir: Optional[str] = None,
+ use_wandb: bool = False,
+ ):
+ """
+ Args:
+ train_dataloader: the dataloader to use for training
+ eval_dataloader: the dataloader to use for evaluation
+ log_dir: the directory to save logs
+ use_wandb: whether to use wandb for logging
+ """
+ self.train_dataloader = train_dataloader
+ self.eval_dataloader = eval_dataloader
+
+ self.writer = None
+ if use_wandb and is_rank_0():
+ assert log_dir is not None, "log_dir must be provided when use_wandb is True"
+ import wandb
+
+ wandb.init(project="Coati-sft", sync_tensorboard=True)
+ if log_dir is not None and is_rank_0():
+ import os
+ import time
+
+ from torch.utils.tensorboard import SummaryWriter
+
+ log_dir = os.path.join(log_dir, "sft")
+ log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
+ self.writer = SummaryWriter(log_dir=log_dir)
+
+ def _train(self, epoch: int):
+ self.model.train()
+ step_bar = trange(
+ len(self.train_dataloader) // self.accumulation_steps,
+ desc=f"Epoch {epoch + 1}/{self.max_epochs}",
+ disable=not is_rank_0(),
+ )
+ for i, batch in enumerate(self.train_dataloader):
+ batch = to_device(batch, torch.cuda.current_device())
+ batch_size = batch["input_ids"].size(0)
+ outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
+ loss = outputs.loss
+ self.booster.backward(loss=loss, optimizer=self.optimizer)
+
+ loss_mean = all_reduce_mean(tensor=loss)
+ self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
+
+ # Gradient accumulation
+ if (i + 1) % self.accumulation_steps == 0:
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+ self.scheduler.step()
+
+ if self.writer:
+ self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
+ self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
+ self.num_train_step += 1
+ self.accumulative_meter.reset()
+ step_bar.update()
+
+ # Save checkpoint
+ if (
+ self.save_dir is not None
+ and self.save_interval is not None
+ and (self.num_train_step + 1) % self.save_interval == 0
+ ):
+ save_checkpoint(
+ save_dir=self.save_dir,
+ booster=self.booster,
+ model=self.model,
+ optimizer=self.optimizer,
+ lr_scheduler=self.scheduler,
+ epoch=epoch,
+ step=self.num_train_step + 1,
+ batch_size=batch_size,
+ coordinator=self.coordinator,
+ )
+ self.coordinator.print_on_master(
+ f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}"
+ )
+ step_bar.close()
+
+ def _eval(self, epoch: int):
+ if self.eval_dataloader is None:
+ self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
+ return
+ self.accumulative_meter.reset()
+ self.model.eval()
+ with torch.no_grad():
+ step_bar = trange(
+ len(self.eval_dataloader),
+ desc=f"Epoch {epoch + 1}/{self.max_epochs}",
+ disable=not is_rank_0(),
+ )
+ for batch in self.eval_dataloader:
+ batch = to_device(batch, torch.cuda.current_device())
+ outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
+ loss_mean = all_reduce_mean(tensor=outputs.loss)
+ self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0))
+ step_bar.update()
+ loss_mean = self.accumulative_meter.get("loss")
+ msg = "Evaluation Result:\n"
+ for tag in ["loss"]:
+ msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
+ self.coordinator.print_on_master(msg)
+ with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
+ f.write(msg)
+ step_bar.close()
diff --git a/applications/ColossalChat/coati/trainer/utils.py b/applications/ColossalChat/coati/trainer/utils.py
new file mode 100755
index 000000000000..5ce1e9ef009c
--- /dev/null
+++ b/applications/ColossalChat/coati/trainer/utils.py
@@ -0,0 +1,113 @@
+"""
+Training utilities for Coati.
+"""
+from typing import Any
+
+import torch
+import torch.distributed as dist
+from torch.utils._pytree import tree_map
+from torch.utils.data import DataLoader
+
+
+class CycledDataLoader:
+ """
+ A data loader that cycles through the data when it reaches the end.
+
+ Args:
+ dataloader (DataLoader): The original data loader.
+
+ Attributes:
+ dataloader (DataLoader): The original data loader.
+ count (int): The number of times the data loader has been cycled.
+ dataloader_iter (iterable): The iterator for the data loader.
+
+ Methods:
+ next(): Returns the next batch of data from the data loader, cycling through the data if necessary.
+ """
+
+ def __init__(
+ self,
+ dataloader: DataLoader,
+ ) -> None:
+ self.dataloader = dataloader
+
+ self.count = 0
+ self.dataloader_iter = None
+
+ def next(self):
+ """
+ Returns the next batch of data from the data loader, cycling through the data if necessary.
+
+ Returns:
+ Any: The next batch of data from the data loader.
+ """
+ # defer initialization
+ if self.dataloader_iter is None:
+ self.dataloader_iter = iter(self.dataloader)
+
+ self.count += 1
+ try:
+ return next(self.dataloader_iter)
+ except StopIteration:
+ self.count = 0
+ self.dataloader_iter = iter(self.dataloader)
+ return next(self.dataloader_iter)
+
+
+def is_rank_0() -> bool:
+ """
+ Check if the current process is the rank 0 process in a distributed training setup.
+
+ Returns:
+ bool: True if the current process is the rank 0 process, False otherwise.
+ """
+ return not dist.is_initialized() or dist.get_rank() == 0
+
+
+def to_device(x: Any, device: torch.device) -> Any:
+ """
+ Move the input tensor or nested structure of tensors to the specified device.
+
+ Args:
+ x (Any): The input tensor or nested structure of tensors.
+ device (torch.device): The target device to move the tensors to.
+
+ Returns:
+ Any: The tensor or nested structure of tensors moved to the target device.
+ """
+
+ def _to(t: Any):
+ if isinstance(t, torch.Tensor):
+ return t.to(device)
+ return t
+
+ return tree_map(_to, x)
+
+
+def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
+ """
+ Perform all-reduce operation on the given tensor and compute the mean across all processes.
+
+ Args:
+ tensor (torch.Tensor): The input tensor to be reduced.
+
+ Returns:
+ torch.Tensor: The reduced tensor with mean computed across all processes.
+ """
+ dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
+ tensor.div_(dist.get_world_size())
+ return tensor
+
+
+def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
+ """
+ Performs an all-reduce operation to sum the values of the given tensor across all processes.
+
+ Args:
+ tensor (torch.Tensor): The input tensor to be reduced.
+
+ Returns:
+ torch.Tensor: The reduced tensor with the sum of values across all processes.
+ """
+ dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
+ return tensor
diff --git a/applications/ColossalChat/coati/utils/__init__.py b/applications/ColossalChat/coati/utils/__init__.py
new file mode 100755
index 000000000000..5d651b202cc8
--- /dev/null
+++ b/applications/ColossalChat/coati/utils/__init__.py
@@ -0,0 +1,4 @@
+from .accumulative_meter import AccumulativeMeanMeter
+from .ckpt_io import load_checkpoint, save_checkpoint
+
+__all__ = ["load_checkpoint", "save_checkpoint", "AccumulativeMeanMeter"]
diff --git a/applications/ColossalChat/coati/utils/accumulative_meter.py b/applications/ColossalChat/coati/utils/accumulative_meter.py
new file mode 100755
index 000000000000..9ed662d2d25a
--- /dev/null
+++ b/applications/ColossalChat/coati/utils/accumulative_meter.py
@@ -0,0 +1,69 @@
+"""
+A class that can be used to calculate the mean of a variable
+"""
+
+
+class AccumulativeMeanVariable:
+ """
+ A class that calculates the accumulative mean of a variable.
+ """
+
+ def __init__(self):
+ self._sum = 0
+ self._count = 0
+
+ def add(self, value, count_update=1):
+ """
+ Adds a value to the sum and updates the count.
+
+ Args:
+ value (float): The value to be added.
+ count_update (int, optional): The amount to update the count by. Defaults to 1.
+ """
+ self._sum += value
+ self._count += count_update
+
+ def get(self):
+ """
+ Calculates and returns the accumulative mean.
+
+ Returns:
+ float: The accumulative mean.
+ """
+ return self._sum / self._count if self._count > 0 else 0
+
+ def reset(self):
+ """
+ Resets the sum and count to zero.
+ """
+ self._sum = 0
+ self._count = 0
+
+
+class AccumulativeMeanMeter:
+ """
+ A class for calculating and storing the accumulative mean of variables.
+
+ Attributes:
+ variable_dict (dict): A dictionary to store the accumulative mean variables.
+
+ Methods:
+ add(name, value, count_update=1): Adds a value to the specified variable.
+ get(name): Retrieves the accumulative mean value of the specified variable.
+ reset(): Resets all the accumulative mean variables to their initial state.
+ """
+
+ def __init__(self):
+ self.variable_dict = {}
+
+ def add(self, name, value, count_update=1):
+ if name not in self.variable_dict:
+ self.variable_dict[name] = AccumulativeMeanVariable()
+ self.variable_dict[name].add(value, count_update=count_update)
+
+ def get(self, name):
+ return self.variable_dict[name].get()
+
+ def reset(self):
+ for name in self.variable_dict:
+ self.variable_dict[name].reset()
diff --git a/applications/ColossalChat/coati/utils/ckpt_io.py b/applications/ColossalChat/coati/utils/ckpt_io.py
new file mode 100755
index 000000000000..5b804f0acc14
--- /dev/null
+++ b/applications/ColossalChat/coati/utils/ckpt_io.py
@@ -0,0 +1,93 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+"""
+Helper functions for IO save load checkpoints
+"""
+
+import json
+import os
+from typing import Any, Dict, Tuple, Union
+
+import torch
+from torch.optim.lr_scheduler import _LRScheduler
+from torch.optim.optimizer import Optimizer
+
+from colossalai.booster import Booster
+from colossalai.cluster import DistCoordinator
+
+
+def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
+ """
+ Load file in JSON format
+ """
+ with open(file=file_path, mode="r", encoding="utf-8") as fp:
+ return json.load(fp)
+
+
+def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
+ """
+ Save as JSON format
+ """
+ with open(file=file_path, mode="w", encoding="utf-8") as fp:
+ json.dump(data, fp=fp, ensure_ascii=False, indent=4)
+
+
+def save_checkpoint(
+ save_dir: Union[str, os.PathLike],
+ booster: Booster,
+ model: torch.nn.Module,
+ optimizer: Optimizer,
+ lr_scheduler: _LRScheduler,
+ epoch: int,
+ step: int,
+ batch_size: int,
+ coordinator: DistCoordinator,
+) -> None:
+ """
+ Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
+ """
+
+ save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
+ os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
+
+ booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
+
+ """
+ Temporary disable the following as save_optimizer causes all processes to hang in a multi-gpu environment,
+ working on fixing this bug
+ """
+
+ booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
+ booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
+ running_states = {
+ "epoch": epoch,
+ "step": step,
+ "sample_start_index": step * batch_size,
+ }
+ if coordinator.is_master():
+ save_json(running_states, os.path.join(save_dir, "running_states.json"))
+
+
+def load_checkpoint(
+ load_dir: Union[str, os.PathLike],
+ booster: Booster,
+ model: torch.nn.Module,
+ optimizer: Optimizer,
+ lr_scheduler: _LRScheduler,
+) -> Tuple[int, int, int]:
+ """
+ Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
+ """
+
+ # Update booster params states.
+ booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling"))
+ booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
+ booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
+
+ running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
+ return (
+ running_states["epoch"],
+ running_states["step"],
+ running_states["sample_start_index"],
+ )
diff --git a/applications/ColossalChat/config/conversation_template/Qwen.json b/applications/ColossalChat/config/conversation_template/Qwen.json
new file mode 100644
index 000000000000..09f706ffed90
--- /dev/null
+++ b/applications/ColossalChat/config/conversation_template/Qwen.json
@@ -0,0 +1,7 @@
+{
+ "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
+ "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ "stop_ids": [
+ null
+ ]
+}
diff --git a/applications/ColossalChat/config/conversation_template/Vicuna.json b/applications/ColossalChat/config/conversation_template/Vicuna.json
new file mode 100644
index 000000000000..2b00b6529720
--- /dev/null
+++ b/applications/ColossalChat/config/conversation_template/Vicuna.json
@@ -0,0 +1,7 @@
+{
+ "chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\\n\\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don\\'t know the answer to a question, please don\\'t share false information.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
+ "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ "stop_ids": [
+ 2
+ ]
+}
diff --git a/applications/ColossalChat/config/conversation_template/Yi.json b/applications/ColossalChat/config/conversation_template/Yi.json
new file mode 100644
index 000000000000..9716413b53ad
--- /dev/null
+++ b/applications/ColossalChat/config/conversation_template/Yi.json
@@ -0,0 +1,7 @@
+{
+ "chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
+ "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ "stop_ids": [
+ 2
+ ]
+}
diff --git a/applications/ColossalChat/config/conversation_template/chatGLM2.json b/applications/ColossalChat/config/conversation_template/chatGLM2.json
new file mode 100644
index 000000000000..a2638dbe7439
--- /dev/null
+++ b/applications/ColossalChat/config/conversation_template/chatGLM2.json
@@ -0,0 +1,7 @@
+{
+ "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
+ "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ "stop_ids": [
+ 2
+ ]
+}
diff --git a/applications/ColossalChat/config/conversation_template/colossal-llama2.json b/applications/ColossalChat/config/conversation_template/colossal-llama2.json
new file mode 100644
index 000000000000..cc7f1e5d76fc
--- /dev/null
+++ b/applications/ColossalChat/config/conversation_template/colossal-llama2.json
@@ -0,0 +1,7 @@
+{
+ "chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{% if message['role'] == 'user' %}{{'Human: ' + bos_token + message['content'].strip() + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'].strip() + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + bos_token + message['content'].strip() + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant: ' + bos_token }}{% endif %}",
+ "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ "stop_ids": [
+ 2
+ ]
+}
diff --git a/applications/ColossalChat/config/conversation_template/llama2.json b/applications/ColossalChat/config/conversation_template/llama2.json
new file mode 100644
index 000000000000..80558f976e3b
--- /dev/null
+++ b/applications/ColossalChat/config/conversation_template/llama2.json
@@ -0,0 +1,7 @@
+{
+ "chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
+ "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ "stop_ids": [
+ 2
+ ]
+}
diff --git a/applications/ColossalChat/config/conversation_template/mistral.json b/applications/ColossalChat/config/conversation_template/mistral.json
new file mode 100644
index 000000000000..b48c3a3f27af
--- /dev/null
+++ b/applications/ColossalChat/config/conversation_template/mistral.json
@@ -0,0 +1,7 @@
+{
+ "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
+ "system_message": null,
+ "stop_ids": [
+ 2
+ ]
+}
diff --git a/applications/ColossalChat/config/conversation_template/zephyr.json b/applications/ColossalChat/config/conversation_template/zephyr.json
new file mode 100644
index 000000000000..2ab14111108b
--- /dev/null
+++ b/applications/ColossalChat/config/conversation_template/zephyr.json
@@ -0,0 +1,7 @@
+{
+ "chat_template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
+ "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ "stop_ids": [
+ 2
+ ]
+}
diff --git a/applications/ColossalChat/examples/README.md b/applications/ColossalChat/examples/README.md
new file mode 100755
index 000000000000..cfed3f1f3a75
--- /dev/null
+++ b/applications/ColossalChat/examples/README.md
@@ -0,0 +1,565 @@
+# Examples
+
+## Table of Contents
+
+- [Examples](#examples)
+ - [Table of Contents](#table-of-contents)
+ - [Install Requirements](#install-requirements)
+ - [Get Start with ColossalRun](#get-start-with-colossalrun)
+ - [Training Configuration](#training-configuration)
+ - [RLHF Stage 1: Supervised Instruction Tuning](#rlhf-training-stage1---supervised-instructs-tuning)
+ - [Step 1: Data Collection](#step-1-data-collection)
+ - [Step 2: Preprocessing](#step-2-preprocessing)
+ - [Step 3: Training](#step-3-training)
+ - [RLHF Stage 2: Training Reward Model](#rlhf-training-stage2---training-reward-model)
+ - [Step 1: Data Collection](#step-1-data-collection-1)
+ - [Step 2: Preprocessing](#step-2-preprocessing-1)
+ - [Step 3: Training](#step-3-training-1)
+ - [Features and Tricks in RM Training](#features-and-tricks-in-rm-training)
+ - [RLHF Stage 3: Proximal Policy Optimization](#rlhf-training-stage3---proximal-policy-optimization)
+ - [Step 1: Data Collection](#step-1-data-collection-2)
+ - [Step 2: Preprocessing](#step-2-preprocessing-2)
+ - [Step 3: Training](#step-3-training-3)
+ - [PPO Training Results](#sample-training-results-using-default-script)
+ - [Reward](#reward)
+ - [KL Divergence](#approximate-kl-divergence)
+ - [Note on PPO Training](#note-on-ppo-training)
+ - [Alternative Option For RLHF: Direct Preference Optimization](#alternative-option-for-rlhf-direct-preference-optimization)
+ - [DPO Stage 1: Supervised Instruction Tuning](#dpo-training-stage1---supervised-instructs-tuning)
+ - [DPO Stage 2: DPO Training](#dpo-training-stage2---dpo-training)
+ - [Hardware Requirements](#hardware-requirements)
+ - [Inference example](#inference-example)
+ - [Attention](#attention)
+
+---
+
+## Install requirements
+
+```shell
+pip install -r requirements.txt
+```
+
+
+## Get Start with ColossalRun
+
+You can use colossalai run to launch multi-nodes training:
+```
+colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
+train.py --OTHER_CONFIGURATIONS
+```
+Here is a sample hostfile:
+
+```
+hostname1
+hostname2
+hostname3
+hostname4
+```
+
+Make sure master node can access all nodes (including itself) by ssh without password. Here are some other arguments.
+
+- nnodes: number of nodes used in the training
+- nproc-per-node: specifies the number of processes to be launched per node
+- rdzv-endpoint: address of the host node
+
+### Training Configuration
+
+This section gives a simple introduction on different training strategies that you can use and how to use them with our boosters and plugins to reduce training time and VRAM consumption. For more detail regarding training strategies, please refer to [here](https://colossalai.org/docs/concepts/paradigms_of_parallelism). For details regarding boosters and plugins, please refer to [here](https://colossalai.org/docs/basics/booster_plugins).
+
+
+Gemini
+
+This plugin implements Zero-3 with chunk-based and heterogeneous memory management. It can train large models without much loss in speed. It also does not support local gradient accumulation. More details can be found in [Gemini Doc](https://colossalai.org/docs/features/zero_with_chunk).
+
+Below shows how to use the gemini in SFT training.
+```
+colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
+ --pretrain $PRETRAINED_MODEL_PATH \
+ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
+ --dataset ${dataset[@]} \
+ --save_interval 5000 \
+ --save_path $SAVE_DIR \
+ --config_file $CONFIG_FILE \
+ --plugin gemini \
+ --batch_size 4 \
+ --max_epochs 1 \
+ --accumulation_steps 1 \ # the gradient accumulation has to be disabled
+ --lr 2e-5 \
+ --max_len 2048 \
+ --use_wandb
+```
+
+
+
+Gemini-Auto
+
+This option use gemini and will automatically offload tensors with low priority to cpu. It also does not support local gradient accumulation. More details can be found in [Gemini Doc](https://colossalai.org/docs/features/zero_with_chunk).
+
+Below shows how to use the gemin-auto in SFT training.
+```
+colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
+ --pretrain $PRETRAINED_MODEL_PATH \
+ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
+ --dataset ${dataset[@]} \
+ --save_interval 5000 \
+ --save_path $SAVE_DIR \
+ --config_file $CONFIG_FILE \
+ --plugin gemini_auto \
+ --batch_size 4 \
+ --max_epochs 1 \
+ --accumulation_steps 1 \ # the gradient accumulation has to be disabled
+ --lr 2e-5 \
+ --max_len 2048 \
+ --use_wandb
+```
+
+
+
+
+
+Zero2
+
+This option will distribute the optimizer parameters and the gradient to multiple GPUs and won't offload weights to cpu. It uses reduce and gather to synchronize gradients and weights. It does not support local gradient accumulation. Though you can accumulate gradient if you insist, it cannot reduce communication cost. That is to say, it's not a good idea to use Zero-2 with pipeline parallelism.
+
+Below shows how to use the zero2 in SFT training.
+```
+colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
+ --pretrain $PRETRAINED_MODEL_PATH \
+ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
+ --dataset ${dataset[@]} \
+ --save_interval 5000 \
+ --save_path $SAVE_DIR \
+ --config_file $CONFIG_FILE \
+ --plugin zero2 \
+ --batch_size 4 \
+ --max_epochs 1 \
+ --accumulation_steps 4 \
+ --lr 2e-5 \
+ --max_len 2048 \
+ --use_wandb
+```
+
+
+
+
+Zero2CPU
+
+This option will distribute the optimizer parameters and the gradient to multiple GPUs as well as offload parameters to cpu. It does not support local gradient accumulation. Though you can accumulate gradient if you insist, it cannot reduce communication cost.
+
+Below shows how to use the zero2-cpu in SFT training.
+```
+colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
+ --pretrain $PRETRAINED_MODEL_PATH \
+ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
+ --dataset ${dataset[@]} \
+ --save_interval 5000 \
+ --save_path $SAVE_DIR \
+ --config_file $CONFIG_FILE \
+ --plugin zero2_cpu \
+ --batch_size 4 \
+ --max_epochs 1 \
+ --accumulation_steps 4 \
+ --lr 2e-5 \
+ --max_len 2048 \
+ --use_wandb
+```
+
+
+
+Tensor Parallelism
+
+This option support Tensor Parallelism (TP). Note that if you want to use TP, zero and pipeline parallelism will be disabled. TP split large model weights/optimizer parameters/gradients into multiple small ones and distributes them to multiple GPUs, hence it is recommended to use TP when your model is large (e.g. 20B and above) or your training algorithm consumes a lot of memory (e.g. PPO).
+
+Below shows how to use the TP in PPO training.
+```
+colossalai run --nproc_per_node 4 --hostfile hostfile --master_port 30039 train_ppo.py \
+ --pretrain $PRETRAINED_MODEL_PATH \
+ --rm_pretrain $PRETRAINED_MODEL_PATH \
+ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
+ --rm_checkpoint_path $REWARD_MODEL_PATH \
+ --prompt_dataset ${prompt_dataset[@]} \
+ --pretrain_dataset ${ptx_dataset[@]} \
+ --ptx_batch_size 1 \
+ --ptx_coef 0.0 \
+ --plugin "zero2" \
+ --save_interval 200 \
+ --save_path $SAVE_DIR \
+ --num_episodes 2000 \
+ --num_collect_steps 4 \
+ --num_update_steps 1 \
+ --experience_batch_size 8 \
+ --train_batch_size 4 \
+ --accumulation_steps 8 \
+ --tp 4 \ # TP size, nproc_per_node must be divisible by it
+ --lr 9e-6 \
+ --mixed_precision "bf16" \
+ --grad_clip 1.0 \
+ --weight_decay 0.01 \
+ --warmup_steps 100 \
+ --grad_checkpoint \
+ --use_wandb
+```
+
+
+
+
+Gradient Checkpointing
+
+This option saves VRAM consumption by selectively recomputing some of the intermediate value on-the-fly during the backward pass, rather than storing them in memory.
+
+To enable gradient checkpointing, add --grad_checkpoint to your training script.
+```
+colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
+ --pretrain $PRETRAINED_MODEL_PATH \
+ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
+ --dataset ${dataset[@]} \
+ --save_interval 5000 \
+ --save_path $SAVE_DIR \
+ --config_file $CONFIG_FILE \
+ --plugin zero2_cpu \
+ --batch_size 4 \
+ --max_epochs 1 \
+ --accumulation_steps 4 \
+ --lr 2e-5 \
+ --max_len 2048 \
+ --grad_checkpoint \ # This enables gradient checkpointing
+ --use_wandb
+```
+
+
+
+Flash Attention
+
+Details about flash attention can be found in the paper: [FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://arxiv.org/abs/2205.14135).
+
+To enable flash attention, add --use_flash_attn to your training script.
+```
+colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
+ --pretrain $PRETRAINED_MODEL_PATH \
+ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
+ --dataset ${dataset[@]} \
+ --save_interval 5000 \
+ --save_path $SAVE_DIR \
+ --config_file $CONFIG_FILE \
+ --plugin zero2_cpu \
+ --batch_size 4 \
+ --max_epochs 1 \
+ --accumulation_steps 4 \
+ --lr 2e-5 \
+ --max_len 2048 \
+ --use_flash_attn \ # This enables flash attention
+ --use_wandb
+```
+
+
+
+Low Rank Adaption
+
+Details about Low Rank Adaption (LoRA) can be found in the paper: [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). It dramatically reduce the VRAM consumption at the cost of sacrifice model capability. It is suitable for training LLM with constrained resources.
+
+To enable LoRA, set --lora_rank to a positive value (usually between 20 and 64).
+```
+colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
+ --pretrain $PRETRAINED_MODEL_PATH \
+ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
+ --dataset ${dataset[@]} \
+ --save_interval 5000 \
+ --save_path $SAVE_DIR \
+ --config_file $CONFIG_FILE \
+ --plugin zero2_cpu \
+ --batch_size 4 \
+ --max_epochs 1 \
+ --accumulation_steps 4 \
+ --lr 2e-5 \
+ --max_len 2048 \
+ --lora_rank 32 \ # This enables LoRA
+ --use_wandb
+```
+
+
+
+Other Training Arguments
+
+- grad_clip: gradient larger than this value will be clipped.
+- weight_decay: weight decay hyper-parameter.
+- warmup_steps: number of warmup steps used in setting up the learning rate scheduler.
+- pretrain: pretrain model path, weights will be loaded from this pretrained model unless checkpoint_path is provided.
+- tokenizer_dir: specify where to load the tokenizer, if not provided, tokenizer will be loaded from pretrain model path.
+- dataset: a list of strings, each is a path to a folder contains buffered dataset files in arrow format.
+- checkpoint_path: if provided, will load weights from the checkpoint_path.
+- config_file: path to store the training config file.
+- save_dir: path to store the model checkpoints.
+- max_length: input will be padded/truncate to max_length before feeding to the model.
+- max_epochs: number of epoch to train.
+- batch_size: training batch size.
+- mixed_precision: precision to use in training. Support 'fp16' and 'bf16'. Note that some device may not support the 'bf16' option, please refer to [Nvidia](https://developer.nvidia.com/) to check compatibility.
+- save_interval: save the model weights as well as optimizer/scheduler states every save_interval steps/episodes.
+- merge_lora_weights: whether to merge lora weights before saving the model
+- lr: the learning rate used in training.
+- accumulation_steps: accumulate gradient every accumulation_steps.
+- log_dir: path to store the log.
+- use_wandb: if this flag is up, you can view logs on wandb.
+
+
+
+### RLHF Training Stage1 - Supervised Instructs Tuning
+
+Stage1 is supervised instructs fine-tuning (SFT). This step is a crucial part of the RLHF training process, as it involves training a machine learning model using human-provided instructions to learn the initial behavior for the task at hand. Here's a detailed guide on how to SFT your LLM with ColossalChat:
+
+#### Step 1: Data Collection
+The first step in Stage 1 is to collect a dataset of human demonstrations of the following format.
+
+```json
+[
+ {"messages":
+ [
+ {
+ "from": "human",
+ "content": "what are some pranks with a pen i can do?"
+ },
+ {
+ "from": "assistant",
+ "content": "Are you looking for practical joke ideas?"
+ },
+ ...
+ ]
+ },
+ ...
+]
+```
+
+#### Step 2: Preprocessing
+Once you have collected your SFT dataset, you will need to preprocess it. This involves four steps: data cleaning, data deduplication, formatting and tokenization. In this section, we will focus on formatting and tokenization.
+
+In this code we provide a flexible way for users to set the conversation template for formatting chat data using Huggingface's newest feature--- chat template. Please follow the following steps to define your chat template and preprocess your data.
+
+- Step 1: (Optional). Define your conversation template. You need to provide a conversation template config file similar to the config files under the ./config/conversation_template directory. This config should include the following fields.
+ ```json
+ {
+ "chat_template": (Optional), A string of chat_template used for formatting chat data. If not set (None), will use the default chat template of the provided tokenizer. If a path to a huggingface model or local model is provided, will use the chat_template of that model. To use a custom chat template, you need to manually set this field. For more details on how to write a chat template in Jinja format, please read https://huggingface.co/docs/transformers/main/chat_templating,
+ "system_message": A string of system message to be added at the beginning of the prompt. If no is provided (None), no system message will be added,
+ "stop_ids": (Optional), A list of string indicating the end of assistant's response during the rollout stage of PPO training. It's recommended to set this manually for PPO training. If not set, will set to tokenizer.eos_token_ids automatically,
+ }
+ ```
+ On your first run of the data preparation script, you only need to define the "chat_template" (if you want to use custom chat template) and the "system message" (if you want to use a custom system message),
+
+- Step 2: Run the data preparation script--- [prepare_sft_dataset.sh](./examples/data_preparation_scripts/prepare_sft_dataset.sh). Note that whether or not you have skipped the first step, you need to provide the path to the conversation template config file (via the conversation_template_config arg). If you skipped the first step, an auto-generated conversation template will be stored at the designated file path.
+
+- Step 3: (Optional) Check the correctness of the processed data. We provided an easy way for you to do a manual checking on the processed data by checking the "$SAVE_DIR/jsonl/part-XXXX.jsonl" files.
+
+Finishing the above steps, you have converted the raw conversation to the designated chat format and tokenized the formatted conversation, calculate input_ids, labels, attention_masks and buffer those into binary dataset files under "$SAVE_DIR/arrow/part-XXXX" folders.
+
+For example, our Colossal-LLaMA-2 format looks like,
+```
+ A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
+
+Human: what are some pranks with a pen i can do? Assistant: Are you looking for practical joke ideas?
+...
+```
+
+#### Step 3: Training
+Choose a suitable model architecture for your task. Note that your model should be compatible with the tokenizer that you used to tokenize the SFT dataset. You can run [train_sft.sh](./examples/training_scripts/train_sft.sh) to start a supervised instructs fine-tuning. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
+
+### RLHF Training Stage2 - Training Reward Model
+
+Stage2 trains a reward model, which obtains corresponding scores by manually ranking different outputs for the same prompt and supervises the training of the reward model.
+
+#### Step 1: Data Collection
+Below shows the preference dataset format used in training the reward model.
+
+```json
+[
+ {"context": [
+ {
+ "from": "human",
+ "content": "Introduce butterflies species in Oregon."
+ }
+ ]
+ "chosen": [
+ {
+ "from": "assistant",
+ "content": "About 150 species of butterflies live in Oregon, with about 100 species are moths..."
+ },
+ ...
+ ],
+ "rejected": [
+ {
+ "from": "assistant",
+ "content": "Are you interested in just the common butterflies? There are a few common ones which will be easy to find..."
+ },
+ ...
+ ]
+ },
+ ...
+]
+```
+
+#### Step 2: Preprocessing
+Similar to the second step in the previous stage, we format the reward data into the same structured format as used in step 2 of the SFT stage. You can run [prepare_preference_dataset.sh](./examples/data_preparation_scripts/prepare_preference_dataset.sh) to prepare the preference data for reward model training.
+
+#### Step 3: Training
+You can run [train_rm.sh](./examples/training_scripts/train_rm.sh) to start the reward model training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
+
+#### Features and Tricks in RM Training
+
+- We recommend using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)and[rm-static](https://huggingface.co/datasets/Dahoas/rm-static) datasets for training the reward model.
+- We support 2 kinds of loss function named `log_sig`(used by OpenAI) and `log_exp`(used by Anthropic).
+- We log the training accuracy `train/acc`, `reward_chosen` and `reward_rejected` to monitor progress during training.
+- We use cosine-reducing lr-scheduler for RM training.
+- We set value_head as 1 liner layer and initialize the weight of value_head using N(0,1/(d_model + 1)) distribution.
+
+#### Note on Reward Model Training
+
+Before you move on the next stage, please check the following list to ensure that your reward model is stable and robust. You can check the reward chart and the accuracy chart on wandb.
+- The mean reward for chosen data is much higher than those for rejected data
+- The accuracy is larger than 0.5 by a significant margin (usually should be greater than 0.6)
+- Optional:check the reward is positive for chosen data vice versa
+
+Your training reward curves should look similar to the following charts.
+
+
+
+
+### RLHF Training Stage3 - Proximal Policy Optimization
+
+In stage3 we will use reinforcement learning algorithm--- Proximal Policy Optimization (PPO), which is the most complex part of the training process:
+
+
+
+
+
+#### Step 1: Data Collection
+PPO uses two kind of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "human" and thus the "assistant" needs to generate a response to answer to the "human". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
+
+```json
+[
+ {"messages":
+ [
+ {
+ "from": "human",
+ "content": "what are some pranks with a pen i can do?"
+ }
+ ...
+ ]
+ },
+]
+```
+
+The second dataset--- pretrained dataset is optional, provide it if you want to use the ptx loss introduced in the [InstructGPT paper](https://arxiv.org/abs/2203.02155). It follows the following format.
+
+```json
+ [
+ {
+ "source": "", # system instruction
+ "Target": "Provide a list of the top 10 most popular mobile games in Asia\nThe top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
+ },
+ ...
+ ]
+ ```
+#### Step 2: Preprocessing
+To prepare the prompt dataset for PPO training, simply run [prepare_prompt_dataset.sh](./examples/data_preparation_scripts/prepare_prompt_dataset.sh)
+
+You can use the SFT dataset you prepared in the SFT stage or prepare a new one from different source for the ptx dataset. The ptx data is used to calculate ptx loss, which stablize the training according to the [InstructGPT paper](https://arxiv.org/pdf/2203.02155.pdf).
+
+#### Step 3: Training
+You can run the [train_ppo.sh](./examples/training_scripts/train_ppo.sh) to start PPO training. Here are some unique arguments for PPO, please refer to the training configuration section for other training configuration. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
+
+```bash
+--pretrain $PRETRAINED_MODEL_PATH \
+--rm_pretrain $PRETRAINED_MODEL_PATH \ # reward model architectural
+--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
+--rm_checkpoint_path $REWARD_MODEL_PATH \ # reward model checkpoint path
+--prompt_dataset ${prompt_dataset[@]} \ # List of string, prompt dataset
+--conversation_template_config $CONVERSATION_TEMPLATE_CONFIG_PATH \ # path to the conversation template config file
+--pretrain_dataset ${ptx_dataset[@]} \ # List of string, the sft dataset
+--ptx_batch_size 1 \ # batch size for calculate ptx loss
+--ptx_coef 0.0 \ # none-zero if ptx loss is enable
+--num_episodes 2000 \ # number of episodes to train
+--num_collect_steps 1 \
+--num_update_steps 1 \
+--experience_batch_size 8 \
+--train_batch_size 4 \
+--accumulation_steps 2
+```
+
+Each episode has two phases, the collect phase and the update phase. During the collect phase, we will collect experiences (answers generated by actor), store those in ExperienceBuffer. Then data in ExperienceBuffer is used during the update phase to update parameter of actor and critic.
+
+- Without tensor parallelism,
+```
+experience buffer size
+= num_process * num_collect_steps * experience_batch_size
+= train_batch_size * accumulation_steps * num_process
+```
+
+- With tensor parallelism,
+```
+num_tp_group = num_process / tp
+experience buffer size
+= num_tp_group * num_collect_steps * experience_batch_size
+= train_batch_size * accumulation_steps * num_tp_group
+```
+
+### Sample Training Results Using Default Script
+#### Reward
+
+
+
+
+### Note on PPO Training
+#### Q1: My reward is negative
+Answer: Check your reward model trained in stage 1. If the reward model only generate negative reward, we actually will expect a negative reward. However, even though the reward is negative, the reward should go up.
+
+#### Q2: My actor loss is negative
+Answer: This is normal for actor loss as PPO doesn't restrict the actor loss to be positive.
+
+#### Q3: My reward doesn't go up (decreases)
+Answer: The causes to this problem are two-fold. Check your reward model, make sure that it gives positive and strong reward for good cases and negative, strong reward for bad responses. You should also try different hyperparameter settings.
+
+#### Q4: Generation is garbage
+Answer: Yes, this happens and is well documented by other implementations. After training for too many episodes, the actor gradually deviate from its original state, which may leads to decrease in language modeling capabilities. A way to fix this is to add supervised loss during PPO. Set ptx_coef to a none-zero value (between 0 and 1), which balances PPO loss and sft loss.
+
+## Alternative Option For RLHF: Direct Preference Optimization
+
+For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in the paper (available at [https://arxiv.org/abs/2305.18290](https://arxiv.org/abs/2305.18290)), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO.
+
+### DPO Training Stage1 - Supervised Instructs Tuning
+
+Please refer the [sft section](#dpo-training-stage1---supervised-instructs-tuning) in the PPO part.
+
+### DPO Training Stage2 - DPO Training
+#### Step 1: Data Collection & Preparation
+For DPO training, you only need the preference dataset. Please follow the instruction in the [preference dataset preparation section](#rlhf-training-stage2---training-reward-model) to prepare the preference data for DPO training.
+
+#### Step 2: Training
+You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
+
+#### DPO Result
+
+
+
+
+## Hardware Requirements
+For PPO, we suggest using Tensor Parallelism. The following table shows the VRAM consumption of training a 7B model on a dummy dataset with 2048 sequence length and 512 layout length with different tp_size (equal to the number of GPUs). In this experiment, we use H800 GPU with 80GB VRAM.
+| PPO | tp=8 | tp=4 |
+|-------|---------------|---------------|
+| bs=1 | 18485.19 MB | 42934.45 MB |
+| bs=4 | 25585.65 MB | 42941.93 MB |
+| bs=16 | 41408.28 MB | 56778.97 MB |
+| bs=30 | 64047.42 MB | failed |
+
+For DPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption on a dummy dataset with 2048 sequence length.
+
+- 1 H800 GPU
+ - zero2-cpu, batch size=2, VRAM Usage=49873.90 MB
+ - zero2-cpu, batch size=4, VRAM Usage=60998.22 MB
+- 4 H800 GPUs
+ - zero2, batch size=4, VRAM Usage=67544.47 MB
+
+## Inference example
+
+We support different inference options, including int8 and int4 quantization.
+For details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
+
+## Attention
+
+The examples are demos for the whole training process. You need to change the hyper-parameters to reach great performance.
diff --git a/applications/Chat/examples/community/README.md b/applications/ColossalChat/examples/community/README.md
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/examples/community/README.md
rename to applications/ColossalChat/examples/community/README.md
diff --git a/applications/Chat/examples/community/peft/README.md b/applications/ColossalChat/examples/community/peft/README.md
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/examples/community/peft/README.md
rename to applications/ColossalChat/examples/community/peft/README.md
diff --git a/applications/Chat/examples/community/peft/easy_dataset.py b/applications/ColossalChat/examples/community/peft/easy_dataset.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/examples/community/peft/easy_dataset.py
rename to applications/ColossalChat/examples/community/peft/easy_dataset.py
diff --git a/applications/Chat/examples/community/peft/easy_models.py b/applications/ColossalChat/examples/community/peft/easy_models.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/examples/community/peft/easy_models.py
rename to applications/ColossalChat/examples/community/peft/easy_models.py
diff --git a/applications/Chat/examples/community/peft/train_peft_prompts.py b/applications/ColossalChat/examples/community/peft/train_peft_prompts.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/examples/community/peft/train_peft_prompts.py
rename to applications/ColossalChat/examples/community/peft/train_peft_prompts.py
diff --git a/applications/Chat/examples/community/peft/train_peft_sft.py b/applications/ColossalChat/examples/community/peft/train_peft_sft.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/examples/community/peft/train_peft_sft.py
rename to applications/ColossalChat/examples/community/peft/train_peft_sft.py
diff --git a/applications/Chat/examples/community/ray/README.md b/applications/ColossalChat/examples/community/ray/README.md
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/examples/community/ray/README.md
rename to applications/ColossalChat/examples/community/ray/README.md
diff --git a/applications/Chat/examples/community/ray/ray_job_script.py b/applications/ColossalChat/examples/community/ray/ray_job_script.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/examples/community/ray/ray_job_script.py
rename to applications/ColossalChat/examples/community/ray/ray_job_script.py
diff --git a/applications/Chat/examples/community/ray/train_prompts_on_ray.py b/applications/ColossalChat/examples/community/ray/train_prompts_on_ray.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/examples/community/ray/train_prompts_on_ray.py
rename to applications/ColossalChat/examples/community/ray/train_prompts_on_ray.py
diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py b/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py
new file mode 100644
index 000000000000..64093f88d7ca
--- /dev/null
+++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py
@@ -0,0 +1,268 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Prepare dataset scripts
+
+Usage:
+- For SFT dataset preparation (SFT)
+python prepare_dataset.py --type sft \
+ --data_input_dirs /PATH/TO/SFT/DATASET \
+ --conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
+ --tokenizer_dir "" \
+ --data_cache_dir $SAVE_DIR/cache \
+ --data_jsonl_output_dir $SAVE_DIR/jsonl \
+ --data_arrow_output_dir $SAVE_DIR/arrow \
+
+- For prompt dataset preparation (PPO)
+python prepare_dataset.py --type prompt \
+ --data_input_dirs /PATH/TO/SFT/DATASET \
+ --conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
+ --tokenizer_dir "" \
+ --data_cache_dir $SAVE_DIR/cache \
+ --data_jsonl_output_dir $SAVE_DIR/jsonl \
+ --data_arrow_output_dir $SAVE_DIR/arrow \
+
+- For Preference dataset preparation (DPO and Reward model training)
+python prepare_dataset.py --type preference \
+ --data_input_dirs /PATH/TO/SFT/DATASET \
+ --conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
+ --tokenizer_dir "" \
+ --data_cache_dir $SAVE_DIR/cache \
+ --data_jsonl_output_dir $SAVE_DIR/jsonl \
+ --data_arrow_output_dir $SAVE_DIR/arrow \
+"""
+
+import argparse
+import json
+import math
+import os
+import random
+import time
+from multiprocessing import cpu_count
+
+from coati.dataset import setup_conversation_template, supervised_tokenize_sft, tokenize_prompt_dataset, tokenize_rlhf
+from datasets import dataset_dict, load_dataset
+from transformers import AutoTokenizer
+
+from colossalai.logging import get_dist_logger
+
+logger = get_dist_logger()
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--type",
+ type=str,
+ required=True,
+ default=None,
+ choices=["sft", "prompt", "preference"],
+ help="Type of dataset, chose from 'sft', 'prompt', 'preference'.",
+ )
+ parser.add_argument(
+ "--data_input_dirs",
+ type=str,
+ required=True,
+ default=None,
+ help="Comma(i.e., ',') separated list of all data directories containing `.jsonl` data files.",
+ )
+ parser.add_argument(
+ "--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer"
+ )
+ parser.add_argument(
+ "--conversation_template_config",
+ type=str,
+ default="conversation_template_config",
+ help="Path \
+ to save conversation template config files.",
+ )
+ parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
+ parser.add_argument(
+ "--data_jsonl_output_dir",
+ type=str,
+ default="jsonl_output",
+ help="Output directory of spliced dataset with jsonl format",
+ )
+ parser.add_argument(
+ "--data_arrow_output_dir",
+ type=str,
+ default="arrow_output",
+ help="Output directory of spliced dataset with arrow format",
+ )
+ parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence")
+ parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins")
+ parser.add_argument(
+ "--num_samples_per_datafile",
+ type=int,
+ default=-1,
+ help="Number of samples to be generated from each data file. -1 denote all samples.",
+ )
+ args = parser.parse_args()
+
+ if args.num_spliced_dataset_bins >= 100000:
+ raise ValueError("Too many spliced divisions, must be smaller than 100000")
+
+ assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}"
+ assert not os.path.exists(
+ args.data_jsonl_output_dir
+ ), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}"
+ assert not os.path.exists(
+ args.data_arrow_output_dir
+ ), f"Find existed arrow data output dir {args.data_arrow_output_dir}"
+ os.makedirs(args.data_jsonl_output_dir)
+ os.makedirs(args.data_arrow_output_dir)
+
+ # Prepare to all input datasets
+ input_data_paths = []
+ input_data_dirs = args.data_input_dirs.split(",")
+ for ds_dir in input_data_dirs:
+ ds_dir = os.path.abspath(ds_dir)
+ assert os.path.exists(ds_dir), f"Not find data dir {ds_dir}"
+ ds_files = [name for name in os.listdir(ds_dir) if name.endswith(".jsonl")]
+ ds_paths = [os.path.join(ds_dir, name) for name in ds_files]
+ input_data_paths.extend(ds_paths)
+
+ # Prepare to data splitting.
+ train_splits = []
+ split_interval = math.ceil(100 / args.num_spliced_dataset_bins)
+ for i in range(0, 100, split_interval):
+ start = i
+ end = i + split_interval
+ if end > 100:
+ end = 100
+ train_splits.append(f"train[{start}%:{end}%]")
+
+ # Prepare the tokenizer.
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir, use_fast=False, trust_remote_code=True)
+ if os.path.exists(args.conversation_template_config):
+ chat_template_config = json.load(open(args.conversation_template_config, "r", encoding="utf8"))
+ else:
+ chat_template_config = {
+ "system_message": "A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n"
+ } # Use default system message
+ if args.type == "preference":
+ if "stop_ids" not in chat_template_config:
+ # Ask the user to define stop_ids for PPO training
+ dummy_messages = [
+ {"role": "user", "content": "Hello, how are you?"},
+ {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
+ {"role": "user", "content": "Who made you?"},
+ {"role": "assistant", "content": "I am a chatbot trained by Colossal-AI."},
+ ]
+ dummy_prompt = tokenizer.apply_chat_template(dummy_messages, tokenize=False)
+ tokenized = tokenizer(dummy_prompt, add_special_tokens=False)["input_ids"]
+ tokens = tokenizer.convert_ids_to_tokens(tokenized, skip_special_tokens=False)
+ corresponding_str = [tokenizer.convert_tokens_to_string([token]) for token in tokens]
+ token_id_mapping = [{"token": s, "id": tokenized[i]} for i, s in enumerate(corresponding_str)]
+ stop_ids = input(
+ "For PPO, we recommend to provide stop_ids for the properly stop the generation during roll out stage. "
+ "stop_ids are the ids of repetitive pattern that indicate the end of the assistant's response. "
+ "Here is an example of formatted prompt and token-id mapping, you can set stop_ids by entering a list "
+ "of integers, separate by space, press `Enter` to end. Or you can press `Enter` without input if you are "
+ "not using PPO or you prefer to not set the stop_ids, in that case, stop_ids will be set to tokenizer.eos_token_id. "
+ f"\nPrompt:\n{dummy_prompt}\nToken-id Mapping:\n{token_id_mapping}\nstop_ids:"
+ )
+ if stop_ids == "":
+ chat_template_config["stop_ids"] = [tokenizer.eos_token_id]
+ else:
+ try:
+ chat_template_config["stop_ids"] = [int(s) for s in stop_ids.split()]
+ except ValueError:
+ raise ValueError("Invalid input, please provide a list of integers.")
+ else:
+ # Set stop_ids to eos_token_id for other dataset types if not exist
+ if "stop_ids" not in chat_template_config:
+ chat_template_config["stop_ids"] = [tokenizer.eos_token_id]
+
+ conversation_template = setup_conversation_template(
+ tokenizer, chat_template_config=chat_template_config, save_path=args.conversation_template_config
+ )
+ if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
+ try:
+ # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
+ tokenizer.pad_token = tokenizer.eos_token
+ except AttributeError as e:
+ logger.warning(f"Unable to set pad token to eos token, {str(e)}")
+ if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
+ logger.warning(
+ "The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
+ )
+
+ list_dataset = load_dataset(
+ path="json",
+ data_files=input_data_paths,
+ cache_dir=os.path.join(args.data_cache_dir, "raw"),
+ keep_in_memory=False,
+ split=train_splits,
+ num_proc=cpu_count(),
+ )
+
+ if args.type == "sft":
+ preparation_function = supervised_tokenize_sft
+ elif args.type == "prompt":
+ preparation_function = tokenize_prompt_dataset
+ elif args.type == "preference":
+ preparation_function = tokenize_rlhf
+ else:
+ raise ValueError("Unknow dataset type. Please choose one from ['sft', 'prompt', 'preference']")
+
+ for index, dataset in enumerate(list_dataset):
+ assert isinstance(dataset, dataset_dict.Dataset)
+ if len(dataset) == 0:
+ # Hack: Skip empty dataset. If dataset contains less than num_of_rank samples, some rank may have empty dataset and leads to error
+ continue
+ if args.num_samples_per_datafile > 0:
+ # limit the number of samples in each dataset
+ dataset = dataset.select(
+ random.sample(range(len(dataset)), min(args.num_samples_per_datafile, len(dataset)))
+ )
+ logger.info(f"Start to process part-{index}/{len(list_dataset)} of all original datasets.")
+ dataset = dataset.map(
+ function=preparation_function,
+ fn_kwargs={
+ "tokenizer": tokenizer,
+ "conversation_template": conversation_template,
+ "max_length": args.max_length,
+ },
+ keep_in_memory=False,
+ num_proc=min(len(dataset), cpu_count()),
+ )
+
+ dataset = dataset.filter(
+ lambda data: data["chosen_input_ids" if args.type == "preference" else "input_ids"] is not None
+ )
+
+ # Save each jsonl spliced dataset.
+ output_index = "0" * (5 - len(str(index))) + str(index)
+ output_name = f"part-{output_index}"
+ output_jsonl_path = os.path.join(args.data_jsonl_output_dir, output_name + ".jsonl")
+ st = time.time()
+ with open(file=output_jsonl_path, mode="w", encoding="utf-8") as fp_writer:
+ count = 0
+ for data_point in dataset:
+ if count % 500 == 0:
+ logger.info(f"processing {count} spliced data points for {fp_writer.name}")
+ count += 1
+ fp_writer.write(json.dumps(data_point, ensure_ascii=False) + "\n")
+ logger.info(
+ f"Current file {fp_writer.name}; "
+ f"Data size: {len(dataset)}; "
+ f"Time cost: {round((time.time() - st) / 60, 6)} minutes."
+ )
+ # Save each arrow spliced dataset
+ output_arrow_path = os.path.join(args.data_arrow_output_dir, output_name)
+ logger.info(f"Start to save {output_arrow_path}")
+ dataset = load_dataset(
+ path="json",
+ data_files=[output_jsonl_path],
+ cache_dir=os.path.join(args.data_cache_dir, "tokenized"),
+ keep_in_memory=False,
+ num_proc=cpu_count(),
+ split="train",
+ )
+ dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(dataset), cpu_count()))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_preference_dataset.sh b/applications/ColossalChat/examples/data_preparation_scripts/prepare_preference_dataset.sh
new file mode 100755
index 000000000000..999d7778be52
--- /dev/null
+++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_preference_dataset.sh
@@ -0,0 +1,13 @@
+SAVE_DIR=""
+
+rm -rf $SAVE_DIR/cache
+rm -rf $SAVE_DIR/jsonl
+rm -rf $SAVE_DIR/arrow
+
+python prepare_dataset.py --type preference \
+ --data_input_dirs "PATH/TO/PREFERENCE/DATA" \
+ --conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
+ --tokenizer_dir "" \
+ --data_cache_dir $SAVE_DIR/cache \
+ --data_jsonl_output_dir $SAVE_DIR/jsonl \
+ --data_arrow_output_dir $SAVE_DIR/arrow
diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_prompt_dataset.sh b/applications/ColossalChat/examples/data_preparation_scripts/prepare_prompt_dataset.sh
new file mode 100755
index 000000000000..8d3d6c2c2d80
--- /dev/null
+++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_prompt_dataset.sh
@@ -0,0 +1,13 @@
+SAVE_DIR=""
+
+rm -rf $SAVE_DIR/cache
+rm -rf $SAVE_DIR/jsonl
+rm -rf $SAVE_DIR/arrow
+
+python prepare_dataset.py --type prompt \
+ --data_input_dirs /PATH/TO/PROMPT/DATASET \
+ --conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
+ --tokenizer_dir "" \
+ --data_cache_dir $SAVE_DIR/cache \
+ --data_jsonl_output_dir $SAVE_DIR/jsonl \
+ --data_arrow_output_dir $SAVE_DIR/arrow
diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh b/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh
new file mode 100755
index 000000000000..cf937db2a84b
--- /dev/null
+++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh
@@ -0,0 +1,13 @@
+SAVE_DIR=""
+
+rm -rf $SAVE_DIR/cache
+rm -rf $SAVE_DIR/jsonl
+rm -rf $SAVE_DIR/arrow
+
+python prepare_dataset.py --type sft \
+ --data_input_dirs /PATH/TO/SFT/DATASET \
+ --conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
+ --tokenizer_dir "" \
+ --data_cache_dir $SAVE_DIR/cache \
+ --data_jsonl_output_dir $SAVE_DIR/jsonl \
+ --data_arrow_output_dir $SAVE_DIR/arrow \
diff --git a/applications/ColossalChat/examples/inference/chatio.py b/applications/ColossalChat/examples/inference/chatio.py
new file mode 100755
index 000000000000..26784f3a3411
--- /dev/null
+++ b/applications/ColossalChat/examples/inference/chatio.py
@@ -0,0 +1,168 @@
+"""
+command line IO utils for chatbot
+"""
+
+import abc
+import re
+
+from prompt_toolkit import PromptSession
+from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
+from prompt_toolkit.completion import WordCompleter
+from prompt_toolkit.history import InMemoryHistory
+from rich.console import Console
+from rich.live import Live
+from rich.markdown import Markdown
+
+
+class ChatIO(abc.ABC):
+ @abc.abstractmethod
+ def prompt_for_input(self, role: str) -> str:
+ """Prompt for input from a role."""
+
+ @abc.abstractmethod
+ def prompt_for_output(self, role: str):
+ """Prompt for output from a role."""
+
+ @abc.abstractmethod
+ def stream_output(self, output_stream):
+ """Stream output."""
+
+
+class SimpleChatIO(ChatIO):
+ def prompt_for_input(self, role) -> str:
+ return input(f"{role}: ")
+
+ def prompt_for_output(self, role: str):
+ print(f"{role}: ", end="", flush=True)
+
+ def stream_output(self, output_stream):
+ pre = 0
+ for outputs in output_stream:
+ outputs = outputs.strip()
+ outputs = outputs.split(" ")
+ now = len(outputs) - 1
+ if now > pre:
+ print(" ".join(outputs[pre:now]), end=" ", flush=True)
+ pre = now
+ print(" ".join(outputs[pre:]), flush=True)
+ return " ".join(outputs)
+
+
+class RichChatIO(ChatIO):
+ def __init__(self):
+ self._prompt_session = PromptSession(history=InMemoryHistory())
+ self._completer = WordCompleter(words=["!exit", "!reset"], pattern=re.compile("$"))
+ self._console = Console()
+
+ def prompt_for_input(self, role) -> str:
+ self._console.print(f"[bold]{role}:")
+ prompt_input = self._prompt_session.prompt(
+ completer=self._completer,
+ multiline=False,
+ auto_suggest=AutoSuggestFromHistory(),
+ key_bindings=None,
+ )
+ self._console.print()
+ return prompt_input
+
+ def prompt_for_output(self, role: str) -> str:
+ self._console.print(f"[bold]{role}:")
+
+ def stream_output(self, output_stream):
+ """Stream output from a role."""
+ # Create a Live context for updating the console output
+ with Live(console=self._console, refresh_per_second=60) as live:
+ # Read lines from the stream
+ for outputs in output_stream:
+ accumulated_text = outputs
+ if not accumulated_text:
+ continue
+ # Render the accumulated text as Markdown
+ # NOTE: this is a workaround for the rendering "unstandard markdown"
+ # in rich. The chatbots output treat "\n" as a new line for
+ # better compatibility with real-world text. However, rendering
+ # in markdown would break the format. It is because standard markdown
+ # treat a single "\n" in normal text as a space.
+ # Our workaround is adding two spaces at the end of each line.
+ # This is not a perfect solution, as it would
+ # introduce trailing spaces (only) in code block, but it works well
+ # especially for console output, because in general the console does not
+ # care about trailing spaces.
+ lines = []
+ for line in accumulated_text.splitlines():
+ lines.append(line)
+ if line.startswith("```"):
+ # Code block marker - do not add trailing spaces, as it would
+ # break the syntax highlighting
+ lines.append("\n")
+ else:
+ lines.append(" \n")
+ markdown = Markdown("".join(lines))
+ # Update the Live console output
+ live.update(markdown)
+ self._console.print()
+ return outputs
+
+
+class DummyChatIO(ChatIO):
+ """
+ Dummy ChatIO class for testing
+ """
+
+ def __init__(self):
+ self.roles = []
+ self._console = Console()
+
+ def prompt_for_input(self, role) -> str:
+ self.roles.append(role)
+ if len(self.roles) == 1:
+ ret = "Hello"
+ elif len(self.roles) == 2:
+ ret = "What's the value of 1+1?"
+ else:
+ ret = "exit"
+ self._console.print(f"[bold]{role}:{ret}")
+ return ret
+
+ def prompt_for_output(self, role: str) -> str:
+ self._console.print(f"[bold]{role}:")
+
+ def stream_output(self, output_stream):
+ """Stream output from a role."""
+ # Create a Live context for updating the console output
+ with Live(console=self._console, refresh_per_second=60) as live:
+ # Read lines from the stream
+ for outputs in output_stream:
+ accumulated_text = outputs
+ if not accumulated_text:
+ continue
+ # Render the accumulated text as Markdown
+ # NOTE: this is a workaround for the rendering "unstandard markdown"
+ # in rich. The chatbots output treat "\n" as a new line for
+ # better compatibility with real-world text. However, rendering
+ # in markdown would break the format. It is because standard markdown
+ # treat a single "\n" in normal text as a space.
+ # Our workaround is adding two spaces at the end of each line.
+ # This is not a perfect solution, as it would
+ # introduce trailing spaces (only) in code block, but it works well
+ # especially for console output, because in general the console does not
+ # care about trailing spaces.
+ lines = []
+ for line in accumulated_text.splitlines():
+ lines.append(line)
+ if line.startswith("```"):
+ # Code block marker - do not add trailing spaces, as it would
+ # break the syntax highlighting
+ lines.append("\n")
+ else:
+ lines.append(" \n")
+ markdown = Markdown("".join(lines))
+ # Update the Live console output
+ live.update(markdown)
+ self._console.print()
+ return outputs
+
+
+simple_io = SimpleChatIO()
+rich_io = RichChatIO()
+dummy_io = DummyChatIO()
diff --git a/applications/ColossalChat/examples/inference/inference.py b/applications/ColossalChat/examples/inference/inference.py
new file mode 100755
index 000000000000..103bd8d95016
--- /dev/null
+++ b/applications/ColossalChat/examples/inference/inference.py
@@ -0,0 +1,195 @@
+import argparse
+import json
+import os
+from typing import Dict
+
+import torch
+from chatio import dummy_io, rich_io, simple_io
+from coati.dataset.conversation import setup_conversation_template
+from coati.models import generate_streaming
+from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
+
+from colossalai.logging import get_dist_logger
+
+logger = get_dist_logger()
+
+
+def get_gpu_memory(max_gpus=None):
+ """
+ Get the available memory for each GPU.
+
+ Args:
+ max_gpus (int, optional): The maximum number of GPUs to consider. Defaults to None.
+
+ Returns:
+ list: A list of available memory for each GPU.
+ """
+ gpu_memory = []
+ num_gpus = torch.cuda.device_count() if max_gpus is None else min(max_gpus, torch.cuda.device_count())
+
+ for gpu_id in range(num_gpus):
+ # Code to get GPU memory goes here
+ with torch.cuda.device(gpu_id):
+ device = torch.cuda.current_device()
+ gpu_properties = torch.cuda.get_device_properties(device)
+ total_memory = gpu_properties.total_memory / (1024**3)
+ allocated_memory = torch.cuda.memory_allocated() / (1024**3)
+ available_memory = total_memory - allocated_memory
+ gpu_memory.append(available_memory)
+ return gpu_memory
+
+
+def load_model_and_tokenizer(model_path, tokenizer_path, device="cuda", **kwargs):
+ """
+ Load the model and tokenizer from the specified paths and move the model to the specified device.
+
+ Args:
+ model_path (str): The path to the pre-trained model.
+ tokenizer_path (str): The path to the pre-trained tokenizer.
+ device (str, optional): The device to move the model to. Defaults to "cuda".
+ **kwargs: Additional keyword arguments to be passed to the `AutoModelForCausalLM.from_pretrained` function.
+
+ Returns:
+ tuple: A tuple containing the loaded model and tokenizer.
+ """
+
+ model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
+ tokenizer.pad_token = tokenizer.eos_token
+ model.to(device)
+
+ return model, tokenizer
+
+
+def _set_default_generate_kwargs(model: PreTrainedModel) -> Dict:
+ """
+ Set default keyword arguments for generation based on the given model.
+
+ Args:
+ model (PreTrainedModel): The model used for generation.
+
+ Returns:
+ Dict: A dictionary containing the default keyword arguments for generation.
+ """
+ unwrapped_model = model
+ new_kwargs = {}
+ # Use huggingface models method directly
+ if hasattr(unwrapped_model, "prepare_inputs_for_generation"):
+ new_kwargs["prepare_inputs_fn"] = unwrapped_model.prepare_inputs_for_generation
+
+ if hasattr(unwrapped_model, "_update_model_kwargs_for_generation"):
+ new_kwargs["update_model_kwargs_fn"] = unwrapped_model._update_model_kwargs_for_generation
+ return new_kwargs
+
+
+def generation_wrapper(*args, **kwargs):
+ input_ids = args[1]
+ tokenizer = args[2]
+ for output in generate_streaming(*args, **kwargs):
+ yield tokenizer.batch_decode(output[:, input_ids.size(1) :], skip_special_tokens=True)[0]
+
+
+def main(args):
+ conversation_template_config = json.load(open(args.conversation_template_config, "r", encoding="utf8"))
+
+ max_new_tokens = args.max_new_tokens
+ model_max_length = args.model_max_length
+ model, tokenizer = load_model_and_tokenizer(
+ args.model_path, args.tokenizer_path or args.model_path, local_files_only=True
+ )
+
+ assert max_new_tokens <= model_max_length
+ if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
+ try:
+ # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
+ tokenizer.pad_token = tokenizer.eos_token
+ except AttributeError as e:
+ logger.warning(f"Unable to set pad token to eos token, {str(e)}")
+ tokenizer.padding_side = "left"
+
+ model_kwargs = {
+ "max_new_tokens": max_new_tokens,
+ # 'early_stopping': True,
+ # 'top_k': -1,
+ # 'top_p': 1.0,
+ # 'temperature': 1.0,
+ # 'temperature':0.1,
+ }
+ round = 1
+
+ conv = setup_conversation_template(tokenizer, conversation_template_config, args.conversation_template_config)
+
+ while True:
+ if args.io == "simple":
+ chat_io = simple_io
+ elif args.io == "rich":
+ chat_io = rich_io
+ elif args.io == "dummy":
+ chat_io = dummy_io
+ else:
+ raise ValueError(f"Unknown io type: {args.io}")
+ # raw_text = print(">>> Human:", end=" ")
+ inp = chat_io.prompt_for_input("user")
+
+ if not inp:
+ print("prompt should not be empty!")
+ continue
+
+ if inp.strip() == "clear":
+ conv.clear()
+ os.system("clear")
+ continue
+
+ if inp.strip() == "exit":
+ print("End of chat.")
+ break
+
+ query_text = inp.strip()
+
+ conv.append_message("user", query_text)
+
+ chat_io.prompt_for_output("assistant")
+
+ prompt = conv.get_prompt(add_generation_prompt=True)
+ print(prompt + "")
+ input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].to(
+ torch.cuda.current_device()
+ )
+ default_generate_kwargs = _set_default_generate_kwargs(model)
+ model_kwargs.update(default_generate_kwargs)
+ output_stream = generation_wrapper(
+ model,
+ input_ids,
+ tokenizer,
+ max_length=model_max_length,
+ temperature=0.7,
+ early_stopping=True,
+ stop_token_ids=conversation_template_config["stop_ids"],
+ **model_kwargs,
+ )
+
+ # print(f">>> Assistant:", end=" ")
+ outputs = chat_io.stream_output(output_stream)
+
+ conv.append_message("assistant", outputs.strip())
+
+ with open("round.txt", mode="a", encoding="utf-8") as f:
+ f.write("\n\n" + "=" * 10 + "\n")
+ f.write(f"round {round}:\n{conv.save_prompt()}\n\n")
+ f.write("=" * 10 + "\n")
+
+ # print(f">>> Assistant:", end=" ")
+
+ round += 1
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_path", type=str, default=None)
+ parser.add_argument("--tokenizer_path", type=str, default=None)
+ parser.add_argument("--conversation_template_config", type=str, default=None)
+ parser.add_argument("--model_max_length", type=int, default=2048)
+ parser.add_argument("--max_new_tokens", type=int, default=512)
+ parser.add_argument("--io", type=str, default="rich", choices=["simple", "rich", "dummy"])
+ args = parser.parse_args()
+ main(args)
diff --git a/applications/Chat/inference/README.md b/applications/ColossalChat/examples/inference/web_chatbot/README.md
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/inference/README.md
rename to applications/ColossalChat/examples/inference/web_chatbot/README.md
diff --git a/applications/Chat/inference/locustfile.py b/applications/ColossalChat/examples/inference/web_chatbot/locustfile.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/inference/locustfile.py
rename to applications/ColossalChat/examples/inference/web_chatbot/locustfile.py
diff --git a/applications/Chat/inference/requirements.txt b/applications/ColossalChat/examples/inference/web_chatbot/requirements.txt
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/inference/requirements.txt
rename to applications/ColossalChat/examples/inference/web_chatbot/requirements.txt
diff --git a/applications/Chat/inference/server.py b/applications/ColossalChat/examples/inference/web_chatbot/server.py
old mode 100644
new mode 100755
similarity index 79%
rename from applications/Chat/inference/server.py
rename to applications/ColossalChat/examples/inference/web_chatbot/server.py
index 7c6a61b9e7f2..aec342802b02
--- a/applications/Chat/inference/server.py
+++ b/applications/ColossalChat/examples/inference/web_chatbot/server.py
@@ -5,6 +5,7 @@
import torch
import uvicorn
+from coati.models import generate_streaming
from coati.quant import llama_load_quant, low_resource_init
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
@@ -13,10 +14,9 @@
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from sse_starlette.sse import EventSourceResponse
-from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
-from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn
+from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
+from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, update_model_kwargs_fn
-CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions."
MAX_LEN = 512
running_lock = Lock()
@@ -54,20 +54,22 @@ class GenerationTaskReq(BaseModel):
)
-def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
- inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
+def generate_streamingly(prompt, max_length, max_new_tokens, top_k, top_p, temperature):
+ input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
# TODO(ver217): streaming generation does not support repetition_penalty now
model_kwargs = {
- "max_generate_tokens": max_new_tokens,
+ "max_new_tokens": max_new_tokens,
"early_stopping": True,
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
- "prepare_inputs_fn": model.prepare_inputs_for_generation,
+ "prepare_inputs_fn": None,
"update_model_kwargs_fn": update_model_kwargs_fn,
}
is_first_word = True
- generator = LockedIterator(sample_streamingly(model, **inputs, **model_kwargs), running_lock)
+ generator = LockedIterator(
+ generate_streaming(model, input_ids, tokenizer, max_length, **model_kwargs), running_lock
+ )
for output in generator:
output = output.cpu()
tokens = tokenizer.convert_ids_to_tokens(output, skip_special_tokens=True)
@@ -101,9 +103,10 @@ async def event_generator(request: Request, generator: Generator):
@app.post("/generate/stream")
@limiter.limit("1/second")
def generate(data: GenerationTaskReq, request: Request):
- prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
+ prompt = prompt_processor.preprocess_prompt(data.history)
event_source = event_generator(
- request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature)
+ request,
+ generate_streamingly(prompt, data.max_length, data.max_new_tokens, data.top_k, data.top_p, data.temperature),
)
return EventSourceResponse(event_source)
@@ -133,6 +136,11 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
"pretrained",
help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.",
)
+ parser.add_argument(
+ "--tokenizer_path",
+ help="Path to pretrained tokenizer. Can be a local path or a model name from the HuggingFace model hub.",
+ default=None,
+ )
parser.add_argument(
"--quant",
choices=["8bit", "4bit"],
@@ -162,26 +170,29 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
if args.quant == "4bit":
assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint."
- tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
+ if args.tokenizer_path is None:
+ args.tokenizer_path = args.pretrained
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, local_files_only=True)
if args.profanity_file is not None:
censored_words = load_json(args.profanity_file)
else:
censored_words = []
- prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words)
+ prompt_processor = ChatPromptProcessor(censored_words=censored_words)
if args.quant == "4bit":
with low_resource_init():
- config = LlamaConfig.from_pretrained(args.pretrained)
- model = LlamaForCausalLM(config)
+ config = AutoConfig.from_pretrained(args.pretrained)
+ model = AutoModelForCausalLM(config)
model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size)
model.cuda()
else:
- model = LlamaForCausalLM.from_pretrained(
+ model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
load_in_8bit=(args.quant == "8bit"),
torch_dtype=torch.float16,
device_map="auto",
+ local_files_only=True,
)
if args.quant != "8bit":
model.half() # seems to fix bugs for some users.
@@ -190,3 +201,8 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
server = uvicorn.Server(config=config)
server.run()
+
+
+"""
+python server.py /home/lcyab/data/models/experiments5/checkpoint/experiment5-2023-10-20-21-53-51/modeling/ --tokenizer_path /mnt/vepfs/lcxyc/leaderboard_models/Colossal-LLaMA-2-7b-base/
+"""
diff --git a/applications/ColossalChat/examples/inference/web_chatbot/utils.py b/applications/ColossalChat/examples/inference/web_chatbot/utils.py
new file mode 100755
index 000000000000..82a1a7255164
--- /dev/null
+++ b/applications/ColossalChat/examples/inference/web_chatbot/utils.py
@@ -0,0 +1,78 @@
+import copy
+import json
+from threading import Lock
+from typing import List
+
+import jieba
+import torch
+from coati.dataset.conversation import default_conversation
+from pydantic import BaseModel, Field
+
+
+def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
+ if "past_key_values" in outputs:
+ model_kwargs["past"] = outputs["past_key_values"]
+ else:
+ model_kwargs["past"] = None
+
+ # update token_type_ids with last value
+ if "token_type_ids" in model_kwargs:
+ token_type_ids = model_kwargs["token_type_ids"]
+ model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
+
+ # update attention mask
+ if "attention_mask" in model_kwargs:
+ attention_mask = model_kwargs["attention_mask"]
+ model_kwargs["attention_mask"] = torch.cat(
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
+ )
+
+ return model_kwargs
+
+
+class Dialogue(BaseModel):
+ instruction: str = Field(min_length=1, example="Count up from 1 to 500.")
+ response: str = Field(example="")
+
+
+class ChatPromptProcessor:
+ SAFE_RESPONSE = "The input/response contains inappropriate content, please rephrase your prompt."
+
+ def __init__(self, censored_words: List[str] = []):
+ self.censored_words = set([word.lower() for word in censored_words])
+ self.conv = copy.deepcopy(default_conversation)
+
+ def preprocess_prompt(self, history: List[Dialogue]) -> str:
+ self.conv.clear()
+ for round in history:
+ self.conv.append_message(self.conv.roles[0], round.instruction)
+ if len(round.instruction) > 0:
+ self.conv.append_message(self.conv.roles[1], round.response)
+ return self.conv.get_prompt()
+
+ def postprocess_output(self, output: str) -> str:
+ return output.strip()
+
+ def has_censored_words(self, text: str) -> bool:
+ if len(self.censored_words) == 0:
+ return False
+ intersection = set(jieba.cut(text.lower())) & self.censored_words
+ return len(intersection) > 0
+
+
+class LockedIterator:
+ def __init__(self, it, lock: Lock) -> None:
+ self.lock = lock
+ self.it = iter(it)
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ with self.lock:
+ return next(self.it)
+
+
+def load_json(path: str):
+ with open(path) as f:
+ return json.load(f)
diff --git a/applications/Chat/examples/ray/1mmt_prompt.py b/applications/ColossalChat/examples/ray/1mmt_prompt.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/examples/ray/1mmt_prompt.py
rename to applications/ColossalChat/examples/ray/1mmt_prompt.py
diff --git a/applications/Chat/examples/ray/mmmt_prompt.py b/applications/ColossalChat/examples/ray/mmmt_prompt.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/examples/ray/mmmt_prompt.py
rename to applications/ColossalChat/examples/ray/mmmt_prompt.py
diff --git a/applications/Chat/examples/ray/requirements.txt b/applications/ColossalChat/examples/ray/requirements.txt
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/examples/ray/requirements.txt
rename to applications/ColossalChat/examples/ray/requirements.txt
diff --git a/applications/Chat/examples/ray/test_ci.sh b/applications/ColossalChat/examples/ray/test_ci.sh
similarity index 100%
rename from applications/Chat/examples/ray/test_ci.sh
rename to applications/ColossalChat/examples/ray/test_ci.sh
diff --git a/applications/Chat/examples/requirements.txt b/applications/ColossalChat/examples/requirements.txt
similarity index 51%
rename from applications/Chat/examples/requirements.txt
rename to applications/ColossalChat/examples/requirements.txt
index 5474dfa16b3e..838590f4b103 100644
--- a/applications/Chat/examples/requirements.txt
+++ b/applications/ColossalChat/examples/requirements.txt
@@ -1,3 +1,4 @@
pandas>=1.4.1
sentencepiece
-colossalai==0.3.3
+colossalai
+prompt_toolkit
diff --git a/applications/ColossalChat/examples/training_scripts/hostfile b/applications/ColossalChat/examples/training_scripts/hostfile
new file mode 100755
index 000000000000..d4118dda9783
--- /dev/null
+++ b/applications/ColossalChat/examples/training_scripts/hostfile
@@ -0,0 +1 @@
+10.20.1.82
diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.py b/applications/ColossalChat/examples/training_scripts/train_dpo.py
new file mode 100755
index 000000000000..b9287eb1a407
--- /dev/null
+++ b/applications/ColossalChat/examples/training_scripts/train_dpo.py
@@ -0,0 +1,326 @@
+import argparse
+import json
+import os
+import resource
+from contextlib import nullcontext
+
+import torch
+from coati.dataset import (
+ DataCollatorForPreferenceDataset,
+ StatefulDistributedSampler,
+ load_tokenized_dataset,
+ setup_distributed_dataloader,
+)
+from coati.models import convert_to_lora_module, disable_dropout
+from coati.trainer import DPOTrainer
+from coati.utils import load_checkpoint
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.logging import get_dist_logger
+from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
+from colossalai.nn.optimizer import HybridAdam
+
+logger = get_dist_logger()
+
+
+def train(args):
+ # check lora compatibility
+ if "gemini" in args.plugin and args.lora_rank > 0:
+ raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
+ if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
+ raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
+
+ # ==============================
+ # Initialize Distributed Training
+ # ==============================
+ colossalai.launch_from_torch({})
+ coordinator = DistCoordinator()
+
+ # ==============================
+ # Initialize Booster
+ # ==============================
+ if args.plugin == "ddp":
+ """
+ Default torch ddp plugin without any acceleration, for
+ debugging purpose acceleration, for debugging purpose
+ """
+ plugin = TorchDDPPlugin(find_unused_parameters=True)
+ elif args.plugin == "gemini":
+ plugin = GeminiPlugin(
+ precision=args.mixed_precision,
+ placement_policy="static",
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ enable_gradient_accumulation=True,
+ )
+ elif args.plugin == "gemini_auto":
+ plugin = GeminiPlugin(
+ precision=args.mixed_precision,
+ placement_policy="auto",
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "zero2":
+ plugin = LowLevelZeroPlugin(
+ stage=2,
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "zero2_cpu":
+ plugin = LowLevelZeroPlugin(
+ stage=2,
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ cpu_offload=True,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "3d":
+ plugin = HybridParallelPlugin(
+ tp_size=args.tp,
+ pp_size=1,
+ zero_stage=0,
+ parallel_output=False,
+ precision=args.mixed_precision,
+ )
+ else:
+ raise ValueError(f"Unknown plugin {args.plugin}")
+
+ booster = Booster(plugin=plugin)
+ ref_booster = Booster(plugin=plugin)
+
+ # ======================================================
+ # Initialize Model, Objective, Optimizer and LR Scheduler
+ # ======================================================
+ # Temp Fix: Disable lazy init due to version conflict
+ # init_ctx = (
+ # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
+ # )
+
+ init_ctx = nullcontext()
+ with init_ctx:
+ if args.use_flash_attn:
+ model = AutoModelForCausalLM.from_pretrained(
+ args.pretrain,
+ torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
+ use_flash_attention_2=True,
+ )
+ coordinator.print_on_master(msg="Flash-attention enabled successfully")
+ else:
+ model = AutoModelForCausalLM.from_pretrained(args.pretrain)
+ disable_dropout(model)
+ if args.enable_reference_model:
+ if args.use_flash_attn:
+ ref_model = AutoModelForCausalLM.from_pretrained(
+ args.pretrain,
+ torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
+ use_flash_attention_2=True,
+ )
+ else:
+ ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
+ disable_dropout(ref_model)
+ else:
+ ref_model = None
+
+ if args.lora_rank > 0:
+ model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
+
+ if args.grad_checkpoint and args.lora_rank == 0:
+ model.gradient_checkpointing_enable()
+ coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
+ elif args.lora_rank > 0:
+ coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
+
+ # configure tokenizer
+ tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
+ if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
+ try:
+ # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
+ tokenizer.pad_token = tokenizer.eos_token
+ except AttributeError as e:
+ logger.warning(f"Unable to set pad token to eos token, {str(e)}")
+ if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
+ logger.warning(
+ "The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
+ )
+
+ tokenizer.add_bos_token = False
+ tokenizer.add_eos_token = False
+
+ # configure optimizer
+ optim = HybridAdam(
+ model_params=model.parameters(),
+ lr=args.lr,
+ betas=(0.9, 0.95),
+ weight_decay=args.weight_decay,
+ adamw_mode=True,
+ )
+
+ # configure dataset
+ coordinator.print_on_master(f"Load dataset: {args.dataset}")
+ mode_map = {"train": "train", "valid": "validation", "test": "test"}
+ train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
+ data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
+ train_dataloader = setup_distributed_dataloader(
+ dataset=train_dataset,
+ batch_size=args.batch_size,
+ shuffle=True,
+ drop_last=True,
+ collate_fn=data_collator,
+ use_tp=args.tp > 1,
+ )
+
+ num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
+ if args.warmup_steps is None:
+ args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
+ coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
+
+ lr_scheduler = CosineAnnealingWarmupLR(
+ optimizer=optim,
+ total_steps=args.max_epochs * num_update_steps_per_epoch,
+ warmup_steps=args.warmup_steps,
+ eta_min=0.1 * args.lr,
+ )
+
+ default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
+ torch.set_default_dtype(default_dtype)
+ model, optim, _, train_dataloader, lr_scheduler = booster.boost(
+ model=model,
+ optimizer=optim,
+ lr_scheduler=lr_scheduler,
+ dataloader=train_dataloader,
+ )
+ if ref_model is not None:
+ ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_dataloader)
+ torch.set_default_dtype(torch.float)
+
+ coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
+ coordinator.print_on_master(
+ f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
+ )
+
+ start_epoch = 0
+ sampler_start_idx = 0
+ start_step = 0
+ if args.checkpoint_path is not None:
+ if "modeling" in args.checkpoint_path:
+ coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
+ booster.load_model(model, args.checkpoint_path)
+ else:
+ coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
+ start_epoch, start_step, sampler_start_idx = load_checkpoint(
+ load_dir=args.checkpoint_path,
+ booster=booster,
+ model=model,
+ optimizer=optim,
+ lr_scheduler=lr_scheduler,
+ )
+ assert isinstance(train_dataloader.sampler, StatefulDistributedSampler)
+ train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
+
+ coordinator.print_on_master(
+ f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
+ )
+ coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
+
+ coordinator.print_on_master(
+ f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+ coordinator.print_on_master(
+ f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+ coordinator.print_on_master(
+ f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
+ )
+
+ trainer = DPOTrainer(
+ actor=model,
+ ref_model=ref_model,
+ booster=booster,
+ actor_optim=optim,
+ actor_lr_scheduler=lr_scheduler,
+ tokenizer=tokenizer,
+ max_epochs=args.max_epochs,
+ accumulation_steps=args.accumulation_steps,
+ start_epoch=start_epoch,
+ save_interval=args.save_interval,
+ save_dir=args.save_dir,
+ coordinator=coordinator,
+ )
+
+ trainer.fit(
+ train_preference_dataloader=train_dataloader,
+ eval_preference_dataloader=None,
+ log_dir=args.log_dir,
+ use_wandb=args.use_wandb,
+ )
+
+ if args.lora_rank > 0 and args.merge_lora_weights:
+ from coati.models.lora import LORA_MANAGER
+
+ # NOTE: set model to eval to merge LoRA weights
+ LORA_MANAGER.merge_weights = True
+ model.eval()
+ # save model checkpoint after fitting on only rank0
+ coordinator.print_on_master("Start saving final model checkpoint")
+ booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
+ coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}")
+
+ coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
+
+
+if __name__ == "__main__":
+ # ==============================
+ # Parse Arguments
+ # ==============================
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--plugin",
+ type=str,
+ default="gemini",
+ choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
+ help="Choose which plugin to use",
+ )
+ parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
+ parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
+ parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
+ parser.add_argument("--tp", type=int, default=1)
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--model_type", type=str, default=None)
+ parser.add_argument("--tokenizer_dir", type=str, default=None)
+ parser.add_argument("--dataset", nargs="+", default=[])
+ parser.add_argument(
+ "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
+ )
+ parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
+ parser.add_argument("--save_dir", type=str, default="output")
+ parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
+ parser.add_argument("--max_epochs", type=int, default=3)
+ parser.add_argument("--batch_size", type=int, default=4)
+ parser.add_argument("--enable_reference_model", type=bool, default=True)
+ parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument(
+ "--lora_train_bias",
+ type=str,
+ default="none",
+ help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
+ )
+ parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
+ parser.add_argument("--merge_lora_weights", type=bool, default=True)
+ parser.add_argument("--lr", type=float, default=5e-6)
+ parser.add_argument("--accumulation_steps", type=int, default=8)
+ parser.add_argument("--log_dir", default="logs", type=str)
+ parser.add_argument("--use_wandb", default=False, action="store_true")
+ parser.add_argument("--grad_checkpoint", default=False, action="store_true")
+ parser.add_argument("--use_flash_attn", default=False, action="store_true")
+ args = parser.parse_args()
+ os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
+ with open(args.config_file, "w") as f:
+ json.dump(args.__dict__, f, indent=4)
+ train(args)
diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.sh b/applications/ColossalChat/examples/training_scripts/train_dpo.sh
new file mode 100755
index 000000000000..80fc30c3d955
--- /dev/null
+++ b/applications/ColossalChat/examples/training_scripts/train_dpo.sh
@@ -0,0 +1,62 @@
+#!/bin/bash
+set_n_least_used_CUDA_VISIBLE_DEVICES() {
+ local n=${1:-"9999"}
+ echo "GPU Memory Usage:"
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
+ tail -n +2 |
+ nl -v 0 |
+ tee /dev/tty |
+ sort -g -k 2 |
+ awk '{print $1}' |
+ head -n $n)
+ export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
+ echo "Now CUDA_VISIBLE_DEVICES is set to:"
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+}
+set_n_least_used_CUDA_VISIBLE_DEVICES 8
+# export CUDA_VISIBLE_DEVICES=6
+
+PROJECT_NAME="dpo"
+PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
+PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
+PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
+PRETRAINED_MODEL_PATH="" # huggingface or local model path
+PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
+
+declare -a dataset=(
+ YOUR/DATA/DIR/arrow/part-00000
+ YOUR/DATA/DIR/arrow/part-00001
+ YOUR/DATA/DIR/arrow/part-00002
+ YOUR/DATA/DIR/arrow/part-00003
+ YOUR/DATA/DIR/arrow/part-00004
+ YOUR/DATA/DIR/arrow/part-00005
+ YOUR/DATA/DIR/arrow/part-00006
+ YOUR/DATA/DIR/arrow/part-00007
+ YOUR/DATA/DIR/arrow/part-00008
+ YOUR/DATA/DIR/arrow/part-00009
+)
+
+TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
+FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
+SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
+CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
+
+colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_dpo.py \
+ --pretrain $PRETRAINED_MODEL_PATH \
+ --checkpoint_path $PRETRAINED_MODEL_PATH \
+ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
+ --dataset ${dataset[@]} \
+ --plugin "zero2" \
+ --save_interval 1000 \
+ --save_dir $SAVE_DIR \
+ --config_file $CONFIG_FILE \
+ --max_epochs 1 \
+ --accumulation_steps 4 \
+ --batch_size 2 \
+ --lr 1e-6 \
+ --mixed_precision "bf16" \
+ --grad_clip 1.0 \
+ --weight_decay 0.01 \
+ --warmup_steps 100 \
+ --grad_checkpoint \
+ --use_wandb
diff --git a/applications/ColossalChat/examples/training_scripts/train_ppo.py b/applications/ColossalChat/examples/training_scripts/train_ppo.py
new file mode 100755
index 000000000000..7c91fa347847
--- /dev/null
+++ b/applications/ColossalChat/examples/training_scripts/train_ppo.py
@@ -0,0 +1,506 @@
+import argparse
+import json
+import os
+import resource
+from contextlib import nullcontext
+
+import torch
+import torch.distributed as dist
+from coati.dataset import (
+ DataCollatorForPromptDataset,
+ DataCollatorForSupervisedDataset,
+ StatefulDistributedSampler,
+ load_tokenized_dataset,
+ setup_conversation_template,
+ setup_distributed_dataloader,
+)
+from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dropout
+from coati.trainer import PPOTrainer
+from coati.utils import load_checkpoint
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.logging import get_dist_logger
+from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
+from colossalai.nn.optimizer import HybridAdam
+
+logger = get_dist_logger()
+
+
+def train(args):
+ # check lora compatibility
+ if "gemini" in args.plugin and args.lora_rank > 0:
+ raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
+ if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
+ raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
+ # ==============================
+ # Initialize Distributed Training
+ # ==============================
+ colossalai.launch_from_torch({})
+ coordinator = DistCoordinator()
+
+ # ======================================================
+ # Initialize Model, Objective, Optimizer and LR Scheduler
+ # ======================================================
+ # Temp Fix: Disable lazy init due to version conflict
+ # init_ctx = (
+ # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
+ # )
+
+ init_ctx = nullcontext()
+ booster_policy = None
+ with init_ctx:
+ if args.use_flash_attn:
+ actor = AutoModelForCausalLM.from_pretrained(
+ args.pretrain,
+ torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
+ use_flash_attention_2=True,
+ local_files_only=True,
+ )
+ ref_model = AutoModelForCausalLM.from_pretrained(
+ args.pretrain,
+ torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
+ use_flash_attention_2=True,
+ local_files_only=True,
+ )
+ reward_model = RewardModel(
+ args.rm_pretrain,
+ torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
+ use_flash_attention_2=True,
+ )
+ critic = Critic(
+ args.rm_pretrain,
+ torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
+ use_flash_attention_2=True,
+ )
+ coordinator.print_on_master(msg="Flash-attention enabled successfully")
+ else:
+ actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
+ ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
+ reward_model = RewardModel(args.rm_pretrain)
+ critic = Critic(args.rm_pretrain)
+ # Disable dropout
+ disable_dropout(actor)
+ disable_dropout(critic)
+
+ if args.tp > 1:
+ if reward_model.model.config.architectures[0] != critic.model.config.architectures[0]:
+ raise ValueError("Reward model and critic model must have the same architecture")
+ if reward_model.model.config.architectures[0] == "BloomForCausalLM":
+ from colossalai.shardformer.policies.bloom import BloomPolicy
+
+ booster_policy = BloomPolicy()
+ elif reward_model.model.config.architectures[0] == "LlamaForCausalLM":
+ from colossalai.shardformer.policies.llama import LlamaPolicy
+
+ booster_policy = LlamaPolicy()
+ elif reward_model.model.config.architectures[0] == "GPT2LMHeadModel":
+ from colossalai.shardformer.policies.gpt2 import GPT2Policy
+
+ booster_policy = GPT2Policy()
+ elif reward_model.model.config.architectures[0] == "ChatGLMModel":
+ from colossalai.shardformer.policies.chatglm2 import ChatGLMPolicy
+
+ booster_policy = ChatGLMPolicy()
+ elif reward_model.model.config.architectures[0] == "OPTForCausalLM":
+ from colossalai.shardformer.policies.opt import OPTPolicy
+
+ booster_policy = OPTPolicy()
+ else:
+ raise ValueError("Unknown model architecture for policy")
+
+ if args.lora_rank > 0:
+ actor = convert_to_lora_module(actor, args.lora_rank, lora_train_bias=args.lora_train_bias)
+ critic = convert_to_lora_module(critic, args.lora_rank, lora_train_bias=args.lora_train_bias)
+
+ if args.grad_checkpoint and args.lora_rank == 0:
+ actor.gradient_checkpointing_enable()
+ critic.model.gradient_checkpointing_enable()
+ coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
+ elif args.lora_rank > 0:
+ coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
+
+ # configure tokenizer
+ tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
+ if os.path.exists(args.conversation_template_config):
+ with open(args.conversation_template_config, "r", encoding="utf8") as f:
+ conversation_template_config = json.load(f)
+ dist.barrier()
+ conversation_template = setup_conversation_template(
+ tokenizer, chat_template_config=conversation_template_config, save_path=args.conversation_template_config
+ )
+ stop_ids = conversation_template.stop_ids if len(conversation_template.stop_ids) > 0 else None
+ else:
+ raise ValueError("Conversation template config is not provided or incorrect")
+ if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
+ try:
+ # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
+ tokenizer.pad_token = tokenizer.eos_token
+ except AttributeError as e:
+ logger.warning(f"Unable to set pad token to eos token, {str(e)}")
+ if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
+ logger.warning(
+ "The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
+ )
+
+ tokenizer.add_bos_token = False
+ tokenizer.add_eos_token = False
+ tokenizer.padding_side = "left" # left padding for generation (online learning)
+
+ # configure generation config
+ actor.generation_config.update(
+ pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id
+ )
+
+ # configure optimizer
+ coordinator.print_on_master(f"setting up optimizer for actor: lr={args.lr}, weight_decay={args.weight_decay}")
+ actor_optim = HybridAdam(
+ model_params=actor.parameters(),
+ lr=args.lr,
+ betas=(0.9, 0.95),
+ weight_decay=args.weight_decay,
+ adamw_mode=True,
+ )
+
+ coordinator.print_on_master(f"setting up optimizer for critic: lr={args.lr}, weight_decay={args.weight_decay}")
+ critic_optim = HybridAdam(
+ model_params=critic.parameters(),
+ lr=args.critic_lr,
+ betas=(0.9, 0.95),
+ weight_decay=args.weight_decay,
+ adamw_mode=True,
+ )
+
+ # configure dataset
+ coordinator.print_on_master(f"Load dataset: {args.prompt_dataset}")
+ mode_map = {"train": "train", "valid": "validation", "test": "test"}
+ train_prompt_dataset = load_tokenized_dataset(dataset_paths=args.prompt_dataset, mode="train", mode_map=mode_map)
+ data_collator = DataCollatorForPromptDataset(tokenizer=tokenizer, max_length=args.max_length - args.max_seq_len)
+ train_prompt_dataloader = setup_distributed_dataloader(
+ dataset=train_prompt_dataset,
+ batch_size=args.experience_batch_size,
+ shuffle=True,
+ drop_last=True,
+ collate_fn=data_collator,
+ use_tp=args.tp > 1,
+ )
+
+ if len(args.ptx_dataset) > 0:
+ train_ptx_dataset = load_tokenized_dataset(dataset_paths=args.ptx_dataset, mode="train", mode_map=mode_map)
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
+ train_pretrain_dataloader = setup_distributed_dataloader(
+ dataset=train_ptx_dataset,
+ batch_size=args.ptx_batch_size,
+ shuffle=True,
+ drop_last=True,
+ collate_fn=data_collator,
+ use_tp=args.tp > 1,
+ )
+ else:
+ train_pretrain_dataloader = None
+
+ if args.warmup_steps is None:
+ args.warmup_steps = int(0.025 * args.num_episodes)
+ coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
+
+ actor_lr_scheduler = CosineAnnealingWarmupLR(
+ optimizer=actor_optim,
+ total_steps=args.num_episodes,
+ warmup_steps=args.warmup_steps,
+ eta_min=0.1 * args.lr,
+ )
+
+ critic_lr_scheduler = CosineAnnealingWarmupLR(
+ optimizer=critic_optim,
+ total_steps=args.num_episodes,
+ warmup_steps=args.warmup_steps,
+ eta_min=0.1 * args.lr,
+ )
+
+ # ==============================
+ # Initialize Booster
+ # ==============================
+ if args.plugin == "ddp":
+ """
+ Default torch ddp plugin without any acceleration, for
+ debugging purpose acceleration, for debugging purpose
+ """
+ plugin = TorchDDPPlugin(find_unused_parameters=True)
+ elif args.plugin == "gemini":
+ plugin = GeminiPlugin(
+ precision=args.mixed_precision,
+ placement_policy="static",
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ enable_gradient_accumulation=True,
+ )
+ elif args.plugin == "gemini_auto":
+ plugin = GeminiPlugin(
+ precision=args.mixed_precision,
+ placement_policy="auto",
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "zero2":
+ plugin = LowLevelZeroPlugin(
+ stage=2,
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "zero2_cpu":
+ plugin = LowLevelZeroPlugin(
+ stage=2,
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ cpu_offload=True,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "3d":
+ plugin = HybridParallelPlugin(
+ tp_size=args.tp,
+ pp_size=1,
+ zero_stage=0,
+ parallel_output=False,
+ precision=args.mixed_precision,
+ )
+ custom_plugin = HybridParallelPlugin(
+ tp_size=args.tp,
+ pp_size=1,
+ zero_stage=0,
+ parallel_output=False,
+ precision=args.mixed_precision,
+ custom_policy=booster_policy,
+ )
+ else:
+ raise ValueError(f"Unknown plugin {args.plugin}")
+
+ if args.plugin != "3d":
+ custom_plugin = plugin
+
+ actor_booster = Booster(plugin=plugin)
+ ref_booster = Booster(plugin=plugin)
+ rm_booster = Booster(plugin=custom_plugin)
+ critic_booster = Booster(plugin=custom_plugin)
+
+ default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
+ torch.set_default_dtype(default_dtype)
+ actor, actor_optim, _, train_prompt_dataloader, actor_lr_scheduler = actor_booster.boost(
+ model=actor,
+ optimizer=actor_optim,
+ lr_scheduler=actor_lr_scheduler,
+ dataloader=train_prompt_dataloader,
+ )
+
+ critic, critic_optim, _, _, critic_lr_scheduler = critic_booster.boost(
+ model=critic,
+ optimizer=critic_optim,
+ lr_scheduler=critic_lr_scheduler,
+ dataloader=train_prompt_dataloader,
+ )
+ reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader)
+ ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_prompt_dataloader)
+
+ torch.set_default_dtype(torch.float)
+
+ coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
+ coordinator.print_on_master(
+ f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
+ )
+
+ sampler_start_idx = 0
+ start_step = 0
+
+ if args.rm_checkpoint_path is not None:
+ if "modeling" in args.rm_checkpoint_path:
+ rm_booster.load_model(reward_model, args.rm_checkpoint_path)
+ else:
+ _, _, _ = load_checkpoint(
+ load_dir=args.rm_checkpoint_path,
+ booster=rm_booster,
+ model=reward_model,
+ optimizer=None,
+ lr_scheduler=None,
+ )
+ coordinator.print_on_master(f"Loaded reward model checkpoint {args.rm_checkpoint_path}")
+
+ if args.checkpoint_path is not None:
+ if "modeling" in args.checkpoint_path:
+ actor_booster.load_model(actor, args.checkpoint_path)
+ ref_booster.load_model(ref_model, args.checkpoint_path)
+ coordinator.print_on_master(f"Loaded actor and reference model {args.checkpoint_path}")
+ else:
+ _, start_step, sampler_start_idx = load_checkpoint(
+ load_dir=args.checkpoint_path,
+ booster=actor_booster,
+ model=actor,
+ optimizer=actor_optim,
+ lr_scheduler=actor_lr_scheduler,
+ )
+ _, _, _ = load_checkpoint(
+ load_dir=args.checkpoint_path,
+ booster=ref_booster,
+ model=ref_model,
+ optimizer=critic_optim,
+ lr_scheduler=critic_lr_scheduler,
+ )
+ assert isinstance(train_prompt_dataloader.sampler, StatefulDistributedSampler)
+ train_prompt_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
+
+ coordinator.print_on_master(
+ f"Loaded actor and reference model checkpoint {args.checkpoint_path} at spisode {start_step}"
+ )
+ coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
+
+ coordinator.print_on_master(
+ f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+ coordinator.print_on_master(
+ f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+ coordinator.print_on_master(
+ f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
+ )
+
+ if args.critic_checkpoint_path is not None:
+ if "modeling" in args.critic_checkpoint_path:
+ critic_booster.load_model(critic, args.critic_checkpoint_path)
+ else:
+ _, _, _ = load_checkpoint(
+ load_dir=args.critic_checkpoint_path,
+ booster=critic_booster,
+ model=critic,
+ optimizer=critic_optim,
+ lr_scheduler=critic_lr_scheduler,
+ )
+ coordinator.print_on_master(f"Loaded critic checkpoint {args.critic_checkpoint_path}")
+ coordinator.print_on_master(
+ f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+ coordinator.print_on_master(
+ f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+ coordinator.print_on_master(
+ f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
+ )
+
+ # configure trainer
+ trainer = PPOTrainer(
+ actor_booster,
+ critic_booster,
+ actor,
+ critic,
+ reward_model,
+ ref_model,
+ actor_optim,
+ critic_optim,
+ actor_lr_scheduler,
+ critic_lr_scheduler,
+ tokenizer=tokenizer,
+ stop_token_ids=stop_ids,
+ kl_coef=args.kl_coef,
+ ptx_coef=args.ptx_coef,
+ train_batch_size=args.train_batch_size,
+ buffer_limit=args.num_collect_steps * args.experience_batch_size,
+ max_length=args.max_length,
+ max_new_tokens=args.max_seq_len,
+ use_cache=True,
+ do_sample=True,
+ temperature=0.7,
+ accumulation_steps=args.accumulation_steps,
+ save_dir=args.save_path,
+ save_interval=args.save_interval,
+ top_k=50,
+ use_tp=args.tp > 1,
+ offload_inference_models="gemini" not in args.plugin,
+ coordinator=coordinator,
+ )
+
+ trainer.fit(
+ num_episodes=args.num_episodes,
+ num_collect_steps=args.num_collect_steps,
+ num_update_steps=args.num_update_steps,
+ prompt_dataloader=train_prompt_dataloader,
+ pretrain_dataloader=train_pretrain_dataloader,
+ log_dir=args.log_dir,
+ use_wandb=args.use_wandb,
+ )
+
+ if args.lora_rank > 0 and args.merge_lora_weights:
+ from coati.models.lora import LORA_MANAGER
+
+ # NOTE: set model to eval to merge LoRA weights
+ LORA_MANAGER.merge_weights = True
+ actor.eval()
+ critic.eval()
+ # save model checkpoint after fitting on only rank0
+ coordinator.print_on_master("Start saving final actor model checkpoint")
+ actor_booster.save_model(actor, os.path.join(trainer.actor_save_dir, "modeling"), shard=True)
+ coordinator.print_on_master(
+ f"Saved final actor model checkpoint at episodes {args.num_episodes} at folder {args.save_path}"
+ )
+ coordinator.print_on_master("Start saving final critic model checkpoint")
+ critic_booster.save_model(critic, os.path.join(trainer.critic_save_dir, "modeling"), shard=True)
+ coordinator.print_on_master(
+ f"Saved final critic model checkpoint at episodes {args.num_episodes} at folder {args.save_path}"
+ )
+ coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--prompt_dataset", nargs="+", default=[])
+ parser.add_argument("--ptx_dataset", nargs="+", default=[])
+ parser.add_argument(
+ "--plugin",
+ type=str,
+ default="gemini",
+ choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
+ help="Choose which plugin to use",
+ )
+ parser.add_argument(
+ "--conversation_template_config",
+ type=str,
+ default=None,
+ help="Path \
+ to save conversation template config files.",
+ )
+ parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
+ parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
+ parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
+ parser.add_argument("--tokenizer_dir", type=str, default=None)
+ parser.add_argument("--tp", type=int, default=1)
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--rm_pretrain", type=str, default=None)
+ parser.add_argument("--checkpoint_path", type=str, default=None)
+ parser.add_argument("--critic_checkpoint_path", type=str, default=None)
+ parser.add_argument("--rm_checkpoint_path", type=str, help="Reward model checkpoint path")
+ parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
+ parser.add_argument("--num_episodes", type=int, default=1)
+ parser.add_argument("--num_collect_steps", type=int, default=2)
+ parser.add_argument("--num_update_steps", type=int, default=5)
+ parser.add_argument("--save_interval", type=int, default=1000)
+ parser.add_argument("--train_batch_size", type=int, default=16)
+ parser.add_argument("--experience_batch_size", type=int, default=16)
+ parser.add_argument("--ptx_batch_size", type=int, default=4)
+ parser.add_argument("--lora_train_bias", type=str, default="none")
+ parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
+ parser.add_argument("--accumulation_steps", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument("--merge_lora_weights", type=bool, default=True)
+ parser.add_argument("--lr", type=float, default=9e-6)
+ parser.add_argument("--critic_lr", type=float, default=9e-6)
+ parser.add_argument("--kl_coef", type=float, default=0.1)
+ parser.add_argument("--ptx_coef", type=float, default=0.0)
+ parser.add_argument("--max_length", type=int, default=2048)
+ parser.add_argument("--max_seq_len", type=int, default=256)
+ parser.add_argument("--log_dir", default="logs", type=str)
+ parser.add_argument("--use_wandb", default=False, action="store_true")
+ parser.add_argument("--grad_checkpoint", default=False, action="store_true")
+ parser.add_argument("--use_flash_attn", default=False, action="store_true")
+ args = parser.parse_args()
+ train(args)
diff --git a/applications/ColossalChat/examples/training_scripts/train_ppo.sh b/applications/ColossalChat/examples/training_scripts/train_ppo.sh
new file mode 100755
index 000000000000..91633978e6ff
--- /dev/null
+++ b/applications/ColossalChat/examples/training_scripts/train_ppo.sh
@@ -0,0 +1,82 @@
+#!/bin/bash
+set_n_least_used_CUDA_VISIBLE_DEVICES() {
+ local n=${1:-"9999"}
+ echo "GPU Memory Usage:"
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
+ tail -n +2 |
+ nl -v 0 |
+ tee /dev/tty |
+ sort -g -k 2 |
+ awk '{print $1}' |
+ head -n $n)
+ export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
+ echo "Now CUDA_VISIBLE_DEVICES is set to:"
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+}
+set_n_least_used_CUDA_VISIBLE_DEVICES 8
+
+PROJECT_NAME="ppo"
+
+PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
+PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
+PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
+PRETRAINED_MODEL_PATH="" # local pretrained model path (from RLHF step 1: SFT)
+PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
+REWARD_MODEL_PATH="" # local reward model path (from RLHF step 2: Train Reward Model)
+CONVERSATION_TEMPLATE_CONFIG_PATH="" # path to the conversation config file
+
+declare -a prompt_dataset=(
+ YOUR/PROMPT/DATA/DIR/arrow/part-00000
+ YOUR/PROMPT/DATA/DIR/arrow/part-00001
+ YOUR/PROMPT/DATA/DIR/arrow/part-00002
+ YOUR/PROMPT/DATA/DIR/arrow/part-00003
+ YOUR/PROMPT/DATA/DIR/arrow/part-00004
+ YOUR/PROMPT/DATA/DIR/arrow/part-00005
+ YOUR/PROMPT/DATA/DIR/arrow/part-00006
+ YOUR/PROMPT/DATA/DIR/arrow/part-00007
+ YOUR/PROMPT/DATA/DIR/arrow/part-00008
+ YOUR/PROMPT/DATA/DIR/arrow/part-00009
+)
+
+declare -a ptx_dataset=(
+ YOUR/SFT/DATA/DIR/arrow/part-00000
+ YOUR/SFT/DATA/DIR/arrow/part-00001
+ YOUR/SFT/DATA/DIR/arrow/part-00002
+ YOUR/SFT/DATA/DIR/arrow/part-00003
+ YOUR/SFT/DATA/DIR/arrow/part-00004
+ YOUR/SFT/DATA/DIR/arrow/part-00005
+ YOUR/SFT/DATA/DIR/arrow/part-00006
+ YOUR/SFT/DATA/DIR/arrow/part-00007
+ YOUR/SFT/DATA/DIR/arrow/part-00008
+ YOUR/SFT/DATA/DIR/arrow/part-00009
+)
+
+TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
+FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
+SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
+CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
+
+colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_ppo.py \
+ --pretrain $PRETRAINED_MODEL_PATH \
+ --rm_pretrain $PRETRAINED_MODEL_PATH \
+ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
+ --rm_checkpoint_path $REWARD_MODEL_PATH \
+ --prompt_dataset ${prompt_dataset[@]} \
+ --conversation_template_config $CONVERSATION_TEMPLATE_CONFIG_PATH \
+ --ptx_coef 0.0 \
+ --plugin "zero2" \
+ --save_interval 500 \
+ --save_path $SAVE_DIR \
+ --num_episodes 2000 \
+ --num_collect_steps 2 \
+ --num_update_steps 1 \
+ --experience_batch_size 4 \
+ --train_batch_size 4 \
+ --accumulation_steps 2 \
+ --lr 9e-6 \
+ --mixed_precision "bf16" \
+ --grad_clip 0.1\
+ --weight_decay 0.01 \
+ --warmup_steps 40 \
+ --grad_checkpoint \
+ --use_wandb
diff --git a/applications/ColossalChat/examples/training_scripts/train_rm.py b/applications/ColossalChat/examples/training_scripts/train_rm.py
new file mode 100755
index 000000000000..a0c710f2bb7f
--- /dev/null
+++ b/applications/ColossalChat/examples/training_scripts/train_rm.py
@@ -0,0 +1,342 @@
+import argparse
+import json
+import math
+import os
+import resource
+from contextlib import nullcontext
+
+import torch
+from coati.dataset import (
+ DataCollatorForPreferenceDataset,
+ StatefulDistributedSampler,
+ load_tokenized_dataset,
+ setup_distributed_dataloader,
+)
+from coati.models import LogExpLoss, LogSigLoss, RewardModel, convert_to_lora_module
+from coati.trainer import RewardModelTrainer
+from coati.utils import load_checkpoint
+from transformers import AutoTokenizer
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
+from colossalai.nn.optimizer import HybridAdam
+
+
+def train(args):
+ # check lora compatibility
+ if "gemini" in args.plugin and args.lora_rank > 0:
+ raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
+ if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
+ raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
+ # ==============================
+ # Initialize Distributed Training
+ # ==============================
+ colossalai.launch_from_torch({})
+ coordinator = DistCoordinator()
+
+ # ======================================================
+ # Initialize Model, Objective, Optimizer and LR Scheduler
+ # ======================================================
+ # Temp Fix: Disable lazy init due to version conflict
+ # init_ctx = (
+ # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
+ # )
+
+ init_ctx = nullcontext()
+ booster_policy = None
+ with init_ctx:
+ if args.use_flash_attn:
+ model = RewardModel(
+ args.pretrain,
+ torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
+ use_flash_attention_2=True,
+ )
+ coordinator.print_on_master(msg="Flash-attention enabled successfully")
+ else:
+ model = RewardModel(args.pretrain)
+
+ if args.tp > 1:
+ if model.model.config.architectures[0] == "BloomForCausalLM":
+ from colossalai.shardformer.policies.bloom import BloomPolicy
+
+ booster_policy = BloomPolicy()
+ elif model.model.config.architectures[0] == "LlamaForCausalLM":
+ from colossalai.shardformer.policies.llama import LlamaPolicy
+
+ booster_policy = LlamaPolicy()
+ elif model.model.config.architectures[0] == "GPT2LMHeadModel":
+ from colossalai.shardformer.policies.gpt2 import GPT2Policy
+
+ booster_policy = GPT2Policy()
+ elif model.model.config.architectures[0] == "ChatGLMModel":
+ from colossalai.shardformer.policies.chatglm2 import ChatGLMPolicy
+
+ booster_policy = ChatGLMPolicy()
+ elif model.model.config.architectures[0] == "OPTForCausalLM":
+ from colossalai.shardformer.policies.opt import OPTPolicy
+
+ booster_policy = OPTPolicy()
+ else:
+ raise ValueError("Unknown model architecture for policy")
+
+ if args.lora_rank > 0:
+ model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
+
+ # ==============================
+ # Initialize Booster
+ # ==============================
+ if args.plugin == "ddp":
+ """
+ Default torch ddp plugin without any acceleration, for
+ debugging purpose acceleration, for debugging purpose
+ """
+ plugin = TorchDDPPlugin(find_unused_parameters=True)
+ elif args.plugin == "gemini":
+ plugin = GeminiPlugin(
+ precision=args.mixed_precision,
+ placement_policy="static",
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ enable_gradient_accumulation=True,
+ )
+ elif args.plugin == "gemini_auto":
+ plugin = GeminiPlugin(
+ precision=args.mixed_precision,
+ placement_policy="auto",
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "zero2":
+ plugin = LowLevelZeroPlugin(
+ stage=2,
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "zero2_cpu":
+ plugin = LowLevelZeroPlugin(
+ stage=2,
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ cpu_offload=True,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "3d":
+ plugin = HybridParallelPlugin(
+ tp_size=args.tp,
+ pp_size=1,
+ zero_stage=0,
+ parallel_output=False,
+ precision=args.mixed_precision,
+ custom_policy=booster_policy,
+ )
+ else:
+ raise ValueError(f"Unknown plugin {args.plugin}")
+
+ booster = Booster(plugin=plugin)
+
+ if args.grad_checkpoint and args.lora_rank == 0:
+ model.model.gradient_checkpointing_enable() # TODO: support gradient checkpoint for the last linear layer
+ coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
+ elif args.lora_rank > 0:
+ coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
+
+ # configure tokenizer
+ tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
+ if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
+ try:
+ # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
+ tokenizer.pad_token = tokenizer.eos_token
+ except AttributeError as e:
+ logger.warning(f"Unable to set pad token to eos token, {str(e)}")
+ if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
+ logger.warning(
+ "The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
+ )
+ tokenizer.padding_side = "right"
+ tokenizer.add_bos_token = False
+ tokenizer.add_eos_token = False
+
+ # configure loss function
+ if args.loss_fn == "log_sig":
+ loss_fn = LogSigLoss()
+ elif args.loss_fn == "log_exp":
+ loss_fn = LogExpLoss()
+ else:
+ raise ValueError(f'Unsupported loss function "{args.loss_fn}"')
+
+ # configure optimizer
+ optim = HybridAdam(
+ model_params=model.parameters(),
+ lr=args.lr,
+ betas=(0.9, 0.95),
+ weight_decay=args.weight_decay,
+ adamw_mode=True,
+ )
+
+ # configure dataset
+ coordinator.print_on_master(f"Load dataset: {args.dataset}")
+ mode_map = {"train": "train", "valid": "validation", "test": "test"}
+ train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
+ data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
+ train_dataloader = setup_distributed_dataloader(
+ dataset=train_dataset,
+ batch_size=args.batch_size,
+ shuffle=True,
+ drop_last=True,
+ collate_fn=data_collator,
+ use_tp=args.tp > 1,
+ )
+
+ num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
+ math.ceil(args.max_epochs * num_update_steps_per_epoch)
+
+ if args.warmup_steps is None:
+ args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
+ coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
+
+ lr_scheduler = CosineAnnealingWarmupLR(
+ optimizer=optim,
+ total_steps=args.max_epochs * num_update_steps_per_epoch,
+ warmup_steps=args.warmup_steps,
+ eta_min=0.1 * args.lr,
+ )
+
+ default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
+ torch.set_default_dtype(default_dtype)
+ model, optim, _, train_dataloader, lr_scheduler = booster.boost(
+ model=model,
+ optimizer=optim,
+ lr_scheduler=lr_scheduler,
+ dataloader=train_dataloader,
+ )
+ torch.set_default_dtype(torch.float)
+
+ coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
+ coordinator.print_on_master(
+ f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
+ )
+
+ start_epoch = 0
+ sampler_start_idx = 0
+ start_step = 0
+ if args.checkpoint_path is not None:
+ if "modeling" in args.checkpoint_path:
+ coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
+ booster.load_model(model, args.checkpoint_path)
+ else:
+ coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
+ start_epoch, start_step, sampler_start_idx = load_checkpoint(
+ load_dir=args.checkpoint_path,
+ booster=booster,
+ model=model,
+ optimizer=optim,
+ lr_scheduler=lr_scheduler,
+ )
+ assert isinstance(train_dataloader.sampler, StatefulDistributedSampler)
+ train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
+
+ coordinator.print_on_master(
+ f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
+ )
+ coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
+
+ coordinator.print_on_master(
+ f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+ coordinator.print_on_master(
+ f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+ coordinator.print_on_master(
+ f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
+ )
+
+ trainer = RewardModelTrainer(
+ model,
+ booster,
+ optim,
+ lr_scheduler,
+ tokenizer,
+ loss_fn=loss_fn,
+ max_epochs=args.max_epochs,
+ accumulation_steps=args.accumulation_steps,
+ start_epoch=start_epoch,
+ save_interval=args.save_interval,
+ save_dir=args.save_dir,
+ coordinator=coordinator,
+ )
+
+ trainer.fit(
+ train_preference_dataloader=train_dataloader,
+ eval_preference_dataloader=None,
+ log_dir=args.log_dir,
+ use_wandb=args.use_wandb,
+ )
+
+ if args.lora_rank > 0 and args.merge_lora_weights:
+ from coati.models.lora import LORA_MANAGER
+
+ # NOTE: set model to eval to merge LoRA weights
+ LORA_MANAGER.merge_weights = True
+ model.eval()
+ # save model checkpoint after fitting on only rank0
+ coordinator.print_on_master("Start saving final model checkpoint")
+ booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
+ coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}")
+
+ coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
+
+
+if __name__ == "__main__":
+ # ==============================
+ # Parse Arguments
+ # ==============================
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--plugin",
+ type=str,
+ default="gemini",
+ choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp"],
+ help="Choose which plugin to use",
+ )
+ parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
+ parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
+ parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
+ parser.add_argument("--tp", type=int, default=1)
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--tokenizer_dir", type=str, default=None)
+ parser.add_argument("--dataset", nargs="+", default=[])
+ parser.add_argument(
+ "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
+ )
+ parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
+ parser.add_argument("--save_dir", type=str, default="output")
+ parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
+ parser.add_argument("--max_epochs", type=int, default=3)
+ parser.add_argument("--batch_size", type=int, default=4)
+ parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
+ parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"], help="Loss function")
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument(
+ "--lora_train_bias",
+ type=str,
+ default="none",
+ help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
+ )
+ parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
+ parser.add_argument("--merge_lora_weights", type=bool, default=True)
+ parser.add_argument("--lr", type=float, default=5e-6)
+ parser.add_argument("--accumulation_steps", type=int, default=8)
+ parser.add_argument("--log_dir", default="logs", type=str)
+ parser.add_argument("--use_wandb", default=False, action="store_true")
+ parser.add_argument("--grad_checkpoint", default=False, action="store_true")
+ parser.add_argument("--use_flash_attn", default=False, action="store_true")
+ args = parser.parse_args()
+ os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
+ with open(args.config_file, "w") as f:
+ json.dump(args.__dict__, f, indent=4)
+ train(args)
diff --git a/applications/ColossalChat/examples/training_scripts/train_rm.sh b/applications/ColossalChat/examples/training_scripts/train_rm.sh
new file mode 100755
index 000000000000..e06d9092fe4c
--- /dev/null
+++ b/applications/ColossalChat/examples/training_scripts/train_rm.sh
@@ -0,0 +1,61 @@
+#!/bin/bash
+set_n_least_used_CUDA_VISIBLE_DEVICES() {
+ local n=${1:-"9999"}
+ echo "GPU Memory Usage:"
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
+ tail -n +2 |
+ nl -v 0 |
+ tee /dev/tty |
+ sort -g -k 2 |
+ awk '{print $1}' |
+ head -n $n)
+ export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
+ echo "Now CUDA_VISIBLE_DEVICES is set to:"
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+}
+set_n_least_used_CUDA_VISIBLE_DEVICES 8
+
+PROJECT_NAME="rm"
+PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
+PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
+PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
+PRETRAINED_MODEL_PATH="" # huggingface or local model path
+PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
+
+declare -a dataset=(
+ YOUR/PREFERENCE/DATA/DIR/arrow/part-00000
+ YOUR/PREFERENCE/DATA/DIR/arrow/part-00001
+ YOUR/PREFERENCE/DATA/DIR/arrow/part-00002
+ YOUR/PREFERENCE/DATA/DIR/arrow/part-00003
+ YOUR/PREFERENCE/DATA/DIR/arrow/part-00004
+ YOUR/PREFERENCE/DATA/DIR/arrow/part-00005
+ YOUR/PREFERENCE/DATA/DIR/arrow/part-00006
+ YOUR/PREFERENCE/DATA/DIR/arrow/part-00007
+ YOUR/PREFERENCE/DATA/DIR/arrow/part-00008
+ YOUR/PREFERENCE/DATA/DIR/arrow/part-00009
+)
+
+TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
+FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
+SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
+CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
+
+colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_rm.py \
+ --pretrain $PRETRAINED_MODEL_PATH \
+ --checkpoint_path /home/yeanbang/data/experiments/rm/hhh_aligh/ckptllama2-rm-2024-01-17-14-43-24/epoch-1_step-1317/modeling \
+ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
+ --dataset ${dataset[@]} \
+ --plugin "zero2" \
+ --save_interval 1000 \
+ --save_dir $SAVE_DIR \
+ --config_file $CONFIG_FILE \
+ --max_epochs 3 \
+ --accumulation_steps 1 \
+ --batch_size 8 \
+ --lr 5e-6 \
+ --mixed_precision "bf16" \
+ --grad_clip 1.0 \
+ --weight_decay 0.01 \
+ --warmup_steps 40 \
+ --grad_checkpoint \
+ --use_wandb
diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.py b/applications/ColossalChat/examples/training_scripts/train_sft.py
new file mode 100755
index 000000000000..fcd1a429cc5f
--- /dev/null
+++ b/applications/ColossalChat/examples/training_scripts/train_sft.py
@@ -0,0 +1,311 @@
+import argparse
+import json
+import math
+import os
+import resource
+from contextlib import nullcontext
+
+import torch
+from coati.dataset import DataCollatorForSupervisedDataset, load_tokenized_dataset, setup_distributed_dataloader
+from coati.models import convert_to_lora_module
+from coati.trainer import SFTTrainer
+from coati.utils import load_checkpoint
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
+from colossalai.nn.optimizer import HybridAdam
+
+
+def train(args):
+ # check lora compatibility
+ if "gemini" in args.plugin and args.lora_rank > 0:
+ raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
+ if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
+ raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
+ # ==============================
+ # Initialize Distributed Training
+ # ==============================
+ colossalai.launch_from_torch({})
+ coordinator = DistCoordinator()
+
+ # ==============================
+ # Initialize Booster
+ # ==============================
+ if args.plugin == "ddp":
+ """
+ Default torch ddp plugin without any acceleration, for
+ debugging purpose acceleration, for debugging purpose
+ """
+ plugin = TorchDDPPlugin(find_unused_parameters=True)
+ elif args.plugin == "gemini":
+ plugin = GeminiPlugin(
+ precision=args.mixed_precision,
+ placement_policy="static",
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ enable_gradient_accumulation=True,
+ )
+ elif args.plugin == "gemini_auto":
+ plugin = GeminiPlugin(
+ precision=args.mixed_precision,
+ placement_policy="auto",
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "zero2":
+ plugin = LowLevelZeroPlugin(
+ stage=2,
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "zero2_cpu":
+ plugin = LowLevelZeroPlugin(
+ stage=2,
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ cpu_offload=True,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "3d":
+ plugin = HybridParallelPlugin(
+ tp_size=args.tp,
+ pp_size=1,
+ zero_stage=0,
+ parallel_output=False,
+ max_norm=args.grad_clip,
+ precision=args.mixed_precision,
+ )
+ else:
+ raise ValueError(f"Unknown plugin {args.plugin}")
+
+ booster = Booster(plugin=plugin)
+
+ # ======================================================
+ # Initialize Model, Objective, Optimizer and LR Scheduler
+ # ======================================================
+ # Temp Fix: Disable lazy init due to version conflict
+ # init_ctx = (
+ # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
+ # )
+
+ init_ctx = nullcontext()
+ with init_ctx:
+ if args.use_flash_attn:
+ model = AutoModelForCausalLM.from_pretrained(
+ args.pretrain,
+ torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
+ use_flash_attention_2=True,
+ )
+ coordinator.print_on_master(msg="Flash-attention enabled successfully")
+ else:
+ model = AutoModelForCausalLM.from_pretrained(args.pretrain)
+ if args.lora_rank > 0:
+ model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
+
+ if args.grad_checkpoint and args.lora_rank == 0:
+ # lora layers are not supported by gradient checkpointing
+ model.gradient_checkpointing_enable()
+ coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
+ elif args.lora_rank > 0:
+ coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
+
+ # configure tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.tokenizer_dir or args.pretrain, use_fast=False, trust_remote_code=True
+ )
+ if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
+ try:
+ # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
+ tokenizer.pad_token = tokenizer.eos_token
+ except AttributeError as e:
+ logger.warning(f"Unable to set pad token to eos token, {str(e)}")
+ if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
+ logger.warning(
+ "The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
+ )
+
+ tokenizer.add_bos_token = False
+ tokenizer.add_eos_token = False
+
+ coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
+ coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_path}")
+
+ # configure optimizer
+ optim = HybridAdam(
+ model_params=model.parameters(),
+ lr=args.lr,
+ betas=(0.9, 0.95),
+ weight_decay=args.weight_decay,
+ adamw_mode=True,
+ )
+
+ # configure dataset
+ coordinator.print_on_master(
+ f"Max CUDA memory before data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+ dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_len)
+ train_dataloader = setup_distributed_dataloader(
+ dataset=dataset,
+ batch_size=args.batch_size,
+ shuffle=True,
+ drop_last=True,
+ collate_fn=data_collator,
+ use_tp=args.tp > 1,
+ )
+ coordinator.print_on_master(
+ f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+
+ num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
+ math.ceil(args.max_epochs * num_update_steps_per_epoch)
+
+ if args.warmup_steps is None:
+ args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
+ coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
+
+ lr_scheduler = CosineAnnealingWarmupLR(
+ optimizer=optim,
+ total_steps=args.max_epochs * num_update_steps_per_epoch,
+ warmup_steps=args.warmup_steps,
+ eta_min=0.1 * args.lr,
+ )
+
+ # Flash attention will be disabled because it does NOT support fp32.
+ default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
+ torch.set_default_dtype(default_dtype)
+ model, optim, _, train_dataloader, lr_scheduler = booster.boost(
+ model=model,
+ optimizer=optim,
+ lr_scheduler=lr_scheduler,
+ dataloader=train_dataloader,
+ )
+ # model = model.to(get_current_device())
+ torch.set_default_dtype(torch.float)
+
+ coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
+ coordinator.print_on_master(
+ f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
+ )
+
+ start_epoch = 0
+ sampler_start_idx = 0
+ start_step = 0
+ if args.checkpoint_path is not None:
+ if "modeling" in args.checkpoint_path:
+ coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
+ booster.load_model(model, args.checkpoint_path)
+ else:
+ coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
+ start_epoch, start_step, sampler_start_idx = load_checkpoint(
+ load_dir=args.checkpoint_path,
+ booster=booster,
+ model=model,
+ optimizer=optim,
+ lr_scheduler=lr_scheduler,
+ )
+ train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
+
+ coordinator.print_on_master(
+ f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
+ )
+ coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
+
+ coordinator.print_on_master(
+ f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+ coordinator.print_on_master(
+ f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+ coordinator.print_on_master(
+ f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
+ )
+
+ trainer = SFTTrainer(
+ model=model,
+ booster=booster,
+ optim=optim,
+ lr_scheduler=lr_scheduler,
+ max_epochs=args.max_epochs,
+ accumulation_steps=args.accumulation_steps,
+ start_epoch=start_epoch,
+ save_interval=args.save_interval,
+ save_dir=args.save_path,
+ coordinator=coordinator,
+ )
+
+ trainer.fit(
+ train_dataloader=train_dataloader,
+ eval_dataloader=None,
+ log_dir=args.log_dir,
+ use_wandb=args.use_wandb,
+ )
+
+ if args.lora_rank > 0 and args.merge_lora_weights:
+ from coati.models.lora import LORA_MANAGER
+
+ # NOTE: set model to eval to merge LoRA weights
+ LORA_MANAGER.merge_weights = True
+ model.eval()
+ # save model checkpoint after fitting on only rank0
+ coordinator.print_on_master("Start saving final model checkpoint")
+
+ booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True)
+ coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_path}")
+
+ coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
+
+
+if __name__ == "__main__":
+ # ==============================
+ # Parse Arguments
+ # ==============================
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--plugin",
+ type=str,
+ default="gemini",
+ choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp"],
+ help="Choose which plugin to use",
+ )
+ parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
+ parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
+ parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
+ parser.add_argument("--tp", type=int, default=1)
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--tokenizer_dir", type=str, default=None)
+ parser.add_argument("--dataset", nargs="+", default=[])
+ parser.add_argument(
+ "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
+ )
+ parser.add_argument("--save_path", type=str, default="output")
+ parser.add_argument("--max_epochs", type=int, default=3)
+ parser.add_argument("--batch_size", type=int, default=4)
+ parser.add_argument("--max_len", type=int, default=512)
+ parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument(
+ "--lora_train_bias",
+ type=str,
+ default="none",
+ help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
+ )
+ parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
+ parser.add_argument("--merge_lora_weights", type=bool, default=True)
+ parser.add_argument("--lr", type=float, default=5e-6)
+ parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
+ parser.add_argument("--accumulation_steps", type=int, default=8)
+ parser.add_argument("--log_dir", default="logs", type=str)
+ parser.add_argument("--use_wandb", default=False, action="store_true")
+ parser.add_argument("--grad_checkpoint", default=False, action="store_true")
+ parser.add_argument("--use_flash_attn", default=False, action="store_true")
+ args = parser.parse_args()
+ os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
+ with open(args.config_file, "w") as f:
+ json.dump(args.__dict__, f, indent=4)
+ train(args)
diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.sh b/applications/ColossalChat/examples/training_scripts/train_sft.sh
new file mode 100755
index 000000000000..d5c394377616
--- /dev/null
+++ b/applications/ColossalChat/examples/training_scripts/train_sft.sh
@@ -0,0 +1,59 @@
+set_n_least_used_CUDA_VISIBLE_DEVICES() {
+ local n=${1:-"9999"}
+ echo "GPU Memory Usage:"
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
+ tail -n +2 |
+ nl -v 0 |
+ tee /dev/tty |
+ sort -g -k 2 |
+ awk '{print $1}' |
+ head -n $n)
+ export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
+ echo "Now CUDA_VISIBLE_DEVICES is set to:"
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+}
+
+
+# export CUDA_VISIBLE_DEVICES=4,5,6
+set_n_least_used_CUDA_VISIBLE_DEVICES 4
+PROJECT_NAME="sft"
+PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
+PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
+PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
+PRETRAINED_MODEL_PATH="" # huggingface or local model path
+PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
+declare -a dataset=(
+ YOUR/SFT/DATA/DIR/arrow/part-00000
+ YOUR/SFT/DATA/DIR/arrow/part-00001
+ YOUR/SFT/DATA/DIR/arrow/part-00002
+ YOUR/SFT/DATA/DIR/arrow/part-00003
+ YOUR/SFT/DATA/DIR/arrow/part-00004
+ YOUR/SFT/DATA/DIR/arrow/part-00005
+ YOUR/SFT/DATA/DIR/arrow/part-00006
+ YOUR/SFT/DATA/DIR/arrow/part-00007
+ YOUR/SFT/DATA/DIR/arrow/part-00008
+ YOUR/SFT/DATA/DIR/arrow/part-00009
+)
+
+TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
+FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
+SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
+CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
+
+# the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size
+colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile train_sft.py \
+ --pretrain $PRETRAINED_MODEL_PATH \
+ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
+ --save_interval 4000 \
+ --dataset ${dataset[@]} \
+ --save_path $SAVE_DIR \
+ --config_file $CONFIG_FILE \
+ --lora_rank 0 \
+ --plugin zero2 \
+ --batch_size 8 \
+ --max_epochs 1 \
+ --accumulation_steps 1 \
+ --lr 2e-5 \
+ --max_len 2048 \
+ --grad_checkpoint \
+ --use_wandb
diff --git a/applications/Chat/pytest.ini b/applications/ColossalChat/pytest.ini
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/pytest.ini
rename to applications/ColossalChat/pytest.ini
diff --git a/applications/ColossalChat/requirements.txt b/applications/ColossalChat/requirements.txt
new file mode 100755
index 000000000000..de5f6160e827
--- /dev/null
+++ b/applications/ColossalChat/requirements.txt
@@ -0,0 +1,24 @@
+transformers==4.34.1
+huggingface_hub==0.17.3
+tqdm
+datasets
+loralib
+colossalai>=0.3.6
+torch>=1.12.1
+langchain
+tokenizers
+fastapi
+sse_starlette
+wandb
+sentencepiece
+gpustat
+packaging
+autoflake==2.2.1
+black==23.9.1
+tensorboard
+six==1.16.0
+datasets
+ninja==1.11.1
+sentencepiece==0.1.99
+flash-attn
+tiktoken
diff --git a/applications/Chat/setup.py b/applications/ColossalChat/setup.py
old mode 100644
new mode 100755
similarity index 97%
rename from applications/Chat/setup.py
rename to applications/ColossalChat/setup.py
index eb44b6203ef8..37503920ade6
--- a/applications/Chat/setup.py
+++ b/applications/ColossalChat/setup.py
@@ -32,7 +32,7 @@ def fetch_version():
license="Apache Software License 2.0",
url="https://github.com/hpcaitech/Coati",
install_requires=fetch_requirements("requirements.txt"),
- python_requires=">=3.6",
+ python_requires=">=3.7",
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
diff --git a/applications/Chat/tests/__init__.py b/applications/ColossalChat/tests/__init__.py
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/tests/__init__.py
rename to applications/ColossalChat/tests/__init__.py
diff --git a/applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py b/applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py
new file mode 100644
index 000000000000..9f85b4beb65d
--- /dev/null
+++ b/applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py
@@ -0,0 +1,72 @@
+import argparse
+import json
+import os
+
+sft_seed = {
+ "messages": [
+ {"from": "human", "content": "Give three tips for staying healthy."},
+ {
+ "from": "assistant",
+ "content": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule.",
+ },
+ ]
+}
+prompt_seed = {
+ "messages": [
+ {"from": "human", "content": "Describe the impacts of climate change on communities living in coastal areas."},
+ {
+ "from": "assistant",
+ "content": "Climate change has caused an increase in sea levels, which has caused coastal erosion and flooding of low-lying areas. This has led to displacement of people from their homes, as well as increased risk of epidemics of waterborne illnesses. Coastal cities have also seen an increase in extreme weather events such as hurricanes and tropical storms, which can cause extensive damage to infrastructure, homes, and businesses. As a result of climate change, some coastal areas are becoming uninhabitable, forcing communities to seek alternative living arrangements.",
+ },
+ ]
+}
+preference_seed = {
+ "context": [
+ {"from": "human", "content": "What kind of noises did dinosaurs make?"},
+ {
+ "from": "assistant",
+ "content": "Humans and dinosaurs didn't live at the same time, so it's really hard to say. The best place to find out what noises dinosaurs made would be",
+ },
+ {"from": "human", "content": "yes they did"},
+ {
+ "from": "assistant",
+ "content": "to guess, and that would probably require lots of reading and a certain amount of imagination, so we're not really prepared to do that.",
+ },
+ {"from": "human", "content": "you cant read"},
+ ],
+ "chosen": [{"from": "assistant", "content": "You can read?"}],
+ "rejected": [{"from": "assistant", "content": "there's a lot of stuff humans don't know"}],
+}
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--data_dir",
+ type=str,
+ required=True,
+ default=None,
+ help="The output dir",
+ )
+ parser.add_argument(
+ "--data_type",
+ type=str,
+ required=True,
+ default=None,
+ help="The type of data",
+ )
+ args = parser.parse_args()
+ if args.data_type == "sft":
+ seed = sft_seed
+ elif args.data_type == "prompt":
+ seed = prompt_seed
+ elif args.data_type == "preference":
+ seed = preference_seed
+ else:
+ raise ValueError(f"Unknown data type {args.data_type}")
+
+ line = json.dumps(seed, ensure_ascii=False) + "\n"
+ for idx in [1, 2, 3]:
+ with open(os.path.join(args.data_dir, f"{idx}.jsonl"), "w", encoding="utf8") as f:
+ for i in range(1000):
+ f.write(line)
+ f.write(line)
diff --git a/applications/ColossalChat/tests/llama.json b/applications/ColossalChat/tests/llama.json
new file mode 100644
index 000000000000..482ff9e6528c
--- /dev/null
+++ b/applications/ColossalChat/tests/llama.json
@@ -0,0 +1,8 @@
+{
+ "chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{% if message['role'] == 'user' %}{{'Human: ' + bos_token + message['content'].strip() + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'].strip() + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + bos_token + message['content'].strip() + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant: ' + bos_token }}{% endif %}",
+ "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ "stop_ids": [
+ 29871,
+ 2
+ ]
+}
diff --git a/applications/ColossalChat/tests/opt.json b/applications/ColossalChat/tests/opt.json
new file mode 100644
index 000000000000..6d47666bb056
--- /dev/null
+++ b/applications/ColossalChat/tests/opt.json
@@ -0,0 +1,17 @@
+{
+ "chat_template": "{% for message in messages %}{% if message['role'] == 'user' %}{{'Human: ' + bos_token + message['content'].strip() + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'].strip() + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + bos_token + message['content'].strip() + eos_token }}{% endif %}{% endfor %}",
+ "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ "human_line_start": [
+ 2
+ ],
+ "human_line_end": [
+ 2
+ ],
+ "assistant_line_start": [
+ 2
+ ],
+ "assistant_line_end": [
+ 2
+ ],
+ "end_of_system_line_position": 0
+}
diff --git a/applications/ColossalChat/tests/test_data/dpo/test_dpo_data.jsonl b/applications/ColossalChat/tests/test_data/dpo/test_dpo_data.jsonl
new file mode 100644
index 000000000000..2e11a91c643f
--- /dev/null
+++ b/applications/ColossalChat/tests/test_data/dpo/test_dpo_data.jsonl
@@ -0,0 +1 @@
+{"context": [{"from": "human", "content": "If I were to give you some coordinates, could you tell me where the location is?"}], "chosen": [{"from": "assistant", "content": "Certainly! I can tell you if the given coordinates are within any geographic location, such as a country or city. However, I cannot give accurate directions to any specific location without more information. Are you looking for a more general geographic location, or do you want more precise directions?"}], "rejected": [{"from": "assistant", "content": "Yes, I can determine the location and give you some details about the area. Please provide the latitude and longitude coordinates, and I will return a description of the location based on these coordinates. Please note that precise geographical coordinates are often imprecise, so please provide the best possible coordinates you have available. Here are some examples:\n-The coordinates 41.543554,-71.09471 locate the Boston Commons in Boston, Massachusetts\n-The coordinates 37.42242, -122.083733 locate the Palace of Fine Arts in San Francisco, California \n-The coordinates 25.774145, -80.190262 locate Canal Street in Miami Beach, Florida\n-The coordinates 35. journalisten nl 48.056944, -106.277778 locate the Wyoming State Penitentiary in Rawlins, Wyoming\n-The coordinates 51.507351,-0.127758 locate Tower Bridge in London, England\n-The coordinates 44.885618, -79.422594 locate the Stratosphere Casino, Hotel and Tower in Las Vegas, Nevada"}]}
diff --git a/applications/ColossalChat/tests/test_data/sft/test_sft_data.jsonl b/applications/ColossalChat/tests/test_data/sft/test_sft_data.jsonl
new file mode 100644
index 000000000000..21c4d9dc76ec
--- /dev/null
+++ b/applications/ColossalChat/tests/test_data/sft/test_sft_data.jsonl
@@ -0,0 +1 @@
+{"messages": [{"from": "human", "content": "what are some pranks with a pen I can do?"}, {"from": "assistant", "content": "Are you looking for practical joke ideas?"}, {"from": "human", "content": "Yes, exactly!"}, {"from": "assistant", "content": "You could try attaching a small buzzer to a pen. When someone tries to use it, they'll be surprised by the sound!"}, {"from": "human", "content": "That's a funny one! Any other ideas?"}, {"from": "assistant", "content": "Another idea is the classic 'ink explosion' prank. Replace the ink of a pen with disappearing ink, and watch the confusion when the ink vanishes from the paper."}]}
diff --git a/applications/ColossalChat/tests/test_data_preparation.sh b/applications/ColossalChat/tests/test_data_preparation.sh
new file mode 100755
index 000000000000..a7689cdc6688
--- /dev/null
+++ b/applications/ColossalChat/tests/test_data_preparation.sh
@@ -0,0 +1,260 @@
+#!/usr/bin/env bash
+set_n_least_used_CUDA_VISIBLE_DEVICES() {
+ local n=${1:-"9999"}
+ echo "GPU Memory Usage:"
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
+ tail -n +2 |
+ nl -v 0 |
+ tee /dev/tty |
+ sort -g -k 2 |
+ awk '{print $1}' |
+ head -n $n)
+ export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
+ echo "Now CUDA_VISIBLE_DEVICES is set to:"
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+}
+
+set_n_least_used_CUDA_VISIBLE_DEVICES 4
+
+set -xu
+
+if [ -z "$SFT_DATASET" ]; then
+ echo "Please set \$SFT_DATASET to the path to sft dataset."
+ exit 1
+fi
+
+if [ -z "$PROMPT_DATASET" ]; then
+ echo "Please set \$PROMPT_DATASET to the path to prompts."
+ exit 1
+fi
+
+if [ -z "$PREFERENCE_DATASET" ]; then
+ echo "Please set \$SFT_DATASET to the path to sft dataset."
+ exit 1
+fi
+
+NUM_RETRY=3
+BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
+BASE_TEMP_DIR=$BASE_DIR/temp
+TEST_DIR=$BASE_DIR/tests
+EXAMPLES_DIR=$BASE_DIR/examples
+DATA_SAVE_PATH=$BASE_TEMP_DIR/rlhf_data
+CONFIG_DIR=$BASE_DIR/config
+# Skip those tests due to CI tests timeout
+MODELS=('llama')
+
+if [ ! -d "$BASE_TEMP_DIR" ]; then
+ mkdir "$BASE_TEMP_DIR"
+ echo "Directory created successfully"
+else
+ echo "Directory already exists"
+fi
+
+if [ ! -d "$DATA_SAVE_PATH" ]; then
+ mkdir "$DATA_SAVE_PATH"
+ echo "Directory created successfully"
+else
+ echo "Directory already exists"
+fi
+
+
+export OMP_NUM_THREADS=8
+
+# install requirements
+pip install -r $EXAMPLES_DIR/requirements.txt
+
+get_data_input_dirs() {
+ local data_type=$1
+ if [[ $data_type == "sft" ]]; then
+ echo "$SFT_DATASET"
+ elif [[ $data_type == "prompt" ]]; then
+ echo "$PROMPT_DATASET"
+ elif [[ $data_type == "preference" ]]; then
+ echo "$PREFERENCE_DATASET"
+ else
+ echo "Unknown data type $data_type"
+ exit 1
+ fi
+}
+
+get_conversation_template_config() {
+ local model=$1
+ if [[ $model == "llama" ]]; then
+ echo "$TEST_DIR/llama.json"
+ elif [[ $model == "opt" ]]; then
+ echo "$TEST_DIR/opt.json"
+ else
+ echo "Unknown model $model"
+ exit 1
+ fi
+}
+
+get_tokenizer_dirs() {
+ local model=$1
+ if [[ $model == "llama" ]]; then
+ echo "hf-internal-testing/llama-tokenizer"
+ elif [[ $model == "opt" ]]; then
+ echo "facebook/opt-125m"
+ else
+ echo "Unknown model $model"
+ exit 1
+ fi
+}
+
+random_choice() {
+ local arr=("$@")
+ local len=${#arr[@]}
+ local idx=$((RANDOM % len))
+ echo ${arr[$idx]}
+}
+
+echo "Prepare dummy data for testing..."
+python $TEST_DIR/generate_dummy_datasets_for_testing.py \
+ --data_dir $(get_data_input_dirs sft) \
+ --data_type "sft"
+
+python $TEST_DIR/generate_dummy_datasets_for_testing.py \
+ --data_dir $(get_data_input_dirs preference) \
+ --data_type "preference"
+
+python $TEST_DIR/generate_dummy_datasets_for_testing.py \
+ --data_dir $(get_data_input_dirs prompt) \
+ --data_type "prompt"
+
+echo "[Test]: testing prepare_preference_dataset.py ..."
+
+# FIXME: This is a hack to skip tests that are not working
+SKIPPED_TESTS=(
+)
+
+# test prepare_preference_dataset
+for model in ${MODELS[@]}; do
+ data_type="preference"
+ if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$data_type " ]]; then
+ echo "[Test]: Skipped $model-$data_type"
+ continue
+ fi
+ cache_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/cache
+ jsonl_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/jsonl
+ arrow_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/arrow
+ rm -rf $cache_dir
+ rm -rf $jsonl_dir
+ rm -rf $arrow_dir
+ data_input_dirs=$(get_data_input_dirs $data_type)
+ tokenizer_dir=$(get_tokenizer_dirs $model)
+ conversation_template=$(get_conversation_template_config $model)
+ for i in $(seq $NUM_RETRY); do
+ echo "[Test]: $model-$data_type, attempt $i"
+ python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py \
+ --type preference \
+ --data_input_dirs $data_input_dirs \
+ --conversation_template_config $conversation_template \
+ --tokenizer_dir $tokenizer_dir \
+ --data_cache_dir $cache_dir \
+ --data_jsonl_output_dir $jsonl_dir \
+ --data_arrow_output_dir $arrow_dir \
+ --max_length 400 \
+ --num_samples_per_datafile 100 \
+ --num_spliced_dataset_bins 1
+ passed=$?
+ if [ $passed -eq 0 ]; then
+ break
+ fi
+ done
+ if [ $passed -ne 0 ]; then
+ echo "[Test]: Failed $model-$data_type"
+ exit 1
+ fi
+done
+
+echo "[Test]: testing prepare_sft_dataset.py ..."
+
+# FIXME: This is a hack to skip tests that are not working
+SKIPPED_TESTS=(
+)
+
+# test prepare_sft_dataset
+for model in ${MODELS[@]}; do
+ data_type="sft"
+ if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$data_type " ]]; then
+ echo "[Test]: Skipped $model-$data_type"
+ continue
+ fi
+ cache_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/cache
+ jsonl_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/jsonl
+ arrow_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/arrow
+ data_input_dirs=$(get_data_input_dirs $data_type)
+ tokenizer_dir=$(get_tokenizer_dirs $model)
+ conversation_template=$(get_conversation_template_config $model)
+ for i in $(seq $NUM_RETRY); do
+ rm -rf $cache_dir
+ rm -rf $jsonl_dir
+ rm -rf $arrow_dir
+ echo "[Test]: $model-$data_type, attempt $i"
+ python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py \
+ --type sft \
+ --data_input_dirs $data_input_dirs \
+ --conversation_template_config $conversation_template \
+ --tokenizer_dir $tokenizer_dir \
+ --data_cache_dir $cache_dir \
+ --data_jsonl_output_dir $jsonl_dir \
+ --data_arrow_output_dir $arrow_dir \
+ --max_length 400 \
+ --num_samples_per_datafile 100 \
+ --num_spliced_dataset_bins 1
+ passed=$?
+ if [ $passed -eq 0 ]; then
+ break
+ fi
+ done
+ if [ $passed -ne 0 ]; then
+ echo "[Test]: Failed $model-$data_type"
+ exit 1
+ fi
+done
+
+echo "[Test]: testing prepare_prompt_dataset.py ..."
+
+# FIXME: This is a hack to skip tests that are not working
+SKIPPED_TESTS=(
+)
+
+# test prepare_prompt_dataset
+for model in ${MODELS[@]}; do
+ data_type="prompt"
+ if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$data_type " ]]; then
+ echo "[Test]: Skipped $model-$data_type"
+ continue
+ fi
+ cache_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/cache
+ jsonl_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/jsonl
+ arrow_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/arrow
+ data_input_dirs=$(get_data_input_dirs $data_type)
+ tokenizer_dir=$(get_tokenizer_dirs $model)
+ conversation_template=$(get_conversation_template_config $model)
+ for i in $(seq $NUM_RETRY); do
+ rm -rf $cache_dir
+ rm -rf $jsonl_dir
+ rm -rf $arrow_dir
+ echo "[Test]: $model-$data_type, attempt $i"
+ python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py \
+ --type prompt \
+ --data_input_dirs $data_input_dirs \
+ --conversation_template_config $conversation_template \
+ --tokenizer_dir $tokenizer_dir \
+ --data_cache_dir $cache_dir \
+ --data_jsonl_output_dir $jsonl_dir \
+ --data_arrow_output_dir $arrow_dir \
+ --max_length 400 \
+ --num_samples_per_datafile 100 \
+ --num_spliced_dataset_bins 1
+ passed=$?
+ if [ $passed -eq 0 ]; then
+ break
+ fi
+ done
+ if [ $passed -ne 0 ]; then
+ echo "[Test]: Failed $model-$data_type"
+ exit 1
+ fi
+done
diff --git a/applications/ColossalChat/tests/test_lora.py b/applications/ColossalChat/tests/test_lora.py
new file mode 100755
index 000000000000..4ea9e1a15c59
--- /dev/null
+++ b/applications/ColossalChat/tests/test_lora.py
@@ -0,0 +1,69 @@
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from coati.models import convert_to_lora_module
+from torch.utils.data import DataLoader, TensorDataset
+
+
+class SimpleNN(nn.Module):
+ def __init__(self, input_size, hidden_size, num_classes):
+ super(SimpleNN, self).__init__()
+ self.fc1 = nn.Linear(input_size, hidden_size)
+ self.relu = nn.ReLU()
+ self.fc2 = nn.Linear(hidden_size, num_classes)
+
+ def forward(self, x):
+ out = self.fc1(x)
+ out = self.relu(out)
+ out = self.fc2(out)
+ return out
+
+
+def test_overfit():
+ input_size = 1000
+ hidden_size = 200
+ num_classes = 5
+ batch_size = 64
+ learning_rate = 0.01
+ num_epochs = 200
+
+ # Synthesized dataset
+ X = torch.randn(batch_size, input_size)
+ Y = torch.randint(0, num_classes, (batch_size,))
+
+ # Convert to DataLoader
+ dataset = TensorDataset(X, Y)
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
+
+ # Build and convert model
+ model = SimpleNN(input_size, hidden_size, num_classes)
+ weight_to_compare = model.fc1.weight.detach().clone()
+ model = convert_to_lora_module(model, lora_rank=30)
+
+ # Loss and optimizer
+ criterion = nn.CrossEntropyLoss()
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
+
+ # Train the model
+ for _ in range(num_epochs):
+ for i, (inputs, labels) in enumerate(loader):
+ # Forward pass
+ outputs = model(inputs)
+ loss = criterion(outputs, labels)
+ print(loss)
+ # Backward and optimize
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ # Check if model has overfitted
+ outputs = model(X)
+ _, predicted = torch.max(outputs.data, 1)
+ total = labels.size(0)
+ correct = (predicted == Y).sum().item()
+ assert (correct / total > 0.95, "The model has not overfitted to the synthesized dataset")
+ assert (weight_to_compare - model.fc1.weight).sum() < 0.01
+
+
+if __name__ == "__main__":
+ test_overfit()
diff --git a/applications/ColossalChat/tests/test_templating.sh b/applications/ColossalChat/tests/test_templating.sh
new file mode 100755
index 000000000000..7fefede47539
--- /dev/null
+++ b/applications/ColossalChat/tests/test_templating.sh
@@ -0,0 +1,97 @@
+
+BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
+BASE_TEMP_DIR=$BASE_DIR/temp
+EXAMPLES_DIR=$BASE_DIR/examples
+TEST_DATA_DIR=$BASE_DIR/tests/test_data
+DATA_SAVE_PATH=$BASE_TEMP_DIR/tests
+CONFIG_DIR=$BASE_DIR/config
+
+MODELS=("colossal-llama2" "llama2" "zephyr" "mistral" "chatGLM2" "Qwen" "Vicuna" "Yi")
+
+get_pretrain() {
+ local model=$1
+ if [[ $model == "colossal-llama2" ]]; then
+ echo "hpcai-tech/Colossal-LLaMA-2-7b-base"
+ elif [[ $model == "llama2" ]]; then
+ echo "hf-internal-testing/llama-tokenizer"
+ elif [[ $model == "zephyr" ]]; then
+ echo "HuggingFaceH4/zephyr-7b-beta"
+ elif [[ $model == "mistral" ]]; then
+ echo "mistralai/Mistral-7B-Instruct-v0.2"
+ elif [[ $model == "chatGLM2" ]]; then
+ echo "THUDM/chatglm2-6b"
+ elif [[ $model == "Qwen" ]]; then
+ echo "Qwen/Qwen-7B-Chat"
+ elif [[ $model == "Vicuna" ]]; then
+ echo "lmsys/vicuna-7b-v1.5"
+ elif [[ $model == "Yi" ]]; then
+ echo "01-ai/Yi-6B-Chat"
+ else
+ echo "Unknown model $model"
+ exit 1
+ fi
+}
+
+get_conversation_template_config() {
+ local model=$1
+ echo "$CONFIG_DIR/conversation_template/$model.json"
+}
+
+# Test SFT data Preparation
+for model in ${MODELS[@]}; do
+ echo "Testing SFT data templating for $model"
+ SAVE_DIR=$DATA_SAVE_PATH/sft/$model
+ rm -rf $SAVE_DIR/cache
+ rm -rf $SAVE_DIR/jsonl
+ rm -rf $SAVE_DIR/arrow
+ pretrain=$(get_pretrain $model)
+ conversation_template_config=$(get_conversation_template_config $model)
+ python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type sft --data_input_dirs $TEST_DATA_DIR/sft \
+ --tokenizer_dir $pretrain \
+ --conversation_template_config $conversation_template_config \
+ --data_cache_dir $SAVE_DIR/cache \
+ --data_jsonl_output_dir $SAVE_DIR/jsonl \
+ --data_arrow_output_dir $SAVE_DIR/arrow
+ passed=$?
+ if [ $passed -ne 0 ]; then
+ echo "[Test]: Failed in the SFT data templating for $model"
+ exit 1
+ fi
+ python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/sft/test_sft_data.jsonl \
+ --to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type sft
+ passed=$?
+ if [ $passed -ne 0 ]; then
+ echo "[Test]: Failed in the SFT data templating test for $model"
+ exit 1
+ fi
+done
+
+
+# Test DPO/PPO data Preparation
+for model in ${MODELS[@]}; do
+ echo "Testing DPO/PPO data templating for $model"
+ SAVE_DIR=$DATA_SAVE_PATH/dpo/$model
+ rm -rf $SAVE_DIR/cache
+ rm -rf $SAVE_DIR/jsonl
+ rm -rf $SAVE_DIR/arrow
+ pretrain=$(get_pretrain $model)
+ conversation_template_config=$(get_conversation_template_config $model)
+ python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type preference --data_input_dirs $TEST_DATA_DIR/dpo \
+ --tokenizer_dir $pretrain \
+ --conversation_template_config $conversation_template_config \
+ --data_cache_dir $SAVE_DIR/cache \
+ --data_jsonl_output_dir $SAVE_DIR/jsonl \
+ --data_arrow_output_dir $SAVE_DIR/arrow
+ passed=$?
+ if [ $passed -ne 0 ]; then
+ echo "[Test]: Failed in the DPO data templating for $model"
+ exit 1
+ fi
+ python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/dpo/test_dpo_data.jsonl \
+ --to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type dpo
+ passed=$?
+ if [ $passed -ne 0 ]; then
+ echo "[Test]: Failed in the DPO data templating test for $model"
+ exit 1
+ fi
+done
diff --git a/applications/ColossalChat/tests/test_train.sh b/applications/ColossalChat/tests/test_train.sh
new file mode 100755
index 000000000000..5ba4904711ea
--- /dev/null
+++ b/applications/ColossalChat/tests/test_train.sh
@@ -0,0 +1,397 @@
+#!/usr/bin/env bash
+
+set_n_least_used_CUDA_VISIBLE_DEVICES() {
+ local n=${1:-"9999"}
+ echo "GPU Memory Usage:"
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
+ tail -n +2 |
+ nl -v 0 |
+ tee /dev/tty |
+ sort -g -k 2 |
+ awk '{print $1}' |
+ head -n $n)
+ export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
+ echo "Now CUDA_VISIBLE_DEVICES is set to:"
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+}
+
+set_n_least_used_CUDA_VISIBLE_DEVICES 4
+
+set -xu
+
+
+NUM_RETRY=3
+BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
+EXAMPLES_DIR=$BASE_DIR/examples
+CONFIG_DIR=$BASE_DIR/config
+TEMP_DIR=$BASE_DIR/temp
+TEST_DIR=$BASE_DIR/tests
+MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models
+MODELS_DIR=$TEMP_DIR/models_config
+# Skip those tests due to CI tests timeout
+MODELS=('llama')
+PLUGINS=('gemini' 'gemini_auto' 'zero2' 'zero2_cpu' '3d')
+LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally
+
+export OMP_NUM_THREADS=8
+
+get_pretrain() {
+ local model=$1
+ if [[ $model == "llama" ]]; then
+ echo "nickypro/tinyllama-110M"
+ elif [[ $model == "opt" ]]; then
+ echo "facebook/opt-125m"
+ else
+ echo "Unknown model $model"
+ exit 1
+ fi
+}
+
+get_tokenizer_dirs() {
+ local model=$1
+ if [[ $model == "llama" ]]; then
+ echo "hf-internal-testing/llama-tokenizer"
+ elif [[ $model == "opt" ]]; then
+ echo "facebook/opt-125m"
+ else
+ echo "Unknown model $model"
+ exit 1
+ fi
+}
+
+
+get_conversation_template_config() {
+ local model=$1
+ if [[ $model == "llama" ]]; then
+ echo "$TEST_DIR/llama.json"
+ elif [[ $model == "opt" ]]; then
+ echo "$TEST_DIR/opt.json"
+ else
+ echo "Unknown model $model"
+ exit 1
+ fi
+}
+
+random_choice() {
+ local arr=("$@")
+ local len=${#arr[@]}
+ local idx=$((RANDOM % len))
+ echo ${arr[$idx]}
+}
+
+
+echo "[Test]: testing sft ..."
+
+SKIPPED_TESTS=(
+ llama-3d-20 # 3d plugin doesn't support lora
+ llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
+ llama-gemini-20 # gemini doesn't support lora
+)
+
+GRAD_CKPTS=('--grad_checkpoint')
+for lora_rank in ${LORA_RANK[@]}; do
+ for model in ${MODELS[@]}; do
+ for plugin in ${PLUGINS[@]}; do
+ if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
+ echo "[Test]: Skipped $model-$plugin-$lora_rank"
+ continue
+ elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
+ echo "[Test]: Skipped $model-$plugin"
+ continue
+ fi
+ pretrain=$(get_pretrain $model)
+ tokenizer_dir=$(get_tokenizer_dirs $model)
+ grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
+ tp='1'
+ bs='2'
+ if [[ $plugin == "3d" ]]; then
+ tp='4'
+ bs='8'
+ fi
+ grad_accu='2'
+ # Check if the plugin is either "gemini_auto" or "gemini" and set grad_accu to '1'
+ if [[ $plugin == "gemini_auto" ]]; then
+ grad_accu='1'
+ fi
+
+ for i in $(seq $NUM_RETRY); do
+ echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
+ declare -a dataset=()
+ for split in $(seq -f "%05g" 0 0); do
+ dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
+ done
+ colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
+ --pretrain $pretrain \
+ --tokenizer_dir $tokenizer_dir \
+ --dataset ${dataset[@]} \
+ --save_path $MODEL_SAVE_PATH \
+ --config_file $MODELS_DIR/config.jsonl \
+ --lora_rank $lora_rank \
+ --plugin $plugin \
+ --batch_size $bs \
+ --max_epochs 1 \
+ --accumulation_steps $grad_accu \
+ --tp $tp \
+ --lr 2e-5 \
+ $grad_ckpt \
+ --max_len 400 \
+ --use_flash_attn
+ passed=$?
+ if [ $passed -eq 0 ]; then
+ rm -rf $MODEL_SAVE_PATH/*
+ rm -rf $MODELS_DIR/*
+ break
+ fi
+ done
+ if [ $passed -ne 0 ]; then
+ echo "[Test]: Failed $model-$plugin-$lora_rank"
+ exit 1
+ fi
+ done
+ done
+done
+
+echo "[Test]: testing reward model ..."
+
+SKIPPED_TESTS=(
+ llama-3d-20 # 3d plugin doesn't support lora
+ llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
+ llama-gemini-20 # gemini doesn't support lora
+)
+
+GRAD_CKPTS=('--grad_checkpoint')
+for lora_rank in ${LORA_RANK[@]}; do
+ for model in ${MODELS[@]}; do
+ for plugin in ${PLUGINS[@]}; do
+ if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
+ echo "[Test]: Skipped $model-$plugin-$lora_rank"
+ continue
+ elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
+ echo "[Test]: Skipped $model-$plugin"
+ continue
+ fi
+ pretrain=$(get_pretrain $model)
+ tokenizer_dir=$(get_tokenizer_dirs $model)
+ grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
+ tp='1'
+ bs='2'
+ if [[ $plugin == "3d" ]]; then
+ tp='4'
+ bs='8'
+ fi
+ grad_accu='2'
+ # gemini_auto and gemini doesn't support gradient accumulation
+ if [[ $plugin == "gemini_auto" ]]; then
+ grad_accu='1'
+ fi
+ for i in $(seq $NUM_RETRY); do
+ echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
+ declare -a dataset=()
+ for split in $(seq -f "%05g" 0 0); do
+ dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
+ done
+ colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_rm.py \
+ --pretrain $pretrain \
+ --tokenizer_dir $tokenizer_dir \
+ --dataset ${dataset[@]} \
+ --save_dir $MODEL_SAVE_PATH \
+ --config_file $MODELS_DIR/config.jsonl \
+ --lora_rank $lora_rank \
+ --plugin $plugin \
+ --batch_size $bs \
+ --max_epochs 1 \
+ --accumulation_steps $grad_accu \
+ --tp $tp \
+ --lr 2e-5 \
+ $grad_ckpt \
+ --max_len 400 \
+ --use_flash_attn
+ passed=$?
+ if [ $passed -eq 0 ]; then
+ rm -rf $MODEL_SAVE_PATH/*
+ rm -rf $MODELS_DIR/*
+ break
+ fi
+ done
+ if [ $passed -ne 0 ]; then
+ echo "[Test]: Failed $model-$plugin-$lora_rank"
+ exit 1
+ fi
+ done
+ done
+done
+
+
+echo "[Test]: testing ppo ..."
+
+
+SKIPPED_TESTS=(
+ llama-3d-20 # 3d plugin doesn't support lora
+ llama-gemini-20 # gemini doesn't support lora
+)
+
+GRAD_CKPTS=('--grad_checkpoint')
+for lora_rank in ${LORA_RANK[@]}; do
+ for model in ${MODELS[@]}; do
+ for plugin in ${PLUGINS[@]}; do
+ if [[ $plugin == "gemini_auto" ]]; then
+ echo "[Test]: Skipped $model-$plugin"
+ continue # gemini_auto plugin doesn't support generation
+ fi
+ if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
+ echo "[Test]: Skipped $model-$plugin-$lora_rank"
+ continue
+ elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
+ echo "[Test]: Skipped $model-$plugin"
+ continue
+ fi
+ pretrain=$(get_pretrain $model)
+ tokenizer_dir=$(get_tokenizer_dirs $model)
+ grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
+ tp='1'
+ bs='4'
+ ebs='8'
+ conversation_template=$(get_conversation_template_config $model)
+ if [[ $plugin == "3d" ]]; then
+ tp='4'
+ bs='16'
+ ebs='32'
+ fi
+ grad_accu='2'
+ # gemini_auto and gemini doesn't support gradient accumulation
+ if [[ $plugin == "gemini_auto" ]]; then
+ grad_accu='1'
+ fi
+ # gemini_auto and gemini doesn't support generation
+ if [[ $plugin == "gemini_auto" ]]; then
+ # gemini-auto doesn't support generation
+ echo "[Test]: Skipped $model-$plugin"
+ continue
+ fi
+ for i in $(seq $NUM_RETRY); do
+ echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
+ declare -a prompt_dataset=()
+ for split in $(seq -f "%05g" 0 0); do
+ prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split")
+ done
+ declare -a ptx_dataset=()
+ for split in $(seq -f "%05g" 0 0); do
+ ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
+ done
+ colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \
+ --pretrain $pretrain \
+ --rm_pretrain $pretrain \
+ --tokenizer_dir $tokenizer_dir \
+ --conversation_template_config $conversation_template \
+ --prompt_dataset ${prompt_dataset[@]} \
+ --ptx_dataset ${ptx_dataset[@]} \
+ --ptx_batch_size 1 \
+ --ptx_coef 0.2 \
+ --save_path $MODEL_SAVE_PATH \
+ --lora_rank $lora_rank \
+ --plugin $plugin \
+ --num_episodes 5 \
+ --num_collect_steps 1 \
+ --num_update_steps 1 \
+ --experience_batch_size $ebs \
+ --train_batch_size $bs \
+ --accumulation_steps $grad_accu \
+ --lr 9e-6 \
+ --mixed_precision "bf16" \
+ --grad_clip 1.0 \
+ --tp $tp \
+ --lr 2e-5 \
+ $grad_ckpt \
+ --max_len 400 \
+ --max_seq_len 10 \
+ --use_flash_attn
+ passed=$?
+ if [ $passed -eq 0 ]; then
+ rm -rf $MODEL_SAVE_PATH/*
+ rm -rf $MODELS_DIR/*
+ break
+ fi
+ done
+ if [ $passed -ne 0 ]; then
+ echo "[Test]: Failed $model-$plugin-$lora_rank"
+ exit 1
+ fi
+ done
+ done
+done
+
+
+echo "[Test]: testing DPO ..."
+
+SKIPPED_TESTS=(
+ llama-3d-20 # 3d plugin doesn't support lora
+ llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
+ llama-gemini-20 # gemini doesn't support lora
+)
+GRAD_CKPTS=('--grad_checkpoint')
+for lora_rank in ${LORA_RANK[@]}; do
+ for model in ${MODELS[@]}; do
+ for plugin in ${PLUGINS[@]}; do
+ if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
+ echo "[Test]: Skipped $model-$plugin-$lora_rank"
+ continue
+ elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
+ echo "[Test]: Skipped $model-$plugin"
+ continue
+ fi
+ pretrain=$(get_pretrain $model)
+ tokenizer_dir=$(get_tokenizer_dirs $model)
+ grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
+ tp='1'
+ bs='2'
+ if [[ $plugin == "3d" ]]; then
+ tp='4'
+ bs='8'
+ fi
+ grad_accu='2'
+ # gemini_auto and gemini doesn't support gradient accumulation
+ if [[ $plugin == "gemini_auto" ]]; then
+ grad_accu='1'
+ fi
+ # gemini_auto doesn't support generation
+ # (need to calculate ref_model logits through forwarding in inference mode)
+ if [[ $plugin == "gemini_auto" ]]; then
+ echo "[Test]: Skipped $model-$plugin"
+ continue
+ fi
+ for i in $(seq $NUM_RETRY); do
+ echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
+ declare -a dataset=()
+ for split in $(seq -f "%05g" 0 0); do
+ dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
+ done
+ colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_dpo.py \
+ --pretrain $pretrain \
+ --tokenizer_dir $tokenizer_dir \
+ --dataset ${dataset[@]} \
+ --save_dir $MODEL_SAVE_PATH \
+ --config_file $MODELS_DIR/config.jsonl \
+ --lora_rank $lora_rank \
+ --plugin $plugin \
+ --batch_size $bs \
+ --max_epochs 1 \
+ --accumulation_steps $grad_accu \
+ --tp $tp \
+ --lr 2e-5 \
+ $grad_ckpt \
+ --max_len 400 \
+ --use_flash_attn
+ passed=$?
+ if [ $passed -eq 0 ]; then
+ rm -rf $MODEL_SAVE_PATH/*
+ rm -rf $MODELS_DIR/*
+ break
+ fi
+ done
+ if [ $passed -ne 0 ]; then
+ echo "[Test]: Failed $model-$plugin-$lora_rank"
+ exit 1
+ fi
+ done
+ done
+done
diff --git a/applications/ColossalChat/tests/verify_chat_data.py b/applications/ColossalChat/tests/verify_chat_data.py
new file mode 100644
index 000000000000..98ae0c1b2d28
--- /dev/null
+++ b/applications/ColossalChat/tests/verify_chat_data.py
@@ -0,0 +1,64 @@
+import argparse
+import json
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--data_source",
+ type=str,
+ required=True,
+ default=None,
+ help="The raw data file",
+ )
+ parser.add_argument(
+ "--to_verify_file",
+ type=str,
+ required=True,
+ default=None,
+ help="The file that contains the data to be verified",
+ )
+ parser.add_argument(
+ "--data_type",
+ type=str,
+ required=True,
+ default=None,
+ help="The data type",
+ )
+ args = parser.parse_args()
+
+ # Read data
+ data = []
+ with open(args.data_source, "r", encoding="utf8") as f:
+ for line in f.readlines():
+ data.append(json.loads(line))
+ to_verify_data = []
+ with open(args.to_verify_file, "r", encoding="utf8") as f:
+ for line in f.readlines():
+ to_verify_data.append(json.loads(line))
+
+ if args.data_type == "sft":
+ target_lable = [msg["content"].strip() for msg in data[0]["messages"] if msg["from"] == "assistant"]
+ target_negative_label = [msg["content"].strip() for msg in data[0]["messages"] if msg["from"] == "human"]
+
+ # Read to verify file
+
+ to_verify_lable = to_verify_data[0]["labels_decode"]
+ for label in target_lable:
+ assert any([label in s for s in to_verify_lable]), f"Label {label} not in target label {to_verify_lable}"
+ for label in target_negative_label:
+ assert all(
+ [label not in s for s in to_verify_lable]
+ ), f"Negative label {label} in target label {to_verify_lable}"
+ elif args.data_type == "dpo":
+ chosen_lable = data[0]["chosen"][0]["content"].strip()
+ rejected_lable = data[0]["rejected"][0]["content"].strip()
+
+ # Read to verify file
+ to_verify_lable_chosen = to_verify_data[0]["chosen_label_decode"]
+ to_verify_lable_rejected = to_verify_data[0]["rejected_label_decode"]
+ assert any(
+ [chosen_lable in s for s in to_verify_lable_chosen]
+ ), f"Chosen label {chosen_lable} not in target chosen label {to_verify_lable_chosen}"
+ assert any(
+ [rejected_lable in s for s in to_verify_lable_rejected]
+ ), f"Rejected label {rejected_lable} not in target rejected label {to_verify_lable_chosen}"
diff --git a/applications/Chat/version.txt b/applications/ColossalChat/version.txt
old mode 100644
new mode 100755
similarity index 100%
rename from applications/Chat/version.txt
rename to applications/ColossalChat/version.txt
diff --git a/applications/ColossalEval/colossal_eval/evaluate/gpt_evaluate.py b/applications/ColossalEval/colossal_eval/evaluate/gpt_evaluate.py
index a0b1ed1143f0..19907daaff7f 100644
--- a/applications/ColossalEval/colossal_eval/evaluate/gpt_evaluate.py
+++ b/applications/ColossalEval/colossal_eval/evaluate/gpt_evaluate.py
@@ -670,7 +670,7 @@ def calculate_scores_form_logprobs(logprobs: Dict[str, Any]) -> float:
def calculate_scores_form_response(response: str, evaluation: Dict[str, Any]) -> int:
"""
Calculate the score from the response returned by gpt-3.5-turbo or gpt-4.
- Different from text-davinci-003, this fuction directly calculates the score according to the plain response returned by gpt-3.5-turbo or gpt-4.
+ Different from text-davinci-003, this function directly calculates the score according to the plain response returned by gpt-3.5-turbo or gpt-4.
Although text-davinci-003 can return log probabilities, it costs ten times as much as gpt-3.5-turbo.
Args:
diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py
index 218b05b27fad..c01e02c49a60 100644
--- a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py
+++ b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py
@@ -109,8 +109,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
else:
module = self.model.model
- layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
- stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
+ layers_per_stage = stage_manager.distribute_layers(len(module.layers))
+ stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
@@ -129,10 +129,10 @@ def get_held_layers(self) -> List[Module]:
stage_manager = self.pipeline_stage_manager
held_layers = []
- layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
+ layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
- start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
diff --git a/applications/ColossalMoE/train.py b/applications/ColossalMoE/train.py
index c567038ec252..850236726a27 100644
--- a/applications/ColossalMoE/train.py
+++ b/applications/ColossalMoE/train.py
@@ -128,13 +128,13 @@ def parse_args():
parser.add_argument(
"--comm_overlap",
action="store_true",
- help="Use communication overlap for MoE. Recommended to enable for muiti-node training.",
+ help="Use communication overlap for MoE. Recommended to enable for multi-node training.",
)
# hierarchical all-to-all
parser.add_argument(
"--hierarchical_alltoall",
action="store_true",
- help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.",
+ help="Use hierarchical all-to-all for MoE. Recommended to enable for multi-node training.",
)
args = parser.parse_args()
@@ -238,7 +238,6 @@ def main():
lambda x, y: x.loss,
optimizer,
return_loss=True,
- return_outputs=True,
)
# Backward and optimize
if is_pp_last_stage:
@@ -268,7 +267,7 @@ def main():
# ):
# coordinator.print_on_master(f"Apply load balance")
# apply_load_balance(model, optimizer)
- # save ckeckpoint
+ # save checkpoint
if (step + 1) % args.save_interval == 0:
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
save_checkpoint(
diff --git a/applications/ColossalQA/colossalqa/data_loader/document_loader.py b/applications/ColossalQA/colossalqa/data_loader/document_loader.py
index cbcd6ad1d2d3..219c7ca41675 100644
--- a/applications/ColossalQA/colossalqa/data_loader/document_loader.py
+++ b/applications/ColossalQA/colossalqa/data_loader/document_loader.py
@@ -52,7 +52,7 @@ def __init__(self, files: List, **kwargs) -> None:
def load_data(self, path: str) -> None:
"""
Load data. Please refer to https://python.langchain.com/docs/modules/data_connection/document_loaders/
- for sepcific format requirements.
+ for specific format requirements.
Args:
path: path to a file
To load files with glob path, here are some examples.
diff --git a/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py b/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py
index 14e33820d9c9..3629778698fb 100644
--- a/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py
+++ b/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py
@@ -101,7 +101,7 @@ def _call(self, prompt: str, stop=None, **kwargs: Any) -> str:
return resp_text
def text_completion(self, prompt, gen_config, auth_config):
- # Complusory Parameters
+ # Required Parameters
endpoint = auth_config.pop("endpoint")
max_new_tokens = gen_config.pop("max_new_tokens")
# Optional Parameters
diff --git a/applications/ColossalQA/colossalqa/local/llm.py b/applications/ColossalQA/colossalqa/local/llm.py
index bab702d14b13..30a456c3d9c7 100644
--- a/applications/ColossalQA/colossalqa/local/llm.py
+++ b/applications/ColossalQA/colossalqa/local/llm.py
@@ -33,7 +33,7 @@ class ColossalAPI:
def __init__(self, model_type: str, model_path: str, ckpt_path: str = None) -> None:
"""
- Configurate model
+ Configure model
"""
if model_type + model_path + (ckpt_path or "") in ColossalAPI.__instances:
return
@@ -47,7 +47,7 @@ def __init__(self, model_type: str, model_path: str, ckpt_path: str = None) -> N
self.model.load_state_dict(state_dict)
self.model.to(torch.cuda.current_device())
- # Configurate tokenizer
+ # Configure tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.model.eval()
@@ -87,7 +87,7 @@ def generate(self, input: str, **kwargs) -> str:
class VllmAPI:
def __init__(self, host: str = "localhost", port: int = 8077) -> None:
- # Configurate api for model served through web
+ # Configure api for model served through web
self.host = host
self.port = port
self.url = f"http://{self.host}:{self.port}/generate"
diff --git a/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py b/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py
index b23058d6dbe3..6e77bb2aee17 100644
--- a/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py
+++ b/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py
@@ -36,7 +36,7 @@ def __init__(
text_splitter_chunk_overlap=10,
) -> None:
"""
- Warpper for multilingual retrieval qa class (Chinese + English)
+ Wrapper for multilingual retrieval qa class (Chinese + English)
Args:
embedding_model_path: local or huggingface embedding model
embedding_model_device:
diff --git a/applications/ColossalQA/colossalqa/retriever.py b/applications/ColossalQA/colossalqa/retriever.py
index 22a75050f03b..6a0c69859ac7 100644
--- a/applications/ColossalQA/colossalqa/retriever.py
+++ b/applications/ColossalQA/colossalqa/retriever.py
@@ -59,7 +59,7 @@ def add_documents(
Add documents to retriever
Args:
docs: the documents to add
- cleanup: choose from "incremental" (update embeddings, skip existing embeddings) and "full" (destory and rebuild retriever)
+ cleanup: choose from "incremental" (update embeddings, skip existing embeddings) and "full" (destroy and rebuild retriever)
mode: choose from "by source" (documents are grouped by source) and "merge" (documents are merged into one vector store)
"""
if cleanup == "full":
diff --git a/applications/ColossalQA/colossalqa/utils.py b/applications/ColossalQA/colossalqa/utils.py
index cd8c3e5acec8..49d99014b372 100644
--- a/applications/ColossalQA/colossalqa/utils.py
+++ b/applications/ColossalQA/colossalqa/utils.py
@@ -49,7 +49,7 @@ def destroy_sql_database(sql_engine: Union[Engine, str]) -> None:
def detect_lang_naive(s):
"""
- Naive function for language detection, should be replaced by an independant layer
+ Naive function for language detection, should be replaced by an independent layer
"""
remove_nota = "[’·°–!\"#$%&'()*+,-./:;<=>?@,。?★、…【】()《》?“”‘’![\\]^_`{|}~]+"
s = re.sub(remove_nota, "", s)
diff --git a/colossalai/_analyzer/fx/tracer/tracer.py b/colossalai/_analyzer/fx/tracer/tracer.py
index 17dce767269d..36e8780af33a 100644
--- a/colossalai/_analyzer/fx/tracer/tracer.py
+++ b/colossalai/_analyzer/fx/tracer/tracer.py
@@ -237,7 +237,7 @@ def _tracer_override(self):
# override the tracer to support custom modules and checkpointing
if self.trace_act_ckpt:
orig_ckpt_func_apply = torch.utils.checkpoint.CheckpointFunction.apply
- orig_ckpt_func_without_reentrant = torch.utils.checkpoint._checkpoint_without_reentrant
+ orig_ckpt_func_without_reentrant = torch.utils.checkpoint._checkpoint_without_reentrant_generator
def checkpoint(run_function, preserve_rng_state=False, *args):
self.ckpt_regions.append(self.ckpt_idx)
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index c37a6b4df72d..29cec7cfd146 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -26,7 +26,7 @@
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.shardformer import ShardConfig, ShardFormer
+from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.d_tensor.api import is_distributed_tensor
@@ -34,7 +34,8 @@
from .pp_plugin_base import PipelinePluginBase
-DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
+DP_AXIS, PP_AXIS, TP_AXIS, SP_AXIS = 0, 1, 2, 3
+SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}
@@ -53,6 +54,7 @@ def __init__(
shard_config: ShardConfig,
dp_group: ProcessGroup,
tp_group: ProcessGroup,
+ sp_group: ProcessGroup,
use_ddp: bool,
ddp_config: dict,
custom_policy: Policy,
@@ -61,6 +63,7 @@ def __init__(
self.shard_config = shard_config
self.dp_group = dp_group
self.tp_group = tp_group
+ self.sp_group = sp_group
self.use_dpp = use_ddp
self.require_grad_sync = True
@@ -168,13 +171,24 @@ def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None):
Returns:
None
"""
- if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:
+
+ if self.shard_config.enable_sequence_parallelism:
+ if self.shard_config.sequence_parallelism_mode == "all_to_all":
+ return
+
+ if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
+ # If sequence parallelism is enabled and mode is split_gather or ring, gradients are synchronized
+ # across the tensor parallelism group.
+ group = self.tp_group
+ else:
+ raise ValueError(f"Unknown sequence parallelism mode: {self.shard_config.sequence_parallelism_mode}")
+
if grads is not None:
# Synchronize provided gradient tensors across the tensor parallelism group.
- SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, grads=grads)
+ SeqParallelUtils.allreduce_partial_data_grad(process_group=group, grads=grads)
else:
# Synchronize gradients from the model across the tensor parallelism group.
- SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, model=self.module)
+ SeqParallelUtils.allreduce_partial_data_grad(process_group=group, model=self.module)
def forward(self, *args, **kwargs):
if self.convert_fn is not None:
@@ -727,10 +741,9 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]:
# Get all working gradients and gradients to be synchronized.
all_working_grads = _get_all_working_grads()
grads_to_sync = _get_grads_to_sync(all_working_grads)
-
if self.require_grad_sync and grads_to_sync is not None:
# Synchronize sequence parallelism gradients if required.
- SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_pg, grads=grads_to_sync)
+ SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync)
else:
return
@@ -891,6 +904,7 @@ class HybridParallelPlugin(PipelinePluginBase):
Args:
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
+ sp_size (int): The size of sequence parallelism.
precision (str, optional): Specifies the precision of parameters during training.
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
Defaults to 'fp16'.
@@ -903,6 +917,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
+ sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
@@ -930,6 +945,7 @@ class HybridParallelPlugin(PipelinePluginBase):
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'.
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
+ gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
"""
@@ -937,6 +953,7 @@ def __init__(
self,
tp_size: int,
pp_size: int,
+ sp_size: int = None,
precision: str = "fp16",
zero_stage: int = 0,
enable_all_optimization: bool = False,
@@ -944,6 +961,7 @@ def __init__(
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
+ sequence_parallelism_mode: str = None,
enable_sequence_overlap: bool = False,
parallel_output: bool = True,
num_microbatches: Optional[int] = None,
@@ -969,19 +987,47 @@ def __init__(
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
+ gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
) -> None:
super().__init__()
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
- ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
+ ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
if enable_sequence_parallelism:
- assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
+ self.sequence_parallelism_mode = sequence_parallelism_mode if sequence_parallelism_mode is not None else "1"
+ assert (
+ self.sequence_parallelism_mode in SUPPORT_SP_MODE
+ ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}"
+ if self.sequence_parallelism_mode in ["split_gather", "ring"]:
+ assert (
+ tp_size > 1
+ ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
+ if sp_size != 1:
+ warnings.warn(
+ f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size."
+ )
+ self.sp_size = 1
+ self.dp_size = dist.get_world_size() // (tp_size * pp_size)
+ elif self.sequence_parallelism_mode in ["all_to_all"]:
+ assert (
+ tp_size == 1
+ ), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with tensor parallelism"
+ assert (
+ pp_size == 1
+ ), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with pipeline parallelism"
+ self.sp_size = dist.get_world_size() if sp_size is None else sp_size
+ self.dp_size = dist.get_world_size() // (self.sp_size * pp_size)
+ else:
+ self.dp_size = dist.get_world_size() // (tp_size * pp_size)
+ assert (
+ sp_size == 1 or sp_size is None
+ ), f"sp_size can only be set to a >1 number when enable_sequence_parallelism is True"
+ self.sp_size = 1
self.tp_size = tp_size
self.pp_size = pp_size
- self.dp_size = dist.get_world_size() // (tp_size * pp_size)
self.precision = precision
self.zero_stage = zero_stage
self.cpu_offload = cpu_offload
@@ -990,7 +1036,7 @@ def __init__(
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
- self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
+ self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
self.stage_manager = None
self.schedule = None
self.custom_policy = custom_policy
@@ -1031,9 +1077,14 @@ def __init__(
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
+ if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
+ self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
+ else:
+ self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS)
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
+ sequence_parallel_process_group=self.sp_group,
pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1,
enable_all_optimization=self.enable_all_optimization,
@@ -1041,8 +1092,10 @@ def __init__(
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
+ sequence_parallelism_mode=sequence_parallelism_mode,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output,
+ gradient_checkpoint_config=gradient_checkpoint_config,
)
self.amp_config = dict(
initial_scale=initial_scale,
@@ -1110,13 +1163,23 @@ def configure(
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
- use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
+ use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
+ self.dp_size == 1
+ and self.pp_size == 1
+ and self.enable_sequence_parallelism
+ and self.sequence_parallelism_mode == "all_to_all"
+ )
+ if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
+ dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS])
+ else:
+ dp_group = self.dp_group
model = HybridParallelModule(
model,
precision=self.precision,
shard_config=self.shard_config,
- dp_group=self.dp_group,
+ dp_group=dp_group,
tp_group=self.tp_group,
+ sp_group=self.sp_group,
use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
@@ -1146,7 +1209,8 @@ def configure(
tp_process_group=self.tp_group,
)
else:
- if self.dp_size == 1:
+ zero_dp_size = dist.get_world_size(dp_group)
+ if zero_dp_size == 1:
warnings.warn(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you are not intended to use cpu_offload, please consider set zero_stage=0."
@@ -1158,7 +1222,7 @@ def configure(
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
- dp_process_group=self.dp_group,
+ dp_process_group=dp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
verbose=True,
@@ -1183,6 +1247,9 @@ def execute_pipeline(
) -> dict:
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
+ if return_outputs:
+ warnings.warn("return_outputs may lead to significant extra memory consumption.")
+
# Create a context for gradient synchronization based on the optimizer type.
# If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync().
# This is to avoid redundant gradient reduction in pipeline parallelism (multiple microbatch values should be reduced once),
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index ae372dd034e0..83888e5069a7 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -254,6 +254,9 @@ def __init__(
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
+ # TODO: Currently moe only support partially sequence parallel
+ self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
+
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager,
@@ -365,6 +368,7 @@ def configure(
shard_config=self.shard_config,
dp_group=self.dp_group,
tp_group=self.tp_group,
+ sp_group=self.sp_group,
use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py
index 1f32541a7b21..43095af50b8f 100644
--- a/colossalai/cluster/process_group_mesh.py
+++ b/colossalai/cluster/process_group_mesh.py
@@ -161,7 +161,7 @@ def get_ranks_in_group(self, group: ProcessGroup) -> List[int]:
@staticmethod
def get_coords_along_axis(
- base_coord: Tuple[int, ...], axis: int, indices_at_axis: List[int]
+ base_coord: Tuple[int, ...], axis: Union[int, List[int]], indices_at_axis: Union[List[int], List[List[int]]]
) -> List[Tuple[int, ...]]:
"""Get coordinates along the given axis.
@@ -173,13 +173,35 @@ def get_coords_along_axis(
Returns:
List[Tuple[int, ...]]: Coordinates along the axis.
"""
- coords_in_group = []
- for idx in indices_at_axis:
- coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
+ if isinstance(axis, int):
+ axis = [
+ axis,
+ ]
+ assert isinstance(indices_at_axis[0], int)
+ indices_at_axis = [
+ indices_at_axis,
+ ]
+
+ def add_index(base_coord, axis, indices_at_axis):
+ coords_in_group = []
+ for idx in indices_at_axis:
+ coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
+ return coords_in_group
+
+ coords_in_group = [base_coord]
+ for ax, indices_at_ax in zip(axis, indices_at_axis):
+ new_coords_in_group = []
+ for coords in coords_in_group:
+ new_coords_in_group += add_index(coords, ax, indices_at_ax)
+ coords_in_group = new_coords_in_group
+
return coords_in_group
def create_group_along_axis(
- self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
+ self,
+ axis: Union[int, List[int]],
+ indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
+ backend: Optional[str] = None,
) -> ProcessGroup:
"""Create all process groups along the given axis, and return the one which the current process belongs to.
@@ -191,10 +213,21 @@ def create_group_along_axis(
Returns:
ProcessGroup: The process group along the given axis which the current process belongs to.
"""
- indices_at_axis = indices_at_axis or list(range(self._shape[axis]))
+ if isinstance(axis, int):
+ axis = [
+ axis,
+ ]
+ if indices_at_axis is not None:
+ assert isinstance(indices_at_axis[0], int)
+ indices_at_axis = [
+ indices_at_axis,
+ ]
+
+ indices_at_axis = indices_at_axis or [list(range(self._shape[ax])) for ax in axis]
reduced_shape = list(self._shape)
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
- reduced_shape[axis] = 1
+ for ax in axis:
+ reduced_shape[ax] = 1
target_group = None
# use Cartesian product to generate all combinations of coordinates
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
diff --git a/colossalai/inference/engine/policies/bloom.py b/colossalai/inference/engine/policies/bloom.py
index f35b50189e82..5bc47c3c1a49 100644
--- a/colossalai/inference/engine/policies/bloom.py
+++ b/colossalai/inference/engine/policies/bloom.py
@@ -114,12 +114,12 @@ def get_held_layers(self) -> List[Module]:
stage_manager = self.pipeline_stage_manager
held_layers = []
- layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
+ layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage():
held_layers.append(module.word_embeddings)
held_layers.append(module.word_embeddings_layernorm)
held_layers.append(self.model.lm_head)
- start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
diff --git a/colossalai/inference/engine/policies/chatglm2.py b/colossalai/inference/engine/policies/chatglm2.py
index 3e1d94f4785c..c7c6f3b927e1 100644
--- a/colossalai/inference/engine/policies/chatglm2.py
+++ b/colossalai/inference/engine/policies/chatglm2.py
@@ -69,11 +69,11 @@ def get_held_layers(self) -> List[nn.Module]:
stage_manager = self.pipeline_stage_manager
held_layers = []
- layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages)
+ layers_per_stage = stage_manager.distribute_layers(module.num_layers)
if stage_manager.is_first_stage():
held_layers.append(module.embedding)
held_layers.append(module.output_layer)
- start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.encoder.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
if module.encoder.post_layer_norm:
diff --git a/colossalai/inference/engine/policies/llama.py b/colossalai/inference/engine/policies/llama.py
index 11517d7e8a13..a57a4e50cdb9 100644
--- a/colossalai/inference/engine/policies/llama.py
+++ b/colossalai/inference/engine/policies/llama.py
@@ -194,11 +194,11 @@ def get_held_layers(self) -> List[Module]:
stage_manager = self.pipeline_stage_manager
held_layers = []
- layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
+ layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
held_layers.append(self.model.lm_head)
- start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
diff --git a/colossalai/kernel/kernel_loader.py b/colossalai/kernel/kernel_loader.py
index 148c3e3fc08a..353e29b3d122 100644
--- a/colossalai/kernel/kernel_loader.py
+++ b/colossalai/kernel/kernel_loader.py
@@ -6,7 +6,7 @@
CpuAdamX86Extension,
FlashAttentionDaoCudaExtension,
FlashAttentionNpuExtension,
- FlashAttentionXformersCudaExtension,
+ FlashAttentionSdpaCudaExtension,
FusedOptimizerCudaExtension,
LayerNormCudaExtension,
MoeCudaExtension,
@@ -65,9 +65,9 @@ def load(self, ext_name: str = None):
else:
usable_exts = []
for ext in exts:
- if ext.is_hardware_available():
+ if ext.is_available():
# make sure the machine is compatible during kernel loading
- ext.assert_hardware_compatible()
+ ext.assert_compatible()
usable_exts.append(ext)
assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine."
@@ -106,4 +106,20 @@ class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader):
class FlashAttentionLoader(KernelLoader):
- REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension]
+ REGISTRY = [
+ FlashAttentionNpuExtension,
+ FlashAttentionDaoCudaExtension,
+ FlashAttentionSdpaCudaExtension,
+ ]
+
+
+class FlashAttentionWithPaddingMaskLoader(KernelLoader):
+ REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension]
+
+
+class FlashAttentionWithCustomMaskLoader(KernelLoader):
+ REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]
+
+
+class FlashAttentionForFloatAndCustomMaskLoader(KernelLoader):
+ REGISTRY = [FlashAttentionSdpaCudaExtension]
diff --git a/colossalai/legacy/tensor/tensor_spec.py b/colossalai/legacy/tensor/tensor_spec.py
index 5bdd384e5e15..44d8d04b9540 100644
--- a/colossalai/legacy/tensor/tensor_spec.py
+++ b/colossalai/legacy/tensor/tensor_spec.py
@@ -1,4 +1,4 @@
-from dataclasses import dataclass
+from dataclasses import dataclass, field
from typing import Optional
from colossalai.legacy.tensor.distspec import DistPlacementPattern, _DistSpec
@@ -17,5 +17,5 @@ class ColoTensorSpec:
"""
pg: ProcessGroup
- dist_attr: Optional[_DistSpec] = _DistSpec(DistPlacementPattern.REPLICATE)
+ dist_attr: Optional[_DistSpec] = field(default_factory=lambda: _DistSpec(DistPlacementPattern.REPLICATE))
compute_attr: Optional[ComputeSpec] = None
diff --git a/colossalai/nn/layer/colo_attention.py b/colossalai/nn/layer/colo_attention.py
deleted file mode 100644
index 0b7011e8e2d8..000000000000
--- a/colossalai/nn/layer/colo_attention.py
+++ /dev/null
@@ -1,209 +0,0 @@
-import enum
-import math
-import warnings
-from dataclasses import dataclass
-from typing import Iterable, Optional, Tuple
-
-import torch
-import torch.nn.functional as F
-from einops import rearrange
-
-from colossalai.accelerator import get_accelerator
-from colossalai.kernel.kernel_loader import FlashAttentionLoader
-
-
-@dataclass
-class SeqLenInfo:
- seqlens: Iterable[int] = None
- indices: torch.Tensor = None
- max_seqlen: int = None
- cu_seqlens: torch.Tensor = None
-
- @staticmethod
- def materialize(
- attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device()
- ):
- if attn_mask is not None:
- indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
- seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
- else:
- batch_size, tgt_len = size[0], size[1]
- indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
- seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
- max_seqlen = max(seqlens)
- cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
- return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
-
-
-class AttnMaskType(enum.Enum):
- padding = 1
- causal = 2
- paddedcausal = 3
-
-
-class Unpad(torch.autograd.Function):
- """
- Adapted from
- https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
- """
-
- @staticmethod
- def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
- ctx.save_for_backward(indices)
- # [b, s, ...]
- assert tensor.ndim >= 3
- ctx.bsz = tensor.shape[0]
- out = rearrange(tensor, "b s ... -> (b s) ...")
- ctx.shape = out.shape
- # [ntokens, ...]
- return out[indices]
-
- @staticmethod
- def backward(ctx, grad_output):
- (indices,) = ctx.saved_tensors
- # [ntokens, ...]
- grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
- grad[indices] = grad_output
- grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
- # [b, s, ...]
- return grad, None
-
-
-class Repad(torch.autograd.Function):
- """
- Adapted from
- https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
- """
-
- @staticmethod
- def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
- ctx.save_for_backward(indices)
- # [ntokens, ...]
- tensor = tensor
- out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
- # [b*s, ...]
- out[indices] = tensor
- return out
-
- @staticmethod
- def backward(ctx, grad_output):
- (indices,) = ctx.saved_tensors
- # [b*s, ...]
- grad = grad_output[indices]
- # [ntokens, ...]
- return grad, None, None, None
-
-
-class ColoAttention(torch.nn.Module):
- def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None):
- super().__init__()
- assert (
- embed_dim % num_heads == 0
- ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
- if scale is not None:
- self.scale = scale
- else:
- self.scale = 1 / math.sqrt(embed_dim // num_heads)
- self.dropout = dropout
-
- self.attn = FlashAttentionLoader().load()
-
- @staticmethod
- def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
- return Unpad.apply(tensor, indices)
-
- @staticmethod
- def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
- return Repad.apply(tensor, indices, batch_size, seq_len)
-
- def forward(
- self,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- origin_attn_mask: Optional[torch.Tensor] = None,
- attn_mask_type: Optional[AttnMaskType] = None,
- bias: Optional[torch.Tensor] = None,
- ):
- """
- ColoAttention
-
- Args:
- q: (batch, q_seqlen, nheads, headdim)
- k: (batch, kv_seqlen, nheads, headdim)
- v: (batch, kv_seqlen, nheads, headdim)
- origin_attn_mask: (nheads, q_seqlen, kv_seqlen)
- bias: will not be used
- Return:
- attn_out: (batch, q_seqlen, nheads, headdim).
- """
- # if flash attention is not applicable, switch to memory effcient attention
- if self.attn.__name__ == "flash_attention" and (
- query.dtype not in [torch.float16, torch.bfloat16] or bias != None
- ):
- warnings.warn(
- f"flash-attn expects fp16 or bf16 but got {query.dtype}, switching to xformers' implementation."
- )
- self.attn = FlashAttentionLoader().load(ext_name="flash_attention_xformers_cuda")
-
- padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1
- causal = attn_mask_type is not None and attn_mask_type.value > 1
-
- batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
- # unpad
- seq_len_info_q = None
- seq_len_info_kv = None
- if padded:
- # bert style, unpad process
- assert (
- attn_mask is not None
- ), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
- assert attn_mask.dim() == 2, (
- "attention mask is supposed to have shape (batch_size, seq_len), "
- + f"but got {attn_mask.dim()} dimensions."
- )
-
- # bert style
- if tgt_len == src_len:
- seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
- if batch_size > 1:
- query, key, value = self.unpad(
- torch.stack([query, key, value], dim=2), seq_len_info_q.indices
- ).unbind(dim=1)
- else:
- query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
- seq_len_info_kv = seq_len_info_q
- else:
- seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device)
- seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
- if batch_size > 1:
- query = rearrange(query, "b s ... -> c (b s) ...", c=1)
- key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind(
- dim=1
- )
- else:
- query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
-
- out = self.attn(
- query,
- key,
- value,
- seq_len_info_q=seq_len_info_q,
- seq_len_info_kv=seq_len_info_kv,
- origin_attn_mask=origin_attn_mask,
- dropout_p=self.dropout,
- scale=self.scale,
- causal=causal,
- padded=padded,
- )
-
- # repad
- if padded:
- if batch_size > 1:
- out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len)
- out = rearrange(out, "(b s) h d -> b s h d", b=batch_size)
-
- if len(out.shape) == 4:
- out = rearrange(out, "b s h d -> b s (h d)")
- return out
diff --git a/colossalai/nn/layer/scaled_softmax.py b/colossalai/nn/layer/scaled_softmax.py
index a8d72ddd90c9..2e802db2dbca 100644
--- a/colossalai/nn/layer/scaled_softmax.py
+++ b/colossalai/nn/layer/scaled_softmax.py
@@ -8,6 +8,14 @@
from colossalai.kernel.kernel_loader import ScaledMaskedSoftmaxLoader, ScaledUpperTriangleMaskedSoftmaxLoader
+# NOTE: These kernels are compiled on specific GPU arch and not widely applicable.
+# try:
+# from colossalai._C import scaled_masked_softmax as scaled_masked_softmax, scaled_upper_triangle_masked_softmax_cuda as scaled_upper_triang_masked_softmax
+# except ImportError:
+
+scaled_masked_softmax = None
+scaled_upper_triang_masked_softmax = None
+
class AttnMaskType(enum.Enum):
padding = 1
diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py
index bf2f01b10e9b..58008b98f24e 100644
--- a/colossalai/pipeline/schedule/one_f_one_b.py
+++ b/colossalai/pipeline/schedule/one_f_one_b.py
@@ -7,7 +7,7 @@
from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator
-from colossalai.interface import ModelWrapper, OptimizerWrapper
+from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils import get_current_device
@@ -327,9 +327,7 @@ def run_forward_only(
self.send_forward(output_obj)
if outputs is not None:
- if isinstance(model, ModelWrapper):
- model = model.unwrap()
- outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0))
+ outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs}
def run_forward_backward(
@@ -412,9 +410,7 @@ def run_forward_backward(
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None:
- if isinstance(model, ModelWrapper):
- model = model.unwrap()
- outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0))
+ outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs}
def forward_backward_step(
diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py
index c8f9042084da..b0556669b2bc 100644
--- a/colossalai/pipeline/stage_manager.py
+++ b/colossalai/pipeline/stage_manager.py
@@ -1,6 +1,7 @@
import contextlib
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Union
+import numpy as np
import torch.distributed as dist
from torch.distributed import ProcessGroup
@@ -29,6 +30,8 @@ def __init__(
) -> None:
assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False"
+ self.num_layers_per_stage = None
+
self.pg_mesh = pg_mesh
self.pipeline_axis = pipeline_axis
self.prev_rank: Optional[Tuple[int, ...]] = None
@@ -69,6 +72,88 @@ def __init__(
# for shardformer, hold model chunk id
self.model_chunk_id: Optional[int] = None
+ @property
+ def control_distribute_layers(self) -> bool:
+ return self.num_layers_per_stage is not None
+
+ def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: List[int]) -> None:
+ """Set the distribution configuration.
+ This allows user to customize the number of layers for each stage.
+
+ Args:
+ num_model_layers (int): Number of layers in the model.
+ num_layers_per_stage (List[int]): Number of layers for each stage.
+ """
+ assert all([0 < num_layers < num_model_layers for num_layers in num_layers_per_stage])
+ assert sum(num_layers_per_stage) == num_model_layers
+ assert len(num_layers_per_stage) == self.num_stages * (self.num_model_chunks if self.is_interleave else 1)
+ self.num_model_layers = num_model_layers
+ self.num_layers_per_stage = num_layers_per_stage
+
+ def distribute_layers(
+ self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
+ ) -> List[int]:
+ """Divide layers into stages"""
+ num_stages = self.num_stages if num_stages is None else num_stages
+ num_model_chunks = (
+ (self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
+ )
+
+ if self.control_distribute_layers:
+ assert num_layers == self.num_model_layers
+ return self.num_layers_per_stage
+
+ else:
+ quotient = num_layers // (num_stages * num_model_chunks)
+ remainder = num_layers % (num_stages * num_model_chunks)
+
+ # calculate the num_layers per stage
+ layers_per_stage = [quotient] * num_stages * num_model_chunks
+
+ # deal with the rest layers
+ if remainder > 0:
+ start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
+ for i in range(start_position, start_position + remainder):
+ layers_per_stage[i] += 1
+ return layers_per_stage
+
+ def get_stage_index(
+ self,
+ layers_per_stage: List[int],
+ stage: Optional[int] = None,
+ num_model_chunks: Optional[int] = None,
+ num_stages: Optional[int] = None,
+ ) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
+ """
+ Get the start index and end index of layers for each stage.
+
+ Args:
+ layers_per_stage (List[int]): number of layers for each stage
+ stage (int): the stage index
+ num_stages (int): number of stages
+ num_model_chunks (int): number of model chunks
+
+ Returns:
+ - Tuple[int, int]: the start index and end index of this stage
+ - List[Tuple[int, int]]: the start index and end index of this stage for each model chunk
+
+ """
+ stage = self.stage if stage is None else stage
+ num_model_chunks = (
+ (self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
+ )
+ num_stages = self.num_stages if num_stages is None else num_stages
+
+ num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
+
+ stage_indices = []
+ for model_chunk in range(num_model_chunks):
+ start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
+ end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
+ stage_indices.append([start_idx, end_idx])
+
+ return stage_indices[0] if num_model_chunks == 1 else stage_indices
+
def is_first_stage(self, ignore_chunk: bool = False) -> bool:
"""Is the current stage the first stage.
diff --git a/colossalai/shardformer/__init__.py b/colossalai/shardformer/__init__.py
index 77c2af8d18f7..234e7131728f 100644
--- a/colossalai/shardformer/__init__.py
+++ b/colossalai/shardformer/__init__.py
@@ -1 +1 @@
-from .shard import ShardConfig, ShardFormer
+from .shard import GradientCheckpointConfig, ModelSharder, PipelineGradientCheckpointConfig, ShardConfig, ShardFormer
diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py
index 56e8b08c4e4a..7b8aa53800f0 100644
--- a/colossalai/shardformer/layer/__init__.py
+++ b/colossalai/shardformer/layer/__init__.py
@@ -1,3 +1,5 @@
+from ._operation import all_to_all_comm
+from .attn import AttnMaskType, ColoAttention
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row
@@ -23,4 +25,7 @@
"FusedRMSNorm",
"FusedLinear1D_Col",
"ParallelModule",
+ "AttnMaskType",
+ "ColoAttention",
+ "all_to_all_comm",
]
diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py
index 241770901ed7..82d37bb4cf94 100644
--- a/colossalai/shardformer/layer/_operation.py
+++ b/colossalai/shardformer/layer/_operation.py
@@ -167,6 +167,97 @@ def backward(ctx, grad_output):
return grad_input, grad_weight, grad_bias, None, None, None
+def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False):
+ # currently only support one single tensor as output
+ group_size = dist.get_world_size(process_group)
+ cur_rank = dist.get_rank(process_group)
+
+ # output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)]
+
+ # initialization of ring communication
+ recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
+ send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
+ rank_map = list(dist.get_process_group_ranks(process_group))
+ recv_rank = rank_map[recv_rank]
+ send_rank = rank_map[send_rank]
+ recv_tensors = {}
+ send_tensors = {}
+ for k, v in input_to_gather.items():
+ recv_tensors[k] = torch.empty_like(v)
+ send_tensors[k] = v.clone()
+
+ def communicate_step():
+ comm_ops = []
+ for k in recv_tensors:
+ comm_ops.append(dist.P2POp(dist.irecv, recv_tensors[k], recv_rank, group=process_group))
+ comm_ops.append(dist.P2POp(dist.isend, send_tensors[k], send_rank, group=process_group))
+ return dist.batch_isend_irecv(comm_ops)
+
+ def switch_step():
+ for k in recv_tensors:
+ send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k]
+
+ output_tensors = []
+
+ handles = communicate_step()
+ # first round: special case, retrive from local tensor
+ output_tensors.append(func(**input_to_gather, **input_local))
+ for i in range(group_size - 2):
+ for handle in handles:
+ handle.wait()
+
+ switch_step()
+
+ handles = communicate_step()
+
+ # actual computation
+ output_tensors.append(func(**send_tensors, **input_local))
+
+ # final round: special case, no need to send/recv again
+ for handle in handles:
+ handle.wait()
+ output_tensors.append(func(**recv_tensors, **input_local))
+
+ return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim)
+
+
+class _GatherForwardReduceScatterBackward(torch.autograd.Function):
+ """Gather input from sequence parallel in forward and reduce-scatter gradient in backward
+
+ Args:
+ input_ (`torch.Tensor`): The input tensor from sequence parallel region.
+ process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
+ overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
+
+ """
+
+ @staticmethod
+ def forward(ctx, input_, process_group, dim):
+ ctx.process_group = process_group
+ ctx.dim = dim
+
+ return _gather(input_, dim, process_group)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ dim = ctx.dim
+ process_group = ctx.process_group
+
+ # do reduce-scatter
+ new_shape = list(grad_output.shape)
+ assert (
+ new_shape[dim] % dist.get_world_size(process_group) == 0
+ ), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
+ new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
+ grad_list = [
+ item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim)
+ ]
+ output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device)
+ dist.reduce_scatter(output, grad_list, group=process_group)
+
+ return output, None, None
+
+
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
@@ -178,7 +269,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
@staticmethod
- def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True):
+ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
@@ -186,12 +277,25 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter,
ctx.dim = dim
ctx.overlap = overlap
- input_parallel = _gather(input_, dim, process_group)
+ if ring is True:
+ input_to_gather = {"input": input_}
+ input_local = {"weight": weight}
- if bias is not None:
- output = F.linear(input_parallel, weight, bias)
+ output = _ring_as_gather(
+ F.linear,
+ input_to_gather=input_to_gather,
+ input_local=input_local,
+ process_group=process_group,
+ )
+
+ if bias is not None:
+ output += bias
else:
- output = F.linear(input_parallel, weight)
+ input_parallel = _gather(input_, dim, process_group)
+ if bias is not None:
+ output = F.linear(input_parallel, weight, bias)
+ else:
+ output = F.linear(input_parallel, weight)
return output
@@ -294,11 +398,146 @@ def backward(ctx, grad_output):
# wait until reduce-scatter finished
reducescatter_handle.wait()
- return output, grad_weight, grad_bias, None, None, None, None
+ return output, grad_weight, grad_bias, None, None, None, None, None
+
+
+def _ring_as_reducescatter(
+ func, input_to_reducescatter=None, input_local=None, process_group=None, reducescatter_dim=1
+):
+ # currently only support one single tensor as output
+ group_size = dist.get_world_size(process_group)
+ cur_rank = dist.get_rank(process_group)
+
+ # initialization of ring communication
+ recv_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
+ send_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
+ rank_map = list(dist.get_process_group_ranks(process_group))
+ recv_rank = rank_map[recv_rank]
+ send_rank = rank_map[send_rank]
+ input_tensors = []
+ for _ in range(group_size):
+ input_tensors.append({})
+ for k, v in input_to_reducescatter.items():
+ input_shape = v.shape
+ assert input_shape[reducescatter_dim] % group_size == 0
+ _input_tensors = list(torch.split(v, input_shape[reducescatter_dim] // group_size, dim=reducescatter_dim))
+ for i in range(group_size):
+ input_tensors[i][k] = _input_tensors[i]
+ input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank]
+ input_tensors.reverse()
+
+ output_tensor = func(**input_tensors[0], **input_local)
+ recv_tensor = torch.empty_like(output_tensor)
+ send_tensor = output_tensor.clone()
+
+ def communicate_step():
+ recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
+ send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
+ return dist.batch_isend_irecv([recv_op, send_op])
+
+ handles = communicate_step()
+ # first round: special case, retrive from local tensor
+ for i in range(group_size - 2):
+ # actual computation
+ output_tensor = func(**input_tensors[i + 1], **input_local)
+
+ for handle in handles:
+ handle.wait()
+ output_tensor += recv_tensor
+
+ tmp_tensor = send_tensor
+ send_tensor = output_tensor
+ output_tensor = tmp_tensor
+
+ handles = communicate_step()
+
+ # final round: special case, no need to send/recv again
+ output_tensor = func(**input_tensors[-1], **input_local)
+ for handle in handles:
+ handle.wait()
+ output_tensor += recv_tensor
+ return output_tensor
class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
- """Gather input from sequence parallel in forward and reduce-scatter gradient in backward
+ """Reduce-scatter input from sequence parallel in forward and gather gradient in backward with ring
+
+ Args:
+ input_ (`torch.Tensor`): The input tensor from sequence parallel region.
+ process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
+ overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
+
+ """
+
+ @staticmethod
+ def forward(ctx, input_, weight, bias, process_group, dim, ring):
+ ctx.save_for_backward(input_, weight, bias)
+ ctx.use_bias = bias is not None
+ ctx.process_group = process_group
+ ctx.dim = dim
+
+ if ring is True:
+ input_to_reducescatter = {"input": input_}
+ input_local = {"weight": weight}
+
+ if bias is not None:
+ input_to_reducescatter["bias"] = bias
+
+ output = _ring_as_reducescatter(
+ F.linear,
+ input_to_reducescatter=input_to_reducescatter,
+ input_local=input_local,
+ process_group=process_group,
+ )
+ else:
+ if bias is not None:
+ partial_output = F.linear(input_, weight, bias)
+ else:
+ partial_output = F.linear(input_, weight)
+
+ output_shape = list(partial_output.shape)
+ assert (
+ output_shape[dim] % dist.get_world_size(process_group) == 0
+ ), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
+ output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group)
+
+ output_list = [
+ item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim)
+ ]
+ output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous()
+ dist.reduce_scatter(output, output_list, group=process_group)
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input_, weight, bias = ctx.saved_tensors
+ use_bias = ctx.use_bias
+ dim = ctx.dim
+ process_group = ctx.process_group
+
+ # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
+ if use_bias:
+ bias = bias.view(bias.shape)
+
+ grad_output = _gather(grad_output, dim, process_group)
+
+ # TODO Need to fully optimize
+ total_input = input_
+ grad_input = grad_output.matmul(weight)
+ grad_output = grad_output.contiguous()
+ # Convert the tensor shapes to 2D for execution compatibility
+ if len(grad_output.shape) > 2:
+ grad_output = grad_output.view(-1, grad_output.shape[-1])
+ total_input = total_input.view(-1, total_input.shape[-1])
+ grad_weight = grad_output.t().matmul(total_input)
+ grad_bias = grad_output.sum(dim=0) if use_bias else None
+
+ return grad_input, grad_weight, grad_bias, None, None, None
+
+
+class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
+ """Reduce-scatter input from sequence parallel in forward and gather gradient in backward
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
@@ -343,7 +582,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
@staticmethod
- def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
+ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
@@ -351,9 +590,24 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter,
ctx.dim = dim
ctx.overlap = overlap
- input_parallel = _gather(input_, dim, process_group)
+ if ring is True:
+ input_to_gather = {}
+ input_local = {}
+ input_to_gather["input"] = input_
+ input_local["other"] = weight
- output = torch.matmul(input_parallel, weight)
+ output = _ring_as_gather(
+ torch.matmul,
+ input_to_gather=input_to_gather,
+ input_local=input_local,
+ process_group=process_group,
+ gather_dim=dim,
+ )
+
+ else:
+ input_parallel = _gather(input_, dim, process_group)
+
+ output = torch.matmul(input_parallel, weight)
if bias is not None:
output = output + bias
@@ -433,7 +687,7 @@ def backward(ctx, grad_output):
# wait until reduce-scatter finished
reducescatter_handle.wait()
- return output, grad_weight, grad_bias, None, None, None, None
+ return output, grad_weight, grad_bias, None, None, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function):
@@ -448,14 +702,17 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
"""
@staticmethod
- def forward(ctx, input_, dim, process_group):
+ def forward(ctx, input_, dim, process_group, grad_scale=None):
ctx.process_group = process_group
ctx.dim = dim
+ ctx.grad_scale = grad_scale
return _split(input_, dim, process_group)
@staticmethod
def backward(ctx, grad_output):
- return _gather(grad_output, ctx.dim, ctx.process_group), None, None
+ if ctx.grad_scale is not None:
+ grad_output = grad_output * ctx.grad_scale
+ return _gather(grad_output, ctx.dim, ctx.process_group), None, None, None
class _ReduceForward(torch.autograd.Function):
@@ -505,14 +762,50 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
"""
@staticmethod
- def forward(ctx, input_, dim, process_group):
+ def forward(ctx, input_, dim, process_group, grad_scale=None):
ctx.process_group = process_group
ctx.dim = dim
+ ctx.grad_scale = grad_scale
return _gather(input_, dim, process_group)
@staticmethod
def backward(ctx, grad_output):
- return _split(grad_output, ctx.dim, ctx.process_group), None, None
+ if ctx.grad_scale is not None:
+ grad_output = grad_output * ctx.grad_scale
+ return _split(grad_output, ctx.dim, ctx.process_group), None, None, None
+
+
+class _AllToAll(torch.autograd.Function):
+ """All-to-all communication.
+
+ Args:
+ input_: input matrix
+ process_group: communication group
+ scatter_dim: scatter dimension
+ gather_dim: gather dimension
+ """
+
+ @staticmethod
+ def forward(ctx, input_, process_group, scatter_dim, gather_dim):
+ ctx.process_group = process_group
+ ctx.scatter_dim = scatter_dim
+ ctx.gather_dim = gather_dim
+ world_size = dist.get_world_size(process_group)
+ bsz, _, _ = input_.shape
+
+ # using all_to_all_single when batch size is 1
+ if bsz == 1:
+ return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim)
+ else:
+ return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim)
+
+ @staticmethod
+ def backward(ctx, *grad_output):
+ process_group = ctx.process_group
+ scatter_dim = ctx.gather_dim
+ gather_dim = ctx.scatter_dim
+ return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
+ return (return_grad, None, None, None)
class HookParameter(torch.autograd.Function):
@@ -608,6 +901,40 @@ def _reduce_scatter(input_, dim=1, process_group=None):
return output
+def _all_to_all(input_, world_size, group, scatter_dim, gather_dim):
+ input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
+ output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
+ dist.all_to_all(output_list, input_list, group=group)
+ return torch.cat(output_list, dim=gather_dim).contiguous()
+
+
+def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
+ inp_shape = list(input_.shape)
+ inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size
+ if scatter_dim < 2:
+ input_t = input_.reshape([seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]).contiguous()
+ else:
+ input_t = (
+ input_.reshape([-1, seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :])
+ .transpose(0, 1)
+ .contiguous()
+ )
+
+ output = torch.empty_like(input_t)
+ dist.all_to_all_single(output, input_t, group=group)
+
+ if scatter_dim < 2:
+ output = output.transpose(0, 1).contiguous()
+
+ return output.reshape(
+ inp_shape[:gather_dim]
+ + [
+ inp_shape[gather_dim] * seq_world_size,
+ ]
+ + inp_shape[gather_dim + 1 :]
+ ).contiguous()
+
+
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
@@ -617,31 +944,39 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
def linear_gather_forward_reducescatter_backward(
- input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
+ input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
):
return _LinearWithGatherForwardReduceScatterBackward.apply(
- input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
+ input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
)
-def linear_reducescatter_forward_gather_backward(input_, process_group, dim):
- return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
+def gather_forward_reducescatter_backward(input_, process_group, dim):
+ return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim)
+
+
+def reducescatter_forward_gather_backward(input_, process_group, dim):
+ return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
+
+
+def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False):
+ return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring)
def matmul_gather_forward_reducescatter_backward(
- input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
+ input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
):
return _MatmulWithGatherForwardReduceScatterBackward.apply(
- input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
+ input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
)
-def gather_forward_split_backward(input_, dim, process_group):
- return _GatherForwardSplitBackward.apply(input_, dim, process_group)
+def gather_forward_split_backward(input_, dim, process_group, grad_scale=None):
+ return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale)
-def split_forward_gather_backward(input_, dim, process_group):
- return _SplitForwardGatherBackward.apply(input_, dim, process_group)
+def split_forward_gather_backward(input_, dim, process_group, grad_scale=None):
+ return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale)
def reduce_forward(input_, process_group):
@@ -650,3 +985,7 @@ def reduce_forward(input_, process_group):
def reduce_backward(input_, process_group):
return _ReduceBackward.apply(input_, process_group)
+
+
+def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
+ return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py
new file mode 100644
index 000000000000..f3f6e59d3d6a
--- /dev/null
+++ b/colossalai/shardformer/layer/attn.py
@@ -0,0 +1,269 @@
+from enum import Enum
+from typing import Callable, Dict, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+
+from colossalai.kernel.kernel_loader import (
+ FlashAttentionForFloatAndCustomMaskLoader,
+ FlashAttentionLoader,
+ FlashAttentionWithCustomMaskLoader,
+ FlashAttentionWithPaddingMaskLoader,
+ KernelLoader,
+)
+
+__all__ = [
+ "AttnMaskType",
+ "ColoAttention",
+]
+
+
+class AttnMaskType(Enum):
+ CUSTOM = 0
+ PADDED = 1
+ CAUSAL = 2
+ PADDED_CAUSAL = 3
+
+
+def invert_mask(mask: torch.Tensor) -> torch.Tensor:
+ """Invert the mask tensor.
+
+ Args:
+ mask (torch.Tensor): Mask tensor. Shape should be [B, 1, Sq, Skv]
+
+ Returns:
+ torch.Tensor: Inverted mask tensor.
+ """
+ inverted_mask = 1.0 - mask
+ return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(mask.dtype).min)
+
+
+# adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
+def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.Tensor]:
+ """Get padding information from padding mask.
+
+ Args:
+ padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, S]
+
+ Returns:
+ Tuple[int, torch.Tensor, torch.Tensor]: Tuple of (max_seq_len, cu_seqlens, indices)
+ """
+ seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
+ return max_seqlen_in_batch, cu_seqlens, indices
+
+
+class ColoAttention:
+ _kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
+
+ @staticmethod
+ def _init_kernels_dispatch():
+ if ColoAttention._kernel_dispatch_map is None:
+ # fp16/bf16
+ half_dispatch_map = {
+ None: FlashAttentionLoader(),
+ AttnMaskType.CUSTOM: FlashAttentionWithCustomMaskLoader(),
+ AttnMaskType.PADDED: FlashAttentionWithPaddingMaskLoader(),
+ AttnMaskType.CAUSAL: FlashAttentionLoader(),
+ AttnMaskType.PADDED_CAUSAL: FlashAttentionWithPaddingMaskLoader(),
+ }
+ # fp32
+ float_dispatch_map = {
+ None: FlashAttentionForFloatAndCustomMaskLoader(),
+ AttnMaskType.CUSTOM: FlashAttentionForFloatAndCustomMaskLoader(),
+ AttnMaskType.CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
+ }
+ ColoAttention._kernel_dispatch_map = {
+ torch.float16: half_dispatch_map,
+ torch.bfloat16: half_dispatch_map,
+ torch.float32: float_dispatch_map,
+ }
+
+ @staticmethod
+ def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType]) -> Callable:
+ ColoAttention._init_kernels_dispatch()
+ if (
+ dtype not in ColoAttention._kernel_dispatch_map
+ or mask_type not in ColoAttention._kernel_dispatch_map[dtype]
+ ):
+ raise ValueError(
+ "FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
+ )
+ # lazy load
+ if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
+ ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
+ mask_type
+ ].load()
+ return ColoAttention._kernel_dispatch_map[dtype][mask_type]
+
+ @staticmethod
+ def prepare_attn_kwargs(
+ shape_4d: Tuple[int],
+ dtype: torch.dtype,
+ device: torch.device,
+ q_padding_mask: Optional[torch.Tensor] = None,
+ kv_padding_mask: Optional[torch.Tensor] = None,
+ is_causal: bool = False,
+ ) -> Dict[str, torch.Tensor]:
+ """Return a dictionary of keyword arguments for attention function. It supports 4 mask type.
+ 1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves.
+ 2. padded mask: recv padding mask and is_causal=False, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}.
+ 3. causal mask: no padding mask and is_causal=True, return {attention_mask, attention_mask_type}.
+ 4. padded causal mask: recv padding mask and is_causal=True, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}.
+
+ Args:
+ shape_4d (Tuple[int]): Should be (B, 1, Sq, Skv)
+ dtype (torch.dtype): Dtype of attention mask, generally should be ``hidden_states.dtype``
+ device (torch.device): Device of attention mask, generally should be ``hidden_states.device``
+ q_padding_mask (Optional[torch.Tensor], optional): Padding mask of query. It should be a long tensor or int tensor.
+ The shape should be [B, Sq]. ``1`` means valid token, and ``0`` means padding token. Defaults to None.
+ kv_padding_mask (Optional[torch.Tensor], optional): Padding mask of key and value. It should be a long tensor or int tensor.
+ The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token.
+ If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None.
+ is_causal (bool, optional): Whether to use causal attention mask. Defaults to False.
+
+ Returns:
+ Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function.
+ """
+ if q_padding_mask is None and not is_causal:
+ return {}
+ assert len(shape_4d) == 4 and shape_4d[1] == 1
+ b, _, s_q, s_kv = shape_4d
+ outputs = {}
+ if (q_padding_mask is None or q_padding_mask.bool().all()) and (
+ kv_padding_mask is None or kv_padding_mask.bool().all()
+ ):
+ # no padding
+ assert is_causal
+ outputs["attention_mask_type"] = AttnMaskType.CAUSAL
+ attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv)
+ else:
+ if kv_padding_mask is None:
+ # self attention
+ kv_padding_mask = q_padding_mask
+ assert q_padding_mask.shape == (b, s_q) and kv_padding_mask.shape == (
+ b,
+ s_kv,
+ ), f"q_padding_mask shape {q_padding_mask.shape} and kv_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
+ attention_mask = torch.einsum("bi,bj->bij", q_padding_mask, kv_padding_mask).to(dtype=dtype, device=device)
+ max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
+ max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask)
+ outputs.update(
+ {
+ "cu_seqlens_q": cu_seqlens_q,
+ "cu_seqlens_kv": cu_seqlens_kv,
+ "max_seqlen_q": max_seqlen_q,
+ "max_seqlen_kv": max_seqlen_kv,
+ "q_indices": q_indices,
+ "kv_indices": kv_indices,
+ }
+ )
+ if is_causal:
+ outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
+ attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
+ else:
+ outputs["attention_mask_type"] = AttnMaskType.PADDED
+ attention_mask = invert_mask(attention_mask).unsqueeze(1)
+ outputs["attention_mask"] = attention_mask
+ return outputs
+
+ @staticmethod
+ def attention(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ attention_mask_type: AttnMaskType = AttnMaskType.CUSTOM,
+ cu_seqlens_q: Optional[torch.Tensor] = None,
+ cu_seqlens_kv: Optional[torch.Tensor] = None,
+ max_seqlen_q: Optional[int] = None,
+ max_seqlen_kv: Optional[int] = None,
+ q_indices: Optional[torch.Tensor] = None,
+ kv_indices: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ scale: Optional[float] = None,
+ ) -> torch.Tensor:
+ """Flash Attention function. It supports 4 mask type.
+ 1. custom mask: recv attention_mask
+ 2. padded mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices
+ 3. causal mask: recv attention_mask, attention_mask_type
+ 4. padded causal mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices
+
+ Args:
+ q (torch.Tensor): Query tensor. Shape should be [B, N, Sq, D]
+ k (torch.Tensor): Key tensor. Shape should be [B, N, Skv, D]
+ v (torch.Tensor): Value tensor. Shape should be [B, N, Skv, D]
+ attention_mask (Optional[torch.Tensor], optional): Attention mask tensor. Shape should be [B, 1, Sq, Skv]. Defaults to None.
+ attention_mask_type (AttnMaskType, optional): Attention mask type. Defaults to AttnMaskType.CUSTOM.
+ cu_seqlens_q (Optional[torch.Tensor], optional): The cumulative sequence lengths
+ of the sequences in the batch, used to index into q.
+ Shape should be [B+1]. Defaults to None.
+ cu_seqlens_kv (Optional[torch.Tensor], optional): The cumulative sequence lengths
+ of the sequences in the batch, used to index into kv.
+ Shape should be [B+1]. Defaults to None.
+ max_seqlen_q (Optional[int], optional): Maximum query sequence length in the batch. Defaults to None.
+ max_seqlen_kv (Optional[int], optional): Maximum key/value sequence length in the batch. Defaults to None.
+ indices (Optional[torch.Tensor], optional): The indices of non-masked tokens from the flattened input sequence.
+ Shape should be [NUM_TOKENS]. Defaults to None.
+ dropout_p (float, optional): Dropout probability. Defaults to 0.0.
+ scale (Optional[float], optional): Scaling factor applied prior to softmax. Defaults to None.
+
+ Returns:
+ torch.Tensor: Output tensor. Shape should be [B, N, Sq, D]
+ """
+ # known issue: sdpa does not support attention mask which contains whole row of masked tokens, which leads to nan
+ # this case is usaul when padding mask is used and self attention is performed
+ # thus, we don't use sdpa when padding mask is used
+ # sanity check
+ if attention_mask is not None:
+ assert torch.is_floating_point(attention_mask), "attention_mask should be a floating point tensor."
+ if attention_mask_type in (AttnMaskType.CUSTOM, AttnMaskType.CAUSAL):
+ assert (
+ cu_seqlens_q is None
+ and cu_seqlens_kv is None
+ and max_seqlen_q is None
+ and max_seqlen_kv is None
+ and q_indices is None
+ and kv_indices is None
+ )
+ if attention_mask_type == AttnMaskType.CUSTOM:
+ assert not torch.all(attention_mask != 0, dim=-1).any()
+ elif attention_mask_type in (
+ AttnMaskType.PADDED,
+ AttnMaskType.PADDED_CAUSAL,
+ ):
+ assert (
+ cu_seqlens_q is not None
+ and cu_seqlens_kv is not None
+ and max_seqlen_q is not None
+ and max_seqlen_kv is not None
+ and q_indices is not None
+ and kv_indices is not None
+ )
+ else:
+ # if attention_mask is None, attention_mask_type should be the default value
+ assert attention_mask_type == AttnMaskType.CUSTOM
+ # kernel dispatch
+ mask_type = attention_mask_type if attention_mask is not None else None
+ attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type)
+ is_causal = attention_mask is not None and attention_mask_type in (
+ AttnMaskType.CAUSAL,
+ AttnMaskType.PADDED_CAUSAL,
+ )
+ return attn_func(
+ q,
+ k,
+ v,
+ dropout_p=dropout_p,
+ scale=scale,
+ attention_mask=attention_mask,
+ is_causal=is_causal,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_kv=cu_seqlens_kv,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_kv=max_seqlen_kv,
+ q_indices=q_indices,
+ kv_indices=kv_indices,
+ )
diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py
index eeb0ef39975f..7c8619ad8f5c 100644
--- a/colossalai/shardformer/layer/linear.py
+++ b/colossalai/shardformer/layer/linear.py
@@ -23,11 +23,13 @@
)
from ._operation import (
+ gather_forward_reducescatter_backward,
gather_forward_split_backward,
linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
reduce_forward,
+ reducescatter_forward_gather_backward,
split_forward_gather_backward,
)
from .parallel_module import ParallelModule
@@ -74,7 +76,7 @@ def __init__(
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = False,
- seq_parallel: bool = False,
+ seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
overlap: torch.cuda.Stream = None,
skip_bias_add: bool = False,
@@ -89,7 +91,7 @@ def __init__(
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
- self.seq_parallel = seq_parallel
+ self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.overlap = overlap
self.skip_bias_add = skip_bias_add
@@ -196,12 +198,18 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
- if self.seq_parallel:
+
+ if self.seq_parallel_mode is None:
+ output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
+ elif self.seq_parallel_mode == "split_gather":
+ input_parallel = gather_forward_reducescatter_backward(
+ input_parallel, self.process_group, self.seq_parallel_dim
+ )
+ output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False)
+ elif self.seq_parallel_mode == "ring":
output_parallel = linear_gather_forward_reducescatter_backward(
- input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap
+ input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
)
- else:
- output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
if self.gather_output:
# All-gather across the partitions.
@@ -225,7 +233,8 @@ class Linear1D_Row(ParallelModule):
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
- seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
+ seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None.
+ seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
@@ -245,7 +254,7 @@ def __init__(
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
- seq_parallel: bool = False,
+ seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
parallel_input: bool = True,
skip_bias_add: bool = False,
@@ -265,7 +274,7 @@ def __init__(
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
- self.seq_parallel = seq_parallel
+ self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.num_partitions = dist.get_world_size(self.process_group)
@@ -403,18 +412,26 @@ def forward(self, input_: Tensor) -> Tensor:
output_parallel_list[i], group=self.process_group, async_op=True
)
handle_list.append(handle)
- # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
for handle in handle_list:
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
- output_parallel = linear_with_async_comm(input_, self.weight, None, None, False)
- if self.seq_parallel:
- output = linear_reducescatter_forward_gather_backward(
+ if self.seq_parallel_mode is None:
+ output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
+ output = reduce_forward(output_parallel, self.process_group)
+ elif self.seq_parallel_mode == "split_gather":
+ output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
+ output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim
)
- else:
- output = reduce_forward(output_parallel, self.process_group)
+ elif self.seq_parallel_mode == "ring":
+ output = linear_reducescatter_forward_gather_backward(
+ input_,
+ self.weight,
+ process_group=self.process_group,
+ dim=self.seq_parallel_dim,
+ ring=True,
+ )
if not self.skip_bias_add:
if self.bias is not None:
diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py
index 12476d050600..dc3634238f74 100644
--- a/colossalai/shardformer/layer/qkv_fused_linear.py
+++ b/colossalai/shardformer/layer/qkv_fused_linear.py
@@ -25,12 +25,12 @@
from ._operation import (
gather_forward_split_backward,
- linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
matmul_gather_forward_reducescatter_backward,
matmul_with_async_comm,
reduce_backward,
reduce_forward,
+ reducescatter_forward_gather_backward,
split_forward_gather_backward,
)
from .parallel_module import ParallelModule
@@ -150,7 +150,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
device (`torch.device`): The device of parameters, defaults to None.
n_fused (int): The number items fused, defaults to 3 (QKV).
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
- seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
+ seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
@@ -175,7 +175,7 @@ def __init__(
process_group: ProcessGroup = None,
async_communication: bool = False,
gather_output: bool = False,
- seq_parallel: bool = False,
+ seq_parallel_mode: str = None,
overlap: bool = False,
skip_bias_add: bool = False,
n_fused: int = 3,
@@ -190,7 +190,7 @@ def __init__(
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
- self.seq_parallel = seq_parallel
+ self.seq_parallel_mode = seq_parallel_mode
self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
@@ -312,17 +312,22 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
- if self.seq_parallel:
- input_parallel = input_
- output_parallel = matmul_gather_forward_reducescatter_backward(
- input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap
- )
- else:
+ if self.seq_parallel_mode is None:
# Set up backprop all-reduce.
input_parallel = reduce_backward(input_, self.process_group)
output_parallel = matmul_with_async_comm(
input_parallel, self.weight, bias, self.process_group, self.async_communication
)
+ elif self.seq_parallel_mode == "split_gather":
+ input_parallel = input_
+ output_parallel = matmul_gather_forward_reducescatter_backward(
+ input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap
+ )
+ elif self.seq_parallel_mode == "ring":
+ input_parallel = input_
+ output_parallel = matmul_gather_forward_reducescatter_backward(
+ input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap, True
+ )
if self.gather_output:
# All-gather across the partitions.
@@ -347,7 +352,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
- seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
+ seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None.
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
@@ -366,7 +371,7 @@ def __init__(
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
- seq_parallel: bool = False,
+ seq_parallel_mode: str = None,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
@@ -385,7 +390,7 @@ def __init__(
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
- self.seq_parallel = seq_parallel
+ self.seq_parallel_mode = seq_parallel_mode
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
@@ -528,11 +533,15 @@ def forward(self, input_: Tensor) -> Tensor:
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
- output_parallel = torch.matmul(input_, self.weight)
- if self.seq_parallel:
- output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
- else:
+ if self.seq_parallel_mode is None:
+ output_parallel = torch.matmul(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group)
+ elif self.seq_parallel_mode == "split_gather":
+ output_parallel = torch.matmul(input_, self.weight)
+ output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
+ elif self.seq_parallel_mode == "ring":
+ output_parallel = torch.matmul(input_, self.weight)
+ output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
if not self.skip_bias_add:
if self.bias is not None:
@@ -702,7 +711,6 @@ def from_native_module(
# process_group=process_group,
# is_transposed=False)
# linear_1d.bias.data.copy_(sharded_bias.data)
- print(linear_1d.weight.shape)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py
index 0d2cc1b3370d..9c6ced4454dc 100644
--- a/colossalai/shardformer/layer/utils.py
+++ b/colossalai/shardformer/layer/utils.py
@@ -35,17 +35,21 @@ def is_sp_partial_derived_param(param):
return getattr(param, "partial_derived", False)
@staticmethod
- def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, grads: List[torch.Tensor] = None):
+ def allreduce_partial_data_grad(
+ process_group: ProcessGroup,
+ model: nn.Module = None,
+ grads: List[torch.Tensor] = None,
+ ):
"""
Allreduce partial derived gradients across the specified process group.
This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism.
Args:
- tp_group (ProcessGroup): The process group for gradient synchronization.
+ process_group (ProcessGroup): The process group for gradient synchronization.
model (nn.Module): The model from which gradients will be synchronized.
grads (List[torch.Tensor]): The list of gradients to be synchronized.
-
+ only_sp_partial (bool): Whether handle all the parameters or only parameters marked as partial derived.
Raises:
AssertionError: If both `model` and `grads` are provided or neither is provided.
"""
@@ -53,22 +57,26 @@ def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None,
assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None."
# Get the size of the process group, which determines whether synchronization is needed.
- tp_size = get_world_size(tp_group) if tp_group is not None else 1
+ group_size = get_world_size(process_group) if process_group is not None else 1
- if tp_size == 1:
+ if group_size == 1:
# If the process group size is 1, no synchronization is required.
return
if model is not None:
# If `model` is provided, extract partial derived gradients from the model's parameters.
grads = []
+
for p in model.parameters():
- if p.grad is not None and SeqParallelUtils.is_sp_partial_derived_param(p):
- grads.append(p.grad.data)
+ if p.grad is not None:
+ if SeqParallelUtils.is_sp_partial_derived_param(p):
+ grads.append(p.grad.data)
# Flatten and reduce the gradients using the specified process group.
+ if len(grads) == 0:
+ return
coalesced = _flatten_dense_tensors(grads)
- dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
+ dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=process_group)
# Unflatten the synchronized gradients and update the model's gradients.
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
@@ -76,7 +84,7 @@ def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None,
else:
# If `grads` are provided explicitly, synchronize those gradients directly.
coalesced = _flatten_dense_tensors(grads)
- dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
+ dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=process_group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py
index 7411e1d0ec46..0838fcee682e 100644
--- a/colossalai/shardformer/modeling/bert.py
+++ b/colossalai/shardformer/modeling/bert.py
@@ -186,13 +186,14 @@ def bert_model_forward(
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config is not None and shard_config.enable_sequence_parallelism:
- hidden_states = split_forward_gather_backward(
- hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
- )
- if encoder_hidden_states is not None:
- encoder_hidden_states = split_forward_gather_backward(
- encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+ if shard_config.sequence_parallelism_mode == "split_gather":
+ hidden_states = split_forward_gather_backward(
+ hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
)
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = split_forward_gather_backward(
+ encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+ )
for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx):
if stage_manager.is_first_stage() and idx == 0:
@@ -240,9 +241,10 @@ def custom_forward(*inputs):
# When sequence parallelism done, gather the output tensor in forward and split it in backward
if shard_config is not None and shard_config.enable_sequence_parallelism:
- hidden_states = gather_forward_split_backward(
- hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
- )
+ if shard_config.sequence_parallelism_mode == "split_gather":
+ hidden_states = gather_forward_split_backward(
+ hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+ )
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py
index d5c10541a28f..bd84c87c667d 100644
--- a/colossalai/shardformer/modeling/blip2.py
+++ b/colossalai/shardformer/modeling/blip2.py
@@ -3,6 +3,8 @@
import torch
import torch.nn as nn
+from colossalai.shardformer.layer import ColoAttention
+
def forward_fn():
def forward(
@@ -62,8 +64,6 @@ def forward(
def get_blip2_flash_attention_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
- from colossalai.nn.layer.colo_attention import ColoAttention
-
def forward(
self: Blip2Attention,
hidden_states: torch.Tensor,
@@ -71,16 +71,25 @@ def forward(
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
-
+ assert head_mask is None, "head_mask is not supported in FlashAttention"
bsz, tgt_len, embed_dim = hidden_states.size()
mixed_qkv = self.qkv(hidden_states)
- mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4)
- query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
+ mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
+ query_states, key_states, value_states = (
+ mixed_qkv[0],
+ mixed_qkv[1],
+ mixed_qkv[2],
+ )
- attention = ColoAttention(
- embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout.p, scale=self.scale
+ dropout_p = self.dropout.p if self.training else 0.0
+ context_layer = ColoAttention.attention(
+ query_states,
+ key_states,
+ value_states,
+ dropout_p=dropout_p,
+ scale=self.scale,
)
- context_layer = attention(query_states, key_states, value_states)
+ context_layer = context_layer.permute(0, 2, 1, 3).reshape(bsz, tgt_len, self.embed_dim)
output = self.projection(context_layer)
outputs = (output, None)
@@ -93,7 +102,11 @@ def forward(
def get_jit_fused_blip2_QFormer_self_output_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput
- def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ def forward(
+ self: Blip2QFormerSelfOutput,
+ hidden_states: torch.Tensor,
+ input_tensor: torch.Tensor,
+ ) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
hidden_states = self.LayerNorm(hidden_states)
@@ -105,7 +118,11 @@ def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_ten
def get_jit_fused_blip2_QFormer_output_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput
- def forward(self: Blip2QFormerOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ def forward(
+ self: Blip2QFormerOutput,
+ hidden_states: torch.Tensor,
+ input_tensor: torch.Tensor,
+ ) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
hidden_states = self.LayerNorm(hidden_states)
diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py
index d94c30d29e71..fe70376e144d 100644
--- a/colossalai/shardformer/modeling/bloom.py
+++ b/colossalai/shardformer/modeling/bloom.py
@@ -213,10 +213,11 @@ def bloom_model_forward(
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
- if shard_config.enable_sequence_parallelism:
- hidden_states = split_forward_gather_backward(
- hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
- )
+ if shard_config and shard_config.enable_sequence_parallelism:
+ if shard_config.sequence_parallelism_mode == "split_gather":
+ hidden_states = split_forward_gather_backward(
+ hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+ )
start_idx, end_idx = stage_index[0], stage_index[1]
for i, (block, layer_past) in enumerate(
@@ -261,10 +262,11 @@ def custom_forward(*inputs):
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
# When sequence parallelism done, gather the output tensor in forward and split it in backward
- if shard_config.enable_sequence_parallelism:
- hidden_states = gather_forward_split_backward(
- hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
- )
+ if shard_config and shard_config.enable_sequence_parallelism:
+ if shard_config.sequence_parallelism_mode == "split_gather":
+ hidden_states = gather_forward_split_backward(
+ hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+ )
if stage_manager.is_last_stage():
# Add last hidden state
diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py
index d13bd34926a5..9207b34d0d1c 100644
--- a/colossalai/shardformer/modeling/chatglm2.py
+++ b/colossalai/shardformer/modeling/chatglm2.py
@@ -1,4 +1,5 @@
""" PyTorch ChatGLM model. """
+
from typing import List, Optional, Tuple
import torch
@@ -9,63 +10,49 @@
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
+from colossalai.shardformer.layer import AttnMaskType, ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
def get_flash_core_attention_forward():
- from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
-
from .chatglm2_6b.modeling_chatglm import CoreAttention
def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask):
- pytorch_major_version = int(torch.__version__.split(".")[0])
- if pytorch_major_version >= 2:
- query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
- if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
- context_layer = torch.nn.functional.scaled_dot_product_attention(
- query_layer, key_layer, value_layer, is_causal=True
- )
- else:
- if attention_mask is not None:
- attention_mask = ~attention_mask
- context_layer = torch.nn.functional.scaled_dot_product_attention(
- query_layer, key_layer, value_layer, attention_mask
- )
- context_layer = context_layer.permute(2, 0, 1, 3)
- new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
- context_layer = context_layer.reshape(*new_context_layer_shape)
- else:
- # Raw attention scores
- query_layer = query_layer.permute(1, 0, 2, 3).contiguous()
- key_layer = key_layer.permute(1, 0, 2, 3).contiguous()
- value_layer = value_layer.permute(1, 0, 2, 3).contiguous()
-
- scale = 1.0 / self.norm_factor
- if self.coeff is not None:
- scale = scale * self.coeff
-
- flash_attention_mask = None
- attn_mask_type = None
- if attention_mask is None:
- attn_mask_type = AttnMaskType.causal
- else:
- flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
- if not torch.all(flash_attention_mask):
- attn_mask_type = AttnMaskType.paddedcausal
-
- attention = ColoAttention(
- embed_dim=self.hidden_size_per_partition,
- num_heads=self.num_attention_heads_per_partition,
- dropout=self.attention_dropout.p,
- scale=scale,
+ query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
+ attention_mask_type = AttnMaskType.CAUSAL
+ attn_bias = torch.zeros(
+ query_layer.shape[0],
+ 1,
+ query_layer.shape[2],
+ key_layer.shape[2],
+ dtype=query_layer.dtype,
+ device=query_layer.device,
)
- context_layer = attention(
- query_layer, key_layer, value_layer, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type
+ temp_mask = (
+ torch.ones(query_layer.shape[2], key_layer.shape[2], dtype=torch.bool, device=query_layer.device)
+ .tril(diagonal=0)
+ .expand(query_layer.shape[0], 1, -1, -1)
)
-
- context_layer = context_layer.permute(1, 0, -1).contiguous()
-
+ attn_bias.masked_fill_(temp_mask.logical_not(), torch.finfo(query_layer.dtype).min)
+ else:
+ attention_mask_type = AttnMaskType.CUSTOM
+ if attention_mask is not None:
+ attn_bias = torch.zeros_like(attention_mask, dtype=query_layer.dtype)
+ attn_bias.masked_fill_(attention_mask, torch.finfo(query_layer.dtype).min)
+ dropout_p = self.attention_dropout.p if self.training else 0.0
+ context_layer = ColoAttention.attention(
+ query_layer,
+ key_layer,
+ value_layer,
+ attention_mask=attn_bias,
+ attention_mask_type=attention_mask_type,
+ dropout_p=dropout_p,
+ )
+ context_layer = context_layer.permute(2, 0, 1, 3)
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
+ context_layer = context_layer.reshape(*new_context_layer_shape)
return context_layer
return forward
@@ -169,11 +156,17 @@ def chatglm_model_forward(
if self.pre_seq_len is not None:
if past_key_values is None:
past_key_values = self.get_prompt(
- batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype
+ batch_size=batch_size,
+ device=input_ids.device,
+ dtype=inputs_embeds.dtype,
)
if attention_mask is not None:
attention_mask = torch.cat(
- [attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1
+ [
+ attention_mask.new_ones((batch_size, self.pre_seq_len)),
+ attention_mask,
+ ],
+ dim=-1,
)
if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
@@ -198,17 +191,23 @@ def chatglm_model_forward(
all_hidden_states = () if output_hidden_states else None
start_idx, end_idx = stage_index[0], stage_index[1]
- if shard_config.enable_sequence_parallelism:
- hidden_states = split_forward_gather_backward(
- hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
- )
+ if shard_config and shard_config.enable_sequence_parallelism:
+ if shard_config.sequence_parallelism_mode == "split_gather":
+ hidden_states = split_forward_gather_backward(
+ hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
+ )
for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.encoder.gradient_checkpointing and self.encoder.training:
layer_ret = torch.utils.checkpoint.checkpoint(
- layer, hidden_states, attention_mask, rotary_pos_emb, past_key_values[idx], use_cache
+ layer,
+ hidden_states,
+ attention_mask,
+ rotary_pos_emb,
+ past_key_values[idx],
+ use_cache,
)
else:
layer_ret = layer(
@@ -222,10 +221,11 @@ def chatglm_model_forward(
if use_cache:
presents = presents + (kv_cache,)
- if shard_config.enable_sequence_parallelism:
- hidden_states = gather_forward_split_backward(
- hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
- )
+ if shard_config and shard_config.enable_sequence_parallelism:
+ if shard_config.sequence_parallelism_mode == "split_gather":
+ hidden_states = gather_forward_split_backward(
+ hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
+ )
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if stage_manager.is_last_stage():
@@ -234,7 +234,14 @@ def chatglm_model_forward(
hidden_states = self.encoder.final_layernorm(hidden_states)
if not return_dict:
return tuple(
- v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None
+ v
+ for v in [
+ hidden_states,
+ presents,
+ all_hidden_states,
+ all_self_attentions,
+ ]
+ if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
@@ -368,7 +375,9 @@ def forward(
# Run encoder.
# [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size]
inputs_embeds = split_forward_gather_backward(
- inputs_embeds, dim=0, process_group=shard_config.tensor_parallel_process_group
+ inputs_embeds,
+ dim=0,
+ process_group=shard_config.tensor_parallel_process_group,
)
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
inputs_embeds,
@@ -380,7 +389,9 @@ def forward(
)
hidden_states = gather_forward_split_backward(
- hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
+ hidden_states,
+ dim=0,
+ process_group=shard_config.tensor_parallel_process_group,
)
if not return_dict:
diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py
index 1e22d9094eae..1306c8aa6299 100644
--- a/colossalai/shardformer/modeling/gpt2.py
+++ b/colossalai/shardformer/modeling/gpt2.py
@@ -21,12 +21,82 @@
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
from ..layer._operation import gather_forward_split_backward
+logger = logging.get_logger(__name__)
+
+
+def _get_attention_mask(
+ self: GPT2Model,
+ shard_config: ShardConfig,
+ hidden_states: torch.Tensor,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]],
+ attention_mask: Optional[torch.FloatTensor],
+ encoder_hidden_states: Optional[torch.Tensor],
+ encoder_attention_mask: Optional[torch.FloatTensor],
+) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]:
+ batch_size, seq_len = hidden_states.shape[:2]
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ if shard_config.enable_flash_attention:
+ encoder_attention_mask = ColoAttention.prepare_attn_kwargs(
+ (encoder_batch_size, 1, seq_len, encoder_sequence_length),
+ dtype=hidden_states.dtype,
+ dtype2=encoder_hidden_states.dtype,
+ q_padding_mask=attention_mask,
+ kv_padding_mask=encoder_attention_mask,
+ )
+ else:
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=encoder_hidden_states.device)
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ if shard_config.enable_flash_attention:
+ encoder_attention_mask = {"attention_mask": None}
+ else:
+ encoder_attention_mask = None
+ # GPT2Attention mask.
+ past_key_values_length = 0
+ if past_key_values is not None and past_key_values[0] is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ if shard_config.enable_flash_attention:
+ if attention_mask is not None:
+ attention_mask = attention_mask.view(batch_size, -1)
+ attention_mask = ColoAttention.prepare_attn_kwargs(
+ (batch_size, 1, seq_len, seq_len + past_key_values_length),
+ hidden_states.dtype,
+ hidden_states.device,
+ attention_mask,
+ is_causal=True,
+ )
+ elif attention_mask is not None:
+ if batch_size <= 0:
+ raise ValueError("batch_size has to be defined and > 0")
+ attention_mask = attention_mask.view(batch_size, -1)
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask = attention_mask[:, None, None, :]
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and the dtype's smallest value for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
+ return attention_mask, encoder_attention_mask
+
class GPT2PipelineForwards:
"""
@@ -83,10 +153,10 @@ def gpt2_model_forward(
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
- batch_size = input_ids.shape[0]
+ input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
- batch_size = inputs_embeds.shape[0]
+ inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
@@ -99,38 +169,7 @@ def gpt2_model_forward(
input_shape = hidden_states.size()[:-1]
device = hidden_states.device
hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:])
- batch_size = hidden_states.shape[0]
-
- # GPT2Attention mask.
- if attention_mask is not None:
- if batch_size <= 0:
- raise ValueError("batch_size has to be defined and > 0")
- attention_mask = attention_mask.view(batch_size, -1)
- # We create a 3D attention mask from a 2D tensor mask.
- # Sizes are [batch_size, 1, 1, to_seq_length]
- # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
- # this attention mask is more simple than the triangular masking of causal attention
- # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
- attention_mask = attention_mask[:, None, None, :]
-
- # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
- # masked positions, this operation will create a tensor which is 0.0 for
- # positions we want to attend and the dtype's smallest value for masked positions.
- # Since we are adding it to the raw scores before the softmax, this is
- # effectively the same as removing these entirely.
- attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
- attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
-
- # If a 2D or 3D attention mask is provided for the cross-attention
- # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
- if self.config.add_cross_attention and encoder_hidden_states is not None:
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
- encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
- if encoder_attention_mask is None:
- encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
- encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
- else:
- encoder_attention_mask = None
+ hidden_states.shape[0]
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
@@ -156,6 +195,16 @@ def gpt2_model_forward(
output_shape = input_shape + (hidden_states.size(-1),)
+ attention_mask, encoder_attention_mask = _get_attention_mask(
+ self,
+ shard_config,
+ hidden_states,
+ past_key_values,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
@@ -169,10 +218,13 @@ def gpt2_model_forward(
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
- if shard_config.enable_sequence_parallelism:
- hidden_states = split_forward_gather_backward(
- hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
- )
+ if shard_config and shard_config.enable_sequence_parallelism:
+ if shard_config.sequence_parallelism_mode == "split_gather":
+ hidden_states = split_forward_gather_backward(
+ hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group,
+ )
# Going through held blocks.
start_idx, end_idx = stage_index[0], stage_index[1]
@@ -180,7 +232,7 @@ def gpt2_model_forward(
block = self.h[i]
torch.cuda.set_device(hidden_states.device)
# Ensure that attention_mask is always on the same device as hidden_states
- if attention_mask is not None:
+ if torch.is_tensor(attention_mask):
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
@@ -227,10 +279,13 @@ def custom_forward(*inputs):
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
# When sequence parallelism done, gather the output tensor in forward and split it in backward
- if shard_config.enable_sequence_parallelism:
- hidden_states = gather_forward_split_backward(
- hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
- )
+ if shard_config and shard_config.enable_sequence_parallelism:
+ if shard_config.sequence_parallelism_mode == "split_gather":
+ hidden_states = gather_forward_split_backward(
+ hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group,
+ )
if stage_manager.is_last_stage():
hidden_states = self.ln_f(hidden_states)
@@ -245,7 +300,13 @@ def custom_forward(*inputs):
if not return_dict:
return tuple(
v
- for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
+ for v in [
+ hidden_states,
+ presents,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
if v is not None
)
@@ -331,9 +392,11 @@ def gpt2_lmhead_model_forward(
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
- if shard_config.enable_tensor_parallelism:
+ if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
loss = cross_entropy_1d(
- shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
+ shift_logits,
+ shift_labels,
+ process_group=shard_config.tensor_parallel_process_group,
)
else:
loss = loss_fct(shift_logits, shift_labels)
@@ -733,27 +796,18 @@ def gpt2_for_sequence_classification_forward(
def get_gpt2_flash_attention_forward():
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
- from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
-
- def split_heads(tensor, num_heads, attn_head_size):
- """
- Splits hidden_size dim into attn_head_size and num_heads
- """
- new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
- tensor = tensor.view(new_shape)
- return tensor
-
def forward(
self: GPT2Attention,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[dict] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[dict] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
+ assert head_mask is None, "FlashAttention does not support head_mask"
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
@@ -766,10 +820,9 @@ def forward(
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
-
- query = split_heads(query, self.num_heads, self.head_dim)
- key = split_heads(key, self.num_heads, self.head_dim)
- value = split_heads(value, self.num_heads, self.head_dim)
+ query = self._split_heads(query, self.num_heads, self.head_dim)
+ key = self._split_heads(key, self.num_heads, self.head_dim)
+ value = self._split_heads(value, self.num_heads, self.head_dim)
if layer_past is not None:
past_key, past_value = layer_past
@@ -781,29 +834,14 @@ def forward(
else:
present = None
- if not self.is_cross_attention:
- attn_mask_type = AttnMaskType.causal
- flash_attention_mask = None
- if attention_mask != None:
- flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
- if not torch.all(flash_attention_mask):
- if attn_mask_type == AttnMaskType.causal:
- attn_mask_type == AttnMaskType.paddedcausal
- else:
- attn_mask_type = AttnMaskType.padding
-
- scale = value.size(-1) ** -0.5
+ scale = 1.0
+ if self.scale_attn_weights:
+ scale /= value.size(-1) ** 0.5
if self.scale_attn_by_inverse_layer_idx:
- scale = scale * (1 / float(self.layer_idx + 1))
-
- # use coloattention
- if not hasattr(self, "attention"):
- self.attention = ColoAttention(
- embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale
- )
-
- attn_output = self.attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
-
+ scale /= float(self.layer_idx + 1)
+ dropout_p = self.attn_dropout.p if self.training else 0.0
+ attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present, None)
@@ -813,9 +851,9 @@ def forward(
return forward
-def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
+def get_gpt_model_forward_for_flash_attn(shard_config: ShardConfig):
def forward(
- self,
+ self: GPT2Model,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
@@ -840,12 +878,13 @@ def forward(
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
- batch_size = input_ids.shape[0]
+ input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
- batch_size = inputs_embeds.shape[0]
+ inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
@@ -862,39 +901,201 @@ def forward(
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
- position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
+ position_ids = torch.arange(
+ past_length,
+ input_shape[-1] + past_length,
+ dtype=torch.long,
+ device=device,
+ )
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
- # GPT2Attention mask.
- if attention_mask is not None:
- if batch_size <= 0:
- raise ValueError("batch_size has to be defined and > 0")
- attention_mask = attention_mask.view(batch_size, -1)
- # We create a 3D attention mask from a 2D tensor mask.
- # Sizes are [batch_size, 1, 1, to_seq_length]
- # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
- # this attention mask is more simple than the triangular masking of causal attention
- # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
- attention_mask = attention_mask[:, None, None, :]
-
- # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
- # masked positions, this operation will create a tensor which is 0.0 for
- # positions we want to attend and the dtype's smallest value for masked positions.
- # Since we are adding it to the raw scores before the softmax, this is
- # effectively the same as removing these entirely.
- attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
- attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
-
- # If a 2D or 3D attention mask is provided for the cross-attention
- # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
- if self.config.add_cross_attention and encoder_hidden_states is not None:
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
- encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
- if encoder_attention_mask is None:
- encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
- encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # head_mask has shape n_layer x batch x n_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.wte(input_ids)
+ position_embeds = self.wpe(position_ids)
+ hidden_states = inputs_embeds + position_embeds
+
+ if token_type_ids is not None:
+ token_type_embeds = self.wte(token_type_ids)
+ hidden_states = hidden_states + token_type_embeds
+
+ hidden_states = self.drop(hidden_states)
+
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
+
+ attention_mask, encoder_attention_mask = _get_attention_mask(
+ self,
+ shard_config,
+ hidden_states,
+ past_key_values,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+ # Model parallel
+ if self.model_parallel:
+ torch.cuda.set_device(hidden_states.device)
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
+ if layer_past is not None:
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
+ # Ensure that attention_mask is always on the same device as hidden_states
+ if torch.is_tensor(attention_mask):
+ attention_mask = attention_mask.to(hidden_states.device)
+ if isinstance(head_mask, torch.Tensor):
+ head_mask = head_mask.to(hidden_states.device)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, use_cache, output_attentions)
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ None,
+ attention_mask,
+ head_mask[i],
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+ else:
+ outputs = block(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ head_mask=head_mask[i],
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
+
+ # Model Parallel: If it's the last layer for that device, put things on the next device
+ if self.model_parallel:
+ for k, v in self.device_map.items():
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
+
+ hidden_states = self.ln_f(hidden_states)
+
+ hidden_states = hidden_states.view(output_shape)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ presents,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+ return forward
+
+
+def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ input_ids.shape[0]
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ inputs_embeds.shape[0]
else:
- encoder_attention_mask = None
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
+ if position_ids is not None:
+ position_ids = position_ids.view(-1, input_shape[-1])
+
+ if past_key_values is None:
+ past_length = 0
+ past_key_values = tuple([None] * len(self.h))
+ else:
+ past_length = past_key_values[0][0].size(-2)
+ if position_ids is None:
+ position_ids = torch.arange(
+ past_length,
+ input_shape[-1] + past_length,
+ dtype=torch.long,
+ device=device,
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
@@ -914,6 +1115,15 @@ def forward(
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
+ attention_mask, encoder_attention_mask = _get_attention_mask(
+ self,
+ shard_config,
+ hidden_states,
+ past_key_values,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
if self.gradient_checkpointing and self.training:
if use_cache:
@@ -931,7 +1141,9 @@ def forward(
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward(
- hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+ hidden_states,
+ dim=1,
+ process_group=shard_config.sequence_parallel_process_group,
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
@@ -942,7 +1154,7 @@ def forward(
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states
- if attention_mask is not None:
+ if torch.is_tensor(attention_mask):
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
@@ -996,7 +1208,9 @@ def custom_forward(*inputs):
# When sequence parallelism done, gather the output tensor in forward and split it in backward
hidden_states = gather_forward_split_backward(
- hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+ hidden_states,
+ dim=1,
+ process_group=shard_config.sequence_parallel_process_group,
)
hidden_states = self.ln_f(hidden_states)
@@ -1008,7 +1222,13 @@ def custom_forward(*inputs):
if not return_dict:
return tuple(
v
- for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
+ for v in [
+ hidden_states,
+ presents,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
if v is not None
)
@@ -1078,15 +1298,11 @@ def forward(
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
- loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
- if shard_config.enable_tensor_parallelism:
- loss = cross_entropy_1d(
- shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
- )
- else:
- loss = loss_fct(shift_logits, shift_labels)
+ loss = cross_entropy_1d(
+ shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
+ )
if not shard_config.parallel_output:
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py
index 1990d7df3279..5c254d1e76bd 100644
--- a/colossalai/shardformer/modeling/gptj.py
+++ b/colossalai/shardformer/modeling/gptj.py
@@ -19,9 +19,54 @@
from transformers.utils import is_torch_fx_proxy, logging
from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig
+logger = logging.get_logger(__name__)
+
+
+def _get_attention_mask(
+ self: GPTJModel,
+ shard_config: ShardConfig,
+ hidden_states: torch.Tensor,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]],
+ attention_mask: Optional[torch.FloatTensor],
+) -> Optional[Union[torch.Tensor, dict]]:
+ batch_size, seq_len = hidden_states.shape[:2]
+ past_key_values_length = 0
+ if past_key_values is not None and past_key_values[0] is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ if shard_config.enable_flash_attention:
+ if attention_mask is not None:
+ attention_mask = attention_mask.view(batch_size, -1)
+ attention_mask = ColoAttention.prepare_attn_kwargs(
+ (batch_size, 1, seq_len, seq_len + past_key_values_length),
+ hidden_states.dtype,
+ hidden_states.device,
+ attention_mask,
+ is_causal=True,
+ )
+ elif attention_mask is not None:
+ if batch_size <= 0:
+ raise ValueError("batch_size has to be defined and > 0")
+ attention_mask = attention_mask.view(batch_size, -1)
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask = attention_mask[:, None, None, :]
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and the dtype's smallest value for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
+ return attention_mask
+
class GPTJPipelineForwards:
"""
@@ -96,26 +141,6 @@ def gptj_model_forward(
batch_size, seq_length = input_shape[0], input_shape[1]
device = hidden_states.device
- # Attention mask.
- if attention_mask is not None:
- if batch_size <= 0:
- raise ValueError("batch_size has to be defined and > 0")
- attention_mask = attention_mask.view(batch_size, -1)
- # We create a 3D attention mask from a 2D tensor mask.
- # Sizes are [batch_size, 1, 1, to_seq_length]
- # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
- # this attention mask is more simple than the triangular masking of causal attention
- # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
- attention_mask = attention_mask[:, None, None, :]
-
- # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
- # masked positions, this operation will create a tensor which is 0.0 for
- # positions we want to attend and the dtype's smallest value for masked positions.
- # Since we are adding it to the raw scores before the softmax, this is
- # effectively the same as removing these entirely.
- attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
- attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
-
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x num_attention_heads x N x N
@@ -139,6 +164,8 @@ def gptj_model_forward(
output_shape = input_shape + (hidden_states.size(-1),)
+ attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
+
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
@@ -154,7 +181,9 @@ def gptj_model_forward(
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config.enable_sequence_parallelism:
hidden_states = split_forward_gather_backward(
- hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+ hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group,
)
# Going through held blocks.
@@ -209,7 +238,9 @@ def custom_forward(*inputs):
# When sequence parallelism done, gather the output tensor in forward and split it in backward
if shard_config.enable_sequence_parallelism:
hidden_states = gather_forward_split_backward(
- hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+ hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group,
)
if stage_manager.is_last_stage():
@@ -223,7 +254,14 @@ def custom_forward(*inputs):
if stage_manager.is_last_stage():
if not return_dict:
return tuple(
- v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None
+ v
+ for v in [
+ hidden_states,
+ presents,
+ all_hidden_states,
+ all_self_attentions,
+ ]
+ if v is not None
)
return BaseModelOutputWithPast(
@@ -530,24 +568,11 @@ def gptj_for_question_answering_forward(
def get_gptj_flash_attention_forward():
from transformers.models.gptj.modeling_gptj import GPTJAttention
- from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
-
- def split_heads(tensor, num_attention_heads, attn_head_size, rotary):
- """
- Splits hidden dim into attn_head_size and num_attention_heads
- """
- new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
- tensor = tensor.view(new_shape)
- if rotary or len(tensor.shape) in [4, 5]:
- return tensor
- else:
- raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
-
def forward(
self: GPTJAttention,
hidden_states: torch.FloatTensor,
layer_past: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[dict] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
@@ -556,13 +581,14 @@ def forward(
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
]:
+ assert head_mask is None, "head_mask is not supported for FlashAttention"
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
- query = split_heads(query, self.num_attention_heads, self.head_dim, True)
- key = split_heads(key, self.num_attention_heads, self.head_dim, True)
- value = split_heads(value, self.num_attention_heads, self.head_dim, False)
+ query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
+ key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
+ value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)
if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing():
# The logic to conditionally copy to GPU could not be traced, so we do this
@@ -591,46 +617,202 @@ def forward(
key = apply_rotary_pos_emb(key, sin, cos)
query = apply_rotary_pos_emb(query, sin, cos)
- # key = key.permute(0, 2, 1, 3)
- # query = query.permute(0, 2, 1, 3)
- key = key.to(dtype=value.dtype) # fp16 compatibility
- query = query.to(dtype=value.dtype)
+ key = key.permute(0, 2, 1, 3)
+ query = query.permute(0, 2, 1, 3)
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
- key = torch.cat((past_key, key), dim=1)
- value = torch.cat((past_value, value), dim=1)
+ key = torch.cat((past_key, key), dim=-2)
+ value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
present = (key, value)
else:
present = None
- # use AttnMaskType and ColoAttention
- attn_mask_type = AttnMaskType.causal
- flash_attention_mask = None
- if attention_mask != None:
- if attn_mask_type == AttnMaskType.causal:
- attn_mask_type == AttnMaskType.paddedcausal
- else:
- attn_mask_type = AttnMaskType.padding
- flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
+ dropout_p = self.attn_dropout.p if self.training else 0.0
+ attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p)
+ attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
+ attn_output = self.out_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output)
+ outputs = (attn_output, present, None)
- # use coloattention
- scale = value.size(-1) ** -0.5
+ return outputs # a, present, (attentions)
+
+ return forward
- attention = ColoAttention(
- embed_dim=self.embed_dim, num_heads=self.num_attention_heads, dropout=self.attn_dropout.p, scale=scale
+
+def gptj_model_forward_for_flash_attention(shard_config: ShardConfig):
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ input_ids.shape[0]
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ inputs_embeds.shape[0]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
- attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
- attn_output = self.out_proj(attn_output)
- attn_output = self.resid_dropout(attn_output)
- outputs = (attn_output, present, None)
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
- return outputs # a, present, (attentions)
+ if position_ids is not None:
+ position_ids = position_ids.view(-1, input_shape[-1]).long()
+
+ if past_key_values is None:
+ past_length = 0
+ past_key_values = tuple([None] * len(self.h))
+ else:
+ past_length = past_key_values[0][0].size(-2)
+
+ if position_ids is None:
+ position_ids = torch.arange(
+ past_length,
+ input_shape[-1] + past_length,
+ dtype=torch.long,
+ device=device,
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x num_attention_heads x N x N
+ # head_mask has shape n_layer x batch x num_attention_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.wte(input_ids)
+
+ hidden_states = inputs_embeds
+
+ if token_type_ids is not None:
+ token_type_embeds = self.wte(token_type_ids)
+ hidden_states = hidden_states + token_type_embeds
+
+ hidden_states = self.drop(hidden_states)
+
+ attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
+
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+ # Model parallel
+ if self.model_parallel:
+ torch.cuda.set_device(hidden_states.device)
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
+ if layer_past is not None:
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
+ # Ensure that attention_mask is always on the same device as hidden_states
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(hidden_states.device)
+ if isinstance(head_mask, torch.Tensor):
+ head_mask = head_mask.to(hidden_states.device)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, use_cache, output_attentions)
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ None,
+ attention_mask,
+ position_ids,
+ head_mask[i],
+ )
+ else:
+ outputs = block(
+ hidden_states=hidden_states,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask[i],
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+
+ # Model Parallel: If it's the last layer for that device, put things on the next device
+ if self.model_parallel:
+ for k, v in self.device_map.items():
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
+
+ hidden_states = self.ln_f(hidden_states)
+
+ hidden_states = hidden_states.view(output_shape)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ presents,
+ all_hidden_states,
+ all_self_attentions,
+ ]
+ if v is not None
+ )
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
return forward
@@ -662,10 +844,10 @@ def forward(
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
- batch_size = input_ids.shape[0]
+ input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
- batch_size = inputs_embeds.shape[0]
+ inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
@@ -684,29 +866,14 @@ def forward(
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
- position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
+ position_ids = torch.arange(
+ past_length,
+ input_shape[-1] + past_length,
+ dtype=torch.long,
+ device=device,
+ )
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
- # Attention mask.
- if attention_mask is not None:
- if batch_size <= 0:
- raise ValueError("batch_size has to be defined and > 0")
- attention_mask = attention_mask.view(batch_size, -1)
- # We create a 3D attention mask from a 2D tensor mask.
- # Sizes are [batch_size, 1, 1, to_seq_length]
- # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
- # this attention mask is more simple than the triangular masking of causal attention
- # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
- attention_mask = attention_mask[:, None, None, :]
-
- # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
- # masked positions, this operation will create a tensor which is 0.0 for
- # positions we want to attend and the dtype's smallest value for masked positions.
- # Since we are adding it to the raw scores before the softmax, this is
- # effectively the same as removing these entirely.
- attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
- attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
-
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x num_attention_heads x N x N
@@ -725,6 +892,7 @@ def forward(
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
+ attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
if self.gradient_checkpointing and self.training:
if use_cache:
@@ -740,7 +908,9 @@ def forward(
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward(
- hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+ hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group,
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
@@ -801,7 +971,9 @@ def custom_forward(*inputs):
# When sequence parallelism done, gather the output tensor in forward and split it in backward
hidden_states = gather_forward_split_backward(
- hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+ hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group,
)
hidden_states = self.ln_f(hidden_states)
@@ -812,7 +984,16 @@ def custom_forward(*inputs):
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
- return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ presents,
+ all_hidden_states,
+ all_self_attentions,
+ ]
+ if v is not None
+ )
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py
index eb8e9f748527..0f1b4ad0a5d5 100644
--- a/colossalai/shardformer/modeling/llama.py
+++ b/colossalai/shardformer/modeling/llama.py
@@ -1,22 +1,35 @@
+import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
-from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
+from transformers.models.llama.modeling_llama import (
+ LlamaForCausalLM,
+ LlamaForSequenceClassification,
+ LlamaModel,
+ apply_rotary_pos_emb,
+ repeat_kv,
+)
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer._operation import (
+ all_to_all_comm,
+ gather_forward_split_backward,
+ split_forward_gather_backward,
+)
from colossalai.shardformer.shard import ShardConfig
-from ..layer import cross_entropy_1d
-from ..layer._operation import gather_forward_split_backward
+from ..layer import ColoAttention, cross_entropy_1d
try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
@@ -106,18 +119,25 @@ def llama_model_forward(
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
- if attention_mask is None:
- attention_mask = torch.ones(
- (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
- )
- if LATEST_VERSION:
- attention_mask = _prepare_4d_causal_attention_mask(
- attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
+ if shard_config.enable_flash_attention:
+ # in this case, attention_mask is a dict rather than a tensor
+ mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
+ attention_mask = ColoAttention.prepare_attn_kwargs(
+ mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True
)
else:
- attention_mask = self._prepare_decoder_attention_mask(
- attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
- )
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
+ )
+ if LATEST_VERSION:
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
+ )
+ else:
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
+ )
if self.gradient_checkpointing and self.training:
if use_cache:
@@ -132,13 +152,25 @@ def llama_model_forward(
next_decoder_cache = () if use_cache else None
start_idx, end_idx = stage_index[0], stage_index[1]
+ num_ckpt_layers = 0
+ if self.gradient_checkpointing and self.training:
+ num_ckpt_layers = end_idx - start_idx
+ # TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
+ if shard_config.gradient_checkpoint_config is not None:
+ num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
+ stage=stage_manager.stage,
+ num_layers=end_idx - start_idx,
+ model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0,
+ )
+ assert num_ckpt_layers <= end_idx - start_idx
+
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
- if self.gradient_checkpointing and self.training:
+ if idx - start_idx < num_ckpt_layers:
def create_custom_forward(module):
def custom_forward(*inputs):
@@ -263,6 +295,7 @@ def llama_for_causal_lm_forward(
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
+ shard_config=shard_config,
)
past_key_values = None
@@ -279,7 +312,7 @@ def llama_for_causal_lm_forward(
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
- if shard_config.enable_tensor_parallelism:
+ if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
@@ -289,9 +322,6 @@ def llama_for_causal_lm_forward(
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
- if not shard_config.parallel_output:
- logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group)
-
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
@@ -356,6 +386,7 @@ def llama_for_sequence_classification_forward(
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
+ shard_config=shard_config,
)
if input_ids is not None:
@@ -421,11 +452,9 @@ def llama_for_sequence_classification_forward(
return {"hidden_states": hidden_states}
-def get_llama_flash_attention_forward(shard_config: ShardConfig):
+def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
- from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
-
llama_version = 2
try:
from transformers.models.llama.modeling_llama import repeat_kv
@@ -436,7 +465,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
def forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
+ attention_mask: Optional[dict] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
@@ -444,18 +473,30 @@ def forward(
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
+
+ if sp_mode in ["split_gather", "ring"]:
+ q_len *= sp_size
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
- query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # sp: all-to-all comminucation when introducing sequence parallel
+ if sp_mode == "all_to_all":
+ query_states = all_to_all_comm(query_states, sp_group)
+ key_states = all_to_all_comm(key_states, sp_group)
+ value_states = all_to_all_comm(value_states, sp_group)
+ bsz, q_len, _ = query_states.size()
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
@@ -470,32 +511,14 @@ def forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
- me_input_shape = (bsz, q_len, self.num_heads, self.head_dim)
- query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape)
- key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape)
- value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape)
-
- flash_attention_mask = None
- attn_mask_type = AttnMaskType.causal
- if not getattr(shard_config, "causal_lm", False) and attention_mask != None:
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
- )
- flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
- attn_mask_type = AttnMaskType.paddedcausal
-
- if not hasattr(self, "attention"):
- self.attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
- attn_output = self.attention(
- query_states,
- key_states,
- value_states,
- attn_mask=flash_attention_mask,
- attn_mask_type=attn_mask_type,
- origin_attn_mask=attention_mask,
- )
+ assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
+ attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ # sp: all-to-all comminucation when introducing sequence parallel
+ if sp_mode == "all_to_all":
+ attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
@@ -503,6 +526,137 @@ def forward(
return forward
+def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
+ logger = logging.get_logger(__name__)
+ assert shard_config.enable_flash_attention, "Flash Attention is not enabled."
+
+ def forward(
+ self: LlamaModel,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ # embed positions
+ hidden_states = inputs_embeds
+
+ # in this case, attention_mask is a dict rather than a tensor
+ mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
+ attention_mask = ColoAttention.prepare_attn_kwargs(
+ mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True
+ )
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, past_key_value, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ position_ids,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ return forward
+
+
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
from transformers import LlamaForCausalLM
@@ -578,23 +732,15 @@ def forward(
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
- if shard_config.enable_tensor_parallelism:
- new_vocab_size = logits.shape[-1]
- shift_logits = shift_logits.view(-1, new_vocab_size)
- loss = cross_entropy_1d(
- shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
- )
- else:
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
- loss = loss_fct(shift_logits, shift_labels)
- if not shard_config.parallel_output:
- logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group)
+ new_vocab_size = logits.shape[-1]
+ shift_logits = shift_logits.view(-1, new_vocab_size)
+ loss = cross_entropy_1d(
+ shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
+ )
if not return_dict:
output = (logits,) + outputs[1:]
@@ -609,3 +755,261 @@ def forward(
)
return forward
+
+
+def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+ # sp: modify sp_len when sequence parallel mode is ring
+ if sp_mode in ["split_gather", "ring"]:
+ q_len *= sp_size
+ if self.config.pretraining_tp > 1:
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
+ query_slices = self.q_proj.weight.split(
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
+ )
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
+
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
+ query_states = torch.cat(query_states, dim=-1)
+
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
+ key_states = torch.cat(key_states, dim=-1)
+
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
+ value_states = torch.cat(value_states, dim=-1)
+
+ else:
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # sp: all-to-all comminucation when introducing sequence parallel
+ if sp_mode == "all_to_all":
+ query_states = all_to_all_comm(query_states, sp_group)
+ key_states = all_to_all_comm(key_states, sp_group)
+ value_states = all_to_all_comm(value_states, sp_group)
+ bsz, q_len, _ = query_states.size()
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ # sp: all-to-all comminucation when introducing sequence parallel
+ if sp_mode == "all_to_all":
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
+ attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
+ else:
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ if self.config.pretraining_tp > 1:
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
+ else:
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+ return attn_output, attn_weights, past_key_value
+
+ return forward
+
+
+def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
+ logger = logging.get_logger(__name__)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ # modify past_key_values_length when using sequence parallel
+ past_key_values_length *= sp_size
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if sp_mode in ["ring", "split_gather"]:
+ inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
+ elif sp_mode == "all_to_all":
+ inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
+
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+ )
+
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length
+ )
+
+ hidden_states = inputs_embeds
+
+ if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, past_key_value, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ position_ids,
+ )
+
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ if sp_mode == "ring" or sp_mode == "split_gather":
+ hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
+ elif sp_mode == "all_to_all":
+ hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ return forward
diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py
index d0e267eacd25..a265264303ad 100644
--- a/colossalai/shardformer/modeling/opt.py
+++ b/colossalai/shardformer/modeling/opt.py
@@ -18,6 +18,37 @@
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer import ColoAttention
+from colossalai.shardformer.shard import ShardConfig
+
+logger = logging.get_logger(__name__)
+
+
+def _get_attention_mask(
+ self: OPTModel,
+ shard_config: ShardConfig,
+ hidden_states: torch.Tensor,
+ past_key_values_length: int,
+ attention_mask: Optional[torch.FloatTensor],
+):
+ batch_size, seq_length = hidden_states.shape[:2]
+ mask_seq_length = past_key_values_length + seq_length
+ if shard_config.enable_flash_attention:
+ attention_mask = ColoAttention.prepare_attn_kwargs(
+ (batch_size, 1, seq_length, mask_seq_length),
+ hidden_states.dtype,
+ hidden_states.device,
+ attention_mask,
+ is_causal=True,
+ )
+ else:
+ attention_mask = self.decoder._prepare_decoder_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ hidden_states,
+ past_key_values_length,
+ )
+ return attention_mask
class OPTPipelineForwards:
@@ -26,46 +57,6 @@ class OPTPipelineForwards:
under pipeline setting.
"""
- @staticmethod
- def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length):
- # create causal mask
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- from transformers.models.opt.modeling_opt import _make_causal_mask
-
- combined_attention_mask = None
- if input_shape[-1] > 1:
- combined_attention_mask = _make_causal_mask(
- input_shape,
- _dtype,
- device,
- past_key_values_length=past_key_values_length,
- )
-
- if attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype, tgt_len=input_shape[-1]).to(
- device
- )
- combined_attention_mask = (
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
- )
-
- return combined_attention_mask
-
- @staticmethod
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
- """
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
- """
- bsz, src_len = mask.size()
- tgt_len = tgt_len if tgt_len is not None else src_len
-
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
-
- inverted_mask = 1.0 - expanded_mask
-
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
-
@staticmethod
def opt_model_forward(
self: OPTModel,
@@ -81,6 +72,7 @@ def opt_model_forward(
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
+ shard_config: Optional[ShardConfig] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward
@@ -119,7 +111,7 @@ def opt_model_forward(
if decoder.project_in is not None:
inputs_embeds = decoder.project_in(inputs_embeds)
device = input_ids.device if input_ids is not None else inputs_embeds.device
- _dtype = inputs_embeds.dtype
+ inputs_embeds.dtype
else:
if hidden_states is None:
@@ -127,7 +119,7 @@ def opt_model_forward(
input_shape = hidden_states.size()[:-1]
batch_size, seq_length = input_shape[0], input_shape[1]
device = hidden_states.device
- _dtype = hidden_states.dtype
+ hidden_states.dtype
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
# required mask seq length can be calculated via length of past
@@ -141,13 +133,24 @@ def opt_model_forward(
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
)
- causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask(
- attention_mask, input_shape, _dtype, device, past_key_values_length
- )
-
if stage_manager.is_first_stage():
+ causal_attention_mask = _get_attention_mask(
+ self,
+ shard_config,
+ inputs_embeds,
+ past_key_values_length,
+ attention_mask,
+ )
pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length)
hidden_states = inputs_embeds + pos_embeds
+ else:
+ causal_attention_mask = _get_attention_mask(
+ self,
+ shard_config,
+ hidden_states,
+ past_key_values_length,
+ attention_mask,
+ )
if decoder.gradient_checkpointing and decoder.training:
if use_cache:
@@ -249,7 +252,16 @@ def custom_forward(*inputs):
if stage_manager.is_last_stage():
if not return_dict:
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_cache,
+ all_hidden_states,
+ all_self_attns,
+ ]
+ if v is not None
+ )
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
@@ -276,6 +288,7 @@ def opt_for_causal_lm_forward(
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
+ shard_config: Optional[ShardConfig] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForCausalLM.forward.
@@ -303,6 +316,7 @@ def opt_for_causal_lm_forward(
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
+ shard_config=shard_config,
)
if stage_manager.is_last_stage():
logits = self.lm_head(outputs[0]).contiguous()
@@ -347,6 +361,7 @@ def opt_for_sequence_classification_forward(
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
+ shard_config: Optional[ShardConfig] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForSequenceClassification.forward.
@@ -371,6 +386,7 @@ def opt_for_sequence_classification_forward(
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
+ shard_config=shard_config,
)
if stage_manager.is_last_stage():
@@ -448,6 +464,7 @@ def opt_for_question_answering_forward(
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
+ shard_config: Optional[ShardConfig] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForQuestionAnswering.forward.
@@ -469,6 +486,7 @@ def opt_for_question_answering_forward(
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
+ shard_config=shard_config,
)
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
@@ -511,49 +529,47 @@ def opt_for_question_answering_forward(
return {"hidden_states": hidden_states}
-def get_opt_flash_attention_forward():
+def get_opt_flash_attention_forward(shard_config: ShardConfig):
from transformers.models.opt.modeling_opt import OPTAttention
- from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
-
def forward(
self: OPTAttention,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
+ attention_mask: Optional[dict] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
-
+ assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention"
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
+
bsz, tgt_len, _ = hidden_states.size()
- attention_input_shape = (bsz, -1, self.num_heads, self.head_dim)
# get query proj
- query_states = self.q_proj(hidden_states).view(*attention_input_shape)
+ query_states = self.q_proj(hidden_states)
# get key, value proj
if is_cross_attention and past_key_value is not None:
- # reuse k, v, cross_attentions
- key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape)
- value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape)
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
- key_states = self.k_proj(key_value_states).view(*attention_input_shape)
- value_states = self.v_proj(key_value_states).view(*attention_input_shape)
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
- key_states = self.k_proj(hidden_states).view(*attention_input_shape)
- value_states = self.v_proj(hidden_states).view(*attention_input_shape)
- key_states = torch.cat([past_key_value[0], key_states], dim=1)
- value_states = torch.cat([past_key_value[1], value_states], dim=1)
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
- key_states = self.k_proj(hidden_states).view(*attention_input_shape)
- value_states = self.v_proj(hidden_states).view(*attention_input_shape)
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
@@ -565,38 +581,181 @@ def forward(
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
- src_len = key_states.size(1)
- if layer_head_mask != None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
-
- flash_attention_mask = None
- attn_mask_type = AttnMaskType.causal
- if attention_mask != None:
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
- )
- flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
- if not torch.all(flash_attention_mask):
- attn_mask_type = AttnMaskType.paddedcausal
+ query_states = self._shape(query_states, tgt_len, bsz)
- attention = ColoAttention(
- embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling
- )
- attn_output = attention(
- query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type
+ dropout_p = self.dropout if self.training else 0.0
+ attn_output = ColoAttention.attention(
+ query_states,
+ key_states,
+ value_states,
+ **attention_mask,
+ dropout_p=dropout_p,
+ scale=self.scaling,
)
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned aross GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
attn_output = self.out_proj(attn_output)
+
return attn_output, None, past_key_value
return forward
+def get_opt_decoder_forward_for_flash_attention(shard_config: ShardConfig):
+ from transformers.models.opt.modeling_opt import OPTDecoder
+
+ def forward(
+ self: OPTDecoder,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ batch_size, seq_length = input_shape
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+ # required mask seq length can be calculated via length of past
+ mask_seq_length = past_key_values_length + seq_length
+
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
+ elif attention_mask.shape[1] != mask_seq_length:
+ raise ValueError(
+ f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
+ f"{mask_seq_length} (sum of the lengths of current and past inputs)"
+ )
+ causal_attention_mask = _get_attention_mask(
+ self, shard_config, inputs_embeds, past_key_values_length, attention_mask
+ )
+ pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
+
+ if self.project_in is not None:
+ inputs_embeds = self.project_in(inputs_embeds)
+
+ hidden_states = inputs_embeds + pos_embeds
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
+ if attn_mask is not None:
+ if attn_mask.size()[0] != (len(self.layers)):
+ raise ValueError(
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop:
+ continue
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, None)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ causal_attention_mask,
+ head_mask[idx] if head_mask is not None else None,
+ None,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if self.final_layer_norm is not None:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ if self.project_out is not None:
+ hidden_states = self.project_out(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ return forward
+
+
def get_jit_fused_opt_decoder_layer_forward():
from transformers.models.opt.modeling_opt import OPTDecoderLayer
diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py
index ab141a74aef8..e9c256a13571 100644
--- a/colossalai/shardformer/modeling/vit.py
+++ b/colossalai/shardformer/modeling/vit.py
@@ -1,4 +1,3 @@
-import math
from typing import List, Optional, Tuple, Union
import torch
@@ -6,6 +5,7 @@
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer import ColoAttention
def _encoder_forward(
@@ -98,7 +98,9 @@ def pp_forward(
pixel_values = pixel_values.to(expected_dtype)
embedding_output = self.embeddings(
- pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
+ pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ interpolate_pos_encoding=interpolate_pos_encoding,
)
hidden_states = embedding_output
else:
@@ -336,34 +338,27 @@ def pp_forward(
def get_vit_flash_self_attention_forward():
from transformers.models.vit.modeling_vit import ViTSelfAttention
- from colossalai.nn.layer.colo_attention import ColoAttention
-
- def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor:
- new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
- x = x.view(new_x_shape)
- return x
-
def forward(
self: ViTSelfAttention,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ assert head_mask is None, "head_mask is not supported for FlashAttention"
mixed_query_layer = self.query(hidden_states)
- key_layer = transpose_for_scores(self.key(hidden_states), self.num_attention_heads, self.attention_head_size)
- value_layer = transpose_for_scores(
- self.value(hidden_states), self.num_attention_heads, self.attention_head_size
- )
- query_layer = transpose_for_scores(mixed_query_layer, self.num_attention_heads, self.attention_head_size)
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
- scale = 1.0 / math.sqrt(self.attention_head_size)
- attention = ColoAttention(
- embed_dim=self.all_head_size, num_heads=self.num_attention_heads, dropout=self.dropout.p, scale=scale
- )
- context_layer = attention(query_layer, key_layer, value_layer)
+ dropout_p = self.dropout.p if self.training else 0.0
+ context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, dropout_p=dropout_p)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
- outputs = (context_layer,)
+ outputs = (context_layer, None) if output_attentions else (context_layer,)
return outputs
diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py
index cb8b45ae7d01..7ccc79276cf7 100644
--- a/colossalai/shardformer/modeling/whisper.py
+++ b/colossalai/shardformer/modeling/whisper.py
@@ -13,41 +13,74 @@
SequenceClassifierOutput,
)
from transformers.models.whisper.modeling_whisper import (
+ WhisperDecoder,
WhisperEncoder,
WhisperForAudioClassification,
WhisperForConditionalGeneration,
WhisperModel,
+ shift_tokens_right,
)
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer import ColoAttention
+from colossalai.shardformer.shard import ShardConfig
+
+logger = logging.get_logger(__name__)
+
+
+def _get_attention_mask(
+ self: WhisperDecoder,
+ shard_config: ShardConfig,
+ hidden_states: torch.Tensor,
+ past_key_values_length: int,
+ attention_mask: Optional[torch.FloatTensor],
+):
+ batch_size, seq_length = hidden_states.shape[:2]
+ mask_seq_length = past_key_values_length + seq_length
+ if shard_config.enable_flash_attention:
+ attention_mask = ColoAttention.prepare_attn_kwargs(
+ (batch_size, 1, seq_length, mask_seq_length),
+ hidden_states.dtype,
+ hidden_states.device,
+ attention_mask,
+ is_causal=True,
+ )
+ else:
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ hidden_states,
+ past_key_values_length,
+ )
+ return attention_mask
def get_whisper_flash_attention_forward():
from transformers.models.whisper.modeling_whisper import WhisperAttention
- from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
-
- def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
- return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous()
-
def forward(
self: WhisperAttention,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
+ attention_mask: Optional[dict] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
-
+ assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention"
+ # for encoder, attention_mask is None
+ if attention_mask is None:
+ attention_mask = {}
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
+ # get query proj
+ query_states = self.q_proj(hidden_states)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
@@ -55,25 +88,25 @@ def forward(
if (
is_cross_attention
and past_key_value is not None
- and past_key_value[0].shape[1] == key_value_states.shape[1]
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
- key_states = shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
- value_states = shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
- key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
- value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
- key_states = torch.cat([past_key_value[0], key_states], dim=1)
- value_states = torch.cat([past_key_value[1], value_states], dim=1)
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
- key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
- value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
@@ -85,42 +118,178 @@ def forward(
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
- # get query proj
- query_states = shape(self.q_proj(hidden_states), tgt_len, bsz, self.num_heads, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz)
- src_len = key_states.size(1)
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
+ dropout_p = self.dropout if self.training else 0.0
+ attn_output = ColoAttention.attention(
+ query_states,
+ key_states,
+ value_states,
+ **attention_mask,
+ dropout_p=dropout_p,
+ scale=self.scaling,
+ )
+ attn_output = attn_output.transpose(1, 2)
- attn_type = None
- flash_attention_mask = None
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned across GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
- if self.is_decoder:
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
- )
- flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous())
- if not torch.all(flash_attention_mask):
- attn_type = AttnMaskType.paddedcausal
- else:
- attn_type = AttnMaskType.causal
+ attn_output = self.out_proj(attn_output)
- attention = ColoAttention(
- embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling
- )
- attn_output = attention(
- query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_type
+ return attn_output, None, past_key_value
+
+ return forward
+
+
+def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig):
+ def forward(
+ self: WhisperDecoder,
+ input_ids=None,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ head_mask=None,
+ cross_attn_head_mask=None,
+ past_key_values=None,
+ inputs_embeds=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- attn_output = self.out_proj(attn_output)
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
- return attn_output, None, past_key_value
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ attention_mask = _get_attention_mask(self, shard_config, inputs_embeds, past_key_values_length, attention_mask)
+
+ # embed positions
+ if input_ids is not None:
+ positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
+ else:
+ positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
+
+ hidden_states = inputs_embeds + positions
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
+ )
+ use_cache = False
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+ next_decoder_cache = () if use_cache else None
+
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+ if attn_mask is not None:
+ assert attn_mask.size()[0] == (len(self.layers)), (
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop:
+ continue
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, use_cache)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ None, # encoder attention mask
+ head_mask[idx] if head_mask is not None else None,
+ (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
+ None, # past_key_value
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ cross_attn_layer_head_mask=(
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
+ ),
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ hidden_states = self.layer_norm(hidden_states)
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_cache,
+ all_hidden_states,
+ all_self_attns,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
return forward
@@ -292,6 +461,7 @@ def whisper_encoder_forward(
all_attentions=None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
+ shard_config: Optional[ShardConfig] = None,
):
r"""
Args:
@@ -403,7 +573,9 @@ def custom_forward(*inputs):
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
- last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_states,
+ attentions=all_attentions,
)
else:
@@ -411,7 +583,7 @@ def custom_forward(*inputs):
@staticmethod
def whisper_decoder_forward(
- self,
+ self: WhisperDecoder,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
@@ -427,6 +599,7 @@ def whisper_decoder_forward(
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
+ shard_config: Optional[ShardConfig] = None,
):
r"""
Args:
@@ -535,8 +708,12 @@ def whisper_decoder_forward(
else:
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
- attention_mask = self._prepare_decoder_attention_mask(
- attention_mask, input_shape, inputs_embeds, past_key_values_length
+ attention_mask = _get_attention_mask(
+ self,
+ shard_config,
+ inputs_embeds,
+ past_key_values_length,
+ attention_mask,
)
hidden_states = inputs_embeds + positions
@@ -556,8 +733,12 @@ def whisper_decoder_forward(
)
input_shape = hidden_states.size()[:-1]
- attention_mask = self._prepare_decoder_attention_mask(
- attention_mask, input_shape, hidden_states, past_key_values_length
+ attention_mask = _get_attention_mask(
+ self,
+ shard_config,
+ hidden_states,
+ past_key_values_length,
+ attention_mask,
)
start_idx, end_idx = stage_index[0], stage_index[1]
@@ -590,7 +771,7 @@ def custom_forward(*inputs):
encoder_hidden_states,
None, # encoder attention mask
head_mask[idx] if head_mask is not None else None,
- cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+ (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
None, # past_key_value
)
else:
@@ -626,7 +807,13 @@ def custom_forward(*inputs):
if not return_dict:
return tuple(
v
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+ for v in [
+ hidden_states,
+ next_cache,
+ all_hidden_states,
+ all_self_attns,
+ all_cross_attentions,
+ ]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
@@ -666,6 +853,7 @@ def whisper_model_forward(
encoder_hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
+ shard_config: Optional[ShardConfig] = None,
):
r"""
Returns:
@@ -735,7 +923,7 @@ def whisper_model_forward(
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ hidden_states=(encoder_outputs[1] if len(encoder_outputs) > 1 else None),
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
@@ -767,6 +955,7 @@ def whisper_model_forward(
hidden_states=hidden_states,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage,
+ shard_config=shard_config,
)
# Directly return outputs of overloaded Whisper forward if not at last stage.
@@ -810,6 +999,7 @@ def whisper_for_conditional_generation_forward(
encoder_hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
+ shard_config: Optional[ShardConfig] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -870,6 +1060,7 @@ def whisper_for_conditional_generation_forward(
encoder_hidden_states=encoder_hidden_states,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage,
+ shard_config=shard_config,
)
if not in_decoder:
return outputs
@@ -920,6 +1111,7 @@ def whisper_for_audio_classification_forward(
all_attentions=None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
+ shard_config: Optional[ShardConfig] = None,
):
r"""
This function is modified on the basis of transformers.models.whisper.modeling_whisper.WhisperForAudioClassification.forward.
diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py
index 1d2b7a570681..d67ab0a3c6bb 100644
--- a/colossalai/shardformer/policies/base_policy.py
+++ b/colossalai/shardformer/policies/base_policy.py
@@ -2,9 +2,8 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Union
-import numpy as np
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
@@ -196,50 +195,3 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
"""
return []
-
- @staticmethod
- def distribute_layers(num_layers: int, num_stages: int) -> List[int]:
- """Divide layers into stages"""
- quotient = num_layers // num_stages
- remainder = num_layers % num_stages
-
- # calculate the num_layers per stage
- layers_per_stage = [quotient] * num_stages
-
- # deal with the rest layers
- if remainder > 0:
- start_position = num_stages // 2 - remainder // 2
- for i in range(start_position, start_position + remainder):
- layers_per_stage[i] += 1
- return layers_per_stage
-
- @staticmethod
- def get_stage_index(
- layers_per_stage: List[int],
- stage: int,
- num_model_chunks: int = 1,
- num_stages: int = 0,
- ) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
- """
- Get the start index and end index of layers for each stage.
-
- Args:
- layers_per_stage (List[int]): number of layers for each stage
- stage (int): the stage index
- num_stages (int): number of stages
- num_model_chunks (int): number of model chunks
-
- Returns:
- - Tuple[int, int]: the start index and end index of this stage
- - List[Tuple[int, int]]: the start index and end index of this stage for each model chunk
-
- """
- num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
-
- stage_indices = []
- for model_chunk in range(num_model_chunks):
- start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
- end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
- stage_indices.append([start_idx, end_idx])
-
- return stage_indices[0] if num_model_chunks == 1 else stage_indices
diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py
index 0ab63b7650c1..0a61d8cff410 100644
--- a/colossalai/shardformer/policies/bert.py
+++ b/colossalai/shardformer/policies/bert.py
@@ -1,3 +1,4 @@
+import warnings
from functools import partial
from typing import Callable, Dict, List
@@ -66,8 +67,17 @@ def module_policy(self):
else:
norm_cls = col_nn.LayerNorm
- use_sequence_parallel = self.shard_config.enable_sequence_parallelism
+ sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
+ assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for Bert"
+ if sp_mode == "ring":
+ warnings.warn(
+ f"For Bert, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
+ )
+ sp_mode = "split_gather"
+
overlap = self.shard_config.enable_sequence_overlap
+ sp_partial_derived = sp_mode == "split_gather"
+
if self.shard_config.enable_tensor_parallelism:
policy[BertLayer] = ModulePolicyDescription(
attribute_replacement={
@@ -84,17 +94,26 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="attention.self.query",
target_module=col_nn.Linear1D_Col,
- kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
+ kwargs={
+ "seq_parallel_mode": sp_mode,
+ "overlap": overlap,
+ },
),
SubModuleReplacementDescription(
suffix="attention.self.key",
target_module=col_nn.Linear1D_Col,
- kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
+ kwargs={
+ "seq_parallel_mode": sp_mode,
+ "overlap": overlap,
+ },
),
SubModuleReplacementDescription(
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
- kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
+ kwargs={
+ "seq_parallel_mode": sp_mode,
+ "overlap": overlap,
+ },
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
@@ -103,7 +122,7 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
- kwargs={"seq_parallel": use_sequence_parallel},
+ kwargs={"seq_parallel_mode": sp_mode},
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
@@ -112,12 +131,15 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
- kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
+ kwargs={
+ "seq_parallel_mode": sp_mode,
+ "overlap": overlap,
+ },
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
- kwargs={"seq_parallel": use_sequence_parallel},
+ kwargs={"seq_parallel_mode": sp_mode},
),
SubModuleReplacementDescription(
suffix="output.dropout",
@@ -139,7 +161,7 @@ def module_policy(self):
]
)
- if use_sequence_parallel:
+ if sp_mode == "split_gather":
self.append_or_create_method_replacement(
description={"forward": bert_sequence_parallel_forward_fn(self.shard_config)},
policy=policy,
@@ -153,12 +175,12 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
target_module=norm_cls,
- kwargs={"sp_partial_derived": use_sequence_parallel},
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
SubModuleReplacementDescription(
suffix="output.LayerNorm",
target_module=norm_cls,
- kwargs={"sp_partial_derived": use_sequence_parallel},
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
],
policy=policy,
@@ -214,7 +236,9 @@ def add_lm_head_policy(self, base_policy):
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
- suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
+ suffix="decoder",
+ target_module=col_nn.Linear1D_Col,
+ kwargs={"gather_output": True},
),
policy=base_policy,
target_key=BertLMPredictionHead,
@@ -241,7 +265,9 @@ def add_lm_prediction_policy(self, base_policy):
"_load_from_state_dict": col_nn.ParallelModule._load_from_state_dict,
}
self.append_or_create_method_replacement(
- description=method_replacement, policy=base_policy, target_key=BertLMPredictionHead
+ description=method_replacement,
+ policy=base_policy,
+ target_key=BertLMPredictionHead,
)
return base_policy
@@ -263,25 +289,25 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
module = self.model.bert
if stage_manager.is_interleave:
- layers_per_stage = self.distribute_layers(
- len(module.encoder.layer), stage_manager.num_stages * stage_manager.num_model_chunks
- )
- stage_manager.stage_indices = Policy.get_stage_index(
- layers_per_stage,
- stage_manager.stage,
- num_model_chunks=stage_manager.num_model_chunks,
- num_stages=stage_manager.num_stages,
- )
+ layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
+ stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
- "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
+ "forward": partial(
+ new_forward,
+ stage_manager=stage_manager,
+ shard_config=self.shard_config,
+ )
}
else:
- layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
- stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
+ layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
+ stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
- new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
+ new_forward,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ shard_config=self.shard_config,
)
}
@@ -300,15 +326,8 @@ def get_held_layers(self) -> List[Module]:
held_layers = []
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
- layers_per_stage = self.distribute_layers(
- len(module.encoder.layer), stage_manager.num_stages * stage_manager.num_model_chunks
- )
- stage_indices = Policy.get_stage_index(
- layers_per_stage,
- stage_manager.stage,
- num_model_chunks=stage_manager.num_model_chunks,
- num_stages=stage_manager.num_stages,
- )
+ layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
+ stage_indices = stage_manager.get_stage_index(layers_per_stage)
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embeddings)
for start_idx, end_idx in stage_indices:
@@ -317,10 +336,10 @@ def get_held_layers(self) -> List[Module]:
held_layers.append(module.pooler)
else:
- layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
+ layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
if stage_manager.is_first_stage():
held_layers.append(module.embeddings)
- start_idx, end_idx = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
+ start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.pooler)
@@ -336,7 +355,9 @@ def module_policy(self):
if self.pipeline_stage_manager:
self.set_pipeline_forward(
- model_cls=BertModel, new_forward=BertPipelineForwards.bert_model_forward, policy=policy
+ model_cls=BertModel,
+ new_forward=BertPipelineForwards.bert_model_forward,
+ policy=policy,
)
return policy
@@ -399,7 +420,9 @@ def module_policy(self):
if self.pipeline_stage_manager:
self.set_pipeline_forward(
- model_cls=BertLMHeadModel, new_forward=BertPipelineForwards.bert_lm_head_model_forward, policy=policy
+ model_cls=BertLMHeadModel,
+ new_forward=BertPipelineForwards.bert_lm_head_model_forward,
+ policy=policy,
)
return policy
@@ -437,7 +460,9 @@ def module_policy(self):
if self.pipeline_stage_manager:
self.set_pipeline_forward(
- model_cls=BertForMaskedLM, new_forward=BertPipelineForwards.bert_for_masked_lm_forward, policy=policy
+ model_cls=BertForMaskedLM,
+ new_forward=BertPipelineForwards.bert_for_masked_lm_forward,
+ policy=policy,
)
return policy
diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py
index eddfafdcbcdc..2becadc3fb19 100644
--- a/colossalai/shardformer/policies/bloom.py
+++ b/colossalai/shardformer/policies/bloom.py
@@ -1,3 +1,4 @@
+import warnings
from functools import partial
from typing import Callable, Dict, List
@@ -55,8 +56,18 @@ def module_policy(self):
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
- use_sequence_parallel = self.shard_config.enable_sequence_parallelism
+
+ sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
+ assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for BLOOM"
+ if sp_mode == "ring":
+ warnings.warn(
+ f"For BLOOM, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
+ )
+ sp_mode = "split_gather"
+
overlap = self.shard_config.enable_sequence_overlap
+ sp_partial_derived = sp_mode == "split_gather"
+
if self.shard_config.enable_tensor_parallelism:
policy[BloomBlock] = ModulePolicyDescription(
attribute_replacement={
@@ -70,12 +81,12 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
- kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
+ kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap},
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
- kwargs={"seq_parallel": use_sequence_parallel},
+ kwargs={"seq_parallel_mode": sp_mode},
),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
@@ -84,12 +95,12 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=col_nn.Linear1D_Col,
- kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
+ kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap},
),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=col_nn.Linear1D_Row,
- kwargs={"seq_parallel": use_sequence_parallel},
+ kwargs={"seq_parallel_mode": sp_mode},
),
],
)
@@ -132,19 +143,19 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=norm_cls,
- kwargs={"sp_partial_derived": use_sequence_parallel},
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=norm_cls,
- kwargs={"sp_partial_derived": use_sequence_parallel},
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
],
policy=policy,
target_key=BloomBlock,
)
- if use_sequence_parallel:
+ if sp_mode == "split_gather":
self.append_or_create_method_replacement(
description={"forward": get_bloom_sequence_parallel_forward_fn(self.shard_config)},
policy=policy,
@@ -203,8 +214,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
else:
module = self.model.transformer
- layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
- stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
+ layers_per_stage = stage_manager.distribute_layers(len(module.h))
+ stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
@@ -226,11 +237,11 @@ def get_held_layers(self) -> List[Module]:
stage_manager = self.pipeline_stage_manager
held_layers = []
- layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
+ layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage():
held_layers.append(module.word_embeddings)
held_layers.append(module.word_embeddings_layernorm)
- start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py
index d1ad9f91478b..dabc14bffc95 100644
--- a/colossalai/shardformer/policies/chatglm2.py
+++ b/colossalai/shardformer/policies/chatglm2.py
@@ -1,3 +1,4 @@
+import warnings
from functools import partial
from typing import Callable, Dict, List, Union
@@ -55,8 +56,17 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
norm_cls = col_nn.RMSNorm
else:
norm_cls = col_nn.LayerNorm
- use_sequence_parallel = self.shard_config.enable_sequence_parallelism
+
+ sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
+ assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for ChatGLM2"
+ if sp_mode == "ring":
+ warnings.warn(
+ f"For ChatGLM2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
+ )
+ sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
+ sp_partial_derived = sp_mode == "split_gather"
+
if self.shard_config.enable_tensor_parallelism:
policy[ChatGLMModel] = ModulePolicyDescription(
attribute_replacement={},
@@ -91,12 +101,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
- kwargs={"seq_parallel": use_sequence_parallel, "seq_parallel_dim": 0, "overlap": overlap},
+ kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0, "overlap": overlap},
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
- kwargs={"seq_parallel": use_sequence_parallel, "seq_parallel_dim": 0},
+ kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0},
),
SubModuleReplacementDescription(
suffix="self_attention.core_attention.attention_dropout",
@@ -110,12 +120,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=norm_cls,
- kwargs={"sp_partial_derived": use_sequence_parallel},
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=norm_cls,
- kwargs={"sp_partial_derived": use_sequence_parallel},
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
],
policy=policy,
@@ -145,7 +155,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
)
# use sequence parallel
- if use_sequence_parallel:
+ if sp_mode == "split_gather":
self.append_or_create_method_replacement(
description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)},
policy=policy,
@@ -179,10 +189,10 @@ def get_held_layers(self) -> List[nn.Module]:
stage_manager = self.pipeline_stage_manager
held_layers = []
- layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages)
+ layers_per_stage = stage_manager.distribute_layers(module.num_layers)
if stage_manager.is_first_stage():
held_layers.append(module.embedding)
- start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.encoder.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
if module.encoder.post_layer_norm:
@@ -204,8 +214,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
else:
module = self.model.transformer
- layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages)
- stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
+ layers_per_stage = stage_manager.distribute_layers(module.num_layers)
+ stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py
index 5c148880f980..fe61c406fae3 100644
--- a/colossalai/shardformer/policies/falcon.py
+++ b/colossalai/shardformer/policies/falcon.py
@@ -161,8 +161,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
else:
module = self.model.transformer
- layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
- stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
+ layers_per_stage = stage_manager.distribute_layers(len(module.h))
+ stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
@@ -181,10 +181,10 @@ def get_held_layers(self) -> List[Module]:
module = self.model.transformer
stage_manager = self.pipeline_stage_manager
held_layers = []
- layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
+ layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage():
held_layers.append(module.word_embeddings)
- start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py
index 303766993e3d..380a432dc8b8 100644
--- a/colossalai/shardformer/policies/gpt2.py
+++ b/colossalai/shardformer/policies/gpt2.py
@@ -1,3 +1,4 @@
+import warnings
from functools import partial
from typing import Callable, Dict, List
@@ -8,6 +9,7 @@
from ..modeling.gpt2 import (
GPT2PipelineForwards,
get_gpt2_flash_attention_forward,
+ get_gpt_model_forward_for_flash_attn,
get_lm_forward_with_dist_cross_entropy,
gpt2_sequence_parallel_forward_fn,
)
@@ -49,8 +51,25 @@ def module_policy(self):
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
- use_sequence_parallel = self.shard_config.enable_sequence_parallelism
+
+ sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
+ assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for GPT2"
+ if sp_mode == "ring":
+ warnings.warn(
+ f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
+ )
+ sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
+ sp_partial_derived = sp_mode in ["split_gather", "ring"]
+ use_flash_attention = self.shard_config.enable_flash_attention
+ # todo: currently sp cannot be used with flashattention
+ if sp_mode in ["split_gather", "ring", "all_to_all"]:
+ if use_flash_attention:
+ warnings.warn(
+ f"Sequence parallelism mode {sp_mode} cannot be used with FlashAttention, will disable FlashAttention automatically."
+ )
+ self.shard_config.enable_flash_attention = False
+ use_flash_attention = False
if self.shard_config.enable_tensor_parallelism:
policy[GPT2Model] = ModulePolicyDescription(
sub_module_replacement=[
@@ -75,24 +94,34 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="attn.c_attn",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
- kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap},
+ kwargs={
+ "n_fused": 3,
+ "seq_parallel_mode": sp_mode,
+ "overlap": overlap,
+ },
),
SubModuleReplacementDescription(
suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={
- "seq_parallel": use_sequence_parallel,
+ "seq_parallel_mode": sp_mode,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
- kwargs={"n_fused": 1, "seq_parallel": use_sequence_parallel, "overlap": overlap},
+ kwargs={
+ "n_fused": 1,
+ "seq_parallel_mode": sp_mode,
+ "overlap": overlap,
+ },
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
- kwargs={"seq_parallel": use_sequence_parallel},
+ kwargs={
+ "seq_parallel_mode": sp_mode,
+ },
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
@@ -124,25 +153,25 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="ln_1",
target_module=norm_cls,
- kwargs={"sp_partial_derived": use_sequence_parallel},
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
SubModuleReplacementDescription(
suffix="ln_2",
target_module=norm_cls,
- kwargs={"sp_partial_derived": use_sequence_parallel},
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
SubModuleReplacementDescription(
suffix="ln_cross_attn",
target_module=norm_cls,
ignore_if_not_exist=True,
- kwargs={"sp_partial_derived": use_sequence_parallel},
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
],
policy=policy,
target_key=GPT2Block,
)
- if self.shard_config.enable_flash_attention:
+ if use_flash_attention:
self.append_or_create_method_replacement(
description={
"forward": get_gpt2_flash_attention_forward(),
@@ -150,8 +179,12 @@ def module_policy(self):
policy=policy,
target_key=GPT2Attention,
)
+ if not self.shard_config.pipeline_stage_manager:
+ policy[GPT2Model].method_replacement = {
+ "forward": get_gpt_model_forward_for_flash_attn(self.shard_config)
+ }
- if self.shard_config.enable_sequence_parallelism:
+ if sp_mode is not None:
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
return policy
@@ -172,15 +205,8 @@ def get_held_layers(self) -> List[nn.Module]:
held_layers = []
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
- layers_per_stage = self.distribute_layers(
- len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks
- )
- stage_indices = Policy.get_stage_index(
- layers_per_stage,
- stage_manager.stage,
- num_model_chunks=stage_manager.num_model_chunks,
- num_stages=stage_manager.num_stages,
- )
+ layers_per_stage = stage_manager.distribute_layers(len(module.h))
+ stage_indices = stage_manager.get_stage_index(layers_per_stage)
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.wte)
held_layers.append(module.wpe)
@@ -190,12 +216,12 @@ def get_held_layers(self) -> List[nn.Module]:
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(module.ln_f)
else:
- layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
+ layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage():
held_layers.append(module.wte)
held_layers.append(module.wpe)
held_layers.append(module.drop)
- start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
@@ -213,24 +239,24 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
module = self.model.transformer
if stage_manager.is_interleave:
- layers_per_stage = self.distribute_layers(
- len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks
- )
- stage_manager.stage_indices = Policy.get_stage_index(
- layers_per_stage,
- stage_manager.stage,
- num_model_chunks=stage_manager.num_model_chunks,
- num_stages=stage_manager.num_stages,
- )
+ layers_per_stage = stage_manager.distribute_layers(len(module.h))
+ stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
- "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
+ "forward": partial(
+ new_forward,
+ stage_manager=stage_manager,
+ shard_config=self.shard_config,
+ )
}
else:
- layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
- stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
+ layers_per_stage = stage_manager.distribute_layers(len(module.h))
+ stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
- new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
+ new_forward,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ shard_config=self.shard_config,
)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
@@ -245,7 +271,9 @@ def module_policy(self):
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(
- model_cls=GPT2Model, new_forward=GPT2PipelineForwards.gpt2_model_forward, policy=policy
+ model_cls=GPT2Model,
+ new_forward=GPT2PipelineForwards.gpt2_model_forward,
+ policy=policy,
)
return policy
@@ -269,12 +297,17 @@ def module_policy(self):
GPT2LMHeadModel: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
- suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": False}
+ suffix="lm_head",
+ target_module=col_nn.Linear1D_Col,
+ kwargs={"gather_output": not self.shard_config.parallel_output},
)
],
- method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
)
}
+ if self.shard_config.parallel_output:
+ addon_module[GPT2LMHeadModel].method_replacement = {
+ "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
+ }
module_policy.update(addon_module)
if self.pipeline_stage_manager is not None:
@@ -298,7 +331,12 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
if stage_manager is not None:
if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight):
first_stage, last_stage = 0, stage_manager.num_stages - 1
- return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
+ return [
+ {
+ first_stage: module.transformer.wte.weight,
+ last_stage: module.lm_head.weight,
+ }
+ ]
return []
@@ -314,7 +352,9 @@ def module_policy(self):
GPT2DoubleHeadsModel: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
- suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
+ suffix="lm_head",
+ target_module=col_nn.Linear1D_Col,
+ kwargs={"gather_output": True},
)
]
)
@@ -349,7 +389,12 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
if stage_manager is not None:
if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight):
first_stage, last_stage = 0, stage_manager.num_stages - 1
- return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
+ return [
+ {
+ first_stage: module.transformer.wte.weight,
+ last_stage: module.lm_head.weight,
+ }
+ ]
return []
@@ -391,7 +436,10 @@ def module_policy(self):
addon_module = {
GPT2ForTokenClassification: ModulePolicyDescription(
sub_module_replacement=[
- SubModuleReplacementDescription(suffix="dropout", target_module=col_nn.DropoutForParallelInput)
+ SubModuleReplacementDescription(
+ suffix="dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ )
]
)
}
diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py
index 9feb826c4624..eab4c214a41f 100644
--- a/colossalai/shardformer/policies/gptj.py
+++ b/colossalai/shardformer/policies/gptj.py
@@ -6,7 +6,11 @@
import colossalai.shardformer.layer as col_nn
-from ..modeling.gptj import GPTJPipelineForwards, get_gptj_flash_attention_forward
+from ..modeling.gptj import (
+ GPTJPipelineForwards,
+ get_gptj_flash_attention_forward,
+ gptj_model_forward_for_flash_attention,
+)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
@@ -71,17 +75,26 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="attn.k_proj",
target_module=col_nn.Linear1D_Col,
- kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
+ kwargs={
+ "seq_parallel": use_sequence_parallel,
+ "overlap": overlap,
+ },
),
SubModuleReplacementDescription(
suffix="attn.q_proj",
target_module=col_nn.Linear1D_Col,
- kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
+ kwargs={
+ "seq_parallel": use_sequence_parallel,
+ "overlap": overlap,
+ },
),
SubModuleReplacementDescription(
suffix="attn.v_proj",
target_module=col_nn.Linear1D_Col,
- kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
+ kwargs={
+ "seq_parallel": use_sequence_parallel,
+ "overlap": overlap,
+ },
),
SubModuleReplacementDescription(
suffix="attn.out_proj",
@@ -143,6 +156,12 @@ def module_policy(self):
policy=policy,
target_key=GPTJAttention,
)
+ if not self.shard_config.pipeline_stage_manager:
+ self.append_or_create_method_replacement(
+ description={"forward": gptj_model_forward_for_flash_attention(self.shard_config)},
+ policy=policy,
+ target_key=GPTJModel,
+ )
return policy
@@ -160,11 +179,11 @@ def get_held_layers(self) -> List[nn.Module]:
stage_manager = self.pipeline_stage_manager
held_layers = []
- layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
+ layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage():
held_layers.append(module.wte)
held_layers.append(module.drop)
- start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
@@ -181,11 +200,14 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
else:
module = self.model.transformer
- layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
- stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
+ layers_per_stage = stage_manager.distribute_layers(len(module.h))
+ stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
- new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
+ new_forward,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ shard_config=self.shard_config,
)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
@@ -203,7 +225,9 @@ def module_policy(self):
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(
- model_cls=GPTJModel, new_forward=GPTJPipelineForwards.gptj_model_forward, policy=policy
+ model_cls=GPTJModel,
+ new_forward=GPTJPipelineForwards.gptj_model_forward,
+ policy=policy,
)
return policy
@@ -230,7 +254,9 @@ def module_policy(self):
GPTJForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
- suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
+ suffix="lm_head",
+ target_module=col_nn.Linear1D_Col,
+ kwargs={"gather_output": True},
)
]
)
@@ -239,7 +265,9 @@ def module_policy(self):
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(
- model_cls=GPTJForCausalLM, new_forward=GPTJPipelineForwards.gptj_causallm_model_forward, policy=policy
+ model_cls=GPTJForCausalLM,
+ new_forward=GPTJPipelineForwards.gptj_causallm_model_forward,
+ policy=policy,
)
return policy
@@ -256,7 +284,12 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
if stage_manager is not None:
if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight):
first_stage, last_stage = 0, stage_manager.num_stages - 1
- return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
+ return [
+ {
+ first_stage: module.transformer.wte.weight,
+ last_stage: module.lm_head.weight,
+ }
+ ]
return []
diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py
index 42bf0825b045..bb4551b2c31c 100644
--- a/colossalai/shardformer/policies/llama.py
+++ b/colossalai/shardformer/policies/llama.py
@@ -11,6 +11,9 @@
from ..modeling.llama import (
LlamaPipelineForwards,
get_llama_flash_attention_forward,
+ get_llama_model_forward_for_flash_attn,
+ get_llama_seq_parallel_attention_forward,
+ get_llama_seq_parallel_model_forward,
get_lm_forward_with_dist_cross_entropy,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@@ -44,9 +47,74 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
else:
norm_cls = RMSNorm
- if self.shard_config.enable_sequence_parallelism:
+ if self.pipeline_stage_manager is not None:
self.shard_config.enable_sequence_parallelism = False
- warnings.warn("Llama doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
+ self.shard_config.enable_sequence_overlap = False
+ self.shard_config.sequence_parallelism_mode = None
+ warnings.warn(
+ f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
+ )
+ sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
+ sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None
+ sp_group = (
+ self.shard_config.sequence_parallel_process_group if self.shard_config.enable_sequence_parallelism else None
+ )
+ sp_partial_derived = sp_mode in ["split_gather", "ring"]
+
+ use_flash_attention = self.shard_config.enable_flash_attention
+ # Currently sp cannot to be used with flashattention
+ if sp_mode in ["split_gather", "ring", "all_to_all"]:
+ if use_flash_attention:
+ warnings.warn(
+ f"Sequence parallelism mode {sp_mode} need to be used with FlashAttention, will disable FlashAttention automatically."
+ )
+ use_flash_attention = False
+
+ if sp_mode in ["split_gather", "ring"]:
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_llama_seq_parallel_model_forward(
+ sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group
+ ),
+ },
+ policy=policy,
+ target_key=LlamaModel,
+ )
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
+ },
+ policy=policy,
+ target_key=LlamaAttention,
+ )
+ elif sp_mode == "all_to_all":
+ decoder_attribute_replacement = {
+ "num_heads": self.model.config.num_attention_heads // sp_size,
+ }
+ if getattr(self.model.config, "num_key_value_heads", False):
+ decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
+
+ policy[LlamaAttention] = ModulePolicyDescription(
+ attribute_replacement=decoder_attribute_replacement,
+ )
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
+ },
+ policy=policy,
+ target_key=LlamaAttention,
+ )
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_llama_seq_parallel_model_forward(
+ sp_mode=sp_mode,
+ sp_size=sp_size,
+ sp_group=sp_group,
+ ),
+ },
+ policy=policy,
+ target_key=LlamaModel,
+ )
if self.shard_config.enable_tensor_parallelism:
decoder_attribute_replacement = {
@@ -64,30 +132,37 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
+ kwargs=dict(seq_parallel_mode=sp_mode),
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
+ kwargs=dict(seq_parallel_mode=sp_mode),
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
+ kwargs=dict(seq_parallel_mode=sp_mode),
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
+ kwargs=dict(seq_parallel_mode=sp_mode),
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
+ kwargs=dict(seq_parallel_mode=sp_mode),
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
+ kwargs=dict(seq_parallel_mode=sp_mode),
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
+ kwargs=dict(seq_parallel_mode=sp_mode),
),
],
)
@@ -107,10 +182,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=norm_cls,
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=norm_cls,
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
],
policy=policy,
@@ -121,20 +198,30 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
description=SubModuleReplacementDescription(
suffix="norm",
target_module=norm_cls,
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
policy=policy,
target_key=LlamaModel,
)
# use flash attention
- if self.shard_config.enable_flash_attention:
+ if use_flash_attention:
self.append_or_create_method_replacement(
description={
- "forward": get_llama_flash_attention_forward(self.shard_config),
+ "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size),
},
policy=policy,
target_key=LlamaAttention,
)
+ if self.pipeline_stage_manager is None:
+ # replace llama model forward method
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_llama_model_forward_for_flash_attn(self.shard_config),
+ },
+ policy=policy,
+ target_key=LlamaModel,
+ )
return policy
@@ -154,30 +241,20 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
module = self.model.model
if stage_manager.is_interleave:
- layers_per_stage = self.distribute_layers(
- len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
- )
- stage_manager.stage_indices = Policy.get_stage_index(
- layers_per_stage,
- stage_manager.stage,
- num_model_chunks=stage_manager.num_model_chunks,
- num_stages=stage_manager.num_stages,
- )
+ layers_per_stage = stage_manager.distribute_layers(len(module.layers))
+ stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
}
else:
- layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
- stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
+ layers_per_stage = stage_manager.distribute_layers(len(module.layers))
+ stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
)
}
- self.append_or_create_method_replacement(
- description=method_replacement, policy=policy, target_key=model_cls
- )
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
@@ -194,15 +271,8 @@ def get_held_layers(self) -> List[Module]:
held_layers = []
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
- layers_per_stage = self.distribute_layers(
- len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
- )
- stage_indices = Policy.get_stage_index(
- layers_per_stage,
- stage_manager.stage,
- num_model_chunks=stage_manager.num_model_chunks,
- num_stages=stage_manager.num_stages,
- )
+ layers_per_stage = stage_manager.distribute_layers(len(module.layers))
+ stage_indices = stage_manager.get_stage_index(layers_per_stage)
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
@@ -211,10 +281,10 @@ def get_held_layers(self) -> List[Module]:
held_layers.append(module.norm)
else:
- layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
+ layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
- start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
@@ -250,18 +320,23 @@ def module_policy(self):
policy = super().module_policy()
- setattr(self.shard_config, "causal_lm", True)
-
- if self.shard_config.enable_tensor_parallelism:
+ if self.shard_config.enable_tensor_parallelism and not self.shard_config.enable_sequence_parallelism:
# add a new item for casual lm
new_item = {
LlamaForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
- SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col)
+ SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=Linear1D_Col,
+ kwargs={"gather_output": not self.shard_config.parallel_output},
+ )
],
- method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
)
}
+ if self.shard_config.parallel_output:
+ new_item[LlamaForCausalLM].method_replacement = {
+ "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
+ }
policy.update(new_item)
if self.pipeline_stage_manager:
diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py
index a542808ba794..98e584be861b 100644
--- a/colossalai/shardformer/policies/opt.py
+++ b/colossalai/shardformer/policies/opt.py
@@ -9,7 +9,12 @@
from .._utils import getattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
-from ..modeling.opt import OPTPipelineForwards, get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward
+from ..modeling.opt import (
+ OPTPipelineForwards,
+ get_jit_fused_opt_decoder_layer_forward,
+ get_opt_decoder_forward_for_flash_attention,
+ get_opt_flash_attention_forward,
+)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
@@ -27,6 +32,7 @@ def __init__(self) -> None:
import transformers
from packaging.version import Version
+ # TODO: remove this version check when transformers>=4.36.0
assert Version(transformers.__version__) <= Version(
"4.33.0"
), "The OPT model should run on a transformers version not greater than 4.33.0."
@@ -111,7 +117,9 @@ def module_policy(self):
# optimization configuration
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
- suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
+ suffix="final_layer_norm",
+ target_module=norm_cls,
+ ignore_if_not_exist=True,
),
policy=policy,
target_key=OPTDecoder,
@@ -119,10 +127,14 @@ def module_policy(self):
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
- suffix="self_attn_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
+ suffix="self_attn_layer_norm",
+ target_module=norm_cls,
+ ignore_if_not_exist=True,
),
SubModuleReplacementDescription(
- suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
+ suffix="final_layer_norm",
+ target_module=norm_cls,
+ ignore_if_not_exist=True,
),
],
policy=policy,
@@ -133,11 +145,19 @@ def module_policy(self):
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
description={
- "forward": get_opt_flash_attention_forward(),
+ "forward": get_opt_flash_attention_forward(self.shard_config),
},
policy=policy,
target_key=OPTAttention,
)
+ if not self.shard_config.pipeline_stage_manager:
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_opt_decoder_forward_for_flash_attention(self.shard_config),
+ },
+ policy=policy,
+ target_key=OPTDecoder,
+ )
# use jit fused operator
if self.shard_config.enable_jit_fused:
@@ -166,12 +186,12 @@ def get_held_layers(self) -> List[nn.Module]:
stage_manager = self.pipeline_stage_manager
held_layers = []
- layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
+ layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
held_layers.append(module.embed_positions)
held_layers.append(module.project_in)
- start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.final_layer_norm)
@@ -188,9 +208,16 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
else:
module = self.model.model.decoder
- layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
- stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
- method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
+ layers_per_stage = stage_manager.distribute_layers(len(module.layers))
+ stage_index = stage_manager.get_stage_index(layers_per_stage)
+ method_replacement = {
+ "forward": partial(
+ new_forward,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ shard_config=self.shard_config,
+ )
+ }
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
@@ -203,7 +230,9 @@ def module_policy(self):
policy = super().module_policy()
if self.pipeline_stage_manager:
self.set_pipeline_forward(
- model_cls=OPTModel, new_forward=OPTPipelineForwards.opt_model_forward, policy=policy
+ model_cls=OPTModel,
+ new_forward=OPTPipelineForwards.opt_model_forward,
+ policy=policy,
)
return policy
@@ -223,14 +252,18 @@ def module_policy(self):
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
- suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
+ suffix="lm_head",
+ target_module=Linear1D_Col,
+ kwargs=dict(gather_output=True),
),
policy=policy,
target_key=OPTForCausalLM,
)
if self.pipeline_stage_manager:
self.set_pipeline_forward(
- model_cls=OPTForCausalLM, new_forward=OPTPipelineForwards.opt_for_causal_lm_forward, policy=policy
+ model_cls=OPTForCausalLM,
+ new_forward=OPTPipelineForwards.opt_for_causal_lm_forward,
+ policy=policy,
)
return policy
@@ -246,7 +279,12 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
num_stages = self.pipeline_stage_manager.num_stages
if id(opt_model.model.decoder.embed_tokens.weight) == id(opt_model.lm_head.weight):
- return [{0: opt_model.model.decoder.embed_tokens.weight, num_stages - 1: opt_model.lm_head.weight}]
+ return [
+ {
+ 0: opt_model.model.decoder.embed_tokens.weight,
+ num_stages - 1: opt_model.lm_head.weight,
+ }
+ ]
return []
def postprocess(self):
diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py
index e183b0632f88..0c8ec15fa0a9 100644
--- a/colossalai/shardformer/policies/t5.py
+++ b/colossalai/shardformer/policies/t5.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import warnings
from functools import partial
from typing import Callable, Dict, List, Tuple
@@ -241,15 +243,16 @@ def module_policy(self):
def postprocess(self):
return self.model
- @staticmethod
def distribute_t5_layers(
- num_encoder_layers: int, num_decoder_layers: int, num_stages: int
+ self, num_encoder_layers: int, num_decoder_layers: int, num_stages: int
) -> Tuple[List[int], int]:
"""
Distribute t5 layers into stages when pipeline parallel is used.
Return the layer distribution as a list and the starting stage of decoder.
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
"""
+ stage_manager = self.pipeline_stage_manager
+ assert stage_manager is not None, "Pipeline stage manager is not set."
# number of encoder layers must be a positive integer
if num_encoder_layers <= 0:
@@ -261,7 +264,7 @@ def distribute_t5_layers(
# in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist
if num_decoder_layers == 0:
- return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages
+ return stage_manager.distribute_layers(num_encoder_layers, num_stages), num_stages
# the number of stages distributed between encoder and decoder is optimized in this way:
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
@@ -272,22 +275,26 @@ def objective(num_encoder_stages):
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
num_decoder_stages = num_stages - num_encoder_stages
- encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages)
- decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages)
+ encoder_distribution = stage_manager.distribute_layers(num_encoder_layers, num_encoder_stages)
+ decoder_distribution = stage_manager.distribute_layers(num_decoder_layers, num_decoder_stages)
return encoder_distribution + decoder_distribution, num_encoder_stages
- @staticmethod
def get_t5_stage_index(
- layers_per_stage: List[int], stage: int, decoder_starting_stage: int
- ) -> Tuple[bool, int, int]:
+ self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int
+ ) -> Tuple[int, int]:
"""
Input the distribution of layers among stages, the current stage and the first stage of decoder.
Return the starting/ending idx of layers in encoder/decoder
"""
+ stage_manager = self.pipeline_stage_manager
+ assert stage_manager is not None, "Pipeline stage manager is not set."
+
if stage < decoder_starting_stage:
- return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
+ return stage_manager.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
else:
- return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
+ return stage_manager.get_stage_index(
+ layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage
+ )
def get_held_layers(self) -> List[nn.Module]:
"""Get pipeline layers for current stage."""
@@ -302,12 +309,10 @@ def get_held_layers(self) -> List[nn.Module]:
num_decoder_layers = len(decoder.block) if decoder else 0
held_layers = []
- layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
+ layers_per_stage, decoder_starting_stage = self.distribute_t5_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
)
- start_idx, end_idx = T5BasePolicy.get_t5_stage_index(
- layers_per_stage, stage_manager.stage, decoder_starting_stage
- )
+ start_idx, end_idx = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
if stage_manager.stage < decoder_starting_stage:
# current stage is in t5's encoder
@@ -343,10 +348,10 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
num_encoder_layers = len(encoder.block)
num_decoder_layers = len(decoder.block) if decoder else 0
- layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
+ layers_per_stage, decoder_starting_stage = self.distribute_t5_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
)
- stage_index = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
+ stage_index = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
method_replacement = {
"forward": partial(
@@ -386,7 +391,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager is not None and stage_manager.num_stages > 1:
- _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
+ _, decoder_starting_stage = self.distribute_t5_layers(
len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages
)
@@ -434,7 +439,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager is not None and stage_manager.num_stages > 1:
- _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
+ _, decoder_starting_stage = self.distribute_t5_layers(
len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages
)
diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py
index 584d4e2652c0..905398c4d51e 100644
--- a/colossalai/shardformer/policies/vit.py
+++ b/colossalai/shardformer/policies/vit.py
@@ -134,10 +134,10 @@ def get_held_layers(self) -> List[nn.Module]:
stage_manager = self.pipeline_stage_manager
held_layers = []
- layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
+ layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
if stage_manager.is_first_stage():
held_layers.append(module.embeddings)
- start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.encoder.layer[start_idx:end_idx])
return held_layers
@@ -149,8 +149,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, pipeline_forward: Callable,
else:
module = self.model.vit
- layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
- stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
+ layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
+ stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {"forward": pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py
index b5b5db79d9de..c63f6d1cc549 100644
--- a/colossalai/shardformer/policies/whisper.py
+++ b/colossalai/shardformer/policies/whisper.py
@@ -13,6 +13,7 @@
WhisperPipelineForwards,
get_jit_fused_whisper_decoder_layer_forward,
get_jit_fused_whisper_encoder_layer_forward,
+ get_whisper_decoder_forward_for_flash_attention,
get_whisper_flash_attention_forward,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@@ -31,6 +32,7 @@ def __init__(self) -> None:
import transformers
from packaging.version import Version
+ # TODO: remove this version check when transformers>=4.36.0
assert Version(transformers.__version__) <= Version(
"4.33.0"
), "The Whisper model should run on a transformers version not greater than 4.33.0."
@@ -240,6 +242,14 @@ def module_policy(self):
policy=policy,
target_key=WhisperAttention,
)
+ if not self.shard_config.pipeline_stage_manager:
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_whisper_decoder_forward_for_flash_attention(self.shard_config),
+ },
+ policy=policy,
+ target_key=WhisperDecoder,
+ )
# use jit fused operator
if self.shard_config.enable_jit_fused:
@@ -269,7 +279,9 @@ def add_lm_head_policy(self, base_policy):
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
- suffix="proj_out", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
+ suffix="proj_out",
+ target_module=col_nn.Linear1D_Col,
+ kwargs={"gather_output": True},
),
policy=base_policy,
target_key=WhisperForConditionalGeneration,
@@ -280,15 +292,16 @@ def add_lm_head_policy(self, base_policy):
def postprocess(self):
return self.model
- @staticmethod
def distribute_whisper_layers(
- num_encoder_layers: int, num_decoder_layers: int, num_stages: int
+ self, num_encoder_layers: int, num_decoder_layers: int, num_stages: int
) -> Tuple[List[int], int]:
"""
Distribute whisper layers into stages when pipeline parallel is used.
Return the layer distribution as a list and the starting stage of decoder.
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
"""
+ stage_manager = self.pipeline_stage_manager
+ assert stage_manager is not None, "pipeline_stage_manager is None"
# number of encoder layers must be a positive integer
if num_encoder_layers <= 0:
@@ -300,7 +313,7 @@ def distribute_whisper_layers(
# in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist
if num_decoder_layers == 0:
- return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages
+ return stage_manager.distribute_layers(num_encoder_layers, num_stages), num_stages
# the number of stages distributed between encoder and decoder is optimized in this way:
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
@@ -311,22 +324,27 @@ def objective(num_encoder_stages):
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
num_decoder_stages = num_stages - num_encoder_stages
- encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages)
- decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages)
+ encoder_distribution = stage_manager.distribute_layers(num_encoder_layers, num_encoder_stages)
+ decoder_distribution = stage_manager.distribute_layers(num_decoder_layers, num_decoder_stages)
return encoder_distribution + decoder_distribution, num_encoder_stages
- @staticmethod
def get_whisper_stage_index(
- layers_per_stage: List[int], stage: int, decoder_starting_stage: int
- ) -> Tuple[bool, int, int]:
+ self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int
+ ) -> Tuple[int, int]:
"""
Input the distribution of layers among stages, the current stage and the first stage of decoder.
Return the starting/ending idx of layers in encoder/decoder
"""
+ stage_manager = self.pipeline_stage_manager
+ assert stage_manager is not None, "pipeline_stage_manager is None"
+
if stage < decoder_starting_stage:
- return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
+ return stage_manager.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
else:
- return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
+ return stage_manager.get_stage_index(
+ layers_per_stage[decoder_starting_stage:],
+ stage - decoder_starting_stage,
+ )
def get_held_layers(self) -> List[nn.Module]:
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
@@ -354,12 +372,10 @@ def get_held_layers(self) -> List[nn.Module]:
num_decoder_layers = 0
held_layers = []
- layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
+ layers_per_stage, decoder_starting_stage = self.distribute_whisper_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
)
- start_idx, end_idx = WhisperPolicy.get_whisper_stage_index(
- layers_per_stage, stage_manager.stage, decoder_starting_stage
- )
+ start_idx, end_idx = self.get_whisper_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
if stage_manager.stage < decoder_starting_stage:
# current stage is in whisper's encoder
@@ -409,12 +425,10 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
else:
num_decoder_layers = 0
- layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
+ layers_per_stage, decoder_starting_stage = self.distribute_whisper_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
)
- stage_index = WhisperPolicy.get_whisper_stage_index(
- layers_per_stage, stage_manager.stage, decoder_starting_stage
- )
+ stage_index = self.get_whisper_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
method_replacement = {
"forward": partial(
@@ -422,6 +436,7 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
stage_manager=stage_manager,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage,
+ shard_config=self.shard_config,
)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
@@ -436,7 +451,9 @@ def module_policy(self):
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(
- model_cls=WhisperModel, new_forward=WhisperPipelineForwards.whisper_model_forward, policy=policy
+ model_cls=WhisperModel,
+ new_forward=WhisperPipelineForwards.whisper_model_forward,
+ policy=policy,
)
return policy
@@ -493,7 +510,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
stage_manager = self.pipeline_stage_manager
if stage_manager is not None and stage_manager.num_stages > 1:
- _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
+ _, decoder_starting_stage = self.distribute_whisper_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
)
shared_params = []
diff --git a/colossalai/shardformer/shard/__init__.py b/colossalai/shardformer/shard/__init__.py
index acf8a95a41ca..dff2118c1c1a 100644
--- a/colossalai/shardformer/shard/__init__.py
+++ b/colossalai/shardformer/shard/__init__.py
@@ -1,5 +1,6 @@
+from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig
from .shard_config import ShardConfig
from .sharder import ModelSharder
from .shardformer import ShardFormer
-__all__ = ["ShardConfig", "ModelSharder", "ShardFormer"]
+__all__ = ["ShardConfig", "ModelSharder", "ShardFormer", "PipelineGradientCheckpointConfig", "GradientCheckpointConfig"]
diff --git a/colossalai/shardformer/shard/grad_ckpt_config.py b/colossalai/shardformer/shard/grad_ckpt_config.py
new file mode 100644
index 000000000000..9c6c2b54ea39
--- /dev/null
+++ b/colossalai/shardformer/shard/grad_ckpt_config.py
@@ -0,0 +1,87 @@
+from dataclasses import dataclass
+from typing import List, Optional
+
+
+@dataclass
+class GradientCheckpointConfig:
+ gradient_checkpointing_ratio: float = 0.0
+
+ def get_num_ckpt_layers(self, num_layers: int) -> int:
+ return int(self.gradient_checkpointing_ratio * num_layers)
+
+
+@dataclass
+class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
+ r"""
+ The pipeline gradient config is designed to provide more flexibility for users to control gradient checkpoint in pipeline parallelism.
+ Combined with PipelineStageManager.set_distribution_config, user can fully control the distribution of layers and checkpointed layers in pipeline parallelism.
+ Refer to https://github.com/hpcaitech/ColossalAI/issues/5509 for more details.
+
+ It provides the following features:
+ 1. `gradient_checkpointing_ratio`: This is used to control gradient checkpointing more precisely, e.g., set 50% of the layers to use gradient checkpointing.
+ 2. Customize # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`.
+
+ """
+ """
+ Args:
+ gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None.
+ num_stages (Optional[int]): Number of stages in the pipeline. Defaults to None. For sanity check.
+ num_model_chunks (Optional[int]): Number of model chunks (1F1B or Interleaved). Defaults to None. For sanity check.
+ num_model_layers (Optional[int]): Number of model layers. Defaults to None. For sanity check.
+ num_ckpt_layers_per_stage (Optional[List[int]]): Number of checkpointed layers for each stage. Defaults to None.
+
+ Example 1:
+ num_stages = 8
+ num_layers = 80
+ num_model_chunks = 1
+ num_layers_per_stage = [9, 9, 9, 10, 11, 10, 11, 11]
+ num_ckpt_layers_per_stage = [4, 4, 2, 2, 0, 0, 0, 0]
+
+ Example 2:
+ num_stages = 4
+ num_layers = 80
+ num_model_chunks = 2
+ num_layers_per_stage = [9, 9, 9, 10, 11, 10, 11, 11]
+ # device 0 holds num_layers_per_stage[0] and num_layers_per_stage[4] layers
+ ...
+
+ """
+ num_stages: Optional[int] = None
+ num_model_chunks: Optional[int] = None
+ num_model_layers: Optional[int] = None
+ num_ckpt_layers_per_stage: Optional[List[int]] = None
+
+ def __post_init__(self):
+ if self._enable_gradient_checkpointing_ratio:
+ if not (0 <= self.gradient_checkpointing_ratio <= 1):
+ raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%")
+
+ if self._enable_customized_ckpt_layers_per_stage:
+ assert (
+ self.num_stages is not None and self.num_model_chunks is not None and self.num_model_layers is not None
+ )
+ assert len(self.num_ckpt_layers_per_stage) == self.num_stages * self.num_model_chunks
+ assert all(
+ [0 <= num_ckpt_layers < self.num_model_layers for num_ckpt_layers in self.num_ckpt_layers_per_stage]
+ )
+ self.gradient_checkpointing_ratio = sum(self.num_ckpt_layers_per_stage) / self.num_model_layers
+
+ @property
+ def _enable_gradient_checkpointing_ratio(self) -> bool:
+ return self.gradient_checkpointing_ratio is not None
+
+ @property
+ def _enable_customized_ckpt_layers_per_stage(self) -> bool:
+ return self.num_ckpt_layers_per_stage is not None
+
+ def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = 0) -> int:
+ if not self._enable_gradient_checkpointing_ratio and not self._enable_customized_ckpt_layers_per_stage:
+ raise RuntimeError("No checkpointed layers information is provided")
+
+ if self._enable_customized_ckpt_layers_per_stage:
+ assert stage <= self.num_stages and model_chunk_id <= self.num_model_chunks
+ num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * self.num_stages]
+ assert num_ckpt_layers <= num_layers
+ return num_ckpt_layers
+ else:
+ return int(self.gradient_checkpointing_ratio * num_layers)
diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py
index da27341d9c29..7489873c2ed6 100644
--- a/colossalai/shardformer/shard/shard_config.py
+++ b/colossalai/shardformer/shard/shard_config.py
@@ -1,3 +1,4 @@
+import warnings
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
@@ -6,7 +7,10 @@
from colossalai.pipeline.stage_manager import PipelineStageManager
+from .grad_ckpt_config import GradientCheckpointConfig
+
__all__ = ["ShardConfig"]
+SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
@dataclass
@@ -23,18 +27,22 @@ class ShardConfig:
enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
+ gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
"""
tensor_parallel_process_group: Optional[ProcessGroup] = None
+ sequence_parallel_process_group: Optional[ProcessGroup] = None
pipeline_stage_manager: Optional[PipelineStageManager] = None
enable_tensor_parallelism: bool = True
+ enable_all_optimization: bool = False
enable_fused_normalization: bool = False
enable_flash_attention: bool = False
enable_jit_fused: bool = False
- enable_all_optimization: bool = False
enable_sequence_parallelism: bool = False
+ sequence_parallelism_mode: str = None
enable_sequence_overlap: bool = False
parallel_output: bool = True
+ gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# TODO padding vocab
# make_vocab_size_divisible_by: int = 128
@@ -46,21 +54,56 @@ class ShardConfig:
def tensor_parallel_size(self):
return self._tensor_parallel_size
+ @property
+ def sequence_parallel_size(self):
+ return self._sequence_parallel_size
+
def __post_init__(self):
- if not self.enable_tensor_parallelism and self.enable_sequence_parallelism:
- raise ValueError(
- "enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True"
+ # turn on all optimization if all_optimization is set to True
+ if self.enable_all_optimization:
+ self._turn_on_all_optimization()
+
+ if self.enable_sequence_parallelism:
+ self.sequence_parallelism_mode = (
+ "split_gather" if self.sequence_parallelism_mode is None else self.sequence_parallelism_mode
)
- if not self.enable_sequence_parallelism and self.enable_sequence_overlap:
- raise ValueError("enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True")
+ assert (
+ self.sequence_parallelism_mode in SUPPORT_SP_MODE
+ ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}"
+ if self.sequence_parallelism_mode in ["split_gather", "ring"]:
+ assert (
+ self.enable_tensor_parallelism
+ ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True"
+ elif self.sequence_parallelism_mode in ["all_to_all"]:
+ assert (
+ not self.enable_tensor_parallelism
+ ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False"
+ if self.enable_sequence_overlap:
+ self.enable_sequence_overlap = False
+ warnings.warn(
+ f"The enable_sequence_overlap flag will be ignored in sequence parallelism mode {self.sequence_parallelism_mode}"
+ )
+ else:
+ if self.sequence_parallelism_mode:
+ self.sequence_parallelism_mode = None
+ warnings.warn(
+ f"The sequence_parallelism_mode will be ignored when enable_sequence_parallelism is False"
+ )
+ assert (
+ not self.enable_sequence_overlap
+ ), f"enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True"
+
+ # get the tensor parallel size
if not self.enable_tensor_parallelism:
self._tensor_parallel_size = 1
else:
- # get the parallel size
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
- # turn on all optimization if all_optimization is set to True
- if self.enable_all_optimization:
- self._turn_on_all_optimization()
+
+ # get the sequence parallel size
+ if not self.enable_sequence_parallelism:
+ self._sequence_parallel_size = 1
+ else:
+ self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group)
def _turn_on_all_optimization(self):
"""
@@ -70,8 +113,10 @@ def _turn_on_all_optimization(self):
self.enable_fused_normalization = True
self.enable_flash_attention = True
self.enable_jit_fused = True
- self.enable_sequence_parallelism = True
- self.enable_sequence_overlap = True
+ # This can cause non-in-place param sharding when used without ZeRO.
+ # It may also slow down training when seq len is small. Plz enable manually.
+ # self.enable_sequence_parallelism = True
+ # self.enable_sequence_overlap = True
def _infer(self):
"""
diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py
index abe4a86d8198..667a7b78e4f5 100644
--- a/colossalai/tensor/d_tensor/layout_converter.py
+++ b/colossalai/tensor/d_tensor/layout_converter.py
@@ -440,7 +440,10 @@ def layout_converting(
total_steps = 0
transform_path = []
comm_action_sequence: List[CommSpec] = []
- spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence))
+
+ src_shape = source_layout.get_sharded_shape_per_device()
+ dst_shape = target_layout.get_sharded_shape_per_device()
+ spec_pairs = ((str(source_spec.sharding_sequence), src_shape), (str(target_spec.sharding_sequence), dst_shape))
if spec_pairs in self.cached_solution:
# Solution Cache hit
diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py
index 4f2a4878e7ce..e415b5fc3aa3 100644
--- a/colossalai/testing/comparison.py
+++ b/colossalai/testing/comparison.py
@@ -40,7 +40,12 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
assert torch.all(a == b), f"expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}"
-def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True, ignore_dtype: bool = False):
+def check_state_dict_equal(
+ d1: OrderedDict,
+ d2: OrderedDict,
+ ignore_device: bool = True,
+ ignore_dtype: bool = False,
+):
assert len(list(d1.keys())) == len(
list(d2.keys())
), f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}"
@@ -94,7 +99,12 @@ def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_devic
def assert_hf_output_close(
- out1: Any, out2: Any, ignore_keys: List[str] = None, track_name: str = "", atol=1e-5, rtol=1e-5
+ out1: Any,
+ out2: Any,
+ ignore_keys: List[str] = None,
+ track_name: str = "",
+ atol=1e-5,
+ rtol=1e-5,
):
"""
Check if two outputs from huggingface are equal.
@@ -113,7 +123,12 @@ def assert_hf_output_close(
if ignore_keys is not None and k in ignore_keys:
continue
assert_hf_output_close(
- out1[k], out2[k], track_name=f"{track_name}.{k}", ignore_keys=ignore_keys, atol=atol, rtol=rtol
+ out1[k],
+ out2[k],
+ track_name=f"{track_name}.{k}",
+ ignore_keys=ignore_keys,
+ atol=atol,
+ rtol=rtol,
)
elif isinstance(out1, (list, tuple)) and isinstance(out2, (list, tuple)):
# if two values are list
@@ -121,12 +136,17 @@ def assert_hf_output_close(
assert len(out1) == len(out2)
for i in range(len(out1)):
assert_hf_output_close(
- out1[i], out2[i], track_name=f"{track_name}.{i}", ignore_keys=ignore_keys, atol=atol, rtol=rtol
+ out1[i],
+ out2[i],
+ track_name=f"{track_name}.{i}",
+ ignore_keys=ignore_keys,
+ atol=atol,
+ rtol=rtol,
)
elif isinstance(out1, Tensor) and isinstance(out2, Tensor):
if out1.shape != out2.shape:
raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}")
- assert torch.allclose(
+ assert_close(
out1, out2, atol=atol, rtol=rtol
), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, \nmean error: {torch.abs(out1 - out2).mean()}"
else:
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index a2433d1b261c..bbbaf13b53ef 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -79,6 +79,7 @@ def __init__(
master_weights: bool = True, # master weights
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
+
self._dtype = self.optim.param_groups[0]["params"][0].dtype
self._logger = get_dist_logger()
self._verbose = verbose
@@ -494,7 +495,6 @@ def backward(self, loss, retain_graph=False):
# clear reduced grads
if self._overlap_communication:
get_accelerator().synchronize()
-
self.zero_grad()
def backward_by_grad(self, tensor, grad):
diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md
index 93045ea6adc6..6d243a80852d 100644
--- a/docs/README-zh-Hans.md
+++ b/docs/README-zh-Hans.md
@@ -24,6 +24,7 @@
## 新闻
+* [2024/03] [314 Billion Parameter Grok-1 Inference Accelerated by 3.8x, Efficient and Easy-to-Use PyTorch+HuggingFace version is Here](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here)
* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0)
* [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora)
* [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer)
@@ -71,6 +72,7 @@