Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 13 additions & 18 deletions colossalai/shardformer/modeling/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def opt_model_forward(
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values_length + seq_length
# embed positions
if self.decoder._use_flash_attention_2:
if self.decoder.config._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
attention_mask = (
Expand Down Expand Up @@ -542,6 +542,9 @@ def opt_for_question_answering_forward(
def get_opt_flash_attention_forward(shard_config: ShardConfig):
from transformers.models.opt.modeling_opt import OPTAttention

def _shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
return tensor.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2).contiguous()

def forward(
self: OPTAttention,
hidden_states: torch.Tensor,
Expand All @@ -568,30 +571,20 @@ def forward(
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
key_states = _shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
value_states = _shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = _shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
value_states = _shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
key_states = _shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
value_states = _shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)

query_states = self._shape(query_states, tgt_len, bsz)
query_states = _shape(query_states, tgt_len, bsz, self.num_heads, self.head_dim)

dropout_p = self.dropout if self.training else 0.0
attn_output = ColoAttention.attention(
Expand Down Expand Up @@ -630,6 +623,8 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
position_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down
1 change: 1 addition & 0 deletions tests/kit/model_zoo/transformers/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def data_gen_for_question_answering():
num_hidden_layers=2,
num_attention_heads=4,
dropout=0,
attn_implementation="eager",
)

# register the following models
Expand Down