From 6442da8be95f077d0d1b0aa843412a696801962a Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Thu, 31 Aug 2023 11:39:56 +0800 Subject: [PATCH] fix docstring in llama modeling --- colossalai/inference/tensor_parallel/modeling/llama.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index d55634a6f00b..ce099c61bda7 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -5,7 +5,7 @@ from transformers.modeling_outputs import ( BaseModelOutputWithPast, ) -from transformers.models.llama.modeling_llama import LlamaModel, LlamaAttention +from transformers.models.llama.modeling_llama import LlamaModel, LlamaDecoderLayer, LlamaAttention from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest from colossalai.kernel.triton.context_attention import llama_context_attn_fwd @@ -16,6 +16,7 @@ class LlamaInferenceForwards: """ This class holds forwards for llama inference. + We intend to replace the forward methods for LlamaModel, LlamaDecoderLayer, and LlamaAttention for LlamaForCausalLM. """ @staticmethod @@ -168,7 +169,7 @@ def llama_model_forward( @staticmethod def llama_decoder_layer_forward( - self, + self: LlamaDecoderLayer, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None,