Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion colossalai/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

## Introduction

`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including TGI, vLLM, FasterTransformer, LightLLM and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users.
`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including LightLLM, TGI, vLLM, FasterTransformer and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users.

## Design

Expand Down Expand Up @@ -62,6 +62,12 @@ triton==2.0.0.dev20221202
vllm
# for install flash-attention, please use commit hash: 67ae6fd74b4bc99c36b2ce524cf139c35663793c
flash-attention

# install lightllm since we depend on lightllm triton kernels
git clone https://github.com/ModelTC/lightllm
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
cd lightllm
pip3 install -e .
```

### Docker
Expand All @@ -73,6 +79,17 @@ You can use docker run to use docker container to set-up environment
docker pull hpcaitech/colossalai-inference:v2
docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash

# enter into docker container
cd /path/to/CollossalAI
pip install -e .

# install lightllm
git clone https://github.com/ModelTC/lightllm
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
cd lightllm
pip3 install -e .


```

### Dive into fast-inference!
Expand Down
3 changes: 2 additions & 1 deletion colossalai/inference/tensor_parallel/batch_infer_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .kvcache_manager import MemoryManager


# adapted from: lightllm/server/router/model_infer/infer_batch.py
@dataclass
class BatchInferState:
r"""
Expand Down Expand Up @@ -41,6 +41,7 @@ def total_token_num(self):
def set_cache_manager(self, manager: MemoryManager):
self.cache_manager = manager

# adapted from: https://github.com/ModelTC/lightllm/blob/28c1267cfca536b7b4f28e921e03de735b003039/lightllm/common/infer_utils.py#L1
@staticmethod
def init_block_loc(
b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor
Expand Down
10 changes: 6 additions & 4 deletions colossalai/inference/tensor_parallel/kvcache_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Adapted from lightllm/common/mem_manager.py
# of the ModelTC/lightllm GitHub repository
# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py

"""
Refered/Modified from lightllm/common/mem_manager.py
of the ModelTC/lightllm GitHub repository
https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design.
"""
import torch
from transformers.utils import logging

Expand Down
18 changes: 12 additions & 6 deletions colossalai/inference/tensor_parallel/modeling/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast

from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton.context_attention import llama2_context_attn_fwd
from colossalai.kernel.triton.rotary_embedding_kernel import Llama2Forwards
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMForConditionalGeneration,
Expand All @@ -20,6 +18,14 @@

from ._utils import copy_kv_to_mem_cache

try:
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_llama2_context_attention_fwd
from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd
HAS_LIGHTLLM_KERNEL = True
except:
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
HAS_LIGHTLLM_KERNEL = False


# This func is same as Llama model init_to_get_rotary, we should move them into _utils.py
def _init_to_get_rotary(self, base=10000):
Expand Down Expand Up @@ -433,17 +439,17 @@ def chatglm_flash_attn_kvcache_forward(

cos, sin = infer_state.position_cos, infer_state.position_sin

Llama2Forwards.rotary_emb_fwd(
chatglm2_rotary_emb_fwd(
query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin
)
if self.multi_query_attention:
Llama2Forwards.rotary_emb_fwd(
chatglm2_rotary_emb_fwd(
key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head),
cos,
sin,
)
else:
Llama2Forwards.rotary_emb_fwd(
chatglm2_rotary_emb_fwd(
key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head),
cos,
sin,
Expand Down Expand Up @@ -474,7 +480,7 @@ def chatglm_flash_attn_kvcache_forward(
attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))

# NOTE: no bug in context attn fwd (del it )
llama2_context_attn_fwd(
lightllm_llama2_context_attention_fwd(
query_layer,
key_layer,
value_layer,
Expand Down
25 changes: 16 additions & 9 deletions colossalai/inference/tensor_parallel/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm

from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import (
llama2_context_attn_fwd,
llama_context_attn_fwd,
rotary_embedding_fwd,
token_attention_fwd,
)
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards

from ._utils import copy_kv_to_mem_cache
Expand All @@ -29,6 +24,17 @@
)
HAS_VLLM_KERNERL = False

try:
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_llama2_context_attention_fwd,
)
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd

HAS_LIGHTLLM_KERNEL = True
except:
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
HAS_LIGHTLLM_KERNEL = False


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
Expand Down Expand Up @@ -280,8 +286,8 @@ def llama_flash_attn_kvcache_forward(
cos, sin = infer_state.position_cos, infer_state.position_sin
# print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )

rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)
llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)

query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
Expand Down Expand Up @@ -312,7 +318,7 @@ def llama_flash_attn_kvcache_forward(
infer_state.cache_manager.past_key_values_length,
)
else:
llama2_context_attn_fwd(
lightllm_llama2_context_attention_fwd(
query_states,
key_states,
value_states,
Expand Down Expand Up @@ -371,6 +377,7 @@ def llama_flash_attn_kvcache_forward(
infer_state.cache_manager.past_key_values_length,
infer_state.other_kv_index,
)

attn_output = attn_output.view(bsz, q_len, self.hidden_size)

attn_output = self.o_proj(attn_output)
Expand Down
6 changes: 2 additions & 4 deletions colossalai/inference/tensor_parallel/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward

try:
from colossalai.kernel.triton import rmsnorm_forward

from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
Expand All @@ -22,9 +21,8 @@

def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:

def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)

return _triton_rmsnorm_forward
else:
Expand Down
7 changes: 1 addition & 6 deletions colossalai/kernel/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,21 @@

# There may exist import error even if we have triton installed.
if HAS_TRITON:
from .context_attention import bloom_context_attn_fwd, llama2_context_attn_fwd, llama_context_attn_fwd
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
from .copy_kv_cache_dest import copy_kv_cache_to_dest
from .fused_layernorm import layer_norm
from .gptq_triton import gptq_fused_linear_triton
from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd
from .rms_norm import rmsnorm_forward
from .rotary_embedding_kernel import rotary_embedding_fwd
from .smooth_attention import smooth_llama_context_attn_fwd, smooth_token_attention_fwd
from .softmax import softmax
from .token_attention_kernel import token_attention_fwd

__all__ = [
"llama_context_attn_fwd",
"llama2_context_attn_fwd",
"bloom_context_attn_fwd",
"softmax",
"layer_norm",
"rmsnorm_forward",
"copy_kv_cache_to_dest",
"rotary_embedding_fwd",
"token_attention_fwd",
"gptq_fused_linear_triton",
"int8_rotary_embedding_fwd",
Expand Down
Loading