diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py index a84a3097231a..429c4350c1dc 100644 --- a/colossalai/shardformer/modeling/deepseek.py +++ b/colossalai/shardformer/modeling/deepseek.py @@ -666,6 +666,9 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + # TODO: upgrade transformers to 4.44.0 to fix the bug, remove the hard code. + self._use_flash_attention_2 = shard_config.enable_flash_attention + self._use_sdpa = False if shard_config.enable_flash_attention else self._use_sdpa if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None