Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions colossalai/inference/tensor_parallel/modeling/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
from transformers.utils import logging

from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton.context_attention import bloom_context_attn_fwd
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd


def generate_alibi(n_head, dtype=torch.float16):
Expand Down
10 changes: 6 additions & 4 deletions colossalai/inference/tensor_parallel/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm

from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton.context_attention import llama_context_attn_fwd
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
from colossalai.kernel.triton import (
copy_kv_cache_to_dest,
llama_context_attn_fwd,
rotary_embedding_fwd,
token_attention_fwd,
)

try:
from vllm import layernorm_ops, pos_encoding_ops
Expand Down
4 changes: 2 additions & 2 deletions colossalai/inference/tensor_parallel/policies/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from ..modeling.bloom import BloomInferenceForwards

try:
from colossalai.kernel.triton.fused_layernorm import layer_norm
from colossalai.kernel.triton import layer_norm
HAS_TRITON_NORM = True
except:
print("you should install triton from https://github.com/openai/triton")
print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton")
HAS_TRITON_NORM = False


Expand Down
26 changes: 12 additions & 14 deletions colossalai/inference/tensor_parallel/policies/llama.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,32 @@
from functools import partial

import torch
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaModel,
LlamaRMSNorm
)
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm

# import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy

from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward

try:
from colossalai.kernel.triton.rms_norm import rmsnorm_forward
from colossalai.kernel.triton import rmsnorm_forward
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
HAS_TRITON_RMSNORM = False


def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:

def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)

return _triton_rmsnorm_forward
else:
return None



class LlamaModelInferPolicy(LlamaForCausalLMPolicy):

def __init__(self) -> None:
Expand Down Expand Up @@ -59,12 +58,11 @@ def module_policy(self):
else:
# NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123
infer_forward = get_llama_vllm_rmsnorm_forward()

if infer_forward is not None:
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaRMSNorm)
policy=policy,
target_key=LlamaRMSNorm)

return policy

7 changes: 0 additions & 7 deletions colossalai/kernel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
from .triton import llama_context_attn_fwd, bloom_context_attn_fwd
from .triton import softmax
from .triton import copy_kv_cache_to_dest

__all__ = [
"LayerNorm",
"FusedScaleMaskSoftmax",
"MultiHeadAttention",
"llama_context_attn_fwd",
"bloom_context_attn_fwd",
"softmax",
"copy_kv_cache_to_dest",
]
7 changes: 7 additions & 0 deletions colossalai/kernel/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,11 @@
from .copy_kv_cache_dest import copy_kv_cache_to_dest
from .fused_layernorm import layer_norm
from .rms_norm import rmsnorm_forward
from .rotary_embedding_kernel import rotary_embedding_fwd
from .softmax import softmax
from .token_attention_kernel import token_attention_fwd

__all__ = [
"llama_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", "rmsnorm_forward",
"copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd"
]