diff --git a/applications/Chat/benchmarks/benchmark_large_bloom_lora_dummy.py b/applications/Chat/benchmarks/benchmark_large_bloom_lora_dummy.py new file mode 100644 index 000000000000..1c04474a01cd --- /dev/null +++ b/applications/Chat/benchmarks/benchmark_large_bloom_lora_dummy.py @@ -0,0 +1,236 @@ +import argparse +import os +import resource +from contextlib import contextmanager +from copy import deepcopy + +import psutil +import torch +import torch.distributed as dist +import torch.nn as nn +from coati.models.base import RewardModel +from coati.models.bloom import BLOOMActor, BLOOMCritic +from coati.trainer import PPOTrainer +from coati.trainer.callbacks import PerformanceEvaluator +from coati.trainer.strategies import ColossalAIStrategy, Strategy, TPZeroStrategy +from torch.optim import Adam +from torch.utils.data import DataLoader +from transformers import AutoTokenizer +from transformers.modeling_utils import no_init_weights +from transformers.models.bloom.configuration_bloom import BloomConfig + +from colossalai.nn.optimizer import HybridAdam + + +def get_model_numel(model: nn.Module, strategy: Strategy) -> int: + numel = sum(p.numel() for p in model.parameters()) + if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init: + numel *= dist.get_world_size() + return numel + + +def preprocess_batch(samples) -> dict: + input_ids = torch.stack(samples) + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + return {'input_ids': input_ids, 'attention_mask': attention_mask} + + +def preprocess_ptx_batch(samples) -> dict: + batch = preprocess_batch(samples) + batch['labels'] = batch['input_ids'] + return batch + + +def print_rank_0(*args, **kwargs) -> None: + if dist.get_rank() == 0: + print(*args, **kwargs) + + +def get_max_memory() -> int: + return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + + +@contextmanager +def low_precision_init(target_dtype: torch.dtype = torch.float16): + dtype = torch.get_default_dtype() + try: + torch.set_default_dtype(target_dtype) + yield + finally: + torch.set_default_dtype(dtype) + + +def print_model_numel(model_dict: dict) -> None: + B = 1024**3 + M = 1024**2 + K = 1024 + outputs = '' + for name, numel in model_dict.items(): + outputs += f'{name}: ' + if numel >= B: + outputs += f'{numel / B:.2f} B\n' + elif numel >= M: + outputs += f'{numel / M:.2f} M\n' + elif numel >= K: + outputs += f'{numel / K:.2f} K\n' + else: + outputs += f'{numel}\n' + print_rank_0(outputs) + + +def get_gpt_config(model_name: str) -> BloomConfig: + model_map = { + '350m': BloomConfig(hidden_size=1024, n_layer=24, n_head=16), + '560m': BloomConfig.from_pretrained('bigscience/bloom-560m'), + '1.1b': BloomConfig.from_pretrained('bigscience/bloom-1b1'), + '1.7b': BloomConfig.from_pretrained('bigscience/bloom-1b7'), + '3b': BloomConfig.from_pretrained('bigscience/bloom-3b'), + '7b': BloomConfig.from_pretrained('bigscience/bloom-7b1'), + '66b': BloomConfig(hidden_size=9216, n_layer=64, n_head=72), + '175b': BloomConfig(hidden_size=12288, n_layer=96, n_head=128), + } + try: + return model_map[model_name] + except KeyError: + raise ValueError(f'Unknown model "{model_name}"') + + +def main(args): + if args.strategy == 'colossalai_gemini': + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) + elif args.strategy == 'colossalai_gemini_cpu': + strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5) + elif args.strategy == 'colossalai_gemini_reshard': + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda_reshard', initial_scale=2**5) + elif args.strategy == 'tp_zero2': + strategy = TPZeroStrategy(args.tp_size, zero_stage=2, initial_scale=2**5) + elif args.strategy == 'tp_zero2_cpu': + strategy = TPZeroStrategy(args.tp_size, zero_stage=2, initial_scale=2**5, cpu_offload=True) + else: + raise ValueError(f'Unsupported strategy "{args.strategy}"') + + torch.cuda.set_per_process_memory_fraction(args.cuda_mem_frac) + + model_config = get_gpt_config(args.model) + critic_config = get_gpt_config(args.critic_model) + with strategy.model_init_context(), low_precision_init(): + actor = BLOOMActor(config=model_config, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) + critic = BLOOMCritic(config=critic_config, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) + + initial_model = BLOOMActor(config=model_config, checkpoint=args.grad_checkpoint) + reward_model = BLOOMCritic(config=critic_config, checkpoint=args.grad_checkpoint) + reward_model = RewardModel(reward_model.model, reward_model.value_head) + + if args.use_kernels: + from coati.kernels import convert_to_xformer_model + actor, critic, initial_model, reward_model = map(convert_to_xformer_model, + (actor, critic, initial_model, reward_model)) + + actor_numel = get_model_numel(actor, strategy) + critic_numel = get_model_numel(critic, strategy) + initial_model_numel = get_model_numel(initial_model, strategy) + reward_model_numel = get_model_numel(reward_model, strategy) + print_model_numel({ + 'Actor': actor_numel, + 'Critic': critic_numel, + 'Initial model': initial_model_numel, + 'Reward model': reward_model_numel + }) + performance_evaluator = PerformanceEvaluator(actor_numel, + critic_numel, + initial_model_numel, + reward_model_numel, + enable_grad_checkpoint=False, + ignore_episodes=1) + + actor_optim = HybridAdam(actor.parameters(), lr=5e-6) + critic_optim = HybridAdam(critic.parameters(), lr=5e-6) + + tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m') + tokenizer.pad_token = tokenizer.eos_token + + with low_precision_init(): + (actor, actor_optim), (critic, critic_optim), initial_model, reward_model = strategy.prepare( + (actor, actor_optim), (critic, critic_optim), initial_model, reward_model) + + print_rank_0(f'Mem after prepare: {psutil.Process(os.getpid()).memory_full_info().rss /1024**3:.2f} GB') + print_rank_0(f'CUDA Mem after prepare: {torch.cuda.memory_allocated() / 1024**3:.2f} GB') + # TODO(ver217): load checkpoint here + + trainer = PPOTrainer(strategy, + actor, + critic, + reward_model, + initial_model, + actor_optim, + critic_optim, + ptx_coef=args.ptx_coef, + max_epochs=args.max_epochs, + train_batch_size=args.train_batch_size, + offload_inference_models=args.offload_inference_models, + max_length=512, + do_sample=True, + temperature=1.0, + top_k=50, + use_cache=True, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + callbacks=[performance_evaluator]) + + random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device()) + ptx_prompts = torch.randint(tokenizer.vocab_size, (1000, 512), device=torch.cuda.current_device()) + dataloader = DataLoader(random_prompts, + batch_size=args.experience_batch_size, + shuffle=True, + collate_fn=preprocess_batch) + ptx_dataloader = DataLoader(ptx_prompts, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=preprocess_ptx_batch) + + trainer.fit(dataloader, + ptx_dataloader, + num_episodes=args.num_episodes, + max_timesteps=args.update_timesteps, + update_timesteps=args.update_timesteps) + + print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB') + print_rank_0(f'Peak Mem: {get_max_memory()/1024**2:.2f} GB') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-m', '--model', default='350m') + parser.add_argument('-c', '--critic_model', default='350m') + parser.add_argument('-s', + '--strategy', + choices=[ + 'colossalai_gemini', + 'colossalai_gemini_reshard', + 'colossalai_gemini_cpu', + 'tp_zero2', + 'tp_zero2_cpu', + ], + default='colossalai_gemini_reshard') + parser.add_argument('-t', '--tp_size', type=int, default=1) + parser.add_argument('-e', '--num_episodes', type=int, default=3) + parser.add_argument('-u', '--update_timesteps', type=int, default=1) + parser.add_argument('--max_epochs', type=int, default=1) + parser.add_argument('--train_batch_size', type=int, default=8) + parser.add_argument('--experience_batch_size', type=int, default=8) + parser.add_argument('-l', '--lora_rank', type=int, default=0) + parser.add_argument('--cuda_mem_frac', type=float, default=1.0) + parser.add_argument('-o', '--offload_inference_models', action='store_true', default=False) + parser.add_argument('-k', + '--use_kernels', + action='store_true', + default=False, + help='This uses xformers kernels, which can save memory and accelerate training.') + parser.add_argument('-g', + '--grad_checkpoint', + default=False, + action='store_true', + help='This uses gradient checkpointing, which can save memory and slow down training.') + parser.add_argument('-p', '--ptx_coef', type=float, default=0.0) + args = parser.parse_args() + main(args) diff --git a/applications/Chat/benchmarks/benchmark_large_opt_lora_dummy.py b/applications/Chat/benchmarks/benchmark_large_opt_lora_dummy.py new file mode 100644 index 000000000000..49147d3950d6 --- /dev/null +++ b/applications/Chat/benchmarks/benchmark_large_opt_lora_dummy.py @@ -0,0 +1,210 @@ +import argparse +from copy import deepcopy + +import torch +import torch.distributed as dist +import torch.nn as nn +from coati.models.base import RewardModel +from coati.models.opt import OPTActor, OPTCritic +from coati.trainer import PPOTrainer +from coati.trainer.callbacks import PerformanceEvaluator +from coati.trainer.strategies import ColossalAIStrategy, Strategy, TPZeroStrategy +from torch.optim import Adam +from torch.utils.data import DataLoader +from transformers import AutoTokenizer +from transformers.modeling_utils import no_init_weights +from transformers.models.opt.configuration_opt import OPTConfig + +from colossalai.nn.optimizer import HybridAdam + + +def get_model_numel(model: nn.Module, strategy: Strategy) -> int: + numel = sum(p.numel() for p in model.parameters()) + if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init: + numel *= dist.get_world_size() + return numel + + +def preprocess_batch(samples) -> dict: + input_ids = torch.stack(samples) + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + return {'input_ids': input_ids, 'attention_mask': attention_mask} + + +def print_rank_0(*args, **kwargs) -> None: + if dist.get_rank() == 0: + print(*args, **kwargs) + + +def print_model_numel(model_dict: dict) -> None: + B = 1024**3 + M = 1024**2 + K = 1024 + outputs = '' + for name, numel in model_dict.items(): + outputs += f'{name}: ' + if numel >= B: + outputs += f'{numel / B:.2f} B\n' + elif numel >= M: + outputs += f'{numel / M:.2f} M\n' + elif numel >= K: + outputs += f'{numel / K:.2f} K\n' + else: + outputs += f'{numel}\n' + print_rank_0(outputs) + + +def get_gpt_config(model_name: str) -> OPTConfig: + model_map = { + '125m': OPTConfig.from_pretrained('facebook/opt-125m'), + '350m': OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16), + '700m': OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20), + '1.3b': OPTConfig.from_pretrained('facebook/opt-1.3b'), + '2.7b': OPTConfig.from_pretrained('facebook/opt-2.7b'), + '3.5b': OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32), + '5.5b': OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32), + '6.7b': OPTConfig.from_pretrained('facebook/opt-6.7b'), + '10b': OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32), + '13b': OPTConfig.from_pretrained('facebook/opt-13b'), + } + try: + return model_map[model_name] + except KeyError: + raise ValueError(f'Unknown model "{model_name}"') + + +def main(args): + if args.strategy == 'colossalai_gemini': + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) + elif args.strategy == 'colossalai_gemini_cpu': + strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5) + elif args.strategy == 'colossalai_gemini_reshard': + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda_reshard', initial_scale=2**5) + elif args.strategy == 'tp_zero2': + strategy = TPZeroStrategy(args.tp_size, zero_stage=2, initial_scale=2**5) + elif args.strategy == 'tp_zero2_cpu': + strategy = TPZeroStrategy(args.tp_size, zero_stage=2, initial_scale=2**5, cpu_offload=True) + else: + raise ValueError(f'Unsupported strategy "{args.strategy}"') + + torch.cuda.set_per_process_memory_fraction(args.cuda_mem_frac) + + model_config = get_gpt_config(args.model) + critic_config = get_gpt_config(args.critic_model) + with strategy.model_init_context(), no_init_weights(): + actor = OPTActor(config=model_config, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) + actor.model.tie_weights() + critic = OPTCritic(config=critic_config, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) + critic.model.tie_weights() + + initial_model = OPTActor(config=model_config, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) + initial_model.model.tie_weights() + reward_model = OPTCritic(config=critic_config, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) + reward_model.model.tie_weights() + reward_model = RewardModel(reward_model.model, reward_model.value_head) + + if args.use_kernels: + from coati.kernels import convert_to_xformer_model + actor, critic, initial_model, reward_model = map(convert_to_xformer_model, + (actor, critic, initial_model, reward_model)) + + actor_numel = get_model_numel(actor, strategy) + critic_numel = get_model_numel(critic, strategy) + initial_model_numel = get_model_numel(initial_model, strategy) + reward_model_numel = get_model_numel(reward_model, strategy) + print_model_numel({ + 'Actor': actor_numel, + 'Critic': critic_numel, + 'Initial model': initial_model_numel, + 'Reward model': reward_model_numel + }) + performance_evaluator = PerformanceEvaluator(actor_numel, + critic_numel, + initial_model_numel, + reward_model_numel, + enable_grad_checkpoint=False, + ignore_episodes=1) + + if args.strategy.startswith('colossalai'): + actor_optim = HybridAdam(actor.parameters(), lr=5e-6) + critic_optim = HybridAdam(critic.parameters(), lr=5e-6) + else: + actor_optim = Adam(actor.parameters(), lr=5e-6) + critic_optim = Adam(critic.parameters(), lr=5e-6) + + tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m') + tokenizer.pad_token = tokenizer.eos_token + + (actor, actor_optim), (critic, critic_optim), initial_model, reward_model = strategy.prepare( + (actor, actor_optim), (critic, critic_optim), initial_model, reward_model) + + # TODO(ver217): load checkpoint here + + trainer = PPOTrainer(strategy, + actor, + critic, + reward_model, + initial_model, + actor_optim, + critic_optim, + ptx_coef=0, + max_epochs=args.max_epochs, + train_batch_size=args.train_batch_size, + offload_inference_models=args.offload_inference_models, + max_length=512, + do_sample=True, + temperature=1.0, + top_k=50, + use_cache=True, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + callbacks=[performance_evaluator]) + + random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device()) + dataloader = DataLoader(random_prompts, + batch_size=args.experience_batch_size, + shuffle=True, + collate_fn=preprocess_batch) + + trainer.fit(dataloader, + None, + num_episodes=args.num_episodes, + max_timesteps=args.max_timesteps, + update_timesteps=args.update_timesteps) + + print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--model', default='125m') + parser.add_argument('--critic_model', default='125m') + parser.add_argument('--strategy', + choices=[ + 'colossalai_gemini', + 'colossalai_gemini_reshard', + 'colossalai_gemini_cpu', + 'tp_zero2', + 'tp_zero2_cpu', + ], + default='colossalai_gemini_reshard') + parser.add_argument('--tp_size', type=int, default=1) + parser.add_argument('--num_episodes', type=int, default=3) + parser.add_argument('--max_timesteps', type=int, default=1) + parser.add_argument('--update_timesteps', type=int, default=1) + parser.add_argument('--max_epochs', type=int, default=1) + parser.add_argument('--train_batch_size', type=int, default=8) + parser.add_argument('--experience_batch_size', type=int, default=8) + parser.add_argument('--lora_rank', type=int, default=0) + parser.add_argument('--cuda_mem_frac', type=float, default=1.0) + parser.add_argument('--offload_inference_models', action='store_true', default=False) + parser.add_argument('--use_kernels', + action='store_true', + default=False, + help='This uses xformers kernels, which can save memory and accelerate training.') + parser.add_argument('--grad_checkpoint', + default=False, + action='store_true', + help='This uses gradient checkpointing, which can save memory and slow down training.') + args = parser.parse_args() + main(args) diff --git a/applications/Chat/benchmarks/bloom_memory.py b/applications/Chat/benchmarks/bloom_memory.py new file mode 100644 index 000000000000..7a4c176ba5d6 --- /dev/null +++ b/applications/Chat/benchmarks/bloom_memory.py @@ -0,0 +1,143 @@ +import argparse +import os +import resource +from contextlib import contextmanager +from copy import deepcopy + +import psutil +import torch +import torch.distributed as dist +import torch.nn as nn +from coati.models.base import RewardModel +from coati.models.bloom import BLOOMActor, BLOOMCritic +from coati.trainer import PPOTrainer +from coati.trainer.callbacks import PerformanceEvaluator +from coati.trainer.strategies import ColossalAIStrategy, Strategy, TPZeroStrategy +from torch.optim import Adam +from torch.utils.data import DataLoader +from transformers import AutoTokenizer +from transformers.modeling_utils import no_init_weights +from transformers.models.bloom.configuration_bloom import BloomConfig + +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam + + +def get_model_numel(model: nn.Module, strategy: Strategy) -> int: + numel = sum(p.numel() for p in model.parameters()) + if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init: + numel *= dist.get_world_size() + return numel + + +def get_max_memory() -> int: + return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + + +def get_memory(): + return psutil.Process(os.getpid()).memory_full_info().rss + + +@contextmanager +def low_precision_init(target_dtype: torch.dtype = torch.float16): + dtype = torch.get_default_dtype() + try: + torch.set_default_dtype(target_dtype) + yield + finally: + torch.set_default_dtype(dtype) + + +def format_model_numel(model_dict: dict) -> str: + B = 1024**3 + M = 1024**2 + K = 1024 + outputs = '' + for name, numel in model_dict.items(): + outputs += f'{name}: ' + if numel >= B: + outputs += f'{numel / B:.2f} B\n' + elif numel >= M: + outputs += f'{numel / M:.2f} M\n' + elif numel >= K: + outputs += f'{numel / K:.2f} K\n' + else: + outputs += f'{numel}\n' + return outputs + + +def get_gpt_config(model_name: str) -> BloomConfig: + model_map = { + '350m': BloomConfig(hidden_size=1024, n_layer=24, n_head=16), + '560m': BloomConfig.from_pretrained('bigscience/bloom-560m'), + '1.1b': BloomConfig.from_pretrained('bigscience/bloom-1b1'), + '1.7b': BloomConfig.from_pretrained('bigscience/bloom-1b7'), + '3b': BloomConfig.from_pretrained('bigscience/bloom-3b'), + '7b': BloomConfig.from_pretrained('bigscience/bloom-7b1'), + '66b': BloomConfig(hidden_size=9216, n_layer=64, n_head=72), + '175b': BloomConfig(hidden_size=12288, n_layer=96, n_head=128), + } + try: + return model_map[model_name] + except KeyError: + raise ValueError(f'Unknown model "{model_name}"') + + +def main(args): + if args.strategy == 'gemini': + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) + elif args.strategy == 'gemini_cpu': + strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5) + elif args.strategy == 'gemini_reshard': + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda_reshard', initial_scale=2**5) + elif args.strategy == 'tp_zero2': + strategy = TPZeroStrategy(args.tp_size, zero_stage=2, initial_scale=2**5) + elif args.strategy == 'tp_zero2_cpu': + strategy = TPZeroStrategy(args.tp_size, zero_stage=2, initial_scale=2**5, cpu_offload=True) + else: + raise ValueError(f'Unsupported strategy "{args.strategy}"') + + coordinator = DistCoordinator() + model_config = get_gpt_config(args.model) + with strategy.model_init_context(), no_init_weights(), low_precision_init(): + actor = BLOOMActor(config=model_config, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) + actor.model.tie_weights() + + actor_numel = get_model_numel(actor, strategy) + coordinator.print_on_master(format_model_numel({'Actor': actor_numel})) + coordinator.print_on_master(f'Mem after lazy init: {get_memory()/1024**3:.2f} GB') + with low_precision_init(): + if args.init_optim: + actor_optim = HybridAdam(actor.parameters(), lr=5e-6) + (actor, actor_optim) = strategy.prepare((actor, actor_optim)) + else: + actor = strategy.prepare(actor) + + coordinator.print_on_master(f'Mem: {get_memory()/1024**3:.2f} GB') + coordinator.print_on_master(f'Peak mem: {get_max_memory()/1024**2:.2f} GB') + coordinator.print_on_master(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-m', '--model', default='350m') + parser.add_argument('-s', + '--strategy', + choices=[ + 'gemini', + 'gemini_reshard', + 'gemini_cpu', + 'tp_zero2', + 'tp_zero2_cpu', + ], + default='gemini_reshard') + parser.add_argument('-t', '--tp_size', type=int, default=1) + parser.add_argument('-l', '--lora_rank', type=int, default=0) + parser.add_argument('-g', + '--grad_checkpoint', + default=False, + action='store_true', + help='This uses gradient checkpointing, which can save memory and slow down training.') + parser.add_argument('-o', '--init_optim', action='store_true', default=False) + args = parser.parse_args() + main(args) diff --git a/applications/Chat/coati/experience_maker/naive.py b/applications/Chat/coati/experience_maker/naive.py index e5bb029e63d0..f3041bcf4c61 100644 --- a/applications/Chat/coati/experience_maker/naive.py +++ b/applications/Chat/coati/experience_maker/naive.py @@ -1,6 +1,11 @@ import torch +import torch.nn as nn from coati.models.generation import generate_with_actor from coati.models.utils import calc_action_log_probs, compute_reward, normalize +from torch.nn import Module + +from applications.Chat.coati.models.base import Actor +from colossalai.utils import get_current_device from .base import Experience, ExperienceMaker @@ -10,6 +15,18 @@ class NaiveExperienceMaker(ExperienceMaker): Naive experience maker. """ + def __init__(self, + actor: Actor, + critic: Module, + reward_model: Module, + initial_model: Actor, + kl_coef: float = 0.1, + offload: bool = False, + is_colossalai_strategy: bool = False) -> None: + super().__init__(actor, critic, reward_model, initial_model, kl_coef) + self.offload = offload + self.is_colossalai_strategy = is_colossalai_strategy + @torch.no_grad() def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience: self.actor.eval() @@ -17,18 +34,36 @@ def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experie self.initial_model.eval() self.reward_model.eval() + self.actor.module.inference_mode(True) + if not self.is_colossalai_strategy: + self.actor.to(get_current_device()) sequences, attention_mask, action_mask = generate_with_actor(self.actor, input_ids, return_action_mask=True, **generate_kwargs) + self.actor.module.inference_mode(False) num_actions = action_mask.size(1) actor_output = self.actor(sequences, attention_mask) action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions) + if self.offload: + self.actor.to('cpu') + if not self.is_colossalai_strategy: + self.initial_model.to(get_current_device()) base_model_output = self.initial_model(sequences, attention_mask) base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions) + if self.offload: + self.initial_model.to('cpu') + if not self.is_colossalai_strategy: + self.critic.to(get_current_device()) value = self.critic(sequences, action_mask, attention_mask) + if self.offload: + self.critic.to('cpu') + if not self.is_colossalai_strategy: + self.reward_model.to(get_current_device()) r = self.reward_model(sequences, attention_mask) + if self.offload: + self.reward_model.to('cpu') reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask) advantage = reward - value diff --git a/applications/Chat/coati/models/bloom/bloom_actor.py b/applications/Chat/coati/models/bloom/bloom_actor.py index d7577f096493..2f3f674600fa 100644 --- a/applications/Chat/coati/models/bloom/bloom_actor.py +++ b/applications/Chat/coati/models/bloom/bloom_actor.py @@ -32,4 +32,8 @@ def __init__(self, model = BloomForCausalLM(BloomConfig()) if checkpoint: model.gradient_checkpointing_enable() + model.lm_head.lora_ignore = True super().__init__(model, lora_rank, lora_train_bias) + + def inference_mode(self, enable: bool = True): + self.model.lm_head.gather_output = not enable diff --git a/applications/Chat/coati/models/bloom/bloom_rm.py b/applications/Chat/coati/models/bloom/bloom_rm.py index 22cfab441abb..2293ad7ba22b 100644 --- a/applications/Chat/coati/models/bloom/bloom_rm.py +++ b/applications/Chat/coati/models/bloom/bloom_rm.py @@ -23,7 +23,8 @@ def __init__(self, config: Optional[BloomConfig] = None, checkpoint: bool = False, lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + lora_train_bias: str = 'none', + freeze_exclude: list = []) -> None: if pretrained is not None: model = BloomModel.from_pretrained(pretrained) elif config is not None: @@ -34,4 +35,8 @@ def __init__(self, model.gradient_checkpointing_enable() value_head = nn.Linear(model.config.hidden_size, 1) value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1)) + if len(freeze_exclude) > 0: + for i, layer in enumerate(model.h): + if i not in freeze_exclude: + layer.requires_grad_(False) super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/bloom/triton_attention_forward.py b/applications/Chat/coati/models/bloom/triton_attention_forward.py new file mode 100644 index 000000000000..abb54c7cb503 --- /dev/null +++ b/applications/Chat/coati/models/bloom/triton_attention_forward.py @@ -0,0 +1,106 @@ +from typing import Optional, Tuple + +import torch +from torch.nn import functional as F +from transformers.models.bloom.configuration_bloom import BloomConfig +from transformers.models.bloom.modeling_bloom import BloomAttention + +from colossalai.kernel.triton.ops import compute_attention_for_bloom + + +def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: + """ + Dropout add function + + Args: + x (`torch.tensor`, *required*): + input tensor + residual (`torch.tensor`, *required*): + esidual tensor + prob (`float`, *required*): + dropout probability + training (`bool`, *required*): + training mode + """ + out = F.dropout(x, p=prob, training=training) + out = residual + out + return out + + +class TritonBloomAttention(BloomAttention): + + def __init__(self, config: BloomConfig): + super(TritonBloomAttention, self).__init__(config) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + q_length = query_layer.shape[1] + batch_size = query_layer.shape[0] + num_heads = query_layer.shape[2] + + query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length) + value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, head_dim, kv_length] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=2) + value_layer = torch.cat((past_value, value_layer), dim=1) + + _, _, kv_length = key_layer.shape + alibi = alibi.view(batch_size, num_heads, q_length, -1) + + context_layer = compute_attention_for_bloom( + q=query_layer.view(batch_size, self.num_heads, q_length, self.head_dim), + k=key_layer.view(batch_size, self.num_heads, self.head_dim, kv_length), + v=value_layer.view(batch_size, self.num_heads, kv_length, self.head_dim), + alibi=alibi, + beta=self.beta, + scale=self.inv_norm_factor, + attention_mask=attention_mask, + drop_out=self.hidden_dropout, + head_mask=head_mask, + layer_past=layer_past, + use_cache=True, + ) + + if use_cache: + present = (key_layer, value_layer) + else: + present = None + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + F.linear( + context_layer[:, :, int(i * slices):int((i + 1) * slices)], + self.dense.weight[:, int(i * slices):int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) + + output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) + + outputs = (output_tensor, present) + if output_attentions: + outputs += (attention_probs,) + + return outputs diff --git a/applications/Chat/coati/models/generation.py b/applications/Chat/coati/models/generation.py index 0156e2284e52..57b2f75c8970 100644 --- a/applications/Chat/coati/models/generation.py +++ b/applications/Chat/coati/models/generation.py @@ -5,6 +5,8 @@ import torch.nn as nn import torch.nn.functional as F +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc try: from transformers.generation_logits_process import ( @@ -38,6 +40,19 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool: return unfinished_sequences.max() == 0 +def gather_logits(logits: torch.Tensor): + if gpc.get_world_size(ParallelMode.PARALLEL_1D) <= 1: + return logits + # gather logits + logits = logits.contiguous() + rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + tensor_list = [torch.zeros_like(logits) for _ in range(world_size)] + tensor_list[rank] = logits + dist.all_gather(tensor_list, logits, group=gpc.get_group(ParallelMode.PARALLEL_1D)) + return torch.cat(tensor_list, dim=-1).contiguous() + + def sample(model: nn.Module, input_ids: torch.Tensor, max_length: int, @@ -62,6 +77,7 @@ def sample(model: nn.Module, outputs = model(**model_inputs) next_token_logits = outputs['logits'][:, -1, :] + next_token_logits = gather_logits(next_token_logits) # pre-process distribution next_token_logits = logits_processor(input_ids, next_token_logits) # sample @@ -148,12 +164,12 @@ def generate(model: nn.Module, @torch.no_grad() -def generate_with_actor(actor_model: nn.Module, - input_ids: torch.Tensor, - return_action_mask: bool = True, - **kwargs - ) -> Union[Tuple[torch.LongTensor, torch.LongTensor], - Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: +def generate_with_actor( + actor_model: nn.Module, + input_ids: torch.Tensor, + return_action_mask: bool = True, + **kwargs +) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: """Generate token sequence with actor model. Refer to `generate` for more details. """ # generate sequences diff --git a/applications/Chat/coati/models/llama/llama_rm.py b/applications/Chat/coati/models/llama/llama_rm.py index f936019d62d2..22343f4e638b 100644 --- a/applications/Chat/coati/models/llama/llama_rm.py +++ b/applications/Chat/coati/models/llama/llama_rm.py @@ -23,7 +23,8 @@ def __init__(self, config: Optional[LlamaConfig] = None, checkpoint: bool = False, lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + lora_train_bias: str = 'none', + freeze_exclude: list = []) -> None: if pretrained is not None: model = LlamaModel.from_pretrained(pretrained) @@ -36,5 +37,8 @@ def __init__(self, model.gradient_checkpointing_enable() value_head = nn.Linear(model.config.hidden_size, 1) value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1)) - + if len(freeze_exclude) > 0: + for i, layer in enumerate(model.layers): + if i not in freeze_exclude: + layer.requires_grad_(False) super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/lora.py b/applications/Chat/coati/models/lora.py index 2a9059e6901e..4f1c81a5fc9a 100644 --- a/applications/Chat/coati/models/lora.py +++ b/applications/Chat/coati/models/lora.py @@ -106,7 +106,7 @@ def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear: def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None: for name, child in module.named_children(): - if isinstance(child, nn.Linear): + if isinstance(child, nn.Linear) and not getattr(child, 'lora_ignore', False): setattr(module, name, lora_linear_wrapper(child, lora_rank)) else: convert_to_lora_recursively(child, lora_rank) diff --git a/applications/Chat/coati/models/utils.py b/applications/Chat/coati/models/utils.py index b9f15f894a1f..6fbd487f5071 100644 --- a/applications/Chat/coati/models/utils.py +++ b/applications/Chat/coati/models/utils.py @@ -2,9 +2,55 @@ import loralib as lora import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.layer.parallel_1d._utils import ( + gather_forward_split_backward, + reduce_grad, + reduce_input, + split_forward_gather_backward, +) + + +def _reduce_max(input_, parallel_mode): + # skip if only one rank involved + if gpc.get_world_size(parallel_mode) == 1: + return input_ + group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode) + dist.all_reduce(input_, op=dist.ReduceOp.MAX, group=group) + + return input_ + + +class _ReduceInput(torch.autograd.Function): + """ + All-reduce the input from the model parallel region. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + """ + + @staticmethod + def symbolic(graph, input_): + return _reduce_max(input_) + + @staticmethod + def forward(ctx, input_, parallel_mode): + return _reduce_max(input_, parallel_mode) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +def reduce_input_max(input_, parallel_mode): + return _ReduceInput.apply(input_, parallel_mode) + def compute_approx_kl(log_probs: torch.Tensor, log_probs_base: torch.Tensor, @@ -46,10 +92,37 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T return log_probs_labels.squeeze(-1) -def calc_action_log_probs(output: torch.Tensor, - sequences: torch.LongTensor, - num_actions: int - ) -> torch.Tensor: +def dist_log_probs_from_logits(parallel_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + logits_max = torch.max(parallel_logits, dim=-1)[0] + logits_max = reduce_input_max(logits_max, ParallelMode.PARALLEL_1D) + + # minus the max to avoid the result of sum of exp is too large and the log is nan + parallel_logits = parallel_logits - logits_max.unsqueeze(-1) + + exp_sum = torch.sum(torch.exp(parallel_logits), dim=-1) + exp_sum = reduce_input(exp_sum, ParallelMode.PARALLEL_1D) + parallel_log_probs = parallel_logits - torch.log(exp_sum.unsqueeze(-1)) + + # create mask + rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + partition_vocab_size = parallel_logits.shape[-1] + global_vocab_size = partition_vocab_size * world_size + delta = (global_vocab_size + world_size - 1) // world_size + down_threshold = rank * delta + up_threshold = down_threshold + delta + mask = (labels < down_threshold) | (labels >= up_threshold) + masked_labels = labels.clone() - down_threshold + masked_labels[mask] = 0 + + parallel_log_prob_labels = parallel_log_probs.gather(dim=-1, index=masked_labels.unsqueeze(-1)) + parallel_log_prob_labels = parallel_log_prob_labels.squeeze(-1) + parallel_log_prob_labels[mask] = 0 + parallel_log_prob_labels = reduce_input(parallel_log_prob_labels, ParallelMode.PARALLEL_1D) + return parallel_log_prob_labels + + +def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor: """Calculate action log probs. Args: @@ -61,7 +134,10 @@ def calc_action_log_probs(output: torch.Tensor, torch.Tensor: Action log probs. """ logits = output['logits'] - log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) + if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1: + log_probs = dist_log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) + else: + log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) return log_probs[:, -num_actions:] diff --git a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py index 925455444597..23b272ed0c08 100644 --- a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py +++ b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py @@ -5,12 +5,15 @@ import torch.distributed as dist from coati.experience_maker import Experience +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + from .base import Callback def get_world_size() -> int: if dist.is_initialized(): - return dist.get_world_size() + return gpc.get_world_size(ParallelMode.DATA) return 1 diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py index e2e44e62533e..9748ad0199ae 100644 --- a/applications/Chat/coati/trainer/ppo.py +++ b/applications/Chat/coati/trainer/ppo.py @@ -17,7 +17,7 @@ from .base import Trainer from .callbacks import Callback -from .strategies import Strategy +from .strategies import ColossalAIStrategy, Strategy from .utils import is_rank_0, to_device @@ -71,7 +71,13 @@ def __init__(self, offload_inference_models: bool = True, callbacks: List[Callback] = [], **generate_kwargs) -> None: - experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef) + experience_maker = NaiveExperienceMaker(actor, + critic, + reward_model, + initial_model, + kl_coef, + offload=offload_inference_models, + is_colossalai_strategy=type(strategy) is ColossalAIStrategy) replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor) super().__init__(strategy, max_epochs, dataloader_pin_memory, callbacks, **generate_kwargs) @@ -145,17 +151,10 @@ def fit(self, time += 1 prompts = next(iter(self.prompt_dataloader)) self._on_make_experience_start() - if self.offload_inference_models: - # TODO(ver217): this may be controlled by strategy if they are prepared by strategy - self.experience_maker.initial_model.to(self.device) - self.experience_maker.reward_model.to(self.device) experience = self._make_experience(prompts) self._on_make_experience_end(experience) self.replay_buffer.append(experience) if time % update_timesteps == 0: - if self.offload_inference_models: - self.experience_maker.initial_model.to('cpu') - self.experience_maker.reward_model.to('cpu') self._learn() self.replay_buffer.clear() self._on_episode_end(episode) @@ -166,26 +165,38 @@ def training_step(self, experience: Experience) -> Dict[str, float]: self.critic.train() # policy loss num_actions = experience.action_mask.size(1) - actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask) - action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions) - actor_loss = self.actor_loss_fn(action_log_probs, - experience.action_log_probs, - experience.advantages, - action_mask=experience.action_mask) - + if type(self.strategy) is not ColossalAIStrategy: + self.actor.to(self.device) # ptx loss if self.ptx_coef != 0: batch = next(iter(self.pretrain_dataloader)) batch = to_device(batch, self.device) - ptx_log_probs = self.actor(batch['input_ids'], - attention_mask=batch['attention_mask'])['logits'] - ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels']) - actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef) + ptx_log_probs = self.actor(batch['input_ids'], attention_mask=batch['attention_mask'])['logits'] + ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels']) * self.ptx_coef + if type(self.strategy) is not ColossalAIStrategy: + # gemini does not support grad accumulation + self.strategy.backward(ptx_loss, self.actor, self.actor_optim) + actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask) + action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions) + actor_loss: torch.Tensor = self.actor_loss_fn(action_log_probs, + experience.action_log_probs, + experience.advantages, + action_mask=experience.action_mask) + if self.ptx_coef != 0: + actor_loss = actor_loss * (1 - self.ptx_coef) + if type(self.strategy) is ColossalAIStrategy: + # gemini does not support grad accumulation + actor_loss = actor_loss + ptx_loss self.strategy.backward(actor_loss, self.actor, self.actor_optim) self.strategy.optimizer_step(self.actor_optim) self.actor_optim.zero_grad() + if self.offload_inference_models: + self.actor.to('cpu') + + if type(self.strategy) is not ColossalAIStrategy: + self.critic.to(self.device) # value loss values = self.critic(experience.sequences, action_mask=experience.action_mask, @@ -199,6 +210,9 @@ def training_step(self, experience: Experience) -> Dict[str, float]: self.strategy.optimizer_step(self.critic_optim) self.critic_optim.zero_grad() + if self.offload_inference_models: + self.critic.to('cpu') + return {'reward': experience.reward.mean().item()} diff --git a/applications/Chat/coati/trainer/rm.py b/applications/Chat/coati/trainer/rm.py index cdae5108ab00..59f18df80305 100644 --- a/applications/Chat/coati/trainer/rm.py +++ b/applications/Chat/coati/trainer/rm.py @@ -42,6 +42,7 @@ def __init__( valid_dataloader: DataLoader, eval_dataloader: DataLoader, max_epochs: int = 1, + scheduler: Optional[lr_scheduler._LRScheduler] = None, callbacks: List[Callback] = [], ) -> None: super().__init__(strategy, max_epochs, callbacks=callbacks) @@ -53,7 +54,7 @@ def __init__( self.model = model self.loss_fn = loss_fn self.optimizer = optim - self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__() // 100) + self.scheduler = scheduler def eval_acc(self, dataloader): dist = 0 @@ -103,7 +104,8 @@ def fit(self): self.optimizer.zero_grad() cnt += 1 if cnt == 100: - self.scheduler.step() + if self.scheduler: + self.scheduler.step() dist, acc = self.eval_acc(self.valid_dataloader) cnt = 0 if is_rank_0(): diff --git a/applications/Chat/coati/trainer/strategies/__init__.py b/applications/Chat/coati/trainer/strategies/__init__.py index f258c9b8a873..d01bc8da5fd9 100644 --- a/applications/Chat/coati/trainer/strategies/__init__.py +++ b/applications/Chat/coati/trainer/strategies/__init__.py @@ -2,5 +2,7 @@ from .colossalai import ColossalAIStrategy from .ddp import DDPStrategy from .naive import NaiveStrategy +from .tp_zero import TPZeroStrategy +from .zero_dp import ZeroDPStrategy -__all__ = ['Strategy', 'NaiveStrategy', 'DDPStrategy', 'ColossalAIStrategy'] +__all__ = ['Strategy', 'NaiveStrategy', 'DDPStrategy', 'ColossalAIStrategy', 'TPZeroStrategy', 'ZeroDPStrategy'] diff --git a/applications/Chat/coati/trainer/strategies/base.py b/applications/Chat/coati/trainer/strategies/base.py index 06f81f21ab26..31dfc9b7c169 100644 --- a/applications/Chat/coati/trainer/strategies/base.py +++ b/applications/Chat/coati/trainer/strategies/base.py @@ -2,8 +2,10 @@ from contextlib import nullcontext from typing import Any, List, Optional, Tuple, Union +import loralib as lora import torch import torch.nn as nn +from coati.models.lora import LoRAModule from coati.replay_buffer import ReplayBuffer from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -73,11 +75,14 @@ def prepare( if isinstance(arg, tuple): assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"' model, optimizer = arg + is_lora = isinstance(model, LoRAModule) and model.lora_rank > 0 model = self.setup_model(model) + if is_lora: + lora.mark_only_lora_as_trainable(model) optimizer = self.setup_optimizer(optimizer, model) rets.append((model, optimizer)) elif isinstance(arg, nn.Module): - rets.append(self.setup_model(model)) + rets.append(self.setup_model(arg)) else: raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}') diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py index cfdab2806a25..9769778c9f10 100644 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -9,6 +9,7 @@ from transformers.tokenization_utils_base import PreTrainedTokenizerBase import colossalai +from colossalai.lazy import LazyInitContext from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import CPUAdam, HybridAdam from colossalai.tensor import ProcessGroup, ShardSpec @@ -79,7 +80,7 @@ def __init__( max_norm: float = 0.0, norm_type: float = 2.0) -> None: super().__init__(seed) - assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"' + assert placement_policy in ('cpu', 'cuda', 'cuda_reshard'), f'Unsupported placement policy "{placement_policy}"' assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"' self.stage = stage # TODO(ver217): support shard_init when using from_pretrained() @@ -122,13 +123,7 @@ def setup_distributed(self) -> None: def model_init_context(self): if self.stage == 3: - world_size = dist.get_world_size() - shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None - default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None - return ColoInitContext(device=get_current_device(), - dtype=torch.half, - default_pg=shard_pg, - default_dist_spec=default_dist_spec) + return LazyInitContext() return super().model_init_context() def setup_model(self, model: nn.Module) -> nn.Module: diff --git a/applications/Chat/coati/trainer/strategies/tp/__init__.py b/applications/Chat/coati/trainer/strategies/tp/__init__.py new file mode 100644 index 000000000000..c84298152be6 --- /dev/null +++ b/applications/Chat/coati/trainer/strategies/tp/__init__.py @@ -0,0 +1,28 @@ +from torch.nn import Module + +from colossalai.lazy import LazyTensor + +from .policy import POLICY_MAP + + +def _replace_recursively(model: Module) -> None: + recurse: bool = True + if type(model) in POLICY_MAP: + policy = POLICY_MAP[type(model)]() + recurse = policy.replace(model) + if recurse: + for child in model.children(): + _replace_recursively(child) + + +def tp_parallelize(model: Module) -> None: + _replace_recursively(model) + for p in model.parameters(): + if isinstance(p, LazyTensor): + p.materialize() + for buf in model.buffers(): + if isinstance(buf, LazyTensor): + buf.materialize() + + +__all__ = ["tp_parallelize"] diff --git a/applications/Chat/coati/trainer/strategies/tp/parallel.py b/applications/Chat/coati/trainer/strategies/tp/parallel.py new file mode 100644 index 000000000000..10cae44973ba --- /dev/null +++ b/applications/Chat/coati/trainer/strategies/tp/parallel.py @@ -0,0 +1,138 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from coati.models.lora import LoraLinear +from torch import Tensor + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.layer.parallel_1d._utils import ( + gather_forward_split_backward, + reduce_grad, + reduce_input, + split_forward_gather_backward, +) +from colossalai.nn.layer.utils import divide + + +def linear_1d_col_fn(self: nn.Linear, input_: Tensor, gather_output: bool = False) -> Tensor: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. + # TODO(ver217): this relies on GPC + input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + # Matrix multiply. + + output_parallel = F.linear(input_parallel, self.weight, self.bias) + if gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + else: + output = output_parallel + return output + + +def linear_1d_row_fn(self: nn.Linear, input_: Tensor, parallel_input: bool = True) -> Tensor: + # Set up backprop all-reduce. + if parallel_input: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + input_ = input_ + else: + assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) + input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) + + output_parallel = F.linear(input_, self.weight) + output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + if self.bias is not None: + output = output + self.bias + return output + + +def lora_linear_1d_col_fn(self: LoraLinear, input_: Tensor, gather_output: bool = False) -> Tensor: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. + # TODO(ver217): this relies on GPC + input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + # Matrix multiply. + + output_parallel = F.linear(input_parallel, self.weight, self.bias) + if self.r > 0: + lora_output_parallel = (self.lora_dropout(input_parallel) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling + output_parallel = output_parallel + lora_output_parallel + + if gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + else: + output = output_parallel + return output + + +def lora_linear_1d_row_fn(self: LoraLinear, input_: Tensor, parallel_input: bool = True) -> Tensor: + # Set up backprop all-reduce. + if parallel_input: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + input_ = input_ + else: + assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) + input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) + + output_parallel = F.linear(input_, self.weight) + if self.r > 0: + lora_output_parallel = (self.lora_dropout(input_) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling + output_parallel = output_parallel + lora_output_parallel + output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + if self.bias is not None: + output = output + self.bias + return output + + +def vocab_parallel_embedding_fn(self: nn.Embedding, input_: Tensor) -> Tensor: + tp_size = gpc.tensor_parallel_size + tp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + vocab_size_per_partition = divide(self.num_embeddings, tp_size) + vocab_start_index = tp_rank * vocab_size_per_partition + vocab_end_index = vocab_start_index + vocab_size_per_partition + # Build the mask. + input_mask = (input_ < vocab_start_index) | (input_ >= vocab_end_index) + # Mask the input. + masked_input = input_.clone() - vocab_start_index + masked_input[input_mask] = 0 + + output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, self.max_norm, self.norm_type, + self.scale_grad_by_freq, self.sparse) + + # Mask the output embedding. + output_parallel[input_mask, :] = 0. + # Reduce across all the model parallel GPUs. + output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + return output + + +def vocab_parallel_lm_head_fn(self: nn.Linear, input_: Tensor) -> Tensor: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in VocabParallelLMHead1D forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. + input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + # Matrix multiply. + output_parallel = F.linear(input_parallel, self.weight, self.bias) + # gather_output = getattr(self, 'gather_output', True) + gather_output = False + if gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + else: + output = output_parallel + return output diff --git a/applications/Chat/coati/trainer/strategies/tp/policy.py b/applications/Chat/coati/trainer/strategies/tp/policy.py new file mode 100644 index 000000000000..a4c4c7ac72e5 --- /dev/null +++ b/applications/Chat/coati/trainer/strategies/tp/policy.py @@ -0,0 +1,213 @@ +from functools import partial +from types import MethodType +from typing import Dict, Type + +import torch.nn as nn +from coati.models.bloom.triton_attention_forward import TritonBloomAttention +from coati.models.lora import LoraLinear +from torch.nn import Module +from torch.nn import functional as F +from transformers.models.bloom.configuration_bloom import BloomConfig +from transformers.models.bloom.modeling_bloom import BloomAttention, BloomForCausalLM, BloomMLP +from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.lazy import LazyTensor +from colossalai.nn.layer.utils import divide + +from .parallel import ( + linear_1d_col_fn, + linear_1d_row_fn, + lora_linear_1d_col_fn, + lora_linear_1d_row_fn, + vocab_parallel_embedding_fn, + vocab_parallel_lm_head_fn, +) + + +class Policy: + + def replace(self, module: nn.Module) -> bool: + """Modfiy the module in place + + Args: + module (nn.Module): Module to be modified + + Returns: + bool: Whether to recurse into the module's children + """ + pass + + +class Linear1DColPolicy(Policy): + + def __init__(self, gather_output: bool = False) -> None: + super().__init__() + self.gather_output = gather_output + self.tp_size = gpc.tensor_parallel_size + self.tp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + def replace(self, module: nn.Module) -> bool: + assert isinstance(module, (nn.Linear, LoraLinear)) + # shard params + # TODO(ver217): this should be done via DTensor + divide(module.out_features, self.tp_size) + if isinstance(module.weight, LazyTensor): + module.weight.materialize() + module.weight.data = module.weight.chunk(self.tp_size, dim=0)[self.tp_rank].data.clone() + if module.bias is not None: + if isinstance(module.bias, LazyTensor): + module.bias.materialize() + module.bias.data = module.bias.chunk(self.tp_size, dim=0)[self.tp_rank].data.clone() + if isinstance(module, LoraLinear): + if isinstance(module.lora_B, LazyTensor): + module.lora_B.materialize() + module.lora_B.data = module.lora_B.chunk(self.tp_size, dim=0)[self.tp_rank].data.clone() + # replace forward + fwd = lora_linear_1d_col_fn if isinstance(module, LoraLinear) else linear_1d_col_fn + module.forward = MethodType(partial(fwd, gather_output=self.gather_output), module) + return False + + +class Linear1DRowPolicy(Policy): + + def __init__(self, parallel_input: bool = True) -> None: + super().__init__() + self.parallel_input = parallel_input + self.tp_size = gpc.tensor_parallel_size + self.tp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + def replace(self, module: nn.Module) -> bool: + assert isinstance(module, (nn.Linear, LoraLinear)) + divide(module.in_features, gpc.tensor_parallel_size) + if isinstance(module.weight, LazyTensor): + module.weight.materialize() + module.weight.data = module.weight.chunk(self.tp_size, dim=1)[self.tp_rank].data.clone() + if module.bias is not None and isinstance(module.bias, LazyTensor): + module.bias.materialize() + if isinstance(module, LoraLinear): + if isinstance(module.lora_A, LazyTensor): + module.lora_A.materialize() + module.lora_A.data = module.lora_A.chunk(self.tp_size, dim=1)[self.tp_rank].data.clone() + fwd = lora_linear_1d_row_fn if isinstance(module, LoraLinear) else linear_1d_row_fn + module.forward = MethodType(partial(fwd, parallel_input=self.parallel_input), module) + return False + + +class Embedding1DPolicy(Policy): + + def __init__(self) -> None: + super().__init__() + self.tp_size = gpc.tensor_parallel_size + self.tp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + def replace(self, module: Module) -> bool: + assert isinstance(module, nn.Embedding) + divide(module.num_embeddings, self.tp_size) + if isinstance(module.weight, LazyTensor): + module.weight.materialize() + module.weight.data = module.weight.chunk(self.tp_size, dim=0)[self.tp_rank].data.clone() + module.forward = MethodType(vocab_parallel_embedding_fn, module) + return False + + +def bloom_attn_fwd(module: BloomAttention, *args, alibi=None, **kwargs): + if alibi is not None: + tp_size = gpc.tensor_parallel_size + tp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + alibi = alibi.chunk(tp_size, dim=0)[tp_rank] + if module.training or True: + return BloomAttention.forward(module, *args, alibi=alibi, **kwargs) + else: + return TritonBloomAttention.forward(module, *args, alibi=alibi, **kwargs) + + +class BloomAttentionPolicy(Policy): + + def replace(self, module: Module) -> bool: + assert isinstance(module, BloomAttention) + assert module.num_heads % gpc.tensor_parallel_size == 0 + assert module.hidden_size % gpc.tensor_parallel_size == 0 + module.num_heads = module.num_heads // gpc.tensor_parallel_size + module.hidden_size = module.hidden_size // gpc.tensor_parallel_size + col_policy = Linear1DColPolicy(gather_output=False) + row_policy = Linear1DRowPolicy(parallel_input=True) + col_policy.replace(module.query_key_value) + row_policy.replace(module.dense) + module.forward = MethodType(bloom_attn_fwd, module) + return False + + +class BloomMLPPolicy(Policy): + + def replace(self, module: Module) -> bool: + assert isinstance(module, BloomMLP) + col_policy = Linear1DColPolicy(gather_output=False) + row_policy = Linear1DRowPolicy(parallel_input=True) + col_policy.replace(module.dense_h_to_4h) + row_policy.replace(module.dense_4h_to_h) + return False + + +class BloomForCausalLMPolicy(Policy): + + def replace(self, module: Module) -> bool: + assert isinstance(module, BloomForCausalLM) + module.lm_head.gather_output = True + module.lm_head.forward = MethodType(vocab_parallel_lm_head_fn, module.lm_head) + return True + + +class OPTAttentionPolicy(Policy): + + def replace(self, module: nn.Module) -> bool: + assert isinstance(module, OPTAttention) + # reset attr + assert module.num_heads % gpc.tensor_parallel_size == 0 + assert module.embed_dim % gpc.tensor_parallel_size == 0 + module.num_heads = module.num_heads // gpc.tensor_parallel_size + module.embed_dim = module.embed_dim // gpc.tensor_parallel_size + col_policy = Linear1DColPolicy(gather_output=False) + row_policy = Linear1DRowPolicy(parallel_input=True) + for layer in (module.k_proj, module.v_proj, module.q_proj): + col_policy.replace(layer) + row_policy.replace(module.out_proj) + return False + + +class OPTDecoderLayerPolicy(Policy): + + def replace(self, module: Module) -> bool: + assert isinstance(module, OPTDecoderLayer) + attn_policy = OPTAttentionPolicy() + attn_policy.replace(module.self_attn) + col_policy = Linear1DColPolicy(gather_output=False) + row_policy = Linear1DRowPolicy(parallel_input=True) + col_policy.replace(module.fc1) + row_policy.replace(module.fc2) + return False + + +class OPTDecoderPolicy(Policy): + + def replace(self, module: Module) -> bool: + assert isinstance(module, OPTDecoder) + col_policy = Linear1DColPolicy(gather_output=True) + if module.project_in is not None: + col_policy.replace(module.project_in) + if module.project_out is not None: + col_policy.replace(module.project_out) + decoder_layer_policy = OPTDecoderLayerPolicy() + for decoder_layer in module.layers: + decoder_layer_policy.replace(decoder_layer) + return False + + +POLICY_MAP: Dict[Type[nn.Module], Type[Policy]] = { + OPTDecoder: OPTDecoderPolicy, + BloomForCausalLM: BloomForCausalLMPolicy, + BloomMLP: BloomMLPPolicy, + BloomAttention: BloomAttentionPolicy, + nn.Embedding: Embedding1DPolicy, +} diff --git a/applications/Chat/coati/trainer/strategies/tp_zero.py b/applications/Chat/coati/trainer/strategies/tp_zero.py new file mode 100644 index 000000000000..388e729a1b4f --- /dev/null +++ b/applications/Chat/coati/trainer/strategies/tp_zero.py @@ -0,0 +1,133 @@ +from types import MethodType +from typing import Any, Optional + +import torch +import torch.nn as nn +import torch.optim as optim +from coati.models.base import get_base_model +from torch import Tensor +from torch.nn import Module +from torch.optim import Optimizer + +import colossalai +from colossalai.booster.plugin.low_level_zero_plugin import ( + SUPPORTED_PRECISION, + LowLevelZeroModel, + LowLevelZeroOptimizer, +) +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.lazy import LazyInitContext +from colossalai.utils import get_current_device + +from .naive import NaiveStrategy +from .tp import tp_parallelize + + +class TPZeroStrategy(NaiveStrategy): + + def __init__(self, + tp_size: int, + zero_stage: int = 2, + precision: str = 'fp16', + initial_scale: float = 2**32, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + reduce_bucket_size_in_m: int = 12, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + cpu_offload: bool = False, + seed: int = 42) -> None: + assert zero_stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training' + assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training' + self.tp_size = tp_size + self.zero_stage = zero_stage + self.precision = precision + self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload) + self.optim_kwargs = dict(initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + max_norm=max_norm, + norm_type=norm_type) + self.seed = seed + super().__init__() + + def setup_distributed(self) -> None: + config = dict(parallel=dict(tensor=dict( + mode='1d', + size=self.tp_size, + ))) + colossalai.launch_from_torch(config, seed=self.seed) + self.zero_optim_config['zero_process_group'] = gpc.get_group(ParallelMode.DATA) + self.zero_optim_config['tp_process_group'] = gpc.get_group(ParallelMode.PARALLEL_1D) + + def model_init_context(self): + return LazyInitContext(default_device=get_current_device()) + + def setup_model(self, model: torch.nn.Module) -> torch.nn.Module: + tp_parallelize(model) + model.to('cpu') + model = LowLevelZeroModel(model, self.zero_stage, self.precision) + model.to('cpu') + return model + + def unwrap_model(self, model: LowLevelZeroModel) -> torch.nn.Module: + assert isinstance(model, LowLevelZeroModel) + return model.module + + def setup_optimizer(self, optimizer: Optimizer, model: LowLevelZeroModel) -> Optimizer: + optim_ = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs) + + # Hack to/cuda/cpu of model + + def model_to(m: LowLevelZeroModel, device: str): + for b in m.buffers(): + b.data = b.data.to(device) + optim_.to(device) + for p in m.parameters(): + if p.device != device: + p.data = p.data.to(device) + return m + + model.to = MethodType(model_to, model) + model.cuda = MethodType(lambda m: model_to(m, 'cuda'), model) + model.cpu = MethodType(lambda m: model_to(m, 'cpu'), model) + return optim_ + + def backward(self, loss: Tensor, model: Module, optimizer: Optimizer, **kwargs) -> None: + optimizer.backward(loss) + + def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None: + optimizer.step() + + def save_model(self, model: Module, path: str, only_rank0: bool = True) -> None: + if gpc.get_local_rank(ParallelMode.DATA) == 0: + path = f'{path}.tr{gpc.get_local_rank(ParallelMode.PARALLEL_1D)}' + state_dict = self.unwrap_model(model).state_dict() + torch.save(state_dict, path) + + def load_model(self, model: Module, path: str, map_location: Any = None, strict: bool = True) -> None: + path = f'{path}.tr{gpc.get_local_rank(ParallelMode.PARALLEL_1D)}' + state_dict = torch.load(path, map_location=map_location) + self.unwrap_model(model).load_state_dict(state_dict, strict=strict) + + def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: + path = f'{path}.r{gpc.get_global_rank()}' + super().save_optimizer(optimizer, path, only_rank0) + + def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None: + path = f'{path}.r{gpc.get_global_rank()}' + super().load_optimizer(optimizer, path, map_location) diff --git a/applications/Chat/coati/trainer/strategies/zero_dp.py b/applications/Chat/coati/trainer/strategies/zero_dp.py new file mode 100644 index 000000000000..4204a441ac64 --- /dev/null +++ b/applications/Chat/coati/trainer/strategies/zero_dp.py @@ -0,0 +1,104 @@ +from typing import Any, Optional + +import torch +from torch import Tensor +from torch.nn import Module +from torch.optim import Optimizer + +import colossalai +from colossalai.booster.plugin.low_level_zero_plugin import ( + SUPPORTED_PRECISION, + LowLevelZeroModel, + LowLevelZeroOptimizer, +) +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + +from .naive import NaiveStrategy + + +class ZeroDPStrategy(NaiveStrategy): + + def __init__(self, + zero_size: int, + zero_stage: int = 2, + precision: str = 'fp16', + initial_scale: float = 2**32, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + reduce_bucket_size_in_m: int = 12, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + cpu_offload: bool = False, + seed: int = 42) -> None: + assert zero_stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training' + assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training' + self.zero_size = zero_size + self.zero_stage = zero_stage + self.precision = precision + self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload) + self.optim_kwargs = dict(initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + max_norm=max_norm, + norm_type=norm_type) + self.seed = seed + super().__init__() + + def setup_distributed(self) -> None: + # original tp group is now zero group + config = dict(parallel=dict(tensor=dict( + mode='1d', + size=self.zero_size, + ))) + colossalai.launch_from_torch(config, seed=self.seed) + self.zero_optim_config['zero_process_group'] = gpc.get_group(ParallelMode.PARALLEL_1D) + self.zero_optim_config['dp_process_group'] = gpc.get_group(ParallelMode.DATA) + + def setup_model(self, model: torch.nn.Module) -> torch.nn.Module: + model = LowLevelZeroModel(model, self.zero_stage, self.precision) + return model + + def unwrap_model(self, model: LowLevelZeroModel) -> torch.nn.Module: + assert isinstance(model, LowLevelZeroModel) + return model.module + + def setup_optimizer(self, optimizer: Optimizer, model: LowLevelZeroModel) -> Optimizer: + return LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs) + + def backward(self, loss: Tensor, model: Module, optimizer: Optimizer, **kwargs) -> None: + optimizer.backward(loss) + + def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None: + optimizer.step() + + def save_model(self, model: Module, path: str, only_rank0: bool = True) -> None: + if gpc.get_global_rank() == 0: + state_dict = self.unwrap_model(model).state_dict() + torch.save(state_dict, path) + + def load_model(self, model: Module, path: str, map_location: Any = None, strict: bool = True) -> None: + state_dict = torch.load(path, map_location=map_location) + self.unwrap_model(model).load_state_dict(state_dict, strict=strict) + + def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: + if gpc.get_local_rank(ParallelMode.DATA) == 0: + path = f'{path}.r{gpc.get_local_rank(ParallelMode.PARALLEL_1D)}' + super().save_optimizer(optimizer, path, only_rank0) + + def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None: + path = f'{path}.r{gpc.get_local_rank(ParallelMode.PARALLEL_1D)}' + super().load_optimizer(optimizer, path, map_location) diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py index 48b12336fa67..aafe7277bdbd 100644 --- a/applications/Chat/examples/train_reward_model.py +++ b/applications/Chat/examples/train_reward_model.py @@ -14,10 +14,10 @@ from coati.models.opt import OPTRM from coati.models.roberta import RoBERTaRM from coati.trainer import RewardModelTrainer -from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy, ZeroDPStrategy from coati.utils import prepare_llama_tokenizer_and_embedding from datasets import load_dataset -from torch.optim import Adam +from torch.optim import Adam, Optimizer, lr_scheduler from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer, LlamaTokenizer, RobertaTokenizer @@ -36,23 +36,19 @@ def train(args): strategy = ColossalAIStrategy(stage=3, placement_policy='cuda') elif args.strategy == 'colossalai_zero2': strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + elif args.strategy == 'zero_dp': + strategy = ZeroDPStrategy(args.zero_size) else: raise ValueError(f'Unsupported strategy "{args.strategy}"') # configure model with strategy.model_init_context(): if args.model == 'bloom': - model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) - elif args.model == 'opt': - model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) - elif args.model == 'gpt2': - model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) - elif args.model == 'deberta': - model = DebertaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank, + freeze_exclude=args.freeze_exclude).cuda() elif args.model == 'llama': - model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) - elif args.model == 'roberta': - model = RoBERTaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank, + freeze_exclude=args.freeze_exclude).cuda() else: raise ValueError(f'Unsupported model "{args.model}"') @@ -63,18 +59,10 @@ def train(args): model = model.to(torch.float16) # configure tokenizer - if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') - elif args.model == 'bloom': + if args.model == 'bloom': tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') - elif args.model == 'opt': - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - elif args.model == 'deberta': - tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-large') elif args.model == 'llama': tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) - elif args.model == 'roberta': - tokenizer = RobertaTokenizer.from_pretrained("roberta-base") else: raise ValueError(f'Unsupported model "{args.model}"') max_len = args.max_len @@ -89,7 +77,6 @@ def train(args): optim = HybridAdam(model.parameters(), lr=5e-6) else: optim = Adam(model.parameters(), lr=5e-6) - # configure loss function if args.loss_fn == 'log_sig': loss_fn = LogSigLoss() @@ -165,6 +152,7 @@ def train(args): batch_size=args.batch_size, pin_memory=True) + scheduler = lr_scheduler.CosineAnnealingLR(optim, len(train_dataloader) // 100) (model, optim) = strategy.prepare((model, optim)) trainer = RewardModelTrainer(model=model, strategy=strategy, @@ -173,6 +161,7 @@ def train(args): train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, eval_dataloader=eval_dataloader, + scheduler=scheduler, max_epochs=args.max_epochs) trainer.fit() @@ -187,10 +176,13 @@ def train(args): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--strategy', - choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='colossalai_zero2') - parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta', 'llama', 'roberta'], default='bloom') + parser.add_argument('-s', + '--strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'zero_dp'], + default='zero_dp'), + parser.add_argument('-z', '--zero_size', type=int, default=1) + parser.add_argument('-f', '--freeze_exclude', type=int, default=[], nargs='*') + parser.add_argument('--model', choices=['bloom', 'llama'], default='bloom') parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--model_path', type=str, default=None) parser.add_argument('--need_optim_ckpt', type=bool, default=False) diff --git a/applications/Chat/requirements.txt b/applications/Chat/requirements.txt index af7ff67861eb..444206405900 100644 --- a/applications/Chat/requirements.txt +++ b/applications/Chat/requirements.txt @@ -2,8 +2,6 @@ transformers>=4.20.1 tqdm datasets loralib -colossalai>=0.2.4 -torch<2.0.0, >=1.12.1 langchain tokenizers fastapi diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 94d722080367..6609a15577ad 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -14,6 +14,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.utils import get_current_device from colossalai.zero import zero_model_wrapper, zero_optim_wrapper +from colossalai.zero.low_level import LowLevelZeroOptimizer as ZeroOptimizer from .dp_plugin_base import DPPluginBase from .torch_ddp_plugin import TorchDDPCheckpointIO @@ -61,7 +62,7 @@ def __init__(self, module: nn.Module, stage: int, precision: str) -> None: module = zero_model_wrapper(module, zero_stage=stage) if self.dtype is not None: module = module.to(self.dtype) - module = module.to(get_current_device()) + # module = module.to(get_current_device()) self.module = module self.convert_fn = None if self.dtype is not None: @@ -82,13 +83,16 @@ def __init__(self, zero_optim_config: dict, optim_kwargs: dict, verbose: bool = False) -> None: - optimizer = zero_optim_wrapper(module, - optimizer, - optim_config=zero_optim_config, - **optim_kwargs, - verbose=verbose) + optimizer: ZeroOptimizer = zero_optim_wrapper(module, + optimizer, + optim_config=zero_optim_config, + **optim_kwargs, + verbose=verbose) super().__init__(optimizer) + def to(self, device): + self.optim.to(device) + def backward(self, loss: Tensor, *args, **kwargs): self.optim.backward(loss) @@ -208,10 +212,7 @@ def configure( if optimizer is not None and \ not isinstance(optimizer, OptimizerWrapper): - optimizer = LowLevelZeroOptimizer(model.unwrap(), - optimizer, - self.zero_optim_config, - self.optim_kwargs, + optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, self.verbose) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/cluster/__init__.py b/colossalai/cluster/__init__.py index 2fbdfd3cc999..44f571ca2501 100644 --- a/colossalai/cluster/__init__.py +++ b/colossalai/cluster/__init__.py @@ -1,5 +1,6 @@ from .device_mesh_manager import DeviceMeshManager from .dist_coordinator import DistCoordinator from .process_group_manager import ProcessGroupManager +from .process_group_mesh import ProcessGroupMesh -__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager'] +__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager', 'ProcessGroupMesh'] diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py new file mode 100644 index 000000000000..1dfd261d5d01 --- /dev/null +++ b/colossalai/cluster/process_group_mesh.py @@ -0,0 +1,203 @@ +import itertools +from functools import reduce +from operator import mul +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch.distributed as dist +from torch.distributed import ProcessGroup + + +def prod(nums: List[int]) -> int: + """Product of a list of numbers. + + Args: + nums (List[int]): A list of numbers. + + Returns: + int: The product of the numbers. + """ + return reduce(mul, nums) + + +class ProcessGroupMesh: + """A helper class to manage the process group mesh. It only describes how to organize process groups, and it's decoupled with parallel method. + It just initialize process groups and cache them. The parallel method should manage them and use them to do the parallel computation. + + We use a ND-tuple to represent the process group mesh. And a ND-coordinate is to represent each process. + For example, ``(0, 1, 0)`` represents the process whose rank is 2 in a 3D process group mesh with size ``(2, 2, 2)``. + + Args: + *size (int): The size of each dimension of the process group mesh. The product of the size must be equal to the world size. + + Attributes: + shape (Tuple[int, ...]): The shape of the process group mesh. + rank (int): The rank of the current process. + """ + + def __init__(self, *size: int) -> None: + assert dist.is_initialized(), "Please initialize torch.distributed first." + assert prod(size) == dist.get_world_size(), "The product of the size must be equal to the world size." + self._shape = size + self._rank = dist.get_rank() + self._coord = ProcessGroupMesh.unravel(self._rank, self._shape) + self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {} + self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {} + + @property + def shape(self) -> Tuple[int, ...]: + return self._shape + + @property + def rank(self) -> int: + return self._rank + + def size(self, dim: Optional[int] = None) -> Union[int, Tuple[int, ...]]: + """Get the size of the process group mesh. + + Args: + dim (Optional[int], optional): Dimension of the process group mesh. `None` means all dimensions. Defaults to None. + + Returns: + Union[int, Tuple[int, ...]]: Size of the target dimension or the whole process group mesh. + """ + if dim is None: + return self._shape + else: + return self._shape[dim] + + def coordinate(self, dim: Optional[int] = None) -> Union[int, Tuple[int, ...]]: + """Get the coordinate of the process group mesh. + + Args: + dim (Optional[int], optional): Dimension of the process group mesh. `None` means all dimensions. Defaults to None. + + Returns: + Union[int, Tuple[int, ...]]: Coordinate of the target dimension or the whole process group mesh. + """ + if dim is None: + return self._coord + else: + return self._coord[dim] + + @staticmethod + def unravel(rank: int, shape: Tuple[int, ...]) -> Tuple[int, ...]: + """Convert a rank to a coordinate. + + Args: + rank (int): Rank to be converted. + shape (Tuple[int, ...]): Shape of the process group mesh. + + Returns: + Tuple[int, ...]: Coordinate of the rank. + """ + return np.unravel_index(rank, shape) + + @staticmethod + def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...]) -> int: + """Convert a coordinate to a rank. + + Args: + coords (Tuple[int, ...]): Coordinate to be converted. + shape (Tuple[int, ...]): Shape of the process group mesh. + + Returns: + int: Rank of the coordinate. + """ + return np.ravel_multi_index(coord, shape) + + def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup: + """Get the process group with the given ranks. It the process group doesn't exist, it will be created. + + Args: + ranks_in_group (List[int]): Ranks in the process group. + backend (Optional[str], optional): Backend of the process group. Defaults to None. + + Returns: + ProcessGroup: The process group with the given ranks. + """ + ranks_in_group = sorted(ranks_in_group) + if tuple(ranks_in_group) not in self._group_to_ranks: + group = dist.new_group(ranks_in_group, backend=backend) + self._ranks_to_group[tuple(ranks_in_group)] = group + self._group_to_ranks[group] = tuple(ranks_in_group) + return self._ranks_to_group[tuple(ranks_in_group)] + + def get_ranks_in_group(self, group: ProcessGroup) -> List[int]: + """Get the ranks in the given process group. The process group must be created by this class. + + Args: + group (ProcessGroup): The process group. + + Returns: + List[int]: Ranks in the process group. + """ + return list(self._group_to_ranks[group]) + + @staticmethod + def get_coords_along_axis(base_coord: Tuple[int, ...], axis: int, + indices_at_axis: List[int]) -> List[Tuple[int, ...]]: + """Get coordinates along the given axis. + + Args: + base_coord (Tuple[int, ...]): Base coordinate which the coordinates along the axis are based on. + axis (int): Axis along which the coordinates are generated. + indices_at_axis (List[int]): Indices at the axis. + + Returns: + List[Tuple[int, ...]]: Coordinates along the axis. + """ + coords_in_group = [] + for idx in indices_at_axis: + coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1:]) + return coords_in_group + + def create_group_along_axis(self, + axis: int, + indices_at_axis: Optional[List[int]] = None, + backend: Optional[str] = None) -> ProcessGroup: + """Create all process groups along the given axis, and return the one which the current process belongs to. + + Args: + axis (int): Axis along which the process groups are created. + indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None. + backend (Optional[str], optional): Backend of the process group. Defaults to None. + + Returns: + ProcessGroup: The process group along the given axis which the current process belongs to. + """ + indices_at_axis = indices_at_axis or list(range(self._shape[axis])) + reduced_shape = list(self._shape) + # the choices on the axis are reduced to 1, since it's determined by `indices_at_axis` + reduced_shape[axis] = 1 + target_group = None + # use Cartesian product to generate all combinations of coordinates + for base_coord in itertools.product(*[range(s) for s in reduced_shape]): + coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis) + ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) + group = self.get_group(ranks_in_group, backend=backend) + if self._rank in ranks_in_group: + target_group = group + return target_group + + def get_group_along_axis(self, + axis: int, + indices_at_axis: Optional[List[int]] = None, + backend: Optional[str] = None) -> ProcessGroup: + """Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created. + + Args: + axis (int): Axis along which the process groups are created. + indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None. + backend (Optional[str], optional): Backend of the process group. Defaults to None. + + Returns: + ProcessGroup: The process group along the given axis which the current process belongs to. + """ + indices_at_axis = indices_at_axis or list(range(self._shape[axis])) + coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis) + ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) + if ranks_in_group not in self._ranks_to_group: + # no need to cache it explicitly, since it will be cached in `create_group_along_axis` + return self.create_group_along_axis(axis, indices_at_axis, backend=backend) + return self._ranks_to_group[ranks_in_group] diff --git a/colossalai/kernel/triton/ops.py b/colossalai/kernel/triton/ops.py new file mode 100644 index 000000000000..e8d60dc14f63 --- /dev/null +++ b/colossalai/kernel/triton/ops.py @@ -0,0 +1,366 @@ +from typing import Optional, Tuple + +import torch +from torch import nn +from torch.nn import functional as F + +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + from .qkv_matmul_kernel import qkv_gemm_4d_kernel, qkv_gemm_4d_kernel_alibi + from .softmax_kernel import softmax_kernel + + def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float): + r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels + Args: + q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) + k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) + v (torch.Tensor): V embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) + input_mask (torch.Tensor): mask for softmax layer, shape should be (batch, num_heads, seq_lem, seq_len) + scale: the float scale value which is used to multiply with Q*K^T before doing softmax + + Return: + output (Torch.Tensor): The output shape is (batch, seq_len, num_heads, head_size) + """ + assert len(q.shape) == 4, "the shape of q val must be 4" + batches, M, H, K = q.shape + assert q.shape == k.shape, "the shape of q and the shape of k must be equal" + assert q.shape == v.shape, "the shape of q and the shape of v must be equal" + assert q.shape[-1] == k.shape[-1], "the last dimension of q and k must be equal" + + N = k.shape[1] + + # head_size * num_of_head + d_model = q.shape[-1] * q.shape[-2] + + score_output = torch.empty( + (batches, H, M, N), device=q.device, dtype=q.dtype) + + grid = lambda meta: ( + batches, + H, + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * + triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + + qkv_gemm_4d_kernel[grid]( + q, k, score_output, + M, N, K, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(3), k.stride(1), + score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3), + scale=scale, + # currently manually setting, later on we can use auto-tune config to match best setting + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=32, + BLOCK_SIZE_K=32, + GROUP_SIZE_M=8, + ) + + softmax_output = torch.empty( + score_output.shape, device=score_output.device, dtype=score_output.dtype) + score_output_shape = score_output.shape + + score_output = score_output.view(-1, score_output.shape[-1]) + n_rows, n_cols = score_output.shape + + if n_rows <= 350000: + + block_size = max(triton.next_power_of_2(n_cols), 2) + num_warps = 4 + if block_size >= 4096: + num_warps = 16 + elif block_size >= 2048: + num_warps = 8 + else: + num_warps = 4 + + softmax_kernel[(n_rows, )]( + softmax_output, + score_output, + score_output.stride(0), + n_cols, + mask_ptr = input_mask, + num_warps=num_warps, + BLOCK_SIZE=block_size, + ) + + else: + #TODO: change softmax kernel functions to make it suitable for large size dimension + softmax_output = torch.nn.functional.softmax(score_output, dim=-1) + softmax_output = softmax_output.view(*score_output_shape) + + batches, H, M, K = softmax_output.shape + N = v.shape[-1] + + output = torch.empty( + (batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype) + + grid = lambda meta: ( + batches, + H, + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * + triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + + qkv_gemm_4d_kernel[grid]( + softmax_output, v, output, + M, N, K, + softmax_output.stride(0), + softmax_output.stride(1), + softmax_output.stride(2), + softmax_output.stride(3), + v.stride(0), + v.stride(2), + v.stride(1), + v.stride(3), + output.stride(0), + output.stride(2), + output.stride(1), + output.stride(3), + BLOCK_SIZE_M=128, + BLOCK_SIZE_N=64, + BLOCK_SIZE_K=64, + GROUP_SIZE_M=8, + scale=-1, + ) + return output.view(batches, -1, d_model) + + def compute_attention_for_bloom(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + alibi: torch.Tensor, + beta: torch.float32 = 1, + scale: torch.float32 = 1.2, + attention_mask: torch.Tensor = None, + drop_out: torch.float32 = -1, + head_mask: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = True + ): + r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels to used for bloom attention + Args: + q (torch.Tensor): Q embedding in attention layer, shape should be (batch, num_heads, seq_len, head_size) + k (torch.Tensor): K embedding in attention layer, shape should be (batch, num_heads, head_size, kv_length) + v (torch.Tensor): V embedding in attention layer, shape should be (batch, num_heads, kv_length, head_size) + alibi(torch.Tensor): bias for qk^T GEMM, shape should be (batch, H, q_length, kv_length) + input_mask (torch.Tensor): mask for softmax layer, shape should be (batch, num_heads, seq_lem, seq_len) + scale: the float scale value which is used to multiply with Q*K^T before doing softmax + beta: the float value for alibi bias matrix + + Return: + output (Torch.Tensor): The output shape is (batch, seq_len, num_heads, head_size) + """ + + assert len(q.shape) == len(k.shape), "the dimensions must be matched" + assert len(q.shape) == len(v.shape), "the dimensions must be matched" + assert len(q.shape) == 4, "the length of input q must be 4, which is (batch, seq_len, num_heads, head_dim)" + + batches, H, M, K = q.shape + d_model = q.shape[1] * q.shape[3] + + # k shape: (batches, num_heads, head_dim (K), seq_k(N)) + N = k.shape[-1] + + score_output = torch.empty( + (batches, H, M, N), device=q.device, dtype=q.dtype) + + assert len(score_output) == len(alibi), "the length of alibi and score output should be matched" + assert score_output.shape[:-1] == alibi.shape[:-1], "the shape of alibi and score outout also should be the same" + alibi = alibi.expand(batches, H, M, N) + + grid = lambda meta: ( + batches, + H, + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * + triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + + qkv_gemm_4d_kernel[grid]( + q, k, + score_output, + M, N, K, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3), + scale=scale, + # currently manually setting, later on we can use auto-tune config to match best setting + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=32, + BLOCK_SIZE_K=32, + GROUP_SIZE_M=8, + num_stages=4, + ) + + score_output += beta * alibi + + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = score_output.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16: + score_output = score_output.to(torch.float) + + if attention_mask is not None: + score_output = torch.masked_fill(score_output, attention_mask, torch.finfo(score_output.dtype).min) + + softmax_leading_size = batches * H * M + + if softmax_leading_size <= 350000: + score_output = score_output.to(input_dtype) + softmax_output = torch.empty( + score_output.shape, device=score_output.device, dtype=score_output.dtype) + score_output_shape = score_output.shape + + score_output = score_output.view(-1, score_output.shape[-1]) + n_rows, n_cols = score_output.shape + + block_size = max(triton.next_power_of_2(n_cols), 2) + num_warps = 4 + if block_size >= 4096: + num_warps = 16 + elif block_size >= 2048: + num_warps = 8 + else: + num_warps = 4 + + softmax_kernel[(n_rows, )]( + softmax_output, + score_output, + score_output.stride(0), + n_cols, + mask_ptr = None, + num_warps=num_warps, + BLOCK_SIZE=block_size, + ) + else: + # TODO: fix softmax layer to make kernel to be suitable for large size cases + softmax_output = F.softmax(score_output, dim=-1, dtype=torch.float32).to(input_dtype) + + if drop_out > 0 and drop_out < 1: + softmax_output = F.dropout(softmax_output, drop_out, False, False).to(input_dtype) + + if head_mask is not None: + softmax_output = softmax_output * head_mask + softmax_output = softmax_output.to(input_dtype) + + batches, H, M, K = softmax_output.shape + N = v.shape[-1] + + output = torch.empty( + (batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype) + + grid = lambda meta: ( + batches, + H, + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * + triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + + qkv_gemm_4d_kernel[grid]( + softmax_output, v, output, + M, N, K, + softmax_output.stride(0), + softmax_output.stride(1), + softmax_output.stride(2), + softmax_output.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + output.stride(0), + output.stride(2), + output.stride(1), + output.stride(3), + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=32, + BLOCK_SIZE_K=32, + GROUP_SIZE_M=8, + scale=-1, + ) + + return output.view(batches, -1 , d_model) + + + def self_attention_compute_using_triton(qkv, + input_mask, + layer_past, + alibi, + scale, + head_size, + triangular=False, + use_flash=False): + + assert qkv.is_contiguous() + assert alibi is None, "current triton self-attention does not support alibi" + batches = qkv.shape[0] + d_model = qkv.shape[-1] // 3 + num_of_heads = d_model // head_size + + q = qkv[:, :, :d_model] + k = qkv[:, :, d_model:d_model * 2] + v = qkv[:, :, d_model * 2:] + q = q.view(batches, -1, num_of_heads, head_size) + k = k.view(batches, -1, num_of_heads, head_size) + v = v.view(batches, -1, num_of_heads, head_size) + + data_output_triton = self_attention_forward_without_fusion( + q, k, v, input_mask, scale) + + return data_output_triton + + + def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor: + if mask is not None: + assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask" + assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention" + + hidden_dim = input.shape[-1] + output = torch.empty_like(input) + input = input.view(-1, hidden_dim) + if mask is not None: + mask = mask.view(-1, hidden_dim) + assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same" + + num_rows, num_cols = input.shape + block_size = max(triton.next_power_of_2(num_cols), 2) + num_warps = 16 + if block_size >= 4096: + num_warps = 16 + elif block_size >= 2048: + num_warps = 8 + else: + num_warps = 4 + + if num_rows <= 350000: + grid = (num_rows,) + softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps) + else: + grid = lambda meta: () + + grid = lambda meta: ( + triton.cdiv(num_rows, meta["BLOCK_M"]), + ) + + BLOCK_M = 32 + if block_size >= 4096: + BLOCK_M = 4 + elif block_size >= 2048: + BLOCK_M = 8 + + softmax_kernel_2[grid](output_ptr = output, + input_ptr = input, + row_stride = input.stride(0), + n_rows = num_rows, + n_cols = num_cols, + mask_ptr = mask, + # currently manually setting up size + BLOCK_M = 32, + BLOCK_SIZE = block_size) + + return output \ No newline at end of file diff --git a/colossalai/kernel/triton/qkv_matmul_kernel.py b/colossalai/kernel/triton/qkv_matmul_kernel.py new file mode 100644 index 000000000000..a7a7322f7276 --- /dev/null +++ b/colossalai/kernel/triton/qkv_matmul_kernel.py @@ -0,0 +1,217 @@ +import torch +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + + +if HAS_TRITON: + ''' + this kernel function is modified from https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html + ''' + @triton.jit + def qkv_gemm_4d_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_ab, + stride_ah, + stride_am, + stride_ak, + stride_bb, + stride_bh, + stride_bk, + stride_bn, + stride_cb, + stride_ch, + stride_cm, + stride_cn, + scale, + # Meta-parameters + BLOCK_SIZE_M : tl.constexpr = 64, + BLOCK_SIZE_N : tl.constexpr = 32, + BLOCK_SIZE_K : tl.constexpr = 32, + GROUP_SIZE_M : tl.constexpr = 8, + ): + r""" A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer, + where score_matrix is softmax(Q*V^T/sqrt(hidden_size)) + Args: + a_ptr(torch.Tensor): pointer to input tensor array (bs, M, h, K) or (bs, h, M, K) + b_ptr(torch.Tensor): pointer to input tensor array (bs, N, h, K) or (bs, h, N, K) + c_ptr(torch.Tensor): pointer to output tensor array (bs, M, h, N) or (bs, h, M, N) + stride_ab(tl.constexpr): stride for bs-dimention for tensor array A + stride_ah(tl.constexpr): stride for h-dimention for tensor array A + stride_am(tl.constexpr): stride for m-dimention for tensor array A + stride_ak(tl.constexpr): stride for k-dimention for tensor array A + stride_bb(tl.constexpr): stride for bs-dimention for tensor array B + stride_bh(tl.constexpr): stride for h-dimention for tensor array B + stride_bk(tl.constexpr): stride for k-dimention for tensor array B + stride_bn(tl.constexpr): stride for n-dimention for tensor array B + stride_cb(tl.constexpr): stride for bs-dimention for tensor array output + stride_ch(tl.constexpr): stride for h-dimention for tensor array output + stride_cm(tl.constexpr): stride for m-dimention for tensor array output + stride_cn(tl.constexpr): stride for n-dimention for tensor array output + BLOCK_SIZE_M : tiling size for M-dimension of tensor Array a + BLOCK_SIZE_N : tiling size for N-dimension of tensor Array b + BLOCK_SIZE_K : tiling size for K-dimension of a and b + GROUP_SIZE_M : group size for reducing cache miss, more details: + """ + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + batch = tl.program_id(axis = 0) + head = tl.program_id(axis = 1) + pid = tl.program_id(axis = 2) + + # the following is from tutorial: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah + + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)) + b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K, BLOCK_SIZE_K): + a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K) + b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N) + a = tl.load(a_ptrs, mask=a_mask, other=0.) + b = tl.load(b_ptrs, mask=b_mask, other=0.) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + accumulator = accumulator.to(c_ptr.dtype.element_ty) + if scale > 0: + accumulator = accumulator * scale.to(c_ptr.dtype.element_ty) + + + offs_accumu_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_accumu_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_accumu_m[:, None] + + stride_cn * offs_accumu_n[None, :]) + accumulator_mask = (offs_accumu_m[:, None] < M) & (offs_accumu_n[None, :] < N) + tl.store(c_ptrs, accumulator, mask=accumulator_mask) + + + @triton.jit + def qkv_gemm_4d_kernel_alibi( + a_ptr, + b_ptr, + alibi_ptr, + c_ptr, + M, + N, + K, + stride_ab, + stride_ah, + stride_am, + stride_ak, + stride_bb, + stride_bh, + stride_bk, + stride_bn, + stride_cb, + stride_ch, + stride_cm, + stride_cn, + scale, + beta, + # Meta-parameters + BLOCK_SIZE_M : tl.constexpr = 64, + BLOCK_SIZE_N : tl.constexpr = 32, + BLOCK_SIZE_K : tl.constexpr = 32, + GROUP_SIZE_M : tl.constexpr = 8, + ): + r""" A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer, + where score_matrix is softmax(Q*V^T/sqrt(hidden_size)) + Args: + a_ptr(torch.Tensor): pointer to input tensor array (bs, M, h, K) or (bs, h, M, K) + b_ptr(torch.Tensor): pointer to input tensor array (bs, N, h, K) or (bs, h, N, K) + c_ptr(torch.Tensor): pointer to output tensor array (bs, M, h, N) or (bs, h, M, N) + stride_ab(tl.constexpr): stride for bs-dimention for tensor array A + stride_ah(tl.constexpr): stride for h-dimention for tensor array A + stride_am(tl.constexpr): stride for m-dimention for tensor array A + stride_ak(tl.constexpr): stride for k-dimention for tensor array A + stride_bb(tl.constexpr): stride for bs-dimention for tensor array B + stride_bh(tl.constexpr): stride for h-dimention for tensor array B + stride_bk(tl.constexpr): stride for k-dimention for tensor array B + stride_bn(tl.constexpr): stride for n-dimention for tensor array B + stride_cb(tl.constexpr): stride for bs-dimention for tensor array output + stride_ch(tl.constexpr): stride for h-dimention for tensor array output + stride_cm(tl.constexpr): stride for m-dimention for tensor array output + stride_cn(tl.constexpr): stride for n-dimention for tensor array output + BLOCK_SIZE_M : tiling size for M-dimension of tensor Array a + BLOCK_SIZE_N : tiling size for N-dimension of tensor Array b + BLOCK_SIZE_K : tiling size for K-dimension of a and b + GROUP_SIZE_M : group size for reducing cache miss, more details: + """ + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + batch = tl.program_id(axis = 0) + head = tl.program_id(axis = 1) + pid = tl.program_id(axis = 2) + + # the following is from tutorial: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah + + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)) + b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K, BLOCK_SIZE_K): + a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K) + b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N) + a = tl.load(a_ptrs, mask=a_mask, other=0.) + b = tl.load(b_ptrs, mask=b_mask, other=0.) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + accumulator = accumulator.to(c_ptr.dtype.element_ty) + if scale > 0: + accumulator = accumulator * scale.to(c_ptr.dtype.element_ty) + + + offs_accumu_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_accumu_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_accumu_m[:, None] + + stride_cn * offs_accumu_n[None, :]) + accumulator_mask = (offs_accumu_m[:, None] < M) & (offs_accumu_n[None, :] < N) + + # load alibi + alibi_ptrs = (alibi_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_accumu_m[:, None] + + stride_cn * offs_accumu_n[None, :]) + alibi_vals = tl.load(alibi_ptrs, mask=accumulator_mask, other=0.) + accumulator += (alibi_vals * beta.to(c_ptr.dtype.element_ty)) + + accumulator = accumulator.to(c_ptr.dtype.element_ty) + + tl.store(c_ptrs, accumulator, mask=accumulator_mask) diff --git a/colossalai/kernel/triton/softmax_kernel.py b/colossalai/kernel/triton/softmax_kernel.py new file mode 100644 index 000000000000..c215890badff --- /dev/null +++ b/colossalai/kernel/triton/softmax_kernel.py @@ -0,0 +1,44 @@ +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + ''' + softmax kernel is modified based on + https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py + ''' + @triton.jit + def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr): + r""" the kernel function for implementing softmax operator + Args: + output_ptr: the output after finishing softmax operation, (N, hidden_dim) + input_ptr: the tensor of input, shape should be (N, hidden_dim) + n_cols(tl.constexpr): the number of cols of input + BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim + """ + row_idx = tl.program_id(0) + row_start_ptr = input_ptr + row_idx * row_stride + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32) + row_minus_max = row - tl.max(row, axis=0) + + if mask_ptr is not None: + # load mask into SRAM + mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets + mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32) + + # update + row_minus_max = row_minus_max + mask + + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + output_row_start_ptr = output_ptr + row_idx * row_stride + output_ptrs = output_row_start_ptr + col_offsets + # Write back output to DRAM + tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) \ No newline at end of file diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index 76f550dc4392..ce0da111c2f7 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -1,5 +1,6 @@ +from contextlib import contextmanager from types import MethodType -from typing import Callable, Optional, Union +from typing import Callable, Dict, Optional, Union import torch import torch.distributed as dist @@ -8,8 +9,9 @@ from torch.utils._pytree import tree_map from colossalai._analyzer._subclasses import MetaTensor -from colossalai.tensor.d_tensor.d_tensor import DTensor -from colossalai.tensor.d_tensor.layout import Layout +from colossalai.device.device_mesh import DeviceMesh +# from colossalai.tensor.d_tensor import distribute_tensor +from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html _NORMAL_FACTORY = [ @@ -60,12 +62,15 @@ class _MyTensor(Tensor): """ _pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None + default_device: Optional[torch.device] = None + def __new__(cls, func, *args, concrete_data=None, **kwargs) -> '_MyTensor': cls._pre_op_fn() if concrete_data is not None: # uniform api as LazyTensor data = concrete_data else: + kwargs['device'] = cls.default_device data = func(*args, **kwargs) return Tensor._make_subclass(cls, data, require_grad=data.requires_grad) @@ -141,6 +146,8 @@ class LazyTensor(torch.Tensor): _meta_data: Optional[MetaTensor] = None # shape, dtype, device _pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None + default_device: Optional[torch.device] = None + @staticmethod def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs): if concrete_data is not None: @@ -158,6 +165,8 @@ def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs): return r def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs): + if func.__name__ in _NORMAL_FACTORY: + kwargs = {**kwargs, 'device': LazyTensor.default_device} self._factory_method = (func, args, kwargs) # (func, args, kwargs) self._op_buffer = [] # (func, args, kwargs, replace) self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data @@ -172,7 +181,7 @@ def materialize(self) -> torch.Tensor: self.clean() return _convert_cls(self, target) - def distribute(self, layout: Layout) -> torch.Tensor: + def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor: """Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout. Args: @@ -183,7 +192,7 @@ def distribute(self, layout: Layout) -> torch.Tensor: """ target = self._materialize_data() self.clean() - local_tensor = DTensor(target, layout).local_tensor + local_tensor = distribute_tensor(target, device_mesh, sharding_spec) return _convert_cls(self, local_tensor) def clean(self) -> None: @@ -205,16 +214,11 @@ def _materialize_data(self) -> torch.Tensor: if self._materialized_data is None: # apply factory method func, args, kwargs = self._factory_method - # apply cached sequence self._pre_op_fn() - try: - init_val = func(*tree_map(self._replace_with_materialized, args), - **tree_map(self._replace_with_materialized, kwargs)) - except TypeError as e: - print(f'init fn: {func.__name__}') - raise e + init_val = func(*tree_map(self._replace_with_materialized, args), + **tree_map(self._replace_with_materialized, kwargs)) self._materialized_data = self._rerun_ops(init_val) return self._materialized_data @@ -304,6 +308,7 @@ def wrap(y, i=None): else: # out of place op, create new lazy tensor fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i] + fn.__name__ = func.__name__ lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs) return lazy_y elif type(y) is Tensor: @@ -434,14 +439,21 @@ class LazyInitContext: """ _replaced: bool = False - def __init__(self, tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor): + def __init__(self, + tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor, + default_device: Optional[Union[torch.device, str, int]] = None): + assert tensor_cls is LazyTensor or tensor_cls is _MyTensor self.overrides = {} self.tensor_cls = tensor_cls + self.old_default_device = LazyTensor.default_device + self.default_device = default_device def __enter__(self): if LazyInitContext._replaced: raise RuntimeError(f'LazyInitContext is not reentrant') LazyInitContext._replaced = True + self.old_default_device = self.tensor_cls.default_device + self.tensor_cls.default_device = self.default_device def wrap_factory_method(target): # factory functions (eg. torch.empty()) @@ -517,6 +529,7 @@ def wrapper(*args, **kwargs): setattr(torch, name, wrapper) def __exit__(self, exc_type, exc_val, exc_tb): + self.tensor_cls.default_device = self.old_default_device LazyInitContext._replaced = False for name, (wrapper, orig) in self.overrides.items(): setattr(torch, name, orig) @@ -536,7 +549,10 @@ def apply_fn(name: str, p: LazyTensor): return _apply_to_lazy_module(module, apply_fn, verbose) @staticmethod - def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module: + def distribute(module: nn.Module, + device_mesh: DeviceMesh, + sharding_spec_dict: Dict[str, ShardingSpec], + verbose: bool = False) -> nn.Module: """Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. Args: @@ -546,7 +562,7 @@ def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> n """ def apply_fn(name: str, p: LazyTensor): - p.distribute(layout_dict[name]) + p.distribute(device_mesh, sharding_spec_dict[name]) return _apply_to_lazy_module(module, apply_fn, verbose) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 094320c4aff4..9324c072c13b 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -22,6 +22,7 @@ from .gemini_hook import GeminiZeROHook from .gemini_mgr import GeminiManager from .memory_tracer import MemStats, OrderedParamGenerator +from .placement_policy import CUDAPlacementPolicy from .utils import get_temp_total_chunk_on_cuda try: @@ -74,6 +75,7 @@ def __init__(self, self.param2name: Dict[nn.Parameter, str] = dict() self.name2param: Dict[str, nn.Parameter] = dict() self.scatter_after_inference = scatter_after_inference + self.prefetch_before_inference = type(gemini_manager._placement_policy) is CUDAPlacementPolicy self.mixed_precision = mixed_precision self._logger = get_dist_logger() @@ -90,7 +92,7 @@ def __init__(self, self._init_chunks(param_order=param_order, strict_ddp_mode=strict_ddp_mode, - cpu_offload=self.gemini_manager.policy_name != 'cuda', + cpu_offload=not self.gemini_manager.policy_name.startswith('cuda'), pin_memory=pin_memory) for name, param in module.named_parameters(): @@ -172,7 +174,7 @@ def _inference_forward(self, *args, **kwargs): """This function is only triggered for inference. """ fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook) - if not self.scatter_after_inference: + if self.prefetch_before_inference: # gather all chunks for chunk in self.chunk_manager.get_chunks(self.fp16_params): self.chunk_manager.access_chunk(chunk) diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index 84a868872f88..6ff29f2d02d3 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -7,6 +7,7 @@ from colossalai.utils import get_current_device from colossalai.utils.memory import colo_device_memory_capacity +from colossalai.zero.gemini.chunk import Chunk from .chunk import Chunk, ChunkManager from .memory_tracer import ChunkMemStatsCollector @@ -42,7 +43,7 @@ def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, f start = time() for chunk in can_evict_chunks: self.chunk_manager.release_chunk(chunk) - self.chunk_manager.move_chunk(chunk, torch.device('cpu')) + # self.chunk_manager.move_chunk(chunk, torch.device('cpu')) volume += chunk.chunk_mem return volume, time() - start @@ -63,6 +64,15 @@ def get_default_device() -> torch.device: return get_current_device() +class CUDAReshardPlacementPolicy(CUDAPlacementPolicy): + + def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: + start = time() + for chunk in can_evict_chunks: + self.chunk_manager.release_chunk(chunk) + return 0, time() - start + + class AutoPlacementPolicy(PlacementPolicy): need_mem_stats: bool = True @@ -227,7 +237,8 @@ class PlacementPolicyFactory: 'cpu': CPUPlacementPolicy, 'cuda': CUDAPlacementPolicy, 'auto': AutoPlacementPolicy, - 'const': ConstPlacementPolicy + 'const': ConstPlacementPolicy, + 'cuda_reshard': CUDAReshardPlacementPolicy, } @staticmethod diff --git a/colossalai/zero/low_level/_utils.py b/colossalai/zero/low_level/_utils.py index 218f7603bc54..a56b52f57c30 100644 --- a/colossalai/zero/low_level/_utils.py +++ b/colossalai/zero/low_level/_utils.py @@ -109,11 +109,14 @@ def split_by_dtype(tensor_list): return buckets -def reduce_tensor_dp_group(tensor: torch.Tensor, - dtype: Optional[torch.dtype] = None, - dst_local_rank: Optional[int] = None, - dst_global_rank: Optional[int] = None, - group: Optional[dist.ProcessGroup] = None): +def reduce_tensor_dp_group( + tensor: torch.Tensor, + dtype: Optional[torch.dtype] = None, + dst_local_rank: Optional[int] = None, + dst_global_rank: Optional[int] = None, + group: Optional[dist.ProcessGroup] = None, + inner_dp_group: Optional[dist.ProcessGroup] = None, +): """ Reduce the tensor in the data parallel process group @@ -150,6 +153,11 @@ def reduce_tensor_dp_group(tensor: torch.Tensor, else: dist.reduce(tensor=tensor_to_reduce, dst=dst_global_rank, group=group) + if inner_dp_group is not None: + inner_dp_size = dist.get_world_size(group=inner_dp_group) + tensor_to_reduce.div_(inner_dp_size) + dist.all_reduce(tensor_to_reduce, group=inner_dp_group) + # recover the original dtype if tensor.dtype != dtype and tensor is not tensor_to_reduce: local_rank = dist.get_rank(group=group) @@ -232,7 +240,7 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2): for g, p in zip(gradients, params): # Pipeline parallelism may replicate parameters. Avoid multi-counting. tp_param_flag = False - if is_model_parallel_parameter(p) or (isinstance(p, ColoParameter) and not p.is_replicate()): + if is_model_parallel_parameter(p) or (type(p) is ColoParameter and not p.is_replicate()): tp_param_flag = True if tp_param_flag or mp_rank == 0: param_norm = g.data.double().norm(2) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index ee03c0f0ae15..2a0b14072ced 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -56,28 +56,40 @@ def check_local_overflow(self) -> bool: return False +def _get_ranks_in_group(group: ProcessGroup) -> list: + from torch.distributed.distributed_c10d import _get_default_group, _pg_group_ranks + if group is None: + group = _get_default_group() + group_rank_map = _pg_group_ranks[group] + return list(group_rank_map.keys()) + + class LowLevelZeroOptimizer(ColossalaiOptimizer): """Optimizer used for ZeRO-1 and ZeRO-2. """ def __init__( - self, - optimizer: Optimizer, - initial_scale: int = 2**16, # grad scaler config - min_scale: int = 1, - growth_factor: float = 2., - backoff_factor: float = .5, - growth_interval: int = 2000, - hysteresis: int = 2, - max_scale: int = 2**24, - clip_grad_norm: float = 0.0, # grad clipping - verbose: bool = False, - reduce_bucket_size: int = 1024 * 1024, # communication - communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = False, - partition_grad: bool = False, # stage 2 flag - cpu_offload: bool = False, # cpu offload - forced_dtype: Optional[torch.dtype] = None): + self, + optimizer: Optimizer, + initial_scale: int = 2**16, # grad scaler config + min_scale: int = 1, + growth_factor: float = 2., + backoff_factor: float = .5, + growth_interval: int = 2000, + hysteresis: int = 2, + max_scale: int = 2**24, + clip_grad_norm: float = 0.0, # grad clipping + verbose: bool = False, + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = False, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + forced_dtype: Optional[torch.dtype] = None, + zero_process_group: Optional[ProcessGroup] = None, + dp_process_group: Optional[ProcessGroup] = None, + tp_process_group: Optional[ProcessGroup] = None, + ): # TODO: add support for # 1. fp16 master weights @@ -95,30 +107,14 @@ def __init__( self._cpu_offload = cpu_offload - colo_pg = self._search_colo_process_group() - if isinstance(colo_pg, ProcessGroup): - self._local_rank = colo_pg.dp_local_rank() - self._world_size = colo_pg.dp_world_size() - self._dp_global_ranks = colo_pg.get_ranks_in_dp() - self._dp_torch_group = colo_pg.dp_process_group() - self._mp_torch_group = None - if colo_pg.tp_world_size() > 1: - self._mp_torch_group = colo_pg.tp_process_group() - elif colo_pg is None: - dp_parallel_mode = ParallelMode.DATA - mp_parallel_mode = ParallelMode.MODEL - - self._dp_parallel_mode = dp_parallel_mode - self._mp_parallel_mode = mp_parallel_mode - self._local_rank = gpc.get_local_rank(dp_parallel_mode) - self._world_size = gpc.get_world_size(dp_parallel_mode) - self._dp_global_ranks = gpc.get_ranks_in_group(dp_parallel_mode) - self._dp_torch_group = gpc.get_group(dp_parallel_mode) - self._mp_torch_group = None - if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1: - self._mp_torch_group = gpc.get_group(mp_parallel_mode) - else: - raise NotImplementedError + if zero_process_group is None: + assert dp_process_group is None and tp_process_group is None + self._local_rank = dist.get_rank(zero_process_group) + self._world_size = dist.get_world_size(zero_process_group) + self._dp_global_ranks = _get_ranks_in_group(zero_process_group) + self._dp_torch_group = zero_process_group + self._mp_torch_group = tp_process_group + self._inner_dp_group = dp_process_group # working and master params for mixed precision training self._working_param_groups = dict() @@ -181,7 +177,7 @@ def __init__( tensor_list = self._param_store.get_params_by_rank_group(rank, group_id) with torch.no_grad(): flat_tensor = flatten(tensor_list) - flat_tensor = flat_tensor.data.cuda() + # flat_tensor = flat_tensor.data.cuda() self._param_store.add_flat_param_by_rank_group(rank, group_id, flat_tensor) # sync parameters @@ -192,7 +188,7 @@ def __init__( # create a copy of fp32 master weights of the parameters for which this rank is responsible working_flat_current_rank = self._param_store.get_flat_param_by_rank_group(self._local_rank, group_id) - master_flat_current_rank = working_flat_current_rank.float() + master_flat_current_rank = working_flat_current_rank.data.float() device = 'cpu' if self._cpu_offload else get_current_device() master_flat_current_rank = master_flat_current_rank.to(device) master_flat_current_rank.requires_grad = True @@ -255,7 +251,7 @@ def _search_colo_process_group(self): for param_group in self.optim.param_groups: group_params = param_group['params'] for param in group_params: - if isinstance(param, ColoParameter): + if type(param) is ColoParameter: colo_flag = True if colo_pg is None: colo_pg = param.get_process_group() @@ -325,7 +321,8 @@ def _reduce_tensor_bucket(self, bucket: TensorBucket, reduce_rank): dtype=self._communication_dtype, dst_local_rank=reduce_rank, dst_global_rank=reduce_global_rank, - group=self._dp_torch_group) + group=self._dp_torch_group, + inner_dp_group=self._inner_dp_group) # update the reduced tensor if reduce_rank is None or reduce_rank == self._local_rank: @@ -599,3 +596,12 @@ def _reduce_grad_stage2(self): # left in the communication bucket for reduce_rank in range(self._world_size): self._run_reduction(reduce_rank) + + def to(self, device): + for group_id in range(len(self.optim.param_groups)): + # sync parameters + for rank in range(self._world_size): + flat_tensor = self._param_store.get_flat_param_by_rank_group(rank, group_id) + flat_tensor.data = flat_tensor.data.to(device) + tensor_list = self._param_store.get_params_by_rank_group(rank, group_id) + sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list) diff --git a/examples/tutorial/fastfold/FastFold b/examples/tutorial/fastfold/FastFold index 05681304651b..eba496808a91 160000 --- a/examples/tutorial/fastfold/FastFold +++ b/examples/tutorial/fastfold/FastFold @@ -1 +1 @@ -Subproject commit 05681304651b1b29d7d887db169045ea3dd28fce +Subproject commit eba496808a91bbcd9661cf832349a418b197015f diff --git a/tests/test_cluster/test_process_group_mesh.py b/tests/test_cluster/test_process_group_mesh.py new file mode 100644 index 000000000000..13b7119424e4 --- /dev/null +++ b/tests/test_cluster/test_process_group_mesh.py @@ -0,0 +1,151 @@ +import pytest +import torch.distributed as dist + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.testing import spawn + + +def check_process_group_mesh_with_gpc(): + from colossalai.context import ParallelMode + from colossalai.core import global_context as gpc + + DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 + pg_mesh = ProcessGroupMesh(1, 2, 2) + + # check world size + assert gpc.get_world_size(ParallelMode.TENSOR) == pg_mesh.size( + TP_DIM), f'{gpc.get_world_size(ParallelMode.TENSOR)} != {pg_mesh.size(TP_DIM)}' + assert gpc.get_world_size(ParallelMode.PIPELINE) == pg_mesh.size(PP_DIM) + assert gpc.get_world_size(ParallelMode.DATA) == pg_mesh.size(DP_DIM) + + # check locak rank (coordinate) + assert gpc.get_local_rank(ParallelMode.TENSOR) == pg_mesh.coordinate( + TP_DIM), f'{gpc.get_local_rank(ParallelMode.TENSOR)} != {pg_mesh.coordinate(TP_DIM)}' + assert gpc.get_local_rank(ParallelMode.PIPELINE) == pg_mesh.coordinate(PP_DIM) + assert gpc.get_local_rank(ParallelMode.DATA) == pg_mesh.coordinate(DP_DIM) + + # check ranks in group + tp_group = pg_mesh.get_group_along_axis(TP_DIM) + assert gpc.get_ranks_in_group(ParallelMode.TENSOR) == pg_mesh.get_ranks_in_group(tp_group) + pp_group = pg_mesh.get_group_along_axis(PP_DIM) + assert gpc.get_ranks_in_group(ParallelMode.PIPELINE) == pg_mesh.get_ranks_in_group(pp_group) + dp_group = pg_mesh.get_group_along_axis(DP_DIM) + assert gpc.get_ranks_in_group(ParallelMode.DATA) == pg_mesh.get_ranks_in_group(dp_group) + + # check prev rank + coord = pg_mesh.coordinate() + if not gpc.is_first_rank(ParallelMode.TENSOR): + assert coord[TP_DIM] != 0 + prev_coord = coord[:TP_DIM] + (coord[TP_DIM] - 1,) + coord[TP_DIM + 1:] + assert gpc.get_prev_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(prev_coord, pg_mesh.shape) + if not gpc.is_first_rank(ParallelMode.PIPELINE): + assert coord[PP_DIM] != 0 + prev_coord = coord[:PP_DIM] + (coord[PP_DIM] - 1,) + coord[PP_DIM + 1:] + assert gpc.get_prev_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(prev_coord, pg_mesh.shape) + + # check next rank + if not gpc.is_last_rank(ParallelMode.TENSOR): + assert coord[TP_DIM] != pg_mesh.size(TP_DIM) - 1 + next_coord = coord[:TP_DIM] + (coord[TP_DIM] + 1,) + coord[TP_DIM + 1:] + assert gpc.get_next_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(next_coord, pg_mesh.shape) + if not gpc.is_last_rank(ParallelMode.PIPELINE): + assert coord[PP_DIM] != pg_mesh.size(PP_DIM) - 1 + next_coord = coord[:PP_DIM] + (coord[PP_DIM] + 1,) + coord[PP_DIM + 1:] + assert gpc.get_next_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(next_coord, pg_mesh.shape) + + +def check_process_group_mesh_with_cases(): + DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 + DP_SIZE, PP_SIZE, TP_SIZE = 1, 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0, 0), + 1: (0, 0, 1), + 2: (0, 1, 0), + 3: (0, 1, 1), + } + TP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + PP_RANKS_IN_GROUP = { + 0: [0, 2], + 1: [1, 3], + 2: [0, 2], + 3: [1, 3], + } + DP_RANKS_IN_GROUP = { + 0: [0], + 1: [1], + 2: [2], + 3: [3], + } + + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE, TP_SIZE) + + rank = dist.get_rank() + assert rank == pg_mesh.rank + + # check world size + assert pg_mesh.size(TP_DIM) == 2 + assert pg_mesh.size(PP_DIM) == 2 + assert pg_mesh.size(DP_DIM) == 1 + + # check coordinate + assert pg_mesh.coordinate(TP_DIM) == RANK_TO_COORDINATE[rank][TP_DIM] + assert pg_mesh.coordinate(PP_DIM) == RANK_TO_COORDINATE[rank][PP_DIM] + assert pg_mesh.coordinate(DP_DIM) == RANK_TO_COORDINATE[rank][DP_DIM] + + # check ranks in group + tp_group = pg_mesh.get_group_along_axis(TP_DIM) + assert pg_mesh.get_ranks_in_group(tp_group) == TP_RANKS_IN_GROUP[rank] + pp_group = pg_mesh.get_group_along_axis(PP_DIM) + assert pg_mesh.get_ranks_in_group(pp_group) == PP_RANKS_IN_GROUP[rank] + dp_group = pg_mesh.get_group_along_axis(DP_DIM) + assert pg_mesh.get_ranks_in_group(dp_group) == DP_RANKS_IN_GROUP[rank] + + # check prev rank + if RANK_TO_COORDINATE[rank][TP_DIM] != 0: + prev_coord = RANK_TO_COORDINATE[rank][:TP_DIM] + (RANK_TO_COORDINATE[rank][TP_DIM] - 1,) + \ + RANK_TO_COORDINATE[rank][TP_DIM + 1:] + prev_rank = TP_RANKS_IN_GROUP[rank][TP_RANKS_IN_GROUP[rank].index(rank) - 1] + assert pg_mesh.ravel(prev_coord, pg_mesh.shape) == prev_rank + if RANK_TO_COORDINATE[rank][PP_DIM] != 0: + prev_coord = RANK_TO_COORDINATE[rank][:PP_DIM] + (RANK_TO_COORDINATE[rank][PP_DIM] - 1,) + \ + RANK_TO_COORDINATE[rank][PP_DIM + 1:] + prev_rank = PP_RANKS_IN_GROUP[rank][PP_RANKS_IN_GROUP[rank].index(rank) - 1] + assert pg_mesh.ravel(prev_coord, pg_mesh.shape) == prev_rank + + # check next rank + if RANK_TO_COORDINATE[rank][TP_DIM] != TP_SIZE - 1: + next_coord = RANK_TO_COORDINATE[rank][:TP_DIM] + (RANK_TO_COORDINATE[rank][TP_DIM] + 1,) + \ + RANK_TO_COORDINATE[rank][TP_DIM + 1:] + next_rank = TP_RANKS_IN_GROUP[rank][TP_RANKS_IN_GROUP[rank].index(rank) + 1] + assert pg_mesh.ravel(next_coord, pg_mesh.shape) == next_rank + if RANK_TO_COORDINATE[rank][PP_DIM] != PP_SIZE - 1: + next_coord = RANK_TO_COORDINATE[rank][:PP_DIM] + (RANK_TO_COORDINATE[rank][PP_DIM] + 1,) + \ + RANK_TO_COORDINATE[rank][PP_DIM + 1:] + next_rank = PP_RANKS_IN_GROUP[rank][PP_RANKS_IN_GROUP[rank].index(rank) + 1] + assert pg_mesh.ravel(next_coord, pg_mesh.shape) == next_rank + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(parallel=dict(data=1, pipeline=2, tensor=dict(mode='1d', size=2))), + rank=rank, + world_size=world_size, + port=port, + host='localhost') + # TODO(ver217): this function should be removed when gpc is removed + check_process_group_mesh_with_gpc() + check_process_group_mesh_with_cases() + + +@pytest.mark.dist +def test_process_group_mesh(): + spawn(run_dist, 4) + + +if __name__ == '__main__': + test_process_group_mesh() diff --git a/tests/test_kernels/test_self_attention.py b/tests/test_kernels/test_self_attention.py new file mode 100644 index 000000000000..b316404a58db --- /dev/null +++ b/tests/test_kernels/test_self_attention.py @@ -0,0 +1,136 @@ +import pytest +from packaging import version +import torch +from torch import nn +import torch.nn.functional as F + +from colossalai.kernel.triton.ops import self_attention_compute_using_triton +from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel + +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4") +def test_qkv_matmul(): + qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16) + scale = 1.2 + head_size = 32 + batches = qkv.shape[0] + d_model = qkv.shape[-1] // 3 + num_of_heads = d_model // head_size + + q = qkv[:, :, :d_model] + k = qkv[:, :, d_model:d_model * 2] + + q = q.view(batches, -1, num_of_heads, head_size) + k = k.view(batches, -1, num_of_heads, head_size) + q_copy = q.clone() + k_copy = k.clone() + q = torch.transpose(q, 1, 2).contiguous() + k = torch.transpose(k, 1, 2).contiguous() + k = torch.transpose(k, 2, 3).contiguous() + + torch_ouput = torch.einsum('bnij,bnjk->bnik', q, k) + torch_ouput *= 1.2 + + q, k = q_copy, k_copy + batches, M, H, K = q.shape + N = k.shape[1] + score_output = torch.empty( + (batches, H, M, N), device=q.device, dtype=q.dtype) + + grid = lambda meta: ( + batches, + H, + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * + triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + + K = q.shape[3] + qkv_gemm_4d_kernel[grid]( + q, k, score_output, + M, N, K, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(3), k.stride(1), + score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3), + scale=scale, + # currently manually setting, later on we can use auto-tune config to match best setting + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=32, + BLOCK_SIZE_K=32, + GROUP_SIZE_M=8, + ) + + check = torch.allclose(torch_ouput.cpu(), score_output.cpu(), rtol=1e-3, atol=1e-5) + assert check is True, "the outputs of triton and torch are not matched" + + +def self_attention_compute_using_torch(qkv, + input_mask, + scale, + head_size + ): + + batches = qkv.shape[0] + d_model = qkv.shape[-1] // 3 + num_of_heads = d_model // head_size + + q = qkv[:, :, :d_model] + k = qkv[:, :, d_model:d_model * 2] + v = qkv[:, :, d_model * 2:] + q = q.view(batches, -1, num_of_heads, head_size) + k = k.view(batches, -1, num_of_heads, head_size) + v = v.view(batches, -1, num_of_heads, head_size) + + q = torch.transpose(q, 1, 2).contiguous() + k = torch.transpose(k, 1, 2).contiguous() + v = torch.transpose(v, 1, 2).contiguous() + + k = torch.transpose(k, -1, -2).contiguous() + + score_output = torch.einsum('bnij,bnjk->bnik', q, k) + score_output *= scale + + softmax_output = F.softmax(score_output, dim = -1) + res = torch.einsum('bnij,bnjk->bnik', softmax_output, v) + res = torch.transpose(res, 1, 2) + res = res.contiguous() + + + return res.view(batches, -1, d_model), score_output, softmax_output + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4") +def test_self_atttention_test(): + + qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16) + data_output_torch, score_output_torch, softmax_output_torch = self_attention_compute_using_torch( + qkv.clone(), + input_mask = None, + scale = 1.2, + head_size = 32 + ) + + data_output_triton = self_attention_compute_using_triton( + qkv.clone(), + alibi=None, + head_size=32, + scale=1.2, + input_mask=None, + layer_past=None, + use_flash=False, + triangular=True) + + check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-4, atol=1e-2) + assert check is True, "the triton output is not matched with torch output" + + +if __name__ == "__main__": + test_qkv_matmul() + test_self_atttention_test() \ No newline at end of file diff --git a/tests/test_kernels/test_softmax.py b/tests/test_kernels/test_softmax.py new file mode 100644 index 000000000000..843d811d019c --- /dev/null +++ b/tests/test_kernels/test_softmax.py @@ -0,0 +1,27 @@ +import pytest +from packaging import version +import torch +from torch import nn + +from colossalai.kernel.triton.ops import softmax + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4") +def test_softmax_op(): + data_samples = [ + torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32), + torch.randn((320, 320, 78), device = "cuda", dtype = torch.float32), + torch.randn((2345, 4, 5, 64), device = "cuda", dtype = torch.float16) + ] + + for data in data_samples: + module = nn.Softmax(dim = -1) + data_torch_out = module(data) + data_triton_out = softmax(data) + check = torch.allclose(data_torch_out.cpu(), data_triton_out.cpu(), rtol=1e-3, atol=1e-3) + assert check is True, "softmax outputs from triton and torch are not matched" + + +if __name__ == "__main__": + test_softmax_op() \ No newline at end of file