diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 2b30074a5e68..01d10c8dcf95 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -8,6 +8,10 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.cache_utils import Cache +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -17,8 +21,6 @@ LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, apply_rotary_pos_emb, repeat_kv, ) diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 8f8ab25a5b3f..e0aa5fba4a01 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -9,13 +9,15 @@ ) try: + from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, + ) from transformers.models.qwen2.modeling_qwen2 import ( Qwen2Attention, Qwen2ForCausalLM, Qwen2ForSequenceClassification, Qwen2Model, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, apply_rotary_pos_emb, repeat_kv, )