From cb5e01d47d1615bc802f4abcfa2bf5ae8eba2547 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 19 May 2025 14:28:20 +0800 Subject: [PATCH 01/16] upgrade mixtral --- colossalai/shardformer/modeling/mixtral.py | 126 +++++++++++++----- colossalai/shardformer/policies/mixtral.py | 22 +-- .../test_model/test_shard_mixtral.py | 2 +- 3 files changed, 103 insertions(+), 47 deletions(-) diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index a88db87bc601..2bc1e5c7c784 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -1,6 +1,6 @@ import inspect import warnings -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Callable import torch import torch.distributed as dist @@ -14,6 +14,7 @@ ) from transformers.models.mixtral.modeling_mixtral import ( MixtralSparseMoeBlock, + MixtralModel, MoeCausalLMOutputWithPast, MoeModelOutputWithPast, apply_rotary_pos_emb, @@ -215,7 +216,7 @@ class MixtralPipelineForwards: @staticmethod def mixtral_model_forward( - self, + self: MixtralModel, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -225,6 +226,7 @@ def mixtral_model_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, @@ -340,11 +342,18 @@ def mixtral_model_forward( ) use_cache = False + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None next_decoder_cache = None + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) start_idx, end_idx = stage_index[0], stage_index[1] for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): @@ -370,6 +379,9 @@ def custom_forward(*inputs): None, output_attentions, output_router_logits, + use_cache, + cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -380,6 +392,8 @@ def custom_forward(*inputs): output_attentions, output_router_logits, use_cache, + cache_position, + position_embeddings, ) hidden_states = layer_outputs[0] @@ -559,14 +573,18 @@ def mixtral_for_causal_lm_forward( def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): logger = logging.get_logger(__name__) + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + from transformers.models.mixtral.modeling_mixtral import eager_attention_forward def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: @@ -615,9 +633,10 @@ def forward( # Because the input can be padded, the absolute sequence length depends on the max position id. rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + cos, sin = position_embeddings + # cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) use_sliding_windows = ( _flash_supports_window_size @@ -631,31 +650,31 @@ def forward( ) if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - if ( - getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and cache_has_contents - ): - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" - f" {past_key.shape}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + # if ( + # getattr(self.config, "sliding_window", None) is not None + # and kv_seq_len > self.config.sliding_window + # and cache_has_contents + # ): + # slicing_tokens = 1 - self.config.sliding_window + + # past_key = past_key_value[self.layer_idx][0] + # past_value = past_key_value[self.layer_idx][1] + + # past_key = past_key[:, :, slicing_tokens:, :].contiguous() + # past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + # if past_key.shape[-2] != self.config.sliding_window - 1: + # raise ValueError( + # f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + # f" {past_key.shape}" + # ) + + # if attention_mask is not None: + # attention_mask = attention_mask[:, slicing_tokens:] + # attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads @@ -689,14 +708,36 @@ def forward( query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = self._flash_attention_forward( + # attn_output = self._flash_attention_forward( + # query_states, + # key_states, + # value_states, + # attention_mask, + # q_len, + # dropout=dropout_rate, + # use_sliding_windows=use_sliding_windows, + # ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=dropout_rate, - use_sliding_windows=use_sliding_windows, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama + **kwargs, ) # sp: all-to-all comminucation when introducing sequence parallel @@ -712,7 +753,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights return forward @@ -731,6 +772,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, MoeModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -788,7 +830,7 @@ def forward( " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) - if self._attn_implementation == "flash_attention_2": + if self.config._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self._attn_implementation == "sdpa" and not output_attentions: @@ -820,6 +862,16 @@ def forward( ) hidden_states = inputs_embeds + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -840,6 +892,8 @@ def forward( output_attentions, output_router_logits, use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, ) else: layer_outputs = decoder_layer( @@ -850,6 +904,8 @@ def forward( output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index fab437c01d51..c35b77869aa4 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -43,18 +43,18 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: from transformers.models.mixtral.modeling_mixtral import ( MixtralAttention, MixtralDecoderLayer, - MixtralFlashAttention2, + # MixtralFlashAttention2, MixtralModel, - MixtralSdpaAttention, + # MixtralSdpaAttention, ) - ATTN_IMPLEMENTATION = { - "eager": MixtralAttention, - "flash_attention_2": MixtralFlashAttention2, - "sdpa": MixtralSdpaAttention, - } + # ATTN_IMPLEMENTATION = { + # "eager": MixtralAttention, + # "flash_attention_2": MixtralFlashAttention2, + # "sdpa": MixtralSdpaAttention, + # } policy = {} - attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + # attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] sp_mode = self.shard_config.sequence_parallelism_mode or None sp_size = self.shard_config.sequence_parallel_size or None @@ -76,7 +76,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: num_kv_heads //= sp_size decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads - policy[attn_cls] = ModulePolicyDescription( + policy[MixtralAttention] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) if self.shard_config.enable_sequence_parallelism: @@ -89,7 +89,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "forward": get_mixtral_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), }, policy=policy, - target_key=attn_cls, + target_key=MixtralAttention, ) self.append_or_create_method_replacement( description={ @@ -330,7 +330,7 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - + held_layers.append(module.rotary_emb) if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None layers_per_stage = stage_manager.distribute_layers(len(module.layers)) diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py index b691130720f7..0d7feff01a1c 100644 --- a/tests/test_shardformer/test_model/test_shard_mixtral.py +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -169,7 +169,7 @@ def run_mixtral_commom(config: Tuple[int, ...]): (1, 1, 4, 1, 1), (1, 1, 1, 4, 1), (1, 2, 1, 1, 2), - # zero 2 + # # zero 2 (2, 4, 1, 1, 1), (2, 1, 4, 1, 1), (2, 1, 1, 4, 1), From f0a6133f5f70b9276861b2cec2b2fd19b63311ae Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 19 May 2025 14:33:04 +0800 Subject: [PATCH 02/16] fix --- colossalai/shardformer/modeling/mixtral.py | 35 ------------------- colossalai/shardformer/policies/mixtral.py | 8 ----- .../test_model/test_shard_mixtral.py | 2 +- 3 files changed, 1 insertion(+), 44 deletions(-) diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 2bc1e5c7c784..a12808d8a07d 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -634,7 +634,6 @@ def forward( # Because the input can be padded, the absolute sequence length depends on the max position id. rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 cos, sin = position_embeddings - # cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -649,31 +648,6 @@ def forward( " make sure to upgrade flash-attn library." ) if past_key_value is not None: - # Activate slicing cache only if the config has a value `sliding_windows` attribute - # cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - # if ( - # getattr(self.config, "sliding_window", None) is not None - # and kv_seq_len > self.config.sliding_window - # and cache_has_contents - # ): - # slicing_tokens = 1 - self.config.sliding_window - - # past_key = past_key_value[self.layer_idx][0] - # past_value = past_key_value[self.layer_idx][1] - - # past_key = past_key[:, :, slicing_tokens:, :].contiguous() - # past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - # if past_key.shape[-2] != self.config.sliding_window - 1: - # raise ValueError( - # f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" - # f" {past_key.shape}" - # ) - - # if attention_mask is not None: - # attention_mask = attention_mask[:, slicing_tokens:] - # attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -708,15 +682,6 @@ def forward( query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - # attn_output = self._flash_attention_forward( - # query_states, - # key_states, - # value_states, - # attention_mask, - # q_len, - # dropout=dropout_rate, - # use_sliding_windows=use_sliding_windows, - # ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index c35b77869aa4..ef326c031a29 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -43,18 +43,10 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: from transformers.models.mixtral.modeling_mixtral import ( MixtralAttention, MixtralDecoderLayer, - # MixtralFlashAttention2, MixtralModel, - # MixtralSdpaAttention, ) - # ATTN_IMPLEMENTATION = { - # "eager": MixtralAttention, - # "flash_attention_2": MixtralFlashAttention2, - # "sdpa": MixtralSdpaAttention, - # } policy = {} - # attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] sp_mode = self.shard_config.sequence_parallelism_mode or None sp_size = self.shard_config.sequence_parallel_size or None diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py index 0d7feff01a1c..b691130720f7 100644 --- a/tests/test_shardformer/test_model/test_shard_mixtral.py +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -169,7 +169,7 @@ def run_mixtral_commom(config: Tuple[int, ...]): (1, 1, 4, 1, 1), (1, 1, 1, 4, 1), (1, 2, 1, 1, 2), - # # zero 2 + # zero 2 (2, 4, 1, 1, 1), (2, 1, 4, 1, 1), (2, 1, 1, 4, 1), From 5c812bd20620cbd441e95cbbe197abfb8e15a4ba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 May 2025 06:39:07 +0000 Subject: [PATCH 03/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/modeling/mixtral.py | 6 +++--- colossalai/shardformer/policies/mixtral.py | 6 +----- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index a12808d8a07d..8c182fcf4081 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -1,6 +1,6 @@ import inspect import warnings -from typing import List, Optional, Tuple, Union, Callable +from typing import Callable, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -13,8 +13,8 @@ _prepare_4d_causal_attention_mask_for_sdpa, ) from transformers.models.mixtral.modeling_mixtral import ( - MixtralSparseMoeBlock, MixtralModel, + MixtralSparseMoeBlock, MoeCausalLMOutputWithPast, MoeModelOutputWithPast, apply_rotary_pos_emb, @@ -654,7 +654,7 @@ def forward( # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout + 0.0 if not self.training else self.attention_dropout # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index ef326c031a29..a9584db9bc52 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -40,11 +40,7 @@ def preprocess(self): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.mixtral.modeling_mixtral import ( - MixtralAttention, - MixtralDecoderLayer, - MixtralModel, - ) + from transformers.models.mixtral.modeling_mixtral import MixtralAttention, MixtralDecoderLayer, MixtralModel policy = {} From a5ec48bb410c7f94a06e5c28ec40b2b4ff12476c Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 19 May 2025 14:58:05 +0800 Subject: [PATCH 04/16] fix --- colossalai/shardformer/modeling/mixtral.py | 6 ------ colossalai/shardformer/policies/mixtral.py | 8 ++++++++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 8c182fcf4081..1766d0b0dd3e 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -632,16 +632,10 @@ def forward( kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - use_sliding_windows = ( - _flash_supports_window_size - and getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - ) if not _flash_supports_window_size: logger.warning_once( "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index a9584db9bc52..48c9d4f71871 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -40,7 +40,15 @@ def preprocess(self): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: +<<<<<<< Updated upstream from transformers.models.mixtral.modeling_mixtral import MixtralAttention, MixtralDecoderLayer, MixtralModel +======= + from transformers.models.mixtral.modeling_mixtral import ( + MixtralAttention, + MixtralDecoderLayer, + MixtralModel, + ) +>>>>>>> Stashed changes policy = {} From c7031a20b061572c32fda810f1416da2bd047fb2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 May 2025 06:59:06 +0000 Subject: [PATCH 05/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/policies/mixtral.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 48c9d4f71871..7895c74007e4 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -43,11 +43,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: <<<<<<< Updated upstream from transformers.models.mixtral.modeling_mixtral import MixtralAttention, MixtralDecoderLayer, MixtralModel ======= - from transformers.models.mixtral.modeling_mixtral import ( - MixtralAttention, - MixtralDecoderLayer, - MixtralModel, - ) + from transformers.models.mixtral.modeling_mixtral import MixtralAttention, MixtralDecoderLayer, MixtralModel >>>>>>> Stashed changes policy = {} From ec1d6392cfb0483295a084dc5e5173a3ee23b122 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 19 May 2025 16:37:25 +0800 Subject: [PATCH 06/16] upgrade infer --- .../inference/modeling/models/glide_llama.py | 26 ++++++++++++------- .../modeling/models/nopadding_llama.py | 6 ++--- colossalai/shardformer/policies/mixtral.py | 4 ++- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/colossalai/inference/modeling/models/glide_llama.py b/colossalai/inference/modeling/models/glide_llama.py index 0ee78a303004..621984285010 100644 --- a/colossalai/inference/modeling/models/glide_llama.py +++ b/colossalai/inference/modeling/models/glide_llama.py @@ -12,9 +12,9 @@ LlamaAttention, LlamaConfig, LlamaDecoderLayer, - LlamaDynamicNTKScalingRotaryEmbedding, + # LlamaDynamicNTKScalingRotaryEmbedding, LlamaForCausalLM, - LlamaLinearScalingRotaryEmbedding, + # LlamaLinearScalingRotaryEmbedding, LlamaMLP, LlamaModel, LlamaRMSNorm, @@ -173,6 +173,8 @@ def glide_llama_model_forward( position_ids = cache_position.unsqueeze(0) attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + position_ids = position_ids + glide_input.n_spec_tokens + position_embeddings = self.rotary_emb(hidden_states, position_ids) # embed positions hidden_states = inputs_embeds @@ -189,9 +191,9 @@ def glide_llama_model_forward( # GlideLlamaDecoderLayer layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, glide_input=glide_input, attention_mask=attention_mask, - position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, @@ -267,7 +269,7 @@ def __init__(self, config: GlideLlamaConfig): self.q_proj = nn.Linear(self.hidden_size, self.large_num_heads * self.large_head_dim, bias=False) self.o_proj = nn.Linear(self.large_num_heads * self.large_head_dim, self.hidden_size, bias=False) - self._init_rope() + # self._init_rope() def _init_rope(self): if self.config.rope_scaling is None: @@ -299,9 +301,10 @@ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, glide_input: GlideInput = None, # Used for glimpsing main model's KV caches attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Optional[torch.Tensor]: @@ -319,8 +322,9 @@ def forward( query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2) # for RoPE - position_ids = position_ids + glide_input.n_spec_tokens - cos, sin = self.rotary_emb(query_states, position_ids) + # position_ids = position_ids + glide_input.n_spec_tokens + # cos, sin = self.rotary_emb(query_states, position_ids) + cos, sin = position_embeddings query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids) query_states = query_states.transpose(1, 2) query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim) @@ -367,9 +371,10 @@ def from_native_module(module: LlamaDecoderLayer, *args, **kwargs) -> "GlideLlam def forward( self, hidden_states: torch.Tensor, + position_embeddings: torch.Tensor = None, + position_ids: Optional[torch.LongTensor] = None, glide_input: GlideInput = None, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -401,8 +406,8 @@ def forward( # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, - position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -425,9 +430,10 @@ def forward( hidden_states = self.cross_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, + position_ids = position_ids, glide_input=glide_input, attention_mask=attention_mask, - position_ids=position_ids, output_attentions=output_attentions, use_cache=True, ) diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index c7c7473acf2c..6c040dd22dc2 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -478,9 +478,9 @@ def from_native_module( attn_oproj=attn_oproj, process_group=process_group, model_shard_infer_config=model_shard_infer_config, - num_heads=module.num_heads, - hidden_size=module.hidden_size, - num_key_value_heads=module.num_key_value_heads, + num_heads=module.config.num_attention_heads, + hidden_size=module.config.hidden_size, + num_key_value_heads=module.config.num_key_value_heads, ) return attn_layer diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 7895c74007e4..206d3ba15408 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -40,10 +40,12 @@ def preprocess(self): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: -<<<<<<< Updated upstream from transformers.models.mixtral.modeling_mixtral import MixtralAttention, MixtralDecoderLayer, MixtralModel +<<<<<<< Updated upstream ======= from transformers.models.mixtral.modeling_mixtral import MixtralAttention, MixtralDecoderLayer, MixtralModel +>>>>>>> Stashed changes +======= >>>>>>> Stashed changes policy = {} From 20c29d5a2be3032181d42a7853e69072aea4598c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 May 2025 08:38:25 +0000 Subject: [PATCH 07/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/inference/modeling/models/glide_llama.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/colossalai/inference/modeling/models/glide_llama.py b/colossalai/inference/modeling/models/glide_llama.py index 621984285010..2fe87aa0caa5 100644 --- a/colossalai/inference/modeling/models/glide_llama.py +++ b/colossalai/inference/modeling/models/glide_llama.py @@ -12,9 +12,7 @@ LlamaAttention, LlamaConfig, LlamaDecoderLayer, - # LlamaDynamicNTKScalingRotaryEmbedding, LlamaForCausalLM, - # LlamaLinearScalingRotaryEmbedding, LlamaMLP, LlamaModel, LlamaRMSNorm, @@ -431,7 +429,7 @@ def forward( hidden_states = self.cross_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, - position_ids = position_ids, + position_ids=position_ids, glide_input=glide_input, attention_mask=attention_mask, output_attentions=output_attentions, From 4ff60cf42858f41b7f4dfa136caa76a677a2c69b Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 19 May 2025 16:54:00 +0800 Subject: [PATCH 08/16] fix --- .../inference/modeling/models/glide_llama.py | 27 ------------------- colossalai/shardformer/policies/mixtral.py | 3 +++ 2 files changed, 3 insertions(+), 27 deletions(-) diff --git a/colossalai/inference/modeling/models/glide_llama.py b/colossalai/inference/modeling/models/glide_llama.py index 2fe87aa0caa5..816c13ab3add 100644 --- a/colossalai/inference/modeling/models/glide_llama.py +++ b/colossalai/inference/modeling/models/glide_llama.py @@ -267,31 +267,6 @@ def __init__(self, config: GlideLlamaConfig): self.q_proj = nn.Linear(self.hidden_size, self.large_num_heads * self.large_head_dim, bias=False) self.o_proj = nn.Linear(self.large_num_heads * self.large_head_dim, self.hidden_size, bias=False) - # self._init_rope() - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding( - self.large_head_dim, - max_position_embeddings=self.max_position_embeddings, - ) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.large_head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - ) - elif scaling_type == "dynamic": - self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.large_head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -320,8 +295,6 @@ def forward( query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2) # for RoPE - # position_ids = position_ids + glide_input.n_spec_tokens - # cos, sin = self.rotary_emb(query_states, position_ids) cos, sin = position_embeddings query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids) query_states = query_states.transpose(1, 2) diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 206d3ba15408..b9a9de973f48 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -42,10 +42,13 @@ def preprocess(self): def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: from transformers.models.mixtral.modeling_mixtral import MixtralAttention, MixtralDecoderLayer, MixtralModel <<<<<<< Updated upstream +<<<<<<< Updated upstream ======= from transformers.models.mixtral.modeling_mixtral import MixtralAttention, MixtralDecoderLayer, MixtralModel >>>>>>> Stashed changes ======= +>>>>>>> Stashed changes +======= >>>>>>> Stashed changes policy = {} From 5f922cce2a792efca56b5b0b0553e5940cd03335 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 May 2025 08:55:48 +0000 Subject: [PATCH 09/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/inference/modeling/models/glide_llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/inference/modeling/models/glide_llama.py b/colossalai/inference/modeling/models/glide_llama.py index 816c13ab3add..f1b4e85bb4bc 100644 --- a/colossalai/inference/modeling/models/glide_llama.py +++ b/colossalai/inference/modeling/models/glide_llama.py @@ -16,7 +16,6 @@ LlamaMLP, LlamaModel, LlamaRMSNorm, - LlamaRotaryEmbedding, ) from colossalai.inference.spec import GlideInput From cada39b56e4b965938ca13956f9490e8c7a3b246 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 19 May 2025 17:02:28 +0800 Subject: [PATCH 10/16] fix --- colossalai/shardformer/policies/mixtral.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index b9a9de973f48..a9584db9bc52 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -41,15 +41,6 @@ def preprocess(self): def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: from transformers.models.mixtral.modeling_mixtral import MixtralAttention, MixtralDecoderLayer, MixtralModel -<<<<<<< Updated upstream -<<<<<<< Updated upstream -======= - from transformers.models.mixtral.modeling_mixtral import MixtralAttention, MixtralDecoderLayer, MixtralModel ->>>>>>> Stashed changes -======= ->>>>>>> Stashed changes -======= ->>>>>>> Stashed changes policy = {} From b6a3e2904eb1ca890a8d64cfa9eb43de3ac6d1d5 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 20 May 2025 15:49:42 +0800 Subject: [PATCH 11/16] upgrade drafter --- .../inference/modeling/models/glide_llama.py | 36 +++++++------------ colossalai/inference/spec/drafter.py | 7 ++-- 2 files changed, 17 insertions(+), 26 deletions(-) diff --git a/colossalai/inference/modeling/models/glide_llama.py b/colossalai/inference/modeling/models/glide_llama.py index f1b4e85bb4bc..0a7247be4b31 100644 --- a/colossalai/inference/modeling/models/glide_llama.py +++ b/colossalai/inference/modeling/models/glide_llama.py @@ -153,15 +153,13 @@ def glide_llama_model_forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + print("past_key_values", type(past_key_values)) if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") + print("past_key_values", type(past_key_values)) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -169,17 +167,17 @@ def glide_llama_model_forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) - position_ids = position_ids + glide_input.n_spec_tokens - position_embeddings = self.rotary_emb(hidden_states, position_ids) + attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) + if hasattr(glide_input, "n_spec_tokens"): + position_ids = position_ids + glide_input.n_spec_tokens # embed positions hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -199,8 +197,6 @@ def glide_llama_model_forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -211,16 +207,11 @@ def glide_llama_model_forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache - ) if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple(v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) @@ -374,7 +365,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, @@ -417,9 +408,6 @@ def forward( outputs = (hidden_states,) - if use_cache: - outputs += (present_key_value,) - return outputs diff --git a/colossalai/inference/spec/drafter.py b/colossalai/inference/spec/drafter.py index 3144b2c90c95..13408791d013 100644 --- a/colossalai/inference/spec/drafter.py +++ b/colossalai/inference/spec/drafter.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn from transformers import PreTrainedTokenizer +from transformers.cache_utils import DynamicCache from colossalai.utils import get_current_device @@ -93,9 +94,8 @@ def speculate( for _ in range(n_spec_tokens): # update past key values - kwargs["past_key_values"] = past_key_values - outputs = self._drafter_model(input_ids, **kwargs) + outputs = self._drafter_model(input_ids, past_key_values=past_key_values, **kwargs) next_token_logits = outputs.logits[:, -1, :] # NOTE Only use greedy search for speculating. @@ -110,10 +110,13 @@ def speculate( break input_ids = next_token_ids[:, None] past_key_values = outputs.past_key_values + speculated_length = len(token_ids) # For now, only support bsz 1 logits = torch.concat(logits, dim=0) token_ids = torch.concat(token_ids, dim=-1) + if isinstance(past_key_values, DynamicCache): + past_key_values = past_key_values.to_legacy_cache() out = DrafterOutput( speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values From f0e31e7602edb90e60e324e035aee5821cd3e8fe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 May 2025 07:50:48 +0000 Subject: [PATCH 12/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/inference/modeling/models/glide_llama.py | 7 +++---- colossalai/inference/spec/drafter.py | 1 - 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/colossalai/inference/modeling/models/glide_llama.py b/colossalai/inference/modeling/models/glide_llama.py index 0a7247be4b31..f5686dd83290 100644 --- a/colossalai/inference/modeling/models/glide_llama.py +++ b/colossalai/inference/modeling/models/glide_llama.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn -from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.cache_utils import DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.llama.modeling_llama import ( LlamaAttention, @@ -154,8 +154,8 @@ def glide_llama_model_forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - past_key_values = DynamicCache() - print("past_key_values", type(past_key_values)) + past_key_values = DynamicCache() + print("past_key_values", type(past_key_values)) if cache_position is None: print("past_key_values", type(past_key_values)) @@ -197,7 +197,6 @@ def glide_llama_model_forward( hidden_states = layer_outputs[0] - if output_attentions: all_self_attns += (layer_outputs[1],) diff --git a/colossalai/inference/spec/drafter.py b/colossalai/inference/spec/drafter.py index 13408791d013..81d26be5cbba 100644 --- a/colossalai/inference/spec/drafter.py +++ b/colossalai/inference/spec/drafter.py @@ -110,7 +110,6 @@ def speculate( break input_ids = next_token_ids[:, None] past_key_values = outputs.past_key_values - speculated_length = len(token_ids) # For now, only support bsz 1 logits = torch.concat(logits, dim=0) From b289151ad4a69f5446cd6eb3196fbb4834ed5dc1 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 20 May 2025 15:54:27 +0800 Subject: [PATCH 13/16] fix --- colossalai/inference/modeling/models/glide_llama.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/colossalai/inference/modeling/models/glide_llama.py b/colossalai/inference/modeling/models/glide_llama.py index f5686dd83290..520eccef4bda 100644 --- a/colossalai/inference/modeling/models/glide_llama.py +++ b/colossalai/inference/modeling/models/glide_llama.py @@ -155,10 +155,8 @@ def glide_llama_model_forward( if use_cache and past_key_values is None: past_key_values = DynamicCache() - print("past_key_values", type(past_key_values)) if cache_position is None: - print("past_key_values", type(past_key_values)) past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device From e80303af25e0003f222574203deccf78c2ec0c5a Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 21 May 2025 15:53:20 +0800 Subject: [PATCH 14/16] upgrade lazy --- colossalai/lazy/pretrained.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/lazy/pretrained.py b/colossalai/lazy/pretrained.py index 226951598aa2..69058a78d2b3 100644 --- a/colossalai/lazy/pretrained.py +++ b/colossalai/lazy/pretrained.py @@ -286,7 +286,8 @@ def new_from_pretrained( config.name_or_path = pretrained_model_name_or_path # Instantiate model. - init_contexts = [no_init_weights(_enable=_fast_init)] + # init_contexts = [no_init_weights(_enable=_fast_init)] + init_contexts = [no_init_weights()] with ContextManagers(init_contexts): model = cls(config, *model_args, **model_kwargs) From 900e8e1f9b5dd9ffbe4b8bdfacf86fee8c6ebb01 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 21 May 2025 07:55:12 +0000 Subject: [PATCH 15/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/lazy/pretrained.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/lazy/pretrained.py b/colossalai/lazy/pretrained.py index 69058a78d2b3..66f4cf3bbc98 100644 --- a/colossalai/lazy/pretrained.py +++ b/colossalai/lazy/pretrained.py @@ -69,7 +69,7 @@ def new_from_pretrained( _ = kwargs.pop("mirror", None) from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) - _fast_init = kwargs.pop("_fast_init", True) + kwargs.pop("_fast_init", True) torch_dtype = kwargs.pop("torch_dtype", None) subfolder = kwargs.pop("subfolder", "") commit_hash = kwargs.pop("_commit_hash", None) From 55ba06b12b2745ccc65a43e5cb956c9d6af2c458 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 21 May 2025 16:12:37 +0800 Subject: [PATCH 16/16] upgrade mixtral --- colossalai/shardformer/modeling/mixtral.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 1766d0b0dd3e..2d094040a555 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -348,11 +348,10 @@ def mixtral_model_forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None - next_decoder_cache = None if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device ) start_idx, end_idx = stage_index[0], stage_index[1]