diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index dd3afecd6723..76534b5d5d2e 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -43,7 +43,7 @@ def _get_attention_mask( is_causal=True, ) else: - attention_mask = self.decoder._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), hidden_states,