Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 34 additions & 15 deletions colossalai/shardformer/modeling/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def t5_stack_forward(
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = None,
cache_position=None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
position_bias: Optional[torch.Tensor] = None,
Expand All @@ -68,15 +69,6 @@ def t5_stack_forward(
if use_cache:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
if use_cache is True:
if not in_decoder:
raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

stage = stage_manager.stage
in_decoder = self.is_decoder
Expand Down Expand Up @@ -122,18 +114,24 @@ def t5_stack_forward(
device = hidden_states.device

# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
mask_seq_length = seq_length

# initialize past_key_values with `None` if past does not exist
if past_key_values is None:
past_key_values = [None] * len(self.block)

past_key_values_length = 0
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=hidden_states.device
)

if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand All @@ -146,6 +144,21 @@ def t5_stack_forward(
else:
encoder_extended_attention_mask = None

if self.config.is_decoder:
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
None,
output_attentions,
)
elif attention_mask is not None:
causal_mask = attention_mask[:, None, None, :]
causal_mask = causal_mask.to(dtype=hidden_states.dtype)
causal_mask = (1.0 - causal_mask) * torch.finfo(hidden_states.dtype).min
else:
causal_mask = None

# Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
Expand All @@ -158,7 +171,6 @@ def t5_stack_forward(
start_idx, end_idx = stage_index[0], stage_index[1]

for i in range(start_idx, end_idx):
past_key_value = past_key_values[i]
layer_module = self.block[i]
layer_head_mask = head_mask[i]
cross_attn_layer_head_mask = cross_attn_head_mask[i]
Expand All @@ -168,7 +180,7 @@ def t5_stack_forward(
layer_outputs = self._gradient_checkpointing_func(
layer_module.forward,
hidden_states,
extended_attention_mask,
causal_mask,
position_bias,
encoder_hidden_states,
encoder_extended_attention_mask,
Expand All @@ -178,20 +190,24 @@ def t5_stack_forward(
None, # past_key_value is always None with gradient checkpointing
use_cache,
output_attentions,
return_dict,
cache_position,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask=extended_attention_mask,
attention_mask=causal_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
past_key_value=None,
use_cache=use_cache,
output_attentions=output_attentions,
return_dict=return_dict,
cache_position=cache_position,
)

# layer_outputs is a tuple with:
Expand Down Expand Up @@ -669,6 +685,7 @@ def forward(
query_length: Optional[int] = None,
use_cache: bool = False,
output_attentions: bool = False,
cache_position=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
Expand Down Expand Up @@ -805,6 +822,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
output_attentions: bool = False,
cache_position=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention(
Expand All @@ -815,6 +833,7 @@ def forward(
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
hidden_states = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training)
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
Expand Down