Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Callable, List, Optional, Union

import torch
import torch.distributed as dist
import torch.nn as nn
from transformers import BloomForCausalLM, LlamaForCausalLM
from transformers.generation import GenerationConfig
Expand Down Expand Up @@ -74,9 +73,14 @@ def __init__(
model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers
)
self.layer_num = num_hidden_layers
self.multi_query_group_num = (
model.config.multi_query_group_num if hasattr(model.config, "multi_query_group_num") else 0
)

self.multi_query_group_num = 0

if hasattr(model.config, "multi_query_group_num"):
self.multi_query_group_num = model.config.multi_query_group_num

if hasattr(model.config, "num_key_value_heads"):
self.multi_query_group_num = model.config.num_key_value_heads

self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.cache_manager = None
Expand All @@ -97,6 +101,7 @@ def _init_manager(self) -> None:
assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig"
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
self.head_num //= self.tp_size # update sharded number of heads

if self.multi_query_group_num:
# NOTE the logic of MQA tensor parallelism should be specified.
assert (
Expand All @@ -116,22 +121,25 @@ def _init_manager(self) -> None:

def _post_init_gptq_buffer(self, model: nn.Module) -> None:
from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear

HAS_GPTQ_CUDA = False
try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder

gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True
except ImportError:
warnings.warn('CUDA gptq is not installed')
warnings.warn("CUDA gptq is not installed")
HAS_GPTQ_CUDA = False

for name, submodule in model.named_modules():
if isinstance(submodule, CaiQuantLinear):
self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8)

if self.use_act_order:
self.max_inner_outer_dim = max(self.max_inner_outer_dim, submodule.infeatures,
submodule.outfeatures)
self.max_inner_outer_dim = max(
self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures
)
self.bits = submodule.bits
if not (HAS_GPTQ_CUDA and self.bits == 4):
return
Expand All @@ -141,15 +149,16 @@ def _post_init_gptq_buffer(self, model: nn.Module) -> None:
max_input_len = self.max_input_len
# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
self.gptq_temp_state_buffer = torch.zeros((max_input_len, self.max_inner_outer_dim),
dtype=torch.float16,
device=torch.cuda.current_device())
self.gptq_temp_dq_buffer = torch.zeros((1, self.max_dq_buffer_size),
dtype=torch.float16,
device=torch.cuda.current_device())

gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer,
self.gptq_temp_dq_buffer)
self.gptq_temp_state_buffer = torch.zeros(
(max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
)
self.gptq_temp_dq_buffer = torch.zeros(
(1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()
)

gptq_cuda.prepare_buffers(
torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer
)
# Using the default from exllama repo here.
matmul_recons_thd = 8
matmul_fused_remap = False
Expand Down
2 changes: 1 addition & 1 deletion colossalai/inference/tensor_parallel/modeling/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def init_to_get_rotary(self, base=10000, use_elem=False):
base = float(base)

# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", None))
ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None)

if ntk_alpha is not None:
ntk_alpha = float(ntk_alpha)
Expand Down
82 changes: 56 additions & 26 deletions colossalai/inference/tensor_parallel/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm

from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import llama_context_attn_fwd, rotary_embedding_fwd, token_attention_fwd
from colossalai.kernel.triton import (
llama2_context_attn_fwd,
llama_context_attn_fwd,
rotary_embedding_fwd,
token_attention_fwd,
)
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards

from ._utils import copy_kv_to_mem_cache

Expand Down Expand Up @@ -138,6 +144,7 @@ def llama_model_forward(
seq_len = infer_state.seq_len
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
infer_state.other_kv_index = infer_state.block_loc[0, seq_length_with_past - 1].item()

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
Expand Down Expand Up @@ -261,8 +268,8 @@ def llama_flash_attn_kvcache_forward(
# key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head]

query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)

# NOTE might want to revise
# need some way to record the length of past key values cache
Expand All @@ -274,11 +281,11 @@ def llama_flash_attn_kvcache_forward(
# print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )

rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin)
rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)

query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
value_states = value_states.reshape(-1, self.num_heads, self.head_dim)
key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
value_states = value_states.reshape(-1, self.num_key_value_heads, self.head_dim)

if infer_state.is_context_stage:
# first token generation
Expand All @@ -294,15 +301,26 @@ def llama_flash_attn_kvcache_forward(

attn_output = torch.empty_like(query_states)

llama_context_attn_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
)
if self.num_key_value_groups == 1:
Comment thread
tiandiao123 marked this conversation as resolved.
llama_context_attn_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
)
else:
llama2_context_attn_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
)
else:
if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
Expand Down Expand Up @@ -330,17 +348,29 @@ def llama_flash_attn_kvcache_forward(
# (batch_size, seqlen, nheads, headdim)
attn_output = torch.empty_like(query_states)

token_attention_fwd(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
)

if self.num_key_value_groups == 1:
token_attention_fwd(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
)
else:
Llama2TokenAttentionForwards.token_attn(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
infer_state.other_kv_index,
)
attn_output = attn_output.view(bsz, q_len, self.hidden_size)

attn_output = self.o_proj(attn_output)
Expand Down
3 changes: 2 additions & 1 deletion colossalai/kernel/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

# There may exist import error even if we have triton installed.
if HAS_TRITON:
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
from .context_attention import bloom_context_attn_fwd, llama2_context_attn_fwd, llama_context_attn_fwd
from .copy_kv_cache_dest import copy_kv_cache_to_dest
from .fused_layernorm import layer_norm
from .gptq_triton import gptq_fused_linear_triton
Expand All @@ -20,6 +20,7 @@

__all__ = [
"llama_context_attn_fwd",
"llama2_context_attn_fwd",
"bloom_context_attn_fwd",
"softmax",
"layer_norm",
Expand Down
69 changes: 69 additions & 0 deletions tests/test_infer/test_llama2_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os

import pytest
import torch
from packaging import version
from transformers import LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig

import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn

os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
TPSIZE = 2
BATCH_SIZE = 8
MAX_INPUT_LEN = 12
MAX_OUTPUT_LEN = 100

CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")


@parameterize(
"test_config",
[
{
"tp_size": TPSIZE,
}
],
)
def run_llama_test(test_config):
llama_config = LlamaConfig(
num_hidden_layers=2, num_key_value_heads=8, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024
)
model = LlamaForCausalLM(llama_config)
model = model.half()

shard_config = ShardConfig(
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
)
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)

input_tokens = {
"input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
"attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
}
outputs = infer_engine.generate(input_tokens, **generate_kwargs)

assert outputs is not None


def check_llama(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_llama_test()


@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama():
spawn(check_llama, TPSIZE)


if __name__ == "__main__":
test_llama()