Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions colossalai/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ cd lightllm
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
pip3 install -e .

# also, install xformers from source:
pip install ninja
# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers

# install flash-attention
git clone -recursive https://github.com/Dao-AILab/flash-attention
cd flash-attention
pip install -e .
```

### Docker
Expand All @@ -95,10 +95,11 @@ cd lightllm
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
pip3 install -e .

# install xformers from source
pip install ninja
# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
# install flash-attention
git clone -recursive https://github.com/Dao-AILab/flash-attention
cd flash-attention
pip install -e .

```

### Dive into fast-inference!
Expand Down
24 changes: 24 additions & 0 deletions colossalai/inference/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env bash

# install triton
pip install triton
pip install transformers

# install lightllm and flash-attention
mkdir 3rdParty
cd 3rdParty
git clone https://github.com/ModelTC/lightllm
cd lightllm
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
pip install -e .
cd ..

git clone -recursive https://github.com/Dao-AILab/flash-attention
cd flash-attention
pip install -e .

cd ../../




46 changes: 15 additions & 31 deletions colossalai/inference/tensor_parallel/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,10 @@
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards

from ._utils import copy_kv_to_mem_cache

try:
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_llama2_context_attention_fwd,
)
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_context_attention_fwd,
context_attention_fwd as lightllm_llama_context_attention_fwd,
)
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd

Expand Down Expand Up @@ -56,32 +51,20 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
def llama_triton_context_attention(
query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1
):
if num_key_value_groups == 1:
if HAS_LIGHTLLM_KERNEL is False:
llama_context_attn_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
lightllm_context_attention_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
# if num_key_value_groups == 1:
if HAS_LIGHTLLM_KERNEL is False:
llama_context_attn_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model"
lightllm_llama2_context_attention_fwd(
lightllm_llama_context_attention_fwd(
query_states,
key_states,
value_states,
Expand All @@ -107,6 +90,7 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)

else:
Llama2TokenAttentionForwards.token_attn(
query_states,
Expand Down
Loading