diff --git a/LICENSE b/LICENSE index 59d456c5b8a1..b3eb43520a6f 100644 --- a/LICENSE +++ b/LICENSE @@ -477,3 +477,53 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + + ---------------- LICENSE FOR torch-int ---------------- + + MIT License + + Copyright (c) 2022 Guangxuan Xiao + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + + ---------------- LICENSE FOR smoothquant ---------------- + + MIT License + + Copyright (c) 2022 MIT HAN Lab + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py index 6c58c59307a6..1926ec78aba8 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py @@ -6,25 +6,20 @@ import torch import torch.nn.functional as F +from einops import rearrange +from flash_attn.bert_padding import pad_input, unpad_input +from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func +from flash_attn.ops.rms_norm import rms_norm from transformers.models.llama.modeling_llama import ( - LlamaRMSNorm, LlamaAttention, - LlamaModel, LlamaForCausalLM, + LlamaModel, + LlamaRMSNorm, apply_rotary_pos_emb, repeat_kv, ) from colossalai.logging import get_dist_logger -from einops import rearrange - -from flash_attn.bert_padding import pad_input, unpad_input -from flash_attn.flash_attn_interface import ( - flash_attn_func, - flash_attn_varlen_kvpacked_func, -) -from flash_attn.ops.rms_norm import rms_norm - logger = get_dist_logger() @@ -65,6 +60,7 @@ def attention_forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention. diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 6c165857506c..20a931b816ea 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -245,6 +245,7 @@ class GeminiPlugin(DPPluginBase): chunk_config_dict (dict, optional): chunk configuration dictionary. chunk_init_device (torch.device, optional): device to initialize the chunk. placement_policy (str, optional): "static" and "auto". Defaults to "static". + enable_gradient_accumulation (bool, optional): Whether to enable gradient accumulation. When set to True, gradient will be stored after doing backward pass. Defaults to False. shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement. If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0. offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement. @@ -257,7 +258,7 @@ class GeminiPlugin(DPPluginBase): warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8. steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9. precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'. - master_weights (bool, optional): master weights. Defaults to True. + master_weights (bool, optional): Whether to keep fp32 master parameter weights in optimizer. Defaults to True. pin_memory (bool, optional): use pin memory on CPU. Defaults to False. force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False. @@ -291,6 +292,7 @@ def __init__( chunk_config_dict: Optional[dict] = None, chunk_init_device: Optional[torch.device] = None, placement_policy: str = "static", + enable_gradient_accumulation: bool = False, shard_param_frac: float = 1.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement @@ -323,6 +325,7 @@ def __init__( chunk_config_dict=chunk_config_dict, chunk_init_device=(chunk_init_device or get_current_device()), placement_policy=placement_policy, + enable_gradient_accumulation=enable_gradient_accumulation, shard_param_frac=shard_param_frac, offload_optim_frac=offload_optim_frac, offload_param_frac=offload_param_frac, diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 088b67c8c533..dc78fe8c094c 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -335,4 +335,4 @@ def get_checkpoint_io(self) -> CheckpointIO: def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: assert isinstance(optimizer, LowLevelZeroOptimizer) - return optimizer.optim.no_sync() + return optimizer.no_sync() diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index 9a965dc982a4..d0c281e057b3 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -4,7 +4,7 @@ ## Introduction -`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including TGI, vLLM, FasterTransformer, LightLLM and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users. +`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including LightLLM, TGI, vLLM, FasterTransformer and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users. ## Design @@ -62,6 +62,12 @@ triton==2.0.0.dev20221202 vllm # for install flash-attention, please use commit hash: 67ae6fd74b4bc99c36b2ce524cf139c35663793c flash-attention + +# install lightllm since we depend on lightllm triton kernels +git clone https://github.com/ModelTC/lightllm +git checkout 28c1267cfca536b7b4f28e921e03de735b003039 +cd lightllm +pip3 install -e . ``` ### Docker @@ -73,6 +79,17 @@ You can use docker run to use docker container to set-up environment docker pull hpcaitech/colossalai-inference:v2 docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash +# enter into docker container +cd /path/to/CollossalAI +pip install -e . + +# install lightllm +git clone https://github.com/ModelTC/lightllm +git checkout 28c1267cfca536b7b4f28e921e03de735b003039 +cd lightllm +pip3 install -e . + + ``` ### Dive into fast-inference! @@ -94,7 +111,7 @@ For various models, experiments were conducted using multiple batch sizes under ### Single GPU Performance: -Currently the stats below are calculated based on A100 (single GPU), and we calculate token latency based on average values of context-forward and decoding forward process, which means we combine both of processes to calculate token generation times. We are actively developing new features and methods to furthur optimize the performance of LLM models. Please stay tuned. +Currently the stats below are calculated based on A100 (single GPU), and we calculate token latency based on average values of context-forward and decoding forward process, which means we combine both of processes to calculate token generation times. We are actively developing new features and methods to further optimize the performance of LLM models. Please stay tuned. #### Llama diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py index db33ae6fe998..35891307e754 100644 --- a/colossalai/inference/__init__.py +++ b/colossalai/inference/__init__.py @@ -1,3 +1,3 @@ from .pipeline import PPInferEngine -__all__ = ['PPInferEngine'] +__all__ = ["PPInferEngine"] diff --git a/colossalai/inference/pipeline/__init__.py b/colossalai/inference/pipeline/__init__.py index aff4568f7d08..41af9f3ef948 100644 --- a/colossalai/inference/pipeline/__init__.py +++ b/colossalai/inference/pipeline/__init__.py @@ -1,3 +1,3 @@ from .engine import PPInferEngine -__all__ = ['PPInferEngine'] +__all__ = ["PPInferEngine"] diff --git a/colossalai/inference/pipeline/benchmark/benchmark.py b/colossalai/inference/pipeline/benchmark/benchmark.py index 97dfc6336bea..9c47909f70f0 100644 --- a/colossalai/inference/pipeline/benchmark/benchmark.py +++ b/colossalai/inference/pipeline/benchmark/benchmark.py @@ -1,28 +1,32 @@ +import argparse +import time + import torch import torch.distributed as dist import transformers import colossalai -import time from colossalai.inference import PPInferEngine from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy -import argparse -GIGABYTE = 1024 ** 3 + +GIGABYTE = 1024**3 MEGABYTE = 1024 * 1024 colossalai.launch_from_torch(config={}) -def data_gen(batch_size: int=4, seq_len: int=512): + +def data_gen(batch_size: int = 4, seq_len: int = 512): input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32) attention_mask = torch.ones((1, seq_len), dtype=torch.int32) data = dict(input_ids=input_ids, attention_mask=attention_mask) for k, v in data.items(): - if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: new_shape = [1] * v.dim() new_shape[0] = batch_size - data[k] = v.to('cuda').repeat(*new_shape) + data[k] = v.to("cuda").repeat(*new_shape) return data + def print_details_info(timestamps, model_config, args, whole_end2end): if dist.get_rank() == 0: prefill = [] @@ -31,32 +35,37 @@ def print_details_info(timestamps, model_config, args, whole_end2end): for timestamp in timestamps: prefill.append(timestamp[1] - timestamp[0]) encoder.append( - sum(timestamp[i + 1] - timestamp[i] for i in range(1,len(timestamp) - 1)) / (len(timestamp) - 2)) + sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2) + ) end2end.append(timestamp[-1] - timestamp[0]) print(whole_end2end) - with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","w+") as f: - mb_avg_end2end = sum(end2end)/len(end2end) - mb_avg_latency = mb_avg_end2end/(args.new_length * args.mb_size) - whole_avg_latency = whole_end2end/(args.new_length * args.batch_size) + with open( + f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log", + "w+", + ) as f: + mb_avg_end2end = sum(end2end) / len(end2end) + mb_avg_latency = mb_avg_end2end / (args.new_length * args.mb_size) + whole_avg_latency = whole_end2end / (args.new_length * args.batch_size) num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers) num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size - if args.dtype in ['fp16','bf16']: + if args.dtype in ["fp16", "bf16"]: num_bytes = 2 else: num_bytes = 4 - f.write(f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n") - f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill)/len(prefill)*1000)) - f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder)/len(encoder)*1000)) - f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end*1000)) + f.write( + f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n" + ) + f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill) / len(prefill) * 1000)) + f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder) / len(encoder) * 1000)) + f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end * 1000)) f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000)) - f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end*1000)) + f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end * 1000)) f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000)) - f.write("Throughput: {} tokens/s\n".format((1000/(whole_avg_latency * 1000)))) - f.write("flops: {0:8.2f} TFlops/s\n".format(1/whole_avg_latency * num_parameters * num_bytes / 1e12)) + f.write("Throughput: {} tokens/s\n".format((1000 / (whole_avg_latency * 1000)))) + f.write("flops: {0:8.2f} TFlops/s\n".format(1 / whole_avg_latency * num_parameters * num_bytes / 1e12)) f.write("----------------------------------------------------------\n") - if torch.cuda.is_available(): current_device = torch.cuda.current_device() @@ -66,7 +75,10 @@ def print_details_info(timestamps, model_config, args, whole_end2end): max_memory_allocated = torch.cuda.max_memory_allocated() memory_reserved = torch.cuda.memory_reserved() max_memory_reserved = torch.cuda.max_memory_reserved() - with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","a") as f: + with open( + f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log", + "a", + ) as f: f.write( f"\nCurrently using GPU: {current_device}\n" f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n" @@ -77,29 +89,37 @@ def print_details_info(timestamps, model_config, args, whole_end2end): f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n" ) -if __name__ == '__main__': + +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--model', default='toy', help='the size of model') - parser.add_argument('-b', '--batch_size', type=int, default=8, help='batch size') - parser.add_argument('-s', '--seq_len', type=int, default=8, help='sequence length') - parser.add_argument('--new_length', type=int, default=4, help='new tokens length') - parser.add_argument('--mb_size', type=int, default=1, help='micro_batch_size') - parser.add_argument('--pp_size', type=int, default=2, help='pipeline size') - parser.add_argument('--log_path', type=str, default='./log' ,help='where to store the benchmark log') - parser.add_argument('--dtype', type=str, default='fp16', help='data type') + parser.add_argument("--model", default="toy", help="the size of model") + parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size") + parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length") + parser.add_argument("--new_length", type=int, default=4, help="new tokens length") + parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size") + parser.add_argument("--pp_size", type=int, default=2, help="pipeline size") + parser.add_argument("--log_path", type=str, default="./log", help="where to store the benchmark log") + parser.add_argument("--dtype", type=str, default="fp16", help="data type") args = parser.parse_args() - if args.model == 'toy': + if args.model == "toy": model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=8)) - elif args.model == '7b': - model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-7b-hf')) - elif args.model == '13b': - model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-13b-hf')) + elif args.model == "7b": + model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-7b-hf")) + elif args.model == "13b": + model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-13b-hf")) else: raise NotImplementedError - - - engine = PPInferEngine(pp_size=args.pp_size, dtype=args.dtype, micro_batch_size=args.mb_size, new_length=args.new_length, model=model, model_policy=LlamaForCausalLMPipelinePolicy(),verbose=True) + + engine = PPInferEngine( + pp_size=args.pp_size, + dtype=args.dtype, + micro_batch_size=args.mb_size, + new_length=args.new_length, + model=model, + model_policy=LlamaForCausalLMPipelinePolicy(), + verbose=True, + ) data = data_gen(args.batch_size, args.seq_len) torch.cuda.synchronize() @@ -109,4 +129,3 @@ def print_details_info(timestamps, model_config, args, whole_end2end): whole_end2end = time.time() - whole_end2end print_details_info(timestamps, model.config, args, whole_end2end) - diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py index 048ead2bccda..4f42385caf8f 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/pipeline/engine.py @@ -1,5 +1,3 @@ -from typing import Callable, List, Optional, Set, Union - import torch import torch.nn as nn @@ -13,7 +11,7 @@ class PPInferEngine: - ''' + """ PPInferEngine is a class that handles the pipeline parallel inference. Args: @@ -41,12 +39,12 @@ class PPInferEngine: output = engine.inference([tokenized_input]) ``` - ''' + """ def __init__( self, pp_size: int, - dtype: str = 'fp16', + dtype: str = "fp16", pp_model: nn.Module = None, model: nn.Module = None, model_policy: Policy = None, @@ -54,7 +52,7 @@ def __init__( micro_batch_size: int = 1, micro_batch_buffer_size: int = None, verbose: bool = False, - # TODO: implement early_stopping, and various gerneration options + # TODO: implement early_stopping, and various gerneration options early_stopping: bool = False, do_sample: bool = False, num_beams: int = 1, @@ -63,15 +61,16 @@ def __init__( self.pp_size = pp_size self.pg_mesh = ProcessGroupMesh(pp_size) self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) - self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size, - micro_batch_buffer_size or pp_size) + self.mb_manager = MicroBatchManager( + self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size + ) self.verbose = verbose self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose) - assert dtype in ['fp16', 'fp32', 'bf16'], "dtype should be one of 'fp16', 'fp32', 'bf16'" - if dtype == 'fp16': + assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'" + if dtype == "fp16": model.half() - elif dtype == 'bf16': + elif dtype == "bf16": model.to(torch.bfloat16) self.model = pp_model or self._shardformer(model, model_policy) diff --git a/colossalai/inference/pipeline/microbatch_manager.py b/colossalai/inference/pipeline/microbatch_manager.py index b6b008442cfd..49d1bf3f42cb 100644 --- a/colossalai/inference/pipeline/microbatch_manager.py +++ b/colossalai/inference/pipeline/microbatch_manager.py @@ -3,7 +3,7 @@ import torch -__all__ = 'MicroBatchManager' +__all__ = "MicroBatchManager" class Status(Enum): @@ -13,7 +13,7 @@ class Status(Enum): COOLDOWN = 4 -class MicroBatchDescription(): +class MicroBatchDescription: """ This is the class to record the infomation of each microbatch, and also do some update operation. This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more @@ -30,14 +30,14 @@ def __init__( output_dict: Dict[str, torch.Tensor], new_length: int, ) -> None: - assert output_dict.get('hidden_states') is not None - self.mb_length = output_dict['hidden_states'].shape[-2] + assert output_dict.get("hidden_states") is not None + self.mb_length = output_dict["hidden_states"].shape[-2] self.target_length = self.mb_length + new_length self.kv_cache = () def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): if output_dict is not None: - self._update_kvcache(output_dict['past_key_values']) + self._update_kvcache(output_dict["past_key_values"]) def _update_kvcache(self, kv_cache: Tuple): assert type(kv_cache) == tuple @@ -64,7 +64,6 @@ def cur_length(self): Return the current sequnence length of micro batch """ - pass class HeadMicroBatchDescription(MicroBatchDescription): @@ -80,13 +79,14 @@ class HeadMicroBatchDescription(MicroBatchDescription): """ - def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], - new_length: int) -> None: + def __init__( + self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int + ) -> None: super().__init__(inputs_dict, output_dict, new_length) assert inputs_dict is not None - assert inputs_dict.get('input_ids') is not None and inputs_dict.get('attention_mask') is not None - self.input_ids = inputs_dict['input_ids'] - self.attn_mask = inputs_dict['attention_mask'] + assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None + self.input_ids = inputs_dict["input_ids"] + self.attn_mask = inputs_dict["attention_mask"] self.new_tokens = None def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): @@ -104,7 +104,8 @@ def _update_newtokens(self, new_token: torch.Tensor): def _update_attnmask(self): self.attn_mask = torch.cat( - (self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device='cuda')), dim=-1) + (self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device="cuda")), dim=-1 + ) @property def cur_length(self): @@ -127,8 +128,9 @@ class BodyMicroBatchDescription(MicroBatchDescription): output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. """ - def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], - new_length: int) -> None: + def __init__( + self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int + ) -> None: super().__init__(inputs_dict, output_dict, new_length) def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): @@ -146,8 +148,8 @@ def cur_length(self): return self.kv_cache[0][0].shape[-2] + 1 -class MicroBatchManager(): - ''' +class MicroBatchManager: + """ MicroBatchManager is a class that manages the micro batch. Args: @@ -156,7 +158,7 @@ class MicroBatchManager(): micro_batch_size (int): the micro batch size. micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. - ''' + """ def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int): self.stage = stage diff --git a/colossalai/inference/pipeline/modeling/gpt2.py b/colossalai/inference/pipeline/modeling/gpt2.py index f490710c1f7f..d2bfcb8b6842 100644 --- a/colossalai/inference/pipeline/modeling/gpt2.py +++ b/colossalai/inference/pipeline/modeling/gpt2.py @@ -1,7 +1,6 @@ from typing import Dict, List, Optional, Tuple, Union import torch -from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model from transformers.utils import logging @@ -10,41 +9,41 @@ class GPT2PipelineForwards: - ''' + """ This class serves as a micro library for forward function substitution of GPT2 models under pipeline setting. - ''' + """ @staticmethod def gpt2_model_forward( - self: GPT2Model, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: - + self: GPT2Model, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. # Please refer to original code of transformers for more details. logger = logging.get_logger(__name__) # Preprocess passed in arguments if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False use_cache = use_cache if use_cache is not None else self.config.use_cache @@ -96,7 +95,7 @@ def gpt2_model_forward( # positions we want to attend and the dtype's smallest value for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention @@ -137,7 +136,8 @@ def gpt2_model_forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False presents = () if use_cache else None @@ -166,7 +166,6 @@ def gpt2_model_forward( if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, use_cache, output_attentions) @@ -218,61 +217,64 @@ def custom_forward(*inputs): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - return {'hidden_states': hidden_states, 'past_key_values': presents} + return {"hidden_states": hidden_states, "past_key_values": presents} @staticmethod def gpt2_lmhead_model_forward( - self: GPT2LMHeadModel, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: + self: GPT2LMHeadModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - - This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. - Please refer to original code of transformers for more details. - """ + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + + This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. + Please refer to original code of transformers for more details. + """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict # If is first stage and after warmup, go throught lm_head first if stage_manager.is_first_stage() and hidden_states is not None: lm_logits = self.lm_head(hidden_states) - return {'logits': lm_logits} + return {"logits": lm_logits} # Not first stage or before warmup, go through gpt2 model - outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index) + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) return outputs diff --git a/colossalai/inference/pipeline/modeling/llama.py b/colossalai/inference/pipeline/modeling/llama.py index eeda96df25fd..f46e1fbdd7b3 100644 --- a/colossalai/inference/pipeline/modeling/llama.py +++ b/colossalai/inference/pipeline/modeling/llama.py @@ -1,8 +1,6 @@ -from typing import List, Optional, Tuple +from typing import List, Optional import torch -from torch.nn import CrossEntropyLoss, MSELoss -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel from transformers.utils import logging @@ -10,10 +8,10 @@ class LlamaPipelineForwards: - ''' + """ This class serves as a micro library for forward function substitution of Llama models under pipeline setting. - ''' + """ def llama_model_forward( self: LlamaModel, @@ -34,10 +32,10 @@ def llama_model_forward( # Preprocess passed in arguments if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False use_cache = use_cache if use_cache is not None else self.config.use_cache @@ -70,10 +68,9 @@ def llama_model_forward( seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: - position_ids = torch.arange(past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device) + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -81,16 +78,18 @@ def llama_model_forward( # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), - dtype=torch.bool, - device=hidden_states.device) - attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states, - past_key_values_length) + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False # decoder layers @@ -112,7 +111,6 @@ def llama_model_forward( if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, None) @@ -152,7 +150,7 @@ def custom_forward(*inputs): next_cache = next_decoder_cache if use_cache else None # always return dict for imediate stage - return {'hidden_states': hidden_states, 'past_key_values': next_cache} + return {"hidden_states": hidden_states, "past_key_values": next_cache} def llama_for_causal_lm_forward( self: LlamaForCausalLM, @@ -171,45 +169,45 @@ def llama_for_causal_lm_forward( stage_index: Optional[List[int]] = None, ): r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - Returns: + Returns: - Example: + Example: - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - >>> prompt = "Hey, are you consciours? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" logger = logging.get_logger(__name__) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False # If is first stage and after warmup, go throught lm_head first if stage_manager.is_first_stage() and hidden_states is not None: lm_logits = self.lm_head(hidden_states) - return {'logits': lm_logits} + return {"logits": lm_logits} # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = LlamaPipelineForwards.llama_model_forward( diff --git a/colossalai/inference/pipeline/policy/gpt2_ppinfer.py b/colossalai/inference/pipeline/policy/gpt2_ppinfer.py index e51090200f83..51e6425b113e 100644 --- a/colossalai/inference/pipeline/policy/gpt2_ppinfer.py +++ b/colossalai/inference/pipeline/policy/gpt2_ppinfer.py @@ -11,7 +11,6 @@ class GPT2LMHeadModelPipelinePolicy(GPT2Policy): - def __init__(self) -> None: super().__init__() @@ -22,18 +21,22 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: addon_module = { - GPT2LMHeadModel: - ModulePolicyDescription(sub_module_replacement=[ + GPT2LMHeadModel: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) - ]) + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + ) + ] + ) } module_policy.update(addon_module) if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=GPT2LMHeadModel, - new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, - policy=module_policy) + self.set_pipeline_forward( + model_cls=GPT2LMHeadModel, + new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, + policy=module_policy, + ) return module_policy def get_held_layers(self) -> List[nn.Module]: @@ -45,7 +48,7 @@ def get_held_layers(self) -> List[nn.Module]: return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - '''The weights of wte and lm_head are shared.''' + """The weights of wte and lm_head are shared.""" module = self.model stage_manager = self.pipeline_stage_manager if stage_manager is not None: @@ -56,16 +59,16 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" + to customized forward method, and add this changing to policy.""" if not self.pipeline_stage_manager: raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == 'GPT2Model': + if self.model.__class__.__name__ == "GPT2Model": module = self.model else: module = self.model.transformer layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) diff --git a/colossalai/inference/pipeline/policy/llama_ppinfer.py b/colossalai/inference/pipeline/policy/llama_ppinfer.py index bb359de0bb6f..6e12ed61bf7b 100644 --- a/colossalai/inference/pipeline/policy/llama_ppinfer.py +++ b/colossalai/inference/pipeline/policy/llama_ppinfer.py @@ -1,19 +1,15 @@ -from functools import partial -from typing import Callable, Dict, List, Union +from typing import List -import torch.nn as nn -from torch import Tensor from torch.nn import Module -from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D -from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from colossalai.shardformer.layer import Linear1D_Col +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription from colossalai.shardformer.policies.llama import LlamaPolicy from ..modeling.llama import LlamaPipelineForwards class LlamaForCausalLMPipelinePolicy(LlamaPolicy): - def __init__(self) -> None: super().__init__() @@ -25,19 +21,21 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm new_item = { - LlamaForCausalLM: - ModulePolicyDescription(sub_module_replacement=[ + LlamaForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) - ]) + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ) + ] + ) } policy.update(new_item) if self.pipeline_stage_manager: # set None as default - self.set_pipeline_forward(model_cls=LlamaForCausalLM, - new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy + ) return policy diff --git a/colossalai/inference/pipeline/utils.py b/colossalai/inference/pipeline/utils.py index 1a6e8a519397..c26aa4e40b71 100644 --- a/colossalai/inference/pipeline/utils.py +++ b/colossalai/inference/pipeline/utils.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Set +from typing import Set import torch.nn as nn @@ -30,6 +30,6 @@ def get_suffix_name(suffix: str, name: str): suffix (str): The suffix of the suffix module name (str): The name of the current module """ - point = '' if suffix is '' else '.' - suffix_name = suffix + f'[{name}]' if name.isdigit() else suffix + f'{point}{name}' + point = "" if suffix is "" else "." + suffix_name = suffix + f"[{name}]" if name.isdigit() else suffix + f"{point}{name}" return suffix_name diff --git a/colossalai/inference/quant/smoothquant/__init__.py b/colossalai/inference/quant/smoothquant/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/inference/quant/smoothquant/models/__init__.py b/colossalai/inference/quant/smoothquant/models/__init__.py new file mode 100644 index 000000000000..77541d8610c5 --- /dev/null +++ b/colossalai/inference/quant/smoothquant/models/__init__.py @@ -0,0 +1,12 @@ +try: + import torch_int + + HAS_TORCH_INT = True +except ImportError: + HAS_TORCH_INT = False + raise ImportError( + "Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int" + ) + +if HAS_TORCH_INT: + from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP diff --git a/colossalai/inference/quant/smoothquant/models/base_model.py b/colossalai/inference/quant/smoothquant/models/base_model.py new file mode 100644 index 000000000000..6a1d96ecec8f --- /dev/null +++ b/colossalai/inference/quant/smoothquant/models/base_model.py @@ -0,0 +1,488 @@ +# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ +# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py +# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py + +import os +import warnings +from abc import abstractmethod +from functools import partial +from os.path import isdir, isfile, join +from typing import Dict, List, Optional, Union + +import accelerate +import numpy as np +import torch +import torch.nn as nn +import transformers +from safetensors.torch import save_file as safe_save +from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel +from transformers.modeling_utils import no_init_weights +from transformers.utils.generic import ContextManagers +from transformers.utils.hub import PushToHubMixin, cached_file + +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager + +SUPPORTED_MODELS = ["llama"] + + +class BaseSmoothForCausalLM(nn.Module, PushToHubMixin): + layer_type: str = None + + def __init__(self, model: PreTrainedModel, quantized: bool = False): + super().__init__() + + self.model = model + self.model_type = self.model.config.model_type + self._quantized = quantized + self.config = self.model.config + self.cache_manager = None + self.max_total_token_num = 0 + + @property + def quantized(self): + return self._quantized + + def init_cache_manager(self, max_total_token_num=2048): + if self.config.model_type == "llama": + head_num = self.config.num_key_value_heads + layer_num = self.config.num_hidden_layers + head_dim = self.config.hidden_size // head_num + + self.cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num) + self.max_total_token_num = max_total_token_num + + def init_batch_state(self, max_output_len=256, **kwargs): + input_ids = kwargs["input_ids"] + batch_size = len(input_ids) + + seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + start_index = 0 + max_len_in_batch = -1 + + for i in range(batch_size): + seq_len = len(input_ids[i]) + seq_lengths[i] = seq_len + seq_start_indexes[i] = start_index + start_index += seq_len + max_len_in_batch = seq_len if seq_len > max_len_in_batch else max_len_in_batch + + if "max_total_token_num" in kwargs.keys(): + max_total_token_num = kwargs["max_total_token_num"] + self.init_cache_manager(max_total_token_num) + + if "max_new_tokens" in kwargs.keys(): + max_output_len = kwargs["max_new_tokens"] + + if batch_size * (max_len_in_batch + max_output_len) > self.max_total_token_num: + max_total_token_num = batch_size * (max_len_in_batch + max_output_len) + warnings.warn(f"reset max tokens to {max_total_token_num}") + self.init_cache_manager(max_total_token_num) + + block_loc = torch.empty((batch_size, max_len_in_batch + max_output_len), dtype=torch.long, device="cuda") + batch_infer_state = BatchInferState(batch_size, max_len_in_batch) + batch_infer_state.seq_len = seq_lengths.to("cuda") + batch_infer_state.start_loc = seq_start_indexes.to("cuda") + batch_infer_state.block_loc = block_loc + batch_infer_state.decode_layer_id = 0 + batch_infer_state.past_key_values_len = 0 + batch_infer_state.is_context_stage = True + batch_infer_state.set_cache_manager(self.cache_manager) + batch_infer_state.cache_manager.free_all() + return batch_infer_state + + @abstractmethod + @torch.inference_mode() + def quantize( + self, + examples: List[Dict[str, Union[List[int], torch.LongTensor]]], + ): + if self.quantized: + raise EnvironmentError("can't execute quantize because the model is quantized.") + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def generate(self, **kwargs): + """shortcut for model.generate""" + + batch_infer_state = self.init_batch_state(**kwargs) + if self.config.model_type == "llama": + setattr(self.model.model, "infer_state", batch_infer_state) + + with torch.inference_mode(): + return self.model.generate(**kwargs) + + def prepare_inputs_for_generation(self, *args, **kwargs): + """shortcut for model.prepare_inputs_for_generation""" + return self.model.prepare_inputs_for_generation(*args, **kwargs) + + def collect_act_scales(self, model, tokenizer, dataset, device, num_samples=512, seq_len=512): + for text in tqdm(dataset): + input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device) + model(input_ids) + + def collect_act_dict(self, model, tokenizer, dataset, act_dict, device, num_samples=512, seq_len=512): + pbar = tqdm(dataset) + for text in pbar: + input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device) + model(input_ids) + mean_scale = np.mean([v["input"] for v in act_dict.values()]) + pbar.set_description(f"Mean input scale: {mean_scale:.2f}") + + # Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py + def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512): + model.eval() + device = next(model.parameters()).device + act_scales = {} + + def stat_tensor(name, tensor): + hidden_dim = tensor.shape[-1] + tensor = tensor.view(-1, hidden_dim).abs().detach() + comming_max = torch.max(tensor, dim=0)[0].float().cpu() + if name in act_scales: + act_scales[name] = torch.max(act_scales[name], comming_max) + else: + act_scales[name] = comming_max + + def stat_input_hook(m, x, y, name): + if isinstance(x, tuple): + x = x[0] + stat_tensor(name, x) + + hooks = [] + for name, m in model.named_modules(): + if isinstance(m, nn.Linear): + hooks.append(m.register_forward_hook(partial(stat_input_hook, name=name))) + + self.collect_act_scales(model, tokenizer, dataset, device, num_samples, seq_len) + + for h in hooks: + h.remove() + + return act_scales + + # Adapted from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py + @torch.no_grad() + def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5): + if not isinstance(fcs, list): + fcs = [fcs] + for fc in fcs: + assert isinstance(fc, nn.Linear) + assert ln.weight.numel() == fc.in_features == act_scales.numel() + + device, dtype = fcs[0].weight.device, fcs[0].weight.dtype + act_scales = act_scales.to(device=device, dtype=dtype) + weight_scales = torch.cat([fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0) + weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5) + + scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype) + + ln.weight.div_(scales) + if hasattr(ln, "bias"): + ln.bias.div_(scales) + + for fc in fcs: + fc.weight.mul_(scales.view(1, -1)) + + @classmethod + def create_quantized_model(model): + raise NotImplementedError("Not implement create_quantized_model method") + + # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py + def save_quantized( + self, + save_dir: str, + model_basename: str, + use_safetensors: bool = False, + safetensors_metadata: Optional[Dict[str, str]] = None, + ): + """save quantized model and configs to local disk""" + os.makedirs(save_dir, exist_ok=True) + + if not self.quantized: + raise EnvironmentError("can only save quantized model, please execute .quantize first.") + + self.model.to("cpu") + + model_base_name = model_basename # or f"smooth-" + if use_safetensors: + model_save_name = model_base_name + ".safetensors" + state_dict = self.model.state_dict() + state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} + if safetensors_metadata is None: + safetensors_metadata = {} + elif not isinstance(safetensors_metadata, dict): + raise TypeError("safetensors_metadata must be a dictionary.") + else: + print(f"Received safetensors_metadata: {safetensors_metadata}") + new_safetensors_metadata = {} + converted_keys = False + for key, value in safetensors_metadata.items(): + if not isinstance(key, str) or not isinstance(value, str): + converted_keys = True + try: + new_key = str(key) + new_value = str(value) + except Exception as e: + raise TypeError( + f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}" + ) + if new_key in new_safetensors_metadata: + print( + f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting." + ) + new_safetensors_metadata[new_key] = new_value + safetensors_metadata = new_safetensors_metadata + if converted_keys: + print( + f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}" + ) + + # Format is required to enable Accelerate to load the metadata + # otherwise it raises an OSError + safetensors_metadata["format"] = "pt" + + safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata) + else: + model_save_name = model_base_name + ".bin" + torch.save(self.model.state_dict(), join(save_dir, model_save_name)) + + self.model.config.save_pretrained(save_dir) + + # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py + def save_pretrained( + self, + save_dir: str, + use_safetensors: bool = False, + safetensors_metadata: Optional[Dict[str, str]] = None, + **kwargs, + ): + """alias of save_quantized""" + warnings.warn("you are using save_pretrained, which will re-direct to save_quantized.") + self.save_quantized(save_dir, use_safetensors, safetensors_metadata) + + # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + max_memory: Optional[dict] = None, + trust_remote_code: bool = False, + torch_dtype: torch.dtype = torch.float16, + **model_init_kwargs, + ): + if not torch.cuda.is_available(): + raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.") + + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + + # Parameters related to loading from Hugging Face Hub + cache_dir = model_init_kwargs.pop("cache_dir", None) + force_download = model_init_kwargs.pop("force_download", False) + resume_download = model_init_kwargs.pop("resume_download", False) + proxies = model_init_kwargs.pop("proxies", None) + local_files_only = model_init_kwargs.pop("local_files_only", False) + use_auth_token = model_init_kwargs.pop("use_auth_token", None) + revision = model_init_kwargs.pop("revision", None) + subfolder = model_init_kwargs.pop("subfolder", "") + model_init_kwargs.pop("_commit_hash", None) + + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "use_auth_token": use_auth_token, + "revision": revision, + "subfolder": subfolder, + } + + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True, **cached_file_kwargs) + if config.model_type not in SUPPORTED_MODELS: + raise TypeError(f"{config.model_type} isn't supported yet.") + + # enforce some values despite user specified + model_init_kwargs["torch_dtype"] = torch_dtype + model_init_kwargs["trust_remote_code"] = trust_remote_code + if max_memory: + if "disk" in max_memory: + raise NotImplementedError("disk offload not support yet.") + with accelerate.init_empty_weights(): + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + model.tie_weights() + + max_memory = accelerate.utils.get_balanced_memory( + model, + max_memory=max_memory, + no_split_module_classes=[cls.layer_type], + dtype=model_init_kwargs["torch_dtype"], + low_zero=False, + ) + model_init_kwargs["device_map"] = accelerate.infer_auto_device_map( + model, + max_memory=max_memory, + no_split_module_classes=[cls.layer_type], + dtype=model_init_kwargs["torch_dtype"], + ) + model_init_kwargs["low_cpu_mem_usage"] = True + + del model + else: + model_init_kwargs["device_map"] = None + model_init_kwargs["low_cpu_mem_usage"] = False + + torch.cuda.empty_cache() + + merged_kwargs = {**model_init_kwargs, **cached_file_kwargs} + model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs) + + model_config = model.config.to_dict() + seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] + if any([k in model_config for k in seq_len_keys]): + for key in seq_len_keys: + if key in model_config: + model.seqlen = model_config[key] + break + else: + warnings.warn("can't get model's sequence length from model config, will set to 4096.") + model.seqlen = 4096 + model.eval() + + return cls(model, False) + + # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py + @classmethod + def from_quantized( + cls, + model_name_or_path: Optional[str], + model_basename: Optional[str] = None, + device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None, + max_memory: Optional[dict] = None, + device: Optional[Union[str, int]] = None, + low_cpu_mem_usage: bool = False, + torch_dtype: Optional[torch.dtype] = None, + use_safetensors: bool = False, + trust_remote_code: bool = False, + **kwargs, + ): + """load quantized model from local disk""" + + # Parameters related to loading from Hugging Face Hub + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) + + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "use_auth_token": use_auth_token, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + + # == step1: prepare configs and file names == # + config = AutoConfig.from_pretrained( + model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs + ) + + if config.model_type not in SUPPORTED_MODELS: + raise TypeError(f"{config.model_type} isn't supported yet.") + + extensions = [] + if use_safetensors: + extensions.append(".safetensors") + else: + extensions += [".bin", ".pt"] + + model_name_or_path = str(model_name_or_path) + is_local = isdir(model_name_or_path) + + resolved_archive_file = None + if is_local: + model_save_name = join(model_name_or_path, model_basename) + for ext in extensions: + if isfile(model_save_name + ext): + resolved_archive_file = model_save_name + ext + break + else: # remote + for ext in extensions: + resolved_archive_file = cached_file(model_name_or_path, model_basename + ext, **cached_file_kwargs) + if resolved_archive_file is not None: + break + + if resolved_archive_file is None: # Could not find a model file to use + raise FileNotFoundError(f"Could not find model in {model_name_or_path}") + + model_save_name = resolved_archive_file + + # == step2: convert model to quantized-model (replace Linear) == # + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + + transformers.modeling_utils._init_weights = False + + init_contexts = [no_init_weights()] + if low_cpu_mem_usage: + init_contexts.append(accelerate.init_empty_weights(include_buffers=True)) + + with ContextManagers(init_contexts): + model = AutoModelForCausalLM.from_config( + config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype + ) + cls.create_quantized_model(model) + model.tie_weights() + + # == step3: load checkpoint to quantized-model == # + accelerate.utils.modeling.load_checkpoint_in_model( + model, checkpoint=model_save_name, offload_state_dict=True, offload_buffers=True + ) + + # == step4: set seqlen == # + model_config = model.config.to_dict() + seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] + if any([k in model_config for k in seq_len_keys]): + for key in seq_len_keys: + if key in model_config: + model.seqlen = model_config[key] + break + else: + warnings.warn("can't get model's sequence length from model config, will set to 4096.") + model.seqlen = 4096 + + return cls( + model, + True, + ) + + def __getattr__(self, item): + try: + return super().__getattr__(item) + except: + return getattr(self.model, item) + + +__all__ = ["BaseSmoothForCausalLM"] diff --git a/colossalai/inference/quant/smoothquant/models/linear.py b/colossalai/inference/quant/smoothquant/models/linear.py new file mode 100644 index 000000000000..969c390a0849 --- /dev/null +++ b/colossalai/inference/quant/smoothquant/models/linear.py @@ -0,0 +1,179 @@ +# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py + +import torch +from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32 +from torch_int.functional.quantization import quantize_per_tensor_absmax + +try: + from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder + + smoothquant_cuda = SmoothquantBuilder().load() + HAS_SMOOTHQUANT_CUDA = True +except ImportError: + HAS_SMOOTHQUANT_CUDA = False + raise ImportError("CUDA smoothquant linear is not installed") + + +class W8A8BFP32O32LinearSiLU(torch.nn.Module): + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer( + "weight", + torch.randint( + -127, + 127, + (self.out_features, self.in_features), + dtype=torch.int8, + requires_grad=False, + ), + ) + self.register_buffer( + "bias", + torch.zeros((1, self.out_features), dtype=torch.float, requires_grad=False), + ) + self.register_buffer("a", torch.tensor(alpha)) + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1.0) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale): + int8_module = W8A8BFP32O32LinearSiLU(module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale + int8_module.weight = int8_weight + if module.bias is not None: + int8_module.bias.data.copy_(module.bias.to(torch.float)) + int8_module.a = alpha + return int8_module + + +# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py +class W8A8B8O8Linear(torch.nn.Module): + # For qkv_proj + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer( + "weight", + torch.randint( + -127, + 127, + (self.out_features, self.in_features), + dtype=torch.int8, + requires_grad=False, + ), + ) + self.register_buffer( + "bias", + torch.zeros((1, self.out_features), dtype=torch.int8, requires_grad=False), + ) + self.register_buffer("a", torch.tensor(alpha)) + self.register_buffer("b", torch.tensor(beta)) + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = linear_a8_w8_b8_o8(x, self.weight, self.bias, self.a.item(), self.b.item()) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale, output_scale): + int8_module = W8A8B8O8Linear(module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale / output_scale + int8_module.weight = int8_weight + int8_module.a = alpha + + if module.bias is not None: + int8_bias, bias_scale = quantize_per_tensor_absmax(module.bias) + int8_module.bias = int8_bias + beta = bias_scale / output_scale + int8_module.b = beta + + return int8_module + + +# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py +class W8A8BFP32OFP32Linear(torch.nn.Module): + # For fc2 and out_proj + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer( + "weight", + torch.randint( + -127, + 127, + (self.out_features, self.in_features), + dtype=torch.int8, + requires_grad=False, + ), + ) + self.register_buffer( + "bias", + torch.zeros(self.out_features, dtype=torch.float32, requires_grad=False), + ) + self.register_buffer("a", torch.tensor(alpha)) + + def _apply(self, fn): + # prevent the bias from being converted to half + super()._apply(fn) + self.bias = self.bias.to(torch.float32) + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + self.bias = self.bias.to(torch.float32) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = linear_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale): + int8_module = W8A8BFP32OFP32Linear(module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale + int8_module.weight = int8_weight + int8_module.a = alpha + int8_module.input_scale = input_scale + int8_module.weight_scale = weight_scale + + if module.bias is not None: + int8_module.bias = module.bias.to(torch.float32) + + return int8_module diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py new file mode 100644 index 000000000000..4c3d6dcc0b23 --- /dev/null +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -0,0 +1,849 @@ +import math +import os +import types +from collections import defaultdict +from functools import partial +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T +from transformers import PreTrainedModel +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LLAMA_INPUTS_DOCSTRING, + LlamaAttention, + LlamaDecoderLayer, + LlamaMLP, + LlamaRotaryEmbedding, + repeat_kv, + rotate_half, +) +from transformers.utils import add_start_docstrings_to_model_forward + +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.kernel.triton import ( + copy_kv_cache_to_dest, + int8_rotary_embedding_fwd, + smooth_llama_context_attn_fwd, + smooth_token_attention_fwd, +) + +from .base_model import BaseSmoothForCausalLM +from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear + + +class LLamaSmoothquantAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + + if (self.head_dim * num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {num_heads})." + ) + + self.qk_bmm = BMM_S8T_S8N_F32T(1.0) + self.pv_bmm = BMM_S8T_S8N_S8T(1.0) + + self.k_proj = W8A8B8O8Linear(hidden_size, hidden_size) + self.v_proj = W8A8B8O8Linear(hidden_size, hidden_size) + self.q_proj = W8A8B8O8Linear(hidden_size, hidden_size) + self.o_proj = W8A8BFP32OFP32Linear(hidden_size, hidden_size) + + self.register_buffer("q_output_scale", torch.tensor([1.0])) + self.register_buffer("k_output_scale", torch.tensor([1.0])) + self.register_buffer("v_output_scale", torch.tensor([1.0])) + self.register_buffer("q_rotary_output_scale", torch.tensor([1.0])) + self.register_buffer("k_rotary_output_scale", torch.tensor([1.0])) + self.register_buffer("out_input_scale", torch.tensor([1.0])) + self.register_buffer("attn_input_scale", torch.tensor([1.0])) + + self._init_rope() + self.num_key_value_heads = num_heads + + def _init_rope(self): + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=2048, + base=10000.0, + ) + + @staticmethod + def pack( + module: LlamaAttention, + attn_input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + q_rotary_output_scale: float, + k_rotary_output_scale: float, + out_input_scale: float, + ): + int8_module = LLamaSmoothquantAttention(module.hidden_size, module.num_heads) + + int8_module.attn_input_scale = torch.tensor([attn_input_scale]) + + int8_module.q_output_scale = torch.tensor([q_output_scale]) + int8_module.k_output_scale = torch.tensor([k_output_scale]) + int8_module.v_output_scale = torch.tensor([v_output_scale]) + + int8_module.q_rotary_output_scale = torch.tensor([q_rotary_output_scale]) + int8_module.k_rotary_output_scale = torch.tensor([k_rotary_output_scale]) + + int8_module.q_proj = W8A8B8O8Linear.from_float(module.q_proj, attn_input_scale, q_output_scale) + int8_module.k_proj = W8A8B8O8Linear.from_float(module.k_proj, attn_input_scale, k_output_scale) + int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, attn_input_scale, v_output_scale) + int8_module.o_proj = W8A8BFP32OFP32Linear.from_float(module.o_proj, out_input_scale) + + int8_module.out_input_scale = torch.tensor([out_input_scale]) + + return int8_module + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + @torch.no_grad() + def forward( + self, + hidden_states: torch.Tensor, + rotary_emb: Tuple[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + cos = rotary_emb[0] + sin = rotary_emb[1] + + int8_rotary_embedding_fwd( + query_states.view(-1, self.num_heads, self.head_dim), + cos, + sin, + self.q_output_scale.item(), + self.q_rotary_output_scale.item(), + ) + int8_rotary_embedding_fwd( + key_states.view(-1, self.num_heads, self.head_dim), + cos, + sin, + self.k_output_scale.item(), + self.k_rotary_output_scale.item(), + ) + + # NOTE might want to revise + # need some way to record the length of past key values cache + # since we won't return past_key_value_cache right now + if infer_state.decode_layer_id == 0: # once per model.forward + infer_state.cache_manager.past_key_values_length += q_len # seq_len + + def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) + return + + query_states = query_states.view(-1, self.num_heads, self.head_dim) + key_states = key_states.view(-1, self.num_heads, self.head_dim) + value_states = value_states.view(-1, self.num_heads, self.head_dim) + + if infer_state.is_context_stage: + # first token generation + + # copy key and value calculated in current step to memory manager + _copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.context_mem_index, + infer_state.cache_manager, + ) + + attn_output = torch.empty_like(query_states) + + smooth_llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + self.q_rotary_output_scale.item(), + self.k_rotary_output_scale.item(), + self.v_output_scale.item(), + self.out_input_scale.item(), + infer_state.start_loc, + infer_state.seq_len, + q_len, + ) + + else: + if infer_state.decode_is_contiguous: + # if decode is contiguous, then we copy to key cache and value cache in cache manager directly + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_k.copy_(key_states) + cache_v.copy_(value_states) + else: + # if decode is not contiguous, use triton kernel to copy key and value cache + # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head + _copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.decode_mem_index, + infer_state.cache_manager, + ) + + # (batch_size, seqlen, nheads, headdim) + attn_output = torch.empty_like(query_states) + + smooth_token_attention_fwd( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + self.q_rotary_output_scale.item(), + self.k_rotary_output_scale.item(), + self.v_output_scale.item(), + self.out_input_scale.item(), + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) + + attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim) + attn_output = self.o_proj(attn_output) + + return attn_output, None, None + + +class LlamaLayerNormQ(torch.nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.input_scale = 1.0 + self.variance_epsilon = eps + self.register_buffer("weight", torch.ones(dim, dtype=torch.float32)) + + def forward(self, x): + ln_output_fp = torch.nn.functional.layer_norm(x, x.shape[-1:], self.weight, None, self.variance_epsilon) + ln_output_int8 = ln_output_fp.round().clamp(-128, 127).to(torch.int8) + return ln_output_int8 + + @staticmethod + def from_float(module: torch.nn.LayerNorm, output_scale: float): + assert module.weight.shape[0] == module.weight.numel() + q_module = LlamaLayerNormQ(module.weight.shape[0], module.variance_epsilon) + q_module.weight = module.weight / output_scale + return q_module + + +class LlamaSmoothquantMLP(nn.Module): + def __init__(self, intermediate_size, hidden_size): + super().__init__() + self.gate_proj = W8A8BFP32O32LinearSiLU(hidden_size, intermediate_size) + self.up_proj = W8A8BFP32OFP32Linear(hidden_size, intermediate_size) + self.down_proj = W8A8BFP32OFP32Linear(intermediate_size, hidden_size) + self.register_buffer("down_proj_input_scale", torch.tensor([1.0])) + + @staticmethod + def pack( + mlp_module: LlamaMLP, + gate_proj_input_scale: float, + up_proj_input_scale: float, + down_proj_input_scale: float, + ): + int8_module = LlamaSmoothquantMLP( + mlp_module.intermediate_size, + mlp_module.hidden_size, + ) + + int8_module.gate_proj = W8A8BFP32O32LinearSiLU.from_float(mlp_module.gate_proj, gate_proj_input_scale) + int8_module.up_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.up_proj, up_proj_input_scale) + int8_module.down_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.down_proj, down_proj_input_scale) + int8_module.down_proj_input_scale = torch.tensor([down_proj_input_scale]) + return int8_module + + def forward( + self, + hidden_states: torch.Tensor, + ): + x_shape = hidden_states.shape + gate_out = self.gate_proj(hidden_states) + up_out = self.up_proj(hidden_states) + inter_out = gate_out * up_out + inter_out = inter_out.div_(self.down_proj_input_scale.item()).round().clamp(-128, 127).to(torch.int8) + down_out = self.down_proj(inter_out) + down_out = down_out.view(*x_shape[:-1], -1) + return down_out + + +class LlamaSmoothquantDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LLamaSmoothquantAttention(config.hidden_size, config.num_attention_heads) + + self.mlp = LlamaSmoothquantMLP(config.intermediate_size, config.hidden_size) + self.input_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps) + + self.post_attention_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps) + + @staticmethod + def pack( + module: LlamaDecoderLayer, + attn_input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + q_rotary_output_scale: float, + k_rotary_output_scale: float, + out_input_scale: float, + gate_input_scale: float, + up_input_scale: float, + down_input_scale: float, + ): + config = module.self_attn.config + int8_decoder_layer = LlamaSmoothquantDecoderLayer(config) + + int8_decoder_layer.input_layernorm = LlamaLayerNormQ.from_float(module.input_layernorm, attn_input_scale) + int8_decoder_layer.self_attn = LLamaSmoothquantAttention.pack( + module.self_attn, + attn_input_scale, + q_output_scale, + k_output_scale, + v_output_scale, + q_rotary_output_scale, + k_rotary_output_scale, + out_input_scale, + ) + + int8_decoder_layer.post_attention_layernorm = LlamaLayerNormQ.from_float( + module.post_attention_layernorm, gate_input_scale + ) + + int8_decoder_layer.mlp = LlamaSmoothquantMLP.pack( + module.mlp, + gate_input_scale, + up_input_scale, + down_input_scale, + ) + + return int8_decoder_layer + + def forward( + self, + hidden_states: torch.Tensor, + rotary_emb: Tuple[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + padding_mask: Optional[torch.LongTensor] = None, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + rotary_emb=rotary_emb, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + infer_state=infer_state, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, None, None + + +class LlamaApplyRotary(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + x_embed = (x * cos) + (rotate_half(x) * sin) + + return x_embed + + +# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py +def llama_decoder_layer_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states = self.q_apply_rotary(query_states, cos, sin, position_ids) + key_states = self.k_apply_rotary(key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def init_to_get_rotary(config, base=10000, use_elem=False): + """ + This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer + Args: + base : calculation arg + use_elem : activated when using chatglm-based models + """ + config.head_dim_ = config.hidden_size // config.num_attention_heads + if not hasattr(config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = config.rope_scaling.factor if config.rope_scaling is not None else 1.0 + + if hasattr(config, "max_sequence_length"): + max_seq_len = config.max_sequence_length + elif hasattr(config, "max_position_embeddings"): + max_seq_len = config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + + # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + try: + ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", 1)) + assert ntk_alpha >= 1 + if ntk_alpha > 1: + print(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (config.head_dim_ / (config.head_dim_ - 2))) # Base change formula + except: + pass + + n_elem = config.head_dim_ + if use_elem: + n_elem //= 2 + + inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + _cos_cached = torch.cos(freqs).to(torch.float) + _sin_cached = torch.sin(freqs).to(torch.float) + return _cos_cached, _sin_cached + + +# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py +@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) +def llama_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + infer_state = self.infer_state + + if past_key_values is not None: + # NOT READY FOR PRIME TIME + # dummy but work, revise it + past_key_values_length = infer_state.cache_manager.past_key_values_length + # past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + # NOTE: differentiate with prefill stage + # block_loc require different value-assigning method for two different stage + # NOTE: differentiate with prefill stage + # block_loc require different value-assigning method for two different stage + if infer_state.is_context_stage: + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + infer_state.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) + else: + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print( + f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" + ) + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device) + padding_mask = None + else: + if 0 in attention_mask: + padding_mask = attention_mask + else: + padding_mask = None + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + raise NotImplementedError("not implement gradient_checkpointing and training options ") + + if past_key_values_length == 0: + position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + else: + position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1) + position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + infer_state.decode_layer_id = 0 + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + rotary_emb=(position_cos, position_sin), + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + infer_state=infer_state, + ) + + hidden_states = layer_outputs[0] + infer_state.decode_layer_id += 1 + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + infer_state.is_context_stage = False + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class SmoothLlamaForCausalLM(BaseSmoothForCausalLM): + layer_type = "LlamaDecoderLayer" + + def __init__(self, model: PreTrainedModel, quantized: bool = False): + super().__init__(model, quantized) + + # Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py + def get_act_dict( + self, + tokenizer, + dataset, + num_samples=512, + seq_len=512, + ): + llama_model = self.model + + llama_model.eval() + device = next(llama_model.parameters()).device + # print("model:", llama_model) + act_dict = defaultdict(dict) + + def stat_io_hook(m, x, y, name): + if isinstance(x, tuple): + x = x[0] + if name not in act_dict or "input" not in act_dict[name]: + act_dict[name]["input"] = x.detach().abs().max().item() + else: + act_dict[name]["input"] = max(act_dict[name]["input"], x.detach().abs().max().item()) + if isinstance(y, tuple): + y = y[0] + if name not in act_dict or "output" not in act_dict[name]: + act_dict[name]["output"] = y.detach().abs().max().item() + else: + act_dict[name]["output"] = max(act_dict[name]["output"], y.detach().abs().max().item()) + + for name, m in llama_model.named_modules(): + if isinstance(m, LlamaAttention): + setattr(m, "q_apply_rotary", LlamaApplyRotary()) + setattr(m, "k_apply_rotary", LlamaApplyRotary()) + m.forward = types.MethodType(llama_decoder_layer_forward, m) + + hooks = [] + for name, m in llama_model.named_modules(): + if isinstance(m, LlamaApplyRotary): + hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) + if isinstance(m, torch.nn.Linear): + hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) + + self.collect_act_dict(llama_model, tokenizer, dataset, act_dict, device, num_samples, seq_len) + + for hook in hooks: + hook.remove() + return act_dict + + def smooth_fn(self, scales, alpha=0.5): + model = self.model + for name, module in model.named_modules(): + if isinstance(module, LlamaDecoderLayer): + attn_ln = module.input_layernorm + qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj] + qkv_input_scales = scales[name + ".self_attn.q_proj"] + self.smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha) + + def create_quantized_model(model): + llama_config = model.config + for i, layer in enumerate(model.model.layers): + model.model.layers[i] = LlamaSmoothquantDecoderLayer(llama_config) + + model.model.forward = types.MethodType(llama_model_forward, model.model) + cos, sin = init_to_get_rotary(llama_config) + model.model.register_buffer("_cos_cached", cos) + model.model.register_buffer("_sin_cached", sin) + + def quantized( + self, + tokenizer, + dataset, + num_samples=512, + seq_len=512, + alpha=0.5, + ): + llama_model = self.model + llama_config = llama_model.config + + act_scales = self.get_act_scales(llama_model, tokenizer, dataset, num_samples, seq_len) + + self.smooth_fn(act_scales, alpha) + + act_dict = self.get_act_dict(tokenizer, dataset, num_samples, seq_len) + decoder_layer_scales = [] + + for idx in range(llama_config.num_hidden_layers): + scale_dict = {} + scale_dict["attn_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["input"] / 127 + scale_dict["q_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["output"] / 127 + scale_dict["k_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.k_proj"]["output"] / 127 + scale_dict["v_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.v_proj"]["output"] / 127 + + scale_dict["q_rotary_output_scale"] = ( + act_dict[f"model.layers.{idx}.self_attn.q_apply_rotary"]["output"] / 127 + ) + scale_dict["k_rotary_output_scale"] = ( + act_dict[f"model.layers.{idx}.self_attn.k_apply_rotary"]["output"] / 127 + ) + + scale_dict["out_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.o_proj"]["input"] / 127 + + scale_dict["gate_input_scale"] = act_dict[f"model.layers.{idx}.mlp.gate_proj"]["input"] / 127 + scale_dict["up_input_scale"] = act_dict[f"model.layers.{idx}.mlp.up_proj"]["input"] / 127 + scale_dict["down_input_scale"] = act_dict[f"model.layers.{idx}.mlp.down_proj"]["input"] / 127 + + decoder_layer_scales.append(scale_dict) + + for i, layer in enumerate(llama_model.model.layers): + orig_layer = layer + llama_model.model.layers[i] = LlamaSmoothquantDecoderLayer.pack(orig_layer, **decoder_layer_scales[i]) + + llama_model.model.forward = types.MethodType(llama_model_forward, llama_model.model) + + cos, sin = init_to_get_rotary(llama_config) + llama_model.model.register_buffer("_cos_cached", cos.to(self.model.device)) + llama_model.model.register_buffer("_sin_cached", sin.to(self.model.device)) diff --git a/colossalai/inference/tensor_parallel/batch_infer_state.py b/colossalai/inference/tensor_parallel/batch_infer_state.py index ac185f1b6529..de150311cc08 100644 --- a/colossalai/inference/tensor_parallel/batch_infer_state.py +++ b/colossalai/inference/tensor_parallel/batch_infer_state.py @@ -5,7 +5,7 @@ from .kvcache_manager import MemoryManager - +# adapted from: lightllm/server/router/model_infer/infer_batch.py @dataclass class BatchInferState: r""" @@ -41,6 +41,7 @@ def total_token_num(self): def set_cache_manager(self, manager: MemoryManager): self.cache_manager = manager + # adapted from: https://github.com/ModelTC/lightllm/blob/28c1267cfca536b7b4f28e921e03de735b003039/lightllm/common/infer_utils.py#L1 @staticmethod def init_block_loc( b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index e4c4a2d70cd7..216b134f5fab 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -21,6 +21,8 @@ "BloomForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration", + "LlamaGPTQForCausalLM", + "BloomGPTQForCausalLM", ] @@ -213,11 +215,14 @@ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None: ), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" model_name = model.__class__.__name__ assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." + + model = model.model if self.shard_config.inference_gptq else model + policy = get_autopolicy(model, inference_only=True) self.model, _ = shardformer.optimize(model, policy) if self.shard_config.inference_gptq: - self._post_init_gptq_buffer(model) + self._post_init_gptq_buffer(self.model) self.model = self.model.cuda() diff --git a/colossalai/inference/tensor_parallel/kvcache_manager.py b/colossalai/inference/tensor_parallel/kvcache_manager.py index e74a3a491a7b..c9e7aaae0844 100644 --- a/colossalai/inference/tensor_parallel/kvcache_manager.py +++ b/colossalai/inference/tensor_parallel/kvcache_manager.py @@ -1,7 +1,9 @@ -# Adapted from lightllm/common/mem_manager.py -# of the ModelTC/lightllm GitHub repository -# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py - +""" +Refered/Modified from lightllm/common/mem_manager.py +of the ModelTC/lightllm GitHub repository +https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py +we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design. +""" import torch from transformers.utils import logging diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index 4b1bc601f436..b8274d3c660f 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -6,8 +6,6 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState -from colossalai.kernel.triton.context_attention import llama2_context_attn_fwd -from colossalai.kernel.triton.rotary_embedding_kernel import Llama2Forwards from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( ChatGLMForConditionalGeneration, @@ -20,6 +18,14 @@ from ._utils import copy_kv_to_mem_cache +try: + from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_llama2_context_attention_fwd + from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd + HAS_LIGHTLLM_KERNEL = True +except: + print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") + HAS_LIGHTLLM_KERNEL = False + # This func is same as Llama model init_to_get_rotary, we should move them into _utils.py def _init_to_get_rotary(self, base=10000): @@ -433,17 +439,17 @@ def chatglm_flash_attn_kvcache_forward( cos, sin = infer_state.position_cos, infer_state.position_sin - Llama2Forwards.rotary_emb_fwd( + chatglm2_rotary_emb_fwd( query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin ) if self.multi_query_attention: - Llama2Forwards.rotary_emb_fwd( + chatglm2_rotary_emb_fwd( key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head), cos, sin, ) else: - Llama2Forwards.rotary_emb_fwd( + chatglm2_rotary_emb_fwd( key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin, @@ -474,7 +480,7 @@ def chatglm_flash_attn_kvcache_forward( attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) # NOTE: no bug in context attn fwd (del it ) - llama2_context_attn_fwd( + lightllm_llama2_context_attention_fwd( query_layer, key_layer, value_layer, diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index ac4ae72f3d18..a3937f6f10ba 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -5,12 +5,7 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState -from colossalai.kernel.triton import ( - llama2_context_attn_fwd, - llama_context_attn_fwd, - rotary_embedding_fwd, - token_attention_fwd, -) +from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards from ._utils import copy_kv_to_mem_cache @@ -29,6 +24,17 @@ ) HAS_VLLM_KERNERL = False +try: + from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd as lightllm_llama2_context_attention_fwd, + ) + from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd + + HAS_LIGHTLLM_KERNEL = True +except: + print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") + HAS_LIGHTLLM_KERNEL = False + def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -280,8 +286,8 @@ def llama_flash_attn_kvcache_forward( cos, sin = infer_state.position_cos, infer_state.position_sin # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, ) - rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) - rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin) + llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) + llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin) query_states = query_states.reshape(-1, self.num_heads, self.head_dim) key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim) @@ -312,7 +318,7 @@ def llama_flash_attn_kvcache_forward( infer_state.cache_manager.past_key_values_length, ) else: - llama2_context_attn_fwd( + lightllm_llama2_context_attention_fwd( query_states, key_states, value_states, @@ -371,6 +377,7 @@ def llama_flash_attn_kvcache_forward( infer_state.cache_manager.past_key_values_length, infer_state.other_kv_index, ) + attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index 507c1203dd6b..7e163efe0173 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -12,8 +12,7 @@ from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward try: - from colossalai.kernel.triton import rmsnorm_forward - + from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward HAS_TRITON_RMSNORM = True except: print("you should install triton from https://github.com/openai/triton") @@ -22,9 +21,8 @@ def get_triton_rmsnorm_forward(): if HAS_TRITON_RMSNORM: - def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): - return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) + return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) return _triton_rmsnorm_forward else: diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp index 0ab250218da3..be9300c545c2 100644 --- a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp @@ -35,23 +35,19 @@ SOFTWARE void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, size_t _param_size, bool param_half_precision, bool grad_half_precision, - float loss_scale) { - size_t rounded_size = 0; + bool momentum_half_precision, + bool variance_half_precision, float loss_scale) { + size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH); float betta1_minus1 = 1 - _betta1; float betta2_minus1 = 1 - _betta2; float step_size = -1 * _alpha / _bias_correction1; float w_decay = -1 * _alpha * _weight_decay; - __half *params_cast_h = NULL; - __half *grads_cast_h = NULL; - - if (param_half_precision) { - params_cast_h = reinterpret_cast<__half *>(_params); - } - if (grad_half_precision) { - grads_cast_h = reinterpret_cast<__half *>(grads); - } + __half *params_cast_h = reinterpret_cast<__half *>(_params); + __half *grads_cast_h = reinterpret_cast<__half *>(grads); + __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); + __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) AVX_Data betta1_4; @@ -77,7 +73,6 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, if (_weight_decay > 0) weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); - rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH); for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; @@ -87,28 +82,23 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, #pragma omp parallel for for (size_t i = t; i < offset; i += SIMD_WIDTH) { AVX_Data grad_4; - if (grad_half_precision) { - grad_4.data = SIMD_LOAD_HALF(grads_cast_h + i); - } else { - grad_4.data = SIMD_LOAD(grads + i); - } + this->simd_load(grad_half_precision, grads + i, grads_cast_h + i, grad_4); if (loss_scale > 0) { AVX_Data loss_scale_vec; loss_scale_vec.data = SIMD_SET(loss_scale); grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data); } AVX_Data momentum_4; - momentum_4.data = SIMD_LOAD(_exp_avg + i); + this->simd_load(momentum_half_precision, _exp_avg + i, + momentum_cast_h + i, momentum_4); AVX_Data variance_4; - variance_4.data = SIMD_LOAD(_exp_avg_sq + i); + this->simd_load(variance_half_precision, _exp_avg_sq + i, + variance_cast_h + i, variance_4); AVX_Data param_4; - if (param_half_precision) { - param_4.data = SIMD_LOAD_HALF(params_cast_h + i); - } else { - param_4.data = SIMD_LOAD(_params + i); - } + this->simd_load(param_half_precision, _params + i, params_cast_h + i, + param_4); if (_weight_decay > 0 && !_adamw_mode) { grad_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, grad_4.data); @@ -130,13 +120,12 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, } param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data); - if (param_half_precision) { - SIMD_STORE_HALF((float *)(params_cast_h + i), param_4.data); - } else { - SIMD_STORE(_params + i, param_4.data); - } - SIMD_STORE(_exp_avg + i, momentum_4.data); - SIMD_STORE(_exp_avg_sq + i, variance_4.data); + this->simd_store(param_half_precision, _params + i, params_cast_h + i, + param_4); + this->simd_store(momentum_half_precision, _exp_avg + i, + momentum_cast_h + i, momentum_4); + this->simd_store(variance_half_precision, _exp_avg_sq + i, + variance_cast_h + i, variance_4); } } #endif @@ -154,8 +143,10 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, } float param = param_half_precision ? (float)params_cast_h[k] : _params[k]; - float momentum = _exp_avg[k]; - float variance = _exp_avg_sq[k]; + float momentum = + momentum_half_precision ? (float)momentum_cast_h[k] : _exp_avg[k]; + float variance = variance_half_precision ? (float)variance_cast_h[k] + : _exp_avg_sq[k]; if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; } @@ -178,8 +169,14 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, params_cast_h[k] = (__half)param; else _params[k] = param; - _exp_avg[k] = momentum; - _exp_avg_sq[k] = variance; + if (momentum_half_precision) + momentum_cast_h[k] = (__half)(momentum); + else + _exp_avg[k] = momentum; + if (variance_half_precision) + variance_cast_h[k] = (__half)(variance); + else + _exp_avg_sq[k] = variance; } } } @@ -188,17 +185,14 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, size_t _param_size, bool param_half_precision, bool grad_half_precision, - float loss_scale) { - size_t rounded_size = 0; + bool momentum_half_precision, + bool variance_half_precision, float loss_scale) { + size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4); - __half *params_cast_h = NULL; - __half *grads_cast_h = NULL; - if (param_half_precision) { - params_cast_h = reinterpret_cast<__half *>(_params); - } - if (grad_half_precision) { - grads_cast_h = reinterpret_cast<__half *>(grads); - } + __half *params_cast_h = reinterpret_cast<__half *>(_params); + __half *grads_cast_h = reinterpret_cast<__half *>(grads); + __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); + __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) AVX_Data betta1_4; @@ -228,7 +222,6 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, if (_weight_decay > 0) weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); - rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4); for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; @@ -243,26 +236,21 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, AVX_Data param_4[4]; #pragma unroll 4 for (int j = 0; j < 4; j++) { - if (grad_half_precision) { - grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j); - } else { - grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j); - } + this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j, + grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]); if (loss_scale > 0) { AVX_Data loss_scale_vec; loss_scale_vec.data = SIMD_SET(loss_scale); grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data); } - - momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j); - variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j); - - if (param_half_precision) { - param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j); - } else { - param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j); - } + this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_load(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); + this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); if (_weight_decay > 0 && !_adamw_mode) { grad_4[j].data = @@ -285,14 +273,13 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, } param_4[j].data = SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data); - if (param_half_precision) { - SIMD_STORE_HALF((float *)(params_cast_h + i + SIMD_WIDTH * j), - param_4[j].data); - } else { - SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data); - } - SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data); - SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data); + this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); + this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_store(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); } } } @@ -302,24 +289,26 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, : _params + rounded_size), (grad_half_precision ? (float *)(grads_cast_h + rounded_size) : grads + rounded_size), - (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), + (momentum_half_precision ? (float *)(momentum_cast_h + rounded_size) + : _exp_avg + rounded_size), + (variance_half_precision ? (float *)(variance_cast_h + rounded_size) + : _exp_avg_sq + rounded_size), (_param_size - rounded_size), param_half_precision, - grad_half_precision, loss_scale); + grad_half_precision, momentum_half_precision, + variance_half_precision, loss_scale); } void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, size_t _param_size, bool param_half_precision, bool grad_half_precision, - float loss_scale) { - size_t rounded_size = 0; - __half *params_cast_h = NULL; - __half *grads_cast_h = NULL; - if (param_half_precision) { - params_cast_h = reinterpret_cast<__half *>(_params); - } - if (grad_half_precision) { - grads_cast_h = reinterpret_cast<__half *>(grads); - } + bool momentum_half_precision, + bool variance_half_precision, float loss_scale) { + size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8); + __half *params_cast_h = reinterpret_cast<__half *>(_params); + __half *grads_cast_h = reinterpret_cast<__half *>(grads); + __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); + __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); + #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) AVX_Data betta1_4; betta1_4.data = SIMD_SET(_betta1); @@ -348,7 +337,6 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, if (_weight_decay > 0) weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); - rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8); for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; @@ -363,26 +351,21 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, AVX_Data param_4[8]; #pragma unroll 8 for (int j = 0; j < 8; j++) { - if (grad_half_precision) { - grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j); - } else { - grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j); - } + this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j, + grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]); if (loss_scale > 0) { AVX_Data loss_scale_vec; loss_scale_vec.data = SIMD_SET(loss_scale); grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data); } - - momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j); - variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j); - - if (param_half_precision) { - param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j); - } else { - param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j); - } + this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_load(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); + this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); if (_weight_decay > 0 && !_adamw_mode) { grad_4[j].data = @@ -405,15 +388,13 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, param_4[j].data = SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data); - if (param_half_precision) { - SIMD_STORE_HALF((float *)(params_cast_h + i + SIMD_WIDTH * j), - param_4[j].data); - } else { - SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data); - } - - SIMD_STORE(_exp_avg + i + (SIMD_WIDTH * j), momentum_4[j].data); - SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH * j), variance_4[j].data); + this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); + this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_store(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); } } } @@ -423,9 +404,13 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, : _params + rounded_size), (grad_half_precision ? (float *)(grads_cast_h + rounded_size) : grads + rounded_size), - (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), + (momentum_half_precision ? (float *)(momentum_cast_h + rounded_size) + : _exp_avg + rounded_size), + (variance_half_precision ? (float *)(variance_cast_h + rounded_size) + : _exp_avg_sq + rounded_size), (_param_size - rounded_size), param_half_precision, - grad_half_precision, loss_scale); + grad_half_precision, momentum_half_precision, + variance_half_precision, loss_scale); } void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2, @@ -447,7 +432,9 @@ void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2, this->update_state(lr, epsilon, weight_decay, bias_correction); this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.numel(), (params.options().dtype() == at::kHalf), - (grads.options().dtype() == at::kHalf), loss_scale); + (grads.options().dtype() == at::kHalf), + (exp_avg.options().dtype() == at::kHalf), + (exp_avg_sq.options().dtype() == at::kHalf), loss_scale); } namespace py = pybind11; diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.h b/colossalai/kernel/cuda_native/csrc/cpu_adam.h index 4247da942775..bf9b85997c78 100644 --- a/colossalai/kernel/cuda_native/csrc/cpu_adam.h +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.h @@ -50,9 +50,9 @@ SOFTWARE #define SIMD_DIV(x, y) _mm512_div_ps(x, y) #define SIMD_LOAD_HALF(x) \ _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) -#define SIMD_STORE_HALF(x, d) \ - _mm256_store_ps( \ - x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) +#define SIMD_STORE_HALF(x, d) \ + _mm256_storeu_ps((float *)(x), _mm256_castsi256_ps(_mm512_cvtps_ph( \ + d, _MM_FROUND_TO_NEAREST_INT))) #elif defined(__AVX256__) or defined(__AVX2__) #define SIMD_WIDTH 8 @@ -66,9 +66,9 @@ SOFTWARE #define SIMD_SQRT(x) _mm256_sqrt_ps(x) #define SIMD_DIV(x, y) _mm256_div_ps(x, y) #define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) -#define SIMD_STORE_HALF(x, d) \ - _mm_store_ps( \ - x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) +#define SIMD_STORE_HALF(x, d) \ + _mm_storeu_ps((float *)(x), _mm_castsi128_ps(_mm256_cvtps_ph( \ + d, _MM_FROUND_TO_NEAREST_INT))) #endif @@ -83,11 +83,12 @@ union AVX_Data { #endif -#define STEP(SPAN) \ - void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \ - float *_exp_avg_sq, size_t _param_size, \ - bool param_half_precision = false, \ - bool grad_half_precision = false, float loss_scale = -1); +#define STEP(SPAN) \ + void Step_##SPAN( \ + float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, \ + size_t _param_size, bool param_half_precision = false, \ + bool grad_half_precision = false, bool momentum_half_precision = false, \ + bool variance_half_precision = false, float loss_scale = -1); class Adam_Optimizer { public: @@ -141,6 +142,24 @@ class Adam_Optimizer { } } + inline void simd_load(bool is_half, float *ptr, __half *h_ptr, + AVX_Data &data) { + if (is_half) { + data.data = SIMD_LOAD_HALF(h_ptr); + } else { + data.data = SIMD_LOAD(ptr); + } + } + + inline void simd_store(bool is_half, float *ptr, __half *h_ptr, + AVX_Data &data) { + if (is_half) { + SIMD_STORE_HALF(h_ptr, data.data); + } else { + SIMD_STORE(ptr, data.data); + } + } + void step(size_t step, float lr, float beta1, float beta2, float epsilon, float weight_decay, bool bias_correction, torch::Tensor ¶ms, torch::Tensor &grads, torch::Tensor &exp_avg, diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp b/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp new file mode 100644 index 000000000000..8444272940b4 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp @@ -0,0 +1,8 @@ +#include + +#include "linear.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("linear_silu_a8_w8_bfp32_ofp32", &linear_silu_a8_w8_bfp32_ofp32, + "Linear SiLU (INT8)"); +} diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu new file mode 100644 index 000000000000..a30d02a4cf42 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu @@ -0,0 +1,162 @@ +// modified from https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu + +#include "linear.h" +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias, // FP32 + float alpha, // FP32 + float beta // FP32 +) { + auto M = input.size(0); + auto N = weight.size(0); + auto K = input.size(1); + + using ElementOutput = float; + using ElementAccumulator = int32_t; + using ElementComputeEpilogue = float; + using ElementInputA = int8_t; // <- data type of elements in input matrix A + using ElementInputB = int8_t; // <- data type of elements in input matrix B + + // The code section below describes matrix layout of input and output + // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major + // for Matrix C + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + +#if CUDA_ARCH >= 800 + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits< + ElementOutput>::value, // <- this is the number of elements per + // vectorized memory access. For half + // precision, it's 8 elements. This + // becomes the vector width of math + // instructions in epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue // <- data type for alpha in linear combination + // function + >; + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; +#elif CUDA_ARCH >= 750 + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits< + ElementOutput>::value, // <- this is the number of elements per + // vectorized memory access. For half + // precision, it's 8 elements. This + // becomes the vector width of math + // instructions in epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue // <- data type for alpha in linear combination + // function + >; + + using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, + DefaultGemmCfg::InstructionShape, + EpilogueOp>; +#elif CUDA_ARCH >= 700 + #define USE_TORCH_SILU + using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< + cutlass::arch::OpClassSimt, cutlass::arch::Sm70, + ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassSimt, cutlass::arch::Sm70, + DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, + DefaultGemmCfg::InstructionShape, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>; +#else + #error "Unsupported cuda arch" +#endif + + auto input_size = cutlass::MatrixCoord(M, K); + auto weight_size = cutlass::MatrixCoord(K, N); + auto output_size = cutlass::MatrixCoord(M, N); + + auto device = input.device(); + // use the broadcasted bias as the output + auto out = bias.to(device).view({1, -1}).repeat({M, 1}); + + // constexpr int kSparse = Gemm::kSparse; + // How many elements of A are covered per ElementE + // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; + // The size of individual meta data + // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; + cutlass::gemm::GemmCoord problem_size(M, N, K); + + cutlass::TensorRef input_ref( + input.data_ptr(), LayoutInputA::packed(input_size)); + cutlass::TensorRef weight_ref( + weight.data_ptr(), LayoutInputB::packed(weight_size)); + cutlass::TensorRef out_ref( + out.data_ptr(), LayoutOutput::packed(output_size)); + + typename Gemm::Arguments arguments{ + problem_size, // <- problem size of matrix multiplication + input_ref, // <- reference to matrix A on device + weight_ref, // <- reference to matrix B on device + out_ref, // <- reference to matrix C on device + out_ref, // <- reference to matrix D on device + {alpha, beta}, 1}; + Gemm gemm_op; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm_op(); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot run"); + } +#ifdef USE_TORCH_SILU +#undef USE_TORCH_SILU + out = torch::silu(out); +#endif + return out; +} diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h new file mode 100644 index 000000000000..b62a27f3f8f3 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h @@ -0,0 +1,12 @@ +#include +#include + +#include +#include + +torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias, // FP32 + float alpha, // FP32 + float beta // FP32 +); diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index f065b2100fa8..1fe292289f3d 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -9,24 +9,24 @@ # There may exist import error even if we have triton installed. if HAS_TRITON: - from .context_attention import bloom_context_attn_fwd, llama2_context_attn_fwd, llama_context_attn_fwd + from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd from .copy_kv_cache_dest import copy_kv_cache_to_dest from .fused_layernorm import layer_norm from .gptq_triton import gptq_fused_linear_triton - from .rms_norm import rmsnorm_forward - from .rotary_embedding_kernel import rotary_embedding_fwd + from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd + from .smooth_attention import smooth_llama_context_attn_fwd, smooth_token_attention_fwd from .softmax import softmax from .token_attention_kernel import token_attention_fwd __all__ = [ "llama_context_attn_fwd", - "llama2_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", - "rmsnorm_forward", "copy_kv_cache_to_dest", - "rotary_embedding_fwd", "token_attention_fwd", "gptq_fused_linear_triton", + "int8_rotary_embedding_fwd", + "smooth_llama_context_attn_fwd", + "smooth_token_attention_fwd", ] diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index 01d54566483a..1b4f6e44b0f2 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -238,329 +238,5 @@ def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): num_warps=num_warps, num_stages=1, ) - return - - @triton.jit - def _fwd_kernel_latest( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, - other=0.0, - ) - # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - return - - @triton.jit - def _fwd_kernel_old( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, - TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_tmp_b, - stride_tmp_h, - stride_tmp_s, - kv_group_num, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - - t_ptrs = TMP + cur_batch * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s - # t_ptrs = TMP + offs_m - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, - other=0.0, - ) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - tl.store(t_ptrs, acc_scale) - acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - - return - - @torch.no_grad() - def llama2_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): - if triton.__version__ >= "2.1.0": - BLOCK = 128 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - sm_scale = 1.0 / (Lq**0.5) # 计算scale系数 - batch, head = b_seq_len.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - - num_warps = 4 if Lk <= 64 else 8 - _fwd_kernel_latest[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - kv_group_num=kv_group_num, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - elif triton.__version__ == "2.0.0": - BLOCK = 128 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - - sm_scale = 1.0 / (Lq**0.5) - batch, head = b_seq_len.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) - tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) - num_warps = 4 if Lk <= 64 else 8 - # num_warps = 4 - _fwd_kernel_old[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - tmp, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - tmp.stride(0), - tmp.stride(1), - tmp.stride(2), - kv_group_num=kv_group_num, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return + + return \ No newline at end of file diff --git a/colossalai/kernel/triton/copy_kv_cache_dest.py b/colossalai/kernel/triton/copy_kv_cache_dest.py index 02edcc9a903a..0ce6b09e54dc 100644 --- a/colossalai/kernel/triton/copy_kv_cache_dest.py +++ b/colossalai/kernel/triton/copy_kv_cache_dest.py @@ -11,6 +11,7 @@ if HAS_TRITON: + # adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py @triton.jit def _fwd_copy_kv_cache_dest( kv_cache_ptr, @@ -42,6 +43,7 @@ def _fwd_copy_kv_cache_dest( tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num) return + # adepted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py @torch.no_grad() def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out): seq_len = dest_index_ptr.shape[0] diff --git a/colossalai/kernel/triton/gptq_triton.py b/colossalai/kernel/triton/gptq_triton.py index 8460103e261d..2dc1fe04438a 100644 --- a/colossalai/kernel/triton/gptq_triton.py +++ b/colossalai/kernel/triton/gptq_triton.py @@ -267,6 +267,7 @@ def cai_gptq_matmul_248_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ @autotune( configs=[ triton.Config( diff --git a/colossalai/kernel/triton/int8_rotary_embedding_kernel.py b/colossalai/kernel/triton/int8_rotary_embedding_kernel.py new file mode 100644 index 000000000000..537dd164d1ab --- /dev/null +++ b/colossalai/kernel/triton/int8_rotary_embedding_kernel.py @@ -0,0 +1,117 @@ +# Adapted from ModelTC https://github.com/ModelTC/lightllm +import torch +import triton +import triton.language as tl + + +@triton.jit +def _rotary_kernel( + q, + input_scale, + output_scale, + Cos, + Sin, + q_bs_stride, + q_h_stride, + q_d_stride, + cos_bs_stride, + cos_d_stride, + total_len, + HEAD_NUM: tl.constexpr, + BLOCK_HEAD: tl.constexpr, + BLOCK_SEQ: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + current_head_index = tl.program_id(0) + current_seq_index = tl.program_id(1) + + dim_range0 = tl.arange(0, HEAD_DIM // 2) + dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) + + current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) + + off_q0 = ( + current_seq_range[:, None, None] * q_bs_stride + + current_head_range[None, :, None] * q_h_stride + + dim_range0[None, None, :] * q_d_stride + ) + off_q1 = ( + current_seq_range[:, None, None] * q_bs_stride + + current_head_range[None, :, None] * q_h_stride + + dim_range1[None, None, :] * q_d_stride + ) + + off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride + + q0 = tl.load( + q + off_q0, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + other=0.0, + ) + q1 = tl.load( + q + off_q1, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + other=0.0, + ) + + cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) + sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) + + q0 = q0.to(tl.float32) * input_scale + q1 = q1.to(tl.float32) * input_scale + + out0 = (q0 * cos - q1 * sin) / output_scale + out1 = (q0 * sin + q1 * cos) / output_scale + + out0 = out0.to(tl.int8) + out1 = out1.to(tl.int8) + + tl.store( + q + off_q0, + out0, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + ) + tl.store( + q + off_q1, + out1, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + ) + + return + + +@torch.no_grad() +def int8_rotary_embedding_fwd(q, cos, sin, input_scale, output_scale): + total_len = q.shape[0] + head_num = q.shape[1] + head_dim = q.shape[2] + assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" + BLOCK_HEAD = 4 + BLOCK_SEQ = 32 + grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) + if head_dim >= 128: + num_warps = 8 + else: + num_warps = 4 + + _rotary_kernel[grid]( + q, + input_scale, + output_scale, + cos, + sin, + q.stride(0), + q.stride(1), + q.stride(2), + cos.stride(0), + cos.stride(1), + total_len, + HEAD_NUM=head_num, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_SEQ=BLOCK_SEQ, + HEAD_DIM=head_dim, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/colossalai/kernel/triton/rms_norm.py b/colossalai/kernel/triton/rms_norm.py deleted file mode 100644 index d5d6f9d85df1..000000000000 --- a/colossalai/kernel/triton/rms_norm.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch - -try: - import triton - import triton.language as tl - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - - -if HAS_TRITON: - """ - this kernel function is modified from - https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py - """ - - @triton.jit - def _rms_norm_fwd_fused( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - stride, # how much to increase the pointer when moving by 1 row - N, # number of columns in X - eps, # epsilon to avoid division by zero - BLOCK_SIZE: tl.constexpr, - ): - # Map the program id to the row of X and Y it should compute. - row = tl.program_id(0) - Y += row * stride - X += row * stride - # Compute variance - _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - _var += x * x - var = tl.sum(_var, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - # Normalize and apply linear transformation - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) - x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) - x_hat = x * rstd - y = x_hat * w - # Write output - tl.store(Y + cols, y.to(tl.float16), mask=mask) - - def rmsnorm_forward(x, weight, eps): - # allocate output - y = torch.empty_like(x) - # reshape input data into 2D tensor - x_arg = x.view(-1, x.shape[-1]) - M, N = x_arg.shape - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - # print("BLOCK_SIZE:", BLOCK_SIZE) - if N > BLOCK_SIZE: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - # heuristics for number of warps - num_warps = min(max(BLOCK_SIZE // 256, 1), 8) - # print(BLOCK_SIZE, num_warps, "block_size, numwarps") - BLOCK_SIZE = 128 * 2 * 2 * 2 * 2 * 2 * 2 * 2 - num_warps = 8 - # enqueue kernel - _rms_norm_fwd_fused[(M,)](x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) - return y diff --git a/colossalai/kernel/triton/rotary_embedding_kernel.py b/colossalai/kernel/triton/rotary_embedding_kernel.py deleted file mode 100644 index fd74ba817551..000000000000 --- a/colossalai/kernel/triton/rotary_embedding_kernel.py +++ /dev/null @@ -1,212 +0,0 @@ -# Adapted from ModelTC https://github.com/ModelTC/lightllm -import torch -import triton -import triton.language as tl - - -@triton.jit -def _rotary_kernel( - q, - Cos, - Sin, - q_bs_stride, - q_h_stride, - q_d_stride, - cos_bs_stride, - cos_d_stride, - total_len, - HEAD_NUM: tl.constexpr, - BLOCK_HEAD: tl.constexpr, - BLOCK_SEQ: tl.constexpr, - HEAD_DIM: tl.constexpr, -): - current_head_index = tl.program_id(0) - current_seq_index = tl.program_id(1) - - current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) - current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) - - dim_range0 = tl.arange(0, HEAD_DIM // 2) - dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) - - off_q0 = ( - current_seq_range[:, None, None] * q_bs_stride - + current_head_range[None, :, None] * q_h_stride - + dim_range0[None, None, :] * q_d_stride - ) - off_q1 = ( - current_seq_range[:, None, None] * q_bs_stride - + current_head_range[None, :, None] * q_h_stride - + dim_range1[None, None, :] * q_d_stride - ) - - off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride - - q0 = tl.load( - q + off_q0, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - other=0.0, - ) - q1 = tl.load( - q + off_q1, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - other=0.0, - ) - - cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) - sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) - - out0 = q0 * cos - q1 * sin - out1 = q0 * sin + q1 * cos - - tl.store( - q + off_q0, - out0, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - ) - tl.store( - q + off_q1, - out1, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - ) - - return - - -@torch.no_grad() -def rotary_embedding_fwd(q, cos, sin): - total_len = q.shape[0] - head_num = q.shape[1] - head_dim = q.shape[2] - assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" - BLOCK_HEAD = 4 - BLOCK_SEQ = 32 - grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) - if head_dim >= 128: - num_warps = 8 - else: - num_warps = 4 - - _rotary_kernel[grid]( - q, - cos, - sin, - q.stride(0), - q.stride(1), - q.stride(2), - cos.stride(0), - cos.stride(1), - total_len, - HEAD_NUM=head_num, - BLOCK_HEAD=BLOCK_HEAD, - BLOCK_SEQ=BLOCK_SEQ, - HEAD_DIM=head_dim, - num_warps=num_warps, - num_stages=1, - ) - return - - -class Llama2Forwards: - @staticmethod - @triton.jit - def _rotary_kernel( - Q, - Cos, - Sin, - stride_qbs, - stride_qh, - stride_qd, - stride_cosbs, - stride_cosd, - stride_sinbs, - stride_sind, - max_total_len, - H, # N_CTX - BLOCK_HEAD: tl.constexpr, - BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - ): - cur_head_index = tl.program_id(0) - cur_seq_index = tl.program_id(1) - - cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) - cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) - - dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2 - dim_range1 = dim_range0 + 1 - off_q0 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range0[None, None, :] * stride_qd - ) - off_q1 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range1[None, None, :] * stride_qd - ) - - cos_range = tl.arange(0, BLOCK_DMODEL // 2) - off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd - - q0 = tl.load( - Q + off_q0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H), - other=0.0, - ) - q1 = tl.load( - Q + off_q1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H), - other=0.0, - ) - - cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - out0 = q0 * cos - q1 * sin - out1 = q0 * sin + q1 * cos - - tl.store( - Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H) - ) - tl.store( - Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H) - ) - - return - - @staticmethod - @torch.no_grad() - def rotary_emb_fwd(q, cos, sin): - total_len = q.shape[0] - head_num = q.shape[1] - head_dim = q.shape[2] // 2 - assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" - BLOCK_HEAD = 4 - BLOCK_SEQ = 32 - grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) - if head_dim >= 128: - num_warps = 8 - else: - num_warps = 4 - - Llama2Forwards._rotary_kernel[grid]( - q, - cos, - sin, - q.stride(0), - q.stride(1), - q.stride(2), - cos.stride(0), - cos.stride(1), - sin.stride(0), - sin.stride(1), - total_len, - head_num, - BLOCK_HEAD=BLOCK_HEAD, - BLOCK_SEQ=BLOCK_SEQ, - BLOCK_DMODEL=head_dim, - num_warps=num_warps, - num_stages=1, - ) - return diff --git a/colossalai/kernel/triton/self_attention_nofusion.py b/colossalai/kernel/triton/self_attention_nofusion.py index 4b56c8afd67f..50d6786bd940 100644 --- a/colossalai/kernel/triton/self_attention_nofusion.py +++ b/colossalai/kernel/triton/self_attention_nofusion.py @@ -12,6 +12,7 @@ from .qkv_matmul_kernel import qkv_gemm_4d_kernel from .softmax import softmax_kernel + # adpeted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/transformer/inference/triton/triton_matmul_kernel.py#L312 def self_attention_forward_without_fusion( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float ): @@ -141,6 +142,7 @@ def self_attention_forward_without_fusion( ) return output.view(batches, -1, d_model) + # modified from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/transformer/inference/triton/attention.py#L212 def self_attention_compute_using_triton( qkv, input_mask, layer_past, alibi, scale, head_size, triangular=False, use_flash=False ): diff --git a/colossalai/kernel/triton/smooth_attention.py b/colossalai/kernel/triton/smooth_attention.py new file mode 100644 index 000000000000..071de58e20c0 --- /dev/null +++ b/colossalai/kernel/triton/smooth_attention.py @@ -0,0 +1,652 @@ +import math + +import torch + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + """ + this functions are modified from https://github.com/ModelTC/lightllm + """ + + # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py + @triton.jit + def _context_flash_attention_kernel( + Q, + K, + V, + q_input_scale, + k_input_scale, + v_input_scale, + pv_output_scale, + sm_scale, + B_Start_Loc, + B_Seqlen, + TMP, + alibi_ptr, + Out, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_tmp_b, + stride_tmp_h, + stride_tmp_s, + # suggtest set-up 64, 128, 256, 512 + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + batch_id = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + + # get batch info + cur_batch_seq_len = tl.load(B_Seqlen + batch_id) + cur_batch_start_index = tl.load(B_Start_Loc + batch_id) + block_start_loc = BLOCK_M * start_m + + load_p_ptrs = ( + Q + + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + q = q.to(tl.float16) * q_input_scale.to(tl.float16) + + k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd + v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd + t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + if alibi_ptr is not None: + alibi_m = tl.load(alibi_ptr + cur_head) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k = tl.load( + k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0, + ) + k = k.to(tl.float16) * k_input_scale.to(tl.float16) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + if alibi_ptr is not None: + alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) + qk -= alibi_loc * alibi_m + + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) + + v = v.to(tl.float16) * v_input_scale.to(tl.float16) + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + acc = (acc / pv_output_scale.to(tl.float16)).to(tl.int8) + off_o = ( + (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + ) + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + return + + @torch.no_grad() + def smooth_llama_context_attn_fwd( + q, k, v, o, q_input_scale, k_input_scale, v_input_scale, pv_output_scale, b_start_loc, b_seq_len, max_input_len + ): + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk, "context process only supports equal query, key, value length" + assert Lk == Lv, "context process only supports equal query, key, value length" + assert Lk in {16, 32, 64, 128} + sm_scale = 1.0 / math.sqrt(Lk) + batch, head = b_seq_len.shape[0], q.shape[1] + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + + tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + + _context_flash_attention_kernel[grid]( + q, + k, + v, + q_input_scale, + k_input_scale, + v_input_scale, + pv_output_scale, + sm_scale, + b_start_loc, + b_seq_len, + tmp, + None, + o, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py + @triton.jit + def _token_attn_1_kernel( + Q, + K, + q_input_scale, + k_input_scale, + sm_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + q_batch_stride, + q_head_stride, + q_head_dim_stride, + k_batch_stride, + k_head_stride, + k_head_dim_stride, + attn_head_stride, + attn_batch_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + start_n = tl.program_id(2) + + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_end_index = max_kv_cache_len + + off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + q = tl.load(Q + off_q + start_mark) + q = q.to(tl.float16) * q_input_scale.to(tl.float16) + offs_n_new = current_batch_start_index + offs_n + k_loc = tl.load( + kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, + mask=offs_n_new < current_batch_end_index, + other=0, + ) + off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride + k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) + k = k.to(tl.float16) * k_input_scale.to(tl.float16) + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride + tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) + return + + # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py + @triton.jit + def _token_attn_1_alibi_kernel( + Q, + K, + q_input_scale, + k_input_scale, + sm_scale, + alibi, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + q_batch_stride, + q_head_stride, + q_head_dim_stride, + k_batch_stride, + k_head_stride, + k_head_dim_stride, + attn_head_stride, + attn_batch_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + start_n = tl.program_id(2) + + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_end_index = max_kv_cache_len + + off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + alibi_m = tl.load(alibi + current_head) + q = tl.load(Q + off_q + start_mark) + q = q.to(tl.float16) * q_input_scale.to(tl.float16) + + offs_n_new = current_batch_start_index + offs_n + k_loc = tl.load( + kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, + mask=offs_n_new < current_batch_end_index, + other=0, + ) + off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride + k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) + k = k.to(tl.float16) * k_input_scale.to(tl.float16) + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n) + off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride + tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) + return + + @torch.no_grad() + def token_attn_fwd_1( + q, + k, + attn_out, + q_input_scale, + k_input_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + alibi=None, + ): + BLOCK = 32 + # shape constraints + q_head_dim, k_head_dim = q.shape[-1], k.shape[-1] + assert q_head_dim == k_head_dim + assert k_head_dim in {16, 32, 64, 128} + sm_scale = 1.0 / (k_head_dim**0.5) + + batch, head_num = kv_cache_loc.shape[0], q.shape[1] + + grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK)) + + num_warps = 4 if k_head_dim <= 64 else 8 + num_warps = 2 + + if alibi is not None: + _token_attn_1_alibi_kernel[grid]( + q, + k, + q_input_scale, + k_input_scale, + sm_scale, + alibi, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + attn_out.stride(0), + attn_out.stride(1), + HEAD_DIM=k_head_dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + else: + _token_attn_1_kernel[grid]( + q, + k, + q_input_scale, + k_input_scale, + sm_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + attn_out.stride(0), + attn_out.stride(1), + HEAD_DIM=k_head_dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py + @triton.jit + def _token_attn_softmax_fwd( + softmax_logics, + kv_cache_start_loc, + kv_cache_seqlen, + softmax_prob_out, + logics_head_dim_stride, + logics_batch_stride, + prob_head_dim_stride, + prob_batch_stride, + BLOCK_SIZE: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + + col_offsets = tl.arange(0, BLOCK_SIZE) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + row = tl.load( + softmax_logics + + current_head * logics_head_dim_stride + + (current_batch_in_all_start_index + col_offsets) * logics_batch_stride, + mask=col_offsets < current_batch_seq_len, + other=-float("inf"), + ).to(tl.float32) + + row_minus_max = row - tl.max(row, axis=0) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + + tl.store( + softmax_prob_out + + current_head * prob_head_dim_stride + + (current_batch_in_all_start_index + col_offsets) * prob_batch_stride, + softmax_output, + mask=col_offsets < current_batch_seq_len, + ) + return + + @torch.no_grad() + def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len): + BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len) + batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0] + + num_warps = 4 + if BLOCK_SIZE >= 2048: + num_warps = 8 + if BLOCK_SIZE >= 4096: + num_warps = 16 + + _token_attn_softmax_fwd[(batch, head_num)]( + softmax_logics, + kv_cache_start_loc, + kv_cache_seqlen, + softmax_prob_out, + softmax_logics.stride(0), + softmax_logics.stride(1), + softmax_prob_out.stride(0), + softmax_prob_out.stride(1), + num_warps=num_warps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return + + # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py + @triton.jit + def _token_attn_2_kernel( + Prob, + V, + attn_out, + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + prob_head_dim_stride, + prob_batch_stride, + v_batch_stride, + v_head_stride, + v_head_dim_stride, + attn_out_batch_stride, + attn_out_head_stride, + attn_out_head_dim_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride + p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride + v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride + + acc = tl.zeros([HEAD_DIM], dtype=tl.float32) + for start_n in range(0, current_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + p_value = tl.load( + Prob + p_offs + start_n * kv_cache_loc_s_stride, + mask=(start_n + offs_n) < current_batch_seq_len, + other=0.0, + ) + v_loc = tl.load( + kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride, + mask=(start_n + offs_n) < current_batch_seq_len, + other=0.0, + ) + v_value = tl.load( + V + v_offs + v_loc[:, None] * v_batch_stride, + mask=(start_n + offs_n[:, None]) < current_batch_seq_len, + other=0.0, + ) + v_value = v_value.to(tl.float16) * v_input_scale.to(tl.float16) + acc += tl.sum(p_value[:, None] * v_value, 0) + + acc = (acc / pv_output_scale.to(tl.float16)).to(tl.int8) + off_o = ( + current_batch * attn_out_batch_stride + + current_head * attn_out_head_stride + + offs_d * attn_out_head_dim_stride + ) + out_ptrs = attn_out + off_o + tl.store(out_ptrs, acc) + return + + @torch.no_grad() + def token_attn_fwd_2( + prob, + v, + attn_out, + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + ): + if triton.__version__ >= "2.1.0": + BLOCK = 128 + else: + BLOCK = 64 + batch, head = kv_cache_loc.shape[0], v.shape[1] + grid = (batch, head) + num_warps = 4 + dim = v.shape[-1] + + _token_attn_2_kernel[grid]( + prob, + v, + attn_out, + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + prob.stride(0), + prob.stride(1), + v.stride(0), + v.stride(1), + v.stride(2), + attn_out.stride(0), + attn_out.stride(1), + attn_out.stride(2), + HEAD_DIM=dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @torch.no_grad() + def smooth_token_attention_fwd( + q, + k, + v, + attn_out, + q_input_scale, + k_input_scale, + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + alibi=None, + ): + head_num = k.shape[1] + batch_size = kv_cache_seq_len.shape[0] + calcu_shape1 = (batch_size, head_num, k.shape[2]) + total_token_num = k.shape[0] + + att_m_tensor = torch.empty((head_num, total_token_num), dtype=torch.float32, device="cuda") + + token_attn_fwd_1( + q.view(calcu_shape1), + k, + att_m_tensor, + q_input_scale, + k_input_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + alibi=alibi, + ) + + prob = torch.empty_like(att_m_tensor) + + token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) + att_m_tensor = None + token_attn_fwd_2( + prob, + v, + attn_out.view(calcu_shape1), + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + ) + + prob = None + + return diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py index c27394f0f9cf..8dc919bad125 100644 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -12,401 +12,78 @@ HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -if HAS_TRITON: - - @triton.jit - def _token_attn_1_kernel( - Q, - K, - sm_scale, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc_b_stride, - kv_cache_loc_s_stride, - q_batch_stride, - q_head_stride, - q_head_dim_stride, - k_batch_stride, - k_head_stride, - k_head_dim_stride, - attn_head_stride, - attn_batch_stride, - HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - start_n = tl.program_id(2) - - offs_d = tl.arange(0, HEAD_DIM) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_end_index = max_kv_cache_len - - off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride - - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - - block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) - - for start_mark in range(0, block_mask, 1): - q = tl.load(Q + off_q + start_mark) - offs_n_new = current_batch_start_index + offs_n - k_loc = tl.load( - kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, - mask=offs_n_new < current_batch_end_index, - other=0, - ) - off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride - k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride - tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) - return - - @triton.jit - def _token_attn_1_alibi_kernel( - Q, - K, - sm_scale, - alibi, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc_b_stride, - kv_cache_loc_s_stride, - q_batch_stride, - q_head_stride, - q_head_dim_stride, - k_batch_stride, - k_head_stride, - k_head_dim_stride, - attn_head_stride, - attn_batch_stride, - HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - start_n = tl.program_id(2) - - offs_d = tl.arange(0, HEAD_DIM) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_end_index = max_kv_cache_len - - off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride - - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - - block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) +try: + from lightllm.models.llama2.triton_kernel.token_attention_nopad_att1 import ( + token_att_fwd as lightllm_llama2_token_att_fwd, + ) + from lightllm.models.llama2.triton_kernel.token_attention_nopad_reduceV import ( + token_att_fwd2 as lightllm_llama2_token_att_fwd2, + ) + from lightllm.models.llama2.triton_kernel.token_attention_nopad_softmax import ( + token_softmax_fwd as lightllm_llama2_token_softmax_fwd, + ) + + from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fw2 + from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd + from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd + from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd + + HAS_TRITON_TOKEN_ATTENTION = True +except ImportError: + print("unable to import lightllm kernels") + HAS_TRITON_TOKEN_ATTENTION = False - for start_mark in range(0, block_mask, 1): - alibi_m = tl.load(alibi + current_head) - q = tl.load(Q + off_q + start_mark) - offs_n_new = current_batch_start_index + offs_n - k_loc = tl.load( - kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, - mask=offs_n_new < current_batch_end_index, - other=0, - ) - off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride - k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n) - off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride - tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) - return +if HAS_TRITON: @torch.no_grad() - def token_attn_fwd_1( - q, k, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, alibi=None + def token_attention_fwd( + q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, alibi=None ): - BLOCK = 32 - # shape constraints - q_head_dim, k_head_dim = q.shape[-1], k.shape[-1] - assert q_head_dim == k_head_dim - assert k_head_dim in {16, 32, 64, 128} - sm_scale = 1.0 / (k_head_dim**0.5) - - batch, head_num = kv_cache_loc.shape[0], q.shape[1] - - grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK)) + head_num = k.shape[1] + batch_size = kv_cache_seq_len.shape[0] + calcu_shape1 = (batch_size, head_num, k.shape[2]) + total_token_num = k.shape[0] - num_warps = 4 if k_head_dim <= 64 else 8 - num_warps = 2 + att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") - if alibi is not None: - _token_attn_1_alibi_kernel[grid]( - q, + if alibi is None: + lightllm_llama_token_att_fwd( + q.view(calcu_shape1), k, - sm_scale, - alibi, + att_m_tensor, kv_cache_loc, kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc.stride(0), - kv_cache_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - attn_out.stride(0), - attn_out.stride(1), - HEAD_DIM=k_head_dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, + kv_cache_seq_len, + max_len_in_batch, ) else: - _token_attn_1_kernel[grid]( - q, + lightllm_bloom_token_att_fwd( + q.view(calcu_shape1), k, - sm_scale, + att_m_tensor, + alibi, kv_cache_loc, kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc.stride(0), - kv_cache_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - attn_out.stride(0), - attn_out.stride(1), - HEAD_DIM=k_head_dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - @triton.jit - def _token_attn_softmax_fwd( - softmax_logics, - kv_cache_start_loc, - kv_cache_seqlen, - softmax_prob_out, - logics_head_dim_stride, - logics_batch_stride, - prob_head_dim_stride, - prob_batch_stride, - BLOCK_SIZE: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - - col_offsets = tl.arange(0, BLOCK_SIZE) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - row = tl.load( - softmax_logics - + current_head * logics_head_dim_stride - + (current_batch_in_all_start_index + col_offsets) * logics_batch_stride, - mask=col_offsets < current_batch_seq_len, - other=-float("inf"), - ).to(tl.float32) - - row_minus_max = row - tl.max(row, axis=0) - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=0) - softmax_output = numerator / denominator - - tl.store( - softmax_prob_out - + current_head * prob_head_dim_stride - + (current_batch_in_all_start_index + col_offsets) * prob_batch_stride, - softmax_output, - mask=col_offsets < current_batch_seq_len, - ) - return - - @torch.no_grad() - def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len): - BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len) - batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0] - - num_warps = 4 - if BLOCK_SIZE >= 2048: - num_warps = 8 - if BLOCK_SIZE >= 4096: - num_warps = 16 - - _token_attn_softmax_fwd[(batch, head_num)]( - softmax_logics, - kv_cache_start_loc, - kv_cache_seqlen, - softmax_prob_out, - softmax_logics.stride(0), - softmax_logics.stride(1), - softmax_prob_out.stride(0), - softmax_prob_out.stride(1), - num_warps=num_warps, - BLOCK_SIZE=BLOCK_SIZE, - ) - return - - @triton.jit - def _token_attn_2_kernel( - Prob, - V, - attn_out, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - kv_cache_loc_b_stride, - kv_cache_loc_s_stride, - prob_head_dim_stride, - prob_batch_stride, - v_batch_stride, - v_head_stride, - v_head_dim_stride, - attn_out_batch_stride, - attn_out_head_stride, - attn_out_head_dim_stride, - HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, HEAD_DIM) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride - p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride - v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride - - acc = tl.zeros([HEAD_DIM], dtype=tl.float32) - for start_n in range(0, current_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - p_value = tl.load( - Prob + p_offs + start_n * kv_cache_loc_s_stride, - mask=(start_n + offs_n) < current_batch_seq_len, - other=0.0, - ) - v_loc = tl.load( - kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride, - mask=(start_n + offs_n) < current_batch_seq_len, - other=0.0, - ) - v_value = tl.load( - V + v_offs + v_loc[:, None] * v_batch_stride, - mask=(start_n + offs_n[:, None]) < current_batch_seq_len, - other=0.0, + kv_cache_seq_len, + max_len_in_batch, ) - acc += tl.sum(p_value[:, None] * v_value, 0) - - acc = acc.to(tl.float16) - off_o = ( - current_batch * attn_out_batch_stride - + current_head * attn_out_head_stride - + offs_d * attn_out_head_dim_stride - ) - out_ptrs = attn_out + off_o - tl.store(out_ptrs, acc) - return - - @torch.no_grad() - def token_attn_fwd_2(prob, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len): - if triton.__version__ >= "2.1.0": - BLOCK = 128 - else: - BLOCK = 64 - batch, head = kv_cache_loc.shape[0], v.shape[1] - grid = (batch, head) - num_warps = 4 - dim = v.shape[-1] - - _token_attn_2_kernel[grid]( - prob, - v, - attn_out, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - kv_cache_loc.stride(0), - kv_cache_loc.stride(1), - prob.stride(0), - prob.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - attn_out.stride(0), - attn_out.stride(1), - attn_out.stride(2), - HEAD_DIM=dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - @torch.no_grad() - def token_attention_fwd( - q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, alibi=None - ): - head_num = k.shape[1] - batch_size = kv_cache_seq_len.shape[0] - calcu_shape1 = (batch_size, head_num, k.shape[2]) - total_token_num = k.shape[0] - - att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") - - token_attn_fwd_1( - q.view(calcu_shape1), - k, - att_m_tensor, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - alibi=alibi, - ) prob = torch.empty_like(att_m_tensor) - token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) + lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) att_m_tensor = None - token_attn_fwd_2( + lightllm_llama_token_att_fw2( prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch ) - prob = None - return class Llama2TokenAttentionForwards: @staticmethod @triton.jit + + # this function is adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/models/llama2/triton_kernel/token_attention_nopad_softmax.py#L8 def _fwd_kernel( Logics, V, @@ -478,6 +155,7 @@ def _fwd_kernel( tl.store(out_ptrs, acc) return + # this function is adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/models/llama2/triton_kernel/token_attention_nopad_softmax.py#L36 @staticmethod @torch.no_grad() def token_softmax_reducev_fwd(logics, v, o, b_loc, b_start_loc, b_seq_len, max_input_len, other_kv_index): @@ -514,277 +192,6 @@ def token_softmax_reducev_fwd(logics, v, o, b_loc, b_start_loc, b_seq_len, max_i ) return - @staticmethod - @triton.jit - def _fwd_kernel_token_softmax( - Logics, - B_Start_Loc, - B_Seqlen, - Prob_Out, - stride_logic_h, - stride_logic_bs, - stride_prob_h, - stride_prob_bs, - BLOCK_SIZE: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - col_offsets = tl.arange(0, BLOCK_SIZE) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - row = tl.load( - Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, - mask=col_offsets < cur_batch_seq_len, - other=-float("inf"), - ).to(tl.float32) - - row_minus_max = row - tl.max(row, axis=0) - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=0) - softmax_output = numerator / denominator - - tl.store( - Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) * stride_prob_bs, - softmax_output, - mask=col_offsets < cur_batch_seq_len, - ) - return - - @staticmethod - @torch.no_grad() - def token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len): - BLOCK_SIZE = triton.next_power_of_2(max_input_len) - batch, head_num = B_Start_Loc.shape[0], Logics.shape[0] - - num_warps = 4 - if BLOCK_SIZE >= 2048: - num_warps = 8 - if BLOCK_SIZE >= 4096: - num_warps = 16 - - Llama2TokenAttentionForwards._fwd_kernel_token_softmax[(batch, head_num)]( - Logics, - B_Start_Loc, - B_Seqlen, - Prob_Out, - Logics.stride(0), - Logics.stride(1), - Prob_Out.stride(0), - Prob_Out.stride(1), - num_warps=num_warps, - BLOCK_SIZE=BLOCK_SIZE, - ) - return - - @staticmethod - @triton.jit - def _fwd_kernel_token_att1( - Q, - K, - sm_scale, - B_Loc, - B_Start_Loc, - B_Seqlen, - max_input_len, - Att_Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - att_stride_h, - att_stride_bs, - kv_group_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_n = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - cur_batch_start_index = max_input_len - cur_batch_seq_len - cur_batch_end_index = max_input_len - - off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd - - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - - block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) - - for start_mark in range(0, block_mask, 1): - q = tl.load(Q + off_q + start_mark) - offs_n_new = cur_batch_start_index + offs_n - k_loc = tl.load( - B_Loc + stride_b_loc_b * cur_batch + stride_b_loc_s * offs_n_new, - mask=offs_n_new < cur_batch_end_index, - other=0, - ) - off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd - k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs - tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) - return - - @staticmethod - @torch.no_grad() - def token_att_fwd(q, k, att_out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len): - BLOCK = 32 - # shape constraints - Lq, Lk = q.shape[-1], k.shape[-1] - assert Lq == Lk - assert Lk in {16, 32, 64, 128} - sm_scale = 1.0 / (Lk**0.5) - - batch, head_num = B_Loc.shape[0], q.shape[1] - - grid = (batch, head_num, triton.cdiv(max_input_len, BLOCK)) - kv_group_num = q.shape[1] // k.shape[1] - - num_warps = 4 if Lk <= 64 else 8 - num_warps = 2 - - Llama2TokenAttentionForwards._fwd_kernel_token_att1[grid]( - q, - k, - sm_scale, - B_Loc, - B_Start_Loc, - B_Seqlen, - max_input_len, - att_out, - B_Loc.stride(0), - B_Loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - att_out.stride(0), - att_out.stride(1), - kv_group_num=kv_group_num, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - @staticmethod - @triton.jit - def _fwd_kernel_token_att2( - Prob, - V, - Out, - B_Loc, - B_Start_Loc, - B_Seqlen, - max_input_len, # B_Start_Loc cumsum of input lens if continuous - stride_b_loc_b, - stride_b_loc_s, - stride_ph, - stride_pbs, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - cur_kv_head = cur_head // kv_group_num - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_index = max_input_len - cur_batch_seq_len - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - v_loc_off = cur_batch * stride_b_loc_b + (cur_batch_start_index + offs_n) * stride_b_loc_s - p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs - v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - for start_n in range(0, cur_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - p_value = tl.load( - Prob + p_offs + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0 - ) - v_loc = tl.load( - B_Loc + v_loc_off + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0 - ) - v_value = tl.load( - V + v_offs + v_loc[:, None] * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - acc += tl.sum(p_value[:, None] * v_value, 0) - - acc = acc.to(tl.float16) - off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) - return - - @staticmethod - @torch.no_grad() - def token_att_fwd2(prob, v, out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len): - if triton.__version__ >= "2.1.0": - BLOCK = 128 - else: - BLOCK = 64 - batch, head = B_Loc.shape[0], prob.shape[0] - grid = (batch, head) - num_warps = 4 - dim = v.shape[-1] - - kv_group_num = prob.shape[0] // v.shape[1] - - Llama2TokenAttentionForwards._fwd_kernel_token_att2[grid]( - prob, - v, - out, - B_Loc, - B_Start_Loc, - B_Seqlen, - max_input_len, - B_Loc.stride(0), - B_Loc.stride(1), - prob.stride(0), - prob.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - kv_group_num=kv_group_num, - BLOCK_DMODEL=dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - # this is the interface of llama2 attn forward @staticmethod @torch.no_grad() @@ -796,7 +203,7 @@ def token_attn( calcu_shape1 = (batch_size, head_num, head_dim) att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") - Llama2TokenAttentionForwards.token_att_fwd( + lightllm_llama2_token_att_fwd( q, k, att_m_tensor, @@ -808,12 +215,12 @@ def token_attn( if triton.__version__ == "2.0.0": prob = torch.empty_like(att_m_tensor) - Llama2TokenAttentionForwards.token_softmax_fwd( + lightllm_llama2_token_softmax_fwd( att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch ) att_m_tensor = None - Llama2TokenAttentionForwards.token_att_fwd2( + lightllm_llama2_token_att_fwd2( prob, v, attn_out.view(calcu_shape1), diff --git a/colossalai/legacy/context/parallel_context.py b/colossalai/legacy/context/parallel_context.py index 48bf8ab279e8..b95405a33092 100644 --- a/colossalai/legacy/context/parallel_context.py +++ b/colossalai/legacy/context/parallel_context.py @@ -54,7 +54,7 @@ def __init__(self): # logging self._verbose = False - self._logger = get_dist_logger() + self._logger = None @property def config(self): @@ -68,6 +68,12 @@ def verbose(self): def verbose(self, verbose_: bool): self._verbose = verbose_ + @property + def logger(self): + if self._logger is None: + self._logger = get_dist_logger() + return self._logger + def load_config(self, config: Union[dict, str]): """Loads the configuration from either a dict or a file. @@ -527,7 +533,7 @@ def set_device(self, device_ordinal: int = None): torch.cuda.set_device(device_ordinal) if self._verbose: - self._logger.info(f"process rank {global_rank} is bound to device {device_ordinal}") + self.logger.info(f"process rank {global_rank} is bound to device {device_ordinal}") def set_seed(self, seed: int): """Sets seeds for all random libraries. @@ -563,19 +569,19 @@ def set_seed(self, seed: int): seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()]) if self._verbose: - self._logger.info( + self.logger.info( f"initialized seed on rank {global_rank}, " f"numpy: {seed}, python random: {seed}, {seed_str}," f"the default parallel seed is {ParallelMode.DATA}." ) else: if self._verbose: - self._logger.info( + self.logger.info( f"initialized seed on rank {global_rank}, " f"numpy: {seed}, python random: {seed}, pytorch: {seed}", ranks=[0], ) - self._logger.info( + self.logger.info( "WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states", ranks=[0], ) diff --git a/colossalai/legacy/tensor/process_group.py b/colossalai/legacy/tensor/process_group.py index ec6043163336..230849f17576 100644 --- a/colossalai/legacy/tensor/process_group.py +++ b/colossalai/legacy/tensor/process_group.py @@ -31,7 +31,7 @@ def get(self, rank_list: List[int], backend: str = "nccl"): return self.dict[processgroup_key] -PYTORCHPGDICT_ = PyTorchProcessGroupDict() +PYTORCHPGDICT_ = None class ProcessGroup: @@ -59,6 +59,9 @@ def __init__( if not torch.distributed.is_initialized(): self.is_init = False return + global PYTORCHPGDICT_ + if PYTORCHPGDICT_ is None: + PYTORCHPGDICT_ = PyTorchProcessGroupDict() assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized" diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 1bdb81e2d6ec..c3c0180e8516 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -9,7 +9,8 @@ class CPUAdam(NVMeOptimizer): - """Implements Adam algorithm. + """ + Implements Adam algorithm. Supports parameters updating on both GPU and CPU, depending on the device of parameters. But the parameters and gradients should on the same device: @@ -146,8 +147,7 @@ def step(self, closure=None, div_scale: float = -1): assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" self._pre_update(p, "exp_avg", "exp_avg_sq") - # FIXME(ver217): CPU adam kernel only supports fp32 states now - if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float: + if p.grad.dtype is torch.bfloat16: # cpu adam kernel does not support bf16 now bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index 7dc4590dc3f2..c7a309b872ce 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -122,8 +122,7 @@ def step(self, closure=None, div_scale: float = -1): assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" self._pre_update(p, "exp_avg", "exp_avg_sq") - # FIXME(ver217): CPU adam kernel only supports fp32 states now - if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float: + if p.grad.dtype is torch.bfloat16: # cpu adam kernel does not support bf16 now bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 67e198ca0347..f822c1819adc 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -167,7 +167,7 @@ def _p2p_comm( group: ProcessGroup, comm_dtype: torch.dtype = torch.float16, ): - """ + """ Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication. Agrs: @@ -176,7 +176,7 @@ def _p2p_comm( peer (int): rank of the peer group (ProcessGroup): process group comm_dtype (torch.dtype): dtype of the tensor to be sent - + Returns: torch.Tensor: tensor received from previous stage """ @@ -302,7 +302,9 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None: cur_rank = self.stage_manager.get_rank() _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) - def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16) -> None: + def p2p_communicate( + self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16 + ) -> None: """ Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch. @@ -313,5 +315,7 @@ def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None, if peer is None: peer = self.stage_manager.get_next_rank() cur_rank = self.stage_manager.get_rank() - recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype) + recv_tensor = _p2p_comm( + output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype + ) return recv_tensor diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 8f6acd5fcf4b..1f4bbe9f8dad 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -1,6 +1,6 @@ import time from functools import partial -from typing import Any, Iterable, List, Optional, Union +from typing import Any, Iterable, Optional, Union import torch import torch.cuda @@ -16,7 +16,7 @@ from .base import PipelineSchedule -class ActionIntervalBuffer(): +class ActionIntervalBuffer: """ The buffer to save the interval hidden states and new token for stage to use. @@ -70,8 +70,9 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) self.batch = batch self.batch_size = get_batch_size(batch) self.microbatch_offset = 0 - assert self.batch_size % self.microbatch_size == 0, \ - f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}" + assert ( + self.batch_size % self.microbatch_size == 0 + ), f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}" self.num_microbatches = self.batch_size // self.microbatch_size self.round = self.num_microbatches // self.stage_manager.num_stages @@ -86,26 +87,26 @@ def load_micro_batch(self) -> Any: return tree_map(partial(to_device, device=get_current_device()), micro_batch) def _prepare_inputs_for_interval_stage(self): - ''' + """ Prepare inputs for interval stage, for all the interval stage, the inputs is just the past_key_values Returns: dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None` - ''' - model_inputs = { - 'past_key_values': self.mb_manager.cur_kv_cache - } if self.mb_manager.cur_kv_cache is not None else None + """ + model_inputs = ( + {"past_key_values": self.mb_manager.cur_kv_cache} if self.mb_manager.cur_kv_cache is not None else None + ) return model_inputs def _prepare_inputs_for_new_token(self, new_token: torch.Tensor): - ''' + """ Prepare inputs for new token, the inputs is a dict with `input_ids`, `attention_mask` and `past_key_values` `input_ids` is the new token, `attention_mask` is the previous mask add `1` in the end, `past_key_values` is the past_key_values save in the micro batch manager Returns: dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}` - ''' + """ new_mask = self.mb_manager.cur_descrption.attn_mask past_key_values = self.mb_manager.cur_descrption.kv_cache @@ -117,12 +118,12 @@ def _get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor: return input_ids def _recv_pre_stage(self) -> Any: - ''' + """ Receive the output from previous stage Returns: Any: The output from previous stage - ''' + """ if self.stage_manager.num_stages == 2: return self.comm.p2p_recv() return self.comm.recv_forward() @@ -138,7 +139,7 @@ def _load_stage_action(self, model: Module) -> None: output_dict = model_forward(model, inputs_dict, None) self.mb_manager.step(inputs_dict, output_dict, None) - self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + self.action_interval_buffer.hidden_states = output_dict["hidden_states"] def _gen_token_action(self, model: Module): """ @@ -146,13 +147,15 @@ def _gen_token_action(self, model: Module): """ hidden_states = self.action_interval_buffer.hidden_states assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" - hidden_states = {'hidden_states': hidden_states} + hidden_states = {"hidden_states": hidden_states} logits = model_forward(model, None, hidden_states) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) - assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" - new_token = self._get_token_id(logits['logits']) + assert ( + "logits" in logits + ), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" + new_token = self._get_token_id(logits["logits"]) self.mb_manager.step(None, None, new_token) self.action_interval_buffer.new_token = new_token @@ -168,17 +171,17 @@ def _head_encoding_action(self, model: Module): output_dict = model_forward(model, inputs_dict, None) self.mb_manager.step(inputs_dict, output_dict, None) - self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + self.action_interval_buffer.hidden_states = output_dict["hidden_states"] def _body_encoding_action(self, model: Module): hidden_states = self.action_interval_buffer.hidden_states assert hidden_states is not None, "When not first stage, the hidden states should not be None" inputs_dict = self._prepare_inputs_for_interval_stage() - hidden_states = {'hidden_states': hidden_states} + hidden_states = {"hidden_states": hidden_states} output_dict = model_forward(model, inputs_dict, hidden_states) self.mb_manager.step(inputs_dict, output_dict, None) - self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + self.action_interval_buffer.hidden_states = output_dict["hidden_states"] def _comm_action(self, recv_pre: bool) -> torch.Tensor: """ @@ -246,10 +249,13 @@ def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.T whole_timestamp = [] - #run by round + # run by round for _ in range(self.round): - self.timestamps = [[] for _ in range(self.stage_manager.num_stages) - ] if self.verbose and self.stage_manager.is_first_stage() else None + self.timestamps = ( + [[] for _ in range(self.stage_manager.num_stages)] + if self.verbose and self.stage_manager.is_first_stage() + else None + ) self.action_interval_buffer.clear() while self.mb_manager.is_micro_batch_done() is False: actions = self._gen_action(model) @@ -286,8 +292,11 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t whole_timestamp = [] # run by round for _ in range(self.round): - self.timestamps = [[] for _ in range(self.stage_manager.num_stages) - ] if self.verbose and self.stage_manager.is_first_stage() else None + self.timestamps = ( + [[] for _ in range(self.stage_manager.num_stages)] + if self.verbose and self.stage_manager.is_first_stage() + else None + ) while self.mb_manager.is_micro_batch_done() is False: inputs_dict = None new_token = None @@ -307,13 +316,17 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t hidden_states = self.comm.recv_forward() if self.stage_manager.is_first_stage(): # First just generate a new token - assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" + assert ( + hidden_states is not None + ), "When first stage in GENERATE phase, the hidden states should not be None" logits = model_forward(model, None, hidden_states) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) - assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" - new_token = self._get_token_id(logits['logits']) + assert ( + "logits" in logits + ), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" + new_token = self._get_token_id(logits["logits"]) self.mb_manager.step(None, None, new_token) # If the current micro batch is not DONE, go through blocks if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN): @@ -327,9 +340,11 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t self.mb_manager.step(inputs_dict, output_dict, None) # Current microbatch is not DONE, send hidden_state to next stage - if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (Status.GENERATE, - Status.COOLDOWN): - self.comm.send_forward({'hidden_states': output_dict['hidden_states']}) + if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in ( + Status.GENERATE, + Status.COOLDOWN, + ): + self.comm.send_forward({"hidden_states": output_dict["hidden_states"]}) self.mb_manager.next() diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 4bd7d5208a64..63b28701e879 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -77,7 +77,7 @@ Following are the description `ShardConfig`'s arguments: - `enable_sequence_parallelism`: Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. -- `enable_sequence_overlap`: Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when `enable_sequence_parallelism` is True. Defaults to False. +- `enable_sequence_overlap`: Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when `enable_sequence_parallelism` is True. Defaults to False. - `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False. diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index 2db83b912112..5a50e7379cdc 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -100,35 +100,24 @@ def pp_forward( embedding_output = self.embeddings( pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding ) + hidden_states = embedding_output else: assert ( hidden_states is not None ), f"Current stage is {stage_manager.stage}, hidden_states should not be None" - # Go through encoder + encoder_outputs = _encoder_forward( + encoder=self.encoder, + start_idx=stage_index[0], + end_idx=stage_index[1], + hidden_states=hidden_states, + head_mask=head_mask, + return_dict=return_dict, + stage_manager=stage_manager, + ) if not stage_manager.is_last_stage(): - hidden_states = _encoder_forward( - encoder=self.encoder, - start_idx=stage_index[0], - end_idx=stage_index[1], - hidden_states=embedding_output, - head_mask=head_mask, - return_dict=return_dict, - stage_manager=stage_manager, - ) - return {"hidden_states": hidden_states} - else: - encoder_outputs = _encoder_forward( - encoder=self.encoder, - start_idx=stage_index[0], - end_idx=stage_index[1], - hidden_states=hidden_states, - head_mask=head_mask, - return_dict=return_dict, - stage_manager=stage_manager, - ) + return {"hidden_states": encoder_outputs} - # Go through rest layers sequence_output = encoder_outputs[0] sequence_output = self.layernorm(sequence_output) pooled_output = self.pooler(sequence_output) if self.pooler is not None else None diff --git a/colossalai/testing/__init__.py b/colossalai/testing/__init__.py index c6956e81fbde..b84ba55a7a13 100644 --- a/colossalai/testing/__init__.py +++ b/colossalai/testing/__init__.py @@ -9,6 +9,7 @@ ) from .pytest_wrapper import run_on_environment_flag from .utils import ( + DummyDataloader, clear_cache_before_run, free_port, parameterize, @@ -34,4 +35,5 @@ "run_on_environment_flag", "check_state_dict_equal", "assert_hf_output_close", + "DummyDataloader", ] diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py index fdbda9a598bf..839e7aab3567 100644 --- a/colossalai/testing/utils.py +++ b/colossalai/testing/utils.py @@ -273,3 +273,24 @@ def _clear_cache(*args, **kwargs): return _clear_cache return _wrap_func + + +class DummyDataloader: + def __init__(self, data_gen_fn: Callable, length: int = 10): + self.data_gen_fn = data_gen_fn + self.length = length + self.step = 0 + + def __iter__(self): + self.step = 0 + return self + + def __next__(self): + if self.step < self.length: + self.step += 1 + return self.data_gen_fn() + else: + raise StopIteration + + def __len__(self): + return self.length diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index c8be773b2c4f..d3309fc5364f 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -434,6 +434,21 @@ def copy_tensor_to_chunk_slice( if update_ptr: tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape) + def add_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: + """ + Add data slice to the memory space indexed by the input tensor in the chunk. + Only used when accumulating gradient chunks. + + Args: + tensor (torch.Tensor): the tensor used to retrieve meta information + data_slice (torch.Tensor): the tensor to be added to the chunk + """ + # sanity check + assert self.is_gathered + + tensor_info = self.tensors_info[tensor] + self.cuda_global_chunk[tensor_info.offset : tensor_info.end].add_(data_slice.data.flatten()) + def get_valid_length(self) -> int: """Get the valid length of the chunk's payload.""" if self.keep_gathered: diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 713c11742e15..d3c512fe978d 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -5,7 +5,7 @@ import torch.distributed as dist from torch.distributed import ProcessGroup -from colossalai.utils import get_current_device +from colossalai.utils import free_storage, get_current_device from .chunk import Chunk, ChunkFullError, TensorState @@ -255,3 +255,37 @@ def init_grad_chunk(self, chunk: Chunk) -> Chunk: self.accessed_chunks.add(grad_chunk) self.accessed_mem += grad_chunk.chunk_mem return grad_chunk + + def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk: + """Rearrange gradients accumulated in chunk.grad_chunk, and getP prepared for gradient reduction.""" + + assert chunk.grad_chunk is not None + + # Make a backup for gradient accumulated before. + # Here backup gradients should be multiplied, since it will be divided after gradient reduction. + if chunk.grad_chunk.is_gathered: + accumulated_grad = chunk.grad_chunk.cuda_global_chunk.clone().detach().mul_(chunk.pg_size) + accumulated_grad_gathered = True + else: + if chunk.grad_chunk.cuda_shard is not None: + accumulated_grad = chunk.grad_chunk.cuda_shard.clone().detach().mul_(chunk.pg_size) + else: + accumulated_grad = ( + chunk.grad_chunk.cpu_shard.to(get_current_device()).clone().detach().mul_(chunk.pg_size) + ) + accumulated_grad_gathered = False + + # Reset grad_chunk, and chunk.grad_chunk will be accessed. + grad_chunk = self.init_grad_chunk(chunk) + grad_chunk.cuda_global_chunk.zero_() + + # Add backup gradients to grad_chunk. + if accumulated_grad_gathered: + grad_chunk.cuda_global_chunk.add_(accumulated_grad) + else: + grad_chunk.cuda_global_chunk[grad_chunk.shard_begin : grad_chunk.shard_end].add_(accumulated_grad) + + # Release accumulated_grad + free_storage(accumulated_grad) + + return grad_chunk diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index a4871f7e4b40..df7e1163c3d9 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -59,6 +59,7 @@ def __init__( chunk_config_dict: Optional[dict] = None, chunk_init_device: torch.device = torch.device("cpu"), placement_policy: str = "static", + enable_gradient_accumulation: bool = False, shard_param_frac: float = 1.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement @@ -119,6 +120,11 @@ def __init__( self.reuse_fp16_chunk = master_weights self.master_weights = master_weights + self.enable_gradient_accumulation = enable_gradient_accumulation + if self.enable_gradient_accumulation: + self.reuse_fp16_chunk = False + self.accumulating_grads = False # Whether model is accumulating gradients + self._logger = get_dist_logger() if self.gemini_manager._premade_memstats_: @@ -298,6 +304,8 @@ def _post_backward(self): f"{error_str}", ) self._setup_grads_ptr() + if self.enable_gradient_accumulation and not self.accumulating_grads: + self.accumulating_grads = True # Turn on the state of gradient accumulation. self._logger.debug( f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}" ) @@ -327,7 +335,15 @@ def grad_handle(self, p, grad): ) grad_chunk = chunk if not self.reuse_fp16_chunk: - grad_chunk = self.chunk_manager.init_grad_chunk(chunk) + if not self.accumulating_grads: + grad_chunk = self.chunk_manager.init_grad_chunk(chunk) + else: + assert chunk.grad_chunk is not None + if chunk.grad_chunk not in self.chunk_manager.accessed_chunks: + grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk) + else: + grad_chunk = chunk.grad_chunk + # hold -> compute -> hold after bwd grad_chunk.tensor_trans_state(p, TensorState.COMPUTE) grad_chunk.tensor_trans_state(p, TensorState.HOLD_AFTER_BWD) @@ -336,7 +352,10 @@ def grad_handle(self, p, grad): chunk.tensor_trans_state(p, TensorState.HOLD) grad_chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE) - grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk) + if not self.accumulating_grads: + grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk) + else: + grad_chunk.add_tensor_to_chunk_slice(p, grad) reduced = self.chunk_manager.reduce_chunk(grad_chunk) if reduced: if not self.reuse_fp16_chunk: @@ -354,7 +373,7 @@ def grad_handle(self, p, grad): if chunk.l2_norm_flag: grad_chunk.set_l2_norm() self.chunk_manager.move_chunk(grad_chunk, self.grads_device[p], force_copy=True) - if not self.master_weights: + if not (self.master_weights) or (self.enable_gradient_accumulation): self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True) return empty_grad diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 3c42e96cb803..0d0298e067f3 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -263,6 +263,7 @@ def step(self, *args, **kwargs): self.zero_grad() if self.module.master_weights: self._update_fp16_params() + self.module.accumulating_grads = False return ret def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0): diff --git a/docs/source/en/basics/booster_plugins.md b/docs/source/en/basics/booster_plugins.md index feb37fc15de2..fa360a4b9213 100644 --- a/docs/source/en/basics/booster_plugins.md +++ b/docs/source/en/basics/booster_plugins.md @@ -15,7 +15,7 @@ We currently provide the following plugins: - [Torch FSDP Plugin](#torch-fsdp-plugin): It is a wrapper of `torch.distributed.fsdp.FullyShardedDataParallel` and can be used to train models with zero-dp. - [Low Level Zero Plugin](#low-level-zero-plugin): It wraps the `colossalai.zero.low_level.LowLevelZeroOptimizer` and can be used to train models with zero-dp. It only supports zero stage-1 and stage-2. - [Gemini Plugin](#gemini-plugin): It wraps the [Gemini](../features/zero_with_chunk.md) which implements Zero-3 with chunk-based and heterogeneous memory management. -- [Hybrid Pararllel Plugin](#hybrid-parallel-plugin): It provides a tidy interface that integrates the power of Shardformer, pipeline manager, mixied precision training, TorchDDP and Zero stage 1/2 feature. With this plugin, transformer models can be easily trained with any combination of tensor parallel, pipeline parallel and data parallel (DDP/Zero) efficiently, along with various kinds of optimization tools for acceleration and memory saving. Detailed information about supported parallel strategies and optimization tools is explained in the section below. +- [Hybrid Parallel Plugin](#hybrid-parallel-plugin): It provides a tidy interface that integrates the power of Shardformer, pipeline manager, mixied precision training, TorchDDP and Zero stage 1/2 feature. With this plugin, transformer models can be easily trained with any combination of tensor parallel, pipeline parallel and data parallel (DDP/Zero) efficiently, along with various kinds of optimization tools for acceleration and memory saving. Detailed information about supported parallel strategies and optimization tools is explained in the section below. More plugins are coming soon. diff --git a/docs/source/en/features/gradient_accumulation_with_booster.md b/docs/source/en/features/gradient_accumulation_with_booster.md index 347cd6e519bb..ea97dd92e885 100644 --- a/docs/source/en/features/gradient_accumulation_with_booster.md +++ b/docs/source/en/features/gradient_accumulation_with_booster.md @@ -1,6 +1,6 @@ # Gradient Accumulation -Author: [Mingyan Jiang](https://github.com/jiangmingyan) +Author: [Mingyan Jiang](https://github.com/jiangmingyan), [Baizhou Zhang](https://github.com/Fridge003) **Prerequisite** - [Training Booster](../basics/booster_api.md) @@ -126,6 +126,7 @@ for idx, (img, label) in enumerate(train_dataloader): ``` + ### Step 6. Invoke Training Scripts To verify gradient accumulation, we can just check the change of parameter values. When gradient accumulation is set, parameters are only updated in the last step. You can run the script using this command: ```shell @@ -142,4 +143,30 @@ iteration 2, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0 iteration 3, first 10 elements of param: tensor([-0.0141, 0.0464, 0.0507, 0.0321, 0.0356, -0.0150, 0.0172, -0.0118, 0.0222, 0.0473], device='cuda:0', grad_fn=) ``` + +## Gradient Accumulation on GeminiPlugin + +Currently the plugins supporting `no_sync()` method include `TorchDDPPlugin` and `LowLevelZeroPlugin` set to stage 1. `GeminiPlugin` doesn't support `no_sync()` method, but it can enable synchronized gradient accumulation in a torch-like way. + +To enable gradient accumulation feature, the argument `enable_gradient_accumulation` should be set to `True` when initializing `GeminiPlugin`. Following is the pseudocode snippet of enabling gradient accumulation for `GeminiPlugin`: + +```python +... +plugin = GeminiPlugin(..., enable_gradient_accumulation=True) +booster = Booster(plugin=plugin) +... + +... +for idx, (input, label) in enumerate(train_dataloader): + output = gemini_model(input.cuda()) + train_loss = criterion(output, label.cuda()) + train_loss = train_loss / GRADIENT_ACCUMULATION + booster.backward(train_loss, gemini_optimizer) + + if idx % (GRADIENT_ACCUMULATION - 1) == 0: + gemini_optimizer.step() # zero_grad is automatically done +... +``` + + diff --git a/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md b/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md index 3ad9b2e07a95..824308f94654 100644 --- a/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md +++ b/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md @@ -1,6 +1,6 @@ # 梯度累积 -作者: [Mingyan Jiang](https://github.com/jiangmingyan) +作者: [Mingyan Jiang](https://github.com/jiangmingyan), [Baizhou Zhang](https://github.com/Fridge003) **前置教程** - [训练中使用Booster](../basics/booster_api.md) @@ -93,6 +93,7 @@ model, optimizer, criterion, train_dataloader, _ = booster.boost(model=model, dataloader=train_dataloader) ``` + ### 步骤 5. 使用booster训练 使用booster构建一个普通的训练循环,验证梯度累积。 `param_by_iter` 记录分布训练的信息。 ```python @@ -144,4 +145,29 @@ iteration 2, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0 iteration 3, first 10 elements of param: tensor([-0.0141, 0.0464, 0.0507, 0.0321, 0.0356, -0.0150, 0.0172, -0.0118, 0.0222, 0.0473], device='cuda:0', grad_fn=) ``` +## 在Gemini插件中使用梯度累积 + +目前支持`no_sync()`方法的插件包括 `TorchDDPPlugin` 和 `LowLevelZeroPlugin`(需要设置参数`stage`为1). `GeminiPlugin` 不支持 `no_sync()` 方法, 但是它可以通过和`pytorch`类似的方式来使用同步的梯度累积。 + +为了开启梯度累积功能,在初始化`GeminiPlugin`的时候需要将参数`enable_gradient_accumulation`设置为`True`。以下是 `GeminiPlugin` 进行梯度累积的伪代码片段: + +```python +... +plugin = GeminiPlugin(..., enable_gradient_accumulation=True) +booster = Booster(plugin=plugin) +... + +... +for idx, (input, label) in enumerate(train_dataloader): + output = gemini_model(input.cuda()) + train_loss = criterion(output, label.cuda()) + train_loss = train_loss / GRADIENT_ACCUMULATION + booster.backward(train_loss, gemini_optimizer) + + if idx % (GRADIENT_ACCUMULATION - 1) == 0: + gemini_optimizer.step() # zero_grad is automatically done +... +``` + + diff --git a/examples/inference/_utils.py b/examples/inference/_utils.py new file mode 100644 index 000000000000..67d897836113 --- /dev/null +++ b/examples/inference/_utils.py @@ -0,0 +1,19 @@ +def print_perf_stats(latency_set, config, bs, warmup=3): + # trim warmup queries + latency_set = list(latency_set) + latency_set = latency_set[warmup:] + count = len(latency_set) + + if count > 0: + latency_set.sort() + avg = sum(latency_set) / count + num_layers = ( + getattr(config, "num_layers") if hasattr(config, "num_layers") else getattr(config, "num_hidden_layers") + ) + num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 + num_bytes = 2 # float16 + + print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) + print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) + print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12)) + print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs)) diff --git a/examples/inference/bench_bloom.py b/examples/inference/bench_bloom.py index 738f43dc0619..054641f6eebf 100644 --- a/examples/inference/bench_bloom.py +++ b/examples/inference/bench_bloom.py @@ -3,6 +3,7 @@ import time import torch +from _utils import print_perf_stats from transformers import BloomForCausalLM, BloomTokenizerFast import colossalai @@ -14,25 +15,6 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" -def print_perf_stats(latency_set, config, bs, warmup=3): - # trim warmup queries - latency_set = list(latency_set) - latency_set = latency_set[warmup:] - count = len(latency_set) - - if count > 0: - latency_set.sort() - avg = sum(latency_set) / count - num_layers = getattr(config, "num_layers", config.num_hidden_layers) - num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 - num_bytes = 2 # float16 - - print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) - print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) - print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12)) - print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs)) - - def bench_bloom(args): model_path = args.path max_batch_size = args.batch_size diff --git a/examples/inference/bench_chatglm2.py b/examples/inference/bench_chatglm2.py new file mode 100644 index 000000000000..f3678d29ff93 --- /dev/null +++ b/examples/inference/bench_chatglm2.py @@ -0,0 +1,116 @@ +import argparse +import os +import time + +import torch +from _utils import print_perf_stats +from transformers import AutoTokenizer + +import colossalai +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration +from colossalai.testing import rerun_if_address_is_in_use, spawn + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" + + +def run_chatglm2_test(args): + chatglm2_model_path = args.path + max_batch_size = args.batch_size + max_input_len = args.input_len + max_output_len = args.output_len + args.test_mode + + print("max_batch_size : " + str(max_batch_size)) + + tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) + model = ChatGLMForConditionalGeneration.from_pretrained(chatglm2_model_path, pad_token_id=tokenizer.eos_token_id) + model = model.half() + model.config + + shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) + infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) + + generate_kwargs = dict(max_new_tokens=1, do_sample=False) + input_tokens = { + "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"), + "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"), + } + + iters = 10 + prefill_times = [] + + warmup = 3 + + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + torch.cuda.synchronize() + end = time.time() + out_len = outputs.shape[1] + print("generation time {} s".format(str(end - start))) + print(out_len - max_input_len) + prefill_times.append((end - start) / (out_len - max_input_len)) + + prefill_times = prefill_times[warmup:] + prefill_time_avg = sum(prefill_times) / len(prefill_times) + generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + + times = [] + decoder_times = [] + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + torch.cuda.synchronize() + end = time.time() + out_len = outputs.shape[1] + print("generation time {} s".format(str(end - start))) + print(out_len - max_input_len) + times.append((end - start) / (out_len - max_input_len)) + if args.test_mode == "decoder_test": + decoder_times.append((end - start - prefill_time_avg) / (out_len - max_input_len - 1)) + + times = times[warmup:] + latency = sum(times) / len(times) + print("total process latency is : " + str(latency) + " s") + print("total throughput is : " + str(1 / latency * max_batch_size)) + + if args.test_mode == "decoder_test": + decoder_times = decoder_times[warmup:] + latency = sum(decoder_times) / len(decoder_times) + + print("decoder process latency is : " + str(latency) + " s") + print("decoder throughput is : " + str(1 / latency * max_batch_size)) + + print_perf_stats(times, model.config, max_batch_size) + + +def check_chatglm2(rank, world_size, port, args): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_chatglm2_test(args) + + +@rerun_if_address_is_in_use() +def test_chatglm2(args): + spawn(check_chatglm2, args.tp_size, args=args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-p", "--path", type=str, help="Model path", required=True) + parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size") + parser.add_argument("--input_len", type=int, default=256, help="Maximum input length") + parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") + parser.add_argument( + "--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"] + ) + + args = parser.parse_args() + + test_chatglm2(args) diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index 90d49f6a264a..f3e742dfbb59 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -3,7 +3,7 @@ import time import torch -from torch.profiler import ProfilerActivity, profile, record_function +from _utils import print_perf_stats from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai @@ -15,48 +15,52 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" -def print_perf_stats(latency_set, config, bs, warmup=3): - # trim warmup queries - latency_set = list(latency_set) - latency_set = latency_set[warmup:] - count = len(latency_set) - - if count > 0: - latency_set.sort() - avg = sum(latency_set) / count - num_layers = getattr(config, "num_layers", config.num_hidden_layers) - num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 - num_bytes = 2 - - print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) - print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) - print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12)) - - def run_llama_test(args): llama_model_path = args.path max_batch_size = args.batch_size max_input_len = args.input_len max_output_len = args.output_len + args.test_mode + + print("max_batch_size : " + str(max_batch_size)) tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) tokenizer.pad_token_id = tokenizer.unk_token_id model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) model = model.half() - model_config = model.config + model.config shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) - generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + generate_kwargs = dict(max_new_tokens=1, do_sample=False) input_tokens = { "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"), "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"), } iters = 10 - times = [] + prefill_times = [] + + warmup = 3 + + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + torch.cuda.synchronize() + end = time.time() + out_len = outputs.shape[1] + print("generation time {} s".format(str(end - start))) + print(out_len - max_input_len) + prefill_times.append((end - start) / (out_len - max_input_len)) + prefill_times = prefill_times[warmup:] + prefill_time_avg = sum(prefill_times) / len(prefill_times) + generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + + times = [] + decoder_times = [] for i in range(iters): torch.cuda.synchronize() start = time.time() @@ -65,17 +69,24 @@ def run_llama_test(args): end = time.time() out_len = outputs.shape[1] print("generation time {} s".format(str(end - start))) + print(out_len - max_input_len) times.append((end - start) / (out_len - max_input_len)) + if args.test_mode == "decoder_test": + decoder_times.append((end - start - prefill_time_avg) / (out_len - max_input_len - 1)) + + times = times[warmup:] + latency = sum(times) / len(times) + print("total process latency is : " + str(latency) + " s") + print("total throughput is : " + str(1 / latency * max_batch_size)) + + if args.test_mode == "decoder_test": + decoder_times = decoder_times[warmup:] + latency = sum(decoder_times) / len(decoder_times) - print("outputs, ", len(outputs)) - print_perf_stats(times, model_config, max_batch_size) + print("decoder process latency is : " + str(latency) + " s") + print("decoder throughput is : " + str(1 / latency * max_batch_size)) - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: - with record_function("model_inference"): - torch.cuda.synchronize() - outputs = infer_engine.generate(input_tokens, **generate_kwargs) - torch.cuda.synchronize() - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + print_perf_stats(times, model.config, max_batch_size) def check_llama(rank, world_size, port, args): @@ -95,8 +106,11 @@ def test_llama(args): parser.add_argument("-p", "--path", type=str, help="Model path", required=True) parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size") - parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length") + parser.add_argument("--input_len", type=int, default=256, help="Maximum input length") parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") + parser.add_argument( + "--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"] + ) args = parser.parse_args() diff --git a/examples/inference/gptq_bloom.py b/examples/inference/gptq_bloom.py index 43e118cc0aa5..cfa3171374dd 100644 --- a/examples/inference/gptq_bloom.py +++ b/examples/inference/gptq_bloom.py @@ -1,12 +1,11 @@ import argparse -import logging import os import time import torch -from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig -from auto_gptq.nn_modules.qlinear import GeneralQuantLinear -from transformers import AutoTokenizer, BloomForCausalLM, BloomTokenizerFast, LlamaForCausalLM, LlamaTokenizer +from _utils import print_perf_stats +from auto_gptq import AutoGPTQForCausalLM +from transformers import BloomTokenizerFast import colossalai from colossalai.inference.tensor_parallel.engine import TPInferEngine @@ -14,30 +13,10 @@ from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' - - -def print_perf_stats(latency_set, config, bs, warmup=3): - # trim warmup queries - latency_set = list(latency_set) - latency_set = latency_set[warmup:] - count = len(latency_set) - - if count > 0: - latency_set.sort() - avg = sum(latency_set) / count - num_layers = getattr(config, "num_layers", config.num_hidden_layers) - num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 - num_bytes = 2 # float16 - - print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) - print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) - print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12)) - print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs)) +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" def bench_bloom(args): - pretrained_model_dir = args.path quantized_model_dir = args.quantized_path max_batch_size = args.batch_size @@ -48,9 +27,9 @@ def bench_bloom(args): tokenizer.pad_token = tokenizer.eos_token # load quantized model to the first GPU - model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, - device=torch.cuda.current_device(), - inject_fused_attention=False) + model = AutoGPTQForCausalLM.from_quantized( + quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False + ) model = model.half() @@ -60,22 +39,22 @@ def bench_bloom(args): generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) input_tokens = { - "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'), - "attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda') + "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"), + "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"), } # init TPInferEngine and shard the original model # To benchmark torch original, comment out the line of optimizing model - shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, - inference_only=True, - inference_gptq=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True + ) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) # prepare data for generation generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) input_tokens = { "input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)), - "attention_mask": torch.ones((max_batch_size, max_input_len)) + "attention_mask": torch.ones((max_batch_size, max_input_len)), } for t in input_tokens: if torch.is_tensor(input_tokens[t]): @@ -99,7 +78,7 @@ def bench_bloom(args): def check_bloom(rank, world_size, port, args): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") bench_bloom(args) @@ -111,12 +90,12 @@ def test_bloom(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('-p', '--path', type=str, help='Model path', required=True) - parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True) - parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') - parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') - parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') - parser.add_argument('--output_len', type=int, default=128, help='Maximum output length') + parser.add_argument("-p", "--path", type=str, help="Model path", required=True) + parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True) + parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size") + parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length") + parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") args = parser.parse_args() diff --git a/examples/inference/gptq_llama.py b/examples/inference/gptq_llama.py index 1bdee448c742..35a6049ad409 100644 --- a/examples/inference/gptq_llama.py +++ b/examples/inference/gptq_llama.py @@ -3,12 +3,12 @@ import time import torch +from _utils import print_perf_stats from auto_gptq import AutoGPTQForCausalLM from transformers import LlamaTokenizer import colossalai from colossalai.inference.tensor_parallel.engine import TPInferEngine -from colossalai.inference.tensor_parallel.modeling._utils import init_to_get_rotary from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn @@ -16,25 +16,6 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" -def print_perf_stats(latency_set, config, bs, warmup=3): - # trim warmup queries - latency_set = list(latency_set) - latency_set = latency_set[warmup:] - count = len(latency_set) - - if count > 0: - latency_set.sort() - avg = sum(latency_set) / count - num_layers = getattr(config, "num_layers", config.num_hidden_layers) - num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 - num_bytes = 2 - - print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) - print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) - print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12)) - print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs)) - - def run_llama_test(args): pretrained_model_dir = args.path quantized_model_dir = args.quantized_path @@ -50,8 +31,6 @@ def run_llama_test(args): quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False ) - init_to_get_rotary(model.model.model, base=10000) - model_config = model.config shard_config = ShardConfig( enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True diff --git a/examples/inference/smoothquant_llama.py b/examples/inference/smoothquant_llama.py new file mode 100644 index 000000000000..ce7a00aa2739 --- /dev/null +++ b/examples/inference/smoothquant_llama.py @@ -0,0 +1,69 @@ +import argparse +import os + +import torch +from datasets import load_dataset +from transformers import LlamaTokenizer + +from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM + + +def build_model_and_tokenizer(model_name): + tokenizer = LlamaTokenizer.from_pretrained(model_name, model_max_length=512) + kwargs = {"torch_dtype": torch.float16, "device_map": "sequential"} + model = SmoothLlamaForCausalLM.from_pretrained(model_name, **kwargs) + model = model.to(torch.float32) + return model, tokenizer + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-name", type=str, help="model name") + parser.add_argument( + "--output-path", + type=str, + help="where to save the checkpoint", + ) + parser.add_argument( + "--dataset-path", + type=str, + help="location of the calibration dataset", + ) + parser.add_argument("--num-samples", type=int, default=512) + parser.add_argument("--seq-len", type=int, default=512) + args = parser.parse_args() + return args + + +@torch.no_grad() +def main(): + args = parse_args() + model_path = args.model_name + dataset_path = args.dataset_path + output_path = args.output_path + num_samples = 10 + seq_len = 512 + + model, tokenizer = build_model_and_tokenizer(model_path) + if not os.path.exists(dataset_path): + print(f"Cannot find the dataset at {args.dataset_path}") + raise FileNotFoundError + dataset = load_dataset("json", data_files=dataset_path, split="train") + + model.quantized(tokenizer, dataset, num_samples=num_samples, seq_len=seq_len) + model = model.cuda() + + model.save_quantized(output_path, model_basename="llama-7b") + + model = SmoothLlamaForCausalLM.from_quantized(output_path, model_basename="llama-7b") + model = model.cuda() + + generate_kwargs = dict(max_new_tokens=16, do_sample=False, use_cache=True) + input_tokens = tokenizer(["today is "], return_tensors="pt").to("cuda") + out = model.generate(**input_tokens, **generate_kwargs) + text = tokenizer.batch_decode(out) + print("out is:", text) + + +if __name__ == "__main__": + main() diff --git a/op_builder/smoothquant.py b/op_builder/smoothquant.py new file mode 100644 index 000000000000..d562a4c4f626 --- /dev/null +++ b/op_builder/smoothquant.py @@ -0,0 +1,52 @@ +import torch + +from .builder import Builder +from .utils import append_nvcc_threads + + +class SmoothquantBuilder(Builder): + NAME = "cu_smoothquant" + PREBUILT_IMPORT_PATH = "colossalai._C.cu_smoothquant" + + def __init__(self): + super().__init__(name=SmoothquantBuilder.NAME, prebuilt_import_path=SmoothquantBuilder.PREBUILT_IMPORT_PATH) + + def include_dirs(self): + ret = [self.csrc_abs_path("smoothquant"), self.get_cuda_home_include()] + return ret + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in [ + "smoothquant/binding.cpp", + "smoothquant/linear.cu", + ] + ] + return ret + + def cxx_flags(self): + return ["-O3"] + self.version_dependent_macros + + def nvcc_flags(self): + compute_capability = torch.cuda.get_device_capability() + cuda_arch = compute_capability[0] * 100 + compute_capability[1] * 10 + + extra_cuda_flags = [ + "-v", + f"-DCUDA_ARCH={cuda_arch}", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-DTHRUST_IGNORE_CUB_VERSION_CHECK", + ] + + ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) + + def builder(self): + try: + super().builder() + except: + warnings.warn("build smoothquant lib not successful") diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 9aa5f2822e40..19cb7a154a01 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -11,3 +11,6 @@ ninja torch>=1.12 safetensors einops +sentencepiece +google +protobuf diff --git a/tests/components_to_test/__init__.py b/tests/components_to_test/__init__.py deleted file mode 100644 index 65eaa72d6e84..000000000000 --- a/tests/components_to_test/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -from . import ( - beit, - bert, - gpt2, - hanging_param_model, - inline_op_model, - nested_model, - repeated_computed_layers, - resnet, - simple_net, -) -from .utils import run_fwd, run_fwd_bwd - -from . import albert # isort:skip - -__all__ = [ - "bert", - "gpt2", - "hanging_param_model", - "inline_op_model", - "nested_model", - "repeated_computed_layers", - "resnet", - "simple_net", - "run_fwd_bwd", - "albert", - "beit", - "run_fwd", -] diff --git a/tests/components_to_test/albert.py b/tests/components_to_test/albert.py deleted file mode 100644 index 0ba4d19655cd..000000000000 --- a/tests/components_to_test/albert.py +++ /dev/null @@ -1,62 +0,0 @@ -import torch -from transformers import AlbertConfig, AlbertForSequenceClassification - -from .bert import get_bert_data_loader -from .registry import non_distributed_component_funcs - - -@non_distributed_component_funcs.register(name="albert") -def get_training_components(): - hidden_dim = 8 - num_head = 4 - sequence_length = 12 - num_layer = 2 - vocab_size = 32 - - def bert_model_builder(checkpoint: bool = False): - config = AlbertConfig( - vocab_size=vocab_size, - gradient_checkpointing=checkpoint, - hidden_size=hidden_dim, - intermediate_size=hidden_dim * 4, - num_attention_heads=num_head, - max_position_embeddings=sequence_length, - num_hidden_layers=num_layer, - hidden_dropout_prob=0.0, - attention_probs_dropout_prob=0.0, - ) - print("building AlbertForSequenceClassification model") - - # adapting huggingface BertForSequenceClassification for single unittest calling interface - class ModelAdaptor(AlbertForSequenceClassification): - def forward(self, input_ids, labels): - """ - inputs: data, label - outputs: loss - """ - return super().forward(input_ids=input_ids, labels=labels)[0] - - model = ModelAdaptor(config) - # if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"): - # model.gradient_checkpointing_enable() - - return model - - is_distributed = torch.distributed.is_initialized() - trainloader = get_bert_data_loader( - n_class=vocab_size, - batch_size=2, - total_samples=10000, - sequence_length=sequence_length, - is_distributed=is_distributed, - ) - testloader = get_bert_data_loader( - n_class=vocab_size, - batch_size=2, - total_samples=10000, - sequence_length=sequence_length, - is_distributed=is_distributed, - ) - - criterion = None - return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/beit.py b/tests/components_to_test/beit.py deleted file mode 100644 index d33474ea9a6b..000000000000 --- a/tests/components_to_test/beit.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch -from timm.models.beit import Beit - -from colossalai.utils.cuda import get_current_device - -from .registry import non_distributed_component_funcs -from .utils.dummy_data_generator import DummyDataGenerator - - -class DummyDataLoader(DummyDataGenerator): - img_size = 64 - num_channel = 3 - num_class = 10 - batch_size = 4 - - def generate(self): - data = torch.randn( - ( - DummyDataLoader.batch_size, - DummyDataLoader.num_channel, - DummyDataLoader.img_size, - DummyDataLoader.img_size, - ), - device=get_current_device(), - ) - label = torch.randint( - low=0, high=DummyDataLoader.num_class, size=(DummyDataLoader.batch_size,), device=get_current_device() - ) - return data, label - - -@non_distributed_component_funcs.register(name="beit") -def get_training_components(): - def model_builder(checkpoint=False): - model = Beit( - img_size=DummyDataLoader.img_size, num_classes=DummyDataLoader.num_class, embed_dim=32, depth=2, num_heads=4 - ) - return model - - trainloader = DummyDataLoader() - testloader = DummyDataLoader() - - criterion = torch.nn.CrossEntropyLoss() - return model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/bert.py b/tests/components_to_test/bert.py deleted file mode 100644 index f0061ad18c84..000000000000 --- a/tests/components_to_test/bert.py +++ /dev/null @@ -1,89 +0,0 @@ -import torch -import transformers -from packaging import version -from torch.utils.data import SequentialSampler -from transformers import BertConfig, BertForSequenceClassification - -from .registry import non_distributed_component_funcs - - -def get_bert_data_loader( - n_class, - batch_size, - total_samples, - sequence_length, - device=torch.device("cpu:0"), - is_distributed=False, -): - train_data = torch.randint( - low=0, - high=n_class, - size=(total_samples, sequence_length), - device=device, - dtype=torch.long, - ) - train_label = torch.randint(low=0, high=2, size=(total_samples,), device=device, dtype=torch.long) - train_dataset = torch.utils.data.TensorDataset(train_data, train_label) - if is_distributed: - sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) - else: - sampler = SequentialSampler(train_dataset) - train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=sampler) - return train_loader - - -@non_distributed_component_funcs.register(name="bert") -def get_training_components(): - hidden_dim = 8 - num_head = 4 - sequence_length = 12 - num_layer = 2 - vocab_size = 32 - - def bert_model_builder(checkpoint: bool = False): - config = BertConfig( - vocab_size=vocab_size, - gradient_checkpointing=checkpoint, - hidden_size=hidden_dim, - intermediate_size=hidden_dim * 4, - num_attention_heads=num_head, - max_position_embeddings=sequence_length, - num_hidden_layers=num_layer, - hidden_dropout_prob=0.0, - attention_probs_dropout_prob=0.0, - ) - print("building BertForSequenceClassification model") - - # adapting huggingface BertForSequenceClassification for single unittest calling interface - class ModelAdaptor(BertForSequenceClassification): - def forward(self, input_ids, labels): - """ - inputs: data, label - outputs: loss - """ - return super().forward(input_ids=input_ids, labels=labels)[0] - - model = ModelAdaptor(config) - if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"): - model.gradient_checkpointing_enable() - - return model - - is_distributed = torch.distributed.is_initialized() - trainloader = get_bert_data_loader( - n_class=vocab_size, - batch_size=2, - total_samples=10000, - sequence_length=sequence_length, - is_distributed=is_distributed, - ) - testloader = get_bert_data_loader( - n_class=vocab_size, - batch_size=2, - total_samples=10000, - sequence_length=sequence_length, - is_distributed=is_distributed, - ) - - criterion = None - return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/gpt2.py b/tests/components_to_test/gpt2.py deleted file mode 100644 index 7f826497d2ab..000000000000 --- a/tests/components_to_test/gpt2.py +++ /dev/null @@ -1,92 +0,0 @@ -import torch -import torch.nn as nn -from transformers import GPT2Config, GPT2LMHeadModel - -from colossalai.utils.cuda import get_current_device - -from .registry import non_distributed_component_funcs -from .utils.dummy_data_generator import DummyDataGenerator - - -class DummyDataLoader(DummyDataGenerator): - vocab_size = 128 - batch_size = 4 - seq_len = 64 - - def generate(self): - input_ids = torch.randint( - 0, - DummyDataLoader.vocab_size, - (DummyDataLoader.batch_size, DummyDataLoader.seq_len), - device=get_current_device(), - ) - return input_ids, input_ids - - -class GPTLMModel(nn.Module): - def __init__( - self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50304, - checkpoint=False, - ): - super().__init__() - self.checkpoint = checkpoint - self.model = GPT2LMHeadModel( - GPT2Config( - n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size, - resid_pdrop=0.0, - embd_pdrop=0.0, - attn_pdrop=0.0, - ) - ) - if checkpoint: - self.model.gradient_checkpointing_enable() - - def forward(self, input_ids): - # Only return lm_logits - attention_mask = torch.ones_like(input_ids) - return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] - - -def gpt2_micro(checkpoint=True): - return GPTLMModel( - checkpoint=checkpoint, hidden_size=32, num_layers=2, num_attention_heads=4, max_seq_len=64, vocab_size=128 - ) - - -def gpt2_s(checkpoint=True): - return GPTLMModel(checkpoint=checkpoint) - - -def gpt2_m(checkpoint=True): - return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) - - -class GPTLMLoss(nn.Module): - def __init__(self): - super().__init__() - self.loss_fn = nn.CrossEntropyLoss() - - def forward(self, logits, labels): - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - -@non_distributed_component_funcs.register(name="gpt2") -def get_training_components(): - trainloader = DummyDataLoader() - testloader = DummyDataLoader() - - criterion = GPTLMLoss() - return gpt2_micro, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/hanging_param_model.py b/tests/components_to_test/hanging_param_model.py deleted file mode 100644 index 5531c8d081a0..000000000000 --- a/tests/components_to_test/hanging_param_model.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from colossalai.legacy.nn import CheckpointModule - -from .registry import non_distributed_component_funcs -from .utils.dummy_data_generator import DummyDataGenerator - - -class HangingParamModule(CheckpointModule): - """ - Hanging Parameter: a parameter dose not belong to a leaf Module. - It has subordinate nn.modules and a nn.Parameter. - """ - - def __init__(self, checkpoint=False) -> None: - super().__init__(checkpoint=checkpoint) - self.proj1 = nn.Linear(4, 8) - self.weight = nn.Parameter(torch.randn(8, 8)) - self.proj2 = nn.Linear(8, 4) - - def forward(self, x): - x = self.proj1(x) - x = F.linear(x, self.weight) - x = self.proj2(x) - return x - - -class DummyDataLoader(DummyDataGenerator): - def generate(self): - data = torch.rand(16, 4) - label = torch.randint(low=0, high=2, size=(16,)) - return data, label - - -@non_distributed_component_funcs.register(name="hanging_param_model") -def get_training_components(): - def model_builder(checkpoint=False): - return HangingParamModule(checkpoint) - - trainloader = DummyDataLoader() - testloader = DummyDataLoader() - - criterion = torch.nn.CrossEntropyLoss() - from colossalai.nn.optimizer import HybridAdam - - return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/inline_op_model.py b/tests/components_to_test/inline_op_model.py deleted file mode 100644 index 8bfa9cf34353..000000000000 --- a/tests/components_to_test/inline_op_model.py +++ /dev/null @@ -1,49 +0,0 @@ -import torch -import torch.nn as nn - -from colossalai.legacy.nn import CheckpointModule - -from .registry import non_distributed_component_funcs -from .utils.dummy_data_generator import DummyDataGenerator - - -class InlineOpModule(CheckpointModule): - """ - a module with inline Ops - """ - - def __init__(self, checkpoint=False) -> None: - super().__init__(checkpoint=checkpoint) - self.proj1 = nn.Linear(4, 8) - self.proj2 = nn.Linear(8, 8) - - def forward(self, x): - x = self.proj1(x) - # inline add_ - x.add_(10) - x = self.proj2(x) - # inline relu_ - x = torch.relu_(x) - x = self.proj2(x) - return x - - -class DummyDataLoader(DummyDataGenerator): - def generate(self): - data = torch.rand(16, 4) - label = torch.randint(low=0, high=2, size=(16,)) - return data, label - - -@non_distributed_component_funcs.register(name="inline_op_model") -def get_training_components(): - def model_builder(checkpoint=False): - return InlineOpModule(checkpoint) - - trainloader = DummyDataLoader() - testloader = DummyDataLoader() - - criterion = torch.nn.CrossEntropyLoss() - from colossalai.nn.optimizer import HybridAdam - - return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/registry.py b/tests/components_to_test/registry.py deleted file mode 100644 index ec561b7831ad..000000000000 --- a/tests/components_to_test/registry.py +++ /dev/null @@ -1,38 +0,0 @@ -#!/usr/bin/env python - - -class Registry: - def __init__(self): - self._registry = dict() - - def register(self, name): - assert name not in self._registry - - def _register(callable_): - self._registry[name] = callable_ - - return _register - - def get_callable(self, name: str): - return self._registry[name] - - def __iter__(self): - self._idx = 0 - self._len = len(self._registry) - self._names = list(self._registry.keys()) - return self - - def __next__(self): - if self._idx < self._len: - key = self._names[self._idx] - callable_ = self._registry[key] - self._idx += 1 - return callable_ - else: - raise StopIteration - - -non_distributed_component_funcs = Registry() -model_parallel_component_funcs = Registry() - -__all__ = ["non_distributed_component_funcs", "model_parallel_component_funcs"] diff --git a/tests/components_to_test/repeated_computed_layers.py b/tests/components_to_test/repeated_computed_layers.py deleted file mode 100644 index 3da64de3fb64..000000000000 --- a/tests/components_to_test/repeated_computed_layers.py +++ /dev/null @@ -1,47 +0,0 @@ -#!/usr/bin/env python - -import torch -import torch.nn as nn - -from colossalai.legacy.nn import CheckpointModule - -from .registry import non_distributed_component_funcs -from .utils.dummy_data_generator import DummyDataGenerator - - -class NetWithRepeatedlyComputedLayers(CheckpointModule): - """ - This model is to test with layers which go through forward pass multiple times. - In this model, the fc1 and fc2 call forward twice - """ - - def __init__(self, checkpoint=False) -> None: - super().__init__(checkpoint=checkpoint) - self.fc1 = nn.Linear(5, 5) - self.fc2 = nn.Linear(5, 5) - self.fc3 = nn.Linear(5, 2) - self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3] - - def forward(self, x): - for layer in self.layers: - x = layer(x) - return x - - -class DummyDataLoader(DummyDataGenerator): - def generate(self): - data = torch.rand(16, 5) - label = torch.randint(low=0, high=2, size=(16,)) - return data, label - - -@non_distributed_component_funcs.register(name="repeated_computed_layers") -def get_training_components(): - def model_builder(checkpoint=False): - return NetWithRepeatedlyComputedLayers(checkpoint) - - trainloader = DummyDataLoader() - testloader = DummyDataLoader() - - criterion = torch.nn.CrossEntropyLoss() - return model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/resnet.py b/tests/components_to_test/resnet.py deleted file mode 100644 index a43becc16233..000000000000 --- a/tests/components_to_test/resnet.py +++ /dev/null @@ -1,37 +0,0 @@ -import os -from pathlib import Path - -import torch -from torchvision.datasets import CIFAR10 -from torchvision.models import resnet18 -from torchvision.transforms import transforms - -from colossalai.legacy.utils import get_dataloader - -from .registry import non_distributed_component_funcs - - -def get_cifar10_dataloader(train): - # build dataloaders - dataset = CIFAR10( - root=Path(os.environ["DATA"]), - download=True, - train=train, - transform=transforms.Compose( - [transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))] - ), - ) - dataloader = get_dataloader(dataset=dataset, shuffle=True, batch_size=16, drop_last=True) - return dataloader - - -@non_distributed_component_funcs.register(name="resnet18") -def get_resnet_training_components(): - def model_builder(checkpoint=False): - return resnet18(num_classes=10) - - trainloader = get_cifar10_dataloader(train=True) - testloader = get_cifar10_dataloader(train=False) - - criterion = torch.nn.CrossEntropyLoss() - return model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/simple_net.py b/tests/components_to_test/simple_net.py deleted file mode 100644 index 0f0ac5cff49a..000000000000 --- a/tests/components_to_test/simple_net.py +++ /dev/null @@ -1,53 +0,0 @@ -import torch -import torch.nn as nn - -from colossalai.legacy.nn import CheckpointModule -from colossalai.utils.cuda import get_current_device - -from .registry import non_distributed_component_funcs -from .utils.dummy_data_generator import DummyDataGenerator - - -class SimpleNet(CheckpointModule): - """ - In this no-leaf module, it has subordinate nn.modules and a nn.Parameter. - """ - - def __init__(self, checkpoint=False) -> None: - super().__init__(checkpoint=checkpoint) - self.embed = nn.Embedding(20, 4) - self.proj1 = nn.Linear(4, 8) - self.ln1 = nn.LayerNorm(8) - self.proj2 = nn.Linear(8, 4) - self.ln2 = nn.LayerNorm(4) - self.classifier = nn.Linear(4, 4) - - def forward(self, x): - x = self.embed(x) - x = self.proj1(x) - x = self.ln1(x) - x = self.proj2(x) - x = self.ln2(x) - x = self.classifier(x) - return x - - -class DummyDataLoader(DummyDataGenerator): - def generate(self): - data = torch.randint(low=0, high=20, size=(16,), device=get_current_device()) - label = torch.randint(low=0, high=2, size=(16,), device=get_current_device()) - return data, label - - -@non_distributed_component_funcs.register(name="simple_net") -def get_training_components(): - def model_builder(checkpoint=False): - return SimpleNet(checkpoint) - - trainloader = DummyDataLoader() - testloader = DummyDataLoader() - - criterion = torch.nn.CrossEntropyLoss() - from colossalai.nn.optimizer import HybridAdam - - return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/utils/__init__.py b/tests/components_to_test/utils/__init__.py deleted file mode 100644 index 150124b58800..000000000000 --- a/tests/components_to_test/utils/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .dummy_data_generator import DummyDataGenerator -from .executor import run_fwd, run_fwd_bwd diff --git a/tests/components_to_test/utils/dummy_data_generator.py b/tests/components_to_test/utils/dummy_data_generator.py deleted file mode 100644 index 7b3af46c8f35..000000000000 --- a/tests/components_to_test/utils/dummy_data_generator.py +++ /dev/null @@ -1,24 +0,0 @@ -from abc import ABC, abstractmethod - - -class DummyDataGenerator(ABC): - def __init__(self, length=10): - self.length = length - - @abstractmethod - def generate(self): - pass - - def __iter__(self): - self.step = 0 - return self - - def __next__(self): - if self.step < self.length: - self.step += 1 - return self.generate() - else: - raise StopIteration - - def __len__(self): - return self.length diff --git a/tests/kit/model_zoo/__init__.py b/tests/kit/model_zoo/__init__.py index c08fd365d871..62b9123b59b0 100644 --- a/tests/kit/model_zoo/__init__.py +++ b/tests/kit/model_zoo/__init__.py @@ -1,4 +1,5 @@ -from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers +from . import custom, diffusers, timm, torchaudio, torchrec, torchvision, transformers +from .executor import run_fwd, run_fwd_bwd from .registry import model_zoo -__all__ = ["model_zoo"] +__all__ = ["model_zoo", "run_fwd", "run_fwd_bwd"] diff --git a/tests/kit/model_zoo/custom/__init__.py b/tests/kit/model_zoo/custom/__init__.py new file mode 100644 index 000000000000..1f8ac324d4d6 --- /dev/null +++ b/tests/kit/model_zoo/custom/__init__.py @@ -0,0 +1,4 @@ +from .hanging_param_model import * +from .nested_model import * +from .repeated_computed_layers import * +from .simple_net import * diff --git a/tests/kit/model_zoo/custom/base.py b/tests/kit/model_zoo/custom/base.py new file mode 100644 index 000000000000..4a0f505826f1 --- /dev/null +++ b/tests/kit/model_zoo/custom/base.py @@ -0,0 +1,26 @@ +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + + +class CheckpointModule(nn.Module): + def __init__(self, checkpoint: bool = False): + super().__init__() + self.checkpoint = checkpoint + self._use_checkpoint = checkpoint + + def _forward(self, *args, **kwargs): + raise NotImplementedError("CheckpointModule should implement _forward method instead of origin forward") + + def forward(self, *args, **kwargs): + if self._use_checkpoint: + return checkpoint(self._forward, *args, **kwargs) + else: + return self._forward(*args, **kwargs) + + def train(self, mode: bool = True): + self._use_checkpoint = self.checkpoint + return super().train(mode=mode) + + def eval(self): + self._use_checkpoint = False + return super().eval() diff --git a/tests/kit/model_zoo/custom/hanging_param_model.py b/tests/kit/model_zoo/custom/hanging_param_model.py new file mode 100644 index 000000000000..a8ace5f35e6a --- /dev/null +++ b/tests/kit/model_zoo/custom/hanging_param_model.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..registry import model_zoo +from .base import CheckpointModule + + +class HangingParamModule(CheckpointModule): + """ + Hanging Parameter: a parameter dose not belong to a leaf Module. + It has subordinate nn.modules and a nn.Parameter. + """ + + def __init__(self, checkpoint=False) -> None: + super().__init__(checkpoint=checkpoint) + self.proj1 = nn.Linear(4, 8) + self.weight = nn.Parameter(torch.randn(8, 8)) + self.proj2 = nn.Linear(8, 4) + + def forward(self, x): + x = self.proj1(x) + x = F.linear(x, self.weight) + x = self.proj2(x) + return x + + +def data_gen(): + return dict(x=torch.rand(16, 4)) + + +def loss_fn(x): + outputs = x["x"] + label = torch.randint(low=0, high=2, size=(16,), device=outputs.device) + return F.cross_entropy(x["x"], label) + + +def output_transform(x: torch.Tensor): + return dict(x=x) + + +model_zoo.register( + name="custom_hanging_param_model", + model_fn=HangingParamModule, + data_gen_fn=data_gen, + output_transform_fn=output_transform, + loss_fn=loss_fn, +) diff --git a/tests/components_to_test/nested_model.py b/tests/kit/model_zoo/custom/nested_model.py similarity index 50% rename from tests/components_to_test/nested_model.py rename to tests/kit/model_zoo/custom/nested_model.py index 44577456dec5..2eb1c8398a29 100644 --- a/tests/components_to_test/nested_model.py +++ b/tests/kit/model_zoo/custom/nested_model.py @@ -2,10 +2,8 @@ import torch.nn as nn import torch.nn.functional as F -from colossalai.legacy.nn import CheckpointModule - -from .registry import non_distributed_component_funcs -from .utils import DummyDataGenerator +from ..registry import model_zoo +from .base import CheckpointModule class SubNet(nn.Module): @@ -32,20 +30,24 @@ def forward(self, x): return x -class DummyDataLoader(DummyDataGenerator): - def generate(self): - data = torch.rand(16, 5) - label = torch.randint(low=0, high=2, size=(16,)) - return data, label +def data_gen(): + return dict(x=torch.rand(16, 5)) + + +def loss_fn(x): + outputs = x["x"] + label = torch.randint(low=0, high=2, size=(16,), device=outputs.device) + return F.cross_entropy(x["x"], label) -@non_distributed_component_funcs.register(name="nested_model") -def get_training_components(): - def model_builder(checkpoint=False): - return NestedNet(checkpoint) +def output_transform(x: torch.Tensor): + return dict(x=x) - trainloader = DummyDataLoader() - testloader = DummyDataLoader() - criterion = torch.nn.CrossEntropyLoss() - return model_builder, trainloader, testloader, torch.optim.Adam, criterion +model_zoo.register( + name="custom_nested_model", + model_fn=NestedNet, + data_gen_fn=data_gen, + output_transform_fn=output_transform, + loss_fn=loss_fn, +) diff --git a/tests/kit/model_zoo/custom/repeated_computed_layers.py b/tests/kit/model_zoo/custom/repeated_computed_layers.py new file mode 100644 index 000000000000..781fecc51427 --- /dev/null +++ b/tests/kit/model_zoo/custom/repeated_computed_layers.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..registry import model_zoo +from .base import CheckpointModule + + +class NetWithRepeatedlyComputedLayers(CheckpointModule): + """ + This model is to test with layers which go through forward pass multiple times. + In this model, the fc1 and fc2 call forward twice + """ + + def __init__(self, checkpoint=False) -> None: + super().__init__(checkpoint=checkpoint) + self.fc1 = nn.Linear(5, 5) + self.fc2 = nn.Linear(5, 5) + self.fc3 = nn.Linear(5, 2) + self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3] + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def data_gen(): + return dict(x=torch.rand(16, 5)) + + +def loss_fn(x): + outputs = x["x"] + label = torch.randint(low=0, high=2, size=(16,), device=outputs.device) + return F.cross_entropy(x["x"], label) + + +def output_transform(x: torch.Tensor): + return dict(x=x) + + +model_zoo.register( + name="custom_repeated_computed_layers", + model_fn=NetWithRepeatedlyComputedLayers, + data_gen_fn=data_gen, + output_transform_fn=output_transform, + loss_fn=loss_fn, +) diff --git a/tests/kit/model_zoo/custom/simple_net.py b/tests/kit/model_zoo/custom/simple_net.py new file mode 100644 index 000000000000..ae68fccf9c61 --- /dev/null +++ b/tests/kit/model_zoo/custom/simple_net.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..registry import model_zoo +from .base import CheckpointModule + + +class SimpleNet(CheckpointModule): + """ + In this no-leaf module, it has subordinate nn.modules and a nn.Parameter. + """ + + def __init__(self, checkpoint=False) -> None: + super().__init__(checkpoint=checkpoint) + self.embed = nn.Embedding(20, 4) + self.proj1 = nn.Linear(4, 8) + self.ln1 = nn.LayerNorm(8) + self.proj2 = nn.Linear(8, 4) + self.ln2 = nn.LayerNorm(4) + self.classifier = nn.Linear(4, 4) + + def forward(self, x): + x = self.embed(x) + x = self.proj1(x) + x = self.ln1(x) + x = self.proj2(x) + x = self.ln2(x) + x = self.classifier(x) + return x + + +def data_gen(): + return dict(x=torch.randint(low=0, high=20, size=(16,))) + + +def loss_fn(x): + outputs = x["x"] + label = torch.randint(low=0, high=2, size=(16,), device=outputs.device) + return F.cross_entropy(x["x"], label) + + +def output_transform(x: torch.Tensor): + return dict(x=x) + + +model_zoo.register( + name="custom_simple_net", + model_fn=SimpleNet, + data_gen_fn=data_gen, + output_transform_fn=output_transform, + loss_fn=loss_fn, +) diff --git a/tests/components_to_test/utils/executor.py b/tests/kit/model_zoo/executor.py similarity index 51% rename from tests/components_to_test/utils/executor.py rename to tests/kit/model_zoo/executor.py index 631401e022e6..033d6d12dd07 100644 --- a/tests/components_to_test/utils/executor.py +++ b/tests/kit/model_zoo/executor.py @@ -1,7 +1,15 @@ +from typing import Callable, Dict, Optional, Union + import torch +from torch.nn import Module +from torch.optim import Optimizer + +from colossalai.interface import OptimizerWrapper -def run_fwd(model, data, label, criterion) -> torch.Tensor: +def run_fwd( + model: Module, data: Dict, output_transform_fn: Callable, criterion: Optional[Callable] = None +) -> torch.Tensor: """run_fwd run fwd for the model @@ -14,18 +22,22 @@ def run_fwd(model, data, label, criterion) -> torch.Tensor: Returns: torch.Tensor: loss of fwd """ + outputs = model(**data) + outputs = output_transform_fn(outputs) if criterion: - y = model(data) - y = y.float() - loss = criterion(y, label) + loss = criterion(outputs) else: - loss = model(data, label) - - loss = loss.float() + loss = next(iter(outputs.values())).sum() return loss -def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor: +def run_fwd_bwd( + model: Module, + data: Dict, + output_transform_fn: Callable, + criterion: Optional[Callable] = None, + optimizer: Optional[Union[Optimizer, OptimizerWrapper]] = None, +) -> torch.Tensor: """run_fwd_bwd run fwd and bwd for the model @@ -38,7 +50,7 @@ def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor: Returns: torch.Tensor: loss of fwd """ - loss = run_fwd(model, data, label, criterion) + loss = run_fwd(model, data, output_transform_fn, criterion) if optimizer: optimizer.backward(loss) else: diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py index 8b90a3c7372c..6dd3e102c20f 100644 --- a/tests/kit/model_zoo/transformers/bert.py +++ b/tests/kit/model_zoo/transformers/bert.py @@ -359,9 +359,9 @@ def data_gen_for_qa(): # define loss funciton loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss( - x.last_hidden_state, torch.ones_like(x.last_hidden_state) + x["last_hidden_state"], torch.ones_like(x["last_hidden_state"]) ) -loss_fn = lambda x: x.loss +loss_fn = lambda x: x["loss"] config = transformers.BertConfig( hidden_size=128, diff --git a/tests/kit/model_zoo/transformers/blip2.py b/tests/kit/model_zoo/transformers/blip2.py index 887b11c7f54e..0be9268307ce 100644 --- a/tests/kit/model_zoo/transformers/blip2.py +++ b/tests/kit/model_zoo/transformers/blip2.py @@ -35,7 +35,7 @@ def data_gen(): output_transform_fn = lambda x: x # define loss funciton -loss_fn_blip2_model = lambda x: x.loss +loss_fn_blip2_model = lambda x: x["loss"] config = transformers.Blip2Config() config.vision_config.patch_size = 14 diff --git a/tests/kit/model_zoo/transformers/bloom.py b/tests/kit/model_zoo/transformers/bloom.py index 12dcd71d5d1b..07f1d497777d 100644 --- a/tests/kit/model_zoo/transformers/bloom.py +++ b/tests/kit/model_zoo/transformers/bloom.py @@ -69,11 +69,11 @@ def data_gen_for_question_answering(): # define loss function loss_fn_for_bloom_model = lambda x: torch.nn.functional.mse_loss( - x.last_hidden_state, torch.ones_like(x.last_hidden_state) + x["last_hidden_state"], torch.ones_like(x["last_hidden_state"]) ) -loss_fn_for_causal_lm = lambda x: x.loss -loss_fn_for_classification = lambda x: x.loss -loss_fn_for_question_answering = lambda x: x.loss +loss_fn_for_causal_lm = lambda x: x["loss"] +loss_fn_for_classification = lambda x: x["loss"] +loss_fn_for_question_answering = lambda x: x["loss"] config = transformers.BloomConfig( n_layer=2, n_head=4, vocab_size=250880, hidden_dropout=0, attention_dropout=0, hidden_size=64, pad_token_id=50256 diff --git a/tests/kit/model_zoo/transformers/chatglm2.py b/tests/kit/model_zoo/transformers/chatglm2.py index f4369cb7d171..0b178d58ce33 100644 --- a/tests/kit/model_zoo/transformers/chatglm2.py +++ b/tests/kit/model_zoo/transformers/chatglm2.py @@ -30,9 +30,9 @@ def data_gen_for_conditional_generation(): # define loss function loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss( - x.last_hidden_state, torch.ones_like(x.last_hidden_state) + x["last_hidden_state"], torch.ones_like(x["last_hidden_state"]) ) -loss_fn = lambda x: x.loss +loss_fn = lambda x: x["loss"] config = ChatGLMConfig( num_layers=2, diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index 2af6176fbe4a..5e98c02fd4fc 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -87,13 +87,14 @@ def date_gen_for_double_heads(): # define loss function loss_fn_for_gpt2_model = lambda x: torch.nn.functional.mse_loss( - x.last_hidden_state, torch.ones_like(x.last_hidden_state) + x["last_hidden_state"], torch.ones_like(x["last_hidden_state"]) ) -loss_fn = lambda x: x.loss +loss_fn = lambda x: x["loss"] config = transformers.GPT2Config( n_layer=2, n_head=4, + n_embd=128, vocab_size=50258, attn_pdrop=0, embd_pdrop=0, diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index bc229b17e08c..041de6b90f8d 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -42,9 +42,9 @@ def data_gen_for_casual_lm(): output_transform_fn = lambda x: x # function to get the loss - loss_fn = lambda output: output.last_hidden_state.mean() - loss_fn_for_casual_lm = lambda output: output.loss - loss_fn_for_seq_classification = lambda output: output.logits.mean() + loss_fn = lambda output: output["last_hidden_state"].mean() + loss_fn_for_casual_lm = lambda output: output["loss"] + loss_fn_for_seq_classification = lambda output: output["logits"].mean() config = LlamaConfig( num_hidden_layers=4, diff --git a/tests/kit/model_zoo/transformers/opt.py b/tests/kit/model_zoo/transformers/opt.py index 07ca41ef21ae..2da94a4fcc0f 100644 --- a/tests/kit/model_zoo/transformers/opt.py +++ b/tests/kit/model_zoo/transformers/opt.py @@ -45,9 +45,9 @@ def data_gen_for_question_answering(): output_transform_fn = lambda x: x loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss( - x.last_hidden_state, torch.ones_like(x.last_hidden_state) + x["last_hidden_state"], torch.ones_like(x["last_hidden_state"]) ) -loss_fn_for_lm = lambda x: x.loss +loss_fn_for_lm = lambda x: x["loss"] config = transformers.OPTConfig( hidden_size=128, num_hidden_layers=2, diff --git a/tests/kit/model_zoo/transformers/sam.py b/tests/kit/model_zoo/transformers/sam.py index b928a8f14e75..7e756abe91b8 100644 --- a/tests/kit/model_zoo/transformers/sam.py +++ b/tests/kit/model_zoo/transformers/sam.py @@ -40,7 +40,7 @@ def data_gen(): output_transform_fn = lambda x: x # define loss funciton -loss_fn = lambda x: x.iou_scores.mean() +loss_fn = lambda x: x["iou_scores"].mean() config = transformers.SamConfig() config.vision_config.num_hidden_layers = 2 diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py index 1b63cccc42ee..2ccfb0356c2b 100644 --- a/tests/kit/model_zoo/transformers/t5.py +++ b/tests/kit/model_zoo/transformers/t5.py @@ -44,9 +44,9 @@ def data_gen_for_t5_model(): output_transform_fn = lambda x: x # define loss function -loss_fn_for_t5_model = lambda x: x.last_hidden_state.mean() -loss_fn_for_encoder_only = lambda x: x.last_hidden_state.mean() -loss_fn_for_conditional_generation = lambda x: x.loss +loss_fn_for_t5_model = lambda x: x["last_hidden_state"].mean() +loss_fn_for_encoder_only = lambda x: x["last_hidden_state"].mean() +loss_fn_for_conditional_generation = lambda x: x["loss"] # define model config config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0) diff --git a/tests/kit/model_zoo/transformers/vit.py b/tests/kit/model_zoo/transformers/vit.py index f1990751b016..223559d73a55 100644 --- a/tests/kit/model_zoo/transformers/vit.py +++ b/tests/kit/model_zoo/transformers/vit.py @@ -34,9 +34,9 @@ def data_gen_for_masked_image_modeling(): output_transform_fn = lambda x: x # function to get the loss -loss_fn_for_vit_model = lambda x: x.pooler_output.mean() -loss_fn_for_image_classification = lambda x: x.logits.mean() -loss_fn_for_masked_image_modeling = lambda x: x.loss +loss_fn_for_vit_model = lambda x: x["pooler_output"].mean() +loss_fn_for_image_classification = lambda x: x["logits"].mean() +loss_fn_for_masked_image_modeling = lambda x: x["loss"] # register the following models # transformers.ViTModel, diff --git a/tests/kit/model_zoo/transformers/whisper.py b/tests/kit/model_zoo/transformers/whisper.py index 928be4468c01..d69bebe6cc04 100644 --- a/tests/kit/model_zoo/transformers/whisper.py +++ b/tests/kit/model_zoo/transformers/whisper.py @@ -53,8 +53,8 @@ def data_gen_for_audio_classification(): output_transform_fn = lambda x: x # define loss funciton -loss_fn = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state)) -loss_fn_attr = lambda x: x.loss +loss_fn = lambda x: torch.nn.functional.mse_loss(x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])) +loss_fn_attr = lambda x: x["loss"] config = transformers.WhisperConfig( classifier_proj_size=256, diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 9cc12f96bd4d..104ca254c572 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -14,6 +14,8 @@ _AMP_ERR_MODELS = ["timm_convit", "deepfm_interactionarch"] # These models have no parameters _LOW_LEVEL_ZERO_ERR_MODELS = ["dlrm_interactionarch"] +# These models will cause stuck, to be fixed +_STUCK_MODELS = ["transformers_albert_for_multiple_choice"] def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: @@ -53,7 +55,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): """ passed_models = [] failed_info = {} # (model_name, error) pair - ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS skipped_models = [] for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): diff --git a/tests/test_infer/test_pipeline_infer.py b/tests/test_infer/test_pipeline_infer.py index 47cf9e78d138..ad8e32b48bae 100644 --- a/tests/test_infer/test_pipeline_infer.py +++ b/tests/test_infer/test_pipeline_infer.py @@ -1,9 +1,6 @@ -from copy import deepcopy - import pytest import torch import torch.distributed as dist -import torch.nn as nn import transformers import colossalai @@ -20,27 +17,29 @@ def data_gen(): inputs = data_gen() for k, v in inputs.items(): - if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: new_shape = [1] * v.dim() new_shape[0] = 16 - inputs[k] = v.to('cuda').repeat(*new_shape) + inputs[k] = v.to("cuda").repeat(*new_shape) def pipeline_inference_test(pp_size, new_length, micro_batch_size): model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8)) - engine = PPInferEngine(pp_size=pp_size, - model=model, - model_policy=GPT2LMHeadModelPipelinePolicy(), - new_length=new_length, - micro_batch_size=micro_batch_size) + engine = PPInferEngine( + pp_size=pp_size, + model=model, + model_policy=GPT2LMHeadModelPipelinePolicy(), + new_length=new_length, + micro_batch_size=micro_batch_size, + ) output = engine.inference([inputs]) if dist.get_rank() == 0: assert len(output[0]) == new_length, f"{len(output)}, {new_length}" -@parameterize('pp_size', [4]) -@parameterize('new_length', [4, 8, 16]) -@parameterize('micro_batch_size', [1, 4]) +@parameterize("pp_size", [4]) +@parameterize("new_length", [4, 8, 16]) +@parameterize("micro_batch_size", [1, 4]) @clear_cache_before_run() def run_pipeline_inference_test(pp_size, new_length, micro_batch_size): pipeline_inference_test(pp_size, new_length, micro_batch_size) @@ -48,7 +47,7 @@ def run_pipeline_inference_test(pp_size, new_length, micro_batch_size): def check_pipeline_inference(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_pipeline_inference_test() @@ -59,5 +58,5 @@ def test_pipeline_inference(): spawn(check_pipeline_inference, nprocs=4) -if __name__ == '__main__': +if __name__ == "__main__": test_pipeline_inference() diff --git a/tests/test_infer_ops/triton/test_llama2_token_attn.py b/tests/test_infer_ops/triton/test_llama2_token_attn.py deleted file mode 100644 index 0537a3d76129..000000000000 --- a/tests/test_infer_ops/triton/test_llama2_token_attn.py +++ /dev/null @@ -1,63 +0,0 @@ -import pytest -import torch -from packaging import version - -try: - pass - - from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): - xq = xq.view(bs, 1, num_head, head_dim) - xk = xk.view(bs, seqlen, num_head, head_dim) - xv = xv.view(bs, seqlen, num_head, head_dim) - - logics = torch.sum(xq * xk, dim=3, keepdim=False) * 1 / (head_dim**0.5) - prob = torch.softmax(logics, dim=1) - prob = prob.view(bs, seqlen, num_head, 1) - - return torch.sum(prob * xv, dim=1, keepdim=False) - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test(): - Z, head_num, seq_len, head_dim = 2, 32, 2048, 128 - dtype = torch.float16 - - # attn out: 2,4096 - q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) - v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda") - max_kv_cache_len = seq_len - kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") - kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda") - kv_cache_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") - other_kv_index = 2048 - - kv_cache_seq_len[:] = seq_len - kv_cache_start_loc[0] = 0 - kv_cache_start_loc[1] = seq_len - - for i in range(Z): - kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda") - - Llama2TokenAttentionForwards.token_attn( - q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, other_kv_index - ) - torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim) - assert torch.allclose(torch_out, o, atol=1e-3, rtol=0) - - -if __name__ == "__main__": - test() diff --git a/tests/test_infer_ops/triton/test_token_attn_1.py b/tests/test_infer_ops/triton/test_token_attn_1.py deleted file mode 100644 index fc5f8cd6c9dc..000000000000 --- a/tests/test_infer_ops/triton/test_token_attn_1.py +++ /dev/null @@ -1,74 +0,0 @@ -import math - -import pytest -import torch -from packaging import version - -try: - pass - - from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_1 - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -def torch_attn(xq, xk, bs, seqlen, num_head, head_dim): - xq = xq.view(bs, 1, num_head, head_dim) - xk = xk.view(bs, seqlen, num_head, head_dim) - keys = xk - xq = xq.transpose(1, 2) - keys = keys.transpose(1, 2) - scores = ( - (torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape(num_head, -1) - ) - return scores - - -def torch_attn_1(xq, xk, seqlen, num_head, head_dim): - xq = xq.view(1, num_head, head_dim) - xk = xk.view(seqlen, num_head, head_dim) - logics = torch.sum(xq * xk, dim=-1, keepdim=False) - - logics = logics.transpose(0, 1) / math.sqrt(head_dim) - return logics - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_attn_1(): - pass - - batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128 - - dtype = torch.float16 - - q = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - k = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - attn_out = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda") - - b_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda") - kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - - for i in range(batch_size): - kv_cache_start_loc[i] = i * seq_len - kv_cache_seq_len[i] = seq_len - b_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda") - - token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) - - torch_out = torch_attn(q, k, batch_size, seq_len, head_num, head_dim).squeeze() - o = attn_out.squeeze() - print("max ", torch.max(torch.abs(torch_out - o))) - print("mean ", torch.mean(torch.abs(torch_out - o))) - assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) - - -if __name__ == "__main__": - test_attn_1() diff --git a/tests/test_infer_ops/triton/test_token_attn_2.py b/tests/test_infer_ops/triton/test_token_attn_2.py deleted file mode 100644 index 2dd756f2ba91..000000000000 --- a/tests/test_infer_ops/triton/test_token_attn_2.py +++ /dev/null @@ -1,63 +0,0 @@ -import pytest -import torch -from packaging import version - -try: - pass - - from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_2 - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -def torch_attn(V, P, bs, seqlen, num_head, head_dim): - V = V.view(bs, seqlen, num_head, head_dim).transpose(1, 2) - P = P.reshape(num_head, bs, 1, seqlen).transpose(0, 1) - attn_out = torch.matmul(P, V) - - return attn_out - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_token_attn_2(): - pass - - batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128 - dtype = torch.float16 - - V = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=10) - Prob = ( - torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda") - .normal_(mean=0.4, std=0.2) - .reshape(head_num, batch_size, seq_len) - .softmax(-1) - .reshape(head_num, batch_size * seq_len) - ) - attn_out = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda") - - kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - kv_cache_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda") - for i in range(batch_size): - kv_cache_start_loc[i] = i * seq_len - kv_cache_seq_len[i] = seq_len - kv_cache_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda") - - token_attn_fwd_2(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) - - torch_out = torch_attn(V, Prob, batch_size, seq_len, head_num, head_dim).squeeze() - o = attn_out - print("max ", torch.max(torch.abs(torch_out - o))) - print("mean ", torch.mean(torch.abs(torch_out - o))) - assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) - - -if __name__ == "__main__": - test_token_attn_2() diff --git a/tests/test_infer_ops/triton/test_token_attn_fwd.py b/tests/test_infer_ops/triton/test_token_attn_fwd.py index 9c7a53798317..a7fc3d29b77a 100644 --- a/tests/test_infer_ops/triton/test_token_attn_fwd.py +++ b/tests/test_infer_ops/triton/test_token_attn_fwd.py @@ -3,16 +3,13 @@ from packaging import version try: - pass - from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd - HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) >= version.parse("11.6") def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): diff --git a/tests/test_legacy/test_amp/test_naive_fp16.py b/tests/test_legacy/test_amp/test_naive_fp16.py index 76f9ff07407f..fe16bc4d480a 100644 --- a/tests/test_legacy/test_amp/test_naive_fp16.py +++ b/tests/test_legacy/test_amp/test_naive_fp16.py @@ -6,7 +6,7 @@ import colossalai from colossalai.legacy.amp import convert_to_apex_amp, convert_to_naive_amp from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo def check_equal(a, b): @@ -25,13 +25,12 @@ def run_naive_amp(): torch.backends.cudnn.deterministic = True # create layer - test_models = ["repeated_computed_layers", "nested_model", "resnet18"] + test_models = ["custom_repeated_computed_layers", "custom_nested_model", "torchvision_resnet18"] for test_name in test_models: - get_component_func = non_distributed_component_funcs.get_callable(test_name) - model_builder, train_dataloader, _, optim_class, _ = get_component_func() + model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(test_name).values())) # create model - naive_amp_model = model_builder(checkpoint=True).cuda() + naive_amp_model = model_builder().cuda() apex_amp_model = copy.deepcopy(naive_amp_model) # create optimizer @@ -48,13 +47,12 @@ def run_naive_amp(): apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config) # create data - data_iter = iter(train_dataloader) - data, label = next(data_iter) - data = data.cuda() + data = data_gen_fn() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} # forward pass - naive_amp_output = naive_amp_model(data) - apex_amp_output = apex_amp_model(data) + naive_amp_output = naive_amp_model(**data) + apex_amp_output = apex_amp_model(**data) assert_close_loose(naive_amp_output, apex_amp_output) # backward diff --git a/tests/test_legacy/test_amp/test_torch_fp16.py b/tests/test_legacy/test_amp/test_torch_fp16.py index 47b303745e4e..5e2e1ede5725 100644 --- a/tests/test_legacy/test_amp/test_torch_fp16.py +++ b/tests/test_legacy/test_amp/test_torch_fp16.py @@ -6,7 +6,7 @@ import colossalai from colossalai.legacy.amp import convert_to_apex_amp, convert_to_torch_amp from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo def run_torch_amp(): @@ -18,13 +18,12 @@ def run_torch_amp(): torch.backends.cudnn.deterministic = True # create layer - test_models = ["resnet18", "simple_net"] + test_models = ["torchvision_resnet18", "custom_simple_net"] for test_name in test_models: - get_component_func = non_distributed_component_funcs.get_callable(test_name) - model_builder, train_dataloader, _, optim_class, _ = get_component_func() + model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(test_name).values())) # create model - torch_amp_model = model_builder(checkpoint=True).cuda() + torch_amp_model = model_builder().cuda() apex_amp_model = copy.deepcopy(torch_amp_model) # create optimizer @@ -41,13 +40,12 @@ def run_torch_amp(): apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config) # create data - data_iter = iter(train_dataloader) - data, label = next(data_iter) - data = data.cuda() + data = data_gen_fn() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} # forward pass - torch_amp_output = torch_amp_model(data) - apex_amp_output = apex_amp_model(data) + torch_amp_output = torch_amp_model(**data) + apex_amp_output = apex_amp_model(**data) assert_close_loose(torch_amp_output, apex_amp_output) for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()): diff --git a/tests/test_legacy/test_engine/test_engine.py b/tests/test_legacy/test_engine/test_engine.py index b07fe8abe86e..1bb0b49c5362 100644 --- a/tests/test_legacy/test_engine/test_engine.py +++ b/tests/test_legacy/test_engine/test_engine.py @@ -1,10 +1,11 @@ import pytest +import torch import colossalai from colossalai.legacy.amp import AMP_TYPE from colossalai.legacy.core import global_context as gpc -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from tests.components_to_test.registry import non_distributed_component_funcs +from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo CONFIG = dict( parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), fp16=dict(mode=None), clip_grad_norm=1.0 @@ -15,29 +16,29 @@ @parameterize("amp_mode", [AMP_TYPE.APEX, AMP_TYPE.TORCH, AMP_TYPE.NAIVE, None]) def run_train(model_name, amp_mode): # FIXME: test bert - get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values())) + train_dataloader = DummyDataloader(data_gen_fn) + criterion = lambda x: x.sum() gpc.config.fp16["mode"] = amp_mode - model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() - model = model_builder(checkpoint=False) + model = model_builder() engine, train_dataloader, *args = colossalai.legacy.initialize( model=model, - optimizer=optimizer_class(model.parameters(), lr=1e-3), + optimizer=torch.optim.Adam(model.parameters(), lr=1e-3), criterion=criterion, train_dataloader=train_dataloader, ) try: engine.train() - for data, label in train_dataloader: + for data in train_dataloader: engine.zero_grad() - data = data.cuda() - label = label.cuda() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} if criterion: - output = engine(data) - loss = engine.criterion(output, label) + output = engine(**data) + loss = engine.criterion(output) else: - loss = engine(data, label) + loss = engine(**data) engine.backward(loss) engine.step() break diff --git a/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py b/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py index d19b12a5b044..d75ddbff7cf3 100644 --- a/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py +++ b/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py @@ -5,9 +5,9 @@ from colossalai.legacy.amp.amp_type import AMP_TYPE from colossalai.legacy.trainer import Trainer from colossalai.logging import get_dist_logger -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import MultiTimer -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo BATCH_SIZE = 4 IMG_SIZE = 32 @@ -16,12 +16,14 @@ CONFIG = dict(fp16=dict(mode=AMP_TYPE.TORCH)) -@parameterize("model_name", ["repeated_computed_layers", "resnet18", "nested_model"]) +@parameterize("model_name", ["custom_repeated_computed_layers", "torchvision_resnet18", "custom_nested_model"]) def run_trainer(model_name): - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values())) model = model_builder() - optimizer = optimizer_class(model.parameters(), lr=1e-3) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + train_dataloader = DummyDataloader(data_gen_fn) + test_dataloader = DummyDataloader(data_gen_fn) + criterion = lambda x: x.sum() engine, train_dataloader, *_ = colossalai.legacy.initialize( model=model, optimizer=optimizer, criterion=criterion, train_dataloader=train_dataloader ) diff --git a/tests/test_optimizer/test_adam_kernel.py b/tests/test_optimizer/test_adam_kernel.py index 8131ea3234d8..6bbe3e4e8172 100644 --- a/tests/test_optimizer/test_adam_kernel.py +++ b/tests/test_optimizer/test_adam_kernel.py @@ -13,9 +13,7 @@ _FUSED_ALLOWED_P_G_TYPES = [ (torch.float, torch.half), (torch.float, torch.float), - (torch.half, torch.float), (torch.half, torch.half), - (torch.bfloat16, torch.float), (torch.float, torch.bfloat16), (torch.bfloat16, torch.bfloat16), ] @@ -23,7 +21,6 @@ _CPU_ALLOWED_P_G_TYPES = [ (torch.float, torch.half), (torch.float, torch.float), - (torch.half, torch.float), (torch.half, torch.half), ] @@ -138,8 +135,8 @@ def check_adam_kernel( master_exp_avg_sq = torch.zeros_like(master_p) p = master_p.clone().to(p_dtype) g = master_g.clone().to(g_dtype) - exp_avg = master_exp_avg.clone() - exp_avg_sq = master_exp_avg_sq.clone() + exp_avg = master_exp_avg.clone().to(p_dtype) + exp_avg_sq = master_exp_avg_sq.clone().to(p_dtype) for step in range(1, 1 + n_steps): torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq) diff --git a/tests/test_optimizer/test_adam_optim.py b/tests/test_optimizer/test_adam_optim.py index 59b40a0afa3c..68d71e3c4194 100644 --- a/tests/test_optimizer/test_adam_optim.py +++ b/tests/test_optimizer/test_adam_optim.py @@ -21,8 +21,6 @@ (torch.float, torch.float), # pure fp32 (torch.float, torch.half), # fp16 amp (torch.float, torch.bfloat16), # bfloat16 amp - # (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16 - # (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16 ] N_STEPS = 3 diff --git a/tests/test_optimizer/test_nvme.py b/tests/test_optimizer/test_nvme.py index a68a9c51855f..4ff16bb9b7c9 100644 --- a/tests/test_optimizer/test_nvme.py +++ b/tests/test_optimizer/test_nvme.py @@ -2,7 +2,7 @@ from colossalai.nn.optimizer import CPUAdam, HybridAdam from colossalai.testing import clear_cache_before_run, parameterize -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo def move_some_params_to_cuda(model, torch_model): @@ -22,8 +22,7 @@ def check_params_equal(model, torch_model): @parameterize("nvme_offload_dir", ["./offload", None]) @parameterize("adam_cls", [CPUAdam, HybridAdam]) def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls): - get_components_func = non_distributed_component_funcs.get_callable("simple_net") - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry("custom_simple_net").values())) model = model_builder() torch_model = model_builder() move_some_params_to_cuda(model, torch_model) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 66d77b48aa0c..6acbe4ff523d 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup from torch.nn import Module from torch.optim import Adam, Optimizer +from torch.testing import assert_close from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin @@ -160,7 +161,7 @@ def _criterion(outputs, inputs): input_shape = data["input_ids"].shape for k, v in data.items(): if v.shape == input_shape: - data[k] = v.repeat((1, ) * (v.dim() - 1) + (times,)) + data[k] = v.repeat((1,) * (v.dim() - 1) + (times,)) sharded_model.train() if booster.plugin.stage_manager is not None: @@ -207,15 +208,11 @@ def check_output_hidden_state( else: sharded_hidden_state = sharded_output.last_hidden_state - assert torch.allclose( - org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol - ), f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" + assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol) def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3): - assert torch.allclose( - org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol - ), f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}" + assert torch.allclose(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol) def check_weight( @@ -242,9 +239,7 @@ def check_weight( if verbose and dist.get_rank() == 0: print(f"'{suffix}' weight: {org_weight}, {sharded_weight}") - assert torch.allclose( - org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol - ), f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}" + assert_close(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol) def get_grad_tensors_for_check( @@ -310,9 +305,7 @@ def check_grad( if verbose and dist.get_rank() == 0: print(f"'{suffix}' grad: {org_grad}, {shard_grad}") - assert torch.allclose( - org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol - ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" + assert_close(org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol) def unwrap_model( @@ -337,6 +330,4 @@ def check_all_grad_tensors(check_tensors): shard_grad = check_info["shard_grad"] rtol = check_info["rtol"] atol = check_info["atol"] - assert torch.allclose( - org_grad, shard_grad, atol=atol, rtol=rtol - ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" + assert_close(org_grad, shard_grad, atol=atol, rtol=rtol) diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 1c934bd22340..3a8af2d6d481 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -43,7 +43,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config["precision"] == "fp32": - atol, rtol = 1e-5, 1e-3 + atol, rtol = 2e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 row_layer_grads = get_grad_tensors_for_check( @@ -62,7 +62,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config["precision"] == "fp32": - atol, rtol = 1e-5, 1e-3 + atol, rtol = 2e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 @@ -154,15 +154,6 @@ def run_vit_test(test_config): "precision": "fp32", "initial_scale": 1, }, - { - "tp_size": 2, - "pp_size": 2, - "num_microbatches": 2, - "enable_all_optimization": False, - "use_lazy_init": False, - "precision": "fp32", - "initial_scale": 1, - }, ], ) def run_vit_3d_test(test_config): diff --git a/tests/test_smoothquant/test_llama_attention.py b/tests/test_smoothquant/test_llama_attention.py new file mode 100644 index 000000000000..f8c79145c952 --- /dev/null +++ b/tests/test_smoothquant/test_llama_attention.py @@ -0,0 +1,136 @@ +import pytest +import torch +from packaging import version + +try: + from colossalai.kernel.triton import int8_rotary_embedding_fwd + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +try: + from colossalai.inference.quant.smoothquant.models import LLamaSmoothquantAttention + + HAS_TORCH_INT = True +except ImportError: + HAS_TORCH_INT = False + print("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int") + + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + +import math + +import torch +from torch.nn import functional as F + + +def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): + """ + adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253 + """ + xq = xq.view(bs, seqlen, num_head, head_dim) + xk = xk.view(bs, seqlen, num_head, head_dim) + xv = xv.view(bs, seqlen, num_head, head_dim) + mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() + mask[mask == 0.0] = -100000000.0 + mask = mask.repeat(bs, num_head, 1, 1) + keys = xk + values = xv + xq = xq.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim) + scores = F.softmax(scores.float() + mask, dim=-1).type_as(xq) + output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) + + return output + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_TORCH_INT, + reason="triton requires cuda version to be higher than 11.4 or not install torch_int", +) +def test_llama_context_attention(): + head_num = 2 + seq_len = 32 + head_dim = 64 + dtype = torch.float + hidden_size = head_num * head_dim + + smooth_attn = LLamaSmoothquantAttention(head_num * head_dim, head_num) + + smooth_attn.q_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) + smooth_attn.k_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) + smooth_attn.v_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) + smooth_attn.out_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) + smooth_attn.out_proj.weight[:, 1:hidden_size] = torch.zeros(hidden_size - 1, device="cuda").to(torch.int8) + + qkv_weight_scale = 1.0 + + ones = torch.ones(hidden_size, hidden_size, dtype=torch.float, device="cuda") + + smooth_attn = smooth_attn.to("cuda") + + input = torch.randint(-20, 20, (1, seq_len, head_num * head_dim), dtype=torch.int8, device="cuda") + input_scale = 1 / 20.0 + + output = torch.matmul(input.to(torch.float) * input_scale, ones) + qkv_max_out = torch.max(torch.abs(output)) / 127 + smooth_attn.q_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out) + smooth_attn.k_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out) + smooth_attn.v_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out) + + q = smooth_attn.q_proj(input) + k = smooth_attn.k_proj(input) + v = smooth_attn.v_proj(input) + + cos_shape = (seq_len, head_dim // 2) + cos = torch.ones(cos_shape, dtype=dtype, device="cuda") + sin = torch.zeros(cos_shape, dtype=dtype, device="cuda") + in_scale = torch.tensor([qkv_max_out], device="cuda") + out_scale = torch.tensor([qkv_max_out], device="cuda") + int8_rotary_embedding_fwd(q.view(-1, head_num, head_dim), cos, sin, in_scale.item(), out_scale.item()) + int8_rotary_embedding_fwd(k.view(-1, head_num, head_dim), cos, sin, in_scale.item(), out_scale.item()) + + q = q.to(torch.float) * out_scale + k = k.to(torch.float) * out_scale + v = v.to(torch.float) * out_scale + torch_out = torch_context_attention(q.clone(), k.clone(), v.clone(), 1, seq_len, head_num, head_dim) + attn_out_max = torch.max(torch.abs(torch_out)) / 127 + + output = torch.matmul(torch_out.view(-1, seq_len, head_num * head_dim), ones) + smooth_attn.q_output_scale = torch.tensor(qkv_max_out) + smooth_attn.k_output_scale = torch.tensor(qkv_max_out) + + smooth_attn.v_output_scale = torch.tensor(qkv_max_out) + smooth_attn.q_rotary_output_scale = torch.tensor(qkv_max_out) + smooth_attn.k_rotary_output_scale = torch.tensor(qkv_max_out) + + smooth_attn.attn_output_scale = torch.tensor(attn_out_max) + smooth_attn.out_proj.a = torch.tensor([attn_out_max]) + + torch_out = ( + (torch_out / smooth_attn.attn_output_scale) + .round() + .clamp(-128, 127) + .to(torch.int8) + .view(-1, seq_len, head_num * head_dim) + ) + + torch_out = smooth_attn.out_proj(torch_out) + torch_out = torch_out.to(torch.float) + + smooth_attn = smooth_attn.to("cuda") + smooth_out, _, _ = smooth_attn(input, (cos, sin)) + smooth_out = smooth_out.to(torch.float) + + assert torch.allclose( + torch_out.cpu(), smooth_out.cpu(), rtol=1e-1, atol=1e-1 + ), "outputs from triton and torch are not matched" + + +if __name__ == "__main__": + test_llama_context_attention() diff --git a/tests/test_smoothquant/test_llama_mlp.py b/tests/test_smoothquant/test_llama_mlp.py new file mode 100644 index 000000000000..236edb10cb7f --- /dev/null +++ b/tests/test_smoothquant/test_llama_mlp.py @@ -0,0 +1,84 @@ +import warnings + +import pytest +import torch +from packaging import version + +try: + from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder + + smoothquant_cuda = SmoothquantBuilder().load() + HAS_SMOOTHQUANT_CUDA = True +except: + warnings.warn("CUDA smoothquant linear is not installed") + HAS_SMOOTHQUANT_CUDA = False + + +try: + from colossalai.inference.quant.smoothquant.models import LlamaSmoothquantMLP + + HAS_TORCH_INT = True +except: + HAS_TORCH_INT = False + warnings.warn("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int") + + +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +def torch_llama_mlp(gate_proj, up_proj, down_proj, x): + gate_out = torch.mm(x, gate_proj) + silu = torch.nn.SiLU() + gate_out = silu(gate_out) + up_out = torch.mm(x, up_proj) + + o_out = gate_out * up_out + + max_up = torch.max(torch.abs(o_out)) + min_up = torch.min(torch.abs(o_out)) + + torch_out = torch.mm(o_out, down_proj) + + return (torch_out, max_up, min_up) + + +@pytest.mark.skipif( + not CUDA_SUPPORT or not HAS_SMOOTHQUANT_CUDA or not HAS_TORCH_INT, + reason="smoothquant linear not installed properly or not install torch_int", +) +def test_llama_mlp(): + hidden_size = 256 + intermediate_size = 512 + + smooth_mlp = LlamaSmoothquantMLP(intermediate_size, hidden_size) + + smooth_mlp.gate_proj.weight = torch.ones((intermediate_size, hidden_size), dtype=torch.int8, device="cuda") + + smooth_mlp.up_proj.weight = torch.randint( + -10, 10, (intermediate_size, hidden_size), dtype=torch.int8, device="cuda" + ) + smooth_mlp.down_proj.weight = torch.randint( + -10, 10, (hidden_size, intermediate_size), dtype=torch.int8, device="cuda" + ) + + x = torch.ones((1, 256), dtype=torch.int8, device="cuda") + + torch_out, max_inter, min_inter = torch_llama_mlp( + smooth_mlp.gate_proj.weight.transpose(0, 1).to(torch.float) / hidden_size, + smooth_mlp.up_proj.weight.transpose(0, 1).to(torch.float) / 127, + smooth_mlp.down_proj.weight.transpose(0, 1).to(torch.float) / 127, + x.to(torch.float), + ) + + smooth_mlp.down_proj_input_scale = torch.tensor(max_inter.item() / 127) + smooth_mlp.gate_proj.a = torch.tensor(1 / hidden_size) + smooth_mlp.up_proj.a = torch.tensor(1 / 127) + smooth_mlp.down_proj.a = torch.tensor(1 / 127 * (max_inter.item() / 127)) + + smooth_out = smooth_mlp(x) + + assert torch.allclose(torch_out, smooth_out, rtol=1e-02, atol=1e-01) + + +if __name__ == "__main__": + test_llama_mlp() diff --git a/tests/test_smoothquant/test_smoothquant_linear.py b/tests/test_smoothquant/test_smoothquant_linear.py new file mode 100644 index 000000000000..58a0b82f6759 --- /dev/null +++ b/tests/test_smoothquant/test_smoothquant_linear.py @@ -0,0 +1,39 @@ +import warnings + +import pytest +import torch + +try: + from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder + + smoothquant_cuda = SmoothquantBuilder().load() + HAS_SMOOTHQUANT_CUDA = True +except: + warnings.warn("CUDA smoothquant linear is not installed") + HAS_SMOOTHQUANT_CUDA = False + + +@pytest.mark.skipif( + not HAS_SMOOTHQUANT_CUDA, + reason="smoothquant linear not installed properly", +) +def test_linear(): + a = torch.randint(-127, 127, (128, 512), dtype=torch.int8, device="cuda") + b = torch.randint(-127, 127, (512, 256), dtype=torch.int8, device="cuda") + c = torch.rand(256, dtype=torch.float, device="cuda") + + alpha = 1 / 127 + beta = 1.0 + torch_out = torch.mm(a.to(torch.float) * alpha, b.to(torch.float)) + c + + silu = torch.nn.SiLU() + torch_out = silu(torch_out) + + b = b.transpose(0, 1).contiguous() + cuda_out = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(a, b, c, alpha, beta) + + assert torch.allclose(torch_out, cuda_out, rtol=1e-02, atol=1e-02) + + +if __name__ == "__main__": + test_linear() diff --git a/tests/test_infer_ops/triton/test_rotary_embedding.py b/tests/test_smoothquant/test_sq_rotary_embedding.py similarity index 73% rename from tests/test_infer_ops/triton/test_rotary_embedding.py rename to tests/test_smoothquant/test_sq_rotary_embedding.py index 7e05ccafbfc4..4cc76f00474d 100644 --- a/tests/test_infer_ops/triton/test_rotary_embedding.py +++ b/tests/test_smoothquant/test_sq_rotary_embedding.py @@ -6,9 +6,7 @@ from packaging import version try: - pass - - from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd + from colossalai.kernel.triton import int8_rotary_embedding_fwd HAS_TRITON = True except ImportError: @@ -36,7 +34,7 @@ def test_rotary_emb(): SEQ_LEN = 1 HEAD_NUM = 32 HEAD_DIM = 128 - dtype = torch.half + dtype = torch.float # create data x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM) x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") @@ -45,10 +43,16 @@ def test_rotary_emb(): sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") # forward pass y_torch = torch_rotary_emb(x, cos, sin) - rotary_embedding_fwd(x, cos, sin) - y_triton = x - # compare - assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=0) + + input_scale = torch.max(torch.abs(x)) / 127 + output_scale = torch.max(torch.abs(y_torch)) / 127 + + x = x / input_scale + x = x.to(torch.int8) + + int8_rotary_embedding_fwd(x, cos, sin, input_scale.item(), output_scale.item()) + y_triton = x.to(torch.float) * output_scale + assert torch.allclose(y_triton, y_torch, atol=2e-1, rtol=1e-2, equal_nan=True) if __name__ == "__main__": diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index 2fb2bcbc851a..b8d3f45e0f34 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -12,8 +12,7 @@ from colossalai.utils.cuda import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration -from tests.components_to_test import run_fwd_bwd -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo, run_fwd_bwd PLACEMENT_CONFIGS = [ {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 @@ -38,7 +37,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("keep_gather", [False, True]) -@parameterize("model_name", ["gpt2", "bert"]) +@parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("use_grad_checkpoint", [False, True]) @parameterize("master_weights", [False, True]) def exam_gpt_fwd_bwd( @@ -49,17 +48,22 @@ def exam_gpt_fwd_bwd( master_weights: bool = True, ): init_device = get_current_device() - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( + iter(model_zoo.get_sub_registry(model_name).values()) + ) set_seed(42) - model = model_builder(use_grad_checkpoint) + model = model_builder() set_seed(42) - torch_model = model_builder(use_grad_checkpoint).cuda() + torch_model = model_builder().cuda() for torch_p, p in zip(torch_model.parameters(), model.parameters()): torch_p.data.copy_(p.data) + if use_grad_checkpoint: + model.gradient_checkpointing_enable() + torch_model.gradient_checkpointing_enable() + world_size = torch.distributed.get_world_size() config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]["chunk_size"] = 5000 @@ -77,25 +81,22 @@ def exam_gpt_fwd_bwd( torch_model = DDP(torch_model, device_ids=[rank]) set_seed(rank) - for i, (input_ids, label) in enumerate(train_dataloader): - # you can only test a single fwd + bwd. - # after bwd param is grad for Gemini, due to the chunk reuse optimization. - if i > 0: - break - input_ids, label = input_ids.cuda(), label.cuda() - torch_optim.zero_grad() - zero_optim.zero_grad() + data = data_gen_fn() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} + + torch_optim.zero_grad() + zero_optim.zero_grad() - # set random seed is same as torch_model.eval() - set_seed(42) - torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) - set_seed(42) - loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) + # set random seed is same as torch_model.eval() + set_seed(42) + torch_loss = run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim) + set_seed(42) + loss = run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim) - assert torch.equal(torch_loss, loss) + assert_close(torch_loss.float(), loss.float()) - check_grad(model, torch_model) + check_grad(model, torch_model) def run_dist(rank, world_size, port): diff --git a/tests/test_zero/test_gemini/test_gemini_use_rmt.py b/tests/test_zero/test_gemini/test_gemini_use_rmt.py index 2fa2d50a6caa..90ad62d1ac78 100644 --- a/tests/test_zero/test_gemini/test_gemini_use_rmt.py +++ b/tests/test_zero/test_gemini/test_gemini_use_rmt.py @@ -3,38 +3,34 @@ import torch.distributed as dist import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed from colossalai.zero import GeminiDDP from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer -from tests.components_to_test import run_fwd_bwd -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo, run_fwd_bwd # run gemini use the runtime memory tracer @parameterize("placement_policy", ["auto"]) @parameterize("keep_gather", [False]) -@parameterize("model_name", ["repeated_computed_layers", "bert", "albert", "gpt2"]) +@parameterize("model_name", ["transformers_bert_for_sequence_classification"]) @parameterize("use_grad_checkpoint", [False, True]) def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_checkpoint: bool = False): set_seed(42) - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values())) - model = model_builder(use_grad_checkpoint).cuda() + model = model_builder().cuda() + if use_grad_checkpoint: + model.gradient_checkpointing_enable() print(f"model_name {model_name}") - runtime_mem_tracer = RuntimeMemTracer(model) - for i, (input_ids, label) in enumerate(train_dataloader): - if i > 0: - break - input_ids, label = input_ids.cuda(), label.cuda() - # mem tracing - if i == 0: - run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer) + runtime_mem_tracer = RuntimeMemTracer(model) + data = data_gen_fn() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} + run_fwd_bwd(runtime_mem_tracer, data, output_transform_fn, optimizer=runtime_mem_tracer) memstats = runtime_mem_tracer.memstats() runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list print("runtime tracer non model data points: ", len(runtime_tracer_non_model_data)) @@ -62,16 +58,17 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ ) set_seed(dist.get_rank()) - for i, (input_ids, label) in enumerate(train_dataloader): + train_dataloader = DummyDataloader(data_gen_fn) + for i, data in enumerate(train_dataloader): # you can only test a single fwd + bwd. # after bwd param is grad for Gemini, due to the chunk reuse optimization. # print(f'iteration {i}') if i > 4: break - input_ids, label = input_ids.cuda(), label.cuda() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} set_seed(42) - run_fwd_bwd(model, input_ids, label, criterion, model) + run_fwd_bwd(model, data, output_transform_fn, optimizer=model) gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list("cuda") diff --git a/tests/test_zero/test_gemini/test_grad_accum.py b/tests/test_zero/test_gemini/test_grad_accum.py new file mode 100644 index 000000000000..5e36b18389b1 --- /dev/null +++ b/tests/test_zero/test_gemini/test_grad_accum.py @@ -0,0 +1,145 @@ +import pytest +import torch +import torch.distributed as dist +from apex import amp +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed +from colossalai.utils.cuda import get_current_device +from colossalai.zero import GeminiDDP, GeminiOptimizer +from colossalai.zero.gemini.chunk import search_chunk_configuration +from tests.kit.model_zoo import model_zoo, run_fwd + +PLACEMENT_CONFIGS = [ + {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 1.0}, # zero3 + {"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half + {"placement_policy": "auto"}, +] + + +def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): + chunk_manager = model.chunk_manager + grad_chunk_list = [] + device_list = [] + + # Access gradient chunks. + for p in model.parameters(): + grad_chunk = chunk_manager.get_chunk(p).grad_chunk + if grad_chunk not in grad_chunk_list: + chunk_manager.access_chunk(grad_chunk) + grad_chunk_list.append(grad_chunk) + device_list.append(model.grads_device[p]) + + # Compare gradients. + for p0, p1 in zip(model.parameters(), torch_model.parameters()): + assert_close(p0, p1.grad, rtol=2e-3, atol=2e-2) + + # Release gradient chunks and move them to gradient device. + for grad_chunk, device in zip(grad_chunk_list, device_list): + chunk_manager.release_chunk(grad_chunk) + chunk_manager.move_chunk(grad_chunk, device, force_copy=True) + + +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("keep_gathered", [False, True]) +@parameterize("model_name", ["transformers_gpt_lm"]) +@parameterize("master_weights", [False, True]) +def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str, master_weights: bool): + init_device = get_current_device() + model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( + iter(model_zoo.get_sub_registry(model_name).values()) + ) + + set_seed(42) + gemini_model = model_builder() + + set_seed(42) + torch_model = model_builder().cuda() + for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()): + torch_p.data.copy_(p.data) + + world_size = torch.distributed.get_world_size() + config_dict, *_ = search_chunk_configuration(gemini_model, search_range_m=1, search_interval=100) + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = keep_gathered + gemini_model = GeminiDDP( + gemini_model, + config_dict, + init_device, + pin_memory=True, + enable_gradient_accumulation=True, + master_weights=master_weights, + **placement_config, + ) + optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3) + gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1) + + rank = dist.get_rank() + + # setting master_weights to False will cause overflow after optimizer.step() + amp_config = dict( + opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1, min_loss_scale=1, max_loss_scale=1, master_weights=True + ) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = amp.initialize(torch_model, torch_optim, **amp_config) + torch_model = DDP(torch_model, device_ids=[rank]) + + set_seed(rank) + accum_iter = 4 + train_dataloader = DummyDataloader(data_gen_fn) + for i, data in enumerate(train_dataloader): + delay_unscale = False if (i + 1) % accum_iter == 0 else True + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} + + set_seed(42 + rank) + torch_loss = run_fwd(torch_model, data, output_transform_fn, loss_fn) + torch_loss = torch_loss / accum_iter + with amp.scale_loss(torch_loss, torch_optim, delay_unscale=delay_unscale) as scaled_loss: + scaled_loss.backward() + + set_seed(42 + rank) + gemini_loss = run_fwd(gemini_model, data, output_transform_fn, loss_fn) + gemini_loss = gemini_loss / accum_iter + gemini_optim.backward(gemini_loss) + + assert torch.allclose(torch_loss.float(), gemini_loss.float(), rtol=1e-3, atol=1e-5) + + check_grad(gemini_model, torch_model) + + if (i + 1) % accum_iter == 0: + torch_optim.step() + gemini_optim.step() + torch_optim.zero_grad() + + # check updated param + torch_dict = torch_model.state_dict() + gemini_dict = gemini_model.state_dict(only_rank_0=False) + + for key, value in gemini_dict.items(): + torch_key = "module." + key + torch_value = torch_dict[torch_key].to(value.device).to(value.dtype) + assert_close(value, torch_value, rtol=1e-3, atol=2e-3) + + if i == accum_iter: + break + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + exam_gemini_grad_acc() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_grad_accumulation(): + spawn(run_dist, 2) + + +if __name__ == "__main__": + test_grad_accumulation() diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index a3af81646a18..c3a36d3bafa1 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -7,12 +7,11 @@ import colossalai from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration -from tests.components_to_test import run_fwd_bwd -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo, run_fwd_bwd PLACEMENT_CONFIGS = [ { @@ -51,11 +50,13 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("placement_config", PLACEMENT_CONFIGS) -@parameterize("model_name", ["gpt2"]) -def exam_grad_clipping(placement_config, model_name: str): +@parameterize("model_name", ["transformers_gpt_lm"]) +@parameterize("master_weights", [True, False]) +def exam_grad_clipping(placement_config, model_name: str, master_weights: bool): set_seed(1912) - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( + iter(model_zoo.get_sub_registry(model_name).values()) + ) torch_model = model_builder().cuda() amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=32) @@ -82,6 +83,7 @@ def exam_grad_clipping(placement_config, model_name: str): chunk_config_dict=config_dict, chunk_init_device=init_device, pin_memory=True, + master_weights=master_weights, **placement_config, ) @@ -92,18 +94,17 @@ def exam_grad_clipping(placement_config, model_name: str): torch_model.train() set_seed(dist.get_rank() * 3 + 128) - for i, (data, label) in enumerate(train_dataloader): + train_dataloader = DummyDataloader(data_gen_fn) + for i, data in enumerate(train_dataloader): if i > 2: break - data = data.cuda() - label = label.cuda() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} zero_optim.zero_grad() torch_optim.zero_grad() - torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim) - loss = run_fwd_bwd(model, data, label, criterion, zero_optim) - assert_close(torch_loss, loss) + run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim) + run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim) import apex.amp as apex_amp @@ -111,7 +112,8 @@ def exam_grad_clipping(placement_config, model_name: str): torch_optim.step() zero_optim.step() - check_param(model, torch_model) + if master_weights: + check_param(model, torch_model) def run_dist(rank, world_size, port): diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py index 2b2b246a9f54..e20428b67b41 100644 --- a/tests/test_zero/test_gemini/test_inference.py +++ b/tests/test_zero/test_gemini/test_inference.py @@ -9,13 +9,12 @@ import colossalai from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed from colossalai.utils.cuda import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration -from tests.components_to_test import run_fwd_bwd -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd PLACEMENT_CONFIGS = [ {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 @@ -53,12 +52,11 @@ def single_chunk_init(model: torch.nn.Module, placement_config: dict): @parameterize("placement_config", PLACEMENT_CONFIGS) -@parameterize("model_name", ["gpt2"]) +@parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("model_init_func", [single_chunk_init, multi_chunk_init]) def exam_inference(placement_config: dict, model_name: str, model_init_func: Callable): set_seed(19360226) - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values())) torch_model = model_builder().cuda() amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128) @@ -79,29 +77,27 @@ def exam_inference(placement_config: dict, model_name: str, model_init_func: Cal torch_model.eval() set_seed(dist.get_rank() * 3 + 128) - train_dataloader = iter(train_dataloader) + train_dataloader = iter(DummyDataloader(data_gen_fn)) def train_iter(): - input_ids, label = next(train_dataloader) - input_ids, label = input_ids.cuda(), label.cuda() + data = next(train_dataloader) + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} zero_optim.zero_grad() torch_optim.zero_grad() - torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) - loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) - assert_close(torch_loss, loss, rtol=1e-5, atol=1e-5) + torch_loss = run_fwd_bwd(torch_model, data, output_transform_fn, optimizer=torch_optim) + loss = run_fwd_bwd(model, data, output_transform_fn, optimizer=zero_optim) + assert_close(torch_loss.float(), loss.float(), rtol=1e-5, atol=1e-5) zero_optim.step() torch_optim.step() check_param(model, torch_model) def inference_iter(): - input_ids, label = next(train_dataloader) - input_ids, label = input_ids.cuda(), label.cuda() + data = next(train_dataloader) + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} with torch.no_grad(): - torch_output = torch_model(input_ids) - torch_loss = criterion(torch_output.float(), label) - zero_output = model(input_ids) - zero_loss = criterion(zero_output.float(), label) - assert_close(torch_loss, zero_loss) + torch_loss = run_fwd(torch_model, data, output_transform_fn) + zero_loss = run_fwd(model, data, output_transform_fn) + assert_close(torch_loss.float(), zero_loss.float(), rtol=1e-5, atol=1e-5) train_iter() inference_iter() diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 8e8e508ff483..887e495e6187 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -7,13 +7,12 @@ import colossalai from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed from colossalai.utils.cuda import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration -from tests.components_to_test import run_fwd_bwd -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo, run_fwd_bwd PLACEMENT_CONFIGS = [ {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2 @@ -31,14 +30,17 @@ ] # this model is large enough to slice to chunks -TEST_MODELS = ["gpt2"] +TEST_MODELS = ["transformers_gpt_lm"] # these models are too small, all parameters in these models are compacted into one chunk -EXAMPLE_MODELS = ["albert", "beit", "bert", "hanging_param_model", "nested_model", "repeated_computed_layers"] +EXAMPLE_MODELS = [ + "transformers_bert_for_sequence_classification", + "custom_hanging_param_model", + "custom_nested_model", + "custom_repeated_computed_layers", +] # bfloat16 cannot represent them exactly BF16_IGNORED_KEYS = [ - "albert.embeddings.word_embeddings.weight", - "albert.embeddings.position_embeddings.weight", "masked_bias", ] @@ -54,7 +56,7 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty temp_zero_value = zero_dict[key].to(device=value.device) if dtype is torch.bfloat16 and any(k in key for k in BF16_IGNORED_KEYS): continue - rtol, atol = 1e-3, 4e-3 + rtol, atol = 2e-3, 6e-3 if dtype is torch.bfloat16: rtol, atol = 4e-3, 8e-3 # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) @@ -70,12 +72,15 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("model_name", TEST_MODELS) @parameterize("mixed_precision", [torch.half, torch.bfloat16]) -def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype): +@parameterize("master_weights", [True, False]) +def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool): set_seed(42) - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( + iter(model_zoo.get_sub_registry(model_name).values()) + ) torch_model = model_builder().cuda() + # apex no master weights leads to nan, so we don't use it amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) @@ -90,7 +95,9 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = False - model = GeminiDDP(model, config_dict, **placement_config, mixed_precision=mixed_precision) + model = GeminiDDP( + model, config_dict, **placement_config, mixed_precision=mixed_precision, master_weights=master_weights + ) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128) @@ -99,31 +106,36 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt torch_model.eval() set_seed(dist.get_rank() * 3 + 128) - rtol, atol = 1e-4, 1e-5 - for i, (input_ids, label) in enumerate(train_dataloader): + rtol, atol = 4e-2, 4e-2 + train_dataloader = iter(DummyDataloader(data_gen_fn)) + for i, data in enumerate(train_dataloader): if i > 2: break - input_ids, label = input_ids.cuda(), label.cuda() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} zero_optim.zero_grad() torch_optim.zero_grad() - torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) - loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) - assert_close(torch_loss, loss, rtol=rtol, atol=atol) + torch_loss = run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim) + loss = run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim) + # as no master weights leads to error accumulation, we don't check the loss + if master_weights: + assert_close(torch_loss.float(), loss.float(), rtol=rtol, atol=atol) zero_optim.step() torch_optim.step() - check_param(model, torch_model, mixed_precision) + if master_weights: + check_param(model, torch_model, mixed_precision) -@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("placement_config", [PLACEMENT_CONFIGS[3]]) @parameterize("model_name", EXAMPLE_MODELS) -@parameterize("mixed_precision", [torch.half, torch.bfloat16]) +@parameterize("mixed_precision", [torch.half]) def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.dtype): set_seed(2008) - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( + iter(model_zoo.get_sub_registry(model_name).values()) + ) torch_model = model_builder().cuda() amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=2) @@ -151,23 +163,19 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch. torch_model.eval() set_seed(dist.get_rank() * 3 + 128) - rtol, atol = 1.5e-6, 2e-5 - if mixed_precision is torch.bfloat16: - rtol, atol = 2e-3, 2e-3 - for i, (input_ids, label) in enumerate(train_dataloader): + + train_dataloader = DummyDataloader(data_gen_fn) + for i, data in enumerate(train_dataloader): if i > 2: break - input_ids = input_ids.cuda() - label = label.cuda() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} zero_optim.zero_grad() torch_optim.zero_grad() - torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) - loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) - assert_close(torch_loss, loss, rtol=rtol, atol=atol) # atol should be 2e-5 for torch lower than 1.12 - + run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim) + run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim) zero_optim.step() torch_optim.step() diff --git a/tests/test_zero/test_gemini/test_runtime_mem_tracer.py b/tests/test_zero/test_gemini/test_runtime_mem_tracer.py index 8e0f6ae36c46..9d00521be694 100644 --- a/tests/test_zero/test_gemini/test_runtime_mem_tracer.py +++ b/tests/test_zero/test_gemini/test_runtime_mem_tracer.py @@ -4,10 +4,9 @@ import pytest import torch -from colossalai.testing import clear_cache_before_run +from colossalai.testing import DummyDataloader, clear_cache_before_run from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer -from tests.components_to_test import run_fwd_bwd -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo, run_fwd_bwd @pytest.mark.skip("this is not used") @@ -16,21 +15,22 @@ def test_runtime_mem_tracer(): test_models = ["gpt2", "bert", "simple_net", "repeated_computed_layers", "nested_model", "albert"] for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, _, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, *_ = next( + iter(model_zoo.get_sub_registry(model_name).values()) + ) - model = model_builder(checkpoint=False).cuda() + model = model_builder().cuda() model_bk = deepcopy(model) runtime_mem_tracer = RuntimeMemTracer(model) - for i, (data, label) in enumerate(train_dataloader): + train_dataloader = DummyDataloader(data_gen_fn) + for i, data in enumerate(train_dataloader): if i > 1: break - data = data.cuda() - label = label.cuda() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} - run_fwd_bwd(runtime_mem_tracer, data, label, criterion, optimizer=runtime_mem_tracer) + run_fwd_bwd(runtime_mem_tracer, data, output_transform_fn, optimizer=runtime_mem_tracer) for p1, p2 in zip(model_bk.parameters(), model.parameters()): torch.allclose(p1.to(torch.half), p2) diff --git a/tests/test_zero/test_gemini/test_search.py b/tests/test_zero/test_gemini/test_search.py index e22e5ece42a5..e99f6d59ba8e 100644 --- a/tests/test_zero/test_gemini/test_search.py +++ b/tests/test_zero/test_gemini/test_search.py @@ -5,40 +5,37 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo def exam_search_chunk_size(): - world_size = torch.distributed.get_world_size() - - get_components_func = non_distributed_component_funcs.get_callable("gpt2") - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, *_ = next( + iter(model_zoo.get_sub_registry("transformers_gpt_lm").values()) + ) # make sure torch_model and model has the same parameter values model = model_builder() config_dict, *_ = search_chunk_configuration( - model, search_range_m=1, search_interval=16, min_chunk_size_m=0, filter_exlarge_params=True + model, search_range_m=1, search_interval=128, min_chunk_size_m=0, filter_exlarge_params=True ) for key in config_dict: chunk_size = config_dict[key]["chunk_size"] - if world_size == 1 or True: - assert chunk_size == 31616 - else: - assert chunk_size == 1024 + assert chunk_size == 527872 def exam_chunk_manager(): world_size = torch.distributed.get_world_size() - get_components_func = non_distributed_component_funcs.get_callable("gpt2") - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, *_ = next( + iter(model_zoo.get_sub_registry("transformers_gpt_lm").values()) + ) sharded_ddp_model = model_builder() chunk_manager = init_chunk_manager( sharded_ddp_model, get_current_device(), - hidden_dim=16, + hidden_dim=128, search_range_m=1, min_chunk_size_m=0, filter_exlarge_params=True, @@ -46,7 +43,7 @@ def exam_chunk_manager(): ) config_dict = chunk_manager.dp_degree_chunk_size_dict assert len(config_dict) == 1 - assert config_dict[world_size] == 31616 + assert config_dict[world_size] == 527872 def run_dist(rank, world_size, port): diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py index bf16a301cd8a..cbf5169fc621 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py @@ -7,7 +7,7 @@ from colossalai.utils import set_seed from colossalai.zero import GeminiDDP from colossalai.zero.gemini.chunk import search_chunk_configuration -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo PLACEMENT_CONFIGS = [ {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 @@ -26,15 +26,16 @@ def ignore_the_first_parameter(model: torch.nn.Module): @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("keep_gathered", [True, False]) -@parameterize("model_name", ["gpt2", "bert"]) +@parameterize("model_name", ["transformers_gpt_lm", "transformers_bert_for_sequence_classification"]) @parameterize("master_weights", [False, True]) def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool): set_seed(431) - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values())) model = model_builder() + model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 + torch_model = model_builder() for torch_p, p in zip(torch_model.parameters(), model.parameters()): torch_p.data.copy_(p.data) @@ -54,29 +55,7 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str, master_wei temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) - -@parameterize("placement_config", PLACEMENT_CONFIGS) -@parameterize("keep_gathered", [True, False]) -@parameterize("model_name", ["gpt2", "bert"]) -@parameterize("master_weights", [False, True]) -def exam_load_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool): - set_seed(431) - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - model = model_builder() - - set_seed(451) - torch_model = model_builder() # get a different model - - world_size = torch.distributed.get_world_size() - config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - config_dict[world_size]["chunk_size"] = 5000 - config_dict[world_size]["keep_gathered"] = keep_gathered - - model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights) - - torch_dict = torch_model.state_dict() + # check load state dict model.load_state_dict(torch_dict, strict=False) zero_dict = model.state_dict(only_rank_0=False) @@ -85,23 +64,7 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str, maste temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) - -@parameterize("placement_config", PLACEMENT_CONFIGS) -@parameterize("model_name", ["gpt2", "bert"]) -@parameterize("master_weights", [False, True]) -def exam_state_dict_shard(placement_config, model_name: str, master_weights: bool): - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - model = model_builder() - - model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 - - config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - model = GeminiDDP(model, config_dict, **placement_config, master_weights=master_weights) - model.train() - - zero_dict = model.state_dict(only_rank_0=False) + # check state dict shard accumulated_keys = set() # ensure number of shards > 1 for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False): @@ -116,8 +79,6 @@ def run_dist(rank, world_size, port): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_state_dict() - exam_load_state_dict() - exam_state_dict_shard() @pytest.mark.dist diff --git a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py index c65c6d292467..87cb1cdfe43f 100644 --- a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py +++ b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py @@ -8,7 +8,7 @@ from colossalai.utils import set_seed from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo PLACEMENT_CONFIGS = [ {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2 @@ -22,8 +22,9 @@ @parameterize("keep_gathered", [True, False]) def exam_zero_optim_state_dict(placement_config, keep_gathered): set_seed(431) - get_components_func = non_distributed_component_funcs.get_callable("gpt2") - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, *_ = next( + iter(model_zoo.get_sub_registry("transformers_gpt_lm").values()) + ) model = model_builder() @@ -41,15 +42,15 @@ def exam_zero_optim_state_dict(placement_config, keep_gathered): set_seed(dist.get_rank() * 3 + 128) model.train() - for i, (input_ids, label) in enumerate(train_dataloader): - if i > 0: - break - optim.zero_grad() - logits = model(input_ids) - logits = logits.float() - loss = criterion(logits, input_ids) - optim.backward(loss) - optim.step() + data = data_gen_fn() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} + + optim.zero_grad() + outputs = model(**data) + outputs = output_transform_fn(outputs) + loss = next(iter(outputs.values())).sum() + optim.backward(loss) + optim.step() optim_state_dict = optim.state_dict() optim.load_state_dict(optim_state_dict) diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index ebda9f6f25c5..e2196cfbf0f2 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -106,7 +106,8 @@ def exam_zero_1_2(): @parameterize("dtype", [torch.float16, torch.bfloat16]) -def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype): +@parameterize("master_weights", [True, False]) +def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): """ In this test, two pairs of model and optimizers are created. 1. zero: use sharded optimizer and fp16 parameters @@ -131,7 +132,11 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype): # in `check_sharded_param_consistency.py`, we will test whether # level 1 and 2 will produce exactly the same results zero_optimizer = LowLevelZeroOptimizer( - zero_optimizer, overlap_communication=True, initial_scale=1, reduce_bucket_size=1024 * 1024 + zero_optimizer, + overlap_communication=True, + initial_scale=1, + reduce_bucket_size=1024 * 1024, + master_weights=master_weights, ) torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)