Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions colossalai/inference/tensor_parallel/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import _utils

from .bloom import BloomInferenceForwards
from .chatglm2 import ChatGLM2InferenceForwards
from .llama import LlamaInferenceForwards
Expand Down
59 changes: 58 additions & 1 deletion colossalai/inference/tensor_parallel/modeling/_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,67 @@
"""
Utils for model inference
"""
import os

import torch

from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest


def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
"""
This function copies the key and value cache to the memory cache
Args:
layer_id : id of current layer
key_buffer : key cache
value_buffer : value cache
context_mem_index : index of memory cache in kv cache manager
mem_manager : cache manager
"""
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
return


def init_to_get_rotary(self, base=10000, use_elem=False):
Comment thread
FrankLeeeee marked this conversation as resolved.
"""
This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer
Args:
self : Model that holds the rotary positional embedding
base : calculation arg
use_elem : activated when using chatglm-based models
"""
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
if not hasattr(self.config, "rope_scaling"):
rope_scaling_factor = 1.0
else:
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0

if hasattr(self.config, "max_sequence_length"):
max_seq_len = self.config.max_sequence_length
elif hasattr(self.config, "max_position_embeddings"):
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
else:
max_seq_len = 2048 * rope_scaling_factor
base = float(base)

# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", None))

if ntk_alpha is not None:
ntk_alpha = float(ntk_alpha)
assert ntk_alpha >= 1, "NTK alpha must be greater than or equal to 1"
if ntk_alpha > 1:
print(f"Note: NTK enabled, alpha set to {ntk_alpha}")
max_seq_len *= ntk_alpha
base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula

n_elem = self.config.head_dim_
if use_elem:
n_elem //= 2

inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem))
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)

self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
24 changes: 5 additions & 19 deletions colossalai/inference/tensor_parallel/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm

from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import (
copy_kv_cache_to_dest,
llama_context_attn_fwd,
rotary_embedding_fwd,
token_attention_fwd,
)
from colossalai.kernel.triton import llama_context_attn_fwd, rotary_embedding_fwd, token_attention_fwd
Comment thread
FrankLeeeee marked this conversation as resolved.

from ._utils import copy_kv_to_mem_cache

try:
from vllm import layernorm_ops, pos_encoding_ops
Expand Down Expand Up @@ -46,12 +43,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
return q_embed, k_embed


def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
return


