From 40fff9a787fbf23d1e5be93a6f54578f21a4c222 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 14 May 2025 15:48:59 +0800 Subject: [PATCH 1/2] upgrade opt --- colossalai/shardformer/modeling/opt.py | 39 ++++++++++++++----------- tests/kit/model_zoo/transformers/opt.py | 1 + 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 3ea4db9e2f70..ea53418d7455 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -128,7 +128,7 @@ def opt_model_forward( # required mask seq length can be calculated via length of past mask_seq_length = past_key_values_length + seq_length # embed positions - if self.decoder._use_flash_attention_2: + if self.decoder.config._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None attention_mask = ( @@ -542,6 +542,9 @@ def opt_for_question_answering_forward( def get_opt_flash_attention_forward(shard_config: ShardConfig): from transformers.models.opt.modeling_opt import OPTAttention + def _shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int): + return tensor.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2).contiguous() + def forward( self: OPTAttention, hidden_states: torch.Tensor, @@ -568,30 +571,30 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = _shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim) + value_states = _shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = _shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) + value_states = _shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = _shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) + value_states = _shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + # if self.is_decoder: + # # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # # Further calls to cross_attention layer can then reuse all cross-attention + # # key/value_states (first "if" case) + # # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # # all previous decoder key/value_states. Further calls to uni-directional self-attention + # # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # # if encoder bi-directional self-attention `past_key_value` is always `None` + # past_key_value = (key_states, value_states) - query_states = self._shape(query_states, tgt_len, bsz) + query_states = _shape(query_states, tgt_len, bsz, self.num_heads, self.head_dim) dropout_p = self.dropout if self.training else 0.0 attn_output = ColoAttention.attention( @@ -630,6 +633,8 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( diff --git a/tests/kit/model_zoo/transformers/opt.py b/tests/kit/model_zoo/transformers/opt.py index 2da94a4fcc0f..5ffc227f979e 100644 --- a/tests/kit/model_zoo/transformers/opt.py +++ b/tests/kit/model_zoo/transformers/opt.py @@ -53,6 +53,7 @@ def data_gen_for_question_answering(): num_hidden_layers=2, num_attention_heads=4, dropout=0, + attn_implementation="eager", ) # register the following models From b2da8e92f8214c60904f69246ae25df23a283e7d Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 14 May 2025 15:50:41 +0800 Subject: [PATCH 2/2] fix --- colossalai/shardformer/modeling/opt.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index ea53418d7455..aa39bf40c298 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -584,16 +584,6 @@ def forward( key_states = _shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) value_states = _shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) - # if self.is_decoder: - # # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # # Further calls to cross_attention layer can then reuse all cross-attention - # # key/value_states (first "if" case) - # # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # # all previous decoder key/value_states. Further calls to uni-directional self-attention - # # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # # if encoder bi-directional self-attention `past_key_value` is always `None` - # past_key_value = (key_states, value_states) - query_states = _shape(query_states, tgt_len, bsz, self.num_heads, self.head_dim) dropout_p = self.dropout if self.training else 0.0