From 1c2f6fac52554f7cd28763fa51d32a72eefbcc58 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 12 Oct 2023 13:36:55 +0800 Subject: [PATCH 1/4] add smooth function and delete useless code --- .../quant/smoothquant/models/base_model.py | 125 ++++++++++++------ .../quant/smoothquant/models/llama.py | 49 ++++--- 2 files changed, 112 insertions(+), 62 deletions(-) diff --git a/colossalai/inference/quant/smoothquant/models/base_model.py b/colossalai/inference/quant/smoothquant/models/base_model.py index 326c3df6e038..e4a2326f8d82 100644 --- a/colossalai/inference/quant/smoothquant/models/base_model.py +++ b/colossalai/inference/quant/smoothquant/models/base_model.py @@ -4,15 +4,19 @@ import types import warnings from abc import abstractmethod +from functools import partial from os.path import isdir, isfile, join from typing import Dict, List, Optional, Union import accelerate +import numpy as np import torch import torch.nn as nn import transformers +from datasets import load_dataset from safetensors.torch import save_file as safe_save from torch import device +from tqdm import tqdm from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel from transformers.modeling_utils import no_init_weights from transformers.utils.generic import ContextManagers @@ -22,53 +26,10 @@ from ....tensor_parallel.kvcache_manager import MemoryManager CPU = device("cpu") -CUDA_0 = device("cuda:0") SUPPORTED_MODELS = ["llama"] -def get_module_by_name_suffix(model, module_name: str): - for name, module in model.named_modules(): - if name.endswith(module_name): - return module - - -def simple_dispatch_model(model, device_map): - from accelerate.hooks import AlignDevicesHook, add_hook_to_module - - if "" in device_map: - d = device_map[""] - model = model.to(torch.device(d)) - model.hf_device_map = device_map - return model - - tied_params = accelerate.utils.modeling.find_tied_parameters(model) - if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}: - main_device = "cpu" - else: - main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0] - - cpu_offload_group = [(n, d) for n, d in device_map.items() if d == "cpu"] - prev_hook = None - for idx, (n, d) in enumerate(cpu_offload_group): - m = get_module_by_name_suffix(model, n) - _, prev_hook = accelerate.cpu_offload_with_hook(m, execution_device=main_device, prev_module_hook=prev_hook) - # set first cpu offload module's prev_module_hook to the last cpu offload module's hook - if len(cpu_offload_group) > 1: - get_module_by_name_suffix(model, cpu_offload_group[0][0])._hf_hook.prev_module_hook = prev_hook - - for n, d in device_map.items(): - m = get_module_by_name_suffix(model, n) - if d != "cpu": - d = torch.device(d) - hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True) - add_hook_to_module(m, hook) - accelerate.utils.modeling.retie_parameters(model, tied_params) - model.hf_device_map = device_map - - return model - - class BaseSmoothForCausalLM(nn.Module, PushToHubMixin): layer_type: str = None @@ -166,6 +127,84 @@ def prepare_inputs_for_generation(self, *args, **kwargs): """shortcut for model.prepare_inputs_for_generation""" return self.model.prepare_inputs_for_generation(*args, **kwargs) + def collect_act_scales(self, model, tokenizer, dataset_path, device, num_samples=512, seq_len=512): + dataset = load_dataset("json", data_files=dataset_path, split="train") + dataset = dataset.shuffle(seed=42) + + for i in tqdm(range(num_samples)): + input_ids = tokenizer( + dataset["rows"][0][i]["row"]["text"], return_tensors="pt", max_length=seq_len, truncation=True + ).input_ids.to(device) + model(input_ids) + + def collect_act_dict(self, model, tokenizer, dataset_path, act_dict, device, num_samples=512, seq_len=512): + pbar = tqdm(range(num_samples)) + dataset = load_dataset("json", data_files=dataset_path, split="train") + dataset = dataset.shuffle(seed=42) + for i in pbar: + input_ids = tokenizer( + dataset["rows"][0][i]["row"]["text"], + return_tensors="pt", + max_length=seq_len, + truncation=True, + ).input_ids.to(device) + model(input_ids) + mean_scale = np.mean([v["input"] for v in act_dict.values()]) + pbar.set_description(f"Mean input scale: {mean_scale:.2f}") + + def get_act_scales(self, model, tokenizer, dataset_path, num_samples=512, seq_len=512): + model.eval() + device = next(model.parameters()).device + act_scales = {} + + def stat_tensor(name, tensor): + hidden_dim = tensor.shape[-1] + tensor = tensor.view(-1, hidden_dim).abs().detach() + comming_max = torch.max(tensor, dim=0)[0].float().cpu() + if name in act_scales: + act_scales[name] = torch.max(act_scales[name], comming_max) + else: + act_scales[name] = comming_max + + def stat_input_hook(m, x, y, name): + if isinstance(x, tuple): + x = x[0] + stat_tensor(name, x) + + hooks = [] + for name, m in model.named_modules(): + if isinstance(m, nn.Linear): + hooks.append(m.register_forward_hook(partial(stat_input_hook, name=name))) + + self.collect_act_scales(model, tokenizer, dataset_path, device, num_samples, seq_len) + + for h in hooks: + h.remove() + + return act_scales + + @torch.no_grad() + def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5): + if not isinstance(fcs, list): + fcs = [fcs] + for fc in fcs: + assert isinstance(fc, nn.Linear) + assert ln.weight.numel() == fc.in_features == act_scales.numel() + + device, dtype = fcs[0].weight.device, fcs[0].weight.dtype + act_scales = act_scales.to(device=device, dtype=dtype) + weight_scales = torch.cat([fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0) + weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5) + + scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype) + + ln.weight.div_(scales) + if hasattr(ln, "bias"): + ln.bias.div_(scales) + + for fc in fcs: + fc.weight.mul_(scales.view(1, -1)) + def save_quantized( self, save_dir: str, diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py index b201347825b2..bfd1f7339686 100644 --- a/colossalai/inference/quant/smoothquant/models/llama.py +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -7,14 +7,11 @@ from functools import partial from typing import List, Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from datasets import load_dataset from torch import nn from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T -from tqdm import tqdm from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.configuration_llama import LlamaConfig @@ -756,7 +753,7 @@ class SmoothLlamaForCausalLM(BaseSmoothForCausalLM): def __init__(self, model: PreTrainedModel, quantized: bool = False): super().__init__(model, quantized) - def quantized( + def get_act_dict( self, tokenizer, dataset_path, @@ -764,7 +761,7 @@ def quantized( seq_len=512, ): llama_model = self.model - llama_config = llama_model.config + llama_model.config llama_model.eval() device = next(llama_model.parameters()).device @@ -798,23 +795,37 @@ def stat_io_hook(m, x, y, name): if isinstance(m, torch.nn.Linear): hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) - print("Collecting activation scales...") - pbar = tqdm(range(num_samples)) - dataset = load_dataset("json", data_files=dataset_path, split="train") - dataset = dataset.shuffle(seed=42) - for i in pbar: - input_ids = tokenizer( - dataset["rows"][0][i]["row"]["text"], - return_tensors="pt", - max_length=seq_len, - truncation=True, - ).input_ids.to(device) - llama_model(input_ids) - mean_scale = np.mean([v["input"] for v in act_dict.values()]) - pbar.set_description(f"Mean input scale: {mean_scale:.2f}") + self.collect_act_dict(llama_model, tokenizer, dataset_path, act_dict, device, num_samples, seq_len) + for hook in hooks: hook.remove() + return act_dict + + def smooth_fn(self, scales, alpha=0.5): + model = self.model + for name, module in model.named_modules(): + if isinstance(module, LlamaDecoderLayer): + attn_ln = module.input_layernorm + qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj] + qkv_input_scales = scales[name + ".self_attn.q_proj"] + self.smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha) + + def quantized( + self, + tokenizer, + dataset_path, + num_samples=512, + seq_len=512, + alpha=0.5, + ): + llama_model = self.model + llama_config = llama_model.config + + act_scales = self.get_act_scales(llama_model, tokenizer, dataset_path, num_samples, seq_len) + + self.smooth_fn(act_scales, alpha) + act_dict = self.get_act_dict(tokenizer, dataset_path, num_samples, seq_len) decoder_layer_scales = [] for idx in range(llama_config.num_hidden_layers): From 96195e14836bfcd66e3e9950de5d7779ec332976 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 12 Oct 2023 15:02:56 +0800 Subject: [PATCH 2/4] update datasets --- .../quant/smoothquant/models/base_model.py | 35 ++++++------------- .../quant/smoothquant/models/llama.py | 11 +++--- 2 files changed, 16 insertions(+), 30 deletions(-) diff --git a/colossalai/inference/quant/smoothquant/models/base_model.py b/colossalai/inference/quant/smoothquant/models/base_model.py index e4a2326f8d82..634c976cb18e 100644 --- a/colossalai/inference/quant/smoothquant/models/base_model.py +++ b/colossalai/inference/quant/smoothquant/models/base_model.py @@ -13,7 +13,6 @@ import torch import torch.nn as nn import transformers -from datasets import load_dataset from safetensors.torch import save_file as safe_save from torch import device from tqdm import tqdm @@ -22,8 +21,8 @@ from transformers.utils.generic import ContextManagers from transformers.utils.hub import PushToHubMixin, cached_file -from ....tensor_parallel.batch_infer_state import BatchInferState -from ....tensor_parallel.kvcache_manager import MemoryManager +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager CPU = device("cpu") @@ -127,32 +126,20 @@ def prepare_inputs_for_generation(self, *args, **kwargs): """shortcut for model.prepare_inputs_for_generation""" return self.model.prepare_inputs_for_generation(*args, **kwargs) - def collect_act_scales(self, model, tokenizer, dataset_path, device, num_samples=512, seq_len=512): - dataset = load_dataset("json", data_files=dataset_path, split="train") - dataset = dataset.shuffle(seed=42) - - for i in tqdm(range(num_samples)): - input_ids = tokenizer( - dataset["rows"][0][i]["row"]["text"], return_tensors="pt", max_length=seq_len, truncation=True - ).input_ids.to(device) + def collect_act_scales(self, model, tokenizer, dataset, device, num_samples=512, seq_len=512): + for text in tqdm(dataset): + input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device) model(input_ids) - def collect_act_dict(self, model, tokenizer, dataset_path, act_dict, device, num_samples=512, seq_len=512): - pbar = tqdm(range(num_samples)) - dataset = load_dataset("json", data_files=dataset_path, split="train") - dataset = dataset.shuffle(seed=42) - for i in pbar: - input_ids = tokenizer( - dataset["rows"][0][i]["row"]["text"], - return_tensors="pt", - max_length=seq_len, - truncation=True, - ).input_ids.to(device) + def collect_act_dict(self, model, tokenizer, dataset, act_dict, device, num_samples=512, seq_len=512): + pbar = tqdm(dataset) + for text in pbar: + input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device) model(input_ids) mean_scale = np.mean([v["input"] for v in act_dict.values()]) pbar.set_description(f"Mean input scale: {mean_scale:.2f}") - def get_act_scales(self, model, tokenizer, dataset_path, num_samples=512, seq_len=512): + def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512): model.eval() device = next(model.parameters()).device act_scales = {} @@ -176,7 +163,7 @@ def stat_input_hook(m, x, y, name): if isinstance(m, nn.Linear): hooks.append(m.register_forward_hook(partial(stat_input_hook, name=name))) - self.collect_act_scales(model, tokenizer, dataset_path, device, num_samples, seq_len) + self.collect_act_scales(model, tokenizer, dataset, device, num_samples, seq_len) for h in hooks: h.remove() diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py index bfd1f7339686..fc4ecf1e3ba4 100644 --- a/colossalai/inference/quant/smoothquant/models/llama.py +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -756,12 +756,11 @@ def __init__(self, model: PreTrainedModel, quantized: bool = False): def get_act_dict( self, tokenizer, - dataset_path, + dataset, num_samples=512, seq_len=512, ): llama_model = self.model - llama_model.config llama_model.eval() device = next(llama_model.parameters()).device @@ -795,7 +794,7 @@ def stat_io_hook(m, x, y, name): if isinstance(m, torch.nn.Linear): hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) - self.collect_act_dict(llama_model, tokenizer, dataset_path, act_dict, device, num_samples, seq_len) + self.collect_act_dict(llama_model, tokenizer, dataset, act_dict, device, num_samples, seq_len) for hook in hooks: hook.remove() @@ -813,7 +812,7 @@ def smooth_fn(self, scales, alpha=0.5): def quantized( self, tokenizer, - dataset_path, + dataset, num_samples=512, seq_len=512, alpha=0.5, @@ -821,11 +820,11 @@ def quantized( llama_model = self.model llama_config = llama_model.config - act_scales = self.get_act_scales(llama_model, tokenizer, dataset_path, num_samples, seq_len) + act_scales = self.get_act_scales(llama_model, tokenizer, dataset, num_samples, seq_len) self.smooth_fn(act_scales, alpha) - act_dict = self.get_act_dict(tokenizer, dataset_path, num_samples, seq_len) + act_dict = self.get_act_dict(tokenizer, dataset, num_samples, seq_len) decoder_layer_scales = [] for idx in range(llama_config.num_hidden_layers): From 73876ca4bee30be7afabb0d2a6957ea222da0ec1 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 12 Oct 2023 15:09:56 +0800 Subject: [PATCH 3/4] remove duplicate import --- colossalai/inference/quant/smoothquant/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py index fc4ecf1e3ba4..014fb640e060 100644 --- a/colossalai/inference/quant/smoothquant/models/llama.py +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -10,7 +10,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch import nn from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPast From 22fc3bc2cac78581eebbf14ebd4d2e109fb686c0 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 12 Oct 2023 15:38:23 +0800 Subject: [PATCH 4/4] delete useless file --- .../quant/smoothquant/calibration.py | 53 ------------------- .../quant/smoothquant/models/base_model.py | 3 +- 2 files changed, 1 insertion(+), 55 deletions(-) delete mode 100644 colossalai/inference/quant/smoothquant/calibration.py diff --git a/colossalai/inference/quant/smoothquant/calibration.py b/colossalai/inference/quant/smoothquant/calibration.py deleted file mode 100644 index 66ac49826592..000000000000 --- a/colossalai/inference/quant/smoothquant/calibration.py +++ /dev/null @@ -1,53 +0,0 @@ -# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ - -import functools - -import torch -import torch.nn as nn -from datasets import load_dataset -from tqdm import tqdm - - -def get_act_scales(model, tokenizer, dataset_path, num_samples=512, seq_len=512): - model.eval() - device = next(model.parameters()).device - act_scales = {} - - def stat_tensor(name, tensor): - hidden_dim = tensor.shape[-1] - tensor = tensor.view(-1, hidden_dim).abs().detach() - comming_max = torch.max(tensor, dim=0)[0].float().cpu() - if name in act_scales: - act_scales[name] = torch.max(act_scales[name], comming_max) - else: - act_scales[name] = comming_max - - def stat_input_hook(m, x, y, name): - if isinstance(x, tuple): - x = x[0] - stat_tensor(name, x) - - hooks = [] - for name, m in model.named_modules(): - if isinstance(m, nn.Linear): - hooks.append(m.register_forward_hook(functools.partial(stat_input_hook, name=name))) - - dataset = load_dataset("json", data_files=dataset_path) - - print("text", dataset["train"]["rows"][0][1]["row"]["text"]) - - dataset = dataset.shuffle(seed=42) - - for i in tqdm(range(num_samples)): - input_ids = tokenizer( - dataset["train"]["rows"][0][i]["row"]["text"], - return_tensors="pt", - max_length=seq_len, - truncation=True, - ).input_ids.to(device) - model(input_ids) - - for h in hooks: - h.remove() - - return act_scales diff --git a/colossalai/inference/quant/smoothquant/models/base_model.py b/colossalai/inference/quant/smoothquant/models/base_model.py index 634c976cb18e..73cdbb39e53f 100644 --- a/colossalai/inference/quant/smoothquant/models/base_model.py +++ b/colossalai/inference/quant/smoothquant/models/base_model.py @@ -92,6 +92,7 @@ def init_batch_state(self, max_output_len=256, **kwargs): batch_infer_state.past_key_values_len = 0 batch_infer_state.is_context_stage = True batch_infer_state.set_cache_manager(self.cache_manager) + batch_infer_state.cache_manager.free_all() return batch_infer_state @abstractmethod @@ -117,8 +118,6 @@ def generate(self, **kwargs): if self.config.model_type == "llama": setattr(self.model.model, "infer_state", batch_infer_state) - batch_infer_state.is_context_stage = True - with torch.inference_mode(): return self.model.generate(**kwargs)