Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions colossalai/shardformer/modeling/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
Expand Down Expand Up @@ -77,7 +80,7 @@ def mistral_model_forward(
else:
position_ids = position_ids.view(-1, seq_length).long()

if attention_mask is not None and self._use_flash_attention_2 and use_cache:
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
Expand All @@ -97,9 +100,18 @@ def mistral_model_forward(
is_causal=True,
)
else:
if self._use_flash_attention_2:
if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
Expand Down Expand Up @@ -462,7 +474,7 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if attention_mask is not None and self._use_flash_attention_2 and use_cache:
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
Expand All @@ -481,9 +493,18 @@ def forward(
is_causal=True,
)
else:
if self._use_flash_attention_2:
if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
Expand Down
2 changes: 2 additions & 0 deletions colossalai/shardformer/policies/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
MistralDecoderLayer,
MistralFlashAttention2,
MistralModel,
MistralSdpaAttention,
)

ATTN_IMPLEMENTATION = {
"eager": MistralAttention,
"flash_attention_2": MistralFlashAttention2,
"sdpa": MistralSdpaAttention,
}

policy = {}
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ ray
sentencepiece
google
protobuf
transformers>=4.36.2,<4.40.0
transformers==4.39.3
peft>=0.7.1
bitsandbytes>=0.39.0
rpyc==6.0.0
Expand Down