diff --git a/.github/workflows/run_chatgpt_examples.yml b/.github/workflows/run_chatgpt_examples.yml
index 4ea86b609267..d0b5c2164119 100644
--- a/.github/workflows/run_chatgpt_examples.yml
+++ b/.github/workflows/run_chatgpt_examples.yml
@@ -52,6 +52,7 @@ jobs:
mkdir sft_data
mkdir prompt_data
mkdir preference_data
+ mkdir kto_data
./tests/test_data_preparation.sh
./tests/test_train.sh
env:
@@ -61,3 +62,4 @@ jobs:
SFT_DATASET: ./sft_data
PROMPT_DATASET: ./prompt_data
PREFERENCE_DATASET: ./preference_data
+ KTO_DATASET: ./kto_data
diff --git a/applications/Colossal-LLaMA/prepare_sft_dataset.py b/applications/Colossal-LLaMA/prepare_sft_dataset.py
index a857d6c0c696..fe57907601f6 100644
--- a/applications/Colossal-LLaMA/prepare_sft_dataset.py
+++ b/applications/Colossal-LLaMA/prepare_sft_dataset.py
@@ -10,7 +10,7 @@
import os
from multiprocessing import cpu_count
-from colossal_llama.dataset.conversation import LLaMA2_Conv
+from colossal_llama.dataset.conversation import LLaMA2_Conv, LLaMA3_Conv
from colossal_llama.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft
from datasets import dataset_dict, load_dataset
from transformers import AddedToken, AutoTokenizer
@@ -75,6 +75,8 @@ def main():
# Prepare to the tokenizer.
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
+ default_conversation = LLaMA3_Conv
+
# Fix split issue: https://github.com/huggingface/transformers/issues/23833
if args.llama_version == 2:
tokenizer.add_tokens(AddedToken("", normalized=False, special=True), special_tokens=True)
diff --git a/applications/Colossal-LLaMA/train.py b/applications/Colossal-LLaMA/train.py
index 43a360a9a49c..e74aad33c3e3 100644
--- a/applications/Colossal-LLaMA/train.py
+++ b/applications/Colossal-LLaMA/train.py
@@ -128,6 +128,12 @@ def main() -> None:
parser.add_argument("--zero", type=int, default=1)
parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos")
parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length")
+ parser.add_argument(
+ "--skip_save_each_epoch",
+ action="store_true",
+ default=False,
+ help="skip saving the model checkpoint after each epoch is completed.",
+ )
args = parser.parse_args()
with open(args.config_file, "w") as f:
@@ -370,11 +376,17 @@ def main() -> None:
)
total_loss.fill_(0.0)
pbar.update()
+
# Save modeling.
- if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
- step + 1
- ) == len(dataloader):
+ save_model_condition = (
+ args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
+ )
+
+ if not args.skip_save_each_epoch:
+ save_model_condition = save_model_condition or (step + 1) == len(dataloader)
+
+ if save_model_condition:
coordinator.print_on_master("\nStart saving model checkpoint with running states")
if args.use_neft:
diff --git a/applications/ColossalChat/.gitignore b/applications/ColossalChat/.gitignore
index 33950adc0bb5..757cbb5da051 100755
--- a/applications/ColossalChat/.gitignore
+++ b/applications/ColossalChat/.gitignore
@@ -146,6 +146,9 @@ docs/.build
examples/wandb/
examples/logs/
examples/output/
+examples/training_scripts/logs
+examples/training_scripts/wandb
+examples/training_scripts/output
examples/awesome-chatgpt-prompts/
temp/
diff --git a/applications/ColossalChat/README.md b/applications/ColossalChat/README.md
index b1b8f7eb2760..de27ebaf6be1 100755
--- a/applications/ColossalChat/README.md
+++ b/applications/ColossalChat/README.md
@@ -24,7 +24,9 @@
- [Limitation for LLaMA-finetuned models](#limitation)
- [Limitation of dataset](#limitation)
- [Alternative Option For RLHF: DPO](#alternative-option-for-rlhf-direct-preference-optimization)
-- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization)
+- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
+- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
+- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
- [FAQ](#faq)
- [How to save/load checkpoint](#faq)
- [How to train with limited resources](#faq)
@@ -137,17 +139,15 @@ The first step in Stage 1 is to collect a dataset of human demonstrations of the
{"messages":
[
{
- "from": "human",
+ "from": "user",
"content": "what are some pranks with a pen i can do?"
},
{
"from": "assistant",
"content": "Are you looking for practical joke ideas?"
},
- ...
]
},
- ...
]
```
@@ -173,23 +173,20 @@ Below shows the preference dataset format used in training the reward model.
"from": "human",
"content": "Introduce butterflies species in Oregon."
}
- ]
+ ],
"chosen": [
{
"from": "assistant",
"content": "About 150 species of butterflies live in Oregon, with about 100 species are moths..."
},
- ...
],
"rejected": [
{
"from": "assistant",
"content": "Are you interested in just the common butterflies? There are a few common ones which will be easy to find..."
},
- ...
]
},
- ...
]
```
@@ -218,7 +215,6 @@ PPO uses two kind of training data--- the prompt data and the sft data (optional
"from": "human",
"content": "what are some pranks with a pen i can do?"
}
- ...
]
},
]
@@ -284,6 +280,9 @@ Simple Preference Optimization (SimPO) from this [paper](https://arxiv.org/pdf/2
## Alternative Option For RLHF: Odds Ratio Preference Optimization (ORPO)
Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pdf/2403.07691) is a reference model free alignment method that use a mixture of SFT loss and a reinforcement leanring loss calculated based on odds-ratio-based implicit reward to makes the training more efficient and stable. Read this [README](./examples/README.md) for more information.
+## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
+We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information.
+
### Inference Quantization and Serving - After Training
We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
@@ -448,20 +447,6 @@ If you only have a single 24G GPU. Generally, using lora and "zero2-cpu" will be
If you have multiple GPUs each has very limited VRAM, say 8GB. You can try the `3d` for the plugin option, which supports tensor parellelism, set `--tp` to the number of GPUs that you have.
-## The Plan
-
-- [x] implement PPO fine-tuning
-- [x] implement training reward model
-- [x] support LoRA
-- [x] support inference
-- [x] support llama from [facebook](https://github.com/facebookresearch/llama)
-- [x] implement PPO-ptx fine-tuning
-- [x] support flash-attention
-- [x] implement DPO fine-tuning
-- [ ] integrate with Ray
-- [ ] support more RL paradigms, like Implicit Language Q-Learning (ILQL),
-- [ ] support chain-of-thought by [langchain](https://github.com/hwchase17/langchain)
-
### Real-time progress
You will find our progress in github [project broad](https://github.com/orgs/hpcaitech/projects/17/views/1).
diff --git a/applications/ColossalChat/benchmarks/benchmark_dpo.sh b/applications/ColossalChat/benchmarks/benchmark_dpo.sh
index dfd0ff846c2e..44d821a87fee 100755
--- a/applications/ColossalChat/benchmarks/benchmark_dpo.sh
+++ b/applications/ColossalChat/benchmarks/benchmark_dpo.sh
@@ -19,30 +19,33 @@ PROJECT_NAME="dpo"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
+BENCHMARK_DATA_DIR="./temp/dpo" # Path to benchmark data
+DATASET_SIZE=320
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
-SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
-CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
+declare -a dataset=(
+ $BENCHMARK_DATA_DIR/arrow/part-0
+)
-colossalai run --nproc_per_node 4 --master_port 31313 benchmark_dpo.py \
+# Generate dummy test data
+python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type preference
+
+
+colossalai run --nproc_per_node 4 --master_port 31313 ../examples/training_scripts/train_dpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
- --config_file $CONFIG_FILE \
+ --dataset ${dataset[@]} \
--plugin "zero2_cpu" \
--max_epochs 1 \
--accumulation_steps 1 \
- --batch_size 8 \
+ --batch_size 4 \
--lr 1e-6 \
--beta 0.1 \
- --gamma 0.6 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--max_length 2048 \
- --dataset_size 640 \
--weight_decay 0.01 \
--warmup_steps 60 \
- --disable_reference_model \
- --length_normalization \
--grad_checkpoint \
--use_flash_attn
diff --git a/applications/ColossalChat/benchmarks/benchmark_kto.sh b/applications/ColossalChat/benchmarks/benchmark_kto.sh
new file mode 100755
index 000000000000..82d3e3421acb
--- /dev/null
+++ b/applications/ColossalChat/benchmarks/benchmark_kto.sh
@@ -0,0 +1,51 @@
+#!/bin/bash
+set_n_least_used_CUDA_VISIBLE_DEVICES() {
+ local n=${1:-"9999"}
+ echo "GPU Memory Usage:"
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
+ tail -n +2 |
+ nl -v 0 |
+ tee /dev/tty |
+ sort -g -k 2 |
+ awk '{print $1}' |
+ head -n $n)
+ export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
+ echo "Now CUDA_VISIBLE_DEVICES is set to:"
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+}
+set_n_least_used_CUDA_VISIBLE_DEVICES 4
+
+PROJECT_NAME="kto"
+PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
+PRETRAINED_MODEL_PATH="" # huggingface or local model path
+PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
+BENCHMARK_DATA_DIR="./temp/kto" # Path to benchmark data
+DATASET_SIZE=80
+
+TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
+FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
+declare -a dataset=(
+ $BENCHMARK_DATA_DIR/arrow/part-0
+)
+
+# Generate dummy test data
+python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type kto
+
+
+colossalai run --nproc_per_node 2 --master_port 31313 ../examples/training_scripts/train_kto.py \
+ --pretrain $PRETRAINED_MODEL_PATH \
+ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
+ --dataset ${dataset[@]} \
+ --plugin "zero2_cpu" \
+ --max_epochs 1 \
+ --accumulation_steps 1 \
+ --batch_size 2 \
+ --lr 1e-5 \
+ --beta 0.1 \
+ --mixed_precision "bf16" \
+ --grad_clip 1.0 \
+ --max_length 2048 \
+ --weight_decay 0.01 \
+ --warmup_steps 60 \
+ --grad_checkpoint \
+ --use_flash_attn
diff --git a/applications/ColossalChat/benchmarks/benchmark_orpo.py b/applications/ColossalChat/benchmarks/benchmark_orpo.py
deleted file mode 100755
index 1325bada2dca..000000000000
--- a/applications/ColossalChat/benchmarks/benchmark_orpo.py
+++ /dev/null
@@ -1,315 +0,0 @@
-import argparse
-import json
-import os
-import resource
-from contextlib import nullcontext
-
-import torch
-from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler
-from coati.models import convert_to_lora_module, disable_dropout
-from coati.trainer import ORPOTrainer
-from coati.utils import load_checkpoint
-from dummy_dataset import DummyLLMDataset
-from transformers import AutoModelForCausalLM, AutoTokenizer
-
-import colossalai
-from colossalai.booster import Booster
-from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
-from colossalai.cluster import DistCoordinator
-from colossalai.logging import get_dist_logger
-from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
-from colossalai.nn.optimizer import HybridAdam
-
-logger = get_dist_logger()
-
-
-def train(args):
- # check lora compatibility
- if "gemini" in args.plugin and args.lora_rank > 0:
- raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
- if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
- raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
-
- # ==============================
- # Initialize Distributed Training
- # ==============================
- colossalai.launch_from_torch()
- coordinator = DistCoordinator()
-
- # ==============================
- # Initialize Booster
- # ==============================
- if args.plugin == "ddp":
- """
- Default torch ddp plugin without any acceleration, for
- debugging purpose acceleration, for debugging purpose
- """
- plugin = TorchDDPPlugin(find_unused_parameters=True)
- elif args.plugin == "gemini":
- plugin = GeminiPlugin(
- precision=args.mixed_precision,
- placement_policy="static",
- initial_scale=2**16,
- max_norm=args.grad_clip,
- enable_gradient_accumulation=True,
- enable_flash_attention=args.use_flash_attn,
- )
- elif args.plugin == "gemini_auto":
- plugin = GeminiPlugin(
- precision=args.mixed_precision,
- placement_policy="auto",
- initial_scale=2**16,
- max_norm=args.grad_clip,
- enable_flash_attention=args.use_flash_attn,
- )
- elif args.plugin == "zero2":
- plugin = LowLevelZeroPlugin(
- stage=2,
- precision=args.mixed_precision,
- initial_scale=2**16,
- max_norm=args.grad_clip,
- )
- elif args.plugin == "zero2_cpu":
- plugin = LowLevelZeroPlugin(
- stage=2,
- precision=args.mixed_precision,
- initial_scale=2**16,
- cpu_offload=True,
- max_norm=args.grad_clip,
- )
- elif args.plugin == "3d":
- plugin = HybridParallelPlugin(
- tp_size=args.tp,
- pp_size=args.pp,
- sp_size=args.sp,
- sequence_parallelism_mode=args.sp_mode,
- zero_stage=args.zero_stage,
- enable_flash_attention=args.use_flash_attn,
- enable_sequence_parallelism=args.enable_sequence_parallelism,
- cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
- parallel_output=False,
- max_norm=args.grad_clip,
- precision=args.mixed_precision,
- )
- else:
- raise ValueError(f"Unknown plugin {args.plugin}")
-
- booster = Booster(plugin=plugin)
-
- # ======================================================
- # Initialize Model, Objective, Optimizer and LR Scheduler
- # ======================================================
- # Temp Fix: Disable lazy init due to version conflict
- # init_ctx = (
- # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
- # )
-
- init_ctx = nullcontext()
- with init_ctx:
- if args.use_flash_attn:
- model = AutoModelForCausalLM.from_pretrained(
- args.pretrain,
- torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
- use_flash_attention_2=True,
- )
- coordinator.print_on_master(msg="Flash-attention enabled successfully")
- else:
- model = AutoModelForCausalLM.from_pretrained(args.pretrain)
- disable_dropout(model)
- if args.lora_rank > 0:
- model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
-
- if args.grad_checkpoint:
- # Note, for some models, lora may not be compatible with gradient checkpointing
- model.gradient_checkpointing_enable()
- coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
-
- # configure tokenizer
- tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
- if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
- try:
- # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
- tokenizer.pad_token = tokenizer.eos_token
- except AttributeError as e:
- logger.warning(f"Unable to set pad token to eos token, {str(e)}")
- if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
- logger.warning(
- "The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
- )
-
- tokenizer.add_bos_token = False
- tokenizer.add_eos_token = False
-
- # configure optimizer
- optim = HybridAdam(
- model_params=model.parameters(),
- lr=args.lr,
- betas=(0.9, 0.95),
- weight_decay=args.weight_decay,
- adamw_mode=True,
- )
-
- # configure dataset
- coordinator.print_on_master(f"Load dataset: {args.dataset}")
- mode_map = {"train": "train", "valid": "validation", "test": "test"}
- train_dataset = DummyLLMDataset(
- ["chosen_input_ids", "chosen_loss_mask", "rejected_input_ids", "rejected_loss_mask"],
- args.max_length,
- args.dataset_size,
- )
- data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
-
- train_dataloader = plugin.prepare_dataloader(
- dataset=train_dataset,
- batch_size=args.batch_size,
- shuffle=True,
- drop_last=True,
- collate_fn=data_collator,
- distributed_sampler_cls=StatefulDistributedSampler,
- )
-
- num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
- if args.warmup_steps is None:
- args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
- coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
-
- lr_scheduler = CosineAnnealingWarmupLR(
- optimizer=optim,
- total_steps=args.max_epochs * num_update_steps_per_epoch,
- warmup_steps=args.warmup_steps,
- eta_min=0.1 * args.lr,
- )
-
- default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
- torch.set_default_dtype(default_dtype)
- model, optim, _, train_dataloader, lr_scheduler = booster.boost(
- model=model,
- optimizer=optim,
- lr_scheduler=lr_scheduler,
- dataloader=train_dataloader,
- )
- torch.set_default_dtype(torch.float)
-
- coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
- coordinator.print_on_master(
- f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
- )
-
- start_epoch = 0
- sampler_start_idx = 0
- start_step = 0
- if args.checkpoint_path is not None:
- if "modeling" in args.checkpoint_path:
- coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
- booster.load_model(model, args.checkpoint_path)
- else:
- coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
- start_epoch, start_step, sampler_start_idx = load_checkpoint(
- load_dir=args.checkpoint_path,
- booster=booster,
- model=model,
- optimizer=optim,
- lr_scheduler=lr_scheduler,
- )
- assert isinstance(train_dataloader.sampler, StatefulDistributedSampler)
- train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
-
- coordinator.print_on_master(
- f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
- )
- coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
-
- coordinator.print_on_master(
- f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
- )
- coordinator.print_on_master(
- f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
- )
- coordinator.print_on_master(
- f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
- )
-
- trainer = ORPOTrainer(
- actor=model,
- booster=booster,
- actor_optim=optim,
- actor_lr_scheduler=lr_scheduler,
- tokenizer=tokenizer,
- max_epochs=args.max_epochs,
- accumulation_steps=args.accumulation_steps,
- start_epoch=start_epoch,
- save_interval=None,
- save_dir=None,
- coordinator=coordinator,
- lam=args.lam,
- )
-
- trainer.fit(
- train_preference_dataloader=train_dataloader,
- eval_preference_dataloader=None,
- log_dir=None,
- use_wandb=False,
- )
- coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
-
-
-if __name__ == "__main__":
- # ==============================
- # Parse Arguments
- # ==============================
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--plugin",
- type=str,
- default="gemini",
- choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
- help="Choose which plugin to use",
- )
- parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
- parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
- parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
- parser.add_argument("--tp", type=int, default=1)
- parser.add_argument("--pp", type=int, default=1)
- parser.add_argument("--sp", type=int, default=1)
- parser.add_argument("--lam", type=float, default=0.1, help="lambda in ORPO loss")
- parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
- parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
- parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
- parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
- parser.add_argument("--pretrain", type=str, default=None)
- parser.add_argument("--model_type", type=str, default=None)
- parser.add_argument("--tokenizer_dir", type=str, default=None)
- parser.add_argument("--dataset", nargs="+", default=[])
- parser.add_argument(
- "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
- )
- parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
- parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
- parser.add_argument("--max_epochs", type=int, default=3)
- parser.add_argument("--batch_size", type=int, default=4)
- parser.add_argument(
- "--disable_reference_model",
- action="store_true",
- default=False,
- help="Disable the reference model (enabled by default)",
- )
- parser.add_argument("--dataset_size", type=int, default=500)
- parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
- parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument(
- "--lora_train_bias",
- type=str,
- default="none",
- help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
- )
- parser.add_argument("--merge_lora_weights", type=bool, default=True)
- parser.add_argument("--lr", type=float, default=5e-6)
- parser.add_argument("--accumulation_steps", type=int, default=8)
- parser.add_argument("--grad_checkpoint", default=False, action="store_true")
- parser.add_argument("--use_flash_attn", default=False, action="store_true")
- args = parser.parse_args()
- os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
- with open(args.config_file, "w") as f:
- json.dump(args.__dict__, f, indent=4)
- train(args)
diff --git a/applications/ColossalChat/benchmarks/benchmark_orpo.sh b/applications/ColossalChat/benchmarks/benchmark_orpo.sh
index cc6eef5108b2..f8fb264aeaae 100755
--- a/applications/ColossalChat/benchmarks/benchmark_orpo.sh
+++ b/applications/ColossalChat/benchmarks/benchmark_orpo.sh
@@ -15,20 +15,28 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
}
set_n_least_used_CUDA_VISIBLE_DEVICES 2
-PROJECT_NAME="dpo"
+PROJECT_NAME="orpo"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
+BENCHMARK_DATA_DIR="./temp/orpo" # Path to benchmark data
+DATASET_SIZE=160
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
-CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
+declare -a dataset=(
+ $BENCHMARK_DATA_DIR/arrow/part-0
+)
-colossalai run --nproc_per_node 2 --master_port 31313 benchmark_orpo.py \
+# Generate dummy test data
+python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type preference
+
+
+colossalai run --nproc_per_node 2 --master_port 31313 ../examples/training_scripts/train_orpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
+ --dataset ${dataset[@]} \
--plugin "zero2" \
- --config_file $CONFIG_FILE \
--max_epochs 1 \
--accumulation_steps 1 \
--batch_size 4 \
@@ -39,6 +47,5 @@ colossalai run --nproc_per_node 2 --master_port 31313 benchmark_orpo.py \
--max_length 2048 \
--weight_decay 0.01 \
--warmup_steps 60 \
- --dataset_size 160 \
--grad_checkpoint \
--use_flash_attn
diff --git a/applications/ColossalChat/benchmarks/benchmark_sft.py b/applications/ColossalChat/benchmarks/benchmark_sft.py
deleted file mode 100644
index b6438c5039bb..000000000000
--- a/applications/ColossalChat/benchmarks/benchmark_sft.py
+++ /dev/null
@@ -1,315 +0,0 @@
-import argparse
-import json
-import math
-import os
-import resource
-from contextlib import nullcontext
-
-import torch
-from coati.dataset import DataCollatorForSupervisedDataset, StatefulDistributedSampler
-from coati.models import convert_to_lora_module
-from coati.trainer import SFTTrainer
-from coati.utils import load_checkpoint
-from dummy_dataset import DummyLLMDataset
-from transformers import AutoModelForCausalLM, AutoTokenizer
-
-import colossalai
-from colossalai.booster import Booster
-from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
-from colossalai.cluster import DistCoordinator
-from colossalai.logging import get_dist_logger
-from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
-from colossalai.nn.optimizer import HybridAdam
-
-logger = get_dist_logger()
-
-
-def train(args):
- # check lora compatibility
- if "gemini" in args.plugin and args.lora_rank > 0:
- raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
- if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
- raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
- # ==============================
- # Initialize Distributed Training
- # ==============================
- colossalai.launch_from_torch()
- coordinator = DistCoordinator()
-
- # ==============================
- # Initialize Booster
- # ==============================
- init_ctx = nullcontext()
- with init_ctx:
- if args.use_flash_attn:
- model = AutoModelForCausalLM.from_pretrained(
- args.pretrain,
- torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
- attn_implementation="flash_attention_2",
- trust_remote_code=True,
- )
- else:
- model = AutoModelForCausalLM.from_pretrained(
- args.pretrain,
- torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
- trust_remote_code=True,
- )
- if args.lora_rank > 0:
- model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
-
- if args.plugin == "ddp":
- """
- Default torch ddp plugin without any acceleration, for
- debugging purpose acceleration, for debugging purpose
- """
- plugin = TorchDDPPlugin(find_unused_parameters=True)
- elif args.plugin == "gemini":
- plugin = GeminiPlugin(
- precision=args.mixed_precision,
- placement_policy="static",
- initial_scale=2**16,
- max_norm=args.grad_clip,
- enable_gradient_accumulation=True if args.accumulation_steps > 1 else False,
- enable_flash_attention=args.use_flash_attn,
- )
- elif args.plugin == "gemini_auto":
- plugin = GeminiPlugin(
- precision=args.mixed_precision,
- placement_policy="auto",
- initial_scale=2**16,
- max_norm=args.grad_clip,
- enable_flash_attention=args.use_flash_attn,
- )
- elif args.plugin == "zero2":
- plugin = LowLevelZeroPlugin(
- stage=2,
- precision=args.mixed_precision,
- initial_scale=2**16,
- max_norm=args.grad_clip,
- )
- elif args.plugin == "zero2_cpu":
- plugin = LowLevelZeroPlugin(
- stage=2,
- precision=args.mixed_precision,
- initial_scale=2**16,
- cpu_offload=True,
- max_norm=args.grad_clip,
- )
- elif args.plugin == "3d":
- plugin = HybridParallelPlugin(
- tp_size=args.tp,
- pp_size=args.pp,
- sp_size=args.sp,
- sequence_parallelism_mode=args.sp_mode,
- zero_stage=args.zero_stage,
- enable_flash_attention=args.use_flash_attn,
- enable_sequence_parallelism=args.enable_sequence_parallelism,
- cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
- parallel_output=False,
- max_norm=args.grad_clip,
- precision=args.mixed_precision,
- microbatch_size=args.batch_size,
- )
- else:
- raise ValueError(f"Unknown plugin {args.plugin}")
-
- booster = Booster(plugin=plugin)
-
- # ======================================================
- # Initialize Model, Objective, Optimizer and LR Scheduler
- # ======================================================
- # Temp Fix: Disable lazy init due to version conflict
- # init_ctx = (
- # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
- # )
-
- if args.grad_checkpoint:
- # Note, for some models, lora may not be compatible with gradient checkpointing
- model.gradient_checkpointing_enable()
- coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
-
- # configure tokenizer
- tokenizer = AutoTokenizer.from_pretrained(
- args.tokenizer_dir or args.pretrain, use_fast=False, trust_remote_code=True
- )
- if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
- try:
- # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
- tokenizer.pad_token = tokenizer.eos_token
- except AttributeError as e:
- logger.warning(f"Unable to set pad token to eos token, {str(e)}")
- if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
- logger.warning(
- "The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
- )
-
- tokenizer.add_bos_token = False
- tokenizer.add_eos_token = False
- tokenizer.padding_side = "right"
-
- coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
-
- # configure optimizer
- optim = HybridAdam(
- model_params=model.parameters(),
- lr=args.lr,
- betas=(0.9, 0.95),
- weight_decay=args.weight_decay,
- adamw_mode=True,
- )
-
- # configure dataset
- coordinator.print_on_master(
- f"Max CUDA memory before data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
- )
- dataset = DummyLLMDataset(["input_ids", "attention_mask", "labels"], args.max_len, args.dataset_size)
- data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_len)
-
- train_dataloader = plugin.prepare_dataloader(
- dataset=dataset,
- batch_size=args.batch_size,
- shuffle=True,
- drop_last=True,
- collate_fn=data_collator,
- distributed_sampler_cls=StatefulDistributedSampler,
- )
- coordinator.print_on_master(
- f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
- )
-
- num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
- math.ceil(args.max_epochs * num_update_steps_per_epoch)
-
- if args.warmup_steps is None:
- args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
- coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
-
- lr_scheduler = CosineAnnealingWarmupLR(
- optimizer=optim,
- total_steps=args.max_epochs * num_update_steps_per_epoch,
- warmup_steps=args.warmup_steps,
- eta_min=0.1 * args.lr,
- )
-
- # Flash attention will be disabled because it does NOT support fp32.
- default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
- torch.set_default_dtype(default_dtype)
- model, optim, _, train_dataloader, lr_scheduler = booster.boost(
- model=model,
- optimizer=optim,
- lr_scheduler=lr_scheduler,
- dataloader=train_dataloader,
- )
- torch.set_default_dtype(torch.float)
-
- coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
- coordinator.print_on_master(
- f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
- )
-
- start_epoch = 0
- sampler_start_idx = 0
- start_step = 0
- if args.checkpoint_path is not None:
- if "modeling" in args.checkpoint_path:
- coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
- booster.load_model(model, args.checkpoint_path)
- else:
- coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
- start_epoch, start_step, sampler_start_idx = load_checkpoint(
- load_dir=args.checkpoint_path,
- booster=booster,
- model=model,
- optimizer=optim,
- lr_scheduler=lr_scheduler,
- )
- train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
-
- coordinator.print_on_master(
- f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
- )
- coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
-
- coordinator.print_on_master(
- f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
- )
- coordinator.print_on_master(
- f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
- )
- coordinator.print_on_master(
- f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
- )
-
- trainer = SFTTrainer(
- model=model,
- booster=booster,
- optim=optim,
- lr_scheduler=lr_scheduler,
- max_epochs=args.max_epochs,
- accumulation_steps=args.accumulation_steps,
- start_epoch=start_epoch,
- save_interval=None,
- save_dir=None,
- coordinator=coordinator,
- )
-
- trainer.fit(
- train_dataloader=train_dataloader,
- eval_dataloader=None,
- log_dir=None,
- use_wandb=False,
- )
-
- coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
-
-
-if __name__ == "__main__":
- # ==============================
- # Parse Arguments
- # ==============================
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--plugin",
- type=str,
- default="gemini",
- choices=["gemini", "gemini_auto", "3d", "ddp", "zero2_cpu", "zero2"],
- help="Choose which plugin to use",
- )
- parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
- parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
- parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
- parser.add_argument("--tp", type=int, default=1)
- parser.add_argument("--pp", type=int, default=1)
- parser.add_argument("--sp", type=int, default=1)
- parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
- parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
- parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
- parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
- parser.add_argument("--pretrain", type=str, default=None)
- parser.add_argument("--tokenizer_dir", type=str, default=None)
- parser.add_argument(
- "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
- )
- parser.add_argument("--max_epochs", type=int, default=3)
- parser.add_argument("--batch_size", type=int, default=4)
- parser.add_argument("--max_len", type=int, default=512)
- parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["fp16", "bf16"], help="Mixed precision")
- parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument(
- "--lora_train_bias",
- type=str,
- default="none",
- help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
- )
- parser.add_argument("--merge_lora_weights", type=bool, default=True)
- parser.add_argument("--lr", type=float, default=5e-6)
- parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
- parser.add_argument("--accumulation_steps", type=int, default=8)
- parser.add_argument("--grad_checkpoint", default=False, action="store_true")
- parser.add_argument("--use_flash_attn", default=False, action="store_true")
- parser.add_argument("--dataset_size", type=int, default=500)
- args = parser.parse_args()
- os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
- with open(args.config_file, "w") as f:
- json.dump(args.__dict__, f, indent=4)
- train(args)
diff --git a/applications/ColossalChat/benchmarks/benchmark_sft.sh b/applications/ColossalChat/benchmarks/benchmark_sft.sh
index 0c80386efec3..efcd428dd21e 100755
--- a/applications/ColossalChat/benchmarks/benchmark_sft.sh
+++ b/applications/ColossalChat/benchmarks/benchmark_sft.sh
@@ -14,21 +14,31 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
}
set_n_least_used_CUDA_VISIBLE_DEVICES 4
-# export CUDA_VISIBLE_DEVICES=3,4
+
PROJECT_NAME="sft"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
+BENCHMARK_DATA_DIR="./temp/sft" # Path to benchmark data
+DATASET_SIZE=640
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
+declare -a dataset=(
+ $BENCHMARK_DATA_DIR/arrow/part-0
+)
+
+
+# Generate dummy test data
+python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type sft
+
# the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size
-colossalai run --nproc_per_node 4 --master_port 31312 benchmark_sft.py \
+colossalai run --nproc_per_node 1 --master_port 31312 ../examples/training_scripts/train_sft.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
- --config_file $CONFIG_FILE \
+ --dataset ${dataset[@]} \
--plugin zero2 \
--batch_size 8 \
--max_epochs 1 \
@@ -36,6 +46,5 @@ colossalai run --nproc_per_node 4 --master_port 31312 benchmark_sft.py \
--lr 5e-5 \
--lora_rank 32 \
--max_len 2048 \
- --dataset_size 640 \
--grad_checkpoint \
--use_flash_attn
diff --git a/applications/ColossalChat/benchmarks/benchmark_simpo.sh b/applications/ColossalChat/benchmarks/benchmark_simpo.sh
new file mode 100755
index 000000000000..47dfc8595e74
--- /dev/null
+++ b/applications/ColossalChat/benchmarks/benchmark_simpo.sh
@@ -0,0 +1,55 @@
+#!/bin/bash
+set_n_least_used_CUDA_VISIBLE_DEVICES() {
+ local n=${1:-"9999"}
+ echo "GPU Memory Usage:"
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
+ tail -n +2 |
+ nl -v 0 |
+ tee /dev/tty |
+ sort -g -k 2 |
+ awk '{print $1}' |
+ head -n $n)
+ export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
+ echo "Now CUDA_VISIBLE_DEVICES is set to:"
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+}
+set_n_least_used_CUDA_VISIBLE_DEVICES 4
+
+PROJECT_NAME="simpo"
+PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
+PRETRAINED_MODEL_PATH="" # huggingface or local model path
+PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
+BENCHMARK_DATA_DIR="./temp/simpo" # Path to benchmark data
+DATASET_SIZE=640
+
+TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
+FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
+declare -a dataset=(
+ $BENCHMARK_DATA_DIR/arrow/part-0
+)
+
+# Generate dummy test data
+python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type preference
+
+
+colossalai run --nproc_per_node 4 --master_port 31313 ../examples/training_scripts/train_dpo.py \
+ --pretrain $PRETRAINED_MODEL_PATH \
+ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
+ --dataset ${dataset[@]} \
+ --plugin "zero2_cpu" \
+ --loss_type "simpo_loss" \
+ --max_epochs 1 \
+ --accumulation_steps 1 \
+ --batch_size 8 \
+ --lr 1e-6 \
+ --beta 0.1 \
+ --gamma 0.6 \
+ --mixed_precision "bf16" \
+ --grad_clip 1.0 \
+ --max_length 2048 \
+ --weight_decay 0.01 \
+ --warmup_steps 60 \
+ --disable_reference_model \
+ --length_normalization \
+ --grad_checkpoint \
+ --use_flash_attn
diff --git a/applications/ColossalChat/benchmarks/dummy_dataset.py b/applications/ColossalChat/benchmarks/dummy_dataset.py
index 070531fd58f3..9af0f164173f 100644
--- a/applications/ColossalChat/benchmarks/dummy_dataset.py
+++ b/applications/ColossalChat/benchmarks/dummy_dataset.py
@@ -1,10 +1,12 @@
-import torch
+from typing import Callable
+
from torch.utils.data import Dataset
class DummyLLMDataset(Dataset):
- def __init__(self, keys, seq_len, size=500):
+ def __init__(self, keys, seq_len, size=500, gen_fn={}):
self.keys = keys
+ self.gen_fn = gen_fn
self.seq_len = seq_len
self.data = self._generate_data()
self.size = size
@@ -12,11 +14,17 @@ def __init__(self, keys, seq_len, size=500):
def _generate_data(self):
data = {}
for key in self.keys:
- data[key] = torch.ones(self.seq_len, dtype=torch.long)
+ if key in self.gen_fn:
+ data[key] = self.gen_fn[key]
+ else:
+ data[key] = [1] * self.seq_len
return data
def __len__(self):
return self.size
def __getitem__(self, idx):
- return {key: self.data[key] for key in self.keys}
+ return {
+ key: self.data[key] if not isinstance(self.data[key], Callable) else self.data[key](idx)
+ for key in self.keys
+ }
diff --git a/applications/ColossalChat/benchmarks/prepare_dummy_test_dataset.py b/applications/ColossalChat/benchmarks/prepare_dummy_test_dataset.py
new file mode 100644
index 000000000000..f501c53582e6
--- /dev/null
+++ b/applications/ColossalChat/benchmarks/prepare_dummy_test_dataset.py
@@ -0,0 +1,105 @@
+import argparse
+import json
+import os
+import time
+from multiprocessing import cpu_count
+
+from datasets import load_dataset
+from dummy_dataset import DummyLLMDataset
+
+from colossalai.logging import get_dist_logger
+
+logger = get_dist_logger()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--data_dir",
+ type=str,
+ required=True,
+ default=None,
+ help="The output dir",
+ )
+ parser.add_argument(
+ "--dataset_size",
+ type=int,
+ required=True,
+ default=None,
+ help="The size of data",
+ )
+ parser.add_argument(
+ "--max_length",
+ type=int,
+ required=True,
+ default=None,
+ help="The max length of data",
+ )
+ parser.add_argument(
+ "--data_type",
+ type=str,
+ required=True,
+ default=None,
+ help="The type of data, choose one from ['sft', 'prompt', 'preference', 'kto']",
+ )
+ args = parser.parse_args()
+ if args.data_type == "sft":
+ dataset = DummyLLMDataset(["input_ids", "attention_mask", "labels"], args.max_length, args.dataset_size)
+ elif args.data_type == "prompt":
+ # pass PPO dataset is prepared separately
+ pass
+ elif args.data_type == "preference":
+ dataset = DummyLLMDataset(
+ ["chosen_input_ids", "chosen_loss_mask", "rejected_input_ids", "rejected_loss_mask"],
+ args.max_length,
+ args.dataset_size,
+ )
+ elif args.data_type == "kto":
+ dataset = DummyLLMDataset(
+ ["prompt", "completion", "label"],
+ args.max_length - 512,
+ args.dataset_size,
+ gen_fn={
+ "completion": lambda x: [1] * 512,
+ "label": lambda x: x % 2,
+ },
+ )
+ else:
+ raise ValueError(f"Unknown data type {args.data_type}")
+
+ # Save each jsonl spliced dataset.
+ output_index = "0"
+ output_name = f"part-{output_index}"
+ os.makedirs(args.data_dir, exist_ok=True)
+ output_jsonl_path = os.path.join(args.data_dir, "json")
+ output_arrow_path = os.path.join(args.data_dir, "arrow")
+ output_cache_path = os.path.join(args.data_dir, "cache")
+ os.makedirs(output_jsonl_path, exist_ok=True)
+ os.makedirs(output_arrow_path, exist_ok=True)
+ output_jsonl_file_path = os.path.join(output_jsonl_path, output_name + ".jsonl")
+ st = time.time()
+ with open(file=output_jsonl_file_path, mode="w", encoding="utf-8") as fp_writer:
+ count = 0
+ for i in range(len(dataset)):
+ data_point = dataset[i]
+ if count % 500 == 0:
+ logger.info(f"processing {count} spliced data points for {fp_writer.name}")
+ count += 1
+ fp_writer.write(json.dumps(data_point, ensure_ascii=False) + "\n")
+ logger.info(
+ f"Current file {fp_writer.name}; "
+ f"Data size: {len(dataset)}; "
+ f"Time cost: {round((time.time() - st) / 60, 6)} minutes."
+ )
+ # Save each arrow spliced dataset
+ output_arrow_file_path = os.path.join(output_arrow_path, output_name)
+ logger.info(f"Start to save {output_arrow_file_path}")
+ dataset = load_dataset(
+ path="json",
+ data_files=[output_jsonl_file_path],
+ cache_dir=os.path.join(output_cache_path, "tokenized"),
+ keep_in_memory=False,
+ num_proc=cpu_count(),
+ split="train",
+ )
+ dataset.save_to_disk(dataset_path=output_arrow_file_path, num_proc=min(len(dataset), cpu_count()))
diff --git a/applications/ColossalChat/coati/dataset/__init__.py b/applications/ColossalChat/coati/dataset/__init__.py
index deb7b6d926fb..8e9060a1a1f9 100755
--- a/applications/ColossalChat/coati/dataset/__init__.py
+++ b/applications/ColossalChat/coati/dataset/__init__.py
@@ -1,24 +1,26 @@
from .conversation import Conversation, setup_conversation_template
from .loader import (
+ DataCollatorForKTODataset,
DataCollatorForPreferenceDataset,
DataCollatorForPromptDataset,
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
)
-from .tokenization_utils import supervised_tokenize_sft, tokenize_prompt_dataset, tokenize_rlhf
+from .tokenization_utils import tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft
__all__ = [
- "tokenize_prompt_dataset",
+ "tokenize_prompt",
"DataCollatorForPromptDataset",
"is_rank_0",
"DataCollatorForPreferenceDataset",
"DataCollatorForSupervisedDataset",
+ "DataCollatorForKTODataset",
"StatefulDistributedSampler",
"load_tokenized_dataset",
- "supervised_tokenize_pretrain",
- "supervised_tokenize_sft",
+ "tokenize_sft",
"tokenize_rlhf",
+ "tokenize_kto",
"setup_conversation_template",
"Conversation",
]
diff --git a/applications/ColossalChat/coati/dataset/conversation.py b/applications/ColossalChat/coati/dataset/conversation.py
index 37900f3b8d64..a77c220d34af 100755
--- a/applications/ColossalChat/coati/dataset/conversation.py
+++ b/applications/ColossalChat/coati/dataset/conversation.py
@@ -18,6 +18,7 @@ class Conversation:
chat_template: str
stop_ids: List[int]
end_of_assistant: str
+ roles = ["user", "assistant"]
@classmethod
def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict):
@@ -85,7 +86,7 @@ def append_message(self, role: str, message: str):
Raises:
AssertionError: If the role is not 'user' or 'assistant'.
"""
- assert role in ["user", "assistant"]
+ assert role in self.roles
self.messages.append({"role": role, "content": message})
def copy(self):
diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py
index 48011c941f46..b92cd76adc38 100755
--- a/applications/ColossalChat/coati/dataset/loader.py
+++ b/applications/ColossalChat/coati/dataset/loader.py
@@ -235,6 +235,91 @@ def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch
)
+@dataclass
+class DataCollatorForKTODataset(object):
+ """
+ Collate instances for kto dataset.
+ Each input instance is a tokenized dictionary with fields
+ `prompt`(List[int]), `completion`(List[int]) and `label`(bool).
+ Each output instance is a tokenized dictionary with fields
+ `kl_input_ids`(List[int]), `kl_attention_mask`(List[int]) and `kl_loss_mask`(List[int]).
+ `input_ids`(List[int]), `attention_mask`(List[int]), `loss_mask`(List[int]) and `label`(bool).
+ """
+
+ tokenizer: PreTrainedTokenizer
+ max_length: int = 4096
+ ignore_index: int = -100
+
+ def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
+ """
+
+ Args:
+ instances (`Sequence[Dict[str, List[int]]]`):
+ Mini-batch samples, each sample is stored in an individual dictionary contains the following fields:
+ `prompt`(List[int]), `completion`(List[int]) and `label`(bool, if the sample is desirable or not).
+
+ Returns:
+ (`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
+ `input_ids`: `torch.Tensor` of shape (bsz, max_len);
+ `attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
+ `labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
+ """
+ assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
+ f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
+ f"but now `{self.tokenizer.pad_token_id}`"
+ )
+ # prepare the preference data
+ prompt = [torch.LongTensor(instance["prompt"]) for instance in instances]
+ prompt_zeros = [torch.zeros_like(t) for t in prompt]
+ completion = [torch.LongTensor(instance["completion"]) for instance in instances]
+ completion_ones = [torch.ones_like(t) for t in completion]
+ label = [torch.tensor(instance["label"], dtype=torch.bool) for instance in instances]
+ input_ids = [torch.cat([prompt[i], completion[i]], dim=-1) for i in range(len(instances))]
+ loss_mask = [torch.cat([prompt_zeros[i], completion_ones[i]], dim=-1) for i in range(len(instances))]
+ # right padding
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ sequences=input_ids,
+ batch_first=True,
+ padding_value=self.tokenizer.pad_token_id,
+ ) # (bsz, max_len)
+ loss_mask = torch.nn.utils.rnn.pad_sequence(
+ sequences=loss_mask, batch_first=True, padding_value=0
+ ) # (bsz, max_len)
+ to_pad = self.max_length - input_ids.size(1)
+ input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
+ loss_mask = F.pad(loss_mask, (0, to_pad), value=0)
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
+
+ # prepare kt data
+ kl_completion = completion[::-1] # y'
+ kl_completion_ones = [torch.ones_like(t) for t in kl_completion]
+ kl_input_ids = [torch.cat([prompt[i], kl_completion[i]], dim=-1) for i in range(len(instances))]
+ kl_loss_mask = [torch.cat([prompt_zeros[i], kl_completion_ones[i]], dim=-1) for i in range(len(instances))]
+ # right padding
+ kl_input_ids = torch.nn.utils.rnn.pad_sequence(
+ sequences=kl_input_ids,
+ batch_first=True,
+ padding_value=self.tokenizer.pad_token_id,
+ ) # (bsz, max_len)
+ kl_loss_mask = torch.nn.utils.rnn.pad_sequence(
+ sequences=kl_loss_mask, batch_first=True, padding_value=0
+ ) # (bsz, max_len)
+ to_pad = self.max_length - kl_input_ids.size(1)
+ kl_input_ids = F.pad(kl_input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
+ kl_loss_mask = F.pad(kl_loss_mask, (0, to_pad), value=0)
+ kl_attention_mask = kl_input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
+ data_dict = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "loss_mask": loss_mask,
+ "label": torch.stack(label),
+ "kl_input_ids": kl_input_ids,
+ "kl_attention_mask": kl_attention_mask,
+ "kl_loss_mask": kl_loss_mask,
+ }
+ return data_dict
+
+
class StatefulDistributedSampler(DistributedSampler):
def __init__(
self,
diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py
index 27addcb0d057..4f890ffc9aa8 100755
--- a/applications/ColossalChat/coati/dataset/tokenization_utils.py
+++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py
@@ -23,11 +23,10 @@
DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
-def supervised_tokenize_sft(
+def tokenize_sft(
data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None,
- ignore_index: int = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
@@ -39,54 +38,41 @@ def supervised_tokenize_sft(
Args:
data_point: the data point of the following format
- {"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
+ {"messages": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
tokenizer: the tokenizer whose
conversation_template: the conversation template to apply
ignore_index: the ignore index when calculate loss during training
max_length: the maximum context length
"""
- if ignore_index is None:
- ignore_index = IGNORE_INDEX
+ ignore_index = IGNORE_INDEX
messages = data_point["messages"]
template = deepcopy(conversation_template)
- template.messages = []
-
- for mess in messages:
- from_str = mess["from"]
- if from_str.lower() == "human":
- from_str = "user"
- elif from_str.lower() == "assistant":
- from_str = "assistant"
- else:
- raise ValueError(f"Unsupported role {from_str.lower()}")
- template.append_message(from_str, mess["content"])
+ if messages[0]["from"] == "system":
+ template.system_message = str(messages[0]["content"])
+ messages.pop(0)
+ template.messages = []
+ for idx, mess in enumerate(messages):
+ if mess["from"] != template.roles[idx % 2]:
+ raise ValueError(
+ f"Message should iterate between user and assistant and starts with a \
+ line from the user. Got the following data:\n{messages}"
+ )
+ template.append_message(mess["from"], mess["content"])
if len(template.messages) % 2 != 0:
+ # Force to end with assistant response
template.messages = template.messages[0:-1]
- # `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
- turns = [i for i in range(1, len(messages) // 2 + 1)]
-
- lo, hi = 0, len(turns)
- while lo < hi:
- mid = (lo + hi) // 2
- prompt = template.get_prompt(2 * turns[mid] - 1)
- chunks, require_loss = split_templated_prompt_into_chunks(
- template.messages[: 2 * turns[mid] - 1], prompt, conversation_template.end_of_assistant
- )
- tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
- if max_length - 1 < len(tokenized):
- hi = mid
- else:
- lo = mid + 1
- target_turn_index = lo
-
- # The tokenized length for first turn already exceeds `max_length - 1`.
- if target_turn_index - 1 < 0:
- warnings.warn("The tokenized length for first turn already exceeds `max_length - 1`.")
+ # tokenize and calculate masked labels -100 for positions corresponding to non-assistant lines
+ prompt = template.get_prompt()
+ chunks, require_loss = split_templated_prompt_into_chunks(
+ template.messages, prompt, conversation_template.end_of_assistant
+ )
+ tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=max_length)
+ if tokenized is None:
return dict(
input_ids=None,
labels=None,
@@ -96,45 +82,18 @@ def supervised_tokenize_sft(
seq_category=None,
)
- target_turn = turns[target_turn_index - 1]
- prompt = template.get_prompt(2 * target_turn)
- chunks, require_loss = split_templated_prompt_into_chunks(
- template.messages[: 2 * target_turn], prompt, conversation_template.end_of_assistant
- )
- tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
-
labels = [ignore_index] * len(tokenized)
for start, end in zip(starts, ends):
- if end == len(tokenized):
- tokenized = tokenized + [tokenizer.eos_token_id]
- labels = labels + [ignore_index]
labels[start:end] = tokenized[start:end]
- # truncate the sequence at the last token that requires loss calculation
- to_truncate_len = 0
- for i in range(len(tokenized) - 1, -1, -1):
- if labels[i] == ignore_index:
- to_truncate_len += 1
- else:
- break
- to_truncate_len = max(len(tokenized) - max_length, to_truncate_len)
- tokenized = tokenized[: len(tokenized) - to_truncate_len]
- labels = labels[: len(labels) - to_truncate_len]
-
if tokenizer.bos_token_id is not None:
+ # Force to add bos token at the beginning of the tokenized sequence if the input ids doesn;t starts with bos
if tokenized[0] != tokenizer.bos_token_id:
+ # Some chat templates already include bos token
tokenized = [tokenizer.bos_token_id] + tokenized
- labels = [ignore_index] + labels
+ labels = [-100] + labels
- if tokenizer.eos_token_id is not None:
- # Force to add eos token at the end of the tokenized sequence
- if tokenized[-1] != tokenizer.eos_token_id:
- tokenized = tokenized + [tokenizer.eos_token_id]
- labels = labels + [tokenizer.eos_token_id]
- else:
- labels[-1] = tokenizer.eos_token_id
-
- # For some model without bos/eos may raise the following errors
+ # log decoded inputs and labels for debugging
inputs_decode = tokenizer.decode(tokenized)
start = 0
end = 0
@@ -171,11 +130,10 @@ def supervised_tokenize_sft(
)
-def tokenize_prompt_dataset(
+def tokenize_prompt(
data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None,
- ignore_index: int = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
@@ -183,48 +141,42 @@ def tokenize_prompt_dataset(
"Something here can be system message[user_line_start]User line[User line end][Assistant line start]Assistant line[Assistant line end]...[Assistant line start]"
Args:
data_point: the data point of the following format
- {"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
+ {"messages": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
tokenizer: the tokenizer whose
conversation_template: the conversation template to apply
ignore_index: the ignore index when calculate loss during training
max_length: the maximum context length
"""
- if ignore_index is None:
- ignore_index = IGNORE_INDEX
messages = data_point["messages"]
template = deepcopy(conversation_template)
template.messages = []
- for mess in messages:
- from_str = mess["from"]
- if from_str.lower() == "human":
- from_str = "user"
- elif from_str.lower() == "assistant":
- from_str = "assistant"
- else:
- raise ValueError(f"Unsupported role {from_str.lower()}")
+ if messages[0]["from"] == "system":
+ template.system_message = str(messages[0]["content"])
+ messages.pop(0)
- template.append_message(from_str, mess["content"])
+ for idx, mess in enumerate(messages):
+ if mess["from"] != template.roles[idx % 2]:
+ raise ValueError(
+ f"Message should iterate between user and assistant and starts with a line from the user. Got the following data:\n{messages}"
+ )
+ template.append_message(mess["from"], mess["content"])
# `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
- target_turn = len(template.messages)
- if target_turn % 2 != 1:
+ if len(template.messages) % 2 != 1:
# exclude the answer if provided. keep only the prompt
- target_turn = target_turn - 1
+ template.messages = template.messages[:-1]
# Prepare data
- prompt = template.get_prompt(target_turn, add_generation_prompt=True)
- chunks, require_loss = split_templated_prompt_into_chunks(
- template.messages[:target_turn], prompt, conversation_template.end_of_assistant
- )
- tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
+ prompt = template.get_prompt(length=len(template.messages) - 1, add_generation_prompt=True)
+ tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
+
if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id:
tokenized = [tokenizer.bos_token_id] + tokenized
- # Skip overlength data
- if max_length - 1 < len(tokenized):
+ if len(tokenized) > max_length:
return dict(
input_ids=None,
inputs_decode=None,
@@ -235,47 +187,32 @@ def tokenize_prompt_dataset(
# `inputs_decode` can be used to check whether the tokenization method is true.
return dict(
input_ids=tokenized,
- inputs_decode=tokenizer.decode(tokenized),
+ inputs_decode=prompt,
seq_length=len(tokenized),
seq_category=data_point["category"] if "category" in data_point else "None",
)
-def apply_rlhf_data_format(
- template: Conversation, tokenizer: Any, context_len: int, mask_out_target_assistant_line_end=False
-):
+def apply_rlhf_data_format(template: Conversation, tokenizer: Any):
target_turn = int(len(template.messages) / 2)
prompt = template.get_prompt(target_turn * 2)
chunks, require_loss = split_templated_prompt_into_chunks(
template.messages[: 2 * target_turn], prompt, template.end_of_assistant
)
- tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
- loss_mask = [0] * len(tokenized)
- mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id
- if mask_token is None:
- mask_token = 1 # If the tokenizer doesn't have eos_token or pad_token: Qwen
+ # no truncation applied
+ tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=None)
+ loss_mask = [0] * len(tokenized)
label_decode = []
- for start, end in zip(starts[-1:], ends[-1:]):
- # only the last round (chosen/rejected) counts
- if end == len(tokenized):
- tokenized = tokenized + [tokenizer.eos_token_id]
- loss_mask = loss_mask + [1]
- loss_mask[start:end] = [1] * len(loss_mask[start:end])
- label_decode.append(tokenizer.decode(tokenized[start:end], skip_special_tokens=False))
+ # only the last round (chosen/rejected) is used to calculate loss
+ for i in range(starts[-1], ends[-1]):
+ loss_mask[i] = 1
+ label_decode.append(tokenizer.decode(tokenized[starts[-1] : ends[-1]], skip_special_tokens=False))
if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id:
tokenized = [tokenizer.bos_token_id] + tokenized
loss_mask = [0] + loss_mask
- if tokenizer.eos_token_id is not None:
- # Force to add eos token at the end of the tokenized sequence
- if tokenized[-1] != tokenizer.eos_token_id:
- tokenized = tokenized + [tokenizer.eos_token_id]
- loss_mask = loss_mask + [1]
- else:
- loss_mask[-1] = 1
-
return {"input_ids": tokenized, "loss_mask": loss_mask, "label_decode": label_decode}
@@ -283,39 +220,33 @@ def tokenize_rlhf(
data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None,
- ignore_index: int = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
A tokenization function to tokenize an original pretraining data point as following:
- {"context": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}],
+ {"context": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}],
"chosen": {"from": "assistant", "content": "xxx"}, "rejected": {"from": "assistant", "content": "xxx"}}
"""
- if ignore_index is None:
- ignore_index = IGNORE_INDEX
context = data_point["context"]
template = deepcopy(conversation_template)
template.clear()
- for mess in context:
- from_str = mess["from"]
- if from_str.lower() == "human":
- from_str = "user"
- elif from_str.lower() == "assistant":
- from_str = "assistant"
- else:
- raise ValueError(f"Unsupported role {from_str.lower()}")
+ if context[0]["from"] == "system":
+ template.system_message = str(context[0]["content"])
+ context.pop(0)
- if len(template.messages) > 0 and from_str == template.messages[-1]["role"]:
- # Concate adjacent message from the same role
- template.messages[-1]["content"] = str(template.messages[-1]["content"] + " " + mess["content"])
- else:
- template.append_message(from_str, mess["content"])
+ for idx, mess in enumerate(context):
+ if mess["from"] != template.roles[idx % 2]:
+ raise ValueError(
+ f"Message should iterate between user and assistant and starts with a \
+ line from the user. Got the following data:\n{context}"
+ )
+ template.append_message(mess["from"], mess["content"])
if len(template.messages) % 2 != 1:
warnings.warn(
- "Please make sure leading context starts and ends with a line from human\nLeading context: "
+ "Please make sure leading context starts and ends with a line from user\nLeading context: "
+ str(template.messages)
)
return dict(
@@ -326,31 +257,27 @@ def tokenize_rlhf(
rejected_loss_mask=None,
rejected_label_decode=None,
)
- round_of_context = int((len(template.messages) - 1) / 2)
- assert context[-1]["from"].lower() == "human", "The last message in context should be from human."
+ assert context[-1]["from"].lower() == template.roles[0], "The last message in context should be from user."
chosen = deepcopy(template)
rejected = deepcopy(template)
-
- for round in range(len(data_point["chosen"])):
- from_str = data_point["chosen"][round]["from"]
- if from_str.lower() == "human":
- from_str = "user"
- elif from_str.lower() == "assistant":
- from_str = "assistant"
- else:
- raise ValueError(f"Unsupported role {from_str.lower()}")
- chosen.append_message(from_str, data_point["chosen"][round]["content"])
-
- for round in range(len(data_point["rejected"])):
- from_str = data_point["rejected"][round]["from"]
- if from_str.lower() == "human":
- from_str = "user"
- elif from_str.lower() == "assistant":
- from_str = "assistant"
- else:
- raise ValueError(f"Unsupported role {from_str.lower()}")
- rejected.append_message(from_str, data_point["rejected"][round]["content"])
+ chosen_continuation = data_point["chosen"]
+ rejected_continuation = data_point["rejected"]
+ for round in range(len(chosen_continuation)):
+ if chosen_continuation[round]["from"] != template.roles[(round + 1) % 2]:
+ raise ValueError(
+ f"Message should iterate between user and assistant and starts with a \
+ line from the user. Got the following data:\n{chosen_continuation}"
+ )
+ chosen.append_message(chosen_continuation[round]["from"], chosen_continuation[round]["content"])
+
+ for round in range(len(rejected_continuation)):
+ if rejected_continuation[round]["from"] != template.roles[(round + 1) % 2]:
+ raise ValueError(
+ f"Message should iterate between user and assistant and starts with a \
+ line from the user. Got the following data:\n{rejected_continuation}"
+ )
+ rejected.append_message(rejected_continuation[round]["from"], rejected_continuation[round]["content"])
(
chosen_input_ids,
@@ -361,16 +288,14 @@ def tokenize_rlhf(
rejected_label_decode,
) = (None, None, None, None, None, None)
- chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer, round_of_context)
+ chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer)
(chosen_input_ids, chosen_loss_mask, chosen_label_decode) = (
chosen_data_packed["input_ids"],
chosen_data_packed["loss_mask"],
chosen_data_packed["label_decode"],
)
- rejected_data_packed = apply_rlhf_data_format(
- rejected, tokenizer, round_of_context, mask_out_target_assistant_line_end=True
- )
+ rejected_data_packed = apply_rlhf_data_format(rejected, tokenizer)
(rejected_input_ids, rejected_loss_mask, rejected_label_decode) = (
rejected_data_packed["input_ids"],
rejected_data_packed["loss_mask"],
@@ -387,7 +312,7 @@ def tokenize_rlhf(
rejected_label_decode=None,
)
# Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long
- if chosen_loss_mask[1:].count(1) == 0 or rejected_loss_mask[1:].count(1) == 0:
+ if chosen_loss_mask.count(1) == 0 or rejected_loss_mask.count(1) == 0:
return dict(
chosen_input_ids=None,
chosen_loss_mask=None,
@@ -405,3 +330,66 @@ def tokenize_rlhf(
"rejected_loss_mask": rejected_loss_mask,
"rejected_label_decode": rejected_label_decode,
}
+
+
+def tokenize_kto(
+ data_point: Dict[str, str],
+ tokenizer: PreTrainedTokenizer,
+ conversation_template: Conversation = None,
+ max_length: int = 4096,
+) -> Dict[str, Union[int, str, List[int]]]:
+ """
+ Tokenize a dataset for KTO training
+ The raw input data is conversation that have the following format
+ {
+ "prompt": [{"from": "user", "content": "xxx"}...],
+ "completion": {"from": "assistant", "content": "xxx"},
+ "label": true/false
+ }
+ It returns three fields
+ The context, which contain the query and the assistant start,
+ the completion, which only contains the assistance's answer,
+ and a binary label, which indicates if the sample is prefered or not
+ """
+ prompt = data_point["prompt"]
+ completion = data_point["completion"]
+ template = deepcopy(conversation_template)
+ template.clear()
+
+ if prompt[0]["from"] == "system":
+ template.system_message = str(prompt[0]["content"])
+ prompt.pop(0)
+
+ if prompt[0].get("from", None) != "user":
+ raise ValueError("conversation should start with user")
+ if completion.get("from", None) != "assistant":
+ raise ValueError("conversation should end with assistant")
+
+ for mess in prompt:
+ if mess.get("from", None) == "user":
+ template.append_message("user", mess["content"])
+ elif mess.get("from", None) == "assistant":
+ template.append_message("assistant", mess["content"])
+ else:
+ raise ValueError(f"Unsupported role {mess.get('from', None)}")
+ generation_prompt = template.get_prompt(len(prompt), add_generation_prompt=True)
+ template.append_message("assistant", completion["content"])
+ full_prompt = template.get_prompt(len(prompt) + 1, add_generation_prompt=False)
+ tokenized_full_prompt = tokenizer(full_prompt, add_special_tokens=False)["input_ids"]
+ if len(tokenized_full_prompt) + 1 > max_length:
+ return dict(prompt=None, completion=None, label=None, input_id_decode=None, completion_decode=None)
+ tokenized_generation_prompt = tokenizer(generation_prompt, add_special_tokens=False)["input_ids"]
+ tokenized_completion = tokenized_full_prompt[len(tokenized_generation_prompt) :]
+ tokenized_completion = deepcopy(tokenized_completion)
+ if tokenizer.bos_token_id is not None and tokenized_generation_prompt[0] != tokenizer.bos_token_id:
+ tokenized_generation_prompt = [tokenizer.bos_token_id] + tokenized_generation_prompt
+ decoded_full_prompt = tokenizer.decode(tokenized_full_prompt, skip_special_tokens=False)
+ decoded_completion = tokenizer.decode(tokenized_completion, skip_special_tokens=False)
+
+ return {
+ "prompt": tokenized_generation_prompt,
+ "completion": tokenized_completion,
+ "label": data_point["label"],
+ "input_id_decode": decoded_full_prompt,
+ "completion_decode": decoded_completion,
+ }
diff --git a/applications/ColossalChat/coati/dataset/utils.py b/applications/ColossalChat/coati/dataset/utils.py
index f41a4d7724da..42c3191db3a5 100755
--- a/applications/ColossalChat/coati/dataset/utils.py
+++ b/applications/ColossalChat/coati/dataset/utils.py
@@ -88,7 +88,13 @@ def find_first_occurrence_subsequence(seq: torch.Tensor, subseq: torch.Tensor, s
return -1
-def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], require_loss: List[bool]):
+def tokenize_and_concatenate(
+ tokenizer: PreTrainedTokenizer,
+ text: List[str],
+ require_loss: List[bool],
+ max_length: int,
+ discard_non_loss_tokens_at_tail: bool = True,
+):
"""
Tokenizes a list of texts using the provided tokenizer and concatenates the tokenized outputs.
@@ -96,6 +102,13 @@ def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], re
tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenization.
text (List[str]): The list of texts to tokenize.
require_loss (List[bool]): A list of boolean values indicating whether each text requires loss calculation.
+ max_length: used to truncate the input ids
+ discard_non_loss_tokens_at_tail: whether to discard the non-loss tokens at the tail
+
+ if the first round has already exeeded max length
+ - if the user query already exeeded max length, discard the sample
+ - if only the first assistant response exeeded max length, truncate the response to fit the max length
+ else keep the first several complete rounds of the conversations until max length is reached
Returns:
Tuple[List[int], List[int], List[int]]: A tuple containing the concatenated tokenized input ids,
@@ -106,10 +119,18 @@ def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], re
loss_ends = []
for s, r in zip(text, require_loss):
tokenized = tokenizer(s, add_special_tokens=False)["input_ids"]
- if r:
- loss_starts.append(len(input_ids))
- loss_ends.append(len(input_ids) + len(tokenized))
- input_ids.extend(tokenized)
+ if not max_length or len(input_ids) + len(tokenized) <= max_length or len(loss_ends) == 0:
+ if r:
+ loss_starts.append(len(input_ids))
+ loss_ends.append(len(input_ids) + len(tokenized))
+ input_ids.extend(tokenized)
+ if max_length and loss_starts[0] >= max_length:
+ return None, None, None
+ if discard_non_loss_tokens_at_tail:
+ input_ids = input_ids[: loss_ends[-1]]
+ if max_length:
+ input_ids = input_ids[:max_length]
+ loss_ends[-1] = min(max_length, loss_ends[-1])
return input_ids, loss_starts, loss_ends
@@ -125,6 +146,12 @@ def split_templated_prompt_into_chunks(messages: List[Dict[str, str]], prompt: s
content_length = (
prompt.find(end_of_assistant, first_occur + content_length) + len(end_of_assistant) - first_occur
)
+ # if the tokenized content start with a leading space, we want to keep it in loss calculation
+ # e.g., Assistant: I am saying...
+ # if the tokenized content doesn't start with a leading space, we only need to keep the content in loss calculation
+ # e.g.,
+ # Assistant: # '\n' as line breaker
+ # I am saying...
if prompt[first_occur - 1] != " ":
chunks.append(prompt[start_idx:first_occur])
chunks.append(prompt[first_occur : first_occur + content_length])
diff --git a/applications/ColossalChat/coati/models/__init__.py b/applications/ColossalChat/coati/models/__init__.py
index 14073207f150..fba0949e3fb8 100755
--- a/applications/ColossalChat/coati/models/__init__.py
+++ b/applications/ColossalChat/coati/models/__init__.py
@@ -1,8 +1,8 @@
from .base import BaseModel
from .critic import Critic
from .generation import generate, generate_streaming, prepare_inputs_fn, update_model_kwargs_fn
-from .lora import convert_to_lora_module
-from .loss import DpoLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
+from .lora import LoraConfig, convert_to_lora_module, lora_manager
+from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from .reward_model import RewardModel
from .utils import disable_dropout
@@ -14,9 +14,11 @@
"ValueLoss",
"LogSigLoss",
"LogExpLoss",
+ "LoraConfig",
+ "lora_manager",
"convert_to_lora_module",
"DpoLoss",
- "generate",
+ "KTOLoss" "generate",
"generate_streaming",
"disable_dropout",
"update_model_kwargs_fn",
diff --git a/applications/ColossalChat/coati/models/base.py b/applications/ColossalChat/coati/models/base.py
index fcea9414b430..cfdffdf289bd 100755
--- a/applications/ColossalChat/coati/models/base.py
+++ b/applications/ColossalChat/coati/models/base.py
@@ -42,7 +42,6 @@ def __init__(self, pretrained: str = None, config: Optional[PretrainedConfig] =
out = self.model(dummy_input)
self.last_hidden_state_size = out.last_hidden_state.shape[-1]
self.model = self.model.cpu()
- # print("self.last_hidden_state_size: ",self.last_hidden_state_size)
def resize_token_embeddings(self, *args, **kwargs):
"""
diff --git a/applications/ColossalChat/coati/models/lora.py b/applications/ColossalChat/coati/models/lora.py
index 9553b00ff2a8..aa5f6ecf8608 100755
--- a/applications/ColossalChat/coati/models/lora.py
+++ b/applications/ColossalChat/coati/models/lora.py
@@ -5,10 +5,11 @@
import dataclasses
import math
import warnings
-from typing import Optional
+from typing import List, Optional, Union
import loralib as lora
import torch
+import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
@@ -18,148 +19,349 @@
@dataclasses.dataclass
-class LoRAManager:
- merge_weights: bool = False
+class LoraManager:
+ able_to_merge: bool = True
-LORA_MANAGER = LoRAManager()
+lora_manager = LoraManager()
-class LoraLinear(lora.LoRALayer, nn.Module):
+@dataclasses.dataclass
+class LoraConfig:
+ r: int = 0
+ lora_alpha: int = 32
+ linear_lora_dropout: float = 0.1
+ embedding_lora_dropout: float = 0.0
+ lora_train_bias: str = "none"
+ lora_initialization_method: str = "kaiming_uniform"
+ target_modules: List = None
+
+ @classmethod
+ def from_file(cls, config_file: str):
+ import json
+
+ with open(config_file, "r") as f:
+ config = json.load(f)
+ return cls(**config)
+
+
+class LoraBase(lora.LoRALayer, nn.Module):
+ def __init__(
+ self,
+ r: int = 0,
+ lora_alpha: int = 32,
+ lora_dropout: float = 0.1,
+ lora_initialization_method: str = "kaiming_uniform",
+ ):
+ nn.Module.__init__(self)
+ lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
+ self.r = r
+ self.lora_alpha = lora_alpha
+ self.lora_dropout = nn.Dropout(lora_dropout)
+ self.merged = False
+ self.lora_initialization_method = lora_initialization_method
+ self.weight = None
+ self.bias = None
+ self.lora_A = None
+ self.lora_B = None
+
+ def reset_parameters(self):
+ if hasattr(self, "lora_A"):
+ if self.lora_initialization_method == "kaiming_uniform" or self.weight.size() != (
+ self.out_features,
+ self.in_features,
+ ):
+ # Initialize A with the default values for nn.Linear and set B to zero.
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_B)
+ elif self.lora_initialization_method == "PiSSA":
+ # PiSSA method in this paper: https://arxiv.org/abs/2404.02948
+ # Assume the SVD of the original weights is W = USV^T
+ # Initialize a frozen weight to U[:,r:]S[r:,r:]V^T[:,r:] to store less significent part of W
+ # Only A, B are trainable, which are initialized to S[r:,:r]^0.5V^T[:,:r] and U[:,:r]S[r:,:r] respectively
+ # self.scaling = 1.
+ # SVD
+ U, S, Vh = torch.svd_lowrank(
+ self.weight.to(torch.float32).data, self.r, niter=4
+ ) # U: [out_features, in_features], S: [in_features], V: [in_features, in_features]
+ # weight_backup = self.weight.clone()
+
+ # Initialize A, B
+ S = S / self.scaling
+ self.lora_B.data = (U @ torch.diag(torch.sqrt(S))).to(torch.float32).contiguous()
+ self.lora_A.data = (torch.diag(torch.sqrt(S)) @ Vh.T).to(torch.float32).contiguous()
+ # Initialize weight
+ # To reduce floating point error, we use residual instead of directly using U[:, :self.r] @ S[:self.r] @ Vh[:self.r, :]
+ self.weight.data = (
+ ((self.weight - self.scaling * self.lora_B @ self.lora_A)).contiguous().to(self.weight.dtype)
+ )
+ self.lora_A.requires_grad = True
+ self.lora_B.requires_grad = True
+ else:
+ raise ValueError(f"Unknown LoRA initialization method {self.lora_initialization_method}")
+
+ def train(self, mode: bool = True):
+ """
+ This function runs when model.train() is invoked. It is used to prepare the linear layer for training
+ """
+
+ self.training = mode
+ if mode and self.merged:
+ warnings.warn("Invoke module.train() would unmerge LoRA weights.")
+ raise NotImplementedError("LoRA unmerge is not tested.")
+ elif not mode and not self.merged and lora_manager.able_to_merge:
+ warnings.warn("Invoke module.eval() would merge LoRA weights.")
+ # Merge the weights and mark it
+ if self.r > 0:
+ self.weight.data += self.lora_B @ self.lora_A * self.scaling
+ delattr(self, "lora_A")
+ delattr(self, "lora_B")
+ self.merged = True
+
+ return self
+
+
+class LoraLinear(LoraBase):
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
def __init__(
self,
weight: nn.Parameter,
- bias: Optional[nn.Parameter],
+ bias: Union[nn.Parameter, bool],
r: int = 0,
- lora_alpha: int = 1,
+ lora_alpha: int = 32,
lora_dropout: float = 0.0,
- # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
- fan_in_fan_out: bool = False,
+ lora_initialization_method: str = "kaiming_uniform",
):
- nn.Module.__init__(self)
- lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
+ super().__init__(
+ r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, lora_initialization_method=lora_initialization_method
+ )
self.weight = weight
self.bias = bias
+ if bias is True:
+ self.bias = nn.Parameter(torch.zeros(weight.shape[0]))
+ if bias is not None:
+ self.bias.requires_grad = True
out_features, in_features = weight.shape
self.in_features = in_features
self.out_features = out_features
-
- self.fan_in_fan_out = fan_in_fan_out
+ assert lora_initialization_method in ["kaiming_uniform", "PiSSA"]
+ self.lora_initialization_method = lora_initialization_method
# Actual trainable parameters
if r > 0:
- self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
- self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
+ self.lora_A = nn.Parameter(torch.randn((r, in_features)))
+ self.lora_B = nn.Parameter(torch.randn((out_features, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
- if fan_in_fan_out:
- self.weight.data = self.weight.data.T
- def reset_parameters(self):
- if hasattr(self, "lora_A"):
- # Initialize A with the default values for nn.Linear and set B to zero.
- nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
- nn.init.zeros_(self.lora_B)
+ def forward(self, x: torch.Tensor):
+ if self.r > 0 and not self.merged:
+ result = F.linear(x, self.weight, bias=self.bias)
+ result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
+ return result
+ else:
+ return F.linear(x, self.weight, bias=self.bias)
+
+
+class LoraEmbedding(LoraBase):
+ """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
+
+ def __init__(
+ self,
+ weight: nn.Parameter,
+ r: int = 0,
+ lora_alpha: int = 32,
+ lora_dropout: float = 0.1,
+ num_embeddings: int = None,
+ embedding_dim: int = None,
+ padding_idx: Optional[int] = None,
+ max_norm: Optional[float] = None,
+ norm_type: float = 2.0,
+ scale_grad_by_freq: bool = False,
+ sparse: bool = False,
+ lora_initialization_method: str = "kaiming_uniform",
+ ):
+ super().__init__(
+ r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, lora_initialization_method=lora_initialization_method
+ )
+ self.padding_idx = padding_idx
+ self.max_norm = max_norm
+ self.norm_type = norm_type
+ self.scale_grad_by_freq = scale_grad_by_freq
+ self.sparse = sparse
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+
+ self.weight = weight
+
+ in_features, out_features = num_embeddings, embedding_dim
+ self.in_features = in_features
+ self.out_features = out_features
+ assert lora_initialization_method in ["kaiming_uniform", "PiSSA"]
+ self.lora_initialization_method = lora_initialization_method
+
+ # Actual trainable parameters
+ if r > 0:
+ self.lora_A = nn.Parameter(torch.randn((r, in_features)))
+ self.lora_B = nn.Parameter(torch.randn((out_features, r)))
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.weight.requires_grad = False
+
+ # reset parameters
+ nn.init.zeros_(self.lora_A)
+ nn.init.normal_(self.lora_B)
+
+ def _embed(self, x: torch.Tensor, weight) -> torch.Tensor:
+ return F.embedding(
+ x,
+ weight,
+ padding_idx=self.padding_idx,
+ max_norm=self.max_norm,
+ norm_type=self.norm_type,
+ scale_grad_by_freq=self.scale_grad_by_freq,
+ sparse=self.sparse,
+ )
+
+ def forward(self, x: torch.Tensor):
+ base_embedding = self._embed(x, self.weight)
+ # base_embedding.requires_grad = True # force the embedding layer to be trainable for gradient checkpointing
+ if self.r > 0 and not self.merged:
+ lora_A_embedding = self._embed(x, self.lora_A.t())
+ embedding = base_embedding + (lora_A_embedding @ self.lora_B.t()) * self.scaling
+ return embedding
+ else:
+ return base_embedding
def train(self, mode: bool = True):
"""
This function runs when model.train() is invoked. It is used to prepare the linear layer for training
"""
- def T(w):
- return w.T if self.fan_in_fan_out else w
-
self.training = mode
- if LORA_MANAGER.merge_weights:
- if mode and self.merged:
- warnings.warn("Invoke module.train() would unmerge LoRA weights.")
- raise NotImplementedError("LoRA unmerge is not tested.")
- # Make sure that the weights are not merged
- if self.r > 0:
- if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
- # FIXME(csric): temporary fix
- self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
- self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
- self.reset_parameters()
- else:
- self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
- self.merged = False
- elif not mode and not self.merged:
- warnings.warn("Invoke module.eval() would merge LoRA weights.")
- # Merge the weights and mark it
- if self.r > 0:
- self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
- delattr(self, "lora_A")
- delattr(self, "lora_B")
- self.merged = True
+ if mode and self.merged:
+ warnings.warn("Invoke module.train() would unmerge LoRA weights.")
+ raise NotImplementedError("LoRA unmerge is not tested.")
+ elif not mode and not self.merged and lora_manager.able_to_merge:
+ warnings.warn("Invoke module.eval() would merge LoRA weights.")
+ # Merge the weights and mark it
+ if self.r > 0:
+ self.weight.data += self.lora_A.t() @ self.lora_B.t() * self.scaling
+ delattr(self, "lora_A")
+ delattr(self, "lora_B")
+ self.merged = True
return self
- def forward(self, x: torch.Tensor):
- def T(w):
- return w.T if self.fan_in_fan_out else w
-
- if self.r > 0 and not self.merged:
- result = F.linear(x, T(self.weight), bias=self.bias)
- if self.r > 0:
- result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
- return result
- else:
- return F.linear(x, T(self.weight), bias=self.bias)
-
-def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
+def _lora_linear_wrapper(linear: nn.Linear, lora_config: LoraConfig) -> LoraLinear:
"""
Wraps a linear layer with LoRA functionality.
Args:
linear (nn.Linear): The linear layer to be wrapped.
lora_rank (int): The rank of the LoRA decomposition.
+ lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora".
+ lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA".
Returns:
LoraLinear: The wrapped linear layer with LoRA functionality.
"""
assert (
- lora_rank <= linear.in_features
- ), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
- lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank)
+ lora_config.r <= linear.in_features
+ ), f"LoRA rank ({lora_config.r}) must be less than or equal to in features ({linear.in_features})"
+ bias = None
+ if lora_config.lora_train_bias in ["all", "lora"]:
+ bias = linear.bias
+ if bias is None:
+ bias = True
+ lora_linear = LoraLinear(
+ linear.weight, bias, r=lora_config.r, lora_initialization_method=lora_config.lora_initialization_method
+ )
return lora_linear
-def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
+def _convert_to_lora_recursively(module: nn.Module, parent_name: str, lora_config: LoraConfig) -> None:
"""
Recursively converts the given module and its children to LoRA (Low-Rank Approximation) form.
Args:
module (nn.Module): The module to convert to LoRA form.
lora_rank (int): The rank of the LoRA approximation.
+ lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora".
+ parent_name (str): The name of the parent module.
+ lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA".
Returns:
None
"""
for name, child in module.named_children():
if isinstance(child, nn.Linear):
- setattr(module, name, _lora_linear_wrapper(child, lora_rank))
+ if lora_config.target_modules is None or any(
+ [name in target_module for target_module in lora_config.target_modules]
+ ):
+ if dist.is_initialized() and dist.get_rank() == 0:
+ logger.info(f"Converting {parent_name}.{name} to LoRA")
+ setattr(module, name, _lora_linear_wrapper(child, lora_config))
+ elif isinstance(child, nn.Embedding):
+ if lora_config.target_modules is None or any(
+ [name in target_module for target_module in lora_config.target_modules]
+ ):
+ if dist.is_initialized() and dist.get_rank() == 0:
+ logger.info(f"Converting {parent_name}.{name} to LoRA")
+ setattr(
+ module,
+ name,
+ LoraEmbedding(
+ child.weight,
+ r=lora_config.r,
+ lora_alpha=lora_config.lora_alpha,
+ lora_dropout=lora_config.embedding_lora_dropout,
+ num_embeddings=child.num_embeddings,
+ embedding_dim=child.embedding_dim,
+ padding_idx=child.padding_idx,
+ max_norm=child.max_norm,
+ norm_type=child.norm_type,
+ scale_grad_by_freq=child.scale_grad_by_freq,
+ sparse=child.sparse,
+ lora_initialization_method=lora_config.lora_initialization_method,
+ ),
+ )
else:
- _convert_to_lora_recursively(child, lora_rank)
+ _convert_to_lora_recursively(child, f"{parent_name}.{name}", lora_config)
-def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = "none") -> nn.Module:
+def convert_to_lora_module(module: nn.Module, lora_config: LoraConfig) -> nn.Module:
"""Convert a torch.nn.Module to a LoRA module.
Args:
module (nn.Module): The module to convert.
lora_rank (int): LoRA rank.
+ lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora".
+ lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA".
Returns:
nn.Module: The converted module.
"""
- if lora_rank <= 0:
+ if lora_config.r <= 0:
return module
- _convert_to_lora_recursively(module, lora_rank)
- lora.mark_only_lora_as_trainable(module, lora_train_bias)
+ # make all parameter not trainable, if lora_train_bias is "all", set bias to trainable
+ total_parameter_size = 0
+ for name, p in module.named_parameters():
+ p.requires_grad = False
+ if "bias" in name and lora_config.lora_train_bias == "all":
+ p.requires_grad = True
+ total_parameter_size += p.numel()
+ _convert_to_lora_recursively(module, "", lora_config)
+ trainable_parameter_size = 0
+ for name, p in module.named_parameters():
+ if p.requires_grad == True:
+ trainable_parameter_size += p.numel()
+ if dist.is_initialized() and dist.get_rank() == 0:
+ logger.info(
+ f"Trainable parameter size: {trainable_parameter_size/1024/1024:.2f}M\nOriginal trainable parameter size: {total_parameter_size/1024/1024:.2f}M\nPercentage: {trainable_parameter_size/total_parameter_size*100:.2f}%"
+ )
return module
diff --git a/applications/ColossalChat/coati/models/loss.py b/applications/ColossalChat/coati/models/loss.py
index e6872276d37e..bd0bbd36b9bc 100755
--- a/applications/ColossalChat/coati/models/loss.py
+++ b/applications/ColossalChat/coati/models/loss.py
@@ -5,6 +5,7 @@
from typing import Optional, Tuple
import torch
+import torch.distributed as dist
import torch.nn as nn
from .utils import masked_mean
@@ -45,7 +46,10 @@ def forward(
action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
skip = False
- ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
+ if action_mask is None:
+ ratio_ = (log_probs - old_log_probs).exp()
+ else:
+ ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
# note that if dropout is disabled (recommanded), ratio will always be 1.
if ratio_.mean() > self.skip_threshold:
@@ -55,7 +59,10 @@ def forward(
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
loss = -torch.min(surr1, surr2)
- loss = masked_mean(loss, action_mask)
+ if action_mask is not None:
+ loss = masked_mean(loss, action_mask)
+ else:
+ loss = loss.mean(dim=1)
loss = loss.mean()
return loss, skip, ratio_.max()
@@ -80,8 +87,10 @@ def forward(
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
surr1 = (values_clipped - returns) ** 2
surr2 = (values - returns) ** 2
- loss = torch.max(surr1, surr2) / torch.sum(action_mask)
- loss = torch.sum(loss * action_mask)
+ if action_mask is not None:
+ loss = torch.sum(torch.max(surr1, surr2) / torch.sum(action_mask) * action_mask)
+ else:
+ loss = torch.mean(torch.max(surr1, surr2))
return 0.5 * loss
@@ -201,7 +210,72 @@ def forward(
chosen_odds_masked = torch.sum(chosen_odds * chosen_loss_mask.float()) / torch.sum(chosen_loss_mask)
reject_odds = reject_logp - torch.log(-torch.exp(reject_logp) + 1.0001)
reject_odds_masked = torch.sum(reject_odds * reject_loss_mask.float()) / torch.sum(reject_loss_mask)
- # print("chosen_odds_masked", chosen_odds_masked[0], "reject_odds_masked", reject_odds_masked[0])
log_odds_ratio = chosen_odds_masked - reject_odds_masked
ratio = torch.log(torch.nn.functional.sigmoid(log_odds_ratio))
return ratio.to(dtype=torch.bfloat16), log_odds_ratio
+
+
+class KTOLoss(nn.Module):
+ def __init__(self, beta: float = 0.1, desirable_weight: float = 1.0, undesirable_weight: float = 1.0):
+ """
+ Args:
+ beta: The temperature parameter in the KTO paper.
+ desirable_weight: The weight for the desirable responses.
+ undesirable_weight: The weight for the undesirable
+ """
+ super().__init__()
+ self.beta = beta
+ self.desirable_weight = desirable_weight
+ self.undesirable_weight = undesirable_weight
+
+ def forward(
+ self,
+ chosen_logps: torch.Tensor,
+ rejected_logps: torch.Tensor,
+ kl_logps: torch.Tensor,
+ ref_chosen_logps: torch.Tensor,
+ ref_rejected_logps: torch.Tensor,
+ ref_kl_logps: torch.Tensor,
+ ):
+ """
+ Reference:
+ https://github.com/huggingface/trl/blob/a2adfb836a90d1e37b1253ab43dace05f1241e04/trl/trainer/kto_trainer.py#L585
+
+ Compute the KTO loss for a batch of policy and reference model log probabilities.
+ Args:
+ chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
+ rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
+ kl_logps: KL divergence of the policy model. Shape: (batch_size,)
+ ref_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
+ ref_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
+ ref_kl_logps: KL divergence of the reference model. Shape: (batch_size,)
+ beta: The temperature parameter in the DPO paper.
+ desirable_weight: The weight for the desirable responses.
+ undesirable_weight: The weight for the undesirable responses.
+
+ Refer to the KTO paper for details about hyperparameters https://arxiv.org/pdf/2402.01306
+ """
+ kl = (kl_logps - ref_kl_logps).mean().detach()
+ # all gather
+ dist.all_reduce(kl, op=dist.ReduceOp.SUM)
+ kl = (kl / dist.get_world_size()).clamp(min=0)
+
+ if chosen_logps.shape[0] != 0 and ref_chosen_logps.shape[0] != 0:
+ chosen_logratios = chosen_logps - ref_chosen_logps
+ chosen_losses = 1 - nn.functional.sigmoid(self.beta * (chosen_logratios - kl))
+ chosen_rewards = self.beta * chosen_logratios.detach()
+ else:
+ chosen_losses = torch.Tensor([]).to(kl_logps.device)
+ chosen_rewards = torch.Tensor([]).to(kl_logps.device)
+
+ if rejected_logps.shape[0] != 0 and ref_rejected_logps.shape[0] != 0:
+ rejected_logratios = rejected_logps - ref_rejected_logps
+ rejected_losses = 1 - nn.functional.sigmoid(self.beta * (kl - rejected_logratios))
+ rejected_rewards = self.beta * rejected_logratios.detach()
+ else:
+ rejected_losses = torch.Tensor([]).to(kl_logps.device)
+ rejected_rewards = torch.Tensor([]).to(kl_logps.device)
+
+ losses = torch.cat((self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), 0).mean()
+
+ return losses, chosen_rewards, rejected_rewards, kl
diff --git a/applications/ColossalChat/coati/trainer/__init__.py b/applications/ColossalChat/coati/trainer/__init__.py
index 6ce159678fc1..6d0900153e8a 100755
--- a/applications/ColossalChat/coati/trainer/__init__.py
+++ b/applications/ColossalChat/coati/trainer/__init__.py
@@ -1,8 +1,18 @@
from .base import OLTrainer, SLTrainer
from .dpo import DPOTrainer
+from .kto import KTOTrainer
from .orpo import ORPOTrainer
from .ppo import PPOTrainer
from .rm import RewardModelTrainer
from .sft import SFTTrainer
-__all__ = ["SLTrainer", "OLTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer", "DPOTrainer", "ORPOTrainer"]
+__all__ = [
+ "SLTrainer",
+ "OLTrainer",
+ "RewardModelTrainer",
+ "SFTTrainer",
+ "PPOTrainer",
+ "DPOTrainer",
+ "ORPOTrainer",
+ "KTOTrainer",
+]
diff --git a/applications/ColossalChat/coati/trainer/dpo.py b/applications/ColossalChat/coati/trainer/dpo.py
index 3daab54f6019..24ddca6545c8 100755
--- a/applications/ColossalChat/coati/trainer/dpo.py
+++ b/applications/ColossalChat/coati/trainer/dpo.py
@@ -26,7 +26,7 @@
class DPOTrainer(SLTrainer):
"""
- Trainer for PPO algorithm.
+ Trainer for DPO algorithm.
Args:
actor (Actor): the actor model in ppo algorithm
@@ -56,6 +56,7 @@ def __init__(
beta: float = 0.1,
gamma: float = 0.0,
length_normalization: bool = False,
+ apply_loss_mask: bool = True,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
@@ -67,6 +68,7 @@ def __init__(
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.actor_loss_fn = DpoLoss(beta, gamma)
+ self.apply_loss_mask = apply_loss_mask
self.save_interval = save_interval
self.coordinator = coordinator
self.save_dir = save_dir
@@ -135,6 +137,10 @@ def _train(self, epoch: int):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
+ if not self.apply_loss_mask:
+ chosen_loss_mask = chosen_loss_mask.fill_(1.0)
+ reject_loss_mask = reject_loss_mask.fill_(1.0)
+
batch_size = chosen_input_ids.size()[0]
actor_all_logits = self.model(
@@ -284,6 +290,9 @@ def _eval(self, epoch: int):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
+ if not self.apply_loss_mask:
+ chosen_loss_mask = chosen_loss_mask.fill_(1.0)
+ reject_loss_mask = reject_loss_mask.fill_(1.0)
batch_size = chosen_input_ids.size()[0]
diff --git a/applications/ColossalChat/coati/trainer/kto.py b/applications/ColossalChat/coati/trainer/kto.py
new file mode 100755
index 000000000000..6462ba816686
--- /dev/null
+++ b/applications/ColossalChat/coati/trainer/kto.py
@@ -0,0 +1,349 @@
+"""
+KTO trainer
+"""
+
+import os
+from typing import Any, Optional
+
+import torch
+import torch.distributed as dist
+from coati.models.loss import KTOLoss
+from coati.models.utils import calc_masked_log_probs
+from coati.trainer.utils import all_reduce_mean
+from coati.utils import AccumulativeMeanMeter, save_checkpoint
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+from torch.utils.data import DataLoader
+from tqdm import trange
+from transformers import PreTrainedTokenizerBase
+
+from colossalai.booster import Booster
+from colossalai.cluster import DistCoordinator
+from colossalai.utils import get_current_device
+
+from .base import SLTrainer
+from .utils import is_rank_0, to_device
+
+
+class KTOTrainer(SLTrainer):
+ """
+ Trainer for KTO algorithm.
+
+ Args:
+ actor (Actor): the actor model in ppo algorithm
+ ref_model (Critic): the reference model in ppo algorithm
+ booster (Strategy): the strategy to use for training
+ actor_optim (Optimizer): the optimizer to use for actor model
+ actor_lr_scheduler (_LRScheduler): the lr scheduler to use for actor model
+ tokenizer (PreTrainedTokenizerBase): the tokenizer to use for encoding
+ max_epochs (int, defaults to 1): the max number of epochs to train
+ accumulation_steps (int): the number of steps to accumulate gradients
+ start_epoch (int, defaults to 0): the start epoch, non-zero if resumed from a checkpoint
+ save_interval (int): the interval to save model checkpoints, default to 0, which means no checkpoint will be saved during trainning
+ save_dir (str): the directory to save checkpoints
+ coordinator (DistCoordinator): the coordinator to use for distributed logging
+ beta (float, defaults to 0.1): the beta parameter in kto loss
+ desirable_weight (float, defaults to 1.0): the weight for desirable reward
+ undesirable_weight (float, defaults to 1.0): the weight for undesirable reward
+ """
+
+ def __init__(
+ self,
+ actor: Any,
+ ref_model: Any,
+ booster: Booster,
+ actor_optim: Optimizer,
+ actor_lr_scheduler: _LRScheduler,
+ tokenizer: PreTrainedTokenizerBase,
+ max_epochs: int = 1,
+ beta: float = 0.1,
+ desirable_weight: float = 1.0,
+ undesirable_weight: float = 1.0,
+ apply_loss_mask: bool = True,
+ accumulation_steps: int = 1,
+ start_epoch: int = 0,
+ save_interval: int = 0,
+ save_dir: str = None,
+ coordinator: DistCoordinator = None,
+ ) -> None:
+ super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
+ self.ref_model = ref_model
+ self.actor_scheduler = actor_lr_scheduler
+ self.tokenizer = tokenizer
+ self.kto_loss = KTOLoss(beta=beta, desirable_weight=desirable_weight, undesirable_weight=undesirable_weight)
+ self.apply_loss_mask = apply_loss_mask
+ self.save_interval = save_interval
+ self.coordinator = coordinator
+ self.save_dir = save_dir
+ self.num_train_step = 0
+ self.accumulation_steps = accumulation_steps
+ self.device = get_current_device()
+ self.accumulative_meter = AccumulativeMeanMeter()
+ self.desirable_weight = desirable_weight
+ self.undesirable_weight = undesirable_weight
+ self.beta = beta
+
+ def _before_fit(
+ self,
+ train_preference_dataloader: DataLoader = None,
+ eval_preference_dataloader: DataLoader = None,
+ log_dir: Optional[str] = None,
+ use_wandb: bool = False,
+ ):
+ """
+ Args:
+ prompt_dataloader (DataLoader): the dataloader to use for prompt data
+ pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
+ """
+ self.train_dataloader = train_preference_dataloader
+ self.eval_dataloader = eval_preference_dataloader
+ self.writer = None
+ if use_wandb and is_rank_0():
+ assert log_dir is not None, "log_dir must be provided when use_wandb is True"
+ import wandb
+
+ self.wandb_run = wandb.init(project="Coati-kto", sync_tensorboard=True)
+ if log_dir is not None and is_rank_0():
+ import os
+ import time
+
+ from torch.utils.tensorboard import SummaryWriter
+
+ log_dir = os.path.join(log_dir, "kto")
+ log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
+ self.writer = SummaryWriter(log_dir=log_dir)
+
+ def _train(self, epoch: int):
+ """
+ Args:
+ epoch int: the number of current epoch
+ """
+ self.model.train()
+ self.accumulative_meter.reset()
+ step_bar = trange(
+ len(self.train_dataloader) // self.accumulation_steps,
+ desc=f"Epoch {epoch + 1}/{self.max_epochs}",
+ disable=not is_rank_0(),
+ )
+ for i, batch in enumerate(self.train_dataloader):
+ batch = to_device(batch, self.device)
+ (input_ids, attention_mask, loss_mask, label, kl_input_ids, kl_attention_mask, kl_loss_mask) = (
+ batch["input_ids"],
+ batch["attention_mask"],
+ batch["loss_mask"],
+ batch["label"],
+ batch["kl_input_ids"],
+ batch["kl_attention_mask"],
+ batch["kl_loss_mask"],
+ )
+ if not self.apply_loss_mask:
+ loss_mask = loss_mask.fill_(1.0)
+ kl_loss_mask = kl_loss_mask.fill_(1.0)
+
+ batch_size = input_ids.size()[0]
+
+ # actor logits
+ with torch.no_grad():
+ # calculate KL term with KT data
+ kl_logits = self.model(
+ input_ids=kl_input_ids,
+ attention_mask=kl_attention_mask,
+ )["logits"]
+
+ logits = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ )["logits"]
+
+ logprob = calc_masked_log_probs(logits, input_ids, loss_mask[:, 1:]).sum(-1)
+ kl_logprob = calc_masked_log_probs(kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1)
+ chosen_index = [i for i in range(batch_size) if label[i] == 1]
+ rejected_index = [i for i in range(batch_size) if label[i] == 0]
+ chosen_logprob = logprob[chosen_index]
+ rejected_logprob = logprob[rejected_index]
+ with torch.no_grad():
+ ref_kl_logits = self.ref_model(
+ input_ids=kl_input_ids,
+ attention_mask=kl_attention_mask,
+ )["logits"]
+ ref_logits = self.ref_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ )["logits"]
+
+ ref_logprob = calc_masked_log_probs(ref_logits, input_ids, loss_mask[:, 1:]).sum(-1)
+ ref_kl_logprob = calc_masked_log_probs(ref_kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1)
+ ref_chosen_logprob = ref_logprob[chosen_index]
+ ref_rejected_logprob = ref_logprob[rejected_index]
+
+ loss, chosen_rewards, rejected_rewards, kl = self.kto_loss(
+ chosen_logprob, rejected_logprob, kl_logprob, ref_chosen_logprob, ref_rejected_logprob, ref_kl_logprob
+ )
+
+ self.booster.backward(loss=loss, optimizer=self.optimizer)
+ if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+ self.actor_scheduler.step()
+
+ # sync
+ loss_mean = all_reduce_mean(tensor=loss)
+ chosen_reward_mean = chosen_rewards.mean()
+ chosen_rewards_list = [
+ torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size())
+ ]
+ dist.all_gather(chosen_rewards_list, chosen_reward_mean)
+ rejected_reward_mean = rejected_rewards.mean()
+ rejected_rewards_list = [
+ torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size())
+ ]
+ dist.all_gather(rejected_rewards_list, rejected_reward_mean)
+ chosen_rewards_list = [i for i in chosen_rewards_list if not i.isnan()]
+ rejected_rewards_list = [i for i in rejected_rewards_list if not i.isnan()]
+ chosen_rewards_mean = (
+ torch.stack(chosen_rewards_list).mean()
+ if len(chosen_rewards_list) > 0
+ else torch.tensor(torch.nan, dtype=loss.dtype, device=loss.device)
+ )
+ rejected_rewards_mean = (
+ torch.stack(rejected_rewards_list).mean()
+ if len(rejected_rewards_list) > 0
+ else torch.tensor(torch.nan, dtype=loss.dtype, device=loss.device)
+ )
+ self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
+ self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
+ self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item())
+
+ if i % self.accumulation_steps == self.accumulation_steps - 1:
+ self.num_train_step += 1
+ step_bar.update()
+ # logging
+ if self.writer and is_rank_0():
+ self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
+ self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
+ self.writer.add_scalar(
+ "train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
+ )
+ self.writer.add_scalar(
+ "train/rejected_rewards",
+ self.accumulative_meter.get("rejected_rewards"),
+ self.num_train_step,
+ )
+ self.writer.add_scalar(
+ "train/margin",
+ self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
+ self.num_train_step,
+ )
+ self.accumulative_meter.reset()
+
+ if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0:
+ # save checkpoint
+ self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
+ save_checkpoint(
+ save_dir=self.save_dir,
+ booster=self.booster,
+ model=self.model,
+ optimizer=self.optimizer,
+ lr_scheduler=self.actor_scheduler,
+ epoch=epoch,
+ step=i + 1,
+ batch_size=batch_size,
+ coordinator=self.coordinator,
+ )
+ self.coordinator.print_on_master(
+ f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
+ )
+
+ step_bar.close()
+
+ def _eval(self, epoch: int):
+ """
+ Args:
+ epoch int: the number of current epoch
+ """
+ if self.eval_dataloader is None:
+ self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
+ return
+ self.model.eval()
+ self.accumulative_meter.reset()
+ step_bar = trange(
+ len(self.train_dataloader) // self.accumulation_steps,
+ desc=f"Epoch {epoch + 1}/{self.max_epochs}",
+ disable=not is_rank_0(),
+ )
+ for i, batch in enumerate(self.train_dataloader):
+ batch = to_device(batch, self.device)
+ (input_ids, attention_mask, loss_mask, label, kl_input_ids, kl_attention_mask, kl_loss_mask) = (
+ batch["input_ids"],
+ batch["attention_mask"],
+ batch["loss_mask"],
+ batch["label"],
+ batch["kl_input_ids"],
+ batch["kl_attention_mask"],
+ batch["kl_loss_mask"],
+ )
+
+ if not self.apply_loss_mask:
+ loss_mask = loss_mask.fill_(1.0)
+ kl_loss_mask = kl_loss_mask.fill_(1.0)
+
+ batch_size = input_ids.size()[0]
+
+ # actor logits
+ with torch.no_grad():
+ # calculate KL term with KT data
+ kl_logits = self.model(
+ input_ids=kl_input_ids,
+ attention_mask=kl_attention_mask,
+ )["logits"]
+
+ logits = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ )["logits"]
+
+ logprob = calc_masked_log_probs(logits, input_ids, loss_mask[:, 1:]).sum(-1)
+ kl_logprob = calc_masked_log_probs(kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1)
+ chosen_index = [i for i in range(batch_size) if label[i] == 1]
+ rejected_index = [i for i in range(batch_size) if label[i] == 0]
+ chosen_logprob = logprob[chosen_index]
+ rejected_logprob = logprob[rejected_index]
+ with torch.no_grad():
+ ref_kl_logits = self.ref_model(
+ input_ids=kl_input_ids,
+ attention_mask=kl_attention_mask,
+ )["logits"]
+
+ ref_logits = self.ref_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ )["logits"]
+
+ ref_logprob = calc_masked_log_probs(ref_logits, input_ids, loss_mask[:, 1:]).sum(-1)
+ ref_kl_logprob = calc_masked_log_probs(ref_kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1)
+ ref_chosen_logprob = ref_logprob[chosen_index]
+ ref_rejected_logprob = ref_logprob[rejected_index]
+
+ loss, chosen_rewards, rejected_rewards, kl = self.kto_loss(
+ chosen_logprob, rejected_logprob, kl_logprob, ref_chosen_logprob, ref_rejected_logprob, ref_kl_logprob
+ )
+
+ # sync
+ loss_mean = all_reduce_mean(tensor=loss)
+ chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean())
+ rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean())
+ self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
+ self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
+ self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item())
+ self.accumulative_meter.add(
+ "margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item()
+ )
+ step_bar.update()
+ msg = "Evaluation Result:\n"
+ for tag in ["loss", "chosen_rewards", "rejected_rewards", "margin"]:
+ msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
+ self.coordinator.print_on_master(msg)
+ os.makedirs(self.save_dir, exist_ok=True)
+ with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
+ f.write(msg)
+ step_bar.close()
diff --git a/applications/ColossalChat/coati/trainer/orpo.py b/applications/ColossalChat/coati/trainer/orpo.py
index 495bb332b514..c2f75771cdff 100644
--- a/applications/ColossalChat/coati/trainer/orpo.py
+++ b/applications/ColossalChat/coati/trainer/orpo.py
@@ -26,7 +26,7 @@
class ORPOTrainer(SLTrainer):
"""
- Trainer for PPO algorithm.
+ Trainer for ORPO algorithm.
Args:
actor (Actor): the actor model in ppo algorithm
@@ -52,6 +52,7 @@ def __init__(
tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1,
lam: float = 0.1,
+ apply_loss_mask: bool = True,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
@@ -67,6 +68,7 @@ def __init__(
self.save_dir = save_dir
self.num_train_step = 0
self.lam = lam
+ self.apply_loss_mask = apply_loss_mask
self.accumulation_steps = accumulation_steps
self.device = get_current_device()
self.accumulative_meter = AccumulativeMeanMeter()
@@ -130,6 +132,11 @@ def _train(self, epoch: int):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
+
+ if not self.apply_loss_mask:
+ chosen_loss_mask = chosen_loss_mask.fill_(1.0)
+ reject_loss_mask = reject_loss_mask.fill_(1.0)
+
batch_size = chosen_input_ids.size()[0]
actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
@@ -263,6 +270,11 @@ def _eval(self, epoch: int):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
+
+ if not self.apply_loss_mask:
+ chosen_loss_mask = chosen_loss_mask.fill_(1.0)
+ reject_loss_mask = reject_loss_mask.fill_(1.0)
+
batch_size = chosen_input_ids.size()[0]
actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
diff --git a/applications/ColossalChat/coati/trainer/ppo.py b/applications/ColossalChat/coati/trainer/ppo.py
index 287767669516..63c813b39ef9 100755
--- a/applications/ColossalChat/coati/trainer/ppo.py
+++ b/applications/ColossalChat/coati/trainer/ppo.py
@@ -102,6 +102,7 @@ def __init__(
sample_buffer: bool = False,
dataloader_pin_memory: bool = True,
offload_inference_models: bool = True,
+ apply_loss_mask: bool = True,
accumulation_steps: int = 1,
save_interval: int = 0,
save_dir: str = None,
@@ -140,6 +141,7 @@ def __init__(
self.actor_optim = actor_optim
self.critic_optim = critic_optim
self.save_interval = save_interval
+ self.apply_loss_mask = apply_loss_mask
self.coordinator = coordinator
self.actor_save_dir = os.path.join(save_dir, "actor")
self.critic_save_dir = os.path.join(save_dir, "critic")
@@ -229,7 +231,10 @@ def _training_step(self, experience: Experience):
action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
actor_loss, to_skip, max_ratio = self.actor_loss_fn(
- action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
+ action_log_probs,
+ experience.action_log_probs,
+ experience.advantages,
+ action_mask=experience.action_mask if self.apply_loss_mask else None,
)
actor_loss = (1 - self.ptx_coef) * actor_loss
if not to_skip:
@@ -249,7 +254,10 @@ def _training_step(self, experience: Experience):
input_ids=experience.sequences, attention_mask=experience.attention_mask
) # [batch size, prompt_length + response_length]
critic_loss = self.critic_loss_fn(
- values[:, -num_actions:], experience.values, experience.advantages, action_mask=experience.action_mask
+ values[:, -num_actions:],
+ experience.values,
+ experience.advantages,
+ action_mask=experience.action_mask if self.apply_loss_mask else None,
)
critic_loss = critic_loss * self.vf_coef
self.critic_booster.backward(loss=critic_loss, optimizer=self.critic_optim)
diff --git a/applications/ColossalChat/coati/trainer/sft.py b/applications/ColossalChat/coati/trainer/sft.py
index 1484f5057a83..d37676ada3e0 100755
--- a/applications/ColossalChat/coati/trainer/sft.py
+++ b/applications/ColossalChat/coati/trainer/sft.py
@@ -41,6 +41,7 @@ def __init__(
lr_scheduler: _LRScheduler,
max_epochs: int = 2,
accumulation_steps: int = 8,
+ apply_loss_mask: bool = True,
start_epoch=0,
save_interval: int = None,
save_dir: str = None,
@@ -55,6 +56,7 @@ def __init__(
self.coordinator = coordinator
self.num_train_step = 0
self.num_eval_step = 0
+ self.apply_loss_mask = apply_loss_mask
self.accumulative_meter = AccumulativeMeanMeter()
def _before_fit(
@@ -100,9 +102,12 @@ def _train(self, epoch: int):
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
batch_size = batch["input_ids"].size(0)
- outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
+ outputs = self.model(
+ batch["input_ids"],
+ attention_mask=batch["attention_mask"],
+ labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
+ )
loss = outputs.loss
- step_bar.set_description(f"Epoch {epoch + 1}/{self.max_epochs} Loss: {loss.detach().cpu().item():.4f}")
self.booster.backward(loss=loss, optimizer=self.optimizer)
@@ -115,6 +120,7 @@ def _train(self, epoch: int):
self.optimizer.zero_grad()
self.scheduler.step()
+ step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
if self.writer:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
@@ -158,7 +164,11 @@ def _eval(self, epoch: int):
)
for batch in self.eval_dataloader:
batch = to_device(batch, torch.cuda.current_device())
- outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
+ outputs = self.model(
+ batch["input_ids"],
+ attention_mask=batch["attention_mask"],
+ labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
+ )
loss_mean = all_reduce_mean(tensor=outputs.loss)
self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0))
step_bar.update()
diff --git a/applications/ColossalChat/config/conversation_template/Qwen_Qwen1.5-32B-Chat.json b/applications/ColossalChat/config/conversation_template/Qwen_Qwen1.5-32B-Chat.json
new file mode 100644
index 000000000000..58941a5918ff
--- /dev/null
+++ b/applications/ColossalChat/config/conversation_template/Qwen_Qwen1.5-32B-Chat.json
@@ -0,0 +1,9 @@
+{
+ "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
+ "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ "stop_ids": [
+ 151645,
+ 151643
+ ],
+ "end_of_assistant": "<|im_end|>"
+}
diff --git a/applications/ColossalChat/config/conversation_template/tiny-llama.json b/applications/ColossalChat/config/conversation_template/tiny-llama.json
new file mode 100644
index 000000000000..59196159f930
--- /dev/null
+++ b/applications/ColossalChat/config/conversation_template/tiny-llama.json
@@ -0,0 +1,8 @@
+{
+ "chat_template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
+ "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ "stop_ids": [
+ 2
+ ],
+ "end_of_assistant": ""
+}
diff --git a/applications/ColossalChat/examples/README.md b/applications/ColossalChat/examples/README.md
index bdf4d23f1ad3..904d69cfcc4e 100755
--- a/applications/ColossalChat/examples/README.md
+++ b/applications/ColossalChat/examples/README.md
@@ -9,6 +9,7 @@
- [Install Requirements](#install-requirements)
- [Get Start with ColossalRun](#get-start-with-colossalrun)
- [Training Configuration](#training-configuration)
+ - [Parameter Efficient Finetuning (PEFT)](#parameter-efficient-finetuning-peft)
- [RLHF Stage 1: Supervised Instruction Tuning](#rlhf-training-stage1---supervised-instructs-tuning)
- [Step 1: Data Collection](#step-1-data-collection)
- [Step 2: Preprocessing](#step-2-preprocessing)
@@ -30,6 +31,8 @@
- [DPO Stage 1: Supervised Instruction Tuning](#dpo-training-stage1---supervised-instructs-tuning)
- [DPO Stage 2: DPO Training](#dpo-training-stage2---dpo-training)
- [Alternative Option For RLHF: Simple Preference Optimization](#alternative-option-for-rlhf-simple-preference-optimization)
+ - [Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
+ - [Alternative Option For RLHF: Odds Ratio Preference Optimization](#alternative-option-for-rlhf-odds-ratio-preference-optimization)
- [List of Supported Models](#list-of-supported-models)
- [Hardware Requirements](#hardware-requirements)
- [Inference example](#inference-example)
@@ -46,9 +49,6 @@
pip install -r requirements.txt
```
-
-
-
## Get Start with ColossalRun
@@ -82,8 +82,6 @@ Make sure the master node can access all nodes (including itself) by ssh without
This section gives a simple introduction on different training strategies that you can use and how to use them with our boosters and plugins to reduce training time and VRAM consumption. For more details regarding training strategies, please refer to [here](https://colossalai.org/docs/concepts/paradigms_of_parallelism). For details regarding boosters and plugins, please refer to [here](https://colossalai.org/docs/basics/booster_plugins).
-
-
Gemini (Zero3)
@@ -375,35 +373,6 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
Low Rank Adaption
-
-
-Details about Low Rank Adaption (LoRA) can be found in the paper: [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). It dramatically reduces the VRAM consumption at the cost of sacrifice model capability. It is suitable for training LLM with constrained resources.
-
-
-To enable LoRA, set --lora_rank to a positive value (usually between 20 and 64).
-```
-colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
- --pretrain $PRETRAINED_MODEL_PATH \
- --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
- --dataset ${dataset[@]} \
- --save_interval 5000 \
- --save_path $SAVE_DIR \
- --config_file $CONFIG_FILE \
- --plugin zero2_cpu \
- --batch_size 4 \
- --max_epochs 1 \
- --accumulation_steps 4 \
- --lr 2e-5 \
- --max_len 2048 \
- --lora_rank 32 \ # This enables LoRA
- --use_wandb
-```
-
-
-Other Training Arguments
@@ -418,6 +387,7 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
- save_dir: path to store the model checkpoints.
- max_length: input will be padded/truncated to max_length before feeding to the model.
- max_epochs: number of epochs to train.
+- disable_loss_mask: whether to use the loss mask to mask the loss or not. For example, in SFT, if the loss mask is disabled, the model will compute the loss across all tokens in the sequence, if the loss mask is applied, only tokens correspond to the assistant responses will contribute to the final loss.
- batch_size: training batch size.
- mixed_precision: precision to use in training. Support 'fp16' and 'bf16'. Note that some devices may not support the 'bf16' option, please refer to [Nvidia](https://developer.nvidia.com/) to check compatibility.
- save_interval: save the model weights as well as optimizer/scheduler states every save_interval steps/episodes.
@@ -428,6 +398,60 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
- use_wandb: if this flag is up, you can view logs on wandb.
+Low Rank Adaption and PiSSA
+
+
+Details about Low Rank Adaption (LoRA) can be found in the paper: [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). Details about Principal Singular Values and Singular Vectors Adaptation (PiSSA) can be found in the paper: [PiSSA: Principal Singular Values and Singular Vectors Adaptation of Large Language Models](https://arxiv.org/abs/2404.02948). Both help to reduce the running-time VRAM consumption as well as timing at the cost of overall model performance. It is suitable for training LLM with constrained resources.
+
+To use LoRA/PiSSA in training, please create a config file as in the following example and set the `--lora_config` to that configuration file.
+
+```json
+{
+ "r": 128,
+ "embedding_lora_dropout": 0.0,
+ "linear_lora_dropout": 0.1,
+ "lora_alpha": 32,
+ "lora_train_bias": "all",
+ "lora_initialization_method": "PiSSA",
+ "target_modules": ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj", "embed_tokens"]
+}
+```
+#### Lora Parameters
+- r: lora rank
+- embedding_lora_dropout: dropout probability for embedding layer
+- linear_lora_dropout: dropout probability for linear layer
+- lora_alpha: lora alpha, controls how much the adaptor can deviate from the pretrained model.
+- lora_train_bias: whether to add trainable bias to lora layers, choose from "all" (all layers (including but not limited to lora layers) will have trainable biases), "none" (no trainable biases), "lora" (only lora layers will have trainable biases)
+- lora_initialization_method: how to initialize lora weights, choose one from ["kaiming_uniform", "PiSSA"], default to "kaiming_uniform". Use "kaiming_uniform" for standard LoRA and "PiSSA" for PiSSA.
+- target_modules: which module(s) should be converted to lora layers, if the module's name contain the keywords in target modules and the module is a linear or embedding layer, the module will be converted. Otherwise, the module will be frozen. Setting this field to None will automatically convert all linear and embedding layer to their LoRA counterparts. Note that this example only works for LLaMA, for other models, you need to modify it.
+
+
+```
+colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
+ --pretrain $PRETRAINED_MODEL_PATH \
+ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
+ --dataset ${dataset[@]} \
+ --save_interval 5000 \
+ --save_path $SAVE_DIR \
+ --config_file $CONFIG_FILE \
+ --plugin zero2_cpu \
+ --batch_size 4 \
+ --max_epochs 1 \
+ --accumulation_steps 4 \
+ --lr 2e-5 \
+ --max_len 2048 \
+ --lora_config /PATH/TO/THE/LORA/CONFIG/FILE.json \ # Setting this enables LoRA
+ --use_wandb
+```
+
+
what are some pranks with a pen i can do? Assistant: Are you
#### Step 3: Training
-Choose a suitable model architecture for your task. Note that your model should be compatible with the tokenizer that you used to tokenize the SFT dataset. You can run [train_sft.sh](./examples/training_scripts/train_sft.sh) to start a supervised instructs fine-tuning. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
+Choose a suitable model architecture for your task. Note that your model should be compatible with the tokenizer that you used to tokenize the SFT dataset. You can run [train_sft.sh](./training_scripts/train_sft.sh) to start a supervised instructs fine-tuning. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
### RLHF Training Stage2 - Training Reward Model
@@ -527,7 +555,7 @@ Below shows the preference dataset format used in training the reward model.
[
{"context": [
{
- "from": "human",
+ "from": "user",
"content": "Introduce butterflies species in Oregon."
}
]
@@ -552,11 +580,11 @@ Below shows the preference dataset format used in training the reward model.
#### Step 2: Preprocessing
-Similar to the second step in the previous stage, we format the reward data into the same structured format as used in step 2 of the SFT stage. You can run [prepare_preference_dataset.sh](./examples/data_preparation_scripts/prepare_preference_dataset.sh) to prepare the preference data for reward model training.
+Similar to the second step in the previous stage, we format the reward data into the same structured format as used in step 2 of the SFT stage. You can run [prepare_preference_dataset.sh](./data_preparation_scripts/prepare_preference_dataset.sh) to prepare the preference data for reward model training.
#### Step 3: Training
-You can run [train_rm.sh](./examples/training_scripts/train_rm.sh) to start the reward model training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
+You can run [train_rm.sh](./training_scripts/train_rm.sh) to start the reward model training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
#### Features and Tricks in RM Training
@@ -596,7 +624,7 @@ In stage3 we will use reinforcement learning algorithm--- Proximal Policy Optimi
#### Step 1: Data Collection
-PPO uses two kinds of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "human" and thus the "assistant" needs to generate a response to answer to the "human". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
+PPO uses two kinds of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "user" and thus the "assistant" needs to generate a response to answer to the "user". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
```json
@@ -604,7 +632,7 @@ PPO uses two kinds of training data--- the prompt data and the pretrain data (op
{"messages":
[
{
- "from": "human",
+ "from": "user",
"content": "what are some pranks with a pen i can do?"
}
...
@@ -627,14 +655,14 @@ The second dataset--- pretrained dataset is optional, provide it if you want to
]
```
#### Step 2: Preprocessing
-To prepare the prompt dataset for PPO training, simply run [prepare_prompt_dataset.sh](./examples/data_preparation_scripts/prepare_prompt_dataset.sh)
+To prepare the prompt dataset for PPO training, simply run [prepare_prompt_dataset.sh](./data_preparation_scripts/prepare_prompt_dataset.sh)
You can use the SFT dataset you prepared in the SFT stage or prepare a new one from different source for the ptx dataset. The ptx data is used to calculate ptx loss, which stabilizes the training according to the [InstructGPT paper](https://arxiv.org/pdf/2203.02155.pdf).
#### Step 3: Training
-You can run the [train_ppo.sh](./examples/training_scripts/train_ppo.sh) to start PPO training. Here are some unique arguments for PPO, please refer to the training configuration section for other training configuration. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
+You can run the [train_ppo.sh](./training_scripts/train_ppo.sh) to start PPO training. Here are some unique arguments for PPO, please refer to the training configuration section for other training configuration. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
```bash
@@ -718,7 +746,7 @@ For DPO training, you only need the preference dataset. Please follow the instru
#### Step 2: Training
-You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. Following the trend of recent research on DPO-like alignment methods, we added option for the user to choose from, including whether to do length normalization , reward shaping and whether to use a reference model in calculating implicit reward. Here are those options,
+You can run the [train_dpo.sh](./training_scripts/train_dpo.sh) to start DPO training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. Following the trend of recent research on DPO-like alignment methods, we added option for the user to choose from, including whether to do length normalization , reward shaping and whether to use a reference model in calculating implicit reward. Here are those options,
```
--beta 0.1 \ # the temperature in DPO loss, Default to 0.1
@@ -735,7 +763,7 @@ You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to star
### Alternative Option For RLHF: Simple Preference Optimization
We support the method introduced in the paper [SimPO: Simple Preference Optimization
-with a Reference-Free Reward](https://arxiv.org/pdf/2405.14734) (SimPO). Which is a reference model free aligment method that add length normalization and reward shaping to the DPO loss to enhance training stability and efficiency. As the method doesn't deviate too much from DPO, we add support for length normalization and SimPO reward shaping in our DPO implementation. To use SimPO in alignment, use the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) script, set the `loss_type` to `simpo_loss`, you can also set the value for temperature (`beta`) and reward target margin (`gamma`) but it is optional.
+with a Reference-Free Reward](https://arxiv.org/pdf/2405.14734) (SimPO). Which is a reference model free aligment method that add length normalization and reward shaping to the DPO loss to enhance training stability and efficiency. As the method doesn't deviate too much from DPO, we add support for length normalization and SimPO reward shaping in our DPO implementation. To use SimPO in alignment, use the [train_dpo.sh](./training_scripts/train_dpo.sh) script, set the `loss_type` to `simpo_loss`, you can also set the value for temperature (`beta`) and reward target margin (`gamma`) but it is optional.
#### SimPO Result
@@ -744,13 +772,50 @@ with a Reference-Free Reward](https://arxiv.org/pdf/2405.14734) (SimPO). Which i ### Alternative Option For RLHF: Odds Ratio Preference Optimization -We support the method introduced in the paper [ORPO: Monolithic Preference Optimization without Reference Model](https://arxiv.org/abs/2403.07691) (ORPO). Which is a reference model free aligment method that mixes the SFT loss with a reinforcement learning loss that uses odds ratio as the implicit reward to enhance training stability and efficiency. Simply set the flag to disable the use of the reference model, set the reward target margin and enable length normalization in the DPO training script. To use ORPO in alignment, use the [train_orpo.sh](./examples/training_scripts/train_orpo.sh) script, You can set the value for `lambda` (which determine how strongly the reinforcement learning loss affect the training) but it is optional. +We support the method introduced in the paper [ORPO: Monolithic Preference Optimization without Reference Model](https://arxiv.org/abs/2403.07691) (ORPO). Which is a reference model free aligment method that mixes the SFT loss with a reinforcement learning loss that uses odds ratio as the implicit reward to enhance training stability and efficiency. To use ORPO in alignment, use the [train_orpo.sh](./training_scripts/train_orpo.sh) script, You can set the value for `lambda` (which determine how strongly the reinforcement learning loss affect the training) but it is optional. #### ORPO Result
+
+
@@ -310,4 +310,14 @@ If you wish to cite relevant research papars, you can find the reference below.
journal={arXiv},
year={2023}
}
+
+# Distrifusion
+@InProceedings{Li_2024_CVPR,
+ author={Li, Muyang and Cai, Tianle and Cao, Jiaxin and Zhang, Qinsheng and Cai, Han and Bai, Junjie and Jia, Yangqing and Li, Kai and Han, Song},
+ title={DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models},
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
+ month={June},
+ year={2024},
+ pages={7183-7193}
+}
```
diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py
index 1beb86874826..072ddbcfd298 100644
--- a/colossalai/inference/config.py
+++ b/colossalai/inference/config.py
@@ -186,6 +186,7 @@ class InferenceConfig(RPC_PARAM):
enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation.
start_token_size(int): The size of the start tokens, when using StreamingLLM.
generated_token_size(int): The size of the generated tokens, When using StreamingLLM.
+ patched_parallelism_size(int): Patched Parallelism Size, When using Distrifusion
"""
# NOTE: arrange configs according to their importance and frequency of usage
@@ -245,6 +246,11 @@ class InferenceConfig(RPC_PARAM):
start_token_size: int = 4
generated_token_size: int = 512
+ # Acceleration for Diffusion Model(PipeFusion or Distrifusion)
+ patched_parallelism_size: int = 1 # for distrifusion
+ # pipeFusion_m_size: int = 1 # for pipefusion
+ # pipeFusion_n_size: int = 1 # for pipefusion
+
def __post_init__(self):
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
self._verify_config()
@@ -288,6 +294,14 @@ def _verify_config(self) -> None:
# Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit.
self.start_token_size = self.block_size
+ # check Distrifusion
+ # TODO(@lry89757) need more detailed check
+ if self.patched_parallelism_size > 1:
+ # self.use_patched_parallelism = True
+ self.tp_size = (
+ self.patched_parallelism_size
+ ) # this is not a real tp, because some annoying check, so we have to set this to patched_parallelism_size
+
# check prompt template
if self.prompt_template is None:
return
@@ -324,6 +338,7 @@ def to_model_shard_inference_config(self) -> "ModelShardInferenceConfig":
use_cuda_kernel=self.use_cuda_kernel,
use_spec_dec=self.use_spec_dec,
use_flash_attn=use_flash_attn,
+ patched_parallelism_size=self.patched_parallelism_size,
)
return model_inference_config
@@ -396,6 +411,7 @@ class ModelShardInferenceConfig:
use_cuda_kernel: bool = False
use_spec_dec: bool = False
use_flash_attn: bool = False
+ patched_parallelism_size: int = 1 # for diffusion model, Distrifusion Technique
@dataclass
diff --git a/colossalai/inference/core/diffusion_engine.py b/colossalai/inference/core/diffusion_engine.py
index 75b9889bf28d..8bed508cba55 100644
--- a/colossalai/inference/core/diffusion_engine.py
+++ b/colossalai/inference/core/diffusion_engine.py
@@ -11,7 +11,7 @@
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig
-from colossalai.inference.modeling.models.diffusion import DiffusionPipe
+from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.struct import DiffusionSequence
from colossalai.inference.utils import get_model_size, get_model_type
diff --git a/colossalai/inference/modeling/models/diffusion.py b/colossalai/inference/modeling/layers/diffusion.py
similarity index 100%
rename from colossalai/inference/modeling/models/diffusion.py
rename to colossalai/inference/modeling/layers/diffusion.py
diff --git a/colossalai/inference/modeling/layers/distrifusion.py b/colossalai/inference/modeling/layers/distrifusion.py
new file mode 100644
index 000000000000..ea97cceefac9
--- /dev/null
+++ b/colossalai/inference/modeling/layers/distrifusion.py
@@ -0,0 +1,626 @@
+# Code refer and adapted from:
+# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers
+# https://github.com/PipeFusion/PipeFusion
+
+import inspect
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from diffusers.models import attention_processor
+from diffusers.models.attention import Attention
+from diffusers.models.embeddings import PatchEmbed, get_2d_sincos_pos_embed
+from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel
+from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
+from torch import nn
+from torch.distributed import ProcessGroup
+
+from colossalai.inference.config import ModelShardInferenceConfig
+from colossalai.logging import get_dist_logger
+from colossalai.shardformer.layer.parallel_module import ParallelModule
+from colossalai.utils import get_current_device
+
+try:
+ from flash_attn import flash_attn_func
+
+ HAS_FLASH_ATTN = True
+except ImportError:
+ HAS_FLASH_ATTN = False
+
+
+logger = get_dist_logger(__name__)
+
+
+# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_2d.py
+def PixArtAlphaTransformer2DModel_forward(
+ self: PixArtTransformer2DModel,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+):
+ assert hasattr(
+ self, "patched_parallel_size"
+ ), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`"
+
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None and attention_mask.ndim == 2:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 1. Input
+ batch_size = hidden_states.shape[0]
+ height, width = (
+ hidden_states.shape[-2] // self.config.patch_size,
+ hidden_states.shape[-1] // self.config.patch_size,
+ )
+ hidden_states = self.pos_embed(hidden_states)
+
+ timestep, embedded_timestep = self.adaln_single(
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
+ )
+
+ if self.caption_projection is not None:
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
+
+ # 2. Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ timestep=timestep,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=class_labels,
+ )
+
+ # 3. Output
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)).chunk(
+ 2, dim=1
+ )
+ hidden_states = self.norm_out(hidden_states)
+ # Modulation
+ hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.squeeze(1)
+
+ # unpatchify
+ hidden_states = hidden_states.reshape(
+ shape=(
+ -1,
+ height // self.patched_parallel_size,
+ width,
+ self.config.patch_size,
+ self.config.patch_size,
+ self.out_channels,
+ )
+ )
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(
+ -1,
+ self.out_channels,
+ height // self.patched_parallel_size * self.config.patch_size,
+ width * self.config.patch_size,
+ )
+ )
+
+ # enable Distrifusion Optimization
+ if hasattr(self, "patched_parallel_size"):
+ from torch import distributed as dist
+
+ if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape):
+ self.output_buffer = torch.empty_like(output)
+ if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape):
+ self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)]
+ output = output.contiguous()
+ dist.all_gather(self.buffer_list, output, async_op=False)
+ torch.cat(self.buffer_list, dim=2, out=self.output_buffer)
+ output = self.output_buffer
+
+ return (output,)
+
+
+# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_sd3.py
+def SD3Transformer2DModel_forward(
+ self: SD3Transformer2DModel,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ pooled_projections: torch.FloatTensor = None,
+ timestep: torch.LongTensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+) -> Union[torch.FloatTensor]:
+
+ assert hasattr(
+ self, "patched_parallel_size"
+ ), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`"
+
+ height, width = hidden_states.shape[-2:]
+
+ hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
+ temb = self.time_text_embed(timestep, pooled_projections)
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ for block in self.transformer_blocks:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
+ )
+
+ hidden_states = self.norm_out(hidden_states, temb)
+ hidden_states = self.proj_out(hidden_states)
+
+ # unpatchify
+ patch_size = self.config.patch_size
+ height = height // patch_size // self.patched_parallel_size
+ width = width // patch_size
+
+ hidden_states = hidden_states.reshape(
+ shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
+ )
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
+ )
+
+ # enable Distrifusion Optimization
+ if hasattr(self, "patched_parallel_size"):
+ from torch import distributed as dist
+
+ if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape):
+ self.output_buffer = torch.empty_like(output)
+ if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape):
+ self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)]
+ output = output.contiguous()
+ dist.all_gather(self.buffer_list, output, async_op=False)
+ torch.cat(self.buffer_list, dim=2, out=self.output_buffer)
+ output = self.output_buffer
+
+ return (output,)
+
+
+# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/patchembed.py
+class DistrifusionPatchEmbed(ParallelModule):
+ def __init__(
+ self,
+ module: PatchEmbed,
+ process_group: Union[ProcessGroup, List[ProcessGroup]],
+ model_shard_infer_config: ModelShardInferenceConfig = None,
+ ):
+ super().__init__()
+ self.module = module
+ self.rank = dist.get_rank(group=process_group)
+ self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
+
+ @staticmethod
+ def from_native_module(module: PatchEmbed, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs):
+ model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
+ distrifusion_embed = DistrifusionPatchEmbed(
+ module, process_group, model_shard_infer_config=model_shard_infer_config
+ )
+ return distrifusion_embed
+
+ def forward(self, latent):
+ module = self.module
+ if module.pos_embed_max_size is not None:
+ height, width = latent.shape[-2:]
+ else:
+ height, width = latent.shape[-2] // module.patch_size, latent.shape[-1] // module.patch_size
+
+ latent = module.proj(latent)
+ if module.flatten:
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
+ if module.layer_norm:
+ latent = module.norm(latent)
+ if module.pos_embed is None:
+ return latent.to(latent.dtype)
+ # Interpolate or crop positional embeddings as needed
+ if module.pos_embed_max_size:
+ pos_embed = module.cropped_pos_embed(height, width)
+ else:
+ if module.height != height or module.width != width:
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim=module.pos_embed.shape[-1],
+ grid_size=(height, width),
+ base_size=module.base_size,
+ interpolation_scale=module.interpolation_scale,
+ )
+ pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
+ else:
+ pos_embed = module.pos_embed
+
+ b, c, h = pos_embed.shape
+ pos_embed = pos_embed.view(b, self.patched_parallelism_size, -1, h)[:, self.rank]
+
+ return (latent + pos_embed).to(latent.dtype)
+
+
+# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/conv2d.py
+class DistrifusionConv2D(ParallelModule):
+
+ def __init__(
+ self,
+ module: nn.Conv2d,
+ process_group: Union[ProcessGroup, List[ProcessGroup]],
+ model_shard_infer_config: ModelShardInferenceConfig = None,
+ ):
+ super().__init__()
+ self.module = module
+ self.rank = dist.get_rank(group=process_group)
+ self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
+
+ @staticmethod
+ def from_native_module(module: nn.Conv2d, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs):
+ model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
+ distrifusion_conv = DistrifusionConv2D(module, process_group, model_shard_infer_config=model_shard_infer_config)
+ return distrifusion_conv
+
+ def sliced_forward(self, x: torch.Tensor) -> torch.Tensor:
+
+ b, c, h, w = x.shape
+
+ stride = self.module.stride[0]
+ padding = self.module.padding[0]
+
+ output_h = x.shape[2] // stride // self.patched_parallelism_size
+ idx = dist.get_rank()
+ h_begin = output_h * idx * stride - padding
+ h_end = output_h * (idx + 1) * stride + padding
+ final_padding = [padding, padding, 0, 0]
+ if h_begin < 0:
+ h_begin = 0
+ final_padding[2] = padding
+ if h_end > h:
+ h_end = h
+ final_padding[3] = padding
+ sliced_input = x[:, :, h_begin:h_end, :]
+ padded_input = F.pad(sliced_input, final_padding, mode="constant")
+ return F.conv2d(
+ padded_input,
+ self.module.weight,
+ self.module.bias,
+ stride=stride,
+ padding="valid",
+ )
+
+ def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ output = self.sliced_forward(input)
+ return output
+
+
+# Code adapted from: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/attention_processor.py
+class DistrifusionFusedAttention(ParallelModule):
+
+ def __init__(
+ self,
+ module: attention_processor.Attention,
+ process_group: Union[ProcessGroup, List[ProcessGroup]],
+ model_shard_infer_config: ModelShardInferenceConfig = None,
+ ):
+ super().__init__()
+ self.counter = 0
+ self.module = module
+ self.buffer_list = None
+ self.kv_buffer_idx = dist.get_rank(group=process_group)
+ self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
+ self.handle = None
+ self.process_group = process_group
+ self.warm_step = 5 # for warmup
+
+ @staticmethod
+ def from_native_module(
+ module: attention_processor.Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
+ ) -> ParallelModule:
+ model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
+ return DistrifusionFusedAttention(
+ module=module,
+ process_group=process_group,
+ model_shard_infer_config=model_shard_infer_config,
+ )
+
+ def _forward(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+ context_input_ndim = encoder_hidden_states.ndim
+ if context_input_ndim == 4:
+ batch_size, channel, height, width = encoder_hidden_states.shape
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size = encoder_hidden_states.shape[0]
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ kv = torch.cat([key, value], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2)
+
+ if self.patched_parallelism_size == 1:
+ full_kv = kv
+ else:
+ if self.buffer_list is None: # buffer not created
+ full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1)
+ elif self.counter <= self.warm_step:
+ # logger.info(f"warmup: {self.counter}")
+ dist.all_gather(
+ self.buffer_list,
+ kv,
+ group=self.process_group,
+ async_op=False,
+ )
+ full_kv = torch.cat(self.buffer_list, dim=1)
+ else:
+ # logger.info(f"use old kv to infer: {self.counter}")
+ self.buffer_list[self.kv_buffer_idx].copy_(kv)
+ full_kv = torch.cat(self.buffer_list, dim=1)
+ assert self.handle is None, "we should maintain the kv of last step"
+ self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True)
+
+ key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
+
+ # `context` projections.
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ # attention
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ hidden_states = hidden_states = F.scaled_dot_product_attention(
+ query, key, value, dropout_p=0.0, is_causal=False
+ ) # NOTE(@lry89757) for torch >= 2.2, flash attn has been already integrated into scaled_dot_product_attention, https://pytorch.org/blog/pytorch2-2/
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # Split the attention outputs.
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : residual.shape[1]],
+ hidden_states[:, residual.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ if not attn.context_pre_only:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ if context_input_ndim == 4:
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ return hidden_states, encoder_hidden_states
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **cross_attention_kwargs,
+ ) -> torch.Tensor:
+
+ if self.handle is not None:
+ self.handle.wait()
+ self.handle = None
+
+ b, l, c = hidden_states.shape
+ kv_shape = (b, l, self.module.to_k.out_features * 2)
+ if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape):
+
+ self.buffer_list = [
+ torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device())
+ for _ in range(self.patched_parallelism_size)
+ ]
+
+ self.counter = 0
+
+ attn_parameters = set(inspect.signature(self.module.processor.__call__).parameters.keys())
+ quiet_attn_parameters = {"ip_adapter_masks"}
+ unused_kwargs = [
+ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
+ ]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"cross_attention_kwargs {unused_kwargs} are not expected by {self.module.processor.__class__.__name__} and will be ignored."
+ )
+ cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
+
+ output = self._forward(
+ self.module,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ self.counter += 1
+
+ return output
+
+
+# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/attn.py
+class DistriSelfAttention(ParallelModule):
+ def __init__(
+ self,
+ module: Attention,
+ process_group: Union[ProcessGroup, List[ProcessGroup]],
+ model_shard_infer_config: ModelShardInferenceConfig = None,
+ ):
+ super().__init__()
+ self.counter = 0
+ self.module = module
+ self.buffer_list = None
+ self.kv_buffer_idx = dist.get_rank(group=process_group)
+ self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
+ self.handle = None
+ self.process_group = process_group
+ self.warm_step = 3 # for warmup
+
+ @staticmethod
+ def from_native_module(
+ module: Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
+ ) -> ParallelModule:
+ model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
+ return DistriSelfAttention(
+ module=module,
+ process_group=process_group,
+ model_shard_infer_config=model_shard_infer_config,
+ )
+
+ def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0):
+ attn = self.module
+ assert isinstance(attn, Attention)
+
+ residual = hidden_states
+
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ query = attn.to_q(hidden_states)
+
+ encoder_hidden_states = hidden_states
+ k = self.module.to_k(encoder_hidden_states)
+ v = self.module.to_v(encoder_hidden_states)
+ kv = torch.cat([k, v], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2)
+
+ if self.patched_parallelism_size == 1:
+ full_kv = kv
+ else:
+ if self.buffer_list is None: # buffer not created
+ full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1)
+ elif self.counter <= self.warm_step:
+ # logger.info(f"warmup: {self.counter}")
+ dist.all_gather(
+ self.buffer_list,
+ kv,
+ group=self.process_group,
+ async_op=False,
+ )
+ full_kv = torch.cat(self.buffer_list, dim=1)
+ else:
+ # logger.info(f"use old kv to infer: {self.counter}")
+ self.buffer_list[self.kv_buffer_idx].copy_(kv)
+ full_kv = torch.cat(self.buffer_list, dim=1)
+ assert self.handle is None, "we should maintain the kv of last step"
+ self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True)
+
+ if HAS_FLASH_ATTN:
+ # flash attn
+ key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, attn.heads, head_dim)
+ value = value.view(batch_size, -1, attn.heads, head_dim)
+
+ hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
+ hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype)
+ else:
+ # naive attn
+ key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ scale: float = 1.0,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+
+ # async preallocates memo buffer
+ if self.handle is not None:
+ self.handle.wait()
+ self.handle = None
+
+ b, l, c = hidden_states.shape
+ kv_shape = (b, l, self.module.to_k.out_features * 2)
+ if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape):
+
+ self.buffer_list = [
+ torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device())
+ for _ in range(self.patched_parallelism_size)
+ ]
+
+ self.counter = 0
+
+ output = self._forward(hidden_states, scale=scale)
+
+ self.counter += 1
+ return output
diff --git a/colossalai/inference/modeling/models/pixart_alpha.py b/colossalai/inference/modeling/models/pixart_alpha.py
index d5774946e365..cc2bee5efd4d 100644
--- a/colossalai/inference/modeling/models/pixart_alpha.py
+++ b/colossalai/inference/modeling/models/pixart_alpha.py
@@ -14,7 +14,7 @@
from colossalai.logging import get_dist_logger
-from .diffusion import DiffusionPipe
+from ..layers.diffusion import DiffusionPipe
logger = get_dist_logger(__name__)
diff --git a/colossalai/inference/modeling/models/stablediffusion3.py b/colossalai/inference/modeling/models/stablediffusion3.py
index d1c63a6dc665..b123164039c8 100644
--- a/colossalai/inference/modeling/models/stablediffusion3.py
+++ b/colossalai/inference/modeling/models/stablediffusion3.py
@@ -4,7 +4,7 @@
import torch
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
-from .diffusion import DiffusionPipe
+from ..layers.diffusion import DiffusionPipe
# TODO(@lry89757) temporarily image, please support more return output
diff --git a/colossalai/inference/modeling/policy/pixart_alpha.py b/colossalai/inference/modeling/policy/pixart_alpha.py
index 356056ba73e7..1150b2432cc5 100644
--- a/colossalai/inference/modeling/policy/pixart_alpha.py
+++ b/colossalai/inference/modeling/policy/pixart_alpha.py
@@ -1,9 +1,17 @@
+from diffusers.models.attention import BasicTransformerBlock
+from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel
from torch import nn
from colossalai.inference.config import RPC_PARAM
-from colossalai.inference.modeling.models.diffusion import DiffusionPipe
+from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
+from colossalai.inference.modeling.layers.distrifusion import (
+ DistrifusionConv2D,
+ DistrifusionPatchEmbed,
+ DistriSelfAttention,
+ PixArtAlphaTransformer2DModel_forward,
+)
from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward
-from colossalai.shardformer.policies.base_policy import Policy
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
@@ -12,9 +20,46 @@ def __init__(self) -> None:
def module_policy(self):
policy = {}
+
+ if self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size > 1:
+
+ policy[PixArtTransformer2DModel] = ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="pos_embed.proj",
+ target_module=DistrifusionConv2D,
+ kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
+ ),
+ SubModuleReplacementDescription(
+ suffix="pos_embed",
+ target_module=DistrifusionPatchEmbed,
+ kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
+ ),
+ ],
+ attribute_replacement={
+ "patched_parallel_size": self.shard_config.extra_kwargs[
+ "model_shard_infer_config"
+ ].patched_parallelism_size
+ },
+ method_replacement={"forward": PixArtAlphaTransformer2DModel_forward},
+ )
+
+ policy[BasicTransformerBlock] = ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="attn1",
+ target_module=DistriSelfAttention,
+ kwargs={
+ "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
+ },
+ )
+ ]
+ )
+
self.append_or_create_method_replacement(
description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe
)
+
return policy
def preprocess(self) -> nn.Module:
diff --git a/colossalai/inference/modeling/policy/stablediffusion3.py b/colossalai/inference/modeling/policy/stablediffusion3.py
index c9877f7dcae6..39b764b92887 100644
--- a/colossalai/inference/modeling/policy/stablediffusion3.py
+++ b/colossalai/inference/modeling/policy/stablediffusion3.py
@@ -1,9 +1,17 @@
+from diffusers.models.attention import JointTransformerBlock
+from diffusers.models.transformers import SD3Transformer2DModel
from torch import nn
from colossalai.inference.config import RPC_PARAM
-from colossalai.inference.modeling.models.diffusion import DiffusionPipe
+from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
+from colossalai.inference.modeling.layers.distrifusion import (
+ DistrifusionConv2D,
+ DistrifusionFusedAttention,
+ DistrifusionPatchEmbed,
+ SD3Transformer2DModel_forward,
+)
from colossalai.inference.modeling.models.stablediffusion3 import sd3_forward
-from colossalai.shardformer.policies.base_policy import Policy
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class StableDiffusion3InferPolicy(Policy, RPC_PARAM):
@@ -12,6 +20,42 @@ def __init__(self) -> None:
def module_policy(self):
policy = {}
+
+ if self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size > 1:
+
+ policy[SD3Transformer2DModel] = ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="pos_embed.proj",
+ target_module=DistrifusionConv2D,
+ kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
+ ),
+ SubModuleReplacementDescription(
+ suffix="pos_embed",
+ target_module=DistrifusionPatchEmbed,
+ kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
+ ),
+ ],
+ attribute_replacement={
+ "patched_parallel_size": self.shard_config.extra_kwargs[
+ "model_shard_infer_config"
+ ].patched_parallelism_size
+ },
+ method_replacement={"forward": SD3Transformer2DModel_forward},
+ )
+
+ policy[JointTransformerBlock] = ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="attn",
+ target_module=DistrifusionFusedAttention,
+ kwargs={
+ "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
+ },
+ )
+ ]
+ )
+
self.append_or_create_method_replacement(
description={"forward": sd3_forward}, policy=policy, target_key=DiffusionPipe
)
diff --git a/colossalai/shardformer/layer/moe/__init__.py b/colossalai/legacy/moe/layer/__init__.py
similarity index 100%
rename from colossalai/shardformer/layer/moe/__init__.py
rename to colossalai/legacy/moe/layer/__init__.py
diff --git a/colossalai/shardformer/layer/moe/experts.py b/colossalai/legacy/moe/layer/experts.py
similarity index 95%
rename from colossalai/shardformer/layer/moe/experts.py
rename to colossalai/legacy/moe/layer/experts.py
index 1be7a27547ed..8088cf44e473 100644
--- a/colossalai/shardformer/layer/moe/experts.py
+++ b/colossalai/legacy/moe/layer/experts.py
@@ -5,9 +5,9 @@
import torch.nn as nn
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
-from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
-from colossalai.moe.manager import MOE_MANAGER
-from colossalai.moe.utils import get_activation
+from colossalai.legacy.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.utils import get_activation
+from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size
@@ -118,7 +118,7 @@ def forward(
Returns:
torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
"""
- x = MoeInGradScaler.apply(x, self.ep_size)
+ x = EPGradScalerIn.apply(x, self.ep_size)
e = x.size(1)
h = x.size(-1)
@@ -157,5 +157,5 @@ def forward(
x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
x = x.reshape(inshape)
x = x.transpose(0, 1).contiguous()
- x = MoeOutGradScaler.apply(x, self.ep_size)
+ x = EPGradScalerOut.apply(x, self.ep_size)
return x
diff --git a/colossalai/shardformer/layer/moe/layers.py b/colossalai/legacy/moe/layer/layers.py
similarity index 99%
rename from colossalai/shardformer/layer/moe/layers.py
rename to colossalai/legacy/moe/layer/layers.py
index e5b0ef97fd87..e43966f68a8c 100644
--- a/colossalai/shardformer/layer/moe/layers.py
+++ b/colossalai/legacy/moe/layer/layers.py
@@ -7,9 +7,9 @@
import torch.nn as nn
import torch.nn.functional as F
+from colossalai.legacy.moe.load_balance import LoadBalancer
+from colossalai.legacy.moe.utils import create_ep_hierarchical_group, get_noise_generator
from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter
-from colossalai.moe.load_balance import LoadBalancer
-from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator
from colossalai.shardformer.layer.moe import MLPExperts
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size
diff --git a/colossalai/shardformer/layer/moe/routers.py b/colossalai/legacy/moe/layer/routers.py
similarity index 95%
rename from colossalai/shardformer/layer/moe/routers.py
rename to colossalai/legacy/moe/layer/routers.py
index 1be7a27547ed..8088cf44e473 100644
--- a/colossalai/shardformer/layer/moe/routers.py
+++ b/colossalai/legacy/moe/layer/routers.py
@@ -5,9 +5,9 @@
import torch.nn as nn
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
-from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
-from colossalai.moe.manager import MOE_MANAGER
-from colossalai.moe.utils import get_activation
+from colossalai.legacy.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.utils import get_activation
+from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size
@@ -118,7 +118,7 @@ def forward(
Returns:
torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
"""
- x = MoeInGradScaler.apply(x, self.ep_size)
+ x = EPGradScalerIn.apply(x, self.ep_size)
e = x.size(1)
h = x.size(-1)
@@ -157,5 +157,5 @@ def forward(
x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
x = x.reshape(inshape)
x = x.transpose(0, 1).contiguous()
- x = MoeOutGradScaler.apply(x, self.ep_size)
+ x = EPGradScalerOut.apply(x, self.ep_size)
return x
diff --git a/colossalai/moe/load_balance.py b/colossalai/legacy/moe/load_balance.py
similarity index 99%
rename from colossalai/moe/load_balance.py
rename to colossalai/legacy/moe/load_balance.py
index 3dc6c02c7445..7339b1a7b0eb 100644
--- a/colossalai/moe/load_balance.py
+++ b/colossalai/legacy/moe/load_balance.py
@@ -7,7 +7,7 @@
from torch.distributed import ProcessGroup
from colossalai.cluster import ProcessGroupMesh
-from colossalai.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.manager import MOE_MANAGER
from colossalai.shardformer.layer.moe import MLPExperts
from colossalai.zero.low_level import LowLevelZeroOptimizer
diff --git a/colossalai/moe/manager.py b/colossalai/legacy/moe/manager.py
similarity index 100%
rename from colossalai/moe/manager.py
rename to colossalai/legacy/moe/manager.py
diff --git a/examples/language/openmoe/README.md b/colossalai/legacy/moe/openmoe/README.md
similarity index 100%
rename from examples/language/openmoe/README.md
rename to colossalai/legacy/moe/openmoe/README.md
diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/colossalai/legacy/moe/openmoe/benchmark/benchmark_cai.py
similarity index 99%
rename from examples/language/openmoe/benchmark/benchmark_cai.py
rename to colossalai/legacy/moe/openmoe/benchmark/benchmark_cai.py
index b9ef915c32a4..5f9447246ae4 100644
--- a/examples/language/openmoe/benchmark/benchmark_cai.py
+++ b/colossalai/legacy/moe/openmoe/benchmark/benchmark_cai.py
@@ -18,9 +18,9 @@
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
+from colossalai.legacy.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.utils import skip_init
from colossalai.moe.layers import apply_load_balance
-from colossalai.moe.manager import MOE_MANAGER
-from colossalai.moe.utils import skip_init
from colossalai.nn.optimizer import HybridAdam
diff --git a/examples/language/openmoe/benchmark/benchmark_cai.sh b/colossalai/legacy/moe/openmoe/benchmark/benchmark_cai.sh
similarity index 100%
rename from examples/language/openmoe/benchmark/benchmark_cai.sh
rename to colossalai/legacy/moe/openmoe/benchmark/benchmark_cai.sh
diff --git a/examples/language/openmoe/benchmark/benchmark_cai_dist.sh b/colossalai/legacy/moe/openmoe/benchmark/benchmark_cai_dist.sh
similarity index 100%
rename from examples/language/openmoe/benchmark/benchmark_cai_dist.sh
rename to colossalai/legacy/moe/openmoe/benchmark/benchmark_cai_dist.sh
diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.py b/colossalai/legacy/moe/openmoe/benchmark/benchmark_fsdp.py
similarity index 98%
rename from examples/language/openmoe/benchmark/benchmark_fsdp.py
rename to colossalai/legacy/moe/openmoe/benchmark/benchmark_fsdp.py
index b00fbd001022..1ae94dd90977 100644
--- a/examples/language/openmoe/benchmark/benchmark_fsdp.py
+++ b/colossalai/legacy/moe/openmoe/benchmark/benchmark_fsdp.py
@@ -14,7 +14,7 @@
from transformers.models.llama import LlamaConfig
from utils import PerformanceEvaluator, get_model_numel
-from colossalai.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.manager import MOE_MANAGER
class RandomDataset(Dataset):
diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.sh b/colossalai/legacy/moe/openmoe/benchmark/benchmark_fsdp.sh
similarity index 100%
rename from examples/language/openmoe/benchmark/benchmark_fsdp.sh
rename to colossalai/legacy/moe/openmoe/benchmark/benchmark_fsdp.sh
diff --git a/examples/language/openmoe/benchmark/hostfile.txt b/colossalai/legacy/moe/openmoe/benchmark/hostfile.txt
similarity index 100%
rename from examples/language/openmoe/benchmark/hostfile.txt
rename to colossalai/legacy/moe/openmoe/benchmark/hostfile.txt
diff --git a/examples/language/openmoe/benchmark/utils.py b/colossalai/legacy/moe/openmoe/benchmark/utils.py
similarity index 100%
rename from examples/language/openmoe/benchmark/utils.py
rename to colossalai/legacy/moe/openmoe/benchmark/utils.py
diff --git a/examples/language/openmoe/infer.py b/colossalai/legacy/moe/openmoe/infer.py
similarity index 100%
rename from examples/language/openmoe/infer.py
rename to colossalai/legacy/moe/openmoe/infer.py
diff --git a/examples/language/openmoe/infer.sh b/colossalai/legacy/moe/openmoe/infer.sh
similarity index 100%
rename from examples/language/openmoe/infer.sh
rename to colossalai/legacy/moe/openmoe/infer.sh
diff --git a/examples/language/openmoe/model/__init__.py b/colossalai/legacy/moe/openmoe/model/__init__.py
similarity index 100%
rename from examples/language/openmoe/model/__init__.py
rename to colossalai/legacy/moe/openmoe/model/__init__.py
diff --git a/examples/language/openmoe/model/convert_openmoe_ckpt.py b/colossalai/legacy/moe/openmoe/model/convert_openmoe_ckpt.py
similarity index 100%
rename from examples/language/openmoe/model/convert_openmoe_ckpt.py
rename to colossalai/legacy/moe/openmoe/model/convert_openmoe_ckpt.py
diff --git a/examples/language/openmoe/model/convert_openmoe_ckpt.sh b/colossalai/legacy/moe/openmoe/model/convert_openmoe_ckpt.sh
similarity index 100%
rename from examples/language/openmoe/model/convert_openmoe_ckpt.sh
rename to colossalai/legacy/moe/openmoe/model/convert_openmoe_ckpt.sh
diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/colossalai/legacy/moe/openmoe/model/modeling_openmoe.py
similarity index 99%
rename from examples/language/openmoe/model/modeling_openmoe.py
rename to colossalai/legacy/moe/openmoe/model/modeling_openmoe.py
index 1febacd7d226..5d6e91765883 100644
--- a/examples/language/openmoe/model/modeling_openmoe.py
+++ b/colossalai/legacy/moe/openmoe/model/modeling_openmoe.py
@@ -50,8 +50,8 @@
except:
HAS_FLASH_ATTN = False
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
-from colossalai.moe.manager import MOE_MANAGER
-from colossalai.moe.utils import get_activation, set_moe_args
+from colossalai.legacy.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.utils import get_activation, set_moe_args
from colossalai.shardformer.layer.moe import SparseMLP
if HAS_TRITON:
diff --git a/examples/language/openmoe/model/openmoe_8b_config.json b/colossalai/legacy/moe/openmoe/model/openmoe_8b_config.json
similarity index 100%
rename from examples/language/openmoe/model/openmoe_8b_config.json
rename to colossalai/legacy/moe/openmoe/model/openmoe_8b_config.json
diff --git a/examples/language/openmoe/model/openmoe_base_config.json b/colossalai/legacy/moe/openmoe/model/openmoe_base_config.json
similarity index 100%
rename from examples/language/openmoe/model/openmoe_base_config.json
rename to colossalai/legacy/moe/openmoe/model/openmoe_base_config.json
diff --git a/examples/language/openmoe/model/openmoe_policy.py b/colossalai/legacy/moe/openmoe/model/openmoe_policy.py
similarity index 99%
rename from examples/language/openmoe/model/openmoe_policy.py
rename to colossalai/legacy/moe/openmoe/model/openmoe_policy.py
index f46062128563..ccd566b08594 100644
--- a/examples/language/openmoe/model/openmoe_policy.py
+++ b/colossalai/legacy/moe/openmoe/model/openmoe_policy.py
@@ -9,7 +9,7 @@
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import logging
-from colossalai.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.manager import MOE_MANAGER
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
diff --git a/examples/language/openmoe/requirements.txt b/colossalai/legacy/moe/openmoe/requirements.txt
similarity index 100%
rename from examples/language/openmoe/requirements.txt
rename to colossalai/legacy/moe/openmoe/requirements.txt
diff --git a/examples/language/openmoe/test_ci.sh b/colossalai/legacy/moe/openmoe/test_ci.sh
similarity index 100%
rename from examples/language/openmoe/test_ci.sh
rename to colossalai/legacy/moe/openmoe/test_ci.sh
diff --git a/examples/language/openmoe/train.py b/colossalai/legacy/moe/openmoe/train.py
similarity index 99%
rename from examples/language/openmoe/train.py
rename to colossalai/legacy/moe/openmoe/train.py
index ff0e4bad6ee3..0173f0964453 100644
--- a/examples/language/openmoe/train.py
+++ b/colossalai/legacy/moe/openmoe/train.py
@@ -19,7 +19,7 @@
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
-from colossalai.moe.utils import skip_init
+from colossalai.legacy.moe.utils import skip_init
from colossalai.nn.optimizer import HybridAdam
from colossalai.shardformer.layer.moe import apply_load_balance
diff --git a/examples/language/openmoe/train.sh b/colossalai/legacy/moe/openmoe/train.sh
similarity index 100%
rename from examples/language/openmoe/train.sh
rename to colossalai/legacy/moe/openmoe/train.sh
diff --git a/colossalai/moe/utils.py b/colossalai/legacy/moe/utils.py
similarity index 99%
rename from colossalai/moe/utils.py
rename to colossalai/legacy/moe/utils.py
index 3d08ab7dd9b0..d91c41363316 100644
--- a/colossalai/moe/utils.py
+++ b/colossalai/legacy/moe/utils.py
@@ -9,7 +9,7 @@
from torch.distributed.distributed_c10d import get_process_group_ranks
from colossalai.accelerator import get_accelerator
-from colossalai.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import is_moe_tensor
diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py
index 0623d19efd5f..e69de29bb2d1 100644
--- a/colossalai/moe/__init__.py
+++ b/colossalai/moe/__init__.py
@@ -1,5 +0,0 @@
-from .manager import MOE_MANAGER
-
-__all__ = [
- "MOE_MANAGER",
-]
diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py
index 01c837ee36ad..ac422a4da98f 100644
--- a/colossalai/moe/_operation.py
+++ b/colossalai/moe/_operation.py
@@ -290,7 +290,7 @@ def moe_cumsum(inputs: Tensor, use_kernel: bool = False):
return torch.cumsum(inputs, dim=0) - 1
-class MoeInGradScaler(torch.autograd.Function):
+class EPGradScalerIn(torch.autograd.Function):
"""
Scale the gradient back by the number of experts
because the batch size increases in the moe stage
@@ -298,8 +298,7 @@ class MoeInGradScaler(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
- if ctx is not None:
- ctx.ep_size = ep_size
+ ctx.ep_size = ep_size
return inputs
@staticmethod
@@ -311,7 +310,7 @@ def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
return grad, None
-class MoeOutGradScaler(torch.autograd.Function):
+class EPGradScalerOut(torch.autograd.Function):
"""
Scale the gradient by the number of experts
because the batch size increases in the moe stage
@@ -331,6 +330,50 @@ def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
return grad, None
+class DPGradScalerIn(torch.autograd.Function):
+ """
+ Scale the gradient back by the number of experts
+ because the batch size increases in the moe stage
+ """
+
+ @staticmethod
+ def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor:
+ assert activated_experts != 0, f"shouldn't be called when no expert is activated"
+ ctx.moe_dp_size = moe_dp_size
+ ctx.activated_experts = activated_experts
+ return inputs
+
+ @staticmethod
+ def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]:
+ assert len(grad_outputs) == 1
+ grad = grad_outputs[0]
+ if ctx.moe_dp_size != ctx.activated_experts:
+ grad.mul_(ctx.activated_experts / ctx.moe_dp_size)
+ return grad, None, None
+
+
+class DPGradScalerOut(torch.autograd.Function):
+ """
+ Scale the gradient by the number of experts
+ because the batch size increases in the moe stage
+ """
+
+ @staticmethod
+ def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor:
+ assert activated_experts != 0, f"shouldn't be called when no expert is activated"
+ ctx.moe_dp_size = moe_dp_size
+ ctx.activated_experts = activated_experts
+ return inputs
+
+ @staticmethod
+ def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]:
+ assert len(grad_outputs) == 1
+ grad = grad_outputs[0]
+ if ctx.moe_dp_size != ctx.activated_experts:
+ grad.mul_(ctx.moe_dp_size / ctx.activated_experts)
+ return grad, None, None
+
+
def _all_to_all(
inputs: torch.Tensor,
input_split_sizes: Optional[List[int]] = None,
@@ -393,4 +436,7 @@ def all_to_all_uneven(
group=None,
overlap: bool = False,
):
+ assert (
+ inputs.requires_grad
+ ), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py
index 141baf3d3770..5872c64856b9 100644
--- a/colossalai/shardformer/layer/attn.py
+++ b/colossalai/shardformer/layer/attn.py
@@ -139,12 +139,11 @@ def prepare_attn_kwargs(
# no padding
assert is_causal
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
- attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv)
+ attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
+ if s_q != 1:
+ attention_mask = attention_mask.tril(diagonal=0)
+ attention_mask = attention_mask.expand(b, s_q, s_kv)
else:
- assert q_padding_mask.shape == (
- b,
- s_q,
- ), f"q_padding_mask shape {q_padding_mask.shape} should be the same. ({shape_4d})"
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
if kv_padding_mask is None:
# self attention
@@ -156,7 +155,7 @@ def prepare_attn_kwargs(
b,
s_kv,
), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
- attention_mask = q_padding_mask[:, None, :].expand(b, s_kv, s_q).to(dtype=dtype, device=device)
+ attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
outputs.update(
{
"cu_seqlens_q": cu_seqlens_q,
@@ -169,7 +168,8 @@ def prepare_attn_kwargs(
)
if is_causal:
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
- attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
+ if s_q != 1:
+ attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
else:
outputs["attention_mask_type"] = AttnMaskType.PADDED
attention_mask = invert_mask(attention_mask).unsqueeze(1)
diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py
index 0f6595a7c4d6..000934ad91a2 100644
--- a/colossalai/shardformer/layer/qkv_fused_linear.py
+++ b/colossalai/shardformer/layer/qkv_fused_linear.py
@@ -695,6 +695,7 @@ def from_native_module(
process_group=process_group,
weight=module.weight,
bias_=module.bias,
+ n_fused=n_fused,
*args,
**kwargs,
)
diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py
index 759c8d7b8d59..5b36fc7db3b9 100644
--- a/colossalai/shardformer/modeling/command.py
+++ b/colossalai/shardformer/modeling/command.py
@@ -116,7 +116,7 @@ def command_model_forward(
# for the other stages, hidden_states is the output of the previous stage
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
- mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
+ mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py
index 6e79ce144cc8..a84a3097231a 100644
--- a/colossalai/shardformer/modeling/deepseek.py
+++ b/colossalai/shardformer/modeling/deepseek.py
@@ -1,21 +1,38 @@
-from typing import List, Optional, Union
+import warnings
+from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
-
-# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
from torch.nn import CrossEntropyLoss
-from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
-from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.cache_utils import Cache, DynamicCache
+from transformers.modeling_attn_mask_utils import (
+ _prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
+)
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
-from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
+from colossalai.moe._operation import (
+ DPGradScalerIn,
+ DPGradScalerOut,
+ EPGradScalerIn,
+ EPGradScalerOut,
+ all_to_all_uneven,
+)
from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer._operation import (
+ all_to_all_comm,
+ gather_forward_split_backward,
+ split_forward_gather_backward,
+)
+from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none
+from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
# copied from modeling_deepseek.py
@@ -42,30 +59,54 @@ def backward(ctx, grad_output):
class EPDeepseekMoE(nn.Module):
def __init__(self):
- super(EPDeepseekMoE, self).__init__()
+ raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
+
+ def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
+ assert tp_group is not None
+ assert moe_dp_group is not None
+ assert ep_group is not None
- def setup_ep(self, ep_group: ProcessGroup):
- ep_group = ep_group
- self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
- self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
+ self.ep_size = dist.get_world_size(ep_group)
+ self.ep_rank = dist.get_rank(ep_group)
self.num_experts = self.config.n_routed_experts
assert self.num_experts % self.ep_size == 0
+
self.ep_group = ep_group
self.num_experts_per_ep = self.num_experts // self.ep_size
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
+
set_tensors_to_none(self.experts, exclude=set(held_experts))
+
+ # setup moe_dp group
+ self.moe_dp_group = moe_dp_group
+ self.moe_dp_size = moe_dp_group.size()
+
+ # setup tp group
+ self.tp_group = tp_group
+ if self.tp_group.size() > 1:
+ for expert in held_experts:
+ expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.tp_group)
+ expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.tp_group)
+ expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.tp_group)
+
for p in self.experts.parameters():
- p.ep_group = ep_group
+ set_moe_tensor_ep_group(p, ep_group)
@staticmethod
- def from_native_module(module: Union["DeepseekMoE", "DeepseekMLP"], *args, **kwargs) -> "EPDeepseekMoE":
+ def from_native_module(
+ module,
+ tp_group: ProcessGroup,
+ moe_dp_group: ProcessGroup,
+ ep_group: ProcessGroup,
+ *args,
+ **kwargs,
+ ) -> "EPDeepseekMoE":
LazyInitContext.materialize(module)
if module.__class__.__name__ == "DeepseekMLP":
return module
module.__class__ = EPDeepseekMoE
- assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!"
- module.setup_ep(kwargs["ep_group"])
+ module.setup_process_groups(tp_group, moe_dp_group, ep_group)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -91,15 +132,24 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3]
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
+ with torch.no_grad():
+ activate_experts = output_split_sizes[: self.num_experts_per_ep].clone()
+ for i in range(1, self.ep_size):
+ activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
+ activate_experts = (activate_experts > 0).float()
+ dist.all_reduce(activate_experts, group=self.moe_dp_group)
+
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
- output_states = MoeInGradScaler.apply(output_states, self.ep_size)
+ output_states = EPGradScalerIn.apply(output_states, self.ep_size)
if output_states.size(0) > 0:
if self.num_experts_per_ep == 1:
expert = self.experts[self.expert_start_idx]
+ output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0])
output_states = expert(output_states)
+ output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0])
else:
output_states_splits = output_states.split(output_split_sizes.tolist())
output_states_list = []
@@ -107,10 +157,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if split_states.size(0) == 0: # no token routed to this experts
continue
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
+ split_states = DPGradScalerIn.apply(
+ split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep]
+ )
split_states = expert(split_states)
+ split_states = DPGradScalerOut.apply(
+ split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep]
+ )
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
- output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
+ output_states = EPGradScalerOut.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
recover_token_idx = torch.empty_like(flat_topk_token_idx)
recover_token_idx[flat_topk_token_idx] = torch.arange(
@@ -310,7 +366,14 @@ def custom_forward(*inputs):
next_cache = next_decoder_cache if use_cache else None
if stage_manager.is_last_stage():
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
# always return dict for imediate stage
return {
"hidden_states": hidden_states,
@@ -427,3 +490,265 @@ def deepseek_for_causal_lm_forward(
hidden_states = outputs.get("hidden_states")
out["hidden_states"] = hidden_states
return out
+
+
+def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
+ logger = logging.get_logger(__name__)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if sp_mode is not None:
+ assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
+ assert (sp_size is not None) and (
+ sp_group is not None
+ ), "Must specify sp_size and sp_group for sequence parallel"
+
+ # DeepseekFlashAttention2 attention does not support output_attentions
+ if "padding_mask" in kwargs:
+ warnings.warn(
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+ )
+
+ # overwrite attention_mask with padding_mask
+ attention_mask = kwargs.pop("padding_mask")
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # sp: modify sp_len when sequence parallel mode is ring
+ if sp_mode in ["split_gather", "ring"]:
+ q_len *= sp_size
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # sp: all-to-all comminucation when introducing sequence parallel
+ if sp_mode == "all_to_all":
+ query_states = all_to_all_comm(query_states, sp_group)
+ key_states = all_to_all_comm(key_states, sp_group)
+ value_states = all_to_all_comm(value_states, sp_group)
+ bsz, q_len, _ = query_states.size()
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids, unsqueeze_dim=0
+ )
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (DeepseekRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ # Handle the case where the model is quantized
+ if hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ elif torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+ attn_output = self._flash_attention_forward(
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
+ )
+ # sp: all-to-all comminucation when introducing sequence parallel
+ if sp_mode == "all_to_all":
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
+ attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256)
+ else:
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+ return forward
+
+
+def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
+ logger = logging.get_logger(__name__)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape[:2]
+ elif inputs_embeds is not None:
+ batch_size, seq_length = inputs_embeds.shape[:2]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers."
+ )
+ use_cache = False
+
+ past_key_values_length = 0
+ if use_cache:
+ use_legacy_cache = not isinstance(past_key_values, Cache)
+ if use_legacy_cache:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if self._use_flash_attention_2:
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self._use_sdpa and not output_attentions:
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ )
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
+
+ if sp_mode in ["ring", "split_gather"]:
+ inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
+ elif sp_mode == "all_to_all":
+ inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
+ # embed positions
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ if sp_mode == "ring" or sp_mode == "split_gather":
+ hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
+ elif sp_mode == "all_to_all":
+ hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = None
+ if use_cache:
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ return forward
diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py
index 54ff8e321e06..9ffbca517d4c 100644
--- a/colossalai/shardformer/modeling/llama.py
+++ b/colossalai/shardformer/modeling/llama.py
@@ -643,7 +643,7 @@ def forward(
# in this case, attention_mask is a dict rather than a tensor
if shard_config.enable_flash_attention:
- mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
+ mask_shape = (inputs_embeds.shape[0], 1, seq_len, past_seen_tokens + seq_len)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
inputs_embeds.dtype,
diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py
index 82e8ef5f9af7..ec1a8a00a58a 100644
--- a/colossalai/shardformer/modeling/mistral.py
+++ b/colossalai/shardformer/modeling/mistral.py
@@ -91,7 +91,7 @@ def mistral_model_forward(
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
- mask_shape = (batch_size, 1, seq_length, seq_length)
+ mask_shape = (batch_size, 1, seq_length, seq_length + past_key_values_length)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py
index 2fbc34302cde..d30ce5ea85cc 100644
--- a/colossalai/shardformer/modeling/mixtral.py
+++ b/colossalai/shardformer/modeling/mixtral.py
@@ -1,52 +1,105 @@
-from typing import List, Optional
+import inspect
+import warnings
+from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed import ProcessGroup
-
-# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
from torch.nn import CrossEntropyLoss
-from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
+from transformers.cache_utils import Cache, DynamicCache
+from transformers.modeling_attn_mask_utils import (
+ _prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
+)
from transformers.models.mixtral.modeling_mixtral import (
MixtralSparseMoeBlock,
MoeCausalLMOutputWithPast,
+ MoeModelOutputWithPast,
+ apply_rotary_pos_emb,
load_balancing_loss_func,
+ repeat_kv,
)
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
-from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
+from colossalai.moe._operation import (
+ DPGradScalerIn,
+ DPGradScalerOut,
+ EPGradScalerIn,
+ EPGradScalerOut,
+ all_to_all_uneven,
+)
from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer._operation import (
+ all_to_all_comm,
+ gather_forward_split_backward,
+ split_forward_gather_backward,
+)
+from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none
+from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
+
+if is_flash_attn_2_available():
+ from flash_attn import flash_attn_func
+
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
- def __init__(self, config):
- self.moe_info = None
- super().__init__(config)
-
- def setup_ep(self, ep_group: ProcessGroup):
- ep_group = ep_group
- self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
- self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
- assert self.num_experts % self.ep_size == 0
+ def __init__(self, *args, **kwargs):
+ raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
+
+ def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
+ assert tp_group is not None
+ assert moe_dp_group is not None
+ assert ep_group is not None
+
+ # setup ep group
+ self.ep_size = dist.get_world_size(ep_group)
+ self.ep_rank = dist.get_rank(ep_group)
self.ep_group = ep_group
+
+ if self.num_experts % self.ep_size != 0:
+ raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
+
self.num_experts_per_ep = self.num_experts // self.ep_size
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
+
set_tensors_to_none(self.experts, exclude=set(held_experts))
+
+ # setup moe_dp group
+ self.moe_dp_group = moe_dp_group
+ self.moe_dp_size = moe_dp_group.size()
+
+ # setup global tp group
+ self.tp_group = tp_group
+ if self.tp_group.size() > 1:
+ for expert in held_experts:
+ expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.tp_group)
+ expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group)
+ expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group)
+
for p in self.experts.parameters():
- p.ep_group = ep_group
+ set_moe_tensor_ep_group(p, ep_group)
@staticmethod
- def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
+ def from_native_module(
+ module: MixtralSparseMoeBlock,
+ tp_group: ProcessGroup,
+ moe_dp_group: ProcessGroup,
+ ep_group: ProcessGroup,
+ *args,
+ **kwargs,
+ ) -> "EPMixtralSparseMoeBlock":
+ # TODO: better init
LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock
- # if "ep_group" in kwargs:
- assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!"
- module.setup_ep(kwargs["ep_group"])
+ module.setup_process_groups(tp_group, moe_dp_group, ep_group)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -65,20 +118,31 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
selected_experts_idx = selected_experts.argsort()
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
+
output_split_sizes = torch.zeros_like(input_split_sizes)
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
+ with torch.no_grad():
+ activate_experts = output_split_sizes[: self.num_experts_per_ep].clone()
+ for i in range(1, self.ep_size):
+ activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
+ activate_experts = (activate_experts > 0).float()
+ dist.all_reduce(activate_experts, group=self.moe_dp_group)
+
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
+
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
# compute expert output
- output_states = MoeInGradScaler.apply(output_states, self.ep_size)
+ output_states = EPGradScalerIn.apply(output_states, self.ep_size)
if output_states.size(0) > 0:
if self.num_experts_per_ep == 1:
# no need to split
expert = self.experts[self.expert_start_idx]
+ output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0])
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
output_states = expert.w2(output_states)
+ output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0])
else:
output_states_splits = output_states.split(output_split_sizes.tolist())
output_states_list = []
@@ -86,12 +150,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if split_states.size(0) == 0:
continue
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
+ split_states = DPGradScalerIn.apply(
+ split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep]
+ )
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
split_states = expert.w2(split_states)
+ split_states = DPGradScalerOut.apply(
+ split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep]
+ )
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
- output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
+
+ output_states = EPGradScalerOut.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
+
recover_experts_idx = torch.empty_like(selected_experts_idx)
recover_experts_idx[selected_experts_idx] = torch.arange(
selected_experts_idx.size(0), device=selected_experts_idx.device
@@ -107,7 +179,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class MixtralPipelineForwards:
"""
- This class serves as a micro library for forward function substitution of Llama models
+ This class serves as a micro library for forward function substitution of Mixtral models
under pipeline setting.
"""
@@ -300,16 +372,29 @@ def custom_forward(*inputs):
if output_router_logits and past_router_logits is not None:
all_router_logits = past_router_logits + all_router_logits
if stage_manager.is_last_stage():
- return tuple(
- v
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
- if v is not None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
+ if v is not None
+ )
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ router_logits=all_router_logits,
)
- # always return dict for imediate stage
- return {
- "hidden_states": hidden_states,
- "past_router_logits": all_router_logits,
- }
+ else:
+ if output_router_logits:
+ return {
+ "hidden_states": hidden_states,
+ "past_router_logits": all_router_logits,
+ }
+ else:
+ return {
+ "hidden_states": hidden_states,
+ }
@staticmethod
def mixtral_for_causal_lm_forward(
@@ -441,3 +526,335 @@ def mixtral_for_causal_lm_forward(
if output_router_logits:
out["past_router_logits"] = outputs["past_router_logits"]
return out
+
+
+def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
+ logger = logging.get_logger(__name__)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
+ if sp_mode is not None:
+ assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
+ assert (sp_size is not None) and (
+ sp_group is not None
+ ), "Must specify sp_size and sp_group for sequence parallel"
+
+ if "padding_mask" in kwargs:
+ warnings.warn(
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+ )
+
+ # overwrite attention_mask with padding_mask
+ attention_mask = kwargs.pop("padding_mask")
+ bsz, q_len, _ = hidden_states.size()
+
+ # sp: modify sp_len when sequence parallel mode is ring
+ if sp_mode in ["split_gather", "ring"]:
+ q_len *= sp_size
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # sp: all-to-all comminucation when introducing sequence parallel
+ if sp_mode == "all_to_all":
+ query_states = all_to_all_comm(query_states, sp_group)
+ key_states = all_to_all_comm(key_states, sp_group)
+ value_states = all_to_all_comm(value_states, sp_group)
+ bsz, q_len, _ = query_states.size()
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if self.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ use_sliding_windows = (
+ _flash_supports_window_size
+ and getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ )
+ if not _flash_supports_window_size:
+ logger.warning_once(
+ "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
+ " make sure to upgrade flash-attn library."
+ )
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
+ if (
+ getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ and cache_has_contents
+ ):
+ slicing_tokens = 1 - self.config.sliding_window
+
+ past_key = past_key_value[self.layer_idx][0]
+ past_value = past_key_value[self.layer_idx][1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
+ f" {past_key.shape}"
+ )
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
+
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+ attn_output = self._flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ dropout=dropout_rate,
+ use_sliding_windows=use_sliding_windows,
+ )
+
+ # sp: all-to-all comminucation when introducing sequence parallel
+ if sp_mode == "all_to_all":
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
+ attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256)
+ else:
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+ return attn_output, attn_weights, past_key_value
+
+ return forward
+
+
+def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
+ logger = logging.get_logger(__name__)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ past_key_values_length = 0
+
+ if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+ if use_cache:
+ use_legacy_cache = not isinstance(past_key_values, Cache)
+ if use_legacy_cache:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
+ if is_padding_right:
+ raise ValueError(
+ "You are attempting to perform batched generation with padding_side='right'"
+ " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
+ )
+ if self._attn_implementation == "flash_attention_2":
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self._attn_implementation == "sdpa" and not output_attentions:
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ )
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ sliding_window=self.config.sliding_window,
+ )
+
+ if sp_mode in ["ring", "split_gather"]:
+ inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
+ elif sp_mode == "all_to_all":
+ inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_router_logits = () if output_router_logits else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ output_router_logits,
+ use_cache,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ output_router_logits=output_router_logits,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if output_router_logits:
+ all_router_logits += (layer_outputs[-1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ if sp_mode == "ring" or sp_mode == "split_gather":
+ hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
+ elif sp_mode == "all_to_all":
+ hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = None
+ if use_cache:
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
+ if v is not None
+ )
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ router_logits=all_router_logits,
+ )
+
+ return forward
diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py
index 55822b1505f1..538e96c32c6d 100644
--- a/colossalai/shardformer/modeling/qwen2.py
+++ b/colossalai/shardformer/modeling/qwen2.py
@@ -136,7 +136,7 @@ def qwen2_model_forward(
# for the other stages, hidden_states is the output of the previous stage
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
- mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
+ mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
@@ -651,6 +651,10 @@ def forward(
seq_length_with_past = seq_length
past_key_values_length = 0
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
@@ -668,7 +672,7 @@ def forward(
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
- mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
+ mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py
index ae9f3603c96e..7b9c759a66c2 100644
--- a/colossalai/shardformer/policies/auto_policy.py
+++ b/colossalai/shardformer/policies/auto_policy.py
@@ -161,7 +161,7 @@ class PolicyLocation:
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
),
# Deepseek
- "transformers_modules.modeling_deepseek.DeepSeekModel": PolicyLocation(
+ "transformers_modules.modeling_deepseek.DeepseekModel": PolicyLocation(
file_name="deepseek", class_name="DeepseekModelPolicy"
),
"transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation(
@@ -200,6 +200,9 @@ class PolicyLocation:
"transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM": PolicyLocation(
file_name="mixtral", class_name="MixtralForCausalLMPolicy"
),
+ "transformers.models.mixtral.modeling_mixtral.MixtralForSequenceClassification": PolicyLocation(
+ file_name="mixtral", class_name="MixtralForSequenceClassificationPolicy"
+ ),
# Qwen2
"transformers.models.qwen2.modeling_qwen2.Qwen2Model": PolicyLocation(
file_name="qwen2", class_name="Qwen2ModelPolicy"
@@ -240,6 +243,9 @@ def _fullname(obj):
# patch custom models which are not in transformers
# it can be like 'transformers_modules.THUDM.chatglm3-6b.103caa40027ebfd8450289ca2f278eac4ff26405.modeling_chatglm' (from huggingface hub)
# or like 'transformers_modules.chatglm.modeling_chatglm' (from local directory)
+ if module.startswith("peft"):
+ klass = obj.base_model.model.__class__
+ module = klass.__module__
if module.startswith("transformers_modules"):
split_module = module.split(".")
if len(split_module) >= 2:
diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py
index 8ebda357b380..605f69c4a632 100644
--- a/colossalai/shardformer/policies/deepseek.py
+++ b/colossalai/shardformer/policies/deepseek.py
@@ -1,13 +1,20 @@
-import warnings
from functools import partial
from typing import Callable, Dict, List, Union
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
+from transformers.utils import is_flash_attn_greater_or_equal_2_10
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
-from colossalai.shardformer.modeling.deepseek import DeepseekPipelineForwards, EPDeepseekMoE
+from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
+from colossalai.shardformer.layer.linear import Linear1D_Row
+from colossalai.shardformer.modeling.deepseek import (
+ DeepseekPipelineForwards,
+ EPDeepseekMoE,
+ get_deepseek_flash_attention_forward,
+ get_deepseek_flash_attention_model_forward,
+)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]
@@ -18,6 +25,13 @@ def config_sanity_check(self):
pass
def preprocess(self):
+ self.tie_weight = self.tie_weight_check()
+ self.origin_attn_implement = self.model.config._attn_implementation
+ """
+ Because transformers library's bug for AutoModel/AutoConfig, who pop “attn_implement” twice from modeling_utils.py and configuration_utils.py.
+ This bug causes attn_cls to be set to sdpa. Here we assign it to "flash_attention_2".
+ """
+ # self.origin_attn_implement = "flash_attention_2"
if self.shard_config.enable_tensor_parallelism:
# Resize embedding
vocab_size = self.model.config.vocab_size
@@ -30,25 +44,118 @@ def preprocess(self):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
+
+ ATTN_IMPLEMENTATION = {
+ "eager": "DeepseekAttention",
+ "flash_attention_2": "DeepseekFlashAttention2",
+ "sdpa": "DeepseekSdpaAttention",
+ }
policy = {}
+ attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
+ sp_mode = self.shard_config.sequence_parallelism_mode or None
+ sp_size = self.shard_config.sequence_parallel_size or None
+ sp_group = self.shard_config.sequence_parallel_process_group or None
+ sp_partial_derived = sp_mode in ["split_gather", "ring"]
+ if sp_mode == "all_to_all":
+ decoder_attribute_replacement = {
+ "num_heads": self.model.config.num_attention_heads // sp_size,
+ }
+ if getattr(self.model.config, "num_key_value_heads", False):
+ decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
+ policy[attn_cls] = ModulePolicyDescription(
+ attribute_replacement=decoder_attribute_replacement,
+ )
if self.shard_config.enable_sequence_parallelism:
- self.shard_config.enable_sequence_parallelism = False
- raise NotImplementedError(
- "Deepseek dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
+ if self.pipeline_stage_manager is not None:
+ # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
+ # if both are enabled, one of them will be ignored
+ raise NotImplementedError("Sequence parallelism is not supported with pipeline parallelism.")
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_deepseek_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
+ },
+ policy=policy,
+ target_key=attn_cls,
)
-
+ if self.pipeline_stage_manager is None:
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_deepseek_flash_attention_model_forward(
+ self.shard_config,
+ sp_mode=sp_mode,
+ sp_size=sp_size,
+ sp_group=sp_group,
+ ),
+ },
+ policy=policy,
+ target_key="DeepseekModel",
+ )
+ embedding_cls = None
+ if self.shard_config.enable_tensor_parallelism:
+ embedding_cls = VocabParallelEmbedding1D
+ else:
+ if self.tie_weight:
+ embedding_cls = PaddingEmbedding
if self.shard_config.enable_tensor_parallelism:
- raise NotImplementedError("Tensor parallelism is not supported for Deepseek model now.")
+ # tensor parallelism for non-moe params
+ assert (
+ self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of attention heads must be divisible by tensor parallel size."
+ assert (
+ self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of key_value heads must be divisible by tensor parallel size."
+ decoder_attribute_replacement = {
+ "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+ "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ "self_attn.num_key_value_heads": self.model.config.num_key_value_heads
+ // self.shard_config.tensor_parallel_size,
+ }
+
+ policy["DeepseekDecoderLayer"] = ModulePolicyDescription(
+ attribute_replacement=decoder_attribute_replacement,
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="self_attn.q_proj",
+ target_module=Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.k_proj",
+ target_module=Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.v_proj",
+ target_module=Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.o_proj",
+ target_module=Linear1D_Row,
+ ),
+ ],
+ )
+ if embedding_cls is not None:
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="embed_tokens",
+ target_module=embedding_cls,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+ ),
+ policy=policy,
+ target_key="DeepseekModel",
+ )
- if getattr(self.shard_config, "ep_group", None) is not None:
+ if self.shard_config.ep_group:
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="mlp",
target_module=EPDeepseekMoE,
- kwargs={"ep_group": self.shard_config.ep_group},
+ kwargs={
+ "ep_group": self.shard_config.ep_group,
+ "tp_group": self.shard_config.tensor_parallel_process_group,
+ "moe_dp_group": self.shard_config.moe_dp_group,
+ },
)
],
policy=policy,
@@ -62,10 +169,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
],
policy=policy,
@@ -76,17 +185,39 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
policy=policy,
target_key="DeepseekModel",
)
if self.shard_config.enable_flash_attention:
- warnings.warn(
- "Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False."
+ # NOTE: there is a bug for toggling flash attention in AutoModel, which has to be used for deepseek right now
+ from transformers.dynamic_module_utils import get_class_from_dynamic_module
+
+ flash_attn_cls = get_class_from_dynamic_module(
+ "deepseek-ai/deepseek-moe-16b-base--modeling_deepseek.DeepseekFlashAttention2",
+ "deepseek-ai/deepseek-moe-16b-base",
)
- self.shard_config.enable_flash_attention = False
+ class TargetFlashAttn:
+ def __init__(self):
+ raise RuntimeError("This class should not be instantiated")
+
+ @staticmethod
+ def from_native_module(original_attn: nn.Module, *args, **kwargs) -> nn.Module:
+ original_attn.__class__ = flash_attn_cls
+ original_attn._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ return original_attn
+
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="self_attn",
+ target_module=TargetFlashAttn,
+ ),
+ policy=policy,
+ target_key="DeepseekDecoderLayer",
+ )
return policy
def postprocess(self):
@@ -96,6 +227,10 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
+ if self.shard_config.enable_sequence_parallelism:
+ # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
+ # if both are enabled, one of them will be ignored
+ raise NotImplementedError("Pipeline parallelism is not supported with sequence parallelism.")
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "DeepseekModel":
module = self.model
diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py
index ad93e94694c8..10df143c99da 100644
--- a/colossalai/shardformer/policies/mixtral.py
+++ b/colossalai/shardformer/policies/mixtral.py
@@ -1,13 +1,21 @@
+import warnings
from functools import partial
from typing import Callable, Dict, List, Union
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
-from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel
+from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
-from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock, MixtralPipelineForwards
+from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
+from colossalai.shardformer.layer.linear import Linear1D_Row
+from colossalai.shardformer.modeling.mixtral import (
+ EPMixtralSparseMoeBlock,
+ MixtralPipelineForwards,
+ get_mixtral_flash_attention_forward,
+ get_mixtral_flash_attention_model_forward,
+)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
@@ -18,36 +26,136 @@ def config_sanity_check(self):
pass
def preprocess(self):
- if self.shard_config.enable_tensor_parallelism:
- # Resize embedding
- vocab_size = self.model.config.vocab_size
- world_size = self.shard_config.tensor_parallel_size
-
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
-
+ self.tie_weight = self.tie_weight_check()
+ self.origin_attn_implement = self.model.config._attn_implementation
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
+ from transformers.models.mixtral.modeling_mixtral import (
+ MixtralAttention,
+ MixtralDecoderLayer,
+ MixtralFlashAttention2,
+ MixtralModel,
+ MixtralSdpaAttention,
+ )
+
+ ATTN_IMPLEMENTATION = {
+ "eager": MixtralAttention,
+ "flash_attention_2": MixtralFlashAttention2,
+ "sdpa": MixtralSdpaAttention,
+ }
policy = {}
+ attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
+
+ sp_mode = self.shard_config.sequence_parallelism_mode or None
+ sp_size = self.shard_config.sequence_parallel_size or None
+ sp_group = self.shard_config.sequence_parallel_process_group or None
+ sp_partial_derived = sp_mode in ["split_gather", "ring"]
+ if sp_mode == "all_to_all":
+ decoder_attribute_replacement = {
+ "num_heads": self.model.config.num_attention_heads // sp_size,
+ }
+ if getattr(self.model.config, "num_key_value_heads", False):
+ decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
+ policy[attn_cls] = ModulePolicyDescription(
+ attribute_replacement=decoder_attribute_replacement,
+ )
if self.shard_config.enable_sequence_parallelism:
- self.shard_config.enable_sequence_parallelism = False
- raise NotImplementedError(
- "Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
+ if self.pipeline_stage_manager is not None:
+ # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
+ # if both are enabled, one of them will be ignored
+ raise NotImplementedError("Sequence parallelism is not supported with pipeline parallelism.")
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_mixtral_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
+ },
+ policy=policy,
+ target_key=attn_cls,
+ )
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_mixtral_flash_attention_model_forward(
+ self.shard_config,
+ sp_mode=sp_mode,
+ sp_size=sp_size,
+ sp_group=sp_group,
+ ),
+ },
+ policy=policy,
+ target_key=MixtralModel,
)
+ embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
- raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
- if getattr(self.shard_config, "ep_group", None) is not None:
+ embedding_cls = VocabParallelEmbedding1D
+ else:
+ if self.tie_weight:
+ embedding_cls = PaddingEmbedding
+
+ if self.shard_config.enable_tensor_parallelism:
+ # tensor parallelism for non-moe params
+ assert (
+ self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of attention heads must be divisible by tensor parallel size."
+ assert (
+ self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of key_value heads must be divisible by tensor parallel size."
+ decoder_attribute_replacement = {
+ "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+ "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ "self_attn.num_key_value_heads": self.model.config.num_key_value_heads
+ // self.shard_config.tensor_parallel_size,
+ }
+
+ policy[MixtralDecoderLayer] = ModulePolicyDescription(
+ attribute_replacement=decoder_attribute_replacement,
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="self_attn.q_proj",
+ target_module=Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.k_proj",
+ target_module=Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.v_proj",
+ target_module=Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.o_proj",
+ target_module=Linear1D_Row,
+ ),
+ SubModuleReplacementDescription( # or replicate?
+ suffix="block_sparse_moe.gate", target_module=Linear1D_Col, kwargs={"gather_output": True}
+ ),
+ ],
+ )
+
+ if embedding_cls is not None:
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="embed_tokens",
+ target_module=embedding_cls,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+ ),
+ policy=policy,
+ target_key=MixtralModel,
+ )
+
+ if self.shard_config.ep_group:
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="block_sparse_moe",
target_module=EPMixtralSparseMoeBlock,
- kwargs={"ep_group": self.shard_config.ep_group},
+ kwargs={
+ "ep_group": self.shard_config.ep_group,
+ "tp_group": self.shard_config.tensor_parallel_process_group,
+ "moe_dp_group": self.shard_config.moe_dp_group,
+ },
)
],
policy=policy,
@@ -61,10 +169,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
],
policy=policy,
@@ -75,13 +185,15 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
policy=policy,
target_key=MixtralModel,
)
if self.shard_config.enable_flash_attention:
- raise NotImplementedError("Flash attention has already been replaced in mixtral.")
+ warnings.warn("Flash attention is natively supported in transformers, will ignore the flag.")
+ self.shard_config.enable_flash_attention = False
return policy
@@ -92,6 +204,10 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
+ if self.shard_config.enable_sequence_parallelism:
+ # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
+ # if both are enabled, one of them will be ignored
+ raise NotImplementedError("Pipeline parallelism is not supported with sequence parallelism.")
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "MixtralModel":
module = self.model
@@ -150,7 +266,7 @@ def get_held_layers(self) -> List[Module]:
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
- """No shared params in llama model"""
+ """No shared params in mixtral model"""
return []
@@ -206,3 +322,40 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
}
]
return []
+
+
+class MixtralForSequenceClassificationPolicy(MixtralPolicy):
+ def module_policy(self):
+ from transformers import MixtralForSequenceClassification
+
+ policy = super().module_policy()
+
+ if self.shard_config.enable_tensor_parallelism:
+ # add a new item for sequence classification
+ new_item = {
+ MixtralForSequenceClassification: ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
+ )
+ ]
+ )
+ }
+ policy.update(new_item)
+
+ if self.pipeline_stage_manager:
+ raise NotImplementedError
+
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ stage_manager = self.pipeline_stage_manager
+ held_layers = super().get_held_layers()
+ if stage_manager.is_last_stage(ignore_chunk=True):
+ held_layers.append(self.model.score)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in mixtral for sequence classification model"""
+ return []
diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py
index b64300366fc3..163d7a7bbb0c 100644
--- a/colossalai/shardformer/shard/shard_config.py
+++ b/colossalai/shardformer/shard/shard_config.py
@@ -46,7 +46,11 @@ class ShardConfig:
make_vocab_size_divisible_by: int = 64
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
+
+ # for moe related
+ moe_dp_group: Optional[ProcessGroup] = None
ep_group: Optional[ProcessGroup] = None
+
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py
index e24a67f9de3c..8b6d403f1327 100644
--- a/colossalai/zero/low_level/bookkeeping/gradient_store.py
+++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py
@@ -19,7 +19,6 @@ def __init__(self, *args, partition_grad: bool = False):
"""
self._grads_of_params = dict()
# stage 2
- self._partition_grads = partition_grad
self._working_index = 0 if partition_grad else self._local_rank
# for zero2, it's `param_id: [grad_local_rank]`
self.grad_to_param_mapping = dict()
@@ -91,7 +90,7 @@ def get_working_grads_by_group_id(self, group_id: int) -> List:
return grad_list
- def get_working_grad_by_param_id(self, param_id) -> Tensor:
+ def get_working_grad_by_param_id(self, param_id) -> Optional[Tensor]:
"""
Return the working gradient for the specified parameter.
@@ -112,6 +111,7 @@ def reset_grads_by_group_id(self, group_id: int):
def reset_all_gradients(self):
self._grads_of_params = dict()
+ self.grad_to_param_mapping = dict()
def get_param_id_for_grad(self, grad: Tensor) -> Optional[int]:
"""Return the id of a parameter which the gradient slice belongs to
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index 6ff235b96a5d..51d7d1eaaa33 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -20,6 +20,7 @@
)
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
+from colossalai.tensor.moe_tensor.api import is_moe_tensor
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
from .bookkeeping import BucketStore, GradientStore, TensorBucket
@@ -66,7 +67,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def __init__(
self,
optimizer: Optimizer,
- pg_to_param_list: Dict[ProcessGroup, List[nn.Parameter]] = None,
+ pg_to_param_list: Optional[Dict[ProcessGroup, List[nn.Parameter]]] = None,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.0,
@@ -92,7 +93,7 @@ def __init__(
self._logger = get_dist_logger()
self._verbose = verbose
- if dp_process_group is not None and pg_to_param_list is not None:
+ if (dp_process_group is not None) and (pg_to_param_list is not None):
raise ValueError("dp_process_group and pg_to_param_list should not be provided at the same time.")
if pg_to_param_list is None:
@@ -338,14 +339,14 @@ def _run_reduction(self):
self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id)
else:
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size))
- recieved_grad = torch.zeros_like(flat_grads_list[0])
- dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg)
+ received_grad = torch.zeros_like(flat_grads_list[0])
+ dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg)
- if recieved_grad.dtype != grad_dtype:
- recieved_grad = recieved_grad.to(grad_dtype)
+ if received_grad.dtype != grad_dtype:
+ received_grad = received_grad.to(grad_dtype)
grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.local_rank]
- self._update_partitoned_grad(bucket_store, grad_in_bucket_current_rank, recieved_grad, group_id, 1)
+ self._update_partitoned_grad(bucket_store, grad_in_bucket_current_rank, received_grad, group_id, 1)
bucket_store.reset()
@@ -649,6 +650,11 @@ def _sync_grad(self):
for group_id in range(self.num_param_groups):
param_group = self._working_param_groups[group_id]
for param in param_group:
+ if is_moe_tensor(param) and param.requires_grad and param.grad is None:
+ # TODO better of of doing this
+ # assign zero grad to unrouted expert to avoid hang during grad reduction
+ param.grad = torch.zeros_like(param)
+
if param.requires_grad and param.grad is not None:
self._add_to_bucket(param, group_id)
@@ -807,8 +813,8 @@ def update_master_params(self, model: nn.Module) -> None:
"""
for p in model.parameters():
p_id = id(p)
- pg = self.param_to_pg[p]
if p_id in self.working_to_master_param:
+ pg = self.param_to_pg[p]
master_param = self.working_to_master_param[p_id]
padding_size = self.get_param_padding_size(p)
working_param = p.data.view(-1)
@@ -869,13 +875,12 @@ def get_padding_map(self) -> Dict[int, Tensor]:
def get_param_grad(self, working_param: nn.Parameter) -> Tensor:
grad_store = self.pid_to_grad_store[id(working_param)]
- partial_grad = grad_store.get_working_grad_by_param_id(id(working_param))
- if partial_grad is None:
+ grad = grad_store.get_working_grad_by_param_id(id(working_param))
+ if grad is None:
return None
- tensor_list = [torch.empty_like(partial_grad) for _ in range(grad_store.world_size)]
- dist.all_gather(tensor_list, partial_grad, group=grad_store.torch_pg)
- grad_flat = torch.cat(tensor_list, dim=0)
- return grad_flat[: working_param.numel()].reshape_as(working_param)
+ grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device)
+ dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg)
+ return grad_flat.view(-1)[: working_param.numel()].view_as(working_param)
def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]:
working_grads = []
diff --git a/docs/source/en/basics/launch_colossalai.md b/docs/source/en/basics/launch_colossalai.md
index 8a6028d6c49a..32748dae1163 100644
--- a/docs/source/en/basics/launch_colossalai.md
+++ b/docs/source/en/basics/launch_colossalai.md
@@ -131,17 +131,18 @@ with one simple command. There are two ways you can launch multi-node jobs.
This is suitable when you only have a few nodes. Let's say I have two nodes, namely `host1` and `host2`, I can start
multi-node training with the following command. Compared to single-node training, you must specify the `master_addr`
-option, which is auto-set to localhost if running on a single node only.
+option, which is auto-set to localhost if running on a single node only. \
+Additionally, you must also ensure that all nodes share the same open ssh port, which can be specified using --ssh-port.
:::caution
-`master_addr` cannot be localhost when running on multiple nodes, it should be the hostname or IP address of a node.
+`master_addr` cannot be localhost when running on multiple nodes, it should be the **hostname or IP address** of a node.
:::
```shell
# run on these two nodes
-colossalai run --nproc_per_node 4 --host host1,host2 --master_addr host1 test.py
+colossalai run --nproc_per_node 4 --host host1,host2 --master_addr host1 test.py --ssh-port 22
```
- Run with `--hostfile`
diff --git a/docs/source/zh-Hans/basics/launch_colossalai.md b/docs/source/zh-Hans/basics/launch_colossalai.md
index a80d16717e40..9e40f64c2124 100644
--- a/docs/source/zh-Hans/basics/launch_colossalai.md
+++ b/docs/source/zh-Hans/basics/launch_colossalai.md
@@ -116,17 +116,17 @@ colossalai run --nproc_per_node 4 --master_port 29505 test.py
- 通过`--hosts`来启动
这个方式适合节点数不多的情况。假设我们有两个节点,分别为`host`和`host2`。我们可以用以下命令进行多节点训练。
-比起单节点训练,多节点训练需要手动设置`--master_addr` (在单节点训练中`master_addr`默认为`127.0.0.1`)。
+比起单节点训练,多节点训练需要手动设置`--master_addr` (在单节点训练中`master_addr`默认为`127.0.0.1`)。同时,你需要确保每个节点都使用同一个ssh port。可以通过--ssh-port设置。
:::caution
-多节点训练时,`master_addr`不能为`localhost`或者`127.0.0.1`,它应该是一个节点的名字或者IP地址。
+多节点训练时,`master_addr`不能为`localhost`或者`127.0.0.1`,它应该是一个节点的**名字或者IP地址**。
:::
```shell
# 在两个节点上训练
-colossalai run --nproc_per_node 4 --host host1,host2 --master_addr host1 test.py
+colossalai run --nproc_per_node 4 --host host1,host2 --master_addr host1 test.py --ssh-port 22
```
diff --git a/examples/inference/stable_diffusion/README.md b/examples/inference/stable_diffusion/README.md
new file mode 100644
index 000000000000..c11b9804392c
--- /dev/null
+++ b/examples/inference/stable_diffusion/README.md
@@ -0,0 +1,22 @@
+## File Structure
+```
+|- sd3_generation.py: an example of how to use Colossalai Inference Engine to generate result by loading Diffusion Model.
+|- compute_metric.py: compare the quality of images w/o some acceleration method like Distrifusion
+|- benchmark_sd3.py: benchmark the performance of our InferenceEngine
+|- run_benchmark.sh: run benchmark command
+```
+Note: compute_metric.py need some dependencies which need `pip install -r requirements.txt`, `requirements.txt` is in `examples/inference/stable_diffusion/`
+
+## Run Inference
+
+The provided example `sd3_generation.py` is an example to configure, initialize the engine, and run inference on provided model. We've added `DiffusionPipeline` as model class, and the script is good to run inference with StableDiffusion 3.
+
+For a basic setting, you could run the example by:
+```bash
+colossalai run --nproc_per_node 1 sd3_generation.py -m PATH_MODEL -p "hello world"
+```
+
+Run multi-GPU inference (Patched Parallelism), as in the following example using 2 GPUs:
+```bash
+colossalai run --nproc_per_node 2 sd3_generation.py -m PATH_MODEL
+```
diff --git a/examples/inference/stable_diffusion/benchmark_sd3.py b/examples/inference/stable_diffusion/benchmark_sd3.py
new file mode 100644
index 000000000000..19db57c33c82
--- /dev/null
+++ b/examples/inference/stable_diffusion/benchmark_sd3.py
@@ -0,0 +1,179 @@
+import argparse
+import json
+import time
+from contextlib import nullcontext
+
+import torch
+import torch.distributed as dist
+from diffusers import DiffusionPipeline
+
+import colossalai
+from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
+from colossalai.inference.core.engine import InferenceEngine
+from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
+
+GIGABYTE = 1024**3
+MEGABYTE = 1024 * 1024
+
+_DTYPE_MAPPING = {
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+ "fp32": torch.float32,
+}
+
+
+def log_generation_time(log_data, log_file):
+ with open(log_file, "a") as f:
+ json.dump(log_data, f, indent=2)
+ f.write("\n")
+
+
+def warmup(engine, args):
+ for _ in range(args.n_warm_up_steps):
+ engine.generate(
+ prompts=["hello world"],
+ generation_config=DiffusionGenerationConfig(
+ num_inference_steps=args.num_inference_steps, height=args.height[0], width=args.width[0]
+ ),
+ )
+
+
+def profile_context(args):
+ return (
+ torch.profiler.profile(
+ record_shapes=True,
+ with_stack=True,
+ with_modules=True,
+ activities=[
+ torch.profiler.ProfilerActivity.CPU,
+ torch.profiler.ProfilerActivity.CUDA,
+ ],
+ )
+ if args.profile
+ else nullcontext()
+ )
+
+
+def log_and_profile(h, w, avg_time, log_msg, args, model_name, mode, prof=None):
+ log_data = {
+ "mode": mode,
+ "model": model_name,
+ "batch_size": args.batch_size,
+ "patched_parallel_size": args.patched_parallel_size,
+ "num_inference_steps": args.num_inference_steps,
+ "height": h,
+ "width": w,
+ "dtype": args.dtype,
+ "profile": args.profile,
+ "n_warm_up_steps": args.n_warm_up_steps,
+ "n_repeat_times": args.n_repeat_times,
+ "avg_generation_time": avg_time,
+ "log_message": log_msg,
+ }
+
+ if args.log:
+ log_file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}.json"
+ log_generation_time(log_data=log_data, log_file=log_file)
+
+ if args.profile:
+ file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}_prof.json"
+ prof.export_chrome_trace(file)
+
+
+def benchmark_colossalai(rank, world_size, port, args):
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ from colossalai.cluster.dist_coordinator import DistCoordinator
+
+ coordinator = DistCoordinator()
+
+ inference_config = InferenceConfig(
+ dtype=args.dtype,
+ patched_parallelism_size=args.patched_parallel_size,
+ )
+ engine = InferenceEngine(args.model, inference_config=inference_config, verbose=False)
+
+ warmup(engine, args)
+
+ for h, w in zip(args.height, args.width):
+ with profile_context(args) as prof:
+ start = time.perf_counter()
+ for _ in range(args.n_repeat_times):
+ engine.generate(
+ prompts=["hello world"],
+ generation_config=DiffusionGenerationConfig(
+ num_inference_steps=args.num_inference_steps, height=h, width=w
+ ),
+ )
+ end = time.perf_counter()
+
+ avg_time = (end - start) / args.n_repeat_times
+ log_msg = f"[ColossalAI]avg generation time for h({h})xw({w}) is {avg_time:.2f}s"
+ coordinator.print_on_master(log_msg)
+
+ if dist.get_rank() == 0:
+ log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "colossalai", prof=prof)
+
+
+def benchmark_diffusers(args):
+ model = DiffusionPipeline.from_pretrained(args.model, torch_dtype=_DTYPE_MAPPING[args.dtype]).to("cuda")
+
+ for _ in range(args.n_warm_up_steps):
+ model(
+ prompt="hello world",
+ num_inference_steps=args.num_inference_steps,
+ height=args.height[0],
+ width=args.width[0],
+ )
+
+ for h, w in zip(args.height, args.width):
+ with profile_context(args) as prof:
+ start = time.perf_counter()
+ for _ in range(args.n_repeat_times):
+ model(prompt="hello world", num_inference_steps=args.num_inference_steps, height=h, width=w)
+ end = time.perf_counter()
+
+ avg_time = (end - start) / args.n_repeat_times
+ log_msg = f"[Diffusers]avg generation time for h({h})xw({w}) is {avg_time:.2f}s"
+ print(log_msg)
+
+ log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "diffusers", prof)
+
+
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def benchmark(args):
+ if args.mode == "colossalai":
+ spawn(benchmark_colossalai, nprocs=args.patched_parallel_size, args=args)
+ elif args.mode == "diffusers":
+ benchmark_diffusers(args)
+
+
+"""
+# enable log
+python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --log
+python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --log
+
+# enable profiler
+python examples/inference/stable_diffusion/benchmark_sd3.py -m "stabilityai/stable-diffusion-3-medium-diffusers" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
+python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
+python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
+"""
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size")
+ parser.add_argument("-p", "--patched_parallel_size", type=int, default=1, help="Patched Parallelism size")
+ parser.add_argument("-n", "--num_inference_steps", type=int, default=50, help="Number of inference steps")
+ parser.add_argument("-H", "--height", type=int, nargs="+", default=[1024, 2048], help="Height list")
+ parser.add_argument("-w", "--width", type=int, nargs="+", default=[1024, 2048], help="Width list")
+ parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type")
+ parser.add_argument("--n_warm_up_steps", type=int, default=3, help="Number of warm up steps")
+ parser.add_argument("--n_repeat_times", type=int, default=5, help="Number of repeat times")
+ parser.add_argument("--profile", default=False, action="store_true", help="Enable torch profiler")
+ parser.add_argument("--log", default=False, action="store_true", help="Enable logging")
+ parser.add_argument("-m", "--model", default="stabilityai/stable-diffusion-3-medium-diffusers", help="Model path")
+ parser.add_argument(
+ "--mode", default="colossalai", choices=["colossalai", "diffusers"], help="Inference framework mode"
+ )
+ args = parser.parse_args()
+ benchmark(args)
diff --git a/examples/inference/stable_diffusion/compute_metric.py b/examples/inference/stable_diffusion/compute_metric.py
new file mode 100644
index 000000000000..14c92501b66d
--- /dev/null
+++ b/examples/inference/stable_diffusion/compute_metric.py
@@ -0,0 +1,80 @@
+# Code from https://github.com/mit-han-lab/distrifuser/blob/main/scripts/compute_metrics.py
+import argparse
+import os
+
+import numpy as np
+import torch
+from cleanfid import fid
+from PIL import Image
+from torch.utils.data import DataLoader, Dataset
+from torchmetrics.image import LearnedPerceptualImagePatchSimilarity, PeakSignalNoiseRatio
+from torchvision.transforms import Resize
+from tqdm import tqdm
+
+
+def read_image(path: str):
+ """
+ input: path
+ output: tensor (C, H, W)
+ """
+ img = np.asarray(Image.open(path))
+ if len(img.shape) == 2:
+ img = np.repeat(img[:, :, None], 3, axis=2)
+ img = torch.from_numpy(img).permute(2, 0, 1)
+ return img
+
+
+class MultiImageDataset(Dataset):
+ def __init__(self, root0, root1, is_gt=False):
+ super().__init__()
+ self.root0 = root0
+ self.root1 = root1
+ file_names0 = os.listdir(root0)
+ file_names1 = os.listdir(root1)
+
+ self.image_names0 = sorted([name for name in file_names0 if name.endswith(".png") or name.endswith(".jpg")])
+ self.image_names1 = sorted([name for name in file_names1 if name.endswith(".png") or name.endswith(".jpg")])
+ self.is_gt = is_gt
+ assert len(self.image_names0) == len(self.image_names1)
+
+ def __len__(self):
+ return len(self.image_names0)
+
+ def __getitem__(self, idx):
+ img0 = read_image(os.path.join(self.root0, self.image_names0[idx]))
+ if self.is_gt:
+ # resize to 1024 x 1024
+ img0 = Resize((1024, 1024))(img0)
+ img1 = read_image(os.path.join(self.root1, self.image_names1[idx]))
+
+ batch_list = [img0, img1]
+ return batch_list
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--batch_size", type=int, default=64)
+ parser.add_argument("--num_workers", type=int, default=8)
+ parser.add_argument("--is_gt", action="store_true")
+ parser.add_argument("--input_root0", type=str, required=True)
+ parser.add_argument("--input_root1", type=str, required=True)
+ args = parser.parse_args()
+
+ psnr = PeakSignalNoiseRatio(data_range=(0, 1), reduction="elementwise_mean", dim=(1, 2, 3)).to("cuda")
+ lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to("cuda")
+
+ dataset = MultiImageDataset(args.input_root0, args.input_root1, is_gt=args.is_gt)
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
+
+ progress_bar = tqdm(dataloader)
+ with torch.inference_mode():
+ for i, batch in enumerate(progress_bar):
+ batch = [img.to("cuda") / 255 for img in batch]
+ batch_size = batch[0].shape[0]
+ psnr.update(batch[0], batch[1])
+ lpips.update(batch[0], batch[1])
+ fid_score = fid.compute_fid(args.input_root0, args.input_root1)
+
+ print("PSNR:", psnr.compute().item())
+ print("LPIPS:", lpips.compute().item())
+ print("FID:", fid_score)
diff --git a/examples/inference/stable_diffusion/requirements.txt b/examples/inference/stable_diffusion/requirements.txt
new file mode 100644
index 000000000000..c4e74162dfb5
--- /dev/null
+++ b/examples/inference/stable_diffusion/requirements.txt
@@ -0,0 +1,3 @@
+torchvision
+torchmetrics
+cleanfid
diff --git a/examples/inference/stable_diffusion/run_benchmark.sh b/examples/inference/stable_diffusion/run_benchmark.sh
new file mode 100644
index 000000000000..f3e45a335219
--- /dev/null
+++ b/examples/inference/stable_diffusion/run_benchmark.sh
@@ -0,0 +1,42 @@
+#!/bin/bash
+
+models=("PixArt-alpha/PixArt-XL-2-1024-MS" "stabilityai/stable-diffusion-3-medium-diffusers")
+parallelism=(1 2 4 8)
+resolutions=(1024 2048 3840)
+modes=("colossalai" "diffusers")
+
+CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() {
+ local n=${1:-"9999"}
+ echo "GPU Memory Usage:"
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
+ | tail -n +2 \
+ | nl -v 0 \
+ | tee /dev/tty \
+ | sort -g -k 2 \
+ | awk '{print $1}' \
+ | head -n $n)
+ export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
+ echo "Now CUDA_VISIBLE_DEVICES is set to:"
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+}
+
+for model in "${models[@]}"; do
+ for p in "${parallelism[@]}"; do
+ for resolution in "${resolutions[@]}"; do
+ for mode in "${modes[@]}"; do
+ if [[ "$mode" == "colossalai" && "$p" == 1 ]]; then
+ continue
+ fi
+ if [[ "$mode" == "diffusers" && "$p" != 1 ]]; then
+ continue
+ fi
+ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage $p
+
+ cmd="python examples/inference/stable_diffusion/benchmark_sd3.py -m \"$model\" -p $p --mode $mode --log -H $resolution -w $resolution"
+
+ echo "Executing: $cmd"
+ eval $cmd
+ done
+ done
+ done
+done
diff --git a/examples/inference/stable_diffusion/sd3_generation.py b/examples/inference/stable_diffusion/sd3_generation.py
index fe989eed7c2d..9e146c34b937 100644
--- a/examples/inference/stable_diffusion/sd3_generation.py
+++ b/examples/inference/stable_diffusion/sd3_generation.py
@@ -1,18 +1,17 @@
import argparse
-from diffusers import PixArtAlphaPipeline, StableDiffusion3Pipeline
-from torch import bfloat16, float16, float32
+from diffusers import DiffusionPipeline
+from torch import bfloat16
+from torch import distributed as dist
+from torch import float16, float32
import colossalai
from colossalai.cluster import DistCoordinator
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
-from colossalai.inference.modeling.policy.pixart_alpha import PixArtAlphaInferPolicy
-from colossalai.inference.modeling.policy.stablediffusion3 import StableDiffusion3InferPolicy
# For Stable Diffusion 3, we'll use the following configuration
-MODEL_CLS = [StableDiffusion3Pipeline, PixArtAlphaPipeline][0]
-POLICY_CLS = [StableDiffusion3InferPolicy, PixArtAlphaInferPolicy][0]
+MODEL_CLS = DiffusionPipeline
TORCH_DTYPE_MAP = {
"fp16": float16,
@@ -43,20 +42,27 @@ def infer(args):
max_batch_size=args.max_batch_size,
tp_size=args.tp_size,
use_cuda_kernel=args.use_cuda_kernel,
+ patched_parallelism_size=dist.get_world_size(),
)
- engine = InferenceEngine(model, inference_config=inference_config, model_policy=POLICY_CLS(), verbose=True)
+ engine = InferenceEngine(model, inference_config=inference_config, verbose=True)
# ==============================
# Generation
# ==============================
coordinator.print_on_master(f"Generating...")
out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0]
- out.save("cat.jpg")
+ if dist.get_rank() == 0:
+ out.save(f"cat_parallel_size{dist.get_world_size()}.jpg")
coordinator.print_on_master(out)
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH
+
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
+# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
+
+# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
+# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
if __name__ == "__main__":
diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py
index 777d16cb9ea0..ae6d655f40a6 100644
--- a/examples/language/gpt/hybridparallelism/finetune.py
+++ b/examples/language/gpt/hybridparallelism/finetune.py
@@ -1,4 +1,5 @@
import argparse
+from contextlib import nullcontext
from typing import Callable, List, Union
import evaluate
@@ -17,6 +18,7 @@
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
+from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
# ==============================
@@ -186,7 +188,6 @@ def main():
help="only gpt2 now",
)
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
- parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
args = parser.parse_args()
if args.model_type == "gpt2":
@@ -250,10 +251,16 @@ def main():
pad_token_id=data_builder.tokenizer.pad_token_id,
)
- if model_name == "gpt2":
- model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
- else:
- raise RuntimeError
+ init_ctx = (
+ LazyInitContext(default_device=get_accelerator().get_current_device())
+ if isinstance(plugin, (GeminiPlugin))
+ else nullcontext()
+ )
+ with init_ctx:
+ if model_name == "gpt2":
+ model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
+ else:
+ raise RuntimeError
# optimizer
no_decay = ["bias", "LayerNorm.weight"]
diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py
index 2b7bd50b8766..e530e2d6a153 100644
--- a/examples/language/llama/benchmark.py
+++ b/examples/language/llama/benchmark.py
@@ -98,6 +98,7 @@ def main():
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true")
+ parser.add_argument("--overlap_allgather", action="store_true")
args = parser.parse_args()
colossalai.launch_from_torch()
@@ -199,9 +200,9 @@ def empty_init():
enable_flash_attention=args.xformers,
microbatch_size=args.mbs,
precision="bf16",
- dp_outside=False,
overlap_p2p=args.overlap,
enable_metadata_cache=not args.no_cache,
+ overlap_allgather=args.overlap_allgather,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":
diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py
index c2883d96c16e..ca9b63d1a14a 100755
--- a/examples/language/opt/opt_benchmark.py
+++ b/examples/language/opt/opt_benchmark.py
@@ -1,4 +1,5 @@
import time
+from contextlib import nullcontext
import torch
import tqdm
@@ -8,9 +9,11 @@
from transformers.utils.versions import require_version
import colossalai
+from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
+from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
@@ -62,14 +65,6 @@ def main():
if args.mem_cap > 0:
colo_memory_cap(args.mem_cap)
- # Build OPT model
- config = AutoConfig.from_pretrained(args.model_name_or_path)
- model = OPTForCausalLM(config=config)
- logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
-
- # Enable gradient checkpointing
- model.gradient_checkpointing_enable()
-
# Set plugin
booster_kwargs = {}
if args.plugin == "torch_ddp_fp16":
@@ -82,6 +77,19 @@ def main():
plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
+ # Build OPT model
+ init_ctx = (
+ LazyInitContext(default_device=get_accelerator().get_current_device())
+ if isinstance(plugin, (GeminiPlugin))
+ else nullcontext()
+ )
+ config = AutoConfig.from_pretrained(args.model_name_or_path)
+ with init_ctx:
+ model = OPTForCausalLM(config=config)
+ logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
+
+ # Enable gradient checkpointing
+ model.gradient_checkpointing_enable()
# Set optimizer
optimizer = HybridAdam(model.parameters(), lr=args.learning_rate)
diff --git a/examples/language/opt/opt_train_demo.py b/examples/language/opt/opt_train_demo.py
index b5b50305cc34..50dfc7bffd07 100644
--- a/examples/language/opt/opt_train_demo.py
+++ b/examples/language/opt/opt_train_demo.py
@@ -1,3 +1,5 @@
+from contextlib import nullcontext
+
import datasets
import torch
import transformers
@@ -8,9 +10,11 @@
from transformers.utils.versions import require_version
import colossalai
+from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
+from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
@@ -78,14 +82,6 @@ def main():
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
- # Build OPT model
- config = AutoConfig.from_pretrained(args.model_name_or_path)
- model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
- logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
-
- # Enable gradient checkpointing
- model.gradient_checkpointing_enable()
-
# Set plugin
booster_kwargs = {}
if args.plugin == "torch_ddp_fp16":
@@ -110,6 +106,21 @@ def main():
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
+ # Build OPT model
+ config = AutoConfig.from_pretrained(args.model_name_or_path)
+ # Build OPT model
+ init_ctx = (
+ LazyInitContext(default_device=get_accelerator().get_current_device())
+ if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
+ else nullcontext()
+ )
+ with init_ctx:
+ model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
+ logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
+
+ # Enable gradient checkpointing
+ model.gradient_checkpointing_enable()
+
# Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
dataset = NetflixDataset(tokenizer)
diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py
index 05c17f562635..4adc386192d3 100644
--- a/tests/kit/model_zoo/transformers/__init__.py
+++ b/tests/kit/model_zoo/transformers/__init__.py
@@ -3,28 +3,17 @@
from .blip2 import *
from .bloom import *
from .chatglm2 import *
+from .command import *
+from .deepseek import *
from .falcon import *
from .gpt import *
from .gptj import *
from .llama import *
+from .mistral import *
+from .mixtral import *
from .opt import *
+from .qwen2 import *
from .sam import *
from .t5 import *
from .vit import *
from .whisper import *
-
-try:
- from .mistral import *
-except ImportError:
- print("This version of transformers doesn't support mistral.")
-
-try:
- from .qwen2 import *
-except ImportError:
- print("This version of transformers doesn't support qwen2.")
-
-
-try:
- from .command import *
-except ImportError:
- print("This version of transformers doesn't support Command-R.")
diff --git a/tests/kit/model_zoo/transformers/deepseek.py b/tests/kit/model_zoo/transformers/deepseek.py
new file mode 100644
index 000000000000..ad73640a57c5
--- /dev/null
+++ b/tests/kit/model_zoo/transformers/deepseek.py
@@ -0,0 +1,83 @@
+# modified from tests/kit/model_zoo/transformers/mistral.py
+import torch
+import transformers
+from transformers import AutoConfig
+
+from ..registry import ModelAttribute, model_zoo
+
+# ===============================
+# Register single-sentence Mixtral
+# ===============================
+
+
+def data_gen():
+ # Generated from following code snippet
+ #
+ # from transformers import AutoModelForCausalLM, AutoTokenizer
+ # tokenizer = AutoTokenizer.from_pretrained("mixtralai/Mixtral-7B-v0.1")
+ # input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement)
+ # tokenized_input = tokenizer([input], return_tensors="pt")
+ # input_ids = tokenized_input['input_ids']
+ # attention_mask = tokenized_input['attention_mask']
+ input_ids = torch.tensor([[1, 22, 55, 77, 532, 349, 43, 22]], dtype=torch.int64)
+ attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
+ return dict(input_ids=input_ids, attention_mask=attention_mask)
+
+
+def data_gen_for_lm():
+ # LM data gen
+ # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
+ data = data_gen()
+ data["labels"] = data["input_ids"].clone()
+ return data
+
+
+def data_gen_for_sequence_classification():
+ # sequence classification data gen
+ data = data_gen()
+ data["labels"] = torch.tensor([1], dtype=torch.int64)
+ return data
+
+
+# define output transform function
+output_transform_fn = lambda x: x
+
+# define loss function
+loss_fn_for_mixtral_model = lambda x: x[0].mean()
+loss_fn = lambda x: x.loss
+loss_fn_for_seq_classification = lambda output: output.logits.mean()
+
+
+def init_deepseek():
+
+ config = AutoConfig.from_pretrained(
+ "deepseek-ai/deepseek-moe-16b-base",
+ hidden_size=32,
+ intermediate_size=32,
+ moe_intermediate_size=32,
+ num_hidden_layers=2,
+ num_attention_heads=8,
+ num_key_value_heads=8,
+ # vocab_size=2200,
+ first_k_dense_replace=1,
+ attn_implementation="flash_attention_2",
+ torch_dtype="float16",
+ n_routed_experts=8,
+ trust_remote_code=True,
+ )
+
+ if hasattr(config, "pad_token_id"):
+ config.pad_token_id = config.eos_token_id
+ model = transformers.AutoModel.from_config(config, trust_remote_code=True)
+
+ return model
+
+
+model_zoo.register(
+ name="transformers_deepseek",
+ model_fn=init_deepseek,
+ data_gen_fn=data_gen,
+ output_transform_fn=output_transform_fn,
+ loss_fn=loss_fn_for_mixtral_model,
+ model_attribute=ModelAttribute(has_control_flow=True),
+)
diff --git a/tests/kit/model_zoo/transformers/mixtral.py b/tests/kit/model_zoo/transformers/mixtral.py
new file mode 100644
index 000000000000..40e5a7b0232d
--- /dev/null
+++ b/tests/kit/model_zoo/transformers/mixtral.py
@@ -0,0 +1,87 @@
+# modified from tests/kit/model_zoo/transformers/mistral.py
+import torch
+import transformers
+from transformers import MixtralConfig
+
+from ..registry import ModelAttribute, model_zoo
+
+# ===============================
+# Register single-sentence Mixtral
+# ===============================
+
+
+def data_gen():
+ # Generated from following code snippet
+ #
+ # from transformers import AutoModelForCausalLM, AutoTokenizer
+ # tokenizer = AutoTokenizer.from_pretrained("mixtralai/Mixtral-7B-v0.1")
+ # input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement)
+ # tokenized_input = tokenizer([input], return_tensors="pt")
+ # input_ids = tokenized_input['input_ids']
+ # attention_mask = tokenized_input['attention_mask']
+ input_ids = torch.tensor([[1, 22, 55, 77, 532, 349, 43, 22]], dtype=torch.int64)
+ attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
+ return dict(input_ids=input_ids, attention_mask=attention_mask)
+
+
+def data_gen_for_lm():
+ # LM data gen
+ # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
+ data = data_gen()
+ data["labels"] = data["input_ids"].clone()
+ return data
+
+
+def data_gen_for_sequence_classification():
+ # sequence classification data gen
+ data = data_gen()
+ data["labels"] = torch.tensor([1], dtype=torch.int64)
+ return data
+
+
+# define output transform function
+output_transform_fn = lambda x: x
+
+# define loss function
+loss_fn_for_mixtral_model = lambda x: x[0].mean()
+loss_fn = lambda x: x.loss
+loss_fn_for_seq_classification = lambda output: output.logits.mean()
+
+config = MixtralConfig(
+ hidden_size=32,
+ intermediate_size=32,
+ num_attention_heads=8,
+ num_hidden_layers=2,
+ vocab_size=1000,
+ attn_implementation="flash_attention_2",
+ torch_dtype="float16",
+ output_router_logits=True,
+)
+
+if hasattr(config, "pad_token_id"):
+ config.pad_token_id = config.eos_token_id
+
+model_zoo.register(
+ name="transformers_mixtral",
+ model_fn=lambda: transformers.MixtralModel(config),
+ data_gen_fn=data_gen,
+ output_transform_fn=output_transform_fn,
+ loss_fn=loss_fn_for_mixtral_model,
+ model_attribute=ModelAttribute(has_control_flow=True),
+)
+# model_zoo.register(
+# name="transformers_mixtral_for_casual_lm",
+# model_fn=lambda: transformers.MixtralForCausalLM(config),
+# data_gen_fn=data_gen_for_lm,
+# output_transform_fn=output_transform_fn,
+# loss_fn=loss_fn,
+# model_attribute=ModelAttribute(has_control_flow=True),
+# )
+# model_zoo.register(
+# name="transformers_mixtral_for_sequence_classification",
+# model_fn=lambda: transformers.MixtralForSequenceClassification(config),
+# data_gen_fn=data_gen_for_sequence_classification,
+# output_transform_fn=output_transform_fn,
+# loss_fn=loss_fn_for_seq_classification,
+# model_attribute=ModelAttribute(has_control_flow=True),
+# )
diff --git a/tests/test_legacy/test_moe/moe_utils.py b/tests/test_legacy/test_moe/moe_utils.py
new file mode 100644
index 000000000000..8c133849b000
--- /dev/null
+++ b/tests/test_legacy/test_moe/moe_utils.py
@@ -0,0 +1,136 @@
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.distributed import ProcessGroup
+
+from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
+from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
+from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
+from colossalai.legacy.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.utils import get_moe_epsize_param_dict
+from colossalai.legacy.registry import GRADIENT_HANDLER
+from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size, set_moe_tensor_ep_group
+
+
+def delete_moe_info(model):
+ for _, param in model.named_parameters():
+ if hasattr(param, "ep_group"):
+ delattr(param, "ep_group")
+
+
+class MoeModel(nn.Module):
+ def __init__(self, ep_group: ProcessGroup = None):
+ super().__init__()
+ self.test_embed = nn.Linear(4, 16, bias=False)
+ self.w1 = torch.nn.Parameter(torch.randn(16, 8))
+ if ep_group:
+ set_moe_tensor_ep_group(self.w1, ep_group)
+
+ def forward(self, x):
+ x = self.test_embed(x)
+ x = torch.matmul(x, self.w1)
+
+ return x
+
+
+@GRADIENT_HANDLER.register_module
+class MoeGradientHandler(BaseGradientHandler):
+ """A helper class to handle all-reduce operations in a data parallel group and
+ moe model parallel. A all-reduce collective communication will be operated in
+ :func:`handle_gradient` among a data parallel group.
+ For better performance, it bucketizes the gradients of all parameters that are
+ the same type to improve the efficiency of communication.
+
+ Args:
+ model (Module): Model where the gradients accumulate.
+ optimizer (Optimizer): Optimizer for updating the parameters.
+ """
+
+ def __init__(self, model, optimizer=None):
+ super().__init__(model, optimizer)
+
+ def handle_gradient(self):
+ """A method running an all-reduce operation in a data parallel group.
+ Then running an all-reduce operation for all parameters in experts
+ across moe model parallel group
+ """
+ if dist.get_world_size() > 1:
+ epsize_param_dict = get_moe_epsize_param_dict(self._model)
+
+ # epsize is 1, indicating the params are replicated among processes in data parallelism
+ # use the ParallelMode.DATA to get data parallel group
+ # reduce gradients for all parameters in data parallelism
+ if 1 in epsize_param_dict:
+ bucket_allreduce(param_list=epsize_param_dict[1])
+
+ for ep_size in epsize_param_dict:
+ if ep_size != 1 and ep_size != MOE_MANAGER.world_size:
+ bucket_allreduce(
+ param_list=epsize_param_dict[ep_size], group=MOE_MANAGER.parallel_info_dict[ep_size].dp_group
+ )
+
+
+def assert_not_equal_in_group(tensor, process_group=None):
+ # all gather tensors from different ranks
+ world_size = dist.get_world_size(process_group)
+ tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
+ dist.all_gather(tensor_list, tensor, group=process_group)
+
+ # check if they are equal one by one
+ for i in range(world_size - 1):
+ a = tensor_list[i]
+ b = tensor_list[i + 1]
+ assert not torch.allclose(a, b), (
+ f"expected tensors on rank {i} and {i + 1} not to be equal " f"but they are, {a} vs {b}"
+ )
+
+
+def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
+ model.train()
+ with torch.cuda.amp.autocast(enabled=enable_autocast):
+ if criterion:
+ y = model(data)
+ loss = criterion(y, label)
+ else:
+ loss = model(data, label)
+ loss = loss.float()
+
+ if isinstance(model, LowLevelZeroModel):
+ optimizer.backward(loss)
+ else:
+ loss.backward()
+ return y
+
+
+def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None:
+ """Sync the parameters of tp model from ep model
+
+ Args:
+ local_model (MoeModule)
+ ep_model (MoeModule)
+ """
+ for (local_name, local_param), (ep_name, ep_param) in zip(
+ local_model.named_parameters(), ep_model.named_parameters()
+ ):
+ if "experts" not in local_name:
+ if assert_grad_flag:
+ assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}"
+ assert torch.allclose(local_param.grad, ep_param.grad)
+ else:
+ local_param.data.copy_(ep_param.data)
+ continue
+
+ # gather param from ep model
+ param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
+ dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
+ all_param = torch.cat(param_list, dim=0)
+ if assert_grad_flag:
+ grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
+ dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
+ all_grad = torch.cat(grad_list, dim=0)
+
+ if assert_grad_flag:
+ assert torch.allclose(local_param, all_param)
+ assert torch.allclose(local_param.grad, all_grad)
+ else:
+ local_param.data.copy_(all_param.data)
diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_legacy/test_moe/test_grad_handler.py
similarity index 98%
rename from tests/test_moe/test_grad_handler.py
rename to tests/test_legacy/test_moe/test_grad_handler.py
index 25e61b091729..3a782a6dd445 100644
--- a/tests/test_moe/test_grad_handler.py
+++ b/tests/test_legacy/test_moe/test_grad_handler.py
@@ -5,7 +5,7 @@
import colossalai
from colossalai.accelerator import get_accelerator
-from colossalai.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.manager import MOE_MANAGER
# from colossalai.shardformer.layer.moe.layers import SparseMLP
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
diff --git a/tests/test_moe/test_moe_group.py b/tests/test_legacy/test_moe/test_moe_group.py
similarity index 95%
rename from tests/test_moe/test_moe_group.py
rename to tests/test_legacy/test_moe/test_moe_group.py
index 89baf1d37b1b..68dac4828fa7 100644
--- a/tests/test_moe/test_moe_group.py
+++ b/tests/test_legacy/test_moe/test_moe_group.py
@@ -4,8 +4,8 @@
import colossalai
from colossalai.accelerator import get_accelerator
-from colossalai.moe.manager import MOE_MANAGER
-from colossalai.moe.utils import sync_moe_model_param
+from colossalai.legacy.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.utils import sync_moe_model_param
# from colossalai.shardformer.layer.moe import MLPExperts
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
diff --git a/tests/test_moe/test_moe_hybrid_zero.py b/tests/test_legacy/test_moe/test_moe_hybrid_zero.py
similarity index 98%
rename from tests/test_moe/test_moe_hybrid_zero.py
rename to tests/test_legacy/test_moe/test_moe_hybrid_zero.py
index 513c4ebda4a5..fdd6d956ef83 100644
--- a/tests/test_moe/test_moe_hybrid_zero.py
+++ b/tests/test_legacy/test_moe/test_moe_hybrid_zero.py
@@ -6,7 +6,7 @@
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
-from colossalai.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_moe.moe_utils import MoeModel
diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_legacy/test_moe/test_moe_load_balance.py
similarity index 99%
rename from tests/test_moe/test_moe_load_balance.py
rename to tests/test_legacy/test_moe/test_moe_load_balance.py
index ddd3ea368964..adf2dbc1ccf3 100644
--- a/tests/test_moe/test_moe_load_balance.py
+++ b/tests/test_legacy/test_moe/test_moe_load_balance.py
@@ -6,7 +6,7 @@
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
-from colossalai.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.manager import MOE_MANAGER
# from colossalai.shardformer.layer.moe import apply_load_balance
from colossalai.tensor.moe_tensor.api import is_moe_tensor
diff --git a/tests/test_lora/test_lora.py b/tests/test_lora/test_lora.py
index b8daf775db0e..1ae17025d31e 100644
--- a/tests/test_lora/test_lora.py
+++ b/tests/test_lora/test_lora.py
@@ -9,7 +9,8 @@
import colossalai
from colossalai.booster import Booster
-from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin
+from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
+from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_checkpoint_io.utils import shared_tempdir
@@ -20,7 +21,7 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type
model = model_fn()
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
- test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()]
+ test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin(), HybridParallelPlugin(tp_size=1, pp_size=1)]
test_configs = [
{
"lora_config": lora_config,
@@ -59,6 +60,8 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type
# test fwd bwd correctness
test_model = model_load
+ if isinstance(model_load, HybridParallelModule):
+ model_load = model_load.module.module
model_copy = copy.deepcopy(model_load)
data = data_gen_fn()
diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py
index 131932dcb3b3..8c411a33fef6 100644
--- a/tests/test_moe/moe_utils.py
+++ b/tests/test_moe/moe_utils.py
@@ -1,142 +1,8 @@
import torch
-import torch.distributed as dist
-import torch.nn as nn
-from torch.distributed import ProcessGroup
-from torch.testing import assert_close
-from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
-from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
-from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
-from colossalai.legacy.registry import GRADIENT_HANDLER
-from colossalai.moe.manager import MOE_MANAGER
-from colossalai.moe.utils import get_moe_epsize_param_dict
-# from colossalai.shardformer.layer.moe import SparseMLP
-from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size, set_moe_tensor_ep_group
-
-
-def delete_moe_info(model):
- for _, param in model.named_parameters():
- if hasattr(param, "ep_group"):
- delattr(param, "ep_group")
-
-
-class MoeModel(nn.Module):
- def __init__(self, ep_group: ProcessGroup = None):
- super().__init__()
- self.test_embed = nn.Linear(4, 16, bias=False)
- self.w1 = torch.nn.Parameter(torch.randn(16, 8))
- if ep_group:
- set_moe_tensor_ep_group(self.w1, ep_group)
-
- def forward(self, x):
- x = self.test_embed(x)
- x = torch.matmul(x, self.w1)
-
- return x
-
-
-@GRADIENT_HANDLER.register_module
-class MoeGradientHandler(BaseGradientHandler):
- """A helper class to handle all-reduce operations in a data parallel group and
- moe model parallel. A all-reduce collective communication will be operated in
- :func:`handle_gradient` among a data parallel group.
- For better performance, it bucketizes the gradients of all parameters that are
- the same type to improve the efficiency of communication.
-
- Args:
- model (Module): Model where the gradients accumulate.
- optimizer (Optimizer): Optimizer for updating the parameters.
- """
-
- def __init__(self, model, optimizer=None):
- super().__init__(model, optimizer)
-
- def handle_gradient(self):
- """A method running an all-reduce operation in a data parallel group.
- Then running an all-reduce operation for all parameters in experts
- across moe model parallel group
- """
- if dist.get_world_size() > 1:
- epsize_param_dict = get_moe_epsize_param_dict(self._model)
-
- # epsize is 1, indicating the params are replicated among processes in data parallelism
- # use the ParallelMode.DATA to get data parallel group
- # reduce gradients for all parameters in data parallelism
- if 1 in epsize_param_dict:
- bucket_allreduce(param_list=epsize_param_dict[1])
-
- for ep_size in epsize_param_dict:
- if ep_size != 1 and ep_size != MOE_MANAGER.world_size:
- bucket_allreduce(
- param_list=epsize_param_dict[ep_size], group=MOE_MANAGER.parallel_info_dict[ep_size].dp_group
- )
-
-
-def assert_not_equal_in_group(tensor, process_group=None):
- # all gather tensors from different ranks
- world_size = dist.get_world_size(process_group)
- tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
- dist.all_gather(tensor_list, tensor, group=process_group)
-
- # check if they are equal one by one
- for i in range(world_size - 1):
- a = tensor_list[i]
- b = tensor_list[i + 1]
- assert not torch.allclose(a, b), (
- f"expected tensors on rank {i} and {i + 1} not to be equal " f"but they are, {a} vs {b}"
- )
-
-
-def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
- model.train()
- with torch.cuda.amp.autocast(enabled=enable_autocast):
- if criterion:
- y = model(data)
- loss = criterion(y, label)
- else:
- loss = model(data, label)
- loss = loss.float()
-
- if isinstance(model, LowLevelZeroModel):
- optimizer.backward(loss)
- else:
- loss.backward()
- return y
-
-
-def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None:
- """Sync the parameters of tp model from ep model
-
- Args:
- local_model (MoeModule)
- ep_model (MoeModule)
- """
- for (local_name, local_param), (ep_name, ep_param) in zip(
- local_model.named_parameters(), ep_model.named_parameters()
- ):
- if "experts" not in local_name:
- if assert_grad_flag:
- assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}"
- assert torch.allclose(local_param.grad, ep_param.grad)
- else:
- local_param.data.copy_(ep_param.data)
- continue
-
- # gather param from ep model
- param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
- dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
- all_param = torch.cat(param_list, dim=0)
- if assert_grad_flag:
- grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
- dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
- all_grad = torch.cat(grad_list, dim=0)
-
- if assert_grad_flag:
- assert torch.allclose(local_param, all_param)
- assert torch.allclose(local_param.grad, all_grad)
- else:
- local_param.data.copy_(all_param.data)
+def assert_loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
+ assert loose_close(a, b, dtype), f"{name} not close {a.mean()} {b.mean()}"
def loose_close(a, b, dtype: torch.dtype = torch.float32):
@@ -148,8 +14,18 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
elif dtype is torch.bfloat16:
rtol = 4e-3
atol = 4e-3
+ else:
+ assert dtype is torch.float32
+ rtol = 1e-05
+ atol = 1e-08
a = a.detach().to(dtype)
b = b.detach().to(dtype).to(a.device)
- assert_close(a, b, rtol=rtol, atol=atol)
+ return torch.allclose(a, b, rtol=rtol, atol=atol)
+
+
+def check_model_equal(model1, model2):
+ assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
+ for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
+ assert_loose_close(p1, p2, p1.dtype)
diff --git a/tests/test_moe/test_deepseek_layer.py b/tests/test_moe/test_deepseek_layer.py
index 85cc986959fd..d18ba2eacd84 100644
--- a/tests/test_moe/test_deepseek_layer.py
+++ b/tests/test_moe/test_deepseek_layer.py
@@ -22,6 +22,7 @@ def check_deepseek_moe_layer():
precision="bf16",
tp_size=1,
pp_size=1,
+ zero_stage=1,
ep_size=dist.get_world_size(),
)
@@ -42,7 +43,12 @@ def check_deepseek_moe_layer():
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
orig_output = orig_model(x)
model = deepcopy(orig_model)
- model = EPDeepseekMoE.from_native_module(model, ep_group=plugin.ep_group)
+ model = EPDeepseekMoE.from_native_module(
+ model,
+ ep_group=plugin.ep_group,
+ moe_dp_group=plugin.moe_dp_group,
+ tp_group=plugin.tp_group,
+ )
ep_output = model(x)
assert_close(orig_output, ep_output)
orig_loss = orig_output.mean()
@@ -62,7 +68,7 @@ def run_dist(rank: int, world_size: int, port: int):
check_deepseek_moe_layer()
-# @pytest.mark.parametrize("world_size", [2, 4])
+@pytest.mark.skip("tested in corresponding sharderformer")
@pytest.mark.parametrize("world_size", [2])
def test_deepseek_moe_layer(world_size: int):
spawn(run_dist, world_size)
diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py
index 28e6db441411..c81023988377 100644
--- a/tests/test_moe/test_kernel.py
+++ b/tests/test_moe/test_kernel.py
@@ -4,8 +4,6 @@
import torch
from colossalai.accelerator import get_accelerator
-
-# from colossalai.moe import SparseMLP
from colossalai.moe._operation import MoeCombine, MoeDispatch, moe_cumsum
NUM_EXPERTS = 4
diff --git a/tests/test_moe/test_mixtral_layer.py b/tests/test_moe/test_mixtral_layer.py
index b7b0322e08b5..bc41ac4f33e9 100644
--- a/tests/test_moe/test_mixtral_layer.py
+++ b/tests/test_moe/test_mixtral_layer.py
@@ -23,6 +23,7 @@ def check_mixtral_moe_layer():
precision="bf16",
tp_size=1,
pp_size=1,
+ zero_stage=1,
ep_size=dist.get_world_size(),
)
config = MixtralConfig(
@@ -36,7 +37,12 @@ def check_mixtral_moe_layer():
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
orig_output, orig_logits = orig_model(x)
model = deepcopy(orig_model)
- model = EPMixtralSparseMoeBlock.from_native_module(model, ep_group=plugin.ep_group)
+ model = EPMixtralSparseMoeBlock.from_native_module(
+ model,
+ ep_group=plugin.ep_group,
+ tp_group=plugin.tp_group,
+ moe_dp_group=plugin.moe_dp_group,
+ )
ep_output, ep_logits = model(x)
assert_close(orig_logits, ep_logits)
assert_close(orig_output, ep_output)
@@ -57,7 +63,8 @@ def run_dist(rank: int, world_size: int, port: int):
check_mixtral_moe_layer()
-@pytest.mark.parametrize("world_size", [2, 4])
+@pytest.mark.skip("tested in corresponding sharderformer")
+@pytest.mark.parametrize("world_size", [2])
def test_mixtral_moe_layer(world_size: int):
spawn(run_dist, world_size)
diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py
index 164301695865..89f5d1c64d0d 100644
--- a/tests/test_moe/test_moe_checkpoint.py
+++ b/tests/test_moe/test_moe_checkpoint.py
@@ -6,31 +6,23 @@
import pytest
import torch
import torch.distributed as dist
-from torch.optim import Adam
+from torch.optim import SGD, Adam
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
-from colossalai.checkpoint_io import MoECheckpointIO
-from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import parameterize, spawn
+from colossalai.testing.random import seed_all
from colossalai.testing.utils import spawn
+from tests.test_moe.moe_utils import check_model_equal
tokens, n_experts = 7, 4
hidden_size = 8
top_k = 2
-def check_model_equal(model1, model2):
- assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
- for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
- if not torch.equal(p1.half(), p2.half()):
- print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}")
- raise AssertionError(f"Model parameter {name} is not equal")
-
-
def get_optimizer_snapshot(optim):
state = {id(k): deepcopy(v) for k, v in optim.state.items()}
param_groups = []
@@ -89,35 +81,33 @@ def check_optimizer_snapshot_equal(snapshot1, snapshot2, param2name, moe_dp_grou
num_experts_per_tok=top_k,
num_attention_heads=2,
num_key_value_heads=2,
+ num_hidden_layers=2,
),
MixtralForCausalLM,
],
],
)
def check_moe_checkpoint(test_config):
+ dtype, precision = torch.float16, "fp16"
+ config, model_cls = test_config
+ torch.cuda.set_device(dist.get_rank())
+
context = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext()
with context as f:
- torch.cuda.set_device(dist.get_rank())
if dist.get_rank() == 0:
broadcast_objects = [f] # any picklable object
else:
broadcast_objects = [None]
dist.broadcast_object_list(broadcast_objects, src=0)
- config = test_config[0]
- model_cls = test_config[1]
- torch.manual_seed(0)
input_ids = torch.randint(0, 100, (2, tokens)).cuda()
- orig_model = model_cls(config).cuda()
+ orig_model = model_cls(config).cuda().to(dtype)
+
+ seed_all(10086)
model = deepcopy(orig_model)
- optimizer = Adam(model.parameters(), lr=1e-3)
+ optimizer = SGD(model.parameters(), lr=1e-3)
plugin = MoeHybridParallelPlugin(
- pp_size=2,
- ep_size=2,
- tp_size=1,
- checkpoint_io=MoECheckpointIO,
- microbatch_size=1,
- zero_stage=1,
+ pp_size=2, ep_size=2, tp_size=1, microbatch_size=1, zero_stage=1, precision=precision
)
booster = Booster(plugin=plugin)
model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
@@ -139,13 +129,12 @@ def check_moe_checkpoint(test_config):
booster.save_model(model, model_dir, shard=True)
dist.barrier()
if dist.get_rank() == 0:
- saved_model = model_cls.from_pretrained(model_dir).cuda()
+ saved_model = model_cls.from_pretrained(model_dir).cuda().to(dtype)
check_model_equal(orig_model, saved_model)
- # check_model_equal(model, saved_model)
saved_model.save_pretrained(hf_model_dir)
dist.barrier()
# check load model
- new_model = model_cls(config).cuda()
+ new_model = model_cls(config).cuda().to(dtype)
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
booster.load_model(new_model, hf_model_dir)
diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py
index 9bc11033af6f..e6d2609ee67c 100644
--- a/tests/test_moe/test_moe_ep_tp.py
+++ b/tests/test_moe/test_moe_ep_tp.py
@@ -1,238 +1,132 @@
-import os
-import warnings
-from typing import Dict
+from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
+from transformers.models.mixtral.configuration_mixtral import MixtralConfig
+from transformers.models.mixtral.modeling_mixtral import MixtralModel
import colossalai
-from colossalai.accelerator import get_accelerator
-from colossalai.moe.manager import MOE_MANAGER
-from colossalai.moe.utils import sync_moe_model_param
-
-# from colossalai.shardformer.layer import SparseMLP
-from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
-from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
-from tests.test_moe.moe_utils import MoeGradientHandler
-
-
-def sync_tp_from_local(tp_model, local_model, assert_grad_flag: bool = False) -> None:
- """Sync the parameters of tp model from local model
-
- Args:
- tp_model (MoeModule)
- local_model (MoeModule)
- """
- for (tp_name, tp_param), (local_name, local_param) in zip(
- tp_model.named_parameters(), local_model.named_parameters()
- ):
- assert tp_name == local_name
- if not is_moe_tensor(tp_param):
- if assert_grad_flag:
- assert torch.allclose(tp_param, local_param)
- assert torch.allclose(tp_param.grad, local_param.grad)
- else:
- tp_param.data.copy_(local_param.data)
- continue
-
- tp_rank = get_ep_rank(tp_param)
- tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape, local_param.shape)) if d1 != d2][0]
- tp_slice = [slice(None)] * tp_dim + [
- slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1))
- ]
-
- if assert_grad_flag:
- assert torch.allclose(tp_param, local_param[tuple(tp_slice)])
- assert torch.allclose(tp_param.grad, local_param.grad[tuple(tp_slice)])
- else:
- tp_param.data.copy_(local_param[tuple(tp_slice)].data)
-
-
-def sync_tp_from_ep(tp_model, ep_model, assert_grad_flag: bool = False) -> None:
- """Sync the parameters of tp model from ep model
-
- Args:
- tp_model (MoeModule)
- ep_model (MoeModule)
- """
- for (tp_name, tp_param), (ep_name, ep_param) in zip(tp_model.named_parameters(), ep_model.named_parameters()):
- assert tp_name == ep_name
- if not is_moe_tensor(tp_param):
- if assert_grad_flag:
- assert torch.allclose(tp_param, ep_param)
- assert torch.allclose(tp_param.grad, ep_param.grad)
- else:
- tp_param.data.copy_(ep_param.data)
- continue
-
- # gather param from ep model
- param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
- dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
- all_param = torch.cat(param_list, dim=0)
- if assert_grad_flag:
- grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
- dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
- all_grad = torch.cat(grad_list, dim=0)
-
- # get tp param
- tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape[1:], all_param.shape[1:])) if d1 != d2][0] + 1
- tp_rank = get_ep_rank(tp_param)
- tp_slice = [slice(None)] * tp_dim + [
- slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1))
- ]
- new_tp_param = all_param[tuple(tp_slice)]
- if assert_grad_flag:
- new_grad = all_grad[tuple(tp_slice)]
- if assert_grad_flag:
- assert torch.allclose(tp_param, new_tp_param)
- assert torch.allclose(tp_param.grad, new_grad)
- else:
- tp_param.data.copy_(new_tp_param.data)
-
-
-def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None:
- """Sync the parameters of tp model from ep model
-
- Args:
- local_model (MoeModule)
- ep_model (MoeModule)
- """
- for (local_name, local_param), (ep_name, ep_param) in zip(
- local_model.named_parameters(), ep_model.named_parameters()
- ):
- assert local_name == ep_name
- if "experts" not in local_name:
- if assert_grad_flag:
- assert torch.allclose(local_param, ep_param)
- assert torch.allclose(local_param.grad, ep_param.grad)
- else:
- local_param.data.copy_(ep_param.data)
- continue
-
- # gather param from ep model
- param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
- dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
- all_param = torch.cat(param_list, dim=0)
- if assert_grad_flag:
- grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
- dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
- all_grad = torch.cat(grad_list, dim=0)
-
- if assert_grad_flag:
- assert torch.allclose(local_param, all_param)
- assert torch.allclose(local_param.grad, all_grad)
- else:
- local_param.data.copy_(all_param.data)
-
-
-def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, config: Dict):
- assert batch_size % world_size == 0
-
- colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
-
- MOE_MANAGER.__init__()
- MOE_MANAGER.setup(parallel=None)
- local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
- MOE_MANAGER.__init__()
- MOE_MANAGER.setup(parallel="EP")
- enable_hierarchical_comm = config.get("enable_hierarchical_comm", False)
- if enable_hierarchical_comm:
- os.environ["LOCAL_WORLD_SIZE"] = str(world_size)
- ep_model = SparseMLP(
- num_experts=num_experts,
- hidden_size=dim,
- intermediate_size=dim * 2,
- enable_hierarchical_comm=enable_hierarchical_comm,
+from colossalai.booster.booster import Booster
+from colossalai.booster.plugin import HybridParallelPlugin
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+from colossalai.testing.random import seed_all
+from tests.test_moe.moe_utils import assert_loose_close
+
+NUM_BATCH = 4
+NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
+HIDDEN_SIZE_PER_HEAD = 4
+NUM_HEADS = 4
+TOP_K = 2
+
+
+@parameterize("stage", [1])
+@parameterize("ep_size", [2])
+def run_zero_with_original_model(stage: int, ep_size: int):
+ tp_size = dist.get_world_size() // ep_size
+ dtype = torch.bfloat16
+
+ rank = torch.distributed.get_rank()
+ torch.cuda.set_device(dist.get_rank())
+
+ seed_all(10086)
+
+ config = MixtralConfig(
+ hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
+ intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
+ num_hidden_layers=2,
+ num_attention_heads=NUM_HEADS,
+ num_key_value_heads=NUM_HEADS,
+ num_local_experts=NUM_EXPERTS,
+ num_experts_per_tok=TOP_K,
)
- MOE_MANAGER.__init__()
- MOE_MANAGER.setup(parallel="TP")
- tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
- ep_model = ep_model.to(get_accelerator().get_current_device())
- tp_model = tp_model.to(get_accelerator().get_current_device())
- local_model = local_model.to(get_accelerator().get_current_device())
-
- # sync ep param
- sync_moe_model_param(ep_model)
- dist_dict = MOE_MANAGER.parallel_info_dict
- assert_equal_in_group(ep_model.experts.wi.data, dist_dict[world_size].dp_group)
- assert_equal_in_group(ep_model.experts.wo.data, dist_dict[world_size].dp_group)
- ep_grad_handler = MoeGradientHandler(ep_model)
- # sync local param
- sync_local_from_ep(local_model, ep_model)
- # sync tp param
- sync_tp_from_ep(tp_model, ep_model)
- tp_grad_handler = MoeGradientHandler(tp_model)
-
- rank = dist.get_rank()
- input_data = torch.randn(batch_size, dim, device=get_accelerator().get_current_device())
- micro_batch_size = batch_size // world_size
- index = rank * micro_batch_size
- # NOTE: ep & tp takes in sharded data for each process
- shard_data = input_data.detach()[index : index + micro_batch_size]
-
- out_local = local_model(input_data)
- MOE_MANAGER.reset_loss()
- out_tp = tp_model(shard_data)
- MOE_MANAGER.reset_loss()
- out_ep = ep_model(shard_data)
- MOE_MANAGER.reset_loss()
-
- assert torch.allclose(
- out_tp, out_ep, atol=1e-6
- ), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}"
- try:
- out_local_slice = out_local[index : index + micro_batch_size]
- assert torch.allclose(
- out_ep, out_local_slice, atol=1e-6
- ), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}"
- except AssertionError:
- """
- e.g., in local model, tokens = 4, capacity = 2, experts = 2, topk = 1
- router yields [01] --> [0], [23] --> [1], this is valid as capacity is 2
- However, in ep mode, there are 2 separate routers dealing with sharded data.
- Assume router 0 handles token [01] and router 1 handles token [23].
- Note that for each router the capacity is only 1 !!!
- Thus, router 0 may yields [0] --> [0] or [1] --> [0], but not both.
- The same thing happens on router 1. And finally some tokens are dropped due to the sharded nature.
- """
- warnings.warn(
- "EP & TP may result in different behavior from local model. " "Please check the comments for details."
+ torch_model = MixtralModel(config).to(dtype).cuda()
+
+ zero_model = deepcopy(torch_model).to(dtype)
+ zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
+ moe_booster = Booster(
+ plugin=MoeHybridParallelPlugin(
+ tp_size=tp_size,
+ moe_tp_size=tp_size,
+ pp_size=1,
+ ep_size=ep_size,
+ zero_stage=stage,
+ overlap_communication=False,
+ initial_scale=1,
)
-
- out_local.mean().backward()
- out_tp.mean().backward()
- tp_grad_handler.handle_gradient()
- out_ep.mean().backward()
- ep_grad_handler.handle_gradient()
-
- assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[world_size].dp_group)
- assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[world_size].dp_group)
- sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True)
- try:
- sync_local_from_ep(local_model, ep_model, assert_grad_flag=True)
- except AssertionError:
- warnings.warn(
- "EP & TP may result in different behavior from local model. " "Please check the comments for details."
+ )
+ zero_model, zero_optimizer, _, _, _ = moe_booster.boost(zero_model, zero_optimizer)
+
+ hybird_booster = Booster(
+ plugin=HybridParallelPlugin(
+ tp_size=tp_size,
+ pp_size=1,
+ zero_stage=stage,
+ overlap_communication=False,
+ initial_scale=1,
)
+ )
+ hybrid_model, hybrid_optimizer, _, _, _ = hybird_booster.boost(
+ torch_model, torch.optim.SGD(torch_model.parameters(), lr=1)
+ )
+ # create different input
+ seed_all(1453 + rank)
+
+ hybrid_model.train()
+ zero_model.train()
+ for _ in range(2):
+ # zero-dp forward
+ input_data = torch.rand(
+ NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
+ ).cuda()
+ zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
+ # zero-dp backward
+ zero_optimizer.backward(zero_output)
+ # torch-ddp forward
+ hybrid_output = hybrid_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
+ assert_loose_close(zero_output, hybrid_output, dtype=dtype)
+ # torch-ddp backward
+ hybrid_optimizer.backward(hybrid_output)
+
+ # check grad
+ name_to_p = {n: p for n, p in hybrid_model.named_parameters()}
+ for n, p in zero_model.named_parameters():
+ zero_grad = zero_optimizer.get_param_grad(p)
+ if name_to_p[n].grad is None:
+ name_to_p[n].grad = torch.zeros_like(name_to_p[n])
+ continue
+ if zero_grad.shape != name_to_p[n].grad.shape: # TODO check sharded and sliced moe
+ continue
+ assert_loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n)
+
+ # zero-dp step
+ zero_optimizer.step()
+
+ # original model step
+ hybrid_optimizer.step()
+
+ # check updated param
+ for n, p in zero_model.named_parameters():
+ if p.data.shape != name_to_p[n].data.shape: # TODO check sharded and sliced moe
+ continue
+ assert_loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n)
+
+ print(f"{dist.get_rank()} test passed")
+
+
+def run_dist(rank, world_size, port):
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_zero_with_original_model()
-@pytest.mark.skip(reason="moe need to be refactored")
+@pytest.mark.skip("tested in corresponding sharderformer")
@pytest.mark.dist
-@pytest.mark.parametrize("num_experts", [4, 64])
-@pytest.mark.parametrize("batch_size", [16])
-@pytest.mark.parametrize("dim", [64])
-@pytest.mark.parametrize(
- "config",
- [
- {"enable_hierarchical_comm": False},
- {"enable_hierarchical_comm": True},
- ],
-)
+@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
-def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict):
- spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config)
+def test_moe_ep_tp(world_size):
+ spawn(run_dist, world_size)
if __name__ == "__main__":
- test_moe_ep_tp(num_experts=8, batch_size=32, dim=32)
+ test_moe_ep_tp(world_size=4)
diff --git a/tests/test_moe/test_moe_ep_zero.py b/tests/test_moe/test_moe_ep_zero.py
new file mode 100644
index 000000000000..2d4e638b638a
--- /dev/null
+++ b/tests/test_moe/test_moe_ep_zero.py
@@ -0,0 +1,119 @@
+from copy import deepcopy
+
+import pytest
+import torch
+import torch.distributed as dist
+from torch.nn.parallel import DistributedDataParallel as DDP
+from transformers.models.mixtral.configuration_mixtral import MixtralConfig
+from transformers.models.mixtral.modeling_mixtral import MixtralModel
+
+import colossalai
+from colossalai.booster.booster import Booster
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+from colossalai.testing.random import seed_all
+from tests.test_moe.moe_utils import assert_loose_close
+
+NUM_BATCH = 4
+NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
+HIDDEN_SIZE_PER_HEAD = 4
+NUM_HEADS = 2
+TOP_K = 1
+
+
+@parameterize("stage", [1])
+@parameterize("ep_size", [2, 4])
+def run_zero_with_original_model(stage: int, ep_size: int):
+ dtype = torch.bfloat16
+
+ rank = torch.distributed.get_rank()
+ torch.cuda.set_device(dist.get_rank())
+
+ plugin = MoeHybridParallelPlugin(
+ pp_size=1, tp_size=1, ep_size=ep_size, zero_stage=stage, overlap_communication=False, initial_scale=1
+ )
+ booster = Booster(plugin=plugin)
+
+ seed_all(10086)
+
+ config = MixtralConfig(
+ hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
+ intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
+ num_hidden_layers=2,
+ num_attention_heads=NUM_HEADS,
+ num_key_value_heads=NUM_HEADS,
+ num_local_experts=NUM_EXPERTS,
+ num_experts_per_tok=TOP_K,
+ )
+
+ torch_model = MixtralModel(config).to(dtype).cuda()
+
+ zero_model = deepcopy(torch_model).to(dtype)
+ zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
+
+ zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
+
+ ddp_model = DDP(
+ torch_model.cuda(),
+ process_group=plugin.dp_group,
+ find_unused_parameters=True, # important for torch ddp, not all experts are routed
+ ).cuda()
+ ddp_optimizer = torch.optim.SGD(ddp_model.parameters(), lr=1)
+
+ # create different input
+ seed_all(1453 + rank)
+
+ ddp_model.train()
+ zero_model.train()
+ for _ in range(2):
+ # zero-dp forward
+ input_data = torch.rand(
+ NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
+ ).cuda()
+ zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
+ # zero-dp backward
+ zero_optimizer.backward(zero_output)
+
+ # torch-ddp forward
+ ddp_output = ddp_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
+ assert_loose_close(zero_output, ddp_output, dtype=dtype)
+ # torch-ddp backward
+ ddp_output.backward()
+
+ # check grad
+ name_to_p = {n: p for n, p in ddp_model.named_parameters()}
+ for n, p in zero_model.named_parameters():
+ zero_grad = zero_optimizer.get_param_grad(p)
+ if name_to_p[n].grad is None:
+ name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
+ continue
+ assert_loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n)
+
+ # zero-dp step
+ zero_optimizer.step()
+
+ # original model step
+ ddp_optimizer.step()
+
+ # check updated param
+ for n, p in zero_model.named_parameters():
+ assert_loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n)
+
+ print(f"{dist.get_rank()} test passed")
+
+
+def run_dist(rank, world_size, port):
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_zero_with_original_model()
+
+
+@pytest.mark.skip("tested in corresponding sharderformer")
+@pytest.mark.dist
+@pytest.mark.parametrize("world_size", [4])
+@rerun_if_address_is_in_use()
+def test_moe_ep_zero(world_size):
+ spawn(run_dist, world_size)
+
+
+if __name__ == "__main__":
+ test_moe_ep_zero(world_size=4)
diff --git a/tests/test_moe/test_moe_zero_fwd_bwd_optim.py b/tests/test_moe/test_moe_zero_fwd_bwd_optim.py
deleted file mode 100644
index 042b3d8aedc5..000000000000
--- a/tests/test_moe/test_moe_zero_fwd_bwd_optim.py
+++ /dev/null
@@ -1,132 +0,0 @@
-from copy import deepcopy
-
-import pytest
-import torch
-import torch.distributed as dist
-from torch.nn.parallel import DistributedDataParallel as DDP
-from transformers.models.mixtral.configuration_mixtral import MixtralConfig
-from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
-
-import colossalai
-from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
-from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock
-from colossalai.tensor.moe_tensor.api import is_moe_tensor
-from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.testing.random import seed_all
-from colossalai.zero import LowLevelZeroOptimizer
-from tests.test_moe.moe_utils import loose_close
-
-tokens, n_experts = 7, 4
-hidden_size = 8
-top_k = 2
-
-
-def split_grad(grad, world_size):
- with torch.no_grad():
- grad = grad.clone().detach().flatten()
- padding_size = (world_size - grad.numel() % world_size) % world_size
- if padding_size > 0:
- grad = torch.nn.functional.pad(grad, [0, padding_size])
- splited_grad = grad.split(grad.numel() // world_size)
- return splited_grad
-
-
-@parameterize("dtype", [torch.float16, torch.bfloat16])
-@parameterize("master_weights", [True, False])
-@parameterize("stage", [1, 2])
-def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.dtype, stage: int):
- rank = torch.distributed.get_rank()
- torch.cuda.set_device(dist.get_rank())
- plugin = MoeHybridParallelPlugin(
- tp_size=1,
- pp_size=1,
- ep_size=dist.get_world_size() // 2,
- )
-
- seed_all(10086)
- config = MixtralConfig(
- hidden_size=hidden_size,
- intermediate_size=hidden_size * 2,
- num_local_experts=n_experts,
- num_experts_per_tok=top_k,
- )
-
- orig_model = MixtralSparseMoeBlock(config).to(dtype).cuda()
-
- ori_model = DDP(orig_model.cuda(), static_graph=True).cuda()
-
- zero_model = deepcopy(orig_model).to(dtype)
- zero_model = EPMixtralSparseMoeBlock.from_native_module(zero_model, ep_group=plugin.ep_group)
-
- zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
- pg_param_list = {plugin.global_dp_group: [], plugin.moe_dp_group: []}
- for p in zero_model.parameters():
- if is_moe_tensor(p):
- pg_param_list[plugin.moe_dp_group].append(p)
- else:
- pg_param_list[plugin.global_dp_group].append(p)
-
- zero_optimizer = LowLevelZeroOptimizer(
- zero_optimizer,
- pg_to_param_list=pg_param_list,
- master_weights=master_weights,
- initial_scale=1,
- overlap_communication=False,
- partition_grad=True,
- )
-
- ori_optimizer = torch.optim.SGD(ori_model.parameters(), lr=1)
-
- # create
- seed_all(1453 + rank)
-
- for _ in range(2):
- # zero-dp forward
- input_data = torch.rand(1, tokens, hidden_size).cuda()
- zero_output, zero_logits = zero_model(input_data.to(dtype))
-
- # torch-ddp forward
- ori_output, ori_logits = ori_model(input_data.to(dtype))
- loose_close(zero_output, ori_output, dtype=dtype)
-
- # zero-dp backward
- zero_optimizer.backward(zero_output.mean().float())
-
- # torch-ddp backward
- ori_output.mean().backward()
-
- # check grad
- name_to_p = {n: p for n, p in ori_model.module.named_parameters()}
- for n, p in zero_model.named_parameters():
- zero_grad = zero_optimizer.get_param_grad(p)
- if name_to_p[n].grad is None:
- assert zero_grad is None
- continue
-
- loose_close(zero_grad, name_to_p[n].grad, dtype=dtype)
-
- # zero-dp step
- zero_optimizer.step()
-
- # original model step
- ori_optimizer.step()
-
- # check updated param
- for n, p in zero_model.named_parameters():
- loose_close(p.data, name_to_p[n].data, dtype=dtype)
-
-
-def run_dist(rank, world_size, port):
- colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- run_zero_with_original_model(world_size=world_size)
-
-
-@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [2, 4])
-@rerun_if_address_is_in_use()
-def test_moe_zero_model(world_size):
- spawn(run_dist, world_size)
-
-
-if __name__ == "__main__":
- test_moe_zero_model(world_size=4)
diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py
index 1ffcc541a854..190fee12931b 100644
--- a/tests/test_shardformer/test_model/_utils.py
+++ b/tests/test_shardformer/test_model/_utils.py
@@ -1,6 +1,6 @@
import copy
from contextlib import nullcontext
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Callable, Dict, List, Optional, Type
import torch
import torch.distributed as dist
@@ -117,7 +117,12 @@ def check_state_dict(org_model: Module, sharded_model: Module, name: str = ""):
def build_model_from_hybrid_plugin(
- model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any], optim_class=Adam, sharded_optim_class=Adam
+ model_fn: Callable,
+ loss_fn: Callable,
+ test_config: Dict[str, Any],
+ optim_class=Adam,
+ sharded_optim_class=Adam,
+ pluggin_cls: Type[HybridParallelPlugin] = HybridParallelPlugin,
):
use_lazy_init = False
if "use_lazy_init" in test_config:
@@ -149,9 +154,10 @@ def build_model_from_hybrid_plugin(
else:
org_optimizer = optim_class(org_model.parameters(), lr=1e-3)
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
+
criterion = loss_fn
- plugin = HybridParallelPlugin(**test_config)
+ plugin = pluggin_cls(**test_config)
booster = Booster(plugin=plugin)
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py
new file mode 100644
index 000000000000..46da4522fd9d
--- /dev/null
+++ b/tests/test_shardformer/test_model/test_shard_deepseek.py
@@ -0,0 +1,196 @@
+import os
+import shutil
+from copy import deepcopy
+from typing import Tuple
+
+import pytest
+import torch
+import torch.distributed
+import torch.distributed as dist
+from transformers import AutoConfig, AutoModel
+
+import colossalai
+from colossalai.booster.booster import Booster
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+from colossalai.testing.random import seed_all
+from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
+
+NUM_BATCH = 8
+NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 2
+NUM_LAYERS = 4
+HIDDEN_SIZE_PER_HEAD = 4
+NUM_HEADS = 4
+TOP_K = 2
+
+
+CHECKED_CONFIG = [ # FOR_WORLD=4
+ (1, 4, 1, 1, 1),
+ (1, 1, 4, 1, 1),
+ (1, 1, 1, 4, 1),
+ (1, 1, 1, 1, 4),
+ (0, 1, 4, 1, 1),
+ (0, 1, 1, 4, 1),
+ (0, 1, 1, 1, 4),
+ (1, 2, 1, 1, 1),
+]
+
+
+@parameterize(
+ "config",
+ [
+ (1, 2, 2, 1, 1),
+ (1, 2, 1, 2, 1),
+ (1, 2, 1, 1, 2),
+ ],
+)
+def run_zero_with_original_model(config: Tuple[int, ...]):
+ stage, ep_size, pp_size, tp_size, sp_size = config
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
+ dtype, precision = torch.float16, "fp16"
+ torch.cuda.set_device(dist.get_rank())
+
+ plugin = MoeHybridParallelPlugin(
+ pp_size=pp_size,
+ num_microbatches=pp_size,
+ tp_size=tp_size,
+ sp_size=sp_size,
+ ep_size=ep_size,
+ zero_stage=stage,
+ enable_sequence_parallelism=sp_size > 1,
+ sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
+ enable_flash_attention=sp_size > 1,
+ overlap_communication=False,
+ initial_scale=1,
+ precision=precision,
+ find_unused_parameters=True,
+ )
+ dp_size = plugin.dp_size
+
+ booster = Booster(plugin=plugin)
+
+ assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
+ config = AutoConfig.from_pretrained(
+ "deepseek-ai/deepseek-moe-16b-base",
+ hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
+ intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
+ moe_intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
+ num_hidden_layers=4,
+ num_attention_heads=NUM_HEADS,
+ num_key_value_heads=NUM_HEADS,
+ first_k_dense_replace=1,
+ attn_implementation="flash_attention_2",
+ torch_dtype="float16",
+ n_routed_experts=NUM_EXPERTS,
+ num_experts_per_tok=TOP_K,
+ trust_remote_code=True,
+ )
+
+ # init model with the same seed
+ seed_all(10086)
+
+ torch_model = AutoModel.from_config(config, trust_remote_code=True).cuda().to(dtype)
+ torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
+
+ parallel_model = deepcopy(torch_model)
+ parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1)
+ parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer)
+
+ # create different input along dp axis
+ seed_all(1453 + rank)
+
+ torch_model.train()
+ parallel_model.train()
+ for _ in range(2):
+ # gen random input
+ input_embeddings = torch.rand(
+ NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
+ ).cuda()
+ dist.all_reduce(
+ input_embeddings, group=plugin.pp_group
+ ) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
+
+ dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input
+ dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input
+
+ # run the model with hybrid parallel
+ if booster.plugin.stage_manager is not None:
+ # for test with pp
+ data_iter = iter([{"inputs_embeds": input_embeddings}])
+ sharded_output = booster.execute_pipeline(
+ data_iter,
+ parallel_model,
+ lambda x, y: x[0].mean(),
+ parallel_optimizer,
+ return_loss=True,
+ return_outputs=True,
+ )
+ if booster.plugin.stage_manager.is_last_stage():
+ parallel_output = sharded_output["loss"]
+ else:
+ parallel_output = torch.tensor(12345.0, device="cuda")
+
+ # broadcast along pp axis
+ dist.broadcast(
+ parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[-1], group=plugin.pp_group
+ )
+ else:
+ # for test without pp
+ parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean()
+ parallel_optimizer.backward(parallel_output)
+ parallel_optimizer.step()
+ parallel_optimizer.zero_grad()
+ dist.all_reduce(parallel_output, group=plugin.dp_group)
+
+ # ===================================================================================
+ # run normal model with all dp(different) inputs
+ all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
+ dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
+ torch_output_sum = 0
+ for input_data_ in all_inputs:
+ torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
+ torch_output.backward()
+ torch_output_sum += torch_output.detach()
+ # avg dp grads follows zero optimizer
+ for p in torch_model.parameters():
+ if p.grad is not None:
+ p.grad /= dp_size
+ torch_optimizer.step()
+ torch_optimizer.zero_grad()
+
+ assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
+
+ # use checkpoint to load sharded zero model
+ model_dir = "./test_deepseek"
+ if rank == world_size - 1:
+ os.makedirs(model_dir, exist_ok=True)
+
+ dist.barrier()
+ booster.save_model(parallel_model, model_dir, shard=True)
+ dist.barrier()
+
+ saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda()
+ check_model_equal(torch_model, saved_model)
+ dist.barrier()
+
+ if rank == world_size - 1:
+ shutil.rmtree(model_dir)
+
+ print(f"rank {dist.get_rank()} test passed")
+
+
+def run_dist(rank, world_size, port):
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_zero_with_original_model()
+
+
+@pytest.mark.dist
+@pytest.mark.parametrize("world_size", [4])
+@rerun_if_address_is_in_use()
+def test_deepseek(world_size):
+ spawn(run_dist, world_size)
+
+
+if __name__ == "__main__":
+ test_deepseek(world_size=4)
diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py
new file mode 100644
index 000000000000..de09eedcbed5
--- /dev/null
+++ b/tests/test_shardformer/test_model/test_shard_mixtral.py
@@ -0,0 +1,190 @@
+import os
+import shutil
+from copy import deepcopy
+from typing import Tuple
+
+import pytest
+import torch
+import torch.distributed
+import torch.distributed as dist
+from transformers.models.mixtral.configuration_mixtral import MixtralConfig
+from transformers.models.mixtral.modeling_mixtral import MixtralModel
+
+import colossalai
+from colossalai.booster.booster import Booster
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+from colossalai.testing.random import seed_all
+from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
+
+NUM_BATCH = 8
+NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
+NUM_LAYERS = 4
+HIDDEN_SIZE_PER_HEAD = 4
+NUM_HEADS = 4
+TOP_K = 1
+
+CHECKED_CONFIG = [ # FOR WORLD=4
+ (0, 1, 4, 1, 1),
+ (0, 1, 1, 4, 1),
+ (0, 1, 1, 1, 4),
+ (1, 4, 1, 1, 1),
+ (1, 1, 4, 1, 1),
+ (1, 1, 1, 4, 1),
+ (1, 1, 1, 1, 4),
+ (1, 2, 1, 1, 1),
+]
+
+
+@parameterize(
+ "config",
+ [
+ (1, 2, 2, 1, 1),
+ (1, 2, 1, 2, 1),
+ (1, 2, 1, 1, 2),
+ ],
+)
+def run_zero_with_original_model(config: Tuple[int, ...]):
+ stage, ep_size, pp_size, tp_size, sp_size = config
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
+ dtype, precision = torch.float16, "fp16"
+ torch.cuda.set_device(dist.get_rank())
+
+ plugin = MoeHybridParallelPlugin(
+ pp_size=pp_size,
+ num_microbatches=pp_size,
+ tp_size=tp_size,
+ sp_size=sp_size,
+ ep_size=ep_size,
+ zero_stage=stage,
+ enable_sequence_parallelism=sp_size > 1,
+ sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
+ overlap_communication=False,
+ initial_scale=1,
+ precision=precision,
+ find_unused_parameters=True,
+ )
+ dp_size = plugin.dp_size
+
+ booster = Booster(plugin=plugin)
+
+ assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
+ config = MixtralConfig(
+ hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
+ intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
+ num_hidden_layers=NUM_LAYERS,
+ num_attention_heads=NUM_HEADS,
+ num_key_value_heads=NUM_HEADS,
+ num_local_experts=NUM_EXPERTS,
+ num_experts_per_tok=TOP_K,
+ attn_implementation="flash_attention_2",
+ )
+
+ # init model with the same seed
+ seed_all(10086)
+
+ torch_model = MixtralModel(config).to(dtype).cuda()
+ torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
+
+ parallel_model = deepcopy(torch_model)
+ parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1)
+ parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer)
+
+ # create different input along dp axis
+ seed_all(1453 + rank)
+
+ torch_model.train()
+ parallel_model.train()
+ for _ in range(2):
+ # gen random input
+ input_embeddings = torch.rand(
+ NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
+ ).cuda()
+ dist.all_reduce(
+ input_embeddings, group=plugin.pp_group
+ ) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
+
+ dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input
+ dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input
+
+ # run the model with hybrid parallel
+ if booster.plugin.stage_manager is not None:
+ # for test with pp
+ data_iter = iter([{"inputs_embeds": input_embeddings}])
+ sharded_output = booster.execute_pipeline(
+ data_iter,
+ parallel_model,
+ lambda x, y: x.last_hidden_state.mean(),
+ parallel_optimizer,
+ return_loss=True,
+ return_outputs=True,
+ )
+ if booster.plugin.stage_manager.is_last_stage():
+ parallel_output = sharded_output["loss"]
+ else:
+ parallel_output = torch.tensor(12345.0, device="cuda")
+
+ # broadcast along pp axis
+ dist.broadcast(
+ parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[-1], group=plugin.pp_group
+ )
+ else:
+ # for test without pp
+ parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean()
+ parallel_optimizer.backward(parallel_output)
+ parallel_optimizer.step()
+ parallel_optimizer.zero_grad()
+ dist.all_reduce(parallel_output, group=plugin.dp_group)
+
+ # ===================================================================================
+ # run normal model with all dp(different) inputs
+ all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
+ dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
+ torch_output_sum = 0
+ for input_data_ in all_inputs:
+ torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
+ torch_output.backward()
+ torch_output_sum += torch_output.detach()
+ # avg dp grads follows zero optimizer
+ for p in torch_model.parameters():
+ if p.grad is not None:
+ p.grad /= dp_size
+ torch_optimizer.step()
+ torch_optimizer.zero_grad()
+
+ assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
+
+ # use checkpoint to load sharded zero model
+ model_dir = "./test_mixtral"
+ if rank == world_size - 1:
+ os.makedirs(model_dir, exist_ok=True)
+
+ dist.barrier()
+ booster.save_model(parallel_model, model_dir, shard=True)
+ dist.barrier()
+
+ saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
+ check_model_equal(torch_model, saved_model)
+ dist.barrier()
+
+ if rank == world_size - 1:
+ shutil.rmtree(model_dir)
+
+ print(f"rank {dist.get_rank()} test passed")
+
+
+def run_dist(rank, world_size, port):
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_zero_with_original_model()
+
+
+@pytest.mark.dist
+@pytest.mark.parametrize("world_size", [4])
+@rerun_if_address_is_in_use()
+def test_mixtral(world_size):
+ spawn(run_dist, world_size)
+
+
+if __name__ == "__main__":
+ test_mixtral(world_size=4)
diff --git a/version.txt b/version.txt
index 267577d47e49..2b7c5ae01848 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-0.4.1
+0.4.2