diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index d55634a6f00b..ce099c61bda7 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -5,7 +5,7 @@ from transformers.modeling_outputs import ( BaseModelOutputWithPast, ) -from transformers.models.llama.modeling_llama import LlamaModel, LlamaAttention +from transformers.models.llama.modeling_llama import LlamaModel, LlamaDecoderLayer, LlamaAttention from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest from colossalai.kernel.triton.context_attention import llama_context_attn_fwd @@ -16,6 +16,7 @@ class LlamaInferenceForwards: """ This class holds forwards for llama inference. + We intend to replace the forward methods for LlamaModel, LlamaDecoderLayer, and LlamaAttention for LlamaForCausalLM. """ @staticmethod @@ -168,7 +169,7 @@ def llama_model_forward( @staticmethod def llama_decoder_layer_forward( - self, + self: LlamaDecoderLayer, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None,