From 8a3c0eea543f7449cbf856e5393819b2a8cdc93e Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Fri, 6 Oct 2023 16:16:43 +0800 Subject: [PATCH 1/7] add smoothquant llama --- .../inference/quant/smoothquant/__init__.py | 0 .../quant/smoothquant/calibration.py | 58 ++ .../quant/smoothquant/models/base_model.py | 428 ++++++++ .../quant/smoothquant/models/linear.py | 132 ++- .../quant/smoothquant/models/llama.py | 944 ++++++++++++++++-- .../inference/quant/smoothquant/smooth.py | 52 + colossalai/kernel/triton/__init__.py | 3 + .../triton/int8_rotary_embedding_kernel.py | 46 +- colossalai/kernel/triton/smooth_attention.py | 649 ++++++++++++ examples/inference/smoothquant_conversion.py | 135 +++ .../test_smoothquant/test_llama_attention.py | 48 +- tests/test_smoothquant/test_llama_mlp.py | 2 +- 12 files changed, 2344 insertions(+), 153 deletions(-) create mode 100644 colossalai/inference/quant/smoothquant/__init__.py create mode 100644 colossalai/inference/quant/smoothquant/calibration.py create mode 100644 colossalai/inference/quant/smoothquant/models/base_model.py create mode 100644 colossalai/inference/quant/smoothquant/smooth.py create mode 100644 colossalai/kernel/triton/smooth_attention.py create mode 100644 examples/inference/smoothquant_conversion.py diff --git a/colossalai/inference/quant/smoothquant/__init__.py b/colossalai/inference/quant/smoothquant/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/inference/quant/smoothquant/calibration.py b/colossalai/inference/quant/smoothquant/calibration.py new file mode 100644 index 000000000000..0b61beed6733 --- /dev/null +++ b/colossalai/inference/quant/smoothquant/calibration.py @@ -0,0 +1,58 @@ +import functools + +import torch +import torch.nn as nn +from datasets import load_dataset +from tqdm import tqdm + + +def get_act_scales(model, tokenizer, dataset_path, num_samples=512, seq_len=512): + model.eval() + device = next(model.parameters()).device + act_scales = {} + + def stat_tensor(name, tensor): + hidden_dim = tensor.shape[-1] + tensor = tensor.view(-1, hidden_dim).abs().detach() + comming_max = torch.max(tensor, dim=0)[0].float().cpu() + if name in act_scales: + act_scales[name] = torch.max(act_scales[name], comming_max) + else: + act_scales[name] = comming_max + + def stat_input_hook(m, x, y, name): + if isinstance(x, tuple): + x = x[0] + stat_tensor(name, x) + + hooks = [] + for name, m in model.named_modules(): + if isinstance(m, nn.Linear): + hooks.append(m.register_forward_hook(functools.partial(stat_input_hook, name=name))) + # print("data path", dataset_path) + # dataset = load_dataset("csv", data_files=dataset_path, split="train") + # dataset = dataset.shuffle(seed=42) + + # dataset = load_dataset("/home/lcxk/data3/datasets/cc_news.py", data_files=dataset_path) + dataset = load_dataset("json", data_files=dataset_path) + + print("text", dataset["train"]["rows"][0][1]["row"]["text"]) + # for test in dataset["train"]["rows"]: + # print(test) + dataset = dataset.shuffle(seed=42) + # print("text", dataset["rows"][0]) + + for i in tqdm(range(num_samples)): + # print("text", dataset[i]) + input_ids = tokenizer( + dataset["train"]["rows"][0][i]["row"]["text"], + return_tensors="pt", + max_length=seq_len, + truncation=True, + ).input_ids.to(device) + model(input_ids) + + for h in hooks: + h.remove() + + return act_scales diff --git a/colossalai/inference/quant/smoothquant/models/base_model.py b/colossalai/inference/quant/smoothquant/models/base_model.py new file mode 100644 index 000000000000..6aec56d557ce --- /dev/null +++ b/colossalai/inference/quant/smoothquant/models/base_model.py @@ -0,0 +1,428 @@ +import os +from abc import abstractmethod +from logging import getLogger +from os.path import isdir, isfile, join +from typing import Dict, List, Optional, Union + +import accelerate +import torch +import torch.nn as nn +import transformers +from safetensors.torch import save_file as safe_save +from torch import device +from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel +from transformers.modeling_utils import no_init_weights +from transformers.utils.generic import ContextManagers +from transformers.utils.hub import PushToHubMixin, cached_file + +logger = getLogger(__name__) + + +CPU = device("cpu") +CUDA_0 = device("cuda:0") + +SUPPORTED_MODELS = ["llama"] + + +def get_module_by_name_suffix(model, module_name: str): + for name, module in model.named_modules(): + if name.endswith(module_name): + return module + + +def simple_dispatch_model(model, device_map): + from accelerate.hooks import AlignDevicesHook, add_hook_to_module + + if "" in device_map: + d = device_map[""] + model = model.to(torch.device(d)) + model.hf_device_map = device_map + return model + + tied_params = accelerate.utils.modeling.find_tied_parameters(model) + if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}: + main_device = "cpu" + else: + main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0] + + cpu_offload_group = [(n, d) for n, d in device_map.items() if d == "cpu"] + prev_hook = None + for idx, (n, d) in enumerate(cpu_offload_group): + m = get_module_by_name_suffix(model, n) + _, prev_hook = accelerate.cpu_offload_with_hook(m, execution_device=main_device, prev_module_hook=prev_hook) + # set first cpu offload module's prev_module_hook to the last cpu offload module's hook + if len(cpu_offload_group) > 1: + get_module_by_name_suffix(model, cpu_offload_group[0][0])._hf_hook.prev_module_hook = prev_hook + + for n, d in device_map.items(): + m = get_module_by_name_suffix(model, n) + if d != "cpu": + d = torch.device(d) + hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True) + add_hook_to_module(m, hook) + accelerate.utils.modeling.retie_parameters(model, tied_params) + model.hf_device_map = device_map + + return model + + +class BaseSmoothForCausalLM(nn.Module, PushToHubMixin): + def __init__(self, model: PreTrainedModel, quantized: bool = False): + super().__init__() + + self.model = model + self.model_type = self.model.config.model_type + self._quantized = quantized + self.config = self.model.config + + @property + def quantized(self): + return self._quantized + + @abstractmethod + @torch.inference_mode() + def quantize( + self, + examples: List[Dict[str, Union[List[int], torch.LongTensor]]], + ): + if self.quantized: + raise EnvironmentError("can't execute quantize because the model is quantized.") + + def to(self, device: Union[str, torch.device]): + self.model.to(device) + return self + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def generate(self, **kwargs): + """shortcut for model.generate""" + with torch.inference_mode(): + return self.model.generate(**kwargs) + + def prepare_inputs_for_generation(self, *args, **kwargs): + """shortcut for model.prepare_inputs_for_generation""" + return self.model.prepare_inputs_for_generation(*args, **kwargs) + + @classmethod + def make_smooth_model(cls, model): + raise NotImplementedError("not implememented smooth model") + + def save_quantized( + self, + save_dir: str, + model_file_base_name: str = None, + use_safetensors: bool = False, + safetensors_metadata: Optional[Dict[str, str]] = None, + ): + """save quantized model and configs to local disk""" + os.makedirs(save_dir, exist_ok=True) + + if not self.quantized: + raise EnvironmentError("can only save quantized model, please execute .quantize first.") + + self.model.to(CPU) + + model_base_name = model_file_base_name or f"smooth-" + if use_safetensors: + model_save_name = model_base_name + ".safetensors" + state_dict = self.model.state_dict() + state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} + if safetensors_metadata is None: + safetensors_metadata = {} + elif not isinstance(safetensors_metadata, dict): + raise TypeError("safetensors_metadata must be a dictionary.") + else: + logger.debug(f"Received safetensors_metadata: {safetensors_metadata}") + new_safetensors_metadata = {} + converted_keys = False + for key, value in safetensors_metadata.items(): + if not isinstance(key, str) or not isinstance(value, str): + converted_keys = True + try: + new_key = str(key) + new_value = str(value) + except Exception as e: + raise TypeError( + f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}" + ) + if new_key in new_safetensors_metadata: + logger.warning( + f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting." + ) + new_safetensors_metadata[new_key] = new_value + safetensors_metadata = new_safetensors_metadata + if converted_keys: + logger.debug( + f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}" + ) + + # Format is required to enable Accelerate to load the metadata + # otherwise it raises an OSError + safetensors_metadata["format"] = "pt" + + safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata) + else: + model_save_name = model_base_name + ".bin" + torch.save(self.model.state_dict(), join(save_dir, model_save_name)) + + self.model.config.save_pretrained(save_dir) + + def save_pretrained( + self, + save_dir: str, + use_safetensors: bool = False, + safetensors_metadata: Optional[Dict[str, str]] = None, + **kwargs, + ): + """alias of save_quantized""" + logger.warning("you are using save_pretrained, which will re-direct to save_quantized.") + self.save_quantized(save_dir, use_safetensors, safetensors_metadata) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + max_memory: Optional[dict] = None, + trust_remote_code: bool = False, + torch_dtype: torch.dtype = torch.float16, + **model_init_kwargs, + ): + if not torch.cuda.is_available(): + raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.") + + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + + # Parameters related to loading from Hugging Face Hub + cache_dir = model_init_kwargs.pop("cache_dir", None) + force_download = model_init_kwargs.pop("force_download", False) + resume_download = model_init_kwargs.pop("resume_download", False) + proxies = model_init_kwargs.pop("proxies", None) + local_files_only = model_init_kwargs.pop("local_files_only", False) + use_auth_token = model_init_kwargs.pop("use_auth_token", None) + revision = model_init_kwargs.pop("revision", None) + subfolder = model_init_kwargs.pop("subfolder", "") + model_init_kwargs.pop("_commit_hash", None) + + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "use_auth_token": use_auth_token, + "revision": revision, + "subfolder": subfolder, + } + + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True, **cached_file_kwargs) + if config.model_type not in SUPPORTED_MODELS: + raise TypeError(f"{config.model_type} isn't supported yet.") + + # enforce some values despite user specified + model_init_kwargs["torch_dtype"] = torch_dtype + model_init_kwargs["trust_remote_code"] = trust_remote_code + if max_memory: + if "disk" in max_memory: + raise NotImplementedError("disk offload not support yet.") + with accelerate.init_empty_weights(): + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + model.tie_weights() + + max_memory = accelerate.utils.get_balanced_memory( + model, + max_memory=max_memory, + no_split_module_classes=[cls.layer_type], + dtype=model_init_kwargs["torch_dtype"], + low_zero=False, + ) + model_init_kwargs["device_map"] = accelerate.infer_auto_device_map( + model, + max_memory=max_memory, + no_split_module_classes=[cls.layer_type], + dtype=model_init_kwargs["torch_dtype"], + ) + model_init_kwargs["low_cpu_mem_usage"] = True + + del model + else: + model_init_kwargs["device_map"] = None + model_init_kwargs["low_cpu_mem_usage"] = False + + torch.cuda.empty_cache() + + merged_kwargs = {**model_init_kwargs, **cached_file_kwargs} + model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs) + + model_config = model.config.to_dict() + seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] + if any([k in model_config for k in seq_len_keys]): + for key in seq_len_keys: + if key in model_config: + model.seqlen = model_config[key] + break + else: + logger.warning("can't get model's sequence length from model config, will set to 4096.") + model.seqlen = 4096 + model.eval() + + return cls(model, False) + + @classmethod + def from_quantized( + cls, + model_name_or_path: Optional[str], + device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None, + max_memory: Optional[dict] = None, + device: Optional[Union[str, int]] = None, + low_cpu_mem_usage: bool = False, + torch_dtype: Optional[torch.dtype] = None, + model_basename: Optional[str] = None, + use_safetensors: bool = False, + trust_remote_code: bool = False, + **kwargs, + ): + """load quantized model from local disk""" + + # Parameters related to loading from Hugging Face Hub + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) + + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "use_auth_token": use_auth_token, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + + # == step1: prepare configs and file names == # + config = AutoConfig.from_pretrained( + model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs + ) + + if config.model_type not in SUPPORTED_MODELS: + raise TypeError(f"{config.model_type} isn't supported yet.") + + extensions = [] + if use_safetensors: + extensions.append(".safetensors") + else: + extensions += [".bin", ".pt"] + + model_name_or_path = str(model_name_or_path) + is_local = isdir(model_name_or_path) + + resolved_archive_file = None + if is_local: + model_save_name = join(model_name_or_path, model_basename) + for ext in extensions: + if isfile(model_save_name + ext): + resolved_archive_file = model_save_name + ext + break + else: # remote + for ext in extensions: + resolved_archive_file = cached_file(model_name_or_path, model_basename + ext, **cached_file_kwargs) + if resolved_archive_file is not None: + break + + if resolved_archive_file is None: # Could not find a model file to use + raise FileNotFoundError(f"Could not find model in {model_name_or_path}") + + model_save_name = resolved_archive_file + + # == step2: convert model to gptq-model (replace Linear with QuantLinear) == # + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + + transformers.modeling_utils._init_weights = False + + init_contexts = [no_init_weights()] + if low_cpu_mem_usage: + init_contexts.append(accelerate.init_empty_weights(include_buffers=False)) + + with ContextManagers(init_contexts): + model = AutoModelForCausalLM.from_config( + config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype + ) + cls.make_smooth_model(model) + model.tie_weights() + + # == step3: load checkpoint and dispatch == # + if isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: + raise ValueError( + "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or " + "'sequential'." + ) + if isinstance(device_map, dict): + max_memory = None + else: + if device is None and not device_map and not max_memory: + device_map = "auto" + if device is not None: + device = torch.device(device) + if not max_memory and not device_map: + device_map = {"": device.index if device.type == "cuda" else device.type} + if not isinstance(device_map, dict) and device_map != "sequential": + max_memory = accelerate.utils.get_balanced_memory( + model=model, + max_memory=max_memory, + no_split_module_classes=[cls.layer_type], + low_zero=(device_map == "balanced_low_0"), + ) + if not isinstance(device_map, dict): + device_map = accelerate.infer_auto_device_map( + model, max_memory=max_memory, no_split_module_classes=[cls.layer_type] + ) + + accelerate.utils.modeling.load_checkpoint_in_model( + model, checkpoint=model_save_name, device_map=device_map, offload_state_dict=True, offload_buffers=True + ) + model = simple_dispatch_model(model, device_map) + + # == step4: set seqlen == # + model_config = model.config.to_dict() + seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] + if any([k in model_config for k in seq_len_keys]): + for key in seq_len_keys: + if key in model_config: + model.seqlen = model_config[key] + break + else: + logger.warning("can't get model's sequence length from model config, will set to 4096.") + model.seqlen = 4096 + + return cls( + model, + True, + ) + + def __getattr__(self, item): + try: + return super().__getattr__(item) + except: + return getattr(self.model, item) + + +__all__ = ["BaseSmoothForCausalLM"] diff --git a/colossalai/inference/quant/smoothquant/models/linear.py b/colossalai/inference/quant/smoothquant/models/linear.py index 1c01c6222e7a..6cf681601dbf 100644 --- a/colossalai/inference/quant/smoothquant/models/linear.py +++ b/colossalai/inference/quant/smoothquant/models/linear.py @@ -1,4 +1,5 @@ import torch +from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32 from torch_int.functional.quantization import quantize_per_tensor_absmax try: @@ -19,9 +20,18 @@ def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): self.register_buffer( "weight", - torch.randint(-127, 127, (self.out_features, self.in_features), dtype=torch.int8, requires_grad=False), + torch.randint( + -127, + 127, + (self.out_features, self.in_features), + dtype=torch.int8, + requires_grad=False, + ), + ) + self.register_buffer( + "bias", + torch.zeros((1, self.out_features), dtype=torch.float, requires_grad=False), ) - self.register_buffer("bias", torch.zeros((1, self.out_features), dtype=torch.float, requires_grad=False)) self.register_buffer("a", torch.tensor(alpha)) def to(self, *args, **kwargs): @@ -44,6 +54,122 @@ def from_float(module: torch.nn.Linear, input_scale): int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) alpha = input_scale * weight_scale int8_module.weight = int8_weight - int8_module.bias.data.copy_(module.bias.to(torch.float)) + if module.bias is not None: + int8_module.bias.data.copy_(module.bias.to(torch.float)) + int8_module.a = alpha + return int8_module + + +class W8A8B8O8Linear(torch.nn.Module): + # For qkv_proj + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer( + "weight", + torch.randint( + -127, + 127, + (self.out_features, self.in_features), + dtype=torch.int8, + requires_grad=False, + ), + ) + self.register_buffer( + "bias", + torch.zeros((1, self.out_features), dtype=torch.int8, requires_grad=False), + ) + self.register_buffer("a", torch.tensor(alpha)) + self.register_buffer("b", torch.tensor(beta)) + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = linear_a8_w8_b8_o8(x, self.weight, self.bias, self.a.item(), self.b.item()) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale, output_scale): + int8_module = W8A8B8O8Linear(module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale / output_scale + int8_module.weight = int8_weight + int8_module.a = alpha + + if module.bias is not None: + int8_bias, bias_scale = quantize_per_tensor_absmax(module.bias) + int8_module.bias = int8_bias + beta = bias_scale / output_scale + int8_module.b = beta + + return int8_module + + +class W8A8BFP32OFP32Linear(torch.nn.Module): + # For fc2 and out_proj + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer( + "weight", + torch.randint( + -127, + 127, + (self.out_features, self.in_features), + dtype=torch.int8, + requires_grad=False, + ), + ) + self.register_buffer( + "bias", + torch.zeros(self.out_features, dtype=torch.float32, requires_grad=False), + ) + self.register_buffer("a", torch.tensor(alpha)) + + def _apply(self, fn): + # prevent the bias from being converted to half + super()._apply(fn) + self.bias = self.bias.to(torch.float32) + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + self.bias = self.bias.to(torch.float32) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = linear_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale): + int8_module = W8A8BFP32OFP32Linear(module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale + int8_module.weight = int8_weight int8_module.a = alpha + int8_module.input_scale = input_scale + int8_module.weight_scale = weight_scale + + if module.bias is not None: + int8_module.bias = module.bias.to(torch.float32) + return int8_module diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py index 34449dbfe03d..831b85e14d5d 100644 --- a/colossalai/inference/quant/smoothquant/models/llama.py +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -1,16 +1,39 @@ # Code modified from smoothquant: https://github.com/mit-han-lab/smoothquant -from typing import Optional, Tuple - +import math +import os +import types +from collections import defaultdict +from functools import partial +from typing import List, Optional, Tuple, Union + +import numpy as np import torch +import torch.nn as nn +from datasets import load_dataset from torch import nn from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T -from torch_int.nn.linear import W8A8B8O8Linear, W8A8BFP32OFP32Linear -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP - -from colossalai.kernel.triton import int8_rotary_embedding_fwd - -from .linear import W8A8BFP32O32LinearSiLU +from torch_int.nn.fused import LayerNormQ +from tqdm import tqdm +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LLAMA_INPUTS_DOCSTRING, + LlamaAttention, + LlamaDecoderLayer, + LlamaMLP, + repeat_kv, + rotate_half, +) +from transformers.utils import add_start_docstrings_to_model_forward + +from colossalai.kernel.triton import ( + int8_rotary_embedding_fwd, + smooth_llama_context_attn_fwd, + smooth_token_attention_fwd, +) + +from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear class LLamaSmoothquantAttention(nn.Module): @@ -30,8 +53,6 @@ def __init__( f" and `num_heads`: {num_heads})." ) - self.attention_weight_scale = 1.0 - self.qk_bmm = BMM_S8T_S8N_F32T(1.0) self.pv_bmm = BMM_S8T_S8N_S8T(1.0) @@ -40,36 +61,50 @@ def __init__( self.q_proj = W8A8B8O8Linear(hidden_size, hidden_size) self.out_proj = W8A8BFP32OFP32Linear(hidden_size, hidden_size) - self.q_output_scale = torch.tensor([1.0]) - self.k_output_scale = torch.tensor([1.0]) - self.rotary_output_scale = torch.tensor([1.0]) + self.register_buffer("q_output_scale", torch.tensor([1.0])) + self.register_buffer("k_output_scale", torch.tensor([1.0])) + self.register_buffer("v_output_scale", torch.tensor([1.0])) + self.register_buffer("q_rotary_output_scale", torch.tensor([1.0])) + self.register_buffer("k_rotary_output_scale", torch.tensor([1.0])) + self.register_buffer("qk_output_scale", torch.tensor([1.0])) + self.register_buffer("attn_output_scale", torch.tensor([1.0])) + @staticmethod def pack( - self, module: LlamaAttention, - input_scale: float, + attn_input_scale: float, q_output_scale: float, k_output_scale: float, v_output_scale: float, + q_rotary_output_scale: float, + k_rotary_output_scale: float, out_input_scale: float, - rotary_output_scale: float, ): - int8_module = LLamaSmoothquantAttention(module.hidden_size, module.head_dim) - int8_module.q_output_scale = q_output_scale - int8_module.k_output_scale = k_output_scale - int8_module.rotary_output_scale = rotary_output_scale - q_output_scale = q_output_scale * module.scaling - module.q_proj.weight *= module.scaling - module.q_proj.bias *= module.scaling - int8_module.q_proj = W8A8B8O8Linear.from_float(module.q_proj, input_scale, q_output_scale) - - int8_module.k_proj = W8A8B8O8Linear.from_float(module.k_proj, input_scale, k_output_scale) - int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, input_scale, v_output_scale) - int8_module.out_proj = W8A8BFP32OFP32Linear.from_float(module.out_proj, out_input_scale) + int8_module = LLamaSmoothquantAttention(module.hidden_size, module.num_heads) + + int8_module.q_output_scale = torch.tensor(q_output_scale) + int8_module.k_output_scale = torch.tensor(k_output_scale) + int8_module.v_output_scale = torch.tensor(v_output_scale) + + int8_module.q_rotary_output_scale = torch.tensor(q_rotary_output_scale) + int8_module.k_rotary_output_scale = torch.tensor(k_rotary_output_scale) + + # q_output_scale = q_output_scale * module.scaling + # module.q_proj.weight *= module.scaling + # module.q_proj.bias *= module.scaling + int8_module.q_proj = W8A8B8O8Linear.from_float(module.q_proj, attn_input_scale, q_output_scale) + int8_module.k_proj = W8A8B8O8Linear.from_float(module.k_proj, attn_input_scale, k_output_scale) + int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, attn_input_scale, v_output_scale) + int8_module.o_proj = W8A8BFP32OFP32Linear.from_float(module.o_proj, out_input_scale) + # print("qout_scale k out scale:", q_output_scale, k_output_scale) int8_module.qk_bmm = BMM_S8T_S8N_F32T.from_scale(q_output_scale, k_output_scale) # alpha = s_prob * s_v / s_out, where s_prob = 1 / 127 int8_module.pv_bmm = BMM_S8T_S8N_S8T.from_scale(1.0 / 127, v_output_scale, out_input_scale) + + int8_module.qk_output_scale = torch.tensor(q_output_scale * k_output_scale) + int8_module.attn_output_scale = torch.tensor(1.0 / 127 * v_output_scale / out_input_scale) + return int8_module def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): @@ -79,12 +114,13 @@ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def forward( self, hidden_states: torch.Tensor, - rotary_emb: Tuple[torch.Tensor], - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + rotary_emb: Tuple[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, seq_len, _ = hidden_states.size() # get query proj @@ -94,101 +130,181 @@ def forward( cos = rotary_emb[0] sin = rotary_emb[1] + int8_rotary_embedding_fwd( query_states.view(-1, self.num_heads, self.head_dim), cos, sin, self.q_output_scale, - self.rotary_output_scale, + self.q_rotary_output_scale, ) int8_rotary_embedding_fwd( key_states.view(-1, self.num_heads, self.head_dim), cos, sin, self.k_output_scale, - self.rotary_output_scale, + self.k_rotary_output_scale, ) - if past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(key_states, -1, bsz) - value_states = self._shape(value_states, -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(key_states, -1, bsz) - value_states = self._shape(value_states, -1, bsz) - - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - - query_states = self._shape(query_states, seq_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) - - src_len = key_states.size(1) - attn_weights = self.qk_bmm(query_states, key_states) - - if attn_weights.size() != (bsz * self.num_heads, seq_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, seq_len, src_len)}, but is" - f" {attn_weights.size()}" + if past_key_value is None: + proj_shape = (bsz * seq_len, -1, self.head_dim) + + query_states = query_states.view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + # attn_output = torch.empty(bsz * seq_len, self.num_heads, self.head_dim, dtype=torch.int8, device="cuda") + attn_output = torch.empty_like(query_states) + + b_start_loc = torch.arange(start=0, end=bsz * seq_len, step=seq_len, dtype=torch.int, device="cuda") + b_seq_len = torch.full([bsz], seq_len, dtype=torch.int, device="cuda") + + smooth_llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + self.q_rotary_output_scale.item(), + self.k_rotary_output_scale.item(), + self.v_output_scale.item(), + self.attn_output_scale.item(), + b_start_loc, + b_seq_len, + seq_len, ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, seq_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, seq_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, seq_len, src_len) + attention_mask - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) - attn_weights = attn_weights.view(bsz * self.num_heads, seq_len, src_len) - - attn_probs = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + if use_cache: + past_key_value = ( + key_states.view(bsz, seq_len, -1, self.head_dim), + value_states.view(bsz, seq_len, -1, self.head_dim), ) - attn_probs = layer_head_mask.view(1, -1, 1, 1) * attn_probs.view(bsz, self.num_heads, seq_len, src_len) - attn_probs = attn_probs.view(bsz * self.num_heads, seq_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_probs_reshaped = attn_probs.view(bsz, self.num_heads, seq_len, src_len) - attn_probs = attn_probs_reshaped.view(bsz * self.num_heads, seq_len, src_len) else: - attn_probs_reshaped = None - - # (A_row V_row)_row = (A_row V_col ^T)_row - attn_probs.mul_(127).round_() - attn_probs = attn_probs.to(torch.int8) - - value_states = value_states.transpose(1, 2).contiguous() - attn_output = self.pv_bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, seq_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, seq_len, self.head_dim)}, but is" - f" {attn_output.size()}" + total_seq_len = past_key_value[0].shape[1] + seq_len + key_states = torch.cat([past_key_value[0], key_states.view(bsz, seq_len, -1, self.head_dim)], dim=1) + value_states = torch.cat([past_key_value[1], value_states.view(bsz, seq_len, -1, self.head_dim)], dim=1) + + proj_shape = (bsz * seq_len, -1, self.head_dim) + kv_shape = (bsz * total_seq_len, -1, self.head_dim) + query_states = query_states.view(*proj_shape) + key_states = key_states.view(*kv_shape) + value_states = value_states.view(*kv_shape) + attn_output = torch.empty_like(query_states) + + b_start_loc = torch.arange( + start=0, end=bsz * total_seq_len, step=total_seq_len, dtype=torch.int, device="cuda" ) - - attn_output = attn_output.view(bsz, self.num_heads, seq_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) + b_seq_len = torch.full([bsz], seq_len, dtype=torch.int, device="cuda") * total_seq_len + block_loc = torch.arange(total_seq_len, dtype=torch.int, device="cuda").expand(bsz, -1) + smooth_token_attention_fwd( + query_states, + key_states, + value_states, + attn_output, + self.q_rotary_output_scale.item(), + self.k_rotary_output_scale.item(), + self.v_output_scale.item(), + self.attn_output_scale.item(), + block_loc, + b_start_loc, + b_seq_len, + total_seq_len, + ) + if use_cache: + past_key_value = ( + key_states.view(bsz, total_seq_len, -1, self.head_dim), + value_states.view(bsz, total_seq_len, -1, self.head_dim), + ) + # if use_cache: + # past_key_value = (key_states, value_states) + + # if use_cache: + # if past_key_value is not None: + # # reuse k, v, self_attention + # key_states = self._shape(key_states, -1, bsz) + # value_states = self._shape(value_states, -1, bsz) + # key_states = torch.cat([past_key_value[0], key_states], dim=2) + # value_states = torch.cat([past_key_value[1], value_states], dim=2) + # else: + # # self_attention + # key_states = self._shape(key_states, -1, bsz) + # value_states = self._shape(value_states, -1, bsz) + + # if use_cache: + # if past_key_value is not None: + # # reuse k, v, self_attention + # key_states = self._shape(key_states, -1, bsz) + # value_states = self._shape(value_states, -1, bsz) + # key_states = torch.cat([past_key_value[0], key_states], dim=2) + # value_states = torch.cat([past_key_value[1], value_states], dim=2) + # else: + # # self_attention + # key_states = self._shape(key_states, -1, bsz) + # value_states = self._shape(value_states, -1, bsz) + + # past_key_value = (key_states, value_states) + + # proj_shape = (bsz * self.num_heads, -1, self.head_dim) + + # query_states = self._shape(query_states, seq_len, bsz).view(*proj_shape) + # key_states = key_states.view(*proj_shape) + # value_states = value_states.view(*proj_shape) + + # src_len = key_states.size(1) + # print("q states:", query_states.shape, query_states.device, query_states.is_contiguous(), query_states.dtype) + # print("key states:", key_states.shape, key_states.device, key_states.is_contiguous(), key_states.dtype) + + # attn_weights = self.qk_bmm(query_states, key_states) + + # if attn_weights.size() != (bsz * self.num_heads, seq_len, src_len): + # raise ValueError( + # f"Attention weights should be of size {(bsz * self.num_heads, seq_len, src_len)}, but is" + # f" {attn_weights.size()}" + # ) + + # if attention_mask is not None: + # if attention_mask.size() != (bsz, 1, seq_len, src_len): + # raise ValueError( + # f"Attention mask should be of size {(bsz, 1, seq_len, src_len)}, but is {attention_mask.size()}" + # ) + # attn_weights = attn_weights.view(bsz, self.num_heads, seq_len, src_len) + attention_mask + # attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + # attn_weights = attn_weights.view(bsz * self.num_heads, seq_len, src_len) + + # attn_probs = nn.functional.softmax(attn_weights, dim=-1) + + # if output_attentions: + # # this operation is a bit awkward, but it's required to + # # make sure that attn_weights keeps its gradient. + # # In order to do so, attn_weights have to be reshaped + # # twice and have to be reused in the following + # attn_probs_reshaped = attn_probs.view(bsz, self.num_heads, seq_len, src_len) + # attn_probs = attn_probs_reshaped.view(bsz * self.num_heads, seq_len, src_len) + # else: + # attn_probs_reshaped = None + + # # (A_row V_row)_row = (A_row V_col ^T)_row + # attn_probs.mul_(127).round_() + # attn_probs = attn_probs.to(torch.int8) + + # value_states = value_states.transpose(1, 2).contiguous() + # attn_output = self.pv_bmm(attn_probs, value_states) + + # if attn_output.size() != (bsz * self.num_heads, seq_len, self.head_dim): + # raise ValueError( + # f"`attn_output` should be of size {(bsz, self.num_heads, seq_len, self.head_dim)}, but is" + # f" {attn_output.size()}" + # ) + + # attn_output = attn_output.view(bsz, self.num_heads, seq_len, self.head_dim) + # attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned aross GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, seq_len, self.num_heads * self.head_dim).contiguous() + # attn_output = attn_output.reshape(bsz, seq_len, self.num_heads * self.head_dim).contiguous() + + attn_output = attn_output.view(bsz, seq_len, self.num_heads * self.head_dim) attn_output = self.out_proj(attn_output) - return attn_output, attn_probs_reshaped, past_key_value + return attn_output, None, past_key_value class LlamaSmoothquantMLP(nn.Module): @@ -197,10 +313,10 @@ def __init__(self, intermediate_size, hidden_size): self.gate_proj = W8A8BFP32O32LinearSiLU(hidden_size, intermediate_size) self.up_proj = W8A8BFP32OFP32Linear(hidden_size, intermediate_size) self.down_proj = W8A8BFP32OFP32Linear(intermediate_size, hidden_size) - self.down_proj_input_scale = 1.0 + self.register_buffer("down_proj_input_scale", torch.tensor([1.0])) + @staticmethod def pack( - self, mlp_module: LlamaMLP, gate_proj_input_scale: float, up_proj_input_scale: float, @@ -214,7 +330,7 @@ def pack( int8_module.gate_proj = W8A8BFP32O32LinearSiLU.from_float(mlp_module.gate_proj, gate_proj_input_scale) int8_module.up_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.up_proj, up_proj_input_scale) int8_module.down_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.down_proj, down_proj_input_scale) - self.down_proj_input_scale = down_proj_input_scale + int8_module.down_proj_input_scale = torch.tensor(down_proj_input_scale) return int8_module def forward( @@ -225,7 +341,617 @@ def forward( gate_out = self.gate_proj(hidden_states) up_out = self.up_proj(hidden_states) inter_out = gate_out * up_out - inter_out = inter_out.div_(self.down_proj_input_scale).round().clamp(-128, 127).to(torch.int8) + inter_out = inter_out.div_(self.down_proj_input_scale.item()).round().clamp(-128, 127).to(torch.int8) down_out = self.down_proj(inter_out) down_out = down_out.view(*x_shape[:-1], -1) return down_out + + +class LlamaSmoothquantDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LLamaSmoothquantAttention(config.hidden_size, config.num_attention_heads) + + self.mlp = LlamaSmoothquantMLP(config.intermediate_size, config.hidden_size) + self.input_layernorm = LayerNormQ(config.hidden_size, eps=config.rms_norm_eps) + + self.post_attention_layernorm = LayerNormQ(config.hidden_size, eps=config.rms_norm_eps) + + @staticmethod + def pack( + module: LlamaDecoderLayer, + attn_input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + q_rotary_output_scale: float, + k_rotary_output_scale: float, + out_input_scale: float, + gate_input_scale: float, + up_input_scale: float, + down_input_scale: float, + ): + config = module.self_attn.config + int8_decoder_layer = LlamaSmoothquantDecoderLayer(config) + int8_decoder_layer.self_attn = LLamaSmoothquantAttention.pack( + module.self_attn, + attn_input_scale, + q_output_scale, + k_output_scale, + v_output_scale, + q_rotary_output_scale, + k_rotary_output_scale, + out_input_scale, + ) + + int8_decoder_layer.mlp = LlamaSmoothquantMLP.pack( + module.mlp, + gate_input_scale, + up_input_scale, + down_input_scale, + ) + + int8_decoder_layer.input_layernorm = LayerNormQ(config.hidden_size, eps=config.rms_norm_eps) + + int8_decoder_layer.post_attention_layernorm = LayerNormQ(config.hidden_size, eps=config.rms_norm_eps) + + return int8_decoder_layer + + def forward( + self, + hidden_states: torch.Tensor, + rotary_emb: Tuple[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + padding_mask: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + rotary_emb=rotary_emb, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + ) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class LlamaApplyRotary(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + x_embed = (x * cos) + (rotate_half(x) * sin) + + return x_embed + + +def llama_decoder_layer_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states = self.q_apply_rotary(query_states, cos, sin, position_ids) + key_states = self.k_apply_rotary(key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def init_to_get_rotary(config, base=10000, use_elem=False): + """ + This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer + Args: + base : calculation arg + use_elem : activated when using chatglm-based models + """ + config.head_dim_ = config.hidden_size // config.num_attention_heads + if not hasattr(config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = config.rope_scaling.factor if config.rope_scaling is not None else 1.0 + + if hasattr(config, "max_sequence_length"): + max_seq_len = config.max_sequence_length + elif hasattr(config, "max_position_embeddings"): + max_seq_len = config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + + # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + try: + ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", 1)) + assert ntk_alpha >= 1 + if ntk_alpha > 1: + print(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (config.head_dim_ / (config.head_dim_ - 2))) # Base change formula + except: + pass + + n_elem = config.head_dim_ + if use_elem: + n_elem //= 2 + + inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + _cos_cached = torch.cos(freqs).to(torch.float) + _sin_cached = torch.sin(freqs).to(torch.float) + return _cos_cached, _sin_cached + + +@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) +def llama_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device) + padding_mask = None + else: + if 0 in attention_mask: + padding_mask = attention_mask + else: + padding_mask = None + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if past_key_values_length == 0: + position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + else: + position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1) + position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids + ) + else: + layer_outputs = decoder_layer( + hidden_states, + rotary_emb=(position_cos, position_sin), + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +def convert_llama_to_smoothquant( + llama_model, + tokenizer, + dataset_path, + num_samples=512, + seq_len=512, +): + llama_config = llama_model.config + + llama_model.eval() + device = next(llama_model.parameters()).device + # print("model:", llama_model) + act_dict = defaultdict(dict) + + def stat_io_hook(m, x, y, name): + if isinstance(x, tuple): + x = x[0] + if name not in act_dict or "input" not in act_dict[name]: + act_dict[name]["input"] = x.detach().abs().max().item() + else: + act_dict[name]["input"] = max(act_dict[name]["input"], x.detach().abs().max().item()) + if isinstance(y, tuple): + y = y[0] + if name not in act_dict or "output" not in act_dict[name]: + act_dict[name]["output"] = y.detach().abs().max().item() + else: + act_dict[name]["output"] = max(act_dict[name]["output"], y.detach().abs().max().item()) + + for name, m in llama_model.named_modules(): + if isinstance(m, LlamaAttention): + setattr(m, "q_apply_rotary", LlamaApplyRotary()) + setattr(m, "k_apply_rotary", LlamaApplyRotary()) + m.forward = types.MethodType(llama_decoder_layer_forward, m) + + hooks = [] + for name, m in llama_model.named_modules(): + if isinstance(m, LlamaApplyRotary): + hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) + if isinstance(m, torch.nn.Linear): + hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) + + print("Collecting activation scales...") + pbar = tqdm(range(num_samples)) + dataset = load_dataset("json", data_files=dataset_path, split="train") + dataset = dataset.shuffle(seed=42) + for i in pbar: + input_ids = tokenizer( + dataset["rows"][0][i]["row"]["text"], + return_tensors="pt", + max_length=seq_len, + truncation=True, + ).input_ids.to(device) + llama_model(input_ids) + mean_scale = np.mean([v["input"] for v in act_dict.values()]) + pbar.set_description(f"Mean input scale: {mean_scale:.2f}") + for hook in hooks: + hook.remove() + + decoder_layer_scales = [] + + for idx in range(llama_config.num_hidden_layers): + scale_dict = {} + scale_dict["attn_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["input"] / 127 + scale_dict["q_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["output"] / 127 + scale_dict["k_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.k_proj"]["output"] / 127 + scale_dict["v_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.v_proj"]["output"] / 127 + + scale_dict["q_rotary_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_apply_rotary"]["output"] / 127 + + scale_dict["k_rotary_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.k_apply_rotary"]["output"] / 127 + + scale_dict["out_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.o_proj"]["input"] / 127 + # mlp scales + scale_dict["gate_input_scale"] = act_dict[f"model.layers.{idx}.mlp.gate_proj"]["input"] / 127 + scale_dict["up_input_scale"] = act_dict[f"model.layers.{idx}.mlp.up_proj"]["input"] / 127 + scale_dict["down_input_scale"] = act_dict[f"model.layers.{idx}.mlp.down_proj"]["input"] / 127 + + decoder_layer_scales.append(scale_dict) + + for i, layer in enumerate(llama_model.model.layers): + orig_layer = layer + llama_model.model.layers[i] = LlamaSmoothquantDecoderLayer.pack(orig_layer, **decoder_layer_scales[i]) + + llama_model.model.forward = types.MethodType(llama_model_forward, llama_model.model) + + cos, sin = init_to_get_rotary(llama_config) + llama_model.model.register_buffer("_cos_cached", cos) + llama_model.model.register_buffer("_sin_cached", sin) + return decoder_layer_scales, act_dict + + +# class SmoothLlamaForCausalLM(BaseSmoothForCausalLM): +# def __init__(self, model: PreTrainedModel, quantized: bool = False): +# super().__init__(model, quantized) + +# def quantized( +# self, +# tokenizer, +# dataset_path, +# num_samples=512, +# seq_len=512, +# ): +# llama_model = self.model +# llama_config = llama_model.config + +# llama_model.eval() +# device = next(llama_model.parameters()).device +# # print("model:", llama_model) +# act_dict = defaultdict(dict) + +# def stat_io_hook(m, x, y, name): +# if isinstance(x, tuple): +# x = x[0] +# if name not in act_dict or "input" not in act_dict[name]: +# act_dict[name]["input"] = x.detach().abs().max().item() +# else: +# act_dict[name]["input"] = max(act_dict[name]["input"], x.detach().abs().max().item()) +# if isinstance(y, tuple): +# y = y[0] +# if name not in act_dict or "output" not in act_dict[name]: +# act_dict[name]["output"] = y.detach().abs().max().item() +# else: +# act_dict[name]["output"] = max(act_dict[name]["output"], y.detach().abs().max().item()) + +# for name, m in llama_model.named_modules(): +# if isinstance(m, LlamaAttention): +# setattr(m, "q_apply_rotary", LlamaApplyRotary()) +# setattr(m, "k_apply_rotary", LlamaApplyRotary()) +# m.forward = types.MethodType(llama_decoder_layer_forward, m) + +# hooks = [] +# for name, m in llama_model.named_modules(): +# if isinstance(m, LlamaApplyRotary): +# hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) +# if isinstance(m, torch.nn.Linear): +# hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) + +# print("Collecting activation scales...") +# pbar = tqdm(range(num_samples)) +# dataset = load_dataset("json", data_files=dataset_path, split="train") +# dataset = dataset.shuffle(seed=42) +# for i in pbar: +# input_ids = tokenizer( +# dataset["rows"][0][i]["row"]["text"], +# return_tensors="pt", +# max_length=seq_len, +# truncation=True, +# ).input_ids.to(device) +# llama_model(input_ids) +# mean_scale = np.mean([v["input"] for v in act_dict.values()]) +# pbar.set_description(f"Mean input scale: {mean_scale:.2f}") +# for hook in hooks: +# hook.remove() + +# decoder_layer_scales = [] + +# for idx in range(llama_config.num_hidden_layers): +# scale_dict = {} +# scale_dict["attn_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["input"] / 127 +# scale_dict["q_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["output"] / 127 +# scale_dict["k_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.k_proj"]["output"] / 127 +# scale_dict["v_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.v_proj"]["output"] / 127 + +# scale_dict["q_rotary_output_scale"] = ( +# act_dict[f"model.layers.{idx}.self_attn.q_apply_rotary"]["output"] / 127 +# ) + +# scale_dict["k_rotary_output_scale"] = ( +# act_dict[f"model.layers.{idx}.self_attn.k_apply_rotary"]["output"] / 127 +# ) + +# scale_dict["out_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.o_proj"]["input"] / 127 +# # mlp scales +# scale_dict["gate_input_scale"] = act_dict[f"model.layers.{idx}.mlp.gate_proj"]["input"] / 127 +# scale_dict["up_input_scale"] = act_dict[f"model.layers.{idx}.mlp.up_proj"]["input"] / 127 +# scale_dict["down_input_scale"] = act_dict[f"model.layers.{idx}.mlp.down_proj"]["input"] / 127 + +# decoder_layer_scales.append(scale_dict) + +# for i, layer in enumerate(llama_model.model.layers): +# orig_layer = layer +# llama_model.model.layers[i] = LlamaSmoothquantDecoderLayer.pack(orig_layer, **decoder_layer_scales[i]) + +# llama_model.model.forward = types.MethodType(llama_model_forward, llama_model.model) + +# cos, sin = init_to_get_rotary(llama_config) +# llama_model.model.register_buffer("_cos_cached", cos) +# llama_model.model.register_buffer("_sin_cached", sin) + +# def make_smooth_model(cls, llama_model): +# super().make_smooth_model() + +# llama_config = llama_model.config + +# for i, layer in enumerate(llama_model.model.layers): +# llama_model.model.layers[i] = LlamaSmoothquantDecoderLayer(llama_config) + +# llama_model.model.forward = types.MethodType(llama_model_forward, llama_model.model) +# cos, sin = init_to_get_rotary(llama_config) +# llama_model.model.register_buffer("_cos_cached", cos) +# llama_model.model.register_buffer("_sin_cached", sin) diff --git a/colossalai/inference/quant/smoothquant/smooth.py b/colossalai/inference/quant/smoothquant/smooth.py new file mode 100644 index 000000000000..120e7818a2f4 --- /dev/null +++ b/colossalai/inference/quant/smoothquant/smooth.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from transformers.models.bloom.modeling_bloom import BloomBlock +from transformers.models.opt.modeling_opt import OPTDecoderLayer + + +@torch.no_grad() +def smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5): + if not isinstance(fcs, list): + fcs = [fcs] + assert isinstance(ln, nn.LayerNorm) + for fc in fcs: + assert isinstance(fc, nn.Linear) + assert ln.weight.numel() == fc.in_features == act_scales.numel() + + device, dtype = fcs[0].weight.device, fcs[0].weight.dtype + act_scales = act_scales.to(device=device, dtype=dtype) + weight_scales = torch.cat([fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0) + weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5) + + scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype) + + ln.weight.div_(scales) + ln.bias.div_(scales) + + for fc in fcs: + fc.weight.mul_(scales.view(1, -1)) + + +@torch.no_grad() +def smooth_lm(model, scales, alpha=0.5): + for name, module in model.named_modules(): + if isinstance(module, OPTDecoderLayer): + attn_ln = module.self_attn_layer_norm + qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj] + qkv_input_scales = scales[name + ".self_attn.q_proj"] + smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha) + + ffn_ln = module.final_layer_norm + fc1 = module.fc1 + fc1_input_scales = scales[name + ".fc1"] + smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha) + elif isinstance(module, BloomBlock): + attn_ln = module.input_layernorm + qkv = module.self_attention.query_key_value + qkv_input_scales = scales[name + ".self_attention.query_key_value"] + smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha) + + ffn_ln = module.post_attention_layernorm + fc1 = module.mlp.dense_h_to_4h + fc1_input_scales = scales[name + ".mlp.dense_h_to_4h"] + smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha) diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 09f7a5592253..877e914c6de0 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -10,6 +10,7 @@ from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd from .rms_norm import rmsnorm_forward from .rotary_embedding_kernel import rotary_embedding_fwd + from .smooth__attention import smooth_llama_context_attn_fwd, smooth_token_attention_fwd from .softmax import softmax from .token_attention_kernel import token_attention_fwd @@ -24,6 +25,8 @@ "token_attention_fwd", "gptq_fused_linear_triton", "int8_rotary_embedding_fwd", + "smooth_llama_context_attn_fwd", + "smooth_token_attention_fwd", ] except ImportError: diff --git a/colossalai/kernel/triton/int8_rotary_embedding_kernel.py b/colossalai/kernel/triton/int8_rotary_embedding_kernel.py index dfad8a973ed6..126a52968088 100644 --- a/colossalai/kernel/triton/int8_rotary_embedding_kernel.py +++ b/colossalai/kernel/triton/int8_rotary_embedding_kernel.py @@ -89,29 +89,29 @@ def int8_rotary_embedding_fwd(q, cos, sin, input_scale, output_scale): assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" BLOCK_HEAD = 4 BLOCK_SEQ = 32 - grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) + (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) if head_dim >= 128: - num_warps = 8 + pass else: - num_warps = 4 - - _rotary_kernel[grid]( - q, - input_scale.item(), - output_scale.item(), - cos, - sin, - q.stride(0), - q.stride(1), - q.stride(2), - cos.stride(0), - cos.stride(1), - total_len, - HEAD_NUM=head_num, - BLOCK_HEAD=BLOCK_HEAD, - BLOCK_SEQ=BLOCK_SEQ, - HEAD_DIM=head_dim, - num_warps=num_warps, - num_stages=1, - ) + pass + + # _rotary_kernel[grid]( + # q, + # input_scale.item(), + # output_scale.item(), + # cos, + # sin, + # q.stride(0), + # q.stride(1), + # q.stride(2), + # cos.stride(0), + # cos.stride(1), + # total_len, + # HEAD_NUM=head_num, + # BLOCK_HEAD=BLOCK_HEAD, + # BLOCK_SEQ=BLOCK_SEQ, + # HEAD_DIM=head_dim, + # num_warps=num_warps, + # num_stages=1, + # ) return diff --git a/colossalai/kernel/triton/smooth_attention.py b/colossalai/kernel/triton/smooth_attention.py new file mode 100644 index 000000000000..c726d4b03a8c --- /dev/null +++ b/colossalai/kernel/triton/smooth_attention.py @@ -0,0 +1,649 @@ +import math + +import torch + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + """ + this function is modified from + https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 + """ + + @triton.jit + def _context_flash_attention_kernel( + Q, + K, + V, + q_input_scale, + k_input_scale, + v_input_scale, + pv_output_scale, + sm_scale, + B_Start_Loc, + B_Seqlen, + TMP, + alibi_ptr, + Out, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_tmp_b, + stride_tmp_h, + stride_tmp_s, + # suggtest set-up 64, 128, 256, 512 + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + batch_id = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + + # get batch info + cur_batch_seq_len = tl.load(B_Seqlen + batch_id) + cur_batch_start_index = tl.load(B_Start_Loc + batch_id) + block_start_loc = BLOCK_M * start_m + + load_p_ptrs = ( + Q + + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + q = q.to(tl.float32) * q_input_scale + + k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd + v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd + t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + if alibi_ptr is not None: + alibi_m = tl.load(alibi_ptr + cur_head) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k = tl.load( + k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0, + ) + k = k.to(tl.float32) * k_input_scale + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + if alibi_ptr is not None: + alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) + qk -= alibi_loc * alibi_m + + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) + + v = v.to(tl.float32) * v_input_scale + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + acc = (acc / pv_output_scale).to(tl.int8) + off_o = ( + (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + ) + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + return + + @torch.no_grad() + def smooth_llama_context_attn_fwd( + q, k, v, o, q_input_scale, k_input_scale, v_input_scale, pv_output_scale, b_start_loc, b_seq_len, max_input_len + ): + BLOCK = 32 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk, "context process only supports equal query, key, value length" + assert Lk == Lv, "context process only supports equal query, key, value length" + assert Lk in {16, 32, 64, 128} + + sm_scale = 1.0 / math.sqrt(Lk) + batch, head = b_seq_len.shape[0], q.shape[1] + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + + tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + # num_warps = 4 + _context_flash_attention_kernel[grid]( + q, + k, + v, + q_input_scale, + k_input_scale, + v_input_scale, + pv_output_scale, + sm_scale, + b_start_loc, + b_seq_len, + tmp, + None, + o, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @triton.jit + def _token_attn_1_kernel( + Q, + K, + q_input_scale, + k_input_scale, + sm_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + q_batch_stride, + q_head_stride, + q_head_dim_stride, + k_batch_stride, + k_head_stride, + k_head_dim_stride, + attn_head_stride, + attn_batch_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + start_n = tl.program_id(2) + + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_end_index = max_kv_cache_len + + off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + q = tl.load(Q + off_q + start_mark) + q = q.to(tl.float32) * q_input_scale + offs_n_new = current_batch_start_index + offs_n + k_loc = tl.load( + kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, + mask=offs_n_new < current_batch_end_index, + other=0, + ) + off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride + k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) + k = k.to(tl.float32) * k_input_scale + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride + tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) + return + + @triton.jit + def _token_attn_1_alibi_kernel( + Q, + K, + q_input_scale, + k_input_scale, + sm_scale, + alibi, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + q_batch_stride, + q_head_stride, + q_head_dim_stride, + k_batch_stride, + k_head_stride, + k_head_dim_stride, + attn_head_stride, + attn_batch_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + start_n = tl.program_id(2) + + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_end_index = max_kv_cache_len + + off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + alibi_m = tl.load(alibi + current_head) + q = tl.load(Q + off_q + start_mark) + q = q.to(tl.float32) * q_input_scale + + offs_n_new = current_batch_start_index + offs_n + k_loc = tl.load( + kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, + mask=offs_n_new < current_batch_end_index, + other=0, + ) + off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride + k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) + k = k.to(tl.float32) * k_input_scale + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n) + off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride + tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) + return + + @torch.no_grad() + def token_attn_fwd_1( + q, + k, + attn_out, + q_input_scale, + k_input_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + alibi=None, + ): + BLOCK = 32 + # shape constraints + q_head_dim, k_head_dim = q.shape[-1], k.shape[-1] + assert q_head_dim == k_head_dim + assert k_head_dim in {16, 32, 64, 128} + sm_scale = 1.0 / (k_head_dim**0.5) + + batch, head_num = kv_cache_loc.shape[0], q.shape[1] + + grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK)) + + num_warps = 4 if k_head_dim <= 64 else 8 + num_warps = 2 + + if alibi is not None: + _token_attn_1_alibi_kernel[grid]( + q, + k, + q_input_scale, + k_input_scale, + sm_scale, + alibi, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + attn_out.stride(0), + attn_out.stride(1), + HEAD_DIM=k_head_dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + else: + _token_attn_1_kernel[grid]( + q, + k, + q_input_scale, + k_input_scale, + sm_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + attn_out.stride(0), + attn_out.stride(1), + HEAD_DIM=k_head_dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @triton.jit + def _token_attn_softmax_fwd( + softmax_logics, + kv_cache_start_loc, + kv_cache_seqlen, + softmax_prob_out, + logics_head_dim_stride, + logics_batch_stride, + prob_head_dim_stride, + prob_batch_stride, + BLOCK_SIZE: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + + col_offsets = tl.arange(0, BLOCK_SIZE) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + row = tl.load( + softmax_logics + + current_head * logics_head_dim_stride + + (current_batch_in_all_start_index + col_offsets) * logics_batch_stride, + mask=col_offsets < current_batch_seq_len, + other=-float("inf"), + ).to(tl.float32) + + row_minus_max = row - tl.max(row, axis=0) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + + tl.store( + softmax_prob_out + + current_head * prob_head_dim_stride + + (current_batch_in_all_start_index + col_offsets) * prob_batch_stride, + softmax_output, + mask=col_offsets < current_batch_seq_len, + ) + return + + @torch.no_grad() + def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len): + BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len) + batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0] + + num_warps = 4 + if BLOCK_SIZE >= 2048: + num_warps = 8 + if BLOCK_SIZE >= 4096: + num_warps = 16 + + _token_attn_softmax_fwd[(batch, head_num)]( + softmax_logics, + kv_cache_start_loc, + kv_cache_seqlen, + softmax_prob_out, + softmax_logics.stride(0), + softmax_logics.stride(1), + softmax_prob_out.stride(0), + softmax_prob_out.stride(1), + num_warps=num_warps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return + + @triton.jit + def _token_attn_2_kernel( + Prob, + V, + attn_out, + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + prob_head_dim_stride, + prob_batch_stride, + v_batch_stride, + v_head_stride, + v_head_dim_stride, + attn_out_batch_stride, + attn_out_head_stride, + attn_out_head_dim_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride + p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride + v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride + + acc = tl.zeros([HEAD_DIM], dtype=tl.float32) + for start_n in range(0, current_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + p_value = tl.load( + Prob + p_offs + start_n * kv_cache_loc_s_stride, + mask=(start_n + offs_n) < current_batch_seq_len, + other=0.0, + ) + v_loc = tl.load( + kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride, + mask=(start_n + offs_n) < current_batch_seq_len, + other=0.0, + ) + v_value = tl.load( + V + v_offs + v_loc[:, None] * v_batch_stride, + mask=(start_n + offs_n[:, None]) < current_batch_seq_len, + other=0.0, + ) + v_value = v_value.to(tl.float32) * v_input_scale + acc += tl.sum(p_value[:, None] * v_value, 0) + + acc = (acc / pv_output_scale).to(tl.int8) + off_o = ( + current_batch * attn_out_batch_stride + + current_head * attn_out_head_stride + + offs_d * attn_out_head_dim_stride + ) + out_ptrs = attn_out + off_o + tl.store(out_ptrs, acc) + return + + @torch.no_grad() + def token_attn_fwd_2( + prob, + v, + attn_out, + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + ): + if triton.__version__ >= "2.1.0": + BLOCK = 128 + else: + BLOCK = 64 + batch, head = kv_cache_loc.shape[0], v.shape[1] + grid = (batch, head) + num_warps = 4 + dim = v.shape[-1] + + _token_attn_2_kernel[grid]( + prob, + v, + attn_out, + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + prob.stride(0), + prob.stride(1), + v.stride(0), + v.stride(1), + v.stride(2), + attn_out.stride(0), + attn_out.stride(1), + attn_out.stride(2), + HEAD_DIM=dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @torch.no_grad() + def smooth_token_attention_fwd( + q, + k, + v, + attn_out, + q_input_scale, + k_input_scale, + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + alibi=None, + ): + head_num = k.shape[1] + batch_size = kv_cache_seq_len.shape[0] + calcu_shape1 = (batch_size, head_num, k.shape[2]) + total_token_num = k.shape[0] + + att_m_tensor = torch.empty((head_num, total_token_num), dtype=torch.float32, device="cuda") + + token_attn_fwd_1( + q.view(calcu_shape1), + k, + att_m_tensor, + q_input_scale, + k_input_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + alibi=alibi, + ) + + prob = torch.empty_like(att_m_tensor) + + token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) + att_m_tensor = None + token_attn_fwd_2( + prob, + v, + attn_out.view(calcu_shape1), + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + ) + + prob = None + + return diff --git a/examples/inference/smoothquant_conversion.py b/examples/inference/smoothquant_conversion.py new file mode 100644 index 000000000000..066e9ea1abae --- /dev/null +++ b/examples/inference/smoothquant_conversion.py @@ -0,0 +1,135 @@ +import argparse +import os + +import torch +from transformers import AutoModelForCausalLM, LlamaTokenizer + +from colossalai.inference.quant.smoothquant.models.llama import convert_llama_to_smoothquant + + +def build_model_and_tokenizer(model_name): + tokenizer = LlamaTokenizer.from_pretrained(model_name, model_max_length=512) + kwargs = {"torch_dtype": torch.float16, "device_map": "sequential"} + model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) + model = model.to(torch.float32) + # config = { + # "architectures": ["LLaMAForCausalLM"], + # "bos_token_id": 0, + # "eos_token_id": 1, + # "hidden_act": "silu", + # "hidden_size": 4096, + # "initializer_range": 0.02, + # "intermediate_size": 11008, + # "max_position_embeddings": 2048, + # "max_sequence_length": 2048, + # "model_type": "llama", + # "num_attention_heads": 32, + # "num_hidden_layers": 2, + # "num_key_value_heads": 32, + # "pad_token_id": -1, + # "pretraining_tp": 1, + # "rms_norm_eps": 1e-06, + # "torch_dtype": "float32", + # "transformers_version": "4.32.1", + # "use_cache": True, + # "vocab_size": 32000, + # } + # config = LlamaConfig(**config) + # model = LlamaForCausalLM(config) + + return model, tokenizer + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-name", type=str, default="facebook/opt-1.3b", help="model name") + parser.add_argument( + "--output-path", + type=str, + default="act_scales/opt-1.3b.pt", + help="where to save the act scales", + ) + parser.add_argument( + "--dataset-path", + type=str, + default="dataset/val.jsonl.zst", + help="location of the calibration dataset, we use the validation set of the Pile dataset", + ) + parser.add_argument("--num-samples", type=int, default=512) + parser.add_argument("--seq-len", type=int, default=512) + args = parser.parse_args() + return args + + +@torch.no_grad() +def main(): + args = parse_args() + model_path = "/home/lcxk/data3/llama-7b-hf" + dataset_path = "/home/lcxk/data3/datasets/cc_news.json" + num_samples = 10 + seq_len = 512 + print("data path", dataset_path) + data_files = {"train": dataset_path} + + # # dataset = load_dataset("/home/lcxk/data3/datasets/cc_news.py", data_files=dataset_path) + # dataset = load_dataset("json", data_files=dataset_path) + + # print("text", dataset["train"]["rows"][0][1``]["row"]["text"]) + # # for test in dataset["train"]["rows"]: + # # print(test) + # dataset = dataset.shuffle(seed=42) + # # print("text", dataset["rows"][0]) + + model, tokenizer = build_model_and_tokenizer(model_path) + print("config:", model.config) + if not os.path.exists(dataset_path): + print(f"Cannot find the dataset at {args.dataset_path}") + print("Please download the Pile dataset and put the validation set at the path") + print( + "You can download the validation dataset of the Pile at https://mystic.the-eye.eu/public/AI/pile/val.jsonl.zst" + ) + raise FileNotFoundError + + # act_scales = get_act_scales(model, tokenizer, dataset_path, num_samples, seq_len) + + # os.makedirs(os.path.dirname(output_path), exist_ok=True) + # torch.save(act_scales, output_path) + + # act_scales = torch.load(output_path) + # smooth_lm(model, act_scales, 0.5) + # # tokenizer = AutoTokenizer.from_pretrained(model_path) + + if not os.path.exists(dataset_path): + print(f"Cannot find the dataset at (dataset_path)") + print("Please download the Pile dataset and put the validation set at the path") + print( + "You can download the validation dataset of the Pile at https://mystic.the-eye.eu/public/AI/pile/val.jsonl.zst" + ) + raise FileNotFoundError + + decoder_layer_scales, raw_scales = convert_llama_to_smoothquant( + model, tokenizer, dataset_path, num_samples=num_samples, seq_len=seq_len + ) + model = model.cuda() + generate_kwargs = dict(max_new_tokens=16, do_sample=False, use_cache=True) + + print("decoder layer scales:", decoder_layer_scales) + # output_path = Path(args.output_path) / (Path(args.model_name).name + "-smoothquant.pt") + input_tokens = tokenizer(["New York City"], return_tensors="pt").to(model.device) + + max_batch_size = 1 + max_input_len = 7 + input_tokens = { + "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"), + "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"), + } + + print("input:", input) + # gen_args = {"max_input_"} + out = model.generate(**input_tokens, **generate_kwargs) + text = tokenizer.batch_decode(out) + print("text:", text) + + +if __name__ == "__main__": + main() diff --git a/tests/test_smoothquant/test_llama_attention.py b/tests/test_smoothquant/test_llama_attention.py index 26f35e20c6b2..6149f43ce970 100644 --- a/tests/test_smoothquant/test_llama_attention.py +++ b/tests/test_smoothquant/test_llama_attention.py @@ -55,7 +55,7 @@ def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): reason="triton requires cuda version to be higher than 11.4 or not install torch_int", ) def test_llama_context_attention(): - head_num = 8 + head_num = 2 seq_len = 32 head_dim = 64 dtype = torch.float @@ -63,41 +63,55 @@ def test_llama_context_attention(): smooth_attn = LLamaSmoothquantAttention(head_num * head_dim, head_num) - smooth_attn.q_proj.weight = torch.ones(hidden_size, hidden_size).to(torch.int8) - smooth_attn.k_proj.weight = torch.ones(hidden_size, hidden_size).to(torch.int8) - smooth_attn.v_proj.weight = torch.ones(hidden_size, hidden_size).to(torch.int8) - smooth_attn.out_proj.weight = torch.ones(hidden_size, hidden_size).to(torch.int8) + smooth_attn.q_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) + smooth_attn.k_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) + smooth_attn.v_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) + smooth_attn.out_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) + + ones = torch.ones(hidden_size, hidden_size, dtype=torch.float, device="cuda") smooth_attn = smooth_attn.to("cuda") - input = torch.randint(-127, 127, (1, seq_len, head_num * head_dim), dtype=torch.int8, device="cuda") + input = torch.randint(-20, 20, (1, seq_len, head_num * head_dim), dtype=torch.int8, device="cuda") + input_scale = 1 / 20.0 + + output = torch.matmul(input.to(torch.float) * input_scale, ones) + qkv_max_out = torch.max(torch.abs(output)) / 127 + smooth_attn.q_proj.a = torch.tensor(input_scale / qkv_max_out) + smooth_attn.k_proj.a = torch.tensor(input_scale / qkv_max_out) + smooth_attn.v_proj.a = torch.tensor(input_scale / qkv_max_out) q = smooth_attn.q_proj(input) k = smooth_attn.k_proj(input) v = smooth_attn.v_proj(input) - cos_shape = (seq_len, head_dim // 2) cos = torch.ones(cos_shape, dtype=dtype, device="cuda") sin = torch.zeros(cos_shape, dtype=dtype, device="cuda") - - in_scale = torch.tensor([1.0], device="cuda") - out_scale = torch.tensor([1.0], device="cuda") - + in_scale = torch.tensor([qkv_max_out], device="cuda") + out_scale = torch.tensor([(qkv_max_out + 2)], device="cuda") int8_rotary_embedding_fwd(q.view(-1, head_num, head_dim), cos, sin, in_scale, out_scale) int8_rotary_embedding_fwd(k.view(-1, head_num, head_dim), cos, sin, in_scale, out_scale) - - q = q.to(torch.float) - k = k.to(torch.float) - v = v.to(torch.float) + q = q.to(torch.float) * out_scale + k = k.to(torch.float) * out_scale + v = v.to(torch.float) * out_scale torch_out = torch_context_attention(q.clone(), k.clone(), v.clone(), 1, seq_len, head_num, head_dim) - torch_out = (torch_out).to(torch.int8).view(-1, seq_len, head_num * head_dim) + + output = torch.matmul(torch_out.view(-1, seq_len, head_num * head_dim), ones) + o_input_scale = torch.max(torch.abs(output)) / 127 + smooth_attn.qk_bmm.a = torch.tensor(1 / qkv_max_out * 1 / o_input_scale) + smooth_attn.pv_bmm.a = torch.tensor(1 / 127.0 * qkv_max_out / o_input_scale) + smooth_attn.out_proj.a = torch.tensor([1.0 / o_input_scale]) + smooth_attn = smooth_attn.to("cuda") + + torch_out = (torch_out * smooth_attn.out_proj.a).to(torch.int8).view(-1, seq_len, head_num * head_dim) torch_out = smooth_attn.out_proj(torch_out) + smooth_out, _, _ = smooth_attn(input, (cos, sin)) smooth_out = smooth_out.to(torch.float) torch_out = torch_out.to(torch.float) assert torch.allclose( - smooth_out.cpu(), torch_out.cpu(), rtol=1e-2, atol=1e-2 + torch_out.cpu(), smooth_out.cpu(), rtol=1e-2, atol=1e-2 ), "outputs from triton and torch are not matched" diff --git a/tests/test_smoothquant/test_llama_mlp.py b/tests/test_smoothquant/test_llama_mlp.py index ec0aaaba0198..236edb10cb7f 100644 --- a/tests/test_smoothquant/test_llama_mlp.py +++ b/tests/test_smoothquant/test_llama_mlp.py @@ -70,7 +70,7 @@ def test_llama_mlp(): x.to(torch.float), ) - smooth_mlp.down_proj_input_scale = max_inter.item() / 127 + smooth_mlp.down_proj_input_scale = torch.tensor(max_inter.item() / 127) smooth_mlp.gate_proj.a = torch.tensor(1 / hidden_size) smooth_mlp.up_proj.a = torch.tensor(1 / 127) smooth_mlp.down_proj.a = torch.tensor(1 / 127 * (max_inter.item() / 127)) From dc4f5068a03b64c34eb83184e1b7ae16bea949c7 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Mon, 9 Oct 2023 18:01:43 +0800 Subject: [PATCH 2/7] fix attention accuracy --- .../quant/smoothquant/models/llama.py | 25 ++++----- colossalai/kernel/triton/__init__.py | 2 +- .../triton/int8_rotary_embedding_kernel.py | 46 ++++++++-------- .../test_smoothquant/test_llama_attention.py | 53 ++++++++++++------- .../test_sq_rotary_embedding.py | 2 +- 5 files changed, 73 insertions(+), 55 deletions(-) diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py index 831b85e14d5d..b2a0467c1623 100644 --- a/colossalai/inference/quant/smoothquant/models/llama.py +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -66,7 +66,7 @@ def __init__( self.register_buffer("v_output_scale", torch.tensor([1.0])) self.register_buffer("q_rotary_output_scale", torch.tensor([1.0])) self.register_buffer("k_rotary_output_scale", torch.tensor([1.0])) - self.register_buffer("qk_output_scale", torch.tensor([1.0])) + # self.register_buffer("qk_output_scale", torch.tensor([1.0])) self.register_buffer("attn_output_scale", torch.tensor([1.0])) @staticmethod @@ -96,14 +96,14 @@ def pack( int8_module.k_proj = W8A8B8O8Linear.from_float(module.k_proj, attn_input_scale, k_output_scale) int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, attn_input_scale, v_output_scale) int8_module.o_proj = W8A8BFP32OFP32Linear.from_float(module.o_proj, out_input_scale) - # print("qout_scale k out scale:", q_output_scale, k_output_scale) - int8_module.qk_bmm = BMM_S8T_S8N_F32T.from_scale(q_output_scale, k_output_scale) + # # print("qout_scale k out scale:", q_output_scale, k_output_scale) + # int8_module.qk_bmm = BMM_S8T_S8N_F32T.from_scale(q_output_scale, k_output_scale) - # alpha = s_prob * s_v / s_out, where s_prob = 1 / 127 - int8_module.pv_bmm = BMM_S8T_S8N_S8T.from_scale(1.0 / 127, v_output_scale, out_input_scale) + # # alpha = s_prob * s_v / s_out, where s_prob = 1 / 127 + # int8_module.pv_bmm = BMM_S8T_S8N_S8T.from_scale(1.0 / 127, v_output_scale, out_input_scale) - int8_module.qk_output_scale = torch.tensor(q_output_scale * k_output_scale) - int8_module.attn_output_scale = torch.tensor(1.0 / 127 * v_output_scale / out_input_scale) + # int8_module.qk_output_scale = torch.tensor(q_output_scale * k_output_scale) + int8_module.attn_output_scale = torch.tensor(out_input_scale) return int8_module @@ -135,15 +135,15 @@ def forward( query_states.view(-1, self.num_heads, self.head_dim), cos, sin, - self.q_output_scale, - self.q_rotary_output_scale, + self.q_output_scale.item(), + self.q_rotary_output_scale.item(), ) int8_rotary_embedding_fwd( key_states.view(-1, self.num_heads, self.head_dim), cos, sin, - self.k_output_scale, - self.k_rotary_output_scale, + self.k_output_scale.item(), + self.k_rotary_output_scale.item(), ) if past_key_value is None: @@ -153,7 +153,7 @@ def forward( key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) - # attn_output = torch.empty(bsz * seq_len, self.num_heads, self.head_dim, dtype=torch.int8, device="cuda") + # attn_output = torch.empty(bsz * seq_len, self.num_heads, self.head_dim, dtype=torch.float32, device="cuda") attn_output = torch.empty_like(query_states) b_start_loc = torch.arange(start=0, end=bsz * seq_len, step=seq_len, dtype=torch.int, device="cuda") @@ -172,6 +172,7 @@ def forward( b_seq_len, seq_len, ) + if use_cache: past_key_value = ( key_states.view(bsz, seq_len, -1, self.head_dim), diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 877e914c6de0..0278e98dbc5e 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -10,7 +10,7 @@ from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd from .rms_norm import rmsnorm_forward from .rotary_embedding_kernel import rotary_embedding_fwd - from .smooth__attention import smooth_llama_context_attn_fwd, smooth_token_attention_fwd + from .smooth_attention import smooth_llama_context_attn_fwd, smooth_token_attention_fwd from .softmax import softmax from .token_attention_kernel import token_attention_fwd diff --git a/colossalai/kernel/triton/int8_rotary_embedding_kernel.py b/colossalai/kernel/triton/int8_rotary_embedding_kernel.py index 126a52968088..537dd164d1ab 100644 --- a/colossalai/kernel/triton/int8_rotary_embedding_kernel.py +++ b/colossalai/kernel/triton/int8_rotary_embedding_kernel.py @@ -89,29 +89,29 @@ def int8_rotary_embedding_fwd(q, cos, sin, input_scale, output_scale): assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" BLOCK_HEAD = 4 BLOCK_SEQ = 32 - (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) + grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) if head_dim >= 128: - pass + num_warps = 8 else: - pass - - # _rotary_kernel[grid]( - # q, - # input_scale.item(), - # output_scale.item(), - # cos, - # sin, - # q.stride(0), - # q.stride(1), - # q.stride(2), - # cos.stride(0), - # cos.stride(1), - # total_len, - # HEAD_NUM=head_num, - # BLOCK_HEAD=BLOCK_HEAD, - # BLOCK_SEQ=BLOCK_SEQ, - # HEAD_DIM=head_dim, - # num_warps=num_warps, - # num_stages=1, - # ) + num_warps = 4 + + _rotary_kernel[grid]( + q, + input_scale, + output_scale, + cos, + sin, + q.stride(0), + q.stride(1), + q.stride(2), + cos.stride(0), + cos.stride(1), + total_len, + HEAD_NUM=head_num, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_SEQ=BLOCK_SEQ, + HEAD_DIM=head_dim, + num_warps=num_warps, + num_stages=1, + ) return diff --git a/tests/test_smoothquant/test_llama_attention.py b/tests/test_smoothquant/test_llama_attention.py index 6149f43ce970..f8c79145c952 100644 --- a/tests/test_smoothquant/test_llama_attention.py +++ b/tests/test_smoothquant/test_llama_attention.py @@ -42,11 +42,10 @@ def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): xq = xq.transpose(1, 2) keys = keys.transpose(1, 2) values = values.transpose(1, 2) - sm_scale = 1 / math.sqrt(head_dim) - scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale - scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float) - + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim) + scores = F.softmax(scores.float() + mask, dim=-1).type_as(xq) output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) + return output @@ -67,6 +66,9 @@ def test_llama_context_attention(): smooth_attn.k_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) smooth_attn.v_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) smooth_attn.out_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) + smooth_attn.out_proj.weight[:, 1:hidden_size] = torch.zeros(hidden_size - 1, device="cuda").to(torch.int8) + + qkv_weight_scale = 1.0 ones = torch.ones(hidden_size, hidden_size, dtype=torch.float, device="cuda") @@ -77,41 +79,56 @@ def test_llama_context_attention(): output = torch.matmul(input.to(torch.float) * input_scale, ones) qkv_max_out = torch.max(torch.abs(output)) / 127 - smooth_attn.q_proj.a = torch.tensor(input_scale / qkv_max_out) - smooth_attn.k_proj.a = torch.tensor(input_scale / qkv_max_out) - smooth_attn.v_proj.a = torch.tensor(input_scale / qkv_max_out) + smooth_attn.q_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out) + smooth_attn.k_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out) + smooth_attn.v_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out) q = smooth_attn.q_proj(input) k = smooth_attn.k_proj(input) v = smooth_attn.v_proj(input) + cos_shape = (seq_len, head_dim // 2) cos = torch.ones(cos_shape, dtype=dtype, device="cuda") sin = torch.zeros(cos_shape, dtype=dtype, device="cuda") in_scale = torch.tensor([qkv_max_out], device="cuda") - out_scale = torch.tensor([(qkv_max_out + 2)], device="cuda") - int8_rotary_embedding_fwd(q.view(-1, head_num, head_dim), cos, sin, in_scale, out_scale) - int8_rotary_embedding_fwd(k.view(-1, head_num, head_dim), cos, sin, in_scale, out_scale) + out_scale = torch.tensor([qkv_max_out], device="cuda") + int8_rotary_embedding_fwd(q.view(-1, head_num, head_dim), cos, sin, in_scale.item(), out_scale.item()) + int8_rotary_embedding_fwd(k.view(-1, head_num, head_dim), cos, sin, in_scale.item(), out_scale.item()) + q = q.to(torch.float) * out_scale k = k.to(torch.float) * out_scale v = v.to(torch.float) * out_scale torch_out = torch_context_attention(q.clone(), k.clone(), v.clone(), 1, seq_len, head_num, head_dim) + attn_out_max = torch.max(torch.abs(torch_out)) / 127 output = torch.matmul(torch_out.view(-1, seq_len, head_num * head_dim), ones) - o_input_scale = torch.max(torch.abs(output)) / 127 - smooth_attn.qk_bmm.a = torch.tensor(1 / qkv_max_out * 1 / o_input_scale) - smooth_attn.pv_bmm.a = torch.tensor(1 / 127.0 * qkv_max_out / o_input_scale) - smooth_attn.out_proj.a = torch.tensor([1.0 / o_input_scale]) - smooth_attn = smooth_attn.to("cuda") + smooth_attn.q_output_scale = torch.tensor(qkv_max_out) + smooth_attn.k_output_scale = torch.tensor(qkv_max_out) + + smooth_attn.v_output_scale = torch.tensor(qkv_max_out) + smooth_attn.q_rotary_output_scale = torch.tensor(qkv_max_out) + smooth_attn.k_rotary_output_scale = torch.tensor(qkv_max_out) + + smooth_attn.attn_output_scale = torch.tensor(attn_out_max) + smooth_attn.out_proj.a = torch.tensor([attn_out_max]) + + torch_out = ( + (torch_out / smooth_attn.attn_output_scale) + .round() + .clamp(-128, 127) + .to(torch.int8) + .view(-1, seq_len, head_num * head_dim) + ) - torch_out = (torch_out * smooth_attn.out_proj.a).to(torch.int8).view(-1, seq_len, head_num * head_dim) torch_out = smooth_attn.out_proj(torch_out) + torch_out = torch_out.to(torch.float) + smooth_attn = smooth_attn.to("cuda") smooth_out, _, _ = smooth_attn(input, (cos, sin)) smooth_out = smooth_out.to(torch.float) - torch_out = torch_out.to(torch.float) assert torch.allclose( - torch_out.cpu(), smooth_out.cpu(), rtol=1e-2, atol=1e-2 + torch_out.cpu(), smooth_out.cpu(), rtol=1e-1, atol=1e-1 ), "outputs from triton and torch are not matched" diff --git a/tests/test_smoothquant/test_sq_rotary_embedding.py b/tests/test_smoothquant/test_sq_rotary_embedding.py index ee030065d66e..4cc76f00474d 100644 --- a/tests/test_smoothquant/test_sq_rotary_embedding.py +++ b/tests/test_smoothquant/test_sq_rotary_embedding.py @@ -50,7 +50,7 @@ def test_rotary_emb(): x = x / input_scale x = x.to(torch.int8) - int8_rotary_embedding_fwd(x, cos, sin, input_scale, output_scale) + int8_rotary_embedding_fwd(x, cos, sin, input_scale.item(), output_scale.item()) y_triton = x.to(torch.float) * output_scale assert torch.allclose(y_triton, y_torch, atol=2e-1, rtol=1e-2, equal_nan=True) From d732e7e85a1b382a753b338b3551c03b5358aaa1 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Tue, 10 Oct 2023 18:18:42 +0800 Subject: [PATCH 3/7] fix accuracy --- .../quant/smoothquant/models/llama.py | 216 ++++++++---------- colossalai/kernel/triton/smooth_attention.py | 29 +-- 2 files changed, 106 insertions(+), 139 deletions(-) diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py index b2a0467c1623..bee403c91185 100644 --- a/colossalai/inference/quant/smoothquant/models/llama.py +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -13,7 +13,6 @@ from datasets import load_dataset from torch import nn from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T -from torch_int.nn.fused import LayerNormQ from tqdm import tqdm from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.configuration_llama import LlamaConfig @@ -24,6 +23,7 @@ LlamaMLP, repeat_kv, rotate_half, + LlamaRotaryEmbedding, ) from transformers.utils import add_start_docstrings_to_model_forward @@ -32,6 +32,7 @@ smooth_llama_context_attn_fwd, smooth_token_attention_fwd, ) +import torch.nn.functional as F from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear @@ -59,16 +60,25 @@ def __init__( self.k_proj = W8A8B8O8Linear(hidden_size, hidden_size) self.v_proj = W8A8B8O8Linear(hidden_size, hidden_size) self.q_proj = W8A8B8O8Linear(hidden_size, hidden_size) - self.out_proj = W8A8BFP32OFP32Linear(hidden_size, hidden_size) + self.o_proj = W8A8BFP32OFP32Linear(hidden_size, hidden_size) self.register_buffer("q_output_scale", torch.tensor([1.0])) self.register_buffer("k_output_scale", torch.tensor([1.0])) self.register_buffer("v_output_scale", torch.tensor([1.0])) self.register_buffer("q_rotary_output_scale", torch.tensor([1.0])) self.register_buffer("k_rotary_output_scale", torch.tensor([1.0])) - # self.register_buffer("qk_output_scale", torch.tensor([1.0])) - self.register_buffer("attn_output_scale", torch.tensor([1.0])) - + self.register_buffer("out_input_scale", torch.tensor([1.0])) + self.register_buffer("attn_input_scale", torch.tensor([1.0])) + + self._init_rope() + self.num_key_value_heads = num_heads + def _init_rope(self): + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=2048, + base=10000.0, + ) + @staticmethod def pack( module: LlamaAttention, @@ -81,6 +91,8 @@ def pack( out_input_scale: float, ): int8_module = LLamaSmoothquantAttention(module.hidden_size, module.num_heads) + # self.register_buffer("attn_input_scale", torch.tensor([1.0])) + int8_module.attn_input_scale = torch.tensor(attn_input_scale) int8_module.q_output_scale = torch.tensor(q_output_scale) int8_module.k_output_scale = torch.tensor(k_output_scale) @@ -89,21 +101,17 @@ def pack( int8_module.q_rotary_output_scale = torch.tensor(q_rotary_output_scale) int8_module.k_rotary_output_scale = torch.tensor(k_rotary_output_scale) - # q_output_scale = q_output_scale * module.scaling - # module.q_proj.weight *= module.scaling - # module.q_proj.bias *= module.scaling int8_module.q_proj = W8A8B8O8Linear.from_float(module.q_proj, attn_input_scale, q_output_scale) int8_module.k_proj = W8A8B8O8Linear.from_float(module.k_proj, attn_input_scale, k_output_scale) int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, attn_input_scale, v_output_scale) int8_module.o_proj = W8A8BFP32OFP32Linear.from_float(module.o_proj, out_input_scale) - # # print("qout_scale k out scale:", q_output_scale, k_output_scale) - # int8_module.qk_bmm = BMM_S8T_S8N_F32T.from_scale(q_output_scale, k_output_scale) - # # alpha = s_prob * s_v / s_out, where s_prob = 1 / 127 - # int8_module.pv_bmm = BMM_S8T_S8N_S8T.from_scale(1.0 / 127, v_output_scale, out_input_scale) - # int8_module.qk_output_scale = torch.tensor(q_output_scale * k_output_scale) - int8_module.attn_output_scale = torch.tensor(out_input_scale) + # int8_module.q_proj = module.q_proj + # int8_module.k_proj = module.k_proj + # int8_module.v_proj = module.v_proj + # int8_module.o_proj = module.o_proj + int8_module.out_input_scale = torch.tensor(out_input_scale) return int8_module @@ -122,8 +130,8 @@ def forward( use_cache: bool = False, padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, seq_len, _ = hidden_states.size() - # get query proj + bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) @@ -147,17 +155,22 @@ def forward( ) if past_key_value is None: - proj_shape = (bsz * seq_len, -1, self.head_dim) + proj_shape = (bsz*q_len, self.num_heads, self.head_dim) query_states = query_states.view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) - # attn_output = torch.empty(bsz * seq_len, self.num_heads, self.head_dim, dtype=torch.float32, device="cuda") + # attn_output = torch.empty(bsz*q_len, self.num_heads, self.head_dim, dtype=torch.float16, device="cuda") attn_output = torch.empty_like(query_states) - b_start_loc = torch.arange(start=0, end=bsz * seq_len, step=seq_len, dtype=torch.int, device="cuda") - b_seq_len = torch.full([bsz], seq_len, dtype=torch.int, device="cuda") + b_start_loc = torch.zeros((bsz,), dtype=torch.int32, device="cuda") + b_seq_len = torch.ones((bsz,), dtype=torch.int32, device="cuda") + + b_seq_len[0] = q_len + + for i in range(1, bsz): + b_start_loc[i] = b_start_loc[i - 1] + b_seq_len[i - 1] smooth_llama_context_attn_fwd( query_states, @@ -167,23 +180,22 @@ def forward( self.q_rotary_output_scale.item(), self.k_rotary_output_scale.item(), self.v_output_scale.item(), - self.attn_output_scale.item(), + self.out_input_scale.item(), b_start_loc, b_seq_len, - seq_len, + q_len, ) - if use_cache: past_key_value = ( - key_states.view(bsz, seq_len, -1, self.head_dim), - value_states.view(bsz, seq_len, -1, self.head_dim), + key_states.view(bsz, q_len, -1, self.head_dim), + value_states.view(bsz, q_len, -1, self.head_dim), ) else: - total_seq_len = past_key_value[0].shape[1] + seq_len - key_states = torch.cat([past_key_value[0], key_states.view(bsz, seq_len, -1, self.head_dim)], dim=1) - value_states = torch.cat([past_key_value[1], value_states.view(bsz, seq_len, -1, self.head_dim)], dim=1) + total_seq_len = past_key_value[0].shape[1] + q_len + key_states = torch.cat([past_key_value[0], key_states.view(bsz, q_len, -1, self.head_dim)], dim=1) + value_states = torch.cat([past_key_value[1], value_states.view(bsz, q_len, -1, self.head_dim)], dim=1) - proj_shape = (bsz * seq_len, -1, self.head_dim) + proj_shape = (bsz * q_len, -1, self.head_dim) kv_shape = (bsz * total_seq_len, -1, self.head_dim) query_states = query_states.view(*proj_shape) key_states = key_states.view(*kv_shape) @@ -193,7 +205,7 @@ def forward( b_start_loc = torch.arange( start=0, end=bsz * total_seq_len, step=total_seq_len, dtype=torch.int, device="cuda" ) - b_seq_len = torch.full([bsz], seq_len, dtype=torch.int, device="cuda") * total_seq_len + b_seq_len = torch.full([bsz], q_len, dtype=torch.int, device="cuda") * total_seq_len block_loc = torch.arange(total_seq_len, dtype=torch.int, device="cuda").expand(bsz, -1) smooth_token_attention_fwd( query_states, @@ -214,99 +226,41 @@ def forward( key_states.view(bsz, total_seq_len, -1, self.head_dim), value_states.view(bsz, total_seq_len, -1, self.head_dim), ) - # if use_cache: - # past_key_value = (key_states, value_states) - - # if use_cache: - # if past_key_value is not None: - # # reuse k, v, self_attention - # key_states = self._shape(key_states, -1, bsz) - # value_states = self._shape(value_states, -1, bsz) - # key_states = torch.cat([past_key_value[0], key_states], dim=2) - # value_states = torch.cat([past_key_value[1], value_states], dim=2) - # else: - # # self_attention - # key_states = self._shape(key_states, -1, bsz) - # value_states = self._shape(value_states, -1, bsz) - - # if use_cache: - # if past_key_value is not None: - # # reuse k, v, self_attention - # key_states = self._shape(key_states, -1, bsz) - # value_states = self._shape(value_states, -1, bsz) - # key_states = torch.cat([past_key_value[0], key_states], dim=2) - # value_states = torch.cat([past_key_value[1], value_states], dim=2) - # else: - # # self_attention - # key_states = self._shape(key_states, -1, bsz) - # value_states = self._shape(value_states, -1, bsz) - - # past_key_value = (key_states, value_states) - - # proj_shape = (bsz * self.num_heads, -1, self.head_dim) - - # query_states = self._shape(query_states, seq_len, bsz).view(*proj_shape) - # key_states = key_states.view(*proj_shape) - # value_states = value_states.view(*proj_shape) - - # src_len = key_states.size(1) - # print("q states:", query_states.shape, query_states.device, query_states.is_contiguous(), query_states.dtype) - # print("key states:", key_states.shape, key_states.device, key_states.is_contiguous(), key_states.dtype) - - # attn_weights = self.qk_bmm(query_states, key_states) - - # if attn_weights.size() != (bsz * self.num_heads, seq_len, src_len): - # raise ValueError( - # f"Attention weights should be of size {(bsz * self.num_heads, seq_len, src_len)}, but is" - # f" {attn_weights.size()}" - # ) - - # if attention_mask is not None: - # if attention_mask.size() != (bsz, 1, seq_len, src_len): - # raise ValueError( - # f"Attention mask should be of size {(bsz, 1, seq_len, src_len)}, but is {attention_mask.size()}" - # ) - # attn_weights = attn_weights.view(bsz, self.num_heads, seq_len, src_len) + attention_mask - # attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) - # attn_weights = attn_weights.view(bsz * self.num_heads, seq_len, src_len) - - # attn_probs = nn.functional.softmax(attn_weights, dim=-1) - - # if output_attentions: - # # this operation is a bit awkward, but it's required to - # # make sure that attn_weights keeps its gradient. - # # In order to do so, attn_weights have to be reshaped - # # twice and have to be reused in the following - # attn_probs_reshaped = attn_probs.view(bsz, self.num_heads, seq_len, src_len) - # attn_probs = attn_probs_reshaped.view(bsz * self.num_heads, seq_len, src_len) - # else: - # attn_probs_reshaped = None - - # # (A_row V_row)_row = (A_row V_col ^T)_row - # attn_probs.mul_(127).round_() - # attn_probs = attn_probs.to(torch.int8) - - # value_states = value_states.transpose(1, 2).contiguous() - # attn_output = self.pv_bmm(attn_probs, value_states) - - # if attn_output.size() != (bsz * self.num_heads, seq_len, self.head_dim): - # raise ValueError( - # f"`attn_output` should be of size {(bsz, self.num_heads, seq_len, self.head_dim)}, but is" - # f" {attn_output.size()}" - # ) - - # attn_output = attn_output.view(bsz, self.num_heads, seq_len, self.head_dim) - # attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned aross GPUs when using tensor-parallelism. - # attn_output = attn_output.reshape(bsz, seq_len, self.num_heads * self.head_dim).contiguous() - - attn_output = attn_output.view(bsz, seq_len, self.num_heads * self.head_dim) - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value + attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaLayerNormQ(torch.nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.input_scale = 1.0 + self.variance_epsilon = eps + self.register_buffer('weight', torch.ones(dim, dtype=torch.float32)) + + def forward(self, x): + + input_dtype = x.dtype + hidden_states = x.to(torch.float32) + ln_output_fp = torch.nn.functional.layer_norm( + x, x.shape[-1:], self.weight, None, self.variance_epsilon) + ln_output_int8 = ln_output_fp.round().clamp(-128, 127).to(torch.int8) + return ln_output_int8 + + + @staticmethod + def from_float(module: torch.nn.LayerNorm, output_scale: float): + assert module.weight.shape[0] == module.weight.numel() + # assert module.bias.shape[0] == module.bias.numel() + q_module = LlamaLayerNormQ(module.weight.shape[0], module.variance_epsilon ) + q_module.weight = module.weight / output_scale + # q_module.bias = module.bias / output_scale + return q_module class LlamaSmoothquantMLP(nn.Module): def __init__(self, intermediate_size, hidden_size): @@ -355,9 +309,9 @@ def __init__(self, config: LlamaConfig): self.self_attn = LLamaSmoothquantAttention(config.hidden_size, config.num_attention_heads) self.mlp = LlamaSmoothquantMLP(config.intermediate_size, config.hidden_size) - self.input_layernorm = LayerNormQ(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LayerNormQ(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps) @staticmethod def pack( @@ -375,6 +329,8 @@ def pack( ): config = module.self_attn.config int8_decoder_layer = LlamaSmoothquantDecoderLayer(config) + + int8_decoder_layer.input_layernorm = LlamaLayerNormQ.from_float(module.input_layernorm, attn_input_scale) int8_decoder_layer.self_attn = LLamaSmoothquantAttention.pack( module.self_attn, attn_input_scale, @@ -386,6 +342,13 @@ def pack( out_input_scale, ) + + # int8_decoder_layer.input_layernorm = module.input_layernorm + # int8_decoder_layer.self_attn = module.self_attn + + + int8_decoder_layer.post_attention_layernorm = LlamaLayerNormQ.from_float(module.post_attention_layernorm, gate_input_scale) + int8_decoder_layer.mlp = LlamaSmoothquantMLP.pack( module.mlp, gate_input_scale, @@ -393,9 +356,9 @@ def pack( down_input_scale, ) - int8_decoder_layer.input_layernorm = LayerNormQ(config.hidden_size, eps=config.rms_norm_eps) + # int8_decoder_layer.post_attention_layernorm = module.post_attention_layernorm + # int8_decoder_layer.mlp = module.mlp - int8_decoder_layer.post_attention_layernorm = LayerNormQ(config.hidden_size, eps=config.rms_norm_eps) return int8_decoder_layer @@ -439,6 +402,7 @@ def forward( use_cache=use_cache, padding_mask=padding_mask, ) + hidden_states = residual + hidden_states # Fully Connected residual = hidden_states diff --git a/colossalai/kernel/triton/smooth_attention.py b/colossalai/kernel/triton/smooth_attention.py index c726d4b03a8c..ee0df6a74eaa 100644 --- a/colossalai/kernel/triton/smooth_attention.py +++ b/colossalai/kernel/triton/smooth_attention.py @@ -73,7 +73,7 @@ def _context_flash_attention_kernel( + offs_d[None, :] * stride_qd ) q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - q = q.to(tl.float32) * q_input_scale + q = q.to(tl.float16) * q_input_scale.to(tl.float16) k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd @@ -95,7 +95,7 @@ def _context_flash_attention_kernel( mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0, ) - k = k.to(tl.float32) * k_input_scale + k = k.to(tl.float16) * k_input_scale.to(tl.float16) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -131,13 +131,13 @@ def _context_flash_attention_kernel( other=0.0, ) - v = v.to(tl.float32) * v_input_scale + v = v.to(tl.float16) * v_input_scale.to(tl.float16) p = p.to(v.dtype) acc += tl.dot(p, v) # update m_i and l_i l_i = l_i_new m_i = m_i_new - acc = (acc / pv_output_scale).to(tl.int8) + acc = (acc / pv_output_scale.to(tl.float16)).to(tl.int8) off_o = ( (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od ) @@ -145,24 +145,27 @@ def _context_flash_attention_kernel( tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) return + + @torch.no_grad() def smooth_llama_context_attn_fwd( q, k, v, o, q_input_scale, k_input_scale, v_input_scale, pv_output_scale, b_start_loc, b_seq_len, max_input_len ): - BLOCK = 32 + + BLOCK = 128 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk, "context process only supports equal query, key, value length" assert Lk == Lv, "context process only supports equal query, key, value length" assert Lk in {16, 32, 64, 128} - + BLOCK_N = 128 sm_scale = 1.0 / math.sqrt(Lk) batch, head = b_seq_len.shape[0], q.shape[1] grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) num_warps = 4 if Lk <= 64 else 8 - # num_warps = 4 + _context_flash_attention_kernel[grid]( q, k, @@ -245,7 +248,7 @@ def _token_attn_1_kernel( for start_mark in range(0, block_mask, 1): q = tl.load(Q + off_q + start_mark) - q = q.to(tl.float32) * q_input_scale + q = q.to(tl.float16) * q_input_scale.to(tl.float16) offs_n_new = current_batch_start_index + offs_n k_loc = tl.load( kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, @@ -254,7 +257,7 @@ def _token_attn_1_kernel( ) off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) - k = k.to(tl.float32) * k_input_scale + k = k.to(tl.float16) * k_input_scale.to(tl.float16) att_value = tl.sum(q[None, :] * k, 1) att_value *= sm_scale off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride @@ -308,7 +311,7 @@ def _token_attn_1_alibi_kernel( for start_mark in range(0, block_mask, 1): alibi_m = tl.load(alibi + current_head) q = tl.load(Q + off_q + start_mark) - q = q.to(tl.float32) * q_input_scale + q = q.to(tl.float16) * q_input_scale.to(tl.float16) offs_n_new = current_batch_start_index + offs_n k_loc = tl.load( @@ -318,7 +321,7 @@ def _token_attn_1_alibi_kernel( ) off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) - k = k.to(tl.float32) * k_input_scale + k = k.to(tl.float16) * k_input_scale.to(tl.float16) att_value = tl.sum(q[None, :] * k, 1) att_value *= sm_scale att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n) @@ -531,10 +534,10 @@ def _token_attn_2_kernel( mask=(start_n + offs_n[:, None]) < current_batch_seq_len, other=0.0, ) - v_value = v_value.to(tl.float32) * v_input_scale + v_value = v_value.to(tl.float16) * v_input_scale.to(tl.float16) acc += tl.sum(p_value[:, None] * v_value, 0) - acc = (acc / pv_output_scale).to(tl.int8) + acc = (acc / pv_output_scale.to(tl.float16)).to(tl.int8) off_o = ( current_batch * attn_out_batch_stride + current_head * attn_out_head_stride From ca98077b738345c537cf1dadad946a9f6a114399 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 11 Oct 2023 17:57:48 +0800 Subject: [PATCH 4/7] add kv cache and save pretrained --- .../quant/smoothquant/calibration.py | 11 +- .../quant/smoothquant/models/base_model.py | 136 +++-- .../quant/smoothquant/models/llama.py | 510 ++++++++---------- 3 files changed, 318 insertions(+), 339 deletions(-) diff --git a/colossalai/inference/quant/smoothquant/calibration.py b/colossalai/inference/quant/smoothquant/calibration.py index 0b61beed6733..66ac49826592 100644 --- a/colossalai/inference/quant/smoothquant/calibration.py +++ b/colossalai/inference/quant/smoothquant/calibration.py @@ -1,3 +1,5 @@ +# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ + import functools import torch @@ -29,21 +31,14 @@ def stat_input_hook(m, x, y, name): for name, m in model.named_modules(): if isinstance(m, nn.Linear): hooks.append(m.register_forward_hook(functools.partial(stat_input_hook, name=name))) - # print("data path", dataset_path) - # dataset = load_dataset("csv", data_files=dataset_path, split="train") - # dataset = dataset.shuffle(seed=42) - # dataset = load_dataset("/home/lcxk/data3/datasets/cc_news.py", data_files=dataset_path) dataset = load_dataset("json", data_files=dataset_path) print("text", dataset["train"]["rows"][0][1]["row"]["text"]) - # for test in dataset["train"]["rows"]: - # print(test) + dataset = dataset.shuffle(seed=42) - # print("text", dataset["rows"][0]) for i in tqdm(range(num_samples)): - # print("text", dataset[i]) input_ids = tokenizer( dataset["train"]["rows"][0][i]["row"]["text"], return_tensors="pt", diff --git a/colossalai/inference/quant/smoothquant/models/base_model.py b/colossalai/inference/quant/smoothquant/models/base_model.py index 6aec56d557ce..326c3df6e038 100644 --- a/colossalai/inference/quant/smoothquant/models/base_model.py +++ b/colossalai/inference/quant/smoothquant/models/base_model.py @@ -1,6 +1,9 @@ +# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ + import os +import types +import warnings from abc import abstractmethod -from logging import getLogger from os.path import isdir, isfile, join from typing import Dict, List, Optional, Union @@ -15,8 +18,8 @@ from transformers.utils.generic import ContextManagers from transformers.utils.hub import PushToHubMixin, cached_file -logger = getLogger(__name__) - +from ....tensor_parallel.batch_infer_state import BatchInferState +from ....tensor_parallel.kvcache_manager import MemoryManager CPU = device("cpu") CUDA_0 = device("cuda:0") @@ -67,6 +70,8 @@ def simple_dispatch_model(model, device_map): class BaseSmoothForCausalLM(nn.Module, PushToHubMixin): + layer_type: str = None + def __init__(self, model: PreTrainedModel, quantized: bool = False): super().__init__() @@ -74,11 +79,61 @@ def __init__(self, model: PreTrainedModel, quantized: bool = False): self.model_type = self.model.config.model_type self._quantized = quantized self.config = self.model.config + self.cache_manager = None + self.max_total_token_num = 0 @property def quantized(self): return self._quantized + def init_cache_manager(self, max_total_token_num=2048): + if self.config.model_type == "llama": + head_num = self.config.num_key_value_heads + layer_num = self.config.num_hidden_layers + head_dim = self.config.hidden_size // head_num + + self.cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num) + self.max_total_token_num = max_total_token_num + + def init_batch_state(self, max_output_len=256, **kwargs): + input_ids = kwargs["input_ids"] + batch_size = len(input_ids) + + seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + start_index = 0 + max_len_in_batch = -1 + + for i in range(batch_size): + seq_len = len(input_ids[i]) + seq_lengths[i] = seq_len + seq_start_indexes[i] = start_index + start_index += seq_len + max_len_in_batch = seq_len if seq_len > max_len_in_batch else max_len_in_batch + + if "max_total_token_num" in kwargs.keys(): + max_total_token_num = kwargs["max_total_token_num"] + self.init_cache_manager(max_total_token_num) + + if "max_new_tokens" in kwargs.keys(): + max_output_len = kwargs["max_new_tokens"] + + if batch_size * (max_len_in_batch + max_output_len) > self.max_total_token_num: + max_total_token_num = batch_size * (max_len_in_batch + max_output_len) + warnings.warn(f"reset max tokens to {max_total_token_num}") + self.init_cache_manager(max_total_token_num) + + block_loc = torch.empty((batch_size, max_len_in_batch + max_output_len), dtype=torch.long, device="cuda") + batch_infer_state = BatchInferState(batch_size, max_len_in_batch) + batch_infer_state.seq_len = seq_lengths.to("cuda") + batch_infer_state.start_loc = seq_start_indexes.to("cuda") + batch_infer_state.block_loc = block_loc + batch_infer_state.decode_layer_id = 0 + batch_infer_state.past_key_values_len = 0 + batch_infer_state.is_context_stage = True + batch_infer_state.set_cache_manager(self.cache_manager) + return batch_infer_state + @abstractmethod @torch.inference_mode() def quantize( @@ -97,6 +152,13 @@ def forward(self, *args, **kwargs): def generate(self, **kwargs): """shortcut for model.generate""" + + batch_infer_state = self.init_batch_state(**kwargs) + if self.config.model_type == "llama": + setattr(self.model.model, "infer_state", batch_infer_state) + + batch_infer_state.is_context_stage = True + with torch.inference_mode(): return self.model.generate(**kwargs) @@ -104,14 +166,10 @@ def prepare_inputs_for_generation(self, *args, **kwargs): """shortcut for model.prepare_inputs_for_generation""" return self.model.prepare_inputs_for_generation(*args, **kwargs) - @classmethod - def make_smooth_model(cls, model): - raise NotImplementedError("not implememented smooth model") - def save_quantized( self, save_dir: str, - model_file_base_name: str = None, + model_basename: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None, ): @@ -123,7 +181,7 @@ def save_quantized( self.model.to(CPU) - model_base_name = model_file_base_name or f"smooth-" + model_base_name = model_basename # or f"smooth-" if use_safetensors: model_save_name = model_base_name + ".safetensors" state_dict = self.model.state_dict() @@ -133,7 +191,7 @@ def save_quantized( elif not isinstance(safetensors_metadata, dict): raise TypeError("safetensors_metadata must be a dictionary.") else: - logger.debug(f"Received safetensors_metadata: {safetensors_metadata}") + print(f"Received safetensors_metadata: {safetensors_metadata}") new_safetensors_metadata = {} converted_keys = False for key, value in safetensors_metadata.items(): @@ -147,13 +205,13 @@ def save_quantized( f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}" ) if new_key in new_safetensors_metadata: - logger.warning( + print( f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting." ) new_safetensors_metadata[new_key] = new_value safetensors_metadata = new_safetensors_metadata if converted_keys: - logger.debug( + print( f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}" ) @@ -176,7 +234,7 @@ def save_pretrained( **kwargs, ): """alias of save_quantized""" - logger.warning("you are using save_pretrained, which will re-direct to save_quantized.") + warnings.warn("you are using save_pretrained, which will re-direct to save_quantized.") self.save_quantized(save_dir, use_safetensors, safetensors_metadata) @classmethod @@ -267,7 +325,7 @@ def skip(*args, **kwargs): model.seqlen = model_config[key] break else: - logger.warning("can't get model's sequence length from model config, will set to 4096.") + warnings.warn("can't get model's sequence length from model config, will set to 4096.") model.seqlen = 4096 model.eval() @@ -277,12 +335,12 @@ def skip(*args, **kwargs): def from_quantized( cls, model_name_or_path: Optional[str], + model_basename: Optional[str] = None, device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None, max_memory: Optional[dict] = None, device: Optional[Union[str, int]] = None, low_cpu_mem_usage: bool = False, torch_dtype: Optional[torch.dtype] = None, - model_basename: Optional[str] = None, use_safetensors: bool = False, trust_remote_code: bool = False, **kwargs, @@ -360,46 +418,30 @@ def skip(*args, **kwargs): init_contexts = [no_init_weights()] if low_cpu_mem_usage: - init_contexts.append(accelerate.init_empty_weights(include_buffers=False)) + init_contexts.append(accelerate.init_empty_weights(include_buffers=True)) with ContextManagers(init_contexts): model = AutoModelForCausalLM.from_config( config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype ) - cls.make_smooth_model(model) - model.tie_weights() + if config.model_type == "llama": + from .llama import LlamaSmoothquantDecoderLayer, init_to_get_rotary, llama_model_forward - # == step3: load checkpoint and dispatch == # - if isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: - raise ValueError( - "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or " - "'sequential'." - ) - if isinstance(device_map, dict): - max_memory = None - else: - if device is None and not device_map and not max_memory: - device_map = "auto" - if device is not None: - device = torch.device(device) - if not max_memory and not device_map: - device_map = {"": device.index if device.type == "cuda" else device.type} - if not isinstance(device_map, dict) and device_map != "sequential": - max_memory = accelerate.utils.get_balanced_memory( - model=model, - max_memory=max_memory, - no_split_module_classes=[cls.layer_type], - low_zero=(device_map == "balanced_low_0"), - ) - if not isinstance(device_map, dict): - device_map = accelerate.infer_auto_device_map( - model, max_memory=max_memory, no_split_module_classes=[cls.layer_type] - ) + llama_config = model.config + + for i, layer in enumerate(model.model.layers): + model.model.layers[i] = LlamaSmoothquantDecoderLayer(llama_config) + + model.model.forward = types.MethodType(llama_model_forward, model.model) + cos, sin = init_to_get_rotary(llama_config) + model.model.register_buffer("_cos_cached", cos) + model.model.register_buffer("_sin_cached", sin) + model.tie_weights() accelerate.utils.modeling.load_checkpoint_in_model( - model, checkpoint=model_save_name, device_map=device_map, offload_state_dict=True, offload_buffers=True + model, checkpoint=model_save_name, offload_state_dict=True, offload_buffers=True ) - model = simple_dispatch_model(model, device_map) + model = model.to("cuda") # == step4: set seqlen == # model_config = model.config.to_dict() @@ -410,7 +452,7 @@ def skip(*args, **kwargs): model.seqlen = model_config[key] break else: - logger.warning("can't get model's sequence length from model config, will set to 4096.") + warnings.warn("can't get model's sequence length from model config, will set to 4096.") model.seqlen = 4096 return cls( diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py index bee403c91185..9bfacd1c6a3b 100644 --- a/colossalai/inference/quant/smoothquant/models/llama.py +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -10,10 +10,12 @@ import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from datasets import load_dataset from torch import nn from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T from tqdm import tqdm +from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import ( @@ -21,19 +23,21 @@ LlamaAttention, LlamaDecoderLayer, LlamaMLP, + LlamaRotaryEmbedding, repeat_kv, rotate_half, - LlamaRotaryEmbedding, ) from transformers.utils import add_start_docstrings_to_model_forward from colossalai.kernel.triton import ( + copy_kv_cache_to_dest, int8_rotary_embedding_fwd, smooth_llama_context_attn_fwd, smooth_token_attention_fwd, ) -import torch.nn.functional as F +from ....tensor_parallel.batch_infer_state import BatchInferState +from .base_model import BaseSmoothForCausalLM from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear @@ -72,13 +76,14 @@ def __init__( self._init_rope() self.num_key_value_heads = num_heads + def _init_rope(self): self.rotary_emb = LlamaRotaryEmbedding( self.head_dim, max_position_embeddings=2048, base=10000.0, ) - + @staticmethod def pack( module: LlamaAttention, @@ -92,26 +97,25 @@ def pack( ): int8_module = LLamaSmoothquantAttention(module.hidden_size, module.num_heads) # self.register_buffer("attn_input_scale", torch.tensor([1.0])) - int8_module.attn_input_scale = torch.tensor(attn_input_scale) + int8_module.attn_input_scale = torch.tensor([attn_input_scale]) - int8_module.q_output_scale = torch.tensor(q_output_scale) - int8_module.k_output_scale = torch.tensor(k_output_scale) - int8_module.v_output_scale = torch.tensor(v_output_scale) + int8_module.q_output_scale = torch.tensor([q_output_scale]) + int8_module.k_output_scale = torch.tensor([k_output_scale]) + int8_module.v_output_scale = torch.tensor([v_output_scale]) - int8_module.q_rotary_output_scale = torch.tensor(q_rotary_output_scale) - int8_module.k_rotary_output_scale = torch.tensor(k_rotary_output_scale) + int8_module.q_rotary_output_scale = torch.tensor([q_rotary_output_scale]) + int8_module.k_rotary_output_scale = torch.tensor([k_rotary_output_scale]) int8_module.q_proj = W8A8B8O8Linear.from_float(module.q_proj, attn_input_scale, q_output_scale) int8_module.k_proj = W8A8B8O8Linear.from_float(module.k_proj, attn_input_scale, k_output_scale) int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, attn_input_scale, v_output_scale) int8_module.o_proj = W8A8BFP32OFP32Linear.from_float(module.o_proj, out_input_scale) - # int8_module.q_proj = module.q_proj # int8_module.k_proj = module.k_proj # int8_module.v_proj = module.v_proj # int8_module.o_proj = module.o_proj - int8_module.out_input_scale = torch.tensor(out_input_scale) + int8_module.out_input_scale = torch.tensor([out_input_scale]) return int8_module @@ -129,6 +133,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, padding_mask: Optional[torch.LongTensor] = None, + infer_state: Optional[BatchInferState] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -154,23 +159,34 @@ def forward( self.k_rotary_output_scale.item(), ) - if past_key_value is None: - proj_shape = (bsz*q_len, self.num_heads, self.head_dim) + # NOTE might want to revise + # need some way to record the length of past key values cache + # since we won't return past_key_value_cache right now + if infer_state.decode_layer_id == 0: # once per model.forward + infer_state.cache_manager.past_key_values_length += q_len # seq_len - query_states = query_states.view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) + return - # attn_output = torch.empty(bsz*q_len, self.num_heads, self.head_dim, dtype=torch.float16, device="cuda") - attn_output = torch.empty_like(query_states) + query_states = query_states.view(-1, self.num_heads, self.head_dim) + key_states = key_states.view(-1, self.num_heads, self.head_dim) + value_states = value_states.view(-1, self.num_heads, self.head_dim) - b_start_loc = torch.zeros((bsz,), dtype=torch.int32, device="cuda") - b_seq_len = torch.ones((bsz,), dtype=torch.int32, device="cuda") + if infer_state.is_context_stage: + # first token generation - b_seq_len[0] = q_len + # copy key and value calculated in current step to memory manager + _copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.context_mem_index, + infer_state.cache_manager, + ) - for i in range(1, bsz): - b_start_loc[i] = b_start_loc[i - 1] + b_seq_len[i - 1] + attn_output = torch.empty_like(query_states) smooth_llama_context_attn_fwd( query_states, @@ -181,51 +197,50 @@ def forward( self.k_rotary_output_scale.item(), self.v_output_scale.item(), self.out_input_scale.item(), - b_start_loc, - b_seq_len, + infer_state.start_loc, + infer_state.seq_len, q_len, ) - if use_cache: - past_key_value = ( - key_states.view(bsz, q_len, -1, self.head_dim), - value_states.view(bsz, q_len, -1, self.head_dim), - ) + else: - total_seq_len = past_key_value[0].shape[1] + q_len - key_states = torch.cat([past_key_value[0], key_states.view(bsz, q_len, -1, self.head_dim)], dim=1) - value_states = torch.cat([past_key_value[1], value_states.view(bsz, q_len, -1, self.head_dim)], dim=1) - - proj_shape = (bsz * q_len, -1, self.head_dim) - kv_shape = (bsz * total_seq_len, -1, self.head_dim) - query_states = query_states.view(*proj_shape) - key_states = key_states.view(*kv_shape) - value_states = value_states.view(*kv_shape) + if infer_state.decode_is_contiguous: + # if decode is contiguous, then we copy to key cache and value cache in cache manager directly + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_k.copy_(key_states) + cache_v.copy_(value_states) + else: + # if decode is not contiguous, use triton kernel to copy key and value cache + # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head + _copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.decode_mem_index, + infer_state.cache_manager, + ) + + # (batch_size, seqlen, nheads, headdim) attn_output = torch.empty_like(query_states) - b_start_loc = torch.arange( - start=0, end=bsz * total_seq_len, step=total_seq_len, dtype=torch.int, device="cuda" - ) - b_seq_len = torch.full([bsz], q_len, dtype=torch.int, device="cuda") * total_seq_len - block_loc = torch.arange(total_seq_len, dtype=torch.int, device="cuda").expand(bsz, -1) smooth_token_attention_fwd( query_states, - key_states, - value_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], attn_output, self.q_rotary_output_scale.item(), self.k_rotary_output_scale.item(), self.v_output_scale.item(), - self.attn_output_scale.item(), - block_loc, - b_start_loc, - b_seq_len, - total_seq_len, + self.out_input_scale.item(), + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, ) - if use_cache: - past_key_value = ( - key_states.view(bsz, total_seq_len, -1, self.head_dim), - value_states.view(bsz, total_seq_len, -1, self.head_dim), - ) attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim) attn_output = self.o_proj(attn_output) @@ -241,27 +256,25 @@ def __init__(self, dim, eps=1e-5): super().__init__() self.input_scale = 1.0 self.variance_epsilon = eps - self.register_buffer('weight', torch.ones(dim, dtype=torch.float32)) + self.register_buffer("weight", torch.ones(dim, dtype=torch.float32)) def forward(self, x): - - input_dtype = x.dtype - hidden_states = x.to(torch.float32) - ln_output_fp = torch.nn.functional.layer_norm( - x, x.shape[-1:], self.weight, None, self.variance_epsilon) + x.dtype + x.to(torch.float32) + ln_output_fp = torch.nn.functional.layer_norm(x, x.shape[-1:], self.weight, None, self.variance_epsilon) ln_output_int8 = ln_output_fp.round().clamp(-128, 127).to(torch.int8) return ln_output_int8 - @staticmethod def from_float(module: torch.nn.LayerNorm, output_scale: float): assert module.weight.shape[0] == module.weight.numel() # assert module.bias.shape[0] == module.bias.numel() - q_module = LlamaLayerNormQ(module.weight.shape[0], module.variance_epsilon ) + q_module = LlamaLayerNormQ(module.weight.shape[0], module.variance_epsilon) q_module.weight = module.weight / output_scale # q_module.bias = module.bias / output_scale return q_module + class LlamaSmoothquantMLP(nn.Module): def __init__(self, intermediate_size, hidden_size): super().__init__() @@ -285,7 +298,7 @@ def pack( int8_module.gate_proj = W8A8BFP32O32LinearSiLU.from_float(mlp_module.gate_proj, gate_proj_input_scale) int8_module.up_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.up_proj, up_proj_input_scale) int8_module.down_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.down_proj, down_proj_input_scale) - int8_module.down_proj_input_scale = torch.tensor(down_proj_input_scale) + int8_module.down_proj_input_scale = torch.tensor([down_proj_input_scale]) return int8_module def forward( @@ -330,7 +343,7 @@ def pack( config = module.self_attn.config int8_decoder_layer = LlamaSmoothquantDecoderLayer(config) - int8_decoder_layer.input_layernorm = LlamaLayerNormQ.from_float(module.input_layernorm, attn_input_scale) + int8_decoder_layer.input_layernorm = LlamaLayerNormQ.from_float(module.input_layernorm, attn_input_scale) int8_decoder_layer.self_attn = LLamaSmoothquantAttention.pack( module.self_attn, attn_input_scale, @@ -342,12 +355,12 @@ def pack( out_input_scale, ) - # int8_decoder_layer.input_layernorm = module.input_layernorm # int8_decoder_layer.self_attn = module.self_attn - - int8_decoder_layer.post_attention_layernorm = LlamaLayerNormQ.from_float(module.post_attention_layernorm, gate_input_scale) + int8_decoder_layer.post_attention_layernorm = LlamaLayerNormQ.from_float( + module.post_attention_layernorm, gate_input_scale + ) int8_decoder_layer.mlp = LlamaSmoothquantMLP.pack( module.mlp, @@ -359,7 +372,6 @@ def pack( # int8_decoder_layer.post_attention_layernorm = module.post_attention_layernorm # int8_decoder_layer.mlp = module.mlp - return int8_decoder_layer def forward( @@ -372,6 +384,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, padding_mask: Optional[torch.LongTensor] = None, + infer_state: Optional[BatchInferState] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -401,6 +414,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, padding_mask=padding_mask, + infer_state=infer_state, ) hidden_states = residual + hidden_states @@ -610,10 +624,44 @@ def llama_model_forward( seq_length_with_past = seq_length past_key_values_length = 0 + infer_state = self.infer_state + if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] + # NOT READY FOR PRIME TIME + # dummy but work, revise it + past_key_values_length = infer_state.cache_manager.past_key_values_length + # past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length + # NOTE: differentiate with prefill stage + # block_loc require different value-assigning method for two different stage + # NOTE: differentiate with prefill stage + # block_loc require different value-assigning method for two different stage + if infer_state.is_context_stage: + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + infer_state.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) + else: + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print( + f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" + ) + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( @@ -663,7 +711,7 @@ def llama_model_forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None - + infer_state.decode_layer_id = 0 for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -671,17 +719,7 @@ def llama_model_forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids - ) + raise NotImplementedError("not implement gradient_checkpointing and training options ") else: layer_outputs = decoder_layer( hidden_states, @@ -692,9 +730,11 @@ def custom_forward(*inputs): output_attentions=output_attentions, use_cache=use_cache, padding_mask=padding_mask, + infer_state=infer_state, ) hidden_states = layer_outputs[0] + infer_state.decode_layer_id += 1 if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) @@ -708,6 +748,10 @@ def custom_forward(*inputs): if output_hidden_states: all_hidden_states += (hidden_states,) + infer_state.is_context_stage = False + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 + next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -719,204 +763,102 @@ def custom_forward(*inputs): ) -def convert_llama_to_smoothquant( - llama_model, - tokenizer, - dataset_path, - num_samples=512, - seq_len=512, -): - llama_config = llama_model.config - - llama_model.eval() - device = next(llama_model.parameters()).device - # print("model:", llama_model) - act_dict = defaultdict(dict) - - def stat_io_hook(m, x, y, name): - if isinstance(x, tuple): - x = x[0] - if name not in act_dict or "input" not in act_dict[name]: - act_dict[name]["input"] = x.detach().abs().max().item() - else: - act_dict[name]["input"] = max(act_dict[name]["input"], x.detach().abs().max().item()) - if isinstance(y, tuple): - y = y[0] - if name not in act_dict or "output" not in act_dict[name]: - act_dict[name]["output"] = y.detach().abs().max().item() - else: - act_dict[name]["output"] = max(act_dict[name]["output"], y.detach().abs().max().item()) - - for name, m in llama_model.named_modules(): - if isinstance(m, LlamaAttention): - setattr(m, "q_apply_rotary", LlamaApplyRotary()) - setattr(m, "k_apply_rotary", LlamaApplyRotary()) - m.forward = types.MethodType(llama_decoder_layer_forward, m) - - hooks = [] - for name, m in llama_model.named_modules(): - if isinstance(m, LlamaApplyRotary): - hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) - if isinstance(m, torch.nn.Linear): - hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) - - print("Collecting activation scales...") - pbar = tqdm(range(num_samples)) - dataset = load_dataset("json", data_files=dataset_path, split="train") - dataset = dataset.shuffle(seed=42) - for i in pbar: - input_ids = tokenizer( - dataset["rows"][0][i]["row"]["text"], - return_tensors="pt", - max_length=seq_len, - truncation=True, - ).input_ids.to(device) - llama_model(input_ids) - mean_scale = np.mean([v["input"] for v in act_dict.values()]) - pbar.set_description(f"Mean input scale: {mean_scale:.2f}") - for hook in hooks: - hook.remove() - - decoder_layer_scales = [] - - for idx in range(llama_config.num_hidden_layers): - scale_dict = {} - scale_dict["attn_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["input"] / 127 - scale_dict["q_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["output"] / 127 - scale_dict["k_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.k_proj"]["output"] / 127 - scale_dict["v_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.v_proj"]["output"] / 127 - - scale_dict["q_rotary_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_apply_rotary"]["output"] / 127 - - scale_dict["k_rotary_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.k_apply_rotary"]["output"] / 127 - - scale_dict["out_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.o_proj"]["input"] / 127 - # mlp scales - scale_dict["gate_input_scale"] = act_dict[f"model.layers.{idx}.mlp.gate_proj"]["input"] / 127 - scale_dict["up_input_scale"] = act_dict[f"model.layers.{idx}.mlp.up_proj"]["input"] / 127 - scale_dict["down_input_scale"] = act_dict[f"model.layers.{idx}.mlp.down_proj"]["input"] / 127 - - decoder_layer_scales.append(scale_dict) - - for i, layer in enumerate(llama_model.model.layers): - orig_layer = layer - llama_model.model.layers[i] = LlamaSmoothquantDecoderLayer.pack(orig_layer, **decoder_layer_scales[i]) - - llama_model.model.forward = types.MethodType(llama_model_forward, llama_model.model) - - cos, sin = init_to_get_rotary(llama_config) - llama_model.model.register_buffer("_cos_cached", cos) - llama_model.model.register_buffer("_sin_cached", sin) - return decoder_layer_scales, act_dict - - -# class SmoothLlamaForCausalLM(BaseSmoothForCausalLM): -# def __init__(self, model: PreTrainedModel, quantized: bool = False): -# super().__init__(model, quantized) - -# def quantized( -# self, -# tokenizer, -# dataset_path, -# num_samples=512, -# seq_len=512, -# ): -# llama_model = self.model -# llama_config = llama_model.config - -# llama_model.eval() -# device = next(llama_model.parameters()).device -# # print("model:", llama_model) -# act_dict = defaultdict(dict) - -# def stat_io_hook(m, x, y, name): -# if isinstance(x, tuple): -# x = x[0] -# if name not in act_dict or "input" not in act_dict[name]: -# act_dict[name]["input"] = x.detach().abs().max().item() -# else: -# act_dict[name]["input"] = max(act_dict[name]["input"], x.detach().abs().max().item()) -# if isinstance(y, tuple): -# y = y[0] -# if name not in act_dict or "output" not in act_dict[name]: -# act_dict[name]["output"] = y.detach().abs().max().item() -# else: -# act_dict[name]["output"] = max(act_dict[name]["output"], y.detach().abs().max().item()) - -# for name, m in llama_model.named_modules(): -# if isinstance(m, LlamaAttention): -# setattr(m, "q_apply_rotary", LlamaApplyRotary()) -# setattr(m, "k_apply_rotary", LlamaApplyRotary()) -# m.forward = types.MethodType(llama_decoder_layer_forward, m) - -# hooks = [] -# for name, m in llama_model.named_modules(): -# if isinstance(m, LlamaApplyRotary): -# hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) -# if isinstance(m, torch.nn.Linear): -# hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) - -# print("Collecting activation scales...") -# pbar = tqdm(range(num_samples)) -# dataset = load_dataset("json", data_files=dataset_path, split="train") -# dataset = dataset.shuffle(seed=42) -# for i in pbar: -# input_ids = tokenizer( -# dataset["rows"][0][i]["row"]["text"], -# return_tensors="pt", -# max_length=seq_len, -# truncation=True, -# ).input_ids.to(device) -# llama_model(input_ids) -# mean_scale = np.mean([v["input"] for v in act_dict.values()]) -# pbar.set_description(f"Mean input scale: {mean_scale:.2f}") -# for hook in hooks: -# hook.remove() - -# decoder_layer_scales = [] - -# for idx in range(llama_config.num_hidden_layers): -# scale_dict = {} -# scale_dict["attn_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["input"] / 127 -# scale_dict["q_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["output"] / 127 -# scale_dict["k_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.k_proj"]["output"] / 127 -# scale_dict["v_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.v_proj"]["output"] / 127 - -# scale_dict["q_rotary_output_scale"] = ( -# act_dict[f"model.layers.{idx}.self_attn.q_apply_rotary"]["output"] / 127 -# ) - -# scale_dict["k_rotary_output_scale"] = ( -# act_dict[f"model.layers.{idx}.self_attn.k_apply_rotary"]["output"] / 127 -# ) - -# scale_dict["out_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.o_proj"]["input"] / 127 -# # mlp scales -# scale_dict["gate_input_scale"] = act_dict[f"model.layers.{idx}.mlp.gate_proj"]["input"] / 127 -# scale_dict["up_input_scale"] = act_dict[f"model.layers.{idx}.mlp.up_proj"]["input"] / 127 -# scale_dict["down_input_scale"] = act_dict[f"model.layers.{idx}.mlp.down_proj"]["input"] / 127 - -# decoder_layer_scales.append(scale_dict) - -# for i, layer in enumerate(llama_model.model.layers): -# orig_layer = layer -# llama_model.model.layers[i] = LlamaSmoothquantDecoderLayer.pack(orig_layer, **decoder_layer_scales[i]) - -# llama_model.model.forward = types.MethodType(llama_model_forward, llama_model.model) - -# cos, sin = init_to_get_rotary(llama_config) -# llama_model.model.register_buffer("_cos_cached", cos) -# llama_model.model.register_buffer("_sin_cached", sin) - -# def make_smooth_model(cls, llama_model): -# super().make_smooth_model() - -# llama_config = llama_model.config - -# for i, layer in enumerate(llama_model.model.layers): -# llama_model.model.layers[i] = LlamaSmoothquantDecoderLayer(llama_config) - -# llama_model.model.forward = types.MethodType(llama_model_forward, llama_model.model) -# cos, sin = init_to_get_rotary(llama_config) -# llama_model.model.register_buffer("_cos_cached", cos) -# llama_model.model.register_buffer("_sin_cached", sin) +class SmoothLlamaForCausalLM(BaseSmoothForCausalLM): + layer_type = "LlamaDecoderLayer" + + def __init__(self, model: PreTrainedModel, quantized: bool = False): + super().__init__(model, quantized) + + def quantized( + self, + tokenizer, + dataset_path, + num_samples=512, + seq_len=512, + ): + llama_model = self.model + llama_config = llama_model.config + + llama_model.eval() + device = next(llama_model.parameters()).device + # print("model:", llama_model) + act_dict = defaultdict(dict) + + def stat_io_hook(m, x, y, name): + if isinstance(x, tuple): + x = x[0] + if name not in act_dict or "input" not in act_dict[name]: + act_dict[name]["input"] = x.detach().abs().max().item() + else: + act_dict[name]["input"] = max(act_dict[name]["input"], x.detach().abs().max().item()) + if isinstance(y, tuple): + y = y[0] + if name not in act_dict or "output" not in act_dict[name]: + act_dict[name]["output"] = y.detach().abs().max().item() + else: + act_dict[name]["output"] = max(act_dict[name]["output"], y.detach().abs().max().item()) + + for name, m in llama_model.named_modules(): + if isinstance(m, LlamaAttention): + setattr(m, "q_apply_rotary", LlamaApplyRotary()) + setattr(m, "k_apply_rotary", LlamaApplyRotary()) + m.forward = types.MethodType(llama_decoder_layer_forward, m) + + hooks = [] + for name, m in llama_model.named_modules(): + if isinstance(m, LlamaApplyRotary): + hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) + if isinstance(m, torch.nn.Linear): + hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) + + print("Collecting activation scales...") + pbar = tqdm(range(num_samples)) + dataset = load_dataset("json", data_files=dataset_path, split="train") + dataset = dataset.shuffle(seed=42) + for i in pbar: + input_ids = tokenizer( + dataset["rows"][0][i]["row"]["text"], + return_tensors="pt", + max_length=seq_len, + truncation=True, + ).input_ids.to(device) + llama_model(input_ids) + mean_scale = np.mean([v["input"] for v in act_dict.values()]) + pbar.set_description(f"Mean input scale: {mean_scale:.2f}") + for hook in hooks: + hook.remove() + + decoder_layer_scales = [] + + for idx in range(llama_config.num_hidden_layers): + scale_dict = {} + scale_dict["attn_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["input"] / 127 + scale_dict["q_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["output"] / 127 + scale_dict["k_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.k_proj"]["output"] / 127 + scale_dict["v_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.v_proj"]["output"] / 127 + + scale_dict["q_rotary_output_scale"] = ( + act_dict[f"model.layers.{idx}.self_attn.q_apply_rotary"]["output"] / 127 + ) + + scale_dict["k_rotary_output_scale"] = ( + act_dict[f"model.layers.{idx}.self_attn.k_apply_rotary"]["output"] / 127 + ) + + scale_dict["out_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.o_proj"]["input"] / 127 + # mlp scales + scale_dict["gate_input_scale"] = act_dict[f"model.layers.{idx}.mlp.gate_proj"]["input"] / 127 + scale_dict["up_input_scale"] = act_dict[f"model.layers.{idx}.mlp.up_proj"]["input"] / 127 + scale_dict["down_input_scale"] = act_dict[f"model.layers.{idx}.mlp.down_proj"]["input"] / 127 + + decoder_layer_scales.append(scale_dict) + + for i, layer in enumerate(llama_model.model.layers): + orig_layer = layer + llama_model.model.layers[i] = LlamaSmoothquantDecoderLayer.pack(orig_layer, **decoder_layer_scales[i]) + + llama_model.model.forward = types.MethodType(llama_model_forward, llama_model.model) + + cos, sin = init_to_get_rotary(llama_config) + llama_model.model.register_buffer("_cos_cached", cos.to(self.model.device)) + llama_model.model.register_buffer("_sin_cached", sin.to(self.model.device)) From ff43f5a63639194c2d4dae3f6d90cf50d2c71783 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 11 Oct 2023 18:00:38 +0800 Subject: [PATCH 5/7] refactor example --- examples/inference/smoothquant_conversion.py | 99 +++----------------- 1 file changed, 15 insertions(+), 84 deletions(-) diff --git a/examples/inference/smoothquant_conversion.py b/examples/inference/smoothquant_conversion.py index 066e9ea1abae..96f6e3730ebf 100644 --- a/examples/inference/smoothquant_conversion.py +++ b/examples/inference/smoothquant_conversion.py @@ -2,58 +2,31 @@ import os import torch -from transformers import AutoModelForCausalLM, LlamaTokenizer +from transformers import LlamaTokenizer -from colossalai.inference.quant.smoothquant.models.llama import convert_llama_to_smoothquant +from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM def build_model_and_tokenizer(model_name): tokenizer = LlamaTokenizer.from_pretrained(model_name, model_max_length=512) kwargs = {"torch_dtype": torch.float16, "device_map": "sequential"} - model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) + model = SmoothLlamaForCausalLM.from_pretrained(model_name, **kwargs) model = model.to(torch.float32) - # config = { - # "architectures": ["LLaMAForCausalLM"], - # "bos_token_id": 0, - # "eos_token_id": 1, - # "hidden_act": "silu", - # "hidden_size": 4096, - # "initializer_range": 0.02, - # "intermediate_size": 11008, - # "max_position_embeddings": 2048, - # "max_sequence_length": 2048, - # "model_type": "llama", - # "num_attention_heads": 32, - # "num_hidden_layers": 2, - # "num_key_value_heads": 32, - # "pad_token_id": -1, - # "pretraining_tp": 1, - # "rms_norm_eps": 1e-06, - # "torch_dtype": "float32", - # "transformers_version": "4.32.1", - # "use_cache": True, - # "vocab_size": 32000, - # } - # config = LlamaConfig(**config) - # model = LlamaForCausalLM(config) - return model, tokenizer def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--model-name", type=str, default="facebook/opt-1.3b", help="model name") + parser.add_argument("--model-name", type=str, help="model name") parser.add_argument( "--output-path", type=str, - default="act_scales/opt-1.3b.pt", - help="where to save the act scales", + help="where to save the checkpoint", ) parser.add_argument( "--dataset-path", type=str, - default="dataset/val.jsonl.zst", - help="location of the calibration dataset, we use the validation set of the Pile dataset", + help="location of the calibration dataset", ) parser.add_argument("--num-samples", type=int, default=512) parser.add_argument("--seq-len", type=int, default=512) @@ -64,71 +37,29 @@ def parse_args(): @torch.no_grad() def main(): args = parse_args() - model_path = "/home/lcxk/data3/llama-7b-hf" - dataset_path = "/home/lcxk/data3/datasets/cc_news.json" + model_path = args.model_name + dataset_path = args.dataset_path + output_path = args.output_path num_samples = 10 seq_len = 512 - print("data path", dataset_path) - data_files = {"train": dataset_path} - - # # dataset = load_dataset("/home/lcxk/data3/datasets/cc_news.py", data_files=dataset_path) - # dataset = load_dataset("json", data_files=dataset_path) - - # print("text", dataset["train"]["rows"][0][1``]["row"]["text"]) - # # for test in dataset["train"]["rows"]: - # # print(test) - # dataset = dataset.shuffle(seed=42) - # # print("text", dataset["rows"][0]) model, tokenizer = build_model_and_tokenizer(model_path) - print("config:", model.config) if not os.path.exists(dataset_path): print(f"Cannot find the dataset at {args.dataset_path}") - print("Please download the Pile dataset and put the validation set at the path") - print( - "You can download the validation dataset of the Pile at https://mystic.the-eye.eu/public/AI/pile/val.jsonl.zst" - ) - raise FileNotFoundError - - # act_scales = get_act_scales(model, tokenizer, dataset_path, num_samples, seq_len) - - # os.makedirs(os.path.dirname(output_path), exist_ok=True) - # torch.save(act_scales, output_path) - - # act_scales = torch.load(output_path) - # smooth_lm(model, act_scales, 0.5) - # # tokenizer = AutoTokenizer.from_pretrained(model_path) - - if not os.path.exists(dataset_path): - print(f"Cannot find the dataset at (dataset_path)") - print("Please download the Pile dataset and put the validation set at the path") - print( - "You can download the validation dataset of the Pile at https://mystic.the-eye.eu/public/AI/pile/val.jsonl.zst" - ) raise FileNotFoundError - decoder_layer_scales, raw_scales = convert_llama_to_smoothquant( - model, tokenizer, dataset_path, num_samples=num_samples, seq_len=seq_len - ) model = model.cuda() - generate_kwargs = dict(max_new_tokens=16, do_sample=False, use_cache=True) + model.quantized(tokenizer, dataset_path, num_samples=num_samples, seq_len=seq_len) - print("decoder layer scales:", decoder_layer_scales) - # output_path = Path(args.output_path) / (Path(args.model_name).name + "-smoothquant.pt") - input_tokens = tokenizer(["New York City"], return_tensors="pt").to(model.device) + model.save_quantized(output_path, model_basename="llama-7b") - max_batch_size = 1 - max_input_len = 7 - input_tokens = { - "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"), - "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"), - } + model = SmoothLlamaForCausalLM.from_quantized(output_path, model_basename="llama-7b") - print("input:", input) - # gen_args = {"max_input_"} + generate_kwargs = dict(max_new_tokens=16, do_sample=False, use_cache=True) + input_tokens = tokenizer(["today is "], return_tensors="pt").to("cuda") out = model.generate(**input_tokens, **generate_kwargs) text = tokenizer.batch_decode(out) - print("text:", text) + print("out is:", text) if __name__ == "__main__": From 7c4d0efeb4e42ccf03fac82743431c3ff2a8714a Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 12 Oct 2023 09:06:27 +0800 Subject: [PATCH 6/7] delete smooth --- .../inference/quant/smoothquant/smooth.py | 52 ------------------- 1 file changed, 52 deletions(-) delete mode 100644 colossalai/inference/quant/smoothquant/smooth.py diff --git a/colossalai/inference/quant/smoothquant/smooth.py b/colossalai/inference/quant/smoothquant/smooth.py deleted file mode 100644 index 120e7818a2f4..000000000000 --- a/colossalai/inference/quant/smoothquant/smooth.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch -import torch.nn as nn -from transformers.models.bloom.modeling_bloom import BloomBlock -from transformers.models.opt.modeling_opt import OPTDecoderLayer - - -@torch.no_grad() -def smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5): - if not isinstance(fcs, list): - fcs = [fcs] - assert isinstance(ln, nn.LayerNorm) - for fc in fcs: - assert isinstance(fc, nn.Linear) - assert ln.weight.numel() == fc.in_features == act_scales.numel() - - device, dtype = fcs[0].weight.device, fcs[0].weight.dtype - act_scales = act_scales.to(device=device, dtype=dtype) - weight_scales = torch.cat([fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0) - weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5) - - scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype) - - ln.weight.div_(scales) - ln.bias.div_(scales) - - for fc in fcs: - fc.weight.mul_(scales.view(1, -1)) - - -@torch.no_grad() -def smooth_lm(model, scales, alpha=0.5): - for name, module in model.named_modules(): - if isinstance(module, OPTDecoderLayer): - attn_ln = module.self_attn_layer_norm - qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj] - qkv_input_scales = scales[name + ".self_attn.q_proj"] - smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha) - - ffn_ln = module.final_layer_norm - fc1 = module.fc1 - fc1_input_scales = scales[name + ".fc1"] - smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha) - elif isinstance(module, BloomBlock): - attn_ln = module.input_layernorm - qkv = module.self_attention.query_key_value - qkv_input_scales = scales[name + ".self_attention.query_key_value"] - smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha) - - ffn_ln = module.post_attention_layernorm - fc1 = module.mlp.dense_h_to_4h - fc1_input_scales = scales[name + ".mlp.dense_h_to_4h"] - smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha) From e7ad57a052203bbeed2ed41b71e516a64d063f53 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 12 Oct 2023 09:59:54 +0800 Subject: [PATCH 7/7] refactor code --- .../inference/quant/smoothquant/models/llama.py | 17 ++--------------- ...quant_conversion.py => smoothquant_llama.py} | 0 2 files changed, 2 insertions(+), 15 deletions(-) rename examples/inference/{smoothquant_conversion.py => smoothquant_llama.py} (100%) diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py index 9bfacd1c6a3b..b201347825b2 100644 --- a/colossalai/inference/quant/smoothquant/models/llama.py +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -245,10 +245,7 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim) attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + return attn_output, None, None class LlamaLayerNormQ(torch.nn.Module): @@ -259,8 +256,6 @@ def __init__(self, dim, eps=1e-5): self.register_buffer("weight", torch.ones(dim, dtype=torch.float32)) def forward(self, x): - x.dtype - x.to(torch.float32) ln_output_fp = torch.nn.functional.layer_norm(x, x.shape[-1:], self.weight, None, self.variance_epsilon) ln_output_int8 = ln_output_fp.round().clamp(-128, 127).to(torch.int8) return ln_output_int8 @@ -424,15 +419,7 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs + return hidden_states, None, None class LlamaApplyRotary(nn.Module): diff --git a/examples/inference/smoothquant_conversion.py b/examples/inference/smoothquant_llama.py similarity index 100% rename from examples/inference/smoothquant_conversion.py rename to examples/inference/smoothquant_llama.py