Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions colossalai/inference/tensor_parallel/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
)
from transformers.models.llama.modeling_llama import LlamaModel, LlamaAttention
from transformers.models.llama.modeling_llama import LlamaModel, LlamaDecoderLayer, LlamaAttention
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
from colossalai.kernel.triton.context_attention import llama_context_attn_fwd
Expand All @@ -16,6 +16,7 @@
class LlamaInferenceForwards:
"""
This class holds forwards for llama inference.
We intend to replace the forward methods for LlamaModel, LlamaDecoderLayer, and LlamaAttention for LlamaForCausalLM.
"""

@staticmethod
Expand Down Expand Up @@ -168,7 +169,7 @@ def llama_model_forward(

@staticmethod
def llama_decoder_layer_forward(
self,
self: LlamaDecoderLayer,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
Expand Down