From 5161429074531a0f7d47277aaf1361a0017748c5 Mon Sep 17 00:00:00 2001 From: haze188 Date: Mon, 12 Aug 2024 09:02:25 +0000 Subject: [PATCH] [misc] Bypass the huggingface bug to solve the mask mismatch problem --- colossalai/shardformer/modeling/deepseek.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py index a84a3097231a..429c4350c1dc 100644 --- a/colossalai/shardformer/modeling/deepseek.py +++ b/colossalai/shardformer/modeling/deepseek.py @@ -666,6 +666,9 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + # TODO: upgrade transformers to 4.44.0 to fix the bug, remove the hard code. + self._use_flash_attention_2 = shard_config.enable_flash_attention + self._use_sdpa = False if shard_config.enable_flash_attention else self._use_sdpa if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None