Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 4 additions & 50 deletions colossalai/gptq/cai_gptq/cai_quant_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,49 +147,6 @@ def pack(self, linear, scales, zeros, g_idx=None):
else:
self.g_idx = g_idx

def prepare_buffers(self):
assert self.qweight.device.type == "cuda"
device = self.qweight.device
if self.g_idx is not None:
if self.row_split and torch.equal(
self.g_idx,
torch.tensor(
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
dtype=torch.int32,
device=self.g_idx.device)):
self.g_idx = None
elif torch.equal(
self.g_idx,
torch.tensor([i // self.groupsize for i in range(self.infeatures)],
dtype=torch.int32,
device=self.g_idx.device)):
self.g_idx = None

CaiQuantLinear.max_dq_buffer_size = max(CaiQuantLinear.max_dq_buffer_size, self.qweight.numel() * 8)

if self.g_idx is not None:
CaiQuantLinear.max_inner_outer_dim = max(CaiQuantLinear.max_inner_outer_dim, self.infeatures,
self.outfeatures)
CaiQuantLinear.max_input_len = 4096
# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
CaiQuantLinear.device_to_buffers['temp_state'] = torch.zeros(
(CaiQuantLinear.max_input_len, CaiQuantLinear.max_inner_outer_dim), dtype=torch.float16, device=device)
CaiQuantLinear.device_to_buffers['temp_dp'] = torch.zeros((1, CaiQuantLinear.max_dq_buffer_size),
dtype=torch.float16,
device=device)

gptq_cuda.prepare_buffers(torch.device(device), CaiQuantLinear.device_to_buffers['temp_state'],
CaiQuantLinear.device_to_buffers['temp_dp'])

# Using the default from exllama repo here.
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)

torch.cuda.empty_cache()

def init_q4(self):
assert self.qweight.device.type == "cuda"
self.q4_width = self.qweight.shape[1]
Expand Down Expand Up @@ -219,21 +176,18 @@ def init_q4(self):
def forward(self, x):
outshape = x.shape[:-1] + (self.outfeatures,)

if HAS_GPTQ_CUDA:
if CaiQuantLinear.prepared_buffers == False:
self.prepare_buffers()
CaiQuantLinear.prepared_buffers = True
if HAS_GPTQ_CUDA and self.bits == 4:

if self.q4 is None:
self.init_q4()

x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], self.outfeatures), dtype=torch.float16, device=x.device)
gptq_cuda.q4_matmul(x, self.q4, output)
if (self.bias is not None and not self.row_split) or self.tp_size == 1:
gptq_cuda.q4_matmul(x.half(), self.q4, output)
if self.bias is not None and (not self.row_split or self.tp_size == 1):
output.add_(self.bias)
else:
if (self.bias is not None and not self.row_split) or self.tp_size == 1:
if self.bias is not None and (not self.row_split or self.tp_size == 1):
bias = self.bias
else:
bias = None
Expand Down
2 changes: 1 addition & 1 deletion colossalai/gptq/gptq_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def all_reduce_hook(cai_linear, input, output):
model_type_name = model.config.model_type

gptq_model_config = model_config_map[model_type_name]
layers = get_module_by_name_prefix(model.model, gptq_model_config.layer_blocks)
layers = get_module_by_name_prefix(model, gptq_model_config.layer_blocks)

for layer in layers:

Expand Down
60 changes: 60 additions & 0 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
import warnings
from typing import Any, Callable, Dict, List, Optional, Union

import torch
import torch.distributed as dist
import torch.nn as nn
from transformers import BloomForCausalLM, LlamaForCausalLM
from transformers.generation import GenerationConfig
from transformers.generation.stopping_criteria import StoppingCriteriaList
from transformers.tokenization_utils_base import BatchEncoding

from colossalai.gptq.cai_gptq import CaiQuantLinear
from colossalai.gptq.gptq_tp import replace_autogptq_linear
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.auto_policy import get_autopolicy

HAS_GPTQ_CUDA = False
try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder
gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True
except ImportError:
warnings.warn('CUDA gptq is not installed')
HAS_GPTQ_CUDA = False

from .batch_infer_state import BatchInferState
from .kvcache_manager import MemoryManager

