-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[Chat] Rlhf support SimPO #5850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
82aecd6
add SimPO
YeAnbang 4b59d87
Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into main
YeAnbang 0b2d627
fix dataloader
YeAnbang f3de5a0
remove debug code
YeAnbang c8d1b4a
add orpo
YeAnbang 8aad064
fix style
YeAnbang 384c640
fix colossalai, transformers version
YeAnbang afa5306
fix colossalai, transformers version
YeAnbang b117274
fix colossalai, transformers version
YeAnbang e752776
Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into r…
YeAnbang a8af6cc
fix torch colossalai version
YeAnbang ff53520
update transformers version
YeAnbang 16f3451
Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into r…
YeAnbang d888c37
add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Sup…
YeAnbang f6ef5c3
fix style
YeAnbang 33f1520
Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into r…
YeAnbang 8a9721b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,340 @@ | ||
| import argparse | ||
| import json | ||
| import os | ||
| import resource | ||
| from contextlib import nullcontext | ||
|
|
||
| import torch | ||
| from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler | ||
| from coati.models import convert_to_lora_module, disable_dropout | ||
| from coati.trainer import DPOTrainer | ||
| from coati.utils import load_checkpoint | ||
| from dummy_dataset import DummyLLMDataset | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| import colossalai | ||
| from colossalai.booster import Booster | ||
| from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin | ||
| from colossalai.cluster import DistCoordinator | ||
| from colossalai.logging import get_dist_logger | ||
| from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR | ||
| from colossalai.nn.optimizer import HybridAdam | ||
|
|
||
| logger = get_dist_logger() | ||
|
|
||
|
|
||
| def train(args): | ||
| # check lora compatibility | ||
| if "gemini" in args.plugin and args.lora_rank > 0: | ||
| raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") | ||
| if args.plugin == "gemini_auto" and args.accumulation_steps > 1: | ||
| raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin") | ||
|
|
||
| # ============================== | ||
| # Initialize Distributed Training | ||
| # ============================== | ||
| colossalai.launch_from_torch() | ||
| coordinator = DistCoordinator() | ||
|
|
||
| # ============================== | ||
| # Initialize Booster | ||
| # ============================== | ||
| if args.plugin == "ddp": | ||
| """ | ||
| Default torch ddp plugin without any acceleration, for | ||
| debugging purpose acceleration, for debugging purpose | ||
| """ | ||
| plugin = TorchDDPPlugin(find_unused_parameters=True) | ||
| elif args.plugin == "gemini": | ||
| plugin = GeminiPlugin( | ||
| precision=args.mixed_precision, | ||
| placement_policy="static", | ||
| initial_scale=2**16, | ||
| max_norm=args.grad_clip, | ||
| enable_gradient_accumulation=True, | ||
| enable_flash_attention=args.use_flash_attn, | ||
| ) | ||
| elif args.plugin == "gemini_auto": | ||
| plugin = GeminiPlugin( | ||
| precision=args.mixed_precision, | ||
| placement_policy="auto", | ||
| initial_scale=2**16, | ||
| max_norm=args.grad_clip, | ||
| enable_flash_attention=args.use_flash_attn, | ||
| ) | ||
| elif args.plugin == "zero2": | ||
| plugin = LowLevelZeroPlugin( | ||
| stage=2, | ||
| precision=args.mixed_precision, | ||
| initial_scale=2**16, | ||
| max_norm=args.grad_clip, | ||
| ) | ||
| elif args.plugin == "zero2_cpu": | ||
| plugin = LowLevelZeroPlugin( | ||
| stage=2, | ||
| precision=args.mixed_precision, | ||
| initial_scale=2**16, | ||
| cpu_offload=True, | ||
| max_norm=args.grad_clip, | ||
| ) | ||
| elif args.plugin == "3d": | ||
| plugin = HybridParallelPlugin( | ||
| tp_size=args.tp, | ||
| pp_size=args.pp, | ||
| sp_size=args.sp, | ||
| sequence_parallelism_mode=args.sp_mode, | ||
| zero_stage=args.zero_stage, | ||
| enable_flash_attention=args.use_flash_attn, | ||
| enable_sequence_parallelism=args.enable_sequence_parallelism, | ||
| cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False, | ||
| parallel_output=False, | ||
| max_norm=args.grad_clip, | ||
| precision=args.mixed_precision, | ||
| ) | ||
| else: | ||
| raise ValueError(f"Unknown plugin {args.plugin}") | ||
|
|
||
| booster = Booster(plugin=plugin) | ||
| ref_booster = Booster(plugin=plugin) | ||
|
|
||
| # ====================================================== | ||
| # Initialize Model, Objective, Optimizer and LR Scheduler | ||
| # ====================================================== | ||
| # Temp Fix: Disable lazy init due to version conflict | ||
| # init_ctx = ( | ||
| # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() | ||
| # ) | ||
|
|
||
| init_ctx = nullcontext() | ||
| with init_ctx: | ||
| if args.use_flash_attn: | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| args.pretrain, | ||
| torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, | ||
| use_flash_attention_2=True, | ||
| ) | ||
| coordinator.print_on_master(msg="Flash-attention enabled successfully") | ||
| else: | ||
| model = AutoModelForCausalLM.from_pretrained(args.pretrain) | ||
| disable_dropout(model) | ||
| if not args.disable_reference_model: | ||
| if args.use_flash_attn: | ||
| ref_model = AutoModelForCausalLM.from_pretrained( | ||
| args.pretrain, | ||
| torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, | ||
| use_flash_attention_2=True, | ||
| ) | ||
| else: | ||
| ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain) | ||
| disable_dropout(ref_model) | ||
| else: | ||
| ref_model = None | ||
| if args.lora_rank > 0: | ||
| model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias) | ||
|
|
||
| if args.grad_checkpoint: | ||
| # Note, for some models, lora may not be compatible with gradient checkpointing | ||
| model.gradient_checkpointing_enable() | ||
| coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") | ||
|
|
||
| # configure tokenizer | ||
| tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain | ||
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True) | ||
| if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None: | ||
| try: | ||
| # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen | ||
| tokenizer.pad_token = tokenizer.eos_token | ||
| except AttributeError as e: | ||
| logger.warning(f"Unable to set pad token to eos token, {str(e)}") | ||
| if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None: | ||
| logger.warning( | ||
| "The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them." | ||
| ) | ||
|
|
||
| tokenizer.add_bos_token = False | ||
| tokenizer.add_eos_token = False | ||
|
|
||
| # configure optimizer | ||
| optim = HybridAdam( | ||
| model_params=model.parameters(), | ||
| lr=args.lr, | ||
| betas=(0.9, 0.95), | ||
| weight_decay=args.weight_decay, | ||
| adamw_mode=True, | ||
| ) | ||
|
|
||
| # configure dataset | ||
| mode_map = {"train": "train", "valid": "validation", "test": "test"} | ||
| train_dataset = DummyLLMDataset( | ||
| ["chosen_input_ids", "chosen_loss_mask", "rejected_input_ids", "rejected_loss_mask"], | ||
| args.max_length, | ||
| args.dataset_size, | ||
| ) | ||
| data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length) | ||
|
|
||
| train_dataloader = plugin.prepare_dataloader( | ||
| dataset=train_dataset, | ||
| batch_size=args.batch_size, | ||
| shuffle=True, | ||
| drop_last=True, | ||
| collate_fn=data_collator, | ||
| distributed_sampler_cls=StatefulDistributedSampler, | ||
| ) | ||
|
|
||
| num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps | ||
| if args.warmup_steps is None: | ||
| args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps)) | ||
| coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}") | ||
|
|
||
| lr_scheduler = CosineAnnealingWarmupLR( | ||
| optimizer=optim, | ||
| total_steps=args.max_epochs * num_update_steps_per_epoch, | ||
| warmup_steps=args.warmup_steps, | ||
| eta_min=0.1 * args.lr, | ||
| ) | ||
|
|
||
| default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 | ||
| torch.set_default_dtype(default_dtype) | ||
| model, optim, _, train_dataloader, lr_scheduler = booster.boost( | ||
| model=model, | ||
| optimizer=optim, | ||
| lr_scheduler=lr_scheduler, | ||
| dataloader=train_dataloader, | ||
| ) | ||
| if ref_model is not None: | ||
| ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_dataloader) | ||
| torch.set_default_dtype(torch.float) | ||
|
|
||
| coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") | ||
| coordinator.print_on_master( | ||
| f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" | ||
| ) | ||
|
|
||
| start_epoch = 0 | ||
| sampler_start_idx = 0 | ||
| start_step = 0 | ||
| if args.checkpoint_path is not None: | ||
| if "modeling" in args.checkpoint_path: | ||
| coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}") | ||
| booster.load_model(model, args.checkpoint_path) | ||
| else: | ||
| coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}") | ||
| start_epoch, start_step, sampler_start_idx = load_checkpoint( | ||
| load_dir=args.checkpoint_path, | ||
| booster=booster, | ||
| model=model, | ||
| optimizer=optim, | ||
| lr_scheduler=lr_scheduler, | ||
| ) | ||
| assert isinstance(train_dataloader.sampler, StatefulDistributedSampler) | ||
| train_dataloader.sampler.set_start_index(start_index=sampler_start_idx) | ||
|
|
||
| coordinator.print_on_master( | ||
| f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}" | ||
| ) | ||
| coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}") | ||
|
|
||
| coordinator.print_on_master( | ||
| f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" | ||
| ) | ||
| coordinator.print_on_master( | ||
| f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB" | ||
| ) | ||
| coordinator.print_on_master( | ||
| f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" | ||
| ) | ||
|
|
||
| trainer = DPOTrainer( | ||
| actor=model, | ||
| ref_model=ref_model, | ||
| booster=booster, | ||
| actor_optim=optim, | ||
| actor_lr_scheduler=lr_scheduler, | ||
| tokenizer=tokenizer, | ||
| max_epochs=args.max_epochs, | ||
| accumulation_steps=args.accumulation_steps, | ||
| start_epoch=start_epoch, | ||
| save_interval=None, | ||
| save_dir=None, | ||
| coordinator=coordinator, | ||
| beta=args.beta, | ||
| gamma=args.gamma, | ||
| length_normalization=args.length_normalization, | ||
| ) | ||
|
|
||
| trainer.fit( | ||
| train_preference_dataloader=train_dataloader, | ||
| eval_preference_dataloader=None, | ||
| log_dir=None, | ||
| use_wandb=False, | ||
| ) | ||
| coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| # ============================== | ||
| # Parse Arguments | ||
| # ============================== | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument( | ||
| "--plugin", | ||
| type=str, | ||
| default="gemini", | ||
| choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"], | ||
| help="Choose which plugin to use", | ||
| ) | ||
| parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") | ||
| parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") | ||
| parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") | ||
| parser.add_argument("--tp", type=int, default=1) | ||
| parser.add_argument("--pp", type=int, default=1) | ||
| parser.add_argument("--sp", type=int, default=1) | ||
| parser.add_argument("--loss_type", type=str, default="dpo_loss", help="dpo_loss or simpo_loss") | ||
| parser.add_argument("--beta", type=float, default=0.1, help="beta in DPO loss") | ||
| parser.add_argument("--gamma", type=float, default=0.0, help="gamma in SimPO loss") | ||
| parser.add_argument("--length_normalization", default=False, action="store_true") | ||
| parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true") | ||
| parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2]) | ||
| parser.add_argument("--zero_cpu_offload", default=False, action="store_true") | ||
| parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"]) | ||
| parser.add_argument("--pretrain", type=str, default=None) | ||
| parser.add_argument("--model_type", type=str, default=None) | ||
| parser.add_argument("--tokenizer_dir", type=str, default=None) | ||
| parser.add_argument( | ||
| "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint" | ||
| ) | ||
| parser.add_argument("--config_file", type=str, default="config_file", help="Config file") | ||
| parser.add_argument("--max_length", type=int, default=2048, help="Model max length") | ||
| parser.add_argument("--max_epochs", type=int, default=3) | ||
| parser.add_argument("--batch_size", type=int, default=4) | ||
| parser.add_argument("--dataset_size", type=int, default=500) | ||
| parser.add_argument( | ||
| "--disable_reference_model", | ||
| action="store_true", | ||
| default=False, | ||
| help="Disable the reference model (enabled by default)", | ||
| ) | ||
| parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") | ||
| parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") | ||
| parser.add_argument( | ||
| "--lora_train_bias", | ||
| type=str, | ||
| default="none", | ||
| help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers", | ||
| ) | ||
| parser.add_argument("--merge_lora_weights", type=bool, default=True) | ||
| parser.add_argument("--lr", type=float, default=5e-6) | ||
| parser.add_argument("--accumulation_steps", type=int, default=8) | ||
| parser.add_argument("--grad_checkpoint", default=False, action="store_true") | ||
| parser.add_argument("--use_flash_attn", default=False, action="store_true") | ||
| args = parser.parse_args() | ||
|
|
||
| # fool proof hyperparameter setup | ||
| if args.loss_type == "simpo_loss": | ||
| args.length_normalization = True | ||
| args.gamma = args.gamma if args.gamma > 0 else 1.4 | ||
|
|
||
| os.makedirs(os.path.dirname(args.config_file), exist_ok=True) | ||
| with open(args.config_file, "w") as f: | ||
| json.dump(args.__dict__, f, indent=4) | ||
| train(args) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.