From eadcffce08466379ef8b800b4c6608375d4a3ee7 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 6 Jun 2024 16:44:06 +0800 Subject: [PATCH] [shardformer] fix import --- colossalai/shardformer/modeling/llama.py | 6 ++++-- colossalai/shardformer/modeling/qwen2.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 2b30074a5e68..01d10c8dcf95 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -8,6 +8,10 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.cache_utils import Cache +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -17,8 +21,6 @@ LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, apply_rotary_pos_emb, repeat_kv, ) diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 8f8ab25a5b3f..e0aa5fba4a01 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -9,13 +9,15 @@ ) try: + from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, + ) from transformers.models.qwen2.modeling_qwen2 import ( Qwen2Attention, Qwen2ForCausalLM, Qwen2ForSequenceClassification, Qwen2Model, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, apply_rotary_pos_emb, repeat_kv, )