class LlamaInferenceForwards:
"""
This class holds forwards for llama inference.
Expand Down Expand Up @@ -285,11 +276,6 @@ def llama_flash_attn_kvcache_forward(
rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin)

def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
return

query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
value_states = value_states.reshape(-1, self.num_heads, self.head_dim)
Expand All @@ -298,7 +284,7 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index,
# first token generation

# copy key and value calculated in current step to memory manager
_copy_kv_to_mem_cache(
copy_kv_to_mem_cache(
infer_state.decode_layer_id,
key_states,
value_states,
Expand Down Expand Up @@ -331,7 +317,7 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index,
else:
# if decode is not contiguous, use triton kernel to copy key and value cache
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
_copy_kv_to_mem_cache(
copy_kv_to_mem_cache(
infer_state.decode_layer_id,
key_states,
value_states,
Expand Down
39 changes: 18 additions & 21 deletions colossalai/inference/tensor_parallel/policies/chatglm2.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,27 @@
from functools import partial

import torch

from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMForConditionalGeneration,
ChatGLMModel,
GLMBlock,
GLMTransformer,
SelfAttention,
)

# import colossalai
from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy

from ..modeling.chatglm2 import ChatGLM2InferenceForwards, _init_to_get_rotary
from ..modeling._utils import init_to_get_rotary
from ..modeling.chatglm2 import ChatGLM2InferenceForwards

try:
from colossalai.kernel.triton.rms_norm import rmsnorm_forward
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
HAS_TRITON_RMSNORM = False


class ChatGLM2InferPolicy(ChatGLMModelPolicy):

def __init__(self) -> None:
super().__init__()

Expand All @@ -32,45 +30,44 @@ def module_policy(self):
self.shard_config._infer()

model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward
method_replacement = {'forward': model_infer_forward}
method_replacement = {"forward": model_infer_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel)

encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward
method_replacement = {'forward': encoder_infer_forward}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=GLMTransformer)
method_replacement = {"forward": encoder_infer_forward}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=GLMTransformer
)

encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward
method_replacement = {'forward': encoder_layer_infer_forward}
method_replacement = {"forward": encoder_layer_infer_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock)

attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward
method_replacement = {'forward': attn_infer_forward}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=SelfAttention)
method_replacement = {"forward": attn_infer_forward}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=SelfAttention
)

# for rmsnorm and others, we need to check the shape
return policy

def postprocess(self):
_init_to_get_rotary(self.model)
init_to_get_rotary(self.model)
return self.model


class ChatGLM2ForConditionalGenerationInferPolicy(ChatGLM2InferPolicy):

def __init__(self) -> None:
super().__init__()

def module_policy(self):
policy = super().module_policy()
model_infer_forward = ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward
method_replacement = {'forward': partial(model_infer_forward)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=ChatGLMForConditionalGeneration)
method_replacement = {"forward": partial(model_infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=ChatGLMForConditionalGeneration
)
return policy

def postprocess(self):
Expand Down
25 changes: 15 additions & 10 deletions colossalai/inference/tensor_parallel/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import torch
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm

from colossalai.shardformer.layer import VocabParallelEmbedding1D
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription

# import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy

from ..modeling._utils import init_to_get_rotary
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward

try:
Expand Down Expand Up @@ -50,38 +51,38 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=ColCaiQuantLinear,
kwargs={'split_num': 1},
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=ColCaiQuantLinear,
kwargs={'split_num': 1},
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=ColCaiQuantLinear,
kwargs={'split_num': 1},
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=RowCaiQuantLinear,
kwargs={'split_num': 1},
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=ColCaiQuantLinear,
kwargs={'split_num': 1},
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=ColCaiQuantLinear,
kwargs={'split_num': 1},
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=RowCaiQuantLinear,
kwargs={'split_num': 1},
)
kwargs={"split_num": 1},
),
],
)

Expand Down Expand Up @@ -117,3 +118,7 @@ def module_policy(self):
)

return policy

def postprocess(self):
init_to_get_rotary(self.model.model)
return self.model
10 changes: 6 additions & 4 deletions colossalai/kernel/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

HAS_TRITON = True

except ImportError:
HAS_TRITON = False
print("Triton is not installed. Please install Triton to use Triton kernels.")

# There may exist import error even if we have triton installed.
if HAS_TRITON:
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
from .copy_kv_cache_dest import copy_kv_cache_to_dest
from .fused_layernorm import layer_norm
Expand All @@ -23,7 +29,3 @@
"token_attention_fwd",
"gptq_fused_linear_triton",
]

except ImportError:
HAS_TRITON = False
print("Triton is not installed. Please install Triton to use Triton kernels.")
25 changes: 0 additions & 25 deletions examples/inference/bench_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,6 @@
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"


def init_to_get_rotary(self, base=10000):
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
if not hasattr(self.config, "rope_scaling"):
rope_scaling_factor = 1.0
else:
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
if hasattr(self.config, "max_sequence_length"):
max_seq_len = self.config.max_sequence_length
elif hasattr(self.config, "max_position_embeddings"):
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
else:
max_seq_len = 2048 * rope_scaling_factor
base = float(base)
inv_freq = 1.0 / (
base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_)
)
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)

self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
return


def print_perf_stats(latency_set, config, bs, warmup=3):
# trim warmup queries
latency_set = list(latency_set)
Expand Down Expand Up @@ -66,7 +42,6 @@ def run_llama_test(args):
tokenizer = LlamaTokenizer.from_pretrained(llama_model_path)
tokenizer.pad_token_id = tokenizer.unk_token_id
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id)
init_to_get_rotary(model.model, base=10000)
model = model.half()

model_config = model.config
Expand Down
Loading