From 11f74bf9e814ed349ea0916b6725b6cdc44ee046 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 31 Aug 2023 14:14:05 +0800 Subject: [PATCH 1/7] add installation req --- colossalai/inference/README.md | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index 5eb89447abc0..eafeb7c57206 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -55,18 +55,23 @@ dependencies ```bash pytorch= 1.13.1 (gpu) +cuda>= 11.6 transformers= 4.30.2 triton==2.0.0.dev20221202 -vllm= -flash-attention= +# for install vllm, please use this branch to install https://github.com/tiandiao123/vllm/tree/setup_branch +vllm +# for install flash-attention, please use commit hash: 67ae6fd74b4bc99c36b2ce524cf139c35663793c +flash-attention ``` ### Docker -You can use our official docker container as well. +You can use docker run to use docker container to set-up environment + +``` +docker pull colossal-inference:v2 +docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace colosssal-inference:v2 /bin/bash -```bash -docker.. ``` ### Dive into fast-inference! From b0d55b1e608e45c3b54de9e7f2f3a7b0cfd84fd4 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 31 Aug 2023 14:23:13 +0800 Subject: [PATCH 2/7] fix --- colossalai/inference/README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index eafeb7c57206..11fe89cfe514 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -69,8 +69,10 @@ flash-attention You can use docker run to use docker container to set-up environment ``` -docker pull colossal-inference:v2 -docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace colosssal-inference:v2 /bin/bash +# env: python==3.8, cuda 11.6, triton==2.0.0, vllm kernels support, flash-attention-2 kernels support +docker pull hpcaitech/colossalai-inference:v2 +docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash + ``` From 2704a27cdd45e9ca0840cf08dc98d5c2f7c280b0 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 31 Aug 2023 14:25:31 +0800 Subject: [PATCH 3/7] slight change --- colossalai/inference/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index 11fe89cfe514..6e0abcd24388 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -69,7 +69,7 @@ flash-attention You can use docker run to use docker container to set-up environment ``` -# env: python==3.8, cuda 11.6, triton==2.0.0, vllm kernels support, flash-attention-2 kernels support +# env: python==3.8, cuda 11.6, pytorch == 1.13.1 triton==2.0.0.dev20221202, vllm kernels support, flash-attention-2 kernels support docker pull hpcaitech/colossalai-inference:v2 docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash From 88c2b0faa52df6aae7b5b79a7bf97332212b28dd Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 31 Aug 2023 14:37:44 +0800 Subject: [PATCH 4/7] remove empty --- colossalai/inference/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index 6e0abcd24388..7228c51aa484 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -73,7 +73,6 @@ You can use docker run to use docker container to set-up environment docker pull hpcaitech/colossalai-inference:v2 docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash - ``` ### Dive into fast-inference! From ffea979624660e28f5264bf8bafd8c0bd4a23b0b Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 31 Aug 2023 14:56:33 +0800 Subject: [PATCH 5/7] add rmsnorm polciy --- .../tensor_parallel/modeling/llama.py | 51 +++++++++++-------- .../tensor_parallel/policies/llama.py | 9 +++- colossalai/shardformer/modeling/llama.py | 25 +-------- 3 files changed, 38 insertions(+), 47 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index adb2ad8a0170..74497dfffdbe 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -3,7 +3,13 @@ import numpy as np import torch from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, apply_rotary_pos_emb +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaModel, + apply_rotary_pos_emb, + LlamaRMSNorm +) from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton.context_attention import llama_context_attn_fwd @@ -11,7 +17,8 @@ from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd try: - from vllm import pos_encoding_ops + from vllm import pos_encoding_ops, layernorm_ops + rms_norm = layernorm_ops.rms_norm rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox HAS_VLLM_KERNERL = True except: @@ -45,14 +52,6 @@ def llama_model_forward( batch_size = input_ids.shape[0] # input_ids.shape[0] - # infer_state = BatchInferState(batch_size, input_ids.shape[1]) - # infer_state.batch_size = batch_size - # # NOTE: dummy implementation here for testing, just assume all inputs same length - # infer_state.block_loc = self.block_loc - # infer_state.start_loc = self.start_loc - # infer_state.seq_len = self.seq_len - # infer_state.max_len_in_batch = self.max_len_in_batch - infer_state = self.infer_state b_seq_len_numpy = infer_state.seq_len.cpu().numpy() position_ids = torch.from_numpy( @@ -276,10 +275,6 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.decode_mem_index, infer_state.cache_manager) - # this is worse than destcopy - # torch.Tensor.copy_(infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][infer_state.decode_mem_start:infer_state.decode_mem_end, :, :],key_states) - # torch.Tensor.copy_(infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][infer_state.decode_mem_start:infer_state.decode_mem_end, :, :],value_states) - # FIXME might want to revise # need some way to record the length of past key values cache # since we won't return past_key_value_cache right now @@ -291,14 +286,6 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, if infer_state.is_context_stage: # first token generation - # attn_output, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(query_states, - # key_states, - # value_states, - # 0, - # 1/math.sqrt(self.head_dim), - # causal, - # False) - attn_output = torch.empty_like(query_states) # calcu_shape for context_attention_fwd @@ -325,3 +312,23 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, # return past_key_value as None return attn_output, None, None + +def get_llama_vllm_rmsnorm_forward(): + + if HAS_VLLM_KERNERL: + + def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): + x = hidden_states + out = torch.empty_like(x) + rms_norm( + out, + x, + self.weight.data, + self.variance_epsilon, + ) + + return out + + return _vllm_rmsnorm_forward + else: + return None \ No newline at end of file diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index 997f5fe48a54..952b5e4e8454 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -11,7 +11,7 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel + from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm policy = super().module_policy() self.shard_config._infer() @@ -36,5 +36,12 @@ def module_policy(self): self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaAttention) + + infer_forward = LlamaInferenceForwards.get_llama_vllm_rmsnorm_forward + if infer_forward is not None: + method_replacement = {'forward': partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=LlamaRMSNorm) return policy diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 2224539d273e..08220eb73427 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -12,7 +12,6 @@ LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, - LlamaRMSNorm, apply_rotary_pos_emb, ) from transformers.utils import logging @@ -21,10 +20,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager try: - from vllm import layernorm_ops, pos_encoding_ops - rms_norm = layernorm_ops.rms_norm + from vllm import pos_encoding_ops rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox - rms_norm = layernorm_ops.rms_norm HAS_VLLM_KERNERL = True except: print("fall back to original rotary_embedding_neox of huggingface") @@ -477,23 +474,3 @@ def forward( return forward - -def get_llama_vllm_rmsnorm_forward(): - - if HAS_VLLM_KERNERL: - - def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): - x = hidden_states - out = torch.empty_like(x) - rms_norm( - out, - x, - self.weight.data, - self.variance_epsilon, - ) - - return out - - return _vllm_rmsnorm_forward - else: - return None From 63c9e49c096993d6670652a977889d40998e4783 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 31 Aug 2023 15:02:00 +0800 Subject: [PATCH 6/7] add --- colossalai/inference/tensor_parallel/policies/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index 952b5e4e8454..76ed1117c0f4 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -3,6 +3,7 @@ from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy from ..modeling.llama import LlamaInferenceForwards +from ..modeling.llama import get_llama_vllm_rmsnorm_forward class LlamaModelInferPolicy(LlamaForCausalLMPolicy): @@ -37,7 +38,7 @@ def module_policy(self): policy=policy, target_key=LlamaAttention) - infer_forward = LlamaInferenceForwards.get_llama_vllm_rmsnorm_forward + infer_forward = get_llama_vllm_rmsnorm_forward() if infer_forward is not None: method_replacement = {'forward': partial(infer_forward)} self.append_or_create_method_replacement(description=method_replacement, From 06b3d152a7729f2dc9c6686a167bdf8ff8e3af19 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 31 Aug 2023 15:38:47 +0800 Subject: [PATCH 7/7] clean codes --- .../inference/tensor_parallel/modeling/llama.py | 3 +-- .../inference/tensor_parallel/policies/llama.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 74497dfffdbe..7c77785b24e8 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -314,9 +314,8 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, return attn_output, None, None def get_llama_vllm_rmsnorm_forward(): - + if HAS_VLLM_KERNERL: - def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): x = hidden_states out = torch.empty_like(x) diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index 76ed1117c0f4..c569a0e3163a 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -1,4 +1,5 @@ from functools import partial +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy @@ -12,7 +13,6 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm policy = super().module_policy() self.shard_config._infer() @@ -38,11 +38,12 @@ def module_policy(self): policy=policy, target_key=LlamaAttention) - infer_forward = get_llama_vllm_rmsnorm_forward() - if infer_forward is not None: - method_replacement = {'forward': partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=LlamaRMSNorm) + # TODO: adding rms_norm caused precision issue, fix @tiandiao123 + # infer_forward = get_llama_vllm_rmsnorm_forward() + # if infer_forward is not None: + # method_replacement = {'forward': partial(infer_forward)} + # self.append_or_create_method_replacement(description=method_replacement, + # policy=policy, + # target_key=LlamaRMSNorm) return policy