Expand Down Expand Up @@ -66,6 +79,13 @@ def __init__(self,
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.cache_manager = None

self.max_dq_buffer_size = 1
self.max_inner_outer_dim = 1
self.gptq_temp_state_buffer = None
self.gptq_temp_dq_buffer = None
self.bits = -1
self.use_act_order = False

self.shard_config = shard_config
self.model = None
# optimize the original model by sharding with ShardFormer
Expand All @@ -78,6 +98,41 @@ def _init_manager(self) -> None:
self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim,
self.layer_num)

def _post_init_gptq_buffer(self, model: nn.Module) -> None:

for name, submodule in model.named_modules():
if isinstance(submodule, CaiQuantLinear):
self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8)

if self.use_act_order:
self.max_inner_outer_dim = max(self.max_inner_outer_dim, submodule.infeatures,
submodule.outfeatures)
self.bits = submodule.bits
if not (HAS_GPTQ_CUDA and self.bits == 4):
return

max_input_len = 1
if self.use_act_order:
max_input_len = self.max_input_len
# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
self.gptq_temp_state_buffer = torch.zeros((max_input_len, self.max_inner_outer_dim),
dtype=torch.float16,
device=torch.cuda.current_device())
self.gptq_temp_dq_buffer = torch.zeros((1, self.max_dq_buffer_size),
dtype=torch.float16,
device=torch.cuda.current_device())

gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer,
self.gptq_temp_dq_buffer)
# Using the default from exllama repo here.
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)

torch.cuda.empty_cache()

def _optimize_model(self, model: nn.Module) -> None:
"""
Optimize the original model by sharding with ShardFormer.
Expand Down Expand Up @@ -124,6 +179,11 @@ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None:
model_name = model.__class__.__name__
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
policy = get_autopolicy(model, inference_only=True)

if self.shard_config.inference_gptq:
tp_rank = dist.get_rank(self.shard_config.tensor_parallel_process_group)
replace_autogptq_linear(model, tp_size=self.tp_size, tp_rank=tp_rank)
self._post_init_gptq_buffer(model)
self.model, _ = shardformer.optimize(model, policy)
self.model = self.model.cuda()

Expand Down
21 changes: 20 additions & 1 deletion colossalai/inference/tensor_parallel/policies/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import torch
from torch.nn import LayerNorm

import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy

from ..modeling.bloom import BloomInferenceForwards
Expand Down Expand Up @@ -33,7 +36,23 @@ def __init__(self) -> None:

def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
policy = super().module_policy()
policy = {}
if not self.shard_config.inference_gptq:
policy = super().module_policy()
else:
policy[BloomModel] = ModulePolicyDescription(
attribute_replacement={
"num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
method_replacement={
"build_alibi_tensor": build_bloom_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group)
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
)
])
# NOTE set inference mode to shard config
self.shard_config._infer()

Expand Down
37 changes: 23 additions & 14 deletions colossalai/inference/tensor_parallel/policies/llama.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from functools import partial

import torch
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaModel,
LlamaRMSNorm
)
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm

from colossalai.shardformer.layer import VocabParallelEmbedding1D
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
# import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy

from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward

try:
Expand All @@ -18,23 +17,34 @@
print("you should install triton from https://github.com/openai/triton")
HAS_TRITON_RMSNORM = False


def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:

def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)

return _triton_rmsnorm_forward
else:
return None



class LlamaModelInferPolicy(LlamaForCausalLMPolicy):

def __init__(self) -> None:
super().__init__()

def module_policy(self):
policy = super().module_policy()
policy = {}
if not self.shard_config.inference_gptq:
policy = super().module_policy()
else:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
),
policy=policy,
target_key=LlamaModel)
self.shard_config._infer()

infer_forward = LlamaInferenceForwards.llama_model_forward
Expand All @@ -59,12 +69,11 @@ def module_policy(self):
else:
# NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123
infer_forward = get_llama_vllm_rmsnorm_forward()

if infer_forward is not None:
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaRMSNorm)
policy=policy,
target_key=LlamaRMSNorm)

return policy

2 changes: 1 addition & 1 deletion colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ class ShardConfig:
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
inference_only: bool = False
inference_gptq: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False

# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
Expand Down
Loading